Skip to content

Commit

Permalink
Make Egor.suggest available in Python
Browse files Browse the repository at this point in the history
  • Loading branch information
relf committed Nov 22, 2023
1 parent 36bafbd commit 993e4a6
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 10 deletions.
12 changes: 12 additions & 0 deletions python/egobox/tests/test_egor.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,18 @@ def test_constructor(self):
self.assertRaises(TypeError, egx.Egor)
egx.Egor(xsinx, egx.to_specs([[0.0, 25.0]]), 22, n_doe=10)

def test_egor_service(self):
xlimits = egx.to_specs([[0.0, 25.0]])
egor = egx.Egor(xlimits, seed=42)
x_doe = egx.lhs(xlimits, 3, seed=42)
print(x_doe)
y_doe = xsinx(x_doe)
print(y_doe)
for _ in range(10):
x = egor.suggest(x_doe, y_doe)
x_doe = np.concatenate((x_doe, x))
y_doe = np.concatenate((y_doe, xsinx(x)))


if __name__ == "__main__":
unittest.main()
125 changes: 115 additions & 10 deletions src/egor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
//!
use crate::types::*;
use ndarray::Array1;
use ndarray::{concatenate, Array1, Axis};
use numpy::ndarray::{Array2, ArrayView2};
use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray2};
use pyo3::exceptions::PyValueError;
Expand Down Expand Up @@ -133,7 +133,6 @@ pub(crate) fn to_specs(py: Python, xlimits: Vec<Vec<f64>>) -> PyResult<PyObject>
///
#[pyclass]
pub(crate) struct Egor {
pub fun: PyObject,
pub xspecs: PyObject,
pub n_cstr: usize,
pub cstr_tol: Option<Vec<f64>>,
Expand Down Expand Up @@ -170,7 +169,6 @@ pub(crate) struct OptimResult {
impl Egor {
#[new]
#[pyo3(signature = (
fun,
xspecs,
n_cstr = 0,
cstr_tol = None,
Expand All @@ -192,8 +190,7 @@ impl Egor {
))]
#[allow(clippy::too_many_arguments)]
fn new(
py: Python,
fun: PyObject,
_py: Python,
xspecs: PyObject,
n_cstr: usize,
cstr_tol: Option<Vec<f64>>,
Expand All @@ -215,7 +212,6 @@ impl Egor {
) -> Self {
let doe = doe.map(|x| x.to_owned_array());
Egor {
fun: fun.to_object(py),
xspecs,
n_cstr,
cstr_tol,
Expand Down Expand Up @@ -248,9 +244,9 @@ impl Egor {
/// x_opt (array[1, nx]): x value where fun is at its minimum subject to constraint
/// y_opt (array[1, nx]): fun(x_opt)
///
#[pyo3(signature = (max_iters = 20))]
fn minimize(&self, py: Python, max_iters: usize) -> PyResult<OptimResult> {
let fun = self.fun.to_object(py);
#[pyo3(signature = (fun, max_iters = 20))]
fn minimize(&self, py: Python, fun: PyObject, max_iters: usize) -> PyResult<OptimResult> {
let fun = fun.to_object(py);
let obj = move |x: &ArrayView2<f64>| -> Array2<f64> {
Python::with_gil(|py| {
let args = (x.to_owned().into_pyarray(py),);
Expand Down Expand Up @@ -300,7 +296,6 @@ impl Egor {
}
})
.collect();
println!("{:?}", xtypes);

let cstr_tol = self.cstr_tol.clone().unwrap_or(vec![0.0; self.n_cstr]);
let cstr_tol = Array1::from_vec(cstr_tol);
Expand Down Expand Up @@ -360,4 +355,114 @@ impl Egor {
y_hist,
})
}

/// This function gives the next best location where to evaluate the function
/// under optimization wrt to previous evaluations.
/// The function returns several point when multi point qEI strategy is used.
///
/// # Parameters
/// x_doe (array[ns, nx]): ns samples where function has been evaluated
/// y_doe (array[ns, 1 + n_cstr]): ns values of objetcive and constraints
///
///
/// # Returns
/// (array[1, nx]): suggested location where to evaluate objective and constraints
///
#[pyo3(signature = (x_doe, y_doe))]
fn suggest(
&self,
py: Python,
x_doe: PyReadonlyArray2<f64>,
y_doe: PyReadonlyArray2<f64>,
) -> Py<PyArray2<f64>> {
let x_doe = x_doe.as_array();
let y_doe = y_doe.as_array();

let doe = concatenate(Axis(1), &[x_doe.view(), y_doe.view()]).unwrap();

let infill_strategy = match self.infill_strategy {
InfillStrategy::Ei => egobox_ego::InfillStrategy::EI,
InfillStrategy::Wb2 => egobox_ego::InfillStrategy::WB2,
InfillStrategy::Wb2s => egobox_ego::InfillStrategy::WB2S,
};

let qei_strategy = match self.par_infill_strategy {
ParInfillStrategy::Kb => egobox_ego::QEiStrategy::KrigingBeliever,
ParInfillStrategy::Kblb => egobox_ego::QEiStrategy::KrigingBelieverLowerBound,
ParInfillStrategy::Kbub => egobox_ego::QEiStrategy::KrigingBelieverUpperBound,
ParInfillStrategy::Clmin => egobox_ego::QEiStrategy::ConstantLiarMinimum,
};

let infill_optimizer = match self.infill_optimizer {
InfillOptimizer::Cobyla => egobox_ego::InfillOptimizer::Cobyla,
InfillOptimizer::Slsqp => egobox_ego::InfillOptimizer::Slsqp,
};

let xspecs: Vec<XSpec> = self.xspecs.extract(py).expect("Error in xspecs conversion");
if xspecs.is_empty() {
panic!("Error: xspecs argument cannot be empty")
}

let xtypes: Vec<egobox_ego::XType> = xspecs
.iter()
.map(|spec| match spec.xtype {
XType::Float => egobox_ego::XType::Cont(spec.xlimits[0], spec.xlimits[1]),
XType::Int => {
egobox_ego::XType::Int(spec.xlimits[0] as i32, spec.xlimits[1] as i32)
}
XType::Ord => egobox_ego::XType::Ord(spec.xlimits.clone()),
XType::Enum => {
if spec.tags.is_empty() {
egobox_ego::XType::Enum(spec.xlimits[0] as usize)
} else {
egobox_ego::XType::Enum(spec.tags.len())
}
}
})
.collect();

let cstr_tol = self.cstr_tol.clone().unwrap_or(vec![0.0; self.n_cstr]);
let cstr_tol = Array1::from_vec(cstr_tol);

let mixintegor = egobox_ego::EgorServiceBuilder::optimize()
.configure(|config| {
let mut config = config
.n_cstr(self.n_cstr)
.n_start(self.n_start)
.doe(&doe)
.cstr_tol(&cstr_tol)
.regression_spec(
egobox_moe::RegressionSpec::from_bits(self.regression_spec.0).unwrap(),
)
.correlation_spec(
egobox_moe::CorrelationSpec::from_bits(self.correlation_spec.0).unwrap(),
)
.infill_strategy(infill_strategy)
.q_points(self.q_points)
.qei_strategy(qei_strategy)
.infill_optimizer(infill_optimizer)
.target(self.target)
.hot_start(false); // when used as a service no hotstart
if let Some(doe) = self.doe.as_ref() {
config = config.doe(doe);
};
if let Some(kpls_dim) = self.kpls_dim {
config = config.kpls_dim(kpls_dim);
};
if let Some(n_clusters) = self.n_clusters {
config = config.n_clusters(n_clusters);
};
if let Some(outdir) = self.outdir.as_ref().cloned() {
config = config.outdir(outdir);
};
if let Some(seed) = self.seed {
config = config.random_seed(seed);
};
config
})
.min_within_mixint_space(&xtypes);

let x_suggested = py.allow_threads(|| mixintegor.suggest(&x_doe, &y_doe));
x_suggested.into_pyarray(py).to_owned()
}
}

0 comments on commit 993e4a6

Please sign in to comment.