43
43
)
44
44
45
45
46
+ def default_tool_parser (text : str ) -> Optional [List [ToolCall ]]:
47
+ """
48
+ Default implementation for parsing tool calls from model output text.
49
+
50
+ Uses DEFAULT_TOOL_PATTERN to extract tool calls.
51
+
52
+ :param text: The text to parse for tool calls.
53
+ :returns: A list containing a single ToolCall if a valid tool call is found, None otherwise.
54
+ """
55
+ try :
56
+ match = re .search (DEFAULT_TOOL_PATTERN , text , re .DOTALL )
57
+ except re .error :
58
+ logger .warning ("Invalid regex pattern for tool parsing: {pattern}" , pattern = DEFAULT_TOOL_PATTERN )
59
+ return None
60
+
61
+ if not match :
62
+ return None
63
+
64
+ name = match .group (1 ) or match .group (3 )
65
+ args_str = match .group (2 ) or match .group (4 )
66
+
67
+ try :
68
+ arguments = json .loads (args_str )
69
+ return [ToolCall (tool_name = name , arguments = arguments )]
70
+ except json .JSONDecodeError :
71
+ logger .warning ("Failed to parse tool call arguments: {args_str}" , args_str = args_str )
72
+ return None
73
+
74
+
46
75
@component
47
76
class HuggingFaceLocalChatGenerator :
48
77
"""
@@ -93,7 +122,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
93
122
stop_words : Optional [List [str ]] = None ,
94
123
streaming_callback : Optional [Callable [[StreamingChunk ], None ]] = None ,
95
124
tools : Optional [List [Tool ]] = None ,
96
- tool_pattern : Optional [Union [ str , Callable [[str ], Optional [List [ToolCall ] ]]]] = None ,
125
+ tool_parsing_function : Optional [Callable [[str ], Optional [List [ToolCall ]]]] = None ,
97
126
):
98
127
"""
99
128
Initializes the HuggingFaceLocalChatGenerator component.
@@ -133,11 +162,9 @@ def __init__( # pylint: disable=too-many-positional-arguments
133
162
In these cases, make sure your prompt has no stop words.
134
163
:param streaming_callback: An optional callable for handling streaming responses.
135
164
:param tools: A list of tools for which the model can prepare calls.
136
- :param tool_pattern:
137
- A pattern or callable to parse tool calls from model output.
138
- If a string, it will be used as a regex pattern to extract ToolCall object.
139
- If a callable, it should take a string and return a ToolCall object or None.
140
- If None, a default pattern will be used.
165
+ :param tool_parsing_function:
166
+ A callable that takes a string and returns a list of ToolCall objects or None.
167
+ If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
141
168
"""
142
169
torch_and_transformers_import .check ()
143
170
@@ -188,7 +215,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
188
215
generation_kwargs ["stop_sequences" ] = generation_kwargs .get ("stop_sequences" , [])
189
216
generation_kwargs ["stop_sequences" ].extend (stop_words or [])
190
217
191
- self .tool_pattern = tool_pattern or DEFAULT_TOOL_PATTERN
218
+ self .tool_parsing_function = tool_parsing_function or default_tool_parser
192
219
self .huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
193
220
self .generation_kwargs = generation_kwargs
194
221
self .chat_template = chat_template
@@ -228,6 +255,7 @@ def to_dict(self) -> Dict[str, Any]:
228
255
token = self .token .to_dict () if self .token else None ,
229
256
chat_template = self .chat_template ,
230
257
tools = serialized_tools ,
258
+ tool_parsing_function = serialize_callable (self .tool_parsing_function ),
231
259
)
232
260
233
261
huggingface_pipeline_kwargs = serialization_dict ["init_parameters" ]["huggingface_pipeline_kwargs" ]
@@ -254,6 +282,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator":
254
282
if serialized_callback_handler :
255
283
data ["init_parameters" ]["streaming_callback" ] = deserialize_callable (serialized_callback_handler )
256
284
285
+ tool_parsing_function = init_params .get ("tool_parsing_function" )
286
+ if tool_parsing_function :
287
+ init_params ["tool_parsing_function" ] = deserialize_callable (tool_parsing_function )
288
+
257
289
huggingface_pipeline_kwargs = init_params .get ("huggingface_pipeline_kwargs" , {})
258
290
deserialize_hf_model_kwargs (huggingface_pipeline_kwargs )
259
291
return default_from_dict (cls , data )
@@ -371,7 +403,7 @@ def create_message( # pylint: disable=too-many-positional-arguments
371
403
prompt_token_count = len (tokenizer .encode (prompt , add_special_tokens = False ))
372
404
total_tokens = prompt_token_count + completion_tokens
373
405
374
- tool_calls = self ._parse_tool_call (text ) if parse_tool_calls else None
406
+ tool_calls = self .tool_parsing_function (text ) if parse_tool_calls else None
375
407
376
408
# Determine finish reason based on context
377
409
if completion_tokens >= generation_kwargs .get ("max_new_tokens" , sys .maxsize ):
@@ -392,7 +424,8 @@ def create_message( # pylint: disable=too-many-positional-arguments
392
424
},
393
425
}
394
426
395
- return ChatMessage .from_assistant (tool_calls = tool_calls , text = text , meta = meta )
427
+ # If tool calls are detected, don't include the text content since it contains the raw tool call format
428
+ return ChatMessage .from_assistant (tool_calls = tool_calls , text = None if tool_calls else text , meta = meta )
396
429
397
430
def _validate_stop_words (self , stop_words : Optional [List [str ]]) -> Optional [List [str ]]:
398
431
"""
@@ -410,36 +443,3 @@ def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List
410
443
return None
411
444
412
445
return list (set (stop_words or []))
413
-
414
- def _parse_tool_call (self , text : str ) -> Optional [List [ToolCall ]]:
415
- """
416
- Parse a tool call from model output text.
417
-
418
- :param text: The text to parse for tool calls.
419
- :returns: A ToolCall object if a valid tool call is found, None otherwise.
420
- """
421
- # if the tool pattern is a callable, call it with the text and return the result
422
- if callable (self .tool_pattern ):
423
- return self .tool_pattern (text )
424
-
425
- # if the tool pattern is a regex pattern, search for it in the text
426
- try :
427
- match = re .search (self .tool_pattern , text , re .DOTALL )
428
- except re .error :
429
- logger .warning ("Invalid regex pattern for tool parsing: {pattern}" , pattern = self .tool_pattern )
430
- return None
431
-
432
- if not match :
433
- return None
434
-
435
- # seem like most models are not producing tool ids, so we omit them
436
- # and just use the tool name and arguments
437
- name = match .group (1 ) or match .group (3 )
438
- args_str = match .group (2 ) or match .group (4 )
439
-
440
- try :
441
- arguments = json .loads (args_str )
442
- return [ToolCall (tool_name = name , arguments = arguments )]
443
- except json .JSONDecodeError :
444
- logger .warning ("Failed to parse tool call arguments: {args_str}" , args_str = args_str )
445
- return None
0 commit comments