303
303
# built-in infix operations #
304
304
# ############################
305
305
306
+ typealias TrackedType Union{TrackedArray,TrackedReal}
307
+
306
308
# dispatch #
307
309
# ----------#
308
310
@@ -311,7 +313,7 @@ for (f, broadcast_f) in ((:.+, :broadcast_plus),
311
313
(:.* , :broadcast_mul ),
312
314
(:./ , :broadcast_rdiv ),
313
315
(:.\ , :broadcast_ldiv ),
314
- (:.^ , :broadcast_exp ))
316
+ (:.^ , :broadcast_pow ))
315
317
@eval begin
316
318
@inline Base.$ (f){X,Y,D}(x:: TrackedArray{X,D} , y:: TrackedArray{Y,D} ) = $ (broadcast_f)(x, y, D)
317
319
@inline Base.$ (f){X,Y,D}(x:: TrackedReal{X,D} , y:: TrackedArray{Y,D} ) = $ (broadcast_f)(x, y, D)
@@ -431,13 +433,14 @@ denom_partials(n, d) = broadcast(denom_partials_kernel, n, d)
431
433
denom_partials! (out:: Ref , n, d) = (out[] = denom_partials_kernel (n, d); nothing )
432
434
denom_partials! (out:: AbstractArray , n, d) = (broadcast! (denom_partials_kernel, out, n, d); nothing )
433
435
436
+ rdiv_cache (x:: TrackedType , y:: TrackedType ) = (numer_partials (value (y)), denom_partials (value (x), value (y)))
437
+ rdiv_cache (x:: TrackedType , y) = (numer_partials (value (y)), nothing )
438
+ rdiv_cache (x, y:: TrackedType ) = (nothing , denom_partials (value (x), value (y)))
439
+
434
440
function broadcast_rdiv {D} (x, y, :: Type{D} )
435
441
tp = tape (x, y)
436
442
out = track (value (x) ./ value (y), D, tp)
437
- n_partials = numer_partials (value (y))
438
- d_partials = denom_partials (value (x), value (y))
439
- cache = (n_partials, d_partials)
440
- record! (tp, SpecialInstruction, Base.:(./ ), (x, y), out, cache)
443
+ record! (tp, SpecialInstruction, Base.:(./ ), (x, y), out, rdiv_cache (x, y))
441
444
return out
442
445
end
443
446
449
452
pull_value! (a)
450
453
pull_value! (b)
451
454
broadcast! (/ , value (output), a_value, b_value)
452
- numer_partials! (n_partials, b_value)
453
- denom_partials! (d_partials, a_value, b_value)
455
+ ! ( isa (n_partials, Void)) && numer_partials! (n_partials, b_value)
456
+ ! ( isa (d_partials, Void)) && denom_partials! (d_partials, a_value, b_value)
454
457
return nothing
455
458
end
456
459
470
473
function broadcast_ldiv {D} (x, y, :: Type{D} )
471
474
tp = tape (x, y)
472
475
out = track (value (x) .\ value (y), D, tp)
473
- n_partials = numer_partials (value (x))
474
- d_partials = denom_partials (value (y), value (x))
475
- cache = (n_partials, d_partials)
476
- record! (tp, SpecialInstruction, Base.:(.\ ), (x, y), out, cache)
476
+ record! (tp, SpecialInstruction, Base.:(.\ ), (x, y), out, rdiv_cache (y, x))
477
477
return out
478
478
end
479
479
485
485
pull_value! (a)
486
486
pull_value! (b)
487
487
broadcast! (\ , value (output), a_value, b_value)
488
- numer_partials! (n_partials, a_value)
489
- denom_partials! (d_partials, b_value, a_value)
488
+ ! ( isa (n_partials, Void)) && numer_partials! (n_partials, a_value)
489
+ ! ( isa (d_partials, Void)) && denom_partials! (d_partials, b_value, a_value)
490
490
return nothing
491
491
end
492
492
@@ -515,13 +515,14 @@ exp_partials(b, e) = broadcast(exp_partials_kernel, b, e)
515
515
exp_partials! (out:: Ref , b, e) = (out[] = exp_partials_kernel (b, e); nothing )
516
516
exp_partials! (out:: AbstractArray , b, e) = (broadcast! (exp_partials_kernel, out, b, e); nothing )
517
517
518
- function broadcast_exp {D} (x, y, :: Type{D} )
518
+ pow_cache (x:: TrackedType , y:: TrackedType ) = (base_partials (value (x), value (y)), exp_partials (value (x), value (y)))
519
+ pow_cache (x:: TrackedType , y) = (base_partials (value (x), value (y)), nothing )
520
+ pow_cache (x, y:: TrackedType ) = (nothing , exp_partials (value (x), value (y)))
521
+
522
+ function broadcast_pow {D} (x, y, :: Type{D} )
519
523
tp = tape (x, y)
520
524
out = track (value (x) .^ value (y), D, tp)
521
- bs_partials = base_partials (value (x), value (y))
522
- ex_partials = exp_partials (value (x), value (y))
523
- cache = (bs_partials, ex_partials)
524
- record! (tp, SpecialInstruction, Base.:(.^ ), (x, y), out, cache)
525
+ record! (tp, SpecialInstruction, Base.:(.^ ), (x, y), out, pow_cache (x, y))
525
526
return out
526
527
end
527
528
533
534
pull_value! (a)
534
535
pull_value! (b)
535
536
broadcast! (^ , value (output), a_value, b_value)
536
- base_partials! (bs_partials, a_value, b_value)
537
- exp_partials! (ex_partials, a_value, b_value)
537
+ ! ( isa (bs_partials, Void)) && base_partials! (bs_partials, a_value, b_value)
538
+ ! ( isa (ex_partials, Void)) && exp_partials! (ex_partials, a_value, b_value)
538
539
return nothing
539
540
end
540
541
0 commit comments