Skip to content

Commit 4491a56

Browse files
fix GPUArrays usage
1 parent 9cb9c36 commit 4491a56

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/default.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function defaultalg(A,b)
99
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
1010
# it makes sense according to the benchmarks, which is dependent on
1111
# whether MKL or OpenBLAS is being used
12-
if (A === nothing && b isa AbstractGPUArray) || A isa Matrix
12+
if (A === nothing && b isa GPUArrays.AbstractGPUArray) || A isa Matrix
1313
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
1414
ArrayInterface.can_setindex(b)
1515
if length(b) <= 10
@@ -39,7 +39,7 @@ function defaultalg(A,b)
3939

4040
# This catches the case where A is a CuMatrix
4141
# Which does not have LU fully defined
42-
elseif A isa AbstractGPUArray || b isa AbstractGPUArray
42+
elseif A isa GPUArrays.AbstractGPUArray || b isa GPUArrays.AbstractGPUArray
4343
alg = QRFactorization(false)
4444

4545
# Not factorizable operator, default to only using A*x
@@ -100,7 +100,7 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
100100

101101
# This catches the case where A is a CuMatrix
102102
# Which does not have LU fully defined
103-
elseif A isa AbstractGPUArray
103+
elseif A isa GPUArrays.AbstractGPUArray
104104
alg = QRFactorization(false)
105105
SciMLBase.solve(cache, alg, args...; kwargs...)
106106

@@ -158,7 +158,7 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,
158158

159159
# This catches the case where A is a CuMatrix
160160
# Which does not have LU fully defined
161-
elseif A isa AbstractGPUArray
161+
elseif A isa GPUArrays.AbstractGPUArray
162162
alg = QRFactorization(false)
163163
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
164164

0 commit comments

Comments
 (0)