10
10
from PIL import Image , UnidentifiedImageError
11
11
12
12
root_dir = Path (__file__ ).resolve ().parent
13
- InputType = Union [str , np .ndarray , bytes , Path ]
13
+ InputType = Union [str , np .ndarray , bytes , Path , Image . Image ]
14
14
15
15
16
16
class OrtInferSession :
@@ -91,8 +91,9 @@ def __call__(self, img: InputType) -> np.ndarray:
91
91
f"The img type { type (img )} does not in { InputType .__args__ } "
92
92
)
93
93
94
+ origin_img_type = type (img )
94
95
img = self .load_img (img )
95
- img = self .convert_img (img )
96
+ img = self .convert_img (img , origin_img_type )
96
97
return img
97
98
98
99
def load_img (self , img : InputType ) -> np .ndarray :
@@ -111,9 +112,12 @@ def load_img(self, img: InputType) -> np.ndarray:
111
112
if isinstance (img , np .ndarray ):
112
113
return img
113
114
115
+ if isinstance (img , Image .Image ):
116
+ return np .array (img )
117
+
114
118
raise LoadImageError (f"{ type (img )} is not supported!" )
115
119
116
- def convert_img (self , img : np .ndarray ):
120
+ def convert_img (self , img : np .ndarray , origin_img_type ):
117
121
if img .ndim == 2 :
118
122
return cv2 .cvtColor (img , cv2 .COLOR_GRAY2BGR )
119
123
@@ -125,31 +129,20 @@ def convert_img(self, img: np.ndarray):
125
129
if channel == 2 :
126
130
return self .cvt_two_to_three (img )
127
131
132
+ if channel == 3 :
133
+ if issubclass (origin_img_type , (str , Path , bytes , Image .Image )):
134
+ return cv2 .cvtColor (img , cv2 .COLOR_RGB2BGR )
135
+ return img
136
+
128
137
if channel == 4 :
129
138
return self .cvt_four_to_three (img )
130
139
131
- if channel == 3 :
132
- return cv2 .cvtColor (img , cv2 .COLOR_RGB2BGR )
133
-
134
140
raise LoadImageError (
135
141
f"The channel({ channel } ) of the img is not in [1, 2, 3, 4]"
136
142
)
137
143
138
144
raise LoadImageError (f"The ndim({ img .ndim } ) of the img is not in [2, 3]" )
139
145
140
- @staticmethod
141
- def cvt_four_to_three (img : np .ndarray ) -> np .ndarray :
142
- """RGBA → BGR"""
143
- r , g , b , a = cv2 .split (img )
144
- new_img = cv2 .merge ((b , g , r ))
145
-
146
- not_a = cv2 .bitwise_not (a )
147
- not_a = cv2 .cvtColor (not_a , cv2 .COLOR_GRAY2BGR )
148
-
149
- new_img = cv2 .bitwise_and (new_img , new_img , mask = a )
150
- new_img = cv2 .add (new_img , not_a )
151
- return new_img
152
-
153
146
@staticmethod
154
147
def cvt_two_to_three (img : np .ndarray ) -> np .ndarray :
155
148
"""gray + alpha → BGR"""
@@ -164,6 +157,19 @@ def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
164
157
new_img = cv2 .add (new_img , not_a )
165
158
return new_img
166
159
160
+ @staticmethod
161
+ def cvt_four_to_three (img : np .ndarray ) -> np .ndarray :
162
+ """RGBA → BGR"""
163
+ r , g , b , a = cv2 .split (img )
164
+ new_img = cv2 .merge ((b , g , r ))
165
+
166
+ not_a = cv2 .bitwise_not (a )
167
+ not_a = cv2 .cvtColor (not_a , cv2 .COLOR_GRAY2BGR )
168
+
169
+ new_img = cv2 .bitwise_and (new_img , new_img , mask = a )
170
+ new_img = cv2 .add (new_img , not_a )
171
+ return new_img
172
+
167
173
@staticmethod
168
174
def verify_exist (file_path : Union [str , Path ]):
169
175
if not Path (file_path ).exists ():
0 commit comments