Skip to content

Commit 36ef4dd

Browse files
committed
refactor: SimpleImplicitDiscreteSolve
1 parent 5009dc3 commit 36ef4dd

File tree

6 files changed

+85
-134
lines changed

6 files changed

+85
-134
lines changed

lib/SimpleImplicitDiscreteSolve/Project.toml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,17 @@ version = "0.1.0"
55

66
[deps]
77
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
8-
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
98
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
109
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1110
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
12-
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
13-
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1411

1512
[compat]
1613
DiffEqBase = "6.164.1"
17-
OrdinaryDiffEqCore = "1.18.1"
1814
OrdinaryDiffEqSDIRK = "1.2.0"
1915
Reexport = "1.2.2"
2016
SciMLBase = "2.74.1"
2117
SimpleNonlinearSolve = "2.1.0"
22-
SymbolicIndexingInterface = "0.3.38"
2318
Test = "1.11.0"
24-
UnPack = "1.0.2"
2519

2620
[extras]
2721
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"

lib/SimpleImplicitDiscreteSolve/src/SimpleImplicitDiscreteSolve.jl

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@ module SimpleImplicitDiscreteSolve
22

33
using SciMLBase
44
using SimpleNonlinearSolve
5-
using UnPack
6-
using SymbolicIndexingInterface: parameter_symbols
7-
import OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, alg_cache, OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache, get_fsalfirstlast, isfsal, initialize!, perform_step!, isdiscretecache, isdiscretealg, alg_order, beta2_default, beta1_default, dt_required, _initialize_dae!, DefaultInit, BrownFullBasicInit, OverrideInit
8-
95
using Reexport
106
@reexport using DiffEqBase
117

@@ -14,11 +10,85 @@ using Reexport
1410
1511
Simple solver for `ImplicitDiscreteSystems`. Uses `SimpleNewtonRaphson` to solve for the next state at every timestep.
1612
"""
17-
struct SimpleIDSolve <: OrdinaryDiffEqAlgorithm end
13+
struct SimpleIDSolve <: SciMLBase.AbstractODEAlgorithm end
14+
15+
function DiffEqBase.__init(prob::ImplicitDiscreteProblem, alg::SimpleIDSolve; dt = 1)
16+
u0 = prob.u0
17+
p = prob.p
18+
f = prob.f
19+
t = prob.tspan[1]
20+
21+
nlf = isinplace(f) ? (out, u, p) -> f(out, u, u0, p, t) : (u, p) -> f(u, u0, p, t)
22+
prob = NonlinearProblem{isinplace(f)}(nlf, u0, p)
23+
sol = solve(prob, SimpleNewtonRaphson())
24+
sol, (sol.retcode != ReturnCode.Success)
25+
end
26+
27+
function DiffEqBase.solve(prob::ImplicitDiscreteProblem, alg::SimpleIDSolve;
28+
dt = 1,
29+
save_everystep = true,
30+
save_start = true,
31+
adaptive = false,
32+
dense = false,
33+
save_end = true,
34+
kwargs...)
35+
36+
@assert !adaptive
37+
@assert !dense
38+
(initsol, initfail) = DiffEqBase.__init(prob, alg; dt)
39+
if initfail
40+
sol = DiffEqBase.build_solution(prob, alg, prob.tspan[1], u0, k = nothing, stats = nothing, calculate_error = false)
41+
return SciMLBase.solution_new_retcode(sol, ReturnCode.InitialFailure)
42+
end
1843

19-
include("cache.jl")
20-
include("solve.jl")
21-
include("alg_utils.jl")
44+
u0 = initsol.u
45+
tspan = prob.tspan
46+
f = prob.f
47+
p = prob.p
48+
t = tspan[1]
49+
tf = prob.tspan[2]
50+
ts = tspan[1]:dt:tspan[2]
51+
52+
if save_everystep && save_start
53+
us = Vector{typeof(u0)}(undef, length(ts))
54+
us[1] = u0
55+
elseif save_everystep
56+
us = Vector{typeof(u0)}(undef, length(ts) - 1)
57+
elseif save_start
58+
us = Vector{typeof(u0)}(undef, 2)
59+
us[1] = u0
60+
else
61+
us = Vector{typeof(u0)}(undef, 1) # for interface compatibility
62+
end
63+
64+
u = u0
65+
convfail = false
66+
for i in 2:length(ts)
67+
uprev = u
68+
t = ts[i]
69+
nlf = isinplace(f) ? (out, u, p) -> f(out, u, uprev, p, t) : (u, p) -> f(u, uprev, p, t)
70+
nlprob = NonlinearProblem{isinplace(f)}(nlf, uprev, p)
71+
nlsol = solve(nlprob, SimpleNewtonRaphson())
72+
u = nlsol.u
73+
save_everystep && (us[i] = u)
74+
convfail = (nlsol.retcode != ReturnCode.Success)
75+
76+
if convfail
77+
sol = DiffEqBase.build_solution(prob, alg, ts[1:i], us[1:i], k = nothing, stats = nothing, calculate_error = false)
78+
sol = SciMLBase.solution_new_retcode(sol, ReturnCode.ConvergenceFailure)
79+
return sol
80+
end
81+
end
82+
83+
!save_everystep && save_end && (us[end] = u)
84+
sol = DiffEqBase.build_solution(prob, alg, ts, us,
85+
k = nothing, stats = nothing,
86+
calculate_error = false)
87+
88+
DiffEqBase.has_analytic(prob.f) &&
89+
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, dense_errors = false)
90+
sol
91+
end
2292

2393
export SimpleIDSolve
2494

lib/SimpleImplicitDiscreteSolve/src/alg_utils.jl

Lines changed: 0 additions & 19 deletions
This file was deleted.

lib/SimpleImplicitDiscreteSolve/src/cache.jl

Lines changed: 0 additions & 38 deletions
This file was deleted.

lib/SimpleImplicitDiscreteSolve/src/solve.jl

Lines changed: 0 additions & 54 deletions
This file was deleted.

lib/SimpleImplicitDiscreteSolve/test/runtests.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#runtests
22
using Test
33
using SimpleImplicitDiscreteSolve
4-
using OrdinaryDiffEqCore
54
using OrdinaryDiffEqSDIRK
65

76
# Test implicit Euler using ImplicitDiscreteProblem
@@ -19,13 +18,13 @@ using OrdinaryDiffEqSDIRK
1918
u0 = [1., 1.]
2019
tspan = (0., 0.5)
2120

22-
idprob = ImplicitDiscreteProblem(f!, u0, tspan, []; dt = 0.01)
23-
idsol = solve(idprob, SimpleIDSolve())
21+
idprob = ImplicitDiscreteProblem(f!, u0, tspan, [])
22+
idsol = solve(idprob, SimpleIDSolve(), dt = 0.01)
2423

2524
oprob = ODEProblem(lotkavolterra, u0, tspan)
2625
osol = solve(oprob, ImplicitEuler())
2726

28-
@test isapprox(idsol[end], osol[end], atol = 0.1)
27+
@test isapprox(idsol[end-1], osol[end], atol = 0.1)
2928

3029
### free-fall
3130
# y, dy
@@ -40,15 +39,15 @@ using OrdinaryDiffEqSDIRK
4039
nothing
4140
end
4241
u0 = [10., 0.]
43-
tspan = (0, 0.2)
42+
tspan = (0, 0.5)
4443

45-
idprob = ImplicitDiscreteProblem(g!, u0, tspan, []; dt = 0.01)
46-
idsol = solve(idprob, SimpleIDSolve())
44+
idprob = ImplicitDiscreteProblem(g!, u0, tspan, [])
45+
idsol = solve(idprob, SimpleIDSolve(); dt = 0.01)
4746

4847
oprob = ODEProblem(ff, u0, tspan)
4948
osol = solve(oprob, ImplicitEuler())
5049

51-
@test isapprox(idsol[end], osol[end], atol = 0.1)
50+
@test isapprox(idsol[end-1], osol[end], atol = 0.1)
5251
end
5352

5453
@testset "Solver initializes" begin
@@ -65,7 +64,6 @@ end
6564

6665
for ts in 1:tsteps
6766
step!(integ)
68-
@show integ.u
6967
@test integ.u[1]^2 + integ.u[2]^2 16
7068
end
7169
end

0 commit comments

Comments
 (0)