3
3
4
4
import argparse
5
5
import logging
6
+ import os
6
7
import time
7
8
import traceback
8
9
from pathlib import Path
12
13
import numpy as np
13
14
from rapidocr_onnxruntime import RapidOCR
14
15
15
- from . lineless_table_process import DetProcess , get_affine_transform_upper_left
16
- from . utils import InputType , LoadImage , OrtInferSession
17
- from . utils_table_recover import (
16
+ from lineless_table_process import DetProcess , get_affine_transform_upper_left
17
+ from utils import InputType , LoadImage , OrtInferSession
18
+ from utils_table_recover import (
18
19
get_rotate_crop_image ,
19
- match_ocr_cell ,
20
20
plot_html_table ,
21
- sorted_boxes ,
21
+ sorted_ocr_boxes , box_4_2_poly_to_box_4_1 , match_ocr_cell ,
22
+ filter_duplicated_box , gather_ocr_list_by_row , plot_rec_box_with_logic_info , plot_rec_box , format_html ,
22
23
)
23
24
24
25
cur_dir = Path (__file__ ).resolve ().parent
28
29
29
30
class LinelessTableRecognition :
30
31
def __init__ (
31
- self ,
32
- detect_model_path : Union [str , Path ] = detect_model_path ,
33
- process_model_path : Union [str , Path ] = process_model_path ,
32
+ self ,
33
+ detect_model_path : Union [str , Path ] = detect_model_path ,
34
+ process_model_path : Union [str , Path ] = process_model_path ,
34
35
):
35
36
self .mean = np .array ([0.408 , 0.447 , 0.470 ], dtype = np .float32 ).reshape (1 , 1 , 3 )
36
37
self .std = np .array ([0.289 , 0.274 , 0.278 ], dtype = np .float32 ).reshape (1 , 1 , 3 )
@@ -45,32 +46,70 @@ def __init__(
45
46
self .det_process = DetProcess ()
46
47
self .ocr = RapidOCR ()
47
48
48
- def __call__ (self , content : InputType ) -> str :
49
+ def __call__ (self , content : InputType ):
49
50
ss = time .perf_counter ()
50
51
img = self .load_img (content )
51
-
52
52
ocr_res , _ = self .ocr (img )
53
-
54
53
input_info = self .preprocess (img )
55
54
try :
56
55
polygons , slct_logi = self .infer (input_info )
57
56
logi_points = self .filter_logi_points (slct_logi )
57
+ # ocr 结果匹配
58
+ cell_box_det_map , no_match_ocr_det = match_ocr_cell (ocr_res , polygons )
59
+ # 如果有识别框没有ocr结果,直接进行rec补充
60
+ cell_box_det_map = self .re_rec (img , polygons , cell_box_det_map )
61
+ # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
62
+ t_rec_ocr_list = self .transform_res (cell_box_det_map , polygons , logi_points )
63
+ # 拆分包含和重叠的识别框
64
+ deleted_idx_set = filter_duplicated_box ([table_box_ocr ['t_box' ] for table_box_ocr in t_rec_ocr_list ])
65
+ t_rec_ocr_list = [t_rec_ocr_list [i ] for i in range (len (t_rec_ocr_list )) if i not in deleted_idx_set ]
66
+ # 生成行列对应的二维表格, 合并同行同列识别框中的的ocr识别框
67
+ t_rec_ocr_list , grid = self .handle_overlap_row_col (t_rec_ocr_list )
68
+ # todo 根据grid 及 not_match_orc_boxes,尝试将ocr识别填入单行单列中
69
+ # 将同一个识别框中的ocr结果排序并同行合并
70
+ t_rec_ocr_list = self .sort_and_gather_ocr_res (t_rec_ocr_list )
71
+ # 渲染为html
72
+ logi_points = [t_box_ocr ['t_logic_box' ] for t_box_ocr in t_rec_ocr_list ]
73
+ cell_box_det_map = {
74
+ i : [ocr_box_and_text [1 ] for ocr_box_and_text in t_box_ocr ['t_ocr_res' ]]
75
+ for i , t_box_ocr in enumerate (t_rec_ocr_list )
76
+ }
77
+ table_str = plot_html_table (logi_points , cell_box_det_map )
58
78
59
- sorted_polygons = sorted_boxes (polygons )
60
-
61
- cell_box_map = match_ocr_cell (sorted_polygons , ocr_res )
62
- cell_box_map = self .re_rec (img , sorted_polygons , cell_box_map )
63
-
64
- logi_points = self .sort_logi_by_polygons (
65
- sorted_polygons , polygons , logi_points
66
- )
67
-
68
- table_str = plot_html_table (logi_points , cell_box_map )
79
+ # 输出可视化排序,用于验证结果,生产版本可以去掉
80
+ _ , idx_list = sorted_ocr_boxes ([t_box_ocr ['t_box' ] for t_box_ocr in t_rec_ocr_list ])
81
+ t_rec_ocr_list = [t_rec_ocr_list [i ] for i in idx_list ]
82
+ sorted_polygons = [t_box_ocr ['t_box' ] for t_box_ocr in t_rec_ocr_list ]
83
+ sorted_logi_points = [t_box_ocr ['t_logic_box' ] for t_box_ocr in t_rec_ocr_list ]
84
+ ocr_boxes_res = [box_4_2_poly_to_box_4_1 (ori_ocr [0 ]) for ori_ocr in ocr_res ]
85
+ sorted_ocr_boxes_res , _ = sorted_ocr_boxes (ocr_boxes_res )
69
86
table_elapse = time .perf_counter () - ss
70
- return table_str , table_elapse
87
+ return table_str , table_elapse , sorted_polygons , sorted_logi_points , sorted_ocr_boxes_res
71
88
except Exception :
72
89
logging .warning (traceback .format_exc ())
73
- return "" , 0.0
90
+ return "" , 0.0 , None , None , None
91
+
92
+ def transform_res (self , cell_box_det_map : dict [int , List [any ]], polygons : np .ndarray ,
93
+ logi_points : list [np .ndarray ]) -> list [dict [str , any ]]:
94
+ res = []
95
+ for i in range (len (polygons )):
96
+ ocr_res_list = cell_box_det_map .get (i )
97
+ if not ocr_res_list :
98
+ continue
99
+ xmin = min ([ocr_box [0 ][0 ][0 ] for ocr_box in ocr_res_list ])
100
+ ymin = min ([ocr_box [0 ][0 ][1 ] for ocr_box in ocr_res_list ])
101
+ xmax = max ([ocr_box [0 ][2 ][0 ] for ocr_box in ocr_res_list ])
102
+ ymax = max ([ocr_box [0 ][2 ][1 ] for ocr_box in ocr_res_list ])
103
+ dict_res = {
104
+ # xmin,xmax,ymin,ymax
105
+ 't_box' : [xmin , ymin , xmax , ymax ],
106
+ # row_start,row_end,col_start,col_end
107
+ 't_logic_box' : logi_points [i ].tolist (),
108
+ # [[xmin,xmax,ymin,ymax], text]
109
+ 't_ocr_res' : [[box_4_2_poly_to_box_4_1 (ocr_det [0 ]), ocr_det [1 ]] for ocr_det in ocr_res_list ]
110
+ }
111
+ res .append (dict_res )
112
+ return res
74
113
75
114
def preprocess (self , img : np .ndarray ) -> Dict [str , Any ]:
76
115
height , width = img .shape [:2 ]
@@ -115,52 +154,107 @@ def infer(self, input_content: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
115
154
)
116
155
return slct_output_dets , slct_logi
117
156
118
- def filter_logi_points (self , slct_logi : np .ndarray ) -> Dict [str , Any ]:
157
+ def sort_and_gather_ocr_res (self , res ):
158
+ for i , dict_res in enumerate (res ):
159
+ dict_res ['t_ocr_res' ] = gather_ocr_list_by_row (dict_res ['t_ocr_res' ])
160
+ _ , sorted_idx = sorted_ocr_boxes ([ocr_det [0 ] for ocr_det in dict_res ['t_ocr_res' ]])
161
+ dict_res ['t_ocr_res' ] = [dict_res ['t_ocr_res' ][i ] for i in sorted_idx ]
162
+ return res
163
+
164
+ def handle_overlap_row_col (self , res ):
165
+ max_row , max_col = 0 , 0
166
+ for dict_res in res :
167
+ max_row = max (max_row , dict_res ['t_logic_box' ][1 ] + 1 ) # 加1是因为结束下标是包含在内的
168
+ max_col = max (max_col , dict_res ['t_logic_box' ][3 ] + 1 ) # 加1是因为结束下标是包含在内的
169
+ # 创建一个二维数组来存储 sorted_logi_points 中的元素
170
+ grid = [[None ] * max_col for _ in range (max_row )]
171
+ # 将 sorted_logi_points 中的元素填充到 grid 中
172
+ deleted_idx = set ()
173
+ for i , dict_res in enumerate (res ):
174
+ if i in deleted_idx :
175
+ continue
176
+ row_start , row_end , col_start , col_end = dict_res ['t_logic_box' ]
177
+ for row in range (row_start , row_end + 1 ):
178
+ if i in deleted_idx :
179
+ continue
180
+ for col in range (col_start , col_end + 1 ):
181
+ if i in deleted_idx :
182
+ continue
183
+ exist_dict_res = grid [row ][col ]
184
+ if not exist_dict_res :
185
+ grid [row ][col ] = dict_res
186
+ continue
187
+ if exist_dict_res ['t_logic_box' ] == dict_res ['t_logic_box' ]:
188
+ exist_dict_res ['t_ocr_res' ].extend (dict_res ['t_ocr_res' ])
189
+ deleted_idx .add (i )
190
+ # 修正识别框坐标
191
+ exist_dict_res ['t_box' ] = [min (exist_dict_res ['t_box' ][0 ], dict_res ['t_box' ][0 ]),
192
+ min (exist_dict_res ['t_box' ][1 ], dict_res ['t_box' ][1 ]),
193
+ max (exist_dict_res ['t_box' ][2 ], dict_res ['t_box' ][2 ]),
194
+ max (exist_dict_res ['t_box' ][3 ], dict_res ['t_box' ][3 ]),
195
+ ]
196
+ continue
197
+
198
+ # 去掉重叠框
199
+ res = [res [i ] for i in range (len (res )) if i not in deleted_idx ]
200
+ return res , grid
201
+
202
+ @staticmethod
203
+ def filter_logi_points (slct_logi : np .ndarray ) -> list [np .ndarray ]:
204
+ for logic_points in slct_logi [0 ]:
205
+ # 修正坐标接近导致的r_e > r_s 或 c_e > c_s
206
+ if abs (logic_points [0 ] - logic_points [1 ]) < 0.2 :
207
+ row = (logic_points [0 ] + logic_points [1 ]) / 2
208
+ logic_points [0 ] = row
209
+ logic_points [1 ] = row
210
+ if abs (logic_points [2 ] - logic_points [3 ]) < 0.2 :
211
+ col = (logic_points [2 ] + logic_points [3 ]) / 2
212
+ logic_points [2 ] = col
213
+ logic_points [3 ] = col
119
214
logi_floor = np .floor (slct_logi )
120
215
dev = slct_logi - logi_floor
121
216
slct_logi = np .where (dev > 0.5 , logi_floor + 1 , logi_floor )
122
- return slct_logi [0 ]
123
-
124
- @staticmethod
125
- def sort_logi_by_polygons (
126
- sorted_polygons : np .ndarray , polygons : np .ndarray , logi_points : np .ndarray
127
- ) -> np .ndarray :
128
- sorted_idx = []
129
- for v in sorted_polygons :
130
- loc_idx = np .argwhere (v [0 , 0 ] == polygons [:, 0 , 0 ]).squeeze ()
131
- sorted_idx .append (int (loc_idx ))
132
- logi_points = logi_points [sorted_idx ]
133
- return logi_points
217
+ return slct_logi [0 ].astype (np .int32 )
134
218
135
219
def re_rec (
136
- self ,
137
- img : np .ndarray ,
138
- sorted_polygons : np .ndarray ,
139
- cell_box_map : Dict [int , List [str ]],
140
- ) -> Dict [int , List [str ]]:
220
+ self ,
221
+ img : np .ndarray ,
222
+ sorted_polygons : np .ndarray ,
223
+ cell_box_map : Dict [int , List [str ]],
224
+ ) -> Dict [int , List [any ]]:
141
225
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
142
- for k , v in cell_box_map .items ():
143
- if v [0 ]:
226
+ #
227
+ for i in range (sorted_polygons .shape [0 ]):
228
+ if cell_box_map .get (i ):
144
229
continue
145
-
146
- crop_img = get_rotate_crop_image (img , sorted_polygons [k ])
230
+ crop_img = get_rotate_crop_image (img , sorted_polygons [i ])
147
231
pad_img = cv2 .copyMakeBorder (
148
- crop_img , 2 , 2 , 100 , 100 , cv2 .BORDER_CONSTANT , value = (255 , 255 , 255 )
232
+ crop_img , 5 , 5 , 100 , 100 , cv2 .BORDER_CONSTANT , value = (255 , 255 , 255 )
149
233
)
150
234
rec_res , _ = self .ocr (pad_img , use_det = False , use_cls = True , use_rec = True )
151
- cell_box_map [k ] = [rec_res [0 ][0 ]]
235
+ box = sorted_polygons [i ]
236
+ text = [rec [0 ] for rec in rec_res ]
237
+ scores = [rec [1 ] for rec in rec_res ]
238
+ cell_box_map [i ] = [[box , "" .join (text ), min (scores )]]
152
239
return cell_box_map
153
240
154
241
155
242
def main ():
156
243
parser = argparse .ArgumentParser ()
157
244
parser .add_argument ("-img" , "--img_path" , type = str , required = True )
245
+ parser .add_argument ( "--output_dir" , default = "./output" , type = str )
158
246
args = parser .parse_args ()
159
-
247
+ # args.img_path = '../images/image (78).png'
160
248
table_rec = LinelessTableRecognition ()
161
- table_str , elapse = table_rec (args .img_path )
162
- print (table_str )
163
- print (f"cost: { elapse :.5f} " )
249
+ html , elasp , polygons , logic_points , ocr_res = table_rec (args .img_path )
250
+ print (f"cost: { elasp :.5f} " )
251
+ complete_html = format_html (html )
252
+ os .makedirs (os .path .dirname (f'{ args .output_dir } /table.html' ), exist_ok = True )
253
+ with open (f'{ args .output_dir } /table.html' , 'w' , encoding = 'utf-8' ) as file :
254
+ file .write (complete_html )
255
+ plot_rec_box_with_logic_info (args .img_path , f'{ args .output_dir } /table_rec_box.jpg' , logic_points , polygons )
256
+ plot_rec_box (args .img_path , f'{ args .output_dir } /ocr_box.jpg' , ocr_res )
257
+
164
258
165
259
166
260
if __name__ == "__main__" :
0 commit comments