Skip to content

Commit 0f8b777

Browse files
authored
add extension for MutableArithmetics (#488)
* add extension for MutableArithmetics * two places
1 parent f6df2a9 commit 0f8b777

File tree

5 files changed

+114
-0
lines changed

5 files changed

+114
-0
lines changed

Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,23 @@ version = "3.2.8"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b"
11+
MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"
1112
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1213

1314
[weakdeps]
1415
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1516
MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b"
17+
MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"
1618

1719
[extensions]
1820
PolynomialsChainRulesCoreExt = "ChainRulesCore"
1921
PolynomialsMakieCoreExt = "MakieCore"
22+
PolynomialsMutableArithmeticsExt = "MutableArithmetics"
2023

2124
[compat]
2225
ChainRulesCore = "1"
2326
MakieCore = "0.6"
27+
MutableArithmetics = "1"
2428
RecipesBase = "0.7, 0.8, 1"
2529
julia = "1.6"
2630

@@ -31,6 +35,7 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3135
DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
3236
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
3337
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
38+
MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"
3439
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3540
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3641
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
module PolynomialsMutableArithmeticsExt
2+
3+
using Polynomials
4+
import MutableArithmetics
5+
6+
const MA = MutableArithmetics
7+
8+
function _resize_zeros!(v::Vector, new_len)
9+
old_len = length(v)
10+
if old_len < new_len
11+
resize!(v, new_len)
12+
for i in (old_len + 1):new_len
13+
v[i] = zero(eltype(v))
14+
end
15+
end
16+
end
17+
18+
"""
19+
add_conv(out::Vector{T}, E::Vector{T}, k::Vector{T})
20+
Returns the vector `out + fastconv(E, k)`. Note that only
21+
`MA.buffered_operate!` is implemented.
22+
"""
23+
function add_conv end
24+
25+
# The buffer we need is the buffer needed by the `MA.add_mul` operation.
26+
# For instance, `BigInt`s need a `BigInt` buffer to store `E[x] * k[i]` before
27+
# adding it to `out[j]`.
28+
function MA.buffer_for(::typeof(add_conv), ::Type{Vector{T}}, ::Type{Vector{T}}, ::Type{Vector{T}}) where {T}
29+
return MA.buffer_for(MA.add_mul, T, T, T)
30+
end
31+
32+
function MA.buffered_operate!(buffer, ::typeof(add_conv), out::Vector{T}, E::Vector{T}, k::Vector{T}) where {T}
33+
for x in eachindex(E)
34+
for i in eachindex(k)
35+
j = x + i - 1
36+
out[j] = MA.buffered_operate!(buffer, MA.add_mul, out[j], E[x], k[i])
37+
end
38+
end
39+
return out
40+
end
41+
42+
"""
43+
@register_mutable_arithmetic
44+
Register polynomial type (with vector based backend) to work with MutableArithmetics
45+
"""
46+
macro register_mutable_arithmetic(name)
47+
poly = esc(name)
48+
quote
49+
MA.mutability(::Type{<:$poly}) = MA.IsMutable()
50+
51+
function MA.promote_operation(::Union{typeof(+), typeof(*)},
52+
::Type{$poly{S,X}}, ::Type{$poly{T,X}}) where {S,T,X}
53+
R = promote_type(S,T)
54+
return $poly{R,X}
55+
end
56+
57+
function MA.buffer_for(::typeof(MA.add_mul),
58+
::Type{<:$poly{T,X}},
59+
::Type{<:$poly{T,X}}, ::Type{<:$poly{T,X}}) where {T,X}
60+
V = Vector{T}
61+
return MA.buffer_for(add_conv, V, V, V)
62+
end
63+
64+
function MA.buffered_operate!(buffer, ::typeof(MA.add_mul),
65+
p::$poly, q::$poly, r::$poly)
66+
ps, qs, rs = coeffs(p), coeffs(q), coeffs(r)
67+
_resize_zeros!(ps, length(qs) + length(rs) - 1)
68+
MA.buffered_operate!(buffer, add_conv, ps, qs, rs)
69+
return p
70+
end
71+
end
72+
end
73+
74+
@register_mutable_arithmetic Polynomials.Polynomial
75+
@register_mutable_arithmetic Polynomials.PnPolynomial
76+
77+
## Ambiguities. Issue #435
78+
#Base.:+(p::P, ::MutableArithmetics.Zero) where {T, X, P<:Polynomials.AbstractPolynomial{T, X}} = p
79+
#Base.:+(p::P, ::T) where {T<:MutableArithmetics.Zero, P<:Polynomials.StandardBasisPolynomial{T}} = p
80+
81+
end

src/Polynomials.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ include("polynomials/Poly.jl")
3838
if !isdefined(Base, :get_extension)
3939
include("../ext/PolynomialsChainRulesCoreExt.jl")
4040
include("../ext/PolynomialsMakieCoreExt.jl")
41+
include("../ext/PolynomialsMutableArithmeticsExt.jl")
4142
end
4243

4344
include("precompiles.jl")

test/mutable-arithmetics.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import MutableArithmetics
2+
const MA = MutableArithmetics
3+
4+
function alloc_test(f, n)
5+
f() # compile
6+
@test n == @allocated f()
7+
end
8+
9+
10+
@testset "PolynomialsMutableArithmetics.jl" begin
11+
d = m = n = 4
12+
p(d) = Polynomial(big.(1:d))
13+
z(d) = Polynomial([zero(BigInt) for i in 1:d])
14+
A = [p(d) for i in 1:m, j in 1:n]
15+
b = [p(d) for i in 1:n]
16+
c = [z(2d - 1) for i in 1:m]
17+
buffer = MA.buffer_for(MA.add_mul, typeof(c), typeof(A), typeof(b))
18+
@test buffer isa BigInt
19+
c = [z(2d - 1) for i in 1:m]
20+
MA.buffered_operate!(buffer, MA.add_mul, c, A, b)
21+
@test c == A * b
22+
@test c == MA.operate(*, A, b)
23+
@test 0 == @allocated MA.buffered_operate!(buffer, MA.add_mul, c, A, b)
24+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,6 @@ using OffsetArrays
1212
@testset "ChebyshevT" begin include("ChebyshevT.jl") end
1313
@testset "Rational functions" begin include("rational-functions.jl") end
1414
@testset "Poly, Pade (compatability)" begin include("Poly.jl") end
15+
if VERSION >= v"1.9.0-"
16+
@testset "MutableArithmetics" begin include("mutable-arithmetics.jl") end
17+
end

0 commit comments

Comments
 (0)