Skip to content

Commit 77cd271

Browse files
committed
feat: automatic backend selection for autodiff
1 parent ab4c9f8 commit 77cd271

File tree

4 files changed

+132
-1
lines changed

4 files changed

+132
-1
lines changed

lib/NonlinearSolveBase/Project.toml

+6
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ authors = ["Avik Pal <[email protected]> and contributors"]
44
version = "1.0.0"
55

66
[deps]
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
78
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
89
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
910
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1011
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
12+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
13+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1114
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1215
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1316
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
@@ -24,10 +27,13 @@ NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
2427
NonlinearSolveBaseSparseArraysExt = "SparseArrays"
2528

2629
[compat]
30+
ADTypes = "1.9"
2731
ArrayInterface = "7.9"
2832
CommonSolve = "0.2.4"
2933
Compat = "4.15"
3034
ConcreteStructs = "0.2.3"
35+
DifferentiationInterface = "0.6.1"
36+
EnzymeCore = "0.8"
3137
FastClosures = "0.3"
3238
ForwardDiff = "0.10.36"
3339
LinearAlgebra = "1.10"

lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
module NonlinearSolveBaseForwardDiffExt
22

3+
using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
34
using CommonSolve: solve
45
using FastClosures: @closure
56
using ForwardDiff: ForwardDiff, Dual
6-
using SciMLBase: SciMLBase, IntervalNonlinearProblem, NonlinearProblem,
7+
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
8+
NonlinearProblem,
79
NonlinearLeastSquaresProblem, remake
810

911
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils
1012

13+
function NonlinearSolveBase.additional_incompatible_backend_check(
14+
prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff})
15+
return !ForwardDiff.can_dual(eltype(prob.u0))
16+
end
17+
1118
Utils.value(::Type{Dual{T, V, N}}) where {T, V, N} = V
1219
Utils.value(x::Dual) = Utils.value(ForwardDiff.value(x))
1320
Utils.value(x::AbstractArray{<:Dual}) = Utils.value.(x)

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

+9
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
module NonlinearSolveBase
22

3+
using ADTypes: ADTypes, AbstractADType, ForwardMode, ReverseMode
34
using ArrayInterface: ArrayInterface
45
using Compat: @compat
56
using ConcreteStructs: @concrete
7+
using DifferentiationInterface: DifferentiationInterface
8+
using EnzymeCore: EnzymeCore
69
using FastClosures: @closure
710
using LinearAlgebra: norm
811
using Markdown: @doc_str
@@ -13,16 +16,22 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear
1316
isinplace, warn_paramtype
1417
using StaticArraysCore: StaticArray
1518

19+
const DI = DifferentiationInterface
20+
1621
include("public.jl")
1722
include("utils.jl")
1823

1924
include("immutable_problem.jl")
2025
include("common_defaults.jl")
2126
include("termination_conditions.jl")
2227

28+
include("autodiff.jl")
29+
2330
# Unexported Public API
2431
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
2532
@compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution))
33+
@compat(public, (select_forward_mode_autodiff, select_reverse_mode_autodiff,
34+
select_jacobian_autodiff))
2635

2736
export RelTerminationMode, AbsTerminationMode, NormTerminationMode, RelNormTerminationMode,
2837
AbsNormTerminationMode, RelNormSafeTerminationMode, AbsNormSafeTerminationMode,
+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Here we determine the preferred AD backend. We have a predefined list of ADs and then
2+
# we select the first one that is avialable and would work with the problem.
3+
4+
# Ordering is important here. We want to select the first one that is compatible with the
5+
# problem.
6+
const ReverseADs = [
7+
ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse),
8+
ADTypes.AutoZygote(),
9+
ADTypes.AutoTracker(),
10+
ADTypes.AutoReverseDiff(),
11+
ADTypes.AutoFiniteDiff()
12+
]
13+
14+
const ForwardADs = [
15+
ADTypes.AutoEnzyme(; mode = EnzymeCore.Forward),
16+
ADTypes.AutoPolyesterForwardDiff(),
17+
ADTypes.AutoForwardDiff(),
18+
ADTypes.AutoFiniteDiff()
19+
]
20+
21+
# TODO: Handle Sparsity
22+
23+
function select_forward_mode_autodiff(
24+
prob::AbstractNonlinearProblem, ad::AbstractADType; warn_check_mode::Bool = true)
25+
if warn_check_mode && !(ADTypes.mode(ad) isa ADTypes.ForwardMode)
26+
@warn "The chosen AD backend $(ad) is not a forward mode AD. Use with caution."
27+
end
28+
if incompatible_backend_and_problem(prob, ad)
29+
adₙ = select_forward_mode_autodiff(prob, nothing; warn_check_mode)
30+
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \
31+
running autodiff selection detected `$(adₙ)` as a potential forward mode \
32+
backend."
33+
return adₙ
34+
end
35+
return ad
36+
end
37+
38+
function select_forward_mode_autodiff(prob::AbstractNonlinearProblem, ::Nothing;
39+
warn_check_mode::Bool = true)
40+
idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ForwardADs)
41+
idx !== nothing && return ForwardADs[idx]
42+
throw(ArgumentError("No forward mode AD backend is compatible with the chosen problem. \
43+
This could be because no forward mode autodiff backend is loaded \
44+
or the loaded backends don't support the problem."))
45+
end
46+
47+
function select_reverse_mode_autodiff(
48+
prob::AbstractNonlinearProblem, ad::AbstractADType; warn_check_mode::Bool = true)
49+
if warn_check_mode && !(ADTypes.mode(ad) isa ADTypes.ReverseMode)
50+
if !is_finite_differences_backend(ad)
51+
@warn "The chosen AD backend $(ad) is not a reverse mode AD. Use with caution."
52+
else
53+
@warn "The chosen AD backend $(ad) is a finite differences backend. This might \
54+
be slow and inaccurate. Use with caution."
55+
end
56+
end
57+
if incompatible_backend_and_problem(prob, ad)
58+
adₙ = select_reverse_mode_autodiff(prob, nothing; warn_check_mode)
59+
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \
60+
running autodiff selection detected `$(adₙ)` as a potential reverse mode \
61+
backend."
62+
return adₙ
63+
end
64+
return ad
65+
end
66+
67+
function select_reverse_mode_autodiff(prob::AbstractNonlinearProblem, ::Nothing;
68+
warn_check_mode::Bool = true)
69+
idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ReverseADs)
70+
idx !== nothing && return ReverseADs[idx]
71+
throw(ArgumentError("No reverse mode AD backend is compatible with the chosen problem. \
72+
This could be because no reverse mode autodiff backend is loaded \
73+
or the loaded backends don't support the problem."))
74+
end
75+
76+
function select_jacobian_autodiff(prob::AbstractNonlinearProblem, ad::AbstractADType)
77+
if incompatible_backend_and_problem(prob, ad)
78+
adₙ = select_jacobian_autodiff(prob, nothing)
79+
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \
80+
running autodiff selection detected `$(adₙ)` as a potential jacobian \
81+
backend."
82+
return adₙ
83+
end
84+
return ad
85+
end
86+
87+
function select_jacobian_autodiff(prob::AbstractNonlinearProblem, ::Nothing)
88+
idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ForwardADs)
89+
idx !== nothing && !is_finite_differences_backend(ForwardADs[idx]) &&
90+
return ForwardADs[idx]
91+
idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ReverseADs)
92+
idx !== nothing && return ReverseADs[idx]
93+
throw(ArgumentError("No jacobian AD backend is compatible with the chosen problem. \
94+
This could be because no jacobian autodiff backend is loaded \
95+
or the loaded backends don't support the problem."))
96+
end
97+
98+
function incompatible_backend_and_problem(
99+
prob::AbstractNonlinearProblem, ad::AbstractADType)
100+
!DI.check_available(ad) && return true
101+
SciMLBase.isinplace(prob) && !DI.check_inplace(ad) && return true
102+
return additional_incompatible_backend_check(prob, ad)
103+
end
104+
105+
additional_incompatible_backend_check(::AbstractNonlinearProblem, ::AbstractADType) = false
106+
107+
is_finite_differences_backend(ad::AbstractADType) = false
108+
is_finite_differences_backend(::ADTypes.AutoFiniteDiff) = true
109+
is_finite_differences_backend(::ADTypes.AutoFiniteDifferences) = true

0 commit comments

Comments
 (0)