@@ -31,9 +31,6 @@ def compute_intermediate_confusion_matrices(
31
31
A 2-D confusion matrix with shape (n_labels + 1, n_labels + 1).
32
32
"""
33
33
34
- n_gt_labels = groundtruth_labels .size
35
- n_pd_labels = prediction_labels .size
36
-
37
34
groundtruth_counts = groundtruths .sum (axis = 1 )
38
35
prediction_counts = predictions .sum (axis = 1 )
39
36
@@ -42,33 +39,23 @@ def compute_intermediate_confusion_matrices(
42
39
).sum ()
43
40
44
41
intersection_counts = np .logical_and (
45
- groundtruths . reshape ( n_gt_labels , 1 , - 1 ) ,
46
- predictions . reshape ( 1 , n_pd_labels , - 1 ) ,
42
+ groundtruths [:, None , :] ,
43
+ predictions [ None , :, :] ,
47
44
).sum (axis = 2 )
48
-
49
45
intersected_groundtruth_counts = intersection_counts .sum (axis = 1 )
50
46
intersected_prediction_counts = intersection_counts .sum (axis = 0 )
51
47
52
48
confusion_matrix = np .zeros ((n_labels + 1 , n_labels + 1 ), dtype = np .int32 )
53
49
confusion_matrix [0 , 0 ] = background_counts
54
- for gidx in range (n_gt_labels ):
55
- gt_label_idx = groundtruth_labels [gidx ]
56
- for pidx in range (n_pd_labels ):
57
- pd_label_idx = prediction_labels [pidx ]
58
- confusion_matrix [
59
- gt_label_idx + 1 ,
60
- pd_label_idx + 1 ,
61
- ] = intersection_counts [gidx , pidx ]
62
-
63
- if gidx == 0 :
64
- confusion_matrix [0 , pd_label_idx + 1 ] = (
65
- prediction_counts [pidx ]
66
- - intersected_prediction_counts [pidx ]
67
- )
68
-
69
- confusion_matrix [gt_label_idx + 1 , 0 ] = (
70
- groundtruth_counts [gidx ] - intersected_groundtruth_counts [gidx ]
71
- )
50
+ confusion_matrix [
51
+ np .ix_ (groundtruth_labels + 1 , prediction_labels + 1 )
52
+ ] = intersection_counts
53
+ confusion_matrix [0 , prediction_labels + 1 ] = (
54
+ prediction_counts - intersected_prediction_counts
55
+ )
56
+ confusion_matrix [groundtruth_labels + 1 , 0 ] = (
57
+ groundtruth_counts - intersected_groundtruth_counts
58
+ )
72
59
73
60
return confusion_matrix
74
61
0 commit comments