Skip to content

Commit e7e9205

Browse files
authored
DCGAN_pytorch (pclubiitk#18)
* Create README.md * Update README.md * Update README.md * Update README.md * Update README.md * Create temp * Add files via upload * Create models.py * Create utils.py * Create main.py * Update utils.py * Update main.py * Update main.py * Update main.py * Delete temp * Delete generated img.png * Update main.py * Update main.py * Update main.py * Update README.md * Update README.md * name changed to assets * Update README.md
1 parent 6f51021 commit e7e9205

File tree

8 files changed

+337
-0
lines changed

8 files changed

+337
-0
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# DCGAN implementation in pytorch on MNIST
2+
3+
## Usage
4+
```bash
5+
$ python3 main.py --num_epochs 10
6+
```
7+
> **_NOTE:_** on Colab Notebook use following command:
8+
```python
9+
!git clone link-to-repo
10+
%run main.py --num_epochs 10
11+
```
12+
13+
## Help log
14+
```
15+
usage: main.py [-h] [--num_epochs NUM_EPOCHS]
16+
[--batch_size BATCH_SIZE]
17+
[--channels_noise CHANNELS_NOISE]
18+
[--lr_g LR_G][--lr_d LR_D]
19+
[--beta1 BETA1]
20+
21+
optional arguments:
22+
-h, --help show this help message and exit
23+
--num_epochs NUM_EPOCHS no. of epochs : default=10
24+
--batch_size BATCH_SIZE batch size : default=128
25+
--channels_noise CHANNELS_NOISE size of noise vector : default=100
26+
--lr_g LR_G learning rate generator : default=0.0002
27+
--lr_d LR_D learning rate discriminator : default=0.0002
28+
--beta1 BETA1 bet1 value for adam optimizer
29+
30+
```
31+
32+
## contributed by :
33+
* [Nakul Jindal](https://github.com/nakul-jindal)
34+
35+
## References
36+
37+
* **Title**: UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS
38+
* **Authors**: Alec Radford, Luke Metz, Soumith Chintala
39+
* **Link**: https://arxiv.org/pdf/1511.06434.pdf
40+
* **Tags**: Neural Network, Generative Networks, GANs
41+
* **Year**: 2015
42+
43+
44+
# summary
45+
46+
## Introduction
47+
* Deep Convolution Generative Adversarial Networks (DCGANs) belong to a set of algorithms called generative models, which are widely used for unupervised learning tasks which aim to learn the underlying structure of the given data.
48+
49+
* Simple GANs allow you to generate new unseen data that mimic the actual given real data. However, GANs pose problems in training and require carefullly tuned hyperparameters.
50+
51+
* DCGAN aims to solve this problem by explicitly using convolutional and convolutional-transpose layers in the discriminator and generator, respectively.
52+
53+
* DCGANs basically convert the laplacian pyramid technique (many pairs of G and D to progressively upscale an image) to a single pair of G and D.
54+
55+
## Generator
56+
57+
* The generator `G` is designed to map the latent space vector `z` (random noise) to data-space (images same as training images)
58+
* involves a series of transpose Conv2d layers, each with BatchNorm2d and relu activation.
59+
* The output of the generator is fed through a tanh function to return it to the input data range of `[-1,1]`.
60+
61+
## Discriminator
62+
63+
* The discriminator `D` is a binary classification network that takes an image as input and outputs a scalar probability that the input image is real or fake.
64+
* `D` involves a series of Conv2d, BatchNorm2d, and LeakyReLU layers.
65+
* outputs the final probability through a Sigmoid activation function.
66+
67+
68+
> The DCGAN paper mentions it is a good practice to use strided convolution rather than pooling to downsample because it lets the network learn its own pooling function. Also batch norm and leaky relu functions promote healthy gradient flow which is critical for the learning process of both `G` and `D`.
69+
70+
71+
## Model Architecture
72+
73+
![architecture](assets/architecture.png)
74+
75+
## Network Design of DCGAN:
76+
* Replace all pooling layers with strided convolutions for the downsampling
77+
* Remove all fully connected layers.
78+
* Use transposed convolutions for upsampling.
79+
* Use Batch Normalization after every layer except after the output layer of the generator and the input layer of the discriminator.
80+
* Use ReLU non-linearity for each layer in the generator except for output layer use tanh.
81+
* Use Leaky-ReLU non-linearity for each layer of the disciminator excpet for output layer use sigmoid.
82+
83+
## Hyperparameters for this Implementation
84+
Hyperparameters are chosen as given in the paper.
85+
* mini-batch size: 128
86+
* learning rate: 0.0002
87+
* momentum term beta1: 0.5
88+
* slope of leak of LeakyReLU: 0.2
89+
* For the optimizer Adam (with beta2 = 0.999) has been used instead of SGD as described in the paper.
90+
91+
## MNIST vs Generated images
92+
93+
<table align='center'>
94+
<tr align='center'>
95+
<td> MNIST </td>
96+
<td> DCGAN after 10 epochs </td>
97+
</tr>
98+
<tr>
99+
<td><img src = 'assets/raw_MNIST.png'>
100+
<td><img src = 'assets/MNIST_DCGAN_10.png'>
101+
</tr>
102+
</table>
103+
104+
## Training loss
105+
106+
![Loss](assets/loss.png)
107+
108+
## contributions of the research paper
109+
110+
* proposes and evaluates Deep Convolutional GANs (DCGAN) which are a set of constraints on the architectural topology of Convolutional
111+
GANs that make them stable to train in most settings.
112+
113+
* use of trained discriminators for image classification tasks, showing competitive performance with other unsupervised algorithms.
114+
115+
* visualize the filters learnt by GANs and empirically show that specific filters have learned to draw specific objects.
116+
117+
* show that the generators have interesting vector arithmetic properties allowing for easy manipulation of many semantic qualities of generated samples.
118+
119+
## Conclusion of research paper
120+
121+
This paper shows how convolutional layers can be used with GANs and provides a series of additional architectural guidelines for doing this. The paper also discusses topics such as Visualizing GAN features, Latent space interpolation, using discriminator features to train classifiers, and evaluating results. The paper contains many examples of images generated by final and intermediate layers of the network.
122+
123+
#### key observations
124+
* Images in the latent space do not show sharp transitions indicating that network did not memorize images.
125+
* DCGAN can learn an interesting hierarchy of features.
126+
* Networks seems to have some success in disentangling image representation from object representation.
127+
* Vector arithmetic can be performed on the Z vectors corresponding to the face samples to get results like `smiling woman - normal woman + normal man = smiling man` visually.
Loading
Loading
22.1 KB
Loading
Loading
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import torch
2+
import torchvision
3+
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
4+
import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc.
5+
import torchvision.datasets as datasets # Has standard datasets we can import in a nice way
6+
import torchvision.transforms as transforms # Transformations we can perform on our dataset
7+
from torch.utils.data import DataLoader # Gives easier dataset managment and creates mini batches
8+
9+
import argparse
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument('--num_epochs', type=int, default=10 , help="no. of epochs : default=10")
12+
parser.add_argument('--batch_size', type=int, default=128, help="batch size : default=128")
13+
parser.add_argument('--channels_noise', type=int, default=100, help="size of noise vector : default=100")
14+
parser.add_argument('--lr_g', type=float, default=0.0002, help="learning rate generator : default=0.0002")
15+
parser.add_argument('--lr_d', type=float, default=0.0002, help="learning rate discriminator : default=0.0002")
16+
parser.add_argument('--beta1', type=float, default=0.5, help="bet1 value for adam optimizer" )
17+
args = parser.parse_args()
18+
19+
lr_g = args.lr_g
20+
lr_d = args.lr_d
21+
beta1 = args.beta1
22+
batch_size = args.batch_size
23+
channels_noise = args.channels_noise
24+
num_epochs = args.num_epochs
25+
26+
image_size = 64
27+
features_d = 128
28+
features_g = 128
29+
30+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31+
32+
my_transforms = transforms.Compose([
33+
transforms.Resize(image_size),
34+
transforms.ToTensor(),
35+
transforms.Normalize((0.5,),(0.5,)),
36+
])
37+
38+
dataset = datasets.MNIST(root='dataset/', train=True, transform=my_transforms, download=True)
39+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
40+
41+
from models import Generator , Discriminator , weights_init
42+
netD = Discriminator(channels_img, features_d).to(device)
43+
netG = Generator(channels_noise, channels_img, features_g).to(device)
44+
netG=netG.apply(weights_init)
45+
netD=netD.apply(weights_init)
46+
47+
optimizerD = optim.Adam(netD.parameters(), lr=lr_d, betas=(beta1, 0.999) )
48+
optimizerG = optim.Adam(netG.parameters(), lr=lr_g, betas=(beta1, 0.999) )
49+
50+
criterion = nn.BCELoss()
51+
52+
real_label = 1
53+
fake_label = 0
54+
55+
img_list = []
56+
G_losses = []
57+
D_losses = []
58+
59+
for epoch in range(num_epochs):
60+
for batch_idx, (data, targets) in enumerate(dataloader):
61+
62+
data = data.to(device)
63+
batch_size = data.shape[0]
64+
65+
# Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
66+
netD.zero_grad()
67+
label = (torch.ones(batch_size)*0.9).to(device)
68+
output = netD(data).view(-1)
69+
lossD_real = criterion(output, label)
70+
D_x = output.mean().item()
71+
72+
noise = torch.randn(batch_size, channels_noise, 1, 1).to(device)
73+
fake = netG(noise)
74+
label = (torch.ones(batch_size)*0.1).to(device)
75+
output = netD(fake.detach()).view(-1)
76+
lossD_fake = criterion(output, label)
77+
78+
lossD = lossD_real + lossD_fake
79+
lossD.backward()
80+
optimizerD.step()
81+
82+
# Train Generator: max log(D(G(z)))
83+
netG.zero_grad()
84+
label = torch.ones(batch_size).to(device)
85+
output = netD(fake).reshape(-1)
86+
lossG = criterion(output, label)
87+
lossG.backward()
88+
optimizerG.step()
89+
D_G_x = output.mean().item()
90+
91+
92+
if batch_idx % 100 == 0:
93+
# Print losses ocassionally
94+
print(f'Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} Loss D: {lossD:.4f} loss G: {lossG:.4f} D(x): {D_x:.4f} D(G(z)): {D_G_x:.4f} ')
95+
G_losses.append(lossG.item())
96+
D_losses.append(lossD.item())
97+
98+
# Check how the generator is doing by saving G's output on fixed_noise
99+
with torch.no_grad():
100+
fake = netG(fixed_noise).detach().cpu()
101+
img_list.append(torchvision.utils.make_grid(fake, padding=2, normalize=True))
102+
103+
104+
105+
from utils import compare_img , plot_loss , animation
106+
plot_loss(G_losses,D_losses) # visualise losses vs iterations
107+
compare_img(data,fake) # compare generated imgs with real mnist images
108+
animation(img_list) # visualise generated images on a fixed noise at intervals
109+
110+
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
# weight initialisation with mean=0 and stddev=0.02
5+
6+
def weights_init(m):
7+
classname = m.__class__.__name__
8+
if classname.find('Conv') != -1:
9+
nn.init.normal_(m.weight.data, 0.0, 0.02)
10+
elif classname.find('BatchNorm') != -1:
11+
nn.init.normal_(m.weight.data, 1.0, 0.02)
12+
nn.init.constant_(m.bias.data, 0)
13+
14+
class Generator(nn.Module):
15+
def __init__(self, channels_noise, channels_img, features_g):
16+
super(Generator, self).__init__()
17+
18+
self.net = nn.Sequential(
19+
20+
nn.ConvTranspose2d(channels_noise, features_g*8, kernel_size=4, stride=1, padding=0 , bias = False ),
21+
nn.BatchNorm2d(features_g*8),
22+
nn.ReLU(True),
23+
24+
nn.ConvTranspose2d(features_g*8, features_g*4, kernel_size=4, stride=2, padding=1 , bias = False ),
25+
nn.BatchNorm2d(features_g*4),
26+
nn.ReLU(True),
27+
28+
nn.ConvTranspose2d(features_g*4, features_g*2, kernel_size=4, stride=2, padding=1 , bias = False),
29+
nn.BatchNorm2d(features_g*2),
30+
nn.ReLU(True),
31+
32+
nn.ConvTranspose2d(features_g*2, features_g, kernel_size=4, stride=2, padding=1 , bias = False),
33+
nn.BatchNorm2d(features_g),
34+
nn.ReLU(True),
35+
36+
nn.ConvTranspose2d(features_g, channels_img, kernel_size=4, stride=2, padding=1 , bias = False ),
37+
nn.Tanh()
38+
)
39+
40+
def forward(self, x):
41+
return self.net(x)
42+
43+
class Discriminator(nn.Module):
44+
def __init__(self, channels_img, features_d):
45+
super(Discriminator, self).__init__()
46+
self.net = nn.Sequential(
47+
48+
nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1 , bias = False ),
49+
nn.LeakyReLU(0.2, inplace=True),
50+
51+
nn.Conv2d(features_d, features_d*2, kernel_size=4, stride=2, padding=1 , bias = False ),
52+
nn.BatchNorm2d(features_d*2),
53+
nn.LeakyReLU(0.2, inplace=True),
54+
55+
nn.Conv2d(features_d*2, features_d*4, kernel_size=4, stride=2, padding=1 , bias = False ),
56+
nn.BatchNorm2d(features_d*4),
57+
nn.LeakyReLU(0.2, inplace=True),
58+
59+
nn.Conv2d(features_d*4, features_d*8, kernel_size=4, stride=2, padding=1 , bias = False ),
60+
nn.BatchNorm2d(features_d*8),
61+
nn.LeakyReLU(0.2, inplace=True),
62+
63+
nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0 , bias = False ),
64+
nn.Sigmoid()
65+
)
66+
67+
def forward(self, x):
68+
return self.net(x)
69+
70+
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
import torchvision
3+
import matplotlib.pyplot as plt
4+
import matplotlib.animation as animation
5+
from IPython.display import HTML
6+
7+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8+
9+
def animation(img_list):
10+
fig = plt.figure(figsize=(8,8))
11+
plt.axis("off")
12+
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
13+
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
14+
HTML(ani.to_jshtml())
15+
16+
def compare_img(data,fake):
17+
img_grid_real = torchvision.utils.make_grid(data[:32], normalize=True).cpu()
18+
img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True).cpu()
19+
plt.imshow(img_grid_fake.permute(1, 2, 0))
20+
plt.imshow(img_grid_real.permute(1, 2, 0))
21+
22+
def plot_loss(G_losses,D_losses):
23+
plt.figure(figsize=(10,5))
24+
plt.title("Generator and Discriminator Loss During Training")
25+
plt.plot(G_losses,label="G")
26+
plt.plot(D_losses,label="D")
27+
plt.xlabel("iterations")
28+
plt.ylabel("Loss")
29+
plt.legend()
30+
plt.show()

0 commit comments

Comments
 (0)