Skip to content

Commit cb61f1a

Browse files
committed
clean code, add load config from .yaml file
1 parent 3d61aaf commit cb61f1a

14 files changed

+659
-706
lines changed

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ ENV PATH /root/miniconda3/envs/strda/bin:$PATH
2828

2929
# install dependencies
3030
RUN pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
31-
RUN pip install opencv-python==4.4.0.46 Pillow==7.2.0 opencv-python-headless==4.5.1.48 lmdb tqdm nltk
31+
RUN pip install opencv-python==4.4.0.46 Pillow==7.2.0 opencv-python-headless==4.5.1.48 lmdb tqdm nltk six pyyaml
3232

3333
RUN apt-get update
3434
RUN apt-get install -y ffmpeg libsm6 libxext6
3535

3636
# get repository
37-
WORKDIR /home
37+
WORKDIR /home

config/default.yaml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Data Processing
2+
batch_max_length: 25 # maximum-label-length
3+
imgH: 32 # the height of the input image
4+
imgW: 100 # the width of the input image
5+
character: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" # character label
6+
7+
# Model Architecture
8+
num_fiducial: 20 # number of fiducial points of TPS-STN"
9+
input_channel: 3 # the number of input channel of Feature extractor
10+
output_channel: 512 # the number of output channel of Feature extractor
11+
hidden_size: 256 # the size of the LSTM hidden state
12+
13+
# Optimizer
14+
lr: 0.001 # learning rate, 0.001 for Adam
15+
weight_decay: 0.01 # weight decay, 0.01 for Adam
16+
17+
# Experiment
18+
manual_seed: 111 # for random seed setting
19+
20+
# Training
21+
grad_clip: 5 # gradient clipping value
22+
workers: 4 # number of data loading workers
23+
24+
# HDGE
25+
decay_epoch: 100 # epoch from which to start lr decay
26+
load_height: 48
27+
load_width: 160
28+
crop_height: 32
29+
crop_width: 100
30+
lamda: 10
31+
idt_coef: 0.5
32+
ngf: 64 # of gen filters in first conv layer
33+
ndf: 64 # of discrim filters in first conv layer
34+
norm: "instance" # instance normalization or batch normalization

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
six
12
lmdb
23
tqdm
34
nltk
5+
pyyaml
46
pillow
5-
opencv-python
7+
opencv-python

source/HDGE.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import itertools
3+
34
import numpy as np
45
from tqdm import tqdm
56

@@ -8,11 +9,11 @@
89
from torch.autograd import Variable
910
from torch.utils.data import Subset
1011

11-
import utils.utils_HDGE as utils
12-
1312
from .ops import set_grad
1413
from .dataset import AlignCollateHDGE, hierarchical_dataset
1514

15+
import utils.utils_HDGE as utils
16+
1617
from modules.generators import define_Gen
1718
from modules.discriminators import define_Dis
1819

@@ -47,7 +48,7 @@ def __init__(self,args):
4748
os.makedirs(args.checkpoint_dir)
4849

4950
try:
50-
ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir))
51+
ckpt = utils.load_checkpoint('%s/HDGE_gen_dis.ckpt' % (args.checkpoint_dir))
5152
self.start_epoch = ckpt['epoch']
5253
self.Da.load_state_dict(ckpt['Da'])
5354
self.Db.load_state_dict(ckpt['Db'])
@@ -73,7 +74,7 @@ def train(self,args):
7374
source_data,
7475
batch_size=args.batch_size,
7576
shuffle=True,
76-
num_workers=args.num_workers,
77+
num_workers=args.workers,
7778
collate_fn=myAlignCollate,
7879
pin_memory=False,
7980
drop_last=True,
@@ -82,7 +83,7 @@ def train(self,args):
8283
target_data_adjust,
8384
batch_size=args.batch_size,
8485
shuffle=True,
85-
num_workers=args.num_workers,
86+
num_workers=args.workers,
8687
collate_fn=myAlignCollate,
8788
pin_memory=False,
8889
drop_last=True,
@@ -182,8 +183,8 @@ def train(self,args):
182183
b_dis_loss.backward()
183184
self.d_optimizer.step()
184185

185-
print("\nEpoch: (%3d/%3d) | Gen Loss: %0.4f | Dis Loss: %0.4f" %
186-
(epoch, args.epochs, gen_loss,a_dis_loss+b_dis_loss))
186+
print("\nEpoch: (%3d/%3d) | Gen Loss: %0.4f | Dis Loss: %0.4f\n" %
187+
(epoch + 1, args.epochs, gen_loss,a_dis_loss+b_dis_loss))
187188

188189
# override the latest checkpoint
189190
utils.save_checkpoint({'epoch': epoch + 1,

source/dataset.py

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import os
22
import sys
33
import six
4-
import PIL
54
import lmdb
5+
6+
import PIL
67
from PIL import Image
78

89
import torch
@@ -15,13 +16,13 @@
1516
_STD_IMAGENET = torch.tensor([0.229, 0.224, 0.225])
1617

1718

18-
def get_dataloader(opt, dataset, batch_size, shuffle = False, mode = "label"):
19+
def get_dataloader(args, dataset, batch_size, shuffle = False, mode = "label"):
1920
"""
2021
Get dataloader for each dataset
2122
2223
Parameters
2324
----------
24-
opt: argparse.ArgumentParser().parse_args()
25+
args: argparse.ArgumentParser().parse_args()
2526
dataset: torch.utils.data.Dataset
2627
batch_size: int
2728
shuffle: boolean
@@ -32,23 +33,23 @@ def get_dataloader(opt, dataset, batch_size, shuffle = False, mode = "label"):
3233
"""
3334

3435
if mode == "raw":
35-
myAlignCollate = AlignCollateRaw(opt)
36+
myAlignCollate = AlignCollateRaw(args)
3637
else:
37-
myAlignCollate = AlignCollate(opt, mode)
38+
myAlignCollate = AlignCollate(args, mode)
3839

3940
data_loader = DataLoader(
4041
dataset,
4142
batch_size=batch_size,
4243
shuffle=shuffle,
43-
num_workers=opt.workers,
44+
num_workers=args.workers,
4445
collate_fn=myAlignCollate,
4546
pin_memory=False,
4647
drop_last=False,
4748
)
4849
return data_loader
4950

5051

51-
def hierarchical_dataset(root, opt, mode="label", drop_data=[]):
52+
def hierarchical_dataset(root, args, mode="label", drop_data=[]):
5253
""" select_data='/' contains all sub-directory of root directory """
5354
dataset_list = []
5455
dataset_log = f"dataset_root: {root}\t dataset:"
@@ -72,10 +73,10 @@ def hierarchical_dataset(root, opt, mode="label", drop_data=[]):
7273
for dirpath in listdir:
7374
if mode == "raw":
7475
# load data without label
75-
dataset = LmdbDataset_raw(dirpath, opt)
76+
dataset = LmdbDataset_raw(dirpath, args)
7677
else:
7778
# load data with label
78-
dataset = LmdbDataset(dirpath, opt)
79+
dataset = LmdbDataset(dirpath, args)
7980
sub_dataset_log = f"sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}"
8081
print(sub_dataset_log)
8182
dataset_log += f"{sub_dataset_log}\n"
@@ -113,15 +114,15 @@ def __getitem__(self, index):
113114

114115
class AlignCollate(object):
115116
""" Transform data to the same format """
116-
def __init__(self, opt, mode = "label"):
117-
self.opt = opt
117+
def __init__(self, args, mode = "label"):
118+
self.args = args
118119
# resize image
119120
if (mode == "adapt" or mode == "supervised"):
120121
self.transform = Rand_augment()
121122
else:
122123
self.transform = torchvision.transforms.Compose([])
123124

124-
self.resize = ResizeNormalize(opt)
125+
self.resize = ResizeNormalize(args)
125126
print("Use Text_augment", self.transform)
126127

127128
def __call__(self, batch):
@@ -135,10 +136,10 @@ def __call__(self, batch):
135136

136137
class AlignCollateRaw(object):
137138
""" Transform data to the same format """
138-
def __init__(self, opt):
139-
self.opt = opt
139+
def __init__(self, args):
140+
self.args = args
140141
# resize image
141-
self.transform = ResizeNormalize(opt)
142+
self.transform = ResizeNormalize(args)
142143

143144
def __call__(self, batch):
144145
images = batch
@@ -151,20 +152,20 @@ def __call__(self, batch):
151152

152153
class AlignCollateHDGE(object):
153154
""" Transform data to the same format """
154-
def __init__(self, opt, infer=False):
155-
self.opt = opt
155+
def __init__(self, args, infer=False):
156+
self.args = args
156157

157158
# for transforming the input image
158159
if infer == False:
159160
transform = torchvision.transforms.Compose(
160161
[torchvision.transforms.RandomHorizontalFlip(),
161-
torchvision.transforms.Resize((opt.load_height,opt.load_width)),
162-
torchvision.transforms.RandomCrop((opt.crop_height,opt.crop_width)),
162+
torchvision.transforms.Resize((args.load_height,args.load_width)),
163+
torchvision.transforms.RandomCrop((args.crop_height,args.crop_width)),
163164
torchvision.transforms.ToTensor(),
164165
torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
165166
else:
166167
transform = torchvision.transforms.Compose(
167-
[torchvision.transforms.Resize((opt.crop_height,opt.crop_width)),
168+
[torchvision.transforms.Resize((args.crop_height,args.crop_width)),
168169
torchvision.transforms.ToTensor(),
169170
torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
170171

@@ -181,10 +182,10 @@ def __call__(self, batch):
181182

182183
class LmdbDataset(Dataset):
183184
""" Load data from Lmdb file with label """
184-
def __init__(self, root, opt):
185+
def __init__(self, root, args):
185186

186187
self.root = root
187-
self.opt = opt
188+
self.args = args
188189
self.env = lmdb.open(
189190
root,
190191
max_readers=32,
@@ -207,7 +208,7 @@ def __init__(self, root, opt):
207208

208209
# length filtering
209210
length_of_label = len(label)
210-
if length_of_label > opt.batch_max_length:
211+
if length_of_label > args.batch_max_length:
211212
continue
212213

213214
self.filtered_index_list.append(index)
@@ -236,18 +237,18 @@ def __getitem__(self, index):
236237
except IOError:
237238
print(f"Corrupted image for {index}")
238239
# make dummy image and dummy label for corrupted image.
239-
img = PIL.Image.new("RGB", (self.opt.imgW, self.opt.imgH))
240+
img = PIL.Image.new("RGB", (self.args.imgW, self.args.imgH))
240241
label = "[dummy_label]"
241242

242243
return (img, label)
243244

244245

245246
class LmdbDataset_raw(Dataset):
246247
""" Load data from Lmdb file without label """
247-
def __init__(self, root, opt):
248+
def __init__(self, root, args):
248249

249250
self.root = root
250-
self.opt = opt
251+
self.args = args
251252
self.env = lmdb.open(
252253
root,
253254
max_readers=32,
@@ -284,27 +285,21 @@ def __getitem__(self, index):
284285
except IOError:
285286
print(f"Corrupted image for {img_key}")
286287
# make dummy image for corrupted image.
287-
img = PIL.Image.new("RGB", (self.opt.imgW, self.opt.imgH))
288+
img = PIL.Image.new("RGB", (self.args.imgW, self.args.imgH))
288289

289290
return img
290291

291292

292293
class ResizeNormalize(object):
293294

294-
def __init__(self, opt):
295-
self.opt = opt
295+
def __init__(self, args):
296+
self.args = args
296297
_transforms = []
297298

298-
_transforms.append(
299-
torchvision.transforms.Resize((self.opt.imgH, self.opt.imgW),
300-
interpolation=torchvision.transforms.InterpolationMode.BICUBIC))
299+
_transforms.append(torchvision.transforms.Resize((self.args.imgH, self.args.imgW),
300+
interpolation=torchvision.transforms.InterpolationMode.BICUBIC))
301301
_transforms.append(torchvision.transforms.ToTensor())
302-
if self.opt.use_IMAGENET_norm:
303-
_transforms.append(torchvision.transforms.Normalize(mean=_MEAN_IMAGENET,
304-
std=_STD_IMAGENET))
305-
else:
306-
_transforms.append(torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5],
307-
std=[0.5, 0.5, 0.5]))
302+
_transforms.append(torchvision.transforms.Normalize(mean=_MEAN_IMAGENET, std=_STD_IMAGENET))
308303
self._transforms = torchvision.transforms.Compose(_transforms)
309304

310305
def __call__(self, image):

0 commit comments

Comments
 (0)