@@ -433,9 +433,7 @@ denom_partials(n, d) = broadcast(denom_partials_kernel, n, d)
433
433
denom_partials! (out:: Ref , n, d) = (out[] = denom_partials_kernel (n, d); nothing )
434
434
denom_partials! (out:: AbstractArray , n, d) = (broadcast! (denom_partials_kernel, out, n, d); nothing )
435
435
436
- rdiv_cache (x:: TrackedType , y:: TrackedType ) = (numer_partials (value (y)), denom_partials (value (x), value (y)))
437
- rdiv_cache (x:: TrackedType , y) = (numer_partials (value (y)), nothing )
438
- rdiv_cache (x, y:: TrackedType ) = (nothing , denom_partials (value (x), value (y)))
436
+ rdiv_cache (x, y) = (numer_partials (value (y)), denom_partials (value (x), value (y)))
439
437
440
438
function broadcast_rdiv {D} (x, y, :: Type{D} )
441
439
tp = tape (x, y)
452
450
pull_value! (a)
453
451
pull_value! (b)
454
452
broadcast! (/ , value (output), a_value, b_value)
455
- ! ( isa (n_partials, Void) ) && numer_partials! (n_partials, b_value)
456
- ! ( isa (d_partials, Void) ) && denom_partials! (d_partials, a_value, b_value)
453
+ istracked (a ) && numer_partials! (n_partials, b_value)
454
+ istracked (b ) && denom_partials! (d_partials, a_value, b_value)
457
455
return nothing
458
456
end
459
457
485
483
pull_value! (a)
486
484
pull_value! (b)
487
485
broadcast! (\ , value (output), a_value, b_value)
488
- ! ( isa (n_partials, Void) ) && numer_partials! (n_partials, a_value)
489
- ! ( isa (d_partials, Void) ) && denom_partials! (d_partials, b_value, a_value)
486
+ istracked (b ) && numer_partials! (n_partials, a_value)
487
+ istracked (a ) && denom_partials! (d_partials, b_value, a_value)
490
488
return nothing
491
489
end
492
490
@@ -515,9 +513,7 @@ exp_partials(b, e) = broadcast(exp_partials_kernel, b, e)
515
513
exp_partials! (out:: Ref , b, e) = (out[] = exp_partials_kernel (b, e); nothing )
516
514
exp_partials! (out:: AbstractArray , b, e) = (broadcast! (exp_partials_kernel, out, b, e); nothing )
517
515
518
- pow_cache (x:: TrackedType , y:: TrackedType ) = (base_partials (value (x), value (y)), exp_partials (value (x), value (y)))
519
- pow_cache (x:: TrackedType , y) = (base_partials (value (x), value (y)), nothing )
520
- pow_cache (x, y:: TrackedType ) = (nothing , exp_partials (value (x), value (y)))
516
+ pow_cache (x, y) = (base_partials (value (x), value (y)), exp_partials (value (x), value (y)))
521
517
522
518
function broadcast_pow {D} (x, y, :: Type{D} )
523
519
tp = tape (x, y)
534
530
pull_value! (a)
535
531
pull_value! (b)
536
532
broadcast! (^ , value (output), a_value, b_value)
537
- ! ( isa (bs_partials, Void) ) && base_partials! (bs_partials, a_value, b_value)
538
- ! ( isa (ex_partials, Void) ) && exp_partials! (ex_partials, a_value, b_value)
533
+ istracked (a ) && base_partials! (bs_partials, a_value, b_value)
534
+ istracked (b ) && exp_partials! (ex_partials, a_value, b_value)
539
535
return nothing
540
536
end
541
537
0 commit comments