Skip to content

Commit 08d7048

Browse files
committed
simplified postprocessing and inference
1 parent 4d381ed commit 08d7048

File tree

24 files changed

+275
-667
lines changed

24 files changed

+275
-667
lines changed

docs/readme/examples_source/postprocessing/predict.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ from oml.datasets.base import DatasetQueryGallery
1212
from oml.inference.flat import inference_on_dataframe
1313
from oml.models import ConcatSiamese, ViTExtractor
1414
from oml.registry.transforms import get_transforms_for_pretrained
15-
from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor
15+
from oml.retrieval.postprocessors.pairwise import PairwiseReranker
1616
from oml.utils.download_mock_dataset import download_mock_dataset
1717
from oml.utils.misc_torch import pairwise_dist
1818

@@ -32,7 +32,7 @@ print("\nOriginal predictions:\n", torch.topk(distances, dim=1, k=3, largest=Fal
3232

3333
# 2. Let's initialise a random pairwise postprocessor to perform re-ranking
3434
siamese = ConcatSiamese(extractor=extractor, mlp_hidden_dims=[100]) # Note! Replace it with your trained postprocessor
35-
postprocessor = PairwiseImagesPostprocessor(top_n=3, pairwise_model=siamese, transforms=transforms)
35+
postprocessor = PairwiseReranker(top_n=3, pairwise_model=siamese, transforms=transforms)
3636

3737
dataset = DatasetQueryGallery(df_val, extra_data={"embeddings": emb_val}, transform=transforms)
3838
loader = DataLoader(dataset, batch_size=4)

docs/readme/examples_source/postprocessing/train_val.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ from oml.inference.flat import inference_on_dataframe
1515
from oml.metrics.embeddings import EmbeddingMetrics
1616
from oml.miners.pairs import PairsMiner
1717
from oml.models import ConcatSiamese, ViTExtractor
18-
from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor
18+
from oml.retrieval.postprocessors.pairwise import PairwiseReranker
1919
from oml.samplers.balance import BalanceSampler
2020
from oml.transforms.images.torchvision import get_normalisation_resize_torch
2121
from oml.utils.download_mock_dataset import download_mock_dataset
@@ -54,7 +54,7 @@ for batch in train_loader:
5454
val_dataset = DatasetQueryGallery(df=df_val, extra_data={"embeddings": embeddings_val}, transform=transform)
5555
valid_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
5656

57-
postprocessor = PairwiseImagesPostprocessor(top_n=3, pairwise_model=siamese, transforms=transform)
57+
postprocessor = PairwiseReranker(top_n=3, pairwise_model=siamese, transforms=transform)
5858
calculator = EmbeddingMetrics(postprocessor=postprocessor)
5959
calculator.setup(num_samples=len(val_dataset))
6060

docs/source/contents/datasets.rst

+2-11
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,9 @@ ImageQueryGalleryLabeledDataset
3737
.. automethod:: get_query_ids
3838
.. automethod:: get_gallery_ids
3939

40-
EmbeddingPairsDataset
40+
PairDataset
4141
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
42-
.. autoclass:: oml.datasets.pairs.EmbeddingPairsDataset
43-
:undoc-members:
44-
:show-inheritance:
45-
46-
.. automethod:: __init__
47-
.. automethod:: __getitem__
48-
49-
ImagePairsDataset
50-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
51-
.. autoclass:: oml.datasets.pairs.ImagePairsDataset
42+
.. autoclass:: oml.datasets.pairs.PairDataset
5243
:undoc-members:
5344
:show-inheritance:
5445

docs/source/contents/interfaces.rst

+10-2
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ IQueryGalleryLabeledDataset
8686
.. automethod:: get_gallery_ids
8787
.. automethod:: get_labels
8888

89-
IPairsDataset
89+
IPairDataset
9090
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
91-
.. autoclass:: oml.interfaces.datasets.IPairsDataset
91+
.. autoclass:: oml.interfaces.datasets.IPairDataset
9292
:undoc-members:
9393
:show-inheritance:
9494

@@ -138,3 +138,11 @@ IPipelineLogger
138138

139139
.. automethod:: log_figure
140140
.. automethod:: log_pipeline_info
141+
142+
IRetrievalPostprocessor
143+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
144+
.. autoclass:: oml.interfaces.retrieval.IRetrievalPostprocessor
145+
:undoc-members:
146+
:show-inheritance:
147+
148+
.. automethod:: process

docs/source/contents/postprocessing.rst

+3-29
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,11 @@ Retrieval Post-Processing
77
.. contents::
88
:local:
99

10-
IDistancesPostprocessor
10+
PairwiseReranker
1111
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
12-
.. autoclass:: oml.interfaces.retrieval.IDistancesPostprocessor
13-
:undoc-members:
14-
:show-inheritance:
15-
16-
.. automethod:: process
17-
18-
PairwisePostprocessor
19-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
20-
.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwisePostprocessor
21-
:undoc-members:
22-
:show-inheritance:
23-
24-
.. automethod:: process
25-
.. automethod:: inference
26-
27-
PairwiseEmbeddingsPostprocessor
28-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
29-
.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwiseEmbeddingsPostprocessor
12+
.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwiseReranker
3013
:undoc-members:
3114
:show-inheritance:
3215

3316
.. automethod:: __init__
34-
.. automethod:: inference
35-
36-
PairwiseImagesPostprocessor
37-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
38-
.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwiseImagesPostprocessor
39-
:undoc-members:
40-
:show-inheritance:
41-
42-
.. automethod:: __init__
43-
.. automethod:: inference
17+
.. automethod:: process

oml/configs/postprocessor/pairwise_embeddings.yaml

-11
This file was deleted.

oml/configs/postprocessor/pairwise_images.yaml oml/configs/postprocessor/pairwise_reranker.yaml

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: pairwise_images
1+
name: pairwise_reranker
22
args:
33
top_n: 3
44
pairwise_model:
@@ -12,10 +12,6 @@ args:
1212
remove_fc: True
1313
normalise_features: False
1414
weights: resnet50_moco_v2
15-
transforms:
16-
name: norm_resize_torch
17-
args:
18-
im_size: 224
1915
num_workers: 0
2016
batch_size: 4
2117
verbose: False

oml/datasets/pairs.py

+19-91
Original file line numberDiff line numberDiff line change
@@ -1,115 +1,43 @@
1-
from pathlib import Path
2-
from typing import Any, Dict, List, Optional, Union
1+
from typing import Dict, List, Tuple
32

43
from torch import Tensor
54

6-
from oml.const import INDEX_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY, TBBoxes
7-
from oml.datasets.images import ImageBaseDataset
8-
from oml.interfaces.datasets import IPairsDataset
9-
from oml.transforms.images.torchvision import get_normalisation_torch
10-
from oml.transforms.images.utils import TTransforms
11-
from oml.utils.images.images import TImReader, imread_pillow
5+
from oml.const import INDEX_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY
6+
from oml.interfaces.datasets import IBaseDataset, IPairDataset
127

13-
# todo 522: make one modality agnostic instead of these two
148

15-
16-
class EmbeddingPairsDataset(IPairsDataset):
9+
class PairDataset(IPairDataset):
1710
"""
18-
Dataset to iterate over pairs of embeddings.
11+
Dataset to iterate over pairs of items.
1912
2013
"""
2114

2215
def __init__(
2316
self,
24-
embeddings1: Tensor,
25-
embeddings2: Tensor,
17+
base_dataset: IBaseDataset,
18+
pair_ids: List[Tuple[int, int]],
2619
pair_1st_key: str = PAIR_1ST_KEY,
2720
pair_2nd_key: str = PAIR_2ND_KEY,
2821
index_key: str = INDEX_KEY,
2922
):
30-
"""
31-
32-
Args:
33-
embeddings1: The first input embeddings
34-
embeddings2: The second input embeddings
35-
pair_1st_key: Key to put ``embeddings1`` into the batches
36-
pair_2nd_key: Key to put ``embeddings2`` into the batches
37-
index_key: Key to put samples' ids into the batches
38-
39-
"""
40-
assert embeddings1.shape == embeddings2.shape
41-
assert embeddings1.ndim >= 2
23+
self.base_dataset = base_dataset
24+
self.pair_ids = pair_ids
4225

4326
self.pair_1st_key = pair_1st_key
4427
self.pair_2nd_key = pair_2nd_key
45-
self.index_key = index_key
46-
47-
self.embeddings1 = embeddings1
48-
self.embeddings2 = embeddings2
28+
self.index_key: str = index_key
4929

5030
def __getitem__(self, idx: int) -> Dict[str, Tensor]:
51-
return {self.pair_1st_key: self.embeddings1[idx], self.pair_2nd_key: self.embeddings2[idx], self.index_key: idx}
52-
53-
def __len__(self) -> int:
54-
return len(self.embeddings1)
55-
56-
57-
class ImagePairsDataset(IPairsDataset):
58-
"""
59-
Dataset to iterate over pairs of images.
60-
61-
"""
62-
63-
def __init__(
64-
self,
65-
paths1: List[Path],
66-
paths2: List[Path],
67-
bboxes1: Optional[TBBoxes] = None,
68-
bboxes2: Optional[TBBoxes] = None,
69-
transform: Optional[TTransforms] = None,
70-
f_imread: TImReader = imread_pillow,
71-
pair_1st_key: str = PAIR_1ST_KEY,
72-
pair_2nd_key: str = PAIR_2ND_KEY,
73-
index_key: str = INDEX_KEY,
74-
cache_size: Optional[int] = 0,
75-
):
76-
"""
77-
Args:
78-
paths1: Paths to the 1st input images
79-
paths2: Paths to the 2nd input images
80-
bboxes1: Should be either ``None`` or a sequence of bboxes.
81-
If an image has ``N`` boxes, duplicate its
82-
path ``N`` times and provide bounding box for each of them.
83-
If you want to get an embedding for the whole image, set bbox to ``None`` for
84-
this particular image path. The format is ``x1, y1, x2, y2``.
85-
bboxes2: The same as ``bboxes2``, but for the second inputs.
86-
transform: Augmentations for the images, set ``None`` to perform only normalisation and casting to tensor
87-
f_imread: Function to read the images
88-
pair_1st_key: Key to put the 1st images into the batches
89-
pair_2nd_key: Key to put the 2nd images into the batches
90-
index_key: Key to put samples' ids into the batches
91-
cache_size: Size of the dataset's cache
92-
93-
"""
94-
assert len(paths1) == len(paths2)
95-
96-
if transform is None:
97-
transform = get_normalisation_torch()
98-
99-
cache_size = cache_size // 2 if cache_size else None
100-
dataset_args = {"transform": transform, "f_imread": f_imread, "cache_size": cache_size}
101-
self.dataset1 = ImageBaseDataset(paths=paths1, bboxes=bboxes1, **dataset_args)
102-
self.dataset2 = ImageBaseDataset(paths=paths2, bboxes=bboxes2, **dataset_args)
103-
104-
self.pair_1st_key = pair_1st_key
105-
self.pair_2nd_key = pair_2nd_key
106-
self.index_key = index_key
107-
108-
def __getitem__(self, idx: int) -> Dict[str, Union[int, Dict[str, Any]]]:
109-
return {self.pair_1st_key: self.dataset1[idx], self.pair_2nd_key: self.dataset2[idx], self.index_key: idx}
31+
i1, i2 = self.pair_ids[idx]
32+
key = self.base_dataset.input_tensors_key
33+
return {
34+
self.pair_1st_key: self.base_dataset[i1][key],
35+
self.pair_2nd_key: self.base_dataset[i2][key],
36+
self.index_key: idx,
37+
}
11038

11139
def __len__(self) -> int:
112-
return len(self.dataset1)
40+
return len(self.pair_ids)
11341

11442

115-
__all__ = ["EmbeddingPairsDataset", "ImagePairsDataset"]
43+
__all__ = ["PairDataset"]

oml/inference/flat.py

-100
This file was deleted.

0 commit comments

Comments
 (0)