From 7b8e86c96100272f384e5f2f37800923cd75ffc5 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Fri, 19 Apr 2024 08:06:38 +0700 Subject: [PATCH 01/23] reworking datasets --- .gitignore | 1 + oml/const.py | 6 +- oml/datasets/base.py | 346 +------------- oml/datasets/images.py | 428 ++++++++++++++++++ oml/datasets/list_dataset.py | 16 +- oml/datasets/pairs.py | 6 +- oml/interfaces/datasets.py | 96 ++-- oml/lightning/pipelines/parser.py | 4 +- oml/lightning/pipelines/train.py | 4 +- oml/lightning/pipelines/validate.py | 4 +- .../test_lightning/test_pipeline.py | 2 + .../test_retrieval_validation.py | 54 +-- tests/test_integrations/utils.py | 72 ++- .../test_datasets/test_list_dataest.py | 4 +- tests/test_oml/test_registry/test_registry.py | 9 +- 15 files changed, 610 insertions(+), 442 deletions(-) create mode 100644 oml/datasets/images.py diff --git a/.gitignore b/.gitignore index 95cecde2d..fa99851dc 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ tmp.ipynb *outputs* *.hydra* *predictions.json* +*ml-runs* # __________________________________ diff --git a/oml/const.py b/oml/const.py index 62045d248..1890c8d94 100644 --- a/oml/const.py +++ b/oml/const.py @@ -2,7 +2,7 @@ import tempfile from pathlib import Path from sys import platform -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union from omegaconf import DictConfig @@ -50,6 +50,7 @@ def get_cache_folder() -> Path: GREEN = (0, 255, 0) BLUE = (0, 0, 255) GRAY = (120, 120, 120) +BLACK = (0, 0, 0) PAD_COLOR = (255, 255, 255) TCfg = Union[Dict[str, Any], DictConfig] @@ -62,6 +63,9 @@ def get_cache_folder() -> Path: MEAN_CLIP = (0.48145466, 0.4578275, 0.40821073) STD_CLIP = (0.26862954, 0.26130258, 0.27577711) +TBBox = Tuple[int, int, int, int] +TBBoxes = Sequence[Optional[TBBox]] + CROP_KEY = "crop" # the format is [x1, y1, x2, y2] # Required dataset format: diff --git a/oml/datasets/base.py b/oml/datasets/base.py index f76ac3b26..67096d0b0 100644 --- a/oml/datasets/base.py +++ b/oml/datasets/base.py @@ -1,344 +1,14 @@ -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from oml.datasets.images import ImagesDatasetQueryGallery, ImagesDatasetWithLabels -import albumentations as albu -import numpy as np -import pandas as pd -import torchvision -from torch.utils.data import Dataset -from oml.const import ( - CATEGORIES_COLUMN, - CATEGORIES_KEY, - INDEX_KEY, - INPUT_TENSORS_KEY, - IS_GALLERY_COLUMN, - IS_GALLERY_KEY, - IS_QUERY_COLUMN, - IS_QUERY_KEY, - LABELS_COLUMN, - LABELS_KEY, - PATHS_COLUMN, - PATHS_KEY, - SEQUENCE_COLUMN, - SEQUENCE_KEY, - SPLIT_COLUMN, - X1_COLUMN, - X1_KEY, - X2_COLUMN, - X2_KEY, - Y1_COLUMN, - Y1_KEY, - Y2_COLUMN, - Y2_KEY, -) -from oml.interfaces.datasets import IDatasetQueryGallery, IDatasetWithLabels -from oml.registry.transforms import get_transforms -from oml.transforms.images.utils import TTransforms, get_im_reader_for_transforms -from oml.utils.dataframe_format import check_retrieval_dataframe_format -from oml.utils.images.images import TImReader +class DatasetWithLabels(ImagesDatasetWithLabels): + # this class allows to have back compatibility + pass -class BaseDataset(Dataset): - """ - Base class for the retrieval datasets. +class DatasetQueryGallery(ImagesDatasetQueryGallery): + # this class allows to have back compatibility + pass - """ - def __init__( - self, - df: pd.DataFrame, - extra_data: Optional[Dict[str, Any]] = None, - transform: Optional[TTransforms] = None, - dataset_root: Optional[Union[str, Path]] = None, - f_imread: Optional[TImReader] = None, - cache_size: Optional[int] = 0, - input_tensors_key: str = INPUT_TENSORS_KEY, - labels_key: str = LABELS_KEY, - paths_key: str = PATHS_KEY, - categories_key: Optional[str] = CATEGORIES_KEY, - sequence_key: Optional[str] = SEQUENCE_KEY, - x1_key: str = X1_KEY, - x2_key: str = X2_KEY, - y1_key: str = Y1_KEY, - y2_key: str = Y2_KEY, - index_key: str = INDEX_KEY, - ): - """ - - Args: - df: Table with the following obligatory columns: - - ``LABELS_COLUMN``, ``PATHS_COLUMN`` - - and the optional ones: - - ``X1_COLUMN``, ``X2_COLUMN``, ``Y1_COLUMN``, ``Y2_COLUMN``, ``CATEGORIES_COLUMN`` - - extra_data: Dictionary with additional information which we want to put into batches. We assume that - the length of each record in this structure is the same as dataset's size. - transform: Augmentations for the images, set ``None`` to perform only normalisation and casting to tensor - dataset_root: Path to the images' dir, set ``None`` if you provided the absolute paths in your dataframe - f_imread: Function to read the images, pass ``None`` so we pick it autmatically based on provided transforms - cache_size: Size of the dataset's cache - input_tensors_key: Key to put tensors into the batches - labels_key: Key to put labels into the batches - paths_key: Key put paths into the batches - categories_key: Key to put categories into the batches - sequence_key: Key to put sequence ids into the batches - x1_key: Key to put ``x1`` into the batches - x2_key: Key to put ``x2`` into the batches - y1_key: Key to put ``y1`` into the batches - y2_key: Key to put ``y2`` into the batches - index_key: Key to put samples' ids into the batches - - """ - df = df.copy() - - if extra_data is not None: - assert all( - len(record) == len(df) for record in extra_data.values() - ), "All the extra records need to have the size equal to the dataset's size" - - assert all(x in df.columns for x in (LABELS_COLUMN, PATHS_COLUMN)) - - self.input_tensors_key = input_tensors_key - self.labels_key = labels_key - self.paths_key = paths_key - self.categories_key = categories_key if (CATEGORIES_COLUMN in df.columns) else None - self.sequence_key = sequence_key if (SEQUENCE_COLUMN in df.columns) else None - self.index_key = index_key - - self.bboxes_exist = all(coord in df.columns for coord in (X1_COLUMN, X2_COLUMN, Y1_COLUMN, Y2_COLUMN)) - if self.bboxes_exist: - self.x1_key, self.x2_key, self.y1_key, self.y2_key = x1_key, x2_key, y1_key, y2_key - else: - self.x1_key, self.x2_key, self.y1_key, self.y2_key = None, None, None, None - - if dataset_root is not None: - dataset_root = Path(dataset_root) - df[PATHS_COLUMN] = df[PATHS_COLUMN].apply(lambda x: str(dataset_root / x)) - else: - df[PATHS_COLUMN] = df[PATHS_COLUMN].astype(str) - - self.df = df - self.extra_data = extra_data - self.transform = transform if transform else get_transforms("norm_albu") - self.f_imread = f_imread or get_im_reader_for_transforms(transform) - self.read_bytes_image = ( - lru_cache(maxsize=cache_size)(self._read_bytes_image) if cache_size else self._read_bytes_image - ) - - available_augs_types = (albu.Compose, torchvision.transforms.Compose) - assert isinstance(self.transform, available_augs_types), f"Type of transforms must be in {available_augs_types}" - - @staticmethod - def _read_bytes_image(path: Union[Path, str]) -> bytes: - with open(str(path), "rb") as fin: - return fin.read() - - def __getitem__(self, idx: int) -> Dict[str, Any]: - row = self.df.iloc[idx] - - img_bytes = self.read_bytes_image(row[PATHS_COLUMN]) # type: ignore - img = self.f_imread(img_bytes) - - im_h, im_w = img.shape[:2] if isinstance(img, np.ndarray) else img.size[::-1] - - if (not self.bboxes_exist) or any( - pd.isna(coord) for coord in [row[X1_COLUMN], row[X2_COLUMN], row[Y1_COLUMN], row[Y2_COLUMN]] - ): - x1, y1, x2, y2 = 0, 0, im_w, im_h - else: - x1, y1, x2, y2 = int(row[X1_COLUMN]), int(row[Y1_COLUMN]), int(row[X2_COLUMN]), int(row[Y2_COLUMN]) - - if isinstance(self.transform, albu.Compose): - img = img[y1:y2, x1:x2, :] # todo: since albu may handle bboxes we should move it to augs - image_tensor = self.transform(image=img)["image"] - else: - # torchvision.transforms - img = img.crop((x1, y1, x2, y2)) - image_tensor = self.transform(img) - - item = { - self.input_tensors_key: image_tensor, - self.labels_key: row[LABELS_COLUMN], - self.paths_key: row[PATHS_COLUMN], - self.index_key: idx, - } - - if self.categories_key: - item[self.categories_key] = row[CATEGORIES_COLUMN] - - if self.sequence_key: - item[self.sequence_key] = row[SEQUENCE_COLUMN] - - if self.bboxes_exist: - item.update( - { - self.x1_key: x1, - self.y1_key: y1, - self.x2_key: x2, - self.y2_key: y2, - } - ) - - if self.extra_data: - for key, record in self.extra_data.items(): - if key in item: - raise ValueError(f" and dataset share the same key: {key}") - else: - item[key] = record[idx] - - return item - - def __len__(self) -> int: - return len(self.df) - - @property - def bboxes_keys(self) -> Tuple[str, ...]: - if self.bboxes_exist: - return self.x1_key, self.y1_key, self.x2_key, self.y2_key - else: - return tuple() - - -class DatasetWithLabels(BaseDataset, IDatasetWithLabels): - """ - The main purpose of this class is to be used as a dataset during - the training stage. - - It has to know how to return its labels, which is required information - to perform the training with the combinations-based losses. - Particularly, these labels will be passed to `Sampler` to form the batches and - batches will be passed to `Miner` to form the combinations (triplets). - - """ - - def get_labels(self) -> np.ndarray: - return np.array(self.df[LABELS_COLUMN].tolist()) - - def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]: - """ - Returns: - Label to category mapping if there was category information in DataFrame, None otherwise. - - """ - if CATEGORIES_COLUMN in self.df.columns: - label2category = dict(zip(self.df[LABELS_COLUMN], self.df[CATEGORIES_COLUMN])) - else: - label2category = None - - return label2category - - -class DatasetQueryGallery(BaseDataset, IDatasetQueryGallery): - """ - The main purpose of this class is to be used as a dataset during - the validation stage. It has to provide information - about its `query`/`gallery` split. - - Note, that some datasets used as benchmarks in Metric Learning - provide the splitting information (for example, ``DeepFashion InShop`` dataset), but some of them - don't (for example, ``CARS196`` or ``CUB200``). - The validation idea for the latter is to calculate the embeddings for the whole validation set, - then for every item find ``top-k`` nearest neighbors and calculate the desired retrieval metric. - In other words, for the desired query item, the gallery is the rest of the validation dataset. - - Thus, if you want to perform this kind of validation process (`1 vs rest`) you should simply return - ``is_query == True`` and ``is_gallery == True`` for every item in the dataset as the same time. - - """ - - def __init__( - self, - df: pd.DataFrame, - extra_data: Optional[Dict[str, Any]] = None, - dataset_root: Optional[Union[str, Path]] = None, - transform: Optional[albu.Compose] = None, - f_imread: Optional[TImReader] = None, - cache_size: Optional[int] = 0, - input_tensors_key: str = INPUT_TENSORS_KEY, - labels_key: str = LABELS_KEY, - paths_key: str = PATHS_KEY, - categories_key: str = CATEGORIES_KEY, - x1_key: str = X1_KEY, - x2_key: str = X2_KEY, - y1_key: str = Y1_KEY, - y2_key: str = Y2_KEY, - is_query_key: str = IS_QUERY_KEY, - is_gallery_key: str = IS_GALLERY_KEY, - ): - super(DatasetQueryGallery, self).__init__( - df=df, - extra_data=extra_data, - dataset_root=dataset_root, - transform=transform, - f_imread=f_imread, - cache_size=cache_size, - input_tensors_key=input_tensors_key, - labels_key=labels_key, - paths_key=paths_key, - categories_key=categories_key, - x1_key=x1_key, - x2_key=x2_key, - y1_key=y1_key, - y2_key=y2_key, - ) - assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN)) - - self.is_query_key = is_query_key - self.is_gallery_key = is_gallery_key - - def __getitem__(self, idx: int) -> Dict[str, Any]: - item = super().__getitem__(idx) - item[self.is_query_key] = bool(self.df.iloc[idx][IS_QUERY_COLUMN]) - item[self.is_gallery_key] = bool(self.df.iloc[idx][IS_GALLERY_COLUMN]) - return item - - -def get_retrieval_datasets( - dataset_root: Path, - transforms_train: Any, - transforms_val: Any, - f_imread_train: Optional[TImReader] = None, - f_imread_val: Optional[TImReader] = None, - dataframe_name: str = "df.csv", - cache_size: Optional[int] = 0, - verbose: bool = True, -) -> Tuple[DatasetWithLabels, DatasetQueryGallery]: - df = pd.read_csv(dataset_root / dataframe_name, index_col=False) - - check_retrieval_dataframe_format(df, dataset_root=dataset_root, verbose=verbose) - - # first half will consist of "train" split, second one of "val" - # so labels in train will be from 0 to N-1 and labels in test will be from N to K - mapper = {l: i for i, l in enumerate(df.sort_values(by=[SPLIT_COLUMN])[LABELS_COLUMN].unique())} - - # train - df_train = df[df[SPLIT_COLUMN] == "train"].reset_index(drop=True) - df_train[LABELS_COLUMN] = df_train[LABELS_COLUMN].map(mapper) - - train_dataset = DatasetWithLabels( - df=df_train, - dataset_root=dataset_root, - transform=transforms_train, - cache_size=cache_size, - f_imread=f_imread_train, - ) - - # val (query + gallery) - df_query_gallery = df[df[SPLIT_COLUMN] == "validation"].reset_index(drop=True) - valid_dataset = DatasetQueryGallery( - df=df_query_gallery, - dataset_root=dataset_root, - transform=transforms_val, - cache_size=cache_size, - f_imread=f_imread_val, - ) - - return train_dataset, valid_dataset - - -__all__ = ["BaseDataset", "DatasetWithLabels", "DatasetQueryGallery", "get_retrieval_datasets"] +__all__ = ["DatasetWithLabels", "DatasetQueryGallery"] diff --git a/oml/datasets/images.py b/oml/datasets/images.py new file mode 100644 index 000000000..8e2da5adf --- /dev/null +++ b/oml/datasets/images.py @@ -0,0 +1,428 @@ +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import albumentations as albu +import numpy as np +import pandas as pd +import torch +import torchvision +from torch import BoolTensor, FloatTensor, LongTensor + +from oml.const import ( + BLACK, + CATEGORIES_COLUMN, + CATEGORIES_KEY, + INDEX_KEY, + INPUT_TENSORS_KEY, + IS_GALLERY_COLUMN, + IS_GALLERY_KEY, + IS_QUERY_COLUMN, + IS_QUERY_KEY, + LABELS_COLUMN, + LABELS_KEY, + PATHS_COLUMN, + PATHS_KEY, + SEQUENCE_COLUMN, + SEQUENCE_KEY, + SPLIT_COLUMN, + X1_COLUMN, + X1_KEY, + X2_COLUMN, + X2_KEY, + Y1_COLUMN, + Y1_KEY, + Y2_COLUMN, + Y2_KEY, + TBBoxes, + TColor, +) +from oml.interfaces.datasets import ( + IBaseDataset, + IDatasetQueryGallery, + IDatasetWithLabels, + IVisualizableDataset, +) +from oml.registry.transforms import get_transforms +from oml.transforms.images.utils import TTransforms, get_im_reader_for_transforms +from oml.utils.dataframe_format import check_retrieval_dataframe_format +from oml.utils.images.images import TImReader, get_img_with_bbox, square_pad + +# todo 522: general comment on Datasets +# We will remove using keys in __getitem__ for: +# Passing extra information (like categories or sequence id) -> we will use .extra_data instead +# Modality related info (like bboxes or paths) -> they may only exist as internals of the datasets +# is_query_key, is_gallery_key -> get_query_ids() and get_gallery_ids() methods +# Before this, we temporary keep both approaches + + +def parse_bboxes(df: pd.DataFrame) -> Optional[TBBoxes]: + n_existing_columns = sum([x in df for x in [X1_COLUMN, X2_COLUMN, Y1_COLUMN, Y2_COLUMN]]) + + if n_existing_columns == 4: + bboxes = [] + for row in df.iterrows(): + bbox = int(row[X1_COLUMN]), int(row[X2_COLUMN]), int(row[Y1_COLUMN]), int(row[Y2_COLUMN]) + bbox = None if any(coord is None for coord in bbox) else bbox + bboxes.append(bbox) + + elif n_existing_columns == 0: + bboxes = None + + else: + raise ValueError(f"Found {n_existing_columns} bounding bboxes columns instead of 4. Check your dataframe.") + + return bboxes + + +class ImagesBaseDataset(IBaseDataset, IVisualizableDataset): + """ + The base class that handles image specific logic. + + """ + + input_tensors_key: str + index_key: str + + def __init__( + self, + paths: List[str], + dataset_root: Optional[Union[str, Path]] = None, + bboxes: Optional[TBBoxes] = None, + extra_data: Optional[Dict[str, Any]] = None, + transform: Optional[TTransforms] = None, + f_imread: Optional[TImReader] = None, + cache_size: Optional[int] = 0, + input_tensors_key: str = INPUT_TENSORS_KEY, + index_key: str = INDEX_KEY, + # todo 522: remove + paths_key: str = PATHS_KEY, + x1_key: str = X1_KEY, + x2_key: str = X2_KEY, + y1_key: str = Y1_KEY, + y2_key: str = Y2_KEY, + ): + """ + + Args: + paths: Paths to images. Will be concatenated with ``dataset_root`` is provided. + dataset_root: Path to the images' dir, set ``None`` if you provided the absolute paths in your dataframe + bboxes: Bounding boxes of images. Some of the images may not have bounding bboxes. + extra_data: Dictionary containing records of some additional information. + transform: Augmentations for the images, set ``None`` to perform only normalisation and casting to tensor + f_imread: Function to read the images, pass ``None`` to pick it automatically based on provided transforms + cache_size: Size of the dataset's cache + input_tensors_key: Key to put tensors into the batches + index_key: Key to put samples' ids into the batches + paths_key: Key put paths into the batches # todo 522: remove + x1_key: Key to put ``x1`` into the batches # todo 522: remove + x2_key: Key to put ``x2`` into the batches # todo 522: remove + y1_key: Key to put ``y1`` into the batches # todo 522: remove + y2_key: Key to put ``y2`` into the batches # todo 522: remove + + """ + assert (bboxes is None) or (len(paths) == len(bboxes)) + + if extra_data is not None: + assert all( + len(record) == len(paths) for record in extra_data.values() + ), "All the extra records need to have the size equal to the dataset's size" + + self.input_tensors_key = input_tensors_key + self.index_key = index_key + + if dataset_root is not None: + self._paths = list(map(lambda x: str(Path(dataset_root) / x), paths)) + else: + self._paths = paths + + self.extra_data = extra_data + + self._bboxes = bboxes + self._transform = transform if transform else get_transforms("norm_albu") + self._f_imread = f_imread or get_im_reader_for_transforms(transform) + + if cache_size: + self.read_bytes = lru_cache(maxsize=cache_size)(self._read_bytes) # type: ignore + else: + self.read_bytes = self._read_bytes # type: ignore + + available_transforms = (albu.Compose, torchvision.transforms.Compose) + assert isinstance(self._transform, available_transforms), f"Transforms must one of: {available_transforms}" + + # todo 522: remove + self.paths_key = paths_key + self.x1_key = x1_key + self.x2_key = x2_key + self.y1_key = y1_key + self.y2_key = y2_key + + @staticmethod + def _read_bytes(path: Union[Path, str]) -> bytes: + with open(str(path), "rb") as fin: + return fin.read() + + def __getitem__(self, idx: int) -> Dict[str, Union[FloatTensor, int]]: + img_bytes = self.read_bytes(self._paths[idx]) + img = self._f_imread(img_bytes) + + im_h, im_w = img.shape[:2] if isinstance(img, np.ndarray) else img.size[::-1] + + if (self._bboxes is not None) and (self._bboxes[idx] is not None): + x1, y1, x2, y2 = self._bboxes[idx] + else: + x1, y1, x2, y2 = 0, 0, im_w, im_h + + if isinstance(self._transform, albu.Compose): + img = img[y1:y2, x1:x2, :] + image_tensor = self._transform(image=img)["image"] + else: + # torchvision.transforms + img = img.crop((x1, y1, x2, y2)) + image_tensor = self._transform(img) + + item = { + self.input_tensors_key: image_tensor, + self.index_key: idx, + } + + if self.extra_data: + for key, record in self.extra_data.items(): + if key in item: + raise ValueError(f" and dataset share the same key: {key}") + else: + item[key] = record[idx] + + # todo 522: remove + item[self.x1_key] = x1 + item[self.y1_key] = y1 + item[self.x2_key] = x2 + item[self.y2_key] = y2 + item[self.paths_key] = str(self._paths[idx]) + + return item + + def __len__(self) -> int: + return len(self._paths) + + def visualize(self, idx: int, color: TColor = BLACK) -> np.ndarray: + bbox = torch.tensor(self._bboxes[idx]) if (self._bboxes is not None) else torch.tensor([None] * 4) + image = get_img_with_bbox(im_path=self._paths[idx], bbox=bbox, color=color) + image = square_pad(image) + + return image + + # todo 522: remove + @property + def bboxes_keys(self) -> Tuple[str, ...]: + return self.x1_key, self.y1_key, self.x2_key, self.y2_key + + +class ImagesDatasetWithLabels(ImagesBaseDataset, IDatasetWithLabels): + """ + The dataset of images having their ground truth labels. + + """ + + def __init__( + self, + df: pd.DataFrame, + extra_data: Optional[Dict[str, Any]] = None, + dataset_root: Optional[Union[str, Path]] = None, + transform: Optional[albu.Compose] = None, + f_imread: Optional[TImReader] = None, + cache_size: Optional[int] = 0, + input_tensors_key: str = INPUT_TENSORS_KEY, + labels_key: str = LABELS_KEY, + index_key: str = INDEX_KEY, + # todo 522: remove + paths_key: str = PATHS_KEY, + categories_key: Optional[str] = CATEGORIES_KEY, + sequence_key: Optional[str] = SEQUENCE_KEY, + x1_key: str = X1_KEY, + x2_key: str = X2_KEY, + y1_key: str = Y1_KEY, + y2_key: str = Y2_KEY, + ): + assert (LABELS_COLUMN in df) and (PATHS_COLUMN in df), "There are only 2 required columns." + self.labels_key = labels_key + self.df = df + + extra_data = {} if extra_data is None else extra_data + + super().__init__( + paths=self.df[PATHS_COLUMN].tolist(), + bboxes=parse_bboxes(self.df), + extra_data=extra_data, + dataset_root=dataset_root, + transform=transform, + f_imread=f_imread, + cache_size=cache_size, + input_tensors_key=input_tensors_key, + index_key=index_key, + # todo 522: remove + x1_key=x1_key, + y2_key=y2_key, + x2_key=x2_key, + y1_key=y1_key, + paths_key=paths_key, + ) + + # todo 522: remove + self.categories_key = categories_key if (CATEGORIES_COLUMN in df.columns) else None + self.sequence_key = sequence_key if (SEQUENCE_COLUMN in df.columns) else None + + def __getitem__(self, idx: int) -> Dict[str, Any]: + item = super().__getitem__(idx) + item[self.labels_key] = self.df.iloc[idx][LABELS_COLUMN] + + # todo 522: remove + if self.sequence_key: + item[self.sequence_key] = self.df[SEQUENCE_COLUMN][idx] + + if self.categories_key: + item[self.categories_key] = self.df[CATEGORIES_COLUMN][idx] + + return item + + def get_labels(self) -> np.ndarray: + return np.array(self.df[LABELS_COLUMN]) + + # todo 522: remove + def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]: + if CATEGORIES_COLUMN in self.df.columns: + label2category = dict(zip(self.df[LABELS_COLUMN], self.df[CATEGORIES_COLUMN])) + else: + label2category = None + + return label2category + + +class ImagesDatasetQueryGallery(ImagesDatasetWithLabels, IDatasetQueryGallery): + """ + The dataset of images having `query`/`gallery` split. + + Note, that some datasets used as benchmarks in Metric Learning + explicitly provide the splitting information (for example, ``DeepFashion InShop`` dataset), but some of them + don't (for example, ``CARS196`` or ``CUB200``). The validation idea for the latter is to perform `1 vs rest` + validation, where every query is evaluated versus the whole validation dataset (except for this exact query). + + So, if you want an item participate in validation as both: query and gallery, you should mark this item as + ``is_query == True`` and ``is_gallery == True``, as it's done in the `CARS196` or `CUB200` dataset. + + """ + + def __init__( + self, + df: pd.DataFrame, + extra_data: Optional[Dict[str, Any]] = None, + dataset_root: Optional[Union[str, Path]] = None, + transform: Optional[albu.Compose] = None, + f_imread: Optional[TImReader] = None, + cache_size: Optional[int] = 0, + input_tensors_key: str = INPUT_TENSORS_KEY, + labels_key: str = LABELS_KEY, + # todo 522: remove + paths_key: str = PATHS_KEY, + categories_key: Optional[str] = CATEGORIES_KEY, + sequence_key: Optional[str] = SEQUENCE_KEY, + x1_key: str = X1_KEY, + x2_key: str = X2_KEY, + y1_key: str = Y1_KEY, + y2_key: str = Y2_KEY, + is_query_key: str = IS_QUERY_KEY, + is_gallery_key: str = IS_GALLERY_KEY, + ): + assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, LABELS_COLUMN)) + self._df = df + + super().__init__( + df=df, + extra_data=extra_data, + dataset_root=dataset_root, + transform=transform, + f_imread=f_imread, + cache_size=cache_size, + input_tensors_key=input_tensors_key, + labels_key=labels_key, + # todo 522: remove + x1_key=x1_key, + y2_key=y2_key, + x2_key=x2_key, + y1_key=y1_key, + paths_key=paths_key, + categories_key=categories_key, + sequence_key=sequence_key, + ) + + # todo 522: remove + self.is_query_key = is_query_key + self.is_gallery_key = is_gallery_key + + def get_query_ids(self) -> LongTensor: + return BoolTensor(self._df[IS_QUERY_COLUMN]).nonzero().squeeze() + + def get_gallery_ids(self) -> LongTensor: + return BoolTensor(self._df[IS_GALLERY_COLUMN]).nonzero().squeeze() + + def __getitem__(self, idx: int) -> Dict[str, Any]: + item = super().__getitem__(idx) + item[self.labels_key] = self._df.iloc[idx][LABELS_COLUMN] + + # todo 522: remove + item[self.is_query_key] = bool(self._df[IS_QUERY_COLUMN][idx]) + item[self.is_gallery_key] = bool(self._df[IS_GALLERY_COLUMN][idx]) + + return item + + +def get_retrieval_images_datasets( + dataset_root: Path, + transforms_train: Any, + transforms_val: Any, + f_imread_train: Optional[TImReader] = None, + f_imread_val: Optional[TImReader] = None, + dataframe_name: str = "df.csv", + cache_size: Optional[int] = 0, + verbose: bool = True, +) -> Tuple[IDatasetWithLabels, IDatasetQueryGallery]: + df = pd.read_csv(dataset_root / dataframe_name, index_col=False) + + check_retrieval_dataframe_format(df, dataset_root=dataset_root, verbose=verbose) + + # todo 522: why do we need it? + # first half will consist of "train" split, second one of "val" + # so labels in train will be from 0 to N-1 and labels in test will be from N to K + mapper = {l: i for i, l in enumerate(df.sort_values(by=[SPLIT_COLUMN])[LABELS_COLUMN].unique())} + + # train + df_train = df[df[SPLIT_COLUMN] == "train"].reset_index(drop=True) + df_train[LABELS_COLUMN] = df_train[LABELS_COLUMN].map(mapper) + + train_dataset = ImagesDatasetWithLabels( + df=df_train, + dataset_root=dataset_root, + transform=transforms_train, + cache_size=cache_size, + f_imread=f_imread_train, + ) + + # val (query + gallery) + df_query_gallery = df[df[SPLIT_COLUMN] == "validation"].reset_index(drop=True) + valid_dataset = ImagesDatasetQueryGallery( + df=df_query_gallery, + dataset_root=dataset_root, + transform=transforms_val, + cache_size=cache_size, + f_imread=f_imread_val, + ) + + return train_dataset, valid_dataset + + +__all__ = [ + "ImagesBaseDataset", + "ImagesDatasetWithLabels", + "ImagesDatasetQueryGallery", + "get_retrieval_images_datasets", +] diff --git a/oml/datasets/list_dataset.py b/oml/datasets/list_dataset.py index d430c9027..8820afd37 100644 --- a/oml/datasets/list_dataset.py +++ b/oml/datasets/list_dataset.py @@ -1,8 +1,7 @@ from collections import defaultdict from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence -import pandas as pd from torch.utils.data import Dataset from oml.const import ( @@ -14,17 +13,17 @@ X2_COLUMN, Y1_COLUMN, Y2_COLUMN, + TBBoxes, ) -from oml.datasets.base import BaseDataset +from oml.datasets.images import ImagesBaseDataset from oml.transforms.images.torchvision import get_normalisation_torch from oml.transforms.images.utils import TTransforms from oml.utils.images.images import TImReader -TBBox = Tuple[int, int, int, int] -TBBoxes = Sequence[Optional[TBBox]] - class ListDataset(Dataset): + # todo 522: remove the whole dataset + """This is a dataset to iterate over a list of images.""" def __init__( @@ -68,8 +67,9 @@ def __init__( data[X2_COLUMN].append(x2) # type: ignore data[Y2_COLUMN].append(y2) # type: ignore - self._dataset = BaseDataset( - df=pd.DataFrame(data), + self._dataset = ImagesBaseDataset( + paths=list(map(str, filenames_list)), + bboxes=bboxes, transform=transform, f_imread=f_imread, input_tensors_key=input_tensors_key, diff --git a/oml/datasets/pairs.py b/oml/datasets/pairs.py index bcd928ed1..dbae225f9 100644 --- a/oml/datasets/pairs.py +++ b/oml/datasets/pairs.py @@ -3,13 +3,15 @@ from torch import Tensor -from oml.const import INDEX_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY -from oml.datasets.list_dataset import ListDataset, TBBoxes +from oml.const import INDEX_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY, TBBoxes +from oml.datasets.list_dataset import ListDataset from oml.interfaces.datasets import IPairsDataset from oml.transforms.images.torchvision import get_normalisation_torch from oml.transforms.images.utils import TTransforms from oml.utils.images.images import TImReader, imread_pillow +# todo 522: make one modality agnostic instead of these two + class EmbeddingPairsDataset(IPairsDataset): """ diff --git a/oml/interfaces/datasets.py b/oml/interfaces/datasets.py index 369693445..e396e5a5c 100644 --- a/oml/interfaces/datasets.py +++ b/oml/interfaces/datasets.py @@ -2,29 +2,39 @@ from typing import Any, Dict import numpy as np +from torch import LongTensor from torch.utils.data import Dataset -from oml.const import ( # noqa - INDEX_KEY, - INPUT_TENSORS_KEY, - IS_GALLERY_KEY, - IS_QUERY_KEY, - LABELS_KEY, - PAIR_1ST_KEY, - PAIR_2ND_KEY, -) -from oml.samplers.balance import BalanceSampler # noqa +from oml.const import INDEX_KEY, LABELS_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY, TColor -class IDatasetWithLabels(Dataset, ABC): +class IBaseDataset(Dataset): + input_tensors_key: str + index_key: str + extra_data: Dict[str, Any] + + def __getitem__(self, item: int) -> Dict[str, Any]: + """ + + Args: + item: Idx of the sample + + Returns: + Dictionary including the following keys: + ``self.input_tensors_key`` + ``self.index_key: int = item`` + + """ + raise NotImplementedError() + + +class IDatasetWithLabels(IBaseDataset, ABC): """ - This is an interface for the datasets which can provide their labels. + This is an interface for the datasets which provide labels of containing items. """ - input_tensors_key: str = INPUT_TENSORS_KEY labels_key: str = LABELS_KEY - index_key: str = INDEX_KEY def __getitem__(self, item: int) -> Dict[str, Any]: """ @@ -33,11 +43,9 @@ def __getitem__(self, item: int) -> Dict[str, Any]: item: Idx of the sample Returns: - Dictionary with the following keys: + Dictionary including the following keys: - ``self.input_tensors_key`` ``self.labels_key`` - ``self.index_key`` """ raise NotImplementedError() @@ -47,36 +55,27 @@ def get_labels(self) -> np.ndarray: raise NotImplementedError() -class IDatasetQueryGallery(Dataset, ABC): +class IDatasetQueryGalleryPrediction(IBaseDataset, ABC): """ - This is an interface for the datasets which can provide the information on how to split - the validation set into the two parts: query and gallery. + This is an interface for the datasets which hold the information on how to split + the data into the query and gallery. The query and gallery ids may overlap. + It doesn't need the ground truth labels, so it can be used for prediction on not annotated data. """ - input_tensors_key: str = INPUT_TENSORS_KEY - labels_key: str = LABELS_KEY - is_query_key: str = IS_QUERY_KEY - is_gallery_key: str = IS_GALLERY_KEY - index_key: str = INDEX_KEY - @abstractmethod - def __getitem__(self, item: int) -> Dict[str, Any]: - """ - Args: - item: Idx of the sample + def get_query_ids(self) -> LongTensor: + raise NotImplementedError() - Returns: - Dictionary with the following keys: + @abstractmethod + def get_gallery_ids(self) -> LongTensor: + raise NotImplementedError() - ``self.input_tensors_key`` - ``self.labels_key`` - ``self.is_query_key`` - ``self.is_gallery_key`` - ``self.index_key`` - """ - raise NotImplementedError() +class IDatasetQueryGallery(IDatasetQueryGalleryPrediction, IDatasetWithLabels, ABC): + """ + This class is similar to "IDatasetQueryGalleryPrediction", but we also have ground truth labels. + """ class IPairsDataset(Dataset, ABC): @@ -106,4 +105,21 @@ def __getitem__(self, item: int) -> Dict[str, Any]: raise NotImplementedError() -__all__ = ["IDatasetWithLabels", "IDatasetQueryGallery", "IPairsDataset"] +class IVisualizableDataset(Dataset, ABC): + """ + Base class for the datasets which know how to visualise their items. + """ + + @abstractmethod + def visualize(self, idx: int, color: TColor) -> np.ndarray: + raise NotImplementedError() + + +__all__ = [ + "IBaseDataset", + "IDatasetWithLabels", + "IDatasetQueryGallery", + "IDatasetQueryGalleryPrediction", + "IPairsDataset", + "IVisualizableDataset", +] diff --git a/oml/lightning/pipelines/parser.py b/oml/lightning/pipelines/parser.py index 213e98a95..c0bad14bd 100644 --- a/oml/lightning/pipelines/parser.py +++ b/oml/lightning/pipelines/parser.py @@ -6,7 +6,7 @@ from pytorch_lightning.strategies import DDPStrategy from oml.const import TCfg -from oml.datasets.base import DatasetWithLabels +from oml.interfaces.datasets import IDatasetWithLabels from oml.interfaces.loggers import IPipelineLogger from oml.interfaces.samplers import IBatchSampler from oml.lightning.pipelines.logging import TensorBoardPipelineLogger @@ -76,7 +76,7 @@ def parse_scheduler_from_config(cfg: TCfg, optimizer: torch.optim.Optimizer) -> return scheduler_kwargs -def parse_sampler_from_config(cfg: TCfg, dataset: DatasetWithLabels) -> Optional[IBatchSampler]: +def parse_sampler_from_config(cfg: TCfg, dataset: IDatasetWithLabels) -> Optional[IBatchSampler]: if ( (not dataset.categories_key) and (cfg["sampler"] is not None) diff --git a/oml/lightning/pipelines/train.py b/oml/lightning/pipelines/train.py index b59d8878b..dd86f7218 100644 --- a/oml/lightning/pipelines/train.py +++ b/oml/lightning/pipelines/train.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader from oml.const import TCfg -from oml.datasets.base import get_retrieval_datasets +from oml.datasets.images import get_retrieval_images_datasets from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP from oml.lightning.modules.extractor import ExtractorModule, ExtractorModuleDDP from oml.lightning.pipelines.parser import ( @@ -26,7 +26,7 @@ def get_retrieval_loaders(cfg: TCfg) -> Tuple[DataLoader, DataLoader]: - train_dataset, valid_dataset = get_retrieval_datasets( + train_dataset, valid_dataset = get_retrieval_images_datasets( dataset_root=Path(cfg["dataset_root"]), transforms_train=get_transforms_by_cfg(cfg["transforms_train"]), transforms_val=get_transforms_by_cfg(cfg["transforms_val"]), diff --git a/oml/lightning/pipelines/validate.py b/oml/lightning/pipelines/validate.py index b60edcf95..543522bf4 100644 --- a/oml/lightning/pipelines/validate.py +++ b/oml/lightning/pipelines/validate.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader from oml.const import TCfg -from oml.datasets.base import get_retrieval_datasets +from oml.datasets.images import get_retrieval_images_datasets from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP from oml.lightning.modules.extractor import ExtractorModule, ExtractorModuleDDP from oml.lightning.pipelines.parser import ( @@ -35,7 +35,7 @@ def extractor_validation_pipeline(cfg: TCfg) -> Tuple[pl.Trainer, Dict[str, Any] pprint(cfg) - _, valid_dataset = get_retrieval_datasets( + _, valid_dataset = get_retrieval_images_datasets( dataset_root=Path(cfg["dataset_root"]), transforms_train=None, transforms_val=get_transforms_by_cfg(cfg["transforms_val"]), diff --git a/tests/test_integrations/test_lightning/test_pipeline.py b/tests/test_integrations/test_lightning/test_pipeline.py index 3f99822b7..afb99a1cc 100644 --- a/tests/test_integrations/test_lightning/test_pipeline.py +++ b/tests/test_integrations/test_lightning/test_pipeline.py @@ -1,3 +1,4 @@ +import sys import tempfile from functools import partial from typing import Any, Dict, List @@ -92,6 +93,7 @@ def create_retrieval_callback(loader_idx: int, samples_in_getitem: int) -> Metri return metric_callback +@pytest.mark.skipif(sys.platform == "darwin", reason="Does not run on macOS") @pytest.mark.parametrize( "samples_in_getitem, is_error_expected, pipeline", [ diff --git a/tests/test_integrations/test_retrieval_validation.py b/tests/test_integrations/test_retrieval_validation.py index e03bd798b..cbaef4f61 100644 --- a/tests/test_integrations/test_retrieval_validation.py +++ b/tests/test_integrations/test_retrieval_validation.py @@ -1,24 +1,19 @@ from math import isclose -from typing import Any, Dict, Tuple +from typing import Tuple import pytest import torch -from torch import Tensor +from torch import BoolTensor, FloatTensor, LongTensor from torch.utils.data import DataLoader -from oml.const import ( - EMBEDDINGS_KEY, - INPUT_TENSORS_KEY, - IS_GALLERY_KEY, - IS_QUERY_KEY, - LABELS_KEY, - OVERALL_CATEGORIES_KEY, -) -from oml.interfaces.datasets import IDatasetQueryGallery +from oml.const import EMBEDDINGS_KEY, INPUT_TENSORS_KEY, OVERALL_CATEGORIES_KEY from oml.metrics.embeddings import EmbeddingMetrics -from tests.test_integrations.utils import IdealClusterEncoder +from tests.test_integrations.utils import ( + EmbeddingsQueryGalleryDataset, + IdealClusterEncoder, +) -TData = Tuple[Tensor, Tensor, Tensor, Tensor, float] +TData = Tuple[LongTensor, BoolTensor, BoolTensor, FloatTensor, float] def get_separate_query_gallery() -> TData: @@ -33,7 +28,7 @@ def get_separate_query_gallery() -> TData: cmc_gt = 3 / 5 - return labels, query_mask, gallery_mask, input_tensors, cmc_gt + return labels.long(), query_mask.bool(), gallery_mask.bool(), input_tensors, cmc_gt def get_shared_query_gallery() -> TData: @@ -46,28 +41,7 @@ def get_shared_query_gallery() -> TData: cmc_gt = 7 / 8 - return labels, query_mask, gallery_mask, input_tensors, cmc_gt - - -class DummyQGDataset(IDatasetQueryGallery): - def __init__(self, labels: Tensor, gallery_mask: Tensor, query_mask: Tensor, input_tensors: Tensor): - assert len(labels) == len(gallery_mask) == len(query_mask) - - self.labels = labels - self.gallery_mask = gallery_mask - self.query_mask = query_mask - self.input_tensors = input_tensors - - def __getitem__(self, idx: int) -> Dict[str, Any]: - return { - LABELS_KEY: self.labels[idx], - INPUT_TENSORS_KEY: self.input_tensors[idx], - IS_QUERY_KEY: self.query_mask[idx], - IS_GALLERY_KEY: self.gallery_mask[idx], - } - - def __len__(self) -> int: - return len(self.labels) + return labels.long(), query_mask.bool(), gallery_mask.bool(), input_tensors, cmc_gt @pytest.mark.parametrize("batch_size", [1, 5]) @@ -77,11 +51,11 @@ def __len__(self) -> int: def test_retrieval_validation(batch_size: int, shuffle: bool, num_workers: int, data: TData) -> None: labels, query_mask, gallery_mask, input_tensors, cmc_gt = data - dataset = DummyQGDataset( + dataset = EmbeddingsQueryGalleryDataset( labels=labels, - input_tensors=input_tensors, - query_mask=query_mask, - gallery_mask=gallery_mask, + embeddings=input_tensors, + is_query=query_mask, + is_gallery=gallery_mask, ) loader = DataLoader( diff --git a/tests/test_integrations/utils.py b/tests/test_integrations/utils.py index 48d4fa7a2..2cb79cbac 100644 --- a/tests/test_integrations/utils.py +++ b/tests/test_integrations/utils.py @@ -1,6 +1,19 @@ +from typing import Any, Dict, Optional + +import numpy as np import torch -from torch import nn +from torch import BoolTensor, FloatTensor, LongTensor, nn +from oml.const import ( + CATEGORIES_COLUMN, + INDEX_KEY, + INPUT_TENSORS_KEY, + IS_GALLERY_KEY, + IS_QUERY_KEY, + LABELS_KEY, + SEQUENCE_COLUMN, +) +from oml.interfaces.datasets import IDatasetQueryGallery from oml.utils.misc import one_hot @@ -20,3 +33,60 @@ def forward(self, labels: torch.Tensor, need_noise: bool = True) -> torch.Tensor embeddings = labels + need_noise * 0.01 * torch.randn_like(labels, dtype=torch.float) embeddings = embeddings.view((len(labels), 1)) return embeddings + + +class EmbeddingsQueryGalleryDataset(IDatasetQueryGallery): + def __init__( + self, + embeddings: FloatTensor, + labels: LongTensor, + is_query: BoolTensor, + is_gallery: BoolTensor, + categories: Optional[np.ndarray] = None, + sequence: Optional[np.ndarray] = None, + input_tensors_key: str = INPUT_TENSORS_KEY, + labels_key: str = LABELS_KEY, + index_key: str = INDEX_KEY, + ): + super().__init__() + assert len(embeddings) == len(labels) == len(is_query) == len(is_gallery) + + self._embeddings = embeddings + self._labels = labels + self._is_query = is_query + self._is_gallery = is_gallery + + self.extra_data = {} + if categories: + self.extra_data[CATEGORIES_COLUMN] = categories + + if sequence: + self.extra_data[SEQUENCE_COLUMN] = sequence + + self.input_tensors_key = input_tensors_key + self.labels_key = labels_key + self.index_key = index_key + + def __getitem__(self, idx: int) -> Dict[str, Any]: + batch = { + self.input_tensors_key: self._embeddings[idx], + self.labels_key: self._labels[idx], + self.index_key: idx, + # todo 522: remove + IS_QUERY_KEY: self._is_query[idx], + IS_GALLERY_KEY: self._is_gallery[idx], + } + + return batch + + def __len__(self) -> int: + return len(self._embeddings) + + def get_query_ids(self) -> LongTensor: + return self._is_query.nonzero().squeeze() + + def get_gallery_ids(self) -> LongTensor: + return self._is_gallery.nonzero().squeeze() + + def get_labels(self) -> np.ndarray: + return np.array(self._labels) diff --git a/tests/test_oml/test_datasets/test_list_dataest.py b/tests/test_oml/test_datasets/test_list_dataest.py index b0d6391ee..c431b5e44 100644 --- a/tests/test_oml/test_datasets/test_list_dataest.py +++ b/tests/test_oml/test_datasets/test_list_dataest.py @@ -6,8 +6,8 @@ import torch from torch.utils.data import DataLoader -from oml.const import MOCK_DATASET_PATH -from oml.datasets.list_dataset import ListDataset, TBBox +from oml.const import MOCK_DATASET_PATH, TBBox +from oml.datasets.list_dataset import ListDataset @pytest.fixture diff --git a/tests/test_oml/test_registry/test_registry.py b/tests/test_oml/test_registry/test_registry.py index 6654d5722..76eb71342 100644 --- a/tests/test_oml/test_registry/test_registry.py +++ b/tests/test_oml/test_registry/test_registry.py @@ -140,9 +140,10 @@ def test_saving_transforms_as_files() -> None: """ cfg = yaml.safe_load(cfg) - save_transforms_as_files(cfg) + names_files = save_transforms_as_files(cfg) - assert Path("transforms_train.yaml").exists() - assert not Path("transforms_val.yaml").exists() + assert len(names_files) == 1, "Check that we only saved train transforms as expected" - Path("transforms_train.yaml").unlink() + file = Path(names_files[0][1]) + assert file.exists() + Path(file).unlink() From cf3f1dbd51d5ca72f1cec933d70b7d8d71327806 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Fri, 19 Apr 2024 08:27:23 +0700 Subject: [PATCH 02/23] removed ListDataset --- docs/source/contents/datasets.rst | 9 -- oml/datasets/images.py | 6 +- oml/datasets/list_dataset.py | 91 ------------------- oml/datasets/pairs.py | 6 +- oml/inference/flat.py | 4 +- oml/lightning/pipelines/predict.py | 4 +- .../test_datasets/test_list_dataest.py | 79 ---------------- 7 files changed, 10 insertions(+), 189 deletions(-) delete mode 100644 oml/datasets/list_dataset.py delete mode 100644 tests/test_oml/test_datasets/test_list_dataest.py diff --git a/docs/source/contents/datasets.rst b/docs/source/contents/datasets.rst index bcb909bfd..63f284616 100644 --- a/docs/source/contents/datasets.rst +++ b/docs/source/contents/datasets.rst @@ -35,15 +35,6 @@ DatasetQueryGallery .. automethod:: __init__ .. automethod:: __getitem__ -ListDataset -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.list_dataset.ListDataset - :undoc-members: - :show-inheritance: - - .. automethod:: __init__ - .. automethod:: __getitem__ - EmbeddingPairsDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: oml.datasets.pairs.EmbeddingPairsDataset diff --git a/oml/datasets/images.py b/oml/datasets/images.py index 8e2da5adf..a1ed54f15 100644 --- a/oml/datasets/images.py +++ b/oml/datasets/images.py @@ -86,7 +86,7 @@ class ImagesBaseDataset(IBaseDataset, IVisualizableDataset): def __init__( self, - paths: List[str], + paths: List[Path], dataset_root: Optional[Union[str, Path]] = None, bboxes: Optional[TBBoxes] = None, extra_data: Optional[Dict[str, Any]] = None, @@ -134,7 +134,7 @@ def __init__( if dataset_root is not None: self._paths = list(map(lambda x: str(Path(dataset_root) / x), paths)) else: - self._paths = paths + self._paths = list(map(str, paths)) self.extra_data = extra_data @@ -198,7 +198,7 @@ def __getitem__(self, idx: int) -> Dict[str, Union[FloatTensor, int]]: item[self.y1_key] = y1 item[self.x2_key] = x2 item[self.y2_key] = y2 - item[self.paths_key] = str(self._paths[idx]) + item[self.paths_key] = self._paths[idx] return item diff --git a/oml/datasets/list_dataset.py b/oml/datasets/list_dataset.py deleted file mode 100644 index 8820afd37..000000000 --- a/oml/datasets/list_dataset.py +++ /dev/null @@ -1,91 +0,0 @@ -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, Optional, Sequence - -from torch.utils.data import Dataset - -from oml.const import ( - INDEX_KEY, - INPUT_TENSORS_KEY, - LABELS_COLUMN, - PATHS_COLUMN, - X1_COLUMN, - X2_COLUMN, - Y1_COLUMN, - Y2_COLUMN, - TBBoxes, -) -from oml.datasets.images import ImagesBaseDataset -from oml.transforms.images.torchvision import get_normalisation_torch -from oml.transforms.images.utils import TTransforms -from oml.utils.images.images import TImReader - - -class ListDataset(Dataset): - # todo 522: remove the whole dataset - - """This is a dataset to iterate over a list of images.""" - - def __init__( - self, - filenames_list: Sequence[Path], - bboxes: Optional[TBBoxes] = None, - transform: TTransforms = get_normalisation_torch(), - f_imread: Optional[TImReader] = None, - input_tensors_key: str = INPUT_TENSORS_KEY, - cache_size: Optional[int] = 0, - index_key: str = INDEX_KEY, - ): - """ - Args: - filenames_list: list of paths to images - bboxes: Should be either ``None`` or a sequence of bboxes. - If an image has ``N`` boxes, duplicate its - path ``N`` times and provide bounding box for each of them. - If you want to get an embedding for the whole image, set bbox to ``None`` for - this particular image path. The format is ``x1, y1, x2, y2``. - transform: torchvision or albumentations augmentations - f_imread: Function to read images, pass ``None`` so we pick it automatically based on provided transforms - input_tensors_key: Key to put tensors into the batches - cache_size: cache_size: Size of the dataset's cache - index_key: Key to put samples' ids into the batches - - """ - data = defaultdict(list) - data[PATHS_COLUMN] = list(map(str, filenames_list)) - data[LABELS_COLUMN] = ["none"] * len(filenames_list) - - if bboxes is not None: - for bbox in bboxes: - if bbox is not None: - x1, y1, x2, y2 = bbox - else: - x1, y1, x2, y2 = None, None, None, None - - data[X1_COLUMN].append(x1) # type: ignore - data[Y1_COLUMN].append(y1) # type: ignore - data[X2_COLUMN].append(x2) # type: ignore - data[Y2_COLUMN].append(y2) # type: ignore - - self._dataset = ImagesBaseDataset( - paths=list(map(str, filenames_list)), - bboxes=bboxes, - transform=transform, - f_imread=f_imread, - input_tensors_key=input_tensors_key, - cache_size=cache_size, - index_key=index_key, - ) - - self.input_tensors_key = input_tensors_key - self.index_key = index_key - self.paths_key = self._dataset.paths_key - - def __getitem__(self, idx: int) -> Dict[str, Any]: - return self._dataset[idx] - - def __len__(self) -> int: - return len(self._dataset) - - -__all__ = ["ListDataset"] diff --git a/oml/datasets/pairs.py b/oml/datasets/pairs.py index dbae225f9..7d755e1ff 100644 --- a/oml/datasets/pairs.py +++ b/oml/datasets/pairs.py @@ -4,7 +4,7 @@ from torch import Tensor from oml.const import INDEX_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY, TBBoxes -from oml.datasets.list_dataset import ListDataset +from oml.datasets.images import ImagesBaseDataset from oml.interfaces.datasets import IPairsDataset from oml.transforms.images.torchvision import get_normalisation_torch from oml.transforms.images.utils import TTransforms @@ -98,8 +98,8 @@ def __init__( cache_size = cache_size // 2 if cache_size else None dataset_args = {"transform": transform, "f_imread": f_imread, "cache_size": cache_size} - self.dataset1 = ListDataset(paths1, bboxes=bboxes1, **dataset_args) - self.dataset2 = ListDataset(paths2, bboxes=bboxes2, **dataset_args) + self.dataset1 = ImagesBaseDataset(paths=paths1, bboxes=bboxes1, **dataset_args) + self.dataset2 = ImagesBaseDataset(paths=paths2, bboxes=bboxes2, **dataset_args) self.pair_1st_key = pair_1st_key self.pair_2nd_key = pair_2nd_key diff --git a/oml/inference/flat.py b/oml/inference/flat.py index 8c1584a28..00c98b959 100644 --- a/oml/inference/flat.py +++ b/oml/inference/flat.py @@ -7,7 +7,7 @@ from torch import Tensor, nn from oml.const import PATHS_COLUMN, SPLIT_COLUMN -from oml.datasets.list_dataset import ListDataset +from oml.datasets.images import ImagesBaseDataset from oml.inference.abstract import _inference from oml.interfaces.models import IExtractor from oml.transforms.images.utils import TTransforms @@ -28,7 +28,7 @@ def inference_on_images( use_fp16: bool = False, accumulate_on_cpu: bool = True, ) -> Tensor: - dataset = ListDataset(paths, bboxes=None, transform=transform, f_imread=f_imread, cache_size=0) + dataset = ImagesBaseDataset(paths=paths, bboxes=None, transform=transform, f_imread=f_imread, cache_size=0) device = get_device(model) def _apply(model_: nn.Module, batch_: Dict[str, Any]) -> Tensor: diff --git a/oml/lightning/pipelines/predict.py b/oml/lightning/pipelines/predict.py index 1e188187d..12849bf9d 100644 --- a/oml/lightning/pipelines/predict.py +++ b/oml/lightning/pipelines/predict.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader from oml.const import IMAGE_EXTENSIONS, TCfg -from oml.datasets.list_dataset import ListDataset +from oml.datasets.images import ImagesBaseDataset from oml.ddp.utils import get_world_size_safe, is_main_process, sync_dicts_ddp from oml.lightning.modules.extractor import ExtractorModule from oml.lightning.pipelines.parser import parse_engine_params_from_config @@ -41,7 +41,7 @@ def extractor_prediction_pipeline(cfg: TCfg) -> None: if broken_images: raise ValueError(f"There are images that cannot be open:\n {broken_images}.") - dataset = ListDataset(filenames_list=filenames, transform=transforms, f_imread=f_imread) + dataset = ImagesBaseDataset(paths=filenames, transform=transforms, f_imread=f_imread) loader = DataLoader( dataset=dataset, batch_size=cfg["bs"], num_workers=cfg["num_workers"], shuffle=False, drop_last=False diff --git a/tests/test_oml/test_datasets/test_list_dataest.py b/tests/test_oml/test_datasets/test_list_dataest.py deleted file mode 100644 index c431b5e44..000000000 --- a/tests/test_oml/test_datasets/test_list_dataest.py +++ /dev/null @@ -1,79 +0,0 @@ -from pathlib import Path -from typing import Iterator, List, Optional, Sequence, Tuple - -import pandas as pd -import pytest -import torch -from torch.utils.data import DataLoader - -from oml.const import MOCK_DATASET_PATH, TBBox -from oml.datasets.list_dataset import ListDataset - - -@pytest.fixture -def images() -> Iterator[List[Path]]: - yield list((MOCK_DATASET_PATH / "images").iterdir()) - - -def get_images_and_boxes() -> Tuple[List[Path], List[TBBox]]: - df = pd.read_csv(MOCK_DATASET_PATH / "df_with_bboxes.csv") - sub_df = df[["path", "x_1", "y_1", "x_2", "y_2"]] - sub_df["path"] = sub_df["path"].apply(lambda p: MOCK_DATASET_PATH / p) - paths, bboxes = [], [] - for row in sub_df.iterrows(): - path, x1, y1, x2, y2 = row[1] - paths.append(path) - bboxes.append((x1, y1, x2, y2)) - return paths, bboxes - - -def get_images_and_boxes_with_nones() -> Tuple[List[Path], List[Optional[TBBox]]]: - import random - - random.seed(42) - - paths, bboxes = [], [] - for path, bbox in zip(*get_images_and_boxes()): - paths.append(path) - bboxes.append(bbox) - if random.random() > 0.5: - paths.append(path) - bboxes.append(None) - return paths, bboxes - - -def test_dataset_len(images: List[Path]) -> None: - assert len(images) > 0 - dataset = ListDataset(images) - - assert len(dataset) == len(images) - - -def test_dataset_iter(images: List[Path]) -> None: - dataset = ListDataset(images) - - for batch in dataset: - assert isinstance(batch[dataset.input_tensors_key], torch.Tensor) - - -def test_dataloader_iter(images: List[Path]) -> None: - dataset = ListDataset(images) - dataloader = DataLoader(dataset) - - for batch in dataloader: - assert batch[dataset.input_tensors_key].ndim == 4 - - -@pytest.mark.parametrize("im_paths,bboxes", [get_images_and_boxes(), get_images_and_boxes_with_nones()]) -def test_list_dataset_iter(im_paths: Sequence[Path], bboxes: Sequence[Optional[TBBox]]) -> None: - dataset = ListDataset(im_paths, bboxes) - - dataloader = DataLoader(dataset) - for batch, box in zip(dataloader, bboxes): - image = batch[dataset.input_tensors_key] - if box is not None: - x1, y1, x2, y2 = box - else: - x1, y1, x2, y2 = 0, 0, image.size()[2], image.size()[3] - assert image.ndim == 4 - assert image.size() == (1, 3, x2 - x1, y2 - y1) From d2e821514967926d2e3c53d8532f43e61b3b7803 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Fri, 19 Apr 2024 08:49:58 +0700 Subject: [PATCH 03/23] upd --- docs/source/contents/datasets.rst | 16 +++++----- docs/source/contents/interfaces.rst | 31 +++++++++++++++++-- oml/datasets/base.py | 6 ++-- oml/datasets/images.py | 18 +++++------ oml/interfaces/datasets.py | 10 +++--- oml/lightning/pipelines/parser.py | 4 +-- .../pipelines/train_postprocessor.py | 6 ++-- .../test_train_with_mining.py | 4 +-- tests/test_integrations/utils.py | 4 +-- .../test_transforms/test_image_augs.py | 6 ++-- 10 files changed, 66 insertions(+), 39 deletions(-) diff --git a/docs/source/contents/datasets.rst b/docs/source/contents/datasets.rst index 63f284616..6d83e210b 100644 --- a/docs/source/contents/datasets.rst +++ b/docs/source/contents/datasets.rst @@ -7,33 +7,35 @@ Datasets .. contents:: :local: -BaseDataset +ImagesBaseDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.base.BaseDataset +.. autoclass:: oml.datasets.images.BaseDataset :undoc-members: :show-inheritance: .. automethod:: __init__ + .. automethod:: visualize -DatasetWithLabels +ImagesDatasetLabeled ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.base.DatasetWithLabels +.. autoclass:: oml.datasets.images.ImagesDatasetLabeled :undoc-members: :show-inheritance: .. automethod:: __init__ .. automethod:: __getitem__ .. automethod:: get_labels - .. automethod:: get_label2category -DatasetQueryGallery +ImagesDatasetQueryGalleryLabeled ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.base.DatasetQueryGallery +.. autoclass:: oml.datasets.images.DatasetQueryGallery :undoc-members: :show-inheritance: .. automethod:: __init__ .. automethod:: __getitem__ + .. automethod:: get_query_ids + .. automethod:: get_gallery_ids EmbeddingPairsDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/contents/interfaces.rst b/docs/source/contents/interfaces.rst index aa5a8a01e..c4cf3c182 100644 --- a/docs/source/contents/interfaces.rst +++ b/docs/source/contents/interfaces.rst @@ -52,9 +52,15 @@ ITripletLossWithMiner .. automethod:: forward -IDatasetWithLabels +IBaseDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.interfaces.datasets.IDatasetWithLabels +.. autoclass:: oml.interfaces.datasets.IBaseDataset + :undoc-members: + :show-inheritance: + +IDatasetLabeled +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: oml.interfaces.datasets.IDatasetLabeled :undoc-members: :show-inheritance: @@ -67,7 +73,18 @@ IDatasetQueryGallery :undoc-members: :show-inheritance: - .. automethod:: __getitem__ + .. automethod:: get_query_ids + .. automethod:: get_gallery_ids + +IDatasetQueryGalleryLabeled +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: oml.interfaces.datasets.IDatasetQueryGalleryLabeled + :undoc-members: + :show-inheritance: + + .. automethod:: get_query_ids + .. automethod:: get_gallery_ids + .. automethod:: get_labels IPairsDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -78,6 +95,14 @@ IPairsDataset .. automethod:: __init__ .. automethod:: __getitem__ +IVisualizableDataset +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: oml.interfaces.datasets.IVisualizableDataset + :undoc-members: + :show-inheritance: + + .. automethod:: visualize + IBasicMetric ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: oml.interfaces.metrics.IBasicMetric diff --git a/oml/datasets/base.py b/oml/datasets/base.py index 67096d0b0..5b943a640 100644 --- a/oml/datasets/base.py +++ b/oml/datasets/base.py @@ -1,12 +1,12 @@ -from oml.datasets.images import ImagesDatasetQueryGallery, ImagesDatasetWithLabels +from oml.datasets.images import ImagesDatasetLabeled, ImagesDatasetQueryGalleryLabeled -class DatasetWithLabels(ImagesDatasetWithLabels): +class DatasetWithLabels(ImagesDatasetLabeled): # this class allows to have back compatibility pass -class DatasetQueryGallery(ImagesDatasetQueryGallery): +class DatasetQueryGallery(ImagesDatasetQueryGalleryLabeled): # this class allows to have back compatibility pass diff --git a/oml/datasets/images.py b/oml/datasets/images.py index a1ed54f15..d23f6aeb3 100644 --- a/oml/datasets/images.py +++ b/oml/datasets/images.py @@ -39,8 +39,8 @@ ) from oml.interfaces.datasets import ( IBaseDataset, - IDatasetQueryGallery, - IDatasetWithLabels, + IDatasetLabeled, + IDatasetQueryGalleryLabeled, IVisualizableDataset, ) from oml.registry.transforms import get_transforms @@ -218,7 +218,7 @@ def bboxes_keys(self) -> Tuple[str, ...]: return self.x1_key, self.y1_key, self.x2_key, self.y2_key -class ImagesDatasetWithLabels(ImagesBaseDataset, IDatasetWithLabels): +class ImagesDatasetLabeled(ImagesBaseDataset, IDatasetLabeled): """ The dataset of images having their ground truth labels. @@ -298,7 +298,7 @@ def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]: return label2category -class ImagesDatasetQueryGallery(ImagesDatasetWithLabels, IDatasetQueryGallery): +class ImagesDatasetQueryGalleryLabeled(ImagesDatasetLabeled, IDatasetQueryGalleryLabeled): """ The dataset of images having `query`/`gallery` split. @@ -385,7 +385,7 @@ def get_retrieval_images_datasets( dataframe_name: str = "df.csv", cache_size: Optional[int] = 0, verbose: bool = True, -) -> Tuple[IDatasetWithLabels, IDatasetQueryGallery]: +) -> Tuple[IDatasetLabeled, IDatasetQueryGalleryLabeled]: df = pd.read_csv(dataset_root / dataframe_name, index_col=False) check_retrieval_dataframe_format(df, dataset_root=dataset_root, verbose=verbose) @@ -399,7 +399,7 @@ def get_retrieval_images_datasets( df_train = df[df[SPLIT_COLUMN] == "train"].reset_index(drop=True) df_train[LABELS_COLUMN] = df_train[LABELS_COLUMN].map(mapper) - train_dataset = ImagesDatasetWithLabels( + train_dataset = ImagesDatasetLabeled( df=df_train, dataset_root=dataset_root, transform=transforms_train, @@ -409,7 +409,7 @@ def get_retrieval_images_datasets( # val (query + gallery) df_query_gallery = df[df[SPLIT_COLUMN] == "validation"].reset_index(drop=True) - valid_dataset = ImagesDatasetQueryGallery( + valid_dataset = ImagesDatasetQueryGalleryLabeled( df=df_query_gallery, dataset_root=dataset_root, transform=transforms_val, @@ -422,7 +422,7 @@ def get_retrieval_images_datasets( __all__ = [ "ImagesBaseDataset", - "ImagesDatasetWithLabels", - "ImagesDatasetQueryGallery", + "ImagesDatasetLabeled", + "ImagesDatasetQueryGalleryLabeled", "get_retrieval_images_datasets", ] diff --git a/oml/interfaces/datasets.py b/oml/interfaces/datasets.py index e396e5a5c..fe4f6a52f 100644 --- a/oml/interfaces/datasets.py +++ b/oml/interfaces/datasets.py @@ -28,7 +28,7 @@ def __getitem__(self, item: int) -> Dict[str, Any]: raise NotImplementedError() -class IDatasetWithLabels(IBaseDataset, ABC): +class IDatasetLabeled(IBaseDataset, ABC): """ This is an interface for the datasets which provide labels of containing items. @@ -55,7 +55,7 @@ def get_labels(self) -> np.ndarray: raise NotImplementedError() -class IDatasetQueryGalleryPrediction(IBaseDataset, ABC): +class IDatasetQueryGallery(IBaseDataset, ABC): """ This is an interface for the datasets which hold the information on how to split the data into the query and gallery. The query and gallery ids may overlap. @@ -72,7 +72,7 @@ def get_gallery_ids(self) -> LongTensor: raise NotImplementedError() -class IDatasetQueryGallery(IDatasetQueryGalleryPrediction, IDatasetWithLabels, ABC): +class IDatasetQueryGalleryLabeled(IDatasetQueryGallery, IDatasetLabeled, ABC): """ This class is similar to "IDatasetQueryGalleryPrediction", but we also have ground truth labels. """ @@ -117,9 +117,9 @@ def visualize(self, idx: int, color: TColor) -> np.ndarray: __all__ = [ "IBaseDataset", - "IDatasetWithLabels", + "IDatasetLabeled", + "IDatasetQueryGalleryLabeled", "IDatasetQueryGallery", - "IDatasetQueryGalleryPrediction", "IPairsDataset", "IVisualizableDataset", ] diff --git a/oml/lightning/pipelines/parser.py b/oml/lightning/pipelines/parser.py index c0bad14bd..7c0bc75db 100644 --- a/oml/lightning/pipelines/parser.py +++ b/oml/lightning/pipelines/parser.py @@ -6,7 +6,7 @@ from pytorch_lightning.strategies import DDPStrategy from oml.const import TCfg -from oml.interfaces.datasets import IDatasetWithLabels +from oml.interfaces.datasets import IDatasetLabeled from oml.interfaces.loggers import IPipelineLogger from oml.interfaces.samplers import IBatchSampler from oml.lightning.pipelines.logging import TensorBoardPipelineLogger @@ -76,7 +76,7 @@ def parse_scheduler_from_config(cfg: TCfg, optimizer: torch.optim.Optimizer) -> return scheduler_kwargs -def parse_sampler_from_config(cfg: TCfg, dataset: IDatasetWithLabels) -> Optional[IBatchSampler]: +def parse_sampler_from_config(cfg: TCfg, dataset: IDatasetLabeled) -> Optional[IBatchSampler]: if ( (not dataset.categories_key) and (cfg["sampler"] is not None) diff --git a/oml/lightning/pipelines/train_postprocessor.py b/oml/lightning/pipelines/train_postprocessor.py index 97485328e..f8a4f3363 100644 --- a/oml/lightning/pipelines/train_postprocessor.py +++ b/oml/lightning/pipelines/train_postprocessor.py @@ -11,7 +11,7 @@ from torch.utils.data import DataLoader from oml.const import BBOXES_COLUMNS, EMBEDDINGS_KEY, TCfg -from oml.datasets.base import DatasetQueryGallery, DatasetWithLabels +from oml.datasets.base import ImagesDatasetLabeled, ImagesDatasetQueryGalleryLabeled from oml.inference.flat import inference_on_dataframe from oml.interfaces.models import IPairwiseModel from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP @@ -81,13 +81,13 @@ def get_loaders_with_embeddings(cfg: TCfg) -> Tuple[DataLoader, DataLoader]: use_fp16=int(cfg.get("precision", 32)) == 16, ) - train_dataset = DatasetWithLabels( + train_dataset = ImagesDatasetLabeled( df=df_train, transform=get_transforms_by_cfg(cfg["transforms_train"]), extra_data={EMBEDDINGS_KEY: emb_train}, ) - valid_dataset = DatasetQueryGallery( + valid_dataset = ImagesDatasetQueryGalleryLabeled( df=df_val, # we don't care about transforms, since the only goal of this dataset is to deliver embeddings transform=get_normalisation_resize_torch(im_size=8), diff --git a/tests/test_integrations/test_train_with_mining.py b/tests/test_integrations/test_train_with_mining.py index a869cf2b8..34ed3d185 100644 --- a/tests/test_integrations/test_train_with_mining.py +++ b/tests/test_integrations/test_train_with_mining.py @@ -8,7 +8,7 @@ from torch.utils.data import DataLoader from oml.const import INPUT_TENSORS_KEY, LABELS_KEY -from oml.interfaces.datasets import IDatasetWithLabels +from oml.interfaces.datasets import IDatasetLabeled from oml.losses.triplet import TripletLossWithMiner from oml.miners.cross_batch import TripletMinerWithMemory from oml.registry.miners import get_miner @@ -16,7 +16,7 @@ from tests.test_integrations.utils import IdealOneHotModel -class DummyDataset(IDatasetWithLabels): +class DummyDataset(IDatasetLabeled): def __init__(self, n_labels: int, n_samples_min: int): self.labels = [] for i in range(n_labels): diff --git a/tests/test_integrations/utils.py b/tests/test_integrations/utils.py index 2cb79cbac..62f47a07e 100644 --- a/tests/test_integrations/utils.py +++ b/tests/test_integrations/utils.py @@ -13,7 +13,7 @@ LABELS_KEY, SEQUENCE_COLUMN, ) -from oml.interfaces.datasets import IDatasetQueryGallery +from oml.interfaces.datasets import IDatasetQueryGalleryLabeled from oml.utils.misc import one_hot @@ -35,7 +35,7 @@ def forward(self, labels: torch.Tensor, need_noise: bool = True) -> torch.Tensor return embeddings -class EmbeddingsQueryGalleryDataset(IDatasetQueryGallery): +class EmbeddingsQueryGalleryDataset(IDatasetQueryGalleryLabeled): def __init__( self, embeddings: FloatTensor, diff --git a/tests/test_oml/test_transforms/test_image_augs.py b/tests/test_oml/test_transforms/test_image_augs.py index fcbb1005b..ae21f9752 100644 --- a/tests/test_oml/test_transforms/test_image_augs.py +++ b/tests/test_oml/test_transforms/test_image_augs.py @@ -5,7 +5,7 @@ from omegaconf import OmegaConf from oml.const import CONFIGS_PATH, MOCK_DATASET_PATH -from oml.datasets.base import DatasetWithLabels +from oml.datasets.images import ImagesDatasetLabeled from oml.registry.transforms import TRANSFORMS_REGISTRY, get_transforms_by_cfg @@ -14,7 +14,7 @@ def test_transforms(aug_name: Optional[str]) -> None: df = pd.read_csv(MOCK_DATASET_PATH / "df.csv") transforms = get_transforms_by_cfg(OmegaConf.load(CONFIGS_PATH / "transforms" / f"{aug_name}.yaml")) - dataset = DatasetWithLabels(df=df, dataset_root=MOCK_DATASET_PATH, transform=transforms) + dataset = ImagesDatasetLabeled(df=df, dataset_root=MOCK_DATASET_PATH, transform=transforms) _ = dataset[0] @@ -23,6 +23,6 @@ def test_transforms(aug_name: Optional[str]) -> None: def test_default_transforms() -> None: df = pd.read_csv(MOCK_DATASET_PATH / "df.csv") - dataset = DatasetWithLabels(df=df, dataset_root=MOCK_DATASET_PATH, transform=None) + dataset = ImagesDatasetLabeled(df=df, dataset_root=MOCK_DATASET_PATH, transform=None) _ = dataset[0] assert True From ad3fe93afd58057faf133c258d43e50632d85269 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Fri, 19 Apr 2024 09:05:18 +0700 Subject: [PATCH 04/23] minor --- docs/source/contents/datasets.rst | 12 ++++++------ oml/datasets/base.py | 10 +++++----- oml/datasets/images.py | 16 ++++++++-------- oml/datasets/pairs.py | 6 +++--- oml/inference/flat.py | 4 ++-- oml/interfaces/datasets.py | 2 +- oml/lightning/pipelines/predict.py | 4 ++-- oml/lightning/pipelines/train_postprocessor.py | 6 +++--- .../test_oml/test_transforms/test_image_augs.py | 6 +++--- 9 files changed, 33 insertions(+), 33 deletions(-) diff --git a/docs/source/contents/datasets.rst b/docs/source/contents/datasets.rst index 6d83e210b..e4a27e234 100644 --- a/docs/source/contents/datasets.rst +++ b/docs/source/contents/datasets.rst @@ -7,18 +7,18 @@ Datasets .. contents:: :local: -ImagesBaseDataset +ImageBaseDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.images.BaseDataset +.. autoclass:: oml.datasets.images.ImageBaseDataset :undoc-members: :show-inheritance: .. automethod:: __init__ .. automethod:: visualize -ImagesDatasetLabeled +ImageDatasetLabeled ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.images.ImagesDatasetLabeled +.. autoclass:: oml.datasets.images.ImageDatasetLabeled :undoc-members: :show-inheritance: @@ -26,9 +26,9 @@ ImagesDatasetLabeled .. automethod:: __getitem__ .. automethod:: get_labels -ImagesDatasetQueryGalleryLabeled +ImageDatasetQueryGalleryLabeled ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.images.DatasetQueryGallery +.. autoclass:: oml.datasets.images.ImageDatasetQueryGalleryLabeled :undoc-members: :show-inheritance: diff --git a/oml/datasets/base.py b/oml/datasets/base.py index 5b943a640..ed4aadd36 100644 --- a/oml/datasets/base.py +++ b/oml/datasets/base.py @@ -1,13 +1,13 @@ -from oml.datasets.images import ImagesDatasetLabeled, ImagesDatasetQueryGalleryLabeled +from oml.datasets.images import ImageDatasetLabeled, ImageDatasetQueryGalleryLabeled -class DatasetWithLabels(ImagesDatasetLabeled): - # this class allows to have back compatibility +class DatasetWithLabels(ImageDatasetLabeled): + # this class allows to have backward compatibility pass -class DatasetQueryGallery(ImagesDatasetQueryGalleryLabeled): - # this class allows to have back compatibility +class DatasetQueryGallery(ImageDatasetQueryGalleryLabeled): + # this class allows to have backward compatibility pass diff --git a/oml/datasets/images.py b/oml/datasets/images.py index d23f6aeb3..bf4d35624 100644 --- a/oml/datasets/images.py +++ b/oml/datasets/images.py @@ -75,7 +75,7 @@ def parse_bboxes(df: pd.DataFrame) -> Optional[TBBoxes]: return bboxes -class ImagesBaseDataset(IBaseDataset, IVisualizableDataset): +class ImageBaseDataset(IBaseDataset, IVisualizableDataset): """ The base class that handles image specific logic. @@ -218,7 +218,7 @@ def bboxes_keys(self) -> Tuple[str, ...]: return self.x1_key, self.y1_key, self.x2_key, self.y2_key -class ImagesDatasetLabeled(ImagesBaseDataset, IDatasetLabeled): +class ImageDatasetLabeled(ImageBaseDataset, IDatasetLabeled): """ The dataset of images having their ground truth labels. @@ -298,7 +298,7 @@ def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]: return label2category -class ImagesDatasetQueryGalleryLabeled(ImagesDatasetLabeled, IDatasetQueryGalleryLabeled): +class ImageDatasetQueryGalleryLabeled(ImageDatasetLabeled, IDatasetQueryGalleryLabeled): """ The dataset of images having `query`/`gallery` split. @@ -399,7 +399,7 @@ def get_retrieval_images_datasets( df_train = df[df[SPLIT_COLUMN] == "train"].reset_index(drop=True) df_train[LABELS_COLUMN] = df_train[LABELS_COLUMN].map(mapper) - train_dataset = ImagesDatasetLabeled( + train_dataset = ImageDatasetLabeled( df=df_train, dataset_root=dataset_root, transform=transforms_train, @@ -409,7 +409,7 @@ def get_retrieval_images_datasets( # val (query + gallery) df_query_gallery = df[df[SPLIT_COLUMN] == "validation"].reset_index(drop=True) - valid_dataset = ImagesDatasetQueryGalleryLabeled( + valid_dataset = ImageDatasetQueryGalleryLabeled( df=df_query_gallery, dataset_root=dataset_root, transform=transforms_val, @@ -421,8 +421,8 @@ def get_retrieval_images_datasets( __all__ = [ - "ImagesBaseDataset", - "ImagesDatasetLabeled", - "ImagesDatasetQueryGalleryLabeled", + "ImageBaseDataset", + "ImageDatasetLabeled", + "ImageDatasetQueryGalleryLabeled", "get_retrieval_images_datasets", ] diff --git a/oml/datasets/pairs.py b/oml/datasets/pairs.py index 7d755e1ff..0c7b44d10 100644 --- a/oml/datasets/pairs.py +++ b/oml/datasets/pairs.py @@ -4,7 +4,7 @@ from torch import Tensor from oml.const import INDEX_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY, TBBoxes -from oml.datasets.images import ImagesBaseDataset +from oml.datasets.images import ImageBaseDataset from oml.interfaces.datasets import IPairsDataset from oml.transforms.images.torchvision import get_normalisation_torch from oml.transforms.images.utils import TTransforms @@ -98,8 +98,8 @@ def __init__( cache_size = cache_size // 2 if cache_size else None dataset_args = {"transform": transform, "f_imread": f_imread, "cache_size": cache_size} - self.dataset1 = ImagesBaseDataset(paths=paths1, bboxes=bboxes1, **dataset_args) - self.dataset2 = ImagesBaseDataset(paths=paths2, bboxes=bboxes2, **dataset_args) + self.dataset1 = ImageBaseDataset(paths=paths1, bboxes=bboxes1, **dataset_args) + self.dataset2 = ImageBaseDataset(paths=paths2, bboxes=bboxes2, **dataset_args) self.pair_1st_key = pair_1st_key self.pair_2nd_key = pair_2nd_key diff --git a/oml/inference/flat.py b/oml/inference/flat.py index 00c98b959..e8e4c063f 100644 --- a/oml/inference/flat.py +++ b/oml/inference/flat.py @@ -7,7 +7,7 @@ from torch import Tensor, nn from oml.const import PATHS_COLUMN, SPLIT_COLUMN -from oml.datasets.images import ImagesBaseDataset +from oml.datasets.images import ImageBaseDataset from oml.inference.abstract import _inference from oml.interfaces.models import IExtractor from oml.transforms.images.utils import TTransforms @@ -28,7 +28,7 @@ def inference_on_images( use_fp16: bool = False, accumulate_on_cpu: bool = True, ) -> Tensor: - dataset = ImagesBaseDataset(paths=paths, bboxes=None, transform=transform, f_imread=f_imread, cache_size=0) + dataset = ImageBaseDataset(paths=paths, bboxes=None, transform=transform, f_imread=f_imread, cache_size=0) device = get_device(model) def _apply(model_: nn.Module, batch_: Dict[str, Any]) -> Tensor: diff --git a/oml/interfaces/datasets.py b/oml/interfaces/datasets.py index fe4f6a52f..37e7211b5 100644 --- a/oml/interfaces/datasets.py +++ b/oml/interfaces/datasets.py @@ -74,7 +74,7 @@ def get_gallery_ids(self) -> LongTensor: class IDatasetQueryGalleryLabeled(IDatasetQueryGallery, IDatasetLabeled, ABC): """ - This class is similar to "IDatasetQueryGalleryPrediction", but we also have ground truth labels. + This interface is similar to `IDatasetQueryGallery`, but there are ground truth labels. """ diff --git a/oml/lightning/pipelines/predict.py b/oml/lightning/pipelines/predict.py index 12849bf9d..8a4e8f0dc 100644 --- a/oml/lightning/pipelines/predict.py +++ b/oml/lightning/pipelines/predict.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader from oml.const import IMAGE_EXTENSIONS, TCfg -from oml.datasets.images import ImagesBaseDataset +from oml.datasets.images import ImageBaseDataset from oml.ddp.utils import get_world_size_safe, is_main_process, sync_dicts_ddp from oml.lightning.modules.extractor import ExtractorModule from oml.lightning.pipelines.parser import parse_engine_params_from_config @@ -41,7 +41,7 @@ def extractor_prediction_pipeline(cfg: TCfg) -> None: if broken_images: raise ValueError(f"There are images that cannot be open:\n {broken_images}.") - dataset = ImagesBaseDataset(paths=filenames, transform=transforms, f_imread=f_imread) + dataset = ImageBaseDataset(paths=filenames, transform=transforms, f_imread=f_imread) loader = DataLoader( dataset=dataset, batch_size=cfg["bs"], num_workers=cfg["num_workers"], shuffle=False, drop_last=False diff --git a/oml/lightning/pipelines/train_postprocessor.py b/oml/lightning/pipelines/train_postprocessor.py index f8a4f3363..1472eafc5 100644 --- a/oml/lightning/pipelines/train_postprocessor.py +++ b/oml/lightning/pipelines/train_postprocessor.py @@ -11,7 +11,7 @@ from torch.utils.data import DataLoader from oml.const import BBOXES_COLUMNS, EMBEDDINGS_KEY, TCfg -from oml.datasets.base import ImagesDatasetLabeled, ImagesDatasetQueryGalleryLabeled +from oml.datasets.base import ImageDatasetLabeled, ImageDatasetQueryGalleryLabeled from oml.inference.flat import inference_on_dataframe from oml.interfaces.models import IPairwiseModel from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP @@ -81,13 +81,13 @@ def get_loaders_with_embeddings(cfg: TCfg) -> Tuple[DataLoader, DataLoader]: use_fp16=int(cfg.get("precision", 32)) == 16, ) - train_dataset = ImagesDatasetLabeled( + train_dataset = ImageDatasetLabeled( df=df_train, transform=get_transforms_by_cfg(cfg["transforms_train"]), extra_data={EMBEDDINGS_KEY: emb_train}, ) - valid_dataset = ImagesDatasetQueryGalleryLabeled( + valid_dataset = ImageDatasetQueryGalleryLabeled( df=df_val, # we don't care about transforms, since the only goal of this dataset is to deliver embeddings transform=get_normalisation_resize_torch(im_size=8), diff --git a/tests/test_oml/test_transforms/test_image_augs.py b/tests/test_oml/test_transforms/test_image_augs.py index ae21f9752..1445d0e84 100644 --- a/tests/test_oml/test_transforms/test_image_augs.py +++ b/tests/test_oml/test_transforms/test_image_augs.py @@ -5,7 +5,7 @@ from omegaconf import OmegaConf from oml.const import CONFIGS_PATH, MOCK_DATASET_PATH -from oml.datasets.images import ImagesDatasetLabeled +from oml.datasets.images import ImageDatasetLabeled from oml.registry.transforms import TRANSFORMS_REGISTRY, get_transforms_by_cfg @@ -14,7 +14,7 @@ def test_transforms(aug_name: Optional[str]) -> None: df = pd.read_csv(MOCK_DATASET_PATH / "df.csv") transforms = get_transforms_by_cfg(OmegaConf.load(CONFIGS_PATH / "transforms" / f"{aug_name}.yaml")) - dataset = ImagesDatasetLabeled(df=df, dataset_root=MOCK_DATASET_PATH, transform=transforms) + dataset = ImageDatasetLabeled(df=df, dataset_root=MOCK_DATASET_PATH, transform=transforms) _ = dataset[0] @@ -23,6 +23,6 @@ def test_transforms(aug_name: Optional[str]) -> None: def test_default_transforms() -> None: df = pd.read_csv(MOCK_DATASET_PATH / "df.csv") - dataset = ImagesDatasetLabeled(df=df, dataset_root=MOCK_DATASET_PATH, transform=None) + dataset = ImageDatasetLabeled(df=df, dataset_root=MOCK_DATASET_PATH, transform=None) _ = dataset[0] assert True From 6919668580c16c7ceb8f17d49b6276a9263717f8 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Fri, 19 Apr 2024 09:25:24 +0700 Subject: [PATCH 05/23] fix --- oml/datasets/images.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/oml/datasets/images.py b/oml/datasets/images.py index bf4d35624..5f9d43757 100644 --- a/oml/datasets/images.py +++ b/oml/datasets/images.py @@ -57,12 +57,12 @@ def parse_bboxes(df: pd.DataFrame) -> Optional[TBBoxes]: - n_existing_columns = sum([x in df for x in [X1_COLUMN, X2_COLUMN, Y1_COLUMN, Y2_COLUMN]]) + n_existing_columns = sum([x in df for x in [X1_COLUMN, Y1_COLUMN, X2_COLUMN, Y2_COLUMN]]) if n_existing_columns == 4: bboxes = [] - for row in df.iterrows(): - bbox = int(row[X1_COLUMN]), int(row[X2_COLUMN]), int(row[Y1_COLUMN]), int(row[Y2_COLUMN]) + for _, row in df.iterrows(): + bbox = int(row[X1_COLUMN]), int(row[Y1_COLUMN]), int(row[X2_COLUMN]), int(row[Y2_COLUMN]) bbox = None if any(coord is None for coord in bbox) else bbox bboxes.append(bbox) @@ -206,7 +206,7 @@ def __len__(self) -> int: return len(self._paths) def visualize(self, idx: int, color: TColor = BLACK) -> np.ndarray: - bbox = torch.tensor(self._bboxes[idx]) if (self._bboxes is not None) else torch.tensor([None] * 4) + bbox = torch.tensor(self._bboxes[idx]) if (self._bboxes is not None) else torch.tensor([torch.nan] * 4) image = get_img_with_bbox(im_path=self._paths[idx], bbox=bbox, color=color) image = square_pad(image) From e75f99124e0b46b5011f299019b1211d7ab4a673 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Fri, 19 Apr 2024 11:01:32 +0700 Subject: [PATCH 06/23] upd --- tests/test_runs/test_pipelines/configs/validate.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_runs/test_pipelines/configs/validate.yaml b/tests/test_runs/test_pipelines/configs/validate.yaml index 942e99296..5096f57a6 100644 --- a/tests/test_runs/test_pipelines/configs/validate.yaml +++ b/tests/test_runs/test_pipelines/configs/validate.yaml @@ -2,7 +2,7 @@ accelerator: cpu devices: 1 dataset_root: path_to_replace # we will replace it in runtime with the default dataset folder -dataframe_name: df.csv +dataframe_name: df_with_bboxes.csv transforms_val: name: norm_resize_torch From 4d381ed740f3f83639d8b08afa42575d0b6c2eea Mon Sep 17 00:00:00 2001 From: alekseysh Date: Fri, 19 Apr 2024 15:43:50 +0700 Subject: [PATCH 07/23] update_naming --- docs/source/contents/datasets.rst | 8 ++++---- docs/source/contents/interfaces.rst | 12 ++++++------ oml/datasets/base.py | 6 +++--- oml/datasets/images.py | 18 +++++++++--------- oml/interfaces/datasets.py | 14 +++++++------- oml/lightning/pipelines/parser.py | 4 ++-- oml/lightning/pipelines/train_postprocessor.py | 6 +++--- .../test_train_with_mining.py | 4 ++-- tests/test_integrations/utils.py | 4 ++-- .../test_transforms/test_image_augs.py | 6 +++--- 10 files changed, 41 insertions(+), 41 deletions(-) diff --git a/docs/source/contents/datasets.rst b/docs/source/contents/datasets.rst index e4a27e234..41bbc9cc0 100644 --- a/docs/source/contents/datasets.rst +++ b/docs/source/contents/datasets.rst @@ -16,9 +16,9 @@ ImageBaseDataset .. automethod:: __init__ .. automethod:: visualize -ImageDatasetLabeled +ImageLabeledDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.images.ImageDatasetLabeled +.. autoclass:: oml.datasets.images.ImageLabeledDataset :undoc-members: :show-inheritance: @@ -26,9 +26,9 @@ ImageDatasetLabeled .. automethod:: __getitem__ .. automethod:: get_labels -ImageDatasetQueryGalleryLabeled +ImageQueryGalleryLabeledDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.images.ImageDatasetQueryGalleryLabeled +.. autoclass:: oml.datasets.images.ImageQueryGalleryLabeledDataset :undoc-members: :show-inheritance: diff --git a/docs/source/contents/interfaces.rst b/docs/source/contents/interfaces.rst index c4cf3c182..7e7224491 100644 --- a/docs/source/contents/interfaces.rst +++ b/docs/source/contents/interfaces.rst @@ -58,27 +58,27 @@ IBaseDataset :undoc-members: :show-inheritance: -IDatasetLabeled +ILabeledDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.interfaces.datasets.IDatasetLabeled +.. autoclass:: oml.interfaces.datasets.ILabeledDataset :undoc-members: :show-inheritance: .. automethod:: __getitem__ .. automethod:: get_labels -IDatasetQueryGallery +IQueryGalleryDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.interfaces.datasets.IDatasetQueryGallery +.. autoclass:: oml.interfaces.datasets.IQueryGalleryDataset :undoc-members: :show-inheritance: .. automethod:: get_query_ids .. automethod:: get_gallery_ids -IDatasetQueryGalleryLabeled +IQueryGalleryLabeledDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.interfaces.datasets.IDatasetQueryGalleryLabeled +.. autoclass:: oml.interfaces.datasets.IQueryGalleryLabeledDataset :undoc-members: :show-inheritance: diff --git a/oml/datasets/base.py b/oml/datasets/base.py index ed4aadd36..bd1aafd97 100644 --- a/oml/datasets/base.py +++ b/oml/datasets/base.py @@ -1,12 +1,12 @@ -from oml.datasets.images import ImageDatasetLabeled, ImageDatasetQueryGalleryLabeled +from oml.datasets.images import ImageLabeledDataset, ImageQueryGalleryLabeledDataset -class DatasetWithLabels(ImageDatasetLabeled): +class DatasetWithLabels(ImageLabeledDataset): # this class allows to have backward compatibility pass -class DatasetQueryGallery(ImageDatasetQueryGalleryLabeled): +class DatasetQueryGallery(ImageQueryGalleryLabeledDataset): # this class allows to have backward compatibility pass diff --git a/oml/datasets/images.py b/oml/datasets/images.py index 5f9d43757..699651b9b 100644 --- a/oml/datasets/images.py +++ b/oml/datasets/images.py @@ -39,8 +39,8 @@ ) from oml.interfaces.datasets import ( IBaseDataset, - IDatasetLabeled, - IDatasetQueryGalleryLabeled, + ILabeledDataset, + IQueryGalleryLabeledDataset, IVisualizableDataset, ) from oml.registry.transforms import get_transforms @@ -218,7 +218,7 @@ def bboxes_keys(self) -> Tuple[str, ...]: return self.x1_key, self.y1_key, self.x2_key, self.y2_key -class ImageDatasetLabeled(ImageBaseDataset, IDatasetLabeled): +class ImageLabeledDataset(ImageBaseDataset, ILabeledDataset): """ The dataset of images having their ground truth labels. @@ -298,7 +298,7 @@ def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]: return label2category -class ImageDatasetQueryGalleryLabeled(ImageDatasetLabeled, IDatasetQueryGalleryLabeled): +class ImageQueryGalleryLabeledDataset(ImageLabeledDataset, IQueryGalleryLabeledDataset): """ The dataset of images having `query`/`gallery` split. @@ -385,7 +385,7 @@ def get_retrieval_images_datasets( dataframe_name: str = "df.csv", cache_size: Optional[int] = 0, verbose: bool = True, -) -> Tuple[IDatasetLabeled, IDatasetQueryGalleryLabeled]: +) -> Tuple[ILabeledDataset, IQueryGalleryLabeledDataset]: df = pd.read_csv(dataset_root / dataframe_name, index_col=False) check_retrieval_dataframe_format(df, dataset_root=dataset_root, verbose=verbose) @@ -399,7 +399,7 @@ def get_retrieval_images_datasets( df_train = df[df[SPLIT_COLUMN] == "train"].reset_index(drop=True) df_train[LABELS_COLUMN] = df_train[LABELS_COLUMN].map(mapper) - train_dataset = ImageDatasetLabeled( + train_dataset = ImageLabeledDataset( df=df_train, dataset_root=dataset_root, transform=transforms_train, @@ -409,7 +409,7 @@ def get_retrieval_images_datasets( # val (query + gallery) df_query_gallery = df[df[SPLIT_COLUMN] == "validation"].reset_index(drop=True) - valid_dataset = ImageDatasetQueryGalleryLabeled( + valid_dataset = ImageQueryGalleryLabeledDataset( df=df_query_gallery, dataset_root=dataset_root, transform=transforms_val, @@ -422,7 +422,7 @@ def get_retrieval_images_datasets( __all__ = [ "ImageBaseDataset", - "ImageDatasetLabeled", - "ImageDatasetQueryGalleryLabeled", + "ImageLabeledDataset", + "ImageQueryGalleryLabeledDataset", "get_retrieval_images_datasets", ] diff --git a/oml/interfaces/datasets.py b/oml/interfaces/datasets.py index 37e7211b5..4726b7b36 100644 --- a/oml/interfaces/datasets.py +++ b/oml/interfaces/datasets.py @@ -28,7 +28,7 @@ def __getitem__(self, item: int) -> Dict[str, Any]: raise NotImplementedError() -class IDatasetLabeled(IBaseDataset, ABC): +class ILabeledDataset(IBaseDataset, ABC): """ This is an interface for the datasets which provide labels of containing items. @@ -55,7 +55,7 @@ def get_labels(self) -> np.ndarray: raise NotImplementedError() -class IDatasetQueryGallery(IBaseDataset, ABC): +class IQueryGalleryDataset(IBaseDataset, ABC): """ This is an interface for the datasets which hold the information on how to split the data into the query and gallery. The query and gallery ids may overlap. @@ -72,9 +72,9 @@ def get_gallery_ids(self) -> LongTensor: raise NotImplementedError() -class IDatasetQueryGalleryLabeled(IDatasetQueryGallery, IDatasetLabeled, ABC): +class IQueryGalleryLabeledDataset(IQueryGalleryDataset, ILabeledDataset, ABC): """ - This interface is similar to `IDatasetQueryGallery`, but there are ground truth labels. + This interface is similar to `IQueryGalleryDataset`, but there are ground truth labels. """ @@ -117,9 +117,9 @@ def visualize(self, idx: int, color: TColor) -> np.ndarray: __all__ = [ "IBaseDataset", - "IDatasetLabeled", - "IDatasetQueryGalleryLabeled", - "IDatasetQueryGallery", + "ILabeledDataset", + "IQueryGalleryLabeledDataset", + "IQueryGalleryDataset", "IPairsDataset", "IVisualizableDataset", ] diff --git a/oml/lightning/pipelines/parser.py b/oml/lightning/pipelines/parser.py index 7c0bc75db..306e5ce7a 100644 --- a/oml/lightning/pipelines/parser.py +++ b/oml/lightning/pipelines/parser.py @@ -6,7 +6,7 @@ from pytorch_lightning.strategies import DDPStrategy from oml.const import TCfg -from oml.interfaces.datasets import IDatasetLabeled +from oml.interfaces.datasets import ILabeledDataset from oml.interfaces.loggers import IPipelineLogger from oml.interfaces.samplers import IBatchSampler from oml.lightning.pipelines.logging import TensorBoardPipelineLogger @@ -76,7 +76,7 @@ def parse_scheduler_from_config(cfg: TCfg, optimizer: torch.optim.Optimizer) -> return scheduler_kwargs -def parse_sampler_from_config(cfg: TCfg, dataset: IDatasetLabeled) -> Optional[IBatchSampler]: +def parse_sampler_from_config(cfg: TCfg, dataset: ILabeledDataset) -> Optional[IBatchSampler]: if ( (not dataset.categories_key) and (cfg["sampler"] is not None) diff --git a/oml/lightning/pipelines/train_postprocessor.py b/oml/lightning/pipelines/train_postprocessor.py index 1472eafc5..cc302941f 100644 --- a/oml/lightning/pipelines/train_postprocessor.py +++ b/oml/lightning/pipelines/train_postprocessor.py @@ -11,7 +11,7 @@ from torch.utils.data import DataLoader from oml.const import BBOXES_COLUMNS, EMBEDDINGS_KEY, TCfg -from oml.datasets.base import ImageDatasetLabeled, ImageDatasetQueryGalleryLabeled +from oml.datasets.base import ImageLabeledDataset, ImageQueryGalleryLabeledDataset from oml.inference.flat import inference_on_dataframe from oml.interfaces.models import IPairwiseModel from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP @@ -81,13 +81,13 @@ def get_loaders_with_embeddings(cfg: TCfg) -> Tuple[DataLoader, DataLoader]: use_fp16=int(cfg.get("precision", 32)) == 16, ) - train_dataset = ImageDatasetLabeled( + train_dataset = ImageLabeledDataset( df=df_train, transform=get_transforms_by_cfg(cfg["transforms_train"]), extra_data={EMBEDDINGS_KEY: emb_train}, ) - valid_dataset = ImageDatasetQueryGalleryLabeled( + valid_dataset = ImageQueryGalleryLabeledDataset( df=df_val, # we don't care about transforms, since the only goal of this dataset is to deliver embeddings transform=get_normalisation_resize_torch(im_size=8), diff --git a/tests/test_integrations/test_train_with_mining.py b/tests/test_integrations/test_train_with_mining.py index 34ed3d185..95b17fb0f 100644 --- a/tests/test_integrations/test_train_with_mining.py +++ b/tests/test_integrations/test_train_with_mining.py @@ -8,7 +8,7 @@ from torch.utils.data import DataLoader from oml.const import INPUT_TENSORS_KEY, LABELS_KEY -from oml.interfaces.datasets import IDatasetLabeled +from oml.interfaces.datasets import ILabeledDataset from oml.losses.triplet import TripletLossWithMiner from oml.miners.cross_batch import TripletMinerWithMemory from oml.registry.miners import get_miner @@ -16,7 +16,7 @@ from tests.test_integrations.utils import IdealOneHotModel -class DummyDataset(IDatasetLabeled): +class DummyDataset(ILabeledDataset): def __init__(self, n_labels: int, n_samples_min: int): self.labels = [] for i in range(n_labels): diff --git a/tests/test_integrations/utils.py b/tests/test_integrations/utils.py index 62f47a07e..7c8627a1c 100644 --- a/tests/test_integrations/utils.py +++ b/tests/test_integrations/utils.py @@ -13,7 +13,7 @@ LABELS_KEY, SEQUENCE_COLUMN, ) -from oml.interfaces.datasets import IDatasetQueryGalleryLabeled +from oml.interfaces.datasets import IQueryGalleryLabeledDataset from oml.utils.misc import one_hot @@ -35,7 +35,7 @@ def forward(self, labels: torch.Tensor, need_noise: bool = True) -> torch.Tensor return embeddings -class EmbeddingsQueryGalleryDataset(IDatasetQueryGalleryLabeled): +class EmbeddingsQueryGalleryDataset(IQueryGalleryLabeledDataset): def __init__( self, embeddings: FloatTensor, diff --git a/tests/test_oml/test_transforms/test_image_augs.py b/tests/test_oml/test_transforms/test_image_augs.py index 1445d0e84..6773a5bae 100644 --- a/tests/test_oml/test_transforms/test_image_augs.py +++ b/tests/test_oml/test_transforms/test_image_augs.py @@ -5,7 +5,7 @@ from omegaconf import OmegaConf from oml.const import CONFIGS_PATH, MOCK_DATASET_PATH -from oml.datasets.images import ImageDatasetLabeled +from oml.datasets.images import ImageLabeledDataset from oml.registry.transforms import TRANSFORMS_REGISTRY, get_transforms_by_cfg @@ -14,7 +14,7 @@ def test_transforms(aug_name: Optional[str]) -> None: df = pd.read_csv(MOCK_DATASET_PATH / "df.csv") transforms = get_transforms_by_cfg(OmegaConf.load(CONFIGS_PATH / "transforms" / f"{aug_name}.yaml")) - dataset = ImageDatasetLabeled(df=df, dataset_root=MOCK_DATASET_PATH, transform=transforms) + dataset = ImageLabeledDataset(df=df, dataset_root=MOCK_DATASET_PATH, transform=transforms) _ = dataset[0] @@ -23,6 +23,6 @@ def test_transforms(aug_name: Optional[str]) -> None: def test_default_transforms() -> None: df = pd.read_csv(MOCK_DATASET_PATH / "df.csv") - dataset = ImageDatasetLabeled(df=df, dataset_root=MOCK_DATASET_PATH, transform=None) + dataset = ImageLabeledDataset(df=df, dataset_root=MOCK_DATASET_PATH, transform=None) _ = dataset[0] assert True From 08d704876ac77204bf3838ba36103e286bfea6eb Mon Sep 17 00:00:00 2001 From: alekseysh Date: Fri, 19 Apr 2024 16:18:33 +0700 Subject: [PATCH 08/23] simplified postprocessing and inference --- .../examples_source/postprocessing/predict.md | 4 +- .../postprocessing/train_val.md | 4 +- docs/source/contents/datasets.rst | 13 +- docs/source/contents/interfaces.rst | 12 +- docs/source/contents/postprocessing.rst | 32 +-- .../postprocessor/pairwise_embeddings.yaml | 11 - ...ise_images.yaml => pairwise_reranker.yaml} | 6 +- oml/datasets/pairs.py | 110 ++------ oml/inference/flat.py | 100 ------- oml/inference/pairs.py | 97 ------- oml/interfaces/datasets.py | 4 +- oml/interfaces/retrieval.py | 53 +--- .../pipelines/train_postprocessor.py | 4 +- oml/registry/postprocessors.py | 17 +- oml/retrieval/postprocessors/pairwise.py | 255 ++++-------------- .../postprocessor_train.yaml | 7 +- .../postprocessor_validate.yaml | 7 +- .../validate_postprocessor.py | 166 +++++++++++- .../visualisation.ipynb | 7 +- .../test_metrics/test_embedding_metrics.py | 6 +- .../test_pairwise_embeddings.py | 10 +- .../test_pairwise_images.py | 4 +- .../configs/train_postprocessor.yaml | 7 +- .../test_pipelines/configs/validate.yaml | 6 +- 24 files changed, 275 insertions(+), 667 deletions(-) delete mode 100644 oml/configs/postprocessor/pairwise_embeddings.yaml rename oml/configs/postprocessor/{pairwise_images.yaml => pairwise_reranker.yaml} (77%) delete mode 100644 oml/inference/flat.py delete mode 100644 oml/inference/pairs.py diff --git a/docs/readme/examples_source/postprocessing/predict.md b/docs/readme/examples_source/postprocessing/predict.md index 22d6cc5f3..0c6edeafc 100644 --- a/docs/readme/examples_source/postprocessing/predict.md +++ b/docs/readme/examples_source/postprocessing/predict.md @@ -12,7 +12,7 @@ from oml.datasets.base import DatasetQueryGallery from oml.inference.flat import inference_on_dataframe from oml.models import ConcatSiamese, ViTExtractor from oml.registry.transforms import get_transforms_for_pretrained -from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor +from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.utils.download_mock_dataset import download_mock_dataset from oml.utils.misc_torch import pairwise_dist @@ -32,7 +32,7 @@ print("\nOriginal predictions:\n", torch.topk(distances, dim=1, k=3, largest=Fal # 2. Let's initialise a random pairwise postprocessor to perform re-ranking siamese = ConcatSiamese(extractor=extractor, mlp_hidden_dims=[100]) # Note! Replace it with your trained postprocessor -postprocessor = PairwiseImagesPostprocessor(top_n=3, pairwise_model=siamese, transforms=transforms) +postprocessor = PairwiseReranker(top_n=3, pairwise_model=siamese, transforms=transforms) dataset = DatasetQueryGallery(df_val, extra_data={"embeddings": emb_val}, transform=transforms) loader = DataLoader(dataset, batch_size=4) diff --git a/docs/readme/examples_source/postprocessing/train_val.md b/docs/readme/examples_source/postprocessing/train_val.md index 345f30c8a..f2ee90382 100644 --- a/docs/readme/examples_source/postprocessing/train_val.md +++ b/docs/readme/examples_source/postprocessing/train_val.md @@ -15,7 +15,7 @@ from oml.inference.flat import inference_on_dataframe from oml.metrics.embeddings import EmbeddingMetrics from oml.miners.pairs import PairsMiner from oml.models import ConcatSiamese, ViTExtractor -from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor +from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.samplers.balance import BalanceSampler from oml.transforms.images.torchvision import get_normalisation_resize_torch from oml.utils.download_mock_dataset import download_mock_dataset @@ -54,7 +54,7 @@ for batch in train_loader: val_dataset = DatasetQueryGallery(df=df_val, extra_data={"embeddings": embeddings_val}, transform=transform) valid_loader = DataLoader(val_dataset, batch_size=4, shuffle=False) -postprocessor = PairwiseImagesPostprocessor(top_n=3, pairwise_model=siamese, transforms=transform) +postprocessor = PairwiseReranker(top_n=3, pairwise_model=siamese, transforms=transform) calculator = EmbeddingMetrics(postprocessor=postprocessor) calculator.setup(num_samples=len(val_dataset)) diff --git a/docs/source/contents/datasets.rst b/docs/source/contents/datasets.rst index 41bbc9cc0..b46ae9e20 100644 --- a/docs/source/contents/datasets.rst +++ b/docs/source/contents/datasets.rst @@ -37,18 +37,9 @@ ImageQueryGalleryLabeledDataset .. automethod:: get_query_ids .. automethod:: get_gallery_ids -EmbeddingPairsDataset +PairDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.pairs.EmbeddingPairsDataset - :undoc-members: - :show-inheritance: - - .. automethod:: __init__ - .. automethod:: __getitem__ - -ImagePairsDataset -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.pairs.ImagePairsDataset +.. autoclass:: oml.datasets.pairs.PairDataset :undoc-members: :show-inheritance: diff --git a/docs/source/contents/interfaces.rst b/docs/source/contents/interfaces.rst index 7e7224491..146b0389f 100644 --- a/docs/source/contents/interfaces.rst +++ b/docs/source/contents/interfaces.rst @@ -86,9 +86,9 @@ IQueryGalleryLabeledDataset .. automethod:: get_gallery_ids .. automethod:: get_labels -IPairsDataset +IPairDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.interfaces.datasets.IPairsDataset +.. autoclass:: oml.interfaces.datasets.IPairDataset :undoc-members: :show-inheritance: @@ -138,3 +138,11 @@ IPipelineLogger .. automethod:: log_figure .. automethod:: log_pipeline_info + +IRetrievalPostprocessor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: oml.interfaces.retrieval.IRetrievalPostprocessor + :undoc-members: + :show-inheritance: + + .. automethod:: process diff --git a/docs/source/contents/postprocessing.rst b/docs/source/contents/postprocessing.rst index a2e0be7b4..5af9aae67 100644 --- a/docs/source/contents/postprocessing.rst +++ b/docs/source/contents/postprocessing.rst @@ -7,37 +7,11 @@ Retrieval Post-Processing .. contents:: :local: -IDistancesPostprocessor +PairwiseReranker ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.interfaces.retrieval.IDistancesPostprocessor - :undoc-members: - :show-inheritance: - - .. automethod:: process - -PairwisePostprocessor -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwisePostprocessor - :undoc-members: - :show-inheritance: - - .. automethod:: process - .. automethod:: inference - -PairwiseEmbeddingsPostprocessor -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwiseEmbeddingsPostprocessor +.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwiseReranker :undoc-members: :show-inheritance: .. automethod:: __init__ - .. automethod:: inference - -PairwiseImagesPostprocessor -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwiseImagesPostprocessor - :undoc-members: - :show-inheritance: - - .. automethod:: __init__ - .. automethod:: inference + .. automethod:: process diff --git a/oml/configs/postprocessor/pairwise_embeddings.yaml b/oml/configs/postprocessor/pairwise_embeddings.yaml deleted file mode 100644 index 4399dc9e5..000000000 --- a/oml/configs/postprocessor/pairwise_embeddings.yaml +++ /dev/null @@ -1,11 +0,0 @@ -name: pairwise_embeddings -args: - top_n: 5 - pairwise_model: - name: linear_siamese - args: - feat_dim: 16 - identity_init: True - num_workers: 0 - batch_size: 4 - verbose: False diff --git a/oml/configs/postprocessor/pairwise_images.yaml b/oml/configs/postprocessor/pairwise_reranker.yaml similarity index 77% rename from oml/configs/postprocessor/pairwise_images.yaml rename to oml/configs/postprocessor/pairwise_reranker.yaml index 5cda25a92..50eeaf19d 100644 --- a/oml/configs/postprocessor/pairwise_images.yaml +++ b/oml/configs/postprocessor/pairwise_reranker.yaml @@ -1,4 +1,4 @@ -name: pairwise_images +name: pairwise_reranker args: top_n: 3 pairwise_model: @@ -12,10 +12,6 @@ args: remove_fc: True normalise_features: False weights: resnet50_moco_v2 - transforms: - name: norm_resize_torch - args: - im_size: 224 num_workers: 0 batch_size: 4 verbose: False diff --git a/oml/datasets/pairs.py b/oml/datasets/pairs.py index 0c7b44d10..cb3a2e41b 100644 --- a/oml/datasets/pairs.py +++ b/oml/datasets/pairs.py @@ -1,115 +1,43 @@ -from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Tuple from torch import Tensor -from oml.const import INDEX_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY, TBBoxes -from oml.datasets.images import ImageBaseDataset -from oml.interfaces.datasets import IPairsDataset -from oml.transforms.images.torchvision import get_normalisation_torch -from oml.transforms.images.utils import TTransforms -from oml.utils.images.images import TImReader, imread_pillow +from oml.const import INDEX_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY +from oml.interfaces.datasets import IBaseDataset, IPairDataset -# todo 522: make one modality agnostic instead of these two - -class EmbeddingPairsDataset(IPairsDataset): +class PairDataset(IPairDataset): """ - Dataset to iterate over pairs of embeddings. + Dataset to iterate over pairs of items. """ def __init__( self, - embeddings1: Tensor, - embeddings2: Tensor, + base_dataset: IBaseDataset, + pair_ids: List[Tuple[int, int]], pair_1st_key: str = PAIR_1ST_KEY, pair_2nd_key: str = PAIR_2ND_KEY, index_key: str = INDEX_KEY, ): - """ - - Args: - embeddings1: The first input embeddings - embeddings2: The second input embeddings - pair_1st_key: Key to put ``embeddings1`` into the batches - pair_2nd_key: Key to put ``embeddings2`` into the batches - index_key: Key to put samples' ids into the batches - - """ - assert embeddings1.shape == embeddings2.shape - assert embeddings1.ndim >= 2 + self.base_dataset = base_dataset + self.pair_ids = pair_ids self.pair_1st_key = pair_1st_key self.pair_2nd_key = pair_2nd_key - self.index_key = index_key - - self.embeddings1 = embeddings1 - self.embeddings2 = embeddings2 + self.index_key: str = index_key def __getitem__(self, idx: int) -> Dict[str, Tensor]: - return {self.pair_1st_key: self.embeddings1[idx], self.pair_2nd_key: self.embeddings2[idx], self.index_key: idx} - - def __len__(self) -> int: - return len(self.embeddings1) - - -class ImagePairsDataset(IPairsDataset): - """ - Dataset to iterate over pairs of images. - - """ - - def __init__( - self, - paths1: List[Path], - paths2: List[Path], - bboxes1: Optional[TBBoxes] = None, - bboxes2: Optional[TBBoxes] = None, - transform: Optional[TTransforms] = None, - f_imread: TImReader = imread_pillow, - pair_1st_key: str = PAIR_1ST_KEY, - pair_2nd_key: str = PAIR_2ND_KEY, - index_key: str = INDEX_KEY, - cache_size: Optional[int] = 0, - ): - """ - Args: - paths1: Paths to the 1st input images - paths2: Paths to the 2nd input images - bboxes1: Should be either ``None`` or a sequence of bboxes. - If an image has ``N`` boxes, duplicate its - path ``N`` times and provide bounding box for each of them. - If you want to get an embedding for the whole image, set bbox to ``None`` for - this particular image path. The format is ``x1, y1, x2, y2``. - bboxes2: The same as ``bboxes2``, but for the second inputs. - transform: Augmentations for the images, set ``None`` to perform only normalisation and casting to tensor - f_imread: Function to read the images - pair_1st_key: Key to put the 1st images into the batches - pair_2nd_key: Key to put the 2nd images into the batches - index_key: Key to put samples' ids into the batches - cache_size: Size of the dataset's cache - - """ - assert len(paths1) == len(paths2) - - if transform is None: - transform = get_normalisation_torch() - - cache_size = cache_size // 2 if cache_size else None - dataset_args = {"transform": transform, "f_imread": f_imread, "cache_size": cache_size} - self.dataset1 = ImageBaseDataset(paths=paths1, bboxes=bboxes1, **dataset_args) - self.dataset2 = ImageBaseDataset(paths=paths2, bboxes=bboxes2, **dataset_args) - - self.pair_1st_key = pair_1st_key - self.pair_2nd_key = pair_2nd_key - self.index_key = index_key - - def __getitem__(self, idx: int) -> Dict[str, Union[int, Dict[str, Any]]]: - return {self.pair_1st_key: self.dataset1[idx], self.pair_2nd_key: self.dataset2[idx], self.index_key: idx} + i1, i2 = self.pair_ids[idx] + key = self.base_dataset.input_tensors_key + return { + self.pair_1st_key: self.base_dataset[i1][key], + self.pair_2nd_key: self.base_dataset[i2][key], + self.index_key: idx, + } def __len__(self) -> int: - return len(self.dataset1) + return len(self.pair_ids) -__all__ = ["EmbeddingPairsDataset", "ImagePairsDataset"] +__all__ = ["PairDataset"] diff --git a/oml/inference/flat.py b/oml/inference/flat.py deleted file mode 100644 index e8e4c063f..000000000 --- a/oml/inference/flat.py +++ /dev/null @@ -1,100 +0,0 @@ -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -import pandas as pd -import torch -from pandas import DataFrame -from torch import Tensor, nn - -from oml.const import PATHS_COLUMN, SPLIT_COLUMN -from oml.datasets.images import ImageBaseDataset -from oml.inference.abstract import _inference -from oml.interfaces.models import IExtractor -from oml.transforms.images.utils import TTransforms -from oml.utils.dataframe_format import check_retrieval_dataframe_format -from oml.utils.images.images import TImReader -from oml.utils.misc_torch import get_device - - -@torch.no_grad() -def inference_on_images( - model: nn.Module, - paths: List[Path], - transform: TTransforms, - num_workers: int, - batch_size: int, - verbose: bool = False, - f_imread: Optional[TImReader] = None, - use_fp16: bool = False, - accumulate_on_cpu: bool = True, -) -> Tensor: - dataset = ImageBaseDataset(paths=paths, bboxes=None, transform=transform, f_imread=f_imread, cache_size=0) - device = get_device(model) - - def _apply(model_: nn.Module, batch_: Dict[str, Any]) -> Tensor: - return model_(batch_[dataset.input_tensors_key].to(device)) - - outputs = _inference( - model=model, - apply_model=_apply, - dataset=dataset, - num_workers=num_workers, - batch_size=batch_size, - verbose=verbose, - use_fp16=use_fp16, - accumulate_on_cpu=accumulate_on_cpu, - ) - - return outputs - - -def inference_on_dataframe( - dataset_root: Union[Path, str], - dataframe_name: str, - extractor: IExtractor, - transforms: TTransforms, - output_cache_path: Optional[Union[str, Path]] = None, - num_workers: int = 0, - batch_size: int = 128, - use_fp16: bool = False, -) -> Tuple[Tensor, Tensor, DataFrame, DataFrame]: - df = pd.read_csv(Path(dataset_root) / dataframe_name) - - # it has now affect if paths are global already - df[PATHS_COLUMN] = df[PATHS_COLUMN].apply(lambda x: Path(dataset_root) / x) - - check_retrieval_dataframe_format(df) - - if (output_cache_path is not None) and Path(output_cache_path).is_file(): - embeddings = torch.load(output_cache_path, map_location="cpu") - print("Embeddings have been loaded from the disk.") - else: - embeddings = inference_on_images( - model=extractor, - paths=df[PATHS_COLUMN], - transform=transforms, - num_workers=num_workers, - batch_size=batch_size, - verbose=True, - use_fp16=use_fp16, - accumulate_on_cpu=True, - ) - if output_cache_path is not None: - torch.save(embeddings, output_cache_path) - print("Embeddings have been saved to the disk.") - - train_mask = df[SPLIT_COLUMN] == "train" - - emb_train = embeddings[train_mask] - emb_val = embeddings[~train_mask] - - df_train = df[train_mask] - df_train.reset_index(inplace=True, drop=True) - - df_val = df[~train_mask] - df_val.reset_index(inplace=True, drop=True) - - return emb_train, emb_val, df_train, df_val - - -__all__ = ["inference_on_images", "inference_on_dataframe"] diff --git a/oml/inference/pairs.py b/oml/inference/pairs.py deleted file mode 100644 index 42317c97c..000000000 --- a/oml/inference/pairs.py +++ /dev/null @@ -1,97 +0,0 @@ -from pathlib import Path -from typing import Any, Dict, List, Optional - -from torch import Tensor - -from oml.datasets.pairs import EmbeddingPairsDataset, ImagePairsDataset -from oml.inference.abstract import _inference -from oml.interfaces.models import IPairwiseModel -from oml.transforms.images.utils import TTransforms -from oml.utils.images.images import TImReader -from oml.utils.misc_torch import get_device - - -def pairwise_inference_on_images( - model: IPairwiseModel, - paths1: List[Path], - paths2: List[Path], - transform: TTransforms, - num_workers: int, - batch_size: int, - verbose: bool = True, - f_imread: Optional[TImReader] = None, - use_fp16: bool = False, - accumulate_on_cpu: bool = True, -) -> Tensor: - device = get_device(model) - - dataset = ImagePairsDataset( - paths1=paths1, - paths2=paths2, - transform=transform, - f_imread=f_imread, - cache_size=0, - ) - - def _apply( - model_: IPairwiseModel, - batch_: Dict[str, Any], - ) -> Tensor: - pair1 = batch_[dataset.pair_1st_key][dataset.dataset1.input_tensors_key].to(device) - pair2 = batch_[dataset.pair_2nd_key][dataset.dataset2.input_tensors_key].to(device) - return model_.predict(pair1, pair2) - - output = _inference( - model=model, - apply_model=_apply, - dataset=dataset, - num_workers=num_workers, - batch_size=batch_size, - verbose=verbose, - use_fp16=use_fp16, - accumulate_on_cpu=accumulate_on_cpu, - ) - - return output - - -def pairwise_inference_on_embeddings( - model: IPairwiseModel, - embeddings1: Tensor, - embeddings2: Tensor, - num_workers: int, - batch_size: int, - verbose: bool = False, - use_fp16: bool = False, - accumulate_on_cpu: bool = True, -) -> Tensor: - device = get_device(model) - - dataset = EmbeddingPairsDataset(embeddings1=embeddings1, embeddings2=embeddings2) - - def _apply( - model_: IPairwiseModel, - batch_: Dict[str, Any], - ) -> Tensor: - pair1 = batch_[dataset.pair_1st_key].to(device) - pair2 = batch_[dataset.pair_2nd_key].to(device) - return model_.predict(pair1, pair2) - - output = _inference( - model=model, - apply_model=_apply, - dataset=dataset, - num_workers=num_workers, - batch_size=batch_size, - verbose=verbose, - use_fp16=use_fp16, - accumulate_on_cpu=accumulate_on_cpu, - ) - - return output - - -__all__ = [ - "pairwise_inference_on_images", - "pairwise_inference_on_embeddings", -] diff --git a/oml/interfaces/datasets.py b/oml/interfaces/datasets.py index 4726b7b36..bd3c9c93b 100644 --- a/oml/interfaces/datasets.py +++ b/oml/interfaces/datasets.py @@ -78,7 +78,7 @@ class IQueryGalleryLabeledDataset(IQueryGalleryDataset, ILabeledDataset, ABC): """ -class IPairsDataset(Dataset, ABC): +class IPairDataset(Dataset, ABC): """ This is an interface for the datasets which return pair of something. @@ -120,6 +120,6 @@ def visualize(self, idx: int, color: TColor) -> np.ndarray: "ILabeledDataset", "IQueryGalleryLabeledDataset", "IQueryGalleryDataset", - "IPairsDataset", + "IPairDataset", "IVisualizableDataset", ] diff --git a/oml/interfaces/retrieval.py b/oml/interfaces/retrieval.py index c593e8c14..ebf2ba552 100644 --- a/oml/interfaces/retrieval.py +++ b/oml/interfaces/retrieval.py @@ -1,56 +1,15 @@ -from typing import Any, Dict, List +from typing import Any -from torch import Tensor - -class IDistancesPostprocessor: +class IRetrievalPostprocessor: """ - This is a parent class for the classes which apply some postprocessing - after query-to-gallery distance matrix has been calculated. - For example, we may want to apply one of re-ranking techniques. + This is a base interface for the classes which somehow postprocess retrieval results. """ - def process(self, distances: Tensor, queries: Any, galleries: Any) -> Tensor: - """ - This method takes all the needed variables and returns - the modified matrix of distances, where some distances are - replaced with new ones. - - Args: - distances: Matrix with the shape of ``[Q, G]`` - queries: Queries in the amount of ``Q`` - galleries: Galleries in the amount of ``G`` - - Returns: - An updated distances matrix with the shape of ``[Q, G]`` - - """ - raise NotImplementedError() - - def process_by_dict(self, distances: Tensor, data: Dict[str, Any]) -> Tensor: - """ - This method is the analogue of ``process``, but data is passed as a dictionary, - so we need to use the corresponding keys, which also have to be obtainable by - ``needed_keys`` property. - - Args: - distances: Matrix with the shape of ``[Q, G]`` - data: Dictionary of data - - Returns: - An updated distances matrix with the shape of ``[Q, G]`` - - """ - raise NotImplementedError() - - @property - def needed_keys(self) -> List[str]: - """ - Returns: Keys that will be used to process data using ``process_by_dict`` - - """ + def process(self, *args, **kwargs) -> Any: + # todo 522: add actual signature later raise NotImplementedError() -__all__ = ["IDistancesPostprocessor"] +__all__ = ["IRetrievalPostprocessor"] diff --git a/oml/lightning/pipelines/train_postprocessor.py b/oml/lightning/pipelines/train_postprocessor.py index cc302941f..b5128eb0b 100644 --- a/oml/lightning/pipelines/train_postprocessor.py +++ b/oml/lightning/pipelines/train_postprocessor.py @@ -33,7 +33,7 @@ from oml.registry.optimizers import get_optimizer_by_cfg from oml.registry.postprocessors import get_postprocessor_by_cfg from oml.registry.transforms import get_transforms_by_cfg -from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor +from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.transforms.images.torchvision import get_normalisation_resize_torch from oml.utils.misc import dictconfig_to_dict, flatten_dict, set_global_seed @@ -128,7 +128,7 @@ def postprocessor_training_pipeline(cfg: DictConfig) -> None: loader_train, loader_val = get_loaders_with_embeddings(cfg) postprocessor = None if not cfg.get("postprocessor", None) else get_postprocessor_by_cfg(cfg["postprocessor"]) - assert isinstance(postprocessor, PairwiseImagesPostprocessor), "We support only images processing in this pipeline." + assert isinstance(postprocessor, PairwiseReranker), "We support only images processing in this pipeline." assert isinstance(postprocessor.model, IPairwiseModel), f"You model must be a child of {IPairwiseModel.__name__}" criterion = torch.nn.BCEWithLogitsLoss() diff --git a/oml/registry/postprocessors.py b/oml/registry/postprocessors.py index 30e8ee57d..76436427c 100644 --- a/oml/registry/postprocessors.py +++ b/oml/registry/postprocessors.py @@ -1,25 +1,18 @@ from typing import Any, Dict from oml.const import TCfg -from oml.interfaces.retrieval import IDistancesPostprocessor +from oml.interfaces.retrieval import IRetrievalPostprocessor from oml.registry.models import get_pairwise_model_by_cfg -from oml.registry.transforms import get_transforms_by_cfg -from oml.retrieval.postprocessors.pairwise import ( - PairwiseEmbeddingsPostprocessor, - PairwiseImagesPostprocessor, -) +from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.utils.misc import dictconfig_to_dict POSTPROCESSORS_REGISTRY = { - "pairwise_images": PairwiseImagesPostprocessor, - "pairwise_embeddings": PairwiseEmbeddingsPostprocessor, + "pairwise_reranker": PairwiseReranker, } -def get_postprocessor(name: str, **kwargs: Dict[str, Any]) -> IDistancesPostprocessor: +def get_postprocessor(name: str, **kwargs: Dict[str, Any]) -> IRetrievalPostprocessor: constructor = POSTPROCESSORS_REGISTRY[name] - if "transforms" in kwargs: - kwargs["transforms"] = get_transforms_by_cfg(kwargs["transforms"]) if "pairwise_model" in kwargs: kwargs["pairwise_model"] = get_pairwise_model_by_cfg(kwargs["pairwise_model"]) @@ -27,7 +20,7 @@ def get_postprocessor(name: str, **kwargs: Dict[str, Any]) -> IDistancesPostproc return constructor(**kwargs) -def get_postprocessor_by_cfg(cfg: TCfg) -> IDistancesPostprocessor: +def get_postprocessor_by_cfg(cfg: TCfg) -> IRetrievalPostprocessor: cfg = dictconfig_to_dict(cfg) postprocessor = get_postprocessor(cfg["name"], **cfg["args"]) return postprocessor diff --git a/oml/retrieval/postprocessors/pairwise.py b/oml/retrieval/postprocessors/pairwise.py index 39ee69447..05a7b87fe 100644 --- a/oml/retrieval/postprocessors/pairwise.py +++ b/oml/retrieval/postprocessors/pairwise.py @@ -1,100 +1,14 @@ -import itertools -from abc import ABC -from pathlib import Path -from typing import Any, Dict, List - -import numpy as np import torch -from torch import Tensor -from oml.const import EMBEDDINGS_KEY, IS_GALLERY_KEY, IS_QUERY_KEY, PATHS_KEY -from oml.inference.pairs import ( - pairwise_inference_on_embeddings, - pairwise_inference_on_images, -) +from oml.inference.abstract import pairwise_inference +from oml.interfaces.datasets import IDatasetQueryGallery from oml.interfaces.models import IPairwiseModel -from oml.interfaces.retrieval import IDistancesPostprocessor -from oml.transforms.images.utils import TTransforms -from oml.utils.misc_torch import assign_2d - - -class PairwisePostprocessor(IDistancesPostprocessor, ABC): - """ - This postprocessor allows us to re-estimate the distances between queries and ``top-n`` galleries - closest to them. It creates pairs of queries and galleries and feeds them to a pairwise model. - - """ - - top_n: int - verbose: bool = False - - def process(self, distances: Tensor, queries: Any, galleries: Any) -> Tensor: - """ - Args: - distances: Matrix with the shape of ``[Q, G]`` - queries: Queries in the amount of ``Q`` - galleries: Galleries in the amount of ``G`` - - Returns: - Distance matrix with the shape of ``[Q, G]``, - where ``top_n`` minimal values in each row have been updated by the pairwise model, - other distances are shifted by a margin to keep the relative order. - - """ - n_queries = len(queries) - n_galleries = len(galleries) - - assert list(distances.shape) == [n_queries, n_galleries] - - # 1. Adjust top_n with respect to the actual gallery size and find top-n pairs - top_n = min(self.top_n, n_galleries) - ii_top = torch.topk(distances, k=top_n, largest=False)[1] - - # 2. Create (n_queries * top_n) pairs of each query and related galleries and re-estimate distances for them - if self.verbose: - print("\nPostprocessor's inference has been started...") - distances_upd = self.inference(queries=queries, galleries=galleries, ii_top=ii_top, top_n=top_n) - distances_upd = distances_upd.to(distances.device).to(distances.dtype) - - # 3. Update distances for top-n galleries - # The idea is that we somehow permute top-n galleries, but rest of the galleries - # we keep in the end of the list as before permutation. - # To do so, we add an offset to these galleries' distances (which haven't participated in the permutation) - if top_n < n_galleries: - # Here we use the fact that distances not participating in permutation start with top_n + 1 position - min_in_old_distances = torch.topk(distances, k=top_n + 1, largest=False)[0][:, -1] - max_in_new_distances = distances_upd.max(dim=1)[0] - offset = max_in_new_distances - min_in_old_distances + 1e-5 # we also need some eps if max == min - distances += offset.unsqueeze(-1) - else: - # Pairwise postprocessor has been applied to all possible pairs, so, there are no rest distances. - # Thus, we don't need to care about order and offset at all. - pass - - distances = assign_2d(x=distances, indices=ii_top, new_values=distances_upd) - - assert list(distances.shape) == [n_queries, n_galleries] +from oml.interfaces.retrieval import IRetrievalPostprocessor +from oml.retrieval.prediction import RetrievalPrediction +from oml.utils.misc_torch import cat_two_sorted_tensors_and_keep_it_sorted, take_2d - return distances - def inference(self, queries: Any, galleries: Any, ii_top: Tensor, top_n: int) -> Tensor: - """ - Depends on the exact types of queries/galleries this method may be implemented differently. - - Args: - queries: Queries in the amount of ``Q`` - galleries: Galleries in the amount of ``G`` - ii_top: Indices of the closest galleries with the shape of ``[Q, top_n]`` - top_n: Number of the closest galleries to re-rank - - Returns: - An updated distance matrix with the shape of ``[Q, G]`` - - """ - raise NotImplementedError() - - -class PairwiseEmbeddingsPostprocessor(PairwisePostprocessor): +class PairwiseReranker(IRetrievalPostprocessor): def __init__( self, top_n: int, @@ -103,155 +17,84 @@ def __init__( batch_size: int, verbose: bool = False, use_fp16: bool = False, - is_query_key: str = IS_QUERY_KEY, - is_gallery_key: str = IS_GALLERY_KEY, - embeddings_key: str = EMBEDDINGS_KEY, ): """ + Args: top_n: Model will be applied to the ``num_queries * top_n`` pairs formed by each query and ``top_n`` most relevant galleries. - pairwise_model: Model which is able to take two embeddings as inputs + pairwise_model: Model which is able to take two items as inputs and estimate the *distance* (not in a strictly mathematical sense) between them. num_workers: Number of workers in DataLoader batch_size: Batch size that will be used in DataLoader verbose: Set ``True`` if you want to see progress bar for an inference use_fp16: Set ``True`` if you want to use half precision - is_query_key: Key to access a binary mask indicates queries in case of using ``process_by_dict`` - is_gallery_key: Key to access a binary mask indicates galleries in case of using ``process_by_dict`` - embeddings_key: Key to access embeddings in case of using ``process_by_dict`` """ - assert top_n > 1, "Number of galleries for each query to process has to be greater than 1." + assert top_n > 1, "The number of the retrieved results for each query to process has to be greater than 1." self.top_n = top_n self.model = pairwise_model + self.num_workers = num_workers self.batch_size = batch_size self.verbose = verbose self.use_fp16 = use_fp16 - self.is_query_key = is_query_key - self.is_gallery_key = is_gallery_key - self.embeddings_key = embeddings_key - - def inference(self, queries: Tensor, galleries: Tensor, ii_top: Tensor, top_n: int) -> Tensor: + def process(self, prediction: RetrievalPrediction, dataset: IDatasetQueryGallery) -> RetrievalPrediction: """ - Args: - queries: Queries representations with the shape of ``[Q, *]`` - galleries: Galleries representations with the shape of ``[G, *]`` - ii_top: Indices of the closest galleries with the shape of ``[Q, top_n]`` - top_n: Number of the closest galleries to re-rank - Returns: - Updated distance matrix with the shape of ``[Q, G]`` + Note, the new distances to the ``top_n`` items produced by the pairwise model may be adjusted + to remain distances sorted. Here is an example: + ``original_distances = [0.1, 0.2, 0.3, 0.5, 0.6], top_n = 3`` + Imagine, the postprocessor didn't change the order of the first 3 items (it's just a convenient example, + the logic remains the same), however the new values have a bigger scale: + ``distances_upd = [1, 2, 5, 0.5, 0.6]``. + Thus, we need to rescale the first three distances, so they don't go above ``0.5``. + The scaling factor is ``s = min(0.5, 0.6) / max(1, 2, 5) = 0.1``. Finally: + ``distances_upd_scaled = [0.1, 0.2, 0.5, 0.5, 0.6]``. + If concatenation of two distances is already sorted, we keep it untouched. """ - n_queries = len(queries) - queries = queries.repeat_interleave(top_n, dim=0) - galleries = galleries[ii_top.view(-1)] - distances_upd = pairwise_inference_on_embeddings( + top_n = min(self.top_n, prediction.top_n) + + retrieved_ids = prediction.retrieved_ids.clone() + distances = prediction.distances.clone() + + # let's list pairs of (query_i, gallery_j) we need to process + ids_q = dataset.get_query_ids().unsqueeze(-1).repeat_interleave(top_n) + ii_g = dataset.get_gallery_ids().unsqueeze(-1) + ids_g = ii_g[retrieved_ids[:, :top_n]].flatten() + assert len(ids_q) == len(ids_g) + pairs = list(zip(ids_q.tolist(), ids_g.tolist())) + + distances_top = pairwise_inference( model=self.model, - embeddings1=queries, - embeddings2=galleries, + base_dataset=dataset, + pair_ids=pairs, num_workers=self.num_workers, batch_size=self.batch_size, verbose=self.verbose, use_fp16=self.use_fp16, ) - distances_upd = distances_upd.view(n_queries, top_n) - return distances_upd - - def process_by_dict(self, distances: Tensor, data: Dict[str, Any]) -> Tensor: - queries = data[self.embeddings_key][data[self.is_query_key]] - galleries = data[self.embeddings_key][data[self.is_gallery_key]] - return self.process(distances=distances, queries=queries, galleries=galleries) - - @property - def needed_keys(self) -> List[str]: - return [self.is_query_key, self.is_gallery_key, self.embeddings_key] + distances_top = distances_top.view(distances.shape[0], top_n) + distances_upd, ii_rerank = distances_top.sort() + retrieved_ids_upd = take_2d(retrieved_ids, ii_rerank) -class PairwiseImagesPostprocessor(PairwisePostprocessor): - def __init__( - self, - top_n: int, - pairwise_model: IPairwiseModel, - transforms: TTransforms, - num_workers: int = 0, - batch_size: int = 128, - verbose: bool = True, - use_fp16: bool = False, - is_query_key: str = IS_QUERY_KEY, - is_gallery_key: str = IS_GALLERY_KEY, - paths_key: str = PATHS_KEY, - ): - """ - Args: - top_n: Model will be applied to the ``num_queries * top_n`` pairs formed by each query - and its ``top_n`` most relevant galleries. - pairwise_model: Model which is able to take two images as inputs - and estimate the *distance* (not in a strictly mathematical sense) between them. - transforms: Transforms that will be applied to an image - num_workers: Number of workers in DataLoader - batch_size: Batch size that will be used in DataLoader - verbose: Set ``True`` if you want to see progress bar for an inference - use_fp16: Set ``True`` if you want to use half precision - is_query_key: Key to access a binary mask indicates queries in case of using ``process_by_dict`` - is_gallery_key: Key to access a binary mask indicates galleries in case of using ``process_by_dict`` - paths_key: Key to access paths to images in case of using ``process_by_dict`` - - """ - assert top_n > 1, "Number of galleries for each query to process has to be greater than 1." - - self.top_n = top_n - self.model = pairwise_model - self.image_transforms = transforms - self.num_workers = num_workers - self.batch_size = batch_size - self.verbose = verbose - self.use_fp16 = use_fp16 - - self.is_query_key = is_query_key - self.is_gallery_key = is_gallery_key - self.paths_key = paths_key + # Stack with the unprocessed values outside the first top_n items + if top_n < distances.shape[1]: + distances_upd = cat_two_sorted_tensors_and_keep_it_sorted(distances_upd, distances[:, top_n:]) + retrieved_ids_upd = torch.concat([retrieved_ids_upd, retrieved_ids[:, top_n:]], dim=1).long() - def inference(self, queries: List[Path], galleries: List[Path], ii_top: Tensor, top_n: int) -> Tensor: - """ - Args: - queries: Paths to queries with the length of ``Q`` - galleries: Paths to galleries with the length of ``G`` - ii_top: Indices of the closest galleries with the shape of ``[Q, top_n]`` - top_n: Number of the closest galleries to re-rank + assert distances_upd.shape == distances.shape + assert retrieved_ids_upd.shape == retrieved_ids.shape - Returns: - Updated distance matrix with the shape of ``[Q, G]`` - - """ - n_queries = len(queries) - queries = list(itertools.chain.from_iterable(itertools.repeat(x, top_n) for x in queries)) - galleries = [galleries[i] for i in ii_top.view(-1)] - distances_upd = pairwise_inference_on_images( - model=self.model, - paths1=queries, - paths2=galleries, - transform=self.image_transforms, - num_workers=self.num_workers, - batch_size=self.batch_size, - verbose=self.verbose, - use_fp16=self.use_fp16, - ) - distances_upd = distances_upd.view(n_queries, top_n) - return distances_upd + prediction_upd = RetrievalPrediction(distances_upd, retrieved_ids=retrieved_ids_upd, gt_ids=prediction.gt_ids) + return prediction_upd - def process_by_dict(self, distances: Tensor, data: Dict[str, Any]) -> Tensor: - queries = np.array(data[self.paths_key])[data[self.is_query_key]] - galleries = np.array(data[self.paths_key])[data[self.is_gallery_key]] - return self.process(distances=distances, queries=queries, galleries=galleries) - @property - def needed_keys(self) -> List[str]: - return [self.is_query_key, self.is_gallery_key, self.paths_key] +__all__ = ["PairwiseReranker"] -__all__ = ["PairwisePostprocessor", "PairwiseEmbeddingsPostprocessor", "PairwiseImagesPostprocessor"] +__all__ = ["PairwiseImagesPostprocessor"] diff --git a/pipelines/postprocessing/pairwise_postprocessing/postprocessor_train.yaml b/pipelines/postprocessing/pairwise_postprocessing/postprocessor_train.yaml index e57305d61..220023636 100644 --- a/pipelines/postprocessing/pairwise_postprocessing/postprocessor_train.yaml +++ b/pipelines/postprocessing/pairwise_postprocessing/postprocessor_train.yaml @@ -84,15 +84,10 @@ transforms_train: batch_size_inference: 128 postprocessor: - name: pairwise_images + name: pairwise_reranker args: top_n: 5 pairwise_model: ${pairwise_model} - transforms: - name: norm_resize_hypvit_torch - args: - im_size: 224 - crop_size: 224 num_workers: ${num_workers} batch_size: ${batch_size_inference} verbose: True diff --git a/pipelines/postprocessing/pairwise_postprocessing/postprocessor_validate.yaml b/pipelines/postprocessing/pairwise_postprocessing/postprocessor_validate.yaml index 0ed09cbe7..749a4d0a8 100644 --- a/pipelines/postprocessing/pairwise_postprocessing/postprocessor_validate.yaml +++ b/pipelines/postprocessing/pairwise_postprocessing/postprocessor_validate.yaml @@ -26,7 +26,7 @@ extractor: weights: ${extractor_weights} postprocessor: - name: pairwise_images + name: pairwise_reranker args: top_n: 5 pairwise_model: @@ -42,11 +42,6 @@ postprocessor: normalise_features: False use_multi_scale: False weights: null - transforms: - name: norm_resize_hypvit_torch - args: - im_size: 224 - crop_size: 224 num_workers: ${num_workers} batch_size: ${bs_val} verbose: True diff --git a/pipelines/postprocessing/pairwise_postprocessing/validate_postprocessor.py b/pipelines/postprocessing/pairwise_postprocessing/validate_postprocessor.py index 9f075ab9e..1af275f68 100644 --- a/pipelines/postprocessing/pairwise_postprocessing/validate_postprocessor.py +++ b/pipelines/postprocessing/pairwise_postprocessing/validate_postprocessor.py @@ -1,14 +1,162 @@ -import hydra -from omegaconf import DictConfig +from pathlib import Path +from typing import Any, Callable, Dict, List, Tuple -from oml.const import HYDRA_BEHAVIOUR -from oml.lightning.pipelines.validate import extractor_validation_pipeline +import torch +from torch import FloatTensor, Tensor, nn +from torch.utils.data import DataLoader, Dataset +from tqdm.auto import tqdm +from oml.datasets import PairDataset +from oml.ddp.patching import patch_dataloader_to_ddp +from oml.ddp.utils import get_world_size_safe, is_ddp, sync_dicts_ddp +from oml.interfaces.datasets import IBaseDataset +from oml.interfaces.models import IPairwiseModel +from oml.utils.misc_torch import ( + drop_duplicates_by_ids, + get_device, + temporary_setting_model_mode, +) -@hydra.main(config_path=".", config_name="postprocessor_validate.yaml", version_base=HYDRA_BEHAVIOUR) -def main_hydra(cfg: DictConfig) -> None: - extractor_validation_pipeline(cfg) +@torch.no_grad() +def _inference( + model: nn.Module, + apply_model: Callable[[nn.Module, Dict[str, Any]], FloatTensor], + dataset: Dataset, # type: ignore + num_workers: int, + batch_size: int, + verbose: bool, + use_fp16: bool, + accumulate_on_cpu: bool = True, +) -> FloatTensor: + # todo: rework hasattr later + assert hasattr(dataset, "index_key"), "We expect that your dataset returns samples ids in __getitem__ method" -if __name__ == "__main__": - main_hydra() + loader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) + + if is_ddp(): + loader = patch_dataloader_to_ddp(loader) + + if verbose: + loader = tqdm(loader, desc=str(get_device(model))) + + outputs_list = [] + ids = [] + + with torch.autocast(device_type="cuda", dtype=torch.float16 if use_fp16 else torch.float32): + with temporary_setting_model_mode(model, set_train=False): + for batch in loader: + out = apply_model(model, batch) + if accumulate_on_cpu: + out = out.cpu() + outputs_list.append(out) + ids.extend(batch[dataset.index_key].long().tolist()) + + outputs = torch.cat(outputs_list).detach() + + data_to_sync = {"outputs": outputs, "ids": ids} + data_synced = sync_dicts_ddp(data_to_sync, world_size=get_world_size_safe()) + outputs, ids = data_synced["outputs"], data_synced["ids"] + + ids, outputs = drop_duplicates_by_ids(ids=ids, data=outputs, sort=True) + + assert len(outputs) == len(dataset), "Data was not collected correctly after DDP sync." + assert list(range(len(dataset))) == ids, "Data was not collected correctly after DDP sync." + + return outputs.float() + + +@torch.no_grad() +def inference( + model: nn.Module, + dataset: IBaseDataset, + batch_size: int, + num_workers: int = 0, + verbose: bool = False, + use_fp16: bool = False, + accumulate_on_cpu: bool = True, +) -> FloatTensor: + device = get_device(model) + + # Inference on IBaseDataset + + def apply(model_: nn.Module, batch_: Dict[str, Any]) -> FloatTensor: + return model_(batch_[dataset.input_tensors_key].to(device)) + + return _inference( + model=model, + apply_model=apply, + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + verbose=verbose, + use_fp16=use_fp16, + accumulate_on_cpu=accumulate_on_cpu, + ) + + +def pairwise_inference( + model: IPairwiseModel, + base_dataset: IBaseDataset, + pair_ids: List[Tuple[int, int]], + num_workers: int, + batch_size: int, + verbose: bool = True, + use_fp16: bool = False, + accumulate_on_cpu: bool = True, +) -> FloatTensor: + device = get_device(model) + + dataset = PairDataset(base_dataset=base_dataset, pair_ids=pair_ids) + + def _apply( + model_: IPairwiseModel, + batch_: Dict[str, Any], + ) -> Tensor: + pair1 = batch_[dataset.pair_1st_key].to(device) + pair2 = batch_[dataset.pair_2nd_key].to(device) + return model_.predict(pair1, pair2) + + output = _inference( + model=model, + apply_model=_apply, + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + verbose=verbose, + use_fp16=use_fp16, + accumulate_on_cpu=accumulate_on_cpu, + ) + + return output + + +def inference_cached( + dataset: IBaseDataset, + extractor: nn.Module, + output_cache_path: str = "inference_cache.pth", + num_workers: int = 0, + batch_size: int = 128, + use_fp16: bool = False, +) -> FloatTensor: + if Path(output_cache_path).is_file(): + outputs = torch.load(output_cache_path, map_location="cpu") + print(f"Model outputs have been loaded from {output_cache_path}.") + else: + outputs = inference( + model=extractor, + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + use_fp16=use_fp16, + verbose=True, + accumulate_on_cpu=True, + ) + + torch.save(outputs, output_cache_path) + print(f"Model outputs have been saved to {output_cache_path}.") + + return outputs + + +__all__ = ["inference", "pairwise_inference", "inference_cached"] diff --git a/pipelines/postprocessing/pairwise_postprocessing/visualisation.ipynb b/pipelines/postprocessing/pairwise_postprocessing/visualisation.ipynb index d0fbe1d27..92ae65cc8 100644 --- a/pipelines/postprocessing/pairwise_postprocessing/visualisation.ipynb +++ b/pipelines/postprocessing/pairwise_postprocessing/visualisation.ipynb @@ -85,7 +85,7 @@ "source": [ "cfg_p = cfg + f\"\"\"\n", " postprocessor:\n", - " name: pairwise_images\n", + " name: pairwise_reranker\n", " args:\n", " top_n: 5\n", " pairwise_model:\n", @@ -100,11 +100,6 @@ " normalise_features: False\n", " use_multi_scale: False\n", " weights: null\n", - " transforms:\n", - " name: norm_resize_hypvit_torch\n", - " args:\n", - " im_size: 224\n", - " crop_size: 224\n", " num_workers: 10\n", " batch_size: 128\n", " verbose: True\n", diff --git a/tests/test_oml/test_metrics/test_embedding_metrics.py b/tests/test_oml/test_metrics/test_embedding_metrics.py index 89a2c3bc4..9ccca3993 100644 --- a/tests/test_oml/test_metrics/test_embedding_metrics.py +++ b/tests/test_oml/test_metrics/test_embedding_metrics.py @@ -18,16 +18,16 @@ ) from oml.metrics.embeddings import EmbeddingMetrics from oml.models.meta.siamese import LinearTrivialDistanceSiamese -from oml.retrieval.postprocessors.pairwise import PairwiseEmbeddingsPostprocessor +from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.utils.misc import compare_dicts_recursively, one_hot FEAT_DIM = 8 oh = partial(one_hot, dim=FEAT_DIM) -def get_trivial_postprocessor(top_n: int) -> PairwiseEmbeddingsPostprocessor: +def get_trivial_postprocessor(top_n: int) -> PairwiseReranker: model = LinearTrivialDistanceSiamese(feat_dim=FEAT_DIM, identity_init=True) - processor = PairwiseEmbeddingsPostprocessor(pairwise_model=model, top_n=top_n, num_workers=0, batch_size=64) + processor = PairwiseReranker(pairwise_model=model, top_n=top_n, num_workers=0, batch_size=64) return processor diff --git a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py index 3940c35ae..cd8e9715f 100644 --- a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py +++ b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py @@ -10,7 +10,7 @@ from oml.functional.metrics import calc_distance_matrix, calc_retrieval_metrics from oml.interfaces.models import IPairwiseModel from oml.models.meta.siamese import LinearTrivialDistanceSiamese -from oml.retrieval.postprocessors.pairwise import PairwiseEmbeddingsPostprocessor +from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.utils.misc import flatten_dict, one_hot from oml.utils.misc_torch import normalise, pairwise_dist @@ -62,7 +62,7 @@ def test_trivial_processing_does_not_change_distances_order( distances = calc_distance_matrix(embeddings, is_query, is_gallery) model = LinearTrivialDistanceSiamese(feat_dim=embeddings.shape[-1], identity_init=True) - processor = PairwiseEmbeddingsPostprocessor(pairwise_model=model, top_n=top_n, num_workers=0, batch_size=64) + processor = PairwiseReranker(pairwise_model=model, top_n=top_n, num_workers=0, batch_size=64) distances_processed = processor.process( queries=embeddings_query, @@ -128,7 +128,7 @@ def test_trivial_processing_fixes_broken_perfect_case() -> None: # Metrics after broken distances have been fixed model = LinearTrivialDistanceSiamese(feat_dim=gallery_embeddings.shape[-1], identity_init=True) - processor = PairwiseEmbeddingsPostprocessor(pairwise_model=model, top_n=top_n, batch_size=16, num_workers=0) + processor = PairwiseReranker(pairwise_model=model, top_n=top_n, batch_size=16, num_workers=0) distances_upd = processor.process(distances, query_embeddings, gallery_embeddings) metrics_upd = flatten_dict(calc_retrieval_metrics(distances=distances_upd, **args)) @@ -169,7 +169,7 @@ def test_trivial_processing_fixes_broken_perfect_case_2() -> None: # Now let's fix the error with dummy pairwise model model = DummyPairwise(distances_to_return=torch.tensor([3.5, 2.5])) - processor = PairwiseEmbeddingsPostprocessor(pairwise_model=model, top_n=2, batch_size=128, num_workers=0) + processor = PairwiseReranker(pairwise_model=model, top_n=2, batch_size=128, num_workers=0) distances_upd = processor.process( distances=distances, queries=torch.randn((1, FEAT_SIZE)), galleries=torch.randn((5, FEAT_SIZE)) ) @@ -215,7 +215,7 @@ def test_processing_not_changing_non_sensitive_metrics(top_n: int) -> None: metrics_before = calc_retrieval_metrics(distances=distances, **args) model = RandomPairwise() - processor = PairwiseEmbeddingsPostprocessor(pairwise_model=model, top_n=top_n, batch_size=4, num_workers=0) + processor = PairwiseReranker(pairwise_model=model, top_n=top_n, batch_size=4, num_workers=0) distances_upd = processor.process(distances=distances, queries=query_embeddings, galleries=gallery_embeddings) metrics_after = calc_retrieval_metrics(distances=distances_upd, **args) diff --git a/tests/test_oml/test_postprocessor/test_pairwise_images.py b/tests/test_oml/test_postprocessor/test_pairwise_images.py index cc1477929..5f5473038 100644 --- a/tests/test_oml/test_postprocessor/test_pairwise_images.py +++ b/tests/test_oml/test_postprocessor/test_pairwise_images.py @@ -9,7 +9,7 @@ from oml.inference.flat import inference_on_images from oml.models.meta.siamese import TrivialDistanceSiamese from oml.models.resnet.extractor import ResnetExtractor -from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor +from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.transforms.images.torchvision import get_normalisation_resize_torch from oml.transforms.images.utils import TTransforms from oml.utils.download_mock_dataset import download_mock_dataset @@ -51,7 +51,7 @@ def test_trivial_processing_does_not_change_distances_order(top_n: int) -> None: distances, queries, galleries = get_validation_results(model=extractor, transforms=transforms) - postprocessor = PairwiseImagesPostprocessor( + postprocessor = PairwiseReranker( top_n=top_n, pairwise_model=pairwise_model, transforms=transforms, diff --git a/tests/test_runs/test_pipelines/configs/train_postprocessor.yaml b/tests/test_runs/test_pipelines/configs/train_postprocessor.yaml index 22dbabd68..8e6f43110 100644 --- a/tests/test_runs/test_pipelines/configs/train_postprocessor.yaml +++ b/tests/test_runs/test_pipelines/configs/train_postprocessor.yaml @@ -78,15 +78,10 @@ transforms_train: batch_size_inference: 128 postprocessor: - name: pairwise_images + name: pairwise_reranker args: top_n: 5 pairwise_model: ${pairwise_model} - transforms: - name: norm_resize_hypvit_torch - args: - im_size: 32 - crop_size: 32 num_workers: 0 batch_size: ${batch_size_inference} verbose: True diff --git a/tests/test_runs/test_pipelines/configs/validate.yaml b/tests/test_runs/test_pipelines/configs/validate.yaml index 5096f57a6..1e9d53f11 100644 --- a/tests/test_runs/test_pipelines/configs/validate.yaml +++ b/tests/test_runs/test_pipelines/configs/validate.yaml @@ -16,7 +16,7 @@ num_workers: 0 bs_val: 2 postprocessor: - name: pairwise_images + name: pairwise_reranker args: top_n: 3 pairwise_model: @@ -30,10 +30,6 @@ postprocessor: remove_fc: True normalise_features: False weights: resnet50_moco_v2 - transforms: - name: norm_resize_torch - args: - im_size: 64 num_workers: 0 batch_size: 4 verbose: True From 05b814446d0344f65935998d08fee8b62ffcc864 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Fri, 19 Apr 2024 16:33:16 +0700 Subject: [PATCH 09/23] upd --- oml/inference/abstract.py | 108 +++++++++++- oml/metrics/embeddings.py | 4 +- oml/retrieval/postprocessors/pairwise.py | 7 +- oml/utils/misc_torch.py | 24 +++ .../validate_postprocessor.py | 166 +----------------- tests/test_oml/test_utils/test_misc_torch.py | 37 +++- 6 files changed, 176 insertions(+), 170 deletions(-) diff --git a/oml/inference/abstract.py b/oml/inference/abstract.py index 096b12e4d..8157b5eb6 100644 --- a/oml/inference/abstract.py +++ b/oml/inference/abstract.py @@ -1,12 +1,16 @@ -from typing import Any, Callable, Dict +from pathlib import Path +from typing import Any, Callable, Dict, List, Tuple import torch -from torch import Tensor, nn +from torch import FloatTensor, Tensor, nn from torch.utils.data import DataLoader, Dataset from tqdm.auto import tqdm +from oml.datasets import PairsDataset from oml.ddp.patching import patch_dataloader_to_ddp from oml.ddp.utils import get_world_size_safe, is_ddp, sync_dicts_ddp +from oml.interfaces.datasets import IBaseDataset +from oml.interfaces.models import IPairwiseModel from oml.utils.misc_torch import ( drop_duplicates_by_ids, get_device, @@ -17,14 +21,15 @@ @torch.no_grad() def _inference( model: nn.Module, - apply_model: Callable[[nn.Module, Dict[str, Any]], Tensor], + apply_model: Callable[[nn.Module, Dict[str, Any]], FloatTensor], dataset: Dataset, # type: ignore num_workers: int, batch_size: int, verbose: bool, use_fp16: bool, accumulate_on_cpu: bool = True, -) -> Tensor: +) -> FloatTensor: + # todo: rework hasattr later assert hasattr(dataset, "index_key"), "We expect that your dataset returns samples ids in __getitem__ method" loader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) @@ -58,7 +63,100 @@ def _inference( assert len(outputs) == len(dataset), "Data was not collected correctly after DDP sync." assert list(range(len(dataset))) == ids, "Data was not collected correctly after DDP sync." + return outputs.float() + + +@torch.no_grad() +def inference( + model: nn.Module, + dataset: IBaseDataset, + batch_size: int, + num_workers: int = 0, + verbose: bool = False, + use_fp16: bool = False, + accumulate_on_cpu: bool = True, +) -> FloatTensor: + device = get_device(model) + + # Inference on IBaseDataset + + def apply(model_: nn.Module, batch_: Dict[str, Any]) -> FloatTensor: + return model_(batch_[dataset.input_tensors_key].to(device)) + + return _inference( + model=model, + apply_model=apply, + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + verbose=verbose, + use_fp16=use_fp16, + accumulate_on_cpu=accumulate_on_cpu, + ) + + +def pairwise_inference( + model: IPairwiseModel, + base_dataset: IBaseDataset, + pair_ids: List[Tuple[int, int]], + num_workers: int, + batch_size: int, + verbose: bool = True, + use_fp16: bool = False, + accumulate_on_cpu: bool = True, +) -> FloatTensor: + device = get_device(model) + + dataset = PairsDataset(base_dataset=base_dataset, pair_ids=pair_ids) + + def _apply( + model_: IPairwiseModel, + batch_: Dict[str, Any], + ) -> Tensor: + pair1 = batch_[dataset.pair_1st_key].to(device) + pair2 = batch_[dataset.pair_2nd_key].to(device) + return model_.predict(pair1, pair2) + + output = _inference( + model=model, + apply_model=_apply, + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + verbose=verbose, + use_fp16=use_fp16, + accumulate_on_cpu=accumulate_on_cpu, + ) + + return output + + +def inference_cached( + dataset: IBaseDataset, + extractor: nn.Module, + output_cache_path: str = "inference_cache.pth", + num_workers: int = 0, + batch_size: int = 128, + use_fp16: bool = False, +) -> FloatTensor: + if Path(output_cache_path).is_file(): + outputs = torch.load(output_cache_path, map_location="cpu") + print(f"Model outputs have been loaded from {output_cache_path}.") + else: + outputs = inference( + model=extractor, + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + use_fp16=use_fp16, + verbose=True, + accumulate_on_cpu=True, + ) + + torch.save(outputs, output_cache_path) + print(f"Model outputs have been saved to {output_cache_path}.") + return outputs -__all__ = ["_inference"] +__all__ = ["inference", "pairwise_inference", "inference_cached"] diff --git a/oml/metrics/embeddings.py b/oml/metrics/embeddings.py index 53367a0a8..246be00d7 100644 --- a/oml/metrics/embeddings.py +++ b/oml/metrics/embeddings.py @@ -38,7 +38,7 @@ reduce_metrics, ) from oml.interfaces.metrics import IMetricDDP, IMetricVisualisable -from oml.interfaces.retrieval import IDistancesPostprocessor +from oml.interfaces.retrieval import IRetrievalPostprocessor from oml.metrics.accumulation import Accumulator from oml.utils.images.images import get_img_with_bbox, square_pad from oml.utils.misc import flatten_dict @@ -79,7 +79,7 @@ def __init__( pcf_variance: Tuple[float, ...] = (0.5,), categories_key: Optional[str] = None, sequence_key: Optional[str] = None, - postprocessor: Optional[IDistancesPostprocessor] = None, + postprocessor: Optional[IRetrievalPostprocessor] = None, metrics_to_exclude_from_visualization: Iterable[str] = (), return_only_overall_category: bool = False, visualize_only_overall_category: bool = True, diff --git a/oml/retrieval/postprocessors/pairwise.py b/oml/retrieval/postprocessors/pairwise.py index 05a7b87fe..e8f62295a 100644 --- a/oml/retrieval/postprocessors/pairwise.py +++ b/oml/retrieval/postprocessors/pairwise.py @@ -1,7 +1,7 @@ import torch from oml.inference.abstract import pairwise_inference -from oml.interfaces.datasets import IDatasetQueryGallery +from oml.interfaces.datasets import IQueryGalleryDataset from oml.interfaces.models import IPairwiseModel from oml.interfaces.retrieval import IRetrievalPostprocessor from oml.retrieval.prediction import RetrievalPrediction @@ -41,7 +41,7 @@ def __init__( self.verbose = verbose self.use_fp16 = use_fp16 - def process(self, prediction: RetrievalPrediction, dataset: IDatasetQueryGallery) -> RetrievalPrediction: + def process(self, prediction: RetrievalPrediction, dataset: IQueryGalleryDataset) -> RetrievalPrediction: """ Note, the new distances to the ``top_n`` items produced by the pairwise model may be adjusted @@ -95,6 +95,3 @@ def process(self, prediction: RetrievalPrediction, dataset: IDatasetQueryGallery __all__ = ["PairwiseReranker"] - - -__all__ = ["PairwiseImagesPostprocessor"] diff --git a/oml/utils/misc_torch.py b/oml/utils/misc_torch.py index 022a85b46..815ff59fb 100644 --- a/oml/utils/misc_torch.py +++ b/oml/utils/misc_torch.py @@ -59,6 +59,29 @@ def assign_2d(x: Tensor, indices: Tensor, new_values: Tensor) -> Tensor: return x +def cat_two_sorted_tensors_and_keep_it_sorted(x1: Tensor, x2: Tensor, eps: float = 1e-7) -> Tensor: + """ + Args: + x1: Sorted tensor with the shape of ``[N, M]`` + x2: Sorted tensor with the shape of ``[N, P]`` + eps: Eps to have a gap between the last x1 and the first x2 + + Returns: + Concatenation of two sorted tensors. + The first tensor may be rescaled if needed to keep the order sorted. + + """ + assert eps >= 0 + assert x1.shape[0] == x2.shape[0] + + scale = (x2[:, 0] / x1[:, -1]).view(-1, 1) + need_scaling = x1[:, -1] > x2[:, 0] + x1[need_scaling] = x1[need_scaling] * scale[need_scaling] - eps + + x = torch.concatenate([x1, x2], dim=1).float() + + return x + def elementwise_dist(x1: Tensor, x2: Tensor, p: int = 2) -> Tensor: """ Args: @@ -455,6 +478,7 @@ def _check_dimensions(self, n_components: int) -> None: __all__ = [ "elementwise_dist", + "cat_two_sorted_tensors_and_keep_it_sorted", "pairwise_dist", "OnlineCalc", "AvgOnline", diff --git a/pipelines/postprocessing/pairwise_postprocessing/validate_postprocessor.py b/pipelines/postprocessing/pairwise_postprocessing/validate_postprocessor.py index 1af275f68..9f075ab9e 100644 --- a/pipelines/postprocessing/pairwise_postprocessing/validate_postprocessor.py +++ b/pipelines/postprocessing/pairwise_postprocessing/validate_postprocessor.py @@ -1,162 +1,14 @@ -from pathlib import Path -from typing import Any, Callable, Dict, List, Tuple +import hydra +from omegaconf import DictConfig -import torch -from torch import FloatTensor, Tensor, nn -from torch.utils.data import DataLoader, Dataset -from tqdm.auto import tqdm +from oml.const import HYDRA_BEHAVIOUR +from oml.lightning.pipelines.validate import extractor_validation_pipeline -from oml.datasets import PairDataset -from oml.ddp.patching import patch_dataloader_to_ddp -from oml.ddp.utils import get_world_size_safe, is_ddp, sync_dicts_ddp -from oml.interfaces.datasets import IBaseDataset -from oml.interfaces.models import IPairwiseModel -from oml.utils.misc_torch import ( - drop_duplicates_by_ids, - get_device, - temporary_setting_model_mode, -) +@hydra.main(config_path=".", config_name="postprocessor_validate.yaml", version_base=HYDRA_BEHAVIOUR) +def main_hydra(cfg: DictConfig) -> None: + extractor_validation_pipeline(cfg) -@torch.no_grad() -def _inference( - model: nn.Module, - apply_model: Callable[[nn.Module, Dict[str, Any]], FloatTensor], - dataset: Dataset, # type: ignore - num_workers: int, - batch_size: int, - verbose: bool, - use_fp16: bool, - accumulate_on_cpu: bool = True, -) -> FloatTensor: - # todo: rework hasattr later - assert hasattr(dataset, "index_key"), "We expect that your dataset returns samples ids in __getitem__ method" - loader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) - - if is_ddp(): - loader = patch_dataloader_to_ddp(loader) - - if verbose: - loader = tqdm(loader, desc=str(get_device(model))) - - outputs_list = [] - ids = [] - - with torch.autocast(device_type="cuda", dtype=torch.float16 if use_fp16 else torch.float32): - with temporary_setting_model_mode(model, set_train=False): - for batch in loader: - out = apply_model(model, batch) - if accumulate_on_cpu: - out = out.cpu() - outputs_list.append(out) - ids.extend(batch[dataset.index_key].long().tolist()) - - outputs = torch.cat(outputs_list).detach() - - data_to_sync = {"outputs": outputs, "ids": ids} - data_synced = sync_dicts_ddp(data_to_sync, world_size=get_world_size_safe()) - outputs, ids = data_synced["outputs"], data_synced["ids"] - - ids, outputs = drop_duplicates_by_ids(ids=ids, data=outputs, sort=True) - - assert len(outputs) == len(dataset), "Data was not collected correctly after DDP sync." - assert list(range(len(dataset))) == ids, "Data was not collected correctly after DDP sync." - - return outputs.float() - - -@torch.no_grad() -def inference( - model: nn.Module, - dataset: IBaseDataset, - batch_size: int, - num_workers: int = 0, - verbose: bool = False, - use_fp16: bool = False, - accumulate_on_cpu: bool = True, -) -> FloatTensor: - device = get_device(model) - - # Inference on IBaseDataset - - def apply(model_: nn.Module, batch_: Dict[str, Any]) -> FloatTensor: - return model_(batch_[dataset.input_tensors_key].to(device)) - - return _inference( - model=model, - apply_model=apply, - dataset=dataset, - num_workers=num_workers, - batch_size=batch_size, - verbose=verbose, - use_fp16=use_fp16, - accumulate_on_cpu=accumulate_on_cpu, - ) - - -def pairwise_inference( - model: IPairwiseModel, - base_dataset: IBaseDataset, - pair_ids: List[Tuple[int, int]], - num_workers: int, - batch_size: int, - verbose: bool = True, - use_fp16: bool = False, - accumulate_on_cpu: bool = True, -) -> FloatTensor: - device = get_device(model) - - dataset = PairDataset(base_dataset=base_dataset, pair_ids=pair_ids) - - def _apply( - model_: IPairwiseModel, - batch_: Dict[str, Any], - ) -> Tensor: - pair1 = batch_[dataset.pair_1st_key].to(device) - pair2 = batch_[dataset.pair_2nd_key].to(device) - return model_.predict(pair1, pair2) - - output = _inference( - model=model, - apply_model=_apply, - dataset=dataset, - num_workers=num_workers, - batch_size=batch_size, - verbose=verbose, - use_fp16=use_fp16, - accumulate_on_cpu=accumulate_on_cpu, - ) - - return output - - -def inference_cached( - dataset: IBaseDataset, - extractor: nn.Module, - output_cache_path: str = "inference_cache.pth", - num_workers: int = 0, - batch_size: int = 128, - use_fp16: bool = False, -) -> FloatTensor: - if Path(output_cache_path).is_file(): - outputs = torch.load(output_cache_path, map_location="cpu") - print(f"Model outputs have been loaded from {output_cache_path}.") - else: - outputs = inference( - model=extractor, - dataset=dataset, - num_workers=num_workers, - batch_size=batch_size, - use_fp16=use_fp16, - verbose=True, - accumulate_on_cpu=True, - ) - - torch.save(outputs, output_cache_path) - print(f"Model outputs have been saved to {output_cache_path}.") - - return outputs - - -__all__ = ["inference", "pairwise_inference", "inference_cached"] +if __name__ == "__main__": + main_hydra() diff --git a/tests/test_oml/test_utils/test_misc_torch.py b/tests/test_oml/test_utils/test_misc_torch.py index efbf348ef..fb8f3f0dd 100644 --- a/tests/test_oml/test_utils/test_misc_torch.py +++ b/tests/test_oml/test_utils/test_misc_torch.py @@ -9,7 +9,7 @@ assign_2d, drop_duplicates_by_ids, elementwise_dist, - take_2d, + take_2d, cat_two_sorted_tensors_and_keep_it_sorted, ) @@ -23,6 +23,41 @@ def test_elementwise_dist() -> None: assert torch.isclose(val_torch, torch.tensor(val_custom)).all() +@pytest.mark.parametrize( + "x1,x2,e,expected", + [ + ( + # x1 + torch.tensor([[10, 20, 30], [40, 50, 60]]).float(), + # x2 + torch.tensor([[0.3, 0.4, 0.5], [0.6, 0.8, 0.9]]).float(), + # e + 0.001, + # expected: rescaling is needed + torch.tensor( + [ + [0.1 - 0.001, 0.2 - 0.001, 0.3 - 0.001, 0.3, 0.4, 0.5], + [0.4 - 0.001, 0.5 - 0.001, 0.6 - 0.001, 0.6, 0.8, 0.9], + ] + ).float(), + ), + ( + # x1 + torch.tensor([[-10, -5], [-20, -8]]).float(), + # x2 + torch.tensor([[0.3, 0.4, 0.5], [0.6, 0.8, 0.9]]).float(), + # e + 0.001, + # expected: rescaling is not needed, we jast concat + torch.tensor([[-10, -5, 0.3, 0.4, 0.5], [-20, -8, 0.6, 0.8, 0.9]]).float(), + ), + ], +) +def test_concat_two_sorted_tensors_with_rescaling(x1, x2, e, expected): # type: ignore + out = cat_two_sorted_tensors_and_keep_it_sorted(x1, x2, eps=e) + assert torch.isclose(expected, out).all() + + # fmt: off def test_take_2d() -> None: x = torch.tensor([ From 5ebd696cf1e75ee65a75dbbc5b628609f1ce21e0 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Sat, 20 Apr 2024 17:22:42 +0700 Subject: [PATCH 10/23] upd --- oml/datasets/__init__.py | 2 + oml/inference/__init__.py | 1 + oml/inference/abstract.py | 4 +- oml/models/meta/siamese.py | 6 +- oml/retrieval/postprocessors/pairwise.py | 48 ++++++++++------ oml/utils/misc_torch.py | 1 + .../test_pairwise_images.py | 57 +++++++------------ tests/test_oml/test_utils/test_misc_torch.py | 45 ++++++++------- 8 files changed, 86 insertions(+), 78 deletions(-) diff --git a/oml/datasets/__init__.py b/oml/datasets/__init__.py index e69de29bb..0e1bd4358 100644 --- a/oml/datasets/__init__.py +++ b/oml/datasets/__init__.py @@ -0,0 +1,2 @@ +from oml.datasets.images import ImageQueryGalleryLabeledDataset, ImageBaseDataset, ImageLabeledDataset +from oml.datasets.pairs import PairDataset diff --git a/oml/inference/__init__.py b/oml/inference/__init__.py index e69de29bb..63420852e 100644 --- a/oml/inference/__init__.py +++ b/oml/inference/__init__.py @@ -0,0 +1 @@ +from oml.inference.abstract import inference, inference_cached, pairwise_inference \ No newline at end of file diff --git a/oml/inference/abstract.py b/oml/inference/abstract.py index 8157b5eb6..1af275f68 100644 --- a/oml/inference/abstract.py +++ b/oml/inference/abstract.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader, Dataset from tqdm.auto import tqdm -from oml.datasets import PairsDataset +from oml.datasets import PairDataset from oml.ddp.patching import patch_dataloader_to_ddp from oml.ddp.utils import get_world_size_safe, is_ddp, sync_dicts_ddp from oml.interfaces.datasets import IBaseDataset @@ -107,7 +107,7 @@ def pairwise_inference( ) -> FloatTensor: device = get_device(model) - dataset = PairsDataset(base_dataset=base_dataset, pair_ids=pair_ids) + dataset = PairDataset(base_dataset=base_dataset, pair_ids=pair_ids) def _apply( model_: IPairwiseModel, diff --git a/oml/models/meta/siamese.py b/oml/models/meta/siamese.py index 4c843f90e..3dbb23345 100644 --- a/oml/models/meta/siamese.py +++ b/oml/models/meta/siamese.py @@ -145,14 +145,16 @@ class TrivialDistanceSiamese(IPairwiseModel): pretrained_models: Dict[str, Any] = {} - def __init__(self, extractor: IExtractor) -> None: + def __init__(self, extractor: IExtractor, output_bias: float = 0) -> None: """ Args: extractor: Instance of ``IExtractor`` (e.g. ``ViTExtractor``) + output_bias: Bias added to the distances. """ super(TrivialDistanceSiamese, self).__init__() self.extractor = extractor + self.output_bias = output_bias def forward(self, x1: Tensor, x2: Tensor) -> Tensor: """ @@ -166,7 +168,7 @@ def forward(self, x1: Tensor, x2: Tensor) -> Tensor: """ x1 = self.extractor(x1) x2 = self.extractor(x2) - return elementwise_dist(x1, x2, p=2) + return elementwise_dist(x1, x2, p=2) + self.output_bias def predict(self, x1: Tensor, x2: Tensor) -> Tensor: return self.forward(x1=x1, x2=x2) diff --git a/oml/retrieval/postprocessors/pairwise.py b/oml/retrieval/postprocessors/pairwise.py index e8f62295a..08a248605 100644 --- a/oml/retrieval/postprocessors/pairwise.py +++ b/oml/retrieval/postprocessors/pairwise.py @@ -1,22 +1,24 @@ +from typing import Tuple + import torch +from torch import Tensor from oml.inference.abstract import pairwise_inference from oml.interfaces.datasets import IQueryGalleryDataset from oml.interfaces.models import IPairwiseModel from oml.interfaces.retrieval import IRetrievalPostprocessor -from oml.retrieval.prediction import RetrievalPrediction -from oml.utils.misc_torch import cat_two_sorted_tensors_and_keep_it_sorted, take_2d +from oml.utils.misc_torch import cat_two_sorted_tensors_and_keep_it_sorted, take_2d, assign_2d class PairwiseReranker(IRetrievalPostprocessor): def __init__( - self, - top_n: int, - pairwise_model: IPairwiseModel, - num_workers: int, - batch_size: int, - verbose: bool = False, - use_fp16: bool = False, + self, + top_n: int, + pairwise_model: IPairwiseModel, + num_workers: int, + batch_size: int, + verbose: bool = False, + use_fp16: bool = False, ): """ @@ -41,7 +43,25 @@ def __init__( self.verbose = verbose self.use_fp16 = use_fp16 - def process(self, prediction: RetrievalPrediction, dataset: IQueryGalleryDataset) -> RetrievalPrediction: + def process(self, distances: Tensor, dataset: IQueryGalleryDataset) -> Tensor: + """ + Args: + distances: Distances among queries and galleries with the shape of ``[Q, G]``. + dataset: Dataset having query-gallery split. + + Returns: + The same distances matrix, but `top_n` smallest values are updated. + """ + # todo 522: + # after we introduce RetrievalPrediction the signature of the method will change: so, we directly call + # self.process_neigh. Thus, the code below is temporary to support the current interface. + distances_neigh, ii_neigh = torch.topk(distances, k=min(distances.shape[1], self.top_n)) + distances_neigh_upd, ii_neigh_upd = self.process_neigh(distances_neigh, ii_neigh, dataset) + distances_upd = assign_2d(x=distances, indices=ii_neigh_upd, new_values=distances_neigh_upd) + return distances_upd + + def process_neigh(self, distances: Tensor, retrieved_ids: Tensor, dataset: IQueryGalleryDataset) -> Tuple[ + Tensor, Tensor]: """ Note, the new distances to the ``top_n`` items produced by the pairwise model may be adjusted @@ -56,10 +76,7 @@ def process(self, prediction: RetrievalPrediction, dataset: IQueryGalleryDataset If concatenation of two distances is already sorted, we keep it untouched. """ - top_n = min(self.top_n, prediction.top_n) - - retrieved_ids = prediction.retrieved_ids.clone() - distances = prediction.distances.clone() + top_n = min(self.top_n, distances.shape[1]) # let's list pairs of (query_i, gallery_j) we need to process ids_q = dataset.get_query_ids().unsqueeze(-1).repeat_interleave(top_n) @@ -90,8 +107,7 @@ def process(self, prediction: RetrievalPrediction, dataset: IQueryGalleryDataset assert distances_upd.shape == distances.shape assert retrieved_ids_upd.shape == retrieved_ids.shape - prediction_upd = RetrievalPrediction(distances_upd, retrieved_ids=retrieved_ids_upd, gt_ids=prediction.gt_ids) - return prediction_upd + return distances_upd, retrieved_ids_upd __all__ = ["PairwiseReranker"] diff --git a/oml/utils/misc_torch.py b/oml/utils/misc_torch.py index 815ff59fb..68b4a1299 100644 --- a/oml/utils/misc_torch.py +++ b/oml/utils/misc_torch.py @@ -82,6 +82,7 @@ def cat_two_sorted_tensors_and_keep_it_sorted(x1: Tensor, x2: Tensor, eps: float return x + def elementwise_dist(x1: Tensor, x2: Tensor, p: int = 2) -> Tensor: """ Args: diff --git a/tests/test_oml/test_postprocessor/test_pairwise_images.py b/tests/test_oml/test_postprocessor/test_pairwise_images.py index 5f5473038..a568c946f 100644 --- a/tests/test_oml/test_postprocessor/test_pairwise_images.py +++ b/tests/test_oml/test_postprocessor/test_pairwise_images.py @@ -1,12 +1,13 @@ from typing import Tuple -import numpy as np import pytest import torch from torch import Tensor, nn from oml.const import MOCK_DATASET_PATH -from oml.inference.flat import inference_on_images +from oml.datasets.images import ImageQueryGalleryLabeledDataset +from oml.inference import inference +from oml.interfaces.datasets import IQueryGalleryDataset from oml.models.meta.siamese import TrivialDistanceSiamese from oml.models.resnet.extractor import ResnetExtractor from oml.retrieval.postprocessors.pairwise import PairwiseReranker @@ -16,58 +17,42 @@ from oml.utils.misc_torch import pairwise_dist -def get_validation_results(model: nn.Module, transforms: TTransforms) -> Tuple[Tensor, Tensor, Tensor]: +def get_validation_results(model: nn.Module, transforms: TTransforms) -> Tuple[Tensor, IQueryGalleryDataset]: _, df_val = download_mock_dataset(MOCK_DATASET_PATH) - is_query = np.array(df_val["is_query"]).astype(bool) - is_gallery = np.array(df_val["is_gallery"]).astype(bool) - paths = np.array(df_val["path"].apply(lambda x: MOCK_DATASET_PATH / x)) - queries = paths[is_query] - galleries = paths[is_gallery] + dataset = ImageQueryGalleryLabeledDataset(df=df_val, transform=transforms, dataset_root=MOCK_DATASET_PATH) - embeddings = inference_on_images( - model=model, - paths=paths.tolist(), - transform=transforms, - num_workers=0, - batch_size=4, - verbose=False, - use_fp16=True, - ) + embeddings = inference(model, dataset, batch_size=4) - distances = pairwise_dist(x1=embeddings[is_query], x2=embeddings[is_gallery], p=2) + distances = pairwise_dist(x1=embeddings[dataset.get_query_ids()], x2=embeddings[dataset.get_gallery_ids()], p=2) - return distances, queries, galleries + return distances, dataset @pytest.mark.long @pytest.mark.parametrize("top_n", [2, 5, 100]) -def test_trivial_processing_does_not_change_distances_order(top_n: int) -> None: +@pytest.mark.parametrize("pairwise_distances_bias", [0, 100]) +def test_trivial_processing_does_not_change_distances_order(top_n: int, pairwise_distances_bias: float) -> None: extractor = ResnetExtractor(weights=None, arch="resnet18", normalise_features=True, gem_p=None, remove_fc=True) - - pairwise_model = TrivialDistanceSiamese(extractor) + pairwise_model = TrivialDistanceSiamese(extractor, output_bias=pairwise_distances_bias) transforms = get_normalisation_resize_torch(im_size=32) - - distances, queries, galleries = get_validation_results(model=extractor, transforms=transforms) + distances, dataset = get_validation_results(model=extractor, transforms=transforms) postprocessor = PairwiseReranker( top_n=top_n, pairwise_model=pairwise_model, - transforms=transforms, num_workers=0, batch_size=4, verbose=False, use_fp16=True, ) - distances_processed = postprocessor.process(distances=distances.clone(), queries=queries, galleries=galleries) - - order = distances.argsort() - order_processed = distances_processed.argsort() - - assert (order == order_processed).all(), (order, order_processed) - - if top_n <= len(galleries): - min_orig_distances = torch.topk(distances, k=top_n, largest=False).values - min_processed_distances = torch.topk(distances_processed, k=top_n, largest=False).values - assert torch.allclose(min_orig_distances, min_processed_distances) + distances_processed = postprocessor.process(distances=distances.clone(), dataset=dataset) + + if pairwise_distances_bias == 0: + # distances are literally the same + assert torch.allclose(distances_processed.argsort(), distances.argsort()) + else: + # since pairwise distances have been shifted, the relative order remain the same, but the values are different + assert torch.allclose(distances_processed.argsort(), distances.argsort()) + assert not torch.allclose(distances_processed, distances) diff --git a/tests/test_oml/test_utils/test_misc_torch.py b/tests/test_oml/test_utils/test_misc_torch.py index fb8f3f0dd..fc19ee2c5 100644 --- a/tests/test_oml/test_utils/test_misc_torch.py +++ b/tests/test_oml/test_utils/test_misc_torch.py @@ -7,9 +7,10 @@ from oml.utils.misc_torch import ( PCA, assign_2d, + cat_two_sorted_tensors_and_keep_it_sorted, drop_duplicates_by_ids, elementwise_dist, - take_2d, cat_two_sorted_tensors_and_keep_it_sorted, + take_2d, ) @@ -27,29 +28,29 @@ def test_elementwise_dist() -> None: "x1,x2,e,expected", [ ( - # x1 - torch.tensor([[10, 20, 30], [40, 50, 60]]).float(), - # x2 - torch.tensor([[0.3, 0.4, 0.5], [0.6, 0.8, 0.9]]).float(), - # e - 0.001, - # expected: rescaling is needed - torch.tensor( - [ - [0.1 - 0.001, 0.2 - 0.001, 0.3 - 0.001, 0.3, 0.4, 0.5], - [0.4 - 0.001, 0.5 - 0.001, 0.6 - 0.001, 0.6, 0.8, 0.9], - ] - ).float(), + # x1 + torch.tensor([[10, 20, 30], [40, 50, 60]]).float(), + # x2 + torch.tensor([[0.3, 0.4, 0.5], [0.6, 0.8, 0.9]]).float(), + # e + 0.001, + # expected: rescaling is needed + torch.tensor( + [ + [0.1 - 0.001, 0.2 - 0.001, 0.3 - 0.001, 0.3, 0.4, 0.5], + [0.4 - 0.001, 0.5 - 0.001, 0.6 - 0.001, 0.6, 0.8, 0.9], + ] + ).float(), ), ( - # x1 - torch.tensor([[-10, -5], [-20, -8]]).float(), - # x2 - torch.tensor([[0.3, 0.4, 0.5], [0.6, 0.8, 0.9]]).float(), - # e - 0.001, - # expected: rescaling is not needed, we jast concat - torch.tensor([[-10, -5, 0.3, 0.4, 0.5], [-20, -8, 0.6, 0.8, 0.9]]).float(), + # x1 + torch.tensor([[-10, -5], [-20, -8]]).float(), + # x2 + torch.tensor([[0.3, 0.4, 0.5], [0.6, 0.8, 0.9]]).float(), + # e + 0.001, + # expected: rescaling is not needed, we jast concat + torch.tensor([[-10, -5, 0.3, 0.4, 0.5], [-20, -8, 0.6, 0.8, 0.9]]).float(), ), ], ) From 33f11ae0f5045e2e95e0672c795b0e7757b9e666 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Sat, 20 Apr 2024 23:17:43 +0700 Subject: [PATCH 11/23] upd --- oml/datasets/__init__.py | 6 +- oml/datasets/images.py | 1 + oml/inference/__init__.py | 2 +- oml/inference/abstract.py | 58 +++--- oml/interfaces/retrieval.py | 2 +- .../pipelines/train_postprocessor.py | 57 +++--- oml/lightning/pipelines/validate.py | 1 + oml/metrics/embeddings.py | 9 +- oml/models/meta/siamese.py | 8 +- oml/registry/postprocessors.py | 4 +- oml/retrieval/postprocessors/pairwise.py | 27 +-- .../test_retrieval_validation.py | 4 +- tests/test_integrations/utils.py | 72 +++++-- tests/test_oml/test_ddp/test_ddp_inference.py | 10 +- .../test_metrics/test_embedding_metrics.py | 187 +++++++----------- .../test_pairwise_embeddings.py | 148 ++++++-------- .../test_pairwise_images.py | 4 +- 17 files changed, 297 insertions(+), 303 deletions(-) diff --git a/oml/datasets/__init__.py b/oml/datasets/__init__.py index 0e1bd4358..538f3360e 100644 --- a/oml/datasets/__init__.py +++ b/oml/datasets/__init__.py @@ -1,2 +1,6 @@ -from oml.datasets.images import ImageQueryGalleryLabeledDataset, ImageBaseDataset, ImageLabeledDataset +from oml.datasets.images import ( + ImageBaseDataset, + ImageLabeledDataset, + ImageQueryGalleryLabeledDataset, +) from oml.datasets.pairs import PairDataset diff --git a/oml/datasets/images.py b/oml/datasets/images.py index 699651b9b..57b7670da 100644 --- a/oml/datasets/images.py +++ b/oml/datasets/images.py @@ -186,6 +186,7 @@ def __getitem__(self, idx: int) -> Dict[str, Union[FloatTensor, int]]: self.index_key: idx, } + # todo 522: avoid passing extra data as keys if self.extra_data: for key, record in self.extra_data.items(): if key in item: diff --git a/oml/inference/__init__.py b/oml/inference/__init__.py index 63420852e..753875058 100644 --- a/oml/inference/__init__.py +++ b/oml/inference/__init__.py @@ -1 +1 @@ -from oml.inference.abstract import inference, inference_cached, pairwise_inference \ No newline at end of file +from oml.inference.abstract import inference, inference_cached, pairwise_inference diff --git a/oml/inference/abstract.py b/oml/inference/abstract.py index 1af275f68..b062a05ef 100644 --- a/oml/inference/abstract.py +++ b/oml/inference/abstract.py @@ -95,6 +95,36 @@ def apply(model_: nn.Module, batch_: Dict[str, Any]) -> FloatTensor: ) +def inference_cached( + model: nn.Module, + dataset: IBaseDataset, + cache_path: str = "inference_cache.pth", + num_workers: int = 0, + batch_size: int = 128, + use_fp16: bool = False, + verbose: bool = True, + accumulate_on_cpu: bool = True, +) -> FloatTensor: + if Path(cache_path).is_file(): + outputs = torch.load(cache_path, map_location="cpu") + print(f"Model outputs have been loaded from {cache_path}.") + else: + outputs = inference( + model=model, + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + use_fp16=use_fp16, + verbose=verbose, + accumulate_on_cpu=accumulate_on_cpu, + ) + + torch.save(outputs, cache_path) + print(f"Model outputs have been saved to {cache_path}.") + + return outputs + + def pairwise_inference( model: IPairwiseModel, base_dataset: IBaseDataset, @@ -131,32 +161,4 @@ def _apply( return output -def inference_cached( - dataset: IBaseDataset, - extractor: nn.Module, - output_cache_path: str = "inference_cache.pth", - num_workers: int = 0, - batch_size: int = 128, - use_fp16: bool = False, -) -> FloatTensor: - if Path(output_cache_path).is_file(): - outputs = torch.load(output_cache_path, map_location="cpu") - print(f"Model outputs have been loaded from {output_cache_path}.") - else: - outputs = inference( - model=extractor, - dataset=dataset, - num_workers=num_workers, - batch_size=batch_size, - use_fp16=use_fp16, - verbose=True, - accumulate_on_cpu=True, - ) - - torch.save(outputs, output_cache_path) - print(f"Model outputs have been saved to {output_cache_path}.") - - return outputs - - __all__ = ["inference", "pairwise_inference", "inference_cached"] diff --git a/oml/interfaces/retrieval.py b/oml/interfaces/retrieval.py index ebf2ba552..7422f6c35 100644 --- a/oml/interfaces/retrieval.py +++ b/oml/interfaces/retrieval.py @@ -7,7 +7,7 @@ class IRetrievalPostprocessor: """ - def process(self, *args, **kwargs) -> Any: + def process(self, *args, **kwargs) -> Any: # type: ignore # todo 522: add actual signature later raise NotImplementedError() diff --git a/oml/lightning/pipelines/train_postprocessor.py b/oml/lightning/pipelines/train_postprocessor.py index b5128eb0b..b7dbb0b9f 100644 --- a/oml/lightning/pipelines/train_postprocessor.py +++ b/oml/lightning/pipelines/train_postprocessor.py @@ -3,16 +3,16 @@ from pprint import pprint from typing import Any, Dict, Tuple -import pandas as pd import pytorch_lightning as pl import torch from omegaconf import DictConfig from torch import device as tdevice from torch.utils.data import DataLoader -from oml.const import BBOXES_COLUMNS, EMBEDDINGS_KEY, TCfg +from oml.const import EMBEDDINGS_KEY, TCfg from oml.datasets.base import ImageLabeledDataset, ImageQueryGalleryLabeledDataset -from oml.inference.flat import inference_on_dataframe +from oml.datasets.images import get_retrieval_images_datasets +from oml.inference import inference, inference_cached from oml.interfaces.models import IPairwiseModel from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP from oml.lightning.modules.pairwise_postprocessing import ( @@ -34,7 +34,6 @@ from oml.registry.postprocessors import get_postprocessor_by_cfg from oml.registry.transforms import get_transforms_by_cfg from oml.retrieval.postprocessors.pairwise import PairwiseReranker -from oml.transforms.images.torchvision import get_normalisation_resize_torch from oml.utils.misc import dictconfig_to_dict, flatten_dict, set_global_seed @@ -56,46 +55,49 @@ def dict2str(dictionary: Dict[str, Any]) -> str: def get_loaders_with_embeddings(cfg: TCfg) -> Tuple[DataLoader, DataLoader]: - # todo: support bounding bboxes - df = pd.read_csv(Path(cfg["dataset_root"]) / cfg["dataframe_name"]) - assert not set(BBOXES_COLUMNS).intersection( - df.columns - ), "We've found bboxes in the dataframe, but they're not supported yet." - device = tdevice("cuda:0") if parse_engine_params_from_config(cfg)["accelerator"] == "gpu" else tdevice("cpu") extractor = get_extractor_by_cfg(cfg["extractor"]).to(device) - if cfg["embeddings_cache_dir"] is not None: - cache_file = Path(cfg["embeddings_cache_dir"]) / f"embeddings_{get_hash_of_extraction_stage_cfg(cfg)[:5]}.pkl" - else: - cache_file = None - - emb_train, emb_val, df_train, df_val = inference_on_dataframe( - extractor=extractor, + train_extraction, val_extraction = get_retrieval_images_datasets( dataset_root=cfg["dataset_root"], - output_cache_path=cache_file, dataframe_name=cfg["dataframe_name"], - transforms=get_transforms_by_cfg(cfg["transforms_extraction"]), - num_workers=cfg["num_workers"], - batch_size=cfg["batch_size_inference"], - use_fp16=int(cfg.get("precision", 32)) == 16, + transforms_train=get_transforms_by_cfg(cfg["transforms_extraction"]), + transforms_val=get_transforms_by_cfg(cfg["transforms_extraction"]), ) + args = { + "model": extractor, + "num_workers": cfg["num_workers"], + "batch_size": cfg["batch_size_inference"], + "use_fp16": int(cfg.get("precision", 32)) == 16, + } + + if cfg["embeddings_cache_dir"] is not None: + hash_ = get_hash_of_extraction_stage_cfg(cfg)[:5] + dir_ = Path(cfg["embeddings_cache_dir"]) + emb_train = inference_cached(dataset=train_extraction, cache_path=str(dir_ / f"emb_train_{hash_}.pkl"), **args) + emb_val = inference_cached(dataset=val_extraction, cache_path=str(dir_ / f"emb_val_{hash_}.pkl"), **args) + else: + emb_train = inference(dataset=train_extraction, **args) + emb_val = inference(dataset=val_extraction, **args) + train_dataset = ImageLabeledDataset( - df=df_train, + dataset_root=cfg["dataset_root"], + df=train_extraction.df, transform=get_transforms_by_cfg(cfg["transforms_train"]), extra_data={EMBEDDINGS_KEY: emb_train}, ) valid_dataset = ImageQueryGalleryLabeledDataset( - df=df_val, - # we don't care about transforms, since the only goal of this dataset is to deliver embeddings - transform=get_normalisation_resize_torch(im_size=8), + dataset_root=cfg["dataset_root"], + df=val_extraction.df, + transform=get_transforms_by_cfg(cfg["transforms_extraction"]), extra_data={EMBEDDINGS_KEY: emb_val}, ) sampler = parse_sampler_from_config(cfg, dataset=train_dataset) - assert sampler is not None + assert sampler is not None, "We will be training on pairs, so, having sampler is obligatory." + loader_train = DataLoader(batch_sampler=sampler, dataset=train_dataset, num_workers=cfg["num_workers"]) loader_val = DataLoader( @@ -157,6 +159,7 @@ def postprocessor_training_pipeline(cfg: DictConfig) -> None: metrics_constructor = EmbeddingMetricsDDP if is_ddp else EmbeddingMetrics metrics_calc = metrics_constructor( + dataset=loader_val.dataset, embeddings_key=pl_module.embeddings_key, categories_key=loader_val.dataset.categories_key, labels_key=loader_val.dataset.labels_key, diff --git a/oml/lightning/pipelines/validate.py b/oml/lightning/pipelines/validate.py index 543522bf4..43a28110c 100644 --- a/oml/lightning/pipelines/validate.py +++ b/oml/lightning/pipelines/validate.py @@ -69,6 +69,7 @@ def extractor_validation_pipeline(cfg: TCfg) -> Tuple[pl.Trainer, Dict[str, Any] metrics_constructor = EmbeddingMetricsDDP if is_ddp else EmbeddingMetrics metrics_calc = metrics_constructor( + dataset=valid_dataset, embeddings_key=pl_model.embeddings_key, categories_key=valid_dataset.categories_key, labels_key=valid_dataset.labels_key, diff --git a/oml/metrics/embeddings.py b/oml/metrics/embeddings.py index 246be00d7..792e8cd7e 100644 --- a/oml/metrics/embeddings.py +++ b/oml/metrics/embeddings.py @@ -37,6 +37,7 @@ calc_topological_metrics, reduce_metrics, ) +from oml.interfaces.datasets import IQueryGalleryLabeledDataset from oml.interfaces.metrics import IMetricDDP, IMetricVisualisable from oml.interfaces.retrieval import IRetrievalPostprocessor from oml.metrics.accumulation import Accumulator @@ -67,6 +68,7 @@ class EmbeddingMetrics(IMetricVisualisable): def __init__( self, + dataset: Optional[IQueryGalleryLabeledDataset] = None, embeddings_key: str = EMBEDDINGS_KEY, labels_key: str = LABELS_KEY, is_query_key: str = IS_QUERY_KEY, @@ -88,6 +90,7 @@ def __init__( """ Args: + dataset: Annotated dataset having query-gallery split. todo 522: This argument will not be Optional soon. embeddings_key: Key to take the embeddings from the batches labels_key: Key to take the labels from the batches is_query_key: Key to take the information whether every batch sample belongs to the query @@ -115,6 +118,7 @@ def __init__( verbose: Set ``True`` if you want to print metrics """ + self.dataset = dataset self.embeddings_key = embeddings_key self.labels_key = labels_key self.is_query_key = is_query_key @@ -148,8 +152,6 @@ def __init__( keys_to_accumulate.append(self.sequence_key) if self.extra_keys: keys_to_accumulate.extend(list(extra_keys)) - if self.postprocessor: - keys_to_accumulate.extend(self.postprocessor.needed_keys) self.keys_to_accumulate = tuple(set(keys_to_accumulate)) self.acc = Accumulator(keys_to_accumulate=self.keys_to_accumulate) @@ -187,7 +189,8 @@ def _calc_matrices(self) -> None: validate_dataset(mask_gt=self.mask_gt, mask_to_ignore=mask_to_ignore) if self.postprocessor: - self.distance_matrix = self.postprocessor.process_by_dict(self.distance_matrix, data=self.acc.storage) + assert self.dataset, "You must pass dataset to init to make postprocessing." + self.distance_matrix = self.postprocessor.process(self.distance_matrix, dataset=self.dataset) def compute_metrics(self) -> TMetricsDict_ByLabels: # type: ignore if not self.acc.is_storage_full(): diff --git a/oml/models/meta/siamese.py b/oml/models/meta/siamese.py index 3dbb23345..b0f0b6d3d 100644 --- a/oml/models/meta/siamese.py +++ b/oml/models/meta/siamese.py @@ -18,16 +18,18 @@ class LinearTrivialDistanceSiamese(IPairwiseModel): """ - def __init__(self, feat_dim: int, identity_init: bool): + def __init__(self, feat_dim: int, identity_init: bool, output_bias: float = 0): """ Args: feat_dim: Expected size of each input. identity_init: If ``True``, models' weights initialised in a way when the model simply estimates L2 distance between the original embeddings. + output_bias: Value to add to the output. """ super(LinearTrivialDistanceSiamese, self).__init__() self.feat_dim = feat_dim + self.output_bias = output_bias self.proj = torch.nn.Linear(in_features=feat_dim, out_features=feat_dim, bias=False) @@ -46,7 +48,7 @@ def forward(self, x1: Tensor, x2: Tensor) -> Tensor: """ x1 = self.proj(x1) x2 = self.proj(x2) - y = elementwise_dist(x1, x2, p=2) + y = elementwise_dist(x1, x2, p=2) + self.output_bias return y def predict(self, x1: Tensor, x2: Tensor) -> Tensor: @@ -149,7 +151,7 @@ def __init__(self, extractor: IExtractor, output_bias: float = 0) -> None: """ Args: extractor: Instance of ``IExtractor`` (e.g. ``ViTExtractor``) - output_bias: Bias added to the distances. + output_bias: Value to add to the outputs. """ super(TrivialDistanceSiamese, self).__init__() diff --git a/oml/registry/postprocessors.py b/oml/registry/postprocessors.py index 76436427c..015971b77 100644 --- a/oml/registry/postprocessors.py +++ b/oml/registry/postprocessors.py @@ -15,9 +15,9 @@ def get_postprocessor(name: str, **kwargs: Dict[str, Any]) -> IRetrievalPostproc constructor = POSTPROCESSORS_REGISTRY[name] if "pairwise_model" in kwargs: - kwargs["pairwise_model"] = get_pairwise_model_by_cfg(kwargs["pairwise_model"]) + kwargs["pairwise_model"] = get_pairwise_model_by_cfg(kwargs["pairwise_model"]) # type: ignore - return constructor(**kwargs) + return constructor(**kwargs) # type: ignore def get_postprocessor_by_cfg(cfg: TCfg) -> IRetrievalPostprocessor: diff --git a/oml/retrieval/postprocessors/pairwise.py b/oml/retrieval/postprocessors/pairwise.py index 08a248605..0ac1cd839 100644 --- a/oml/retrieval/postprocessors/pairwise.py +++ b/oml/retrieval/postprocessors/pairwise.py @@ -7,18 +7,22 @@ from oml.interfaces.datasets import IQueryGalleryDataset from oml.interfaces.models import IPairwiseModel from oml.interfaces.retrieval import IRetrievalPostprocessor -from oml.utils.misc_torch import cat_two_sorted_tensors_and_keep_it_sorted, take_2d, assign_2d +from oml.utils.misc_torch import ( + assign_2d, + cat_two_sorted_tensors_and_keep_it_sorted, + take_2d, +) class PairwiseReranker(IRetrievalPostprocessor): def __init__( - self, - top_n: int, - pairwise_model: IPairwiseModel, - num_workers: int, - batch_size: int, - verbose: bool = False, - use_fp16: bool = False, + self, + top_n: int, + pairwise_model: IPairwiseModel, + num_workers: int, + batch_size: int, + verbose: bool = False, + use_fp16: bool = False, ): """ @@ -43,7 +47,7 @@ def __init__( self.verbose = verbose self.use_fp16 = use_fp16 - def process(self, distances: Tensor, dataset: IQueryGalleryDataset) -> Tensor: + def process(self, distances: Tensor, dataset: IQueryGalleryDataset) -> Tensor: # type: ignore """ Args: distances: Distances among queries and galleries with the shape of ``[Q, G]``. @@ -60,8 +64,9 @@ def process(self, distances: Tensor, dataset: IQueryGalleryDataset) -> Tensor: distances_upd = assign_2d(x=distances, indices=ii_neigh_upd, new_values=distances_neigh_upd) return distances_upd - def process_neigh(self, distances: Tensor, retrieved_ids: Tensor, dataset: IQueryGalleryDataset) -> Tuple[ - Tensor, Tensor]: + def process_neigh( + self, distances: Tensor, retrieved_ids: Tensor, dataset: IQueryGalleryDataset + ) -> Tuple[Tensor, Tensor]: """ Note, the new distances to the ``top_n`` items produced by the pairwise model may be adjusted diff --git a/tests/test_integrations/test_retrieval_validation.py b/tests/test_integrations/test_retrieval_validation.py index cbaef4f61..91f9ee6d3 100644 --- a/tests/test_integrations/test_retrieval_validation.py +++ b/tests/test_integrations/test_retrieval_validation.py @@ -9,7 +9,7 @@ from oml.const import EMBEDDINGS_KEY, INPUT_TENSORS_KEY, OVERALL_CATEGORIES_KEY from oml.metrics.embeddings import EmbeddingMetrics from tests.test_integrations.utils import ( - EmbeddingsQueryGalleryDataset, + EmbeddingsQueryGalleryLabeledDataset, IdealClusterEncoder, ) @@ -51,7 +51,7 @@ def get_shared_query_gallery() -> TData: def test_retrieval_validation(batch_size: int, shuffle: bool, num_workers: int, data: TData) -> None: labels, query_mask, gallery_mask, input_tensors, cmc_gt = data - dataset = EmbeddingsQueryGalleryDataset( + dataset = EmbeddingsQueryGalleryLabeledDataset( labels=labels, embeddings=input_tensors, is_query=query_mask, diff --git a/tests/test_integrations/utils.py b/tests/test_integrations/utils.py index 7c8627a1c..8702c001b 100644 --- a/tests/test_integrations/utils.py +++ b/tests/test_integrations/utils.py @@ -5,15 +5,15 @@ from torch import BoolTensor, FloatTensor, LongTensor, nn from oml.const import ( - CATEGORIES_COLUMN, + CATEGORIES_KEY, INDEX_KEY, INPUT_TENSORS_KEY, IS_GALLERY_KEY, IS_QUERY_KEY, LABELS_KEY, - SEQUENCE_COLUMN, + SEQUENCE_KEY, ) -from oml.interfaces.datasets import IQueryGalleryLabeledDataset +from oml.interfaces.datasets import IQueryGalleryDataset, IQueryGalleryLabeledDataset from oml.utils.misc import one_hot @@ -35,48 +35,58 @@ def forward(self, labels: torch.Tensor, need_noise: bool = True) -> torch.Tensor return embeddings -class EmbeddingsQueryGalleryDataset(IQueryGalleryLabeledDataset): +class EmbeddingsQueryGalleryDataset(IQueryGalleryDataset): def __init__( self, embeddings: FloatTensor, - labels: LongTensor, is_query: BoolTensor, is_gallery: BoolTensor, categories: Optional[np.ndarray] = None, sequence: Optional[np.ndarray] = None, input_tensors_key: str = INPUT_TENSORS_KEY, - labels_key: str = LABELS_KEY, index_key: str = INDEX_KEY, + # todo 522: remove keys later + categories_key: str = CATEGORIES_KEY, + sequence_key: str = SEQUENCE_KEY, ): super().__init__() - assert len(embeddings) == len(labels) == len(is_query) == len(is_gallery) + assert len(embeddings) == len(is_query) == len(is_gallery) self._embeddings = embeddings - self._labels = labels self._is_query = is_query self._is_gallery = is_gallery + # todo 522: remove keys + self.categories_key = categories_key + self.sequence_key = sequence_key + self.extra_data = {} - if categories: - self.extra_data[CATEGORIES_COLUMN] = categories + if categories is not None: + self.extra_data[self.categories_key] = categories - if sequence: - self.extra_data[SEQUENCE_COLUMN] = sequence + if sequence is not None: + self.extra_data[self.sequence_key] = sequence self.input_tensors_key = input_tensors_key - self.labels_key = labels_key self.index_key = index_key def __getitem__(self, idx: int) -> Dict[str, Any]: batch = { self.input_tensors_key: self._embeddings[idx], - self.labels_key: self._labels[idx], self.index_key: idx, # todo 522: remove IS_QUERY_KEY: self._is_query[idx], IS_GALLERY_KEY: self._is_gallery[idx], } + # todo 522: avoid passing extra data as keys + if self.extra_data: + for key, record in self.extra_data.items(): + if key in batch: + raise ValueError(f" and dataset share the same key: {key}") + else: + batch[key] = record[idx] + return batch def __len__(self) -> int: @@ -88,5 +98,39 @@ def get_query_ids(self) -> LongTensor: def get_gallery_ids(self) -> LongTensor: return self._is_gallery.nonzero().squeeze() + +class EmbeddingsQueryGalleryLabeledDataset(EmbeddingsQueryGalleryDataset, IQueryGalleryLabeledDataset): + def __init__( + self, + embeddings: FloatTensor, + labels: LongTensor, + is_query: BoolTensor, + is_gallery: BoolTensor, + categories: Optional[np.ndarray] = None, + sequence: Optional[np.ndarray] = None, + input_tensors_key: str = INPUT_TENSORS_KEY, + labels_key: str = LABELS_KEY, + index_key: str = INDEX_KEY, + ): + super().__init__( + embeddings=embeddings, + is_query=is_query, + is_gallery=is_gallery, + categories=categories, + sequence=sequence, + input_tensors_key=input_tensors_key, + index_key=index_key, + ) + + assert len(embeddings) == len(labels) + + self._labels = labels + self.labels_key = labels_key + + def __getitem__(self, idx: int) -> Dict[str, Any]: + item = super().__getitem__(idx) + item[self.labels_key] = self._labels[idx] + return item + def get_labels(self) -> np.ndarray: return np.array(self._labels) diff --git a/tests/test_oml/test_ddp/test_ddp_inference.py b/tests/test_oml/test_ddp/test_ddp_inference.py index 505205055..aafde4b3f 100644 --- a/tests/test_oml/test_ddp/test_ddp_inference.py +++ b/tests/test_oml/test_ddp/test_ddp_inference.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import List import pytest @@ -5,7 +6,8 @@ from torchvision.models import resnet18 from oml.const import MOCK_DATASET_PATH -from oml.inference.flat import inference_on_images +from oml.datasets import ImageBaseDataset +from oml.inference import inference from oml.transforms.images.torchvision import get_normalisation_resize_torch from oml.utils.download_mock_dataset import download_mock_dataset from tests.test_oml.test_ddp.utils import init_ddp, run_in_ddp @@ -34,13 +36,11 @@ def run_with_handling_duplicates(rank: int, world_size: int, device: str, paths: args = { "model": model, - "paths": paths, - "transform": transform, + "dataset": ImageBaseDataset(paths=[Path(x) for x in paths], transform=transform), "num_workers": 0, "verbose": True, - "f_imread": None, "batch_size": batch_size, } - output = inference_on_images(**args) + output = inference(**args) assert len(paths) == len(output), (len(paths), len(output)) diff --git a/tests/test_oml/test_metrics/test_embedding_metrics.py b/tests/test_oml/test_metrics/test_embedding_metrics.py index 9ccca3993..68e64a57f 100644 --- a/tests/test_oml/test_metrics/test_embedding_metrics.py +++ b/tests/test_oml/test_metrics/test_embedding_metrics.py @@ -3,9 +3,11 @@ from functools import partial from typing import Any, Tuple +import numpy as np import pytest import torch from torch import Tensor +from torch.utils.data import DataLoader from oml.const import ( CATEGORIES_KEY, @@ -20,6 +22,7 @@ from oml.models.meta.siamese import LinearTrivialDistanceSiamese from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.utils.misc import compare_dicts_recursively, one_hot +from tests.test_integrations.utils import EmbeddingsQueryGalleryLabeledDataset FEAT_DIM = 8 oh = partial(one_hot, dim=FEAT_DIM) @@ -46,22 +49,13 @@ def perfect_case() -> Any: Thus, we expect all of the metrics equals to 1. """ - - batch1 = { - EMBEDDINGS_KEY: torch.stack([oh(0), oh(1), oh(1)]), - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([True, True, True]), - IS_GALLERY_KEY: torch.tensor([False, False, False]), - CATEGORIES_KEY: ["cat", "dog", "dog"], - } - - batch2 = { - EMBEDDINGS_KEY: torch.stack([oh(0), oh(1), oh(1)]), - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([False, False, False]), - IS_GALLERY_KEY: torch.tensor([True, True, True]), - CATEGORIES_KEY: ["cat", "dog", "dog"], - } + dataset = EmbeddingsQueryGalleryLabeledDataset( + embeddings=torch.stack([oh(0), oh(1), oh(1), oh(0), oh(1), oh(1)]).float(), + labels=torch.tensor([0, 1, 1, 0, 1, 1]).long(), + is_query=torch.tensor([True, True, True, False, False, False]).bool(), + is_gallery=torch.tensor([False, False, False, True, True, True]).bool(), + categories=np.array(["cat", "dog", "dog", "cat", "dog", "dog"]), + ) k = 1 metrics = defaultdict(lambda: defaultdict(dict)) # type: ignore @@ -69,26 +63,18 @@ def perfect_case() -> Any: metrics["cat"]["cmc"][k] = 1.0 metrics["dog"]["cmc"][k] = 1.0 - return (batch1, batch2), (metrics, k) + return dataset, (metrics, k) @pytest.fixture() def imperfect_case() -> Any: - batch1 = { - EMBEDDINGS_KEY: torch.stack([oh(0), oh(1), oh(3)]), # 3d embedding pretends to be an error - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([True, True, True]), - IS_GALLERY_KEY: torch.tensor([False, False, False]), - CATEGORIES_KEY: torch.tensor([10, 20, 20]), - } - - batch2 = { - EMBEDDINGS_KEY: torch.stack([oh(0), oh(1), oh(1)]), - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([False, False, False]), - IS_GALLERY_KEY: torch.tensor([True, True, True]), - CATEGORIES_KEY: torch.tensor([10, 20, 20]), - } + dataset = EmbeddingsQueryGalleryLabeledDataset( + embeddings=torch.stack([oh(0), oh(1), oh(3), oh(0), oh(1), oh(1)]).float(), # 3d val pretends to be an error + labels=torch.tensor([0, 1, 1, 0, 1, 1]).long(), + is_query=torch.tensor([True, True, True, False, False, False]).bool(), + is_gallery=torch.tensor([False, False, False, True, True, True]).bool(), + categories=np.array([10, 20, 20, 10, 20, 20]), + ) k = 1 metrics = defaultdict(lambda: defaultdict(dict)) # type: ignore @@ -96,26 +82,18 @@ def imperfect_case() -> Any: metrics[10]["cmc"][k] = 1.0 metrics[20]["cmc"][k] = 0.5 - return (batch1, batch2), (metrics, k) + return dataset, (metrics, k) @pytest.fixture() def worst_case() -> Any: - batch1 = { - EMBEDDINGS_KEY: torch.stack([oh(1), oh(0), oh(0)]), # 3d embedding pretends to be an error - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([True, True, True]), - IS_GALLERY_KEY: torch.tensor([False, False, False]), - CATEGORIES_KEY: torch.tensor([10, 20, 20]), - } - - batch2 = { - EMBEDDINGS_KEY: torch.stack([oh(0), oh(1), oh(1)]), - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([False, False, False]), - IS_GALLERY_KEY: torch.tensor([True, True, True]), - CATEGORIES_KEY: torch.tensor([10, 20, 20]), - } + dataset = EmbeddingsQueryGalleryLabeledDataset( + embeddings=torch.stack([oh(1), oh(0), oh(0), oh(0), oh(1), oh(1)]).float(), # all are errors + labels=torch.tensor([0, 1, 1, 0, 1, 1]).long(), + is_query=torch.tensor([True, True, True, False, False, False]).bool(), + is_gallery=torch.tensor([False, False, False, True, True, True]).bool(), + categories=np.array([10, 20, 20, 10, 20, 20]), + ) k = 1 metrics = defaultdict(lambda: defaultdict(dict)) # type: ignore @@ -123,38 +101,32 @@ def worst_case() -> Any: metrics[10]["cmc"][k] = 0 metrics[20]["cmc"][k] = 0 - return (batch1, batch2), (metrics, k) + return dataset, (metrics, k) @pytest.fixture() -def case_for_distance_check() -> Any: - batch1 = { - EMBEDDINGS_KEY: torch.stack([oh(1) * 2, oh(1) * 3, oh(0)]), - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([True, True, True]), - IS_GALLERY_KEY: torch.tensor([False, False, False]), - CATEGORIES_KEY: torch.tensor([10, 20, 20]), - } - - batch2 = { - EMBEDDINGS_KEY: torch.stack([oh(0), oh(1), oh(1)]), - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([False, False, False]), - IS_GALLERY_KEY: torch.tensor([True, True, True]), - CATEGORIES_KEY: torch.tensor([10, 20, 20]), - } - ids_ranked_by_distance = [0, 2, 1] - return (batch1, batch2), ids_ranked_by_distance +def case_for_finding_worst_queries() -> Any: + dataset = EmbeddingsQueryGalleryLabeledDataset( + embeddings=torch.stack([oh(0), oh(1), oh(2), oh(0), oh(5), oh(5)]).float(), # last 2 are errors + labels=torch.tensor([0, 1, 2, 0, 1, 2]).long(), + is_query=torch.tensor([True, True, True, False, False, False]).bool(), + is_gallery=torch.tensor([False, False, False, True, True, True]).bool(), + categories=np.array([10, 20, 20, 10, 20, 20]), + ) + + worst_two_queries = {1, 2} + return dataset, worst_two_queries def run_retrieval_metrics(case) -> None: # type: ignore - (batch1, batch2), (gt_metrics, k) = case + dataset, (gt_metrics, k) = case top_k = (k,) - num_samples = len(batch1[LABELS_KEY]) + len(batch2[LABELS_KEY]) + num_samples = len(dataset) calc = EmbeddingMetrics( - embeddings_key=EMBEDDINGS_KEY, + dataset=dataset, + embeddings_key=dataset.input_tensors_key, labels_key=LABELS_KEY, is_query_key=IS_QUERY_KEY, is_gallery_key=IS_GALLERY_KEY, @@ -168,8 +140,9 @@ def run_retrieval_metrics(case) -> None: # type: ignore ) calc.setup(num_samples=num_samples) - calc.update_data(batch1) - calc.update_data(batch2) + + for batch in DataLoader(dataset, batch_size=4, shuffle=False): + calc.update_data(batch) metrics = calc.compute_metrics() @@ -182,15 +155,15 @@ def run_retrieval_metrics(case) -> None: # type: ignore assert calc.acc.collected_samples == num_samples # type: ignore -def run_across_epochs(case1, case2) -> None: # type: ignore - (batch11, batch12), (gt_metrics1, k1) = case1 - (batch21, batch22), (gt_metrics2, k2) = case2 - assert k1 == k2 +def run_across_epochs(case) -> None: # type: ignore + dataset, (gt_metrics, k) = case - top_k = (k1,) + top_k = (k,) + num_samples = len(dataset) calc = EmbeddingMetrics( - embeddings_key=EMBEDDINGS_KEY, + dataset=dataset, + embeddings_key=dataset.input_tensors_key, labels_key=LABELS_KEY, is_query_key=IS_QUERY_KEY, is_gallery_key=IS_GALLERY_KEY, @@ -203,32 +176,22 @@ def run_across_epochs(case1, case2) -> None: # type: ignore postprocessor=get_trivial_postprocessor(top_n=3), ) - def epoch_case(batch_a, batch_b, ground_truth_metrics) -> None: # type: ignore - num_samples = len(batch_a[LABELS_KEY]) + len(batch_b[LABELS_KEY]) - calc.setup(num_samples=num_samples) - calc.update_data(batch_a) - calc.update_data(batch_b) - metrics = calc.compute_metrics() - - compare_dicts_recursively(metrics, ground_truth_metrics) + metrics_all_epochs = [] - # the euclidean distance between any one-hots is always sqrt(2) or 0 - assert compare_tensors_as_sets(calc.distance_matrix, torch.tensor([0, math.sqrt(2)])) # type: ignore + for _ in range(2): # epochs + calc.setup(num_samples=num_samples) - assert (calc.mask_gt.unique() == torch.tensor([0, 1])).all() # type: ignore - assert calc.acc.collected_samples == num_samples + for batch in DataLoader(dataset, batch_size=2, num_workers=0, shuffle=False, drop_last=False): + calc.update_data(batch) - # 1st epoch - epoch_case(batch11, batch12, gt_metrics1) + metrics_all_epochs.append(calc.compute_metrics()) - # 2nd epoch - epoch_case(batch21, batch22, gt_metrics2) + assert compare_dicts_recursively(metrics_all_epochs[0], metrics_all_epochs[-1]) - # 3d epoch - epoch_case(batch11, batch12, gt_metrics1) + # the euclidean distance between any one-hots is always sqrt(2) or 0 + assert compare_tensors_as_sets(calc.distance_matrix, torch.tensor([0, math.sqrt(2)])) - # 4th epoch - epoch_case(batch21, batch22, gt_metrics2) + assert calc.acc.collected_samples == num_samples def test_perfect_case(perfect_case) -> None: # type: ignore @@ -243,37 +206,37 @@ def test_worst_case(worst_case) -> None: # type: ignore run_retrieval_metrics(worst_case) -def test_mixed_epochs(perfect_case, imperfect_case, worst_case): # type: ignore - cases = [perfect_case, imperfect_case, worst_case] - for case1 in cases: - for case2 in cases: - run_across_epochs(case1, case2) +def test_several_epochs(perfect_case, imperfect_case, worst_case): # type: ignore + run_across_epochs(perfect_case) + run_across_epochs(imperfect_case) + run_across_epochs(worst_case) -def test_worst_k(case_for_distance_check) -> None: # type: ignore - (batch1, batch2), gt_ids = case_for_distance_check +def test_worst_k(case_for_finding_worst_queries) -> None: # type: ignore + dataset, worst_query_ids = case_for_finding_worst_queries - num_samples = len(batch1[LABELS_KEY]) + len(batch2[LABELS_KEY]) + num_samples = len(dataset) calc = EmbeddingMetrics( - embeddings_key=EMBEDDINGS_KEY, + dataset=dataset, + embeddings_key=dataset.input_tensors_key, labels_key=LABELS_KEY, is_query_key=IS_QUERY_KEY, is_gallery_key=IS_GALLERY_KEY, categories_key=CATEGORIES_KEY, - cmc_top_k=(), + cmc_top_k=(1,), precision_top_k=(), - map_top_k=(2,), + map_top_k=(), fmr_vals=tuple(), postprocessor=get_trivial_postprocessor(top_n=1_000), ) calc.setup(num_samples=num_samples) - calc.update_data(batch1) - calc.update_data(batch2) + for batch in DataLoader(dataset, batch_size=4, shuffle=False): + calc.update_data(batch) calc.compute_metrics() - assert calc.get_worst_queries_ids(f"{OVERALL_CATEGORIES_KEY}/map/2", 3) == gt_ids + assert set(calc.get_worst_queries_ids(f"{OVERALL_CATEGORIES_KEY}/cmc/1", 2)) == worst_query_ids @pytest.mark.parametrize("extra_keys", [[], [PATHS_KEY], [PATHS_KEY, "a"], ["a"]]) diff --git a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py index cd8e9715f..4cb920d24 100644 --- a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py +++ b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py @@ -1,4 +1,3 @@ -import math from functools import partial from random import randint, random from typing import Tuple @@ -7,88 +6,94 @@ import torch from torch import Tensor -from oml.functional.metrics import calc_distance_matrix, calc_retrieval_metrics +from oml.functional.metrics import calc_retrieval_metrics +from oml.interfaces.datasets import IQueryGalleryDataset, IQueryGalleryLabeledDataset from oml.interfaces.models import IPairwiseModel from oml.models.meta.siamese import LinearTrivialDistanceSiamese from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.utils.misc import flatten_dict, one_hot from oml.utils.misc_torch import normalise, pairwise_dist +from tests.test_integrations.utils import ( + EmbeddingsQueryGalleryDataset, + EmbeddingsQueryGalleryLabeledDataset, +) FEAT_SIZE = 8 oh = partial(one_hot, dim=FEAT_SIZE) @pytest.fixture -def independent_query_gallery_case() -> Tuple[Tensor, Tensor, Tensor]: +def independent_query_gallery_case() -> Tuple[IQueryGalleryDataset, Tensor]: sz = 7 feat_dim = 12 - embeddings = torch.randn((sz, feat_dim)) - embeddings = normalise(embeddings) - is_query = torch.ones(sz).bool() is_query[: sz // 2] = False is_gallery = torch.ones(sz).bool() is_gallery[sz // 2 :] = False - return embeddings, is_query, is_gallery + embeddings = normalise(torch.randn((sz, feat_dim))).float() + + dataset = EmbeddingsQueryGalleryDataset(embeddings=embeddings, is_query=is_query, is_gallery=is_gallery) + + embeddings_inference = embeddings.clone() # pretend it's our inference + + return dataset, embeddings_inference @pytest.fixture -def shared_query_gallery_case() -> Tuple[Tensor, Tensor, Tensor]: +def shared_query_gallery_case() -> Tuple[IQueryGalleryDataset, Tensor]: sz = 7 feat_dim = 4 - embeddings = torch.randn((sz, feat_dim)) - embeddings = normalise(embeddings) + embeddings = normalise(torch.randn((sz, feat_dim))).float() - is_query = torch.ones(sz).bool() - is_gallery = torch.ones(sz).bool() + dataset = EmbeddingsQueryGalleryDataset( + embeddings=embeddings, is_query=torch.ones(sz).bool(), is_gallery=torch.ones(sz).bool() + ) + + embeddings_inference = embeddings.clone() # pretend it's our inference - return embeddings, is_query, is_gallery + return dataset, embeddings_inference @pytest.mark.long @pytest.mark.parametrize("top_n", [2, 5, 100]) +@pytest.mark.parametrize("pairwise_distances_bias", [0, 100]) @pytest.mark.parametrize("fixture_name", ["independent_query_gallery_case", "shared_query_gallery_case"]) def test_trivial_processing_does_not_change_distances_order( - request: pytest.FixtureRequest, fixture_name: str, top_n: int + request: pytest.FixtureRequest, fixture_name: str, top_n: int, pairwise_distances_bias: float ) -> None: - embeddings, is_query, is_gallery = request.getfixturevalue(fixture_name) - embeddings_query = embeddings[is_query] - embeddings_gallery = embeddings[is_gallery] + dataset, embeddings = request.getfixturevalue(fixture_name) - distances = calc_distance_matrix(embeddings, is_query, is_gallery) + distances = pairwise_dist(x1=embeddings[dataset.get_query_ids()], x2=embeddings[dataset.get_gallery_ids()], p=2) - model = LinearTrivialDistanceSiamese(feat_dim=embeddings.shape[-1], identity_init=True) + model = LinearTrivialDistanceSiamese(embeddings.shape[-1], output_bias=pairwise_distances_bias, identity_init=True) processor = PairwiseReranker(pairwise_model=model, top_n=top_n, num_workers=0, batch_size=64) - distances_processed = processor.process( - queries=embeddings_query, - galleries=embeddings_gallery, - distances=distances.clone(), - ) - - order = distances.argsort() - order_processed = distances_processed.argsort() + distances_processed = processor.process(distances=distances.clone(), dataset=dataset) - assert (order == order_processed).all(), (order, order_processed) + if pairwise_distances_bias == 0: + assert torch.allclose(distances_processed, distances) + else: + assert (distances_processed.argsort() == distances.argsort()).all() + assert not torch.allclose(distances_processed, distances) - if top_n <= is_gallery.sum(): - min_orig_distances = torch.topk(distances, k=top_n, largest=False).values - min_processed_distances = torch.topk(distances_processed, k=top_n, largest=False).values - assert torch.allclose(min_orig_distances, min_processed_distances) +def perfect_case() -> Tuple[IQueryGalleryLabeledDataset, Tensor]: + embeddings = torch.stack([oh(1), oh(2), oh(3), oh(1), oh(2), oh(1), oh(2), oh(3)]).float() -def perfect_case() -> Tuple[Tensor, Tensor, Tensor, Tensor]: - query_labels = torch.tensor([1, 2, 3]).long() - query_embeddings = torch.stack([oh(1), oh(2), oh(3)]) + dataset = EmbeddingsQueryGalleryLabeledDataset( + embeddings=embeddings, + labels=torch.tensor([1, 2, 3, 1, 2, 1, 2, 3]).long(), + is_query=torch.tensor([1, 1, 1, 1, 0, 0, 0, 0]).bool(), + is_gallery=torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]).bool(), + ) - gallery_labels = torch.tensor([1, 2, 1, 2, 3]).long() - gallery_embeddings = torch.stack([oh(1), oh(2), oh(1), oh(2), oh(3)]) + embeddings_inference = embeddings.clone() - return query_embeddings, gallery_embeddings, query_labels, gallery_labels + return dataset, embeddings_inference @pytest.mark.long @@ -105,9 +110,12 @@ def test_trivial_processing_fixes_broken_perfect_case() -> None: n_repetitions = 20 for _ in range(n_repetitions): - query_embeddings, gallery_embeddings, query_labels, gallery_labels = perfect_case() - distances = pairwise_dist(query_embeddings, gallery_embeddings) - mask_gt = query_labels.unsqueeze(-1) == gallery_labels + dataset, embeddings = perfect_case() + distances = pairwise_dist(embeddings[dataset.get_query_ids()], embeddings[dataset.get_gallery_ids()], p=2) + + labels_q = torch.tensor(dataset.get_labels()[dataset.get_query_ids()]) + labels_g = torch.tensor(dataset.get_labels()[dataset.get_gallery_ids()]) + mask_gt = labels_q.unsqueeze(-1) == labels_g nq, ng = distances.shape @@ -127,9 +135,9 @@ def test_trivial_processing_fixes_broken_perfect_case() -> None: metrics = flatten_dict(calc_retrieval_metrics(distances=distances, **args)) # Metrics after broken distances have been fixed - model = LinearTrivialDistanceSiamese(feat_dim=gallery_embeddings.shape[-1], identity_init=True) + model = LinearTrivialDistanceSiamese(feat_dim=embeddings.shape[-1], identity_init=True, output_bias=10) processor = PairwiseReranker(pairwise_model=model, top_n=top_n, batch_size=16, num_workers=0) - distances_upd = processor.process(distances, query_embeddings, gallery_embeddings) + distances_upd = processor.process(distances, dataset) metrics_upd = flatten_dict(calc_retrieval_metrics(distances=distances_upd, **args)) for key in metrics.keys(): @@ -138,46 +146,6 @@ def test_trivial_processing_fixes_broken_perfect_case() -> None: assert metric_upd >= metric, (key, metric, metric_upd) -class DummyPairwise(IPairwiseModel): - def __init__(self, distances_to_return: Tensor): - super(DummyPairwise, self).__init__() - self.distances_to_return = distances_to_return - self.parameter = torch.nn.Linear(1, 1) - - def forward(self, x1: Tensor, x2: Tensor) -> Tensor: - return self.distances_to_return - - def predict(self, x1: Tensor, x2: Tensor) -> Tensor: - return self.distances_to_return - - -@pytest.mark.long -def test_trivial_processing_fixes_broken_perfect_case_2() -> None: - """ - The idea of the test is similar to "test_trivial_processing_fixes_broken_perfect_case", - but this time we check the exact metrics values. - - """ - distances = torch.tensor([[0.8, 0.3, 0.2, 0.4, 0.5]]) - mask_gt = torch.tensor([[1, 1, 0, 1, 0]]).bool() - - args = {"mask_gt": mask_gt, "precision_top_k": (1, 3)} - - precisions = calc_retrieval_metrics(distances=distances, **args)["precision"] - assert math.isclose(precisions[1], 0) - assert math.isclose(precisions[3], 2 / 3, abs_tol=1e-5) - - # Now let's fix the error with dummy pairwise model - model = DummyPairwise(distances_to_return=torch.tensor([3.5, 2.5])) - processor = PairwiseReranker(pairwise_model=model, top_n=2, batch_size=128, num_workers=0) - distances_upd = processor.process( - distances=distances, queries=torch.randn((1, FEAT_SIZE)), galleries=torch.randn((5, FEAT_SIZE)) - ) - precisions_upd = calc_retrieval_metrics(distances=distances_upd, **args)["precision"] - assert math.isclose(precisions_upd[1], 1) - assert math.isclose(precisions_upd[3], 2 / 3, abs_tol=1e-5) - - class RandomPairwise(IPairwiseModel): def __init__(self): # type: ignore super(RandomPairwise, self).__init__() @@ -196,13 +164,13 @@ def test_processing_not_changing_non_sensitive_metrics(top_n: int) -> None: # The idea of the test is that postprocessing of first n elements # cannot change cmc@n and precision@n - # Let's construct some random input - query_embeddings_perfect, gallery_embeddings_perfect, query_labels, gallery_labels = perfect_case() - query_embeddings = torch.rand_like(query_embeddings_perfect) - gallery_embeddings = torch.rand_like(gallery_embeddings_perfect) - mask_gt = query_labels.unsqueeze(-1) == gallery_labels + dataset, embeddings = perfect_case() + + distances = pairwise_dist(embeddings[dataset.get_query_ids()], embeddings[dataset.get_gallery_ids()], p=2) - distances = pairwise_dist(query_embeddings, gallery_embeddings) + labels_q = torch.tensor(dataset.get_labels()[dataset.get_query_ids()]) + labels_g = torch.tensor(dataset.get_labels()[dataset.get_gallery_ids()]) + mask_gt = labels_q.unsqueeze(-1) == labels_g args = { "cmc_top_k": (top_n,), @@ -216,7 +184,7 @@ def test_processing_not_changing_non_sensitive_metrics(top_n: int) -> None: model = RandomPairwise() processor = PairwiseReranker(pairwise_model=model, top_n=top_n, batch_size=4, num_workers=0) - distances_upd = processor.process(distances=distances, queries=query_embeddings, galleries=gallery_embeddings) + distances_upd = processor.process(distances=distances, dataset=dataset) metrics_after = calc_retrieval_metrics(distances=distances_upd, **args) diff --git a/tests/test_oml/test_postprocessor/test_pairwise_images.py b/tests/test_oml/test_postprocessor/test_pairwise_images.py index a568c946f..45ca9bf1c 100644 --- a/tests/test_oml/test_postprocessor/test_pairwise_images.py +++ b/tests/test_oml/test_postprocessor/test_pairwise_images.py @@ -50,9 +50,7 @@ def test_trivial_processing_does_not_change_distances_order(top_n: int, pairwise distances_processed = postprocessor.process(distances=distances.clone(), dataset=dataset) if pairwise_distances_bias == 0: - # distances are literally the same assert torch.allclose(distances_processed.argsort(), distances.argsort()) else: - # since pairwise distances have been shifted, the relative order remain the same, but the values are different - assert torch.allclose(distances_processed.argsort(), distances.argsort()) + assert (distances_processed.argsort() == distances.argsort()).all() assert not torch.allclose(distances_processed, distances) From db2d7e45e1d2aac462606360b01031dd2989ce92 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Sat, 20 Apr 2024 23:31:23 +0700 Subject: [PATCH 12/23] upd --- tests/test_runs/test_code_from_markdown.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_runs/test_code_from_markdown.py b/tests/test_runs/test_code_from_markdown.py index 81260ed58..9c352a32c 100644 --- a/tests/test_runs/test_code_from_markdown.py +++ b/tests/test_runs/test_code_from_markdown.py @@ -30,10 +30,11 @@ def find_code_block(file: Path, start_indicator: str, end_indicator: str) -> str ("extractor/train_val_pl.md", "[comment]:lightning-start\n", "[comment]:lightning-end\n"), ("extractor/train_val_pl_ddp.md", "[comment]:lightning-ddp-start\n", "[comment]:lightning-ddp-end\n"), ("extractor/train_2loaders_val.md", "[comment]:lightning-2loaders-start\n", "[comment]:lightning-2loaders-end\n"), # noqa - ("extractor/retrieval_usage.md", "[comment]:usage-retrieval-start\n", "[comment]:usage-retrieval-end\n"), ("zoo/models_usage.md", "[comment]:zoo-start\n", "[comment]:zoo-end\n"), - ("postprocessing/train_val.md", "[comment]:postprocessor-start\n", "[comment]:postprocessor-end\n"), - ("postprocessing/predict.md", "[comment]:postprocessor-pred-start\n", "[comment]:postprocessor-pred-end\n"), + # todo 522: update this examples affected by reworking inference functions and reworking re-ranker + # ("extractor/retrieval_usage.md", "[comment]:usage-retrieval-start\n", "[comment]:usage-retrieval-end\n"), + # ("postprocessing/train_val.md", "[comment]:postprocessor-start\n", "[comment]:postprocessor-end\n"), + # ("postprocessing/predict.md", "[comment]:postprocessor-pred-start\n", "[comment]:postprocessor-pred-end\n"), ], ) # fmt: skip def test_code_blocks_in_readme(fname: str, start_indicator: str, end_indicator: str) -> None: From bb216280812e39c2501491ae8e32094c339047cb Mon Sep 17 00:00:00 2001 From: alekseysh Date: Sun, 21 Apr 2024 07:13:41 +0700 Subject: [PATCH 13/23] minor --- oml/datasets/images.py | 1 - 1 file changed, 1 deletion(-) diff --git a/oml/datasets/images.py b/oml/datasets/images.py index 6ce966d33..e7bae7635 100644 --- a/oml/datasets/images.py +++ b/oml/datasets/images.py @@ -450,7 +450,6 @@ def get_retrieval_images_datasets( check_retrieval_dataframe_format(df, dataset_root=dataset_root, verbose=verbose) - # todo 522: why do we need it? # first half will consist of "train" split, second one of "val" # so labels in train will be from 0 to N-1 and labels in test will be from N to K mapper = {l: i for i, l in enumerate(df.sort_values(by=[SPLIT_COLUMN])[LABELS_COLUMN].unique())} From 6212be9c91e818d4bbe5af5b1b8c9bf0cdaeeb76 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Sun, 21 Apr 2024 22:34:56 +0700 Subject: [PATCH 14/23] addressed comments and introduced IIndexedDataset --- oml/retrieval/postprocessors/pairwise.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/oml/retrieval/postprocessors/pairwise.py b/oml/retrieval/postprocessors/pairwise.py index 0ac1cd839..e5d41a5e8 100644 --- a/oml/retrieval/postprocessors/pairwise.py +++ b/oml/retrieval/postprocessors/pairwise.py @@ -68,7 +68,6 @@ def process_neigh( self, distances: Tensor, retrieved_ids: Tensor, dataset: IQueryGalleryDataset ) -> Tuple[Tensor, Tensor]: """ - Note, the new distances to the ``top_n`` items produced by the pairwise model may be adjusted to remain distances sorted. Here is an example: ``original_distances = [0.1, 0.2, 0.3, 0.5, 0.6], top_n = 3`` @@ -81,6 +80,8 @@ def process_neigh( If concatenation of two distances is already sorted, we keep it untouched. """ + # todo 522: explain what's going on here + top_n = min(self.top_n, distances.shape[1]) # let's list pairs of (query_i, gallery_j) we need to process From 27369223f697f1bbf2d2d2cd437929e40e516c76 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Mon, 22 Apr 2024 01:02:42 +0700 Subject: [PATCH 15/23] updated examples --- .../extractor/retrieval_usage.md | 25 +++++----- .../examples_source/postprocessing/predict.md | 40 +++++++--------- .../postprocessing/train_val.md | 48 +++++++++---------- oml/datasets/__init__.py | 1 + oml/datasets/images.py | 10 ++++ oml/interfaces/datasets.py | 2 +- oml/retrieval/postprocessors/pairwise.py | 4 +- oml/utils/download_mock_dataset.py | 8 +++- tests/test_runs/test_code_from_markdown.py | 7 ++- 9 files changed, 79 insertions(+), 66 deletions(-) diff --git a/docs/readme/examples_source/extractor/retrieval_usage.md b/docs/readme/examples_source/extractor/retrieval_usage.md index 57954b738..a514bd69e 100644 --- a/docs/readme/examples_source/extractor/retrieval_usage.md +++ b/docs/readme/examples_source/extractor/retrieval_usage.md @@ -6,24 +6,24 @@ ```python import torch -from oml.const import MOCK_DATASET_PATH -from oml.inference.flat import inference_on_images +from oml.datasets import ImageQueryGalleryDataset +from oml.inference import inference from oml.models import ViTExtractor from oml.registry.transforms import get_transforms_for_pretrained from oml.utils.download_mock_dataset import download_mock_dataset from oml.utils.misc_torch import pairwise_dist -_, df_val = download_mock_dataset(MOCK_DATASET_PATH) -df_val["path"] = df_val["path"].apply(lambda x: MOCK_DATASET_PATH / x) -queries = df_val[df_val["is_query"]]["path"].tolist() -galleries = df_val[df_val["is_gallery"]]["path"].tolist() +_, df_test = download_mock_dataset(global_paths=True) +del df_test["label"] # we don't need gt labels for doing predictions extractor = ViTExtractor.from_pretrained("vits16_dino") transform, _ = get_transforms_for_pretrained("vits16_dino") -args = {"num_workers": 0, "batch_size": 8} -features_queries = inference_on_images(extractor, paths=queries, transform=transform, **args) -features_galleries = inference_on_images(extractor, paths=galleries, transform=transform, **args) +dataset = ImageQueryGalleryDataset(df_test, transform=transform) + +embeddings = inference(extractor, dataset, batch_size=4, num_workers=0) +embeddings_query = embeddings[dataset.get_query_ids()] +embeddings_gallery = embeddings[dataset.get_gallery_ids()] # Now we can explicitly build pairwise matrix of distances or save you RAM via using kNN use_knn = False @@ -31,12 +31,11 @@ top_k = 3 if use_knn: from sklearn.neighbors import NearestNeighbors - knn = NearestNeighbors(algorithm="auto", p=2) - knn.fit(features_galleries) - dists, ii_closest = knn.kneighbors(features_queries, n_neighbors=top_k, return_distance=True) + knn = NearestNeighbors(algorithm="auto", p=2).fit(embeddings_query) + dists, ii_closest = knn.kneighbors(embeddings_gallery, n_neighbors=top_k, return_distance=True) else: - dist_mat = pairwise_dist(x1=features_queries, x2=features_galleries) + dist_mat = pairwise_dist(x1=embeddings_query, x2=embeddings_gallery, p=2) dists, ii_closest = torch.topk(dist_mat, dim=1, k=top_k, largest=False) print(f"Top {top_k} items closest to queries are:\n {ii_closest}") diff --git a/docs/readme/examples_source/postprocessing/predict.md b/docs/readme/examples_source/postprocessing/predict.md index 0c6edeafc..35fb69604 100644 --- a/docs/readme/examples_source/postprocessing/predict.md +++ b/docs/readme/examples_source/postprocessing/predict.md @@ -5,44 +5,40 @@ [comment]:postprocessor-pred-start ```python import torch -from torch.utils.data import DataLoader -from oml.const import PATHS_COLUMN -from oml.datasets.base import DatasetQueryGallery -from oml.inference.flat import inference_on_dataframe +from oml.datasets import ImageQueryGalleryDataset +from oml.inference import inference from oml.models import ConcatSiamese, ViTExtractor from oml.registry.transforms import get_transforms_for_pretrained from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.utils.download_mock_dataset import download_mock_dataset from oml.utils.misc_torch import pairwise_dist -dataset_root = "mock_dataset/" -download_mock_dataset(dataset_root) +_, df_test = download_mock_dataset(global_paths=True) +del df_test["label"] # we don't need gt labels for doing predictions -# 1. Let's use feature extractor to get predictions extractor = ViTExtractor.from_pretrained("vits16_dino") transforms, _ = get_transforms_for_pretrained("vits16_dino") -_, emb_val, _, df_val = inference_on_dataframe(dataset_root, "df.csv", extractor, transforms=transforms) +dataset = ImageQueryGalleryDataset(df_test, transform=transforms) -is_query = df_val["is_query"].astype('bool').values -distances = pairwise_dist(x1=emb_val[is_query], x2=emb_val[~is_query]) +# 1. Let's get top 5 galleries closest to every query... +embeddings = inference(extractor, dataset, batch_size=4, num_workers=0) +embeddings_query = embeddings[dataset.get_query_ids()] +embeddings_gallery = embeddings[dataset.get_gallery_ids()] -print("\nOriginal predictions:\n", torch.topk(distances, dim=1, k=3, largest=False)[1]) +distances = pairwise_dist(x1=embeddings_query, x2=embeddings_gallery, p=2) +ii_closest = torch.topk(distances, dim=1, k=5, largest=False)[1] -# 2. Let's initialise a random pairwise postprocessor to perform re-ranking +# 2. ... and let's re-rank first 3 of them siamese = ConcatSiamese(extractor=extractor, mlp_hidden_dims=[100]) # Note! Replace it with your trained postprocessor -postprocessor = PairwiseReranker(top_n=3, pairwise_model=siamese, transforms=transforms) - -dataset = DatasetQueryGallery(df_val, extra_data={"embeddings": emb_val}, transform=transforms) -loader = DataLoader(dataset, batch_size=4) - -query_paths = df_val[PATHS_COLUMN][is_query].values -gallery_paths = df_val[PATHS_COLUMN][~is_query].values -distances_upd = postprocessor.process(distances=distances, queries=query_paths, galleries=gallery_paths) - -print("\nPredictions after postprocessing:\n", torch.topk(distances_upd, dim=1, k=3, largest=False)[1]) +postprocessor = PairwiseReranker(top_n=3, pairwise_model=siamese, batch_size=4, num_workers=0) +distances_upd = postprocessor.process(distances, dataset=dataset) +ii_closest_upd = torch.topk(distances_upd, dim=1, k=5, largest=False)[1] +# You may see the first 3 positions have changed, but the rest remain the same: +print("\Closest galleries:\n", ii_closest) +print("\nClosest galleries updates:\n", ii_closest_upd) ``` [comment]:postprocessor-pred-end

diff --git a/docs/readme/examples_source/postprocessing/train_val.md b/docs/readme/examples_source/postprocessing/train_val.md index f2ee90382..8384a6dd7 100644 --- a/docs/readme/examples_source/postprocessing/train_val.md +++ b/docs/readme/examples_source/postprocessing/train_val.md @@ -10,52 +10,52 @@ import torch from torch.nn import BCEWithLogitsLoss from torch.utils.data import DataLoader -from oml.datasets.base import DatasetWithLabels, DatasetQueryGallery -from oml.inference.flat import inference_on_dataframe +from oml.datasets import ImageLabeledDataset, ImageQueryGalleryLabeledDataset, ImageBaseDataset +from oml.inference import inference from oml.metrics.embeddings import EmbeddingMetrics from oml.miners.pairs import PairsMiner from oml.models import ConcatSiamese, ViTExtractor +from oml.registry.transforms import get_transforms_for_pretrained from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.samplers.balance import BalanceSampler -from oml.transforms.images.torchvision import get_normalisation_resize_torch from oml.utils.download_mock_dataset import download_mock_dataset +from oml.transforms.images.torchvision import get_augs_torch -# Let's start with saving embeddings of a pretrained extractor for which we want to build a postprocessor -dataset_root = "mock_dataset/" -download_mock_dataset(dataset_root) +# In these example we will train a pairwise model as a re-ranker for ViT +extractor = ViTExtractor.from_pretrained("vits16_dino") +transforms, _ = get_transforms_for_pretrained("vits16_dino") +df_train, df_val = download_mock_dataset(global_paths=True) -extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False) -transform = get_normalisation_resize_torch(im_size=64) +# SAVE VIT EMBEDDINGS +# - training ones are needed for hard negative sampling when training pairwise model +# - validation ones are needed to construct the original prediction (which we will re-rank) +embeddings_train = inference(extractor, ImageBaseDataset(df_train["path"].tolist(), transform=transforms), batch_size=4, num_workers=0) +embeddings_valid = inference(extractor, ImageBaseDataset(df_val["path"].tolist(), transform=transforms), batch_size=4, num_workers=0) -embeddings_train, embeddings_val, df_train, df_val = \ - inference_on_dataframe(dataset_root, "df.csv", extractor=extractor, transforms=transform) - -# We are building Siamese model on top of existing weights and train it to recognize positive/negative pairs -siamese = ConcatSiamese(extractor=extractor, mlp_hidden_dims=[100]) -optimizer = torch.optim.SGD(siamese.parameters(), lr=1e-6) +# TRAIN PAIRWISE MODEL +train_dataset = ImageLabeledDataset(df_train, transform=get_augs_torch(224), extra_data={"embeddings": embeddings_train}) +pairwise_model = ConcatSiamese(extractor=extractor, mlp_hidden_dims=[100]) +optimizer = torch.optim.SGD(pairwise_model.parameters(), lr=1e-6) miner = PairsMiner(hard_mining=True) criterion = BCEWithLogitsLoss() -train_dataset = DatasetWithLabels(df=df_train, transform=transform, extra_data={"embeddings": embeddings_train}) -batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2) -train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler) +train_loader = DataLoader(train_dataset, batch_sampler=BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)) for batch in train_loader: - # We sample pairs on which the original model struggled most + # We sample positive and negative pairs on which the original model struggled most ids1, ids2, is_negative_pair = miner.sample(features=batch["embeddings"], labels=batch["labels"]) - probs = siamese(x1=batch["input_tensors"][ids1], x2=batch["input_tensors"][ids2]) + probs = pairwise_model(x1=batch["input_tensors"][ids1], x2=batch["input_tensors"][ids2]) loss = criterion(probs, is_negative_pair.float()) - loss.backward() optimizer.step() optimizer.zero_grad() -# Siamese re-ranks top-n retrieval outputs of the original model performing inference on pairs (query, output_i) -val_dataset = DatasetQueryGallery(df=df_val, extra_data={"embeddings": embeddings_val}, transform=transform) +# VALIDATE RE-RANKING MODEL +val_dataset = ImageQueryGalleryLabeledDataset(df=df_val, transform=transforms, extra_data={"embeddings": embeddings_valid}) valid_loader = DataLoader(val_dataset, batch_size=4, shuffle=False) -postprocessor = PairwiseReranker(top_n=3, pairwise_model=siamese, transforms=transform) -calculator = EmbeddingMetrics(postprocessor=postprocessor) +postprocessor = PairwiseReranker(top_n=3, pairwise_model=pairwise_model, num_workers=0, batch_size=4) +calculator = EmbeddingMetrics(dataset=val_dataset, postprocessor=postprocessor) calculator.setup(num_samples=len(val_dataset)) for batch in valid_loader: diff --git a/oml/datasets/__init__.py b/oml/datasets/__init__.py index 538f3360e..682b6a93b 100644 --- a/oml/datasets/__init__.py +++ b/oml/datasets/__init__.py @@ -1,6 +1,7 @@ from oml.datasets.images import ( ImageBaseDataset, ImageLabeledDataset, + ImageQueryGalleryDataset, ImageQueryGalleryLabeledDataset, ) from oml.datasets.pairs import PairDataset diff --git a/oml/datasets/images.py b/oml/datasets/images.py index d51d8044a..adbd58d4f 100644 --- a/oml/datasets/images.py +++ b/oml/datasets/images.py @@ -421,6 +421,13 @@ def __init__( is_gallery_key=is_gallery_key, ) + self.input_tensors_key = self.__dataset.input_tensors_key + self.index_key = self.__dataset.index_key + + # todo 522: remove + self.is_query_key = self.__dataset.is_query_key + self.is_gallery_key = self.__dataset.is_gallery_key + def __getitem__(self, item: int) -> Dict[str, Any]: batch = self.__dataset[item] del batch[self.__dataset.labels_key] @@ -435,6 +442,9 @@ def get_gallery_ids(self) -> LongTensor: def visualize(self, item: int, color: TColor = BLACK) -> np.ndarray: return self.__dataset.visualize(item, color) + def __len__(self) -> int: + return len(self.__dataset) + def get_retrieval_images_datasets( dataset_root: Path, diff --git a/oml/interfaces/datasets.py b/oml/interfaces/datasets.py index 66b0db274..e5a967443 100644 --- a/oml/interfaces/datasets.py +++ b/oml/interfaces/datasets.py @@ -96,7 +96,7 @@ class IQueryGalleryLabeledDataset(IQueryGalleryDataset, ILabeledDataset, ABC): """ -class IPairDataset(Dataset, IIndexedDataset): +class IPairDataset(IIndexedDataset): """ This is an interface for the datasets which return pair of something. diff --git a/oml/retrieval/postprocessors/pairwise.py b/oml/retrieval/postprocessors/pairwise.py index e5d41a5e8..e82c6e841 100644 --- a/oml/retrieval/postprocessors/pairwise.py +++ b/oml/retrieval/postprocessors/pairwise.py @@ -59,7 +59,9 @@ def process(self, distances: Tensor, dataset: IQueryGalleryDataset) -> Tensor: # todo 522: # after we introduce RetrievalPrediction the signature of the method will change: so, we directly call # self.process_neigh. Thus, the code below is temporary to support the current interface. - distances_neigh, ii_neigh = torch.topk(distances, k=min(distances.shape[1], self.top_n)) + distances_neigh, ii_neigh = torch.topk( + distances, k=min(distances.shape[1], self.top_n), largest=False + ) # todo 522: test it!!! distances_neigh_upd, ii_neigh_upd = self.process_neigh(distances_neigh, ii_neigh, dataset) distances_upd = assign_2d(x=distances, indices=ii_neigh_upd, new_values=distances_neigh_upd) return distances_upd diff --git a/oml/utils/download_mock_dataset.py b/oml/utils/download_mock_dataset.py index eb193495e..202d3945c 100644 --- a/oml/utils/download_mock_dataset.py +++ b/oml/utils/download_mock_dataset.py @@ -17,7 +17,10 @@ def get_argparser() -> ArgumentParser: def download_mock_dataset( - dataset_root: Union[str, Path], check_md5: bool = True, df_name: str = "df.csv" + dataset_root: Union[str, Path] = MOCK_DATASET_PATH, + check_md5: bool = True, + df_name: str = "df.csv", + global_paths: bool = False, ) -> Tuple[pd.DataFrame, pd.DataFrame]: """ Function to download mock dataset which is already prepared in the required format. @@ -48,6 +51,9 @@ def download_mock_dataset( df = pd.read_csv(Path(dataset_root) / df_name) + if global_paths: + df["path"] = df["path"].apply(lambda x: str(Path(dataset_root) / x)) + df_train = df[df["split"] == "train"].reset_index(drop=True) df_val = df[df["split"] == "validation"].reset_index(drop=True) diff --git a/tests/test_runs/test_code_from_markdown.py b/tests/test_runs/test_code_from_markdown.py index 9c352a32c..4621484a8 100644 --- a/tests/test_runs/test_code_from_markdown.py +++ b/tests/test_runs/test_code_from_markdown.py @@ -31,10 +31,9 @@ def find_code_block(file: Path, start_indicator: str, end_indicator: str) -> str ("extractor/train_val_pl_ddp.md", "[comment]:lightning-ddp-start\n", "[comment]:lightning-ddp-end\n"), ("extractor/train_2loaders_val.md", "[comment]:lightning-2loaders-start\n", "[comment]:lightning-2loaders-end\n"), # noqa ("zoo/models_usage.md", "[comment]:zoo-start\n", "[comment]:zoo-end\n"), - # todo 522: update this examples affected by reworking inference functions and reworking re-ranker - # ("extractor/retrieval_usage.md", "[comment]:usage-retrieval-start\n", "[comment]:usage-retrieval-end\n"), - # ("postprocessing/train_val.md", "[comment]:postprocessor-start\n", "[comment]:postprocessor-end\n"), - # ("postprocessing/predict.md", "[comment]:postprocessor-pred-start\n", "[comment]:postprocessor-pred-end\n"), + ("extractor/retrieval_usage.md", "[comment]:usage-retrieval-start\n", "[comment]:usage-retrieval-end\n"), + ("postprocessing/train_val.md", "[comment]:postprocessor-start\n", "[comment]:postprocessor-end\n"), + ("postprocessing/predict.md", "[comment]:postprocessor-pred-start\n", "[comment]:postprocessor-pred-end\n"), ], ) # fmt: skip def test_code_blocks_in_readme(fname: str, start_indicator: str, end_indicator: str) -> None: From aedec8b9a7ae326f5aa258d25b388519ead614ef Mon Sep 17 00:00:00 2001 From: alekseysh Date: Mon, 22 Apr 2024 06:05:36 +0700 Subject: [PATCH 16/23] update --- README.md | 42 ++++++------- .../readme/examples_source/extractor/train.md | 5 +- .../extractor/train_2loaders_val.md | 9 +-- .../examples_source/extractor/train_val_pl.md | 7 +-- .../extractor/train_val_pl_ddp.md | 7 +-- .../extractor/train_with_pml.md | 5 +- .../extractor/train_with_pml_advanced.md | 5 +- docs/readme/examples_source/extractor/val.md | 5 +- .../extractor/val_with_sequence.md | 5 +- oml/retrieval/postprocessors/pairwise.py | 63 +++++++++++++------ oml/utils/download_mock_dataset.py | 1 + .../test_pairwise_embeddings.py | 33 ++++++---- .../test_pairwise_images.py | 9 +-- 13 files changed, 107 insertions(+), 89 deletions(-) diff --git a/README.md b/README.md index c20a2aa0e..60810a3fe 100644 --- a/README.md +++ b/README.md @@ -301,13 +301,12 @@ from oml.models import ViTExtractor from oml.samplers.balance import BalanceSampler from oml.utils.download_mock_dataset import download_mock_dataset -dataset_root = "mock_dataset/" -df_train, _ = download_mock_dataset(dataset_root) +df_train, _ = download_mock_dataset(global_paths=True) extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train() optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) -train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) +train_dataset = DatasetWithLabels(df_train) criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner(), need_logs=True) sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2) train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler) @@ -342,12 +341,11 @@ from oml.metrics.embeddings import EmbeddingMetrics from oml.models import ViTExtractor from oml.utils.download_mock_dataset import download_mock_dataset -dataset_root = "mock_dataset/" -_, df_val = download_mock_dataset(dataset_root) +_, df_val = download_mock_dataset(global_paths=True) extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval() -val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root) +val_dataset = DatasetQueryGallery(df_val) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4) calculator = EmbeddingMetrics(extra_keys=("paths",)) @@ -401,21 +399,20 @@ from oml.lightning.pipelines.logging import ( WandBPipelineLogger, ) -dataset_root = "mock_dataset/" -df_train, df_val = download_mock_dataset(dataset_root) +df_train, df_val = download_mock_dataset(global_paths=True) # model extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False) # train optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) -train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) +train_dataset = DatasetWithLabels(df_train) criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner()) batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3) train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler) # val -val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root) +val_dataset = DatasetQueryGallery(df_val) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4) metric_callback = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[train_dataset.paths_key,]), log_images=True) @@ -455,24 +452,24 @@ trainer.fit(pl_model, train_dataloaders=train_loader, val_dataloaders=val_loader ```python import torch -from oml.const import MOCK_DATASET_PATH -from oml.inference.flat import inference_on_images +from oml.datasets import ImageQueryGalleryDataset +from oml.inference import inference from oml.models import ViTExtractor from oml.registry.transforms import get_transforms_for_pretrained from oml.utils.download_mock_dataset import download_mock_dataset from oml.utils.misc_torch import pairwise_dist -_, df_val = download_mock_dataset(MOCK_DATASET_PATH) -df_val["path"] = df_val["path"].apply(lambda x: MOCK_DATASET_PATH / x) -queries = df_val[df_val["is_query"]]["path"].tolist() -galleries = df_val[df_val["is_gallery"]]["path"].tolist() +_, df_test = download_mock_dataset(global_paths=True) +del df_test["label"] # we don't need gt labels for doing predictions extractor = ViTExtractor.from_pretrained("vits16_dino") transform, _ = get_transforms_for_pretrained("vits16_dino") -args = {"num_workers": 0, "batch_size": 8} -features_queries = inference_on_images(extractor, paths=queries, transform=transform, **args) -features_galleries = inference_on_images(extractor, paths=galleries, transform=transform, **args) +dataset = ImageQueryGalleryDataset(df_test, transform=transform) + +embeddings = inference(extractor, dataset, batch_size=4, num_workers=0) +embeddings_query = embeddings[dataset.get_query_ids()] +embeddings_gallery = embeddings[dataset.get_gallery_ids()] # Now we can explicitly build pairwise matrix of distances or save you RAM via using kNN use_knn = False @@ -480,12 +477,11 @@ top_k = 3 if use_knn: from sklearn.neighbors import NearestNeighbors - knn = NearestNeighbors(algorithm="auto", p=2) - knn.fit(features_galleries) - dists, ii_closest = knn.kneighbors(features_queries, n_neighbors=top_k, return_distance=True) + knn = NearestNeighbors(algorithm="auto", p=2).fit(embeddings_query) + dists, ii_closest = knn.kneighbors(embeddings_gallery, n_neighbors=top_k, return_distance=True) else: - dist_mat = pairwise_dist(x1=features_queries, x2=features_galleries) + dist_mat = pairwise_dist(x1=embeddings_query, x2=embeddings_gallery, p=2) dists, ii_closest = torch.topk(dist_mat, dim=1, k=top_k, largest=False) print(f"Top {top_k} items closest to queries are:\n {ii_closest}") diff --git a/docs/readme/examples_source/extractor/train.md b/docs/readme/examples_source/extractor/train.md index 50145ff5a..01110eeca 100644 --- a/docs/readme/examples_source/extractor/train.md +++ b/docs/readme/examples_source/extractor/train.md @@ -14,13 +14,12 @@ from oml.models import ViTExtractor from oml.samplers.balance import BalanceSampler from oml.utils.download_mock_dataset import download_mock_dataset -dataset_root = "mock_dataset/" -df_train, _ = download_mock_dataset(dataset_root) +df_train, _ = download_mock_dataset(global_paths=True) extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train() optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) -train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) +train_dataset = DatasetWithLabels(df_train) criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner(), need_logs=True) sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2) train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler) diff --git a/docs/readme/examples_source/extractor/train_2loaders_val.md b/docs/readme/examples_source/extractor/train_2loaders_val.md index eb1676da6..c277b133d 100644 --- a/docs/readme/examples_source/extractor/train_2loaders_val.md +++ b/docs/readme/examples_source/extractor/train_2loaders_val.md @@ -15,21 +15,18 @@ from oml.models import ViTExtractor from oml.transforms.images.torchvision import get_normalisation_resize_torch from oml.utils.download_mock_dataset import download_mock_dataset -dataset_root = "mock_dataset/" -_, df_val = download_mock_dataset(dataset_root) +_, df_val = download_mock_dataset(global_paths=True) extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False) # 1st validation dataset (big images) -val_dataset_1 = DatasetQueryGallery(df_val, dataset_root=dataset_root, - transform=get_normalisation_resize_torch(im_size=224)) +val_dataset_1 = DatasetQueryGallery(df_val, transform=get_normalisation_resize_torch(im_size=224)) val_loader_1 = torch.utils.data.DataLoader(val_dataset_1, batch_size=4) metric_callback_1 = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[val_dataset_1.paths_key,]), log_images=True, loader_idx=0) # 2nd validation dataset (small images) -val_dataset_2 = DatasetQueryGallery(df_val, dataset_root=dataset_root, - transform=get_normalisation_resize_torch(im_size=48)) +val_dataset_2 = DatasetQueryGallery(df_val, transform=get_normalisation_resize_torch(im_size=48)) val_loader_2 = torch.utils.data.DataLoader(val_dataset_2, batch_size=4) metric_callback_2 = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[val_dataset_2.paths_key,]), log_images=True, loader_idx=1) diff --git a/docs/readme/examples_source/extractor/train_val_pl.md b/docs/readme/examples_source/extractor/train_val_pl.md index fc9764a97..c182bb675 100644 --- a/docs/readme/examples_source/extractor/train_val_pl.md +++ b/docs/readme/examples_source/extractor/train_val_pl.md @@ -24,21 +24,20 @@ from oml.lightning.pipelines.logging import ( WandBPipelineLogger, ) -dataset_root = "mock_dataset/" -df_train, df_val = download_mock_dataset(dataset_root) +df_train, df_val = download_mock_dataset(global_paths=True) # model extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False) # train optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) -train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) +train_dataset = DatasetWithLabels(df_train) criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner()) batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3) train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler) # val -val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root) +val_dataset = DatasetQueryGallery(df_val) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4) metric_callback = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[train_dataset.paths_key,]), log_images=True) diff --git a/docs/readme/examples_source/extractor/train_val_pl_ddp.md b/docs/readme/examples_source/extractor/train_val_pl_ddp.md index 1aea03452..dc9cfd726 100644 --- a/docs/readme/examples_source/extractor/train_val_pl_ddp.md +++ b/docs/readme/examples_source/extractor/train_val_pl_ddp.md @@ -19,21 +19,20 @@ from oml.samplers.balance import BalanceSampler from oml.utils.download_mock_dataset import download_mock_dataset from pytorch_lightning.strategies import DDPStrategy -dataset_root = "mock_dataset/" -df_train, df_val = download_mock_dataset(dataset_root) +df_train, df_val = download_mock_dataset(global_paths=True) # model extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False) # train optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) -train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) +train_dataset = DatasetWithLabels(df_train) criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner()) batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3) train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler) # val -val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root) +val_dataset = DatasetQueryGallery(df_val) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4) metric_callback = MetricValCallbackDDP(metric=EmbeddingMetricsDDP()) # DDP specific diff --git a/docs/readme/examples_source/extractor/train_with_pml.md b/docs/readme/examples_source/extractor/train_with_pml.md index 138cc07ee..b5f249c82 100644 --- a/docs/readme/examples_source/extractor/train_with_pml.md +++ b/docs/readme/examples_source/extractor/train_with_pml.md @@ -13,13 +13,12 @@ from oml.utils.download_mock_dataset import download_mock_dataset from pytorch_metric_learning import losses, distances, reducers, miners -dataset_root = "mock_dataset/" -df_train, _ = download_mock_dataset(dataset_root) +df_train, _ = download_mock_dataset(global_paths=True) extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train() optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) -train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) +train_dataset = DatasetWithLabels(df_train) # PML specific # criterion = losses.TripletMarginLoss(margin=0.2, triplets_per_anchor="all") diff --git a/docs/readme/examples_source/extractor/train_with_pml_advanced.md b/docs/readme/examples_source/extractor/train_with_pml_advanced.md index c0b245253..33a27d59b 100644 --- a/docs/readme/examples_source/extractor/train_with_pml_advanced.md +++ b/docs/readme/examples_source/extractor/train_with_pml_advanced.md @@ -13,13 +13,12 @@ from oml.utils.download_mock_dataset import download_mock_dataset from pytorch_metric_learning import losses, distances, reducers, miners -dataset_root = "mock_dataset/" -df_train, _ = download_mock_dataset(dataset_root) +df_train, _ = download_mock_dataset(global_paths=True) extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train() optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) -train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) +train_dataset = DatasetWithLabels(df_train) # PML specific distance = distances.LpDistance(p=2) diff --git a/docs/readme/examples_source/extractor/val.md b/docs/readme/examples_source/extractor/val.md index 39181f17d..3f71ebd25 100644 --- a/docs/readme/examples_source/extractor/val.md +++ b/docs/readme/examples_source/extractor/val.md @@ -12,12 +12,11 @@ from oml.metrics.embeddings import EmbeddingMetrics from oml.models import ViTExtractor from oml.utils.download_mock_dataset import download_mock_dataset -dataset_root = "mock_dataset/" -_, df_val = download_mock_dataset(dataset_root) +_, df_val = download_mock_dataset(global_paths=True) extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval() -val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root) +val_dataset = DatasetQueryGallery(df_val) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4) calculator = EmbeddingMetrics(extra_keys=("paths",)) diff --git a/docs/readme/examples_source/extractor/val_with_sequence.md b/docs/readme/examples_source/extractor/val_with_sequence.md index 2b015ea40..a0cfd916a 100644 --- a/docs/readme/examples_source/extractor/val_with_sequence.md +++ b/docs/readme/examples_source/extractor/val_with_sequence.md @@ -42,12 +42,11 @@ from oml.metrics.embeddings import EmbeddingMetrics from oml.models import ViTExtractor from oml.utils.download_mock_dataset import download_mock_dataset -dataset_root = "mock_dataset/" -_, df_val = download_mock_dataset(dataset_root, df_name="df_with_sequence.csv") # <- sequence info is in the file +_, df_val = download_mock_dataset(global_paths=True, df_name="df_with_sequence.csv") # <- sequence info is in the file extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval() -val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root) +val_dataset = DatasetQueryGallery(df_val) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4) calculator = EmbeddingMetrics(extra_keys=("paths",), sequence_key=val_dataset.sequence_key) diff --git a/oml/retrieval/postprocessors/pairwise.py b/oml/retrieval/postprocessors/pairwise.py index e82c6e841..e57961795 100644 --- a/oml/retrieval/postprocessors/pairwise.py +++ b/oml/retrieval/postprocessors/pairwise.py @@ -7,11 +7,7 @@ from oml.interfaces.datasets import IQueryGalleryDataset from oml.interfaces.models import IPairwiseModel from oml.interfaces.retrieval import IRetrievalPostprocessor -from oml.utils.misc_torch import ( - assign_2d, - cat_two_sorted_tensors_and_keep_it_sorted, - take_2d, -) +from oml.utils.misc_torch import cat_two_sorted_tensors_and_keep_it_sorted, take_2d class PairwiseReranker(IRetrievalPostprocessor): @@ -50,27 +46,54 @@ def __init__( def process(self, distances: Tensor, dataset: IQueryGalleryDataset) -> Tensor: # type: ignore """ Args: - distances: Distances among queries and galleries with the shape of ``[Q, G]``. + distances: Where ``distances[i, j]`` is a distance between i-th query and j-th gallery. dataset: Dataset having query-gallery split. Returns: - The same distances matrix, but `top_n` smallest values are updated. + Distances: Where ``distances[i, j]`` is a distance between i-th query and j-th gallery, + but the distances to the first ``top_n`` galleries have been updated. """ # todo 522: - # after we introduce RetrievalPrediction the signature of the method will change: so, we directly call - # self.process_neigh. Thus, the code below is temporary to support the current interface. - distances_neigh, ii_neigh = torch.topk( - distances, k=min(distances.shape[1], self.top_n), largest=False - ) # todo 522: test it!!! - distances_neigh_upd, ii_neigh_upd = self.process_neigh(distances_neigh, ii_neigh, dataset) - distances_upd = assign_2d(x=distances, indices=ii_neigh_upd, new_values=distances_neigh_upd) - return distances_upd + # This function is needed only during the migration time. We will directly use `process_neigh` later. + # Thus, the code above is just an adapter for input and output of the `process_neigh` function. + + assert distances.shape == (len(dataset.get_query_ids()), len(dataset.get_gallery_ids())) + + distances, ii_retrieved = distances.sort() + distances, ii_retrieved_upd = self.process_neigh( + retrieved_ids=ii_retrieved, distances=distances, dataset=dataset + ) + distances = take_2d(distances, ii_retrieved_upd.argsort()) + + assert distances.shape == (len(dataset.get_query_ids()), len(dataset.get_gallery_ids())) + + return distances def process_neigh( - self, distances: Tensor, retrieved_ids: Tensor, dataset: IQueryGalleryDataset + self, retrieved_ids: Tensor, distances: Tensor, dataset: IQueryGalleryDataset ) -> Tuple[Tensor, Tensor]: """ - Note, the new distances to the ``top_n`` items produced by the pairwise model may be adjusted + + Args: + retrieved_ids: Ids of galleries closest to every query with the shape of ``[n_query, n_retrieved]`` sorted + by their distances. + distances: The corresponding distances (in sorted order). + dataset: Dataset having query/gallery split. + + Returns: + After model is applied to the ``top_n`` retrieved items, the updated ids and distances are returned. + Thus, you can expect permutation among first ``top_n`` ids and distances, but the rest remains untouched. + + **Example 1:** + ``retrieved_ids: [3, 2, 1, 0, 4 ]`` + ``distances: [0.1, 0.2, 0.5, 0.6, 0.7]`` + Let's say postprocessor has been applied to the first 3 elements and new distances are: ``[0.4, 0.2, 0.3]`` + In this case, updated values are: + ``[2, 1, 3, 0, 4 ]`` + ``[0.2, 0.3, 0.4, 0.6, 0.7]`` + + **Example 2:** + Note, the new distances to the ``top_n`` items produced by the pairwise model may be rescaled to remain distances sorted. Here is an example: ``original_distances = [0.1, 0.2, 0.3, 0.5, 0.6], top_n = 3`` Imagine, the postprocessor didn't change the order of the first 3 items (it's just a convenient example, @@ -79,10 +102,12 @@ def process_neigh( Thus, we need to rescale the first three distances, so they don't go above ``0.5``. The scaling factor is ``s = min(0.5, 0.6) / max(1, 2, 5) = 0.1``. Finally: ``distances_upd_scaled = [0.1, 0.2, 0.5, 0.5, 0.6]``. - If concatenation of two distances is already sorted, we keep it untouched. + If concatenation of the new and old distances is already sorted, we don't apply any scaling. """ - # todo 522: explain what's going on here + assert retrieved_ids.shape == distances.shape + assert len(retrieved_ids) == len(dataset.get_query_ids()) + assert retrieved_ids.shape[1] <= len(dataset.get_gallery_ids()) top_n = min(self.top_n, distances.shape[1]) diff --git a/oml/utils/download_mock_dataset.py b/oml/utils/download_mock_dataset.py index 202d3945c..b4a5ddfa4 100644 --- a/oml/utils/download_mock_dataset.py +++ b/oml/utils/download_mock_dataset.py @@ -29,6 +29,7 @@ def download_mock_dataset( dataset_root: Path to save the dataset check_md5: Set ``True`` to check md5sum df_name: Name of csv file for which output DataFrames will be returned + global_paths: Set ``True`` to cancat paths and ``dataset_root`` Returns: Dataframes for the training and validation stages diff --git a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py index 4cb920d24..ab1d8b132 100644 --- a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py +++ b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py @@ -11,7 +11,7 @@ from oml.interfaces.models import IPairwiseModel from oml.models.meta.siamese import LinearTrivialDistanceSiamese from oml.retrieval.postprocessors.pairwise import PairwiseReranker -from oml.utils.misc import flatten_dict, one_hot +from oml.utils.misc import flatten_dict, one_hot, set_global_seed from oml.utils.misc_torch import normalise, pairwise_dist from tests.test_integrations.utils import ( EmbeddingsQueryGalleryDataset, @@ -37,7 +37,7 @@ def independent_query_gallery_case() -> Tuple[IQueryGalleryDataset, Tensor]: dataset = EmbeddingsQueryGalleryDataset(embeddings=embeddings, is_query=is_query, is_gallery=is_gallery) - embeddings_inference = embeddings.clone() # pretend it's our inference + embeddings_inference = embeddings.clone() # pretend it's our inference results return dataset, embeddings_inference @@ -53,14 +53,14 @@ def shared_query_gallery_case() -> Tuple[IQueryGalleryDataset, Tensor]: embeddings=embeddings, is_query=torch.ones(sz).bool(), is_gallery=torch.ones(sz).bool() ) - embeddings_inference = embeddings.clone() # pretend it's our inference + embeddings_inference = embeddings.clone() # pretend it's our inference results return dataset, embeddings_inference @pytest.mark.long @pytest.mark.parametrize("top_n", [2, 5, 100]) -@pytest.mark.parametrize("pairwise_distances_bias", [0, 100]) +@pytest.mark.parametrize("pairwise_distances_bias", [0, -100, +100]) @pytest.mark.parametrize("fixture_name", ["independent_query_gallery_case", "shared_query_gallery_case"]) def test_trivial_processing_does_not_change_distances_order( request: pytest.FixtureRequest, fixture_name: str, top_n: int, pairwise_distances_bias: float @@ -74,11 +74,7 @@ def test_trivial_processing_does_not_change_distances_order( distances_processed = processor.process(distances=distances.clone(), dataset=dataset) - if pairwise_distances_bias == 0: - assert torch.allclose(distances_processed, distances) - else: - assert (distances_processed.argsort() == distances.argsort()).all() - assert not torch.allclose(distances_processed, distances) + assert (distances_processed.argsort() == distances.argsort()).all() def perfect_case() -> Tuple[IQueryGalleryLabeledDataset, Tensor]: @@ -97,7 +93,8 @@ def perfect_case() -> Tuple[IQueryGalleryLabeledDataset, Tensor]: @pytest.mark.long -def test_trivial_processing_fixes_broken_perfect_case() -> None: +@pytest.mark.parametrize("pairwise_distances_bias", [0, -100, +100]) +def test_trivial_processing_fixes_broken_perfect_case(pairwise_distances_bias: float) -> None: """ The idea of the test is the following: @@ -135,7 +132,9 @@ def test_trivial_processing_fixes_broken_perfect_case() -> None: metrics = flatten_dict(calc_retrieval_metrics(distances=distances, **args)) # Metrics after broken distances have been fixed - model = LinearTrivialDistanceSiamese(feat_dim=embeddings.shape[-1], identity_init=True, output_bias=10) + model = LinearTrivialDistanceSiamese( + feat_dim=embeddings.shape[-1], identity_init=True, output_bias=pairwise_distances_bias + ) processor = PairwiseReranker(pairwise_model=model, top_n=top_n, batch_size=16, num_workers=0) distances_upd = processor.process(distances, dataset) metrics_upd = flatten_dict(calc_retrieval_metrics(distances=distances_upd, **args)) @@ -164,7 +163,13 @@ def test_processing_not_changing_non_sensitive_metrics(top_n: int) -> None: # The idea of the test is that postprocessing of first n elements # cannot change cmc@n and precision@n + set_global_seed(42) + + # let's get some random inputs dataset, embeddings = perfect_case() + embeddings = torch.randn_like(embeddings).float() + + top_n = min(top_n, embeddings.shape[1]) distances = pairwise_dist(embeddings[dataset.get_query_ids()], embeddings[dataset.get_gallery_ids()], p=2) @@ -181,11 +186,17 @@ def test_processing_not_changing_non_sensitive_metrics(top_n: int) -> None: } metrics_before = calc_retrieval_metrics(distances=distances, **args) + ii_closest_before = torch.argsort(distances) model = RandomPairwise() processor = PairwiseReranker(pairwise_model=model, top_n=top_n, batch_size=4, num_workers=0) distances_upd = processor.process(distances=distances, dataset=dataset) metrics_after = calc_retrieval_metrics(distances=distances_upd, **args) + ii_closest_after = torch.argsort(distances_upd) assert metrics_before == metrics_after + + # also check that we only re-ranked the first top_n items + assert (ii_closest_before[:, :top_n] != ii_closest_after[:, :top_n]).any() + assert (ii_closest_before[:, top_n:] == ii_closest_after[:, top_n:]).all() diff --git a/tests/test_oml/test_postprocessor/test_pairwise_images.py b/tests/test_oml/test_postprocessor/test_pairwise_images.py index 45ca9bf1c..c71b7054a 100644 --- a/tests/test_oml/test_postprocessor/test_pairwise_images.py +++ b/tests/test_oml/test_postprocessor/test_pairwise_images.py @@ -1,7 +1,6 @@ from typing import Tuple import pytest -import torch from torch import Tensor, nn from oml.const import MOCK_DATASET_PATH @@ -31,7 +30,7 @@ def get_validation_results(model: nn.Module, transforms: TTransforms) -> Tuple[T @pytest.mark.long @pytest.mark.parametrize("top_n", [2, 5, 100]) -@pytest.mark.parametrize("pairwise_distances_bias", [0, 100]) +@pytest.mark.parametrize("pairwise_distances_bias", [0, -100, +100]) def test_trivial_processing_does_not_change_distances_order(top_n: int, pairwise_distances_bias: float) -> None: extractor = ResnetExtractor(weights=None, arch="resnet18", normalise_features=True, gem_p=None, remove_fc=True) pairwise_model = TrivialDistanceSiamese(extractor, output_bias=pairwise_distances_bias) @@ -49,8 +48,4 @@ def test_trivial_processing_does_not_change_distances_order(top_n: int, pairwise ) distances_processed = postprocessor.process(distances=distances.clone(), dataset=dataset) - if pairwise_distances_bias == 0: - assert torch.allclose(distances_processed.argsort(), distances.argsort()) - else: - assert (distances_processed.argsort() == distances.argsort()).all() - assert not torch.allclose(distances_processed, distances) + assert (distances_processed.argsort() == distances.argsort()).all() From 5e38ea0e7890a56dbe11be257e1b41c38ed19bbc Mon Sep 17 00:00:00 2001 From: alekseysh Date: Mon, 22 Apr 2024 06:41:56 +0700 Subject: [PATCH 17/23] upd --- docs/source/contents/datasets.rst | 9 +++ docs/source/contents/postprocessing.rst | 2 + oml/datasets/pairs.py | 4 +- oml/retrieval/postprocessors/pairwise.py | 58 ++++++++++++------- .../test_pairwise_embeddings.py | 17 ++++-- 5 files changed, 60 insertions(+), 30 deletions(-) diff --git a/docs/source/contents/datasets.rst b/docs/source/contents/datasets.rst index a349a0f3b..34cb41e59 100644 --- a/docs/source/contents/datasets.rst +++ b/docs/source/contents/datasets.rst @@ -52,3 +52,12 @@ ImageQueryGalleryDataset .. automethod:: get_query_ids .. automethod:: get_gallery_ids .. automethod:: visualize + +PairDataset +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: oml.datasets.pairs.PairDataset + :undoc-members: + :show-inheritance: + + .. automethod:: __init__ + .. automethod:: __getitem__ diff --git a/docs/source/contents/postprocessing.rst b/docs/source/contents/postprocessing.rst index 5af9aae67..ea1f41227 100644 --- a/docs/source/contents/postprocessing.rst +++ b/docs/source/contents/postprocessing.rst @@ -15,3 +15,5 @@ PairwiseReranker .. automethod:: __init__ .. automethod:: process + .. automethod:: process_neigh + diff --git a/oml/datasets/pairs.py b/oml/datasets/pairs.py index b31e5fd83..2aca4c530 100644 --- a/oml/datasets/pairs.py +++ b/oml/datasets/pairs.py @@ -5,12 +5,10 @@ from oml.const import INDEX_KEY, INPUT_TENSORS_KEY_1, INPUT_TENSORS_KEY_2 from oml.interfaces.datasets import IBaseDataset, IPairDataset -# todo 522: make one modality agnostic instead of these two - class PairDataset(IPairDataset): """ - Dataset to iterate over pairs of items. + Dataset to iterate over pairs of items of any modality. """ diff --git a/oml/retrieval/postprocessors/pairwise.py b/oml/retrieval/postprocessors/pairwise.py index e57961795..89c0dbf7c 100644 --- a/oml/retrieval/postprocessors/pairwise.py +++ b/oml/retrieval/postprocessors/pairwise.py @@ -50,8 +50,9 @@ def process(self, distances: Tensor, dataset: IQueryGalleryDataset) -> Tensor: dataset: Dataset having query-gallery split. Returns: - Distances: Where ``distances[i, j]`` is a distance between i-th query and j-th gallery, - but the distances to the first ``top_n`` galleries have been updated. + Distances, where ``distances[i, j]`` is a distance between i-th query and j-th gallery, + but the distances to the first ``top_n`` galleries have been updated INPLACE. + """ # todo 522: # This function is needed only during the migration time. We will directly use `process_neigh` later. @@ -84,25 +85,40 @@ def process_neigh( After model is applied to the ``top_n`` retrieved items, the updated ids and distances are returned. Thus, you can expect permutation among first ``top_n`` ids and distances, but the rest remains untouched. - **Example 1:** - ``retrieved_ids: [3, 2, 1, 0, 4 ]`` - ``distances: [0.1, 0.2, 0.5, 0.6, 0.7]`` - Let's say postprocessor has been applied to the first 3 elements and new distances are: ``[0.4, 0.2, 0.3]`` - In this case, updated values are: - ``[2, 1, 3, 0, 4 ]`` - ``[0.2, 0.3, 0.4, 0.6, 0.7]`` - - **Example 2:** - Note, the new distances to the ``top_n`` items produced by the pairwise model may be rescaled - to remain distances sorted. Here is an example: - ``original_distances = [0.1, 0.2, 0.3, 0.5, 0.6], top_n = 3`` - Imagine, the postprocessor didn't change the order of the first 3 items (it's just a convenient example, - the logic remains the same), however the new values have a bigger scale: - ``distances_upd = [1, 2, 5, 0.5, 0.6]``. - Thus, we need to rescale the first three distances, so they don't go above ``0.5``. - The scaling factor is ``s = min(0.5, 0.6) / max(1, 2, 5) = 0.1``. Finally: - ``distances_upd_scaled = [0.1, 0.2, 0.5, 0.5, 0.6]``. - If concatenation of the new and old distances is already sorted, we don't apply any scaling. + **Example 1** (for one query): + + .. code-block:: python + + retrieved_ids = [3, 2, 1, 0, 4 ] + distances = [0.1, 0.2, 0.5, 0.6, 0.7] + + # Let's say a postprocessor has been applied to the + # first 3 elements and the new distances are: [0.4, 0.2, 0.3] + + # In this case, the updated values will be: + retrievied_ids = [2, 1, 3, 0, 4 ] + distances: = [0.2, 0.3, 0.4, 0.6, 0.7] + + **Example 2** (for one query): + + .. code-block:: python + + # Note, the new distances to the top_n items produced by the pairwise model + # may be rescaled to keep the distances order. Here is an example: + original_distances = [0.1, 0.2, 0.3, 0.5, 0.6] + top_n = 3 + + # Imagine, the postprocessor didn't change the order of the first 3 items + # (it's just a convenient example, the general logic remains the same), + # however the new values have a bigger scale: + distances_upd = [1, 2, 5, 0.5, 0.6] + + # Thus, we need to downscale the first 3 distances, so they are lower than 0.5: + scale = 5 / 0.5 = 0.1 + # Finally, let's apply the found scale to the top 3 distances: + distances_upd_scaled = [0.1, 0.2, 0.5, 0.5, 0.6] + + # Note, if new and old distances are already sorted, we don't apply any scaling. """ assert retrieved_ids.shape == distances.shape diff --git a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py index ab1d8b132..03df79e6c 100644 --- a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py +++ b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py @@ -65,16 +65,21 @@ def shared_query_gallery_case() -> Tuple[IQueryGalleryDataset, Tensor]: def test_trivial_processing_does_not_change_distances_order( request: pytest.FixtureRequest, fixture_name: str, top_n: int, pairwise_distances_bias: float ) -> None: - dataset, embeddings = request.getfixturevalue(fixture_name) + for _ in range(10): + dataset, embeddings = request.getfixturevalue(fixture_name) - distances = pairwise_dist(x1=embeddings[dataset.get_query_ids()], x2=embeddings[dataset.get_gallery_ids()], p=2) + distances = pairwise_dist(x1=embeddings[dataset.get_query_ids()], x2=embeddings[dataset.get_gallery_ids()], p=2) - model = LinearTrivialDistanceSiamese(embeddings.shape[-1], output_bias=pairwise_distances_bias, identity_init=True) - processor = PairwiseReranker(pairwise_model=model, top_n=top_n, num_workers=0, batch_size=64) + print(distances, "zzzz") - distances_processed = processor.process(distances=distances.clone(), dataset=dataset) + model = LinearTrivialDistanceSiamese( + embeddings.shape[-1], output_bias=pairwise_distances_bias, identity_init=True + ) + processor = PairwiseReranker(pairwise_model=model, top_n=top_n, num_workers=0, batch_size=64) + + distances_processed = processor.process(distances=distances.clone(), dataset=dataset) - assert (distances_processed.argsort() == distances.argsort()).all() + assert (distances_processed.argsort() == distances.argsort()).all() def perfect_case() -> Tuple[IQueryGalleryLabeledDataset, Tensor]: From daf0c475c33039f52b031b8359af9f4dde87351d Mon Sep 17 00:00:00 2001 From: alekseysh Date: Mon, 22 Apr 2024 07:30:48 +0700 Subject: [PATCH 18/23] upd --- oml/metrics/embeddings.py | 9 ++++++-- .../test_embedding_visualizations.py | 3 +++ .../test_pairwise_embeddings.py | 23 ++++++++++--------- .../test_pairwise_images.py | 6 +++++ tests/test_runs/test_code_from_markdown.py | 2 +- 5 files changed, 29 insertions(+), 14 deletions(-) diff --git a/oml/metrics/embeddings.py b/oml/metrics/embeddings.py index 44f273993..712edee92 100644 --- a/oml/metrics/embeddings.py +++ b/oml/metrics/embeddings.py @@ -12,6 +12,7 @@ EMBEDDINGS_KEY, GRAY, GREEN, + INDEX_KEY, IS_GALLERY_KEY, IS_QUERY_KEY, LABELS_KEY, @@ -68,7 +69,7 @@ class EmbeddingMetrics(IMetricVisualisable): def __init__( self, - dataset: Optional[IQueryGalleryLabeledDataset] = None, + dataset: Optional[IQueryGalleryLabeledDataset] = None, # todo 522: This argument will not be Optional soon. embeddings_key: str = EMBEDDINGS_KEY, labels_key: str = LABELS_KEY, is_query_key: str = IS_QUERY_KEY, @@ -90,7 +91,7 @@ def __init__( """ Args: - dataset: Annotated dataset having query-gallery split. todo 522: This argument will not be Optional soon. + dataset: Annotated dataset having query-gallery split. embeddings_key: Key to take the embeddings from the batches labels_key: Key to take the labels from the batches is_query_key: Key to take the information whether every batch sample belongs to the query @@ -146,6 +147,7 @@ def __init__( self.verbose = verbose keys_to_accumulate = [self.embeddings_key, self.is_query_key, self.is_gallery_key, self.labels_key] + keys_to_accumulate += [INDEX_KEY] # todo 522: remove it after we make "indices" not optional in .update_data() if self.categories_key: keys_to_accumulate.append(self.categories_key) if self.sequence_key: @@ -200,6 +202,9 @@ def _calc_matrices(self) -> None: if self.postprocessor: assert self.dataset, "You must pass dataset to init to make postprocessing." + # todo 522: remove this assert after "indices" become not optional + ii_aligned = list(range(len(self.dataset))) + assert ii_aligned == self.acc.storage[INDEX_KEY].tolist(), "The data is shuffled!" # type: ignore self.distance_matrix = self.postprocessor.process(self.distance_matrix, dataset=self.dataset) def compute_metrics(self) -> TMetricsDict_ByLabels: # type: ignore diff --git a/tests/test_oml/test_metrics/test_embedding_visualizations.py b/tests/test_oml/test_metrics/test_embedding_visualizations.py index 2c883187c..ee2e90df7 100644 --- a/tests/test_oml/test_metrics/test_embedding_visualizations.py +++ b/tests/test_oml/test_metrics/test_embedding_visualizations.py @@ -6,6 +6,7 @@ from oml.const import ( CATEGORIES_KEY, EMBEDDINGS_KEY, + INDEX_KEY, IS_GALLERY_KEY, IS_QUERY_KEY, LABELS_KEY, @@ -32,6 +33,7 @@ def test_visualization() -> None: IS_GALLERY_KEY: torch.tensor([False, False, False]), CATEGORIES_KEY: torch.tensor([10, 20, 20]), PATHS_KEY: [cf / "temp.png", cf / "temp.png", cf / "temp.png"], + INDEX_KEY: torch.tensor([0, 1, 2]), } batch2 = { @@ -41,6 +43,7 @@ def test_visualization() -> None: IS_GALLERY_KEY: torch.tensor([True, True, True]), CATEGORIES_KEY: torch.tensor([10, 20, 20]), PATHS_KEY: [cf / "temp.png", cf / "temp.png", cf / "temp.png"], + INDEX_KEY: torch.tensor([3, 4, 5]), } calc = EmbeddingMetrics( diff --git a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py index 03df79e6c..27ed03062 100644 --- a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py +++ b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py @@ -60,26 +60,27 @@ def shared_query_gallery_case() -> Tuple[IQueryGalleryDataset, Tensor]: @pytest.mark.long @pytest.mark.parametrize("top_n", [2, 5, 100]) -@pytest.mark.parametrize("pairwise_distances_bias", [0, -100, +100]) +@pytest.mark.parametrize("pairwise_distances_bias", [0, -5, +5]) @pytest.mark.parametrize("fixture_name", ["independent_query_gallery_case", "shared_query_gallery_case"]) def test_trivial_processing_does_not_change_distances_order( request: pytest.FixtureRequest, fixture_name: str, top_n: int, pairwise_distances_bias: float ) -> None: - for _ in range(10): - dataset, embeddings = request.getfixturevalue(fixture_name) + set_global_seed(10) # todo 522: make it work on seed == 55, 42 + dataset, embeddings = request.getfixturevalue(fixture_name) - distances = pairwise_dist(x1=embeddings[dataset.get_query_ids()], x2=embeddings[dataset.get_gallery_ids()], p=2) + distances = pairwise_dist(x1=embeddings[dataset.get_query_ids()], x2=embeddings[dataset.get_gallery_ids()], p=2) - print(distances, "zzzz") + model = LinearTrivialDistanceSiamese(embeddings.shape[-1], output_bias=pairwise_distances_bias, identity_init=True) + processor = PairwiseReranker(pairwise_model=model, top_n=top_n, num_workers=0, batch_size=64) - model = LinearTrivialDistanceSiamese( - embeddings.shape[-1], output_bias=pairwise_distances_bias, identity_init=True - ) - processor = PairwiseReranker(pairwise_model=model, top_n=top_n, num_workers=0, batch_size=64) + distances_processed = processor.process(distances=distances.clone(), dataset=dataset) - distances_processed = processor.process(distances=distances.clone(), dataset=dataset) + assert (distances_processed.argsort() == distances.argsort()).all() - assert (distances_processed.argsort() == distances.argsort()).all() + if pairwise_distances_bias == 0: + assert torch.allclose(distances, distances_processed) + else: + assert not torch.allclose(distances, distances_processed) def perfect_case() -> Tuple[IQueryGalleryLabeledDataset, Tensor]: diff --git a/tests/test_oml/test_postprocessor/test_pairwise_images.py b/tests/test_oml/test_postprocessor/test_pairwise_images.py index c71b7054a..ba37b666f 100644 --- a/tests/test_oml/test_postprocessor/test_pairwise_images.py +++ b/tests/test_oml/test_postprocessor/test_pairwise_images.py @@ -1,6 +1,7 @@ from typing import Tuple import pytest +import torch from torch import Tensor, nn from oml.const import MOCK_DATASET_PATH @@ -49,3 +50,8 @@ def test_trivial_processing_does_not_change_distances_order(top_n: int, pairwise distances_processed = postprocessor.process(distances=distances.clone(), dataset=dataset) assert (distances_processed.argsort() == distances.argsort()).all() + + if pairwise_distances_bias == 0: + assert torch.allclose(distances_processed, distances) + else: + assert not torch.allclose(distances_processed, distances) diff --git a/tests/test_runs/test_code_from_markdown.py b/tests/test_runs/test_code_from_markdown.py index 4621484a8..81260ed58 100644 --- a/tests/test_runs/test_code_from_markdown.py +++ b/tests/test_runs/test_code_from_markdown.py @@ -30,8 +30,8 @@ def find_code_block(file: Path, start_indicator: str, end_indicator: str) -> str ("extractor/train_val_pl.md", "[comment]:lightning-start\n", "[comment]:lightning-end\n"), ("extractor/train_val_pl_ddp.md", "[comment]:lightning-ddp-start\n", "[comment]:lightning-ddp-end\n"), ("extractor/train_2loaders_val.md", "[comment]:lightning-2loaders-start\n", "[comment]:lightning-2loaders-end\n"), # noqa - ("zoo/models_usage.md", "[comment]:zoo-start\n", "[comment]:zoo-end\n"), ("extractor/retrieval_usage.md", "[comment]:usage-retrieval-start\n", "[comment]:usage-retrieval-end\n"), + ("zoo/models_usage.md", "[comment]:zoo-start\n", "[comment]:zoo-end\n"), ("postprocessing/train_val.md", "[comment]:postprocessor-start\n", "[comment]:postprocessor-end\n"), ("postprocessing/predict.md", "[comment]:postprocessor-pred-start\n", "[comment]:postprocessor-pred-end\n"), ], From 19a6ed222e39d7a3064de55a5383ddcce7a990af Mon Sep 17 00:00:00 2001 From: alekseysh Date: Mon, 22 Apr 2024 07:49:40 +0700 Subject: [PATCH 19/23] upd --- oml/retrieval/postprocessors/pairwise.py | 4 ++-- oml/utils/misc_torch.py | 2 +- .../test_postprocessor/test_pairwise_embeddings.py | 7 ++++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/oml/retrieval/postprocessors/pairwise.py b/oml/retrieval/postprocessors/pairwise.py index 89c0dbf7c..43187abc9 100644 --- a/oml/retrieval/postprocessors/pairwise.py +++ b/oml/retrieval/postprocessors/pairwise.py @@ -96,8 +96,8 @@ def process_neigh( # first 3 elements and the new distances are: [0.4, 0.2, 0.3] # In this case, the updated values will be: - retrievied_ids = [2, 1, 3, 0, 4 ] - distances: = [0.2, 0.3, 0.4, 0.6, 0.7] + retrieved_ids = [2, 1, 3, 0, 4 ] + distances: = [0.2, 0.3, 0.4, 0.6, 0.7] **Example 2** (for one query): diff --git a/oml/utils/misc_torch.py b/oml/utils/misc_torch.py index f02e2beac..a8bc10ded 100644 --- a/oml/utils/misc_torch.py +++ b/oml/utils/misc_torch.py @@ -57,7 +57,7 @@ def assign_2d(x: Tensor, indices: Tensor, new_values: Tensor) -> Tensor: return x -def cat_two_sorted_tensors_and_keep_it_sorted(x1: Tensor, x2: Tensor, eps: float = 1e-7) -> Tensor: +def cat_two_sorted_tensors_and_keep_it_sorted(x1: Tensor, x2: Tensor, eps: float = 1e-6) -> Tensor: """ Args: x1: Sorted tensor with the shape of ``[N, M]`` diff --git a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py index 27ed03062..83432deb4 100644 --- a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py +++ b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py @@ -59,13 +59,14 @@ def shared_query_gallery_case() -> Tuple[IQueryGalleryDataset, Tensor]: @pytest.mark.long -@pytest.mark.parametrize("top_n", [2, 5, 100]) -@pytest.mark.parametrize("pairwise_distances_bias", [0, -5, +5]) +# @pytest.mark.parametrize("top_n", [2, 5, 100]) +# @pytest.mark.parametrize("pairwise_distances_bias", [0, -5, +5]) +@pytest.mark.parametrize("pairwise_distances_bias", [5]) +@pytest.mark.parametrize("top_n", [5]) @pytest.mark.parametrize("fixture_name", ["independent_query_gallery_case", "shared_query_gallery_case"]) def test_trivial_processing_does_not_change_distances_order( request: pytest.FixtureRequest, fixture_name: str, top_n: int, pairwise_distances_bias: float ) -> None: - set_global_seed(10) # todo 522: make it work on seed == 55, 42 dataset, embeddings = request.getfixturevalue(fixture_name) distances = pairwise_dist(x1=embeddings[dataset.get_query_ids()], x2=embeddings[dataset.get_gallery_ids()], p=2) From 06864e21809ad13597665c165fa5e0722be799c8 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Mon, 22 Apr 2024 18:41:53 +0700 Subject: [PATCH 20/23] put_back_test --- .../test_oml/test_postprocessor/test_pairwise_embeddings.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py index 83432deb4..20cffadd1 100644 --- a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py +++ b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py @@ -59,10 +59,8 @@ def shared_query_gallery_case() -> Tuple[IQueryGalleryDataset, Tensor]: @pytest.mark.long -# @pytest.mark.parametrize("top_n", [2, 5, 100]) -# @pytest.mark.parametrize("pairwise_distances_bias", [0, -5, +5]) -@pytest.mark.parametrize("pairwise_distances_bias", [5]) -@pytest.mark.parametrize("top_n", [5]) +@pytest.mark.parametrize("top_n", [2, 5, 100]) +@pytest.mark.parametrize("pairwise_distances_bias", [0, -5, +5]) @pytest.mark.parametrize("fixture_name", ["independent_query_gallery_case", "shared_query_gallery_case"]) def test_trivial_processing_does_not_change_distances_order( request: pytest.FixtureRequest, fixture_name: str, top_n: int, pairwise_distances_bias: float From b54b6a1a331999d94762d01a47e5424f581854c3 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Mon, 22 Apr 2024 22:13:08 +0700 Subject: [PATCH 21/23] fixes: type of dataset root, model link in validation --- oml/lightning/pipelines/train_postprocessor.py | 2 +- oml/lightning/pipelines/validate.py | 2 +- tests/test_runs/test_pipelines/predict.py | 2 +- tests/test_runs/test_pipelines/train.py | 2 +- tests/test_runs/test_pipelines/train_arcface_with_categories.py | 2 +- tests/test_runs/test_pipelines/train_postprocessor.py | 2 +- tests/test_runs/test_pipelines/train_with_categories.py | 2 +- tests/test_runs/test_pipelines/train_with_sequence.py | 2 +- tests/test_runs/test_pipelines/validate.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/oml/lightning/pipelines/train_postprocessor.py b/oml/lightning/pipelines/train_postprocessor.py index 4cdfa6886..e372627cb 100644 --- a/oml/lightning/pipelines/train_postprocessor.py +++ b/oml/lightning/pipelines/train_postprocessor.py @@ -61,7 +61,7 @@ def get_loaders_with_embeddings(cfg: TCfg) -> Tuple[DataLoader, DataLoader]: transforms_extraction = get_transforms_by_cfg(cfg["transforms_extraction"]) train_extraction, val_extraction = get_retrieval_images_datasets( - dataset_root=cfg["dataset_root"], + dataset_root=Path(cfg["dataset_root"]), dataframe_name=cfg["dataframe_name"], transforms_train=transforms_extraction, transforms_val=transforms_extraction, diff --git a/oml/lightning/pipelines/validate.py b/oml/lightning/pipelines/validate.py index 43a28110c..64baa1b7a 100644 --- a/oml/lightning/pipelines/validate.py +++ b/oml/lightning/pipelines/validate.py @@ -65,7 +65,7 @@ def extractor_validation_pipeline(cfg: TCfg) -> Tuple[pl.Trainer, Dict[str, Any] postprocessor = None if not cfg.get("postprocessor", None) else get_postprocessor_by_cfg(cfg["postprocessor"]) # Note! We add the link to our extractor to a Lightning's Module, so it can recognize it and manipulate its devices - pl_model.model_link_ = getattr(postprocessor, "extractor", None) + pl_model.model_link_ = postprocessor.model # type: ignore metrics_constructor = EmbeddingMetricsDDP if is_ddp else EmbeddingMetrics metrics_calc = metrics_constructor( diff --git a/tests/test_runs/test_pipelines/predict.py b/tests/test_runs/test_pipelines/predict.py index f10c0e9a4..7c5f3b43e 100644 --- a/tests/test_runs/test_pipelines/predict.py +++ b/tests/test_runs/test_pipelines/predict.py @@ -11,7 +11,7 @@ def main_hydra(cfg: DictConfig) -> None: cfg = dictconfig_to_dict(cfg) download_mock_dataset(MOCK_DATASET_PATH) - cfg["data_dir"] = MOCK_DATASET_PATH + cfg["data_dir"] = str(MOCK_DATASET_PATH) extractor_prediction_pipeline(cfg) diff --git a/tests/test_runs/test_pipelines/train.py b/tests/test_runs/test_pipelines/train.py index 2505e4c54..3b16fefe0 100644 --- a/tests/test_runs/test_pipelines/train.py +++ b/tests/test_runs/test_pipelines/train.py @@ -11,7 +11,7 @@ def main_hydra(cfg: DictConfig) -> None: cfg = dictconfig_to_dict(cfg) download_mock_dataset(MOCK_DATASET_PATH) - cfg["dataset_root"] = MOCK_DATASET_PATH + cfg["dataset_root"] = str(MOCK_DATASET_PATH) extractor_training_pipeline(cfg) diff --git a/tests/test_runs/test_pipelines/train_arcface_with_categories.py b/tests/test_runs/test_pipelines/train_arcface_with_categories.py index 65f30807e..74177ba3e 100644 --- a/tests/test_runs/test_pipelines/train_arcface_with_categories.py +++ b/tests/test_runs/test_pipelines/train_arcface_with_categories.py @@ -11,7 +11,7 @@ def main_hydra(cfg: DictConfig) -> None: cfg = dictconfig_to_dict(cfg) download_mock_dataset(MOCK_DATASET_PATH) - cfg["dataset_root"] = MOCK_DATASET_PATH + cfg["dataset_root"] = str(MOCK_DATASET_PATH) extractor_training_pipeline(cfg) diff --git a/tests/test_runs/test_pipelines/train_postprocessor.py b/tests/test_runs/test_pipelines/train_postprocessor.py index 6ac23cfeb..f01a6a516 100644 --- a/tests/test_runs/test_pipelines/train_postprocessor.py +++ b/tests/test_runs/test_pipelines/train_postprocessor.py @@ -11,7 +11,7 @@ def main_hydra(cfg: DictConfig) -> None: cfg = dictconfig_to_dict(cfg) download_mock_dataset(MOCK_DATASET_PATH) - cfg["dataset_root"] = MOCK_DATASET_PATH + cfg["dataset_root"] = str(MOCK_DATASET_PATH) postprocessor_training_pipeline(cfg) diff --git a/tests/test_runs/test_pipelines/train_with_categories.py b/tests/test_runs/test_pipelines/train_with_categories.py index 90328ac9e..db967f38f 100644 --- a/tests/test_runs/test_pipelines/train_with_categories.py +++ b/tests/test_runs/test_pipelines/train_with_categories.py @@ -11,7 +11,7 @@ def main_hydra(cfg: DictConfig) -> None: cfg = dictconfig_to_dict(cfg) download_mock_dataset(MOCK_DATASET_PATH) - cfg["dataset_root"] = MOCK_DATASET_PATH + cfg["dataset_root"] = str(MOCK_DATASET_PATH) extractor_training_pipeline(cfg) diff --git a/tests/test_runs/test_pipelines/train_with_sequence.py b/tests/test_runs/test_pipelines/train_with_sequence.py index ac7d1597f..3ecaa5701 100644 --- a/tests/test_runs/test_pipelines/train_with_sequence.py +++ b/tests/test_runs/test_pipelines/train_with_sequence.py @@ -11,7 +11,7 @@ def main_hydra(cfg: DictConfig) -> None: cfg = dictconfig_to_dict(cfg) download_mock_dataset(MOCK_DATASET_PATH) - cfg["dataset_root"] = MOCK_DATASET_PATH + cfg["dataset_root"] = str(MOCK_DATASET_PATH) extractor_training_pipeline(cfg) diff --git a/tests/test_runs/test_pipelines/validate.py b/tests/test_runs/test_pipelines/validate.py index 1226f9a97..c8890e6a4 100644 --- a/tests/test_runs/test_pipelines/validate.py +++ b/tests/test_runs/test_pipelines/validate.py @@ -11,7 +11,7 @@ def main_hydra(cfg: DictConfig) -> None: cfg = dictconfig_to_dict(cfg) download_mock_dataset(MOCK_DATASET_PATH) - cfg["dataset_root"] = MOCK_DATASET_PATH + cfg["dataset_root"] = str(MOCK_DATASET_PATH) extractor_validation_pipeline(cfg) From b9d02a6523c664cd05598cccbdca5a97f3058918 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Mon, 22 Apr 2024 23:42:28 +0700 Subject: [PATCH 22/23] optimizer postprocessing; solved issue with half precision; removed confusing test --- oml/retrieval/postprocessors/pairwise.py | 26 ++++++++++++++----- oml/utils/misc_torch.py | 2 +- .../configs/train_postprocessor.yaml | 2 +- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/oml/retrieval/postprocessors/pairwise.py b/oml/retrieval/postprocessors/pairwise.py index 43187abc9..9274dc21c 100644 --- a/oml/retrieval/postprocessors/pairwise.py +++ b/oml/retrieval/postprocessors/pairwise.py @@ -7,7 +7,11 @@ from oml.interfaces.datasets import IQueryGalleryDataset from oml.interfaces.models import IPairwiseModel from oml.interfaces.retrieval import IRetrievalPostprocessor -from oml.utils.misc_torch import cat_two_sorted_tensors_and_keep_it_sorted, take_2d +from oml.utils.misc_torch import ( + assign_2d, + cat_two_sorted_tensors_and_keep_it_sorted, + take_2d, +) class PairwiseReranker(IRetrievalPostprocessor): @@ -55,16 +59,24 @@ def process(self, distances: Tensor, dataset: IQueryGalleryDataset) -> Tensor: """ # todo 522: - # This function is needed only during the migration time. We will directly use `process_neigh` later. - # Thus, the code above is just an adapter for input and output of the `process_neigh` function. + # This function and the code below is only needed during the migration time. + # We will directly use `process_neigh` later on. + # So, the code below is just a format adapter: + # 1) it takes the top (dists + ii) of the big distance matrix, + # 2) passes this top to the `process_neigh()` + # 3) puts the processed outputs on their places in the big distance matrix assert distances.shape == (len(dataset.get_query_ids()), len(dataset.get_gallery_ids())) - distances, ii_retrieved = distances.sort() - distances, ii_retrieved_upd = self.process_neigh( - retrieved_ids=ii_retrieved, distances=distances, dataset=dataset + # we need this "+10" to activate rescaling if needed (so we have both: new and old distances in proces_neigh. + # anyway, this code is temporary + distances_top, ii_retrieved_top = torch.topk( + distances, k=min(self.top_n + 10, distances.shape[1]), largest=False ) - distances = take_2d(distances, ii_retrieved_upd.argsort()) + distances_top_upd, ii_retrieved_upd = self.process_neigh( + retrieved_ids=ii_retrieved_top, distances=distances_top, dataset=dataset + ) + distances = assign_2d(x=distances, indices=ii_retrieved_upd, new_values=distances_top_upd) assert distances.shape == (len(dataset.get_query_ids()), len(dataset.get_gallery_ids())) diff --git a/oml/utils/misc_torch.py b/oml/utils/misc_torch.py index a8bc10ded..b055aaf8e 100644 --- a/oml/utils/misc_torch.py +++ b/oml/utils/misc_torch.py @@ -72,7 +72,7 @@ def cat_two_sorted_tensors_and_keep_it_sorted(x1: Tensor, x2: Tensor, eps: float assert eps >= 0 assert x1.shape[0] == x2.shape[0] - scale = (x2[:, 0] / x1[:, -1]).view(-1, 1) + scale = (x2[:, 0] / x1[:, -1]).view(-1, 1).type_as(x1) need_scaling = x1[:, -1] > x2[:, 0] x1[need_scaling] = x1[need_scaling] * scale[need_scaling] - eps diff --git a/tests/test_runs/test_pipelines/configs/train_postprocessor.yaml b/tests/test_runs/test_pipelines/configs/train_postprocessor.yaml index 8e6f43110..c8a1e34f0 100644 --- a/tests/test_runs/test_pipelines/configs/train_postprocessor.yaml +++ b/tests/test_runs/test_pipelines/configs/train_postprocessor.yaml @@ -1,7 +1,7 @@ postfix: "postprocessing" seed: 42 -precision: 16 +precision: 32 accelerator: cpu devices: 2 find_unused_parameters: False From f2be9b256910791e28987c05e819eeccc14da5c5 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Tue, 23 Apr 2024 00:02:30 +0700 Subject: [PATCH 23/23] minor: hotfix none postproc + raise error if no images in predict --- oml/lightning/pipelines/predict.py | 3 +++ oml/lightning/pipelines/validate.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/oml/lightning/pipelines/predict.py b/oml/lightning/pipelines/predict.py index 8a4e8f0dc..63b170d86 100644 --- a/oml/lightning/pipelines/predict.py +++ b/oml/lightning/pipelines/predict.py @@ -34,6 +34,9 @@ def extractor_prediction_pipeline(cfg: TCfg) -> None: filenames = [list(Path(cfg["data_dir"]).glob(f"**/*.{ext}")) for ext in IMAGE_EXTENSIONS] filenames = list(itertools.chain(*filenames)) + if len(filenames) == 0: + raise RuntimeError(f"There are no images in the provided directory: {cfg['data_dir']}") + f_imread = get_im_reader_for_transforms(transforms) print("Let's check if there are broken images:") diff --git a/oml/lightning/pipelines/validate.py b/oml/lightning/pipelines/validate.py index 64baa1b7a..598c5e70e 100644 --- a/oml/lightning/pipelines/validate.py +++ b/oml/lightning/pipelines/validate.py @@ -65,7 +65,8 @@ def extractor_validation_pipeline(cfg: TCfg) -> Tuple[pl.Trainer, Dict[str, Any] postprocessor = None if not cfg.get("postprocessor", None) else get_postprocessor_by_cfg(cfg["postprocessor"]) # Note! We add the link to our extractor to a Lightning's Module, so it can recognize it and manipulate its devices - pl_model.model_link_ = postprocessor.model # type: ignore + if postprocessor is not None: + pl_model.model_link_ = postprocessor.model # type: ignore metrics_constructor = EmbeddingMetricsDDP if is_ddp else EmbeddingMetrics metrics_calc = metrics_constructor(