Skip to content

Commit a16afd5

Browse files
committed
feat: trans to rapidTable code style
1 parent f492e8f commit a16afd5

26 files changed

+2501
-1280
lines changed

demo_lineless.py

+33-18
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,45 @@
33
# @Contact: [email protected]
44
from pathlib import Path
55

6+
from rapidocr_onnxruntime import RapidOCR
7+
68
from lineless_table_rec import LinelessTableRecognition
7-
from lineless_table_rec.utils_table_recover import (
8-
format_html,
9-
plot_rec_box,
10-
plot_rec_box_with_logic_info,
11-
)
9+
from lineless_table_rec.main import RapidTableInput
10+
from lineless_table_rec.utils.utils import VisTable
1211

1312
output_dir = Path("outputs")
1413
output_dir.mkdir(parents=True, exist_ok=True)
14+
input_args = RapidTableInput()
15+
table_engine = LinelessTableRecognition(input_args)
16+
ocr_engine = RapidOCR()
17+
viser = VisTable()
18+
19+
if __name__ == "__main__":
20+
img_path = "tests/test_files/lineless_table_recognition.jpg"
21+
22+
ocr_result, _ = ocr_engine(img_path)
23+
boxes, txts, scores = list(zip(*ocr_result))
1524

16-
img_path = "tests/test_files/lineless_table_recognition.jpg"
17-
table_rec = LinelessTableRecognition()
25+
# Table Rec
26+
table_results = table_engine(img_path)
27+
table_html_str, table_cell_bboxes = (
28+
table_results.pred_html,
29+
table_results.cell_bboxes,
30+
)
1831

19-
html, elasp, polygons, logic_points, ocr_res = table_rec(img_path)
20-
print(f"cost: {elasp:.5f}")
32+
# Save
33+
save_dir = Path("outputs")
34+
save_dir.mkdir(parents=True, exist_ok=True)
2135

22-
complete_html = format_html(html)
36+
save_html_path = f"outputs/{Path(img_path).stem}.html"
37+
save_drawed_path = f"outputs/{Path(img_path).stem}_table_vis{Path(img_path).suffix}"
38+
save_logic_path = (
39+
f"outputs/{Path(img_path).stem}_table_vis_logic{Path(img_path).suffix}"
40+
)
2341

24-
save_table_path = output_dir / "table.html"
25-
with open(save_table_path, "w", encoding="utf-8") as file:
26-
file.write(complete_html)
42+
# Visualize table rec result
43+
vis_imged = viser(
44+
img_path, table_results, save_html_path, save_drawed_path, save_logic_path
45+
)
2746

28-
plot_rec_box_with_logic_info(
29-
img_path, f"{output_dir}/table_rec_box.jpg", logic_points, polygons
30-
)
31-
plot_rec_box(img_path, f"{output_dir}/ocr_box.jpg", ocr_res)
32-
print(f"The results has been saved under {output_dir}")
47+
print(f"The results has been saved under {output_dir}")

demo_wired.py

+37-25
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,44 @@
33
# @Contact: [email protected]
44
from pathlib import Path
55

6+
from rapidocr_onnxruntime import RapidOCR
7+
68
from wired_table_rec import WiredTableRecognition
7-
from wired_table_rec.utils_table_recover import (
8-
format_html,
9-
plot_rec_box,
10-
plot_rec_box_with_logic_info,
11-
)
9+
from wired_table_rec.main import RapidTableInput, ModelType
10+
from wired_table_rec.utils.utils import VisTable
1211

1312
output_dir = Path("outputs")
1413
output_dir.mkdir(parents=True, exist_ok=True)
15-
16-
table_rec = WiredTableRecognition()
17-
18-
img_path = "tests/test_files/wired/table1.png"
19-
html, elasp, polygons, logic_points, ocr_res = table_rec(img_path)
20-
21-
print(f"cost: {elasp:.5f}")
22-
23-
complete_html = format_html(html)
24-
25-
save_table_path = output_dir / "table.html"
26-
with open(save_table_path, "w", encoding="utf-8") as file:
27-
file.write(complete_html)
28-
29-
plot_rec_box_with_logic_info(
30-
img_path, f"{output_dir}/table_rec_box.jpg", logic_points, polygons
31-
)
32-
plot_rec_box(img_path, f"{output_dir}/ocr_box.jpg", ocr_res)
33-
34-
print(f"The results has been saved under {output_dir}")
14+
input_args = RapidTableInput(model_type=ModelType.CYCLE_CENTER_NET.value)
15+
table_engine = WiredTableRecognition(input_args)
16+
ocr_engine = RapidOCR()
17+
viser = VisTable()
18+
if __name__ == "__main__":
19+
img_path = "tests/test_files/wired/bad_case_1.png"
20+
21+
ocr_result, _ = ocr_engine(img_path)
22+
boxes, txts, scores = list(zip(*ocr_result))
23+
24+
# Table Rec
25+
table_results = table_engine(img_path)
26+
table_html_str, table_cell_bboxes = (
27+
table_results.pred_html,
28+
table_results.cell_bboxes,
29+
)
30+
31+
# Save
32+
save_dir = Path("outputs")
33+
save_dir.mkdir(parents=True, exist_ok=True)
34+
35+
save_html_path = f"outputs/{Path(img_path).stem}.html"
36+
save_drawed_path = f"outputs/{Path(img_path).stem}_table_vis{Path(img_path).suffix}"
37+
save_logic_path = (
38+
f"outputs/{Path(img_path).stem}_table_vis_logic{Path(img_path).suffix}"
39+
)
40+
41+
# Visualize table rec result
42+
vis_imged = viser(
43+
img_path, table_results, save_html_path, save_drawed_path, save_logic_path
44+
)
45+
46+
print(f"The results has been saved under {output_dir}")

lineless_table_rec/main.py

+78-107
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
# -*- encoding: utf-8 -*-
22
# @Author: SWHL
33
# @Contact: [email protected]
4+
import importlib
45
import logging
56
import time
67
import traceback
8+
from dataclasses import dataclass, asdict
9+
from enum import Enum
710
from pathlib import Path
8-
from typing import Any, Dict, List, Tuple, Union, Optional
11+
from typing import Dict, List, Union, Optional
912

1013
import cv2
1114
import numpy as np
12-
from rapidocr_onnxruntime import RapidOCR
1315

14-
from .process import DetProcess, get_affine_transform_upper_left
15-
from .utils import InputType, LoadImage, OrtInferSession
16-
from .utils_table_recover import (
16+
from .table_structure_lore import TSRLore
17+
from .utils.download_model import DownloadModel
18+
from .utils.utils import InputType, LoadImage
19+
from lineless_table_rec.utils.utils_table_recover import (
1720
box_4_2_poly_to_box_4_1,
1821
filter_duplicated_box,
1922
gather_ocr_list_by_row,
@@ -23,57 +26,76 @@
2326
sorted_ocr_boxes,
2427
)
2528

26-
cur_dir = Path(__file__).resolve().parent
27-
detect_model_path = cur_dir / "models" / "lore_detect.onnx"
28-
process_model_path = cur_dir / "models" / "lore_process.onnx"
2929

30+
class ModelType(Enum):
31+
LORE = "lore"
3032

31-
class LinelessTableRecognition:
32-
def __init__(
33-
self,
34-
detect_model_path: Union[str, Path] = detect_model_path,
35-
process_model_path: Union[str, Path] = process_model_path,
36-
):
37-
self.mean = np.array([0.408, 0.447, 0.470], dtype=np.float32).reshape(1, 1, 3)
38-
self.std = np.array([0.289, 0.274, 0.278], dtype=np.float32).reshape(1, 1, 3)
3933

40-
self.inp_h = 768
41-
self.inp_w = 768
34+
ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
35+
KEY_TO_MODEL_URL = {
36+
ModelType.LORE.value: {
37+
"lore_detect": f"{ROOT_URL}/lore/detect.onnx",
38+
"lore_process": f"{ROOT_URL}/lore/process.onnx",
39+
},
40+
}
41+
42+
43+
@dataclass
44+
class RapidTableInput:
45+
model_type: Optional[str] = ModelType.LORE.value
46+
model_path: Union[str, Path, None, Dict[str, str]] = None
47+
use_cuda: bool = False
48+
device: str = "cpu"
49+
4250

43-
self.det_session = OrtInferSession(detect_model_path)
44-
self.process_session = OrtInferSession(process_model_path)
51+
@dataclass
52+
class RapidTableOutput:
53+
pred_html: Optional[str] = None
54+
cell_bboxes: Optional[np.ndarray] = None
55+
logic_points: Optional[np.ndarray] = None
56+
elapse: Optional[float] = None
4557

58+
59+
class LinelessTableRecognition:
60+
def __init__(self, config: RapidTableInput):
61+
self.model_type = config.model_type
62+
if self.model_type not in KEY_TO_MODEL_URL:
63+
model_list = ",".join(KEY_TO_MODEL_URL)
64+
raise ValueError(
65+
f"{self.model_type} is not supported. The currently supported models are {model_list}."
66+
)
67+
68+
config.model_path = self.get_model_path(config.model_type, config.model_path)
69+
self.table_structure = TSRLore(asdict(config))
4670
self.load_img = LoadImage()
47-
self.det_process = DetProcess()
48-
self.ocr = RapidOCR()
71+
try:
72+
self.ocr = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
73+
except ModuleNotFoundError:
74+
self.ocr = None
4975

5076
def __call__(
5177
self,
5278
content: InputType,
5379
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
54-
**kwargs
55-
):
56-
ss = time.perf_counter()
80+
**kwargs,
81+
) -> RapidTableOutput:
82+
s = time.perf_counter()
5783
rec_again = True
5884
need_ocr = True
5985
if kwargs:
6086
rec_again = kwargs.get("rec_again", True)
61-
need_ocr = kwargs.get("need_ocr", True)
6287
img = self.load_img(content)
63-
input_info = self.preprocess(img)
6488
try:
65-
polygons, slct_logi = self.infer(input_info)
66-
logi_points = self.filter_logi_points(slct_logi)
89+
polygons, logi_points = self.table_structure(img)
6790
if not need_ocr:
6891
sorted_polygons, idx_list = sorted_ocr_boxes(
6992
[box_4_2_poly_to_box_4_1(box) for box in polygons]
7093
)
71-
return (
94+
return RapidTableOutput(
7295
"",
73-
time.perf_counter() - ss,
7496
sorted_polygons,
7597
logi_points[idx_list],
76-
[],
98+
time.perf_counter() - s,
7799
)
78100

79101
if ocr_result is None and need_ocr:
@@ -103,32 +125,19 @@ def __call__(
103125
i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
104126
for i, t_box_ocr in enumerate(t_rec_ocr_list)
105127
}
106-
table_str = plot_html_table(logi_points, cell_box_det_map)
128+
pred_html = plot_html_table(logi_points, cell_box_det_map)
107129

108130
# 输出可视化排序,用于验证结果,生产版本可以去掉
109131
_, idx_list = sorted_ocr_boxes(
110132
[t_box_ocr["t_box"] for t_box_ocr in t_rec_ocr_list]
111133
)
112-
t_rec_ocr_list = [t_rec_ocr_list[i] for i in idx_list]
113-
sorted_polygons = [t_box_ocr["t_box"] for t_box_ocr in t_rec_ocr_list]
114-
sorted_logi_points = [
115-
t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list
116-
]
117-
ocr_boxes_res = [
118-
box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_result
119-
]
120-
sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res)
121-
table_elapse = time.perf_counter() - ss
122-
return (
123-
table_str,
124-
table_elapse,
125-
sorted_polygons,
126-
sorted_logi_points,
127-
sorted_ocr_boxes_res,
128-
)
134+
polygons = polygons.reshape(-1, 8)
135+
logi_points = np.array(logi_points)
136+
elapse = time.perf_counter() - s
129137
except Exception:
130138
logging.warning(traceback.format_exc())
131-
return "", 0.0, None, None, None
139+
return RapidTableOutput("", None, None, 0.0)
140+
return RapidTableOutput(pred_html, polygons, logi_points, elapse)
132141

133142
def transform_res(
134143
self,
@@ -159,48 +168,27 @@ def transform_res(
159168
res.append(dict_res)
160169
return res
161170

162-
def preprocess(self, img: np.ndarray) -> Dict[str, Any]:
163-
height, width = img.shape[:2]
164-
resized_image = cv2.resize(img, (width, height))
165-
166-
c = np.array([0, 0], dtype=np.float32)
167-
s = max(height, width) * 1.0
168-
trans_input = get_affine_transform_upper_left(c, s, [self.inp_w, self.inp_h])
169-
170-
inp_image = cv2.warpAffine(
171-
resized_image, trans_input, (self.inp_w, self.inp_h), flags=cv2.INTER_LINEAR
172-
)
173-
inp_image = ((inp_image / 255.0 - self.mean) / self.std).astype(np.float32)
174-
175-
images = inp_image.transpose(2, 0, 1).reshape(1, 3, self.inp_h, self.inp_w)
176-
meta = {
177-
"c": c,
178-
"s": s,
179-
"out_height": self.inp_h // 4,
180-
"out_width": self.inp_w // 4,
181-
}
182-
return {"img": images, "meta": meta}
171+
@staticmethod
172+
def get_model_path(
173+
model_type: str, model_path: Union[str, Path, None]
174+
) -> Union[str, Dict[str, str]]:
175+
if model_path is not None:
176+
return model_path
183177

184-
def infer(self, input_content: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
185-
hm, st, wh, ax, cr, reg = self.det_session([input_content["img"]])
186-
output = {
187-
"hm": hm,
188-
"st": st,
189-
"wh": wh,
190-
"ax": ax,
191-
"cr": cr,
192-
"reg": reg,
193-
}
194-
slct_logi_feat, slct_dets_feat, slct_output_dets = self.det_process(
195-
output, input_content["meta"]
196-
)
178+
model_url = KEY_TO_MODEL_URL.get(model_type, None)
179+
if isinstance(model_url, str):
180+
model_path = DownloadModel.download(model_url)
181+
return model_path
197182

198-
slct_output_dets = slct_output_dets.reshape(-1, 4, 2)
183+
if isinstance(model_url, dict):
184+
model_paths = {}
185+
for k, url in model_url.items():
186+
model_paths[k] = DownloadModel.download(
187+
url, save_model_name=f"{model_type}_{Path(url).name}"
188+
)
189+
return model_paths
199190

200-
_, slct_logi = self.process_session(
201-
[slct_logi_feat, slct_dets_feat.astype(np.int64)]
202-
)
203-
return slct_output_dets, slct_logi
191+
raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.")
204192

205193
def sort_and_gather_ocr_res(self, res):
206194
for i, dict_res in enumerate(res):
@@ -254,23 +242,6 @@ def handle_overlap_row_col(self, res):
254242
res = [res[i] for i in range(len(res)) if i not in deleted_idx]
255243
return res, grid
256244

257-
@staticmethod
258-
def filter_logi_points(slct_logi: np.ndarray) -> List[np.ndarray]:
259-
for logic_points in slct_logi[0]:
260-
# 修正坐标接近导致的r_e > r_s 或 c_e > c_s
261-
if abs(logic_points[0] - logic_points[1]) < 0.2:
262-
row = (logic_points[0] + logic_points[1]) / 2
263-
logic_points[0] = row
264-
logic_points[1] = row
265-
if abs(logic_points[2] - logic_points[3]) < 0.2:
266-
col = (logic_points[2] + logic_points[3]) / 2
267-
logic_points[2] = col
268-
logic_points[3] = col
269-
logi_floor = np.floor(slct_logi)
270-
dev = slct_logi - logi_floor
271-
slct_logi = np.where(dev > 0.5, logi_floor + 1, logi_floor)
272-
return slct_logi[0].astype(np.int32)
273-
274245
def re_rec(
275246
self,
276247
img: np.ndarray,

0 commit comments

Comments
 (0)