-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introducing container for storing Retrieval Results #544
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from typing import List, Optional, Tuple | ||
|
||
import torch | ||
from torch import FloatTensor, LongTensor | ||
|
||
from oml.const import BS_KNN | ||
from oml.utils.misc_torch import pairwise_dist | ||
|
||
|
||
def batched_knn( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this func has nothing with KNN. It's exact calculation of distance matrix with truncating. However, I don't know how to call it... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's the same as kNN from sklearn:
By the way, did not you mean aNN (approximate NN)? But it's not what we are doing here |
||
embeddings: FloatTensor, | ||
ids_query: LongTensor, | ||
ids_gallery: LongTensor, | ||
top_n: int, | ||
sequence_ids: Optional[LongTensor] = None, | ||
labels_gt: Optional[LongTensor] = None, | ||
bs: int = BS_KNN, | ||
) -> Tuple[FloatTensor, LongTensor, Optional[List[LongTensor]]]: | ||
""" | ||
|
||
Args: | ||
embeddings: Matrix with the shape of ``[L, dim]`` | ||
ids_query: Tensor with the size of ``Q``, where ``Q <= n``. Each element is within the range ``(0, L - 1)``. | ||
ids_gallery: Tensor with the size of ``G`` where ``G <= n``. Each element is within the range ``(0, L - 1)``. | ||
May overlap with ``ids_query``. | ||
top_n: Number of neighbors to find and return. | ||
sequence_ids: Sequence identifiers with the size of ``L`` (if known). | ||
labels_gt: Ground truth labels of every element with the size of ``L`` (if known). | ||
bs: Batch size for computing distances to avoid OOM errors when processing the whole matrix at once. | ||
|
||
Returns: | ||
distances: Sorted distances from every query to the closest ``top_n`` galleries with the size of ``(Q, top_n)``. | ||
retrieved_ids: The corresponding ids of gallery items with the shape of ``(Q, top_n)``. | ||
Each element is within the range ``(0, G - 1)``. | ||
gt_ids: Ids of the gallery items relevant to every query. Each element is within the range ``(0, G - 1)``. | ||
It's only computed if ``labels_gt`` is provided. | ||
|
||
""" | ||
assert (ids_query.ndim == 1) and (ids_gallery.ndim == 1) and (embeddings.ndim == 2) | ||
assert len(embeddings) <= len(ids_query) + len(ids_gallery) | ||
assert (sequence_ids is None) or ((len(sequence_ids) == len(embeddings)) and (sequence_ids.ndim == 1)) | ||
assert (labels_gt is None) or ((len(labels_gt) == embeddings.shape[0]) and (labels_gt.ndim == 1)) | ||
|
||
top_n = min(top_n, len(ids_gallery)) | ||
|
||
emb_q = embeddings[ids_query] | ||
emb_g = embeddings[ids_gallery] | ||
|
||
nq = len(ids_query) | ||
retrieved_ids = LongTensor(nq, top_n) | ||
distances = FloatTensor(nq, top_n) | ||
gt_ids = [] | ||
|
||
# we do batching over first (queries) dimension | ||
for i in range(0, nq, bs): | ||
distances_b = pairwise_dist(x1=emb_q[i : i + bs, :], x2=emb_g) | ||
ids_query_b = ids_query[i : i + bs] | ||
|
||
# the logic behind: we want to ignore the item during search if it was used for both: query and gallery | ||
mask_to_ignore_b = ids_query_b[..., None] == ids_gallery[None, ...] | ||
if sequence_ids is not None: | ||
# sometimes our items may be packed into the groups, so we ignore other members of this group during search | ||
# more info in the docs: find for "Handling sequences of photos" | ||
mask_sequence = sequence_ids[ids_query_b][..., None] == sequence_ids[ids_gallery][None, ...] | ||
mask_to_ignore_b = torch.logical_or(mask_to_ignore_b, mask_sequence) | ||
|
||
if labels_gt is not None: | ||
mask_gt_b = labels_gt[ids_query_b][..., None] == labels_gt[ids_gallery][None, ...] | ||
mask_gt_b[mask_to_ignore_b] = False | ||
gt_ids.extend([LongTensor(row.nonzero()).view(-1) for row in mask_gt_b]) # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we cannot because every query has arbitorary number of gt, so we have list of tensors as a result |
||
|
||
distances_b[mask_to_ignore_b] = float("inf") | ||
distances[i : i + bs, :], retrieved_ids[i : i + bs, :] = torch.topk( | ||
distances_b, k=top_n, largest=False, sorted=True | ||
) | ||
|
||
return distances, retrieved_ids, gt_ids or None | ||
|
||
|
||
__all__ = ["batched_knn"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
from typing import List | ||
|
||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
from torch import FloatTensor, LongTensor | ||
|
||
from oml.const import ( | ||
BLACK, | ||
BLUE, | ||
GRAY, | ||
GREEN, | ||
N_GT_SHOW_EMBEDDING_METRICS, | ||
RED, | ||
SEQUENCE_COLUMN, | ||
) | ||
from oml.functional.knn import batched_knn | ||
from oml.interfaces.datasets import ( | ||
IQueryGalleryDataset, | ||
IQueryGalleryLabeledDataset, | ||
IVisualizableDataset, | ||
) | ||
|
||
|
||
class RetrievalResults: | ||
def __init__( | ||
self, | ||
distances: FloatTensor, | ||
retrieved_ids: LongTensor, | ||
gt_ids: List[LongTensor] = None, | ||
): | ||
""" | ||
Args: | ||
distances: Sorted distances to the first ``top_n`` gallery items with the shape of ``[n_query, top_n]``. | ||
retrieved_ids: Top N gallery ids retrieved for every query with the shape of ``[n_query, top_n]``. | ||
Every element is within the range ``(0, n_gallery - 1)``. | ||
gt_ids: Gallery ids relevant to every query, list of ``n_query`` elements where every element may | ||
have an arbitrary length. Every element is within the range ``(0, n_gallery - 1)`` | ||
|
||
""" | ||
assert distances.shape == retrieved_ids.shape | ||
assert distances.ndim == 2 | ||
|
||
if gt_ids is not None: | ||
assert distances.shape[0] == len(gt_ids) | ||
if any(len(x) == 0 for x in gt_ids): | ||
raise RuntimeError("Every query must have at least one relevant gallery id.") | ||
|
||
self.distances = distances | ||
self.retrieved_ids = retrieved_ids | ||
self.gt_ids = gt_ids | ||
|
||
@property | ||
def n_retrieved_items(self) -> int: | ||
return self.retrieved_ids.shape[1] | ||
|
||
@classmethod | ||
def compute_from_embeddings( | ||
cls, | ||
embeddings: FloatTensor, | ||
dataset: IQueryGalleryDataset, | ||
n_items_to_retrieve: int = 100, | ||
) -> "RetrievalResults": | ||
""" | ||
Args: | ||
embeddings: The result of inference with the shape of ``[dataset_len, emb_dim]``. | ||
dataset: Dataset having query/gallery split. | ||
n_items_to_retrieve: Number of the closest gallery items to retrieve. | ||
It may be clipped by gallery size if needed. | ||
|
||
""" | ||
assert len(embeddings) == len(dataset), "Embeddings and dataset must have the same size." | ||
|
||
# todo 522: rework taking sequence | ||
if hasattr(dataset, "df") and SEQUENCE_COLUMN in dataset.df: | ||
dataset.df[SEQUENCE_COLUMN], _ = pd.factorize(dataset.df[SEQUENCE_COLUMN]) | ||
sequence_ids = LongTensor(dataset.df[SEQUENCE_COLUMN]) | ||
else: | ||
sequence_ids = None | ||
|
||
labels_gt = dataset.get_labels() if isinstance(dataset, IQueryGalleryLabeledDataset) else None | ||
|
||
distances, retrieved_ids, gt_ids = batched_knn( | ||
embeddings=embeddings, | ||
ids_query=dataset.get_query_ids(), | ||
ids_gallery=dataset.get_gallery_ids(), | ||
labels_gt=labels_gt, | ||
sequence_ids=sequence_ids, | ||
top_n=n_items_to_retrieve, | ||
) | ||
|
||
return RetrievalResults(distances=distances, retrieved_ids=retrieved_ids, gt_ids=gt_ids) | ||
|
||
def __str__(self) -> str: | ||
txt = ( | ||
f"You retrieved {self.n_retrieved_items} items.\n" | ||
f"Distances to the retrieved items:\n{self.distances}.\n" | ||
f"Ids of the retrieved gallery items:\n{self.retrieved_ids}.\n" | ||
) | ||
|
||
if self.gt_ids is None: | ||
txt += "Ground truths are unknown.\n" | ||
else: | ||
gt_ids_list = [x.tolist() for x in self.gt_ids] | ||
txt += f"Ground truth gallery ids are:\n{gt_ids_list}.\n" | ||
|
||
return txt | ||
|
||
def visualize( | ||
self, | ||
query_ids: List[int], | ||
dataset: IQueryGalleryDataset, | ||
n_galleries_to_show: int = 5, | ||
n_gt_to_show: int = N_GT_SHOW_EMBEDDING_METRICS, | ||
verbose: bool = False, | ||
) -> plt.Figure: | ||
""" | ||
Args: | ||
query_ids: Query indices within the range of ``(0, n_query - 1)``. | ||
dataset: Dataset that provides query-gallery split and supports visualisation. | ||
n_galleries_to_show: Number of closest gallery items to show. | ||
n_gt_to_show: Number of ground truth gallery items to show for reference (if available). | ||
verbose: Set ``True`` to allow prints. | ||
|
||
""" | ||
if not isinstance(dataset, (IVisualizableDataset, IQueryGalleryDataset)): | ||
raise TypeError( | ||
f"Dataset has to support {IVisualizableDataset.__name__} and " | ||
f"{IQueryGalleryDataset.__name__} interfaces. Got {type(dataset)}." | ||
) | ||
|
||
if verbose: | ||
print(f"Visualizing {n_galleries_to_show} for the following query ids: {query_ids}.") | ||
|
||
ii_query = dataset.get_query_ids() | ||
ii_gallery = dataset.get_gallery_ids() | ||
|
||
n_galleries_to_show = min(n_galleries_to_show, self.n_retrieved_items) | ||
n_gt_to_show = n_gt_to_show if (self.gt_ids is not None) else 0 | ||
|
||
fig = plt.figure(figsize=(16, 16 / (n_galleries_to_show + n_gt_to_show + 1) * len(query_ids))) | ||
n_rows, n_cols = len(query_ids), n_galleries_to_show + 1 + n_gt_to_show | ||
|
||
# iterate over queries | ||
for i, query_idx in enumerate(query_ids): | ||
|
||
plt.subplot(n_rows, n_cols, i * (n_galleries_to_show + 1 + n_gt_to_show) + 1) | ||
|
||
img = dataset.visualize(item=ii_query[query_idx].item(), color=BLUE) | ||
|
||
plt.imshow(img) | ||
plt.title(f"Query #{query_idx}") | ||
plt.axis("off") | ||
|
||
# iterate over retrieved items | ||
for j, ret_idx in enumerate(self.retrieved_ids[query_idx, :][:n_galleries_to_show]): | ||
if self.gt_ids is not None: | ||
color = GREEN if ret_idx in self.gt_ids[query_idx] else RED | ||
else: | ||
color = BLACK | ||
|
||
plt.subplot(n_rows, n_cols, i * (n_galleries_to_show + 1 + n_gt_to_show) + j + 2) | ||
img = dataset.visualize(item=ii_gallery[ret_idx].item(), color=color) | ||
|
||
plt.title(f"Gallery #{ret_idx} - {round(self.distances[query_idx, j].item(), 3)}") | ||
plt.imshow(img) | ||
plt.axis("off") | ||
|
||
if self.gt_ids is not None: | ||
|
||
for k, gt_idx in enumerate(self.gt_ids[query_idx][:n_gt_to_show]): | ||
plt.subplot( | ||
n_rows, n_cols, i * (n_galleries_to_show + 1 + n_gt_to_show) + k + n_galleries_to_show + 2 | ||
) | ||
|
||
img = dataset.visualize(item=ii_gallery[gt_idx].item(), color=GRAY) | ||
plt.title("GT") | ||
plt.imshow(img) | ||
plt.axis("off") | ||
|
||
fig.tight_layout() | ||
return fig | ||
|
||
|
||
__all__ = ["RetrievalResults"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Poor docs ((
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's annoying when u do search in the full repo and find some trash in docs html files
so yep