17
17
from imblearn .utils import check_neighbors_object
18
18
from imblearn .utils import check_sampling_strategy
19
19
from imblearn .utils import check_target_type
20
+ from imblearn .utils import get_classes_counts
20
21
from imblearn .utils ._validation import ArraysTransformer
21
22
from imblearn .utils ._validation import _deprecate_positional_args
22
23
23
24
multiclass_target = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
25
+ multiclass_classes_counts = get_classes_counts (multiclass_target )
24
26
binary_target = np .array ([1 ] * 25 + [0 ] * 100 )
27
+ binary_classes_counts = get_classes_counts (binary_target )
25
28
26
29
27
30
def test_check_neighbors_object ():
@@ -70,11 +73,11 @@ def test_check_target_type_ova(target, output_target, is_ova):
70
73
assert binarize_target == is_ova
71
74
72
75
73
- def test_check_sampling_strategy_warning ():
76
+ def test_check_sampling_strategy_error_dict_cleaning_methods ():
74
77
msg = "dict for cleaning methods is not supported"
75
78
with pytest .raises (ValueError , match = msg ):
76
79
check_sampling_strategy (
77
- {1 : 0 , 2 : 0 , 3 : 0 }, multiclass_target , "clean-sampling"
80
+ {1 : 0 , 2 : 0 , 3 : 0 }, multiclass_classes_counts , "clean-sampling"
78
81
)
79
82
80
83
@@ -83,19 +86,19 @@ def test_check_sampling_strategy_warning():
83
86
[
84
87
(
85
88
0.5 ,
86
- binary_target ,
89
+ binary_classes_counts ,
87
90
"clean-sampling" ,
88
91
"'clean-sampling' methods do let the user specify the sampling ratio" , # noqa
89
92
),
90
93
(
91
94
0.1 ,
92
- np .array ([0 ] * 10 + [1 ] * 20 ),
95
+ get_classes_counts ( np .array ([0 ] * 10 + [1 ] * 20 ) ),
93
96
"over-sampling" ,
94
97
"remove samples from the minority class while trying to generate new" , # noqa
95
98
),
96
99
(
97
100
0.1 ,
98
- np .array ([0 ] * 10 + [1 ] * 20 ),
101
+ get_classes_counts ( np .array ([0 ] * 10 + [1 ] * 20 ) ),
99
102
"under-sampling" ,
100
103
"generate new sample in the majority class while trying to remove" ,
101
104
),
@@ -108,15 +111,21 @@ def test_check_sampling_strategy_float_error(ratio, y, type, err_msg):
108
111
109
112
def test_check_sampling_strategy_error ():
110
113
with pytest .raises (ValueError , match = "'sampling_type' should be one of" ):
111
- check_sampling_strategy ("auto" , np .array ([1 , 2 , 3 ]), "rnd" )
114
+ check_sampling_strategy (
115
+ "auto" , get_classes_counts (np .array ([1 , 2 , 3 ])), "rnd"
116
+ )
112
117
113
118
error_regex = "The target 'y' needs to have more than 1 class."
114
119
with pytest .raises (ValueError , match = error_regex ):
115
- check_sampling_strategy ("auto" , np .ones ((10 ,)), "over-sampling" )
120
+ check_sampling_strategy (
121
+ "auto" , get_classes_counts (np .ones ((10 ,))), "over-sampling"
122
+ )
116
123
117
124
error_regex = "When 'sampling_strategy' is a string, it needs to be one of"
118
125
with pytest .raises (ValueError , match = error_regex ):
119
- check_sampling_strategy ("rnd" , np .array ([1 , 2 , 3 ]), "over-sampling" )
126
+ check_sampling_strategy (
127
+ "rnd" , get_classes_counts (np .array ([1 , 2 , 3 ])), "over-sampling"
128
+ )
120
129
121
130
122
131
@pytest .mark .parametrize (
@@ -136,7 +145,9 @@ def test_check_sampling_strategy_error_wrong_string(
136
145
),
137
146
):
138
147
check_sampling_strategy (
139
- sampling_strategy , np .array ([1 , 2 , 3 ]), sampling_type
148
+ sampling_strategy ,
149
+ get_classes_counts (np .array ([1 , 2 , 3 ])),
150
+ sampling_type ,
140
151
)
141
152
142
153
@@ -153,14 +164,18 @@ def test_sampling_strategy_class_target_unknown(
153
164
):
154
165
y = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
155
166
with pytest .raises (ValueError , match = "are not present in the data." ):
156
- check_sampling_strategy (sampling_strategy , y , sampling_method )
167
+ check_sampling_strategy (
168
+ sampling_strategy , get_classes_counts (y ), sampling_method
169
+ )
157
170
158
171
159
172
def test_sampling_strategy_dict_error ():
160
173
y = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
161
174
sampling_strategy = {1 : - 100 , 2 : 50 , 3 : 25 }
162
175
with pytest .raises (ValueError , match = "in a class cannot be negative." ):
163
- check_sampling_strategy (sampling_strategy , y , "under-sampling" )
176
+ check_sampling_strategy (
177
+ sampling_strategy , get_classes_counts (y ), "under-sampling"
178
+ )
164
179
sampling_strategy = {1 : 45 , 2 : 100 , 3 : 70 }
165
180
error_regex = (
166
181
"With over-sampling methods, the number of samples in a"
@@ -169,7 +184,9 @@ def test_sampling_strategy_dict_error():
169
184
" samples are asked."
170
185
)
171
186
with pytest .raises (ValueError , match = error_regex ):
172
- check_sampling_strategy (sampling_strategy , y , "over-sampling" )
187
+ check_sampling_strategy (
188
+ sampling_strategy , get_classes_counts (y ), "over-sampling"
189
+ )
173
190
174
191
error_regex = (
175
192
"With under-sampling methods, the number of samples in a"
@@ -178,21 +195,27 @@ def test_sampling_strategy_dict_error():
178
195
" are asked."
179
196
)
180
197
with pytest .raises (ValueError , match = error_regex ):
181
- check_sampling_strategy (sampling_strategy , y , "under-sampling" )
198
+ check_sampling_strategy (
199
+ sampling_strategy , get_classes_counts (y ), "under-sampling"
200
+ )
182
201
183
202
184
203
@pytest .mark .parametrize ("sampling_strategy" , [- 10 , 10 ])
185
204
def test_sampling_strategy_float_error_not_in_range (sampling_strategy ):
186
205
y = np .array ([1 ] * 50 + [2 ] * 100 )
187
206
with pytest .raises (ValueError , match = "it should be in the range" ):
188
- check_sampling_strategy (sampling_strategy , y , "under-sampling" )
207
+ check_sampling_strategy (
208
+ sampling_strategy , get_classes_counts (y ), "under-sampling"
209
+ )
189
210
190
211
191
212
def test_sampling_strategy_float_error_not_binary ():
192
213
y = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
193
214
with pytest .raises (ValueError , match = "the type of target is binary" ):
194
215
sampling_strategy = 0.5
195
- check_sampling_strategy (sampling_strategy , y , "under-sampling" )
216
+ check_sampling_strategy (
217
+ sampling_strategy , get_classes_counts (y ), "under-sampling"
218
+ )
196
219
197
220
198
221
@pytest .mark .parametrize (
@@ -202,7 +225,9 @@ def test_sampling_strategy_list_error_not_clean_sampling(sampling_method):
202
225
y = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
203
226
with pytest .raises (ValueError , match = "cannot be a list for samplers" ):
204
227
sampling_strategy = [1 , 2 , 3 ]
205
- check_sampling_strategy (sampling_strategy , y , sampling_method )
228
+ check_sampling_strategy (
229
+ sampling_strategy , get_classes_counts (y ), sampling_method
230
+ )
206
231
207
232
208
233
def _sampling_strategy_func (y ):
@@ -215,42 +240,87 @@ def _sampling_strategy_func(y):
215
240
@pytest .mark .parametrize (
216
241
"sampling_strategy, sampling_type, expected_sampling_strategy, target" ,
217
242
[
218
- ("auto" , "under-sampling" , {1 : 25 , 2 : 25 }, multiclass_target ),
219
- ("auto" , "clean-sampling" , {1 : 25 , 2 : 25 }, multiclass_target ),
220
- ("auto" , "over-sampling" , {1 : 50 , 3 : 75 }, multiclass_target ),
221
- ("all" , "over-sampling" , {1 : 50 , 2 : 0 , 3 : 75 }, multiclass_target ),
222
- ("all" , "under-sampling" , {1 : 25 , 2 : 25 , 3 : 25 }, multiclass_target ),
223
- ("all" , "clean-sampling" , {1 : 25 , 2 : 25 , 3 : 25 }, multiclass_target ),
224
- ("majority" , "under-sampling" , {2 : 25 }, multiclass_target ),
225
- ("majority" , "clean-sampling" , {2 : 25 }, multiclass_target ),
226
- ("minority" , "over-sampling" , {3 : 75 }, multiclass_target ),
227
- ("not minority" , "over-sampling" , {1 : 50 , 2 : 0 }, multiclass_target ),
228
- ("not minority" , "under-sampling" , {1 : 25 , 2 : 25 }, multiclass_target ),
229
- ("not minority" , "clean-sampling" , {1 : 25 , 2 : 25 }, multiclass_target ),
230
- ("not majority" , "over-sampling" , {1 : 50 , 3 : 75 }, multiclass_target ),
231
- ("not majority" , "under-sampling" , {1 : 25 , 3 : 25 }, multiclass_target ),
232
- ("not majority" , "clean-sampling" , {1 : 25 , 3 : 25 }, multiclass_target ),
243
+ ("auto" , "under-sampling" , {1 : 25 , 2 : 25 }, multiclass_classes_counts ),
244
+ ("auto" , "clean-sampling" , {1 : 25 , 2 : 25 }, multiclass_classes_counts ),
245
+ ("auto" , "over-sampling" , {1 : 50 , 3 : 75 }, multiclass_classes_counts ),
246
+ (
247
+ "all" ,
248
+ "over-sampling" ,
249
+ {1 : 50 , 2 : 0 , 3 : 75 },
250
+ multiclass_classes_counts ,
251
+ ),
252
+ (
253
+ "all" ,
254
+ "under-sampling" ,
255
+ {1 : 25 , 2 : 25 , 3 : 25 },
256
+ multiclass_classes_counts ,
257
+ ),
258
+ (
259
+ "all" ,
260
+ "clean-sampling" ,
261
+ {1 : 25 , 2 : 25 , 3 : 25 },
262
+ multiclass_classes_counts ,
263
+ ),
264
+ ("majority" , "under-sampling" , {2 : 25 }, multiclass_classes_counts ),
265
+ ("majority" , "clean-sampling" , {2 : 25 }, multiclass_classes_counts ),
266
+ ("minority" , "over-sampling" , {3 : 75 }, multiclass_classes_counts ),
267
+ (
268
+ "not minority" ,
269
+ "over-sampling" ,
270
+ {1 : 50 , 2 : 0 },
271
+ multiclass_classes_counts ,
272
+ ),
273
+ (
274
+ "not minority" ,
275
+ "under-sampling" ,
276
+ {1 : 25 , 2 : 25 },
277
+ multiclass_classes_counts ,
278
+ ),
279
+ (
280
+ "not minority" ,
281
+ "clean-sampling" ,
282
+ {1 : 25 , 2 : 25 },
283
+ multiclass_classes_counts ,
284
+ ),
285
+ (
286
+ "not majority" ,
287
+ "over-sampling" ,
288
+ {1 : 50 , 3 : 75 },
289
+ multiclass_classes_counts ,
290
+ ),
291
+ (
292
+ "not majority" ,
293
+ "under-sampling" ,
294
+ {1 : 25 , 3 : 25 },
295
+ multiclass_classes_counts ,
296
+ ),
297
+ (
298
+ "not majority" ,
299
+ "clean-sampling" ,
300
+ {1 : 25 , 3 : 25 },
301
+ multiclass_classes_counts ,
302
+ ),
233
303
(
234
304
{1 : 70 , 2 : 100 , 3 : 70 },
235
305
"over-sampling" ,
236
306
{1 : 20 , 2 : 0 , 3 : 45 },
237
- multiclass_target ,
307
+ multiclass_classes_counts ,
238
308
),
239
309
(
240
310
{1 : 30 , 2 : 45 , 3 : 25 },
241
311
"under-sampling" ,
242
312
{1 : 30 , 2 : 45 , 3 : 25 },
243
- multiclass_target ,
313
+ multiclass_classes_counts ,
244
314
),
245
- ([1 ], "clean-sampling" , {1 : 25 }, multiclass_target ),
315
+ ([1 ], "clean-sampling" , {1 : 25 }, multiclass_classes_counts ),
246
316
(
247
317
_sampling_strategy_func ,
248
318
"over-sampling" ,
249
319
{1 : 50 , 2 : 0 , 3 : 75 },
250
- multiclass_target ,
320
+ multiclass_classes_counts ,
251
321
),
252
- (0.5 , "over-sampling" , {1 : 25 }, binary_target ),
253
- (0.5 , "under-sampling" , {0 : 50 }, binary_target ),
322
+ (0.5 , "over-sampling" , {1 : 25 }, binary_classes_counts ),
323
+ (0.5 , "under-sampling" , {0 : 50 }, binary_classes_counts ),
254
324
],
255
325
)
256
326
def test_check_sampling_strategy (
@@ -271,23 +341,27 @@ def test_sampling_strategy_dict_over_sampling():
271
341
r" the majority class \(class #2 -> 100\)"
272
342
)
273
343
with warns (UserWarning , expected_msg ):
274
- check_sampling_strategy (sampling_strategy , y , "over-sampling" )
344
+ check_sampling_strategy (
345
+ sampling_strategy , get_classes_counts (y ), "over-sampling"
346
+ )
275
347
276
348
277
349
def test_sampling_strategy_callable_args ():
278
350
y = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
279
351
multiplier = {1 : 1.5 , 2 : 1 , 3 : 3 }
280
352
281
- def sampling_strategy_func (y , multiplier ):
353
+ def sampling_strategy_func (classes_counts , multiplier ):
282
354
"""samples such that each class will be affected by the multiplier."""
283
- target_stats = Counter (y )
284
355
return {
285
356
key : int (values * multiplier [key ])
286
- for key , values in target_stats .items ()
357
+ for key , values in classes_counts .items ()
287
358
}
288
359
289
360
sampling_strategy_ = check_sampling_strategy (
290
- sampling_strategy_func , y , "over-sampling" , multiplier = multiplier
361
+ sampling_strategy_func ,
362
+ get_classes_counts (y ),
363
+ "over-sampling" ,
364
+ multiplier = multiplier ,
291
365
)
292
366
assert sampling_strategy_ == {1 : 25 , 2 : 0 , 3 : 50 }
293
367
@@ -314,11 +388,20 @@ def test_sampling_strategy_check_order(
314
388
# dictionary is sorted. Refer to issue #428.
315
389
y = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
316
390
sampling_strategy_ = check_sampling_strategy (
317
- sampling_strategy , y , sampling_type
391
+ sampling_strategy , get_classes_counts ( y ) , sampling_type
318
392
)
319
393
assert sampling_strategy_ == expected_result
320
394
321
395
396
+ # FIXME: remove in 0.9
397
+ def test_sampling_strategy_deprecation_array_target ():
398
+ # Check that we raise a FutureWarning when an array of target is passed
399
+ with pytest .warns (FutureWarning ):
400
+ sampling_strategy = "auto"
401
+ check_sampling_strategy (
402
+ sampling_strategy , binary_target , "under-sampling" ,
403
+ )
404
+
322
405
def test_arrays_transformer_plain_list ():
323
406
X = np .array ([[0 , 0 ], [1 , 1 ]])
324
407
y = np .array ([[0 , 0 ], [1 , 1 ]])
0 commit comments