47
47
from oml .registry .transforms import get_transforms
48
48
from oml .transforms .images .utils import TTransforms , get_im_reader_for_transforms
49
49
from oml .utils .dataframe_format import check_retrieval_dataframe_format
50
- from oml .utils .images .images import TImReader , get_img_with_bbox , square_pad
50
+ from oml .utils .images .images import TImReader , get_img_with_bbox
51
51
52
52
# todo 522: general comment on Datasets
53
53
# We will remove using keys in __getitem__ for:
@@ -82,9 +82,6 @@ class ImageBaseDataset(IBaseDataset, IVisualizableDataset):
82
82
83
83
"""
84
84
85
- input_tensors_key : str
86
- index_key : str
87
-
88
85
def __init__ (
89
86
self ,
90
87
paths : List [Path ],
@@ -106,7 +103,7 @@ def __init__(
106
103
"""
107
104
108
105
Args:
109
- paths: Paths to images. Will be concatenated with ``dataset_root`` is provided.
106
+ paths: Paths to images. Will be concatenated with ``dataset_root`` if provided.
110
107
dataset_root: Path to the images' dir, set ``None`` if you provided the absolute paths in your dataframe
111
108
bboxes: Bounding boxes of images. Some of the images may not have bounding bboxes.
112
109
extra_data: Dictionary containing records of some additional information.
@@ -128,20 +125,20 @@ def __init__(
128
125
assert all (
129
126
len (record ) == len (paths ) for record in extra_data .values ()
130
127
), "All the extra records need to have the size equal to the dataset's size"
128
+ self .extra_data = extra_data
129
+ else :
130
+ self .extra_data = {}
131
131
132
132
self .input_tensors_key = input_tensors_key
133
133
self .index_key = index_key
134
134
135
135
if dataset_root is not None :
136
- self ._paths = list (map (lambda x : str (Path (dataset_root ) / x ), paths ))
137
- else :
138
- self ._paths = list (map (str , paths ))
139
-
140
- self .extra_data = extra_data
136
+ paths = list (map (lambda x : Path (dataset_root ) / x ), paths ) # type: ignore
141
137
138
+ self ._paths = list (map (str , paths ))
142
139
self ._bboxes = bboxes
143
140
self ._transform = transform if transform else get_transforms ("norm_albu" )
144
- self ._f_imread = f_imread or get_im_reader_for_transforms (transform )
141
+ self ._f_imread = f_imread or get_im_reader_for_transforms (self . _transform )
145
142
146
143
if cache_size :
147
144
self .read_bytes = lru_cache (maxsize = cache_size )(self ._read_bytes ) # type: ignore
@@ -163,14 +160,14 @@ def _read_bytes(path: Union[Path, str]) -> bytes:
163
160
with open (str (path ), "rb" ) as fin :
164
161
return fin .read ()
165
162
166
- def __getitem__ (self , idx : int ) -> Dict [str , Union [FloatTensor , int ]]:
167
- img_bytes = self .read_bytes (self ._paths [idx ])
163
+ def __getitem__ (self , item : int ) -> Dict [str , Union [FloatTensor , int ]]:
164
+ img_bytes = self .read_bytes (self ._paths [item ])
168
165
img = self ._f_imread (img_bytes )
169
166
170
167
im_h , im_w = img .shape [:2 ] if isinstance (img , np .ndarray ) else img .size [::- 1 ]
171
168
172
- if (self ._bboxes is not None ) and (self ._bboxes [idx ] is not None ):
173
- x1 , y1 , x2 , y2 = self ._bboxes [idx ]
169
+ if (self ._bboxes is not None ) and (self ._bboxes [item ] is not None ):
170
+ x1 , y1 , x2 , y2 = self ._bboxes [item ]
174
171
else :
175
172
x1 , y1 , x2 , y2 = 0 , 0 , im_w , im_h
176
173
@@ -182,34 +179,32 @@ def __getitem__(self, idx: int) -> Dict[str, Union[FloatTensor, int]]:
182
179
img = img .crop ((x1 , y1 , x2 , y2 ))
183
180
image_tensor = self ._transform (img )
184
181
185
- item = {
182
+ data = {
186
183
self .input_tensors_key : image_tensor ,
187
- self .index_key : idx ,
184
+ self .index_key : item ,
188
185
}
189
186
190
- if self .extra_data :
191
- for key , record in self .extra_data .items ():
192
- if key in item :
193
- raise ValueError (f"<extra_data> and dataset share the same key: { key } " )
194
- else :
195
- item [key ] = record [idx ]
187
+ for key , record in self .extra_data .items ():
188
+ if key in data :
189
+ raise ValueError (f"<extra_data> and dataset share the same key: { key } " )
190
+ else :
191
+ data [key ] = record [item ]
196
192
197
193
# todo 522: remove
198
- item [self .x1_key ] = x1
199
- item [self .y1_key ] = y1
200
- item [self .x2_key ] = x2
201
- item [self .y2_key ] = y2
202
- item [self .paths_key ] = self ._paths [idx ]
194
+ data [self .x1_key ] = x1
195
+ data [self .y1_key ] = y1
196
+ data [self .x2_key ] = x2
197
+ data [self .y2_key ] = y2
198
+ data [self .paths_key ] = self ._paths [item ]
203
199
204
- return item
200
+ return data
205
201
206
202
def __len__ (self ) -> int :
207
203
return len (self ._paths )
208
204
209
- def visualize (self , idx : int , color : TColor = BLACK ) -> np .ndarray :
210
- bbox = torch .tensor (self ._bboxes [idx ]) if (self ._bboxes is not None ) else torch .tensor ([torch .nan ] * 4 )
211
- image = get_img_with_bbox (im_path = self ._paths [idx ], bbox = bbox , color = color )
212
- image = square_pad (image )
205
+ def visualize (self , item : int , color : TColor = BLACK ) -> np .ndarray :
206
+ bbox = torch .tensor (self ._bboxes [item ]) if (self ._bboxes is not None ) else torch .tensor ([torch .nan ] * 4 )
207
+ image = get_img_with_bbox (im_path = self ._paths [item ], bbox = bbox , color = color )
213
208
214
209
return image
215
210
@@ -245,12 +240,10 @@ def __init__(
245
240
y1_key : str = Y1_KEY ,
246
241
y2_key : str = Y2_KEY ,
247
242
):
248
- assert (LABELS_COLUMN in df ) and ( PATHS_COLUMN in df ), "There are only 2 required columns."
243
+ assert (x in df . columns for x in ( LABELS_COLUMN , PATHS_COLUMN ))
249
244
self .labels_key = labels_key
250
245
self .df = df
251
246
252
- extra_data = {} if extra_data is None else extra_data
253
-
254
247
super ().__init__ (
255
248
paths = self .df [PATHS_COLUMN ].tolist (),
256
249
bboxes = parse_bboxes (self .df ),
@@ -273,18 +266,18 @@ def __init__(
273
266
self .categories_key = categories_key if (CATEGORIES_COLUMN in df .columns ) else None
274
267
self .sequence_key = sequence_key if (SEQUENCE_COLUMN in df .columns ) else None
275
268
276
- def __getitem__ (self , idx : int ) -> Dict [str , Any ]:
277
- item = super ().__getitem__ (idx )
278
- item [self .labels_key ] = self .df .iloc [idx ][LABELS_COLUMN ]
269
+ def __getitem__ (self , item : int ) -> Dict [str , Any ]:
270
+ data = super ().__getitem__ (item )
271
+ data [self .labels_key ] = self .df .iloc [item ][LABELS_COLUMN ]
279
272
280
273
# todo 522: remove
281
274
if self .sequence_key :
282
- item [self .sequence_key ] = self .df [SEQUENCE_COLUMN ][idx ]
275
+ data [self .sequence_key ] = self .df [SEQUENCE_COLUMN ][item ]
283
276
284
277
if self .categories_key :
285
- item [self .categories_key ] = self .df [CATEGORIES_COLUMN ][idx ]
278
+ data [self .categories_key ] = self .df [CATEGORIES_COLUMN ][item ]
286
279
287
- return item
280
+ return data
288
281
289
282
def get_labels (self ) -> np .ndarray :
290
283
return np .array (self .df [LABELS_COLUMN ])
@@ -299,7 +292,20 @@ def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]:
299
292
return label2category
300
293
301
294
302
- class ImageQueryGalleryDataset (ImageBaseDataset , IQueryGalleryDataset ):
295
+ class ImageQueryGalleryLabeledDataset (ImageLabeledDataset , IQueryGalleryLabeledDataset ):
296
+ """
297
+ The annotated dataset of images having `query`/`gallery` split.
298
+
299
+ Note, that some datasets used as benchmarks in Metric Learning
300
+ explicitly provide the splitting information (for example, ``DeepFashion InShop`` dataset), but some of them
301
+ don't (for example, ``CARS196`` or ``CUB200``). The validation idea for the latter is to perform `1 vs rest`
302
+ validation, where every query is evaluated versus the whole validation dataset (except for this exact query).
303
+
304
+ So, if you want an item participate in validation as both: query and gallery, you should mark this item as
305
+ ``is_query == True`` and ``is_gallery == True``, as it's done in the `CARS196` or `CUB200` dataset.
306
+
307
+ """
308
+
303
309
def __init__ (
304
310
self ,
305
311
df : pd .DataFrame ,
@@ -309,6 +315,7 @@ def __init__(
309
315
f_imread : Optional [TImReader ] = None ,
310
316
cache_size : Optional [int ] = 0 ,
311
317
input_tensors_key : str = INPUT_TENSORS_KEY ,
318
+ labels_key : str = LABELS_KEY ,
312
319
# todo 522: remove
313
320
paths_key : str = PATHS_KEY ,
314
321
categories_key : Optional [str ] = CATEGORIES_KEY ,
@@ -320,71 +327,53 @@ def __init__(
320
327
is_query_key : str = IS_QUERY_KEY ,
321
328
is_gallery_key : str = IS_GALLERY_KEY ,
322
329
):
323
- """
324
- This is a not annotated dataset of images having `query`/`gallery` split.
325
-
326
- """
327
-
328
- assert all (x in df .columns for x in (IS_QUERY_COLUMN , IS_GALLERY_COLUMN , PATHS_COLUMN ))
330
+ assert all (x in df .columns for x in (IS_QUERY_COLUMN , IS_GALLERY_COLUMN , LABELS_COLUMN , PATHS_COLUMN ))
329
331
self .df = df
330
332
331
333
super ().__init__ (
332
- paths = self . df [ PATHS_COLUMN ]. tolist () ,
334
+ df = df ,
333
335
extra_data = extra_data ,
334
336
dataset_root = dataset_root ,
335
337
transform = transform ,
336
338
f_imread = f_imread ,
337
339
cache_size = cache_size ,
338
340
input_tensors_key = input_tensors_key ,
341
+ labels_key = labels_key ,
339
342
# todo 522: remove
340
343
x1_key = x1_key ,
341
344
y2_key = y2_key ,
342
345
x2_key = x2_key ,
343
346
y1_key = y1_key ,
344
347
paths_key = paths_key ,
348
+ categories_key = categories_key ,
349
+ sequence_key = sequence_key ,
345
350
)
346
351
347
352
# todo 522: remove
348
353
self .is_query_key = is_query_key
349
354
self .is_gallery_key = is_gallery_key
350
355
351
- self .categories_key = categories_key if (CATEGORIES_COLUMN in df .columns ) else None
352
- self .sequence_key = sequence_key if (SEQUENCE_COLUMN in df .columns ) else None
353
-
354
356
def get_query_ids (self ) -> LongTensor :
355
357
return BoolTensor (self .df [IS_QUERY_COLUMN ]).nonzero ().squeeze ()
356
358
357
359
def get_gallery_ids (self ) -> LongTensor :
358
360
return BoolTensor (self .df [IS_GALLERY_COLUMN ]).nonzero ().squeeze ()
359
361
360
362
def __getitem__ (self , idx : int ) -> Dict [str , Any ]:
361
- item = super ().__getitem__ (idx )
363
+ data = super ().__getitem__ (idx )
364
+ data [self .labels_key ] = self .df .iloc [idx ][LABELS_COLUMN ]
362
365
363
366
# todo 522: remove
364
- item [self .is_query_key ] = bool (self .df [IS_QUERY_COLUMN ][idx ])
365
- item [self .is_gallery_key ] = bool (self .df [IS_GALLERY_COLUMN ][idx ])
367
+ data [self .is_query_key ] = bool (self .df [IS_QUERY_COLUMN ][idx ])
368
+ data [self .is_gallery_key ] = bool (self .df [IS_GALLERY_COLUMN ][idx ])
366
369
367
- # todo 522: remove
368
- if self .sequence_key :
369
- item [self .sequence_key ] = self .df [SEQUENCE_COLUMN ][idx ]
370
+ return data
370
371
371
- if self .categories_key :
372
- item [self .categories_key ] = self .df [CATEGORIES_COLUMN ][idx ]
373
372
374
- return item
375
-
376
-
377
- class ImageQueryGalleryLabeledDataset (ImageQueryGalleryDataset , ImageLabeledDataset , IQueryGalleryLabeledDataset ):
373
+ class ImageQueryGalleryDataset (IVisualizableDataset , IQueryGalleryDataset ):
378
374
"""
379
- This is an annotated dataset of images having `query`/`gallery` split.
380
-
381
- Note, that some datasets used as benchmarks in Metric Learning
382
- explicitly provide the splitting information (for example, ``DeepFashion InShop`` dataset), but some of them
383
- don't (for example, ``CARS196`` or ``CUB200``). The validation idea for the latter is to perform `1 vs rest`
384
- validation, where every query is evaluated versus the whole validation dataset (except for this exact query).
375
+ The NOT annotated dataset of images having `query`/`gallery` split.
385
376
386
- So, if you want an item participate in validation as both: query and gallery, you should mark this item as
387
- ``is_query == True`` and ``is_gallery == True``, as it's done in the `CARS196` or `CUB200` dataset.
388
377
"""
389
378
390
379
def __init__ (
@@ -396,7 +385,6 @@ def __init__(
396
385
f_imread : Optional [TImReader ] = None ,
397
386
cache_size : Optional [int ] = 0 ,
398
387
input_tensors_key : str = INPUT_TENSORS_KEY ,
399
- labels_key : str = LABELS_KEY ,
400
388
# todo 522: remove
401
389
paths_key : str = PATHS_KEY ,
402
390
categories_key : Optional [str ] = CATEGORIES_KEY ,
@@ -408,17 +396,20 @@ def __init__(
408
396
is_query_key : str = IS_QUERY_KEY ,
409
397
is_gallery_key : str = IS_GALLERY_KEY ,
410
398
):
411
- assert all (x in df .columns for x in (LABELS_COLUMN , IS_GALLERY_COLUMN , IS_QUERY_COLUMN , PATHS_COLUMN ))
412
- self .df = df
399
+ assert all (x in df .columns for x in (IS_QUERY_COLUMN , IS_GALLERY_COLUMN , PATHS_COLUMN ))
400
+ # instead of implementing the whole logic let's just re-use QGL dataset, but with dropped labels
401
+ df = df .copy ()
402
+ df [LABELS_COLUMN ] = "fake_label"
413
403
414
- super (). __init__ (
404
+ self . __dataset = ImageQueryGalleryLabeledDataset (
415
405
df = df ,
416
406
extra_data = extra_data ,
417
407
dataset_root = dataset_root ,
418
408
transform = transform ,
419
409
f_imread = f_imread ,
420
410
cache_size = cache_size ,
421
411
input_tensors_key = input_tensors_key ,
412
+ labels_key = LABELS_COLUMN ,
422
413
# todo 522: remove
423
414
x1_key = x1_key ,
424
415
y2_key = y2_key ,
@@ -430,13 +421,20 @@ def __init__(
430
421
is_query_key = is_query_key ,
431
422
is_gallery_key = is_gallery_key ,
432
423
)
433
- self .labels_key = labels_key
434
424
435
- def __getitem__ (self , idx : int ) -> Dict [str , Any ]:
436
- item = super ().__getitem__ (idx )
437
- item [self .labels_key ] = self .df .iloc [idx ][LABELS_COLUMN ]
425
+ def __getitem__ (self , item : int ) -> Dict [str , Any ]:
426
+ batch = self .__dataset [item ]
427
+ del batch [self .__dataset .labels_key ]
428
+ return batch
429
+
430
+ def get_query_ids (self ) -> LongTensor :
431
+ return self .__dataset .get_query_ids ()
432
+
433
+ def get_gallery_ids (self ) -> LongTensor :
434
+ return self .__dataset .get_gallery_ids ()
438
435
439
- return item
436
+ def visualize (self , item : int , color : TColor = BLACK ) -> np .ndarray :
437
+ return self .__dataset .visualize (item , color )
440
438
441
439
442
440
def get_retrieval_images_datasets (
0 commit comments