From 6ce5e5d600f23b0ad9e8a7b20fa918aa907c8d45 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 25 Jan 2024 14:20:15 +0800 Subject: [PATCH 1/2] WIP: premapping based cyclic type zeros --- src/tangent_types/abstract_zero.jl | 44 ++++++++++++++++++++++++++++- test/tangent_types/abstract_zero.jl | 18 ++++++------ 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index f921db29d..d52f213d0 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -138,7 +138,19 @@ end ) Expr(:kw, fname, fval) end - return if has_mutable_tangent(primal) + + # easy case exit early, can't hold references, can't be a reference. + if isbitstype(primal) + return :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...)))) + end + + # hard case need to be prepared for cycic references to this, or that are contained within this + quote + counts = $count_references!(primal) + end + +## TODO rewrite below + has_mutable_tangent(primal) any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype # If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype)) @@ -171,6 +183,36 @@ function zero_tangent(x::Array{P,N}) where {P,N} return y end +############################################### +count_references!(x) = count_references(IdDict{Any, Int}(), x) +function count_references!(counts::IdDict{Any, Int}, x) + isbits(x) && return counts # can't be a refernece and can't hold a reference + counts[x] = get(counts, x, 0) + 1 # Increment *before* recursing + if counts[x] == 1 # Only recurse the first time + for ii in fieldcount(typeof(x)) + field = getfield(x, ii) + count_references!(counts, field) + end + end + return counts +end + +function count_references!(counts::IdDict{Any, Int}, x::Array) + counts[x] = get(counts, x, 0) + 1 # increment before recursing + isbitstype(eltype(x)) && return counts # no need to look inside, it can't hold references + if counts[x] == 1 # only recurse the first time + for ele in x + count_references!(counts, ele) + end + end + return counts +end + +count_references!(counts::IdDict{Any, Int}, ::DataType) = counts + +############################################### + + # Sad heauristic methods we need because of unassigned values guess_zero_tangent_type(::Type{T}) where {T<:Number} = T guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T))) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 245d9a29d..08cdaaf28 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -303,24 +303,24 @@ end lk = Link(1.5) lk.next = lk - @test_broken d = zero_tangent(lk) - @test_broken d.data == 0.0 - @test_broken d.next === d + d = zero_tangent(lk) + @test d.data == 0.0 + @test d.next === d struct CarryingArray x::Vector end ca = CarryingArray(Any[1.5]) push!(ca.x, ca) - @test_broken d_ca = zero_tangent(ca) - @test_broken d_ca[1] == 0.0 - @test_broken d_ca[2] === _ca + @test d_ca = zero_tangent(ca) + @test d_ca[1] == 0.0 + @test d_ca[2] === _ca # Idea: check if typeof(xs) <: eltype(xs), if so need to cache it before computing xs = Any[1.5] push!(xs, xs) - @test_broken d_xs = zero_tangent(xs) - @test_broken d_xs[1] == 0.0 - @test_broken d_xs[2] == d_xs + @test d_xs = zero_tangent(xs) + @test d_xs[1] == 0.0 + @test d_xs[2] == d_xs end end From 9df22e9350e3b68230fb13fd76878ce564e248fd Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 25 Jan 2024 17:55:02 +0800 Subject: [PATCH 2/2] WIP --- src/tangent_types/abstract_zero.jl | 44 ++++++++++++++++-------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index d52f213d0..d9df07baa 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -126,36 +126,38 @@ end @generated function zero_tangent(primal) fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero. - zfield_exprs = map(fieldnames(primal)) do fname - fval = :( - if isdefined(primal, $(QuoteNode(fname))) - zero_tangent(getfield(primal, $(QuoteNode(fname)))) - else - # This is going to be potentially bad, but that's what they get for not giving us a primal - # This will never me mutated inplace, rather it will alway be replaced with an actual value first - ZeroTangent() - end - ) - Expr(:kw, fname, fval) - end - + # easy case exit early, can't hold references, can't be a reference. if isbitstype(primal) + zfield_exprs = map(fieldnames(primal)) do fname + fval = :(zero_tangent(getfield(primal, $(QuoteNode(fname))))) + Expr(:kw, fname, fval) + end return :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...)))) end - # hard case need to be prepared for cycic references to this, or that are contained within this + # hard case need to be prepared for references to this, or that are contained within this quote - counts = $count_references!(primal) + counts = $count_references(primal) + any_mask = $(Expr(:tuple, Expr(:parameters, map(fieldnames(primal), fieldtypes(primal)) do fname, ftype + # If it is is unassigned, or if it doesn't have a concrete type, or we have multiple reference to it + # then let it take any value for its tangent + fdef = :( + !isdefined(primal, $(QuoteNode(fname))) || + !isconcretetype($ftype) || + get(counts, $(QuoteNode(fname)), 0) > 1 + ) + Expr(:kw, fname, fdef) + end...))) + + # Construct tangents + + # Go back and fill in tangents that were not ready end ## TODO rewrite below has_mutable_tangent(primal) - any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype - # If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent - fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype)) - Expr(:kw, fname, fdef) - end + any_mask = :($MutableTangent{$primal}( $(Expr(:tuple, Expr(:parameters, any_mask...))), $(Expr(:tuple, Expr(:parameters, zfield_exprs...))), @@ -184,7 +186,7 @@ function zero_tangent(x::Array{P,N}) where {P,N} end ############################################### -count_references!(x) = count_references(IdDict{Any, Int}(), x) +count_references(x) = count_references(IdDict{Any, Int}(), x) function count_references!(counts::IdDict{Any, Int}, x) isbits(x) && return counts # can't be a refernece and can't hold a reference counts[x] = get(counts, x, 0) + 1 # Increment *before* recursing