1
+ # coding: utf-8
2
+
1
3
"""Testing the metric for classification with imbalanced dataset"""
2
4
# Authors: Guillaume Lemaitre <[email protected] >
3
5
# Christos Aridas
@@ -221,7 +223,6 @@ def test_geometric_mean_multiclass():
221
223
222
224
y_true = [0 , 0 , 0 , 0 ]
223
225
y_pred = [1 , 1 , 1 , 1 ]
224
- print (geometric_mean_score (y_true , y_pred ))
225
226
assert_allclose (geometric_mean_score (y_true , y_pred ), 0.0 , rtol = R_TOL )
226
227
227
228
cor = 0.001
@@ -316,9 +317,9 @@ def test_classification_report_imbalanced_multiclass():
316
317
317
318
# print classification report with class names
318
319
expected_report = ('pre rec spe f1 geo iba sup setosa 0.83 0.79 0.92 '
319
- '0.81 0.86 0.74 24 versicolor 0.33 0.10 0.86 0.15 '
320
- '0.44 0.19 31 virginica 0.42 0.90 0.55 0.57 0.63 '
321
- '0.37 20 avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75' )
320
+ '0.81 0.85 0.72 24 versicolor 0.33 0.10 0.86 0.15 '
321
+ '0.29 0.08 31 virginica 0.42 0.90 0.55 0.57 0.70 '
322
+ '0.51 20 avg / total 0.51 0.53 0.80 0.47 0.58 0.40 75' )
322
323
323
324
report = classification_report_imbalanced (
324
325
y_true ,
@@ -328,9 +329,9 @@ def test_classification_report_imbalanced_multiclass():
328
329
assert _format_report (report ) == expected_report
329
330
# print classification report with label detection
330
331
expected_report = ('pre rec spe f1 geo iba sup 0 0.83 0.79 0.92 0.81 '
331
- '0.86 0.74 24 1 0.33 0.10 0.86 0.15 0.44 0.19 31 2 '
332
- '0.42 0.90 0.55 0.57 0.63 0.37 20 avg / total 0.51 '
333
- '0.53 0.80 0.47 0.62 0.41 75' )
332
+ '0.85 0.72 24 1 0.33 0.10 0.86 0.15 0.29 0.08 31 '
333
+ '2 0.42 0.90 0.55 0.57 0.70 0.51 20 avg / total '
334
+ '0.51 0. 53 0.80 0.47 0.58 0.40 75' )
334
335
335
336
report = classification_report_imbalanced (y_true , y_pred )
336
337
assert _format_report (report ) == expected_report
@@ -342,11 +343,11 @@ def test_classification_report_imbalanced_multiclass_with_digits():
342
343
343
344
# print classification report with class names
344
345
expected_report = ('pre rec spe f1 geo iba sup setosa 0.82609 0.79167 '
345
- '0.92157 0.80851 0.86409 0.74085 24 versicolor '
346
- '0.33333 0.09677 0.86364 0.15000 0.43809 0.18727 31 '
347
- 'virginica 0.41860 0.90000 0.54545 0.57143 0.62645 '
348
- '0.37208 20 avg / total 0.51375 0.53333 0.79733 '
349
- '0.47310 0.62464 0.41370 75' )
346
+ '0.92157 0.80851 0.85415 0.72010 24 versicolor '
347
+ '0.33333 0.09677 0.86364 0.15000 0.28910 0.07717 '
348
+ '31 virginica 0.41860 0.90000 0.54545 0.57143 0.70065 '
349
+ '0.50831 20 avg / total 0.51375 0.53333 0.79733 '
350
+ '0.47310 0.57966 0.39788 75' )
350
351
report = classification_report_imbalanced (
351
352
y_true ,
352
353
y_pred ,
@@ -356,9 +357,9 @@ def test_classification_report_imbalanced_multiclass_with_digits():
356
357
assert _format_report (report ) == expected_report
357
358
# print classification report with label detection
358
359
expected_report = ('pre rec spe f1 geo iba sup 0 0.83 0.79 0.92 0.81 '
359
- '0.86 0.74 24 1 0.33 0.10 0.86 0.15 0.44 0.19 31 2 '
360
- '0.42 0.90 0.55 0.57 0.63 0.37 20 avg / total 0.51 '
361
- '0.53 0.80 0.47 0.62 0.41 75' )
360
+ '0.85 0.72 24 1 0.33 0.10 0.86 0.15 0.29 0.08 31 '
361
+ '2 0.42 0.90 0.55 0.57 0.70 0.51 20 avg / total 0.51 '
362
+ '0.53 0.80 0.47 0.58 0.40 75' )
362
363
report = classification_report_imbalanced (y_true , y_pred )
363
364
assert _format_report (report ) == expected_report
364
365
@@ -369,17 +370,17 @@ def test_classification_report_imbalanced_multiclass_with_string_label():
369
370
y_true = np .array (["blue" , "green" , "red" ])[y_true ]
370
371
y_pred = np .array (["blue" , "green" , "red" ])[y_pred ]
371
372
372
- expected_report = ('pre rec spe f1 geo iba sup blue 0.83 0.79 0.92 '
373
- '0.81 0.86 0.74 24 green 0.33 0.10 0.86 0.15 0.44 '
374
- '0.19 31 red 0.42 0.90 0.55 0.57 0.63 0.37 20 '
375
- 'avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75' )
373
+ expected_report = ('pre rec spe f1 geo iba sup blue 0.83 0.79 0.92 0.81 '
374
+ '0.85 0.72 24 green 0.33 0.10 0.86 0.15 0.29 0.08 31 '
375
+ 'red 0.42 0.90 0.55 0.57 0.70 0.51 20 avg / total '
376
+ '0.51 0.53 0.80 0.47 0.58 0.40 75' )
376
377
report = classification_report_imbalanced (y_true , y_pred )
377
378
assert _format_report (report ) == expected_report
378
379
379
- expected_report = ('pre rec spe f1 geo iba sup a 0.83 0.79 0.92 0.81 '
380
- '0.86 0.74 24 b 0.33 0.10 0.86 0.15 0.44 0.19 31 '
381
- 'c 0.42 0. 90 0.55 0.57 0.63 0.37 20 avg / total '
382
- '0.51 0.53 0. 80 0.47 0.62 0.41 75' )
380
+ expected_report = ('pre rec spe f1 geo iba sup a 0.83 0.79 0.92 0.81 0.85 '
381
+ '0.72 24 b 0.33 0.10 0.86 0.15 0.29 0.08 31 c 0.42 '
382
+ '0. 90 0.55 0.57 0.70 0.51 20 avg / total 0.51 0.53 '
383
+ '0.80 0.47 0.58 0.40 75' )
383
384
report = classification_report_imbalanced (
384
385
y_true , y_pred , target_names = ["a" , "b" , "c" ])
385
386
assert _format_report (report ) == expected_report
@@ -392,10 +393,10 @@ def test_classification_report_imbalanced_multiclass_with_unicode_label():
392
393
y_true = labels [y_true ]
393
394
y_pred = labels [y_pred ]
394
395
395
- expected_report = (u'pre rec spe f1 geo iba sup blue\xa2 0.83 0.79 '
396
- u'0.92 0.81 0.86 0.74 24 green\xa2 0.33 0.10 0.86 '
397
- u'0.15 0.44 0.19 31 red\xa2 0.42 0.90 0.55 0.57 0.63 '
398
- u'0.37 20 avg / total 0. 51 0.53 0.80 0.47 0.62 0.41 75' )
396
+ expected_report = (u'pre rec spe f1 geo iba sup blue¢ 0.83 0.79 0.92 0.81 '
397
+ u'0.85 0.72 24 green¢ 0.33 0.10 0.86 0.15 0.29 0.08 31 '
398
+ u'red¢ 0.42 0.90 0.55 0.57 0.70 0.51 20 avg / total '
399
+ u'0.51 0.53 0.80 0.47 0.58 0.40 75' )
399
400
if np_version [:3 ] < (1 , 7 , 0 ):
400
401
with raises (RuntimeError , match = "NumPy < 1.7.0" ):
401
402
classification_report_imbalanced (y_true , y_pred )
@@ -412,9 +413,9 @@ def test_classification_report_imbalanced_multiclass_with_long_string_label():
412
413
y_pred = labels [y_pred ]
413
414
414
415
expected_report = ('pre rec spe f1 geo iba sup blue 0.83 0.79 0.92 0.81 '
415
- '0.86 0.74 24 greengreengreengreengreen 0.33 0.10 '
416
- '0.86 0.15 0.44 0.19 31 red 0.42 0.90 0.55 0.57 0.63 '
417
- '0.37 20 avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75' )
416
+ '0.85 0.72 24 greengreengreengreengreen 0.33 0.10 '
417
+ '0.86 0.15 0.29 0.08 31 red 0.42 0.90 0.55 0.57 0.70 '
418
+ '0.51 20 avg / total 0.51 0.53 0.80 0.47 0.58 0.40 75' )
418
419
419
420
report = classification_report_imbalanced (y_true , y_pred )
420
421
assert _format_report (report ) == expected_report
0 commit comments