@@ -414,10 +414,10 @@ impl GpSurrogate for GpMixture {
414414 }
415415 }
416416
417- fn predict_variances ( & self , x : & ArrayView2 < f64 > ) -> Result < Array2 < f64 > > {
417+ fn predic_var ( & self , x : & ArrayView2 < f64 > ) -> Result < Array2 < f64 > > {
418418 match self . recombination {
419- Recombination :: Hard => self . predict_variances_hard ( x) ,
420- Recombination :: Smooth ( _) => self . predict_variances_smooth ( x) ,
419+ Recombination :: Hard => self . predic_var_hard ( x) ,
420+ Recombination :: Smooth ( _) => self . predic_var_smooth ( x) ,
421421 }
422422 }
423423 /// Save Moe model in given file.
@@ -514,7 +514,7 @@ impl GpMixture {
514514 /// Gaussian Mixture is used to get the probability of the point to belongs to one cluster
515515 /// or another (ie responsabilities).
516516 /// The smooth recombination of each cluster expert responsabilty is used to get the result.
517- pub fn predict_variances_smooth (
517+ pub fn predic_var_smooth (
518518 & self ,
519519 x : & ArrayBase < impl Data < Elem = f64 > , Ix2 > ,
520520 ) -> Result < Array2 < f64 > > {
@@ -529,7 +529,7 @@ impl GpMixture {
529529 let preds: Array1 < f64 > = self
530530 . experts
531531 . iter ( )
532- . map ( |gp| gp. predict_variances ( & x) . unwrap ( ) [ [ 0 , 0 ] ] )
532+ . map ( |gp| gp. predic_var ( & x) . unwrap ( ) [ [ 0 , 0 ] ] )
533533 . collect ( ) ;
534534 * y = ( preds * p * p) . sum ( ) ;
535535 } ) ;
@@ -606,7 +606,7 @@ impl GpMixture {
606606 let preds: Array1 < f64 > = self
607607 . experts
608608 . iter ( )
609- . map ( |gp| gp. predict_variances ( & xii) . unwrap ( ) [ [ 0 , 0 ] ] )
609+ . map ( |gp| gp. predic_var ( & xii) . unwrap ( ) [ [ 0 , 0 ] ] )
610610 . collect ( ) ;
611611 let drvs: Vec < Array1 < f64 > > = self
612612 . experts
@@ -665,7 +665,7 @@ impl GpMixture {
665665 /// Gaussian Mixture is used to get the cluster where the point belongs (highest responsability)
666666 /// The expert of the cluster is used to predict variance value.
667667 /// Returns the variances as a (n, 1) column vector
668- pub fn predict_variances_hard (
668+ pub fn predic_var_hard (
669669 & self ,
670670 x : & ArrayBase < impl Data < Elem = f64 > , Ix2 > ,
671671 ) -> Result < Array2 < f64 > > {
@@ -678,7 +678,7 @@ impl GpMixture {
678678 . for_each ( |mut y, x, & c| {
679679 y. assign (
680680 & self . experts [ c]
681- . predict_variances ( & x. insert_axis ( Axis ( 0 ) ) )
681+ . predic_var ( & x. insert_axis ( Axis ( 0 ) ) )
682682 . unwrap ( )
683683 . row ( 0 ) ,
684684 ) ;
@@ -747,11 +747,8 @@ impl GpMixture {
747747 <GpMixture as GpSurrogate >:: predict ( self , & x. view ( ) )
748748 }
749749
750- pub fn predict_variances (
751- & self ,
752- x : & ArrayBase < impl Data < Elem = f64 > , Ix2 > ,
753- ) -> Result < Array2 < f64 > > {
754- <GpMixture as GpSurrogate >:: predict_variances ( self , & x. view ( ) )
750+ pub fn predic_var ( & self , x : & ArrayBase < impl Data < Elem = f64 > , Ix2 > ) -> Result < Array2 < f64 > > {
751+ <GpMixture as GpSurrogate >:: predic_var ( self , & x. view ( ) )
755752 }
756753
757754 pub fn predict_derivatives (
@@ -829,10 +826,7 @@ impl<'a, D: Data<Elem = f64>> PredictInplace<ArrayBase<D, Ix2>, Array2<f64>>
829826 "The number of data points must match the number of output targets."
830827 ) ;
831828
832- let values = self
833- . 0
834- . predict_variances ( x)
835- . expect ( "MoE variances prediction" ) ;
829+ let values = self . 0 . predic_var ( x) . expect ( "MoE variances prediction" ) ;
836830 * y = values;
837831 }
838832
@@ -1003,7 +997,7 @@ mod tests {
1003997 . expect ( "MOE fitted" ) ;
1004998 // Smoke test: prediction is pretty good hence variance is very low
1005999 let x = Array1 :: linspace ( 0. , 1. , 20 ) . insert_axis ( Axis ( 1 ) ) ;
1006- let variances = moe. predict_variances ( & x) . expect ( "MOE variances prediction" ) ;
1000+ let variances = moe. predic_var ( & x) . expect ( "MOE variances prediction" ) ;
10071001 assert_abs_diff_eq ! ( * variances. max( ) . unwrap( ) , 0. , epsilon = 1e-10 ) ;
10081002 }
10091003
@@ -1163,7 +1157,7 @@ mod tests {
11631157 assert_rel_or_abs_error ( y_deriv[ [ 0 , 0 ] ] , diff_g) ;
11641158 assert_rel_or_abs_error ( y_deriv[ [ 0 , 1 ] ] , diff_d) ;
11651159
1166- let y_pred = moe. predict_variances ( & x) . unwrap ( ) ;
1160+ let y_pred = moe. predic_var ( & x) . unwrap ( ) ;
11671161 let y_deriv = moe. predict_variance_derivatives ( & x) . unwrap ( ) ;
11681162
11691163 let diff_g = ( y_pred[ [ 1 , 0 ] ] - y_pred[ [ 2 , 0 ] ] ) / ( 2. * e) ;
0 commit comments