2
2
# forward #
3
3
# ##########
4
4
5
- function dualwrap {N,T,A} (duals:: AbstractArray{Dual{N,T}} , :: Type{A} , tp:: Nullable{Tape} )
5
+ function retrack_duals {N,T,A} (duals:: AbstractArray{Dual{N,T}} , :: Type{A} , tp:: Nullable{Tape} )
6
6
ts = similar (duals, Tracked{T,A})
7
7
for i in eachindex (duals)
8
8
ts[i] = Tracked (value (duals[i]), A, tp)
@@ -19,7 +19,7 @@ for A in ARRAY_TYPES
19
19
fdual = t -> fopt. f (Dual (value (t), one (V)))
20
20
duals = $ (g)(fdual, x)
21
21
tp = tape (x)
22
- out = dualwrap (duals, S, tp)
22
+ out = retrack_duals (duals, S, tp)
23
23
record! (tp, $ (g), x, out, duals)
24
24
return out
25
25
end
@@ -31,10 +31,32 @@ for A in ARRAY_TYPES
31
31
Dual (value (t2), zero (V2), one (V2)))
32
32
duals = $ (g)(fdual, x1, x2)
33
33
tp = tape (x1, x2)
34
- out = dualwrap (duals, S, tp)
34
+ out = retrack_duals (duals, S, tp)
35
35
record! (tp, $ (g), (x1, x2), out, duals)
36
36
return out
37
37
end
38
+
39
+ function Base. $ (g){F,V,S}(fopt:: ForwardOptimize{F} ,
40
+ x1:: $ (A){Tracked{V,S}},
41
+ x2:: $ (A))
42
+ fdual = (t1, t2) -> fopt. f (Dual (value (t1), one (V)), t2)
43
+ duals = $ (g)(fdual, x1, x2)
44
+ tp = tape (x1)
45
+ out = retrack_duals (duals, S, tp)
46
+ record! (tp, $ (g), x1, out, duals)
47
+ return out
48
+ end
49
+
50
+ function Base. $ (g){F,V,S}(fopt:: ForwardOptimize{F} ,
51
+ x1:: $ (A),
52
+ x2:: $ (A){Tracked{V,S}})
53
+ fdual = (t1, t2) -> fopt. f (t1, Dual (value (t2), one (V)))
54
+ duals = $ (g)(fdual, x1, x2)
55
+ tp = tape (x2)
56
+ out = retrack_duals (duals, S, tp)
57
+ record! (tp, $ (g), x2, out, duals)
58
+ return out
59
+ end
38
60
end
39
61
end
40
62
@@ -58,7 +80,7 @@ for A in ARRAY_TYPES
58
80
fdual = t -> fopt. f (ndual, Dual (value (t), zero (X), one (X)))
59
81
duals = broadcast (fdual, x)
60
82
tp = tape (n, x)
61
- out = dualwrap (duals, S, tp)
83
+ out = retrack_duals (duals, S, tp)
62
84
record! (tp, broadcast, (n, x), out, duals)
63
85
return out
64
86
end
@@ -68,10 +90,30 @@ for A in ARRAY_TYPES
68
90
fdual = t -> fopt. f (Dual (value (t), one (X), zero (X)), ndual)
69
91
duals = broadcast (fdual, x)
70
92
tp = tape (n, x)
71
- out = dualwrap (duals, S, tp)
93
+ out = retrack_duals (duals, S, tp)
72
94
record! (tp, broadcast, (x, n), out, duals)
73
95
return out
74
96
end
97
+
98
+ function Base. broadcast {F,V,S} (fopt:: ForwardOptimize{F} , n:: Tracked{V,S} , x:: $ (A))
99
+ ndual = Dual (value (n), one (V))
100
+ fdual = t -> fopt. f (ndual, t)
101
+ duals = broadcast (fdual, x)
102
+ tp = tape (n)
103
+ out = retrack_duals (duals, S, tp)
104
+ record! (tp, broadcast, n, out, duals)
105
+ return out
106
+ end
107
+
108
+ function Base. broadcast {F,V,S} (fopt:: ForwardOptimize{F} , x:: $ (A), n:: Tracked{V,S} )
109
+ ndual = Dual (value (n), one (V))
110
+ fdual = t -> fopt. f (t, ndual)
111
+ duals = broadcast (fdual, x)
112
+ tp = tape (n)
113
+ out = retrack_duals (duals, S, tp)
114
+ record! (tp, broadcast, n, out, duals)
115
+ return out
116
+ end
75
117
end
76
118
77
119
# standard elementwise operations (.+, .-, .*, etc.) #
@@ -82,13 +124,29 @@ for A in ARRAY_TYPES
82
124
return broadcast (ForwardOptimize ($ (f)), x, y)
83
125
end
84
126
127
+ @inline function Base. $ (f){X<: Tracked }(x:: $ (A){X}, y:: $ (A))
128
+ return broadcast (ForwardOptimize ($ (f)), x, y)
129
+ end
130
+
131
+ @inline function Base. $ (f){Y<: Tracked }(x:: $ (A), y:: $ (A){Y})
132
+ return broadcast (ForwardOptimize ($ (f)), x, y)
133
+ end
134
+
85
135
@inline function Base. $ (f){T<: Tracked }(n:: Tracked , x:: $ (A){T})
86
136
return broadcast (ForwardOptimize ($ (f)), n, x)
87
137
end
88
138
89
139
@inline function Base. $ (f){T<: Tracked }(x:: $ (A){T}, n:: Tracked )
90
140
return broadcast (ForwardOptimize ($ (f)), x, n)
91
141
end
142
+
143
+ @inline function Base. $ (f)(n:: Tracked , x:: $ (A))
144
+ return broadcast (ForwardOptimize ($ (f)), n, x)
145
+ end
146
+
147
+ @inline function Base. $ (f)(x:: $ (A), n:: Tracked )
148
+ return broadcast (ForwardOptimize ($ (f)), x, n)
149
+ end
92
150
end
93
151
for R in REAL_TYPES
94
152
@eval begin
@@ -138,26 +196,34 @@ function special_reverse_step!{A,B}(::typeof(broadcast), inputs::Tuple{A,B}, out
138
196
if size (a) == size (b)
139
197
special_reverse_step! (map, inputs, output, duals)
140
198
else
141
- for i in eachindex (duals)
142
- duals[i] *= adjoint (output[i])
143
- end
144
- s = sumover (1 , a, duals)
145
- increment_adjoint! (a, s)
146
- increment_adjoint! (b, sumover (2 , b, duals))
199
+ broadcast_adjoint_reduce! (a, output, duals, 1 )
200
+ broadcast_adjoint_reduce! (b, output, duals, 2 )
147
201
end
148
202
return nothing
149
203
end
150
204
151
- # Inference here is pretty wonky (see JuliaLang/julia#10533),
152
- # so it's important that we allocate the array for the sum
153
- # result ourselves. Otherwise, `reducedim_init` tries to
154
- # allocate an array of the wrong type in some cases, which
155
- # leads to conversion errors.
156
- function sumover {N,M,T} (p, x:: AbstractArray , duals:: AbstractArray{Dual{N,T},M} )
157
- dims = (size (x, i) != size (duals, i) ? 1 : size (duals, i) for i in 1 : ndims (duals))
158
- result = similar (duals, T, (dims... ):: NTuple{M,Int} )
159
- sum! (d -> partials (d, p), result, duals)
160
- return result
205
+ function special_reverse_step! (:: typeof (broadcast), input:: Number , output, duals)
206
+ broadcast_adjoint_reduce! (input, output, duals, 1 )
207
+ return nothing
208
+ end
209
+
210
+ # This strategy should be pretty fast, but it might be prone to numerical error if the
211
+ # accumulated adjoint becomes too large compared to the individual terms being added to
212
+ # it. This can be overcome by using the divide-and-conquer strategy used by
213
+ # Base.mapreducedim, but that strategy is less cache efficient and more complicated to
214
+ # implement.
215
+ function broadcast_adjoint_reduce! {T,N} (input:: AbstractArray , output:: AbstractArray{T,N} , duals, p)
216
+ dims = (size (input, i) != size (duals, i) ? 1 : size (duals, i) for i in 1 : ndims (duals))
217
+ max_index = CartesianIndex ((dims... ):: NTuple{N,Int} )
218
+ for i in CartesianRange (size (input))
219
+ increment_adjoint! (input[min (max_index, i)], adjoint (output[i]) * partials (duals[i], p))
220
+ end
221
+ return nothing
161
222
end
162
223
163
- sumover (p, x:: Real , duals) = sum (d -> partials (d, p), duals)
224
+ function broadcast_adjoint_reduce! {T,N} (input:: Number , output:: AbstractArray{T,N} , duals, p)
225
+ for i in eachindex (duals)
226
+ increment_adjoint! (input, adjoint (output[i]) * partials (duals[i], p))
227
+ end
228
+ return nothing
229
+ end
0 commit comments