@@ -126,36 +126,38 @@ end
126
126
127
127
@generated function zero_tangent (primal)
128
128
fieldcount (primal) == 0 && return NoTangent () # no tangent space at all, no need for structural zero.
129
- zfield_exprs = map (fieldnames (primal)) do fname
130
- fval = :(
131
- if isdefined (primal, $ (QuoteNode (fname)))
132
- zero_tangent (getfield (primal, $ (QuoteNode (fname))))
133
- else
134
- # This is going to be potentially bad, but that's what they get for not giving us a primal
135
- # This will never me mutated inplace, rather it will alway be replaced with an actual value first
136
- ZeroTangent ()
137
- end
138
- )
139
- Expr (:kw , fname, fval)
140
- end
141
-
129
+
142
130
# easy case exit early, can't hold references, can't be a reference.
143
131
if isbitstype (primal)
132
+ zfield_exprs = map (fieldnames (primal)) do fname
133
+ fval = :(zero_tangent (getfield (primal, $ (QuoteNode (fname)))))
134
+ Expr (:kw , fname, fval)
135
+ end
144
136
return :($ Tangent {$primal} ($ (Expr (:parameters , zfield_exprs... ))))
145
137
end
146
138
147
- # hard case need to be prepared for cycic references to this, or that are contained within this
139
+ # hard case need to be prepared for references to this, or that are contained within this
148
140
quote
149
- counts = $ count_references! (primal)
141
+ counts = $ count_references (primal)
142
+ any_mask = $ (Expr (:tuple , Expr (:parameters , map (fieldnames (primal), fieldtypes (primal)) do fname, ftype
143
+ # If it is is unassigned, or if it doesn't have a concrete type, or we have multiple reference to it
144
+ # then let it take any value for its tangent
145
+ fdef = :(
146
+ ! isdefined (primal, $ (QuoteNode (fname))) ||
147
+ ! isconcretetype ($ ftype) ||
148
+ get (counts, $ (QuoteNode (fname)), 0 ) > 1
149
+ )
150
+ Expr (:kw , fname, fdef)
151
+ end ... )))
152
+
153
+ # Construct tangents
154
+
155
+ # Go back and fill in tangents that were not ready
150
156
end
151
157
152
158
# # TODO rewrite below
153
159
has_mutable_tangent (primal)
154
- any_mask = map (fieldnames (primal), fieldtypes (primal)) do fname, ftype
155
- # If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent
156
- fdef = :(! isdefined (primal, $ (QuoteNode (fname))) || ! isconcretetype ($ ftype))
157
- Expr (:kw , fname, fdef)
158
- end
160
+ any_mask =
159
161
:($ MutableTangent {$primal} (
160
162
$ (Expr (:tuple , Expr (:parameters , any_mask... ))),
161
163
$ (Expr (:tuple , Expr (:parameters , zfield_exprs... ))),
@@ -184,7 +186,7 @@ function zero_tangent(x::Array{P,N}) where {P,N}
184
186
end
185
187
186
188
# ##############################################
187
- count_references! (x) = count_references (IdDict {Any, Int} (), x)
189
+ count_references (x) = count_references (IdDict {Any, Int} (), x)
188
190
function count_references! (counts:: IdDict{Any, Int} , x)
189
191
isbits (x) && return counts # can't be a refernece and can't hold a reference
190
192
counts[x] = get (counts, x, 0 ) + 1 # Increment *before* recursing
0 commit comments