Skip to content

Commit 0b70d62

Browse files
committed
feat: ForwardDiff support in NonlinearSolveBase
1 parent d669bb8 commit 0b70d62

File tree

8 files changed

+136
-6
lines changed

8 files changed

+136
-6
lines changed

lib/BracketingNonlinearSolve/Project.toml

+9-1
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,26 @@ NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
1010
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1111
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1212

13+
[weakdeps]
14+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
15+
16+
[extensions]
17+
NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
18+
1319
[compat]
1420
CommonSolve = "0.2.4"
1521
ConcreteStructs = "0.2.3"
22+
ForwardDiff = "0.10.36"
1623
NonlinearSolveBase = "1"
1724
PrecompileTools = "1.2.1"
1825
SciMLBase = "2.50"
1926
julia = "1.10"
2027

2128
[extras]
2229
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
30+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2331
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2432
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
2533

2634
[targets]
27-
test = ["InteractiveUtils", "Test", "TestItemRunner"]
35+
test = ["InteractiveUtils", "ForwardDiff", "Test", "TestItemRunner"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
module BracketingNonlinearSolveForwardDiffExt
2+
3+
using CommonSolve: CommonSolve
4+
using ForwardDiff: ForwardDiff, Dual
5+
using NonlinearSolveBase: nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution
6+
using SciMLBase: SciMLBase, IntervalNonlinearProblem
7+
8+
using BracketingNonlinearSolve: Bisection, Brent, Alefeld, Falsi, ITP, Ridder
9+
10+
for algT in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
11+
@eval function CommonSolve.solve(
12+
prob::IntervalNonlinearProblem{
13+
uType, iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
14+
alg::$(algT),
15+
args...;
16+
kwargs...) where {uType, iip, T, V, P}
17+
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
18+
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
19+
return SciMLBase.build_solution(
20+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats,
21+
sol.original, left = Dual{T, V, P}(sol.left, partials),
22+
right = Dual{T, V, P}(sol.right, partials))
23+
end
24+
end
25+
26+
end

lib/BracketingNonlinearSolve/test/rootfind_tests.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
end
88

99
@testitem "Interval Nonlinear Problems" setup=[RootfindingTestSnippet] tags=[:core] begin
10+
using ForwardDiff
11+
1012
@testset for alg in (Bisection(), Falsi(), Ridder(), Brent(), ITP(), Alefeld())
1113
tspan = (1.0, 20.0)
1214

@@ -17,7 +19,7 @@ end
1719

1820
@testset for p in 1.1:0.1:100.0
1921
@test g(p)sqrt(p) atol=1e-3 rtol=1e-3
20-
# @test ForwardDiff.derivative(g, p)≈1 / (2 * sqrt(p)) atol=1e-3 rtol=1e-3
22+
@test ForwardDiff.derivative(g, p)1 / (2 * sqrt(p)) atol=1e-3 rtol=1e-3
2123
end
2224

2325
t = (p) -> [sqrt(p[2] / p[1])]
@@ -30,7 +32,7 @@ end
3032
end
3133

3234
@test g2(p)[sqrt(p[2] / p[1])] atol=1e-3 rtol=1e-3
33-
# @test ForwardDiff.jacobian(g2, p)≈ForwardDiff.jacobian(t, p) atol=1e-3 rtol=1e-3
35+
@test ForwardDiff.jacobian(g2, p)ForwardDiff.jacobian(t, p) atol=1e-3 rtol=1e-3
3436

3537
probB = IntervalNonlinearProblem{false}(quadratic_f, (1.0, 2.0), 2.0)
3638
sol = solve(probB, alg; abstol = 1e-9)
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,82 @@
11
module NonlinearSolveBaseForwardDiffExt
22

3+
using CommonSolve: solve
4+
using FastClosures: @closure
35
using ForwardDiff: ForwardDiff, Dual
4-
using NonlinearSolveBase: Utils
6+
using SciMLBase: SciMLBase, IntervalNonlinearProblem, NonlinearProblem,
7+
NonlinearLeastSquaresProblem, remake
8+
9+
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils
510

611
Utils.value(::Type{Dual{T, V, N}}) where {T, V, N} = V
712
Utils.value(x::Dual) = Utils.value(ForwardDiff.value(x))
813

14+
function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
15+
prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem},
16+
alg, args...; kwargs...)
17+
p = Utils.value(prob.p)
18+
if prob isa IntervalNonlinearProblem
19+
tspan = Utils.value.(prob.tspan)
20+
newprob = IntervalNonlinearProblem(prob.f, tspan, p; prob.kwargs...)
21+
else
22+
newprob = remake(prob; p, u0 = Utils.value(prob.u0))
23+
end
24+
25+
sol = solve(newprob, alg, args...; kwargs...)
26+
27+
uu = sol.u
28+
Jₚ = nonlinearsolve_∂f_∂p(prob, prob.f, uu, p)
29+
Jᵤ = nonlinearsolve_∂f_∂u(prob, prob.f, uu, p)
30+
z = -Jᵤ \ Jₚ
31+
pp = prob.p
32+
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)
33+
34+
if uu isa Number
35+
partials = sum(sumfun, zip(z, pp))
36+
elseif p isa Number
37+
partials = sumfun((z, pp))
38+
else
39+
partials = sum(sumfun, zip(eachcol(z), pp))
40+
end
41+
42+
return sol, partials
43+
end
44+
45+
function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F}
46+
if isinplace(prob)
47+
f = @closure p -> begin
48+
du = Utils.safe_similar(u, promote_type(eltype(u), eltype(p)))
49+
f(du, u, p)
50+
return du
51+
end
52+
else
53+
f = Base.Fix1(f, u)
54+
end
55+
if p isa Number
56+
return Utils.safe_reshape(ForwardDiff.derivative(f, p), :, 1)
57+
elseif u isa Number
58+
return Utils.safe_reshape(ForwardDiff.gradient(f, p), 1, :)
59+
else
60+
return ForwardDiff.jacobian(f, p)
61+
end
62+
end
63+
64+
function nonlinearsolve_∂f_∂u(prob, f::F, u, p) where {F}
65+
if isinplace(prob)
66+
return ForwardDiff.jacobian(
67+
@closure((du, u)->f(du, u, p)), Utils.safe_similar(u), u)
68+
end
69+
return ForwardDiff.jacobian(Base.Fix2(f, p), u)
70+
end
71+
72+
function NonlinearSolveBase.nonlinearsolve_dual_solution(u::Number, partials,
73+
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
74+
return Dual{T, V, P}(u, partials)
75+
end
76+
77+
function NonlinearSolveBase.nonlinearsolve_dual_solution(u::AbstractArray, partials,
78+
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
79+
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials)))
80+
end
81+
982
end

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ include("utils.jl")
1717

1818
include("common_defaults.jl")
1919
include("termination_conditions.jl")
20-
include("autodiff.jl")
2120
include("immutable_problem.jl")
2221

2322
# Unexported Public API
2423
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
24+
@compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution))
2525

2626
export RelTerminationMode, AbsTerminationMode, NormTerminationMode, RelNormTerminationMode,
2727
AbsNormTerminationMode, RelNormSafeTerminationMode, AbsNormSafeTerminationMode,

lib/NonlinearSolveBase/src/autodiff.jl

-1
This file was deleted.

lib/NonlinearSolveBase/src/public.jl

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ function L2_NORM end
55
function Linf_NORM end
66
function get_tolerance end
77

8+
# Forward declarations of functions for forward mode AD
9+
function nonlinearsolve_forwarddiff_solve end
10+
function nonlinearsolve_dual_solution end
11+
812
# Nonlinear Solve Termination Conditions
913
abstract type AbstractNonlinearTerminationMode end
1014
abstract type AbstractSafeNonlinearTerminationMode <: AbstractNonlinearTerminationMode end

lib/NonlinearSolveBase/src/utils.jl

+18
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,22 @@ apply_norm(f::F, x, y) where {F} = norm_op(standardize_norm(f), +, x, y)
7272
convert_real(::Type{T}, ::Nothing) where {T} = nothing
7373
convert_real(::Type{T}, x) where {T} = real(T(x))
7474

75+
restructure(::Number, x::Number) = x
76+
restructure(y, x) = ArrayInterface.restructure(y, x)
77+
78+
function safe_similar(x, args...; kwargs...)
79+
y = similar(x, args...; kwargs...)
80+
return init_bigfloat_array!!(y)
81+
end
82+
83+
init_bigfloat_array!!(x) = x
84+
85+
function init_bigfloat_array!!(x::AbstractArray{<:BigFloat})
86+
ArrayInterface.can_setindex(x) && fill!(x, BigFloat(0))
87+
return x
88+
end
89+
90+
safe_reshape(x::Number, args...) = x
91+
safe_reshape(x, args...) = reshape(x, args...)
92+
7593
end

0 commit comments

Comments
 (0)