14
14
from sklearn .utils .multiclass import check_classification_targets
15
15
16
16
from .utils import check_sampling_strategy , check_target_type
17
+ from .utils ._validation import _OutputReconstructor
17
18
18
19
19
20
class SamplerMixin (BaseEstimator , metaclass = ABCMeta ):
@@ -80,21 +81,10 @@ def fit_resample(self, X, y):
80
81
81
82
output = self ._fit_resample (X , y )
82
83
83
- if self ._X_columns is not None or self ._y_name is not None :
84
- import pandas as pd
85
-
86
- if self ._X_columns is not None :
87
- X_ = pd .DataFrame (output [0 ], columns = self ._X_columns )
88
- X_ = X_ .astype (self ._X_dtypes )
89
- else :
90
- X_ = output [0 ]
91
-
92
84
y_ = (label_binarize (output [1 ], np .unique (y ))
93
85
if binarize_y else output [1 ])
94
86
95
- if self ._y_name is not None :
96
- y_ = pd .Series (y_ , dtype = self ._y_dtype , name = self ._y_name )
97
-
87
+ X_ , y_ = self ._reconstructor .reconstruct (output [0 ], y_ )
98
88
return (X_ , y_ ) if len (output ) == 2 else (X_ , y_ , output [2 ])
99
89
100
90
# define an alias for back-compatibility
@@ -137,22 +127,7 @@ def __init__(self, sampling_strategy="auto"):
137
127
self .sampling_strategy = sampling_strategy
138
128
139
129
def _check_X_y (self , X , y , accept_sparse = None ):
140
- if hasattr (X , "loc" ):
141
- # store information to build dataframe
142
- self ._X_columns = X .columns
143
- self ._X_dtypes = X .dtypes
144
- else :
145
- self ._X_columns = None
146
- self ._X_dtypes = None
147
-
148
- if hasattr (y , "loc" ):
149
- # store information to build a series
150
- self ._y_name = y .name
151
- self ._y_dtype = y .dtype
152
- else :
153
- self ._y_name = None
154
- self ._y_dtype = None
155
-
130
+ self ._reconstructor = _OutputReconstructor (X , y )
156
131
if accept_sparse is None :
157
132
accept_sparse = ["csr" , "csc" ]
158
133
y , binarize_y = check_target_type (y , indicate_one_vs_all = True )
@@ -265,8 +240,8 @@ def fit_resample(self, X, y):
265
240
y_resampled : array-like of shape (n_samples_new,)
266
241
The corresponding label of `X_resampled`.
267
242
"""
268
- # store the columns name to reconstruct a dataframe
269
- self . _columns = X . columns if hasattr ( X , "loc" ) else None
243
+ self . _reconstructor = _OutputReconstructor ( X , y )
244
+
270
245
if self .validate :
271
246
check_classification_targets (y )
272
247
X , y , binarize_y = self ._check_X_y (
@@ -280,22 +255,12 @@ def fit_resample(self, X, y):
280
255
output = self ._fit_resample (X , y )
281
256
282
257
if self .validate :
283
- if self ._X_columns is not None or self ._y_name is not None :
284
- import pandas as pd
285
-
286
- if self ._X_columns is not None :
287
- X_ = pd .DataFrame (output [0 ], columns = self ._X_columns )
288
- X_ = X_ .astype (self ._X_dtypes )
289
- else :
290
- X_ = output [0 ]
291
258
292
259
y_ = (label_binarize (output [1 ], np .unique (y ))
293
260
if binarize_y else output [1 ])
294
-
295
- if self ._y_name is not None :
296
- y_ = pd .Series (y_ , dtype = self ._y_dtype , name = self ._y_name )
297
-
261
+ X_ , y_ = self ._reconstructor .reconstruct (output [0 ], y_ )
298
262
return (X_ , y_ ) if len (output ) == 2 else (X_ , y_ , output [2 ])
263
+
299
264
return output
300
265
301
266
def _fit_resample (self , X , y ):
0 commit comments