Skip to content

Commit cf7acfb

Browse files
ToucheSirKeno
andcommitted
Elide stack generation outside of non-looping control flow
Co-authored-by: Keno Fischer <[email protected]>
1 parent 843a52d commit cf7acfb

File tree

6 files changed

+156
-34
lines changed

6 files changed

+156
-34
lines changed

Diff for: src/compiler/emit.jl

+105-23
Original file line numberDiff line numberDiff line change
@@ -34,50 +34,130 @@ xtuple(xs...) = xcall(:tuple, xs...)
3434

3535
concrete(T::DataType) = T
3636
concrete(::Type{Type{T}}) where T = typeof(T)
37-
concrete(T) = Any
37+
concrete(@nospecialize _) = Any
3838

3939
runonce(b) = b.id in (1, length(b.ir.blocks))
4040

41+
# TODO use a more efficient algorithm such as Johnson (1975)
42+
# https://epubs.siam.org/doi/abs/10.1137/0204007
43+
self_reaching(cfg, bid, visited = BitSet()) = reaches(cfg, bid, bid, visited)
44+
function reaches(cfg, from, to, visited)
45+
for succ in cfg[from]
46+
if succ === to
47+
return true
48+
elseif succ visited
49+
push!(visited, succ)
50+
if reaches(cfg, succ, to, visited)
51+
return true
52+
end
53+
end
54+
end
55+
return false
56+
end
57+
4158
function forward_stacks!(adj, F)
42-
stks, recs = [], []
59+
stks, recs = Tuple{Int, Alpha, Bool}[], Variable[]
4360
pr = adj.primal
44-
for b in blocks(pr), α in alphauses(block(adj.adjoint, b.id))
45-
if runonce(b)
46-
push!(recs, Variable(α))
47-
else
48-
stk = pushfirst!(pr, xstack(Any))
49-
push!(recs, stk)
50-
push!(b, xcall(Zygote, :_push!, stk, Variable(α)))
61+
blks = blocks(pr)
62+
last_block = length(blks)
63+
cfg = IRTools.CFG(pr)
64+
cfgᵀ = cfg'
65+
doms = IRTools.dominators(cfg)
66+
67+
reaching_visited = BitSet()
68+
in_loop = map(1:last_block) do b
69+
empty!(reaching_visited)
70+
self_reaching(cfg, b, reaching_visited)
71+
end
72+
alphavars = Dict{Alpha, Variable}()
73+
alpha_blocks ==> b.id for b in blks for α in alphauses(block(adj.adjoint, b.id))]
74+
for b in Iterators.reverse(blks)
75+
filter!(alpha_blocks) do (α, bid)
76+
if b.id in doms[bid]
77+
# If a block dominates this block, α is guaranteed to be present here
78+
αvar = Variable(α)
79+
for br in branches(b)
80+
map!(a -> a === α ? αvar : a, br.args, br.args)
81+
end
82+
push!(recs, b.id === last_block ? αvar : alphavars[α])
83+
push!(stks, (bid, α, false))
84+
elseif in_loop[bid]
85+
# This block is in a loop, so we're forced to insert stacks
86+
# Note: all alphas in loops will have stacks after the first iteration
87+
stk = pushfirst!(pr, xstack(Any))
88+
push!(recs, stk)
89+
push!(block(pr, bid), xcall(Zygote, :_push!, stk, Variable(α)))
90+
push!(stks, (bid, α, true))
91+
else
92+
# Fallback case, propagate alpha back through the CFG
93+
argvar = nothing
94+
if b.id > 1
95+
# Need to make sure all predecessors have a branch to add arguments to
96+
IRTools.explicitbranch!(b)
97+
argvar = argument!(b, insert=false)
98+
end
99+
if b.id === last_block
100+
# This alpha has been threaded all the way through to the exit block
101+
alphavars[α] = argvar
102+
end
103+
for br in branches(b)
104+
map!(a -> a === α ? argvar : a, br.args, br.args)
105+
end
106+
for pred in cfgᵀ[b.id]
107+
pred >= b.id && continue # TODO is this needed?
108+
pred_branches = branches(block(pr, pred))
109+
idx = findfirst(br -> br.block === b.id, pred_branches)
110+
if idx === nothing
111+
throw(error("Predecessor $pred of block $(b.id) has no branch to $(b.id)"))
112+
end
113+
branch_here = pred_branches[idx]
114+
push!(branch_here.args, α)
115+
end
116+
# We're not done with this alpha yet, revisit in predecessors
117+
return true
118+
end
119+
return false
120+
end
121+
# Prune any alphas that don't exist on this path through the CFG
122+
for br in branches(b)
123+
map!(a -> a isa Alpha ? nothing : a, br.args, br.args)
51124
end
52-
push!(stks, (b.id, alpha(α)))
53125
end
54-
args = arguments(pr)[3:end]
126+
@assert isempty(alpha_blocks)
127+
55128
rec = push!(pr, xtuple(recs...))
129+
# Pullback{F,Any} reduces specialisation
56130
P = length(pr.blocks) == 1 ? Pullback{F} : Pullback{F,Any}
57-
# P = Pullback{F,Any} # reduce specialisation
58131
rec = push!(pr, Expr(:call, P, rec))
59132
ret = xtuple(pr.blocks[end].branches[end].args[1], rec)
60133
ret = push!(pr, ret)
61134
pr.blocks[end].branches[end].args[1] = ret
62135
return pr, stks
63136
end
64137

138+
# Helps constrain pullback function type in the backwards pass
139+
# If we had the type, we could make this a PiNode
140+
notnothing(::Nothing) = error()
141+
notnothing(x) = x
142+
65143
function reverse_stacks!(adj, stks)
66144
ir = adj.adjoint
67-
entry = blocks(ir)[end]
145+
blcks = blocks(ir)
146+
entry = blcks[end]
68147
self = argument!(entry, at = 1)
69-
t = pushfirst!(blocks(ir)[end], xcall(:getfield, self, QuoteNode(:t)))
70-
repl = Dict()
71-
runonce(b) = b.id in (1, length(ir.blocks))
72-
for b in blocks(ir)
73-
for (i, (b′, α)) in enumerate(stks)
148+
t = pushfirst!(entry, xcall(:getfield, self, QuoteNode(:t)))
149+
repl = Dict{Alpha,Variable}()
150+
for b in blcks
151+
for (i, (b′, α, use_stack)) in enumerate(stks)
74152
b.id == b′ || continue
75-
if runonce(b)
76-
val = insertafter!(ir, t, xcall(:getindex, t, i))
77-
else
78-
stk = push!(entry, xcall(:getindex, t, i))
79-
stk = push!(entry, xcall(Zygote, :Stack, stk))
153+
# i.e. recs[i] from forward_stacks!
154+
val = insertafter!(ir, t, xcall(:getindex, t, i))
155+
if use_stack
156+
stk = push!(entry, xcall(Zygote, :Stack, val))
80157
val = pushfirst!(b, xcall(:pop!, stk))
158+
elseif !runonce(b)
159+
# The first and last blocks always run, so this check is redundant there
160+
val = pushfirst!(b, xcall(Zygote, :notnothing, val))
81161
end
82162
repl[α] = val
83163
end
@@ -87,6 +167,7 @@ end
87167

88168
function stacks!(adj, T)
89169
forw, stks = forward_stacks!(adj, T)
170+
IRTools.domorder!(forw)
90171
back = reverse_stacks!(adj, stks)
91172
permute!(back, length(back.blocks):-1:1)
92173
IRTools.domorder!(back)
@@ -97,6 +178,7 @@ varargs(m::Method, n) = m.isva ? n - m.nargs + 1 : nothing
97178

98179
function _generate_pullback_via_decomposition(T)
99180
(m = meta(T)) === nothing && return
181+
# Core.println("decomp: ", T)
100182
va = varargs(m.method, length(T.parameters))
101183
forw, back = stacks!(Adjoint(IR(m), varargs = va, normalise = false), T)
102184
m, forw, back

Diff for: src/compiler/interface2.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ end
3333
meta, forw, _ = g
3434
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
3535
forw = varargs!(meta, forw, 3)
36-
# IRTools.verify(forw)
36+
# verify(forw)
3737
forw = slots!(pis!(inlineable!(forw)))
3838
# be ready to swap to using chainrule if one is declared
39-
cr_edge != nothing && edge!(meta, cr_edge)
39+
cr_edge !== nothing && edge!(meta, cr_edge)
4040
return update!(meta.code, forw)
4141
end
4242

@@ -53,7 +53,7 @@ end
5353
end
5454
meta, _, back = g
5555
argnames!(meta, Symbol("#self#"), )
56-
# IRTools.verify(back)
56+
# verify(back)
5757
back = slots!(inlineable!(back))
5858
return update!(meta.code, back)
5959
end

Diff for: src/lib/lib.jl

+3
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ function accum_global(cx::Context, ref, x̄)
6969
return
7070
end
7171

72+
# Needed for nested AD
73+
@nograd accum_global
74+
7275
unwrap(x) = x
7376

7477
@adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),)

Diff for: test/compiler.jl

+44-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Zygote, Test
22
using Zygote: pullback, @adjoint
3+
using IRTools
34

45
macro test_inferred(ex)
56
:(let res = nothing
@@ -30,11 +31,11 @@ y, back = pullback(badly, 2)
3031
@test_throws Exception back(1)
3132
bt = try back(1) catch e stacktrace(catch_backtrace()) end
3233

33-
@test trace_contains(bt, nothing, "compiler.jl", 20)
34+
@test trace_contains(bt, nothing, "compiler.jl", 21)
3435
if VERSION >= v"1.6-"
35-
@test_broken trace_contains(bt, :badly, "compiler.jl", 24)
36+
@test_broken trace_contains(bt, :badly, "compiler.jl", 25)
3637
else
37-
@test trace_contains(bt, :badly, "compiler.jl", 24)
38+
@test trace_contains(bt, :badly, "compiler.jl", 25)
3839
end
3940

4041
# Type inference checks
@@ -58,10 +59,9 @@ y, back = @test_inferred pullback(f, 5)
5859
y, back = @test_inferred pullback(Core._apply, +, (1, 2, 3))
5960
@test_inferred back(1)
6061

61-
# TODO fix bcast inference
62-
# bcast(x) = x .* 5
63-
# y, back = @test_inferred pullback(bcast, [1,2,3])
64-
# @test_inferred back([1,1,1])
62+
bcast(x) = x .* 5
63+
y, back = @test_inferred pullback(bcast, [1,2,3])
64+
@test_inferred back([1,1,1])
6565

6666
foo = let a = 4
6767
x -> x*a
@@ -91,6 +91,43 @@ struct Funky
9191
y
9292
end
9393

94+
@testset "stack elision" begin
95+
function stackfree(T)
96+
_, forw = Zygote._generate_pullback_via_decomposition(T)
97+
for b in IRTools.blocks(forw)
98+
bb = IRTools.BasicBlock(b)
99+
for stmt in bb.stmts
100+
expr = stmt.expr
101+
expr.head == :call && expr.args[1:2] == [Zygote, :_push!] && return false
102+
end
103+
end
104+
return true
105+
end
106+
107+
function knockoff_pow(x, n)
108+
n == 0 && return 1
109+
n == 1 && return x
110+
n == 2 && return x * x
111+
n == 3 && return x * x * x
112+
return x ^ n
113+
end
114+
115+
function roundabout_trig(x, fancy_sin, fancy_cos, fancy_tan)
116+
if fancy_tan
117+
s = fancy_sin ? inv(csc(x)) : sin(x)
118+
c = fancy_cos ? inv(sec(x)) : cos(x)
119+
s += 0
120+
c *= 1
121+
return s / c
122+
else
123+
return tan(x)
124+
end
125+
end
126+
127+
@test stackfree(Tuple{typeof(knockoff_pow), Int, Int})
128+
@test stackfree(Tuple{typeof(roundabout_trig), Float64, Bool, Bool, Bool})
129+
end
130+
94131
@testset "issue #851" begin
95132
f = Funky(1, 1);
96133
function Base.getproperty(f::Funky, i::Symbol)

Diff for: test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using Zygote, Test
22
using Zygote: gradient, ZygoteRuleConfig
33
using CUDA
44
using CUDA: has_cuda
5+
using LinearAlgebra
56

67
@testset "all" begin # Overall testset ensures it keeps running after failure
78

Diff for: test/utils.jl

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using LinearAlgebra
21
using ForwardDiff
32
using Zygote: hessian_dual, hessian_reverse
43

0 commit comments

Comments
 (0)