1- # XXX : dispatch on `__solve` & `__init`
2- function SciMLBase. solve (
3- prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} , iip,
4- <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
5- alg:: Union{Nothing, AbstractNonlinearAlgorithm} ,
6- args... ;
7- kwargs... ) where {T, V, P, iip}
8- sol, partials = nonlinearsolve_forwarddiff_solve (prob, alg, args... ; kwargs... )
9- dual_soln = nonlinearsolve_dual_solution (sol. u, partials, prob. p)
10- return SciMLBase. build_solution (
11- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
1+ const DualNonlinearProblem = NonlinearProblem{<: Union{Number, <:AbstractArray} , iip,
2+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} } where {iip, T, V, P}
3+ const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
4+ <: Union{Number, <:AbstractArray} , iip,
5+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} } where {iip, T, V, P}
6+ const DualAbstractNonlinearProblem = Union{
7+ DualNonlinearProblem, DualNonlinearLeastSquaresProblem}
8+
9+ for algType in (Nothing, AbstractNonlinearSolveAlgorithm)
10+ @eval function SciMLBase. __solve (
11+ prob:: DualNonlinearProblem , alg:: $ (algType), args... ; kwargs... )
12+ sol, partials = nonlinearsolve_forwarddiff_solve (prob, alg, args... ; kwargs... )
13+ dual_soln = nonlinearsolve_dual_solution (sol. u, partials, prob. p)
14+ return SciMLBase. build_solution (
15+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
16+ end
1217end
1318
1419@concrete mutable struct NonlinearSolveForwardDiffCache
@@ -32,17 +37,19 @@ function reinit_cache!(cache::NonlinearSolveForwardDiffCache;
3237 return cache
3338end
3439
35- function SciMLBase. init (
36- prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} , iip,
37- <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
38- alg:: Union{Nothing, AbstractNonlinearAlgorithm} ,
39- args... ;
40- kwargs... ) where {T, V, P, iip}
41- p = __value (prob. p)
42- newprob = NonlinearProblem (prob. f, __value (prob. u0), p; prob. kwargs... )
43- cache = init (newprob, alg, args... ; kwargs... )
44- return NonlinearSolveForwardDiffCache (
45- cache, newprob, alg, prob. p, p, ForwardDiff. partials (prob. p))
40+ for algType in (
41+ Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
42+ SimpleNonlinearSolve. AbstractSimpleNonlinearSolveAlgorithm,
43+ GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm
44+ )
45+ @eval function SciMLBase. __init (
46+ prob:: DualNonlinearProblem , alg:: $ (algType), args... ; kwargs... )
47+ p = __value (prob. p)
48+ newprob = NonlinearProblem (prob. f, __value (prob. u0), p; prob. kwargs... )
49+ cache = init (newprob, alg, args... ; kwargs... )
50+ return NonlinearSolveForwardDiffCache (
51+ cache, newprob, alg, prob. p, p, ForwardDiff. partials (prob. p))
52+ end
4653end
4754
4855function SciMLBase. solve! (cache:: NonlinearSolveForwardDiffCache )
0 commit comments