@@ -30,16 +30,83 @@ function update_u!!(cache::NonlinearTerminationModeCache, u)
30
30
end
31
31
end
32
32
33
- function SciMLBase. init (du:: Union{AbstractArray{T}, T} , u:: Union{AbstractArray{T}, T} ,
34
- mode:: AbstractNonlinearTerminationMode , saved_value_prototype... ;
35
- abstol = nothing , reltol = nothing , kwargs... ) where {T <: Number }
36
- error (" Not yet implemented..." )
33
+ function SciMLBase. init (
34
+ du, u, mode:: AbstractNonlinearTerminationMode , saved_value_prototype... ;
35
+ abstol = nothing , reltol = nothing , kwargs... )
36
+ T = promote_type (eltype (du), eltype (u))
37
+ abstol = get_tolerance (abstol, T)
38
+ reltol = get_tolerance (reltol, T)
39
+ TT = typeof (abstol)
40
+
41
+ u_unaliased = mode isa AbstractSafeBestNonlinearTerminationMode ?
42
+ (ArrayInterface. can_setindex (u) ? copy (u) : u) : nothing
43
+
44
+ if mode isa AbstractSafeNonlinearTerminationMode
45
+ if mode isa AbsNormSafeTerminationMode || mode isa AbsNormSafeBestTerminationMode
46
+ initial_objective = Linf_NORM (du)
47
+ u0_norm = nothing
48
+ else
49
+ initial_objective = Linf_NORM (du) /
50
+ (Utils. nonallocating_maximum (+ , du, u) + eps (TT))
51
+ u0_norm = mode. max_stalled_steps === nothing ? nothing : L2_NORM (u)
52
+ end
53
+ objectives_trace = Vector {TT} (undef, mode. patience_steps)
54
+ step_norm_trace = mode. max_stalled_steps === nothing ? nothing :
55
+ Vector {TT} (undef, mode. max_stalled_steps)
56
+ if step_norm_trace != = nothing &&
57
+ ArrayInterface. can_setindex (u_unaliased) &&
58
+ ! (u_unaliased isa Number)
59
+ u_diff_cache = similar (u_unaliased)
60
+ else
61
+ u_diff_cache = u_unaliased
62
+ end
63
+ else
64
+ initial_objective = nothing
65
+ objectives_trace = nothing
66
+ u0_norm = nothing
67
+ step_norm_trace = nothing
68
+ best_value = Utils. convert_real (T, Inf )
69
+ max_stalled_steps = nothing
70
+ u_diff_cache = u_unaliased
71
+ end
72
+
73
+ length (saved_value_prototype) == 0 && (saved_value_prototype = nothing )
74
+
75
+ return NonlinearTerminationModeCache (
76
+ u_unaliased, ReturnCode. Default, abstol, reltol, best_value, mode,
77
+ initial_objective, objectives_trace, 0 , saved_value_prototype,
78
+ u0_norm, step_norm_trace, max_stalled_steps, u_diff_cache)
37
79
end
38
80
39
81
function SciMLBase. reinit! (
40
82
cache:: NonlinearTerminationModeCache , du, u, saved_value_prototype... ;
41
- abstol = nothing , reltol = nothing , kwargs... )
42
- error (" Not yet implemented..." )
83
+ abstol = cache. abstol, reltol = cache. reltol, kwargs... )
84
+ T = eltype (cache. abstol)
85
+ length (saved_value_prototype) != 0 && (cache. saved_values = saved_value_prototype)
86
+
87
+ mode = cache. mode
88
+ u_unaliased = mode isa AbstractSafeBestNonlinearTerminationMode ?
89
+ (ArrayInterface. can_setindex (u) ? copy (u) : u) : nothing
90
+ cache. u = u_unaliased
91
+ cache. retcode = ReturnCode. Default
92
+
93
+ cache. abstol = get_tolerance (abstol, T)
94
+ cache. reltol = get_tolerance (reltol, T)
95
+ cache. nsteps = 0
96
+ TT = typeof (cache. abstol)
97
+
98
+ if mode isa AbstractSafeNonlinearTerminationMode
99
+ if mode isa AbsNormSafeTerminationMode || mode isa AbsNormSafeBestTerminationMode
100
+ cache. initial_objective = Linf_NORM (du)
101
+ else
102
+ cache. initial_objective = Linf_NORM (du) /
103
+ (Utils. nonallocating_maximum (+ , du, u) + eps (TT))
104
+ cache. max_stalled_steps != = nothing && (cache. u0_norm = L2_NORM (u))
105
+ end
106
+ cache. best_objective_value = cache. initial_objective
107
+ else
108
+ cache. best_objective_value = Utils. convert_real (T, Inf )
109
+ end
43
110
end
44
111
45
112
# # This dispatch is needed based on how Terminating Callback works!
0 commit comments