Skip to content

Commit 7329003

Browse files
patch default for CUDA
1 parent 7a4dace commit 7329003

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "1.0.0"
4+
version = "1.0.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/default.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
4343
# This catches the case where A is a CuMatrix
4444
# Which does not have LU fully defined
4545
elseif !(A isa AbstractDiffEqOperator)
46-
alg = QRFactorization()
46+
alg = QRFactorization(false)
4747
SciMLBase.solve(cache, alg, args...; kwargs...)
4848

4949
# Not factorizable operator, default to only using A*x
@@ -98,7 +98,7 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,
9898
# This catches the case where A is a CuMatrix
9999
# Which does not have LU fully defined
100100
elseif !(A isa AbstractDiffEqOperator)
101-
alg = QRFactorization()
101+
alg = QRFactorization(false)
102102
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
103103

104104
# Not factorizable operator, default to only using A*x

src/factorization.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,16 @@ end
129129
struct QRFactorization{P} <: AbstractFactorization
130130
pivot::P
131131
blocksize::Int
132+
inplace::Bool
132133
end
133134

134-
function QRFactorization()
135+
function QRFactorization(inplace = true)
135136
pivot = @static if VERSION < v"1.7beta"
136137
Val(false)
137138
else
138139
NoPivot()
139140
end
140-
QRFactorization(pivot, 16)
141+
QRFactorization(pivot, 16, inplace)
141142
end
142143

143144
function do_factorization(alg::QRFactorization, A, b, u)
@@ -147,7 +148,11 @@ function do_factorization(alg::QRFactorization, A, b, u)
147148
if A isa AbstractDiffEqOperator
148149
A = A.A
149150
end
150-
fact = qr!(A, alg.pivot; blocksize = alg.blocksize)
151+
if alg.inplace
152+
fact = qr!(A, alg.pivot; blocksize = alg.blocksize)
153+
else
154+
fact = qr(A) # CUDA.jl does not allow other args!
155+
end
151156
return fact
152157
end
153158

0 commit comments

Comments
 (0)