Skip to content

Commit 2c3fb73

Browse files
authored
Overload more GPU operations (#164)
* make adapt work on GPU * overload more gpu operations * update
1 parent c409a9e commit 2c3fb73

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

src/compat/gpuarrays.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,23 @@ function GPUArrays.Adapt.adapt_structure(to, x::ComponentArray)
77
return ComponentArray(data, getaxes(x))
88
end
99

10+
GPUArrays.Adapt.adapt_storage(::Type{ComponentArray{T,N,A,Ax}}, xs::AT) where {T,N,A,Ax,AT<:AbstractArray} =
11+
GPUArrays.Adapt.adapt_storage(A, xs)
12+
13+
function Base.fill!(A::GPUComponentArray{T}, x) where {T}
14+
length(A) == 0 && return A
15+
GPUArrays.gpu_call(A, convert(T, x)) do ctx, a, val
16+
idx = GPUArrays.@linearidx(a)
17+
@inbounds a[idx] = val
18+
return
19+
end
20+
A
21+
end
22+
23+
LinearAlgebra.dot(x::GPUComponentArray, y::GPUComponentArray) = dot(getdata(x), getdata(y))
24+
LinearAlgebra.norm(ca::GPUComponentArray, p::Real) = norm(getdata(ca), p)
25+
LinearAlgebra.rmul!(ca::GPUComponentArray, b::Number) = GPUArrays.generic_rmul!(ca, b)
26+
1027
function Base.map(f, x::GPUComponentArray, args...)
1128
data = map(f, getdata(x), getdata.(args)...)
1229
return ComponentArray(data, getaxes(x))
@@ -46,4 +63,4 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
4663
Base.$(fname!)(f::Function, r::GPUComponentArray, A::GPUComponentArray{T}) where T =
4764
GPUArrays.mapreducedim!(f, $(op), getdata(r), getdata(A); init=neutral_element($(op), T))
4865
end
49-
end
66+
end

test/gpu_tests.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,42 @@ jlca = ComponentArray(jla, Axis(a=1:2, b=3:4))
77

88
@testset "Broadcasting" begin
99
@test identity.(jlca + jla) ./ 2 == jlca
10-
10+
1111
@test getdata(map(identity, jlca)) isa JLArray
1212
@test all(==(0), map(-, jlca, jla))
1313
@test all(map(-, jlca, jlca) .== 0)
1414
@test all(==(0), map(-, jla, jlca))
1515

1616
@test any(==(1), jlca)
1717
@test count(>(2), jlca) == 2
18-
18+
1919
# Make sure mapreducing multiple arrays works
2020
@test mapreduce(==, +, jlca, jla) == 4
2121
@test mapreduce(abs2, +, jlca) == 30
2222

2323
@test all(map(sin, jlca) .== sin.(jlca) .== sin.(jla) .≈ sin.(1:4))
2424
end
25+
26+
@testset "adapt" begin
27+
x = [1 2; 3 4]
28+
jlx = JLArrays.Adapt.adapt(typeof(jlca), x)
29+
@test jlx isa JLArray
30+
end
31+
32+
@testset "Linear Algebra" begin
33+
@testset "fill!" begin
34+
jlca2 = deepcopy(jlca)
35+
jlca2 = fill!(jlca2, 2)
36+
@test jlca2 == ComponentArray(jl([2,2,2,2]), Axis(a=1:2, b=3:4))
37+
end
38+
39+
@testset "norm" begin
40+
@test norm(jlca, 2) == norm(jla,2)
41+
@test norm(jlca, Inf) == norm(jla,Inf)
42+
end
43+
44+
@testset "rmul!" begin
45+
jlca3 = deepcopy(jlca)
46+
@test rmul!(jlca3, 2) == ComponentArray(jla .* 2, Axis(a=1:2, b=3:4))
47+
end
48+
end

0 commit comments

Comments
 (0)