Skip to content

Commit af4551c

Browse files
Merge pull request #88 from SciML/pardiso
Fix Pardiso precaching form
2 parents 8213cd7 + d865767 commit af4551c

File tree

3 files changed

+49
-14
lines changed

3 files changed

+49
-14
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ julia = "1.6"
3838
[extras]
3939
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
4040
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
41+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4142
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4243
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4344

4445
[targets]
45-
test = ["Test", "Pardiso", "Pkg", "SafeTestsets"]
46+
test = ["Test", "Pardiso", "Pkg", "Random", "SafeTestsets"]

src/pardiso.jl

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,15 @@ Base.@kwdef struct PardisoJL <: SciMLLinearSolveAlgorithm
66
dparm::Union{Vector{Tuple{Int,Int}}, Nothing} = nothing
77
end
88

9-
MKLPardisoFactorize(;kwargs...) = PardisoJL(;kwargs...)
10-
MKLPardisoIterate(;kwargs...) = PardisoJL(;kwargs...)
9+
MKLPardisoFactorize(;kwargs...) = PardisoJL(;solver_type = 0,kwargs...)
10+
MKLPardisoIterate(;kwargs...) = PardisoJL(;solver_type = 1,kwargs...)
11+
needs_concrete_A(alg::PardisoJL) = true
1112

1213
# TODO schur complement functionality
1314

1415
function init_cacheval(alg::PardisoJL, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
1516
@unpack nprocs, solver_type, matrix_type, iparm, dparm = alg
16-
17-
if A isa DiffEqArrayOperator
18-
A = A.A
19-
end
17+
A = convert(AbstractMatrix,A)
2018

2119
solver =
2220
if Pardiso.PARDISO_LOADED[]
@@ -59,6 +57,9 @@ function init_cacheval(alg::PardisoJL, A, b, u, Pl, Pr, maxiters, abstol, reltol
5957
end
6058
end
6159

60+
# Make sure to say it's transposed because its CSC not CSR
61+
Pardiso.set_iparm!(solver,12, 1)
62+
6263
Pardiso.set_phase!(solver, Pardiso.ANALYSIS)
6364
Pardiso.pardiso(solver, u, A, b)
6465

@@ -67,23 +68,18 @@ end
6768

6869
function SciMLBase.solve(cache::LinearCache, alg::PardisoJL; kwargs...)
6970
@unpack A, b, u = cache
70-
if A isa DiffEqArrayOperator
71-
A = A.A
72-
end
71+
A = convert(AbstractMatrix,A)
7372

7473
if cache.isfresh
7574
Pardiso.set_phase!(cache.cacheval, Pardiso.NUM_FACT)
76-
Pardiso.pardiso(cache.cacheval, cache.u, cache.A, cache.b)
75+
Pardiso.pardiso(cache.cacheval, u, A, b)
7776
end
78-
7977
Pardiso.set_phase!(cache.cacheval, Pardiso.SOLVE_ITERATIVE_REFINE)
8078
Pardiso.pardiso(cache.cacheval, u, A, b)
8179

8280
return SciMLBase.build_linear_solution(alg,cache.u,nothing,cache)
8381
end
8482

85-
needsconcreteA(alg::PardisoJL) = true
86-
8783
# Add finalizer to release memory
8884
# Pardiso.set_phase!(cache.cacheval, Pardiso.RELEASE_ALL)
8985

test/basictests.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using LinearSolve, LinearAlgebra, SparseArrays
22
using Test
3+
import Random
34

45
n = 8
56
A = Matrix(I,n,n)
@@ -232,4 +233,41 @@ end
232233
end
233234
end
234235

236+
@testset "Sparse Precaching" begin
237+
n = 4
238+
Random.seed!(10)
239+
A = sprand(n,n,0.8); A2 = 2.0 .* A
240+
b1 = rand(n); b2 = rand(n)
241+
242+
prob = LinearProblem(copy(A), copy(b1))
243+
linsolve = init(prob,UMFPACKFactorization())
244+
sol11 = solve(linsolve)
245+
linsolve = LinearSolve.set_b(sol11.cache,copy(b2))
246+
sol12 = solve(linsolve)
247+
linsolve = LinearSolve.set_A(sol12.cache,copy(A2))
248+
sol13 = solve(linsolve)
249+
250+
prob = LinearProblem(copy(A), copy(b1))
251+
linsolve = init(prob,KLUFactorization())
252+
sol21 = solve(linsolve)
253+
linsolve = LinearSolve.set_b(sol21.cache,copy(b2))
254+
sol22 = solve(linsolve)
255+
linsolve = LinearSolve.set_A(sol22.cache,copy(A2))
256+
sol23 = solve(linsolve)
257+
258+
linsolve = init(prob,MKLPardisoFactorize())
259+
sol31 = solve(linsolve)
260+
linsolve = LinearSolve.set_b(sol31.cache,copy(b2))
261+
sol32 = solve(linsolve)
262+
linsolve = LinearSolve.set_A(sol32.cache,copy(A2))
263+
sol33 = solve(linsolve)
264+
265+
@test sol11.u sol21.u
266+
@test sol11.u sol31.u
267+
@test sol12.u sol22.u
268+
@test sol12.u sol32.u
269+
@test sol13.u sol23.u
270+
@test sol13.u sol33.u
271+
end
272+
235273
end # testset

0 commit comments

Comments
 (0)