@@ -294,7 +294,7 @@ docker pull omlteam/oml:cpu
294
294
import torch
295
295
from tqdm import tqdm
296
296
297
- from oml.datasets.base import DatasetWithLabels
297
+ from oml.datasets import ImageLabeledDataset
298
298
from oml.losses.triplet import TripletLossWithMiner
299
299
from oml.miners.inbatch_all_tri import AllTripletsMiner
300
300
from oml.models import ViTExtractor
@@ -306,7 +306,7 @@ df_train, _ = download_mock_dataset(global_paths=True)
306
306
extractor = ViTExtractor(" vits16_dino" , arch = " vits16" , normalise_features = False ).train()
307
307
optimizer = torch.optim.SGD(extractor.parameters(), lr = 1e-6 )
308
308
309
- train_dataset = DatasetWithLabels (df_train)
309
+ train_dataset = ImageLabeledDataset (df_train)
310
310
criterion = TripletLossWithMiner(margin = 0.1 , miner = AllTripletsMiner(), need_logs = True )
311
311
sampler = BalanceSampler(train_dataset.get_labels(), n_labels = 2 , n_instances = 2 )
312
312
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler = sampler)
@@ -333,39 +333,28 @@ for batch in tqdm(train_loader):
333
333
334
334
[ comment ] :vanilla-validation-start
335
335
``` python
336
- import torch
337
- from tqdm import tqdm
338
336
339
- from oml.datasets.base import DatasetQueryGallery
340
- from oml.metrics.embeddings import EmbeddingMetrics
337
+ from oml.datasets import ImageQueryGalleryLabeledDataset
338
+ from oml.inference import inference
339
+ from oml.metrics import calc_retrieval_metrics_rr
341
340
from oml.models import ViTExtractor
341
+ from oml.retrieval import RetrievalResults
342
342
from oml.utils.download_mock_dataset import download_mock_dataset
343
+ from oml.registry.transforms import get_transforms_for_pretrained
343
344
344
- _, df_val = download_mock_dataset(global_paths = True )
345
-
346
- extractor = ViTExtractor(" vits16_dino" , arch = " vits16" , normalise_features = False ).eval()
347
-
348
- val_dataset = DatasetQueryGallery(df_val)
349
-
350
- val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 4 )
351
- calculator = EmbeddingMetrics(extra_keys = (" paths" ,))
352
- calculator.setup(num_samples = len (val_dataset))
345
+ extractor = ViTExtractor.from_pretrained(" vits16_dino" )
346
+ transform, _ = get_transforms_for_pretrained(" vits16_dino" )
353
347
354
- with torch.no_grad():
355
- for batch in tqdm(val_loader):
356
- batch[" embeddings" ] = extractor(batch[" input_tensors" ])
357
- calculator.update_data(batch)
348
+ _, df_val = download_mock_dataset(global_paths = True )
349
+ dataset = ImageQueryGalleryLabeledDataset(df_val, transform = transform)
358
350
359
- metrics = calculator.compute_metrics( )
351
+ embeddings = inference(extractor, dataset, batch_size = 4 )
360
352
361
- # Logging
362
- print (calculator.metrics) # metrics
363
- print (calculator.metrics_unreduced) # metrics without averaging over queries
353
+ rr = RetrievalResults.compute_from_embeddings(embeddings, dataset, n_items_to_retrieve = 5 )
354
+ metrics = calc_retrieval_metrics_rr(rr, map_top_k = (3 , 5 ), precision_top_k = (5 ,), cmc_top_k = (3 ,))
364
355
365
- # Visualisation
366
- calculator.get_plot_for_queries(query_ids = [0 , 2 ], n_instances = 5 ) # draw predictions on predefined queries
367
- calculator.get_plot_for_worst_queries(metric_name = " OVERALL/map/5" , n_queries = 2 , n_instances = 5 ) # draw mistakes
368
- calculator.visualize() # draw mistakes for all the available metrics
356
+ print (rr, " \n " , metrics)
357
+ rr.visualize(query_ids = [2 , 1 ], dataset = dataset).show()
369
358
370
359
```
371
360
[ comment ] :vanilla-validation-end
@@ -380,9 +369,10 @@ calculator.visualize() # draw mistakes for all the available metrics
380
369
[ comment ] :lightning-start
381
370
``` python
382
371
import pytorch_lightning as pl
383
- import torch
372
+ from torch.utils.data import DataLoader
373
+ from torch.optim import SGD
384
374
385
- from oml.datasets.base import DatasetQueryGallery, DatasetWithLabels
375
+ from oml.datasets import ImageLabeledDataset, ImageQueryGalleryLabeledDataset
386
376
from oml.lightning.modules.extractor import ExtractorModule
387
377
from oml.lightning.callbacks.metric import MetricValCallback
388
378
from oml.losses.triplet import TripletLossWithMiner
@@ -405,16 +395,16 @@ df_train, df_val = download_mock_dataset(global_paths=True)
405
395
extractor = ViTExtractor(" vits16_dino" , arch = " vits16" , normalise_features = False )
406
396
407
397
# train
408
- optimizer = torch.optim. SGD(extractor.parameters(), lr = 1e-6 )
409
- train_dataset = DatasetWithLabels (df_train)
398
+ optimizer = SGD(extractor.parameters(), lr = 1e-6 )
399
+ train_dataset = ImageLabeledDataset (df_train)
410
400
criterion = TripletLossWithMiner(margin = 0.1 , miner = AllTripletsMiner())
411
401
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels = 2 , n_instances = 3 )
412
- train_loader = torch.utils.data. DataLoader(train_dataset, batch_sampler = batch_sampler)
402
+ train_loader = DataLoader(train_dataset, batch_sampler = batch_sampler)
413
403
414
404
# val
415
- val_dataset = DatasetQueryGallery (df_val)
416
- val_loader = torch.utils.data. DataLoader(val_dataset, batch_size = 4 )
417
- metric_callback = MetricValCallback(metric = EmbeddingMetrics(extra_keys = [train_dataset.paths_key,] ), log_images = True )
405
+ val_dataset = ImageQueryGalleryLabeledDataset (df_val)
406
+ val_loader = DataLoader(val_dataset, batch_size = 4 )
407
+ metric_callback = MetricValCallback(metric = EmbeddingMetrics(dataset = val_dataset ), log_images = True )
418
408
419
409
# 1) Logging with Tensorboard
420
410
logger = TensorBoardPipelineLogger(" ." )
@@ -450,14 +440,13 @@ trainer.fit(pl_model, train_dataloaders=train_loader, val_dataloaders=val_loader
450
440
451
441
[ comment ] :usage-retrieval-start
452
442
``` python
453
- import torch
454
-
455
443
from oml.datasets import ImageQueryGalleryDataset
456
444
from oml.inference import inference
457
445
from oml.models import ViTExtractor
458
446
from oml.registry.transforms import get_transforms_for_pretrained
459
447
from oml.utils.download_mock_dataset import download_mock_dataset
460
- from oml.utils.misc_torch import pairwise_dist
448
+ from oml.retrieval.retrieval_results import RetrievalResults
449
+
461
450
462
451
_, df_test = download_mock_dataset(global_paths = True )
463
452
del df_test[" label" ] # we don't need gt labels for doing predictions
@@ -466,25 +455,14 @@ extractor = ViTExtractor.from_pretrained("vits16_dino")
466
455
transform, _ = get_transforms_for_pretrained(" vits16_dino" )
467
456
468
457
dataset = ImageQueryGalleryDataset(df_test, transform = transform)
469
-
470
458
embeddings = inference(extractor, dataset, batch_size = 4 , num_workers = 0 )
471
- embeddings_query = embeddings[dataset.get_query_ids()]
472
- embeddings_gallery = embeddings[dataset.get_gallery_ids()]
473
459
474
- # Now we can explicitly build pairwise matrix of distances or save you RAM via using kNN
475
- use_knn = False
476
- top_k = 3
460
+ retrieval_results = RetrievalResults.compute_from_embeddings(embeddings, dataset, n_items_to_retrieve = 5 )
477
461
478
- if use_knn:
479
- from sklearn.neighbors import NearestNeighbors
480
- knn = NearestNeighbors(algorithm = " auto" , p = 2 ).fit(embeddings_query)
481
- dists, ii_closest = knn.kneighbors(embeddings_gallery, n_neighbors = top_k, return_distance = True )
462
+ retrieval_results.visualize(query_ids = [0 , 1 ], dataset = dataset).show()
482
463
483
- else :
484
- dist_mat = pairwise_dist(x1 = embeddings_query, x2 = embeddings_gallery, p = 2 )
485
- dists, ii_closest = torch.topk(dist_mat, dim = 1 , k = top_k, largest = False )
464
+ print (retrieval_results) # you get the ids of retrieved items and the corresponding distances
486
465
487
- print (f " Top { top_k} items closest to queries are: \n { ii_closest} " )
488
466
```
489
467
[ comment ] :usage-retrieval-end
490
468
</p >
0 commit comments