1
+ #-*- coding: utf-8 -*-
2
+ from __future__ import division
3
+ import os
4
+ import time
5
+ import tensorflow as tf
6
+ import numpy as np
7
+
8
+ from ops import *
9
+ from utils import *
10
+
11
+ class ACGAN (object ):
12
+ model_name = "ACGAN" # name for checkpoint
13
+
14
+ def __init__ (self , sess , epoch , batch_size , z_dim , dataset_name , checkpoint_dir , result_dir , log_dir ):
15
+ self .sess = sess
16
+ self .dataset_name = dataset_name
17
+ self .checkpoint_dir = checkpoint_dir
18
+ self .result_dir = result_dir
19
+ self .log_dir = log_dir
20
+ self .epoch = epoch
21
+ self .batch_size = batch_size
22
+
23
+ if dataset_name == 'mnist' or dataset_name == 'fashion-mnist' :
24
+ # parameters
25
+ self .input_height = 28
26
+ self .input_width = 28
27
+ self .output_height = 28
28
+ self .output_width = 28
29
+
30
+ self .z_dim = z_dim # dimension of noise-vector
31
+ self .y_dim = 10 # dimension of code-vector (label)
32
+ self .c_dim = 1
33
+
34
+ # train
35
+ self .learning_rate = 0.0002
36
+ self .beta1 = 0.5
37
+
38
+ # test
39
+ self .sample_num = 64 # number of generated images to be saved
40
+
41
+ # code
42
+ self .len_discrete_code = 10 # categorical distribution (i.e. label)
43
+ self .len_continuous_code = 2 # gaussian distribution (e.g. rotation, thickness)
44
+
45
+ # load mnist
46
+ self .data_X , self .data_y = load_mnist (self .dataset_name )
47
+
48
+ # get number of batches for a single epoch
49
+ self .num_batches = len (self .data_X ) // self .batch_size
50
+ else :
51
+ raise NotImplementedError
52
+
53
+ def classifier (self , x , is_training = True , reuse = False ):
54
+ # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
55
+ # Architecture : (64)5c2s-(128)5c2s_BL-FC1024_BL-FC128_BL-FC12S’
56
+ # All layers except the last two layers are shared by discriminator
57
+ with tf .variable_scope ("classifier" , reuse = reuse ):
58
+
59
+ net = lrelu (bn (linear (x , 128 , scope = 'c_fc1' ), is_training = is_training , scope = 'c_bn1' ))
60
+ out_logit = linear (net , self .y_dim , scope = 'c_fc2' )
61
+ out = tf .nn .softmax (out_logit )
62
+
63
+ return out , out_logit
64
+
65
+ def discriminator (self , x , is_training = True , reuse = False ):
66
+ # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
67
+ # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
68
+ with tf .variable_scope ("discriminator" , reuse = reuse ):
69
+
70
+ net = lrelu (conv2d (x , 64 , 4 , 4 , 2 , 2 , name = 'd_conv1' ))
71
+ net = lrelu (bn (conv2d (net , 128 , 4 , 4 , 2 , 2 , name = 'd_conv2' ), is_training = is_training , scope = 'd_bn2' ))
72
+ net = tf .reshape (net , [self .batch_size , - 1 ])
73
+ net = lrelu (bn (linear (net , 1024 , scope = 'd_fc3' ), is_training = is_training , scope = 'd_bn3' ))
74
+ out_logit = linear (net , 1 , scope = 'd_fc4' )
75
+ out = tf .nn .sigmoid (out_logit )
76
+
77
+ return out , out_logit , net
78
+
79
+ def generator (self , z , y , is_training = True , reuse = False ):
80
+ # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
81
+ # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
82
+ with tf .variable_scope ("generator" , reuse = reuse ):
83
+
84
+ # merge noise and code
85
+ z = concat ([z , y ], 1 )
86
+
87
+ net = tf .nn .relu (bn (linear (z , 1024 , scope = 'g_fc1' ), is_training = is_training , scope = 'g_bn1' ))
88
+ net = tf .nn .relu (bn (linear (net , 128 * 7 * 7 , scope = 'g_fc2' ), is_training = is_training , scope = 'g_bn2' ))
89
+ net = tf .reshape (net , [self .batch_size , 7 , 7 , 128 ])
90
+ net = tf .nn .relu (
91
+ bn (deconv2d (net , [self .batch_size , 14 , 14 , 64 ], 4 , 4 , 2 , 2 , name = 'g_dc3' ), is_training = is_training ,
92
+ scope = 'g_bn3' ))
93
+
94
+ out = tf .nn .sigmoid (deconv2d (net , [self .batch_size , 28 , 28 , 1 ], 4 , 4 , 2 , 2 , name = 'g_dc4' ))
95
+
96
+ return out
97
+
98
+ def build_model (self ):
99
+ # some parameters
100
+ image_dims = [self .input_height , self .input_width , self .c_dim ]
101
+ bs = self .batch_size
102
+
103
+ """ Graph Input """
104
+ # images
105
+ self .inputs = tf .placeholder (tf .float32 , [bs ] + image_dims , name = 'real_images' )
106
+
107
+ # labels
108
+ self .y = tf .placeholder (tf .float32 , [bs , self .y_dim ], name = 'y' )
109
+
110
+ # noises
111
+ self .z = tf .placeholder (tf .float32 , [bs , self .z_dim ], name = 'z' )
112
+
113
+ """ Loss Function """
114
+ ## 1. GAN Loss
115
+ # output of D for real images
116
+ D_real , D_real_logits , input4classifier_real = self .discriminator (self .inputs , is_training = True , reuse = False )
117
+
118
+ # output of D for fake images
119
+ G = self .generator (self .z , self .y , is_training = True , reuse = False )
120
+ D_fake , D_fake_logits , input4classifier_fake = self .discriminator (G , is_training = True , reuse = True )
121
+
122
+ # get loss for discriminator
123
+ d_loss_real = tf .reduce_mean (
124
+ tf .nn .sigmoid_cross_entropy_with_logits (logits = D_real_logits , labels = tf .ones_like (D_real )))
125
+ d_loss_fake = tf .reduce_mean (
126
+ tf .nn .sigmoid_cross_entropy_with_logits (logits = D_fake_logits , labels = tf .zeros_like (D_fake )))
127
+
128
+ self .d_loss = d_loss_real + d_loss_fake
129
+
130
+ # get loss for generator
131
+ self .g_loss = tf .reduce_mean (
132
+ tf .nn .sigmoid_cross_entropy_with_logits (logits = D_fake_logits , labels = tf .ones_like (D_fake )))
133
+
134
+ ## 2. Information Loss
135
+ code_fake , code_logit_fake = self .classifier (input4classifier_fake , is_training = True , reuse = False )
136
+ code_real , code_logit_real = self .classifier (input4classifier_real , is_training = True , reuse = True )
137
+
138
+ # For real samples
139
+ q_real_loss = tf .reduce_mean (tf .nn .softmax_cross_entropy_with_logits (logits = code_logit_real , labels = self .y ))
140
+
141
+ # For fake samples
142
+ q_fake_loss = tf .reduce_mean (tf .nn .softmax_cross_entropy_with_logits (logits = code_logit_fake , labels = self .y ))
143
+
144
+ # get information loss
145
+ self .q_loss = q_fake_loss + q_real_loss
146
+
147
+ """ Training """
148
+ # divide trainable variables into a group for D and a group for G
149
+ t_vars = tf .trainable_variables ()
150
+ d_vars = [var for var in t_vars if 'd_' in var .name ]
151
+ g_vars = [var for var in t_vars if 'g_' in var .name ]
152
+ q_vars = [var for var in t_vars if ('d_' in var .name ) or ('c_' in var .name ) or ('g_' in var .name )]
153
+
154
+ # optimizers
155
+ with tf .control_dependencies (tf .get_collection (tf .GraphKeys .UPDATE_OPS )):
156
+ self .d_optim = tf .train .AdamOptimizer (self .learning_rate , beta1 = self .beta1 ) \
157
+ .minimize (self .d_loss , var_list = d_vars )
158
+ self .g_optim = tf .train .AdamOptimizer (self .learning_rate * 5 , beta1 = self .beta1 ) \
159
+ .minimize (self .g_loss , var_list = g_vars )
160
+ self .q_optim = tf .train .AdamOptimizer (self .learning_rate * 5 , beta1 = self .beta1 ) \
161
+ .minimize (self .q_loss , var_list = q_vars )
162
+
163
+ """" Testing """
164
+ # for test
165
+ self .fake_images = self .generator (self .z , self .y , is_training = False , reuse = True )
166
+
167
+ """ Summary """
168
+ d_loss_real_sum = tf .summary .scalar ("d_loss_real" , d_loss_real )
169
+ d_loss_fake_sum = tf .summary .scalar ("d_loss_fake" , d_loss_fake )
170
+ d_loss_sum = tf .summary .scalar ("d_loss" , self .d_loss )
171
+ g_loss_sum = tf .summary .scalar ("g_loss" , self .g_loss )
172
+
173
+ q_loss_sum = tf .summary .scalar ("g_loss" , self .q_loss )
174
+ q_real_sum = tf .summary .scalar ("q_real_loss" , q_real_loss )
175
+ q_fake_sum = tf .summary .scalar ("q_fake_loss" , q_fake_loss )
176
+
177
+ # final summary operations
178
+ self .g_sum = tf .summary .merge ([d_loss_fake_sum , g_loss_sum ])
179
+ self .d_sum = tf .summary .merge ([d_loss_real_sum , d_loss_sum ])
180
+ self .q_sum = tf .summary .merge ([q_loss_sum , q_real_sum , q_fake_sum ])
181
+
182
+ def train (self ):
183
+
184
+ # initialize all variables
185
+ tf .global_variables_initializer ().run ()
186
+
187
+ # graph inputs for visualize training results
188
+ self .sample_z = np .random .uniform (- 1 , 1 , size = (self .batch_size , self .z_dim ))
189
+ self .test_codes = self .data_y [0 :self .batch_size ]
190
+
191
+ # saver to save model
192
+ self .saver = tf .train .Saver ()
193
+
194
+ # summary writer
195
+ self .writer = tf .summary .FileWriter (self .log_dir + '/' + self .model_name , self .sess .graph )
196
+
197
+ # restore check-point if it exits
198
+ could_load , checkpoint_counter = self .load (self .checkpoint_dir )
199
+ if could_load :
200
+ start_epoch = (int )(checkpoint_counter / self .num_batches )
201
+ start_batch_id = checkpoint_counter - start_epoch * self .num_batches
202
+ counter = checkpoint_counter
203
+ print (" [*] Load SUCCESS" )
204
+ else :
205
+ start_epoch = 0
206
+ start_batch_id = 0
207
+ counter = 1
208
+ print (" [!] Load failed..." )
209
+
210
+ # loop for epoch
211
+ start_time = time .time ()
212
+ for epoch in range (start_epoch , self .epoch ):
213
+
214
+ # get batch data
215
+ for idx in range (start_batch_id , self .num_batches ):
216
+ batch_images = self .data_X [idx * self .batch_size :(idx + 1 )* self .batch_size ]
217
+ batch_codes = self .data_y [idx * self .batch_size :(idx + 1 ) * self .batch_size ]
218
+
219
+ batch_z = np .random .uniform (- 1 , 1 , [self .batch_size , self .z_dim ]).astype (np .float32 )
220
+
221
+ # update D network
222
+ _ , summary_str , d_loss = self .sess .run ([self .d_optim , self .d_sum , self .d_loss ],
223
+ feed_dict = {self .inputs : batch_images , self .y : batch_codes ,
224
+ self .z : batch_z })
225
+ self .writer .add_summary (summary_str , counter )
226
+
227
+ # update G & Q network
228
+ _ , summary_str_g , g_loss , _ , summary_str_q , q_loss = self .sess .run (
229
+ [self .g_optim , self .g_sum , self .g_loss , self .q_optim , self .q_sum , self .q_loss ],
230
+ feed_dict = {self .z : batch_z , self .y : batch_codes , self .inputs : batch_images })
231
+ self .writer .add_summary (summary_str_g , counter )
232
+ self .writer .add_summary (summary_str_q , counter )
233
+
234
+ # display training status
235
+ counter += 1
236
+ print ("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
237
+ % (epoch , idx , self .num_batches , time .time () - start_time , d_loss , g_loss ))
238
+
239
+ # save training results for every 300 steps
240
+ if np .mod (counter , 300 ) == 0 :
241
+ samples = self .sess .run (self .fake_images ,
242
+ feed_dict = {self .z : self .sample_z , self .y : self .test_codes })
243
+ tot_num_samples = min (self .sample_num , self .batch_size )
244
+ manifold_h = int (np .floor (np .sqrt (tot_num_samples )))
245
+ manifold_w = int (np .floor (np .sqrt (tot_num_samples )))
246
+ save_images (samples [:manifold_h * manifold_w , :, :, :], [manifold_h , manifold_w ], './' + check_folder (
247
+ self .result_dir + '/' + self .model_dir ) + '/' + self .model_name + '_train_{:02d}_{:04d}.png' .format (
248
+ epoch , idx ))
249
+
250
+ # After an epoch, start_batch_id is set to zero
251
+ # non-zero value is only for the first epoch after loading pre-trained model
252
+ start_batch_id = 0
253
+
254
+ # save model
255
+ self .save (self .checkpoint_dir , counter )
256
+
257
+ # show temporal results
258
+ self .visualize_results (epoch )
259
+
260
+ # save model for final step
261
+ self .save (self .checkpoint_dir , counter )
262
+
263
+ def visualize_results (self , epoch ):
264
+ tot_num_samples = min (self .sample_num , self .batch_size )
265
+ image_frame_dim = int (np .floor (np .sqrt (tot_num_samples )))
266
+ z_sample = np .random .uniform (- 1 , 1 , size = (self .batch_size , self .z_dim ))
267
+
268
+ """ random noise, random discrete code, fixed continuous code """
269
+ y = np .random .choice (self .len_discrete_code , self .batch_size )
270
+ y_one_hot = np .zeros ((self .batch_size , self .y_dim ))
271
+ y_one_hot [np .arange (self .batch_size ), y ] = 1
272
+
273
+ samples = self .sess .run (self .fake_images , feed_dict = {self .z : z_sample , self .y : y_one_hot })
274
+
275
+ save_images (samples [:image_frame_dim * image_frame_dim ,:,:,:], [image_frame_dim , image_frame_dim ],
276
+ check_folder (self .result_dir + '/' + self .model_dir ) + '/' + self .model_name + '_epoch%03d' % epoch + '_test_all_classes.png' )
277
+
278
+ """ specified condition, random noise """
279
+ n_styles = 10 # must be less than or equal to self.batch_size
280
+
281
+ np .random .seed ()
282
+ si = np .random .choice (self .batch_size , n_styles )
283
+
284
+ for l in range (self .len_discrete_code ):
285
+ y = np .zeros (self .batch_size , dtype = np .int64 ) + l
286
+ y_one_hot = np .zeros ((self .batch_size , self .y_dim ))
287
+ y_one_hot [np .arange (self .batch_size ), y ] = 1
288
+
289
+ samples = self .sess .run (self .fake_images , feed_dict = {self .z : z_sample , self .y : y_one_hot })
290
+ save_images (samples [:image_frame_dim * image_frame_dim ,:,:,:], [image_frame_dim , image_frame_dim ],
291
+ check_folder (self .result_dir + '/' + self .model_dir ) + '/' + self .model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l )
292
+
293
+ samples = samples [si , :, :, :]
294
+
295
+ if l == 0 :
296
+ all_samples = samples
297
+ else :
298
+ all_samples = np .concatenate ((all_samples , samples ), axis = 0 )
299
+
300
+ """ save merged images to check style-consistency """
301
+ canvas = np .zeros_like (all_samples )
302
+ for s in range (n_styles ):
303
+ for c in range (self .len_discrete_code ):
304
+ canvas [s * self .len_discrete_code + c , :, :, :] = all_samples [c * n_styles + s , :, :, :]
305
+
306
+ save_images (canvas , [n_styles , self .len_discrete_code ],
307
+ check_folder (self .result_dir + '/' + self .model_dir ) + '/' + self .model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png' )
308
+
309
+ @property
310
+ def model_dir (self ):
311
+ return "{}_{}_{}_{}" .format (
312
+ self .model_name , self .dataset_name ,
313
+ self .batch_size , self .z_dim )
314
+
315
+ def save (self , checkpoint_dir , step ):
316
+ checkpoint_dir = os .path .join (checkpoint_dir , self .model_dir , self .model_name )
317
+
318
+ if not os .path .exists (checkpoint_dir ):
319
+ os .makedirs (checkpoint_dir )
320
+
321
+ self .saver .save (self .sess ,os .path .join (checkpoint_dir , self .model_name + '.model' ), global_step = step )
322
+
323
+ def load (self , checkpoint_dir ):
324
+ import re
325
+ print (" [*] Reading checkpoints..." )
326
+ checkpoint_dir = os .path .join (checkpoint_dir , self .model_dir , self .model_name )
327
+
328
+ ckpt = tf .train .get_checkpoint_state (checkpoint_dir )
329
+ if ckpt and ckpt .model_checkpoint_path :
330
+ ckpt_name = os .path .basename (ckpt .model_checkpoint_path )
331
+ self .saver .restore (self .sess , os .path .join (checkpoint_dir , ckpt_name ))
332
+ counter = int (next (re .finditer ("(\d+)(?!.*\d)" ,ckpt_name )).group (0 ))
333
+ print (" [*] Success to read {}" .format (ckpt_name ))
334
+ return True , counter
335
+ else :
336
+ print (" [*] Failed to find a checkpoint" )
337
+ return False , 0
0 commit comments