Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/weighted multiple datasets #3

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 64 additions & 55 deletions miyagi_trainer/dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,63 @@
import os
import inspect
from collections import defaultdict, Counter
from collections import Counter

import torch
import torchvision
import torchvision.transforms as transforms

from datasets import CUSTOM_DATASETS


def _get_pytorch_dataloders(dataset, batch_size, num_workers, balanced_weights = False):
def _get_pytorch_dataloders(
dataset, batch_size, num_workers, balanced_weights=False,
multiple_datasets_temperature=0.0):

if balanced_weights:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that it was me that did that, but I forgot, if I set balanced_weights the samples will be balanced by the size of the class, that's it? What you implemented is a way that you can give more/less importance to other sources of data, datasets, right?

Additionally, you cannot combine both, right?

Maybe we should change the balanced_weights parameter's name to be more descriptive.

class_sample_count = torch.tensor([*dataset.class_sample_count.values()])

weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in dataset.all_targets])

# Create sampler, dataset, loader
sampler = torch.utils.data.WeightedRandomSampler(samples_weight, len(samples_weight))
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers,
pin_memory=True, sampler=sampler)

sampler = torch.utils.data.WeightedRandomSampler(
samples_weight, len(samples_weight), replacement=True)
loader = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, num_workers=num_workers,
pin_memory=True, sampler=sampler)

elif multiple_datasets_temperature:
samples_weight_list = []
num_datasets = len(dataset.dataset_idx)
datasets_sizes = [sum(dataset.ds_class_sample_count[idx].values())
for idx in range(num_datasets)]
t = multiple_datasets_temperature # temperature
sizes_tensor = torch.tensor(datasets_sizes).float()
weights = torch.softmax(sizes_tensor/(max(datasets_sizes)*t), dim=0)
dataset_weights = weights.numpy(force=True)

for idx, ds_idx in enumerate(dataset.dataset_idx):
class_sample_count = torch.tensor([*dataset.ds_class_sample_count[idx].values()])
weight = (class_sample_count/torch.sum(class_sample_count)).float()
# degenerate case, when we have only one class
if len(weight) == 1:
weight = torch.tensor([0.5]).float()
weight *= dataset_weights[idx]
if idx < num_datasets-1:
next_ds_idx = dataset.dataset_idx[idx+1]
else:
next_ds_idx = len(dataset.all_targets)
samples_weight_list.extend(
[weight[t] for t in dataset.all_targets[ds_idx:next_ds_idx]])

samples_weight = torch.tensor(samples_weight_list)
# Create sampler, dataset, loader
sampler = torch.utils.data.WeightedRandomSampler(
samples_weight, len(samples_weight), replacement=True)
loader = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, num_workers=num_workers,
pin_memory=True, sampler=sampler)

else:
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True)
Expand Down Expand Up @@ -50,33 +88,6 @@ def _get_image_folder_dataset(dataset_name, split, transform):

return dataset

def DEPRECATED_get_ffcv_dataloaders(root_dir, dataset_name, resize_size, batch_size, num_workers):
# Random resized crop
decoder = RandomResizedCropRGBImageDecoder((resize_size, resize_size))

# TODO: pipelines should be different for train, val, normally.
# TODO: augmentation should be done by another lib and equal for testing purposes.
# Data decoding and augmentation
image_pipeline = [decoder, ToTensor(), ToTorchImage(), ToDevice(0)]
label_pipeline = [IntDecoder(), ToTensor(), ToDevice(0)]

# Pipeline for each data field
pipelines = {
'image': image_pipeline,
'label': label_pipeline
}

train_loader = Loader(os.path.join(root_dir, f"{dataset_name}_train.beton"),
batch_size=batch_size, num_workers=num_workers,
order=OrderOption.RANDOM, pipelines=pipelines, os_cache=True)

val_loader = Loader(os.path.join(root_dir, f"{dataset_name}_val.beton"),
batch_size=batch_size, num_workers=num_workers,
order=OrderOption.RANDOM, pipelines=pipelines, os_cache=True)


return train_loader, val_loader


class DatasetJoin(torch.utils.data.ConcatDataset):

Expand All @@ -87,7 +98,9 @@ def __init__(self, imagefolder_dataset_list):
def join_classes(self):
join_class_to_idx = None
class_sample_count = Counter()
self.ds_class_sample_count = []
self.all_targets = []
self.dataset_idx = []
for ds in self.datasets:
if join_class_to_idx is None:
join_class_to_idx = ds.class_to_idx
Expand All @@ -106,6 +119,8 @@ def join_classes(self):
this_ds_counts = Counter(ds.targets)
# TODO: this class_sample_count is not taking into account the target_mapping
class_sample_count.update(this_ds_counts)
self.ds_class_sample_count.append(this_ds_counts)
self.dataset_idx.append(len(self.all_targets))
self.all_targets.extend(ds.targets)

# TODO: for this to work the order should be same when Im weighting samples
Expand All @@ -119,10 +134,11 @@ def join_classes(self):


def get_dataset_loaders(dataset_names,
transforms,
batch_size = 32,
num_workers = 4,
balanced_weights = False):
transforms,
batch_size = 32,
num_workers = 4,
balanced_weights = False,
multiple_datasets_temperature = 0.2):

"""
Expecting dataset_names and transforms to be dict with "train" and "val" keys
Expand All @@ -132,37 +148,30 @@ def get_dataset_loaders(dataset_names,
data_loaders = {}
for s in splits:
split_datasets = []
for ith_ds, ds_name in enumerate(dataset_names[s]):
for ds_name in dataset_names[s]:
if ds_name in dir(torchvision.datasets):
this_dataset = _get_pytorch_dataset(ds_name, s, transforms[s])
elif ds_name in CUSTOM_DATASETS.keys():
this_dataset = _get_image_folder_dataset(ds_name, s, transforms[s])
else:
raise ValueError(f'Invalid dataset: {ds_name}')

split_datasets.append(this_dataset)

# TODO: https://stackoverflow.com/questions/71173583/concat-datasets-in-pytorch
combined_datasets[s] = DatasetJoin(split_datasets)
data_loaders[s] = _get_pytorch_dataloders(combined_datasets[s], batch_size, num_workers, balanced_weights)
if s == 'train':
data_loaders[s] = _get_pytorch_dataloders(
combined_datasets[s], batch_size, num_workers,
balanced_weights,
multiple_datasets_temperature)
else:
data_loaders[s] = _get_pytorch_dataloders(
combined_datasets[s], batch_size, num_workers)

return data_loaders["train"], data_loaders["val"]


# def DEPRECATED_get_pytorch_default_transform(resize_size):

# def_transform = transforms.Compose([
# transforms.ToTensor(),
# transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x)
# ])

# if resize_size is not None:
# def_transform = transforms.Compose([
# transforms.Resize((resize_size,resize_size)),
# def_transform
# ])

# return def_transform


if __name__ == "__main__":
dataset_names = {
"train": ["liveness_simple", "flash_ds"],
Expand Down
2 changes: 2 additions & 0 deletions miyagi_trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@ def get_optimizer(model, optimizer_arg, weight_decay = 1e-4):
optimizer = torch.optim.Adam(model.parameters(), lr=0.05, weight_decay=weight_decay)
elif optimizer_arg == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=weight_decay)
else:
raise ValueError(f"Invalid optimizer_arg: {optimizer_arg}")

return optimizer
2 changes: 1 addition & 1 deletion miyagi_trainer/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
def get_scheduler(optimizer, args):
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
T_0=10,
T_mult=2,
T_mult=args.t_mult,
eta_min=0.01,
last_epoch=-1)

32 changes: 12 additions & 20 deletions miyagi_trainer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
import time
import datetime
import copy
from collections import deque

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import wandb

Expand Down Expand Up @@ -71,8 +68,6 @@ def train_model(model,
dataset_sizes = {x: len(dataloaders[x].dataset) for x in phases}
num_epochs = n_epochs

start = time.time()


for epoch in range(num_epochs):
start_epoch = time.time()
Expand All @@ -95,9 +90,6 @@ def train_model(model,
running_labels = torch.Tensor()
running_outputs = torch.Tensor()

#wrong_epoch_images = deque(maxlen=32)
#wrong_epoch_attr = deque(maxlen=32)

# Iterate over data.
for batch_idx, (inputs, labels) in enumerate(tqdm(dataloaders[phase])):
if metric_eer:
Expand Down Expand Up @@ -129,11 +121,6 @@ def train_model(model,
if metric_eer:
running_outputs = torch.cat((running_outputs, outputs.detach().cpu()))

#if phase == "train":
# wrong_epoch_images.extend([x for x in inputs[preds!=labels]])
#if track_images:
# wrong_epoch_attr.extend([(labels[i], preds[i])\
# for i in (preds!=labels).nonzero().flatten()])

if phase == 'train':
scheduler.step()
Expand All @@ -160,8 +147,7 @@ def train_model(model,
if save_curr_model:
model_folder = wandb.run.name if track_experiment else \
datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
if not os.path.exists(model_folder):
os.mkdir(model_folder)
os.makedirs(model_folder, exist_ok=True)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
Expand Down Expand Up @@ -224,10 +210,11 @@ def train(args):
}

train_loader, val_loader = dataloaders.get_dataset_loaders(in_datasets_names,
transformers,
int(args.batch_size),
int(args.num_dataloader_workers),
args.balanced_weights)
transformers,
int(args.batch_size),
int(args.num_dataloader_workers),
args.balanced_weights,
args.multiple_datasets_temperature)

model = models.get_model(args.backbone, len(train_loader.dataset.classes),
not args.no_transfer_learning, args.freeze_all_but_last)
Expand Down Expand Up @@ -258,11 +245,15 @@ def train(args):

parser.add_argument("--no_transfer_learning", action=argparse.BooleanOptionalAction)
parser.add_argument("--freeze_all_but_last", action=argparse.BooleanOptionalAction)
parser.add_argument("--weights", type=str)

# {phase} datasets are hope to have {phase}-named folders inside them
parser.add_argument("--train_datasets", action='store', type=str, nargs="+", required=True)
parser.add_argument("--val_datasets", action='store', type=str, nargs="+", required=True)
parser.add_argument("--balanced_weights", action=argparse.BooleanOptionalAction)
parser.add_argument("--multiple_datasets_temperature", type=float, required=False,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't get that, the temperature is not suppose to be defined by a value for each dataset? Is there any reference for this type of weighting?

It would be awesome if you explain a bit this options in the README.md as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a hacky way to balance multiple datasets that was not very good, honestly. I think it is a nice feature to have, but I would reimplement it with a more clear interface and output. Probably something like passing a list of datasets: [ds1, ds2, ds3] and some sampling weights to them: [0.2, 0.1, 0.7] that would sample 20% of ds1, 10% of ds2 and 70% of ds3. Sounds better, right?

help="Dataset path contains multiple datasets that will be combined, each one "
"having a weight given by a softmax of the datasets size with this temperature.")

parser.add_argument("--resize_size", default=None)
parser.add_argument("--num_dataloader_workers", default=8) # recomends to be 4 x #GPU
Expand All @@ -278,11 +269,12 @@ def train(args):
parser.add_argument("--wandb_sweep_activated", action=argparse.BooleanOptionalAction)

parser.add_argument("--augmentation", type=str, default="simple",
choices=["noaug", "simple", "rand-m9-n3-mstd0.5", "rand-mstd1-w0", "random_erase"])
choices=["noaug", "simple", "rand-m9-n3-mstd0.5", "rand-mstd1-w0", "random_erase"])

# options for optimizers
parser.add_argument("--optimizer", default="sgd") # possible adam, adamp and sgd
parser.add_argument("--weight_decay", type=float, default=1e-4)
parser.add_argument("--t_mult", type=int, default=2)

# options for model saving
parser.add_argument("--save_best_model", action=argparse.BooleanOptionalAction)
Expand Down