Skip to content

Commit 728c525

Browse files
Merge pull request #105 from vpuri3/vp-krylovkit
KrylovKit wrapper
2 parents 5b9b492 + cc2d5da commit 728c525

File tree

3 files changed

+55
-1
lines changed

3 files changed

+55
-1
lines changed

src/LinearSolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
5050
UMFPACKFactorization, KLUFactorization
5151
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES,
5252
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
53-
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES
53+
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES,
54+
KrylovKitJL, KrylovKitJL_CG, KrylovKitJL_GMRES
5455

5556
end

src/iterative_wrappers.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,43 @@ function purge_history!(iter::IterativeSolvers.GMRESIterable, x, b)
231231
iter.β = iter.residual.current
232232
nothing
233233
end
234+
235+
## KrylovKit.jl
236+
237+
struct KrylovKitJL{F,A,I,K} <: AbstractKrylovSubspaceMethod
238+
KrylovAlg::F
239+
gmres_restart::I
240+
args::A
241+
kwargs::K
242+
end
243+
244+
function KrylovKitJL(args...;
245+
KrylovAlg = KrylovKit.GMRES, gmres_restart = 0,
246+
kwargs...)
247+
return KrylovJL(KrylovAlg, gmres_restart, args, kwargs)
248+
end
249+
250+
KrylovKitJL_CG(args...;kwargs...) =
251+
KrylovKitJL(args...; KrylovAlg=KrylovKit.CG, kwargs...)
252+
KrylovKitJL_GMRES(args...;kwargs...) =
253+
KrylovKitJL(args...; KrylovAlg=KrylovKit.GMRES, kwargs...)
254+
255+
function SciMLBase.solve(cache::LinearCache, alg::KrylovKitJL, kwargs...)
256+
257+
atol = float(cache.abstol)
258+
rtol = float(cache.reltol)
259+
maxiter = cache.maxiters
260+
verbosity = cache.verbose ? 1 : 0
261+
krylovdim = (alg.gmres_restart == 0) ? min(20, size(A,1)) : alg.gmres_restart
262+
263+
kwargs = (atol=atol, rtol=rtol, maxiter=maxiter, verbosity=verbosity,
264+
krylovdim = krylovdim, alg.kwargs...)
265+
266+
x, info = KrylovKit.linsolve(cache.A, cache.b, cache.u, alg.KrylovAlg)
267+
268+
copy!(cache.u, x)
269+
resid = info.normres
270+
retcode = info.converged == 1 ? :Default : :DidNotConverge
271+
iters = info.numiter
272+
return SciMLBase.build_linear_solution(alg, cache.u, resid, cache; retcode = retcode, iters = iters)
273+
end

test/basictests.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,19 @@ end
161161
end
162162
end
163163

164+
@testset "KrylovKit" begin
165+
kwargs = (;gmres_restart=5,)
166+
for alg in (
167+
("Default", KrylovKitJL(kwargs...)),
168+
("CG", KrylovKitJL_CG(kwargs...)),
169+
("GMRES",KrylovKitJL_GMRES(kwargs...)),
170+
)
171+
@testset "$(alg[1])" begin
172+
test_interface(alg[2], prob1, prob2)
173+
end
174+
end
175+
end
176+
164177
@testset "PardisoJL" begin
165178
@test_throws UndefVarError alg = PardisoJL()
166179

0 commit comments

Comments
 (0)