Skip to content

Commit f5219b5

Browse files
committed
add scalar derivative tests
1 parent b3cc86e commit f5219b5

File tree

3 files changed

+87
-2
lines changed

3 files changed

+87
-2
lines changed

test/derivatives/MacrosTests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# module MacrosTests
1+
module MacrosTests
22

33
using ReverseDiffPrototype, Base.Test
44
using ForwardDiff: Dual, partials
@@ -190,4 +190,4 @@ test_skip(g6, a, b, tp)
190190

191191
println("done (took $(toq()) seconds)")
192192

193-
# end # module
193+
end # module

test/derivatives/ScalarTests.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
module ScalarTests
2+
3+
using ReverseDiffPrototype, ForwardDiff, Base.Test
4+
5+
include("../utils.jl")
6+
7+
println("testing scalar derivatives (both forward and reverse)")
8+
tic()
9+
10+
############################################################################################
11+
x, a, b = rand(3)
12+
tp = Tape()
13+
14+
function test_forward(f, x, tp)
15+
tx = track(x, tp)
16+
y = f(x)
17+
ty = f(tx)
18+
@test ty == y
19+
RDP.seed!(ty)
20+
RDP.reverse_pass!(tp)
21+
@test adjoint(tx) == ForwardDiff.derivative(f, x)
22+
empty!(tp)
23+
end
24+
25+
function test_forward(f, a, b, tp)
26+
ta, tb = track(a, tp), track(b, tp)
27+
c = f(a, b)
28+
tc = f(ta, tb)
29+
@test tc == c
30+
RDP.seed!(tc)
31+
RDP.reverse_pass!(tp)
32+
@test_approx_eq_eps adjoint(ta) ForwardDiff.derivative(x -> f(x, b), a) EPS
33+
@test_approx_eq_eps adjoint(tb) ForwardDiff.derivative(x -> f(a, x), b) EPS
34+
empty!(tp)
35+
end
36+
37+
function test_skip(f, x, tp)
38+
tx = track(x, tp)
39+
y = f(x)
40+
ty = f(tx)
41+
@test ty == y
42+
@test isempty(tp)
43+
end
44+
45+
function test_skip(f, a, b, tp)
46+
ta, tb = track(a, tp), track(b, tp)
47+
c = f(a, b)
48+
tc = f(ta, tb)
49+
@test tc == c
50+
@test isempty(tp)
51+
end
52+
53+
DOMAIN_ERR_FUNCS = (:asec, :acsc, :asecd, :acscd, :acoth, :acosh)
54+
55+
testprintln("FORWARD_UNARY_SCALAR_FUNCS", "(too many to print)")
56+
for f in RDP.FORWARD_UNARY_SCALAR_FUNCS
57+
n = in(f, DOMAIN_ERR_FUNCS) ? x + 1 : x
58+
test_forward(eval(f), n, tp)
59+
end
60+
61+
testprintln("FORWARD_BINARY_SCALAR_FUNCS", "(too many to print)")
62+
for f in RDP.FORWARD_BINARY_SCALAR_FUNCS
63+
test_forward(eval(f), a, b, tp)
64+
end
65+
66+
INT_ONLY_FUNCS = (:iseven, :isodd)
67+
68+
testprintln("SKIPPED_UNARY_SCALAR_FUNCS", "(too many to print)")
69+
for f in RDP.SKIPPED_UNARY_SCALAR_FUNCS
70+
n = in(f, DOMAIN_ERR_FUNCS) ? x + 1 : x
71+
n = in(f, INT_ONLY_FUNCS) ? ceil(Int, n) : n
72+
test_skip(eval(f), n, tp)
73+
end
74+
75+
testprintln("SKIPPED_BINARY_SCALAR_FUNCS", "(too many to print)")
76+
for f in RDP.SKIPPED_BINARY_SCALAR_FUNCS
77+
test_skip(eval(f), a, b, tp)
78+
end
79+
80+
############################################################################################
81+
82+
println("done (took $(toq()) seconds)")
83+
84+
end # module

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ include(joinpath(TESTDIR, "TapeTests.jl"))
66
include(joinpath(TESTDIR, "TrackedTests.jl"))
77
include(joinpath(TESTDIR, "UtilsTests.jl"))
88
include(joinpath(TESTDIR, "derivatives/MacrosTests.jl"))
9+
include(joinpath(TESTDIR, "derivatives/ScalarTests.jl"))
910
include(joinpath(TESTDIR, "api/OptionsTests.jl"))
1011
include(joinpath(TESTDIR, "api/GradientTests.jl"))
1112
include(joinpath(TESTDIR, "api/JacobianTests.jl"))

0 commit comments

Comments
 (0)