Skip to content

Commit 2b2b380

Browse files
ptiederafaqz
andauthored
Improve GPU functionality (#780)
* Improve GPU functionality * Add missing weakdeps * Update src/array/broadcast.jl Co-authored-by: Rafael Schouten <[email protected]> * Update src/array/broadcast.jl Co-authored-by: Rafael Schouten <[email protected]> * Push materialize fix * Clean up mapreduce and add a bunch of tests for JLArray broadcast * Add some more JLArray tests * Just return dest in broadcast * Update src/array/methods.jl Co-authored-by: Rafael Schouten <[email protected]> * Format --------- Co-authored-by: Rafael Schouten <[email protected]>
1 parent fe39de7 commit 2b2b380

File tree

5 files changed

+347
-45
lines changed

5 files changed

+347
-45
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Interfaces = "0.3"
5454
IntervalSets = "0.5, 0.6, 0.7"
5555
InvertedIndices = "1"
5656
IteratorInterfaceExtensions = "1"
57+
JLArrays = "0.1"
5758
LinearAlgebra = "1"
5859
Makie = "0.19, 0.20, 0.21"
5960
OffsetArrays = "1"
@@ -85,6 +86,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
8586
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
8687
ImageFiltering = "6a3955dd-da59-5b1f-98d4-e7296123deb5"
8788
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
89+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
8890
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
8991
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
9092
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
@@ -95,4 +97,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9597
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
9698

9799
[targets]
98-
test = ["Aqua", "ArrayInterface", "BenchmarkTools", "CategoricalArrays", "ColorTypes", "Combinatorics", "CoordinateTransformations", "DataFrames", "Distributions", "Documenter", "ImageFiltering", "ImageTransformations", "CairoMakie", "OffsetArrays", "Plots", "Random", "SafeTestsets", "StatsPlots", "Test", "Unitful"]
100+
test = ["Aqua", "ArrayInterface", "BenchmarkTools", "CategoricalArrays", "ColorTypes", "Combinatorics", "CoordinateTransformations", "DataFrames", "Distributions", "Documenter", "ImageFiltering", "ImageTransformations", "JLArrays", "CairoMakie", "OffsetArrays", "Plots", "Random", "SafeTestsets", "StatsPlots", "Test", "Unitful"]

src/array/broadcast.jl

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,26 +47,24 @@ function Broadcast.copy(bc::Broadcasted{DimensionalStyle{S}}) where S
4747
end
4848

4949
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{DimensionalStyle{S}}) where S
50-
_dims = comparedims(dims(dest), _broadcasted_dims(bc); ignore_length_one=true, order=true)
50+
#TODO: this will cause a comparisson to happen twice. We should avoid that
51+
comparedims(dims(dest), _broadcasted_dims(bc); ignore_length_one=true, order=true)
5152
copyto!(dest, _unwrap_broadcasted(bc))
52-
A = _firstdimarray(bc)
53-
if A isa Nothing || _dims isa Nothing
54-
dest
55-
else
56-
rebuild(A, dest, _dims, refdims(A))
57-
end
53+
return dest
5854
end
59-
function Base.copyto!(dest::AbstractDimArray, bc::Broadcasted{DimensionalStyle{S}}) where S
60-
_dims = comparedims(dims(dest), _broadcasted_dims(bc); ignore_length_one=true, order=true)
61-
copyto!(parent(dest), _unwrap_broadcasted(bc))
62-
A = _firstdimarray(bc)
63-
if A isa Nothing || _dims isa Nothing
64-
dest
65-
else
66-
rebuild(A, parent(dest), _dims, refdims(A))
67-
end
55+
56+
57+
@inline function Base.Broadcast.materialize!(dest::AbstractDimArray, bc::Base.Broadcast.Broadcasted{<:Any})
58+
# needed because we need to check whether the dims are compatible in dest which are already
59+
# stripped when sent to copyto!
60+
comparedims(dims(dest), _broadcasted_dims(bc); ignore_length_one=true, order=true)
61+
style = DimensionalData.DimensionalStyle(Base.Broadcast.combine_styles(parent(dest), bc))
62+
Base.Broadcast.materialize!(style, parent(dest), bc)
63+
return dest
6864
end
6965

66+
67+
7068
function Base.similar(bc::Broadcast.Broadcasted{DimensionalStyle{S}}, ::Type{T}) where {S,T}
7169
A = _firstdimarray(bc)
7270
rebuildsliced(A, similar(_unwrap_broadcasted(bc), T, axes(bc)...), axes(bc), Symbol(""))

src/array/methods.jl

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,25 +53,9 @@ for (m, f) in ((:Statistics, :median), (:Base, :any), (:Base, :all))
5353
end
5454
end
5555

56-
# These are not exported but it makes a lot of things easier using them
57-
function Base._mapreduce_dim(f, op, nt::NamedTuple{(),<:Tuple}, A::AbstractDimArray, dims)
58-
rebuild(A, Base._mapreduce_dim(f, op, nt, parent(A), dimnum(A, _astuple(dims))), reducedims(A, dims))
59-
end
60-
function Base._mapreduce_dim(f, op, nt::NamedTuple{(),<:Tuple}, A::AbstractDimArray, dims::Colon)
61-
Base._mapreduce_dim(f, op, nt, parent(A), dims)
62-
end
63-
function Base._mapreduce_dim(f, op, nt, A::AbstractDimArray, dims)
64-
rebuild(A, Base._mapreduce_dim(f, op, nt, parent(A), dimnum(A, dims)), reducedims(A, dims))
65-
end
66-
function Base._mapreduce_dim(f, op, nt, A::AbstractDimArray, dims::Colon)
67-
rebuild(A, Base._mapreduce_dim(f, op, nt, parent(A), dimnum(A, dims)), reducedims(A, dims))
68-
end
69-
70-
function Base._mapreduce_dim(f, op, nt::Base._InitialValue, A::AbstractDimArray, dims)
71-
rebuild(A, Base._mapreduce_dim(f, op, nt, parent(A), dimnum(A, dims)), reducedims(A, dims))
72-
end
73-
function Base._mapreduce_dim(f, op, nt::Base._InitialValue, A::AbstractDimArray, dims::Colon)
74-
Base._mapreduce_dim(f, op, nt, parent(A), dims)
56+
function Base.mapreduce(f, op, A::AbstractDimArray; dims=Base.Colon(), kw...)
57+
dims === Colon() && return mapreduce(f, op, parent(A); kw...)
58+
rebuild(A, mapreduce(f, op, parent(A); dims=dimnum(A, dims), kw...), reducedims(A, dims))
7559
end
7660

7761

test/broadcast.jl

Lines changed: 112 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
using DimensionalData, Test
2-
2+
using JLArrays
33
using DimensionalData: NoLookup
44

55
# Tests taken from NamedDims. Thanks @oxinabox
66

77
da = ones(X(3))
8+
dajl = rebuild(da, JLArray(parent(da)));
89
@test Base.BroadcastStyle(typeof(da)) isa DimensionalData.DimensionalStyle
910

1011
@testset "standard case" begin
@@ -19,18 +20,35 @@ end
1920
@test da2 .* da2[:, 1:1] == [1, 4, 9, 16] * (1:2:8)'
2021
end
2122

23+
@testset "JLArray broadcast over length one dimension" begin
24+
da2 = DimArray(JLArray((1:4) * (1:2:8)'), (X, Y))
25+
@test Array(da2 .* da2[:, 1:1]) == [1, 4, 9, 16] * (1:2:8)'
26+
end
27+
2228
@testset "in place" begin
2329
@test parent(da .= 1 .* da .+ 7) == 8 * ones(3)
2430
@test dims(da .= 1 .* da .+ 7) == dims(da)
2531
end
2632

33+
@testset "JLArray in place" begin
34+
@test Array(parent(dajl .= 1 .* dajl .+ 7)) == 8 * ones(3)
35+
@test dims(dajl .= 1 .* dajl .+ 7) == dims(da)
36+
end
37+
2738
@testset "Dimension disagreement" begin
2839
@test_throws DimensionMismatch begin
2940
DimArray(zeros(3, 3, 3), (X, Y, Z)) .+
3041
DimArray(ones(3, 3, 3), (Y, Z, X))
3142
end
3243
end
3344

45+
@testset "JLArray Dimension disagreement" begin
46+
@test_throws DimensionMismatch begin
47+
DimArray(JLArray(zeros(3, 3, 3)), (X, Y, Z)) .+
48+
DimArray(JLArray(ones(3, 3, 3)), (Y, Z, X))
49+
end
50+
end
51+
3452
@testset "dims and regular" begin
3553
da = DimArray(ones(3, 3, 3), (X, Y, Z))
3654
left_sum = da .+ ones(3, 3, 3)
@@ -41,6 +59,16 @@ end
4159
@test dims(right_sum) == dims(da)
4260
end
4361

62+
@testset "JLArray dims and regular" begin
63+
da = DimArray(JLArray(ones(3, 3, 3)), (X, Y, Z))
64+
left_sum = da .+ ones(3, 3, 3)
65+
@test Array(left_sum) == fill(2, 3, 3, 3)
66+
@test dims(left_sum) == dims(da)
67+
right_sum = ones(3, 3, 3) .+ da
68+
@test Array(right_sum) == fill(2, 3, 3, 3)
69+
@test dims(right_sum) == dims(da)
70+
end
71+
4472
@testset "changing type" begin
4573
@test (da .> 0) isa DimArray
4674
@test (da .* da .> 0) isa DimArray
@@ -51,6 +79,16 @@ end
5179
@test (rand(3) .> 1 .> 0 .* da) isa DimArray
5280
end
5381

82+
@testset "JLArray changing type" begin
83+
@test (dajl .> 0) isa DimArray
84+
@test (dajl .* dajl .> 0) isa DimArray
85+
@test (dajl .> 0 .> rand(3)) isa DimArray
86+
@test (dajl .* rand(3) .> 0.0) isa DimArray
87+
@test (0 .> dajl .> 0 .> rand(3)) isa DimArray
88+
@test (rand(3) .> dajl .> 0 .* rand(3)) isa DimArray
89+
@test (rand(3) .> 1 .> 0 .* dajl) isa DimArray
90+
end
91+
5492
@testset "trailng dimensions" begin
5593
@test zeros(X(10), Y(5)) .* zeros(X(10), Y(1)) ==
5694
zeros(X(10), Y(5)) .* zeros(X(1), Y(1)) ==
@@ -79,6 +117,18 @@ end
79117
@test dims(s .+ v .+ m) == dims(m .+ s .+ v)
80118
end
81119

120+
@testset "JLArray broadcasting" begin
121+
v = DimArray(JLArray(zeros(3,)), X)
122+
m = DimArray(JLArray(ones(3, 3)), (X, Y))
123+
s = 0
124+
@test Array(v .+ m) == ones(3, 3) == Array(m .+ v)
125+
@test Array(s .+ m) == ones(3, 3) == Array(m .+ s)
126+
@test Array(s .+ v .+ m) == ones(3, 3) == Array(m .+ s .+ v)
127+
@test dims(v .+ m) == dims(m .+ v)
128+
@test dims(s .+ m) == dims(m .+ s)
129+
@test dims(s .+ v .+ m) == dims(m .+ s .+ v)
130+
end
131+
82132
@testset "adjoint broadcasting" begin
83133
a = DimArray(reshape(1:12, (4, 3)), (X, Y))
84134
b = DimArray(1:3, Y)
@@ -88,6 +138,17 @@ end
88138
@test dims(a .* b') == dims(a)
89139
end
90140

141+
@testset "JLArray adjoint broadcasting" begin
142+
a = DimArray(JLArray(reshape(1:12, (4, 3))), (X, Y))
143+
b = DimArray(JLArray(1:3), Y)
144+
@test_throws DimensionMismatch a .* b
145+
@test_throws DimensionMismatch parent(a) .* parent(b)
146+
@test Array(parent(a) .* parent(b)') == Array(parent(a .* b'))
147+
@test dims(a .* b') == dims(a)
148+
end
149+
150+
151+
91152
@testset "Mixed array types" begin
92153
casts = (
93154
A -> DimArray(A, (X, Y)), # Named Matrix
@@ -121,13 +182,26 @@ end
121182
@test_throws DimensionMismatch ac .= ab .+ ba
122183

123184
# check that dest is written into:
124-
@test dims(z .= ab .+ ba') == dims(ab .+ ba')
185+
z .= ab .+ ba'
125186
@test z == (ab.data .+ ba.data')
187+
end
126188

127-
@test dims(z .= ab .+ a_) ==
128-
(X(NoLookup(Base.OneTo(2))), Y(NoLookup(Base.OneTo(2))))
129-
@test dims(a_ .= ba' .+ ab) ==
130-
(X(NoLookup(Base.OneTo(2))), Y(NoLookup(Base.OneTo(2))))
189+
@testset "JLArray in-place assignment .=" begin
190+
ab = DimArray(JLArray(rand(2,2)), (X, Y))
191+
ba = DimArray(JLArray(rand(2,2)), (Y, X))
192+
ac = DimArray(JLArray(rand(2,2)), (X, Z))
193+
a_ = DimArray(JLArray(rand(2,2)), (X(), DimensionalData.AnonDim()))
194+
z = JLArray(zeros(2,2))
195+
196+
@test_throws DimensionMismatch z .= ab .+ ba
197+
@test_throws DimensionMismatch z .= ab .+ ac
198+
@test_throws DimensionMismatch a_ .= ab .+ ac
199+
@test_throws DimensionMismatch ab .= a_ .+ ac
200+
@test_throws DimensionMismatch ac .= ab .+ ba
201+
202+
# check that dest is written into:
203+
z .= ab .+ ba'
204+
@test z == (ab.data .+ ba.data')
131205
end
132206

133207
@testset "assign using named indexing and dotview" begin
@@ -137,6 +211,13 @@ end
137211
@test A == [1.0 1.0; 2.0 2.0; 7.0 7.0]
138212
end
139213

214+
@testset "JLArray assign using named indexing and dotview" begin
215+
A = DimArray(JLArray(zeros(3,2)), (X, Y))
216+
A[X=1:2] .= JLArray([1, 2])
217+
A[X=3] .= 7
218+
@test Array(A) == [1.0 1.0; 2.0 2.0; 7.0 7.0]
219+
end
220+
140221
@testset "0-dimensional array broadcasting" begin
141222
x = DimArray(fill(3), ())
142223
y = DimArray(fill(4), ())
@@ -168,6 +249,31 @@ end
168249
@test A[DimSelectors(sub)] == C[DimSelectors(sub)]
169250
end
170251

252+
@testset "JLArray DimIndices broadcasting" begin
253+
ds = X(1.0:0.2:2.0), Y(10:2:20)
254+
_A = (rand(ds))
255+
_B = (zeros(ds))
256+
_C = (zeros(ds))
257+
258+
A = rebuild(_A, JLArray(parent(_A)))
259+
B = rebuild(_B, JLArray(parent(_B)))
260+
C = rebuild(_C, JLArray(parent(_C)))
261+
262+
B[DimIndices(B)] .+= A
263+
C[DimSelectors(C)] .+= A
264+
@test Array(A) == Array(B) == Array(C)
265+
sub = A[1:4, 1:3]
266+
B .= 0
267+
C .= 0
268+
B[DimIndices(sub)] .+= sub
269+
C[DimSelectors(sub)] .+= sub
270+
@test Array(A[DimIndices(sub)]) == Array(B[DimIndices(sub)]) == Array(C[DimIndices(sub)])
271+
sub = A[2:4, 2:5]
272+
C .= 0
273+
C[DimSelectors(sub)] .+= sub
274+
@test Array(A[DimSelectors(sub)]) == Array(C[DimSelectors(sub)])
275+
end
276+
171277
# @testset "Competing Wrappers" begin
172278
# da = DimArray(ones(4), X)
173279
# ta = TrackedArray(5 * ones(4))

0 commit comments

Comments
 (0)