Skip to content

Commit 22b689a

Browse files
update
1 parent f8d0b69 commit 22b689a

File tree

6 files changed

+543
-0
lines changed

6 files changed

+543
-0
lines changed

Diff for: pix2pix/dataset.py

+203
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import numpy as np
2+
import random
3+
import cv2 as cv
4+
import copy
5+
import chainer
6+
import chainer.functions as F
7+
8+
from xdog import xdog_process
9+
from chainer import cuda
10+
from pathlib import Path
11+
from PIL import Image
12+
13+
xp = cuda.cupy
14+
cuda.get_device(0).use()
15+
16+
17+
class DatasetLoader:
18+
def __init__(self,
19+
data_path: Path,
20+
sketch_path: Path,
21+
digi_path: Path,
22+
extension='.jpg',
23+
train_size=128,
24+
valid_size=512):
25+
26+
self.data_path = data_path
27+
self.skecth_path = sketch_path
28+
self.digi_path = digi_path
29+
self.extension = extension
30+
self.train_size = train_size
31+
self.valid_size = valid_size
32+
33+
self.interpolations = (
34+
cv.INTER_LINEAR,
35+
cv.INTER_AREA,
36+
cv.INTER_NEAREST,
37+
cv.INTER_CUBIC,
38+
cv.INTER_LANCZOS4
39+
)
40+
41+
self.pathlist = list(self.data_path.glob(f"**/*{extension}"))
42+
self.train_list, self.val_list = self._train_val_split(self.pathlist)
43+
self.train_len = len(self.train_list)
44+
45+
def __str__(self):
46+
return f"dataset path: {self.data_path} train data: {self.train_len}"
47+
48+
# Initialization method
49+
def _train_val_split(self, pathlist: list):
50+
split_point = int(len(self.pathlist) * 0.95)
51+
x_train = self.pathlist[:split_point]
52+
x_test = self.pathlist[split_point:]
53+
54+
return x_train, x_test
55+
56+
# Line art preparation method
57+
@staticmethod
58+
def _add_intensity(img, intensity=1.7):
59+
const = 255.0 ** (1.0 - intensity)
60+
img = (const * (img ** intensity))
61+
62+
return img
63+
64+
@staticmethod
65+
def _morphology(img):
66+
method = np.random.choice(["dilate", "erode"])
67+
if method == "dilate":
68+
img = cv.dilate(img, (5, 5), iterations=1)
69+
elif method == "erode":
70+
img = cv.erode(img, (5, 5), iterations=1)
71+
72+
return img
73+
74+
@staticmethod
75+
def _color_variant(img, max_value=30):
76+
color = np.random.randint(max_value + 1)
77+
img[img < 200] = color
78+
79+
return img
80+
81+
def _detail_preprocess(self, img):
82+
intensity = np.random.randint(2)
83+
morphology = np.random.randint(2)
84+
color_variance = np.random.randint(2)
85+
86+
if intensity:
87+
img = self._add_intensity(img)
88+
if morphology:
89+
img = self._morphology(img)
90+
if color_variance:
91+
img = self._color_variant(img)
92+
93+
return img
94+
95+
def _xdog_preprocess(self, path):
96+
img = xdog_process(str(path))
97+
img = (img * 255.0).reshape(img.shape[0], img.shape[1], 1)
98+
img = np.tile(img, (1, 1, 3))
99+
100+
return img
101+
102+
def _pencil_preprocess(self, path):
103+
filename = path.name
104+
line_path = self.skecth_path / Path(filename)
105+
img = cv.imread(str(line_path))
106+
107+
return img
108+
109+
def _digital_preprocess(self, path):
110+
filename = path.name
111+
line_path = self.digi_path / Path(filename)
112+
img = cv.imread(str(line_path))
113+
114+
return img
115+
116+
def _preprocess(self, path):
117+
method = np.random.choice(["xdog", "pencil", "digital"])
118+
119+
if method == "xdog":
120+
img = self._xdog_preprocess(path)
121+
elif method == "pencil":
122+
img = self._pencil_preprocess(path)
123+
elif method == "digital":
124+
img = self._digital_preprocess(path)
125+
126+
img = self._detail_preprocess(img)
127+
128+
return img
129+
130+
# Preprocess method
131+
@staticmethod
132+
def _random_crop(line, color, size):
133+
scale = np.random.randint(288, 768)
134+
line = cv.resize(line, (scale, scale))
135+
color = cv.resize(color, (scale, scale))
136+
137+
height, width = line.shape[0], line.shape[1]
138+
rnd0 = np.random.randint(height - size - 1)
139+
rnd1 = np.random.randint(width - size - 1)
140+
141+
line = line[rnd0: rnd0 + size, rnd1: rnd1 + size]
142+
color = color[rnd0: rnd0 + size, rnd1: rnd1 + size]
143+
144+
return line, color
145+
146+
@staticmethod
147+
def _coordinate(image):
148+
image = image[:, :, ::-1]
149+
image = image.transpose(2, 0, 1)
150+
image = (image - 127.5) / 127.5
151+
152+
return image
153+
154+
@staticmethod
155+
def _variable(image_list):
156+
return chainer.as_variable(xp.array(image_list).astype(xp.float32))
157+
158+
def _prepare_pair(self, image_path, size, mode="train"):
159+
color = cv.imread(str(image_path))
160+
line = self._preprocess(image_path)
161+
162+
if mode == "train":
163+
line, color = self._random_crop(line, color, size=size)
164+
165+
color = self._coordinate(color)
166+
line = self._coordinate(line)
167+
168+
return (line, color)
169+
170+
def train(self, batchsize):
171+
color_box = []
172+
line_box = []
173+
174+
for _ in range(batchsize):
175+
rnd = np.random.randint(self.train_len)
176+
image_path = self.train_list[rnd]
177+
178+
line, color = self._prepare_pair(image_path, size=self.train_size, mode="train")
179+
180+
color_box.append(color)
181+
line_box.append(line)
182+
183+
color = self._variable(color_box)
184+
line = self._variable(line_box)
185+
186+
return (line, color)
187+
188+
def valid(self, validsize):
189+
color_box = []
190+
line_box = []
191+
192+
for v in range(validsize):
193+
image_path = self.val_list[v]
194+
195+
line, color = self._prepare_pair(image_path, size=self.valid_size, mode="valid")
196+
197+
color_box.append(color)
198+
line_box.append(line)
199+
200+
color = self._variable(color_box)
201+
line = self._variable(line_box)
202+
203+
return (line, color)

Diff for: pix2pix/model.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import numpy as np
2+
import chainer
3+
import chainer.functions as F
4+
import chainer.links as L
5+
6+
from chainer import cuda, initializers, Chain
7+
8+
xp = cuda.cupy
9+
10+
11+
class CBR(Chain):
12+
def __init__(self,
13+
in_ch,
14+
out_ch,
15+
kernel,
16+
stride,
17+
padding,
18+
up=False,
19+
activ=F.relu):
20+
21+
self.up = up
22+
self.activ = activ
23+
w = initializers.GlorotUniform()
24+
super(CBR, self).__init__()
25+
26+
with self.init_scope():
27+
self.c0 = L.Convolution2D(in_ch, out_ch, kernel, stride, padding, initialW=w)
28+
self.bn0 = L.BatchNormalization(out_ch)
29+
30+
def __call__(self, x):
31+
if self.up:
32+
x = F.unpooling_2d(x, 2, 2, 0, cover_all=False)
33+
34+
h = self.activ(self.bn0(self.c0(x)))
35+
36+
return h
37+
38+
39+
class UNet(Chain):
40+
def __init__(self, base=64):
41+
super(UNet, self).__init__()
42+
w = initializers.GlorotUniform()
43+
44+
with self.init_scope():
45+
self.e0 = CBR(3, base, 3, 1, 1)
46+
self.e1 = CBR(base, base*2, 4, 2, 1, activ=F.leaky_relu)
47+
self.e2 = CBR(base*2, base*4, 4, 2, 1, activ=F.leaky_relu)
48+
self.e3 = CBR(base*4, base*8, 4, 2, 1, activ=F.leaky_relu)
49+
self.e4 = CBR(base*8, base*8, 4, 2, 1, activ=F.leaky_relu)
50+
self.e5 = CBR(base*8, base*16, 4, 2, 1, activ=F.leaky_relu)
51+
52+
self.d0 = CBR(base*16, base*8, 3, 1, 1, up=True, activ=F.leaky_relu)
53+
self.d1 = CBR(base*16, base*8, 3, 1, 1, up=True, activ=F.leaky_relu)
54+
self.d2 = CBR(base*16, base*4, 3, 1, 1, up=True, activ=F.leaky_relu)
55+
self.d3 = CBR(base*8, base*2, 3, 1, 1, up=True, activ=F.leaky_relu)
56+
self.d4 = CBR(base*4, base, 3, 1, 1, up=True, activ=F.leaky_relu)
57+
self.out = L.Convolution2D(base, 3, 1, 1, 0, initialW=w)
58+
59+
def __call__(self, x):
60+
h = self.e0(x)
61+
h1 = self.e1(h)
62+
h2 = self.e2(h1)
63+
h3 = self.e3(h2)
64+
h4 = self.e4(h3)
65+
h5 = self.e5(h4)
66+
67+
h = self.d0(h5)
68+
h = self.d1(F.concat([h, h4]))
69+
h = self.d2(F.concat([h, h3]))
70+
h = self.d3(F.concat([h, h2]))
71+
h = self.d4(F.concat([h, h1]))
72+
h = self.out(h)
73+
74+
return F.tanh(h)
75+
76+
77+
class Discriminator(Chain):
78+
def __init__(self, base=64):
79+
super(Discriminator, self).__init__()
80+
81+
w = initializers.GlorotUniform()
82+
with self.init_scope():
83+
self.cbr0 = CBR(3, base, 4, 2, 1)
84+
self.cbr1 = CBR(base, base*2, 4, 2, 1)
85+
self.cbr2 = CBR(base*2, base*4, 4, 2, 1)
86+
self.cbr3 = CBR(base*4, base*8, 4, 2, 1)
87+
self.cbr4 = CBR(base*8, base*16, 4, 2, 1)
88+
self.cout = L.Convolution2D(base*16, 1, 1, 1, 0, initialW=w)
89+
90+
def __call__(self, x):
91+
h = self.cbr0(x)
92+
h = self.cbr1(h)
93+
h = self.cbr2(h)
94+
h = self.cbr3(h)
95+
h = self.cbr4(h)
96+
h = self.cout(h)
97+
98+
return h

0 commit comments

Comments
 (0)