Skip to content

Files

This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Latest commit

Apr 21, 2024
aedec8b · Apr 21, 2024

History

History
43 lines (32 loc) · 1.45 KB

train_with_pml_advanced.md

File metadata and controls

43 lines (32 loc) · 1.45 KB
Training with distance, reducer, miner and loss from PML

import torch
from tqdm import tqdm

from oml.datasets.base import DatasetWithLabels
from oml.models import ViTExtractor
from oml.samplers.balance import BalanceSampler
from oml.utils.download_mock_dataset import download_mock_dataset

from pytorch_metric_learning import losses, distances, reducers, miners

df_train, _ = download_mock_dataset(global_paths=True)

extractor = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6)

train_dataset = DatasetWithLabels(df_train)

# PML specific
distance = distances.LpDistance(p=2)
reducer = reducers.ThresholdReducer(low=0)
criterion = losses.TripletMarginLoss()
miner = miners.TripletMarginMiner(margin=0.2, distance=distance, type_of_triplets="all")

sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler)

for batch in tqdm(train_loader):
    embeddings = extractor(batch["input_tensors"])
    loss = criterion(embeddings, batch["labels"], miner(embeddings, batch["labels"]))  # PML specific
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Open In Colab