Skip to content

Commit b6e3a9a

Browse files
chkoarglemaitre
authored andcommitted
BUG Accept column vectors when having binary or multiclass targets (#673)
1 parent e766a16 commit b6e3a9a

File tree

3 files changed

+20
-0
lines changed

3 files changed

+20
-0
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ var/
2222
*.egg-info/
2323
.installed.cfg
2424
*.egg
25+
Pipfile
26+
Pipfile.lock
2527

2628
# PyInstaller
2729
# Usually these files are written by a python script from a template

imblearn/utils/_validation.py

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sklearn.base import clone
1313
from sklearn.neighbors._base import KNeighborsMixin
1414
from sklearn.neighbors import NearestNeighbors
15+
from sklearn.utils import column_or_1d
1516
from sklearn.utils.multiclass import type_of_target
1617

1718
from ..exceptions import raise_isinstance_error
@@ -96,6 +97,8 @@ def check_target_type(y, indicate_one_vs_all=False):
9697
"multioutput targets are not supported."
9798
)
9899
y = y.argmax(axis=1)
100+
else:
101+
y = column_or_1d(y)
99102

100103
return (y, type_y == "multilabel-indicator") if indicate_one_vs_all else y
101104

imblearn/utils/estimator_checks.py

+15
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def _yield_sampler_checks(name, Estimator):
4444
yield check_samplers_multiclass_ova
4545
yield check_samplers_preserve_dtype
4646
yield check_samplers_sample_indices
47+
yield check_samplers_2d_target
4748

4849

4950
def _yield_classifier_checks(name, Estimator):
@@ -283,6 +284,20 @@ def check_samplers_multiclass_ova(name, Sampler):
283284
assert_allclose(y_res, y_res_ova.argmax(axis=1))
284285

285286

287+
def check_samplers_2d_target(name, Sampler):
288+
X, y = make_classification(
289+
n_samples=100,
290+
n_classes=3,
291+
n_informative=4,
292+
weights=[0.2, 0.3, 0.5],
293+
random_state=0,
294+
)
295+
296+
y = y.reshape(-1, 1) # Make the target 2d
297+
sampler = Sampler()
298+
sampler.fit_resample(X, y)
299+
300+
286301
def check_samplers_preserve_dtype(name, Sampler):
287302
X, y = make_classification(
288303
n_samples=1000,

0 commit comments

Comments
 (0)