Skip to content

Commit 3691fc0

Browse files
committed
Rename GP predict_variance_derivatives in predict_var_derivatives
1 parent 40c46a4 commit 3691fc0

File tree

12 files changed

+84
-94
lines changed

12 files changed

+84
-94
lines changed

ego/src/criteria/ei.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ impl InfillCriterion for ExpectedImprovement {
2323
) -> f64 {
2424
let pt = ArrayView::from_shape((1, x.len()), x).unwrap();
2525
if let Ok(p) = obj_model.predict(&pt) {
26-
if let Ok(s) = obj_model.predic_var(&pt) {
26+
if let Ok(s) = obj_model.predict_var(&pt) {
2727
let pred = p[[0, 0]];
2828
let sigma = s[[0, 0]].sqrt();
2929
let args0 = (f_min - pred) / sigma;
@@ -49,7 +49,7 @@ impl InfillCriterion for ExpectedImprovement {
4949
) -> Array1<f64> {
5050
let pt = ArrayView::from_shape((1, x.len()), x).unwrap();
5151
if let Ok(p) = obj_model.predict(&pt) {
52-
if let Ok(s) = obj_model.predic_var(&pt) {
52+
if let Ok(s) = obj_model.predict_var(&pt) {
5353
let sigma = s[[0, 0]].sqrt();
5454
if sigma.abs() < 1e-12 {
5555
Array1::zeros(pt.len())
@@ -59,7 +59,7 @@ impl InfillCriterion for ExpectedImprovement {
5959
let arg = (f_min - pred) / sigma;
6060
let y_prime = obj_model.predict_derivatives(&pt).unwrap();
6161
let y_prime = y_prime.row(0);
62-
let sig_2_prime = obj_model.predict_variance_derivatives(&pt).unwrap();
62+
let sig_2_prime = obj_model.predict_var_derivatives(&pt).unwrap();
6363

6464
let sig_2_prime = sig_2_prime.row(0);
6565
let sig_prime = sig_2_prime.mapv(|v| v / (2. * sigma));

ego/src/criteria/wb2.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,11 @@ mod tests {
165165
let xtest12 = array![[x1 - h, x2]];
166166
let xtest21 = array![[x1, x2 + h]];
167167
let xtest22 = array![[x1, x2 - h]];
168-
let fdiff1 = (bgp.predic_var(&xtest11.view()).unwrap()
169-
- bgp.predic_var(&xtest12.view()).unwrap())
168+
let fdiff1 = (bgp.predict_var(&xtest11.view()).unwrap()
169+
- bgp.predict_var(&xtest12.view()).unwrap())
170170
/ (2. * h);
171-
let fdiff2 = (bgp.predic_var(&xtest21.view()).unwrap()
172-
- bgp.predic_var(&xtest22.view()).unwrap())
171+
let fdiff2 = (bgp.predict_var(&xtest21.view()).unwrap()
172+
- bgp.predict_var(&xtest22.view()).unwrap())
173173
/ (2. * h);
174174
println!(
175175
"gp var fdiff({}) = [[{}, {}]]",
@@ -180,7 +180,7 @@ mod tests {
180180
println!(
181181
"GP predict variances derivatives({}) = {}",
182182
xtest,
183-
bgp.predict_variance_derivatives(&basetest.view()).unwrap()
183+
bgp.predict_var_derivatives(&basetest.view()).unwrap()
184184
);
185185
}
186186
}

ego/src/egor_solver.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ where
883883
} else {
884884
let x = &xk.view().insert_axis(Axis(0));
885885
let pred = obj_model.predict(x)?[[0, 0]];
886-
let var = obj_model.predic_var(x)?[[0, 0]];
886+
let var = obj_model.predict_var(x)?[[0, 0]];
887887
let conf = match self.config.q_ei {
888888
QEiStrategy::KrigingBeliever => 0.,
889889
QEiStrategy::KrigingBelieverLowerBound => -3.,

ego/src/mixint.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -541,14 +541,14 @@ impl GpSurrogate for MixintMoe {
541541
self.moe.predict(&xcast)
542542
}
543543

544-
fn predic_var(&self, x: &ArrayView2<f64>) -> egobox_moe::Result<Array2<f64>> {
544+
fn predict_var(&self, x: &ArrayView2<f64>) -> egobox_moe::Result<Array2<f64>> {
545545
let mut xcast = if self.work_in_folded_space {
546546
unfold_with_enum_mask(&self.xtypes, x)
547547
} else {
548548
x.to_owned()
549549
};
550550
cast_to_discrete_values_mut(&self.xtypes, &mut xcast);
551-
self.moe.predic_var(&xcast)
551+
self.moe.predict_var(&xcast)
552552
}
553553

554554
/// Save Moe model in given file.
@@ -576,14 +576,14 @@ impl GpSurrogateExt for MixintMoe {
576576
self.moe.predict_derivatives(&xcast)
577577
}
578578

579-
fn predict_variance_derivatives(&self, x: &ArrayView2<f64>) -> egobox_moe::Result<Array2<f64>> {
579+
fn predict_var_derivatives(&self, x: &ArrayView2<f64>) -> egobox_moe::Result<Array2<f64>> {
580580
let mut xcast = if self.work_in_folded_space {
581581
unfold_with_enum_mask(&self.xtypes, x)
582582
} else {
583583
x.to_owned()
584584
};
585585
cast_to_discrete_values_mut(&self.xtypes, &mut xcast);
586-
self.moe.predict_variance_derivatives(&xcast)
586+
self.moe.predict_var_derivatives(&xcast)
587587
}
588588

589589
fn sample(&self, x: &ArrayView2<f64>, n_traj: usize) -> egobox_moe::Result<Array2<f64>> {
@@ -633,7 +633,7 @@ impl<'a, D: Data<Elem = f64>> PredictInplace<ArrayBase<D, Ix2>, Array2<f64>>
633633

634634
let values = self
635635
.0
636-
.predic_var(x)
636+
.predict_var(x)
637637
.expect("MixintMoE variances prediction");
638638
*y = values;
639639
}
@@ -840,7 +840,7 @@ mod tests {
840840
let xtest = Array::linspace(0.0, 4.0, num).insert_axis(Axis(1));
841841
let ytest = mixi_moe.predict(&xtest.view()).expect("Predict val fail");
842842
let yvar = mixi_moe
843-
.predic_var(&xtest.view())
843+
.predict_var(&xtest.view())
844844
.expect("Predict var fail");
845845
println!("{ytest:?}");
846846
assert_abs_diff_eq!(

gp/examples/kriging.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ fn main() {
2424
let ypred = kriging.predict(&xtest).expect("Kriging prediction");
2525
// predict standard deviation
2626
let ysigma = kriging
27-
.predic_var(&xtest)
27+
.predict_var(&xtest)
2828
.expect("Kriging prediction")
2929
.map(|v| v.sqrt());
3030

gp/src/algorithm.rs

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ impl<F: Float> Clone for GpInnerParams<F> {
159159
/// let ytest = xsinx(&xtest);
160160
///
161161
/// let ypred = kriging.predict(&xtest).expect("Kriging prediction");
162-
/// let yvariances = kriging.predic_var(&xtest).expect("Kriging prediction");
162+
/// let yvariances = kriging.predict_var(&xtest).expect("Kriging prediction");
163163
///```
164164
///
165165
/// # Reference:
@@ -270,7 +270,7 @@ impl<F: Float, Mean: RegressionModel<F>, Corr: CorrelationModel<F>> GaussianProc
270270

271271
/// Predict variance values at n given `x` points of nx components specified as a (n, nx) matrix.
272272
/// Returns n variance values as (n, 1) column vector.
273-
pub fn predic_var(&self, x: &ArrayBase<impl Data<Elem = F>, Ix2>) -> Result<Array2<F>> {
273+
pub fn predict_var(&self, x: &ArrayBase<impl Data<Elem = F>, Ix2>) -> Result<Array2<F>> {
274274
let (rt, u, _) = self._compute_rt_u(x);
275275

276276
let mut b = Array::ones(rt.ncols()) - rt.mapv(|v| v * v).sum_axis(Axis(0))
@@ -561,7 +561,7 @@ impl<F: Float, Mean: RegressionModel<F>, Corr: CorrelationModel<F>> GaussianProc
561561
/// Predict variance derivatives at a point `x` specified as a (nx,) vector where x has nx components.
562562
/// Returns a (nx,) vector containing variance derivatives at `x` wrt each nx components
563563
#[cfg(not(feature = "blas"))]
564-
pub fn predict_variance_derivatives_single(
564+
pub fn predict_var_derivatives_single(
565565
&self,
566566
x: &ArrayBase<impl Data<Elem = F>, Ix1>,
567567
) -> Array1<F> {
@@ -630,7 +630,7 @@ impl<F: Float, Mean: RegressionModel<F>, Corr: CorrelationModel<F>> GaussianProc
630630

631631
/// See non blas version
632632
#[cfg(feature = "blas")]
633-
pub fn predict_variance_derivatives_single(
633+
pub fn predict_var_derivatives_single(
634634
&self,
635635
x: &ArrayBase<impl Data<Elem = F>, Ix1>,
636636
) -> Array1<F> {
@@ -709,14 +709,11 @@ impl<F: Float, Mean: RegressionModel<F>, Corr: CorrelationModel<F>> GaussianProc
709709

710710
/// Predict variance derivatives at a set of points `x` specified as a (n, nx) matrix where x has nx components.
711711
/// Returns a (n, nx) matrix containing variance derivatives at `x` wrt each nx components
712-
pub fn predict_variance_derivatives(
713-
&self,
714-
x: &ArrayBase<impl Data<Elem = F>, Ix2>,
715-
) -> Array2<F> {
712+
pub fn predict_var_derivatives(&self, x: &ArrayBase<impl Data<Elem = F>, Ix2>) -> Array2<F> {
716713
let mut derivs = Array::zeros((x.nrows(), x.ncols()));
717714
Zip::from(derivs.rows_mut())
718715
.and(x.rows())
719-
.for_each(|mut der, x| der.assign(&self.predict_variance_derivatives_single(&x)));
716+
.for_each(|mut der, x| der.assign(&self.predict_var_derivatives_single(&x)));
720717
derivs
721718
}
722719
}
@@ -767,7 +764,7 @@ where
767764
"The number of data points must match the number of output targets."
768765
);
769766

770-
let values = self.0.predic_var(x).expect("GP Prediction");
767+
let values = self.0.predict_var(x).expect("GP Prediction");
771768
*y = values;
772769
}
773770

@@ -1300,12 +1297,12 @@ mod tests {
13001297
let gpr_vals = gp.predict(&xplot).unwrap();
13011298

13021299
let yvars = gp
1303-
.predic_var(&arr2(&[[1.0], [3.5]]))
1300+
.predict_var(&arr2(&[[1.0], [3.5]]))
13041301
.expect("prediction error");
13051302
let expected_vars = arr2(&[[0.], [0.1]]);
13061303
assert_abs_diff_eq!(expected_vars, yvars, epsilon = 0.5);
13071304

1308-
let gpr_vars = gp.predic_var(&xplot).unwrap();
1305+
let gpr_vars = gp.predict_var(&xplot).unwrap();
13091306

13101307
let test_dir = "target/tests";
13111308
std::fs::create_dir_all(test_dir).ok();
@@ -1602,9 +1599,9 @@ mod tests {
16021599
println!("value at [{},{}] = {}", xa, xb, y_pred);
16031600
let y_deriv = gp.predict_derivatives(&x);
16041601
println!("deriv at [{},{}] = {}", xa, xb, y_deriv);
1605-
let y_pred = gp.predic_var(&x).unwrap();
1602+
let y_pred = gp.predict_var(&x).unwrap();
16061603
println!("variance at [{},{}] = {}", xa, xb, y_pred);
1607-
let y_deriv = gp.predict_variance_derivatives(&x);
1604+
let y_deriv = gp.predict_var_derivatives(&x);
16081605
println!("variance deriv at [{},{}] = {}", xa, xb, y_deriv);
16091606

16101607
let diff_g = (y_pred[[1, 0]] - y_pred[[2, 0]]) / (2. * e);
@@ -1658,9 +1655,9 @@ mod tests {
16581655
[xa, xb + e],
16591656
[xa, xb - e]
16601657
];
1661-
let y_pred = gp.predic_var(&x).unwrap();
1658+
let y_pred = gp.predict_var(&x).unwrap();
16621659
println!("variance at [{xa},{xb}] = {y_pred}");
1663-
let y_deriv = gp.predict_variance_derivatives(&x);
1660+
let y_deriv = gp.predict_var_derivatives(&x);
16641661
println!("variance deriv at [{xa},{xb}] = {y_deriv}");
16651662

16661663
let diff_g = (y_pred[[1, 0]] - y_pred[[2, 0]]) / (2. * e);

gp/src/sparse_algorithm.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ impl<F: Float> Clone for WoodburyData<F> {
121121
/// // Predict with our trained SGP
122122
/// let xplot = Array::linspace(-1., 1., 100).insert_axis(Axis(1));
123123
/// let sgp_vals = sgp.predict(&xplot).unwrap();
124-
/// let sgp_vars = sgp.predic_var(&xplot).unwrap();
124+
/// let sgp_vars = sgp.predict_var(&xplot).unwrap();
125125
/// ```
126126
///
127127
/// # Reference
@@ -239,7 +239,7 @@ impl<F: Float, Corr: CorrelationModel<F>> SparseGaussianProcess<F, Corr> {
239239

240240
/// Predict variance values at n given `x` points of nx components specified as a (n, nx) matrix.
241241
/// Returns n variance values as (n, 1) column vector.
242-
pub fn predic_var(&self, x: &ArrayBase<impl Data<Elem = F>, Ix2>) -> Result<Array2<F>> {
242+
pub fn predict_var(&self, x: &ArrayBase<impl Data<Elem = F>, Ix2>) -> Result<Array2<F>> {
243243
let kx = self.compute_k(&self.inducings, x, &self.w_star, &self.theta, self.sigma2);
244244
let kxx = Array::from_elem(x.nrows(), self.sigma2);
245245
let var = kxx - (self.w_data.inv.t().clone().dot(&kx) * &kx).sum_axis(Axis(0));
@@ -340,7 +340,7 @@ where
340340
"The number of data points must match the number of output targets."
341341
);
342342

343-
let values = self.0.predic_var(x).expect("GP Prediction");
343+
let values = self.0.predict_var(x).expect("GP Prediction");
344344
*y = values;
345345
}
346346

@@ -863,7 +863,7 @@ mod tests {
863863
let yplot = f_obj(&xplot);
864864
let errvals = (yplot - &sgp_vals).mapv(|v| v.abs());
865865
assert_abs_diff_eq!(errvals, Array2::zeros((xplot.nrows(), 1)), epsilon = 0.5);
866-
let sgp_vars = sgp.predic_var(&xplot).unwrap();
866+
let sgp_vars = sgp.predict_var(&xplot).unwrap();
867867
let errvars = (&sgp_vars - Array2::from_elem((xplot.nrows(), 1), 0.01)).mapv(|v| v.abs());
868868
assert_abs_diff_eq!(errvars, Array2::zeros((xplot.nrows(), 1)), epsilon = 0.2);
869869

@@ -906,7 +906,7 @@ mod tests {
906906
assert_abs_diff_eq!(eta2, sgp.noise_variance());
907907

908908
let sgp_vals = sgp.predict(&xplot).unwrap();
909-
let sgp_vars = sgp.predic_var(&xplot).unwrap();
909+
let sgp_vars = sgp.predict_var(&xplot).unwrap();
910910

911911
save_data(&xt, &yt, sgp.inducings(), &xplot, &sgp_vals, &sgp_vars);
912912
}
@@ -953,7 +953,7 @@ mod tests {
953953
assert_abs_diff_eq!(&z, sgp.inducings(), epsilon = 0.0015);
954954

955955
let sgp_vals = sgp.predict(&xplot).unwrap();
956-
let sgp_vars = sgp.predic_var(&xplot).unwrap();
956+
let sgp_vars = sgp.predict_var(&xplot).unwrap();
957957

958958
save_data(&xt, &yt, &z, &xplot, &sgp_vals, &sgp_vars);
959959
}

0 commit comments

Comments
 (0)