Skip to content

Commit 58acdf2

Browse files
committed
iter
1 parent d4aabf8 commit 58acdf2

File tree

5 files changed

+28
-10
lines changed

5 files changed

+28
-10
lines changed

imblearn/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def fit(self, X, y):
5353
if (not dask_collection or
5454
(dask_collection and self.validate_if_dask_collection)):
5555
X, y, _ = self._check_X_y(X, y)
56+
5657
self._classes_counts = get_classes_counts(y)
5758
self.sampling_strategy_ = check_sampling_strategy(
5859
self.sampling_strategy, self._classes_counts, self._sampling_type

imblearn/datasets/_imbalance.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
from ..under_sampling import RandomUnderSampler
1111
from ..utils import check_sampling_strategy
12-
from ..utils._validation import _deprecate_positional_args
12+
from ..utils._validation import (
13+
_deprecate_positional_args,
14+
get_classes_counts,
15+
)
1316

1417

1518
@_deprecate_positional_args
@@ -87,11 +90,11 @@ def make_imbalance(
8790
>>> print('Distribution after imbalancing: {}'.format(Counter(y_res)))
8891
Distribution after imbalancing: Counter({2: 30, 1: 20, 0: 10})
8992
"""
90-
target_stats = Counter(y)
93+
target_stats = get_classes_counts(y)
9194
# restrict ratio to be a dict or a callable
9295
if isinstance(sampling_strategy, dict) or callable(sampling_strategy):
9396
sampling_strategy_ = check_sampling_strategy(
94-
sampling_strategy, y, "under-sampling", **kwargs
97+
sampling_strategy, target_stats, "under-sampling", **kwargs
9598
)
9699
else:
97100
raise ValueError(

imblearn/ensemble/_bagging.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
from ..utils import Substitution, check_target_type, check_sampling_strategy
1919
from ..utils._docstring import _n_jobs_docstring
2020
from ..utils._docstring import _random_state_docstring
21-
from ..utils._validation import _deprecate_positional_args
21+
from ..utils._validation import (
22+
_deprecate_positional_args,
23+
get_classes_counts,
24+
)
2225

2326

2427
@Substitution(
@@ -216,11 +219,12 @@ def __init__(
216219

217220
def _validate_y(self, y):
218221
y_encoded = super()._validate_y(y)
222+
classes_counts = get_classes_counts(y)
219223
if isinstance(self.sampling_strategy, dict):
220224
self._sampling_strategy = {
221225
np.where(self.classes_ == key)[0][0]: value
222226
for key, value in check_sampling_strategy(
223-
self.sampling_strategy, y, 'under-sampling',
227+
self.sampling_strategy, classes_counts, 'under-sampling',
224228
).items()
225229
}
226230
else:

imblearn/ensemble/_easy_ensemble.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from ..utils import Substitution, check_target_type, check_sampling_strategy
1818
from ..utils._docstring import _n_jobs_docstring
1919
from ..utils._docstring import _random_state_docstring
20-
from ..utils._validation import _deprecate_positional_args
20+
from ..utils._validation import (
21+
_deprecate_positional_args,
22+
get_classes_counts,
23+
)
2124
from ..pipeline import Pipeline
2225

2326
MAX_INT = np.iinfo(np.int32).max
@@ -156,11 +159,14 @@ def __init__(
156159

157160
def _validate_y(self, y):
158161
y_encoded = super()._validate_y(y)
162+
classes_counts = get_classes_counts(y)
159163
if isinstance(self.sampling_strategy, dict):
160164
self._sampling_strategy = {
161165
np.where(self.classes_ == key)[0][0]: value
162166
for key, value in check_sampling_strategy(
163-
self.sampling_strategy, y, 'under-sampling',
167+
self.sampling_strategy,
168+
classes_counts,
169+
"under-sampling",
164170
).items()
165171
}
166172
else:

imblearn/ensemble/_forest.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,11 @@
3333
from ..utils import Substitution
3434
from ..utils._docstring import _n_jobs_docstring
3535
from ..utils._docstring import _random_state_docstring
36-
from ..utils._validation import check_sampling_strategy
37-
from ..utils._validation import _deprecate_positional_args
36+
from ..utils._validation import (
37+
check_sampling_strategy,
38+
_deprecate_positional_args,
39+
get_classes_counts,
40+
)
3841

3942
MAX_INT = np.iinfo(np.int32).max
4043

@@ -457,10 +460,11 @@ def fit(self, X, y, sample_weight=None):
457460
y_encoded = np.ascontiguousarray(y_encoded, dtype=DOUBLE)
458461

459462
if isinstance(self.sampling_strategy, dict):
463+
classes_counts = get_classes_counts(y)
460464
self._sampling_strategy = {
461465
np.where(self.classes_[0] == key)[0][0]: value
462466
for key, value in check_sampling_strategy(
463-
self.sampling_strategy, y, 'under-sampling',
467+
self.sampling_strategy, classes_counts, 'under-sampling',
464468
).items()
465469
}
466470
else:

0 commit comments

Comments
 (0)