Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing container for storing Retrieval Results #544

Merged
merged 6 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,11 @@ upload_to_pip: build_wheel
.PHONY: pip_install_actual_oml
pip_install_actual_oml:
pip install open-metric-learning==$(OML_VERSION)

# ====================================== MISC =============================
.PHONY: clean
clean:
find . -type d -name "__pycache__" -exec rm -r {} +
find . -type f -name "*.log" -exec rm {} +
find . -type f -name "*.predictions.json" -exec rm {} +
rm -rf docs/build
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Poor docs ((

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's annoying when u do search in the full repo and find some trash in docs html files
so yep

2 changes: 2 additions & 0 deletions oml/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def get_cache_folder() -> Path:
BLACK = (0, 0, 0)
PAD_COLOR = (255, 255, 255)

BS_KNN = 5_000

TCfg = Union[Dict[str, Any], DictConfig]

# ImageNet Params
Expand Down
8 changes: 6 additions & 2 deletions oml/datasets/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,9 @@ def __init__(
is_gallery_key: str = IS_GALLERY_KEY,
):
assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, PATHS_COLUMN))
self.df = df.copy()

# instead of implementing the whole logic let's just re-use QGL dataset, but with dropped labels
df = df.copy()
df[LABELS_COLUMN] = "fake_label"

self.__dataset = ImageQueryGalleryLabeledDataset(
Expand Down Expand Up @@ -427,14 +428,17 @@ def __getitem__(self, item: int) -> Dict[str, Any]:
del batch[self.__dataset.labels_key]
return batch

def __len__(self) -> int:
return len(self.__dataset)

def get_query_ids(self) -> LongTensor:
return self.__dataset.get_query_ids()

def get_gallery_ids(self) -> LongTensor:
return self.__dataset.get_gallery_ids()

def visualize(self, item: int, color: TColor = BLACK) -> np.ndarray:
return self.__dataset.visualize(item, color)
return self.__dataset.visualize(item=item, color=color)


def get_retrieval_images_datasets(
Expand Down
78 changes: 78 additions & 0 deletions oml/functional/knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import List, Optional, Tuple

import torch
from torch import FloatTensor, LongTensor

from oml.const import BS_KNN
from oml.utils.misc_torch import pairwise_dist


def batched_knn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this func has nothing with KNN. It's exact calculation of distance matrix with truncating. However, I don't know how to call it...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's the same as kNN from sklearn:

    from sklearn.neighbors import NearestNeighbors
    knn = NearestNeighbors(algorithm="auto", p=2)
    knn.fit(features_galleries)
    dists, ii_closest = knn.kneighbors(features_queries, n_neighbors=top_k, return_distance=True)

By the way, did not you mean aNN (approximate NN)? But it's not what we are doing here

embeddings: FloatTensor,
ids_query: LongTensor,
ids_gallery: LongTensor,
top_n: int,
sequence_ids: Optional[LongTensor] = None,
labels_gt: Optional[LongTensor] = None,
bs: int = BS_KNN,
) -> Tuple[FloatTensor, LongTensor, Optional[List[LongTensor]]]:
"""

Args:
embeddings: Matrix with the shape of ``[n, dim]``
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usage of n here and in top_n confuse. Maybe L (refer to len) instead of lowercase n would be better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

ids_query: Tensor with the size of ``Q``, where ``Q <= n``. Each element is withing the range ``(0, n - 1)``.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: withing
navigate over repo to find the same errors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, removed everywhere

ids_gallery: Tensor with the size of ``G`` where ``G <= n``. Each element is withing the range ``(0, n - 1)``.
May overlap with ``ids_query``.
top_n: Number of neighbors to find and return.
sequence_ids: Sequence identifiers with the size of ``n`` (if known).
labels_gt: Ground truth labels of every element with the size of ``n`` (if known).
bs: Batch size for computing distances to avoid OOM errors when processing the whole matrix at once.

Returns:
distances: Sorted distances from every query to the closest ``top_n`` galleries with the size of ``(Q, top_n)``.
retrieved_ids: The corresponding ids of gallery items with the shape of ``(Q, top_n)``.
Each element is withing the range ``(0, G - 1)``.
gt_ids: Ids of the gallery items relevant to every query. Each element is withing the range ``(0, G - 1)``.
It's only computed if ``labels_gt`` is provided.

"""
assert (ids_query.ndim == 1) and (ids_gallery.ndim == 1) and (embeddings.ndim == 2)
assert len(embeddings) <= len(ids_query) + len(ids_gallery)
assert (sequence_ids is None) or (len(sequence_ids) == len(embeddings) and (sequence_ids.ndim == 1))
assert (labels_gt is None) or (len(labels_gt) <= len(ids_query) + len(ids_gallery) and (labels_gt.ndim == 1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or len(labels_gt) == embeddings.shape[0]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree


top_n = min(top_n, len(ids_gallery))

emb_q = embeddings[ids_query]
emb_g = embeddings[ids_gallery]

nq = len(ids_query)
retrieved_ids = LongTensor(nq, top_n)
distances = FloatTensor(nq, top_n)
gt_ids = []

# we do batching over first (queries) dimension
for i in range(0, nq, bs):
distances_b = pairwise_dist(x1=emb_q[i : i + bs], x2=emb_g)
ids_query_b = ids_query[i : i + bs]

# the logic behind: we want to ignore the item during search if it was used for both: query and gallery
mask_to_ignore_b = ids_query_b[..., None] == ids_gallery[None, ...]
if sequence_ids is not None:
# sometimes our items may be packed into the groups, so we ignore other members of this group during search
# more info in the docs: find for "Handling sequences of photos"
mask_sequence = sequence_ids[ids_query_b][..., None] == sequence_ids[ids_gallery][None, ...]
mask_to_ignore_b = torch.logical_or(mask_to_ignore_b, mask_sequence)

if labels_gt is not None:
mask_gt_b = labels_gt[ids_query_b][..., None] == labels_gt[ids_gallery][None, ...]
mask_gt_b[mask_to_ignore_b] = False
gt_ids.extend([LongTensor(row.nonzero()).view(-1) for row in mask_gt_b]) # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If labels_gt is not None we can allocate memory for gt_ids before the for-loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we cannot because every query has arbitorary number of gt, so we have list of tensors as a result


distances_b[mask_to_ignore_b] = float("inf")
distances[i : i + bs], retrieved_ids[i : i + bs] = torch.topk(distances_b, k=top_n, largest=False, sorted=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the second dimension for clarity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay


return distances, retrieved_ids, gt_ids or None


__all__ = ["batched_knn"]
182 changes: 182 additions & 0 deletions oml/retrieval/retrieval_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from typing import List

import matplotlib.pyplot as plt
import pandas as pd
from torch import FloatTensor, LongTensor

from oml.const import (
BLACK,
BLUE,
GRAY,
GREEN,
N_GT_SHOW_EMBEDDING_METRICS,
RED,
SEQUENCE_COLUMN,
)
from oml.functional.knn import batched_knn
from oml.interfaces.datasets import (
IQueryGalleryDataset,
IQueryGalleryLabeledDataset,
IVisualizableDataset,
)


class RetrievalResults:
def __init__(
self,
distances: FloatTensor,
retrieved_ids: LongTensor,
gt_ids: List[LongTensor] = None,
):
"""
Args:
distances: Sorted distances to the first ``top_n`` gallery items with the shape of ``[n_query, top_n]``.
retrieved_ids: Top N gallery ids retrieved for every query with the shape of ``[n_query, top_n]``.
Every element is within the range ``(0, n_gallery - 1)``.
gt_ids: Gallery ids relevant to every query, list of ``n_query`` elements where every element may
have an arbitrary length. Every element is within the range ``(0, n_gallery - 1)``

"""
assert distances.shape == retrieved_ids.shape
assert distances.ndim == 2

if gt_ids is not None:
assert distances.shape[0] == len(gt_ids)
if not all(len(x) > 0 for x in gt_ids):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be faster to evaluate any(len(x) == 0 for x in gt_ids)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you are right

raise RuntimeError("Every query must have at least one relevant gallery id.")

self.distances = distances
self.retrieved_ids = retrieved_ids
self.gt_ids = gt_ids

@property
def n_retrieved_items(self) -> int:
return self.retrieved_ids.shape[1]

@classmethod
def compute_from_embeddings(
cls,
embeddings: FloatTensor,
dataset: IQueryGalleryDataset,
n_items_to_retrieve: int = 1_000,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest to set default value for n_items_to_retrieve to 10 or even 5. 1000 is looking too big.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets it be in the middle, i set 100 :)

) -> "RetrievalResults":
"""
Args:
embeddings: The result of inference with the shape of ``[dataset_len, emb_dim]``.
dataset: Dataset having query/gallery split.
n_items_to_retrieve: Number of the closest gallery items to retrieve.
It may be clipped by gallery size if needed.

"""
assert len(embeddings) == len(dataset), "Embeddings and dataset must have the same size."

# todo 522: rework taking sequence
if hasattr(dataset, "df") and SEQUENCE_COLUMN in dataset.df:
dataset.df[SEQUENCE_COLUMN], _ = pd.factorize(dataset.df[SEQUENCE_COLUMN])
sequence_ids = LongTensor(dataset.df[SEQUENCE_COLUMN])
else:
sequence_ids = None

labels_gt = dataset.get_labels() if isinstance(dataset, IQueryGalleryLabeledDataset) else None

distances, retrieved_ids, gt_ids = batched_knn(
embeddings=embeddings,
ids_query=dataset.get_query_ids(),
ids_gallery=dataset.get_gallery_ids(),
labels_gt=labels_gt,
sequence_ids=sequence_ids,
top_n=n_items_to_retrieve,
)

return RetrievalResults(distances=distances, retrieved_ids=retrieved_ids, gt_ids=gt_ids)

def __repr__(self) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me it looks like a violation of the agreement about the purpose of __repr__ function.

Called by the repr() built-in function to compute the “official” string representation of an object. If at all possible, this should look like a valid Python expression that could be used to recreate an object with the same value (given an appropriate environment). If this is not possible, a string of the form <...some useful description...> should be returned.

https://docs.python.org/3/reference/datamodel.html#object.__repr__

Maybe it would be better to use __str__ for this message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

txt = (
f"You retrieved {self.n_retrieved_items} items.\n"
f"Distances to the retrieved items:\n{self.distances}.\n"
f"Ids of the retrieved gallery items:\n{self.retrieved_ids}.\n"
)

if self.gt_ids is None:
txt += "Ground truths are unknown.\n"
else:
gt_ids_list = [x.tolist() for x in self.gt_ids]
txt += f"Ground truth gallery ids are:\n{gt_ids_list}.\n"

return txt

def visualize(
self,
query_ids: List[int],
dataset: IQueryGalleryDataset,
n_galleries_to_show: int = 5,
verbose: bool = False,
) -> plt.Figure:
"""
Args:
query_ids: Query indices within the range of ``(0, n_query - 1)``.
dataset: Dataset that provides query-gallery split and supports visualisation.
n_galleries_to_show: Number of closest gallery items to show.
verbose: Set ``True`` to allow prints.

"""
if not isinstance(dataset, (IVisualizableDataset, IQueryGalleryDataset)):
raise ValueError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypeError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

f"Dataset has to support {IVisualizableDataset.__name__} and "
f"{IQueryGalleryDataset} interfaces. Got {type(dataset)}."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add .__name__ as for {IVisualizableDataset.__name__}?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right

)

if verbose:
print(f"Visualizing {n_galleries_to_show} for the following query ids: {query_ids}.")

ii_query = dataset.get_query_ids()
ii_gallery = dataset.get_gallery_ids()

n_galleries_to_show = min(n_galleries_to_show, self.n_retrieved_items)
n_gt_to_show = N_GT_SHOW_EMBEDDING_METRICS if (self.gt_ids is not None) else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about adding this as an argument with the default N_GT_SHOW_EMBEDDING_METRICS? 2 is ok for exp logging, but might be not enough for developing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


fig = plt.figure(figsize=(16, 16 / (n_galleries_to_show + n_gt_to_show + 1) * len(query_ids)))
n_rows, n_cols = len(query_ids), n_galleries_to_show + 1 + n_gt_to_show

# iterate over queries
for j, query_idx in enumerate(query_ids):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

j for rows and i for cols, like [j,i]? My entire life it is usually [i,j] 😁

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in mine as well :) swaped


plt.subplot(n_rows, n_cols, j * (n_galleries_to_show + 1 + n_gt_to_show) + 1)

img = dataset.visualize(item=ii_query[query_idx].item(), color=BLUE)

plt.imshow(img)
plt.title(f"Query #{query_idx}")
plt.axis("off")

# iterate over retrieved items
for i, ret_idx in enumerate(self.retrieved_ids[query_idx, :][:n_galleries_to_show]):
if self.gt_ids is not None:
color = GREEN if ret_idx in self.gt_ids[query_idx] else RED
else:
color = BLACK

plt.subplot(n_rows, n_cols, j * (n_galleries_to_show + 1 + n_gt_to_show) + i + 2)
img = dataset.visualize(item=ii_gallery[ret_idx].item(), color=color)

plt.title(f"Gallery #{ret_idx} - {round(self.distances[query_idx, i].item(), 3)}")
plt.imshow(img)
plt.axis("off")

if self.gt_ids is not None:

for k, gt_idx in enumerate(self.gt_ids[query_idx][:n_gt_to_show]):
plt.subplot(
n_rows, n_cols, j * (n_galleries_to_show + 1 + n_gt_to_show) + k + n_galleries_to_show + 2
)

img = dataset.visualize(item=ii_gallery[gt_idx].item(), color=GRAY)
plt.title("GT")
plt.imshow(img)
plt.axis("off")

fig.tight_layout()
return fig


__all__ = ["RetrievalResults"]
Loading
Loading