@@ -2,10 +2,6 @@ module SimpleImplicitDiscreteSolve
2
2
3
3
using SciMLBase
4
4
using SimpleNonlinearSolve
5
- using UnPack
6
- using SymbolicIndexingInterface: parameter_symbols
7
- import OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, alg_cache, OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache, get_fsalfirstlast, isfsal, initialize!, perform_step!, isdiscretecache, isdiscretealg, alg_order, beta2_default, beta1_default, dt_required, _initialize_dae!, DefaultInit, BrownFullBasicInit, OverrideInit
8
-
9
5
using Reexport
10
6
@reexport using DiffEqBase
11
7
@@ -14,11 +10,85 @@ using Reexport
14
10
15
11
Simple solver for `ImplicitDiscreteSystems`. Uses `SimpleNewtonRaphson` to solve for the next state at every timestep.
16
12
"""
17
- struct SimpleIDSolve <: OrdinaryDiffEqAlgorithm end
13
+ struct SimpleIDSolve <: SciMLBase.AbstractODEAlgorithm end
14
+
15
+ function DiffEqBase. __init (prob:: ImplicitDiscreteProblem , alg:: SimpleIDSolve ; dt = 1 )
16
+ u0 = prob. u0
17
+ p = prob. p
18
+ f = prob. f
19
+ t = prob. tspan[1 ]
20
+
21
+ nlf = isinplace (f) ? (out, u, p) -> f (out, u, u0, p, t) : (u, p) -> f (u, u0, p, t)
22
+ prob = NonlinearProblem {isinplace(f)} (nlf, u0, p)
23
+ sol = solve (prob, SimpleNewtonRaphson ())
24
+ sol, (sol. retcode != ReturnCode. Success)
25
+ end
26
+
27
+ function DiffEqBase. solve (prob:: ImplicitDiscreteProblem , alg:: SimpleIDSolve ;
28
+ dt = 1 ,
29
+ save_everystep = true ,
30
+ save_start = true ,
31
+ adaptive = false ,
32
+ dense = false ,
33
+ save_end = true ,
34
+ kwargs... )
35
+
36
+ @assert ! adaptive
37
+ @assert ! dense
38
+ (initsol, initfail) = DiffEqBase. __init (prob, alg; dt)
39
+ if initfail
40
+ sol = DiffEqBase. build_solution (prob, alg, prob. tspan[1 ], u0, k = nothing , stats = nothing , calculate_error = false )
41
+ return SciMLBase. solution_new_retcode (sol, ReturnCode. InitialFailure)
42
+ end
18
43
19
- include (" cache.jl" )
20
- include (" solve.jl" )
21
- include (" alg_utils.jl" )
44
+ u0 = initsol. u
45
+ tspan = prob. tspan
46
+ f = prob. f
47
+ p = prob. p
48
+ t = tspan[1 ]
49
+ tf = prob. tspan[2 ]
50
+ ts = tspan[1 ]: dt: tspan[2 ]
51
+
52
+ if save_everystep && save_start
53
+ us = Vector {typeof(u0)} (undef, length (ts))
54
+ us[1 ] = u0
55
+ elseif save_everystep
56
+ us = Vector {typeof(u0)} (undef, length (ts) - 1 )
57
+ elseif save_start
58
+ us = Vector {typeof(u0)} (undef, 2 )
59
+ us[1 ] = u0
60
+ else
61
+ us = Vector {typeof(u0)} (undef, 1 ) # for interface compatibility
62
+ end
63
+
64
+ u = u0
65
+ convfail = false
66
+ for i in 2 : length (ts)
67
+ uprev = u
68
+ t = ts[i]
69
+ nlf = isinplace (f) ? (out, u, p) -> f (out, u, uprev, p, t) : (u, p) -> f (u, uprev, p, t)
70
+ nlprob = NonlinearProblem {isinplace(f)} (nlf, uprev, p)
71
+ nlsol = solve (nlprob, SimpleNewtonRaphson ())
72
+ u = nlsol. u
73
+ save_everystep && (us[i] = u)
74
+ convfail = (nlsol. retcode != ReturnCode. Success)
75
+
76
+ if convfail
77
+ sol = DiffEqBase. build_solution (prob, alg, ts[1 : i], us[1 : i], k = nothing , stats = nothing , calculate_error = false )
78
+ sol = SciMLBase. solution_new_retcode (sol, ReturnCode. ConvergenceFailure)
79
+ return sol
80
+ end
81
+ end
82
+
83
+ ! save_everystep && save_end && (us[end ] = u)
84
+ sol = DiffEqBase. build_solution (prob, alg, ts, us,
85
+ k = nothing , stats = nothing ,
86
+ calculate_error = false )
87
+
88
+ DiffEqBase. has_analytic (prob. f) &&
89
+ DiffEqBase. calculate_solution_errors! (sol; timeseries_errors = true , dense_errors = false )
90
+ sol
91
+ end
22
92
23
93
export SimpleIDSolve
24
94
0 commit comments