@@ -34,50 +34,130 @@ xtuple(xs...) = xcall(:tuple, xs...)
34
34
35
35
concrete (T:: DataType ) = T
36
36
concrete (:: Type{Type{T}} ) where T = typeof (T)
37
- concrete (T ) = Any
37
+ concrete (@nospecialize _ ) = Any
38
38
39
39
runonce (b) = b. id in (1 , length (b. ir. blocks))
40
40
41
+ # TODO use a more efficient algorithm such as Johnson (1975)
42
+ # https://epubs.siam.org/doi/abs/10.1137/0204007
43
+ self_reaching (cfg, bid, visited = BitSet ()) = reaches (cfg, bid, bid, visited)
44
+ function reaches (cfg, from, to, visited)
45
+ for succ in cfg[from]
46
+ if succ === to
47
+ return true
48
+ elseif succ ∉ visited
49
+ push! (visited, succ)
50
+ if reaches (cfg, succ, to, visited)
51
+ return true
52
+ end
53
+ end
54
+ end
55
+ return false
56
+ end
57
+
41
58
function forward_stacks! (adj, F)
42
- stks, recs = [], []
59
+ stks, recs = Tuple{Int, Alpha, Bool} [], Variable []
43
60
pr = adj. primal
44
- for b in blocks (pr), α in alphauses (block (adj. adjoint, b. id))
45
- if runonce (b)
46
- push! (recs, Variable (α))
47
- else
48
- stk = pushfirst! (pr, xstack (Any))
49
- push! (recs, stk)
50
- push! (b, xcall (Zygote, :_push! , stk, Variable (α)))
61
+ blks = blocks (pr)
62
+ last_block = length (blks)
63
+ cfg = IRTools. CFG (pr)
64
+ cfgᵀ = cfg'
65
+ doms = IRTools. dominators (cfg)
66
+
67
+ reaching_visited = BitSet ()
68
+ in_loop = map (1 : last_block) do b
69
+ empty! (reaching_visited)
70
+ self_reaching (cfg, b, reaching_visited)
71
+ end
72
+ alphavars = Dict {Alpha, Variable} ()
73
+ alpha_blocks = [α => b. id for b in blks for α in alphauses (block (adj. adjoint, b. id))]
74
+ for b in Iterators. reverse (blks)
75
+ filter! (alpha_blocks) do (α, bid)
76
+ if b. id in doms[bid]
77
+ # If a block dominates this block, α is guaranteed to be present here
78
+ αvar = Variable (α)
79
+ for br in branches (b)
80
+ map! (a -> a === α ? αvar : a, br. args, br. args)
81
+ end
82
+ push! (recs, b. id === last_block ? αvar : alphavars[α])
83
+ push! (stks, (bid, α, false ))
84
+ elseif in_loop[bid]
85
+ # This block is in a loop, so we're forced to insert stacks
86
+ # Note: all alphas in loops will have stacks after the first iteration
87
+ stk = pushfirst! (pr, xstack (Any))
88
+ push! (recs, stk)
89
+ push! (block (pr, bid), xcall (Zygote, :_push! , stk, Variable (α)))
90
+ push! (stks, (bid, α, true ))
91
+ else
92
+ # Fallback case, propagate alpha back through the CFG
93
+ argvar = nothing
94
+ if b. id > 1
95
+ # Need to make sure all predecessors have a branch to add arguments to
96
+ IRTools. explicitbranch! (b)
97
+ argvar = argument! (b, insert= false )
98
+ end
99
+ if b. id === last_block
100
+ # This alpha has been threaded all the way through to the exit block
101
+ alphavars[α] = argvar
102
+ end
103
+ for br in branches (b)
104
+ map! (a -> a === α ? argvar : a, br. args, br. args)
105
+ end
106
+ for pred in cfgᵀ[b. id]
107
+ pred >= b. id && continue # TODO is this needed?
108
+ pred_branches = branches (block (pr, pred))
109
+ idx = findfirst (br -> br. block === b. id, pred_branches)
110
+ if idx === nothing
111
+ throw (error (" Predecessor $pred of block $(b. id) has no branch to $(b. id) " ))
112
+ end
113
+ branch_here = pred_branches[idx]
114
+ push! (branch_here. args, α)
115
+ end
116
+ # We're not done with this alpha yet, revisit in predecessors
117
+ return true
118
+ end
119
+ return false
120
+ end
121
+ # Prune any alphas that don't exist on this path through the CFG
122
+ for br in branches (b)
123
+ map! (a -> a isa Alpha ? nothing : a, br. args, br. args)
51
124
end
52
- push! (stks, (b. id, alpha (α)))
53
125
end
54
- args = arguments (pr)[3 : end ]
126
+ @assert isempty (alpha_blocks)
127
+
55
128
rec = push! (pr, xtuple (recs... ))
129
+ # Pullback{F,Any} reduces specialisation
56
130
P = length (pr. blocks) == 1 ? Pullback{F} : Pullback{F,Any}
57
- # P = Pullback{F,Any} # reduce specialisation
58
131
rec = push! (pr, Expr (:call , P, rec))
59
132
ret = xtuple (pr. blocks[end ]. branches[end ]. args[1 ], rec)
60
133
ret = push! (pr, ret)
61
134
pr. blocks[end ]. branches[end ]. args[1 ] = ret
62
135
return pr, stks
63
136
end
64
137
138
+ # Helps constrain pullback function type in the backwards pass
139
+ # If we had the type, we could make this a PiNode
140
+ notnothing (:: Nothing ) = error ()
141
+ notnothing (x) = x
142
+
65
143
function reverse_stacks! (adj, stks)
66
144
ir = adj. adjoint
67
- entry = blocks (ir)[end ]
145
+ blcks = blocks (ir)
146
+ entry = blcks[end ]
68
147
self = argument! (entry, at = 1 )
69
- t = pushfirst! (blocks (ir)[end ], xcall (:getfield , self, QuoteNode (:t )))
70
- repl = Dict ()
71
- runonce (b) = b. id in (1 , length (ir. blocks))
72
- for b in blocks (ir)
73
- for (i, (b′, α)) in enumerate (stks)
148
+ t = pushfirst! (entry, xcall (:getfield , self, QuoteNode (:t )))
149
+ repl = Dict {Alpha,Variable} ()
150
+ for b in blcks
151
+ for (i, (b′, α, use_stack)) in enumerate (stks)
74
152
b. id == b′ || continue
75
- if runonce (b)
76
- val = insertafter! (ir, t, xcall (:getindex , t, i))
77
- else
78
- stk = push! (entry, xcall (:getindex , t, i))
79
- stk = push! (entry, xcall (Zygote, :Stack , stk))
153
+ # i.e. recs[i] from forward_stacks!
154
+ val = insertafter! (ir, t, xcall (:getindex , t, i))
155
+ if use_stack
156
+ stk = push! (entry, xcall (Zygote, :Stack , val))
80
157
val = pushfirst! (b, xcall (:pop! , stk))
158
+ elseif ! runonce (b)
159
+ # The first and last blocks always run, so this check is redundant there
160
+ val = pushfirst! (b, xcall (Zygote, :notnothing , val))
81
161
end
82
162
repl[α] = val
83
163
end
87
167
88
168
function stacks! (adj, T)
89
169
forw, stks = forward_stacks! (adj, T)
170
+ IRTools. domorder! (forw)
90
171
back = reverse_stacks! (adj, stks)
91
172
permute! (back, length (back. blocks): - 1 : 1 )
92
173
IRTools. domorder! (back)
@@ -97,6 +178,7 @@ varargs(m::Method, n) = m.isva ? n - m.nargs + 1 : nothing
97
178
98
179
function _generate_pullback_via_decomposition (T)
99
180
(m = meta (T)) === nothing && return
181
+ # Core.println("decomp: ", T)
100
182
va = varargs (m. method, length (T. parameters))
101
183
forw, back = stacks! (Adjoint (IR (m), varargs = va, normalise = false ), T)
102
184
m, forw, back
0 commit comments