Skip to content

Commit 422011a

Browse files
committed
init
0 parents  commit 422011a

8 files changed

+1127
-0
lines changed

README.md

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# AdaTask
2+
[AdaTask: A Task-Aware Adaptive Learning Rate Approach to Multi-Task Learning](https://arxiv.org/abs/2211.15055)
3+
4+
In this paper we propose a Task-wise Adaptive Learning Rate Method, named AdaTask, to use task-specific accumulative gradients when adjusting the learning rate of each parameter.
5+
6+
## DataSet
7+
- Download [CityScapes](https://www.dropbox.com/sh/gaw6vh6qusoyms6/AADwWi0Tp3E3M4B2xzeGlsEna?dl=0) dataset and put it in the dataset directory.
8+
9+
10+
## Train and Evaluate Method
11+
12+
```
13+
python3 main_cityscapes.py --method=adam
14+
```
15+
16+
```
17+
python3 main_cityscapes.py --method=adam_with_adatask
18+
```
19+
20+
21+
22+
## Reference
23+
24+
Please cite our paper if you use this code.
25+
26+
```
27+
@inproceedings{adatask_aaai2023,
28+
title={AdaTask: A Task-aware Adaptive Learning Rate Approach to Multi-task Learning},
29+
author={{Yang, Enneng and Pan, Junwei and Wang, Ximei and Yu, Haibin and Shen, Li and Chen, Xihua and Xiao, Lei and Jiang, Jie and Guo, Guibing},
30+
booktitle={AAAI},
31+
year={2023}
32+
}
33+
34+
```
35+

adatask.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import math
2+
import torch
3+
from torch.optim.optimizer import Optimizer
4+
from typing import List, Union
5+
6+
class Adam_with_AdaTask(Optimizer):
7+
r"""
8+
Implements Adam with AdaTask algorithm.
9+
"""
10+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, args=None, device='cpu', n_tasks=3, task_weight=[1, 1]):
11+
if not 0.0 <= lr:
12+
raise ValueError("Invalid learning rate: {}".format(lr))
13+
if not 0.0 <= eps:
14+
raise ValueError("Invalid epsilon value: {}".format(eps))
15+
if not 0.0 <= betas[0] < 1.0:
16+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
17+
if not 0.0 <= betas[1] < 1.0:
18+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
19+
if not 0.0 <= weight_decay:
20+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
21+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
22+
super(Adam_with_AdaTask, self).__init__(params, defaults)
23+
24+
self.n_tasks = n_tasks
25+
self.device = device
26+
self.betas = betas
27+
self.eps = eps
28+
self.task_weight = torch.Tensor(task_weight).to(device)
29+
30+
def zero_grad_modules(self, modules_parameters):
31+
for p in modules_parameters:
32+
if p.grad is not None:
33+
p.grad.detach_()
34+
p.grad.zero_()
35+
36+
def backward_and_step(self,
37+
losses: torch.Tensor,
38+
shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
39+
task_specific_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None,
40+
last_shared_parameters: Union[List[torch.nn.parameter.Parameter], torch.Tensor] = None, ):
41+
42+
shared_grads = []
43+
if shared_parameters is not None:
44+
for i in range(len(losses)):
45+
self.zero_grad_modules(shared_parameters)
46+
(self.task_weight[i] * losses[i]).backward(retain_graph=True)
47+
grad = [p.grad.detach().clone() if (p.requires_grad is True and p.grad is not None) else None for p in shared_parameters]
48+
shared_grads.append(grad)
49+
50+
if task_specific_parameters is not None:
51+
self.zero_grad_modules(task_specific_parameters)
52+
(self.task_weight*losses).sum().backward()
53+
task_specific_grads = [p.grad.detach().clone() if (p.requires_grad is True and p.grad is not None) else None for p in task_specific_parameters]
54+
55+
return self.step(shared_parameters, task_specific_parameters, shared_grads, task_specific_grads)
56+
57+
@torch.no_grad()
58+
def step(self, shared_parameters, task_specific_parameters, shared_grads, task_specific_grads):
59+
# lr
60+
for group in self.param_groups:
61+
step_lr = group['lr']
62+
63+
# shared param
64+
for pi in range(len(shared_parameters)):
65+
p = shared_parameters[pi]
66+
state = self.state[p]
67+
# State initialization
68+
if len(state) == 0:
69+
state['step'] = 0
70+
for t in range(self.n_tasks):
71+
# Exponential moving average of gradient values
72+
state['exp_avg_'+str(t)] = torch.zeros_like(p, memory_format=torch.preserve_format)
73+
# Exponential moving average of squared gradient values
74+
state['exp_avg_sq_'+str(t)] = torch.zeros_like(p, memory_format=torch.preserve_format)
75+
76+
state['step'] += 1
77+
beta1, beta2 = self.betas
78+
bias_correction1 = 1 - beta1 ** state['step']
79+
bias_correction2 = 1 - beta2 ** state['step']
80+
81+
for t in range(self.n_tasks):
82+
grad = shared_grads[t][pi]
83+
exp_avg = state['exp_avg_' + str(t)]
84+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
85+
exp_avg_sq = state['exp_avg_sq_' + str(t)]
86+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
87+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
88+
step_size = step_lr / bias_correction1
89+
p.addcdiv_(exp_avg, denom, value=-step_size)
90+
91+
# task specific param
92+
for pi in range(len(task_specific_parameters)):
93+
p = task_specific_parameters[pi]
94+
state = self.state[p]
95+
# State initialization
96+
if len(state) == 0:
97+
state['step'] = 0
98+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
99+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
100+
101+
state['step'] += 1
102+
beta1, beta2 = self.betas
103+
bias_correction1 = 1 - beta1 ** state['step']
104+
bias_correction2 = 1 - beta2 ** state['step']
105+
106+
grad = task_specific_grads[pi]
107+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
108+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
109+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
110+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
111+
step_size = step_lr / bias_correction1
112+
p.addcdiv_(exp_avg, denom, value=-step_size)
113+
114+
return None

common.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import argparse
2+
import logging
3+
import random
4+
from collections import defaultdict
5+
from pathlib import Path
6+
import numpy as np
7+
import torch
8+
9+
def str_to_list(string):
10+
return [float(s) for s in string.split(",")]
11+
12+
def str_or_float(value):
13+
try:
14+
return float(value)
15+
except:
16+
return value
17+
18+
def str2bool(v):
19+
if isinstance(v, bool):
20+
return v
21+
if v.lower() in ("yes", "true", "t", "y", "1"):
22+
return True
23+
elif v.lower() in ("no", "false", "f", "n", "0"):
24+
return False
25+
else:
26+
raise argparse.ArgumentTypeError("Boolean value expected.")
27+
28+
common_parser = argparse.ArgumentParser(add_help=False)
29+
common_parser.add_argument("--data-path", type=Path, help="path to data")
30+
common_parser.add_argument("--log_path", type=Path, help="path to log")
31+
common_parser.add_argument("--n-epochs", type=int, default=200)
32+
common_parser.add_argument("--n_task", type=int, default=2)
33+
common_parser.add_argument("--batch-size", type=int, default=120, help="batch size")
34+
common_parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
35+
common_parser.add_argument("--method-params-lr", type=float, default=0.025, help="lr for weight method params. If None, set to args.lr. For uncertainty weighting",)
36+
common_parser.add_argument("--gpu", type=int, default=0, help="gpu device ID")
37+
common_parser.add_argument("--seed", type=int, default=42, help="seed value")
38+
39+
def count_parameters(model):
40+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
41+
42+
def set_logger():
43+
logging.basicConfig(
44+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
45+
level=logging.INFO,)
46+
47+
def set_seed(seed):
48+
"""for reproducibility
49+
:param seed:
50+
:return:
51+
"""
52+
np.random.seed(seed)
53+
random.seed(seed)
54+
55+
torch.manual_seed(seed)
56+
if torch.cuda.is_available():
57+
torch.cuda.manual_seed(seed)
58+
torch.cuda.manual_seed_all(seed)
59+
60+
torch.backends.cudnn.enabled = True
61+
torch.backends.cudnn.benchmark = False
62+
torch.backends.cudnn.deterministic = True
63+
64+
def get_device(no_cuda=False, gpus="0"):
65+
return torch.device(
66+
f"cuda:{gpus}" if torch.cuda.is_available() and not no_cuda else "cpu"
67+
)
68+

data.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import fnmatch
2+
import os
3+
import random
4+
import numpy as np
5+
import torch
6+
import torch.nn.functional as F
7+
from torch.utils.data.dataset import Dataset
8+
9+
class RandomScaleCrop(object):
10+
"""
11+
Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34.
12+
"""
13+
14+
def __init__(self, scale=[1.0, 1.2, 1.5]):
15+
self.scale = scale
16+
17+
def __call__(self, img, label, depth, normal):
18+
height, width = img.shape[-2:]
19+
sc = self.scale[random.randint(0, len(self.scale) - 1)]
20+
h, w = int(height / sc), int(width / sc)
21+
i = random.randint(0, height - h)
22+
j = random.randint(0, width - w)
23+
img_ = F.interpolate(
24+
img[None, :, i : i + h, j : j + w],
25+
size=(height, width),
26+
mode="bilinear",
27+
align_corners=True,
28+
).squeeze(0)
29+
label_ = (
30+
F.interpolate(
31+
label[None, None, i : i + h, j : j + w],
32+
size=(height, width),
33+
mode="nearest",
34+
)
35+
.squeeze(0)
36+
.squeeze(0)
37+
)
38+
depth_ = F.interpolate(
39+
depth[None, :, i : i + h, j : j + w], size=(height, width), mode="nearest"
40+
).squeeze(0)
41+
normal_ = F.interpolate(
42+
normal[None, :, i : i + h, j : j + w],
43+
size=(height, width),
44+
mode="bilinear",
45+
align_corners=True,
46+
).squeeze(0)
47+
return img_, label_, depth_ / sc, normal_
48+
49+
class RandomScaleCropCityScapes(object):
50+
"""
51+
Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34.
52+
"""
53+
def __init__(self, scale=[1.0, 1.2, 1.5]):
54+
self.scale = scale
55+
56+
def __call__(self, img, label, depth):
57+
height, width = img.shape[-2:]
58+
sc = self.scale[random.randint(0, len(self.scale) - 1)]
59+
h, w = int(height / sc), int(width / sc)
60+
i = random.randint(0, height - h)
61+
j = random.randint(0, width - w)
62+
img_ = F.interpolate(img[None, :, i:i + h, j:j + w], size=(height, width), mode='bilinear', align_corners=True).squeeze(0)
63+
label_ = F.interpolate(label[None, None, i:i + h, j:j + w], size=(height, width), mode='nearest').squeeze(0).squeeze(0)
64+
depth_ = F.interpolate(depth[None, :, i:i + h, j:j + w], size=(height, width), mode='nearest').squeeze(0)
65+
return img_, label_, depth_ / sc
66+
67+
class CityScapes(Dataset):
68+
"""
69+
We could further improve the performance with the data augmentation of NYUv2 defined in:
70+
[1] PAD-Net: Multi-Tasks Guided Prediction-and-Distillation Network for Simultaneous Depth Estimation and Scene Parsing
71+
[2] Pattern affinitive propagation across depth, surface normal and semantic segmentation
72+
[3] Mti-net: Multiscale task interaction networks for multi-task learning
73+
74+
1. Random scale in a selected raio 1.0, 1.2, and 1.5.
75+
2. Random horizontal flip.
76+
77+
Please note that: all baselines and MTAN did NOT apply data augmentation in the original paper.
78+
"""
79+
def __init__(self, root, train=True, augmentation=False):
80+
self.train = train
81+
self.root = os.path.expanduser(root)
82+
self.augmentation = augmentation
83+
84+
# read the data file
85+
if train:
86+
self.data_path = root + '/train'
87+
else:
88+
self.data_path = root + '/val'
89+
90+
# calculate data length
91+
self.data_len = len(fnmatch.filter(os.listdir(self.data_path + '/image'), '*.npy'))
92+
93+
def __getitem__(self, index):
94+
# load data from the pre-processed npy files
95+
image = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/image/{:d}.npy'.format(index)), -1, 0))
96+
semantic = torch.from_numpy(np.load(self.data_path + '/label_7/{:d}.npy'.format(index)))
97+
depth = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/depth/{:d}.npy'.format(index)), -1, 0))
98+
99+
# apply data augmentation if required
100+
if self.augmentation:
101+
image, semantic, depth = RandomScaleCropCityScapes()(image, semantic, depth)
102+
if torch.rand(1) < 0.5:
103+
image = torch.flip(image, dims=[2])
104+
semantic = torch.flip(semantic, dims=[1])
105+
depth = torch.flip(depth, dims=[2])
106+
107+
return image.float(), semantic.float(), depth.float()
108+
109+
def __len__(self):
110+
return self.data_len

0 commit comments

Comments
 (0)