|
1 |
| -from pathlib import Path |
2 |
| -from typing import Any, Dict, List, Optional, Union |
| 1 | +from typing import Dict, List, Tuple |
3 | 2 |
|
4 | 3 | from torch import Tensor
|
5 | 4 |
|
6 |
| -from oml.const import INDEX_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY, TBBoxes |
7 |
| -from oml.datasets.images import ImageBaseDataset |
8 |
| -from oml.interfaces.datasets import IPairsDataset |
9 |
| -from oml.transforms.images.torchvision import get_normalisation_torch |
10 |
| -from oml.transforms.images.utils import TTransforms |
11 |
| -from oml.utils.images.images import TImReader, imread_pillow |
| 5 | +from oml.const import INDEX_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY |
| 6 | +from oml.interfaces.datasets import IBaseDataset, IPairDataset |
12 | 7 |
|
13 |
| -# todo 522: make one modality agnostic instead of these two |
14 | 8 |
|
15 |
| - |
16 |
| -class EmbeddingPairsDataset(IPairsDataset): |
| 9 | +class PairDataset(IPairDataset): |
17 | 10 | """
|
18 |
| - Dataset to iterate over pairs of embeddings. |
| 11 | + Dataset to iterate over pairs of items. |
19 | 12 |
|
20 | 13 | """
|
21 | 14 |
|
22 | 15 | def __init__(
|
23 | 16 | self,
|
24 |
| - embeddings1: Tensor, |
25 |
| - embeddings2: Tensor, |
| 17 | + base_dataset: IBaseDataset, |
| 18 | + pair_ids: List[Tuple[int, int]], |
26 | 19 | pair_1st_key: str = PAIR_1ST_KEY,
|
27 | 20 | pair_2nd_key: str = PAIR_2ND_KEY,
|
28 | 21 | index_key: str = INDEX_KEY,
|
29 | 22 | ):
|
30 |
| - """ |
31 |
| -
|
32 |
| - Args: |
33 |
| - embeddings1: The first input embeddings |
34 |
| - embeddings2: The second input embeddings |
35 |
| - pair_1st_key: Key to put ``embeddings1`` into the batches |
36 |
| - pair_2nd_key: Key to put ``embeddings2`` into the batches |
37 |
| - index_key: Key to put samples' ids into the batches |
38 |
| -
|
39 |
| - """ |
40 |
| - assert embeddings1.shape == embeddings2.shape |
41 |
| - assert embeddings1.ndim >= 2 |
| 23 | + self.base_dataset = base_dataset |
| 24 | + self.pair_ids = pair_ids |
42 | 25 |
|
43 | 26 | self.pair_1st_key = pair_1st_key
|
44 | 27 | self.pair_2nd_key = pair_2nd_key
|
45 |
| - self.index_key = index_key |
46 |
| - |
47 |
| - self.embeddings1 = embeddings1 |
48 |
| - self.embeddings2 = embeddings2 |
| 28 | + self.index_key: str = index_key |
49 | 29 |
|
50 | 30 | def __getitem__(self, idx: int) -> Dict[str, Tensor]:
|
51 |
| - return {self.pair_1st_key: self.embeddings1[idx], self.pair_2nd_key: self.embeddings2[idx], self.index_key: idx} |
52 |
| - |
53 |
| - def __len__(self) -> int: |
54 |
| - return len(self.embeddings1) |
55 |
| - |
56 |
| - |
57 |
| -class ImagePairsDataset(IPairsDataset): |
58 |
| - """ |
59 |
| - Dataset to iterate over pairs of images. |
60 |
| -
|
61 |
| - """ |
62 |
| - |
63 |
| - def __init__( |
64 |
| - self, |
65 |
| - paths1: List[Path], |
66 |
| - paths2: List[Path], |
67 |
| - bboxes1: Optional[TBBoxes] = None, |
68 |
| - bboxes2: Optional[TBBoxes] = None, |
69 |
| - transform: Optional[TTransforms] = None, |
70 |
| - f_imread: TImReader = imread_pillow, |
71 |
| - pair_1st_key: str = PAIR_1ST_KEY, |
72 |
| - pair_2nd_key: str = PAIR_2ND_KEY, |
73 |
| - index_key: str = INDEX_KEY, |
74 |
| - cache_size: Optional[int] = 0, |
75 |
| - ): |
76 |
| - """ |
77 |
| - Args: |
78 |
| - paths1: Paths to the 1st input images |
79 |
| - paths2: Paths to the 2nd input images |
80 |
| - bboxes1: Should be either ``None`` or a sequence of bboxes. |
81 |
| - If an image has ``N`` boxes, duplicate its |
82 |
| - path ``N`` times and provide bounding box for each of them. |
83 |
| - If you want to get an embedding for the whole image, set bbox to ``None`` for |
84 |
| - this particular image path. The format is ``x1, y1, x2, y2``. |
85 |
| - bboxes2: The same as ``bboxes2``, but for the second inputs. |
86 |
| - transform: Augmentations for the images, set ``None`` to perform only normalisation and casting to tensor |
87 |
| - f_imread: Function to read the images |
88 |
| - pair_1st_key: Key to put the 1st images into the batches |
89 |
| - pair_2nd_key: Key to put the 2nd images into the batches |
90 |
| - index_key: Key to put samples' ids into the batches |
91 |
| - cache_size: Size of the dataset's cache |
92 |
| -
|
93 |
| - """ |
94 |
| - assert len(paths1) == len(paths2) |
95 |
| - |
96 |
| - if transform is None: |
97 |
| - transform = get_normalisation_torch() |
98 |
| - |
99 |
| - cache_size = cache_size // 2 if cache_size else None |
100 |
| - dataset_args = {"transform": transform, "f_imread": f_imread, "cache_size": cache_size} |
101 |
| - self.dataset1 = ImageBaseDataset(paths=paths1, bboxes=bboxes1, **dataset_args) |
102 |
| - self.dataset2 = ImageBaseDataset(paths=paths2, bboxes=bboxes2, **dataset_args) |
103 |
| - |
104 |
| - self.pair_1st_key = pair_1st_key |
105 |
| - self.pair_2nd_key = pair_2nd_key |
106 |
| - self.index_key = index_key |
107 |
| - |
108 |
| - def __getitem__(self, idx: int) -> Dict[str, Union[int, Dict[str, Any]]]: |
109 |
| - return {self.pair_1st_key: self.dataset1[idx], self.pair_2nd_key: self.dataset2[idx], self.index_key: idx} |
| 31 | + i1, i2 = self.pair_ids[idx] |
| 32 | + key = self.base_dataset.input_tensors_key |
| 33 | + return { |
| 34 | + self.pair_1st_key: self.base_dataset[i1][key], |
| 35 | + self.pair_2nd_key: self.base_dataset[i2][key], |
| 36 | + self.index_key: idx, |
| 37 | + } |
110 | 38 |
|
111 | 39 | def __len__(self) -> int:
|
112 |
| - return len(self.dataset1) |
| 40 | + return len(self.pair_ids) |
113 | 41 |
|
114 | 42 |
|
115 |
| -__all__ = ["EmbeddingPairsDataset", "ImagePairsDataset"] |
| 43 | +__all__ = ["PairDataset"] |
0 commit comments