Skip to content

Commit b3cc86e

Browse files
committed
add MacrosTests
1 parent 734739c commit b3cc86e

File tree

3 files changed

+213
-6
lines changed

3 files changed

+213
-6
lines changed

src/derivatives/macros.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,28 @@ works for the following formats:
55
- `@forward f = (args...) -> ...`
66
=#
77
function annotate_func_expr(typesym, expr)
8-
if isa(expr, Expr) && expr.head == :(=)
8+
if isa(expr, Expr) && (expr.head == :(=) || expr.head == :function)
99
lhs = expr.args[1]
1010
if isa(lhs, Expr) && lhs.head == :call # named function definition site
11-
name = lhs.args[1]
12-
hidden_name = gensym(name)
13-
lhs.args[1] = hidden_name
11+
name_and_types = lhs.args[1]
12+
args_signature = lhs.args[2:end]
13+
old_name_and_types = deepcopy(name_and_types)
14+
if isa(name_and_types, Expr) && name_and_types.head == :curly
15+
name = name_and_types.args[1]
16+
hidden_name = gensym(name)
17+
name_and_types.args[1] = hidden_name
18+
19+
elseif isa(name_and_types, Symbol)
20+
name = name_and_types
21+
hidden_name = gensym(name)
22+
lhs.args[1] = hidden_name
23+
else
24+
error("potentially malformed function signature: $(signature)")
25+
end
1426
return quote
1527
$expr
16-
@inline function $(name)(args...)
17-
return ReverseDiffPrototype.$(typesym)($(hidden_name))(args...)
28+
@inline function $(old_name_and_types)($(args_signature...))
29+
return ReverseDiffPrototype.$(typesym)($(hidden_name))($(args_signature...))
1830
end
1931
end
2032
elseif isa(lhs, Symbol) # variable assignment site

test/derivatives/MacrosTests.jl

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# module MacrosTests
2+
3+
using ReverseDiffPrototype, Base.Test
4+
using ForwardDiff: Dual, partials
5+
6+
include("../utils.jl")
7+
8+
println("testing macros (@forward, @skip, etc.)...")
9+
tic()
10+
11+
############################################################################################
12+
13+
tp = Tape()
14+
x, a, b = rand(3)
15+
16+
############
17+
# @forward #
18+
############
19+
20+
f0(x) = 1. / (1. + exp(-x))
21+
f0(a, b) = sqrt(a^2 + b^2)
22+
23+
RDP.@forward f1{T<:Real}(x::T) = 1. / (1. + exp(-x))
24+
RDP.@forward f1{A,B<:Real}(a::A, b::B) = sqrt(a^2 + b^2)
25+
26+
RDP.@forward f2(x) = 1. / (1. + exp(-x))
27+
RDP.@forward f2(a, b) = sqrt(a^2 + b^2)
28+
29+
RDP.@forward function f3{T<:Real}(x::T)
30+
return 1. / (1. + exp(-x))
31+
end
32+
33+
RDP.@forward function f3{A,B<:Real}(a::A, b::B)
34+
return sqrt(a^2 + b^2)
35+
end
36+
37+
RDP.@forward function f4(x)
38+
return 1. / (1. + exp(-x))
39+
end
40+
41+
RDP.@forward function f4(a, b)
42+
return sqrt(a^2 + b^2)
43+
end
44+
45+
function test_forward(f, x, tp)
46+
tx = track(x, tp)
47+
48+
y = f(x)
49+
@test isempty(tp)
50+
51+
ty = f(tx)
52+
@test ty == y
53+
dual = f(Dual(x, one(x)))
54+
@test length(tp) == 1
55+
node = first(tp)
56+
@test node.func === nothing
57+
@test node.inputs === tx
58+
@test node.outputs === ty
59+
@test node.cache === partials(dual)
60+
empty!(tp)
61+
end
62+
63+
function test_forward(f, a, b, tp)
64+
ta, tb = track(a, tp), track(b, tp)
65+
66+
c = f(a, b)
67+
@test isempty(tp)
68+
69+
tc = f(ta, b)
70+
@test tc == c
71+
dual = f(Dual(a, one(a)), b)
72+
@test length(tp) == 1
73+
node = first(tp)
74+
@test node.func === nothing
75+
@test node.inputs === ta
76+
@test node.outputs === tc
77+
@test node.cache === partials(dual)
78+
empty!(tp)
79+
80+
tc = f(a, tb)
81+
@test tc == c
82+
dual = f(a, Dual(b, one(b)))
83+
@test length(tp) == 1
84+
node = first(tp)
85+
@test node.func === nothing
86+
@test node.inputs === tb
87+
@test node.outputs === tc
88+
@test node.cache === partials(dual)
89+
empty!(tp)
90+
91+
tc = f(ta, tb)
92+
@test tc == c
93+
dual = f(Dual(a, one(a), zero(a)), Dual(b, zero(b), one(b)))
94+
@test length(tp) == 1
95+
node = first(tp)
96+
@test node.func === nothing
97+
@test node.inputs === (ta, tb)
98+
@test node.outputs === tc
99+
@test node.cache === partials(dual)
100+
empty!(tp)
101+
end
102+
103+
for f in (RDP.@forward(f0), f1, f2, f3, f4)
104+
testprintln("@forward named functions", f)
105+
test_forward(f, x, tp)
106+
test_forward(f, a, b, tp)
107+
end
108+
109+
RDP.@forward f5 = (x) -> 1. / (1. + exp(-x))
110+
testprintln("@forward anonymous functions", f5)
111+
test_forward(f5, x, tp)
112+
113+
RDP.@forward f6 = (a, b) -> sqrt(a^2 + b^2)
114+
testprintln("@forward anonymous functions", f6)
115+
test_forward(f6, a, b, tp)
116+
117+
#########
118+
# @skip #
119+
#########
120+
121+
g0 = f0
122+
123+
RDP.@skip g1{T<:Real}(x::T) = 1. / (1. + exp(-x))
124+
RDP.@skip g1{A,B<:Real}(a::A, b::B) = sqrt(a^2 + b^2)
125+
126+
RDP.@skip g2(x) = 1. / (1. + exp(-x))
127+
RDP.@skip g2(a, b) = sqrt(a^2 + b^2)
128+
129+
RDP.@skip function g3{T<:Real}(x::T)
130+
return 1. / (1. + exp(-x))
131+
end
132+
133+
RDP.@skip function g3{A,B<:Real}(a::A, b::B)
134+
return sqrt(a^2 + b^2)
135+
end
136+
137+
RDP.@skip function g4(x)
138+
return 1. / (1. + exp(-x))
139+
end
140+
141+
RDP.@skip function g4(a, b)
142+
return sqrt(a^2 + b^2)
143+
end
144+
145+
function test_skip(g, x, tp)
146+
tx = track(x, tp)
147+
148+
y = g(x)
149+
@test isempty(tp)
150+
151+
ty = g(tx)
152+
@test ty === y
153+
@test isempty(tp)
154+
end
155+
156+
function test_skip(g, a, b, tp)
157+
ta, tb = track(a, tp), track(b, tp)
158+
159+
c = g(a, b)
160+
@test isempty(tp)
161+
162+
tc = g(ta, b)
163+
@test tc === c
164+
@test isempty(tp)
165+
166+
tc = g(a, tb)
167+
@test tc === c
168+
@test isempty(tp)
169+
170+
tc = g(ta, tb)
171+
@test tc === c
172+
@test isempty(tp)
173+
end
174+
175+
for g in (RDP.@skip(g0), g1, g2, g3, g4)
176+
testprintln("@skip named functions", g)
177+
test_skip(g, x, tp)
178+
test_skip(g, a, b, tp)
179+
end
180+
181+
RDP.@skip g5 = (x) -> 1. / (1. + exp(-x))
182+
testprintln("@skip anonymous functions", g5)
183+
test_skip(g5, x, tp)
184+
185+
RDP.@skip g6 = (a, b) -> sqrt(a^2 + b^2)
186+
testprintln("@skip anonymous functions", g6)
187+
test_skip(g6, a, b, tp)
188+
189+
############################################################################################
190+
191+
println("done (took $(toq()) seconds)")
192+
193+
# end # module

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ testprintln(kind, f, pad = " ") = println(pad, "testing $(kind): `$(f)`...")
55
include(joinpath(TESTDIR, "TapeTests.jl"))
66
include(joinpath(TESTDIR, "TrackedTests.jl"))
77
include(joinpath(TESTDIR, "UtilsTests.jl"))
8+
include(joinpath(TESTDIR, "derivatives/MacrosTests.jl"))
9+
include(joinpath(TESTDIR, "api/OptionsTests.jl"))
810
include(joinpath(TESTDIR, "api/GradientTests.jl"))
911
include(joinpath(TESTDIR, "api/JacobianTests.jl"))
1012
include(joinpath(TESTDIR, "api/HessianTests.jl"))

0 commit comments

Comments
 (0)