Skip to content

Commit 924d8a6

Browse files
committed
publish firsy
1 parent f621d25 commit 924d8a6

14 files changed

+3787
-0
lines changed

Diff for: ACGAN.py

+337
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
#-*- coding: utf-8 -*-
2+
from __future__ import division
3+
import os
4+
import time
5+
import tensorflow as tf
6+
import numpy as np
7+
8+
from ops import *
9+
from utils import *
10+
11+
class ACGAN(object):
12+
model_name = "ACGAN" # name for checkpoint
13+
14+
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
15+
self.sess = sess
16+
self.dataset_name = dataset_name
17+
self.checkpoint_dir = checkpoint_dir
18+
self.result_dir = result_dir
19+
self.log_dir = log_dir
20+
self.epoch = epoch
21+
self.batch_size = batch_size
22+
23+
if dataset_name == 'mnist' or dataset_name == 'fashion-mnist':
24+
# parameters
25+
self.input_height = 28
26+
self.input_width = 28
27+
self.output_height = 28
28+
self.output_width = 28
29+
30+
self.z_dim = z_dim # dimension of noise-vector
31+
self.y_dim = 10 # dimension of code-vector (label)
32+
self.c_dim = 1
33+
34+
# train
35+
self.learning_rate = 0.0002
36+
self.beta1 = 0.5
37+
38+
# test
39+
self.sample_num = 64 # number of generated images to be saved
40+
41+
# code
42+
self.len_discrete_code = 10 # categorical distribution (i.e. label)
43+
self.len_continuous_code = 2 # gaussian distribution (e.g. rotation, thickness)
44+
45+
# load mnist
46+
self.data_X, self.data_y = load_mnist(self.dataset_name)
47+
48+
# get number of batches for a single epoch
49+
self.num_batches = len(self.data_X) // self.batch_size
50+
else:
51+
raise NotImplementedError
52+
53+
def classifier(self, x, is_training=True, reuse=False):
54+
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
55+
# Architecture : (64)5c2s-(128)5c2s_BL-FC1024_BL-FC128_BL-FC12S’
56+
# All layers except the last two layers are shared by discriminator
57+
with tf.variable_scope("classifier", reuse=reuse):
58+
59+
net = lrelu(bn(linear(x, 128, scope='c_fc1'), is_training=is_training, scope='c_bn1'))
60+
out_logit = linear(net, self.y_dim, scope='c_fc2')
61+
out = tf.nn.softmax(out_logit)
62+
63+
return out, out_logit
64+
65+
def discriminator(self, x, is_training=True, reuse=False):
66+
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
67+
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
68+
with tf.variable_scope("discriminator", reuse=reuse):
69+
70+
net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1'))
71+
net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2'))
72+
net = tf.reshape(net, [self.batch_size, -1])
73+
net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3'))
74+
out_logit = linear(net, 1, scope='d_fc4')
75+
out = tf.nn.sigmoid(out_logit)
76+
77+
return out, out_logit, net
78+
79+
def generator(self, z, y, is_training=True, reuse=False):
80+
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
81+
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
82+
with tf.variable_scope("generator", reuse=reuse):
83+
84+
# merge noise and code
85+
z = concat([z, y], 1)
86+
87+
net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1'))
88+
net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2'))
89+
net = tf.reshape(net, [self.batch_size, 7, 7, 128])
90+
net = tf.nn.relu(
91+
bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training,
92+
scope='g_bn3'))
93+
94+
out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4'))
95+
96+
return out
97+
98+
def build_model(self):
99+
# some parameters
100+
image_dims = [self.input_height, self.input_width, self.c_dim]
101+
bs = self.batch_size
102+
103+
""" Graph Input """
104+
# images
105+
self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images')
106+
107+
# labels
108+
self.y = tf.placeholder(tf.float32, [bs, self.y_dim], name='y')
109+
110+
# noises
111+
self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z')
112+
113+
""" Loss Function """
114+
## 1. GAN Loss
115+
# output of D for real images
116+
D_real, D_real_logits, input4classifier_real = self.discriminator(self.inputs, is_training=True, reuse=False)
117+
118+
# output of D for fake images
119+
G = self.generator(self.z, self.y, is_training=True, reuse=False)
120+
D_fake, D_fake_logits, input4classifier_fake = self.discriminator(G, is_training=True, reuse=True)
121+
122+
# get loss for discriminator
123+
d_loss_real = tf.reduce_mean(
124+
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real)))
125+
d_loss_fake = tf.reduce_mean(
126+
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros_like(D_fake)))
127+
128+
self.d_loss = d_loss_real + d_loss_fake
129+
130+
# get loss for generator
131+
self.g_loss = tf.reduce_mean(
132+
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones_like(D_fake)))
133+
134+
## 2. Information Loss
135+
code_fake, code_logit_fake = self.classifier(input4classifier_fake, is_training=True, reuse=False)
136+
code_real, code_logit_real = self.classifier(input4classifier_real, is_training=True, reuse=True)
137+
138+
# For real samples
139+
q_real_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=code_logit_real, labels=self.y))
140+
141+
# For fake samples
142+
q_fake_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=code_logit_fake, labels=self.y))
143+
144+
# get information loss
145+
self.q_loss = q_fake_loss + q_real_loss
146+
147+
""" Training """
148+
# divide trainable variables into a group for D and a group for G
149+
t_vars = tf.trainable_variables()
150+
d_vars = [var for var in t_vars if 'd_' in var.name]
151+
g_vars = [var for var in t_vars if 'g_' in var.name]
152+
q_vars = [var for var in t_vars if ('d_' in var.name) or ('c_' in var.name) or ('g_' in var.name)]
153+
154+
# optimizers
155+
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
156+
self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \
157+
.minimize(self.d_loss, var_list=d_vars)
158+
self.g_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \
159+
.minimize(self.g_loss, var_list=g_vars)
160+
self.q_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \
161+
.minimize(self.q_loss, var_list=q_vars)
162+
163+
"""" Testing """
164+
# for test
165+
self.fake_images = self.generator(self.z, self.y, is_training=False, reuse=True)
166+
167+
""" Summary """
168+
d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real)
169+
d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake)
170+
d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
171+
g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
172+
173+
q_loss_sum = tf.summary.scalar("g_loss", self.q_loss)
174+
q_real_sum = tf.summary.scalar("q_real_loss", q_real_loss)
175+
q_fake_sum = tf.summary.scalar("q_fake_loss", q_fake_loss)
176+
177+
# final summary operations
178+
self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum])
179+
self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum])
180+
self.q_sum = tf.summary.merge([q_loss_sum, q_real_sum, q_fake_sum])
181+
182+
def train(self):
183+
184+
# initialize all variables
185+
tf.global_variables_initializer().run()
186+
187+
# graph inputs for visualize training results
188+
self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))
189+
self.test_codes = self.data_y[0:self.batch_size]
190+
191+
# saver to save model
192+
self.saver = tf.train.Saver()
193+
194+
# summary writer
195+
self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)
196+
197+
# restore check-point if it exits
198+
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
199+
if could_load:
200+
start_epoch = (int)(checkpoint_counter / self.num_batches)
201+
start_batch_id = checkpoint_counter - start_epoch * self.num_batches
202+
counter = checkpoint_counter
203+
print(" [*] Load SUCCESS")
204+
else:
205+
start_epoch = 0
206+
start_batch_id = 0
207+
counter = 1
208+
print(" [!] Load failed...")
209+
210+
# loop for epoch
211+
start_time = time.time()
212+
for epoch in range(start_epoch, self.epoch):
213+
214+
# get batch data
215+
for idx in range(start_batch_id, self.num_batches):
216+
batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size]
217+
batch_codes = self.data_y[idx * self.batch_size:(idx + 1) * self.batch_size]
218+
219+
batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32)
220+
221+
# update D network
222+
_, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss],
223+
feed_dict={self.inputs: batch_images, self.y: batch_codes,
224+
self.z: batch_z})
225+
self.writer.add_summary(summary_str, counter)
226+
227+
# update G & Q network
228+
_, summary_str_g, g_loss, _, summary_str_q, q_loss = self.sess.run(
229+
[self.g_optim, self.g_sum, self.g_loss, self.q_optim, self.q_sum, self.q_loss],
230+
feed_dict={self.z: batch_z, self.y: batch_codes, self.inputs: batch_images})
231+
self.writer.add_summary(summary_str_g, counter)
232+
self.writer.add_summary(summary_str_q, counter)
233+
234+
# display training status
235+
counter += 1
236+
print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
237+
% (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss))
238+
239+
# save training results for every 300 steps
240+
if np.mod(counter, 300) == 0:
241+
samples = self.sess.run(self.fake_images,
242+
feed_dict={self.z: self.sample_z, self.y: self.test_codes})
243+
tot_num_samples = min(self.sample_num, self.batch_size)
244+
manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
245+
manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
246+
save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], './' + check_folder(
247+
self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
248+
epoch, idx))
249+
250+
# After an epoch, start_batch_id is set to zero
251+
# non-zero value is only for the first epoch after loading pre-trained model
252+
start_batch_id = 0
253+
254+
# save model
255+
self.save(self.checkpoint_dir, counter)
256+
257+
# show temporal results
258+
self.visualize_results(epoch)
259+
260+
# save model for final step
261+
self.save(self.checkpoint_dir, counter)
262+
263+
def visualize_results(self, epoch):
264+
tot_num_samples = min(self.sample_num, self.batch_size)
265+
image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
266+
z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))
267+
268+
""" random noise, random discrete code, fixed continuous code """
269+
y = np.random.choice(self.len_discrete_code, self.batch_size)
270+
y_one_hot = np.zeros((self.batch_size, self.y_dim))
271+
y_one_hot[np.arange(self.batch_size), y] = 1
272+
273+
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot})
274+
275+
save_images(samples[:image_frame_dim*image_frame_dim,:,:,:], [image_frame_dim, image_frame_dim],
276+
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
277+
278+
""" specified condition, random noise """
279+
n_styles = 10 # must be less than or equal to self.batch_size
280+
281+
np.random.seed()
282+
si = np.random.choice(self.batch_size, n_styles)
283+
284+
for l in range(self.len_discrete_code):
285+
y = np.zeros(self.batch_size, dtype=np.int64) + l
286+
y_one_hot = np.zeros((self.batch_size, self.y_dim))
287+
y_one_hot[np.arange(self.batch_size), y] = 1
288+
289+
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot})
290+
save_images(samples[:image_frame_dim*image_frame_dim,:,:,:], [image_frame_dim, image_frame_dim],
291+
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l)
292+
293+
samples = samples[si, :, :, :]
294+
295+
if l == 0:
296+
all_samples = samples
297+
else:
298+
all_samples = np.concatenate((all_samples, samples), axis=0)
299+
300+
""" save merged images to check style-consistency """
301+
canvas = np.zeros_like(all_samples)
302+
for s in range(n_styles):
303+
for c in range(self.len_discrete_code):
304+
canvas[s * self.len_discrete_code + c, :, :, :] = all_samples[c * n_styles + s, :, :, :]
305+
306+
save_images(canvas, [n_styles, self.len_discrete_code],
307+
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png')
308+
309+
@property
310+
def model_dir(self):
311+
return "{}_{}_{}_{}".format(
312+
self.model_name, self.dataset_name,
313+
self.batch_size, self.z_dim)
314+
315+
def save(self, checkpoint_dir, step):
316+
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
317+
318+
if not os.path.exists(checkpoint_dir):
319+
os.makedirs(checkpoint_dir)
320+
321+
self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)
322+
323+
def load(self, checkpoint_dir):
324+
import re
325+
print(" [*] Reading checkpoints...")
326+
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
327+
328+
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
329+
if ckpt and ckpt.model_checkpoint_path:
330+
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
331+
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
332+
counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
333+
print(" [*] Success to read {}".format(ckpt_name))
334+
return True, counter
335+
else:
336+
print(" [*] Failed to find a checkpoint")
337+
return False, 0

0 commit comments

Comments
 (0)