|
138 | 138 | )
|
139 | 139 | Expr(:kw, fname, fval)
|
140 | 140 | end
|
141 |
| - return if has_mutable_tangent(primal) |
| 141 | + |
| 142 | + # easy case exit early, can't hold references, can't be a reference. |
| 143 | + if isbitstype(primal) |
| 144 | + return :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...)))) |
| 145 | + end |
| 146 | + |
| 147 | + # hard case need to be prepared for cycic references to this, or that are contained within this |
| 148 | + quote |
| 149 | + counts = $count_references!(primal) |
| 150 | + end |
| 151 | + |
| 152 | +## TODO rewrite below |
| 153 | + has_mutable_tangent(primal) |
142 | 154 | any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype
|
143 | 155 | # If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent
|
144 | 156 | fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype))
|
@@ -171,6 +183,36 @@ function zero_tangent(x::Array{P,N}) where {P,N}
|
171 | 183 | return y
|
172 | 184 | end
|
173 | 185 |
|
| 186 | +############################################### |
| 187 | +count_references!(x) = count_references(IdDict{Any, Int}(), x) |
| 188 | +function count_references!(counts::IdDict{Any, Int}, x) |
| 189 | + isbits(x) && return counts # can't be a refernece and can't hold a reference |
| 190 | + counts[x] = get(counts, x, 0) + 1 # Increment *before* recursing |
| 191 | + if counts[x] == 1 # Only recurse the first time |
| 192 | + for ii in fieldcount(typeof(x)) |
| 193 | + field = getfield(x, ii) |
| 194 | + count_references!(counts, field) |
| 195 | + end |
| 196 | + end |
| 197 | + return counts |
| 198 | +end |
| 199 | + |
| 200 | +function count_references!(counts::IdDict{Any, Int}, x::Array) |
| 201 | + counts[x] = get(counts, x, 0) + 1 # increment before recursing |
| 202 | + isbitstype(eltype(x)) && return counts # no need to look inside, it can't hold references |
| 203 | + if counts[x] == 1 # only recurse the first time |
| 204 | + for ele in x |
| 205 | + count_references!(counts, ele) |
| 206 | + end |
| 207 | + end |
| 208 | + return counts |
| 209 | +end |
| 210 | + |
| 211 | +count_references!(counts::IdDict{Any, Int}, ::DataType) = counts |
| 212 | + |
| 213 | +############################################### |
| 214 | + |
| 215 | + |
174 | 216 | # Sad heauristic methods we need because of unassigned values
|
175 | 217 | guess_zero_tangent_type(::Type{T}) where {T<:Number} = T
|
176 | 218 | guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T)))
|
|
0 commit comments