2
2
# forward #
3
3
# ##########
4
4
5
+ const A_MUL_B_FUNCS = (:* , :A_mul_Bt , :At_mul_B , :At_mul_Bt , :A_mul_Bc , :Ac_mul_B , :Ac_mul_Bc )
6
+
7
+ const A_MUL_B!_FUNCS = ((:A_mul_B! , :* ),
8
+ (:A_mul_Bt! , :A_mul_Bt ), (:At_mul_B! , :At_mul_B ), (:At_mul_Bt! , :At_mul_Bt ),
9
+ (:A_mul_Bc! , :A_mul_Bc ), (:Ac_mul_B! , :Ac_mul_B ), (:Ac_mul_Bc! , :Ac_mul_Bc ))
10
+
5
11
for A in ARRAY_TYPES
6
12
7
13
# addition/subtraction #
@@ -25,6 +31,29 @@ for A in ARRAY_TYPES
25
31
record! (tp, $ (f), (x, y), out)
26
32
return out
27
33
end
34
+
35
+ if f != :-
36
+ @eval function Base. $ (f){X,S}(x:: $ (A){Tracked{X,S}}, y:: $ (A))
37
+ tp = tape (x)
38
+ out = track ($ (f)(value (x), y), S, tp)
39
+ record! (tp, $ (f), x, out)
40
+ return out
41
+ end
42
+ end
43
+
44
+ @eval function Base. $ (f){Y,S}(x:: $ (A), y:: $ (A){Tracked{Y,S}})
45
+ tp = tape (y)
46
+ out = track ($ (f)(x, value (y)), S, tp)
47
+ record! (tp, $ (f), y, out)
48
+ return out
49
+ end
50
+ end
51
+
52
+ @eval function Base.:- {X,S}(x:: $ (A){Tracked{X,S}}, y:: $ (A))
53
+ tp = tape (x)
54
+ out = track (value (x) - y, S, tp)
55
+ record! (tp, + , x, out)
56
+ return out
28
57
end
29
58
30
59
@eval function Base.:- {V,S}(x:: $ (A){Tracked{V,S}})
@@ -37,22 +66,34 @@ for A in ARRAY_TYPES
37
66
# A_mul_B family #
38
67
# ----------------#
39
68
40
- for f in (:* ,
41
- :A_mul_Bt , :At_mul_B , :At_mul_Bt ,
42
- :A_mul_Bc , :Ac_mul_B , :Ac_mul_Bc )
69
+ for f in A_MUL_B_FUNCS
43
70
@eval function Base. $ (f){X,Y,S}(x:: $ (A){Tracked{X,S}}, y:: $ (A){Tracked{Y,S}})
44
71
tp = tape (x, y)
45
72
xval, yval = value (x), value (y)
46
73
out = track ($ (f)(xval, yval), S, tp)
47
74
record! (tp, $ (f), (x, y), out, (xval, yval))
48
75
return out
49
76
end
77
+
78
+ @eval function Base. $ (f){X,S}(x:: $ (A){Tracked{X,S}}, y:: $ (A))
79
+ tp = tape (x)
80
+ xval = value (x)
81
+ out = track ($ (f)(xval, y), S, tp)
82
+ record! (tp, $ (f), (x, nothing ), out, y)
83
+ return out
84
+ end
85
+
86
+ @eval function Base. $ (f){Y,S}(x:: $ (A), y:: $ (A){Tracked{Y,S}})
87
+ tp = tape (y)
88
+ yval = value (y)
89
+ out = track ($ (f)(x, yval), S, tp)
90
+ record! (tp, $ (f), (nothing , y), out, x)
91
+ return out
92
+ end
50
93
end
51
94
52
95
# in-place versions
53
- for (f!, f) in ((:A_mul_B! , :* ),
54
- (:A_mul_Bt! , :A_mul_Bt ), (:At_mul_B! , :At_mul_B ), (:At_mul_Bt! , :At_mul_Bt ),
55
- (:A_mul_Bc! , :A_mul_Bc ), (:Ac_mul_B! , :Ac_mul_B ), (:Ac_mul_Bc! , :Ac_mul_Bc ))
96
+ for (f!, f) in A_MUL_B!_FUNCS
56
97
@eval function Base. $ (f!){V,X,Y,S}(out:: $ (A){Tracked{V,S}},
57
98
x:: $ (A){Tracked{X,S}},
58
99
y:: $ (A){Tracked{Y,S}})
@@ -62,6 +103,26 @@ for A in ARRAY_TYPES
62
103
record! (tp, $ (f), (x, y), out, (xval, yval))
63
104
return out
64
105
end
106
+
107
+ @eval function Base. $ (f!){V,X,S}(out:: $ (A){Tracked{V,S}},
108
+ x:: $ (A){Tracked{X,S}},
109
+ y:: $ (A))
110
+ tp = tape (x)
111
+ xval = value (x)
112
+ track! (out, $ (f)(xval, y), tp)
113
+ record! (tp, $ (f), (x, nothing ), out, y)
114
+ return out
115
+ end
116
+
117
+ @eval function Base. $ (f!){V,Y,S}(out:: $ (A){Tracked{V,S}},
118
+ x:: $ (A),
119
+ y:: $ (A){Tracked{Y,S}})
120
+ tp = tape (y)
121
+ yval = value (y)
122
+ track! (out, $ (f)(x, yval), tp)
123
+ record! (tp, $ (f), (nothing , y), out, x)
124
+ return out
125
+ end
65
126
end
66
127
67
128
# linear algebra #
@@ -107,8 +168,8 @@ function special_reverse_step!{A,B}(::typeof(+), inputs::Tuple{A,B}, output::Abs
107
168
return nothing
108
169
end
109
170
110
- function special_reverse_step! (:: typeof (- ), input, output, _)
111
- extract_and_decrement_adjoint ! (input, output)
171
+ function special_reverse_step! (:: typeof (+ ), input, output, _)
172
+ extract_and_increment_adjoint ! (input, output)
112
173
return nothing
113
174
end
114
175
@@ -118,9 +179,16 @@ function special_reverse_step!{A,B}(::typeof(-), inputs::Tuple{A,B}, output::Abs
118
179
return nothing
119
180
end
120
181
182
+ function special_reverse_step! (:: typeof (- ), input, output, _)
183
+ extract_and_decrement_adjoint! (input, output)
184
+ return nothing
185
+ end
186
+
121
187
# A_mul_B family #
122
188
# ----------------#
123
189
190
+ # *
191
+
124
192
function special_reverse_step! {A,B} (:: typeof (* ), inputs:: Tuple{A,B} , output, vals)
125
193
a, b = inputs
126
194
aval, bval = vals
@@ -130,6 +198,24 @@ function special_reverse_step!{A,B}(::typeof(*), inputs::Tuple{A,B}, output, val
130
198
return nothing
131
199
end
132
200
201
+ function special_reverse_step! {T} (:: typeof (* ), inputs:: Tuple{T,Void} , output, vals)
202
+ a, _ = inputs
203
+ bval = vals
204
+ output_adjoint = adjoint (output)
205
+ increment_adjoint! (a, output_adjoint * bval' )
206
+ return nothing
207
+ end
208
+
209
+ function special_reverse_step! {T} (:: typeof (* ), inputs:: Tuple{Void,T} , output, vals)
210
+ _, b = inputs
211
+ aval = vals
212
+ output_adjoint = adjoint (output)
213
+ increment_adjoint! (b, aval' * output_adjoint)
214
+ return nothing
215
+ end
216
+
217
+ # A_mul_Bt
218
+
133
219
function special_reverse_step! {A,B} (:: typeof (A_mul_Bt), inputs:: Tuple{A,B} , output, vals)
134
220
a, b = inputs
135
221
aval, bval = vals
@@ -139,6 +225,24 @@ function special_reverse_step!{A,B}(::typeof(A_mul_Bt), inputs::Tuple{A,B}, outp
139
225
return nothing
140
226
end
141
227
228
+ function special_reverse_step! {T} (:: typeof (A_mul_Bt), inputs:: Tuple{T,Void} , output, vals)
229
+ a, _ = inputs
230
+ bval = vals
231
+ output_adjoint = adjoint (output)
232
+ increment_adjoint! (a, output_adjoint * bval)
233
+ return nothing
234
+ end
235
+
236
+ function special_reverse_step! {T} (:: typeof (A_mul_Bt), inputs:: Tuple{Void,T} , output, vals)
237
+ _, b = inputs
238
+ aval = vals
239
+ output_adjoint = adjoint (output)
240
+ increment_adjoint! (b, output_adjoint.' * aval)
241
+ return nothing
242
+ end
243
+
244
+ # At_mul_B
245
+
142
246
function special_reverse_step! {A,B} (:: typeof (At_mul_B), inputs:: Tuple{A,B} , output, vals)
143
247
a, b = inputs
144
248
aval, bval = vals
@@ -148,6 +252,24 @@ function special_reverse_step!{A,B}(::typeof(At_mul_B), inputs::Tuple{A,B}, outp
148
252
return nothing
149
253
end
150
254
255
+ function special_reverse_step! {T} (:: typeof (At_mul_B), inputs:: Tuple{T,Void} , output, vals)
256
+ a, _ = inputs
257
+ bval = vals
258
+ output_adjoint = adjoint (output)
259
+ increment_adjoint! (a, bval * output_adjoint.' )
260
+ return nothing
261
+ end
262
+
263
+ function special_reverse_step! {T} (:: typeof (At_mul_B), inputs:: Tuple{Void,T} , output, vals)
264
+ _, b = inputs
265
+ aval = vals
266
+ output_adjoint = adjoint (output)
267
+ increment_adjoint! (b, aval * output_adjoint)
268
+ return nothing
269
+ end
270
+
271
+ # At_mul_Bt
272
+
151
273
function special_reverse_step! {A,B} (:: typeof (At_mul_Bt), inputs:: Tuple{A,B} , output, vals)
152
274
a, b = inputs
153
275
aval, bval = vals
@@ -157,6 +279,24 @@ function special_reverse_step!{A,B}(::typeof(At_mul_Bt), inputs::Tuple{A,B}, out
157
279
return nothing
158
280
end
159
281
282
+ function special_reverse_step! {T} (:: typeof (At_mul_Bt), inputs:: Tuple{T,Void} , output, vals)
283
+ a, _ = inputs
284
+ bval = vals
285
+ output_adjoint = adjoint (output)
286
+ increment_adjoint! (a, (output_adjoint * bval).' )
287
+ return nothing
288
+ end
289
+
290
+ function special_reverse_step! {T} (:: typeof (At_mul_Bt), inputs:: Tuple{Void,T} , output, vals)
291
+ _, b = inputs
292
+ aval = vals
293
+ output_adjoint = adjoint (output)
294
+ increment_adjoint! (b, (aval * output_adjoint).' )
295
+ return nothing
296
+ end
297
+
298
+ # A_mul_Bc
299
+
160
300
function special_reverse_step! {A,B} (:: typeof (A_mul_Bc), inputs:: Tuple{A,B} , output, vals)
161
301
a, b = inputs
162
302
aval, bval = vals
@@ -166,6 +306,24 @@ function special_reverse_step!{A,B}(::typeof(A_mul_Bc), inputs::Tuple{A,B}, outp
166
306
return nothing
167
307
end
168
308
309
+ function special_reverse_step! {T} (:: typeof (A_mul_Bc), inputs:: Tuple{T,Void} , output, vals)
310
+ a, _ = inputs
311
+ bval = vals
312
+ output_adjoint = adjoint (output)
313
+ increment_adjoint! (a, output_adjoint * bval)
314
+ return nothing
315
+ end
316
+
317
+ function special_reverse_step! {T} (:: typeof (A_mul_Bc), inputs:: Tuple{Void,T} , output, vals)
318
+ _, b = inputs
319
+ aval = vals
320
+ output_adjoint = adjoint (output)
321
+ increment_adjoint! (b, output_adjoint' * aval)
322
+ return nothing
323
+ end
324
+
325
+ # Ac_mul_B
326
+
169
327
function special_reverse_step! {A,B} (:: typeof (Ac_mul_B), inputs:: Tuple{A,B} , output, vals)
170
328
a, b = inputs
171
329
aval, bval = vals
@@ -175,6 +333,24 @@ function special_reverse_step!{A,B}(::typeof(Ac_mul_B), inputs::Tuple{A,B}, outp
175
333
return nothing
176
334
end
177
335
336
+ function special_reverse_step! {T} (:: typeof (Ac_mul_B), inputs:: Tuple{T,Void} , output, vals)
337
+ a, _ = inputs
338
+ bval = vals
339
+ output_adjoint = adjoint (output)
340
+ increment_adjoint! (a, bval * output_adjoint' )
341
+ return nothing
342
+ end
343
+
344
+ function special_reverse_step! {T} (:: typeof (Ac_mul_B), inputs:: Tuple{Void,T} , output, vals)
345
+ _, b = inputs
346
+ aval = vals
347
+ output_adjoint = adjoint (output)
348
+ increment_adjoint! (b, aval * output_adjoint)
349
+ return nothing
350
+ end
351
+
352
+ # Ac_mul_Bc
353
+
178
354
function special_reverse_step! {A,B} (:: typeof (Ac_mul_Bc), inputs:: Tuple{A,B} , output, vals)
179
355
a, b = inputs
180
356
aval, bval = vals
@@ -184,6 +360,22 @@ function special_reverse_step!{A,B}(::typeof(Ac_mul_Bc), inputs::Tuple{A,B}, out
184
360
return nothing
185
361
end
186
362
363
+ function special_reverse_step! {T} (:: typeof (Ac_mul_Bc), inputs:: Tuple{T,Void} , output, vals)
364
+ a, _ = inputs
365
+ bval = vals
366
+ output_adjoint = adjoint (output)
367
+ increment_adjoint! (a, (output_adjoint * bval)' )
368
+ return nothing
369
+ end
370
+
371
+ function special_reverse_step! {T} (:: typeof (Ac_mul_Bc), inputs:: Tuple{Void,T} , output, vals)
372
+ _, b = inputs
373
+ aval = vals
374
+ output_adjoint = adjoint (output)
375
+ increment_adjoint! (b, (aval * output_adjoint)' )
376
+ return nothing
377
+ end
378
+
187
379
# special functions #
188
380
# -------------------#
189
381
0 commit comments