-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rework retrieval metrics #525
Conversation
Changelog:
|
oml/functional/metrics.py
Outdated
gt_tops = take_2d(mask_gt, ii_top_k) | ||
n_gt = mask_gt.sum(dim=1) | ||
# let's mark every correctly retrieved item as True and vice versa | ||
gt_tops = stack([isin(retrieved_ids[i], tensor(gt_ids[i])) for i in range(n_queries)]).bool() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove redundant .bool()
call, torch.isin
dtype already defaults to bool
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, but type checkers and PyCharm are not as smart, so this way I explicitly annotate types
oml/functional/metrics.py
Outdated
output: TMetricsDict = {} | ||
|
||
for k, v in metrics.items(): | ||
if isinstance(v, (Tensor, np.ndarray)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove redundant check for np.ndarray
due to TMetricsDict
definition: TMetricsDict = Dict[str, Dict[Union[int, float], Union[float, FloatTensor]]]
or change TMetricsDict
definition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree, I prefer to remove the checking
oml/metrics/embeddings.py
Outdated
max_k_arg = max([*self.cmc_top_k, *self.precision_top_k, *self.map_top_k]) | ||
k = min(self.distance_matrix.shape[1], max_k_arg) # type: ignore | ||
_, retrieved_ids = torch.topk(self.distance_matrix, largest=False, k=k) | ||
gt_ids = [torch.nonzero(row, as_tuple=True)[0].tolist() for row in self.mask_gt] # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why retrieve gt_ids
as List[List[int]]
instead of List[Tensor]
? I propose updating the calc_retrieval_metrics
function to accept gt_ids
of type List[Tensor]
to eliminate the need for double conversion between tensor and list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree
oml/utils/misc.py
Outdated
if isinstance(d2[k], torch.Tensor) and isinstance(v, torch.Tensor): | ||
is_equal = torch.all(torch.isclose(d2[k], v)) | ||
elif isinstance(d2[k], float) and isinstance(v, float): | ||
is_equal = math.isclose(d2[k], v, rel_tol=1e-6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's ensure consistency in tolerance values for numerical comparisons. Currently, torch.isclose
defaults to rtol=1e-05
and atol=1e-08
, while math.isclose
uses rel_tol=1e-09
and abs_tol=0.0
. I suggest updating both functions to use the same values, specifically rtol=1e-06
and atol=1e-08
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree
gt_tops = take_2d(mask_gt, ii_top_k) | ||
n_gt = mask_gt.sum(dim=1) | ||
# let's mark every correctly retrieved item as True and vice versa | ||
gt_tops = stack([isin(retrieved_ids[i], gt_ids[i]) for i in range(n_queries)]).bool() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice :)
Check list: