Skip to content

Commit 2c3df71

Browse files
committed
Clarify issue with scalar products
1 parent 0993510 commit 2c3df71

File tree

2 files changed

+41
-22
lines changed

2 files changed

+41
-22
lines changed

src/chain_rules.jl

+13-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# The publlback depends on the scalar product on the polynomials
2+
# With the scalar product `LinearAlgebra.dot(p, q) = p * q`, there is no pullback for `differentiate`
3+
# With the scalar product `_dot(p, q)` of `test/chain_rules.jl`, there is a pullback for `differentiate`
4+
# and the pullback for `*` changes.
5+
# We give the one for the scalar product `_dot`.
6+
17
import ChainRulesCore
28

39
ChainRulesCore.@scalar_rule +(x::APL) true
@@ -22,7 +28,7 @@ function ChainRulesCore.frule((_, Δp, Δq), ::typeof(*), p::APL, q::APL)
2228
return p * q, MA.add_mul!!(p * Δq, q, Δp)
2329
end
2430

25-
function _adjoint_mult(op::F, ts, p, Δ) where {F<:Function}
31+
function _mult_pullback(op::F, ts, p, Δ) where {F<:Function}
2632
for t in terms(p)
2733
c = coefficient(t)
2834
m = monomial(t)
@@ -38,20 +44,23 @@ function _adjoint_mult(op::F, ts, p, Δ) where {F<:Function}
3844
end
3945
function adjoint_mult_left(p, Δ)
4046
ts = MA.promote_operation(*, MA.promote_operation(adjoint, termtype(p)), termtype(Δ))[]
41-
return _adjoint_mult(ts, p, Δ) do c, d
47+
return _mult_pullback(ts, p, Δ) do c, d
4248
c' * d
4349
end
4450
end
4551
function adjoint_mult_right(p, Δ)
4652
ts = MA.promote_operation(*, termtype(Δ), MA.promote_operation(adjoint, termtype(p)))[]
47-
return _adjoint_mult(ts, p, Δ) do c, d
53+
return _mult_pullback(ts, p, Δ) do c, d
4854
d * c'
4955
end
5056
end
5157

5258
function ChainRulesCore.rrule(::typeof(*), p::APL, q::APL)
5359
function times_pullback2(ΔΩ̇)
60+
# This is for the scalar product `_dot`:
5461
return (ChainRulesCore.NoTangent(), adjoint_mult_right(q, ΔΩ̇), adjoint_mult_left(p, ΔΩ̇))
62+
# For the scalar product `dot`, it would be instead:
63+
return (ChainRulesCore.NoTangent(), ΔΩ̇ * q', p' * ΔΩ̇)
5564
end
5665
return p * q, times_pullback2
5766
end
@@ -82,6 +91,7 @@ end
8291
function ChainRulesCore.frule((_, Δp, _), ::typeof(differentiate), p, x)
8392
return differentiate(p, x), differentiate(Δp, x)
8493
end
94+
# This is for the scalar product `_dot`, there is no pullback for the scalar product `dot`
8595
function differentiate_pullback(Δdpdx, x)
8696
return ChainRulesCore.NoTangent(), x * differentiate(x * Δdpdx, x), ChainRulesCore.NoTangent()
8797
end

test/chain_rules.jl

+28-19
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,20 @@ function test_chain_rule(dot, op, args, Δin, Δout)
1212
@test dot(Δin, rΔin[2:end]) dot(fΔout, Δout)
1313
end
1414

15+
function _dot(p, q)
16+
monos = monovec([monomials(p); monomials(q)])
17+
return dot(coefficient.(p, monos), coefficient.(q, monos))
18+
end
19+
function _dot(px::Tuple, qx::Tuple)
20+
return _dot(first(px), first(qx)) + _dot(Base.tail(px), Base.tail(qx))
21+
end
22+
function _dot(::Tuple{}, ::Tuple{})
23+
return MultivariatePolynomials.MA.Zero()
24+
end
25+
function _dot(::NoTangent, ::NoTangent)
26+
return MultivariatePolynomials.MA.Zero()
27+
end
28+
1529
@testset "ChainRulesCore" begin
1630
Mod.@polyvar x y
1731
p = 1.1x + y
@@ -42,30 +56,25 @@ end
4256
@test pullback(q) == (NoTangent(), (-0.2 + 2im) * x^2 - x*y, NoTangent())
4357
@test pullback(1x) == (NoTangent(), 2x^2, NoTangent())
4458

45-
test_chain_rule(dot, +, (p,), (q,), p)
46-
test_chain_rule(dot, +, (q,), (p,), q)
59+
for d in [dot, _dot]
60+
test_chain_rule(d, +, (p,), (q,), p)
61+
test_chain_rule(d, +, (q,), (p,), q)
4762

48-
test_chain_rule(dot, -, (p,), (q,), p)
49-
test_chain_rule(dot, -, (p,), (p,), q)
63+
test_chain_rule(d, -, (p,), (q,), p)
64+
test_chain_rule(d, -, (p,), (p,), q)
5065

51-
test_chain_rule(dot, +, (p, q), (q, p), p)
52-
test_chain_rule(dot, +, (p, q), (p, q), q)
66+
test_chain_rule(d, +, (p, q), (q, p), p)
67+
test_chain_rule(d, +, (p, q), (p, q), q)
5368

54-
test_chain_rule(dot, -, (p, q), (q, p), p)
55-
test_chain_rule(dot, -, (p, q), (p, q), q)
69+
test_chain_rule(d, -, (p, q), (q, p), p)
70+
test_chain_rule(d, -, (p, q), (p, q), q)
71+
end
5672

57-
test_chain_rule(dot, *, (p, q), (q, p), p * q)
58-
test_chain_rule(dot, *, (p, q), (p, q), q * q)
59-
test_chain_rule(dot, *, (q, p), (p, q), q * q)
60-
test_chain_rule(dot, *, (p, q), (q, p), q * q)
73+
test_chain_rule(_dot, *, (p, q), (q, p), p * q)
74+
test_chain_rule(_dot, *, (p, q), (p, q), q * q)
75+
test_chain_rule(_dot, *, (q, p), (p, q), q * q)
76+
test_chain_rule(_dot, *, (p, q), (q, p), q * q)
6177

62-
function _dot(p, q)
63-
monos = monomials(p + q)
64-
return dot(coefficient.(p, monos), coefficient.(q, monos))
65-
end
66-
function _dot(px::Tuple{<:AbstractPolynomial,NoTangent}, qx::Tuple{<:AbstractPolynomial,NoTangent})
67-
return _dot(px[1], qx[1])
68-
end
6978
test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), p)
7079
test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), differentiate(p, x))
7180
test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), differentiate(q, x))

0 commit comments

Comments
 (0)