Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made inference modality agnostic in re-ranking and other parts of the repo #542

Merged
merged 27 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 19 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,13 +301,12 @@ from oml.models import ViTExtractor
from oml.samplers.balance import BalanceSampler
from oml.utils.download_mock_dataset import download_mock_dataset

dataset_root = "mock_dataset/"
df_train, _ = download_mock_dataset(dataset_root)
df_train, _ = download_mock_dataset(global_paths=True)

extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)

train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
train_dataset = DatasetWithLabels(df_train)
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner(), need_logs=True)
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler)
Expand Down Expand Up @@ -342,12 +341,11 @@ from oml.metrics.embeddings import EmbeddingMetrics
from oml.models import ViTExtractor
from oml.utils.download_mock_dataset import download_mock_dataset

dataset_root = "mock_dataset/"
_, df_val = download_mock_dataset(dataset_root)
_, df_val = download_mock_dataset(global_paths=True)

extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval()

val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
val_dataset = DatasetQueryGallery(df_val)

val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
calculator = EmbeddingMetrics(extra_keys=("paths",))
Expand Down Expand Up @@ -401,21 +399,20 @@ from oml.lightning.pipelines.logging import (
WandBPipelineLogger,
)

dataset_root = "mock_dataset/"
df_train, df_val = download_mock_dataset(dataset_root)
df_train, df_val = download_mock_dataset(global_paths=True)

# model
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)

# train
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
train_dataset = DatasetWithLabels(df_train)
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler)

# val
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
val_dataset = DatasetQueryGallery(df_val)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
metric_callback = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[train_dataset.paths_key,]), log_images=True)

Expand Down Expand Up @@ -455,37 +452,36 @@ trainer.fit(pl_model, train_dataloaders=train_loader, val_dataloaders=val_loader
```python
import torch

from oml.const import MOCK_DATASET_PATH
from oml.inference.flat import inference_on_images
from oml.datasets import ImageQueryGalleryDataset
from oml.inference import inference
from oml.models import ViTExtractor
from oml.registry.transforms import get_transforms_for_pretrained
from oml.utils.download_mock_dataset import download_mock_dataset
from oml.utils.misc_torch import pairwise_dist

_, df_val = download_mock_dataset(MOCK_DATASET_PATH)
df_val["path"] = df_val["path"].apply(lambda x: MOCK_DATASET_PATH / x)
queries = df_val[df_val["is_query"]]["path"].tolist()
galleries = df_val[df_val["is_gallery"]]["path"].tolist()
_, df_test = download_mock_dataset(global_paths=True)
del df_test["label"] # we don't need gt labels for doing predictions

extractor = ViTExtractor.from_pretrained("vits16_dino")
transform, _ = get_transforms_for_pretrained("vits16_dino")

args = {"num_workers": 0, "batch_size": 8}
features_queries = inference_on_images(extractor, paths=queries, transform=transform, **args)
features_galleries = inference_on_images(extractor, paths=galleries, transform=transform, **args)
dataset = ImageQueryGalleryDataset(df_test, transform=transform)

embeddings = inference(extractor, dataset, batch_size=4, num_workers=0)
embeddings_query = embeddings[dataset.get_query_ids()]
embeddings_gallery = embeddings[dataset.get_gallery_ids()]

# Now we can explicitly build pairwise matrix of distances or save you RAM via using kNN
use_knn = False
top_k = 3

if use_knn:
from sklearn.neighbors import NearestNeighbors
knn = NearestNeighbors(algorithm="auto", p=2)
knn.fit(features_galleries)
dists, ii_closest = knn.kneighbors(features_queries, n_neighbors=top_k, return_distance=True)
knn = NearestNeighbors(algorithm="auto", p=2).fit(embeddings_query)
dists, ii_closest = knn.kneighbors(embeddings_gallery, n_neighbors=top_k, return_distance=True)

else:
dist_mat = pairwise_dist(x1=features_queries, x2=features_galleries)
dist_mat = pairwise_dist(x1=embeddings_query, x2=embeddings_gallery, p=2)
dists, ii_closest = torch.topk(dist_mat, dim=1, k=top_k, largest=False)

print(f"Top {top_k} items closest to queries are:\n {ii_closest}")
Expand Down
25 changes: 12 additions & 13 deletions docs/readme/examples_source/extractor/retrieval_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,36 @@
```python
import torch

from oml.const import MOCK_DATASET_PATH
from oml.inference.flat import inference_on_images
from oml.datasets import ImageQueryGalleryDataset
from oml.inference import inference
from oml.models import ViTExtractor
from oml.registry.transforms import get_transforms_for_pretrained
from oml.utils.download_mock_dataset import download_mock_dataset
from oml.utils.misc_torch import pairwise_dist

_, df_val = download_mock_dataset(MOCK_DATASET_PATH)
df_val["path"] = df_val["path"].apply(lambda x: MOCK_DATASET_PATH / x)
queries = df_val[df_val["is_query"]]["path"].tolist()
galleries = df_val[df_val["is_gallery"]]["path"].tolist()
_, df_test = download_mock_dataset(global_paths=True)
del df_test["label"] # we don't need gt labels for doing predictions

extractor = ViTExtractor.from_pretrained("vits16_dino")
transform, _ = get_transforms_for_pretrained("vits16_dino")

args = {"num_workers": 0, "batch_size": 8}
features_queries = inference_on_images(extractor, paths=queries, transform=transform, **args)
features_galleries = inference_on_images(extractor, paths=galleries, transform=transform, **args)
dataset = ImageQueryGalleryDataset(df_test, transform=transform)

embeddings = inference(extractor, dataset, batch_size=4, num_workers=0)
embeddings_query = embeddings[dataset.get_query_ids()]
embeddings_gallery = embeddings[dataset.get_gallery_ids()]

# Now we can explicitly build pairwise matrix of distances or save you RAM via using kNN
use_knn = False
top_k = 3

if use_knn:
from sklearn.neighbors import NearestNeighbors
knn = NearestNeighbors(algorithm="auto", p=2)
knn.fit(features_galleries)
dists, ii_closest = knn.kneighbors(features_queries, n_neighbors=top_k, return_distance=True)
knn = NearestNeighbors(algorithm="auto", p=2).fit(embeddings_query)
dists, ii_closest = knn.kneighbors(embeddings_gallery, n_neighbors=top_k, return_distance=True)

else:
dist_mat = pairwise_dist(x1=features_queries, x2=features_galleries)
dist_mat = pairwise_dist(x1=embeddings_query, x2=embeddings_gallery, p=2)
dists, ii_closest = torch.topk(dist_mat, dim=1, k=top_k, largest=False)

print(f"Top {top_k} items closest to queries are:\n {ii_closest}")
Expand Down
5 changes: 2 additions & 3 deletions docs/readme/examples_source/extractor/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ from oml.models import ViTExtractor
from oml.samplers.balance import BalanceSampler
from oml.utils.download_mock_dataset import download_mock_dataset

dataset_root = "mock_dataset/"
df_train, _ = download_mock_dataset(dataset_root)
df_train, _ = download_mock_dataset(global_paths=True)

extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)

train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
train_dataset = DatasetWithLabels(df_train)
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner(), need_logs=True)
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler)
Expand Down
9 changes: 3 additions & 6 deletions docs/readme/examples_source/extractor/train_2loaders_val.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,18 @@ from oml.models import ViTExtractor
from oml.transforms.images.torchvision import get_normalisation_resize_torch
from oml.utils.download_mock_dataset import download_mock_dataset

dataset_root = "mock_dataset/"
_, df_val = download_mock_dataset(dataset_root)
_, df_val = download_mock_dataset(global_paths=True)

extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)

# 1st validation dataset (big images)
val_dataset_1 = DatasetQueryGallery(df_val, dataset_root=dataset_root,
transform=get_normalisation_resize_torch(im_size=224))
val_dataset_1 = DatasetQueryGallery(df_val, transform=get_normalisation_resize_torch(im_size=224))
val_loader_1 = torch.utils.data.DataLoader(val_dataset_1, batch_size=4)
metric_callback_1 = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[val_dataset_1.paths_key,]),
log_images=True, loader_idx=0)

# 2nd validation dataset (small images)
val_dataset_2 = DatasetQueryGallery(df_val, dataset_root=dataset_root,
transform=get_normalisation_resize_torch(im_size=48))
val_dataset_2 = DatasetQueryGallery(df_val, transform=get_normalisation_resize_torch(im_size=48))
val_loader_2 = torch.utils.data.DataLoader(val_dataset_2, batch_size=4)
metric_callback_2 = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[val_dataset_2.paths_key,]),
log_images=True, loader_idx=1)
Expand Down
7 changes: 3 additions & 4 deletions docs/readme/examples_source/extractor/train_val_pl.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,20 @@ from oml.lightning.pipelines.logging import (
WandBPipelineLogger,
)

dataset_root = "mock_dataset/"
df_train, df_val = download_mock_dataset(dataset_root)
df_train, df_val = download_mock_dataset(global_paths=True)

# model
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)

# train
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
train_dataset = DatasetWithLabels(df_train)
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler)

# val
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
val_dataset = DatasetQueryGallery(df_val)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
metric_callback = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[train_dataset.paths_key,]), log_images=True)

Expand Down
7 changes: 3 additions & 4 deletions docs/readme/examples_source/extractor/train_val_pl_ddp.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,20 @@ from oml.samplers.balance import BalanceSampler
from oml.utils.download_mock_dataset import download_mock_dataset
from pytorch_lightning.strategies import DDPStrategy

dataset_root = "mock_dataset/"
df_train, df_val = download_mock_dataset(dataset_root)
df_train, df_val = download_mock_dataset(global_paths=True)

# model
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False)

# train
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
train_dataset = DatasetWithLabels(df_train)
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=3)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler)

# val
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
val_dataset = DatasetQueryGallery(df_val)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
metric_callback = MetricValCallbackDDP(metric=EmbeddingMetricsDDP()) # DDP specific

Expand Down
5 changes: 2 additions & 3 deletions docs/readme/examples_source/extractor/train_with_pml.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@ from oml.utils.download_mock_dataset import download_mock_dataset

from pytorch_metric_learning import losses, distances, reducers, miners

dataset_root = "mock_dataset/"
df_train, _ = download_mock_dataset(dataset_root)
df_train, _ = download_mock_dataset(global_paths=True)

extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)

train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
train_dataset = DatasetWithLabels(df_train)

# PML specific
# criterion = losses.TripletMarginLoss(margin=0.2, triplets_per_anchor="all")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@ from oml.utils.download_mock_dataset import download_mock_dataset

from pytorch_metric_learning import losses, distances, reducers, miners

dataset_root = "mock_dataset/"
df_train, _ = download_mock_dataset(dataset_root)
df_train, _ = download_mock_dataset(global_paths=True)

extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)

train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
train_dataset = DatasetWithLabels(df_train)

# PML specific
distance = distances.LpDistance(p=2)
Expand Down
5 changes: 2 additions & 3 deletions docs/readme/examples_source/extractor/val.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@ from oml.metrics.embeddings import EmbeddingMetrics
from oml.models import ViTExtractor
from oml.utils.download_mock_dataset import download_mock_dataset

dataset_root = "mock_dataset/"
_, df_val = download_mock_dataset(dataset_root)
_, df_val = download_mock_dataset(global_paths=True)

extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval()

val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
val_dataset = DatasetQueryGallery(df_val)

val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
calculator = EmbeddingMetrics(extra_keys=("paths",))
Expand Down
5 changes: 2 additions & 3 deletions docs/readme/examples_source/extractor/val_with_sequence.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,11 @@ from oml.metrics.embeddings import EmbeddingMetrics
from oml.models import ViTExtractor
from oml.utils.download_mock_dataset import download_mock_dataset

dataset_root = "mock_dataset/"
_, df_val = download_mock_dataset(dataset_root, df_name="df_with_sequence.csv") # <- sequence info is in the file
_, df_val = download_mock_dataset(global_paths=True, df_name="df_with_sequence.csv") # <- sequence info is in the file

extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval()

val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
val_dataset = DatasetQueryGallery(df_val)

val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
calculator = EmbeddingMetrics(extra_keys=("paths",), sequence_key=val_dataset.sequence_key)
Expand Down
42 changes: 19 additions & 23 deletions docs/readme/examples_source/postprocessing/predict.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,40 @@
[comment]:postprocessor-pred-start
```python
import torch
from torch.utils.data import DataLoader

from oml.const import PATHS_COLUMN
from oml.datasets.base import DatasetQueryGallery
from oml.inference.flat import inference_on_dataframe
from oml.datasets import ImageQueryGalleryDataset
from oml.inference import inference
from oml.models import ConcatSiamese, ViTExtractor
from oml.registry.transforms import get_transforms_for_pretrained
from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor
from oml.retrieval.postprocessors.pairwise import PairwiseReranker
from oml.utils.download_mock_dataset import download_mock_dataset
from oml.utils.misc_torch import pairwise_dist

dataset_root = "mock_dataset/"
download_mock_dataset(dataset_root)
_, df_test = download_mock_dataset(global_paths=True)
del df_test["label"] # we don't need gt labels for doing predictions

# 1. Let's use feature extractor to get predictions
extractor = ViTExtractor.from_pretrained("vits16_dino")
transforms, _ = get_transforms_for_pretrained("vits16_dino")

_, emb_val, _, df_val = inference_on_dataframe(dataset_root, "df.csv", extractor, transforms=transforms)
dataset = ImageQueryGalleryDataset(df_test, transform=transforms)

is_query = df_val["is_query"].astype('bool').values
distances = pairwise_dist(x1=emb_val[is_query], x2=emb_val[~is_query])
# 1. Let's get top 5 galleries closest to every query...
embeddings = inference(extractor, dataset, batch_size=4, num_workers=0)
embeddings_query = embeddings[dataset.get_query_ids()]
embeddings_gallery = embeddings[dataset.get_gallery_ids()]

print("\nOriginal predictions:\n", torch.topk(distances, dim=1, k=3, largest=False)[1])
distances = pairwise_dist(x1=embeddings_query, x2=embeddings_gallery, p=2)
ii_closest = torch.topk(distances, dim=1, k=5, largest=False)[1]

# 2. Let's initialise a random pairwise postprocessor to perform re-ranking
# 2. ... and let's re-rank first 3 of them
siamese = ConcatSiamese(extractor=extractor, mlp_hidden_dims=[100]) # Note! Replace it with your trained postprocessor
postprocessor = PairwiseImagesPostprocessor(top_n=3, pairwise_model=siamese, transforms=transforms)

dataset = DatasetQueryGallery(df_val, extra_data={"embeddings": emb_val}, transform=transforms)
loader = DataLoader(dataset, batch_size=4)

query_paths = df_val[PATHS_COLUMN][is_query].values
gallery_paths = df_val[PATHS_COLUMN][~is_query].values
distances_upd = postprocessor.process(distances=distances, queries=query_paths, galleries=gallery_paths)

print("\nPredictions after postprocessing:\n", torch.topk(distances_upd, dim=1, k=3, largest=False)[1])
postprocessor = PairwiseReranker(top_n=3, pairwise_model=siamese, batch_size=4, num_workers=0)
distances_upd = postprocessor.process(distances, dataset=dataset)
ii_closest_upd = torch.topk(distances_upd, dim=1, k=5, largest=False)[1]

# You may see the first 3 positions have changed, but the rest remain the same:
print("\Closest galleries:\n", ii_closest)
print("\nClosest galleries updates:\n", ii_closest_upd)
```
[comment]:postprocessor-pred-end
</p>
Expand Down
Loading
Loading