diff --git a/Makefile b/Makefile index 9d6c16a1a..39ecc2b2c 100644 --- a/Makefile +++ b/Makefile @@ -114,3 +114,11 @@ upload_to_pip: build_wheel .PHONY: pip_install_actual_oml pip_install_actual_oml: pip install open-metric-learning==$(OML_VERSION) + +# ====================================== MISC ============================= +.PHONY: clean +clean: + find . -type d -name "__pycache__" -exec rm -r {} + + find . -type f -name "*.log" -exec rm {} + + find . -type f -name "*.predictions.json" -exec rm {} + + rm -rf docs/build diff --git a/oml/const.py b/oml/const.py index 1890c8d94..907b0ca0f 100644 --- a/oml/const.py +++ b/oml/const.py @@ -53,6 +53,8 @@ def get_cache_folder() -> Path: BLACK = (0, 0, 0) PAD_COLOR = (255, 255, 255) +BS_KNN = 5_000 + TCfg = Union[Dict[str, Any], DictConfig] # ImageNet Params diff --git a/oml/datasets/images.py b/oml/datasets/images.py index 8828b66fc..28b16fe12 100644 --- a/oml/datasets/images.py +++ b/oml/datasets/images.py @@ -397,8 +397,9 @@ def __init__( is_gallery_key: str = IS_GALLERY_KEY, ): assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, PATHS_COLUMN)) + self.df = df.copy() + # instead of implementing the whole logic let's just re-use QGL dataset, but with dropped labels - df = df.copy() df[LABELS_COLUMN] = "fake_label" self.__dataset = ImageQueryGalleryLabeledDataset( @@ -427,6 +428,9 @@ def __getitem__(self, item: int) -> Dict[str, Any]: del batch[self.__dataset.labels_key] return batch + def __len__(self) -> int: + return len(self.__dataset) + def get_query_ids(self) -> LongTensor: return self.__dataset.get_query_ids() @@ -434,7 +438,7 @@ def get_gallery_ids(self) -> LongTensor: return self.__dataset.get_gallery_ids() def visualize(self, item: int, color: TColor = BLACK) -> np.ndarray: - return self.__dataset.visualize(item, color) + return self.__dataset.visualize(item=item, color=color) def get_retrieval_images_datasets( diff --git a/oml/functional/knn.py b/oml/functional/knn.py new file mode 100644 index 000000000..785a7c2f9 --- /dev/null +++ b/oml/functional/knn.py @@ -0,0 +1,80 @@ +from typing import List, Optional, Tuple + +import torch +from torch import FloatTensor, LongTensor + +from oml.const import BS_KNN +from oml.utils.misc_torch import pairwise_dist + + +def batched_knn( + embeddings: FloatTensor, + ids_query: LongTensor, + ids_gallery: LongTensor, + top_n: int, + sequence_ids: Optional[LongTensor] = None, + labels_gt: Optional[LongTensor] = None, + bs: int = BS_KNN, +) -> Tuple[FloatTensor, LongTensor, Optional[List[LongTensor]]]: + """ + + Args: + embeddings: Matrix with the shape of ``[L, dim]`` + ids_query: Tensor with the size of ``Q``, where ``Q <= n``. Each element is within the range ``(0, L - 1)``. + ids_gallery: Tensor with the size of ``G`` where ``G <= n``. Each element is within the range ``(0, L - 1)``. + May overlap with ``ids_query``. + top_n: Number of neighbors to find and return. + sequence_ids: Sequence identifiers with the size of ``L`` (if known). + labels_gt: Ground truth labels of every element with the size of ``L`` (if known). + bs: Batch size for computing distances to avoid OOM errors when processing the whole matrix at once. + + Returns: + distances: Sorted distances from every query to the closest ``top_n`` galleries with the size of ``(Q, top_n)``. + retrieved_ids: The corresponding ids of gallery items with the shape of ``(Q, top_n)``. + Each element is within the range ``(0, G - 1)``. + gt_ids: Ids of the gallery items relevant to every query. Each element is within the range ``(0, G - 1)``. + It's only computed if ``labels_gt`` is provided. + + """ + assert (ids_query.ndim == 1) and (ids_gallery.ndim == 1) and (embeddings.ndim == 2) + assert len(embeddings) <= len(ids_query) + len(ids_gallery) + assert (sequence_ids is None) or ((len(sequence_ids) == len(embeddings)) and (sequence_ids.ndim == 1)) + assert (labels_gt is None) or ((len(labels_gt) == embeddings.shape[0]) and (labels_gt.ndim == 1)) + + top_n = min(top_n, len(ids_gallery)) + + emb_q = embeddings[ids_query] + emb_g = embeddings[ids_gallery] + + nq = len(ids_query) + retrieved_ids = LongTensor(nq, top_n) + distances = FloatTensor(nq, top_n) + gt_ids = [] + + # we do batching over first (queries) dimension + for i in range(0, nq, bs): + distances_b = pairwise_dist(x1=emb_q[i : i + bs, :], x2=emb_g) + ids_query_b = ids_query[i : i + bs] + + # the logic behind: we want to ignore the item during search if it was used for both: query and gallery + mask_to_ignore_b = ids_query_b[..., None] == ids_gallery[None, ...] + if sequence_ids is not None: + # sometimes our items may be packed into the groups, so we ignore other members of this group during search + # more info in the docs: find for "Handling sequences of photos" + mask_sequence = sequence_ids[ids_query_b][..., None] == sequence_ids[ids_gallery][None, ...] + mask_to_ignore_b = torch.logical_or(mask_to_ignore_b, mask_sequence) + + if labels_gt is not None: + mask_gt_b = labels_gt[ids_query_b][..., None] == labels_gt[ids_gallery][None, ...] + mask_gt_b[mask_to_ignore_b] = False + gt_ids.extend([LongTensor(row.nonzero()).view(-1) for row in mask_gt_b]) # type: ignore + + distances_b[mask_to_ignore_b] = float("inf") + distances[i : i + bs, :], retrieved_ids[i : i + bs, :] = torch.topk( + distances_b, k=top_n, largest=False, sorted=True + ) + + return distances, retrieved_ids, gt_ids or None + + +__all__ = ["batched_knn"] diff --git a/oml/retrieval/retrieval_results.py b/oml/retrieval/retrieval_results.py new file mode 100644 index 000000000..cfbe633db --- /dev/null +++ b/oml/retrieval/retrieval_results.py @@ -0,0 +1,184 @@ +from typing import List + +import matplotlib.pyplot as plt +import pandas as pd +from torch import FloatTensor, LongTensor + +from oml.const import ( + BLACK, + BLUE, + GRAY, + GREEN, + N_GT_SHOW_EMBEDDING_METRICS, + RED, + SEQUENCE_COLUMN, +) +from oml.functional.knn import batched_knn +from oml.interfaces.datasets import ( + IQueryGalleryDataset, + IQueryGalleryLabeledDataset, + IVisualizableDataset, +) + + +class RetrievalResults: + def __init__( + self, + distances: FloatTensor, + retrieved_ids: LongTensor, + gt_ids: List[LongTensor] = None, + ): + """ + Args: + distances: Sorted distances to the first ``top_n`` gallery items with the shape of ``[n_query, top_n]``. + retrieved_ids: Top N gallery ids retrieved for every query with the shape of ``[n_query, top_n]``. + Every element is within the range ``(0, n_gallery - 1)``. + gt_ids: Gallery ids relevant to every query, list of ``n_query`` elements where every element may + have an arbitrary length. Every element is within the range ``(0, n_gallery - 1)`` + + """ + assert distances.shape == retrieved_ids.shape + assert distances.ndim == 2 + + if gt_ids is not None: + assert distances.shape[0] == len(gt_ids) + if any(len(x) == 0 for x in gt_ids): + raise RuntimeError("Every query must have at least one relevant gallery id.") + + self.distances = distances + self.retrieved_ids = retrieved_ids + self.gt_ids = gt_ids + + @property + def n_retrieved_items(self) -> int: + return self.retrieved_ids.shape[1] + + @classmethod + def compute_from_embeddings( + cls, + embeddings: FloatTensor, + dataset: IQueryGalleryDataset, + n_items_to_retrieve: int = 100, + ) -> "RetrievalResults": + """ + Args: + embeddings: The result of inference with the shape of ``[dataset_len, emb_dim]``. + dataset: Dataset having query/gallery split. + n_items_to_retrieve: Number of the closest gallery items to retrieve. + It may be clipped by gallery size if needed. + + """ + assert len(embeddings) == len(dataset), "Embeddings and dataset must have the same size." + + # todo 522: rework taking sequence + if hasattr(dataset, "df") and SEQUENCE_COLUMN in dataset.df: + dataset.df[SEQUENCE_COLUMN], _ = pd.factorize(dataset.df[SEQUENCE_COLUMN]) + sequence_ids = LongTensor(dataset.df[SEQUENCE_COLUMN]) + else: + sequence_ids = None + + labels_gt = dataset.get_labels() if isinstance(dataset, IQueryGalleryLabeledDataset) else None + + distances, retrieved_ids, gt_ids = batched_knn( + embeddings=embeddings, + ids_query=dataset.get_query_ids(), + ids_gallery=dataset.get_gallery_ids(), + labels_gt=labels_gt, + sequence_ids=sequence_ids, + top_n=n_items_to_retrieve, + ) + + return RetrievalResults(distances=distances, retrieved_ids=retrieved_ids, gt_ids=gt_ids) + + def __str__(self) -> str: + txt = ( + f"You retrieved {self.n_retrieved_items} items.\n" + f"Distances to the retrieved items:\n{self.distances}.\n" + f"Ids of the retrieved gallery items:\n{self.retrieved_ids}.\n" + ) + + if self.gt_ids is None: + txt += "Ground truths are unknown.\n" + else: + gt_ids_list = [x.tolist() for x in self.gt_ids] + txt += f"Ground truth gallery ids are:\n{gt_ids_list}.\n" + + return txt + + def visualize( + self, + query_ids: List[int], + dataset: IQueryGalleryDataset, + n_galleries_to_show: int = 5, + n_gt_to_show: int = N_GT_SHOW_EMBEDDING_METRICS, + verbose: bool = False, + ) -> plt.Figure: + """ + Args: + query_ids: Query indices within the range of ``(0, n_query - 1)``. + dataset: Dataset that provides query-gallery split and supports visualisation. + n_galleries_to_show: Number of closest gallery items to show. + n_gt_to_show: Number of ground truth gallery items to show for reference (if available). + verbose: Set ``True`` to allow prints. + + """ + if not isinstance(dataset, (IVisualizableDataset, IQueryGalleryDataset)): + raise TypeError( + f"Dataset has to support {IVisualizableDataset.__name__} and " + f"{IQueryGalleryDataset.__name__} interfaces. Got {type(dataset)}." + ) + + if verbose: + print(f"Visualizing {n_galleries_to_show} for the following query ids: {query_ids}.") + + ii_query = dataset.get_query_ids() + ii_gallery = dataset.get_gallery_ids() + + n_galleries_to_show = min(n_galleries_to_show, self.n_retrieved_items) + n_gt_to_show = n_gt_to_show if (self.gt_ids is not None) else 0 + + fig = plt.figure(figsize=(16, 16 / (n_galleries_to_show + n_gt_to_show + 1) * len(query_ids))) + n_rows, n_cols = len(query_ids), n_galleries_to_show + 1 + n_gt_to_show + + # iterate over queries + for i, query_idx in enumerate(query_ids): + + plt.subplot(n_rows, n_cols, i * (n_galleries_to_show + 1 + n_gt_to_show) + 1) + + img = dataset.visualize(item=ii_query[query_idx].item(), color=BLUE) + + plt.imshow(img) + plt.title(f"Query #{query_idx}") + plt.axis("off") + + # iterate over retrieved items + for j, ret_idx in enumerate(self.retrieved_ids[query_idx, :][:n_galleries_to_show]): + if self.gt_ids is not None: + color = GREEN if ret_idx in self.gt_ids[query_idx] else RED + else: + color = BLACK + + plt.subplot(n_rows, n_cols, i * (n_galleries_to_show + 1 + n_gt_to_show) + j + 2) + img = dataset.visualize(item=ii_gallery[ret_idx].item(), color=color) + + plt.title(f"Gallery #{ret_idx} - {round(self.distances[query_idx, j].item(), 3)}") + plt.imshow(img) + plt.axis("off") + + if self.gt_ids is not None: + + for k, gt_idx in enumerate(self.gt_ids[query_idx][:n_gt_to_show]): + plt.subplot( + n_rows, n_cols, i * (n_galleries_to_show + 1 + n_gt_to_show) + k + n_galleries_to_show + 2 + ) + + img = dataset.visualize(item=ii_gallery[gt_idx].item(), color=GRAY) + plt.title("GT") + plt.imshow(img) + plt.axis("off") + + fig.tight_layout() + return fig + + +__all__ = ["RetrievalResults"] diff --git a/tests/test_oml/test_functional/test_knn.py b/tests/test_oml/test_functional/test_knn.py new file mode 100644 index 000000000..7297ce040 --- /dev/null +++ b/tests/test_oml/test_functional/test_knn.py @@ -0,0 +1,105 @@ +from random import randint +from typing import Optional, Tuple + +import pytest +import torch +from torch import FloatTensor, LongTensor + +from oml.functional.knn import batched_knn +from oml.utils.misc_torch import pairwise_dist + + +def straight_knn( + embeddings: FloatTensor, + ids_query: LongTensor, + ids_gallery: LongTensor, + labels: Optional[LongTensor], + sequence_ids: Optional[LongTensor], + top_n: int, +) -> Tuple[FloatTensor, LongTensor, Optional[LongTensor]]: + top_n = min(top_n, len(ids_gallery)) + + mask_to_ignore = ids_query[..., None] == ids_gallery[None, ...] + if sequence_ids is not None: + mask_sequence = sequence_ids[ids_query][..., None] == sequence_ids[ids_gallery][None, ...] + mask_to_ignore = torch.logical_or(mask_to_ignore, mask_sequence) + + distances_all = pairwise_dist(x1=embeddings[ids_query], x2=embeddings[ids_gallery], p=2) + distances_all[mask_to_ignore] = float("inf") + distances, retrieved_ids = torch.topk(distances_all, k=top_n, largest=False, sorted=True) + + if labels is not None: + mask_gt = labels[ids_query][..., None] == labels[ids_gallery][None, ...] + mask_gt[mask_to_ignore] = False + gt_ids = [LongTensor(row.nonzero()).view(-1) for row in mask_gt] + else: + gt_ids = None + + return distances, retrieved_ids, gt_ids + + +def generate_data( + dataset_len: int, n_classes: Optional[int], n_sequences: Optional[int], separate_query_gallery: bool +) -> Tuple[FloatTensor, LongTensor, LongTensor, Optional[LongTensor], Optional[LongTensor]]: + if separate_query_gallery: + n_query = randint(1, dataset_len - 1) + ii = torch.randperm(n=dataset_len) + ids_query, ids_gallery = ii[:n_query], ii[n_query:] + + else: + n_query = randint(1 + dataset_len // 2, dataset_len) + n_gallery = randint(1 + dataset_len // 2, dataset_len) + ids_query = torch.randperm(n=dataset_len)[:n_query] + ids_gallery = torch.randperm(n=dataset_len)[:n_gallery] + assert set(ids_query.tolist()).intersection(ids_gallery.tolist()), "Query and gallery don't intersect!" + + embeddings = torch.randn((dataset_len, 8)).float() + labels = torch.randint(0, n_classes, size=(dataset_len,)) if n_classes else None + sequence_ids = torch.randint(0, n_sequences, size=(dataset_len,)) if n_sequences else None + + return embeddings, ids_query, ids_gallery, labels, sequence_ids + + +@pytest.mark.parametrize("dataset_len", [2, 10, 30]) +@pytest.mark.parametrize("need_sequence", [True, False]) +@pytest.mark.parametrize("need_gt", [True, False]) +@pytest.mark.parametrize("separate_query_gallery", [True, False]) +def test_batched_knn(dataset_len: int, need_sequence: bool, need_gt: bool, separate_query_gallery: bool) -> None: + for i in range(5): + batch_size_knn = randint(1, dataset_len) + n_classes = randint(1, 5) if need_gt else None + n_sequences = randint(1, 4) if need_sequence else None + top_n = randint(1, int(1.5 * dataset_len)) + + embeddings, ids_query, ids_gallery, labels, sequence_ids = generate_data( + dataset_len=dataset_len, + n_classes=n_classes, + n_sequences=n_sequences, + separate_query_gallery=separate_query_gallery, + ) + + distances_, retrieved_ids_, gt_ids_ = straight_knn( + embeddings=embeddings, + ids_query=ids_query, + ids_gallery=ids_gallery, + top_n=top_n, + labels=labels, + sequence_ids=sequence_ids, + ) + + distances, retrieved_ids, gt_ids = batched_knn( + embeddings=embeddings, + ids_query=ids_query, + ids_gallery=ids_gallery, + top_n=top_n, + labels_gt=labels, + bs=batch_size_knn, + sequence_ids=sequence_ids, + ) + + assert torch.allclose(distances, distances_) + assert (retrieved_ids == retrieved_ids_).all() + + if need_gt: + for (ii, ii_) in zip(gt_ids, gt_ids_): + assert (ii == ii_).all() diff --git a/tests/test_oml/test_retrieval_results/__init__.py b/tests/test_oml/test_retrieval_results/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_oml/test_retrieval_results/test_retrieval_results.py b/tests/test_oml/test_retrieval_results/test_retrieval_results.py new file mode 100644 index 000000000..701475299 --- /dev/null +++ b/tests/test_oml/test_retrieval_results/test_retrieval_results.py @@ -0,0 +1,58 @@ +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch + +from oml.const import IS_QUERY_COLUMN, LABELS_COLUMN, MOCK_DATASET_PATH, PATHS_COLUMN +from oml.datasets.images import ( + ImageQueryGalleryDataset, + ImageQueryGalleryLabeledDataset, +) +from oml.inference.flat import inference_on_images +from oml.models import ResnetExtractor +from oml.retrieval.retrieval_results import RetrievalResults +from oml.transforms.images.torchvision import get_normalisation_torch +from oml.utils.download_mock_dataset import download_mock_dataset + + +@pytest.mark.parametrize("with_gt_labels", [False, True]) +@pytest.mark.parametrize("df_name", ["df.csv", "df_with_bboxes.csv", "df_with_sequence.csv"]) +def test_retrieval_results_om_images(with_gt_labels: bool, df_name: str) -> None: + # todo 522: add test on Embeddings after we merge unified inference + + _, df_val = download_mock_dataset(dataset_root=MOCK_DATASET_PATH, df_name=df_name) + df_val[PATHS_COLUMN] = df_val[PATHS_COLUMN].apply(lambda x: Path(MOCK_DATASET_PATH) / x) + + n_query = df_val[IS_QUERY_COLUMN].sum() + + if with_gt_labels: + dataset = ImageQueryGalleryLabeledDataset(df_val) + else: + del df_val[LABELS_COLUMN] + dataset = ImageQueryGalleryDataset(df_val) + + model = ResnetExtractor(weights=None, arch="resnet18", gem_p=None, remove_fc=True, normalise_features=False) + embeddings = inference_on_images( + model=model, + paths=df_val[PATHS_COLUMN].tolist(), + transform=get_normalisation_torch(), + num_workers=0, + batch_size=4, + ).float() + + top_n = 2 + rr = RetrievalResults.compute_from_embeddings(embeddings=embeddings, dataset=dataset, n_items_to_retrieve=top_n) + + assert rr.distances.shape == (n_query, top_n) + assert rr.retrieved_ids.shape == (n_query, top_n) + assert torch.allclose(rr.distances.clone().sort()[0], rr.distances) + + if with_gt_labels: + assert rr.gt_ids is not None + + fig = rr.visualize(query_ids=[0, 3], dataset=dataset, n_galleries_to_show=3) + fig.show() + plt.close(fig=fig) + + assert True