5
5
6
6
from fastembed .common import OnnxProvider , ImageInput
7
7
from fastembed .common .onnx_model import OnnxOutputContext
8
+ from fastembed .common .types import NumpyArray
8
9
from fastembed .common .utils import define_cache_dir
9
10
from fastembed .late_interaction_multimodal .late_interaction_multimodal_embedding_base import (
10
11
LateInteractionMultimodalEmbeddingBase ,
@@ -96,7 +97,7 @@ def __init__(
96
97
self .device_id = None
97
98
98
99
self .model_description = self ._get_model_description (model_name )
99
- self .cache_dir = define_cache_dir (cache_dir )
100
+ self .cache_dir = str ( define_cache_dir (cache_dir ) )
100
101
101
102
self ._model_dir = self .download_model (
102
103
self .model_description ,
@@ -132,15 +133,15 @@ def load_onnx_model(self) -> None:
132
133
def _post_process_onnx_image_output (
133
134
self ,
134
135
output : OnnxOutputContext ,
135
- ) -> Iterable [np . ndarray ]:
136
+ ) -> Iterable [NumpyArray ]:
136
137
"""
137
138
Post-process the ONNX model output to convert it into a usable format.
138
139
139
140
Args:
140
141
output (OnnxOutputContext): The raw output from the ONNX model.
141
142
142
143
Returns:
143
- Iterable[np.ndarray ]: Post-processed output as NumPy arrays.
144
+ Iterable[NumpyArray ]: Post-processed output as NumPy arrays.
144
145
"""
145
146
return output .model_output .reshape (
146
147
output .model_output .shape [0 ], - 1 , self .model_description ["dim" ]
@@ -149,15 +150,15 @@ def _post_process_onnx_image_output(
149
150
def _post_process_onnx_text_output (
150
151
self ,
151
152
output : OnnxOutputContext ,
152
- ) -> Iterable [np . ndarray ]:
153
+ ) -> Iterable [NumpyArray ]:
153
154
"""
154
155
Post-process the ONNX model output to convert it into a usable format.
155
156
156
157
Args:
157
158
output (OnnxOutputContext): The raw output from the ONNX model.
158
159
159
160
Returns:
160
- Iterable[np.ndarray ]: Post-processed output as NumPy arrays.
161
+ Iterable[NumpyArray ]: Post-processed output as NumPy arrays.
161
162
"""
162
163
return output .model_output .astype (np .float32 )
163
164
@@ -172,30 +173,32 @@ def tokenize(self, documents: list[str], **_) -> list[Encoding]:
172
173
return encoded
173
174
174
175
def _preprocess_onnx_text_input (
175
- self , onnx_input : dict [str , np . ndarray ], ** kwargs
176
- ) -> dict [str , np . ndarray ]:
176
+ self , onnx_input : dict [str , NumpyArray ], ** kwargs
177
+ ) -> dict [str , NumpyArray ]:
177
178
onnx_input ["input_ids" ] = np .array (
178
179
[
179
180
self .QUERY_MARKER_TOKEN_ID + input_ids [2 :].tolist ()
180
181
for input_ids in onnx_input ["input_ids" ]
181
182
]
182
183
)
183
- empty_image_placeholder = np .zeros (self .IMAGE_PLACEHOLDER_SIZE , dtype = np .float32 )
184
+ empty_image_placeholder : NumpyArray = np .zeros (
185
+ self .IMAGE_PLACEHOLDER_SIZE , dtype = np .float32
186
+ )
184
187
onnx_input ["pixel_values" ] = np .array (
185
- [empty_image_placeholder for _ in onnx_input ["input_ids" ]]
188
+ [empty_image_placeholder for _ in onnx_input ["input_ids" ]],
186
189
)
187
190
return onnx_input
188
191
189
192
def _preprocess_onnx_image_input (
190
193
self , onnx_input : dict [str , np .ndarray ], ** kwargs
191
- ) -> dict [str , np . ndarray ]:
194
+ ) -> dict [str , NumpyArray ]:
192
195
"""
193
196
Add placeholders for text input when processing image data for ONNX.
194
197
Args:
195
- onnx_input (Dict[str, np.ndarray ]): Preprocessed image inputs.
198
+ onnx_input (Dict[str, NumpyArray ]): Preprocessed image inputs.
196
199
**kwargs: Additional arguments.
197
200
Returns:
198
- Dict[str, np.ndarray ]: ONNX input with text placeholders.
201
+ Dict[str, NumpyArray ]: ONNX input with text placeholders.
199
202
"""
200
203
201
204
onnx_input ["input_ids" ] = np .array (
@@ -212,7 +215,7 @@ def embed_text(
212
215
batch_size : int = 256 ,
213
216
parallel : Optional [int ] = None ,
214
217
** kwargs ,
215
- ) -> Iterable [np . ndarray ]:
218
+ ) -> Iterable [NumpyArray ]:
216
219
"""
217
220
Encode a list of documents into list of embeddings.
218
221
@@ -241,11 +244,11 @@ def embed_text(
241
244
242
245
def embed_image (
243
246
self ,
244
- images : ImageInput ,
247
+ images : Union [ ImageInput , Iterable [ ImageInput ]] ,
245
248
batch_size : int = 16 ,
246
249
parallel : Optional [int ] = None ,
247
250
** kwargs ,
248
- ) -> Iterable [np . ndarray ]:
251
+ ) -> Iterable [NumpyArray ]:
249
252
"""
250
253
Encode a list of images into list of embeddings.
251
254
0 commit comments