Skip to content

Commit ddd1647

Browse files
committed
WIP
1 parent 8514b3e commit ddd1647

File tree

1 file changed

+23
-21
lines changed

1 file changed

+23
-21
lines changed

src/tangent_types/abstract_zero.jl

+23-21
Original file line numberDiff line numberDiff line change
@@ -126,36 +126,38 @@ end
126126

127127
@generated function zero_tangent(primal)
128128
fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero.
129-
zfield_exprs = map(fieldnames(primal)) do fname
130-
fval = :(
131-
if isdefined(primal, $(QuoteNode(fname)))
132-
zero_tangent(getfield(primal, $(QuoteNode(fname))))
133-
else
134-
# This is going to be potentially bad, but that's what they get for not giving us a primal
135-
# This will never me mutated inplace, rather it will alway be replaced with an actual value first
136-
ZeroTangent()
137-
end
138-
)
139-
Expr(:kw, fname, fval)
140-
end
141-
129+
142130
# easy case exit early, can't hold references, can't be a reference.
143131
if isbitstype(primal)
132+
zfield_exprs = map(fieldnames(primal)) do fname
133+
fval = :(zero_tangent(getfield(primal, $(QuoteNode(fname)))))
134+
Expr(:kw, fname, fval)
135+
end
144136
return :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...))))
145137
end
146138

147-
# hard case need to be prepared for cycic references to this, or that are contained within this
139+
# hard case need to be prepared for references to this, or that are contained within this
148140
quote
149-
counts = $count_references!(primal)
141+
counts = $count_references(primal)
142+
any_mask = $(Expr(:tuple, Expr(:parameters, map(fieldnames(primal), fieldtypes(primal)) do fname, ftype
143+
# If it is is unassigned, or if it doesn't have a concrete type, or we have multiple reference to it
144+
# then let it take any value for its tangent
145+
fdef = :(
146+
!isdefined(primal, $(QuoteNode(fname))) ||
147+
!isconcretetype($ftype) ||
148+
get(counts, $(QuoteNode(fname)), 0) > 1
149+
)
150+
Expr(:kw, fname, fdef)
151+
end...)))
152+
153+
# Construct tangents
154+
155+
# Go back and fill in tangents that were not ready
150156
end
151157

152158
## TODO rewrite below
153159
has_mutable_tangent(primal)
154-
any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype
155-
# If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent
156-
fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype))
157-
Expr(:kw, fname, fdef)
158-
end
160+
any_mask =
159161
:($MutableTangent{$primal}(
160162
$(Expr(:tuple, Expr(:parameters, any_mask...))),
161163
$(Expr(:tuple, Expr(:parameters, zfield_exprs...))),
@@ -184,7 +186,7 @@ function zero_tangent(x::Array{P,N}) where {P,N}
184186
end
185187

186188
###############################################
187-
count_references!(x) = count_references(IdDict{Any, Int}(), x)
189+
count_references(x) = count_references(IdDict{Any, Int}(), x)
188190
function count_references!(counts::IdDict{Any, Int}, x)
189191
isbits(x) && return counts # can't be a refernece and can't hold a reference
190192
counts[x] = get(counts, x, 0) + 1 # Increment *before* recursing

0 commit comments

Comments
 (0)