@@ -322,6 +322,20 @@ impl SupervisedModel {
322
322
/// model.train();
323
323
/// ```
324
324
pub fn train ( & mut self ) {
325
+ // Train any necessary preprocessing
326
+ if let PreProcessing :: ReplaceWithPCA {
327
+ number_of_components,
328
+ } = self . settings . preprocessing
329
+ {
330
+ self . train_pca ( self . x_train . clone ( ) , number_of_components) ;
331
+ }
332
+ if let PreProcessing :: ReplaceWithSVD {
333
+ number_of_components,
334
+ } = self . settings . preprocessing
335
+ {
336
+ self . train_svd ( self . x_train . clone ( ) , number_of_components) ;
337
+ }
338
+
325
339
// Preprocess the data
326
340
self . x_train = self . preprocess ( self . x_train . clone ( ) ) ;
327
341
@@ -727,17 +741,18 @@ impl SupervisedModel {
727
741
x
728
742
}
729
743
730
- fn pca_features ( & mut self , x : DenseMatrix < f32 > , n : usize ) -> DenseMatrix < f32 > {
731
- if let None = self . preprocessing . 0 {
732
- let pca = PCA :: fit (
733
- & x,
734
- PCAParameters :: default ( )
735
- . with_n_components ( n)
736
- . with_use_correlation_matrix ( true ) ,
737
- )
738
- . unwrap ( ) ;
739
- self . preprocessing . 0 = Some ( pca) ;
740
- }
744
+ fn train_pca ( & mut self , x : DenseMatrix < f32 > , n : usize ) {
745
+ let pca = PCA :: fit (
746
+ & x,
747
+ PCAParameters :: default ( )
748
+ . with_n_components ( n)
749
+ . with_use_correlation_matrix ( true ) ,
750
+ )
751
+ . unwrap ( ) ;
752
+ self . preprocessing . 0 = Some ( pca) ;
753
+ }
754
+
755
+ fn pca_features ( & self , x : DenseMatrix < f32 > , n : usize ) -> DenseMatrix < f32 > {
741
756
self . preprocessing
742
757
. 0
743
758
. as_ref ( )
@@ -746,11 +761,12 @@ impl SupervisedModel {
746
761
. unwrap ( )
747
762
}
748
763
749
- fn svd_features ( & mut self , x : DenseMatrix < f32 > , n : usize ) -> DenseMatrix < f32 > {
750
- if let None = self . preprocessing . 1 {
751
- let svd = SVD :: fit ( & x, SVDParameters :: default ( ) . with_n_components ( n) ) . unwrap ( ) ;
752
- self . preprocessing . 1 = Some ( svd) ;
753
- }
764
+ fn train_svd ( & mut self , x : DenseMatrix < f32 > , n : usize ) {
765
+ let svd = SVD :: fit ( & x, SVDParameters :: default ( ) . with_n_components ( n) ) . unwrap ( ) ;
766
+ self . preprocessing . 1 = Some ( svd) ;
767
+ }
768
+
769
+ fn svd_features ( & self , x : DenseMatrix < f32 > , n : usize ) -> DenseMatrix < f32 > {
754
770
self . preprocessing
755
771
. 1
756
772
. as_ref ( )
@@ -759,7 +775,7 @@ impl SupervisedModel {
759
775
. unwrap ( )
760
776
}
761
777
762
- fn preprocess ( & mut self , x : DenseMatrix < f32 > ) -> DenseMatrix < f32 > {
778
+ fn preprocess ( & self , x : DenseMatrix < f32 > ) -> DenseMatrix < f32 > {
763
779
match self . settings . preprocessing {
764
780
PreProcessing :: None => x,
765
781
PreProcessing :: AddInteractions => SupervisedModel :: interaction_features ( x) ,
0 commit comments