Skip to content

Commit c7bdc74

Browse files
committed
check raise FutureWarning
1 parent 36a0aa3 commit c7bdc74

File tree

3 files changed

+129
-44
lines changed

3 files changed

+129
-44
lines changed

doc/api.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,6 @@ Imbalance-learn provides some fast-prototyping tools.
248248
:toctree: generated/
249249
:template: function.rst
250250

251-
utils.estimator_checks.parametrize_with_checks
252251
utils.check_neighbors_object
253252
utils.check_sampling_strategy
253+
utils.get_classes_counts

imblearn/utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
from ._validation import check_neighbors_object
88
from ._validation import check_target_type
99
from ._validation import check_sampling_strategy
10+
from ._validation import get_classes_counts
1011

1112
__all__ = [
1213
"check_neighbors_object",
1314
"check_sampling_strategy",
1415
"check_target_type",
16+
"get_classes_counts",
1517
"Substitution",
1618
]

imblearn/utils/tests/test_validation.py

+126-43
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717
from imblearn.utils import check_neighbors_object
1818
from imblearn.utils import check_sampling_strategy
1919
from imblearn.utils import check_target_type
20+
from imblearn.utils import get_classes_counts
2021
from imblearn.utils._validation import ArraysTransformer
2122
from imblearn.utils._validation import _deprecate_positional_args
2223

2324
multiclass_target = np.array([1] * 50 + [2] * 100 + [3] * 25)
25+
multiclass_classes_counts = get_classes_counts(multiclass_target)
2426
binary_target = np.array([1] * 25 + [0] * 100)
27+
binary_classes_counts = get_classes_counts(binary_target)
2528

2629

2730
def test_check_neighbors_object():
@@ -70,11 +73,11 @@ def test_check_target_type_ova(target, output_target, is_ova):
7073
assert binarize_target == is_ova
7174

7275

73-
def test_check_sampling_strategy_warning():
76+
def test_check_sampling_strategy_error_dict_cleaning_methods():
7477
msg = "dict for cleaning methods is not supported"
7578
with pytest.raises(ValueError, match=msg):
7679
check_sampling_strategy(
77-
{1: 0, 2: 0, 3: 0}, multiclass_target, "clean-sampling"
80+
{1: 0, 2: 0, 3: 0}, multiclass_classes_counts, "clean-sampling"
7881
)
7982

8083

@@ -83,19 +86,19 @@ def test_check_sampling_strategy_warning():
8386
[
8487
(
8588
0.5,
86-
binary_target,
89+
binary_classes_counts,
8790
"clean-sampling",
8891
"'clean-sampling' methods do let the user specify the sampling ratio", # noqa
8992
),
9093
(
9194
0.1,
92-
np.array([0] * 10 + [1] * 20),
95+
get_classes_counts(np.array([0] * 10 + [1] * 20)),
9396
"over-sampling",
9497
"remove samples from the minority class while trying to generate new", # noqa
9598
),
9699
(
97100
0.1,
98-
np.array([0] * 10 + [1] * 20),
101+
get_classes_counts(np.array([0] * 10 + [1] * 20)),
99102
"under-sampling",
100103
"generate new sample in the majority class while trying to remove",
101104
),
@@ -108,15 +111,21 @@ def test_check_sampling_strategy_float_error(ratio, y, type, err_msg):
108111

109112
def test_check_sampling_strategy_error():
110113
with pytest.raises(ValueError, match="'sampling_type' should be one of"):
111-
check_sampling_strategy("auto", np.array([1, 2, 3]), "rnd")
114+
check_sampling_strategy(
115+
"auto", get_classes_counts(np.array([1, 2, 3])), "rnd"
116+
)
112117

113118
error_regex = "The target 'y' needs to have more than 1 class."
114119
with pytest.raises(ValueError, match=error_regex):
115-
check_sampling_strategy("auto", np.ones((10,)), "over-sampling")
120+
check_sampling_strategy(
121+
"auto", get_classes_counts(np.ones((10,))), "over-sampling"
122+
)
116123

117124
error_regex = "When 'sampling_strategy' is a string, it needs to be one of"
118125
with pytest.raises(ValueError, match=error_regex):
119-
check_sampling_strategy("rnd", np.array([1, 2, 3]), "over-sampling")
126+
check_sampling_strategy(
127+
"rnd", get_classes_counts(np.array([1, 2, 3])), "over-sampling"
128+
)
120129

121130

122131
@pytest.mark.parametrize(
@@ -136,7 +145,9 @@ def test_check_sampling_strategy_error_wrong_string(
136145
),
137146
):
138147
check_sampling_strategy(
139-
sampling_strategy, np.array([1, 2, 3]), sampling_type
148+
sampling_strategy,
149+
get_classes_counts(np.array([1, 2, 3])),
150+
sampling_type,
140151
)
141152

142153

@@ -153,14 +164,18 @@ def test_sampling_strategy_class_target_unknown(
153164
):
154165
y = np.array([1] * 50 + [2] * 100 + [3] * 25)
155166
with pytest.raises(ValueError, match="are not present in the data."):
156-
check_sampling_strategy(sampling_strategy, y, sampling_method)
167+
check_sampling_strategy(
168+
sampling_strategy, get_classes_counts(y), sampling_method
169+
)
157170

158171

159172
def test_sampling_strategy_dict_error():
160173
y = np.array([1] * 50 + [2] * 100 + [3] * 25)
161174
sampling_strategy = {1: -100, 2: 50, 3: 25}
162175
with pytest.raises(ValueError, match="in a class cannot be negative."):
163-
check_sampling_strategy(sampling_strategy, y, "under-sampling")
176+
check_sampling_strategy(
177+
sampling_strategy, get_classes_counts(y), "under-sampling"
178+
)
164179
sampling_strategy = {1: 45, 2: 100, 3: 70}
165180
error_regex = (
166181
"With over-sampling methods, the number of samples in a"
@@ -169,7 +184,9 @@ def test_sampling_strategy_dict_error():
169184
" samples are asked."
170185
)
171186
with pytest.raises(ValueError, match=error_regex):
172-
check_sampling_strategy(sampling_strategy, y, "over-sampling")
187+
check_sampling_strategy(
188+
sampling_strategy, get_classes_counts(y), "over-sampling"
189+
)
173190

174191
error_regex = (
175192
"With under-sampling methods, the number of samples in a"
@@ -178,21 +195,27 @@ def test_sampling_strategy_dict_error():
178195
" are asked."
179196
)
180197
with pytest.raises(ValueError, match=error_regex):
181-
check_sampling_strategy(sampling_strategy, y, "under-sampling")
198+
check_sampling_strategy(
199+
sampling_strategy, get_classes_counts(y), "under-sampling"
200+
)
182201

183202

184203
@pytest.mark.parametrize("sampling_strategy", [-10, 10])
185204
def test_sampling_strategy_float_error_not_in_range(sampling_strategy):
186205
y = np.array([1] * 50 + [2] * 100)
187206
with pytest.raises(ValueError, match="it should be in the range"):
188-
check_sampling_strategy(sampling_strategy, y, "under-sampling")
207+
check_sampling_strategy(
208+
sampling_strategy, get_classes_counts(y), "under-sampling"
209+
)
189210

190211

191212
def test_sampling_strategy_float_error_not_binary():
192213
y = np.array([1] * 50 + [2] * 100 + [3] * 25)
193214
with pytest.raises(ValueError, match="the type of target is binary"):
194215
sampling_strategy = 0.5
195-
check_sampling_strategy(sampling_strategy, y, "under-sampling")
216+
check_sampling_strategy(
217+
sampling_strategy, get_classes_counts(y), "under-sampling"
218+
)
196219

197220

198221
@pytest.mark.parametrize(
@@ -202,7 +225,9 @@ def test_sampling_strategy_list_error_not_clean_sampling(sampling_method):
202225
y = np.array([1] * 50 + [2] * 100 + [3] * 25)
203226
with pytest.raises(ValueError, match="cannot be a list for samplers"):
204227
sampling_strategy = [1, 2, 3]
205-
check_sampling_strategy(sampling_strategy, y, sampling_method)
228+
check_sampling_strategy(
229+
sampling_strategy, get_classes_counts(y), sampling_method
230+
)
206231

207232

208233
def _sampling_strategy_func(y):
@@ -215,42 +240,87 @@ def _sampling_strategy_func(y):
215240
@pytest.mark.parametrize(
216241
"sampling_strategy, sampling_type, expected_sampling_strategy, target",
217242
[
218-
("auto", "under-sampling", {1: 25, 2: 25}, multiclass_target),
219-
("auto", "clean-sampling", {1: 25, 2: 25}, multiclass_target),
220-
("auto", "over-sampling", {1: 50, 3: 75}, multiclass_target),
221-
("all", "over-sampling", {1: 50, 2: 0, 3: 75}, multiclass_target),
222-
("all", "under-sampling", {1: 25, 2: 25, 3: 25}, multiclass_target),
223-
("all", "clean-sampling", {1: 25, 2: 25, 3: 25}, multiclass_target),
224-
("majority", "under-sampling", {2: 25}, multiclass_target),
225-
("majority", "clean-sampling", {2: 25}, multiclass_target),
226-
("minority", "over-sampling", {3: 75}, multiclass_target),
227-
("not minority", "over-sampling", {1: 50, 2: 0}, multiclass_target),
228-
("not minority", "under-sampling", {1: 25, 2: 25}, multiclass_target),
229-
("not minority", "clean-sampling", {1: 25, 2: 25}, multiclass_target),
230-
("not majority", "over-sampling", {1: 50, 3: 75}, multiclass_target),
231-
("not majority", "under-sampling", {1: 25, 3: 25}, multiclass_target),
232-
("not majority", "clean-sampling", {1: 25, 3: 25}, multiclass_target),
243+
("auto", "under-sampling", {1: 25, 2: 25}, multiclass_classes_counts),
244+
("auto", "clean-sampling", {1: 25, 2: 25}, multiclass_classes_counts),
245+
("auto", "over-sampling", {1: 50, 3: 75}, multiclass_classes_counts),
246+
(
247+
"all",
248+
"over-sampling",
249+
{1: 50, 2: 0, 3: 75},
250+
multiclass_classes_counts,
251+
),
252+
(
253+
"all",
254+
"under-sampling",
255+
{1: 25, 2: 25, 3: 25},
256+
multiclass_classes_counts,
257+
),
258+
(
259+
"all",
260+
"clean-sampling",
261+
{1: 25, 2: 25, 3: 25},
262+
multiclass_classes_counts,
263+
),
264+
("majority", "under-sampling", {2: 25}, multiclass_classes_counts),
265+
("majority", "clean-sampling", {2: 25}, multiclass_classes_counts),
266+
("minority", "over-sampling", {3: 75}, multiclass_classes_counts),
267+
(
268+
"not minority",
269+
"over-sampling",
270+
{1: 50, 2: 0},
271+
multiclass_classes_counts,
272+
),
273+
(
274+
"not minority",
275+
"under-sampling",
276+
{1: 25, 2: 25},
277+
multiclass_classes_counts,
278+
),
279+
(
280+
"not minority",
281+
"clean-sampling",
282+
{1: 25, 2: 25},
283+
multiclass_classes_counts,
284+
),
285+
(
286+
"not majority",
287+
"over-sampling",
288+
{1: 50, 3: 75},
289+
multiclass_classes_counts,
290+
),
291+
(
292+
"not majority",
293+
"under-sampling",
294+
{1: 25, 3: 25},
295+
multiclass_classes_counts,
296+
),
297+
(
298+
"not majority",
299+
"clean-sampling",
300+
{1: 25, 3: 25},
301+
multiclass_classes_counts,
302+
),
233303
(
234304
{1: 70, 2: 100, 3: 70},
235305
"over-sampling",
236306
{1: 20, 2: 0, 3: 45},
237-
multiclass_target,
307+
multiclass_classes_counts,
238308
),
239309
(
240310
{1: 30, 2: 45, 3: 25},
241311
"under-sampling",
242312
{1: 30, 2: 45, 3: 25},
243-
multiclass_target,
313+
multiclass_classes_counts,
244314
),
245-
([1], "clean-sampling", {1: 25}, multiclass_target),
315+
([1], "clean-sampling", {1: 25}, multiclass_classes_counts),
246316
(
247317
_sampling_strategy_func,
248318
"over-sampling",
249319
{1: 50, 2: 0, 3: 75},
250-
multiclass_target,
320+
multiclass_classes_counts,
251321
),
252-
(0.5, "over-sampling", {1: 25}, binary_target),
253-
(0.5, "under-sampling", {0: 50}, binary_target),
322+
(0.5, "over-sampling", {1: 25}, binary_classes_counts),
323+
(0.5, "under-sampling", {0: 50}, binary_classes_counts),
254324
],
255325
)
256326
def test_check_sampling_strategy(
@@ -271,23 +341,27 @@ def test_sampling_strategy_dict_over_sampling():
271341
r" the majority class \(class #2 -> 100\)"
272342
)
273343
with warns(UserWarning, expected_msg):
274-
check_sampling_strategy(sampling_strategy, y, "over-sampling")
344+
check_sampling_strategy(
345+
sampling_strategy, get_classes_counts(y), "over-sampling"
346+
)
275347

276348

277349
def test_sampling_strategy_callable_args():
278350
y = np.array([1] * 50 + [2] * 100 + [3] * 25)
279351
multiplier = {1: 1.5, 2: 1, 3: 3}
280352

281-
def sampling_strategy_func(y, multiplier):
353+
def sampling_strategy_func(classes_counts, multiplier):
282354
"""samples such that each class will be affected by the multiplier."""
283-
target_stats = Counter(y)
284355
return {
285356
key: int(values * multiplier[key])
286-
for key, values in target_stats.items()
357+
for key, values in classes_counts.items()
287358
}
288359

289360
sampling_strategy_ = check_sampling_strategy(
290-
sampling_strategy_func, y, "over-sampling", multiplier=multiplier
361+
sampling_strategy_func,
362+
get_classes_counts(y),
363+
"over-sampling",
364+
multiplier=multiplier,
291365
)
292366
assert sampling_strategy_ == {1: 25, 2: 0, 3: 50}
293367

@@ -314,11 +388,20 @@ def test_sampling_strategy_check_order(
314388
# dictionary is sorted. Refer to issue #428.
315389
y = np.array([1] * 50 + [2] * 100 + [3] * 25)
316390
sampling_strategy_ = check_sampling_strategy(
317-
sampling_strategy, y, sampling_type
391+
sampling_strategy, get_classes_counts(y), sampling_type
318392
)
319393
assert sampling_strategy_ == expected_result
320394

321395

396+
# FIXME: remove in 0.9
397+
def test_sampling_strategy_deprecation_array_target():
398+
# Check that we raise a FutureWarning when an array of target is passed
399+
with pytest.warns(FutureWarning):
400+
sampling_strategy = "auto"
401+
check_sampling_strategy(
402+
sampling_strategy, binary_target, "under-sampling",
403+
)
404+
322405
def test_arrays_transformer_plain_list():
323406
X = np.array([[0, 0], [1, 1]])
324407
y = np.array([[0, 0], [1, 1]])

0 commit comments

Comments
 (0)