Skip to content

Commit 0993510

Browse files
blegatharris-mit
andcommitted
Fix rrule for * and add support for constant operations
Co-authored-by: Mitchell Harris <[email protected]>
1 parent c5f360b commit 0993510

File tree

1 file changed

+73
-5
lines changed

1 file changed

+73
-5
lines changed

src/chain_rules.jl

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,95 @@ ChainRulesCore.@scalar_rule +(x::APL) true
44
ChainRulesCore.@scalar_rule -(x::APL) -1
55

66
ChainRulesCore.@scalar_rule +(x::APL, y::APL) (true, true)
7+
function plusconstant1_pullback(Δ)
8+
return ChainRulesCore.NoTangent(), Δ, coefficient(Δ, constantmonomial(Δ))
9+
end
10+
function ChainRulesCore.rrule(::typeof(plusconstant), p::APL, α)
11+
return plusconstant(p, α), plusconstant1_pullback
12+
end
13+
function plusconstant2_pullback(Δ)
14+
return ChainRulesCore.NoTangent(), coefficient(Δ, constantmonomial(Δ)), Δ
15+
end
16+
function ChainRulesCore.rrule(::typeof(plusconstant), α, p::APL)
17+
return plusconstant(α, p), plusconstant2_pullback
18+
end
719
ChainRulesCore.@scalar_rule -(x::APL, y::APL) (true, -1)
820

921
function ChainRulesCore.frule((_, Δp, Δq), ::typeof(*), p::APL, q::APL)
1022
return p * q, MA.add_mul!!(p * Δq, q, Δp)
1123
end
24+
25+
function _adjoint_mult(op::F, ts, p, Δ) where {F<:Function}
26+
for t in terms(p)
27+
c = coefficient(t)
28+
m = monomial(t)
29+
for δ in Δ
30+
if divides(m, δ)
31+
coef = op(c, coefficient(δ))
32+
mono = _div(monomial(δ), m)
33+
push!(ts, term(coef, mono))
34+
end
35+
end
36+
end
37+
return polynomial(ts)
38+
end
39+
function adjoint_mult_left(p, Δ)
40+
ts = MA.promote_operation(*, MA.promote_operation(adjoint, termtype(p)), termtype(Δ))[]
41+
return _adjoint_mult(ts, p, Δ) do c, d
42+
c' * d
43+
end
44+
end
45+
function adjoint_mult_right(p, Δ)
46+
ts = MA.promote_operation(*, termtype(Δ), MA.promote_operation(adjoint, termtype(p)))[]
47+
return _adjoint_mult(ts, p, Δ) do c, d
48+
d * c'
49+
end
50+
end
51+
1252
function ChainRulesCore.rrule(::typeof(*), p::APL, q::APL)
1353
function times_pullback2(ΔΩ̇)
14-
#ΔΩ = ChainRulesCore.unthunk(Ω̇)
15-
#return (ChainRulesCore.NoTangent(), ChainRulesCore.ProjectTo(p)(ΔΩ * q'), ChainRulesCore.ProjectTo(q)(p' * ΔΩ))
16-
return (ChainRulesCore.NoTangent(), ΔΩ̇ * q', p' * ΔΩ̇)
54+
return (ChainRulesCore.NoTangent(), adjoint_mult_right(q, ΔΩ̇), adjoint_mult_left(p, ΔΩ̇))
1755
end
1856
return p * q, times_pullback2
1957
end
2058

59+
function ChainRulesCore.rrule(::typeof(multconstant), α, p::APL)
60+
function times_pullback2(ΔΩ̇)
61+
# TODO we could make it faster, don't need to compute `Δα` entirely if we only care about the constant term.
62+
Δα = adjoint_mult_right(p, ΔΩ̇)
63+
return (ChainRulesCore.NoTangent(), coefficient(Δα, constantmonomial(Δα)), α' * ΔΩ̇)
64+
end
65+
return multconstant(α, p), times_pullback2
66+
end
67+
68+
function ChainRulesCore.rrule(::typeof(multconstant), p::APL, α)
69+
function times_pullback2(ΔΩ̇)
70+
# TODO we could make it faster, don't need to compute `Δα` entirely if we only care about the constant term.
71+
Δα = adjoint_mult_left(p, ΔΩ̇)
72+
return (ChainRulesCore.NoTangent(), ΔΩ̇ * α', coefficient(Δα, constantmonomial(Δα)))
73+
end
74+
return multconstant(p, α), times_pullback2
75+
end
76+
77+
notangent3(Δ) = ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
78+
function ChainRulesCore.rrule(::typeof(^), mono::AbstractMonomialLike, i::Integer)
79+
return mono^i, notangent3
80+
end
81+
2182
function ChainRulesCore.frule((_, Δp, _), ::typeof(differentiate), p, x)
2283
return differentiate(p, x), differentiate(Δp, x)
2384
end
24-
function pullback(Δdpdx, x)
85+
function differentiate_pullback(Δdpdx, x)
2586
return ChainRulesCore.NoTangent(), x * differentiate(x * Δdpdx, x), ChainRulesCore.NoTangent()
2687
end
2788
function ChainRulesCore.rrule(::typeof(differentiate), p, x)
2889
dpdx = differentiate(p, x)
29-
return dpdx, Base.Fix2(pullback, x)
90+
return dpdx, Base.Fix2(differentiate_pullback, x)
91+
end
92+
93+
function coefficient_pullback(Δ, m::AbstractMonomialLike)
94+
return ChainRulesCore.NoTangent(), polynomial(term(Δ, m)), ChainRulesCore.NoTangent()
95+
end
96+
function ChainRulesCore.rrule(::typeof(coefficient), p::APL, m::AbstractMonomialLike)
97+
return coefficient(p, m), Base.Fix2(coefficient_pullback, m)
3098
end

0 commit comments

Comments
 (0)