Skip to content

Make auglag generic and reusable with all solvers #833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ to add the specific wrapper packages.
- Second order
- Zeroth order
- Box Constraints
- Constrained 🟡
- Constrained
- <strong>Global Methods</strong>
- Zeroth order
- Unconstrained
Expand Down Expand Up @@ -126,21 +126,21 @@ to add the specific wrapper packages.
- Zeroth order
- Unconstrained
- Box Constraints
- Constrained 🟡
- Constrained
</details>
<details>
<summary><strong>NLopt</strong></summary>
- <strong>Local Methods</strong>
- First order
- Zeroth order
- Second order 🟡
- Second order
- Box Constraints
- Local Constrained 🟡
- Local Constrained
- <strong>Global Methods</strong>
- Zeroth order
- First order
- Unconstrained
- Constrained 🟡
- Constrained
</details>
<details>
<summary><strong>Optim</strong></summary>
Expand All @@ -158,21 +158,21 @@ to add the specific wrapper packages.
<details>
<summary><strong>PRIMA</strong></summary>
- <strong>Local Methods</strong>
- Derivative-Free:
- Derivative-Free:
- **Constraints**
- Box Constraints:
- Local Constrained:
- Box Constraints:
- Local Constrained:
</details>
<details>
<summary><strong>QuadDIRECT</strong></summary>
- **Constraints**
- Box Constraints:
- Box Constraints:
- <strong>Global Methods</strong>
- Unconstrained:
- Unconstrained:
</details>
```

🟡 = supported in downstream library but not yet implemented in `Optimization.jl`; PR to add this functionality are welcome
= supported in downstream library but not yet implemented in `Optimization.jl`; PR to add this functionality are welcome

## Citation

Expand Down
40 changes: 20 additions & 20 deletions docs/src/optimization_packages/optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,28 @@ There are some solvers that are available in the Optimization.jl package directl

## Methods

- `LBFGS`: The popular quasi-Newton method that leverages limited memory BFGS approximation of the inverse of the Hessian. Through a wrapper over the [L-BFGS-B](https://users.iems.northwestern.edu/%7Enocedal/lbfgsb.html) fortran routine accessed from the [LBFGSB.jl](https://github.com/Gnimuc/LBFGSB.jl/) package. It directly supports box-constraints.

This can also handle arbitrary non-linear constraints through a Augmented Lagrangian method with bounds constraints described in 17.4 of Numerical Optimization by Nocedal and Wright. Thus serving as a general-purpose nonlinear optimization solver available directly in Optimization.jl.
- `LBFGS`: The popular quasi-Newton method that leverages limited memory BFGS approximation of the inverse of the Hessian. Through a wrapper over the [L-BFGS-B](https://users.iems.northwestern.edu/%7Enocedal/lbfgsb.html) fortran routine accessed from the [LBFGSB.jl](https://github.com/Gnimuc/LBFGSB.jl/) package. It directly supports box-constraints.
This can also handle arbitrary non-linear constraints through a Augmented Lagrangian method with bounds constraints described in 17.4 of Numerical Optimization by Nocedal and Wright. Thus serving as a general-purpose nonlinear optimization solver available directly in Optimization.jl.

- `Sophia`: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD.
- `Sophia`: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD.

+ `solve(problem, Sophia(; η, βs, ϵ, λ, k, ρ))`

+ `solve(problem, Sophia(; η, βs, ϵ, λ, k, ρ))`

+ `η` is the learning rate
+ `βs` are the decay of momentums
+ `ϵ` is the epsilon value
+ `λ` is the weight decay parameter
+ `k` is the number of iterations to re-compute the diagonal of the Hessian matrix
+ `ρ` is the momentum
+ Defaults:

* `η = 0.001`
* `βs = (0.9, 0.999)`
* `ϵ = 1e-8`
* `λ = 0.1`
* `k = 10`
* `ρ = 0.04`
+ `η` is the learning rate
+ `βs` are the decay of momentums
+ `ϵ` is the epsilon value
+ `λ` is the weight decay parameter
+ `k` is the number of iterations to re-compute the diagonal of the Hessian matrix
+ `ρ` is the momentum
+ Defaults:

* `η = 0.001`
* `βs = (0.9, 0.999)`
* `ϵ = 1e-8`
* `λ = 0.1`
* `k = 10`
* `ρ = 0.04`

## Examples

Expand Down
3 changes: 1 addition & 2 deletions lib/OptimizationOptimJL/src/OptimizationOptimJL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ function __map_optimizer_args(cache::OptimizationCache,
abstol::Union{Number, Nothing} = nothing,
reltol::Union{Number, Nothing} = nothing,
kwargs...)

mapped_args = (; extended_trace = true, kwargs...)

if !isnothing(abstol)
mapped_args = (; mapped_args..., f_abstol = abstol)
end

if !isnothing(callback)
mapped_args = (; mapped_args..., callback = callback)
end
Expand Down
1 change: 1 addition & 0 deletions src/Optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ include("utils.jl")
include("state.jl")
include("lbfgsb.jl")
include("sophia.jl")
include("auglag.jl")

export solve

Expand Down
181 changes: 181 additions & 0 deletions src/auglag.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
@kwdef struct AugLag
inner::Any
τ = 0.5
γ = 10.0
λmin = -1e20
λmax = 1e20
μmin = 0.0
μmax = 1e20
ϵ = 1e-8
end

SciMLBase.supports_opt_cache_interface(::AugLag) = true
SciMLBase.allowsbounds(::AugLag) = true
SciMLBase.requiresgradient(::AugLag) = true
SciMLBase.allowsconstraints(::AugLag) = true
SciMLBase.requiresconsjac(::AugLag) = true

function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::AugLag;
callback = nothing,
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,
reltol::Union{Number, Nothing} = nothing,
verbose::Bool = false,
kwargs...)
if !isnothing(abstol)
@warn "common abstol is currently not used by $(opt)"
end
if !isnothing(maxtime)
@warn "common abstol is currently not used by $(opt)"
end

mapped_args = (;)

if cache.lb !== nothing && cache.ub !== nothing
mapped_args = (; mapped_args..., lb = cache.lb, ub = cache.ub)
end

if !isnothing(maxiters)
mapped_args = (; mapped_args..., maxiter = maxiters)
end

if !isnothing(reltol)
mapped_args = (; mapped_args..., pgtol = reltol)
end

return mapped_args
end

function SciMLBase.__solve(cache::OptimizationCache{
F,
RC,
LB,
UB,
LC,
UC,
S,
O,
D,
P,
C
}) where {
F,
RC,
LB,
UB,
LC,
UC,
S,
O <:
AugLag,
D,
P,
C
}
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)

local x

solver_kwargs = __map_optimizer_args(cache, cache.opt; maxiters, cache.solver_args...)

if !isnothing(cache.f.cons)
eq_inds = [cache.lcons[i] == cache.ucons[i] for i in eachindex(cache.lcons)]
ineq_inds = (!).(eq_inds)

τ = cache.opt.τ
γ = cache.opt.γ
λmin = cache.opt.λmin
λmax = cache.opt.λmax
μmin = cache.opt.μmin
μmax = cache.opt.μmax
ϵ = cache.opt.ϵ

λ = zeros(eltype(cache.u0), sum(eq_inds))
μ = zeros(eltype(cache.u0), sum(ineq_inds))

cons_tmp = zeros(eltype(cache.u0), length(cache.lcons))
cache.f.cons(cons_tmp, cache.u0)
ρ = max(1e-6,
min(10, 2 * (abs(cache.f(cache.u0, iterate(cache.p)[1]))) / norm(cons_tmp)))

_loss = function (θ, p = cache.p)
x = cache.f(θ, p)
cons_tmp .= zero(eltype(θ))
cache.f.cons(cons_tmp, θ)
cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds]
cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds]
opt_state = Optimization.OptimizationState(u = θ, objective = x[1])
if cache.callback(opt_state, x...)
error("Optimization halted by callback.")
end
return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) +
1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+ (ρ .* cons_tmp[ineq_inds]))) .^ 2)
end

prev_eqcons = zero(λ)
θ = cache.u0
β = max.(cons_tmp[ineq_inds], Ref(0.0))
prevβ = zero(β)
eqidxs = [eq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
eqidxs = eqidxs[eqidxs .!= nothing]
ineqidxs = ineqidxs[ineqidxs .!= nothing]
function aug_grad(G, θ, p)
cache.f.grad(G, θ, p)
if !isnothing(cache.f.cons_jac_prototype)
J = Float64.(cache.f.cons_jac_prototype)
else
J = zeros((length(cache.lcons), length(θ)))
end
cache.f.cons_j(J, θ)
__tmp = zero(cons_tmp)
cache.f.cons(__tmp, θ)
__tmp[eq_inds] .= __tmp[eq_inds] .- cache.lcons[eq_inds]
__tmp[ineq_inds] .= __tmp[ineq_inds] .- cache.ucons[ineq_inds]
G .+= sum(
λ[i] .* J[idx, :] + ρ * (__tmp[idx] .* J[idx, :])
for (i, idx) in enumerate(eqidxs);
init = zero(G)) #should be jvp
G .+= sum(
1 / ρ * (max.(Ref(0.0), μ[i] .+ (ρ .* __tmp[idx])) .* J[idx, :])
for (i, idx) in enumerate(ineqidxs);
init = zero(G)) #should be jvp
end

opt_ret = ReturnCode.MaxIters
n = length(cache.u0)

augprob = OptimizationProblem(
OptimizationFunction(_loss; grad = aug_grad), cache.u0, cache.p)

solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing))

for i in 1:(maxiters / 10)
prev_eqcons .= cons_tmp[eq_inds] .- cache.lcons[eq_inds]
prevβ .= copy(β)
res = solve(augprob, cache.opt.inner, maxiters = maxiters / 10)
θ = res.u
cons_tmp .= 0.0
cache.f.cons(cons_tmp, θ)
λ = max.(min.(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache.lcons[eq_inds])), λmin)
β = max.(cons_tmp[ineq_inds], -1 .* μ ./ ρ)
μ = min.(μmax, max.(μ .+ ρ * cons_tmp[ineq_inds], μmin))
if max(norm(cons_tmp[eq_inds] .- cache.lcons[eq_inds], Inf), norm(β, Inf)) >
τ * max(norm(prev_eqcons, Inf), norm(prevβ, Inf))
ρ = γ * ρ
end
if norm(
(cons_tmp[eq_inds] .- cache.lcons[eq_inds]) ./ cons_tmp[eq_inds], Inf) <
ϵ && norm(β, Inf) < ϵ
opt_ret = ReturnCode.Success
break
end
end
stats = Optimization.OptimizationStats(; iterations = maxiters,
time = 0.0, fevals = maxiters, gevals = maxiters)
return SciMLBase.build_solution(
cache, cache.opt, θ, x,
stats = stats, retcode = opt_ret)
end
end
3 changes: 2 additions & 1 deletion test/diffeqfluxtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
dudt2 = Lux.Chain(x -> x .^ 3,
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-8)
prob_neuralode = NeuralODE(
dudt2, tspan, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-8)
pp, st = Lux.setup(rng, dudt2)
pp = ComponentArray(pp)

Expand Down
27 changes: 0 additions & 27 deletions test/lbfgsb.jl

This file was deleted.

Loading
Loading