Skip to content

Commit 5be5670

Browse files
AYY7glemaitre
andauthored
FIX do not ignore target_names when output_dict=True in classification_report_imbalanced (#989)
Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent e9d120a commit 5be5670

File tree

3 files changed

+42
-6
lines changed

3 files changed

+42
-6
lines changed

doc/whats_new/v0.11.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ Changelog
99
Bug fixes
1010
.........
1111

12+
- Fix a bug in :func:`~imblearn.metrics.classification_report_imbalanced` where the
13+
parameter `target_names` was not taken into account when `output_dict=True`.
14+
:pr:`989` by :user:`AYY7 <AYY7>`.
15+
1216
- :class:`~imblearn.over_sampling.SMOTENC` now handles mix types of data type such as
1317
`bool` and `pd.category` by delegating the conversion to scikit-learn encoder.
1418
:pr:`1002` by :user:`Guillaume Lemaitre <glemaitre>`.

imblearn/metrics/_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,7 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
10381038
report_dict_label[headers[-1]] = support[i]
10391039
report += fmt % tuple(values)
10401040

1041-
report_dict[label] = report_dict_label
1041+
report_dict[target_names[i]] = report_dict_label
10421042

10431043
report += "\n"
10441044

imblearn/metrics/tests/test_classification.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def test_iba_error_y_score_prob_error(score_loss):
459459
aps(y_true, y_pred)
460460

461461

462-
def test_classification_report_imbalanced_dict():
462+
def test_classification_report_imbalanced_dict_with_target_names():
463463
iris = datasets.load_iris()
464464
y_true, y_pred, _ = make_prediction(dataset=iris, binary=False)
465465

@@ -471,12 +471,44 @@ def test_classification_report_imbalanced_dict():
471471
output_dict=True,
472472
)
473473
outer_keys = set(report.keys())
474-
inner_keys = set(report[0].keys())
474+
inner_keys = set(report["setosa"].keys())
475475

476476
expected_outer_keys = {
477-
0,
478-
1,
479-
2,
477+
"setosa",
478+
"versicolor",
479+
"virginica",
480+
"avg_pre",
481+
"avg_rec",
482+
"avg_spe",
483+
"avg_f1",
484+
"avg_geo",
485+
"avg_iba",
486+
"total_support",
487+
}
488+
expected_inner_keys = {"spe", "f1", "sup", "rec", "geo", "iba", "pre"}
489+
490+
assert outer_keys == expected_outer_keys
491+
assert inner_keys == expected_inner_keys
492+
493+
494+
def test_classification_report_imbalanced_dict_without_target_names():
495+
iris = datasets.load_iris()
496+
y_true, y_pred, _ = make_prediction(dataset=iris, binary=False)
497+
print(iris.target_names)
498+
report = classification_report_imbalanced(
499+
y_true,
500+
y_pred,
501+
labels=np.arange(len(iris.target_names)),
502+
output_dict=True,
503+
)
504+
print(report.keys())
505+
outer_keys = set(report.keys())
506+
inner_keys = set(report["0"].keys())
507+
508+
expected_outer_keys = {
509+
"0",
510+
"1",
511+
"2",
480512
"avg_pre",
481513
"avg_rec",
482514
"avg_spe",

0 commit comments

Comments
 (0)