Skip to content

Commit f17107e

Browse files
authored
FIX: raise an error when multilabel does not encode multiclass (#490)
1 parent 61cace9 commit f17107e

File tree

11 files changed

+38
-54
lines changed

11 files changed

+38
-54
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
Release history
55
===============
66

7-
.. include:: whats_new/v0.0.4.rst
7+
.. include:: whats_new/v0.4.rst
88

9-
.. include:: whats_new/v0.0.3.rst
9+
.. include:: whats_new/v0.3.rst
1010

11-
.. include:: whats_new/v0.0.2.rst
11+
.. include:: whats_new/v0.2.rst
1212

13-
.. include:: whats_new/v0.0.1.rst
13+
.. include:: whats_new/v0.1.rst
File renamed without changes.
File renamed without changes.
File renamed without changes.

doc/whats_new/v0.0.4.rst renamed to doc/whats_new/v0.4.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ Bug fixes
1414
deviation.
1515
By :user:`Guillaume Lemaitre <glemaitre>` in :issue:`491`.
1616

17+
- Raise an error when passing target which is not supported, i.e. regression
18+
target or multilabel targets. Imbalanced-learn does not support this case.
19+
By :user:`Guillaume Lemaitre <glemaitre>` in :issue:`490`.
20+
1721
Version 0.4
1822
===========
1923

imblearn/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sklearn.externals import six
1616
from sklearn.preprocessing import label_binarize
1717
from sklearn.utils import check_X_y
18+
from sklearn.utils.multiclass import check_classification_targets
1819

1920
from .utils import check_sampling_strategy, check_target_type
2021
from .utils.deprecation import deprecate_parameter
@@ -77,6 +78,7 @@ def fit_resample(self, X, y):
7778
"""
7879
self._deprecate_ratio()
7980

81+
check_classification_targets(y)
8082
X, y, binarize_y = self._check_X_y(X, y)
8183

8284
self.sampling_strategy_ = check_sampling_strategy(

imblearn/over_sampling/tests/test_smote_nc.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,12 @@ def test_smotenc_check_target_type():
127127
y = np.linspace(0, 1, 30)
128128
smote = SMOTENC(categorical_features=categorical_features,
129129
random_state=0)
130-
with pytest.warns(UserWarning, match='should be of types'):
130+
with pytest.raises(ValueError, match="Unknown label type: 'continuous'"):
131+
smote.fit_resample(X, y)
132+
rng = np.random.RandomState(42)
133+
y = rng.randint(2, size=(20, 3))
134+
with pytest.raises(ValueError, match="'y' should encode the multiclass"):
131135
smote.fit_resample(X, y)
132-
133136

134137
def test_smotenc_samplers_one_label():
135138
X, _, categorical_features = data_heterogneous_unordered()

imblearn/utils/_validation.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,14 @@ def check_target_type(y, indicate_one_vs_all=False):
8686
8787
"""
8888
type_y = type_of_target(y)
89-
if type_y not in TARGET_KIND:
90-
# FIXME: perfectly we should raise an error but the sklearn API does
91-
# not allow for it
92-
warnings.warn("'y' should be of types {} only. Got {} instead.".format(
93-
TARGET_KIND, type_of_target(y)))
94-
95-
if indicate_one_vs_all:
96-
return (y.argmax(axis=1) if type_y == 'multilabel-indicator' else y,
97-
type_y == 'multilabel-indicator')
98-
else:
99-
return y.argmax(axis=1) if type_y == 'multilabel-indicator' else y
89+
if type_y == 'multilabel-indicator':
90+
if np.any(y.sum(axis=1) > 1):
91+
raise ValueError(
92+
"When 'y' corresponds to '{}', 'y' should encode the "
93+
"multiclass (a single 1 by row).".format(type_y))
94+
y = y.argmax(axis=1)
95+
96+
return (y, type_y == 'multilabel-indicator') if indicate_one_vs_all else y
10097

10198

10299
def _sampling_strategy_all(y, sampling_type):

imblearn/utils/estimator_checks.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
from imblearn.over_sampling import SMOTE
3434
from imblearn.under_sampling import NearMiss, ClusterCentroids
3535

36-
from imblearn.utils.testing import warns
37-
3836
DONT_SUPPORT_RATIO = ['SVMSMOTE', 'BorderlineSMOTE']
3937
SUPPORT_STRING = ['RandomUnderSampler', 'RandomOverSampler']
4038
HAVE_SAMPLE_INDICES = [
@@ -54,7 +52,6 @@ def monkey_patch_check_dtype_object(name, estimator_orig):
5452
X = rng.rand(40, 10).astype(object)
5553
y = np.array([0] * 10 + [1] * 30, dtype=np.int)
5654
estimator = clone(estimator_orig)
57-
5855
estimator.fit(X, y)
5956

6057
try:
@@ -123,14 +120,20 @@ def check_estimator(Estimator, run_sampler_tests=True):
123120

124121

125122
def check_target_type(name, Estimator):
123+
# should raise warning if the target is continuous (we cannot raise error)
126124
X = np.random.random((20, 2))
127125
y = np.linspace(0, 1, 20)
128126
estimator = Estimator()
129127
# FIXME: in 0.6 set the random_state for all
130128
if name not in DONT_HAVE_RANDOM_STATE:
131129
set_random_state(estimator)
132-
with warns(UserWarning, match='should be of types'):
133-
estimator.fit(X, y)
130+
with pytest.raises(ValueError, match="Unknown label type: 'continuous'"):
131+
estimator.fit_resample(X, y)
132+
# if the target is multilabel then we should raise an error
133+
rng = np.random.RandomState(42)
134+
y = rng.randint(2, size=(20, 3))
135+
with pytest.raises(ValueError, match="'y' should encode the multiclass"):
136+
estimator.fit_resample(X, y)
134137

135138

136139
def check_samplers_one_label(name, Sampler):
@@ -139,7 +142,7 @@ def check_samplers_one_label(name, Sampler):
139142
X = np.random.random((20, 2))
140143
y = np.zeros(20)
141144
try:
142-
sampler.fit(X, y)
145+
sampler.fit_resample(X, y)
143146
except ValueError as e:
144147
if 'class' not in repr(e):
145148
print(error_string_fit, Sampler, e)
@@ -157,15 +160,15 @@ def check_samplers_fit(name, Sampler):
157160
sampler = Sampler()
158161
X = np.random.random((30, 2))
159162
y = np.array([1] * 20 + [0] * 10)
160-
sampler.fit(X, y)
163+
sampler.fit_resample(X, y)
161164
assert hasattr(sampler, 'sampling_strategy_'), \
162165
"No fitted attribute sampling_strategy_"
163166

164167

165168
def check_samplers_fit_resample(name, Sampler):
166169
sampler = Sampler()
167170
X, y = make_classification(n_samples=1000, n_classes=3, n_informative=4,
168-
weights=[0.2, 0.3, 0.5], random_state=0)
171+
weights=[0.2, 0.3, 0.5], random_state=0)
169172
target_stats = Counter(y)
170173
X_res, y_res = sampler.fit_resample(X, y)
171174
if isinstance(sampler, BaseOverSampler):

imblearn/utils/tests/test_estimator_checks.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from sklearn.base import BaseEstimator
55
from sklearn.utils import check_X_y
6+
from sklearn.utils.multiclass import check_classification_targets
67

78
from imblearn.base import BaseSampler
89
from imblearn.utils.estimator_checks import check_estimator
@@ -17,6 +18,8 @@ def fit(self, X, y):
1718
return self
1819

1920
def fit_resample(self, X, y):
21+
check_classification_targets(y)
22+
self.fit(X, y)
2023
return X, y
2124

2225

@@ -27,36 +30,15 @@ def fit(self, X, y):
2730
X, y = check_X_y(X, y, accept_sparse=True)
2831
return self
2932

30-
def fit_resample(self, X, y):
31-
self.fit(X, y)
32-
return X, y
33-
3433

3534
class NoAcceptingSparseSampler(BaseBadSampler):
3635
"""Sampler which does not accept sparse matrix."""
3736
def fit(self, X, y):
38-
X, y = check_X_y(X, y, accept_sparse=False)
39-
y, _ = check_target_type(y, indicate_one_vs_all=True)
40-
self.sampling_strategy_ = 'sampling_strategy_'
41-
return self
42-
43-
def fit_resample(self, X, y):
44-
self.fit(X, y)
45-
return X, y
46-
47-
48-
class NotTransformingTargetOvR(BaseBadSampler):
49-
"""Sampler which does not transform OvR enconding."""
50-
def fit(self, X, y):
51-
X, y = check_X_y(X, y, accept_sparse=True)
5237
y, _ = check_target_type(y, indicate_one_vs_all=True)
38+
X, y = check_X_y(X, y, accept_sparse=False)
5339
self.sampling_strategy_ = 'sampling_strategy_'
5440
return self
5541

56-
def fit_resample(self, X, y):
57-
self.fit(X, y)
58-
return X, y
59-
6042

6143
class NotPreservingDtypeSampler(BaseSampler):
6244
_sampling_type = 'bypass'
@@ -72,7 +54,6 @@ def _fit_resample(self, X, y):
7254
[(BaseBadSampler, AssertionError, "TypeError not raised by fit"),
7355
(NotFittedSampler, AssertionError, "No fitted attribute"),
7456
(NoAcceptingSparseSampler, TypeError, "A sparse matrix was passed"),
75-
(NotTransformingTargetOvR, ValueError, "bad input shape"),
7657
(NotPreservingDtypeSampler, AssertionError, "X dytype is not preserved")]
7758
)
7859
def test_check_estimator(Estimator, err_type, err_msg):

0 commit comments

Comments
 (0)