Skip to content

Commit 66dc400

Browse files
authored
Add InfoGAN, TensorFlow implementation (pclubiitk#7)
* Add Infogan implemented in TensorFlow * Add README.md, remove/rename files * Update README.md * Remove GIFs, add images instead
1 parent 4687027 commit 66dc400

17 files changed

+588
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# TensorFlow Implementation of InfoGAN
2+
## Usage
3+
```bash
4+
$ python3 main.py --dataset CIFAR10 --noise_dim 64
5+
```
6+
> **_NOTE:_** on Colab Notebook use following command:
7+
```python
8+
!git clone link-to-repo
9+
%run main.py main.py --dataset CIFAR10 --noise_dim 64
10+
```
11+
12+
## Help Log
13+
```
14+
usage: main.py [-h] [--dataset DATASET] [--epochs EPOCHS]
15+
[--noise_dim NOISE_DIM] [--continuous_weight CONTINUOUS_WEIGHT]
16+
[--batch_size BATCH_SIZE] [--outdir OUTDIR]
17+
18+
optional arguments:
19+
-h, --help show this help message and exit
20+
--dataset DATASET Name of dataset: MNIST (default) or CIFAR10
21+
--epochs EPOCHS No of epochs: default 50 for MNIST, 150 for CIFAR10
22+
--noise_dim NOISE_DIM
23+
No of latent Noise variables, default 62 for MNIST, 64
24+
for CIFAR10
25+
--continuous_weight CONTINUOUS_WEIGHT
26+
Weight given to continuous Latent codes in loss
27+
calculation, default 0.5 for MNIST, 1 for CIFAR10
28+
--batch_size BATCH_SIZE
29+
Batch size, default 256
30+
--outdir OUTDIR Directory in which to store data, don't put '/' at the
31+
end!
32+
```
33+
34+
## Contributed by:
35+
* [Atharv Singh Patlan](https://github.com/AthaSSiN)
36+
37+
## References
38+
39+
* **Title**: InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets
40+
* **Authors**: Xi Chen, Yan Duan, Rein Houthooft, John Schulman, Ilya Sutskever, Pieter Abbeel
41+
* **Link**: https://arxiv.org/pdf/1606.03657.pdf
42+
* **Tags**: Neural Network, Generative Networks, GANs
43+
* **Year**: 2016
44+
45+
# Summary
46+
47+
## Introduction
48+
49+
Generative adversarial nets were recently introduced as a novel way to train a generative model.
50+
They consist of two ‘adversarial’ models: a generative model G that captures the data distribution, and a discriminative model D that estimates the probability that a sample came from the training data rather than G.
51+
52+
However, the above specified GAN, termed as VanillaGAN, is not good in classifying the inputs provided to it, and hence generate an image as per our specifications. In order to do this, we need to tune the noise provided in the input provided to the GAN, and hence define a way so that the GAN learns to classify an image as belong to a given class, and also determine if it is real or fake.
53+
54+
Enter InfoGAN!
55+
56+
## InfoGAN
57+
58+
The idea is to provide a latent code, which has meaningful and consistent effects on the output. For instance, consider the MNIST dataset, where we have 10 digits. It would be helpful if we could use the property of the dataset having 10 classes and be able to assign a given digit with a particular value. This can be done by assigning part of the input to a 10-state discrete variable. The hope is that if you keep the code the same and randomly change the noise, you get variations of the same digit.
59+
60+
The way InfoGAN approaches this problem is by splitting the Generator input into two parts: the traditional noise vector and a new “latent code” vector. The codes are then made meaningful by maximizing the __Mutual Information__ between the code and the generator output.
61+
62+
![Eqn1](https://miro.medium.com/max/552/1*rSZXfx4_xcC-5z4LirNDRQ.png)
63+
64+
Here *V(D,G)* is the standard Vanilla Gan loss, and *I(c;G(z,c))* is the mutual information loss, with Lambda being sort of a regularization constant (the mutual information loss can be seen as a regularizing term
65+
66+
However, int the calculation of *I(c;G(z,c))*, we need to sample from the posterior distribution of the latent codes, which is usually intractable, and hence we replace it with a lower bound, calculated by approximating the posterior using an auxiliary distribution *Q(c|x)* and the reparameterization trick.
67+
68+
![Eqn2](https://miro.medium.com/max/552/1*NTYmbgNBT9RzhdLl71-koA.png)
69+
70+
Where
71+
![Eqn3](https://miro.medium.com/max/552/1*92L-ml_k7iQcPIWcvT7TIw.png)
72+
73+
Hence the final form of the loss function becomes:
74+
![Eqn4](https://miro.medium.com/max/552/1*W2G0DFBQUa52Piy1snYVjQ.png)
75+
76+
Thus, the problem basically reduces to the following process:
77+
1. Sample a value for the latent code c from a prior of your choice
78+
2. Sample a value for the noise z from a prior of your choice
79+
3. Generate x = *G(c,z)*
80+
4. Calculate *Q(c|x=G(c,z))*
81+
82+
## Implementation
83+
84+
In the implementation, we input a user defined number of noise variables, 10 categorical latent codes (hoping that in the output, each corresponds to a class of the datasets), and 2 uniform continuous latent codes (with values from -1 to 1), hoping that the correspond to some other features in the dataset
85+
86+
![Model](https://miro.medium.com/max/1104/1*dXLgTV8lNiTInvxomgZSAg.png)
87+
88+
We use the following default configuration:
89+
- Binary CE to calculate the loss in real and fake samples detection
90+
- Categorical CE to calculate the loss in categorical classification
91+
- Ordinary Least Squares to calculate the loss in continuous variable detection (The continuous variables are uniform in the input but in the architecture predicts them in the form of a Gaussian Distribution. So i tried outputting the mean and log variance of the predictions and hence calculating the losses using the reparameterization trick, but upon applying some basic mathematics, I realized that it all boils down to calculating the OLS of the predicted values)
92+
- Lambda = 1, however, the weight given to the loss of the continuous codes can be varied (we used 0.5 for MNIST and 1 for CIFAR10)
93+
94+
# Results
95+
96+
## On MNIST Dataset
97+
98+
Results after training for 50 epochs:
99+
![mnistFinal](https://github.com/AthaSSiN/model-zoo/blob/master/generative_models/InfoGAN_TensorFlow/assets/ReadmeImages/mnistfinal.png)
100+
101+
> **_NOTE:_** In this graph orange plot corresponds to dicriminator loss, blue to generator loss, Green to loss of continuous variables and Gray to loss in categorical variables.
102+
103+
104+
Loss:
105+
![mnistloss](https://github.com/AthaSSiN/model-zoo/blob/master/generative_models/InfoGAN_TensorFlow/assets/ReadmeImages/mnistloss.png)
106+
107+
Plot of Real and Fake detection accuracies:
108+
![mnistreal](https://github.com/AthaSSiN/model-zoo/blob/master/generative_models/InfoGAN_TensorFlow/assets/ReadmeImages/mnistrealaccuracy.png)
109+
![mnistfake](https://github.com/AthaSSiN/model-zoo/blob/master/generative_models/InfoGAN_TensorFlow/assets/ReadmeImages/mnistfakeaccuracy.png)
110+
111+
Here is the final image generated by the generator for a randomly generated noise and label, with one continuous code being varied along the rows.
112+
113+
In this one, the tilt in the images seems to change as we move left to right:
114+
![mnisttilt](https://github.com/AthaSSiN/model-zoo/blob/master/generative_models/InfoGAN_TensorFlow/assets/ReadmeImages/mnisttilt.png)
115+
116+
While in this, the thickness of the digits seems to change:
117+
![mnistthick](https://github.com/AthaSSiN/model-zoo/blob/master/generative_models/InfoGAN_TensorFlow/assets/ReadmeImages/mnistthick.png)
118+
119+
Note: In some cases, the digits have also changed while varying the continuous codes. I think that this is because there are many possible characters that the uniform codes can comply to, and its actually quite possible that they do not apply only to thickness / tilt etc, but can apply to curviness, or number of lines in a digit etc, which can make digits which look similar to each other, be generated by the same categorical code.
120+
121+
## On CIFAR10 Dataset
122+
123+
> **_NOTE_**: The paper does not have an implementation for the CIFAR10 dataset and hence the results aren't very good.
124+
125+
Results after training for 137 epochs
126+
127+
![cifargif](https://github.com/AthaSSiN/model-zoo/blob/master/generative_models/InfoGAN_TensorFlow/assets/ReadmeImages/CIFARfinal.png)
128+
129+
> **_NOTE:_** In this graph blue plot corresponds to generator loss and orange to discriminator loss
130+
131+
Here is the loss graph
132+
![cifarloss](https://github.com/AthaSSiN/model-zoo/blob/master/generative_models/InfoGAN_TensorFlow/assets/ReadmeImages/CIFARloss.png)
133+
134+
Plot of Real and Fake detection accuracies:
135+
![CIFARreal](https://github.com/AthaSSiN/model-zoo/blob/master/generative_models/InfoGAN_TensorFlow/assets/ReadmeImages/CIFARrealaccuracy.png)
136+
![CIFARfake](https://github.com/AthaSSiN/model-zoo/blob/master/generative_models/InfoGAN_TensorFlow/assets/ReadmeImages/CIFARfakeaccuracy.png)
137+
138+
Here is the final image generated by the generator for a randomly generated noise and label.
139+
140+
In this one, the background color varies as we move left to right:
141+
![cifarbg](https://github.com/AthaSSiN/model-zoo/blob/master/generative_models/InfoGAN_TensorFlow/assets/ReadmeImages/CIFARbackground.png)
142+
143+
While in this, the foreground color/size varies:
144+
![cifarfg](https://github.com/AthaSSiN/model-zoo/blob/master/generative_models/InfoGAN_TensorFlow/assets/ReadmeImages/CIFARforeground.png)
145+
146+
It seems the continuous latent codes are working fine, but the categorical codes weren't able to represent the different classes too well, hence there is room for a lot of experiments!
147+
148+
# Sources
149+
150+
- [InfoGAN — Generative Adversarial Networks Part III](https://towardsdatascience.com/infogan-generative-adversarial-networks-part-iii-380c0c6712cd)
151+
Template on which the code was built:
152+
- [DCGAN on TensorFlow tutorials](https://www.tensorflow.org/tutorials/generative/dcgan)
153+
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import tensorflow as tf
2+
import glob
3+
import matplotlib.pyplot as plt
4+
import os
5+
import time
6+
import datetime
7+
import argparse
8+
from tensorflow.keras import layers
9+
10+
print(tf.__version__)
11+
from utils import run_from_ipython, generate_latent_points, generate_and_save_images, save_gif, generate_varying_outputs
12+
13+
parser = argparse.ArgumentParser()
14+
ipython = run_from_ipython()
15+
16+
if ipython:
17+
from IPython import display
18+
19+
parser.add_argument('--dataset', type = str, default = "MNIST", help = "Name of dataset: MNIST (default) or CIFAR10")
20+
parser.add_argument('--epochs', type = int, default = 0, help = "No of epochs: default 50 for MNIST, 150 for CIFAR10")
21+
parser.add_argument('--noise_dim', type = int, default = 0, help = "No of latent Noise variables, default 62 for MNIST, 64 for CIFAR10")
22+
parser.add_argument('--continuous_weight', type = float, default = 0.0, help = "Weight given to continuous Latent codes in loss calculation, default 0.5 for MNIST, 1 for CIFAR10")
23+
parser.add_argument('--batch_size', type = int, default = 256, help = "Batch size, default 256")
24+
parser.add_argument('--outdir', type = str, default = '.', help = "Directory in which to store data, don't put '/' at the end!")
25+
26+
args = parser.parse_args()
27+
28+
if args.dataset == "MNIST":
29+
from model_MNIST import make_generator_model, make_discriminator_model
30+
(train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()
31+
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
32+
if args.epochs == 0 :
33+
args.epochs = 50
34+
if args.noise_dim == 0 :
35+
args.noise_dim = 62
36+
if args.continuous_weight == 0.0:
37+
args.continuous_weight = 0.5
38+
39+
else :
40+
from model_CIFAR10 import make_generator_model, make_discriminator_model
41+
(train_images, _), (_, _) = tf.keras.datasets.cifar10.load_data()
42+
train_images = train_images.reshape(train_images.shape[0], 32, 32, 3).astype('float32')
43+
if args.epochs == 0 :
44+
args.epochs = 150
45+
if args.noise_dim == 0 :
46+
args.noise_dim = 64
47+
if args.continuous_weight == 0.0:
48+
args.continuous_weight = 1
49+
50+
if not os.path.exists(f"{args.outdir}/assets/{args.dataset}"):
51+
os.makedirs(f"{args.outdir}/assets/{args.dataset}")
52+
53+
#normalizing the images
54+
train_images = (train_images - 127.5) / 127.5
55+
56+
##### DEFINE GLOBAL VARIABLES AND OBJECTS ######
57+
BUFFER_SIZE = 600000
58+
BATCH_SIZE = args.batch_size
59+
epochs = args.epochs
60+
noise_dim = args.noise_dim
61+
continuous_dim = 2
62+
categorical_dim = 10
63+
num_examples_to_generate = 100
64+
continuous_weight = args.continuous_weight
65+
seed, _, _ = generate_latent_points(num_examples_to_generate, noise_dim, categorical_dim, continuous_dim) # A constant sample of latent points so as to create images
66+
67+
# Define Generator
68+
generator = make_generator_model(noise_dim)
69+
print("\nGenerator : ")
70+
print(generator.summary())
71+
discriminator = make_discriminator_model()
72+
print("\nDiscriminator : ")
73+
print(discriminator.summary())
74+
75+
print("Dataset : ", args.dataset)
76+
###########################################
77+
78+
# Converting data to tf Dataset
79+
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
80+
81+
# defining losses
82+
binary_cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
83+
categorical_cross_entropy = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
84+
85+
#defining optimizers
86+
generator_optimizer = tf.keras.optimizers.Adam(1e-3, beta_1=0.5 )
87+
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
88+
89+
#defining storage points for checkpoints
90+
checkpoint_dir = f'{args.outdir}/training_checkpoints'
91+
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
92+
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, discriminator_optimizer =discriminator_optimizer,generator=generator,discriminator=discriminator)
93+
94+
#defining loss metrics for Plotting purposes with tensorboard
95+
discriminator_loss_metric = tf.keras.metrics.Mean('discriminator_loss', dtype=tf.float32)
96+
discriminator_real_accuracy_metric = tf.keras.metrics.BinaryCrossentropy('discriminator_real_accuracy', from_logits=True)
97+
discriminator_fake_accuracy_metric = tf.keras.metrics.BinaryCrossentropy('discriminator_fake_accuracy', from_logits=True)
98+
generator_loss_metric = tf.keras.metrics.Mean('generator_loss', dtype=tf.float32)
99+
categorical_loss_metric = tf.keras.metrics.Mean('categorical_loss', dtype=tf.float32)
100+
continuous_loss_metric = tf.keras.metrics.Mean('continuous_loss', dtype=tf.float32)
101+
102+
# Save points for metrics
103+
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
104+
base = f"{args.outdir}/logs/gradientTape/{current_time}"
105+
disc_log_dir = base + '/discriminator'
106+
gen_log_dir = base + '/generator'
107+
cont_log_dir = base + '/cont'
108+
cat_log_dir = base + '/cat'
109+
110+
# Create summary writers
111+
disc_summary_writer = tf.summary.create_file_writer(disc_log_dir)
112+
gen_summary_writer = tf.summary.create_file_writer(gen_log_dir)
113+
cat_summary_writer = tf.summary.create_file_writer(cont_log_dir)
114+
cont_summary_writer = tf.summary.create_file_writer(cat_log_dir)
115+
116+
##################################
117+
# A train step to train the model on a minibatch
118+
119+
def train_step(images):
120+
noise, categorical_input, continuous_input = generate_latent_points(BATCH_SIZE, noise_dim, categorical_dim, continuous_dim)
121+
122+
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
123+
generated_images = generator(noise, training=True)
124+
125+
real_output = discriminator(images, training=True)
126+
fake_output = discriminator(generated_images, training=True)
127+
128+
disc_loss, real_loss, fake_loss, categorical_loss, continuous_loss = discriminator_loss(real_output, fake_output, categorical_input, continuous_input)
129+
gen_loss = generator_loss(fake_output, categorical_loss, continuous_loss)
130+
131+
discriminator_loss_metric(disc_loss)
132+
generator_loss_metric(gen_loss)
133+
discriminator_real_accuracy_metric(tf.ones_like(real_output[:,0]), real_output[:,0])
134+
discriminator_fake_accuracy_metric(tf.zeros_like(fake_output[:,0]), fake_output[:,0])
135+
categorical_loss_metric(categorical_loss)
136+
continuous_loss_metric(continuous_loss)
137+
138+
print(f"Losses - Disc : [{disc_loss}], Gen : [{gen_loss}], \n categorical loss : {categorical_loss}, continuous loss : {continuous_loss}")
139+
140+
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
141+
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
142+
143+
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
144+
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
145+
146+
####################################
147+
148+
def discriminator_loss(real_output, fake_output, categorical_input, continuous_input):
149+
real_loss = binary_cross_entropy(tf.ones_like(real_output[:,0]), real_output[:,0])
150+
fake_loss = binary_cross_entropy(tf.zeros_like(fake_output[:,0]), fake_output[:,0])
151+
152+
categorical_output = fake_output[:,1:1 + categorical_dim]
153+
continuous_output = fake_output[:, 1+categorical_dim : ]
154+
155+
categorical_loss = categorical_cross_entropy(categorical_input, categorical_output)
156+
continuous_loss = tf.reduce_mean((2*(continuous_output - continuous_input))**2)
157+
158+
total_loss = real_loss + fake_loss + continuous_weight*continuous_loss + categorical_loss
159+
return total_loss, real_loss, fake_loss, categorical_loss, continuous_loss
160+
161+
#####################################
162+
163+
def generator_loss(fake_output, categorical_loss, continuous_loss):
164+
gen_loss = binary_cross_entropy(tf.ones_like(fake_output[:,0]), fake_output[:,0])
165+
return gen_loss + continuous_weight*continuous_loss + categorical_loss
166+
167+
#####################################
168+
169+
def main():
170+
# begin the training loop
171+
172+
for epoch in range(epochs):
173+
start = time.time()
174+
print(f"EPOCH : {epoch+1}")
175+
for image_batch in train_dataset:
176+
train_step(image_batch)
177+
# Produce images for the GIF
178+
if ipython:
179+
display.clear_output(wait=True)
180+
generate_and_save_images(generator, epoch + 1, seed, outdir = args.outdir, dataset = args.dataset)
181+
182+
# Save the model every 15 epochs
183+
if (epoch + 1) % 15 == 0:
184+
checkpoint.save(file_prefix = checkpoint_prefix)
185+
186+
# writing to summary writers
187+
with disc_summary_writer.as_default():
188+
tf.summary.scalar('Loss', discriminator_loss_metric.result(), step = epoch)
189+
tf.summary.scalar('Real Accuracy', discriminator_real_accuracy_metric.result(), step = epoch)
190+
tf.summary.scalar('Fake Accuracy', discriminator_fake_accuracy_metric.result(), step = epoch)
191+
192+
with cat_summary_writer.as_default():
193+
tf.summary.scalar('Loss', categorical_loss_metric.result(), step = epoch)
194+
195+
with cont_summary_writer.as_default():
196+
tf.summary.scalar('Loss', continuous_loss_metric.result(), step = epoch)
197+
198+
with gen_summary_writer.as_default():
199+
tf.summary.scalar('Loss', generator_loss_metric.result(), step = epoch)
200+
201+
print('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
202+
print(f'Epoch results: Discriminator Loss: {discriminator_loss_metric.result()}, Real Accuracy: {discriminator_real_accuracy_metric.result()}, Fake Accuracy: {discriminator_fake_accuracy_metric.result()}')
203+
print(f' Generator Loss: {generator_loss_metric.result()}')
204+
205+
discriminator_loss_metric.reset_states()
206+
discriminator_real_accuracy_metric.reset_states()
207+
discriminator_fake_accuracy_metric.reset_states()
208+
generator_loss_metric.reset_states()
209+
categorical_loss_metric.reset_states()
210+
continuous_loss_metric.reset_states()
211+
212+
# Generate after the final epoch
213+
if ipython:
214+
display.clear_output(wait=True)
215+
generate_and_save_images(generator, epochs, seed, outdir = args.outdir, dataset = args.dataset)
216+
217+
save_gif(args.outdir, args.dataset)
218+
219+
# For producing outputs with constant noise and varying continuous and categorical latent codes
220+
221+
generate_varying_outputs(generator, num_examples_to_generate, noise_dim, args.dataset, args.outdir)
222+
223+
if __name__ == '__main__':
224+
main()

0 commit comments

Comments
 (0)