Skip to content

Commit 55d8f27

Browse files
author
chkoar
committed
Better in-out support for pandas.
1 parent 42cd496 commit 55d8f27

File tree

6 files changed

+75
-100
lines changed

6 files changed

+75
-100
lines changed

imblearn/base.py

+7-42
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.utils.multiclass import check_classification_targets
1515

1616
from .utils import check_sampling_strategy, check_target_type
17+
from .utils._validation import _OutputReconstructor
1718

1819

1920
class SamplerMixin(BaseEstimator, metaclass=ABCMeta):
@@ -80,21 +81,10 @@ def fit_resample(self, X, y):
8081

8182
output = self._fit_resample(X, y)
8283

83-
if self._X_columns is not None or self._y_name is not None:
84-
import pandas as pd
85-
86-
if self._X_columns is not None:
87-
X_ = pd.DataFrame(output[0], columns=self._X_columns)
88-
X_ = X_.astype(self._X_dtypes)
89-
else:
90-
X_ = output[0]
91-
9284
y_ = (label_binarize(output[1], np.unique(y))
9385
if binarize_y else output[1])
9486

95-
if self._y_name is not None:
96-
y_ = pd.Series(y_, dtype=self._y_dtype, name=self._y_name)
97-
87+
X_, y_ = self._reconstructor.reconstruct(output[0], y_)
9888
return (X_, y_) if len(output) == 2 else (X_, y_, output[2])
9989

10090
# define an alias for back-compatibility
@@ -137,22 +127,7 @@ def __init__(self, sampling_strategy="auto"):
137127
self.sampling_strategy = sampling_strategy
138128

139129
def _check_X_y(self, X, y, accept_sparse=None):
140-
if hasattr(X, "loc"):
141-
# store information to build dataframe
142-
self._X_columns = X.columns
143-
self._X_dtypes = X.dtypes
144-
else:
145-
self._X_columns = None
146-
self._X_dtypes = None
147-
148-
if hasattr(y, "loc"):
149-
# store information to build a series
150-
self._y_name = y.name
151-
self._y_dtype = y.dtype
152-
else:
153-
self._y_name = None
154-
self._y_dtype = None
155-
130+
self._reconstructor = _OutputReconstructor(X, y)
156131
if accept_sparse is None:
157132
accept_sparse = ["csr", "csc"]
158133
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
@@ -265,8 +240,8 @@ def fit_resample(self, X, y):
265240
y_resampled : array-like of shape (n_samples_new,)
266241
The corresponding label of `X_resampled`.
267242
"""
268-
# store the columns name to reconstruct a dataframe
269-
self._columns = X.columns if hasattr(X, "loc") else None
243+
self._reconstructor = _OutputReconstructor(X, y)
244+
270245
if self.validate:
271246
check_classification_targets(y)
272247
X, y, binarize_y = self._check_X_y(
@@ -280,22 +255,12 @@ def fit_resample(self, X, y):
280255
output = self._fit_resample(X, y)
281256

282257
if self.validate:
283-
if self._X_columns is not None or self._y_name is not None:
284-
import pandas as pd
285-
286-
if self._X_columns is not None:
287-
X_ = pd.DataFrame(output[0], columns=self._X_columns)
288-
X_ = X_.astype(self._X_dtypes)
289-
else:
290-
X_ = output[0]
291258

292259
y_ = (label_binarize(output[1], np.unique(y))
293260
if binarize_y else output[1])
294-
295-
if self._y_name is not None:
296-
y_ = pd.Series(y_, dtype=self._y_dtype, name=self._y_name)
297-
261+
X_, y_ = self._reconstructor.reconstruct(output[0], y_)
298262
return (X_, y_) if len(output) == 2 else (X_, y_, output[2])
263+
299264
return output
300265

301266
def _fit_resample(self, X, y):

imblearn/over_sampling/_random_over_sampler.py

+2-16
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ..utils import check_target_type
1616
from ..utils import Substitution
1717
from ..utils._docstring import _random_state_docstring
18+
from ..utils._validation import _OutputReconstructor
1819

1920

2021
@Substitution(
@@ -75,22 +76,7 @@ def __init__(self, sampling_strategy="auto", random_state=None):
7576
self.random_state = random_state
7677

7778
def _check_X_y(self, X, y):
78-
if hasattr(X, "loc"):
79-
# store information to build dataframe
80-
self._X_columns = X.columns
81-
self._X_dtypes = X.dtypes
82-
else:
83-
self._X_columns = None
84-
self._X_dtypes = None
85-
86-
if hasattr(y, "loc"):
87-
# store information to build a series
88-
self._y_name = y.name
89-
self._y_dtype = y.dtype
90-
else:
91-
self._y_name = None
92-
self._y_dtype = None
93-
79+
self._reconstructor = _OutputReconstructor(X, y)
9480
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
9581
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
9682
force_all_finite=False)

imblearn/over_sampling/_smote.py

+2-16
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ..utils import Substitution
3232
from ..utils._docstring import _n_jobs_docstring
3333
from ..utils._docstring import _random_state_docstring
34+
from ..utils._validation import _OutputReconstructor
3435

3536

3637
class BaseSMOTE(BaseOverSampler):
@@ -891,22 +892,7 @@ def _check_X_y(self, X, y):
891892
"""Overwrite the checking to let pass some string for categorical
892893
features.
893894
"""
894-
if hasattr(X, "loc"):
895-
# store information to build dataframe
896-
self._X_columns = X.columns
897-
self._X_dtypes = X.dtypes
898-
else:
899-
self._X_columns = None
900-
self._X_dtypes = None
901-
902-
if hasattr(y, "loc"):
903-
# store information to build a series
904-
self._y_name = y.name
905-
self._y_dtype = y.dtype
906-
else:
907-
self._y_name = None
908-
self._y_dtype = None
909-
895+
self._reconstructor = _OutputReconstructor(X, y)
910896
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
911897
X, y = check_X_y(X, y, accept_sparse=["csr", "csc"], dtype=None)
912898
return X, y, binarize_y

imblearn/under_sampling/_prototype_selection/_random_under_sampler.py

+2-16
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ...utils import check_target_type
1616
from ...utils import Substitution
1717
from ...utils._docstring import _random_state_docstring
18+
from ...utils._validation import _OutputReconstructor
1819

1920

2021
@Substitution(
@@ -81,22 +82,7 @@ def __init__(
8182
self.replacement = replacement
8283

8384
def _check_X_y(self, X, y):
84-
if hasattr(X, "loc"):
85-
# store information to build dataframe
86-
self._X_columns = X.columns
87-
self._X_dtypes = X.dtypes
88-
else:
89-
self._X_columns = None
90-
self._X_dtypes = None
91-
92-
if hasattr(y, "loc"):
93-
# store information to build a series
94-
self._y_name = y.name
95-
self._y_dtype = y.dtype
96-
else:
97-
self._y_name = None
98-
self._y_dtype = None
99-
85+
self._reconstructor = _OutputReconstructor(X, y)
10086
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
10187
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
10288
force_all_finite=False)

imblearn/utils/_validation.py

+45
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,51 @@
2727
TARGET_KIND = ("binary", "multiclass", "multilabel-indicator")
2828

2929

30+
class _OutputReconstructor:
31+
"""A class for converting input types to numpy and back."""
32+
33+
def __init__(self, X, y):
34+
self.x_props = self._gets_props(X)
35+
self.y_props = self._gets_props(y)
36+
37+
def reconstruct(self, X, y):
38+
X = self._transfrom(X, self.x_props)
39+
y = self._transfrom(y, self.y_props)
40+
return X, y
41+
42+
def _gets_props(self, array):
43+
props = {}
44+
props["type"] = array.__class__.__name__
45+
props["columns"] = getattr(array, "columns", None)
46+
props["name"] = getattr(array, "name", None)
47+
props["dtypes"] = getattr(array, "dtypes", None)
48+
return props
49+
50+
def _transfrom(self, array, props):
51+
type_ = props["type"].lower()
52+
msg="Could not convert to {}".format(type_)
53+
if type_ == "list":
54+
ret = array.tolist()
55+
elif type_ == "dataframe":
56+
try:
57+
import pandas as pd
58+
ret = pd.DataFrame(array, columns=props["columns"])
59+
ret = ret.astype(props["dtypes"])
60+
except Exception:
61+
warnings.warn(msg)
62+
elif type_ == "series":
63+
try:
64+
import pandas as pd
65+
ret = pd.Series(array,
66+
dtype=props["dtypes"],
67+
name=props["name"])
68+
except Exception:
69+
warnings.warn(msg)
70+
else:
71+
ret = array
72+
return ret
73+
74+
3075
def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
3176
"""Check the objects is consistent to be a NN.
3277

imblearn/utils/estimator_checks.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,9 @@ def check_samplers_pandas(name, Sampler):
242242
weights=[0.2, 0.3, 0.5],
243243
random_state=0,
244244
)
245-
X_pd = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
246-
y_pd = pd.Series(y, name="class")
245+
X_df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
246+
y_df = pd.DataFrame(y)
247+
y_s = pd.Series(y, name="class")
247248
sampler = Sampler()
248249
if isinstance(Sampler(), NearMiss):
249250
samplers = [Sampler(version=version) for version in (1, 2, 3)]
@@ -253,16 +254,22 @@ def check_samplers_pandas(name, Sampler):
253254

254255
for sampler in samplers:
255256
set_random_state(sampler)
256-
X_res_pd, y_res_pd = sampler.fit_resample(X_pd, y_pd)
257+
X_res_df, y_res_s = sampler.fit_resample(X_df, y_s)
258+
X_res_df, y_res_df = sampler.fit_resample(X_df, y_df)
257259
X_res, y_res = sampler.fit_resample(X, y)
258260

259-
# check that we return a pandas dataframe if a dataframe was given in
260-
assert isinstance(X_res_pd, pd.DataFrame)
261-
assert isinstance(y_res_pd, pd.Series)
262-
assert X_pd.columns.to_list() == X_res_pd.columns.to_list()
263-
assert y_pd.name == y_res_pd.name
264-
assert_allclose(X_res_pd.to_numpy(), X_res)
265-
assert_allclose(y_res_pd.to_numpy(), y_res)
261+
# check that we return the same type for dataframes or seires types
262+
assert isinstance(X_res_df, pd.DataFrame)
263+
assert isinstance(y_res_df, pd.DataFrame)
264+
assert isinstance(y_res_s, pd.Series)
265+
266+
assert X_df.columns.to_list() == X_res_df.columns.to_list()
267+
assert y_df.columns.to_list() == y_res_df.columns.to_list()
268+
assert y_s.name == y_res_s.name
269+
270+
assert_allclose(X_res_df.to_numpy(), X_res)
271+
assert_allclose(y_res_df.to_numpy().ravel(), y_res)
272+
assert_allclose(y_res_s.to_numpy(), y_res)
266273

267274

268275
def check_samplers_multiclass_ova(name, Sampler):

0 commit comments

Comments
 (0)