Skip to content

Commit 6e5b0ad

Browse files
authoredFeb 14, 2020
update
1 parent 22b689a commit 6e5b0ad

File tree

6 files changed

+160
-171
lines changed

6 files changed

+160
-171
lines changed
 

‎GauGAN/dataset.py

+30-67
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,49 @@
44
import copy
55
import chainer
66

7-
from xdog import line_process
7+
from xdog import xdog_process
88
from chainer import cuda
99

1010
xp = cuda.cupy
1111
cuda.get_device(0).use()
1212

1313

1414
class DataLoader:
15-
def __init__(self, path):
15+
def __init__(self,
16+
path,
17+
extension='.jpg',
18+
img_size=224,
19+
latent_dim=256):
20+
1621
self.path = path
17-
self.pathlist = list(self.path.glob('**/*.jpg'))
22+
self.pathlist = list(self.path.glob(f"**/*{extension}"))
1823
self.train, self.valid = self._split(self.pathlist)
1924
self.train_len = len(self.train)
2025
self.valid_len = len(self.valid)
2126

27+
self.size = img_size
28+
self.latent_dim = latent_dim
29+
30+
self.interpolations = (
31+
cv.INTER_LINEAR,
32+
cv.INTER_AREA,
33+
cv.INTER_NEAREST,
34+
cv.INTER_CUBIC,
35+
cv.INTER_LANCZOS4
36+
)
37+
2238
def __str__(self):
2339
return f"dataset path: {self.path} train data: {self.train_len}"
2440

2541
def _split(self, pathlist: list):
26-
split_point = int(len(self.pathlist) * 0.9)
42+
split_point = int(len(self.pathlist) * 0.95)
2743
x_train = self.pathlist[:split_point]
2844
x_test = self.pathlist[split_point:]
2945

3046
return x_train, x_test
3147

3248
@staticmethod
33-
def _random_crop(line, color, size=224):
49+
def _random_crop(line, color, size):
3450
height, width = line.shape[0], line.shape[1]
3551
rnd0 = np.random.randint(height - size - 1)
3652
rnd1 = np.random.randint(width - size - 1)
@@ -52,76 +68,27 @@ def _coordinate(image):
5268
def _variable(image_list):
5369
return chainer.as_variable(xp.array(image_list).astype(xp.float32))
5470

55-
@staticmethod
56-
def noise_generator(batchsize):
57-
noise = xp.random.normal(size=(batchsize, 256)).astype(xp.float32)
71+
def noise_generator(self, batchsize):
72+
noise = xp.random.normal(size=(batchsize, self.latent_dim)).astype(xp.float32)
5873

5974
return chainer.as_variable(noise)
6075

61-
@staticmethod
62-
def _making_mask(mask, color, size=224):
63-
choice = np.random.choice(['width', 'height', 'diag'])
64-
65-
if choice == 'width':
66-
rnd_height = np.random.randint(4, 8)
67-
rnd_width = np.random.randint(4, 64)
68-
69-
rnd1 = np.random.randint(size - rnd_height)
70-
rnd2 = np.random.randint(size - rnd_width)
71-
mask[rnd1:rnd1+rnd_height, rnd2:rnd2+rnd_width] = color[rnd1:rnd1+rnd_height, rnd2:rnd2+rnd_width]
72-
73-
elif choice == 'height':
74-
rnd_height = np.random.randint(4, 64)
75-
rnd_width = np.random.randint(4, 8)
76-
77-
rnd1 = np.random.randint(size - rnd_height)
78-
rnd2 = np.random.randint(size - rnd_width)
79-
mask[rnd1:rnd1+rnd_height, rnd2:rnd2+rnd_width] = color[rnd1:rnd1+rnd_height, rnd2:rnd2+rnd_width]
80-
81-
elif choice == 'diag':
82-
rnd_height = np.random.randint(4, 8)
83-
rnd_width = np.random.randint(4, 64)
84-
85-
rnd1 = np.random.randint(size - rnd_height - rnd_width - 1)
86-
rnd2 = np.random.randint(size - rnd_width)
87-
88-
for index in range(rnd_width):
89-
mask[rnd1 + index : rnd1 + rnd_height + index, rnd2 + index] = color[rnd1 + index: rnd1 + rnd_height + index, rnd2 + index]
90-
91-
return mask
92-
93-
def _prepare_pair(self, image_path, size=224, repeat=16):
94-
interpolations = (
95-
cv.INTER_LINEAR,
96-
cv.INTER_AREA,
97-
cv.INTER_NEAREST,
98-
cv.INTER_CUBIC,
99-
cv.INTER_LANCZOS4
100-
)
101-
interpolation = random.choice(interpolations)
76+
def _prepare_pair(self, image_path, size, repeat=16):
77+
interpolation = random.choice(self.interpolations)
10278

10379
color = cv.imread(str(image_path))
104-
line = line_process(str(image_path))
80+
line = xdog_process(str(image_path))
10581

10682
line, color = self._random_crop(line, color, size=size)
107-
mask = copy.copy(line)
108-
109-
for _ in range(repeat):
110-
mask = self._making_mask(mask, color, size=size)
111-
mask_ds = cv.resize(mask, (int(size/2), int(size/2)), interpolation=interpolation)
11283

11384
color = self._coordinate(color)
11485
line = self._coordinate(line)
115-
mask = self._coordinate(mask)
116-
mask_ds = self._coordinate(mask_ds)
11786

118-
return (color, line, mask, mask_ds)
87+
return (color, line)
11988

120-
def __call__(self, batchsize, mode='train', size=224):
89+
def __call__(self, batchsize, mode='train'):
12190
color_box = []
12291
line_box = []
123-
mask_box = []
124-
mask_ds_box = []
12592

12693
for _ in range(batchsize):
12794
if mode == 'train':
@@ -133,16 +100,12 @@ def __call__(self, batchsize, mode='train', size=224):
133100
else:
134101
raise AttributeError
135102

136-
color, line, mask, mask_ds = self._prepare_pair(image_path, size=size)
103+
color, line = self._prepare_pair(image_path, size=self.size)
137104

138105
color_box.append(color)
139106
line_box.append(line)
140-
mask_box.append(mask)
141-
mask_ds_box.append(mask_ds)
142107

143108
color = self._variable(color_box)
144109
line = self._variable(line_box)
145-
mask = self._variable(mask_box)
146-
mask_ds = self._variable(mask_ds_box)
147110

148-
return (color, line, mask, mask_ds)
111+
return (color, line)

‎GauGAN/evaluation.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,21 @@ def __call__(self, y, t, x, outdir, epoch, validsize=3):
1818
pylab.rcParams['figure.figsize'] = (16.0, 16.0)
1919
pylab.clf()
2020

21+
wid = int(validsize/2)
22+
2123
for index in range(validsize):
2224
tmp = self._coordinate(x[index])
23-
pylab.subplot(validsize, validsize, validsize * index + 1)
25+
pylab.subplot(wid, wid, 3 * index + 1)
2426
pylab.imshow(tmp)
2527
pylab.axis('off')
2628
pylab.savefig(f"{outdir}/visualize_{epoch}.png")
2729
tmp = self._coordinate(t[index])
28-
pylab.subplot(validsize, validsize, validsize * index + 2)
30+
pylab.subplot(wid, wid, 3 * index + 2)
2931
pylab.imshow(tmp)
3032
pylab.axis('off')
3133
pylab.savefig(f"{outdir}/visualize_{epoch}.png")
3234
tmp = self._coordinate(y[index])
33-
pylab.subplot(validsize, validsize, validsize * index + 3)
35+
pylab.subplot(wid, wid, 3 * index + 3)
3436
pylab.imshow(tmp)
3537
pylab.axis('off')
3638
pylab.savefig(f"{outdir}/visualize_{epoch}.png")

‎GauGAN/model.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,10 @@ def __call__(self, x, c):
121121
return F.tanh(h)
122122

123123

124-
class Discriminator(Chain):
124+
class DiscriminatorBlock(Chain):
125125
def __init__(self, base=64):
126126
w = initializers.GlorotUniform()
127-
super(Discriminator, self).__init__()
127+
super(DiscriminatorBlock, self).__init__()
128128
with self.init_scope():
129129
self.c0 = SNConvolution2D(6, base, 4, 2, 1, initialW=w)
130130
self.c1 = SNConvolution2D(base, base*2, 4, 2, 1, initialW=w)
@@ -148,6 +148,30 @@ def __call__(self, x):
148148
return h, [h1, h2, h3, h4]
149149

150150

151+
class Discriminator(Chain):
152+
def __init__(self, base=64):
153+
super(Discriminator, self).__init__()
154+
discriminators = chainer.ChainList()
155+
for _ in range(3):
156+
discriminators.add_link(DiscriminatorBlock())
157+
with self.init_scope():
158+
self.dis = discriminators
159+
160+
def __call__(self, x):
161+
adv_list = []
162+
feat_list = []
163+
164+
for index in range(3):
165+
h, h_list = self.dis[index](x)
166+
167+
adv_list.append(h)
168+
feat_list.append(h_list)
169+
170+
x = F.average_pooling_2d(x, 3, 2, 1)
171+
172+
return adv_list, feat_list
173+
174+
151175
class Prior(chainer.Link):
152176

153177
def __init__(self):
@@ -157,4 +181,4 @@ def __init__(self):
157181
self.scale = xp.ones(256, xp.float32)
158182

159183
def __call__(self):
160-
return D.Normal(self.loc, scale=self.scale)
184+
return D.Normal(self.loc, scale=self.scale)

‎GauGAN/sn.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from chainer.links.connection.linear import Linear
99
import chainer.functions as F
1010

11+
1112
def _l2normalize(v, eps=1e-12):
1213
norm = cuda.reduce('T x', 'T out',
1314
'x * x', 'a + b', 'out = sqrt(a)', 0,
@@ -19,6 +20,7 @@ def _l2normalize(v, eps=1e-12):
1920

2021
return div(v, norm(v), eps)
2122

23+
2224
def max_singular_value(W, u=None, Ip=1):
2325
"""
2426
Apply power iteration for the weight parameter
@@ -37,6 +39,7 @@ def max_singular_value(W, u=None, Ip=1):
3739

3840
return sigma, _u, _v
3941

42+
4043
class SNConvolution2D(Convolution2D):
4144
"""Two-dimensional convolutional layer with spectral normalization.
4245
This link wraps the :func:`~chainer.functions.convolution_2d` function and
@@ -129,6 +132,7 @@ def __call__(self, x):
129132
return convolution_2d.convolution_2d(
130133
x, self.W_bar, self.b, self.stride, self.pad)
131134

135+
132136
class SNLinear(Linear):
133137
"""Linear layer with Spectral Normalization.
134138
Args:
@@ -204,4 +208,4 @@ def __call__(self, x):
204208
"""
205209
if self.W.data is None:
206210
self._initialize_params(x.size // x.shape[0])
207-
return linear.linear(x, self.W_bar, self.b)
211+
return linear.linear(x, self.W_bar, self.b)

‎GauGAN/train.py

+82-91
Original file line numberDiff line numberDiff line change
@@ -14,47 +14,67 @@
1414
cuda.get_device(0).use()
1515

1616

17-
def downsampling(array):
18-
d2 = F.average_pooling_2d(array, 3, 2, 1)
19-
d4 = F.average_pooling_2d(d2, 3, 2, 1)
20-
21-
return d2, d4
22-
23-
2417
class GauGANLossFunction:
2518
def __init__(self):
2619
pass
2720

2821
@staticmethod
2922
def content_loss(y, t):
30-
return F.mean_absolute_error(y, t)
23+
return 10.0 * F.mean_absolute_error(y, t)
3124

3225
@staticmethod
33-
def dis_hinge_loss(discriminator, y, t):
34-
y_dis, _ = discriminator(y)
35-
t_dis, _ = discriminator(t)
26+
def dis_loss(discriminator, y, t):
27+
y_adv_list, _ = discriminator(y)
28+
t_adv_list, _ = discriminator(t)
29+
30+
sum_loss = 0
3631

37-
return F.mean(F.relu(1. - t_dis)) + F.mean(F.relu(1. + y_dis))
32+
for y_adv, t_adv in zip(y_adv_list, t_adv_list):
33+
loss = F.mean(F.relu(1. - t_adv)) + F.mean(F.relu(1. + y_adv))
34+
sum_loss += loss
35+
36+
return sum_loss
3837

3938
@staticmethod
40-
def gen_hinge_loss(discriminator, y, t):
41-
y_dis, y_feats = discriminator(y)
39+
def gen_loss(discriminator, y, t):
40+
y_dis_list, y_feats = discriminator(y)
4241
_, t_feats = discriminator(t)
4342

4443
sum_loss = 0
45-
for yf, tf in zip(y_feats, t_feats):
46-
_, ch, height, width = yf.shape
47-
sum_loss += 10.0 * F.mean_absolute_error(yf, tf) / (ch * height * width)
48-
49-
return -F.mean(y_dis) + sum_loss
5044

45+
# adversarial loss
46+
for y_dis in y_dis_list:
47+
loss = -F.mean(y_dis)
48+
sum_loss += loss
49+
50+
# feature matching loss
51+
for yf_list, tf_list in zip(y_feats, t_feats):
52+
for yf, tf in zip(yf_list, tf_list):
53+
_, ch, height, width = yf.shape
54+
sum_loss += 10.0 * F.mean_absolute_error(yf, tf) / (ch * height * width)
55+
56+
return sum_loss
57+
58+
59+
def train(epochs,
60+
iterations,
61+
batchsize,
62+
validsize,
63+
outdir,
64+
modeldir,
65+
data_path,
66+
extension,
67+
img_size,
68+
latent_dim,
69+
learning_rate,
70+
beta1,
71+
beta2,
72+
enable):
5173

52-
def train(epochs, iterations, batchsize, validsize, path, outdir,
53-
con_weight, kl_weight, enable):
5474
# Dataset Definition
55-
dataloader = DataLoader(path)
75+
dataloader = DataLoader(data_path, extension, img_size, latent_dim)
5676
print(dataloader)
57-
color_valid, line_valid, _, _ = dataloader(validsize, mode="valid")
77+
color_valid, line_valid = dataloader(validsize, mode="valid")
5878
noise_valid = dataloader.noise_generator(validsize)
5979

6080
# Model Definition
@@ -65,19 +85,11 @@ def train(epochs, iterations, batchsize, validsize, path, outdir,
6585

6686
generator = Generator()
6787
generator.to_gpu()
68-
gen_opt = set_optimizer(generator)
88+
gen_opt = set_optimizer(generator, learning_rate, beta1, beta2)
6989

7090
discriminator = Discriminator()
7191
discriminator.to_gpu()
72-
dis_opt = set_optimizer(discriminator)
73-
74-
discriminator_d2 = Discriminator()
75-
discriminator_d2.to_gpu()
76-
dis2_opt = set_optimizer(discriminator_d2)
77-
78-
discriminator_d4 = Discriminator()
79-
discriminator_d4.to_gpu()
80-
dis4_opt = set_optimizer(discriminator_d4)
92+
dis_opt = set_optimizer(discriminator, learning_rate, beta1, beta2)
8193

8294
# Loss Funtion Definition
8395
lossfunc = GauGANLossFunction()
@@ -86,95 +98,64 @@ def train(epochs, iterations, batchsize, validsize, path, outdir,
8698
evaluator = Evaluaton()
8799

88100
for epoch in range(epochs):
89-
sum_loss = 0
101+
sum_dis_loss = 0
102+
sum_gen_loss = 0
90103
for batch in range(0, iterations, batchsize):
91-
color, line, _, _ = dataloader(batchsize)
92-
93-
color_d2, color_d4 = downsampling(color)
94-
line_d2, line_d4 = downsampling(line)
104+
color, line = dataloader(batchsize)
95105
z = dataloader.noise_generator(batchsize)
96106

107+
# Discriminator update
97108
if enable:
98109
mu, sigma = encoder(color)
99110
z = F.gaussian(mu, sigma)
100111
y = generator(z, line)
101-
y_d2, y_d4 = downsampling(y)
102112

103113
y.unchain_backward()
104-
y_d2.unchain_backward()
105-
y_d4.unchain_backward()
106114

107-
loss = lossfunc.dis_hinge_loss(
115+
dis_loss = lossfunc.dis_loss(
108116
discriminator,
109117
F.concat([y, line]),
110118
F.concat([color, line])
111119
)
112-
loss += lossfunc.dis_hinge_loss(
113-
discriminator_d2,
114-
F.concat([y_d2, line_d2]),
115-
F.concat([color_d2, line_d2])
116-
)
117-
loss += lossfunc.dis_hinge_loss(
118-
discriminator_d4,
119-
F.concat([y_d4, line_d4]),
120-
F.concat([color_d4, line_d4])
121-
)
122120

123121
discriminator.cleargrads()
124-
discriminator_d2.cleargrads()
125-
discriminator_d4.cleargrads()
126-
loss.backward()
122+
dis_loss.backward()
127123
dis_opt.update()
128-
dis2_opt.update()
129-
dis4_opt.update()
130-
loss.unchain_backward()
124+
dis_loss.unchain_backward()
125+
126+
sum_dis_loss += dis_loss.data
131127

128+
# Generator update
132129
z = dataloader.noise_generator(batchsize)
133130

134131
if enable:
135132
mu, sigma = encoder(color)
136133
z = F.gaussian(mu, sigma)
137134
y = generator(z, line)
138-
y_d2, y_d4 = downsampling(y)
139135

140-
loss = lossfunc.gen_hinge_loss(
136+
gen_loss = lossfunc.gen_loss(
141137
discriminator,
142138
F.concat([y, line]),
143139
F.concat([color, line])
144140
)
145-
loss += lossfunc.gen_hinge_loss(
146-
discriminator_d2,
147-
F.concat([y_d2, line_d2]),
148-
F.concat([color_d2, line_d2])
149-
)
150-
loss += lossfunc.gen_hinge_loss(
151-
discriminator_d4,
152-
F.concat([y_d4, line_d4]),
153-
F.concat([color_d4, line_d4])
154-
)
155-
loss += con_weight * lossfunc.content_loss(y, color)
156-
loss += con_weight * lossfunc.content_loss(y_d2, color_d2)
157-
loss += con_weight * lossfunc.content_loss(y_d4, color_d4)
141+
gen_loss += lossfunc.content_loss(y, color)
158142

159143
if enable:
160-
loss += kl_weight * F.gaussian_kl_divergence(mu, sigma) / batchsize
144+
gen_loss += 0.05 * F.gaussian_kl_divergence(mu, sigma) / batchsize
161145

162146
generator.cleargrads()
163147
if enable:
164148
encoder.cleargrads()
165-
loss.backward()
149+
gen_loss.backward()
166150
gen_opt.update()
167151
if enable:
168152
enc_opt.update()
169-
loss.unchain_backward()
153+
gen_loss.unchain_backward()
170154

171-
sum_loss += loss.data
155+
sum_gen_loss += gen_loss.data
172156

173157
if batch == 0:
174-
serializers.save_npz(f"{outdir}/generator.model", generator)
175-
serializers.save_npz(f"{outdir}/discriminator_0.model", discriminator)
176-
serializers.save_npz(f"{outdir}/discriminator_2.model", discriminator_d2)
177-
serializers.save_npz(f"{outdir}/discriminator_4.model", discriminator_d4)
158+
serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator)
178159

179160
with chainer.using_config("train", False):
180161
y = generator(noise_valid, line_valid)
@@ -183,25 +164,35 @@ def train(epochs, iterations, batchsize, validsize, path, outdir,
183164
cr = color_valid.data.get()
184165

185166
evaluator(y, cr, sr, outdir, epoch, validsize=validsize)
186-
187-
print(f"epoch: {epoch}")
188-
print(f"loss: {sum_loss / iterations}")
167+
168+
print(f"epoch: {epoch}")
169+
print(f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}")
189170

190171

191172
if __name__ == "__main__":
192173
parser = argparse.ArgumentParser(description="GauGAN")
193174
parser.add_argument('--e', type=int, default=1000, help="the number of epochs")
194-
parser.add_argument('--i', type=int, default=10000, help="the number of iterations")
175+
parser.add_argument('--i', type=int, default=2000, help="the number of iterations")
195176
parser.add_argument('--b', type=int, default=16, help="batch size")
196-
parser.add_argument('--v', type=int, default=3, help="valid size")
197-
parser.add_argument('--w', type=float, default=10.0, help="the weight of content loss")
198-
parser.add_argument('--kl', type=float, default=0.05, help="the weight of kl divergence loss")
177+
parser.add_argument('--v', type=int, default=12, help="valid size")
178+
parser.add_argument('--outdir', type=Path, default='outdir', help="output directory")
179+
parser.add_argument('--modeldir', type=Path, default='modeldir', help="model output directory")
180+
parser.add_argument('--ext', type=str, default=".jpg", help="extension of training images")
181+
parser.add_argument('--size', type=int, default=224, help="the size of training images")
182+
parser.add_argument('--dim', type=int, default=256, help="dimensions of latent space")
183+
parser.add_argument('--lr', type=float, default=0.0002, help="learning rate of Adam")
184+
parser.add_argument('--b1', type=float, default=0.0, help="beta1 of Adam")
185+
parser.add_argument('--b2', type=float, default=0.999, help="beta2 of Adam")
186+
parser.add_argument('--data_path', type=Path, help="path which contains training data")
199187
parser.add_argument('--encoder', action="store_true", help="enable image encoder")
200188

201189
args = parser.parse_args()
202190

203-
dataset_path = Path('./Dataset/danbooru-images/')
204-
outdir = Path('./outdir')
191+
outdir = args.outdir
205192
outdir.mkdir(exist_ok=True)
206193

207-
train(args.e, args.i, args.b, args.v, dataset_path, outdir, args.w, args.kl, args.encoder)
194+
modeldir = args.modeldir
195+
modeldir.mkdir(exist_ok=True)
196+
197+
train(args.e, args.i, args.b, args.v, outdir, modeldir, args.data_path,
198+
args.ext, args.size, args.dim, args.lr, args.b1, args.b2, args.encoder)

‎GauGAN/xdog.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,43 @@
11
import numpy as np
22
import cv2 as cv
33

4+
45
def sharpImage(img, sigma, k_sigma, p):
56
sigma_large = sigma * k_sigma
6-
G_small = cv.GaussianBlur(img,(0, 0), sigma)
7-
G_large = cv.GaussianBlur(img,(0, 0), sigma_large)
7+
G_small = cv.GaussianBlur(img, (0, 0), sigma)
8+
G_large = cv.GaussianBlur(img, (0, 0), sigma_large)
89
S = (1+p) * G_small - p * G_large
910

1011
return S
1112

13+
1214
def softThreshold(SI, epsilon, phi):
1315
T = np.zeros(SI.shape)
1416
SI_bright = SI >= epsilon
1517
SI_dark = SI < epsilon
1618
T[SI_bright] = 1.0
17-
T[SI_dark] = 1.0 + np.tanh( phi * (SI[SI_dark] - epsilon))
19+
T[SI_dark] = 1.0 + np.tanh(phi * (SI[SI_dark] - epsilon))
1820

1921
return T
2022

23+
2124
def xdog(img, sigma, k_sigma, p, epsilon, phi):
2225
S = sharpImage(img, sigma, k_sigma, p)
2326
SI = np.multiply(img, S)
2427
T = softThreshold(SI, epsilon, phi)
2528

2629
return T
2730

28-
def line_process(filename):
31+
32+
def xdog_process(filename):
2933
img = cv.imread(filename)
3034
img = cv.cvtColor(img, cv.COLOR_RGB2GRAY)
3135
img = img / 255.0
3236
sigma = np.random.choice([0.3, 0.4, 0.5])
33-
img = xdog(img, sigma, 4.5, 19,0.01, 10^9)
37+
img = xdog(img, sigma, 4.5, 19, 0.01, 10^9)
38+
img[img < 0.9] = 0.0
3439
img = img * 255
3540
img = img.reshape(img.shape[0], img.shape[1], 1)
36-
img = np.tile(img, (1,1,3))
41+
img = np.tile(img, (1, 1, 3))
3742

3843
return img

0 commit comments

Comments
 (0)
Please sign in to comment.