diff --git a/README.md b/README.md
index 5ca5ea938..6ac9c5c63 100644
--- a/README.md
+++ b/README.md
@@ -23,7 +23,8 @@ OML is a PyTorch-based framework to train and validate the models producing high
ㅤㅤ
ㅤㅤ
ㅤㅤ
-
+
ㅤㅤ
+
ㅤㅤ
diff --git a/docs/readme/header.md b/docs/readme/header.md
index 426c555a7..e38011127 100644
--- a/docs/readme/header.md
+++ b/docs/readme/header.md
@@ -23,7 +23,8 @@ OML is a PyTorch-based framework to train and validate the models producing high
ㅤㅤ
ㅤㅤ
ㅤㅤ
-
+
ㅤㅤ
+
ㅤㅤ
diff --git a/oml/datasets/dataframe.py b/oml/datasets/dataframe.py
new file mode 100644
index 000000000..f0f2054a5
--- /dev/null
+++ b/oml/datasets/dataframe.py
@@ -0,0 +1,149 @@
+from typing import Any, Dict, Optional, Union
+
+import numpy as np
+import pandas as pd
+from torch import BoolTensor, LongTensor
+
+from oml.const import (
+ CATEGORIES_COLUMN,
+ IS_GALLERY_COLUMN,
+ IS_QUERY_COLUMN,
+ LABELS_COLUMN,
+ LABELS_KEY,
+ SEQUENCE_COLUMN,
+)
+from oml.interfaces.datasets import (
+ IBaseDataset,
+ ILabeledDataset,
+ IQueryGalleryDataset,
+ IQueryGalleryLabeledDataset,
+)
+
+
+def update_dataset_extra_data(dataset, df, extra_data):
+ extra_data = dict() or extra_data
+
+ if CATEGORIES_COLUMN in df.columns:
+ extra_data[CATEGORIES_COLUMN] = df[CATEGORIES_COLUMN].copy()
+
+ if SEQUENCE_COLUMN in df.columns:
+ extra_data[SEQUENCE_COLUMN] = df[SEQUENCE_COLUMN].copy()
+
+ dataset.extra_data.update(extra_data)
+
+ return dataset
+
+
+def label_to_category(df: pd.DataFrame) -> Dict[Union[str, int], str]:
+ if CATEGORIES_COLUMN in df.columns:
+ label2category = dict(zip(df[LABELS_COLUMN], df[CATEGORIES_COLUMN]))
+ else:
+ label2category = None
+
+ return label2category
+
+
+class DFLabeledDataset(ILabeledDataset):
+ def __init__(
+ self,
+ dataset: IBaseDataset,
+ df: pd.DataFrame,
+ extra_data: Optional[Dict[str, Any]] = None,
+ labels_key: str = LABELS_KEY,
+ ):
+ assert LABELS_COLUMN in df.columns
+
+ self.__dataset = update_dataset_extra_data(dataset, df, extra_data)
+
+ self.df = df
+ self.labels_key = labels_key
+ self.index_key = self.__dataset.index_key
+ self.input_tensors_key = self.__dataset.input_tensors_key
+
+ def __len__(self) -> int:
+ return len(self.__dataset)
+
+ def __getitem__(self, item: int) -> Dict[str, Any]:
+ data = self.__dataset[item]
+ data[self.labels_key] = self.df.iloc[item][LABELS_COLUMN]
+ return data
+
+ def get_labels(self) -> np.ndarray:
+ return np.array(self.df[LABELS_COLUMN])
+
+ def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]:
+ return label_to_category(self.df)
+
+
+class DFQueryGalleryDataset(IQueryGalleryDataset):
+ def __init__(
+ self,
+ dataset: IBaseDataset,
+ df: pd.DataFrame,
+ extra_data: Optional[Dict[str, Any]] = None,
+ ):
+ assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN))
+
+ self.__dataset = update_dataset_extra_data(dataset, df, extra_data)
+
+ self.df = df
+ self._query_ids = BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze()
+ self._gallery_ids = BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze()
+
+ self.index_key = self.__dataset.index_key
+ self.input_tensors_key = self.__dataset.input_tensors_key
+
+ def __len__(self) -> int:
+ return len(self.__dataset)
+
+ def __getitem__(self, item: int) -> Dict[str, Any]:
+ return self.__dataset[item]
+
+ def get_query_ids(self) -> LongTensor:
+ return self._query_ids
+
+ def get_gallery_ids(self) -> LongTensor:
+ return self._gallery_ids
+
+
+class DFQueryGalleryLabeledDataset(IQueryGalleryLabeledDataset):
+ def __init__(
+ self,
+ dataset: IBaseDataset,
+ df: pd.DataFrame,
+ extra_data: Optional[Dict[str, Any]] = None,
+ labels_key: str = LABELS_KEY,
+ ):
+
+ assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, LABELS_COLUMN))
+
+ self.__dataset = update_dataset_extra_data(dataset, df, extra_data)
+
+ self.df = df
+ self._query_ids = BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze()
+ self._gallery_ids = BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze()
+
+ self.index_key = self.__dataset.index_key
+ self.input_tensors_key = self.__dataset.input_tensors_key
+ self.labels_key = labels_key
+
+ def __len__(self) -> int:
+ return len(self.__dataset)
+
+ def __getitem__(self, item: int) -> Dict[str, Any]:
+ return self.__dataset[item]
+
+ def get_query_ids(self) -> LongTensor:
+ return self._query_ids
+
+ def get_gallery_ids(self) -> LongTensor:
+ return self._gallery_ids
+
+ def get_labels(self) -> np.ndarray:
+ return np.array(self.df[LABELS_COLUMN])
+
+ def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]:
+ return label_to_category(self.df)
+
+
+__all__ = ["DFLabeledDataset", "DFQueryGalleryDataset", "DFQueryGalleryLabeledDataset"]
diff --git a/oml/datasets/images.py b/oml/datasets/images.py
index aa6332ff9..630a1408f 100644
--- a/oml/datasets/images.py
+++ b/oml/datasets/images.py
@@ -28,6 +28,11 @@
TBBoxes,
TColor,
)
+from oml.datasets.base import (
+ LabeledDatasetDF,
+ QueryGalleryDatasetDF,
+ QueryGalleryLabeledDatasetDF,
+)
from oml.interfaces.datasets import (
IBaseDataset,
ILabeledDataset,
@@ -169,7 +174,7 @@ def visualize(self, item: int, color: TColor = BLACK) -> np.ndarray:
return image
-class ImageLabeledDataset(ImageBaseDataset, ILabeledDataset):
+class ImageLabeledDataset(LabeledDatasetDF, IVisualizableDataset):
"""
The dataset of images having their ground truth labels.
@@ -187,22 +192,10 @@ def __init__(
labels_key: str = LABELS_KEY,
index_key: str = INDEX_KEY,
):
- assert all(x in df.columns for x in (LABELS_COLUMN, PATHS_COLUMN))
- self.labels_key = labels_key
- self.df = df
-
- extra_data = extra_data or dict()
-
- if CATEGORIES_COLUMN in df.columns:
- extra_data[CATEGORIES_COLUMN] = df[CATEGORIES_COLUMN].copy()
-
- if SEQUENCE_COLUMN in df.columns:
- extra_data[SEQUENCE_COLUMN] = df[SEQUENCE_COLUMN].copy()
-
- super().__init__(
+ dataset = ImageBaseDataset(
paths=self.df[PATHS_COLUMN].tolist(),
bboxes=parse_bboxes(self.df),
- extra_data=extra_data,
+ extra_data=None,
dataset_root=dataset_root,
transform=transform,
f_imread=f_imread,
@@ -210,25 +203,13 @@ def __init__(
input_tensors_key=input_tensors_key,
index_key=index_key,
)
+ super().__init__(dataset=dataset, df=df, extra_data=extra_data, labels_key=labels_key)
- def __getitem__(self, item: int) -> Dict[str, Any]:
- data = super().__getitem__(item)
- data[self.labels_key] = self.df.iloc[item][LABELS_COLUMN]
- return data
-
- def get_labels(self) -> np.ndarray:
- return np.array(self.df[LABELS_COLUMN])
-
- def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]:
- if CATEGORIES_COLUMN in self.df.columns:
- label2category = dict(zip(self.df[LABELS_COLUMN], self.df[CATEGORIES_COLUMN]))
- else:
- label2category = None
-
- return label2category
+ def visualize(self, item: int, color: TColor) -> np.ndarray:
+ return self.__dataset.visualize(item=item, color=color)
-class ImageQueryGalleryLabeledDataset(ImageLabeledDataset, IQueryGalleryLabeledDataset):
+class ImageQueryGalleryLabeledDataset(QueryGalleryLabeledDatasetDF, IVisualizableDataset):
"""
The annotated dataset of images having `query`/`gallery` split.
@@ -253,31 +234,24 @@ def __init__(
input_tensors_key: str = INPUT_TENSORS_KEY,
labels_key: str = LABELS_KEY,
):
- assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, LABELS_COLUMN, PATHS_COLUMN))
- self.df = df
-
- super().__init__(
- df=df,
- extra_data=extra_data,
+ dataset = ImageBaseDataset(
+ paths=self.df[PATHS_COLUMN].tolist(),
+ bboxes=parse_bboxes(self.df),
+ extra_data=None,
dataset_root=dataset_root,
transform=transform,
f_imread=f_imread,
cache_size=cache_size,
input_tensors_key=input_tensors_key,
- labels_key=labels_key,
)
- self._query_ids = BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze()
- self._gallery_ids = BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze()
+ super().__init__(dataset=dataset, df=df, extra_data=extra_data, labels_key=labels_key)
- def get_query_ids(self) -> LongTensor:
- return self._query_ids
-
- def get_gallery_ids(self) -> LongTensor:
- return self._gallery_ids
+ def visualize(self, item: int, color: TColor) -> np.ndarray:
+ return self.__dataset.visualize(item=item, color=color)
-class ImageQueryGalleryDataset(IVisualizableDataset, IQueryGalleryDataset):
+class ImageQueryGalleryDataset(QueryGalleryDatasetDF, IVisualizableDataset):
"""
The NOT annotated dataset of images having `query`/`gallery` split.
@@ -293,42 +267,19 @@ def __init__(
cache_size: Optional[int] = 0,
input_tensors_key: str = INPUT_TENSORS_KEY,
):
- assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, PATHS_COLUMN))
-
- # instead of implementing the whole logic let's just re-use QGL dataset, but with dropped labels
- df = deepcopy(df)
- df[LABELS_COLUMN] = "fake_label"
-
- self.__dataset = ImageQueryGalleryLabeledDataset(
- df=df,
- extra_data=extra_data,
+ dataset = ImageBaseDataset(
+ paths=self.df[PATHS_COLUMN].tolist(),
+ bboxes=parse_bboxes(self.df),
+ extra_data=None,
dataset_root=dataset_root,
transform=transform,
f_imread=f_imread,
cache_size=cache_size,
input_tensors_key=input_tensors_key,
- labels_key=LABELS_COLUMN,
)
+ super().__init__(dataset=dataset, df=df, extra_data=extra_data)
- self.extra_data = self.__dataset.extra_data
- self.input_tensors_key = self.__dataset.input_tensors_key
- self.index_key = self.__dataset.index_key
-
- def __getitem__(self, item: int) -> Dict[str, Any]:
- batch = self.__dataset[item]
- del batch[self.__dataset.labels_key]
- return batch
-
- def __len__(self) -> int:
- return len(self.__dataset)
-
- def get_query_ids(self) -> LongTensor:
- return self.__dataset.get_query_ids()
-
- def get_gallery_ids(self) -> LongTensor:
- return self.__dataset.get_gallery_ids()
-
- def visualize(self, item: int, color: TColor = BLACK) -> np.ndarray:
+ def visualize(self, item: int, color: TColor) -> np.ndarray:
return self.__dataset.visualize(item=item, color=color)