Skip to content

Commit dad3a9d

Browse files
committed
feat: lineless table rec support ocr result param
1 parent 1db4655 commit dad3a9d

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

lineless_table_rec/main.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
import traceback
77
from pathlib import Path
8-
from typing import Any, Dict, List, Tuple, Union
8+
from typing import Any, Dict, List, Tuple, Union, Optional
99

1010
import cv2
1111
import numpy as np
@@ -47,16 +47,25 @@ def __init__(
4747
self.det_process = DetProcess()
4848
self.ocr = RapidOCR()
4949

50-
def __call__(self, content: InputType):
50+
def __call__(
51+
self,
52+
content: InputType,
53+
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
54+
):
5155
ss = time.perf_counter()
5256
img = self.load_img(content)
53-
ocr_res, _ = self.ocr(img)
57+
if self.ocr is None and ocr_result is None:
58+
raise ValueError(
59+
"One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
60+
)
61+
if ocr_result is None:
62+
ocr_result, _ = self.ocr(img)
5463
input_info = self.preprocess(img)
5564
try:
5665
polygons, slct_logi = self.infer(input_info)
5766
logi_points = self.filter_logi_points(slct_logi)
5867
# ocr 结果匹配
59-
cell_box_det_map, no_match_ocr_det = match_ocr_cell(ocr_res, polygons)
68+
cell_box_det_map, no_match_ocr_det = match_ocr_cell(ocr_result, polygons)
6069
# 如果有识别框没有ocr结果,直接进行rec补充
6170
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map)
6271
# 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
@@ -92,7 +101,9 @@ def __call__(self, content: InputType):
92101
sorted_logi_points = [
93102
t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list
94103
]
95-
ocr_boxes_res = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
104+
ocr_boxes_res = [
105+
box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_result
106+
]
96107
sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res)
97108
table_elapse = time.perf_counter() - ss
98109
return (

0 commit comments

Comments
 (0)