Skip to content

Commit 6cda006

Browse files
committed
feat: add paddle cls for table cls
1 parent a16afd5 commit 6cda006

File tree

7 files changed

+184
-15
lines changed

7 files changed

+184
-15
lines changed

demo_table_cls.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# -*- encoding: utf-8 -*-
22
from table_cls import TableCls
33

4-
table_cls = TableCls()
5-
img_path = "tests/test_files/table_cls/lineless_table.png"
6-
cls_str, elapse = table_cls(img_path)
7-
print(cls_str)
8-
print(elapse)
4+
if __name__ == "__main__":
5+
table_cls = TableCls(model_type="yolox")
6+
img_path = "tests/test_files/table_cls/lineless_table_2.png"
7+
cls_str, elapse = table_cls(img_path)
8+
print(cls_str)
9+
print(elapse)

lineless_table_rec/main.py

+1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __call__(
8484
need_ocr = True
8585
if kwargs:
8686
rec_again = kwargs.get("rec_again", True)
87+
need_ocr = kwargs.get("need_ocr", True)
8788
img = self.load_img(content)
8889
try:
8990
polygons, logi_points = self.table_structure(img)

table_cls/main.py

+89-10
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,42 @@
11
import time
2+
from enum import Enum
23
from pathlib import Path
4+
from typing import Union, Dict
35

46
import cv2
57
import numpy as np
68
from PIL import Image
79

8-
from .utils import InputType, LoadImage, OrtInferSession, resize_and_center_crop
10+
from .utils.download_model import DownloadModel
11+
from .utils.utils import InputType, LoadImage, OrtInferSession, resize_and_center_crop
912

10-
cur_dir = Path(__file__).resolve().parent
11-
q_cls_model_path = cur_dir / "models" / "table_cls.onnx"
12-
yolo_cls_model_path = cur_dir / "models" / "yolo_cls.onnx"
13-
yolo_cls_x_model_path = cur_dir / "models" / "yolo_cls_x.onnx"
13+
14+
class ModelType(Enum):
15+
YOLO_CLS_X = "yolox"
16+
YOLO_CLS = "yolo"
17+
PADDLE_CLS = "paddle"
18+
Q_CLS = "q"
19+
20+
21+
ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
22+
KEY_TO_MODEL_URL = {
23+
ModelType.YOLO_CLS_X.value: f"{ROOT_URL}/table_cls/yolo_cls_x.onnx",
24+
ModelType.YOLO_CLS.value: f"{ROOT_URL}/table_cls/yolo_cls.onnx",
25+
ModelType.PADDLE_CLS.value: f"{ROOT_URL}/table_cls/paddle_cls.onnx",
26+
ModelType.Q_CLS.value: f"{ROOT_URL}/table_cls/q_cls.onnx",
27+
}
1428

1529

1630
class TableCls:
17-
def __init__(self, model_type="yolo", model_path=yolo_cls_model_path):
18-
if model_type == "yolo":
31+
def __init__(self, model_type=ModelType.YOLO_CLS.value, model_path=None):
32+
model_path = self.get_model_path(model_type, model_path)
33+
if model_type == ModelType.YOLO_CLS.value:
34+
self.table_engine = YoloCls(model_path)
35+
elif model_type == ModelType.YOLO_CLS_X.value:
1936
self.table_engine = YoloCls(model_path)
20-
elif model_type == "yolox":
21-
self.table_engine = YoloCls(yolo_cls_x_model_path)
37+
elif model_type == ModelType.PADDLE_CLS.value:
38+
self.table_engine = PaddleCls(model_path)
2239
else:
23-
model_path = q_cls_model_path
2440
self.table_engine = QanythingCls(model_path)
2541
self.load_img = LoadImage()
2642

@@ -32,6 +48,69 @@ def __call__(self, content: InputType):
3248
table_elapse = time.perf_counter() - ss
3349
return predict_cla, table_elapse
3450

51+
@staticmethod
52+
def get_model_path(
53+
model_type: str, model_path: Union[str, Path, None]
54+
) -> Union[str, Dict[str, str]]:
55+
if model_path is not None:
56+
return model_path
57+
58+
model_url = KEY_TO_MODEL_URL.get(model_type, None)
59+
if isinstance(model_url, str):
60+
model_path = DownloadModel.download(model_url)
61+
return model_path
62+
63+
if isinstance(model_url, dict):
64+
model_paths = {}
65+
for k, url in model_url.items():
66+
model_paths[k] = DownloadModel.download(
67+
url, save_model_name=f"{model_type}_{Path(url).name}"
68+
)
69+
return model_paths
70+
71+
raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.")
72+
73+
74+
class PaddleCls:
75+
def __init__(self, model_path):
76+
self.table_cls = OrtInferSession(model_path)
77+
self.inp_h = 224
78+
self.inp_w = 224
79+
self.resize_short = 256
80+
self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
81+
self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
82+
self.cls = {0: "wired", 1: "wireless"}
83+
84+
def preprocess(self, img):
85+
# short resize
86+
img_h, img_w = img.shape[:2]
87+
percent = float(self.resize_short) / min(img_w, img_h)
88+
w = int(round(img_w * percent))
89+
h = int(round(img_h * percent))
90+
img = cv2.resize(img, dsize=(w, h), interpolation=cv2.INTER_LANCZOS4)
91+
# center crop
92+
img_h, img_w = img.shape[:2]
93+
w_start = (img_w - self.inp_w) // 2
94+
h_start = (img_h - self.inp_h) // 2
95+
w_end = w_start + self.inp_w
96+
h_end = h_start + self.inp_h
97+
img = img[h_start:h_end, w_start:w_end, :]
98+
# normalize
99+
img = np.array(img, dtype=np.float32) / 255.0
100+
img -= self.mean
101+
img /= self.std
102+
# HWC to CHW
103+
img = img.transpose(2, 0, 1)
104+
# Add batch dimension, only one image
105+
img = np.expand_dims(img, axis=0)
106+
return img
107+
108+
def __call__(self, img):
109+
pred_output = self.table_cls(img)[0]
110+
pred_idxs = list(np.argmax(pred_output, axis=1))
111+
predict_cla = max(set(pred_idxs), key=pred_idxs.count)
112+
return self.cls[predict_cla]
113+
35114

36115
class QanythingCls:
37116
def __init__(self, model_path):

table_cls/utils/__init__.py

Whitespace-only changes.

table_cls/utils/download_model.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import io
2+
from pathlib import Path
3+
from typing import Optional, Union
4+
5+
import requests
6+
from tqdm import tqdm
7+
8+
from .logger import get_logger
9+
10+
logger = get_logger("DownloadModel")
11+
12+
PROJECT_DIR = Path(__file__).resolve().parent.parent
13+
DEFAULT_MODEL_DIR = PROJECT_DIR / "models"
14+
15+
16+
class DownloadModel:
17+
@classmethod
18+
def download(
19+
cls,
20+
model_full_url: Union[str, Path],
21+
save_dir: Union[str, Path, None] = None,
22+
save_model_name: Optional[str] = None,
23+
) -> str:
24+
if save_dir is None:
25+
save_dir = DEFAULT_MODEL_DIR
26+
27+
save_dir.mkdir(parents=True, exist_ok=True)
28+
29+
if save_model_name is None:
30+
save_model_name = Path(model_full_url).name
31+
32+
save_file_path = save_dir / save_model_name
33+
if save_file_path.exists():
34+
logger.debug("%s already exists", save_file_path)
35+
return str(save_file_path)
36+
37+
try:
38+
logger.info("Download %s to %s", model_full_url, save_dir)
39+
file = cls.download_as_bytes_with_progress(model_full_url, save_model_name)
40+
cls.save_file(save_file_path, file)
41+
except Exception as exc:
42+
raise DownloadModelError from exc
43+
return str(save_file_path)
44+
45+
@staticmethod
46+
def download_as_bytes_with_progress(
47+
url: Union[str, Path], name: Optional[str] = None
48+
) -> bytes:
49+
resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180)
50+
total = int(resp.headers.get("content-length", 0))
51+
bio = io.BytesIO()
52+
with tqdm(
53+
desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024
54+
) as pbar:
55+
for chunk in resp.iter_content(chunk_size=65536):
56+
pbar.update(len(chunk))
57+
bio.write(chunk)
58+
return bio.getvalue()
59+
60+
@staticmethod
61+
def save_file(save_path: Union[str, Path], file: bytes):
62+
with open(save_path, "wb") as f:
63+
f.write(file)
64+
65+
66+
class DownloadModelError(Exception):
67+
pass

table_cls/utils/logger.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# -*- encoding: utf-8 -*-
2+
# @Author: Jocker1212
3+
# @Contact: [email protected]
4+
import logging
5+
from functools import lru_cache
6+
7+
8+
@lru_cache(maxsize=32)
9+
def get_logger(name: str) -> logging.Logger:
10+
logger = logging.getLogger(name)
11+
logger.setLevel(logging.DEBUG)
12+
13+
fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s"
14+
format_str = logging.Formatter(fmt)
15+
16+
sh = logging.StreamHandler()
17+
sh.setLevel(logging.DEBUG)
18+
19+
logger.addHandler(sh)
20+
sh.setFormatter(format_str)
21+
return logger
File renamed without changes.

0 commit comments

Comments
 (0)