Skip to content

Commit de50d2f

Browse files
committed
PR feedback
1 parent 89d8b76 commit de50d2f

File tree

2 files changed

+76
-58
lines changed

2 files changed

+76
-58
lines changed

haystack/components/generators/chat/hugging_face_local.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,35 @@
4343
)
4444

4545

46+
def default_tool_parser(text: str) -> Optional[List[ToolCall]]:
47+
"""
48+
Default implementation for parsing tool calls from model output text.
49+
50+
Uses DEFAULT_TOOL_PATTERN to extract tool calls.
51+
52+
:param text: The text to parse for tool calls.
53+
:returns: A list containing a single ToolCall if a valid tool call is found, None otherwise.
54+
"""
55+
try:
56+
match = re.search(DEFAULT_TOOL_PATTERN, text, re.DOTALL)
57+
except re.error:
58+
logger.warning("Invalid regex pattern for tool parsing: {pattern}", pattern=DEFAULT_TOOL_PATTERN)
59+
return None
60+
61+
if not match:
62+
return None
63+
64+
name = match.group(1) or match.group(3)
65+
args_str = match.group(2) or match.group(4)
66+
67+
try:
68+
arguments = json.loads(args_str)
69+
return [ToolCall(tool_name=name, arguments=arguments)]
70+
except json.JSONDecodeError:
71+
logger.warning("Failed to parse tool call arguments: {args_str}", args_str=args_str)
72+
return None
73+
74+
4675
@component
4776
class HuggingFaceLocalChatGenerator:
4877
"""
@@ -93,7 +122,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
93122
stop_words: Optional[List[str]] = None,
94123
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
95124
tools: Optional[List[Tool]] = None,
96-
tool_pattern: Optional[Union[str, Callable[[str], Optional[List[ToolCall]]]]] = None,
125+
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
97126
):
98127
"""
99128
Initializes the HuggingFaceLocalChatGenerator component.
@@ -133,11 +162,9 @@ def __init__( # pylint: disable=too-many-positional-arguments
133162
In these cases, make sure your prompt has no stop words.
134163
:param streaming_callback: An optional callable for handling streaming responses.
135164
:param tools: A list of tools for which the model can prepare calls.
136-
:param tool_pattern:
137-
A pattern or callable to parse tool calls from model output.
138-
If a string, it will be used as a regex pattern to extract ToolCall object.
139-
If a callable, it should take a string and return a ToolCall object or None.
140-
If None, a default pattern will be used.
165+
:param tool_parsing_function:
166+
A callable that takes a string and returns a list of ToolCall objects or None.
167+
If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
141168
"""
142169
torch_and_transformers_import.check()
143170

@@ -188,7 +215,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
188215
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
189216
generation_kwargs["stop_sequences"].extend(stop_words or [])
190217

191-
self.tool_pattern = tool_pattern or DEFAULT_TOOL_PATTERN
218+
self.tool_parsing_function = tool_parsing_function or default_tool_parser
192219
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
193220
self.generation_kwargs = generation_kwargs
194221
self.chat_template = chat_template
@@ -228,6 +255,7 @@ def to_dict(self) -> Dict[str, Any]:
228255
token=self.token.to_dict() if self.token else None,
229256
chat_template=self.chat_template,
230257
tools=serialized_tools,
258+
tool_parsing_function=serialize_callable(self.tool_parsing_function),
231259
)
232260

233261
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
@@ -254,6 +282,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator":
254282
if serialized_callback_handler:
255283
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
256284

285+
tool_parsing_function = init_params.get("tool_parsing_function")
286+
if tool_parsing_function:
287+
init_params["tool_parsing_function"] = deserialize_callable(tool_parsing_function)
288+
257289
huggingface_pipeline_kwargs = init_params.get("huggingface_pipeline_kwargs", {})
258290
deserialize_hf_model_kwargs(huggingface_pipeline_kwargs)
259291
return default_from_dict(cls, data)
@@ -371,7 +403,7 @@ def create_message( # pylint: disable=too-many-positional-arguments
371403
prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False))
372404
total_tokens = prompt_token_count + completion_tokens
373405

374-
tool_calls = self._parse_tool_call(text) if parse_tool_calls else None
406+
tool_calls = self.tool_parsing_function(text) if parse_tool_calls else None
375407

376408
# Determine finish reason based on context
377409
if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize):
@@ -392,7 +424,8 @@ def create_message( # pylint: disable=too-many-positional-arguments
392424
},
393425
}
394426

395-
return ChatMessage.from_assistant(tool_calls=tool_calls, text=text, meta=meta)
427+
# If tool calls are detected, don't include the text content since it contains the raw tool call format
428+
return ChatMessage.from_assistant(tool_calls=tool_calls, text=None if tool_calls else text, meta=meta)
396429

397430
def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List[str]]:
398431
"""
@@ -410,36 +443,3 @@ def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List
410443
return None
411444

412445
return list(set(stop_words or []))
413-
414-
def _parse_tool_call(self, text: str) -> Optional[List[ToolCall]]:
415-
"""
416-
Parse a tool call from model output text.
417-
418-
:param text: The text to parse for tool calls.
419-
:returns: A ToolCall object if a valid tool call is found, None otherwise.
420-
"""
421-
# if the tool pattern is a callable, call it with the text and return the result
422-
if callable(self.tool_pattern):
423-
return self.tool_pattern(text)
424-
425-
# if the tool pattern is a regex pattern, search for it in the text
426-
try:
427-
match = re.search(self.tool_pattern, text, re.DOTALL)
428-
except re.error:
429-
logger.warning("Invalid regex pattern for tool parsing: {pattern}", pattern=self.tool_pattern)
430-
return None
431-
432-
if not match:
433-
return None
434-
435-
# seem like most models are not producing tool ids, so we omit them
436-
# and just use the tool name and arguments
437-
name = match.group(1) or match.group(3)
438-
args_str = match.group(2) or match.group(4)
439-
440-
try:
441-
arguments = json.loads(args_str)
442-
return [ToolCall(tool_name=name, arguments=arguments)]
443-
except json.JSONDecodeError:
444-
logger.warning("Failed to parse tool call arguments: {args_str}", args_str=args_str)
445-
return None

test/components/generators/chat/test_hugging_face_local.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44
from unittest.mock import Mock, patch
5+
from typing import Optional, List
56

67
from haystack.dataclasses.streaming_chunk import StreamingChunk
78
import pytest
@@ -68,6 +69,11 @@ def tools():
6869
return [tool]
6970

7071

72+
def custom_tool_parser(text: str) -> Optional[List[ToolCall]]:
73+
"""Test implementation of a custom tool parser."""
74+
return [ToolCall(tool_name="weather", arguments={"city": "Berlin"})]
75+
76+
7177
class TestHuggingFaceLocalChatGenerator:
7278
def test_initialize_with_valid_model_and_generation_parameters(self, model_info_mock):
7379
model = "HuggingFaceH4/zephyr-7b-alpha"
@@ -433,26 +439,38 @@ def test_run_with_tools_and_tool_response(self, model_info_mock, tools):
433439
assert "22°C" in message.text
434440
assert message.meta["finish_reason"] == "stop"
435441

436-
def test_run_with_invalid_tool_pattern(self, model_info_mock, tools):
442+
def test_run_with_custom_tool_parser(self, model_info_mock, tools):
443+
"""Test that a custom tool parsing function works correctly."""
437444
generator = HuggingFaceLocalChatGenerator(
438-
model="meta-llama/Llama-2-13b-chat-hf",
439-
tools=tools,
440-
tool_pattern=r"invalid[pattern", # Invalid regex pattern
445+
model="meta-llama/Llama-2-13b-chat-hf", tools=tools, tool_parsing_function=custom_tool_parser
441446
)
447+
generator.pipeline = Mock(return_value=[{"generated_text": "Let me check the weather for you"}])
448+
generator.pipeline.tokenizer = Mock()
449+
generator.pipeline.tokenizer.encode.return_value = [1, 2, 3]
450+
generator.pipeline.tokenizer.pad_token_id = 1
442451

443-
# Mock pipeline and tokenizer
444-
mock_pipeline = Mock(return_value=[{"generated_text": '{"name": "weather", "arguments": {"city": "Paris"}}'}])
445-
mock_tokenizer = Mock(spec=PreTrainedTokenizer)
446-
mock_tokenizer.encode.return_value = ["some", "tokens"]
447-
mock_tokenizer.pad_token_id = 100
448-
mock_tokenizer.apply_chat_template.return_value = "test prompt"
449-
mock_pipeline.tokenizer = mock_tokenizer
450-
generator.pipeline = mock_pipeline
452+
messages = [ChatMessage.from_user("What's the weather like in Berlin?")]
453+
results = generator.run(messages=messages)
451454

452-
messages = [ChatMessage.from_user("What's the weather in Paris?")]
455+
assert len(results["replies"]) == 1
456+
assert len(results["replies"][0].tool_calls) == 1
457+
assert results["replies"][0].tool_calls[0].tool_name == "weather"
458+
assert results["replies"][0].tool_calls[0].arguments == {"city": "Berlin"}
459+
460+
def test_default_tool_parser(self, model_info_mock, tools):
461+
"""Test that the default tool parser works correctly with valid tool call format."""
462+
generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf", tools=tools)
463+
generator.pipeline = Mock(
464+
return_value=[{"generated_text": '{"name": "weather", "arguments": {"city": "Berlin"}}'}]
465+
)
466+
generator.pipeline.tokenizer = Mock()
467+
generator.pipeline.tokenizer.encode.return_value = [1, 2, 3]
468+
generator.pipeline.tokenizer.pad_token_id = 1
469+
470+
messages = [ChatMessage.from_user("What's the weather like in Berlin?")]
453471
results = generator.run(messages=messages)
454472

455473
assert len(results["replies"]) == 1
456-
message = results["replies"][0]
457-
assert not message.tool_calls # No tool calls due to invalid pattern
458-
assert message.meta["finish_reason"] == "stop"
474+
assert len(results["replies"][0].tool_calls) == 1
475+
assert results["replies"][0].tool_calls[0].tool_name == "weather"
476+
assert results["replies"][0].tool_calls[0].arguments == {"city": "Berlin"}

0 commit comments

Comments
 (0)