Skip to content

Commit b7b9d3f

Browse files
Remove unnecessary @generated functions for isinplace checks
SciMLBase.isinplace extracts a type parameter and returns a compile-time constant (true/false). The compiler constant-folds it, so @generated functions are unnecessary for eliminating dead branches. Regular functions with if SciMLBase.isinplace(prob) infer identically. Reverted: incompatible_backend_and_problem, evaluate_f!!, evaluate_f, should_cache_fx, and prepare_jacobian back to regular functions. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent e94e86c commit b7b9d3f

File tree

3 files changed

+31
-61
lines changed

3 files changed

+31
-61
lines changed

lib/NonlinearSolveBase/src/autodiff.jl

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -105,22 +105,12 @@ function select_jacobian_autodiff(prob::AbstractNonlinearProblem, ::Nothing)
105105
or the loaded backends don't support the problem."))
106106
end
107107

108-
@generated function incompatible_backend_and_problem(
108+
function incompatible_backend_and_problem(
109109
prob::AbstractNonlinearProblem, ad::AbstractADType
110110
)
111-
iip = prob <: AbstractNonlinearProblem{<:Any, true}
112-
if iip
113-
return quote
114-
!DI.check_available(ad) && return true
115-
!DI.check_inplace(ad) && return true
116-
return additional_incompatible_backend_check(prob, ad)
117-
end
118-
else
119-
return quote
120-
!DI.check_available(ad) && return true
121-
return additional_incompatible_backend_check(prob, ad)
122-
end
123-
end
111+
!DI.check_available(ad) && return true
112+
SciMLBase.isinplace(prob) && !DI.check_inplace(ad) && return true
113+
return additional_incompatible_backend_check(prob, ad)
124114
end
125115

126116
additional_incompatible_backend_check(::AbstractNonlinearProblem, ::AbstractADType) = false

lib/NonlinearSolveBase/src/utils.jl

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using LinearAlgebra: LinearAlgebra, Diagonal, Symmetric, norm, dot, cond, diagin
77
using MaybeInplace: @bb
88
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition, recursivecopy!
99
using SciMLOperators: AbstractSciMLOperator
10-
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NonlinearFunction, AbstractNonlinearFunction
10+
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NonlinearFunction
1111
using StaticArraysCore: StaticArray, SArray, SMatrix
1212

1313
using ..NonlinearSolveBase: NonlinearSolveBase, L2_NORM, Linf_NORM
@@ -174,33 +174,22 @@ function evaluate_f!!(prob::AbstractNonlinearProblem, fu, u, p = prob.p)
174174
return evaluate_f!!(prob.f, fu, u, p)
175175
end
176176

177-
@generated function evaluate_f!!(f::NonlinearFunction, fu, u, p)
178-
iip = f <: AbstractNonlinearFunction{true}
179-
if iip
180-
return quote
181-
f(fu, u, p)
182-
return fu
183-
end
184-
else
185-
return quote
186-
return f(u, p)
187-
end
177+
function evaluate_f!!(f::NonlinearFunction, fu, u, p)
178+
if SciMLBase.isinplace(f)
179+
f(fu, u, p)
180+
return fu
188181
end
182+
return f(u, p)
189183
end
190184

191-
@generated function evaluate_f(prob::AbstractNonlinearProblem, u)
192-
iip = prob <: AbstractNonlinearProblem{<:Any, true}
193-
if iip
194-
return quote
195-
fu = prob.f.resid_prototype === nothing ? similar(u) :
196-
similar(prob.f.resid_prototype)
197-
prob.f(fu, u, prob.p)
198-
return fu
199-
end
185+
function evaluate_f(prob::AbstractNonlinearProblem, u)
186+
if SciMLBase.isinplace(prob)
187+
fu = prob.f.resid_prototype === nothing ? similar(u) :
188+
similar(prob.f.resid_prototype)
189+
prob.f(fu, u, prob.p)
190+
return fu
200191
else
201-
return quote
202-
return prob.f(u, prob.p)
203-
end
192+
return prob.f(u, prob.p)
204193
end
205194
end
206195

lib/SimpleNonlinearSolve/src/utils.jl

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,8 @@ const NLBUtils = NonlinearSolveBase.Utils
1919
end
2020

2121
# GPU-compatible helper to check if fx should be cached
22-
@generated function should_cache_fx(prob::SciMLBase.AbstractNonlinearProblem, f)
23-
iip = prob <: SciMLBase.AbstractNonlinearProblem{<:Any, true}
24-
return quote
25-
$iip && !SciMLBase.has_jac(f)
26-
end
22+
@inline function should_cache_fx(prob::SciMLBase.AbstractNonlinearProblem, f)
23+
return SciMLBase.isinplace(prob) && !SciMLBase.has_jac(f)
2724
end
2825

2926
function identity_jacobian(u::Number, fu::Number, α = true)
@@ -98,27 +95,21 @@ function prepare_jacobian(prob, autodiff, _, x::Number)
9895
return DINoPreparation()
9996
end
10097

101-
@generated function prepare_jacobian(prob, autodiff, fx, x)
102-
iip = prob <: SciMLBase.AbstractNonlinearProblem{<:Any, true}
103-
if iip
104-
return quote
105-
SciMLBase.has_jac(prob.f) && return AnalyticJacobian()
106-
return DIExtras(
107-
DI.prepare_jacobian(
108-
prob.f, fx, autodiff, x, Constant(prob.p), strict = Val(false)
109-
)
98+
function prepare_jacobian(prob, autodiff, fx, x)
99+
SciMLBase.has_jac(prob.f) && return AnalyticJacobian()
100+
if SciMLBase.isinplace(prob.f)
101+
return DIExtras(
102+
DI.prepare_jacobian(
103+
prob.f, fx, autodiff, x, Constant(prob.p), strict = Val(false)
110104
)
111-
end
105+
)
112106
else
113-
return quote
114-
SciMLBase.has_jac(prob.f) && return AnalyticJacobian()
115-
x isa SArray && return DINoPreparation()
116-
return DIExtras(
117-
DI.prepare_jacobian(
118-
prob.f, autodiff, x, Constant(prob.p), strict = Val(false)
119-
)
107+
x isa SArray && return DINoPreparation()
108+
return DIExtras(
109+
DI.prepare_jacobian(
110+
prob.f, autodiff, x, Constant(prob.p), strict = Val(false)
120111
)
121-
end
112+
)
122113
end
123114
end
124115

0 commit comments

Comments
 (0)