16
16
from scipy import sparse
17
17
from sklearn import __version__ as sklearn_version
18
18
from sklearn .cross_validation import cross_val_score as sklearn_cv_score
19
+ from sklearn .grid_search import GridSearchCV as sklearn_grid_search
19
20
from sklearn .datasets import load_iris
20
21
from sklearn .pipeline import Pipeline
21
22
from sklearn .svm import SVC
@@ -113,7 +114,7 @@ def transform(self, X):
113
114
114
115
class Adder (BaseEstimator , TransformerMixin ):
115
116
116
- def __init__ (self , num_to_add ):
117
+ def __init__ (self , num_to_add = 0 ):
117
118
self .num_to_add = num_to_add
118
119
119
120
def fit (self , X , y = None ):
@@ -125,7 +126,7 @@ def transform(self, X):
125
126
126
127
class Divider (BaseEstimator , TransformerMixin ):
127
128
128
- def __init__ (self , denominator ):
129
+ def __init__ (self , denominator = 1 ):
129
130
self .denominator = denominator
130
131
131
132
def fit (self , X , y = None ):
@@ -1068,3 +1069,28 @@ def test_setting_parameters_to_a_list_of_transformers():
1068
1069
1069
1070
assert adder .num_to_add == 0
1070
1071
assert divider .denominator == 1
1072
+
1073
+
1074
+ def test_compliant_with_grid_search (iris_dataframe ):
1075
+ pipeline = Pipeline ([
1076
+ ('mapper' , DataFrameMapper ([
1077
+ (['petal length (cm)' ], StandardScaler ()),
1078
+ (['petal width (cm)' ], StandardScaler ()),
1079
+ (['sepal length (cm)' ], StandardScaler ()),
1080
+ (['sepal width (cm)' ], StandardScaler ()),
1081
+ ])),
1082
+ ('classifier' , SVC (kernel = 'linear' ))
1083
+ ])
1084
+ param_grid = {
1085
+ 'mapper__petal length (cm)__with_mean' : [True , False ],
1086
+ 'mapper__petal width (cm)__with_mean' : [True , False ],
1087
+ 'mapper__sepal length (cm)__with_mean' : [True , False ],
1088
+ 'mapper__sepal width (cm)__with_mean' : [True , False ]
1089
+ }
1090
+ data = iris_dataframe .drop ("species" , axis = 1 )
1091
+ labels = iris_dataframe ["species" ]
1092
+
1093
+ grid_search = sklearn_grid_search (pipeline , param_grid = param_grid )
1094
+ grid_search .fit (data , labels )
1095
+
1096
+ assert len (grid_search .grid_scores_ ) == 2 ** len (param_grid )
0 commit comments