Skip to content

Commit 54b9d60

Browse files
authored
Add SRGAN PyTorch Implementation (pclubiitk#35)
* Add SRGAN PyTorch Implementation * Formatted README to reduce clutter * Update README code block * Fix typo * Add new directory * Change directory
1 parent 69e3c40 commit 54b9d60

14 files changed

+411
-0
lines changed
+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# PyTorch Implementation of SRGAN
2+
3+
## Usage
4+
```bash
5+
$ python3 main.py
6+
```
7+
following are the arguments:
8+
```
9+
usage: main.py [-h] [--root_dir ROOT_DIR] [--num_workers NUM_WORKERS] [--batch_size BATCH_SIZE] [--num_epochs NUM_EPOCHS] [--lr LR]
10+
[--pre_num_epochs PRE_NUM_EPOCHS] [--outdir OUTDIR] [--load_checkpoint LOAD_CHECKPOINT] [--b B]
11+
12+
optional arguments:
13+
-h, --help show this help message and exit
14+
--root_dir ROOT_DIR path to dataset
15+
--num_workers NUM_WORKERS
16+
number of data loading workers
17+
--batch_size BATCH_SIZE
18+
input batch size
19+
--num_epochs NUM_EPOCHS
20+
number of epochs to train for
21+
--lr LR learning rate
22+
--pre_num_epochs PRE_NUM_EPOCHS
23+
number of pre-training epochs
24+
--outdir OUTDIR directory to output model checkpoints
25+
--load_checkpoint LOAD_CHECKPOINT
26+
Pass 1 to load checkpoint
27+
--b B number of residual blocks in generator
28+
```
29+
## Contributed by:
30+
[Naman Gupta](https://github.com/namangup)
31+
32+
## References
33+
* **Title**: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
34+
* **Authors**: Christian Ledig et. al.
35+
* **Link**: https://arxiv.org/abs/1609.04802
36+
* **Tags**: Super Resolution, Generative Adversarial Networks
37+
* **Year**: 2017
38+
39+
## What's New?
40+
Super Resolution received substantial attention from within the computer
41+
vision research community and has a wide range of
42+
applications. The optimization target of supervised
43+
SR algorithms is commonly the minimization of the mean
44+
squared error (MSE) between the recovered HR image
45+
and the ground truth. This is convenient as minimizing
46+
MSE also maximizes the peak signal-to-noise ratio (PSNR),
47+
which is a common measure used to evaluate and compare
48+
SR algorithms. However, the ability of MSE (and
49+
PSNR) to capture perceptually relevant differences, such
50+
as high texture detail, is very limited as they are defined
51+
based on pixel-wise image difference. Hence, to capure those details, SRGAN define a novel perceptual loss using high-level feature maps of the VGG network
52+
combined with a discriminator that encourages solutions
53+
perceptually hard to distinguish from the HR reference
54+
images.
55+
## Architecture
56+
![model](assets/model.jpg)
57+
## Loss Functions
58+
we have the following loss functions:\
59+
**Perceptual Loss**\
60+
![perceptual loss](assets/perceptual_loss.png)\
61+
**PixelWise MSE Loss**\
62+
![MSE loss](assets/MSE_loss.png)\
63+
**Content Loss**\
64+
![content loss](assets/content_loss.png)\
65+
**Adversarial Loss**\
66+
![adversarial loss](assets/adversarial_loss.png)\
67+
68+
## Implementation
69+
Following the paper the SRResNet(Generator) is pre-trained first on MSE Loss, followed by adversarial training of both the Generator and Discriminator alternately (k=1).
70+
For the perceptual loss, VGG(5,4) is used by default.
71+
72+
The dataset consists of ~40k images randomly sampled from the Imagenet Dataset, and 96\*96 patches are cropped randomly from each image.
73+
These 96\*96 images are downsampled and fed into the generator as images of size 24\*24, which in turn generates back Super Resolution images of size 96\*96.
74+
75+
## Results
76+
77+
On pre-training the Generator for 100 epochs followed by adversarial training for 200 epochs, the following results are obtained.
78+
79+
>***NOTE*** : Go to the assets folder to view full sized images. They have been resized for better readability.
80+
81+
>*x4 refers to image upscaled four times*.
82+
83+
**Low Resolution (Original)**
84+
<p float="left">
85+
<img src="assets/lr.png" width="280" style="margin:10px">
86+
</p>
87+
88+
**x4 Bicubic Interpolation, High Resolution (Original)**
89+
<p float="left">
90+
<img src="assets/lr_bicubic.png" width="400" style="margin:10px">
91+
<img src="assets/hr.png" width="400" style="margin:10px">
92+
</p>
93+
94+
**x4 *Online Image Enhancer*, x4 SRGAN**
95+
<p>
96+
<img src="assets/lr_letsenhance.png" width="400" style="margin:10px">
97+
<img src="assets/sr.png" width="400" style="margin:10px">
98+
</p>
99+
100+
>I used [letsenhance.io](https://letsenhance.io/) which claims to use a "*Powerful AI to increase image resolution without quality loss*".
101+
102+
The SRGAN generated image clearly retains more features, and produces better images.
103+
104+
The current model can definitely achieve far better results given more data and training epochs .
105+
I used 40k images whereas, the authors used 350k images, trained for 10^5 steps.
Loading
Loading
Loading
Loading
76.8 KB
Loading
Loading
Loading
Loading
Loading
714 KB
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from torch.utils.data import Dataset
2+
from PIL import Image
3+
import os
4+
from torchvision import transforms
5+
6+
def clean_dataset(dir):
7+
""" Remove images which are not in RGB colour space"""
8+
for img in os.listdir(dir):
9+
path = os.path.join(dir, img)
10+
im = Image.open(path)
11+
if(im.mode != 'RGB'):
12+
os.remove(path)
13+
14+
class TrainDataset(Dataset):
15+
16+
def __init__(self, dir):
17+
super().__init__()
18+
clean_dataset(dir)
19+
self.img = [os.path.join(dir, x) for x in os.listdir(dir)]
20+
self.hr = transforms.Compose([
21+
transforms.RandomCrop(96, pad_if_needed=True),
22+
transforms.ToTensor(),
23+
])
24+
self.lr = transforms.Compose([
25+
transforms.ToPILImage(),
26+
transforms.Resize(24, interpolation=Image.BICUBIC),
27+
transforms.ToTensor()
28+
])
29+
30+
def __getitem__(self, index):
31+
hr_image = self.hr(Image.open(self.img[index]))
32+
lr_image = self.lr(hr_image)
33+
return lr_image, hr_image
34+
35+
def __len__(self):
36+
return len(self.img)
+158
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
from torchvision import models, utils
5+
from torch.utils.data import DataLoader
6+
import time
7+
from dataloader import TrainDataset
8+
from models import FeatureExtractor, Generator, Discriminator
9+
from torchsummary import summary
10+
import argparse
11+
import os
12+
import matplotlib.pyplot as plt
13+
from torch.utils.tensorboard import SummaryWriter
14+
15+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16+
17+
parser = argparse.ArgumentParser()
18+
parser.add_argument('--root_dir', default='./', help='path to dataset')
19+
parser.add_argument('--num_workers', type=int, default=2, help='number of data loading workers')
20+
parser.add_argument('--batch_size', type=int, default=128, help='input batch size')
21+
parser.add_argument('--num_epochs', type=int, default=200, help='number of epochs to train for')
22+
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate')
23+
parser.add_argument('--pre_num_epochs', type=int, default=100, help='number of pre-training epochs')
24+
parser.add_argument('--outdir', default='./', help='directory to output model checkpoints')
25+
parser.add_argument('--load_checkpoint', default=0, type=int, help='Pass 1 to load checkpoint')
26+
parser.add_argument('--b', default=16, type=int, help='number of residual blocks in generator')
27+
args = parser.parse_args()
28+
29+
# Load data
30+
dataset = TrainDataset(args.root_dir)
31+
dataloader = DataLoader(dataset, args.batch_size, True, num_workers=args.num_workers)
32+
# Initialize models
33+
vgg = models.vgg19(pretrained=True)
34+
feature_extractor = FeatureExtractor(vgg, 5, 4)
35+
if torch.cuda.device_count() > 1:
36+
feature_extractor = nn.DataParallel(feature_extractor)
37+
feature_extractor = feature_extractor.to(device)
38+
39+
disc = Discriminator()
40+
if torch.cuda.device_count() > 1:
41+
disc = nn.DataParallel(disc)
42+
disc = disc.to(device)
43+
if args.load_checkpoint == 1 and os.path.exists('disc.pt'):
44+
disc.load_state_dict(torch.load('disc.pt'))
45+
print(disc)
46+
47+
gen = Generator(args.b)
48+
if torch.cuda.device_count() > 1:
49+
gen = nn.DataParallel(gen)
50+
gen = gen.to(device)
51+
if args.load_checkpoint == 1 and os.path.exists('gen.pt'):
52+
gen.load_state_dict(torch.load('gen.pt'))
53+
print(gen)
54+
55+
content_criterion = nn.MSELoss()
56+
adversarial_criterion = nn.BCELoss()
57+
optimG = optim.Adam(gen.parameters(), args.lr)
58+
schedulerG1 = optim.lr_scheduler.MultiStepLR(optimG, [100], 0.1)
59+
schedulerG2 = optim.lr_scheduler.MultiStepLR(optimG, [100], 0.1)
60+
optimD = optim.Adam(disc.parameters(), args.lr)
61+
schedulerD = optim.lr_scheduler.MultiStepLR(optimD, [100], 0.1)
62+
writer = SummaryWriter()
63+
# Generator pre-training
64+
start_time = time.time()
65+
iters = 0
66+
for epoch in range(args.pre_num_epochs):
67+
68+
for i, data in enumerate(dataloader, 0):
69+
70+
lr, hr_real = data
71+
hr_real = hr_real.to(device)
72+
lr = lr.to(device)
73+
74+
batch_size = hr_real.size()[0]
75+
hr_fake = gen(lr)
76+
77+
gen.zero_grad()
78+
gen_content_loss = content_criterion(hr_fake, hr_real)
79+
gen_content_loss.backward()
80+
optimG.step()
81+
82+
if i == 0:
83+
print(f'[{epoch}/{args.pre_num_epochs}][{i}/{len(dataloader)}] Gen_MSE: {gen_content_loss.item()}')
84+
iters += 1
85+
86+
torch.save(gen.state_dict(), f'{args.outdir}gen.pt')
87+
schedulerG1.step()
88+
print(f'Time Elapsed: {(time.time()-start_time): .2f}')
89+
90+
# Adversarial Training
91+
G_losses = []
92+
D_losses = []
93+
iters = 0
94+
optimG = optim.Adam(gen.parameters(), args.lr)
95+
for epoch in range(args.num_epochs):
96+
97+
for i, data in enumerate(dataloader):
98+
iters += 1
99+
lr, hr_real = data
100+
batch_size = hr_real.size()[0]
101+
hr_real = hr_real.to(device)
102+
lr = lr.to(device)
103+
hr_fake = gen(lr)
104+
105+
# Label Smoothing (Salimans et. al. 2016)
106+
target_real = torch.rand(batch_size, 1, device=device)*0.85+0.3
107+
target_fake = torch.rand(batch_size, 1, device=device)*0.15
108+
109+
# Discriminator
110+
disc.zero_grad()
111+
D_x = disc(hr_real)
112+
D_G_z1 = disc(hr_fake.detach())
113+
errD_real = adversarial_criterion(D_x, target_real)
114+
errD_fake = adversarial_criterion(D_G_z1, target_fake)
115+
errD = errD_real + errD_fake
116+
D_x = D_x.view(-1).mean().item()
117+
D_G_z1 = D_G_z1.view(-1).mean().item()
118+
errD.backward()
119+
optimD.step()
120+
121+
# Generator
122+
gen.zero_grad()
123+
real_features = feature_extractor(hr_real)
124+
fake_features = feature_extractor(hr_fake)
125+
ones = torch.ones(batch_size, 1, device=device)
126+
127+
errG_mse = content_criterion(hr_fake, hr_real)
128+
errG_vgg = content_criterion(fake_features, real_features)
129+
D_G_z2 = disc(hr_fake)
130+
errG_adv = adversarial_criterion(D_G_z2, ones)
131+
errG = errG_mse + 0.006*errG_vgg + 0.001*errG_adv
132+
D_G_z2 = D_G_z2.view(-1).mean().item()
133+
errG.backward()
134+
optimG.step()
135+
if i == 0:
136+
print(f'[{epoch}/{args.num_epochs}][{i}/{len(dataloader)}] errD: {errD.item():.4f}'
137+
f' errG: {errG.item():.4f} ({errG_mse.item():.4f}/{0.006*errG_vgg.item():.4f}/{0.001*errG_adv.item():.4f})'
138+
f' D(HR): {D_x :.4f} D(G(LR1)): {D_G_z1:.4f} D(G(LR2)): {D_G_z2:.4f}')
139+
140+
G_losses.append(errG.item())
141+
D_losses.append(errD.item())
142+
143+
torch.save(gen.state_dict(), f'{args.outdir}gen.pt')
144+
torch.save(disc.state_dict(), f'{args.outdir}disc.pt')
145+
print(f'Time Elapsed: {(time.time()-start_time): .2f}')
146+
schedulerD.step()
147+
schedulerG2.step()
148+
149+
print(f'Finished Training {args.num_epochs} epochs')
150+
151+
plt.figure(figsize=(10,5))
152+
plt.title("Generator and Discriminator Loss During Training")
153+
plt.plot(G_losses,label="G")
154+
plt.plot(D_losses,label="D")
155+
plt.xlabel("Iterations")
156+
plt.ylabel("Loss")
157+
plt.legend()
158+
plt.show()

0 commit comments

Comments
 (0)