Skip to content

Commit aba52f7

Browse files
authored
Improve hot start API (#199)
* Improve hot start API * Fix visibility * Use HotStartMode enum in EgorConfig
1 parent 625fdd4 commit aba52f7

File tree

5 files changed

+101
-19
lines changed

5 files changed

+101
-19
lines changed

ego/src/egor.rs

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//! * Trust-region EGO optional activation
77
//! * Infill criteria: EI, WB2, WB2S
88
//! * Multi-point infill strategy (aka qEI)
9+
//! * Warm/hot start
910
//!
1011
//! See refences below.
1112
//!
@@ -103,6 +104,7 @@ use crate::gpmix::mixint::*;
103104
use crate::types::*;
104105
use crate::EgorConfig;
105106
use crate::EgorState;
107+
use crate::HotStartMode;
106108
use crate::{to_xtypes, EgorSolver};
107109
use crate::{CheckpointingFrequency, HotStartCheckpoint};
108110

@@ -213,12 +215,12 @@ impl<O: GroupFunc, SB: SurrogateBuilder + DeserializeOwned> Egor<O, SB> {
213215

214216
let exec = Executor::new(self.fobj.clone(), self.solver.clone());
215217

216-
let exec = if let Some(ext_iters) = self.solver.config.hot_start {
218+
let exec = if self.solver.config.hot_start != HotStartMode::Disabled {
217219
let checkpoint = HotStartCheckpoint::new(
218220
".checkpoints",
219221
"egor",
220222
CheckpointingFrequency::Always,
221-
ext_iters,
223+
self.solver.config.hot_start.clone(),
222224
);
223225
exec.checkpointing(checkpoint)
224226
} else {
@@ -423,7 +425,12 @@ mod tests {
423425
let _ = std::fs::remove_file(".checkpoints/egor.arg");
424426
let n_iter = 1;
425427
let res = EgorBuilder::optimize(xsinx)
426-
.configure(|config| config.max_iters(n_iter).seed(42).hot_start(Some(0)))
428+
.configure(|config| {
429+
config
430+
.max_iters(n_iter)
431+
.seed(42)
432+
.hot_start(HotStartMode::Enabled)
433+
})
427434
.min_within(&array![[0.0, 25.0]])
428435
.run()
429436
.expect("Egor should minimize");
@@ -432,7 +439,12 @@ mod tests {
432439

433440
// without hostart we reach the same point
434441
let res = EgorBuilder::optimize(xsinx)
435-
.configure(|config| config.max_iters(n_iter).seed(42).hot_start(None))
442+
.configure(|config| {
443+
config
444+
.max_iters(n_iter)
445+
.seed(42)
446+
.hot_start(HotStartMode::Disabled)
447+
})
436448
.min_within(&array![[0.0, 25.0]])
437449
.run()
438450
.expect("Egor should minimize");
@@ -442,7 +454,11 @@ mod tests {
442454
// with hot start we continue
443455
let ext_iters = 3;
444456
let res = EgorBuilder::optimize(xsinx)
445-
.configure(|config| config.seed(42).hot_start(Some(ext_iters)))
457+
.configure(|config| {
458+
config
459+
.seed(42)
460+
.hot_start(HotStartMode::ExtendedIters(ext_iters))
461+
})
446462
.min_within(&array![[0.0, 25.0]])
447463
.run()
448464
.expect("Egor should minimize");

ego/src/lib.rs

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
//! * specify the initial doe,
1212
//! * parameterize internal optimization,
1313
//! * parameterize mixture of experts,
14-
//! * save intermediate results and allow warm restart,
14+
//! * save intermediate results and allow warm/hot restart,
15+
//! * handling of mixed-integer variables
16+
//! * activation of TREGO algorithm variation
1517
//!
1618
//! # Examples
1719
//!
@@ -149,13 +151,47 @@
149151
//! In the above example all GP with combinations of regression and correlation will be tested and the best combination for
150152
//! each modeled function will be retained. You can also simply specify `RegressionSpec::ALL` and `CorrelationSpec::ALL` to
151153
//! test all available combinations but remember that the more you test the slower it runs.
154+
//!
155+
//! * the TREGO algorithm described in \[[Diouane2023](#Diouane2023)\] can be activated
156+
//!
157+
//! ```no_run
158+
//! # use egobox_ego::{EgorConfig, RegressionSpec, CorrelationSpec};
159+
//! # let egor_config = EgorConfig::default();
160+
//! egor_config.trego(true);
161+
//! ```
162+
//!
163+
//! * Intermediate results can be logged at each iteration when `outdir` directory is specified.
164+
//! The following files :
165+
//! * egor_config.json: Egor configuration,
166+
//! * egor_initial_doe.npy: initial DOE (x, y) as numpy array,
167+
//! * egor_doe.npy: DOE (x, y) as numpy array,
168+
//! * egor_history.npy: best (x, y) wrt to iteration number as (n_iters, nx + ny) numpy array
152169
//!
170+
//! ```no_run
171+
//! # use egobox_ego::EgorConfig;
172+
//! # let egor_config = EgorConfig::default();
173+
//! egor_config.outdir("./.output");
174+
//! ```
175+
//! If warm_start is set to `true`, the algorithm starts from the saved `egor_doe.npy`
176+
//!
177+
//! * Hot start checkpointing can be enabled with `hot_start` option specifying a number of
178+
//! extra iterations beyond max iters. This mechanism allows to restart after an interruption
179+
//! from the last saved checkpoint. While warm_start restart from saved doe for another max_iters
180+
//! iterations, hot start allows to continue from the last saved optimizer state till max_iters
181+
//! is reached with optinal extra iterations.
182+
//!
183+
//! ```no_run
184+
//! # use egobox_ego::{EgorConfig, HotStartMode};
185+
//! # let egor_config = EgorConfig::default();
186+
//! egor_config.hot_start(HotStartMode::Enabled);
187+
//! ```
188+
//!
153189
//! # Implementation notes
154190
//!
155191
//! * Mixture of experts and PLS dimension reduction is explained in \[[Bartoli2019](#Bartoli2019)\]
156192
//! * Parallel optimization is available through the selection of a qei strategy. See in \[[Ginsbourger2010](#Ginsbourger2010)\]
157193
//! * Mixed integer approach is implemented using continuous relaxation. See \[[Garrido2018](#Garrido2018)\]
158-
//! * TREGO algorithm is enabled by default. See \[[Diouane2023](#Diouane2023)\]
194+
//! * TREGO algorithm is not enabled by default. See \[[Diouane2023](#Diouane2023)\]
159195
//!
160196
//! # References
161197
//!
@@ -210,7 +246,7 @@ pub use crate::gpmix::spec::{CorrelationSpec, RegressionSpec};
210246
pub use crate::solver::*;
211247
pub use crate::types::*;
212248
pub use crate::utils::{
213-
find_best_result_index, Checkpoint, CheckpointingFrequency, HotStartCheckpoint,
249+
find_best_result_index, Checkpoint, CheckpointingFrequency, HotStartCheckpoint, HotStartMode,
214250
};
215251

216252
mod optimizers;

ego/src/solver/egor_config.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Egor optimizer configuration.
22
use crate::criteria::*;
33
use crate::types::*;
4+
use crate::HotStartMode;
45
use egobox_moe::{CorrelationSpec, RegressionSpec};
56
use ndarray::Array1;
67
use ndarray::Array2;
@@ -81,7 +82,7 @@ pub struct EgorConfig {
8182
/// If true use `outdir` to retrieve and start from previous results
8283
pub(crate) warm_start: bool,
8384
/// If some enable checkpointing allowing to restart for given ext_iters number of iteration from last checkpointed iteration
84-
pub(crate) hot_start: Option<u64>,
85+
pub(crate) hot_start: HotStartMode,
8586
/// List of x types allowing the handling of discrete input variables
8687
pub(crate) xtypes: Vec<XType>,
8788
/// A random generator seed used to get reproductible results.
@@ -111,7 +112,7 @@ impl Default for EgorConfig {
111112
target: f64::NEG_INFINITY,
112113
outdir: None,
113114
warm_start: false,
114-
hot_start: None,
115+
hot_start: HotStartMode::Disabled,
115116
xtypes: vec![],
116117
seed: None,
117118
trego: TregoConfig::default(),
@@ -269,7 +270,7 @@ impl EgorConfig {
269270
}
270271

271272
/// Whether checkpointing is enabled allowing hot start from previous checkpointed iteration if any
272-
pub fn hot_start(mut self, hot_start: Option<u64>) -> Self {
273+
pub fn hot_start(mut self, hot_start: HotStartMode) -> Self {
273274
self.hot_start = hot_start;
274275
self
275276
}

ego/src/utils/hot_start.rs

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,44 @@
11
pub use argmin::core::checkpointing::{Checkpoint, CheckpointingFrequency};
22
use argmin::core::Error;
3-
use serde::{de::DeserializeOwned, Serialize};
3+
use serde::{de::DeserializeOwned, Deserialize, Serialize};
44
use std::fs::File;
55
use std::io::{BufReader, BufWriter};
66
use std::path::PathBuf;
77

88
use crate::EgorState;
99

10+
/// An enum to specify hot start mode
11+
#[derive(Clone, Eq, PartialEq, Debug, Hash, Default, Serialize, Deserialize)]
12+
pub enum HotStartMode {
13+
/// Hot start checkpoints are not saved
14+
#[default]
15+
Disabled,
16+
/// Hot start checkpoints are saved and optionally used if it already exists
17+
Enabled,
18+
/// Hot start checkpoints are saved and optionally used if it already exists
19+
/// and optimization is run with an extended iteration budget
20+
ExtendedIters(u64),
21+
}
22+
23+
impl std::convert::From<Option<u64>> for HotStartMode {
24+
fn from(value: Option<u64>) -> Self {
25+
if let Some(ext_iters) = value {
26+
if ext_iters == 0 {
27+
HotStartMode::Enabled
28+
} else {
29+
HotStartMode::ExtendedIters(ext_iters)
30+
}
31+
} else {
32+
HotStartMode::Disabled
33+
}
34+
}
35+
}
36+
1037
/// Handles saving a checkpoint to disk as a binary file.
1138
#[derive(Clone, Eq, PartialEq, Debug, Hash)]
1239
pub struct HotStartCheckpoint {
1340
/// Extended iteration number
14-
pub extension_iters: u64,
41+
pub mode: HotStartMode,
1542
/// Indicates how often a checkpoint is created
1643
pub frequency: CheckpointingFrequency,
1744
/// Directory where the checkpoints are saved to
@@ -24,7 +51,7 @@ impl Default for HotStartCheckpoint {
2451
/// Create a default `HotStartCheckpoint` instance.
2552
fn default() -> HotStartCheckpoint {
2653
HotStartCheckpoint {
27-
extension_iters: 0,
54+
mode: HotStartMode::default(),
2855
frequency: CheckpointingFrequency::default(),
2956
directory: PathBuf::from(".checkpoints"),
3057
filename: PathBuf::from("egor.arg"),
@@ -38,10 +65,10 @@ impl HotStartCheckpoint {
3865
directory: N,
3966
name: N,
4067
frequency: CheckpointingFrequency,
41-
ext_iters: u64,
68+
ext_iters: HotStartMode,
4269
) -> Self {
4370
HotStartCheckpoint {
44-
extension_iters: ext_iters,
71+
mode: ext_iters,
4572
frequency,
4673
directory: PathBuf::from(directory.as_ref()),
4774
filename: PathBuf::from(format!("{}.arg", name.as_ref())),
@@ -81,7 +108,9 @@ where
81108
let file = File::open(path)?;
82109
let reader = BufReader::new(file);
83110
let (solver, mut state): (_, EgorState<_>) = bincode::deserialize_from(reader)?;
84-
state.extend_max_iters(self.extension_iters);
111+
if let HotStartMode::ExtendedIters(n_iters) = self.mode {
112+
state.extend_max_iters(n_iters);
113+
}
85114
Ok(Some((solver, state)))
86115
}
87116

src/egor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ pub(crate) fn to_specs(py: Python, xlimits: Vec<Vec<f64>>) -> PyResult<PyObject>
138138
/// warm_start (bool)
139139
/// Start by loading initial doe from <outdir> directory
140140
///
141-
/// hot_start (int or None)
141+
/// hot_start (int >= 0 or None)
142142
/// When hot_start>=0 saves optimizer state at each iteration and starts from a previous checkpoint
143143
/// if any for the given hot_start number of iterations beyond the max_iters nb of iterations.
144144
/// In an unstable environment were there can be crashes it allows to restart the optimization
@@ -476,7 +476,7 @@ impl Egor {
476476
.n_optmod(self.n_optmod)
477477
.target(self.target)
478478
.warm_start(self.warm_start)
479-
.hot_start(self.hot_start);
479+
.hot_start(self.hot_start.into());
480480
if let Some(doe) = doe {
481481
config = config.doe(doe);
482482
};

0 commit comments

Comments
 (0)