|
1 | 1 | # -*- encoding: utf-8 -*-
|
2 | 2 | # @Author: SWHL
|
3 | 3 |
|
| 4 | +import importlib |
4 | 5 | import logging
|
5 | 6 | import time
|
6 | 7 | import traceback
|
| 8 | +from dataclasses import dataclass, asdict |
| 9 | +from enum import Enum |
7 | 10 | from pathlib import Path
|
8 |
| -from typing import Any, Dict, List, Tuple, Union, Optional |
| 11 | +from typing import Dict, List, Union, Optional |
9 | 12 |
|
10 | 13 | import cv2
|
11 | 14 | import numpy as np
|
12 |
| -from rapidocr_onnxruntime import RapidOCR |
13 | 15 |
|
14 |
| -from .process import DetProcess, get_affine_transform_upper_left |
15 |
| -from .utils import InputType, LoadImage, OrtInferSession |
16 |
| -from .utils_table_recover import ( |
| 16 | +from .table_structure_lore import TSRLore |
| 17 | +from .utils.download_model import DownloadModel |
| 18 | +from .utils.utils import InputType, LoadImage |
| 19 | +from lineless_table_rec.utils.utils_table_recover import ( |
17 | 20 | box_4_2_poly_to_box_4_1,
|
18 | 21 | filter_duplicated_box,
|
19 | 22 | gather_ocr_list_by_row,
|
|
23 | 26 | sorted_ocr_boxes,
|
24 | 27 | )
|
25 | 28 |
|
26 |
| -cur_dir = Path(__file__).resolve().parent |
27 |
| -detect_model_path = cur_dir / "models" / "lore_detect.onnx" |
28 |
| -process_model_path = cur_dir / "models" / "lore_process.onnx" |
29 | 29 |
|
| 30 | +class ModelType(Enum): |
| 31 | + LORE = "lore" |
30 | 32 |
|
31 |
| -class LinelessTableRecognition: |
32 |
| - def __init__( |
33 |
| - self, |
34 |
| - detect_model_path: Union[str, Path] = detect_model_path, |
35 |
| - process_model_path: Union[str, Path] = process_model_path, |
36 |
| - ): |
37 |
| - self.mean = np.array([0.408, 0.447, 0.470], dtype=np.float32).reshape(1, 1, 3) |
38 |
| - self.std = np.array([0.289, 0.274, 0.278], dtype=np.float32).reshape(1, 1, 3) |
39 | 33 |
|
40 |
| - self.inp_h = 768 |
41 |
| - self.inp_w = 768 |
| 34 | +ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/" |
| 35 | +KEY_TO_MODEL_URL = { |
| 36 | + ModelType.LORE.value: { |
| 37 | + "lore_detect": f"{ROOT_URL}/lore/detect.onnx", |
| 38 | + "lore_process": f"{ROOT_URL}/lore/process.onnx", |
| 39 | + }, |
| 40 | +} |
| 41 | + |
| 42 | + |
| 43 | +@dataclass |
| 44 | +class RapidTableInput: |
| 45 | + model_type: Optional[str] = ModelType.LORE.value |
| 46 | + model_path: Union[str, Path, None, Dict[str, str]] = None |
| 47 | + use_cuda: bool = False |
| 48 | + device: str = "cpu" |
| 49 | + |
42 | 50 |
|
43 |
| - self.det_session = OrtInferSession(detect_model_path) |
44 |
| - self.process_session = OrtInferSession(process_model_path) |
| 51 | +@dataclass |
| 52 | +class RapidTableOutput: |
| 53 | + pred_html: Optional[str] = None |
| 54 | + cell_bboxes: Optional[np.ndarray] = None |
| 55 | + logic_points: Optional[np.ndarray] = None |
| 56 | + elapse: Optional[float] = None |
45 | 57 |
|
| 58 | + |
| 59 | +class LinelessTableRecognition: |
| 60 | + def __init__(self, config: RapidTableInput): |
| 61 | + self.model_type = config.model_type |
| 62 | + if self.model_type not in KEY_TO_MODEL_URL: |
| 63 | + model_list = ",".join(KEY_TO_MODEL_URL) |
| 64 | + raise ValueError( |
| 65 | + f"{self.model_type} is not supported. The currently supported models are {model_list}." |
| 66 | + ) |
| 67 | + |
| 68 | + config.model_path = self.get_model_path(config.model_type, config.model_path) |
| 69 | + self.table_structure = TSRLore(asdict(config)) |
46 | 70 | self.load_img = LoadImage()
|
47 |
| - self.det_process = DetProcess() |
48 |
| - self.ocr = RapidOCR() |
| 71 | + try: |
| 72 | + self.ocr = importlib.import_module("rapidocr_onnxruntime").RapidOCR() |
| 73 | + except ModuleNotFoundError: |
| 74 | + self.ocr = None |
49 | 75 |
|
50 | 76 | def __call__(
|
51 | 77 | self,
|
52 | 78 | content: InputType,
|
53 | 79 | ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
|
54 |
| - **kwargs |
55 |
| - ): |
56 |
| - ss = time.perf_counter() |
| 80 | + **kwargs, |
| 81 | + ) -> RapidTableOutput: |
| 82 | + s = time.perf_counter() |
57 | 83 | rec_again = True
|
58 | 84 | need_ocr = True
|
59 | 85 | if kwargs:
|
60 | 86 | rec_again = kwargs.get("rec_again", True)
|
61 |
| - need_ocr = kwargs.get("need_ocr", True) |
62 | 87 | img = self.load_img(content)
|
63 |
| - input_info = self.preprocess(img) |
64 | 88 | try:
|
65 |
| - polygons, slct_logi = self.infer(input_info) |
66 |
| - logi_points = self.filter_logi_points(slct_logi) |
| 89 | + polygons, logi_points = self.table_structure(img) |
67 | 90 | if not need_ocr:
|
68 | 91 | sorted_polygons, idx_list = sorted_ocr_boxes(
|
69 | 92 | [box_4_2_poly_to_box_4_1(box) for box in polygons]
|
70 | 93 | )
|
71 |
| - return ( |
| 94 | + return RapidTableOutput( |
72 | 95 | "",
|
73 |
| - time.perf_counter() - ss, |
74 | 96 | sorted_polygons,
|
75 | 97 | logi_points[idx_list],
|
76 |
| - [], |
| 98 | + time.perf_counter() - s, |
77 | 99 | )
|
78 | 100 |
|
79 | 101 | if ocr_result is None and need_ocr:
|
@@ -103,32 +125,19 @@ def __call__(
|
103 | 125 | i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
|
104 | 126 | for i, t_box_ocr in enumerate(t_rec_ocr_list)
|
105 | 127 | }
|
106 |
| - table_str = plot_html_table(logi_points, cell_box_det_map) |
| 128 | + pred_html = plot_html_table(logi_points, cell_box_det_map) |
107 | 129 |
|
108 | 130 | # 输出可视化排序,用于验证结果,生产版本可以去掉
|
109 | 131 | _, idx_list = sorted_ocr_boxes(
|
110 | 132 | [t_box_ocr["t_box"] for t_box_ocr in t_rec_ocr_list]
|
111 | 133 | )
|
112 |
| - t_rec_ocr_list = [t_rec_ocr_list[i] for i in idx_list] |
113 |
| - sorted_polygons = [t_box_ocr["t_box"] for t_box_ocr in t_rec_ocr_list] |
114 |
| - sorted_logi_points = [ |
115 |
| - t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list |
116 |
| - ] |
117 |
| - ocr_boxes_res = [ |
118 |
| - box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_result |
119 |
| - ] |
120 |
| - sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res) |
121 |
| - table_elapse = time.perf_counter() - ss |
122 |
| - return ( |
123 |
| - table_str, |
124 |
| - table_elapse, |
125 |
| - sorted_polygons, |
126 |
| - sorted_logi_points, |
127 |
| - sorted_ocr_boxes_res, |
128 |
| - ) |
| 134 | + polygons = polygons.reshape(-1, 8) |
| 135 | + logi_points = np.array(logi_points) |
| 136 | + elapse = time.perf_counter() - s |
129 | 137 | except Exception:
|
130 | 138 | logging.warning(traceback.format_exc())
|
131 |
| - return "", 0.0, None, None, None |
| 139 | + return RapidTableOutput("", None, None, 0.0) |
| 140 | + return RapidTableOutput(pred_html, polygons, logi_points, elapse) |
132 | 141 |
|
133 | 142 | def transform_res(
|
134 | 143 | self,
|
@@ -159,48 +168,27 @@ def transform_res(
|
159 | 168 | res.append(dict_res)
|
160 | 169 | return res
|
161 | 170 |
|
162 |
| - def preprocess(self, img: np.ndarray) -> Dict[str, Any]: |
163 |
| - height, width = img.shape[:2] |
164 |
| - resized_image = cv2.resize(img, (width, height)) |
165 |
| - |
166 |
| - c = np.array([0, 0], dtype=np.float32) |
167 |
| - s = max(height, width) * 1.0 |
168 |
| - trans_input = get_affine_transform_upper_left(c, s, [self.inp_w, self.inp_h]) |
169 |
| - |
170 |
| - inp_image = cv2.warpAffine( |
171 |
| - resized_image, trans_input, (self.inp_w, self.inp_h), flags=cv2.INTER_LINEAR |
172 |
| - ) |
173 |
| - inp_image = ((inp_image / 255.0 - self.mean) / self.std).astype(np.float32) |
174 |
| - |
175 |
| - images = inp_image.transpose(2, 0, 1).reshape(1, 3, self.inp_h, self.inp_w) |
176 |
| - meta = { |
177 |
| - "c": c, |
178 |
| - "s": s, |
179 |
| - "out_height": self.inp_h // 4, |
180 |
| - "out_width": self.inp_w // 4, |
181 |
| - } |
182 |
| - return {"img": images, "meta": meta} |
| 171 | + @staticmethod |
| 172 | + def get_model_path( |
| 173 | + model_type: str, model_path: Union[str, Path, None] |
| 174 | + ) -> Union[str, Dict[str, str]]: |
| 175 | + if model_path is not None: |
| 176 | + return model_path |
183 | 177 |
|
184 |
| - def infer(self, input_content: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]: |
185 |
| - hm, st, wh, ax, cr, reg = self.det_session([input_content["img"]]) |
186 |
| - output = { |
187 |
| - "hm": hm, |
188 |
| - "st": st, |
189 |
| - "wh": wh, |
190 |
| - "ax": ax, |
191 |
| - "cr": cr, |
192 |
| - "reg": reg, |
193 |
| - } |
194 |
| - slct_logi_feat, slct_dets_feat, slct_output_dets = self.det_process( |
195 |
| - output, input_content["meta"] |
196 |
| - ) |
| 178 | + model_url = KEY_TO_MODEL_URL.get(model_type, None) |
| 179 | + if isinstance(model_url, str): |
| 180 | + model_path = DownloadModel.download(model_url) |
| 181 | + return model_path |
197 | 182 |
|
198 |
| - slct_output_dets = slct_output_dets.reshape(-1, 4, 2) |
| 183 | + if isinstance(model_url, dict): |
| 184 | + model_paths = {} |
| 185 | + for k, url in model_url.items(): |
| 186 | + model_paths[k] = DownloadModel.download( |
| 187 | + url, save_model_name=f"{model_type}_{Path(url).name}" |
| 188 | + ) |
| 189 | + return model_paths |
199 | 190 |
|
200 |
| - _, slct_logi = self.process_session( |
201 |
| - [slct_logi_feat, slct_dets_feat.astype(np.int64)] |
202 |
| - ) |
203 |
| - return slct_output_dets, slct_logi |
| 191 | + raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.") |
204 | 192 |
|
205 | 193 | def sort_and_gather_ocr_res(self, res):
|
206 | 194 | for i, dict_res in enumerate(res):
|
@@ -254,23 +242,6 @@ def handle_overlap_row_col(self, res):
|
254 | 242 | res = [res[i] for i in range(len(res)) if i not in deleted_idx]
|
255 | 243 | return res, grid
|
256 | 244 |
|
257 |
| - @staticmethod |
258 |
| - def filter_logi_points(slct_logi: np.ndarray) -> List[np.ndarray]: |
259 |
| - for logic_points in slct_logi[0]: |
260 |
| - # 修正坐标接近导致的r_e > r_s 或 c_e > c_s |
261 |
| - if abs(logic_points[0] - logic_points[1]) < 0.2: |
262 |
| - row = (logic_points[0] + logic_points[1]) / 2 |
263 |
| - logic_points[0] = row |
264 |
| - logic_points[1] = row |
265 |
| - if abs(logic_points[2] - logic_points[3]) < 0.2: |
266 |
| - col = (logic_points[2] + logic_points[3]) / 2 |
267 |
| - logic_points[2] = col |
268 |
| - logic_points[3] = col |
269 |
| - logi_floor = np.floor(slct_logi) |
270 |
| - dev = slct_logi - logi_floor |
271 |
| - slct_logi = np.where(dev > 0.5, logi_floor + 1, logi_floor) |
272 |
| - return slct_logi[0].astype(np.int32) |
273 |
| - |
274 | 245 | def re_rec(
|
275 | 246 | self,
|
276 | 247 | img: np.ndarray,
|
|
0 commit comments