@@ -104,15 +104,18 @@ use crate::types::*;
104
104
use crate :: EgorConfig ;
105
105
use crate :: EgorState ;
106
106
use crate :: { to_xtypes, EgorSolver } ;
107
+ use crate :: { CheckpointingFrequency , HotStartCheckpoint } ;
107
108
108
109
use argmin:: core:: observers:: ObserverMode ;
110
+
109
111
use egobox_moe:: GpMixtureParams ;
110
112
use log:: info;
111
113
use ndarray:: { concatenate, Array2 , ArrayBase , Axis , Data , Ix2 } ;
112
114
use ndarray_rand:: rand:: SeedableRng ;
113
115
use rand_xoshiro:: Xoshiro256Plus ;
114
116
115
117
use argmin:: core:: { observers:: Observe , Error , Executor , State , KV } ;
118
+ use serde:: de:: DeserializeOwned ;
116
119
117
120
/// Json filename for configuration
118
121
pub const CONFIG_FILE : & str = "egor_config.json" ;
@@ -191,12 +194,12 @@ impl<O: GroupFunc> EgorBuilder<O> {
191
194
/// Egor optimizer structure used to parameterize the underlying `argmin::Solver`
192
195
/// and trigger the optimization using `argmin::Executor`.
193
196
#[ derive( Clone ) ]
194
- pub struct Egor < O : GroupFunc , SB : SurrogateBuilder > {
197
+ pub struct Egor < O : GroupFunc , SB : SurrogateBuilder + DeserializeOwned > {
195
198
fobj : ObjFunc < O > ,
196
199
solver : EgorSolver < SB > ,
197
200
}
198
201
199
- impl < O : GroupFunc , SB : SurrogateBuilder > Egor < O , SB > {
202
+ impl < O : GroupFunc , SB : SurrogateBuilder + DeserializeOwned > Egor < O , SB > {
200
203
/// Runs the (constrained) optimization of the objective function.
201
204
pub fn run ( & self ) -> Result < OptimResult < f64 > > {
202
205
let xtypes = self . solver . config . xtypes . clone ( ) ;
@@ -209,12 +212,26 @@ impl<O: GroupFunc, SB: SurrogateBuilder> Egor<O, SB> {
209
212
}
210
213
211
214
let exec = Executor :: new ( self . fobj . clone ( ) , self . solver . clone ( ) ) ;
215
+
216
+ let exec = if let Some ( ext_iters) = self . solver . config . hot_start {
217
+ let checkpoint = HotStartCheckpoint :: new (
218
+ ".checkpoints" ,
219
+ "egor" ,
220
+ CheckpointingFrequency :: Always ,
221
+ ext_iters,
222
+ ) ;
223
+ exec. checkpointing ( checkpoint)
224
+ } else {
225
+ exec
226
+ } ;
227
+
212
228
let result = if let Some ( outdir) = self . solver . config . outdir . as_ref ( ) {
213
229
let hist = OptimizationObserver :: new ( outdir. clone ( ) ) ;
214
230
exec. add_observer ( hist, ObserverMode :: Always ) . run ( ) ?
215
231
} else {
216
232
exec. run ( ) ?
217
233
} ;
234
+
218
235
info ! ( "{}" , result) ;
219
236
let ( x_data, y_data) = result. state ( ) . clone ( ) . take_data ( ) . unwrap ( ) ;
220
237
@@ -399,6 +416,41 @@ mod tests {
399
416
assert_abs_diff_eq ! ( expected, res. x_opt, epsilon = 1e-1 ) ;
400
417
}
401
418
419
+ #[ test]
420
+ #[ serial]
421
+ fn test_xsinx_checkpoint_egor ( ) {
422
+ let _ = std:: fs:: remove_file ( ".checkpoints/egor.arg" ) ;
423
+ let n_iter = 1 ;
424
+ let res = EgorBuilder :: optimize ( xsinx)
425
+ . configure ( |config| config. max_iters ( n_iter) . seed ( 42 ) . hot_start ( Some ( 0 ) ) )
426
+ . min_within ( & array ! [ [ 0.0 , 25.0 ] ] )
427
+ . run ( )
428
+ . expect ( "Egor should minimize" ) ;
429
+ let expected = array ! [ 19.1 ] ;
430
+ assert_abs_diff_eq ! ( expected, res. x_opt, epsilon = 1e-1 ) ;
431
+
432
+ // without hostart we reach the same point
433
+ let res = EgorBuilder :: optimize ( xsinx)
434
+ . configure ( |config| config. max_iters ( n_iter) . seed ( 42 ) . hot_start ( None ) )
435
+ . min_within ( & array ! [ [ 0.0 , 25.0 ] ] )
436
+ . run ( )
437
+ . expect ( "Egor should minimize" ) ;
438
+ let expected = array ! [ 19.1 ] ;
439
+ assert_abs_diff_eq ! ( expected, res. x_opt, epsilon = 1e-1 ) ;
440
+
441
+ // with hot start we continue
442
+ let ext_iters = 3 ;
443
+ let res = EgorBuilder :: optimize ( xsinx)
444
+ . configure ( |config| config. seed ( 42 ) . hot_start ( Some ( ext_iters) ) )
445
+ . min_within ( & array ! [ [ 0.0 , 25.0 ] ] )
446
+ . run ( )
447
+ . expect ( "Egor should minimize" ) ;
448
+ let expected = array ! [ 18.9 ] ;
449
+ assert_abs_diff_eq ! ( expected, res. x_opt, epsilon = 1e-1 ) ;
450
+ assert_eq ! ( n_iter as u64 + ext_iters, res. state. get_iter( ) ) ;
451
+ let _ = std:: fs:: remove_file ( ".checkpoints/egor.arg" ) ;
452
+ }
453
+
402
454
#[ test]
403
455
#[ serial]
404
456
fn test_xsinx_auto_clustering_egor_builder ( ) {
0 commit comments