Skip to content

Commit 7948ff1

Browse files
authored
Fix row vector views (#84)
* Fix row vector views * sub views of BroadcastArray are BroadcastLayout * Update broadcasttests.jl * Update broadcasttests.jl * improve code
1 parent 7bbf717 commit 7948ff1

File tree

5 files changed

+213
-120
lines changed

5 files changed

+213
-120
lines changed

src/LazyArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, broadcas
3737
materialize!, eltypes
3838

3939
import LinearAlgebra: AbstractTriangular, AbstractQ, checksquare, pinv, fill!, tilebufsize, Abuf, Bbuf, Cbuf, dot, factorize, qr, lu, cholesky,
40-
norm2, norm1, normInf, normMinusInf, det, tr
40+
norm2, norm1, normInf, normMinusInf, det, tr, AdjOrTrans
4141

4242
import LinearAlgebra.BLAS: BlasFloat, BlasReal, BlasComplex
4343

src/lazybroadcasting.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ BroadcastMatrix(A::BroadcastMatrix) = A
4040

4141
Broadcasted(A::BroadcastArray) = instantiate(broadcasted(A.f, A.args...))
4242

43+
@inline BroadcastArray(A::AbstractArray) = BroadcastArray(call(A), arguments(A)...)
44+
4345
axes(A::BroadcastArray) = axes(Broadcasted(A))
4446
size(A::BroadcastArray) = map(length, axes(A))
4547

@@ -174,3 +176,34 @@ for op in (:+, :-)
174176
rowsupport(::BroadcastLayout{typeof($op)}, A, j) = convexunion(_broadcast_rowsupport.(Ref(size(A)), A.args, j)...)
175177
end
176178
end
179+
180+
181+
###
182+
# SubArray
183+
###
184+
185+
call(b::BroadcastLayout, a::SubArray) = call(b, parent(a))
186+
187+
sublayout(b::BroadcastLayout, _) = b
188+
189+
190+
_broadcastviewinds(::Tuple{}, inds) = ()
191+
_broadcastviewinds(sz, inds) =
192+
tuple(isone(sz[1]) ? OneTo(sz[1]) : inds[1], _broadcastviewinds(tail(sz), tail(inds))...)
193+
194+
_broadcastview(a, inds) = view(a, _broadcastviewinds(size(a), inds)...)
195+
196+
function arguments(b::BroadcastLayout, V::SubArray)
197+
args = arguments(parent(V))
198+
_broadcastview.(args, Ref(parentindices(V)))
199+
end
200+
201+
###
202+
# Transpose
203+
###
204+
205+
call(b::BroadcastLayout, a::AdjOrTrans) = call(b, parent(a))
206+
207+
transposelayout(b::BroadcastLayout) = b
208+
arguments(b::BroadcastLayout, A::Adjoint) = map(adjoint, arguments(b, parent(A)))
209+
arguments(b::BroadcastLayout, A::Transpose) = map(transpose, arguments(b, parent(A)))

src/linalg/mul.jl

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -191,26 +191,31 @@ sublayout(::ApplyLayout{typeof(*)}, _...) = ApplyLayout{typeof(*)}()
191191

192192
call(::ApplyLayout{typeof(*)}, V::SubArray) = *
193193

194-
function _mat_mul_arguments(V)
195-
P = parent(V)
196-
kr, jr = parentindices(V)
197-
as = arguments(P)
198-
kjr = intersect.(_mul_args_rows(kr, as...), _mul_args_cols(jr, reverse(as)...))
199-
view.(as, (kr, kjr...), (kjr..., jr))
194+
function _mat_mul_arguments(args, (kr,jr))
195+
kjr = intersect.(_mul_args_rows(kr, args...), _mul_args_cols(jr, reverse(args)...))
196+
view.(args, (kr, kjr...), (kjr..., jr))
200197
end
201198

202-
203199
_vec_mul_view(a...) = view(a...)
204200
_vec_mul_view(a::AbstractVector, kr, ::Colon) = view(a, kr)
205201

206-
function _vec_mul_arguments(V)
207-
P = parent(V)
208-
kr, = parentindices(V)
209-
as = arguments(P)
210-
kjr = intersect.(_mul_args_rows(kr, as...), _mul_args_cols(Base.OneTo(1), reverse(as)...))
211-
_vec_mul_view.(as, (kr, kjr...), (kjr..., :))
202+
# this is a vector view of a MulVector
203+
function _vec_mul_arguments(args, (kr,))
204+
kjr = intersect.(_mul_args_rows(kr, args...), _mul_args_cols(Base.OneTo(1), reverse(args)...))
205+
_vec_mul_view.(args, (kr, kjr...), (kjr..., :))
212206
end
213207

208+
# this is a vector view of a MulMatrix
209+
_vec_mul_arguments(args, (kr,jr)::Tuple{AbstractVector,Number}) =
210+
_mat_mul_arguments(args, (kr,jr))
211+
212+
# this is a row-vector view
213+
_vec_mul_arguments(args, (kr,jr)::Tuple{Number,AbstractVector}) =
214+
_vec_mul_arguments(reverse(map(transpose, args)), (jr,kr))
215+
216+
_mat_mul_arguments(V) = _mat_mul_arguments(arguments(parent(V)), parentindices(V))
217+
_vec_mul_arguments(V) = _vec_mul_arguments(arguments(parent(V)), parentindices(V))
218+
214219
arguments(::ApplyLayout{typeof(*)}, V::SubArray{<:Any,2}) = _mat_mul_arguments(V)
215220
arguments(::ApplyLayout{typeof(*)}, V::SubArray{<:Any,1}) = _vec_mul_arguments(V)
216221

test/broadcasttests.jl

Lines changed: 150 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,152 @@
1-
using LazyArrays, Test
2-
3-
@testset "BroadcastArray" begin
4-
a = randn(6)
5-
b = BroadcastArray(exp, a)
6-
@test BroadcastArray(b) == BroadcastVector(b) == b
7-
8-
@test b == Vector(b) == exp.(a)
9-
@test b[2:5] isa BroadcastVector
10-
@test b[2:5] == exp.(a[2:5])
11-
12-
@test exp.(b) isa BroadcastVector
13-
@test b .+ SVector(1,2,3,4,5,6) isa BroadcastVector
14-
@test SVector(1,2,3,4,5,6) .+ b isa BroadcastVector
15-
16-
A = randn(6,6)
17-
B = BroadcastArray(exp, A)
18-
19-
@test Matrix(B) == exp.(A)
20-
21-
22-
C = BroadcastArray(+, A, 2)
23-
@test C == A .+ 2
24-
D = BroadcastArray(+, A, C)
25-
@test D == A + C
26-
27-
@test sum(B) sum(exp, A)
28-
@test sum(C) sum(A .+ 2)
29-
@test prod(B) prod(exp, A)
30-
@test prod(C) prod(A .+ 2)
31-
32-
x = Vcat([3,4], [1,1,1,1,1], 1:3)
33-
@test x .+ (1:10) isa Vcat
34-
@test (1:10) .+ x isa Vcat
35-
@test x + (1:10) isa Vcat
36-
@test (1:10) + x isa Vcat
37-
@test x .+ (1:10) == (1:10) .+ x == (1:10) + x == x + (1:10) == Vector(x) + (1:10)
38-
39-
@test exp.(x) isa Vcat
40-
@test exp.(x) == exp.(Vector(x))
41-
@test x .+ 2 isa Vcat
42-
@test (x .+ 2).args[end] x.args[end] .+ 2 3:5
43-
@test x .* 2 isa Vcat
44-
@test 2 .+ x isa Vcat
45-
@test 2 .* x isa Vcat
46-
47-
A = Vcat([[1 2; 3 4]], [[4 5; 6 7]])
48-
@test A .+ Ref(I) == Ref(I) .+ A == Vcat([[2 2; 3 5]], [[5 5; 6 8]])
49-
50-
@test_broken BroadcastArray(*,1.1,[1 2])[1] == 1.1
51-
52-
B = BroadcastArray(*, Diagonal(randn(5)), randn(5,5))
53-
@test B == broadcast(*,B.args...)
54-
@test colsupport(B,1) == rowsupport(B,1) == 1:1
55-
@test colsupport(B,3) == rowsupport(B,3) == 3:3
56-
@test colsupport(B,5) == rowsupport(B,5) == 5:5
57-
B = BroadcastArray(*, Diagonal(randn(5)), 2)
58-
@test B == broadcast(*,B.args...)
59-
@test colsupport(B,1) == rowsupport(B,1) == 1:1
60-
@test colsupport(B,3) == rowsupport(B,3) == 3:3
61-
@test colsupport(B,5) == rowsupport(B,5) == 5:5
62-
B = BroadcastArray(*, Diagonal(randn(5)), randn(5))
63-
@test B == broadcast(*,B.args...)
64-
@test colsupport(B,1) == rowsupport(B,1) == 1:1
65-
@test colsupport(B,3) == rowsupport(B,3) == 3:3
66-
@test colsupport(B,5) == rowsupport(B,5) == 5:5
67-
68-
B = BroadcastArray(+, Diagonal(randn(5)), 2)
69-
@test colsupport(B,1) == rowsupport(B,1) == 1:5
70-
@test colsupport(B,3) == rowsupport(B,3) == 1:5
71-
@test colsupport(B,5) == rowsupport(B,5) == 1:5
72-
end
73-
74-
@testset "vector*matrix broadcasting #27" begin
75-
H = [1., 0.]
76-
@test Mul(H, H') .+ 1 == H*H' .+ 1
77-
B = randn(2,2)
78-
@test Mul(H, H') .+ B == H*H' .+ B
79-
end
80-
81-
@testset "BroadcastArray +" begin
82-
a = BroadcastArray(+, randn(400), randn(400))
83-
b = similar(a)
84-
copyto!(b, a)
85-
if VERSION v"1.1"
86-
@test @allocated(copyto!(b, a)) == 0
1+
using LazyArrays, ArrayLayouts, LinearAlgebra, FillArrays, StaticArrays, Test
2+
import LazyArrays: BroadcastLayout, arguments, LazyArrayStyle
3+
import Base: broadcasted
4+
5+
@testset "Broadcasting" begin
6+
@testset "BroadcastArray" begin
7+
a = randn(6)
8+
b = BroadcastArray(exp, a)
9+
@test BroadcastArray(b) == BroadcastVector(b) == b
10+
11+
@test b == Vector(b) == exp.(a)
12+
@test b[2:5] isa BroadcastVector
13+
@test b[2:5] == exp.(a[2:5])
14+
15+
@test exp.(b) isa BroadcastVector
16+
@test b .+ SVector(1,2,3,4,5,6) isa BroadcastVector
17+
@test SVector(1,2,3,4,5,6) .+ b isa BroadcastVector
18+
19+
A = randn(6,6)
20+
B = BroadcastArray(exp, A)
21+
22+
@test Matrix(B) == exp.(A)
23+
24+
25+
C = BroadcastArray(+, A, 2)
26+
@test C == A .+ 2
27+
D = BroadcastArray(+, A, C)
28+
@test D == A + C
29+
30+
@test sum(B) sum(exp, A)
31+
@test sum(C) sum(A .+ 2)
32+
@test prod(B) prod(exp, A)
33+
@test prod(C) prod(A .+ 2)
34+
35+
x = Vcat([3,4], [1,1,1,1,1], 1:3)
36+
@test x .+ (1:10) isa Vcat
37+
@test (1:10) .+ x isa Vcat
38+
@test x + (1:10) isa Vcat
39+
@test (1:10) + x isa Vcat
40+
@test x .+ (1:10) == (1:10) .+ x == (1:10) + x == x + (1:10) == Vector(x) + (1:10)
41+
42+
@test exp.(x) isa Vcat
43+
@test exp.(x) == exp.(Vector(x))
44+
@test x .+ 2 isa Vcat
45+
@test (x .+ 2).args[end] x.args[end] .+ 2 3:5
46+
@test x .* 2 isa Vcat
47+
@test 2 .+ x isa Vcat
48+
@test 2 .* x isa Vcat
49+
50+
A = Vcat([[1 2; 3 4]], [[4 5; 6 7]])
51+
@test A .+ Ref(I) == Ref(I) .+ A == Vcat([[2 2; 3 5]], [[5 5; 6 8]])
52+
53+
@test_broken BroadcastArray(*,1.1,[1 2])[1] == 1.1
54+
55+
B = BroadcastArray(*, Diagonal(randn(5)), randn(5,5))
56+
@test B == broadcast(*,B.args...)
57+
@test colsupport(B,1) == rowsupport(B,1) == 1:1
58+
@test colsupport(B,3) == rowsupport(B,3) == 3:3
59+
@test colsupport(B,5) == rowsupport(B,5) == 5:5
60+
B = BroadcastArray(*, Diagonal(randn(5)), 2)
61+
@test B == broadcast(*,B.args...)
62+
@test colsupport(B,1) == rowsupport(B,1) == 1:1
63+
@test colsupport(B,3) == rowsupport(B,3) == 3:3
64+
@test colsupport(B,5) == rowsupport(B,5) == 5:5
65+
B = BroadcastArray(*, Diagonal(randn(5)), randn(5))
66+
@test B == broadcast(*,B.args...)
67+
@test colsupport(B,1) == rowsupport(B,1) == 1:1
68+
@test colsupport(B,3) == rowsupport(B,3) == 3:3
69+
@test colsupport(B,5) == rowsupport(B,5) == 5:5
70+
71+
B = BroadcastArray(+, Diagonal(randn(5)), 2)
72+
@test colsupport(B,1) == rowsupport(B,1) == 1:5
73+
@test colsupport(B,3) == rowsupport(B,3) == 1:5
74+
@test colsupport(B,5) == rowsupport(B,5) == 1:5
75+
end
76+
77+
@testset "vector*matrix broadcasting #27" begin
78+
H = [1., 0.]
79+
@test Mul(H, H') .+ 1 == H*H' .+ 1
80+
B = randn(2,2)
81+
@test Mul(H, H') .+ B == H*H' .+ B
82+
end
83+
84+
@testset "BroadcastArray +" begin
85+
a = BroadcastArray(+, randn(400), randn(400))
86+
b = similar(a)
87+
copyto!(b, a)
88+
if VERSION v"1.1"
89+
@test @allocated(copyto!(b, a)) == 0
90+
end
91+
@test b == a
92+
end
93+
94+
@testset "Lazy range" begin
95+
@test broadcasted(LazyArrayStyle{1}(), +, 1:5) 1:5
96+
@test broadcasted(LazyArrayStyle{1}(), +, 1, 1:5) 2:6
97+
@test broadcasted(LazyArrayStyle{1}(), +, 1:5, 1) 2:6
98+
99+
@test broadcasted(LazyArrayStyle{1}(), +, Fill(2,5)) Fill(2,5)
100+
@test broadcasted(LazyArrayStyle{1}(), +, 1, Fill(2,5)) Fill(3,5)
101+
@test broadcasted(LazyArrayStyle{1}(), +, Fill(2,5), 1) Fill(3,5)
102+
@test broadcasted(LazyArrayStyle{1}(), +, Ref(1), Fill(2,5)) Fill(3,5)
103+
@test broadcasted(LazyArrayStyle{1}(), +, Fill(2,5), Ref(1)) Fill(3,5)
104+
@test broadcasted(LazyArrayStyle{1}(), +, 1, Fill(2,5)) Fill(3,5)
105+
@test broadcasted(LazyArrayStyle{1}(), +, Fill(2,5), Fill(3,5)) Fill(5,5)
106+
107+
@test broadcasted(LazyArrayStyle{1}(), *, Zeros(5), Zeros(5)) Zeros(5)
108+
b = BroadcastArray(exp, randn(5))
109+
@test b .* Zeros(5) Zeros(5)
110+
@test Zeros(5) .* b Zeros(5)
111+
end
112+
113+
@testset "Sub-broadcast" begin
114+
A = BroadcastArray(exp,randn(5,5))
115+
V = view(A, 1:2,2:3)
116+
@test MemoryLayout(typeof(V)) isa BroadcastLayout{typeof(exp)}
117+
@test BroadcastArray(V) == A[1:2,2:3] == Array(A)[1:2,2:3]
118+
119+
B = BroadcastArray(-, randn(5,5), randn(5))
120+
V = view(B, 1:2,2:3)
121+
@test MemoryLayout(typeof(V)) isa BroadcastLayout{typeof(-)}
122+
@test BroadcastArray(V) == B[1:2,2:3] == Array(B)[1:2,2:3]
123+
end
124+
125+
@testset "AdjTrans" begin
126+
A = BroadcastArray(exp,randn(5,5))
127+
@test MemoryLayout(typeof(transpose(A))) isa BroadcastLayout{typeof(exp)}
128+
@test MemoryLayout(typeof(A')) isa BroadcastLayout{typeof(exp)}
129+
@test BroadcastArray(A') == BroadcastArray(transpose(A)) == A' == Array(A)'
130+
131+
B = BroadcastArray(-, randn(5,5), randn(5))
132+
@test MemoryLayout(typeof(transpose(B))) isa BroadcastLayout{typeof(-)}
133+
@test MemoryLayout(typeof(B')) isa BroadcastLayout{typeof(-)}
134+
@test BroadcastArray(B') == BroadcastArray(transpose(B)) == B' == Array(B)'
135+
136+
Vc = view(B', 1:2,1:3)
137+
Vt = view(transpose(B), 1:2,1:3)
138+
@test MemoryLayout(typeof(Vc)) isa BroadcastLayout{typeof(-)}
139+
@test MemoryLayout(typeof(Vt)) isa BroadcastLayout{typeof(-)}
140+
@test arguments(Vc) == (B.args[1][1:3,1:2]', permutedims(B.args[2][1:3]))
141+
@test arguments(Vt) == (transpose(B.args[1][1:3,1:2]), permutedims(B.args[2][1:3]))
142+
@test BroadcastArray(Vc) == BroadcastArray(Vt) == Vc == Vt == (Array(B)')[1:2,1:3]
143+
144+
Vc = view(B,1:3,1:2)'
145+
Vt = transpose(view(B,1:3,1:2))
146+
@test MemoryLayout(typeof(Vc)) isa BroadcastLayout{typeof(-)}
147+
@test MemoryLayout(typeof(Vt)) isa BroadcastLayout{typeof(-)}
148+
@test arguments(Vc) == (B.args[1][1:3,1:2]', permutedims(B.args[2][1:3]))
149+
@test arguments(Vt) == (transpose(B.args[1][1:3,1:2]), permutedims(B.args[2][1:3]))
150+
@test BroadcastArray(Vc) == BroadcastArray(Vt) == Vc == (Array(B)')[1:2,1:3]
87151
end
88-
@test b == a
89-
end
90-
91-
@testset "Lazy range" begin
92-
@test broadcasted(LazyArrayStyle{1}(), +, 1:5) 1:5
93-
@test broadcasted(LazyArrayStyle{1}(), +, 1, 1:5) 2:6
94-
@test broadcasted(LazyArrayStyle{1}(), +, 1:5, 1) 2:6
95-
96-
@test broadcasted(LazyArrayStyle{1}(), +, Fill(2,5)) Fill(2,5)
97-
@test broadcasted(LazyArrayStyle{1}(), +, 1, Fill(2,5)) Fill(3,5)
98-
@test broadcasted(LazyArrayStyle{1}(), +, Fill(2,5), 1) Fill(3,5)
99-
@test broadcasted(LazyArrayStyle{1}(), +, Ref(1), Fill(2,5)) Fill(3,5)
100-
@test broadcasted(LazyArrayStyle{1}(), +, Fill(2,5), Ref(1)) Fill(3,5)
101-
@test broadcasted(LazyArrayStyle{1}(), +, 1, Fill(2,5)) Fill(3,5)
102-
@test broadcasted(LazyArrayStyle{1}(), +, Fill(2,5), Fill(3,5)) Fill(5,5)
103-
104-
@test broadcasted(LazyArrayStyle{1}(), *, Zeros(5), Zeros(5)) Zeros(5)
105-
b = BroadcastArray(exp, randn(5))
106-
@test b .* Zeros(5) Zeros(5)
107-
@test Zeros(5) .* b Zeros(5)
108152
end

test/multests.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,4 +1073,15 @@ end
10731073
@test MemoryLayout(typeof(transpose(V))) isa ApplyLayout{typeof(*)}
10741074
@test transpose(V) == ApplyArray(transpose(V))
10751075
end
1076+
1077+
@testset "row-vec * fix" begin
1078+
A = ApplyArray(*,[1 2; 3 4], Vcat(Fill(1,1,3),Fill(2,1,3)))
1079+
V = view(A, 2, 1:2)
1080+
@test arguments(V) == ([1 2; 1 2], [3,4])
1081+
@test ApplyArray(V) == Array(V) == A[2,1:2] == Array(A)[2,1:2]
1082+
1083+
V = view(A, 1:2, 2)
1084+
@test arguments(V) == ([1 2; 3 4], [1,2])
1085+
@test ApplyArray(V) == Array(V) == A[1:2,2] == Array(A)[1:2,2]
1086+
end
10761087
end

0 commit comments

Comments
 (0)