Skip to content

Fix #429 -- add branch-free paths for simple cases like CuArray{Float32} #430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 45 additions & 14 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ Base.propertynames(p::ProjectTo) = propertynames(backing(p))
backing(project::ProjectTo) = getfield(project, :info)

project_type(p::ProjectTo{T}) where {T} = T
project_eltype(p::ProjectTo{T}) where {T} = eltype(T)

function project_promote_type(projectors)
T = mapreduce(project_type, promote_type, projectors)
if T <: Number
# The point of this function is to make p.element for arrays. Not in use yet!
return ProjectTo(zero(T))
else
return ProjectTo{Any}()
end
end

function Base.show(io::IO, project::ProjectTo{T}) where {T}
print(io, "ProjectTo{")
Expand Down Expand Up @@ -181,9 +192,19 @@ end
# If we don't have a more specialized `ProjectTo` rule, we just assume that there is
# no structure worth re-imposing. Then any array is acceptable as a gradient.

# For arrays of numbers, just store one projector:
function ProjectTo(x::AbstractArray{T}) where {T<:Number}
return ProjectTo{AbstractArray}(; element=_eltype_projectto(T), axes=axes(x))
# For arrays of numbers, just store one projector, and construct it without branches:
ProjectTo(x::AbstractArray{<:Number}) = _array_projectto(x, axes(x))
function _array_projectto(x::AbstractArray{T,N}, axes::NTuple{N,<:Base.OneTo{Int}}) where {T,N}
element = _eltype_projectto(T)
S = project_type(element)
# Fastest path: N means they are OneTo, hence reshape can be skipped
return ProjectTo{AbstractArray{S,N}}(; element=element, axes=axes)
Comment on lines +200 to +201
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is slightly weird, but my idea for avoiding the if axes(dx) == projector.axes branch is to include the N only when the axes are Base.OneTo. If ndims(dx) matches, and its eltype matches, then it can pass through just by dispatch.

This means it won't check for quite as many size mismatches. But it will still reshape for OffsetArrays, SArrays, etc, those will go the "slow path" as before.

end
function _array_projectto(x::AbstractArray{T,N}, axes::Tuple) where {T,N}
element = _eltype_projectto(T)
S = project_type(element)
# Omitting N means reshape will be called, for OffsetArrays, SArrays, etc.
return ProjectTo{AbstractArray{S}}(; element=element, axes=axes)
end
ProjectTo(x::AbstractArray{Bool}) = ProjectTo{NoTangent}()

Expand All @@ -201,7 +222,7 @@ function ProjectTo(xs::AbstractArray)
end
end

function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
function (project::ProjectTo{<:AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
# First deal with shape. The rule is that we reshape to add or remove trivial dimensions
# like dx = ones(4,1), where x = ones(4), but throw an error on dx = ones(1,4) etc.
dy = if axes(dx) == project.axes
Expand All @@ -225,24 +246,34 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
return dz
end

# Fast paths, for arrays of numbers:
(::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{S,N}) where {S<:T} where {T,N} = dx
(project::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{S}) where {S<:T} where {T,N} = reshape(dx, project.axes)
(project::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{S,N}) where {S,T,N} = map(project.element, dx)
(project::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray) where {T,N} = map(project.element, reshape(dx, project.axes))

# Trivial case, this won't collapse Any[NoTangent(), NoTangent()] but that's OK.
(project::ProjectTo{AbstractArray})(dx::AbstractArray{<:AbstractZero}) = NoTangent()
(project::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{<:AbstractZero}) where {T,N} = NoTangent()

# Row vectors aren't acceptable as gradients for 1-row matrices:
function (project::ProjectTo{AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec)
# function (project::ProjectTo{<:AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec)
# return project(reshape(vec(dx), 1, :))
# end
function (project::ProjectTo{AbstractArray{T,N}})(dx::LinearAlgebra.AdjOrTransAbsVec) where {T,N}
return project(reshape(vec(dx), 1, :))
end

# Zero-dimensional arrays -- these have a habit of going missing,
# although really Ref() is probably a better structure.
function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore from numbers
if !(project.axes isa Tuple{})
throw(DimensionMismatch(
"array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number",
))
end
return fill(project.element(dx))
end
# function (project::ProjectTo{<:AbstractArray})(dx::Number) # ... so we restore from numbers
# if !(project.axes isa Tuple{})
# throw(DimensionMismatch(
# "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number",
# ))
# end
# return fill(project.element(dx))
# end
(project::ProjectTo{AbstractArray{<:Number,0}})(dx::Number) = fill(project.element(dx))

function _projection_mismatch(axes_x::Tuple, size_dx::Tuple)
size_x = map(length, axes_x)
Expand Down
10 changes: 9 additions & 1 deletion test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
@test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im
@test ProjectTo(2.0)(1+1im) === 1.0


# storage
@test ProjectTo(1)(pi) === pi
@test ProjectTo(1 + im)(pi) === ComplexF64(pi)
Expand Down Expand Up @@ -285,6 +284,15 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
##### `OffsetArrays`
#####

# function ProjectTo(x::OffsetArray{T,N}) where {T<:Number,N}
# # As usual:
# element = ChainRulesCore._eltype_projectto(T)
# S = ChainRulesCore.project_type(element)
# # But don't save N? Avoids fast path?
# # Or perhaps the default constructor can check whether axes(x) is NTuple{N,OneTo}?
# return ProjectTo{AbstractArray{S}}(; element=element, axes=axes(x))
# end

@testset "OffsetArrays" begin
# While there is no code for this, the rule that it checks axes(x) == axes(dx) else
# reshape means that it restores offsets. (It throws an error on nontrivial size mismatch.)
Expand Down