Skip to content

Commit

Permalink
Improve hot start API (#199)
Browse files Browse the repository at this point in the history
* Improve hot start API

* Fix visibility

* Use HotStartMode enum in EgorConfig
  • Loading branch information
relf authored Sep 30, 2024
1 parent 625fdd4 commit aba52f7
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 19 deletions.
26 changes: 21 additions & 5 deletions ego/src/egor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//! * Trust-region EGO optional activation
//! * Infill criteria: EI, WB2, WB2S
//! * Multi-point infill strategy (aka qEI)
//! * Warm/hot start
//!
//! See refences below.
//!
Expand Down Expand Up @@ -103,6 +104,7 @@ use crate::gpmix::mixint::*;
use crate::types::*;
use crate::EgorConfig;
use crate::EgorState;
use crate::HotStartMode;
use crate::{to_xtypes, EgorSolver};
use crate::{CheckpointingFrequency, HotStartCheckpoint};

Expand Down Expand Up @@ -213,12 +215,12 @@ impl<O: GroupFunc, SB: SurrogateBuilder + DeserializeOwned> Egor<O, SB> {

let exec = Executor::new(self.fobj.clone(), self.solver.clone());

let exec = if let Some(ext_iters) = self.solver.config.hot_start {
let exec = if self.solver.config.hot_start != HotStartMode::Disabled {
let checkpoint = HotStartCheckpoint::new(
".checkpoints",
"egor",
CheckpointingFrequency::Always,
ext_iters,
self.solver.config.hot_start.clone(),
);
exec.checkpointing(checkpoint)
} else {
Expand Down Expand Up @@ -423,7 +425,12 @@ mod tests {
let _ = std::fs::remove_file(".checkpoints/egor.arg");
let n_iter = 1;
let res = EgorBuilder::optimize(xsinx)
.configure(|config| config.max_iters(n_iter).seed(42).hot_start(Some(0)))
.configure(|config| {
config
.max_iters(n_iter)
.seed(42)
.hot_start(HotStartMode::Enabled)
})
.min_within(&array![[0.0, 25.0]])
.run()
.expect("Egor should minimize");
Expand All @@ -432,7 +439,12 @@ mod tests {

// without hostart we reach the same point
let res = EgorBuilder::optimize(xsinx)
.configure(|config| config.max_iters(n_iter).seed(42).hot_start(None))
.configure(|config| {
config
.max_iters(n_iter)
.seed(42)
.hot_start(HotStartMode::Disabled)
})
.min_within(&array![[0.0, 25.0]])
.run()
.expect("Egor should minimize");
Expand All @@ -442,7 +454,11 @@ mod tests {
// with hot start we continue
let ext_iters = 3;
let res = EgorBuilder::optimize(xsinx)
.configure(|config| config.seed(42).hot_start(Some(ext_iters)))
.configure(|config| {
config
.seed(42)
.hot_start(HotStartMode::ExtendedIters(ext_iters))
})
.min_within(&array![[0.0, 25.0]])
.run()
.expect("Egor should minimize");
Expand Down
42 changes: 39 additions & 3 deletions ego/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
//! * specify the initial doe,
//! * parameterize internal optimization,
//! * parameterize mixture of experts,
//! * save intermediate results and allow warm restart,
//! * save intermediate results and allow warm/hot restart,
//! * handling of mixed-integer variables
//! * activation of TREGO algorithm variation
//!
//! # Examples
//!
Expand Down Expand Up @@ -149,13 +151,47 @@
//! In the above example all GP with combinations of regression and correlation will be tested and the best combination for
//! each modeled function will be retained. You can also simply specify `RegressionSpec::ALL` and `CorrelationSpec::ALL` to
//! test all available combinations but remember that the more you test the slower it runs.
//!
//! * the TREGO algorithm described in \[[Diouane2023](#Diouane2023)\] can be activated
//!
//! ```no_run
//! # use egobox_ego::{EgorConfig, RegressionSpec, CorrelationSpec};
//! # let egor_config = EgorConfig::default();
//! egor_config.trego(true);
//! ```
//!
//! * Intermediate results can be logged at each iteration when `outdir` directory is specified.
//! The following files :
//! * egor_config.json: Egor configuration,
//! * egor_initial_doe.npy: initial DOE (x, y) as numpy array,
//! * egor_doe.npy: DOE (x, y) as numpy array,
//! * egor_history.npy: best (x, y) wrt to iteration number as (n_iters, nx + ny) numpy array
//!
//! ```no_run
//! # use egobox_ego::EgorConfig;
//! # let egor_config = EgorConfig::default();
//! egor_config.outdir("./.output");
//! ```
//! If warm_start is set to `true`, the algorithm starts from the saved `egor_doe.npy`
//!
//! * Hot start checkpointing can be enabled with `hot_start` option specifying a number of
//! extra iterations beyond max iters. This mechanism allows to restart after an interruption
//! from the last saved checkpoint. While warm_start restart from saved doe for another max_iters
//! iterations, hot start allows to continue from the last saved optimizer state till max_iters
//! is reached with optinal extra iterations.
//!
//! ```no_run
//! # use egobox_ego::{EgorConfig, HotStartMode};
//! # let egor_config = EgorConfig::default();
//! egor_config.hot_start(HotStartMode::Enabled);
//! ```
//!
//! # Implementation notes
//!
//! * Mixture of experts and PLS dimension reduction is explained in \[[Bartoli2019](#Bartoli2019)\]
//! * Parallel optimization is available through the selection of a qei strategy. See in \[[Ginsbourger2010](#Ginsbourger2010)\]
//! * Mixed integer approach is implemented using continuous relaxation. See \[[Garrido2018](#Garrido2018)\]
//! * TREGO algorithm is enabled by default. See \[[Diouane2023](#Diouane2023)\]
//! * TREGO algorithm is not enabled by default. See \[[Diouane2023](#Diouane2023)\]
//!
//! # References
//!
Expand Down Expand Up @@ -210,7 +246,7 @@ pub use crate::gpmix::spec::{CorrelationSpec, RegressionSpec};
pub use crate::solver::*;
pub use crate::types::*;
pub use crate::utils::{
find_best_result_index, Checkpoint, CheckpointingFrequency, HotStartCheckpoint,
find_best_result_index, Checkpoint, CheckpointingFrequency, HotStartCheckpoint, HotStartMode,
};

mod optimizers;
Expand Down
7 changes: 4 additions & 3 deletions ego/src/solver/egor_config.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Egor optimizer configuration.
use crate::criteria::*;
use crate::types::*;
use crate::HotStartMode;
use egobox_moe::{CorrelationSpec, RegressionSpec};
use ndarray::Array1;
use ndarray::Array2;
Expand Down Expand Up @@ -81,7 +82,7 @@ pub struct EgorConfig {
/// If true use `outdir` to retrieve and start from previous results
pub(crate) warm_start: bool,
/// If some enable checkpointing allowing to restart for given ext_iters number of iteration from last checkpointed iteration
pub(crate) hot_start: Option<u64>,
pub(crate) hot_start: HotStartMode,
/// List of x types allowing the handling of discrete input variables
pub(crate) xtypes: Vec<XType>,
/// A random generator seed used to get reproductible results.
Expand Down Expand Up @@ -111,7 +112,7 @@ impl Default for EgorConfig {
target: f64::NEG_INFINITY,
outdir: None,
warm_start: false,
hot_start: None,
hot_start: HotStartMode::Disabled,
xtypes: vec![],
seed: None,
trego: TregoConfig::default(),
Expand Down Expand Up @@ -269,7 +270,7 @@ impl EgorConfig {
}

/// Whether checkpointing is enabled allowing hot start from previous checkpointed iteration if any
pub fn hot_start(mut self, hot_start: Option<u64>) -> Self {
pub fn hot_start(mut self, hot_start: HotStartMode) -> Self {
self.hot_start = hot_start;
self
}
Expand Down
41 changes: 35 additions & 6 deletions ego/src/utils/hot_start.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,44 @@
pub use argmin::core::checkpointing::{Checkpoint, CheckpointingFrequency};
use argmin::core::Error;
use serde::{de::DeserializeOwned, Serialize};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::PathBuf;

use crate::EgorState;

/// An enum to specify hot start mode
#[derive(Clone, Eq, PartialEq, Debug, Hash, Default, Serialize, Deserialize)]
pub enum HotStartMode {
/// Hot start checkpoints are not saved
#[default]
Disabled,
/// Hot start checkpoints are saved and optionally used if it already exists
Enabled,
/// Hot start checkpoints are saved and optionally used if it already exists
/// and optimization is run with an extended iteration budget
ExtendedIters(u64),
}

impl std::convert::From<Option<u64>> for HotStartMode {
fn from(value: Option<u64>) -> Self {
if let Some(ext_iters) = value {
if ext_iters == 0 {
HotStartMode::Enabled
} else {
HotStartMode::ExtendedIters(ext_iters)
}
} else {
HotStartMode::Disabled
}
}
}

/// Handles saving a checkpoint to disk as a binary file.
#[derive(Clone, Eq, PartialEq, Debug, Hash)]
pub struct HotStartCheckpoint {
/// Extended iteration number
pub extension_iters: u64,
pub mode: HotStartMode,
/// Indicates how often a checkpoint is created
pub frequency: CheckpointingFrequency,
/// Directory where the checkpoints are saved to
Expand All @@ -24,7 +51,7 @@ impl Default for HotStartCheckpoint {
/// Create a default `HotStartCheckpoint` instance.
fn default() -> HotStartCheckpoint {
HotStartCheckpoint {
extension_iters: 0,
mode: HotStartMode::default(),
frequency: CheckpointingFrequency::default(),
directory: PathBuf::from(".checkpoints"),
filename: PathBuf::from("egor.arg"),
Expand All @@ -38,10 +65,10 @@ impl HotStartCheckpoint {
directory: N,
name: N,
frequency: CheckpointingFrequency,
ext_iters: u64,
ext_iters: HotStartMode,
) -> Self {
HotStartCheckpoint {
extension_iters: ext_iters,
mode: ext_iters,
frequency,
directory: PathBuf::from(directory.as_ref()),
filename: PathBuf::from(format!("{}.arg", name.as_ref())),
Expand Down Expand Up @@ -81,7 +108,9 @@ where
let file = File::open(path)?;
let reader = BufReader::new(file);
let (solver, mut state): (_, EgorState<_>) = bincode::deserialize_from(reader)?;
state.extend_max_iters(self.extension_iters);
if let HotStartMode::ExtendedIters(n_iters) = self.mode {
state.extend_max_iters(n_iters);
}
Ok(Some((solver, state)))
}

Expand Down
4 changes: 2 additions & 2 deletions src/egor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ pub(crate) fn to_specs(py: Python, xlimits: Vec<Vec<f64>>) -> PyResult<PyObject>
/// warm_start (bool)
/// Start by loading initial doe from <outdir> directory
///
/// hot_start (int or None)
/// hot_start (int >= 0 or None)
/// When hot_start>=0 saves optimizer state at each iteration and starts from a previous checkpoint
/// if any for the given hot_start number of iterations beyond the max_iters nb of iterations.
/// In an unstable environment were there can be crashes it allows to restart the optimization
Expand Down Expand Up @@ -476,7 +476,7 @@ impl Egor {
.n_optmod(self.n_optmod)
.target(self.target)
.warm_start(self.warm_start)
.hot_start(self.hot_start);
.hot_start(self.hot_start.into());
if let Some(doe) = doe {
config = config.doe(doe);
};
Expand Down

0 comments on commit aba52f7

Please sign in to comment.