Skip to content

Commit 473cdfc

Browse files
committed
refactor: Try to follow similar pattern to other APIs (#180)
Co-authored-by: Luis Moreno <[email protected]>
1 parent ad2e6c2 commit 473cdfc

File tree

3 files changed

+163
-56
lines changed

3 files changed

+163
-56
lines changed
+161-54
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,175 @@
11
use crate::{
2-
api::Predictor,
2+
api::{Predictor, SupervisedEstimator},
33
error::{Failed, FailedError},
44
linalg::Matrix,
55
math::num::RealNumber,
66
};
77

88
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
99

10-
/// grid search results.
11-
#[derive(Clone, Debug)]
12-
pub struct GridSearchResult<T: RealNumber, I: Clone> {
13-
/// Vector with test scores on each cv split
14-
pub cross_validation_result: CrossValidationResult<T>,
15-
/// Vector with training scores on each cv split
16-
pub parameters: I,
17-
}
18-
19-
/// Search for the best estimator by testing all possible combinations with cross-validation using given metric.
20-
/// * `fit_estimator` - a `fit` function of an estimator
21-
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
22-
/// * `y` - target values, should be of size _N_
23-
/// * `parameter_search` - an iterator for parameters that will be tested.
24-
/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html)
25-
/// * `score` - a metric to use for evaluation, see [metrics](../metrics/index.html)
26-
pub fn grid_search<T, M, I, E, K, F, S>(
27-
fit_estimator: F,
28-
x: &M,
29-
y: &M::RowVector,
30-
parameter_search: I,
31-
cv: K,
32-
score: S,
33-
) -> Result<GridSearchResult<T, I::Item>, Failed>
34-
where
10+
/// Parameters for GridSearchCV
11+
#[derive(Debug)]
12+
pub struct GridSearchCVParameters<
3513
T: RealNumber,
3614
M: Matrix<T>,
37-
I: Iterator,
38-
I::Item: Clone,
15+
C: Clone,
16+
I: Iterator<Item = C>,
3917
E: Predictor<M, M::RowVector>,
18+
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
4019
K: BaseKFold,
41-
F: Fn(&M, &M::RowVector, I::Item) -> Result<E, Failed>,
4220
S: Fn(&M::RowVector, &M::RowVector) -> T,
21+
> {
22+
_phantom: std::marker::PhantomData<(T, M)>,
23+
24+
parameters_search: I,
25+
estimator: F,
26+
score: S,
27+
cv: K,
28+
}
29+
30+
impl<
31+
T: RealNumber,
32+
M: Matrix<T>,
33+
C: Clone,
34+
I: Iterator<Item = C>,
35+
E: Predictor<M, M::RowVector>,
36+
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
37+
K: BaseKFold,
38+
S: Fn(&M::RowVector, &M::RowVector) -> T,
39+
> GridSearchCVParameters<T, M, C, I, E, F, K, S>
4340
{
44-
let mut best_result: Option<CrossValidationResult<T>> = None;
45-
let mut best_parameters = None;
41+
/// Create new GridSearchCVParameters
42+
pub fn new(parameters_search: I, estimator: F, score: S, cv: K) -> Self {
43+
GridSearchCVParameters {
44+
_phantom: std::marker::PhantomData,
45+
parameters_search,
46+
estimator,
47+
score,
48+
cv,
49+
}
50+
}
51+
}
52+
/// Exhaustive search over specified parameter values for an estimator.
53+
#[derive(Debug)]
54+
pub struct GridSearchCV<T: RealNumber, M: Matrix<T>, C: Clone, E: Predictor<M, M::RowVector>> {
55+
_phantom: std::marker::PhantomData<(T, M)>,
56+
predictor: E,
57+
/// Cross validation results.
58+
pub cross_validation_result: CrossValidationResult<T>,
59+
/// best parameter
60+
pub best_parameter: C,
61+
}
62+
63+
impl<T: RealNumber, M: Matrix<T>, E: Predictor<M, M::RowVector>, C: Clone>
64+
GridSearchCV<T, M, C, E>
65+
{
66+
/// Search for the best estimator by testing all possible combinations with cross-validation using given metric.
67+
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
68+
/// * `y` - target values, should be of size _N_
69+
/// * `gs_parameters` - GridSearchCVParameters struct
70+
pub fn fit<
71+
I: Iterator<Item = C>,
72+
K: BaseKFold,
73+
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
74+
S: Fn(&M::RowVector, &M::RowVector) -> T,
75+
>(
76+
x: &M,
77+
y: &M::RowVector,
78+
gs_parameters: GridSearchCVParameters<T, M, C, I, E, F, K, S>,
79+
) -> Result<Self, Failed> {
80+
let mut best_result: Option<CrossValidationResult<T>> = None;
81+
let mut best_parameters = None;
82+
let parameters_search = gs_parameters.parameters_search;
83+
let estimator = gs_parameters.estimator;
84+
let cv = gs_parameters.cv;
85+
let score = gs_parameters.score;
4686

47-
for parameters in parameter_search {
48-
let result = cross_validate(&fit_estimator, x, y, &parameters, &cv, &score)?;
49-
if best_result.is_none()
50-
|| result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score()
87+
for parameters in parameters_search {
88+
let result = cross_validate(&estimator, x, y, &parameters, &cv, &score)?;
89+
if best_result.is_none()
90+
|| result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score()
91+
{
92+
best_parameters = Some(parameters);
93+
best_result = Some(result);
94+
}
95+
}
96+
97+
if let (Some(best_parameter), Some(cross_validation_result)) =
98+
(best_parameters, best_result)
5199
{
52-
best_parameters = Some(parameters);
53-
best_result = Some(result);
100+
let predictor = estimator(x, y, best_parameter.clone())?;
101+
Ok(Self {
102+
_phantom: gs_parameters._phantom,
103+
predictor,
104+
cross_validation_result,
105+
best_parameter,
106+
})
107+
} else {
108+
Err(Failed::because(
109+
FailedError::FindFailed,
110+
"there were no parameter sets found",
111+
))
54112
}
55113
}
56114

57-
if let (Some(parameters), Some(cross_validation_result)) = (best_parameters, best_result) {
58-
Ok(GridSearchResult {
59-
cross_validation_result,
60-
parameters,
61-
})
62-
} else {
63-
Err(Failed::because(
64-
FailedError::FindFailed,
65-
"there were no parameter sets found",
66-
))
115+
/// Return grid search cross validation results
116+
pub fn cv_results(&self) -> &CrossValidationResult<T> {
117+
&self.cross_validation_result
118+
}
119+
120+
/// Return best parameters found
121+
pub fn best_parameters(&self) -> &C {
122+
&self.best_parameter
123+
}
124+
125+
/// Call predict on the estimator with the best found parameters
126+
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
127+
self.predictor.predict(x)
128+
}
129+
}
130+
131+
impl<
132+
T: RealNumber,
133+
M: Matrix<T>,
134+
C: Clone,
135+
I: Iterator<Item = C>,
136+
E: Predictor<M, M::RowVector>,
137+
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
138+
K: BaseKFold,
139+
S: Fn(&M::RowVector, &M::RowVector) -> T,
140+
> SupervisedEstimator<M, M::RowVector, GridSearchCVParameters<T, M, C, I, E, F, K, S>>
141+
for GridSearchCV<T, M, C, E>
142+
{
143+
fn fit(
144+
x: &M,
145+
y: &M::RowVector,
146+
parameters: GridSearchCVParameters<T, M, C, I, E, F, K, S>,
147+
) -> Result<Self, Failed> {
148+
GridSearchCV::fit(x, y, parameters)
149+
}
150+
}
151+
152+
impl<T: RealNumber, M: Matrix<T>, C: Clone, E: Predictor<M, M::RowVector>>
153+
Predictor<M, M::RowVector> for GridSearchCV<T, M, C, E>
154+
{
155+
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
156+
self.predict(x)
67157
}
68158
}
69159

70160
#[cfg(test)]
71161
mod tests {
162+
72163
use crate::{
73164
linalg::naive::dense_matrix::DenseMatrix,
74165
linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters},
75166
metrics::accuracy,
76-
model_selection::{hyper_tuning::grid_search, KFold},
167+
model_selection::{
168+
hyper_tuning::grid_search::{self, GridSearchCVParameters},
169+
KFold,
170+
},
77171
};
172+
use grid_search::GridSearchCV;
78173

79174
#[test]
80175
fn test_grid_search() {
@@ -114,16 +209,28 @@ mod tests {
114209
..Default::default()
115210
};
116211

117-
let results = grid_search(
118-
LogisticRegression::fit,
212+
let grid_search = GridSearchCV::fit(
119213
&x,
120214
&y,
121-
parameters.into_iter(),
122-
cv,
123-
&accuracy,
215+
GridSearchCVParameters {
216+
estimator: LogisticRegression::fit,
217+
score: accuracy,
218+
cv,
219+
parameters_search: parameters.into_iter(),
220+
_phantom: Default::default(),
221+
},
124222
)
125223
.unwrap();
224+
let best_parameters = grid_search.best_parameters();
225+
226+
assert!([1.].contains(&best_parameters.alpha));
227+
228+
let cv_results = grid_search.cv_results();
229+
230+
assert_eq!(cv_results.mean_test_score(), 0.9);
126231

127-
assert!([0., 1.].contains(&results.parameters.alpha));
232+
let x = DenseMatrix::from_2d_array(&[&[5., 3., 1., 0.]]);
233+
let result = grid_search.predict(&x).unwrap();
234+
assert_eq!(result, vec![0.]);
128235
}
129236
}
+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
mod grid_search;
2-
pub use grid_search::{grid_search, GridSearchResult};
2+
pub use grid_search::{GridSearchCV, GridSearchCVParameters};

src/model_selection/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ use rand::seq::SliceRandom;
113113
pub(crate) mod hyper_tuning;
114114
pub(crate) mod kfold;
115115

116-
pub use hyper_tuning::{grid_search, GridSearchResult};
116+
pub use hyper_tuning::{GridSearchCV, GridSearchCVParameters};
117117
pub use kfold::{KFold, KFoldIter};
118118

119119
/// An interface for the K-Folds cross-validator

0 commit comments

Comments
 (0)