diff --git a/src/model_selection/hyper_tuning/grid_search.rs b/src/model_selection/hyper_tuning/grid_search.rs index 053611a1..1544faf0 100644 --- a/src/model_selection/hyper_tuning/grid_search.rs +++ b/src/model_selection/hyper_tuning/grid_search.rs @@ -1,5 +1,5 @@ use crate::{ - api::Predictor, + api::{Predictor, SupervisedEstimator}, error::{Failed, FailedError}, linalg::Matrix, math::num::RealNumber, @@ -7,74 +7,169 @@ use crate::{ use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult}; -/// grid search results. -#[derive(Clone, Debug)] -pub struct GridSearchResult { - /// Vector with test scores on each cv split - pub cross_validation_result: CrossValidationResult, - /// Vector with training scores on each cv split - pub parameters: I, -} - -/// Search for the best estimator by testing all possible combinations with cross-validation using given metric. -/// * `fit_estimator` - a `fit` function of an estimator -/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes. -/// * `y` - target values, should be of size _N_ -/// * `parameter_search` - an iterator for parameters that will be tested. -/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html) -/// * `score` - a metric to use for evaluation, see [metrics](../metrics/index.html) -pub fn grid_search( - fit_estimator: F, - x: &M, - y: &M::RowVector, - parameter_search: I, - cv: K, - score: S, -) -> Result, Failed> -where +/// Parameters for GridSearchCV +#[derive(Debug)] +pub struct GridSearchCVParameters< T: RealNumber, M: Matrix, - I: Iterator, - I::Item: Clone, + C: Clone, + I: Iterator, E: Predictor, + F: Fn(&M, &M::RowVector, C) -> Result, K: BaseKFold, - F: Fn(&M, &M::RowVector, I::Item) -> Result, S: Fn(&M::RowVector, &M::RowVector) -> T, +> { + _phantom: std::marker::PhantomData<(T, M)>, + + parameters_search: I, + estimator: F, + score: S, + cv: K, +} + +impl< + T: RealNumber, + M: Matrix, + C: Clone, + I: Iterator, + E: Predictor, + F: Fn(&M, &M::RowVector, C) -> Result, + K: BaseKFold, + S: Fn(&M::RowVector, &M::RowVector) -> T, + > GridSearchCVParameters { - let mut best_result: Option> = None; - let mut best_parameters = None; + /// Create new GridSearchCVParameters + pub fn new(parameters_search: I, estimator: F, score: S, cv: K) -> Self { + GridSearchCVParameters { + _phantom: std::marker::PhantomData, + parameters_search, + estimator, + score, + cv, + } + } +} +/// Exhaustive search over specified parameter values for an estimator. +#[derive(Debug)] +pub struct GridSearchCV, C: Clone, E: Predictor> { + _phantom: std::marker::PhantomData<(T, M)>, + predictor: E, + /// Cross validation results. + pub cross_validation_result: CrossValidationResult, + /// best parameter + pub best_parameter: C, +} + +impl, E: Predictor, C: Clone> + GridSearchCV +{ + /// Search for the best estimator by testing all possible combinations with cross-validation using given metric. + /// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes. + /// * `y` - target values, should be of size _N_ + /// * `gs_parameters` - GridSearchCVParameters struct + pub fn fit< + I: Iterator, + K: BaseKFold, + F: Fn(&M, &M::RowVector, C) -> Result, + S: Fn(&M::RowVector, &M::RowVector) -> T, + >( + x: &M, + y: &M::RowVector, + gs_parameters: GridSearchCVParameters, + ) -> Result { + let mut best_result: Option> = None; + let mut best_parameters = None; + let parameters_search = gs_parameters.parameters_search; + let estimator = gs_parameters.estimator; + let cv = gs_parameters.cv; + let score = gs_parameters.score; - for parameters in parameter_search { - let result = cross_validate(&fit_estimator, x, y, ¶meters, &cv, &score)?; - if best_result.is_none() - || result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score() + for parameters in parameters_search { + let result = cross_validate(&estimator, x, y, ¶meters, &cv, &score)?; + if best_result.is_none() + || result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score() + { + best_parameters = Some(parameters); + best_result = Some(result); + } + } + + if let (Some(best_parameter), Some(cross_validation_result)) = + (best_parameters, best_result) { - best_parameters = Some(parameters); - best_result = Some(result); + let predictor = estimator(x, y, best_parameter.clone())?; + Ok(Self { + _phantom: gs_parameters._phantom, + predictor, + cross_validation_result, + best_parameter, + }) + } else { + Err(Failed::because( + FailedError::FindFailed, + "there were no parameter sets found", + )) } } - if let (Some(parameters), Some(cross_validation_result)) = (best_parameters, best_result) { - Ok(GridSearchResult { - cross_validation_result, - parameters, - }) - } else { - Err(Failed::because( - FailedError::FindFailed, - "there were no parameter sets found", - )) + /// Return grid search cross validation results + pub fn cv_results(&self) -> &CrossValidationResult { + &self.cross_validation_result + } + + /// Return best parameters found + pub fn best_parameters(&self) -> &C { + &self.best_parameter + } + + /// Call predict on the estimator with the best found parameters + pub fn predict(&self, x: &M) -> Result { + self.predictor.predict(x) + } +} + +impl< + T: RealNumber, + M: Matrix, + C: Clone, + I: Iterator, + E: Predictor, + F: Fn(&M, &M::RowVector, C) -> Result, + K: BaseKFold, + S: Fn(&M::RowVector, &M::RowVector) -> T, + > SupervisedEstimator> + for GridSearchCV +{ + fn fit( + x: &M, + y: &M::RowVector, + parameters: GridSearchCVParameters, + ) -> Result { + GridSearchCV::fit(x, y, parameters) + } +} + +impl, C: Clone, E: Predictor> + Predictor for GridSearchCV +{ + fn predict(&self, x: &M) -> Result { + self.predict(x) } } #[cfg(test)] mod tests { + use crate::{ linalg::naive::dense_matrix::DenseMatrix, linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters}, metrics::accuracy, - model_selection::{hyper_tuning::grid_search, KFold}, + model_selection::{ + hyper_tuning::grid_search::{self, GridSearchCVParameters}, + KFold, + }, }; + use grid_search::GridSearchCV; #[test] fn test_grid_search() { @@ -114,16 +209,28 @@ mod tests { ..Default::default() }; - let results = grid_search( - LogisticRegression::fit, + let grid_search = GridSearchCV::fit( &x, &y, - parameters.into_iter(), - cv, - &accuracy, + GridSearchCVParameters { + estimator: LogisticRegression::fit, + score: accuracy, + cv, + parameters_search: parameters.into_iter(), + _phantom: Default::default(), + }, ) .unwrap(); + let best_parameters = grid_search.best_parameters(); + + assert!([1.].contains(&best_parameters.alpha)); + + let cv_results = grid_search.cv_results(); + + assert_eq!(cv_results.mean_test_score(), 0.9); - assert!([0., 1.].contains(&results.parameters.alpha)); + let x = DenseMatrix::from_2d_array(&[&[5., 3., 1., 0.]]); + let result = grid_search.predict(&x).unwrap(); + assert_eq!(result, vec![0.]); } } diff --git a/src/model_selection/hyper_tuning/mod.rs b/src/model_selection/hyper_tuning/mod.rs index 6810d1a4..dfe0d06b 100644 --- a/src/model_selection/hyper_tuning/mod.rs +++ b/src/model_selection/hyper_tuning/mod.rs @@ -1,2 +1,2 @@ mod grid_search; -pub use grid_search::{grid_search, GridSearchResult}; +pub use grid_search::{GridSearchCV, GridSearchCVParameters}; diff --git a/src/model_selection/mod.rs b/src/model_selection/mod.rs index 943c143a..f16b9559 100644 --- a/src/model_selection/mod.rs +++ b/src/model_selection/mod.rs @@ -113,7 +113,7 @@ use rand::seq::SliceRandom; pub(crate) mod hyper_tuning; pub(crate) mod kfold; -pub use hyper_tuning::{grid_search, GridSearchResult}; +pub use hyper_tuning::{GridSearchCV, GridSearchCVParameters}; pub use kfold::{KFold, KFoldIter}; /// An interface for the K-Folds cross-validator