Skip to content

Commit 7c83dbd

Browse files
committed
add LinAlg derivative tests
1 parent f5219b5 commit 7c83dbd

11 files changed

+535
-171
lines changed

Diff for: src/ReverseDiffPrototype.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using ForwardDiff: Dual, Partials, partials
99
# Not all operations will be valid over all of these types, but that's okay; such cases
1010
# will simply error when they hit the original operation in the overloaded definition.
1111
const ARRAY_TYPES = (:AbstractArray, :AbstractVector, :AbstractMatrix, :Array, :Vector, :Matrix)
12-
const REAL_TYPES = (:Bool, :Integer, :Rational, :Real, :Dual)
12+
const REAL_TYPES = (:Bool, :Integer, :Rational, :AbstractFloat, :Real, :Dual)
1313

1414
const FORWARD_UNARY_SCALAR_FUNCS = (ForwardDiff.AUTO_DEFINED_UNARY_FUNCS..., :-, :abs, :conj)
1515
const FORWARD_BINARY_SCALAR_FUNCS = (:*, :/, :+, :-, :^, :atan2)

Diff for: src/Tape.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222
# Ensure that the external state is "captured" so that external
2323
# reference-breaking (e.g. destructive assignment) doesn't break
2424
# internal TapeNode state.
25-
@inline capture(state::Number) = state
25+
@inline capture(state) = state
2626
@inline capture(state::AbstractArray) = copy(state)
2727
@inline capture(state::Tuple{Vararg{Number}}) = state
2828
@inline capture(state::Tuple) = map(capture, state)

Diff for: src/derivatives/linalg.jl

+200-8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
# forward #
33
###########
44

5+
const A_MUL_B_FUNCS = (:*, :A_mul_Bt, :At_mul_B, :At_mul_Bt, :A_mul_Bc, :Ac_mul_B, :Ac_mul_Bc)
6+
7+
const A_MUL_B!_FUNCS = ((:A_mul_B!, :*),
8+
(:A_mul_Bt!, :A_mul_Bt), (:At_mul_B!, :At_mul_B), (:At_mul_Bt!, :At_mul_Bt),
9+
(:A_mul_Bc!, :A_mul_Bc), (:Ac_mul_B!, :Ac_mul_B), (:Ac_mul_Bc!, :Ac_mul_Bc))
10+
511
for A in ARRAY_TYPES
612

713
# addition/subtraction #
@@ -25,6 +31,29 @@ for A in ARRAY_TYPES
2531
record!(tp, $(f), (x, y), out)
2632
return out
2733
end
34+
35+
if f != :-
36+
@eval function Base.$(f){X,S}(x::$(A){Tracked{X,S}}, y::$(A))
37+
tp = tape(x)
38+
out = track($(f)(value(x), y), S, tp)
39+
record!(tp, $(f), x, out)
40+
return out
41+
end
42+
end
43+
44+
@eval function Base.$(f){Y,S}(x::$(A), y::$(A){Tracked{Y,S}})
45+
tp = tape(y)
46+
out = track($(f)(x, value(y)), S, tp)
47+
record!(tp, $(f), y, out)
48+
return out
49+
end
50+
end
51+
52+
@eval function Base.:-{X,S}(x::$(A){Tracked{X,S}}, y::$(A))
53+
tp = tape(x)
54+
out = track(value(x) - y, S, tp)
55+
record!(tp, +, x, out)
56+
return out
2857
end
2958

3059
@eval function Base.:-{V,S}(x::$(A){Tracked{V,S}})
@@ -37,22 +66,34 @@ for A in ARRAY_TYPES
3766
# A_mul_B family #
3867
#----------------#
3968

40-
for f in (:*,
41-
:A_mul_Bt, :At_mul_B, :At_mul_Bt,
42-
:A_mul_Bc, :Ac_mul_B, :Ac_mul_Bc)
69+
for f in A_MUL_B_FUNCS
4370
@eval function Base.$(f){X,Y,S}(x::$(A){Tracked{X,S}}, y::$(A){Tracked{Y,S}})
4471
tp = tape(x, y)
4572
xval, yval = value(x), value(y)
4673
out = track($(f)(xval, yval), S, tp)
4774
record!(tp, $(f), (x, y), out, (xval, yval))
4875
return out
4976
end
77+
78+
@eval function Base.$(f){X,S}(x::$(A){Tracked{X,S}}, y::$(A))
79+
tp = tape(x)
80+
xval = value(x)
81+
out = track($(f)(xval, y), S, tp)
82+
record!(tp, $(f), (x, nothing), out, y)
83+
return out
84+
end
85+
86+
@eval function Base.$(f){Y,S}(x::$(A), y::$(A){Tracked{Y,S}})
87+
tp = tape(y)
88+
yval = value(y)
89+
out = track($(f)(x, yval), S, tp)
90+
record!(tp, $(f), (nothing, y), out, x)
91+
return out
92+
end
5093
end
5194

5295
# in-place versions
53-
for (f!, f) in ((:A_mul_B!, :*),
54-
(:A_mul_Bt!, :A_mul_Bt), (:At_mul_B!, :At_mul_B), (:At_mul_Bt!, :At_mul_Bt),
55-
(:A_mul_Bc!, :A_mul_Bc), (:Ac_mul_B!, :Ac_mul_B), (:Ac_mul_Bc!, :Ac_mul_Bc))
96+
for (f!, f) in A_MUL_B!_FUNCS
5697
@eval function Base.$(f!){V,X,Y,S}(out::$(A){Tracked{V,S}},
5798
x::$(A){Tracked{X,S}},
5899
y::$(A){Tracked{Y,S}})
@@ -62,6 +103,26 @@ for A in ARRAY_TYPES
62103
record!(tp, $(f), (x, y), out, (xval, yval))
63104
return out
64105
end
106+
107+
@eval function Base.$(f!){V,X,S}(out::$(A){Tracked{V,S}},
108+
x::$(A){Tracked{X,S}},
109+
y::$(A))
110+
tp = tape(x)
111+
xval = value(x)
112+
track!(out, $(f)(xval, y), tp)
113+
record!(tp, $(f), (x, nothing), out, y)
114+
return out
115+
end
116+
117+
@eval function Base.$(f!){V,Y,S}(out::$(A){Tracked{V,S}},
118+
x::$(A),
119+
y::$(A){Tracked{Y,S}})
120+
tp = tape(y)
121+
yval = value(y)
122+
track!(out, $(f)(x, yval), tp)
123+
record!(tp, $(f), (nothing, y), out, x)
124+
return out
125+
end
65126
end
66127

67128
# linear algebra #
@@ -107,8 +168,8 @@ function special_reverse_step!{A,B}(::typeof(+), inputs::Tuple{A,B}, output::Abs
107168
return nothing
108169
end
109170

110-
function special_reverse_step!(::typeof(-), input, output, _)
111-
extract_and_decrement_adjoint!(input, output)
171+
function special_reverse_step!(::typeof(+), input, output, _)
172+
extract_and_increment_adjoint!(input, output)
112173
return nothing
113174
end
114175

@@ -118,9 +179,16 @@ function special_reverse_step!{A,B}(::typeof(-), inputs::Tuple{A,B}, output::Abs
118179
return nothing
119180
end
120181

182+
function special_reverse_step!(::typeof(-), input, output, _)
183+
extract_and_decrement_adjoint!(input, output)
184+
return nothing
185+
end
186+
121187
# A_mul_B family #
122188
#----------------#
123189

190+
# *
191+
124192
function special_reverse_step!{A,B}(::typeof(*), inputs::Tuple{A,B}, output, vals)
125193
a, b = inputs
126194
aval, bval = vals
@@ -130,6 +198,24 @@ function special_reverse_step!{A,B}(::typeof(*), inputs::Tuple{A,B}, output, val
130198
return nothing
131199
end
132200

201+
function special_reverse_step!{T}(::typeof(*), inputs::Tuple{T,Void}, output, vals)
202+
a, _ = inputs
203+
bval = vals
204+
output_adjoint = adjoint(output)
205+
increment_adjoint!(a, output_adjoint * bval')
206+
return nothing
207+
end
208+
209+
function special_reverse_step!{T}(::typeof(*), inputs::Tuple{Void,T}, output, vals)
210+
_, b = inputs
211+
aval = vals
212+
output_adjoint = adjoint(output)
213+
increment_adjoint!(b, aval' * output_adjoint)
214+
return nothing
215+
end
216+
217+
# A_mul_Bt
218+
133219
function special_reverse_step!{A,B}(::typeof(A_mul_Bt), inputs::Tuple{A,B}, output, vals)
134220
a, b = inputs
135221
aval, bval = vals
@@ -139,6 +225,24 @@ function special_reverse_step!{A,B}(::typeof(A_mul_Bt), inputs::Tuple{A,B}, outp
139225
return nothing
140226
end
141227

228+
function special_reverse_step!{T}(::typeof(A_mul_Bt), inputs::Tuple{T,Void}, output, vals)
229+
a, _ = inputs
230+
bval = vals
231+
output_adjoint = adjoint(output)
232+
increment_adjoint!(a, output_adjoint * bval)
233+
return nothing
234+
end
235+
236+
function special_reverse_step!{T}(::typeof(A_mul_Bt), inputs::Tuple{Void,T}, output, vals)
237+
_, b = inputs
238+
aval = vals
239+
output_adjoint = adjoint(output)
240+
increment_adjoint!(b, output_adjoint.' * aval)
241+
return nothing
242+
end
243+
244+
# At_mul_B
245+
142246
function special_reverse_step!{A,B}(::typeof(At_mul_B), inputs::Tuple{A,B}, output, vals)
143247
a, b = inputs
144248
aval, bval = vals
@@ -148,6 +252,24 @@ function special_reverse_step!{A,B}(::typeof(At_mul_B), inputs::Tuple{A,B}, outp
148252
return nothing
149253
end
150254

255+
function special_reverse_step!{T}(::typeof(At_mul_B), inputs::Tuple{T,Void}, output, vals)
256+
a, _ = inputs
257+
bval = vals
258+
output_adjoint = adjoint(output)
259+
increment_adjoint!(a, bval * output_adjoint.')
260+
return nothing
261+
end
262+
263+
function special_reverse_step!{T}(::typeof(At_mul_B), inputs::Tuple{Void,T}, output, vals)
264+
_, b = inputs
265+
aval = vals
266+
output_adjoint = adjoint(output)
267+
increment_adjoint!(b, aval * output_adjoint)
268+
return nothing
269+
end
270+
271+
# At_mul_Bt
272+
151273
function special_reverse_step!{A,B}(::typeof(At_mul_Bt), inputs::Tuple{A,B}, output, vals)
152274
a, b = inputs
153275
aval, bval = vals
@@ -157,6 +279,24 @@ function special_reverse_step!{A,B}(::typeof(At_mul_Bt), inputs::Tuple{A,B}, out
157279
return nothing
158280
end
159281

282+
function special_reverse_step!{T}(::typeof(At_mul_Bt), inputs::Tuple{T,Void}, output, vals)
283+
a, _ = inputs
284+
bval = vals
285+
output_adjoint = adjoint(output)
286+
increment_adjoint!(a, (output_adjoint * bval).')
287+
return nothing
288+
end
289+
290+
function special_reverse_step!{T}(::typeof(At_mul_Bt), inputs::Tuple{Void,T}, output, vals)
291+
_, b = inputs
292+
aval = vals
293+
output_adjoint = adjoint(output)
294+
increment_adjoint!(b, (aval * output_adjoint).')
295+
return nothing
296+
end
297+
298+
# A_mul_Bc
299+
160300
function special_reverse_step!{A,B}(::typeof(A_mul_Bc), inputs::Tuple{A,B}, output, vals)
161301
a, b = inputs
162302
aval, bval = vals
@@ -166,6 +306,24 @@ function special_reverse_step!{A,B}(::typeof(A_mul_Bc), inputs::Tuple{A,B}, outp
166306
return nothing
167307
end
168308

309+
function special_reverse_step!{T}(::typeof(A_mul_Bc), inputs::Tuple{T,Void}, output, vals)
310+
a, _ = inputs
311+
bval = vals
312+
output_adjoint = adjoint(output)
313+
increment_adjoint!(a, output_adjoint * bval)
314+
return nothing
315+
end
316+
317+
function special_reverse_step!{T}(::typeof(A_mul_Bc), inputs::Tuple{Void,T}, output, vals)
318+
_, b = inputs
319+
aval = vals
320+
output_adjoint = adjoint(output)
321+
increment_adjoint!(b, output_adjoint' * aval)
322+
return nothing
323+
end
324+
325+
# Ac_mul_B
326+
169327
function special_reverse_step!{A,B}(::typeof(Ac_mul_B), inputs::Tuple{A,B}, output, vals)
170328
a, b = inputs
171329
aval, bval = vals
@@ -175,6 +333,24 @@ function special_reverse_step!{A,B}(::typeof(Ac_mul_B), inputs::Tuple{A,B}, outp
175333
return nothing
176334
end
177335

336+
function special_reverse_step!{T}(::typeof(Ac_mul_B), inputs::Tuple{T,Void}, output, vals)
337+
a, _ = inputs
338+
bval = vals
339+
output_adjoint = adjoint(output)
340+
increment_adjoint!(a, bval * output_adjoint')
341+
return nothing
342+
end
343+
344+
function special_reverse_step!{T}(::typeof(Ac_mul_B), inputs::Tuple{Void,T}, output, vals)
345+
_, b = inputs
346+
aval = vals
347+
output_adjoint = adjoint(output)
348+
increment_adjoint!(b, aval * output_adjoint)
349+
return nothing
350+
end
351+
352+
# Ac_mul_Bc
353+
178354
function special_reverse_step!{A,B}(::typeof(Ac_mul_Bc), inputs::Tuple{A,B}, output, vals)
179355
a, b = inputs
180356
aval, bval = vals
@@ -184,6 +360,22 @@ function special_reverse_step!{A,B}(::typeof(Ac_mul_Bc), inputs::Tuple{A,B}, out
184360
return nothing
185361
end
186362

363+
function special_reverse_step!{T}(::typeof(Ac_mul_Bc), inputs::Tuple{T,Void}, output, vals)
364+
a, _ = inputs
365+
bval = vals
366+
output_adjoint = adjoint(output)
367+
increment_adjoint!(a, (output_adjoint * bval)')
368+
return nothing
369+
end
370+
371+
function special_reverse_step!{T}(::typeof(Ac_mul_Bc), inputs::Tuple{Void,T}, output, vals)
372+
_, b = inputs
373+
aval = vals
374+
output_adjoint = adjoint(output)
375+
increment_adjoint!(b, (aval * output_adjoint)')
376+
return nothing
377+
end
378+
187379
# special functions #
188380
#-------------------#
189381

Diff for: src/utils.jl

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ seed!(t::TapeNode) = (seed!(t.outputs); return t)
7676

7777
unseed!(t::Tracked) = (t.adjoint = zero(adjtype(t)); return t)
7878
unseed!(t::TapeNode) = (unseed!(t.inputs); unseed!(t.outputs); return t)
79+
unseed!(::Void) = nothing
7980
unseed!(ts) = for t in ts; unseed!(t); end
8081

8182
#######################

0 commit comments

Comments
 (0)