Skip to content

Commit 1808553

Browse files
committed
fix: extension for forward AD support
1 parent 0b70d62 commit 1808553

File tree

4 files changed

+7
-5
lines changed

4 files changed

+7
-5
lines changed

lib/BracketingNonlinearSolve/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1414
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1515

1616
[extensions]
17-
NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
17+
BracketingNonlinearSolveForwardDiffExt = "ForwardDiff"
1818

1919
[compat]
2020
CommonSolve = "0.2.4"

lib/BracketingNonlinearSolve/test/rootfind_tests.jl

-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
@testsnippet RootfindingTestSnippet begin
2-
using NonlinearSolveBase, BracketingNonlinearSolve
3-
42
quadratic_f(u, p) = u .* u .- p
53
quadratic_f!(du, u, p) = (du .= u .* u .- p)
64
quadratic_f2(u, p) = @. p[1] * u * u - p[2]

lib/NonlinearSolveBase/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "1.0.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
89
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
910
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1011
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
@@ -24,6 +25,7 @@ NonlinearSolveBaseSparseArraysExt = "SparseArrays"
2425

2526
[compat]
2627
ArrayInterface = "7.9"
28+
CommonSolve = "0.2.4"
2729
Compat = "4.15"
2830
ConcreteStructs = "0.2.3"
2931
FastClosures = "0.3"

lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils
1010

1111
Utils.value(::Type{Dual{T, V, N}}) where {T, V, N} = V
1212
Utils.value(x::Dual) = Utils.value(ForwardDiff.value(x))
13+
Utils.value(x::AbstractArray{<:Dual}) = Utils.value.(x)
1314

1415
function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
1516
prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem},
@@ -43,7 +44,7 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
4344
end
4445

4546
function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F}
46-
if isinplace(prob)
47+
if SciMLBase.isinplace(prob)
4748
f = @closure p -> begin
4849
du = Utils.safe_similar(u, promote_type(eltype(u), eltype(p)))
4950
f(du, u, p)
@@ -62,10 +63,11 @@ function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F}
6263
end
6364

6465
function nonlinearsolve_∂f_∂u(prob, f::F, u, p) where {F}
65-
if isinplace(prob)
66+
if SciMLBase.isinplace(prob)
6667
return ForwardDiff.jacobian(
6768
@closure((du, u)->f(du, u, p)), Utils.safe_similar(u), u)
6869
end
70+
u isa Number && return ForwardDiff.derivative(Base.Fix2(f, p), u)
6971
return ForwardDiff.jacobian(Base.Fix2(f, p), u)
7072
end
7173

0 commit comments

Comments
 (0)