Skip to content

Commit 993e4a6

Browse files
committed
Make Egor.suggest available in Python
1 parent 36bafbd commit 993e4a6

File tree

2 files changed

+127
-10
lines changed

2 files changed

+127
-10
lines changed

python/egobox/tests/test_egor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,18 @@ def test_constructor(self):
174174
self.assertRaises(TypeError, egx.Egor)
175175
egx.Egor(xsinx, egx.to_specs([[0.0, 25.0]]), 22, n_doe=10)
176176

177+
def test_egor_service(self):
178+
xlimits = egx.to_specs([[0.0, 25.0]])
179+
egor = egx.Egor(xlimits, seed=42)
180+
x_doe = egx.lhs(xlimits, 3, seed=42)
181+
print(x_doe)
182+
y_doe = xsinx(x_doe)
183+
print(y_doe)
184+
for _ in range(10):
185+
x = egor.suggest(x_doe, y_doe)
186+
x_doe = np.concatenate((x_doe, x))
187+
y_doe = np.concatenate((y_doe, xsinx(x)))
188+
177189

178190
if __name__ == "__main__":
179191
unittest.main()

src/egor.rs

Lines changed: 115 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
//!
1212
1313
use crate::types::*;
14-
use ndarray::Array1;
14+
use ndarray::{concatenate, Array1, Axis};
1515
use numpy::ndarray::{Array2, ArrayView2};
1616
use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray2};
1717
use pyo3::exceptions::PyValueError;
@@ -133,7 +133,6 @@ pub(crate) fn to_specs(py: Python, xlimits: Vec<Vec<f64>>) -> PyResult<PyObject>
133133
///
134134
#[pyclass]
135135
pub(crate) struct Egor {
136-
pub fun: PyObject,
137136
pub xspecs: PyObject,
138137
pub n_cstr: usize,
139138
pub cstr_tol: Option<Vec<f64>>,
@@ -170,7 +169,6 @@ pub(crate) struct OptimResult {
170169
impl Egor {
171170
#[new]
172171
#[pyo3(signature = (
173-
fun,
174172
xspecs,
175173
n_cstr = 0,
176174
cstr_tol = None,
@@ -192,8 +190,7 @@ impl Egor {
192190
))]
193191
#[allow(clippy::too_many_arguments)]
194192
fn new(
195-
py: Python,
196-
fun: PyObject,
193+
_py: Python,
197194
xspecs: PyObject,
198195
n_cstr: usize,
199196
cstr_tol: Option<Vec<f64>>,
@@ -215,7 +212,6 @@ impl Egor {
215212
) -> Self {
216213
let doe = doe.map(|x| x.to_owned_array());
217214
Egor {
218-
fun: fun.to_object(py),
219215
xspecs,
220216
n_cstr,
221217
cstr_tol,
@@ -248,9 +244,9 @@ impl Egor {
248244
/// x_opt (array[1, nx]): x value where fun is at its minimum subject to constraint
249245
/// y_opt (array[1, nx]): fun(x_opt)
250246
///
251-
#[pyo3(signature = (max_iters = 20))]
252-
fn minimize(&self, py: Python, max_iters: usize) -> PyResult<OptimResult> {
253-
let fun = self.fun.to_object(py);
247+
#[pyo3(signature = (fun, max_iters = 20))]
248+
fn minimize(&self, py: Python, fun: PyObject, max_iters: usize) -> PyResult<OptimResult> {
249+
let fun = fun.to_object(py);
254250
let obj = move |x: &ArrayView2<f64>| -> Array2<f64> {
255251
Python::with_gil(|py| {
256252
let args = (x.to_owned().into_pyarray(py),);
@@ -300,7 +296,6 @@ impl Egor {
300296
}
301297
})
302298
.collect();
303-
println!("{:?}", xtypes);
304299

305300
let cstr_tol = self.cstr_tol.clone().unwrap_or(vec![0.0; self.n_cstr]);
306301
let cstr_tol = Array1::from_vec(cstr_tol);
@@ -360,4 +355,114 @@ impl Egor {
360355
y_hist,
361356
})
362357
}
358+
359+
/// This function gives the next best location where to evaluate the function
360+
/// under optimization wrt to previous evaluations.
361+
/// The function returns several point when multi point qEI strategy is used.
362+
///
363+
/// # Parameters
364+
/// x_doe (array[ns, nx]): ns samples where function has been evaluated
365+
/// y_doe (array[ns, 1 + n_cstr]): ns values of objetcive and constraints
366+
///
367+
///
368+
/// # Returns
369+
/// (array[1, nx]): suggested location where to evaluate objective and constraints
370+
///
371+
#[pyo3(signature = (x_doe, y_doe))]
372+
fn suggest(
373+
&self,
374+
py: Python,
375+
x_doe: PyReadonlyArray2<f64>,
376+
y_doe: PyReadonlyArray2<f64>,
377+
) -> Py<PyArray2<f64>> {
378+
let x_doe = x_doe.as_array();
379+
let y_doe = y_doe.as_array();
380+
381+
let doe = concatenate(Axis(1), &[x_doe.view(), y_doe.view()]).unwrap();
382+
383+
let infill_strategy = match self.infill_strategy {
384+
InfillStrategy::Ei => egobox_ego::InfillStrategy::EI,
385+
InfillStrategy::Wb2 => egobox_ego::InfillStrategy::WB2,
386+
InfillStrategy::Wb2s => egobox_ego::InfillStrategy::WB2S,
387+
};
388+
389+
let qei_strategy = match self.par_infill_strategy {
390+
ParInfillStrategy::Kb => egobox_ego::QEiStrategy::KrigingBeliever,
391+
ParInfillStrategy::Kblb => egobox_ego::QEiStrategy::KrigingBelieverLowerBound,
392+
ParInfillStrategy::Kbub => egobox_ego::QEiStrategy::KrigingBelieverUpperBound,
393+
ParInfillStrategy::Clmin => egobox_ego::QEiStrategy::ConstantLiarMinimum,
394+
};
395+
396+
let infill_optimizer = match self.infill_optimizer {
397+
InfillOptimizer::Cobyla => egobox_ego::InfillOptimizer::Cobyla,
398+
InfillOptimizer::Slsqp => egobox_ego::InfillOptimizer::Slsqp,
399+
};
400+
401+
let xspecs: Vec<XSpec> = self.xspecs.extract(py).expect("Error in xspecs conversion");
402+
if xspecs.is_empty() {
403+
panic!("Error: xspecs argument cannot be empty")
404+
}
405+
406+
let xtypes: Vec<egobox_ego::XType> = xspecs
407+
.iter()
408+
.map(|spec| match spec.xtype {
409+
XType::Float => egobox_ego::XType::Cont(spec.xlimits[0], spec.xlimits[1]),
410+
XType::Int => {
411+
egobox_ego::XType::Int(spec.xlimits[0] as i32, spec.xlimits[1] as i32)
412+
}
413+
XType::Ord => egobox_ego::XType::Ord(spec.xlimits.clone()),
414+
XType::Enum => {
415+
if spec.tags.is_empty() {
416+
egobox_ego::XType::Enum(spec.xlimits[0] as usize)
417+
} else {
418+
egobox_ego::XType::Enum(spec.tags.len())
419+
}
420+
}
421+
})
422+
.collect();
423+
424+
let cstr_tol = self.cstr_tol.clone().unwrap_or(vec![0.0; self.n_cstr]);
425+
let cstr_tol = Array1::from_vec(cstr_tol);
426+
427+
let mixintegor = egobox_ego::EgorServiceBuilder::optimize()
428+
.configure(|config| {
429+
let mut config = config
430+
.n_cstr(self.n_cstr)
431+
.n_start(self.n_start)
432+
.doe(&doe)
433+
.cstr_tol(&cstr_tol)
434+
.regression_spec(
435+
egobox_moe::RegressionSpec::from_bits(self.regression_spec.0).unwrap(),
436+
)
437+
.correlation_spec(
438+
egobox_moe::CorrelationSpec::from_bits(self.correlation_spec.0).unwrap(),
439+
)
440+
.infill_strategy(infill_strategy)
441+
.q_points(self.q_points)
442+
.qei_strategy(qei_strategy)
443+
.infill_optimizer(infill_optimizer)
444+
.target(self.target)
445+
.hot_start(false); // when used as a service no hotstart
446+
if let Some(doe) = self.doe.as_ref() {
447+
config = config.doe(doe);
448+
};
449+
if let Some(kpls_dim) = self.kpls_dim {
450+
config = config.kpls_dim(kpls_dim);
451+
};
452+
if let Some(n_clusters) = self.n_clusters {
453+
config = config.n_clusters(n_clusters);
454+
};
455+
if let Some(outdir) = self.outdir.as_ref().cloned() {
456+
config = config.outdir(outdir);
457+
};
458+
if let Some(seed) = self.seed {
459+
config = config.random_seed(seed);
460+
};
461+
config
462+
})
463+
.min_within_mixint_space(&xtypes);
464+
465+
let x_suggested = py.allow_threads(|| mixintegor.suggest(&x_doe, &y_doe));
466+
x_suggested.into_pyarray(py).to_owned()
467+
}
363468
}

0 commit comments

Comments
 (0)