Skip to content

Commit 65538c4

Browse files
ragrawalkitmonisitragrawal
authored
Enable regex and other ways to dynamically select columns. (#246)
Closes #239 #137 Co-authored-by: Kit Monisit <[email protected]> Co-authored-by: ragrawal <[email protected]>
1 parent e842746 commit 65538c4

File tree

6 files changed

+127
-16
lines changed

6 files changed

+127
-16
lines changed

README.rst

+59-3
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ The examples in this file double as basic sanity tests. To run them, use ``docte
3030

3131
# python -m doctest README.rst
3232

33+
3334
Usage
3435
-----
3536

37+
3638
Import
3739
******
3840

@@ -50,25 +52,33 @@ For these examples, we'll also use pandas, numpy, and sklearn::
5052
>>> import pandas as pd
5153
>>> import numpy as np
5254
>>> import sklearn.preprocessing, sklearn.decomposition, \
53-
... sklearn.linear_model, sklearn.pipeline, sklearn.metrics
55+
... sklearn.linear_model, sklearn.pipeline, sklearn.metrics, \
56+
... sklearn.compose
5457
>>> from sklearn.feature_extraction.text import CountVectorizer
5558

59+
5660
Load some Data
5761
**************
5862

63+
5964
Normally you'll read the data from a file, but for demonstration purposes we'll create a data frame from a Python dict::
6065

6166
>>> data = pd.DataFrame({'pet': ['cat', 'dog', 'dog', 'fish', 'cat', 'dog', 'cat', 'fish'],
6267
... 'children': [4., 6, 3, 3, 2, 3, 5, 4],
6368
... 'salary': [90., 24, 44, 27, 32, 59, 36, 27]})
6469

70+
6571
Transformation Mapping
6672
----------------------
6773

74+
6875
Map the Columns to Transformations
6976
**********************************
7077

71-
The mapper takes a list of tuples. The first element of each tuple is a column name from the pandas DataFrame, or a list containing one or multiple columns (we will see an example with multiple columns later). The second element is an object which will perform the transformation which will be applied to that column. The third one is optional and is a dictionary containing the transformation options, if applicable (see "custom column names for transformed features" below).
78+
The mapper takes a list of tuples. Each tuple has three elements:
79+
1. column name(s): The first element is a column name from the pandas DataFrame, or a list containing one or multiple columns (we will see an example with multiple columns later) or an instance of a callable function such as `make_column_selector <https://scikit-learn.org/stable/modules/generated/sklearn.compose.make_column_selector.html>`
80+
2. transformer(s): The second element is an object which will perform the transformation which will be applied to that column.
81+
3. attributes: The third one is optional and is a dictionary containing the transformation options, if applicable (see "custom column names for transformed features" below).
7282

7383
Let's see an example::
7484

@@ -77,7 +87,7 @@ Let's see an example::
7787
... (['children'], sklearn.preprocessing.StandardScaler())
7888
... ])
7989

80-
The difference between specifying the column selector as ``'column'`` (as a simple string) and ``['column']`` (as a list with one element) is the shape of the array that is passed to the transformer. In the first case, a one dimensional array will be passed, while in the second case it will be a 2-dimensional array with one column, i.e. a column vector.
90+
The difference between specifying the column selector as ``'column'`` (as a simple string) and ``['column']`` (as a list with one element) is the shape of the array that is passed to the transformer. In the first case, a one dimensional array will be passed, while in the second case it will be a 2-dimensional array with one column, i.e. a column vector.
8191

8292
This behaviour mimics the same pattern as pandas' dataframes ``__getitem__`` indexing:
8393

@@ -88,6 +98,7 @@ This behaviour mimics the same pattern as pandas' dataframes ``__getitem__`` in
8898

8999
Be aware that some transformers expect a 1-dimensional input (the label-oriented ones) while some others, like ``OneHotEncoder`` or ``Imputer``, expect 2-dimensional input, with the shape ``[n_samples, n_features]``.
90100

101+
91102
Test the Transformation
92103
***********************
93104

@@ -150,6 +161,46 @@ Alternatively, you can also specify prefix and/or suffix to add to the column na
150161
>>> mapper_alias.transformed_names_
151162
['standard_scaled_children', 'children_raw']
152163

164+
165+
Dynamic Columns
166+
***********************
167+
In some situations the columns are not known before hand and we would like to dynamically select them during the fit operation. As shown below, in such situations you can provide either a custom callable or use `make_column_selector <https://scikit-learn.org/stable/modules/generated/sklearn.compose.make_column_selector.html>`.
168+
169+
170+
>>> class GetColumnsStartingWith:
171+
... def __init__(self, start_str):
172+
... self.pattern = start_str
173+
...
174+
... def __call__(self, X:pd.DataFrame=None):
175+
... return [c for c in X.columns if c.startswith(self.pattern)]
176+
...
177+
>>> df = pd.DataFrame({
178+
... 'sepal length (cm)': [1.0, 2.0, 3.0],
179+
... 'sepal width (cm)': [1.0, 2.0, 3.0],
180+
... 'petal length (cm)': [1.0, 2.0, 3.0],
181+
... 'petal width (cm)': [1.0, 2.0, 3.0]
182+
... })
183+
>>> t = DataFrameMapper([
184+
... (
185+
... sklearn.compose.make_column_selector(dtype_include=float),
186+
... sklearn.preprocessing.StandardScaler(),
187+
... {'alias': 'x'}
188+
... ),
189+
... (
190+
... GetColumnsStartingWith('petal'),
191+
... None,
192+
... {'alias': 'petal'}
193+
... )], df_out=True, default=False)
194+
>>> t.fit(df).transform(df).shape
195+
(3, 6)
196+
>>> t.transformed_names_
197+
['x_0', 'x_1', 'x_2', 'x_3', 'petal_0', 'petal_1']
198+
199+
200+
201+
Above we use `make_column_selector` to select all columns that are of type float and also use a custom callable function to select columns that start with the word 'petal'.
202+
203+
153204
Passing Series/DataFrames to the transformers
154205
*********************************************
155206

@@ -463,6 +514,11 @@ Changelog
463514
---------
464515

465516

517+
2.2.0 (2021-05-07)
518+
******************
519+
* Added an ability to provide callable functions instead of static column list.
520+
521+
466522
2.1.0 (2021-02-26)
467523
******************
468524
* Removed test for Python 3.6 and added Python 3.9

sklearn_pandas/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = '2.1.0'
1+
__version__ = '2.2.0'
22

33
import logging
44
logger = logging.getLogger(__name__)

sklearn_pandas/dataframe_mapper.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import contextlib
2-
32
from datetime import datetime
43
import pandas as pd
54
import numpy as np
65
from scipy import sparse
76
from sklearn.base import BaseEstimator, TransformerMixin
8-
97
from .cross_validation import DataWrapper
108
from .pipeline import make_transformer_pipeline, _call_fit, TransformerPipeline
119
from . import logger
@@ -29,8 +27,14 @@ def _build_transformer(transformers):
2927
return transformers
3028

3129

32-
def _build_feature(columns, transformers, options={}):
33-
return (columns, _build_transformer(transformers), options)
30+
def _build_feature(columns, transformers, options={}, X=None):
31+
if X is None:
32+
return (columns, _build_transformer(transformers), options)
33+
return (
34+
columns(X) if callable(columns) else columns,
35+
_build_transformer(transformers),
36+
options
37+
)
3438

3539

3640
def _elapsed_secs(t1):
@@ -116,14 +120,16 @@ def __init__(self, features, default=False, sparse=False, df_out=False,
116120
if (df_out and (sparse or default)):
117121
raise ValueError("Can not use df_out with sparse or default")
118122

119-
def _build(self):
123+
def _build(self, X=None):
120124
"""
121125
Build attributes built_features and built_default.
122126
"""
123127
if isinstance(self.features, list):
124-
self.built_features = [_build_feature(*f) for f in self.features]
128+
self.built_features = [
129+
_build_feature(*f, X=X) for f in self.features
130+
]
125131
else:
126-
self.built_features = self.features
132+
self.built_features = _build_feature(*self.features, X=X)
127133
self.built_default = _build_transformer(self.default)
128134

129135
@property
@@ -185,11 +191,13 @@ def _get_col_subset(self, X, cols, input_df=False):
185191
Get a subset of columns from the given table X.
186192
187193
X a Pandas dataframe; the table to select columns from
188-
cols a string or list of strings representing the columns
189-
to select
194+
cols a string or list of strings representing the columns to select.
195+
It can also be a callable that returns True or False, i.e.
196+
compatible with the built-in filter function.
190197
191198
Returns a numpy array with the data from the selected columns
192199
"""
200+
193201
if isinstance(cols, string_types):
194202
return_vector = True
195203
cols = [cols]
@@ -226,7 +234,7 @@ def fit(self, X, y=None):
226234
y the target vector relative to X, optional
227235
228236
"""
229-
self._build()
237+
self._build(X=X)
230238

231239
for columns, transformers, options in self.built_features:
232240
t1 = datetime.now()
@@ -315,7 +323,7 @@ def _transform(self, X, y=None, do_fit=False):
315323
fit_transform.
316324
"""
317325
if do_fit:
318-
self._build()
326+
self._build(X=X)
319327

320328
extracted = []
321329
transformed_names_ = []

sklearn_pandas/transformers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, func):
3333
"""
3434

3535
warnings.warn("""
36-
NumericalTransformer will be deprecated in 2.2 version.
36+
NumericalTransformer will be deprecated in 3.0 version.
3737
Please use Sklearn.base.TransformerMixin to write
3838
customer transformers
3939
""", DeprecationWarning)

test.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
from unittest.mock import Mock
3+
import numpy as np
4+
import pandas as pd
5+
from sklearn_pandas import DataFrameMapper
6+
from sklearn.compose import make_column_selector
7+
from sklearn.preprocessing import StandardScaler
8+
9+
10+
class GetStartWith:
11+
def __init__(self, start_str):
12+
self.start_str = start_str
13+
14+
def __call__(self, X: pd.DataFrame) -> list:
15+
return [c for c in X.columns if c.startswith(self.start_str)]
16+
17+
18+
df = pd.DataFrame({
19+
'sepal length (cm)': [1.0, 2.0, 3.0],
20+
'sepal width (cm)': [1.0, 2.0, 3.0],
21+
'petal length (cm)': [1.0, 2.0, 3.0],
22+
'petal width (cm)': [1.0, 2.0, 3.0]
23+
})
24+
t = DataFrameMapper([
25+
(make_column_selector(dtype_include=float), StandardScaler(), {'alias': 'x'}),
26+
(GetStartWith('petal'), None, {'alias': 'petal'})
27+
], df_out=True, default=False)
28+
29+
t.fit(df)
30+
print(t.transform(df).shape)

tests/test_dataframe_mapper.py

+17
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import numpy as np
2121
from numpy.testing import assert_array_equal
2222
import pickle
23+
from sklearn.compose import make_column_selector
2324

2425
from sklearn_pandas import DataFrameMapper
2526
from sklearn_pandas.dataframe_mapper import _handle_feature, _build_transformer
@@ -969,3 +970,19 @@ def test_heterogeneous_output_types_input_df():
969970
dft = M.fit_transform(df)
970971
assert dft['feat1'].dtype == np.dtype('int64')
971972
assert dft['feat2'].dtype == np.dtype('float64')
973+
974+
975+
def test_make_column_selector(iris_dataframe):
976+
t = DataFrameMapper([
977+
(make_column_selector(dtype_include=float), None, {'alias': 'x'}),
978+
('sepal length (cm)', None),
979+
], df_out=True, default=False)
980+
981+
xt = t.fit(iris_dataframe).transform(iris_dataframe)
982+
expected = ['x_0', 'x_1', 'x_2', 'x_3', 'sepal length (cm)']
983+
assert list(xt.columns) == expected
984+
985+
pickled = pickle.dumps(t)
986+
t2 = pickle.loads(pickled)
987+
xt2 = t2.transform(iris_dataframe)
988+
assert np.array_equal(xt.values, xt2.values)

0 commit comments

Comments
 (0)