Skip to content

Commit b7e50f0

Browse files
Merge pull request #547 from oscardssmith/os/precsisfresh
use isfresh like mechinisim for precs
2 parents d7f0fb3 + 8d33f02 commit b7e50f0

File tree

5 files changed

+58
-27
lines changed

5 files changed

+58
-27
lines changed

ext/LinearSolveHYPREExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,13 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
8686
assumptions)
8787
Tc = typeof(cacheval)
8888
isfresh = true
89+
precsisfresh = false
8990

9091
cache = LinearCache{
9192
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
9293
typeof(Pl), typeof(Pr), typeof(reltol),
9394
typeof(__issquare(assumptions)), typeof(sensealg)
94-
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
95+
}(A, b, u0, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol,
9596
maxiters, verbose, assumptions, sensealg)
9697
return cache
9798
end

ext/LinearSolveIterativeSolversExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ function LinearSolve.init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, max
9090
end
9191

9292
function SciMLBase.solve!(cache::LinearCache, alg::IterativeSolversJL; kwargs...)
93+
if cache.precsisfresh && !isnothing(alg.precs)
94+
Pl, Pr = alg.precs(cache.Pl, cache.Pr)
95+
cache.Pl = Pl
96+
cache.Pr = Pr
97+
cache.precsisfresh = false
98+
end
9399
if cache.isfresh || !(alg isa IterativeSolvers.GMRESIterable)
94100
solver = LinearSolve.init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl,
95101
cache.Pr,

src/common.jl

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, S}
7373
alg::Talg
7474
cacheval::Tc # store alg cache here
7575
isfresh::Bool # false => cacheval is set wrt A, true => update cacheval wrt A
76+
precsisfresh::Bool # false => PR,PL is set wrt A, true => update PR,PL wrt A
7677
Pl::Tl # preconditioners
7778
Pr::Tr
7879
abstol::Ttol
@@ -85,18 +86,10 @@ end
8586

8687
function Base.setproperty!(cache::LinearCache, name::Symbol, x)
8788
if name === :A
88-
if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs)
89-
Pl, Pr = cache.alg.precs(x, cache.p)
90-
setfield!(cache, :Pl, Pl)
91-
setfield!(cache, :Pr, Pr)
92-
end
9389
setfield!(cache, :isfresh, true)
90+
setfield!(cache, :precsisfresh, true)
9491
elseif name === :p
95-
if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs)
96-
Pl, Pr = cache.alg.precs(cache.A, x)
97-
setfield!(cache, :Pl, Pl)
98-
setfield!(cache, :Pr, Pr)
99-
end
92+
setfield!(cache, :precsisfresh, true)
10093
elseif name === :b
10194
# In case there is something that needs to be done when b is updated
10295
update_cacheval!(cache, :b, x)
@@ -208,11 +201,12 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
208201
cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose,
209202
assumptions)
210203
isfresh = true
204+
precsisfresh = false
211205
Tc = typeof(cacheval)
212206

213207
cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
214208
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
215-
typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
209+
typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol,
216210
maxiters, verbose, assumptions, sensealg)
217211
return cache
218212
end
@@ -223,27 +217,26 @@ function SciMLBase.reinit!(cache::LinearCache;
223217
b = cache.b,
224218
u = cache.u,
225219
p = nothing,
226-
reinit_cache = false,)
220+
reinit_cache = false,
221+
reuse_precs = false)
227222
(; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache
228223

229-
precs = (hasproperty(alg, :precs) && !isnothing(alg.precs)) ? alg.precs : DEFAULT_PRECS
230-
Pl, Pr = if isnothing(A) || isnothing(p)
231-
if isnothing(A)
232-
A = cache.A
233-
end
234-
if isnothing(p)
235-
p = cache.p
236-
end
237-
precs(A, p)
238-
else
239-
(cache.Pl, cache.Pr)
240-
end
241-
isfresh = true
242224

225+
isfresh = !isnothing(A)
226+
precsisfresh = !reuse_precs && (isfresh || !isnothing(p))
227+
isfresh |= cache.isfresh
228+
precsisfresh |= cache.precsisfresh
229+
230+
A = isnothing(A) ? cache.A : A
231+
b = isnothing(b) ? cache.b : b
232+
u = isnothing(u) ? cache.u : u
233+
p = isnothing(p) ? cache.p : p
234+
Pl = cache.Pl
235+
Pr = cache.Pr
243236
if reinit_cache
244237
return LinearCache{typeof(A), typeof(b), typeof(u), typeof(p), typeof(alg), typeof(cacheval),
245238
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
246-
typeof(sensealg)}(A, b, u, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
239+
typeof(sensealg)}(A, b, u, p, alg, cacheval, precsisfresh, isfresh, Pl, Pr, abstol, reltol,
247240
maxiters, verbose, assumptions, sensealg)
248241
else
249242
cache.A = A
@@ -253,6 +246,7 @@ function SciMLBase.reinit!(cache::LinearCache;
253246
cache.Pl = Pl
254247
cache.Pr = Pr
255248
cache.isfresh = true
249+
cache.precsisfresh = precsisfresh
256250
end
257251
end
258252

src/iterative_wrappers.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,12 @@ function init_cacheval(alg::KrylovJL, A, b, u, Pl, Pr, maxiters::Int, abstol, re
225225
end
226226

227227
function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
228+
if cache.precsisfresh && !isnothing(alg.precs)
229+
Pl, Pr = alg.precs(cache.A, cache.p)
230+
cache.Pl = Pl
231+
cache.Pr = Pr
232+
cache.precsisfresh = false
233+
end
228234
if cache.isfresh
229235
solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr,
230236
cache.maxiters, cache.abstol, cache.reltol, cache.verbose,

test/basictests.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,30 @@ end
284284
end
285285
end
286286

287+
@testset "Reuse precs" begin
288+
num_precs_calls = 0
289+
290+
function countingprecs(A, p = nothing)
291+
num_precs_calls += 1
292+
(BlockJacobiPreconditioner(A, 2), I)
293+
end
294+
295+
n = 10
296+
A = spdiagm(-1 => -ones(n - 1), 0 => fill(10.0, n), 1 => -ones(n - 1))
297+
b = rand(n)
298+
p = LinearProblem(A, b)
299+
x0 = solve(p, KrylovJL_CG(precs = countingprecs, ldiv = false))
300+
cache = x0.cache
301+
x0 = copy(x0)
302+
for i in 4:(n - 3)
303+
A[i, i + 3] -= 1.0e-4
304+
A[i - 3, i] -= 1.0e-4
305+
end
306+
LinearSolve.reinit!(cache; A, reuse_precs = true)
307+
x1 = copy(solve!(cache))
308+
@test all(x0 .< x1) && num_precs_calls == 1
309+
end
310+
287311
if VERSION >= v"1.9-"
288312
@testset "IterativeSolversJL" begin
289313
kwargs = (; gmres_restart = 5)

0 commit comments

Comments
 (0)