From 370286ba92ed6fb6b332bc60a5e1e4b26d11249a Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 26 Sep 2024 07:53:51 -0400 Subject: [PATCH 1/3] Copy auglag from lbfgsb.jl --- src/auglag.jl | 182 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 src/auglag.jl diff --git a/src/auglag.jl b/src/auglag.jl new file mode 100644 index 000000000..e7234334a --- /dev/null +++ b/src/auglag.jl @@ -0,0 +1,182 @@ + +SciMLBase.supports_opt_cache_interface(::LBFGS) = true +SciMLBase.allowsbounds(::LBFGS) = true +SciMLBase.requiresgradient(::LBFGS) = true +SciMLBase.allowsconstraints(::LBFGS) = true +SciMLBase.requiresconsjac(::LBFGS) = true + +function task_message_to_string(task::Vector{UInt8}) + return String(task) +end + +function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::LBFGS; + callback = nothing, + maxiters::Union{Number, Nothing} = nothing, + maxtime::Union{Number, Nothing} = nothing, + abstol::Union{Number, Nothing} = nothing, + reltol::Union{Number, Nothing} = nothing, + verbose::Bool = false, + kwargs...) + if !isnothing(abstol) + @warn "common abstol is currently not used by $(opt)" + end + if !isnothing(maxtime) + @warn "common abstol is currently not used by $(opt)" + end + + mapped_args = (;) + + if cache.lb !== nothing && cache.ub !== nothing + mapped_args = (; mapped_args..., lb = cache.lb, ub = cache.ub) + end + + if !isnothing(maxiters) + mapped_args = (; mapped_args..., maxiter = maxiters) + end + + if !isnothing(reltol) + mapped_args = (; mapped_args..., pgtol = reltol) + end + + return mapped_args +end + +function SciMLBase.__solve(cache::OptimizationCache{ + F, + RC, + LB, + UB, + LC, + UC, + S, + O, + D, + P, + C +}) where { + F, + RC, + LB, + UB, + LC, + UC, + S, + O <: + LBFGS, + D, + P, + C +} +maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters) + +local x + +solver_kwargs = __map_optimizer_args(cache, cache.opt; maxiters, cache.solver_args...) + +if !isnothing(cache.f.cons) + eq_inds = [cache.lcons[i] == cache.ucons[i] for i in eachindex(cache.lcons)] + ineq_inds = (!).(eq_inds) + + τ = cache.opt.τ + γ = cache.opt.γ + λmin = cache.opt.λmin + λmax = cache.opt.λmax + μmin = cache.opt.μmin + μmax = cache.opt.μmax + ϵ = cache.opt.ϵ + + λ = zeros(eltype(cache.u0), sum(eq_inds)) + μ = zeros(eltype(cache.u0), sum(ineq_inds)) + + cons_tmp = zeros(eltype(cache.u0), length(cache.lcons)) + cache.f.cons(cons_tmp, cache.u0) + ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, cache.p))) / norm(cons_tmp))) + + _loss = function (θ) + x = cache.f(θ, cache.p) + cons_tmp .= zero(eltype(θ)) + cache.f.cons(cons_tmp, θ) + cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds] + cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds] + opt_state = Optimization.OptimizationState(u = θ, objective = x[1]) + if cache.callback(opt_state, x...) + error("Optimization halted by callback.") + end + return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) + + 1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+ (ρ .* cons_tmp[ineq_inds]))) .^ 2) + end + + prev_eqcons = zero(λ) + θ = cache.u0 + β = max.(cons_tmp[ineq_inds], Ref(0.0)) + prevβ = zero(β) + eqidxs = [eq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)] + ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)] + eqidxs = eqidxs[eqidxs .!= nothing] + ineqidxs = ineqidxs[ineqidxs .!= nothing] + function aug_grad(G, θ) + cache.f.grad(G, θ) + if !isnothing(cache.f.cons_jac_prototype) + J = Float64.(cache.f.cons_jac_prototype) + else + J = zeros((length(cache.lcons), length(θ))) + end + cache.f.cons_j(J, θ) + __tmp = zero(cons_tmp) + cache.f.cons(__tmp, θ) + __tmp[eq_inds] .= __tmp[eq_inds] .- cache.lcons[eq_inds] + __tmp[ineq_inds] .= __tmp[ineq_inds] .- cache.ucons[ineq_inds] + G .+= sum( + λ[i] .* J[idx, :] + ρ * (__tmp[idx] .* J[idx, :]) + for (i, idx) in enumerate(eqidxs); + init = zero(G)) #should be jvp + G .+= sum( + 1 / ρ * (max.(Ref(0.0), μ[i] .+ (ρ .* __tmp[idx])) .* J[idx, :]) + for (i, idx) in enumerate(ineqidxs); + init = zero(G)) #should be jvp + end + + opt_ret = ReturnCode.MaxIters + n = length(cache.u0) + + sol = solve(....) + + solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing)) + + for i in 1:maxiters + prev_eqcons .= cons_tmp[eq_inds] .- cache.lcons[eq_inds] + prevβ .= copy(β) + res = optimizer(_loss, aug_grad, θ, bounds; solver_kwargs..., + m = cache.opt.m, pgtol = sqrt(ϵ), maxiter = maxiters / 100) + # @show res[2] + # @show res[1] + # @show cons_tmp + # @show λ + # @show β + # @show μ + # @show ρ + θ = res[2] + cons_tmp .= 0.0 + cache.f.cons(cons_tmp, θ) + λ = max.(min.(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache.lcons[eq_inds])), λmin) + β = max.(cons_tmp[ineq_inds], -1 .* μ ./ ρ) + μ = min.(μmax, max.(μ .+ ρ * cons_tmp[ineq_inds], μmin)) + if max(norm(cons_tmp[eq_inds] .- cache.lcons[eq_inds], Inf), norm(β, Inf)) > + τ * max(norm(prev_eqcons, Inf), norm(prevβ, Inf)) + ρ = γ * ρ + end + if norm( + (cons_tmp[eq_inds] .- cache.lcons[eq_inds]) ./ cons_tmp[eq_inds], Inf) < + ϵ && norm(β, Inf) < ϵ + opt_ret = ReturnCode.Success + break + end + end +end + +stats = Optimization.OptimizationStats(; iterations = maxiters, + time = 0.0, fevals = maxiters, gevals = maxiters) +return SciMLBase.build_solution( + cache, cache.opt, res[2], cache.f(res[2], cache.p)[1], + stats = stats, retcode = opt_ret) +end \ No newline at end of file From a2598c3dd31c69bcb0996ec5a4da691963138998 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 25 Oct 2024 23:08:41 -0400 Subject: [PATCH 2/3] Work with stochastic optimizers too --- src/Optimization.jl | 1 + src/auglag.jl | 63 +++++++++++++++++++++------------------------ test/lbfgsb.jl | 28 ++++++++++++++++++++ 3 files changed, 59 insertions(+), 33 deletions(-) diff --git a/src/Optimization.jl b/src/Optimization.jl index 4cfeead6e..8d0257dd1 100644 --- a/src/Optimization.jl +++ b/src/Optimization.jl @@ -24,6 +24,7 @@ include("utils.jl") include("state.jl") include("lbfgsb.jl") include("sophia.jl") +include("auglag.jl") export solve diff --git a/src/auglag.jl b/src/auglag.jl index e7234334a..390659c8e 100644 --- a/src/auglag.jl +++ b/src/auglag.jl @@ -1,15 +1,21 @@ - -SciMLBase.supports_opt_cache_interface(::LBFGS) = true -SciMLBase.allowsbounds(::LBFGS) = true -SciMLBase.requiresgradient(::LBFGS) = true -SciMLBase.allowsconstraints(::LBFGS) = true -SciMLBase.requiresconsjac(::LBFGS) = true - -function task_message_to_string(task::Vector{UInt8}) - return String(task) +@kwdef struct AugLag + inner + τ = 0.5 + γ = 10.0 + λmin = -1e20 + λmax = 1e20 + μmin = 0.0 + μmax = 1e20 + ϵ = 1e-8 end -function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::LBFGS; +SciMLBase.supports_opt_cache_interface(::AugLag) = true +SciMLBase.allowsbounds(::AugLag) = true +SciMLBase.requiresgradient(::AugLag) = true +SciMLBase.allowsconstraints(::AugLag) = true +SciMLBase.requiresconsjac(::AugLag) = true + +function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::AugLag; callback = nothing, maxiters::Union{Number, Nothing} = nothing, maxtime::Union{Number, Nothing} = nothing, @@ -62,7 +68,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ UC, S, O <: - LBFGS, + AugLag, D, P, C @@ -90,10 +96,10 @@ if !isnothing(cache.f.cons) cons_tmp = zeros(eltype(cache.u0), length(cache.lcons)) cache.f.cons(cons_tmp, cache.u0) - ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, cache.p))) / norm(cons_tmp))) + ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, iterate(cache.p)[1]))) / norm(cons_tmp))) - _loss = function (θ) - x = cache.f(θ, cache.p) + _loss = function (θ, p = cache.p) + x = cache.f(θ, p) cons_tmp .= zero(eltype(θ)) cache.f.cons(cons_tmp, θ) cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds] @@ -114,8 +120,8 @@ if !isnothing(cache.f.cons) ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)] eqidxs = eqidxs[eqidxs .!= nothing] ineqidxs = ineqidxs[ineqidxs .!= nothing] - function aug_grad(G, θ) - cache.f.grad(G, θ) + function aug_grad(G, θ, p) + cache.f.grad(G, θ, p) if !isnothing(cache.f.cons_jac_prototype) J = Float64.(cache.f.cons_jac_prototype) else @@ -139,23 +145,15 @@ if !isnothing(cache.f.cons) opt_ret = ReturnCode.MaxIters n = length(cache.u0) - sol = solve(....) + augprob = OptimizationProblem(OptimizationFunction(_loss; grad = aug_grad), cache.u0, cache.p) solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing)) - for i in 1:maxiters + for i in 1:(maxiters/10) prev_eqcons .= cons_tmp[eq_inds] .- cache.lcons[eq_inds] prevβ .= copy(β) - res = optimizer(_loss, aug_grad, θ, bounds; solver_kwargs..., - m = cache.opt.m, pgtol = sqrt(ϵ), maxiter = maxiters / 100) - # @show res[2] - # @show res[1] - # @show cons_tmp - # @show λ - # @show β - # @show μ - # @show ρ - θ = res[2] + res = solve(augprob, cache.opt.inner, maxiters = maxiters / 10) + θ = res.u cons_tmp .= 0.0 cache.f.cons(cons_tmp, θ) λ = max.(min.(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache.lcons[eq_inds])), λmin) @@ -172,11 +170,10 @@ if !isnothing(cache.f.cons) break end end -end - -stats = Optimization.OptimizationStats(; iterations = maxiters, + stats = Optimization.OptimizationStats(; iterations = maxiters, time = 0.0, fevals = maxiters, gevals = maxiters) -return SciMLBase.build_solution( - cache, cache.opt, res[2], cache.f(res[2], cache.p)[1], + return SciMLBase.build_solution( + cache, cache.opt, θ, x, stats = stats, retcode = opt_ret) +end end \ No newline at end of file diff --git a/test/lbfgsb.jl b/test/lbfgsb.jl index 2b5ec1691..b981cc2f2 100644 --- a/test/lbfgsb.jl +++ b/test/lbfgsb.jl @@ -25,3 +25,31 @@ prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf], ub = [1.0, 1.0]) @time res = solve(prob, Optimization.LBFGS(), maxiters = 100) @test res.retcode == SciMLBase.ReturnCode.Success + +using MLUtils, OptimizationOptimisers + +x0 = -pi:0.001:pi +y0 = sin.(x0) +data = MLUtils.DataLoader((x0, y0), batchsize = 100) +function loss(coeffs, data) + ypred = [evalpoly(data[1][i], coeffs) for i in eachindex(data[1])] + return sum(abs2, ypred .- data[2]) +end + +function cons1(res, coeffs, p = nothing) + res[1] = coeffs[1] * coeffs[5] - 1 + return nothing +end + +optf = OptimizationFunction(loss, AutoSparseForwardDiff(), cons = cons1) +callback = (st, l) -> (@show l; return false) + +prob = OptimizationProblem(optf, rand(5), (x0, y0), lcons = [-0.5], ucons = [0.5], lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0]) +opt1 = solve(prob, Optimization.LBFGS(), maxiters = 1000, callback = callback) + +prob = OptimizationProblem(optf, rand(5), data, lcons = [0.0], ucons = [0.0], lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0]) +opt = solve(prob, Optimization.AugLag(; inner = Adam()), maxiters = 500, callback = callback) + +optf1 = OptimizationFunction(loss, AutoSparseForwardDiff()) +prob1 = OptimizationProblem(optf1, rand(5), data) +sol1 = solve(prob1, OptimizationOptimisers.Adam(), maxiters = 1000, callback = callback) From 9086bfd630125b437bf77d1b2749163a43d2f3e3 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 18 Apr 2025 13:21:25 -0400 Subject: [PATCH 3/3] add auglag tests and run formatter --- docs/src/index.md | 22 +- .../src/optimization_packages/optimization.md | 40 ++-- .../src/OptimizationOptimJL.jl | 3 +- src/auglag.jl | 194 +++++++++--------- test/diffeqfluxtests.jl | 3 +- test/{lbfgsb.jl => native.jl} | 16 +- test/runtests.jl | 2 +- 7 files changed, 145 insertions(+), 135 deletions(-) rename test/{lbfgsb.jl => native.jl} (75%) diff --git a/docs/src/index.md b/docs/src/index.md index 34f3edd07..9fbabead3 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -88,7 +88,7 @@ to add the specific wrapper packages. - Second order - Zeroth order - Box Constraints - - Constrained 🟡 + - Constrained - Global Methods - Zeroth order - Unconstrained @@ -126,21 +126,21 @@ to add the specific wrapper packages. - Zeroth order - Unconstrained - Box Constraints - - Constrained 🟡 + - Constrained
NLopt - Local Methods - First order - Zeroth order - - Second order 🟡 + - Second order - Box Constraints - - Local Constrained 🟡 + - Local Constrained - Global Methods - Zeroth order - First order - Unconstrained - - Constrained 🟡 + - Constrained
Optim @@ -158,21 +158,21 @@ to add the specific wrapper packages.
PRIMA - Local Methods - - Derivative-Free: ✅ + - Derivative-Free: - **Constraints** - - Box Constraints: ✅ - - Local Constrained: ✅ + - Box Constraints: + - Local Constrained:
QuadDIRECT - **Constraints** - - Box Constraints: ✅ + - Box Constraints: - Global Methods - - Unconstrained: ✅ + - Unconstrained:
``` -🟡 = supported in downstream library but not yet implemented in `Optimization.jl`; PR to add this functionality are welcome += supported in downstream library but not yet implemented in `Optimization.jl`; PR to add this functionality are welcome ## Citation diff --git a/docs/src/optimization_packages/optimization.md b/docs/src/optimization_packages/optimization.md index e36728b11..ddd3bf062 100644 --- a/docs/src/optimization_packages/optimization.md +++ b/docs/src/optimization_packages/optimization.md @@ -4,28 +4,28 @@ There are some solvers that are available in the Optimization.jl package directl ## Methods -- `LBFGS`: The popular quasi-Newton method that leverages limited memory BFGS approximation of the inverse of the Hessian. Through a wrapper over the [L-BFGS-B](https://users.iems.northwestern.edu/%7Enocedal/lbfgsb.html) fortran routine accessed from the [LBFGSB.jl](https://github.com/Gnimuc/LBFGSB.jl/) package. It directly supports box-constraints. - - This can also handle arbitrary non-linear constraints through a Augmented Lagrangian method with bounds constraints described in 17.4 of Numerical Optimization by Nocedal and Wright. Thus serving as a general-purpose nonlinear optimization solver available directly in Optimization.jl. + - `LBFGS`: The popular quasi-Newton method that leverages limited memory BFGS approximation of the inverse of the Hessian. Through a wrapper over the [L-BFGS-B](https://users.iems.northwestern.edu/%7Enocedal/lbfgsb.html) fortran routine accessed from the [LBFGSB.jl](https://github.com/Gnimuc/LBFGSB.jl/) package. It directly supports box-constraints. + + This can also handle arbitrary non-linear constraints through a Augmented Lagrangian method with bounds constraints described in 17.4 of Numerical Optimization by Nocedal and Wright. Thus serving as a general-purpose nonlinear optimization solver available directly in Optimization.jl. -- `Sophia`: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD. + - `Sophia`: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD. + + + `solve(problem, Sophia(; η, βs, ϵ, λ, k, ρ))` - + `solve(problem, Sophia(; η, βs, ϵ, λ, k, ρ))` - - + `η` is the learning rate - + `βs` are the decay of momentums - + `ϵ` is the epsilon value - + `λ` is the weight decay parameter - + `k` is the number of iterations to re-compute the diagonal of the Hessian matrix - + `ρ` is the momentum - + Defaults: - - * `η = 0.001` - * `βs = (0.9, 0.999)` - * `ϵ = 1e-8` - * `λ = 0.1` - * `k = 10` - * `ρ = 0.04` + + `η` is the learning rate + + `βs` are the decay of momentums + + `ϵ` is the epsilon value + + `λ` is the weight decay parameter + + `k` is the number of iterations to re-compute the diagonal of the Hessian matrix + + `ρ` is the momentum + + Defaults: + + * `η = 0.001` + * `βs = (0.9, 0.999)` + * `ϵ = 1e-8` + * `λ = 0.1` + * `k = 10` + * `ρ = 0.04` ## Examples diff --git a/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl b/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl index 24de50d31..34a2ae679 100644 --- a/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl +++ b/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl @@ -38,13 +38,12 @@ function __map_optimizer_args(cache::OptimizationCache, abstol::Union{Number, Nothing} = nothing, reltol::Union{Number, Nothing} = nothing, kwargs...) - mapped_args = (; extended_trace = true, kwargs...) if !isnothing(abstol) mapped_args = (; mapped_args..., f_abstol = abstol) end - + if !isnothing(callback) mapped_args = (; mapped_args..., callback = callback) end diff --git a/src/auglag.jl b/src/auglag.jl index 390659c8e..f3a15036d 100644 --- a/src/auglag.jl +++ b/src/auglag.jl @@ -1,5 +1,5 @@ @kwdef struct AugLag - inner + inner::Any τ = 0.5 γ = 10.0 λmin = -1e20 @@ -73,107 +73,109 @@ function SciMLBase.__solve(cache::OptimizationCache{ P, C } -maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters) - -local x - -solver_kwargs = __map_optimizer_args(cache, cache.opt; maxiters, cache.solver_args...) - -if !isnothing(cache.f.cons) - eq_inds = [cache.lcons[i] == cache.ucons[i] for i in eachindex(cache.lcons)] - ineq_inds = (!).(eq_inds) - - τ = cache.opt.τ - γ = cache.opt.γ - λmin = cache.opt.λmin - λmax = cache.opt.λmax - μmin = cache.opt.μmin - μmax = cache.opt.μmax - ϵ = cache.opt.ϵ - - λ = zeros(eltype(cache.u0), sum(eq_inds)) - μ = zeros(eltype(cache.u0), sum(ineq_inds)) - - cons_tmp = zeros(eltype(cache.u0), length(cache.lcons)) - cache.f.cons(cons_tmp, cache.u0) - ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, iterate(cache.p)[1]))) / norm(cons_tmp))) - - _loss = function (θ, p = cache.p) - x = cache.f(θ, p) - cons_tmp .= zero(eltype(θ)) - cache.f.cons(cons_tmp, θ) - cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds] - cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds] - opt_state = Optimization.OptimizationState(u = θ, objective = x[1]) - if cache.callback(opt_state, x...) - error("Optimization halted by callback.") + maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters) + + local x + + solver_kwargs = __map_optimizer_args(cache, cache.opt; maxiters, cache.solver_args...) + + if !isnothing(cache.f.cons) + eq_inds = [cache.lcons[i] == cache.ucons[i] for i in eachindex(cache.lcons)] + ineq_inds = (!).(eq_inds) + + τ = cache.opt.τ + γ = cache.opt.γ + λmin = cache.opt.λmin + λmax = cache.opt.λmax + μmin = cache.opt.μmin + μmax = cache.opt.μmax + ϵ = cache.opt.ϵ + + λ = zeros(eltype(cache.u0), sum(eq_inds)) + μ = zeros(eltype(cache.u0), sum(ineq_inds)) + + cons_tmp = zeros(eltype(cache.u0), length(cache.lcons)) + cache.f.cons(cons_tmp, cache.u0) + ρ = max(1e-6, + min(10, 2 * (abs(cache.f(cache.u0, iterate(cache.p)[1]))) / norm(cons_tmp))) + + _loss = function (θ, p = cache.p) + x = cache.f(θ, p) + cons_tmp .= zero(eltype(θ)) + cache.f.cons(cons_tmp, θ) + cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds] + cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds] + opt_state = Optimization.OptimizationState(u = θ, objective = x[1]) + if cache.callback(opt_state, x...) + error("Optimization halted by callback.") + end + return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) + + 1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+ (ρ .* cons_tmp[ineq_inds]))) .^ 2) end - return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) + - 1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+ (ρ .* cons_tmp[ineq_inds]))) .^ 2) - end - prev_eqcons = zero(λ) - θ = cache.u0 - β = max.(cons_tmp[ineq_inds], Ref(0.0)) - prevβ = zero(β) - eqidxs = [eq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)] - ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)] - eqidxs = eqidxs[eqidxs .!= nothing] - ineqidxs = ineqidxs[ineqidxs .!= nothing] - function aug_grad(G, θ, p) - cache.f.grad(G, θ, p) - if !isnothing(cache.f.cons_jac_prototype) - J = Float64.(cache.f.cons_jac_prototype) - else - J = zeros((length(cache.lcons), length(θ))) + prev_eqcons = zero(λ) + θ = cache.u0 + β = max.(cons_tmp[ineq_inds], Ref(0.0)) + prevβ = zero(β) + eqidxs = [eq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)] + ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)] + eqidxs = eqidxs[eqidxs .!= nothing] + ineqidxs = ineqidxs[ineqidxs .!= nothing] + function aug_grad(G, θ, p) + cache.f.grad(G, θ, p) + if !isnothing(cache.f.cons_jac_prototype) + J = Float64.(cache.f.cons_jac_prototype) + else + J = zeros((length(cache.lcons), length(θ))) + end + cache.f.cons_j(J, θ) + __tmp = zero(cons_tmp) + cache.f.cons(__tmp, θ) + __tmp[eq_inds] .= __tmp[eq_inds] .- cache.lcons[eq_inds] + __tmp[ineq_inds] .= __tmp[ineq_inds] .- cache.ucons[ineq_inds] + G .+= sum( + λ[i] .* J[idx, :] + ρ * (__tmp[idx] .* J[idx, :]) + for (i, idx) in enumerate(eqidxs); + init = zero(G)) #should be jvp + G .+= sum( + 1 / ρ * (max.(Ref(0.0), μ[i] .+ (ρ .* __tmp[idx])) .* J[idx, :]) + for (i, idx) in enumerate(ineqidxs); + init = zero(G)) #should be jvp end - cache.f.cons_j(J, θ) - __tmp = zero(cons_tmp) - cache.f.cons(__tmp, θ) - __tmp[eq_inds] .= __tmp[eq_inds] .- cache.lcons[eq_inds] - __tmp[ineq_inds] .= __tmp[ineq_inds] .- cache.ucons[ineq_inds] - G .+= sum( - λ[i] .* J[idx, :] + ρ * (__tmp[idx] .* J[idx, :]) - for (i, idx) in enumerate(eqidxs); - init = zero(G)) #should be jvp - G .+= sum( - 1 / ρ * (max.(Ref(0.0), μ[i] .+ (ρ .* __tmp[idx])) .* J[idx, :]) - for (i, idx) in enumerate(ineqidxs); - init = zero(G)) #should be jvp - end - opt_ret = ReturnCode.MaxIters - n = length(cache.u0) - - augprob = OptimizationProblem(OptimizationFunction(_loss; grad = aug_grad), cache.u0, cache.p) - - solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing)) - - for i in 1:(maxiters/10) - prev_eqcons .= cons_tmp[eq_inds] .- cache.lcons[eq_inds] - prevβ .= copy(β) - res = solve(augprob, cache.opt.inner, maxiters = maxiters / 10) - θ = res.u - cons_tmp .= 0.0 - cache.f.cons(cons_tmp, θ) - λ = max.(min.(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache.lcons[eq_inds])), λmin) - β = max.(cons_tmp[ineq_inds], -1 .* μ ./ ρ) - μ = min.(μmax, max.(μ .+ ρ * cons_tmp[ineq_inds], μmin)) - if max(norm(cons_tmp[eq_inds] .- cache.lcons[eq_inds], Inf), norm(β, Inf)) > - τ * max(norm(prev_eqcons, Inf), norm(prevβ, Inf)) - ρ = γ * ρ - end - if norm( - (cons_tmp[eq_inds] .- cache.lcons[eq_inds]) ./ cons_tmp[eq_inds], Inf) < - ϵ && norm(β, Inf) < ϵ - opt_ret = ReturnCode.Success - break + opt_ret = ReturnCode.MaxIters + n = length(cache.u0) + + augprob = OptimizationProblem( + OptimizationFunction(_loss; grad = aug_grad), cache.u0, cache.p) + + solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing)) + + for i in 1:(maxiters / 10) + prev_eqcons .= cons_tmp[eq_inds] .- cache.lcons[eq_inds] + prevβ .= copy(β) + res = solve(augprob, cache.opt.inner, maxiters = maxiters / 10) + θ = res.u + cons_tmp .= 0.0 + cache.f.cons(cons_tmp, θ) + λ = max.(min.(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache.lcons[eq_inds])), λmin) + β = max.(cons_tmp[ineq_inds], -1 .* μ ./ ρ) + μ = min.(μmax, max.(μ .+ ρ * cons_tmp[ineq_inds], μmin)) + if max(norm(cons_tmp[eq_inds] .- cache.lcons[eq_inds], Inf), norm(β, Inf)) > + τ * max(norm(prev_eqcons, Inf), norm(prevβ, Inf)) + ρ = γ * ρ + end + if norm( + (cons_tmp[eq_inds] .- cache.lcons[eq_inds]) ./ cons_tmp[eq_inds], Inf) < + ϵ && norm(β, Inf) < ϵ + opt_ret = ReturnCode.Success + break + end end - end - stats = Optimization.OptimizationStats(; iterations = maxiters, + stats = Optimization.OptimizationStats(; iterations = maxiters, time = 0.0, fevals = maxiters, gevals = maxiters) - return SciMLBase.build_solution( + return SciMLBase.build_solution( cache, cache.opt, θ, x, stats = stats, retcode = opt_ret) + end end -end \ No newline at end of file diff --git a/test/diffeqfluxtests.jl b/test/diffeqfluxtests.jl index 6ec24e2cd..4a6a170c0 100644 --- a/test/diffeqfluxtests.jl +++ b/test/diffeqfluxtests.jl @@ -70,7 +70,8 @@ ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) dudt2 = Lux.Chain(x -> x .^ 3, Lux.Dense(2, 50, tanh), Lux.Dense(50, 2)) -prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-8) +prob_neuralode = NeuralODE( + dudt2, tspan, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-8) pp, st = Lux.setup(rng, dudt2) pp = ComponentArray(pp) diff --git a/test/lbfgsb.jl b/test/native.jl similarity index 75% rename from test/lbfgsb.jl rename to test/native.jl index b981cc2f2..1f54f47d0 100644 --- a/test/lbfgsb.jl +++ b/test/native.jl @@ -28,7 +28,7 @@ prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf], using MLUtils, OptimizationOptimisers -x0 = -pi:0.001:pi +x0 = (-pi):0.001:pi y0 = sin.(x0) data = MLUtils.DataLoader((x0, y0), batchsize = 100) function loss(coeffs, data) @@ -44,12 +44,20 @@ end optf = OptimizationFunction(loss, AutoSparseForwardDiff(), cons = cons1) callback = (st, l) -> (@show l; return false) -prob = OptimizationProblem(optf, rand(5), (x0, y0), lcons = [-0.5], ucons = [0.5], lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0]) +initpars = rand(5) +l0 = optf(initpars, (x0, y0)) +prob = OptimizationProblem(optf, initpars, (x0, y0), lcons = [-Inf], ucons = [0.5], + lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0]) opt1 = solve(prob, Optimization.LBFGS(), maxiters = 1000, callback = callback) +@test opt1.objective < l0 -prob = OptimizationProblem(optf, rand(5), data, lcons = [0.0], ucons = [0.0], lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0]) -opt = solve(prob, Optimization.AugLag(; inner = Adam()), maxiters = 500, callback = callback) +prob = OptimizationProblem(optf, initpars, data, lcons = [-Inf], ucons = [1], + lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0]) +opt = solve( + prob, Optimization.AugLag(; inner = Adam()), maxiters = 10000, callback = callback) +@test opt.objective < l0 optf1 = OptimizationFunction(loss, AutoSparseForwardDiff()) prob1 = OptimizationProblem(optf1, rand(5), data) sol1 = solve(prob1, OptimizationOptimisers.Adam(), maxiters = 1000, callback = callback) +@test sol1.objective < l0 diff --git a/test/runtests.jl b/test/runtests.jl index ba1714ca2..d15543484 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,7 +36,7 @@ end include("AD_performance_regression.jl") end @safetestset "Optimization" begin - include("lbfgsb.jl") + include("native.jl") end @safetestset "Mini batching" begin include("minibatch.jl")