Skip to content

Commit 69dd923

Browse files
authored
use istracked instead of pure type checking to decide whether to update partials cache for pow/div broadcast methods (#37)
1 parent c8c47d6 commit 69dd923

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

src/derivatives/elementwise.jl

+8-12
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,7 @@ denom_partials(n, d) = broadcast(denom_partials_kernel, n, d)
433433
denom_partials!(out::Ref, n, d) = (out[] = denom_partials_kernel(n, d); nothing)
434434
denom_partials!(out::AbstractArray, n, d) = (broadcast!(denom_partials_kernel, out, n, d); nothing)
435435

436-
rdiv_cache(x::TrackedType, y::TrackedType) = (numer_partials(value(y)), denom_partials(value(x), value(y)))
437-
rdiv_cache(x::TrackedType, y) = (numer_partials(value(y)), nothing)
438-
rdiv_cache(x, y::TrackedType) = (nothing, denom_partials(value(x), value(y)))
436+
rdiv_cache(x, y) = (numer_partials(value(y)), denom_partials(value(x), value(y)))
439437

440438
function broadcast_rdiv{D}(x, y, ::Type{D})
441439
tp = tape(x, y)
@@ -452,8 +450,8 @@ end
452450
pull_value!(a)
453451
pull_value!(b)
454452
broadcast!(/, value(output), a_value, b_value)
455-
!(isa(n_partials, Void)) && numer_partials!(n_partials, b_value)
456-
!(isa(d_partials, Void)) && denom_partials!(d_partials, a_value, b_value)
453+
istracked(a) && numer_partials!(n_partials, b_value)
454+
istracked(b) && denom_partials!(d_partials, a_value, b_value)
457455
return nothing
458456
end
459457

@@ -485,8 +483,8 @@ end
485483
pull_value!(a)
486484
pull_value!(b)
487485
broadcast!(\, value(output), a_value, b_value)
488-
!(isa(n_partials, Void)) && numer_partials!(n_partials, a_value)
489-
!(isa(d_partials, Void)) && denom_partials!(d_partials, b_value, a_value)
486+
istracked(b) && numer_partials!(n_partials, a_value)
487+
istracked(a) && denom_partials!(d_partials, b_value, a_value)
490488
return nothing
491489
end
492490

@@ -515,9 +513,7 @@ exp_partials(b, e) = broadcast(exp_partials_kernel, b, e)
515513
exp_partials!(out::Ref, b, e) = (out[] = exp_partials_kernel(b, e); nothing)
516514
exp_partials!(out::AbstractArray, b, e) = (broadcast!(exp_partials_kernel, out, b, e); nothing)
517515

518-
pow_cache(x::TrackedType, y::TrackedType) = (base_partials(value(x), value(y)), exp_partials(value(x), value(y)))
519-
pow_cache(x::TrackedType, y) = (base_partials(value(x), value(y)), nothing)
520-
pow_cache(x, y::TrackedType) = (nothing, exp_partials(value(x), value(y)))
516+
pow_cache(x, y) = (base_partials(value(x), value(y)), exp_partials(value(x), value(y)))
521517

522518
function broadcast_pow{D}(x, y, ::Type{D})
523519
tp = tape(x, y)
@@ -534,8 +530,8 @@ end
534530
pull_value!(a)
535531
pull_value!(b)
536532
broadcast!(^, value(output), a_value, b_value)
537-
!(isa(bs_partials, Void)) && base_partials!(bs_partials, a_value, b_value)
538-
!(isa(ex_partials, Void)) && exp_partials!(ex_partials, a_value, b_value)
533+
istracked(a) && base_partials!(bs_partials, a_value, b_value)
534+
istracked(b) && exp_partials!(ex_partials, a_value, b_value)
539535
return nothing
540536
end
541537

0 commit comments

Comments
 (0)