Skip to content

Commit

Permalink
Miscellaneous style improvements (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Aug 20, 2024
1 parent 55fdd9f commit 66ee350
Showing 1 changed file with 55 additions and 76 deletions.
131 changes: 55 additions & 76 deletions src/NLopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,21 @@ mutable struct Opt
finalizer(destroy, opt)
return opt
end
end

function Opt(algorithm::Algorithm, n::Integer)
if n < 0
throw(ArgumentError("invalid dimension $n < 0"))
end
p = nlopt_create(algorithm, n)
if p == C_NULL
error("Error in nlopt_create")
end
return Opt(p)
function Opt(algorithm::Algorithm, n::Integer)
if n < 0
throw(ArgumentError("invalid dimension $n < 0"))
end

function Opt(algorithm::Union{Integer,Symbol}, n::Integer)
return Opt(Algorithm(algorithm), n)
p = nlopt_create(algorithm, n)
if p == C_NULL
error("Error in nlopt_create")
end
return Opt(p)
end

function Opt(algorithm::Union{Integer,Symbol}, n::Integer)
return Opt(Algorithm(algorithm), n)
end

Base.unsafe_convert(::Type{Ptr{Cvoid}}, o::Opt) = getfield(o, :opt)
Expand Down Expand Up @@ -206,7 +206,7 @@ end
struct ForcedStop <: Exception end

# cache current exception for forced stop
nlopt_exception = nothing
global nlopt_exception = nothing

function errmsg(o::Opt)
msg = nlopt_get_errmsg(o)
Expand Down Expand Up @@ -347,19 +347,14 @@ end
xtol_abs!(o::Opt, val::Real) = chk(o, nlopt_set_xtol_abs1(o, val))

function local_optimizer!(o::Opt, lo::Opt)
ret = nlopt_set_local_optimizer(o, lo)
return chk(o, ret)
return chk(o, nlopt_set_local_optimizer(o, lo))
end

# the initial-stepsize stuff is a bit different than GETSET_VEC,
# since the heuristics depend on the position x.

function default_initial_step!(o::Opt, x::Vector{Cdouble})
if length(x) != ndims(o)
throw(BoundsError())
end
ret = nlopt_set_default_initial_step(o, x)
return chk(o, ret)
return chk(o, nlopt_set_default_initial_step(o, x))
end

function default_initial_step!(o::Opt, x::AbstractVector{<:Real})
Expand All @@ -370,25 +365,22 @@ function initial_step!(o::Opt, dx::Vector{Cdouble})
if length(dx) != ndims(o)
throw(BoundsError())
end
ret = nlopt_set_initial_step(o, dx)
return chk(o, ret)
return chk(o, nlopt_set_initial_step(o, dx))
end

function initial_step!(o::Opt, dx::AbstractVector{<:Real})
return initial_step!(o, Array{Cdouble}(dx))
end

function initial_step!(o::Opt, dx::Real)
ret = nlopt_set_initial_step1(o, dx)
return chk(o, ret)
return chk(o, nlopt_set_initial_step1(o, dx))
end

function initial_step(o::Opt, x::Vector{Cdouble}, dx::Vector{Cdouble})
if length(x) != ndims(o) || length(dx) != ndims(o)
throw(BoundsError())
end
ret::Result = nlopt_get_initial_step(o, x, dx)
chk(o, ret)
chk(o, nlopt_get_initial_step(o, x, dx))
return dx
end

Expand Down Expand Up @@ -436,29 +428,25 @@ srand_time() = nlopt_srand_time()
############################################################################
# Objective function:

const empty_grad = Cdouble[] # for passing when grad == C_NULL

function nlopt_callback_wrapper(
n::Cuint,
x::Ptr{Cdouble},
grad::Ptr{Cdouble},
p_x::Ptr{Cdouble},
p_grad::Ptr{Cdouble},
d_::Ptr{Cvoid},
)
)::Cdouble
d = unsafe_pointer_to_objref(d_)::Callback_Data
x = unsafe_wrap(Array, p_x, (n,))
grad = unsafe_wrap(Array, p_grad, (n,))
try
x_vec = unsafe_wrap(Array, x, (convert(Int, n),))
grad_vec = unsafe_wrap(Array, grad, (convert(Int, n),))
res =
convert(Cdouble, d.f(x_vec, grad == C_NULL ? empty_grad : grad_vec))
return res::Cdouble
return d.f(x, p_grad == C_NULL ? Cdouble[] : grad)
catch e
if e isa ForcedStop
global nlopt_exception = e
else
global nlopt_exception = CapturedException(e, catch_backtrace())
end
force_stop!(d.o::Opt)
return 0.0 # ignored by nlopt
return NaN
end
end

Expand All @@ -470,8 +458,7 @@ function min_objective!(o::Opt, f::Function)
Cdouble,
(Cuint, Ptr{Cdouble}, Ptr{Cdouble}, Ptr{Cvoid})
)
ret = nlopt_set_min_objective(o, c_fn, cb)
return chk(o, ret)
return chk(o, nlopt_set_min_objective(o, c_fn, cb))
end

function max_objective!(o::Opt, f::Function)
Expand All @@ -482,8 +469,7 @@ function max_objective!(o::Opt, f::Function)
Cdouble,
(Cuint, Ptr{Cdouble}, Ptr{Cdouble}, Ptr{Cvoid})
)
ret = nlopt_set_max_objective(o, c_fn, cb)
return chk(o, ret)
return chk(o, nlopt_set_max_objective(o, c_fn, cb))
end

############################################################################
Expand All @@ -497,8 +483,7 @@ function inequality_constraint!(o::Opt, f::Function, tol::Real = 0.0)
Cdouble,
(Cuint, Ptr{Cdouble}, Ptr{Cdouble}, Ptr{Cvoid})
)
ret::Result = nlopt_add_inequality_constraint(o, c_fn, cb, tol)
return chk(o, ret)
return chk(o, nlopt_add_inequality_constraint(o, c_fn, cb, tol))
end

function equality_constraint!(o::Opt, f::Function, tol::Real = 0.0)
Expand All @@ -509,40 +494,35 @@ function equality_constraint!(o::Opt, f::Function, tol::Real = 0.0)
Cdouble,
(Cuint, Ptr{Cdouble}, Ptr{Cdouble}, Ptr{Cvoid})
)
ret::Result = nlopt_add_equality_constraint(o, c_fn, cb, tol)
return chk(o, ret)
return chk(o, nlopt_add_equality_constraint(o, c_fn, cb, tol))
end

function remove_constraints!(o::Opt)
resize!(getfield(o, :cb), 1)
ret = nlopt_remove_inequality_constraints(o)
chk(o, ret)
# TODO(odow): why is this called twice?
ret = nlopt_remove_equality_constraints(o)
return chk(o, ret)
chk(o, nlopt_remove_inequality_constraints(o))
chk(o, nlopt_remove_equality_constraints(o))
return
end

############################################################################
# Vector-valued constraints

const empty_jac = Array{Cdouble}(undef, 0, 0) # for passing when grad == C_NULL

function nlopt_vcallback_wrapper(
m::Cuint,
res::Ptr{Cdouble},
p_res::Ptr{Cdouble},
n::Cuint,
x::Ptr{Cdouble},
grad::Ptr{Cdouble},
p_x::Ptr{Cdouble},
p_grad::Ptr{Cdouble},
d_::Ptr{Cvoid},
)
d = unsafe_pointer_to_objref(d_)::Callback_Data
res = unsafe_wrap(Array, p_res, (m,))
x = unsafe_wrap(Array, p_x, (n,))
grad =
p_grad == C_NULL ? zeros(Cdouble, 0, 0) :
unsafe_wrap(Array, p_grad, (n, m))
try
d.f(
unsafe_wrap(Array, res, (convert(Int, m),)),
unsafe_wrap(Array, x, (convert(Int, n),)),
grad == C_NULL ? empty_jac :
unsafe_wrap(Array, grad, (convert(Int, n), convert(Int, m))),
)
d.f(res, x, grad)
catch e
if e isa ForcedStop
global nlopt_exception = e
Expand All @@ -551,7 +531,7 @@ function nlopt_vcallback_wrapper(
end
force_stop!(d.o::Opt)
end
return nothing
return
end

function inequality_constraint!(o::Opt, f::Function, tol::Vector{Cdouble})
Expand All @@ -562,8 +542,7 @@ function inequality_constraint!(o::Opt, f::Function, tol::Vector{Cdouble})
Cvoid,
(Cuint, Ptr{Cdouble}, Cuint, Ptr{Cdouble}, Ptr{Cdouble}, Ptr{Cvoid}),
)
ret::Result =
nlopt_add_inequality_mconstraint(o, length(tol), c_fn, cb, tol)
ret = nlopt_add_inequality_mconstraint(o, length(tol), c_fn, cb, tol)
return chk(o, ret)
end

Expand Down Expand Up @@ -592,8 +571,7 @@ function equality_constraint!(o::Opt, f::Function, tol::Vector{Cdouble})
Cvoid,
(Cuint, Ptr{Cdouble}, Cuint, Ptr{Cdouble}, Ptr{Cdouble}, Ptr{Cvoid}),
)
ret::Result = nlopt_add_equality_mconstraint(o, length(tol), c_fn, cb, tol)
return chk(o, ret)
return chk(o, nlopt_add_equality_mconstraint(o, length(tol), c_fn, cb, tol))
end

function equality_constraint!(o::Opt, f::Function, tol::AbstractVector{<:Real})
Expand All @@ -611,8 +589,9 @@ end
OptParams <: AbstractDict{String, Float64}
Dictionary-like structure for accessing algorithm-specific parameters for
an NLopt optimization object `opt`, returned by `opt.params`. Use this
object to both set and view these string-keyed numeric parameters.
an NLopt optimization object `opt`, returned by `opt.params`.
Use this object to both set and view these string-keyed numeric parameters.
"""
struct OptParams <: AbstractDict{String,Float64}
o::Opt
Expand All @@ -622,13 +601,13 @@ Base.length(p::OptParams)::Int = nlopt_num_params(p.o)

Base.haskey(p::OptParams, s::AbstractString)::Bool = nlopt_has_param(p.o, s)

function Base.get(p::OptParams, s::AbstractString, defaultval::Float64)
return nlopt_get_param(p.o, s, defaultval)
function Base.get(p::OptParams, s::AbstractString, default::Float64)
return nlopt_get_param(p.o, s, default)
end

function Base.get(p::OptParams, s::AbstractString, defaultval)
function Base.get(p::OptParams, s::AbstractString, default)
if !haskey(p, s)
return defaultval
return default
end
return nlopt_get_param(p.o, s, NaN)
end
Expand All @@ -643,13 +622,13 @@ function Base.setindex!(p::OptParams, v::Algorithm, s::AbstractString)
end

function Base.iterate(p::OptParams, state = 0)
if state length(p)
if state >= length(p)
return nothing
end
name_ptr = nlopt_nth_param(p.o, state)
@assert name_ptr != C_NULL
name = unsafe_string(name_ptr)
return (name => p[name], state + 1)
return name => p[name], state + 1
end

############################################################################
Expand Down Expand Up @@ -780,7 +759,7 @@ function optimize!(o::Opt, x::Vector{Cdouble})
if length(x) != ndims(o)
throw(BoundsError())
end
opt_f = Array{Cdouble}(undef, 1)
opt_f = Ref{Cdouble}(NaN)
ret::Result = nlopt_optimize(o, x, opt_f)
# We do not need to check the value of `ret`, except if it is a FORCED_STOP
# with a Julia-related exception from a callback
Expand All @@ -792,7 +771,7 @@ function optimize!(o::Opt, x::Vector{Cdouble})
throw(e)
end
end
return opt_f[1], x, Symbol(ret)
return opt_f[], x, Symbol(ret)
end

function optimize(o::Opt, x::AbstractVector{<:Real})
Expand Down

0 comments on commit 66ee350

Please sign in to comment.