Skip to content

Commit 91a71b1

Browse files
authored
fix: fix Optimum types + add py.typed (#2018)
* fix: fix Optimum types + add py.typed * check types with 3.10 * with 3.13
1 parent 6ca5751 commit 91a71b1

File tree

6 files changed

+23
-35
lines changed

6 files changed

+23
-35
lines changed

.github/workflows/optimum.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ jobs:
4949
- name: Install Hatch
5050
run: pip install --upgrade hatch
5151

52-
# TODO: Once this integration is properly typed, use hatch run test:types
53-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
5452
- name: Lint
55-
if: matrix.python-version == '3.9' && runner.os == 'Linux'
56-
run: hatch run fmt-check && hatch run lint:typing
53+
# we check types with python 3.13 because with 3.9, the installation of some type stubs fails
54+
# due to incompatibilities
55+
if: matrix.python-version == '3.13' && runner.os == 'Linux'
56+
run: hatch run fmt-check && hatch run test:types
5757

5858
- name: Generate docs
5959
if: matrix.python-version == '3.9' && runner.os == 'Linux'

integrations/optimum/pyproject.toml

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,17 @@ integration = 'pytest -m "integration" {args:tests}'
7979
all = 'pytest {args:tests}'
8080
cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x'
8181

82-
types = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
82+
types = "mypy -p haystack_integrations.components.embedders.optimum {args}"
8383

84-
# TODO: remove lint environment once this integration is properly typed
85-
# test environment should be used instead
86-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
87-
[tool.hatch.envs.lint]
88-
installer = "uv"
89-
detached = true
90-
dependencies = ["pip", "mypy>=1.0.0", "ruff>=0.0.243"]
84+
[tool.mypy]
85+
install_types = true
86+
non_interactive = true
87+
check_untyped_defs = true
88+
disallow_incomplete_defs = true
9189

92-
[tool.hatch.envs.lint.scripts]
93-
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
90+
[[tool.mypy.overrides]]
91+
module = "optimum.*"
92+
ignore_missing_imports = true
9493

9594
[tool.hatch.metadata]
9695
allow-direct-references = true
@@ -173,26 +172,9 @@ omit = ["*/tests/*", "*/__init__.py"]
173172
show_missing = true
174173
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
175174

176-
177-
[[tool.mypy.overrides]]
178-
module = [
179-
"haystack.*",
180-
"haystack_integrations.*",
181-
"pytest.*",
182-
"numpy.*",
183-
"optimum.*",
184-
"torch.*",
185-
"transformers.*",
186-
"huggingface_hub.*",
187-
"sentence_transformers.*",
188-
]
189-
ignore_missing_imports = true
190-
191175
[tool.pytest.ini_options]
192176
addopts = ["--strict-markers", "-vv"]
193177
markers = [
194178
"integration: integration tests",
195-
"unit: unit tests",
196-
"embedders: embedders tests",
197179
]
198180
log_cli = true

integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
from dataclasses import dataclass
44
from pathlib import Path
5-
from typing import Any, Dict, List, Optional, Tuple, Union
5+
from typing import Any, Dict, List, Optional, Tuple, Union, overload
66

77
import numpy as np
88
import torch
@@ -111,7 +111,7 @@ def __init__(self, params: _EmbedderParams):
111111
self.params = params
112112
self.model = None
113113
self.tokenizer = None
114-
self.pooling_layer = None
114+
self.pooling_layer: Optional[SentenceTransformerPoolingLayer] = None
115115

116116
def warm_up(self):
117117
assert self.params.model_kwargs
@@ -188,6 +188,12 @@ def pool_embeddings(self, model_output: torch.Tensor, attention_mask: torch.Tens
188188
pooled_outputs = self.pooling_layer.forward(features)
189189
return pooled_outputs["sentence_embedding"]
190190

191+
@overload
192+
def embed_texts(self, texts_to_embed: str) -> List[float]: ...
193+
194+
@overload
195+
def embed_texts(self, texts_to_embed: List[str]) -> List[List[float]]: ...
196+
191197
def embed_texts(
192198
self,
193199
texts_to_embed: Union[str, List[str]],

integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
191191
return texts_to_embed
192192

193193
@component.output_types(documents=List[Document])
194-
def run(self, documents: List[Document]):
194+
def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
195195
"""
196196
Embed a list of Documents.
197197
The embedding of each Document is stored in the `embedding` field of the Document.

integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OptimumTextEmbedder":
154154
return default_from_dict(cls, data)
155155

156156
@component.output_types(embedding=List[float])
157-
def run(self, text: str):
157+
def run(self, text: str) -> Dict[str, List[float]]:
158158
"""
159159
Embed a string.
160160

integrations/optimum/src/haystack_integrations/components/embedders/py.typed

Whitespace-only changes.

0 commit comments

Comments
 (0)