Skip to content

Commit f200520

Browse files
authored
add missing elementwise derivative propagation methods (#30)
* add missing elementwise derivative propagation methods * always checkout DiffBase master in tests, in order to grab the most recent test functions * fix typo * fix Jacobian tests for functions with general output dimension
1 parent 689e789 commit f200520

File tree

3 files changed

+71
-36
lines changed

3 files changed

+71
-36
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ notifications:
99
email: false
1010
script:
1111
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
12-
- julia -e 'Pkg.clone(pwd()); Pkg.build("ReverseDiff"); Pkg.test("ReverseDiff"; coverage=true)'
12+
- julia -e 'Pkg.clone(pwd()); Pkg.build("ReverseDiff"); Pkg.checkout("DiffBase"); Pkg.test("ReverseDiff"; coverage=true)'
1313
after_success:
1414
# push coverage results to Coveralls
1515
- julia -e 'cd(Pkg.dir("ReverseDiff")); Pkg.add("Coverage"); using Coverage; Coveralls.submit(Coveralls.process_folder())'

src/derivatives/elementwise.jl

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -613,8 +613,10 @@ function broadcast_duals_increment!(input::TrackedReal, output, duals, p::Int)
613613
return nothing
614614
end
615615

616-
# broadcast_deriv_increment! #
617-
#----------------------------#
616+
# broadcast_deriv_increment!/broadcast_deriv_decrement! #
617+
#-------------------------------------------------------#
618+
619+
#######
618620

619621
function broadcast_deriv_increment!(input::TrackedArray, output)
620622
max_input_index = max_leftover_index(input, output)
@@ -625,6 +627,17 @@ function broadcast_deriv_increment!(input::TrackedArray, output)
625627
return nothing
626628
end
627629

630+
function broadcast_deriv_increment!(input::AbstractArray, output)
631+
max_input_index = max_leftover_index(input, output)
632+
output_deriv = deriv(output)
633+
for i in CartesianRange(size(output))
634+
increment_deriv!(input[min(max_input_index, i)], output_deriv[i])
635+
end
636+
return nothing
637+
end
638+
639+
#######
640+
628641
function broadcast_deriv_decrement!(input::TrackedArray, output)
629642
max_input_index = max_leftover_index(input, output)
630643
input_deriv, output_deriv = deriv(input), deriv(output)
@@ -634,6 +647,17 @@ function broadcast_deriv_decrement!(input::TrackedArray, output)
634647
return nothing
635648
end
636649

650+
function broadcast_deriv_decrement!(input::AbstractArray, output)
651+
max_input_index = max_leftover_index(input, output)
652+
output_deriv = deriv(output)
653+
for i in CartesianRange(size(output))
654+
decrement_deriv!(input[min(max_input_index, i)], output_deriv[i])
655+
end
656+
return nothing
657+
end
658+
659+
#######
660+
637661
function broadcast_deriv_increment!(input::TrackedArray, output, partials::AbstractArray)
638662
max_input_index = max_leftover_index(input, output)
639663
max_partials_index = max_leftover_index(partials, output)
@@ -644,15 +668,18 @@ function broadcast_deriv_increment!(input::TrackedArray, output, partials::Abstr
644668
return nothing
645669
end
646670

647-
function broadcast_deriv_increment!(input::AbstractArray, output, partials::Real)
648-
max_input_index = max_leftover_index(input, output)
671+
function broadcast_deriv_increment!(input::AbstractArray, output, partials::AbstractArray)
672+
max_input_index = max_leftover_index(input, output)
673+
max_partials_index = max_leftover_index(partials, output)
649674
output_deriv = deriv(output)
650675
for i in CartesianRange(size(output))
651-
increment_deriv!(input[min(max_input_index, i)], output_deriv[i] * partials)
676+
increment_deriv!(input[min(max_input_index, i)], output_deriv[i] * partials[min(max_partials_index, i)])
652677
end
653678
return nothing
654679
end
655680

681+
#######
682+
656683
function broadcast_deriv_increment!(input::TrackedArray, output, partials::Real)
657684
max_input_index = max_leftover_index(input, output)
658685
input_deriv, output_deriv = deriv(input), deriv(output)
@@ -662,16 +689,17 @@ function broadcast_deriv_increment!(input::TrackedArray, output, partials::Real)
662689
return nothing
663690
end
664691

665-
function broadcast_deriv_increment!(input::AbstractArray, output, partials::AbstractArray)
666-
max_input_index = max_leftover_index(input, output)
667-
max_partials_index = max_leftover_index(partials, output)
692+
function broadcast_deriv_increment!(input::AbstractArray, output, partials::Real)
693+
max_input_index = max_leftover_index(input, output)
668694
output_deriv = deriv(output)
669695
for i in CartesianRange(size(output))
670-
increment_deriv!(input[min(max_input_index, i)], output_deriv[i] * partials[min(max_partials_index, i)])
696+
increment_deriv!(input[min(max_input_index, i)], output_deriv[i] * partials)
671697
end
672698
return nothing
673699
end
674700

701+
#######
702+
675703
function broadcast_deriv_increment!(input::TrackedReal, output::TrackedArray, partials::AbstractArray)
676704
output_deriv = deriv(output)
677705
for i in eachindex(output_deriv)
@@ -680,6 +708,8 @@ function broadcast_deriv_increment!(input::TrackedReal, output::TrackedArray, pa
680708
return nothing
681709
end
682710

711+
#######
712+
683713
function broadcast_deriv_increment!(input::TrackedReal, output::TrackedArray, partials::Real)
684714
output_deriv = deriv(output)
685715
for i in eachindex(output_deriv)
@@ -688,6 +718,8 @@ function broadcast_deriv_increment!(input::TrackedReal, output::TrackedArray, pa
688718
return nothing
689719
end
690720

721+
#######
722+
691723
function broadcast_deriv_increment!(input::TrackedReal, output::TrackedArray)
692724
output_deriv = deriv(output)
693725
for i in eachindex(output_deriv)
@@ -696,6 +728,8 @@ function broadcast_deriv_increment!(input::TrackedReal, output::TrackedArray)
696728
return nothing
697729
end
698730

731+
#######
732+
699733
function broadcast_deriv_decrement!(input::TrackedReal, output::TrackedArray)
700734
output_deriv = deriv(output)
701735
for i in eachindex(output_deriv)

test/api/JacobianTests.jl

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ tic()
1010
############################################################################################
1111

1212
function test_unary_jacobian(f, x)
13-
test = ForwardDiff.jacobian!(DiffBase.JacobianResult(x), f, x, ForwardDiff.JacobianConfig(x))
13+
test_val = f(x)
14+
test = ForwardDiff.jacobian!(DiffBase.JacobianResult(test_val, x), f, x, ForwardDiff.JacobianConfig(x))
1415

1516
# without JacobianConfig
1617

@@ -20,7 +21,7 @@ function test_unary_jacobian(f, x)
2021
ReverseDiff.jacobian!(out, f, x)
2122
@test_approx_eq_eps out DiffBase.jacobian(test) EPS
2223

23-
result = DiffBase.JacobianResult(x)
24+
result = DiffBase.JacobianResult(test_val, x)
2425
ReverseDiff.jacobian!(result, f, x)
2526
@test_approx_eq_eps DiffBase.value(result) DiffBase.value(test) EPS
2627
@test_approx_eq_eps DiffBase.jacobian(result) DiffBase.jacobian(test) EPS
@@ -35,7 +36,7 @@ function test_unary_jacobian(f, x)
3536
ReverseDiff.jacobian!(out, f, x, cfg)
3637
@test_approx_eq_eps out DiffBase.jacobian(test) EPS
3738

38-
result = DiffBase.JacobianResult(x)
39+
result = DiffBase.JacobianResult(test_val, x)
3940
ReverseDiff.jacobian!(result, f, x, cfg)
4041
@test_approx_eq_eps DiffBase.value(result) DiffBase.value(test) EPS
4142
@test_approx_eq_eps DiffBase.jacobian(result) DiffBase.jacobian(test) EPS
@@ -50,7 +51,7 @@ function test_unary_jacobian(f, x)
5051
ReverseDiff.jacobian!(out, tp, x)
5152
@test_approx_eq_eps out DiffBase.jacobian(test) EPS
5253

53-
result = DiffBase.JacobianResult(x)
54+
result = DiffBase.JacobianResult(test_val, x)
5455
ReverseDiff.jacobian!(result, tp, x)
5556
@test_approx_eq_eps DiffBase.value(result) DiffBase.value(test) EPS
5657
@test_approx_eq_eps DiffBase.jacobian(result) DiffBase.jacobian(test) EPS
@@ -71,12 +72,12 @@ function test_unary_jacobian(f, x)
7172
Jf!(out, x)
7273
@test_approx_eq_eps out DiffBase.jacobian(test) EPS
7374

74-
result = DiffBase.JacobianResult(x)
75+
result = DiffBase.JacobianResult(test_val, x)
7576
ReverseDiff.jacobian!(result, ctp, x)
7677
@test_approx_eq_eps DiffBase.value(result) DiffBase.value(test) EPS
7778
@test_approx_eq_eps DiffBase.jacobian(result) DiffBase.jacobian(test) EPS
7879

79-
result = DiffBase.JacobianResult(x)
80+
result = DiffBase.JacobianResult(test_val, x)
8081
Jf!(result, x)
8182
@test_approx_eq_eps DiffBase.value(result) DiffBase.value(test) EPS
8283
@test_approx_eq_eps DiffBase.jacobian(result) DiffBase.jacobian(test) EPS
@@ -186,14 +187,14 @@ function test_binary_jacobian(f, a, b)
186187
@test_approx_eq_eps Ja test_a EPS
187188
@test_approx_eq_eps Jb test_b EPS
188189

189-
Ja = similar(a, length(a), length(b))
190-
Jb = copy(Ja)
190+
Ja = similar(a, length(test_val), length(a))
191+
Jb = similar(b, length(test_val), length(b))
191192
ReverseDiff.jacobian!((Ja, Jb), f, (a, b))
192193
@test_approx_eq_eps Ja test_a EPS
193194
@test_approx_eq_eps Jb test_b EPS
194195

195-
Ja = DiffBase.JacobianResult(a, b)
196-
Jb = copy(Ja)
196+
Ja = DiffBase.JacobianResult(test_val, a)
197+
Jb = DiffBase.JacobianResult(test_val, b)
197198
ReverseDiff.jacobian!((Ja, Jb), f, (a, b))
198199
@test_approx_eq_eps DiffBase.value(Ja) test_val EPS
199200
@test_approx_eq_eps DiffBase.value(Jb) test_val EPS
@@ -208,14 +209,14 @@ function test_binary_jacobian(f, a, b)
208209
@test_approx_eq_eps Ja test_a EPS
209210
@test_approx_eq_eps Jb test_b EPS
210211

211-
Ja = similar(a, length(a), length(b))
212-
Jb = copy(Ja)
212+
Ja = similar(a, length(test_val), length(a))
213+
Jb = similar(b, length(test_val), length(b))
213214
ReverseDiff.jacobian!((Ja, Jb), f, (a, b), cfg)
214215
@test_approx_eq_eps Ja test_a EPS
215216
@test_approx_eq_eps Jb test_b EPS
216217

217-
Ja = DiffBase.JacobianResult(a, b)
218-
Jb = copy(Ja)
218+
Ja = DiffBase.JacobianResult(test_val, a)
219+
Jb = DiffBase.JacobianResult(test_val, b)
219220
ReverseDiff.jacobian!((Ja, Jb), f, (a, b), cfg)
220221
@test_approx_eq_eps DiffBase.value(Ja) test_val EPS
221222
@test_approx_eq_eps DiffBase.value(Jb) test_val EPS
@@ -230,14 +231,14 @@ function test_binary_jacobian(f, a, b)
230231
@test_approx_eq_eps Ja test_a EPS
231232
@test_approx_eq_eps Jb test_b EPS
232233

233-
Ja = similar(a, length(a), length(b))
234-
Jb = copy(Ja)
234+
Ja = similar(a, length(test_val), length(a))
235+
Jb = similar(b, length(test_val), length(b))
235236
ReverseDiff.jacobian!((Ja, Jb), tp, (a, b))
236237
@test_approx_eq_eps Ja test_a EPS
237238
@test_approx_eq_eps Jb test_b EPS
238239

239-
Ja = DiffBase.JacobianResult(a, b)
240-
Jb = copy(Ja)
240+
Ja = DiffBase.JacobianResult(test_val, a)
241+
Jb = DiffBase.JacobianResult(test_val, b)
241242
ReverseDiff.jacobian!((Ja, Jb), tp, (a, b))
242243
@test_approx_eq_eps DiffBase.value(Ja) test_val EPS
243244
@test_approx_eq_eps DiffBase.value(Jb) test_val EPS
@@ -254,28 +255,28 @@ function test_binary_jacobian(f, a, b)
254255
@test_approx_eq_eps Ja test_a EPS
255256
@test_approx_eq_eps Jb test_b EPS
256257

257-
Ja = similar(a, length(a), length(b))
258-
Jb = copy(Ja)
258+
Ja = similar(a, length(test_val), length(a))
259+
Jb = similar(b, length(test_val), length(b))
259260
ReverseDiff.jacobian!((Ja, Jb), ctp, (a, b))
260261
@test_approx_eq_eps Ja test_a EPS
261262
@test_approx_eq_eps Jb test_b EPS
262263

263-
Ja = similar(a, length(a), length(b))
264-
Jb = copy(Ja)
264+
Ja = similar(a, length(test_val), length(a))
265+
Jb = similar(b, length(test_val), length(b))
265266
Jf!((Ja, Jb), (a, b))
266267
@test_approx_eq_eps Ja test_a EPS
267268
@test_approx_eq_eps Jb test_b EPS
268269

269-
Ja = DiffBase.JacobianResult(a, b)
270-
Jb = copy(Ja)
270+
Ja = DiffBase.JacobianResult(test_val, a)
271+
Jb = DiffBase.JacobianResult(test_val, b)
271272
ReverseDiff.jacobian!((Ja, Jb), ctp, (a, b))
272273
@test_approx_eq_eps DiffBase.value(Ja) test_val EPS
273274
@test_approx_eq_eps DiffBase.value(Jb) test_val EPS
274275
@test_approx_eq_eps DiffBase.gradient(Ja) test_a EPS
275276
@test_approx_eq_eps DiffBase.gradient(Jb) test_b EPS
276277

277-
Ja = DiffBase.JacobianResult(a, b)
278-
Jb = copy(Ja)
278+
Ja = DiffBase.JacobianResult(test_val, a)
279+
Jb = DiffBase.JacobianResult(test_val, b)
279280
Jf!((Ja, Jb), (a, b))
280281
@test_approx_eq_eps DiffBase.value(Ja) test_val EPS
281282
@test_approx_eq_eps DiffBase.value(Jb) test_val EPS

0 commit comments

Comments
 (0)