diff --git a/src/model_selection/hyper_tuning.rs b/src/model_selection/hyper_tuning/grid_search.rs similarity index 53% rename from src/model_selection/hyper_tuning.rs rename to src/model_selection/hyper_tuning/grid_search.rs index cb69da18..053611a1 100644 --- a/src/model_selection/hyper_tuning.rs +++ b/src/model_selection/hyper_tuning/grid_search.rs @@ -1,3 +1,12 @@ +use crate::{ + api::Predictor, + error::{Failed, FailedError}, + linalg::Matrix, + math::num::RealNumber, +}; + +use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult}; + /// grid search results. #[derive(Clone, Debug)] pub struct GridSearchResult { @@ -60,58 +69,61 @@ where #[cfg(test)] mod tests { - use crate::linear::logistic_regression::{ - LogisticRegression, LogisticRegressionSearchParameters, -}; + use crate::{ + linalg::naive::dense_matrix::DenseMatrix, + linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters}, + metrics::accuracy, + model_selection::{hyper_tuning::grid_search, KFold}, + }; - #[test] - fn test_grid_search() { - let x = DenseMatrix::from_2d_array(&[ - &[5.1, 3.5, 1.4, 0.2], - &[4.9, 3.0, 1.4, 0.2], - &[4.7, 3.2, 1.3, 0.2], - &[4.6, 3.1, 1.5, 0.2], - &[5.0, 3.6, 1.4, 0.2], - &[5.4, 3.9, 1.7, 0.4], - &[4.6, 3.4, 1.4, 0.3], - &[5.0, 3.4, 1.5, 0.2], - &[4.4, 2.9, 1.4, 0.2], - &[4.9, 3.1, 1.5, 0.1], - &[7.0, 3.2, 4.7, 1.4], - &[6.4, 3.2, 4.5, 1.5], - &[6.9, 3.1, 4.9, 1.5], - &[5.5, 2.3, 4.0, 1.3], - &[6.5, 2.8, 4.6, 1.5], - &[5.7, 2.8, 4.5, 1.3], - &[6.3, 3.3, 4.7, 1.6], - &[4.9, 2.4, 3.3, 1.0], - &[6.6, 2.9, 4.6, 1.3], - &[5.2, 2.7, 3.9, 1.4], - ]); - let y = vec![ - 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., - ]; + #[test] + fn test_grid_search() { + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]); + let y = vec![ + 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + ]; - let cv = KFold { - n_splits: 5, - ..KFold::default() - }; + let cv = KFold { + n_splits: 5, + ..KFold::default() + }; - let parameters = LogisticRegressionSearchParameters { - alpha: vec![0., 1.], - ..Default::default() - }; + let parameters = LogisticRegressionSearchParameters { + alpha: vec![0., 1.], + ..Default::default() + }; - let results = grid_search( - LogisticRegression::fit, - &x, - &y, - parameters.into_iter(), - cv, - &accuracy, - ) - .unwrap(); + let results = grid_search( + LogisticRegression::fit, + &x, + &y, + parameters.into_iter(), + cv, + &accuracy, + ) + .unwrap(); - assert!([0., 1.].contains(&results.parameters.alpha)); - } + assert!([0., 1.].contains(&results.parameters.alpha)); + } } diff --git a/src/model_selection/hyper_tuning/mod.rs b/src/model_selection/hyper_tuning/mod.rs new file mode 100644 index 00000000..6810d1a4 --- /dev/null +++ b/src/model_selection/hyper_tuning/mod.rs @@ -0,0 +1,2 @@ +mod grid_search; +pub use grid_search::{grid_search, GridSearchResult}; diff --git a/src/model_selection/mod.rs b/src/model_selection/mod.rs index 21cf7ed3..943c143a 100644 --- a/src/model_selection/mod.rs +++ b/src/model_selection/mod.rs @@ -110,8 +110,10 @@ use crate::math::num::RealNumber; use crate::rand::get_rng_impl; use rand::seq::SliceRandom; +pub(crate) mod hyper_tuning; pub(crate) mod kfold; +pub use hyper_tuning::{grid_search, GridSearchResult}; pub use kfold::{KFold, KFoldIter}; /// An interface for the K-Folds cross-validator