14
14
cuda .get_device (0 ).use ()
15
15
16
16
17
- def downsampling (array ):
18
- d2 = F .average_pooling_2d (array , 3 , 2 , 1 )
19
- d4 = F .average_pooling_2d (d2 , 3 , 2 , 1 )
20
-
21
- return d2 , d4
22
-
23
-
24
17
class GauGANLossFunction :
25
18
def __init__ (self ):
26
19
pass
27
20
28
21
@staticmethod
29
22
def content_loss (y , t ):
30
- return F .mean_absolute_error (y , t )
23
+ return 10.0 * F .mean_absolute_error (y , t )
31
24
32
25
@staticmethod
33
- def dis_hinge_loss (discriminator , y , t ):
34
- y_dis , _ = discriminator (y )
35
- t_dis , _ = discriminator (t )
26
+ def dis_loss (discriminator , y , t ):
27
+ y_adv_list , _ = discriminator (y )
28
+ t_adv_list , _ = discriminator (t )
29
+
30
+ sum_loss = 0
36
31
37
- return F .mean (F .relu (1. - t_dis )) + F .mean (F .relu (1. + y_dis ))
32
+ for y_adv , t_adv in zip (y_adv_list , t_adv_list ):
33
+ loss = F .mean (F .relu (1. - t_adv )) + F .mean (F .relu (1. + y_adv ))
34
+ sum_loss += loss
35
+
36
+ return sum_loss
38
37
39
38
@staticmethod
40
- def gen_hinge_loss (discriminator , y , t ):
41
- y_dis , y_feats = discriminator (y )
39
+ def gen_loss (discriminator , y , t ):
40
+ y_dis_list , y_feats = discriminator (y )
42
41
_ , t_feats = discriminator (t )
43
42
44
43
sum_loss = 0
45
- for yf , tf in zip (y_feats , t_feats ):
46
- _ , ch , height , width = yf .shape
47
- sum_loss += 10.0 * F .mean_absolute_error (yf , tf ) / (ch * height * width )
48
-
49
- return - F .mean (y_dis ) + sum_loss
50
44
45
+ # adversarial loss
46
+ for y_dis in y_dis_list :
47
+ loss = - F .mean (y_dis )
48
+ sum_loss += loss
49
+
50
+ # feature matching loss
51
+ for yf_list , tf_list in zip (y_feats , t_feats ):
52
+ for yf , tf in zip (yf_list , tf_list ):
53
+ _ , ch , height , width = yf .shape
54
+ sum_loss += 10.0 * F .mean_absolute_error (yf , tf ) / (ch * height * width )
55
+
56
+ return sum_loss
57
+
58
+
59
+ def train (epochs ,
60
+ iterations ,
61
+ batchsize ,
62
+ validsize ,
63
+ outdir ,
64
+ modeldir ,
65
+ data_path ,
66
+ extension ,
67
+ img_size ,
68
+ latent_dim ,
69
+ learning_rate ,
70
+ beta1 ,
71
+ beta2 ,
72
+ enable ):
51
73
52
- def train (epochs , iterations , batchsize , validsize , path , outdir ,
53
- con_weight , kl_weight , enable ):
54
74
# Dataset Definition
55
- dataloader = DataLoader (path )
75
+ dataloader = DataLoader (data_path , extension , img_size , latent_dim )
56
76
print (dataloader )
57
- color_valid , line_valid , _ , _ = dataloader (validsize , mode = "valid" )
77
+ color_valid , line_valid = dataloader (validsize , mode = "valid" )
58
78
noise_valid = dataloader .noise_generator (validsize )
59
79
60
80
# Model Definition
@@ -65,19 +85,11 @@ def train(epochs, iterations, batchsize, validsize, path, outdir,
65
85
66
86
generator = Generator ()
67
87
generator .to_gpu ()
68
- gen_opt = set_optimizer (generator )
88
+ gen_opt = set_optimizer (generator , learning_rate , beta1 , beta2 )
69
89
70
90
discriminator = Discriminator ()
71
91
discriminator .to_gpu ()
72
- dis_opt = set_optimizer (discriminator )
73
-
74
- discriminator_d2 = Discriminator ()
75
- discriminator_d2 .to_gpu ()
76
- dis2_opt = set_optimizer (discriminator_d2 )
77
-
78
- discriminator_d4 = Discriminator ()
79
- discriminator_d4 .to_gpu ()
80
- dis4_opt = set_optimizer (discriminator_d4 )
92
+ dis_opt = set_optimizer (discriminator , learning_rate , beta1 , beta2 )
81
93
82
94
# Loss Funtion Definition
83
95
lossfunc = GauGANLossFunction ()
@@ -86,95 +98,64 @@ def train(epochs, iterations, batchsize, validsize, path, outdir,
86
98
evaluator = Evaluaton ()
87
99
88
100
for epoch in range (epochs ):
89
- sum_loss = 0
101
+ sum_dis_loss = 0
102
+ sum_gen_loss = 0
90
103
for batch in range (0 , iterations , batchsize ):
91
- color , line , _ , _ = dataloader (batchsize )
92
-
93
- color_d2 , color_d4 = downsampling (color )
94
- line_d2 , line_d4 = downsampling (line )
104
+ color , line = dataloader (batchsize )
95
105
z = dataloader .noise_generator (batchsize )
96
106
107
+ # Discriminator update
97
108
if enable :
98
109
mu , sigma = encoder (color )
99
110
z = F .gaussian (mu , sigma )
100
111
y = generator (z , line )
101
- y_d2 , y_d4 = downsampling (y )
102
112
103
113
y .unchain_backward ()
104
- y_d2 .unchain_backward ()
105
- y_d4 .unchain_backward ()
106
114
107
- loss = lossfunc .dis_hinge_loss (
115
+ dis_loss = lossfunc .dis_loss (
108
116
discriminator ,
109
117
F .concat ([y , line ]),
110
118
F .concat ([color , line ])
111
119
)
112
- loss += lossfunc .dis_hinge_loss (
113
- discriminator_d2 ,
114
- F .concat ([y_d2 , line_d2 ]),
115
- F .concat ([color_d2 , line_d2 ])
116
- )
117
- loss += lossfunc .dis_hinge_loss (
118
- discriminator_d4 ,
119
- F .concat ([y_d4 , line_d4 ]),
120
- F .concat ([color_d4 , line_d4 ])
121
- )
122
120
123
121
discriminator .cleargrads ()
124
- discriminator_d2 .cleargrads ()
125
- discriminator_d4 .cleargrads ()
126
- loss .backward ()
122
+ dis_loss .backward ()
127
123
dis_opt .update ()
128
- dis2_opt . update ()
129
- dis4_opt . update ()
130
- loss . unchain_backward ()
124
+ dis_loss . unchain_backward ()
125
+
126
+ sum_dis_loss += dis_loss . data
131
127
128
+ # Generator update
132
129
z = dataloader .noise_generator (batchsize )
133
130
134
131
if enable :
135
132
mu , sigma = encoder (color )
136
133
z = F .gaussian (mu , sigma )
137
134
y = generator (z , line )
138
- y_d2 , y_d4 = downsampling (y )
139
135
140
- loss = lossfunc .gen_hinge_loss (
136
+ gen_loss = lossfunc .gen_loss (
141
137
discriminator ,
142
138
F .concat ([y , line ]),
143
139
F .concat ([color , line ])
144
140
)
145
- loss += lossfunc .gen_hinge_loss (
146
- discriminator_d2 ,
147
- F .concat ([y_d2 , line_d2 ]),
148
- F .concat ([color_d2 , line_d2 ])
149
- )
150
- loss += lossfunc .gen_hinge_loss (
151
- discriminator_d4 ,
152
- F .concat ([y_d4 , line_d4 ]),
153
- F .concat ([color_d4 , line_d4 ])
154
- )
155
- loss += con_weight * lossfunc .content_loss (y , color )
156
- loss += con_weight * lossfunc .content_loss (y_d2 , color_d2 )
157
- loss += con_weight * lossfunc .content_loss (y_d4 , color_d4 )
141
+ gen_loss += lossfunc .content_loss (y , color )
158
142
159
143
if enable :
160
- loss += kl_weight * F .gaussian_kl_divergence (mu , sigma ) / batchsize
144
+ gen_loss += 0.05 * F .gaussian_kl_divergence (mu , sigma ) / batchsize
161
145
162
146
generator .cleargrads ()
163
147
if enable :
164
148
encoder .cleargrads ()
165
- loss .backward ()
149
+ gen_loss .backward ()
166
150
gen_opt .update ()
167
151
if enable :
168
152
enc_opt .update ()
169
- loss .unchain_backward ()
153
+ gen_loss .unchain_backward ()
170
154
171
- sum_loss += loss .data
155
+ sum_gen_loss += gen_loss .data
172
156
173
157
if batch == 0 :
174
- serializers .save_npz (f"{ outdir } /generator.model" , generator )
175
- serializers .save_npz (f"{ outdir } /discriminator_0.model" , discriminator )
176
- serializers .save_npz (f"{ outdir } /discriminator_2.model" , discriminator_d2 )
177
- serializers .save_npz (f"{ outdir } /discriminator_4.model" , discriminator_d4 )
158
+ serializers .save_npz (f"{ modeldir } /generator_{ epoch } .model" , generator )
178
159
179
160
with chainer .using_config ("train" , False ):
180
161
y = generator (noise_valid , line_valid )
@@ -183,25 +164,35 @@ def train(epochs, iterations, batchsize, validsize, path, outdir,
183
164
cr = color_valid .data .get ()
184
165
185
166
evaluator (y , cr , sr , outdir , epoch , validsize = validsize )
186
-
187
- print (f"epoch: { epoch } " )
188
- print (f"loss: { sum_loss / iterations } " )
167
+
168
+ print (f"epoch: { epoch } " )
169
+ print (f"dis loss: { sum_dis_loss / iterations } gen loss: { sum_gen_loss / iterations } " )
189
170
190
171
191
172
if __name__ == "__main__" :
192
173
parser = argparse .ArgumentParser (description = "GauGAN" )
193
174
parser .add_argument ('--e' , type = int , default = 1000 , help = "the number of epochs" )
194
- parser .add_argument ('--i' , type = int , default = 10000 , help = "the number of iterations" )
175
+ parser .add_argument ('--i' , type = int , default = 2000 , help = "the number of iterations" )
195
176
parser .add_argument ('--b' , type = int , default = 16 , help = "batch size" )
196
- parser .add_argument ('--v' , type = int , default = 3 , help = "valid size" )
197
- parser .add_argument ('--w' , type = float , default = 10.0 , help = "the weight of content loss" )
198
- parser .add_argument ('--kl' , type = float , default = 0.05 , help = "the weight of kl divergence loss" )
177
+ parser .add_argument ('--v' , type = int , default = 12 , help = "valid size" )
178
+ parser .add_argument ('--outdir' , type = Path , default = 'outdir' , help = "output directory" )
179
+ parser .add_argument ('--modeldir' , type = Path , default = 'modeldir' , help = "model output directory" )
180
+ parser .add_argument ('--ext' , type = str , default = ".jpg" , help = "extension of training images" )
181
+ parser .add_argument ('--size' , type = int , default = 224 , help = "the size of training images" )
182
+ parser .add_argument ('--dim' , type = int , default = 256 , help = "dimensions of latent space" )
183
+ parser .add_argument ('--lr' , type = float , default = 0.0002 , help = "learning rate of Adam" )
184
+ parser .add_argument ('--b1' , type = float , default = 0.0 , help = "beta1 of Adam" )
185
+ parser .add_argument ('--b2' , type = float , default = 0.999 , help = "beta2 of Adam" )
186
+ parser .add_argument ('--data_path' , type = Path , help = "path which contains training data" )
199
187
parser .add_argument ('--encoder' , action = "store_true" , help = "enable image encoder" )
200
188
201
189
args = parser .parse_args ()
202
190
203
- dataset_path = Path ('./Dataset/danbooru-images/' )
204
- outdir = Path ('./outdir' )
191
+ outdir = args .outdir
205
192
outdir .mkdir (exist_ok = True )
206
193
207
- train (args .e , args .i , args .b , args .v , dataset_path , outdir , args .w , args .kl , args .encoder )
194
+ modeldir = args .modeldir
195
+ modeldir .mkdir (exist_ok = True )
196
+
197
+ train (args .e , args .i , args .b , args .v , outdir , modeldir , args .data_path ,
198
+ args .ext , args .size , args .dim , args .lr , args .b1 , args .b2 , args .encoder )
0 commit comments