-
Notifications
You must be signed in to change notification settings - Fork 68
/
Copy pathretrieval_results.py
182 lines (146 loc) · 6.47 KB
/
retrieval_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from typing import List
import matplotlib.pyplot as plt
import pandas as pd
from torch import FloatTensor, LongTensor
from oml.const import (
BLACK,
BLUE,
GRAY,
GREEN,
N_GT_SHOW_EMBEDDING_METRICS,
RED,
SEQUENCE_COLUMN,
)
from oml.functional.knn import batched_knn
from oml.interfaces.datasets import (
IQueryGalleryDataset,
IQueryGalleryLabeledDataset,
IVisualizableDataset,
)
class RetrievalResults:
def __init__(
self,
distances: FloatTensor,
retrieved_ids: LongTensor,
gt_ids: List[LongTensor] = None,
):
"""
Args:
distances: Sorted distances to the first ``top_n`` gallery items with the shape of ``[n_query, top_n]``.
retrieved_ids: Top N gallery ids retrieved for every query with the shape of ``[n_query, top_n]``.
Every element is within the range ``(0, n_gallery - 1)``.
gt_ids: Gallery ids relevant to every query, list of ``n_query`` elements where every element may
have an arbitrary length. Every element is within the range ``(0, n_gallery - 1)``
"""
assert distances.shape == retrieved_ids.shape
assert distances.ndim == 2
if gt_ids is not None:
assert distances.shape[0] == len(gt_ids)
if not all(len(x) > 0 for x in gt_ids):
raise RuntimeError("Every query must have at least one relevant gallery id.")
self.distances = distances
self.retrieved_ids = retrieved_ids
self.gt_ids = gt_ids
@property
def n_retrieved_items(self) -> int:
return self.retrieved_ids.shape[1]
@classmethod
def compute_from_embeddings(
cls,
embeddings: FloatTensor,
dataset: IQueryGalleryDataset,
n_items_to_retrieve: int = 1_000,
) -> "RetrievalResults":
"""
Args:
embeddings: The result of inference with the shape of ``[dataset_len, emb_dim]``.
dataset: Dataset having query/gallery split.
n_items_to_retrieve: Number of the closest gallery items to retrieve.
It may be clipped by gallery size if needed.
"""
assert len(embeddings) == len(dataset), "Embeddings and dataset must have the same size."
# todo 522: rework taking sequence
if hasattr(dataset, "df") and SEQUENCE_COLUMN in dataset.df:
dataset.df[SEQUENCE_COLUMN], _ = pd.factorize(dataset.df[SEQUENCE_COLUMN])
sequence_ids = LongTensor(dataset.df[SEQUENCE_COLUMN])
else:
sequence_ids = None
labels_gt = dataset.get_labels() if isinstance(dataset, IQueryGalleryLabeledDataset) else None
distances, retrieved_ids, gt_ids = batched_knn(
embeddings=embeddings,
ids_query=dataset.get_query_ids(),
ids_gallery=dataset.get_gallery_ids(),
labels_gt=labels_gt,
sequence_ids=sequence_ids,
top_n=n_items_to_retrieve,
)
return RetrievalResults(distances=distances, retrieved_ids=retrieved_ids, gt_ids=gt_ids)
def __repr__(self) -> str:
txt = (
f"You retrieved {self.n_retrieved_items} items.\n"
f"Distances to the retrieved items:\n{self.distances}.\n"
f"Ids of the retrieved gallery items:\n{self.retrieved_ids}.\n"
)
if self.gt_ids is None:
txt += "Ground truths are unknown.\n"
else:
gt_ids_list = [x.tolist() for x in self.gt_ids]
txt += f"Ground truth gallery ids are:\n{gt_ids_list}.\n"
return txt
def visualize(
self,
query_ids: List[int],
dataset: IQueryGalleryDataset,
n_galleries_to_show: int = 5,
verbose: bool = False,
) -> plt.Figure:
"""
Args:
query_ids: Query indices within the range of ``(0, n_query - 1)``.
dataset: Dataset that provides query-gallery split and supports visualisation.
n_galleries_to_show: Number of closest gallery items to show.
verbose: Set ``True`` to allow prints.
"""
if not isinstance(dataset, (IVisualizableDataset, IQueryGalleryDataset)):
raise ValueError(
f"Dataset has to support {IVisualizableDataset.__name__} and "
f"{IQueryGalleryDataset} interfaces. Got {type(dataset)}."
)
if verbose:
print(f"Visualizing {n_galleries_to_show} for the following query ids: {query_ids}.")
ii_query = dataset.get_query_ids()
ii_gallery = dataset.get_gallery_ids()
n_galleries_to_show = min(n_galleries_to_show, self.n_retrieved_items)
n_gt_to_show = N_GT_SHOW_EMBEDDING_METRICS if (self.gt_ids is not None) else 0
fig = plt.figure(figsize=(16, 16 / (n_galleries_to_show + n_gt_to_show + 1) * len(query_ids)))
n_rows, n_cols = len(query_ids), n_galleries_to_show + 1 + n_gt_to_show
# iterate over queries
for j, query_idx in enumerate(query_ids):
plt.subplot(n_rows, n_cols, j * (n_galleries_to_show + 1 + n_gt_to_show) + 1)
img = dataset.visualize(item=ii_query[query_idx].item(), color=BLUE)
plt.imshow(img)
plt.title(f"Query #{query_idx}")
plt.axis("off")
# iterate over retrieved items
for i, ret_idx in enumerate(self.retrieved_ids[query_idx, :][:n_galleries_to_show]):
if self.gt_ids is not None:
color = GREEN if ret_idx in self.gt_ids[query_idx] else RED
else:
color = BLACK
plt.subplot(n_rows, n_cols, j * (n_galleries_to_show + 1 + n_gt_to_show) + i + 2)
img = dataset.visualize(item=ii_gallery[ret_idx].item(), color=color)
plt.title(f"Gallery #{ret_idx} - {round(self.distances[query_idx, i].item(), 3)}")
plt.imshow(img)
plt.axis("off")
if self.gt_ids is not None:
for k, gt_idx in enumerate(self.gt_ids[query_idx][:n_gt_to_show]):
plt.subplot(
n_rows, n_cols, j * (n_galleries_to_show + 1 + n_gt_to_show) + k + n_galleries_to_show + 2
)
img = dataset.visualize(item=ii_gallery[gt_idx].item(), color=GRAY)
plt.title("GT")
plt.imshow(img)
plt.axis("off")
fig.tight_layout()
return fig
__all__ = ["RetrievalResults"]