Skip to content

Commit 54bb51c

Browse files
committed
Add n_start argument to control hyperparams optimization restarts
1 parent 47fb019 commit 54bb51c

File tree

9 files changed

+108
-25
lines changed

9 files changed

+108
-25
lines changed

gp/src/parameters.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ impl<F: Float, Mean: RegressionModel<F>, Corr: CorrelationModel<F>> GpParams<F,
112112
self
113113
}
114114

115-
/// Set the number of internal GP hyperparameter theta optimization restart
115+
/// Set the number of internal GP hyperparameter theta optimization restarts
116116
pub fn n_start(mut self, n_start: usize) -> Self {
117117
self.0.n_start = n_start;
118118
self

gp/src/sgp_parameters.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@ impl<F: Float, Corr: CorrelationModel<F>> SgpParams<F, Corr> {
168168
self
169169
}
170170

171+
/// Set the number of internal hyperparameters optimization restarts
172+
pub fn n_start(mut self, n_start: usize) -> Self {
173+
self.0.gp_params.n_start = n_start;
174+
self
175+
}
176+
171177
/// Set nugget value.
172178
///
173179
/// Nugget is used to improve numerical stability

moe/src/gp_algorithm.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ impl<R: Rng + SeedableRng + Clone> GpMixValidParams<f64, R> {
270270
};
271271
let mut expert_params = best_expert_params?;
272272
expert_params.kpls_dim(self.kpls_dim());
273+
expert_params.n_start(self.n_start());
273274
let expert = expert_params.train(&xtrain.view(), &ytrain.view());
274275
if let Some(v) = best.1 {
275276
info!("Best expert {} accuracy={}", best.0, v);

moe/src/gp_parameters.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ pub struct GpMixValidParams<F: Float, R: Rng + Clone> {
3232
/// Number of PLS components, should be used when problem size
3333
/// is over ten variables or so.
3434
kpls_dim: Option<usize>,
35+
/// Number of GP hyperparameters optimization restarts
36+
n_start: usize,
3537
/// Gaussian Mixture model used to cluster
3638
gmm: Option<Box<GaussianMixtureModel<F>>>,
3739
/// GaussianMixture preset
@@ -48,6 +50,7 @@ impl<F: Float, R: Rng + SeedableRng + Clone> Default for GpMixValidParams<F, R>
4850
regression_spec: RegressionSpec::ALL,
4951
correlation_spec: CorrelationSpec::ALL,
5052
kpls_dim: None,
53+
n_start: 10,
5154
gmm: None,
5255
gmx: None,
5356
rng: R::from_entropy(),
@@ -81,6 +84,11 @@ impl<F: Float, R: Rng + Clone> GpMixValidParams<F, R> {
8184
self.kpls_dim
8285
}
8386

87+
/// The number of hypermarameters optimization restarts
88+
pub fn n_start(&self) -> usize {
89+
self.n_start
90+
}
91+
8492
/// An optional gaussian mixture to be fitted to generate multivariate normal
8593
/// in turns used to cluster
8694
pub fn gmm(&self) -> &Option<Box<GaussianMixtureModel<F>>> {
@@ -135,6 +143,7 @@ impl<F: Float, R: Rng + Clone> GpMixParams<F, R> {
135143
regression_spec: RegressionSpec::ALL,
136144
correlation_spec: CorrelationSpec::ALL,
137145
kpls_dim: None,
146+
n_start: 10,
138147
gmm: None,
139148
gmx: None,
140149
rng,
@@ -179,6 +188,12 @@ impl<F: Float, R: Rng + Clone> GpMixParams<F, R> {
179188
self
180189
}
181190

191+
/// Sets the number of hyperparameters optimization restarts
192+
pub fn n_start(mut self, n_start: usize) -> Self {
193+
self.0.n_start = n_start;
194+
self
195+
}
196+
182197
#[doc(hidden)]
183198
/// Sets the gaussian mixture (used to find the optimal number of clusters)
184199
pub(crate) fn gmm(mut self, gmm: Option<Box<GaussianMixtureModel<F>>>) -> Self {
@@ -205,6 +220,7 @@ impl<F: Float, R: Rng + Clone> GpMixParams<F, R> {
205220
regression_spec: self.0.regression_spec(),
206221
correlation_spec: self.0.correlation_spec(),
207222
kpls_dim: self.0.kpls_dim(),
223+
n_start: self.0.n_start(),
208224
gmm: self.0.gmm().clone(),
209225
gmx: self.0.gmx().clone(),
210226
rng,

moe/src/sgp_algorithm.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,9 @@ impl<R: Rng + SeedableRng + Clone> SparseGpMixtureValidParams<f64, R> {
258258
};
259259
let mut expert_params = best_expert_params?;
260260
let seed = self.rng().gen();
261-
expert_params.kpls_dim(self.kpls_dim());
262261
expert_params.initial_theta(self.initial_theta());
262+
expert_params.kpls_dim(self.kpls_dim());
263+
expert_params.n_start(self.n_start());
263264
expert_params.sparse_method(self.sparse_method());
264265
expert_params.seed(seed);
265266
debug!("Train best expert...");

moe/src/sgp_parameters.rs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ pub struct SparseGpMixtureValidParams<F: Float, R: Rng + Clone> {
3535
/// Number of PLS components, should be used when problem size
3636
/// is over ten variables or so.
3737
kpls_dim: Option<usize>,
38+
/// Number of GP hyperparameters optimization restarts
39+
n_start: usize,
3840
/// Used sparse method
3941
sparse_method: SparseMethod,
4042
/// Inducings
@@ -54,8 +56,9 @@ impl<F: Float, R: Rng + SeedableRng + Clone> Default for SparseGpMixtureValidPar
5456
recombination: Recombination::Smooth(Some(F::one())),
5557
regression_spec: RegressionSpec::CONSTANT,
5658
correlation_spec: CorrelationSpec::SQUAREDEXPONENTIAL,
57-
kpls_dim: None,
5859
initial_theta: None,
60+
kpls_dim: None,
61+
n_start: 10,
5962
sparse_method: SparseMethod::default(),
6063
inducings: Inducings::default(),
6164
gmm: None,
@@ -86,14 +89,19 @@ impl<F: Float, R: Rng + Clone> SparseGpMixtureValidParams<F, R> {
8689
self.correlation_spec
8790
}
8891

92+
/// The optional initial guess for GP theta hyperparameters
93+
pub fn initial_theta(&self) -> Option<Vec<F>> {
94+
self.initial_theta.clone()
95+
}
96+
8997
/// The optional number of PLS components
9098
pub fn kpls_dim(&self) -> Option<usize> {
9199
self.kpls_dim
92100
}
93101

94-
/// The optional initial guess for GP theta hyperparameters
95-
pub fn initial_theta(&self) -> Option<Vec<F>> {
96-
self.initial_theta.clone()
102+
/// The number of hypermarameters optimization restarts
103+
pub fn n_start(&self) -> usize {
104+
self.n_start
97105
}
98106

99107
/// The sparse method used
@@ -158,8 +166,9 @@ impl<F: Float, R: Rng + Clone> SparseGpMixtureParams<F, R> {
158166
recombination: Recombination::Smooth(Some(F::one())),
159167
regression_spec: RegressionSpec::CONSTANT,
160168
correlation_spec: CorrelationSpec::SQUAREDEXPONENTIAL,
161-
kpls_dim: None,
162169
initial_theta: None,
170+
kpls_dim: None,
171+
n_start: 10,
163172
sparse_method: SparseMethod::default(),
164173
inducings,
165174
gmm: None,
@@ -205,6 +214,12 @@ impl<F: Float, R: Rng + Clone> SparseGpMixtureParams<F, R> {
205214
self
206215
}
207216

217+
/// Sets the number of hyperparameters optimization restarts
218+
pub fn n_start(mut self, n_start: usize) -> Self {
219+
self.0.n_start = n_start;
220+
self
221+
}
222+
208223
/// Sets
209224
///
210225
/// None means no PLS dimension reduction applied.
@@ -246,8 +261,9 @@ impl<F: Float, R: Rng + Clone> SparseGpMixtureParams<F, R> {
246261
recombination: self.0.recombination(),
247262
regression_spec: self.0.regression_spec(),
248263
correlation_spec: self.0.correlation_spec(),
249-
kpls_dim: self.0.kpls_dim(),
250264
initial_theta: self.0.initial_theta(),
265+
kpls_dim: self.0.kpls_dim(),
266+
n_start: self.0.n_start(),
251267
sparse_method: self.0.sparse_method(),
252268
inducings: self.0.inducings().clone(),
253269
gmm: self.0.gmm().clone(),

moe/src/surrogates.rs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,28 @@ use crate::MoeError;
1616
use std::fs;
1717
#[cfg(feature = "persistent")]
1818
use std::io::Write;
19-
/// A trait for Gp surrogate parameters to build surrogate once fitted.
19+
/// A trait for Gp surrogate parameters to build surrogate.
2020
pub trait GpSurrogateParams {
2121
/// Set initial theta
2222
fn initial_theta(&mut self, theta: Option<Vec<f64>>);
2323
/// Set the number of PLS components
2424
fn kpls_dim(&mut self, kpls_dim: Option<usize>);
25+
/// Set the nuber of internal optimization restarts
26+
fn n_start(&mut self, n_start: usize);
2527
/// Set the nugget parameter to improve numerical stability
2628
fn nugget(&mut self, nugget: f64);
2729
/// Train the surrogate
2830
fn train(&self, x: &ArrayView2<f64>, y: &ArrayView2<f64>) -> Result<Box<dyn FullGpSurrogate>>;
2931
}
3032

31-
/// A trait for sparse GP surrogate parameters to build surrogate once fitted.
33+
/// A trait for sparse GP surrogate parameters to build surrogate.
3234
pub trait SgpSurrogateParams {
3335
/// Set initial theta
3436
fn initial_theta(&mut self, theta: Option<Vec<f64>>);
3537
/// Set the number of PLS components
3638
fn kpls_dim(&mut self, kpls_dim: Option<usize>);
39+
/// Set the nuber of internal optimization restarts
40+
fn n_start(&mut self, n_start: usize);
3741
/// Set the sparse method
3842
fn sparse_method(&mut self, method: SparseMethod);
3943
/// Set random generator seed
@@ -106,6 +110,10 @@ macro_rules! declare_surrogate {
106110
self.0 = self.0.clone().kpls_dim(kpls_dim);
107111
}
108112

113+
fn n_start(&mut self, n_start: usize) {
114+
self.0 = self.0.clone().n_start(n_start);
115+
}
116+
109117
fn nugget(&mut self, nugget: f64) {
110118
self.0 = self.0.clone().nugget(nugget);
111119
}
@@ -216,14 +224,18 @@ macro_rules! declare_sgp_surrogate {
216224
self.0 = self.0.clone().initial_theta(theta);
217225
}
218226

219-
fn sparse_method(&mut self, method: SparseMethod) {
220-
self.0 = self.0.clone().sparse_method(method);
221-
}
222-
223227
fn kpls_dim(&mut self, kpls_dim: Option<usize>) {
224228
self.0 = self.0.clone().kpls_dim(kpls_dim);
225229
}
226230

231+
fn n_start(&mut self, n_start: usize) {
232+
self.0 = self.0.clone().n_start(n_start);
233+
}
234+
235+
fn sparse_method(&mut self, method: SparseMethod) {
236+
self.0 = self.0.clone().sparse_method(method);
237+
}
238+
227239
fn seed(&mut self, seed: Option<u64>) {
228240
self.0 = self.0.clone().seed(seed);
229241
}

src/gp_mix.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ use rand_xoshiro::Xoshiro256Plus;
4949
/// Number of components to be used when PLS projection is used (a.k.a KPLS method).
5050
/// This is used to address high-dimensional problems typically when nx > 9.
5151
///
52+
/// n_start (int >= 0)
53+
/// Number of internal GP hyperpameters optimization restart (multistart)
54+
///
5255
/// seed (int >= 0)
5356
/// Random generator seed to allow computation reproducibility.
5457
///
@@ -59,6 +62,7 @@ pub(crate) struct GpMix {
5962
pub correlation_spec: CorrelationSpec,
6063
pub recombination: Recombination,
6164
pub kpls_dim: Option<usize>,
65+
pub n_start: usize,
6266
pub seed: Option<u64>,
6367
}
6468

@@ -71,6 +75,7 @@ impl GpMix {
7175
corr_spec = CorrelationSpec::SQUARED_EXPONENTIAL,
7276
recombination = Recombination::Smooth,
7377
kpls_dim = None,
78+
n_start = 10,
7479
seed = None
7580
))]
7681
#[allow(clippy::too_many_arguments)]
@@ -80,6 +85,7 @@ impl GpMix {
8085
corr_spec: u8,
8186
recombination: Recombination,
8287
kpls_dim: Option<usize>,
88+
n_start: usize,
8389
seed: Option<u64>,
8490
) -> Self {
8591
GpMix {
@@ -88,6 +94,7 @@ impl GpMix {
8894
correlation_spec: CorrelationSpec(corr_spec),
8995
recombination,
9096
kpls_dim,
97+
n_start,
9198
seed,
9299
}
93100
}
@@ -124,6 +131,7 @@ impl GpMix {
124131
egobox_moe::CorrelationSpec::from_bits(self.correlation_spec.0).unwrap(),
125132
)
126133
.kpls_dim(self.kpls_dim)
134+
.n_start(self.n_start)
127135
.with_rng(rng)
128136
.fit(&dataset)
129137
.expect("MoE model training")
@@ -149,6 +157,7 @@ impl Gpx {
149157
corr_spec = CorrelationSpec::SQUARED_EXPONENTIAL,
150158
recombination = Recombination::Smooth,
151159
kpls_dim = None,
160+
n_start = 10,
152161
seed = None
153162
))]
154163
fn builder(
@@ -157,6 +166,7 @@ impl Gpx {
157166
corr_spec: u8,
158167
recombination: Recombination,
159168
kpls_dim: Option<usize>,
169+
n_start: usize,
160170
seed: Option<u64>,
161171
) -> GpMix {
162172
GpMix::new(
@@ -165,6 +175,7 @@ impl Gpx {
165175
corr_spec,
166176
recombination,
167177
kpls_dim,
178+
n_start,
168179
seed,
169180
)
170181
}

0 commit comments

Comments
 (0)