Skip to content

Commit 30487a1

Browse files
committed
feat: check for branching for ReverseDiff(compile=true)
1 parent 1a31a42 commit 30487a1

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

lib/NonlinearSolveBase/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1212
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1313
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1414
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
15+
FunctionProperties = "f62d2435-5019-4c03-9749-2d4c77af0cbc"
1516
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1617
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1718
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
@@ -36,6 +37,7 @@ DifferentiationInterface = "0.6.1"
3637
EnzymeCore = "0.8"
3738
FastClosures = "0.3"
3839
ForwardDiff = "0.10.36"
40+
FunctionProperties = "0.1.2"
3941
LinearAlgebra = "1.10"
4042
Markdown = "1.10"
4143
RecursiveArrayTools = "3"

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@ using ConcreteStructs: @concrete
77
using DifferentiationInterface: DifferentiationInterface
88
using EnzymeCore: EnzymeCore
99
using FastClosures: @closure
10+
using FunctionProperties: hasbranching
1011
using LinearAlgebra: norm
1112
using Markdown: @doc_str
1213
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
1314
using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem,
1415
NonlinearProblem, NonlinearLeastSquaresProblem, AbstractNonlinearFunction,
15-
@add_kwonly, StandardNonlinearProblem, NullParameters, NonlinearProblem,
16-
isinplace, warn_paramtype
16+
@add_kwonly, StandardNonlinearProblem, NullParameters, isinplace,
17+
warn_paramtype
1718
using StaticArraysCore: StaticArray
1819

1920
const DI = DifferentiationInterface
@@ -30,8 +31,9 @@ include("autodiff.jl")
3031
# Unexported Public API
3132
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
3233
@compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution))
33-
@compat(public, (select_forward_mode_autodiff, select_reverse_mode_autodiff,
34-
select_jacobian_autodiff))
34+
@compat(public,
35+
(select_forward_mode_autodiff, select_reverse_mode_autodiff,
36+
select_jacobian_autodiff))
3537

3638
export RelTerminationMode, AbsTerminationMode, NormTerminationMode, RelNormTerminationMode,
3739
AbsNormTerminationMode, RelNormSafeTerminationMode, AbsNormSafeTerminationMode,

lib/NonlinearSolveBase/src/autodiff.jl

+9
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ const ReverseADs = [
77
ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse),
88
ADTypes.AutoZygote(),
99
ADTypes.AutoTracker(),
10+
ADTypes.AutoReverseDiff(; compile = true),
1011
ADTypes.AutoReverseDiff(),
1112
ADTypes.AutoFiniteDiff()
1213
]
@@ -103,6 +104,14 @@ function incompatible_backend_and_problem(
103104
end
104105

105106
additional_incompatible_backend_check(::AbstractNonlinearProblem, ::AbstractADType) = false
107+
function additional_incompatible_backend_check(prob::AbstractNonlinearProblem,
108+
::ADTypes.AutoReverseDiff{true})
109+
if SciMLBase.isinplace(prob)
110+
fu = prob.f.resid_prototype === nothing ? zero(prob.u0) : prob.f.resid_prototype
111+
return hasbranching(prob.f, fu, prob.u0, prob.p)
112+
end
113+
return hasbranching(prob.f, prob.u0, prob.p)
114+
end
106115

107116
is_finite_differences_backend(ad::AbstractADType) = false
108117
is_finite_differences_backend(::ADTypes.AutoFiniteDiff) = true

0 commit comments

Comments
 (0)