diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index e70604cd..3320e3ae 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -10,19 +10,7 @@ This is more or less the Diffractor equivelent of ForwardDiff.jl's `Dual` type. """ function bundle end bundle(x, dx::ChainRulesCore.AbstractZero) = UniformBundle{1, typeof(x), typeof(dx)}(x, dx) -bundle(x::Number, dx::Number) = TaylorBundle{1}(x, (dx,)) -bundle(x::AbstractArray{<:Number}, dx::AbstractArray{<:Number}) = TaylorBundle{1}(x, (dx,)) -bundle(x::P, dx::Tangent{P}) where P = _bundle(x, ChainRulesCore.canonicalize(dx)) - -"helper that assumes tangent is in canonical form" -function _bundle(x::P, dx::Tangent{P}) where P - # SoA to AoS flip (hate this, hate it even more cos we just undo it later when we hit chainrules) - the_bundle = ntuple(Val{fieldcount(P)}()) do ii - bundle(getfield(x, ii), getproperty(dx, ii)) - end - return CompositeBundle{1, P}(the_bundle) -end - +bundle(x, dx) = TaylorBundle{1}(x, (dx,)) AD.@primitive function pushforward_function(b::DiffractorForwardBackend, f, args...) return function pushforward(vs) diff --git a/src/stage1/compiler_utils.jl b/src/stage1/compiler_utils.jl index 1b5d6413..651a063d 100644 --- a/src/stage1/compiler_utils.jl +++ b/src/stage1/compiler_utils.jl @@ -7,7 +7,7 @@ function Base.push!(cfg::CFG, bb::BasicBlock) push!(cfg.index, bb.stmts.start) end -if VERSION <= v"1.11.0-DEV.116" +if VERSION < v"1.11.0-DEV.258" Base.getindex(ir::IRCode, ssa::SSAValue) = Core.Compiler.getindex(ir, ssa) end diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 5393f68d..13a5f7a8 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -4,14 +4,6 @@ partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i) partial(x::UniformTangent, i) = getfield(x, :val) partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors))) partial(x::AbstractZero, i) = x -partial(x::CompositeBundle{N, B}, i) where {N, B<:Tuple} = Tangent{B}(map(x->partial(x, i), getfield(x, :tup))...) -function partial(x::CompositeBundle{N, B}, i) where {N, B} - # This is tangent for a struct, but fields partials are each stored in a plain tuple - # so we add the names back using the primal `B` - # TODO: If required this can be done as a `@generated` function so it is type-stable - backing = NamedTuple{fieldnames(B)}(map(x->partial(x, i), getfield(x, :tup))) - return Tangent{B, typeof(backing)}(backing) -end primal(x::AbstractTangentBundle) = x.primal @@ -42,20 +34,12 @@ function shuffle_down(b::TaylorBundle{N, B}) where {N, B} ntuple(_sdown, N-1)) end -function shuffle_down(b::CompositeBundle{N, B}) where {N, B} - z = CompositeBundle{N-1, CompositeBundle{1, B}}( - (CompositeBundle{N-1, Tuple}( - map(shuffle_down, b.tup) - ),) - ) - z -end -function shuffle_up(r::CompositeBundle{1}) - z₀ = primal(r.tup[1]) - z₁ = partial(r.tup[1], 1) - z₂ = primal(r.tup[2]) - z₁₂ = partial(r.tup[2], 1) +function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2} + z₀ = primal(r)[1] + z₁ = partial(r, 1)[1] + z₂ = primal(r)[2] + z₁₂ = partial(r, 1)[2] if z₁ == z₂ return TaylorBundle{2}(z₀, (z₁, z₁₂)) else @@ -70,26 +54,33 @@ function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N} end end -# Check whether the tangent bundle element is taylor-like -isswifty(::TaylorBundle) = true -isswifty(::UniformBundle) = true -isswifty(b::CompositeBundle) = all(isswifty, b.tup) -isswifty(::Any) = false - -function shuffle_up(r::CompositeBundle{N}) where {N} - a, b = r.tup - if isswifty(a) && isswifty(b) && taylor_compatible(a, b) - return TaylorBundle{N+1}(primal(a), - ntuple(i->i == N+1 ? - b[TaylorTangentIndex(i-1)] : a[TaylorTangentIndex(i)], - N+1)) +function taylor_compatible(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2} + partial(r, 1)[1] == primal(r)[2] || return false + return all(1:N-1) do i + partial(r, i+1)[1] == partial(r, i)[2] + end +end +function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2} + the_primal = primal(r)[1] + if taylor_compatible(r) + the_partials = ntuple(N+1) do i + if i <= N + partial(r, i)[1] # == `partial(r,i-1)[2]` (except first which is primal(r)[2]) + else # ii = N+1 + partial(r, i-1)[2] + end + end + return TaylorBundle{N+1}(the_primal, the_partials) else - return TangentBundle{N+1}(r.tup[1].primal, - (r.tup[1].tangent.partials..., primal(b), - ntuple(i->partial(b,i), 1<<(N+1)-1)...)) + #XXX: am dubious of the correctness of this + a_partials = ntuple(i->partial(r, i)[1], N) + b_partials = ntuple(i->partial(r, i)[2], N) + the_partials = (a_partials..., primal_b, b_partials...) + return TangentBundle{N+1}(the_primal, the_partials) end end + function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U} (a, b) = primal(r) if r.tangent.val === b @@ -185,13 +176,6 @@ end map(y->lifted_getfield(y, s), x.tangent.coeffs)) end -@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N} - x.tup[primal(s)] -end - -@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B} - x.tup[Base.fieldindex(B, primal(s))] -end @Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U} UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.tangent.val) @@ -210,8 +194,8 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}} end (f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...) -function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::CompositeBundle{N, <:Tuple}) where {N} - ∂vararg{N}()(map(FwdMap(f), tup.tup)...) +function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N} + ∂vararg{N}()(map(FwdMap(f), destructure(tup))...) end function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N} @@ -254,25 +238,26 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate Core._apply_iterate(FwdIterate(iterate), this, (f,), args...) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}) where {N} - r = iterate(t.tup) + +function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N} + r = iterate(destructure(t)) r === nothing && return ZeroBundle{N}(nothing) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N} - r = iterate(t.tup, primal(a), map(primal, args)...) +function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N} + r = iterate(destructure(t), primal(a), map(primal, args)...) r === nothing && return ZeroBundle{N}(nothing) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}) where {N} - r = Base.indexed_iterate(t.tup, primal(i)) +function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N} + r = Base.indexed_iterate(destructure(t), primal(i)) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end -function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N} - r = Base.indexed_iterate(t.tup, primal(i), primal(st1), map(primal, st)...) +function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N} + r = Base.indexed_iterate(destructure(t), primal(i), primal(st1), map(primal, st)...) ∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) end @@ -280,9 +265,10 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::Tan ∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1)) end - -function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::CompositeBundle{N, <:Tuple}, i::ZeroBundle) where {N} - t.tup[primal(i)] +function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::ZeroBundle) where {N} + field_ind = primal(i) + the_partials = ntuple(order_ind->partial(t, order_ind)[field_ind], N) + TaylorBundle{N}(primal(t)[field_ind], the_partials) end function (this::∂☆{N})(::ZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N} diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index 0c5ab07c..2c561e73 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -4,13 +4,51 @@ struct ∂vararg{N}; end (::∂vararg{N})() where {N} = ZeroBundle{N}(()) function (::∂vararg{N})(a::AbstractTangentBundle{N}...) where N - CompositeBundle{N, Tuple{map(x->basespace(typeof(x)), a)...}}(a) + B = Tuple{map(x->basespace(Core.Typeof(x)), a)...} + return (∂☆new{N}())(B, a...) end struct ∂☆new{N}; end -(::∂☆new{N})(B::Type, a::AbstractTangentBundle{N}...) where {N} = - CompositeBundle{N, B}(a) +# we split out the 1st order derivative as a special case for performance +# but the nth order case does also work for this +function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...) + primal_args = map(primal, xs) + the_primal = _construct(B, primal_args) + + tangent_tup = map(first_partial, xs) + the_partial = if B<:Tuple + Tangent{B, typeof(tangent_tup)}(tangent_tup) + else + names = fieldnames(B) + tangent_nt = NamedTuple{names}(tangent_tup) + Tangent{B, typeof(tangent_nt)}(tangent_nt) + end + return TaylorBundle{1, B}(the_primal, (the_partial,)) +end + +function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N} + primal_args = map(primal, xs) + the_primal = _construct(B, primal_args) + + the_partials = ntuple(Val{N}()) do ii + iith_order_type = ii==1 ? B : Any # the type of the higher order tangents isn't worth tracking + tangent_tup = map(x->partial(x, ii), xs) + tangent = if B<:Tuple + Tangent{iith_order_type, typeof(tangent_tup)}(tangent_tup) + else + names = fieldnames(B) + tangent_nt = NamedTuple{names}(tangent_tup) + Tangent{iith_order_type, typeof(tangent_nt)}(tangent_nt) + end + return tangent + end + return TaylorBundle{N, B}(the_primal, the_partials) +end + +_construct(::Type{B}, args) where B<:Tuple = B(args) +# Hack for making things that do not have public constructors constructable: +@generated _construct(B::Type, args) = Expr(:splatnew, :B, :args) @generated (::∂☆new{N})(B::Type) where {N} = return :(ZeroBundle{$N}($(Expr(:new, :B)))) diff --git a/src/tangent.jl b/src/tangent.jl index 10fe9ada..6af77211 100644 --- a/src/tangent.jl +++ b/src/tangent.jl @@ -208,9 +208,7 @@ end function check_taylor_invariants(coeffs, primal, N) @assert length(coeffs) == N - if isa(primal, TangentBundle) - @assert isa(coeffs[1], TangentBundle) - end + end @ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N) @@ -230,6 +228,18 @@ function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex) tb.tangent.coeffs[count_ones(tti.i)] end +"for a TaylorTangent{N, <:Tuple} this breaks it up unto 1 TaylorTangent{N} for each element of the primal tuple" +function destructure(r::TaylorBundle{N, B}) where {N, B<:Tuple} + return ntuple(fieldcount(B)) do field_ii + the_primal = primal(r)[field_ii] + the_partials = ntuple(N) do order_ii + partial(r, order_ii)[field_ii] + end + return TaylorBundle{N}(the_primal, the_partials) + end +end + + function truncate(tt::TaylorTangent, order::Val{N}) where {N} TaylorTangent(tt.coeffs[1:N]) end @@ -290,33 +300,6 @@ end Base.getindex(u::UniformBundle, ::TaylorTangentIndex) = u.tangent.val -""" - CompositeBundle{N, B, B <: Tuple} - -Represents the tagent bundle where the base space is some tuple or struct type. -Mathematically, this tangent bundle is the product bundle of the individual -element bundles. -""" -struct CompositeBundle{N, B, T<:Tuple{Vararg{AbstractTangentBundle{N}}}} <: AbstractTangentBundle{N, B} - tup::T -end -CompositeBundle{N, B}(tup::T) where {N, B, T} = CompositeBundle{N, B, T}(tup) - -function Base.getindex(tb::CompositeBundle{N, B} where N, tti::TaylorTangentIndex) where {B} - B <: SArray && error() - return partial(tb, tti.i) -end - -primal(b::CompositeBundle{N, <:Tuple} where N) = map(primal, b.tup) -function primal(b::CompositeBundle{N, T} where N) where T<:CompositeBundle - T(map(primal, b.tup)...) -end -@generated primal(b::CompositeBundle{N, B} where N) where {B} = - quote - x = map(primal, b.tup) - $(Expr(:splatnew, B, :x)) - end - expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...) expand_singleton_to_array(asize, a::AbstractArray) = a diff --git a/test/AbstractDifferentiationTests.jl b/test/AbstractDifferentiationTests.jl index 66bcc1fd..821adc99 100644 --- a/test/AbstractDifferentiationTests.jl +++ b/test/AbstractDifferentiationTests.jl @@ -7,14 +7,14 @@ backend = Diffractor.DiffractorForwardBackend() @test bundle(1.0, 2.0) isa Diffractor.TaylorBundle{1} @test bundle([1.0, 2.0], [2.0, 3.0]) isa Diffractor.TaylorBundle{1} - @test bundle(1.5=>2.5, Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0)) isa Diffractor.CompositeBundle{1} + @test bundle(1.5=>2.5, Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0)) isa Diffractor.TaylorBundle{1} @test bundle(1.1, ChainRulesCore.ZeroTangent()) isa Diffractor.ZeroBundle{1} - @test bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(first=1.0, second=Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0))) isa Diffractor.CompositeBundle{1} + @test bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(first=1.0, second=Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0))) isa Diffractor.TaylorBundle{1} # noncanonical structural tangent b = bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(second=Tangent{Pair{Float64, Float64}}(second=2.0, first=1.0))) t = Diffractor.first_partial(b) - @test b isa Diffractor.CompositeBundle{1} + @test b isa Diffractor.TaylorBundle{1} @test iszero(t.first) @test t.second.first == 1.0 @test t.second.second == 2.0 @@ -29,7 +29,7 @@ end # standard tests from AbstractDifferentiation.test_utils include(joinpath(pathof(AbstractDifferentiation), "..", "..", "test", "test_utils.jl")) -@testset "ForwardDiffBackend" begin +@testset "Standard AbstractDifferentiation.test_utils tests" begin backends = [ @inferred(Diffractor.DiffractorForwardBackend()) ] diff --git a/test/forward.jl b/test/forward.jl index 9ac5bea1..53b65059 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -1,21 +1,19 @@ module forward_tests using Diffractor -using Diffractor: var"'", ∂⃖, DiffractorRuleConfig, ZeroBundle +using Diffractor: TaylorBundle, ZeroBundle using ChainRules using ChainRulesCore using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad using LinearAlgebra - using Test -const fwd = Diffractor.PrimeDerivativeFwd -const bwd = Diffractor.PrimeDerivativeBack - # Minimal 2-nd order forward smoke test -@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin), - Diffractor.ExplicitTangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0) +let var"'" = Diffractor.PrimeDerivativeFwd + @test Diffractor.∂☆{2}()(ZeroBundle{2}(sin), + Diffractor.ExplicitTangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0) +end # Simple Forward Mode tests let var"'" = Diffractor.PrimeDerivativeFwd @@ -25,8 +23,7 @@ let var"'" = Diffractor.PrimeDerivativeFwd # Integration tests @test recursive_sin'(1.0) == cos(1.0) @test recursive_sin''(1.0) == -sin(1.0) - # Error: ArgumentError: Tangent for the primal Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}} - # should be backed by a NamedTuple type, not by Tuple{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}. + @test_broken recursive_sin'''(1.0) == -cos(1.0) @test_broken recursive_sin''''(1.0) == sin(1.0) @test_broken recursive_sin'''''(1.0) == cos(1.0) @@ -90,4 +87,83 @@ end end end + +@testset "structs" begin + struct IDemo + x::Float64 + y::Float64 + end + + function foo(a) + obj = IDemo(2.0, a) + return obj.x * obj.y + end + + let var"'" = Diffractor.PrimeDerivativeFwd + @test foo'(100.0) == 2.0 + @test foo''(100.0) == 0.0 + end +end + +@testset "tuples" begin + function foo(a) + tup = (2.0, a) + return first(tup) * tup[2] + end + + let var"'" = Diffractor.PrimeDerivativeFwd + @test foo'(100.0) == 2.0 + @test foo''(100.0) == 0.0 + end +end + +@testset "vararg" begin + function foo(a) + tup = (2.0, a) + return *(tup...) + end + + let var"'" = Diffractor.PrimeDerivativeFwd + @test foo'(100.0) == 2.0 + @test foo''(100.0) == 0.0 + end +end + + +@testset "taylor_compatible" begin + taylor_compatible = Diffractor.taylor_compatible + + @test taylor_compatible( + TaylorBundle{1}(10.0, (20.0,)), + TaylorBundle{1}(20.0, (30.0,)) + ) + @test !taylor_compatible( + TaylorBundle{1}(10.0, (20.0,)), + TaylorBundle{1}(21.0, (30.0,)) + ) + @test taylor_compatible( + TaylorBundle{2}(10.0, (20.0, 30.)), + TaylorBundle{2}(20.0, (30.0, 40.)) + ) + @test !taylor_compatible( + TaylorBundle{2}(10.0, (20.0, 30.0)), + TaylorBundle{2}(20.0, (31.0, 40.0)) + ) + + + tuptan(args...) = Tangent{typeof(args)}(args...) + @test taylor_compatible( + TaylorBundle{1}((10.0, 20.0), (tuptan(20.0, 30.0),)), + ) + @test taylor_compatible( + TaylorBundle{2}((10.0, 20.0), (tuptan(20.0, 30.0),tuptan(30.0, 40.0))), + ) + @test !taylor_compatible( + TaylorBundle{1}((10.0, 20.0), (tuptan(21.0, 30.0),)), + ) + @test !taylor_compatible( + TaylorBundle{2}((10.0, 20.0), (tuptan(20.0, 31.0),tuptan(30.0, 40.0))), + ) +end + end diff --git a/test/tangent.jl b/test/tangent.jl index fae513f5..01a54607 100644 --- a/test/tangent.jl +++ b/test/tangent.jl @@ -1,7 +1,7 @@ module tagent using Diffractor using Diffractor: AbstractZeroBundle, ZeroBundle, DNEBundle -using Diffractor: TaylorBundle, TaylorTangentIndex, CompositeBundle +using Diffractor: TaylorBundle, TaylorTangentIndex using Diffractor: ExplicitTangent, TaylorTangent, truncate using ChainRulesCore using Test @@ -28,21 +28,22 @@ using Test end @testset "AD through constructor" begin - #https://github.com/JuliaDiff/Diffractor.jl/issues/152 - # hits `getindex(::CompositeBundle{Foo152}, ::TaylorTangentIndex)` + # https://github.com/JuliaDiff/Diffractor.jl/issues/152 + # Though we have now removed the underlying cause, we keep this as a regression test just in case struct Foo152 x::Float64 end # Unit Test - cb = CompositeBundle{1, Foo152}((TaylorBundle{1, Float64}(23.5, (1.0,)),)) + cb = TaylorBundle{1, Foo152}(Foo152(23.5), (Tangent{Foo152}(;x=1.0),)) tti = TaylorTangentIndex(1,) @test cb[tti] == Tangent{Foo152}(; x=1.0) # Integration Test - var"'" = Diffractor.PrimeDerivativeFwd - f(x) = Foo152(x) - @test f'(23.5) == Tangent{Foo152}(; x=1.0) + let var"'" = Diffractor.PrimeDerivativeFwd + f(x) = Foo152(x) + @test f'(23.5) == Tangent{Foo152}(; x=1.0) + end end @testset "truncate" begin