Skip to content

Commit 3fe1e58

Browse files
committed
chore: Optim code
1 parent 8ccbaa5 commit 3fe1e58

File tree

7 files changed

+170
-115
lines changed

7 files changed

+170
-115
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,4 @@ long1.jpg
157157

158158
.DS_Store
159159
*.npy
160-
/lineless_table_rec/output/
160+
outputs/

demo_lineless.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,29 @@
11
# -*- encoding: utf-8 -*-
22
# @Author: SWHL
33
# @Contact: [email protected]
4-
from pathlib import Path
4+
import os
55

66
from lineless_table_rec import LinelessTableRecognition
7+
from lineless_table_rec.utils_table_recover import (
8+
format_html,
9+
plot_rec_box,
10+
plot_rec_box_with_logic_info,
11+
)
712

8-
engine = LinelessTableRecognition()
9-
13+
output_dir = "outputs"
1014
img_path = "tests/test_files/lineless_table_recognition.jpg"
11-
table_str, elapse = engine(img_path)
15+
table_rec = LinelessTableRecognition()
16+
17+
html, elasp, polygons, logic_points, ocr_res = table_rec(img_path)
18+
print(f"cost: {elasp:.5f}")
1219

13-
print(table_str)
14-
print(elapse)
20+
complete_html = format_html(html)
21+
os.makedirs(os.path.dirname(f"{output_dir}/table.html"), exist_ok=True)
1522

16-
with open(f"{Path(img_path).stem}.html", "w", encoding="utf-8") as f:
17-
f.write(table_str)
23+
with open(f"{output_dir}/table.html", "w", encoding="utf-8") as file:
24+
file.write(complete_html)
1825

19-
print("ok")
26+
plot_rec_box_with_logic_info(
27+
img_path, f"{output_dir}/table_rec_box.jpg", logic_points, polygons
28+
)
29+
plot_rec_box(img_path, f"{output_dir}/ocr_box.jpg", ocr_res)

lineless_table_rec/main.py

+71-62
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# -*- encoding: utf-8 -*-
22
# @Author: SWHL
33
# @Contact: [email protected]
4-
import argparse
54
import logging
6-
import os
75
import time
86
import traceback
97
from pathlib import Path
@@ -13,13 +11,16 @@
1311
import numpy as np
1412
from rapidocr_onnxruntime import RapidOCR
1513

16-
from lineless_table_process import DetProcess, get_affine_transform_upper_left
17-
from utils import InputType, LoadImage, OrtInferSession
18-
from utils_table_recover import (
14+
from .process import DetProcess, get_affine_transform_upper_left
15+
from .utils import InputType, LoadImage, OrtInferSession
16+
from .utils_table_recover import (
17+
box_4_2_poly_to_box_4_1,
18+
filter_duplicated_box,
19+
gather_ocr_list_by_row,
1920
get_rotate_crop_image,
21+
match_ocr_cell,
2022
plot_html_table,
21-
sorted_ocr_boxes, box_4_2_poly_to_box_4_1, match_ocr_cell,
22-
filter_duplicated_box, gather_ocr_list_by_row, plot_rec_box_with_logic_info, plot_rec_box, format_html,
23+
sorted_ocr_boxes,
2324
)
2425

2526
cur_dir = Path(__file__).resolve().parent
@@ -29,9 +30,9 @@
2930

3031
class LinelessTableRecognition:
3132
def __init__(
32-
self,
33-
detect_model_path: Union[str, Path] = detect_model_path,
34-
process_model_path: Union[str, Path] = process_model_path,
33+
self,
34+
detect_model_path: Union[str, Path] = detect_model_path,
35+
process_model_path: Union[str, Path] = process_model_path,
3536
):
3637
self.mean = np.array([0.408, 0.447, 0.470], dtype=np.float32).reshape(1, 1, 3)
3738
self.std = np.array([0.289, 0.274, 0.278], dtype=np.float32).reshape(1, 1, 3)
@@ -61,36 +62,56 @@ def __call__(self, content: InputType):
6162
# 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
6263
t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
6364
# 拆分包含和重叠的识别框
64-
deleted_idx_set = filter_duplicated_box([table_box_ocr['t_box'] for table_box_ocr in t_rec_ocr_list])
65-
t_rec_ocr_list = [t_rec_ocr_list[i] for i in range(len(t_rec_ocr_list)) if i not in deleted_idx_set]
65+
deleted_idx_set = filter_duplicated_box(
66+
[table_box_ocr["t_box"] for table_box_ocr in t_rec_ocr_list]
67+
)
68+
t_rec_ocr_list = [
69+
t_rec_ocr_list[i]
70+
for i in range(len(t_rec_ocr_list))
71+
if i not in deleted_idx_set
72+
]
6673
# 生成行列对应的二维表格, 合并同行同列识别框中的的ocr识别框
6774
t_rec_ocr_list, grid = self.handle_overlap_row_col(t_rec_ocr_list)
6875
# todo 根据grid 及 not_match_orc_boxes,尝试将ocr识别填入单行单列中
6976
# 将同一个识别框中的ocr结果排序并同行合并
7077
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
7178
# 渲染为html
72-
logi_points = [t_box_ocr['t_logic_box'] for t_box_ocr in t_rec_ocr_list]
79+
logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
7380
cell_box_det_map = {
74-
i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr['t_ocr_res']]
81+
i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
7582
for i, t_box_ocr in enumerate(t_rec_ocr_list)
7683
}
7784
table_str = plot_html_table(logi_points, cell_box_det_map)
7885

7986
# 输出可视化排序,用于验证结果,生产版本可以去掉
80-
_, idx_list = sorted_ocr_boxes([t_box_ocr['t_box'] for t_box_ocr in t_rec_ocr_list])
87+
_, idx_list = sorted_ocr_boxes(
88+
[t_box_ocr["t_box"] for t_box_ocr in t_rec_ocr_list]
89+
)
8190
t_rec_ocr_list = [t_rec_ocr_list[i] for i in idx_list]
82-
sorted_polygons = [t_box_ocr['t_box'] for t_box_ocr in t_rec_ocr_list]
83-
sorted_logi_points = [t_box_ocr['t_logic_box'] for t_box_ocr in t_rec_ocr_list]
91+
sorted_polygons = [t_box_ocr["t_box"] for t_box_ocr in t_rec_ocr_list]
92+
sorted_logi_points = [
93+
t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list
94+
]
8495
ocr_boxes_res = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
8596
sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res)
8697
table_elapse = time.perf_counter() - ss
87-
return table_str, table_elapse, sorted_polygons, sorted_logi_points, sorted_ocr_boxes_res
98+
return (
99+
table_str,
100+
table_elapse,
101+
sorted_polygons,
102+
sorted_logi_points,
103+
sorted_ocr_boxes_res,
104+
)
88105
except Exception:
89106
logging.warning(traceback.format_exc())
90107
return "", 0.0, None, None, None
91108

92-
def transform_res(self, cell_box_det_map: dict[int, List[any]], polygons: np.ndarray,
93-
logi_points: list[np.ndarray]) -> list[dict[str, any]]:
109+
def transform_res(
110+
self,
111+
cell_box_det_map: dict[int, List[any]],
112+
polygons: np.ndarray,
113+
logi_points: list[np.ndarray],
114+
) -> list[dict[str, any]]:
94115
res = []
95116
for i in range(len(polygons)):
96117
ocr_res_list = cell_box_det_map.get(i)
@@ -102,11 +123,14 @@ def transform_res(self, cell_box_det_map: dict[int, List[any]], polygons: np.nda
102123
ymax = max([ocr_box[0][2][1] for ocr_box in ocr_res_list])
103124
dict_res = {
104125
# xmin,xmax,ymin,ymax
105-
't_box': [xmin, ymin, xmax, ymax],
126+
"t_box": [xmin, ymin, xmax, ymax],
106127
# row_start,row_end,col_start,col_end
107-
't_logic_box': logi_points[i].tolist(),
128+
"t_logic_box": logi_points[i].tolist(),
108129
# [[xmin,xmax,ymin,ymax], text]
109-
't_ocr_res': [[box_4_2_poly_to_box_4_1(ocr_det[0]), ocr_det[1]] for ocr_det in ocr_res_list]
130+
"t_ocr_res": [
131+
[box_4_2_poly_to_box_4_1(ocr_det[0]), ocr_det[1]]
132+
for ocr_det in ocr_res_list
133+
],
110134
}
111135
res.append(dict_res)
112136
return res
@@ -156,24 +180,30 @@ def infer(self, input_content: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
156180

157181
def sort_and_gather_ocr_res(self, res):
158182
for i, dict_res in enumerate(res):
159-
dict_res['t_ocr_res'] = gather_ocr_list_by_row(dict_res['t_ocr_res'])
160-
_, sorted_idx = sorted_ocr_boxes([ocr_det[0] for ocr_det in dict_res['t_ocr_res']])
161-
dict_res['t_ocr_res'] = [dict_res['t_ocr_res'][i] for i in sorted_idx]
183+
dict_res["t_ocr_res"] = gather_ocr_list_by_row(dict_res["t_ocr_res"])
184+
_, sorted_idx = sorted_ocr_boxes(
185+
[ocr_det[0] for ocr_det in dict_res["t_ocr_res"]]
186+
)
187+
dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
162188
return res
163189

164190
def handle_overlap_row_col(self, res):
165191
max_row, max_col = 0, 0
166192
for dict_res in res:
167-
max_row = max(max_row, dict_res['t_logic_box'][1] + 1) # 加1是因为结束下标是包含在内的
168-
max_col = max(max_col, dict_res['t_logic_box'][3] + 1) # 加1是因为结束下标是包含在内的
193+
max_row = max(
194+
max_row, dict_res["t_logic_box"][1] + 1
195+
) # 加1是因为结束下标是包含在内的
196+
max_col = max(
197+
max_col, dict_res["t_logic_box"][3] + 1
198+
) # 加1是因为结束下标是包含在内的
169199
# 创建一个二维数组来存储 sorted_logi_points 中的元素
170200
grid = [[None] * max_col for _ in range(max_row)]
171201
# 将 sorted_logi_points 中的元素填充到 grid 中
172202
deleted_idx = set()
173203
for i, dict_res in enumerate(res):
174204
if i in deleted_idx:
175205
continue
176-
row_start, row_end, col_start, col_end = dict_res['t_logic_box']
206+
row_start, row_end, col_start, col_end = dict_res["t_logic_box"]
177207
for row in range(row_start, row_end + 1):
178208
if i in deleted_idx:
179209
continue
@@ -184,15 +214,16 @@ def handle_overlap_row_col(self, res):
184214
if not exist_dict_res:
185215
grid[row][col] = dict_res
186216
continue
187-
if exist_dict_res['t_logic_box'] == dict_res['t_logic_box']:
188-
exist_dict_res['t_ocr_res'].extend(dict_res['t_ocr_res'])
217+
if exist_dict_res["t_logic_box"] == dict_res["t_logic_box"]:
218+
exist_dict_res["t_ocr_res"].extend(dict_res["t_ocr_res"])
189219
deleted_idx.add(i)
190220
# 修正识别框坐标
191-
exist_dict_res['t_box'] = [min(exist_dict_res['t_box'][0], dict_res['t_box'][0]),
192-
min(exist_dict_res['t_box'][1], dict_res['t_box'][1]),
193-
max(exist_dict_res['t_box'][2], dict_res['t_box'][2]),
194-
max(exist_dict_res['t_box'][3], dict_res['t_box'][3]),
195-
]
221+
exist_dict_res["t_box"] = [
222+
min(exist_dict_res["t_box"][0], dict_res["t_box"][0]),
223+
min(exist_dict_res["t_box"][1], dict_res["t_box"][1]),
224+
max(exist_dict_res["t_box"][2], dict_res["t_box"][2]),
225+
max(exist_dict_res["t_box"][3], dict_res["t_box"][3]),
226+
]
196227
continue
197228

198229
# 去掉重叠框
@@ -217,10 +248,10 @@ def filter_logi_points(slct_logi: np.ndarray) -> list[np.ndarray]:
217248
return slct_logi[0].astype(np.int32)
218249

219250
def re_rec(
220-
self,
221-
img: np.ndarray,
222-
sorted_polygons: np.ndarray,
223-
cell_box_map: Dict[int, List[str]],
251+
self,
252+
img: np.ndarray,
253+
sorted_polygons: np.ndarray,
254+
cell_box_map: Dict[int, List[str]],
224255
) -> Dict[int, List[any]]:
225256
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
226257
#
@@ -237,25 +268,3 @@ def re_rec(
237268
scores = [rec[1] for rec in rec_res]
238269
cell_box_map[i] = [[box, "".join(text), min(scores)]]
239270
return cell_box_map
240-
241-
242-
def main():
243-
parser = argparse.ArgumentParser()
244-
parser.add_argument("-img", "--img_path", type=str, required=True)
245-
parser.add_argument( "--output_dir", default= "./output", type=str)
246-
args = parser.parse_args()
247-
# args.img_path = '../images/image (78).png'
248-
table_rec = LinelessTableRecognition()
249-
html, elasp, polygons, logic_points, ocr_res = table_rec(args.img_path)
250-
print(f"cost: {elasp:.5f}")
251-
complete_html = format_html(html)
252-
os.makedirs(os.path.dirname(f'{args.output_dir}/table.html'), exist_ok=True)
253-
with open(f'{args.output_dir}/table.html', 'w', encoding='utf-8') as file:
254-
file.write(complete_html)
255-
plot_rec_box_with_logic_info(args.img_path, f'{args.output_dir}/table_rec_box.jpg', logic_points, polygons)
256-
plot_rec_box(args.img_path, f'{args.output_dir}/ocr_box.jpg', ocr_res)
257-
258-
259-
260-
if __name__ == "__main__":
261-
main()

0 commit comments

Comments
 (0)