Skip to content

Commit 018d026

Browse files
committed
addressed comments
1 parent 0b20c07 commit 018d026

File tree

6 files changed

+103
-106
lines changed

6 files changed

+103
-106
lines changed

oml/datasets/images.py

+80-82
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from oml.registry.transforms import get_transforms
4848
from oml.transforms.images.utils import TTransforms, get_im_reader_for_transforms
4949
from oml.utils.dataframe_format import check_retrieval_dataframe_format
50-
from oml.utils.images.images import TImReader, get_img_with_bbox, square_pad
50+
from oml.utils.images.images import TImReader, get_img_with_bbox
5151

5252
# todo 522: general comment on Datasets
5353
# We will remove using keys in __getitem__ for:
@@ -82,9 +82,6 @@ class ImageBaseDataset(IBaseDataset, IVisualizableDataset):
8282
8383
"""
8484

85-
input_tensors_key: str
86-
index_key: str
87-
8885
def __init__(
8986
self,
9087
paths: List[Path],
@@ -106,7 +103,7 @@ def __init__(
106103
"""
107104
108105
Args:
109-
paths: Paths to images. Will be concatenated with ``dataset_root`` is provided.
106+
paths: Paths to images. Will be concatenated with ``dataset_root`` if provided.
110107
dataset_root: Path to the images' dir, set ``None`` if you provided the absolute paths in your dataframe
111108
bboxes: Bounding boxes of images. Some of the images may not have bounding bboxes.
112109
extra_data: Dictionary containing records of some additional information.
@@ -128,20 +125,20 @@ def __init__(
128125
assert all(
129126
len(record) == len(paths) for record in extra_data.values()
130127
), "All the extra records need to have the size equal to the dataset's size"
128+
self.extra_data = extra_data
129+
else:
130+
self.extra_data = {}
131131

132132
self.input_tensors_key = input_tensors_key
133133
self.index_key = index_key
134134

135135
if dataset_root is not None:
136-
self._paths = list(map(lambda x: str(Path(dataset_root) / x), paths))
137-
else:
138-
self._paths = list(map(str, paths))
139-
140-
self.extra_data = extra_data
136+
paths = list(map(lambda x: Path(dataset_root) / x), paths) # type: ignore
141137

138+
self._paths = list(map(str, paths))
142139
self._bboxes = bboxes
143140
self._transform = transform if transform else get_transforms("norm_albu")
144-
self._f_imread = f_imread or get_im_reader_for_transforms(transform)
141+
self._f_imread = f_imread or get_im_reader_for_transforms(self._transform)
145142

146143
if cache_size:
147144
self.read_bytes = lru_cache(maxsize=cache_size)(self._read_bytes) # type: ignore
@@ -163,14 +160,14 @@ def _read_bytes(path: Union[Path, str]) -> bytes:
163160
with open(str(path), "rb") as fin:
164161
return fin.read()
165162

166-
def __getitem__(self, idx: int) -> Dict[str, Union[FloatTensor, int]]:
167-
img_bytes = self.read_bytes(self._paths[idx])
163+
def __getitem__(self, item: int) -> Dict[str, Union[FloatTensor, int]]:
164+
img_bytes = self.read_bytes(self._paths[item])
168165
img = self._f_imread(img_bytes)
169166

170167
im_h, im_w = img.shape[:2] if isinstance(img, np.ndarray) else img.size[::-1]
171168

172-
if (self._bboxes is not None) and (self._bboxes[idx] is not None):
173-
x1, y1, x2, y2 = self._bboxes[idx]
169+
if (self._bboxes is not None) and (self._bboxes[item] is not None):
170+
x1, y1, x2, y2 = self._bboxes[item]
174171
else:
175172
x1, y1, x2, y2 = 0, 0, im_w, im_h
176173

@@ -182,34 +179,32 @@ def __getitem__(self, idx: int) -> Dict[str, Union[FloatTensor, int]]:
182179
img = img.crop((x1, y1, x2, y2))
183180
image_tensor = self._transform(img)
184181

185-
item = {
182+
data = {
186183
self.input_tensors_key: image_tensor,
187-
self.index_key: idx,
184+
self.index_key: item,
188185
}
189186

190-
if self.extra_data:
191-
for key, record in self.extra_data.items():
192-
if key in item:
193-
raise ValueError(f"<extra_data> and dataset share the same key: {key}")
194-
else:
195-
item[key] = record[idx]
187+
for key, record in self.extra_data.items():
188+
if key in data:
189+
raise ValueError(f"<extra_data> and dataset share the same key: {key}")
190+
else:
191+
data[key] = record[item]
196192

197193
# todo 522: remove
198-
item[self.x1_key] = x1
199-
item[self.y1_key] = y1
200-
item[self.x2_key] = x2
201-
item[self.y2_key] = y2
202-
item[self.paths_key] = self._paths[idx]
194+
data[self.x1_key] = x1
195+
data[self.y1_key] = y1
196+
data[self.x2_key] = x2
197+
data[self.y2_key] = y2
198+
data[self.paths_key] = self._paths[item]
203199

204-
return item
200+
return data
205201

206202
def __len__(self) -> int:
207203
return len(self._paths)
208204

209-
def visualize(self, idx: int, color: TColor = BLACK) -> np.ndarray:
210-
bbox = torch.tensor(self._bboxes[idx]) if (self._bboxes is not None) else torch.tensor([torch.nan] * 4)
211-
image = get_img_with_bbox(im_path=self._paths[idx], bbox=bbox, color=color)
212-
image = square_pad(image)
205+
def visualize(self, item: int, color: TColor = BLACK) -> np.ndarray:
206+
bbox = torch.tensor(self._bboxes[item]) if (self._bboxes is not None) else torch.tensor([torch.nan] * 4)
207+
image = get_img_with_bbox(im_path=self._paths[item], bbox=bbox, color=color)
213208

214209
return image
215210

@@ -245,12 +240,10 @@ def __init__(
245240
y1_key: str = Y1_KEY,
246241
y2_key: str = Y2_KEY,
247242
):
248-
assert (LABELS_COLUMN in df) and (PATHS_COLUMN in df), "There are only 2 required columns."
243+
assert (x in df.columns for x in (LABELS_COLUMN, PATHS_COLUMN))
249244
self.labels_key = labels_key
250245
self.df = df
251246

252-
extra_data = {} if extra_data is None else extra_data
253-
254247
super().__init__(
255248
paths=self.df[PATHS_COLUMN].tolist(),
256249
bboxes=parse_bboxes(self.df),
@@ -273,18 +266,18 @@ def __init__(
273266
self.categories_key = categories_key if (CATEGORIES_COLUMN in df.columns) else None
274267
self.sequence_key = sequence_key if (SEQUENCE_COLUMN in df.columns) else None
275268

276-
def __getitem__(self, idx: int) -> Dict[str, Any]:
277-
item = super().__getitem__(idx)
278-
item[self.labels_key] = self.df.iloc[idx][LABELS_COLUMN]
269+
def __getitem__(self, item: int) -> Dict[str, Any]:
270+
data = super().__getitem__(item)
271+
data[self.labels_key] = self.df.iloc[item][LABELS_COLUMN]
279272

280273
# todo 522: remove
281274
if self.sequence_key:
282-
item[self.sequence_key] = self.df[SEQUENCE_COLUMN][idx]
275+
data[self.sequence_key] = self.df[SEQUENCE_COLUMN][item]
283276

284277
if self.categories_key:
285-
item[self.categories_key] = self.df[CATEGORIES_COLUMN][idx]
278+
data[self.categories_key] = self.df[CATEGORIES_COLUMN][item]
286279

287-
return item
280+
return data
288281

289282
def get_labels(self) -> np.ndarray:
290283
return np.array(self.df[LABELS_COLUMN])
@@ -299,7 +292,20 @@ def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]:
299292
return label2category
300293

301294

302-
class ImageQueryGalleryDataset(ImageBaseDataset, IQueryGalleryDataset):
295+
class ImageQueryGalleryLabeledDataset(ImageLabeledDataset, IQueryGalleryLabeledDataset):
296+
"""
297+
The annotated dataset of images having `query`/`gallery` split.
298+
299+
Note, that some datasets used as benchmarks in Metric Learning
300+
explicitly provide the splitting information (for example, ``DeepFashion InShop`` dataset), but some of them
301+
don't (for example, ``CARS196`` or ``CUB200``). The validation idea for the latter is to perform `1 vs rest`
302+
validation, where every query is evaluated versus the whole validation dataset (except for this exact query).
303+
304+
So, if you want an item participate in validation as both: query and gallery, you should mark this item as
305+
``is_query == True`` and ``is_gallery == True``, as it's done in the `CARS196` or `CUB200` dataset.
306+
307+
"""
308+
303309
def __init__(
304310
self,
305311
df: pd.DataFrame,
@@ -309,6 +315,7 @@ def __init__(
309315
f_imread: Optional[TImReader] = None,
310316
cache_size: Optional[int] = 0,
311317
input_tensors_key: str = INPUT_TENSORS_KEY,
318+
labels_key: str = LABELS_KEY,
312319
# todo 522: remove
313320
paths_key: str = PATHS_KEY,
314321
categories_key: Optional[str] = CATEGORIES_KEY,
@@ -320,71 +327,53 @@ def __init__(
320327
is_query_key: str = IS_QUERY_KEY,
321328
is_gallery_key: str = IS_GALLERY_KEY,
322329
):
323-
"""
324-
This is a not annotated dataset of images having `query`/`gallery` split.
325-
326-
"""
327-
328-
assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, PATHS_COLUMN))
330+
assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, LABELS_COLUMN, PATHS_COLUMN))
329331
self.df = df
330332

331333
super().__init__(
332-
paths=self.df[PATHS_COLUMN].tolist(),
334+
df=df,
333335
extra_data=extra_data,
334336
dataset_root=dataset_root,
335337
transform=transform,
336338
f_imread=f_imread,
337339
cache_size=cache_size,
338340
input_tensors_key=input_tensors_key,
341+
labels_key=labels_key,
339342
# todo 522: remove
340343
x1_key=x1_key,
341344
y2_key=y2_key,
342345
x2_key=x2_key,
343346
y1_key=y1_key,
344347
paths_key=paths_key,
348+
categories_key=categories_key,
349+
sequence_key=sequence_key,
345350
)
346351

347352
# todo 522: remove
348353
self.is_query_key = is_query_key
349354
self.is_gallery_key = is_gallery_key
350355

351-
self.categories_key = categories_key if (CATEGORIES_COLUMN in df.columns) else None
352-
self.sequence_key = sequence_key if (SEQUENCE_COLUMN in df.columns) else None
353-
354356
def get_query_ids(self) -> LongTensor:
355357
return BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze()
356358

357359
def get_gallery_ids(self) -> LongTensor:
358360
return BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze()
359361

360362
def __getitem__(self, idx: int) -> Dict[str, Any]:
361-
item = super().__getitem__(idx)
363+
data = super().__getitem__(idx)
364+
data[self.labels_key] = self.df.iloc[idx][LABELS_COLUMN]
362365

363366
# todo 522: remove
364-
item[self.is_query_key] = bool(self.df[IS_QUERY_COLUMN][idx])
365-
item[self.is_gallery_key] = bool(self.df[IS_GALLERY_COLUMN][idx])
367+
data[self.is_query_key] = bool(self.df[IS_QUERY_COLUMN][idx])
368+
data[self.is_gallery_key] = bool(self.df[IS_GALLERY_COLUMN][idx])
366369

367-
# todo 522: remove
368-
if self.sequence_key:
369-
item[self.sequence_key] = self.df[SEQUENCE_COLUMN][idx]
370+
return data
370371

371-
if self.categories_key:
372-
item[self.categories_key] = self.df[CATEGORIES_COLUMN][idx]
373372

374-
return item
375-
376-
377-
class ImageQueryGalleryLabeledDataset(ImageQueryGalleryDataset, ImageLabeledDataset, IQueryGalleryLabeledDataset):
373+
class ImageQueryGalleryDataset(IVisualizableDataset, IQueryGalleryDataset):
378374
"""
379-
This is an annotated dataset of images having `query`/`gallery` split.
380-
381-
Note, that some datasets used as benchmarks in Metric Learning
382-
explicitly provide the splitting information (for example, ``DeepFashion InShop`` dataset), but some of them
383-
don't (for example, ``CARS196`` or ``CUB200``). The validation idea for the latter is to perform `1 vs rest`
384-
validation, where every query is evaluated versus the whole validation dataset (except for this exact query).
375+
The NOT annotated dataset of images having `query`/`gallery` split.
385376
386-
So, if you want an item participate in validation as both: query and gallery, you should mark this item as
387-
``is_query == True`` and ``is_gallery == True``, as it's done in the `CARS196` or `CUB200` dataset.
388377
"""
389378

390379
def __init__(
@@ -396,7 +385,6 @@ def __init__(
396385
f_imread: Optional[TImReader] = None,
397386
cache_size: Optional[int] = 0,
398387
input_tensors_key: str = INPUT_TENSORS_KEY,
399-
labels_key: str = LABELS_KEY,
400388
# todo 522: remove
401389
paths_key: str = PATHS_KEY,
402390
categories_key: Optional[str] = CATEGORIES_KEY,
@@ -408,17 +396,20 @@ def __init__(
408396
is_query_key: str = IS_QUERY_KEY,
409397
is_gallery_key: str = IS_GALLERY_KEY,
410398
):
411-
assert all(x in df.columns for x in (LABELS_COLUMN, IS_GALLERY_COLUMN, IS_QUERY_COLUMN, PATHS_COLUMN))
412-
self.df = df
399+
assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, PATHS_COLUMN))
400+
# instead of implementing the whole logic let's just re-use QGL dataset, but with dropped labels
401+
df = df.copy()
402+
df[LABELS_COLUMN] = "fake_label"
413403

414-
super().__init__(
404+
self.__dataset = ImageQueryGalleryLabeledDataset(
415405
df=df,
416406
extra_data=extra_data,
417407
dataset_root=dataset_root,
418408
transform=transform,
419409
f_imread=f_imread,
420410
cache_size=cache_size,
421411
input_tensors_key=input_tensors_key,
412+
labels_key=LABELS_COLUMN,
422413
# todo 522: remove
423414
x1_key=x1_key,
424415
y2_key=y2_key,
@@ -430,13 +421,20 @@ def __init__(
430421
is_query_key=is_query_key,
431422
is_gallery_key=is_gallery_key,
432423
)
433-
self.labels_key = labels_key
434424

435-
def __getitem__(self, idx: int) -> Dict[str, Any]:
436-
item = super().__getitem__(idx)
437-
item[self.labels_key] = self.df.iloc[idx][LABELS_COLUMN]
425+
def __getitem__(self, item: int) -> Dict[str, Any]:
426+
batch = self.__dataset[item]
427+
del batch[self.__dataset.labels_key]
428+
return batch
429+
430+
def get_query_ids(self) -> LongTensor:
431+
return self.__dataset.get_query_ids()
432+
433+
def get_gallery_ids(self) -> LongTensor:
434+
return self.__dataset.get_gallery_ids()
438435

439-
return item
436+
def visualize(self, item: int, color: TColor = BLACK) -> np.ndarray:
437+
return self.__dataset.visualize(item, color)
440438

441439

442440
def get_retrieval_images_datasets(

oml/interfaces/datasets.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class IVisualizableDataset(Dataset, ABC):
111111
"""
112112

113113
@abstractmethod
114-
def visualize(self, idx: int, color: TColor) -> np.ndarray:
114+
def visualize(self, item: int, color: TColor) -> np.ndarray:
115115
raise NotImplementedError()
116116

117117

tests/test_integrations/test_lightning/test_pipeline.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def __init__(self, labels: List[int], im_size: int):
2828
self.labels = labels
2929
self.im_size = im_size
3030

31-
def __getitem__(self, idx: int) -> Dict[str, Any]:
31+
def __getitem__(self, item: int) -> Dict[str, Any]:
3232
input_tensors = torch.rand((3, self.im_size, self.im_size))
33-
label = torch.tensor(self.labels[idx]).long()
33+
label = torch.tensor(self.labels[item]).long()
3434
return {INPUT_TENSORS_KEY: input_tensors, LABELS_KEY: label, IS_QUERY_KEY: True, IS_GALLERY_KEY: True}
3535

3636
def __len__(self) -> int:

tests/test_integrations/test_lightning/test_train_with_sequence.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from tqdm import tqdm
66

77
from oml.const import LABELS_COLUMN, MOCK_DATASET_PATH, SEQUENCE_COLUMN
8-
from oml.datasets.base import DatasetQueryGallery
8+
from oml.datasets.images import ImageQueryGalleryLabeledDataset
99
from oml.metrics.embeddings import EmbeddingMetrics, TMetricsDict_ByLabels
1010
from oml.utils.download_mock_dataset import download_mock_dataset
1111
from oml.utils.misc import compare_dicts_recursively, set_global_seed
@@ -15,7 +15,7 @@ def validation(df: pd.DataFrame) -> TMetricsDict_ByLabels:
1515
set_global_seed(42)
1616
extractor = nn.Flatten()
1717

18-
val_dataset = DatasetQueryGallery(df, dataset_root=MOCK_DATASET_PATH)
18+
val_dataset = ImageQueryGalleryLabeledDataset(df, dataset_root=MOCK_DATASET_PATH)
1919

2020
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, num_workers=0)
2121
calculator = EmbeddingMetrics(extra_keys=("paths",), sequence_key=val_dataset.sequence_key, cmc_top_k=(1,))

tests/test_integrations/test_train_with_mining.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def __init__(self, n_labels: int, n_samples_min: int):
2323
self.labels.extend([i] * randint(n_samples_min, 2 * n_samples_min))
2424
shuffle(self.labels)
2525

26-
def __getitem__(self, idx: int) -> Dict[str, Any]:
27-
return {INPUT_TENSORS_KEY: torch.tensor(self.labels[idx]), LABELS_KEY: self.labels[idx]}
26+
def __getitem__(self, item: int) -> Dict[str, Any]:
27+
return {INPUT_TENSORS_KEY: torch.tensor(self.labels[item]), LABELS_KEY: self.labels[item]}
2828

2929
def __len__(self) -> int:
3030
return len(self.labels)

0 commit comments

Comments
 (0)