|
1 | 1 | use crate::{
|
2 |
| - api::Predictor, |
| 2 | + api::{Predictor, SupervisedEstimator}, |
3 | 3 | error::{Failed, FailedError},
|
4 | 4 | linalg::Matrix,
|
5 | 5 | math::num::RealNumber,
|
6 | 6 | };
|
7 | 7 |
|
8 | 8 | use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
|
9 | 9 |
|
10 |
| -/// grid search results. |
11 |
| -#[derive(Clone, Debug)] |
12 |
| -pub struct GridSearchResult<T: RealNumber, I: Clone> { |
13 |
| - /// Vector with test scores on each cv split |
14 |
| - pub cross_validation_result: CrossValidationResult<T>, |
15 |
| - /// Vector with training scores on each cv split |
16 |
| - pub parameters: I, |
17 |
| -} |
18 |
| - |
19 |
| -/// Search for the best estimator by testing all possible combinations with cross-validation using given metric. |
20 |
| -/// * `fit_estimator` - a `fit` function of an estimator |
21 |
| -/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes. |
22 |
| -/// * `y` - target values, should be of size _N_ |
23 |
| -/// * `parameter_search` - an iterator for parameters that will be tested. |
24 |
| -/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html) |
25 |
| -/// * `score` - a metric to use for evaluation, see [metrics](../metrics/index.html) |
26 |
| -pub fn grid_search<T, M, I, E, K, F, S>( |
27 |
| - fit_estimator: F, |
28 |
| - x: &M, |
29 |
| - y: &M::RowVector, |
30 |
| - parameter_search: I, |
31 |
| - cv: K, |
32 |
| - score: S, |
33 |
| -) -> Result<GridSearchResult<T, I::Item>, Failed> |
34 |
| -where |
| 10 | +/// Parameters for GridSearchCV |
| 11 | +#[derive(Debug)] |
| 12 | +pub struct GridSearchCVParameters< |
35 | 13 | T: RealNumber,
|
36 | 14 | M: Matrix<T>,
|
37 |
| - I: Iterator, |
38 |
| - I::Item: Clone, |
| 15 | + C: Clone, |
| 16 | + I: Iterator<Item = C>, |
39 | 17 | E: Predictor<M, M::RowVector>,
|
| 18 | + F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>, |
40 | 19 | K: BaseKFold,
|
41 |
| - F: Fn(&M, &M::RowVector, I::Item) -> Result<E, Failed>, |
42 | 20 | S: Fn(&M::RowVector, &M::RowVector) -> T,
|
| 21 | +> { |
| 22 | + _phantom: std::marker::PhantomData<(T, M)>, |
| 23 | + |
| 24 | + parameters_search: I, |
| 25 | + estimator: F, |
| 26 | + score: S, |
| 27 | + cv: K, |
| 28 | +} |
| 29 | + |
| 30 | +impl< |
| 31 | + T: RealNumber, |
| 32 | + M: Matrix<T>, |
| 33 | + C: Clone, |
| 34 | + I: Iterator<Item = C>, |
| 35 | + E: Predictor<M, M::RowVector>, |
| 36 | + F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>, |
| 37 | + K: BaseKFold, |
| 38 | + S: Fn(&M::RowVector, &M::RowVector) -> T, |
| 39 | + > GridSearchCVParameters<T, M, C, I, E, F, K, S> |
43 | 40 | {
|
44 |
| - let mut best_result: Option<CrossValidationResult<T>> = None; |
45 |
| - let mut best_parameters = None; |
| 41 | + /// Create new GridSearchCVParameters |
| 42 | + pub fn new(parameters_search: I, estimator: F, score: S, cv: K) -> Self { |
| 43 | + GridSearchCVParameters { |
| 44 | + _phantom: std::marker::PhantomData, |
| 45 | + parameters_search, |
| 46 | + estimator, |
| 47 | + score, |
| 48 | + cv, |
| 49 | + } |
| 50 | + } |
| 51 | +} |
| 52 | +/// Exhaustive search over specified parameter values for an estimator. |
| 53 | +#[derive(Debug)] |
| 54 | +pub struct GridSearchCV<T: RealNumber, M: Matrix<T>, C: Clone, E: Predictor<M, M::RowVector>> { |
| 55 | + _phantom: std::marker::PhantomData<(T, M)>, |
| 56 | + predictor: E, |
| 57 | + /// Cross validation results. |
| 58 | + pub cross_validation_result: CrossValidationResult<T>, |
| 59 | + /// best parameter |
| 60 | + pub best_parameter: C, |
| 61 | +} |
| 62 | + |
| 63 | +impl<T: RealNumber, M: Matrix<T>, E: Predictor<M, M::RowVector>, C: Clone> |
| 64 | + GridSearchCV<T, M, C, E> |
| 65 | +{ |
| 66 | + /// Search for the best estimator by testing all possible combinations with cross-validation using given metric. |
| 67 | + /// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes. |
| 68 | + /// * `y` - target values, should be of size _N_ |
| 69 | + /// * `gs_parameters` - GridSearchCVParameters struct |
| 70 | + pub fn fit< |
| 71 | + I: Iterator<Item = C>, |
| 72 | + K: BaseKFold, |
| 73 | + F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>, |
| 74 | + S: Fn(&M::RowVector, &M::RowVector) -> T, |
| 75 | + >( |
| 76 | + x: &M, |
| 77 | + y: &M::RowVector, |
| 78 | + gs_parameters: GridSearchCVParameters<T, M, C, I, E, F, K, S>, |
| 79 | + ) -> Result<Self, Failed> { |
| 80 | + let mut best_result: Option<CrossValidationResult<T>> = None; |
| 81 | + let mut best_parameters = None; |
| 82 | + let parameters_search = gs_parameters.parameters_search; |
| 83 | + let estimator = gs_parameters.estimator; |
| 84 | + let cv = gs_parameters.cv; |
| 85 | + let score = gs_parameters.score; |
46 | 86 |
|
47 |
| - for parameters in parameter_search { |
48 |
| - let result = cross_validate(&fit_estimator, x, y, ¶meters, &cv, &score)?; |
49 |
| - if best_result.is_none() |
50 |
| - || result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score() |
| 87 | + for parameters in parameters_search { |
| 88 | + let result = cross_validate(&estimator, x, y, ¶meters, &cv, &score)?; |
| 89 | + if best_result.is_none() |
| 90 | + || result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score() |
| 91 | + { |
| 92 | + best_parameters = Some(parameters); |
| 93 | + best_result = Some(result); |
| 94 | + } |
| 95 | + } |
| 96 | + |
| 97 | + if let (Some(best_parameter), Some(cross_validation_result)) = |
| 98 | + (best_parameters, best_result) |
51 | 99 | {
|
52 |
| - best_parameters = Some(parameters); |
53 |
| - best_result = Some(result); |
| 100 | + let predictor = estimator(x, y, best_parameter.clone())?; |
| 101 | + Ok(Self { |
| 102 | + _phantom: gs_parameters._phantom, |
| 103 | + predictor, |
| 104 | + cross_validation_result, |
| 105 | + best_parameter, |
| 106 | + }) |
| 107 | + } else { |
| 108 | + Err(Failed::because( |
| 109 | + FailedError::FindFailed, |
| 110 | + "there were no parameter sets found", |
| 111 | + )) |
54 | 112 | }
|
55 | 113 | }
|
56 | 114 |
|
57 |
| - if let (Some(parameters), Some(cross_validation_result)) = (best_parameters, best_result) { |
58 |
| - Ok(GridSearchResult { |
59 |
| - cross_validation_result, |
60 |
| - parameters, |
61 |
| - }) |
62 |
| - } else { |
63 |
| - Err(Failed::because( |
64 |
| - FailedError::FindFailed, |
65 |
| - "there were no parameter sets found", |
66 |
| - )) |
| 115 | + /// Return grid search cross validation results |
| 116 | + pub fn cv_results(&self) -> &CrossValidationResult<T> { |
| 117 | + &self.cross_validation_result |
| 118 | + } |
| 119 | + |
| 120 | + /// Return best parameters found |
| 121 | + pub fn best_parameters(&self) -> &C { |
| 122 | + &self.best_parameter |
| 123 | + } |
| 124 | + |
| 125 | + /// Call predict on the estimator with the best found parameters |
| 126 | + pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> { |
| 127 | + self.predictor.predict(x) |
| 128 | + } |
| 129 | +} |
| 130 | + |
| 131 | +impl< |
| 132 | + T: RealNumber, |
| 133 | + M: Matrix<T>, |
| 134 | + C: Clone, |
| 135 | + I: Iterator<Item = C>, |
| 136 | + E: Predictor<M, M::RowVector>, |
| 137 | + F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>, |
| 138 | + K: BaseKFold, |
| 139 | + S: Fn(&M::RowVector, &M::RowVector) -> T, |
| 140 | + > SupervisedEstimator<M, M::RowVector, GridSearchCVParameters<T, M, C, I, E, F, K, S>> |
| 141 | + for GridSearchCV<T, M, C, E> |
| 142 | +{ |
| 143 | + fn fit( |
| 144 | + x: &M, |
| 145 | + y: &M::RowVector, |
| 146 | + parameters: GridSearchCVParameters<T, M, C, I, E, F, K, S>, |
| 147 | + ) -> Result<Self, Failed> { |
| 148 | + GridSearchCV::fit(x, y, parameters) |
| 149 | + } |
| 150 | +} |
| 151 | + |
| 152 | +impl<T: RealNumber, M: Matrix<T>, C: Clone, E: Predictor<M, M::RowVector>> |
| 153 | + Predictor<M, M::RowVector> for GridSearchCV<T, M, C, E> |
| 154 | +{ |
| 155 | + fn predict(&self, x: &M) -> Result<M::RowVector, Failed> { |
| 156 | + self.predict(x) |
67 | 157 | }
|
68 | 158 | }
|
69 | 159 |
|
70 | 160 | #[cfg(test)]
|
71 | 161 | mod tests {
|
| 162 | + |
72 | 163 | use crate::{
|
73 | 164 | linalg::naive::dense_matrix::DenseMatrix,
|
74 | 165 | linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters},
|
75 | 166 | metrics::accuracy,
|
76 |
| - model_selection::{hyper_tuning::grid_search, KFold}, |
| 167 | + model_selection::{ |
| 168 | + hyper_tuning::grid_search::{self, GridSearchCVParameters}, |
| 169 | + KFold, |
| 170 | + }, |
77 | 171 | };
|
| 172 | + use grid_search::GridSearchCV; |
78 | 173 |
|
79 | 174 | #[test]
|
80 | 175 | fn test_grid_search() {
|
@@ -114,16 +209,28 @@ mod tests {
|
114 | 209 | ..Default::default()
|
115 | 210 | };
|
116 | 211 |
|
117 |
| - let results = grid_search( |
118 |
| - LogisticRegression::fit, |
| 212 | + let grid_search = GridSearchCV::fit( |
119 | 213 | &x,
|
120 | 214 | &y,
|
121 |
| - parameters.into_iter(), |
122 |
| - cv, |
123 |
| - &accuracy, |
| 215 | + GridSearchCVParameters { |
| 216 | + estimator: LogisticRegression::fit, |
| 217 | + score: accuracy, |
| 218 | + cv, |
| 219 | + parameters_search: parameters.into_iter(), |
| 220 | + _phantom: Default::default(), |
| 221 | + }, |
124 | 222 | )
|
125 | 223 | .unwrap();
|
| 224 | + let best_parameters = grid_search.best_parameters(); |
| 225 | + |
| 226 | + assert!([1.].contains(&best_parameters.alpha)); |
| 227 | + |
| 228 | + let cv_results = grid_search.cv_results(); |
| 229 | + |
| 230 | + assert_eq!(cv_results.mean_test_score(), 0.9); |
126 | 231 |
|
127 |
| - assert!([0., 1.].contains(&results.parameters.alpha)); |
| 232 | + let x = DenseMatrix::from_2d_array(&[&[5., 3., 1., 0.]]); |
| 233 | + let result = grid_search.predict(&x).unwrap(); |
| 234 | + assert_eq!(result, vec![0.]); |
128 | 235 | }
|
129 | 236 | }
|
0 commit comments