Skip to content

Commit 3135741

Browse files
committed
initial commit
1 parent ba640e1 commit 3135741

File tree

7 files changed

+852
-0
lines changed

7 files changed

+852
-0
lines changed

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,8 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
# etc
132+
.png
133+
.jpg
134+
.pth

create_dataset.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import os
2+
import cv2
3+
import glob
4+
import random
5+
import progressbar
6+
7+
import numpy as np
8+
9+
import matplotlib.pyplot as plt
10+
11+
rand_color = lambda : (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
12+
rand_pos = lambda a, b: (random.randint(a, b-1), random.randint(a, b-1))
13+
14+
target_size = 256
15+
imgs_per_back = 30
16+
17+
backs = glob.glob('./dataset/backs/*.png')
18+
fonts = glob.glob('./dataset/font_mask/*.png')
19+
20+
os.makedirs('./dataset/train/I', exist_ok=True)
21+
os.makedirs('./dataset/train/Itegt', exist_ok=True)
22+
os.makedirs('./dataset/train/Mm', exist_ok=True)
23+
os.makedirs('./dataset/train/Msgt', exist_ok=True)
24+
25+
os.makedirs('./dataset/val/I', exist_ok=True)
26+
os.makedirs('./dataset/val/Itegt', exist_ok=True)
27+
os.makedirs('./dataset/val/Mm', exist_ok=True)
28+
os.makedirs('./dataset/val/Msgt', exist_ok=True)
29+
30+
t_idx = len(os.listdir('./dataset/train/I'))
31+
v_idx = len(os.listdir('./dataset/val/I'))
32+
33+
bar = progressbar.ProgressBar(maxval=len(backs)*imgs_per_back)
34+
bar.start()
35+
for back in backs:
36+
back_img = cv2.imread(back)
37+
bh, bw, _ = back_img.shape
38+
if bh < target_size or bw < target_size:
39+
back_img = cv2.resize(back_img, (target_size, target_size), interpolation=cv2.INTER_CUBIC)
40+
bh, bw, _ = back_img.shape
41+
42+
for bi in range(imgs_per_back):
43+
sx, sy = random.randint(0, bw-target_size), random.randint(0, bh-target_size)
44+
45+
Itegt = back_img[sy:sy+target_size, sx:sx+target_size, :].copy()
46+
I = Itegt.copy()
47+
Mm = np.zeros_like(I)
48+
Msgt = np.zeros_like(I)
49+
50+
hist = []
51+
for font in random.sample(fonts, random.randint(2, 4)):
52+
font_img = cv2.imread(font)
53+
mask_img = np.ones_like(font_img, dtype=np.uint8)*255
54+
55+
height, width, _ = font_img.shape
56+
57+
angle = random.randint(-30, +30)
58+
fs = random.randint(90, 120)
59+
ratio = fs / height - 0.2
60+
61+
matrix = cv2.getRotationMatrix2D((width/2, height/2), angle, ratio)
62+
font_rot = cv2.warpAffine(font_img, matrix, (width, height), cv2.INTER_CUBIC)
63+
mask_rot = cv2.warpAffine(mask_img, matrix, (width, height), cv2.INTER_CUBIC)
64+
65+
h, w, _ = font_rot.shape
66+
67+
font_in_I = np.zeros_like(I)
68+
mask_in_I = np.zeros_like(I)
69+
70+
allow = 0
71+
while True:
72+
sx, sy = rand_pos(0, target_size-w)
73+
74+
done = True
75+
for sx_, sy_ in hist:
76+
if (sx_ - sx)**2 + (sy_ - sy)**2 < (fs * ratio)**2 - allow:
77+
done = False
78+
break
79+
allow += 5
80+
81+
if done:
82+
hist.append([sx, sy])
83+
break
84+
85+
font_in_I[sy:sy+h, sx:sx+w, :] = font_rot
86+
mask_in_I[sy:sy+h, sx:sx+w, :] = mask_rot
87+
88+
font_in_I[font_in_I > 30] = 255
89+
mask_in_I[mask_in_I > 30] = 255
90+
91+
I = cv2.bitwise_and(I, 255-font_in_I)
92+
I = cv2.bitwise_or(I, (font_in_I // 255 * rand_color()).astype(np.uint8))
93+
94+
Mm = cv2.bitwise_or(Mm, mask_in_I)
95+
Msgt = cv2.bitwise_or(Msgt, font_in_I)
96+
97+
if bi < imgs_per_back*0.8:
98+
cv2.imwrite(f'dataset/train/I/{t_idx}.png', I)
99+
cv2.imwrite(f'dataset/train/Itegt/{t_idx}.png', Itegt)
100+
cv2.imwrite(f'dataset/train/Mm/{t_idx}.png', Mm)
101+
cv2.imwrite(f'dataset/train/Msgt/{t_idx}.png', Msgt)
102+
t_idx += 1
103+
else:
104+
cv2.imwrite(f'dataset/val/I/{v_idx}.png', I)
105+
cv2.imwrite(f'dataset/val/Itegt/{v_idx}.png', Itegt)
106+
cv2.imwrite(f'dataset/val/Mm/{v_idx}.png', Mm)
107+
cv2.imwrite(f'dataset/val/Msgt/{v_idx}.png', Msgt)
108+
v_idx += 1
109+
110+
bar.update(t_idx + v_idx)
111+
bar.finish()

dataset.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import os, cv2
2+
import numpy as np
3+
4+
import torch
5+
from torch.utils.data import Dataset
6+
7+
def mat_to_tensor(mat):
8+
mat = mat.transpose((2, 0, 1))
9+
tensor = torch.Tensor(mat)
10+
return tensor
11+
12+
def tensor_to_mat(tensor):
13+
mat = tensor.detach().cpu().numpy()
14+
mat = mat.transpose((0, 2, 3, 1))
15+
return mat
16+
17+
def preprocess_image(img, target_shape: tuple):
18+
img = cv2.resize(img, target_shape, interpolation=cv2.INTER_CUBIC).astype(np.float32)
19+
img = img / 255.
20+
if len(img.shape) == 2:
21+
img = img.reshape(*img.shape, 1)
22+
23+
return img
24+
25+
def postprocess_image(img):
26+
# img = img * 255
27+
img = (img - img.min()) / (img.max() - img.min()) * 255
28+
return img.astype(np.uint8)
29+
30+
class CustomDataset(Dataset):
31+
def __init__(self,
32+
data_dir,
33+
set_name="train",
34+
target_size=(256, 256)):
35+
36+
super().__init__()
37+
38+
self.root_dir = os.path.join(data_dir, set_name)
39+
self.target_size = target_size
40+
41+
self.I_dir = os.path.join(self.root_dir, "I")
42+
self.Itegt_dir = os.path.join(self.root_dir, "Itegt")
43+
self.Mm_dir = os.path.join(self.root_dir, "Mm")
44+
self.Msgt_dir = os.path.join(self.root_dir, "Msgt")
45+
46+
self.datas = os.listdir(self.I_dir)
47+
48+
def __len__(self):
49+
return len(self.datas)
50+
51+
def __getitem__(self, idx):
52+
img_name = self.datas[idx]
53+
54+
I = cv2.imread(os.path.join(self.I_dir, img_name))
55+
Itegt = cv2.imread(os.path.join(self.Itegt_dir, img_name))
56+
Mm = cv2.imread(os.path.join(self.Mm_dir, img_name), cv2.IMREAD_GRAYSCALE)
57+
Msgt = cv2.imread(os.path.join(self.Msgt_dir, img_name), cv2.IMREAD_GRAYSCALE)
58+
59+
I = mat_to_tensor(preprocess_image(I, self.target_size))
60+
Itegt = mat_to_tensor(preprocess_image(Itegt, self.target_size))
61+
Mm = mat_to_tensor(preprocess_image(Mm, self.target_size))
62+
Msgt = mat_to_tensor(preprocess_image(Msgt, self.target_size))
63+
64+
return I, Itegt, Mm, Msgt
65+
66+
67+
if __name__ == "__main__":
68+
ds = CustomDataset('dataset', 'train')
69+
70+
I, Itegt, Mm, Ms = ds.__getitem__(0)
71+
print(f"Dataset length : {len(ds)}")
72+
print(f"I shape : {I.shape}")
73+
print(f"Itegt shape : {Itegt.shape}")
74+
print(f"Mm shape : {Mm.shape}")
75+
print(f"Ms shape : {Ms.shape}")

losses.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
def TSDLoss(Mgt, Ms, Ms_, r=10):
6+
return torch.mean(torch.abs(Ms-Mgt) + r * torch.abs(Ms_-Mgt))
7+
8+
def TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_, rm=5, rs=5, rr=10):
9+
10+
Mw = torch.ones_like(Mm) + rm * Mm + rs * Ms
11+
Mw_ = torch.ones_like(Mm) + rm * Mm + rs * Ms_
12+
13+
Ltrg = torch.mean(torch.abs(torch.mul(Ite, Mw) - torch.mul(Itegt, Mw)) + \
14+
rr * torch.abs(torch.mul(Ite_, Mw_) - torch.mul(Itegt, Mw_)))
15+
16+
return Ltrg

modules.py

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
# dis_conv
6+
# (https://github.com/JiahuiYu/generative_inpainting/blob/3a5324373ba52c68c79587ca183bc10b9e57b783/inpaint_ops.py#L84)
7+
class _dis_conv(nn.Module):
8+
9+
def __init__(self, in_channels, out_channels, kernel_size=5, stride=2, padding=2):
10+
super().__init__()
11+
12+
self._conv = nn.Sequential(
13+
nn.utils.spectral_norm(
14+
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
15+
),
16+
nn.LeakyReLU(inplace=True)
17+
)
18+
19+
# weight initialization
20+
def weight_init(m):
21+
if isinstance(m, nn.Conv2d):
22+
# nn.utils.spectral_norm(m.weight)
23+
nn.init.zeros_(m.bias)
24+
25+
self.apply(weight_init)
26+
27+
def forward(self, x):
28+
return self._conv(x)
29+
30+
# weights are fixed to one, bias to zero
31+
class _one_conv(nn.Module):
32+
def __init__(self, in_channels, out_channels, kernel_size=5, stride=2, padding=2):
33+
super().__init__()
34+
35+
self._conv = nn.Sequential(
36+
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
37+
)
38+
39+
# weight initialization
40+
def weight_init(m):
41+
if isinstance(m, nn.Conv2d):
42+
nn.init.ones_(m.weight)
43+
nn.init.zeros_(m.bias)
44+
m.weight.requires_grad = False
45+
m.bias.requires_grad = False
46+
47+
self.apply(weight_init)
48+
49+
def forward(self, x):
50+
return self._conv(x)
51+
52+
class _double_conv2d(nn.Module):
53+
54+
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, mid_channels=None):
55+
super().__init__()
56+
57+
if not mid_channels:
58+
mid_channels = out_channels
59+
60+
self.double_conv = nn.Sequential(
61+
nn.Conv2d(in_channels, mid_channels, kernel_size=kernel_size, padding=padding),
62+
nn.BatchNorm2d(mid_channels),
63+
nn.ReLU(inplace=True),
64+
65+
nn.Conv2d(mid_channels, out_channels, kernel_size=kernel_size, padding=padding),
66+
nn.BatchNorm2d(out_channels),
67+
nn.ReLU(inplace=True)
68+
)
69+
70+
# weight initialization
71+
def weight_init(m):
72+
if isinstance(m, nn.Conv2d):
73+
nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu'))
74+
nn.init.zeros_(m.bias)
75+
76+
self.apply(weight_init)
77+
78+
def forward(self, x):
79+
return self.double_conv(x)
80+
81+
82+
class _down_conv2d(nn.Module):
83+
84+
def __init__(self,
85+
in_channels,
86+
out_channels,
87+
kernel_size):
88+
89+
super().__init__()
90+
91+
self.seq_model = nn.Sequential(
92+
nn.MaxPool2d(2),
93+
_double_conv2d(in_channels, out_channels)
94+
)
95+
96+
97+
def forward(self, x):
98+
return self.seq_model(x)
99+
100+
101+
class _up_conv2d(nn.Module):
102+
103+
def __init__(self,
104+
in_channels,
105+
out_channels,
106+
kernel_size):
107+
108+
super().__init__()
109+
110+
self.conv_t = nn.ConvTranspose2d(in_channels, in_channels//2, 2, 2)
111+
self.conv = _double_conv2d(in_channels, out_channels)
112+
113+
# x1 : input, x2 : matching down_conv2d output
114+
def forward(self, x1, x2):
115+
x1 = self.conv_t(x1)
116+
117+
diffY = x2.size()[2] - x1.size()[2]
118+
diffX = x2.size()[3] - x1.size()[3]
119+
120+
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
121+
diffY // 2, diffY - diffY // 2])
122+
123+
x = torch.cat([x2, x1], dim=1)
124+
return self.conv(x)
125+
126+
127+
class _final_conv2d(nn.Module):
128+
129+
def __init__(self,
130+
in_channels,
131+
out_channels,
132+
kernel_size):
133+
134+
super().__init__()
135+
136+
self.conv = nn.Conv2d(in_channels, out_channels, 1, 1)
137+
138+
def forward(self, x):
139+
return self.conv(x)

0 commit comments

Comments
 (0)