diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index d82505b54468..867b53cd2f2f 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -2,6 +2,7 @@ from typing import Callable, cast, List, Optional, Sequence, Tuple, Union import torch +from packaging.version import Version from typing_extensions import Literal import ignite.distributed as idist @@ -11,6 +12,9 @@ from ignite.utils import to_onehot +_torch_version_lt_113 = Version(torch.__version__) < Version("1.13.0") + + class _BaseAveragePrecision: def __init__( self, @@ -97,9 +101,12 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens if self.rec_thresholds is not None: rec_thresholds = self.rec_thresholds.repeat((*recall.shape[:-1], 1)) rec_thresh_indices = torch.searchsorted(recall, rec_thresholds) - precision = precision.take_along_dim( - rec_thresh_indices.where(rec_thresh_indices != recall.size(-1), 0), dim=-1 - ).where(rec_thresh_indices != recall.size(-1), 0) + rec_mask = rec_thresh_indices != recall.size(-1) + precision = torch.where( + rec_mask, + precision.take_along_dim(torch.where(rec_mask, rec_thresh_indices, 0), dim=-1), + 0.0, + ) recall = rec_thresholds recall_differential = recall.diff( dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=recall.device, dtype=recall.dtype) @@ -335,9 +342,10 @@ def _compute_recall_and_precision( Returns: `(recall, precision)` """ - indices = torch.argsort(y_pred, stable=True, descending=True) + kwargs = {} if _torch_version_lt_113 else {"stable": True} + indices = torch.argsort(y_pred, descending=True, **kwargs) tp_summation = y_true[indices].cumsum(dim=0) - if tp_summation.device != torch.device("mps"): + if tp_summation.device.type != "mps": tp_summation = tp_summation.double() # Adopted from Scikit-learn's implementation @@ -354,7 +362,7 @@ def _compute_recall_and_precision( recall = tp_summation / y_true_positive_count predicted_positive = tp_summation + fp_summation - precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive) + precision = tp_summation / torch.where(predicted_positive == 0, 1.0, predicted_positive) return recall, precision def compute(self) -> Union[torch.Tensor, float]: @@ -371,7 +379,7 @@ def compute(self) -> Union[torch.Tensor, float]: torch.long if self._type == "multiclass" else torch.uint8, self._device, ) - fp_precision = torch.double if self._device != torch.device("mps") else torch.float32 + fp_precision = torch.double if self._device.type != "mps" else torch.float32 y_pred = _cat_and_agg_tensors(self._y_pred, (num_classes,), fp_precision, self._device) if self._type == "multiclass": diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index 881c9ccc5ed8..c15d8acb84fc 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -1,6 +1,7 @@ from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union import torch +from packaging.version import Version from typing_extensions import Literal from ignite.metrics import MetricGroup @@ -9,6 +10,9 @@ from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce +_torch_version_lt_113 = Version(torch.__version__) < Version("1.13.0") + + def coco_tensor_list_to_dict_list( output: Tuple[ Union[List[torch.Tensor], List[Dict[str, torch.Tensor]]], @@ -213,7 +217,8 @@ def _compute_recall_and_precision( Returns: `(recall, precision)` """ - indices = torch.argsort(scores, dim=-1, stable=True, descending=True) + kwargs = {} if _torch_version_lt_113 else {"stable": True} + indices = torch.argsort(scores, descending=True, **kwargs) tp = TP[..., indices] tp_summation = tp.cumsum(dim=-1) if tp_summation.device.type != "mps": @@ -226,7 +231,7 @@ def _compute_recall_and_precision( recall = tp_summation / y_true_count predicted_positive = tp_summation + fp_summation - precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive) + precision = tp_summation / torch.where(predicted_positive == 0, 1.0, predicted_positive) return recall, precision @@ -258,9 +263,12 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens if recall.size(-1) != 0 else torch.LongTensor([], device=self._device) ) - precision_integrand = precision_integrand.take_along_dim( - rec_thresh_indices.where(rec_thresh_indices != recall.size(-1), 0), dim=-1 - ).where(rec_thresh_indices != recall.size(-1), 0) + recall_mask = rec_thresh_indices != recall.size(-1) + precision_integrand = torch.where( + recall_mask, + precision_integrand.take_along_dim(torch.where(recall_mask, rec_thresh_indices, 0), dim=-1), + 0.0, + ) return torch.sum(precision_integrand, dim=-1) / len(cast(torch.Tensor, self.rec_thresholds)) @reinit__is_reduced @@ -298,6 +306,7 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor This key is optional. ========= ================= ================================================= """ + kwargs = {} if _torch_version_lt_113 else {"stable": True} self._check_matching_input(output) for pred, target in zip(*output): labels = target["labels"] @@ -312,7 +321,7 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor # Matching logic of object detection mAP, according to COCO reference implementation. if len(pred["labels"]): - best_detections_index = torch.argsort(pred["scores"], stable=True, descending=True) + best_detections_index = torch.argsort(pred["scores"], descending=True, **kwargs) max_best_detections_index = torch.cat( [ best_detections_index[pred["labels"][best_detections_index] == c][ diff --git a/tests/ignite/metrics/test_mean_average_precision.py b/tests/ignite/metrics/test_mean_average_precision.py index 16be8b7fbb04..8ba5630213a0 100644 --- a/tests/ignite/metrics/test_mean_average_precision.py +++ b/tests/ignite/metrics/test_mean_average_precision.py @@ -45,7 +45,7 @@ def test__prepare_output(): metric = MeanAveragePrecision() metric._type = "binary" - scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)).bool())) + scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)))) assert scores.shape == y.shape == (1, 120) metric._type = "multiclass" @@ -53,14 +53,14 @@ def test__prepare_output(): assert scores.shape == (4, 30) and y.shape == (30,) metric._type = "multilabel" - scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)).bool())) + scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)))) assert scores.shape == y.shape == (4, 30) def test_update(): metric = MeanAveragePrecision() assert len(metric._y_pred) == len(metric._y_true) == 0 - metric.update((torch.rand((5, 4)), torch.randint(0, 2, (5, 4)).bool())) + metric.update((torch.rand((5, 4)), torch.randint(0, 2, (5, 4)))) assert len(metric._y_pred) == len(metric._y_true) == 1 @@ -68,7 +68,7 @@ def test__compute_recall_and_precision(): m = MeanAveragePrecision() scores = torch.rand((50,)) - y_true = torch.randint(0, 2, (50,)).bool() + y_true = torch.randint(0, 2, (50,)) precision, recall, _ = precision_recall_curve(y_true.numpy(), scores.numpy()) P = y_true.sum(dim=-1) ignite_recall, ignite_precision = m._compute_recall_and_precision(y_true, scores, P) @@ -77,7 +77,7 @@ def test__compute_recall_and_precision(): # When there's no actual positive. Numpy expectedly raises warning. scores = torch.rand((50,)) - y_true = torch.zeros((50,)).bool() + y_true = torch.zeros((50,)) precision, recall, _ = precision_recall_curve(y_true.numpy(), scores.numpy()) P = torch.tensor(0) ignite_recall, ignite_precision = m._compute_recall_and_precision(y_true, scores, P) @@ -147,7 +147,7 @@ def test_compute_nonbinary_data(class_mean): # Multilabel m = MeanAveragePrecision(is_multilabel=True, class_mean=class_mean) - y_true = torch.randint(0, 2, (130, 5, 2, 2)).bool() + y_true = torch.randint(0, 2, (130, 5, 2, 2)) m.update((scores[:50], y_true[:50])) m.update((scores[50:], y_true[50:])) ignite_map = m.compute().numpy() diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index 712b2fdebdf9..9a060911dd26 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -864,7 +864,7 @@ def test__compute_recall_and_precision(): def test_compute(sample): device = idist.device() - if device == torch.device("mps"): + if device.type == "mps": pytest.skip("Due to MPS backend out of memory") # AP@.5...95, AP@.5, AP@.75, AP-S, AP-M, AP-L, AR-1, AR-10, AR-100, AR-S, AR-M, AR-L @@ -924,7 +924,7 @@ def test_integration(sample): bs = 3 device = idist.device() - if device == torch.device("mps"): + if device.type == "mps": pytest.skip("Due to MPS backend out of memory") def update(engine, i): @@ -995,7 +995,7 @@ def test_distrib_update_compute(distributed, sample): device = idist.device() - if device == torch.device("mps"): + if device.type == "mps": pytest.skip("Due to MPS backend out of memory") metric_device = "cpu" if device.type == "xla" else device