|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.optim as optim |
| 4 | +from torchvision import models, utils |
| 5 | +from torch.utils.data import DataLoader |
| 6 | +import time |
| 7 | +from dataloader import TrainDataset |
| 8 | +from models import FeatureExtractor, Generator, Discriminator |
| 9 | +from torchsummary import summary |
| 10 | +import argparse |
| 11 | +import os |
| 12 | +import matplotlib.pyplot as plt |
| 13 | +from torch.utils.tensorboard import SummaryWriter |
| 14 | + |
| 15 | +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 16 | + |
| 17 | +parser = argparse.ArgumentParser() |
| 18 | +parser.add_argument('--root_dir', default='./', help='path to dataset') |
| 19 | +parser.add_argument('--num_workers', type=int, default=2, help='number of data loading workers') |
| 20 | +parser.add_argument('--batch_size', type=int, default=128, help='input batch size') |
| 21 | +parser.add_argument('--num_epochs', type=int, default=200, help='number of epochs to train for') |
| 22 | +parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') |
| 23 | +parser.add_argument('--pre_num_epochs', type=int, default=100, help='number of pre-training epochs') |
| 24 | +parser.add_argument('--outdir', default='./', help='directory to output model checkpoints') |
| 25 | +parser.add_argument('--load_checkpoint', default=0, type=int, help='Pass 1 to load checkpoint') |
| 26 | +parser.add_argument('--b', default=16, type=int, help='number of residual blocks in generator') |
| 27 | +args = parser.parse_args() |
| 28 | + |
| 29 | +# Load data |
| 30 | +dataset = TrainDataset(args.root_dir) |
| 31 | +dataloader = DataLoader(dataset, args.batch_size, True, num_workers=args.num_workers) |
| 32 | +# Initialize models |
| 33 | +vgg = models.vgg19(pretrained=True) |
| 34 | +feature_extractor = FeatureExtractor(vgg, 5, 4) |
| 35 | +if torch.cuda.device_count() > 1: |
| 36 | + feature_extractor = nn.DataParallel(feature_extractor) |
| 37 | +feature_extractor = feature_extractor.to(device) |
| 38 | + |
| 39 | +disc = Discriminator() |
| 40 | +if torch.cuda.device_count() > 1: |
| 41 | + disc = nn.DataParallel(disc) |
| 42 | +disc = disc.to(device) |
| 43 | +if args.load_checkpoint == 1 and os.path.exists('disc.pt'): |
| 44 | + disc.load_state_dict(torch.load('disc.pt')) |
| 45 | +print(disc) |
| 46 | + |
| 47 | +gen = Generator(args.b) |
| 48 | +if torch.cuda.device_count() > 1: |
| 49 | + gen = nn.DataParallel(gen) |
| 50 | +gen = gen.to(device) |
| 51 | +if args.load_checkpoint == 1 and os.path.exists('gen.pt'): |
| 52 | + gen.load_state_dict(torch.load('gen.pt')) |
| 53 | +print(gen) |
| 54 | + |
| 55 | +content_criterion = nn.MSELoss() |
| 56 | +adversarial_criterion = nn.BCELoss() |
| 57 | +optimG = optim.Adam(gen.parameters(), args.lr) |
| 58 | +schedulerG1 = optim.lr_scheduler.MultiStepLR(optimG, [100], 0.1) |
| 59 | +schedulerG2 = optim.lr_scheduler.MultiStepLR(optimG, [100], 0.1) |
| 60 | +optimD = optim.Adam(disc.parameters(), args.lr) |
| 61 | +schedulerD = optim.lr_scheduler.MultiStepLR(optimD, [100], 0.1) |
| 62 | +writer = SummaryWriter() |
| 63 | +# Generator pre-training |
| 64 | +start_time = time.time() |
| 65 | +iters = 0 |
| 66 | +for epoch in range(args.pre_num_epochs): |
| 67 | + |
| 68 | + for i, data in enumerate(dataloader, 0): |
| 69 | + |
| 70 | + lr, hr_real = data |
| 71 | + hr_real = hr_real.to(device) |
| 72 | + lr = lr.to(device) |
| 73 | + |
| 74 | + batch_size = hr_real.size()[0] |
| 75 | + hr_fake = gen(lr) |
| 76 | + |
| 77 | + gen.zero_grad() |
| 78 | + gen_content_loss = content_criterion(hr_fake, hr_real) |
| 79 | + gen_content_loss.backward() |
| 80 | + optimG.step() |
| 81 | + |
| 82 | + if i == 0: |
| 83 | + print(f'[{epoch}/{args.pre_num_epochs}][{i}/{len(dataloader)}] Gen_MSE: {gen_content_loss.item()}') |
| 84 | + iters += 1 |
| 85 | + |
| 86 | + torch.save(gen.state_dict(), f'{args.outdir}gen.pt') |
| 87 | + schedulerG1.step() |
| 88 | + print(f'Time Elapsed: {(time.time()-start_time): .2f}') |
| 89 | + |
| 90 | +# Adversarial Training |
| 91 | +G_losses = [] |
| 92 | +D_losses = [] |
| 93 | +iters = 0 |
| 94 | +optimG = optim.Adam(gen.parameters(), args.lr) |
| 95 | +for epoch in range(args.num_epochs): |
| 96 | + |
| 97 | + for i, data in enumerate(dataloader): |
| 98 | + iters += 1 |
| 99 | + lr, hr_real = data |
| 100 | + batch_size = hr_real.size()[0] |
| 101 | + hr_real = hr_real.to(device) |
| 102 | + lr = lr.to(device) |
| 103 | + hr_fake = gen(lr) |
| 104 | + |
| 105 | + # Label Smoothing (Salimans et. al. 2016) |
| 106 | + target_real = torch.rand(batch_size, 1, device=device)*0.85+0.3 |
| 107 | + target_fake = torch.rand(batch_size, 1, device=device)*0.15 |
| 108 | + |
| 109 | + # Discriminator |
| 110 | + disc.zero_grad() |
| 111 | + D_x = disc(hr_real) |
| 112 | + D_G_z1 = disc(hr_fake.detach()) |
| 113 | + errD_real = adversarial_criterion(D_x, target_real) |
| 114 | + errD_fake = adversarial_criterion(D_G_z1, target_fake) |
| 115 | + errD = errD_real + errD_fake |
| 116 | + D_x = D_x.view(-1).mean().item() |
| 117 | + D_G_z1 = D_G_z1.view(-1).mean().item() |
| 118 | + errD.backward() |
| 119 | + optimD.step() |
| 120 | + |
| 121 | + # Generator |
| 122 | + gen.zero_grad() |
| 123 | + real_features = feature_extractor(hr_real) |
| 124 | + fake_features = feature_extractor(hr_fake) |
| 125 | + ones = torch.ones(batch_size, 1, device=device) |
| 126 | + |
| 127 | + errG_mse = content_criterion(hr_fake, hr_real) |
| 128 | + errG_vgg = content_criterion(fake_features, real_features) |
| 129 | + D_G_z2 = disc(hr_fake) |
| 130 | + errG_adv = adversarial_criterion(D_G_z2, ones) |
| 131 | + errG = errG_mse + 0.006*errG_vgg + 0.001*errG_adv |
| 132 | + D_G_z2 = D_G_z2.view(-1).mean().item() |
| 133 | + errG.backward() |
| 134 | + optimG.step() |
| 135 | + if i == 0: |
| 136 | + print(f'[{epoch}/{args.num_epochs}][{i}/{len(dataloader)}] errD: {errD.item():.4f}' |
| 137 | + f' errG: {errG.item():.4f} ({errG_mse.item():.4f}/{0.006*errG_vgg.item():.4f}/{0.001*errG_adv.item():.4f})' |
| 138 | + f' D(HR): {D_x :.4f} D(G(LR1)): {D_G_z1:.4f} D(G(LR2)): {D_G_z2:.4f}') |
| 139 | + |
| 140 | + G_losses.append(errG.item()) |
| 141 | + D_losses.append(errD.item()) |
| 142 | + |
| 143 | + torch.save(gen.state_dict(), f'{args.outdir}gen.pt') |
| 144 | + torch.save(disc.state_dict(), f'{args.outdir}disc.pt') |
| 145 | + print(f'Time Elapsed: {(time.time()-start_time): .2f}') |
| 146 | + schedulerD.step() |
| 147 | + schedulerG2.step() |
| 148 | + |
| 149 | +print(f'Finished Training {args.num_epochs} epochs') |
| 150 | + |
| 151 | +plt.figure(figsize=(10,5)) |
| 152 | +plt.title("Generator and Discriminator Loss During Training") |
| 153 | +plt.plot(G_losses,label="G") |
| 154 | +plt.plot(D_losses,label="D") |
| 155 | +plt.xlabel("Iterations") |
| 156 | +plt.ylabel("Loss") |
| 157 | +plt.legend() |
| 158 | +plt.show() |
0 commit comments