Skip to content

Commit b59ce2b

Browse files
add test and run
1 parent 808f789 commit b59ce2b

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

src/pardiso.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,13 @@ end
6868

6969
function SciMLBase.solve(cache::LinearCache, alg::PardisoJL; kwargs...)
7070
@unpack A, b, u = cache
71-
A = copy(convert(AbstractMatrix,A))
72-
73-
#if cache.isfresh
74-
# Pardiso.set_phase!(cache.cacheval, Pardiso.NUM_FACT)
75-
# Pardiso.pardiso(cache.cacheval, u, A, b)
76-
#end
77-
#Pardiso.set_phase!(cache.cacheval, Pardiso.SOLVE_ITERATIVE_REFINE)
71+
A = convert(AbstractMatrix,A)
7872

79-
Pardiso.set_phase!(cache.cacheval, Pardiso.ANALYSIS_NUM_FACT_SOLVE_REFINE)
73+
if cache.isfresh
74+
Pardiso.set_phase!(cache.cacheval, Pardiso.NUM_FACT)
75+
Pardiso.pardiso(cache.cacheval, u, A, b)
76+
end
77+
Pardiso.set_phase!(cache.cacheval, Pardiso.SOLVE_ITERATIVE_REFINE)
8078
Pardiso.pardiso(cache.cacheval, u, A, b)
8179

8280
return SciMLBase.build_linear_solution(alg,cache.u,nothing,cache)

test/basictests.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using LinearSolve, LinearAlgebra, SparseArrays
22
using Test
3+
import Random
34

45
n = 8
56
A = Matrix(I,n,n)
@@ -232,4 +233,41 @@ end
232233
end
233234
end
234235

236+
@testset "Sparse Precaching" begin
237+
n = 4
238+
Random.seed!(10)
239+
A = sprand(n,n,0.8); A2 = 2.0 .* A
240+
b1 = rand(n); b2 = rand(n)
241+
242+
prob = LinearProblem(copy(A), copy(b1))
243+
linsolve = init(prob,UMFPACKFactorization())
244+
sol11 = solve(linsolve)
245+
linsolve = LinearSolve.set_b(sol1.cache,copy(b2))
246+
sol12 = solve(linsolve)
247+
linsolve = LinearSolve.set_A(sol2.cache,copy(A2))
248+
sol13 = solve(linsolve)
249+
250+
prob = LinearProblem(copy(A), copy(b1))
251+
linsolve = init(prob,KLUFactorization())
252+
sol21 = solve(linsolve)
253+
linsolve = LinearSolve.set_b(sol1.cache,copy(b2))
254+
sol22 = solve(linsolve)
255+
linsolve = LinearSolve.set_A(sol2.cache,copy(A2))
256+
sol23 = solve(linsolve)
257+
258+
linsolve = init(prob,MKLPardisoFactorize())
259+
sol31 = solve(linsolve)
260+
linsolve = LinearSolve.set_b(sol1.cache,copy(b2))
261+
sol32 = solve(linsolve)
262+
linsolve = LinearSolve.set_A(sol2.cache,copy(A2))
263+
sol33 = solve(linsolve)
264+
265+
@test sol11.u sol21.u
266+
@test sol11.u sol31.u
267+
@test sol12.u sol22.u
268+
@test sol12.u sol32.u
269+
@test sol13.u sol23.u
270+
@test sol13.u sol33.u
271+
end
272+
235273
end # testset

0 commit comments

Comments
 (0)