|
5 | 5 | import time
|
6 | 6 | import traceback
|
7 | 7 | from pathlib import Path
|
8 |
| -from typing import Any, Dict, List, Tuple, Union |
| 8 | +from typing import Any, Dict, List, Tuple, Union, Optional |
9 | 9 |
|
10 | 10 | import cv2
|
11 | 11 | import numpy as np
|
@@ -47,16 +47,25 @@ def __init__(
|
47 | 47 | self.det_process = DetProcess()
|
48 | 48 | self.ocr = RapidOCR()
|
49 | 49 |
|
50 |
| - def __call__(self, content: InputType): |
| 50 | + def __call__( |
| 51 | + self, |
| 52 | + content: InputType, |
| 53 | + ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None, |
| 54 | + ): |
51 | 55 | ss = time.perf_counter()
|
52 | 56 | img = self.load_img(content)
|
53 |
| - ocr_res, _ = self.ocr(img) |
| 57 | + if self.ocr is None and ocr_result is None: |
| 58 | + raise ValueError( |
| 59 | + "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed." |
| 60 | + ) |
| 61 | + if ocr_result is None: |
| 62 | + ocr_result, _ = self.ocr(img) |
54 | 63 | input_info = self.preprocess(img)
|
55 | 64 | try:
|
56 | 65 | polygons, slct_logi = self.infer(input_info)
|
57 | 66 | logi_points = self.filter_logi_points(slct_logi)
|
58 | 67 | # ocr 结果匹配
|
59 |
| - cell_box_det_map, no_match_ocr_det = match_ocr_cell(ocr_res, polygons) |
| 68 | + cell_box_det_map, no_match_ocr_det = match_ocr_cell(ocr_result, polygons) |
60 | 69 | # 如果有识别框没有ocr结果,直接进行rec补充
|
61 | 70 | cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map)
|
62 | 71 | # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
|
@@ -92,7 +101,9 @@ def __call__(self, content: InputType):
|
92 | 101 | sorted_logi_points = [
|
93 | 102 | t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list
|
94 | 103 | ]
|
95 |
| - ocr_boxes_res = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res] |
| 104 | + ocr_boxes_res = [ |
| 105 | + box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_result |
| 106 | + ] |
96 | 107 | sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res)
|
97 | 108 | table_elapse = time.perf_counter() - ss
|
98 | 109 | return (
|
|
0 commit comments