-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introducing container for storing Retrieval Results #544
Changes from 5 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from typing import List, Optional, Tuple | ||
|
||
import torch | ||
from torch import FloatTensor, LongTensor | ||
|
||
from oml.const import BS_KNN | ||
from oml.utils.misc_torch import pairwise_dist | ||
|
||
|
||
def batched_knn( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this func has nothing with KNN. It's exact calculation of distance matrix with truncating. However, I don't know how to call it... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's the same as kNN from sklearn:
By the way, did not you mean aNN (approximate NN)? But it's not what we are doing here |
||
embeddings: FloatTensor, | ||
ids_query: LongTensor, | ||
ids_gallery: LongTensor, | ||
top_n: int, | ||
sequence_ids: Optional[LongTensor] = None, | ||
labels_gt: Optional[LongTensor] = None, | ||
bs: int = BS_KNN, | ||
) -> Tuple[FloatTensor, LongTensor, Optional[List[LongTensor]]]: | ||
""" | ||
|
||
Args: | ||
embeddings: Matrix with the shape of ``[n, dim]`` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Usage of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree |
||
ids_query: Tensor with the size of ``Q``, where ``Q <= n``. Each element is withing the range ``(0, n - 1)``. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo: withing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, removed everywhere |
||
ids_gallery: Tensor with the size of ``G`` where ``G <= n``. Each element is withing the range ``(0, n - 1)``. | ||
May overlap with ``ids_query``. | ||
top_n: Number of neighbors to find and return. | ||
sequence_ids: Sequence identifiers with the size of ``n`` (if known). | ||
labels_gt: Ground truth labels of every element with the size of ``n`` (if known). | ||
bs: Batch size for computing distances to avoid OOM errors when processing the whole matrix at once. | ||
|
||
Returns: | ||
distances: Sorted distances from every query to the closest ``top_n`` galleries with the size of ``(Q, top_n)``. | ||
retrieved_ids: The corresponding ids of gallery items with the shape of ``(Q, top_n)``. | ||
Each element is withing the range ``(0, G - 1)``. | ||
gt_ids: Ids of the gallery items relevant to every query. Each element is withing the range ``(0, G - 1)``. | ||
It's only computed if ``labels_gt`` is provided. | ||
|
||
""" | ||
assert (ids_query.ndim == 1) and (ids_gallery.ndim == 1) and (embeddings.ndim == 2) | ||
assert len(embeddings) <= len(ids_query) + len(ids_gallery) | ||
assert (sequence_ids is None) or (len(sequence_ids) == len(embeddings) and (sequence_ids.ndim == 1)) | ||
assert (labels_gt is None) or (len(labels_gt) <= len(ids_query) + len(ids_gallery) and (labels_gt.ndim == 1)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree |
||
|
||
top_n = min(top_n, len(ids_gallery)) | ||
|
||
emb_q = embeddings[ids_query] | ||
emb_g = embeddings[ids_gallery] | ||
|
||
nq = len(ids_query) | ||
retrieved_ids = LongTensor(nq, top_n) | ||
distances = FloatTensor(nq, top_n) | ||
gt_ids = [] | ||
|
||
# we do batching over first (queries) dimension | ||
for i in range(0, nq, bs): | ||
distances_b = pairwise_dist(x1=emb_q[i : i + bs], x2=emb_g) | ||
ids_query_b = ids_query[i : i + bs] | ||
|
||
# the logic behind: we want to ignore the item during search if it was used for both: query and gallery | ||
mask_to_ignore_b = ids_query_b[..., None] == ids_gallery[None, ...] | ||
if sequence_ids is not None: | ||
# sometimes our items may be packed into the groups, so we ignore other members of this group during search | ||
# more info in the docs: find for "Handling sequences of photos" | ||
mask_sequence = sequence_ids[ids_query_b][..., None] == sequence_ids[ids_gallery][None, ...] | ||
mask_to_ignore_b = torch.logical_or(mask_to_ignore_b, mask_sequence) | ||
|
||
if labels_gt is not None: | ||
mask_gt_b = labels_gt[ids_query_b][..., None] == labels_gt[ids_gallery][None, ...] | ||
mask_gt_b[mask_to_ignore_b] = False | ||
gt_ids.extend([LongTensor(row.nonzero()).view(-1) for row in mask_gt_b]) # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we cannot because every query has arbitorary number of gt, so we have list of tensors as a result |
||
|
||
distances_b[mask_to_ignore_b] = float("inf") | ||
distances[i : i + bs], retrieved_ids[i : i + bs] = torch.topk(distances_b, k=top_n, largest=False, sorted=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add the second dimension for clarity There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay |
||
|
||
return distances, retrieved_ids, gt_ids or None | ||
|
||
|
||
__all__ = ["batched_knn"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
from typing import List | ||
|
||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
from torch import FloatTensor, LongTensor | ||
|
||
from oml.const import ( | ||
BLACK, | ||
BLUE, | ||
GRAY, | ||
GREEN, | ||
N_GT_SHOW_EMBEDDING_METRICS, | ||
RED, | ||
SEQUENCE_COLUMN, | ||
) | ||
from oml.functional.knn import batched_knn | ||
from oml.interfaces.datasets import ( | ||
IQueryGalleryDataset, | ||
IQueryGalleryLabeledDataset, | ||
IVisualizableDataset, | ||
) | ||
|
||
|
||
class RetrievalResults: | ||
def __init__( | ||
self, | ||
distances: FloatTensor, | ||
retrieved_ids: LongTensor, | ||
gt_ids: List[LongTensor] = None, | ||
): | ||
""" | ||
Args: | ||
distances: Sorted distances to the first ``top_n`` gallery items with the shape of ``[n_query, top_n]``. | ||
retrieved_ids: Top N gallery ids retrieved for every query with the shape of ``[n_query, top_n]``. | ||
Every element is within the range ``(0, n_gallery - 1)``. | ||
gt_ids: Gallery ids relevant to every query, list of ``n_query`` elements where every element may | ||
have an arbitrary length. Every element is within the range ``(0, n_gallery - 1)`` | ||
|
||
""" | ||
assert distances.shape == retrieved_ids.shape | ||
assert distances.ndim == 2 | ||
|
||
if gt_ids is not None: | ||
assert distances.shape[0] == len(gt_ids) | ||
if not all(len(x) > 0 for x in gt_ids): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't it be faster to evaluate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think you are right |
||
raise RuntimeError("Every query must have at least one relevant gallery id.") | ||
|
||
self.distances = distances | ||
self.retrieved_ids = retrieved_ids | ||
self.gt_ids = gt_ids | ||
|
||
@property | ||
def n_retrieved_items(self) -> int: | ||
return self.retrieved_ids.shape[1] | ||
|
||
@classmethod | ||
def compute_from_embeddings( | ||
cls, | ||
embeddings: FloatTensor, | ||
dataset: IQueryGalleryDataset, | ||
n_items_to_retrieve: int = 1_000, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd suggest to set default value for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets it be in the middle, i set 100 :) |
||
) -> "RetrievalResults": | ||
""" | ||
Args: | ||
embeddings: The result of inference with the shape of ``[dataset_len, emb_dim]``. | ||
dataset: Dataset having query/gallery split. | ||
n_items_to_retrieve: Number of the closest gallery items to retrieve. | ||
It may be clipped by gallery size if needed. | ||
|
||
""" | ||
assert len(embeddings) == len(dataset), "Embeddings and dataset must have the same size." | ||
|
||
# todo 522: rework taking sequence | ||
if hasattr(dataset, "df") and SEQUENCE_COLUMN in dataset.df: | ||
dataset.df[SEQUENCE_COLUMN], _ = pd.factorize(dataset.df[SEQUENCE_COLUMN]) | ||
sequence_ids = LongTensor(dataset.df[SEQUENCE_COLUMN]) | ||
else: | ||
sequence_ids = None | ||
|
||
labels_gt = dataset.get_labels() if isinstance(dataset, IQueryGalleryLabeledDataset) else None | ||
|
||
distances, retrieved_ids, gt_ids = batched_knn( | ||
embeddings=embeddings, | ||
ids_query=dataset.get_query_ids(), | ||
ids_gallery=dataset.get_gallery_ids(), | ||
labels_gt=labels_gt, | ||
sequence_ids=sequence_ids, | ||
top_n=n_items_to_retrieve, | ||
) | ||
|
||
return RetrievalResults(distances=distances, retrieved_ids=retrieved_ids, gt_ids=gt_ids) | ||
|
||
def __repr__(self) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For me it looks like a violation of the agreement about the purpose of
https://docs.python.org/3/reference/datamodel.html#object.__repr__ Maybe it would be better to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
txt = ( | ||
f"You retrieved {self.n_retrieved_items} items.\n" | ||
f"Distances to the retrieved items:\n{self.distances}.\n" | ||
f"Ids of the retrieved gallery items:\n{self.retrieved_ids}.\n" | ||
) | ||
|
||
if self.gt_ids is None: | ||
txt += "Ground truths are unknown.\n" | ||
else: | ||
gt_ids_list = [x.tolist() for x in self.gt_ids] | ||
txt += f"Ground truth gallery ids are:\n{gt_ids_list}.\n" | ||
|
||
return txt | ||
|
||
def visualize( | ||
self, | ||
query_ids: List[int], | ||
dataset: IQueryGalleryDataset, | ||
n_galleries_to_show: int = 5, | ||
verbose: bool = False, | ||
) -> plt.Figure: | ||
""" | ||
Args: | ||
query_ids: Query indices within the range of ``(0, n_query - 1)``. | ||
dataset: Dataset that provides query-gallery split and supports visualisation. | ||
n_galleries_to_show: Number of closest gallery items to show. | ||
verbose: Set ``True`` to allow prints. | ||
|
||
""" | ||
if not isinstance(dataset, (IVisualizableDataset, IQueryGalleryDataset)): | ||
raise ValueError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TypeError There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed |
||
f"Dataset has to support {IVisualizableDataset.__name__} and " | ||
f"{IQueryGalleryDataset} interfaces. Got {type(dataset)}." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh right |
||
) | ||
|
||
if verbose: | ||
print(f"Visualizing {n_galleries_to_show} for the following query ids: {query_ids}.") | ||
|
||
ii_query = dataset.get_query_ids() | ||
ii_gallery = dataset.get_gallery_ids() | ||
|
||
n_galleries_to_show = min(n_galleries_to_show, self.n_retrieved_items) | ||
n_gt_to_show = N_GT_SHOW_EMBEDDING_METRICS if (self.gt_ids is not None) else 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about adding this as an argument with the default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||
|
||
fig = plt.figure(figsize=(16, 16 / (n_galleries_to_show + n_gt_to_show + 1) * len(query_ids))) | ||
n_rows, n_cols = len(query_ids), n_galleries_to_show + 1 + n_gt_to_show | ||
|
||
# iterate over queries | ||
for j, query_idx in enumerate(query_ids): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in mine as well :) swaped |
||
|
||
plt.subplot(n_rows, n_cols, j * (n_galleries_to_show + 1 + n_gt_to_show) + 1) | ||
|
||
img = dataset.visualize(item=ii_query[query_idx].item(), color=BLUE) | ||
|
||
plt.imshow(img) | ||
plt.title(f"Query #{query_idx}") | ||
plt.axis("off") | ||
|
||
# iterate over retrieved items | ||
for i, ret_idx in enumerate(self.retrieved_ids[query_idx, :][:n_galleries_to_show]): | ||
if self.gt_ids is not None: | ||
color = GREEN if ret_idx in self.gt_ids[query_idx] else RED | ||
else: | ||
color = BLACK | ||
|
||
plt.subplot(n_rows, n_cols, j * (n_galleries_to_show + 1 + n_gt_to_show) + i + 2) | ||
img = dataset.visualize(item=ii_gallery[ret_idx].item(), color=color) | ||
|
||
plt.title(f"Gallery #{ret_idx} - {round(self.distances[query_idx, i].item(), 3)}") | ||
plt.imshow(img) | ||
plt.axis("off") | ||
|
||
if self.gt_ids is not None: | ||
|
||
for k, gt_idx in enumerate(self.gt_ids[query_idx][:n_gt_to_show]): | ||
plt.subplot( | ||
n_rows, n_cols, j * (n_galleries_to_show + 1 + n_gt_to_show) + k + n_galleries_to_show + 2 | ||
) | ||
|
||
img = dataset.visualize(item=ii_gallery[gt_idx].item(), color=GRAY) | ||
plt.title("GT") | ||
plt.imshow(img) | ||
plt.axis("off") | ||
|
||
fig.tight_layout() | ||
return fig | ||
|
||
|
||
__all__ = ["RetrievalResults"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Poor docs ((
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's annoying when u do search in the full repo and find some trash in docs html files
so yep