13
13
from fastembed .common .onnx_model import EmbeddingWorker , OnnxModel , OnnxOutputContext , T
14
14
from fastembed .common .preprocessor_utils import load_tokenizer , load_preprocessor
15
15
from fastembed .common .types import NumpyArray
16
- from fastembed .common .utils import iter_batch
16
+ from fastembed .common .utils import iter_batch , is_cuda_enabled
17
17
from fastembed .image .transform .operators import Compose
18
18
from fastembed .parallel_processor import ParallelWorkerPool
19
19
@@ -89,7 +89,6 @@ def onnx_embed_text(
89
89
documents : list [str ],
90
90
** kwargs : Any ,
91
91
) -> OnnxOutputContext :
92
- device_id = kwargs .pop ("device_id" , 0 )
93
92
encoded = self .tokenize (documents , ** kwargs )
94
93
input_ids = np .array ([e .ids for e in encoded ])
95
94
attention_mask = np .array ([e .attention_mask for e in encoded ]) # type: ignore[union-attr]
@@ -105,12 +104,18 @@ def onnx_embed_text(
105
104
)
106
105
107
106
onnx_input = self ._preprocess_onnx_text_input (onnx_input , ** kwargs )
108
- device_id = device_id if isinstance ( device_id , int ) else 0
107
+
109
108
run_options = ort .RunOptions ()
110
- run_options .add_run_config_entry (
111
- "memory.enable_memory_arena_shrinkage" , f"gpu:{ device_id } "
112
- )
113
- model_output = self .model .run (self .ONNX_OUTPUT_NAMES , onnx_input ) # type: ignore[union-attr]
109
+ providers = kwargs .get ("providers" , None )
110
+ cuda = kwargs .get ("cuda" , False )
111
+ if is_cuda_enabled (cuda , providers ):
112
+ device_id = kwargs .get ("device_id" , None )
113
+ device_id = str (device_id if isinstance (device_id , int ) else 0 )
114
+ run_options .add_run_config_entry (
115
+ "memory.enable_memory_arena_shrinkage" , f"gpu:{ device_id } "
116
+ )
117
+
118
+ model_output = self .model .run (self .ONNX_OUTPUT_NAMES , onnx_input , run_options ) # type: ignore[union-attr]
114
119
return OnnxOutputContext (
115
120
model_output = model_output [0 ],
116
121
attention_mask = onnx_input .get ("attention_mask" , attention_mask ),
@@ -167,7 +172,6 @@ def _embed_documents(
167
172
yield from self ._post_process_onnx_text_output (batch ) # type: ignore
168
173
169
174
def onnx_embed_image (self , images : list [ImageInput ], ** kwargs : Any ) -> OnnxOutputContext :
170
- device_id = kwargs .pop ("device_id" , 0 )
171
175
with contextlib .ExitStack ():
172
176
image_files = [
173
177
Image .open (image ) if not isinstance (image , Image .Image ) else image
@@ -177,12 +181,18 @@ def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutpu
177
181
encoded = np .array (self .processor (image_files ))
178
182
onnx_input = {"pixel_values" : encoded }
179
183
onnx_input = self ._preprocess_onnx_image_input (onnx_input , ** kwargs )
180
- device_id = device_id if isinstance ( device_id , int ) else 0
184
+
181
185
run_options = ort .RunOptions ()
182
- run_options .add_run_config_entry (
183
- "memory.enable_memory_arena_shrinkage" , f"gpu:{ device_id } "
184
- )
185
- model_output = self .model .run (None , onnx_input ) # type: ignore[union-attr]
186
+ providers = kwargs .get ("providers" , None )
187
+ cuda = kwargs .get ("cuda" , False )
188
+ if is_cuda_enabled (cuda , providers ):
189
+ device_id = kwargs .get ("device_id" , None )
190
+ device_id = str (device_id if isinstance (device_id , int ) else 0 )
191
+ run_options .add_run_config_entry (
192
+ "memory.enable_memory_arena_shrinkage" , f"gpu:{ device_id } "
193
+ )
194
+
195
+ model_output = self .model .run (None , onnx_input , run_options ) # type: ignore[union-attr]
186
196
embeddings = model_output [0 ].reshape (len (images ), - 1 )
187
197
return OnnxOutputContext (model_output = embeddings )
188
198
0 commit comments