Skip to content

Commit d8d9eea

Browse files
Merge pull request #399 from SciML/ap/nlls_term
Use a different termination norm for NLLS
2 parents e231d64 + fb0f1eb commit d8d9eea

10 files changed

+106
-71
lines changed

.github/workflows/Downstream.yml

+3
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ jobs:
5555
@info "Not compatible with this release. No problem." exception=err
5656
exit(0) # Exit immediately, as a success
5757
end
58+
env:
59+
RETESTITEMS_NWORKERS: 4
60+
RETESTITEMS_NWORKER_THREADS: 2
5861
- uses: julia-actions/julia-processcoverage@v1
5962
with:
6063
directories: src,ext

Project.toml

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.8.4"
4+
version = "3.9.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -58,15 +58,15 @@ NonlinearSolveZygoteExt = "Zygote"
5858
[compat]
5959
ADTypes = "0.2.6"
6060
Aqua = "0.8"
61-
ArrayInterface = "7.7"
61+
ArrayInterface = "7.9"
6262
BandedMatrices = "1.4"
6363
BenchmarkTools = "1.4"
6464
ConcreteStructs = "0.2.3"
65-
CUDA = "5.1"
66-
DiffEqBase = "6.146.0"
65+
CUDA = "5.2"
66+
DiffEqBase = "6.149.0"
6767
Enzyme = "0.11.15"
6868
FastBroadcast = "0.2.8"
69-
FastClosures = "0.3"
69+
FastClosures = "0.3.2"
7070
FastLevenbergMarquardt = "0.1"
7171
FiniteDiff = "2.21"
7272
FixedPointAcceleration = "0.3"
@@ -82,21 +82,21 @@ NLsolve = "4.5"
8282
NLSolvers = "0.5"
8383
NaNMath = "1"
8484
NonlinearProblemLibrary = "0.1.2"
85-
OrdinaryDiffEq = "6.63"
85+
OrdinaryDiffEq = "6.74"
8686
Pkg = "1.10"
8787
PrecompileTools = "1.2"
8888
Preferences = "1.4"
8989
Printf = "1.10"
9090
Random = "1.91"
9191
ReTestItems = "1"
92-
RecursiveArrayTools = "3.4"
92+
RecursiveArrayTools = "3.8"
9393
Reexport = "1.2"
9494
SIAMFANLEquations = "1.0.1"
9595
SafeTestsets = "0.1"
96-
SciMLBase = "2.19.0"
96+
SciMLBase = "2.28.0"
9797
SimpleNonlinearSolve = "1.2"
9898
SparseArrays = "1.10"
99-
SparseDiffTools = "2.14"
99+
SparseDiffTools = "2.17"
100100
SpeedMapping = "0.3"
101101
StableRNGs = "1"
102102
StaticArrays = "1.7"
@@ -105,7 +105,7 @@ Sundials = "4.23.1"
105105
Symbolics = "5.13"
106106
Test = "1.10"
107107
TimerOutputs = "0.5.23"
108-
Zygote = "0.6.67"
108+
Zygote = "0.6.69"
109109
julia = "1.10"
110110

111111
[extras]

src/core/approximate_jacobian.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ function SciMLBase.__init(
167167
prob, alg.initialization, alg, f, fu, u, p; linsolve, maxiters, internalnorm)
168168

169169
abstol, reltol, termination_cache = init_termination_cache(
170-
abstol, reltol, fu, u, termination_condition)
170+
prob, abstol, reltol, fu, u, termination_condition)
171171
linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)
172172

173173
J = initialization_cache(nothing)
@@ -206,7 +206,7 @@ function SciMLBase.__init(
206206
update_rule_cache = __internal_init(
207207
prob, alg.update_rule, J, fu, u, du; internalnorm)
208208

209-
trace = init_nonlinearsolve_trace(alg, u, fu, ApplyArray(__zero, J), du;
209+
trace = init_nonlinearsolve_trace(prob, alg, u, fu, ApplyArray(__zero, J), du;
210210
uses_jacobian_inverse = Val(INV), kwargs...)
211211

212212
return ApproximateJacobianSolveCache{INV, GB, iip, maxtime !== nothing}(

src/core/generalized_first_order.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ function SciMLBase.__init(
156156
linsolve = get_linear_solver(alg.descent)
157157

158158
abstol, reltol, termination_cache = init_termination_cache(
159-
abstol, reltol, fu, u, termination_condition)
159+
prob, abstol, reltol, fu, u, termination_condition)
160160
linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)
161161

162162
jac_cache = JacobianCache(
@@ -191,7 +191,8 @@ function SciMLBase.__init(
191191
GB = :LineSearch
192192
end
193193

194-
trace = init_nonlinearsolve_trace(alg, u, fu, ApplyArray(__zero, J), du; kwargs...)
194+
trace = init_nonlinearsolve_trace(
195+
prob, alg, u, fu, ApplyArray(__zero, J), du; kwargs...)
195196

196197
return GeneralizedFirstOrderAlgorithmCache{iip, GB, maxtime !== nothing}(
197198
fu, u, u_cache, p, du, J, alg, prob, jac_cache, descent_cache, linesearch_cache,

src/core/spectral_methods.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane
133133
prob, alg.linesearch, prob.f, fu, u, prob.p; maxiters, internalnorm, kwargs...)
134134

135135
abstol, reltol, tc_cache = init_termination_cache(
136-
abstol, reltol, fu, u_cache, termination_condition)
137-
trace = init_nonlinearsolve_trace(alg, u, fu, nothing, du; kwargs...)
136+
prob, abstol, reltol, fu, u_cache, termination_condition)
137+
trace = init_nonlinearsolve_trace(prob, alg, u, fu, nothing, du; kwargs...)
138138

139139
if alg.σ_1 === nothing
140140
σ_n = dot(u, u) / dot(u, fu)

src/internal/termination.jl

+19-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,23 @@
1-
function init_termination_cache(abstol, reltol, du, u, ::Nothing)
2-
return init_termination_cache(
3-
abstol, reltol, du, u, AbsSafeBestTerminationMode(; max_stalled_steps = 32))
1+
function init_termination_cache(prob::NonlinearProblem, abstol, reltol, du, u, ::Nothing)
2+
return init_termination_cache(prob, abstol, reltol, du, u,
3+
AbsSafeBestTerminationMode(Base.Fix1(maximum, abs); max_stalled_steps = 32))
44
end
5-
function init_termination_cache(abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
6-
tc_cache = init(du, u, tc; abstol, reltol, use_deprecated_retcodes = Val(false))
5+
function init_termination_cache(
6+
prob::NonlinearLeastSquaresProblem, abstol, reltol, du, u, ::Nothing)
7+
return init_termination_cache(prob, abstol, reltol, du, u,
8+
AbsSafeBestTerminationMode(Base.Fix2(norm, 2); max_stalled_steps = 32))
9+
end
10+
11+
function init_termination_cache(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
12+
abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
13+
tc_ = if hasfield(typeof(tc), :internalnorm) && tc.internalnorm === nothing
14+
internalnorm = ifelse(
15+
prob isa NonlinearProblem, Base.Fix1(maximum, abs), Base.Fix2(norm, 2))
16+
DiffEqBase.set_termination_mode_internalnorm(tc, internalnorm)
17+
else
18+
tc
19+
end
20+
tc_cache = init(du, u, tc_; abstol, reltol, use_deprecated_retcodes = Val(false))
721
return DiffEqBase.get_abstol(tc_cache), DiffEqBase.get_reltol(tc_cache), tc_cache
822
end
923

src/internal/tracing.jl

+50-33
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ for Tr in (:TraceMinimal, :TraceWithJacobianConditionNumber, :TraceAll)
5252
end
5353

5454
# NonlinearSolve Tracing Utilities
55-
@concrete struct NonlinearSolveTraceEntry
55+
@concrete struct NonlinearSolveTraceEntry{nType}
5656
iteration::Int
5757
fnorm
5858
stepnorm
@@ -63,19 +63,27 @@ end
6363
δu
6464
end
6565

66-
function __show_top_level(io::IO, entry::NonlinearSolveTraceEntry)
66+
function __show_top_level(io::IO, entry::NonlinearSolveTraceEntry{nType}) where {nType}
6767
if entry.condJ === nothing
6868
@printf io "%-8s %-20s %-20s\n" "----" "-------------" "-----------"
69-
@printf io "%-8s %-20s %-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm"
69+
if nType === :L2
70+
@printf io "%-8s %-20s %-20s\n" "Iter" "f(u) 2-norm" "Step 2-norm"
71+
else
72+
@printf io "%-8s %-20s %-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm"
73+
end
7074
@printf io "%-8s %-20s %-20s\n" "----" "-------------" "-----------"
7175
else
7276
@printf io "%-8s %-20s %-20s %-20s\n" "----" "-------------" "-----------" "-------"
73-
@printf io "%-8s %-20s %-20s %-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm" "cond(J)"
77+
if nType === :L2
78+
@printf io "%-8s %-20s %-20s %-20s\n" "Iter" "f(u) 2-norm" "Step 2-norm" "cond(J)"
79+
else
80+
@printf io "%-8s %-20s %-20s %-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm" "cond(J)"
81+
end
7482
@printf io "%-8s %-20s %-20s %-20s\n" "----" "-------------" "-----------" "-------"
7583
end
7684
end
7785

78-
function Base.show(io::IO, entry::NonlinearSolveTraceEntry)
86+
function Base.show(io::IO, entry::NonlinearSolveTraceEntry{nType}) where {nType}
7987
entry.iteration == 0 && __show_top_level(io, entry)
8088
if entry.iteration < 0
8189
# Special case for final entry
@@ -89,25 +97,32 @@ function Base.show(io::IO, entry::NonlinearSolveTraceEntry)
8997
return nothing
9098
end
9199

92-
function NonlinearSolveTraceEntry(iteration, fu, δu)
93-
return NonlinearSolveTraceEntry(
94-
iteration, norm(fu, Inf), norm(δu, 2), nothing, nothing, nothing, nothing, nothing)
100+
function NonlinearSolveTraceEntry(prob::AbstractNonlinearProblem, iteration, fu, δu)
101+
nType = ifelse(prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
102+
fnorm = prob isa NonlinearLeastSquaresProblem ? norm(fu, 2) : norm(fu, Inf)
103+
return NonlinearSolveTraceEntry{nType}(
104+
iteration, fnorm, norm(δu, 2), nothing, nothing, nothing, nothing, nothing)
95105
end
96106

97-
function NonlinearSolveTraceEntry(iteration, fu, δu, J)
98-
return NonlinearSolveTraceEntry(iteration, norm(fu, Inf), norm(δu, 2),
99-
__cond(J), nothing, nothing, nothing, nothing)
107+
function NonlinearSolveTraceEntry(prob::AbstractNonlinearProblem, iteration, fu, δu, J)
108+
nType = ifelse(prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
109+
fnorm = prob isa NonlinearLeastSquaresProblem ? norm(fu, 2) : norm(fu, Inf)
110+
return NonlinearSolveTraceEntry{nType}(
111+
iteration, fnorm, norm(δu, 2), __cond(J), nothing, nothing, nothing, nothing)
100112
end
101113

102-
function NonlinearSolveTraceEntry(iteration, fu, δu, J, u)
103-
return NonlinearSolveTraceEntry(iteration, norm(fu, Inf), norm(δu, 2), __cond(J),
114+
function NonlinearSolveTraceEntry(prob::AbstractNonlinearProblem, iteration, fu, δu, J, u)
115+
nType = ifelse(prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
116+
fnorm = prob isa NonlinearLeastSquaresProblem ? norm(fu, 2) : norm(fu, Inf)
117+
return NonlinearSolveTraceEntry{nType}(iteration, fnorm, norm(δu, 2), __cond(J),
104118
__copy(J), __copy(u), __copy(fu), __copy(δu))
105119
end
106120

107121
@concrete struct NonlinearSolveTrace{
108122
show_trace, store_trace, Tr <: AbstractNonlinearSolveTraceLevel}
109123
history
110124
trace_level::Tr
125+
prob
111126
end
112127

113128
function reset!(trace::NonlinearSolveTrace)
@@ -123,61 +138,63 @@ function Base.show(io::IO, trace::NonlinearSolveTrace)
123138
return nothing
124139
end
125140

126-
function init_nonlinearsolve_trace(alg, u, fu, J, δu; show_trace::Val = Val(false),
141+
function init_nonlinearsolve_trace(prob, alg, u, fu, J, δu; show_trace::Val = Val(false),
127142
trace_level::AbstractNonlinearSolveTraceLevel = TraceMinimal(),
128143
store_trace::Val = Val(false), uses_jac_inverse = Val(false), kwargs...)
129144
return init_nonlinearsolve_trace(
130-
alg, show_trace, trace_level, store_trace, u, fu, J, δu, uses_jac_inverse)
145+
prob, alg, show_trace, trace_level, store_trace, u, fu, J, δu, uses_jac_inverse)
131146
end
132147

133-
function init_nonlinearsolve_trace(
134-
alg, ::Val{show_trace}, trace_level::AbstractNonlinearSolveTraceLevel,
135-
::Val{store_trace}, u, fu, J, δu,
136-
::Val{uses_jac_inverse}) where {show_trace, store_trace, uses_jac_inverse}
148+
function init_nonlinearsolve_trace(prob::AbstractNonlinearProblem, alg, ::Val{show_trace},
149+
trace_level::AbstractNonlinearSolveTraceLevel, ::Val{store_trace}, u, fu, J,
150+
δu, ::Val{uses_jac_inverse}) where {show_trace, store_trace, uses_jac_inverse}
137151
if show_trace
138152
print("\nAlgorithm: ")
139153
Base.printstyled(alg, "\n\n"; color = :green, bold = true)
140154
end
141155
J_ = uses_jac_inverse ? (trace_level isa TraceMinimal ? J : __safe_inv(J)) : J
142156
history = __init_trace_history(
143-
Val{show_trace}(), trace_level, Val{store_trace}(), u, fu, J_, δu)
144-
return NonlinearSolveTrace{show_trace, store_trace}(history, trace_level)
157+
prob, Val{show_trace}(), trace_level, Val{store_trace}(), u, fu, J_, δu)
158+
return NonlinearSolveTrace{show_trace, store_trace}(history, trace_level, prob)
145159
end
146160

147-
function __init_trace_history(::Val{show_trace}, trace_level, ::Val{store_trace},
148-
u, fu, J, δu) where {show_trace, store_trace}
161+
function __init_trace_history(
162+
prob::AbstractNonlinearProblem, ::Val{show_trace}, trace_level,
163+
::Val{store_trace}, u, fu, J, δu) where {show_trace, store_trace}
149164
!store_trace && !show_trace && return nothing
150-
entry = __trace_entry(trace_level, 0, u, fu, J, δu)
165+
entry = __trace_entry(prob, trace_level, 0, u, fu, J, δu)
151166
show_trace && show(entry)
152167
store_trace && return NonlinearSolveTraceEntry[entry]
153168
return nothing
154169
end
155170

156-
function __trace_entry(::TraceMinimal, iter, u, fu, J, δu, α = 1)
157-
return NonlinearSolveTraceEntry(iter, fu, δu .* α)
171+
function __trace_entry(prob, ::TraceMinimal, iter, u, fu, J, δu, α = 1)
172+
return NonlinearSolveTraceEntry(prob, iter, fu, δu .* α)
158173
end
159-
function __trace_entry(::TraceWithJacobianConditionNumber, iter, u, fu, J, δu, α = 1)
160-
return NonlinearSolveTraceEntry(iter, fu, δu .* α, J)
174+
function __trace_entry(prob, ::TraceWithJacobianConditionNumber, iter, u, fu, J, δu, α = 1)
175+
return NonlinearSolveTraceEntry(prob, iter, fu, δu .* α, J)
161176
end
162-
function __trace_entry(::TraceAll, iter, u, fu, J, δu, α = 1)
163-
return NonlinearSolveTraceEntry(iter, fu, δu .* α, J, u)
177+
function __trace_entry(prob, ::TraceAll, iter, u, fu, J, δu, α = 1)
178+
return NonlinearSolveTraceEntry(prob, iter, fu, δu .* α, J, u)
164179
end
165180

166181
function update_trace!(trace::NonlinearSolveTrace{ShT, StT}, iter, u, fu, J, δu,
167182
α = 1; last::Val{L} = Val(false)) where {ShT, StT, L}
168183
!StT && !ShT && return nothing
169184

170185
if L
171-
entry = NonlinearSolveTraceEntry(
172-
-1, norm(fu, Inf), NaN32, nothing, nothing, nothing, nothing, nothing)
186+
nType = ifelse(trace.prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
187+
fnorm = trace.prob isa NonlinearLeastSquaresProblem ? norm(fu, 2) : norm(fu, Inf)
188+
entry = NonlinearSolveTraceEntry{nType}(
189+
-1, fnorm, NaN32, nothing, nothing, nothing, nothing, nothing)
173190
ShT && show(entry)
174191
return trace
175192
end
176193

177194
show_now = ShT && (mod1(iter, trace.trace_level.print_frequency) == 1)
178195
store_now = StT && (mod1(iter, trace.trace_level.store_frequency) == 1)
179196
(show_now || store_now) &&
180-
(entry = __trace_entry(trace.trace_level, iter, u, fu, J, δu, α))
197+
(entry = __trace_entry(trace.prob, trace.trace_level, iter, u, fu, J, δu, α))
181198
store_now && push!(trace.history, entry)
182199
show_now && show(entry)
183200
return trace

test/core/forward_ad_tests.jl

-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ end
7979
gs = abs.(ForwardDiff.derivative(solve_with(Val{mode}(), u0, alg), p))
8080
gs_true = abs.(jacobian_f(u0, p))
8181
if !(isapprox(gs, gs_true, atol = 1e-5))
82-
@show sol.retcode, sol.u
8382
@error "ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg)" forwardiff_gradient=gs true_gradient=gs_true
8483
else
8584
@test abs.(gs)abs.(gs_true) atol=1e-5

test/core/nlls_tests.jl

+9-8
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ using Reexport
66
true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
77
true_function(y, x, θ) = (@. y = θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]))
88

9-
θ_true = [1.0, 0.1, 2.0, 0.5]
9+
const θ_true = [1.0, 0.1, 2.0, 0.5]
1010

11-
x = [-1.0, -0.5, 0.0, 0.5, 1.0]
11+
const x = [-1.0, -0.5, 0.0, 0.5, 1.0]
1212

13-
y_target = true_function(x, θ_true)
13+
const y_target = true_function(x, θ_true)
1414

1515
function loss_function(θ, p)
1616
= true_function(p, θ)
@@ -23,7 +23,7 @@ function loss_function(resid, θ, p)
2323
return resid
2424
end
2525

26-
θ_init = θ_true .+ randn!(StableRNG(0), similar(θ_true)) * 0.1
26+
const θ_init = θ_true .+ randn!(StableRNG(0), similar(θ_true)) * 0.1
2727

2828
solvers = []
2929
for linsolve in [nothing, LUFactorization(), KrylovJL_GMRES(), KrylovJL_LSMR()]
@@ -56,9 +56,9 @@ end
5656
nlls_problems = [prob_oop, prob_iip]
5757

5858
for prob in nlls_problems, solver in solvers
59-
sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
59+
sol = solve(prob, solver; maxiters = 10000, abstol = 1e-6)
6060
@test SciMLBase.successful_retcode(sol)
61-
@test maximum(abs, sol.resid) < 1e-6
61+
@test norm(sol.resid, 2) < 1e-6
6262
end
6363
end
6464

@@ -90,8 +90,9 @@ end
9090
x)]
9191

9292
for prob in probs, solver in solvers
93-
sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
94-
@test maximum(abs, sol.resid) < 1e-6
93+
sol = solve(prob, solver; maxiters = 10000, abstol = 1e-6)
94+
@test SciMLBase.successful_retcode(sol)
95+
@test norm(sol.resid, 2) < 1e-6
9596
end
9697
end
9798

0 commit comments

Comments
 (0)