Skip to content

Commit 3ab6851

Browse files
committed
Refactoring
1 parent fcfc964 commit 3ab6851

File tree

1 file changed

+73
-127
lines changed

1 file changed

+73
-127
lines changed

src/egor.rs

Lines changed: 73 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -255,88 +255,10 @@ impl Egor {
255255
pyarray.to_owned_array()
256256
})
257257
};
258-
259-
let infill_strategy = match self.infill_strategy {
260-
InfillStrategy::Ei => egobox_ego::InfillStrategy::EI,
261-
InfillStrategy::Wb2 => egobox_ego::InfillStrategy::WB2,
262-
InfillStrategy::Wb2s => egobox_ego::InfillStrategy::WB2S,
263-
};
264-
265-
let qei_strategy = match self.par_infill_strategy {
266-
ParInfillStrategy::Kb => egobox_ego::QEiStrategy::KrigingBeliever,
267-
ParInfillStrategy::Kblb => egobox_ego::QEiStrategy::KrigingBelieverLowerBound,
268-
ParInfillStrategy::Kbub => egobox_ego::QEiStrategy::KrigingBelieverUpperBound,
269-
ParInfillStrategy::Clmin => egobox_ego::QEiStrategy::ConstantLiarMinimum,
270-
};
271-
272-
let infill_optimizer = match self.infill_optimizer {
273-
InfillOptimizer::Cobyla => egobox_ego::InfillOptimizer::Cobyla,
274-
InfillOptimizer::Slsqp => egobox_ego::InfillOptimizer::Slsqp,
275-
};
276-
277-
let xspecs: Vec<XSpec> = self.xspecs.extract(py).expect("Error in xspecs conversion");
278-
if xspecs.is_empty() {
279-
panic!("Error: xspecs argument cannot be empty")
280-
}
281-
282-
let xtypes: Vec<egobox_ego::XType> = xspecs
283-
.iter()
284-
.map(|spec| match spec.xtype {
285-
XType::Float => egobox_ego::XType::Cont(spec.xlimits[0], spec.xlimits[1]),
286-
XType::Int => {
287-
egobox_ego::XType::Int(spec.xlimits[0] as i32, spec.xlimits[1] as i32)
288-
}
289-
XType::Ord => egobox_ego::XType::Ord(spec.xlimits.clone()),
290-
XType::Enum => {
291-
if spec.tags.is_empty() {
292-
egobox_ego::XType::Enum(spec.xlimits[0] as usize)
293-
} else {
294-
egobox_ego::XType::Enum(spec.tags.len())
295-
}
296-
}
297-
})
298-
.collect();
299-
300-
let cstr_tol = self.cstr_tol.clone().unwrap_or(vec![0.0; self.n_cstr]);
301-
let cstr_tol = Array1::from_vec(cstr_tol);
258+
let xtypes: Vec<egobox_ego::XType> = self.xtypes(py);
302259

303260
let mixintegor = egobox_ego::EgorBuilder::optimize(obj)
304-
.configure(|config| {
305-
let mut config = config
306-
.n_cstr(self.n_cstr)
307-
.max_iters(max_iters)
308-
.n_start(self.n_start)
309-
.n_doe(self.n_doe)
310-
.cstr_tol(&cstr_tol)
311-
.regression_spec(
312-
egobox_moe::RegressionSpec::from_bits(self.regression_spec.0).unwrap(),
313-
)
314-
.correlation_spec(
315-
egobox_moe::CorrelationSpec::from_bits(self.correlation_spec.0).unwrap(),
316-
)
317-
.infill_strategy(infill_strategy)
318-
.q_points(self.q_points)
319-
.qei_strategy(qei_strategy)
320-
.infill_optimizer(infill_optimizer)
321-
.target(self.target)
322-
.hot_start(self.hot_start);
323-
if let Some(doe) = self.doe.as_ref() {
324-
config = config.doe(doe);
325-
};
326-
if let Some(kpls_dim) = self.kpls_dim {
327-
config = config.kpls_dim(kpls_dim);
328-
};
329-
if let Some(n_clusters) = self.n_clusters {
330-
config = config.n_clusters(n_clusters);
331-
};
332-
if let Some(outdir) = self.outdir.as_ref().cloned() {
333-
config = config.outdir(outdir);
334-
};
335-
if let Some(seed) = self.seed {
336-
config = config.random_seed(seed);
337-
};
338-
config
339-
})
261+
.configure(|config| self.apply_config(config, Some(max_iters), self.doe.as_ref()))
340262
.min_within_mixint_space(&xtypes);
341263

342264
let res = py.allow_threads(|| {
@@ -377,27 +299,44 @@ impl Egor {
377299
) -> Py<PyArray2<f64>> {
378300
let x_doe = x_doe.as_array();
379301
let y_doe = y_doe.as_array();
380-
381302
let doe = concatenate(Axis(1), &[x_doe.view(), y_doe.view()]).unwrap();
303+
let xtypes: Vec<egobox_ego::XType> = self.xtypes(py);
304+
305+
let mixintegor = egobox_ego::EgorServiceBuilder::optimize()
306+
.configure(|config| self.apply_config(config, Some(1), Some(&doe)))
307+
.min_within_mixint_space(&xtypes);
308+
309+
let x_suggested = py.allow_threads(|| mixintegor.suggest(&x_doe, &y_doe));
310+
x_suggested.into_pyarray(py).to_owned()
311+
}
312+
}
382313

383-
let infill_strategy = match self.infill_strategy {
314+
impl Egor {
315+
fn infill_strategy(&self) -> egobox_ego::InfillStrategy {
316+
match self.infill_strategy {
384317
InfillStrategy::Ei => egobox_ego::InfillStrategy::EI,
385318
InfillStrategy::Wb2 => egobox_ego::InfillStrategy::WB2,
386319
InfillStrategy::Wb2s => egobox_ego::InfillStrategy::WB2S,
387-
};
320+
}
321+
}
388322

389-
let qei_strategy = match self.par_infill_strategy {
323+
fn qei_strategy(&self) -> egobox_ego::QEiStrategy {
324+
match self.par_infill_strategy {
390325
ParInfillStrategy::Kb => egobox_ego::QEiStrategy::KrigingBeliever,
391326
ParInfillStrategy::Kblb => egobox_ego::QEiStrategy::KrigingBelieverLowerBound,
392327
ParInfillStrategy::Kbub => egobox_ego::QEiStrategy::KrigingBelieverUpperBound,
393328
ParInfillStrategy::Clmin => egobox_ego::QEiStrategy::ConstantLiarMinimum,
394-
};
329+
}
330+
}
395331

396-
let infill_optimizer = match self.infill_optimizer {
332+
fn infill_optimizer(&self) -> egobox_ego::InfillOptimizer {
333+
match self.infill_optimizer {
397334
InfillOptimizer::Cobyla => egobox_ego::InfillOptimizer::Cobyla,
398335
InfillOptimizer::Slsqp => egobox_ego::InfillOptimizer::Slsqp,
399-
};
336+
}
337+
}
400338

339+
fn xtypes(&self, py: Python) -> Vec<egobox_ego::XType> {
401340
let xspecs: Vec<XSpec> = self.xspecs.extract(py).expect("Error in xspecs conversion");
402341
if xspecs.is_empty() {
403342
panic!("Error: xspecs argument cannot be empty")
@@ -420,49 +359,56 @@ impl Egor {
420359
}
421360
})
422361
.collect();
362+
xtypes
363+
}
423364

365+
fn cstr_tol(&self) -> Array1<f64> {
424366
let cstr_tol = self.cstr_tol.clone().unwrap_or(vec![0.0; self.n_cstr]);
425-
let cstr_tol = Array1::from_vec(cstr_tol);
367+
Array1::from_vec(cstr_tol)
368+
}
426369

427-
let mixintegor = egobox_ego::EgorServiceBuilder::optimize()
428-
.configure(|config| {
429-
let mut config = config
430-
.n_cstr(self.n_cstr)
431-
.n_start(self.n_start)
432-
.doe(&doe)
433-
.cstr_tol(&cstr_tol)
434-
.regression_spec(
435-
egobox_moe::RegressionSpec::from_bits(self.regression_spec.0).unwrap(),
436-
)
437-
.correlation_spec(
438-
egobox_moe::CorrelationSpec::from_bits(self.correlation_spec.0).unwrap(),
439-
)
440-
.infill_strategy(infill_strategy)
441-
.q_points(self.q_points)
442-
.qei_strategy(qei_strategy)
443-
.infill_optimizer(infill_optimizer)
444-
.target(self.target)
445-
.hot_start(false); // when used as a service no hotstart
446-
if let Some(doe) = self.doe.as_ref() {
447-
config = config.doe(doe);
448-
};
449-
if let Some(kpls_dim) = self.kpls_dim {
450-
config = config.kpls_dim(kpls_dim);
451-
};
452-
if let Some(n_clusters) = self.n_clusters {
453-
config = config.n_clusters(n_clusters);
454-
};
455-
if let Some(outdir) = self.outdir.as_ref().cloned() {
456-
config = config.outdir(outdir);
457-
};
458-
if let Some(seed) = self.seed {
459-
config = config.random_seed(seed);
460-
};
461-
config
462-
})
463-
.min_within_mixint_space(&xtypes);
370+
fn apply_config(
371+
&self,
372+
config: egobox_ego::EgorConfig,
373+
max_iters: Option<usize>,
374+
doe: Option<&Array2<f64>>,
375+
) -> egobox_ego::EgorConfig {
376+
let infill_strategy = self.infill_strategy();
377+
let qei_strategy = self.qei_strategy();
378+
let infill_optimizer = self.infill_optimizer();
379+
let cstr_tol = self.cstr_tol();
464380

465-
let x_suggested = py.allow_threads(|| mixintegor.suggest(&x_doe, &y_doe));
466-
x_suggested.into_pyarray(py).to_owned()
381+
let mut config = config
382+
.n_cstr(self.n_cstr)
383+
.max_iters(max_iters.unwrap_or(1))
384+
.n_start(self.n_start)
385+
.n_doe(self.n_doe)
386+
.cstr_tol(&cstr_tol)
387+
.regression_spec(egobox_moe::RegressionSpec::from_bits(self.regression_spec.0).unwrap())
388+
.correlation_spec(
389+
egobox_moe::CorrelationSpec::from_bits(self.correlation_spec.0).unwrap(),
390+
)
391+
.infill_strategy(infill_strategy)
392+
.q_points(self.q_points)
393+
.qei_strategy(qei_strategy)
394+
.infill_optimizer(infill_optimizer)
395+
.target(self.target)
396+
.hot_start(self.hot_start); // when used as a service no hotstart
397+
if let Some(doe) = doe {
398+
config = config.doe(doe);
399+
};
400+
if let Some(kpls_dim) = self.kpls_dim {
401+
config = config.kpls_dim(kpls_dim);
402+
};
403+
if let Some(n_clusters) = self.n_clusters {
404+
config = config.n_clusters(n_clusters);
405+
};
406+
if let Some(outdir) = self.outdir.as_ref().cloned() {
407+
config = config.outdir(outdir);
408+
};
409+
if let Some(seed) = self.seed {
410+
config = config.random_seed(seed);
411+
};
412+
config
467413
}
468414
}

0 commit comments

Comments
 (0)