Skip to content

Commit 40ff915

Browse files
committed
Move random_seed setting in configuration
1 parent 65508bb commit 40ff915

File tree

6 files changed

+44
-52
lines changed

6 files changed

+44
-52
lines changed

ego/examples/g24.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ fn main() {
3434

3535
// We use Egor optimizer as a service
3636
let egor = EgorServiceBuilder::optimize()
37-
.configure(|config| config.n_cstr(2))
38-
.random_seed(42)
37+
.configure(|config| config.n_cstr(2).random_seed(42))
3938
.min_within(&xlimits);
4039

4140
let mut y_doe = f_g24(&doe.view());

ego/src/egor.rs

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ use argmin::core::{Executor, State};
109109
pub struct EgorBuilder<O: GroupFunc> {
110110
fobj: O,
111111
config: EgorConfig,
112-
seed: Option<u64>,
113112
}
114113

115114
impl<O: GroupFunc> EgorBuilder<O> {
@@ -122,7 +121,6 @@ impl<O: GroupFunc> EgorBuilder<O> {
122121
EgorBuilder {
123122
fobj,
124123
config: EgorConfig::default(),
125-
seed: None,
126124
}
127125
}
128126

@@ -131,13 +129,6 @@ impl<O: GroupFunc> EgorBuilder<O> {
131129
self
132130
}
133131

134-
/// Allow to specify a seed for random number generator to allow
135-
/// reproducible runs.
136-
pub fn random_seed(mut self, seed: u64) -> Self {
137-
self.seed = Some(seed);
138-
self
139-
}
140-
141132
/// Build an Egor optimizer to minimize the function within
142133
/// the continuous `xlimits` specified as [[lower, upper], ...] array where the
143134
/// number of rows gives the dimension of the inputs (continuous optimization)
@@ -146,7 +137,7 @@ impl<O: GroupFunc> EgorBuilder<O> {
146137
self,
147138
xlimits: &ArrayBase<impl Data<Elem = f64>, Ix2>,
148139
) -> Egor<O, MoeParams<f64, Xoshiro256Plus>> {
149-
let rng = if let Some(seed) = self.seed {
140+
let rng = if let Some(seed) = self.config.seed {
150141
Xoshiro256Plus::seed_from_u64(seed)
151142
} else {
152143
Xoshiro256Plus::from_entropy()
@@ -161,7 +152,7 @@ impl<O: GroupFunc> EgorBuilder<O> {
161152
/// inputs specified with given xtypes where some of components may be
162153
/// discrete variables (mixed-integer optimization).
163154
pub fn min_within_mixint_space(self, xtypes: &[XType]) -> Egor<O, MixintMoeParams> {
164-
let rng = if let Some(seed) = self.seed {
155+
let rng = if let Some(seed) = self.config.seed {
165156
Xoshiro256Plus::seed_from_u64(seed)
166157
} else {
167158
Xoshiro256Plus::from_entropy()
@@ -320,17 +311,27 @@ mod tests {
320311
let xlimits = array![[0.0, 25.0]];
321312
let doe = Lhs::new(&xlimits).sample(10);
322313
let res = EgorBuilder::optimize(xsinx)
323-
.configure(|config| config.n_iter(15).doe(&doe).outdir("target/tests"))
324-
.random_seed(42)
314+
.configure(|config| {
315+
config
316+
.n_iter(15)
317+
.doe(&doe)
318+
.outdir("target/tests")
319+
.random_seed(42)
320+
})
325321
.min_within(&xlimits)
326322
.run()
327323
.expect("Minimize failure");
328324
let expected = array![18.9];
329325
assert_abs_diff_eq!(expected, res.x_opt, epsilon = 1e-1);
330326

331327
let res = EgorBuilder::optimize(xsinx)
332-
.configure(|config| config.n_iter(5).outdir("target/tests").hot_start(true))
333-
.random_seed(42)
328+
.configure(|config| {
329+
config
330+
.n_iter(5)
331+
.outdir("target/tests")
332+
.hot_start(true)
333+
.random_seed(42)
334+
})
334335
.min_within(&xlimits)
335336
.run()
336337
.expect("Egor should minimize xsinx");
@@ -362,8 +363,8 @@ mod tests {
362363
.regression_spec(RegressionSpec::ALL)
363364
.correlation_spec(CorrelationSpec::ALL)
364365
.target(1e-2)
366+
.random_seed(42)
365367
})
366-
.random_seed(42)
367368
.min_within(&xlimits)
368369
.run()
369370
.expect("Minimize failure");
@@ -407,8 +408,7 @@ mod tests {
407408
.with_rng(Xoshiro256Plus::seed_from_u64(42))
408409
.sample(3);
409410
let res = EgorBuilder::optimize(f_g24)
410-
.configure(|config| config.n_cstr(2).doe(&doe).n_iter(20))
411-
.random_seed(42)
411+
.configure(|config| config.n_cstr(2).doe(&doe).n_iter(20).random_seed(42))
412412
.min_within(&xlimits)
413413
.run()
414414
.expect("Minimize failure");
@@ -436,8 +436,8 @@ mod tests {
436436
.doe(&doe)
437437
.target(-5.5030)
438438
.n_iter(30)
439+
.random_seed(42)
439440
})
440-
.random_seed(42)
441441
.min_within(&xlimits)
442442
.run()
443443
.expect("Egor minimization");
@@ -470,8 +470,8 @@ mod tests {
470470
.n_iter(n_iter)
471471
.target(-15.1)
472472
.infill_strategy(InfillStrategy::EI)
473+
.random_seed(42)
473474
})
474-
.random_seed(42)
475475
.min_within_mixint_space(&xtypes)
476476
.run()
477477
.unwrap();
@@ -492,8 +492,8 @@ mod tests {
492492
.n_iter(n_iter)
493493
.target(-15.1)
494494
.infill_strategy(InfillStrategy::EI)
495+
.random_seed(42)
495496
})
496-
.random_seed(42)
497497
.min_within_mixint_space(&xtypes)
498498
.run()
499499
.unwrap();
@@ -512,8 +512,8 @@ mod tests {
512512
.regression_spec(egobox_moe::RegressionSpec::CONSTANT)
513513
.correlation_spec(egobox_moe::CorrelationSpec::SQUAREDEXPONENTIAL)
514514
.n_iter(n_iter)
515+
.random_seed(42)
515516
})
516-
.random_seed(42)
517517
.min_within_mixint_space(&xtypes)
518518
.run()
519519
.unwrap();
@@ -563,8 +563,8 @@ mod tests {
563563
.regression_spec(egobox_moe::RegressionSpec::CONSTANT)
564564
.correlation_spec(egobox_moe::CorrelationSpec::SQUAREDEXPONENTIAL)
565565
.n_iter(n_iter)
566+
.random_seed(42)
566567
})
567-
.random_seed(42)
568568
.min_within_mixint_space(&xtypes)
569569
.run()
570570
.unwrap();

ego/src/egor_config.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ use crate::types::*;
44
use egobox_moe::{CorrelationSpec, RegressionSpec};
55
use ndarray::Array1;
66
use ndarray::Array2;
7-
use rand_xoshiro::rand_core::SeedableRng;
8-
use rand_xoshiro::Xoshiro256Plus;
97

108
use serde::{Deserialize, Serialize};
119

@@ -59,9 +57,8 @@ pub struct EgorConfig {
5957
pub(crate) xtypes: Option<Vec<XType>>,
6058
/// Flag for discrete handling, true if mixed-integer type present in xtypes, otherwise false
6159
pub(crate) no_discrete: bool,
62-
/// A random generator used to get reproductible results.
63-
/// For instance: Xoshiro256Plus::from_u64_seed(42) for reproducibility
64-
pub(crate) rng: Xoshiro256Plus,
60+
/// A random generator seed used to get reproductible results.
61+
pub(crate) seed: Option<u64>,
6562
}
6663

6764
impl Default for EgorConfig {
@@ -86,7 +83,7 @@ impl Default for EgorConfig {
8683
hot_start: false,
8784
xtypes: None,
8885
no_discrete: true,
89-
rng: Xoshiro256Plus::from_entropy(),
86+
seed: None,
9087
}
9188
}
9289
}
@@ -233,4 +230,11 @@ impl EgorConfig {
233230
self.hot_start = hot_start;
234231
self
235232
}
233+
234+
/// Allow to specify a seed for random number generator to allow
235+
/// reproducible runs.
236+
pub fn random_seed(mut self, seed: u64) -> Self {
237+
self.seed = Some(seed);
238+
self
239+
}
236240
}

ego/src/egor_service.rs

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
//! conf.regression_spec(RegressionSpec::ALL)
2121
//! .correlation_spec(CorrelationSpec::ALL)
2222
//! .infill_strategy(InfillStrategy::EI)
23+
//! .random_seed(42)
2324
//! })
24-
//! .random_seed(42)
2525
//! .min_within(&array![[0., 25.]]);
2626
//!
2727
//! let mut doe = array![[0.], [7.], [20.], [25.]];
@@ -54,7 +54,6 @@ use rand_xoshiro::Xoshiro256Plus;
5454
///
5555
pub struct EgorServiceBuilder {
5656
config: EgorConfig,
57-
seed: Option<u64>,
5857
}
5958

6059
impl EgorServiceBuilder {
@@ -66,7 +65,6 @@ impl EgorServiceBuilder {
6665
pub fn optimize() -> Self {
6766
EgorServiceBuilder {
6867
config: EgorConfig::default(),
69-
seed: None,
7068
}
7169
}
7270

@@ -75,13 +73,6 @@ impl EgorServiceBuilder {
7573
self
7674
}
7775

78-
/// Allow to specify a seed for random number generator to allow
79-
/// reproducible runs.
80-
pub fn random_seed(mut self, seed: u64) -> Self {
81-
self.seed = Some(seed);
82-
self
83-
}
84-
8576
/// Build an Egor optimizer to minimize the function within
8677
/// the continuous `xlimits` specified as [[lower, upper], ...] array where the
8778
/// number of rows gives the dimension of the inputs (continuous optimization)
@@ -90,7 +81,7 @@ impl EgorServiceBuilder {
9081
self,
9182
xlimits: &ArrayBase<impl Data<Elem = f64>, Ix2>,
9283
) -> EgorService<MoeParams<f64, Xoshiro256Plus>> {
93-
let rng = if let Some(seed) = self.seed {
84+
let rng = if let Some(seed) = self.config.seed {
9485
Xoshiro256Plus::seed_from_u64(seed)
9586
} else {
9687
Xoshiro256Plus::from_entropy()
@@ -105,7 +96,7 @@ impl EgorServiceBuilder {
10596
/// inputs specified with given xtypes where some of components may be
10697
/// discrete variables (mixed-integer optimization).
10798
pub fn min_within_mixint_space(self, xtypes: &[XType]) -> EgorService<MixintMoeParams> {
108-
let rng = if let Some(seed) = self.seed {
99+
let rng = if let Some(seed) = self.config.seed {
109100
Xoshiro256Plus::seed_from_u64(seed)
110101
} else {
111102
Xoshiro256Plus::from_entropy()
@@ -162,8 +153,8 @@ mod tests {
162153
conf.regression_spec(RegressionSpec::ALL)
163154
.correlation_spec(CorrelationSpec::ALL)
164155
.infill_strategy(InfillStrategy::EI)
156+
.random_seed(42)
165157
})
166-
.random_seed(42)
167158
.min_within(&array![[0., 25.]]);
168159

169160
let mut doe = array![[0.], [7.], [20.], [25.]];

ego/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@
7474
//! .configure(|config|
7575
//! config.doe(&doe) // we pass the initial doe
7676
//! .n_iter(n_iter)
77-
//! .infill_strategy(InfillStrategy::EI))
78-
//! .random_seed(42)
77+
//! .infill_strategy(InfillStrategy::EI)
78+
//! .random_seed(42))
7979
//! .min_within_mixint_space(&xtypes) // We build a mixed-integer optimizer
8080
//! .run()
8181
//! .expect("Egor minimization");

src/egor.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,15 +302,10 @@ impl Egor {
302302
.collect();
303303
println!("{:?}", xtypes);
304304

305-
let mut mixintegor_build = egobox_ego::EgorBuilder::optimize(obj);
306-
if let Some(seed) = self.seed {
307-
mixintegor_build = mixintegor_build.random_seed(seed);
308-
};
309-
310305
let cstr_tol = self.cstr_tol.clone().unwrap_or(vec![0.0; self.n_cstr]);
311306
let cstr_tol = Array1::from_vec(cstr_tol);
312307

313-
let mixintegor = mixintegor_build
308+
let mixintegor = egobox_ego::EgorBuilder::optimize(obj)
314309
.configure(|config| {
315310
let mut config = config
316311
.n_cstr(self.n_cstr)
@@ -342,6 +337,9 @@ impl Egor {
342337
if let Some(outdir) = self.outdir.as_ref().cloned() {
343338
config = config.outdir(outdir);
344339
};
340+
if let Some(seed) = self.seed {
341+
config = config.random_seed(seed);
342+
};
345343
config
346344
})
347345
.min_within_mixint_space(&xtypes);

0 commit comments

Comments
 (0)