Skip to content

Commit 2bc58d2

Browse files
authored
feat: support for tools in HuggingFaceAPIChatGenerator (#8661)
* message conversion function * hfapi w tools * right test file + hf_hub version * release note * feedback
1 parent c306bee commit 2bc58d2

File tree

11 files changed

+509
-84
lines changed

11 files changed

+509
-84
lines changed

haystack/components/embedders/hugging_face_api_document_embedder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
1515
from haystack.utils.url_validation import is_valid_http_url
1616

17-
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
17+
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
1818
from huggingface_hub import InferenceClient
1919

2020
logger = logging.getLogger(__name__)

haystack/components/embedders/hugging_face_api_text_embedder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
1212
from haystack.utils.url_validation import is_valid_http_url
1313

14-
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
14+
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
1515
from huggingface_hub import InferenceClient
1616

1717
logger = logging.getLogger(__name__)

haystack/components/generators/chat/hugging_face_api.py

+107-45
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,25 @@
55
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
66

77
from haystack import component, default_from_dict, default_to_dict, logging
8-
from haystack.dataclasses import ChatMessage, StreamingChunk
8+
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
9+
from haystack.dataclasses.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
910
from haystack.lazy_imports import LazyImport
1011
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
11-
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
12+
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
1213
from haystack.utils.url_validation import is_valid_http_url
1314

14-
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.23.0\"'") as huggingface_hub_import:
15-
from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput, InferenceClient
15+
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
16+
from huggingface_hub import (
17+
ChatCompletionInputTool,
18+
ChatCompletionOutput,
19+
ChatCompletionStreamOutput,
20+
InferenceClient,
21+
)
1622

1723

1824
logger = logging.getLogger(__name__)
1925

2026

21-
def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]:
22-
"""
23-
Convert a message to the format expected by Hugging Face APIs.
24-
25-
:returns: A dictionary with the following keys:
26-
- `role`
27-
- `content`
28-
"""
29-
return {"role": message.role.value, "content": message.text or ""}
30-
31-
3227
@component
3328
class HuggingFaceAPIChatGenerator:
3429
"""
@@ -107,6 +102,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
107102
generation_kwargs: Optional[Dict[str, Any]] = None,
108103
stop_words: Optional[List[str]] = None,
109104
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
105+
tools: Optional[List[Tool]] = None,
110106
):
111107
"""
112108
Initialize the HuggingFaceAPIChatGenerator instance.
@@ -121,14 +117,22 @@ def __init__( # pylint: disable=too-many-positional-arguments
121117
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
122118
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
123119
`TEXT_GENERATION_INFERENCE`.
124-
:param token: The Hugging Face token to use as HTTP bearer authorization.
120+
:param token:
121+
The Hugging Face token to use as HTTP bearer authorization.
125122
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
126123
:param generation_kwargs:
127124
A dictionary with keyword arguments to customize text generation.
128125
Some examples: `max_tokens`, `temperature`, `top_p`.
129126
For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
130-
:param stop_words: An optional list of strings representing the stop words.
131-
:param streaming_callback: An optional callable for handling streaming responses.
127+
:param stop_words:
128+
An optional list of strings representing the stop words.
129+
:param streaming_callback:
130+
An optional callable for handling streaming responses.
131+
:param tools:
132+
A list of tools for which the model can prepare calls.
133+
The chosen model should support tool/function calling, according to the model card.
134+
Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
135+
unexpected behavior.
132136
"""
133137

134138
huggingface_hub_import.check()
@@ -159,6 +163,11 @@ def __init__( # pylint: disable=too-many-positional-arguments
159163
msg = f"Unknown api_type {api_type}"
160164
raise ValueError(msg)
161165

166+
if tools:
167+
if streaming_callback is not None:
168+
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
169+
_check_duplicate_tool_names(tools)
170+
162171
# handle generation kwargs setup
163172
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
164173
generation_kwargs["stop"] = generation_kwargs.get("stop", [])
@@ -171,6 +180,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
171180
self.generation_kwargs = generation_kwargs
172181
self.streaming_callback = streaming_callback
173182
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
183+
self.tools = tools
174184

175185
def to_dict(self) -> Dict[str, Any]:
176186
"""
@@ -180,13 +190,15 @@ def to_dict(self) -> Dict[str, Any]:
180190
A dictionary containing the serialized component.
181191
"""
182192
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
193+
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
183194
return default_to_dict(
184195
self,
185196
api_type=str(self.api_type),
186197
api_params=self.api_params,
187198
token=self.token.to_dict() if self.token else None,
188199
generation_kwargs=self.generation_kwargs,
189200
streaming_callback=callback_name,
201+
tools=serialized_tools,
190202
)
191203

192204
@classmethod
@@ -195,32 +207,53 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
195207
Deserialize this component from a dictionary.
196208
"""
197209
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
210+
deserialize_tools_inplace(data["init_parameters"], key="tools")
198211
init_params = data.get("init_parameters", {})
199212
serialized_callback_handler = init_params.get("streaming_callback")
200213
if serialized_callback_handler:
201214
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
202215
return default_from_dict(cls, data)
203216

204217
@component.output_types(replies=List[ChatMessage])
205-
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
218+
def run(
219+
self,
220+
messages: List[ChatMessage],
221+
generation_kwargs: Optional[Dict[str, Any]] = None,
222+
tools: Optional[List[Tool]] = None,
223+
):
206224
"""
207225
Invoke the text generation inference based on the provided messages and generation parameters.
208226
209-
:param messages: A list of ChatMessage objects representing the input messages.
210-
:param generation_kwargs: Additional keyword arguments for text generation.
227+
:param messages:
228+
A list of ChatMessage objects representing the input messages.
229+
:param generation_kwargs:
230+
Additional keyword arguments for text generation.
231+
:param tools:
232+
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
233+
during component initialization.
211234
:returns: A dictionary with the following keys:
212235
- `replies`: A list containing the generated responses as ChatMessage objects.
213236
"""
214237

215238
# update generation kwargs by merging with the default ones
216239
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
217240

218-
formatted_messages = [_convert_message_to_hfapi_format(message) for message in messages]
241+
formatted_messages = [convert_message_to_hf_format(message) for message in messages]
242+
243+
tools = tools or self.tools
244+
if tools:
245+
if self.streaming_callback:
246+
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
247+
_check_duplicate_tool_names(tools)
219248

220249
if self.streaming_callback:
221250
return self._run_streaming(formatted_messages, generation_kwargs)
222251

223-
return self._run_non_streaming(formatted_messages, generation_kwargs)
252+
hf_tools = None
253+
if tools:
254+
hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools]
255+
256+
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
224257

225258
def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]):
226259
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
@@ -229,11 +262,17 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict
229262

230263
generated_text = ""
231264

232-
for chunk in api_output: # pylint: disable=not-an-iterable
233-
text = chunk.choices[0].delta.content
265+
for chunk in api_output:
266+
# n is unused, so the API always returns only one choice
267+
# the argument is probably allowed for compatibility with OpenAI
268+
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
269+
choice = chunk.choices[0]
270+
271+
text = choice.delta.content
234272
if text:
235273
generated_text += text
236-
finish_reason = chunk.choices[0].finish_reason
274+
275+
finish_reason = choice.finish_reason
237276

238277
meta = {}
239278
if finish_reason:
@@ -242,33 +281,56 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict
242281
stream_chunk = StreamingChunk(text, meta)
243282
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
244283

245-
message = ChatMessage.from_assistant(generated_text)
246-
message.meta.update(
284+
meta.update(
247285
{
248286
"model": self._client.model,
249287
"finish_reason": finish_reason,
250288
"index": 0,
251289
"usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming
252290
}
253291
)
292+
293+
message = ChatMessage.from_assistant(text=generated_text, meta=meta)
294+
254295
return {"replies": [message]}
255296

256297
def _run_non_streaming(
257-
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]
298+
self,
299+
messages: List[Dict[str, str]],
300+
generation_kwargs: Dict[str, Any],
301+
tools: Optional[List["ChatCompletionInputTool"]] = None,
258302
) -> Dict[str, List[ChatMessage]]:
259-
chat_messages: List[ChatMessage] = []
260-
261-
api_chat_output: ChatCompletionOutput = self._client.chat_completion(messages, **generation_kwargs)
262-
for choice in api_chat_output.choices:
263-
message = ChatMessage.from_assistant(choice.message.content)
264-
message.meta.update(
265-
{
266-
"model": self._client.model,
267-
"finish_reason": choice.finish_reason,
268-
"index": choice.index,
269-
"usage": api_chat_output.usage or {"prompt_tokens": 0, "completion_tokens": 0},
270-
}
271-
)
272-
chat_messages.append(message)
273-
274-
return {"replies": chat_messages}
303+
api_chat_output: ChatCompletionOutput = self._client.chat_completion(
304+
messages=messages, tools=tools, **generation_kwargs
305+
)
306+
307+
if len(api_chat_output.choices) == 0:
308+
return {"replies": []}
309+
310+
# n is unused, so the API always returns only one choice
311+
# the argument is probably allowed for compatibility with OpenAI
312+
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
313+
choice = api_chat_output.choices[0]
314+
315+
text = choice.message.content
316+
tool_calls = []
317+
318+
if hfapi_tool_calls := choice.message.tool_calls:
319+
for hfapi_tc in hfapi_tool_calls:
320+
tool_call = ToolCall(
321+
tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
322+
)
323+
tool_calls.append(tool_call)
324+
325+
meta = {"model": self._client.model, "finish_reason": choice.finish_reason, "index": choice.index}
326+
327+
usage = {"prompt_tokens": 0, "completion_tokens": 0}
328+
if api_chat_output.usage:
329+
usage = {
330+
"prompt_tokens": api_chat_output.usage.prompt_tokens,
331+
"completion_tokens": api_chat_output.usage.completion_tokens,
332+
}
333+
meta["usage"] = usage
334+
335+
message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
336+
return {"replies": [message]}

haystack/components/generators/hugging_face_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
1313
from haystack.utils.url_validation import is_valid_http_url
1414

15-
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
15+
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
1616
from huggingface_hub import (
1717
InferenceClient,
1818
TextGenerationOutput,

haystack/dataclasses/tool.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import inspect
66
from dataclasses import asdict, dataclass
7-
from typing import Any, Callable, Dict, Optional
7+
from typing import Any, Callable, Dict, List, Optional
88

99
from pydantic import create_model
1010

@@ -216,6 +216,19 @@ def _remove_title_from_schema(schema: Dict[str, Any]):
216216
del property_schema[key]
217217

218218

219+
def _check_duplicate_tool_names(tools: List[Tool]) -> None:
220+
"""
221+
Check for duplicate tool names and raises a ValueError if they are found.
222+
223+
:param tools: The list of tools to check.
224+
:raises ValueError: If duplicate tool names are found.
225+
"""
226+
tool_names = [tool.name for tool in tools]
227+
duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1}
228+
if duplicate_tool_names:
229+
raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}")
230+
231+
219232
def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"):
220233
"""
221234
Deserialize Tools in a dictionary inplace.

haystack/utils/hf.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
from typing import Any, Callable, Dict, List, Optional, Union
99

1010
from haystack import logging
11-
from haystack.dataclasses import StreamingChunk
11+
from haystack.dataclasses import ChatMessage, StreamingChunk
1212
from haystack.lazy_imports import LazyImport
1313
from haystack.utils.auth import Secret
1414
from haystack.utils.device import ComponentDevice
1515

1616
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_import:
1717
import torch
1818

19-
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
19+
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
2020
from huggingface_hub import HfApi, InferenceClient, model_info
2121
from huggingface_hub.utils import RepositoryNotFoundError
2222

@@ -270,6 +270,44 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
270270
)
271271

272272

273+
def convert_message_to_hf_format(message: ChatMessage) -> Dict[str, Any]:
274+
"""
275+
Convert a message to the format expected by Hugging Face.
276+
"""
277+
text_contents = message.texts
278+
tool_calls = message.tool_calls
279+
tool_call_results = message.tool_call_results
280+
281+
if not text_contents and not tool_calls and not tool_call_results:
282+
raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.")
283+
if len(text_contents) + len(tool_call_results) > 1:
284+
raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.")
285+
286+
# HF always expects a content field, even if it is empty
287+
hf_msg: Dict[str, Any] = {"role": message._role.value, "content": ""}
288+
289+
if tool_call_results:
290+
result = tool_call_results[0]
291+
hf_msg["content"] = result.result
292+
if tc_id := result.origin.id:
293+
hf_msg["tool_call_id"] = tc_id
294+
# HF does not provide a way to communicate errors in tool invocations, so we ignore the error field
295+
return hf_msg
296+
297+
if text_contents:
298+
hf_msg["content"] = text_contents[0]
299+
if tool_calls:
300+
hf_tool_calls = []
301+
for tc in tool_calls:
302+
hf_tool_call = {"type": "function", "function": {"name": tc.tool_name, "arguments": tc.arguments}}
303+
if tc.id is not None:
304+
hf_tool_call["id"] = tc.id
305+
hf_tool_calls.append(hf_tool_call)
306+
hf_msg["tool_calls"] = hf_tool_calls
307+
308+
return hf_msg
309+
310+
273311
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import:
274312
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer
275313

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ extra-dependencies = [
8585
"numpy>=2", # Haystack is compatible both with numpy 1.x and 2.x, but we test with 2.x
8686

8787
"transformers[torch,sentencepiece]==4.44.2", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators...
88-
"huggingface_hub>=0.23.0", # Hugging Face API Generators and Embedders
88+
"huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders
8989
"sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder
9090
"langdetect", # TextLanguageRouter and DocumentLanguageClassifier
9191
"openai-whisper>=20231106", # LocalWhisperTranscriber
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
Add support for Tools in the Hugging Face API Chat Generator.

0 commit comments

Comments
 (0)