Skip to content

Commit b4c8b67

Browse files
committed
merge
2 parents 3df8de4 + 749c326 commit b4c8b67

File tree

5 files changed

+108
-92
lines changed

5 files changed

+108
-92
lines changed

oml/functional/metrics.py

Lines changed: 81 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -4,110 +4,99 @@
44

55
import numpy as np
66
import torch
7-
from torch import Tensor
7+
from torch import BoolTensor, FloatTensor, LongTensor, Tensor, isin, stack, tensor
88

99
from oml.utils.misc import check_if_nonempty_positive_integers, clip_max
10-
from oml.utils.misc_torch import PCA, pairwise_dist, take_2d
10+
from oml.utils.misc_torch import PCA, pairwise_dist
1111

1212
TMetricsDict = Dict[str, Dict[Union[int, float], Union[float, Tensor]]]
1313

1414

1515
def calc_retrieval_metrics(
16-
distances: Tensor,
17-
mask_gt: Tensor,
18-
mask_to_ignore: Optional[Tensor] = None,
16+
retrieved_ids: LongTensor,
17+
gt_ids: List[LongTensor],
1918
cmc_top_k: Tuple[int, ...] = (5,),
2019
precision_top_k: Tuple[int, ...] = (5,),
2120
map_top_k: Tuple[int, ...] = (5,),
22-
fmr_vals: Tuple[int, ...] = (1,),
2321
reduce: bool = True,
2422
) -> TMetricsDict:
2523
"""
2624
Function to count different retrieval metrics.
2725
2826
Args:
29-
distances: Distance matrix with the shape of ``[query_size, gallery_size]``
30-
mask_gt: ``(i,j)`` element indicates if for ``i``-th query ``j``-th gallery is the correct prediction
31-
mask_to_ignore: Binary matrix to indicate that some elements in the gallery cannot be used
32-
as answers and must be ignored
27+
retrieved_ids: Top N gallery ids retrieved for every query with the shape of ``[n_query, top_n]``.
28+
Every element is within the range ``(0, n_gallery - 1)``.
29+
gt_ids: Gallery ids relevant to every query, list of ``n_query`` elements where every element may
30+
have an arbitrary length. Every element is within the range ``(0, n_gallery - 1)``
3331
cmc_top_k: Values of ``k`` to calculate ``cmc@k`` (`Cumulative Matching Characteristic`)
3432
precision_top_k: Values of ``k`` to calculate ``precision@k``
3533
map_top_k: Values of ``k`` to calculate ``map@k`` (`Mean Average Precision`)
36-
fmr_vals: Values of ``fmr`` (measured in quantiles) to calculate ``fnmr@fmr`` (`False Non Match Rate
37-
at the given False Match Rate`).
38-
For example, if ``fmr_values`` is (0.2, 0.4) we will calculate ``fnmr@fmr=0.2`` and ``fnmr@fmr=0.4``
3934
reduce: If ``False`` return metrics for each query without averaging
4035
4136
Returns:
4237
Metrics dictionary.
4338
4439
"""
45-
top_k_args = [cmc_top_k, precision_top_k, map_top_k]
46-
47-
if not any(top_k_args + [fmr_vals]):
48-
raise ValueError("You must specify arguments for at leas 1 metric to calculate it")
49-
50-
if distances.shape != mask_gt.shape:
51-
raise ValueError(
52-
f"Distances matrix has the shape of {distances.shape}, "
53-
f"but mask_to_ignore has the shape of {mask_gt.shape}."
54-
)
55-
56-
if (mask_to_ignore is not None) and (mask_to_ignore.shape != distances.shape):
57-
raise ValueError(
58-
f"Distances matrix has the shape of {distances.shape}, "
59-
f"but mask_to_ignore has the shape of {mask_to_ignore.shape}."
60-
)
61-
62-
query_sz, gallery_sz = distances.shape
63-
64-
for top_k_arg in top_k_args:
65-
for k in top_k_arg:
66-
if k > gallery_sz:
67-
warnings.warn(
68-
f"Your desired k={k} more than gallery_size={gallery_sz}. "
69-
f"We'll calculate metrics with k limited by the gallery size."
70-
)
71-
72-
if mask_to_ignore is not None:
73-
distances, mask_gt = apply_mask_to_ignore(distances=distances, mask_gt=mask_gt, mask_to_ignore=mask_to_ignore)
40+
# todo 522: clipping
7441

75-
cmc_top_k_clipped = clip_max(cmc_top_k, gallery_sz)
76-
precision_top_k_clipped = clip_max(precision_top_k, gallery_sz)
77-
map_top_k_clipped = clip_max(map_top_k, gallery_sz)
42+
assert retrieved_ids.ndim == 2, "Retrieved ids must be a tensor with the shape of [n_query, top_n]."
43+
assert len(retrieved_ids) == len(gt_ids), "Numbers of queries have be the same."
44+
n_queries = len(retrieved_ids)
7845

79-
max_k = max([*cmc_top_k, *precision_top_k, *map_top_k])
80-
max_k = min(max_k, gallery_sz)
81-
82-
_, ii_top_k = torch.topk(distances, k=max_k, largest=False)
83-
gt_tops = take_2d(mask_gt, ii_top_k)
84-
n_gt = mask_gt.sum(dim=1)
46+
# let's mark every correctly retrieved item as True and vice versa
47+
gt_tops = stack([isin(retrieved_ids[i], gt_ids[i]) for i in range(n_queries)]).bool()
48+
n_gts = tensor([len(ids) for ids in gt_ids]).long()
8549

8650
metrics: TMetricsDict = defaultdict(dict)
8751

8852
if cmc_top_k:
89-
cmc = calc_cmc(gt_tops, cmc_top_k_clipped)
53+
cmc = calc_cmc(gt_tops, cmc_top_k)
9054
metrics["cmc"] = dict(zip(cmc_top_k, cmc))
9155

9256
if precision_top_k:
93-
precision = calc_precision(gt_tops, n_gt, precision_top_k_clipped)
57+
precision = calc_precision(gt_tops, n_gts, precision_top_k)
9458
metrics["precision"] = dict(zip(precision_top_k, precision))
9559

9660
if map_top_k:
97-
map = calc_map(gt_tops, n_gt, map_top_k_clipped)
61+
map = calc_map(gt_tops, n_gts, map_top_k)
9862
metrics["map"] = dict(zip(map_top_k, map))
9963

100-
if fmr_vals:
101-
pos_dist, neg_dist = extract_pos_neg_dists(distances, mask_gt, mask_to_ignore)
102-
fnmr_at_fmr = calc_fnmr_at_fmr(pos_dist, neg_dist, fmr_vals)
103-
metrics["fnmr@fmr"] = dict(zip(fmr_vals, fnmr_at_fmr))
104-
10564
if reduce:
10665
metrics = reduce_metrics(metrics)
10766

10867
return metrics
10968

11069

70+
def calc_retrieval_metrics_on_full(
71+
distances: Tensor,
72+
mask_gt: Tensor,
73+
mask_to_ignore: Optional[Tensor] = None,
74+
cmc_top_k: Tuple[int, ...] = (5,),
75+
precision_top_k: Tuple[int, ...] = (5,),
76+
map_top_k: Tuple[int, ...] = (5,),
77+
reduce: bool = True,
78+
) -> TMetricsDict:
79+
# todo 522: get rid of this tmp function or at least move to the tests
80+
if mask_to_ignore is not None:
81+
distances, mask_gt = apply_mask_to_ignore(distances=distances, mask_gt=mask_gt, mask_to_ignore=mask_to_ignore)
82+
83+
max_k_arg = max([*cmc_top_k, *precision_top_k, *map_top_k])
84+
k = min(distances.shape[1], max_k_arg)
85+
_, retrieved_ids = torch.topk(distances, largest=False, k=k)
86+
87+
gt_ids = [LongTensor(row.nonzero()).view(-1) for row in mask_gt]
88+
89+
metrics = calc_retrieval_metrics(
90+
cmc_top_k=cmc_top_k,
91+
precision_top_k=precision_top_k,
92+
map_top_k=map_top_k,
93+
reduce=reduce,
94+
gt_ids=gt_ids,
95+
retrieved_ids=retrieved_ids,
96+
)
97+
return metrics
98+
99+
111100
def calc_topological_metrics(embeddings: Tensor, pcf_variance: Tuple[float, ...]) -> TMetricsDict:
112101
"""
113102
Function to evaluate different topological metrics.
@@ -145,6 +134,20 @@ def reduce_metrics(metrics_to_reduce: TMetricsDict) -> TMetricsDict:
145134
return output
146135

147136

137+
def take_unreduced_metrics_by_mask(metrics: TMetricsDict, mask: BoolTensor) -> TMetricsDict:
138+
output: TMetricsDict = {}
139+
140+
for k, v in metrics.items():
141+
if isinstance(v, Tensor):
142+
output[k] = v[mask] if v.numel() > 1 else v
143+
elif isinstance(v, (float, int)):
144+
output[k] = v
145+
else:
146+
output[k] = take_unreduced_metrics_by_mask(v, mask) # type: ignore
147+
148+
return output
149+
150+
148151
def apply_mask_to_ignore(distances: Tensor, mask_gt: Tensor, mask_to_ignore: Tensor) -> Tuple[Tensor, Tensor]:
149152
distances[mask_to_ignore] = float("inf")
150153
mask_gt[mask_to_ignore] = False
@@ -466,6 +469,19 @@ def calc_fnmr_at_fmr(pos_dist: Tensor, neg_dist: Tensor, fmr_vals: Tuple[float,
466469
return fnmr_at_fmr
467470

468471

472+
def calc_fnmr_at_fmr_from_matrices(
473+
distance_matrix: FloatTensor, mask_gt: BoolTensor, fmr_vals: Tuple[float, ...]
474+
) -> TMetricsDict:
475+
metrics: TMetricsDict = dict()
476+
477+
if fmr_vals:
478+
pos_dist, neg_dist = extract_pos_neg_dists(distance_matrix, mask_gt)
479+
fnmr_at_fmr = calc_fnmr_at_fmr(pos_dist, neg_dist, fmr_vals)
480+
metrics["fnmr@fmr"] = dict(zip(fmr_vals, fnmr_at_fmr))
481+
482+
return metrics
483+
484+
469485
def calc_pcf(embeddings: Tensor, pcf_variance: Tuple[float, ...]) -> List[Tensor]:
470486
"""
471487
Function estimates the Principal Components Fraction (PCF) of embeddings using Principal Component Analysis.
@@ -529,28 +545,20 @@ def calc_pcf(embeddings: Tensor, pcf_variance: Tuple[float, ...]) -> List[Tensor
529545
return metric
530546

531547

532-
def extract_pos_neg_dists(
533-
distances: Tensor, mask_gt: Tensor, mask_to_ignore: Optional[Tensor]
534-
) -> Tuple[Tensor, Tensor]:
548+
def extract_pos_neg_dists(distances: Tensor, mask_gt: Tensor) -> Tuple[Tensor, Tensor]:
535549
"""
536550
Extract distances between relevant samples, and distances between non-relevant samples.
537551
538552
Args:
539553
distances: Distance matrix with the shape of ``[query_size, gallery_size]``
540554
mask_gt: ``(i,j)`` element indicates if for i-th query j-th gallery is the correct prediction
541-
mask_to_ignore: Binary matrix to indicate that some elements in gallery cannot be used
542-
as answers and must be ignored
555+
543556
Returns:
544557
pos_dist: Tensor of distances between relevant samples
545558
neg_dist: Tensor of distances between non-relevant samples
546559
"""
547-
if mask_to_ignore is not None:
548-
mask_to_not_ignore = ~mask_to_ignore
549-
pos_dist = distances[mask_gt & mask_to_not_ignore]
550-
neg_dist = distances[~mask_gt & mask_to_not_ignore]
551-
else:
552-
pos_dist = distances[mask_gt]
553-
neg_dist = distances[~mask_gt]
560+
pos_dist = distances[mask_gt]
561+
neg_dist = distances[~mask_gt]
554562
return pos_dist, neg_dist
555563

556564

@@ -592,10 +600,12 @@ def _check_if_in_range(vals: Sequence[float], min_: float, max_: float, name: st
592600
__all__ = [
593601
"TMetricsDict",
594602
"calc_retrieval_metrics",
603+
"calc_retrieval_metrics_on_full",
595604
"calc_topological_metrics",
596605
"apply_mask_to_ignore",
597606
"calc_gt_mask",
598607
"calc_mask_to_ignore",
599608
"calc_distance_matrix",
600609
"reduce_metrics",
610+
"take_unreduced_metrics_by_mask",
601611
]

oml/metrics/embeddings.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@
3232
TMetricsDict,
3333
apply_mask_to_ignore,
3434
calc_distance_matrix,
35+
calc_fnmr_at_fmr_from_matrices,
3536
calc_gt_mask,
3637
calc_mask_to_ignore,
37-
calc_retrieval_metrics,
38+
calc_retrieval_metrics_on_full,
3839
calc_topological_metrics,
3940
reduce_metrics,
41+
take_unreduced_metrics_by_mask,
4042
)
4143
from oml.interfaces.datasets import IQueryGalleryLabeledDataset
4244
from oml.interfaces.metrics import IMetricDDP, IMetricVisualisable
@@ -221,14 +223,12 @@ def compute_metrics(self) -> TMetricsDict_ByLabels: # type: ignore
221223
"cmc_top_k": self.cmc_top_k,
222224
"precision_top_k": self.precision_top_k,
223225
"map_top_k": self.map_top_k,
224-
"fmr_vals": self.fmr_vals,
225226
}
226-
args_topological_metrics = {"pcf_variance": self.pcf_variance}
227227

228228
metrics: TMetricsDict_ByLabels = dict()
229229

230230
# note, here we do micro averaging
231-
metrics[self.overall_categories_key] = calc_retrieval_metrics(
231+
metrics[self.overall_categories_key] = calc_retrieval_metrics_on_full(
232232
distances=self.distance_matrix,
233233
mask_gt=self.mask_gt,
234234
reduce=False,
@@ -237,26 +237,28 @@ def compute_metrics(self) -> TMetricsDict_ByLabels: # type: ignore
237237
)
238238

239239
embeddings = self.acc.storage[self.embeddings_key]
240-
metrics[self.overall_categories_key].update(calc_topological_metrics(embeddings, **args_topological_metrics))
240+
metrics[self.overall_categories_key].update(calc_topological_metrics(embeddings, self.pcf_variance))
241+
metrics[self.overall_categories_key].update(
242+
calc_fnmr_at_fmr_from_matrices(self.distance_matrix, self.mask_gt, self.fmr_vals)
243+
)
241244

242245
if self.categories_key is not None:
243246
categories = np.array(self.acc.storage[self.categories_key])
244247
is_query = self.acc.storage[self.is_query_key]
245248
query_categories = categories[is_query]
246249

247250
for category in np.unique(query_categories):
248-
mask = query_categories == category
249-
250-
metrics[category] = calc_retrieval_metrics(
251-
distances=self.distance_matrix[mask], # type: ignore
252-
mask_gt=self.mask_gt[mask], # type: ignore
253-
reduce=False,
254-
mask_to_ignore=None, # we already applied it
255-
**args_retrieval_metrics, # type: ignore
251+
mask_query_sz = query_categories == category
252+
253+
metrics[category] = take_unreduced_metrics_by_mask(metrics[self.overall_categories_key], mask_query_sz)
254+
metrics[category].update(
255+
calc_fnmr_at_fmr_from_matrices(
256+
self.distance_matrix[mask_query_sz], self.mask_gt[mask_query_sz], self.fmr_vals # type: ignore
257+
)
256258
)
257259

258-
mask = categories == category
259-
metrics[category].update(calc_topological_metrics(embeddings[mask], **args_topological_metrics))
260+
mask_dataset_sz = categories == category
261+
metrics[category].update(calc_topological_metrics(embeddings[mask_dataset_sz], self.pcf_variance))
260262

261263
self.metrics_unreduced = metrics # type: ignore
262264
self.metrics = reduce_metrics(metrics) # type: ignore

tests/test_oml/test_functional/test_metrics/test_cmc_metric_old.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import torch
77
from torch import Tensor
88

9-
from oml.functional.metrics import calc_retrieval_metrics
9+
from oml.functional.metrics import (
10+
calc_retrieval_metrics_on_full as calc_retrieval_metrics,
11+
)
1012
from oml.utils.misc_torch import pairwise_dist
1113

1214

@@ -18,7 +20,6 @@ def cmc_score_count(distances: Tensor, mask_gt: Tensor, topk: int, mask_to_ignor
1820
cmc_top_k=(topk,),
1921
map_top_k=tuple(),
2022
precision_top_k=tuple(),
21-
fmr_vals=tuple(),
2223
)
2324
return metrics["cmc"][topk]
2425

tests/test_oml/test_functional/test_metrics/test_retrieval_metrics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
calc_map,
1515
calc_mask_to_ignore,
1616
calc_precision,
17-
calc_retrieval_metrics,
17+
)
18+
from oml.functional.metrics import (
19+
calc_retrieval_metrics_on_full as calc_retrieval_metrics,
1820
)
1921
from oml.metrics.embeddings import validate_dataset
2022
from oml.utils.misc import remove_unused_kwargs

tests/test_oml/test_postprocessor/test_pairwise_embeddings.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import torch
77
from torch import Tensor
88

9-
from oml.functional.metrics import calc_retrieval_metrics
9+
from oml.functional.metrics import (
10+
calc_retrieval_metrics_on_full as calc_retrieval_metrics,
11+
)
1012
from oml.interfaces.datasets import IQueryGalleryDataset, IQueryGalleryLabeledDataset
1113
from oml.interfaces.models import IPairwiseModel
1214
from oml.models.meta.siamese import LinearTrivialDistanceSiamese
@@ -131,7 +133,7 @@ def test_trivial_processing_fixes_broken_perfect_case(pairwise_distances_bias: f
131133
top_k = (randint(1, ng - 1),)
132134
top_n = randint(2, 10)
133135

134-
args = {"mask_gt": mask_gt, "precision_top_k": top_k, "map_top_k": top_k, "cmc_top_k": top_k, "fmr_vals": ()}
136+
args = {"mask_gt": mask_gt, "precision_top_k": top_k, "map_top_k": top_k, "cmc_top_k": top_k}
135137

136138
# Metrics before
137139
metrics = flatten_dict(calc_retrieval_metrics(distances=distances, **args))
@@ -185,7 +187,6 @@ def test_processing_not_changing_non_sensitive_metrics(top_n: int) -> None:
185187
args = {
186188
"cmc_top_k": (top_n,),
187189
"precision_top_k": (top_n,),
188-
"fmr_vals": tuple(),
189190
"map_top_k": tuple(),
190191
"mask_gt": mask_gt,
191192
}

0 commit comments

Comments
 (0)