40
40
from oml .interfaces .datasets import (
41
41
IBaseDataset ,
42
42
ILabeledDataset ,
43
+ IQueryGalleryDataset ,
43
44
IQueryGalleryLabeledDataset ,
44
45
IVisualizableDataset ,
45
46
)
@@ -298,9 +299,84 @@ def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]:
298
299
return label2category
299
300
300
301
301
- class ImageQueryGalleryLabeledDataset (ImageLabeledDataset , IQueryGalleryLabeledDataset ):
302
+ class ImageQueryGalleryDataset (ImageBaseDataset , IQueryGalleryDataset ):
303
+ def __init__ (
304
+ self ,
305
+ df : pd .DataFrame ,
306
+ extra_data : Optional [Dict [str , Any ]] = None ,
307
+ dataset_root : Optional [Union [str , Path ]] = None ,
308
+ transform : Optional [albu .Compose ] = None ,
309
+ f_imread : Optional [TImReader ] = None ,
310
+ cache_size : Optional [int ] = 0 ,
311
+ input_tensors_key : str = INPUT_TENSORS_KEY ,
312
+ # todo 522: remove
313
+ paths_key : str = PATHS_KEY ,
314
+ categories_key : Optional [str ] = CATEGORIES_KEY ,
315
+ sequence_key : Optional [str ] = SEQUENCE_KEY ,
316
+ x1_key : str = X1_KEY ,
317
+ x2_key : str = X2_KEY ,
318
+ y1_key : str = Y1_KEY ,
319
+ y2_key : str = Y2_KEY ,
320
+ is_query_key : str = IS_QUERY_KEY ,
321
+ is_gallery_key : str = IS_GALLERY_KEY ,
322
+ ):
323
+ """
324
+ This is a not annotated dataset of images having `query`/`gallery` split.
325
+
326
+ """
327
+
328
+ assert all (x in df .columns for x in (IS_QUERY_COLUMN , IS_GALLERY_COLUMN , PATHS_COLUMN ))
329
+ self .df = df
330
+
331
+ super ().__init__ (
332
+ paths = self .df [PATHS_COLUMN ].tolist (),
333
+ extra_data = extra_data ,
334
+ dataset_root = dataset_root ,
335
+ transform = transform ,
336
+ f_imread = f_imread ,
337
+ cache_size = cache_size ,
338
+ input_tensors_key = input_tensors_key ,
339
+ # todo 522: remove
340
+ x1_key = x1_key ,
341
+ y2_key = y2_key ,
342
+ x2_key = x2_key ,
343
+ y1_key = y1_key ,
344
+ paths_key = paths_key ,
345
+ )
346
+
347
+ # todo 522: remove
348
+ self .is_query_key = is_query_key
349
+ self .is_gallery_key = is_gallery_key
350
+
351
+ self .categories_key = categories_key if (CATEGORIES_COLUMN in df .columns ) else None
352
+ self .sequence_key = sequence_key if (SEQUENCE_COLUMN in df .columns ) else None
353
+
354
+ def get_query_ids (self ) -> LongTensor :
355
+ return BoolTensor (self .df [IS_QUERY_COLUMN ]).nonzero ().squeeze ()
356
+
357
+ def get_gallery_ids (self ) -> LongTensor :
358
+ return BoolTensor (self .df [IS_GALLERY_COLUMN ]).nonzero ().squeeze ()
359
+
360
+ def __getitem__ (self , idx : int ) -> Dict [str , Any ]:
361
+ item = super ().__getitem__ (idx )
362
+
363
+ # todo 522: remove
364
+ item [self .is_query_key ] = bool (self .df [IS_QUERY_COLUMN ][idx ])
365
+ item [self .is_gallery_key ] = bool (self .df [IS_GALLERY_COLUMN ][idx ])
366
+
367
+ # todo 522: remove
368
+ if self .sequence_key :
369
+ item [self .sequence_key ] = self .df [SEQUENCE_COLUMN ][idx ]
370
+
371
+ if self .categories_key :
372
+ item [self .categories_key ] = self .df [CATEGORIES_COLUMN ][idx ]
373
+
374
+ return item
375
+
376
+
377
+ class ImageQueryGalleryLabeledDataset (ImageQueryGalleryDataset , ImageLabeledDataset , IQueryGalleryLabeledDataset ):
302
378
"""
303
- The dataset of images having `query`/`gallery` split.
379
+ This is an annotated dataset of images having `query`/`gallery` split.
304
380
305
381
Note, that some datasets used as benchmarks in Metric Learning
306
382
explicitly provide the splitting information (for example, ``DeepFashion InShop`` dataset), but some of them
@@ -309,7 +385,6 @@ class ImageQueryGalleryLabeledDataset(ImageLabeledDataset, IQueryGalleryLabeledD
309
385
310
386
So, if you want an item participate in validation as both: query and gallery, you should mark this item as
311
387
``is_query == True`` and ``is_gallery == True``, as it's done in the `CARS196` or `CUB200` dataset.
312
-
313
388
"""
314
389
315
390
def __init__ (
@@ -333,8 +408,8 @@ def __init__(
333
408
is_query_key : str = IS_QUERY_KEY ,
334
409
is_gallery_key : str = IS_GALLERY_KEY ,
335
410
):
336
- assert all (x in df .columns for x in (IS_QUERY_COLUMN , IS_GALLERY_COLUMN , LABELS_COLUMN ))
337
- self ._df = df
411
+ assert all (x in df .columns for x in (LABELS_COLUMN , IS_GALLERY_COLUMN , IS_QUERY_COLUMN , PATHS_COLUMN ))
412
+ self .df = df
338
413
339
414
super ().__init__ (
340
415
df = df ,
@@ -344,7 +419,6 @@ def __init__(
344
419
f_imread = f_imread ,
345
420
cache_size = cache_size ,
346
421
input_tensors_key = input_tensors_key ,
347
- labels_key = labels_key ,
348
422
# todo 522: remove
349
423
x1_key = x1_key ,
350
424
y2_key = y2_key ,
@@ -353,25 +427,14 @@ def __init__(
353
427
paths_key = paths_key ,
354
428
categories_key = categories_key ,
355
429
sequence_key = sequence_key ,
430
+ is_query_key = is_query_key ,
431
+ is_gallery_key = is_gallery_key ,
356
432
)
357
-
358
- # todo 522: remove
359
- self .is_query_key = is_query_key
360
- self .is_gallery_key = is_gallery_key
361
-
362
- def get_query_ids (self ) -> LongTensor :
363
- return BoolTensor (self ._df [IS_QUERY_COLUMN ]).nonzero ().squeeze ()
364
-
365
- def get_gallery_ids (self ) -> LongTensor :
366
- return BoolTensor (self ._df [IS_GALLERY_COLUMN ]).nonzero ().squeeze ()
433
+ self .labels_key = labels_key
367
434
368
435
def __getitem__ (self , idx : int ) -> Dict [str , Any ]:
369
436
item = super ().__getitem__ (idx )
370
- item [self .labels_key ] = self ._df .iloc [idx ][LABELS_COLUMN ]
371
-
372
- # todo 522: remove
373
- item [self .is_query_key ] = bool (self ._df [IS_QUERY_COLUMN ][idx ])
374
- item [self .is_gallery_key ] = bool (self ._df [IS_GALLERY_COLUMN ][idx ])
437
+ item [self .labels_key ] = self .df .iloc [idx ][LABELS_COLUMN ]
375
438
376
439
return item
377
440
@@ -423,6 +486,7 @@ def get_retrieval_images_datasets(
423
486
__all__ = [
424
487
"ImageBaseDataset" ,
425
488
"ImageLabeledDataset" ,
489
+ "ImageQueryGalleryDataset" ,
426
490
"ImageQueryGalleryLabeledDataset" ,
427
491
"get_retrieval_images_datasets" ,
428
492
]
0 commit comments