Skip to content

Commit bc0f4ee

Browse files
Merge pull request #626 from SciML/cudss
Fix CUDSS dispatches
2 parents 1d2fa41 + e07ed9f commit bc0f4ee

File tree

6 files changed

+57
-4
lines changed

6 files changed

+57
-4
lines changed

.buildkite/pipeline.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@ steps:
1212
GROUP: 'LinearSolveCUDA'
1313
JULIA_PKG_SERVER: "" # it often struggles with our large artifacts
1414
# SECRET_CODECOV_TOKEN: "..."
15-
timeout_in_minutes: 30
15+
timeout_in_minutes: 180
1616
# Don't run Buildkite if the commit message includes the text [skip tests]
1717
if: build.message !~ /\[skip tests\]/

ext/LinearSolveCUDAExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ using LinearSolve
55
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
66
using SciMLBase: AbstractSciMLOperator
77

8+
function LinearSolve.is_cusparse(A::Union{CUDA.CUSPARSE.CuSparseMatrixCSR, CUDA.CUSPARSE.CuSparseMatrixCSC})
9+
true
10+
end
11+
812
function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
913
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
1014
if LinearSolve.cudss_loaded(A)

ext/LinearSolveSparseArraysExt.jl

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module LinearSolveSparseArraysExt
33
using LinearSolve, LinearAlgebra
44
using SparseArrays
55
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
6-
using LinearSolve: BLASELTYPES, pattern_changed
6+
using LinearSolve: BLASELTYPES, pattern_changed, ArrayInterface
77

88
# Can't `using KLU` because cannot have a dependency in there without
99
# requiring the user does `using KLU`
@@ -100,15 +100,31 @@ function LinearSolve.init_cacheval(
100100
Pl, Pr,
101101
maxiters::Int, abstol, reltol,
102102
verbose::Bool, assumptions::OperatorAssumptions) where {T<:BLASELTYPES}
103-
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int64}(zero(Int64), zero(Int64), [Int64(1)], Int64[], T[]))
103+
if is_cusparse(A)
104+
ArrayInterface.lu_instance(A)
105+
else
106+
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int64}(zero(Int64), zero(Int64), [Int64(1)], Int64[], T[]))
107+
end
104108
end
105109

106110
function LinearSolve.init_cacheval(
107111
alg::LUFactorization, A::AbstractSparseArray{T, Int32}, b, u,
108112
Pl, Pr,
109113
maxiters::Int, abstol, reltol,
110114
verbose::Bool, assumptions::OperatorAssumptions) where {T<:BLASELTYPES}
111-
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int32}(zero(Int32), zero(Int32), [Int32(1)], Int32[], T[]))
115+
if LinearSolve.is_cusparse(A)
116+
ArrayInterface.lu_instance(A)
117+
else
118+
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int32}(zero(Int32), zero(Int32), [Int32(1)], Int32[], T[]))
119+
end
120+
end
121+
122+
function LinearSolve.init_cacheval(
123+
alg::LUFactorization, A::LinearSolve.GPUArraysCore.AnyGPUArray, b, u,
124+
Pl, Pr,
125+
maxiters::Int, abstol, reltol,
126+
verbose::Bool, assumptions::OperatorAssumptions)
127+
ArrayInterface.lu_instance(A)
112128
end
113129

114130
function LinearSolve.init_cacheval(
@@ -120,6 +136,14 @@ function LinearSolve.init_cacheval(
120136
PREALLOCATED_UMFPACK
121137
end
122138

139+
function LinearSolve.init_cacheval(
140+
alg::UMFPACKFactorization, A::LinearSolve.GPUArraysCore.AnyGPUArray, b, u,
141+
Pl, Pr,
142+
maxiters::Int, abstol, reltol,
143+
verbose::Bool, assumptions::OperatorAssumptions)
144+
nothing
145+
end
146+
123147
function LinearSolve.init_cacheval(
124148
alg::UMFPACKFactorization, A::AbstractSparseArray{T, Int64}, b, u,
125149
Pl, Pr,
@@ -191,6 +215,14 @@ function LinearSolve.init_cacheval(
191215
PREALLOCATED_KLU
192216
end
193217

218+
function LinearSolve.init_cacheval(
219+
alg::KLUFactorization, A::LinearSolve.GPUArraysCore.AnyGPUArray, b, u,
220+
Pl, Pr,
221+
maxiters::Int, abstol, reltol,
222+
verbose::Bool, assumptions::OperatorAssumptions)
223+
nothing
224+
end
225+
194226
function LinearSolve.init_cacheval(
195227
alg::KLUFactorization, A::AbstractSparseArray{Float64, Int32}, b, u, Pl, Pr,
196228
maxiters::Int, abstol,

src/LinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ end
217217
ALREADY_WARNED_CUDSS = Ref{Bool}(false)
218218
error_no_cudss_lu(A) = nothing
219219
cudss_loaded(A) = false
220+
is_cusparse(A) = false
220221

221222
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
222223
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,

test/gpu/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
33
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4+
CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
45
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
56
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
67
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

test/gpu/cuda.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using LinearSolve, CUDA, LinearAlgebra, SparseArrays, StableRNGs
2+
using CUDA.CUSPARSE, CUDSS
23
using Test
34

45
CUDA.allowscalar(false)
@@ -91,3 +92,17 @@ prob2 = LinearProblem(transpose(A), b)
9192
sol = solve(prob2, alg; alias = LinearAliasSpecifier(alias_A = false))
9293
@test norm(transpose(A) * sol.u .- b) < 1e-5
9394
end
95+
96+
@testset "CUDSS" begin
97+
T = Float32
98+
n = 100
99+
A_cpu = sprand(T, n, n, 0.05) + I
100+
x_cpu = zeros(T, n)
101+
b_cpu = rand(T, n)
102+
103+
A_gpu_csr = CuSparseMatrixCSR(A_cpu)
104+
b_gpu = CuVector(b_cpu)
105+
106+
prob = LinearProblem(A_gpu_csr, b_gpu)
107+
sol = solve(prob)
108+
end

0 commit comments

Comments
 (0)