Skip to content

Commit 0e560c6

Browse files
mcabbottoxinabox
andauthored
Add ProjectTo(::Any) = identity (#458)
* add ProjectTo(::Any) = identity * Apply 3 suggestions Co-authored-by: Lyndon White <[email protected]> Co-authored-by: Lyndon White <[email protected]>
1 parent 1893e82 commit 0e560c6

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

src/projection.jl

+10-3
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,10 @@ _maybe_call(f, x) = f
8383
Returns a `ProjectTo{T}` functor which projects a differential `dx` onto the
8484
relevant tangent space for `x`.
8585
86-
At present this undersands only `x::Number`, `x::AbstractArray` and `x::Ref`.
87-
It should not be called on arguments of an `rrule` method which accepts other types.
86+
Custom `ProjectTo` methods are provided for many subtypes of `Number` (to e.g. ensure precision),
87+
and `AbstractArray` (to e.g. ensure sparsity structure is maintained by tangent).
88+
Called on unknown types it will (as of v1.5.0) simply return `identity`, thus can be safely
89+
applied to arbitrary `rrule` arguments.
8890
8991
# Examples
9092
```jldoctest
@@ -112,7 +114,7 @@ julia> ProjectTo([1 2; 3 4]') # no special structure, integers are promoted to
112114
ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(2), Base.OneTo(2)))
113115
```
114116
"""
115-
ProjectTo(::Any) # just to attach docstring
117+
ProjectTo(::Any) = identity
116118

117119
# Generic
118120
(::ProjectTo{T})(dx::AbstractZero) where {T} = dx
@@ -143,6 +145,11 @@ ProjectTo{P}(::NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}}) where {P,T} = Project
143145
# Bool
144146
ProjectTo(::Bool) = ProjectTo{NoTangent}() # same projector as ProjectTo(::AbstractZero) above
145147

148+
# Other never-differentiable types
149+
for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle)
150+
@eval ProjectTo(::$T) = ProjectTo{NoTangent}()
151+
end
152+
146153
# Numbers
147154
ProjectTo(::Real) = ProjectTo{Real}()
148155
ProjectTo(::Complex) = ProjectTo{Complex}()

test/projection.jl

+17-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ Base.real(x::Dual) = x
1111
Base.float(x::Dual) = Dual(float(x.value), float(x.partial))
1212
Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
1313

14+
# Trivial struct
15+
struct NoSuperType end
16+
1417
@testset "projection" begin
1518

1619
#####
@@ -24,7 +27,6 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
2427
@test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im
2528
@test ProjectTo(2.0)(1+1im) === 1.0
2629

27-
2830
# storage
2931
@test ProjectTo(1)(pi) === pi
3032
@test ProjectTo(1 + im)(pi) === ComplexF64(pi)
@@ -94,9 +96,10 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
9496
@test y1[1] == [1 2]
9597
@test !(y1 isa Adjoint) && !(y1[1] isa Adjoint)
9698

97-
# arrays of unknown things
98-
@test_throws MethodError ProjectTo([:x, :y])
99-
@test_throws MethodError ProjectTo(Any[:x, :y])
99+
# arrays of other things
100+
@test ProjectTo([:x, :y]) isa ProjectTo{NoTangent}
101+
@test ProjectTo(Any['x', "y"]) isa ProjectTo{NoTangent}
102+
@test ProjectTo([(1,2), (3,4), (5,6)]) isa ProjectTo{AbstractArray}
100103

101104
@test ProjectTo(Any[1, 2])(1:2) == [1.0, 2.0] # projects each number.
102105
@test Tuple(ProjectTo(Any[1, 2 + 3im])(1:2)) === (1.0, 2.0 + 0.0im)
@@ -140,6 +143,12 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
140143
@test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent}
141144
end
142145

146+
@testset "Base: non-diff" begin
147+
@test ProjectTo(:a)(1) == NoTangent()
148+
@test ProjectTo('b')(2) == NoTangent()
149+
@test ProjectTo("cde")(345) == NoTangent()
150+
end
151+
143152
#####
144153
##### `LinearAlgebra`
145154
#####
@@ -301,6 +310,10 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
301310
##### `ChainRulesCore`
302311
#####
303312

313+
@testset "pass-through" begin
314+
@test ProjectTo(NoSuperType()) === identity
315+
end
316+
304317
@testset "AbstractZero" begin
305318
pz = ProjectTo(ZeroTangent())
306319
pz(0) == NoTangent()

0 commit comments

Comments
 (0)