Skip to content

feat: Create GridSearchCV #180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 161 additions & 54 deletions src/model_selection/hyper_tuning/grid_search.rs
Original file line number Diff line number Diff line change
@@ -1,80 +1,175 @@
use crate::{
api::Predictor,
api::{Predictor, SupervisedEstimator},
error::{Failed, FailedError},
linalg::Matrix,
math::num::RealNumber,
};

use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};

/// grid search results.
#[derive(Clone, Debug)]
pub struct GridSearchResult<T: RealNumber, I: Clone> {
/// Vector with test scores on each cv split
pub cross_validation_result: CrossValidationResult<T>,
/// Vector with training scores on each cv split
pub parameters: I,
}

/// Search for the best estimator by testing all possible combinations with cross-validation using given metric.
/// * `fit_estimator` - a `fit` function of an estimator
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
/// * `y` - target values, should be of size _N_
/// * `parameter_search` - an iterator for parameters that will be tested.
/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html)
/// * `score` - a metric to use for evaluation, see [metrics](../metrics/index.html)
pub fn grid_search<T, M, I, E, K, F, S>(
fit_estimator: F,
x: &M,
y: &M::RowVector,
parameter_search: I,
cv: K,
score: S,
) -> Result<GridSearchResult<T, I::Item>, Failed>
where
/// Parameters for GridSearchCV
#[derive(Debug)]
pub struct GridSearchCVParameters<
T: RealNumber,
M: Matrix<T>,
I: Iterator,
I::Item: Clone,
C: Clone,
I: Iterator<Item = C>,
E: Predictor<M, M::RowVector>,
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
K: BaseKFold,
F: Fn(&M, &M::RowVector, I::Item) -> Result<E, Failed>,
S: Fn(&M::RowVector, &M::RowVector) -> T,
> {
_phantom: std::marker::PhantomData<(T, M)>,

parameters_search: I,
estimator: F,
score: S,
cv: K,
}

impl<
T: RealNumber,
M: Matrix<T>,
C: Clone,
I: Iterator<Item = C>,
E: Predictor<M, M::RowVector>,
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
K: BaseKFold,
S: Fn(&M::RowVector, &M::RowVector) -> T,
> GridSearchCVParameters<T, M, C, I, E, F, K, S>
{
let mut best_result: Option<CrossValidationResult<T>> = None;
let mut best_parameters = None;
/// Create new GridSearchCVParameters
pub fn new(parameters_search: I, estimator: F, score: S, cv: K) -> Self {
GridSearchCVParameters {
_phantom: std::marker::PhantomData,
parameters_search,
estimator,
score,
cv,
}
}
}
/// Exhaustive search over specified parameter values for an estimator.
#[derive(Debug)]
pub struct GridSearchCV<T: RealNumber, M: Matrix<T>, C: Clone, E: Predictor<M, M::RowVector>> {
_phantom: std::marker::PhantomData<(T, M)>,
predictor: E,
/// Cross validation results.
pub cross_validation_result: CrossValidationResult<T>,
/// best parameter
pub best_parameter: C,
}

impl<T: RealNumber, M: Matrix<T>, E: Predictor<M, M::RowVector>, C: Clone>
GridSearchCV<T, M, C, E>
{
/// Search for the best estimator by testing all possible combinations with cross-validation using given metric.
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
/// * `y` - target values, should be of size _N_
/// * `gs_parameters` - GridSearchCVParameters struct
pub fn fit<
I: Iterator<Item = C>,
K: BaseKFold,
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
S: Fn(&M::RowVector, &M::RowVector) -> T,
>(
x: &M,
y: &M::RowVector,
gs_parameters: GridSearchCVParameters<T, M, C, I, E, F, K, S>,
) -> Result<Self, Failed> {
let mut best_result: Option<CrossValidationResult<T>> = None;
let mut best_parameters = None;
let parameters_search = gs_parameters.parameters_search;
let estimator = gs_parameters.estimator;
let cv = gs_parameters.cv;
let score = gs_parameters.score;

for parameters in parameter_search {
let result = cross_validate(&fit_estimator, x, y, &parameters, &cv, &score)?;
if best_result.is_none()
|| result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score()
for parameters in parameters_search {
let result = cross_validate(&estimator, x, y, &parameters, &cv, &score)?;
if best_result.is_none()
|| result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score()
{
best_parameters = Some(parameters);
best_result = Some(result);
}
}

if let (Some(best_parameter), Some(cross_validation_result)) =
(best_parameters, best_result)
{
best_parameters = Some(parameters);
best_result = Some(result);
let predictor = estimator(x, y, best_parameter.clone())?;
Ok(Self {
_phantom: gs_parameters._phantom,
predictor,
cross_validation_result,
best_parameter,
})
} else {
Err(Failed::because(
FailedError::FindFailed,
"there were no parameter sets found",
))
}
}

if let (Some(parameters), Some(cross_validation_result)) = (best_parameters, best_result) {
Ok(GridSearchResult {
cross_validation_result,
parameters,
})
} else {
Err(Failed::because(
FailedError::FindFailed,
"there were no parameter sets found",
))
/// Return grid search cross validation results
pub fn cv_results(&self) -> &CrossValidationResult<T> {
&self.cross_validation_result
}

/// Return best parameters found
pub fn best_parameters(&self) -> &C {
&self.best_parameter
}

/// Call predict on the estimator with the best found parameters
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predictor.predict(x)
}
}

impl<
T: RealNumber,
M: Matrix<T>,
C: Clone,
I: Iterator<Item = C>,
E: Predictor<M, M::RowVector>,
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
K: BaseKFold,
S: Fn(&M::RowVector, &M::RowVector) -> T,
> SupervisedEstimator<M, M::RowVector, GridSearchCVParameters<T, M, C, I, E, F, K, S>>
for GridSearchCV<T, M, C, E>
{
fn fit(
x: &M,
y: &M::RowVector,
parameters: GridSearchCVParameters<T, M, C, I, E, F, K, S>,
) -> Result<Self, Failed> {
GridSearchCV::fit(x, y, parameters)
}
}

impl<T: RealNumber, M: Matrix<T>, C: Clone, E: Predictor<M, M::RowVector>>
Predictor<M, M::RowVector> for GridSearchCV<T, M, C, E>
{
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
}
}

#[cfg(test)]
mod tests {

use crate::{
linalg::naive::dense_matrix::DenseMatrix,
linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters},
metrics::accuracy,
model_selection::{hyper_tuning::grid_search, KFold},
model_selection::{
hyper_tuning::grid_search::{self, GridSearchCVParameters},
KFold,
},
};
use grid_search::GridSearchCV;

#[test]
fn test_grid_search() {
Expand Down Expand Up @@ -114,16 +209,28 @@ mod tests {
..Default::default()
};

let results = grid_search(
LogisticRegression::fit,
let grid_search = GridSearchCV::fit(
&x,
&y,
parameters.into_iter(),
cv,
&accuracy,
GridSearchCVParameters {
estimator: LogisticRegression::fit,
score: accuracy,
cv,
parameters_search: parameters.into_iter(),
_phantom: Default::default(),
},
)
.unwrap();
let best_parameters = grid_search.best_parameters();

assert!([1.].contains(&best_parameters.alpha));

let cv_results = grid_search.cv_results();

assert_eq!(cv_results.mean_test_score(), 0.9);

assert!([0., 1.].contains(&results.parameters.alpha));
let x = DenseMatrix::from_2d_array(&[&[5., 3., 1., 0.]]);
let result = grid_search.predict(&x).unwrap();
assert_eq!(result, vec![0.]);
}
}
2 changes: 1 addition & 1 deletion src/model_selection/hyper_tuning/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
mod grid_search;
pub use grid_search::{grid_search, GridSearchResult};
pub use grid_search::{GridSearchCV, GridSearchCVParameters};
2 changes: 1 addition & 1 deletion src/model_selection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ use rand::seq::SliceRandom;
pub(crate) mod hyper_tuning;
pub(crate) mod kfold;

pub use hyper_tuning::{grid_search, GridSearchResult};
pub use hyper_tuning::{GridSearchCV, GridSearchCVParameters};
pub use kfold::{KFold, KFoldIter};

/// An interface for the K-Folds cross-validator
Expand Down