Skip to content

Commit 1b7cf5b

Browse files
committed
Predict no longer requires mut
1 parent 1a574c1 commit 1b7cf5b

File tree

1 file changed

+33
-17
lines changed

1 file changed

+33
-17
lines changed

src/lib.rs

+33-17
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,20 @@ impl SupervisedModel {
322322
/// model.train();
323323
/// ```
324324
pub fn train(&mut self) {
325+
// Train any necessary preprocessing
326+
if let PreProcessing::ReplaceWithPCA {
327+
number_of_components,
328+
} = self.settings.preprocessing
329+
{
330+
self.train_pca(self.x_train.clone(), number_of_components);
331+
}
332+
if let PreProcessing::ReplaceWithSVD {
333+
number_of_components,
334+
} = self.settings.preprocessing
335+
{
336+
self.train_svd(self.x_train.clone(), number_of_components);
337+
}
338+
325339
// Preprocess the data
326340
self.x_train = self.preprocess(self.x_train.clone());
327341

@@ -727,17 +741,18 @@ impl SupervisedModel {
727741
x
728742
}
729743

730-
fn pca_features(&mut self, x: DenseMatrix<f32>, n: usize) -> DenseMatrix<f32> {
731-
if let None = self.preprocessing.0 {
732-
let pca = PCA::fit(
733-
&x,
734-
PCAParameters::default()
735-
.with_n_components(n)
736-
.with_use_correlation_matrix(true),
737-
)
738-
.unwrap();
739-
self.preprocessing.0 = Some(pca);
740-
}
744+
fn train_pca(&mut self, x: DenseMatrix<f32>, n: usize) {
745+
let pca = PCA::fit(
746+
&x,
747+
PCAParameters::default()
748+
.with_n_components(n)
749+
.with_use_correlation_matrix(true),
750+
)
751+
.unwrap();
752+
self.preprocessing.0 = Some(pca);
753+
}
754+
755+
fn pca_features(&self, x: DenseMatrix<f32>, n: usize) -> DenseMatrix<f32> {
741756
self.preprocessing
742757
.0
743758
.as_ref()
@@ -746,11 +761,12 @@ impl SupervisedModel {
746761
.unwrap()
747762
}
748763

749-
fn svd_features(&mut self, x: DenseMatrix<f32>, n: usize) -> DenseMatrix<f32> {
750-
if let None = self.preprocessing.1 {
751-
let svd = SVD::fit(&x, SVDParameters::default().with_n_components(n)).unwrap();
752-
self.preprocessing.1 = Some(svd);
753-
}
764+
fn train_svd(&mut self, x: DenseMatrix<f32>, n: usize) {
765+
let svd = SVD::fit(&x, SVDParameters::default().with_n_components(n)).unwrap();
766+
self.preprocessing.1 = Some(svd);
767+
}
768+
769+
fn svd_features(&self, x: DenseMatrix<f32>, n: usize) -> DenseMatrix<f32> {
754770
self.preprocessing
755771
.1
756772
.as_ref()
@@ -759,7 +775,7 @@ impl SupervisedModel {
759775
.unwrap()
760776
}
761777

762-
fn preprocess(&mut self, x: DenseMatrix<f32>) -> DenseMatrix<f32> {
778+
fn preprocess(&self, x: DenseMatrix<f32>) -> DenseMatrix<f32> {
763779
match self.settings.preprocessing {
764780
PreProcessing::None => x,
765781
PreProcessing::AddInteractions => SupervisedModel::interaction_features(x),

0 commit comments

Comments
 (0)