Skip to content

Commit 1b604e8

Browse files
committed
add TapeTests
1 parent 389e76f commit 1b604e8

File tree

8 files changed

+83
-7
lines changed

8 files changed

+83
-7
lines changed

src/Tape.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ end
2727
@inline capture(state::Tuple{Vararg{Number}}) = state
2828
@inline capture(state::Tuple) = map(capture, state)
2929

30+
function Base.:(==)(a::TapeNode, b::TapeNode)
31+
return (a.func == b.func &&
32+
a.inputs == b.inputs &&
33+
a.outputs == b.outputs &&
34+
a.cache == b.cache)
35+
end
36+
3037
################################################
3138
# reverse pass (backpropagation over the tape) #
3239
################################################

test/TapeTests.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
module TapeTests
2+
3+
using ReverseDiffPrototype, Base.Test
4+
using ReverseDiffPrototype: TapeNode, Tape, Tracked
5+
6+
const RDP = ReverseDiffPrototype
7+
8+
include("utils.jl")
9+
10+
println("testing Tape/TapeNode types...")
11+
tic()
12+
13+
############################################################################################
14+
15+
#################
16+
# TapeNode/Tape #
17+
#################
18+
19+
x, y, k = [1, 2, 3], [4, 5, 6], 7
20+
z = x + y + k
21+
c = []
22+
tn = TapeNode(+, (x, y, k), z, c)
23+
@test tn.func === +
24+
@test tn.inputs === (x, y, k)
25+
@test tn.outputs === z
26+
@test tn.cache === c
27+
28+
tp = Tape()
29+
ntp = Nullable(tp)
30+
RDP.record!(ntp, +, (x, y, k), z, c)
31+
tp1 = first(tp)
32+
@test tp1 == tn
33+
@test tp1.inputs[1] !== x
34+
@test tp1.inputs[2] !== y
35+
@test tp1.inputs[3] === k
36+
@test tp1.outputs !== z
37+
@test tp1.cache === c
38+
39+
ntp = Nullable{Tape}()
40+
RDP.record!(ntp, +, (x, y, k), z, c)
41+
@test ntp === Nullable{Tape}()
42+
43+
t = Tracked(1)
44+
x = [t, t]
45+
@test RDP.capture(t) === t
46+
47+
cx = RDP.capture(x)
48+
@test cx !== x
49+
@test cx == x
50+
@test cx[1] === x[1]
51+
@test cx[2] === x[2]
52+
53+
cs = RDP.capture((x, t, x))
54+
@test cs[1] !== x
55+
@test cs[1] == x
56+
@test cs[1][1] === x[1]
57+
@test cs[1][2] === x[2]
58+
@test cs[2] === t
59+
@test cs[3] !== x
60+
@test cs[3] == x
61+
@test cs[3][1] === x[1]
62+
@test cs[3][2] === x[2]
63+
64+
############################################################################################
65+
66+
println("done (took $(toq()) seconds)")
67+
68+
end # module

test/TrackedTests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module TrackedTests
22

3-
using DiffBase, ForwardDiff, ReverseDiffPrototype, Base.Test
3+
using ReverseDiffPrototype, Base.Test
44
using ReverseDiffPrototype: Tape, Tracked, value, adjoint, tape, valtype, adjtype
55

66
const RDP = ReverseDiffPrototype

test/UtilsTests.jl

Whitespace-only changes.

test/GradientTests.jl renamed to test/api/GradientTests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using DiffBase, ForwardDiff, ReverseDiffPrototype, Base.Test
44

55
const RDP = ReverseDiffPrototype
66

7-
include("utils.jl")
7+
include("../utils.jl")
88

99
println("testing gradient/gradient!...")
1010
tic()

test/HessianTests.jl renamed to test/api/HessianTests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using DiffBase, ForwardDiff, ReverseDiffPrototype, Base.Test
44

55
const RDP = ReverseDiffPrototype
66

7-
include("utils.jl")
7+
include("../utils.jl")
88

99
println("testing hessian/hessian!...")
1010
tic()

test/JacobianTests.jl renamed to test/api/JacobianTests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using DiffBase, ForwardDiff, ReverseDiffPrototype, Base.Test
44

55
const RDP = ReverseDiffPrototype
66

7-
include("utils.jl")
7+
include("../utils.jl")
88

99
println("testing jacobian/jacobian!...")
1010
tic()

test/runtests.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ const TESTDIR = dirname(@__FILE__)
22

33
testprintln(kind, f, pad = " ") = println(pad, "testing $(kind): `$(f)`...")
44

5+
include(joinpath(TESTDIR, "TapeTests.jl"))
56
include(joinpath(TESTDIR, "TrackedTests.jl"))
6-
include(joinpath(TESTDIR, "GradientTests.jl"))
7-
include(joinpath(TESTDIR, "JacobianTests.jl"))
8-
include(joinpath(TESTDIR, "HessianTests.jl"))
7+
include(joinpath(TESTDIR, "api/GradientTests.jl"))
8+
include(joinpath(TESTDIR, "api/JacobianTests.jl"))
9+
include(joinpath(TESTDIR, "api/HessianTests.jl"))

0 commit comments

Comments
 (0)