5
5
from typing import Any , Callable , Dict , Iterable , List , Optional , Union
6
6
7
7
from haystack import component , default_from_dict , default_to_dict , logging
8
- from haystack .dataclasses import ChatMessage , StreamingChunk
8
+ from haystack .dataclasses import ChatMessage , StreamingChunk , ToolCall
9
+ from haystack .dataclasses .tool import Tool , _check_duplicate_tool_names , deserialize_tools_inplace
9
10
from haystack .lazy_imports import LazyImport
10
11
from haystack .utils import Secret , deserialize_callable , deserialize_secrets_inplace , serialize_callable
11
- from haystack .utils .hf import HFGenerationAPIType , HFModelType , check_valid_model
12
+ from haystack .utils .hf import HFGenerationAPIType , HFModelType , check_valid_model , convert_message_to_hf_format
12
13
from haystack .utils .url_validation import is_valid_http_url
13
14
14
- with LazyImport (message = "Run 'pip install \" huggingface_hub[inference]>=0.23.0\" '" ) as huggingface_hub_import :
15
- from huggingface_hub import ChatCompletionOutput , ChatCompletionStreamOutput , InferenceClient
15
+ with LazyImport (message = "Run 'pip install \" huggingface_hub[inference]>=0.27.0\" '" ) as huggingface_hub_import :
16
+ from huggingface_hub import (
17
+ ChatCompletionInputTool ,
18
+ ChatCompletionOutput ,
19
+ ChatCompletionStreamOutput ,
20
+ InferenceClient ,
21
+ )
16
22
17
23
18
24
logger = logging .getLogger (__name__ )
19
25
20
26
21
- def _convert_message_to_hfapi_format (message : ChatMessage ) -> Dict [str , str ]:
22
- """
23
- Convert a message to the format expected by Hugging Face APIs.
24
-
25
- :returns: A dictionary with the following keys:
26
- - `role`
27
- - `content`
28
- """
29
- return {"role" : message .role .value , "content" : message .text or "" }
30
-
31
-
32
27
@component
33
28
class HuggingFaceAPIChatGenerator :
34
29
"""
@@ -107,6 +102,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
107
102
generation_kwargs : Optional [Dict [str , Any ]] = None ,
108
103
stop_words : Optional [List [str ]] = None ,
109
104
streaming_callback : Optional [Callable [[StreamingChunk ], None ]] = None ,
105
+ tools : Optional [List [Tool ]] = None ,
110
106
):
111
107
"""
112
108
Initialize the HuggingFaceAPIChatGenerator instance.
@@ -121,14 +117,22 @@ def __init__( # pylint: disable=too-many-positional-arguments
121
117
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
122
118
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
123
119
`TEXT_GENERATION_INFERENCE`.
124
- :param token: The Hugging Face token to use as HTTP bearer authorization.
120
+ :param token:
121
+ The Hugging Face token to use as HTTP bearer authorization.
125
122
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
126
123
:param generation_kwargs:
127
124
A dictionary with keyword arguments to customize text generation.
128
125
Some examples: `max_tokens`, `temperature`, `top_p`.
129
126
For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
130
- :param stop_words: An optional list of strings representing the stop words.
131
- :param streaming_callback: An optional callable for handling streaming responses.
127
+ :param stop_words:
128
+ An optional list of strings representing the stop words.
129
+ :param streaming_callback:
130
+ An optional callable for handling streaming responses.
131
+ :param tools:
132
+ A list of tools for which the model can prepare calls.
133
+ The chosen model should support tool/function calling, according to the model card.
134
+ Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
135
+ unexpected behavior.
132
136
"""
133
137
134
138
huggingface_hub_import .check ()
@@ -159,6 +163,11 @@ def __init__( # pylint: disable=too-many-positional-arguments
159
163
msg = f"Unknown api_type { api_type } "
160
164
raise ValueError (msg )
161
165
166
+ if tools :
167
+ if streaming_callback is not None :
168
+ raise ValueError ("Using tools and streaming at the same time is not supported. Please choose one." )
169
+ _check_duplicate_tool_names (tools )
170
+
162
171
# handle generation kwargs setup
163
172
generation_kwargs = generation_kwargs .copy () if generation_kwargs else {}
164
173
generation_kwargs ["stop" ] = generation_kwargs .get ("stop" , [])
@@ -171,6 +180,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
171
180
self .generation_kwargs = generation_kwargs
172
181
self .streaming_callback = streaming_callback
173
182
self ._client = InferenceClient (model_or_url , token = token .resolve_value () if token else None )
183
+ self .tools = tools
174
184
175
185
def to_dict (self ) -> Dict [str , Any ]:
176
186
"""
@@ -180,13 +190,15 @@ def to_dict(self) -> Dict[str, Any]:
180
190
A dictionary containing the serialized component.
181
191
"""
182
192
callback_name = serialize_callable (self .streaming_callback ) if self .streaming_callback else None
193
+ serialized_tools = [tool .to_dict () for tool in self .tools ] if self .tools else None
183
194
return default_to_dict (
184
195
self ,
185
196
api_type = str (self .api_type ),
186
197
api_params = self .api_params ,
187
198
token = self .token .to_dict () if self .token else None ,
188
199
generation_kwargs = self .generation_kwargs ,
189
200
streaming_callback = callback_name ,
201
+ tools = serialized_tools ,
190
202
)
191
203
192
204
@classmethod
@@ -195,32 +207,53 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
195
207
Deserialize this component from a dictionary.
196
208
"""
197
209
deserialize_secrets_inplace (data ["init_parameters" ], keys = ["token" ])
210
+ deserialize_tools_inplace (data ["init_parameters" ], key = "tools" )
198
211
init_params = data .get ("init_parameters" , {})
199
212
serialized_callback_handler = init_params .get ("streaming_callback" )
200
213
if serialized_callback_handler :
201
214
data ["init_parameters" ]["streaming_callback" ] = deserialize_callable (serialized_callback_handler )
202
215
return default_from_dict (cls , data )
203
216
204
217
@component .output_types (replies = List [ChatMessage ])
205
- def run (self , messages : List [ChatMessage ], generation_kwargs : Optional [Dict [str , Any ]] = None ):
218
+ def run (
219
+ self ,
220
+ messages : List [ChatMessage ],
221
+ generation_kwargs : Optional [Dict [str , Any ]] = None ,
222
+ tools : Optional [List [Tool ]] = None ,
223
+ ):
206
224
"""
207
225
Invoke the text generation inference based on the provided messages and generation parameters.
208
226
209
- :param messages: A list of ChatMessage objects representing the input messages.
210
- :param generation_kwargs: Additional keyword arguments for text generation.
227
+ :param messages:
228
+ A list of ChatMessage objects representing the input messages.
229
+ :param generation_kwargs:
230
+ Additional keyword arguments for text generation.
231
+ :param tools:
232
+ A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
233
+ during component initialization.
211
234
:returns: A dictionary with the following keys:
212
235
- `replies`: A list containing the generated responses as ChatMessage objects.
213
236
"""
214
237
215
238
# update generation kwargs by merging with the default ones
216
239
generation_kwargs = {** self .generation_kwargs , ** (generation_kwargs or {})}
217
240
218
- formatted_messages = [_convert_message_to_hfapi_format (message ) for message in messages ]
241
+ formatted_messages = [convert_message_to_hf_format (message ) for message in messages ]
242
+
243
+ tools = tools or self .tools
244
+ if tools :
245
+ if self .streaming_callback :
246
+ raise ValueError ("Using tools and streaming at the same time is not supported. Please choose one." )
247
+ _check_duplicate_tool_names (tools )
219
248
220
249
if self .streaming_callback :
221
250
return self ._run_streaming (formatted_messages , generation_kwargs )
222
251
223
- return self ._run_non_streaming (formatted_messages , generation_kwargs )
252
+ hf_tools = None
253
+ if tools :
254
+ hf_tools = [{"type" : "function" , "function" : {** t .tool_spec }} for t in tools ]
255
+
256
+ return self ._run_non_streaming (formatted_messages , generation_kwargs , hf_tools )
224
257
225
258
def _run_streaming (self , messages : List [Dict [str , str ]], generation_kwargs : Dict [str , Any ]):
226
259
api_output : Iterable [ChatCompletionStreamOutput ] = self ._client .chat_completion (
@@ -229,11 +262,17 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict
229
262
230
263
generated_text = ""
231
264
232
- for chunk in api_output : # pylint: disable=not-an-iterable
233
- text = chunk .choices [0 ].delta .content
265
+ for chunk in api_output :
266
+ # n is unused, so the API always returns only one choice
267
+ # the argument is probably allowed for compatibility with OpenAI
268
+ # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
269
+ choice = chunk .choices [0 ]
270
+
271
+ text = choice .delta .content
234
272
if text :
235
273
generated_text += text
236
- finish_reason = chunk .choices [0 ].finish_reason
274
+
275
+ finish_reason = choice .finish_reason
237
276
238
277
meta = {}
239
278
if finish_reason :
@@ -242,33 +281,56 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict
242
281
stream_chunk = StreamingChunk (text , meta )
243
282
self .streaming_callback (stream_chunk ) # type: ignore # streaming_callback is not None (verified in the run method)
244
283
245
- message = ChatMessage .from_assistant (generated_text )
246
- message .meta .update (
284
+ meta .update (
247
285
{
248
286
"model" : self ._client .model ,
249
287
"finish_reason" : finish_reason ,
250
288
"index" : 0 ,
251
289
"usage" : {"prompt_tokens" : 0 , "completion_tokens" : 0 }, # not available in streaming
252
290
}
253
291
)
292
+
293
+ message = ChatMessage .from_assistant (text = generated_text , meta = meta )
294
+
254
295
return {"replies" : [message ]}
255
296
256
297
def _run_non_streaming (
257
- self , messages : List [Dict [str , str ]], generation_kwargs : Dict [str , Any ]
298
+ self ,
299
+ messages : List [Dict [str , str ]],
300
+ generation_kwargs : Dict [str , Any ],
301
+ tools : Optional [List ["ChatCompletionInputTool" ]] = None ,
258
302
) -> Dict [str , List [ChatMessage ]]:
259
- chat_messages : List [ChatMessage ] = []
260
-
261
- api_chat_output : ChatCompletionOutput = self ._client .chat_completion (messages , ** generation_kwargs )
262
- for choice in api_chat_output .choices :
263
- message = ChatMessage .from_assistant (choice .message .content )
264
- message .meta .update (
265
- {
266
- "model" : self ._client .model ,
267
- "finish_reason" : choice .finish_reason ,
268
- "index" : choice .index ,
269
- "usage" : api_chat_output .usage or {"prompt_tokens" : 0 , "completion_tokens" : 0 },
270
- }
271
- )
272
- chat_messages .append (message )
273
-
274
- return {"replies" : chat_messages }
303
+ api_chat_output : ChatCompletionOutput = self ._client .chat_completion (
304
+ messages = messages , tools = tools , ** generation_kwargs
305
+ )
306
+
307
+ if len (api_chat_output .choices ) == 0 :
308
+ return {"replies" : []}
309
+
310
+ # n is unused, so the API always returns only one choice
311
+ # the argument is probably allowed for compatibility with OpenAI
312
+ # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
313
+ choice = api_chat_output .choices [0 ]
314
+
315
+ text = choice .message .content
316
+ tool_calls = []
317
+
318
+ if hfapi_tool_calls := choice .message .tool_calls :
319
+ for hfapi_tc in hfapi_tool_calls :
320
+ tool_call = ToolCall (
321
+ tool_name = hfapi_tc .function .name , arguments = hfapi_tc .function .arguments , id = hfapi_tc .id
322
+ )
323
+ tool_calls .append (tool_call )
324
+
325
+ meta = {"model" : self ._client .model , "finish_reason" : choice .finish_reason , "index" : choice .index }
326
+
327
+ usage = {"prompt_tokens" : 0 , "completion_tokens" : 0 }
328
+ if api_chat_output .usage :
329
+ usage = {
330
+ "prompt_tokens" : api_chat_output .usage .prompt_tokens ,
331
+ "completion_tokens" : api_chat_output .usage .completion_tokens ,
332
+ }
333
+ meta ["usage" ] = usage
334
+
335
+ message = ChatMessage .from_assistant (text = text , tool_calls = tool_calls , meta = meta )
336
+ return {"replies" : [message ]}
0 commit comments