Skip to content

Commit 0922e4a

Browse files
committed
add fast paths for arrays of correct numbers, same ndims
1 parent 598e891 commit 0922e4a

File tree

2 files changed

+50
-14
lines changed

2 files changed

+50
-14
lines changed

src/projection.jl

+41-13
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ Base.propertynames(p::ProjectTo) = propertynames(backing(p))
4040
backing(project::ProjectTo) = getfield(project, :info)
4141

4242
project_type(p::ProjectTo{T}) where {T} = T
43+
project_eltype(p::ProjectTo{T}) where {T} = eltype(T)
44+
45+
function project_promote_type(projectors)
46+
T = mapreduce(project_type, promote_type, projectors)
47+
if T <: Number
48+
# The point of this function is to make p.element for arrays. Not in use yet!
49+
return ProjectTo(zero(T))
50+
else
51+
return ProjectTo{Any}()
52+
end
53+
end
4354

4455
function Base.show(io::IO, project::ProjectTo{T}) where {T}
4556
print(io, "ProjectTo{")
@@ -178,8 +189,16 @@ end
178189
# no structure worth re-imposing. Then any array is acceptable as a gradient.
179190

180191
# For arrays of numbers, just store one projector:
181-
function ProjectTo(x::AbstractArray{T}) where {T<:Number}
182-
return ProjectTo{AbstractArray}(; element=_eltype_projectto(T), axes=axes(x))
192+
function ProjectTo(x::AbstractArray{T,N}) where {T<:Number,N}
193+
element = _eltype_projectto(T)
194+
S = project_type(element) # new idea -- for any number, S is enough.
195+
# Store .element for now too, although it's redundant? Reconstruct from eltype?
196+
if axes(x) isa NTuple{N,Base.OneTo{Int}}
197+
return ProjectTo{AbstractArray{S,N}}(; element=element, axes=axes(x))
198+
else
199+
# Omitting N prohibits the fast path, and thus won't skip reshape for OffsetArrays, SArrays, etc.
200+
return ProjectTo{AbstractArray{S}}(; element=element, axes=axes(x))
201+
end
183202
end
184203
ProjectTo(x::AbstractArray{Bool}) = ProjectTo{NoTangent}()
185204

@@ -197,7 +216,7 @@ function ProjectTo(xs::AbstractArray)
197216
end
198217
end
199218

200-
function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
219+
function (project::ProjectTo{<:AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
201220
# First deal with shape. The rule is that we reshape to add or remove trivial dimensions
202221
# like dx = ones(4,1), where x = ones(4), but throw an error on dx = ones(1,4) etc.
203222
dy = if axes(dx) == project.axes
@@ -221,24 +240,33 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
221240
return dz
222241
end
223242

243+
# Fast path, for arrays of numbers:
244+
# (::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{T,N}) where {T,N} = (@info "fast 1"; dx)
245+
(::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{S,N}) where {S<:T} where {T,N} = dx # (@info "fast 2"; dx)
246+
(project::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{S,N}) where {S,T,N} = (@info "fast 3"; map(project.element, dx))
247+
224248
# Trivial case, this won't collapse Any[NoTangent(), NoTangent()] but that's OK.
225-
(project::ProjectTo{AbstractArray})(dx::AbstractArray{<:AbstractZero}) = NoTangent()
249+
(project::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{<:AbstractZero}) where {T,N} = NoTangent()
226250

227251
# Row vectors aren't acceptable as gradients for 1-row matrices:
228-
function (project::ProjectTo{AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec)
252+
# function (project::ProjectTo{<:AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec)
253+
# return project(reshape(vec(dx), 1, :))
254+
# end
255+
function (project::ProjectTo{AbstractArray{T,N}})(dx::LinearAlgebra.AdjOrTransAbsVec) where {T,N}
229256
return project(reshape(vec(dx), 1, :))
230257
end
231258

232259
# Zero-dimensional arrays -- these have a habit of going missing,
233260
# although really Ref() is probably a better structure.
234-
function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore from numbers
235-
if !(project.axes isa Tuple{})
236-
throw(DimensionMismatch(
237-
"array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number",
238-
))
239-
end
240-
return fill(project.element(dx))
241-
end
261+
# function (project::ProjectTo{<:AbstractArray})(dx::Number) # ... so we restore from numbers
262+
# if !(project.axes isa Tuple{})
263+
# throw(DimensionMismatch(
264+
# "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number",
265+
# ))
266+
# end
267+
# return fill(project.element(dx))
268+
# end
269+
(project::ProjectTo{AbstractArray{<:Number,0}})(dx::Number) = fill(project.element(dx))
242270

243271
# Ref -- works like a zero-array, also allows restoration from a number:
244272
ProjectTo(x::Ref) = ProjectTo{Ref}(; x=ProjectTo(x[]))

test/projection.jl

+9-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
2424
@test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im
2525
@test ProjectTo(2.0)(1+1im) === 1.0
2626

27-
2827
# storage
2928
@test ProjectTo(1)(pi) === pi
3029
@test ProjectTo(1 + im)(pi) === ComplexF64(pi)
@@ -265,6 +264,15 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
265264
##### `OffsetArrays`
266265
#####
267266

267+
# function ProjectTo(x::OffsetArray{T,N}) where {T<:Number,N}
268+
# # As usual:
269+
# element = ChainRulesCore._eltype_projectto(T)
270+
# S = ChainRulesCore.project_type(element)
271+
# # But don't save N? Avoids fast path?
272+
# # Or perhaps the default constructor can check whether axes(x) is NTuple{N,OneTo}?
273+
# return ProjectTo{AbstractArray{S}}(; element=element, axes=axes(x))
274+
# end
275+
268276
@testset "OffsetArrays" begin
269277
# While there is no code for this, the rule that it checks axes(x) == axes(dx) else
270278
# reshape means that it restores offsets. (It throws an error on nontrivial size mismatch.)

0 commit comments

Comments
 (0)