Skip to content

Commit 89e6d11

Browse files
committed
Testing pipeline with GridSearchCV
1 parent ae88b71 commit 89e6d11

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

sklearn_pandas/dataframe_mapper.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ def _build_feature(columns, transformers, options={}):
3939
return (columns, _build_transformer(transformers), options)
4040

4141

42+
def _build_feature_name(values):
43+
if isinstance(values, list):
44+
values = '-'.join([str(value) for value in values])
45+
return values
46+
47+
4248
def _get_feature_names(estimator):
4349
"""
4450
Attempt to extract feature names based on a given estimator
@@ -420,7 +426,11 @@ def get_params(self, deep=True):
420426
return out
421427

422428
def set_params(self, **params):
423-
features = dict(self.features)
429+
features = {}
430+
for column_names, transformers in self.features:
431+
key = _build_feature_name(column_names)
432+
features[key] = transformers
433+
424434
assignment = defaultdict(dict)
425435

426436
for key, value in params.items():
@@ -442,5 +452,3 @@ def set_params(self, **params):
442452

443453
for instance in transformers_instances:
444454
instance.set_params(**assignment[id(instance)])
445-
446-

tests/test_dataframe_mapper.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from scipy import sparse
1717
from sklearn import __version__ as sklearn_version
1818
from sklearn.cross_validation import cross_val_score as sklearn_cv_score
19+
from sklearn.grid_search import GridSearchCV as sklearn_grid_search
1920
from sklearn.datasets import load_iris
2021
from sklearn.pipeline import Pipeline
2122
from sklearn.svm import SVC
@@ -113,7 +114,7 @@ def transform(self, X):
113114

114115
class Adder(BaseEstimator, TransformerMixin):
115116

116-
def __init__(self, num_to_add):
117+
def __init__(self, num_to_add=0):
117118
self.num_to_add = num_to_add
118119

119120
def fit(self, X, y=None):
@@ -125,7 +126,7 @@ def transform(self, X):
125126

126127
class Divider(BaseEstimator, TransformerMixin):
127128

128-
def __init__(self, denominator):
129+
def __init__(self, denominator=1):
129130
self.denominator = denominator
130131

131132
def fit(self, X, y=None):
@@ -1068,3 +1069,28 @@ def test_setting_parameters_to_a_list_of_transformers():
10681069

10691070
assert adder.num_to_add == 0
10701071
assert divider.denominator == 1
1072+
1073+
1074+
def test_compliant_with_grid_search(iris_dataframe):
1075+
pipeline = Pipeline([
1076+
('mapper', DataFrameMapper([
1077+
(['petal length (cm)'], StandardScaler()),
1078+
(['petal width (cm)'], StandardScaler()),
1079+
(['sepal length (cm)'], StandardScaler()),
1080+
(['sepal width (cm)'], StandardScaler()),
1081+
])),
1082+
('classifier', SVC(kernel='linear'))
1083+
])
1084+
param_grid = {
1085+
'mapper__petal length (cm)__with_mean': [True, False],
1086+
'mapper__petal width (cm)__with_mean': [True, False],
1087+
'mapper__sepal length (cm)__with_mean': [True, False],
1088+
'mapper__sepal width (cm)__with_mean': [True, False]
1089+
}
1090+
data = iris_dataframe.drop("species", axis=1)
1091+
labels = iris_dataframe["species"]
1092+
1093+
grid_search = sklearn_grid_search(pipeline, param_grid=param_grid)
1094+
grid_search.fit(data, labels)
1095+
1096+
assert len(grid_search.grid_scores_) == 2**len(param_grid)

0 commit comments

Comments
 (0)