Skip to content

Commit 9086bfd

Browse files
add auglag tests and run formatter
1 parent a2598c3 commit 9086bfd

File tree

7 files changed

+145
-135
lines changed

7 files changed

+145
-135
lines changed

docs/src/index.md

+11-11
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ to add the specific wrapper packages.
8888
- Second order
8989
- Zeroth order
9090
- Box Constraints
91-
- Constrained 🟡
91+
- Constrained
9292
- <strong>Global Methods</strong>
9393
- Zeroth order
9494
- Unconstrained
@@ -126,21 +126,21 @@ to add the specific wrapper packages.
126126
- Zeroth order
127127
- Unconstrained
128128
- Box Constraints
129-
- Constrained 🟡
129+
- Constrained
130130
</details>
131131
<details>
132132
<summary><strong>NLopt</strong></summary>
133133
- <strong>Local Methods</strong>
134134
- First order
135135
- Zeroth order
136-
- Second order 🟡
136+
- Second order
137137
- Box Constraints
138-
- Local Constrained 🟡
138+
- Local Constrained
139139
- <strong>Global Methods</strong>
140140
- Zeroth order
141141
- First order
142142
- Unconstrained
143-
- Constrained 🟡
143+
- Constrained
144144
</details>
145145
<details>
146146
<summary><strong>Optim</strong></summary>
@@ -158,21 +158,21 @@ to add the specific wrapper packages.
158158
<details>
159159
<summary><strong>PRIMA</strong></summary>
160160
- <strong>Local Methods</strong>
161-
- Derivative-Free:
161+
- Derivative-Free:
162162
- **Constraints**
163-
- Box Constraints:
164-
- Local Constrained:
163+
- Box Constraints:
164+
- Local Constrained:
165165
</details>
166166
<details>
167167
<summary><strong>QuadDIRECT</strong></summary>
168168
- **Constraints**
169-
- Box Constraints:
169+
- Box Constraints:
170170
- <strong>Global Methods</strong>
171-
- Unconstrained:
171+
- Unconstrained:
172172
</details>
173173
```
174174

175-
🟡 = supported in downstream library but not yet implemented in `Optimization.jl`; PR to add this functionality are welcome
175+
= supported in downstream library but not yet implemented in `Optimization.jl`; PR to add this functionality are welcome
176176

177177
## Citation
178178

docs/src/optimization_packages/optimization.md

+20-20
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,28 @@ There are some solvers that are available in the Optimization.jl package directl
44

55
## Methods
66

7-
- `LBFGS`: The popular quasi-Newton method that leverages limited memory BFGS approximation of the inverse of the Hessian. Through a wrapper over the [L-BFGS-B](https://users.iems.northwestern.edu/%7Enocedal/lbfgsb.html) fortran routine accessed from the [LBFGSB.jl](https://github.com/Gnimuc/LBFGSB.jl/) package. It directly supports box-constraints.
8-
9-
This can also handle arbitrary non-linear constraints through a Augmented Lagrangian method with bounds constraints described in 17.4 of Numerical Optimization by Nocedal and Wright. Thus serving as a general-purpose nonlinear optimization solver available directly in Optimization.jl.
7+
- `LBFGS`: The popular quasi-Newton method that leverages limited memory BFGS approximation of the inverse of the Hessian. Through a wrapper over the [L-BFGS-B](https://users.iems.northwestern.edu/%7Enocedal/lbfgsb.html) fortran routine accessed from the [LBFGSB.jl](https://github.com/Gnimuc/LBFGSB.jl/) package. It directly supports box-constraints.
8+
9+
This can also handle arbitrary non-linear constraints through a Augmented Lagrangian method with bounds constraints described in 17.4 of Numerical Optimization by Nocedal and Wright. Thus serving as a general-purpose nonlinear optimization solver available directly in Optimization.jl.
1010

11-
- `Sophia`: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD.
11+
- `Sophia`: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD.
12+
13+
+ `solve(problem, Sophia(; η, βs, ϵ, λ, k, ρ))`
1214

13-
+ `solve(problem, Sophia(; η, βs, ϵ, λ, k, ρ))`
14-
15-
+ `η` is the learning rate
16-
+ `βs` are the decay of momentums
17-
+ `ϵ` is the epsilon value
18-
+ `λ` is the weight decay parameter
19-
+ `k` is the number of iterations to re-compute the diagonal of the Hessian matrix
20-
+ `ρ` is the momentum
21-
+ Defaults:
22-
23-
* `η = 0.001`
24-
* `βs = (0.9, 0.999)`
25-
* `ϵ = 1e-8`
26-
* `λ = 0.1`
27-
* `k = 10`
28-
* `ρ = 0.04`
15+
+ `η` is the learning rate
16+
+ `βs` are the decay of momentums
17+
+ `ϵ` is the epsilon value
18+
+ `λ` is the weight decay parameter
19+
+ `k` is the number of iterations to re-compute the diagonal of the Hessian matrix
20+
+ `ρ` is the momentum
21+
+ Defaults:
22+
23+
* `η = 0.001`
24+
* `βs = (0.9, 0.999)`
25+
* `ϵ = 1e-8`
26+
* `λ = 0.1`
27+
* `k = 10`
28+
* `ρ = 0.04`
2929

3030
## Examples
3131

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,12 @@ function __map_optimizer_args(cache::OptimizationCache,
3838
abstol::Union{Number, Nothing} = nothing,
3939
reltol::Union{Number, Nothing} = nothing,
4040
kwargs...)
41-
4241
mapped_args = (; extended_trace = true, kwargs...)
4342

4443
if !isnothing(abstol)
4544
mapped_args = (; mapped_args..., f_abstol = abstol)
4645
end
47-
46+
4847
if !isnothing(callback)
4948
mapped_args = (; mapped_args..., callback = callback)
5049
end

src/auglag.jl

+98-96
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@kwdef struct AugLag
2-
inner
2+
inner::Any
33
τ = 0.5
44
γ = 10.0
55
λmin = -1e20
@@ -73,107 +73,109 @@ function SciMLBase.__solve(cache::OptimizationCache{
7373
P,
7474
C
7575
}
76-
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
77-
78-
local x
79-
80-
solver_kwargs = __map_optimizer_args(cache, cache.opt; maxiters, cache.solver_args...)
81-
82-
if !isnothing(cache.f.cons)
83-
eq_inds = [cache.lcons[i] == cache.ucons[i] for i in eachindex(cache.lcons)]
84-
ineq_inds = (!).(eq_inds)
85-
86-
τ = cache.opt.τ
87-
γ = cache.opt.γ
88-
λmin = cache.opt.λmin
89-
λmax = cache.opt.λmax
90-
μmin = cache.opt.μmin
91-
μmax = cache.opt.μmax
92-
ϵ = cache.opt.ϵ
93-
94-
λ = zeros(eltype(cache.u0), sum(eq_inds))
95-
μ = zeros(eltype(cache.u0), sum(ineq_inds))
96-
97-
cons_tmp = zeros(eltype(cache.u0), length(cache.lcons))
98-
cache.f.cons(cons_tmp, cache.u0)
99-
ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, iterate(cache.p)[1]))) / norm(cons_tmp)))
100-
101-
_loss = function (θ, p = cache.p)
102-
x = cache.f(θ, p)
103-
cons_tmp .= zero(eltype(θ))
104-
cache.f.cons(cons_tmp, θ)
105-
cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds]
106-
cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds]
107-
opt_state = Optimization.OptimizationState(u = θ, objective = x[1])
108-
if cache.callback(opt_state, x...)
109-
error("Optimization halted by callback.")
76+
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
77+
78+
local x
79+
80+
solver_kwargs = __map_optimizer_args(cache, cache.opt; maxiters, cache.solver_args...)
81+
82+
if !isnothing(cache.f.cons)
83+
eq_inds = [cache.lcons[i] == cache.ucons[i] for i in eachindex(cache.lcons)]
84+
ineq_inds = (!).(eq_inds)
85+
86+
τ = cache.opt.τ
87+
γ = cache.opt.γ
88+
λmin = cache.opt.λmin
89+
λmax = cache.opt.λmax
90+
μmin = cache.opt.μmin
91+
μmax = cache.opt.μmax
92+
ϵ = cache.opt.ϵ
93+
94+
λ = zeros(eltype(cache.u0), sum(eq_inds))
95+
μ = zeros(eltype(cache.u0), sum(ineq_inds))
96+
97+
cons_tmp = zeros(eltype(cache.u0), length(cache.lcons))
98+
cache.f.cons(cons_tmp, cache.u0)
99+
ρ = max(1e-6,
100+
min(10, 2 * (abs(cache.f(cache.u0, iterate(cache.p)[1]))) / norm(cons_tmp)))
101+
102+
_loss = function (θ, p = cache.p)
103+
x = cache.f(θ, p)
104+
cons_tmp .= zero(eltype(θ))
105+
cache.f.cons(cons_tmp, θ)
106+
cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds]
107+
cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds]
108+
opt_state = Optimization.OptimizationState(u = θ, objective = x[1])
109+
if cache.callback(opt_state, x...)
110+
error("Optimization halted by callback.")
111+
end
112+
return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) +
113+
1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+.* cons_tmp[ineq_inds]))) .^ 2)
110114
end
111-
return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) +
112-
1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+.* cons_tmp[ineq_inds]))) .^ 2)
113-
end
114115

115-
prev_eqcons = zero(λ)
116-
θ = cache.u0
117-
β = max.(cons_tmp[ineq_inds], Ref(0.0))
118-
prevβ = zero(β)
119-
eqidxs = [eq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
120-
ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
121-
eqidxs = eqidxs[eqidxs .!= nothing]
122-
ineqidxs = ineqidxs[ineqidxs .!= nothing]
123-
function aug_grad(G, θ, p)
124-
cache.f.grad(G, θ, p)
125-
if !isnothing(cache.f.cons_jac_prototype)
126-
J = Float64.(cache.f.cons_jac_prototype)
127-
else
128-
J = zeros((length(cache.lcons), length(θ)))
116+
prev_eqcons = zero(λ)
117+
θ = cache.u0
118+
β = max.(cons_tmp[ineq_inds], Ref(0.0))
119+
prevβ = zero(β)
120+
eqidxs = [eq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
121+
ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
122+
eqidxs = eqidxs[eqidxs .!= nothing]
123+
ineqidxs = ineqidxs[ineqidxs .!= nothing]
124+
function aug_grad(G, θ, p)
125+
cache.f.grad(G, θ, p)
126+
if !isnothing(cache.f.cons_jac_prototype)
127+
J = Float64.(cache.f.cons_jac_prototype)
128+
else
129+
J = zeros((length(cache.lcons), length(θ)))
130+
end
131+
cache.f.cons_j(J, θ)
132+
__tmp = zero(cons_tmp)
133+
cache.f.cons(__tmp, θ)
134+
__tmp[eq_inds] .= __tmp[eq_inds] .- cache.lcons[eq_inds]
135+
__tmp[ineq_inds] .= __tmp[ineq_inds] .- cache.ucons[ineq_inds]
136+
G .+= sum(
137+
λ[i] .* J[idx, :] + ρ * (__tmp[idx] .* J[idx, :])
138+
for (i, idx) in enumerate(eqidxs);
139+
init = zero(G)) #should be jvp
140+
G .+= sum(
141+
1 / ρ * (max.(Ref(0.0), μ[i] .+.* __tmp[idx])) .* J[idx, :])
142+
for (i, idx) in enumerate(ineqidxs);
143+
init = zero(G)) #should be jvp
129144
end
130-
cache.f.cons_j(J, θ)
131-
__tmp = zero(cons_tmp)
132-
cache.f.cons(__tmp, θ)
133-
__tmp[eq_inds] .= __tmp[eq_inds] .- cache.lcons[eq_inds]
134-
__tmp[ineq_inds] .= __tmp[ineq_inds] .- cache.ucons[ineq_inds]
135-
G .+= sum(
136-
λ[i] .* J[idx, :] + ρ * (__tmp[idx] .* J[idx, :])
137-
for (i, idx) in enumerate(eqidxs);
138-
init = zero(G)) #should be jvp
139-
G .+= sum(
140-
1 / ρ * (max.(Ref(0.0), μ[i] .+.* __tmp[idx])) .* J[idx, :])
141-
for (i, idx) in enumerate(ineqidxs);
142-
init = zero(G)) #should be jvp
143-
end
144145

145-
opt_ret = ReturnCode.MaxIters
146-
n = length(cache.u0)
147-
148-
augprob = OptimizationProblem(OptimizationFunction(_loss; grad = aug_grad), cache.u0, cache.p)
149-
150-
solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing))
151-
152-
for i in 1:(maxiters/10)
153-
prev_eqcons .= cons_tmp[eq_inds] .- cache.lcons[eq_inds]
154-
prevβ .= copy(β)
155-
res = solve(augprob, cache.opt.inner, maxiters = maxiters / 10)
156-
θ = res.u
157-
cons_tmp .= 0.0
158-
cache.f.cons(cons_tmp, θ)
159-
λ = max.(min.(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache.lcons[eq_inds])), λmin)
160-
β = max.(cons_tmp[ineq_inds], -1 .* μ ./ ρ)
161-
μ = min.(μmax, max.(μ .+ ρ * cons_tmp[ineq_inds], μmin))
162-
if max(norm(cons_tmp[eq_inds] .- cache.lcons[eq_inds], Inf), norm(β, Inf)) >
163-
τ * max(norm(prev_eqcons, Inf), norm(prevβ, Inf))
164-
ρ = γ * ρ
165-
end
166-
if norm(
167-
(cons_tmp[eq_inds] .- cache.lcons[eq_inds]) ./ cons_tmp[eq_inds], Inf) <
168-
ϵ && norm(β, Inf) < ϵ
169-
opt_ret = ReturnCode.Success
170-
break
146+
opt_ret = ReturnCode.MaxIters
147+
n = length(cache.u0)
148+
149+
augprob = OptimizationProblem(
150+
OptimizationFunction(_loss; grad = aug_grad), cache.u0, cache.p)
151+
152+
solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing))
153+
154+
for i in 1:(maxiters / 10)
155+
prev_eqcons .= cons_tmp[eq_inds] .- cache.lcons[eq_inds]
156+
prevβ .= copy(β)
157+
res = solve(augprob, cache.opt.inner, maxiters = maxiters / 10)
158+
θ = res.u
159+
cons_tmp .= 0.0
160+
cache.f.cons(cons_tmp, θ)
161+
λ = max.(min.(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache.lcons[eq_inds])), λmin)
162+
β = max.(cons_tmp[ineq_inds], -1 .* μ ./ ρ)
163+
μ = min.(μmax, max.(μ .+ ρ * cons_tmp[ineq_inds], μmin))
164+
if max(norm(cons_tmp[eq_inds] .- cache.lcons[eq_inds], Inf), norm(β, Inf)) >
165+
τ * max(norm(prev_eqcons, Inf), norm(prevβ, Inf))
166+
ρ = γ * ρ
167+
end
168+
if norm(
169+
(cons_tmp[eq_inds] .- cache.lcons[eq_inds]) ./ cons_tmp[eq_inds], Inf) <
170+
ϵ && norm(β, Inf) < ϵ
171+
opt_ret = ReturnCode.Success
172+
break
173+
end
171174
end
172-
end
173-
stats = Optimization.OptimizationStats(; iterations = maxiters,
175+
stats = Optimization.OptimizationStats(; iterations = maxiters,
174176
time = 0.0, fevals = maxiters, gevals = maxiters)
175-
return SciMLBase.build_solution(
177+
return SciMLBase.build_solution(
176178
cache, cache.opt, θ, x,
177179
stats = stats, retcode = opt_ret)
180+
end
178181
end
179-
end

test/diffeqfluxtests.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
7070
dudt2 = Lux.Chain(x -> x .^ 3,
7171
Lux.Dense(2, 50, tanh),
7272
Lux.Dense(50, 2))
73-
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-8)
73+
prob_neuralode = NeuralODE(
74+
dudt2, tspan, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-8)
7475
pp, st = Lux.setup(rng, dudt2)
7576
pp = ComponentArray(pp)
7677

test/lbfgsb.jl renamed to test/native.jl

+12-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf],
2828

2929
using MLUtils, OptimizationOptimisers
3030

31-
x0 = -pi:0.001:pi
31+
x0 = (-pi):0.001:pi
3232
y0 = sin.(x0)
3333
data = MLUtils.DataLoader((x0, y0), batchsize = 100)
3434
function loss(coeffs, data)
@@ -44,12 +44,20 @@ end
4444
optf = OptimizationFunction(loss, AutoSparseForwardDiff(), cons = cons1)
4545
callback = (st, l) -> (@show l; return false)
4646

47-
prob = OptimizationProblem(optf, rand(5), (x0, y0), lcons = [-0.5], ucons = [0.5], lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
47+
initpars = rand(5)
48+
l0 = optf(initpars, (x0, y0))
49+
prob = OptimizationProblem(optf, initpars, (x0, y0), lcons = [-Inf], ucons = [0.5],
50+
lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
4851
opt1 = solve(prob, Optimization.LBFGS(), maxiters = 1000, callback = callback)
52+
@test opt1.objective < l0
4953

50-
prob = OptimizationProblem(optf, rand(5), data, lcons = [0.0], ucons = [0.0], lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
51-
opt = solve(prob, Optimization.AugLag(; inner = Adam()), maxiters = 500, callback = callback)
54+
prob = OptimizationProblem(optf, initpars, data, lcons = [-Inf], ucons = [1],
55+
lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
56+
opt = solve(
57+
prob, Optimization.AugLag(; inner = Adam()), maxiters = 10000, callback = callback)
58+
@test opt.objective < l0
5259

5360
optf1 = OptimizationFunction(loss, AutoSparseForwardDiff())
5461
prob1 = OptimizationProblem(optf1, rand(5), data)
5562
sol1 = solve(prob1, OptimizationOptimisers.Adam(), maxiters = 1000, callback = callback)
63+
@test sol1.objective < l0

test/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ end
3636
include("AD_performance_regression.jl")
3737
end
3838
@safetestset "Optimization" begin
39-
include("lbfgsb.jl")
39+
include("native.jl")
4040
end
4141
@safetestset "Mini batching" begin
4242
include("minibatch.jl")

0 commit comments

Comments
 (0)