Skip to content

Commit d4aabf8

Browse files
committed
iter
1 parent 8bfa040 commit d4aabf8

File tree

7 files changed

+168
-136
lines changed

7 files changed

+168
-136
lines changed

imblearn/base.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,13 @@
1111
from sklearn.base import BaseEstimator
1212
from sklearn.preprocessing import label_binarize
1313

14+
from .dask._support import is_dask_collection
1415
from .utils import check_sampling_strategy, check_target_type
15-
from .utils._validation import ArraysTransformer
16-
from .utils._validation import _deprecate_positional_args
16+
from .utils._validation import (
17+
ArraysTransformer,
18+
_deprecate_positional_args,
19+
get_classes_counts,
20+
)
1721
from .utils.wrapper import check_classification_targets
1822

1923

@@ -45,9 +49,13 @@ def fit(self, X, y):
4549
self : object
4650
Return the instance itself.
4751
"""
48-
X, y, _ = self._check_X_y(X, y)
52+
dask_collection = any([is_dask_collection(arr) for arr in (X, y)])
53+
if (not dask_collection or
54+
(dask_collection and self.validate_if_dask_collection)):
55+
X, y, _ = self._check_X_y(X, y)
56+
self._classes_counts = get_classes_counts(y)
4957
self.sampling_strategy_ = check_sampling_strategy(
50-
self.sampling_strategy, y, self._sampling_type
58+
self.sampling_strategy, self._classes_counts, self._sampling_type
5159
)
5260
return self
5361

@@ -72,12 +80,19 @@ def fit_resample(self, X, y):
7280
y_resampled : array-like of shape (n_samples_new,)
7381
The corresponding label of `X_resampled`.
7482
"""
75-
check_classification_targets(y)
7683
arrays_transformer = ArraysTransformer(X, y)
77-
X, y, binarize_y = self._check_X_y(X, y)
84+
dask_collection = any([is_dask_collection(arr) for arr in (X, y)])
85+
if (not dask_collection or
86+
(dask_collection and self.validate_if_dask_collection)):
87+
check_classification_targets(y)
88+
X, y, binarize_y = self._check_X_y(X, y)
89+
else:
90+
X, y = arrays_transformer.to_dask_array(X, y)
91+
binarize_y = False
7892

93+
self._classes_counts = get_classes_counts(y)
7994
self.sampling_strategy_ = check_sampling_strategy(
80-
self.sampling_strategy, y, self._sampling_type
95+
self.sampling_strategy, self._classes_counts, self._sampling_type
8196
)
8297

8398
output = self._fit_resample(X, y)
@@ -125,8 +140,13 @@ class BaseSampler(SamplerMixin):
125140
instead.
126141
"""
127142

128-
def __init__(self, sampling_strategy="auto"):
143+
def __init__(
144+
self,
145+
sampling_strategy="auto",
146+
validate_if_dask_collection=False,
147+
):
129148
self.sampling_strategy = sampling_strategy
149+
self.validate_if_dask_collection = validate_if_dask_collection
130150

131151
def _check_X_y(self, X, y, accept_sparse=None):
132152
if accept_sparse is None:

imblearn/dask/_support.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
_REGISTERED_DASK_CONTAINER = []
1+
def is_dask_collection(container):
2+
try:
3+
# to keep dask as an optional depency, keep the statement in a
4+
# try/except statement
5+
from dask import is_dask_collection
26

3-
try:
4-
from dask import array, dataframe
5-
_REGISTERED_DASK_CONTAINER += [
6-
array.Array, dataframe.Series, dataframe.DataFrame,
7-
]
8-
except ImportError:
9-
pass
10-
11-
12-
def is_dask_container(container):
13-
return isinstance(container, tuple(_REGISTERED_DASK_CONTAINER))
7+
return is_dask_collection(container)
8+
except ImportError:
9+
return False

imblearn/under_sampling/_prototype_selection/_random_under_sampler.py

+25-19
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,20 @@
1010
from sklearn.utils import _safe_indexing
1111

1212
from ..base import BaseUnderSampler
13-
from ...dask._support import is_dask_container
13+
from ...dask._support import is_dask_collection
1414
from ...utils import check_target_type
1515
from ...utils import Substitution
16-
from ...utils._docstring import _random_state_docstring
16+
from ...utils._docstring import (
17+
_random_state_docstring,
18+
_validate_if_dask_collection_docstring
19+
)
1720
from ...utils._validation import _deprecate_positional_args
1821

1922

2023
@Substitution(
2124
sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
2225
random_state=_random_state_docstring,
26+
validate_if_dask_collection=_validate_if_dask_collection_docstring,
2327
)
2428
class RandomUnderSampler(BaseUnderSampler):
2529
"""Class to perform random under-sampling.
@@ -38,6 +42,8 @@ class RandomUnderSampler(BaseUnderSampler):
3842
replacement : bool, default=False
3943
Whether the sample is with or without replacement.
4044
45+
{validate_if_dask_collection}
46+
4147
Attributes
4248
----------
4349
sample_indices_ : ndarray of shape (n_new_samples,)
@@ -74,22 +80,23 @@ class RandomUnderSampler(BaseUnderSampler):
7480

7581
@_deprecate_positional_args
7682
def __init__(
77-
self, *, sampling_strategy="auto", random_state=None, replacement=False
83+
self,
84+
*,
85+
sampling_strategy="auto",
86+
random_state=None,
87+
replacement=False,
88+
validate_if_dask_collection=False,
7889
):
79-
super().__init__(sampling_strategy=sampling_strategy)
90+
super().__init__(
91+
sampling_strategy=sampling_strategy,
92+
validate_if_dask_collection=validate_if_dask_collection,
93+
)
8094
self.random_state = random_state
8195
self.replacement = replacement
8296

8397
def _check_X_y(self, X, y):
84-
if is_dask_container(y) and hasattr(y, "to_dask_array"):
85-
y = y.to_dask_array()
86-
y.compute_chunk_sizes()
87-
y, binarize_y, self._uniques = check_target_type(
88-
y,
89-
indicate_one_vs_all=True,
90-
return_unique=True,
91-
)
92-
if not any([is_dask_container(arr) for arr in (X, y)]):
98+
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
99+
if not any([is_dask_collection(arr) for arr in (X, y)]):
93100
X, y = self._validate_data(
94101
X,
95102
y,
@@ -98,24 +105,23 @@ def _check_X_y(self, X, y):
98105
dtype=None,
99106
force_all_finite=False,
100107
)
101-
elif is_dask_container(X) and hasattr(X, "to_dask_array"):
102-
X = X.to_dask_array()
103-
X.compute_chunk_sizes()
104108
return X, y, binarize_y
105109

106110
@staticmethod
107111
def _find_target_class_indices(y, target_class):
108112
target_class_indices = np.flatnonzero(y == target_class)
109-
if is_dask_container(y):
110-
return target_class_indices.compute()
113+
if is_dask_collection(y):
114+
from dask import compute
115+
116+
return compute(target_class_indices)[0]
111117
return target_class_indices
112118

113119
def _fit_resample(self, X, y):
114120
random_state = check_random_state(self.random_state)
115121

116122
idx_under = []
117123

118-
for target_class in self._uniques:
124+
for target_class in self._classes_counts:
119125
target_class_indices = self._find_target_class_indices(
120126
y, target_class
121127
)

imblearn/utils/_docstring.py

+7
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,10 @@ def __call__(self, obj):
4141
`Glossary <https://scikit-learn.org/stable/glossary.html#term-n-jobs>`_
4242
for more details.
4343
""".rstrip()
44+
45+
_validate_if_dask_collection_docstring = \
46+
"""validate_if_dask_collection : bool, default=False
47+
Whether or not `X` and `y` should be validated. This parameter applies
48+
only when `X` and `y` are Dask collections where validation might be
49+
potentially costly.
50+
""".rstrip()

0 commit comments

Comments
 (0)