diff --git a/README.md b/README.md index 5ca5ea938..6ac9c5c63 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,8 @@ OML is a PyTorch-based framework to train and validate the models producing high ㅤㅤ ㅤㅤ ㅤㅤ - +ㅤㅤ + ㅤㅤ diff --git a/docs/readme/header.md b/docs/readme/header.md index 426c555a7..e38011127 100644 --- a/docs/readme/header.md +++ b/docs/readme/header.md @@ -23,7 +23,8 @@ OML is a PyTorch-based framework to train and validate the models producing high ㅤㅤ ㅤㅤ ㅤㅤ - +ㅤㅤ + ㅤㅤ diff --git a/oml/datasets/dataframe.py b/oml/datasets/dataframe.py new file mode 100644 index 000000000..f0f2054a5 --- /dev/null +++ b/oml/datasets/dataframe.py @@ -0,0 +1,149 @@ +from typing import Any, Dict, Optional, Union + +import numpy as np +import pandas as pd +from torch import BoolTensor, LongTensor + +from oml.const import ( + CATEGORIES_COLUMN, + IS_GALLERY_COLUMN, + IS_QUERY_COLUMN, + LABELS_COLUMN, + LABELS_KEY, + SEQUENCE_COLUMN, +) +from oml.interfaces.datasets import ( + IBaseDataset, + ILabeledDataset, + IQueryGalleryDataset, + IQueryGalleryLabeledDataset, +) + + +def update_dataset_extra_data(dataset, df, extra_data): + extra_data = dict() or extra_data + + if CATEGORIES_COLUMN in df.columns: + extra_data[CATEGORIES_COLUMN] = df[CATEGORIES_COLUMN].copy() + + if SEQUENCE_COLUMN in df.columns: + extra_data[SEQUENCE_COLUMN] = df[SEQUENCE_COLUMN].copy() + + dataset.extra_data.update(extra_data) + + return dataset + + +def label_to_category(df: pd.DataFrame) -> Dict[Union[str, int], str]: + if CATEGORIES_COLUMN in df.columns: + label2category = dict(zip(df[LABELS_COLUMN], df[CATEGORIES_COLUMN])) + else: + label2category = None + + return label2category + + +class DFLabeledDataset(ILabeledDataset): + def __init__( + self, + dataset: IBaseDataset, + df: pd.DataFrame, + extra_data: Optional[Dict[str, Any]] = None, + labels_key: str = LABELS_KEY, + ): + assert LABELS_COLUMN in df.columns + + self.__dataset = update_dataset_extra_data(dataset, df, extra_data) + + self.df = df + self.labels_key = labels_key + self.index_key = self.__dataset.index_key + self.input_tensors_key = self.__dataset.input_tensors_key + + def __len__(self) -> int: + return len(self.__dataset) + + def __getitem__(self, item: int) -> Dict[str, Any]: + data = self.__dataset[item] + data[self.labels_key] = self.df.iloc[item][LABELS_COLUMN] + return data + + def get_labels(self) -> np.ndarray: + return np.array(self.df[LABELS_COLUMN]) + + def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]: + return label_to_category(self.df) + + +class DFQueryGalleryDataset(IQueryGalleryDataset): + def __init__( + self, + dataset: IBaseDataset, + df: pd.DataFrame, + extra_data: Optional[Dict[str, Any]] = None, + ): + assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN)) + + self.__dataset = update_dataset_extra_data(dataset, df, extra_data) + + self.df = df + self._query_ids = BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze() + self._gallery_ids = BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze() + + self.index_key = self.__dataset.index_key + self.input_tensors_key = self.__dataset.input_tensors_key + + def __len__(self) -> int: + return len(self.__dataset) + + def __getitem__(self, item: int) -> Dict[str, Any]: + return self.__dataset[item] + + def get_query_ids(self) -> LongTensor: + return self._query_ids + + def get_gallery_ids(self) -> LongTensor: + return self._gallery_ids + + +class DFQueryGalleryLabeledDataset(IQueryGalleryLabeledDataset): + def __init__( + self, + dataset: IBaseDataset, + df: pd.DataFrame, + extra_data: Optional[Dict[str, Any]] = None, + labels_key: str = LABELS_KEY, + ): + + assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, LABELS_COLUMN)) + + self.__dataset = update_dataset_extra_data(dataset, df, extra_data) + + self.df = df + self._query_ids = BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze() + self._gallery_ids = BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze() + + self.index_key = self.__dataset.index_key + self.input_tensors_key = self.__dataset.input_tensors_key + self.labels_key = labels_key + + def __len__(self) -> int: + return len(self.__dataset) + + def __getitem__(self, item: int) -> Dict[str, Any]: + return self.__dataset[item] + + def get_query_ids(self) -> LongTensor: + return self._query_ids + + def get_gallery_ids(self) -> LongTensor: + return self._gallery_ids + + def get_labels(self) -> np.ndarray: + return np.array(self.df[LABELS_COLUMN]) + + def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]: + return label_to_category(self.df) + + +__all__ = ["DFLabeledDataset", "DFQueryGalleryDataset", "DFQueryGalleryLabeledDataset"] diff --git a/oml/datasets/images.py b/oml/datasets/images.py index aa6332ff9..630a1408f 100644 --- a/oml/datasets/images.py +++ b/oml/datasets/images.py @@ -28,6 +28,11 @@ TBBoxes, TColor, ) +from oml.datasets.base import ( + LabeledDatasetDF, + QueryGalleryDatasetDF, + QueryGalleryLabeledDatasetDF, +) from oml.interfaces.datasets import ( IBaseDataset, ILabeledDataset, @@ -169,7 +174,7 @@ def visualize(self, item: int, color: TColor = BLACK) -> np.ndarray: return image -class ImageLabeledDataset(ImageBaseDataset, ILabeledDataset): +class ImageLabeledDataset(LabeledDatasetDF, IVisualizableDataset): """ The dataset of images having their ground truth labels. @@ -187,22 +192,10 @@ def __init__( labels_key: str = LABELS_KEY, index_key: str = INDEX_KEY, ): - assert all(x in df.columns for x in (LABELS_COLUMN, PATHS_COLUMN)) - self.labels_key = labels_key - self.df = df - - extra_data = extra_data or dict() - - if CATEGORIES_COLUMN in df.columns: - extra_data[CATEGORIES_COLUMN] = df[CATEGORIES_COLUMN].copy() - - if SEQUENCE_COLUMN in df.columns: - extra_data[SEQUENCE_COLUMN] = df[SEQUENCE_COLUMN].copy() - - super().__init__( + dataset = ImageBaseDataset( paths=self.df[PATHS_COLUMN].tolist(), bboxes=parse_bboxes(self.df), - extra_data=extra_data, + extra_data=None, dataset_root=dataset_root, transform=transform, f_imread=f_imread, @@ -210,25 +203,13 @@ def __init__( input_tensors_key=input_tensors_key, index_key=index_key, ) + super().__init__(dataset=dataset, df=df, extra_data=extra_data, labels_key=labels_key) - def __getitem__(self, item: int) -> Dict[str, Any]: - data = super().__getitem__(item) - data[self.labels_key] = self.df.iloc[item][LABELS_COLUMN] - return data - - def get_labels(self) -> np.ndarray: - return np.array(self.df[LABELS_COLUMN]) - - def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]: - if CATEGORIES_COLUMN in self.df.columns: - label2category = dict(zip(self.df[LABELS_COLUMN], self.df[CATEGORIES_COLUMN])) - else: - label2category = None - - return label2category + def visualize(self, item: int, color: TColor) -> np.ndarray: + return self.__dataset.visualize(item=item, color=color) -class ImageQueryGalleryLabeledDataset(ImageLabeledDataset, IQueryGalleryLabeledDataset): +class ImageQueryGalleryLabeledDataset(QueryGalleryLabeledDatasetDF, IVisualizableDataset): """ The annotated dataset of images having `query`/`gallery` split. @@ -253,31 +234,24 @@ def __init__( input_tensors_key: str = INPUT_TENSORS_KEY, labels_key: str = LABELS_KEY, ): - assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, LABELS_COLUMN, PATHS_COLUMN)) - self.df = df - - super().__init__( - df=df, - extra_data=extra_data, + dataset = ImageBaseDataset( + paths=self.df[PATHS_COLUMN].tolist(), + bboxes=parse_bboxes(self.df), + extra_data=None, dataset_root=dataset_root, transform=transform, f_imread=f_imread, cache_size=cache_size, input_tensors_key=input_tensors_key, - labels_key=labels_key, ) - self._query_ids = BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze() - self._gallery_ids = BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze() + super().__init__(dataset=dataset, df=df, extra_data=extra_data, labels_key=labels_key) - def get_query_ids(self) -> LongTensor: - return self._query_ids - - def get_gallery_ids(self) -> LongTensor: - return self._gallery_ids + def visualize(self, item: int, color: TColor) -> np.ndarray: + return self.__dataset.visualize(item=item, color=color) -class ImageQueryGalleryDataset(IVisualizableDataset, IQueryGalleryDataset): +class ImageQueryGalleryDataset(QueryGalleryDatasetDF, IVisualizableDataset): """ The NOT annotated dataset of images having `query`/`gallery` split. @@ -293,42 +267,19 @@ def __init__( cache_size: Optional[int] = 0, input_tensors_key: str = INPUT_TENSORS_KEY, ): - assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, PATHS_COLUMN)) - - # instead of implementing the whole logic let's just re-use QGL dataset, but with dropped labels - df = deepcopy(df) - df[LABELS_COLUMN] = "fake_label" - - self.__dataset = ImageQueryGalleryLabeledDataset( - df=df, - extra_data=extra_data, + dataset = ImageBaseDataset( + paths=self.df[PATHS_COLUMN].tolist(), + bboxes=parse_bboxes(self.df), + extra_data=None, dataset_root=dataset_root, transform=transform, f_imread=f_imread, cache_size=cache_size, input_tensors_key=input_tensors_key, - labels_key=LABELS_COLUMN, ) + super().__init__(dataset=dataset, df=df, extra_data=extra_data) - self.extra_data = self.__dataset.extra_data - self.input_tensors_key = self.__dataset.input_tensors_key - self.index_key = self.__dataset.index_key - - def __getitem__(self, item: int) -> Dict[str, Any]: - batch = self.__dataset[item] - del batch[self.__dataset.labels_key] - return batch - - def __len__(self) -> int: - return len(self.__dataset) - - def get_query_ids(self) -> LongTensor: - return self.__dataset.get_query_ids() - - def get_gallery_ids(self) -> LongTensor: - return self.__dataset.get_gallery_ids() - - def visualize(self, item: int, color: TColor = BLACK) -> np.ndarray: + def visualize(self, item: int, color: TColor) -> np.ndarray: return self.__dataset.visualize(item=item, color=color)