Skip to content

Commit 0855c80

Browse files
authored
Introducing container for storing Retrieval Results and memory optimized kNN
Introducing container for storing Retrieval Results and memory optimized kNN
1 parent 1040235 commit 0855c80

File tree

8 files changed

+443
-2
lines changed

8 files changed

+443
-2
lines changed

Makefile

+8
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,11 @@ upload_to_pip: build_wheel
114114
.PHONY: pip_install_actual_oml
115115
pip_install_actual_oml:
116116
pip install open-metric-learning==$(OML_VERSION)
117+
118+
# ====================================== MISC =============================
119+
.PHONY: clean
120+
clean:
121+
find . -type d -name "__pycache__" -exec rm -r {} +
122+
find . -type f -name "*.log" -exec rm {} +
123+
find . -type f -name "*.predictions.json" -exec rm {} +
124+
rm -rf docs/build

oml/const.py

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def get_cache_folder() -> Path:
5353
BLACK = (0, 0, 0)
5454
PAD_COLOR = (255, 255, 255)
5555

56+
BS_KNN = 5_000
57+
5658
TCfg = Union[Dict[str, Any], DictConfig]
5759

5860
# ImageNet Params

oml/datasets/images.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -397,8 +397,9 @@ def __init__(
397397
is_gallery_key: str = IS_GALLERY_KEY,
398398
):
399399
assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, PATHS_COLUMN))
400+
self.df = df.copy()
401+
400402
# instead of implementing the whole logic let's just re-use QGL dataset, but with dropped labels
401-
df = df.copy()
402403
df[LABELS_COLUMN] = "fake_label"
403404

404405
self.__dataset = ImageQueryGalleryLabeledDataset(
@@ -427,14 +428,17 @@ def __getitem__(self, item: int) -> Dict[str, Any]:
427428
del batch[self.__dataset.labels_key]
428429
return batch
429430

431+
def __len__(self) -> int:
432+
return len(self.__dataset)
433+
430434
def get_query_ids(self) -> LongTensor:
431435
return self.__dataset.get_query_ids()
432436

433437
def get_gallery_ids(self) -> LongTensor:
434438
return self.__dataset.get_gallery_ids()
435439

436440
def visualize(self, item: int, color: TColor = BLACK) -> np.ndarray:
437-
return self.__dataset.visualize(item, color)
441+
return self.__dataset.visualize(item=item, color=color)
438442

439443

440444
def get_retrieval_images_datasets(

oml/functional/knn.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import List, Optional, Tuple
2+
3+
import torch
4+
from torch import FloatTensor, LongTensor
5+
6+
from oml.const import BS_KNN
7+
from oml.utils.misc_torch import pairwise_dist
8+
9+
10+
def batched_knn(
11+
embeddings: FloatTensor,
12+
ids_query: LongTensor,
13+
ids_gallery: LongTensor,
14+
top_n: int,
15+
sequence_ids: Optional[LongTensor] = None,
16+
labels_gt: Optional[LongTensor] = None,
17+
bs: int = BS_KNN,
18+
) -> Tuple[FloatTensor, LongTensor, Optional[List[LongTensor]]]:
19+
"""
20+
21+
Args:
22+
embeddings: Matrix with the shape of ``[L, dim]``
23+
ids_query: Tensor with the size of ``Q``, where ``Q <= n``. Each element is within the range ``(0, L - 1)``.
24+
ids_gallery: Tensor with the size of ``G`` where ``G <= n``. Each element is within the range ``(0, L - 1)``.
25+
May overlap with ``ids_query``.
26+
top_n: Number of neighbors to find and return.
27+
sequence_ids: Sequence identifiers with the size of ``L`` (if known).
28+
labels_gt: Ground truth labels of every element with the size of ``L`` (if known).
29+
bs: Batch size for computing distances to avoid OOM errors when processing the whole matrix at once.
30+
31+
Returns:
32+
distances: Sorted distances from every query to the closest ``top_n`` galleries with the size of ``(Q, top_n)``.
33+
retrieved_ids: The corresponding ids of gallery items with the shape of ``(Q, top_n)``.
34+
Each element is within the range ``(0, G - 1)``.
35+
gt_ids: Ids of the gallery items relevant to every query. Each element is within the range ``(0, G - 1)``.
36+
It's only computed if ``labels_gt`` is provided.
37+
38+
"""
39+
assert (ids_query.ndim == 1) and (ids_gallery.ndim == 1) and (embeddings.ndim == 2)
40+
assert len(embeddings) <= len(ids_query) + len(ids_gallery)
41+
assert (sequence_ids is None) or ((len(sequence_ids) == len(embeddings)) and (sequence_ids.ndim == 1))
42+
assert (labels_gt is None) or ((len(labels_gt) == embeddings.shape[0]) and (labels_gt.ndim == 1))
43+
44+
top_n = min(top_n, len(ids_gallery))
45+
46+
emb_q = embeddings[ids_query]
47+
emb_g = embeddings[ids_gallery]
48+
49+
nq = len(ids_query)
50+
retrieved_ids = LongTensor(nq, top_n)
51+
distances = FloatTensor(nq, top_n)
52+
gt_ids = []
53+
54+
# we do batching over first (queries) dimension
55+
for i in range(0, nq, bs):
56+
distances_b = pairwise_dist(x1=emb_q[i : i + bs, :], x2=emb_g)
57+
ids_query_b = ids_query[i : i + bs]
58+
59+
# the logic behind: we want to ignore the item during search if it was used for both: query and gallery
60+
mask_to_ignore_b = ids_query_b[..., None] == ids_gallery[None, ...]
61+
if sequence_ids is not None:
62+
# sometimes our items may be packed into the groups, so we ignore other members of this group during search
63+
# more info in the docs: find for "Handling sequences of photos"
64+
mask_sequence = sequence_ids[ids_query_b][..., None] == sequence_ids[ids_gallery][None, ...]
65+
mask_to_ignore_b = torch.logical_or(mask_to_ignore_b, mask_sequence)
66+
67+
if labels_gt is not None:
68+
mask_gt_b = labels_gt[ids_query_b][..., None] == labels_gt[ids_gallery][None, ...]
69+
mask_gt_b[mask_to_ignore_b] = False
70+
gt_ids.extend([LongTensor(row.nonzero()).view(-1) for row in mask_gt_b]) # type: ignore
71+
72+
distances_b[mask_to_ignore_b] = float("inf")
73+
distances[i : i + bs, :], retrieved_ids[i : i + bs, :] = torch.topk(
74+
distances_b, k=top_n, largest=False, sorted=True
75+
)
76+
77+
return distances, retrieved_ids, gt_ids or None
78+
79+
80+
__all__ = ["batched_knn"]

oml/retrieval/retrieval_results.py

+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from typing import List
2+
3+
import matplotlib.pyplot as plt
4+
import pandas as pd
5+
from torch import FloatTensor, LongTensor
6+
7+
from oml.const import (
8+
BLACK,
9+
BLUE,
10+
GRAY,
11+
GREEN,
12+
N_GT_SHOW_EMBEDDING_METRICS,
13+
RED,
14+
SEQUENCE_COLUMN,
15+
)
16+
from oml.functional.knn import batched_knn
17+
from oml.interfaces.datasets import (
18+
IQueryGalleryDataset,
19+
IQueryGalleryLabeledDataset,
20+
IVisualizableDataset,
21+
)
22+
23+
24+
class RetrievalResults:
25+
def __init__(
26+
self,
27+
distances: FloatTensor,
28+
retrieved_ids: LongTensor,
29+
gt_ids: List[LongTensor] = None,
30+
):
31+
"""
32+
Args:
33+
distances: Sorted distances to the first ``top_n`` gallery items with the shape of ``[n_query, top_n]``.
34+
retrieved_ids: Top N gallery ids retrieved for every query with the shape of ``[n_query, top_n]``.
35+
Every element is within the range ``(0, n_gallery - 1)``.
36+
gt_ids: Gallery ids relevant to every query, list of ``n_query`` elements where every element may
37+
have an arbitrary length. Every element is within the range ``(0, n_gallery - 1)``
38+
39+
"""
40+
assert distances.shape == retrieved_ids.shape
41+
assert distances.ndim == 2
42+
43+
if gt_ids is not None:
44+
assert distances.shape[0] == len(gt_ids)
45+
if any(len(x) == 0 for x in gt_ids):
46+
raise RuntimeError("Every query must have at least one relevant gallery id.")
47+
48+
self.distances = distances
49+
self.retrieved_ids = retrieved_ids
50+
self.gt_ids = gt_ids
51+
52+
@property
53+
def n_retrieved_items(self) -> int:
54+
return self.retrieved_ids.shape[1]
55+
56+
@classmethod
57+
def compute_from_embeddings(
58+
cls,
59+
embeddings: FloatTensor,
60+
dataset: IQueryGalleryDataset,
61+
n_items_to_retrieve: int = 100,
62+
) -> "RetrievalResults":
63+
"""
64+
Args:
65+
embeddings: The result of inference with the shape of ``[dataset_len, emb_dim]``.
66+
dataset: Dataset having query/gallery split.
67+
n_items_to_retrieve: Number of the closest gallery items to retrieve.
68+
It may be clipped by gallery size if needed.
69+
70+
"""
71+
assert len(embeddings) == len(dataset), "Embeddings and dataset must have the same size."
72+
73+
# todo 522: rework taking sequence
74+
if hasattr(dataset, "df") and SEQUENCE_COLUMN in dataset.df:
75+
dataset.df[SEQUENCE_COLUMN], _ = pd.factorize(dataset.df[SEQUENCE_COLUMN])
76+
sequence_ids = LongTensor(dataset.df[SEQUENCE_COLUMN])
77+
else:
78+
sequence_ids = None
79+
80+
labels_gt = dataset.get_labels() if isinstance(dataset, IQueryGalleryLabeledDataset) else None
81+
82+
distances, retrieved_ids, gt_ids = batched_knn(
83+
embeddings=embeddings,
84+
ids_query=dataset.get_query_ids(),
85+
ids_gallery=dataset.get_gallery_ids(),
86+
labels_gt=labels_gt,
87+
sequence_ids=sequence_ids,
88+
top_n=n_items_to_retrieve,
89+
)
90+
91+
return RetrievalResults(distances=distances, retrieved_ids=retrieved_ids, gt_ids=gt_ids)
92+
93+
def __str__(self) -> str:
94+
txt = (
95+
f"You retrieved {self.n_retrieved_items} items.\n"
96+
f"Distances to the retrieved items:\n{self.distances}.\n"
97+
f"Ids of the retrieved gallery items:\n{self.retrieved_ids}.\n"
98+
)
99+
100+
if self.gt_ids is None:
101+
txt += "Ground truths are unknown.\n"
102+
else:
103+
gt_ids_list = [x.tolist() for x in self.gt_ids]
104+
txt += f"Ground truth gallery ids are:\n{gt_ids_list}.\n"
105+
106+
return txt
107+
108+
def visualize(
109+
self,
110+
query_ids: List[int],
111+
dataset: IQueryGalleryDataset,
112+
n_galleries_to_show: int = 5,
113+
n_gt_to_show: int = N_GT_SHOW_EMBEDDING_METRICS,
114+
verbose: bool = False,
115+
) -> plt.Figure:
116+
"""
117+
Args:
118+
query_ids: Query indices within the range of ``(0, n_query - 1)``.
119+
dataset: Dataset that provides query-gallery split and supports visualisation.
120+
n_galleries_to_show: Number of closest gallery items to show.
121+
n_gt_to_show: Number of ground truth gallery items to show for reference (if available).
122+
verbose: Set ``True`` to allow prints.
123+
124+
"""
125+
if not isinstance(dataset, (IVisualizableDataset, IQueryGalleryDataset)):
126+
raise TypeError(
127+
f"Dataset has to support {IVisualizableDataset.__name__} and "
128+
f"{IQueryGalleryDataset.__name__} interfaces. Got {type(dataset)}."
129+
)
130+
131+
if verbose:
132+
print(f"Visualizing {n_galleries_to_show} for the following query ids: {query_ids}.")
133+
134+
ii_query = dataset.get_query_ids()
135+
ii_gallery = dataset.get_gallery_ids()
136+
137+
n_galleries_to_show = min(n_galleries_to_show, self.n_retrieved_items)
138+
n_gt_to_show = n_gt_to_show if (self.gt_ids is not None) else 0
139+
140+
fig = plt.figure(figsize=(16, 16 / (n_galleries_to_show + n_gt_to_show + 1) * len(query_ids)))
141+
n_rows, n_cols = len(query_ids), n_galleries_to_show + 1 + n_gt_to_show
142+
143+
# iterate over queries
144+
for i, query_idx in enumerate(query_ids):
145+
146+
plt.subplot(n_rows, n_cols, i * (n_galleries_to_show + 1 + n_gt_to_show) + 1)
147+
148+
img = dataset.visualize(item=ii_query[query_idx].item(), color=BLUE)
149+
150+
plt.imshow(img)
151+
plt.title(f"Query #{query_idx}")
152+
plt.axis("off")
153+
154+
# iterate over retrieved items
155+
for j, ret_idx in enumerate(self.retrieved_ids[query_idx, :][:n_galleries_to_show]):
156+
if self.gt_ids is not None:
157+
color = GREEN if ret_idx in self.gt_ids[query_idx] else RED
158+
else:
159+
color = BLACK
160+
161+
plt.subplot(n_rows, n_cols, i * (n_galleries_to_show + 1 + n_gt_to_show) + j + 2)
162+
img = dataset.visualize(item=ii_gallery[ret_idx].item(), color=color)
163+
164+
plt.title(f"Gallery #{ret_idx} - {round(self.distances[query_idx, j].item(), 3)}")
165+
plt.imshow(img)
166+
plt.axis("off")
167+
168+
if self.gt_ids is not None:
169+
170+
for k, gt_idx in enumerate(self.gt_ids[query_idx][:n_gt_to_show]):
171+
plt.subplot(
172+
n_rows, n_cols, i * (n_galleries_to_show + 1 + n_gt_to_show) + k + n_galleries_to_show + 2
173+
)
174+
175+
img = dataset.visualize(item=ii_gallery[gt_idx].item(), color=GRAY)
176+
plt.title("GT")
177+
plt.imshow(img)
178+
plt.axis("off")
179+
180+
fig.tight_layout()
181+
return fig
182+
183+
184+
__all__ = ["RetrievalResults"]

0 commit comments

Comments
 (0)