Skip to content

Commit c9db2d6

Browse files
authored
Add new complex dataframe transform test for 2d cell data (scikit-learn-contrib#254)
* add new test for 2d data in dataframe cell * remove dbg statement * fixes for flake8 * improve test
1 parent fa3b726 commit c9db2d6

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

tests/test_dataframe_mapper.py

+45
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,24 @@ def transform(self, X):
8989
return X - self.min
9090

9191

92+
class MockImageTransformer(BaseEstimator, TransformerMixin):
93+
"""
94+
Example transformer that takes the max of a 2d vector
95+
then scales the result.
96+
"""
97+
def __init__(self, multiplier=10.0):
98+
self.multiplier = multiplier
99+
100+
def fit(self, X, y=None):
101+
return self
102+
103+
def transform(self, X):
104+
assert isinstance(X, pd.DataFrame)
105+
for col in X.columns:
106+
X[col] = X[col].map(lambda img: np.max(img))
107+
return X * self.multiplier
108+
109+
92110
@pytest.fixture
93111
def simple_dataframe():
94112
return pd.DataFrame({'a': [1, 2, 3]})
@@ -101,6 +119,15 @@ def complex_dataframe():
101119
'feat2': [1, 2, 3, 2, 3, 4]})
102120

103121

122+
@pytest.fixture
123+
def complex_object_dataframe():
124+
return pd.DataFrame({'target': ['a', 'a', 'b', 'b', 'c', 'c'],
125+
'feat1': [1, 2, 3, 4, 5, 6],
126+
'feat2': [1, 2, 3, 2, 3, 4],
127+
'img2d': [1*np.eye(2), 2*np.eye(2), 3*np.eye(2),
128+
4*np.eye(2), 5*np.eye(2), 6*np.eye(2)]})
129+
130+
104131
@pytest.fixture
105132
def multiindex_dataframe():
106133
"""Example MultiIndex DataFrame, taken from pandas documentation
@@ -264,6 +291,24 @@ def test_complex_df(complex_dataframe):
264291
assert len(transformed[c]) == len(df[c])
265292

266293

294+
def test_complex_object_df(complex_object_dataframe):
295+
"""
296+
Get a dataframe from a complex dataframe with 2d features
297+
"""
298+
df = complex_object_dataframe
299+
img_scale = 10
300+
mapper = DataFrameMapper(
301+
[('target', None), ('feat1', None),
302+
(make_column_selector('feat2'), StandardScaler()),
303+
(make_column_selector('img2d'), MockImageTransformer(img_scale))],
304+
df_out=True, input_df=True)
305+
transformed = mapper.fit_transform(df)
306+
assert len(transformed) == len(complex_object_dataframe)
307+
assert np.isclose(
308+
np.sum(transformed['img2d']),
309+
np.max(np.sum(df['img2d'])) * img_scale, atol=1e-12)
310+
311+
267312
def test_numeric_column_names(complex_dataframe):
268313
"""
269314
Get a dataframe from a complex mapped dataframe with numeric column names

0 commit comments

Comments
 (0)