Skip to content

multiplication specialization for RectDiagonal and DiagonalFill #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ Base.replace_in_print_matrix(A::RectDiagonal, i::Integer, j::Integer, s::Abstrac
const RectOrDiagonal{T,V,Axes} = Union{RectDiagonal{T,V,Axes}, Diagonal{T,V}}
const RectOrDiagonalFill{T,V<:AbstractFillVector{T},Axes} = RectOrDiagonal{T,V,Axes}
const RectDiagonalFill{T,V<:AbstractFillVector{T}} = RectDiagonal{T,V}
const DiagonalFill{T,V<:AbstractFillVector{T}} = Diagonal{T,V}
const SquareEye{T,Axes} = Diagonal{T,Ones{T,1,Tuple{Axes}}}
const Eye{T,Axes} = RectOrDiagonal{T,Ones{T,1,Tuple{Axes}}}

Expand Down
240 changes: 240 additions & 0 deletions src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
end
$OP(a::AbstractOnesMatrix) = fillsimilar(a, reverse(axes(a)))
$OP(a::FillMatrix) = Fill($OP(a.value), reverse(a.axes))
$OP(a::RectDiagonal) = RectDiagonal(vec($OP(a.diag)), reverse(a.axes))
end
end

Expand Down Expand Up @@ -80,10 +81,57 @@

*(a::AbstractFillMatrix, b::AbstractFillMatrix) = mult_fill(a,b)
*(a::AbstractFillMatrix, b::AbstractFillVector) = mult_fill(a,b)
for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AdjointAbsVec{<:Any,<:AbstractFillVector}, TransposeAbsVec{<:Any,<:AbstractFillVector})
@eval begin
function *(A::AbstractFillVector, B::$type)
size(A,2) == size(B,1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))"))
Fill(getindex_value(A) * getindex_value(B), size(A, 1), size(B, 2))
end
function *(A::AbstractFillMatrix, B::$type)
size(A,2) == size(B,1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))"))
Fill(getindex_value(A) * getindex_value(B), size(A, 1), size(B, 2))
end
end
end

# this treats a size (n,) vector as a nx1 matrix, so b needs to have 1 row
# special cased, as OnesMatrix * OnesMatrix isn't a Ones
*(a::AbstractOnesVector, b::AbstractOnesMatrix) = mult_ones(a, b)
for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector})
@eval begin
*(A::AbstractOnesVector, B::$type) = mult_ones(A, B)
*(A::AbstractOnesMatrix, B::$type) = mult_ones(A, B)
end
end

for type2 in (AdjointAbsVec{<:Any,<:AbstractZerosVector}, TransposeAbsVec{<:Any,<:AbstractZerosVector})
for type1 in (AbstractFillVector, AbstractZerosVector, AbstractOnesVector, AbstractFillMatrix, AbstractZerosMatrix, AbstractOnesMatrix)
@eval begin
function *(A::$type1, B::$type2)
size(A,2) == size(B,1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))"))
Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1), size(B, 2))
end
end
end
end

for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AdjointAbsVec{<:Any,<:AbstractFillVector}, TransposeAbsVec{<:Any,<:AbstractFillVector}, )
@eval begin
function *(A::AbstractZerosVector, B::$type)
size(A,2) == size(B,1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))"))
Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1), size(B, 2))
end
function *(A::AbstractZerosMatrix, B::$type)
size(A,2) == size(B,1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))"))
Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1), size(B, 2))
end
end
end

*(a::AbstractZerosMatrix, b::AbstractZerosMatrix) = mult_zeros(a, b)
*(a::AbstractZerosMatrix, b::AbstractZerosVector) = mult_zeros(a, b)
Expand Down Expand Up @@ -485,6 +533,198 @@
@inline elconvert(::Type{T}, A::AbstractUnitRange) where T<:Integer = AbstractUnitRange{T}(A)
@inline elconvert(::Type{T}, A::AbstractArray) where T = AbstractArray{T}(A)

# RectDiagonal Multiplication
const RectDiagonalZeros{T,V<:AbstractZerosVector{T}} = RectDiagonal{T,V}
const RectDiagonalOnes{T,V<:AbstractOnesVector{T}} = RectDiagonal{T,V}

function *(A::RectDiagonal, B::Diagonal)
check_matmul_sizes(A, B)
len = minimum(size(A))
RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), (size(A, 1), size(B, 2)))
end
function *(A::Diagonal, B::RectDiagonal)
check_matmul_sizes(A, B)
len = minimum(size(B))
RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), (size(A, 1), size(B, 2)))
end

for type in (AbstractMatrix, AbstractTriangular, AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector})
@eval begin
function *(A::RectDiagonal, B::$type)
check_matmul_sizes(A, B)
TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B))
diag = A.diag
out = fill!(similar(diag, TS, axes(A,1), axes(B,2)), 0)
len = Base.OneTo(minimum(size(A)))
out[len, :] .= view(diag, len) .* view(B, len, :)
out

Check warning on line 560 in src/fillalgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/fillalgebra.jl#L560

Added line #L560 was not covered by tests
end

function *(A::$type, B::RectDiagonal)
check_matmul_sizes(A, B)
TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B))
out = fill!(similar(A, TS, axes(A,1), axes(B, 2)), 0)
len = Base.OneTo(minimum(size(B)))
out[:, len] .= view(A, :, len) .* view(reshape(B.diag, 1, :), Base.OneTo(1), len)
out

Check warning on line 569 in src/fillalgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/fillalgebra.jl#L569

Added line #L569 was not covered by tests
end
end
end

function *(A::RectDiagonal, x::AbstractVector)
check_matmul_sizes(A, x)
TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(x))
diag = A.diag
out = fill!(similar(diag, TS, axes(A,1)), 0)
len = Base.OneTo(minimum(size(A)))
out[len] .= view(diag, len) .* view(x, len)
out

Check warning on line 581 in src/fillalgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/fillalgebra.jl#L581

Added line #L581 was not covered by tests
end

function *(A::RectDiagonal, B::RectDiagonal)
check_matmul_sizes(A, B)
TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B))
out = fill!(similar(A.diag, TS, min(size(A, 1), size(B, 2))), 0)
len = Base.OneTo(min(minimum(size(A)), minimum(size(B))))
out[len] .= view(A.diag, len) .* view(B.diag, len)
RectDiagonal(out, (size(A,1), size(B,2)))
end

for type in (RectDiagonal, RectDiagonalZeros)
@eval begin
function *(A::$type, B::AbstractZerosMatrix)
check_matmul_sizes(A, B)
Zeros{promote_type(eltype(A),eltype(B))}(size(A, 1), size(B, 2))
end

function *(A::$type, B::AbstractZerosVector)
check_matmul_sizes(A, B)
Zeros{promote_type(eltype(A),eltype(B))}(size(A, 1))
end

function *(A::AbstractZerosMatrix, B::$type)
check_matmul_sizes(A, B)
Zeros{promote_type(eltype(A),eltype(B))}(size(A, 1), size(B, 2))
end

*(A::AdjointAbsVec{<:Any,<:AbstractZerosVector}, B::$type) = Zeros(A) * B
*(A::TransposeAbsVec{<:Any,<:AbstractZerosVector}, B::$type) = Zeros(A) * B
*(A::$type, B::AdjointAbsVec{<:Any,<:AbstractZerosVector}) = A * Zeros(B)
*(A::$type, B::TransposeAbsVec{<:Any,<:AbstractZerosVector}) = A * Zeros(B)
end
end

for type in (AbstractMatrix, RectDiagonal, Diagonal, AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}, AbstractTriangular)
@eval begin
function *(A::$type, B::RectDiagonalZeros)
check_matmul_sizes(A, B)
Zeros{promote_type(eltype(A),eltype(B))}(size(A,1), size(B,2))
end
function *(A::RectDiagonalZeros, B::$type)
check_matmul_sizes(A, B)
Zeros{promote_type(eltype(A),eltype(B))}(size(A,1), size(B,2))
end
end
end
function *(A::RectDiagonalZeros, B::AbstractVector)
check_matmul_sizes(A, B)
Zeros{promote_type(eltype(A),eltype(B))}(size(A,1))
end
function *(A::RectDiagonalZeros, B::RectDiagonalZeros)
check_matmul_sizes(A, B)
Zeros{promote_type(eltype(A),eltype(B))}(size(A,1), size(B,2))
end

*(a::RectDiagonalFill, b::Number) = RectDiagonal(a.diag * b, a.axes)
*(a::Number, b::RectDiagonalFill) = RectDiagonal(a * b.diag, b.axes)

# DiagonalFill Multiplication
const DiagonalZeros{T,V<:AbstractZerosVector{T}} = Diagonal{T,V}
const DiagonalOnes{T,V<:AbstractOnesVector{T}} = Diagonal{T,V}
mat_types = (AbstractMatrix, RectDiagonal, AbstractZerosMatrix,
AbstractFillMatrix, AdjointAbsVec, TransposeAbsVec, UnitUpperTriangular, UnitLowerTriangular,
LowerTriangular, UpperTriangular, AbstractTriangular, Symmetric, Hermitian, LinearAlgebra.HermOrSym,
SymTridiagonal, UpperHessenberg, LinearAlgebra.AdjOrTransAbsMat, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector})#, OneElement)
for type in tuple(AbstractVector, AbstractZerosVector, mat_types...)
@eval begin
function *(A::DiagonalFill, B::$type)
check_matmul_sizes(A, B)
getindex_value(A.diag) * B
end
*(A::DiagonalZeros, B::$type) = Zeros(A) * B
function *(A::DiagonalOnes, B::$type)
check_matmul_sizes(A, B)
convert(AbstractArray{promote_type(eltype(A), eltype(B))}, deepcopy(B))
end
end
end
*(A::DiagonalOnes, B::AbstractRange) = one(eltype(A)) * B

for type in mat_types
@eval begin
function *(A::$type, B::DiagonalFill)
check_matmul_sizes(A, B)
getindex_value(B.diag) * A
end
*(A::$type, B::DiagonalZeros) = A * Zeros(B)
function *(A::$type, B::DiagonalOnes)
check_matmul_sizes(A, B)
convert(AbstractMatrix{promote_type(eltype(A), eltype(B))}, deepcopy(A))
end
end
end

for type1 in (DiagonalFill, DiagonalOnes, DiagonalZeros)
for type2 in (AdjointAbsVec{<:Any,<:AbstractZerosVector}, TransposeAbsVec{<:Any,<:AbstractZerosVector}, RectDiagonalZeros)
@eval begin
*(A::$type2, B::$type1) = Zeros(A) * B
*(A::$type1, B::$type2) = A * Zeros(B)
end
end
@eval begin
*(A::Diagonal, B::$type1) = Diagonal(A.diag .* B.diag)
*(A::$type1, B::Diagonal) = Diagonal(A.diag .* B.diag)
end
end

for type1 in (DiagonalFill, DiagonalOnes, DiagonalZeros)
for type2 in (DiagonalFill, DiagonalOnes, DiagonalZeros)
@eval begin
*(A::$type1, B::$type2) = Diagonal(A.diag .* B.diag)
end
end
end

*(A::RectDiagonalFill, B::DiagonalZeros) = A * Zeros(B)
*(A::DiagonalZeros, B::RectDiagonalFill) = Zeros(A) * B
for type in (DiagonalFill, DiagonalOnes)
@eval begin
function *(A::$type, B::RectDiagonalFill)
check_matmul_sizes(A, B)
len = Base.OneTo(minimum(size(B)))
RectDiagonal(view(A.diag, len) .* view(B.diag, len), size(B))
end

function *(A::RectDiagonalFill, B::$type)
check_matmul_sizes(A, B)
len = Base.OneTo(minimum(size(A)))
RectDiagonal(view(A.diag, len) .* view(B.diag, len), size(A))
end
end
end

function *(Da::Diagonal, A::RectDiagonal, Db::Diagonal)
check_matmul_sizes(Da, A)
check_matmul_sizes(A, Db)
len = Base.OneTo(minimum(size(A)))
diag = view(Da.diag, len) .* view(A.diag, len) .* view(Db.diag, len)
if diag isa Zeros
Zeros{eltype(diag)}(axes(A))
else
RectDiagonal(diag, axes(A))
end
end

####
# norm
####
Expand Down
15 changes: 15 additions & 0 deletions src/fillbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,21 @@ broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = broadcasted_fill(op, r, op(getindex_value(r),x[]), axes(r))
broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = broadcasted_fill(op, r, op(x[], getindex_value(r)), axes(r))

# ternary broadcasting
for type1 in (AbstractArray, AbstractFill, AbstractZeros)
for type2 in (AbstractArray, AbstractFill, AbstractZeros)
for type3 in (AbstractArray, AbstractFill, AbstractZeros)
if type1 === AbstractZeros || type2 === AbstractZeros || type3 === AbstractZeros
@eval begin
broadcasted(::DefaultArrayStyle, ::typeof(*), a::$type1, b::$type2, c::$type3) = Zeros{promote_type(eltype(a),eltype(b),eltype(c))}(broadcast_shape(axes(a), axes(b), axes(c)))
end
end
end
end
end
broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractOnes, b::AbstractOnes, c::AbstractOnes) = Ones{promote_type(eltype(a),eltype(b),eltype(c))}(broadcast_shape(axes(a), axes(b), axes(c)))
broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractFill, b::AbstractFill, c::AbstractFill) = Fill(getindex_value(a)*getindex_value(b)*getindex_value(c), broadcast_shape(axes(a), axes(b), axes(c)))

# support AbstractFill .^ k
broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractFill{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_fill(op, r, getindex_value(r)^k, axes(r))
broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractOnes{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_ones(op, r, T, axes(r))
Expand Down
24 changes: 24 additions & 0 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,30 @@ function *(D::Diagonal, A::OneElementMatrix)
OneElement(val, A.ind, size(A))
end

function *(A::OneElementMatrix, D::DiagonalZeros)
check_matmul_sizes(A, D)
Zeros{promote_type(eltype(A),eltype(D))}(size(A, 1), size(D, 2))
end

function *(D::DiagonalZeros, A::OneElementMatrix)
check_matmul_sizes(D, A)
Zeros{promote_type(eltype(A),eltype(D))}(size(D, 1), size(A, 2))
end

for type in (DiagonalFill, DiagonalOnes)
@eval begin
function *(A::OneElementMatrix, D::$type)
check_matmul_sizes(A, D)
getindex_value(D.diag) * A
end

function *(D::$type, A::OneElementMatrix)
check_matmul_sizes(D, A)
getindex_value(D.diag) * A
end
end
end

# Inplace multiplication

# We use this for out overloads for _mul! for OneElement because its more efficient
Expand Down
Loading
Loading