Skip to content

Commit

Permalink
When there are no gts we raise warning instead of error
Browse files Browse the repository at this point in the history
When there are no gts we raise warning instead of error
  • Loading branch information
AlekseySh authored Jun 18, 2024
1 parent ec08e2b commit 9cf93be
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion oml/retrieval/retrieval_results.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from copy import deepcopy
from pprint import pformat
from typing import Callable, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(
if gt_ids is not None:
assert len(distances) == len(gt_ids)
if any(len(x) == 0 for x in gt_ids):
raise RuntimeError("Every query must have at least one relevant gallery id.")
warnings.warn("Some of the queries don't have available gts.")

self._distances = tuple(distances)
self._retrieved_ids = tuple(retrieved_ids)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_invariants_in_validation_with_sequences_3() -> None:
df_with_seq = df.copy()
df_with_seq[SEQUENCE_COLUMN] = df_with_seq[LABELS_COLUMN]

with pytest.raises(RuntimeError):
with pytest.warns(UserWarning, match="Some of the queries don't have available gts."):
validation(df)
validation(df_with_seq)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_visualisation_for_different_number_of_retrieved_items() -> None:

def test_retrieval_results_creation() -> None:
# there is a query with no gt
with pytest.raises(RuntimeError):
with pytest.warns(UserWarning, match="Some of the queries don't have available gts."):
RetrievalResults(
distances=[torch.arange(3).float(), torch.arange(3).float()],
retrieved_ids=[LongTensor([1, 0, 2]), LongTensor([4, 0, 1])],
Expand Down

0 comments on commit 9cf93be

Please sign in to comment.