1
+
2
+ SciMLBase. supports_opt_cache_interface (:: LBFGS ) = true
3
+ SciMLBase. allowsbounds (:: LBFGS ) = true
4
+ SciMLBase. requiresgradient (:: LBFGS ) = true
5
+ SciMLBase. allowsconstraints (:: LBFGS ) = true
6
+ SciMLBase. requiresconsjac (:: LBFGS ) = true
7
+
8
+ function task_message_to_string (task:: Vector{UInt8} )
9
+ return String (task)
10
+ end
11
+
12
+ function __map_optimizer_args (cache:: Optimization.OptimizationCache , opt:: LBFGS ;
13
+ callback = nothing ,
14
+ maxiters:: Union{Number, Nothing} = nothing ,
15
+ maxtime:: Union{Number, Nothing} = nothing ,
16
+ abstol:: Union{Number, Nothing} = nothing ,
17
+ reltol:: Union{Number, Nothing} = nothing ,
18
+ verbose:: Bool = false ,
19
+ kwargs... )
20
+ if ! isnothing (abstol)
21
+ @warn " common abstol is currently not used by $(opt) "
22
+ end
23
+ if ! isnothing (maxtime)
24
+ @warn " common abstol is currently not used by $(opt) "
25
+ end
26
+
27
+ mapped_args = (;)
28
+
29
+ if cache. lb != = nothing && cache. ub != = nothing
30
+ mapped_args = (; mapped_args... , lb = cache. lb, ub = cache. ub)
31
+ end
32
+
33
+ if ! isnothing (maxiters)
34
+ mapped_args = (; mapped_args... , maxiter = maxiters)
35
+ end
36
+
37
+ if ! isnothing (reltol)
38
+ mapped_args = (; mapped_args... , pgtol = reltol)
39
+ end
40
+
41
+ return mapped_args
42
+ end
43
+
44
+ function SciMLBase. __solve (cache:: OptimizationCache {
45
+ F,
46
+ RC,
47
+ LB,
48
+ UB,
49
+ LC,
50
+ UC,
51
+ S,
52
+ O,
53
+ D,
54
+ P,
55
+ C
56
+ }) where {
57
+ F,
58
+ RC,
59
+ LB,
60
+ UB,
61
+ LC,
62
+ UC,
63
+ S,
64
+ O < :
65
+ LBFGS,
66
+ D,
67
+ P,
68
+ C
69
+ }
70
+ maxiters = Optimization. _check_and_convert_maxiters (cache. solver_args. maxiters)
71
+
72
+ local x
73
+
74
+ solver_kwargs = __map_optimizer_args (cache, cache. opt; maxiters, cache. solver_args... )
75
+
76
+ if ! isnothing (cache. f. cons)
77
+ eq_inds = [cache. lcons[i] == cache. ucons[i] for i in eachindex (cache. lcons)]
78
+ ineq_inds = (! ). (eq_inds)
79
+
80
+ τ = cache. opt. τ
81
+ γ = cache. opt. γ
82
+ λmin = cache. opt. λmin
83
+ λmax = cache. opt. λmax
84
+ μmin = cache. opt. μmin
85
+ μmax = cache. opt. μmax
86
+ ϵ = cache. opt. ϵ
87
+
88
+ λ = zeros (eltype (cache. u0), sum (eq_inds))
89
+ μ = zeros (eltype (cache. u0), sum (ineq_inds))
90
+
91
+ cons_tmp = zeros (eltype (cache. u0), length (cache. lcons))
92
+ cache. f. cons (cons_tmp, cache. u0)
93
+ ρ = max (1e-6 , min (10 , 2 * (abs (cache. f (cache. u0, cache. p))) / norm (cons_tmp)))
94
+
95
+ _loss = function (θ)
96
+ x = cache. f (θ, cache. p)
97
+ cons_tmp .= zero (eltype (θ))
98
+ cache. f. cons (cons_tmp, θ)
99
+ cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache. lcons[eq_inds]
100
+ cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache. ucons[ineq_inds]
101
+ opt_state = Optimization. OptimizationState (u = θ, objective = x[1 ])
102
+ if cache. callback (opt_state, x... )
103
+ error (" Optimization halted by callback." )
104
+ end
105
+ return x[1 ] + sum (@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2 )) +
106
+ 1 / (2 * ρ) * sum ((max .(Ref (0.0 ), μ .+ (ρ .* cons_tmp[ineq_inds]))) .^ 2 )
107
+ end
108
+
109
+ prev_eqcons = zero (λ)
110
+ θ = cache. u0
111
+ β = max .(cons_tmp[ineq_inds], Ref (0.0 ))
112
+ prevβ = zero (β)
113
+ eqidxs = [eq_inds[i] > 0 ? i : nothing for i in eachindex (ineq_inds)]
114
+ ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex (ineq_inds)]
115
+ eqidxs = eqidxs[eqidxs .!= nothing ]
116
+ ineqidxs = ineqidxs[ineqidxs .!= nothing ]
117
+ function aug_grad (G, θ)
118
+ cache. f. grad (G, θ)
119
+ if ! isnothing (cache. f. cons_jac_prototype)
120
+ J = Float64 .(cache. f. cons_jac_prototype)
121
+ else
122
+ J = zeros ((length (cache. lcons), length (θ)))
123
+ end
124
+ cache. f. cons_j (J, θ)
125
+ __tmp = zero (cons_tmp)
126
+ cache. f. cons (__tmp, θ)
127
+ __tmp[eq_inds] .= __tmp[eq_inds] .- cache. lcons[eq_inds]
128
+ __tmp[ineq_inds] .= __tmp[ineq_inds] .- cache. ucons[ineq_inds]
129
+ G .+ = sum (
130
+ λ[i] .* J[idx, :] + ρ * (__tmp[idx] .* J[idx, :])
131
+ for (i, idx) in enumerate (eqidxs);
132
+ init = zero (G)) # should be jvp
133
+ G .+ = sum (
134
+ 1 / ρ * (max .(Ref (0.0 ), μ[i] .+ (ρ .* __tmp[idx])) .* J[idx, :])
135
+ for (i, idx) in enumerate (ineqidxs);
136
+ init = zero (G)) # should be jvp
137
+ end
138
+
139
+ opt_ret = ReturnCode. MaxIters
140
+ n = length (cache. u0)
141
+
142
+ sol = solve (.... )
143
+
144
+ solver_kwargs = Base. structdiff (solver_kwargs, (; lb = nothing , ub = nothing ))
145
+
146
+ for i in 1 : maxiters
147
+ prev_eqcons .= cons_tmp[eq_inds] .- cache. lcons[eq_inds]
148
+ prevβ .= copy (β)
149
+ res = optimizer (_loss, aug_grad, θ, bounds; solver_kwargs... ,
150
+ m = cache. opt. m, pgtol = sqrt (ϵ), maxiter = maxiters / 100 )
151
+ # @show res[2]
152
+ # @show res[1]
153
+ # @show cons_tmp
154
+ # @show λ
155
+ # @show β
156
+ # @show μ
157
+ # @show ρ
158
+ θ = res[2 ]
159
+ cons_tmp .= 0.0
160
+ cache. f. cons (cons_tmp, θ)
161
+ λ = max .(min .(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache. lcons[eq_inds])), λmin)
162
+ β = max .(cons_tmp[ineq_inds], - 1 .* μ ./ ρ)
163
+ μ = min .(μmax, max .(μ .+ ρ * cons_tmp[ineq_inds], μmin))
164
+ if max (norm (cons_tmp[eq_inds] .- cache. lcons[eq_inds], Inf ), norm (β, Inf )) >
165
+ τ * max (norm (prev_eqcons, Inf ), norm (prevβ, Inf ))
166
+ ρ = γ * ρ
167
+ end
168
+ if norm (
169
+ (cons_tmp[eq_inds] .- cache. lcons[eq_inds]) ./ cons_tmp[eq_inds], Inf ) <
170
+ ϵ && norm (β, Inf ) < ϵ
171
+ opt_ret = ReturnCode. Success
172
+ break
173
+ end
174
+ end
175
+ end
176
+
177
+ stats = Optimization. OptimizationStats (; iterations = maxiters,
178
+ time = 0.0 , fevals = maxiters, gevals = maxiters)
179
+ return SciMLBase. build_solution (
180
+ cache, cache. opt, res[2 ], cache. f (res[2 ], cache. p)[1 ],
181
+ stats = stats, retcode = opt_ret)
182
+ end
0 commit comments