Skip to content

Commit 7fdc39a

Browse files
kristofvedukebody
authored andcommitted
Fix column naming for DataFrames with MultiIndex columns (#166)
1 parent 757cc33 commit 7fdc39a

File tree

3 files changed

+54
-2
lines changed

3 files changed

+54
-2
lines changed

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@
33
.tox/
44
build/
55
dist/
6-
.cache/
6+
.cache/
7+
.idea/
8+
.pytest_cache/

sklearn_pandas/dataframe_mapper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def get_names(self, columns, transformer, x, alias=None):
233233
if alias is not None:
234234
name = alias
235235
elif isinstance(columns, list):
236-
name = '_'.join(columns)
236+
name = '_'.join(map(str, columns))
237237
else:
238238
name = columns
239239
num_cols = x.shape[1] if len(x.shape) > 1 else 1

tests/test_dataframe_mapper.py

+50
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,29 @@ def complex_dataframe():
108108
'feat2': [1, 2, 3, 2, 3, 4]})
109109

110110

111+
@pytest.fixture
112+
def multiindex_dataframe():
113+
"""Example MultiIndex DataFrame, taken from pandas documentation
114+
"""
115+
iterables = [['bar', 'baz', 'foo', 'qux'], ['one', 'two']]
116+
index = pd.MultiIndex.from_product(iterables, names=['first', 'second'])
117+
df = pd.DataFrame(np.random.randn(10, 8), columns=index)
118+
return df
119+
120+
121+
@pytest.fixture
122+
def multiindex_dataframe_incomplete(multiindex_dataframe):
123+
"""Example MultiIndex DataFrame with missing entries
124+
"""
125+
df = multiindex_dataframe
126+
mask_array = np.zeros(df.size)
127+
mask_array[:20] = 1
128+
np.random.shuffle(mask_array)
129+
mask = mask_array.reshape(df.shape).astype(bool)
130+
df.mask(mask, inplace=True)
131+
return df
132+
133+
111134
def test_transformed_names_simple(simple_dataframe):
112135
"""
113136
Get transformed names of features in `transformed_names` attribute
@@ -234,6 +257,33 @@ def test_complex_df(complex_dataframe):
234257
assert len(transformed[c]) == len(df[c])
235258

236259

260+
def test_numeric_column_names(complex_dataframe):
261+
"""
262+
Get a dataframe from a complex mapped dataframe with numeric column names
263+
"""
264+
df = complex_dataframe
265+
df.columns = [0, 1, 2]
266+
mapper = DataFrameMapper(
267+
[(0, None), (1, None), (2, None)], df_out=True)
268+
transformed = mapper.fit_transform(df)
269+
assert len(transformed) == len(complex_dataframe)
270+
for c in df.columns:
271+
assert len(transformed[c]) == len(df[c])
272+
273+
274+
def test_multiindex_df(multiindex_dataframe_incomplete):
275+
"""
276+
Get a dataframe from a multiindex dataframe with missing data
277+
"""
278+
df = multiindex_dataframe_incomplete
279+
mapper = DataFrameMapper([([c], Imputer()) for c in df.columns],
280+
df_out=True)
281+
transformed = mapper.fit_transform(df)
282+
assert len(transformed) == len(multiindex_dataframe_incomplete)
283+
for c in df.columns:
284+
assert len(transformed[str(c)]) == len(df[c])
285+
286+
237287
def test_binarizer_df():
238288
"""
239289
Check level names from LabelBinarizer

0 commit comments

Comments
 (0)