Skip to content

Commit 4492817

Browse files
authored
BroadcastStyle for lazy triangular or HermOrSym, improved broadcast * (#360)
* BroadcastStyle for lazy triangular or HermOrSym * Update inv.jl * Update broadcasttests.jl * broadcast mul with / and \ * Update LazyArraysBandedMatricesExt.jl * Add tests and back compat * Update lazybroadcasting.jl * simplifiable for broadcast mul * Update LazyArraysBandedMatricesExt.jl * increase coverage * Update broadcasttests.jl * Update cachetests.jl * increase coverage
1 parent 5d7d524 commit 4492817

File tree

10 files changed

+107
-14
lines changed

10 files changed

+107
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LazyArrays"
22
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
3-
version = "2.3.2"
3+
version = "2.4"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

ext/LazyArraysBandedMatricesExt.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import LazyArrays: sublayout, symmetriclayout, hermitianlayout, applylayout, cac
1010
AbstractPaddedLayout, PaddedLayout, AbstractLazyBandedLayout, LazyBandedLayout, PaddedRows,
1111
PaddedColumns, CachedArray, CachedMatrix, LazyLayout, BroadcastLayout, ApplyLayout,
1212
paddeddata, resizedata!, broadcastlayout, _broadcastarray2broadcasted, _broadcast_sub_arguments,
13-
arguments, call, applybroadcaststyle, simplify, simplifiable, islazy_layout, lazymaterialize, _broadcast_mul_mul,
13+
arguments, call, applybroadcaststyle, simplify, simplifiable, islazy_layout, lazymaterialize, _broadcast_mul_mul, _broadcast_mul_simplifiable,
1414
triangularlayout, AbstractCachedMatrix, _mulbanded_copyto!, ApplyBandedLayout, BroadcastBandedLayout
1515
import Base: BroadcastStyle, similar, copy, broadcasted, getindex, OneTo, oneto, tail, sign, abs
1616
import BandedMatrices: bandedbroadcaststyle, bandwidths, isbanded, bandedcolumns, bandeddata, BandedStyle,
@@ -539,6 +539,7 @@ bandeddata(R::ApplyMatrix{<:Any,typeof(rot180)}) = @view(bandeddata(arguments(R)
539539
const BandedLazyLayouts = Union{AbstractLazyBandedLayout, BandedColumns{LazyLayout}, BandedRows{LazyLayout},
540540
TriangularLayout{UPLO,UNIT,BandedRows{LazyLayout}} where {UPLO,UNIT},
541541
TriangularLayout{UPLO,UNIT,BandedColumns{LazyLayout}} where {UPLO,UNIT},
542+
TriangularLayout{UPLO,UNIT,LazyBandedLayout} where {UPLO,UNIT},
542543
SymTridiagonalLayout{LazyLayout}, BidiagonalLayout{LazyLayout}, TridiagonalLayout{LazyLayout},
543544
SymmetricLayout{BandedColumns{LazyLayout}}, HermitianLayout{BandedColumns{LazyLayout}}}
544545

@@ -551,7 +552,12 @@ copy(M::Mul{<:LazyLayouts, <:BandedLazyLayouts}) = simplify(M)
551552
copy(M::Mul{<:Any, <:BandedLazyLayouts}) = simplify(M)
552553
copy(M::Mul{<:BandedLazyLayouts, <:AbstractLazyLayout}) = simplify(M)
553554
copy(M::Mul{<:AbstractLazyLayout, <:BandedLazyLayouts}) = simplify(M)
554-
copy(M::Mul{BroadcastLayout{typeof(*)}, <:BandedLazyLayouts}) = _broadcast_mul_mul(arguments(BroadcastLayout{typeof(*)}(), M.A), M.B)
555+
for op in (:*, :/, :\)
556+
@eval begin
557+
simplifiable(M::Mul{BroadcastLayout{typeof($op)}, <:BandedLazyLayouts}) = _broadcast_mul_simplifiable($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)
558+
copy(M::Mul{BroadcastLayout{typeof($op)}, <:BandedLazyLayouts}) = _broadcast_mul_mul($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)
559+
end
560+
end
555561
copy(M::Mul{<:BandedLazyLayouts, <:DiagonalLayout}) = simplify(M)
556562
copy(M::Mul{<:DiagonalLayout, <:BandedLazyLayouts}) = simplify(M)
557563

src/lazybroadcasting.jl

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ BroadcastStyle(::Type{<:Transpose{<:Any,<:LazyVector}}) = LazyArrayStyle{2}()
137137
BroadcastStyle(::Type{<:Adjoint{<:Any,<:LazyMatrix}}) = LazyArrayStyle{2}()
138138
BroadcastStyle(::Type{<:Transpose{<:Any,<:LazyMatrix}}) = LazyArrayStyle{2}()
139139
BroadcastStyle(::Type{<:SubArray{<:Any,1,<:LazyMatrix,<:Tuple{Slice,Any}}}) = LazyArrayStyle{1}()
140+
141+
BroadcastStyle(::Type{<:UpperOrLowerTriangular{<:Any,<:LazyMatrix}}) = LazyArrayStyle{2}()
142+
BroadcastStyle(::Type{<:LinearAlgebra.HermOrSym{<:Any,<:LazyMatrix}}) = LazyArrayStyle{2}()
143+
144+
140145
BroadcastStyle(L::LazyArrayStyle{N}, ::StructuredMatrixStyle) where N = L
141146

142147

@@ -397,11 +402,37 @@ end
397402
###
398403

399404
_broadcast_mul_mul(A, B) = simplify(Mul(broadcast(*, A...), B))
400-
_broadcast_mul_mul((a,B)::Tuple{AbstractVector,AbstractMatrix}, C) = a .* (B*C)
401-
_broadcast_mul_mul((A,b)::Tuple{AbstractMatrix,AbstractVector}, C) = b .* (A*C)
402-
@inline copy(M::Mul{BroadcastLayout{typeof(*)}}) = _broadcast_mul_mul(arguments(BroadcastLayout{typeof(*)}(), M.A), M.B)
403-
@inline copy(M::Mul{BroadcastLayout{typeof(*)},<:AbstractLazyLayout}) = _broadcast_mul_mul(arguments(BroadcastLayout{typeof(*)}(), M.A), M.B)
404-
@inline copy(M::Mul{BroadcastLayout{typeof(*)},ApplyLayout{typeof(*)}}) = _broadcast_mul_mul(arguments(BroadcastLayout{typeof(*)}(), M.A), M.B)
405+
_broadcast_mul_mul(::typeof(*), A, B) = _broadcast_mul_mul(A, B) # maintain back-compatibility with Quasi/ContiuumArrays.jl
406+
_broadcast_mul_simplifiable(op, A, B) = Val(false)
407+
_broadcast_mul_mul(op, A, B) = simplify(Mul(broadcast(op, A...), B))
408+
409+
for op in (:*, :\)
410+
@eval begin
411+
_broadcast_mul_simplifiable(::typeof($op), (a,B)::Tuple{Union{AbstractVector,Number},AbstractMatrix}, C) = simplifiable(*, B, C)
412+
_broadcast_mul_mul(::typeof($op), (a,B)::Tuple{Union{AbstractVector,Number},AbstractMatrix}, C) = broadcast($op, a, (B*C))
413+
end
414+
end
415+
416+
for op in (:*, :/)
417+
@eval begin
418+
_broadcast_mul_simplifiable(::typeof($op), (A,b)::Tuple{AbstractMatrix,Union{AbstractVector,Number}}, C) = simplifiable(*, A, C)
419+
_broadcast_mul_mul(::typeof($op), (A,b)::Tuple{AbstractMatrix,Union{AbstractVector,Number}}, C) = broadcast($op, (A*C), b)
420+
end
421+
end
422+
423+
424+
425+
for op in (:*, :/, :\)
426+
@eval begin
427+
@inline simplifiable(M::Mul{BroadcastLayout{typeof($op)}}) = _broadcast_mul_simplifiable($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)
428+
@inline simplifiable(M::Mul{BroadcastLayout{typeof($op)},<:LazyLayouts}) = _broadcast_mul_simplifiable($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)
429+
@inline simplifiable(M::Mul{BroadcastLayout{typeof($op)},ApplyLayout{typeof(*)}}) = _broadcast_mul_simplifiable($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)
430+
@inline copy(M::Mul{BroadcastLayout{typeof($op)}}) = _broadcast_mul_mul($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)
431+
@inline copy(M::Mul{BroadcastLayout{typeof($op)},<:LazyLayouts}) = _broadcast_mul_mul($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)
432+
@inline copy(M::Mul{BroadcastLayout{typeof($op)},ApplyLayout{typeof(*)}}) = _broadcast_mul_mul($op, arguments(BroadcastLayout{typeof($op)}(), M.A), M.B)
433+
end
434+
end
435+
405436

406437
for op in (:*, :\, :/)
407438
@eval begin

src/linalg/inv.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,13 @@ applylayout(::Type{typeof(pinv)}, ::A) where A = PInvLayout{A}()
8888
simplifiable(::Mul{<:AbstractInvLayout}) = Val(true)
8989

9090
copy(M::Mul{<:AbstractInvLayout}) = ArrayLayouts.ldiv(pinv(M.A), M.B)
91-
copy(M::Mul{<:AbstractInvLayout,<:AbstractLazyLayout}) = ArrayLayouts.ldiv(pinv(M.A), M.B)
92-
@inline copy(M::Mul{<:AbstractInvLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = copy(mulreduce(M))
93-
@inline copy(M::Mul{<:AbstractInvLayout,ApplyLayout{typeof(*)}}) = simplify(M)
91+
copy(M::Mul{<:AbstractInvLayout, <:AbstractLazyLayout}) = ArrayLayouts.ldiv(pinv(M.A), M.B)
92+
@inline copy(M::Mul{<:AbstractInvLayout, <:DiagonalLayout{<:AbstractFillLayout}}) = copy(mulreduce(M))
93+
@inline copy(M::Mul{<:AbstractInvLayout, ApplyLayout{typeof(*)}}) = simplify(M)
9494
copy(L::Ldiv{<:AbstractInvLayout}) = pinv(L.A) * L.B
95+
copy(L::Ldiv{<:AbstractInvLayout, <:AbstractLazyLayout}) = pinv(L.A) * L.B
96+
copy(L::Ldiv{<:AbstractInvLayout, <:AbstractInvLayout}) = pinv(L.A) * L.B
97+
copy(L::Ldiv{<:AbstractInvLayout, ApplyLayout{typeof(*)}}) = pinv(L.A) * L.B
9598
Ldiv(A::Applied{<:Any,typeof(\)}) = Ldiv(A.args...)
9699

97100

src/padded.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,9 @@ copy(M::Mul{<:Union{TriangularLayout{'U', 'N', <:AbstractLazyLayout}, Triangular
510510
simplifiable(::Mul{<:Union{TriangularLayout{'U', 'N', <:AbstractLazyLayout}, TriangularLayout{'U', 'U', <:AbstractLazyLayout}}, <:Union{PaddedColumns,PaddedLayout}}) = Val(true)
511511

512512

513+
@inline simplifiable(M::Mul{BroadcastLayout{typeof(*)},<:Union{PaddedColumns,PaddedLayout}}) = simplifiable(Mul{BroadcastLayout{typeof(*)},UnknownLayout}(M.A,M.B))
514+
@inline copy(M::Mul{BroadcastLayout{typeof(*)},<:Union{PaddedColumns,PaddedLayout}}) = copy(Mul{BroadcastLayout{typeof(*)},UnknownLayout}(M.A,M.B))
515+
513516
simplifiable(::Mul{<:DualLayout{<:AbstractLazyLayout}, <:Union{PaddedColumns,PaddedLayout}}) = Val(true)
514517
copy(M::Mul{<:DualLayout{<:AbstractLazyLayout}, <:Union{PaddedColumns,PaddedLayout}}) = copy(mulreduce(M))
515518
simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout}, <:Union{PaddedColumns,PaddedLayout}}) = Val(true)

test/bandedtests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,13 @@ LinearAlgebra.lmul!(β::Number, A::PseudoBandedMatrix) = (lmul!(β, A.data); A)
425425
@test MemoryLayout(BroadcastMatrix(cos, A)) isa BroadcastLayout
426426
end
427427
end
428+
429+
@testset "broadcast_mul_mul" begin
430+
A = BroadcastMatrix(*, randn(5,5), randn(5,5))
431+
B = ApplyArray(*, brand(5,5,1,2), brand(5,5,2,1))
432+
@test A * UpperTriangular(B) Matrix(A) * UpperTriangular(B)
433+
@test simplifiable(*, A, UpperTriangular(B)) == Val(false) # TODO: probably should be true
434+
end
428435
end
429436

430437
@testset "Cache" begin

test/broadcasttests.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module BroadcastTests
22

33
using LazyArrays, ArrayLayouts, LinearAlgebra, FillArrays, Base64, Test
44
using StaticArrays, Tracker
5-
import LazyArrays: BroadcastLayout, arguments, LazyArrayStyle, sub_materialize
5+
import LazyArrays: BroadcastLayout, arguments, LazyArrayStyle, sub_materialize, simplifiable
66
import Base: broadcasted
77

88
using ..InfiniteArrays
@@ -224,6 +224,9 @@ using Infinities
224224
@test A[:,2] Ã[:,2] Matrix(A)[:,2]
225225
@test C*b Matrix(C)*b
226226

227+
@test simplifiable(*, A, B) == Val(true)
228+
@test simplifiable(*, Ã, B) == Val(true)
229+
227230
D = Diagonal(Fill(2,4))
228231
@test A*D Matrix(A)*D
229232
end
@@ -396,7 +399,7 @@ using Infinities
396399
@test a[:,1:3] isa Adjoint{Int,Vector{Int}}
397400
end
398401

399-
@testset "broadcast with adjtrans" begin
402+
@testset "broadcast with adjtrans/triangular/hermsym" begin
400403
a = BroadcastArray(real, ((1:5) .+ im))
401404
b = BroadcastArray(exp, ((1:5) .+ im))
402405
@test exp.(transpose(a)) isa Transpose{<:Any,<:BroadcastVector}
@@ -407,8 +410,17 @@ using Infinities
407410
@test exp.(b') isa BroadcastMatrix
408411
@test exp.(transpose(b)) == transpose(exp.(b))
409412
@test exp.(b') == exp.(b)'
413+
414+
A = BroadcastArray(*, ((1:5) .+ im), (1:5)')
415+
@test exp.(UpperTriangular(A)) isa BroadcastArray
416+
@test exp.(Symmetric(A)) isa BroadcastArray
417+
@test exp.(Hermitian(A)) isa BroadcastArray
418+
@test exp.(UpperTriangular(A)) == exp.(UpperTriangular(Matrix(A)))
419+
@test exp.(Symmetric(A)) == exp.(Symmetric(Matrix(A)))
420+
@test exp.(Hermitian(A)) == exp.(Hermitian(Matrix(A)))
410421
end
411422

423+
412424
@testset "linear indexing" begin
413425
a = BroadcastArray(real, ((1:5) .+ im))
414426
b = BroadcastArray(exp, ((1:5) .+ im))
@@ -421,6 +433,22 @@ using Infinities
421433
a = BroadcastArray(Base.literal_pow, ^, 1:5, Val(2))
422434
@test last(a) == 25
423435
end
436+
437+
@testset "BroadcastArray(*) * MulArray" begin
438+
A = BroadcastArray(*, 1:3, randn(3,4))
439+
B = ApplyArray(*, randn(4,3), randn(3,4))
440+
@test A*B Matrix(A)*Matrix(B)
441+
@test A*UpperTriangular(B) Matrix(A)*UpperTriangular(Matrix(B))
442+
@test simplifiable(*,A,B) == Val(false) # TODO: Why False?
443+
@test simplifiable(*,A,UpperTriangular(B)) == Val(false) # TODO: Why False?
444+
end
445+
446+
@testset "/" begin
447+
A = BroadcastArray(/, randn(3,4), randn(3,4))
448+
B = randn(4,3)
449+
@test A*B Matrix(A)*Matrix(B)
450+
@test simplifiable(*,A,B) == Val(false) # TODO: Why False?
451+
end
424452
end
425453

426454
end #module

test/cachetests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ using Infinities
189189
@testset "linalg" begin
190190
c = cache(Fill(3,3,3))
191191
@test fill(2,1,3) * c == fill(18,1,3)
192-
@test ApplyMatrix(exp,fill(3,3,3)) * c == exp(fill(3,3,3)) * fill(3,3,3)
192+
@test ApplyMatrix(exp,fill(3,3,3)) * c exp(fill(3,3,3)) * fill(3,3,3)
193193
@test BroadcastMatrix(exp,fill(3,3,3)) * c == exp.(fill(3,3,3)) * fill(3,3,3)
194194
@test fill(2,3)' * c == fill(18,1,3)
195195
@test fill(2,3,1)' * c == fill(18,1,3)

test/ldivtests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,4 +205,12 @@ end
205205
@test rowsupport(invL, ()) == 1:0
206206
end
207207

208+
@testset "Inv \\ Lazy" begin
209+
A = randn(5,5)
210+
Ai = InvMatrix(A)
211+
@test Ai \ Ai I
212+
@test Ai \ BroadcastArray(exp, A) Ai \ exp.(A) A*exp.(A)
213+
@test Ai \ ApplyArray(*, A, A) A^3
214+
end
215+
208216
end # module

test/paddedtests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,5 +430,12 @@ paddeddata(a::PaddedPadded) = a
430430
@test_throws SingularException ArrayLayouts.ldiv!(Bidiagonal(-1:3, 1:4, :L), c)
431431
@test_throws SingularException ArrayLayouts.ldiv!(Bidiagonal(-4:0, 1:4, :L), c)
432432
end
433+
434+
@testset "Broadcast * Padded" begin
435+
B = BroadcastArray(*, 1:8, (2:9)')
436+
p = Vcat(1:2, Zeros(6))
437+
@test B*p == Matrix(B)*p
438+
@test simplifiable(*,B,p) == Val(true)
439+
end
433440
end
434441
end # module

0 commit comments

Comments
 (0)