Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support Trition with Nvidia embedders #1098

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion integrations/nvidia/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ dependencies = [
"pytest-rerunfailures",
"haystack-pydoc-tools",
"requests_mock",
"tritonclient[http,grpc]",
]

[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
test-cov = "coverage run -m pytest {args:tests}"
Expand All @@ -70,7 +72,7 @@ style = [
"ruff check {args:. --exclude tests/}",
"black --check --diff {args:.}",
]
fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"]
fmt = ["black {args:.}", "ruff check --fix {args:. --exclude tests/}", "style"]
all = ["style", "typing"]

[tool.black]
Expand Down Expand Up @@ -160,6 +162,7 @@ module = [
"numpy.*",
"requests_mock.*",
"pydantic.*",
"tritonclient.*",
]
ignore_missing_imports = true

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
from tqdm import tqdm

from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation
from haystack_integrations.utils.nvidia import NimBackend, TritonBackend, is_hosted, url_validation

from .truncate import EmbeddingTruncateMode

Expand All @@ -16,7 +16,7 @@
class NvidiaDocumentEmbedder:
"""
A component for embedding documents using embedding models provided by
[NVIDIA NIMs](https://ai.nvidia.com).
[NVIDIA NIMs](https://ai.nvidia.com) or [NIVIDIA Triton](https://developer.nvidia.com/triton-inference-server).

Usage example:
```python
Expand Down Expand Up @@ -44,6 +44,8 @@ def __init__(
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
truncate: Optional[Union[EmbeddingTruncateMode, str]] = None,
backend: Literal["nim", "triton-http", "triton-grpc"] = "nim",
timeout: Optional[float] = None,
):
"""
Create a NvidiaTextEmbedder component.
Expand Down Expand Up @@ -73,19 +75,30 @@ def __init__(
:param truncate:
Specifies how inputs longer that the maximum token length should be truncated.
If None the behavior is model-dependent, see the official documentation for more information.
:param backend:
The backend to use for the component. Currently supported are "nim", "triton-http", and "triton-grpc".
Default is "nim".
:param timeout:
Timeout for the request in seconds. If not set, defaults either to `NVIDIA_TIMEOUT` environment variable
or 60 seconds.
"""

self.api_key = api_key
self.model = model
self.api_url = url_validation(api_url, _DEFAULT_API_URL, ["v1/embeddings"])
self.api_url = url_validation(api_url, _DEFAULT_API_URL, ["v1/embeddings"]) if backend == "nim" else api_url
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self._backend = backend
self.timeout = timeout

if isinstance(truncate, str):
if self._backend != "nim":
error_message = "Truncation is only supported with the nim backend."
raise ValueError(error_message)
truncate = EmbeddingTruncateMode.from_str(truncate)
self.truncate = truncate

Expand Down Expand Up @@ -121,15 +134,25 @@ def warm_up(self):
if self._initialized:
return

model_kwargs = {"input_type": "passage"}
if self.truncate is not None:
model_kwargs["truncate"] = str(self.truncate)
self.backend = NimBackend(
self.model,
api_url=self.api_url,
api_key=self.api_key,
model_kwargs=model_kwargs,
)
if self._backend == "nim":
model_kwargs = {"input_type": "passage"}
if self.truncate is not None:
model_kwargs["truncate"] = str(self.truncate)
self.backend = NimBackend(
self.model,
api_url=self.api_url,
api_key=self.api_key,
model_kwargs=model_kwargs,
timeout=self.timeout,
)
else:
self.backend = TritonBackend(
model=self.model,
api_url=self.api_url,
api_key=self.api_key,
protocol="http" if self._backend == "triton-http" else "grpc",
timeout=self.timeout,
)

self._initialized = True

Expand All @@ -155,6 +178,7 @@ def to_dict(self) -> Dict[str, Any]:
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
truncate=str(self.truncate) if self.truncate is not None else None,
backend=self._backend,
)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import warnings
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from haystack import component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace

from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation
from haystack_integrations.utils.nvidia import NimBackend, TritonBackend, is_hosted, url_validation

from .truncate import EmbeddingTruncateMode

Expand All @@ -15,7 +15,7 @@
class NvidiaTextEmbedder:
"""
A component for embedding strings using embedding models provided by
[NVIDIA NIMs](https://ai.nvidia.com).
[NVIDIA NIMs](https://ai.nvidia.com) or [NIVIDIA Triton](https://developer.nvidia.com/triton-inference-server).

For models that differentiate between query and document inputs,
this component embeds the input string as a query.
Expand All @@ -41,6 +41,8 @@ def __init__(
prefix: str = "",
suffix: str = "",
truncate: Optional[Union[EmbeddingTruncateMode, str]] = None,
backend: Literal["nim", "triton-http", "triton-grpc"] = "nim",
timeout: Optional[float] = None,
):
"""
Create a NvidiaTextEmbedder component.
Expand All @@ -61,15 +63,26 @@ def __init__(
:param truncate:
Specifies how inputs longer that the maximum token length should be truncated.
If None the behavior is model-dependent, see the official documentation for more information.
:param backend:
The backend to use for the component. Currently supported are "nim", "triton-http", and "triton-grpc".
Default is "nim".
:param timeout:
Timeout for the request in seconds. If not set, defaults either to `NVIDIA_TIMEOUT` environment variable
or 60 seconds.
"""

self.api_key = api_key
self.model = model
self.api_url = url_validation(api_url, _DEFAULT_API_URL, ["v1/embeddings"])
self.api_url = url_validation(api_url, _DEFAULT_API_URL, ["v1/embeddings"]) if backend == "nim" else api_url
self.prefix = prefix
self.suffix = suffix
self._backend = backend
self.timeout = timeout

if isinstance(truncate, str):
if self._backend != "nim":
error_message = "Truncation is only supported with the nim backend."
raise ValueError(error_message)
truncate = EmbeddingTruncateMode.from_str(truncate)
self.truncate = truncate

Expand Down Expand Up @@ -105,15 +118,25 @@ def warm_up(self):
if self._initialized:
return

model_kwargs = {"input_type": "query"}
if self.truncate is not None:
model_kwargs["truncate"] = str(self.truncate)
self.backend = NimBackend(
self.model,
api_url=self.api_url,
api_key=self.api_key,
model_kwargs=model_kwargs,
)
if self._backend == "nim":
model_kwargs = {"input_type": "query"}
if self.truncate is not None:
model_kwargs["truncate"] = str(self.truncate)
self.backend = NimBackend(
self.model,
api_url=self.api_url,
api_key=self.api_key,
model_kwargs=model_kwargs,
timeout=self.timeout,
)
else:
self.backend = TritonBackend(
model=self.model,
api_url=self.api_url,
api_key=self.api_key,
protocol="http" if self._backend == "triton-http" else "grpc",
timeout=self.timeout,
)

self._initialized = True

Expand All @@ -135,6 +158,7 @@ def to_dict(self) -> Dict[str, Any]:
prefix=self.prefix,
suffix=self.suffix,
truncate=str(self.truncate) if self.truncate is not None else None,
backend=self._backend,
)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .nim_backend import Model, NimBackend
from .utils import is_hosted, url_validation
from .nim_backend import NimBackend
from .triton_backend import TritonBackend
from .utils import Model, is_hosted, url_validation

__all__ = ["NimBackend", "Model", "is_hosted", "url_validation"]
__all__ = ["NimBackend", "TritonBackend", "Model", "is_hosted", "url_validation"]
Original file line number Diff line number Diff line change
@@ -1,27 +1,10 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

import requests
from haystack import Document
from haystack.utils import Secret

REQUEST_TIMEOUT = 60


@dataclass
class Model:
"""
Model information.

id: unique identifier for the model, passed as model parameter for requests
aliases: list of aliases for the model
base_model: root model for the model
All aliases are deprecated and will trigger a warning when used.
"""

id: str
aliases: Optional[List[str]] = field(default_factory=list)
base_model: Optional[str] = None
from haystack_integrations.utils.nvidia.utils import REQUEST_TIMEOUT, Model


class NimBackend:
Expand All @@ -31,6 +14,7 @@ def __init__(
api_url: str,
api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"),
model_kwargs: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
):
headers = {
"Content-Type": "application/json",
Expand All @@ -47,6 +31,10 @@ def __init__(
self.api_url = api_url
self.model_kwargs = model_kwargs or {}

if timeout is None:
timeout = REQUEST_TIMEOUT
self.timeout = timeout

def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]:
url = f"{self.api_url}/embeddings"

Expand All @@ -57,7 +45,7 @@ def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]:
"input": texts,
**self.model_kwargs,
},
timeout=REQUEST_TIMEOUT,
timeout=self.timeout,
)
res.raise_for_status()

Expand Down Expand Up @@ -85,7 +73,7 @@ def generate(self, prompt: str) -> Tuple[List[str], List[Dict[str, Any]]]:
],
**self.model_kwargs,
},
timeout=REQUEST_TIMEOUT,
timeout=self.timeout,
)
res.raise_for_status()

Expand Down Expand Up @@ -120,7 +108,7 @@ def models(self) -> List[Model]:

res = self.session.get(
url,
timeout=REQUEST_TIMEOUT,
timeout=self.timeout,
)
res.raise_for_status()

Expand All @@ -147,7 +135,7 @@ def rank(
"passages": [{"text": doc.content} for doc in documents],
**self.model_kwargs,
},
timeout=REQUEST_TIMEOUT,
timeout=self.timeout,
)
res.raise_for_status()

Expand Down
Loading