Skip to content

Commit 00c0a26

Browse files
committed
iter
1 parent 7aae9d9 commit 00c0a26

File tree

6 files changed

+74
-17
lines changed

6 files changed

+74
-17
lines changed

imblearn/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
from sklearn.base import BaseEstimator
1212
from sklearn.preprocessing import label_binarize
13-
from sklearn.utils.multiclass import check_classification_targets
1413

1514
from .utils import check_sampling_strategy, check_target_type
1615
from .utils._validation import ArraysTransformer
1716
from .utils._validation import _deprecate_positional_args
17+
from .utils.wrapper import check_classification_targets
1818

1919

2020
class SamplerMixin(BaseEstimator, metaclass=ABCMeta):
@@ -82,6 +82,7 @@ def fit_resample(self, X, y):
8282

8383
output = self._fit_resample(X, y)
8484

85+
# TODO: label binarize is not implemented with dask
8586
y_ = (label_binarize(output[1], np.unique(y))
8687
if binarize_y else output[1])
8788

imblearn/dask/utils.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ def is_multilabel(y):
99
if not (y.ndim == 2 and y.shape[1] > 1):
1010
return False
1111

12-
labels = np.unique(y).compute()
12+
if hasattr(y, "unique"):
13+
labels = np.asarray(y.unique())
14+
else:
15+
labels = np.unique(y).compute()
1316

1417
return len(labels) < 3 and (
1518
y.dtype.kind in 'biu' or _is_integral_float(labels)
@@ -39,7 +42,10 @@ def type_of_target(y):
3942
# NOTE: we don't check for infinite values
4043
return 'continuous' + suffix
4144

42-
labels = np.unique(y).compute()
45+
if hasattr(y, "unique"):
46+
labels = np.asarray(y.unique())
47+
else:
48+
labels = np.unique(y).compute()
4349
if (len((labels)) > 2) or (y.ndim >= 2 and len(y[0]) > 1):
4450
# [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
4551
return 'multiclass' + suffix
@@ -63,3 +69,10 @@ def column_or_1d(y, *, warn=False):
6369
raise ValueError(
6470
f"y should be a 1d array. Got an array of shape {shape} instead."
6571
)
72+
73+
74+
def check_classification_targets(y):
75+
y_type = type_of_target(y)
76+
if y_type not in ['binary', 'multiclass', 'multiclass-multioutput',
77+
'multilabel-indicator', 'multilabel-sequences']:
78+
raise ValueError("Unknown label type: %r" % y_type)

imblearn/under_sampling/_prototype_selection/_random_under_sampler.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def __init__(
8181
self.replacement = replacement
8282

8383
def _check_X_y(self, X, y):
84+
if is_dask_container(y) and hasattr(y, "to_dask_array"):
85+
y = y.to_dask_array()
86+
y.compute_chunk_sizes()
8487
y, binarize_y, self._uniques = check_target_type(
8588
y,
8689
indicate_one_vs_all=True,
@@ -95,6 +98,9 @@ def _check_X_y(self, X, y):
9598
dtype=None,
9699
force_all_finite=False,
97100
)
101+
elif is_dask_container(X) and hasattr(X, "to_dask_array"):
102+
X = X.to_dask_array()
103+
X.compute_chunk_sizes()
98104
return X, y, binarize_y
99105

100106
@staticmethod
@@ -140,7 +146,7 @@ def _more_tags(self):
140146
"2darray",
141147
"string",
142148
"dask-array",
143-
# "dask-dataframe"
149+
"dask-dataframe"
144150
],
145151
"sample_indices": True,
146152
"allow_nan": True,

imblearn/utils/_validation.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def transform(self, X, y):
4747
def _gets_props(self, array):
4848
props = {}
4949
props["type"] = array.__class__.__name__
50+
if props["type"].lower() in ("series", "dataframe"):
51+
suffix = "dask-" if is_dask_container(array) else "pandas-"
52+
props["type"] = suffix + props["type"]
5053
props["columns"] = getattr(array, "columns", None)
5154
props["name"] = getattr(array, "name", None)
5255
props["dtypes"] = getattr(array, "dtypes", None)
@@ -56,13 +59,34 @@ def _transfrom_one(self, array, props):
5659
type_ = props["type"].lower()
5760
if type_ == "list":
5861
ret = array.tolist()
59-
elif type_ == "dataframe":
62+
elif type_ == "pandas-dataframe":
6063
import pandas as pd
64+
6165
ret = pd.DataFrame(array, columns=props["columns"])
6266
ret = ret.astype(props["dtypes"])
63-
elif type_ == "series":
67+
elif type_ == "pandas-series":
6468
import pandas as pd
69+
6570
ret = pd.Series(array, dtype=props["dtypes"], name=props["name"])
71+
elif type_ == "dask-dataframe":
72+
from dask import dataframe
73+
74+
if is_dask_container(array):
75+
ret = dataframe.from_dask_array(
76+
array, columns=props["columns"]
77+
)
78+
else:
79+
ret = dataframe.from_array(array, columns=props["columns"])
80+
ret = ret.astype(props["dtypes"])
81+
elif type_ == "dask-series":
82+
from dask import dataframe
83+
84+
if is_dask_container(array):
85+
ret = dataframe.from_dask_array(array)
86+
else:
87+
ret = dataframe.from_array(array)
88+
ret = ret.astype(props["dtypes"])
89+
ret = ret.rename(props["name"])
6690
else:
6791
ret = array
6892
return ret

imblearn/utils/estimator_checks.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def check_samplers_dask_dataframe(name, sampler):
333333
X, columns=[str(i) for i in range(X.shape[1])]
334334
)
335335
y_s = dask.dataframe.from_array(y)
336+
y_s = y_s.rename("target")
336337

337338
X_res_df, y_res_s = sampler.fit_resample(X_df, y_s)
338339
X_res, y_res = sampler.fit_resample(X, y)
@@ -341,13 +342,11 @@ def check_samplers_dask_dataframe(name, sampler):
341342
assert isinstance(X_res_df, dask.dataframe.DataFrame)
342343
assert isinstance(y_res_s, dask.dataframe.Series)
343344

344-
# assert X_df.columns.to_list() == X_res_df.columns.to_list()
345-
# assert y_df.columns.to_list() == y_res_df.columns.to_list()
346-
# assert y_s.name == y_res_s.name
345+
assert X_df.columns.to_list() == X_res_df.columns.to_list()
346+
assert y_s.name == y_res_s.name
347347

348-
# assert_allclose(X_res_df.to_numpy(), X_res)
349-
# assert_allclose(y_res_df.to_numpy().ravel(), y_res)
350-
# assert_allclose(y_res_s.to_numpy(), y_res)
348+
assert_allclose(np.array(X_res_df), X_res)
349+
assert_allclose(np.array(y_res_s), y_res)
351350

352351

353352
def check_samplers_list(name, sampler):

imblearn/utils/wrapper.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy as np
22

3+
from sklearn.utils.multiclass import check_classification_targets as \
4+
sklearn_check_classification_targets
35
from sklearn.utils.multiclass import type_of_target as sklearn_type_of_target
46
from sklearn.utils.validation import column_or_1d as sklearn_column_or_1d
57

@@ -30,8 +32,20 @@ def column_or_1d(y, *, warn=False):
3032
return sklearn_column_or_1d(y, warn=warn)
3133

3234

33-
def unique(*args, **kwargs):
34-
output = np.unique(args, kwargs)
35-
if is_dask_container(output):
36-
return (arr.compute() for arr in output)
37-
return output
35+
def unique(arr, **kwargs):
36+
if is_dask_container(arr):
37+
if hasattr(arr, "unique"):
38+
output = np.asarray(arr.unique(**kwargs))
39+
else:
40+
output = np.unique(arr).compute()
41+
return output
42+
return np.unique(arr, **kwargs)
43+
44+
45+
def check_classification_targets(y):
46+
if is_dask_container(y):
47+
from ..dask.utils import check_classification_targets as \
48+
dask_check_classification_targets
49+
50+
return dask_check_classification_targets(y)
51+
return sklearn_check_classification_targets(y)

0 commit comments

Comments
 (0)