14
14
from sklearn .utils .multiclass import check_classification_targets
15
15
16
16
from .utils import check_sampling_strategy , check_target_type
17
+ from .utils ._validation import ArraysTransformer
17
18
18
19
19
20
class SamplerMixin (BaseEstimator , metaclass = ABCMeta ):
@@ -72,6 +73,7 @@ def fit_resample(self, X, y):
72
73
The corresponding label of `X_resampled`.
73
74
"""
74
75
check_classification_targets (y )
76
+ arrays_transformer = ArraysTransformer (X , y )
75
77
X , y , binarize_y = self ._check_X_y (X , y )
76
78
77
79
self .sampling_strategy_ = check_sampling_strategy (
@@ -80,21 +82,10 @@ def fit_resample(self, X, y):
80
82
81
83
output = self ._fit_resample (X , y )
82
84
83
- if self ._X_columns is not None or self ._y_name is not None :
84
- import pandas as pd
85
-
86
- if self ._X_columns is not None :
87
- X_ = pd .DataFrame (output [0 ], columns = self ._X_columns )
88
- X_ = X_ .astype (self ._X_dtypes )
89
- else :
90
- X_ = output [0 ]
91
-
92
85
y_ = (label_binarize (output [1 ], np .unique (y ))
93
86
if binarize_y else output [1 ])
94
87
95
- if self ._y_name is not None :
96
- y_ = pd .Series (y_ , dtype = self ._y_dtype , name = self ._y_name )
97
-
88
+ X_ , y_ = arrays_transformer .transform (output [0 ], y_ )
98
89
return (X_ , y_ ) if len (output ) == 2 else (X_ , y_ , output [2 ])
99
90
100
91
# define an alias for back-compatibility
@@ -137,22 +128,6 @@ def __init__(self, sampling_strategy="auto"):
137
128
self .sampling_strategy = sampling_strategy
138
129
139
130
def _check_X_y (self , X , y , accept_sparse = None ):
140
- if hasattr (X , "loc" ):
141
- # store information to build dataframe
142
- self ._X_columns = X .columns
143
- self ._X_dtypes = X .dtypes
144
- else :
145
- self ._X_columns = None
146
- self ._X_dtypes = None
147
-
148
- if hasattr (y , "loc" ):
149
- # store information to build a series
150
- self ._y_name = y .name
151
- self ._y_dtype = y .dtype
152
- else :
153
- self ._y_name = None
154
- self ._y_dtype = None
155
-
156
131
if accept_sparse is None :
157
132
accept_sparse = ["csr" , "csc" ]
158
133
y , binarize_y = check_target_type (y , indicate_one_vs_all = True )
@@ -265,8 +240,8 @@ def fit_resample(self, X, y):
265
240
y_resampled : array-like of shape (n_samples_new,)
266
241
The corresponding label of `X_resampled`.
267
242
"""
268
- # store the columns name to reconstruct a dataframe
269
- self . _columns = X . columns if hasattr ( X , "loc" ) else None
243
+ arrays_transformer = ArraysTransformer ( X , y )
244
+
270
245
if self .validate :
271
246
check_classification_targets (y )
272
247
X , y , binarize_y = self ._check_X_y (
@@ -280,22 +255,12 @@ def fit_resample(self, X, y):
280
255
output = self ._fit_resample (X , y )
281
256
282
257
if self .validate :
283
- if self ._X_columns is not None or self ._y_name is not None :
284
- import pandas as pd
285
-
286
- if self ._X_columns is not None :
287
- X_ = pd .DataFrame (output [0 ], columns = self ._X_columns )
288
- X_ = X_ .astype (self ._X_dtypes )
289
- else :
290
- X_ = output [0 ]
291
258
292
259
y_ = (label_binarize (output [1 ], np .unique (y ))
293
260
if binarize_y else output [1 ])
294
-
295
- if self ._y_name is not None :
296
- y_ = pd .Series (y_ , dtype = self ._y_dtype , name = self ._y_name )
297
-
261
+ X_ , y_ = arrays_transformer .transform (output [0 ], y_ )
298
262
return (X_ , y_ ) if len (output ) == 2 else (X_ , y_ , output [2 ])
263
+
299
264
return output
300
265
301
266
def _fit_resample (self , X , y ):
0 commit comments