Skip to content

Commit 0faead7

Browse files
committed
Fit_transform all transformers iteratively when transformers is a list.
1 parent 3328aa5 commit 0faead7

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

sklearn_pandas/__init__.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,12 @@ def fit(self, X, y=None):
122122
for columns, transformers in self.features:
123123
if transformers is not None:
124124
if isinstance(transformers, list):
125-
for transformer in transformers:
126-
transformer.fit(self._get_col_subset(X, columns))
125+
# first fit_transform all transformers except the last one
126+
Xt = self._get_col_subset(X, columns)
127+
for transformer in transformers[:-1]:
128+
Xt = transformer.fit_transform(Xt)
129+
# then fit the last one without transformation
130+
transformers[-1].fit(Xt)
127131
else:
128132
transformers.fit(self._get_col_subset(X, columns))
129133
return self
@@ -139,14 +143,14 @@ def transform(self, X):
139143
# columns could be a string or list of
140144
# strings; we don't care because pandas
141145
# will handle either.
142-
final_feature = self._get_col_subset(X, columns)
146+
Xt = self._get_col_subset(X, columns)
143147
if transformers is not None:
144148
if isinstance(transformers, list):
145149
for transformer in transformers:
146-
final_feature = transformer.transform(final_feature)
150+
Xt = transformer.transform(Xt)
147151
else:
148-
final_feature = transformers.transform(final_feature)
149-
extracted.append(_handle_feature(final_feature))
152+
Xt = transformers.transform(Xt)
153+
extracted.append(_handle_feature(Xt))
150154

151155
# combine the feature outputs into one array.
152156
# at this point we lose track of which features

tests/test_dataframe_mapper.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sklearn.pipeline import Pipeline
77
from sklearn.svm import SVC
88
from sklearn.feature_extraction.text import CountVectorizer
9+
from sklearn.preprocessing import Imputer, StandardScaler
910
import numpy as np
1011

1112
from sklearn_pandas import (
@@ -54,11 +55,27 @@ def test_with_iris_dataframe(iris_dataframe):
5455
def test_with_car_dataframe(cars_dataframe):
5556
pipeline = Pipeline([
5657
("preprocess", DataFrameMapper([
57-
("description", [PassthroughTransformer(), CountVectorizer()]),
58+
("description", CountVectorizer()),
5859
])),
5960
("classify", SVC(kernel='linear'))
6061
])
6162
data = cars_dataframe.drop("model", axis=1)
6263
labels = cars_dataframe["model"]
6364
scores = cross_val_score(pipeline, data, labels)
6465
assert scores.mean() > 0.30
66+
67+
68+
def test_list_transformers():
69+
dataframe = pd.DataFrame({"a": [1, np.nan, 3], "b": [1, 5, 7]})
70+
71+
mapper = DataFrameMapper([
72+
(["a"], [Imputer(), StandardScaler()]),
73+
(["b"], StandardScaler()),
74+
])
75+
dmatrix = mapper.fit_transform(dataframe)
76+
77+
assert pd.isnull(dmatrix).sum() == 0 # no null values
78+
79+
# all features have mean 0 and std deviation 1 (standardized)
80+
assert (abs(dmatrix.mean(axis=0) - 0) <= 1e-6).all()
81+
assert (abs(dmatrix.std(axis=0) - 1) <= 1e-6).all()

0 commit comments

Comments
 (0)