Skip to content

Commit 0dbacc7

Browse files
klizterglemaitre
authored andcommitted
FIX bug in classification_imbalanced_report where y_pred and y_true were inversed (#397)
1 parent d482829 commit 0dbacc7

File tree

3 files changed

+40
-35
lines changed

3 files changed

+40
-35
lines changed

doc/whats_new/v0.0.4.rst

+4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ Enhancement
1818
Bug fixes
1919
.........
2020

21+
- Fix bug in :func:`metrics.classification_report_imbalanced` for which
22+
`y_pred` and `y_true` where inversed. :issue:`394` by :user:`Ole Silvig
23+
<klizter>.`
24+
2125
- Fix bug in ADASYN to consider only samples from the current class when
2226
generating new samples. :issue:`354` by :user:`Guillaume Lemaitre
2327
<glemaitre>`.

imblearn/metrics/classification.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -815,14 +815,14 @@ def classification_report_imbalanced(y_true,
815815
pre rec spe f1 geo iba\
816816
sup
817817
<BLANKLINE>
818-
class 0 0.50 1.00 0.75 0.67 0.71 0.48\
818+
class 0 0.50 1.00 0.75 0.67 0.87 0.77\
819819
1
820820
class 1 0.00 0.00 0.75 0.00 0.00 0.00\
821821
1
822-
class 2 1.00 0.67 1.00 0.80 0.82 0.69\
822+
class 2 1.00 0.67 1.00 0.80 0.82 0.64\
823823
3
824824
<BLANKLINE>
825-
avg / total 0.70 0.60 0.90 0.61 0.63 0.51\
825+
avg / total 0.70 0.60 0.90 0.61 0.66 0.54\
826826
5
827827
<BLANKLINE>
828828
@@ -867,17 +867,17 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.69\
867867
sample_weight=sample_weight)
868868
# Geometric mean
869869
geo_mean = geometric_mean_score(
870-
y_pred,
871870
y_true,
871+
y_pred,
872872
labels=labels,
873873
average=None,
874874
sample_weight=sample_weight)
875875
# Index balanced accuracy
876876
iba_gmean = make_index_balanced_accuracy(
877877
alpha=alpha, squared=True)(geometric_mean_score)
878878
iba = iba_gmean(
879-
y_pred,
880879
y_true,
880+
y_pred,
881881
labels=labels,
882882
average=None,
883883
sample_weight=sample_weight)

imblearn/metrics/tests/test_classification.py

+31-30
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# coding: utf-8
2+
13
"""Testing the metric for classification with imbalanced dataset"""
24
# Authors: Guillaume Lemaitre <[email protected]>
35
# Christos Aridas
@@ -221,7 +223,6 @@ def test_geometric_mean_multiclass():
221223

222224
y_true = [0, 0, 0, 0]
223225
y_pred = [1, 1, 1, 1]
224-
print(geometric_mean_score(y_true, y_pred))
225226
assert_allclose(geometric_mean_score(y_true, y_pred), 0.0, rtol=R_TOL)
226227

227228
cor = 0.001
@@ -316,9 +317,9 @@ def test_classification_report_imbalanced_multiclass():
316317

317318
# print classification report with class names
318319
expected_report = ('pre rec spe f1 geo iba sup setosa 0.83 0.79 0.92 '
319-
'0.81 0.86 0.74 24 versicolor 0.33 0.10 0.86 0.15 '
320-
'0.44 0.19 31 virginica 0.42 0.90 0.55 0.57 0.63 '
321-
'0.37 20 avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75')
320+
'0.81 0.85 0.72 24 versicolor 0.33 0.10 0.86 0.15 '
321+
'0.29 0.08 31 virginica 0.42 0.90 0.55 0.57 0.70 '
322+
'0.51 20 avg / total 0.51 0.53 0.80 0.47 0.58 0.40 75')
322323

323324
report = classification_report_imbalanced(
324325
y_true,
@@ -328,9 +329,9 @@ def test_classification_report_imbalanced_multiclass():
328329
assert _format_report(report) == expected_report
329330
# print classification report with label detection
330331
expected_report = ('pre rec spe f1 geo iba sup 0 0.83 0.79 0.92 0.81 '
331-
'0.86 0.74 24 1 0.33 0.10 0.86 0.15 0.44 0.19 31 2 '
332-
'0.42 0.90 0.55 0.57 0.63 0.37 20 avg / total 0.51 '
333-
'0.53 0.80 0.47 0.62 0.41 75')
332+
'0.85 0.72 24 1 0.33 0.10 0.86 0.15 0.29 0.08 31 '
333+
'2 0.42 0.90 0.55 0.57 0.70 0.51 20 avg / total '
334+
'0.51 0.53 0.80 0.47 0.58 0.40 75')
334335

335336
report = classification_report_imbalanced(y_true, y_pred)
336337
assert _format_report(report) == expected_report
@@ -342,11 +343,11 @@ def test_classification_report_imbalanced_multiclass_with_digits():
342343

343344
# print classification report with class names
344345
expected_report = ('pre rec spe f1 geo iba sup setosa 0.82609 0.79167 '
345-
'0.92157 0.80851 0.86409 0.74085 24 versicolor '
346-
'0.33333 0.09677 0.86364 0.15000 0.43809 0.18727 31 '
347-
'virginica 0.41860 0.90000 0.54545 0.57143 0.62645 '
348-
'0.37208 20 avg / total 0.51375 0.53333 0.79733 '
349-
'0.47310 0.62464 0.41370 75')
346+
'0.92157 0.80851 0.85415 0.72010 24 versicolor '
347+
'0.33333 0.09677 0.86364 0.15000 0.28910 0.07717 '
348+
'31 virginica 0.41860 0.90000 0.54545 0.57143 0.70065 '
349+
'0.50831 20 avg / total 0.51375 0.53333 0.79733 '
350+
'0.47310 0.57966 0.39788 75')
350351
report = classification_report_imbalanced(
351352
y_true,
352353
y_pred,
@@ -356,9 +357,9 @@ def test_classification_report_imbalanced_multiclass_with_digits():
356357
assert _format_report(report) == expected_report
357358
# print classification report with label detection
358359
expected_report = ('pre rec spe f1 geo iba sup 0 0.83 0.79 0.92 0.81 '
359-
'0.86 0.74 24 1 0.33 0.10 0.86 0.15 0.44 0.19 31 2 '
360-
'0.42 0.90 0.55 0.57 0.63 0.37 20 avg / total 0.51 '
361-
'0.53 0.80 0.47 0.62 0.41 75')
360+
'0.85 0.72 24 1 0.33 0.10 0.86 0.15 0.29 0.08 31 '
361+
'2 0.42 0.90 0.55 0.57 0.70 0.51 20 avg / total 0.51 '
362+
'0.53 0.80 0.47 0.58 0.40 75')
362363
report = classification_report_imbalanced(y_true, y_pred)
363364
assert _format_report(report) == expected_report
364365

@@ -369,17 +370,17 @@ def test_classification_report_imbalanced_multiclass_with_string_label():
369370
y_true = np.array(["blue", "green", "red"])[y_true]
370371
y_pred = np.array(["blue", "green", "red"])[y_pred]
371372

372-
expected_report = ('pre rec spe f1 geo iba sup blue 0.83 0.79 0.92 '
373-
'0.81 0.86 0.74 24 green 0.33 0.10 0.86 0.15 0.44 '
374-
'0.19 31 red 0.42 0.90 0.55 0.57 0.63 0.37 20 '
375-
'avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75')
373+
expected_report = ('pre rec spe f1 geo iba sup blue 0.83 0.79 0.92 0.81 '
374+
'0.85 0.72 24 green 0.33 0.10 0.86 0.15 0.29 0.08 31 '
375+
'red 0.42 0.90 0.55 0.57 0.70 0.51 20 avg / total '
376+
'0.51 0.53 0.80 0.47 0.58 0.40 75')
376377
report = classification_report_imbalanced(y_true, y_pred)
377378
assert _format_report(report) == expected_report
378379

379-
expected_report = ('pre rec spe f1 geo iba sup a 0.83 0.79 0.92 0.81 '
380-
'0.86 0.74 24 b 0.33 0.10 0.86 0.15 0.44 0.19 31 '
381-
'c 0.42 0.90 0.55 0.57 0.63 0.37 20 avg / total '
382-
'0.51 0.53 0.80 0.47 0.62 0.41 75')
380+
expected_report = ('pre rec spe f1 geo iba sup a 0.83 0.79 0.92 0.81 0.85 '
381+
'0.72 24 b 0.33 0.10 0.86 0.15 0.29 0.08 31 c 0.42 '
382+
'0.90 0.55 0.57 0.70 0.51 20 avg / total 0.51 0.53 '
383+
'0.80 0.47 0.58 0.40 75')
383384
report = classification_report_imbalanced(
384385
y_true, y_pred, target_names=["a", "b", "c"])
385386
assert _format_report(report) == expected_report
@@ -392,10 +393,10 @@ def test_classification_report_imbalanced_multiclass_with_unicode_label():
392393
y_true = labels[y_true]
393394
y_pred = labels[y_pred]
394395

395-
expected_report = (u'pre rec spe f1 geo iba sup blue\xa2 0.83 0.79 '
396-
u'0.92 0.81 0.86 0.74 24 green\xa2 0.33 0.10 0.86 '
397-
u'0.15 0.44 0.19 31 red\xa2 0.42 0.90 0.55 0.57 0.63 '
398-
u'0.37 20 avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75')
396+
expected_report = (u'pre rec spe f1 geo iba sup blue¢ 0.83 0.79 0.92 0.81 '
397+
u'0.85 0.72 24 green¢ 0.33 0.10 0.86 0.15 0.29 0.08 31 '
398+
u'red¢ 0.42 0.90 0.55 0.57 0.70 0.51 20 avg / total '
399+
u'0.51 0.53 0.80 0.47 0.58 0.40 75')
399400
if np_version[:3] < (1, 7, 0):
400401
with raises(RuntimeError, match="NumPy < 1.7.0"):
401402
classification_report_imbalanced(y_true, y_pred)
@@ -412,9 +413,9 @@ def test_classification_report_imbalanced_multiclass_with_long_string_label():
412413
y_pred = labels[y_pred]
413414

414415
expected_report = ('pre rec spe f1 geo iba sup blue 0.83 0.79 0.92 0.81 '
415-
'0.86 0.74 24 greengreengreengreengreen 0.33 0.10 '
416-
'0.86 0.15 0.44 0.19 31 red 0.42 0.90 0.55 0.57 0.63 '
417-
'0.37 20 avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75')
416+
'0.85 0.72 24 greengreengreengreengreen 0.33 0.10 '
417+
'0.86 0.15 0.29 0.08 31 red 0.42 0.90 0.55 0.57 0.70 '
418+
'0.51 20 avg / total 0.51 0.53 0.80 0.47 0.58 0.40 75')
418419

419420
report = classification_report_imbalanced(y_true, y_pred)
420421
assert _format_report(report) == expected_report

0 commit comments

Comments
 (0)