@@ -4,27 +4,95 @@ ChainRulesCore.@scalar_rule +(x::APL) true
4
4
ChainRulesCore. @scalar_rule - (x:: APL ) - 1
5
5
6
6
ChainRulesCore. @scalar_rule + (x:: APL , y:: APL ) (true , true )
7
+ function plusconstant1_pullback (Δ)
8
+ return ChainRulesCore. NoTangent (), Δ, coefficient (Δ, constantmonomial (Δ))
9
+ end
10
+ function ChainRulesCore. rrule (:: typeof (plusconstant), p:: APL , α)
11
+ return plusconstant (p, α), plusconstant1_pullback
12
+ end
13
+ function plusconstant2_pullback (Δ)
14
+ return ChainRulesCore. NoTangent (), coefficient (Δ, constantmonomial (Δ)), Δ
15
+ end
16
+ function ChainRulesCore. rrule (:: typeof (plusconstant), α, p:: APL )
17
+ return plusconstant (α, p), plusconstant2_pullback
18
+ end
7
19
ChainRulesCore. @scalar_rule - (x:: APL , y:: APL ) (true , - 1 )
8
20
9
21
function ChainRulesCore. frule ((_, Δp, Δq), :: typeof (* ), p:: APL , q:: APL )
10
22
return p * q, MA. add_mul!! (p * Δq, q, Δp)
11
23
end
24
+
25
+ function _adjoint_mult (op:: F , ts, p, Δ) where {F<: Function }
26
+ for t in terms (p)
27
+ c = coefficient (t)
28
+ m = monomial (t)
29
+ for δ in Δ
30
+ if divides (m, δ)
31
+ coef = op (c, coefficient (δ))
32
+ mono = _div (monomial (δ), m)
33
+ push! (ts, term (coef, mono))
34
+ end
35
+ end
36
+ end
37
+ return polynomial (ts)
38
+ end
39
+ function adjoint_mult_left (p, Δ)
40
+ ts = MA. promote_operation (* , MA. promote_operation (adjoint, termtype (p)), termtype (Δ))[]
41
+ return _adjoint_mult (ts, p, Δ) do c, d
42
+ c' * d
43
+ end
44
+ end
45
+ function adjoint_mult_right (p, Δ)
46
+ ts = MA. promote_operation (* , termtype (Δ), MA. promote_operation (adjoint, termtype (p)))[]
47
+ return _adjoint_mult (ts, p, Δ) do c, d
48
+ d * c'
49
+ end
50
+ end
51
+
12
52
function ChainRulesCore. rrule (:: typeof (* ), p:: APL , q:: APL )
13
53
function times_pullback2 (ΔΩ̇)
14
- # ΔΩ = ChainRulesCore.unthunk(Ω̇)
15
- # return (ChainRulesCore.NoTangent(), ChainRulesCore.ProjectTo(p)(ΔΩ * q'), ChainRulesCore.ProjectTo(q)(p' * ΔΩ))
16
- return (ChainRulesCore. NoTangent (), ΔΩ̇ * q' , p' * ΔΩ̇)
54
+ return (ChainRulesCore. NoTangent (), adjoint_mult_right (q, ΔΩ̇), adjoint_mult_left (p, ΔΩ̇))
17
55
end
18
56
return p * q, times_pullback2
19
57
end
20
58
59
+ function ChainRulesCore. rrule (:: typeof (multconstant), α, p:: APL )
60
+ function times_pullback2 (ΔΩ̇)
61
+ # TODO we could make it faster, don't need to compute `Δα` entirely if we only care about the constant term.
62
+ Δα = adjoint_mult_right (p, ΔΩ̇)
63
+ return (ChainRulesCore. NoTangent (), coefficient (Δα, constantmonomial (Δα)), α' * ΔΩ̇)
64
+ end
65
+ return multconstant (α, p), times_pullback2
66
+ end
67
+
68
+ function ChainRulesCore. rrule (:: typeof (multconstant), p:: APL , α)
69
+ function times_pullback2 (ΔΩ̇)
70
+ # TODO we could make it faster, don't need to compute `Δα` entirely if we only care about the constant term.
71
+ Δα = adjoint_mult_left (p, ΔΩ̇)
72
+ return (ChainRulesCore. NoTangent (), ΔΩ̇ * α' , coefficient (Δα, constantmonomial (Δα)))
73
+ end
74
+ return multconstant (p, α), times_pullback2
75
+ end
76
+
77
+ notangent3 (Δ) = ChainRulesCore. NoTangent (), ChainRulesCore. NoTangent (), ChainRulesCore. NoTangent ()
78
+ function ChainRulesCore. rrule (:: typeof (^ ), mono:: AbstractMonomialLike , i:: Integer )
79
+ return mono^ i, notangent3
80
+ end
81
+
21
82
function ChainRulesCore. frule ((_, Δp, _), :: typeof (differentiate), p, x)
22
83
return differentiate (p, x), differentiate (Δp, x)
23
84
end
24
- function pullback (Δdpdx, x)
85
+ function differentiate_pullback (Δdpdx, x)
25
86
return ChainRulesCore. NoTangent (), x * differentiate (x * Δdpdx, x), ChainRulesCore. NoTangent ()
26
87
end
27
88
function ChainRulesCore. rrule (:: typeof (differentiate), p, x)
28
89
dpdx = differentiate (p, x)
29
- return dpdx, Base. Fix2 (pullback, x)
90
+ return dpdx, Base. Fix2 (differentiate_pullback, x)
91
+ end
92
+
93
+ function coefficient_pullback (Δ, m:: AbstractMonomialLike )
94
+ return ChainRulesCore. NoTangent (), polynomial (term (Δ, m)), ChainRulesCore. NoTangent ()
95
+ end
96
+ function ChainRulesCore. rrule (:: typeof (coefficient), p:: APL , m:: AbstractMonomialLike )
97
+ return coefficient (p, m), Base. Fix2 (coefficient_pullback, m)
30
98
end
0 commit comments