Skip to content

Commit eaecf7d

Browse files
hh-space-invadergeneralljoein
authored
Multi gpu support (#358)
* feat: Added multi gpu support for text embedding * feat: Add support for multi-gpu for special text models * fix: Fix lazy_load to load the model to child processes when parallel is not none * feat: Added lazy_load and multi-gpu to colbert * feat: Add lazy_load and multi gpu to image models * feat: Support lazy_load and multi-gpu to sparse models (except BM25) * fix: Fixed BM25 not working * refactor: Remove redundant GPUParallelProcessor * refactor: Refactor _embed_*_parallel * feat: Add cuda argument refactor: Refactor how worker assign device * fix: Fix if providers and cuda are None * fix: Fix providers and cuda are none * WIP: Multi gpu support review (#361) * WIP: review * wip: review * refactor: refactor images * refactor: refactor sparse * refactor: refactor late interaction * add model loading * add tests * fix: uncomment models in tests * fix: fix variable declaration order * fix: fix device id assignment * tests: add multi gpu tests * fix: fix device id assignment for sparse embeddings * tests: update multi gpu tests --------- Co-authored-by: George Panchuk <[email protected]> * refactor: remove redundant declarations * fix: rollback redundant changes * fix: remove num workers device ids dep, fix type hint * fix: fix post process for sparse models * fix: remove redundant model loading * new: add lazy load and new gpu support to cross encoders * fix: add rerankers to multi gpu tests * fix: unlock multilingual test * fix: fix gpu test with cross encoder --------- Co-authored-by: Andrey Vasnetsov <[email protected]> Co-authored-by: George Panchuk <[email protected]>
1 parent 58b5a8e commit eaecf7d

29 files changed

+823
-152
lines changed

fastembed/common/onnx_model.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,28 @@ def _preprocess_onnx_input(
5050
"""
5151
return onnx_input
5252

53-
def load_onnx_model(
53+
def _load_onnx_model(
5454
self,
5555
model_dir: Path,
5656
model_file: str,
5757
threads: Optional[int],
5858
providers: Optional[Sequence[OnnxProvider]] = None,
59+
cuda: bool = False,
60+
device_id: Optional[int] = None,
5961
) -> None:
6062
model_path = model_dir / model_file
6163
# List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
6264

63-
onnx_providers = (
64-
["CPUExecutionProvider"] if providers is None else list(providers)
65-
)
65+
if providers is not None:
66+
onnx_providers = list(providers)
67+
elif cuda:
68+
if device_id is None:
69+
onnx_providers = ["CUDAExecutionProvider"]
70+
else:
71+
onnx_providers = [("CUDAExecutionProvider", {"device_id": device_id})]
72+
else:
73+
onnx_providers = ["CPUExecutionProvider"]
74+
6675
available_providers = ort.get_available_providers()
6776
requested_provider_names = []
6877
for provider in onnx_providers:
@@ -94,6 +103,9 @@ def load_onnx_model(
94103
RuntimeWarning,
95104
)
96105

106+
def load_onnx_model(self) -> None:
107+
raise NotImplementedError("Subclasses must implement this method")
108+
97109
def onnx_embed(self, *args, **kwargs) -> OnnxOutputContext:
98110
raise NotImplementedError("Subclasses must implement this method")
99111

fastembed/image/image_embedding.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,23 @@ def __init__(
4545
cache_dir: Optional[str] = None,
4646
threads: Optional[int] = None,
4747
providers: Optional[Sequence[OnnxProvider]] = None,
48+
cuda: bool = False,
49+
device_ids: Optional[List[int]] = None,
50+
lazy_load: bool = False,
4851
**kwargs,
4952
):
5053
super().__init__(model_name, cache_dir, threads, **kwargs)
51-
5254
for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
5355
supported_models = EMBEDDING_MODEL_TYPE.list_supported_models()
54-
if any(
55-
model_name.lower() == model["model"].lower()
56-
for model in supported_models
57-
):
56+
if any(model_name.lower() == model["model"].lower() for model in supported_models):
5857
self.model = EMBEDDING_MODEL_TYPE(
5958
model_name,
6059
cache_dir,
6160
threads=threads,
6261
providers=providers,
62+
cuda=cuda,
63+
device_ids=device_ids,
64+
lazy_load=lazy_load,
6365
**kwargs,
6466
)
6567
return

fastembed/image/image_embedding_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def embed(
3030
Embeds a list of images into a list of embeddings.
3131
3232
Args:
33-
images - The list of image paths to preprocess and embed.
33+
images: The list of image paths to preprocess and embed.
3434
batch_size: Batch size for encoding
3535
parallel:
3636
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.

fastembed/image/onnx_embedding.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def __init__(
5959
cache_dir: Optional[str] = None,
6060
threads: Optional[int] = None,
6161
providers: Optional[Sequence[OnnxProvider]] = None,
62+
cuda: bool = False,
63+
device_ids: Optional[List[int]] = None,
64+
lazy_load: bool = False,
65+
device_id: Optional[int] = None,
6266
**kwargs,
6367
):
6468
"""
@@ -68,24 +72,56 @@ def __init__(
6872
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
6973
Defaults to `fastembed_cache` in the system's temp directory.
7074
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
75+
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
76+
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
77+
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
78+
Defaults to False.
79+
device_ids (Optional[List[int]], optional): The list of device ids to use for data parallel processing in
80+
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
81+
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
82+
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
83+
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
7184
7285
Raises:
7386
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
7487
"""
7588

7689
super().__init__(model_name, cache_dir, threads, **kwargs)
77-
78-
model_description = self._get_model_description(model_name)
90+
self.providers = providers
91+
self.lazy_load = lazy_load
92+
93+
# List of device ids, that can be used for data parallel processing in workers
94+
self.device_ids = device_ids
95+
self.cuda = cuda
96+
97+
# This device_id will be used if we need to load model in current process
98+
if device_id is not None:
99+
self.device_id = device_id
100+
elif self.device_ids is not None:
101+
self.device_id = self.device_ids[0]
102+
else:
103+
self.device_id = None
104+
105+
self.model_description = self._get_model_description(model_name)
79106
self.cache_dir = define_cache_dir(cache_dir)
80107
self._model_dir = self.download_model(
81-
model_description, self.cache_dir, local_files_only=self._local_files_only
108+
self.model_description, self.cache_dir, local_files_only=self._local_files_only
82109
)
83110

84-
self.load_onnx_model(
111+
if not self.lazy_load:
112+
self.load_onnx_model()
113+
114+
def load_onnx_model(self) -> None:
115+
"""
116+
Load the onnx model.
117+
"""
118+
self._load_onnx_model(
85119
model_dir=self._model_dir,
86-
model_file=model_description["model_file"],
87-
threads=threads,
88-
providers=providers,
120+
model_file=self.model_description["model_file"],
121+
threads=self.threads,
122+
providers=self.providers,
123+
cuda=self.cuda,
124+
device_id=self.device_id,
89125
)
90126

91127
@classmethod
@@ -120,12 +156,16 @@ def embed(
120156
Returns:
121157
List of embeddings, one per document
122158
"""
159+
123160
yield from self._embed_images(
124161
model_name=self.model_name,
125162
cache_dir=str(self.cache_dir),
126163
images=images,
127164
batch_size=batch_size,
128165
parallel=parallel,
166+
providers=self.providers,
167+
cuda=self.cuda,
168+
device_ids=self.device_ids,
129169
**kwargs,
130170
)
131171

@@ -148,4 +188,9 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.nd
148188

149189
class OnnxImageEmbeddingWorker(ImageEmbeddingWorker):
150190
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> OnnxImageEmbedding:
151-
return OnnxImageEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs)
191+
return OnnxImageEmbedding(
192+
model_name=model_name,
193+
cache_dir=cache_dir,
194+
threads=1,
195+
**kwargs,
196+
)

fastembed/image/onnx_image_model.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,28 @@ def _preprocess_onnx_input(
3636
"""
3737
return onnx_input
3838

39-
def load_onnx_model(
39+
def _load_onnx_model(
4040
self,
4141
model_dir: Path,
4242
model_file: str,
4343
threads: Optional[int],
4444
providers: Optional[Sequence[OnnxProvider]] = None,
45+
cuda: bool = False,
46+
device_id: Optional[int] = None,
4547
) -> None:
46-
super().load_onnx_model(
48+
super()._load_onnx_model(
4749
model_dir=model_dir,
4850
model_file=model_file,
4951
threads=threads,
5052
providers=providers,
53+
cuda=cuda,
54+
device_id=device_id,
5155
)
5256
self.processor = load_preprocessor(model_dir=model_dir)
5357

58+
def load_onnx_model(self) -> None:
59+
raise NotImplementedError("Subclasses must implement this method")
60+
5461
def _build_onnx_input(self, encoded: np.ndarray) -> Dict[str, np.ndarray]:
5562
return {node.name: encoded for node in self.model.get_inputs()}
5663

@@ -74,33 +81,44 @@ def _embed_images(
7481
images: ImageInput,
7582
batch_size: int = 256,
7683
parallel: Optional[int] = None,
84+
providers: Optional[Sequence[OnnxProvider]] = None,
85+
cuda: bool = False,
86+
device_ids: Optional[List[int]] = None,
7787
**kwargs,
7888
) -> Iterable[T]:
7989
is_small = False
8090

81-
if (
82-
isinstance(images, str)
83-
or isinstance(images, Path)
84-
or (isinstance(images, Image.Image))
85-
):
91+
if isinstance(images, (str, Path, Image.Image)):
8692
images = [images]
8793
is_small = True
8894

89-
if isinstance(images, list):
90-
if len(images) < batch_size:
91-
is_small = True
92-
93-
if parallel == 0:
94-
parallel = os.cpu_count()
95+
if isinstance(images, list) and len(images) < batch_size:
96+
is_small = True
9597

9698
if parallel is None or is_small:
99+
if not hasattr(self, "model") or self.model is None:
100+
self.load_onnx_model()
101+
97102
for batch in iter_batch(images, batch_size):
98103
yield from self._post_process_onnx_output(self.onnx_embed(batch))
99104
else:
105+
if parallel == 0:
106+
parallel = os.cpu_count()
107+
100108
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
101-
params = {"model_name": model_name, "cache_dir": cache_dir, **kwargs}
109+
params = {
110+
"model_name": model_name,
111+
"cache_dir": cache_dir,
112+
"providers": providers,
113+
**kwargs,
114+
}
115+
102116
pool = ParallelWorkerPool(
103-
parallel, self._get_worker_class(), start_method=start_method
117+
parallel,
118+
self._get_worker_class(),
119+
cuda=cuda,
120+
device_ids=device_ids,
121+
start_method=start_method,
104122
)
105123
for batch in pool.ordered_map(iter_batch(images, batch_size), **params):
106124
yield from self._post_process_onnx_output(batch)

fastembed/late_interaction/colbert.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ def __init__(
117117
cache_dir: Optional[str] = None,
118118
threads: Optional[int] = None,
119119
providers: Optional[Sequence[OnnxProvider]] = None,
120+
cuda: bool = False,
121+
device_ids: Optional[List[int]] = None,
122+
lazy_load: bool = False,
123+
device_id: Optional[int] = None,
120124
**kwargs,
121125
):
122126
"""
@@ -126,29 +130,60 @@ def __init__(
126130
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
127131
Defaults to `fastembed_cache` in the system's temp directory.
128132
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
133+
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
134+
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
135+
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
136+
Defaults to False.
137+
device_ids (Optional[List[int]], optional): The list of device ids to use for data parallel processing in
138+
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
139+
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
140+
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
141+
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
129142
130143
Raises:
131144
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
132145
"""
133146

134147
super().__init__(model_name, cache_dir, threads, **kwargs)
148+
self.providers = providers
149+
self.lazy_load = lazy_load
150+
151+
# List of device ids, that can be used for data parallel processing in workers
152+
self.device_ids = device_ids
153+
self.cuda = cuda
154+
155+
# This device_id will be used if we need to load model in current process
156+
if device_id is not None:
157+
self.device_id = device_id
158+
elif self.device_ids is not None:
159+
self.device_id = self.device_ids[0]
160+
else:
161+
self.device_id = None
135162

136-
model_description = self._get_model_description(model_name)
163+
self.model_description = self._get_model_description(model_name)
137164
self.cache_dir = define_cache_dir(cache_dir)
138165

139166
self._model_dir = self.download_model(
140-
model_description, self.cache_dir, local_files_only=self._local_files_only
167+
self.model_description, self.cache_dir, local_files_only=self._local_files_only
141168
)
169+
self.mask_token_id = None
170+
self.pad_token_id = None
171+
self.skip_list = set()
172+
173+
if not self.lazy_load:
174+
self.load_onnx_model()
142175

143-
self.load_onnx_model(
176+
def load_onnx_model(self) -> None:
177+
self._load_onnx_model(
144178
model_dir=self._model_dir,
145-
model_file=model_description["model_file"],
146-
threads=threads,
147-
providers=providers,
179+
model_file=self.model_description["model_file"],
180+
threads=self.threads,
181+
providers=self.providers,
182+
cuda=self.cuda,
183+
device_id=self.device_id,
148184
)
149185
self.mask_token_id = self.special_token_to_id["[MASK]"]
150186
self.pad_token_id = self.tokenizer.padding["pad_id"]
151-
152187
self.skip_list = {
153188
self.tokenizer.encode(symbol, add_special_tokens=False).ids[0]
154189
for symbol in string.punctuation
@@ -182,13 +217,19 @@ def embed(
182217
documents=documents,
183218
batch_size=batch_size,
184219
parallel=parallel,
220+
providers=self.providers,
221+
cuda=self.cuda,
222+
device_ids=self.device_ids,
185223
**kwargs,
186224
)
187225

188226
def query_embed(self, query: Union[str, List[str]], **kwargs) -> Iterable[np.ndarray]:
189227
if isinstance(query, str):
190228
query = [query]
191229

230+
if not hasattr(self, "model") or self.model is None:
231+
self.load_onnx_model()
232+
192233
for text in query:
193234
yield from self._post_process_onnx_output(
194235
self.onnx_embed([text], is_doc=False), is_doc=False
@@ -201,4 +242,9 @@ def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
201242

202243
class ColbertEmbeddingWorker(TextEmbeddingWorker):
203244
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> Colbert:
204-
return Colbert(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs)
245+
return Colbert(
246+
model_name=model_name,
247+
cache_dir=cache_dir,
248+
threads=1,
249+
**kwargs,
250+
)

0 commit comments

Comments
 (0)