1
1
# -*- encoding: utf-8 -*-
2
2
# @Author: SWHL
3
3
4
- import argparse
5
4
import logging
6
- import os
7
5
import time
8
6
import traceback
9
7
from pathlib import Path
13
11
import numpy as np
14
12
from rapidocr_onnxruntime import RapidOCR
15
13
16
- from lineless_table_process import DetProcess , get_affine_transform_upper_left
17
- from utils import InputType , LoadImage , OrtInferSession
18
- from utils_table_recover import (
14
+ from .process import DetProcess , get_affine_transform_upper_left
15
+ from .utils import InputType , LoadImage , OrtInferSession
16
+ from .utils_table_recover import (
17
+ box_4_2_poly_to_box_4_1 ,
18
+ filter_duplicated_box ,
19
+ gather_ocr_list_by_row ,
19
20
get_rotate_crop_image ,
21
+ match_ocr_cell ,
20
22
plot_html_table ,
21
- sorted_ocr_boxes , box_4_2_poly_to_box_4_1 , match_ocr_cell ,
22
- filter_duplicated_box , gather_ocr_list_by_row , plot_rec_box_with_logic_info , plot_rec_box , format_html ,
23
+ sorted_ocr_boxes ,
23
24
)
24
25
25
26
cur_dir = Path (__file__ ).resolve ().parent
29
30
30
31
class LinelessTableRecognition :
31
32
def __init__ (
32
- self ,
33
- detect_model_path : Union [str , Path ] = detect_model_path ,
34
- process_model_path : Union [str , Path ] = process_model_path ,
33
+ self ,
34
+ detect_model_path : Union [str , Path ] = detect_model_path ,
35
+ process_model_path : Union [str , Path ] = process_model_path ,
35
36
):
36
37
self .mean = np .array ([0.408 , 0.447 , 0.470 ], dtype = np .float32 ).reshape (1 , 1 , 3 )
37
38
self .std = np .array ([0.289 , 0.274 , 0.278 ], dtype = np .float32 ).reshape (1 , 1 , 3 )
@@ -61,36 +62,56 @@ def __call__(self, content: InputType):
61
62
# 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
62
63
t_rec_ocr_list = self .transform_res (cell_box_det_map , polygons , logi_points )
63
64
# 拆分包含和重叠的识别框
64
- deleted_idx_set = filter_duplicated_box ([table_box_ocr ['t_box' ] for table_box_ocr in t_rec_ocr_list ])
65
- t_rec_ocr_list = [t_rec_ocr_list [i ] for i in range (len (t_rec_ocr_list )) if i not in deleted_idx_set ]
65
+ deleted_idx_set = filter_duplicated_box (
66
+ [table_box_ocr ["t_box" ] for table_box_ocr in t_rec_ocr_list ]
67
+ )
68
+ t_rec_ocr_list = [
69
+ t_rec_ocr_list [i ]
70
+ for i in range (len (t_rec_ocr_list ))
71
+ if i not in deleted_idx_set
72
+ ]
66
73
# 生成行列对应的二维表格, 合并同行同列识别框中的的ocr识别框
67
74
t_rec_ocr_list , grid = self .handle_overlap_row_col (t_rec_ocr_list )
68
75
# todo 根据grid 及 not_match_orc_boxes,尝试将ocr识别填入单行单列中
69
76
# 将同一个识别框中的ocr结果排序并同行合并
70
77
t_rec_ocr_list = self .sort_and_gather_ocr_res (t_rec_ocr_list )
71
78
# 渲染为html
72
- logi_points = [t_box_ocr [' t_logic_box' ] for t_box_ocr in t_rec_ocr_list ]
79
+ logi_points = [t_box_ocr [" t_logic_box" ] for t_box_ocr in t_rec_ocr_list ]
73
80
cell_box_det_map = {
74
- i : [ocr_box_and_text [1 ] for ocr_box_and_text in t_box_ocr [' t_ocr_res' ]]
81
+ i : [ocr_box_and_text [1 ] for ocr_box_and_text in t_box_ocr [" t_ocr_res" ]]
75
82
for i , t_box_ocr in enumerate (t_rec_ocr_list )
76
83
}
77
84
table_str = plot_html_table (logi_points , cell_box_det_map )
78
85
79
86
# 输出可视化排序,用于验证结果,生产版本可以去掉
80
- _ , idx_list = sorted_ocr_boxes ([t_box_ocr ['t_box' ] for t_box_ocr in t_rec_ocr_list ])
87
+ _ , idx_list = sorted_ocr_boxes (
88
+ [t_box_ocr ["t_box" ] for t_box_ocr in t_rec_ocr_list ]
89
+ )
81
90
t_rec_ocr_list = [t_rec_ocr_list [i ] for i in idx_list ]
82
- sorted_polygons = [t_box_ocr ['t_box' ] for t_box_ocr in t_rec_ocr_list ]
83
- sorted_logi_points = [t_box_ocr ['t_logic_box' ] for t_box_ocr in t_rec_ocr_list ]
91
+ sorted_polygons = [t_box_ocr ["t_box" ] for t_box_ocr in t_rec_ocr_list ]
92
+ sorted_logi_points = [
93
+ t_box_ocr ["t_logic_box" ] for t_box_ocr in t_rec_ocr_list
94
+ ]
84
95
ocr_boxes_res = [box_4_2_poly_to_box_4_1 (ori_ocr [0 ]) for ori_ocr in ocr_res ]
85
96
sorted_ocr_boxes_res , _ = sorted_ocr_boxes (ocr_boxes_res )
86
97
table_elapse = time .perf_counter () - ss
87
- return table_str , table_elapse , sorted_polygons , sorted_logi_points , sorted_ocr_boxes_res
98
+ return (
99
+ table_str ,
100
+ table_elapse ,
101
+ sorted_polygons ,
102
+ sorted_logi_points ,
103
+ sorted_ocr_boxes_res ,
104
+ )
88
105
except Exception :
89
106
logging .warning (traceback .format_exc ())
90
107
return "" , 0.0 , None , None , None
91
108
92
- def transform_res (self , cell_box_det_map : dict [int , List [any ]], polygons : np .ndarray ,
93
- logi_points : list [np .ndarray ]) -> list [dict [str , any ]]:
109
+ def transform_res (
110
+ self ,
111
+ cell_box_det_map : dict [int , List [any ]],
112
+ polygons : np .ndarray ,
113
+ logi_points : list [np .ndarray ],
114
+ ) -> list [dict [str , any ]]:
94
115
res = []
95
116
for i in range (len (polygons )):
96
117
ocr_res_list = cell_box_det_map .get (i )
@@ -102,11 +123,14 @@ def transform_res(self, cell_box_det_map: dict[int, List[any]], polygons: np.nda
102
123
ymax = max ([ocr_box [0 ][2 ][1 ] for ocr_box in ocr_res_list ])
103
124
dict_res = {
104
125
# xmin,xmax,ymin,ymax
105
- ' t_box' : [xmin , ymin , xmax , ymax ],
126
+ " t_box" : [xmin , ymin , xmax , ymax ],
106
127
# row_start,row_end,col_start,col_end
107
- ' t_logic_box' : logi_points [i ].tolist (),
128
+ " t_logic_box" : logi_points [i ].tolist (),
108
129
# [[xmin,xmax,ymin,ymax], text]
109
- 't_ocr_res' : [[box_4_2_poly_to_box_4_1 (ocr_det [0 ]), ocr_det [1 ]] for ocr_det in ocr_res_list ]
130
+ "t_ocr_res" : [
131
+ [box_4_2_poly_to_box_4_1 (ocr_det [0 ]), ocr_det [1 ]]
132
+ for ocr_det in ocr_res_list
133
+ ],
110
134
}
111
135
res .append (dict_res )
112
136
return res
@@ -156,24 +180,30 @@ def infer(self, input_content: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
156
180
157
181
def sort_and_gather_ocr_res (self , res ):
158
182
for i , dict_res in enumerate (res ):
159
- dict_res ['t_ocr_res' ] = gather_ocr_list_by_row (dict_res ['t_ocr_res' ])
160
- _ , sorted_idx = sorted_ocr_boxes ([ocr_det [0 ] for ocr_det in dict_res ['t_ocr_res' ]])
161
- dict_res ['t_ocr_res' ] = [dict_res ['t_ocr_res' ][i ] for i in sorted_idx ]
183
+ dict_res ["t_ocr_res" ] = gather_ocr_list_by_row (dict_res ["t_ocr_res" ])
184
+ _ , sorted_idx = sorted_ocr_boxes (
185
+ [ocr_det [0 ] for ocr_det in dict_res ["t_ocr_res" ]]
186
+ )
187
+ dict_res ["t_ocr_res" ] = [dict_res ["t_ocr_res" ][i ] for i in sorted_idx ]
162
188
return res
163
189
164
190
def handle_overlap_row_col (self , res ):
165
191
max_row , max_col = 0 , 0
166
192
for dict_res in res :
167
- max_row = max (max_row , dict_res ['t_logic_box' ][1 ] + 1 ) # 加1是因为结束下标是包含在内的
168
- max_col = max (max_col , dict_res ['t_logic_box' ][3 ] + 1 ) # 加1是因为结束下标是包含在内的
193
+ max_row = max (
194
+ max_row , dict_res ["t_logic_box" ][1 ] + 1
195
+ ) # 加1是因为结束下标是包含在内的
196
+ max_col = max (
197
+ max_col , dict_res ["t_logic_box" ][3 ] + 1
198
+ ) # 加1是因为结束下标是包含在内的
169
199
# 创建一个二维数组来存储 sorted_logi_points 中的元素
170
200
grid = [[None ] * max_col for _ in range (max_row )]
171
201
# 将 sorted_logi_points 中的元素填充到 grid 中
172
202
deleted_idx = set ()
173
203
for i , dict_res in enumerate (res ):
174
204
if i in deleted_idx :
175
205
continue
176
- row_start , row_end , col_start , col_end = dict_res [' t_logic_box' ]
206
+ row_start , row_end , col_start , col_end = dict_res [" t_logic_box" ]
177
207
for row in range (row_start , row_end + 1 ):
178
208
if i in deleted_idx :
179
209
continue
@@ -184,15 +214,16 @@ def handle_overlap_row_col(self, res):
184
214
if not exist_dict_res :
185
215
grid [row ][col ] = dict_res
186
216
continue
187
- if exist_dict_res [' t_logic_box' ] == dict_res [' t_logic_box' ]:
188
- exist_dict_res [' t_ocr_res' ].extend (dict_res [' t_ocr_res' ])
217
+ if exist_dict_res [" t_logic_box" ] == dict_res [" t_logic_box" ]:
218
+ exist_dict_res [" t_ocr_res" ].extend (dict_res [" t_ocr_res" ])
189
219
deleted_idx .add (i )
190
220
# 修正识别框坐标
191
- exist_dict_res ['t_box' ] = [min (exist_dict_res ['t_box' ][0 ], dict_res ['t_box' ][0 ]),
192
- min (exist_dict_res ['t_box' ][1 ], dict_res ['t_box' ][1 ]),
193
- max (exist_dict_res ['t_box' ][2 ], dict_res ['t_box' ][2 ]),
194
- max (exist_dict_res ['t_box' ][3 ], dict_res ['t_box' ][3 ]),
195
- ]
221
+ exist_dict_res ["t_box" ] = [
222
+ min (exist_dict_res ["t_box" ][0 ], dict_res ["t_box" ][0 ]),
223
+ min (exist_dict_res ["t_box" ][1 ], dict_res ["t_box" ][1 ]),
224
+ max (exist_dict_res ["t_box" ][2 ], dict_res ["t_box" ][2 ]),
225
+ max (exist_dict_res ["t_box" ][3 ], dict_res ["t_box" ][3 ]),
226
+ ]
196
227
continue
197
228
198
229
# 去掉重叠框
@@ -217,10 +248,10 @@ def filter_logi_points(slct_logi: np.ndarray) -> list[np.ndarray]:
217
248
return slct_logi [0 ].astype (np .int32 )
218
249
219
250
def re_rec (
220
- self ,
221
- img : np .ndarray ,
222
- sorted_polygons : np .ndarray ,
223
- cell_box_map : Dict [int , List [str ]],
251
+ self ,
252
+ img : np .ndarray ,
253
+ sorted_polygons : np .ndarray ,
254
+ cell_box_map : Dict [int , List [str ]],
224
255
) -> Dict [int , List [any ]]:
225
256
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
226
257
#
@@ -237,25 +268,3 @@ def re_rec(
237
268
scores = [rec [1 ] for rec in rec_res ]
238
269
cell_box_map [i ] = [[box , "" .join (text ), min (scores )]]
239
270
return cell_box_map
240
-
241
-
242
- def main ():
243
- parser = argparse .ArgumentParser ()
244
- parser .add_argument ("-img" , "--img_path" , type = str , required = True )
245
- parser .add_argument ( "--output_dir" , default = "./output" , type = str )
246
- args = parser .parse_args ()
247
- # args.img_path = '../images/image (78).png'
248
- table_rec = LinelessTableRecognition ()
249
- html , elasp , polygons , logic_points , ocr_res = table_rec (args .img_path )
250
- print (f"cost: { elasp :.5f} " )
251
- complete_html = format_html (html )
252
- os .makedirs (os .path .dirname (f'{ args .output_dir } /table.html' ), exist_ok = True )
253
- with open (f'{ args .output_dir } /table.html' , 'w' , encoding = 'utf-8' ) as file :
254
- file .write (complete_html )
255
- plot_rec_box_with_logic_info (args .img_path , f'{ args .output_dir } /table_rec_box.jpg' , logic_points , polygons )
256
- plot_rec_box (args .img_path , f'{ args .output_dir } /ocr_box.jpg' , ocr_res )
257
-
258
-
259
-
260
- if __name__ == "__main__" :
261
- main ()
0 commit comments