Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing container for storing Retrieval Results #544

Merged
merged 6 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,11 @@ upload_to_pip: build_wheel
.PHONY: pip_install_actual_oml
pip_install_actual_oml:
pip install open-metric-learning==$(OML_VERSION)

# ====================================== MISC =============================
.PHONY: clean
clean:
find . -type d -name "__pycache__" -exec rm -r {} +
find . -type f -name "*.log" -exec rm {} +
find . -type f -name "*.predictions.json" -exec rm {} +
rm -rf docs/build
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Poor docs ((

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's annoying when u do search in the full repo and find some trash in docs html files
so yep

2 changes: 2 additions & 0 deletions oml/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def get_cache_folder() -> Path:
BLACK = (0, 0, 0)
PAD_COLOR = (255, 255, 255)

BS_KNN = 5_000

TCfg = Union[Dict[str, Any], DictConfig]

# ImageNet Params
Expand Down
8 changes: 6 additions & 2 deletions oml/datasets/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,9 @@ def __init__(
is_gallery_key: str = IS_GALLERY_KEY,
):
assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, PATHS_COLUMN))
self.df = df.copy()

# instead of implementing the whole logic let's just re-use QGL dataset, but with dropped labels
df = df.copy()
df[LABELS_COLUMN] = "fake_label"

self.__dataset = ImageQueryGalleryLabeledDataset(
Expand Down Expand Up @@ -427,14 +428,17 @@ def __getitem__(self, item: int) -> Dict[str, Any]:
del batch[self.__dataset.labels_key]
return batch

def __len__(self) -> int:
return len(self.__dataset)

def get_query_ids(self) -> LongTensor:
return self.__dataset.get_query_ids()

def get_gallery_ids(self) -> LongTensor:
return self.__dataset.get_gallery_ids()

def visualize(self, item: int, color: TColor = BLACK) -> np.ndarray:
return self.__dataset.visualize(item, color)
return self.__dataset.visualize(item=item, color=color)


def get_retrieval_images_datasets(
Expand Down
80 changes: 80 additions & 0 deletions oml/functional/knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import List, Optional, Tuple

import torch
from torch import FloatTensor, LongTensor

from oml.const import BS_KNN
from oml.utils.misc_torch import pairwise_dist


def batched_knn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this func has nothing with KNN. It's exact calculation of distance matrix with truncating. However, I don't know how to call it...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's the same as kNN from sklearn:

    from sklearn.neighbors import NearestNeighbors
    knn = NearestNeighbors(algorithm="auto", p=2)
    knn.fit(features_galleries)
    dists, ii_closest = knn.kneighbors(features_queries, n_neighbors=top_k, return_distance=True)

By the way, did not you mean aNN (approximate NN)? But it's not what we are doing here

embeddings: FloatTensor,
ids_query: LongTensor,
ids_gallery: LongTensor,
top_n: int,
sequence_ids: Optional[LongTensor] = None,
labels_gt: Optional[LongTensor] = None,
bs: int = BS_KNN,
) -> Tuple[FloatTensor, LongTensor, Optional[List[LongTensor]]]:
"""

Args:
embeddings: Matrix with the shape of ``[L, dim]``
ids_query: Tensor with the size of ``Q``, where ``Q <= n``. Each element is within the range ``(0, L - 1)``.
ids_gallery: Tensor with the size of ``G`` where ``G <= n``. Each element is within the range ``(0, L - 1)``.
May overlap with ``ids_query``.
top_n: Number of neighbors to find and return.
sequence_ids: Sequence identifiers with the size of ``L`` (if known).
labels_gt: Ground truth labels of every element with the size of ``L`` (if known).
bs: Batch size for computing distances to avoid OOM errors when processing the whole matrix at once.

Returns:
distances: Sorted distances from every query to the closest ``top_n`` galleries with the size of ``(Q, top_n)``.
retrieved_ids: The corresponding ids of gallery items with the shape of ``(Q, top_n)``.
Each element is within the range ``(0, G - 1)``.
gt_ids: Ids of the gallery items relevant to every query. Each element is within the range ``(0, G - 1)``.
It's only computed if ``labels_gt`` is provided.

"""
assert (ids_query.ndim == 1) and (ids_gallery.ndim == 1) and (embeddings.ndim == 2)
assert len(embeddings) <= len(ids_query) + len(ids_gallery)
assert (sequence_ids is None) or ((len(sequence_ids) == len(embeddings)) and (sequence_ids.ndim == 1))
assert (labels_gt is None) or ((len(labels_gt) == embeddings.shape[0]) and (labels_gt.ndim == 1))

top_n = min(top_n, len(ids_gallery))

emb_q = embeddings[ids_query]
emb_g = embeddings[ids_gallery]

nq = len(ids_query)
retrieved_ids = LongTensor(nq, top_n)
distances = FloatTensor(nq, top_n)
gt_ids = []

# we do batching over first (queries) dimension
for i in range(0, nq, bs):
distances_b = pairwise_dist(x1=emb_q[i : i + bs, :], x2=emb_g)
ids_query_b = ids_query[i : i + bs]

# the logic behind: we want to ignore the item during search if it was used for both: query and gallery
mask_to_ignore_b = ids_query_b[..., None] == ids_gallery[None, ...]
if sequence_ids is not None:
# sometimes our items may be packed into the groups, so we ignore other members of this group during search
# more info in the docs: find for "Handling sequences of photos"
mask_sequence = sequence_ids[ids_query_b][..., None] == sequence_ids[ids_gallery][None, ...]
mask_to_ignore_b = torch.logical_or(mask_to_ignore_b, mask_sequence)

if labels_gt is not None:
mask_gt_b = labels_gt[ids_query_b][..., None] == labels_gt[ids_gallery][None, ...]
mask_gt_b[mask_to_ignore_b] = False
gt_ids.extend([LongTensor(row.nonzero()).view(-1) for row in mask_gt_b]) # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If labels_gt is not None we can allocate memory for gt_ids before the for-loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we cannot because every query has arbitorary number of gt, so we have list of tensors as a result


distances_b[mask_to_ignore_b] = float("inf")
distances[i : i + bs, :], retrieved_ids[i : i + bs, :] = torch.topk(
distances_b, k=top_n, largest=False, sorted=True
)

return distances, retrieved_ids, gt_ids or None


__all__ = ["batched_knn"]
184 changes: 184 additions & 0 deletions oml/retrieval/retrieval_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from typing import List

import matplotlib.pyplot as plt
import pandas as pd
from torch import FloatTensor, LongTensor

from oml.const import (
BLACK,
BLUE,
GRAY,
GREEN,
N_GT_SHOW_EMBEDDING_METRICS,
RED,
SEQUENCE_COLUMN,
)
from oml.functional.knn import batched_knn
from oml.interfaces.datasets import (
IQueryGalleryDataset,
IQueryGalleryLabeledDataset,
IVisualizableDataset,
)


class RetrievalResults:
def __init__(
self,
distances: FloatTensor,
retrieved_ids: LongTensor,
gt_ids: List[LongTensor] = None,
):
"""
Args:
distances: Sorted distances to the first ``top_n`` gallery items with the shape of ``[n_query, top_n]``.
retrieved_ids: Top N gallery ids retrieved for every query with the shape of ``[n_query, top_n]``.
Every element is within the range ``(0, n_gallery - 1)``.
gt_ids: Gallery ids relevant to every query, list of ``n_query`` elements where every element may
have an arbitrary length. Every element is within the range ``(0, n_gallery - 1)``

"""
assert distances.shape == retrieved_ids.shape
assert distances.ndim == 2

if gt_ids is not None:
assert distances.shape[0] == len(gt_ids)
if any(len(x) == 0 for x in gt_ids):
raise RuntimeError("Every query must have at least one relevant gallery id.")

self.distances = distances
self.retrieved_ids = retrieved_ids
self.gt_ids = gt_ids

@property
def n_retrieved_items(self) -> int:
return self.retrieved_ids.shape[1]

@classmethod
def compute_from_embeddings(
cls,
embeddings: FloatTensor,
dataset: IQueryGalleryDataset,
n_items_to_retrieve: int = 100,
) -> "RetrievalResults":
"""
Args:
embeddings: The result of inference with the shape of ``[dataset_len, emb_dim]``.
dataset: Dataset having query/gallery split.
n_items_to_retrieve: Number of the closest gallery items to retrieve.
It may be clipped by gallery size if needed.

"""
assert len(embeddings) == len(dataset), "Embeddings and dataset must have the same size."

# todo 522: rework taking sequence
if hasattr(dataset, "df") and SEQUENCE_COLUMN in dataset.df:
dataset.df[SEQUENCE_COLUMN], _ = pd.factorize(dataset.df[SEQUENCE_COLUMN])
sequence_ids = LongTensor(dataset.df[SEQUENCE_COLUMN])
else:
sequence_ids = None

labels_gt = dataset.get_labels() if isinstance(dataset, IQueryGalleryLabeledDataset) else None

distances, retrieved_ids, gt_ids = batched_knn(
embeddings=embeddings,
ids_query=dataset.get_query_ids(),
ids_gallery=dataset.get_gallery_ids(),
labels_gt=labels_gt,
sequence_ids=sequence_ids,
top_n=n_items_to_retrieve,
)

return RetrievalResults(distances=distances, retrieved_ids=retrieved_ids, gt_ids=gt_ids)

def __str__(self) -> str:
txt = (
f"You retrieved {self.n_retrieved_items} items.\n"
f"Distances to the retrieved items:\n{self.distances}.\n"
f"Ids of the retrieved gallery items:\n{self.retrieved_ids}.\n"
)

if self.gt_ids is None:
txt += "Ground truths are unknown.\n"
else:
gt_ids_list = [x.tolist() for x in self.gt_ids]
txt += f"Ground truth gallery ids are:\n{gt_ids_list}.\n"

return txt

def visualize(
self,
query_ids: List[int],
dataset: IQueryGalleryDataset,
n_galleries_to_show: int = 5,
n_gt_to_show: int = N_GT_SHOW_EMBEDDING_METRICS,
verbose: bool = False,
) -> plt.Figure:
"""
Args:
query_ids: Query indices within the range of ``(0, n_query - 1)``.
dataset: Dataset that provides query-gallery split and supports visualisation.
n_galleries_to_show: Number of closest gallery items to show.
n_gt_to_show: Number of ground truth gallery items to show for reference (if available).
verbose: Set ``True`` to allow prints.

"""
if not isinstance(dataset, (IVisualizableDataset, IQueryGalleryDataset)):
raise TypeError(
f"Dataset has to support {IVisualizableDataset.__name__} and "
f"{IQueryGalleryDataset.__name__} interfaces. Got {type(dataset)}."
)

if verbose:
print(f"Visualizing {n_galleries_to_show} for the following query ids: {query_ids}.")

ii_query = dataset.get_query_ids()
ii_gallery = dataset.get_gallery_ids()

n_galleries_to_show = min(n_galleries_to_show, self.n_retrieved_items)
n_gt_to_show = n_gt_to_show if (self.gt_ids is not None) else 0

fig = plt.figure(figsize=(16, 16 / (n_galleries_to_show + n_gt_to_show + 1) * len(query_ids)))
n_rows, n_cols = len(query_ids), n_galleries_to_show + 1 + n_gt_to_show

# iterate over queries
for i, query_idx in enumerate(query_ids):

plt.subplot(n_rows, n_cols, i * (n_galleries_to_show + 1 + n_gt_to_show) + 1)

img = dataset.visualize(item=ii_query[query_idx].item(), color=BLUE)

plt.imshow(img)
plt.title(f"Query #{query_idx}")
plt.axis("off")

# iterate over retrieved items
for j, ret_idx in enumerate(self.retrieved_ids[query_idx, :][:n_galleries_to_show]):
if self.gt_ids is not None:
color = GREEN if ret_idx in self.gt_ids[query_idx] else RED
else:
color = BLACK

plt.subplot(n_rows, n_cols, i * (n_galleries_to_show + 1 + n_gt_to_show) + j + 2)
img = dataset.visualize(item=ii_gallery[ret_idx].item(), color=color)

plt.title(f"Gallery #{ret_idx} - {round(self.distances[query_idx, j].item(), 3)}")
plt.imshow(img)
plt.axis("off")

if self.gt_ids is not None:

for k, gt_idx in enumerate(self.gt_ids[query_idx][:n_gt_to_show]):
plt.subplot(
n_rows, n_cols, i * (n_galleries_to_show + 1 + n_gt_to_show) + k + n_galleries_to_show + 2
)

img = dataset.visualize(item=ii_gallery[gt_idx].item(), color=GRAY)
plt.title("GT")
plt.imshow(img)
plt.axis("off")

fig.tight_layout()
return fig


__all__ = ["RetrievalResults"]
Loading
Loading