Skip to content

Commit 1f25794

Browse files
authored
chore: fix Hugging Face components for mypy 1.15.0 (#8822)
* chore: fix Hugging Face components for mypy 1.15.0 * small fixes * fix test * rm print * use cast and be more permissive
1 parent e7c6d14 commit 1f25794

File tree

4 files changed

+26
-13
lines changed

4 files changed

+26
-13
lines changed

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
1717
from huggingface_hub import (
18+
ChatCompletionInputFunctionDefinition,
1819
ChatCompletionInputTool,
1920
ChatCompletionOutput,
2021
ChatCompletionStreamOutput,
@@ -255,8 +256,15 @@ def run(
255256

256257
hf_tools = None
257258
if tools:
258-
hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools]
259-
259+
hf_tools = [
260+
ChatCompletionInputTool(
261+
function=ChatCompletionInputFunctionDefinition(
262+
name=tool.name, description=tool.description, arguments=tool.parameters
263+
),
264+
type="function",
265+
)
266+
for tool in tools
267+
]
260268
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
261269

262270
def _run_streaming(
@@ -278,13 +286,12 @@ def _run_streaming(
278286
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
279287
choice = chunk.choices[0]
280288

281-
text = choice.delta.content
282-
if text:
283-
generated_text += text
289+
text = choice.delta.content or ""
290+
generated_text += text
284291

285292
finish_reason = choice.finish_reason
286293

287-
meta = {}
294+
meta: Dict[str, Any] = {}
288295
if finish_reason:
289296
meta["finish_reason"] = finish_reason
290297

@@ -336,7 +343,11 @@ def _run_non_streaming(
336343
)
337344
tool_calls.append(tool_call)
338345

339-
meta = {"model": self._client.model, "finish_reason": choice.finish_reason, "index": choice.index}
346+
meta: Dict[str, Any] = {
347+
"model": self._client.model,
348+
"finish_reason": choice.finish_reason,
349+
"index": choice.index,
350+
}
340351

341352
usage = {"prompt_tokens": 0, "completion_tokens": 0}
342353
if api_chat_output.usage:

haystack/components/generators/chat/hugging_face_local.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
145145
elif isinstance(huggingface_pipeline_kwargs["model"], str):
146146
task = model_info(
147147
huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
148-
).pipeline_tag
148+
).pipeline_tag # type: ignore[assignment] # we'll check below if task is in supported tasks
149149

150150
if task not in PIPELINE_SUPPORTED_TASKS:
151151
raise ValueError(

haystack/components/generators/hugging_face_api.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from dataclasses import asdict
66
from datetime import datetime
7-
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
7+
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, cast
88

99
from haystack import component, default_from_dict, default_to_dict, logging
1010
from haystack.dataclasses import StreamingChunk
@@ -17,8 +17,8 @@
1717
from huggingface_hub import (
1818
InferenceClient,
1919
TextGenerationOutput,
20-
TextGenerationOutputToken,
2120
TextGenerationStreamOutput,
21+
TextGenerationStreamOutputToken,
2222
)
2323

2424

@@ -212,7 +212,8 @@ def run(
212212
if streaming_callback is not None:
213213
return self._stream_and_build_response(hf_output, streaming_callback)
214214

215-
return self._build_non_streaming_response(hf_output)
215+
# mypy doesn't know that hf_output is a TextGenerationOutput, so we cast it
216+
return self._build_non_streaming_response(cast(TextGenerationOutput, hf_output))
216217

217218
def _stream_and_build_response(
218219
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: Callable[[StreamingChunk], None]
@@ -221,7 +222,7 @@ def _stream_and_build_response(
221222
first_chunk_time = None
222223

223224
for chunk in hf_output:
224-
token: TextGenerationOutputToken = chunk.token
225+
token: TextGenerationStreamOutputToken = chunk.token
225226
if token.special:
226227
continue
227228

test/components/generators/test_hugging_face_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pytest
99
from huggingface_hub import (
10+
TextGenerationOutput,
1011
TextGenerationOutputToken,
1112
TextGenerationStreamOutput,
1213
TextGenerationStreamOutputStreamDetails,
@@ -30,7 +31,7 @@ def mock_check_valid_model():
3031
@pytest.fixture
3132
def mock_text_generation():
3233
with patch("huggingface_hub.InferenceClient.text_generation", autospec=True) as mock_text_generation:
33-
mock_response = Mock()
34+
mock_response = Mock(spec=TextGenerationOutput)
3435
mock_response.generated_text = "I'm fine, thanks."
3536
details = Mock()
3637
details.finish_reason = MagicMock(field1="value")

0 commit comments

Comments
 (0)