diff --git a/lite/valor_lite/semantic_segmentation/computation.py b/lite/valor_lite/semantic_segmentation/computation.py index 556a0ef9b..807c2bde6 100644 --- a/lite/valor_lite/semantic_segmentation/computation.py +++ b/lite/valor_lite/semantic_segmentation/computation.py @@ -31,9 +31,6 @@ def compute_intermediate_confusion_matrices( A 2-D confusion matrix with shape (n_labels + 1, n_labels + 1). """ - n_gt_labels = groundtruth_labels.size - n_pd_labels = prediction_labels.size - groundtruth_counts = groundtruths.sum(axis=1) prediction_counts = predictions.sum(axis=1) @@ -42,33 +39,23 @@ def compute_intermediate_confusion_matrices( ).sum() intersection_counts = np.logical_and( - groundtruths.reshape(n_gt_labels, 1, -1), - predictions.reshape(1, n_pd_labels, -1), + groundtruths[:, None, :], + predictions[None, :, :], ).sum(axis=2) - intersected_groundtruth_counts = intersection_counts.sum(axis=1) intersected_prediction_counts = intersection_counts.sum(axis=0) confusion_matrix = np.zeros((n_labels + 1, n_labels + 1), dtype=np.int32) confusion_matrix[0, 0] = background_counts - for gidx in range(n_gt_labels): - gt_label_idx = groundtruth_labels[gidx] - for pidx in range(n_pd_labels): - pd_label_idx = prediction_labels[pidx] - confusion_matrix[ - gt_label_idx + 1, - pd_label_idx + 1, - ] = intersection_counts[gidx, pidx] - - if gidx == 0: - confusion_matrix[0, pd_label_idx + 1] = ( - prediction_counts[pidx] - - intersected_prediction_counts[pidx] - ) - - confusion_matrix[gt_label_idx + 1, 0] = ( - groundtruth_counts[gidx] - intersected_groundtruth_counts[gidx] - ) + confusion_matrix[ + np.ix_(groundtruth_labels + 1, prediction_labels + 1) + ] = intersection_counts + confusion_matrix[0, prediction_labels + 1] = ( + prediction_counts - intersected_prediction_counts + ) + confusion_matrix[groundtruth_labels + 1, 0] = ( + groundtruth_counts - intersected_groundtruth_counts + ) return confusion_matrix