Skip to content

Commit b7bffc1

Browse files
authored
Make use of mul indirection (#1215)
1 parent 4c9e125 commit b7bffc1

File tree

7 files changed

+29
-30
lines changed

7 files changed

+29
-30
lines changed

src/bidiag.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -1189,25 +1189,25 @@ function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
11891189
C
11901190
end
11911191

1192-
function *(A::UpperOrUnitUpperTriangular, B::Bidiagonal)
1192+
function mul(A::UpperOrUnitUpperTriangular, B::Bidiagonal)
11931193
TS = promote_op(matprod, eltype(A), eltype(B))
11941194
C = mul!(similar(A, TS, size(A)), A, B)
11951195
return B.uplo == 'U' ? UpperTriangular(C) : C
11961196
end
11971197

1198-
function *(A::LowerOrUnitLowerTriangular, B::Bidiagonal)
1198+
function mul(A::LowerOrUnitLowerTriangular, B::Bidiagonal)
11991199
TS = promote_op(matprod, eltype(A), eltype(B))
12001200
C = mul!(similar(A, TS, size(A)), A, B)
12011201
return B.uplo == 'L' ? LowerTriangular(C) : C
12021202
end
12031203

1204-
function *(A::Bidiagonal, B::UpperOrUnitUpperTriangular)
1204+
function mul(A::Bidiagonal, B::UpperOrUnitUpperTriangular)
12051205
TS = promote_op(matprod, eltype(A), eltype(B))
12061206
C = mul!(similar(B, TS, size(B)), A, B)
12071207
return A.uplo == 'U' ? UpperTriangular(C) : C
12081208
end
12091209

1210-
function *(A::Bidiagonal, B::LowerOrUnitLowerTriangular)
1210+
function mul(A::Bidiagonal, B::LowerOrUnitLowerTriangular)
12111211
TS = promote_op(matprod, eltype(A), eltype(B))
12121212
C = mul!(similar(B, TS, size(B)), A, B)
12131213
return A.uplo == 'L' ? LowerTriangular(C) : C

src/diagonal.jl

+7-17
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ Base.literal_pow(::typeof(^), D::Diagonal, valp::Val) =
320320
Diagonal(Base.literal_pow.(^, D.diag, valp)) # for speed
321321
Base.literal_pow(::typeof(^), D::Diagonal, ::Val{-1}) = inv(D) # for disambiguation
322322

323-
function (*)(Da::Diagonal, Db::Diagonal)
323+
function mul(Da::Diagonal, Db::Diagonal)
324324
matmul_size_check(size(Da), size(Db))
325325
return Diagonal(Da.diag .* Db.diag)
326326
end
@@ -330,26 +330,19 @@ function (*)(D::Diagonal, V::AbstractVector)
330330
return D.diag .* V
331331
end
332332

333-
function _diag_adj_mul(A::AdjOrTransAbsMat, D::Diagonal)
333+
function mul(A::AdjOrTransAbsMat, D::Diagonal)
334334
adj = wrapperop(A)
335335
copy(adj(adj(D) * adj(A)))
336336
end
337-
function _diag_adj_mul(A::AdjOrTransAbsMat{<:Number, <:StridedMatrix}, D::Diagonal{<:Number})
338-
@invoke *(A::AbstractMatrix, D::AbstractMatrix)
337+
function mul(A::AdjOrTransAbsMat{<:Number, <:StridedMatrix}, D::Diagonal{<:Number})
338+
@invoke mul(A::AbstractMatrix, D::AbstractMatrix)
339339
end
340-
function _diag_adj_mul(D::Diagonal, A::AdjOrTransAbsMat)
340+
function mul(D::Diagonal, A::AdjOrTransAbsMat)
341341
adj = wrapperop(A)
342342
copy(adj(adj(A) * adj(D)))
343343
end
344-
function _diag_adj_mul(D::Diagonal{<:Number}, A::AdjOrTransAbsMat{<:Number, <:StridedMatrix})
345-
@invoke *(D::AbstractMatrix, A::AbstractMatrix)
346-
end
347-
348-
function (*)(A::AdjOrTransAbsMat, D::Diagonal)
349-
_diag_adj_mul(A, D)
350-
end
351-
function (*)(D::Diagonal, A::AdjOrTransAbsMat)
352-
_diag_adj_mul(D, A)
344+
function mul(D::Diagonal{<:Number}, A::AdjOrTransAbsMat{<:Number, <:StridedMatrix})
345+
@invoke mul(D::AbstractMatrix, A::AbstractMatrix)
353346
end
354347

355348
function rmul!(A::AbstractMatrix, D::Diagonal)
@@ -1088,9 +1081,6 @@ end
10881081
*(x::TransposeAbsVec, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
10891082
/(u::AdjointAbsVec, D::Diagonal) = (D' \ u')'
10901083
/(u::TransposeAbsVec, D::Diagonal) = transpose(transpose(D) \ transpose(u))
1091-
# disambiguation methods: Call unoptimized version for user defined AbstractTriangular.
1092-
*(A::AbstractTriangular, D::Diagonal) = @invoke *(A::AbstractMatrix, D::Diagonal)
1093-
*(D::Diagonal, A::AbstractTriangular) = @invoke *(D::Diagonal, A::AbstractMatrix)
10941084

10951085
_opnorm1(A::Diagonal) = maximum(norm(x) for x in A.diag)
10961086
_opnormInf(A::Diagonal) = maximum(norm(x) for x in A.diag)

src/hessenberg.jl

+7-3
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ for T = (:UniformScaling, :Diagonal, :Bidiagonal, :Tridiagonal, :SymTridiagonal,
137137
end
138138
end
139139

140-
for T = (:Number, :UniformScaling, :Diagonal)
140+
for T = (:Number, :UniformScaling)
141141
@eval begin
142142
*(H::UpperHessenberg, x::$T) = UpperHessenberg(H.data * x)
143143
*(x::$T, H::UpperHessenberg) = UpperHessenberg(x * H.data)
@@ -146,19 +146,23 @@ for T = (:Number, :UniformScaling, :Diagonal)
146146
end
147147
end
148148

149-
function *(H::UpperHessenberg, U::UpperOrUnitUpperTriangular)
149+
mul(H::UpperHessenberg, D::Diagonal) = UpperHessenberg(H.data * D)
150+
mul(D::Diagonal, H::UpperHessenberg) = UpperHessenberg(D * H.data)
151+
function mul(H::UpperHessenberg, U::UpperOrUnitUpperTriangular)
150152
HH = mul!(matprod_dest(H, U, promote_op(matprod, eltype(H), eltype(U))), H, U)
151153
UpperHessenberg(HH)
152154
end
153-
function *(U::UpperOrUnitUpperTriangular, H::UpperHessenberg)
155+
function mul(U::UpperOrUnitUpperTriangular, H::UpperHessenberg)
154156
HH = mul!(matprod_dest(U, H, promote_op(matprod, eltype(U), eltype(H))), U, H)
155157
UpperHessenberg(HH)
156158
end
157159

160+
/(H::UpperHessenberg, D::Diagonal) = UpperHessenberg(H.data / D)
158161
function /(H::UpperHessenberg, U::UpperTriangular)
159162
HH = _rdiv!(matprod_dest(H, U, promote_op(/, eltype(H), eltype(U))), H, U)
160163
UpperHessenberg(HH)
161164
end
165+
\(D::Diagonal, H::UpperHessenberg) = UpperHessenberg(D \ H.data)
162166
function /(H::UpperHessenberg, U::UnitUpperTriangular)
163167
HH = _rdiv!(matprod_dest(H, U, promote_op(/, eltype(H), eltype(U))), H, U)
164168
UpperHessenberg(HH)

src/special.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,12 @@ _mul!(C::AbstractMatrix, A::AbstractTriangular, B::BandedMatrix, alpha::Number,
120120
_mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractTriangular, alpha::Number, beta::Number) =
121121
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
122122

123-
function *(H::UpperHessenberg, B::Bidiagonal)
123+
function mul(H::UpperHessenberg, B::Bidiagonal)
124124
T = promote_op(matprod, eltype(H), eltype(B))
125125
A = mul!(similar(H, T, size(H)), H, B)
126126
return B.uplo == 'U' ? UpperHessenberg(A) : A
127127
end
128-
function *(B::Bidiagonal, H::UpperHessenberg)
128+
function mul(B::Bidiagonal, H::UpperHessenberg)
129129
T = promote_op(matprod, eltype(B), eltype(H))
130130
A = mul!(similar(H, T, size(H)), B, H)
131131
return B.uplo == 'U' ? UpperHessenberg(A) : A

src/symmetric.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -702,15 +702,15 @@ for f in (:+, :-)
702702
end
703703
end
704704

705-
*(A::HermOrSym, B::HermOrSym) = A * copyto!(similar(parent(B)), B)
705+
mul(A::HermOrSym, B::HermOrSym) = A * copyto!(similar(parent(B)), B)
706706
# catch a few potential BLAS-cases
707-
function *(A::HermOrSym{<:BlasFloat,<:StridedMatrix}, B::AdjOrTrans{<:BlasFloat,<:StridedMatrix})
707+
function mul(A::HermOrSym{<:BlasFloat,<:StridedMatrix}, B::AdjOrTrans{<:BlasFloat,<:StridedMatrix})
708708
T = promote_type(eltype(A), eltype(B))
709709
mul!(similar(B, T, (size(A, 1), size(B, 2))),
710710
convert(AbstractMatrix{T}, A),
711711
copy_oftype(B, T)) # make sure the AdjOrTrans wrapper is resolved
712712
end
713-
function *(A::AdjOrTrans{<:BlasFloat,<:StridedMatrix}, B::HermOrSym{<:BlasFloat,<:StridedMatrix})
713+
function mul(A::AdjOrTrans{<:BlasFloat,<:StridedMatrix}, B::HermOrSym{<:BlasFloat,<:StridedMatrix})
714714
T = promote_type(eltype(A), eltype(B))
715715
mul!(similar(B, T, (size(A, 1), size(B, 2))),
716716
copy_oftype(A, T), # make sure the AdjOrTrans wrapper is resolved

src/triangular.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1886,7 +1886,7 @@ end
18861886

18871887
## Some Triangular-Triangular cases. We might want to write tailored methods
18881888
## for these cases, but I'm not sure it is worth it.
1889-
for f in (:*, :\)
1889+
for f in (:mul, :\)
18901890
@eval begin
18911891
($f)(A::LowerTriangular, B::LowerTriangular) =
18921892
LowerTriangular(@invoke $f(A::LowerTriangular, B::AbstractMatrix))

test/triangular.jl

+5
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,7 @@ end
703703
@testset "(l/r)mul! and (l/r)div! for generic triangular" begin
704704
@testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
705705
M = MyTriangular(T(rand(4,4)))
706+
D = Diagonal(randn(4))
706707
A = rand(4,4)
707708
Ac = similar(A)
708709
@testset "lmul!" begin
@@ -725,6 +726,10 @@ end
725726
rdiv!(Ac, M)
726727
@test Ac A / M
727728
end
729+
@testset "diagonal mul" begin
730+
@test D * M D * M.data
731+
@test M * D M.data * D
732+
end
728733
end
729734
end
730735

0 commit comments

Comments
 (0)