1
-
2
- SciMLBase . supports_opt_cache_interface ( :: LBFGS ) = true
3
- SciMLBase . allowsbounds ( :: LBFGS ) = true
4
- SciMLBase . requiresgradient ( :: LBFGS ) = true
5
- SciMLBase . allowsconstraints ( :: LBFGS ) = true
6
- SciMLBase . requiresconsjac ( :: LBFGS ) = true
7
-
8
- function task_message_to_string (task :: Vector{UInt8} )
9
- return String (task)
1
+ @kwdef struct AugLag
2
+ inner
3
+ τ = 0.5
4
+ γ = 10.0
5
+ λmin = - 1e20
6
+ λmax = 1e20
7
+ μmin = 0.0
8
+ μmax = 1e20
9
+ ϵ = 1e-8
10
10
end
11
11
12
- function __map_optimizer_args (cache:: Optimization.OptimizationCache , opt:: LBFGS ;
12
+ SciMLBase. supports_opt_cache_interface (:: AugLag ) = true
13
+ SciMLBase. allowsbounds (:: AugLag ) = true
14
+ SciMLBase. requiresgradient (:: AugLag ) = true
15
+ SciMLBase. allowsconstraints (:: AugLag ) = true
16
+ SciMLBase. requiresconsjac (:: AugLag ) = true
17
+
18
+ function __map_optimizer_args (cache:: Optimization.OptimizationCache , opt:: AugLag ;
13
19
callback = nothing ,
14
20
maxiters:: Union{Number, Nothing} = nothing ,
15
21
maxtime:: Union{Number, Nothing} = nothing ,
@@ -62,7 +68,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
62
68
UC,
63
69
S,
64
70
O < :
65
- LBFGS ,
71
+ AugLag ,
66
72
D,
67
73
P,
68
74
C
@@ -90,10 +96,10 @@ if !isnothing(cache.f.cons)
90
96
91
97
cons_tmp = zeros (eltype (cache. u0), length (cache. lcons))
92
98
cache. f. cons (cons_tmp, cache. u0)
93
- ρ = max (1e-6 , min (10 , 2 * (abs (cache. f (cache. u0, cache. p))) / norm (cons_tmp)))
99
+ ρ = max (1e-6 , min (10 , 2 * (abs (cache. f (cache. u0, iterate ( cache. p)[ 1 ] ))) / norm (cons_tmp)))
94
100
95
- _loss = function (θ)
96
- x = cache. f (θ, cache . p)
101
+ _loss = function (θ, p = cache . p )
102
+ x = cache. f (θ, p)
97
103
cons_tmp .= zero (eltype (θ))
98
104
cache. f. cons (cons_tmp, θ)
99
105
cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache. lcons[eq_inds]
@@ -114,8 +120,8 @@ if !isnothing(cache.f.cons)
114
120
ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex (ineq_inds)]
115
121
eqidxs = eqidxs[eqidxs .!= nothing ]
116
122
ineqidxs = ineqidxs[ineqidxs .!= nothing ]
117
- function aug_grad (G, θ)
118
- cache. f. grad (G, θ)
123
+ function aug_grad (G, θ, p )
124
+ cache. f. grad (G, θ, p )
119
125
if ! isnothing (cache. f. cons_jac_prototype)
120
126
J = Float64 .(cache. f. cons_jac_prototype)
121
127
else
@@ -139,23 +145,15 @@ if !isnothing(cache.f.cons)
139
145
opt_ret = ReturnCode. MaxIters
140
146
n = length (cache. u0)
141
147
142
- sol = solve ( .... )
148
+ augprob = OptimizationProblem ( OptimizationFunction (_loss; grad = aug_grad), cache . u0, cache . p )
143
149
144
150
solver_kwargs = Base. structdiff (solver_kwargs, (; lb = nothing , ub = nothing ))
145
151
146
- for i in 1 : maxiters
152
+ for i in 1 : ( maxiters/ 10 )
147
153
prev_eqcons .= cons_tmp[eq_inds] .- cache. lcons[eq_inds]
148
154
prevβ .= copy (β)
149
- res = optimizer (_loss, aug_grad, θ, bounds; solver_kwargs... ,
150
- m = cache. opt. m, pgtol = sqrt (ϵ), maxiter = maxiters / 100 )
151
- # @show res[2]
152
- # @show res[1]
153
- # @show cons_tmp
154
- # @show λ
155
- # @show β
156
- # @show μ
157
- # @show ρ
158
- θ = res[2 ]
155
+ res = solve (augprob, cache. opt. inner, maxiters = maxiters / 10 )
156
+ θ = res. u
159
157
cons_tmp .= 0.0
160
158
cache. f. cons (cons_tmp, θ)
161
159
λ = max .(min .(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache. lcons[eq_inds])), λmin)
@@ -172,11 +170,10 @@ if !isnothing(cache.f.cons)
172
170
break
173
171
end
174
172
end
175
- end
176
-
177
- stats = Optimization. OptimizationStats (; iterations = maxiters,
173
+ stats = Optimization. OptimizationStats (; iterations = maxiters,
178
174
time = 0.0 , fevals = maxiters, gevals = maxiters)
179
- return SciMLBase. build_solution (
180
- cache, cache. opt, res[ 2 ], cache . f (res[ 2 ], cache . p)[ 1 ] ,
175
+ return SciMLBase. build_solution (
176
+ cache, cache. opt, θ, x ,
181
177
stats = stats, retcode = opt_ret)
178
+ end
182
179
end
0 commit comments