Skip to content

Commit 98720a9

Browse files
authored
overload mul! for gpu arryays (#167)
* mul * Update gpuarrays.jl * Update gpu_tests.jl * Update gpuarrays.jl * Update gpuarrays.jl * Update gpu_tests.jl * Update gpu_tests.jl * Update gpu_tests.jl * Update gpu_tests.jl * more test * Update gpu_tests.jl * Update gpu_tests.jl
1 parent d96f434 commit 98720a9

File tree

2 files changed

+218
-0
lines changed

2 files changed

+218
-0
lines changed

src/compat/gpuarrays.jl

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
const GPUComponentArray = ComponentArray{T,N,<:GPUArrays.AbstractGPUArray,Ax} where {T,N,Ax}
2+
const GPUComponentVector{T,Ax} = ComponentArray{T,1,<:GPUArrays.AbstractGPUVector,Ax}
3+
const GPUComponentMatrix{T,Ax} = ComponentArray{T,2,<:GPUArrays.AbstractGPUMatrix,Ax}
4+
const GPUComponentVecorMat{T,Ax} = Union{GPUComponentVector{T,Ax},GPUComponentMatrix{T,Ax}}
25

36
GPUArrays.backend(x::ComponentArray) = GPUArrays.backend(getdata(x))
47

@@ -71,3 +74,200 @@ function ComponentArray(nt::NamedTuple{names,<:Tuple{Vararg{Union{GPUArrays.Abst
7174
G = Base.typename(typeof(gpuarray)).wrapper # SciMLBase.parameterless_type(gpuarray)
7275
return GPUArrays.adapt(G, ComponentArray(NamedTuple{names}(map(GPUArrays.adapt(Array{T}), nt))))
7376
end
77+
78+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
79+
A::GPUComponentVecorMat,
80+
B::GPUComponentVecorMat, a::Number, b::Number)
81+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
82+
end
83+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
84+
A::GPUComponentVecorMat,
85+
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
86+
a::Number, b::Number)
87+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
88+
end
89+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
90+
A::GPUComponentVecorMat,
91+
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
92+
a::Number, b::Number)
93+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
94+
end
95+
96+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
97+
A::GPUComponentVecorMat,
98+
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
99+
a::Number, b::Number)
100+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
101+
end
102+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
103+
A::GPUComponentVecorMat,
104+
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
105+
}, a::Number, b::Number)
106+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
107+
end
108+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
109+
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
110+
B::GPUComponentVecorMat, a::Number, b::Number)
111+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
112+
end
113+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
114+
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
115+
B::GPUComponentVecorMat, a::Number, b::Number)
116+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
117+
end
118+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
119+
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
120+
B::GPUComponentVecorMat, a::Number, b::Number)
121+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
122+
end
123+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
124+
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
125+
}, B::GPUComponentVecorMat,
126+
a::Number, b::Number)
127+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
128+
end
129+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
130+
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
131+
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
132+
a::Number, b::Number)
133+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
134+
end
135+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
136+
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
137+
},
138+
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
139+
a::Number, b::Number)
140+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
141+
end
142+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
143+
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
144+
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
145+
a::Number, b::Number)
146+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
147+
end
148+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
149+
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
150+
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
151+
}, a::Number, b::Number)
152+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
153+
end
154+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
155+
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
156+
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
157+
a::Number, b::Number)
158+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
159+
end
160+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
161+
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
162+
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
163+
a::Number, b::Number)
164+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
165+
end
166+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
167+
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
168+
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
169+
a::Number, b::Number)
170+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
171+
end
172+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
173+
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
174+
},
175+
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
176+
}, a::Number, b::Number)
177+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
178+
end
179+
180+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
181+
A::GPUComponentVecorMat,
182+
B::GPUComponentVecorMat, a::Real, b::Real)
183+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
184+
end
185+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
186+
A::GPUComponentVecorMat,
187+
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real,
188+
b::Real)
189+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
190+
end
191+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
192+
A::GPUComponentVecorMat,
193+
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
194+
a::Real, b::Real)
195+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
196+
end
197+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
198+
A::GPUComponentVecorMat,
199+
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
200+
}, a::Real, b::Real)
201+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
202+
end
203+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
204+
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
205+
B::GPUComponentVecorMat, a::Real, b::Real)
206+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
207+
end
208+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
209+
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
210+
B::GPUComponentVecorMat, a::Real, b::Real)
211+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
212+
end
213+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
214+
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
215+
B::GPUComponentVecorMat, a::Real, b::Real)
216+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
217+
end
218+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
219+
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
220+
}, B::GPUComponentVecorMat,
221+
a::Real, b::Real)
222+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
223+
end
224+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
225+
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
226+
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real,
227+
b::Real)
228+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
229+
end
230+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
231+
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
232+
},
233+
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
234+
a::Real, b::Real)
235+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
236+
end
237+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
238+
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
239+
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
240+
a::Real, b::Real)
241+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
242+
end
243+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
244+
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
245+
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
246+
}, a::Real, b::Real)
247+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
248+
end
249+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
250+
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
251+
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real,
252+
b::Real)
253+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
254+
end
255+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
256+
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
257+
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
258+
a::Real, b::Real)
259+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
260+
end
261+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
262+
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
263+
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
264+
a::Real, b::Real)
265+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
266+
end
267+
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
268+
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
269+
},
270+
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
271+
}, a::Real, b::Real)
272+
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
273+
end

test/gpu_tests.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,22 @@ end
4545
jlca3 = deepcopy(jlca)
4646
@test rmul!(jlca3, 2) == ComponentArray(jla .* 2, Axis(a=1:2, b=3:4))
4747
end
48+
@testset "mul!" begin
49+
A = jlca .* jlca';
50+
@test_nowarn mul!(deepcopy(A), A, A, 1, 2);
51+
@test_nowarn mul!(deepcopy(A), A', A', 1, 2);
52+
@test_nowarn mul!(deepcopy(A), A', A, 1, 2);
53+
@test_nowarn mul!(deepcopy(A), A, A', 1, 2);
54+
@test_nowarn mul!(deepcopy(A), A, getdata(A'), 1, 2);
55+
@test_nowarn mul!(deepcopy(A), getdata(A'), A, 1, 2);
56+
@test_nowarn mul!(deepcopy(A), getdata(A'), getdata(A'), 1, 2);
57+
@test_nowarn mul!(deepcopy(A), transpose(A), A, 1, 2);
58+
@test_nowarn mul!(deepcopy(A), A, transpose(A), 1, 2);
59+
@test_nowarn mul!(deepcopy(A), transpose(A), transpose(A), 1, 2);
60+
@test_nowarn mul!(deepcopy(A), transpose(getdata(A)), A, 1, 2);
61+
@test_nowarn mul!(deepcopy(A), A, transpose(getdata(A)), 1, 2);
62+
@test_nowarn mul!(deepcopy(A), transpose(getdata(A)), transpose(getdata(A)), 1, 2);
63+
@test_nowarn mul!(deepcopy(A), transpose(A), A', 1, 2);
64+
@test_nowarn mul!(deepcopy(A), A', transpose(A), 1, 2);
65+
end
4866
end

0 commit comments

Comments
 (0)