1
1
import time
2
+ from enum import Enum
2
3
from pathlib import Path
4
+ from typing import Union , Dict
3
5
4
6
import cv2
5
7
import numpy as np
6
8
from PIL import Image
7
9
8
- from .utils import InputType , LoadImage , OrtInferSession , resize_and_center_crop
10
+ from .utils .download_model import DownloadModel
11
+ from .utils .utils import InputType , LoadImage , OrtInferSession , resize_and_center_crop
9
12
10
- cur_dir = Path (__file__ ).resolve ().parent
11
- q_cls_model_path = cur_dir / "models" / "table_cls.onnx"
12
- yolo_cls_model_path = cur_dir / "models" / "yolo_cls.onnx"
13
- yolo_cls_x_model_path = cur_dir / "models" / "yolo_cls_x.onnx"
13
+
14
+ class ModelType (Enum ):
15
+ YOLO_CLS_X = "yolox"
16
+ YOLO_CLS = "yolo"
17
+ PADDLE_CLS = "paddle"
18
+ Q_CLS = "q"
19
+
20
+
21
+ ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
22
+ KEY_TO_MODEL_URL = {
23
+ ModelType .YOLO_CLS_X .value : f"{ ROOT_URL } /table_cls/yolo_cls_x.onnx" ,
24
+ ModelType .YOLO_CLS .value : f"{ ROOT_URL } /table_cls/yolo_cls.onnx" ,
25
+ ModelType .PADDLE_CLS .value : f"{ ROOT_URL } /table_cls/paddle_cls.onnx" ,
26
+ ModelType .Q_CLS .value : f"{ ROOT_URL } /table_cls/q_cls.onnx" ,
27
+ }
14
28
15
29
16
30
class TableCls :
17
- def __init__ (self , model_type = "yolo" , model_path = yolo_cls_model_path ):
18
- if model_type == "yolo" :
31
+ def __init__ (self , model_type = ModelType .YOLO_CLS .value , model_path = None ):
32
+ model_path = self .get_model_path (model_type , model_path )
33
+ if model_type == ModelType .YOLO_CLS .value :
34
+ self .table_engine = YoloCls (model_path )
35
+ elif model_type == ModelType .YOLO_CLS_X .value :
19
36
self .table_engine = YoloCls (model_path )
20
- elif model_type == "yolox" :
21
- self .table_engine = YoloCls ( yolo_cls_x_model_path )
37
+ elif model_type == ModelType . PADDLE_CLS . value :
38
+ self .table_engine = PaddleCls ( model_path )
22
39
else :
23
- model_path = q_cls_model_path
24
40
self .table_engine = QanythingCls (model_path )
25
41
self .load_img = LoadImage ()
26
42
@@ -32,6 +48,69 @@ def __call__(self, content: InputType):
32
48
table_elapse = time .perf_counter () - ss
33
49
return predict_cla , table_elapse
34
50
51
+ @staticmethod
52
+ def get_model_path (
53
+ model_type : str , model_path : Union [str , Path , None ]
54
+ ) -> Union [str , Dict [str , str ]]:
55
+ if model_path is not None :
56
+ return model_path
57
+
58
+ model_url = KEY_TO_MODEL_URL .get (model_type , None )
59
+ if isinstance (model_url , str ):
60
+ model_path = DownloadModel .download (model_url )
61
+ return model_path
62
+
63
+ if isinstance (model_url , dict ):
64
+ model_paths = {}
65
+ for k , url in model_url .items ():
66
+ model_paths [k ] = DownloadModel .download (
67
+ url , save_model_name = f"{ model_type } _{ Path (url ).name } "
68
+ )
69
+ return model_paths
70
+
71
+ raise ValueError (f"Model URL: { type (model_url )} is not between str and dict." )
72
+
73
+
74
+ class PaddleCls :
75
+ def __init__ (self , model_path ):
76
+ self .table_cls = OrtInferSession (model_path )
77
+ self .inp_h = 224
78
+ self .inp_w = 224
79
+ self .resize_short = 256
80
+ self .mean = np .array ([0.485 , 0.456 , 0.406 ], dtype = np .float32 )
81
+ self .std = np .array ([0.229 , 0.224 , 0.225 ], dtype = np .float32 )
82
+ self .cls = {0 : "wired" , 1 : "wireless" }
83
+
84
+ def preprocess (self , img ):
85
+ # short resize
86
+ img_h , img_w = img .shape [:2 ]
87
+ percent = float (self .resize_short ) / min (img_w , img_h )
88
+ w = int (round (img_w * percent ))
89
+ h = int (round (img_h * percent ))
90
+ img = cv2 .resize (img , dsize = (w , h ), interpolation = cv2 .INTER_LANCZOS4 )
91
+ # center crop
92
+ img_h , img_w = img .shape [:2 ]
93
+ w_start = (img_w - self .inp_w ) // 2
94
+ h_start = (img_h - self .inp_h ) // 2
95
+ w_end = w_start + self .inp_w
96
+ h_end = h_start + self .inp_h
97
+ img = img [h_start :h_end , w_start :w_end , :]
98
+ # normalize
99
+ img = np .array (img , dtype = np .float32 ) / 255.0
100
+ img -= self .mean
101
+ img /= self .std
102
+ # HWC to CHW
103
+ img = img .transpose (2 , 0 , 1 )
104
+ # Add batch dimension, only one image
105
+ img = np .expand_dims (img , axis = 0 )
106
+ return img
107
+
108
+ def __call__ (self , img ):
109
+ pred_output = self .table_cls (img )[0 ]
110
+ pred_idxs = list (np .argmax (pred_output , axis = 1 ))
111
+ predict_cla = max (set (pred_idxs ), key = pred_idxs .count )
112
+ return self .cls [predict_cla ]
113
+
35
114
36
115
class QanythingCls :
37
116
def __init__ (self , model_path ):
0 commit comments