Skip to content

Commit 4cd9116

Browse files
committed
iter
1 parent 20b44c6 commit 4cd9116

File tree

4 files changed

+66
-26
lines changed

4 files changed

+66
-26
lines changed

imblearn/base.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
_deprecate_positional_args,
1919
get_classes_counts,
2020
)
21-
from .utils.wrapper import check_classification_targets
21+
from .utils.wrapper import (
22+
check_classification_targets,
23+
label_binarize,
24+
)
2225

2326

2427
class SamplerMixin(BaseEstimator, metaclass=ABCMeta):
@@ -49,7 +52,11 @@ def fit(self, X, y):
4952
self : object
5053
Return the instance itself.
5154
"""
55+
arrays_transformer = ArraysTransformer(X, y)
5256
dask_collection = any([is_dask_collection(arr) for arr in (X, y)])
57+
if dask_collection:
58+
X, y = arrays_transformer.to_dask_array(X, y)
59+
5360
if (not dask_collection or
5461
(dask_collection and self.validate_if_dask_collection)):
5562
X, y, _ = self._check_X_y(X, y)
@@ -83,12 +90,14 @@ def fit_resample(self, X, y):
8390
"""
8491
arrays_transformer = ArraysTransformer(X, y)
8592
dask_collection = any([is_dask_collection(arr) for arr in (X, y)])
93+
if dask_collection:
94+
X, y = arrays_transformer.to_dask_array(X, y)
95+
8696
if (not dask_collection or
8797
(dask_collection and self.validate_if_dask_collection)):
8898
check_classification_targets(y)
8999
X, y, binarize_y = self._check_X_y(X, y)
90100
else:
91-
X, y = arrays_transformer.to_dask_array(X, y)
92101
binarize_y = False
93102

94103
self._classes_counts = get_classes_counts(y)
@@ -98,9 +107,10 @@ def fit_resample(self, X, y):
98107

99108
output = self._fit_resample(X, y)
100109

101-
# TODO: label binarize is not implemented with dask
102-
y_ = (label_binarize(output[1], np.unique(y))
103-
if binarize_y else output[1])
110+
if binarize_y:
111+
y_ = label_binarize(output[1], classes=np.unique(y))
112+
else:
113+
y_ = output[1]
104114

105115
X_, y_ = arrays_transformer.transform(output[0], y_)
106116
return (X_, y_) if len(output) == 2 else (X_, y_, output[2])
@@ -281,9 +291,10 @@ def fit_resample(self, X, y):
281291
output = self._fit_resample(X, y)
282292

283293
if self.validate:
284-
285-
y_ = (label_binarize(output[1], np.unique(y))
286-
if binarize_y else output[1])
294+
if binarize_y:
295+
y_ = label_binarize(output[1], classes=np.unique(y))
296+
else:
297+
y_ = output[1]
287298
X_, y_ = arrays_transformer.transform(output[0], y_)
288299
return (X_, y_) if len(output) == 2 else (X_, y_, output[2])
289300

imblearn/dask/preprocessing.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import numpy as np
2+
3+
4+
def label_binarize(y, *, classes):
5+
import pandas as pd
6+
from dask import dataframe
7+
8+
cat_dtype = pd.CategoricalDtype(categories=classes)
9+
y = dataframe.from_array(y).astype(cat_dtype)
10+
return dataframe.get_dummies(y).to_dask_array()

imblearn/utils/estimator_checks.py

+28-18
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,10 @@ def check_samplers_pandas(name, sampler):
295295
assert_allclose(y_res_s.to_numpy(), y_res)
296296

297297

298-
def check_samplers_dask_array(name, sampler):
298+
def check_samplers_dask_array(name, sampler_orig):
299299
pytest.importorskip("dask")
300300
from dask import array
301+
sampler = clone(sampler_orig)
301302
# Check that the samplers handle dask array
302303
X, y = make_classification(
303304
n_samples=1000,
@@ -309,20 +310,25 @@ def check_samplers_dask_array(name, sampler):
309310
X_dask = array.from_array(X, chunks=100)
310311
y_dask = array.from_array(y, chunks=100)
311312

312-
X_res_dask, y_res_dask = sampler.fit_resample(X_dask, y_dask)
313-
X_res, y_res = sampler.fit_resample(X, y)
313+
for validate_if_dask_collection in (True, False):
314+
sampler.set_params(
315+
validate_if_dask_collection=validate_if_dask_collection
316+
)
317+
X_res_dask, y_res_dask = sampler.fit_resample(X_dask, y_dask)
318+
X_res, y_res = sampler.fit_resample(X, y)
314319

315-
# check that we return the same type for dataframes or series types
316-
assert isinstance(X_res_dask, array.Array)
317-
assert isinstance(y_res_dask, array.Array)
320+
# check that we return the same type for dataframes or series types
321+
assert isinstance(X_res_dask, array.Array)
322+
assert isinstance(y_res_dask, array.Array)
318323

319-
assert_allclose(X_res_dask, X_res)
320-
assert_allclose(y_res_dask, y_res)
324+
assert_allclose(X_res_dask, X_res)
325+
assert_allclose(y_res_dask, y_res)
321326

322327

323-
def check_samplers_dask_dataframe(name, sampler):
328+
def check_samplers_dask_dataframe(name, sampler_orig):
324329
pytest.importorskip("dask")
325330
from dask import dataframe
331+
sampler = clone(sampler_orig)
326332
# Check that the samplers handle dask dataframe and dask series
327333
X, y = make_classification(
328334
n_samples=1000,
@@ -337,18 +343,22 @@ def check_samplers_dask_dataframe(name, sampler):
337343
y_s = dataframe.from_array(y)
338344
y_s = y_s.rename("target")
339345

340-
X_res_df, y_res_s = sampler.fit_resample(X_df, y_s)
341-
X_res, y_res = sampler.fit_resample(X, y)
346+
for validate_if_dask_collection in (True, False):
347+
sampler.set_params(
348+
validate_if_dask_collection=validate_if_dask_collection
349+
)
350+
X_res_df, y_res_s = sampler.fit_resample(X_df, y_s)
351+
X_res, y_res = sampler.fit_resample(X, y)
342352

343-
# check that we return the same type for dataframes or series types
344-
assert isinstance(X_res_df, dataframe.DataFrame)
345-
assert isinstance(y_res_s, dataframe.Series)
353+
# check that we return the same type for dataframes or series types
354+
assert isinstance(X_res_df, dataframe.DataFrame)
355+
assert isinstance(y_res_s, dataframe.Series)
346356

347-
assert X_df.columns.to_list() == X_res_df.columns.to_list()
348-
assert y_s.name == y_res_s.name
357+
assert X_df.columns.to_list() == X_res_df.columns.to_list()
358+
assert y_s.name == y_res_s.name
349359

350-
assert_allclose(np.array(X_res_df), X_res)
351-
assert_allclose(np.array(y_res_s), y_res)
360+
assert_allclose(np.array(X_res_df), X_res)
361+
assert_allclose(np.array(y_res_s), y_res)
352362

353363

354364
def check_samplers_list(name, sampler):

imblearn/utils/wrapper.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
from sklearn.preprocessing import label_binarize as sklearn_label_binarize
34
from sklearn.utils.multiclass import check_classification_targets as \
45
sklearn_check_classification_targets
56
from sklearn.utils.multiclass import type_of_target as sklearn_type_of_target
@@ -49,3 +50,11 @@ def check_classification_targets(y):
4950

5051
return dask_check_classification_targets(y)
5152
return sklearn_check_classification_targets(y)
53+
54+
55+
def label_binarize(y, *, classes):
56+
if is_dask_collection(y):
57+
from ..dask.preprocessing import label_binarize as dask_label_binarize
58+
59+
return dask_label_binarize(y, classes=classes)
60+
return sklearn_label_binarize(y, classes=classes)

0 commit comments

Comments
 (0)