|
4 | 4 |
|
5 | 5 | import numpy as np
|
6 | 6 | import torch
|
7 |
| -from torch import Tensor |
| 7 | +from torch import BoolTensor, FloatTensor, LongTensor, Tensor, isin, stack, tensor |
8 | 8 |
|
9 | 9 | from oml.utils.misc import check_if_nonempty_positive_integers, clip_max
|
10 |
| -from oml.utils.misc_torch import PCA, pairwise_dist, take_2d |
| 10 | +from oml.utils.misc_torch import PCA, pairwise_dist |
11 | 11 |
|
12 | 12 | TMetricsDict = Dict[str, Dict[Union[int, float], Union[float, Tensor]]]
|
13 | 13 |
|
14 | 14 |
|
15 | 15 | def calc_retrieval_metrics(
|
16 |
| - distances: Tensor, |
17 |
| - mask_gt: Tensor, |
18 |
| - mask_to_ignore: Optional[Tensor] = None, |
| 16 | + retrieved_ids: LongTensor, |
| 17 | + gt_ids: List[LongTensor], |
19 | 18 | cmc_top_k: Tuple[int, ...] = (5,),
|
20 | 19 | precision_top_k: Tuple[int, ...] = (5,),
|
21 | 20 | map_top_k: Tuple[int, ...] = (5,),
|
22 |
| - fmr_vals: Tuple[int, ...] = (1,), |
23 | 21 | reduce: bool = True,
|
24 | 22 | ) -> TMetricsDict:
|
25 | 23 | """
|
26 | 24 | Function to count different retrieval metrics.
|
27 | 25 |
|
28 | 26 | Args:
|
29 |
| - distances: Distance matrix with the shape of ``[query_size, gallery_size]`` |
30 |
| - mask_gt: ``(i,j)`` element indicates if for ``i``-th query ``j``-th gallery is the correct prediction |
31 |
| - mask_to_ignore: Binary matrix to indicate that some elements in the gallery cannot be used |
32 |
| - as answers and must be ignored |
| 27 | + retrieved_ids: Top N gallery ids retrieved for every query with the shape of ``[n_query, top_n]``. |
| 28 | + Every element is within the range ``(0, n_gallery - 1)``. |
| 29 | + gt_ids: Gallery ids relevant to every query, list of ``n_query`` elements where every element may |
| 30 | + have an arbitrary length. Every element is within the range ``(0, n_gallery - 1)`` |
33 | 31 | cmc_top_k: Values of ``k`` to calculate ``cmc@k`` (`Cumulative Matching Characteristic`)
|
34 | 32 | precision_top_k: Values of ``k`` to calculate ``precision@k``
|
35 | 33 | map_top_k: Values of ``k`` to calculate ``map@k`` (`Mean Average Precision`)
|
36 |
| - fmr_vals: Values of ``fmr`` (measured in quantiles) to calculate ``fnmr@fmr`` (`False Non Match Rate |
37 |
| - at the given False Match Rate`). |
38 |
| - For example, if ``fmr_values`` is (0.2, 0.4) we will calculate ``fnmr@fmr=0.2`` and ``fnmr@fmr=0.4`` |
39 | 34 | reduce: If ``False`` return metrics for each query without averaging
|
40 | 35 |
|
41 | 36 | Returns:
|
42 | 37 | Metrics dictionary.
|
43 | 38 |
|
44 | 39 | """
|
45 |
| - top_k_args = [cmc_top_k, precision_top_k, map_top_k] |
46 |
| - |
47 |
| - if not any(top_k_args + [fmr_vals]): |
48 |
| - raise ValueError("You must specify arguments for at leas 1 metric to calculate it") |
49 |
| - |
50 |
| - if distances.shape != mask_gt.shape: |
51 |
| - raise ValueError( |
52 |
| - f"Distances matrix has the shape of {distances.shape}, " |
53 |
| - f"but mask_to_ignore has the shape of {mask_gt.shape}." |
54 |
| - ) |
55 |
| - |
56 |
| - if (mask_to_ignore is not None) and (mask_to_ignore.shape != distances.shape): |
57 |
| - raise ValueError( |
58 |
| - f"Distances matrix has the shape of {distances.shape}, " |
59 |
| - f"but mask_to_ignore has the shape of {mask_to_ignore.shape}." |
60 |
| - ) |
61 |
| - |
62 |
| - query_sz, gallery_sz = distances.shape |
63 |
| - |
64 |
| - for top_k_arg in top_k_args: |
65 |
| - for k in top_k_arg: |
66 |
| - if k > gallery_sz: |
67 |
| - warnings.warn( |
68 |
| - f"Your desired k={k} more than gallery_size={gallery_sz}. " |
69 |
| - f"We'll calculate metrics with k limited by the gallery size." |
70 |
| - ) |
71 |
| - |
72 |
| - if mask_to_ignore is not None: |
73 |
| - distances, mask_gt = apply_mask_to_ignore(distances=distances, mask_gt=mask_gt, mask_to_ignore=mask_to_ignore) |
| 40 | + # todo 522: clipping |
74 | 41 |
|
75 |
| - cmc_top_k_clipped = clip_max(cmc_top_k, gallery_sz) |
76 |
| - precision_top_k_clipped = clip_max(precision_top_k, gallery_sz) |
77 |
| - map_top_k_clipped = clip_max(map_top_k, gallery_sz) |
| 42 | + assert retrieved_ids.ndim == 2, "Retrieved ids must be a tensor with the shape of [n_query, top_n]." |
| 43 | + assert len(retrieved_ids) == len(gt_ids), "Numbers of queries have be the same." |
| 44 | + n_queries = len(retrieved_ids) |
78 | 45 |
|
79 |
| - max_k = max([*cmc_top_k, *precision_top_k, *map_top_k]) |
80 |
| - max_k = min(max_k, gallery_sz) |
81 |
| - |
82 |
| - _, ii_top_k = torch.topk(distances, k=max_k, largest=False) |
83 |
| - gt_tops = take_2d(mask_gt, ii_top_k) |
84 |
| - n_gt = mask_gt.sum(dim=1) |
| 46 | + # let's mark every correctly retrieved item as True and vice versa |
| 47 | + gt_tops = stack([isin(retrieved_ids[i], gt_ids[i]) for i in range(n_queries)]).bool() |
| 48 | + n_gts = tensor([len(ids) for ids in gt_ids]).long() |
85 | 49 |
|
86 | 50 | metrics: TMetricsDict = defaultdict(dict)
|
87 | 51 |
|
88 | 52 | if cmc_top_k:
|
89 |
| - cmc = calc_cmc(gt_tops, cmc_top_k_clipped) |
| 53 | + cmc = calc_cmc(gt_tops, cmc_top_k) |
90 | 54 | metrics["cmc"] = dict(zip(cmc_top_k, cmc))
|
91 | 55 |
|
92 | 56 | if precision_top_k:
|
93 |
| - precision = calc_precision(gt_tops, n_gt, precision_top_k_clipped) |
| 57 | + precision = calc_precision(gt_tops, n_gts, precision_top_k) |
94 | 58 | metrics["precision"] = dict(zip(precision_top_k, precision))
|
95 | 59 |
|
96 | 60 | if map_top_k:
|
97 |
| - map = calc_map(gt_tops, n_gt, map_top_k_clipped) |
| 61 | + map = calc_map(gt_tops, n_gts, map_top_k) |
98 | 62 | metrics["map"] = dict(zip(map_top_k, map))
|
99 | 63 |
|
100 |
| - if fmr_vals: |
101 |
| - pos_dist, neg_dist = extract_pos_neg_dists(distances, mask_gt, mask_to_ignore) |
102 |
| - fnmr_at_fmr = calc_fnmr_at_fmr(pos_dist, neg_dist, fmr_vals) |
103 |
| - metrics["fnmr@fmr"] = dict(zip(fmr_vals, fnmr_at_fmr)) |
104 |
| - |
105 | 64 | if reduce:
|
106 | 65 | metrics = reduce_metrics(metrics)
|
107 | 66 |
|
108 | 67 | return metrics
|
109 | 68 |
|
110 | 69 |
|
| 70 | +def calc_retrieval_metrics_on_full( |
| 71 | + distances: Tensor, |
| 72 | + mask_gt: Tensor, |
| 73 | + mask_to_ignore: Optional[Tensor] = None, |
| 74 | + cmc_top_k: Tuple[int, ...] = (5,), |
| 75 | + precision_top_k: Tuple[int, ...] = (5,), |
| 76 | + map_top_k: Tuple[int, ...] = (5,), |
| 77 | + reduce: bool = True, |
| 78 | +) -> TMetricsDict: |
| 79 | + # todo 522: get rid of this tmp function or at least move to the tests |
| 80 | + if mask_to_ignore is not None: |
| 81 | + distances, mask_gt = apply_mask_to_ignore(distances=distances, mask_gt=mask_gt, mask_to_ignore=mask_to_ignore) |
| 82 | + |
| 83 | + max_k_arg = max([*cmc_top_k, *precision_top_k, *map_top_k]) |
| 84 | + k = min(distances.shape[1], max_k_arg) |
| 85 | + _, retrieved_ids = torch.topk(distances, largest=False, k=k) |
| 86 | + |
| 87 | + gt_ids = [LongTensor(row.nonzero()).view(-1) for row in mask_gt] |
| 88 | + |
| 89 | + metrics = calc_retrieval_metrics( |
| 90 | + cmc_top_k=cmc_top_k, |
| 91 | + precision_top_k=precision_top_k, |
| 92 | + map_top_k=map_top_k, |
| 93 | + reduce=reduce, |
| 94 | + gt_ids=gt_ids, |
| 95 | + retrieved_ids=retrieved_ids, |
| 96 | + ) |
| 97 | + return metrics |
| 98 | + |
| 99 | + |
111 | 100 | def calc_topological_metrics(embeddings: Tensor, pcf_variance: Tuple[float, ...]) -> TMetricsDict:
|
112 | 101 | """
|
113 | 102 | Function to evaluate different topological metrics.
|
@@ -145,6 +134,20 @@ def reduce_metrics(metrics_to_reduce: TMetricsDict) -> TMetricsDict:
|
145 | 134 | return output
|
146 | 135 |
|
147 | 136 |
|
| 137 | +def take_unreduced_metrics_by_mask(metrics: TMetricsDict, mask: BoolTensor) -> TMetricsDict: |
| 138 | + output: TMetricsDict = {} |
| 139 | + |
| 140 | + for k, v in metrics.items(): |
| 141 | + if isinstance(v, Tensor): |
| 142 | + output[k] = v[mask] if v.numel() > 1 else v |
| 143 | + elif isinstance(v, (float, int)): |
| 144 | + output[k] = v |
| 145 | + else: |
| 146 | + output[k] = take_unreduced_metrics_by_mask(v, mask) # type: ignore |
| 147 | + |
| 148 | + return output |
| 149 | + |
| 150 | + |
148 | 151 | def apply_mask_to_ignore(distances: Tensor, mask_gt: Tensor, mask_to_ignore: Tensor) -> Tuple[Tensor, Tensor]:
|
149 | 152 | distances[mask_to_ignore] = float("inf")
|
150 | 153 | mask_gt[mask_to_ignore] = False
|
@@ -466,6 +469,19 @@ def calc_fnmr_at_fmr(pos_dist: Tensor, neg_dist: Tensor, fmr_vals: Tuple[float,
|
466 | 469 | return fnmr_at_fmr
|
467 | 470 |
|
468 | 471 |
|
| 472 | +def calc_fnmr_at_fmr_from_matrices( |
| 473 | + distance_matrix: FloatTensor, mask_gt: BoolTensor, fmr_vals: Tuple[float, ...] |
| 474 | +) -> TMetricsDict: |
| 475 | + metrics: TMetricsDict = dict() |
| 476 | + |
| 477 | + if fmr_vals: |
| 478 | + pos_dist, neg_dist = extract_pos_neg_dists(distance_matrix, mask_gt) |
| 479 | + fnmr_at_fmr = calc_fnmr_at_fmr(pos_dist, neg_dist, fmr_vals) |
| 480 | + metrics["fnmr@fmr"] = dict(zip(fmr_vals, fnmr_at_fmr)) |
| 481 | + |
| 482 | + return metrics |
| 483 | + |
| 484 | + |
469 | 485 | def calc_pcf(embeddings: Tensor, pcf_variance: Tuple[float, ...]) -> List[Tensor]:
|
470 | 486 | """
|
471 | 487 | Function estimates the Principal Components Fraction (PCF) of embeddings using Principal Component Analysis.
|
@@ -529,28 +545,20 @@ def calc_pcf(embeddings: Tensor, pcf_variance: Tuple[float, ...]) -> List[Tensor
|
529 | 545 | return metric
|
530 | 546 |
|
531 | 547 |
|
532 |
| -def extract_pos_neg_dists( |
533 |
| - distances: Tensor, mask_gt: Tensor, mask_to_ignore: Optional[Tensor] |
534 |
| -) -> Tuple[Tensor, Tensor]: |
| 548 | +def extract_pos_neg_dists(distances: Tensor, mask_gt: Tensor) -> Tuple[Tensor, Tensor]: |
535 | 549 | """
|
536 | 550 | Extract distances between relevant samples, and distances between non-relevant samples.
|
537 | 551 |
|
538 | 552 | Args:
|
539 | 553 | distances: Distance matrix with the shape of ``[query_size, gallery_size]``
|
540 | 554 | mask_gt: ``(i,j)`` element indicates if for i-th query j-th gallery is the correct prediction
|
541 |
| - mask_to_ignore: Binary matrix to indicate that some elements in gallery cannot be used |
542 |
| - as answers and must be ignored |
| 555 | +
|
543 | 556 | Returns:
|
544 | 557 | pos_dist: Tensor of distances between relevant samples
|
545 | 558 | neg_dist: Tensor of distances between non-relevant samples
|
546 | 559 | """
|
547 |
| - if mask_to_ignore is not None: |
548 |
| - mask_to_not_ignore = ~mask_to_ignore |
549 |
| - pos_dist = distances[mask_gt & mask_to_not_ignore] |
550 |
| - neg_dist = distances[~mask_gt & mask_to_not_ignore] |
551 |
| - else: |
552 |
| - pos_dist = distances[mask_gt] |
553 |
| - neg_dist = distances[~mask_gt] |
| 560 | + pos_dist = distances[mask_gt] |
| 561 | + neg_dist = distances[~mask_gt] |
554 | 562 | return pos_dist, neg_dist
|
555 | 563 |
|
556 | 564 |
|
@@ -592,10 +600,12 @@ def _check_if_in_range(vals: Sequence[float], min_: float, max_: float, name: st
|
592 | 600 | __all__ = [
|
593 | 601 | "TMetricsDict",
|
594 | 602 | "calc_retrieval_metrics",
|
| 603 | + "calc_retrieval_metrics_on_full", |
595 | 604 | "calc_topological_metrics",
|
596 | 605 | "apply_mask_to_ignore",
|
597 | 606 | "calc_gt_mask",
|
598 | 607 | "calc_mask_to_ignore",
|
599 | 608 | "calc_distance_matrix",
|
600 | 609 | "reduce_metrics",
|
| 610 | + "take_unreduced_metrics_by_mask", |
601 | 611 | ]
|
0 commit comments