Skip to content

Commit f3cf7f6

Browse files
authored
Merge pull request scikit-learn-contrib#177 from ragrawal/astransformer
ActAsTransformer
2 parents 0024cf7 + c83a2de commit f3cf7f6

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

sklearn_pandas/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .cross_validation import cross_val_score, GridSearchCV, RandomizedSearchCV # NOQA
55
from .categorical_imputer import CategoricalImputer # NOQA
66
from .features_generator import gen_features # NOQA
7+
from .act_as_transformer import ActAsTransformer # NOQA

sklearn_pandas/act_as_transformer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from sklearn.base import BaseEstimator, TransformerMixin
2+
import numpy as np
3+
4+
5+
class ActAsTransformer(BaseEstimator, TransformerMixin):
6+
"""
7+
Use this class to convert a random function into a
8+
transformer.
9+
"""
10+
11+
def __init__(self, func):
12+
self.__func = func
13+
14+
def fit(self, x, y=None):
15+
return self
16+
17+
def transform(self, x):
18+
return np.vectorize(self.__func)(x)
19+
20+
def __call__(self, *args, **kwargs):
21+
return self.__func(*args, **kwargs)

tests/test_dataframe_mapper.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sklearn_pandas import DataFrameMapper, cross_val_score
3333
from sklearn_pandas.dataframe_mapper import _handle_feature, _build_transformer
3434
from sklearn_pandas.pipeline import TransformerPipeline
35+
from sklearn_pandas import ActAsTransformer
3536

3637

3738
class MockXTransformer(object):
@@ -950,3 +951,17 @@ def test_heterogeneous_output_types_input_df():
950951
dft = M.fit_transform(df)
951952
assert dft['feat1'].dtype == np.dtype('int64')
952953
assert dft['feat2'].dtype == np.dtype('float64')
954+
955+
956+
def test_actastransformer():
957+
"""
958+
Test whether random transformations works
959+
"""
960+
df = pd.DataFrame({
961+
'feat1': [10, 100],
962+
})
963+
M = DataFrameMapper([
964+
('feat1', ActAsTransformer(np.log10))
965+
], input_df=True, df_out=True, default=None)
966+
dft = M.fit_transform(df)
967+
assert_array_equal([1., 2.], dft['feat1'].values)

0 commit comments

Comments
 (0)