Skip to content

Commit 8f7023a

Browse files
authored
Made inference modality agnostic in re-ranking and other parts of the repo
Changelog (all the functions and classes on the right side are modality agnostic): * `EmbeddingPairsDataset`, `ImagePairsDataset` -> `PairDataset` * `pairwise_inference_on_images`, `pairwise_inference_on_embeddings` -> `pairwise_inference` * `IDistancesPostprocessor` -> (mostly renamed) -> `IRetrievalPostprocessor` * `PairwisePostprocessor`, `PairwiseEmbeddingsPostprocessor`, `PairwiseImagesPostprocessor` -> `PairwiseReranker` * `inference_on_images` -> `inference` * `inference_on_dataframe` -> `inference_cached` Also: * `EmbeddingMetrics` takes optional `dataset` argument in order to perform postprocessing. * Made postprocessing tests a bit more informative via making dummy models a bit less trivial (added bias to their outputs) Examples changed: * `train + val` and `prediction` for postprocessor * `retrieval usage` * + added `global_paths` parameter to `download_mock_dataset` so it looks nicer
1 parent 749c326 commit 8f7023a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+819
-1108
lines changed

README.md

+19-23
Original file line numberDiff line numberDiff line change
@@ -301,13 +301,12 @@ from oml.models import ViTExtractor
301301
from oml.samplers.balance import BalanceSampler
302302
from oml.utils.download_mock_dataset import download_mock_dataset
303303

304-
dataset_root = "mock_dataset/"
305-
df_train, _ = download_mock_dataset(dataset_root)
304+
df_train, _ = download_mock_dataset(global_paths=True)
306305

307306
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
308307
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
309308

310-
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
309+
train_dataset = DatasetWithLabels(df_train)
311310
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner(), need_logs=True)
312311
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
313312
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler)
@@ -342,12 +341,11 @@ from oml.metrics.embeddings import EmbeddingMetrics
342341
from oml.models import ViTExtractor
343342
from oml.utils.download_mock_dataset import download_mock_dataset
344343

345-
dataset_root = "mock_dataset/"
346-
_, df_val = download_mock_dataset(dataset_root)
344+
_, df_val = download_mock_dataset(global_paths=True)
347345

348346
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval()
349347

350-
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
348+
val_dataset = DatasetQueryGallery(df_val)
351349

352350
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
353351
calculator = EmbeddingMetrics(extra_keys=("paths",))
@@ -401,21 +399,20 @@ from oml.lightning.pipelines.logging import (
401399
WandBPipelineLogger,
402400
)
403401

404-
dataset_root = "mock_dataset/"
405-
df_train, df_val = download_mock_dataset(dataset_root)
402+
df_train, df_val = download_mock_dataset(global_paths=True)
406403

407404
# model
408405
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)
409406

410407
# train
411408
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
412-
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
409+
train_dataset = DatasetWithLabels(df_train)
413410
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
414411
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
415412
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler)
416413

417414
# val
418-
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
415+
val_dataset = DatasetQueryGallery(df_val)
419416
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
420417
metric_callback = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[train_dataset.paths_key,]), log_images=True)
421418

@@ -455,37 +452,36 @@ trainer.fit(pl_model, train_dataloaders=train_loader, val_dataloaders=val_loader
455452
```python
456453
import torch
457454

458-
from oml.const import MOCK_DATASET_PATH
459-
from oml.inference.flat import inference_on_images
455+
from oml.datasets import ImageQueryGalleryDataset
456+
from oml.inference import inference
460457
from oml.models import ViTExtractor
461458
from oml.registry.transforms import get_transforms_for_pretrained
462459
from oml.utils.download_mock_dataset import download_mock_dataset
463460
from oml.utils.misc_torch import pairwise_dist
464461

465-
_, df_val = download_mock_dataset(MOCK_DATASET_PATH)
466-
df_val["path"] = df_val["path"].apply(lambda x: MOCK_DATASET_PATH / x)
467-
queries = df_val[df_val["is_query"]]["path"].tolist()
468-
galleries = df_val[df_val["is_gallery"]]["path"].tolist()
462+
_, df_test = download_mock_dataset(global_paths=True)
463+
del df_test["label"] # we don't need gt labels for doing predictions
469464

470465
extractor = ViTExtractor.from_pretrained("vits16_dino")
471466
transform, _ = get_transforms_for_pretrained("vits16_dino")
472467

473-
args = {"num_workers": 0, "batch_size": 8}
474-
features_queries = inference_on_images(extractor, paths=queries, transform=transform, **args)
475-
features_galleries = inference_on_images(extractor, paths=galleries, transform=transform, **args)
468+
dataset = ImageQueryGalleryDataset(df_test, transform=transform)
469+
470+
embeddings = inference(extractor, dataset, batch_size=4, num_workers=0)
471+
embeddings_query = embeddings[dataset.get_query_ids()]
472+
embeddings_gallery = embeddings[dataset.get_gallery_ids()]
476473

477474
# Now we can explicitly build pairwise matrix of distances or save you RAM via using kNN
478475
use_knn = False
479476
top_k = 3
480477

481478
if use_knn:
482479
from sklearn.neighbors import NearestNeighbors
483-
knn = NearestNeighbors(algorithm="auto", p=2)
484-
knn.fit(features_galleries)
485-
dists, ii_closest = knn.kneighbors(features_queries, n_neighbors=top_k, return_distance=True)
480+
knn = NearestNeighbors(algorithm="auto", p=2).fit(embeddings_query)
481+
dists, ii_closest = knn.kneighbors(embeddings_gallery, n_neighbors=top_k, return_distance=True)
486482

487483
else:
488-
dist_mat = pairwise_dist(x1=features_queries, x2=features_galleries)
484+
dist_mat = pairwise_dist(x1=embeddings_query, x2=embeddings_gallery, p=2)
489485
dists, ii_closest = torch.topk(dist_mat, dim=1, k=top_k, largest=False)
490486

491487
print(f"Top {top_k} items closest to queries are:\n {ii_closest}")

docs/readme/examples_source/extractor/retrieval_usage.md

+12-13
Original file line numberDiff line numberDiff line change
@@ -6,37 +6,36 @@
66
```python
77
import torch
88

9-
from oml.const import MOCK_DATASET_PATH
10-
from oml.inference.flat import inference_on_images
9+
from oml.datasets import ImageQueryGalleryDataset
10+
from oml.inference import inference
1111
from oml.models import ViTExtractor
1212
from oml.registry.transforms import get_transforms_for_pretrained
1313
from oml.utils.download_mock_dataset import download_mock_dataset
1414
from oml.utils.misc_torch import pairwise_dist
1515

16-
_, df_val = download_mock_dataset(MOCK_DATASET_PATH)
17-
df_val["path"] = df_val["path"].apply(lambda x: MOCK_DATASET_PATH / x)
18-
queries = df_val[df_val["is_query"]]["path"].tolist()
19-
galleries = df_val[df_val["is_gallery"]]["path"].tolist()
16+
_, df_test = download_mock_dataset(global_paths=True)
17+
del df_test["label"] # we don't need gt labels for doing predictions
2018

2119
extractor = ViTExtractor.from_pretrained("vits16_dino")
2220
transform, _ = get_transforms_for_pretrained("vits16_dino")
2321

24-
args = {"num_workers": 0, "batch_size": 8}
25-
features_queries = inference_on_images(extractor, paths=queries, transform=transform, **args)
26-
features_galleries = inference_on_images(extractor, paths=galleries, transform=transform, **args)
22+
dataset = ImageQueryGalleryDataset(df_test, transform=transform)
23+
24+
embeddings = inference(extractor, dataset, batch_size=4, num_workers=0)
25+
embeddings_query = embeddings[dataset.get_query_ids()]
26+
embeddings_gallery = embeddings[dataset.get_gallery_ids()]
2727

2828
# Now we can explicitly build pairwise matrix of distances or save you RAM via using kNN
2929
use_knn = False
3030
top_k = 3
3131

3232
if use_knn:
3333
from sklearn.neighbors import NearestNeighbors
34-
knn = NearestNeighbors(algorithm="auto", p=2)
35-
knn.fit(features_galleries)
36-
dists, ii_closest = knn.kneighbors(features_queries, n_neighbors=top_k, return_distance=True)
34+
knn = NearestNeighbors(algorithm="auto", p=2).fit(embeddings_query)
35+
dists, ii_closest = knn.kneighbors(embeddings_gallery, n_neighbors=top_k, return_distance=True)
3736

3837
else:
39-
dist_mat = pairwise_dist(x1=features_queries, x2=features_galleries)
38+
dist_mat = pairwise_dist(x1=embeddings_query, x2=embeddings_gallery, p=2)
4039
dists, ii_closest = torch.topk(dist_mat, dim=1, k=top_k, largest=False)
4140

4241
print(f"Top {top_k} items closest to queries are:\n {ii_closest}")

docs/readme/examples_source/extractor/train.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@ from oml.models import ViTExtractor
1414
from oml.samplers.balance import BalanceSampler
1515
from oml.utils.download_mock_dataset import download_mock_dataset
1616

17-
dataset_root = "mock_dataset/"
18-
df_train, _ = download_mock_dataset(dataset_root)
17+
df_train, _ = download_mock_dataset(global_paths=True)
1918

2019
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
2120
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
2221

23-
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
22+
train_dataset = DatasetWithLabels(df_train)
2423
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner(), need_logs=True)
2524
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
2625
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler)

docs/readme/examples_source/extractor/train_2loaders_val.md

+3-6
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,18 @@ from oml.models import ViTExtractor
1515
from oml.transforms.images.torchvision import get_normalisation_resize_torch
1616
from oml.utils.download_mock_dataset import download_mock_dataset
1717

18-
dataset_root = "mock_dataset/"
19-
_, df_val = download_mock_dataset(dataset_root)
18+
_, df_val = download_mock_dataset(global_paths=True)
2019

2120
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)
2221

2322
# 1st validation dataset (big images)
24-
val_dataset_1 = DatasetQueryGallery(df_val, dataset_root=dataset_root,
25-
transform=get_normalisation_resize_torch(im_size=224))
23+
val_dataset_1 = DatasetQueryGallery(df_val, transform=get_normalisation_resize_torch(im_size=224))
2624
val_loader_1 = torch.utils.data.DataLoader(val_dataset_1, batch_size=4)
2725
metric_callback_1 = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[val_dataset_1.paths_key,]),
2826
log_images=True, loader_idx=0)
2927

3028
# 2nd validation dataset (small images)
31-
val_dataset_2 = DatasetQueryGallery(df_val, dataset_root=dataset_root,
32-
transform=get_normalisation_resize_torch(im_size=48))
29+
val_dataset_2 = DatasetQueryGallery(df_val, transform=get_normalisation_resize_torch(im_size=48))
3330
val_loader_2 = torch.utils.data.DataLoader(val_dataset_2, batch_size=4)
3431
metric_callback_2 = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[val_dataset_2.paths_key,]),
3532
log_images=True, loader_idx=1)

docs/readme/examples_source/extractor/train_val_pl.md

+3-4
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,20 @@ from oml.lightning.pipelines.logging import (
2424
WandBPipelineLogger,
2525
)
2626

27-
dataset_root = "mock_dataset/"
28-
df_train, df_val = download_mock_dataset(dataset_root)
27+
df_train, df_val = download_mock_dataset(global_paths=True)
2928

3029
# model
3130
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)
3231

3332
# train
3433
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
35-
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
34+
train_dataset = DatasetWithLabels(df_train)
3635
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
3736
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
3837
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler)
3938

4039
# val
41-
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
40+
val_dataset = DatasetQueryGallery(df_val)
4241
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
4342
metric_callback = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[train_dataset.paths_key,]), log_images=True)
4443

docs/readme/examples_source/extractor/train_val_pl_ddp.md

+3-4
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,20 @@ from oml.samplers.balance import BalanceSampler
1919
from oml.utils.download_mock_dataset import download_mock_dataset
2020
from pytorch_lightning.strategies import DDPStrategy
2121

22-
dataset_root = "mock_dataset/"
23-
df_train, df_val = download_mock_dataset(dataset_root)
22+
df_train, df_val = download_mock_dataset(global_paths=True)
2423

2524
# model
2625
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)
2726

2827
# train
2928
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
30-
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
29+
train_dataset = DatasetWithLabels(df_train)
3130
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
3231
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
3332
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler)
3433

3534
# val
36-
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
35+
val_dataset = DatasetQueryGallery(df_val)
3736
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
3837
metric_callback = MetricValCallbackDDP(metric=EmbeddingMetricsDDP()) # DDP specific
3938

docs/readme/examples_source/extractor/train_with_pml.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@ from oml.utils.download_mock_dataset import download_mock_dataset
1313

1414
from pytorch_metric_learning import losses, distances, reducers, miners
1515

16-
dataset_root = "mock_dataset/"
17-
df_train, _ = download_mock_dataset(dataset_root)
16+
df_train, _ = download_mock_dataset(global_paths=True)
1817

1918
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
2019
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
2120

22-
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
21+
train_dataset = DatasetWithLabels(df_train)
2322

2423
# PML specific
2524
# criterion = losses.TripletMarginLoss(margin=0.2, triplets_per_anchor="all")

docs/readme/examples_source/extractor/train_with_pml_advanced.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@ from oml.utils.download_mock_dataset import download_mock_dataset
1313

1414
from pytorch_metric_learning import losses, distances, reducers, miners
1515

16-
dataset_root = "mock_dataset/"
17-
df_train, _ = download_mock_dataset(dataset_root)
16+
df_train, _ = download_mock_dataset(global_paths=True)
1817

1918
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
2019
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
2120

22-
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
21+
train_dataset = DatasetWithLabels(df_train)
2322

2423
# PML specific
2524
distance = distances.LpDistance(p=2)

docs/readme/examples_source/extractor/val.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@ from oml.metrics.embeddings import EmbeddingMetrics
1212
from oml.models import ViTExtractor
1313
from oml.utils.download_mock_dataset import download_mock_dataset
1414

15-
dataset_root = "mock_dataset/"
16-
_, df_val = download_mock_dataset(dataset_root)
15+
_, df_val = download_mock_dataset(global_paths=True)
1716

1817
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval()
1918

20-
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
19+
val_dataset = DatasetQueryGallery(df_val)
2120

2221
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
2322
calculator = EmbeddingMetrics(extra_keys=("paths",))

docs/readme/examples_source/extractor/val_with_sequence.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,11 @@ from oml.metrics.embeddings import EmbeddingMetrics
4242
from oml.models import ViTExtractor
4343
from oml.utils.download_mock_dataset import download_mock_dataset
4444

45-
dataset_root = "mock_dataset/"
46-
_, df_val = download_mock_dataset(dataset_root, df_name="df_with_sequence.csv") # <- sequence info is in the file
45+
_, df_val = download_mock_dataset(global_paths=True, df_name="df_with_sequence.csv") # <- sequence info is in the file
4746

4847
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval()
4948

50-
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
49+
val_dataset = DatasetQueryGallery(df_val)
5150

5251
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
5352
calculator = EmbeddingMetrics(extra_keys=("paths",), sequence_key=val_dataset.sequence_key)

docs/readme/examples_source/postprocessing/predict.md

+19-23
Original file line numberDiff line numberDiff line change
@@ -5,44 +5,40 @@
55
[comment]:postprocessor-pred-start
66
```python
77
import torch
8-
from torch.utils.data import DataLoader
98

10-
from oml.const import PATHS_COLUMN
11-
from oml.datasets.base import DatasetQueryGallery
12-
from oml.inference.flat import inference_on_dataframe
9+
from oml.datasets import ImageQueryGalleryDataset
10+
from oml.inference import inference
1311
from oml.models import ConcatSiamese, ViTExtractor
1412
from oml.registry.transforms import get_transforms_for_pretrained
15-
from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor
13+
from oml.retrieval.postprocessors.pairwise import PairwiseReranker
1614
from oml.utils.download_mock_dataset import download_mock_dataset
1715
from oml.utils.misc_torch import pairwise_dist
1816

19-
dataset_root = "mock_dataset/"
20-
download_mock_dataset(dataset_root)
17+
_, df_test = download_mock_dataset(global_paths=True)
18+
del df_test["label"] # we don't need gt labels for doing predictions
2119

22-
# 1. Let's use feature extractor to get predictions
2320
extractor = ViTExtractor.from_pretrained("vits16_dino")
2421
transforms, _ = get_transforms_for_pretrained("vits16_dino")
2522

26-
_, emb_val, _, df_val = inference_on_dataframe(dataset_root, "df.csv", extractor, transforms=transforms)
23+
dataset = ImageQueryGalleryDataset(df_test, transform=transforms)
2724

28-
is_query = df_val["is_query"].astype('bool').values
29-
distances = pairwise_dist(x1=emb_val[is_query], x2=emb_val[~is_query])
25+
# 1. Let's get top 5 galleries closest to every query...
26+
embeddings = inference(extractor, dataset, batch_size=4, num_workers=0)
27+
embeddings_query = embeddings[dataset.get_query_ids()]
28+
embeddings_gallery = embeddings[dataset.get_gallery_ids()]
3029

31-
print("\nOriginal predictions:\n", torch.topk(distances, dim=1, k=3, largest=False)[1])
30+
distances = pairwise_dist(x1=embeddings_query, x2=embeddings_gallery, p=2)
31+
ii_closest = torch.topk(distances, dim=1, k=5, largest=False)[1]
3232

33-
# 2. Let's initialise a random pairwise postprocessor to perform re-ranking
33+
# 2. ... and let's re-rank first 3 of them
3434
siamese = ConcatSiamese(extractor=extractor, mlp_hidden_dims=[100]) # Note! Replace it with your trained postprocessor
35-
postprocessor = PairwiseImagesPostprocessor(top_n=3, pairwise_model=siamese, transforms=transforms)
36-
37-
dataset = DatasetQueryGallery(df_val, extra_data={"embeddings": emb_val}, transform=transforms)
38-
loader = DataLoader(dataset, batch_size=4)
39-
40-
query_paths = df_val[PATHS_COLUMN][is_query].values
41-
gallery_paths = df_val[PATHS_COLUMN][~is_query].values
42-
distances_upd = postprocessor.process(distances=distances, queries=query_paths, galleries=gallery_paths)
43-
44-
print("\nPredictions after postprocessing:\n", torch.topk(distances_upd, dim=1, k=3, largest=False)[1])
35+
postprocessor = PairwiseReranker(top_n=3, pairwise_model=siamese, batch_size=4, num_workers=0)
36+
distances_upd = postprocessor.process(distances, dataset=dataset)
37+
ii_closest_upd = torch.topk(distances_upd, dim=1, k=5, largest=False)[1]
4538

39+
# You may see the first 3 positions have changed, but the rest remain the same:
40+
print("\Closest galleries:\n", ii_closest)
41+
print("\nClosest galleries updates:\n", ii_closest_upd)
4642
```
4743
[comment]:postprocessor-pred-end
4844
</p>

0 commit comments

Comments
 (0)