Skip to content

Commit 1c016a2

Browse files
fix: Fix multi gpu settings
1 parent 4037e14 commit 1c016a2

File tree

5 files changed

+73
-34
lines changed

5 files changed

+73
-34
lines changed

fastembed/common/utils.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import unicodedata
66
from pathlib import Path
77
from itertools import islice
8-
from typing import Iterable, Optional, TypeVar
8+
from typing import Iterable, Optional, TypeVar, Sequence
99

1010
import numpy as np
1111
from numpy.typing import NDArray
1212

13-
from fastembed.common.types import NumpyArray
13+
from fastembed.common.types import NumpyArray, OnnxProvider
1414

1515
T = TypeVar("T")
1616

@@ -67,3 +67,18 @@ def get_all_punctuation() -> set[str]:
6767

6868
def remove_non_alphanumeric(text: str) -> str:
6969
return re.sub(r"[^\w\s]", " ", text, flags=re.UNICODE)
70+
71+
72+
def is_cuda_enabled(cuda: bool, providers: Optional[Sequence[OnnxProvider]]) -> bool:
73+
"""
74+
Check if CUDA is enabled based on the `cuda` and `providers` parameters
75+
"""
76+
if cuda:
77+
return True
78+
if not providers:
79+
return False
80+
if isinstance(providers, str):
81+
return "CUDAExecutionProvider" in providers
82+
return isinstance(providers, (list, tuple)) and any(
83+
isinstance(p, str) and "CUDAExecutionProvider" in p for p in providers
84+
)

fastembed/image/onnx_image_model.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from fastembed.common import ImageInput, OnnxProvider
1414
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
1515
from fastembed.common.preprocessor_utils import load_preprocessor
16-
from fastembed.common.utils import iter_batch
16+
from fastembed.common.utils import iter_batch, is_cuda_enabled
1717
from fastembed.parallel_processor import ParallelWorkerPool
1818

1919
# Holds type of the embedding result
@@ -66,7 +66,6 @@ def _build_onnx_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]:
6666
return {input_name: encoded}
6767

6868
def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
69-
device_id = kwargs.pop("device_id", 0)
7069
with contextlib.ExitStack():
7170
image_files = [
7271
Image.open(image) if not isinstance(image, Image.Image) else image
@@ -76,12 +75,18 @@ def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputConte
7675
encoded = np.array(self.processor(image_files))
7776
onnx_input = self._build_onnx_input(encoded)
7877
onnx_input = self._preprocess_onnx_input(onnx_input)
79-
device_id = device_id if isinstance(device_id, int) else 0
78+
8079
run_options = ort.RunOptions()
81-
run_options.add_run_config_entry(
82-
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
83-
)
84-
model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]
80+
providers = kwargs.get("providers", None)
81+
cuda = kwargs.get("cuda", False)
82+
if is_cuda_enabled(cuda, providers):
83+
device_id = kwargs.get("device_id", None)
84+
device_id = str(device_id if isinstance(device_id, int) else 0)
85+
run_options.add_run_config_entry(
86+
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
87+
)
88+
89+
model_output = self.model.run(None, onnx_input, run_options) # type: ignore[union-attr]
8590
embeddings = model_output[0].reshape(len(images), -1)
8691
return OnnxOutputContext(model_output=embeddings)
8792

fastembed/late_interaction_multimodal/onnx_multimodal_model.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
1414
from fastembed.common.preprocessor_utils import load_tokenizer, load_preprocessor
1515
from fastembed.common.types import NumpyArray
16-
from fastembed.common.utils import iter_batch
16+
from fastembed.common.utils import iter_batch, is_cuda_enabled
1717
from fastembed.image.transform.operators import Compose
1818
from fastembed.parallel_processor import ParallelWorkerPool
1919

@@ -89,7 +89,6 @@ def onnx_embed_text(
8989
documents: list[str],
9090
**kwargs: Any,
9191
) -> OnnxOutputContext:
92-
device_id = kwargs.pop("device_id", 0)
9392
encoded = self.tokenize(documents, **kwargs)
9493
input_ids = np.array([e.ids for e in encoded])
9594
attention_mask = np.array([e.attention_mask for e in encoded]) # type: ignore[union-attr]
@@ -105,12 +104,18 @@ def onnx_embed_text(
105104
)
106105

107106
onnx_input = self._preprocess_onnx_text_input(onnx_input, **kwargs)
108-
device_id = device_id if isinstance(device_id, int) else 0
107+
109108
run_options = ort.RunOptions()
110-
run_options.add_run_config_entry(
111-
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
112-
)
113-
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
109+
providers = kwargs.get("providers", None)
110+
cuda = kwargs.get("cuda", False)
111+
if is_cuda_enabled(cuda, providers):
112+
device_id = kwargs.get("device_id", None)
113+
device_id = str(device_id if isinstance(device_id, int) else 0)
114+
run_options.add_run_config_entry(
115+
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
116+
)
117+
118+
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input, run_options) # type: ignore[union-attr]
114119
return OnnxOutputContext(
115120
model_output=model_output[0],
116121
attention_mask=onnx_input.get("attention_mask", attention_mask),
@@ -167,7 +172,6 @@ def _embed_documents(
167172
yield from self._post_process_onnx_text_output(batch) # type: ignore
168173

169174
def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
170-
device_id = kwargs.pop("device_id", 0)
171175
with contextlib.ExitStack():
172176
image_files = [
173177
Image.open(image) if not isinstance(image, Image.Image) else image
@@ -177,12 +181,18 @@ def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutpu
177181
encoded = np.array(self.processor(image_files))
178182
onnx_input = {"pixel_values": encoded}
179183
onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs)
180-
device_id = device_id if isinstance(device_id, int) else 0
184+
181185
run_options = ort.RunOptions()
182-
run_options.add_run_config_entry(
183-
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
184-
)
185-
model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]
186+
providers = kwargs.get("providers", None)
187+
cuda = kwargs.get("cuda", False)
188+
if is_cuda_enabled(cuda, providers):
189+
device_id = kwargs.get("device_id", None)
190+
device_id = str(device_id if isinstance(device_id, int) else 0)
191+
run_options.add_run_config_entry(
192+
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
193+
)
194+
195+
model_output = self.model.run(None, onnx_input, run_options) # type: ignore[union-attr]
186196
embeddings = model_output[0].reshape(len(images), -1)
187197
return OnnxOutputContext(model_output=embeddings)
188198

fastembed/rerank/cross_encoder/onnx_text_model.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from fastembed.common.types import NumpyArray
1717
from fastembed.common.preprocessor_utils import load_tokenizer
18-
from fastembed.common.utils import iter_batch
18+
from fastembed.common.utils import iter_batch, is_cuda_enabled
1919
from fastembed.parallel_processor import ParallelWorkerPool
2020

2121

@@ -69,15 +69,20 @@ def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOut
6969
return self.onnx_embed_pairs(pairs, **kwargs)
7070

7171
def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxOutputContext:
72-
device_id = kwargs.pop("device_id", 0)
7372
tokenized_input = self.tokenize(pairs, **kwargs)
7473
inputs = self._build_onnx_input(tokenized_input)
7574
onnx_input = self._preprocess_onnx_input(inputs, **kwargs)
76-
device_id = device_id if isinstance(device_id, int) else 0
75+
7776
run_options = ort.RunOptions()
78-
run_options.add_run_config_entry(
79-
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
80-
)
77+
providers = kwargs.get("providers", None)
78+
cuda = kwargs.get("cuda", False)
79+
if is_cuda_enabled(cuda, providers):
80+
device_id = kwargs.get("device_id", None)
81+
device_id = str(device_id if isinstance(device_id, int) else 0)
82+
run_options.add_run_config_entry(
83+
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
84+
)
85+
8186
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input, run_options) # type: ignore[union-attr]
8287
relevant_output = outputs[0]
8388
scores: NumpyArray = relevant_output[:, 0]

fastembed/text/onnx_text_model.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from fastembed.common.types import NumpyArray, OnnxProvider
1212
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
1313
from fastembed.common.preprocessor_utils import load_tokenizer
14-
from fastembed.common.utils import iter_batch
14+
from fastembed.common.utils import iter_batch, is_cuda_enabled
1515
from fastembed.parallel_processor import ParallelWorkerPool
1616

1717

@@ -68,7 +68,6 @@ def onnx_embed(
6868
documents: list[str],
6969
**kwargs: Any,
7070
) -> OnnxOutputContext:
71-
device_id = kwargs.pop("device_id", 0)
7271
encoded = self.tokenize(documents, **kwargs)
7372
input_ids = np.array([e.ids for e in encoded])
7473
attention_mask = np.array([e.attention_mask for e in encoded])
@@ -84,11 +83,16 @@ def onnx_embed(
8483
)
8584
onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs)
8685

87-
device_id = device_id if isinstance(device_id, int) else 0
8886
run_options = ort.RunOptions()
89-
run_options.add_run_config_entry(
90-
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
91-
)
87+
providers = kwargs.get("providers", None)
88+
cuda = kwargs.get("cuda", False)
89+
if is_cuda_enabled(cuda, providers):
90+
device_id = kwargs.get("device_id", None)
91+
device_id = str(device_id if isinstance(device_id, int) else 0)
92+
run_options.add_run_config_entry(
93+
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
94+
)
95+
9296
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input, run_options) # type: ignore[union-attr]
9397
return OnnxOutputContext(
9498
model_output=model_output[0],

0 commit comments

Comments
 (0)