Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework retrieval metrics #525

Merged
merged 14 commits into from
Apr 14, 2024
Merged

Conversation

AlekseySh
Copy link
Contributor

@AlekseySh AlekseySh commented Apr 10, 2024

Check list:

  • Add a few new tests for the case of clipped predictions
  • Don't forget to call topological metrics
  • Find in code "todo 522" and check if they can be solved now

@AlekseySh AlekseySh self-assigned this Apr 10, 2024
@AlekseySh AlekseySh mentioned this pull request Apr 10, 2024
@AlekseySh AlekseySh linked an issue Apr 10, 2024 that may be closed by this pull request
@AlekseySh AlekseySh requested a review from DaloroAT April 12, 2024 01:24
@AlekseySh
Copy link
Contributor Author

Changelog:

  • Changed inputs of calc_retrieval_metrics
  • PCF is moved outside of calc_retrieval_metrics
  • FNMR & PCF now return list of floats (for consistency with other metrics)
  • Precomputed metrics are reused for getting category-based metrics
  • Some typings got more specific types of Tensors

gt_tops = take_2d(mask_gt, ii_top_k)
n_gt = mask_gt.sum(dim=1)
# let's mark every correctly retrieved item as True and vice versa
gt_tops = stack([isin(retrieved_ids[i], tensor(gt_ids[i])) for i in range(n_queries)]).bool()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant .bool() call, torch.isin dtype already defaults to bool

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, but type checkers and PyCharm are not as smart, so this way I explicitly annotate types

output: TMetricsDict = {}

for k, v in metrics.items():
if isinstance(v, (Tensor, np.ndarray)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant check for np.ndarray due to TMetricsDict definition: TMetricsDict = Dict[str, Dict[Union[int, float], Union[float, FloatTensor]]] or change TMetricsDict definition

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, I prefer to remove the checking

max_k_arg = max([*self.cmc_top_k, *self.precision_top_k, *self.map_top_k])
k = min(self.distance_matrix.shape[1], max_k_arg) # type: ignore
_, retrieved_ids = torch.topk(self.distance_matrix, largest=False, k=k)
gt_ids = [torch.nonzero(row, as_tuple=True)[0].tolist() for row in self.mask_gt] # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why retrieve gt_ids as List[List[int]] instead of List[Tensor]? I propose updating the calc_retrieval_metrics function to accept gt_ids of type List[Tensor] to eliminate the need for double conversion between tensor and list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

Comment on lines 152 to 155
if isinstance(d2[k], torch.Tensor) and isinstance(v, torch.Tensor):
is_equal = torch.all(torch.isclose(d2[k], v))
elif isinstance(d2[k], float) and isinstance(v, float):
is_equal = math.isclose(d2[k], v, rel_tol=1e-6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's ensure consistency in tolerance values for numerical comparisons. Currently, torch.isclose defaults to rtol=1e-05 and atol=1e-08, while math.isclose uses rel_tol=1e-09 and abs_tol=0.0. I suggest updating both functions to use the same values, specifically rtol=1e-06 and atol=1e-08.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

@AlekseySh AlekseySh merged commit a018e62 into main_rework_validation Apr 14, 2024
8 checks passed
@AlekseySh AlekseySh deleted the rework_retrieval_metrics branch April 14, 2024 18:17
gt_tops = take_2d(mask_gt, ii_top_k)
n_gt = mask_gt.sum(dim=1)
# let's mark every correctly retrieved item as True and vice versa
gt_tops = stack([isin(retrieved_ids[i], gt_ids[i]) for i in range(n_queries)]).bool()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[EPIC] Release OML 3.0
3 participants