Skip to content

Commit

Permalink
new: Shrink empty arena for multi gpu settings
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Mar 4, 2025
1 parent 5c46b17 commit 758d339
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 16 deletions.
4 changes: 3 additions & 1 deletion fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,5 +138,7 @@ def __init__(
def start(cls, model_name: str, cache_dir: str, **kwargs: Any) -> "EmbeddingWorker[T]":
return cls(model_name=model_name, cache_dir=cache_dir, **kwargs)

def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
def process(
self, items: Iterable[tuple[int, Any]], **kwargs: Any
) -> Iterable[tuple[int, Any]]:
raise NotImplementedError("Subclasses must implement this method")
13 changes: 11 additions & 2 deletions fastembed/image/onnx_image_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
from PIL import Image
import onnxruntime as ort

from fastembed.image.transform.operators import Compose
from fastembed.common.types import NumpyArray
Expand Down Expand Up @@ -65,6 +66,7 @@ def _build_onnx_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]:
return {input_name: encoded}

def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
device_id = kwargs.pop("device_id", 0)
with contextlib.ExitStack():
image_files = [
Image.open(image) if not isinstance(image, Image.Image) else image
Expand All @@ -74,6 +76,11 @@ def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputConte
encoded = np.array(self.processor(image_files))
onnx_input = self._build_onnx_input(encoded)
onnx_input = self._preprocess_onnx_input(onnx_input)
device_id = device_id if isinstance(device_id, int) else 0
run_options = ort.RunOptions()
run_options.add_run_config_entry(
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
)
model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]
embeddings = model_output[0].reshape(len(images), -1)
return OnnxOutputContext(model_output=embeddings)
Expand Down Expand Up @@ -129,7 +136,9 @@ def _embed_images(


class ImageEmbeddingWorker(EmbeddingWorker[T]):
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
def process(
self, items: Iterable[tuple[int, Any]], **kwargs: Any
) -> Iterable[tuple[int, Any]]:
for idx, batch in items:
embeddings = self.model.onnx_embed(batch)
embeddings = self.model.onnx_embed(batch, **kwargs)
yield idx, embeddings
25 changes: 21 additions & 4 deletions fastembed/late_interaction_multimodal/onnx_multimodal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
from PIL import Image
import onnxruntime as ort
from tokenizers import Encoding, Tokenizer

from fastembed.common import OnnxProvider, ImageInput
Expand Down Expand Up @@ -88,6 +89,7 @@ def onnx_embed_text(
documents: list[str],
**kwargs: Any,
) -> OnnxOutputContext:
device_id = kwargs.pop("device_id", 0)
encoded = self.tokenize(documents, **kwargs)
input_ids = np.array([e.ids for e in encoded])
attention_mask = np.array([e.attention_mask for e in encoded]) # type: ignore[union-attr]
Expand All @@ -103,6 +105,11 @@ def onnx_embed_text(
)

onnx_input = self._preprocess_onnx_text_input(onnx_input, **kwargs)
device_id = device_id if isinstance(device_id, int) else 0
run_options = ort.RunOptions()
run_options.add_run_config_entry(
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
)
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
return OnnxOutputContext(
model_output=model_output[0],
Expand Down Expand Up @@ -160,6 +167,7 @@ def _embed_documents(
yield from self._post_process_onnx_text_output(batch) # type: ignore

def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
device_id = kwargs.pop("device_id", 0)
with contextlib.ExitStack():
image_files = [
Image.open(image) if not isinstance(image, Image.Image) else image
Expand All @@ -169,6 +177,11 @@ def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutpu
encoded = np.array(self.processor(image_files))
onnx_input = {"pixel_values": encoded}
onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs)
device_id = device_id if isinstance(device_id, int) else 0
run_options = ort.RunOptions()
run_options.add_run_config_entry(
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
)
model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]
embeddings = model_output[0].reshape(len(images), -1)
return OnnxOutputContext(model_output=embeddings)
Expand Down Expand Up @@ -241,9 +254,11 @@ def init_embedding(
) -> OnnxMultimodalModel:
raise NotImplementedError()

def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
def process(
self, items: Iterable[tuple[int, Any]], **kwargs: Any
) -> Iterable[tuple[int, Any]]:
for idx, batch in items:
onnx_output = self.model.onnx_embed_text(batch)
onnx_output = self.model.onnx_embed_text(batch, **kwargs)
yield idx, onnx_output


Expand All @@ -265,7 +280,9 @@ def init_embedding(
) -> OnnxMultimodalModel:
raise NotImplementedError()

def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
def process(
self, items: Iterable[tuple[int, Any]], **kwargs: Any
) -> Iterable[tuple[int, Any]]:
for idx, batch in items:
embeddings = self.model.onnx_embed_image(batch)
embeddings = self.model.onnx_embed_image(batch, **kwargs)
yield idx, embeddings
6 changes: 4 additions & 2 deletions fastembed/parallel_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class Worker:
def start(cls, *args: Any, **kwargs: Any) -> "Worker":
raise NotImplementedError()

def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
def process(
self, items: Iterable[tuple[int, Any]], **kwargs: Any
) -> Iterable[tuple[int, Any]]:
raise NotImplementedError()


Expand Down Expand Up @@ -63,7 +65,7 @@ def input_queue_iterable() -> Iterable[Any]:
break
yield item

for processed_item in worker.process(input_queue_iterable()):
for processed_item in worker.process(input_queue_iterable(), **kwargs):
output_queue.put(processed_item)
except Exception as e: # pylint: disable=broad-except
logging.exception(e)
Expand Down
15 changes: 12 additions & 3 deletions fastembed/rerank/cross_encoder/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Iterable, Optional, Sequence, Type

import numpy as np
import onnxruntime as ort
from tokenizers import Encoding

from fastembed.common.onnx_model import (
Expand Down Expand Up @@ -68,10 +69,16 @@ def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOut
return self.onnx_embed_pairs(pairs, **kwargs)

def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxOutputContext:
device_id = kwargs.pop("device_id", 0)
tokenized_input = self.tokenize(pairs, **kwargs)
inputs = self._build_onnx_input(tokenized_input)
onnx_input = self._preprocess_onnx_input(inputs, **kwargs)
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
device_id = device_id if isinstance(device_id, int) else 0
run_options = ort.RunOptions()
run_options.add_run_config_entry(
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
)
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input, run_options) # type: ignore[union-attr]
relevant_output = outputs[0]
scores: NumpyArray = relevant_output[:, 0]
return OnnxOutputContext(model_output=scores)
Expand Down Expand Up @@ -163,7 +170,9 @@ def init_embedding(
) -> OnnxCrossEncoderModel:
raise NotImplementedError()

def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
def process(
self, items: Iterable[tuple[int, Any]], **kwargs: Any
) -> Iterable[tuple[int, Any]]:
for idx, batch in items:
onnx_output = self.model.onnx_embed_pairs(batch)
onnx_output = self.model.onnx_embed_pairs(batch, **kwargs)
yield idx, onnx_output
2 changes: 1 addition & 1 deletion fastembed/sparse/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def start(cls, model_name: str, cache_dir: str, **kwargs: Any) -> "Bm25Worker":
return cls(model_name=model_name, cache_dir=cache_dir, **kwargs)

def process(
self, items: Iterable[tuple[int, Any]]
self, items: Iterable[tuple[int, Any]], **kwargs: Any
) -> Iterable[tuple[int, list[SparseEmbedding]]]:
for idx, batch in items:
onnx_output = self.model.raw_embed(batch)
Expand Down
12 changes: 9 additions & 3 deletions fastembed/text/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def onnx_embed(
documents: list[str],
**kwargs: Any,
) -> OnnxOutputContext:
device_id = kwargs.pop("device_id", 0)
encoded = self.tokenize(documents, **kwargs)
input_ids = np.array([e.ids for e in encoded])
attention_mask = np.array([e.attention_mask for e in encoded])
Expand All @@ -83,8 +84,11 @@ def onnx_embed(
)
onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs)

device_id = device_id if isinstance(device_id, int) else 0
run_options = ort.RunOptions()
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:0")
run_options.add_run_config_entry(
"memory.enable_memory_arena_shrinkage", f"gpu:{device_id}"
)
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input, run_options) # type: ignore[union-attr]
return OnnxOutputContext(
model_output=model_output[0],
Expand Down Expand Up @@ -143,7 +147,9 @@ def _embed_documents(


class TextEmbeddingWorker(EmbeddingWorker[T]):
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, OnnxOutputContext]]:
def process(
self, items: Iterable[tuple[int, Any]], **kwargs: Any
) -> Iterable[tuple[int, OnnxOutputContext]]:
for idx, batch in items:
onnx_output = self.model.onnx_embed(batch)
onnx_output = self.model.onnx_embed(batch, **kwargs)
yield idx, onnx_output

0 comments on commit 758d339

Please sign in to comment.