Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into czaloom-synthetic-sem…
Browse files Browse the repository at this point in the history
…seg-bench
  • Loading branch information
czaloom committed Nov 13, 2024
2 parents 5108d94 + 582eb93 commit c974f41
Showing 1 changed file with 11 additions and 24 deletions.
35 changes: 11 additions & 24 deletions lite/valor_lite/semantic_segmentation/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ def compute_intermediate_confusion_matrices(
A 2-D confusion matrix with shape (n_labels + 1, n_labels + 1).
"""

n_gt_labels = groundtruth_labels.size
n_pd_labels = prediction_labels.size

groundtruth_counts = groundtruths.sum(axis=1)
prediction_counts = predictions.sum(axis=1)

Expand All @@ -42,33 +39,23 @@ def compute_intermediate_confusion_matrices(
).sum()

intersection_counts = np.logical_and(
groundtruths.reshape(n_gt_labels, 1, -1),
predictions.reshape(1, n_pd_labels, -1),
groundtruths[:, None, :],
predictions[None, :, :],
).sum(axis=2)

intersected_groundtruth_counts = intersection_counts.sum(axis=1)
intersected_prediction_counts = intersection_counts.sum(axis=0)

confusion_matrix = np.zeros((n_labels + 1, n_labels + 1), dtype=np.int32)
confusion_matrix[0, 0] = background_counts
for gidx in range(n_gt_labels):
gt_label_idx = groundtruth_labels[gidx]
for pidx in range(n_pd_labels):
pd_label_idx = prediction_labels[pidx]
confusion_matrix[
gt_label_idx + 1,
pd_label_idx + 1,
] = intersection_counts[gidx, pidx]

if gidx == 0:
confusion_matrix[0, pd_label_idx + 1] = (
prediction_counts[pidx]
- intersected_prediction_counts[pidx]
)

confusion_matrix[gt_label_idx + 1, 0] = (
groundtruth_counts[gidx] - intersected_groundtruth_counts[gidx]
)
confusion_matrix[
np.ix_(groundtruth_labels + 1, prediction_labels + 1)
] = intersection_counts
confusion_matrix[0, prediction_labels + 1] = (
prediction_counts - intersected_prediction_counts
)
confusion_matrix[groundtruth_labels + 1, 0] = (
groundtruth_counts - intersected_groundtruth_counts
)

return confusion_matrix

Expand Down

0 comments on commit c974f41

Please sign in to comment.