Skip to content

Commit 8fabaa8

Browse files
authored
Integrated the previous changes with RetrievalResults class. Removed keys. Changed signature of EmbeddingMetrics.
**CHANGELOG** * removed keys: `IS_QUERY_KEY`, `IS_GALLERY_KEY`, `CATEGORIES_KEY`, `PATHS_KEY`, `X1_KEY`, `X2_KEY`, `Y1_KEY`, `Y2_KEY`, `SEQUENCE_KEY`. Categories and sequences are passed through `extra_data` instead. The rest is incapsulated in Dataset. * Removed `IMetricDDP`, `EmbeddingMetricsDDP`. Reason: having `EmbeddingMetrics` is enough, because we do accumulator sync there anyway. * Changed signatures of `EmbeddingMetrics`: keys replaces by providing dataset, removed `.sync()` and `.visualisation()` methods and so on. * Updated `.md` examples and `.rst` docs Minor: * removed: `calc_distance_matrix`, `validate_dataset` -- this logic happens in `RetrievalResults`, also removed `find_first_occurrences` -- we have `unique_by_ids` instead. * removed `DummyDataset` in tests (used `EmbeddingsQueryGalleryLabeledDataset` instead).
1 parent 47e67f0 commit 8fabaa8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+409
-1008
lines changed

Makefile

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ pip_install_actual_oml:
119119
.PHONY: clean
120120
clean:
121121
find . -type d -name "__pycache__" -exec rm -r {} +
122+
find . -type d -name "lightning_logs" -exec rm -r {} +
123+
find . -type d -name "ml-runs" -exec rm -r {} +
124+
find . -type d -name "logs" -exec rm -r {} +
125+
find . -type d -name ".ipynb_checkpoints" -exec rm -r {} +
122126
find . -type f -name "*.log" -exec rm {} +
123-
find . -type f -name "*.predictions.json" -exec rm {} +
127+
find . -type f -name "*predictions.json" -exec rm {} +
124128
rm -rf docs/build
129+
rm -rf outputs/

README.md

Lines changed: 30 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ docker pull omlteam/oml:cpu
294294
import torch
295295
from tqdm import tqdm
296296

297-
from oml.datasets.base import DatasetWithLabels
297+
from oml.datasets import ImageLabeledDataset
298298
from oml.losses.triplet import TripletLossWithMiner
299299
from oml.miners.inbatch_all_tri import AllTripletsMiner
300300
from oml.models import ViTExtractor
@@ -306,7 +306,7 @@ df_train, _ = download_mock_dataset(global_paths=True)
306306
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
307307
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
308308

309-
train_dataset = DatasetWithLabels(df_train)
309+
train_dataset = ImageLabeledDataset(df_train)
310310
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner(), need_logs=True)
311311
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
312312
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler)
@@ -333,39 +333,28 @@ for batch in tqdm(train_loader):
333333

334334
[comment]:vanilla-validation-start
335335
```python
336-
import torch
337-
from tqdm import tqdm
338336

339-
from oml.datasets.base import DatasetQueryGallery
340-
from oml.metrics.embeddings import EmbeddingMetrics
337+
from oml.datasets import ImageQueryGalleryLabeledDataset
338+
from oml.inference import inference
339+
from oml.metrics import calc_retrieval_metrics_rr
341340
from oml.models import ViTExtractor
341+
from oml.retrieval import RetrievalResults
342342
from oml.utils.download_mock_dataset import download_mock_dataset
343+
from oml.registry.transforms import get_transforms_for_pretrained
343344

344-
_, df_val = download_mock_dataset(global_paths=True)
345-
346-
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval()
347-
348-
val_dataset = DatasetQueryGallery(df_val)
349-
350-
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
351-
calculator = EmbeddingMetrics(extra_keys=("paths",))
352-
calculator.setup(num_samples=len(val_dataset))
345+
extractor = ViTExtractor.from_pretrained("vits16_dino")
346+
transform, _ = get_transforms_for_pretrained("vits16_dino")
353347

354-
with torch.no_grad():
355-
for batch in tqdm(val_loader):
356-
batch["embeddings"] = extractor(batch["input_tensors"])
357-
calculator.update_data(batch)
348+
_, df_val = download_mock_dataset(global_paths=True)
349+
dataset = ImageQueryGalleryLabeledDataset(df_val, transform=transform)
358350

359-
metrics = calculator.compute_metrics()
351+
embeddings = inference(extractor, dataset, batch_size=4)
360352

361-
# Logging
362-
print(calculator.metrics) # metrics
363-
print(calculator.metrics_unreduced) # metrics without averaging over queries
353+
rr = RetrievalResults.compute_from_embeddings(embeddings, dataset, n_items_to_retrieve=5)
354+
metrics = calc_retrieval_metrics_rr(rr, map_top_k=(3, 5), precision_top_k=(5,), cmc_top_k=(3,))
364355

365-
# Visualisation
366-
calculator.get_plot_for_queries(query_ids=[0, 2], n_instances=5) # draw predictions on predefined queries
367-
calculator.get_plot_for_worst_queries(metric_name="OVERALL/map/5", n_queries=2, n_instances=5) # draw mistakes
368-
calculator.visualize() # draw mistakes for all the available metrics
356+
print(rr, "\n", metrics)
357+
rr.visualize(query_ids=[2, 1], dataset=dataset).show()
369358

370359
```
371360
[comment]:vanilla-validation-end
@@ -380,9 +369,10 @@ calculator.visualize() # draw mistakes for all the available metrics
380369
[comment]:lightning-start
381370
```python
382371
import pytorch_lightning as pl
383-
import torch
372+
from torch.utils.data import DataLoader
373+
from torch.optim import SGD
384374

385-
from oml.datasets.base import DatasetQueryGallery, DatasetWithLabels
375+
from oml.datasets import ImageLabeledDataset, ImageQueryGalleryLabeledDataset
386376
from oml.lightning.modules.extractor import ExtractorModule
387377
from oml.lightning.callbacks.metric import MetricValCallback
388378
from oml.losses.triplet import TripletLossWithMiner
@@ -405,16 +395,16 @@ df_train, df_val = download_mock_dataset(global_paths=True)
405395
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)
406396

407397
# train
408-
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
409-
train_dataset = DatasetWithLabels(df_train)
398+
optimizer = SGD(extractor.parameters(), lr=1e-6)
399+
train_dataset = ImageLabeledDataset(df_train)
410400
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
411401
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
412-
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler)
402+
train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler)
413403

414404
# val
415-
val_dataset = DatasetQueryGallery(df_val)
416-
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
417-
metric_callback = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[train_dataset.paths_key,]), log_images=True)
405+
val_dataset = ImageQueryGalleryLabeledDataset(df_val)
406+
val_loader = DataLoader(val_dataset, batch_size=4)
407+
metric_callback = MetricValCallback(metric=EmbeddingMetrics(dataset=val_dataset), log_images=True)
418408

419409
# 1) Logging with Tensorboard
420410
logger = TensorBoardPipelineLogger(".")
@@ -450,14 +440,13 @@ trainer.fit(pl_model, train_dataloaders=train_loader, val_dataloaders=val_loader
450440

451441
[comment]:usage-retrieval-start
452442
```python
453-
import torch
454-
455443
from oml.datasets import ImageQueryGalleryDataset
456444
from oml.inference import inference
457445
from oml.models import ViTExtractor
458446
from oml.registry.transforms import get_transforms_for_pretrained
459447
from oml.utils.download_mock_dataset import download_mock_dataset
460-
from oml.utils.misc_torch import pairwise_dist
448+
from oml.retrieval.retrieval_results import RetrievalResults
449+
461450

462451
_, df_test = download_mock_dataset(global_paths=True)
463452
del df_test["label"] # we don't need gt labels for doing predictions
@@ -466,25 +455,14 @@ extractor = ViTExtractor.from_pretrained("vits16_dino")
466455
transform, _ = get_transforms_for_pretrained("vits16_dino")
467456

468457
dataset = ImageQueryGalleryDataset(df_test, transform=transform)
469-
470458
embeddings = inference(extractor, dataset, batch_size=4, num_workers=0)
471-
embeddings_query = embeddings[dataset.get_query_ids()]
472-
embeddings_gallery = embeddings[dataset.get_gallery_ids()]
473459

474-
# Now we can explicitly build pairwise matrix of distances or save you RAM via using kNN
475-
use_knn = False
476-
top_k = 3
460+
retrieval_results = RetrievalResults.compute_from_embeddings(embeddings, dataset, n_items_to_retrieve=5)
477461

478-
if use_knn:
479-
from sklearn.neighbors import NearestNeighbors
480-
knn = NearestNeighbors(algorithm="auto", p=2).fit(embeddings_query)
481-
dists, ii_closest = knn.kneighbors(embeddings_gallery, n_neighbors=top_k, return_distance=True)
462+
retrieval_results.visualize(query_ids=[0, 1], dataset=dataset).show()
482463

483-
else:
484-
dist_mat = pairwise_dist(x1=embeddings_query, x2=embeddings_gallery, p=2)
485-
dists, ii_closest = torch.topk(dist_mat, dim=1, k=top_k, largest=False)
464+
print(retrieval_results) # you get the ids of retrieved items and the corresponding distances
486465

487-
print(f"Top {top_k} items closest to queries are:\n {ii_closest}")
488466
```
489467
[comment]:usage-retrieval-end
490468
</p>

docs/readme/examples_source/extractor/retrieval_usage.md

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44

55
[comment]:usage-retrieval-start
66
```python
7-
import torch
8-
97
from oml.datasets import ImageQueryGalleryDataset
108
from oml.inference import inference
119
from oml.models import ViTExtractor
1210
from oml.registry.transforms import get_transforms_for_pretrained
1311
from oml.utils.download_mock_dataset import download_mock_dataset
14-
from oml.utils.misc_torch import pairwise_dist
12+
from oml.retrieval.retrieval_results import RetrievalResults
13+
1514

1615
_, df_test = download_mock_dataset(global_paths=True)
1716
del df_test["label"] # we don't need gt labels for doing predictions
@@ -20,25 +19,14 @@ extractor = ViTExtractor.from_pretrained("vits16_dino")
2019
transform, _ = get_transforms_for_pretrained("vits16_dino")
2120

2221
dataset = ImageQueryGalleryDataset(df_test, transform=transform)
23-
2422
embeddings = inference(extractor, dataset, batch_size=4, num_workers=0)
25-
embeddings_query = embeddings[dataset.get_query_ids()]
26-
embeddings_gallery = embeddings[dataset.get_gallery_ids()]
2723

28-
# Now we can explicitly build pairwise matrix of distances or save you RAM via using kNN
29-
use_knn = False
30-
top_k = 3
24+
retrieval_results = RetrievalResults.compute_from_embeddings(embeddings, dataset, n_items_to_retrieve=5)
3125

32-
if use_knn:
33-
from sklearn.neighbors import NearestNeighbors
34-
knn = NearestNeighbors(algorithm="auto", p=2).fit(embeddings_query)
35-
dists, ii_closest = knn.kneighbors(embeddings_gallery, n_neighbors=top_k, return_distance=True)
26+
retrieval_results.visualize(query_ids=[0, 1], dataset=dataset).show()
3627

37-
else:
38-
dist_mat = pairwise_dist(x1=embeddings_query, x2=embeddings_gallery, p=2)
39-
dists, ii_closest = torch.topk(dist_mat, dim=1, k=top_k, largest=False)
28+
print(retrieval_results) # you get the ids of retrieved items and the corresponding distances
4029

41-
print(f"Top {top_k} items closest to queries are:\n {ii_closest}")
4230
```
4331
[comment]:usage-retrieval-end
4432
</p>

docs/readme/examples_source/extractor/train.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from tqdm import tqdm
99

10-
from oml.datasets.base import DatasetWithLabels
10+
from oml.datasets import ImageLabeledDataset
1111
from oml.losses.triplet import TripletLossWithMiner
1212
from oml.miners.inbatch_all_tri import AllTripletsMiner
1313
from oml.models import ViTExtractor
@@ -19,7 +19,7 @@ df_train, _ = download_mock_dataset(global_paths=True)
1919
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
2020
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
2121

22-
train_dataset = DatasetWithLabels(df_train)
22+
train_dataset = ImageLabeledDataset(df_train)
2323
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner(), need_logs=True)
2424
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
2525
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler)

docs/readme/examples_source/extractor/train_2loaders_val.md

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
[comment]:lightning-2loaders-start
66
```python
77
import pytorch_lightning as pl
8-
import torch
98

10-
from oml.datasets.base import DatasetQueryGallery
9+
from torch.utils.data import DataLoader
10+
11+
from oml.datasets import ImageQueryGalleryLabeledDataset
1112
from oml.lightning.callbacks.metric import MetricValCallback
1213
from oml.lightning.modules.extractor import ExtractorModule
1314
from oml.metrics.embeddings import EmbeddingMetrics
@@ -20,24 +21,24 @@ _, df_val = download_mock_dataset(global_paths=True)
2021
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)
2122

2223
# 1st validation dataset (big images)
23-
val_dataset_1 = DatasetQueryGallery(df_val, transform=get_normalisation_resize_torch(im_size=224))
24-
val_loader_1 = torch.utils.data.DataLoader(val_dataset_1, batch_size=4)
25-
metric_callback_1 = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[val_dataset_1.paths_key,]),
24+
val_dataset_1 = ImageQueryGalleryLabeledDataset(df_val, transform=get_normalisation_resize_torch(im_size=224))
25+
val_loader_1 = DataLoader(val_dataset_1, batch_size=4)
26+
metric_callback_1 = MetricValCallback(metric=EmbeddingMetrics(dataset=val_dataset_1),
2627
log_images=True, loader_idx=0)
2728

2829
# 2nd validation dataset (small images)
29-
val_dataset_2 = DatasetQueryGallery(df_val, transform=get_normalisation_resize_torch(im_size=48))
30-
val_loader_2 = torch.utils.data.DataLoader(val_dataset_2, batch_size=4)
31-
metric_callback_2 = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[val_dataset_2.paths_key,]),
30+
val_dataset_2 = ImageQueryGalleryLabeledDataset(df_val, transform=get_normalisation_resize_torch(im_size=48))
31+
val_loader_2 = DataLoader(val_dataset_2, batch_size=4)
32+
metric_callback_2 = MetricValCallback(metric=EmbeddingMetrics(dataset=val_dataset_2),
3233
log_images=True, loader_idx=1)
3334

3435
# run validation
3536
pl_model = ExtractorModule(extractor, None, None)
3637
trainer = pl.Trainer(max_epochs=3, callbacks=[metric_callback_1, metric_callback_2], num_sanity_val_steps=0)
3738
trainer.validate(pl_model, dataloaders=(val_loader_1, val_loader_2))
3839

39-
print(metric_callback_1.metric.metrics)
40-
print(metric_callback_2.metric.metrics)
40+
print(metric_callback_1.metric.retrieval_results)
41+
print(metric_callback_2.metric.retrieval_results)
4142
```
4243
[comment]:lightning-2loaders-end
4344
</p>

docs/readme/examples_source/extractor/train_val_pl.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
[comment]:lightning-start
66
```python
77
import pytorch_lightning as pl
8-
import torch
8+
from torch.utils.data import DataLoader
9+
from torch.optim import SGD
910

10-
from oml.datasets.base import DatasetQueryGallery, DatasetWithLabels
11+
from oml.datasets import ImageLabeledDataset, ImageQueryGalleryLabeledDataset
1112
from oml.lightning.modules.extractor import ExtractorModule
1213
from oml.lightning.callbacks.metric import MetricValCallback
1314
from oml.losses.triplet import TripletLossWithMiner
@@ -30,16 +31,16 @@ df_train, df_val = download_mock_dataset(global_paths=True)
3031
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)
3132

3233
# train
33-
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
34-
train_dataset = DatasetWithLabels(df_train)
34+
optimizer = SGD(extractor.parameters(), lr=1e-6)
35+
train_dataset = ImageLabeledDataset(df_train)
3536
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
3637
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
37-
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler)
38+
train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler)
3839

3940
# val
40-
val_dataset = DatasetQueryGallery(df_val)
41-
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
42-
metric_callback = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[train_dataset.paths_key,]), log_images=True)
41+
val_dataset = ImageQueryGalleryLabeledDataset(df_val)
42+
val_loader = DataLoader(val_dataset, batch_size=4)
43+
metric_callback = MetricValCallback(metric=EmbeddingMetrics(dataset=val_dataset), log_images=True)
4344

4445
# 1) Logging with Tensorboard
4546
logger = TensorBoardPipelineLogger(".")

docs/readme/examples_source/extractor/train_val_pl_ddp.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
[comment]:lightning-ddp-start
77
```python
88
import pytorch_lightning as pl
9-
import torch
9+
from torch.utils.data import DataLoader
10+
from torch.optim import SGD
1011

11-
from oml.datasets.base import DatasetQueryGallery, DatasetWithLabels
12+
from oml.datasets import ImageQueryGalleryLabeledDataset, ImageLabeledDataset
1213
from oml.lightning.modules.extractor import ExtractorModuleDDP
1314
from oml.lightning.callbacks.metric import MetricValCallbackDDP
1415
from oml.losses.triplet import TripletLossWithMiner
15-
from oml.metrics.embeddings import EmbeddingMetricsDDP
16+
from oml.metrics.embeddings import EmbeddingMetrics
1617
from oml.miners.inbatch_all_tri import AllTripletsMiner
1718
from oml.models import ViTExtractor
1819
from oml.samplers.balance import BalanceSampler
@@ -25,16 +26,16 @@ df_train, df_val = download_mock_dataset(global_paths=True)
2526
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)
2627

2728
# train
28-
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
29-
train_dataset = DatasetWithLabels(df_train)
29+
optimizer = SGD(extractor.parameters(), lr=1e-6)
30+
train_dataset = ImageLabeledDataset(df_train)
3031
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
3132
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
32-
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler)
33+
train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler)
3334

3435
# val
35-
val_dataset = DatasetQueryGallery(df_val)
36-
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
37-
metric_callback = MetricValCallbackDDP(metric=EmbeddingMetricsDDP()) # DDP specific
36+
val_dataset = ImageQueryGalleryLabeledDataset(df_val)
37+
val_loader = DataLoader(val_dataset, batch_size=4)
38+
metric_callback = MetricValCallbackDDP(metric=EmbeddingMetrics(dataset=val_dataset)) # DDP specific
3839

3940
# run
4041
pl_model = ExtractorModuleDDP(extractor=extractor, criterion=criterion, optimizer=optimizer,

docs/readme/examples_source/extractor/train_with_pml.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from tqdm import tqdm
88

9-
from oml.datasets.base import DatasetWithLabels
9+
from oml.datasets import ImageLabeledDataset
1010
from oml.models import ViTExtractor
1111
from oml.samplers.balance import BalanceSampler
1212
from oml.utils.download_mock_dataset import download_mock_dataset
@@ -18,7 +18,7 @@ df_train, _ = download_mock_dataset(global_paths=True)
1818
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
1919
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
2020

21-
train_dataset = DatasetWithLabels(df_train)
21+
train_dataset = ImageLabeledDataset(df_train)
2222

2323
# PML specific
2424
# criterion = losses.TripletMarginLoss(margin=0.2, triplets_per_anchor="all")

docs/readme/examples_source/extractor/train_with_pml_advanced.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from tqdm import tqdm
88

9-
from oml.datasets.base import DatasetWithLabels
9+
from oml.datasets import ImageLabeledDataset
1010
from oml.models import ViTExtractor
1111
from oml.samplers.balance import BalanceSampler
1212
from oml.utils.download_mock_dataset import download_mock_dataset
@@ -18,7 +18,7 @@ df_train, _ = download_mock_dataset(global_paths=True)
1818
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
1919
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
2020

21-
train_dataset = DatasetWithLabels(df_train)
21+
train_dataset = ImageLabeledDataset(df_train)
2222

2323
# PML specific
2424
distance = distances.LpDistance(p=2)

0 commit comments

Comments
 (0)