1
- # XXX : dispatch on `__solve` & `__init`
2
- function SciMLBase. solve (
3
- prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} , iip,
4
- <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
5
- alg:: Union{Nothing, AbstractNonlinearAlgorithm} ,
6
- args... ;
7
- kwargs... ) where {T, V, P, iip}
8
- sol, partials = nonlinearsolve_forwarddiff_solve (prob, alg, args... ; kwargs... )
9
- dual_soln = nonlinearsolve_dual_solution (sol. u, partials, prob. p)
10
- return SciMLBase. build_solution (
11
- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
1
+ const DualNonlinearProblem = NonlinearProblem{<: Union{Number, <:AbstractArray} , iip,
2
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} } where {iip, T, V, P}
3
+ const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
4
+ <: Union{Number, <:AbstractArray} , iip,
5
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} } where {iip, T, V, P}
6
+ const DualAbstractNonlinearProblem = Union{
7
+ DualNonlinearProblem, DualNonlinearLeastSquaresProblem}
8
+
9
+ for algType in (Nothing, AbstractNonlinearSolveAlgorithm)
10
+ @eval function SciMLBase. __solve (
11
+ prob:: DualNonlinearProblem , alg:: $ (algType), args... ; kwargs... )
12
+ sol, partials = nonlinearsolve_forwarddiff_solve (prob, alg, args... ; kwargs... )
13
+ dual_soln = nonlinearsolve_dual_solution (sol. u, partials, prob. p)
14
+ return SciMLBase. build_solution (
15
+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
16
+ end
12
17
end
13
18
14
19
@concrete mutable struct NonlinearSolveForwardDiffCache
@@ -32,17 +37,19 @@ function reinit_cache!(cache::NonlinearSolveForwardDiffCache;
32
37
return cache
33
38
end
34
39
35
- function SciMLBase. init (
36
- prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} , iip,
37
- <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
38
- alg:: Union{Nothing, AbstractNonlinearAlgorithm} ,
39
- args... ;
40
- kwargs... ) where {T, V, P, iip}
41
- p = __value (prob. p)
42
- newprob = NonlinearProblem (prob. f, __value (prob. u0), p; prob. kwargs... )
43
- cache = init (newprob, alg, args... ; kwargs... )
44
- return NonlinearSolveForwardDiffCache (
45
- cache, newprob, alg, prob. p, p, ForwardDiff. partials (prob. p))
40
+ for algType in (
41
+ Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
42
+ SimpleNonlinearSolve. AbstractSimpleNonlinearSolveAlgorithm,
43
+ GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm
44
+ )
45
+ @eval function SciMLBase. __init (
46
+ prob:: DualNonlinearProblem , alg:: $ (algType), args... ; kwargs... )
47
+ p = __value (prob. p)
48
+ newprob = NonlinearProblem (prob. f, __value (prob. u0), p; prob. kwargs... )
49
+ cache = init (newprob, alg, args... ; kwargs... )
50
+ return NonlinearSolveForwardDiffCache (
51
+ cache, newprob, alg, prob. p, p, ForwardDiff. partials (prob. p))
52
+ end
46
53
end
47
54
48
55
function SciMLBase. solve! (cache:: NonlinearSolveForwardDiffCache )
0 commit comments