Skip to content

Commit 370286b

Browse files
Copy auglag from lbfgsb.jl
1 parent f6a7301 commit 370286b

File tree

1 file changed

+182
-0
lines changed

1 file changed

+182
-0
lines changed

src/auglag.jl

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
2+
SciMLBase.supports_opt_cache_interface(::LBFGS) = true
3+
SciMLBase.allowsbounds(::LBFGS) = true
4+
SciMLBase.requiresgradient(::LBFGS) = true
5+
SciMLBase.allowsconstraints(::LBFGS) = true
6+
SciMLBase.requiresconsjac(::LBFGS) = true
7+
8+
function task_message_to_string(task::Vector{UInt8})
9+
return String(task)
10+
end
11+
12+
function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::LBFGS;
13+
callback = nothing,
14+
maxiters::Union{Number, Nothing} = nothing,
15+
maxtime::Union{Number, Nothing} = nothing,
16+
abstol::Union{Number, Nothing} = nothing,
17+
reltol::Union{Number, Nothing} = nothing,
18+
verbose::Bool = false,
19+
kwargs...)
20+
if !isnothing(abstol)
21+
@warn "common abstol is currently not used by $(opt)"
22+
end
23+
if !isnothing(maxtime)
24+
@warn "common abstol is currently not used by $(opt)"
25+
end
26+
27+
mapped_args = (;)
28+
29+
if cache.lb !== nothing && cache.ub !== nothing
30+
mapped_args = (; mapped_args..., lb = cache.lb, ub = cache.ub)
31+
end
32+
33+
if !isnothing(maxiters)
34+
mapped_args = (; mapped_args..., maxiter = maxiters)
35+
end
36+
37+
if !isnothing(reltol)
38+
mapped_args = (; mapped_args..., pgtol = reltol)
39+
end
40+
41+
return mapped_args
42+
end
43+
44+
function SciMLBase.__solve(cache::OptimizationCache{
45+
F,
46+
RC,
47+
LB,
48+
UB,
49+
LC,
50+
UC,
51+
S,
52+
O,
53+
D,
54+
P,
55+
C
56+
}) where {
57+
F,
58+
RC,
59+
LB,
60+
UB,
61+
LC,
62+
UC,
63+
S,
64+
O <:
65+
LBFGS,
66+
D,
67+
P,
68+
C
69+
}
70+
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
71+
72+
local x
73+
74+
solver_kwargs = __map_optimizer_args(cache, cache.opt; maxiters, cache.solver_args...)
75+
76+
if !isnothing(cache.f.cons)
77+
eq_inds = [cache.lcons[i] == cache.ucons[i] for i in eachindex(cache.lcons)]
78+
ineq_inds = (!).(eq_inds)
79+
80+
τ = cache.opt.τ
81+
γ = cache.opt.γ
82+
λmin = cache.opt.λmin
83+
λmax = cache.opt.λmax
84+
μmin = cache.opt.μmin
85+
μmax = cache.opt.μmax
86+
ϵ = cache.opt.ϵ
87+
88+
λ = zeros(eltype(cache.u0), sum(eq_inds))
89+
μ = zeros(eltype(cache.u0), sum(ineq_inds))
90+
91+
cons_tmp = zeros(eltype(cache.u0), length(cache.lcons))
92+
cache.f.cons(cons_tmp, cache.u0)
93+
ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, cache.p))) / norm(cons_tmp)))
94+
95+
_loss = function (θ)
96+
x = cache.f(θ, cache.p)
97+
cons_tmp .= zero(eltype(θ))
98+
cache.f.cons(cons_tmp, θ)
99+
cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds]
100+
cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds]
101+
opt_state = Optimization.OptimizationState(u = θ, objective = x[1])
102+
if cache.callback(opt_state, x...)
103+
error("Optimization halted by callback.")
104+
end
105+
return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) +
106+
1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+.* cons_tmp[ineq_inds]))) .^ 2)
107+
end
108+
109+
prev_eqcons = zero(λ)
110+
θ = cache.u0
111+
β = max.(cons_tmp[ineq_inds], Ref(0.0))
112+
prevβ = zero(β)
113+
eqidxs = [eq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
114+
ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
115+
eqidxs = eqidxs[eqidxs .!= nothing]
116+
ineqidxs = ineqidxs[ineqidxs .!= nothing]
117+
function aug_grad(G, θ)
118+
cache.f.grad(G, θ)
119+
if !isnothing(cache.f.cons_jac_prototype)
120+
J = Float64.(cache.f.cons_jac_prototype)
121+
else
122+
J = zeros((length(cache.lcons), length(θ)))
123+
end
124+
cache.f.cons_j(J, θ)
125+
__tmp = zero(cons_tmp)
126+
cache.f.cons(__tmp, θ)
127+
__tmp[eq_inds] .= __tmp[eq_inds] .- cache.lcons[eq_inds]
128+
__tmp[ineq_inds] .= __tmp[ineq_inds] .- cache.ucons[ineq_inds]
129+
G .+= sum(
130+
λ[i] .* J[idx, :] + ρ * (__tmp[idx] .* J[idx, :])
131+
for (i, idx) in enumerate(eqidxs);
132+
init = zero(G)) #should be jvp
133+
G .+= sum(
134+
1 / ρ * (max.(Ref(0.0), μ[i] .+.* __tmp[idx])) .* J[idx, :])
135+
for (i, idx) in enumerate(ineqidxs);
136+
init = zero(G)) #should be jvp
137+
end
138+
139+
opt_ret = ReturnCode.MaxIters
140+
n = length(cache.u0)
141+
142+
sol = solve(....)
143+
144+
solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing))
145+
146+
for i in 1:maxiters
147+
prev_eqcons .= cons_tmp[eq_inds] .- cache.lcons[eq_inds]
148+
prevβ .= copy(β)
149+
res = optimizer(_loss, aug_grad, θ, bounds; solver_kwargs...,
150+
m = cache.opt.m, pgtol = sqrt(ϵ), maxiter = maxiters / 100)
151+
# @show res[2]
152+
# @show res[1]
153+
# @show cons_tmp
154+
# @show λ
155+
# @show β
156+
# @show μ
157+
# @show ρ
158+
θ = res[2]
159+
cons_tmp .= 0.0
160+
cache.f.cons(cons_tmp, θ)
161+
λ = max.(min.(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache.lcons[eq_inds])), λmin)
162+
β = max.(cons_tmp[ineq_inds], -1 .* μ ./ ρ)
163+
μ = min.(μmax, max.(μ .+ ρ * cons_tmp[ineq_inds], μmin))
164+
if max(norm(cons_tmp[eq_inds] .- cache.lcons[eq_inds], Inf), norm(β, Inf)) >
165+
τ * max(norm(prev_eqcons, Inf), norm(prevβ, Inf))
166+
ρ = γ * ρ
167+
end
168+
if norm(
169+
(cons_tmp[eq_inds] .- cache.lcons[eq_inds]) ./ cons_tmp[eq_inds], Inf) <
170+
ϵ && norm(β, Inf) < ϵ
171+
opt_ret = ReturnCode.Success
172+
break
173+
end
174+
end
175+
end
176+
177+
stats = Optimization.OptimizationStats(; iterations = maxiters,
178+
time = 0.0, fevals = maxiters, gevals = maxiters)
179+
return SciMLBase.build_solution(
180+
cache, cache.opt, res[2], cache.f(res[2], cache.p)[1],
181+
stats = stats, retcode = opt_ret)
182+
end

0 commit comments

Comments
 (0)