1
- # XXX : dispatch on `__solve` & `__init`
2
- function SciMLBase. solve (
3
- prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} , iip,
4
- <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
5
- alg:: Union{Nothing, AbstractNonlinearAlgorithm} ,
6
- args... ;
7
- kwargs... ) where {T, V, P, iip}
8
- sol, partials = nonlinearsolve_forwarddiff_solve (prob, alg, args... ; kwargs... )
9
- dual_soln = nonlinearsolve_dual_solution (sol. u, partials, prob. p)
10
- return SciMLBase. build_solution (
11
- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
1
+ const DualNonlinearProblem = NonlinearProblem{<: Union{Number, <:AbstractArray} , iip,
2
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} } where {iip, T, V, P}
3
+ const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
4
+ <: Union{Number, <:AbstractArray} , iip,
5
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} } where {iip, T, V, P}
6
+ const DualAbstractNonlinearProblem = Union{
7
+ DualNonlinearProblem, DualNonlinearLeastSquaresProblem}
8
+
9
+ for algType in (
10
+ Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
11
+ GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm,
12
+ LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL,
13
+ SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL
14
+ )
15
+ @eval function SciMLBase. __solve (
16
+ prob:: DualNonlinearProblem , alg:: $ (algType), args... ; kwargs... )
17
+ sol, partials = nonlinearsolve_forwarddiff_solve (prob, alg, args... ; kwargs... )
18
+ dual_soln = nonlinearsolve_dual_solution (sol. u, partials, prob. p)
19
+ return SciMLBase. build_solution (
20
+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
21
+ end
12
22
end
13
23
14
24
@concrete mutable struct NonlinearSolveForwardDiffCache
@@ -32,17 +42,21 @@ function reinit_cache!(cache::NonlinearSolveForwardDiffCache;
32
42
return cache
33
43
end
34
44
35
- function SciMLBase. init (
36
- prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} , iip,
37
- <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
38
- alg:: Union{Nothing, AbstractNonlinearAlgorithm} ,
39
- args... ;
40
- kwargs... ) where {T, V, P, iip}
41
- p = __value (prob. p)
42
- newprob = NonlinearProblem (prob. f, __value (prob. u0), p; prob. kwargs... )
43
- cache = init (newprob, alg, args... ; kwargs... )
44
- return NonlinearSolveForwardDiffCache (
45
- cache, newprob, alg, prob. p, p, ForwardDiff. partials (prob. p))
45
+ for algType in (
46
+ Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
47
+ SimpleNonlinearSolve. AbstractSimpleNonlinearSolveAlgorithm,
48
+ GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm,
49
+ LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL,
50
+ SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL
51
+ )
52
+ @eval function SciMLBase. __init (
53
+ prob:: DualNonlinearProblem , alg:: $ (algType), args... ; kwargs... )
54
+ p = __value (prob. p)
55
+ newprob = NonlinearProblem (prob. f, __value (prob. u0), p; prob. kwargs... )
56
+ cache = init (newprob, alg, args... ; kwargs... )
57
+ return NonlinearSolveForwardDiffCache (
58
+ cache, newprob, alg, prob. p, p, ForwardDiff. partials (prob. p))
59
+ end
46
60
end
47
61
48
62
function SciMLBase. solve! (cache:: NonlinearSolveForwardDiffCache )
0 commit comments