Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reworking datasets #541

Merged
merged 10 commits into from
Apr 20, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
addressed comments
AlekseySh committed Apr 20, 2024
commit 018d026bd20a409360db8465e075317d2318b7e6
162 changes: 80 additions & 82 deletions oml/datasets/images.py
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@
from oml.registry.transforms import get_transforms
from oml.transforms.images.utils import TTransforms, get_im_reader_for_transforms
from oml.utils.dataframe_format import check_retrieval_dataframe_format
from oml.utils.images.images import TImReader, get_img_with_bbox, square_pad
from oml.utils.images.images import TImReader, get_img_with_bbox

# todo 522: general comment on Datasets
# We will remove using keys in __getitem__ for:
@@ -82,9 +82,6 @@ class ImageBaseDataset(IBaseDataset, IVisualizableDataset):
"""

input_tensors_key: str
index_key: str

def __init__(
self,
paths: List[Path],
@@ -106,7 +103,7 @@ def __init__(
"""
Args:
paths: Paths to images. Will be concatenated with ``dataset_root`` is provided.
paths: Paths to images. Will be concatenated with ``dataset_root`` if provided.
dataset_root: Path to the images' dir, set ``None`` if you provided the absolute paths in your dataframe
bboxes: Bounding boxes of images. Some of the images may not have bounding bboxes.
extra_data: Dictionary containing records of some additional information.
@@ -128,20 +125,20 @@ def __init__(
assert all(
len(record) == len(paths) for record in extra_data.values()
), "All the extra records need to have the size equal to the dataset's size"
self.extra_data = extra_data
else:
self.extra_data = {}

self.input_tensors_key = input_tensors_key
self.index_key = index_key

if dataset_root is not None:
self._paths = list(map(lambda x: str(Path(dataset_root) / x), paths))
else:
self._paths = list(map(str, paths))

self.extra_data = extra_data
paths = list(map(lambda x: Path(dataset_root) / x), paths) # type: ignore

self._paths = list(map(str, paths))
self._bboxes = bboxes
self._transform = transform if transform else get_transforms("norm_albu")
self._f_imread = f_imread or get_im_reader_for_transforms(transform)
self._f_imread = f_imread or get_im_reader_for_transforms(self._transform)

if cache_size:
self.read_bytes = lru_cache(maxsize=cache_size)(self._read_bytes) # type: ignore
@@ -163,14 +160,14 @@ def _read_bytes(path: Union[Path, str]) -> bytes:
with open(str(path), "rb") as fin:
return fin.read()

def __getitem__(self, idx: int) -> Dict[str, Union[FloatTensor, int]]:
img_bytes = self.read_bytes(self._paths[idx])
def __getitem__(self, item: int) -> Dict[str, Union[FloatTensor, int]]:
img_bytes = self.read_bytes(self._paths[item])
img = self._f_imread(img_bytes)

im_h, im_w = img.shape[:2] if isinstance(img, np.ndarray) else img.size[::-1]

if (self._bboxes is not None) and (self._bboxes[idx] is not None):
x1, y1, x2, y2 = self._bboxes[idx]
if (self._bboxes is not None) and (self._bboxes[item] is not None):
x1, y1, x2, y2 = self._bboxes[item]
else:
x1, y1, x2, y2 = 0, 0, im_w, im_h

@@ -182,34 +179,32 @@ def __getitem__(self, idx: int) -> Dict[str, Union[FloatTensor, int]]:
img = img.crop((x1, y1, x2, y2))
image_tensor = self._transform(img)

item = {
data = {
self.input_tensors_key: image_tensor,
self.index_key: idx,
self.index_key: item,
}

if self.extra_data:
for key, record in self.extra_data.items():
if key in item:
raise ValueError(f"<extra_data> and dataset share the same key: {key}")
else:
item[key] = record[idx]
for key, record in self.extra_data.items():
if key in data:
raise ValueError(f"<extra_data> and dataset share the same key: {key}")
else:
data[key] = record[item]

# todo 522: remove
item[self.x1_key] = x1
item[self.y1_key] = y1
item[self.x2_key] = x2
item[self.y2_key] = y2
item[self.paths_key] = self._paths[idx]
data[self.x1_key] = x1
data[self.y1_key] = y1
data[self.x2_key] = x2
data[self.y2_key] = y2
data[self.paths_key] = self._paths[item]

return item
return data

def __len__(self) -> int:
return len(self._paths)

def visualize(self, idx: int, color: TColor = BLACK) -> np.ndarray:
bbox = torch.tensor(self._bboxes[idx]) if (self._bboxes is not None) else torch.tensor([torch.nan] * 4)
image = get_img_with_bbox(im_path=self._paths[idx], bbox=bbox, color=color)
image = square_pad(image)
def visualize(self, item: int, color: TColor = BLACK) -> np.ndarray:
bbox = torch.tensor(self._bboxes[item]) if (self._bboxes is not None) else torch.tensor([torch.nan] * 4)
image = get_img_with_bbox(im_path=self._paths[item], bbox=bbox, color=color)

return image

@@ -245,12 +240,10 @@ def __init__(
y1_key: str = Y1_KEY,
y2_key: str = Y2_KEY,
):
assert (LABELS_COLUMN in df) and (PATHS_COLUMN in df), "There are only 2 required columns."
assert (x in df.columns for x in (LABELS_COLUMN, PATHS_COLUMN))
self.labels_key = labels_key
self.df = df

extra_data = {} if extra_data is None else extra_data

super().__init__(
paths=self.df[PATHS_COLUMN].tolist(),
bboxes=parse_bboxes(self.df),
@@ -273,18 +266,18 @@ def __init__(
self.categories_key = categories_key if (CATEGORIES_COLUMN in df.columns) else None
self.sequence_key = sequence_key if (SEQUENCE_COLUMN in df.columns) else None

def __getitem__(self, idx: int) -> Dict[str, Any]:
item = super().__getitem__(idx)
item[self.labels_key] = self.df.iloc[idx][LABELS_COLUMN]
def __getitem__(self, item: int) -> Dict[str, Any]:
data = super().__getitem__(item)
data[self.labels_key] = self.df.iloc[item][LABELS_COLUMN]

# todo 522: remove
if self.sequence_key:
item[self.sequence_key] = self.df[SEQUENCE_COLUMN][idx]
data[self.sequence_key] = self.df[SEQUENCE_COLUMN][item]

if self.categories_key:
item[self.categories_key] = self.df[CATEGORIES_COLUMN][idx]
data[self.categories_key] = self.df[CATEGORIES_COLUMN][item]

return item
return data

def get_labels(self) -> np.ndarray:
return np.array(self.df[LABELS_COLUMN])
@@ -299,7 +292,20 @@ def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]:
return label2category


class ImageQueryGalleryDataset(ImageBaseDataset, IQueryGalleryDataset):
class ImageQueryGalleryLabeledDataset(ImageLabeledDataset, IQueryGalleryLabeledDataset):
"""
The annotated dataset of images having `query`/`gallery` split.
Note, that some datasets used as benchmarks in Metric Learning
explicitly provide the splitting information (for example, ``DeepFashion InShop`` dataset), but some of them
don't (for example, ``CARS196`` or ``CUB200``). The validation idea for the latter is to perform `1 vs rest`
validation, where every query is evaluated versus the whole validation dataset (except for this exact query).
So, if you want an item participate in validation as both: query and gallery, you should mark this item as
``is_query == True`` and ``is_gallery == True``, as it's done in the `CARS196` or `CUB200` dataset.
"""

def __init__(
self,
df: pd.DataFrame,
@@ -309,6 +315,7 @@ def __init__(
f_imread: Optional[TImReader] = None,
cache_size: Optional[int] = 0,
input_tensors_key: str = INPUT_TENSORS_KEY,
labels_key: str = LABELS_KEY,
# todo 522: remove
paths_key: str = PATHS_KEY,
categories_key: Optional[str] = CATEGORIES_KEY,
@@ -320,71 +327,53 @@ def __init__(
is_query_key: str = IS_QUERY_KEY,
is_gallery_key: str = IS_GALLERY_KEY,
):
"""
This is a not annotated dataset of images having `query`/`gallery` split.
"""

assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, PATHS_COLUMN))
assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, LABELS_COLUMN, PATHS_COLUMN))
self.df = df

super().__init__(
paths=self.df[PATHS_COLUMN].tolist(),
df=df,
extra_data=extra_data,
dataset_root=dataset_root,
transform=transform,
f_imread=f_imread,
cache_size=cache_size,
input_tensors_key=input_tensors_key,
labels_key=labels_key,
# todo 522: remove
x1_key=x1_key,
y2_key=y2_key,
x2_key=x2_key,
y1_key=y1_key,
paths_key=paths_key,
categories_key=categories_key,
sequence_key=sequence_key,
)

# todo 522: remove
self.is_query_key = is_query_key
self.is_gallery_key = is_gallery_key

self.categories_key = categories_key if (CATEGORIES_COLUMN in df.columns) else None
self.sequence_key = sequence_key if (SEQUENCE_COLUMN in df.columns) else None

def get_query_ids(self) -> LongTensor:
return BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze()

def get_gallery_ids(self) -> LongTensor:
return BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze()

def __getitem__(self, idx: int) -> Dict[str, Any]:
item = super().__getitem__(idx)
data = super().__getitem__(idx)
data[self.labels_key] = self.df.iloc[idx][LABELS_COLUMN]

# todo 522: remove
item[self.is_query_key] = bool(self.df[IS_QUERY_COLUMN][idx])
item[self.is_gallery_key] = bool(self.df[IS_GALLERY_COLUMN][idx])
data[self.is_query_key] = bool(self.df[IS_QUERY_COLUMN][idx])
data[self.is_gallery_key] = bool(self.df[IS_GALLERY_COLUMN][idx])

# todo 522: remove
if self.sequence_key:
item[self.sequence_key] = self.df[SEQUENCE_COLUMN][idx]
return data

if self.categories_key:
item[self.categories_key] = self.df[CATEGORIES_COLUMN][idx]

return item


class ImageQueryGalleryLabeledDataset(ImageQueryGalleryDataset, ImageLabeledDataset, IQueryGalleryLabeledDataset):
class ImageQueryGalleryDataset(IVisualizableDataset, IQueryGalleryDataset):
"""
This is an annotated dataset of images having `query`/`gallery` split.
Note, that some datasets used as benchmarks in Metric Learning
explicitly provide the splitting information (for example, ``DeepFashion InShop`` dataset), but some of them
don't (for example, ``CARS196`` or ``CUB200``). The validation idea for the latter is to perform `1 vs rest`
validation, where every query is evaluated versus the whole validation dataset (except for this exact query).
The NOT annotated dataset of images having `query`/`gallery` split.
So, if you want an item participate in validation as both: query and gallery, you should mark this item as
``is_query == True`` and ``is_gallery == True``, as it's done in the `CARS196` or `CUB200` dataset.
"""

def __init__(
@@ -396,7 +385,6 @@ def __init__(
f_imread: Optional[TImReader] = None,
cache_size: Optional[int] = 0,
input_tensors_key: str = INPUT_TENSORS_KEY,
labels_key: str = LABELS_KEY,
# todo 522: remove
paths_key: str = PATHS_KEY,
categories_key: Optional[str] = CATEGORIES_KEY,
@@ -408,17 +396,20 @@ def __init__(
is_query_key: str = IS_QUERY_KEY,
is_gallery_key: str = IS_GALLERY_KEY,
):
assert all(x in df.columns for x in (LABELS_COLUMN, IS_GALLERY_COLUMN, IS_QUERY_COLUMN, PATHS_COLUMN))
self.df = df
assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, PATHS_COLUMN))
# instead of implementing the whole logic let's just re-use QGL dataset, but with dropped labels
df = df.copy()
df[LABELS_COLUMN] = "fake_label"

super().__init__(
self.__dataset = ImageQueryGalleryLabeledDataset(
df=df,
extra_data=extra_data,
dataset_root=dataset_root,
transform=transform,
f_imread=f_imread,
cache_size=cache_size,
input_tensors_key=input_tensors_key,
labels_key=LABELS_COLUMN,
# todo 522: remove
x1_key=x1_key,
y2_key=y2_key,
@@ -430,13 +421,20 @@ def __init__(
is_query_key=is_query_key,
is_gallery_key=is_gallery_key,
)
self.labels_key = labels_key

def __getitem__(self, idx: int) -> Dict[str, Any]:
item = super().__getitem__(idx)
item[self.labels_key] = self.df.iloc[idx][LABELS_COLUMN]
def __getitem__(self, item: int) -> Dict[str, Any]:
batch = self.__dataset[item]
del batch[self.__dataset.labels_key]
return batch

def get_query_ids(self) -> LongTensor:
return self.__dataset.get_query_ids()

def get_gallery_ids(self) -> LongTensor:
return self.__dataset.get_gallery_ids()

return item
def visualize(self, item: int, color: TColor = BLACK) -> np.ndarray:
return self.__dataset.visualize(item, color)


def get_retrieval_images_datasets(
2 changes: 1 addition & 1 deletion oml/interfaces/datasets.py
Original file line number Diff line number Diff line change
@@ -111,7 +111,7 @@ class IVisualizableDataset(Dataset, ABC):
"""

@abstractmethod
def visualize(self, idx: int, color: TColor) -> np.ndarray:
def visualize(self, item: int, color: TColor) -> np.ndarray:
raise NotImplementedError()


4 changes: 2 additions & 2 deletions tests/test_integrations/test_lightning/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -28,9 +28,9 @@ def __init__(self, labels: List[int], im_size: int):
self.labels = labels
self.im_size = im_size

def __getitem__(self, idx: int) -> Dict[str, Any]:
def __getitem__(self, item: int) -> Dict[str, Any]:
input_tensors = torch.rand((3, self.im_size, self.im_size))
label = torch.tensor(self.labels[idx]).long()
label = torch.tensor(self.labels[item]).long()
return {INPUT_TENSORS_KEY: input_tensors, LABELS_KEY: label, IS_QUERY_KEY: True, IS_GALLERY_KEY: True}

def __len__(self) -> int:
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
from tqdm import tqdm

from oml.const import LABELS_COLUMN, MOCK_DATASET_PATH, SEQUENCE_COLUMN
from oml.datasets.base import DatasetQueryGallery
from oml.datasets.images import ImageQueryGalleryLabeledDataset
from oml.metrics.embeddings import EmbeddingMetrics, TMetricsDict_ByLabels
from oml.utils.download_mock_dataset import download_mock_dataset
from oml.utils.misc import compare_dicts_recursively, set_global_seed
@@ -15,7 +15,7 @@ def validation(df: pd.DataFrame) -> TMetricsDict_ByLabels:
set_global_seed(42)
extractor = nn.Flatten()

val_dataset = DatasetQueryGallery(df, dataset_root=MOCK_DATASET_PATH)
val_dataset = ImageQueryGalleryLabeledDataset(df, dataset_root=MOCK_DATASET_PATH)

val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, num_workers=0)
calculator = EmbeddingMetrics(extra_keys=("paths",), sequence_key=val_dataset.sequence_key, cmc_top_k=(1,))
4 changes: 2 additions & 2 deletions tests/test_integrations/test_train_with_mining.py
Original file line number Diff line number Diff line change
@@ -23,8 +23,8 @@ def __init__(self, n_labels: int, n_samples_min: int):
self.labels.extend([i] * randint(n_samples_min, 2 * n_samples_min))
shuffle(self.labels)

def __getitem__(self, idx: int) -> Dict[str, Any]:
return {INPUT_TENSORS_KEY: torch.tensor(self.labels[idx]), LABELS_KEY: self.labels[idx]}
def __getitem__(self, item: int) -> Dict[str, Any]:
return {INPUT_TENSORS_KEY: torch.tensor(self.labels[item]), LABELS_KEY: self.labels[item]}

def __len__(self) -> int:
return len(self.labels)
33 changes: 16 additions & 17 deletions tests/test_integrations/utils.py
Original file line number Diff line number Diff line change
@@ -70,24 +70,23 @@ def __init__(
self.input_tensors_key = input_tensors_key
self.index_key = index_key

def __getitem__(self, idx: int) -> Dict[str, Any]:
batch = {
self.input_tensors_key: self._embeddings[idx],
self.index_key: idx,
def __getitem__(self, item: int) -> Dict[str, Any]:
data = {
self.input_tensors_key: self._embeddings[item],
self.index_key: item,
# todo 522: remove
IS_QUERY_KEY: self._is_query[idx],
IS_GALLERY_KEY: self._is_gallery[idx],
IS_QUERY_KEY: self._is_query[item],
IS_GALLERY_KEY: self._is_gallery[item],
}

# todo 522: avoid passing extra data as keys
if self.extra_data:
for key, record in self.extra_data.items():
if key in batch:
raise ValueError(f"<extra_data> and dataset share the same key: {key}")
else:
batch[key] = record[idx]
for key, record in self.extra_data.items():
if key in data:
raise ValueError(f"<extra_data> and dataset share the same key: {key}")
else:
data[key] = record[item]

return batch
return data

def __len__(self) -> int:
return len(self._embeddings)
@@ -127,10 +126,10 @@ def __init__(
self._labels = labels
self.labels_key = labels_key

def __getitem__(self, idx: int) -> Dict[str, Any]:
item = super().__getitem__(idx)
item[self.labels_key] = self._labels[idx]
return item
def __getitem__(self, item: int) -> Dict[str, Any]:
data = super().__getitem__(item)
data[self.labels_key] = self._labels[item]
return data

def get_labels(self) -> np.ndarray:
return np.array(self._labels)