Skip to content

Commit ebdde79

Browse files
authored
Add checkpointing and hot start (#197)
* Implement hot_start using argmin checkpointing * Make it work and validate hot_start * Improve py documentaiton * Restore trego default config
1 parent d5e0617 commit ebdde79

13 files changed

+201
-12
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ input.txt
2525
output.txt
2626
mopta08.exe
2727
mopta08_elf64.bin
28+
**/.checkpoints
2829

2930
# JOSS
3031
joss/paper.jats
3132
joss/paper.pdf
33+

ego/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ nlopt = { version = "0.7.0", optional = true }
4141

4242
rand_xoshiro = { version = "0.6", features = ["serde1"] }
4343
argmin = { version = "0.10.0", features = ["serde1", "ctrlc"] }
44+
bincode = { version = "1.3.0" }
4445
web-time = "1.1.0"
4546
libm = "0.2.6"
4647
finitediff = { version = "0.1", features = ["ndarray"] }

ego/src/egor.rs

+54-2
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,18 @@ use crate::types::*;
104104
use crate::EgorConfig;
105105
use crate::EgorState;
106106
use crate::{to_xtypes, EgorSolver};
107+
use crate::{CheckpointingFrequency, HotStartCheckpoint};
107108

108109
use argmin::core::observers::ObserverMode;
110+
109111
use egobox_moe::GpMixtureParams;
110112
use log::info;
111113
use ndarray::{concatenate, Array2, ArrayBase, Axis, Data, Ix2};
112114
use ndarray_rand::rand::SeedableRng;
113115
use rand_xoshiro::Xoshiro256Plus;
114116

115117
use argmin::core::{observers::Observe, Error, Executor, State, KV};
118+
use serde::de::DeserializeOwned;
116119

117120
/// Json filename for configuration
118121
pub const CONFIG_FILE: &str = "egor_config.json";
@@ -191,12 +194,12 @@ impl<O: GroupFunc> EgorBuilder<O> {
191194
/// Egor optimizer structure used to parameterize the underlying `argmin::Solver`
192195
/// and trigger the optimization using `argmin::Executor`.
193196
#[derive(Clone)]
194-
pub struct Egor<O: GroupFunc, SB: SurrogateBuilder> {
197+
pub struct Egor<O: GroupFunc, SB: SurrogateBuilder + DeserializeOwned> {
195198
fobj: ObjFunc<O>,
196199
solver: EgorSolver<SB>,
197200
}
198201

199-
impl<O: GroupFunc, SB: SurrogateBuilder> Egor<O, SB> {
202+
impl<O: GroupFunc, SB: SurrogateBuilder + DeserializeOwned> Egor<O, SB> {
200203
/// Runs the (constrained) optimization of the objective function.
201204
pub fn run(&self) -> Result<OptimResult<f64>> {
202205
let xtypes = self.solver.config.xtypes.clone();
@@ -209,12 +212,26 @@ impl<O: GroupFunc, SB: SurrogateBuilder> Egor<O, SB> {
209212
}
210213

211214
let exec = Executor::new(self.fobj.clone(), self.solver.clone());
215+
216+
let exec = if let Some(ext_iters) = self.solver.config.hot_start {
217+
let checkpoint = HotStartCheckpoint::new(
218+
".checkpoints",
219+
"egor",
220+
CheckpointingFrequency::Always,
221+
ext_iters,
222+
);
223+
exec.checkpointing(checkpoint)
224+
} else {
225+
exec
226+
};
227+
212228
let result = if let Some(outdir) = self.solver.config.outdir.as_ref() {
213229
let hist = OptimizationObserver::new(outdir.clone());
214230
exec.add_observer(hist, ObserverMode::Always).run()?
215231
} else {
216232
exec.run()?
217233
};
234+
218235
info!("{}", result);
219236
let (x_data, y_data) = result.state().clone().take_data().unwrap();
220237

@@ -399,6 +416,41 @@ mod tests {
399416
assert_abs_diff_eq!(expected, res.x_opt, epsilon = 1e-1);
400417
}
401418

419+
#[test]
420+
#[serial]
421+
fn test_xsinx_checkpoint_egor() {
422+
let _ = std::fs::remove_file(".checkpoints/egor.arg");
423+
let n_iter = 1;
424+
let res = EgorBuilder::optimize(xsinx)
425+
.configure(|config| config.max_iters(n_iter).seed(42).hot_start(Some(0)))
426+
.min_within(&array![[0.0, 25.0]])
427+
.run()
428+
.expect("Egor should minimize");
429+
let expected = array![19.1];
430+
assert_abs_diff_eq!(expected, res.x_opt, epsilon = 1e-1);
431+
432+
// without hostart we reach the same point
433+
let res = EgorBuilder::optimize(xsinx)
434+
.configure(|config| config.max_iters(n_iter).seed(42).hot_start(None))
435+
.min_within(&array![[0.0, 25.0]])
436+
.run()
437+
.expect("Egor should minimize");
438+
let expected = array![19.1];
439+
assert_abs_diff_eq!(expected, res.x_opt, epsilon = 1e-1);
440+
441+
// with hot start we continue
442+
let ext_iters = 3;
443+
let res = EgorBuilder::optimize(xsinx)
444+
.configure(|config| config.seed(42).hot_start(Some(ext_iters)))
445+
.min_within(&array![[0.0, 25.0]])
446+
.run()
447+
.expect("Egor should minimize");
448+
let expected = array![18.9];
449+
assert_abs_diff_eq!(expected, res.x_opt, epsilon = 1e-1);
450+
assert_eq!(n_iter as u64 + ext_iters, res.state.get_iter());
451+
let _ = std::fs::remove_file(".checkpoints/egor.arg");
452+
}
453+
402454
#[test]
403455
#[serial]
404456
fn test_xsinx_auto_clustering_egor_builder() {

ego/src/lib.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ pub use crate::errors::*;
209209
pub use crate::gpmix::spec::{CorrelationSpec, RegressionSpec};
210210
pub use crate::solver::*;
211211
pub use crate::types::*;
212-
pub use crate::utils::find_best_result_index;
212+
pub use crate::utils::{
213+
find_best_result_index, Checkpoint, CheckpointingFrequency, HotStartCheckpoint,
214+
};
213215

214216
mod optimizers;
215217
mod utils;

ego/src/solver/egor_config.rs

+9
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ pub struct EgorConfig {
8080
pub(crate) outdir: Option<String>,
8181
/// If true use `outdir` to retrieve and start from previous results
8282
pub(crate) warm_start: bool,
83+
/// If some enable checkpointing allowing to restart for given ext_iters number of iteration from last checkpointed iteration
84+
pub(crate) hot_start: Option<u64>,
8385
/// List of x types allowing the handling of discrete input variables
8486
pub(crate) xtypes: Vec<XType>,
8587
/// A random generator seed used to get reproductible results.
@@ -109,6 +111,7 @@ impl Default for EgorConfig {
109111
target: f64::NEG_INFINITY,
110112
outdir: None,
111113
warm_start: false,
114+
hot_start: None,
112115
xtypes: vec![],
113116
seed: None,
114117
trego: TregoConfig::default(),
@@ -265,6 +268,12 @@ impl EgorConfig {
265268
self
266269
}
267270

271+
/// Whether checkpointing is enabled allowing hot start from previous checkpointed iteration if any
272+
pub fn hot_start(mut self, hot_start: Option<u64>) -> Self {
273+
self.hot_start = hot_start;
274+
self
275+
}
276+
268277
/// Allow to specify a seed for random number generator to allow
269278
/// reproducible runs.
270279
pub fn seed(mut self, seed: u64) -> Self {

ego/src/solver/egor_impl.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ use ndarray::{
2222
use ndarray_stats::QuantileExt;
2323
use rand_xoshiro::Xoshiro256Plus;
2424
use rayon::prelude::*;
25+
use serde::de::DeserializeOwned;
2526

26-
impl<SB: SurrogateBuilder> EgorSolver<SB> {
27+
impl<SB: SurrogateBuilder + DeserializeOwned> EgorSolver<SB> {
2728
/// Constructor of the optimization of the function `f` with specified random generator
2829
/// to get reproducibility.
2930
///
@@ -80,7 +81,7 @@ impl<SB: SurrogateBuilder> EgorSolver<SB> {
8081

8182
impl<SB> EgorSolver<SB>
8283
where
83-
SB: SurrogateBuilder,
84+
SB: SurrogateBuilder + DeserializeOwned,
8485
{
8586
pub fn have_to_recluster(&self, added: usize, prev_added: usize) -> bool {
8687
self.config.n_clusters == 0 && (added != 0 && added % 10 == 0 && added - prev_added > 0)

ego/src/solver/egor_service.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ use egobox_moe::GpMixtureParams;
4646
use ndarray::{Array2, ArrayBase, Data, Ix2};
4747
use ndarray_rand::rand::SeedableRng;
4848
use rand_xoshiro::Xoshiro256Plus;
49+
use serde::de::DeserializeOwned;
4950

5051
/// EGO optimizer service builder allowing to use Egor optimizer
5152
/// as a service.
@@ -114,11 +115,11 @@ impl EgorServiceBuilder {
114115

115116
/// Egor optimizer service.
116117
#[derive(Clone)]
117-
pub struct EgorService<SB: SurrogateBuilder> {
118+
pub struct EgorService<SB: SurrogateBuilder + DeserializeOwned> {
118119
solver: EgorSolver<SB>,
119120
}
120121

121-
impl<SB: SurrogateBuilder> EgorService<SB> {
122+
impl<SB: SurrogateBuilder + DeserializeOwned> EgorService<SB> {
122123
/// Given an evaluated doe (x, y) data, return the next promising x point
123124
/// where optimum may be located with regard to the infill criterion.
124125
/// This function inverses the control of the optimization and can be used

ego/src/solver/egor_solver.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ use argmin::core::{
120120
};
121121

122122
use rand_xoshiro::Xoshiro256Plus;
123-
use serde::{Deserialize, Serialize};
123+
use serde::{de::DeserializeOwned, Deserialize, Serialize};
124124
use std::time::Instant;
125125

126126
/// Numpy filename for initial DOE dump
@@ -161,7 +161,7 @@ pub fn to_xtypes(xlimits: &ArrayBase<impl Data<Elem = f64>, Ix2>) -> Vec<XType>
161161
impl<O, SB> Solver<O, EgorState<f64>> for EgorSolver<SB>
162162
where
163163
O: CostFunction<Param = Array2<f64>, Output = Array2<f64>>,
164-
SB: SurrogateBuilder,
164+
SB: SurrogateBuilder + DeserializeOwned,
165165
{
166166
const NAME: &'static str = "Egor";
167167

@@ -304,7 +304,7 @@ where
304304

305305
impl<SB> EgorSolver<SB>
306306
where
307-
SB: SurrogateBuilder,
307+
SB: SurrogateBuilder + DeserializeOwned,
308308
{
309309
/// Iteration of EGO algorithm
310310
fn ego_iteration<O: CostFunction<Param = Array2<f64>, Output = Array2<f64>>>(

ego/src/solver/egor_state.rs

+10
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,16 @@ where
289289
}
290290
}
291291

292+
impl<F> EgorState<F>
293+
where
294+
F: Float + ArgminFloat,
295+
{
296+
/// Allow hot start feature by extending current max_iters
297+
pub fn extend_max_iters(&mut self, ext_iters: u64) {
298+
self.max_iters += ext_iters;
299+
}
300+
}
301+
292302
impl<F> State for EgorState<F>
293303
where
294304
F: Float + ArgminFloat,

ego/src/solver/trego.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ use ndarray::Zip;
2121
use ndarray::{s, Array, Array1, Array2, ArrayView1, Axis};
2222

2323
use rayon::prelude::*;
24+
use serde::de::DeserializeOwned;
2425

25-
impl<SB: SurrogateBuilder> EgorSolver<SB> {
26+
impl<SB: SurrogateBuilder + DeserializeOwned> EgorSolver<SB> {
2627
/// Local step where infill criterion is optimized within trust region
2728
pub fn trego_step<O: CostFunction<Param = Array2<f64>, Output = Array2<f64>>>(
2829
&mut self,

ego/src/utils/hot_start.rs

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
pub use argmin::core::checkpointing::{Checkpoint, CheckpointingFrequency};
2+
use argmin::core::Error;
3+
use serde::{de::DeserializeOwned, Serialize};
4+
use std::fs::File;
5+
use std::io::{BufReader, BufWriter};
6+
use std::path::PathBuf;
7+
8+
use crate::EgorState;
9+
10+
/// Handles saving a checkpoint to disk as a binary file.
11+
#[derive(Clone, Eq, PartialEq, Debug, Hash)]
12+
pub struct HotStartCheckpoint {
13+
/// Extended iteration number
14+
pub extension_iters: u64,
15+
/// Indicates how often a checkpoint is created
16+
pub frequency: CheckpointingFrequency,
17+
/// Directory where the checkpoints are saved to
18+
pub directory: PathBuf,
19+
/// Name of the checkpoint files
20+
pub filename: PathBuf,
21+
}
22+
23+
impl Default for HotStartCheckpoint {
24+
/// Create a default `HotStartCheckpoint` instance.
25+
fn default() -> HotStartCheckpoint {
26+
HotStartCheckpoint {
27+
extension_iters: 0,
28+
frequency: CheckpointingFrequency::default(),
29+
directory: PathBuf::from(".checkpoints"),
30+
filename: PathBuf::from("egor.arg"),
31+
}
32+
}
33+
}
34+
35+
impl HotStartCheckpoint {
36+
/// Create a new `HotStartCheckpoint` instance
37+
pub fn new<N: AsRef<str>>(
38+
directory: N,
39+
name: N,
40+
frequency: CheckpointingFrequency,
41+
ext_iters: u64,
42+
) -> Self {
43+
HotStartCheckpoint {
44+
extension_iters: ext_iters,
45+
frequency,
46+
directory: PathBuf::from(directory.as_ref()),
47+
filename: PathBuf::from(format!("{}.arg", name.as_ref())),
48+
}
49+
}
50+
}
51+
52+
impl<S> Checkpoint<S, EgorState<f64>> for HotStartCheckpoint
53+
where
54+
S: Serialize + DeserializeOwned,
55+
{
56+
/// Writes checkpoint to disk.
57+
///
58+
/// If the directory does not exist already, it will be created. It uses `bincode` to serialize
59+
/// the data.
60+
/// It will return an error if creating the directory or file or serialization failed.
61+
fn save(&self, solver: &S, state: &EgorState<f64>) -> Result<(), Error> {
62+
if !self.directory.exists() {
63+
std::fs::create_dir_all(&self.directory)?
64+
}
65+
let fname = self.directory.join(&self.filename);
66+
let f = BufWriter::new(File::create(fname)?);
67+
bincode::serialize_into(f, &(solver, state))?;
68+
Ok(())
69+
}
70+
71+
/// Load a checkpoint from disk.
72+
///
73+
///
74+
/// If there is no checkpoint on disk, it will return `Ok(None)`.
75+
/// Returns an error if opening the file or deserialization failed.
76+
fn load(&self) -> Result<Option<(S, EgorState<f64>)>, Error> {
77+
let path = &self.directory.join(&self.filename);
78+
if !path.exists() {
79+
return Ok(None);
80+
}
81+
let file = File::open(path)?;
82+
let reader = BufReader::new(file);
83+
let (solver, mut state): (_, EgorState<_>) = bincode::deserialize_from(reader)?;
84+
state.extend_max_iters(self.extension_iters);
85+
Ok(Some((solver, state)))
86+
}
87+
88+
/// Returns the how often a checkpoint is to be saved.
89+
///
90+
/// Used internally by [`save_cond`](`argmin::core::checkpointing::Checkpoint::save_cond`).
91+
fn frequency(&self) -> CheckpointingFrequency {
92+
self.frequency
93+
}
94+
}

ego/src/utils/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
mod find_result;
2+
mod hot_start;
23
mod misc;
34
mod sort_axis;
45

56
pub use find_result::*;
7+
pub use hot_start::*;
68
pub use misc::*;

0 commit comments

Comments
 (0)