Skip to content

Commit a2598c3

Browse files
Work with stochastic optimizers too
1 parent 370286b commit a2598c3

File tree

3 files changed

+59
-33
lines changed

3 files changed

+59
-33
lines changed

src/Optimization.jl

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ include("utils.jl")
2424
include("state.jl")
2525
include("lbfgsb.jl")
2626
include("sophia.jl")
27+
include("auglag.jl")
2728

2829
export solve
2930

src/auglag.jl

+30-33
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
1-
2-
SciMLBase.supports_opt_cache_interface(::LBFGS) = true
3-
SciMLBase.allowsbounds(::LBFGS) = true
4-
SciMLBase.requiresgradient(::LBFGS) = true
5-
SciMLBase.allowsconstraints(::LBFGS) = true
6-
SciMLBase.requiresconsjac(::LBFGS) = true
7-
8-
function task_message_to_string(task::Vector{UInt8})
9-
return String(task)
1+
@kwdef struct AugLag
2+
inner
3+
τ = 0.5
4+
γ = 10.0
5+
λmin = -1e20
6+
λmax = 1e20
7+
μmin = 0.0
8+
μmax = 1e20
9+
ϵ = 1e-8
1010
end
1111

12-
function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::LBFGS;
12+
SciMLBase.supports_opt_cache_interface(::AugLag) = true
13+
SciMLBase.allowsbounds(::AugLag) = true
14+
SciMLBase.requiresgradient(::AugLag) = true
15+
SciMLBase.allowsconstraints(::AugLag) = true
16+
SciMLBase.requiresconsjac(::AugLag) = true
17+
18+
function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::AugLag;
1319
callback = nothing,
1420
maxiters::Union{Number, Nothing} = nothing,
1521
maxtime::Union{Number, Nothing} = nothing,
@@ -62,7 +68,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
6268
UC,
6369
S,
6470
O <:
65-
LBFGS,
71+
AugLag,
6672
D,
6773
P,
6874
C
@@ -90,10 +96,10 @@ if !isnothing(cache.f.cons)
9096

9197
cons_tmp = zeros(eltype(cache.u0), length(cache.lcons))
9298
cache.f.cons(cons_tmp, cache.u0)
93-
ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, cache.p))) / norm(cons_tmp)))
99+
ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, iterate(cache.p)[1]))) / norm(cons_tmp)))
94100

95-
_loss = function (θ)
96-
x = cache.f(θ, cache.p)
101+
_loss = function, p = cache.p)
102+
x = cache.f(θ, p)
97103
cons_tmp .= zero(eltype(θ))
98104
cache.f.cons(cons_tmp, θ)
99105
cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds]
@@ -114,8 +120,8 @@ if !isnothing(cache.f.cons)
114120
ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
115121
eqidxs = eqidxs[eqidxs .!= nothing]
116122
ineqidxs = ineqidxs[ineqidxs .!= nothing]
117-
function aug_grad(G, θ)
118-
cache.f.grad(G, θ)
123+
function aug_grad(G, θ, p)
124+
cache.f.grad(G, θ, p)
119125
if !isnothing(cache.f.cons_jac_prototype)
120126
J = Float64.(cache.f.cons_jac_prototype)
121127
else
@@ -139,23 +145,15 @@ if !isnothing(cache.f.cons)
139145
opt_ret = ReturnCode.MaxIters
140146
n = length(cache.u0)
141147

142-
sol = solve(....)
148+
augprob = OptimizationProblem(OptimizationFunction(_loss; grad = aug_grad), cache.u0, cache.p)
143149

144150
solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing))
145151

146-
for i in 1:maxiters
152+
for i in 1:(maxiters/10)
147153
prev_eqcons .= cons_tmp[eq_inds] .- cache.lcons[eq_inds]
148154
prevβ .= copy(β)
149-
res = optimizer(_loss, aug_grad, θ, bounds; solver_kwargs...,
150-
m = cache.opt.m, pgtol = sqrt(ϵ), maxiter = maxiters / 100)
151-
# @show res[2]
152-
# @show res[1]
153-
# @show cons_tmp
154-
# @show λ
155-
# @show β
156-
# @show μ
157-
# @show ρ
158-
θ = res[2]
155+
res = solve(augprob, cache.opt.inner, maxiters = maxiters / 10)
156+
θ = res.u
159157
cons_tmp .= 0.0
160158
cache.f.cons(cons_tmp, θ)
161159
λ = max.(min.(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache.lcons[eq_inds])), λmin)
@@ -172,11 +170,10 @@ if !isnothing(cache.f.cons)
172170
break
173171
end
174172
end
175-
end
176-
177-
stats = Optimization.OptimizationStats(; iterations = maxiters,
173+
stats = Optimization.OptimizationStats(; iterations = maxiters,
178174
time = 0.0, fevals = maxiters, gevals = maxiters)
179-
return SciMLBase.build_solution(
180-
cache, cache.opt, res[2], cache.f(res[2], cache.p)[1],
175+
return SciMLBase.build_solution(
176+
cache, cache.opt, θ, x,
181177
stats = stats, retcode = opt_ret)
178+
end
182179
end

test/lbfgsb.jl

+28
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,31 @@ prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf],
2525
ub = [1.0, 1.0])
2626
@time res = solve(prob, Optimization.LBFGS(), maxiters = 100)
2727
@test res.retcode == SciMLBase.ReturnCode.Success
28+
29+
using MLUtils, OptimizationOptimisers
30+
31+
x0 = -pi:0.001:pi
32+
y0 = sin.(x0)
33+
data = MLUtils.DataLoader((x0, y0), batchsize = 100)
34+
function loss(coeffs, data)
35+
ypred = [evalpoly(data[1][i], coeffs) for i in eachindex(data[1])]
36+
return sum(abs2, ypred .- data[2])
37+
end
38+
39+
function cons1(res, coeffs, p = nothing)
40+
res[1] = coeffs[1] * coeffs[5] - 1
41+
return nothing
42+
end
43+
44+
optf = OptimizationFunction(loss, AutoSparseForwardDiff(), cons = cons1)
45+
callback = (st, l) -> (@show l; return false)
46+
47+
prob = OptimizationProblem(optf, rand(5), (x0, y0), lcons = [-0.5], ucons = [0.5], lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
48+
opt1 = solve(prob, Optimization.LBFGS(), maxiters = 1000, callback = callback)
49+
50+
prob = OptimizationProblem(optf, rand(5), data, lcons = [0.0], ucons = [0.0], lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
51+
opt = solve(prob, Optimization.AugLag(; inner = Adam()), maxiters = 500, callback = callback)
52+
53+
optf1 = OptimizationFunction(loss, AutoSparseForwardDiff())
54+
prob1 = OptimizationProblem(optf1, rand(5), data)
55+
sol1 = solve(prob1, OptimizationOptimisers.Adam(), maxiters = 1000, callback = callback)

0 commit comments

Comments
 (0)