Skip to content

Commit a51b0e2

Browse files
Merge pull request #608 from SciML/nonsquare
Fix non-square sparse matrix preallocations for defaults
2 parents 1fca18c + 23f73e4 commit a51b0e2

File tree

4 files changed

+118
-31
lines changed

4 files changed

+118
-31
lines changed

ext/LinearSolveSparseArraysExt.jl

Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,6 @@ function LinearSolve.init_cacheval(alg::RFLUFactorization,
2727
nothing, nothing
2828
end
2929

30-
function LinearSolve.init_cacheval(
31-
alg::QRFactorization, A::Symmetric{<:Number, <:SparseMatrixCSC}, b, u, Pl, Pr,
32-
maxiters::Int, abstol, reltol, verbose::Bool,
33-
assumptions::OperatorAssumptions)
34-
return nothing
35-
end
36-
3730
function LinearSolve.handle_sparsematrixcsc_lu(A::AbstractSparseMatrixCSC)
3831
lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
3932
check = false)
@@ -71,22 +64,51 @@ const PREALLOCATED_UMFPACK = SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(0, 0
7164
Int[], Float64[]))
7265

7366
function LinearSolve.init_cacheval(
74-
alg::UMFPACKFactorization, A::SparseMatrixCSC{Float64, Int}, b, u,
67+
alg::Union{LUFactorization, GenericLUFactorization}, A::AbstractSparseArray{<:Number, <:Integer}, b, u,
68+
Pl, Pr,
69+
maxiters::Int, abstol, reltol,
70+
verbose::Bool, assumptions::OperatorAssumptions)
71+
nothing
72+
end
73+
74+
function LinearSolve.init_cacheval(
75+
alg::UMFPACKFactorization, A::AbstractArray, b, u,
76+
Pl, Pr,
77+
maxiters::Int, abstol, reltol,
78+
verbose::Bool, assumptions::OperatorAssumptions)
79+
nothing
80+
end
81+
82+
function LinearSolve.init_cacheval(
83+
alg::Union{LUFactorization, GenericLUFactorization}, A::AbstractSparseArray{Float64, Int64}, b, u,
7584
Pl, Pr,
7685
maxiters::Int, abstol, reltol,
7786
verbose::Bool, assumptions::OperatorAssumptions)
7887
PREALLOCATED_UMFPACK
7988
end
8089

8190
function LinearSolve.init_cacheval(
82-
alg::UMFPACKFactorization, A::AbstractSparseArray{Float64}, b, u, Pl, Pr,
91+
alg::Union{LUFactorization, GenericLUFactorization}, A::AbstractSparseArray{Float64, Int32}, b, u,
92+
Pl, Pr,
93+
maxiters::Int, abstol, reltol,
94+
verbose::Bool, assumptions::OperatorAssumptions)
95+
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{Float64, Int32}(0, 0, [Int32(1)], Int32[], Float64[]))
96+
end
97+
98+
function LinearSolve.init_cacheval(
99+
alg::UMFPACKFactorization, A::AbstractSparseArray{Float64, Int}, b, u, Pl, Pr,
83100
maxiters::Int, abstol,
84101
reltol,
85102
verbose::Bool, assumptions::OperatorAssumptions)
86-
A = convert(AbstractMatrix, A)
87-
zerobased = SparseArrays.getcolptr(A)[1] == 0
88-
return SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(size(A)..., getcolptr(A),
89-
rowvals(A), nonzeros(A)))
103+
PREALLOCATED_UMFPACK
104+
end
105+
106+
function LinearSolve.init_cacheval(
107+
alg::UMFPACKFactorization, A::AbstractSparseArray{Float64, Int32}, b, u,
108+
Pl, Pr,
109+
maxiters::Int, abstol, reltol,
110+
verbose::Bool, assumptions::OperatorAssumptions)
111+
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{Float64, Int32}(0, 0, [Int32(1)], Int32[], Float64[]))
90112
end
91113

92114
function SciMLBase.solve!(
@@ -118,7 +140,7 @@ function SciMLBase.solve!(
118140
F = LinearSolve.@get_cacheval(cache, :UMFPACKFactorization)
119141
if F.status == SparseArrays.UMFPACK.UMFPACK_OK
120142
y = ldiv!(cache.u, F, cache.b)
121-
SciMLBase.build_linear_solution(alg, y, nothing, cache)
143+
SciMLBase.build_linear_solution(alg, y, nothing, cache; retcode = ReturnCode.Success)
122144
else
123145
SciMLBase.build_linear_solution(
124146
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
@@ -129,21 +151,27 @@ const PREALLOCATED_KLU = KLU.KLUFactorization(SparseMatrixCSC(0, 0, [1], Int[],
129151
Float64[]))
130152

131153
function LinearSolve.init_cacheval(
132-
alg::KLUFactorization, A::SparseMatrixCSC{Float64, Int}, b, u, Pl,
154+
alg::KLUFactorization, A::AbstractArray, b, u, Pl,
133155
Pr,
134156
maxiters::Int, abstol, reltol,
135157
verbose::Bool, assumptions::OperatorAssumptions)
158+
nothing
159+
end
160+
161+
function LinearSolve.init_cacheval(
162+
alg::KLUFactorization, A::AbstractSparseArray{Float64, Int64}, b, u, Pl, Pr,
163+
maxiters::Int, abstol,
164+
reltol,
165+
verbose::Bool, assumptions::OperatorAssumptions)
136166
PREALLOCATED_KLU
137167
end
138168

139169
function LinearSolve.init_cacheval(
140-
alg::KLUFactorization, A::AbstractSparseArray{Float64}, b, u, Pl, Pr,
170+
alg::KLUFactorization, A::AbstractSparseArray{Float64, Int32}, b, u, Pl, Pr,
141171
maxiters::Int, abstol,
142172
reltol,
143173
verbose::Bool, assumptions::OperatorAssumptions)
144-
A = convert(AbstractMatrix, A)
145-
return KLU.KLUFactorization(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
146-
nonzeros(A)))
174+
KLU.KLUFactorization(SparseMatrixCSC{Float64, Int32}(0, 0, [Int32(1)], Int32[], Float64[]))
147175
end
148176

149177
# TODO: guard this against errors
@@ -173,7 +201,7 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::KLUFactorization;
173201
F = LinearSolve.@get_cacheval(cache, :KLUFactorization)
174202
if F.common.status == KLU.KLU_OK
175203
y = ldiv!(cache.u, F, cache.b)
176-
SciMLBase.build_linear_solution(alg, y, nothing, cache)
204+
SciMLBase.build_linear_solution(alg, y, nothing, cache; retcode = ReturnCode.Success)
177205
else
178206
SciMLBase.build_linear_solution(
179207
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
@@ -249,6 +277,34 @@ function LinearSolve.defaultalg(
249277
end
250278
end
251279

280+
# SPQR Handling
281+
function LinearSolve.init_cacheval(
282+
alg::QRFactorization, A::AbstractSparseArray{<:Number, <:Integer}, b, u,
283+
Pl, Pr,
284+
maxiters::Int, abstol, reltol,
285+
verbose::Bool, assumptions::OperatorAssumptions)
286+
nothing
287+
end
288+
289+
function LinearSolve.init_cacheval(alg::QRFactorization, A::SparseMatrixCSC{Float64, Int}, b, u, Pl, Pr,
290+
maxiters::Int, abstol, reltol, verbose::Bool,
291+
assumptions::OperatorAssumptions)
292+
LinearSolve.ArrayInterface.qr_instance(convert(AbstractMatrix, A), alg.pivot)
293+
end
294+
295+
function LinearSolve.init_cacheval(alg::QRFactorization, A::SparseMatrixCSC{Float64, Int32}, b, u, Pl, Pr,
296+
maxiters::Int, abstol, reltol, verbose::Bool,
297+
assumptions::OperatorAssumptions)
298+
LinearSolve.ArrayInterface.qr_instance(convert(AbstractMatrix, A), alg.pivot)
299+
end
300+
301+
function LinearSolve.init_cacheval(
302+
alg::QRFactorization, A::Symmetric{<:Number, <:SparseMatrixCSC}, b, u, Pl, Pr,
303+
maxiters::Int, abstol, reltol, verbose::Bool,
304+
assumptions::OperatorAssumptions)
305+
return nothing
306+
end
307+
252308
LinearSolve.PrecompileTools.@compile_workload begin
253309
A = sprand(4, 4, 0.3) + I
254310
b = rand(4)

ext/LinearSolveSparspakExt.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,23 @@ function LinearSolve.init_cacheval(
1717
end
1818

1919
function LinearSolve.init_cacheval(
20-
::SparspakFactorization, A::AbstractSparseMatrixCSC, b, u, Pl, Pr, maxiters::Int, abstol,
20+
::SparspakFactorization, A::AbstractSparseMatrixCSC{Tv, Ti}, b, u, Pl, Pr, maxiters::Int, abstol,
2121
reltol,
22-
verbose::Bool, assumptions::OperatorAssumptions)
23-
A = convert(AbstractMatrix, A)
24-
if A isa SparseArrays.AbstractSparseArray
25-
return sparspaklu(
26-
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
27-
nonzeros(A)),
28-
factorize = false)
22+
verbose::Bool, assumptions::OperatorAssumptions) where {Tv, Ti}
23+
24+
if size(A,1) == size(A,2)
25+
A = convert(AbstractMatrix, A)
26+
if A isa SparseArrays.AbstractSparseArray
27+
return sparspaklu(
28+
SparseMatrixCSC{Tv, Ti}(size(A)..., getcolptr(A), rowvals(A),
29+
nonzeros(A)),
30+
factorize = false)
31+
else
32+
return sparspaklu(SparseMatrixCSC{Tv, Ti}(zero(Ti), zero(Ti), [one(Ti)], Ti[], eltype(A)[]),
33+
factorize = false)
34+
end
2935
else
30-
return sparspaklu(SparseMatrixCSC(0, 0, [1], Int[], eltype(A)[]),
31-
factorize = false)
36+
PREALLOCATED_SPARSEPAK
3237
end
3338
end
3439

src/LinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ end
174174

175175
y = _ldiv!(cache.u, @get_cacheval(cache, $(Meta.quot(defaultalg_symbol(alg)))),
176176
cache.b)
177-
return SciMLBase.build_linear_solution(alg, y, nothing, cache)
177+
return SciMLBase.build_linear_solution(alg, y, nothing, cache; retcode = ReturnCode.Success)
178178
end
179179
end
180180

test/default_algs.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,29 @@ cache.A = [2.0 1.0
144144
sol = solve!(cache)
145145

146146
@test !SciMLBase.successful_retcode(sol.retcode)
147+
148+
## Non-square Sparse Defaults
149+
# https://github.com/SciML/NonlinearSolve.jl/issues/599
150+
A = SparseMatrixCSC{Float64, Int64}([
151+
1.0 0.0
152+
1.0 1.0
153+
])
154+
b = ones(2)
155+
A2 = hcat(A,A)
156+
prob = LinearProblem(A, b)
157+
@test SciMLBase.successful_retcode(solve(prob))
158+
159+
prob2 = LinearProblem(A2, b)
160+
@test SciMLBase.successful_retcode(solve(prob2))
161+
162+
A = SparseMatrixCSC{Float64, Int32}([
163+
1.0 0.0
164+
1.0 1.0
165+
])
166+
b = ones(2)
167+
A2 = hcat(A,A)
168+
prob = LinearProblem(A, b)
169+
@test_broken SciMLBase.successful_retcode(solve(prob))
170+
171+
prob2 = LinearProblem(A2, b)
172+
@test SciMLBase.successful_retcode(solve(prob2))

0 commit comments

Comments
 (0)