Skip to content

Commit 4ba2803

Browse files
authored
BUG Better in-out support with different arrays type (#681)
1 parent 3ede269 commit 4ba2803

File tree

8 files changed

+146
-103
lines changed

8 files changed

+146
-103
lines changed

imblearn/base.py

+7-42
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.utils.multiclass import check_classification_targets
1515

1616
from .utils import check_sampling_strategy, check_target_type
17+
from .utils._validation import ArraysTransformer
1718

1819

1920
class SamplerMixin(BaseEstimator, metaclass=ABCMeta):
@@ -72,6 +73,7 @@ def fit_resample(self, X, y):
7273
The corresponding label of `X_resampled`.
7374
"""
7475
check_classification_targets(y)
76+
arrays_transformer = ArraysTransformer(X, y)
7577
X, y, binarize_y = self._check_X_y(X, y)
7678

7779
self.sampling_strategy_ = check_sampling_strategy(
@@ -80,21 +82,10 @@ def fit_resample(self, X, y):
8082

8183
output = self._fit_resample(X, y)
8284

83-
if self._X_columns is not None or self._y_name is not None:
84-
import pandas as pd
85-
86-
if self._X_columns is not None:
87-
X_ = pd.DataFrame(output[0], columns=self._X_columns)
88-
X_ = X_.astype(self._X_dtypes)
89-
else:
90-
X_ = output[0]
91-
9285
y_ = (label_binarize(output[1], np.unique(y))
9386
if binarize_y else output[1])
9487

95-
if self._y_name is not None:
96-
y_ = pd.Series(y_, dtype=self._y_dtype, name=self._y_name)
97-
88+
X_, y_ = arrays_transformer.transform(output[0], y_)
9889
return (X_, y_) if len(output) == 2 else (X_, y_, output[2])
9990

10091
# define an alias for back-compatibility
@@ -137,22 +128,6 @@ def __init__(self, sampling_strategy="auto"):
137128
self.sampling_strategy = sampling_strategy
138129

139130
def _check_X_y(self, X, y, accept_sparse=None):
140-
if hasattr(X, "loc"):
141-
# store information to build dataframe
142-
self._X_columns = X.columns
143-
self._X_dtypes = X.dtypes
144-
else:
145-
self._X_columns = None
146-
self._X_dtypes = None
147-
148-
if hasattr(y, "loc"):
149-
# store information to build a series
150-
self._y_name = y.name
151-
self._y_dtype = y.dtype
152-
else:
153-
self._y_name = None
154-
self._y_dtype = None
155-
156131
if accept_sparse is None:
157132
accept_sparse = ["csr", "csc"]
158133
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
@@ -265,8 +240,8 @@ def fit_resample(self, X, y):
265240
y_resampled : array-like of shape (n_samples_new,)
266241
The corresponding label of `X_resampled`.
267242
"""
268-
# store the columns name to reconstruct a dataframe
269-
self._columns = X.columns if hasattr(X, "loc") else None
243+
arrays_transformer = ArraysTransformer(X, y)
244+
270245
if self.validate:
271246
check_classification_targets(y)
272247
X, y, binarize_y = self._check_X_y(
@@ -280,22 +255,12 @@ def fit_resample(self, X, y):
280255
output = self._fit_resample(X, y)
281256

282257
if self.validate:
283-
if self._X_columns is not None or self._y_name is not None:
284-
import pandas as pd
285-
286-
if self._X_columns is not None:
287-
X_ = pd.DataFrame(output[0], columns=self._X_columns)
288-
X_ = X_.astype(self._X_dtypes)
289-
else:
290-
X_ = output[0]
291258

292259
y_ = (label_binarize(output[1], np.unique(y))
293260
if binarize_y else output[1])
294-
295-
if self._y_name is not None:
296-
y_ = pd.Series(y_, dtype=self._y_dtype, name=self._y_name)
297-
261+
X_, y_ = arrays_transformer.transform(output[0], y_)
298262
return (X_, y_) if len(output) == 2 else (X_, y_, output[2])
263+
299264
return output
300265

301266
def _fit_resample(self, X, y):

imblearn/over_sampling/_random_over_sampler.py

-17
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from ..utils import Substitution
1717
from ..utils._docstring import _random_state_docstring
1818

19-
2019
@Substitution(
2120
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
2221
random_state=_random_state_docstring,
@@ -75,22 +74,6 @@ def __init__(self, sampling_strategy="auto", random_state=None):
7574
self.random_state = random_state
7675

7776
def _check_X_y(self, X, y):
78-
if hasattr(X, "loc"):
79-
# store information to build dataframe
80-
self._X_columns = X.columns
81-
self._X_dtypes = X.dtypes
82-
else:
83-
self._X_columns = None
84-
self._X_dtypes = None
85-
86-
if hasattr(y, "loc"):
87-
# store information to build a series
88-
self._y_name = y.name
89-
self._y_dtype = y.dtype
90-
else:
91-
self._y_name = None
92-
self._y_dtype = None
93-
9477
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
9578
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
9679
force_all_finite=False)

imblearn/over_sampling/_smote.py

-16
Original file line numberDiff line numberDiff line change
@@ -891,22 +891,6 @@ def _check_X_y(self, X, y):
891891
"""Overwrite the checking to let pass some string for categorical
892892
features.
893893
"""
894-
if hasattr(X, "loc"):
895-
# store information to build dataframe
896-
self._X_columns = X.columns
897-
self._X_dtypes = X.dtypes
898-
else:
899-
self._X_columns = None
900-
self._X_dtypes = None
901-
902-
if hasattr(y, "loc"):
903-
# store information to build a series
904-
self._y_name = y.name
905-
self._y_dtype = y.dtype
906-
else:
907-
self._y_name = None
908-
self._y_dtype = None
909-
910894
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
911895
X, y = check_X_y(X, y, accept_sparse=["csr", "csc"], dtype=None)
912896
return X, y, binarize_y

imblearn/under_sampling/_prototype_selection/_nearmiss.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,10 @@ def _fit_resample(self, X, y):
247247
_safe_indexing(X, minority_class_indices)
248248
)
249249
idx_vec_farthest = np.unique(idx_vec.reshape(-1))
250-
X_class_selected = _safe_indexing(X_class, idx_vec_farthest)
251-
y_class_selected = _safe_indexing(y_class, idx_vec_farthest)
250+
X_class_selected = _safe_indexing(
251+
X_class, idx_vec_farthest)
252+
y_class_selected = _safe_indexing(
253+
y_class, idx_vec_farthest)
252254

253255
dist_vec, idx_vec = self.nn_.kneighbors(
254256
X_class_selected, n_neighbors=self.nn_.n_neighbors

imblearn/under_sampling/_prototype_selection/_random_under_sampler.py

-16
Original file line numberDiff line numberDiff line change
@@ -81,22 +81,6 @@ def __init__(
8181
self.replacement = replacement
8282

8383
def _check_X_y(self, X, y):
84-
if hasattr(X, "loc"):
85-
# store information to build dataframe
86-
self._X_columns = X.columns
87-
self._X_dtypes = X.dtypes
88-
else:
89-
self._X_columns = None
90-
self._X_dtypes = None
91-
92-
if hasattr(y, "loc"):
93-
# store information to build a series
94-
self._y_name = y.name
95-
self._y_dtype = y.dtype
96-
else:
97-
self._y_name = None
98-
self._y_dtype = None
99-
10084
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
10185
X = check_array(X, accept_sparse=["csr", "csc"], dtype=None,
10286
force_all_finite=False)

imblearn/utils/_validation.py

+36
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,42 @@
2727
TARGET_KIND = ("binary", "multiclass", "multilabel-indicator")
2828

2929

30+
class ArraysTransformer:
31+
"""A class to convert sampler ouput arrays to their orinal types."""
32+
33+
def __init__(self, X, y):
34+
self.x_props = self._gets_props(X)
35+
self.y_props = self._gets_props(y)
36+
37+
def transform(self, X, y):
38+
X = self._transfrom_one(X, self.x_props)
39+
y = self._transfrom_one(y, self.y_props)
40+
return X, y
41+
42+
def _gets_props(self, array):
43+
props = {}
44+
props["type"] = array.__class__.__name__
45+
props["columns"] = getattr(array, "columns", None)
46+
props["name"] = getattr(array, "name", None)
47+
props["dtypes"] = getattr(array, "dtypes", None)
48+
return props
49+
50+
def _transfrom_one(self, array, props):
51+
type_ = props["type"].lower()
52+
if type_ == "list":
53+
ret = array.tolist()
54+
elif type_ == "dataframe":
55+
import pandas as pd
56+
ret = pd.DataFrame(array, columns=props["columns"])
57+
ret = ret.astype(props["dtypes"])
58+
elif type_ == "series":
59+
import pandas as pd
60+
ret = pd.Series(array, dtype=props["dtypes"], name=props["name"])
61+
else:
62+
ret = array
63+
return ret
64+
65+
3066
def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
3167
"""Check the objects is consistent to be a NN.
3268

imblearn/utils/estimator_checks.py

+48-10
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def _yield_sampler_checks(name, Estimator):
4141
yield check_samplers_sampling_strategy_fit_resample
4242
yield check_samplers_sparse
4343
yield check_samplers_pandas
44+
yield check_samplers_list
4445
yield check_samplers_multiclass_ova
4546
yield check_samplers_preserve_dtype
4647
yield check_samplers_sample_indices
@@ -242,8 +243,9 @@ def check_samplers_pandas(name, Sampler):
242243
weights=[0.2, 0.3, 0.5],
243244
random_state=0,
244245
)
245-
X_pd = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
246-
y_pd = pd.Series(y, name="class")
246+
X_df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
247+
y_df = pd.DataFrame(y)
248+
y_s = pd.Series(y, name="class")
247249
sampler = Sampler()
248250
if isinstance(Sampler(), NearMiss):
249251
samplers = [Sampler(version=version) for version in (1, 2, 3)]
@@ -253,16 +255,52 @@ def check_samplers_pandas(name, Sampler):
253255

254256
for sampler in samplers:
255257
set_random_state(sampler)
256-
X_res_pd, y_res_pd = sampler.fit_resample(X_pd, y_pd)
258+
X_res_df, y_res_s = sampler.fit_resample(X_df, y_s)
259+
X_res_df, y_res_df = sampler.fit_resample(X_df, y_df)
257260
X_res, y_res = sampler.fit_resample(X, y)
258261

259-
# check that we return a pandas dataframe if a dataframe was given in
260-
assert isinstance(X_res_pd, pd.DataFrame)
261-
assert isinstance(y_res_pd, pd.Series)
262-
assert X_pd.columns.to_list() == X_res_pd.columns.to_list()
263-
assert y_pd.name == y_res_pd.name
264-
assert_allclose(X_res_pd.to_numpy(), X_res)
265-
assert_allclose(y_res_pd.to_numpy(), y_res)
262+
# check that we return the same type for dataframes or series types
263+
assert isinstance(X_res_df, pd.DataFrame)
264+
assert isinstance(y_res_df, pd.DataFrame)
265+
assert isinstance(y_res_s, pd.Series)
266+
267+
assert X_df.columns.to_list() == X_res_df.columns.to_list()
268+
assert y_df.columns.to_list() == y_res_df.columns.to_list()
269+
assert y_s.name == y_res_s.name
270+
271+
assert_allclose(X_res_df.to_numpy(), X_res)
272+
assert_allclose(y_res_df.to_numpy().ravel(), y_res)
273+
assert_allclose(y_res_s.to_numpy(), y_res)
274+
275+
276+
def check_samplers_list(name, Sampler):
277+
# Check that the can samplers handle simple lists
278+
X, y = make_classification(
279+
n_samples=1000,
280+
n_classes=3,
281+
n_informative=4,
282+
weights=[0.2, 0.3, 0.5],
283+
random_state=0,
284+
)
285+
X_list = X.tolist()
286+
y_list = y.tolist()
287+
sampler = Sampler()
288+
if isinstance(sampler, NearMiss):
289+
samplers = [Sampler(version=version) for version in (1, 2, 3)]
290+
291+
else:
292+
samplers = [sampler]
293+
294+
for sampler in samplers:
295+
set_random_state(sampler)
296+
X_res, y_res = sampler.fit_resample(X, y)
297+
X_res_list, y_res_list = sampler.fit_resample(X_list, y_list)
298+
299+
assert isinstance(X_res_list, list)
300+
assert isinstance(y_res_list, list)
301+
302+
assert_allclose(X_res, X_res_list)
303+
assert_allclose(y_res, y_res_list)
266304

267305

268306
def check_samplers_multiclass_ova(name, Sampler):

imblearn/utils/tests/test_validation.py

+51
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from imblearn.utils import check_neighbors_object
1818
from imblearn.utils import check_sampling_strategy
1919
from imblearn.utils import check_target_type
20+
from imblearn.utils._validation import ArraysTransformer
2021

2122
multiclass_target = np.array([1] * 50 + [2] * 100 + [3] * 25)
2223
binary_target = np.array([1] * 25 + [0] * 100)
@@ -315,3 +316,53 @@ def test_sampling_strategy_check_order(
315316
sampling_strategy, y, sampling_type
316317
)
317318
assert sampling_strategy_ == expected_result
319+
320+
321+
def test_arrays_transformer_plain_list():
322+
X = np.array([[0, 0], [1, 1]])
323+
y = np.array([[0, 0], [1, 1]])
324+
325+
arrays_transformer = ArraysTransformer(X.tolist(), y.tolist())
326+
X_res, y_res = arrays_transformer.transform(X, y)
327+
assert isinstance(X_res, list)
328+
assert isinstance(y_res, list)
329+
330+
331+
def test_arrays_transformer_numpy():
332+
X = np.array([[0, 0], [1, 1]])
333+
y = np.array([[0, 0], [1, 1]])
334+
335+
arrays_transformer = ArraysTransformer(X, y)
336+
X_res, y_res = arrays_transformer.transform(X, y)
337+
assert isinstance(X_res, np.ndarray)
338+
assert isinstance(y_res, np.ndarray)
339+
340+
341+
def test_arrays_transformer_pandas():
342+
pd = pytest.importorskip("pandas")
343+
344+
X = np.array([[0, 0], [1, 1]])
345+
y = np.array([0, 1])
346+
347+
X_df = pd.DataFrame(X, columns=["a", "b"])
348+
X_df = X_df.astype(int)
349+
y_df = pd.DataFrame(y, columns=["target", ])
350+
y_df = y_df.astype(int)
351+
y_s = pd.Series(y, name="target", dtype=int)
352+
353+
# DataFrame and DataFrame case
354+
arrays_transformer = ArraysTransformer(X_df, y_df)
355+
X_res, y_res = arrays_transformer.transform(X, y)
356+
assert isinstance(X_res, pd.DataFrame)
357+
assert_array_equal(X_res.columns, X_df.columns)
358+
assert_array_equal(X_res.dtypes, X_df.dtypes)
359+
assert isinstance(y_res, pd.DataFrame)
360+
assert_array_equal(y_res.columns, y_df.columns)
361+
assert_array_equal(y_res.dtypes, y_df.dtypes)
362+
363+
# DataFrames and Series case
364+
arrays_transformer = ArraysTransformer(X_df, y_s)
365+
_, y_res = arrays_transformer.transform(X, y)
366+
assert isinstance(y_res, pd.Series)
367+
assert_array_equal(y_res.name, y_s.name)
368+
assert_array_equal(y_res.dtype, y_s.dtype)

0 commit comments

Comments
 (0)