diff --git a/.gitignore b/.gitignore index 95cecde2d..fa99851dc 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ tmp.ipynb *outputs* *.hydra* *predictions.json* +*ml-runs* # __________________________________ diff --git a/docs/source/contents/datasets.rst b/docs/source/contents/datasets.rst index bcb909bfd..9af621a1c 100644 --- a/docs/source/contents/datasets.rst +++ b/docs/source/contents/datasets.rst @@ -7,42 +7,51 @@ Datasets .. contents:: :local: -BaseDataset +ImageBaseDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.base.BaseDataset +.. autoclass:: oml.datasets.images.ImageBaseDataset :undoc-members: :show-inheritance: .. automethod:: __init__ + .. automethod:: __getitem__ + .. automethod:: visualize -DatasetWithLabels +ImageLabeledDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.base.DatasetWithLabels +.. autoclass:: oml.datasets.images.ImageLabeledDataset :undoc-members: :show-inheritance: .. automethod:: __init__ .. automethod:: __getitem__ .. automethod:: get_labels - .. automethod:: get_label2category + .. automethod:: visualize -DatasetQueryGallery +ImageQueryGalleryLabeledDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.base.DatasetQueryGallery +.. autoclass:: oml.datasets.images.ImageQueryGalleryLabeledDataset :undoc-members: :show-inheritance: .. automethod:: __init__ .. automethod:: __getitem__ + .. automethod:: get_query_ids + .. automethod:: get_gallery_ids + .. automethod:: get_labels + .. automethod:: visualize -ListDataset +ImageQueryGalleryDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.list_dataset.ListDataset +.. autoclass:: oml.datasets.images.ImageQueryGalleryDataset :undoc-members: :show-inheritance: .. automethod:: __init__ .. automethod:: __getitem__ + .. automethod:: get_query_ids + .. automethod:: get_gallery_ids + .. automethod:: visualize EmbeddingPairsDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/contents/interfaces.rst b/docs/source/contents/interfaces.rst index aa5a8a01e..7e7224491 100644 --- a/docs/source/contents/interfaces.rst +++ b/docs/source/contents/interfaces.rst @@ -52,22 +52,39 @@ ITripletLossWithMiner .. automethod:: forward -IDatasetWithLabels +IBaseDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.interfaces.datasets.IDatasetWithLabels +.. autoclass:: oml.interfaces.datasets.IBaseDataset + :undoc-members: + :show-inheritance: + +ILabeledDataset +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: oml.interfaces.datasets.ILabeledDataset :undoc-members: :show-inheritance: .. automethod:: __getitem__ .. automethod:: get_labels -IDatasetQueryGallery +IQueryGalleryDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.interfaces.datasets.IDatasetQueryGallery +.. autoclass:: oml.interfaces.datasets.IQueryGalleryDataset :undoc-members: :show-inheritance: - .. automethod:: __getitem__ + .. automethod:: get_query_ids + .. automethod:: get_gallery_ids + +IQueryGalleryLabeledDataset +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: oml.interfaces.datasets.IQueryGalleryLabeledDataset + :undoc-members: + :show-inheritance: + + .. automethod:: get_query_ids + .. automethod:: get_gallery_ids + .. automethod:: get_labels IPairsDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -78,6 +95,14 @@ IPairsDataset .. automethod:: __init__ .. automethod:: __getitem__ +IVisualizableDataset +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: oml.interfaces.datasets.IVisualizableDataset + :undoc-members: + :show-inheritance: + + .. automethod:: visualize + IBasicMetric ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: oml.interfaces.metrics.IBasicMetric diff --git a/oml/const.py b/oml/const.py index 62045d248..1890c8d94 100644 --- a/oml/const.py +++ b/oml/const.py @@ -2,7 +2,7 @@ import tempfile from pathlib import Path from sys import platform -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union from omegaconf import DictConfig @@ -50,6 +50,7 @@ def get_cache_folder() -> Path: GREEN = (0, 255, 0) BLUE = (0, 0, 255) GRAY = (120, 120, 120) +BLACK = (0, 0, 0) PAD_COLOR = (255, 255, 255) TCfg = Union[Dict[str, Any], DictConfig] @@ -62,6 +63,9 @@ def get_cache_folder() -> Path: MEAN_CLIP = (0.48145466, 0.4578275, 0.40821073) STD_CLIP = (0.26862954, 0.26130258, 0.27577711) +TBBox = Tuple[int, int, int, int] +TBBoxes = Sequence[Optional[TBBox]] + CROP_KEY = "crop" # the format is [x1, y1, x2, y2] # Required dataset format: diff --git a/oml/datasets/base.py b/oml/datasets/base.py index f76ac3b26..bd1aafd97 100644 --- a/oml/datasets/base.py +++ b/oml/datasets/base.py @@ -1,344 +1,14 @@ -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from oml.datasets.images import ImageLabeledDataset, ImageQueryGalleryLabeledDataset -import albumentations as albu -import numpy as np -import pandas as pd -import torchvision -from torch.utils.data import Dataset -from oml.const import ( - CATEGORIES_COLUMN, - CATEGORIES_KEY, - INDEX_KEY, - INPUT_TENSORS_KEY, - IS_GALLERY_COLUMN, - IS_GALLERY_KEY, - IS_QUERY_COLUMN, - IS_QUERY_KEY, - LABELS_COLUMN, - LABELS_KEY, - PATHS_COLUMN, - PATHS_KEY, - SEQUENCE_COLUMN, - SEQUENCE_KEY, - SPLIT_COLUMN, - X1_COLUMN, - X1_KEY, - X2_COLUMN, - X2_KEY, - Y1_COLUMN, - Y1_KEY, - Y2_COLUMN, - Y2_KEY, -) -from oml.interfaces.datasets import IDatasetQueryGallery, IDatasetWithLabels -from oml.registry.transforms import get_transforms -from oml.transforms.images.utils import TTransforms, get_im_reader_for_transforms -from oml.utils.dataframe_format import check_retrieval_dataframe_format -from oml.utils.images.images import TImReader +class DatasetWithLabels(ImageLabeledDataset): + # this class allows to have backward compatibility + pass -class BaseDataset(Dataset): - """ - Base class for the retrieval datasets. +class DatasetQueryGallery(ImageQueryGalleryLabeledDataset): + # this class allows to have backward compatibility + pass - """ - def __init__( - self, - df: pd.DataFrame, - extra_data: Optional[Dict[str, Any]] = None, - transform: Optional[TTransforms] = None, - dataset_root: Optional[Union[str, Path]] = None, - f_imread: Optional[TImReader] = None, - cache_size: Optional[int] = 0, - input_tensors_key: str = INPUT_TENSORS_KEY, - labels_key: str = LABELS_KEY, - paths_key: str = PATHS_KEY, - categories_key: Optional[str] = CATEGORIES_KEY, - sequence_key: Optional[str] = SEQUENCE_KEY, - x1_key: str = X1_KEY, - x2_key: str = X2_KEY, - y1_key: str = Y1_KEY, - y2_key: str = Y2_KEY, - index_key: str = INDEX_KEY, - ): - """ - - Args: - df: Table with the following obligatory columns: - - ``LABELS_COLUMN``, ``PATHS_COLUMN`` - - and the optional ones: - - ``X1_COLUMN``, ``X2_COLUMN``, ``Y1_COLUMN``, ``Y2_COLUMN``, ``CATEGORIES_COLUMN`` - - extra_data: Dictionary with additional information which we want to put into batches. We assume that - the length of each record in this structure is the same as dataset's size. - transform: Augmentations for the images, set ``None`` to perform only normalisation and casting to tensor - dataset_root: Path to the images' dir, set ``None`` if you provided the absolute paths in your dataframe - f_imread: Function to read the images, pass ``None`` so we pick it autmatically based on provided transforms - cache_size: Size of the dataset's cache - input_tensors_key: Key to put tensors into the batches - labels_key: Key to put labels into the batches - paths_key: Key put paths into the batches - categories_key: Key to put categories into the batches - sequence_key: Key to put sequence ids into the batches - x1_key: Key to put ``x1`` into the batches - x2_key: Key to put ``x2`` into the batches - y1_key: Key to put ``y1`` into the batches - y2_key: Key to put ``y2`` into the batches - index_key: Key to put samples' ids into the batches - - """ - df = df.copy() - - if extra_data is not None: - assert all( - len(record) == len(df) for record in extra_data.values() - ), "All the extra records need to have the size equal to the dataset's size" - - assert all(x in df.columns for x in (LABELS_COLUMN, PATHS_COLUMN)) - - self.input_tensors_key = input_tensors_key - self.labels_key = labels_key - self.paths_key = paths_key - self.categories_key = categories_key if (CATEGORIES_COLUMN in df.columns) else None - self.sequence_key = sequence_key if (SEQUENCE_COLUMN in df.columns) else None - self.index_key = index_key - - self.bboxes_exist = all(coord in df.columns for coord in (X1_COLUMN, X2_COLUMN, Y1_COLUMN, Y2_COLUMN)) - if self.bboxes_exist: - self.x1_key, self.x2_key, self.y1_key, self.y2_key = x1_key, x2_key, y1_key, y2_key - else: - self.x1_key, self.x2_key, self.y1_key, self.y2_key = None, None, None, None - - if dataset_root is not None: - dataset_root = Path(dataset_root) - df[PATHS_COLUMN] = df[PATHS_COLUMN].apply(lambda x: str(dataset_root / x)) - else: - df[PATHS_COLUMN] = df[PATHS_COLUMN].astype(str) - - self.df = df - self.extra_data = extra_data - self.transform = transform if transform else get_transforms("norm_albu") - self.f_imread = f_imread or get_im_reader_for_transforms(transform) - self.read_bytes_image = ( - lru_cache(maxsize=cache_size)(self._read_bytes_image) if cache_size else self._read_bytes_image - ) - - available_augs_types = (albu.Compose, torchvision.transforms.Compose) - assert isinstance(self.transform, available_augs_types), f"Type of transforms must be in {available_augs_types}" - - @staticmethod - def _read_bytes_image(path: Union[Path, str]) -> bytes: - with open(str(path), "rb") as fin: - return fin.read() - - def __getitem__(self, idx: int) -> Dict[str, Any]: - row = self.df.iloc[idx] - - img_bytes = self.read_bytes_image(row[PATHS_COLUMN]) # type: ignore - img = self.f_imread(img_bytes) - - im_h, im_w = img.shape[:2] if isinstance(img, np.ndarray) else img.size[::-1] - - if (not self.bboxes_exist) or any( - pd.isna(coord) for coord in [row[X1_COLUMN], row[X2_COLUMN], row[Y1_COLUMN], row[Y2_COLUMN]] - ): - x1, y1, x2, y2 = 0, 0, im_w, im_h - else: - x1, y1, x2, y2 = int(row[X1_COLUMN]), int(row[Y1_COLUMN]), int(row[X2_COLUMN]), int(row[Y2_COLUMN]) - - if isinstance(self.transform, albu.Compose): - img = img[y1:y2, x1:x2, :] # todo: since albu may handle bboxes we should move it to augs - image_tensor = self.transform(image=img)["image"] - else: - # torchvision.transforms - img = img.crop((x1, y1, x2, y2)) - image_tensor = self.transform(img) - - item = { - self.input_tensors_key: image_tensor, - self.labels_key: row[LABELS_COLUMN], - self.paths_key: row[PATHS_COLUMN], - self.index_key: idx, - } - - if self.categories_key: - item[self.categories_key] = row[CATEGORIES_COLUMN] - - if self.sequence_key: - item[self.sequence_key] = row[SEQUENCE_COLUMN] - - if self.bboxes_exist: - item.update( - { - self.x1_key: x1, - self.y1_key: y1, - self.x2_key: x2, - self.y2_key: y2, - } - ) - - if self.extra_data: - for key, record in self.extra_data.items(): - if key in item: - raise ValueError(f" and dataset share the same key: {key}") - else: - item[key] = record[idx] - - return item - - def __len__(self) -> int: - return len(self.df) - - @property - def bboxes_keys(self) -> Tuple[str, ...]: - if self.bboxes_exist: - return self.x1_key, self.y1_key, self.x2_key, self.y2_key - else: - return tuple() - - -class DatasetWithLabels(BaseDataset, IDatasetWithLabels): - """ - The main purpose of this class is to be used as a dataset during - the training stage. - - It has to know how to return its labels, which is required information - to perform the training with the combinations-based losses. - Particularly, these labels will be passed to `Sampler` to form the batches and - batches will be passed to `Miner` to form the combinations (triplets). - - """ - - def get_labels(self) -> np.ndarray: - return np.array(self.df[LABELS_COLUMN].tolist()) - - def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]: - """ - Returns: - Label to category mapping if there was category information in DataFrame, None otherwise. - - """ - if CATEGORIES_COLUMN in self.df.columns: - label2category = dict(zip(self.df[LABELS_COLUMN], self.df[CATEGORIES_COLUMN])) - else: - label2category = None - - return label2category - - -class DatasetQueryGallery(BaseDataset, IDatasetQueryGallery): - """ - The main purpose of this class is to be used as a dataset during - the validation stage. It has to provide information - about its `query`/`gallery` split. - - Note, that some datasets used as benchmarks in Metric Learning - provide the splitting information (for example, ``DeepFashion InShop`` dataset), but some of them - don't (for example, ``CARS196`` or ``CUB200``). - The validation idea for the latter is to calculate the embeddings for the whole validation set, - then for every item find ``top-k`` nearest neighbors and calculate the desired retrieval metric. - In other words, for the desired query item, the gallery is the rest of the validation dataset. - - Thus, if you want to perform this kind of validation process (`1 vs rest`) you should simply return - ``is_query == True`` and ``is_gallery == True`` for every item in the dataset as the same time. - - """ - - def __init__( - self, - df: pd.DataFrame, - extra_data: Optional[Dict[str, Any]] = None, - dataset_root: Optional[Union[str, Path]] = None, - transform: Optional[albu.Compose] = None, - f_imread: Optional[TImReader] = None, - cache_size: Optional[int] = 0, - input_tensors_key: str = INPUT_TENSORS_KEY, - labels_key: str = LABELS_KEY, - paths_key: str = PATHS_KEY, - categories_key: str = CATEGORIES_KEY, - x1_key: str = X1_KEY, - x2_key: str = X2_KEY, - y1_key: str = Y1_KEY, - y2_key: str = Y2_KEY, - is_query_key: str = IS_QUERY_KEY, - is_gallery_key: str = IS_GALLERY_KEY, - ): - super(DatasetQueryGallery, self).__init__( - df=df, - extra_data=extra_data, - dataset_root=dataset_root, - transform=transform, - f_imread=f_imread, - cache_size=cache_size, - input_tensors_key=input_tensors_key, - labels_key=labels_key, - paths_key=paths_key, - categories_key=categories_key, - x1_key=x1_key, - x2_key=x2_key, - y1_key=y1_key, - y2_key=y2_key, - ) - assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN)) - - self.is_query_key = is_query_key - self.is_gallery_key = is_gallery_key - - def __getitem__(self, idx: int) -> Dict[str, Any]: - item = super().__getitem__(idx) - item[self.is_query_key] = bool(self.df.iloc[idx][IS_QUERY_COLUMN]) - item[self.is_gallery_key] = bool(self.df.iloc[idx][IS_GALLERY_COLUMN]) - return item - - -def get_retrieval_datasets( - dataset_root: Path, - transforms_train: Any, - transforms_val: Any, - f_imread_train: Optional[TImReader] = None, - f_imread_val: Optional[TImReader] = None, - dataframe_name: str = "df.csv", - cache_size: Optional[int] = 0, - verbose: bool = True, -) -> Tuple[DatasetWithLabels, DatasetQueryGallery]: - df = pd.read_csv(dataset_root / dataframe_name, index_col=False) - - check_retrieval_dataframe_format(df, dataset_root=dataset_root, verbose=verbose) - - # first half will consist of "train" split, second one of "val" - # so labels in train will be from 0 to N-1 and labels in test will be from N to K - mapper = {l: i for i, l in enumerate(df.sort_values(by=[SPLIT_COLUMN])[LABELS_COLUMN].unique())} - - # train - df_train = df[df[SPLIT_COLUMN] == "train"].reset_index(drop=True) - df_train[LABELS_COLUMN] = df_train[LABELS_COLUMN].map(mapper) - - train_dataset = DatasetWithLabels( - df=df_train, - dataset_root=dataset_root, - transform=transforms_train, - cache_size=cache_size, - f_imread=f_imread_train, - ) - - # val (query + gallery) - df_query_gallery = df[df[SPLIT_COLUMN] == "validation"].reset_index(drop=True) - valid_dataset = DatasetQueryGallery( - df=df_query_gallery, - dataset_root=dataset_root, - transform=transforms_val, - cache_size=cache_size, - f_imread=f_imread_val, - ) - - return train_dataset, valid_dataset - - -__all__ = ["BaseDataset", "DatasetWithLabels", "DatasetQueryGallery", "get_retrieval_datasets"] +__all__ = ["DatasetWithLabels", "DatasetQueryGallery"] diff --git a/oml/datasets/images.py b/oml/datasets/images.py new file mode 100644 index 000000000..8828b66fc --- /dev/null +++ b/oml/datasets/images.py @@ -0,0 +1,490 @@ +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import albumentations as albu +import numpy as np +import pandas as pd +import torch +import torchvision +from torch import BoolTensor, FloatTensor, LongTensor + +from oml.const import ( + BLACK, + CATEGORIES_COLUMN, + CATEGORIES_KEY, + INDEX_KEY, + INPUT_TENSORS_KEY, + IS_GALLERY_COLUMN, + IS_GALLERY_KEY, + IS_QUERY_COLUMN, + IS_QUERY_KEY, + LABELS_COLUMN, + LABELS_KEY, + PATHS_COLUMN, + PATHS_KEY, + SEQUENCE_COLUMN, + SEQUENCE_KEY, + SPLIT_COLUMN, + X1_COLUMN, + X1_KEY, + X2_COLUMN, + X2_KEY, + Y1_COLUMN, + Y1_KEY, + Y2_COLUMN, + Y2_KEY, + TBBoxes, + TColor, +) +from oml.interfaces.datasets import ( + IBaseDataset, + ILabeledDataset, + IQueryGalleryDataset, + IQueryGalleryLabeledDataset, + IVisualizableDataset, +) +from oml.registry.transforms import get_transforms +from oml.transforms.images.utils import TTransforms, get_im_reader_for_transforms +from oml.utils.dataframe_format import check_retrieval_dataframe_format +from oml.utils.images.images import TImReader, get_img_with_bbox + +# todo 522: general comment on Datasets +# We will remove using keys in __getitem__ for: +# Passing extra information (like categories or sequence id) -> we will use .extra_data instead +# Modality related info (like bboxes or paths) -> they may only exist as internals of the datasets +# is_query_key, is_gallery_key -> get_query_ids() and get_gallery_ids() methods +# Before this, we temporary keep both approaches + + +def parse_bboxes(df: pd.DataFrame) -> Optional[TBBoxes]: + n_existing_columns = sum([x in df for x in [X1_COLUMN, Y1_COLUMN, X2_COLUMN, Y2_COLUMN]]) + + if n_existing_columns == 4: + bboxes = [] + for _, row in df.iterrows(): + bbox = int(row[X1_COLUMN]), int(row[Y1_COLUMN]), int(row[X2_COLUMN]), int(row[Y2_COLUMN]) + bbox = None if any(coord is None for coord in bbox) else bbox + bboxes.append(bbox) + + elif n_existing_columns == 0: + bboxes = None + + else: + raise ValueError(f"Found {n_existing_columns} bounding bboxes columns instead of 4. Check your dataframe.") + + return bboxes + + +class ImageBaseDataset(IBaseDataset, IVisualizableDataset): + """ + The base class that handles image specific logic. + + """ + + def __init__( + self, + paths: List[Path], + dataset_root: Optional[Union[str, Path]] = None, + bboxes: Optional[TBBoxes] = None, + extra_data: Optional[Dict[str, Any]] = None, + transform: Optional[TTransforms] = None, + f_imread: Optional[TImReader] = None, + cache_size: Optional[int] = 0, + input_tensors_key: str = INPUT_TENSORS_KEY, + index_key: str = INDEX_KEY, + # todo 522: remove + paths_key: str = PATHS_KEY, + x1_key: str = X1_KEY, + x2_key: str = X2_KEY, + y1_key: str = Y1_KEY, + y2_key: str = Y2_KEY, + ): + """ + + Args: + paths: Paths to images. Will be concatenated with ``dataset_root`` if provided. + dataset_root: Path to the images' dir, set ``None`` if you provided the absolute paths in your dataframe + bboxes: Bounding boxes of images. Some of the images may not have bounding bboxes. + extra_data: Dictionary containing records of some additional information. + transform: Augmentations for the images, set ``None`` to perform only normalisation and casting to tensor + f_imread: Function to read the images, pass ``None`` to pick it automatically based on provided transforms + cache_size: Size of the dataset's cache + input_tensors_key: Key to put tensors into the batches + index_key: Key to put samples' ids into the batches + paths_key: Key put paths into the batches # todo 522: remove + x1_key: Key to put ``x1`` into the batches # todo 522: remove + x2_key: Key to put ``x2`` into the batches # todo 522: remove + y1_key: Key to put ``y1`` into the batches # todo 522: remove + y2_key: Key to put ``y2`` into the batches # todo 522: remove + + """ + assert (bboxes is None) or (len(paths) == len(bboxes)) + + if extra_data is not None: + assert all( + len(record) == len(paths) for record in extra_data.values() + ), "All the extra records need to have the size equal to the dataset's size" + self.extra_data = extra_data + else: + self.extra_data = {} + + self.input_tensors_key = input_tensors_key + self.index_key = index_key + + if dataset_root is not None: + paths = list(map(lambda x: Path(dataset_root) / x, paths)) + + self._paths = list(map(str, paths)) + self._bboxes = bboxes + self._transform = transform if transform else get_transforms("norm_albu") + self._f_imread = f_imread or get_im_reader_for_transforms(self._transform) + + if cache_size: + self.read_bytes = lru_cache(maxsize=cache_size)(self._read_bytes) # type: ignore + else: + self.read_bytes = self._read_bytes # type: ignore + + available_transforms = (albu.Compose, torchvision.transforms.Compose) + assert isinstance(self._transform, available_transforms), f"Transforms must one of: {available_transforms}" + + # todo 522: remove + self.paths_key = paths_key + self.x1_key = x1_key + self.x2_key = x2_key + self.y1_key = y1_key + self.y2_key = y2_key + + @staticmethod + def _read_bytes(path: Union[Path, str]) -> bytes: + with open(str(path), "rb") as fin: + return fin.read() + + def __getitem__(self, item: int) -> Dict[str, Union[FloatTensor, int]]: + img_bytes = self.read_bytes(self._paths[item]) + img = self._f_imread(img_bytes) + + im_h, im_w = img.shape[:2] if isinstance(img, np.ndarray) else img.size[::-1] + + if (self._bboxes is not None) and (self._bboxes[item] is not None): + x1, y1, x2, y2 = self._bboxes[item] + else: + x1, y1, x2, y2 = 0, 0, im_w, im_h + + if isinstance(self._transform, albu.Compose): + img = img[y1:y2, x1:x2, :] + image_tensor = self._transform(image=img)["image"] + else: + # torchvision.transforms + img = img.crop((x1, y1, x2, y2)) + image_tensor = self._transform(img) + + data = { + self.input_tensors_key: image_tensor, + self.index_key: item, + } + + for key, record in self.extra_data.items(): + if key in data: + raise ValueError(f" and dataset share the same key: {key}") + else: + data[key] = record[item] + + # todo 522: remove + data[self.x1_key] = x1 + data[self.y1_key] = y1 + data[self.x2_key] = x2 + data[self.y2_key] = y2 + data[self.paths_key] = self._paths[item] + + return data + + def __len__(self) -> int: + return len(self._paths) + + def visualize(self, item: int, color: TColor = BLACK) -> np.ndarray: + bbox = torch.tensor(self._bboxes[item]) if (self._bboxes is not None) else torch.tensor([torch.nan] * 4) + image = get_img_with_bbox(im_path=self._paths[item], bbox=bbox, color=color) + + return image + + # todo 522: remove + @property + def bboxes_keys(self) -> Tuple[str, ...]: + return self.x1_key, self.y1_key, self.x2_key, self.y2_key + + +class ImageLabeledDataset(ImageBaseDataset, ILabeledDataset): + """ + The dataset of images having their ground truth labels. + + """ + + def __init__( + self, + df: pd.DataFrame, + extra_data: Optional[Dict[str, Any]] = None, + dataset_root: Optional[Union[str, Path]] = None, + transform: Optional[albu.Compose] = None, + f_imread: Optional[TImReader] = None, + cache_size: Optional[int] = 0, + input_tensors_key: str = INPUT_TENSORS_KEY, + labels_key: str = LABELS_KEY, + index_key: str = INDEX_KEY, + # todo 522: remove + paths_key: str = PATHS_KEY, + categories_key: Optional[str] = CATEGORIES_KEY, + sequence_key: Optional[str] = SEQUENCE_KEY, + x1_key: str = X1_KEY, + x2_key: str = X2_KEY, + y1_key: str = Y1_KEY, + y2_key: str = Y2_KEY, + ): + assert (x in df.columns for x in (LABELS_COLUMN, PATHS_COLUMN)) + self.labels_key = labels_key + self.df = df + + super().__init__( + paths=self.df[PATHS_COLUMN].tolist(), + bboxes=parse_bboxes(self.df), + extra_data=extra_data, + dataset_root=dataset_root, + transform=transform, + f_imread=f_imread, + cache_size=cache_size, + input_tensors_key=input_tensors_key, + index_key=index_key, + # todo 522: remove + x1_key=x1_key, + y2_key=y2_key, + x2_key=x2_key, + y1_key=y1_key, + paths_key=paths_key, + ) + + # todo 522: remove + self.categories_key = categories_key if (CATEGORIES_COLUMN in df.columns) else None + self.sequence_key = sequence_key if (SEQUENCE_COLUMN in df.columns) else None + + def __getitem__(self, item: int) -> Dict[str, Any]: + data = super().__getitem__(item) + data[self.labels_key] = self.df.iloc[item][LABELS_COLUMN] + + # todo 522: remove + if self.sequence_key: + data[self.sequence_key] = self.df[SEQUENCE_COLUMN][item] + + if self.categories_key: + data[self.categories_key] = self.df[CATEGORIES_COLUMN][item] + + return data + + def get_labels(self) -> np.ndarray: + return np.array(self.df[LABELS_COLUMN]) + + # todo 522: remove + def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]: + if CATEGORIES_COLUMN in self.df.columns: + label2category = dict(zip(self.df[LABELS_COLUMN], self.df[CATEGORIES_COLUMN])) + else: + label2category = None + + return label2category + + +class ImageQueryGalleryLabeledDataset(ImageLabeledDataset, IQueryGalleryLabeledDataset): + """ + The annotated dataset of images having `query`/`gallery` split. + + Note, that some datasets used as benchmarks in Metric Learning + explicitly provide the splitting information (for example, ``DeepFashion InShop`` dataset), but some of them + don't (for example, ``CARS196`` or ``CUB200``). The validation idea for the latter is to perform `1 vs rest` + validation, where every query is evaluated versus the whole validation dataset (except for this exact query). + + So, if you want an item participate in validation as both: query and gallery, you should mark this item as + ``is_query == True`` and ``is_gallery == True``, as it's done in the `CARS196` or `CUB200` dataset. + + """ + + def __init__( + self, + df: pd.DataFrame, + extra_data: Optional[Dict[str, Any]] = None, + dataset_root: Optional[Union[str, Path]] = None, + transform: Optional[albu.Compose] = None, + f_imread: Optional[TImReader] = None, + cache_size: Optional[int] = 0, + input_tensors_key: str = INPUT_TENSORS_KEY, + labels_key: str = LABELS_KEY, + # todo 522: remove + paths_key: str = PATHS_KEY, + categories_key: Optional[str] = CATEGORIES_KEY, + sequence_key: Optional[str] = SEQUENCE_KEY, + x1_key: str = X1_KEY, + x2_key: str = X2_KEY, + y1_key: str = Y1_KEY, + y2_key: str = Y2_KEY, + is_query_key: str = IS_QUERY_KEY, + is_gallery_key: str = IS_GALLERY_KEY, + ): + assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, LABELS_COLUMN, PATHS_COLUMN)) + self.df = df + + super().__init__( + df=df, + extra_data=extra_data, + dataset_root=dataset_root, + transform=transform, + f_imread=f_imread, + cache_size=cache_size, + input_tensors_key=input_tensors_key, + labels_key=labels_key, + # todo 522: remove + x1_key=x1_key, + y2_key=y2_key, + x2_key=x2_key, + y1_key=y1_key, + paths_key=paths_key, + categories_key=categories_key, + sequence_key=sequence_key, + ) + + # todo 522: remove + self.is_query_key = is_query_key + self.is_gallery_key = is_gallery_key + + def get_query_ids(self) -> LongTensor: + return BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze() + + def get_gallery_ids(self) -> LongTensor: + return BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze() + + def __getitem__(self, idx: int) -> Dict[str, Any]: + data = super().__getitem__(idx) + data[self.labels_key] = self.df.iloc[idx][LABELS_COLUMN] + + # todo 522: remove + data[self.is_query_key] = bool(self.df[IS_QUERY_COLUMN][idx]) + data[self.is_gallery_key] = bool(self.df[IS_GALLERY_COLUMN][idx]) + + return data + + +class ImageQueryGalleryDataset(IVisualizableDataset, IQueryGalleryDataset): + """ + The NOT annotated dataset of images having `query`/`gallery` split. + + """ + + def __init__( + self, + df: pd.DataFrame, + extra_data: Optional[Dict[str, Any]] = None, + dataset_root: Optional[Union[str, Path]] = None, + transform: Optional[albu.Compose] = None, + f_imread: Optional[TImReader] = None, + cache_size: Optional[int] = 0, + input_tensors_key: str = INPUT_TENSORS_KEY, + # todo 522: remove + paths_key: str = PATHS_KEY, + categories_key: Optional[str] = CATEGORIES_KEY, + sequence_key: Optional[str] = SEQUENCE_KEY, + x1_key: str = X1_KEY, + x2_key: str = X2_KEY, + y1_key: str = Y1_KEY, + y2_key: str = Y2_KEY, + is_query_key: str = IS_QUERY_KEY, + is_gallery_key: str = IS_GALLERY_KEY, + ): + assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, PATHS_COLUMN)) + # instead of implementing the whole logic let's just re-use QGL dataset, but with dropped labels + df = df.copy() + df[LABELS_COLUMN] = "fake_label" + + self.__dataset = ImageQueryGalleryLabeledDataset( + df=df, + extra_data=extra_data, + dataset_root=dataset_root, + transform=transform, + f_imread=f_imread, + cache_size=cache_size, + input_tensors_key=input_tensors_key, + labels_key=LABELS_COLUMN, + # todo 522: remove + x1_key=x1_key, + y2_key=y2_key, + x2_key=x2_key, + y1_key=y1_key, + paths_key=paths_key, + categories_key=categories_key, + sequence_key=sequence_key, + is_query_key=is_query_key, + is_gallery_key=is_gallery_key, + ) + + def __getitem__(self, item: int) -> Dict[str, Any]: + batch = self.__dataset[item] + del batch[self.__dataset.labels_key] + return batch + + def get_query_ids(self) -> LongTensor: + return self.__dataset.get_query_ids() + + def get_gallery_ids(self) -> LongTensor: + return self.__dataset.get_gallery_ids() + + def visualize(self, item: int, color: TColor = BLACK) -> np.ndarray: + return self.__dataset.visualize(item, color) + + +def get_retrieval_images_datasets( + dataset_root: Path, + transforms_train: Any, + transforms_val: Any, + f_imread_train: Optional[TImReader] = None, + f_imread_val: Optional[TImReader] = None, + dataframe_name: str = "df.csv", + cache_size: Optional[int] = 0, + verbose: bool = True, +) -> Tuple[ILabeledDataset, IQueryGalleryLabeledDataset]: + df = pd.read_csv(dataset_root / dataframe_name, index_col=False) + + check_retrieval_dataframe_format(df, dataset_root=dataset_root, verbose=verbose) + + # todo 522: why do we need it? + # first half will consist of "train" split, second one of "val" + # so labels in train will be from 0 to N-1 and labels in test will be from N to K + mapper = {l: i for i, l in enumerate(df.sort_values(by=[SPLIT_COLUMN])[LABELS_COLUMN].unique())} + + # train + df_train = df[df[SPLIT_COLUMN] == "train"].reset_index(drop=True) + df_train[LABELS_COLUMN] = df_train[LABELS_COLUMN].map(mapper) + + train_dataset = ImageLabeledDataset( + df=df_train, + dataset_root=dataset_root, + transform=transforms_train, + cache_size=cache_size, + f_imread=f_imread_train, + ) + + # val (query + gallery) + df_query_gallery = df[df[SPLIT_COLUMN] == "validation"].reset_index(drop=True) + valid_dataset = ImageQueryGalleryLabeledDataset( + df=df_query_gallery, + dataset_root=dataset_root, + transform=transforms_val, + cache_size=cache_size, + f_imread=f_imread_val, + ) + + return train_dataset, valid_dataset + + +__all__ = [ + "ImageBaseDataset", + "ImageLabeledDataset", + "ImageQueryGalleryDataset", + "ImageQueryGalleryLabeledDataset", + "get_retrieval_images_datasets", +] diff --git a/oml/datasets/list_dataset.py b/oml/datasets/list_dataset.py deleted file mode 100644 index d430c9027..000000000 --- a/oml/datasets/list_dataset.py +++ /dev/null @@ -1,91 +0,0 @@ -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Tuple - -import pandas as pd -from torch.utils.data import Dataset - -from oml.const import ( - INDEX_KEY, - INPUT_TENSORS_KEY, - LABELS_COLUMN, - PATHS_COLUMN, - X1_COLUMN, - X2_COLUMN, - Y1_COLUMN, - Y2_COLUMN, -) -from oml.datasets.base import BaseDataset -from oml.transforms.images.torchvision import get_normalisation_torch -from oml.transforms.images.utils import TTransforms -from oml.utils.images.images import TImReader - -TBBox = Tuple[int, int, int, int] -TBBoxes = Sequence[Optional[TBBox]] - - -class ListDataset(Dataset): - """This is a dataset to iterate over a list of images.""" - - def __init__( - self, - filenames_list: Sequence[Path], - bboxes: Optional[TBBoxes] = None, - transform: TTransforms = get_normalisation_torch(), - f_imread: Optional[TImReader] = None, - input_tensors_key: str = INPUT_TENSORS_KEY, - cache_size: Optional[int] = 0, - index_key: str = INDEX_KEY, - ): - """ - Args: - filenames_list: list of paths to images - bboxes: Should be either ``None`` or a sequence of bboxes. - If an image has ``N`` boxes, duplicate its - path ``N`` times and provide bounding box for each of them. - If you want to get an embedding for the whole image, set bbox to ``None`` for - this particular image path. The format is ``x1, y1, x2, y2``. - transform: torchvision or albumentations augmentations - f_imread: Function to read images, pass ``None`` so we pick it automatically based on provided transforms - input_tensors_key: Key to put tensors into the batches - cache_size: cache_size: Size of the dataset's cache - index_key: Key to put samples' ids into the batches - - """ - data = defaultdict(list) - data[PATHS_COLUMN] = list(map(str, filenames_list)) - data[LABELS_COLUMN] = ["none"] * len(filenames_list) - - if bboxes is not None: - for bbox in bboxes: - if bbox is not None: - x1, y1, x2, y2 = bbox - else: - x1, y1, x2, y2 = None, None, None, None - - data[X1_COLUMN].append(x1) # type: ignore - data[Y1_COLUMN].append(y1) # type: ignore - data[X2_COLUMN].append(x2) # type: ignore - data[Y2_COLUMN].append(y2) # type: ignore - - self._dataset = BaseDataset( - df=pd.DataFrame(data), - transform=transform, - f_imread=f_imread, - input_tensors_key=input_tensors_key, - cache_size=cache_size, - index_key=index_key, - ) - - self.input_tensors_key = input_tensors_key - self.index_key = index_key - self.paths_key = self._dataset.paths_key - - def __getitem__(self, idx: int) -> Dict[str, Any]: - return self._dataset[idx] - - def __len__(self) -> int: - return len(self._dataset) - - -__all__ = ["ListDataset"] diff --git a/oml/datasets/pairs.py b/oml/datasets/pairs.py index bcd928ed1..0c7b44d10 100644 --- a/oml/datasets/pairs.py +++ b/oml/datasets/pairs.py @@ -3,13 +3,15 @@ from torch import Tensor -from oml.const import INDEX_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY -from oml.datasets.list_dataset import ListDataset, TBBoxes +from oml.const import INDEX_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY, TBBoxes +from oml.datasets.images import ImageBaseDataset from oml.interfaces.datasets import IPairsDataset from oml.transforms.images.torchvision import get_normalisation_torch from oml.transforms.images.utils import TTransforms from oml.utils.images.images import TImReader, imread_pillow +# todo 522: make one modality agnostic instead of these two + class EmbeddingPairsDataset(IPairsDataset): """ @@ -96,8 +98,8 @@ def __init__( cache_size = cache_size // 2 if cache_size else None dataset_args = {"transform": transform, "f_imread": f_imread, "cache_size": cache_size} - self.dataset1 = ListDataset(paths1, bboxes=bboxes1, **dataset_args) - self.dataset2 = ListDataset(paths2, bboxes=bboxes2, **dataset_args) + self.dataset1 = ImageBaseDataset(paths=paths1, bboxes=bboxes1, **dataset_args) + self.dataset2 = ImageBaseDataset(paths=paths2, bboxes=bboxes2, **dataset_args) self.pair_1st_key = pair_1st_key self.pair_2nd_key = pair_2nd_key diff --git a/oml/inference/flat.py b/oml/inference/flat.py index 8c1584a28..e8e4c063f 100644 --- a/oml/inference/flat.py +++ b/oml/inference/flat.py @@ -7,7 +7,7 @@ from torch import Tensor, nn from oml.const import PATHS_COLUMN, SPLIT_COLUMN -from oml.datasets.list_dataset import ListDataset +from oml.datasets.images import ImageBaseDataset from oml.inference.abstract import _inference from oml.interfaces.models import IExtractor from oml.transforms.images.utils import TTransforms @@ -28,7 +28,7 @@ def inference_on_images( use_fp16: bool = False, accumulate_on_cpu: bool = True, ) -> Tensor: - dataset = ListDataset(paths, bboxes=None, transform=transform, f_imread=f_imread, cache_size=0) + dataset = ImageBaseDataset(paths=paths, bboxes=None, transform=transform, f_imread=f_imread, cache_size=0) device = get_device(model) def _apply(model_: nn.Module, batch_: Dict[str, Any]) -> Tensor: diff --git a/oml/interfaces/datasets.py b/oml/interfaces/datasets.py index 369693445..e5c3009e8 100644 --- a/oml/interfaces/datasets.py +++ b/oml/interfaces/datasets.py @@ -2,29 +2,39 @@ from typing import Any, Dict import numpy as np +from torch import LongTensor from torch.utils.data import Dataset -from oml.const import ( # noqa - INDEX_KEY, - INPUT_TENSORS_KEY, - IS_GALLERY_KEY, - IS_QUERY_KEY, - LABELS_KEY, - PAIR_1ST_KEY, - PAIR_2ND_KEY, -) -from oml.samplers.balance import BalanceSampler # noqa +from oml.const import INDEX_KEY, LABELS_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY, TColor -class IDatasetWithLabels(Dataset, ABC): +class IBaseDataset(Dataset): + input_tensors_key: str + index_key: str + extra_data: Dict[str, Any] + + def __getitem__(self, item: int) -> Dict[str, Any]: + """ + + Args: + item: Idx of the sample + + Returns: + Dictionary including the following keys: + ``self.input_tensors_key`` + ``self.index_key: int = item`` + + """ + raise NotImplementedError() + + +class ILabeledDataset(IBaseDataset, ABC): """ - This is an interface for the datasets which can provide their labels. + This is an interface for the datasets which provide labels of containing items. """ - input_tensors_key: str = INPUT_TENSORS_KEY labels_key: str = LABELS_KEY - index_key: str = INDEX_KEY def __getitem__(self, item: int) -> Dict[str, Any]: """ @@ -33,11 +43,9 @@ def __getitem__(self, item: int) -> Dict[str, Any]: item: Idx of the sample Returns: - Dictionary with the following keys: + Dictionary including the following keys: - ``self.input_tensors_key`` ``self.labels_key`` - ``self.index_key`` """ raise NotImplementedError() @@ -47,36 +55,27 @@ def get_labels(self) -> np.ndarray: raise NotImplementedError() -class IDatasetQueryGallery(Dataset, ABC): +class IQueryGalleryDataset(IBaseDataset, ABC): """ - This is an interface for the datasets which can provide the information on how to split - the validation set into the two parts: query and gallery. + This is an interface for the datasets which hold the information on how to split + the data into the query and gallery. The query and gallery ids may overlap. + It doesn't need the ground truth labels, so it can be used for prediction on not annotated data. """ - input_tensors_key: str = INPUT_TENSORS_KEY - labels_key: str = LABELS_KEY - is_query_key: str = IS_QUERY_KEY - is_gallery_key: str = IS_GALLERY_KEY - index_key: str = INDEX_KEY - @abstractmethod - def __getitem__(self, item: int) -> Dict[str, Any]: - """ - Args: - item: Idx of the sample + def get_query_ids(self) -> LongTensor: + raise NotImplementedError() - Returns: - Dictionary with the following keys: + @abstractmethod + def get_gallery_ids(self) -> LongTensor: + raise NotImplementedError() - ``self.input_tensors_key`` - ``self.labels_key`` - ``self.is_query_key`` - ``self.is_gallery_key`` - ``self.index_key`` - """ - raise NotImplementedError() +class IQueryGalleryLabeledDataset(IQueryGalleryDataset, ILabeledDataset, ABC): + """ + This interface is similar to `IQueryGalleryDataset`, but there are ground truth labels. + """ class IPairsDataset(Dataset, ABC): @@ -106,4 +105,21 @@ def __getitem__(self, item: int) -> Dict[str, Any]: raise NotImplementedError() -__all__ = ["IDatasetWithLabels", "IDatasetQueryGallery", "IPairsDataset"] +class IVisualizableDataset(Dataset, ABC): + """ + Base class for the datasets which know how to visualise their items. + """ + + @abstractmethod + def visualize(self, item: int, color: TColor) -> np.ndarray: + raise NotImplementedError() + + +__all__ = [ + "IBaseDataset", + "ILabeledDataset", + "IQueryGalleryLabeledDataset", + "IQueryGalleryDataset", + "IPairsDataset", + "IVisualizableDataset", +] diff --git a/oml/lightning/pipelines/parser.py b/oml/lightning/pipelines/parser.py index 213e98a95..306e5ce7a 100644 --- a/oml/lightning/pipelines/parser.py +++ b/oml/lightning/pipelines/parser.py @@ -6,7 +6,7 @@ from pytorch_lightning.strategies import DDPStrategy from oml.const import TCfg -from oml.datasets.base import DatasetWithLabels +from oml.interfaces.datasets import ILabeledDataset from oml.interfaces.loggers import IPipelineLogger from oml.interfaces.samplers import IBatchSampler from oml.lightning.pipelines.logging import TensorBoardPipelineLogger @@ -76,7 +76,7 @@ def parse_scheduler_from_config(cfg: TCfg, optimizer: torch.optim.Optimizer) -> return scheduler_kwargs -def parse_sampler_from_config(cfg: TCfg, dataset: DatasetWithLabels) -> Optional[IBatchSampler]: +def parse_sampler_from_config(cfg: TCfg, dataset: ILabeledDataset) -> Optional[IBatchSampler]: if ( (not dataset.categories_key) and (cfg["sampler"] is not None) diff --git a/oml/lightning/pipelines/predict.py b/oml/lightning/pipelines/predict.py index 1e188187d..8a4e8f0dc 100644 --- a/oml/lightning/pipelines/predict.py +++ b/oml/lightning/pipelines/predict.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader from oml.const import IMAGE_EXTENSIONS, TCfg -from oml.datasets.list_dataset import ListDataset +from oml.datasets.images import ImageBaseDataset from oml.ddp.utils import get_world_size_safe, is_main_process, sync_dicts_ddp from oml.lightning.modules.extractor import ExtractorModule from oml.lightning.pipelines.parser import parse_engine_params_from_config @@ -41,7 +41,7 @@ def extractor_prediction_pipeline(cfg: TCfg) -> None: if broken_images: raise ValueError(f"There are images that cannot be open:\n {broken_images}.") - dataset = ListDataset(filenames_list=filenames, transform=transforms, f_imread=f_imread) + dataset = ImageBaseDataset(paths=filenames, transform=transforms, f_imread=f_imread) loader = DataLoader( dataset=dataset, batch_size=cfg["bs"], num_workers=cfg["num_workers"], shuffle=False, drop_last=False diff --git a/oml/lightning/pipelines/train.py b/oml/lightning/pipelines/train.py index b59d8878b..dd86f7218 100644 --- a/oml/lightning/pipelines/train.py +++ b/oml/lightning/pipelines/train.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader from oml.const import TCfg -from oml.datasets.base import get_retrieval_datasets +from oml.datasets.images import get_retrieval_images_datasets from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP from oml.lightning.modules.extractor import ExtractorModule, ExtractorModuleDDP from oml.lightning.pipelines.parser import ( @@ -26,7 +26,7 @@ def get_retrieval_loaders(cfg: TCfg) -> Tuple[DataLoader, DataLoader]: - train_dataset, valid_dataset = get_retrieval_datasets( + train_dataset, valid_dataset = get_retrieval_images_datasets( dataset_root=Path(cfg["dataset_root"]), transforms_train=get_transforms_by_cfg(cfg["transforms_train"]), transforms_val=get_transforms_by_cfg(cfg["transforms_val"]), diff --git a/oml/lightning/pipelines/train_postprocessor.py b/oml/lightning/pipelines/train_postprocessor.py index 97485328e..cc302941f 100644 --- a/oml/lightning/pipelines/train_postprocessor.py +++ b/oml/lightning/pipelines/train_postprocessor.py @@ -11,7 +11,7 @@ from torch.utils.data import DataLoader from oml.const import BBOXES_COLUMNS, EMBEDDINGS_KEY, TCfg -from oml.datasets.base import DatasetQueryGallery, DatasetWithLabels +from oml.datasets.base import ImageLabeledDataset, ImageQueryGalleryLabeledDataset from oml.inference.flat import inference_on_dataframe from oml.interfaces.models import IPairwiseModel from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP @@ -81,13 +81,13 @@ def get_loaders_with_embeddings(cfg: TCfg) -> Tuple[DataLoader, DataLoader]: use_fp16=int(cfg.get("precision", 32)) == 16, ) - train_dataset = DatasetWithLabels( + train_dataset = ImageLabeledDataset( df=df_train, transform=get_transforms_by_cfg(cfg["transforms_train"]), extra_data={EMBEDDINGS_KEY: emb_train}, ) - valid_dataset = DatasetQueryGallery( + valid_dataset = ImageQueryGalleryLabeledDataset( df=df_val, # we don't care about transforms, since the only goal of this dataset is to deliver embeddings transform=get_normalisation_resize_torch(im_size=8), diff --git a/oml/lightning/pipelines/validate.py b/oml/lightning/pipelines/validate.py index b60edcf95..543522bf4 100644 --- a/oml/lightning/pipelines/validate.py +++ b/oml/lightning/pipelines/validate.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader from oml.const import TCfg -from oml.datasets.base import get_retrieval_datasets +from oml.datasets.images import get_retrieval_images_datasets from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP from oml.lightning.modules.extractor import ExtractorModule, ExtractorModuleDDP from oml.lightning.pipelines.parser import ( @@ -35,7 +35,7 @@ def extractor_validation_pipeline(cfg: TCfg) -> Tuple[pl.Trainer, Dict[str, Any] pprint(cfg) - _, valid_dataset = get_retrieval_datasets( + _, valid_dataset = get_retrieval_images_datasets( dataset_root=Path(cfg["dataset_root"]), transforms_train=None, transforms_val=get_transforms_by_cfg(cfg["transforms_val"]), diff --git a/tests/test_integrations/test_lightning/test_pipeline.py b/tests/test_integrations/test_lightning/test_pipeline.py index 3f99822b7..b7c408e72 100644 --- a/tests/test_integrations/test_lightning/test_pipeline.py +++ b/tests/test_integrations/test_lightning/test_pipeline.py @@ -1,3 +1,4 @@ +import sys import tempfile from functools import partial from typing import Any, Dict, List @@ -27,9 +28,9 @@ def __init__(self, labels: List[int], im_size: int): self.labels = labels self.im_size = im_size - def __getitem__(self, idx: int) -> Dict[str, Any]: + def __getitem__(self, item: int) -> Dict[str, Any]: input_tensors = torch.rand((3, self.im_size, self.im_size)) - label = torch.tensor(self.labels[idx]).long() + label = torch.tensor(self.labels[item]).long() return {INPUT_TENSORS_KEY: input_tensors, LABELS_KEY: label, IS_QUERY_KEY: True, IS_GALLERY_KEY: True} def __len__(self) -> int: @@ -92,6 +93,7 @@ def create_retrieval_callback(loader_idx: int, samples_in_getitem: int) -> Metri return metric_callback +@pytest.mark.skipif(sys.platform == "darwin", reason="Does not run on macOS") @pytest.mark.parametrize( "samples_in_getitem, is_error_expected, pipeline", [ diff --git a/tests/test_integrations/test_lightning/test_train_with_sequence.py b/tests/test_integrations/test_lightning/test_train_with_sequence.py index 3371a2192..39bc24188 100644 --- a/tests/test_integrations/test_lightning/test_train_with_sequence.py +++ b/tests/test_integrations/test_lightning/test_train_with_sequence.py @@ -5,7 +5,7 @@ from tqdm import tqdm from oml.const import LABELS_COLUMN, MOCK_DATASET_PATH, SEQUENCE_COLUMN -from oml.datasets.base import DatasetQueryGallery +from oml.datasets.images import ImageQueryGalleryLabeledDataset from oml.metrics.embeddings import EmbeddingMetrics, TMetricsDict_ByLabels from oml.utils.download_mock_dataset import download_mock_dataset from oml.utils.misc import compare_dicts_recursively, set_global_seed @@ -15,7 +15,7 @@ def validation(df: pd.DataFrame) -> TMetricsDict_ByLabels: set_global_seed(42) extractor = nn.Flatten() - val_dataset = DatasetQueryGallery(df, dataset_root=MOCK_DATASET_PATH) + val_dataset = ImageQueryGalleryLabeledDataset(df, dataset_root=MOCK_DATASET_PATH) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, num_workers=0) calculator = EmbeddingMetrics(extra_keys=("paths",), sequence_key=val_dataset.sequence_key, cmc_top_k=(1,)) diff --git a/tests/test_integrations/test_retrieval_validation.py b/tests/test_integrations/test_retrieval_validation.py index e03bd798b..91f9ee6d3 100644 --- a/tests/test_integrations/test_retrieval_validation.py +++ b/tests/test_integrations/test_retrieval_validation.py @@ -1,24 +1,19 @@ from math import isclose -from typing import Any, Dict, Tuple +from typing import Tuple import pytest import torch -from torch import Tensor +from torch import BoolTensor, FloatTensor, LongTensor from torch.utils.data import DataLoader -from oml.const import ( - EMBEDDINGS_KEY, - INPUT_TENSORS_KEY, - IS_GALLERY_KEY, - IS_QUERY_KEY, - LABELS_KEY, - OVERALL_CATEGORIES_KEY, -) -from oml.interfaces.datasets import IDatasetQueryGallery +from oml.const import EMBEDDINGS_KEY, INPUT_TENSORS_KEY, OVERALL_CATEGORIES_KEY from oml.metrics.embeddings import EmbeddingMetrics -from tests.test_integrations.utils import IdealClusterEncoder +from tests.test_integrations.utils import ( + EmbeddingsQueryGalleryLabeledDataset, + IdealClusterEncoder, +) -TData = Tuple[Tensor, Tensor, Tensor, Tensor, float] +TData = Tuple[LongTensor, BoolTensor, BoolTensor, FloatTensor, float] def get_separate_query_gallery() -> TData: @@ -33,7 +28,7 @@ def get_separate_query_gallery() -> TData: cmc_gt = 3 / 5 - return labels, query_mask, gallery_mask, input_tensors, cmc_gt + return labels.long(), query_mask.bool(), gallery_mask.bool(), input_tensors, cmc_gt def get_shared_query_gallery() -> TData: @@ -46,28 +41,7 @@ def get_shared_query_gallery() -> TData: cmc_gt = 7 / 8 - return labels, query_mask, gallery_mask, input_tensors, cmc_gt - - -class DummyQGDataset(IDatasetQueryGallery): - def __init__(self, labels: Tensor, gallery_mask: Tensor, query_mask: Tensor, input_tensors: Tensor): - assert len(labels) == len(gallery_mask) == len(query_mask) - - self.labels = labels - self.gallery_mask = gallery_mask - self.query_mask = query_mask - self.input_tensors = input_tensors - - def __getitem__(self, idx: int) -> Dict[str, Any]: - return { - LABELS_KEY: self.labels[idx], - INPUT_TENSORS_KEY: self.input_tensors[idx], - IS_QUERY_KEY: self.query_mask[idx], - IS_GALLERY_KEY: self.gallery_mask[idx], - } - - def __len__(self) -> int: - return len(self.labels) + return labels.long(), query_mask.bool(), gallery_mask.bool(), input_tensors, cmc_gt @pytest.mark.parametrize("batch_size", [1, 5]) @@ -77,11 +51,11 @@ def __len__(self) -> int: def test_retrieval_validation(batch_size: int, shuffle: bool, num_workers: int, data: TData) -> None: labels, query_mask, gallery_mask, input_tensors, cmc_gt = data - dataset = DummyQGDataset( + dataset = EmbeddingsQueryGalleryLabeledDataset( labels=labels, - input_tensors=input_tensors, - query_mask=query_mask, - gallery_mask=gallery_mask, + embeddings=input_tensors, + is_query=query_mask, + is_gallery=gallery_mask, ) loader = DataLoader( diff --git a/tests/test_integrations/test_train_with_mining.py b/tests/test_integrations/test_train_with_mining.py index a869cf2b8..09f96cf09 100644 --- a/tests/test_integrations/test_train_with_mining.py +++ b/tests/test_integrations/test_train_with_mining.py @@ -8,7 +8,7 @@ from torch.utils.data import DataLoader from oml.const import INPUT_TENSORS_KEY, LABELS_KEY -from oml.interfaces.datasets import IDatasetWithLabels +from oml.interfaces.datasets import ILabeledDataset from oml.losses.triplet import TripletLossWithMiner from oml.miners.cross_batch import TripletMinerWithMemory from oml.registry.miners import get_miner @@ -16,15 +16,15 @@ from tests.test_integrations.utils import IdealOneHotModel -class DummyDataset(IDatasetWithLabels): +class DummyDataset(ILabeledDataset): def __init__(self, n_labels: int, n_samples_min: int): self.labels = [] for i in range(n_labels): self.labels.extend([i] * randint(n_samples_min, 2 * n_samples_min)) shuffle(self.labels) - def __getitem__(self, idx: int) -> Dict[str, Any]: - return {INPUT_TENSORS_KEY: torch.tensor(self.labels[idx]), LABELS_KEY: self.labels[idx]} + def __getitem__(self, item: int) -> Dict[str, Any]: + return {INPUT_TENSORS_KEY: torch.tensor(self.labels[item]), LABELS_KEY: self.labels[item]} def __len__(self) -> int: return len(self.labels) diff --git a/tests/test_integrations/utils.py b/tests/test_integrations/utils.py index 48d4fa7a2..e3a65cb96 100644 --- a/tests/test_integrations/utils.py +++ b/tests/test_integrations/utils.py @@ -1,6 +1,19 @@ +from typing import Any, Dict, Optional + +import numpy as np import torch -from torch import nn +from torch import BoolTensor, FloatTensor, LongTensor, nn +from oml.const import ( + CATEGORIES_KEY, + INDEX_KEY, + INPUT_TENSORS_KEY, + IS_GALLERY_KEY, + IS_QUERY_KEY, + LABELS_KEY, + SEQUENCE_KEY, +) +from oml.interfaces.datasets import IQueryGalleryDataset, IQueryGalleryLabeledDataset from oml.utils.misc import one_hot @@ -20,3 +33,103 @@ def forward(self, labels: torch.Tensor, need_noise: bool = True) -> torch.Tensor embeddings = labels + need_noise * 0.01 * torch.randn_like(labels, dtype=torch.float) embeddings = embeddings.view((len(labels), 1)) return embeddings + + +class EmbeddingsQueryGalleryDataset(IQueryGalleryDataset): + def __init__( + self, + embeddings: FloatTensor, + is_query: BoolTensor, + is_gallery: BoolTensor, + categories: Optional[np.ndarray] = None, + sequence: Optional[np.ndarray] = None, + input_tensors_key: str = INPUT_TENSORS_KEY, + index_key: str = INDEX_KEY, + # todo 522: remove keys later + categories_key: str = CATEGORIES_KEY, + sequence_key: str = SEQUENCE_KEY, + ): + super().__init__() + assert len(embeddings) == len(is_query) == len(is_gallery) + + self._embeddings = embeddings + self._is_query = is_query + self._is_gallery = is_gallery + + # todo 522: remove keys + self.categories_key = categories_key + self.sequence_key = sequence_key + + self.extra_data = {} + if categories is not None: + self.extra_data[self.categories_key] = categories + + if sequence is not None: + self.extra_data[self.sequence_key] = sequence + + self.input_tensors_key = input_tensors_key + self.index_key = index_key + + def __getitem__(self, item: int) -> Dict[str, Any]: + data = { + self.input_tensors_key: self._embeddings[item], + self.index_key: item, + # todo 522: remove + IS_QUERY_KEY: self._is_query[item], + IS_GALLERY_KEY: self._is_gallery[item], + } + + # todo 522: avoid passing extra data as keys + for key, record in self.extra_data.items(): + if key in data: + raise ValueError(f" and dataset share the same key: {key}") + else: + data[key] = record[item] + + return data + + def __len__(self) -> int: + return len(self._embeddings) + + def get_query_ids(self) -> LongTensor: + return self._is_query.nonzero().squeeze() + + def get_gallery_ids(self) -> LongTensor: + return self._is_gallery.nonzero().squeeze() + + +class EmbeddingsQueryGalleryLabeledDataset(EmbeddingsQueryGalleryDataset, IQueryGalleryLabeledDataset): + def __init__( + self, + embeddings: FloatTensor, + labels: LongTensor, + is_query: BoolTensor, + is_gallery: BoolTensor, + categories: Optional[np.ndarray] = None, + sequence: Optional[np.ndarray] = None, + input_tensors_key: str = INPUT_TENSORS_KEY, + labels_key: str = LABELS_KEY, + index_key: str = INDEX_KEY, + ): + super().__init__( + embeddings=embeddings, + is_query=is_query, + is_gallery=is_gallery, + categories=categories, + sequence=sequence, + input_tensors_key=input_tensors_key, + index_key=index_key, + ) + + assert len(embeddings) == len(labels) + + self._labels = labels + self.labels_key = labels_key + + def __getitem__(self, item: int) -> Dict[str, Any]: + data = super().__getitem__(item) + data[self.labels_key] = self._labels[item] + return data + + def get_labels(self) -> np.ndarray: + return np.array(self._labels) diff --git a/tests/test_oml/test_datasets/test_list_dataest.py b/tests/test_oml/test_datasets/test_list_dataest.py deleted file mode 100644 index b0d6391ee..000000000 --- a/tests/test_oml/test_datasets/test_list_dataest.py +++ /dev/null @@ -1,79 +0,0 @@ -from pathlib import Path -from typing import Iterator, List, Optional, Sequence, Tuple - -import pandas as pd -import pytest -import torch -from torch.utils.data import DataLoader - -from oml.const import MOCK_DATASET_PATH -from oml.datasets.list_dataset import ListDataset, TBBox - - -@pytest.fixture -def images() -> Iterator[List[Path]]: - yield list((MOCK_DATASET_PATH / "images").iterdir()) - - -def get_images_and_boxes() -> Tuple[List[Path], List[TBBox]]: - df = pd.read_csv(MOCK_DATASET_PATH / "df_with_bboxes.csv") - sub_df = df[["path", "x_1", "y_1", "x_2", "y_2"]] - sub_df["path"] = sub_df["path"].apply(lambda p: MOCK_DATASET_PATH / p) - paths, bboxes = [], [] - for row in sub_df.iterrows(): - path, x1, y1, x2, y2 = row[1] - paths.append(path) - bboxes.append((x1, y1, x2, y2)) - return paths, bboxes - - -def get_images_and_boxes_with_nones() -> Tuple[List[Path], List[Optional[TBBox]]]: - import random - - random.seed(42) - - paths, bboxes = [], [] - for path, bbox in zip(*get_images_and_boxes()): - paths.append(path) - bboxes.append(bbox) - if random.random() > 0.5: - paths.append(path) - bboxes.append(None) - return paths, bboxes - - -def test_dataset_len(images: List[Path]) -> None: - assert len(images) > 0 - dataset = ListDataset(images) - - assert len(dataset) == len(images) - - -def test_dataset_iter(images: List[Path]) -> None: - dataset = ListDataset(images) - - for batch in dataset: - assert isinstance(batch[dataset.input_tensors_key], torch.Tensor) - - -def test_dataloader_iter(images: List[Path]) -> None: - dataset = ListDataset(images) - dataloader = DataLoader(dataset) - - for batch in dataloader: - assert batch[dataset.input_tensors_key].ndim == 4 - - -@pytest.mark.parametrize("im_paths,bboxes", [get_images_and_boxes(), get_images_and_boxes_with_nones()]) -def test_list_dataset_iter(im_paths: Sequence[Path], bboxes: Sequence[Optional[TBBox]]) -> None: - dataset = ListDataset(im_paths, bboxes) - - dataloader = DataLoader(dataset) - for batch, box in zip(dataloader, bboxes): - image = batch[dataset.input_tensors_key] - if box is not None: - x1, y1, x2, y2 = box - else: - x1, y1, x2, y2 = 0, 0, image.size()[2], image.size()[3] - assert image.ndim == 4 - assert image.size() == (1, 3, x2 - x1, y2 - y1) diff --git a/tests/test_oml/test_registry/test_registry.py b/tests/test_oml/test_registry/test_registry.py index 6654d5722..76eb71342 100644 --- a/tests/test_oml/test_registry/test_registry.py +++ b/tests/test_oml/test_registry/test_registry.py @@ -140,9 +140,10 @@ def test_saving_transforms_as_files() -> None: """ cfg = yaml.safe_load(cfg) - save_transforms_as_files(cfg) + names_files = save_transforms_as_files(cfg) - assert Path("transforms_train.yaml").exists() - assert not Path("transforms_val.yaml").exists() + assert len(names_files) == 1, "Check that we only saved train transforms as expected" - Path("transforms_train.yaml").unlink() + file = Path(names_files[0][1]) + assert file.exists() + Path(file).unlink() diff --git a/tests/test_oml/test_transforms/test_image_augs.py b/tests/test_oml/test_transforms/test_image_augs.py index fcbb1005b..6773a5bae 100644 --- a/tests/test_oml/test_transforms/test_image_augs.py +++ b/tests/test_oml/test_transforms/test_image_augs.py @@ -5,7 +5,7 @@ from omegaconf import OmegaConf from oml.const import CONFIGS_PATH, MOCK_DATASET_PATH -from oml.datasets.base import DatasetWithLabels +from oml.datasets.images import ImageLabeledDataset from oml.registry.transforms import TRANSFORMS_REGISTRY, get_transforms_by_cfg @@ -14,7 +14,7 @@ def test_transforms(aug_name: Optional[str]) -> None: df = pd.read_csv(MOCK_DATASET_PATH / "df.csv") transforms = get_transforms_by_cfg(OmegaConf.load(CONFIGS_PATH / "transforms" / f"{aug_name}.yaml")) - dataset = DatasetWithLabels(df=df, dataset_root=MOCK_DATASET_PATH, transform=transforms) + dataset = ImageLabeledDataset(df=df, dataset_root=MOCK_DATASET_PATH, transform=transforms) _ = dataset[0] @@ -23,6 +23,6 @@ def test_transforms(aug_name: Optional[str]) -> None: def test_default_transforms() -> None: df = pd.read_csv(MOCK_DATASET_PATH / "df.csv") - dataset = DatasetWithLabels(df=df, dataset_root=MOCK_DATASET_PATH, transform=None) + dataset = ImageLabeledDataset(df=df, dataset_root=MOCK_DATASET_PATH, transform=None) _ = dataset[0] assert True diff --git a/tests/test_runs/test_pipelines/configs/validate.yaml b/tests/test_runs/test_pipelines/configs/validate.yaml index 942e99296..5096f57a6 100644 --- a/tests/test_runs/test_pipelines/configs/validate.yaml +++ b/tests/test_runs/test_pipelines/configs/validate.yaml @@ -2,7 +2,7 @@ accelerator: cpu devices: 1 dataset_root: path_to_replace # we will replace it in runtime with the default dataset folder -dataframe_name: df.csv +dataframe_name: df_with_bboxes.csv transforms_val: name: norm_resize_torch