diff --git a/README.md b/README.md index 3a02f5c58..68a425e4e 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Tensorflow implementation of [Deep Convolutional Generative Adversarial Networks - [Tensorflow 0.12.1](https://github.com/tensorflow/tensorflow/tree/r0.12) - [SciPy](http://www.scipy.org/install.html) - [pillow](https://github.com/python-pillow/Pillow) +- [tqdm](https://pypi.org/project/tqdm/) - (Optional) [moviepy](https://github.com/Zulko/moviepy) (for visualization) - (Optional) [Align&Cropped Images.zip](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) : Large-scale CelebFaces Dataset @@ -29,26 +30,34 @@ Tensorflow implementation of [Deep Convolutional Generative Adversarial Networks First, download dataset with: - $ python download.py --datasets mnist celebA + $ python download.py mnist celebA To train a model with downloaded dataset: - $ python main.py --dataset mnist --input_height=28 --output_height=28 --c_dim=1 --is_train - $ python main.py --dataset celebA --input_height=108 --is_train --is_crop True + $ python main.py --dataset mnist --input_height=28 --output_height=28 --train + $ python main.py --dataset celebA --input_height=108 --train --crop To test with an existing model: - $ python main.py --dataset mnist --input_height=28 --output_height=28 --c_dim=1 - $ python main.py --dataset celebA --input_height=108 --is_crop True + $ python main.py --dataset mnist --input_height=28 --output_height=28 + $ python main.py --dataset celebA --input_height=108 --crop Or, you can use your own dataset (without central crop) by: $ mkdir data/DATASET_NAME ... add images to data/DATASET_NAME ... - $ python main.py --dataset DATASET_NAME --is_train + $ python main.py --dataset DATASET_NAME --train $ python main.py --dataset DATASET_NAME $ # example - $ python main.py --dataset=eyes --input_fname_pattern="*_cropped.png" --c_dim=1 --is_train + $ python main.py --dataset=eyes --input_fname_pattern="*_cropped.png" --train + +If your dataset is located in a different root directory: + + $ python main.py --dataset DATASET_NAME --data_dir DATASET_ROOT_DIR --train + $ python main.py --dataset DATASET_NAME --data_dir DATASET_ROOT_DIR + $ # example + $ python main.py --dataset=eyes --data_dir ../datasets/ --input_fname_pattern="*_cropped.png" --train + ## Results @@ -100,6 +109,13 @@ Details of the histogram of true and fake result of discriminator (with custom d ![d__hist](assets/d__hist.png) +## Related works + +- [BEGAN-tensorflow](https://github.com/carpedm20/BEGAN-tensorflow) +- [DiscoGAN-pytorch](https://github.com/carpedm20/DiscoGAN-pytorch) +- [simulated-unsupervised-tensorflow](https://github.com/carpedm20/simulated-unsupervised-tensorflow) + + ## Author Taehoon Kim / [@carpedm20](http://carpedm20.github.io/) diff --git a/assets/test_2016-01-27 15:07:47.png b/assets/test_2016-01-27 15_07_47.png similarity index 100% rename from assets/test_2016-01-27 15:07:47.png rename to assets/test_2016-01-27 15_07_47.png diff --git a/assets/test_2016-01-27 15:08:45.png b/assets/test_2016-01-27 15_08_45.png similarity index 100% rename from assets/test_2016-01-27 15:08:45.png rename to assets/test_2016-01-27 15_08_45.png diff --git a/assets/test_2016-01-27 15:08:54.png b/assets/test_2016-01-27 15_08_54.png similarity index 100% rename from assets/test_2016-01-27 15:08:54.png rename to assets/test_2016-01-27 15_08_54.png diff --git a/assets/test_2016-01-27 15:08:57.png b/assets/test_2016-01-27 15_08_57.png similarity index 100% rename from assets/test_2016-01-27 15:08:57.png rename to assets/test_2016-01-27 15_08_57.png diff --git a/assets/test_2016-01-27 15:09:00.png b/assets/test_2016-01-27 15_09_00.png similarity index 100% rename from assets/test_2016-01-27 15:09:00.png rename to assets/test_2016-01-27 15_09_00.png diff --git a/assets/test_2016-01-27 15:09:04.png b/assets/test_2016-01-27 15_09_04.png similarity index 100% rename from assets/test_2016-01-27 15:09:04.png rename to assets/test_2016-01-27 15_09_04.png diff --git a/assets/test_2016-01-27 15:09:46.png b/assets/test_2016-01-27 15_09_46.png similarity index 100% rename from assets/test_2016-01-27 15:09:46.png rename to assets/test_2016-01-27 15_09_46.png diff --git a/assets/test_2016-01-27 15:09:50.png b/assets/test_2016-01-27 15_09_50.png similarity index 100% rename from assets/test_2016-01-27 15:09:50.png rename to assets/test_2016-01-27 15_09_50.png diff --git a/download.py b/download.py index cb5397f5f..cdd416ea0 100644 --- a/download.py +++ b/download.py @@ -172,7 +172,7 @@ def prepare_data_dir(path = './data'): args = parser.parse_args() prepare_data_dir() - if 'celebA' in args.datasets: + if any(name in args.datasets for name in ['CelebA', 'celebA', 'celebA']): download_celeb_a('./data') if 'lsun' in args.datasets: download_lsun('./data') diff --git a/main.py b/main.py index 71d310134..b2ea7c451 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,10 @@ import os import scipy.misc import numpy as np +import json from model import DCGAN -from utils import pp, visualize, to_json, show_all_variables +from utils import pp, visualize, to_json, show_all_variables, expand_path, timestamp import tensorflow as tf @@ -11,34 +12,64 @@ flags.DEFINE_integer("epoch", 25, "Epoch to train [25]") flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") -flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]") +flags.DEFINE_float("train_size", np.inf, "The size of train images [np.inf]") flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]") flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]") flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]") flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]") -flags.DEFINE_integer("c_dim", 3, "Dimension of image color. [3]") flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]") flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]") -flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") -flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]") -flags.DEFINE_boolean("is_train", False, "True for training, False for testing [False]") -flags.DEFINE_boolean("is_crop", False, "True for training, False for testing [False]") +flags.DEFINE_string("data_dir", "./data", "path to datasets [e.g. $HOME/data]") +flags.DEFINE_string("out_dir", "./out", "Root directory for outputs [e.g. $HOME/out]") +flags.DEFINE_string("out_name", "", "Folder (under out_root_dir) for all outputs. Generated automatically if left blank []") +flags.DEFINE_string("checkpoint_dir", "checkpoint", "Folder (under out_root_dir/out_name) to save checkpoints [checkpoint]") +flags.DEFINE_string("sample_dir", "samples", "Folder (under out_root_dir/out_name) to save samples [samples]") +flags.DEFINE_boolean("train", False, "True for training, False for testing [False]") +flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]") flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]") +flags.DEFINE_boolean("export", False, "True for exporting with new batch size") +flags.DEFINE_boolean("freeze", False, "True for exporting with new batch size") +flags.DEFINE_integer("max_to_keep", 1, "maximum number of checkpoints to keep") +flags.DEFINE_integer("sample_freq", 200, "sample every this many iterations") +flags.DEFINE_integer("ckpt_freq", 200, "save checkpoint every this many iterations") +flags.DEFINE_integer("z_dim", 100, "dimensions of z") +flags.DEFINE_string("z_dist", "uniform_signed", "'normal01' or 'uniform_unsigned' or uniform_signed") +flags.DEFINE_boolean("G_img_sum", False, "Save generator image summaries in log") +#flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]") FLAGS = flags.FLAGS def main(_): pp.pprint(flags.FLAGS.__flags) + + # expand user name and environment variables + FLAGS.data_dir = expand_path(FLAGS.data_dir) + FLAGS.out_dir = expand_path(FLAGS.out_dir) + FLAGS.out_name = expand_path(FLAGS.out_name) + FLAGS.checkpoint_dir = expand_path(FLAGS.checkpoint_dir) + FLAGS.sample_dir = expand_path(FLAGS.sample_dir) - if FLAGS.input_width is None: - FLAGS.input_width = FLAGS.input_height - if FLAGS.output_width is None: - FLAGS.output_width = FLAGS.output_height + if FLAGS.output_height is None: FLAGS.output_height = FLAGS.input_height + if FLAGS.input_width is None: FLAGS.input_width = FLAGS.input_height + if FLAGS.output_width is None: FLAGS.output_width = FLAGS.output_height - if not os.path.exists(FLAGS.checkpoint_dir): - os.makedirs(FLAGS.checkpoint_dir) - if not os.path.exists(FLAGS.sample_dir): - os.makedirs(FLAGS.sample_dir) + # output folders + if FLAGS.out_name == "": + FLAGS.out_name = '{} - {} - {}'.format(timestamp(), FLAGS.data_dir.split('/')[-1], FLAGS.dataset) # penultimate folder of path + if FLAGS.train: + FLAGS.out_name += ' - x{}.z{}.{}.y{}.b{}'.format(FLAGS.input_width, FLAGS.z_dim, FLAGS.z_dist, FLAGS.output_width, FLAGS.batch_size) + + FLAGS.out_dir = os.path.join(FLAGS.out_dir, FLAGS.out_name) + FLAGS.checkpoint_dir = os.path.join(FLAGS.out_dir, FLAGS.checkpoint_dir) + FLAGS.sample_dir = os.path.join(FLAGS.out_dir, FLAGS.sample_dir) + + if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) + if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) + + with open(os.path.join(FLAGS.out_dir, 'FLAGS.json'), 'w') as f: + flags_dict = {k:FLAGS[k].value for k in FLAGS} + json.dump(flags_dict, f, indent=4, sort_keys=True, ensure_ascii=False) + #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333) run_config = tf.ConfigProto() @@ -55,12 +86,15 @@ def main(_): batch_size=FLAGS.batch_size, sample_num=FLAGS.batch_size, y_dim=10, - c_dim=1, + z_dim=FLAGS.z_dim, dataset_name=FLAGS.dataset, input_fname_pattern=FLAGS.input_fname_pattern, - is_crop=FLAGS.is_crop, + crop=FLAGS.crop, checkpoint_dir=FLAGS.checkpoint_dir, - sample_dir=FLAGS.sample_dir) + sample_dir=FLAGS.sample_dir, + data_dir=FLAGS.data_dir, + out_dir=FLAGS.out_dir, + max_to_keep=FLAGS.max_to_keep) else: dcgan = DCGAN( sess, @@ -70,20 +104,25 @@ def main(_): output_height=FLAGS.output_height, batch_size=FLAGS.batch_size, sample_num=FLAGS.batch_size, - c_dim=FLAGS.c_dim, + z_dim=FLAGS.z_dim, dataset_name=FLAGS.dataset, input_fname_pattern=FLAGS.input_fname_pattern, - is_crop=FLAGS.is_crop, + crop=FLAGS.crop, checkpoint_dir=FLAGS.checkpoint_dir, - sample_dir=FLAGS.sample_dir) + sample_dir=FLAGS.sample_dir, + data_dir=FLAGS.data_dir, + out_dir=FLAGS.out_dir, + max_to_keep=FLAGS.max_to_keep) show_all_variables() - if FLAGS.is_train: + + if FLAGS.train: dcgan.train(FLAGS) else: - if not dcgan.load(FLAGS.checkpoint_dir): - raise Exception("[!] Train a model first, then run test mode") - + load_success, load_counter = dcgan.load(FLAGS.checkpoint_dir) + if not load_success: + raise Exception("Checkpoint not found in " + FLAGS.checkpoint_dir) + # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0], # [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1], @@ -92,8 +131,17 @@ def main(_): # [dcgan.h4_w, dcgan.h4_b, None]) # Below is codes for visualization - OPTION = 1 - visualize(sess, dcgan, FLAGS, OPTION) + if FLAGS.export: + export_dir = os.path.join(FLAGS.checkpoint_dir, 'export_b'+str(FLAGS.batch_size)) + dcgan.save(export_dir, load_counter, ckpt=True, frozen=False) + + if FLAGS.freeze: + export_dir = os.path.join(FLAGS.checkpoint_dir, 'frozen_b'+str(FLAGS.batch_size)) + dcgan.save(export_dir, load_counter, ckpt=False, frozen=True) + + if FLAGS.visualize: + OPTION = 1 + visualize(sess, dcgan, FLAGS, OPTION, FLAGS.sample_dir) if __name__ == '__main__': tf.app.run() diff --git a/model.py b/model.py index 969cc3fb9..e528ddb5c 100644 --- a/model.py +++ b/model.py @@ -1,4 +1,5 @@ from __future__ import division +from __future__ import print_function import os import time import math @@ -13,12 +14,19 @@ def conv_out_size_same(size, stride): return int(math.ceil(float(size) / float(stride))) +def gen_random(mode, size): + if mode=='normal01': return np.random.normal(0,1,size=size) + if mode=='uniform_signed': return np.random.uniform(-1,1,size=size) + if mode=='uniform_unsigned': return np.random.uniform(0,1,size=size) + + class DCGAN(object): - def __init__(self, sess, input_height=108, input_width=108, is_crop=True, + def __init__(self, sess, input_height=108, input_width=108, crop=True, batch_size=64, sample_num = 64, output_height=64, output_width=64, y_dim=None, z_dim=100, gf_dim=64, df_dim=64, gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default', - input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None): + max_to_keep=1, + input_fname_pattern='*.jpg', checkpoint_dir='ckpts', sample_dir='samples', out_dir='./out', data_dir='./data'): """ Args: @@ -33,8 +41,7 @@ def __init__(self, sess, input_height=108, input_width=108, is_crop=True, c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3] """ self.sess = sess - self.is_crop = is_crop - self.is_grayscale = (c_dim == 1) + self.crop = crop self.batch_size = batch_size self.sample_num = sample_num @@ -53,8 +60,6 @@ def __init__(self, sess, input_height=108, input_width=108, is_crop=True, self.gfc_dim = gfc_dim self.dfc_dim = dfc_dim - self.c_dim = c_dim - # batch normalization : deals with poor initialization helps gradient flow self.d_bn1 = batch_norm(name='d_bn1') self.d_bn2 = batch_norm(name='d_bn2') @@ -72,44 +77,57 @@ def __init__(self, sess, input_height=108, input_width=108, is_crop=True, self.dataset_name = dataset_name self.input_fname_pattern = input_fname_pattern self.checkpoint_dir = checkpoint_dir + self.data_dir = data_dir + self.out_dir = out_dir + self.max_to_keep = max_to_keep + + if self.dataset_name == 'mnist': + self.data_X, self.data_y = self.load_mnist() + self.c_dim = self.data_X[0].shape[-1] + else: + data_path = os.path.join(self.data_dir, self.dataset_name, self.input_fname_pattern) + self.data = glob(data_path) + if len(self.data) == 0: + raise Exception("[!] No data found in '" + data_path + "'") + np.random.shuffle(self.data) + imreadImg = imread(self.data[0]) + if len(imreadImg.shape) >= 3: #check if image is a non-grayscale image by checking channel number + self.c_dim = imread(self.data[0]).shape[-1] + else: + self.c_dim = 1 + + if len(self.data) < self.batch_size: + raise Exception("[!] Entire dataset size is less than the configured batch_size") + + self.grayscale = (self.c_dim == 1) + self.build_model() def build_model(self): if self.y_dim: - self.y= tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y') + self.y = tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y') + else: + self.y = None - if self.is_crop: + if self.crop: image_dims = [self.output_height, self.output_width, self.c_dim] else: image_dims = [self.input_height, self.input_width, self.c_dim] self.inputs = tf.placeholder( tf.float32, [self.batch_size] + image_dims, name='real_images') - self.sample_inputs = tf.placeholder( - tf.float32, [self.sample_num] + image_dims, name='sample_inputs') inputs = self.inputs - sample_inputs = self.sample_inputs self.z = tf.placeholder( tf.float32, [None, self.z_dim], name='z') self.z_sum = histogram_summary("z", self.z) - if self.y_dim: - self.G = self.generator(self.z, self.y) - self.D, self.D_logits = \ - self.discriminator(inputs, self.y, reuse=False) - - self.sampler = self.sampler(self.z, self.y) - self.D_, self.D_logits_ = \ - self.discriminator(self.G, self.y, reuse=True) - else: - self.G = self.generator(self.z) - self.D, self.D_logits = self.discriminator(inputs) - - self.sampler = self.sampler(self.z) - self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True) - + self.G = self.generator(self.z, self.y) + self.D, self.D_logits = self.discriminator(inputs, self.y, reuse=False) + self.sampler = self.sampler(self.z, self.y) + self.D_, self.D_logits_ = self.discriminator(self.G, self.y, reuse=True) + self.d_sum = histogram_summary("d", self.D) self.d__sum = histogram_summary("d_", self.D_) self.G_sum = image_summary("G", self.G) @@ -140,16 +158,9 @@ def sigmoid_cross_entropy_with_logits(x, y): self.d_vars = [var for var in t_vars if 'd_' in var.name] self.g_vars = [var for var in t_vars if 'g_' in var.name] - self.saver = tf.train.Saver() + self.saver = tf.train.Saver(max_to_keep=self.max_to_keep) def train(self, config): - """Train DCGAN""" - if config.dataset == 'mnist': - data_X, data_y = self.load_mnist() - else: - data = glob(os.path.join("./data", config.dataset, self.input_fname_pattern)) - #np.random.shuffle(data) - d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ .minimize(self.d_loss, var_list=self.d_vars) g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ @@ -159,28 +170,30 @@ def train(self, config): except: tf.initialize_all_variables().run() - self.g_sum = merge_summary([self.z_sum, self.d__sum, - self.G_sum, self.d_loss_fake_sum, self.g_loss_sum]) + if config.G_img_sum: + self.g_sum = merge_summary([self.z_sum, self.d__sum, self.G_sum, self.d_loss_fake_sum, self.g_loss_sum]) + else: + self.g_sum = merge_summary([self.z_sum, self.d__sum, self.d_loss_fake_sum, self.g_loss_sum]) self.d_sum = merge_summary( [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum]) - self.writer = SummaryWriter("./logs", self.sess.graph) + self.writer = SummaryWriter(os.path.join(self.out_dir, "logs"), self.sess.graph) - sample_z = np.random.uniform(-1, 1, size=(self.sample_num , self.z_dim)) + sample_z = gen_random(config.z_dist, size=(self.sample_num , self.z_dim)) if config.dataset == 'mnist': - sample_inputs = data_X[0:self.sample_num] - sample_labels = data_y[0:self.sample_num] + sample_inputs = self.data_X[0:self.sample_num] + sample_labels = self.data_y[0:self.sample_num] else: - sample_files = data[0:self.sample_num] + sample_files = self.data[0:self.sample_num] sample = [ get_image(sample_file, input_height=self.input_height, input_width=self.input_width, resize_height=self.output_height, resize_width=self.output_width, - is_crop=self.is_crop, - is_grayscale=self.is_grayscale) for sample_file in sample_files] - if (self.is_grayscale): + crop=self.crop, + grayscale=self.grayscale) for sample_file in sample_files] + if (self.grayscale): sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None] else: sample_inputs = np.array(sample).astype(np.float32) @@ -196,32 +209,33 @@ def train(self, config): for epoch in xrange(config.epoch): if config.dataset == 'mnist': - batch_idxs = min(len(data_X), config.train_size) // config.batch_size + batch_idxs = min(len(self.data_X), config.train_size) // config.batch_size else: - data = glob(os.path.join( - "./data", config.dataset, self.input_fname_pattern)) - batch_idxs = min(len(data), config.train_size) // config.batch_size + self.data = glob(os.path.join( + config.data_dir, config.dataset, self.input_fname_pattern)) + np.random.shuffle(self.data) + batch_idxs = min(len(self.data), config.train_size) // config.batch_size - for idx in xrange(0, batch_idxs): + for idx in xrange(0, int(batch_idxs)): if config.dataset == 'mnist': - batch_images = data_X[idx*config.batch_size:(idx+1)*config.batch_size] - batch_labels = data_y[idx*config.batch_size:(idx+1)*config.batch_size] + batch_images = self.data_X[idx*config.batch_size:(idx+1)*config.batch_size] + batch_labels = self.data_y[idx*config.batch_size:(idx+1)*config.batch_size] else: - batch_files = data[idx*config.batch_size:(idx+1)*config.batch_size] + batch_files = self.data[idx*config.batch_size:(idx+1)*config.batch_size] batch = [ get_image(batch_file, input_height=self.input_height, input_width=self.input_width, resize_height=self.output_height, resize_width=self.output_width, - is_crop=self.is_crop, - is_grayscale=self.is_grayscale) for batch_file in batch_files] - if (self.is_grayscale): + crop=self.crop, + grayscale=self.grayscale) for batch_file in batch_files] + if self.grayscale: batch_images = np.array(batch).astype(np.float32)[:, :, :, None] else: batch_images = np.array(batch).astype(np.float32) - batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \ + batch_z = gen_random(config.z_dist, size=[config.batch_size, self.z_dim]) \ .astype(np.float32) if config.dataset == 'mnist': @@ -279,12 +293,11 @@ def train(self, config): errD_real = self.d_loss_real.eval({ self.inputs: batch_images }) errG = self.g_loss.eval({self.z: batch_z}) - counter += 1 - print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ - % (epoch, idx, batch_idxs, + print("[%8d Epoch:[%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ + % (counter, epoch, config.epoch, idx, batch_idxs, time.time() - start_time, errD_fake+errD_real, errG)) - if np.mod(counter, 100) == 1: + if np.mod(counter, config.sample_freq) == 0: if config.dataset == 'mnist': samples, d_loss, g_loss = self.sess.run( [self.sampler, self.d_loss, self.g_loss], @@ -294,10 +307,8 @@ def train(self, config): self.y:sample_labels, } ) - manifold_h = int(np.ceil(np.sqrt(samples.shape[0]))) - manifold_w = int(np.floor(np.sqrt(samples.shape[0]))) - save_images(samples, [manifold_h, manifold_w], - './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx)) + save_images(samples, image_manifold_size(samples.shape[0]), + './{}/train_{:08d}.png'.format(config.sample_dir, counter)) print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) else: try: @@ -308,17 +319,17 @@ def train(self, config): self.inputs: sample_inputs, }, ) - manifold_h = int(np.ceil(np.sqrt(samples.shape[0]))) - manifold_w = int(np.floor(np.sqrt(samples.shape[0]))) - save_images(samples, [manifold_h, manifold_w], - './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx)) + save_images(samples, image_manifold_size(samples.shape[0]), + './{}/train_{:08d}.png'.format(config.sample_dir, counter)) print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) except: print("one pic error!...") - if np.mod(counter, 500) == 2: + if np.mod(counter, config.ckpt_freq) == 0: self.save(config.checkpoint_dir, counter) - + + counter += 1 + def discriminator(self, image, y=None, reuse=False): with tf.variable_scope("discriminator") as scope: if reuse: @@ -329,7 +340,7 @@ def discriminator(self, image, y=None, reuse=False): h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv'))) h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv'))) h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv'))) - h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h3_lin') + h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h4_lin') return tf.nn.sigmoid(h4), h4 else: @@ -462,7 +473,7 @@ def sampler(self, z, y=None): return tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3')) def load_mnist(self): - data_dir = os.path.join("./data", self.dataset_name) + data_dir = os.path.join(self.data_dir, self.dataset_name) fd = open(os.path.join(data_dir,'train-images-idx3-ubyte')) loaded = np.fromfile(file=fd,dtype=np.uint8) @@ -503,28 +514,39 @@ def model_dir(self): return "{}_{}_{}_{}".format( self.dataset_name, self.batch_size, self.output_height, self.output_width) - - def save(self, checkpoint_dir, step): - model_name = "DCGAN.model" - checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) + def save(self, checkpoint_dir, step, filename='model', ckpt=True, frozen=False): + # model_name = "DCGAN.model" + # checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) + + filename += '.b' + str(self.batch_size) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) - self.saver.save(self.sess, - os.path.join(checkpoint_dir, model_name), - global_step=step) + if ckpt: + self.saver.save(self.sess, + os.path.join(checkpoint_dir, filename), + global_step=step) + + if frozen: + tf.train.write_graph( + tf.graph_util.convert_variables_to_constants(self.sess, self.sess.graph_def, ["generator_1/Tanh"]), + checkpoint_dir, + '{}-{:06d}_frz.pb'.format(filename, step), + as_text=False) def load(self, checkpoint_dir): - import re - print(" [*] Reading checkpoints...") - checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) + #import re + print(" [*] Reading checkpoints...", checkpoint_dir) + # checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) + # print(" ->", checkpoint_dir) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) - counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) + #counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) + counter = int(ckpt_name.split('-')[-1]) print(" [*] Success to read {}".format(ckpt_name)) return True, counter else: diff --git a/ops.py b/ops.py index e4a2847bc..e65ef1711 100644 --- a/ops.py +++ b/ops.py @@ -94,8 +94,13 @@ def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w= shape = input_.get_shape().as_list() with tf.variable_scope(scope or "Linear"): - matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, + try: + matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, tf.random_normal_initializer(stddev=stddev)) + except ValueError as err: + msg = "NOTE: Usually, this is due to an issue with the image dimensions. Did you correctly set '--crop' or '--input_height' or '--output_height'?" + err.args = err.args + (msg,) + raise bias = tf.get_variable("bias", [output_size], initializer=tf.constant_initializer(bias_start)) if with_w: diff --git a/utils.py b/utils.py index 30f3b1003..3c45c8c25 100644 --- a/utils.py +++ b/utils.py @@ -7,9 +7,14 @@ import random import pprint import scipy.misc +import cv2 import numpy as np +import os +import time +import datetime from time import gmtime, strftime from six.moves import xrange +from PIL import Image import tensorflow as tf import tensorflow.contrib.slim as slim @@ -18,25 +23,38 @@ get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1]) + +def expand_path(path): + return os.path.expanduser(os.path.expandvars(path)) + +def timestamp(s='%Y%m%d.%H%M%S', ts=None): + if not ts: ts = time.time() + st = datetime.datetime.fromtimestamp(ts).strftime(s) + return st + def show_all_variables(): model_vars = tf.trainable_variables() slim.model_analyzer.analyze_vars(model_vars, print_info=True) def get_image(image_path, input_height, input_width, resize_height=64, resize_width=64, - is_crop=True, is_grayscale=False): - image = imread(image_path, is_grayscale) + crop=True, grayscale=False): + image = imread(image_path, grayscale) return transform(image, input_height, input_width, - resize_height, resize_width, is_crop) + resize_height, resize_width, crop) def save_images(images, size, image_path): return imsave(inverse_transform(images), size, image_path) -def imread(path, is_grayscale = False): - if (is_grayscale): +def imread(path, grayscale = False): + if (grayscale): return scipy.misc.imread(path, flatten = True).astype(np.float) else: - return scipy.misc.imread(path).astype(np.float) + # Reference: https://github.com/carpedm20/DCGAN-tensorflow/issues/162#issuecomment-315519747 + img_bgr = cv2.imread(path) + # Reference: https://stackoverflow.com/a/15074748/ + img_rgb = img_bgr[..., ::-1] + return img_rgb.astype(np.float) def merge_images(images, size): return inverse_transform(images) @@ -63,7 +81,8 @@ def merge(images, size): 'must have dimensions: HxW or HxWx3 or HxWx4') def imsave(images, size, path): - return scipy.misc.imsave(path, merge(images, size)) + image = np.squeeze(merge(images, size)) + return scipy.misc.imsave(path, image) def center_crop(x, crop_h, crop_w, resize_h=64, resize_w=64): @@ -72,18 +91,18 @@ def center_crop(x, crop_h, crop_w, h, w = x.shape[:2] j = int(round((h - crop_h)/2.)) i = int(round((w - crop_w)/2.)) - return scipy.misc.imresize( - x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w]) + im = Image.fromarray(x[j:j+crop_h, i:i+crop_w]) + return np.array(im.resize([resize_h, resize_w]), PIL.Image.BILINEAR) def transform(image, input_height, input_width, - resize_height=64, resize_width=64, is_crop=True): - if is_crop: + resize_height=64, resize_width=64, crop=True): + if crop: cropped_image = center_crop( image, input_height, input_width, resize_height, resize_width) else: - cropped_image = scipy.misc.imresize(image, [resize_height, resize_width]) - return np.array(cropped_image)/127.5 - 1. + im = Image.fromarray(image[j:j+crop_h, i:i+crop_w]) + return np.array(im.resize([resize_h, resize_w]), PIL.Image.BILINEAR)/127.5 - 1. def inverse_transform(images): return (images+1.)/2. @@ -168,17 +187,17 @@ def make_frame(t): clip = mpy.VideoClip(make_frame, duration=duration) clip.write_gif(fname, fps = len(images) / duration) -def visualize(sess, dcgan, config, option): +def visualize(sess, dcgan, config, option, sample_dir='samples'): image_frame_dim = int(math.ceil(config.batch_size**.5)) if option == 0: z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim)) samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) - save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime())) + save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime() ))) elif option == 1: values = np.arange(0, 1, 1./config.batch_size) - for idx in xrange(100): + for idx in xrange(dcgan.z_dim): print(" [*] %d" % idx) - z_sample = np.zeros([config.batch_size, dcgan.z_dim]) + z_sample = np.random.uniform(-1, 1, size=(config.batch_size , dcgan.z_dim)) for kdx, z in enumerate(z_sample): z[idx] = values[kdx] @@ -191,10 +210,10 @@ def visualize(sess, dcgan, config, option): else: samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) - save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_arange_%s.png' % (idx)) + save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_arange_%s.png' % (idx))) elif option == 2: values = np.arange(0, 1, 1./config.batch_size) - for idx in [random.randint(0, 99) for _ in xrange(100)]: + for idx in [random.randint(0, dcgan.z_dim - 1) for _ in xrange(dcgan.z_dim)]: print(" [*] %d" % idx) z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim)) z_sample = np.tile(z, (config.batch_size, 1)) @@ -214,29 +233,36 @@ def visualize(sess, dcgan, config, option): try: make_gif(samples, './samples/test_gif_%s.gif' % (idx)) except: - save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime())) + save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime() ))) elif option == 3: values = np.arange(0, 1, 1./config.batch_size) - for idx in xrange(100): + for idx in xrange(dcgan.z_dim): print(" [*] %d" % idx) z_sample = np.zeros([config.batch_size, dcgan.z_dim]) for kdx, z in enumerate(z_sample): z[idx] = values[kdx] samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) - make_gif(samples, './samples/test_gif_%s.gif' % (idx)) + make_gif(samples, os.path.join(sample_dir, 'test_gif_%s.gif' % (idx))) elif option == 4: image_set = [] values = np.arange(0, 1, 1./config.batch_size) - for idx in xrange(100): + for idx in xrange(dcgan.z_dim): print(" [*] %d" % idx) z_sample = np.zeros([config.batch_size, dcgan.z_dim]) for kdx, z in enumerate(z_sample): z[idx] = values[kdx] image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})) - make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx)) + make_gif(image_set[-1], os.path.join(sample_dir, 'test_gif_%s.gif' % (idx))) new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \ for idx in range(64) + range(63, -1, -1)] make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8) + + +def image_manifold_size(num_images): + manifold_h = int(np.floor(np.sqrt(num_images))) + manifold_w = int(np.ceil(np.sqrt(num_images))) + assert manifold_h * manifold_w == num_images + return manifold_h, manifold_w diff --git a/web/index.html b/web/index.html index 4d1026153..6ea2fe615 100644 --- a/web/index.html +++ b/web/index.html @@ -34,8 +34,7 @@ - - + @@ -225,7 +224,7 @@
프사 뉴럴의 핵심 모델인 DCGAN은 두 개의 인공 신경망으로 구성되어 있으며, 각각
@@ -405,7 +404,7 @@