diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 66dad48..a2d2b26 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -489,6 +489,7 @@ Base.replace_in_print_matrix(A::RectDiagonal, i::Integer, j::Integer, s::Abstrac const RectOrDiagonal{T,V,Axes} = Union{RectDiagonal{T,V,Axes}, Diagonal{T,V}} const RectOrDiagonalFill{T,V<:AbstractFillVector{T},Axes} = RectOrDiagonal{T,V,Axes} const RectDiagonalFill{T,V<:AbstractFillVector{T}} = RectDiagonal{T,V} +const DiagonalFill{T,V<:AbstractFillVector{T}} = Diagonal{T,V} const SquareEye{T,Axes} = Diagonal{T,Ones{T,1,Tuple{Axes}}} const Eye{T,Axes} = RectOrDiagonal{T,Ones{T,1,Tuple{Axes}}} diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index f98ae60..0b81d71 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -15,6 +15,7 @@ for OP in (:transpose, :adjoint) end $OP(a::AbstractOnesMatrix) = fillsimilar(a, reverse(axes(a))) $OP(a::FillMatrix) = Fill($OP(a.value), reverse(a.axes)) + $OP(a::RectDiagonal) = RectDiagonal(vec($OP(a.diag)), reverse(a.axes)) end end @@ -80,10 +81,57 @@ mult_ones(a, b) = mult_ones(a, b, mult_axes(a, b)) *(a::AbstractFillMatrix, b::AbstractFillMatrix) = mult_fill(a,b) *(a::AbstractFillMatrix, b::AbstractFillVector) = mult_fill(a,b) +for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AdjointAbsVec{<:Any,<:AbstractFillVector}, TransposeAbsVec{<:Any,<:AbstractFillVector}) + @eval begin + function *(A::AbstractFillVector, B::$type) + size(A,2) == size(B,1) || + throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))")) + Fill(getindex_value(A) * getindex_value(B), size(A, 1), size(B, 2)) + end + function *(A::AbstractFillMatrix, B::$type) + size(A,2) == size(B,1) || + throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))")) + Fill(getindex_value(A) * getindex_value(B), size(A, 1), size(B, 2)) + end + end +end # this treats a size (n,) vector as a nx1 matrix, so b needs to have 1 row # special cased, as OnesMatrix * OnesMatrix isn't a Ones *(a::AbstractOnesVector, b::AbstractOnesMatrix) = mult_ones(a, b) +for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}) + @eval begin + *(A::AbstractOnesVector, B::$type) = mult_ones(A, B) + *(A::AbstractOnesMatrix, B::$type) = mult_ones(A, B) + end +end + +for type2 in (AdjointAbsVec{<:Any,<:AbstractZerosVector}, TransposeAbsVec{<:Any,<:AbstractZerosVector}) + for type1 in (AbstractFillVector, AbstractZerosVector, AbstractOnesVector, AbstractFillMatrix, AbstractZerosMatrix, AbstractOnesMatrix) + @eval begin + function *(A::$type1, B::$type2) + size(A,2) == size(B,1) || + throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))")) + Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1), size(B, 2)) + end + end + end +end + +for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AdjointAbsVec{<:Any,<:AbstractFillVector}, TransposeAbsVec{<:Any,<:AbstractFillVector}, ) + @eval begin + function *(A::AbstractZerosVector, B::$type) + size(A,2) == size(B,1) || + throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))")) + Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1), size(B, 2)) + end + function *(A::AbstractZerosMatrix, B::$type) + size(A,2) == size(B,1) || + throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))")) + Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1), size(B, 2)) + end + end +end *(a::AbstractZerosMatrix, b::AbstractZerosMatrix) = mult_zeros(a, b) *(a::AbstractZerosMatrix, b::AbstractZerosVector) = mult_zeros(a, b) @@ -485,6 +533,198 @@ end @inline elconvert(::Type{T}, A::AbstractUnitRange) where T<:Integer = AbstractUnitRange{T}(A) @inline elconvert(::Type{T}, A::AbstractArray) where T = AbstractArray{T}(A) +# RectDiagonal Multiplication +const RectDiagonalZeros{T,V<:AbstractZerosVector{T}} = RectDiagonal{T,V} +const RectDiagonalOnes{T,V<:AbstractOnesVector{T}} = RectDiagonal{T,V} + +function *(A::RectDiagonal, B::Diagonal) + check_matmul_sizes(A, B) + len = minimum(size(A)) + RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), (size(A, 1), size(B, 2))) +end +function *(A::Diagonal, B::RectDiagonal) + check_matmul_sizes(A, B) + len = minimum(size(B)) + RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), (size(A, 1), size(B, 2))) +end + +for type in (AbstractMatrix, AbstractTriangular, AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) + @eval begin + function *(A::RectDiagonal, B::$type) + check_matmul_sizes(A, B) + TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) + diag = A.diag + out = fill!(similar(diag, TS, axes(A,1), axes(B,2)), 0) + len = Base.OneTo(minimum(size(A))) + out[len, :] .= view(diag, len) .* view(B, len, :) + out + end + + function *(A::$type, B::RectDiagonal) + check_matmul_sizes(A, B) + TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) + out = fill!(similar(A, TS, axes(A,1), axes(B, 2)), 0) + len = Base.OneTo(minimum(size(B))) + out[:, len] .= view(A, :, len) .* view(reshape(B.diag, 1, :), Base.OneTo(1), len) + out + end + end +end + +function *(A::RectDiagonal, x::AbstractVector) + check_matmul_sizes(A, x) + TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(x)) + diag = A.diag + out = fill!(similar(diag, TS, axes(A,1)), 0) + len = Base.OneTo(minimum(size(A))) + out[len] .= view(diag, len) .* view(x, len) + out +end + +function *(A::RectDiagonal, B::RectDiagonal) + check_matmul_sizes(A, B) + TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) + out = fill!(similar(A.diag, TS, min(size(A, 1), size(B, 2))), 0) + len = Base.OneTo(min(minimum(size(A)), minimum(size(B)))) + out[len] .= view(A.diag, len) .* view(B.diag, len) + RectDiagonal(out, (size(A,1), size(B,2))) +end + +for type in (RectDiagonal, RectDiagonalZeros) + @eval begin + function *(A::$type, B::AbstractZerosMatrix) + check_matmul_sizes(A, B) + Zeros{promote_type(eltype(A),eltype(B))}(size(A, 1), size(B, 2)) + end + + function *(A::$type, B::AbstractZerosVector) + check_matmul_sizes(A, B) + Zeros{promote_type(eltype(A),eltype(B))}(size(A, 1)) + end + + function *(A::AbstractZerosMatrix, B::$type) + check_matmul_sizes(A, B) + Zeros{promote_type(eltype(A),eltype(B))}(size(A, 1), size(B, 2)) + end + + *(A::AdjointAbsVec{<:Any,<:AbstractZerosVector}, B::$type) = Zeros(A) * B + *(A::TransposeAbsVec{<:Any,<:AbstractZerosVector}, B::$type) = Zeros(A) * B + *(A::$type, B::AdjointAbsVec{<:Any,<:AbstractZerosVector}) = A * Zeros(B) + *(A::$type, B::TransposeAbsVec{<:Any,<:AbstractZerosVector}) = A * Zeros(B) + end +end + +for type in (AbstractMatrix, RectDiagonal, Diagonal, AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}, AbstractTriangular) + @eval begin + function *(A::$type, B::RectDiagonalZeros) + check_matmul_sizes(A, B) + Zeros{promote_type(eltype(A),eltype(B))}(size(A,1), size(B,2)) + end + function *(A::RectDiagonalZeros, B::$type) + check_matmul_sizes(A, B) + Zeros{promote_type(eltype(A),eltype(B))}(size(A,1), size(B,2)) + end + end +end +function *(A::RectDiagonalZeros, B::AbstractVector) + check_matmul_sizes(A, B) + Zeros{promote_type(eltype(A),eltype(B))}(size(A,1)) +end +function *(A::RectDiagonalZeros, B::RectDiagonalZeros) + check_matmul_sizes(A, B) + Zeros{promote_type(eltype(A),eltype(B))}(size(A,1), size(B,2)) +end + +*(a::RectDiagonalFill, b::Number) = RectDiagonal(a.diag * b, a.axes) +*(a::Number, b::RectDiagonalFill) = RectDiagonal(a * b.diag, b.axes) + +# DiagonalFill Multiplication +const DiagonalZeros{T,V<:AbstractZerosVector{T}} = Diagonal{T,V} +const DiagonalOnes{T,V<:AbstractOnesVector{T}} = Diagonal{T,V} +mat_types = (AbstractMatrix, RectDiagonal, AbstractZerosMatrix, + AbstractFillMatrix, AdjointAbsVec, TransposeAbsVec, UnitUpperTriangular, UnitLowerTriangular, + LowerTriangular, UpperTriangular, AbstractTriangular, Symmetric, Hermitian, LinearAlgebra.HermOrSym, + SymTridiagonal, UpperHessenberg, LinearAlgebra.AdjOrTransAbsMat, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector})#, OneElement) +for type in tuple(AbstractVector, AbstractZerosVector, mat_types...) + @eval begin + function *(A::DiagonalFill, B::$type) + check_matmul_sizes(A, B) + getindex_value(A.diag) * B + end + *(A::DiagonalZeros, B::$type) = Zeros(A) * B + function *(A::DiagonalOnes, B::$type) + check_matmul_sizes(A, B) + convert(AbstractArray{promote_type(eltype(A), eltype(B))}, deepcopy(B)) + end + end +end +*(A::DiagonalOnes, B::AbstractRange) = one(eltype(A)) * B + +for type in mat_types + @eval begin + function *(A::$type, B::DiagonalFill) + check_matmul_sizes(A, B) + getindex_value(B.diag) * A + end + *(A::$type, B::DiagonalZeros) = A * Zeros(B) + function *(A::$type, B::DiagonalOnes) + check_matmul_sizes(A, B) + convert(AbstractMatrix{promote_type(eltype(A), eltype(B))}, deepcopy(A)) + end + end +end + +for type1 in (DiagonalFill, DiagonalOnes, DiagonalZeros) + for type2 in (AdjointAbsVec{<:Any,<:AbstractZerosVector}, TransposeAbsVec{<:Any,<:AbstractZerosVector}, RectDiagonalZeros) + @eval begin + *(A::$type2, B::$type1) = Zeros(A) * B + *(A::$type1, B::$type2) = A * Zeros(B) + end + end + @eval begin + *(A::Diagonal, B::$type1) = Diagonal(A.diag .* B.diag) + *(A::$type1, B::Diagonal) = Diagonal(A.diag .* B.diag) + end +end + +for type1 in (DiagonalFill, DiagonalOnes, DiagonalZeros) + for type2 in (DiagonalFill, DiagonalOnes, DiagonalZeros) + @eval begin + *(A::$type1, B::$type2) = Diagonal(A.diag .* B.diag) + end + end +end + +*(A::RectDiagonalFill, B::DiagonalZeros) = A * Zeros(B) +*(A::DiagonalZeros, B::RectDiagonalFill) = Zeros(A) * B +for type in (DiagonalFill, DiagonalOnes) + @eval begin + function *(A::$type, B::RectDiagonalFill) + check_matmul_sizes(A, B) + len = Base.OneTo(minimum(size(B))) + RectDiagonal(view(A.diag, len) .* view(B.diag, len), size(B)) + end + + function *(A::RectDiagonalFill, B::$type) + check_matmul_sizes(A, B) + len = Base.OneTo(minimum(size(A))) + RectDiagonal(view(A.diag, len) .* view(B.diag, len), size(A)) + end + end +end + +function *(Da::Diagonal, A::RectDiagonal, Db::Diagonal) + check_matmul_sizes(Da, A) + check_matmul_sizes(A, Db) + len = Base.OneTo(minimum(size(A))) + diag = view(Da.diag, len) .* view(A.diag, len) .* view(Db.diag, len) + if diag isa Zeros + Zeros{eltype(diag)}(axes(A)) + else + RectDiagonal(diag, axes(A)) + end +end + #### # norm #### diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 2b5ea59..0b0b4c3 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -259,6 +259,21 @@ broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where { broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = broadcasted_fill(op, r, op(getindex_value(r),x[]), axes(r)) broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = broadcasted_fill(op, r, op(x[], getindex_value(r)), axes(r)) +# ternary broadcasting +for type1 in (AbstractArray, AbstractFill, AbstractZeros) + for type2 in (AbstractArray, AbstractFill, AbstractZeros) + for type3 in (AbstractArray, AbstractFill, AbstractZeros) + if type1 === AbstractZeros || type2 === AbstractZeros || type3 === AbstractZeros + @eval begin + broadcasted(::DefaultArrayStyle, ::typeof(*), a::$type1, b::$type2, c::$type3) = Zeros{promote_type(eltype(a),eltype(b),eltype(c))}(broadcast_shape(axes(a), axes(b), axes(c))) + end + end + end + end +end +broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractOnes, b::AbstractOnes, c::AbstractOnes) = Ones{promote_type(eltype(a),eltype(b),eltype(c))}(broadcast_shape(axes(a), axes(b), axes(c))) +broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractFill, b::AbstractFill, c::AbstractFill) = Fill(getindex_value(a)*getindex_value(b)*getindex_value(c), broadcast_shape(axes(a), axes(b), axes(c))) + # support AbstractFill .^ k broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractFill{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_fill(op, r, getindex_value(r)^k, axes(r)) broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractOnes{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_ones(op, r, T, axes(r)) diff --git a/src/oneelement.jl b/src/oneelement.jl index 9b76d35..6e6bd20 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -190,6 +190,30 @@ function *(D::Diagonal, A::OneElementMatrix) OneElement(val, A.ind, size(A)) end +function *(A::OneElementMatrix, D::DiagonalZeros) + check_matmul_sizes(A, D) + Zeros{promote_type(eltype(A),eltype(D))}(size(A, 1), size(D, 2)) +end + +function *(D::DiagonalZeros, A::OneElementMatrix) + check_matmul_sizes(D, A) + Zeros{promote_type(eltype(A),eltype(D))}(size(D, 1), size(A, 2)) +end + +for type in (DiagonalFill, DiagonalOnes) + @eval begin + function *(A::OneElementMatrix, D::$type) + check_matmul_sizes(A, D) + getindex_value(D.diag) * A + end + + function *(D::$type, A::OneElementMatrix) + check_matmul_sizes(D, A) + getindex_value(D.diag) * A + end + end +end + # Inplace multiplication # We use this for out overloads for _mul! for OneElement because its more efficient diff --git a/test/runtests.jl b/test/runtests.jl index 32cd2ed..9035fa7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -386,6 +386,512 @@ end @test stringmime("text/plain", D) == "3×2 RectDiagonal{Float64, Vector{Float64}, Tuple{Base.OneTo{$Int}, Base.OneTo{$Int}}}:\n 1.0 ⋅ \n ⋅ 2.0\n ⋅ ⋅ " end +@testset "RectDiagonal multiplication" begin + using FillArrays: RectDiagonalFill, RectDiagonalZeros, RectDiagonalOnes, DiagonalFill, DiagonalOnes, DiagonalZeros + + val = 2.0 + + n = 3 + square_mat_instances = Dict( + :RectDiagonal => RectDiagonal(rand(n), n, n), + :RectDiagonalFill => RectDiagonal(Fill(val, n), n, n), + :RectDiagonalZeros => RectDiagonal(Zeros(n), n, n), + :RectDiagonalOnes => RectDiagonal(Ones(n), n, n), + :Diagonal => Diagonal(rand(n)), + :DiagonalFill => Diagonal(Fill(val, n)), + :DiagonalZeros => Diagonal(Zeros(n)), + :DiagonalOnes => Diagonal(Ones(n)), + :Zeros => Zeros(n, n), + :Ones => Ones(n, n), + :Fill => Fill(val, n, n), + :Mat => rand(n, n), + ) + + m = 1 + n = 3 + row_mat_instances = Dict( + :RectDiagonal => RectDiagonal(rand(m), m, n), + :RectDiagonalFill => RectDiagonal(Fill(val, min(m, n)), m, n), + :RectDiagonalZeros => RectDiagonal(Zeros(min(m, n)), m, n), + :RectDiagonalOnes => RectDiagonal(Ones(min(m, n)), m, n), + :Zeros => Zeros(m, n), + :Ones => Ones(m, n), + :Fill => Fill(val, m, n), + :Mat => rand(m, n), + ) + + m = 3 + n = 1 + col_mat_instances = Dict( + :RectDiagonal => RectDiagonal(rand(m), m, n), + :RectDiagonalFill => RectDiagonal(Fill(val, min(m, n)), m, n), + :RectDiagonalZeros => RectDiagonal(Zeros(min(m, n)), m, n), + :RectDiagonalOnes => RectDiagonal(Ones(min(m, n)), m, n), + :Zeros => Zeros(m, n), + :Ones => Ones(m, n), + :Fill => Fill(val, m, n), + :Mat => rand(m, n), + ) + + n = 3 + trans_vec_instances = Dict( + :TransVec => rand(n), + :TransZerosVec => Zeros(n), + :TransOnesVec => Ones(n), + :TransFillVec => Fill(val, n), + ) + + vec_instances = Dict( + :Vec => rand(n), + :ZerosVec => Zeros(n), + :OnesVec => Ones(n), + :FillVec => Fill(val, n) + ) + + n = 1 + one_dim_mat_instances = Dict( + :RectDiagonal => RectDiagonal(rand(n), n, n), + :RectDiagonalFill => RectDiagonal(Fill(val, n), n, n), + :RectDiagonalZeros => RectDiagonal(Zeros(n), n, n), + :RectDiagonalOnes => RectDiagonal(Ones(n), n, n), + :Diagonal => Diagonal(rand(n)), + :DiagonalFill => Diagonal(Fill(val, n)), + :DiagonalZeros => Diagonal(Zeros(n)), + :DiagonalOnes => Diagonal(Ones(n)), + :Zeros => Zeros(n, n), + :Ones => Ones(n, n), + :Fill => Fill(val, n, n), + :Mat => rand(n, n), + ) + + + # Expected outcome table. + # The header (in order) corresponds to the following instance symbols: + # :RectDiagonal, :RectDiagonalFill, :RectDiagonalZeros, :RectDiagonalOnes, + # :Diagonal, :DiagonalFill, :DiagonalZeros, :DiagonalOnes, :Zeros, :Ones, :Fill, :Mat + # Each row gives the expected resultant type when doing multiplication, + expected = Dict( + :RectDiagonal => Dict( + :RectDiagonal => RectDiagonal, + :RectDiagonalFill => RectDiagonal, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => RectDiagonal, + :Diagonal => RectDiagonal, + :DiagonalFill => RectDiagonal, + :DiagonalZeros => Zeros, + :DiagonalOnes => RectDiagonal, + :Zeros => Zeros, + :Ones => Array, + :Fill => Array, + :Mat => Array, + :Vec => Vector, + :ZerosVec => Zeros, + :OnesVec => Vector, + :FillVec => Vector, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Array, + :TransFillVec => Array, + ), + :RectDiagonalFill => Dict( + :RectDiagonal => RectDiagonal, + :RectDiagonalFill => RectDiagonal, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => RectDiagonal, + :Diagonal => RectDiagonal, + :DiagonalFill => RectDiagonalFill, + :DiagonalZeros => Zeros, + :DiagonalOnes => RectDiagonalFill, + :Zeros => Zeros, + :Ones => Array, + :Fill => Array, + :Mat => Array, + :Vec => Vector, + :ZerosVec => Zeros, + :OnesVec => Vector, + :FillVec => Vector, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Array, + :TransFillVec => Array, + ), + :RectDiagonalZeros => Dict( + :RectDiagonal => Zeros, + :RectDiagonalFill => Zeros, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Zeros, + :Diagonal => Zeros, + :DiagonalFill => Zeros, + :DiagonalZeros => Zeros, + :DiagonalOnes => Zeros, + :Zeros => Zeros, + :Ones => Zeros, + :Fill => Zeros, + :Mat => Zeros, + :Vec => Zeros, + :ZerosVec => Zeros, + :OnesVec => Zeros, + :FillVec => Zeros, + :TransVec => Zeros, + :TransZerosVec => Zeros, + :TransOnesVec => Zeros, + :TransFillVec => Zeros, + ), + :RectDiagonalOnes => Dict( + :RectDiagonal => RectDiagonal, + :RectDiagonalFill => RectDiagonal, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => RectDiagonal, + :Diagonal => RectDiagonal, + :DiagonalFill => RectDiagonalFill, + :DiagonalZeros => Zeros, + :DiagonalOnes => RectDiagonalOnes, + :Zeros => Zeros, + :Ones => Array, + :Fill => Array, + :Mat => Array, + :Vec => Vector, + :ZerosVec => Zeros, + :OnesVec => Vector, + :FillVec => Vector, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Array, + :TransFillVec => Array, + ), + :Diagonal => Dict( + :RectDiagonal => RectDiagonal, + :RectDiagonalFill => RectDiagonal, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => RectDiagonal, + :Diagonal => Diagonal, + :DiagonalFill => Diagonal, + :DiagonalZeros => DiagonalZeros, + :DiagonalOnes => Diagonal, + :Zeros => Zeros, + :Ones => Array, + :Fill => Array, + :Mat => Array, + :Vec => Vector, + :ZerosVec => Zeros, + :OnesVec => Vector, + :FillVec => Vector, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Array, + :TransFillVec => Array, + ), + :DiagonalFill => Dict( + :RectDiagonal => RectDiagonal, + :RectDiagonalFill => RectDiagonalFill, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => RectDiagonalFill, + :Diagonal => Diagonal, + :DiagonalFill => DiagonalFill, + :DiagonalZeros => DiagonalZeros, + :DiagonalOnes => DiagonalFill, + :Zeros => Zeros, + :Ones => Fill, + :Fill => Fill, + :Mat => Array, + :Vec => Vector, + :ZerosVec => Zeros, + :OnesVec => Fill, + :FillVec => Fill, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Fill, + :TransFillVec => Fill, + ), + :DiagonalZeros => Dict( + :RectDiagonal => Zeros, + :RectDiagonalFill => Zeros, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Zeros, + :Diagonal => DiagonalZeros, + :DiagonalFill => DiagonalZeros, + :DiagonalZeros => DiagonalZeros, + :DiagonalOnes => DiagonalZeros, + :Zeros => Zeros, + :Ones => Zeros, + :Fill => Zeros, + :Mat => Zeros, + :Vec => Zeros, + :ZerosVec => Zeros, + :OnesVec => Zeros, + :FillVec => Zeros, + :TransVec => Zeros, + :TransZerosVec => Zeros, + :TransOnesVec => Zeros, + :TransFillVec => Zeros, + ), + :DiagonalOnes => Dict( + :RectDiagonal => RectDiagonal, + :RectDiagonalFill => RectDiagonalFill, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => RectDiagonalOnes, + :Diagonal => Diagonal, + :DiagonalFill => DiagonalFill, + :DiagonalZeros => DiagonalZeros, + :DiagonalOnes => DiagonalOnes, + :Zeros => Zeros, + :Ones => Ones, + :Fill => Fill, + :Mat => Array, + :Vec => Vector, + :ZerosVec => Zeros, + :OnesVec => Ones, + :FillVec => Fill, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Ones, + :TransFillVec => Fill, + ), + :Zeros => Dict( + :RectDiagonal => Zeros, + :RectDiagonalFill => Zeros, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Zeros, + :Diagonal => Zeros, + :DiagonalFill => Zeros, + :DiagonalZeros => Zeros, + :DiagonalOnes => Zeros, + :Zeros => Zeros, + :Ones => Zeros, + :Fill => Zeros, + :Mat => Zeros, + :Vec => Zeros, + :ZerosVec => Zeros, + :OnesVec => Zeros, + :FillVec => Zeros, + :TransVec => Zeros, + :TransZerosVec => Zeros, + :TransOnesVec => Zeros, + :TransFillVec => Zeros, + ), + :Ones => Dict( + :RectDiagonal => Array, + :RectDiagonalFill => Array, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Array, + :Diagonal => Array, + :DiagonalFill => Fill, + :DiagonalZeros => Zeros, + :DiagonalOnes => Ones, + :Zeros => Zeros, + :Ones => Fill, + :Fill => Fill, + :Mat => Array, + :Vec => Fill, + :ZerosVec => Zeros, + :OnesVec => Fill, + :FillVec => Fill, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Ones, + :TransFillVec => Fill, + ), + :Fill => Dict( + :RectDiagonal => Array, + :RectDiagonalFill => Array, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Array, + :Diagonal => Array, + :DiagonalFill => Fill, + :DiagonalZeros => Zeros, + :DiagonalOnes => Fill, + :Zeros => Zeros, + :Ones => Fill, + :Fill => Fill, + :Mat => Array, + :Vec => Fill, + :ZerosVec => Zeros, + :OnesVec => Fill, + :FillVec => Fill, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Fill, + :TransFillVec => Fill, + ), + :Mat => Dict( + :RectDiagonal => Array, + :RectDiagonalFill => Array, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Array, + :Diagonal => Array, + :DiagonalFill => Array, + :DiagonalZeros => Zeros, + :DiagonalOnes => Array, + :Zeros => Zeros, + :Ones => Array, + :Fill => Array, + :Mat => Array, + :Vec => Vector, + :ZerosVec => Zeros, + :OnesVec => Vector, + :FillVec => Vector, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Array, + :TransFillVec => Array, + ), + :Vec => Dict( + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Array, + :TransFillVec => Array, + ), + :ZerosVec => Dict( + :TransVec => Zeros, + :TransZerosVec => Zeros, + :TransOnesVec => Zeros, + :TransFillVec => Zeros, + ), + :OnesVec => Dict( + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Ones, + :TransFillVec => Fill, + ), + :FillVec => Dict( + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Fill, + :TransFillVec => Fill, + ), + :TransVec => Dict( + :RectDiagonal => Array, + :RectDiagonalFill => Array, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Array, + :Diagonal => Array, + :DiagonalFill => Array, + :DiagonalZeros => Zeros, + :DiagonalOnes => Array, + :Zeros => Zeros, + :Ones => Fill, + :Fill => Fill, + :Mat => Array, + ), + :TransZerosVec => Dict( + :RectDiagonal => Zeros, + :RectDiagonalFill => Zeros, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Zeros, + :Diagonal => Zeros, + :DiagonalFill => Zeros, + :DiagonalZeros => Zeros, + :DiagonalOnes => Zeros, + :Zeros => Zeros, + :Ones => Zeros, + :Fill => Zeros, + :Mat => Zeros, + ), + :TransOnesVec => Dict( + :RectDiagonal => Array, + :RectDiagonalFill => Array, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Array, + :Diagonal => Array, + :DiagonalFill => Fill, + :DiagonalZeros => Zeros, + :DiagonalOnes => Ones, + :Zeros => Zeros, + :Ones => Fill, + :Fill => Fill, + :Mat => Array, + ), + :TransFillVec => Dict( + :RectDiagonal => Array, + :RectDiagonalFill => Array, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Array, + :Diagonal => Array, + :DiagonalFill => Fill, + :DiagonalZeros => Zeros, + :DiagonalOnes => Fill, + :Zeros => Zeros, + :Ones => Fill, + :Fill => Fill, + :Mat => Array, + ), + ) + + for (k2, B) in square_mat_instances + for op2 in (adjoint, transpose, identity) + for (k1, A) in square_mat_instances + for op1 in (adjoint, transpose, identity) + result = op1(A) * op2(B) + @test result isa expected[k1][k2] || result.parent isa expected[k1][k2] + end + end + + for (k1, A) in trans_vec_instances + for op1 in (adjoint, transpose) + result = op1(A) * op2(B) + @test result isa expected[k1][k2] || result.parent isa expected[k1][k2] + end + end + end + end + + for (k2, B) in vec_instances + for (k1, A) in square_mat_instances + for op1 in (adjoint, transpose, identity) + @test op1(A) * B isa expected[k1][k2] + end + + end + for (k1, A) in trans_vec_instances + for op1 in (adjoint, transpose) + @test op1(A)*B isa Number + @test B*op1(A) isa expected[k2][k1] + end + end + end + + for (k1, A) in Iterators.flatten((col_mat_instances, one_dim_mat_instances)) + for (k2, B) in trans_vec_instances + for op2 in (adjoint, transpose) + result = A * op2(B) + @test result isa expected[k1][k2] || result.parent isa expected[k1][k2] + end + end + end + + for (k1, A) in row_mat_instances + for op1 in (adjoint, transpose) + for (k2, B) in trans_vec_instances + for op2 in (adjoint, transpose) + @test op1(A) * op2(B) isa expected[k1][k2] + end + end + end + end + + num = 3. + @test square_mat_instances[:RectDiagonalFill] * num == num *square_mat_instances[:RectDiagonalFill] == num * Matrix(square_mat_instances[:RectDiagonalFill]) + + for (k1, Da) in square_mat_instances + for (k2, Db) in square_mat_instances + for (k3, A) in square_mat_instances + @test typeof(Da * A * Db) === typeof((Da * A) * Db) === typeof((Da * (A * Db))) + if !(typeof(Da * A * Db) === typeof((Da * A) * Db) === typeof((Da * (A * Db)))) + @show typeof(Da) typeof(A) typeof(Db) typeof(Da*A*Db) typeof((Da*A)*Db) + end + end + end + end + + ind = (1, 2) + sz = (3, 3) + oneele = OneElement(val, ind, sz) + @test oneele * Diagonal(Zeros(3)) === Diagonal(Zeros(3)) * oneele === Zeros(3,3) + @test oneele * Diagonal(Fill(val, 3)) === Diagonal(Fill(val, 3)) * oneele === OneElement(val*val, ind, sz) + @test oneele * Diagonal(Ones(3)) === Diagonal(Ones(3)) * oneele === oneele + +end + # Check that all pair-wise combinations of + / - elements of As and Bs yield the correct # type, and produce numerically correct results. as_array(x::AbstractArray) = Array(x)