diff --git a/README.md b/README.md index c20a2aa0e..60810a3fe 100644 --- a/README.md +++ b/README.md @@ -301,13 +301,12 @@ from oml.models import ViTExtractor from oml.samplers.balance import BalanceSampler from oml.utils.download_mock_dataset import download_mock_dataset -dataset_root = "mock_dataset/" -df_train, _ = download_mock_dataset(dataset_root) +df_train, _ = download_mock_dataset(global_paths=True) extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train() optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) -train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) +train_dataset = DatasetWithLabels(df_train) criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner(), need_logs=True) sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2) train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler) @@ -342,12 +341,11 @@ from oml.metrics.embeddings import EmbeddingMetrics from oml.models import ViTExtractor from oml.utils.download_mock_dataset import download_mock_dataset -dataset_root = "mock_dataset/" -_, df_val = download_mock_dataset(dataset_root) +_, df_val = download_mock_dataset(global_paths=True) extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval() -val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root) +val_dataset = DatasetQueryGallery(df_val) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4) calculator = EmbeddingMetrics(extra_keys=("paths",)) @@ -401,21 +399,20 @@ from oml.lightning.pipelines.logging import ( WandBPipelineLogger, ) -dataset_root = "mock_dataset/" -df_train, df_val = download_mock_dataset(dataset_root) +df_train, df_val = download_mock_dataset(global_paths=True) # model extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False) # train optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) -train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) +train_dataset = DatasetWithLabels(df_train) criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner()) batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3) train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler) # val -val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root) +val_dataset = DatasetQueryGallery(df_val) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4) metric_callback = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[train_dataset.paths_key,]), log_images=True) @@ -455,24 +452,24 @@ trainer.fit(pl_model, train_dataloaders=train_loader, val_dataloaders=val_loader ```python import torch -from oml.const import MOCK_DATASET_PATH -from oml.inference.flat import inference_on_images +from oml.datasets import ImageQueryGalleryDataset +from oml.inference import inference from oml.models import ViTExtractor from oml.registry.transforms import get_transforms_for_pretrained from oml.utils.download_mock_dataset import download_mock_dataset from oml.utils.misc_torch import pairwise_dist -_, df_val = download_mock_dataset(MOCK_DATASET_PATH) -df_val["path"] = df_val["path"].apply(lambda x: MOCK_DATASET_PATH / x) -queries = df_val[df_val["is_query"]]["path"].tolist() -galleries = df_val[df_val["is_gallery"]]["path"].tolist() +_, df_test = download_mock_dataset(global_paths=True) +del df_test["label"] # we don't need gt labels for doing predictions extractor = ViTExtractor.from_pretrained("vits16_dino") transform, _ = get_transforms_for_pretrained("vits16_dino") -args = {"num_workers": 0, "batch_size": 8} -features_queries = inference_on_images(extractor, paths=queries, transform=transform, **args) -features_galleries = inference_on_images(extractor, paths=galleries, transform=transform, **args) +dataset = ImageQueryGalleryDataset(df_test, transform=transform) + +embeddings = inference(extractor, dataset, batch_size=4, num_workers=0) +embeddings_query = embeddings[dataset.get_query_ids()] +embeddings_gallery = embeddings[dataset.get_gallery_ids()] # Now we can explicitly build pairwise matrix of distances or save you RAM via using kNN use_knn = False @@ -480,12 +477,11 @@ top_k = 3 if use_knn: from sklearn.neighbors import NearestNeighbors - knn = NearestNeighbors(algorithm="auto", p=2) - knn.fit(features_galleries) - dists, ii_closest = knn.kneighbors(features_queries, n_neighbors=top_k, return_distance=True) + knn = NearestNeighbors(algorithm="auto", p=2).fit(embeddings_query) + dists, ii_closest = knn.kneighbors(embeddings_gallery, n_neighbors=top_k, return_distance=True) else: - dist_mat = pairwise_dist(x1=features_queries, x2=features_galleries) + dist_mat = pairwise_dist(x1=embeddings_query, x2=embeddings_gallery, p=2) dists, ii_closest = torch.topk(dist_mat, dim=1, k=top_k, largest=False) print(f"Top {top_k} items closest to queries are:\n {ii_closest}") diff --git a/docs/readme/examples_source/extractor/retrieval_usage.md b/docs/readme/examples_source/extractor/retrieval_usage.md index 57954b738..a514bd69e 100644 --- a/docs/readme/examples_source/extractor/retrieval_usage.md +++ b/docs/readme/examples_source/extractor/retrieval_usage.md @@ -6,24 +6,24 @@ ```python import torch -from oml.const import MOCK_DATASET_PATH -from oml.inference.flat import inference_on_images +from oml.datasets import ImageQueryGalleryDataset +from oml.inference import inference from oml.models import ViTExtractor from oml.registry.transforms import get_transforms_for_pretrained from oml.utils.download_mock_dataset import download_mock_dataset from oml.utils.misc_torch import pairwise_dist -_, df_val = download_mock_dataset(MOCK_DATASET_PATH) -df_val["path"] = df_val["path"].apply(lambda x: MOCK_DATASET_PATH / x) -queries = df_val[df_val["is_query"]]["path"].tolist() -galleries = df_val[df_val["is_gallery"]]["path"].tolist() +_, df_test = download_mock_dataset(global_paths=True) +del df_test["label"] # we don't need gt labels for doing predictions extractor = ViTExtractor.from_pretrained("vits16_dino") transform, _ = get_transforms_for_pretrained("vits16_dino") -args = {"num_workers": 0, "batch_size": 8} -features_queries = inference_on_images(extractor, paths=queries, transform=transform, **args) -features_galleries = inference_on_images(extractor, paths=galleries, transform=transform, **args) +dataset = ImageQueryGalleryDataset(df_test, transform=transform) + +embeddings = inference(extractor, dataset, batch_size=4, num_workers=0) +embeddings_query = embeddings[dataset.get_query_ids()] +embeddings_gallery = embeddings[dataset.get_gallery_ids()] # Now we can explicitly build pairwise matrix of distances or save you RAM via using kNN use_knn = False @@ -31,12 +31,11 @@ top_k = 3 if use_knn: from sklearn.neighbors import NearestNeighbors - knn = NearestNeighbors(algorithm="auto", p=2) - knn.fit(features_galleries) - dists, ii_closest = knn.kneighbors(features_queries, n_neighbors=top_k, return_distance=True) + knn = NearestNeighbors(algorithm="auto", p=2).fit(embeddings_query) + dists, ii_closest = knn.kneighbors(embeddings_gallery, n_neighbors=top_k, return_distance=True) else: - dist_mat = pairwise_dist(x1=features_queries, x2=features_galleries) + dist_mat = pairwise_dist(x1=embeddings_query, x2=embeddings_gallery, p=2) dists, ii_closest = torch.topk(dist_mat, dim=1, k=top_k, largest=False) print(f"Top {top_k} items closest to queries are:\n {ii_closest}") diff --git a/docs/readme/examples_source/extractor/train.md b/docs/readme/examples_source/extractor/train.md index 50145ff5a..01110eeca 100644 --- a/docs/readme/examples_source/extractor/train.md +++ b/docs/readme/examples_source/extractor/train.md @@ -14,13 +14,12 @@ from oml.models import ViTExtractor from oml.samplers.balance import BalanceSampler from oml.utils.download_mock_dataset import download_mock_dataset -dataset_root = "mock_dataset/" -df_train, _ = download_mock_dataset(dataset_root) +df_train, _ = download_mock_dataset(global_paths=True) extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train() optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) -train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) +train_dataset = DatasetWithLabels(df_train) criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner(), need_logs=True) sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2) train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler) diff --git a/docs/readme/examples_source/extractor/train_2loaders_val.md b/docs/readme/examples_source/extractor/train_2loaders_val.md index eb1676da6..c277b133d 100644 --- a/docs/readme/examples_source/extractor/train_2loaders_val.md +++ b/docs/readme/examples_source/extractor/train_2loaders_val.md @@ -15,21 +15,18 @@ from oml.models import ViTExtractor from oml.transforms.images.torchvision import get_normalisation_resize_torch from oml.utils.download_mock_dataset import download_mock_dataset -dataset_root = "mock_dataset/" -_, df_val = download_mock_dataset(dataset_root) +_, df_val = download_mock_dataset(global_paths=True) extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False) # 1st validation dataset (big images) -val_dataset_1 = DatasetQueryGallery(df_val, dataset_root=dataset_root, - transform=get_normalisation_resize_torch(im_size=224)) +val_dataset_1 = DatasetQueryGallery(df_val, transform=get_normalisation_resize_torch(im_size=224)) val_loader_1 = torch.utils.data.DataLoader(val_dataset_1, batch_size=4) metric_callback_1 = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[val_dataset_1.paths_key,]), log_images=True, loader_idx=0) # 2nd validation dataset (small images) -val_dataset_2 = DatasetQueryGallery(df_val, dataset_root=dataset_root, - transform=get_normalisation_resize_torch(im_size=48)) +val_dataset_2 = DatasetQueryGallery(df_val, transform=get_normalisation_resize_torch(im_size=48)) val_loader_2 = torch.utils.data.DataLoader(val_dataset_2, batch_size=4) metric_callback_2 = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[val_dataset_2.paths_key,]), log_images=True, loader_idx=1) diff --git a/docs/readme/examples_source/extractor/train_val_pl.md b/docs/readme/examples_source/extractor/train_val_pl.md index fc9764a97..c182bb675 100644 --- a/docs/readme/examples_source/extractor/train_val_pl.md +++ b/docs/readme/examples_source/extractor/train_val_pl.md @@ -24,21 +24,20 @@ from oml.lightning.pipelines.logging import ( WandBPipelineLogger, ) -dataset_root = "mock_dataset/" -df_train, df_val = download_mock_dataset(dataset_root) +df_train, df_val = download_mock_dataset(global_paths=True) # model extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False) # train optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) -train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) +train_dataset = DatasetWithLabels(df_train) criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner()) batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3) train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler) # val -val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root) +val_dataset = DatasetQueryGallery(df_val) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4) metric_callback = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[train_dataset.paths_key,]), log_images=True) diff --git a/docs/readme/examples_source/extractor/train_val_pl_ddp.md b/docs/readme/examples_source/extractor/train_val_pl_ddp.md index 1aea03452..dc9cfd726 100644 --- a/docs/readme/examples_source/extractor/train_val_pl_ddp.md +++ b/docs/readme/examples_source/extractor/train_val_pl_ddp.md @@ -19,21 +19,20 @@ from oml.samplers.balance import BalanceSampler from oml.utils.download_mock_dataset import download_mock_dataset from pytorch_lightning.strategies import DDPStrategy -dataset_root = "mock_dataset/" -df_train, df_val = download_mock_dataset(dataset_root) +df_train, df_val = download_mock_dataset(global_paths=True) # model extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False) # train optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) -train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) +train_dataset = DatasetWithLabels(df_train) criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner()) batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3) train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler) # val -val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root) +val_dataset = DatasetQueryGallery(df_val) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4) metric_callback = MetricValCallbackDDP(metric=EmbeddingMetricsDDP()) # DDP specific diff --git a/docs/readme/examples_source/extractor/train_with_pml.md b/docs/readme/examples_source/extractor/train_with_pml.md index 138cc07ee..b5f249c82 100644 --- a/docs/readme/examples_source/extractor/train_with_pml.md +++ b/docs/readme/examples_source/extractor/train_with_pml.md @@ -13,13 +13,12 @@ from oml.utils.download_mock_dataset import download_mock_dataset from pytorch_metric_learning import losses, distances, reducers, miners -dataset_root = "mock_dataset/" -df_train, _ = download_mock_dataset(dataset_root) +df_train, _ = download_mock_dataset(global_paths=True) extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train() optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) -train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) +train_dataset = DatasetWithLabels(df_train) # PML specific # criterion = losses.TripletMarginLoss(margin=0.2, triplets_per_anchor="all") diff --git a/docs/readme/examples_source/extractor/train_with_pml_advanced.md b/docs/readme/examples_source/extractor/train_with_pml_advanced.md index c0b245253..33a27d59b 100644 --- a/docs/readme/examples_source/extractor/train_with_pml_advanced.md +++ b/docs/readme/examples_source/extractor/train_with_pml_advanced.md @@ -13,13 +13,12 @@ from oml.utils.download_mock_dataset import download_mock_dataset from pytorch_metric_learning import losses, distances, reducers, miners -dataset_root = "mock_dataset/" -df_train, _ = download_mock_dataset(dataset_root) +df_train, _ = download_mock_dataset(global_paths=True) extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train() optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) -train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) +train_dataset = DatasetWithLabels(df_train) # PML specific distance = distances.LpDistance(p=2) diff --git a/docs/readme/examples_source/extractor/val.md b/docs/readme/examples_source/extractor/val.md index 39181f17d..3f71ebd25 100644 --- a/docs/readme/examples_source/extractor/val.md +++ b/docs/readme/examples_source/extractor/val.md @@ -12,12 +12,11 @@ from oml.metrics.embeddings import EmbeddingMetrics from oml.models import ViTExtractor from oml.utils.download_mock_dataset import download_mock_dataset -dataset_root = "mock_dataset/" -_, df_val = download_mock_dataset(dataset_root) +_, df_val = download_mock_dataset(global_paths=True) extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval() -val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root) +val_dataset = DatasetQueryGallery(df_val) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4) calculator = EmbeddingMetrics(extra_keys=("paths",)) diff --git a/docs/readme/examples_source/extractor/val_with_sequence.md b/docs/readme/examples_source/extractor/val_with_sequence.md index 2b015ea40..a0cfd916a 100644 --- a/docs/readme/examples_source/extractor/val_with_sequence.md +++ b/docs/readme/examples_source/extractor/val_with_sequence.md @@ -42,12 +42,11 @@ from oml.metrics.embeddings import EmbeddingMetrics from oml.models import ViTExtractor from oml.utils.download_mock_dataset import download_mock_dataset -dataset_root = "mock_dataset/" -_, df_val = download_mock_dataset(dataset_root, df_name="df_with_sequence.csv") # <- sequence info is in the file +_, df_val = download_mock_dataset(global_paths=True, df_name="df_with_sequence.csv") # <- sequence info is in the file extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval() -val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root) +val_dataset = DatasetQueryGallery(df_val) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4) calculator = EmbeddingMetrics(extra_keys=("paths",), sequence_key=val_dataset.sequence_key) diff --git a/docs/readme/examples_source/postprocessing/predict.md b/docs/readme/examples_source/postprocessing/predict.md index 22d6cc5f3..35fb69604 100644 --- a/docs/readme/examples_source/postprocessing/predict.md +++ b/docs/readme/examples_source/postprocessing/predict.md @@ -5,44 +5,40 @@ [comment]:postprocessor-pred-start ```python import torch -from torch.utils.data import DataLoader -from oml.const import PATHS_COLUMN -from oml.datasets.base import DatasetQueryGallery -from oml.inference.flat import inference_on_dataframe +from oml.datasets import ImageQueryGalleryDataset +from oml.inference import inference from oml.models import ConcatSiamese, ViTExtractor from oml.registry.transforms import get_transforms_for_pretrained -from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor +from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.utils.download_mock_dataset import download_mock_dataset from oml.utils.misc_torch import pairwise_dist -dataset_root = "mock_dataset/" -download_mock_dataset(dataset_root) +_, df_test = download_mock_dataset(global_paths=True) +del df_test["label"] # we don't need gt labels for doing predictions -# 1. Let's use feature extractor to get predictions extractor = ViTExtractor.from_pretrained("vits16_dino") transforms, _ = get_transforms_for_pretrained("vits16_dino") -_, emb_val, _, df_val = inference_on_dataframe(dataset_root, "df.csv", extractor, transforms=transforms) +dataset = ImageQueryGalleryDataset(df_test, transform=transforms) -is_query = df_val["is_query"].astype('bool').values -distances = pairwise_dist(x1=emb_val[is_query], x2=emb_val[~is_query]) +# 1. Let's get top 5 galleries closest to every query... +embeddings = inference(extractor, dataset, batch_size=4, num_workers=0) +embeddings_query = embeddings[dataset.get_query_ids()] +embeddings_gallery = embeddings[dataset.get_gallery_ids()] -print("\nOriginal predictions:\n", torch.topk(distances, dim=1, k=3, largest=False)[1]) +distances = pairwise_dist(x1=embeddings_query, x2=embeddings_gallery, p=2) +ii_closest = torch.topk(distances, dim=1, k=5, largest=False)[1] -# 2. Let's initialise a random pairwise postprocessor to perform re-ranking +# 2. ... and let's re-rank first 3 of them siamese = ConcatSiamese(extractor=extractor, mlp_hidden_dims=[100]) # Note! Replace it with your trained postprocessor -postprocessor = PairwiseImagesPostprocessor(top_n=3, pairwise_model=siamese, transforms=transforms) - -dataset = DatasetQueryGallery(df_val, extra_data={"embeddings": emb_val}, transform=transforms) -loader = DataLoader(dataset, batch_size=4) - -query_paths = df_val[PATHS_COLUMN][is_query].values -gallery_paths = df_val[PATHS_COLUMN][~is_query].values -distances_upd = postprocessor.process(distances=distances, queries=query_paths, galleries=gallery_paths) - -print("\nPredictions after postprocessing:\n", torch.topk(distances_upd, dim=1, k=3, largest=False)[1]) +postprocessor = PairwiseReranker(top_n=3, pairwise_model=siamese, batch_size=4, num_workers=0) +distances_upd = postprocessor.process(distances, dataset=dataset) +ii_closest_upd = torch.topk(distances_upd, dim=1, k=5, largest=False)[1] +# You may see the first 3 positions have changed, but the rest remain the same: +print("\Closest galleries:\n", ii_closest) +print("\nClosest galleries updates:\n", ii_closest_upd) ``` [comment]:postprocessor-pred-end

diff --git a/docs/readme/examples_source/postprocessing/train_val.md b/docs/readme/examples_source/postprocessing/train_val.md index 345f30c8a..8384a6dd7 100644 --- a/docs/readme/examples_source/postprocessing/train_val.md +++ b/docs/readme/examples_source/postprocessing/train_val.md @@ -10,52 +10,52 @@ import torch from torch.nn import BCEWithLogitsLoss from torch.utils.data import DataLoader -from oml.datasets.base import DatasetWithLabels, DatasetQueryGallery -from oml.inference.flat import inference_on_dataframe +from oml.datasets import ImageLabeledDataset, ImageQueryGalleryLabeledDataset, ImageBaseDataset +from oml.inference import inference from oml.metrics.embeddings import EmbeddingMetrics from oml.miners.pairs import PairsMiner from oml.models import ConcatSiamese, ViTExtractor -from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor +from oml.registry.transforms import get_transforms_for_pretrained +from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.samplers.balance import BalanceSampler -from oml.transforms.images.torchvision import get_normalisation_resize_torch from oml.utils.download_mock_dataset import download_mock_dataset +from oml.transforms.images.torchvision import get_augs_torch -# Let's start with saving embeddings of a pretrained extractor for which we want to build a postprocessor -dataset_root = "mock_dataset/" -download_mock_dataset(dataset_root) +# In these example we will train a pairwise model as a re-ranker for ViT +extractor = ViTExtractor.from_pretrained("vits16_dino") +transforms, _ = get_transforms_for_pretrained("vits16_dino") +df_train, df_val = download_mock_dataset(global_paths=True) -extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False) -transform = get_normalisation_resize_torch(im_size=64) +# SAVE VIT EMBEDDINGS +# - training ones are needed for hard negative sampling when training pairwise model +# - validation ones are needed to construct the original prediction (which we will re-rank) +embeddings_train = inference(extractor, ImageBaseDataset(df_train["path"].tolist(), transform=transforms), batch_size=4, num_workers=0) +embeddings_valid = inference(extractor, ImageBaseDataset(df_val["path"].tolist(), transform=transforms), batch_size=4, num_workers=0) -embeddings_train, embeddings_val, df_train, df_val = \ - inference_on_dataframe(dataset_root, "df.csv", extractor=extractor, transforms=transform) - -# We are building Siamese model on top of existing weights and train it to recognize positive/negative pairs -siamese = ConcatSiamese(extractor=extractor, mlp_hidden_dims=[100]) -optimizer = torch.optim.SGD(siamese.parameters(), lr=1e-6) +# TRAIN PAIRWISE MODEL +train_dataset = ImageLabeledDataset(df_train, transform=get_augs_torch(224), extra_data={"embeddings": embeddings_train}) +pairwise_model = ConcatSiamese(extractor=extractor, mlp_hidden_dims=[100]) +optimizer = torch.optim.SGD(pairwise_model.parameters(), lr=1e-6) miner = PairsMiner(hard_mining=True) criterion = BCEWithLogitsLoss() -train_dataset = DatasetWithLabels(df=df_train, transform=transform, extra_data={"embeddings": embeddings_train}) -batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2) -train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler) +train_loader = DataLoader(train_dataset, batch_sampler=BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)) for batch in train_loader: - # We sample pairs on which the original model struggled most + # We sample positive and negative pairs on which the original model struggled most ids1, ids2, is_negative_pair = miner.sample(features=batch["embeddings"], labels=batch["labels"]) - probs = siamese(x1=batch["input_tensors"][ids1], x2=batch["input_tensors"][ids2]) + probs = pairwise_model(x1=batch["input_tensors"][ids1], x2=batch["input_tensors"][ids2]) loss = criterion(probs, is_negative_pair.float()) - loss.backward() optimizer.step() optimizer.zero_grad() -# Siamese re-ranks top-n retrieval outputs of the original model performing inference on pairs (query, output_i) -val_dataset = DatasetQueryGallery(df=df_val, extra_data={"embeddings": embeddings_val}, transform=transform) +# VALIDATE RE-RANKING MODEL +val_dataset = ImageQueryGalleryLabeledDataset(df=df_val, transform=transforms, extra_data={"embeddings": embeddings_valid}) valid_loader = DataLoader(val_dataset, batch_size=4, shuffle=False) -postprocessor = PairwiseImagesPostprocessor(top_n=3, pairwise_model=siamese, transforms=transform) -calculator = EmbeddingMetrics(postprocessor=postprocessor) +postprocessor = PairwiseReranker(top_n=3, pairwise_model=pairwise_model, num_workers=0, batch_size=4) +calculator = EmbeddingMetrics(dataset=val_dataset, postprocessor=postprocessor) calculator.setup(num_samples=len(val_dataset)) for batch in valid_loader: diff --git a/docs/source/contents/datasets.rst b/docs/source/contents/datasets.rst index 9af621a1c..34cb41e59 100644 --- a/docs/source/contents/datasets.rst +++ b/docs/source/contents/datasets.rst @@ -53,18 +53,9 @@ ImageQueryGalleryDataset .. automethod:: get_gallery_ids .. automethod:: visualize -EmbeddingPairsDataset +PairDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.pairs.EmbeddingPairsDataset - :undoc-members: - :show-inheritance: - - .. automethod:: __init__ - .. automethod:: __getitem__ - -ImagePairsDataset -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.datasets.pairs.ImagePairsDataset +.. autoclass:: oml.datasets.pairs.PairDataset :undoc-members: :show-inheritance: diff --git a/docs/source/contents/interfaces.rst b/docs/source/contents/interfaces.rst index 7e7224491..81acf03b7 100644 --- a/docs/source/contents/interfaces.rst +++ b/docs/source/contents/interfaces.rst @@ -52,12 +52,22 @@ ITripletLossWithMiner .. automethod:: forward +IIndexedDataset +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: oml.interfaces.datasets.IIndexedDataset + :undoc-members: + :show-inheritance: + + .. automethod:: __getitem__ + IBaseDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: oml.interfaces.datasets.IBaseDataset :undoc-members: :show-inheritance: + .. automethod:: __getitem__ + ILabeledDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: oml.interfaces.datasets.ILabeledDataset @@ -86,9 +96,9 @@ IQueryGalleryLabeledDataset .. automethod:: get_gallery_ids .. automethod:: get_labels -IPairsDataset +IPairDataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.interfaces.datasets.IPairsDataset +.. autoclass:: oml.interfaces.datasets.IPairDataset :undoc-members: :show-inheritance: @@ -138,3 +148,11 @@ IPipelineLogger .. automethod:: log_figure .. automethod:: log_pipeline_info + +IRetrievalPostprocessor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: oml.interfaces.retrieval.IRetrievalPostprocessor + :undoc-members: + :show-inheritance: + + .. automethod:: process diff --git a/docs/source/contents/postprocessing.rst b/docs/source/contents/postprocessing.rst index a2e0be7b4..ea1f41227 100644 --- a/docs/source/contents/postprocessing.rst +++ b/docs/source/contents/postprocessing.rst @@ -7,37 +7,13 @@ Retrieval Post-Processing .. contents:: :local: -IDistancesPostprocessor +PairwiseReranker ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.interfaces.retrieval.IDistancesPostprocessor - :undoc-members: - :show-inheritance: - - .. automethod:: process - -PairwisePostprocessor -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwisePostprocessor - :undoc-members: - :show-inheritance: - - .. automethod:: process - .. automethod:: inference - -PairwiseEmbeddingsPostprocessor -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwiseEmbeddingsPostprocessor +.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwiseReranker :undoc-members: :show-inheritance: .. automethod:: __init__ - .. automethod:: inference - -PairwiseImagesPostprocessor -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: oml.retrieval.postprocessors.pairwise.PairwiseImagesPostprocessor - :undoc-members: - :show-inheritance: + .. automethod:: process + .. automethod:: process_neigh - .. automethod:: __init__ - .. automethod:: inference diff --git a/oml/configs/postprocessor/pairwise_embeddings.yaml b/oml/configs/postprocessor/pairwise_embeddings.yaml deleted file mode 100644 index 4399dc9e5..000000000 --- a/oml/configs/postprocessor/pairwise_embeddings.yaml +++ /dev/null @@ -1,11 +0,0 @@ -name: pairwise_embeddings -args: - top_n: 5 - pairwise_model: - name: linear_siamese - args: - feat_dim: 16 - identity_init: True - num_workers: 0 - batch_size: 4 - verbose: False diff --git a/oml/configs/postprocessor/pairwise_images.yaml b/oml/configs/postprocessor/pairwise_reranker.yaml similarity index 77% rename from oml/configs/postprocessor/pairwise_images.yaml rename to oml/configs/postprocessor/pairwise_reranker.yaml index 5cda25a92..50eeaf19d 100644 --- a/oml/configs/postprocessor/pairwise_images.yaml +++ b/oml/configs/postprocessor/pairwise_reranker.yaml @@ -1,4 +1,4 @@ -name: pairwise_images +name: pairwise_reranker args: top_n: 3 pairwise_model: @@ -12,10 +12,6 @@ args: remove_fc: True normalise_features: False weights: resnet50_moco_v2 - transforms: - name: norm_resize_torch - args: - im_size: 224 num_workers: 0 batch_size: 4 verbose: False diff --git a/oml/const.py b/oml/const.py index 907b0ca0f..8301a47f7 100644 --- a/oml/const.py +++ b/oml/const.py @@ -102,8 +102,8 @@ def get_cache_folder() -> Path: INDEX_KEY = "idx" SEQUENCE_KEY = "sequence" -PAIR_1ST_KEY = "input_tensors_1" -PAIR_2ND_KEY = "input_tensors_2" +INPUT_TENSORS_KEY_1 = "input_tensors_1" +INPUT_TENSORS_KEY_2 = "input_tensors_2" IMAGE_EXTENSIONS = ["jpg", "jpeg", "JPG", "JPEG", "png"] diff --git a/oml/datasets/__init__.py b/oml/datasets/__init__.py index e69de29bb..682b6a93b 100644 --- a/oml/datasets/__init__.py +++ b/oml/datasets/__init__.py @@ -0,0 +1,7 @@ +from oml.datasets.images import ( + ImageBaseDataset, + ImageLabeledDataset, + ImageQueryGalleryDataset, + ImageQueryGalleryLabeledDataset, +) +from oml.datasets.pairs import PairDataset diff --git a/oml/datasets/images.py b/oml/datasets/images.py index 28b16fe12..4d1e2f2ec 100644 --- a/oml/datasets/images.py +++ b/oml/datasets/images.py @@ -359,13 +359,12 @@ def get_query_ids(self) -> LongTensor: def get_gallery_ids(self) -> LongTensor: return BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze() - def __getitem__(self, idx: int) -> Dict[str, Any]: - data = super().__getitem__(idx) - data[self.labels_key] = self.df.iloc[idx][LABELS_COLUMN] + def __getitem__(self, item: int) -> Dict[str, Any]: + data = super().__getitem__(item) # todo 522: remove - data[self.is_query_key] = bool(self.df[IS_QUERY_COLUMN][idx]) - data[self.is_gallery_key] = bool(self.df[IS_GALLERY_COLUMN][idx]) + data[self.is_query_key] = bool(self.df[IS_QUERY_COLUMN][item]) + data[self.is_gallery_key] = bool(self.df[IS_GALLERY_COLUMN][item]) return data @@ -423,6 +422,13 @@ def __init__( is_gallery_key=is_gallery_key, ) + self.input_tensors_key = self.__dataset.input_tensors_key + self.index_key = self.__dataset.index_key + + # todo 522: remove + self.is_query_key = self.__dataset.is_query_key + self.is_gallery_key = self.__dataset.is_gallery_key + def __getitem__(self, item: int) -> Dict[str, Any]: batch = self.__dataset[item] del batch[self.__dataset.labels_key] @@ -455,7 +461,6 @@ def get_retrieval_images_datasets( check_retrieval_dataframe_format(df, dataset_root=dataset_root, verbose=verbose) - # todo 522: why do we need it? # first half will consist of "train" split, second one of "val" # so labels in train will be from 0 to N-1 and labels in test will be from N to K mapper = {l: i for i, l in enumerate(df.sort_values(by=[SPLIT_COLUMN])[LABELS_COLUMN].unique())} diff --git a/oml/datasets/pairs.py b/oml/datasets/pairs.py index 0c7b44d10..2aca4c530 100644 --- a/oml/datasets/pairs.py +++ b/oml/datasets/pairs.py @@ -1,115 +1,43 @@ -from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Tuple, Union from torch import Tensor -from oml.const import INDEX_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY, TBBoxes -from oml.datasets.images import ImageBaseDataset -from oml.interfaces.datasets import IPairsDataset -from oml.transforms.images.torchvision import get_normalisation_torch -from oml.transforms.images.utils import TTransforms -from oml.utils.images.images import TImReader, imread_pillow +from oml.const import INDEX_KEY, INPUT_TENSORS_KEY_1, INPUT_TENSORS_KEY_2 +from oml.interfaces.datasets import IBaseDataset, IPairDataset -# todo 522: make one modality agnostic instead of these two - -class EmbeddingPairsDataset(IPairsDataset): +class PairDataset(IPairDataset): """ - Dataset to iterate over pairs of embeddings. + Dataset to iterate over pairs of items of any modality. """ def __init__( self, - embeddings1: Tensor, - embeddings2: Tensor, - pair_1st_key: str = PAIR_1ST_KEY, - pair_2nd_key: str = PAIR_2ND_KEY, + base_dataset: IBaseDataset, + pair_ids: List[Tuple[int, int]], + input_tensors_key_1: str = INPUT_TENSORS_KEY_1, + input_tensors_key_2: str = INPUT_TENSORS_KEY_2, index_key: str = INDEX_KEY, ): - """ - - Args: - embeddings1: The first input embeddings - embeddings2: The second input embeddings - pair_1st_key: Key to put ``embeddings1`` into the batches - pair_2nd_key: Key to put ``embeddings2`` into the batches - index_key: Key to put samples' ids into the batches - - """ - assert embeddings1.shape == embeddings2.shape - assert embeddings1.ndim >= 2 - - self.pair_1st_key = pair_1st_key - self.pair_2nd_key = pair_2nd_key - self.index_key = index_key - - self.embeddings1 = embeddings1 - self.embeddings2 = embeddings2 - - def __getitem__(self, idx: int) -> Dict[str, Tensor]: - return {self.pair_1st_key: self.embeddings1[idx], self.pair_2nd_key: self.embeddings2[idx], self.index_key: idx} - - def __len__(self) -> int: - return len(self.embeddings1) - - -class ImagePairsDataset(IPairsDataset): - """ - Dataset to iterate over pairs of images. - - """ - - def __init__( - self, - paths1: List[Path], - paths2: List[Path], - bboxes1: Optional[TBBoxes] = None, - bboxes2: Optional[TBBoxes] = None, - transform: Optional[TTransforms] = None, - f_imread: TImReader = imread_pillow, - pair_1st_key: str = PAIR_1ST_KEY, - pair_2nd_key: str = PAIR_2ND_KEY, - index_key: str = INDEX_KEY, - cache_size: Optional[int] = 0, - ): - """ - Args: - paths1: Paths to the 1st input images - paths2: Paths to the 2nd input images - bboxes1: Should be either ``None`` or a sequence of bboxes. - If an image has ``N`` boxes, duplicate its - path ``N`` times and provide bounding box for each of them. - If you want to get an embedding for the whole image, set bbox to ``None`` for - this particular image path. The format is ``x1, y1, x2, y2``. - bboxes2: The same as ``bboxes2``, but for the second inputs. - transform: Augmentations for the images, set ``None`` to perform only normalisation and casting to tensor - f_imread: Function to read the images - pair_1st_key: Key to put the 1st images into the batches - pair_2nd_key: Key to put the 2nd images into the batches - index_key: Key to put samples' ids into the batches - cache_size: Size of the dataset's cache - - """ - assert len(paths1) == len(paths2) - - if transform is None: - transform = get_normalisation_torch() - - cache_size = cache_size // 2 if cache_size else None - dataset_args = {"transform": transform, "f_imread": f_imread, "cache_size": cache_size} - self.dataset1 = ImageBaseDataset(paths=paths1, bboxes=bboxes1, **dataset_args) - self.dataset2 = ImageBaseDataset(paths=paths2, bboxes=bboxes2, **dataset_args) - - self.pair_1st_key = pair_1st_key - self.pair_2nd_key = pair_2nd_key - self.index_key = index_key - - def __getitem__(self, idx: int) -> Dict[str, Union[int, Dict[str, Any]]]: - return {self.pair_1st_key: self.dataset1[idx], self.pair_2nd_key: self.dataset2[idx], self.index_key: idx} + self.base_dataset = base_dataset + self.pair_ids = pair_ids + + self.input_tensors_key_1 = input_tensors_key_1 + self.input_tensors_key_2 = input_tensors_key_2 + self.index_key: str = index_key + + def __getitem__(self, item: int) -> Dict[str, Union[Tensor, int]]: + i1, i2 = self.pair_ids[item] + key = self.base_dataset.input_tensors_key + return { + self.input_tensors_key_1: self.base_dataset[i1][key], + self.input_tensors_key_2: self.base_dataset[i2][key], + self.index_key: item, + } def __len__(self) -> int: - return len(self.dataset1) + return len(self.pair_ids) -__all__ = ["EmbeddingPairsDataset", "ImagePairsDataset"] +__all__ = ["PairDataset"] diff --git a/oml/inference/__init__.py b/oml/inference/__init__.py index e69de29bb..753875058 100644 --- a/oml/inference/__init__.py +++ b/oml/inference/__init__.py @@ -0,0 +1 @@ +from oml.inference.abstract import inference, inference_cached, pairwise_inference diff --git a/oml/inference/abstract.py b/oml/inference/abstract.py index be106f025..ac572c395 100644 --- a/oml/inference/abstract.py +++ b/oml/inference/abstract.py @@ -1,12 +1,16 @@ -from typing import Any, Callable, Dict +from pathlib import Path +from typing import Any, Callable, Dict, List, Tuple import torch -from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset +from torch import FloatTensor, Tensor, nn +from torch.utils.data import DataLoader from tqdm.auto import tqdm +from oml.datasets import PairDataset from oml.ddp.patching import patch_dataloader_to_ddp from oml.ddp.utils import get_world_size_safe, is_ddp, sync_dicts_ddp +from oml.interfaces.datasets import IBaseDataset, IIndexedDataset +from oml.interfaces.models import IPairwiseModel from oml.utils.misc_torch import get_device, temporary_setting_model_mode, unique_by_ids @@ -14,15 +18,13 @@ def _inference( model: nn.Module, apply_model: Callable[[nn.Module, Dict[str, Any]], Tensor], - dataset: Dataset, # type: ignore + dataset: IIndexedDataset, num_workers: int, batch_size: int, verbose: bool, use_fp16: bool, accumulate_on_cpu: bool = True, ) -> Tensor: - assert hasattr(dataset, "index_key"), "We expect that your dataset returns samples ids in __getitem__ method" - loader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) if is_ddp(): @@ -57,4 +59,97 @@ def _inference( return outputs -__all__ = ["_inference"] +@torch.no_grad() +def inference( + model: nn.Module, + dataset: IBaseDataset, + batch_size: int, + num_workers: int = 0, + verbose: bool = False, + use_fp16: bool = False, + accumulate_on_cpu: bool = True, +) -> Tensor: + device = get_device(model) + + def apply(model_: nn.Module, batch_: Dict[str, Any]) -> FloatTensor: + return model_(batch_[dataset.input_tensors_key].to(device)) + + return _inference( + model=model, + apply_model=apply, + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + verbose=verbose, + use_fp16=use_fp16, + accumulate_on_cpu=accumulate_on_cpu, + ) + + +def inference_cached( + model: nn.Module, + dataset: IBaseDataset, + batch_size: int, + num_workers: int = 0, + use_fp16: bool = False, + verbose: bool = True, + accumulate_on_cpu: bool = True, + cache_path: str = "inference_cache.pth", +) -> Tensor: + if Path(cache_path).is_file(): + outputs = torch.load(cache_path, map_location="cpu") + print(f"Model outputs have been loaded from {cache_path}.") + else: + outputs = inference( + model=model, + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + use_fp16=use_fp16, + verbose=verbose, + accumulate_on_cpu=accumulate_on_cpu, + ) + + torch.save(outputs, cache_path) + print(f"Model outputs have been saved to {cache_path}.") + + return outputs + + +def pairwise_inference( + model: IPairwiseModel, + base_dataset: IBaseDataset, + pair_ids: List[Tuple[int, int]], + num_workers: int, + batch_size: int, + verbose: bool = True, + use_fp16: bool = False, + accumulate_on_cpu: bool = True, +) -> Tensor: + device = get_device(model) + + dataset = PairDataset(base_dataset=base_dataset, pair_ids=pair_ids) + + def _apply( + model_: IPairwiseModel, + batch_: Dict[str, Any], + ) -> Tensor: + pair1 = batch_[dataset.input_tensors_key_1].to(device) + pair2 = batch_[dataset.input_tensors_key_2].to(device) + return model_.predict(pair1, pair2) + + output = _inference( + model=model, + apply_model=_apply, + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + verbose=verbose, + use_fp16=use_fp16, + accumulate_on_cpu=accumulate_on_cpu, + ) + + return output + + +__all__ = ["inference", "pairwise_inference", "inference_cached"] diff --git a/oml/inference/flat.py b/oml/inference/flat.py deleted file mode 100644 index e8e4c063f..000000000 --- a/oml/inference/flat.py +++ /dev/null @@ -1,100 +0,0 @@ -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -import pandas as pd -import torch -from pandas import DataFrame -from torch import Tensor, nn - -from oml.const import PATHS_COLUMN, SPLIT_COLUMN -from oml.datasets.images import ImageBaseDataset -from oml.inference.abstract import _inference -from oml.interfaces.models import IExtractor -from oml.transforms.images.utils import TTransforms -from oml.utils.dataframe_format import check_retrieval_dataframe_format -from oml.utils.images.images import TImReader -from oml.utils.misc_torch import get_device - - -@torch.no_grad() -def inference_on_images( - model: nn.Module, - paths: List[Path], - transform: TTransforms, - num_workers: int, - batch_size: int, - verbose: bool = False, - f_imread: Optional[TImReader] = None, - use_fp16: bool = False, - accumulate_on_cpu: bool = True, -) -> Tensor: - dataset = ImageBaseDataset(paths=paths, bboxes=None, transform=transform, f_imread=f_imread, cache_size=0) - device = get_device(model) - - def _apply(model_: nn.Module, batch_: Dict[str, Any]) -> Tensor: - return model_(batch_[dataset.input_tensors_key].to(device)) - - outputs = _inference( - model=model, - apply_model=_apply, - dataset=dataset, - num_workers=num_workers, - batch_size=batch_size, - verbose=verbose, - use_fp16=use_fp16, - accumulate_on_cpu=accumulate_on_cpu, - ) - - return outputs - - -def inference_on_dataframe( - dataset_root: Union[Path, str], - dataframe_name: str, - extractor: IExtractor, - transforms: TTransforms, - output_cache_path: Optional[Union[str, Path]] = None, - num_workers: int = 0, - batch_size: int = 128, - use_fp16: bool = False, -) -> Tuple[Tensor, Tensor, DataFrame, DataFrame]: - df = pd.read_csv(Path(dataset_root) / dataframe_name) - - # it has now affect if paths are global already - df[PATHS_COLUMN] = df[PATHS_COLUMN].apply(lambda x: Path(dataset_root) / x) - - check_retrieval_dataframe_format(df) - - if (output_cache_path is not None) and Path(output_cache_path).is_file(): - embeddings = torch.load(output_cache_path, map_location="cpu") - print("Embeddings have been loaded from the disk.") - else: - embeddings = inference_on_images( - model=extractor, - paths=df[PATHS_COLUMN], - transform=transforms, - num_workers=num_workers, - batch_size=batch_size, - verbose=True, - use_fp16=use_fp16, - accumulate_on_cpu=True, - ) - if output_cache_path is not None: - torch.save(embeddings, output_cache_path) - print("Embeddings have been saved to the disk.") - - train_mask = df[SPLIT_COLUMN] == "train" - - emb_train = embeddings[train_mask] - emb_val = embeddings[~train_mask] - - df_train = df[train_mask] - df_train.reset_index(inplace=True, drop=True) - - df_val = df[~train_mask] - df_val.reset_index(inplace=True, drop=True) - - return emb_train, emb_val, df_train, df_val - - -__all__ = ["inference_on_images", "inference_on_dataframe"] diff --git a/oml/inference/pairs.py b/oml/inference/pairs.py deleted file mode 100644 index 42317c97c..000000000 --- a/oml/inference/pairs.py +++ /dev/null @@ -1,97 +0,0 @@ -from pathlib import Path -from typing import Any, Dict, List, Optional - -from torch import Tensor - -from oml.datasets.pairs import EmbeddingPairsDataset, ImagePairsDataset -from oml.inference.abstract import _inference -from oml.interfaces.models import IPairwiseModel -from oml.transforms.images.utils import TTransforms -from oml.utils.images.images import TImReader -from oml.utils.misc_torch import get_device - - -def pairwise_inference_on_images( - model: IPairwiseModel, - paths1: List[Path], - paths2: List[Path], - transform: TTransforms, - num_workers: int, - batch_size: int, - verbose: bool = True, - f_imread: Optional[TImReader] = None, - use_fp16: bool = False, - accumulate_on_cpu: bool = True, -) -> Tensor: - device = get_device(model) - - dataset = ImagePairsDataset( - paths1=paths1, - paths2=paths2, - transform=transform, - f_imread=f_imread, - cache_size=0, - ) - - def _apply( - model_: IPairwiseModel, - batch_: Dict[str, Any], - ) -> Tensor: - pair1 = batch_[dataset.pair_1st_key][dataset.dataset1.input_tensors_key].to(device) - pair2 = batch_[dataset.pair_2nd_key][dataset.dataset2.input_tensors_key].to(device) - return model_.predict(pair1, pair2) - - output = _inference( - model=model, - apply_model=_apply, - dataset=dataset, - num_workers=num_workers, - batch_size=batch_size, - verbose=verbose, - use_fp16=use_fp16, - accumulate_on_cpu=accumulate_on_cpu, - ) - - return output - - -def pairwise_inference_on_embeddings( - model: IPairwiseModel, - embeddings1: Tensor, - embeddings2: Tensor, - num_workers: int, - batch_size: int, - verbose: bool = False, - use_fp16: bool = False, - accumulate_on_cpu: bool = True, -) -> Tensor: - device = get_device(model) - - dataset = EmbeddingPairsDataset(embeddings1=embeddings1, embeddings2=embeddings2) - - def _apply( - model_: IPairwiseModel, - batch_: Dict[str, Any], - ) -> Tensor: - pair1 = batch_[dataset.pair_1st_key].to(device) - pair2 = batch_[dataset.pair_2nd_key].to(device) - return model_.predict(pair1, pair2) - - output = _inference( - model=model, - apply_model=_apply, - dataset=dataset, - num_workers=num_workers, - batch_size=batch_size, - verbose=verbose, - use_fp16=use_fp16, - accumulate_on_cpu=accumulate_on_cpu, - ) - - return output - - -__all__ = [ - "pairwise_inference_on_images", - "pairwise_inference_on_embeddings", -] diff --git a/oml/interfaces/datasets.py b/oml/interfaces/datasets.py index e5c3009e8..e5a967443 100644 --- a/oml/interfaces/datasets.py +++ b/oml/interfaces/datasets.py @@ -5,12 +5,28 @@ from torch import LongTensor from torch.utils.data import Dataset -from oml.const import INDEX_KEY, LABELS_KEY, PAIR_1ST_KEY, PAIR_2ND_KEY, TColor +from oml.const import INPUT_TENSORS_KEY_1, INPUT_TENSORS_KEY_2, LABELS_KEY, TColor -class IBaseDataset(Dataset): - input_tensors_key: str +class IIndexedDataset(Dataset, ABC): index_key: str + + def __getitem__(self, item: int) -> Dict[str, Any]: + """ + + Args: + item: Idx of the sample + + Returns: + Dictionary having the following key: + ``self.index_key: int = item`` + + """ + raise NotImplementedError() + + +class IBaseDataset(IIndexedDataset, ABC): + input_tensors_key: str extra_data: Dict[str, Any] def __getitem__(self, item: int) -> Dict[str, Any]: @@ -45,6 +61,8 @@ def __getitem__(self, item: int) -> Dict[str, Any]: Returns: Dictionary including the following keys: + ``self.input_tensors_key`` + ``self.index_key: int = item`` ``self.labels_key`` """ @@ -78,15 +96,14 @@ class IQueryGalleryLabeledDataset(IQueryGalleryDataset, ILabeledDataset, ABC): """ -class IPairsDataset(Dataset, ABC): +class IPairDataset(IIndexedDataset): """ This is an interface for the datasets which return pair of something. """ - pairs_1st_key: str = PAIR_1ST_KEY - pairs_2nd_key: str = PAIR_2ND_KEY - index_key: str = INDEX_KEY + input_tensors_key_1: str = INPUT_TENSORS_KEY_1 + input_tensors_key_2: str = INPUT_TENSORS_KEY_2 @abstractmethod def __getitem__(self, item: int) -> Dict[str, Any]: @@ -97,8 +114,8 @@ def __getitem__(self, item: int) -> Dict[str, Any]: Returns: Dictionary with the following keys: - ``self.pairs_1st_key`` - ``self.pairs_2nd_key`` + ``self.input_tensors_key_1`` + ``self.input_tensors_key_2`` ``self.index_key`` """ @@ -116,10 +133,11 @@ def visualize(self, item: int, color: TColor) -> np.ndarray: __all__ = [ + "IIndexedDataset", "IBaseDataset", "ILabeledDataset", "IQueryGalleryLabeledDataset", "IQueryGalleryDataset", - "IPairsDataset", + "IPairDataset", "IVisualizableDataset", ] diff --git a/oml/interfaces/retrieval.py b/oml/interfaces/retrieval.py index c593e8c14..7422f6c35 100644 --- a/oml/interfaces/retrieval.py +++ b/oml/interfaces/retrieval.py @@ -1,56 +1,15 @@ -from typing import Any, Dict, List +from typing import Any -from torch import Tensor - -class IDistancesPostprocessor: +class IRetrievalPostprocessor: """ - This is a parent class for the classes which apply some postprocessing - after query-to-gallery distance matrix has been calculated. - For example, we may want to apply one of re-ranking techniques. + This is a base interface for the classes which somehow postprocess retrieval results. """ - def process(self, distances: Tensor, queries: Any, galleries: Any) -> Tensor: - """ - This method takes all the needed variables and returns - the modified matrix of distances, where some distances are - replaced with new ones. - - Args: - distances: Matrix with the shape of ``[Q, G]`` - queries: Queries in the amount of ``Q`` - galleries: Galleries in the amount of ``G`` - - Returns: - An updated distances matrix with the shape of ``[Q, G]`` - - """ - raise NotImplementedError() - - def process_by_dict(self, distances: Tensor, data: Dict[str, Any]) -> Tensor: - """ - This method is the analogue of ``process``, but data is passed as a dictionary, - so we need to use the corresponding keys, which also have to be obtainable by - ``needed_keys`` property. - - Args: - distances: Matrix with the shape of ``[Q, G]`` - data: Dictionary of data - - Returns: - An updated distances matrix with the shape of ``[Q, G]`` - - """ - raise NotImplementedError() - - @property - def needed_keys(self) -> List[str]: - """ - Returns: Keys that will be used to process data using ``process_by_dict`` - - """ + def process(self, *args, **kwargs) -> Any: # type: ignore + # todo 522: add actual signature later raise NotImplementedError() -__all__ = ["IDistancesPostprocessor"] +__all__ = ["IRetrievalPostprocessor"] diff --git a/oml/lightning/pipelines/predict.py b/oml/lightning/pipelines/predict.py index 8a4e8f0dc..63b170d86 100644 --- a/oml/lightning/pipelines/predict.py +++ b/oml/lightning/pipelines/predict.py @@ -34,6 +34,9 @@ def extractor_prediction_pipeline(cfg: TCfg) -> None: filenames = [list(Path(cfg["data_dir"]).glob(f"**/*.{ext}")) for ext in IMAGE_EXTENSIONS] filenames = list(itertools.chain(*filenames)) + if len(filenames) == 0: + raise RuntimeError(f"There are no images in the provided directory: {cfg['data_dir']}") + f_imread = get_im_reader_for_transforms(transforms) print("Let's check if there are broken images:") diff --git a/oml/lightning/pipelines/train_postprocessor.py b/oml/lightning/pipelines/train_postprocessor.py index cc302941f..e372627cb 100644 --- a/oml/lightning/pipelines/train_postprocessor.py +++ b/oml/lightning/pipelines/train_postprocessor.py @@ -3,16 +3,16 @@ from pprint import pprint from typing import Any, Dict, Tuple -import pandas as pd import pytorch_lightning as pl import torch from omegaconf import DictConfig from torch import device as tdevice from torch.utils.data import DataLoader -from oml.const import BBOXES_COLUMNS, EMBEDDINGS_KEY, TCfg +from oml.const import EMBEDDINGS_KEY, TCfg from oml.datasets.base import ImageLabeledDataset, ImageQueryGalleryLabeledDataset -from oml.inference.flat import inference_on_dataframe +from oml.datasets.images import get_retrieval_images_datasets +from oml.inference import inference, inference_cached from oml.interfaces.models import IPairwiseModel from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP from oml.lightning.modules.pairwise_postprocessing import ( @@ -33,8 +33,7 @@ from oml.registry.optimizers import get_optimizer_by_cfg from oml.registry.postprocessors import get_postprocessor_by_cfg from oml.registry.transforms import get_transforms_by_cfg -from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor -from oml.transforms.images.torchvision import get_normalisation_resize_torch +from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.utils.misc import dictconfig_to_dict, flatten_dict, set_global_seed @@ -56,46 +55,51 @@ def dict2str(dictionary: Dict[str, Any]) -> str: def get_loaders_with_embeddings(cfg: TCfg) -> Tuple[DataLoader, DataLoader]: - # todo: support bounding bboxes - df = pd.read_csv(Path(cfg["dataset_root"]) / cfg["dataframe_name"]) - assert not set(BBOXES_COLUMNS).intersection( - df.columns - ), "We've found bboxes in the dataframe, but they're not supported yet." - device = tdevice("cuda:0") if parse_engine_params_from_config(cfg)["accelerator"] == "gpu" else tdevice("cpu") extractor = get_extractor_by_cfg(cfg["extractor"]).to(device) - if cfg["embeddings_cache_dir"] is not None: - cache_file = Path(cfg["embeddings_cache_dir"]) / f"embeddings_{get_hash_of_extraction_stage_cfg(cfg)[:5]}.pkl" - else: - cache_file = None + transforms_extraction = get_transforms_by_cfg(cfg["transforms_extraction"]) - emb_train, emb_val, df_train, df_val = inference_on_dataframe( - extractor=extractor, - dataset_root=cfg["dataset_root"], - output_cache_path=cache_file, + train_extraction, val_extraction = get_retrieval_images_datasets( + dataset_root=Path(cfg["dataset_root"]), dataframe_name=cfg["dataframe_name"], - transforms=get_transforms_by_cfg(cfg["transforms_extraction"]), - num_workers=cfg["num_workers"], - batch_size=cfg["batch_size_inference"], - use_fp16=int(cfg.get("precision", 32)) == 16, + transforms_train=transforms_extraction, + transforms_val=transforms_extraction, ) + args = { + "model": extractor, + "num_workers": cfg["num_workers"], + "batch_size": cfg["batch_size_inference"], + "use_fp16": int(cfg.get("precision", 32)) == 16, + } + + if cfg["embeddings_cache_dir"] is not None: + hash_ = get_hash_of_extraction_stage_cfg(cfg)[:5] + dir_ = Path(cfg["embeddings_cache_dir"]) + emb_train = inference_cached(dataset=train_extraction, cache_path=str(dir_ / f"emb_train_{hash_}.pkl"), **args) + emb_val = inference_cached(dataset=val_extraction, cache_path=str(dir_ / f"emb_val_{hash_}.pkl"), **args) + else: + emb_train = inference(dataset=train_extraction, **args) + emb_val = inference(dataset=val_extraction, **args) + train_dataset = ImageLabeledDataset( - df=df_train, + dataset_root=cfg["dataset_root"], + df=train_extraction.df, transform=get_transforms_by_cfg(cfg["transforms_train"]), extra_data={EMBEDDINGS_KEY: emb_train}, ) valid_dataset = ImageQueryGalleryLabeledDataset( - df=df_val, - # we don't care about transforms, since the only goal of this dataset is to deliver embeddings - transform=get_normalisation_resize_torch(im_size=8), + dataset_root=cfg["dataset_root"], + df=val_extraction.df, + transform=transforms_extraction, extra_data={EMBEDDINGS_KEY: emb_val}, ) sampler = parse_sampler_from_config(cfg, dataset=train_dataset) - assert sampler is not None + assert sampler is not None, "We will be training on pairs, so, having sampler is obligatory." + loader_train = DataLoader(batch_sampler=sampler, dataset=train_dataset, num_workers=cfg["num_workers"]) loader_val = DataLoader( @@ -128,7 +132,7 @@ def postprocessor_training_pipeline(cfg: DictConfig) -> None: loader_train, loader_val = get_loaders_with_embeddings(cfg) postprocessor = None if not cfg.get("postprocessor", None) else get_postprocessor_by_cfg(cfg["postprocessor"]) - assert isinstance(postprocessor, PairwiseImagesPostprocessor), "We support only images processing in this pipeline." + assert isinstance(postprocessor, PairwiseReranker), "We support only images processing in this pipeline." assert isinstance(postprocessor.model, IPairwiseModel), f"You model must be a child of {IPairwiseModel.__name__}" criterion = torch.nn.BCEWithLogitsLoss() @@ -157,6 +161,7 @@ def postprocessor_training_pipeline(cfg: DictConfig) -> None: metrics_constructor = EmbeddingMetricsDDP if is_ddp else EmbeddingMetrics metrics_calc = metrics_constructor( + dataset=loader_val.dataset, embeddings_key=pl_module.embeddings_key, categories_key=loader_val.dataset.categories_key, labels_key=loader_val.dataset.labels_key, diff --git a/oml/lightning/pipelines/validate.py b/oml/lightning/pipelines/validate.py index 543522bf4..598c5e70e 100644 --- a/oml/lightning/pipelines/validate.py +++ b/oml/lightning/pipelines/validate.py @@ -65,10 +65,12 @@ def extractor_validation_pipeline(cfg: TCfg) -> Tuple[pl.Trainer, Dict[str, Any] postprocessor = None if not cfg.get("postprocessor", None) else get_postprocessor_by_cfg(cfg["postprocessor"]) # Note! We add the link to our extractor to a Lightning's Module, so it can recognize it and manipulate its devices - pl_model.model_link_ = getattr(postprocessor, "extractor", None) + if postprocessor is not None: + pl_model.model_link_ = postprocessor.model # type: ignore metrics_constructor = EmbeddingMetricsDDP if is_ddp else EmbeddingMetrics metrics_calc = metrics_constructor( + dataset=valid_dataset, embeddings_key=pl_model.embeddings_key, categories_key=valid_dataset.categories_key, labels_key=valid_dataset.labels_key, diff --git a/oml/metrics/embeddings.py b/oml/metrics/embeddings.py index b1d03c3a4..3454fc582 100644 --- a/oml/metrics/embeddings.py +++ b/oml/metrics/embeddings.py @@ -12,6 +12,7 @@ EMBEDDINGS_KEY, GRAY, GREEN, + INDEX_KEY, IS_GALLERY_KEY, IS_QUERY_KEY, LABELS_KEY, @@ -39,8 +40,9 @@ reduce_metrics, take_unreduced_metrics_by_mask, ) +from oml.interfaces.datasets import IQueryGalleryLabeledDataset from oml.interfaces.metrics import IMetricDDP, IMetricVisualisable -from oml.interfaces.retrieval import IDistancesPostprocessor +from oml.interfaces.retrieval import IRetrievalPostprocessor from oml.metrics.accumulation import Accumulator from oml.utils.images.images import get_img_with_bbox, square_pad from oml.utils.misc import flatten_dict @@ -69,6 +71,7 @@ class EmbeddingMetrics(IMetricVisualisable): def __init__( self, + dataset: Optional[IQueryGalleryLabeledDataset] = None, # todo 522: This argument will not be Optional soon. embeddings_key: str = EMBEDDINGS_KEY, labels_key: str = LABELS_KEY, is_query_key: str = IS_QUERY_KEY, @@ -81,7 +84,7 @@ def __init__( pcf_variance: Tuple[float, ...] = (0.5,), categories_key: Optional[str] = None, sequence_key: Optional[str] = None, - postprocessor: Optional[IDistancesPostprocessor] = None, + postprocessor: Optional[IRetrievalPostprocessor] = None, metrics_to_exclude_from_visualization: Iterable[str] = (), return_only_overall_category: bool = False, visualize_only_overall_category: bool = True, @@ -90,6 +93,7 @@ def __init__( """ Args: + dataset: Annotated dataset having query-gallery split. embeddings_key: Key to take the embeddings from the batches labels_key: Key to take the labels from the batches is_query_key: Key to take the information whether every batch sample belongs to the query @@ -117,6 +121,7 @@ def __init__( verbose: Set ``True`` if you want to print metrics """ + self.dataset = dataset self.embeddings_key = embeddings_key self.labels_key = labels_key self.is_query_key = is_query_key @@ -144,14 +149,13 @@ def __init__( self.verbose = verbose keys_to_accumulate = [self.embeddings_key, self.is_query_key, self.is_gallery_key, self.labels_key] + keys_to_accumulate += [INDEX_KEY] # todo 522: remove it after we make "indices" not optional in .update_data() if self.categories_key: keys_to_accumulate.append(self.categories_key) if self.sequence_key: keys_to_accumulate.append(self.sequence_key) if self.extra_keys: keys_to_accumulate.extend(list(extra_keys)) - if self.postprocessor: - keys_to_accumulate.extend(self.postprocessor.needed_keys) self.keys_to_accumulate = tuple(set(keys_to_accumulate)) self.acc = Accumulator(keys_to_accumulate=self.keys_to_accumulate) @@ -199,7 +203,11 @@ def _calc_matrices(self) -> None: validate_dataset(mask_gt=self.mask_gt, mask_to_ignore=mask_to_ignore) if self.postprocessor: - self.distance_matrix = self.postprocessor.process_by_dict(self.distance_matrix, data=self.acc.storage) + assert self.dataset, "You must pass dataset to init to make postprocessing." + # todo 522: remove this assert after "indices" become not optional + ii_aligned = list(range(len(self.dataset))) + assert ii_aligned == self.acc.storage[INDEX_KEY].tolist(), "The data is shuffled!" # type: ignore + self.distance_matrix = self.postprocessor.process(self.distance_matrix, dataset=self.dataset) def compute_metrics(self) -> TMetricsDict_ByLabels: # type: ignore if not self.acc.is_storage_full(): diff --git a/oml/models/meta/siamese.py b/oml/models/meta/siamese.py index 4c843f90e..b0f0b6d3d 100644 --- a/oml/models/meta/siamese.py +++ b/oml/models/meta/siamese.py @@ -18,16 +18,18 @@ class LinearTrivialDistanceSiamese(IPairwiseModel): """ - def __init__(self, feat_dim: int, identity_init: bool): + def __init__(self, feat_dim: int, identity_init: bool, output_bias: float = 0): """ Args: feat_dim: Expected size of each input. identity_init: If ``True``, models' weights initialised in a way when the model simply estimates L2 distance between the original embeddings. + output_bias: Value to add to the output. """ super(LinearTrivialDistanceSiamese, self).__init__() self.feat_dim = feat_dim + self.output_bias = output_bias self.proj = torch.nn.Linear(in_features=feat_dim, out_features=feat_dim, bias=False) @@ -46,7 +48,7 @@ def forward(self, x1: Tensor, x2: Tensor) -> Tensor: """ x1 = self.proj(x1) x2 = self.proj(x2) - y = elementwise_dist(x1, x2, p=2) + y = elementwise_dist(x1, x2, p=2) + self.output_bias return y def predict(self, x1: Tensor, x2: Tensor) -> Tensor: @@ -145,14 +147,16 @@ class TrivialDistanceSiamese(IPairwiseModel): pretrained_models: Dict[str, Any] = {} - def __init__(self, extractor: IExtractor) -> None: + def __init__(self, extractor: IExtractor, output_bias: float = 0) -> None: """ Args: extractor: Instance of ``IExtractor`` (e.g. ``ViTExtractor``) + output_bias: Value to add to the outputs. """ super(TrivialDistanceSiamese, self).__init__() self.extractor = extractor + self.output_bias = output_bias def forward(self, x1: Tensor, x2: Tensor) -> Tensor: """ @@ -166,7 +170,7 @@ def forward(self, x1: Tensor, x2: Tensor) -> Tensor: """ x1 = self.extractor(x1) x2 = self.extractor(x2) - return elementwise_dist(x1, x2, p=2) + return elementwise_dist(x1, x2, p=2) + self.output_bias def predict(self, x1: Tensor, x2: Tensor) -> Tensor: return self.forward(x1=x1, x2=x2) diff --git a/oml/registry/postprocessors.py b/oml/registry/postprocessors.py index 30e8ee57d..015971b77 100644 --- a/oml/registry/postprocessors.py +++ b/oml/registry/postprocessors.py @@ -1,33 +1,26 @@ from typing import Any, Dict from oml.const import TCfg -from oml.interfaces.retrieval import IDistancesPostprocessor +from oml.interfaces.retrieval import IRetrievalPostprocessor from oml.registry.models import get_pairwise_model_by_cfg -from oml.registry.transforms import get_transforms_by_cfg -from oml.retrieval.postprocessors.pairwise import ( - PairwiseEmbeddingsPostprocessor, - PairwiseImagesPostprocessor, -) +from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.utils.misc import dictconfig_to_dict POSTPROCESSORS_REGISTRY = { - "pairwise_images": PairwiseImagesPostprocessor, - "pairwise_embeddings": PairwiseEmbeddingsPostprocessor, + "pairwise_reranker": PairwiseReranker, } -def get_postprocessor(name: str, **kwargs: Dict[str, Any]) -> IDistancesPostprocessor: +def get_postprocessor(name: str, **kwargs: Dict[str, Any]) -> IRetrievalPostprocessor: constructor = POSTPROCESSORS_REGISTRY[name] - if "transforms" in kwargs: - kwargs["transforms"] = get_transforms_by_cfg(kwargs["transforms"]) if "pairwise_model" in kwargs: - kwargs["pairwise_model"] = get_pairwise_model_by_cfg(kwargs["pairwise_model"]) + kwargs["pairwise_model"] = get_pairwise_model_by_cfg(kwargs["pairwise_model"]) # type: ignore - return constructor(**kwargs) + return constructor(**kwargs) # type: ignore -def get_postprocessor_by_cfg(cfg: TCfg) -> IDistancesPostprocessor: +def get_postprocessor_by_cfg(cfg: TCfg) -> IRetrievalPostprocessor: cfg = dictconfig_to_dict(cfg) postprocessor = get_postprocessor(cfg["name"], **cfg["args"]) return postprocessor diff --git a/oml/retrieval/postprocessors/pairwise.py b/oml/retrieval/postprocessors/pairwise.py index 39ee69447..9274dc21c 100644 --- a/oml/retrieval/postprocessors/pairwise.py +++ b/oml/retrieval/postprocessors/pairwise.py @@ -1,100 +1,20 @@ -import itertools -from abc import ABC -from pathlib import Path -from typing import Any, Dict, List +from typing import Tuple -import numpy as np import torch from torch import Tensor -from oml.const import EMBEDDINGS_KEY, IS_GALLERY_KEY, IS_QUERY_KEY, PATHS_KEY -from oml.inference.pairs import ( - pairwise_inference_on_embeddings, - pairwise_inference_on_images, -) +from oml.inference.abstract import pairwise_inference +from oml.interfaces.datasets import IQueryGalleryDataset from oml.interfaces.models import IPairwiseModel -from oml.interfaces.retrieval import IDistancesPostprocessor -from oml.transforms.images.utils import TTransforms -from oml.utils.misc_torch import assign_2d - - -class PairwisePostprocessor(IDistancesPostprocessor, ABC): - """ - This postprocessor allows us to re-estimate the distances between queries and ``top-n`` galleries - closest to them. It creates pairs of queries and galleries and feeds them to a pairwise model. - - """ - - top_n: int - verbose: bool = False - - def process(self, distances: Tensor, queries: Any, galleries: Any) -> Tensor: - """ - Args: - distances: Matrix with the shape of ``[Q, G]`` - queries: Queries in the amount of ``Q`` - galleries: Galleries in the amount of ``G`` - - Returns: - Distance matrix with the shape of ``[Q, G]``, - where ``top_n`` minimal values in each row have been updated by the pairwise model, - other distances are shifted by a margin to keep the relative order. - - """ - n_queries = len(queries) - n_galleries = len(galleries) - - assert list(distances.shape) == [n_queries, n_galleries] - - # 1. Adjust top_n with respect to the actual gallery size and find top-n pairs - top_n = min(self.top_n, n_galleries) - ii_top = torch.topk(distances, k=top_n, largest=False)[1] - - # 2. Create (n_queries * top_n) pairs of each query and related galleries and re-estimate distances for them - if self.verbose: - print("\nPostprocessor's inference has been started...") - distances_upd = self.inference(queries=queries, galleries=galleries, ii_top=ii_top, top_n=top_n) - distances_upd = distances_upd.to(distances.device).to(distances.dtype) - - # 3. Update distances for top-n galleries - # The idea is that we somehow permute top-n galleries, but rest of the galleries - # we keep in the end of the list as before permutation. - # To do so, we add an offset to these galleries' distances (which haven't participated in the permutation) - if top_n < n_galleries: - # Here we use the fact that distances not participating in permutation start with top_n + 1 position - min_in_old_distances = torch.topk(distances, k=top_n + 1, largest=False)[0][:, -1] - max_in_new_distances = distances_upd.max(dim=1)[0] - offset = max_in_new_distances - min_in_old_distances + 1e-5 # we also need some eps if max == min - distances += offset.unsqueeze(-1) - else: - # Pairwise postprocessor has been applied to all possible pairs, so, there are no rest distances. - # Thus, we don't need to care about order and offset at all. - pass - - distances = assign_2d(x=distances, indices=ii_top, new_values=distances_upd) - - assert list(distances.shape) == [n_queries, n_galleries] - - return distances - - def inference(self, queries: Any, galleries: Any, ii_top: Tensor, top_n: int) -> Tensor: - """ - Depends on the exact types of queries/galleries this method may be implemented differently. - - Args: - queries: Queries in the amount of ``Q`` - galleries: Galleries in the amount of ``G`` - ii_top: Indices of the closest galleries with the shape of ``[Q, top_n]`` - top_n: Number of the closest galleries to re-rank - - Returns: - An updated distance matrix with the shape of ``[Q, G]`` - - """ - raise NotImplementedError() +from oml.interfaces.retrieval import IRetrievalPostprocessor +from oml.utils.misc_torch import ( + assign_2d, + cat_two_sorted_tensors_and_keep_it_sorted, + take_2d, +) -class PairwiseEmbeddingsPostprocessor(PairwisePostprocessor): +class PairwiseReranker(IRetrievalPostprocessor): def __init__( self, top_n: int, @@ -103,155 +23,152 @@ def __init__( batch_size: int, verbose: bool = False, use_fp16: bool = False, - is_query_key: str = IS_QUERY_KEY, - is_gallery_key: str = IS_GALLERY_KEY, - embeddings_key: str = EMBEDDINGS_KEY, ): """ + Args: top_n: Model will be applied to the ``num_queries * top_n`` pairs formed by each query and ``top_n`` most relevant galleries. - pairwise_model: Model which is able to take two embeddings as inputs + pairwise_model: Model which is able to take two items as inputs and estimate the *distance* (not in a strictly mathematical sense) between them. num_workers: Number of workers in DataLoader batch_size: Batch size that will be used in DataLoader verbose: Set ``True`` if you want to see progress bar for an inference use_fp16: Set ``True`` if you want to use half precision - is_query_key: Key to access a binary mask indicates queries in case of using ``process_by_dict`` - is_gallery_key: Key to access a binary mask indicates galleries in case of using ``process_by_dict`` - embeddings_key: Key to access embeddings in case of using ``process_by_dict`` """ - assert top_n > 1, "Number of galleries for each query to process has to be greater than 1." + assert top_n > 1, "The number of the retrieved results for each query to process has to be greater than 1." self.top_n = top_n self.model = pairwise_model + self.num_workers = num_workers self.batch_size = batch_size self.verbose = verbose self.use_fp16 = use_fp16 - self.is_query_key = is_query_key - self.is_gallery_key = is_gallery_key - self.embeddings_key = embeddings_key - - def inference(self, queries: Tensor, galleries: Tensor, ii_top: Tensor, top_n: int) -> Tensor: + def process(self, distances: Tensor, dataset: IQueryGalleryDataset) -> Tensor: # type: ignore """ Args: - queries: Queries representations with the shape of ``[Q, *]`` - galleries: Galleries representations with the shape of ``[G, *]`` - ii_top: Indices of the closest galleries with the shape of ``[Q, top_n]`` - top_n: Number of the closest galleries to re-rank + distances: Where ``distances[i, j]`` is a distance between i-th query and j-th gallery. + dataset: Dataset having query-gallery split. Returns: - Updated distance matrix with the shape of ``[Q, G]`` + Distances, where ``distances[i, j]`` is a distance between i-th query and j-th gallery, + but the distances to the first ``top_n`` galleries have been updated INPLACE. """ - n_queries = len(queries) - queries = queries.repeat_interleave(top_n, dim=0) - galleries = galleries[ii_top.view(-1)] - distances_upd = pairwise_inference_on_embeddings( - model=self.model, - embeddings1=queries, - embeddings2=galleries, - num_workers=self.num_workers, - batch_size=self.batch_size, - verbose=self.verbose, - use_fp16=self.use_fp16, + # todo 522: + # This function and the code below is only needed during the migration time. + # We will directly use `process_neigh` later on. + # So, the code below is just a format adapter: + # 1) it takes the top (dists + ii) of the big distance matrix, + # 2) passes this top to the `process_neigh()` + # 3) puts the processed outputs on their places in the big distance matrix + + assert distances.shape == (len(dataset.get_query_ids()), len(dataset.get_gallery_ids())) + + # we need this "+10" to activate rescaling if needed (so we have both: new and old distances in proces_neigh. + # anyway, this code is temporary + distances_top, ii_retrieved_top = torch.topk( + distances, k=min(self.top_n + 10, distances.shape[1]), largest=False ) - distances_upd = distances_upd.view(n_queries, top_n) - return distances_upd - - def process_by_dict(self, distances: Tensor, data: Dict[str, Any]) -> Tensor: - queries = data[self.embeddings_key][data[self.is_query_key]] - galleries = data[self.embeddings_key][data[self.is_gallery_key]] - return self.process(distances=distances, queries=queries, galleries=galleries) + distances_top_upd, ii_retrieved_upd = self.process_neigh( + retrieved_ids=ii_retrieved_top, distances=distances_top, dataset=dataset + ) + distances = assign_2d(x=distances, indices=ii_retrieved_upd, new_values=distances_top_upd) - @property - def needed_keys(self) -> List[str]: - return [self.is_query_key, self.is_gallery_key, self.embeddings_key] + assert distances.shape == (len(dataset.get_query_ids()), len(dataset.get_gallery_ids())) + return distances -class PairwiseImagesPostprocessor(PairwisePostprocessor): - def __init__( - self, - top_n: int, - pairwise_model: IPairwiseModel, - transforms: TTransforms, - num_workers: int = 0, - batch_size: int = 128, - verbose: bool = True, - use_fp16: bool = False, - is_query_key: str = IS_QUERY_KEY, - is_gallery_key: str = IS_GALLERY_KEY, - paths_key: str = PATHS_KEY, - ): + def process_neigh( + self, retrieved_ids: Tensor, distances: Tensor, dataset: IQueryGalleryDataset + ) -> Tuple[Tensor, Tensor]: """ + Args: - top_n: Model will be applied to the ``num_queries * top_n`` pairs formed by each query - and its ``top_n`` most relevant galleries. - pairwise_model: Model which is able to take two images as inputs - and estimate the *distance* (not in a strictly mathematical sense) between them. - transforms: Transforms that will be applied to an image - num_workers: Number of workers in DataLoader - batch_size: Batch size that will be used in DataLoader - verbose: Set ``True`` if you want to see progress bar for an inference - use_fp16: Set ``True`` if you want to use half precision - is_query_key: Key to access a binary mask indicates queries in case of using ``process_by_dict`` - is_gallery_key: Key to access a binary mask indicates galleries in case of using ``process_by_dict`` - paths_key: Key to access paths to images in case of using ``process_by_dict`` + retrieved_ids: Ids of galleries closest to every query with the shape of ``[n_query, n_retrieved]`` sorted + by their distances. + distances: The corresponding distances (in sorted order). + dataset: Dataset having query/gallery split. - """ - assert top_n > 1, "Number of galleries for each query to process has to be greater than 1." + Returns: + After model is applied to the ``top_n`` retrieved items, the updated ids and distances are returned. + Thus, you can expect permutation among first ``top_n`` ids and distances, but the rest remains untouched. - self.top_n = top_n - self.model = pairwise_model - self.image_transforms = transforms - self.num_workers = num_workers - self.batch_size = batch_size - self.verbose = verbose - self.use_fp16 = use_fp16 + **Example 1** (for one query): - self.is_query_key = is_query_key - self.is_gallery_key = is_gallery_key - self.paths_key = paths_key + .. code-block:: python - def inference(self, queries: List[Path], galleries: List[Path], ii_top: Tensor, top_n: int) -> Tensor: - """ - Args: - queries: Paths to queries with the length of ``Q`` - galleries: Paths to galleries with the length of ``G`` - ii_top: Indices of the closest galleries with the shape of ``[Q, top_n]`` - top_n: Number of the closest galleries to re-rank + retrieved_ids = [3, 2, 1, 0, 4 ] + distances = [0.1, 0.2, 0.5, 0.6, 0.7] - Returns: - Updated distance matrix with the shape of ``[Q, G]`` + # Let's say a postprocessor has been applied to the + # first 3 elements and the new distances are: [0.4, 0.2, 0.3] + + # In this case, the updated values will be: + retrieved_ids = [2, 1, 3, 0, 4 ] + distances: = [0.2, 0.3, 0.4, 0.6, 0.7] + + **Example 2** (for one query): + + .. code-block:: python + + # Note, the new distances to the top_n items produced by the pairwise model + # may be rescaled to keep the distances order. Here is an example: + original_distances = [0.1, 0.2, 0.3, 0.5, 0.6] + top_n = 3 + + # Imagine, the postprocessor didn't change the order of the first 3 items + # (it's just a convenient example, the general logic remains the same), + # however the new values have a bigger scale: + distances_upd = [1, 2, 5, 0.5, 0.6] + + # Thus, we need to downscale the first 3 distances, so they are lower than 0.5: + scale = 5 / 0.5 = 0.1 + # Finally, let's apply the found scale to the top 3 distances: + distances_upd_scaled = [0.1, 0.2, 0.5, 0.5, 0.6] + + # Note, if new and old distances are already sorted, we don't apply any scaling. """ - n_queries = len(queries) - queries = list(itertools.chain.from_iterable(itertools.repeat(x, top_n) for x in queries)) - galleries = [galleries[i] for i in ii_top.view(-1)] - distances_upd = pairwise_inference_on_images( + assert retrieved_ids.shape == distances.shape + assert len(retrieved_ids) == len(dataset.get_query_ids()) + assert retrieved_ids.shape[1] <= len(dataset.get_gallery_ids()) + + top_n = min(self.top_n, distances.shape[1]) + + # let's list pairs of (query_i, gallery_j) we need to process + ids_q = dataset.get_query_ids().unsqueeze(-1).repeat_interleave(top_n) + ii_g = dataset.get_gallery_ids().unsqueeze(-1) + ids_g = ii_g[retrieved_ids[:, :top_n]].flatten() + assert len(ids_q) == len(ids_g) + pairs = list(zip(ids_q.tolist(), ids_g.tolist())) + + distances_top = pairwise_inference( model=self.model, - paths1=queries, - paths2=galleries, - transform=self.image_transforms, + base_dataset=dataset, + pair_ids=pairs, num_workers=self.num_workers, batch_size=self.batch_size, verbose=self.verbose, use_fp16=self.use_fp16, ) - distances_upd = distances_upd.view(n_queries, top_n) - return distances_upd + distances_top = distances_top.view(distances.shape[0], top_n) + + distances_upd, ii_rerank = distances_top.sort() + retrieved_ids_upd = take_2d(retrieved_ids, ii_rerank) + + # Stack with the unprocessed values outside the first top_n items + if top_n < distances.shape[1]: + distances_upd = cat_two_sorted_tensors_and_keep_it_sorted(distances_upd, distances[:, top_n:]) + retrieved_ids_upd = torch.concat([retrieved_ids_upd, retrieved_ids[:, top_n:]], dim=1).long() - def process_by_dict(self, distances: Tensor, data: Dict[str, Any]) -> Tensor: - queries = np.array(data[self.paths_key])[data[self.is_query_key]] - galleries = np.array(data[self.paths_key])[data[self.is_gallery_key]] - return self.process(distances=distances, queries=queries, galleries=galleries) + assert distances_upd.shape == distances.shape + assert retrieved_ids_upd.shape == retrieved_ids.shape - @property - def needed_keys(self) -> List[str]: - return [self.is_query_key, self.is_gallery_key, self.paths_key] + return distances_upd, retrieved_ids_upd -__all__ = ["PairwisePostprocessor", "PairwiseEmbeddingsPostprocessor", "PairwiseImagesPostprocessor"] +__all__ = ["PairwiseReranker"] diff --git a/oml/retrieval/retrieval_results.py b/oml/retrieval/retrieval_results.py index cfbe633db..ef38419b1 100644 --- a/oml/retrieval/retrieval_results.py +++ b/oml/retrieval/retrieval_results.py @@ -122,11 +122,10 @@ def visualize( verbose: Set ``True`` to allow prints. """ - if not isinstance(dataset, (IVisualizableDataset, IQueryGalleryDataset)): - raise TypeError( - f"Dataset has to support {IVisualizableDataset.__name__} and " - f"{IQueryGalleryDataset.__name__} interfaces. Got {type(dataset)}." - ) + if not isinstance(dataset, IVisualizableDataset): + raise TypeError(f"Dataset has to support {IVisualizableDataset.__name__}. Got {type(dataset)}.") + if not isinstance(dataset, IQueryGalleryDataset): + raise TypeError(f"Dataset has to support {IQueryGalleryDataset.__name__}. Got {type(dataset)}.") if verbose: print(f"Visualizing {n_galleries_to_show} for the following query ids: {query_ids}.") diff --git a/oml/utils/download_mock_dataset.py b/oml/utils/download_mock_dataset.py index eb193495e..b4a5ddfa4 100644 --- a/oml/utils/download_mock_dataset.py +++ b/oml/utils/download_mock_dataset.py @@ -17,7 +17,10 @@ def get_argparser() -> ArgumentParser: def download_mock_dataset( - dataset_root: Union[str, Path], check_md5: bool = True, df_name: str = "df.csv" + dataset_root: Union[str, Path] = MOCK_DATASET_PATH, + check_md5: bool = True, + df_name: str = "df.csv", + global_paths: bool = False, ) -> Tuple[pd.DataFrame, pd.DataFrame]: """ Function to download mock dataset which is already prepared in the required format. @@ -26,6 +29,7 @@ def download_mock_dataset( dataset_root: Path to save the dataset check_md5: Set ``True`` to check md5sum df_name: Name of csv file for which output DataFrames will be returned + global_paths: Set ``True`` to cancat paths and ``dataset_root`` Returns: Dataframes for the training and validation stages @@ -48,6 +52,9 @@ def download_mock_dataset( df = pd.read_csv(Path(dataset_root) / df_name) + if global_paths: + df["path"] = df["path"].apply(lambda x: str(Path(dataset_root) / x)) + df_train = df[df["split"] == "train"].reset_index(drop=True) df_val = df[df["split"] == "validation"].reset_index(drop=True) diff --git a/oml/utils/misc_torch.py b/oml/utils/misc_torch.py index 8d12f2d6f..b055aaf8e 100644 --- a/oml/utils/misc_torch.py +++ b/oml/utils/misc_torch.py @@ -57,6 +57,30 @@ def assign_2d(x: Tensor, indices: Tensor, new_values: Tensor) -> Tensor: return x +def cat_two_sorted_tensors_and_keep_it_sorted(x1: Tensor, x2: Tensor, eps: float = 1e-6) -> Tensor: + """ + Args: + x1: Sorted tensor with the shape of ``[N, M]`` + x2: Sorted tensor with the shape of ``[N, P]`` + eps: Eps to have a gap between the last x1 and the first x2 + + Returns: + Concatenation of two sorted tensors. + The first tensor may be rescaled if needed to keep the order sorted. + + """ + assert eps >= 0 + assert x1.shape[0] == x2.shape[0] + + scale = (x2[:, 0] / x1[:, -1]).view(-1, 1).type_as(x1) + need_scaling = x1[:, -1] > x2[:, 0] + x1[need_scaling] = x1[need_scaling] * scale[need_scaling] - eps + + x = torch.concatenate([x1, x2], dim=1).float() + + return x + + def elementwise_dist(x1: Tensor, x2: Tensor, p: int = 2) -> Tensor: """ Args: @@ -455,6 +479,7 @@ def _check_dimensions(self, n_components: int) -> None: __all__ = [ "elementwise_dist", + "cat_two_sorted_tensors_and_keep_it_sorted", "pairwise_dist", "OnlineCalc", "AvgOnline", diff --git a/pipelines/postprocessing/pairwise_postprocessing/postprocessor_train.yaml b/pipelines/postprocessing/pairwise_postprocessing/postprocessor_train.yaml index e57305d61..220023636 100644 --- a/pipelines/postprocessing/pairwise_postprocessing/postprocessor_train.yaml +++ b/pipelines/postprocessing/pairwise_postprocessing/postprocessor_train.yaml @@ -84,15 +84,10 @@ transforms_train: batch_size_inference: 128 postprocessor: - name: pairwise_images + name: pairwise_reranker args: top_n: 5 pairwise_model: ${pairwise_model} - transforms: - name: norm_resize_hypvit_torch - args: - im_size: 224 - crop_size: 224 num_workers: ${num_workers} batch_size: ${batch_size_inference} verbose: True diff --git a/pipelines/postprocessing/pairwise_postprocessing/postprocessor_validate.yaml b/pipelines/postprocessing/pairwise_postprocessing/postprocessor_validate.yaml index 0ed09cbe7..749a4d0a8 100644 --- a/pipelines/postprocessing/pairwise_postprocessing/postprocessor_validate.yaml +++ b/pipelines/postprocessing/pairwise_postprocessing/postprocessor_validate.yaml @@ -26,7 +26,7 @@ extractor: weights: ${extractor_weights} postprocessor: - name: pairwise_images + name: pairwise_reranker args: top_n: 5 pairwise_model: @@ -42,11 +42,6 @@ postprocessor: normalise_features: False use_multi_scale: False weights: null - transforms: - name: norm_resize_hypvit_torch - args: - im_size: 224 - crop_size: 224 num_workers: ${num_workers} batch_size: ${bs_val} verbose: True diff --git a/pipelines/postprocessing/pairwise_postprocessing/visualisation.ipynb b/pipelines/postprocessing/pairwise_postprocessing/visualisation.ipynb index d0fbe1d27..92ae65cc8 100644 --- a/pipelines/postprocessing/pairwise_postprocessing/visualisation.ipynb +++ b/pipelines/postprocessing/pairwise_postprocessing/visualisation.ipynb @@ -85,7 +85,7 @@ "source": [ "cfg_p = cfg + f\"\"\"\n", " postprocessor:\n", - " name: pairwise_images\n", + " name: pairwise_reranker\n", " args:\n", " top_n: 5\n", " pairwise_model:\n", @@ -100,11 +100,6 @@ " normalise_features: False\n", " use_multi_scale: False\n", " weights: null\n", - " transforms:\n", - " name: norm_resize_hypvit_torch\n", - " args:\n", - " im_size: 224\n", - " crop_size: 224\n", " num_workers: 10\n", " batch_size: 128\n", " verbose: True\n", diff --git a/tests/test_oml/test_ddp/test_ddp_inference.py b/tests/test_oml/test_ddp/test_ddp_inference.py index 505205055..aafde4b3f 100644 --- a/tests/test_oml/test_ddp/test_ddp_inference.py +++ b/tests/test_oml/test_ddp/test_ddp_inference.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import List import pytest @@ -5,7 +6,8 @@ from torchvision.models import resnet18 from oml.const import MOCK_DATASET_PATH -from oml.inference.flat import inference_on_images +from oml.datasets import ImageBaseDataset +from oml.inference import inference from oml.transforms.images.torchvision import get_normalisation_resize_torch from oml.utils.download_mock_dataset import download_mock_dataset from tests.test_oml.test_ddp.utils import init_ddp, run_in_ddp @@ -34,13 +36,11 @@ def run_with_handling_duplicates(rank: int, world_size: int, device: str, paths: args = { "model": model, - "paths": paths, - "transform": transform, + "dataset": ImageBaseDataset(paths=[Path(x) for x in paths], transform=transform), "num_workers": 0, "verbose": True, - "f_imread": None, "batch_size": batch_size, } - output = inference_on_images(**args) + output = inference(**args) assert len(paths) == len(output), (len(paths), len(output)) diff --git a/tests/test_oml/test_metrics/test_embedding_metrics.py b/tests/test_oml/test_metrics/test_embedding_metrics.py index 89a2c3bc4..68e64a57f 100644 --- a/tests/test_oml/test_metrics/test_embedding_metrics.py +++ b/tests/test_oml/test_metrics/test_embedding_metrics.py @@ -3,9 +3,11 @@ from functools import partial from typing import Any, Tuple +import numpy as np import pytest import torch from torch import Tensor +from torch.utils.data import DataLoader from oml.const import ( CATEGORIES_KEY, @@ -18,16 +20,17 @@ ) from oml.metrics.embeddings import EmbeddingMetrics from oml.models.meta.siamese import LinearTrivialDistanceSiamese -from oml.retrieval.postprocessors.pairwise import PairwiseEmbeddingsPostprocessor +from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.utils.misc import compare_dicts_recursively, one_hot +from tests.test_integrations.utils import EmbeddingsQueryGalleryLabeledDataset FEAT_DIM = 8 oh = partial(one_hot, dim=FEAT_DIM) -def get_trivial_postprocessor(top_n: int) -> PairwiseEmbeddingsPostprocessor: +def get_trivial_postprocessor(top_n: int) -> PairwiseReranker: model = LinearTrivialDistanceSiamese(feat_dim=FEAT_DIM, identity_init=True) - processor = PairwiseEmbeddingsPostprocessor(pairwise_model=model, top_n=top_n, num_workers=0, batch_size=64) + processor = PairwiseReranker(pairwise_model=model, top_n=top_n, num_workers=0, batch_size=64) return processor @@ -46,22 +49,13 @@ def perfect_case() -> Any: Thus, we expect all of the metrics equals to 1. """ - - batch1 = { - EMBEDDINGS_KEY: torch.stack([oh(0), oh(1), oh(1)]), - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([True, True, True]), - IS_GALLERY_KEY: torch.tensor([False, False, False]), - CATEGORIES_KEY: ["cat", "dog", "dog"], - } - - batch2 = { - EMBEDDINGS_KEY: torch.stack([oh(0), oh(1), oh(1)]), - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([False, False, False]), - IS_GALLERY_KEY: torch.tensor([True, True, True]), - CATEGORIES_KEY: ["cat", "dog", "dog"], - } + dataset = EmbeddingsQueryGalleryLabeledDataset( + embeddings=torch.stack([oh(0), oh(1), oh(1), oh(0), oh(1), oh(1)]).float(), + labels=torch.tensor([0, 1, 1, 0, 1, 1]).long(), + is_query=torch.tensor([True, True, True, False, False, False]).bool(), + is_gallery=torch.tensor([False, False, False, True, True, True]).bool(), + categories=np.array(["cat", "dog", "dog", "cat", "dog", "dog"]), + ) k = 1 metrics = defaultdict(lambda: defaultdict(dict)) # type: ignore @@ -69,26 +63,18 @@ def perfect_case() -> Any: metrics["cat"]["cmc"][k] = 1.0 metrics["dog"]["cmc"][k] = 1.0 - return (batch1, batch2), (metrics, k) + return dataset, (metrics, k) @pytest.fixture() def imperfect_case() -> Any: - batch1 = { - EMBEDDINGS_KEY: torch.stack([oh(0), oh(1), oh(3)]), # 3d embedding pretends to be an error - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([True, True, True]), - IS_GALLERY_KEY: torch.tensor([False, False, False]), - CATEGORIES_KEY: torch.tensor([10, 20, 20]), - } - - batch2 = { - EMBEDDINGS_KEY: torch.stack([oh(0), oh(1), oh(1)]), - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([False, False, False]), - IS_GALLERY_KEY: torch.tensor([True, True, True]), - CATEGORIES_KEY: torch.tensor([10, 20, 20]), - } + dataset = EmbeddingsQueryGalleryLabeledDataset( + embeddings=torch.stack([oh(0), oh(1), oh(3), oh(0), oh(1), oh(1)]).float(), # 3d val pretends to be an error + labels=torch.tensor([0, 1, 1, 0, 1, 1]).long(), + is_query=torch.tensor([True, True, True, False, False, False]).bool(), + is_gallery=torch.tensor([False, False, False, True, True, True]).bool(), + categories=np.array([10, 20, 20, 10, 20, 20]), + ) k = 1 metrics = defaultdict(lambda: defaultdict(dict)) # type: ignore @@ -96,26 +82,18 @@ def imperfect_case() -> Any: metrics[10]["cmc"][k] = 1.0 metrics[20]["cmc"][k] = 0.5 - return (batch1, batch2), (metrics, k) + return dataset, (metrics, k) @pytest.fixture() def worst_case() -> Any: - batch1 = { - EMBEDDINGS_KEY: torch.stack([oh(1), oh(0), oh(0)]), # 3d embedding pretends to be an error - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([True, True, True]), - IS_GALLERY_KEY: torch.tensor([False, False, False]), - CATEGORIES_KEY: torch.tensor([10, 20, 20]), - } - - batch2 = { - EMBEDDINGS_KEY: torch.stack([oh(0), oh(1), oh(1)]), - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([False, False, False]), - IS_GALLERY_KEY: torch.tensor([True, True, True]), - CATEGORIES_KEY: torch.tensor([10, 20, 20]), - } + dataset = EmbeddingsQueryGalleryLabeledDataset( + embeddings=torch.stack([oh(1), oh(0), oh(0), oh(0), oh(1), oh(1)]).float(), # all are errors + labels=torch.tensor([0, 1, 1, 0, 1, 1]).long(), + is_query=torch.tensor([True, True, True, False, False, False]).bool(), + is_gallery=torch.tensor([False, False, False, True, True, True]).bool(), + categories=np.array([10, 20, 20, 10, 20, 20]), + ) k = 1 metrics = defaultdict(lambda: defaultdict(dict)) # type: ignore @@ -123,38 +101,32 @@ def worst_case() -> Any: metrics[10]["cmc"][k] = 0 metrics[20]["cmc"][k] = 0 - return (batch1, batch2), (metrics, k) + return dataset, (metrics, k) @pytest.fixture() -def case_for_distance_check() -> Any: - batch1 = { - EMBEDDINGS_KEY: torch.stack([oh(1) * 2, oh(1) * 3, oh(0)]), - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([True, True, True]), - IS_GALLERY_KEY: torch.tensor([False, False, False]), - CATEGORIES_KEY: torch.tensor([10, 20, 20]), - } - - batch2 = { - EMBEDDINGS_KEY: torch.stack([oh(0), oh(1), oh(1)]), - LABELS_KEY: torch.tensor([0, 1, 1]), - IS_QUERY_KEY: torch.tensor([False, False, False]), - IS_GALLERY_KEY: torch.tensor([True, True, True]), - CATEGORIES_KEY: torch.tensor([10, 20, 20]), - } - ids_ranked_by_distance = [0, 2, 1] - return (batch1, batch2), ids_ranked_by_distance +def case_for_finding_worst_queries() -> Any: + dataset = EmbeddingsQueryGalleryLabeledDataset( + embeddings=torch.stack([oh(0), oh(1), oh(2), oh(0), oh(5), oh(5)]).float(), # last 2 are errors + labels=torch.tensor([0, 1, 2, 0, 1, 2]).long(), + is_query=torch.tensor([True, True, True, False, False, False]).bool(), + is_gallery=torch.tensor([False, False, False, True, True, True]).bool(), + categories=np.array([10, 20, 20, 10, 20, 20]), + ) + + worst_two_queries = {1, 2} + return dataset, worst_two_queries def run_retrieval_metrics(case) -> None: # type: ignore - (batch1, batch2), (gt_metrics, k) = case + dataset, (gt_metrics, k) = case top_k = (k,) - num_samples = len(batch1[LABELS_KEY]) + len(batch2[LABELS_KEY]) + num_samples = len(dataset) calc = EmbeddingMetrics( - embeddings_key=EMBEDDINGS_KEY, + dataset=dataset, + embeddings_key=dataset.input_tensors_key, labels_key=LABELS_KEY, is_query_key=IS_QUERY_KEY, is_gallery_key=IS_GALLERY_KEY, @@ -168,8 +140,9 @@ def run_retrieval_metrics(case) -> None: # type: ignore ) calc.setup(num_samples=num_samples) - calc.update_data(batch1) - calc.update_data(batch2) + + for batch in DataLoader(dataset, batch_size=4, shuffle=False): + calc.update_data(batch) metrics = calc.compute_metrics() @@ -182,15 +155,15 @@ def run_retrieval_metrics(case) -> None: # type: ignore assert calc.acc.collected_samples == num_samples # type: ignore -def run_across_epochs(case1, case2) -> None: # type: ignore - (batch11, batch12), (gt_metrics1, k1) = case1 - (batch21, batch22), (gt_metrics2, k2) = case2 - assert k1 == k2 +def run_across_epochs(case) -> None: # type: ignore + dataset, (gt_metrics, k) = case - top_k = (k1,) + top_k = (k,) + num_samples = len(dataset) calc = EmbeddingMetrics( - embeddings_key=EMBEDDINGS_KEY, + dataset=dataset, + embeddings_key=dataset.input_tensors_key, labels_key=LABELS_KEY, is_query_key=IS_QUERY_KEY, is_gallery_key=IS_GALLERY_KEY, @@ -203,32 +176,22 @@ def run_across_epochs(case1, case2) -> None: # type: ignore postprocessor=get_trivial_postprocessor(top_n=3), ) - def epoch_case(batch_a, batch_b, ground_truth_metrics) -> None: # type: ignore - num_samples = len(batch_a[LABELS_KEY]) + len(batch_b[LABELS_KEY]) - calc.setup(num_samples=num_samples) - calc.update_data(batch_a) - calc.update_data(batch_b) - metrics = calc.compute_metrics() - - compare_dicts_recursively(metrics, ground_truth_metrics) + metrics_all_epochs = [] - # the euclidean distance between any one-hots is always sqrt(2) or 0 - assert compare_tensors_as_sets(calc.distance_matrix, torch.tensor([0, math.sqrt(2)])) # type: ignore + for _ in range(2): # epochs + calc.setup(num_samples=num_samples) - assert (calc.mask_gt.unique() == torch.tensor([0, 1])).all() # type: ignore - assert calc.acc.collected_samples == num_samples + for batch in DataLoader(dataset, batch_size=2, num_workers=0, shuffle=False, drop_last=False): + calc.update_data(batch) - # 1st epoch - epoch_case(batch11, batch12, gt_metrics1) + metrics_all_epochs.append(calc.compute_metrics()) - # 2nd epoch - epoch_case(batch21, batch22, gt_metrics2) + assert compare_dicts_recursively(metrics_all_epochs[0], metrics_all_epochs[-1]) - # 3d epoch - epoch_case(batch11, batch12, gt_metrics1) + # the euclidean distance between any one-hots is always sqrt(2) or 0 + assert compare_tensors_as_sets(calc.distance_matrix, torch.tensor([0, math.sqrt(2)])) - # 4th epoch - epoch_case(batch21, batch22, gt_metrics2) + assert calc.acc.collected_samples == num_samples def test_perfect_case(perfect_case) -> None: # type: ignore @@ -243,37 +206,37 @@ def test_worst_case(worst_case) -> None: # type: ignore run_retrieval_metrics(worst_case) -def test_mixed_epochs(perfect_case, imperfect_case, worst_case): # type: ignore - cases = [perfect_case, imperfect_case, worst_case] - for case1 in cases: - for case2 in cases: - run_across_epochs(case1, case2) +def test_several_epochs(perfect_case, imperfect_case, worst_case): # type: ignore + run_across_epochs(perfect_case) + run_across_epochs(imperfect_case) + run_across_epochs(worst_case) -def test_worst_k(case_for_distance_check) -> None: # type: ignore - (batch1, batch2), gt_ids = case_for_distance_check +def test_worst_k(case_for_finding_worst_queries) -> None: # type: ignore + dataset, worst_query_ids = case_for_finding_worst_queries - num_samples = len(batch1[LABELS_KEY]) + len(batch2[LABELS_KEY]) + num_samples = len(dataset) calc = EmbeddingMetrics( - embeddings_key=EMBEDDINGS_KEY, + dataset=dataset, + embeddings_key=dataset.input_tensors_key, labels_key=LABELS_KEY, is_query_key=IS_QUERY_KEY, is_gallery_key=IS_GALLERY_KEY, categories_key=CATEGORIES_KEY, - cmc_top_k=(), + cmc_top_k=(1,), precision_top_k=(), - map_top_k=(2,), + map_top_k=(), fmr_vals=tuple(), postprocessor=get_trivial_postprocessor(top_n=1_000), ) calc.setup(num_samples=num_samples) - calc.update_data(batch1) - calc.update_data(batch2) + for batch in DataLoader(dataset, batch_size=4, shuffle=False): + calc.update_data(batch) calc.compute_metrics() - assert calc.get_worst_queries_ids(f"{OVERALL_CATEGORIES_KEY}/map/2", 3) == gt_ids + assert set(calc.get_worst_queries_ids(f"{OVERALL_CATEGORIES_KEY}/cmc/1", 2)) == worst_query_ids @pytest.mark.parametrize("extra_keys", [[], [PATHS_KEY], [PATHS_KEY, "a"], ["a"]]) diff --git a/tests/test_oml/test_metrics/test_embedding_visualizations.py b/tests/test_oml/test_metrics/test_embedding_visualizations.py index 2c883187c..ee2e90df7 100644 --- a/tests/test_oml/test_metrics/test_embedding_visualizations.py +++ b/tests/test_oml/test_metrics/test_embedding_visualizations.py @@ -6,6 +6,7 @@ from oml.const import ( CATEGORIES_KEY, EMBEDDINGS_KEY, + INDEX_KEY, IS_GALLERY_KEY, IS_QUERY_KEY, LABELS_KEY, @@ -32,6 +33,7 @@ def test_visualization() -> None: IS_GALLERY_KEY: torch.tensor([False, False, False]), CATEGORIES_KEY: torch.tensor([10, 20, 20]), PATHS_KEY: [cf / "temp.png", cf / "temp.png", cf / "temp.png"], + INDEX_KEY: torch.tensor([0, 1, 2]), } batch2 = { @@ -41,6 +43,7 @@ def test_visualization() -> None: IS_GALLERY_KEY: torch.tensor([True, True, True]), CATEGORIES_KEY: torch.tensor([10, 20, 20]), PATHS_KEY: [cf / "temp.png", cf / "temp.png", cf / "temp.png"], + INDEX_KEY: torch.tensor([3, 4, 5]), } calc = EmbeddingMetrics( diff --git a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py index 2b1eaaf15..8a8d1a5c3 100644 --- a/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py +++ b/tests/test_oml/test_postprocessor/test_pairwise_embeddings.py @@ -1,4 +1,3 @@ -import math from functools import partial from random import randint, random from typing import Tuple @@ -7,95 +6,102 @@ import torch from torch import Tensor -from oml.functional.metrics import calc_distance_matrix from oml.functional.metrics import ( calc_retrieval_metrics_on_full as calc_retrieval_metrics, ) +from oml.interfaces.datasets import IQueryGalleryDataset, IQueryGalleryLabeledDataset from oml.interfaces.models import IPairwiseModel from oml.models.meta.siamese import LinearTrivialDistanceSiamese -from oml.retrieval.postprocessors.pairwise import PairwiseEmbeddingsPostprocessor -from oml.utils.misc import flatten_dict, one_hot +from oml.retrieval.postprocessors.pairwise import PairwiseReranker +from oml.utils.misc import flatten_dict, one_hot, set_global_seed from oml.utils.misc_torch import normalise, pairwise_dist +from tests.test_integrations.utils import ( + EmbeddingsQueryGalleryDataset, + EmbeddingsQueryGalleryLabeledDataset, +) FEAT_SIZE = 8 oh = partial(one_hot, dim=FEAT_SIZE) @pytest.fixture -def independent_query_gallery_case() -> Tuple[Tensor, Tensor, Tensor]: +def independent_query_gallery_case() -> Tuple[IQueryGalleryDataset, Tensor]: sz = 7 feat_dim = 12 - embeddings = torch.randn((sz, feat_dim)) - embeddings = normalise(embeddings) - is_query = torch.ones(sz).bool() is_query[: sz // 2] = False is_gallery = torch.ones(sz).bool() is_gallery[sz // 2 :] = False - return embeddings, is_query, is_gallery + embeddings = normalise(torch.randn((sz, feat_dim))).float() + + dataset = EmbeddingsQueryGalleryDataset(embeddings=embeddings, is_query=is_query, is_gallery=is_gallery) + + embeddings_inference = embeddings.clone() # pretend it's our inference results + + return dataset, embeddings_inference @pytest.fixture -def shared_query_gallery_case() -> Tuple[Tensor, Tensor, Tensor]: +def shared_query_gallery_case() -> Tuple[IQueryGalleryDataset, Tensor]: sz = 7 feat_dim = 4 - embeddings = torch.randn((sz, feat_dim)) - embeddings = normalise(embeddings) + embeddings = normalise(torch.randn((sz, feat_dim))).float() - is_query = torch.ones(sz).bool() - is_gallery = torch.ones(sz).bool() + dataset = EmbeddingsQueryGalleryDataset( + embeddings=embeddings, is_query=torch.ones(sz).bool(), is_gallery=torch.ones(sz).bool() + ) + + embeddings_inference = embeddings.clone() # pretend it's our inference results - return embeddings, is_query, is_gallery + return dataset, embeddings_inference @pytest.mark.long @pytest.mark.parametrize("top_n", [2, 5, 100]) +@pytest.mark.parametrize("pairwise_distances_bias", [0, -5, +5]) @pytest.mark.parametrize("fixture_name", ["independent_query_gallery_case", "shared_query_gallery_case"]) def test_trivial_processing_does_not_change_distances_order( - request: pytest.FixtureRequest, fixture_name: str, top_n: int + request: pytest.FixtureRequest, fixture_name: str, top_n: int, pairwise_distances_bias: float ) -> None: - embeddings, is_query, is_gallery = request.getfixturevalue(fixture_name) - embeddings_query = embeddings[is_query] - embeddings_gallery = embeddings[is_gallery] + dataset, embeddings = request.getfixturevalue(fixture_name) - distances = calc_distance_matrix(embeddings, is_query, is_gallery) + distances = pairwise_dist(x1=embeddings[dataset.get_query_ids()], x2=embeddings[dataset.get_gallery_ids()], p=2) - model = LinearTrivialDistanceSiamese(feat_dim=embeddings.shape[-1], identity_init=True) - processor = PairwiseEmbeddingsPostprocessor(pairwise_model=model, top_n=top_n, num_workers=0, batch_size=64) + model = LinearTrivialDistanceSiamese(embeddings.shape[-1], output_bias=pairwise_distances_bias, identity_init=True) + processor = PairwiseReranker(pairwise_model=model, top_n=top_n, num_workers=0, batch_size=64) - distances_processed = processor.process( - queries=embeddings_query, - galleries=embeddings_gallery, - distances=distances.clone(), - ) + distances_processed = processor.process(distances=distances.clone(), dataset=dataset) - order = distances.argsort() - order_processed = distances_processed.argsort() + assert (distances_processed.argsort() == distances.argsort()).all() - assert (order == order_processed).all(), (order, order_processed) + if pairwise_distances_bias == 0: + assert torch.allclose(distances, distances_processed) + else: + assert not torch.allclose(distances, distances_processed) - if top_n <= is_gallery.sum(): - min_orig_distances = torch.topk(distances, k=top_n, largest=False).values - min_processed_distances = torch.topk(distances_processed, k=top_n, largest=False).values - assert torch.allclose(min_orig_distances, min_processed_distances) +def perfect_case() -> Tuple[IQueryGalleryLabeledDataset, Tensor]: + embeddings = torch.stack([oh(1), oh(2), oh(3), oh(1), oh(2), oh(1), oh(2), oh(3)]).float() -def perfect_case() -> Tuple[Tensor, Tensor, Tensor, Tensor]: - query_labels = torch.tensor([1, 2, 3]).long() - query_embeddings = torch.stack([oh(1), oh(2), oh(3)]) + dataset = EmbeddingsQueryGalleryLabeledDataset( + embeddings=embeddings, + labels=torch.tensor([1, 2, 3, 1, 2, 1, 2, 3]).long(), + is_query=torch.tensor([1, 1, 1, 1, 0, 0, 0, 0]).bool(), + is_gallery=torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]).bool(), + ) - gallery_labels = torch.tensor([1, 2, 1, 2, 3]).long() - gallery_embeddings = torch.stack([oh(1), oh(2), oh(1), oh(2), oh(3)]) + embeddings_inference = embeddings.clone() - return query_embeddings, gallery_embeddings, query_labels, gallery_labels + return dataset, embeddings_inference @pytest.mark.long -def test_trivial_processing_fixes_broken_perfect_case() -> None: +@pytest.mark.parametrize("pairwise_distances_bias", [0, -100, +100]) +def test_trivial_processing_fixes_broken_perfect_case(pairwise_distances_bias: float) -> None: """ The idea of the test is the following: @@ -108,9 +114,12 @@ def test_trivial_processing_fixes_broken_perfect_case() -> None: n_repetitions = 20 for _ in range(n_repetitions): - query_embeddings, gallery_embeddings, query_labels, gallery_labels = perfect_case() - distances = pairwise_dist(query_embeddings, gallery_embeddings) - mask_gt = query_labels.unsqueeze(-1) == gallery_labels + dataset, embeddings = perfect_case() + distances = pairwise_dist(embeddings[dataset.get_query_ids()], embeddings[dataset.get_gallery_ids()], p=2) + + labels_q = torch.tensor(dataset.get_labels()[dataset.get_query_ids()]) + labels_g = torch.tensor(dataset.get_labels()[dataset.get_gallery_ids()]) + mask_gt = labels_q.unsqueeze(-1) == labels_g nq, ng = distances.shape @@ -130,9 +139,11 @@ def test_trivial_processing_fixes_broken_perfect_case() -> None: metrics = flatten_dict(calc_retrieval_metrics(distances=distances, **args)) # Metrics after broken distances have been fixed - model = LinearTrivialDistanceSiamese(feat_dim=gallery_embeddings.shape[-1], identity_init=True) - processor = PairwiseEmbeddingsPostprocessor(pairwise_model=model, top_n=top_n, batch_size=16, num_workers=0) - distances_upd = processor.process(distances, query_embeddings, gallery_embeddings) + model = LinearTrivialDistanceSiamese( + feat_dim=embeddings.shape[-1], identity_init=True, output_bias=pairwise_distances_bias + ) + processor = PairwiseReranker(pairwise_model=model, top_n=top_n, batch_size=16, num_workers=0) + distances_upd = processor.process(distances, dataset) metrics_upd = flatten_dict(calc_retrieval_metrics(distances=distances_upd, **args)) for key in metrics.keys(): @@ -141,46 +152,6 @@ def test_trivial_processing_fixes_broken_perfect_case() -> None: assert metric_upd >= metric, (key, metric, metric_upd) -class DummyPairwise(IPairwiseModel): - def __init__(self, distances_to_return: Tensor): - super(DummyPairwise, self).__init__() - self.distances_to_return = distances_to_return - self.parameter = torch.nn.Linear(1, 1) - - def forward(self, x1: Tensor, x2: Tensor) -> Tensor: - return self.distances_to_return - - def predict(self, x1: Tensor, x2: Tensor) -> Tensor: - return self.distances_to_return - - -@pytest.mark.long -def test_trivial_processing_fixes_broken_perfect_case_2() -> None: - """ - The idea of the test is similar to "test_trivial_processing_fixes_broken_perfect_case", - but this time we check the exact metrics values. - - """ - distances = torch.tensor([[0.8, 0.3, 0.2, 0.4, 0.5]]) - mask_gt = torch.tensor([[1, 1, 0, 1, 0]]).bool() - - args = {"mask_gt": mask_gt, "precision_top_k": (1, 3)} - - precisions = calc_retrieval_metrics(distances=distances, **args)["precision"] - assert math.isclose(precisions[1], 0) - assert math.isclose(precisions[3], 2 / 3, abs_tol=1e-5) - - # Now let's fix the error with dummy pairwise model - model = DummyPairwise(distances_to_return=torch.tensor([3.5, 2.5])) - processor = PairwiseEmbeddingsPostprocessor(pairwise_model=model, top_n=2, batch_size=128, num_workers=0) - distances_upd = processor.process( - distances=distances, queries=torch.randn((1, FEAT_SIZE)), galleries=torch.randn((5, FEAT_SIZE)) - ) - precisions_upd = calc_retrieval_metrics(distances=distances_upd, **args)["precision"] - assert math.isclose(precisions_upd[1], 1) - assert math.isclose(precisions_upd[3], 2 / 3, abs_tol=1e-5) - - class RandomPairwise(IPairwiseModel): def __init__(self): # type: ignore super(RandomPairwise, self).__init__() @@ -199,13 +170,19 @@ def test_processing_not_changing_non_sensitive_metrics(top_n: int) -> None: # The idea of the test is that postprocessing of first n elements # cannot change cmc@n and precision@n - # Let's construct some random input - query_embeddings_perfect, gallery_embeddings_perfect, query_labels, gallery_labels = perfect_case() - query_embeddings = torch.rand_like(query_embeddings_perfect) - gallery_embeddings = torch.rand_like(gallery_embeddings_perfect) - mask_gt = query_labels.unsqueeze(-1) == gallery_labels + set_global_seed(42) - distances = pairwise_dist(query_embeddings, gallery_embeddings) + # let's get some random inputs + dataset, embeddings = perfect_case() + embeddings = torch.randn_like(embeddings).float() + + top_n = min(top_n, embeddings.shape[1]) + + distances = pairwise_dist(embeddings[dataset.get_query_ids()], embeddings[dataset.get_gallery_ids()], p=2) + + labels_q = torch.tensor(dataset.get_labels()[dataset.get_query_ids()]) + labels_g = torch.tensor(dataset.get_labels()[dataset.get_gallery_ids()]) + mask_gt = labels_q.unsqueeze(-1) == labels_g args = { "cmc_top_k": (top_n,), @@ -215,11 +192,17 @@ def test_processing_not_changing_non_sensitive_metrics(top_n: int) -> None: } metrics_before = calc_retrieval_metrics(distances=distances, **args) + ii_closest_before = torch.argsort(distances) model = RandomPairwise() - processor = PairwiseEmbeddingsPostprocessor(pairwise_model=model, top_n=top_n, batch_size=4, num_workers=0) - distances_upd = processor.process(distances=distances, queries=query_embeddings, galleries=gallery_embeddings) + processor = PairwiseReranker(pairwise_model=model, top_n=top_n, batch_size=4, num_workers=0) + distances_upd = processor.process(distances=distances, dataset=dataset) metrics_after = calc_retrieval_metrics(distances=distances_upd, **args) + ii_closest_after = torch.argsort(distances_upd) assert metrics_before == metrics_after + + # also check that we only re-ranked the first top_n items + assert (ii_closest_before[:, :top_n] != ii_closest_after[:, :top_n]).any() + assert (ii_closest_before[:, top_n:] == ii_closest_after[:, top_n:]).all() diff --git a/tests/test_oml/test_postprocessor/test_pairwise_images.py b/tests/test_oml/test_postprocessor/test_pairwise_images.py index cc1477929..ba37b666f 100644 --- a/tests/test_oml/test_postprocessor/test_pairwise_images.py +++ b/tests/test_oml/test_postprocessor/test_pairwise_images.py @@ -1,73 +1,57 @@ from typing import Tuple -import numpy as np import pytest import torch from torch import Tensor, nn from oml.const import MOCK_DATASET_PATH -from oml.inference.flat import inference_on_images +from oml.datasets.images import ImageQueryGalleryLabeledDataset +from oml.inference import inference +from oml.interfaces.datasets import IQueryGalleryDataset from oml.models.meta.siamese import TrivialDistanceSiamese from oml.models.resnet.extractor import ResnetExtractor -from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor +from oml.retrieval.postprocessors.pairwise import PairwiseReranker from oml.transforms.images.torchvision import get_normalisation_resize_torch from oml.transforms.images.utils import TTransforms from oml.utils.download_mock_dataset import download_mock_dataset from oml.utils.misc_torch import pairwise_dist -def get_validation_results(model: nn.Module, transforms: TTransforms) -> Tuple[Tensor, Tensor, Tensor]: +def get_validation_results(model: nn.Module, transforms: TTransforms) -> Tuple[Tensor, IQueryGalleryDataset]: _, df_val = download_mock_dataset(MOCK_DATASET_PATH) - is_query = np.array(df_val["is_query"]).astype(bool) - is_gallery = np.array(df_val["is_gallery"]).astype(bool) - paths = np.array(df_val["path"].apply(lambda x: MOCK_DATASET_PATH / x)) - queries = paths[is_query] - galleries = paths[is_gallery] + dataset = ImageQueryGalleryLabeledDataset(df=df_val, transform=transforms, dataset_root=MOCK_DATASET_PATH) - embeddings = inference_on_images( - model=model, - paths=paths.tolist(), - transform=transforms, - num_workers=0, - batch_size=4, - verbose=False, - use_fp16=True, - ) + embeddings = inference(model, dataset, batch_size=4) - distances = pairwise_dist(x1=embeddings[is_query], x2=embeddings[is_gallery], p=2) + distances = pairwise_dist(x1=embeddings[dataset.get_query_ids()], x2=embeddings[dataset.get_gallery_ids()], p=2) - return distances, queries, galleries + return distances, dataset @pytest.mark.long @pytest.mark.parametrize("top_n", [2, 5, 100]) -def test_trivial_processing_does_not_change_distances_order(top_n: int) -> None: +@pytest.mark.parametrize("pairwise_distances_bias", [0, -100, +100]) +def test_trivial_processing_does_not_change_distances_order(top_n: int, pairwise_distances_bias: float) -> None: extractor = ResnetExtractor(weights=None, arch="resnet18", normalise_features=True, gem_p=None, remove_fc=True) - - pairwise_model = TrivialDistanceSiamese(extractor) + pairwise_model = TrivialDistanceSiamese(extractor, output_bias=pairwise_distances_bias) transforms = get_normalisation_resize_torch(im_size=32) + distances, dataset = get_validation_results(model=extractor, transforms=transforms) - distances, queries, galleries = get_validation_results(model=extractor, transforms=transforms) - - postprocessor = PairwiseImagesPostprocessor( + postprocessor = PairwiseReranker( top_n=top_n, pairwise_model=pairwise_model, - transforms=transforms, num_workers=0, batch_size=4, verbose=False, use_fp16=True, ) - distances_processed = postprocessor.process(distances=distances.clone(), queries=queries, galleries=galleries) - - order = distances.argsort() - order_processed = distances_processed.argsort() + distances_processed = postprocessor.process(distances=distances.clone(), dataset=dataset) - assert (order == order_processed).all(), (order, order_processed) + assert (distances_processed.argsort() == distances.argsort()).all() - if top_n <= len(galleries): - min_orig_distances = torch.topk(distances, k=top_n, largest=False).values - min_processed_distances = torch.topk(distances_processed, k=top_n, largest=False).values - assert torch.allclose(min_orig_distances, min_processed_distances) + if pairwise_distances_bias == 0: + assert torch.allclose(distances_processed, distances) + else: + assert not torch.allclose(distances_processed, distances) diff --git a/tests/test_oml/test_retrieval_results/test_retrieval_results.py b/tests/test_oml/test_retrieval_results/test_retrieval_results.py index 701475299..c22c1666d 100644 --- a/tests/test_oml/test_retrieval_results/test_retrieval_results.py +++ b/tests/test_oml/test_retrieval_results/test_retrieval_results.py @@ -3,56 +3,92 @@ import matplotlib.pyplot as plt import pytest import torch +from torch import nn -from oml.const import IS_QUERY_COLUMN, LABELS_COLUMN, MOCK_DATASET_PATH, PATHS_COLUMN +from oml.const import LABELS_COLUMN, MOCK_DATASET_PATH, PATHS_COLUMN from oml.datasets.images import ( ImageQueryGalleryDataset, ImageQueryGalleryLabeledDataset, ) -from oml.inference.flat import inference_on_images +from oml.inference import inference +from oml.interfaces.datasets import IVisualizableDataset from oml.models import ResnetExtractor from oml.retrieval.retrieval_results import RetrievalResults -from oml.transforms.images.torchvision import get_normalisation_torch from oml.utils.download_mock_dataset import download_mock_dataset +from tests.test_integrations.utils import ( + EmbeddingsQueryGalleryDataset, + EmbeddingsQueryGalleryLabeledDataset, +) -@pytest.mark.parametrize("with_gt_labels", [False, True]) -@pytest.mark.parametrize("df_name", ["df.csv", "df_with_bboxes.csv", "df_with_sequence.csv"]) -def test_retrieval_results_om_images(with_gt_labels: bool, df_name: str) -> None: - # todo 522: add test on Embeddings after we merge unified inference +def get_model_and_datasets_images(with_gt_labels): # type: ignore + datasets = [] - _, df_val = download_mock_dataset(dataset_root=MOCK_DATASET_PATH, df_name=df_name) - df_val[PATHS_COLUMN] = df_val[PATHS_COLUMN].apply(lambda x: Path(MOCK_DATASET_PATH) / x) + for df_name in ["df.csv", "df_with_bboxes.csv", "df_with_sequence.csv"]: + _, df_val = download_mock_dataset(global_paths=True, df_name=df_name) + df_val[PATHS_COLUMN] = df_val[PATHS_COLUMN].apply(lambda x: Path(MOCK_DATASET_PATH) / x) - n_query = df_val[IS_QUERY_COLUMN].sum() + if with_gt_labels: + dataset = ImageQueryGalleryLabeledDataset(df_val) + else: + del df_val[LABELS_COLUMN] + dataset = ImageQueryGalleryDataset(df_val) - if with_gt_labels: - dataset = ImageQueryGalleryLabeledDataset(df_val) - else: - del df_val[LABELS_COLUMN] - dataset = ImageQueryGalleryDataset(df_val) + datasets.append(dataset) model = ResnetExtractor(weights=None, arch="resnet18", gem_p=None, remove_fc=True, normalise_features=False) - embeddings = inference_on_images( - model=model, - paths=df_val[PATHS_COLUMN].tolist(), - transform=get_normalisation_torch(), - num_workers=0, - batch_size=4, - ).float() - top_n = 2 - rr = RetrievalResults.compute_from_embeddings(embeddings=embeddings, dataset=dataset, n_items_to_retrieve=top_n) + return datasets, model + - assert rr.distances.shape == (n_query, top_n) - assert rr.retrieved_ids.shape == (n_query, top_n) - assert torch.allclose(rr.distances.clone().sort()[0], rr.distances) +def get_model_and_datasets_embeddings(with_gt_labels): # type: ignore + embeddings = torch.randn((6, 4)).float() + is_query = torch.tensor([1, 1, 1, 0, 0, 0]).bool() + is_gallery = torch.tensor([0, 0, 0, 1, 1, 1]).bool() if with_gt_labels: - assert rr.gt_ids is not None + labels = torch.tensor([0, 1, 0, 1, 0, 1]).long() + dataset = EmbeddingsQueryGalleryLabeledDataset( + embeddings=embeddings, labels=labels, is_query=is_query, is_gallery=is_gallery + ) + else: + dataset = EmbeddingsQueryGalleryDataset(embeddings=embeddings, is_query=is_query, is_gallery=is_gallery) + + model = nn.Linear(4, 1) + + return [dataset], model + + +@pytest.mark.parametrize("with_gt_labels", [False, True]) +@pytest.mark.parametrize("data_getter", [get_model_and_datasets_embeddings, get_model_and_datasets_images]) +def test_retrieval_results_om_images(with_gt_labels, data_getter) -> None: # type: ignore + datasets, model = data_getter(with_gt_labels=with_gt_labels) + + for dataset in datasets: + + n_query = len(dataset.get_query_ids()) + + embeddings = inference(model=model, dataset=dataset, num_workers=0, batch_size=4).float() + + top_n = 2 + rr = RetrievalResults.compute_from_embeddings(embeddings=embeddings, dataset=dataset, n_items_to_retrieve=top_n) + + assert rr.distances.shape == (n_query, top_n) + assert rr.retrieved_ids.shape == (n_query, top_n) + assert torch.allclose(rr.distances.clone().sort()[0], rr.distances) + + if with_gt_labels: + assert rr.gt_ids is not None - fig = rr.visualize(query_ids=[0, 3], dataset=dataset, n_galleries_to_show=3) - fig.show() - plt.close(fig=fig) + error_expected = not isinstance(dataset, IVisualizableDataset) + if error_expected: + with pytest.raises(TypeError): + fig = rr.visualize(query_ids=[0, 3], dataset=dataset, n_galleries_to_show=3) + fig.show() + plt.close(fig=fig) + else: + fig = rr.visualize(query_ids=[0, 3], dataset=dataset, n_galleries_to_show=3) + fig.show() + plt.close(fig=fig) assert True diff --git a/tests/test_oml/test_utils/test_misc_torch.py b/tests/test_oml/test_utils/test_misc_torch.py index 116e03d2a..7b9c4d06b 100644 --- a/tests/test_oml/test_utils/test_misc_torch.py +++ b/tests/test_oml/test_utils/test_misc_torch.py @@ -8,6 +8,7 @@ PCA, TData, assign_2d, + cat_two_sorted_tensors_and_keep_it_sorted, elementwise_dist, take_2d, unique_by_ids, @@ -24,6 +25,41 @@ def test_elementwise_dist() -> None: assert torch.isclose(val_torch, torch.tensor(val_custom)).all() +@pytest.mark.parametrize( + "x1,x2,e,expected", + [ + ( + # x1 + torch.tensor([[10, 20, 30], [40, 50, 60]]).float(), + # x2 + torch.tensor([[0.3, 0.4, 0.5], [0.6, 0.8, 0.9]]).float(), + # e + 0.001, + # expected: rescaling is needed + torch.tensor( + [ + [0.1 - 0.001, 0.2 - 0.001, 0.3 - 0.001, 0.3, 0.4, 0.5], + [0.4 - 0.001, 0.5 - 0.001, 0.6 - 0.001, 0.6, 0.8, 0.9], + ] + ).float(), + ), + ( + # x1 + torch.tensor([[-10, -5], [-20, -8]]).float(), + # x2 + torch.tensor([[0.3, 0.4, 0.5], [0.6, 0.8, 0.9]]).float(), + # e + 0.001, + # expected: rescaling is not needed, we jast concat + torch.tensor([[-10, -5, 0.3, 0.4, 0.5], [-20, -8, 0.6, 0.8, 0.9]]).float(), + ), + ], +) +def test_concat_two_sorted_tensors_with_rescaling(x1, x2, e, expected): # type: ignore + out = cat_two_sorted_tensors_and_keep_it_sorted(x1, x2, eps=e) + assert torch.isclose(expected, out).all() + + # fmt: off def test_take_2d() -> None: x = torch.tensor([ diff --git a/tests/test_runs/test_pipelines/configs/train_postprocessor.yaml b/tests/test_runs/test_pipelines/configs/train_postprocessor.yaml index 22dbabd68..c8a1e34f0 100644 --- a/tests/test_runs/test_pipelines/configs/train_postprocessor.yaml +++ b/tests/test_runs/test_pipelines/configs/train_postprocessor.yaml @@ -1,7 +1,7 @@ postfix: "postprocessing" seed: 42 -precision: 16 +precision: 32 accelerator: cpu devices: 2 find_unused_parameters: False @@ -78,15 +78,10 @@ transforms_train: batch_size_inference: 128 postprocessor: - name: pairwise_images + name: pairwise_reranker args: top_n: 5 pairwise_model: ${pairwise_model} - transforms: - name: norm_resize_hypvit_torch - args: - im_size: 32 - crop_size: 32 num_workers: 0 batch_size: ${batch_size_inference} verbose: True diff --git a/tests/test_runs/test_pipelines/configs/validate.yaml b/tests/test_runs/test_pipelines/configs/validate.yaml index 5096f57a6..1e9d53f11 100644 --- a/tests/test_runs/test_pipelines/configs/validate.yaml +++ b/tests/test_runs/test_pipelines/configs/validate.yaml @@ -16,7 +16,7 @@ num_workers: 0 bs_val: 2 postprocessor: - name: pairwise_images + name: pairwise_reranker args: top_n: 3 pairwise_model: @@ -30,10 +30,6 @@ postprocessor: remove_fc: True normalise_features: False weights: resnet50_moco_v2 - transforms: - name: norm_resize_torch - args: - im_size: 64 num_workers: 0 batch_size: 4 verbose: True diff --git a/tests/test_runs/test_pipelines/predict.py b/tests/test_runs/test_pipelines/predict.py index f10c0e9a4..7c5f3b43e 100644 --- a/tests/test_runs/test_pipelines/predict.py +++ b/tests/test_runs/test_pipelines/predict.py @@ -11,7 +11,7 @@ def main_hydra(cfg: DictConfig) -> None: cfg = dictconfig_to_dict(cfg) download_mock_dataset(MOCK_DATASET_PATH) - cfg["data_dir"] = MOCK_DATASET_PATH + cfg["data_dir"] = str(MOCK_DATASET_PATH) extractor_prediction_pipeline(cfg) diff --git a/tests/test_runs/test_pipelines/train.py b/tests/test_runs/test_pipelines/train.py index 2505e4c54..3b16fefe0 100644 --- a/tests/test_runs/test_pipelines/train.py +++ b/tests/test_runs/test_pipelines/train.py @@ -11,7 +11,7 @@ def main_hydra(cfg: DictConfig) -> None: cfg = dictconfig_to_dict(cfg) download_mock_dataset(MOCK_DATASET_PATH) - cfg["dataset_root"] = MOCK_DATASET_PATH + cfg["dataset_root"] = str(MOCK_DATASET_PATH) extractor_training_pipeline(cfg) diff --git a/tests/test_runs/test_pipelines/train_arcface_with_categories.py b/tests/test_runs/test_pipelines/train_arcface_with_categories.py index 65f30807e..74177ba3e 100644 --- a/tests/test_runs/test_pipelines/train_arcface_with_categories.py +++ b/tests/test_runs/test_pipelines/train_arcface_with_categories.py @@ -11,7 +11,7 @@ def main_hydra(cfg: DictConfig) -> None: cfg = dictconfig_to_dict(cfg) download_mock_dataset(MOCK_DATASET_PATH) - cfg["dataset_root"] = MOCK_DATASET_PATH + cfg["dataset_root"] = str(MOCK_DATASET_PATH) extractor_training_pipeline(cfg) diff --git a/tests/test_runs/test_pipelines/train_postprocessor.py b/tests/test_runs/test_pipelines/train_postprocessor.py index 6ac23cfeb..f01a6a516 100644 --- a/tests/test_runs/test_pipelines/train_postprocessor.py +++ b/tests/test_runs/test_pipelines/train_postprocessor.py @@ -11,7 +11,7 @@ def main_hydra(cfg: DictConfig) -> None: cfg = dictconfig_to_dict(cfg) download_mock_dataset(MOCK_DATASET_PATH) - cfg["dataset_root"] = MOCK_DATASET_PATH + cfg["dataset_root"] = str(MOCK_DATASET_PATH) postprocessor_training_pipeline(cfg) diff --git a/tests/test_runs/test_pipelines/train_with_categories.py b/tests/test_runs/test_pipelines/train_with_categories.py index 90328ac9e..db967f38f 100644 --- a/tests/test_runs/test_pipelines/train_with_categories.py +++ b/tests/test_runs/test_pipelines/train_with_categories.py @@ -11,7 +11,7 @@ def main_hydra(cfg: DictConfig) -> None: cfg = dictconfig_to_dict(cfg) download_mock_dataset(MOCK_DATASET_PATH) - cfg["dataset_root"] = MOCK_DATASET_PATH + cfg["dataset_root"] = str(MOCK_DATASET_PATH) extractor_training_pipeline(cfg) diff --git a/tests/test_runs/test_pipelines/train_with_sequence.py b/tests/test_runs/test_pipelines/train_with_sequence.py index ac7d1597f..3ecaa5701 100644 --- a/tests/test_runs/test_pipelines/train_with_sequence.py +++ b/tests/test_runs/test_pipelines/train_with_sequence.py @@ -11,7 +11,7 @@ def main_hydra(cfg: DictConfig) -> None: cfg = dictconfig_to_dict(cfg) download_mock_dataset(MOCK_DATASET_PATH) - cfg["dataset_root"] = MOCK_DATASET_PATH + cfg["dataset_root"] = str(MOCK_DATASET_PATH) extractor_training_pipeline(cfg) diff --git a/tests/test_runs/test_pipelines/validate.py b/tests/test_runs/test_pipelines/validate.py index 1226f9a97..c8890e6a4 100644 --- a/tests/test_runs/test_pipelines/validate.py +++ b/tests/test_runs/test_pipelines/validate.py @@ -11,7 +11,7 @@ def main_hydra(cfg: DictConfig) -> None: cfg = dictconfig_to_dict(cfg) download_mock_dataset(MOCK_DATASET_PATH) - cfg["dataset_root"] = MOCK_DATASET_PATH + cfg["dataset_root"] = str(MOCK_DATASET_PATH) extractor_validation_pipeline(cfg)