Skip to content

Commit d065a44

Browse files
committed
add UtilsTests
1 parent 1b604e8 commit d065a44

File tree

5 files changed

+218
-6
lines changed

5 files changed

+218
-6
lines changed

src/Tracked.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Base.convert{V,A}(::Type{Tracked{V,A}}, t::Tracked) = Tracked(V(value(t)), A(adj
4444
Base.convert{T<:Tracked}(::Type{T}, t::T) = t
4545

4646
Base.promote_rule{R<:Real,V,A}(::Type{R}, ::Type{Tracked{V,A}}) = Tracked{promote_type(R,V),A}
47-
Base.promote_rule{V1,V2,A}(::Type{Tracked{V1,A}}, ::Type{Tracked{V2,A}}) = Tracked{promote_type(V1,V2),A}
47+
Base.promote_rule{V1,V2,A1,A2}(::Type{Tracked{V1,A1}}, ::Type{Tracked{V2,A2}}) = Tracked{promote_type(V1,V2),promote_type(A1,A2)}
4848

4949
Base.promote_array_type{T<:Tracked, F<:AbstractFloat}(_, ::Type{T}, ::Type{F}) = promote_type(T, F)
5050
Base.promote_array_type{T<:Tracked, F<:AbstractFloat, P}(_, ::Type{T}, ::Type{F}, ::Type{P}) = P

src/utils.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
################
44

55
track(x, tp::Tape = Tape()) = track(x, eltype(x), tp)
6-
track(xts::Tuple, tp::Tape = Tape()) = track(xts, map(eltype, xts), tp)
7-
track(xs::Tuple, types::Tuple{Vararg{DataType}}, tp::Tape = Tape()) = map((x, A) -> track(x, A, tp), xs, types)
6+
track(xts::Tuple, tp::Tape = Tape()) = track(xts, eltype(first(xts)), tp)
7+
track{A}(xs::Tuple, ::Type{A}, tp::Tape = Tape()) = map(x -> track(x, A, tp), xs)
88

9-
track{T,A}(x::AbstractArray{T}, ::Type{A}, tp::Tape = Tape()) = track(x,A, Nullable(tp))
9+
track{T,A}(x::AbstractArray{T}, ::Type{A}, tp::Tape = Tape()) = track(x, A, Nullable(tp))
1010
track{A}(x::Number, ::Type{A}, tp::Tape = Tape()) = track(x, A, Nullable(tp))
1111
track!(xts, xs, tp::Tape = Tape()) = track!(xts, xs, Nullable(tp))
1212

@@ -63,7 +63,9 @@ function tape(arr::AbstractArray)
6363
return Nullable{Tape}()
6464
end
6565

66-
tape(a, b::AbstractArray) = isnull(tape(a)) ? tape(b) : tape(a)
66+
tape(a::AbstractArray, b::AbstractArray) = (tp = tape(a); isnull(tp) ? tape(b) : tp)
67+
tape(a, b::AbstractArray) = (tp = tape(a); isnull(tp) ? tape(b) : tp)
68+
tape(a::AbstractArray, b) = (tp = tape(a); isnull(tp) ? tape(b) : tp)
6769

6870
#####################
6971
# seeding/unseeding #

test/TrackedTests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ t = Tracked(x, Int, ntp)
7777
@test convert(typeof(t), t) === t
7878

7979
@test promote_type(Int64, Tracked{Int32,Int32}) === Tracked{Int64,Int32}
80-
@test promote_type(Tracked{Int64,Int32}, Tracked{Int32,Int32}) === Tracked{Int64,Int32}
80+
@test promote_type(Tracked{Int64,Int32}, Tracked{Int32,Int64}) === Tracked{Int64,Int64}
8181

8282
@test Base.promote_array_type(nothing, Tracked{Int,Int}, Float64) === Tracked{Float64,Int}
8383
@test Base.promote_array_type(nothing, Tracked{Int,Int}, Float64, Int) === Int

test/UtilsTests.jl

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
module UtilsTests
2+
3+
using ReverseDiffPrototype, Base.Test
4+
using ReverseDiffPrototype: Tape, TapeNode, Tracked, track, track!, value, adjoint, tape
5+
6+
const RDP = ReverseDiffPrototype
7+
8+
include("utils.jl")
9+
10+
println("testing utilities...")
11+
tic()
12+
13+
############################################################################################
14+
15+
tracked_is(a, b) = value(a) === value(b) && adjoint(a) === adjoint(b) && tape(a) === tape(b)
16+
tracked_is(a::AbstractArray, b::AbstractArray) = all(map(tracked_is, a, b))
17+
tracked_is(a::Tuple, b::Tuple) = all(map(tracked_is, a, b))
18+
19+
################
20+
# track/track! #
21+
################
22+
23+
tp = Tape()
24+
ntp = Nullable(tp)
25+
n, m = rand(), rand()
26+
x, y = rand(3), rand(3, 3)
27+
tn_int, tm_int = Tracked(n, Int, ntp), Tracked(m, Int, ntp)
28+
tn_float, tm_float = Tracked(n, ntp), Tracked(m, ntp)
29+
tx_int, ty_int = map(t -> Tracked(t, Int, ntp), x), map(t -> Tracked(t, Int, ntp), y)
30+
tx_float, ty_float = map(t -> Tracked(t, ntp), x), map(t -> Tracked(t, ntp), y)
31+
32+
@test tracked_is(track(n, tp), tn_float)
33+
@test tracked_is(track(n, Int, tp), tn_int)
34+
@test tracked_is(track(n, Int, ntp), tn_int)
35+
36+
@test tracked_is(track(x, tp), tx_float)
37+
@test tracked_is(track(x, Int, tp), tx_int)
38+
@test tracked_is(track(x, Int, ntp), tx_int)
39+
40+
@test tracked_is(track((n, m), tp), (tn_float, tm_float))
41+
@test tracked_is(track((n, m), Int, tp), (tn_int, tm_int))
42+
43+
@test tracked_is(track((x, y), tp), (tx_float, ty_float))
44+
@test tracked_is(track((x, y), Int, tp), (tx_int, ty_int))
45+
46+
tx_int_sim = similar(tx_int)
47+
tx_float_sim = similar(tx_float)
48+
track!(tx_int_sim, x, tp)
49+
@test tracked_is(tx_int_sim, tx_int)
50+
track!(tx_float_sim, x, tp)
51+
@test tracked_is(tx_float_sim, tx_float)
52+
53+
tx_int_sim = similar(tx_int)
54+
tx_float_sim = similar(tx_float)
55+
track!(tx_int_sim, x, ntp)
56+
@test tracked_is(tx_int_sim, tx_int)
57+
track!(tx_float_sim, x, ntp)
58+
@test tracked_is(tx_float_sim, tx_float)
59+
60+
tx_int_sim, ty_int_sim = similar(tx_int), similar(ty_int)
61+
tx_float_sim, ty_float_sim = similar(tx_float), similar(ty_float)
62+
track!((tx_int_sim, ty_int_sim), (x, y), tp)
63+
@test tracked_is(tx_int_sim, tx_int)
64+
@test tracked_is(ty_int_sim, ty_int)
65+
track!((tx_float_sim, ty_float_sim), (x, y), tp)
66+
@test tracked_is(tx_float_sim, tx_float)
67+
@test tracked_is(ty_float_sim, ty_float)
68+
69+
tx_int_sim, ty_int_sim = similar(tx_int), similar(ty_int)
70+
tx_float_sim, ty_float_sim = similar(tx_float), similar(ty_float)
71+
track!((tx_int_sim, ty_int_sim), (x, y), ntp)
72+
@test tracked_is(tx_int_sim, tx_int)
73+
@test tracked_is(ty_int_sim, ty_int)
74+
track!((tx_float_sim, ty_float_sim), (x, y), ntp)
75+
@test tracked_is(tx_float_sim, tx_float)
76+
@test tracked_is(ty_float_sim, ty_float)
77+
78+
track!(tn_int, m, tp)
79+
@test tracked_is(tn_int, tm_int)
80+
track!(tn_int, n, tp)
81+
82+
track!(tn_float, m, tp)
83+
@test tracked_is(tn_float, tm_float)
84+
track!(tn_float, n, tp)
85+
86+
track!((tm_int, tn_int), (n, m), tp)
87+
@test value(tm_int) === n
88+
@test value(tn_int) === m
89+
track!((tm_int, tn_int), (m, n), tp)
90+
91+
track!(tn_int, m, ntp)
92+
@test tracked_is(tn_int, tm_int)
93+
track!(tn_int, n, ntp)
94+
95+
track!(tn_float, m, ntp)
96+
@test tracked_is(tn_float, tm_float)
97+
track!(tn_float, n, ntp)
98+
99+
track!((tm_int, tn_int), (n, m), ntp)
100+
@test value(tm_int) === n
101+
@test value(tn_int) === m
102+
track!((tm_int, tn_int), (m, n), ntp)
103+
104+
##################################
105+
# array accessors/tape selection #
106+
##################################
107+
108+
tp = Tape()
109+
ntp = Nullable(tp)
110+
x, y = rand(3), rand(3, 2)
111+
tx, ty = track(x, tp), track(y, tp)
112+
113+
@test value(tx) == x
114+
@test value(ty) == y
115+
116+
x_sim, y_sim = similar(x), similar(y)
117+
RDP.value!(x_sim, tx)
118+
RDP.value!(y_sim, ty)
119+
@test x_sim == x
120+
@test y_sim == y
121+
122+
@test all(adjoint(tx) .== zero(first(x)))
123+
@test all(adjoint(ty) .== zero(first(y)))
124+
125+
x_sim, y_sim = similar(x), similar(y)
126+
RDP.adjoint!(x_sim, tx)
127+
RDP.adjoint!(y_sim, ty)
128+
@test all(x_sim .== zero(first(x)))
129+
@test all(y_sim .== zero(first(y)))
130+
131+
@test tape(tx) === ntp
132+
@test tape(tx, ty) === ntp
133+
@test tape(tx, first(tx)) === ntp
134+
@test tape(first(tx), tx) === ntp
135+
@test tape(tx, Tracked(1)) === ntp
136+
@test tape(Tracked(1), tx) === ntp
137+
@test tape(tx, [Tracked(1)]) === ntp
138+
@test tape([Tracked(1)], tx) === ntp
139+
@test tape([Tracked(1), first(tx)]) === ntp
140+
@test isnull(tape([Tracked(1)]))
141+
@test isnull(tape([Tracked(1)], [Tracked(1)]))
142+
143+
#################
144+
# seed!/unseed! #
145+
#################
146+
147+
tp = Tape()
148+
ntp = Nullable(tp)
149+
150+
@test tracked_is(RDP.seed!(Tracked(1, 0, ntp)), Tracked(1, 1, ntp))
151+
152+
node = TapeNode(+, (Tracked(2, ntp), Tracked(1, ntp)), Tracked(3, ntp), nothing)
153+
@test tracked_is(RDP.seed!(node).outputs, Tracked(3, 1, ntp))
154+
155+
@test tracked_is(RDP.unseed!(Tracked(1, 1, ntp)), Tracked(1, 0, ntp))
156+
157+
node = TapeNode(+, (Tracked(2, 2, ntp), Tracked(1, 3, ntp)), Tracked(3, 4, ntp), nothing)
158+
RDP.unseed!(node)
159+
@test adjoint(node.inputs[1]) === 0
160+
@test adjoint(node.inputs[2]) === 0
161+
@test adjoint(node.outputs) === 0
162+
163+
tp2 = [TapeNode(+, (Tracked(2, 2, ntp), Tracked(1, 3, ntp)), Tracked(3, 4, ntp), nothing),
164+
TapeNode(+, Tracked(1.0, 3.0, ntp), (Tracked(51.4, 3.1, ntp), Tracked(3, 4, ntp)), nothing)]
165+
RDP.unseed!(tp2)
166+
@test adjoint(tp2[1].inputs[1]) === 0
167+
@test adjoint(tp2[1].inputs[2]) === 0
168+
@test adjoint(tp2[1].outputs) === 0
169+
@test adjoint(tp2[2].inputs) === 0.0
170+
@test adjoint(tp2[2].outputs[1]) === 0.0
171+
@test adjoint(tp2[2].outputs[2]) === 0
172+
173+
#######################
174+
# adjoint propagation #
175+
#######################
176+
177+
tp = Tape()
178+
ntp = Nullable(tp)
179+
genarr = () -> [Tracked(rand(), rand(), ntp) for i in 1:3]
180+
181+
x, y = genarr(), genarr()
182+
xadj, yadj = adjoint(x), adjoint(y)
183+
RDP.extract_and_decrement_adjoint!(x, y)
184+
@test adjoint(x) == (xadj - yadj)
185+
186+
x, y = genarr(), genarr()
187+
xadj, yadj = adjoint(x), adjoint(y)
188+
RDP.extract_and_increment_adjoint!(x, y)
189+
@test adjoint(x) == (xadj + yadj)
190+
191+
x, y = genarr(), rand(3)
192+
xadj = adjoint(x)
193+
RDP.increment_adjoint!(x, y)
194+
@test adjoint(x) == (xadj + y)
195+
196+
x = genarr()
197+
xadj = adjoint(x)
198+
RDP.increment_adjoint!(x, 3)
199+
@test adjoint(x) == (xadj + 3)
200+
201+
k = Tracked(1, 3, ntp)
202+
RDP.increment_adjoint!(k, 3)
203+
@test tracked_is(k, Tracked(1, 6, ntp))
204+
205+
############################################################################################
206+
207+
println("done (took $(toq()) seconds)")
208+
209+
end # module

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ testprintln(kind, f, pad = " ") = println(pad, "testing $(kind): `$(f)`...")
44

55
include(joinpath(TESTDIR, "TapeTests.jl"))
66
include(joinpath(TESTDIR, "TrackedTests.jl"))
7+
include(joinpath(TESTDIR, "UtilsTests.jl"))
78
include(joinpath(TESTDIR, "api/GradientTests.jl"))
89
include(joinpath(TESTDIR, "api/JacobianTests.jl"))
910
include(joinpath(TESTDIR, "api/HessianTests.jl"))

0 commit comments

Comments
 (0)