Skip to content

Commit 0b20c07

Browse files
committed
updated classes
1 parent 4d381ed commit 0b20c07

File tree

3 files changed

+145
-37
lines changed

3 files changed

+145
-37
lines changed

oml/datasets/images.py

+85-21
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from oml.interfaces.datasets import (
4141
IBaseDataset,
4242
ILabeledDataset,
43+
IQueryGalleryDataset,
4344
IQueryGalleryLabeledDataset,
4445
IVisualizableDataset,
4546
)
@@ -298,9 +299,84 @@ def get_label2category(self) -> Optional[Dict[int, Union[str, int]]]:
298299
return label2category
299300

300301

301-
class ImageQueryGalleryLabeledDataset(ImageLabeledDataset, IQueryGalleryLabeledDataset):
302+
class ImageQueryGalleryDataset(ImageBaseDataset, IQueryGalleryDataset):
303+
def __init__(
304+
self,
305+
df: pd.DataFrame,
306+
extra_data: Optional[Dict[str, Any]] = None,
307+
dataset_root: Optional[Union[str, Path]] = None,
308+
transform: Optional[albu.Compose] = None,
309+
f_imread: Optional[TImReader] = None,
310+
cache_size: Optional[int] = 0,
311+
input_tensors_key: str = INPUT_TENSORS_KEY,
312+
# todo 522: remove
313+
paths_key: str = PATHS_KEY,
314+
categories_key: Optional[str] = CATEGORIES_KEY,
315+
sequence_key: Optional[str] = SEQUENCE_KEY,
316+
x1_key: str = X1_KEY,
317+
x2_key: str = X2_KEY,
318+
y1_key: str = Y1_KEY,
319+
y2_key: str = Y2_KEY,
320+
is_query_key: str = IS_QUERY_KEY,
321+
is_gallery_key: str = IS_GALLERY_KEY,
322+
):
323+
"""
324+
This is a not annotated dataset of images having `query`/`gallery` split.
325+
326+
"""
327+
328+
assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, PATHS_COLUMN))
329+
self.df = df
330+
331+
super().__init__(
332+
paths=self.df[PATHS_COLUMN].tolist(),
333+
extra_data=extra_data,
334+
dataset_root=dataset_root,
335+
transform=transform,
336+
f_imread=f_imread,
337+
cache_size=cache_size,
338+
input_tensors_key=input_tensors_key,
339+
# todo 522: remove
340+
x1_key=x1_key,
341+
y2_key=y2_key,
342+
x2_key=x2_key,
343+
y1_key=y1_key,
344+
paths_key=paths_key,
345+
)
346+
347+
# todo 522: remove
348+
self.is_query_key = is_query_key
349+
self.is_gallery_key = is_gallery_key
350+
351+
self.categories_key = categories_key if (CATEGORIES_COLUMN in df.columns) else None
352+
self.sequence_key = sequence_key if (SEQUENCE_COLUMN in df.columns) else None
353+
354+
def get_query_ids(self) -> LongTensor:
355+
return BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze()
356+
357+
def get_gallery_ids(self) -> LongTensor:
358+
return BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze()
359+
360+
def __getitem__(self, idx: int) -> Dict[str, Any]:
361+
item = super().__getitem__(idx)
362+
363+
# todo 522: remove
364+
item[self.is_query_key] = bool(self.df[IS_QUERY_COLUMN][idx])
365+
item[self.is_gallery_key] = bool(self.df[IS_GALLERY_COLUMN][idx])
366+
367+
# todo 522: remove
368+
if self.sequence_key:
369+
item[self.sequence_key] = self.df[SEQUENCE_COLUMN][idx]
370+
371+
if self.categories_key:
372+
item[self.categories_key] = self.df[CATEGORIES_COLUMN][idx]
373+
374+
return item
375+
376+
377+
class ImageQueryGalleryLabeledDataset(ImageQueryGalleryDataset, ImageLabeledDataset, IQueryGalleryLabeledDataset):
302378
"""
303-
The dataset of images having `query`/`gallery` split.
379+
This is an annotated dataset of images having `query`/`gallery` split.
304380
305381
Note, that some datasets used as benchmarks in Metric Learning
306382
explicitly provide the splitting information (for example, ``DeepFashion InShop`` dataset), but some of them
@@ -309,7 +385,6 @@ class ImageQueryGalleryLabeledDataset(ImageLabeledDataset, IQueryGalleryLabeledD
309385
310386
So, if you want an item participate in validation as both: query and gallery, you should mark this item as
311387
``is_query == True`` and ``is_gallery == True``, as it's done in the `CARS196` or `CUB200` dataset.
312-
313388
"""
314389

315390
def __init__(
@@ -333,8 +408,8 @@ def __init__(
333408
is_query_key: str = IS_QUERY_KEY,
334409
is_gallery_key: str = IS_GALLERY_KEY,
335410
):
336-
assert all(x in df.columns for x in (IS_QUERY_COLUMN, IS_GALLERY_COLUMN, LABELS_COLUMN))
337-
self._df = df
411+
assert all(x in df.columns for x in (LABELS_COLUMN, IS_GALLERY_COLUMN, IS_QUERY_COLUMN, PATHS_COLUMN))
412+
self.df = df
338413

339414
super().__init__(
340415
df=df,
@@ -344,7 +419,6 @@ def __init__(
344419
f_imread=f_imread,
345420
cache_size=cache_size,
346421
input_tensors_key=input_tensors_key,
347-
labels_key=labels_key,
348422
# todo 522: remove
349423
x1_key=x1_key,
350424
y2_key=y2_key,
@@ -353,25 +427,14 @@ def __init__(
353427
paths_key=paths_key,
354428
categories_key=categories_key,
355429
sequence_key=sequence_key,
430+
is_query_key=is_query_key,
431+
is_gallery_key=is_gallery_key,
356432
)
357-
358-
# todo 522: remove
359-
self.is_query_key = is_query_key
360-
self.is_gallery_key = is_gallery_key
361-
362-
def get_query_ids(self) -> LongTensor:
363-
return BoolTensor(self._df[IS_QUERY_COLUMN]).nonzero().squeeze()
364-
365-
def get_gallery_ids(self) -> LongTensor:
366-
return BoolTensor(self._df[IS_GALLERY_COLUMN]).nonzero().squeeze()
433+
self.labels_key = labels_key
367434

368435
def __getitem__(self, idx: int) -> Dict[str, Any]:
369436
item = super().__getitem__(idx)
370-
item[self.labels_key] = self._df.iloc[idx][LABELS_COLUMN]
371-
372-
# todo 522: remove
373-
item[self.is_query_key] = bool(self._df[IS_QUERY_COLUMN][idx])
374-
item[self.is_gallery_key] = bool(self._df[IS_GALLERY_COLUMN][idx])
437+
item[self.labels_key] = self.df.iloc[idx][LABELS_COLUMN]
375438

376439
return item
377440

@@ -423,6 +486,7 @@ def get_retrieval_images_datasets(
423486
__all__ = [
424487
"ImageBaseDataset",
425488
"ImageLabeledDataset",
489+
"ImageQueryGalleryDataset",
426490
"ImageQueryGalleryLabeledDataset",
427491
"get_retrieval_images_datasets",
428492
]

tests/test_integrations/test_retrieval_validation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from oml.const import EMBEDDINGS_KEY, INPUT_TENSORS_KEY, OVERALL_CATEGORIES_KEY
1010
from oml.metrics.embeddings import EmbeddingMetrics
1111
from tests.test_integrations.utils import (
12-
EmbeddingsQueryGalleryDataset,
12+
EmbeddingsQueryGalleryLabeledDataset,
1313
IdealClusterEncoder,
1414
)
1515

@@ -51,7 +51,7 @@ def get_shared_query_gallery() -> TData:
5151
def test_retrieval_validation(batch_size: int, shuffle: bool, num_workers: int, data: TData) -> None:
5252
labels, query_mask, gallery_mask, input_tensors, cmc_gt = data
5353

54-
dataset = EmbeddingsQueryGalleryDataset(
54+
dataset = EmbeddingsQueryGalleryLabeledDataset(
5555
labels=labels,
5656
embeddings=input_tensors,
5757
is_query=query_mask,

tests/test_integrations/utils.py

+58-14
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
from torch import BoolTensor, FloatTensor, LongTensor, nn
66

77
from oml.const import (
8-
CATEGORIES_COLUMN,
8+
CATEGORIES_KEY,
99
INDEX_KEY,
1010
INPUT_TENSORS_KEY,
1111
IS_GALLERY_KEY,
1212
IS_QUERY_KEY,
1313
LABELS_KEY,
14-
SEQUENCE_COLUMN,
14+
SEQUENCE_KEY,
1515
)
16-
from oml.interfaces.datasets import IQueryGalleryLabeledDataset
16+
from oml.interfaces.datasets import IQueryGalleryDataset, IQueryGalleryLabeledDataset
1717
from oml.utils.misc import one_hot
1818

1919

@@ -35,48 +35,58 @@ def forward(self, labels: torch.Tensor, need_noise: bool = True) -> torch.Tensor
3535
return embeddings
3636

3737

38-
class EmbeddingsQueryGalleryDataset(IQueryGalleryLabeledDataset):
38+
class EmbeddingsQueryGalleryDataset(IQueryGalleryDataset):
3939
def __init__(
4040
self,
4141
embeddings: FloatTensor,
42-
labels: LongTensor,
4342
is_query: BoolTensor,
4443
is_gallery: BoolTensor,
4544
categories: Optional[np.ndarray] = None,
4645
sequence: Optional[np.ndarray] = None,
4746
input_tensors_key: str = INPUT_TENSORS_KEY,
48-
labels_key: str = LABELS_KEY,
4947
index_key: str = INDEX_KEY,
48+
# todo 522: remove keys later
49+
categories_key: str = CATEGORIES_KEY,
50+
sequence_key: str = SEQUENCE_KEY,
5051
):
5152
super().__init__()
52-
assert len(embeddings) == len(labels) == len(is_query) == len(is_gallery)
53+
assert len(embeddings) == len(is_query) == len(is_gallery)
5354

5455
self._embeddings = embeddings
55-
self._labels = labels
5656
self._is_query = is_query
5757
self._is_gallery = is_gallery
5858

59+
# todo 522: remove keys
60+
self.categories_key = categories_key
61+
self.sequence_key = sequence_key
62+
5963
self.extra_data = {}
60-
if categories:
61-
self.extra_data[CATEGORIES_COLUMN] = categories
64+
if categories is not None:
65+
self.extra_data[self.categories_key] = categories
6266

63-
if sequence:
64-
self.extra_data[SEQUENCE_COLUMN] = sequence
67+
if sequence is not None:
68+
self.extra_data[self.sequence_key] = sequence
6569

6670
self.input_tensors_key = input_tensors_key
67-
self.labels_key = labels_key
6871
self.index_key = index_key
6972

7073
def __getitem__(self, idx: int) -> Dict[str, Any]:
7174
batch = {
7275
self.input_tensors_key: self._embeddings[idx],
73-
self.labels_key: self._labels[idx],
7476
self.index_key: idx,
7577
# todo 522: remove
7678
IS_QUERY_KEY: self._is_query[idx],
7779
IS_GALLERY_KEY: self._is_gallery[idx],
7880
}
7981

82+
# todo 522: avoid passing extra data as keys
83+
if self.extra_data:
84+
for key, record in self.extra_data.items():
85+
if key in batch:
86+
raise ValueError(f"<extra_data> and dataset share the same key: {key}")
87+
else:
88+
batch[key] = record[idx]
89+
8090
return batch
8191

8292
def __len__(self) -> int:
@@ -88,5 +98,39 @@ def get_query_ids(self) -> LongTensor:
8898
def get_gallery_ids(self) -> LongTensor:
8999
return self._is_gallery.nonzero().squeeze()
90100

101+
102+
class EmbeddingsQueryGalleryLabeledDataset(EmbeddingsQueryGalleryDataset, IQueryGalleryLabeledDataset):
103+
def __init__(
104+
self,
105+
embeddings: FloatTensor,
106+
labels: LongTensor,
107+
is_query: BoolTensor,
108+
is_gallery: BoolTensor,
109+
categories: Optional[np.ndarray] = None,
110+
sequence: Optional[np.ndarray] = None,
111+
input_tensors_key: str = INPUT_TENSORS_KEY,
112+
labels_key: str = LABELS_KEY,
113+
index_key: str = INDEX_KEY,
114+
):
115+
super().__init__(
116+
embeddings=embeddings,
117+
is_query=is_query,
118+
is_gallery=is_gallery,
119+
categories=categories,
120+
sequence=sequence,
121+
input_tensors_key=input_tensors_key,
122+
index_key=index_key,
123+
)
124+
125+
assert len(embeddings) == len(labels)
126+
127+
self._labels = labels
128+
self.labels_key = labels_key
129+
130+
def __getitem__(self, idx: int) -> Dict[str, Any]:
131+
item = super().__getitem__(idx)
132+
item[self.labels_key] = self._labels[idx]
133+
return item
134+
91135
def get_labels(self) -> np.ndarray:
92136
return np.array(self._labels)

0 commit comments

Comments
 (0)