Training with distance, reducer, miner and loss from PML
import torch
from tqdm import tqdm
from oml.datasets.base import DatasetWithLabels
from oml.models import ViTExtractor
from oml.samplers.balance import BalanceSampler
from oml.utils.download_mock_dataset import download_mock_dataset
from pytorch_metric_learning import losses, distances, reducers, miners
df_train, _ = download_mock_dataset(global_paths=True)
extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)
train_dataset = DatasetWithLabels(df_train)
# PML specific
distance = distances.LpDistance(p=2)
reducer = reducers.ThresholdReducer(low=0)
criterion = losses.TripletMarginLoss()
miner = miners.TripletMarginMiner(margin=0.2, distance=distance, type_of_triplets="all")
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler)
for batch in tqdm(train_loader):
embeddings = extractor(batch["input_tensors"])
loss = criterion(embeddings, batch["labels"], miner(embeddings, batch["labels"])) # PML specific
loss.backward()
optimizer.step()
optimizer.zero_grad()