diff --git a/nemoguardrails/embeddings/providers/fastembed.py b/nemoguardrails/embeddings/providers/fastembed.py index e4956626c7..a124f783fd 100644 --- a/nemoguardrails/embeddings/providers/fastembed.py +++ b/nemoguardrails/embeddings/providers/fastembed.py @@ -19,6 +19,13 @@ from .base import EmbeddingModel +def _is_missing_onnx_model_error(ex: Exception) -> bool: + message = str(ex) + return "model.onnx" in message and ( + "Could not find model.onnx in" in message or "NO_SUCHFILE" in message or "File doesn't exist" in message + ) + + def get_executor(): from . import embeddings_executor @@ -51,14 +58,16 @@ def __init__(self, embedding_model: str, **kwargs): try: self.model = Embedding(embedding_model, **kwargs) - except ValueError as ex: + except Exception as ex: # Sometimes the cached model in the temporary folder gets removed, # but the folder still exists, which causes an error. In this case, # we fall back to an explicit cache directory. - if "Could not find model.onnx in" in str(ex): - self.model = Embedding(embedding_model, cache_dir=".cache", **kwargs) + if _is_missing_onnx_model_error(ex): + fallback_kwargs = dict(kwargs) + fallback_kwargs["cache_dir"] = ".cache" + self.model = Embedding(embedding_model, **fallback_kwargs) else: - raise ex + raise # Get the embedding dimension of the model self.embedding_size = len(list(self.model.embed("test"))[0].tolist()) diff --git a/tests/test_embeddings_fastembed.py b/tests/test_embeddings_fastembed.py index a7ae4f2708..bf6f8a9f54 100644 --- a/tests/test_embeddings_fastembed.py +++ b/tests/test_embeddings_fastembed.py @@ -16,12 +16,54 @@ import os import pytest +from onnxruntime.capi.onnxruntime_pybind11_state import NoSuchFile from nemoguardrails.embeddings.providers.fastembed import FastEmbedEmbeddingModel LIVE_TEST_MODE = os.environ.get("LIVE_TEST") +class _FakeEmbeddingVector: + def tolist(self): + return [0.1, 0.2] + + +def test_recovers_from_missing_onnxruntime_model_cache(monkeypatch): + calls = [] + + class FakeTextEmbedding: + def __init__(self, model_name, **kwargs): + calls.append((model_name, kwargs)) + if len(calls) == 1: + raise NoSuchFile( + "[ONNXRuntimeError] : 3 : NO_SUCHFILE : Load model from /tmp/model.onnx failed. File doesn't exist" + ) + + def embed(self, documents): + return [_FakeEmbeddingVector()] + + monkeypatch.setattr("fastembed.TextEmbedding", FakeTextEmbedding) + + model = FastEmbedEmbeddingModel("all-MiniLM-L6-v2") + + assert model.embedding_size == 2 + assert calls == [ + ("sentence-transformers/all-MiniLM-L6-v2", {}), + ("sentence-transformers/all-MiniLM-L6-v2", {"cache_dir": ".cache"}), + ] + + +def test_reraises_unrelated_fastembed_errors(monkeypatch): + class FakeTextEmbedding: + def __init__(self, model_name, **kwargs): + raise ValueError("unrelated failure") + + monkeypatch.setattr("fastembed.TextEmbedding", FakeTextEmbedding) + + with pytest.raises(ValueError, match="unrelated failure"): + FastEmbedEmbeddingModel("all-MiniLM-L6-v2") + + @pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") def test_sync_embeddings(): model = FastEmbedEmbeddingModel("all-MiniLM-L6-v2")