Skip to content

Commit b3c08af

Browse files
Add VanillaGAN tensorflow files (pclubiitk#17)
* Add Gan folder,VanillaGAN-tensorflow * Create README.md * Update README.md * Update files and compress gif * changes applied Co-authored-by: ashishpm <[email protected]>
1 parent 08a0b56 commit b3c08af

File tree

12 files changed

+310
-0
lines changed

12 files changed

+310
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# TensorFlow Implementation of VanillaGAN on MNIST Dataset
2+
3+
### Usage
4+
```bash
5+
$ python3 main.py --epochs 50 --batch_size 128 --outdir "."
6+
```
7+
NOTE: on Colab Notebook use following command:
8+
```bash
9+
!git clone link-to-repo
10+
%run main.py --epochs 50 --batch_size 128 --outdir "."
11+
```
12+
13+
## Help Log
14+
```
15+
16+
usage: main.py [-h] [--epochs EPOCHS] [--batch_size BATCH_SIZE] --outdir
17+
OUTDIR [--learning_rate LEARNING_RATE] [--beta_1 BETA_1]
18+
--encoding_dims ENCODING_DIMS
19+
20+
optional arguments:
21+
-h, --help show this help message and exit
22+
--epochs EPOCHS
23+
--batch_size BATCH_SIZE
24+
--outdir OUTDIR
25+
--learning_rate LEARNING_RATE
26+
--beta_1 BETA_1
27+
--encoding_dims ENCODING_DIMS
28+
29+
```
30+
31+
### Contributed by:
32+
* [Ashish Murali](https://github.com/ashishmurali)
33+
34+
# References :
35+
36+
* **Title**: Generative Adversarial Networks
37+
* **Authors**: Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio
38+
* **Link**: http://arxiv.org/abs/1406.2661
39+
* **Tags**: Neural Network, GAN, generative models, unsupervised learning
40+
* **Year**: 2014
41+
42+
# Summary
43+
44+
* What are GANs
45+
* GANs are based on adversarial training.
46+
* Adversarial training is a basic technique to train generative models (so here primarily models that create new images).
47+
* In an adversarial training one model (G, Generator) generates things (e.g. images). Another model (D, discriminator) sees real things (e.g. real images) as well as fake things (e.g. images from G) and has to learn how to differentiate the two.
48+
* Neural Networks are models that can be trained in an adversarial way (and are the only models discussed here).
49+
50+
* Basic architecture of GANs
51+
* G is a simple neural net (e.g. just one fully connected hidden layer). It takes a vector as input (e.g. 100 dimensions) and produces an image as output.
52+
* D is a simple neural net (e.g. just one fully connected hidden layer). It takes an image as input and produces a quality rating as output (0-1, so sigmoid).
53+
* You need a training set of things to be generated, e.g. images of human faces.
54+
* Let the batch size be B.
55+
* G is trained the following way:
56+
* Create B vectors of 100 random values each, e.g. sampled uniformly from [-1, +1]. (Number of values per components depends on the chosen input size of G.)
57+
* Feed forward the vectors through G to create new images.
58+
* Feed forward the images through D to create ratings.
59+
* Use a cross entropy loss on these ratings. All of these (fake) images should be viewed as label=0 by D. If D gives them label=1, the error will be low (G did a good job).
60+
* Perform a backward pass of the errors through D (without training D). That generates gradients/errors per image and pixel.
61+
* Perform a backward pass of these errors through G to train G.
62+
* D is trained the following way:
63+
* Create B/2 images using G (again, B/2 random vectors, feed forward through G).
64+
* Chose B/2 images from the training set. Real images get label=1.
65+
* Merge the fake and real images to one batch. Fake images get label=0.
66+
* Feed forward the batch through D.
67+
* Measure the error using cross entropy.
68+
* Perform a backward pass with the error through D.
69+
* Train G for one batch, then D for one (or more) batches. Sometimes D can be too slow to catch up with D, then you need more iterations of D per batch of G.
70+
71+
* Results
72+
* Good looking images MNIST-numbers and human faces. (Grayscale, rather homogeneous datasets.)
73+
* Not so good looking images of CIFAR-10. (Color, rather heterogeneous datasets.)
74+
75+
76+
-------------------------
77+
# Our implementation :
78+
79+
80+
81+
* We have implemented the GAN model with the following architectures :
82+
83+
* Generator Architecture
84+
85+
![Generator](https://github.com/ashishmurali/model-zoo/blob/master/generative_models/VanillaGAN_TensorFlow/assets/generator_architecture.png)
86+
87+
88+
* Discriminator Architecture
89+
90+
![Discriminator](https://github.com/ashishmurali/model-zoo/blob/master/generative_models/VanillaGAN_TensorFlow/assets/discriminator_architecture.png)
91+
92+
93+
94+
# Results of our implementation :
95+
96+
97+
98+
* The following GIF shows how our model has improved generating digits after 400 epochs of training
99+
100+
![gif](https://github.com/ashishmurali/model-zoo/blob/master/generative_models/VanillaGAN_TensorFlow/assets/gan.gif)
101+
102+
* The image generated by our model after the first epoch
103+
104+
![epoch1](https://github.com/ashishmurali/model-zoo/blob/master/generative_models/VanillaGAN_TensorFlow/assets/gan_image%201.png)
105+
106+
* The image generated by our model after the 400th epoch
107+
108+
![epoch400](https://github.com/ashishmurali/model-zoo/blob/master/generative_models/VanillaGAN_TensorFlow/assets/gan_image%20400.png)
109+
110+
* The Generator loss for our model
111+
112+
![gloss](https://github.com/ashishmurali/model-zoo/blob/master/generative_models/VanillaGAN_TensorFlow/assets/generator_loss.png)
113+
114+
* The Discriminator loss for our model
115+
116+
![dloss](https://github.com/ashishmurali/model-zoo/blob/master/generative_models/VanillaGAN_TensorFlow/assets/discriminator_loss.png)
117+
118+
119+
120+
### Sources:
121+
* [Papers](https://github.com/aleju/papers/blob/master/neural-nets/Generative_Adversarial_Networks.md)
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import numpy as np
2+
import tensorflow as tf
3+
from tensorflow.keras.datasets import mnist
4+
5+
6+
def load_data():
7+
(x_train,_), (_,_) = mnist.load_data()
8+
x_train = (x_train.astype(np.float32) - 127.5)/127.5
9+
x_train = x_train.reshape(60000, 784)
10+
return x_train
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import os
2+
import argparse
3+
import numpy as np
4+
5+
from dataloader import load_data
6+
from utils import plot_generated_images,make_gif
7+
from models import create_generator,create_gan,create_discriminator
8+
9+
def run_from_ipython():
10+
try:
11+
__IPYTHON__
12+
return True
13+
except NameError:
14+
return False
15+
16+
ipython = run_from_ipython()
17+
18+
if ipython:
19+
from IPython import display
20+
21+
parser = argparse.ArgumentParser()
22+
23+
parser.add_argument('--epochs', type=int, default=50)
24+
parser.add_argument('--batch_size', type=int, default=128)
25+
parser.add_argument('--outdir', type=str, required=True,default='.')
26+
parser.add_argument('--learning_rate', type=float, default=0.0002)
27+
parser.add_argument('--beta_1', type=float, default=0.5)
28+
parser.add_argument('--encoding_dims', type=int, required=True,default=100)
29+
30+
args = parser.parse_args()
31+
32+
outdir = args.outdir
33+
if not os.path.exists(outdir):
34+
os.makedirs(outdir)
35+
36+
epochs = args.epochs
37+
batch_size = args.batch_size
38+
outdir = args.outdir
39+
learning_rate = args.learning_rate
40+
beta_1 = args.beta_1
41+
encoding_dims = args.encoding_dims
42+
43+
def training(epochs, batch_size):
44+
45+
X_train = load_data()
46+
batch_count = int(X_train.shape[0] / batch_size)
47+
48+
generator= create_generator(learning_rate,beta_1,encoding_dims)
49+
discriminator= create_discriminator(learning_rate,beta_1)
50+
gan = create_gan(discriminator, generator,encoding_dims)
51+
52+
valid = np.ones((batch_size, 1))
53+
fake = np.zeros((batch_size, 1))
54+
55+
seed = np.random.normal(0,1, [25, encoding_dims])
56+
57+
for e in range(1,epochs+1 ):
58+
print("Epoch %d" %e)
59+
for _ in range(batch_count):
60+
61+
noise= np.random.normal(0,1, [batch_size, encoding_dims])
62+
generated_images = generator.predict(noise)
63+
64+
image_batch = X_train[np.random.randint(low=0,high=X_train.shape[0],size=batch_size)]
65+
66+
discriminator.trainable=True
67+
d_loss_real = discriminator.train_on_batch(image_batch, valid)
68+
d_loss_fake = discriminator.train_on_batch(generated_images, fake)
69+
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
70+
71+
noise= np.random.normal(0,1, [batch_size, encoding_dims])
72+
73+
discriminator.trainable=False
74+
g_loss = gan.train_on_batch(noise,valid)
75+
76+
print ("%d [D loss: %f] [G loss: %f]" % (e, d_loss, g_loss))
77+
if ipython:
78+
display.clear_output(wait=True)
79+
plot_generated_images(e, generator,seed,outdir)
80+
generator.save('{}/gan_model'.format(outdir))
81+
82+
training(epochs,batch_size)
83+
84+
make_gif(outdir)

0 commit comments

Comments
 (0)