Skip to content

Commit eaf69c9

Browse files
committed
Fix hatch run test:types
1 parent 9f88135 commit eaf69c9

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77
import os
88
from datetime import datetime, timezone
9-
from typing import Any
9+
from typing import Any, Dict, List, Optional, Union
1010

1111
from haystack import component, default_from_dict, default_to_dict, logging
1212
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
@@ -177,7 +177,9 @@ def from_dict(cls, data: dict[str, Any]) -> WatsonxChatGenerator:
177177
return default_from_dict(cls, data)
178178

179179
@component.output_types(replies=list[ChatMessage])
180-
def run(self, messages: list[ChatMessage], generation_kwargs: dict[str, Any] | None = None, stream: bool = False):
180+
def run(
181+
self, messages: list[ChatMessage], generation_kwargs: dict[str, Any] | None = None, stream: bool = False
182+
) -> Dict[str, List[ChatMessage]]:
181183
"""
182184
Generate chat completions synchronously.
183185
@@ -198,7 +200,7 @@ def run(self, messages: list[ChatMessage], generation_kwargs: dict[str, Any] | N
198200
@component.output_types(replies=list[ChatMessage])
199201
async def run_async(
200202
self, messages: list[ChatMessage], generation_kwargs: dict[str, Any] | None = None, stream: bool = False
201-
):
203+
) -> Dict[str, List[ChatMessage]]:
202204
"""
203205
Generate chat completions asynchronously.
204206
@@ -222,6 +224,7 @@ def _prepare_api_call(
222224
merged_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
223225

224226
watsonx_messages = []
227+
content: Union[Optional[str], Dict[str, Any]]
225228
for msg in messages:
226229
if msg.is_from("user"):
227230
content = msg.text
@@ -347,7 +350,10 @@ def _process_response(self, response: dict[str, Any]) -> dict[str, Any]:
347350
ToolCall(id=tc.get("id"), tool_name=tc["function"]["name"], arguments=arguments)
348351
)
349352
except json.JSONDecodeError:
350-
logger.warning("Failed to parse tool call arguments: %s", tc["function"]["arguments"])
353+
logger.warning(
354+
"Failed to parse tool call arguments: {tool_call_args}",
355+
tool_call_args=tc["function"]["arguments"]
356+
)
351357

352358
return {
353359
"replies": [

integrations/watsonx/src/haystack_integrations/components/generators/watsonx/generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44
from __future__ import annotations
55

6-
from typing import Any
6+
from typing import Any, Dict, List, Union
77

88
from haystack import component, default_from_dict, default_to_dict, logging
99
from haystack.dataclasses import StreamingChunk
@@ -181,7 +181,7 @@ def run(
181181
generation_kwargs: dict[str, Any] | None = None,
182182
guardrails: bool = False,
183183
stream: bool = False,
184-
):
184+
) -> Dict[str, Union[List[str], List[Dict[str, Any]], List[StreamingChunk]]]:
185185
"""
186186
Generate text using the watsonx.ai model.
187187

0 commit comments

Comments
 (0)