@@ -301,13 +301,12 @@ from oml.models import ViTExtractor
301
301
from oml.samplers.balance import BalanceSampler
302
302
from oml.utils.download_mock_dataset import download_mock_dataset
303
303
304
- dataset_root = " mock_dataset/"
305
- df_train, _ = download_mock_dataset(dataset_root)
304
+ df_train, _ = download_mock_dataset(global_paths = True )
306
305
307
306
extractor = ViTExtractor(" vits16_dino" , arch = " vits16" , normalise_features = False ).train()
308
307
optimizer = torch.optim.SGD(extractor.parameters(), lr = 1e-6 )
309
308
310
- train_dataset = DatasetWithLabels(df_train, dataset_root = dataset_root )
309
+ train_dataset = DatasetWithLabels(df_train)
311
310
criterion = TripletLossWithMiner(margin = 0.1 , miner = AllTripletsMiner(), need_logs = True )
312
311
sampler = BalanceSampler(train_dataset.get_labels(), n_labels = 2 , n_instances = 2 )
313
312
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler = sampler)
@@ -342,12 +341,11 @@ from oml.metrics.embeddings import EmbeddingMetrics
342
341
from oml.models import ViTExtractor
343
342
from oml.utils.download_mock_dataset import download_mock_dataset
344
343
345
- dataset_root = " mock_dataset/"
346
- _, df_val = download_mock_dataset(dataset_root)
344
+ _, df_val = download_mock_dataset(global_paths = True )
347
345
348
346
extractor = ViTExtractor(" vits16_dino" , arch = " vits16" , normalise_features = False ).eval()
349
347
350
- val_dataset = DatasetQueryGallery(df_val, dataset_root = dataset_root )
348
+ val_dataset = DatasetQueryGallery(df_val)
351
349
352
350
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 4 )
353
351
calculator = EmbeddingMetrics(extra_keys = (" paths" ,))
@@ -401,21 +399,20 @@ from oml.lightning.pipelines.logging import (
401
399
WandBPipelineLogger,
402
400
)
403
401
404
- dataset_root = " mock_dataset/"
405
- df_train, df_val = download_mock_dataset(dataset_root)
402
+ df_train, df_val = download_mock_dataset(global_paths = True )
406
403
407
404
# model
408
405
extractor = ViTExtractor(" vits16_dino" , arch = " vits16" , normalise_features = False )
409
406
410
407
# train
411
408
optimizer = torch.optim.SGD(extractor.parameters(), lr = 1e-6 )
412
- train_dataset = DatasetWithLabels(df_train, dataset_root = dataset_root )
409
+ train_dataset = DatasetWithLabels(df_train)
413
410
criterion = TripletLossWithMiner(margin = 0.1 , miner = AllTripletsMiner())
414
411
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels = 2 , n_instances = 3 )
415
412
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler = batch_sampler)
416
413
417
414
# val
418
- val_dataset = DatasetQueryGallery(df_val, dataset_root = dataset_root )
415
+ val_dataset = DatasetQueryGallery(df_val)
419
416
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 4 )
420
417
metric_callback = MetricValCallback(metric = EmbeddingMetrics(extra_keys = [train_dataset.paths_key,]), log_images = True )
421
418
@@ -455,37 +452,36 @@ trainer.fit(pl_model, train_dataloaders=train_loader, val_dataloaders=val_loader
455
452
``` python
456
453
import torch
457
454
458
- from oml.const import MOCK_DATASET_PATH
459
- from oml.inference.flat import inference_on_images
455
+ from oml.datasets import ImageQueryGalleryDataset
456
+ from oml.inference import inference
460
457
from oml.models import ViTExtractor
461
458
from oml.registry.transforms import get_transforms_for_pretrained
462
459
from oml.utils.download_mock_dataset import download_mock_dataset
463
460
from oml.utils.misc_torch import pairwise_dist
464
461
465
- _, df_val = download_mock_dataset(MOCK_DATASET_PATH )
466
- df_val[" path" ] = df_val[" path" ].apply(lambda x : MOCK_DATASET_PATH / x)
467
- queries = df_val[df_val[" is_query" ]][" path" ].tolist()
468
- galleries = df_val[df_val[" is_gallery" ]][" path" ].tolist()
462
+ _, df_test = download_mock_dataset(global_paths = True )
463
+ del df_test[" label" ] # we don't need gt labels for doing predictions
469
464
470
465
extractor = ViTExtractor.from_pretrained(" vits16_dino" )
471
466
transform, _ = get_transforms_for_pretrained(" vits16_dino" )
472
467
473
- args = {" num_workers" : 0 , " batch_size" : 8 }
474
- features_queries = inference_on_images(extractor, paths = queries, transform = transform, ** args)
475
- features_galleries = inference_on_images(extractor, paths = galleries, transform = transform, ** args)
468
+ dataset = ImageQueryGalleryDataset(df_test, transform = transform)
469
+
470
+ embeddings = inference(extractor, dataset, batch_size = 4 , num_workers = 0 )
471
+ embeddings_query = embeddings[dataset.get_query_ids()]
472
+ embeddings_gallery = embeddings[dataset.get_gallery_ids()]
476
473
477
474
# Now we can explicitly build pairwise matrix of distances or save you RAM via using kNN
478
475
use_knn = False
479
476
top_k = 3
480
477
481
478
if use_knn:
482
479
from sklearn.neighbors import NearestNeighbors
483
- knn = NearestNeighbors(algorithm = " auto" , p = 2 )
484
- knn.fit(features_galleries)
485
- dists, ii_closest = knn.kneighbors(features_queries, n_neighbors = top_k, return_distance = True )
480
+ knn = NearestNeighbors(algorithm = " auto" , p = 2 ).fit(embeddings_query)
481
+ dists, ii_closest = knn.kneighbors(embeddings_gallery, n_neighbors = top_k, return_distance = True )
486
482
487
483
else :
488
- dist_mat = pairwise_dist(x1 = features_queries , x2 = features_galleries )
484
+ dist_mat = pairwise_dist(x1 = embeddings_query , x2 = embeddings_gallery, p = 2 )
489
485
dists, ii_closest = torch.topk(dist_mat, dim = 1 , k = top_k, largest = False )
490
486
491
487
print (f " Top { top_k} items closest to queries are: \n { ii_closest} " )
0 commit comments