-
Notifications
You must be signed in to change notification settings - Fork 135
/
Copy pathlsvt_converter.py
executable file
·98 lines (81 loc) · 3.3 KB
/
lsvt_converter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python3
import argparse
import os
import os.path as osp
import re
from functools import partial
import mmcv
import numpy as np
from mmocr.utils.fileio import list_to_file
from PIL import Image
def parse_args():
parser = argparse.ArgumentParser(description='Generate training set of LSVT ' 'by cropping box image.')
parser.add_argument('root_path', help='Root dir path of LSVT')
parser.add_argument('n_proc', default=1, type=int, help='Number of processes to run')
args = parser.parse_args()
return args
def process_img(args, src_image_root, dst_image_root):
# Dirty hack for multiprocessing
img_idx, img_info, anns = args
try:
src_img = Image.open(osp.join(src_image_root, 'train_full_images_0/{}.jpg'.format(img_info)))
except IOError:
src_img = Image.open(osp.join(src_image_root, 'train_full_images_1/{}.jpg'.format(img_info)))
blacklist = ['LOFTINESS*']
whitelist = ['#Find YOUR Fun#', 'Story #', '*0#']
labels = []
for ann_idx, ann in enumerate(anns):
text_label = ann['transcription']
# Ignore illegible or words with non-Latin characters
if (
ann['illegibility']
or re.findall(r'[\u4e00-\u9fff]+', text_label)
or text_label in blacklist
or ('#' in text_label and text_label not in whitelist)
):
continue
points = np.asarray(ann['points'])
x1, y1 = points.min(axis=0)
x2, y2 = points.max(axis=0)
dst_img = src_img.crop((x1, y1, x2, y2))
dst_img_name = f'img_{img_idx}_{ann_idx}.jpg'
dst_img_path = osp.join(dst_image_root, dst_img_name)
# Preserve JPEG quality
dst_img.save(dst_img_path, qtables=src_img.quantization)
labels.append(f'{osp.basename(dst_image_root)}/{dst_img_name}' f' {text_label}')
src_img.close()
return labels
def convert_lsvt(root_path, dst_image_path, dst_label_filename, annotation_filename, img_start_idx=0, nproc=1):
annotation_path = osp.join(root_path, annotation_filename)
if not osp.exists(annotation_path):
raise Exception(f'{annotation_path} not exists, please check and try again.')
src_image_root = root_path
# outputs
dst_label_file = osp.join(root_path, dst_label_filename)
dst_image_root = osp.join(root_path, dst_image_path)
os.makedirs(dst_image_root, exist_ok=True)
annotation = mmcv.load(annotation_path)
process_img_with_path = partial(process_img, src_image_root=src_image_root, dst_image_root=dst_image_root)
tasks = []
for img_idx, (img_info, anns) in enumerate(annotation.items()):
tasks.append((img_idx + img_start_idx, img_info, anns))
labels_list = mmcv.track_parallel_progress(process_img_with_path, tasks, keep_order=True, nproc=nproc)
final_labels = []
for label_list in labels_list:
final_labels += label_list
list_to_file(dst_label_file, final_labels)
return len(annotation)
def main():
args = parse_args()
root_path = args.root_path
print('Processing training set...')
convert_lsvt(
root_path=root_path,
dst_image_path='image_train',
dst_label_filename='train_label.txt',
annotation_filename='train_full_labels.json',
nproc=args.n_proc,
)
print('Finish')
if __name__ == '__main__':
main()