Skip to content

Commit 36453ca

Browse files
ToucheSirKeno
andcommitted
Elide stack generation outside of non-looping control flow
Co-authored-by: Keno Fischer <[email protected]>
1 parent d39ab59 commit 36453ca

File tree

6 files changed

+157
-36
lines changed

6 files changed

+157
-36
lines changed

Diff for: src/compiler/emit.jl

+105-23
Original file line numberDiff line numberDiff line change
@@ -34,50 +34,130 @@ xtuple(xs...) = xcall(:tuple, xs...)
3434

3535
concrete(T::DataType) = T
3636
concrete(::Type{Type{T}}) where T = typeof(T)
37-
concrete(T) = Any
37+
concrete(@nospecialize _) = Any
3838

3939
runonce(b) = b.id in (1, length(b.ir.blocks))
4040

41+
# TODO use a more efficient algorithm such as Johnson (1975)
42+
# https://epubs.siam.org/doi/abs/10.1137/0204007
43+
self_reaching(cfg, bid, visited = BitSet()) = reaches(cfg, bid, bid, visited)
44+
function reaches(cfg, from, to, visited)
45+
for succ in cfg[from]
46+
if succ === to
47+
return true
48+
elseif succ visited
49+
push!(visited, succ)
50+
if reaches(cfg, succ, to, visited)
51+
return true
52+
end
53+
end
54+
end
55+
return false
56+
end
57+
4158
function forward_stacks!(adj, F)
42-
stks, recs = [], []
59+
stks, recs = Tuple{Int, Alpha, Bool}[], Variable[]
4360
pr = adj.primal
44-
for b in blocks(pr), α in alphauses(block(adj.adjoint, b.id))
45-
if runonce(b)
46-
push!(recs, Variable(α))
47-
else
48-
stk = pushfirst!(pr, xstack(Any))
49-
push!(recs, stk)
50-
push!(b, xcall(Zygote, :_push!, stk, Variable(α)))
61+
blks = blocks(pr)
62+
last_block = length(blks)
63+
cfg = IRTools.CFG(pr)
64+
cfgᵀ = cfg'
65+
doms = IRTools.dominators(cfg)
66+
67+
reaching_visited = BitSet()
68+
in_loop = map(1:last_block) do b
69+
empty!(reaching_visited)
70+
self_reaching(cfg, b, reaching_visited)
71+
end
72+
alphavars = Dict{Alpha, Variable}()
73+
alpha_blocks ==> b.id for b in blks for α in alphauses(block(adj.adjoint, b.id))]
74+
for b in Iterators.reverse(blks)
75+
filter!(alpha_blocks) do (α, bid)
76+
if b.id in doms[bid]
77+
# If a block dominates this block, α is guaranteed to be present here
78+
αvar = Variable(α)
79+
for br in branches(b)
80+
map!(a -> a === α ? αvar : a, br.args, br.args)
81+
end
82+
push!(recs, b.id === last_block ? αvar : alphavars[α])
83+
push!(stks, (bid, α, false))
84+
elseif in_loop[bid]
85+
# This block is in a loop, so we're forced to insert stacks
86+
# Note: all alphas in loops will have stacks after the first iteration
87+
stk = pushfirst!(pr, xstack(Any))
88+
push!(recs, stk)
89+
push!(block(pr, bid), xcall(Zygote, :_push!, stk, Variable(α)))
90+
push!(stks, (bid, α, true))
91+
else
92+
# Fallback case, propagate alpha back through the CFG
93+
argvar = nothing
94+
if b.id > 1
95+
# Need to make sure all predecessors have a branch to add arguments to
96+
IRTools.explicitbranch!(b)
97+
argvar = argument!(b, insert=false)
98+
end
99+
if b.id === last_block
100+
# This alpha has been threaded all the way through to the exit block
101+
alphavars[α] = argvar
102+
end
103+
for br in branches(b)
104+
map!(a -> a === α ? argvar : a, br.args, br.args)
105+
end
106+
for pred in cfgᵀ[b.id]
107+
pred >= b.id && continue # TODO is this needed?
108+
pred_branches = branches(block(pr, pred))
109+
idx = findfirst(br -> br.block === b.id, pred_branches)
110+
if idx === nothing
111+
throw(error("Predecessor $pred of block $(b.id) has no branch to $(b.id)"))
112+
end
113+
branch_here = pred_branches[idx]
114+
push!(branch_here.args, α)
115+
end
116+
# We're not done with this alpha yet, revisit in predecessors
117+
return true
118+
end
119+
return false
120+
end
121+
# Prune any alphas that don't exist on this path through the CFG
122+
for br in branches(b)
123+
map!(a -> a isa Alpha ? nothing : a, br.args, br.args)
51124
end
52-
push!(stks, (b.id, alpha(α)))
53125
end
54-
args = arguments(pr)[3:end]
126+
@assert isempty(alpha_blocks)
127+
55128
rec = push!(pr, xtuple(recs...))
129+
# Pullback{F,Any} reduces specialisation
56130
P = length(pr.blocks) == 1 ? Pullback{F} : Pullback{F,Any}
57-
# P = Pullback{F,Any} # reduce specialisation
58131
rec = push!(pr, Expr(:call, P, rec))
59132
ret = xtuple(pr.blocks[end].branches[end].args[1], rec)
60133
ret = push!(pr, ret)
61134
pr.blocks[end].branches[end].args[1] = ret
62135
return pr, stks
63136
end
64137

138+
# Helps constrain pullback function type in the backwards pass
139+
# If we had the type, we could make this a PiNode
140+
notnothing(::Nothing) = error()
141+
notnothing(x) = x
142+
65143
function reverse_stacks!(adj, stks)
66144
ir = adj.adjoint
67-
entry = blocks(ir)[end]
145+
blcks = blocks(ir)
146+
entry = blcks[end]
68147
self = argument!(entry, at = 1)
69-
t = pushfirst!(blocks(ir)[end], xcall(:getfield, self, QuoteNode(:t)))
70-
repl = Dict()
71-
runonce(b) = b.id in (1, length(ir.blocks))
72-
for b in blocks(ir)
73-
for (i, (b′, α)) in enumerate(stks)
148+
t = pushfirst!(entry, xcall(:getfield, self, QuoteNode(:t)))
149+
repl = Dict{Alpha,Variable}()
150+
for b in blcks
151+
for (i, (b′, α, use_stack)) in enumerate(stks)
74152
b.id == b′ || continue
75-
if runonce(b)
76-
val = insertafter!(ir, t, xcall(:getindex, t, i))
77-
else
78-
stk = push!(entry, xcall(:getindex, t, i))
79-
stk = push!(entry, xcall(Zygote, :Stack, stk))
153+
# i.e. recs[i] from forward_stacks!
154+
val = insertafter!(ir, t, xcall(:getindex, t, i))
155+
if use_stack
156+
stk = push!(entry, xcall(Zygote, :Stack, val))
80157
val = pushfirst!(b, xcall(:pop!, stk))
158+
elseif !runonce(b)
159+
# The first and last blocks always run, so this check is redundant there
160+
val = pushfirst!(b, xcall(Zygote, :notnothing, val))
81161
end
82162
repl[α] = val
83163
end
@@ -87,6 +167,7 @@ end
87167

88168
function stacks!(adj, T)
89169
forw, stks = forward_stacks!(adj, T)
170+
IRTools.domorder!(forw)
90171
back = reverse_stacks!(adj, stks)
91172
permute!(back, length(back.blocks):-1:1)
92173
IRTools.domorder!(back)
@@ -97,6 +178,7 @@ varargs(m::Method, n) = m.isva ? n - m.nargs + 1 : nothing
97178

98179
function _generate_pullback_via_decomposition(T)
99180
(m = meta(T)) === nothing && return
181+
# Core.println("decomp: ", T)
100182
va = varargs(m.method, length(T.parameters))
101183
forw, back = stacks!(Adjoint(IR(m), varargs = va, normalise = false), T)
102184
m, forw, back

Diff for: src/compiler/interface2.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ end
3333
meta, forw, _ = g
3434
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
3535
forw = varargs!(meta, forw, 3)
36-
# IRTools.verify(forw)
36+
# verify(forw)
3737
forw = slots!(pis!(inlineable!(forw)))
3838
# be ready to swap to using chainrule if one is declared
39-
cr_edge != nothing && edge!(meta, cr_edge)
39+
cr_edge !== nothing && edge!(meta, cr_edge)
4040
return update!(meta.code, forw)
4141
end
4242

@@ -53,7 +53,7 @@ end
5353
end
5454
meta, _, back = g
5555
argnames!(meta, Symbol("#self#"), )
56-
# IRTools.verify(back)
56+
# verify(back)
5757
back = slots!(inlineable!(back))
5858
return update!(meta.code, back)
5959
end

Diff for: src/lib/lib.jl

+3
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ function accum_global(cx::Context, ref, x̄)
6868
return
6969
end
7070

71+
# Needed for nested AD
72+
@nograd accum_global
73+
7174
unwrap(x) = x
7275

7376
@adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),)

Diff for: test/compiler.jl

+45-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Zygote, Test
1+
using Zygote, IRTools, Test
22
using Zygote: pullback, @adjoint, Context
33

44
macro test_inferred(ex)
@@ -30,11 +30,11 @@ y, back = pullback(badly, 2)
3030
@test_throws Exception back(1)
3131
bt = try back(1) catch e stacktrace(catch_backtrace()) end
3232

33-
@test trace_contains(bt, nothing, "compiler.jl", 20)
33+
@test trace_contains(bt, nothing, "compiler.jl", 21)
3434
if VERSION >= v"1.6-"
35-
@test_broken trace_contains(bt, :badly, "compiler.jl", 24)
35+
@test_broken trace_contains(bt, :badly, "compiler.jl", 25)
3636
else
37-
@test trace_contains(bt, :badly, "compiler.jl", 24)
37+
@test trace_contains(bt, :badly, "compiler.jl", 25)
3838
end
3939

4040
# Type inference checks
@@ -58,10 +58,9 @@ y, back = @test_inferred pullback(f, 5)
5858
y, back = @test_inferred pullback(Core._apply, +, (1, 2, 3))
5959
@test_inferred back(1)
6060

61-
# TODO fix bcast inference
62-
# bcast(x) = x .* 5
63-
# y, back = @test_inferred pullback(bcast, [1,2,3])
64-
# @test_inferred back([1,1,1])
61+
bcast(x) = x .* 5
62+
y, back = @test_inferred pullback(bcast, [1,2,3])
63+
@test_inferred back([1,1,1])
6564

6665
foo = let a = 4
6766
x -> x*a
@@ -91,6 +90,43 @@ struct Funky
9190
y
9291
end
9392

93+
@testset "stack elision" begin
94+
function stackfree(T)
95+
_, forw = Zygote._generate_pullback_via_decomposition(T)
96+
for b in IRTools.blocks(forw)
97+
bb = IRTools.BasicBlock(b)
98+
for stmt in bb.stmts
99+
expr = stmt.expr
100+
expr.head == :call && expr.args[1:2] == [Zygote, :_push!] && return false
101+
end
102+
end
103+
return true
104+
end
105+
106+
function knockoff_pow(x, n)
107+
n == 0 && return 1
108+
n == 1 && return x
109+
n == 2 && return x * x
110+
n == 3 && return x * x * x
111+
return x ^ n
112+
end
113+
114+
function roundabout_trig(x, fancy_sin, fancy_cos, fancy_tan)
115+
if fancy_tan
116+
s = fancy_sin ? inv(csc(x)) : sin(x)
117+
c = fancy_cos ? inv(sec(x)) : cos(x)
118+
s += 0
119+
c *= 1
120+
return s / c
121+
else
122+
return tan(x)
123+
end
124+
end
125+
126+
@test stackfree(Tuple{typeof(knockoff_pow), Int, Int})
127+
@test stackfree(Tuple{typeof(roundabout_trig), Float64, Bool, Bool, Bool})
128+
end
129+
94130
@testset "issue #851" begin
95131
f = Funky(1, 1);
96132
function Base.getproperty(f::Funky, i::Symbol)
@@ -128,7 +164,7 @@ end
128164
d_two = Zygote.pullback(two_svds, X)[2](Δoutput)
129165
d_one = Zygote.pullback(one_svd, X)[2](Δoutput)
130166
@test d_one == d_two
131-
end
167+
end
132168

133169
# this test fails if adjoint for literal_getproperty is added
134170
# https://github.com/FluxML/Zygote.jl/issues/922#issuecomment-804128905

Diff for: test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using Zygote, Test
22
using Zygote: gradient, ZygoteRuleConfig
33
using CUDA
44
using CUDA: has_cuda
5+
using LinearAlgebra
56

67
@testset "all" begin # Overall testset ensures it keeps running after failure
78

Diff for: test/utils.jl

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using LinearAlgebra
21
using ForwardDiff
32
using Zygote: hessian_dual, hessian_reverse
43

0 commit comments

Comments
 (0)