@@ -89,6 +89,24 @@ def transform(self, X):
89
89
return X - self .min
90
90
91
91
92
+ class MockImageTransformer (BaseEstimator , TransformerMixin ):
93
+ """
94
+ Example transformer that takes the max of a 2d vector
95
+ then scales the result.
96
+ """
97
+ def __init__ (self , multiplier = 10.0 ):
98
+ self .multiplier = multiplier
99
+
100
+ def fit (self , X , y = None ):
101
+ return self
102
+
103
+ def transform (self , X ):
104
+ assert isinstance (X , pd .DataFrame )
105
+ for col in X .columns :
106
+ X [col ] = X [col ].map (lambda img : np .max (img ))
107
+ return X * self .multiplier
108
+
109
+
92
110
@pytest .fixture
93
111
def simple_dataframe ():
94
112
return pd .DataFrame ({'a' : [1 , 2 , 3 ]})
@@ -101,6 +119,15 @@ def complex_dataframe():
101
119
'feat2' : [1 , 2 , 3 , 2 , 3 , 4 ]})
102
120
103
121
122
+ @pytest .fixture
123
+ def complex_object_dataframe ():
124
+ return pd .DataFrame ({'target' : ['a' , 'a' , 'b' , 'b' , 'c' , 'c' ],
125
+ 'feat1' : [1 , 2 , 3 , 4 , 5 , 6 ],
126
+ 'feat2' : [1 , 2 , 3 , 2 , 3 , 4 ],
127
+ 'img2d' : [1 * np .eye (2 ), 2 * np .eye (2 ), 3 * np .eye (2 ),
128
+ 4 * np .eye (2 ), 5 * np .eye (2 ), 6 * np .eye (2 )]})
129
+
130
+
104
131
@pytest .fixture
105
132
def multiindex_dataframe ():
106
133
"""Example MultiIndex DataFrame, taken from pandas documentation
@@ -264,6 +291,24 @@ def test_complex_df(complex_dataframe):
264
291
assert len (transformed [c ]) == len (df [c ])
265
292
266
293
294
+ def test_complex_object_df (complex_object_dataframe ):
295
+ """
296
+ Get a dataframe from a complex dataframe with 2d features
297
+ """
298
+ df = complex_object_dataframe
299
+ img_scale = 10
300
+ mapper = DataFrameMapper (
301
+ [('target' , None ), ('feat1' , None ),
302
+ (make_column_selector ('feat2' ), StandardScaler ()),
303
+ (make_column_selector ('img2d' ), MockImageTransformer (img_scale ))],
304
+ df_out = True , input_df = True )
305
+ transformed = mapper .fit_transform (df )
306
+ assert len (transformed ) == len (complex_object_dataframe )
307
+ assert np .isclose (
308
+ np .sum (transformed ['img2d' ]),
309
+ np .max (np .sum (df ['img2d' ])) * img_scale , atol = 1e-12 )
310
+
311
+
267
312
def test_numeric_column_names (complex_dataframe ):
268
313
"""
269
314
Get a dataframe from a complex mapped dataframe with numeric column names
0 commit comments