@@ -12,6 +12,20 @@ function test_chain_rule(dot, op, args, Δin, Δout)
12
12
@test dot (Δin, rΔin[2 : end ]) ≈ dot (fΔout, Δout)
13
13
end
14
14
15
+ function _dot (p, q)
16
+ monos = monovec ([monomials (p); monomials (q)])
17
+ return dot (coefficient .(p, monos), coefficient .(q, monos))
18
+ end
19
+ function _dot (px:: Tuple , qx:: Tuple )
20
+ return _dot (first (px), first (qx)) + _dot (Base. tail (px), Base. tail (qx))
21
+ end
22
+ function _dot (:: Tuple{} , :: Tuple{} )
23
+ return MultivariatePolynomials. MA. Zero ()
24
+ end
25
+ function _dot (:: NoTangent , :: NoTangent )
26
+ return MultivariatePolynomials. MA. Zero ()
27
+ end
28
+
15
29
@testset " ChainRulesCore" begin
16
30
Mod. @polyvar x y
17
31
p = 1.1 x + y
42
56
@test pullback (q) == (NoTangent (), (- 0.2 + 2im ) * x^ 2 - x* y, NoTangent ())
43
57
@test pullback (1 x) == (NoTangent (), 2 x^ 2 , NoTangent ())
44
58
45
- test_chain_rule (dot, + , (p,), (q,), p)
46
- test_chain_rule (dot, + , (q,), (p,), q)
59
+ for d in [dot, _dot]
60
+ test_chain_rule (d, + , (p,), (q,), p)
61
+ test_chain_rule (d, + , (q,), (p,), q)
47
62
48
- test_chain_rule (dot , - , (p,), (q,), p)
49
- test_chain_rule (dot , - , (p,), (p,), q)
63
+ test_chain_rule (d , - , (p,), (q,), p)
64
+ test_chain_rule (d , - , (p,), (p,), q)
50
65
51
- test_chain_rule (dot , + , (p, q), (q, p), p)
52
- test_chain_rule (dot , + , (p, q), (p, q), q)
66
+ test_chain_rule (d , + , (p, q), (q, p), p)
67
+ test_chain_rule (d , + , (p, q), (p, q), q)
53
68
54
- test_chain_rule (dot, - , (p, q), (q, p), p)
55
- test_chain_rule (dot, - , (p, q), (p, q), q)
69
+ test_chain_rule (d, - , (p, q), (q, p), p)
70
+ test_chain_rule (d, - , (p, q), (p, q), q)
71
+ end
56
72
57
- test_chain_rule (dot , * , (p, q), (q, p), p * q)
58
- test_chain_rule (dot , * , (p, q), (p, q), q * q)
59
- test_chain_rule (dot , * , (q, p), (p, q), q * q)
60
- test_chain_rule (dot , * , (p, q), (q, p), q * q)
73
+ test_chain_rule (_dot , * , (p, q), (q, p), p * q)
74
+ test_chain_rule (_dot , * , (p, q), (p, q), q * q)
75
+ test_chain_rule (_dot , * , (q, p), (p, q), q * q)
76
+ test_chain_rule (_dot , * , (p, q), (q, p), q * q)
61
77
62
- function _dot (p, q)
63
- monos = monomials (p + q)
64
- return dot (coefficient .(p, monos), coefficient .(q, monos))
65
- end
66
- function _dot (px:: Tuple{<:AbstractPolynomial,NoTangent} , qx:: Tuple{<:AbstractPolynomial,NoTangent} )
67
- return _dot (px[1 ], qx[1 ])
68
- end
69
78
test_chain_rule (_dot, differentiate, (p, x), (q, NoTangent ()), p)
70
79
test_chain_rule (_dot, differentiate, (p, x), (q, NoTangent ()), differentiate (p, x))
71
80
test_chain_rule (_dot, differentiate, (p, x), (q, NoTangent ()), differentiate (q, x))
0 commit comments