From 7c0641c90f7dc9efc01983d696e44062b192e1df Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 19 Sep 2023 16:17:28 +0800 Subject: [PATCH 01/15] rename testset --- test/AbstractDifferentiationTests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/AbstractDifferentiationTests.jl b/test/AbstractDifferentiationTests.jl index 66bcc1fd..e02186b2 100644 --- a/test/AbstractDifferentiationTests.jl +++ b/test/AbstractDifferentiationTests.jl @@ -29,7 +29,7 @@ end # standard tests from AbstractDifferentiation.test_utils include(joinpath(pathof(AbstractDifferentiation), "..", "..", "test", "test_utils.jl")) -@testset "ForwardDiffBackend" begin +@testset "Standard AbstractDifferentiation.test_utils tests" begin backends = [ @inferred(Diffractor.DiffractorForwardBackend()) ] From be091aa37e267f8abcdc9a1a43d3bd0be2a2a9bc Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 19 Sep 2023 21:36:38 +0800 Subject: [PATCH 02/15] Initial pass at switching CompositeBundle over to TaylorTangentBundle --- src/stage1/forward.jl | 13 +++++++++++ src/stage1/recurse_fwd.jl | 40 +++++++++++++++++++++++++++++++--- test/forward.jl | 46 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 94 insertions(+), 5 deletions(-) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 5393f68d..2b6f36e7 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -51,6 +51,18 @@ function shuffle_down(b::CompositeBundle{N, B}) where {N, B} z end +function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2} + z₀ = primal(r)[1] + z₁ = partial(r, 1)[1] + z₂ = primal(r)[2] + z₁₂ = partial(r, 1)[2] + if z₁ == z₂ + return TaylorBundle{2}(z₀, (z₁, z₁₂)) + else + return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂)) + end +end + function shuffle_up(r::CompositeBundle{1}) z₀ = primal(r.tup[1]) z₁ = partial(r.tup[1], 1) @@ -76,6 +88,7 @@ isswifty(::UniformBundle) = true isswifty(b::CompositeBundle) = all(isswifty, b.tup) isswifty(::Any) = false +#TODO: port this to TaylorTangent: function shuffle_up(r::CompositeBundle{N}) where {N} a, b = r.tup if isswifty(a) && isswifty(b) && taylor_compatible(a, b) diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index 0c5ab07c..a1c23bc4 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -4,13 +4,47 @@ struct ∂vararg{N}; end (::∂vararg{N})() where {N} = ZeroBundle{N}(()) function (::∂vararg{N})(a::AbstractTangentBundle{N}...) where N - CompositeBundle{N, Tuple{map(x->basespace(typeof(x)), a)...}}(a) + B = Tuple{map(x->basespace(Core.Typeof(x)), a)...} + return (∂☆new{N}())(B, a...) end struct ∂☆new{N}; end -(::∂☆new{N})(B::Type, a::AbstractTangentBundle{N}...) where {N} = - CompositeBundle{N, B}(a) +# we split out the 1st order derivative as a special case for performance +# but the nth order case does also work for this +function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...) + primal_args = map(primal, xs) + the_primal = B <: Tuple ? B(primal_args) : B(primal_args...) + + tangent_tup = map(x->partial(x, 1), xs) + the_partial = if B<:Tuple + Tangent{B, typeof(tangent_tup)}(tangent_tup) + else + names = fieldnames(B) + tangent_nt = NamedTuple{names}(tangent_tup) + Tangent{B, typeof(tangent_nt)}(tangent_nt) + end + return TaylorBundle{1, B}(the_primal, (the_partial,)) +end + +function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N} + primal_args = map(primal, xs) + the_primal = B <: Tuple ? B(primal_args) : B(primal_args...) + + the_partials = ntuple(Val{N}()) do ii + iith_order_type = ii==1 ? B : Any # the type of the higher order tangents isn't worth tracking + tangent_tup = map(x->partial(x, ii), xs) + tangent = if B<:Tuple + Tangent{iith_order_type, typeof(tangent_tup)}(tangent_tup) + else + names = fieldnames(B) + tangent_nt = NamedTuple{names}(tangent_tup) + Tangent{iith_order_type, typeof(tangent_nt)}(tangent_nt) + end + return tangent + end + return TaylorBundle{N, B}(the_primal, the_partials) +end @generated (::∂☆new{N})(B::Type) where {N} = return :(ZeroBundle{$N}($(Expr(:new, :B)))) diff --git a/test/forward.jl b/test/forward.jl index 9ac5bea1..8ea24cde 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -25,8 +25,7 @@ let var"'" = Diffractor.PrimeDerivativeFwd # Integration tests @test recursive_sin'(1.0) == cos(1.0) @test recursive_sin''(1.0) == -sin(1.0) - # Error: ArgumentError: Tangent for the primal Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}} - # should be backed by a NamedTuple type, not by Tuple{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}. + @test_broken recursive_sin'''(1.0) == -cos(1.0) @test_broken recursive_sin''''(1.0) == sin(1.0) @test_broken recursive_sin'''''(1.0) == cos(1.0) @@ -40,6 +39,7 @@ let var"'" = Diffractor.PrimeDerivativeFwd end # Some Basic Mixed Mode tests +# TODO: unbreak this function sin_twice_fwd(x) let var"'" = Diffractor.PrimeDerivativeFwd sin''(x) @@ -90,4 +90,46 @@ end end end + +@testset "structs" begin + struct IDemo + x::Float64 + y::Float64 + end + + function foo(a) + obj = IDemo(2.0, a) + return obj.x * obj.y + end + + let var"'" = Diffractor.PrimeDerivativeFwd + @test foo'(100.0) == 2.0 + @test foo''(100.0) == 0.0 + end +end + +@testset "tuples" begin + function foo(a) + tup = (2.0, a) + return first(tup) * tup[2] + end + + let var"'" = Diffractor.PrimeDerivativeFwd + @test foo'(100.0) == 2.0 + @test foo''(100.0) == 0.0 + end +end + +@testset "vararg" begin + function foo(a) + tup = (2.0, a) + return *(tup...) + end + + let var"'" = Diffractor.PrimeDerivativeFwd + @test foo'(100.0) == 2.0 + @test foo''(100.0) == 0.0 + end +end + end From c205e3424484c2d6b1fae8142bb0c16846557c9f Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 19 Sep 2023 22:00:22 +0800 Subject: [PATCH 03/15] Remove most uses of CompositeTangent --- src/AbstractDifferentiation.jl | 14 +------- src/stage1/forward.jl | 50 ++++++++-------------------- src/tangent.jl | 27 --------------- test/AbstractDifferentiationTests.jl | 6 ++-- test/tangent.jl | 15 +++++---- 5 files changed, 25 insertions(+), 87 deletions(-) diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index e70604cd..3320e3ae 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -10,19 +10,7 @@ This is more or less the Diffractor equivelent of ForwardDiff.jl's `Dual` type. """ function bundle end bundle(x, dx::ChainRulesCore.AbstractZero) = UniformBundle{1, typeof(x), typeof(dx)}(x, dx) -bundle(x::Number, dx::Number) = TaylorBundle{1}(x, (dx,)) -bundle(x::AbstractArray{<:Number}, dx::AbstractArray{<:Number}) = TaylorBundle{1}(x, (dx,)) -bundle(x::P, dx::Tangent{P}) where P = _bundle(x, ChainRulesCore.canonicalize(dx)) - -"helper that assumes tangent is in canonical form" -function _bundle(x::P, dx::Tangent{P}) where P - # SoA to AoS flip (hate this, hate it even more cos we just undo it later when we hit chainrules) - the_bundle = ntuple(Val{fieldcount(P)}()) do ii - bundle(getfield(x, ii), getproperty(dx, ii)) - end - return CompositeBundle{1, P}(the_bundle) -end - +bundle(x, dx) = TaylorBundle{1}(x, (dx,)) AD.@primitive function pushforward_function(b::DiffractorForwardBackend, f, args...) return function pushforward(vs) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 2b6f36e7..e771da98 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -4,14 +4,6 @@ partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i) partial(x::UniformTangent, i) = getfield(x, :val) partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors))) partial(x::AbstractZero, i) = x -partial(x::CompositeBundle{N, B}, i) where {N, B<:Tuple} = Tangent{B}(map(x->partial(x, i), getfield(x, :tup))...) -function partial(x::CompositeBundle{N, B}, i) where {N, B} - # This is tangent for a struct, but fields partials are each stored in a plain tuple - # so we add the names back using the primal `B` - # TODO: If required this can be done as a `@generated` function so it is type-stable - backing = NamedTuple{fieldnames(B)}(map(x->partial(x, i), getfield(x, :tup))) - return Tangent{B, typeof(backing)}(backing) -end primal(x::AbstractTangentBundle) = x.primal @@ -42,14 +34,6 @@ function shuffle_down(b::TaylorBundle{N, B}) where {N, B} ntuple(_sdown, N-1)) end -function shuffle_down(b::CompositeBundle{N, B}) where {N, B} - z = CompositeBundle{N-1, CompositeBundle{1, B}}( - (CompositeBundle{N-1, Tuple}( - map(shuffle_down, b.tup) - ),) - ) - z -end function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2} z₀ = primal(r)[1] @@ -63,18 +47,7 @@ function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2} end end -function shuffle_up(r::CompositeBundle{1}) - z₀ = primal(r.tup[1]) - z₁ = partial(r.tup[1], 1) - z₂ = primal(r.tup[2]) - z₁₂ = partial(r.tup[2], 1) - if z₁ == z₂ - return TaylorBundle{2}(z₀, (z₁, z₁₂)) - else - return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂)) - end -end - +#== function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N} primal(b) === a[TaylorTangentIndex(1)] || return false return all(1:(N-1)) do i @@ -88,7 +61,7 @@ isswifty(::UniformBundle) = true isswifty(b::CompositeBundle) = all(isswifty, b.tup) isswifty(::Any) = false -#TODO: port this to TaylorTangent: +#TODO: port this to TaylorTangent over composite structures function shuffle_up(r::CompositeBundle{N}) where {N} a, b = r.tup if isswifty(a) && isswifty(b) && taylor_compatible(a, b) @@ -102,6 +75,7 @@ function shuffle_up(r::CompositeBundle{N}) where {N} ntuple(i->partial(b,i), 1<<(N+1)-1)...)) end end +==# function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U} (a, b) = primal(r) @@ -198,13 +172,6 @@ end map(y->lifted_getfield(y, s), x.tangent.coeffs)) end -@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N} - x.tup[primal(s)] -end - -@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B} - x.tup[Base.fieldindex(B, primal(s))] -end @Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U} UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.tangent.val) @@ -223,9 +190,12 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}} end (f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...) +#== +# TODO port this to TaylorBundle over composite structure function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::CompositeBundle{N, <:Tuple}) where {N} ∂vararg{N}()(map(FwdMap(f), tup.tup)...) end +==# function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N} # TODO: This could do an inplace map! to avoid the extra rebundling @@ -267,23 +237,28 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate Core._apply_iterate(FwdIterate(iterate), this, (f,), args...) end +#== +#TODO: port this to TaylorTangent over composite structures function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}) where {N} r = iterate(t.tup) r === nothing && return ZeroBundle{N}(nothing) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end +#TODO: port this to TaylorTangent over composite structures function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N} r = iterate(t.tup, primal(a), map(primal, args)...) r === nothing && return ZeroBundle{N}(nothing) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end +#TODO: port this to TaylorTangent over composite structures function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}) where {N} r = Base.indexed_iterate(t.tup, primal(i)) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end +#TODO: port this to TaylorTangent over composite structures function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N} r = Base.indexed_iterate(t.tup, primal(i), primal(st1), map(primal, st)...) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) @@ -293,10 +268,11 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::Tan ∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1)) end - +#TODO: port this to TaylorTangent over composite structures function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::CompositeBundle{N, <:Tuple}, i::ZeroBundle) where {N} t.tup[primal(i)] end +==# function (this::∂☆{N})(::ZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N} DNEBundle{N}(typeof(primal(x))) diff --git a/src/tangent.jl b/src/tangent.jl index 10fe9ada..a350e468 100644 --- a/src/tangent.jl +++ b/src/tangent.jl @@ -290,33 +290,6 @@ end Base.getindex(u::UniformBundle, ::TaylorTangentIndex) = u.tangent.val -""" - CompositeBundle{N, B, B <: Tuple} - -Represents the tagent bundle where the base space is some tuple or struct type. -Mathematically, this tangent bundle is the product bundle of the individual -element bundles. -""" -struct CompositeBundle{N, B, T<:Tuple{Vararg{AbstractTangentBundle{N}}}} <: AbstractTangentBundle{N, B} - tup::T -end -CompositeBundle{N, B}(tup::T) where {N, B, T} = CompositeBundle{N, B, T}(tup) - -function Base.getindex(tb::CompositeBundle{N, B} where N, tti::TaylorTangentIndex) where {B} - B <: SArray && error() - return partial(tb, tti.i) -end - -primal(b::CompositeBundle{N, <:Tuple} where N) = map(primal, b.tup) -function primal(b::CompositeBundle{N, T} where N) where T<:CompositeBundle - T(map(primal, b.tup)...) -end -@generated primal(b::CompositeBundle{N, B} where N) where {B} = - quote - x = map(primal, b.tup) - $(Expr(:splatnew, B, :x)) - end - expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...) expand_singleton_to_array(asize, a::AbstractArray) = a diff --git a/test/AbstractDifferentiationTests.jl b/test/AbstractDifferentiationTests.jl index e02186b2..821adc99 100644 --- a/test/AbstractDifferentiationTests.jl +++ b/test/AbstractDifferentiationTests.jl @@ -7,14 +7,14 @@ backend = Diffractor.DiffractorForwardBackend() @test bundle(1.0, 2.0) isa Diffractor.TaylorBundle{1} @test bundle([1.0, 2.0], [2.0, 3.0]) isa Diffractor.TaylorBundle{1} - @test bundle(1.5=>2.5, Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0)) isa Diffractor.CompositeBundle{1} + @test bundle(1.5=>2.5, Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0)) isa Diffractor.TaylorBundle{1} @test bundle(1.1, ChainRulesCore.ZeroTangent()) isa Diffractor.ZeroBundle{1} - @test bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(first=1.0, second=Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0))) isa Diffractor.CompositeBundle{1} + @test bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(first=1.0, second=Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0))) isa Diffractor.TaylorBundle{1} # noncanonical structural tangent b = bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(second=Tangent{Pair{Float64, Float64}}(second=2.0, first=1.0))) t = Diffractor.first_partial(b) - @test b isa Diffractor.CompositeBundle{1} + @test b isa Diffractor.TaylorBundle{1} @test iszero(t.first) @test t.second.first == 1.0 @test t.second.second == 2.0 diff --git a/test/tangent.jl b/test/tangent.jl index fae513f5..8115c653 100644 --- a/test/tangent.jl +++ b/test/tangent.jl @@ -1,7 +1,7 @@ module tagent using Diffractor using Diffractor: AbstractZeroBundle, ZeroBundle, DNEBundle -using Diffractor: TaylorBundle, TaylorTangentIndex, CompositeBundle +using Diffractor: TaylorBundle, TaylorTangentIndex using Diffractor: ExplicitTangent, TaylorTangent, truncate using ChainRulesCore using Test @@ -28,21 +28,22 @@ using Test end @testset "AD through constructor" begin - #https://github.com/JuliaDiff/Diffractor.jl/issues/152 - # hits `getindex(::CompositeBundle{Foo152}, ::TaylorTangentIndex)` + # https://github.com/JuliaDiff/Diffractor.jl/issues/152 + # Though we have now removed the underlying cause, we keep this as a regression test just in case struct Foo152 x::Float64 end # Unit Test - cb = CompositeBundle{1, Foo152}((TaylorBundle{1, Float64}(23.5, (1.0,)),)) + cb = TaylorBundle{1, Foo152}(Foo152(23.5), (Tangent{Foo152}(;x=1.0),)) tti = TaylorTangentIndex(1,) @test cb[tti] == Tangent{Foo152}(; x=1.0) # Integration Test - var"'" = Diffractor.PrimeDerivativeFwd - f(x) = Foo152(x) - @test f'(23.5) == Tangent{Foo152}(; x=1.0) + let var"'" = Diffractor.PrimeDerivativeFwd + f(x) = Foo152(x) + @test f'(23.5) == Tangent{Foo152}(; x=1.0) + end end @testset "truncate" begin From cf137a085c7587e7d16bbad7cc7f7c322d2525dd Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 20 Sep 2023 14:35:22 +0800 Subject: [PATCH 04/15] fix sin_twice_fwd test --- src/stage1/recurse_fwd.jl | 12 ++++++++---- src/tangent.jl | 2 +- test/forward.jl | 1 - 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index a1c23bc4..797f49d4 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -14,9 +14,9 @@ struct ∂☆new{N}; end # but the nth order case does also work for this function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...) primal_args = map(primal, xs) - the_primal = B <: Tuple ? B(primal_args) : B(primal_args...) - - tangent_tup = map(x->partial(x, 1), xs) + the_primal = _construct(B, primal_args) + + tangent_tup = map(first_partial, xs) the_partial = if B<:Tuple Tangent{B, typeof(tangent_tup)}(tangent_tup) else @@ -29,7 +29,7 @@ end function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N} primal_args = map(primal, xs) - the_primal = B <: Tuple ? B(primal_args) : B(primal_args...) + the_primal = _construct(B, primal_args) the_partials = ntuple(Val{N}()) do ii iith_order_type = ii==1 ? B : Any # the type of the higher order tangents isn't worth tracking @@ -46,6 +46,10 @@ function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N} return TaylorBundle{N, B}(the_primal, the_partials) end +_construct(::Type{B}, args) where B<:Tuple = B(args) +# Hack for making things that do not have public constructors constructable: +@generated _construct(B::Type, args) = :($(Expr(:splatnew, :B, :args))) + @generated (::∂☆new{N})(B::Type) where {N} = return :(ZeroBundle{$N}($(Expr(:new, :B)))) # Sometimes we don't know whether or not we need to the ZeroBundle when doing diff --git a/src/tangent.jl b/src/tangent.jl index a350e468..2e21a6ea 100644 --- a/src/tangent.jl +++ b/src/tangent.jl @@ -202,7 +202,7 @@ end const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}} function TaylorBundle{N, B}(primal::B, coeffs) where {N, B} - check_taylor_invariants(coeffs, primal, N) +# check_taylor_invariants(coeffs, primal, N) # TODO: renable this _TangentBundle(Val{N}(), primal, TaylorTangent(coeffs)) end diff --git a/test/forward.jl b/test/forward.jl index 8ea24cde..48ccc76b 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -39,7 +39,6 @@ let var"'" = Diffractor.PrimeDerivativeFwd end # Some Basic Mixed Mode tests -# TODO: unbreak this function sin_twice_fwd(x) let var"'" = Diffractor.PrimeDerivativeFwd sin''(x) From 13762e7c4553478f185b0795fddff978eff42bbf Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 20 Sep 2023 14:51:07 +0800 Subject: [PATCH 05/15] Remove check that if primal is TangentBundle so is coeff --- src/tangent.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/tangent.jl b/src/tangent.jl index 2e21a6ea..8debab20 100644 --- a/src/tangent.jl +++ b/src/tangent.jl @@ -202,15 +202,13 @@ end const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}} function TaylorBundle{N, B}(primal::B, coeffs) where {N, B} -# check_taylor_invariants(coeffs, primal, N) # TODO: renable this + check_taylor_invariants(coeffs, primal, N) _TangentBundle(Val{N}(), primal, TaylorTangent(coeffs)) end function check_taylor_invariants(coeffs, primal, N) @assert length(coeffs) == N - if isa(primal, TangentBundle) - @assert isa(coeffs[1], TangentBundle) - end + end @ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N) From f4f996dbdfe19a74bb04913da0ea16dc6f03d654 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 20 Sep 2023 15:46:00 +0800 Subject: [PATCH 06/15] remove condition on getindex overload --- src/stage1/compiler_utils.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/stage1/compiler_utils.jl b/src/stage1/compiler_utils.jl index 1b5d6413..c2442ce0 100644 --- a/src/stage1/compiler_utils.jl +++ b/src/stage1/compiler_utils.jl @@ -7,9 +7,7 @@ function Base.push!(cfg::CFG, bb::BasicBlock) push!(cfg.index, bb.stmts.start) end -if VERSION <= v"1.11.0-DEV.116" - Base.getindex(ir::IRCode, ssa::SSAValue) = Core.Compiler.getindex(ir, ssa) -end +Base.getindex(ir::IRCode, ssa::SSAValue) = Core.Compiler.getindex(ir, ssa) Base.copy(ir::IRCode) = Core.Compiler.copy(ir) From 50383b659d9ce19afb9873b286da69095a1b41e2 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 20 Sep 2023 16:37:23 +0800 Subject: [PATCH 07/15] Port over shuffle_up --- src/stage1/forward.jl | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index e771da98..9b5d89d6 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -47,7 +47,6 @@ function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2} end end -#== function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N} primal(b) === a[TaylorTangentIndex(1)] || return false return all(1:(N-1)) do i @@ -55,27 +54,32 @@ function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N} end end -# Check whether the tangent bundle element is taylor-like -isswifty(::TaylorBundle) = true -isswifty(::UniformBundle) = true -isswifty(b::CompositeBundle) = all(isswifty, b.tup) -isswifty(::Any) = false - -#TODO: port this to TaylorTangent over composite structures -function shuffle_up(r::CompositeBundle{N}) where {N} - a, b = r.tup - if isswifty(a) && isswifty(b) && taylor_compatible(a, b) - return TaylorBundle{N+1}(primal(a), - ntuple(i->i == N+1 ? - b[TaylorTangentIndex(i-1)] : a[TaylorTangentIndex(i)], - N+1)) +function taylor_compatible(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2} + partial(r, 1)[1] = primal(r)[2] || return false + return all(1:N-1) do ii + partial(r, i+1)[1] == partial(r, i)[2] + end +end +function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2} + the_primal = primal(r)[1] + if taylor_compatible(r) + the_partials = ntuple(N+1) do i + if ii <= N + partial(r, i)[1] # == `partial(r,i-1)[2]` (except first which is primal(r)[2]) + else # ii = N+1 + partial(r, i-1)[2] + end + end + return TaylorBundle{N+1}(the_primal, the_partials) else - return TangentBundle{N+1}(r.tup[1].primal, - (r.tup[1].tangent.partials..., primal(b), - ntuple(i->partial(b,i), 1<<(N+1)-1)...)) + #XXX: am dubious of the correctness of this + a_partials = ntuple(i->partial(r, ii)[1], N) + b_partials = ntuple(i->partial(r, ii)[2], N) + the_partials = (a_partials..., primal_b, b_partials...) + return TangentBundle{N+1}(the_primal, the_partials) end end -==# + function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U} (a, b) = primal(r) From bfd99c25f07c5e6300037154cf625b9332185066 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 20 Sep 2023 18:39:11 +0800 Subject: [PATCH 08/15] Port map special handling use fieldcount taylor bundle not taylor tangent --- src/stage1/forward.jl | 7 ++----- src/tangent.jl | 12 ++++++++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 9b5d89d6..1027632c 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -194,12 +194,9 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}} end (f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...) -#== -# TODO port this to TaylorBundle over composite structure -function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::CompositeBundle{N, <:Tuple}) where {N} - ∂vararg{N}()(map(FwdMap(f), tup.tup)...) +function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N} + ∂vararg{N}()(map(FwdMap(f), destructure(tup))...) end -==# function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N} # TODO: This could do an inplace map! to avoid the extra rebundling diff --git a/src/tangent.jl b/src/tangent.jl index 8debab20..6af77211 100644 --- a/src/tangent.jl +++ b/src/tangent.jl @@ -228,6 +228,18 @@ function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex) tb.tangent.coeffs[count_ones(tti.i)] end +"for a TaylorTangent{N, <:Tuple} this breaks it up unto 1 TaylorTangent{N} for each element of the primal tuple" +function destructure(r::TaylorBundle{N, B}) where {N, B<:Tuple} + return ntuple(fieldcount(B)) do field_ii + the_primal = primal(r)[field_ii] + the_partials = ntuple(N) do order_ii + partial(r, order_ii)[field_ii] + end + return TaylorBundle{N}(the_primal, the_partials) + end +end + + function truncate(tt::TaylorTangent, order::Val{N}) where {N} TaylorTangent(tt.coeffs[1:N]) end From 121ec20a10f64ba28852785a4a5bf15e7ef33357 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 20 Sep 2023 19:26:44 +0800 Subject: [PATCH 09/15] =?UTF-8?q?port=20remaining=20=E2=88=82=E2=98=86=20o?= =?UTF-8?q?verloads=20to=20TaylorTangent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/stage1/forward.jl | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 1027632c..3e0da306 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -238,30 +238,26 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate Core._apply_iterate(FwdIterate(iterate), this, (f,), args...) end -#== -#TODO: port this to TaylorTangent over composite structures -function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}) where {N} - r = iterate(t.tup) + +function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N} + r = iterate(destructure(t)) r === nothing && return ZeroBundle{N}(nothing) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end -#TODO: port this to TaylorTangent over composite structures -function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N} - r = iterate(t.tup, primal(a), map(primal, args)...) +function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N} + r = iterate(destructure(t), primal(a), map(primal, args)...) r === nothing && return ZeroBundle{N}(nothing) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end -#TODO: port this to TaylorTangent over composite structures -function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}) where {N} - r = Base.indexed_iterate(t.tup, primal(i)) +function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N} + r = Base.indexed_iterate(destructure(t), primal(i)) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end -#TODO: port this to TaylorTangent over composite structures -function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N} - r = Base.indexed_iterate(t.tup, primal(i), primal(st1), map(primal, st)...) +function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N} + r = Base.indexed_iterate(destructure(t), primal(i), primal(st1), map(primal, st)...) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end @@ -269,11 +265,11 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::Tan ∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1)) end -#TODO: port this to TaylorTangent over composite structures -function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::CompositeBundle{N, <:Tuple}, i::ZeroBundle) where {N} - t.tup[primal(i)] +function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::ZeroBundle) where {N} + field_ind = primal(i) + the_partials = ntuple(order_ind->partial(t, order_ind)[field_ind], N) + TaylorBundle{N}(primal(t)[field_ind], the_partials) end -==# function (this::∂☆{N})(::ZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N} DNEBundle{N}(typeof(primal(x))) From b257eeb3e18e9728f7b26f8435c50b8fa3c7bdeb Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 20 Sep 2023 19:55:19 +0800 Subject: [PATCH 10/15] Fix/update version check on if to define getindex(ir,ssa) --- src/stage1/compiler_utils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/stage1/compiler_utils.jl b/src/stage1/compiler_utils.jl index c2442ce0..651a063d 100644 --- a/src/stage1/compiler_utils.jl +++ b/src/stage1/compiler_utils.jl @@ -7,7 +7,9 @@ function Base.push!(cfg::CFG, bb::BasicBlock) push!(cfg.index, bb.stmts.start) end -Base.getindex(ir::IRCode, ssa::SSAValue) = Core.Compiler.getindex(ir, ssa) +if VERSION < v"1.11.0-DEV.258" + Base.getindex(ir::IRCode, ssa::SSAValue) = Core.Compiler.getindex(ir, ssa) +end Base.copy(ir::IRCode) = Core.Compiler.copy(ir) From b9f74ce0fb9b69a5e70ece7adb2dd250633d36da Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 21 Sep 2023 10:04:33 +0800 Subject: [PATCH 11/15] use i --- src/stage1/forward.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 3e0da306..29fcba33 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -73,8 +73,8 @@ function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2} return TaylorBundle{N+1}(the_primal, the_partials) else #XXX: am dubious of the correctness of this - a_partials = ntuple(i->partial(r, ii)[1], N) - b_partials = ntuple(i->partial(r, ii)[2], N) + a_partials = ntuple(i->partial(r, i)[1], N) + b_partials = ntuple(i->partial(r, i)[2], N) the_partials = (a_partials..., primal_b, b_partials...) return TangentBundle{N+1}(the_primal, the_partials) end From 34cedf89855ad2c975817b88f6157849983bcffb Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 21 Sep 2023 10:25:27 +0800 Subject: [PATCH 12/15] remove redudant interpolation --- src/stage1/recurse_fwd.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index 797f49d4..2c561e73 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -48,7 +48,7 @@ end _construct(::Type{B}, args) where B<:Tuple = B(args) # Hack for making things that do not have public constructors constructable: -@generated _construct(B::Type, args) = :($(Expr(:splatnew, :B, :args))) +@generated _construct(B::Type, args) = Expr(:splatnew, :B, :args) @generated (::∂☆new{N})(B::Type) where {N} = return :(ZeroBundle{$N}($(Expr(:new, :B)))) From 97946665ab38bbf143686fcdac954b4939f0d664 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 21 Sep 2023 14:20:17 +0800 Subject: [PATCH 13/15] fix and test taylor_compatible --- src/stage1/forward.jl | 6 +++--- test/forward.jl | 43 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 29fcba33..13a5f7a8 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -55,8 +55,8 @@ function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N} end function taylor_compatible(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2} - partial(r, 1)[1] = primal(r)[2] || return false - return all(1:N-1) do ii + partial(r, 1)[1] == primal(r)[2] || return false + return all(1:N-1) do i partial(r, i+1)[1] == partial(r, i)[2] end end @@ -64,7 +64,7 @@ function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2} the_primal = primal(r)[1] if taylor_compatible(r) the_partials = ntuple(N+1) do i - if ii <= N + if i <= N partial(r, i)[1] # == `partial(r,i-1)[2]` (except first which is primal(r)[2]) else # ii = N+1 partial(r, i-1)[2] diff --git a/test/forward.jl b/test/forward.jl index 48ccc76b..70bd0e8e 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -1,16 +1,12 @@ module forward_tests using Diffractor -using Diffractor: var"'", ∂⃖, DiffractorRuleConfig, ZeroBundle +using Diffractor: TaylorBundle using ChainRules using ChainRulesCore using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad using LinearAlgebra - using Test -const fwd = Diffractor.PrimeDerivativeFwd -const bwd = Diffractor.PrimeDerivativeBack - # Minimal 2-nd order forward smoke test @@ -131,4 +127,41 @@ end end end + +@testset "taylor_compatible" begin + taylor_compatible = Diffractor.taylor_compatible + + @test taylor_compatible( + TaylorBundle{1}(10.0, (20.0,)), + TaylorBundle{1}(20.0, (30.0,)) + ) + @test !taylor_compatible( + TaylorBundle{1}(10.0, (20.0,)), + TaylorBundle{1}(21.0, (30.0,)) + ) + @test taylor_compatible( + TaylorBundle{2}(10.0, (20.0, 30.)), + TaylorBundle{2}(20.0, (30.0, 40.)) + ) + @test !taylor_compatible( + TaylorBundle{2}(10.0, (20.0, 30.0)), + TaylorBundle{2}(20.0, (31.0, 40.0)) + ) + + + tuptan(args...) = Tangent{typeof(args)}(args...) + @test taylor_compatible( + TaylorBundle{1}((10.0, 20.0), (tuptan(20.0, 30.0),)), + ) + @test taylor_compatible( + TaylorBundle{2}((10.0, 20.0), (tuptan(20.0, 30.0),tuptan(30.0, 40.0))), + ) + @test !taylor_compatible( + TaylorBundle{1}((10.0, 20.0), (tuptan(21.0, 30.0),)), + ) + @test !taylor_compatible( + TaylorBundle{2}((10.0, 20.0), (tuptan(20.0, 31.0),tuptan(30.0, 40.0))), + ) +end + end From ba5841e04e5e2c294d1c3e8a2e1754c827279c9d Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 26 Sep 2023 14:50:55 +0800 Subject: [PATCH 14/15] remove excess whitepace Co-authored-by: Elliot Saba --- test/tangent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tangent.jl b/test/tangent.jl index 8115c653..01a54607 100644 --- a/test/tangent.jl +++ b/test/tangent.jl @@ -35,7 +35,7 @@ end end # Unit Test - cb = TaylorBundle{1, Foo152}(Foo152(23.5), (Tangent{Foo152}(;x=1.0),)) + cb = TaylorBundle{1, Foo152}(Foo152(23.5), (Tangent{Foo152}(;x=1.0),)) tti = TaylorTangentIndex(1,) @test cb[tti] == Tangent{Foo152}(; x=1.0) From 40da34b26d55f169fccbb8147f9fdf97569b8054 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 26 Sep 2023 16:10:53 +0800 Subject: [PATCH 15/15] fix missing imports --- test/forward.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/forward.jl b/test/forward.jl index 70bd0e8e..53b65059 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -1,6 +1,6 @@ module forward_tests using Diffractor -using Diffractor: TaylorBundle +using Diffractor: TaylorBundle, ZeroBundle using ChainRules using ChainRulesCore using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad @@ -10,8 +10,10 @@ using Test # Minimal 2-nd order forward smoke test -@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin), - Diffractor.ExplicitTangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0) +let var"'" = Diffractor.PrimeDerivativeFwd + @test Diffractor.∂☆{2}()(ZeroBundle{2}(sin), + Diffractor.ExplicitTangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0) +end # Simple Forward Mode tests let var"'" = Diffractor.PrimeDerivativeFwd