Skip to content

Commit a6a1e82

Browse files
fix: fix Nvidia types + add py.typed (#1970)
* draft * client enum * fix workflow * improvements * Apply suggestions from code review Co-authored-by: Julian Risch <[email protected]> * fixes * model_type optional + NimBackend tests * tests --------- Co-authored-by: Julian Risch <[email protected]>
1 parent 8e03f3e commit a6a1e82

File tree

18 files changed

+238
-159
lines changed

18 files changed

+238
-159
lines changed

.github/workflows/nvidia.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,9 @@ jobs:
5151
- name: Install Hatch
5252
run: pip install --upgrade hatch
5353

54-
# TODO: Once this integration is properly typed, use hatch run test:types
55-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
5654
- name: Lint
5755
if: matrix.python-version == '3.9' && runner.os == 'Linux'
58-
run: hatch run fmt-check && hatch run lint:typing
56+
run: hatch run fmt-check && hatch run test:types
5957

6058
- name: Generate docs
6159
if: matrix.python-version == '3.9' && runner.os == 'Linux'

integrations/nvidia/pyproject.toml

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,17 @@ integration = 'pytest -m "integration" {args:tests}'
6868
all = 'pytest {args:tests}'
6969
cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x'
7070

71-
types = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
71+
types = """mypy -p haystack_integrations.components.embedders.nvidia \
72+
-p haystack_integrations.components.generators.nvidia \
73+
-p haystack_integrations.components.rankers.nvidia \
74+
-p haystack_integrations.utils.nvidia {args}"""
7275

73-
# TODO: remove lint environment once this integration is properly typed
74-
# test environment should be used instead
75-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
76-
[tool.hatch.envs.lint]
77-
installer = "uv"
78-
detached = true
79-
dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"]
76+
[tool.mypy]
77+
install_types = true
78+
non_interactive = true
79+
check_untyped_defs = true
80+
disallow_incomplete_defs = true
8081

81-
[tool.hatch.envs.lint.scripts]
82-
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
8382

8483
[tool.black]
8584
target-version = ["py38"]
@@ -160,27 +159,9 @@ omit = ["*/tests/*", "*/__init__.py"]
160159
show_missing = true
161160
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
162161

163-
164-
[[tool.mypy.overrides]]
165-
module = [
166-
"nvidia.*",
167-
"haystack.*",
168-
"haystack_integrations.*",
169-
"pytest.*",
170-
"numpy.*",
171-
"requests_mock.*",
172-
"openai.*",
173-
"pydantic.*",
174-
]
175-
ignore_missing_imports = true
176-
177162
[tool.pytest.ini_options]
178163
addopts = "--strict-markers"
179164
markers = [
180165
"integration: integration tests",
181-
"unit: unit tests",
182-
"embedders: embedders tests",
183-
"generators: generators tests",
184-
"chat_generators: chat_generators tests",
185166
]
186167
log_cli = true

integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tqdm import tqdm
1212

1313
from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode
14-
from haystack_integrations.utils.nvidia import DEFAULT_API_URL, Model, NimBackend, url_validation
14+
from haystack_integrations.utils.nvidia import DEFAULT_API_URL, Client, Model, NimBackend, url_validation
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -122,7 +122,9 @@ def default_model(self):
122122
UserWarning,
123123
stacklevel=2,
124124
)
125-
self.model = self.backend.model = name
125+
self.model = name
126+
if self.backend:
127+
self.backend.model = name
126128
else:
127129
error_message = "No locally hosted model was found."
128130
raise ValueError(error_message)
@@ -143,7 +145,7 @@ def warm_up(self):
143145
api_url=self.api_url,
144146
api_key=self.api_key,
145147
model_kwargs=model_kwargs,
146-
client=self.__class__.__name__,
148+
client=Client.NVIDIA_DOCUMENT_EMBEDDER,
147149
timeout=self.timeout,
148150
)
149151
if not self.model and self.backend.model:
@@ -232,7 +234,7 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List
232234
return all_embeddings, {"usage": {"prompt_tokens": usage_prompt_tokens, "total_tokens": usage_total_tokens}}
233235

234236
@component.output_types(documents=List[Document], meta=Dict[str, Any])
235-
def run(self, documents: List[Document]):
237+
def run(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict[str, Any]]]:
236238
"""
237239
Embed a list of Documents.
238240

integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from haystack.utils import Secret, deserialize_secrets_inplace
1111

1212
from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode
13-
from haystack_integrations.utils.nvidia import DEFAULT_API_URL, Model, NimBackend, url_validation
13+
from haystack_integrations.utils.nvidia import DEFAULT_API_URL, Client, Model, NimBackend, url_validation
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -112,7 +112,9 @@ def default_model(self):
112112
UserWarning,
113113
stacklevel=2,
114114
)
115-
self.model = self.backend.model = name
115+
self.model = name
116+
if self.backend:
117+
self.backend.model = name
116118
else:
117119
error_message = "No locally hosted model was found."
118120
raise ValueError(error_message)
@@ -134,7 +136,7 @@ def warm_up(self):
134136
api_key=self.api_key,
135137
model_kwargs=model_kwargs,
136138
timeout=self.timeout,
137-
client=self.__class__.__name__,
139+
client=Client.NVIDIA_TEXT_EMBEDDER,
138140
)
139141
self._initialized = True
140142

@@ -185,7 +187,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaTextEmbedder":
185187
return default_from_dict(cls, data)
186188

187189
@component.output_types(embedding=List[float], meta=Dict[str, Any])
188-
def run(self, text: str):
190+
def run(self, text: str) -> Dict[str, Union[List[float], Dict[str, Any]]]:
189191
"""
190192
Embed a string.
191193

integrations/nvidia/src/haystack_integrations/components/embedders/py.typed

Whitespace-only changes.

integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
import os
66
import warnings
7-
from typing import Any, Dict, List, Optional
7+
from typing import Any, Dict, List, Optional, Union
88

99
from haystack import component, default_from_dict, default_to_dict
1010
from haystack.utils.auth import Secret, deserialize_secrets_inplace
1111

12-
from haystack_integrations.utils.nvidia import DEFAULT_API_URL, Model, NimBackend, is_hosted, url_validation
12+
from haystack_integrations.utils.nvidia import DEFAULT_API_URL, Client, Model, NimBackend, is_hosted, url_validation
1313

1414

1515
@component
@@ -104,7 +104,9 @@ def default_model(self):
104104
UserWarning,
105105
stacklevel=2,
106106
)
107-
self._model = self.backend.model = name
107+
self._model = name
108+
if self.backend:
109+
self.backend.model = name
108110
else:
109111
error_message = "No locally hosted model was found."
110112
raise ValueError(error_message)
@@ -123,7 +125,7 @@ def warm_up(self):
123125
api_key=self._api_key,
124126
model_kwargs=self._model_arguments,
125127
timeout=self.timeout,
126-
client=self.__class__.__name__,
128+
client=Client.NVIDIA_GENERATOR,
127129
)
128130

129131
if not self.is_hosted and not self._model:
@@ -169,7 +171,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaGenerator":
169171
return default_from_dict(cls, data)
170172

171173
@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
172-
def run(self, prompt: str):
174+
def run(self, prompt: str) -> Dict[str, Union[List[str], List[Dict[str, Any]]]]:
173175
"""
174176
Queries the model with the provided prompt.
175177

integrations/nvidia/src/haystack_integrations/components/generators/py.typed

Whitespace-only changes.

integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/py.typed

Whitespace-only changes.

integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from haystack.utils import Secret, deserialize_secrets_inplace
1111

1212
from haystack_integrations.components.rankers.nvidia.truncate import RankerTruncateMode
13-
from haystack_integrations.utils.nvidia import DEFAULT_API_URL, NimBackend, is_hosted, url_validation
13+
from haystack_integrations.utils.nvidia import DEFAULT_API_URL, Client, NimBackend, is_hosted, url_validation
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -162,7 +162,7 @@ def warm_up(self):
162162
:raises ValueError: If the API key is required for hosted NVIDIA NIMs.
163163
"""
164164
if not self._initialized:
165-
model_kwargs = {}
165+
model_kwargs: Dict[str, Any] = {}
166166
if self.truncate is not None:
167167
model_kwargs.update(truncate=str(self.truncate))
168168
self.backend = NimBackend(
@@ -172,9 +172,9 @@ def warm_up(self):
172172
api_key=self.api_key,
173173
model_kwargs=model_kwargs,
174174
timeout=self.timeout,
175-
client=self.__class__.__name__,
175+
client=Client.NVIDIA_RANKER,
176176
)
177-
if not self.is_hosted and not self._model:
177+
if not self.is_hosted and not self.model:
178178
if self.backend.model:
179179
self.model = self.backend.model
180180
self._initialized = True

integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from .client import Client
56
from .models import DEFAULT_API_URL, Model
67
from .nim_backend import NimBackend
78
from .utils import is_hosted, url_validation
89

9-
__all__ = ["DEFAULT_API_URL", "Model", "NimBackend", "is_hosted", "url_validation", "validate_hosted_model"]
10+
__all__ = ["DEFAULT_API_URL", "Client", "Model", "NimBackend", "is_hosted", "url_validation", "validate_hosted_model"]
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from enum import Enum
2+
3+
4+
class Client(Enum):
5+
"""
6+
Client to use for NVIDIA NIMs.
7+
"""
8+
9+
NVIDIA_GENERATOR = "NvidiaGenerator"
10+
NVIDIA_TEXT_EMBEDDER = "NvidiaTextEmbedder"
11+
NVIDIA_DOCUMENT_EMBEDDER = "NvidiaDocumentEmbedder"
12+
NVIDIA_RANKER = "NvidiaRanker"
13+
14+
def __str__(self) -> str:
15+
"""Convert a Client enum to a string."""
16+
return self.value
17+
18+
@staticmethod
19+
def from_str(string: str) -> "Client":
20+
"""Convert a string to a Client enum."""
21+
enum_map = {e.value: e for e in Client}
22+
mode = enum_map.get(string)
23+
if mode is None:
24+
msg = f"Unknown client '{string}' to use for NVIDIA NIMs. Supported modes are: {list(enum_map.keys())}"
25+
raise ValueError(msg)
26+
return mode

0 commit comments

Comments
 (0)