Skip to content

Commit cc4f3cf

Browse files
committed
fix: finish broken termination conditions
1 parent 5481e82 commit cc4f3cf

File tree

1 file changed

+73
-6
lines changed

1 file changed

+73
-6
lines changed

lib/NonlinearSolveBase/src/termination_conditions.jl

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,83 @@ function update_u!!(cache::NonlinearTerminationModeCache, u)
3030
end
3131
end
3232

33-
function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T}, T},
34-
mode::AbstractNonlinearTerminationMode, saved_value_prototype...;
35-
abstol = nothing, reltol = nothing, kwargs...) where {T <: Number}
36-
error("Not yet implemented...")
33+
function SciMLBase.init(
34+
du, u, mode::AbstractNonlinearTerminationMode, saved_value_prototype...;
35+
abstol = nothing, reltol = nothing, kwargs...)
36+
T = promote_type(eltype(du), eltype(u))
37+
abstol = get_tolerance(abstol, T)
38+
reltol = get_tolerance(reltol, T)
39+
TT = typeof(abstol)
40+
41+
u_unaliased = mode isa AbstractSafeBestNonlinearTerminationMode ?
42+
(ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
43+
44+
if mode isa AbstractSafeNonlinearTerminationMode
45+
if mode isa AbsNormSafeTerminationMode || mode isa AbsNormSafeBestTerminationMode
46+
initial_objective = Linf_NORM(du)
47+
u0_norm = nothing
48+
else
49+
initial_objective = Linf_NORM(du) /
50+
(Utils.nonallocating_maximum(+, du, u) + eps(TT))
51+
u0_norm = mode.max_stalled_steps === nothing ? nothing : L2_NORM(u)
52+
end
53+
objectives_trace = Vector{TT}(undef, mode.patience_steps)
54+
step_norm_trace = mode.max_stalled_steps === nothing ? nothing :
55+
Vector{TT}(undef, mode.max_stalled_steps)
56+
if step_norm_trace !== nothing &&
57+
ArrayInterface.can_setindex(u_unaliased) &&
58+
!(u_unaliased isa Number)
59+
u_diff_cache = similar(u_unaliased)
60+
else
61+
u_diff_cache = u_unaliased
62+
end
63+
else
64+
initial_objective = nothing
65+
objectives_trace = nothing
66+
u0_norm = nothing
67+
step_norm_trace = nothing
68+
best_value = Utils.convert_real(T, Inf)
69+
max_stalled_steps = nothing
70+
u_diff_cache = u_unaliased
71+
end
72+
73+
length(saved_value_prototype) == 0 && (saved_value_prototype = nothing)
74+
75+
return NonlinearTerminationModeCache(
76+
u_unaliased, ReturnCode.Default, abstol, reltol, best_value, mode,
77+
initial_objective, objectives_trace, 0, saved_value_prototype,
78+
u0_norm, step_norm_trace, max_stalled_steps, u_diff_cache)
3779
end
3880

3981
function SciMLBase.reinit!(
4082
cache::NonlinearTerminationModeCache, du, u, saved_value_prototype...;
41-
abstol = nothing, reltol = nothing, kwargs...)
42-
error("Not yet implemented...")
83+
abstol = cache.abstol, reltol = cache.reltol, kwargs...)
84+
T = eltype(cache.abstol)
85+
length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype)
86+
87+
mode = cache.mode
88+
u_unaliased = mode isa AbstractSafeBestNonlinearTerminationMode ?
89+
(ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
90+
cache.u = u_unaliased
91+
cache.retcode = ReturnCode.Default
92+
93+
cache.abstol = get_tolerance(abstol, T)
94+
cache.reltol = get_tolerance(reltol, T)
95+
cache.nsteps = 0
96+
TT = typeof(cache.abstol)
97+
98+
if mode isa AbstractSafeNonlinearTerminationMode
99+
if mode isa AbsNormSafeTerminationMode || mode isa AbsNormSafeBestTerminationMode
100+
cache.initial_objective = Linf_NORM(du)
101+
else
102+
cache.initial_objective = Linf_NORM(du) /
103+
(Utils.nonallocating_maximum(+, du, u) + eps(TT))
104+
cache.max_stalled_steps !== nothing && (cache.u0_norm = L2_NORM(u))
105+
end
106+
cache.best_objective_value = cache.initial_objective
107+
else
108+
cache.best_objective_value = Utils.convert_real(T, Inf)
109+
end
43110
end
44111

45112
## This dispatch is needed based on how Terminating Callback works!

0 commit comments

Comments
 (0)