diff --git a/src/chain_rules.jl b/src/chain_rules.jl index fdeb75d8..2e3b9919 100644 --- a/src/chain_rules.jl +++ b/src/chain_rules.jl @@ -18,10 +18,10 @@ end function ChainRulesCore.frule((_, Δp, _), ::typeof(differentiate), p, x) return differentiate(p, x), differentiate(Δp, x) end -function pullback(Δdpdx, x) +function pullback_differentiate_polynomial(Δdpdx, x) return ChainRulesCore.NoTangent(), x * differentiate(x * Δdpdx, x), ChainRulesCore.NoTangent() end -function ChainRulesCore.rrule(::typeof(differentiate), p, x) +function ChainRulesCore.rrule(::typeof(differentiate), p::APL, x) dpdx = differentiate(p, x) - return dpdx, Base.Fix2(pullback, x) + return dpdx, Base.Fix2(pullback_differentiate_polynomial, x) end