Skip to content

Commit 8ccbaa5

Browse files
authored
Merge pull request #21 from Joker1212/lineless
feature: optimize lineless table rec
2 parents a7dfe47 + 9da519c commit 8ccbaa5

File tree

3 files changed

+547
-138
lines changed

3 files changed

+547
-138
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,5 @@ long1.jpg
156156
*.pdmodel
157157

158158
.DS_Store
159-
*.npy
159+
*.npy
160+
/lineless_table_rec/output/

lineless_table_rec/main.py

Lines changed: 145 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# @Contact: [email protected]
44
import argparse
55
import logging
6+
import os
67
import time
78
import traceback
89
from pathlib import Path
@@ -12,13 +13,13 @@
1213
import numpy as np
1314
from rapidocr_onnxruntime import RapidOCR
1415

15-
from .lineless_table_process import DetProcess, get_affine_transform_upper_left
16-
from .utils import InputType, LoadImage, OrtInferSession
17-
from .utils_table_recover import (
16+
from lineless_table_process import DetProcess, get_affine_transform_upper_left
17+
from utils import InputType, LoadImage, OrtInferSession
18+
from utils_table_recover import (
1819
get_rotate_crop_image,
19-
match_ocr_cell,
2020
plot_html_table,
21-
sorted_boxes,
21+
sorted_ocr_boxes, box_4_2_poly_to_box_4_1, match_ocr_cell,
22+
filter_duplicated_box, gather_ocr_list_by_row, plot_rec_box_with_logic_info, plot_rec_box, format_html,
2223
)
2324

2425
cur_dir = Path(__file__).resolve().parent
@@ -28,9 +29,9 @@
2829

2930
class LinelessTableRecognition:
3031
def __init__(
31-
self,
32-
detect_model_path: Union[str, Path] = detect_model_path,
33-
process_model_path: Union[str, Path] = process_model_path,
32+
self,
33+
detect_model_path: Union[str, Path] = detect_model_path,
34+
process_model_path: Union[str, Path] = process_model_path,
3435
):
3536
self.mean = np.array([0.408, 0.447, 0.470], dtype=np.float32).reshape(1, 1, 3)
3637
self.std = np.array([0.289, 0.274, 0.278], dtype=np.float32).reshape(1, 1, 3)
@@ -45,32 +46,70 @@ def __init__(
4546
self.det_process = DetProcess()
4647
self.ocr = RapidOCR()
4748

48-
def __call__(self, content: InputType) -> str:
49+
def __call__(self, content: InputType):
4950
ss = time.perf_counter()
5051
img = self.load_img(content)
51-
5252
ocr_res, _ = self.ocr(img)
53-
5453
input_info = self.preprocess(img)
5554
try:
5655
polygons, slct_logi = self.infer(input_info)
5756
logi_points = self.filter_logi_points(slct_logi)
57+
# ocr 结果匹配
58+
cell_box_det_map, no_match_ocr_det = match_ocr_cell(ocr_res, polygons)
59+
# 如果有识别框没有ocr结果,直接进行rec补充
60+
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map)
61+
# 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
62+
t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
63+
# 拆分包含和重叠的识别框
64+
deleted_idx_set = filter_duplicated_box([table_box_ocr['t_box'] for table_box_ocr in t_rec_ocr_list])
65+
t_rec_ocr_list = [t_rec_ocr_list[i] for i in range(len(t_rec_ocr_list)) if i not in deleted_idx_set]
66+
# 生成行列对应的二维表格, 合并同行同列识别框中的的ocr识别框
67+
t_rec_ocr_list, grid = self.handle_overlap_row_col(t_rec_ocr_list)
68+
# todo 根据grid 及 not_match_orc_boxes,尝试将ocr识别填入单行单列中
69+
# 将同一个识别框中的ocr结果排序并同行合并
70+
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
71+
# 渲染为html
72+
logi_points = [t_box_ocr['t_logic_box'] for t_box_ocr in t_rec_ocr_list]
73+
cell_box_det_map = {
74+
i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr['t_ocr_res']]
75+
for i, t_box_ocr in enumerate(t_rec_ocr_list)
76+
}
77+
table_str = plot_html_table(logi_points, cell_box_det_map)
5878

59-
sorted_polygons = sorted_boxes(polygons)
60-
61-
cell_box_map = match_ocr_cell(sorted_polygons, ocr_res)
62-
cell_box_map = self.re_rec(img, sorted_polygons, cell_box_map)
63-
64-
logi_points = self.sort_logi_by_polygons(
65-
sorted_polygons, polygons, logi_points
66-
)
67-
68-
table_str = plot_html_table(logi_points, cell_box_map)
79+
# 输出可视化排序,用于验证结果,生产版本可以去掉
80+
_, idx_list = sorted_ocr_boxes([t_box_ocr['t_box'] for t_box_ocr in t_rec_ocr_list])
81+
t_rec_ocr_list = [t_rec_ocr_list[i] for i in idx_list]
82+
sorted_polygons = [t_box_ocr['t_box'] for t_box_ocr in t_rec_ocr_list]
83+
sorted_logi_points = [t_box_ocr['t_logic_box'] for t_box_ocr in t_rec_ocr_list]
84+
ocr_boxes_res = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
85+
sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res)
6986
table_elapse = time.perf_counter() - ss
70-
return table_str, table_elapse
87+
return table_str, table_elapse, sorted_polygons, sorted_logi_points, sorted_ocr_boxes_res
7188
except Exception:
7289
logging.warning(traceback.format_exc())
73-
return "", 0.0
90+
return "", 0.0, None, None, None
91+
92+
def transform_res(self, cell_box_det_map: dict[int, List[any]], polygons: np.ndarray,
93+
logi_points: list[np.ndarray]) -> list[dict[str, any]]:
94+
res = []
95+
for i in range(len(polygons)):
96+
ocr_res_list = cell_box_det_map.get(i)
97+
if not ocr_res_list:
98+
continue
99+
xmin = min([ocr_box[0][0][0] for ocr_box in ocr_res_list])
100+
ymin = min([ocr_box[0][0][1] for ocr_box in ocr_res_list])
101+
xmax = max([ocr_box[0][2][0] for ocr_box in ocr_res_list])
102+
ymax = max([ocr_box[0][2][1] for ocr_box in ocr_res_list])
103+
dict_res = {
104+
# xmin,xmax,ymin,ymax
105+
't_box': [xmin, ymin, xmax, ymax],
106+
# row_start,row_end,col_start,col_end
107+
't_logic_box': logi_points[i].tolist(),
108+
# [[xmin,xmax,ymin,ymax], text]
109+
't_ocr_res': [[box_4_2_poly_to_box_4_1(ocr_det[0]), ocr_det[1]] for ocr_det in ocr_res_list]
110+
}
111+
res.append(dict_res)
112+
return res
74113

75114
def preprocess(self, img: np.ndarray) -> Dict[str, Any]:
76115
height, width = img.shape[:2]
@@ -115,52 +154,107 @@ def infer(self, input_content: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
115154
)
116155
return slct_output_dets, slct_logi
117156

118-
def filter_logi_points(self, slct_logi: np.ndarray) -> Dict[str, Any]:
157+
def sort_and_gather_ocr_res(self, res):
158+
for i, dict_res in enumerate(res):
159+
dict_res['t_ocr_res'] = gather_ocr_list_by_row(dict_res['t_ocr_res'])
160+
_, sorted_idx = sorted_ocr_boxes([ocr_det[0] for ocr_det in dict_res['t_ocr_res']])
161+
dict_res['t_ocr_res'] = [dict_res['t_ocr_res'][i] for i in sorted_idx]
162+
return res
163+
164+
def handle_overlap_row_col(self, res):
165+
max_row, max_col = 0, 0
166+
for dict_res in res:
167+
max_row = max(max_row, dict_res['t_logic_box'][1] + 1) # 加1是因为结束下标是包含在内的
168+
max_col = max(max_col, dict_res['t_logic_box'][3] + 1) # 加1是因为结束下标是包含在内的
169+
# 创建一个二维数组来存储 sorted_logi_points 中的元素
170+
grid = [[None] * max_col for _ in range(max_row)]
171+
# 将 sorted_logi_points 中的元素填充到 grid 中
172+
deleted_idx = set()
173+
for i, dict_res in enumerate(res):
174+
if i in deleted_idx:
175+
continue
176+
row_start, row_end, col_start, col_end = dict_res['t_logic_box']
177+
for row in range(row_start, row_end + 1):
178+
if i in deleted_idx:
179+
continue
180+
for col in range(col_start, col_end + 1):
181+
if i in deleted_idx:
182+
continue
183+
exist_dict_res = grid[row][col]
184+
if not exist_dict_res:
185+
grid[row][col] = dict_res
186+
continue
187+
if exist_dict_res['t_logic_box'] == dict_res['t_logic_box']:
188+
exist_dict_res['t_ocr_res'].extend(dict_res['t_ocr_res'])
189+
deleted_idx.add(i)
190+
# 修正识别框坐标
191+
exist_dict_res['t_box'] = [min(exist_dict_res['t_box'][0], dict_res['t_box'][0]),
192+
min(exist_dict_res['t_box'][1], dict_res['t_box'][1]),
193+
max(exist_dict_res['t_box'][2], dict_res['t_box'][2]),
194+
max(exist_dict_res['t_box'][3], dict_res['t_box'][3]),
195+
]
196+
continue
197+
198+
# 去掉重叠框
199+
res = [res[i] for i in range(len(res)) if i not in deleted_idx]
200+
return res, grid
201+
202+
@staticmethod
203+
def filter_logi_points(slct_logi: np.ndarray) -> list[np.ndarray]:
204+
for logic_points in slct_logi[0]:
205+
# 修正坐标接近导致的r_e > r_s 或 c_e > c_s
206+
if abs(logic_points[0] - logic_points[1]) < 0.2:
207+
row = (logic_points[0] + logic_points[1]) / 2
208+
logic_points[0] = row
209+
logic_points[1] = row
210+
if abs(logic_points[2] - logic_points[3]) < 0.2:
211+
col = (logic_points[2] + logic_points[3]) / 2
212+
logic_points[2] = col
213+
logic_points[3] = col
119214
logi_floor = np.floor(slct_logi)
120215
dev = slct_logi - logi_floor
121216
slct_logi = np.where(dev > 0.5, logi_floor + 1, logi_floor)
122-
return slct_logi[0]
123-
124-
@staticmethod
125-
def sort_logi_by_polygons(
126-
sorted_polygons: np.ndarray, polygons: np.ndarray, logi_points: np.ndarray
127-
) -> np.ndarray:
128-
sorted_idx = []
129-
for v in sorted_polygons:
130-
loc_idx = np.argwhere(v[0, 0] == polygons[:, 0, 0]).squeeze()
131-
sorted_idx.append(int(loc_idx))
132-
logi_points = logi_points[sorted_idx]
133-
return logi_points
217+
return slct_logi[0].astype(np.int32)
134218

135219
def re_rec(
136-
self,
137-
img: np.ndarray,
138-
sorted_polygons: np.ndarray,
139-
cell_box_map: Dict[int, List[str]],
140-
) -> Dict[int, List[str]]:
220+
self,
221+
img: np.ndarray,
222+
sorted_polygons: np.ndarray,
223+
cell_box_map: Dict[int, List[str]],
224+
) -> Dict[int, List[any]]:
141225
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
142-
for k, v in cell_box_map.items():
143-
if v[0]:
226+
#
227+
for i in range(sorted_polygons.shape[0]):
228+
if cell_box_map.get(i):
144229
continue
145-
146-
crop_img = get_rotate_crop_image(img, sorted_polygons[k])
230+
crop_img = get_rotate_crop_image(img, sorted_polygons[i])
147231
pad_img = cv2.copyMakeBorder(
148-
crop_img, 2, 2, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255)
232+
crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255)
149233
)
150234
rec_res, _ = self.ocr(pad_img, use_det=False, use_cls=True, use_rec=True)
151-
cell_box_map[k] = [rec_res[0][0]]
235+
box = sorted_polygons[i]
236+
text = [rec[0] for rec in rec_res]
237+
scores = [rec[1] for rec in rec_res]
238+
cell_box_map[i] = [[box, "".join(text), min(scores)]]
152239
return cell_box_map
153240

154241

155242
def main():
156243
parser = argparse.ArgumentParser()
157244
parser.add_argument("-img", "--img_path", type=str, required=True)
245+
parser.add_argument( "--output_dir", default= "./output", type=str)
158246
args = parser.parse_args()
159-
247+
# args.img_path = '../images/image (78).png'
160248
table_rec = LinelessTableRecognition()
161-
table_str, elapse = table_rec(args.img_path)
162-
print(table_str)
163-
print(f"cost: {elapse:.5f}")
249+
html, elasp, polygons, logic_points, ocr_res = table_rec(args.img_path)
250+
print(f"cost: {elasp:.5f}")
251+
complete_html = format_html(html)
252+
os.makedirs(os.path.dirname(f'{args.output_dir}/table.html'), exist_ok=True)
253+
with open(f'{args.output_dir}/table.html', 'w', encoding='utf-8') as file:
254+
file.write(complete_html)
255+
plot_rec_box_with_logic_info(args.img_path, f'{args.output_dir}/table_rec_box.jpg', logic_points, polygons)
256+
plot_rec_box(args.img_path, f'{args.output_dir}/ocr_box.jpg', ocr_res)
257+
164258

165259

166260
if __name__ == "__main__":

0 commit comments

Comments
 (0)