Skip to content

Commit 8514b3e

Browse files
committed
WIP: premapping based cyclic type zeros
1 parent f9ade6e commit 8514b3e

File tree

2 files changed

+52
-10
lines changed

2 files changed

+52
-10
lines changed

src/tangent_types/abstract_zero.jl

+43-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,19 @@ end
138138
)
139139
Expr(:kw, fname, fval)
140140
end
141-
return if has_mutable_tangent(primal)
141+
142+
# easy case exit early, can't hold references, can't be a reference.
143+
if isbitstype(primal)
144+
return :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...))))
145+
end
146+
147+
# hard case need to be prepared for cycic references to this, or that are contained within this
148+
quote
149+
counts = $count_references!(primal)
150+
end
151+
152+
## TODO rewrite below
153+
has_mutable_tangent(primal)
142154
any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype
143155
# If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent
144156
fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype))
@@ -171,6 +183,36 @@ function zero_tangent(x::Array{P,N}) where {P,N}
171183
return y
172184
end
173185

186+
###############################################
187+
count_references!(x) = count_references(IdDict{Any, Int}(), x)
188+
function count_references!(counts::IdDict{Any, Int}, x)
189+
isbits(x) && return counts # can't be a refernece and can't hold a reference
190+
counts[x] = get(counts, x, 0) + 1 # Increment *before* recursing
191+
if counts[x] == 1 # Only recurse the first time
192+
for ii in fieldcount(typeof(x))
193+
field = getfield(x, ii)
194+
count_references!(counts, field)
195+
end
196+
end
197+
return counts
198+
end
199+
200+
function count_references!(counts::IdDict{Any, Int}, x::Array)
201+
counts[x] = get(counts, x, 0) + 1 # increment before recursing
202+
isbitstype(eltype(x)) && return counts # no need to look inside, it can't hold references
203+
if counts[x] == 1 # only recurse the first time
204+
for ele in x
205+
count_references!(counts, ele)
206+
end
207+
end
208+
return counts
209+
end
210+
211+
count_references!(counts::IdDict{Any, Int}, ::DataType) = counts
212+
213+
###############################################
214+
215+
174216
# Sad heauristic methods we need because of unassigned values
175217
guess_zero_tangent_type(::Type{T}) where {T<:Number} = T
176218
guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T)))

test/tangent_types/abstract_zero.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -303,24 +303,24 @@ end
303303
lk = Link(1.5)
304304
lk.next = lk
305305

306-
@test_broken d = zero_tangent(lk)
307-
@test_broken d.data == 0.0
308-
@test_broken d.next === d
306+
d = zero_tangent(lk)
307+
@test d.data == 0.0
308+
@test d.next === d
309309

310310
struct CarryingArray
311311
x::Vector
312312
end
313313
ca = CarryingArray(Any[1.5])
314314
push!(ca.x, ca)
315-
@test_broken d_ca = zero_tangent(ca)
316-
@test_broken d_ca[1] == 0.0
317-
@test_broken d_ca[2] === _ca
315+
@test d_ca = zero_tangent(ca)
316+
@test d_ca[1] == 0.0
317+
@test d_ca[2] === _ca
318318

319319
# Idea: check if typeof(xs) <: eltype(xs), if so need to cache it before computing
320320
xs = Any[1.5]
321321
push!(xs, xs)
322-
@test_broken d_xs = zero_tangent(xs)
323-
@test_broken d_xs[1] == 0.0
324-
@test_broken d_xs[2] == d_xs
322+
@test d_xs = zero_tangent(xs)
323+
@test d_xs[1] == 0.0
324+
@test d_xs[2] == d_xs
325325
end
326326
end

0 commit comments

Comments
 (0)