Skip to content

Commit 28edcd4

Browse files
authored
Merge pull request #393 from SciML/ap/tstable_findmin
Make `__findmin` type stable
2 parents 3925219 + e6ff3aa commit 28edcd4

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.8.1"
4+
version = "3.8.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/utils.jl

+20-7
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,29 @@ LazyArrays.applied_axes(::typeof(__zero), x) = axes(x)
9494
@inline __is_complex(::Type{Complex}) = true
9595
@inline __is_complex(::Type{T}) where {T} = false
9696

97-
@inline __findmin_caches(f, caches) = __findmin(f get_fu, caches)
97+
@inline __findmin_caches(f::F, caches) where {F} = __findmin(f get_fu, caches)
9898
# FIXME: DEFAULT_NORM makes an Array of NaNs not a NaN (atleast according to `isnan`)
99-
@inline __findmin(::typeof(DEFAULT_NORM), x) = __findmin(Base.Fix1(maximum, abs), x)
100-
@inline function __findmin(f, x)
99+
@generated function __findmin(f::F, x) where {F}
100+
# JET shows dynamic dispatch if this is not written as a generated function
101+
if F === typeof(DEFAULT_NORM)
102+
return :(return __findmin_impl(Base.Fix1(maximum, abs), x))
103+
end
104+
return :(return __findmin_impl(f, x))
105+
end
106+
@inline @views function __findmin_impl(f::F, x) where {F}
107+
idx = findfirst(Base.Fix2(!==, nothing), x)
108+
# This is an internal function so we assume that inputs are consistent and there is
109+
# atleast one non-`nothing` value
110+
fx_idx = f(x[idx])
111+
idx == length(x) && return fx_idx, idx
101112
fmin = @closure xᵢ -> begin
102-
xᵢ === nothing && return Inf
113+
xᵢ === nothing && return oftype(fx_idx, Inf)
103114
fx = f(xᵢ)
104-
return ifelse(isnan(fx), Inf, fx)
115+
return ifelse(isnan(fx), oftype(fx, Inf), fx)
105116
end
106-
return findmin(fmin, x)
117+
x_min, x_min_idx = findmin(fmin, x[(idx + 1):length(x)])
118+
x_min < fx_idx && return x_min, x_min_idx + idx
119+
return fx_idx, idx
107120
end
108121

109122
@inline __can_setindex(x) = can_setindex(x)
@@ -130,7 +143,7 @@ Statistics from the nonlinear equation solver about the solution process.
130143
- nf: Number of function evaluations.
131144
- njacs: Number of Jacobians created during the solve.
132145
- nfactors: Number of factorzations of the jacobian required for the solve.
133-
- nsolve: Number of linear solves `W\b` required for the solve.
146+
- nsolve: Number of linear solves `W \\ b` required for the solve.
134147
- nsteps: Total number of iterations for the nonlinear solver.
135148
"""
136149
struct ImmutableNLStats

0 commit comments

Comments
 (0)