From e64f76e6cf8874838b6a9768926bb13ce3ebab0c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 10 Aug 2021 20:52:48 -0400 Subject: [PATCH] add fast paths for arrays of correct numbers, same ndims --- src/projection.jl | 59 +++++++++++++++++++++++++++++++++++----------- test/projection.jl | 10 +++++++- 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index b4532390c..7a251e6bf 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -40,6 +40,17 @@ Base.propertynames(p::ProjectTo) = propertynames(backing(p)) backing(project::ProjectTo) = getfield(project, :info) project_type(p::ProjectTo{T}) where {T} = T +project_eltype(p::ProjectTo{T}) where {T} = eltype(T) + +function project_promote_type(projectors) + T = mapreduce(project_type, promote_type, projectors) + if T <: Number + # The point of this function is to make p.element for arrays. Not in use yet! + return ProjectTo(zero(T)) + else + return ProjectTo{Any}() + end +end function Base.show(io::IO, project::ProjectTo{T}) where {T} print(io, "ProjectTo{") @@ -181,9 +192,19 @@ end # If we don't have a more specialized `ProjectTo` rule, we just assume that there is # no structure worth re-imposing. Then any array is acceptable as a gradient. -# For arrays of numbers, just store one projector: -function ProjectTo(x::AbstractArray{T}) where {T<:Number} - return ProjectTo{AbstractArray}(; element=_eltype_projectto(T), axes=axes(x)) +# For arrays of numbers, just store one projector, and construct it without branches: +ProjectTo(x::AbstractArray{<:Number}) = _array_projectto(x, axes(x)) +function _array_projectto(x::AbstractArray{T,N}, axes::NTuple{N,<:Base.OneTo{Int}}) where {T,N} + element = _eltype_projectto(T) + S = project_type(element) + # Fastest path: N means they are OneTo, hence reshape can be skipped + return ProjectTo{AbstractArray{S,N}}(; element=element, axes=axes) +end +function _array_projectto(x::AbstractArray{T,N}, axes::Tuple) where {T,N} + element = _eltype_projectto(T) + S = project_type(element) + # Omitting N means reshape will be called, for OffsetArrays, SArrays, etc. + return ProjectTo{AbstractArray{S}}(; element=element, axes=axes) end ProjectTo(x::AbstractArray{Bool}) = ProjectTo{NoTangent}() @@ -201,7 +222,7 @@ function ProjectTo(xs::AbstractArray) end end -function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M} +function (project::ProjectTo{<:AbstractArray})(dx::AbstractArray{S,M}) where {S,M} # First deal with shape. The rule is that we reshape to add or remove trivial dimensions # like dx = ones(4,1), where x = ones(4), but throw an error on dx = ones(1,4) etc. dy = if axes(dx) == project.axes @@ -225,24 +246,34 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M} return dz end +# Fast paths, for arrays of numbers: +(::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{S,N}) where {S<:T} where {T,N} = dx +(project::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{S}) where {S<:T} where {T,N} = reshape(dx, project.axes) +(project::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{S,N}) where {S,T,N} = map(project.element, dx) +(project::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray) where {T,N} = map(project.element, reshape(dx, project.axes)) + # Trivial case, this won't collapse Any[NoTangent(), NoTangent()] but that's OK. -(project::ProjectTo{AbstractArray})(dx::AbstractArray{<:AbstractZero}) = NoTangent() +(project::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{<:AbstractZero}) where {T,N} = NoTangent() # Row vectors aren't acceptable as gradients for 1-row matrices: -function (project::ProjectTo{AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec) +# function (project::ProjectTo{<:AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec) +# return project(reshape(vec(dx), 1, :)) +# end +function (project::ProjectTo{AbstractArray{T,N}})(dx::LinearAlgebra.AdjOrTransAbsVec) where {T,N} return project(reshape(vec(dx), 1, :)) end # Zero-dimensional arrays -- these have a habit of going missing, # although really Ref() is probably a better structure. -function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore from numbers - if !(project.axes isa Tuple{}) - throw(DimensionMismatch( - "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number", - )) - end - return fill(project.element(dx)) -end +# function (project::ProjectTo{<:AbstractArray})(dx::Number) # ... so we restore from numbers +# if !(project.axes isa Tuple{}) +# throw(DimensionMismatch( +# "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number", +# )) +# end +# return fill(project.element(dx)) +# end +(project::ProjectTo{AbstractArray{<:Number,0}})(dx::Number) = fill(project.element(dx)) function _projection_mismatch(axes_x::Tuple, size_dx::Tuple) size_x = map(length, axes_x) diff --git a/test/projection.jl b/test/projection.jl index 53e3e0bcb..d5049e74a 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -24,7 +24,6 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial)) @test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im @test ProjectTo(2.0)(1+1im) === 1.0 - # storage @test ProjectTo(1)(pi) === pi @test ProjectTo(1 + im)(pi) === ComplexF64(pi) @@ -285,6 +284,15 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial)) ##### `OffsetArrays` ##### +# function ProjectTo(x::OffsetArray{T,N}) where {T<:Number,N} +# # As usual: +# element = ChainRulesCore._eltype_projectto(T) +# S = ChainRulesCore.project_type(element) +# # But don't save N? Avoids fast path? +# # Or perhaps the default constructor can check whether axes(x) is NTuple{N,OneTo}? +# return ProjectTo{AbstractArray{S}}(; element=element, axes=axes(x)) +# end + @testset "OffsetArrays" begin # While there is no code for this, the rule that it checks axes(x) == axes(dx) else # reshape means that it restores offsets. (It throws an error on nontrivial size mismatch.)