|
| 1 | +import tensorflow as tf |
| 2 | +import numpy as np |
| 3 | +import os |
| 4 | +import PIL |
| 5 | +from tensorflow.keras.models import Sequential, Model, load_model |
| 6 | +from tensorflow.keras.optimizers import Adam |
| 7 | +from tensorflow.keras.losses import BinaryCrossentropy |
| 8 | + |
| 9 | +from MNIST_model import make_generator,make_discriminator |
| 10 | +from utils.py import plot_loss,generate_and_save_images |
| 11 | +import argparse |
| 12 | + |
| 13 | +parser = argparse.ArgumentParser() |
| 14 | + |
| 15 | +parser.add_argument('--EPOCHS', type = int, default = 50, help = "No of EPOCHS: default 50 ") |
| 16 | +parser.add_argument('--noise_dim', type = int, default = 100, help = "Noise dimension, default 100 ") |
| 17 | +parser.add_argument('--BATCH_SIZE', type = int, default = 128, help = "Batch size, default 128") |
| 18 | +parser.add_argument('--num_examples_to_generate', type = int, default = 16, help = "no of images shown after each epoch in output, default 16") |
| 19 | +parser.add_argument('--lr_gen', type = int, default = 0.0002, help = "Learning rate for generator optimizer,default 0.0002 ") |
| 20 | +parser.add_argument('--lr_disc', type = int, default = 0.0002, help = "Learning rate for discriminator optimizer,default 0.0002 ") |
| 21 | +parser.add_argument('--outdir', type = str, default = '.', help = "Directory in which to store data") |
| 22 | + |
| 23 | +args = parser.parse_args() |
| 24 | + |
| 25 | +# Loading MNIST_Dataset |
| 26 | +(train_images, train_labels),(_,_) = tf.keras.datasets.mnist.load_data() |
| 27 | + |
| 28 | +BUFFER_SIZE=60000 |
| 29 | +BATCH_SIZE = args.BATCH_SIZE |
| 30 | +EPOCHS = args.EPOCHS |
| 31 | +noise_dim = args.noise_dim |
| 32 | +num_examples_to_generate = args.num_examples_to_generate |
| 33 | +lr_gen = args.lr_gen |
| 34 | +lr_disc = args.lr_disc |
| 35 | + |
| 36 | + |
| 37 | +seed = tf.random.normal([num_examples_to_generate, noise_dim]) |
| 38 | + |
| 39 | +# Preparing and Normalising Dataset |
| 40 | +train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') |
| 41 | +train_images = (train_images - 127.5) / 127.5 |
| 42 | +train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) |
| 43 | + |
| 44 | +# Making generator and Discriminator |
| 45 | +generator=make_generator(noise_dim) |
| 46 | +discriminator=make_discriminator() |
| 47 | + |
| 48 | +# Defining generator and discriminator losses |
| 49 | +cross_entropy = BinaryCrossentropy(from_logits=True) |
| 50 | + |
| 51 | +def discriminator_loss(real_output, fake_output): |
| 52 | + real_loss = cross_entropy(tf.ones_like(real_output), real_output) |
| 53 | + fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) |
| 54 | + total_loss = real_loss + fake_loss |
| 55 | + return total_loss |
| 56 | + |
| 57 | +def generator_loss(fake_output): |
| 58 | + return cross_entropy(tf.ones_like(fake_output), fake_output) |
| 59 | + |
| 60 | +# Defining optimizers |
| 61 | +generator_optimizer = Adam(learning_rate=lr_gen) |
| 62 | +discriminator_optimizer = Adam(learning_rate=lr_disc) |
| 63 | + |
| 64 | +# Saving Checkpoints |
| 65 | +checkpoint_dir = os.path.join(args.outdir, "training_checkpoints") |
| 66 | +checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") |
| 67 | +checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, |
| 68 | + discriminator_optimizer=discriminator_optimizer, |
| 69 | + generator=generator, |
| 70 | + discriminator=discriminator) |
| 71 | + |
| 72 | +# Defining Training Loop |
| 73 | +@tf.function |
| 74 | +def train_step(images): |
| 75 | + noise = tf.random.normal([BATCH_SIZE, noise_dim]) |
| 76 | + |
| 77 | + with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: |
| 78 | + generated_images = generator(noise, training=True) |
| 79 | + |
| 80 | + real_output = discriminator(images, training=True) |
| 81 | + fake_output = discriminator(generated_images, training=True) |
| 82 | + |
| 83 | + gen_loss = generator_loss(fake_output) |
| 84 | + disc_loss = discriminator_loss(real_output, fake_output) |
| 85 | + |
| 86 | + gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) |
| 87 | + gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) |
| 88 | + |
| 89 | + generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) |
| 90 | + discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) |
| 91 | + return gen_loss,disc_loss |
| 92 | + |
| 93 | +def train(dataset, epochs): |
| 94 | + for epoch in range(epochs): |
| 95 | + gen_loss_list = [] |
| 96 | + disc_loss_list = [] |
| 97 | + |
| 98 | + for image_batch in dataset: |
| 99 | + t=train_step(image_batch) |
| 100 | + gen_loss_list.append(t[0]) |
| 101 | + disc_loss_list.append(t[1]) |
| 102 | + |
| 103 | + gen_loss = sum(gen_loss_list) / len(gen_loss_list) |
| 104 | + disc_loss = sum(disc_loss_list) / len(disc_loss_list) |
| 105 | + |
| 106 | + |
| 107 | + print (f'Epoch {epoch+1}, gen loss={gen_loss},disc loss={disc_loss}') |
| 108 | + |
| 109 | + generate_and_save_images(generator, |
| 110 | + epoch + 1, |
| 111 | + seed) |
| 112 | + |
| 113 | + # Save the model every 15 epochs |
| 114 | + if (epoch + 1) % 15 == 0: |
| 115 | + checkpoint.save(file_prefix = checkpoint_prefix) |
| 116 | + |
| 117 | + return gen_loss_list,disc_loss_list |
| 118 | + |
| 119 | +# Training our model |
| 120 | +plo=train(train_dataset, EPOCHS) |
| 121 | + |
| 122 | +#Ploting generator and discriminator losses |
| 123 | +plot_loss(plo[0],plo[1]) |
0 commit comments