Skip to content

Commit f5a06cb

Browse files
committed
fix: dispatch forwarddiff on __init and __solve
1 parent 89d76b0 commit f5a06cb

File tree

2 files changed

+38
-23
lines changed

2 files changed

+38
-23
lines changed

src/NonlinearSolve.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ include("descent/damped_newton.jl")
6767
include("descent/geodesic_acceleration.jl")
6868

6969
include("internal/jacobian.jl")
70-
include("internal/forward_diff.jl")
7170
include("internal/linear_solve.jl")
7271
include("internal/termination.jl")
7372
include("internal/tracing.jl")
@@ -93,6 +92,8 @@ include("algorithms/levenberg_marquardt.jl")
9392
include("algorithms/trust_region.jl")
9493
include("algorithms/extension_algs.jl")
9594

95+
include("internal/forward_diff.jl") # we need to define after the algorithms
96+
9697
include("utils.jl")
9798
include("default.jl")
9899

src/internal/forward_diff.jl

+36-22
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
1-
# XXX: dispatch on `__solve` & `__init`
2-
function SciMLBase.solve(
3-
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
4-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
5-
alg::Union{Nothing, AbstractNonlinearAlgorithm},
6-
args...;
7-
kwargs...) where {T, V, P, iip}
8-
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
9-
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
10-
return SciMLBase.build_solution(
11-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
1+
const DualNonlinearProblem = NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
2+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P}
3+
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
4+
<:Union{Number, <:AbstractArray}, iip,
5+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P}
6+
const DualAbstractNonlinearProblem = Union{
7+
DualNonlinearProblem, DualNonlinearLeastSquaresProblem}
8+
9+
for algType in (
10+
Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
11+
GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm,
12+
LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL,
13+
SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL
14+
)
15+
@eval function SciMLBase.__solve(
16+
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...)
17+
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
18+
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
19+
return SciMLBase.build_solution(
20+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
21+
end
1222
end
1323

1424
@concrete mutable struct NonlinearSolveForwardDiffCache
@@ -32,17 +42,21 @@ function reinit_cache!(cache::NonlinearSolveForwardDiffCache;
3242
return cache
3343
end
3444

35-
function SciMLBase.init(
36-
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
37-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
38-
alg::Union{Nothing, AbstractNonlinearAlgorithm},
39-
args...;
40-
kwargs...) where {T, V, P, iip}
41-
p = __value(prob.p)
42-
newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...)
43-
cache = init(newprob, alg, args...; kwargs...)
44-
return NonlinearSolveForwardDiffCache(
45-
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p))
45+
for algType in (
46+
Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
47+
SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm,
48+
GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm,
49+
LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL,
50+
SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL
51+
)
52+
@eval function SciMLBase.__init(
53+
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...)
54+
p = __value(prob.p)
55+
newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...)
56+
cache = init(newprob, alg, args...; kwargs...)
57+
return NonlinearSolveForwardDiffCache(
58+
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p))
59+
end
4660
end
4761

4862
function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)

0 commit comments

Comments
 (0)