Skip to content

Commit d328ffc

Browse files
committed
Update LoadImage
1 parent 50eef11 commit d328ffc

File tree

2 files changed

+33
-25
lines changed

2 files changed

+33
-25
lines changed

lineless_table_rec/main.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
import time
77
import traceback
88
from pathlib import Path
9-
from typing import Any, Dict, List, Tuple
9+
from typing import Any, Dict, List, Tuple, Union
1010

1111
import cv2
1212
import numpy as np
1313
from rapidocr_onnxruntime import RapidOCR
1414

1515
from .lineless_table_process import DetProcess, get_affine_transform_upper_left
16-
from .utils import LoadImage, OrtInferSession
16+
from .utils import InputType, LoadImage, OrtInferSession
1717
from .utils_table_recover import (
1818
get_rotate_crop_image,
1919
match_ocr_cell,
@@ -29,6 +29,8 @@
2929
class LinelessTableRecognition:
3030
def __init__(
3131
self,
32+
detect_model_path: Union[str, Path] = detect_model_path,
33+
process_model_path: Union[str, Path] = process_model_path,
3234
):
3335
self.mean = np.array([0.408, 0.447, 0.470], dtype=np.float32).reshape(1, 1, 3)
3436
self.std = np.array([0.289, 0.274, 0.278], dtype=np.float32).reshape(1, 1, 3)
@@ -43,7 +45,7 @@ def __init__(
4345
self.det_process = DetProcess()
4446
self.ocr = RapidOCR()
4547

46-
def __call__(self, content: Dict[str, Any]) -> str:
48+
def __call__(self, content: InputType) -> str:
4749
ss = time.perf_counter()
4850
img = self.load_img(content)
4951

@@ -92,8 +94,8 @@ def preprocess(self, img: np.ndarray) -> Dict[str, Any]:
9294
}
9395
return {"img": images, "meta": meta}
9496

95-
def infer(self, input: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
96-
hm, st, wh, ax, cr, reg = self.det_session([input["img"]])
97+
def infer(self, input_content: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
98+
hm, st, wh, ax, cr, reg = self.det_session([input_content["img"]])
9799
output = {
98100
"hm": hm,
99101
"st": st,
@@ -103,7 +105,7 @@ def infer(self, input: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
103105
"reg": reg,
104106
}
105107
slct_logi_feat, slct_dets_feat, slct_output_dets = self.det_process(
106-
output, input["meta"]
108+
output, input_content["meta"]
107109
)
108110

109111
slct_output_dets = slct_output_dets.reshape(-1, 4, 2)

lineless_table_rec/utils.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from PIL import Image, UnidentifiedImageError
1111

1212
root_dir = Path(__file__).resolve().parent
13-
InputType = Union[str, np.ndarray, bytes, Path]
13+
InputType = Union[str, np.ndarray, bytes, Path, Image.Image]
1414

1515

1616
class OrtInferSession:
@@ -91,8 +91,9 @@ def __call__(self, img: InputType) -> np.ndarray:
9191
f"The img type {type(img)} does not in {InputType.__args__}"
9292
)
9393

94+
origin_img_type = type(img)
9495
img = self.load_img(img)
95-
img = self.convert_img(img)
96+
img = self.convert_img(img, origin_img_type)
9697
return img
9798

9899
def load_img(self, img: InputType) -> np.ndarray:
@@ -111,9 +112,12 @@ def load_img(self, img: InputType) -> np.ndarray:
111112
if isinstance(img, np.ndarray):
112113
return img
113114

115+
if isinstance(img, Image.Image):
116+
return np.array(img)
117+
114118
raise LoadImageError(f"{type(img)} is not supported!")
115119

116-
def convert_img(self, img: np.ndarray):
120+
def convert_img(self, img: np.ndarray, origin_img_type):
117121
if img.ndim == 2:
118122
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
119123

@@ -125,31 +129,20 @@ def convert_img(self, img: np.ndarray):
125129
if channel == 2:
126130
return self.cvt_two_to_three(img)
127131

132+
if channel == 3:
133+
if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
134+
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
135+
return img
136+
128137
if channel == 4:
129138
return self.cvt_four_to_three(img)
130139

131-
if channel == 3:
132-
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
133-
134140
raise LoadImageError(
135141
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
136142
)
137143

138144
raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
139145

140-
@staticmethod
141-
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
142-
"""RGBA → BGR"""
143-
r, g, b, a = cv2.split(img)
144-
new_img = cv2.merge((b, g, r))
145-
146-
not_a = cv2.bitwise_not(a)
147-
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
148-
149-
new_img = cv2.bitwise_and(new_img, new_img, mask=a)
150-
new_img = cv2.add(new_img, not_a)
151-
return new_img
152-
153146
@staticmethod
154147
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
155148
"""gray + alpha → BGR"""
@@ -164,6 +157,19 @@ def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
164157
new_img = cv2.add(new_img, not_a)
165158
return new_img
166159

160+
@staticmethod
161+
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
162+
"""RGBA → BGR"""
163+
r, g, b, a = cv2.split(img)
164+
new_img = cv2.merge((b, g, r))
165+
166+
not_a = cv2.bitwise_not(a)
167+
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
168+
169+
new_img = cv2.bitwise_and(new_img, new_img, mask=a)
170+
new_img = cv2.add(new_img, not_a)
171+
return new_img
172+
167173
@staticmethod
168174
def verify_exist(file_path: Union[str, Path]):
169175
if not Path(file_path).exists():

0 commit comments

Comments
 (0)