|
| 1 | +import tensorflow as tf |
| 2 | +import glob |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import os |
| 5 | +import time |
| 6 | +import datetime |
| 7 | +import argparse |
| 8 | +from tensorflow.keras import layers |
| 9 | + |
| 10 | +print(tf.__version__) |
| 11 | +from utils import run_from_ipython, generate_latent_points, generate_and_save_images, save_gif, generate_varying_outputs |
| 12 | + |
| 13 | +parser = argparse.ArgumentParser() |
| 14 | +ipython = run_from_ipython() |
| 15 | + |
| 16 | +if ipython: |
| 17 | + from IPython import display |
| 18 | + |
| 19 | +parser.add_argument('--dataset', type = str, default = "MNIST", help = "Name of dataset: MNIST (default) or CIFAR10") |
| 20 | +parser.add_argument('--epochs', type = int, default = 0, help = "No of epochs: default 50 for MNIST, 150 for CIFAR10") |
| 21 | +parser.add_argument('--noise_dim', type = int, default = 0, help = "No of latent Noise variables, default 62 for MNIST, 64 for CIFAR10") |
| 22 | +parser.add_argument('--continuous_weight', type = float, default = 0.0, help = "Weight given to continuous Latent codes in loss calculation, default 0.5 for MNIST, 1 for CIFAR10") |
| 23 | +parser.add_argument('--batch_size', type = int, default = 256, help = "Batch size, default 256") |
| 24 | +parser.add_argument('--outdir', type = str, default = '.', help = "Directory in which to store data, don't put '/' at the end!") |
| 25 | + |
| 26 | +args = parser.parse_args() |
| 27 | + |
| 28 | +if args.dataset == "MNIST": |
| 29 | + from model_MNIST import make_generator_model, make_discriminator_model |
| 30 | + (train_images, _), (_, _) = tf.keras.datasets.mnist.load_data() |
| 31 | + train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') |
| 32 | + if args.epochs == 0 : |
| 33 | + args.epochs = 50 |
| 34 | + if args.noise_dim == 0 : |
| 35 | + args.noise_dim = 62 |
| 36 | + if args.continuous_weight == 0.0: |
| 37 | + args.continuous_weight = 0.5 |
| 38 | + |
| 39 | +else : |
| 40 | + from model_CIFAR10 import make_generator_model, make_discriminator_model |
| 41 | + (train_images, _), (_, _) = tf.keras.datasets.cifar10.load_data() |
| 42 | + train_images = train_images.reshape(train_images.shape[0], 32, 32, 3).astype('float32') |
| 43 | + if args.epochs == 0 : |
| 44 | + args.epochs = 150 |
| 45 | + if args.noise_dim == 0 : |
| 46 | + args.noise_dim = 64 |
| 47 | + if args.continuous_weight == 0.0: |
| 48 | + args.continuous_weight = 1 |
| 49 | + |
| 50 | +if not os.path.exists(f"{args.outdir}/assets/{args.dataset}"): |
| 51 | + os.makedirs(f"{args.outdir}/assets/{args.dataset}") |
| 52 | + |
| 53 | +#normalizing the images |
| 54 | +train_images = (train_images - 127.5) / 127.5 |
| 55 | + |
| 56 | +##### DEFINE GLOBAL VARIABLES AND OBJECTS ###### |
| 57 | +BUFFER_SIZE = 600000 |
| 58 | +BATCH_SIZE = args.batch_size |
| 59 | +epochs = args.epochs |
| 60 | +noise_dim = args.noise_dim |
| 61 | +continuous_dim = 2 |
| 62 | +categorical_dim = 10 |
| 63 | +num_examples_to_generate = 100 |
| 64 | +continuous_weight = args.continuous_weight |
| 65 | +seed, _, _ = generate_latent_points(num_examples_to_generate, noise_dim, categorical_dim, continuous_dim) # A constant sample of latent points so as to create images |
| 66 | + |
| 67 | + # Define Generator |
| 68 | +generator = make_generator_model(noise_dim) |
| 69 | +print("\nGenerator : ") |
| 70 | +print(generator.summary()) |
| 71 | +discriminator = make_discriminator_model() |
| 72 | +print("\nDiscriminator : ") |
| 73 | +print(discriminator.summary()) |
| 74 | + |
| 75 | +print("Dataset : ", args.dataset) |
| 76 | +########################################### |
| 77 | + |
| 78 | +# Converting data to tf Dataset |
| 79 | +train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) |
| 80 | + |
| 81 | +# defining losses |
| 82 | +binary_cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) |
| 83 | +categorical_cross_entropy = tf.keras.losses.CategoricalCrossentropy(from_logits=True) |
| 84 | + |
| 85 | +#defining optimizers |
| 86 | +generator_optimizer = tf.keras.optimizers.Adam(1e-3, beta_1=0.5 ) |
| 87 | +discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5) |
| 88 | + |
| 89 | +#defining storage points for checkpoints |
| 90 | +checkpoint_dir = f'{args.outdir}/training_checkpoints' |
| 91 | +checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") |
| 92 | +checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, discriminator_optimizer =discriminator_optimizer,generator=generator,discriminator=discriminator) |
| 93 | + |
| 94 | +#defining loss metrics for Plotting purposes with tensorboard |
| 95 | +discriminator_loss_metric = tf.keras.metrics.Mean('discriminator_loss', dtype=tf.float32) |
| 96 | +discriminator_real_accuracy_metric = tf.keras.metrics.BinaryCrossentropy('discriminator_real_accuracy', from_logits=True) |
| 97 | +discriminator_fake_accuracy_metric = tf.keras.metrics.BinaryCrossentropy('discriminator_fake_accuracy', from_logits=True) |
| 98 | +generator_loss_metric = tf.keras.metrics.Mean('generator_loss', dtype=tf.float32) |
| 99 | +categorical_loss_metric = tf.keras.metrics.Mean('categorical_loss', dtype=tf.float32) |
| 100 | +continuous_loss_metric = tf.keras.metrics.Mean('continuous_loss', dtype=tf.float32) |
| 101 | + |
| 102 | +# Save points for metrics |
| 103 | +current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") |
| 104 | +base = f"{args.outdir}/logs/gradientTape/{current_time}" |
| 105 | +disc_log_dir = base + '/discriminator' |
| 106 | +gen_log_dir = base + '/generator' |
| 107 | +cont_log_dir = base + '/cont' |
| 108 | +cat_log_dir = base + '/cat' |
| 109 | + |
| 110 | +# Create summary writers |
| 111 | +disc_summary_writer = tf.summary.create_file_writer(disc_log_dir) |
| 112 | +gen_summary_writer = tf.summary.create_file_writer(gen_log_dir) |
| 113 | +cat_summary_writer = tf.summary.create_file_writer(cont_log_dir) |
| 114 | +cont_summary_writer = tf.summary.create_file_writer(cat_log_dir) |
| 115 | + |
| 116 | +################################## |
| 117 | +# A train step to train the model on a minibatch |
| 118 | + |
| 119 | +def train_step(images): |
| 120 | + noise, categorical_input, continuous_input = generate_latent_points(BATCH_SIZE, noise_dim, categorical_dim, continuous_dim) |
| 121 | + |
| 122 | + with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: |
| 123 | + generated_images = generator(noise, training=True) |
| 124 | + |
| 125 | + real_output = discriminator(images, training=True) |
| 126 | + fake_output = discriminator(generated_images, training=True) |
| 127 | + |
| 128 | + disc_loss, real_loss, fake_loss, categorical_loss, continuous_loss = discriminator_loss(real_output, fake_output, categorical_input, continuous_input) |
| 129 | + gen_loss = generator_loss(fake_output, categorical_loss, continuous_loss) |
| 130 | + |
| 131 | + discriminator_loss_metric(disc_loss) |
| 132 | + generator_loss_metric(gen_loss) |
| 133 | + discriminator_real_accuracy_metric(tf.ones_like(real_output[:,0]), real_output[:,0]) |
| 134 | + discriminator_fake_accuracy_metric(tf.zeros_like(fake_output[:,0]), fake_output[:,0]) |
| 135 | + categorical_loss_metric(categorical_loss) |
| 136 | + continuous_loss_metric(continuous_loss) |
| 137 | + |
| 138 | + print(f"Losses - Disc : [{disc_loss}], Gen : [{gen_loss}], \n categorical loss : {categorical_loss}, continuous loss : {continuous_loss}") |
| 139 | + |
| 140 | + gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) |
| 141 | + gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) |
| 142 | + |
| 143 | + generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) |
| 144 | + discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) |
| 145 | + |
| 146 | +#################################### |
| 147 | + |
| 148 | +def discriminator_loss(real_output, fake_output, categorical_input, continuous_input): |
| 149 | + real_loss = binary_cross_entropy(tf.ones_like(real_output[:,0]), real_output[:,0]) |
| 150 | + fake_loss = binary_cross_entropy(tf.zeros_like(fake_output[:,0]), fake_output[:,0]) |
| 151 | + |
| 152 | + categorical_output = fake_output[:,1:1 + categorical_dim] |
| 153 | + continuous_output = fake_output[:, 1+categorical_dim : ] |
| 154 | + |
| 155 | + categorical_loss = categorical_cross_entropy(categorical_input, categorical_output) |
| 156 | + continuous_loss = tf.reduce_mean((2*(continuous_output - continuous_input))**2) |
| 157 | + |
| 158 | + total_loss = real_loss + fake_loss + continuous_weight*continuous_loss + categorical_loss |
| 159 | + return total_loss, real_loss, fake_loss, categorical_loss, continuous_loss |
| 160 | + |
| 161 | +##################################### |
| 162 | + |
| 163 | +def generator_loss(fake_output, categorical_loss, continuous_loss): |
| 164 | + gen_loss = binary_cross_entropy(tf.ones_like(fake_output[:,0]), fake_output[:,0]) |
| 165 | + return gen_loss + continuous_weight*continuous_loss + categorical_loss |
| 166 | + |
| 167 | +##################################### |
| 168 | + |
| 169 | +def main(): |
| 170 | + # begin the training loop |
| 171 | + |
| 172 | + for epoch in range(epochs): |
| 173 | + start = time.time() |
| 174 | + print(f"EPOCH : {epoch+1}") |
| 175 | + for image_batch in train_dataset: |
| 176 | + train_step(image_batch) |
| 177 | + # Produce images for the GIF |
| 178 | + if ipython: |
| 179 | + display.clear_output(wait=True) |
| 180 | + generate_and_save_images(generator, epoch + 1, seed, outdir = args.outdir, dataset = args.dataset) |
| 181 | + |
| 182 | + # Save the model every 15 epochs |
| 183 | + if (epoch + 1) % 15 == 0: |
| 184 | + checkpoint.save(file_prefix = checkpoint_prefix) |
| 185 | + |
| 186 | + # writing to summary writers |
| 187 | + with disc_summary_writer.as_default(): |
| 188 | + tf.summary.scalar('Loss', discriminator_loss_metric.result(), step = epoch) |
| 189 | + tf.summary.scalar('Real Accuracy', discriminator_real_accuracy_metric.result(), step = epoch) |
| 190 | + tf.summary.scalar('Fake Accuracy', discriminator_fake_accuracy_metric.result(), step = epoch) |
| 191 | + |
| 192 | + with cat_summary_writer.as_default(): |
| 193 | + tf.summary.scalar('Loss', categorical_loss_metric.result(), step = epoch) |
| 194 | + |
| 195 | + with cont_summary_writer.as_default(): |
| 196 | + tf.summary.scalar('Loss', continuous_loss_metric.result(), step = epoch) |
| 197 | + |
| 198 | + with gen_summary_writer.as_default(): |
| 199 | + tf.summary.scalar('Loss', generator_loss_metric.result(), step = epoch) |
| 200 | + |
| 201 | + print('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) |
| 202 | + print(f'Epoch results: Discriminator Loss: {discriminator_loss_metric.result()}, Real Accuracy: {discriminator_real_accuracy_metric.result()}, Fake Accuracy: {discriminator_fake_accuracy_metric.result()}') |
| 203 | + print(f' Generator Loss: {generator_loss_metric.result()}') |
| 204 | + |
| 205 | + discriminator_loss_metric.reset_states() |
| 206 | + discriminator_real_accuracy_metric.reset_states() |
| 207 | + discriminator_fake_accuracy_metric.reset_states() |
| 208 | + generator_loss_metric.reset_states() |
| 209 | + categorical_loss_metric.reset_states() |
| 210 | + continuous_loss_metric.reset_states() |
| 211 | + |
| 212 | + # Generate after the final epoch |
| 213 | + if ipython: |
| 214 | + display.clear_output(wait=True) |
| 215 | + generate_and_save_images(generator, epochs, seed, outdir = args.outdir, dataset = args.dataset) |
| 216 | + |
| 217 | + save_gif(args.outdir, args.dataset) |
| 218 | + |
| 219 | + # For producing outputs with constant noise and varying continuous and categorical latent codes |
| 220 | + |
| 221 | + generate_varying_outputs(generator, num_examples_to_generate, noise_dim, args.dataset, args.outdir) |
| 222 | + |
| 223 | +if __name__ == '__main__': |
| 224 | + main() |
0 commit comments