Skip to content

Commit 481d02e

Browse files
committed
add elementwise tests (broadcast and map)
1 parent 7c83dbd commit 481d02e

File tree

5 files changed

+313
-24
lines changed

5 files changed

+313
-24
lines changed

src/derivatives/elementwise.jl

+88-22
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# forward #
33
###########
44

5-
function dualwrap{N,T,A}(duals::AbstractArray{Dual{N,T}}, ::Type{A}, tp::Nullable{Tape})
5+
function retrack_duals{N,T,A}(duals::AbstractArray{Dual{N,T}}, ::Type{A}, tp::Nullable{Tape})
66
ts = similar(duals, Tracked{T,A})
77
for i in eachindex(duals)
88
ts[i] = Tracked(value(duals[i]), A, tp)
@@ -19,7 +19,7 @@ for A in ARRAY_TYPES
1919
fdual = t -> fopt.f(Dual(value(t), one(V)))
2020
duals = $(g)(fdual, x)
2121
tp = tape(x)
22-
out = dualwrap(duals, S, tp)
22+
out = retrack_duals(duals, S, tp)
2323
record!(tp, $(g), x, out, duals)
2424
return out
2525
end
@@ -31,10 +31,32 @@ for A in ARRAY_TYPES
3131
Dual(value(t2), zero(V2), one(V2)))
3232
duals = $(g)(fdual, x1, x2)
3333
tp = tape(x1, x2)
34-
out = dualwrap(duals, S, tp)
34+
out = retrack_duals(duals, S, tp)
3535
record!(tp, $(g), (x1, x2), out, duals)
3636
return out
3737
end
38+
39+
function Base.$(g){F,V,S}(fopt::ForwardOptimize{F},
40+
x1::$(A){Tracked{V,S}},
41+
x2::$(A))
42+
fdual = (t1, t2) -> fopt.f(Dual(value(t1), one(V)), t2)
43+
duals = $(g)(fdual, x1, x2)
44+
tp = tape(x1)
45+
out = retrack_duals(duals, S, tp)
46+
record!(tp, $(g), x1, out, duals)
47+
return out
48+
end
49+
50+
function Base.$(g){F,V,S}(fopt::ForwardOptimize{F},
51+
x1::$(A),
52+
x2::$(A){Tracked{V,S}})
53+
fdual = (t1, t2) -> fopt.f(t1, Dual(value(t2), one(V)))
54+
duals = $(g)(fdual, x1, x2)
55+
tp = tape(x2)
56+
out = retrack_duals(duals, S, tp)
57+
record!(tp, $(g), x2, out, duals)
58+
return out
59+
end
3860
end
3961
end
4062

@@ -58,7 +80,7 @@ for A in ARRAY_TYPES
5880
fdual = t -> fopt.f(ndual, Dual(value(t), zero(X), one(X)))
5981
duals = broadcast(fdual, x)
6082
tp = tape(n, x)
61-
out = dualwrap(duals, S, tp)
83+
out = retrack_duals(duals, S, tp)
6284
record!(tp, broadcast, (n, x), out, duals)
6385
return out
6486
end
@@ -68,10 +90,30 @@ for A in ARRAY_TYPES
6890
fdual = t -> fopt.f(Dual(value(t), one(X), zero(X)), ndual)
6991
duals = broadcast(fdual, x)
7092
tp = tape(n, x)
71-
out = dualwrap(duals, S, tp)
93+
out = retrack_duals(duals, S, tp)
7294
record!(tp, broadcast, (x, n), out, duals)
7395
return out
7496
end
97+
98+
function Base.broadcast{F,V,S}(fopt::ForwardOptimize{F}, n::Tracked{V,S}, x::$(A))
99+
ndual = Dual(value(n), one(V))
100+
fdual = t -> fopt.f(ndual, t)
101+
duals = broadcast(fdual, x)
102+
tp = tape(n)
103+
out = retrack_duals(duals, S, tp)
104+
record!(tp, broadcast, n, out, duals)
105+
return out
106+
end
107+
108+
function Base.broadcast{F,V,S}(fopt::ForwardOptimize{F}, x::$(A), n::Tracked{V,S})
109+
ndual = Dual(value(n), one(V))
110+
fdual = t -> fopt.f(t, ndual)
111+
duals = broadcast(fdual, x)
112+
tp = tape(n)
113+
out = retrack_duals(duals, S, tp)
114+
record!(tp, broadcast, n, out, duals)
115+
return out
116+
end
75117
end
76118

77119
# standard elementwise operations (.+, .-, .*, etc.) #
@@ -82,13 +124,29 @@ for A in ARRAY_TYPES
82124
return broadcast(ForwardOptimize($(f)), x, y)
83125
end
84126

127+
@inline function Base.$(f){X<:Tracked}(x::$(A){X}, y::$(A))
128+
return broadcast(ForwardOptimize($(f)), x, y)
129+
end
130+
131+
@inline function Base.$(f){Y<:Tracked}(x::$(A), y::$(A){Y})
132+
return broadcast(ForwardOptimize($(f)), x, y)
133+
end
134+
85135
@inline function Base.$(f){T<:Tracked}(n::Tracked, x::$(A){T})
86136
return broadcast(ForwardOptimize($(f)), n, x)
87137
end
88138

89139
@inline function Base.$(f){T<:Tracked}(x::$(A){T}, n::Tracked)
90140
return broadcast(ForwardOptimize($(f)), x, n)
91141
end
142+
143+
@inline function Base.$(f)(n::Tracked, x::$(A))
144+
return broadcast(ForwardOptimize($(f)), n, x)
145+
end
146+
147+
@inline function Base.$(f)(x::$(A), n::Tracked)
148+
return broadcast(ForwardOptimize($(f)), x, n)
149+
end
92150
end
93151
for R in REAL_TYPES
94152
@eval begin
@@ -138,26 +196,34 @@ function special_reverse_step!{A,B}(::typeof(broadcast), inputs::Tuple{A,B}, out
138196
if size(a) == size(b)
139197
special_reverse_step!(map, inputs, output, duals)
140198
else
141-
for i in eachindex(duals)
142-
duals[i] *= adjoint(output[i])
143-
end
144-
s = sumover(1, a, duals)
145-
increment_adjoint!(a, s)
146-
increment_adjoint!(b, sumover(2, b, duals))
199+
broadcast_adjoint_reduce!(a, output, duals, 1)
200+
broadcast_adjoint_reduce!(b, output, duals, 2)
147201
end
148202
return nothing
149203
end
150204

151-
# Inference here is pretty wonky (see JuliaLang/julia#10533),
152-
# so it's important that we allocate the array for the sum
153-
# result ourselves. Otherwise, `reducedim_init` tries to
154-
# allocate an array of the wrong type in some cases, which
155-
# leads to conversion errors.
156-
function sumover{N,M,T}(p, x::AbstractArray, duals::AbstractArray{Dual{N,T},M})
157-
dims = (size(x, i) != size(duals, i) ? 1 : size(duals, i) for i in 1:ndims(duals))
158-
result = similar(duals, T, (dims...)::NTuple{M,Int})
159-
sum!(d -> partials(d, p), result, duals)
160-
return result
205+
function special_reverse_step!(::typeof(broadcast), input::Number, output, duals)
206+
broadcast_adjoint_reduce!(input, output, duals, 1)
207+
return nothing
208+
end
209+
210+
# This strategy should be pretty fast, but it might be prone to numerical error if the
211+
# accumulated adjoint becomes too large compared to the individual terms being added to
212+
# it. This can be overcome by using the divide-and-conquer strategy used by
213+
# Base.mapreducedim, but that strategy is less cache efficient and more complicated to
214+
# implement.
215+
function broadcast_adjoint_reduce!{T,N}(input::AbstractArray, output::AbstractArray{T,N}, duals, p)
216+
dims = (size(input, i) != size(duals, i) ? 1 : size(duals, i) for i in 1:ndims(duals))
217+
max_index = CartesianIndex((dims...)::NTuple{N,Int})
218+
for i in CartesianRange(size(input))
219+
increment_adjoint!(input[min(max_index, i)], adjoint(output[i]) * partials(duals[i], p))
220+
end
221+
return nothing
161222
end
162223

163-
sumover(p, x::Real, duals) = sum(d -> partials(d, p), duals)
224+
function broadcast_adjoint_reduce!{T,N}(input::Number, output::AbstractArray{T,N}, duals, p)
225+
for i in eachindex(duals)
226+
increment_adjoint!(input, adjoint(output[i]) * partials(duals[i], p))
227+
end
228+
return nothing
229+
end

test/derivatives/ElementwiseTests.jl

+222
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# module ElementwiseTests
2+
3+
using ReverseDiffPrototype, ForwardDiff, Base.Test
4+
5+
include("../utils.jl")
6+
7+
println("testing elementwise derivatives (both forward and reverse passes)")
8+
tic()
9+
10+
############################################################################################
11+
x, a, b, n = rand(3, 3), rand(3, 3), rand(3, 3), rand()
12+
tp = Tape()
13+
14+
function test_elementwise(f, x, tp)
15+
xt = track(x, tp)
16+
y = map(f, x)
17+
18+
out = similar(y, (length(x), length(x)))
19+
yt = map(RDP.@forward(f), xt)
20+
@test yt == y
21+
@test length(tp) == 1
22+
RDP.jacobian_reverse_pass!(out, yt, xt, tp)
23+
@test_approx_eq_eps out ForwardDiff.jacobian(z -> map(f, z), x) EPS
24+
empty!(tp)
25+
26+
y = broadcast(RDP.@forward(f), x)
27+
out = similar(y, (length(x), length(x)))
28+
yt = broadcast(RDP.@forward(f), xt)
29+
@test yt == y
30+
@test length(tp) == 1
31+
RDP.jacobian_reverse_pass!(out, yt, xt, tp)
32+
@test_approx_eq_eps out ForwardDiff.jacobian(z -> broadcast(f, z), x) EPS
33+
empty!(tp)
34+
end
35+
36+
function test_map(f, a, b, tp)
37+
at, bt = track(a, tp), track(b, tp)
38+
c = map(f, a, b)
39+
40+
out = similar(c, (length(a), length(a)))
41+
ct = map(RDP.@forward(f), at, b)
42+
@test ct == c
43+
@test length(tp) == 1
44+
RDP.jacobian_reverse_pass!(out, ct, at, tp)
45+
@test_approx_eq_eps out ForwardDiff.jacobian(x -> map(f, x, b), a) EPS
46+
RDP.unseed!(tp)
47+
empty!(tp)
48+
49+
out = similar(c, (length(a), length(a)))
50+
ct = map(RDP.@forward(f), a, bt)
51+
@test ct == c
52+
@test length(tp) == 1
53+
RDP.jacobian_reverse_pass!(out, ct, bt, tp)
54+
@test_approx_eq_eps out ForwardDiff.jacobian(x -> map(f, a, x), b) EPS
55+
RDP.unseed!(tp)
56+
empty!(tp)
57+
58+
out_a = similar(c, (length(a), length(a)))
59+
out_b = similar(c, (length(a), length(a)))
60+
ct = map(RDP.@forward(f), at, bt)
61+
@test ct == c
62+
@test length(tp) == 1
63+
RDP.jacobian_reverse_pass!(out_a, ct, at, tp)
64+
RDP.jacobian_reverse_pass!(out_b, ct, bt, tp)
65+
@test_approx_eq_eps out_a ForwardDiff.jacobian(x -> map(f, x, b), a) EPS
66+
@test_approx_eq_eps out_b ForwardDiff.jacobian(x -> map(f, a, x), b) EPS
67+
RDP.unseed!(tp)
68+
empty!(tp)
69+
end
70+
71+
function test_broadcast(f, a::AbstractArray, b::AbstractArray, tp, builtin = false)
72+
at, bt = track(a, tp), track(b, tp)
73+
74+
if builtin
75+
g = RDP.@forward(f)
76+
else
77+
g = (x, y) -> broadcast(RDP.@forward(f), x, y)
78+
end
79+
80+
c = g(a, b)
81+
82+
out = similar(c, (length(c), length(a)))
83+
ct = g(at, b)
84+
@test ct == c
85+
@test length(tp) == 1
86+
RDP.jacobian_reverse_pass!(out, ct, at, tp)
87+
@test_approx_eq_eps out ForwardDiff.jacobian(x -> g(x, b), a) EPS
88+
RDP.unseed!(tp)
89+
empty!(tp)
90+
91+
out = similar(c, (length(c), length(b)))
92+
ct = g(a, bt)
93+
@test ct == c
94+
@test length(tp) == 1
95+
RDP.jacobian_reverse_pass!(out, ct, bt, tp)
96+
@test_approx_eq_eps out ForwardDiff.jacobian(x -> g(a, x), b) EPS
97+
RDP.unseed!(tp)
98+
empty!(tp)
99+
100+
out_a = similar(c, (length(c), length(a)))
101+
out_b = similar(c, (length(c), length(b)))
102+
ct = g(at, bt)
103+
@test ct == c
104+
@test length(tp) == 1
105+
RDP.jacobian_reverse_pass!(out_a, ct, at, tp)
106+
RDP.jacobian_reverse_pass!(out_b, ct, bt, tp)
107+
@test_approx_eq_eps out_a ForwardDiff.jacobian(x -> g(x, b), a) EPS
108+
@test_approx_eq_eps out_b ForwardDiff.jacobian(x -> g(a, x), b) EPS
109+
RDP.unseed!(tp)
110+
empty!(tp)
111+
end
112+
113+
function test_broadcast(f, n::Number, x::AbstractArray, tp, builtin = false)
114+
nt, xt = track(n, tp), track(x, tp)
115+
116+
if builtin
117+
g = RDP.@forward(f)
118+
else
119+
g = (x, y) -> broadcast(RDP.@forward(f), x, y)
120+
end
121+
122+
y = g(n, x)
123+
124+
out = similar(y)
125+
yt = g(nt, x)
126+
@test yt == y
127+
@test length(tp) == 1
128+
RDP.jacobian_reverse_pass!(out, yt, [nt], tp)
129+
@test_approx_eq_eps out ForwardDiff.derivative(z -> g(z, x), n) EPS
130+
RDP.unseed!(tp)
131+
empty!(tp)
132+
133+
out = similar(y, (length(y), length(x)))
134+
yt = g(n, xt)
135+
@test yt == y
136+
@test length(tp) == 1
137+
RDP.jacobian_reverse_pass!(out, yt, xt, tp)
138+
@test_approx_eq_eps out ForwardDiff.jacobian(z -> g(n, z), x) EPS
139+
RDP.unseed!(tp)
140+
empty!(tp)
141+
142+
out_n = similar(y)
143+
out_x = similar(y, (length(y), length(x)))
144+
yt = g(nt, xt)
145+
@test yt == y
146+
@test length(tp) == 1
147+
RDP.jacobian_reverse_pass!(out_n, yt, [nt], tp)
148+
RDP.jacobian_reverse_pass!(out_x, yt, xt, tp)
149+
@test_approx_eq_eps out_n ForwardDiff.derivative(z -> g(z, x), n) EPS
150+
@test_approx_eq_eps out_x ForwardDiff.jacobian(z -> g(n, z), x) EPS
151+
RDP.unseed!(tp)
152+
empty!(tp)
153+
end
154+
155+
function test_broadcast(f, x::AbstractArray, n::Number, tp, builtin = false)
156+
xt, nt = track(x, tp), track(n, tp)
157+
158+
if builtin
159+
g = RDP.@forward(f)
160+
else
161+
g = (x, y) -> broadcast(RDP.@forward(f), x, y)
162+
end
163+
164+
y = g(x, n)
165+
166+
out = similar(y)
167+
yt = g(x, nt)
168+
@test yt == y
169+
@test length(tp) == 1
170+
RDP.jacobian_reverse_pass!(out, yt, [nt], tp)
171+
@test_approx_eq_eps out ForwardDiff.derivative(z -> g(x, z), n) EPS
172+
RDP.unseed!(tp)
173+
empty!(tp)
174+
175+
out = similar(y, (length(y), length(x)))
176+
yt = g(xt, n)
177+
@test yt == y
178+
@test length(tp) == 1
179+
RDP.jacobian_reverse_pass!(out, yt, xt, tp)
180+
@test_approx_eq_eps out ForwardDiff.jacobian(z -> g(z, n), x) EPS
181+
RDP.unseed!(tp)
182+
empty!(tp)
183+
184+
out_n = similar(y)
185+
out_x = similar(y, (length(y), length(x)))
186+
yt = g(xt, nt)
187+
@test yt == y
188+
@test length(tp) == 1
189+
RDP.jacobian_reverse_pass!(out_n, yt, [nt], tp)
190+
RDP.jacobian_reverse_pass!(out_x, yt, xt, tp)
191+
@test_approx_eq_eps out_n ForwardDiff.derivative(z -> g(x, z), n) EPS
192+
@test_approx_eq_eps out_x ForwardDiff.jacobian(z -> g(z, n), x) EPS
193+
RDP.unseed!(tp)
194+
empty!(tp)
195+
end
196+
197+
for f in (sin, cos, tan, exp, x -> 1. / (1. + exp(-x)))
198+
testprintln("unary scalar functions", f)
199+
test_elementwise(f, x, tp)
200+
end
201+
202+
for fsym in RDP.FORWARD_BINARY_SCALAR_FUNCS
203+
f = eval(fsym)
204+
testprintln("binary scalar functions", f)
205+
test_map(f, a, b, tp)
206+
test_broadcast(f, a, b, tp)
207+
test_broadcast(f, n, x, tp)
208+
test_broadcast(f, x, n, tp)
209+
end
210+
211+
for f in (.+, .-, .*, ./, .\, .^)
212+
testprintln("built-in broadcast functions", f)
213+
test_broadcast(f, a, b, tp, true)
214+
test_broadcast(f, n, x, tp, true)
215+
test_broadcast(f, x, n, tp, true)
216+
end
217+
218+
############################################################################################
219+
220+
println("done (took $(toq()) seconds)")
221+
222+
# end # module

0 commit comments

Comments
 (0)