Skip to content

Commit e9a81e0

Browse files
committed
upd
1 parent fac3178 commit e9a81e0

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

oml/metrics/embeddings.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def compute_metrics(self) -> TMetricsDict_ByLabels: # type: ignore
212212
# note, here we do micro averaging
213213
metrics[self.overall_categories_key] = calc_retrieval_metrics(
214214
distances=self.distance_matrix,
215-
gt_ids=self.mask_gt,
215+
mask_gt=self.mask_gt,
216216
reduce=False,
217217
mask_to_ignore=None, # we already applied it
218218
**args_retrieval_metrics, # type: ignore
@@ -231,7 +231,7 @@ def compute_metrics(self) -> TMetricsDict_ByLabels: # type: ignore
231231

232232
metrics[category] = calc_retrieval_metrics(
233233
distances=self.distance_matrix[mask], # type: ignore
234-
gt_ids=self.mask_gt[mask], # type: ignore
234+
mask_gt=self.mask_gt[mask], # type: ignore
235235
reduce=False,
236236
mask_to_ignore=None, # we already applied it
237237
**args_retrieval_metrics, # type: ignore

tests/test_oml/test_functional/test_metrics/test_cmc_metric_old.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
def cmc_score_count(distances: Tensor, mask_gt: Tensor, topk: int, mask_to_ignore: Optional[Tensor] = None) -> float:
1414
metrics = calc_retrieval_metrics(
1515
distances=distances,
16-
gt_ids=mask_gt,
16+
mask_gt=mask_gt,
1717
mask_to_ignore=mask_to_ignore,
1818
cmc_top_k=(topk,),
1919
map_top_k=tuple(),

0 commit comments

Comments
 (0)