Skip to content

Commit

Permalink
specify shrinkage as run options not session options
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Mar 4, 2025
1 parent b82e4d0 commit c212c1f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _load_onnx_model(

so = ort.SessionOptions()
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
so.add_session_config_entry("memory.enable_memory_arena_shrinkage", "1")
# so.add_session_config_entry("memory.enable_memory_arena_shrinkage", "1")

if threads is not None:
so.intra_op_num_threads = threads
Expand Down
5 changes: 4 additions & 1 deletion fastembed/text/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np
import onnxruntime as ort
from numpy.typing import NDArray
from tokenizers import Encoding, Tokenizer

Expand Down Expand Up @@ -82,7 +83,9 @@ def onnx_embed(
)
onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs)

model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
run_options = ort.RunOptions()
run_options.add_config_entry("memory.enable_memory_arena_shrinkage", "1")
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input, run_options) # type: ignore[union-attr]
return OnnxOutputContext(
model_output=model_output[0],
attention_mask=onnx_input.get("attention_mask", attention_mask),
Expand Down

0 comments on commit c212c1f

Please sign in to comment.