10
10
from sklearn .utils import _safe_indexing
11
11
12
12
from ..base import BaseUnderSampler
13
- from ...dask ._support import is_dask_container
13
+ from ...dask ._support import is_dask_collection
14
14
from ...utils import check_target_type
15
15
from ...utils import Substitution
16
- from ...utils ._docstring import _random_state_docstring
16
+ from ...utils ._docstring import (
17
+ _random_state_docstring ,
18
+ _validate_if_dask_collection_docstring
19
+ )
17
20
from ...utils ._validation import _deprecate_positional_args
18
21
19
22
20
23
@Substitution (
21
24
sampling_strategy = BaseUnderSampler ._sampling_strategy_docstring ,
22
25
random_state = _random_state_docstring ,
26
+ validate_if_dask_collection = _validate_if_dask_collection_docstring ,
23
27
)
24
28
class RandomUnderSampler (BaseUnderSampler ):
25
29
"""Class to perform random under-sampling.
@@ -38,6 +42,8 @@ class RandomUnderSampler(BaseUnderSampler):
38
42
replacement : bool, default=False
39
43
Whether the sample is with or without replacement.
40
44
45
+ {validate_if_dask_collection}
46
+
41
47
Attributes
42
48
----------
43
49
sample_indices_ : ndarray of shape (n_new_samples,)
@@ -74,22 +80,23 @@ class RandomUnderSampler(BaseUnderSampler):
74
80
75
81
@_deprecate_positional_args
76
82
def __init__ (
77
- self , * , sampling_strategy = "auto" , random_state = None , replacement = False
83
+ self ,
84
+ * ,
85
+ sampling_strategy = "auto" ,
86
+ random_state = None ,
87
+ replacement = False ,
88
+ validate_if_dask_collection = False ,
78
89
):
79
- super ().__init__ (sampling_strategy = sampling_strategy )
90
+ super ().__init__ (
91
+ sampling_strategy = sampling_strategy ,
92
+ validate_if_dask_collection = validate_if_dask_collection ,
93
+ )
80
94
self .random_state = random_state
81
95
self .replacement = replacement
82
96
83
97
def _check_X_y (self , X , y ):
84
- if is_dask_container (y ) and hasattr (y , "to_dask_array" ):
85
- y = y .to_dask_array ()
86
- y .compute_chunk_sizes ()
87
- y , binarize_y , self ._uniques = check_target_type (
88
- y ,
89
- indicate_one_vs_all = True ,
90
- return_unique = True ,
91
- )
92
- if not any ([is_dask_container (arr ) for arr in (X , y )]):
98
+ y , binarize_y = check_target_type (y , indicate_one_vs_all = True )
99
+ if not any ([is_dask_collection (arr ) for arr in (X , y )]):
93
100
X , y = self ._validate_data (
94
101
X ,
95
102
y ,
@@ -98,24 +105,23 @@ def _check_X_y(self, X, y):
98
105
dtype = None ,
99
106
force_all_finite = False ,
100
107
)
101
- elif is_dask_container (X ) and hasattr (X , "to_dask_array" ):
102
- X = X .to_dask_array ()
103
- X .compute_chunk_sizes ()
104
108
return X , y , binarize_y
105
109
106
110
@staticmethod
107
111
def _find_target_class_indices (y , target_class ):
108
112
target_class_indices = np .flatnonzero (y == target_class )
109
- if is_dask_container (y ):
110
- return target_class_indices .compute ()
113
+ if is_dask_collection (y ):
114
+ from dask import compute
115
+
116
+ return compute (target_class_indices )[0 ]
111
117
return target_class_indices
112
118
113
119
def _fit_resample (self , X , y ):
114
120
random_state = check_random_state (self .random_state )
115
121
116
122
idx_under = []
117
123
118
- for target_class in self ._uniques :
124
+ for target_class in self ._classes_counts :
119
125
target_class_indices = self ._find_target_class_indices (
120
126
y , target_class
121
127
)
0 commit comments