-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathmain.py
218 lines (178 loc) · 8.04 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import argparse
import logging as log
import os
import time
from math import ceil, floor
from tensorboardX import SummaryWriter
import torch
import torch.distributed as dist
import torch.optim as optim
import torch.utils.data.distributed
from torch.multiprocessing import Process
from torch.autograd import Variable
from dataloading.dataloaders import get_loader
from model.model import VSRNet
from model.clr import cyclic_learning_rate
from nvidia.fp16 import FP16_Optimizer
from nvidia.fp16util import network_to_half
from nvidia.distributed import DistributedDataParallel
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed')
parser.add_argument('--root', type=str, default='.',
help='input data root folder')
parser.add_argument('--frames', type=int, default = 3,
help='num frames in input sequence')
parser.add_argument('--is_cropped', action='store_true',
help='crop input frames?')
parser.add_argument('--crop_size', type=int, nargs='+', default=[256, 256],
help='[height, width] for input crop')
parser.add_argument('--batchsize', type=int, default=1,
help='per rank batch size')
parser.add_argument('--loader', type=str, default='NVVL',
help='dataloader: pytorch or NVVL')
parser.add_argument('--rank', type=int, default=0,
help='pytorch distributed rank')
parser.add_argument('--world_size', default=2, type=int, metavar='N',
help='num processes for pytorch distributed')
parser.add_argument('--ip', default='localhost', type=str,
help='IP address for distributed init.')
parser.add_argument('--max_iter', type=int, default=1000,
help='num training iters')
parser.add_argument('--fp16', action='store_true',
help='train in fp16?')
parser.add_argument('--checkpoint_dir', type=str, default='.',
help='where to save checkpoints')
parser.add_argument('--min_lr', type=float, default=0.000001,
help='min learning rate for cyclic learning rate')
parser.add_argument('--max_lr', type=float, default=0.00001,
help='max learning rate for cyclic learning rate')
parser.add_argument('--weight_decay', type=float, default=0.0004,
help='ADAM weight decay')
parser.add_argument('--flownet_path', type=str,
default='flownet2-pytorch/networks/FlowNet2-SD_checkpoint.pth.tar',
help='FlowNetSD weights path')
parser.add_argument('--image_freq', type=int, default=100,
help='num iterations between image dumps to Tensorboard ')
parser.add_argument('--timing', action='store_true',
help="Time data loading and model training (default: False)")
def main(args):
if args.rank == 0:
log.basicConfig(level=log.INFO)
writer = SummaryWriter()
writer.add_text('config', str(args))
else:
log.basicConfig(level=log.WARNING)
writer = None
torch.cuda.set_device(args.rank % args.world_size)
torch.manual_seed(args.seed + args.rank)
torch.cuda.manual_seed(args.seed + args.rank)
torch.backends.cudnn.benchmark = True
log.info('Initializing process group')
dist.init_process_group(
backend='nccl',
init_method='tcp://' + args.ip + ':3567',
world_size=args.world_size,
rank=args.rank)
log.info('Process group initialized')
log.info("Initializing dataloader...")
train_loader, train_batches, val_loader, val_batches, sampler = get_loader(args)
samples_per_epoch = train_batches * args.batchsize
log.info('Dataloader initialized')
model = VSRNet(args.frames, args.flownet_path, args.fp16)
if args.fp16:
network_to_half(model)
model.cuda()
model.train()
for param in model.FlowNetSD_network.parameters():
param.requires_grad = False
model_params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.Adam(model_params, lr=1, weight_decay=args.weight_decay)
#optimizer = optim.SGD(model_params, lr=1,
# momentum=0.99, weight_decay=args.weight_decay)
stepsize = 2 * train_batches
clr_lambda = cyclic_learning_rate(args.min_lr, args.max_lr, stepsize)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[clr_lambda])
if args.fp16:
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
model = DistributedDataParallel(model)
# BEGIN TRAINING
total_iter = 0
while total_iter * args.world_size < args.max_iter:
epoch = floor(total_iter / train_batches)
if args.loader == 'pytorch':
sampler.set_epoch(epoch)
model.train()
total_epoch_loss = 0.0
sample_timer = 0.0
data_timer = 0.0
compute_timer = 0.0
iter_start = time.perf_counter()
# TRAINING EPOCH LOOP
for i, inputs in enumerate(train_loader):
if args.loader == 'NVVL':
inputs = inputs['input']
else:
inputs = inputs.cuda(non_blocking=True)
if args.fp16:
inputs = inputs.half()
if args.timing:
torch.cuda.synchronize()
data_end = time.perf_counter()
optimizer.zero_grad()
im_out = total_iter % args.image_freq == 0
loss = model(Variable(inputs), i, writer, im_out)
total_epoch_loss += loss.item()
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
optimizer.step()
scheduler.step()
if args.rank == 0:
if args.timing:
torch.cuda.synchronize()
iter_end = time.perf_counter()
sample_timer += (iter_end - iter_start)
data_timer += (data_end - iter_start)
compute_timer += (iter_end - data_end)
torch.cuda.synchronize()
iter_start = time.perf_counter()
writer.add_scalar('learning_rate', scheduler.get_lr()[0], total_iter)
writer.add_scalar('train_loss', loss.item(), total_iter)
log.info('Rank %d, Epoch %d, Iteration %d of %d, loss %.5f' %
(dist.get_rank(), epoch, i+1, train_batches, loss.item()))
total_iter += 1
if args.rank == 0:
if args.timing:
sample_timer_avg = sample_timer / samples_per_epoch
writer.add_scalar('sample_time', sample_timer_avg, total_iter)
data_timer_avg = data_timer / samples_per_epoch
writer.add_scalar('sample_data_time', data_timer_avg, total_iter)
compute_timer_avg = compute_timer / samples_per_epoch
writer.add_scalar('sample_compute_time', compute_timer_avg, total_iter)
epoch_loss_avg = total_epoch_loss / train_batches
log.info('Rank %d, epoch %d: %.5f' % (dist.get_rank(), epoch, epoch_loss_avg))
model.eval()
total_loss = 0
total_psnr = 0
for i, inputs in enumerate(val_loader):
if args.loader == 'NVVL':
inputs = inputs['input']
else:
inputs = inputs.cuda(non_blocking=True)
if args.fp16:
inputs = inputs.half()
log.info('Validation it %d of %d' % (i + 1, val_batches))
loss, psnr = model(Variable(inputs), i, None)
total_loss += loss.item()
total_psnr += psnr.item()
loss = total_loss / i
psnr = total_psnr / i
if args.rank == 0:
writer.add_scalar('val_loss', loss, total_iter)
writer.add_scalar('val_psnr', psnr, total_iter)
log.info('Rank %d validation loss %.5f' % (dist.get_rank(), loss))
log.info('Rank %d validation psnr %.5f' % (dist.get_rank(), psnr))
if __name__=='__main__':
main(parser.parse_args())