Skip to content

Commit 40c46a4

Browse files
committed
Rename GP predict_variances in predict_var
1 parent ba5e027 commit 40c46a4

File tree

12 files changed

+58
-78
lines changed

12 files changed

+58
-78
lines changed

ego/src/criteria/ei.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ impl InfillCriterion for ExpectedImprovement {
2323
) -> f64 {
2424
let pt = ArrayView::from_shape((1, x.len()), x).unwrap();
2525
if let Ok(p) = obj_model.predict(&pt) {
26-
if let Ok(s) = obj_model.predict_variances(&pt) {
26+
if let Ok(s) = obj_model.predic_var(&pt) {
2727
let pred = p[[0, 0]];
2828
let sigma = s[[0, 0]].sqrt();
2929
let args0 = (f_min - pred) / sigma;
@@ -49,7 +49,7 @@ impl InfillCriterion for ExpectedImprovement {
4949
) -> Array1<f64> {
5050
let pt = ArrayView::from_shape((1, x.len()), x).unwrap();
5151
if let Ok(p) = obj_model.predict(&pt) {
52-
if let Ok(s) = obj_model.predict_variances(&pt) {
52+
if let Ok(s) = obj_model.predic_var(&pt) {
5353
let sigma = s[[0, 0]].sqrt();
5454
if sigma.abs() < 1e-12 {
5555
Array1::zeros(pt.len())

ego/src/criteria/wb2.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,11 @@ mod tests {
165165
let xtest12 = array![[x1 - h, x2]];
166166
let xtest21 = array![[x1, x2 + h]];
167167
let xtest22 = array![[x1, x2 - h]];
168-
let fdiff1 = (bgp.predict_variances(&xtest11.view()).unwrap()
169-
- bgp.predict_variances(&xtest12.view()).unwrap())
168+
let fdiff1 = (bgp.predic_var(&xtest11.view()).unwrap()
169+
- bgp.predic_var(&xtest12.view()).unwrap())
170170
/ (2. * h);
171-
let fdiff2 = (bgp.predict_variances(&xtest21.view()).unwrap()
172-
- bgp.predict_variances(&xtest22.view()).unwrap())
171+
let fdiff2 = (bgp.predic_var(&xtest21.view()).unwrap()
172+
- bgp.predic_var(&xtest22.view()).unwrap())
173173
/ (2. * h);
174174
println!(
175175
"gp var fdiff({}) = [[{}, {}]]",

ego/src/egor_solver.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ where
883883
} else {
884884
let x = &xk.view().insert_axis(Axis(0));
885885
let pred = obj_model.predict(x)?[[0, 0]];
886-
let var = obj_model.predict_variances(x)?[[0, 0]];
886+
let var = obj_model.predic_var(x)?[[0, 0]];
887887
let conf = match self.config.q_ei {
888888
QEiStrategy::KrigingBeliever => 0.,
889889
QEiStrategy::KrigingBelieverLowerBound => -3.,

ego/src/mixint.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -541,14 +541,14 @@ impl GpSurrogate for MixintMoe {
541541
self.moe.predict(&xcast)
542542
}
543543

544-
fn predict_variances(&self, x: &ArrayView2<f64>) -> egobox_moe::Result<Array2<f64>> {
544+
fn predic_var(&self, x: &ArrayView2<f64>) -> egobox_moe::Result<Array2<f64>> {
545545
let mut xcast = if self.work_in_folded_space {
546546
unfold_with_enum_mask(&self.xtypes, x)
547547
} else {
548548
x.to_owned()
549549
};
550550
cast_to_discrete_values_mut(&self.xtypes, &mut xcast);
551-
self.moe.predict_variances(&xcast)
551+
self.moe.predic_var(&xcast)
552552
}
553553

554554
/// Save Moe model in given file.
@@ -633,7 +633,7 @@ impl<'a, D: Data<Elem = f64>> PredictInplace<ArrayBase<D, Ix2>, Array2<f64>>
633633

634634
let values = self
635635
.0
636-
.predict_variances(x)
636+
.predic_var(x)
637637
.expect("MixintMoE variances prediction");
638638
*y = values;
639639
}
@@ -840,7 +840,7 @@ mod tests {
840840
let xtest = Array::linspace(0.0, 4.0, num).insert_axis(Axis(1));
841841
let ytest = mixi_moe.predict(&xtest.view()).expect("Predict val fail");
842842
let yvar = mixi_moe
843-
.predict_variances(&xtest.view())
843+
.predic_var(&xtest.view())
844844
.expect("Predict var fail");
845845
println!("{ytest:?}");
846846
assert_abs_diff_eq!(

gp/examples/kriging.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ fn main() {
2424
let ypred = kriging.predict(&xtest).expect("Kriging prediction");
2525
// predict standard deviation
2626
let ysigma = kriging
27-
.predict_variances(&xtest)
27+
.predic_var(&xtest)
2828
.expect("Kriging prediction")
2929
.map(|v| v.sqrt());
3030

gp/src/algorithm.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ impl<F: Float> Clone for GpInnerParams<F> {
159159
/// let ytest = xsinx(&xtest);
160160
///
161161
/// let ypred = kriging.predict(&xtest).expect("Kriging prediction");
162-
/// let yvariances = kriging.predict_variances(&xtest).expect("Kriging prediction");
162+
/// let yvariances = kriging.predic_var(&xtest).expect("Kriging prediction");
163163
///```
164164
///
165165
/// # Reference:
@@ -270,7 +270,7 @@ impl<F: Float, Mean: RegressionModel<F>, Corr: CorrelationModel<F>> GaussianProc
270270

271271
/// Predict variance values at n given `x` points of nx components specified as a (n, nx) matrix.
272272
/// Returns n variance values as (n, 1) column vector.
273-
pub fn predict_variances(&self, x: &ArrayBase<impl Data<Elem = F>, Ix2>) -> Result<Array2<F>> {
273+
pub fn predic_var(&self, x: &ArrayBase<impl Data<Elem = F>, Ix2>) -> Result<Array2<F>> {
274274
let (rt, u, _) = self._compute_rt_u(x);
275275

276276
let mut b = Array::ones(rt.ncols()) - rt.mapv(|v| v * v).sum_axis(Axis(0))
@@ -767,7 +767,7 @@ where
767767
"The number of data points must match the number of output targets."
768768
);
769769

770-
let values = self.0.predict_variances(x).expect("GP Prediction");
770+
let values = self.0.predic_var(x).expect("GP Prediction");
771771
*y = values;
772772
}
773773

@@ -1300,12 +1300,12 @@ mod tests {
13001300
let gpr_vals = gp.predict(&xplot).unwrap();
13011301

13021302
let yvars = gp
1303-
.predict_variances(&arr2(&[[1.0], [3.5]]))
1303+
.predic_var(&arr2(&[[1.0], [3.5]]))
13041304
.expect("prediction error");
13051305
let expected_vars = arr2(&[[0.], [0.1]]);
13061306
assert_abs_diff_eq!(expected_vars, yvars, epsilon = 0.5);
13071307

1308-
let gpr_vars = gp.predict_variances(&xplot).unwrap();
1308+
let gpr_vars = gp.predic_var(&xplot).unwrap();
13091309

13101310
let test_dir = "target/tests";
13111311
std::fs::create_dir_all(test_dir).ok();
@@ -1602,7 +1602,7 @@ mod tests {
16021602
println!("value at [{},{}] = {}", xa, xb, y_pred);
16031603
let y_deriv = gp.predict_derivatives(&x);
16041604
println!("deriv at [{},{}] = {}", xa, xb, y_deriv);
1605-
let y_pred = gp.predict_variances(&x).unwrap();
1605+
let y_pred = gp.predic_var(&x).unwrap();
16061606
println!("variance at [{},{}] = {}", xa, xb, y_pred);
16071607
let y_deriv = gp.predict_variance_derivatives(&x);
16081608
println!("variance deriv at [{},{}] = {}", xa, xb, y_deriv);
@@ -1658,7 +1658,7 @@ mod tests {
16581658
[xa, xb + e],
16591659
[xa, xb - e]
16601660
];
1661-
let y_pred = gp.predict_variances(&x).unwrap();
1661+
let y_pred = gp.predic_var(&x).unwrap();
16621662
println!("variance at [{xa},{xb}] = {y_pred}");
16631663
let y_deriv = gp.predict_variance_derivatives(&x);
16641664
println!("variance deriv at [{xa},{xb}] = {y_deriv}");

gp/src/sparse_algorithm.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ impl<F: Float> Clone for WoodburyData<F> {
121121
/// // Predict with our trained SGP
122122
/// let xplot = Array::linspace(-1., 1., 100).insert_axis(Axis(1));
123123
/// let sgp_vals = sgp.predict(&xplot).unwrap();
124-
/// let sgp_vars = sgp.predict_variances(&xplot).unwrap();
124+
/// let sgp_vars = sgp.predic_var(&xplot).unwrap();
125125
/// ```
126126
///
127127
/// # Reference
@@ -239,7 +239,7 @@ impl<F: Float, Corr: CorrelationModel<F>> SparseGaussianProcess<F, Corr> {
239239

240240
/// Predict variance values at n given `x` points of nx components specified as a (n, nx) matrix.
241241
/// Returns n variance values as (n, 1) column vector.
242-
pub fn predict_variances(&self, x: &ArrayBase<impl Data<Elem = F>, Ix2>) -> Result<Array2<F>> {
242+
pub fn predic_var(&self, x: &ArrayBase<impl Data<Elem = F>, Ix2>) -> Result<Array2<F>> {
243243
let kx = self.compute_k(&self.inducings, x, &self.w_star, &self.theta, self.sigma2);
244244
let kxx = Array::from_elem(x.nrows(), self.sigma2);
245245
let var = kxx - (self.w_data.inv.t().clone().dot(&kx) * &kx).sum_axis(Axis(0));
@@ -340,7 +340,7 @@ where
340340
"The number of data points must match the number of output targets."
341341
);
342342

343-
let values = self.0.predict_variances(x).expect("GP Prediction");
343+
let values = self.0.predic_var(x).expect("GP Prediction");
344344
*y = values;
345345
}
346346

@@ -863,7 +863,7 @@ mod tests {
863863
let yplot = f_obj(&xplot);
864864
let errvals = (yplot - &sgp_vals).mapv(|v| v.abs());
865865
assert_abs_diff_eq!(errvals, Array2::zeros((xplot.nrows(), 1)), epsilon = 0.5);
866-
let sgp_vars = sgp.predict_variances(&xplot).unwrap();
866+
let sgp_vars = sgp.predic_var(&xplot).unwrap();
867867
let errvars = (&sgp_vars - Array2::from_elem((xplot.nrows(), 1), 0.01)).mapv(|v| v.abs());
868868
assert_abs_diff_eq!(errvars, Array2::zeros((xplot.nrows(), 1)), epsilon = 0.2);
869869

@@ -906,7 +906,7 @@ mod tests {
906906
assert_abs_diff_eq!(eta2, sgp.noise_variance());
907907

908908
let sgp_vals = sgp.predict(&xplot).unwrap();
909-
let sgp_vars = sgp.predict_variances(&xplot).unwrap();
909+
let sgp_vars = sgp.predic_var(&xplot).unwrap();
910910

911911
save_data(&xt, &yt, sgp.inducings(), &xplot, &sgp_vals, &sgp_vars);
912912
}
@@ -953,7 +953,7 @@ mod tests {
953953
assert_abs_diff_eq!(&z, sgp.inducings(), epsilon = 0.0015);
954954

955955
let sgp_vals = sgp.predict(&xplot).unwrap();
956-
let sgp_vars = sgp.predict_variances(&xplot).unwrap();
956+
let sgp_vars = sgp.predic_var(&xplot).unwrap();
957957

958958
save_data(&xt, &yt, &z, &xplot, &sgp_vals, &sgp_vars);
959959
}

moe/src/gp_algorithm.rs

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -414,10 +414,10 @@ impl GpSurrogate for GpMixture {
414414
}
415415
}
416416

417-
fn predict_variances(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>> {
417+
fn predic_var(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>> {
418418
match self.recombination {
419-
Recombination::Hard => self.predict_variances_hard(x),
420-
Recombination::Smooth(_) => self.predict_variances_smooth(x),
419+
Recombination::Hard => self.predic_var_hard(x),
420+
Recombination::Smooth(_) => self.predic_var_smooth(x),
421421
}
422422
}
423423
/// Save Moe model in given file.
@@ -514,7 +514,7 @@ impl GpMixture {
514514
/// Gaussian Mixture is used to get the probability of the point to belongs to one cluster
515515
/// or another (ie responsabilities).
516516
/// The smooth recombination of each cluster expert responsabilty is used to get the result.
517-
pub fn predict_variances_smooth(
517+
pub fn predic_var_smooth(
518518
&self,
519519
x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
520520
) -> Result<Array2<f64>> {
@@ -529,7 +529,7 @@ impl GpMixture {
529529
let preds: Array1<f64> = self
530530
.experts
531531
.iter()
532-
.map(|gp| gp.predict_variances(&x).unwrap()[[0, 0]])
532+
.map(|gp| gp.predic_var(&x).unwrap()[[0, 0]])
533533
.collect();
534534
*y = (preds * p * p).sum();
535535
});
@@ -606,7 +606,7 @@ impl GpMixture {
606606
let preds: Array1<f64> = self
607607
.experts
608608
.iter()
609-
.map(|gp| gp.predict_variances(&xii).unwrap()[[0, 0]])
609+
.map(|gp| gp.predic_var(&xii).unwrap()[[0, 0]])
610610
.collect();
611611
let drvs: Vec<Array1<f64>> = self
612612
.experts
@@ -665,7 +665,7 @@ impl GpMixture {
665665
/// Gaussian Mixture is used to get the cluster where the point belongs (highest responsability)
666666
/// The expert of the cluster is used to predict variance value.
667667
/// Returns the variances as a (n, 1) column vector
668-
pub fn predict_variances_hard(
668+
pub fn predic_var_hard(
669669
&self,
670670
x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
671671
) -> Result<Array2<f64>> {
@@ -678,7 +678,7 @@ impl GpMixture {
678678
.for_each(|mut y, x, &c| {
679679
y.assign(
680680
&self.experts[c]
681-
.predict_variances(&x.insert_axis(Axis(0)))
681+
.predic_var(&x.insert_axis(Axis(0)))
682682
.unwrap()
683683
.row(0),
684684
);
@@ -747,11 +747,8 @@ impl GpMixture {
747747
<GpMixture as GpSurrogate>::predict(self, &x.view())
748748
}
749749

750-
pub fn predict_variances(
751-
&self,
752-
x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
753-
) -> Result<Array2<f64>> {
754-
<GpMixture as GpSurrogate>::predict_variances(self, &x.view())
750+
pub fn predic_var(&self, x: &ArrayBase<impl Data<Elem = f64>, Ix2>) -> Result<Array2<f64>> {
751+
<GpMixture as GpSurrogate>::predic_var(self, &x.view())
755752
}
756753

757754
pub fn predict_derivatives(
@@ -829,10 +826,7 @@ impl<'a, D: Data<Elem = f64>> PredictInplace<ArrayBase<D, Ix2>, Array2<f64>>
829826
"The number of data points must match the number of output targets."
830827
);
831828

832-
let values = self
833-
.0
834-
.predict_variances(x)
835-
.expect("MoE variances prediction");
829+
let values = self.0.predic_var(x).expect("MoE variances prediction");
836830
*y = values;
837831
}
838832

@@ -1003,7 +997,7 @@ mod tests {
1003997
.expect("MOE fitted");
1004998
// Smoke test: prediction is pretty good hence variance is very low
1005999
let x = Array1::linspace(0., 1., 20).insert_axis(Axis(1));
1006-
let variances = moe.predict_variances(&x).expect("MOE variances prediction");
1000+
let variances = moe.predic_var(&x).expect("MOE variances prediction");
10071001
assert_abs_diff_eq!(*variances.max().unwrap(), 0., epsilon = 1e-10);
10081002
}
10091003

@@ -1163,7 +1157,7 @@ mod tests {
11631157
assert_rel_or_abs_error(y_deriv[[0, 0]], diff_g);
11641158
assert_rel_or_abs_error(y_deriv[[0, 1]], diff_d);
11651159

1166-
let y_pred = moe.predict_variances(&x).unwrap();
1160+
let y_pred = moe.predic_var(&x).unwrap();
11671161
let y_deriv = moe.predict_variance_derivatives(&x).unwrap();
11681162

11691163
let diff_g = (y_pred[[1, 0]] - y_pred[[2, 0]]) / (2. * e);

moe/src/sgp_algorithm.rs

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -408,10 +408,10 @@ impl GpSurrogate for SparseGpMixture {
408408
}
409409
}
410410

411-
fn predict_variances(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>> {
411+
fn predic_var(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>> {
412412
match self.recombination {
413-
Recombination::Hard => self.predict_variances_hard(x),
414-
Recombination::Smooth(_) => self.predict_variances_smooth(x),
413+
Recombination::Hard => self.predic_var_hard(x),
414+
Recombination::Smooth(_) => self.predic_var_smooth(x),
415415
}
416416
}
417417
/// Save Sgp model in given file.
@@ -479,7 +479,7 @@ impl SparseGpMixture {
479479
/// Gaussian Mixture is used to get the probability of the point to belongs to one cluster
480480
/// or another (ie responsabilities).
481481
/// The smooth recombination of each cluster expert responsabilty is used to get the result.
482-
pub fn predict_variances_smooth(
482+
pub fn predic_var_smooth(
483483
&self,
484484
x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
485485
) -> Result<Array2<f64>> {
@@ -494,7 +494,7 @@ impl SparseGpMixture {
494494
let preds: Array1<f64> = self
495495
.experts
496496
.iter()
497-
.map(|gp| gp.predict_variances(&x).unwrap()[[0, 0]])
497+
.map(|gp| gp.predic_var(&x).unwrap()[[0, 0]])
498498
.collect();
499499
*y = (preds * p * p).sum();
500500
});
@@ -527,7 +527,7 @@ impl SparseGpMixture {
527527
/// Gaussian Mixture is used to get the cluster where the point belongs (highest responsability)
528528
/// The expert of the cluster is used to predict variance value.
529529
/// Returns the variances as a (n, 1) column vector
530-
pub fn predict_variances_hard(
530+
pub fn predic_var_hard(
531531
&self,
532532
x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
533533
) -> Result<Array2<f64>> {
@@ -540,7 +540,7 @@ impl SparseGpMixture {
540540
.for_each(|mut y, x, &c| {
541541
y.assign(
542542
&self.experts[c]
543-
.predict_variances(&x.insert_axis(Axis(0)))
543+
.predic_var(&x.insert_axis(Axis(0)))
544544
.unwrap()
545545
.row(0),
546546
);
@@ -552,11 +552,8 @@ impl SparseGpMixture {
552552
<SparseGpMixture as GpSurrogate>::predict(self, &x.view())
553553
}
554554

555-
pub fn predict_variances(
556-
&self,
557-
x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
558-
) -> Result<Array2<f64>> {
559-
<SparseGpMixture as GpSurrogate>::predict_variances(self, &x.view())
555+
pub fn predic_var(&self, x: &ArrayBase<impl Data<Elem = f64>, Ix2>) -> Result<Array2<f64>> {
556+
<SparseGpMixture as GpSurrogate>::predic_var(self, &x.view())
560557
}
561558

562559
#[cfg(feature = "persistent")]
@@ -611,10 +608,7 @@ impl<'a, D: Data<Elem = f64>> PredictInplace<ArrayBase<D, Ix2>, Array2<f64>>
611608
"The number of data points must match the number of output targets."
612609
);
613610

614-
let values = self
615-
.0
616-
.predict_variances(x)
617-
.expect("Sgp variances prediction");
611+
let values = self.0.predic_var(x).expect("Sgp variances prediction");
618612
*y = values;
619613
}
620614

@@ -675,7 +669,7 @@ mod tests {
675669
let yplot = f_obj(&xplot);
676670
let errvals = (yplot - &sgp_vals).mapv(|v| v.abs());
677671
assert_abs_diff_eq!(errvals, Array2::zeros((xplot.nrows(), 1)), epsilon = 1.0);
678-
let sgp_vars = sgp.predict_variances(&xplot).unwrap();
672+
let sgp_vars = sgp.predic_var(&xplot).unwrap();
679673
let errvars = (&sgp_vars - Array2::from_elem((xplot.nrows(), 1), 0.01)).mapv(|v| v.abs());
680674
assert_abs_diff_eq!(errvars, Array2::zeros((xplot.nrows(), 1)), epsilon = 0.05);
681675
}

0 commit comments

Comments
 (0)