Skip to content

Commit 952e2a6

Browse files
authored
fix: fix Sagemaker types + add py.typed (#2023)
* fix: fix Sagemaker types + add py.typed * improve type hint
1 parent fc367f6 commit 952e2a6

File tree

4 files changed

+20
-31
lines changed

4 files changed

+20
-31
lines changed

.github/workflows/amazon_sagemaker.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,9 @@ jobs:
4949
- name: Install Hatch
5050
run: pip install --upgrade hatch
5151

52-
# TODO: Once this integration is properly typed, use hatch run test:types
53-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
5452
- name: Lint
5553
if: matrix.python-version == '3.9' && runner.os == 'Linux'
56-
run: hatch run fmt-check && hatch run lint:typing
54+
run: hatch run fmt-check && hatch run test:types
5755

5856
- name: Generate docs
5957
if: matrix.python-version == '3.9' && runner.os == 'Linux'

integrations/amazon_sagemaker/pyproject.toml

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,20 @@ unit = 'pytest -m "not integration" {args:tests}'
6767
integration = 'pytest -m "integration" {args:tests}'
6868
all = 'pytest {args:tests}'
6969
cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x'
70-
types = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
70+
types = "mypy -p haystack_integrations.components.generators.amazon_sagemaker {args}"
7171

72-
# TODO: remove lint environment once this integration is properly typed
73-
# test environment should be used instead
74-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
75-
[tool.hatch.envs.lint]
76-
installer = "uv"
77-
detached = true
78-
dependencies = ["pip", "mypy>=1.0.0", "ruff>=0.0.243"]
79-
[tool.hatch.envs.lint.scripts]
80-
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
72+
[tool.mypy]
73+
install_types = true
74+
non_interactive = true
75+
check_untyped_defs = true
76+
disallow_incomplete_defs = true
77+
78+
[[tool.mypy.overrides]]
79+
module = [
80+
"botocore.*",
81+
"boto3.*",
82+
]
83+
ignore_missing_imports = true
8184

8285
[tool.ruff]
8386
target-version = "py38"
@@ -154,24 +157,10 @@ omit = ["*/tests/*", "*/__init__.py"]
154157
show_missing = true
155158
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
156159

157-
[[tool.mypy.overrides]]
158-
module = [
159-
"botocore.*",
160-
"boto3.*",
161-
"haystack.*",
162-
"haystack_integrations.*",
163-
"pytest.*",
164-
"numpy.*",
165-
]
166-
ignore_missing_imports = true
167-
168160

169161
[tool.pytest.ini_options]
170162
addopts = "--strict-markers"
171163
markers = [
172-
"unit: unit tests",
173164
"integration: integration tests",
174-
"embedders: embedders tests",
175-
"generators: generators tests",
176165
]
177166
log_cli = true

integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, ClassVar, Dict, List, Optional
2+
from typing import Any, ClassVar, Dict, List, Optional, Union
33

44
import boto3
55
import requests
@@ -142,7 +142,7 @@ def to_dict(self) -> Dict[str, Any]:
142142
)
143143

144144
@classmethod
145-
def from_dict(cls, data) -> "SagemakerGenerator":
145+
def from_dict(cls, data: Dict[str, Any]) -> "SagemakerGenerator":
146146
"""
147147
Deserializes the component from a dictionary.
148148
@@ -164,7 +164,7 @@ def _get_aws_session(
164164
aws_session_token: Optional[str] = None,
165165
aws_region_name: Optional[str] = None,
166166
aws_profile_name: Optional[str] = None,
167-
):
167+
) -> boto3.Session:
168168
"""
169169
Creates an AWS Session with the given parameters.
170170
@@ -192,7 +192,9 @@ def _get_aws_session(
192192
raise AWSConfigurationError(msg) from e
193193

194194
@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
195-
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
195+
def run(
196+
self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None
197+
) -> Dict[str, Union[List[str], List[Dict[str, Any]]]]:
196198
"""
197199
Invoke the text generation inference based on the provided prompt and generation parameters.
198200

integrations/amazon_sagemaker/src/haystack_integrations/components/generators/py.typed

Whitespace-only changes.

0 commit comments

Comments
 (0)