Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made inference modality agnostic in re-ranking and other parts of the repo #542

Merged
merged 27 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/readme/examples_source/postprocessing/predict.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ from oml.datasets.base import DatasetQueryGallery
from oml.inference.flat import inference_on_dataframe
from oml.models import ConcatSiamese, ViTExtractor
from oml.registry.transforms import get_transforms_for_pretrained
from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor
from oml.retrieval.postprocessors.pairwise import PairwiseReranker
from oml.utils.download_mock_dataset import download_mock_dataset
from oml.utils.misc_torch import pairwise_dist

Expand All @@ -32,7 +32,7 @@ print("\nOriginal predictions:\n", torch.topk(distances, dim=1, k=3, largest=Fal

# 2. Let's initialise a random pairwise postprocessor to perform re-ranking
siamese = ConcatSiamese(extractor=extractor, mlp_hidden_dims=[100]) # Note! Replace it with your trained postprocessor
postprocessor = PairwiseImagesPostprocessor(top_n=3, pairwise_model=siamese, transforms=transforms)
postprocessor = PairwiseReranker(top_n=3, pairwise_model=siamese, transforms=transforms)

dataset = DatasetQueryGallery(df_val, extra_data={"embeddings": emb_val}, transform=transforms)
loader = DataLoader(dataset, batch_size=4)
Expand Down
4 changes: 2 additions & 2 deletions docs/readme/examples_source/postprocessing/train_val.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from oml.inference.flat import inference_on_dataframe
from oml.metrics.embeddings import EmbeddingMetrics
from oml.miners.pairs import PairsMiner
from oml.models import ConcatSiamese, ViTExtractor
from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor
from oml.retrieval.postprocessors.pairwise import PairwiseReranker
from oml.samplers.balance import BalanceSampler
from oml.transforms.images.torchvision import get_normalisation_resize_torch
from oml.utils.download_mock_dataset import download_mock_dataset
Expand Down Expand Up @@ -54,7 +54,7 @@ for batch in train_loader:
val_dataset = DatasetQueryGallery(df=df_val, extra_data={"embeddings": embeddings_val}, transform=transform)
valid_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

postprocessor = PairwiseImagesPostprocessor(top_n=3, pairwise_model=siamese, transforms=transform)
postprocessor = PairwiseReranker(top_n=3, pairwise_model=siamese, transforms=transform)
calculator = EmbeddingMetrics(postprocessor=postprocessor)
calculator.setup(num_samples=len(val_dataset))

Expand Down
18 changes: 0 additions & 18 deletions docs/source/contents/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,3 @@ ImageQueryGalleryDataset
.. automethod:: get_query_ids
.. automethod:: get_gallery_ids
.. automethod:: visualize

EmbeddingPairsDataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: oml.datasets.pairs.EmbeddingPairsDataset
:undoc-members:
:show-inheritance:

.. automethod:: __init__
.. automethod:: __getitem__

ImagePairsDataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: oml.datasets.pairs.ImagePairsDataset
:undoc-members:
:show-inheritance:

.. automethod:: __init__
.. automethod:: __getitem__
14 changes: 12 additions & 2 deletions docs/source/contents/interfaces.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ IBaseDataset
:undoc-members:
:show-inheritance:

.. automethod:: __getitem__

ILabeledDataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: oml.interfaces.datasets.ILabeledDataset
Expand Down Expand Up @@ -86,9 +88,9 @@ IQueryGalleryLabeledDataset
.. automethod:: get_gallery_ids
.. automethod:: get_labels

IPairsDataset
IPairDataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: oml.interfaces.datasets.IPairsDataset
.. autoclass:: oml.interfaces.datasets.IPairDataset
:undoc-members:
:show-inheritance:

Expand Down Expand Up @@ -138,3 +140,11 @@ IPipelineLogger

.. automethod:: log_figure
.. automethod:: log_pipeline_info

IRetrievalPostprocessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: oml.interfaces.retrieval.IRetrievalPostprocessor
:undoc-members:
:show-inheritance:

.. automethod:: process
32 changes: 3 additions & 29 deletions docs/source/contents/postprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,11 @@ Retrieval Post-Processing
.. contents::
:local:

IDistancesPostprocessor
PairwiseReranker
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: oml.interfaces.retrieval.IDistancesPostprocessor
:undoc-members:
:show-inheritance:

.. automethod:: process

PairwisePostprocessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwisePostprocessor
:undoc-members:
:show-inheritance:

.. automethod:: process
.. automethod:: inference

PairwiseEmbeddingsPostprocessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwiseEmbeddingsPostprocessor
.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwiseReranker
:undoc-members:
:show-inheritance:

.. automethod:: __init__
.. automethod:: inference

PairwiseImagesPostprocessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwiseImagesPostprocessor
:undoc-members:
:show-inheritance:

.. automethod:: __init__
.. automethod:: inference
.. automethod:: process
11 changes: 0 additions & 11 deletions oml/configs/postprocessor/pairwise_embeddings.yaml

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: pairwise_images
name: pairwise_reranker
args:
top_n: 3
pairwise_model:
Expand All @@ -12,10 +12,6 @@ args:
remove_fc: True
normalise_features: False
weights: resnet50_moco_v2
transforms:
name: norm_resize_torch
args:
im_size: 224
num_workers: 0
batch_size: 4
verbose: False
6 changes: 6 additions & 0 deletions oml/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from oml.datasets.images import (
ImageBaseDataset,
ImageLabeledDataset,
ImageQueryGalleryLabeledDataset,
)
from oml.datasets.pairs import PairDataset
2 changes: 0 additions & 2 deletions oml/datasets/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,6 @@ def get_gallery_ids(self) -> LongTensor:

def __getitem__(self, idx: int) -> Dict[str, Any]:
data = super().__getitem__(idx)
data[self.labels_key] = self.df.iloc[idx][LABELS_COLUMN]

# todo 522: remove
data[self.is_query_key] = bool(self.df[IS_QUERY_COLUMN][idx])
Expand Down Expand Up @@ -451,7 +450,6 @@ def get_retrieval_images_datasets(

check_retrieval_dataframe_format(df, dataset_root=dataset_root, verbose=verbose)

# todo 522: why do we need it?
# first half will consist of "train" split, second one of "val"
# so labels in train will be from 0 to N-1 and labels in test will be from N to K
mapper = {l: i for i, l in enumerate(df.sort_values(by=[SPLIT_COLUMN])[LABELS_COLUMN].unique())}
Expand Down
108 changes: 19 additions & 89 deletions oml/datasets/pairs.py
Original file line number Diff line number Diff line change
@@ -1,115 +1,45 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Dict, List, Tuple

from torch import Tensor

from oml.const import INDEX_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY, TBBoxes
from oml.datasets.images import ImageBaseDataset
from oml.interfaces.datasets import IPairsDataset
from oml.transforms.images.torchvision import get_normalisation_torch
from oml.transforms.images.utils import TTransforms
from oml.utils.images.images import TImReader, imread_pillow
from oml.const import INDEX_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY
from oml.interfaces.datasets import IBaseDataset, IPairDataset

# todo 522: make one modality agnostic instead of these two


class EmbeddingPairsDataset(IPairsDataset):
class PairDataset(IPairDataset):
"""
Dataset to iterate over pairs of embeddings.
Dataset to iterate over pairs of items.

"""

def __init__(
self,
embeddings1: Tensor,
embeddings2: Tensor,
base_dataset: IBaseDataset,
pair_ids: List[Tuple[int, int]],
pair_1st_key: str = PAIR_1ST_KEY,
pair_2nd_key: str = PAIR_2ND_KEY,
index_key: str = INDEX_KEY,
):
"""

Args:
embeddings1: The first input embeddings
embeddings2: The second input embeddings
pair_1st_key: Key to put ``embeddings1`` into the batches
pair_2nd_key: Key to put ``embeddings2`` into the batches
index_key: Key to put samples' ids into the batches

"""
assert embeddings1.shape == embeddings2.shape
assert embeddings1.ndim >= 2
self.base_dataset = base_dataset
self.pair_ids = pair_ids

self.pair_1st_key = pair_1st_key
self.pair_2nd_key = pair_2nd_key
self.index_key = index_key

self.embeddings1 = embeddings1
self.embeddings2 = embeddings2
self.index_key: str = index_key

def __getitem__(self, idx: int) -> Dict[str, Tensor]:
return {self.pair_1st_key: self.embeddings1[idx], self.pair_2nd_key: self.embeddings2[idx], self.index_key: idx}

def __len__(self) -> int:
return len(self.embeddings1)


class ImagePairsDataset(IPairsDataset):
"""
Dataset to iterate over pairs of images.

"""

def __init__(
self,
paths1: List[Path],
paths2: List[Path],
bboxes1: Optional[TBBoxes] = None,
bboxes2: Optional[TBBoxes] = None,
transform: Optional[TTransforms] = None,
f_imread: TImReader = imread_pillow,
pair_1st_key: str = PAIR_1ST_KEY,
pair_2nd_key: str = PAIR_2ND_KEY,
index_key: str = INDEX_KEY,
cache_size: Optional[int] = 0,
):
"""
Args:
paths1: Paths to the 1st input images
paths2: Paths to the 2nd input images
bboxes1: Should be either ``None`` or a sequence of bboxes.
If an image has ``N`` boxes, duplicate its
path ``N`` times and provide bounding box for each of them.
If you want to get an embedding for the whole image, set bbox to ``None`` for
this particular image path. The format is ``x1, y1, x2, y2``.
bboxes2: The same as ``bboxes2``, but for the second inputs.
transform: Augmentations for the images, set ``None`` to perform only normalisation and casting to tensor
f_imread: Function to read the images
pair_1st_key: Key to put the 1st images into the batches
pair_2nd_key: Key to put the 2nd images into the batches
index_key: Key to put samples' ids into the batches
cache_size: Size of the dataset's cache

"""
assert len(paths1) == len(paths2)

if transform is None:
transform = get_normalisation_torch()

cache_size = cache_size // 2 if cache_size else None
dataset_args = {"transform": transform, "f_imread": f_imread, "cache_size": cache_size}
self.dataset1 = ImageBaseDataset(paths=paths1, bboxes=bboxes1, **dataset_args)
self.dataset2 = ImageBaseDataset(paths=paths2, bboxes=bboxes2, **dataset_args)

self.pair_1st_key = pair_1st_key
self.pair_2nd_key = pair_2nd_key
self.index_key = index_key

def __getitem__(self, idx: int) -> Dict[str, Union[int, Dict[str, Any]]]:
return {self.pair_1st_key: self.dataset1[idx], self.pair_2nd_key: self.dataset2[idx], self.index_key: idx}
i1, i2 = self.pair_ids[idx]
key = self.base_dataset.input_tensors_key
return {
self.pair_1st_key: self.base_dataset[i1][key],
self.pair_2nd_key: self.base_dataset[i2][key],
self.index_key: idx,
}

def __len__(self) -> int:
return len(self.dataset1)
return len(self.pair_ids)


__all__ = ["EmbeddingPairsDataset", "ImagePairsDataset"]
__all__ = ["PairDataset"]
1 change: 1 addition & 0 deletions oml/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from oml.inference.abstract import inference, inference_cached, pairwise_inference
Loading
Loading