Skip to content

Commit c8c47d6

Browse files
authored
avoid unnecessary computation for built-in ^ and / broadcast functions (#33)
1 parent 7377f9a commit c8c47d6

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

src/derivatives/elementwise.jl

+21-20
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,8 @@ end
303303
# built-in infix operations #
304304
#############################
305305

306+
typealias TrackedType Union{TrackedArray,TrackedReal}
307+
306308
# dispatch #
307309
#----------#
308310

@@ -311,7 +313,7 @@ for (f, broadcast_f) in ((:.+, :broadcast_plus),
311313
(:.*, :broadcast_mul),
312314
(:./, :broadcast_rdiv),
313315
(:.\, :broadcast_ldiv),
314-
(:.^, :broadcast_exp))
316+
(:.^, :broadcast_pow))
315317
@eval begin
316318
@inline Base.$(f){X,Y,D}(x::TrackedArray{X,D}, y::TrackedArray{Y,D}) = $(broadcast_f)(x, y, D)
317319
@inline Base.$(f){X,Y,D}(x::TrackedReal{X,D}, y::TrackedArray{Y,D}) = $(broadcast_f)(x, y, D)
@@ -431,13 +433,14 @@ denom_partials(n, d) = broadcast(denom_partials_kernel, n, d)
431433
denom_partials!(out::Ref, n, d) = (out[] = denom_partials_kernel(n, d); nothing)
432434
denom_partials!(out::AbstractArray, n, d) = (broadcast!(denom_partials_kernel, out, n, d); nothing)
433435

436+
rdiv_cache(x::TrackedType, y::TrackedType) = (numer_partials(value(y)), denom_partials(value(x), value(y)))
437+
rdiv_cache(x::TrackedType, y) = (numer_partials(value(y)), nothing)
438+
rdiv_cache(x, y::TrackedType) = (nothing, denom_partials(value(x), value(y)))
439+
434440
function broadcast_rdiv{D}(x, y, ::Type{D})
435441
tp = tape(x, y)
436442
out = track(value(x) ./ value(y), D, tp)
437-
n_partials = numer_partials(value(y))
438-
d_partials = denom_partials(value(x), value(y))
439-
cache = (n_partials, d_partials)
440-
record!(tp, SpecialInstruction, Base.:(./), (x, y), out, cache)
443+
record!(tp, SpecialInstruction, Base.:(./), (x, y), out, rdiv_cache(x, y))
441444
return out
442445
end
443446

@@ -449,8 +452,8 @@ end
449452
pull_value!(a)
450453
pull_value!(b)
451454
broadcast!(/, value(output), a_value, b_value)
452-
numer_partials!(n_partials, b_value)
453-
denom_partials!(d_partials, a_value, b_value)
455+
!(isa(n_partials, Void)) && numer_partials!(n_partials, b_value)
456+
!(isa(d_partials, Void)) && denom_partials!(d_partials, a_value, b_value)
454457
return nothing
455458
end
456459

@@ -470,10 +473,7 @@ end
470473
function broadcast_ldiv{D}(x, y, ::Type{D})
471474
tp = tape(x, y)
472475
out = track(value(x) .\ value(y), D, tp)
473-
n_partials = numer_partials(value(x))
474-
d_partials = denom_partials(value(y), value(x))
475-
cache = (n_partials, d_partials)
476-
record!(tp, SpecialInstruction, Base.:(.\), (x, y), out, cache)
476+
record!(tp, SpecialInstruction, Base.:(.\), (x, y), out, rdiv_cache(y, x))
477477
return out
478478
end
479479

@@ -485,8 +485,8 @@ end
485485
pull_value!(a)
486486
pull_value!(b)
487487
broadcast!(\, value(output), a_value, b_value)
488-
numer_partials!(n_partials, a_value)
489-
denom_partials!(d_partials, b_value, a_value)
488+
!(isa(n_partials, Void)) && numer_partials!(n_partials, a_value)
489+
!(isa(d_partials, Void)) && denom_partials!(d_partials, b_value, a_value)
490490
return nothing
491491
end
492492

@@ -515,13 +515,14 @@ exp_partials(b, e) = broadcast(exp_partials_kernel, b, e)
515515
exp_partials!(out::Ref, b, e) = (out[] = exp_partials_kernel(b, e); nothing)
516516
exp_partials!(out::AbstractArray, b, e) = (broadcast!(exp_partials_kernel, out, b, e); nothing)
517517

518-
function broadcast_exp{D}(x, y, ::Type{D})
518+
pow_cache(x::TrackedType, y::TrackedType) = (base_partials(value(x), value(y)), exp_partials(value(x), value(y)))
519+
pow_cache(x::TrackedType, y) = (base_partials(value(x), value(y)), nothing)
520+
pow_cache(x, y::TrackedType) = (nothing, exp_partials(value(x), value(y)))
521+
522+
function broadcast_pow{D}(x, y, ::Type{D})
519523
tp = tape(x, y)
520524
out = track(value(x) .^ value(y), D, tp)
521-
bs_partials = base_partials(value(x), value(y))
522-
ex_partials = exp_partials(value(x), value(y))
523-
cache = (bs_partials, ex_partials)
524-
record!(tp, SpecialInstruction, Base.:(.^), (x, y), out, cache)
525+
record!(tp, SpecialInstruction, Base.:(.^), (x, y), out, pow_cache(x, y))
525526
return out
526527
end
527528

@@ -533,8 +534,8 @@ end
533534
pull_value!(a)
534535
pull_value!(b)
535536
broadcast!(^, value(output), a_value, b_value)
536-
base_partials!(bs_partials, a_value, b_value)
537-
exp_partials!(ex_partials, a_value, b_value)
537+
!(isa(bs_partials, Void)) && base_partials!(bs_partials, a_value, b_value)
538+
!(isa(ex_partials, Void)) && exp_partials!(ex_partials, a_value, b_value)
538539
return nothing
539540
end
540541

0 commit comments

Comments
 (0)