Skip to content

Commit c974f41

Browse files
committed
Merge remote-tracking branch 'origin/main' into czaloom-synthetic-semseg-bench
2 parents 5108d94 + 582eb93 commit c974f41

File tree

1 file changed

+11
-24
lines changed

1 file changed

+11
-24
lines changed

lite/valor_lite/semantic_segmentation/computation.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ def compute_intermediate_confusion_matrices(
3131
A 2-D confusion matrix with shape (n_labels + 1, n_labels + 1).
3232
"""
3333

34-
n_gt_labels = groundtruth_labels.size
35-
n_pd_labels = prediction_labels.size
36-
3734
groundtruth_counts = groundtruths.sum(axis=1)
3835
prediction_counts = predictions.sum(axis=1)
3936

@@ -42,33 +39,23 @@ def compute_intermediate_confusion_matrices(
4239
).sum()
4340

4441
intersection_counts = np.logical_and(
45-
groundtruths.reshape(n_gt_labels, 1, -1),
46-
predictions.reshape(1, n_pd_labels, -1),
42+
groundtruths[:, None, :],
43+
predictions[None, :, :],
4744
).sum(axis=2)
48-
4945
intersected_groundtruth_counts = intersection_counts.sum(axis=1)
5046
intersected_prediction_counts = intersection_counts.sum(axis=0)
5147

5248
confusion_matrix = np.zeros((n_labels + 1, n_labels + 1), dtype=np.int32)
5349
confusion_matrix[0, 0] = background_counts
54-
for gidx in range(n_gt_labels):
55-
gt_label_idx = groundtruth_labels[gidx]
56-
for pidx in range(n_pd_labels):
57-
pd_label_idx = prediction_labels[pidx]
58-
confusion_matrix[
59-
gt_label_idx + 1,
60-
pd_label_idx + 1,
61-
] = intersection_counts[gidx, pidx]
62-
63-
if gidx == 0:
64-
confusion_matrix[0, pd_label_idx + 1] = (
65-
prediction_counts[pidx]
66-
- intersected_prediction_counts[pidx]
67-
)
68-
69-
confusion_matrix[gt_label_idx + 1, 0] = (
70-
groundtruth_counts[gidx] - intersected_groundtruth_counts[gidx]
71-
)
50+
confusion_matrix[
51+
np.ix_(groundtruth_labels + 1, prediction_labels + 1)
52+
] = intersection_counts
53+
confusion_matrix[0, prediction_labels + 1] = (
54+
prediction_counts - intersected_prediction_counts
55+
)
56+
confusion_matrix[groundtruth_labels + 1, 0] = (
57+
groundtruth_counts - intersected_groundtruth_counts
58+
)
7259

7360
return confusion_matrix
7461

0 commit comments

Comments
 (0)