Skip to content

Commit b49e42d

Browse files
committed
fix: dispatch forwarddiff on __init and __solve
1 parent 89d76b0 commit b49e42d

File tree

2 files changed

+31
-23
lines changed

2 files changed

+31
-23
lines changed

src/NonlinearSolve.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ include("descent/damped_newton.jl")
6767
include("descent/geodesic_acceleration.jl")
6868

6969
include("internal/jacobian.jl")
70-
include("internal/forward_diff.jl")
7170
include("internal/linear_solve.jl")
7271
include("internal/termination.jl")
7372
include("internal/tracing.jl")
@@ -82,6 +81,8 @@ include("core/generalized_first_order.jl")
8281
include("core/spectral_methods.jl")
8382
include("core/noinit.jl")
8483

84+
include("internal/forward_diff.jl") # we need to define after the algorithms
85+
8586
include("algorithms/raphson.jl")
8687
include("algorithms/pseudo_transient.jl")
8788
include("algorithms/broyden.jl")

src/internal/forward_diff.jl

+29-22
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1-
# XXX: dispatch on `__solve` & `__init`
2-
function SciMLBase.solve(
3-
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
4-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
5-
alg::Union{Nothing, AbstractNonlinearAlgorithm},
6-
args...;
7-
kwargs...) where {T, V, P, iip}
8-
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
9-
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
10-
return SciMLBase.build_solution(
11-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
1+
const DualNonlinearProblem = NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
2+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P}
3+
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
4+
<:Union{Number, <:AbstractArray}, iip,
5+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P}
6+
const DualAbstractNonlinearProblem = Union{
7+
DualNonlinearProblem, DualNonlinearLeastSquaresProblem}
8+
9+
for algType in (Nothing, AbstractNonlinearSolveAlgorithm)
10+
@eval function SciMLBase.__solve(
11+
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...)
12+
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
13+
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
14+
return SciMLBase.build_solution(
15+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
16+
end
1217
end
1318

1419
@concrete mutable struct NonlinearSolveForwardDiffCache
@@ -32,17 +37,19 @@ function reinit_cache!(cache::NonlinearSolveForwardDiffCache;
3237
return cache
3338
end
3439

35-
function SciMLBase.init(
36-
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
37-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
38-
alg::Union{Nothing, AbstractNonlinearAlgorithm},
39-
args...;
40-
kwargs...) where {T, V, P, iip}
41-
p = __value(prob.p)
42-
newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...)
43-
cache = init(newprob, alg, args...; kwargs...)
44-
return NonlinearSolveForwardDiffCache(
45-
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p))
40+
for algType in (
41+
Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
42+
SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm,
43+
GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm
44+
)
45+
@eval function SciMLBase.__init(
46+
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...)
47+
p = __value(prob.p)
48+
newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...)
49+
cache = init(newprob, alg, args...; kwargs...)
50+
return NonlinearSolveForwardDiffCache(
51+
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p))
52+
end
4653
end
4754

4855
function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)

0 commit comments

Comments
 (0)