Skip to content

Commit 08dd4d8

Browse files
authored
feat: AnthropicChatGenerator - add Toolset support (#1787)
* AnthropicChatGenerator - add Toolset support * Use new serialization method for tools * Update haystack dep to 2.13.1 which includes Toolset * Small update
1 parent ea80f71 commit 08dd4d8

File tree

3 files changed

+78
-24
lines changed

3 files changed

+78
-24
lines changed

integrations/anthropic/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
"Programming Language :: Python :: Implementation :: CPython",
2424
"Programming Language :: Python :: Implementation :: PyPy",
2525
]
26-
dependencies = ["haystack-ai>=2.11.0", "anthropic>=0.47.0"]
26+
dependencies = ["haystack-ai>=2.13.1", "anthropic>=0.47.0"]
2727

2828
[project.urls]
2929
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/anthropic#readme"

integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple
2+
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
33

44
from haystack import component, default_from_dict, default_to_dict, logging
55
from haystack.dataclasses import (
@@ -12,15 +12,15 @@
1212
ToolCallResult,
1313
select_streaming_callback,
1414
)
15-
from haystack.tools import Tool, _check_duplicate_tool_names
15+
from haystack.tools import (
16+
Tool,
17+
Toolset,
18+
_check_duplicate_tool_names,
19+
deserialize_tools_or_toolset_inplace,
20+
serialize_tools_or_toolset,
21+
)
1622
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1723

18-
# Compatibility with Haystack 2.12.0 and 2.13.0 - remove after 2.13.0 is released
19-
try:
20-
from haystack.tools import deserialize_tools_or_toolset_inplace
21-
except ImportError:
22-
from haystack.tools import deserialize_tools_inplace as deserialize_tools_or_toolset_inplace
23-
2424
from anthropic import Anthropic, AsyncAnthropic
2525

2626
logger = logging.getLogger(__name__)
@@ -186,7 +186,7 @@ def __init__(
186186
streaming_callback: Optional[StreamingCallbackT] = None,
187187
generation_kwargs: Optional[Dict[str, Any]] = None,
188188
ignore_tools_thinking_messages: bool = True,
189-
tools: Optional[List[Tool]] = None,
189+
tools: Optional[Union[List[Tool], Toolset]] = None,
190190
):
191191
"""
192192
Creates an instance of AnthropicChatGenerator.
@@ -213,10 +213,10 @@ def __init__(
213213
`ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool
214214
use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use)
215215
for more details.
216-
:param tools: A list of Tool objects that the model can use. Each tool should have a unique name.
216+
:param tools: A list of Tool objects or a Toolset that the model can use. Each tool should have a unique name.
217217
218218
"""
219-
_check_duplicate_tool_names(tools)
219+
_check_duplicate_tool_names(list(tools or [])) # handles Toolset as well
220220

221221
self.api_key = api_key
222222
self.model = model
@@ -241,15 +241,14 @@ def to_dict(self) -> Dict[str, Any]:
241241
The serialized component as a dictionary.
242242
"""
243243
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
244-
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
245244
return default_to_dict(
246245
self,
247246
model=self.model,
248247
streaming_callback=callback_name,
249248
generation_kwargs=self.generation_kwargs,
250249
api_key=self.api_key.to_dict(),
251250
ignore_tools_thinking_messages=self.ignore_tools_thinking_messages,
252-
tools=serialized_tools,
251+
tools=serialize_tools_or_toolset(self.tools),
253252
)
254253

255254
@classmethod
@@ -402,14 +401,15 @@ def _prepare_request_params(
402401
self,
403402
messages: List[ChatMessage],
404403
generation_kwargs: Optional[Dict[str, Any]] = None,
405-
tools: Optional[List[Tool]] = None,
404+
tools: Optional[Union[List[Tool], Toolset]] = None,
406405
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Dict[str, Any], List[Dict[str, Any]]]:
407406
"""
408407
Prepare the parameters for the Anthropic API request.
409408
410409
:param messages: A list of ChatMessage instances representing the input messages.
411410
:param generation_kwargs: Optional arguments to pass to the Anthropic generation endpoint.
412-
:param tools: A list of tools for which the model can prepare calls.
411+
:param tools: A list of Tool objects or a Toolset that the model can use. Each tool should
412+
have a unique name.
413413
:returns: A tuple containing:
414414
- system_messages: List of system messages in Anthropic format
415415
- non_system_messages: List of non-system messages in Anthropic format
@@ -448,7 +448,8 @@ def _prepare_request_params(
448448

449449
# tools management
450450
tools = tools or self.tools
451-
_check_duplicate_tool_names(tools)
451+
tools = list(tools) if isinstance(tools, Toolset) else tools
452+
_check_duplicate_tool_names(tools) # handles Toolset as well
452453
anthropic_tools = (
453454
[
454455
{
@@ -550,16 +551,16 @@ def run(
550551
messages: List[ChatMessage],
551552
streaming_callback: Optional[StreamingCallbackT] = None,
552553
generation_kwargs: Optional[Dict[str, Any]] = None,
553-
tools: Optional[List[Tool]] = None,
554+
tools: Optional[Union[List[Tool], Toolset]] = None,
554555
):
555556
"""
556557
Invokes the Anthropic API with the given messages and generation kwargs.
557558
558559
:param messages: A list of ChatMessage instances representing the input messages.
559560
:param streaming_callback: A callback function that is called when a new token is received from the stream.
560561
:param generation_kwargs: Optional arguments to pass to the Anthropic generation endpoint.
561-
:param tools: A list of tools for which the model can prepare calls. If set, it will override
562-
the `tools` parameter set during component initialization.
562+
:param tools: A list of Tool objects or a Toolset that the model can use. Each tool should
563+
have a unique name. If set, it will override the `tools` parameter set during component initialization.
563564
:returns: A dictionary with the following keys:
564565
- `replies`: The responses from the model
565566
"""
@@ -591,16 +592,16 @@ async def run_async(
591592
messages: List[ChatMessage],
592593
streaming_callback: Optional[StreamingCallbackT] = None,
593594
generation_kwargs: Optional[Dict[str, Any]] = None,
594-
tools: Optional[List[Tool]] = None,
595+
tools: Optional[Union[List[Tool], Toolset]] = None,
595596
):
596597
"""
597598
Async version of the run method. Invokes the Anthropic API with the given messages and generation kwargs.
598599
599600
:param messages: A list of ChatMessage instances representing the input messages.
600601
:param streaming_callback: A callback function that is called when a new token is received from the stream.
601602
:param generation_kwargs: Optional arguments to pass to the Anthropic generation endpoint.
602-
:param tools: A list of tools for which the model can prepare calls. If set, it will override
603-
the `tools` parameter set during component initialization.
603+
:param tools: A list of Tool objects or a Toolset that the model can use. Each tool should
604+
have a unique name. If set, it will override the `tools` parameter set during component initialization.
604605
:returns: A dictionary with the following keys:
605606
- `replies`: The responses from the model
606607
"""

integrations/anthropic/tests/test_chat_generator.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from haystack import Pipeline
2222
from haystack.components.generators.utils import print_streaming_chunk
2323
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall
24-
from haystack.tools import Tool
24+
from haystack.tools import Tool, Toolset
2525
from haystack.utils.auth import Secret
2626

2727
from haystack_integrations.components.generators.anthropic.chat.chat_generator import (
@@ -876,6 +876,59 @@ def test_live_run_with_tools(self, tools):
876876
assert len(final_message.text) > 0
877877
assert "paris" in final_message.text.lower()
878878

879+
@pytest.mark.skipif(
880+
not os.environ.get("ANTHROPIC_API_KEY", None),
881+
reason="Export an env var called ANTHROPIC_API_KEY containing the Anthropic API key to run this test.",
882+
)
883+
@pytest.mark.integration
884+
def test_live_run_with_toolset(self):
885+
"""
886+
Integration test that the AnthropicChatGenerator component can run with a Toolset.
887+
"""
888+
889+
def weather_function(city: str) -> str:
890+
"""Get weather information for a city."""
891+
weather_data = {"Paris": "22°C, sunny", "London": "15°C, rainy", "Tokyo": "18°C, cloudy"}
892+
return weather_data.get(city, "Weather data not available")
893+
894+
def echo_function(text: str) -> str:
895+
"""Echo a text."""
896+
return text
897+
898+
# Create tools
899+
weather_tool = Tool(
900+
name="weather",
901+
description="Get weather information for a city",
902+
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
903+
function=weather_function,
904+
)
905+
906+
echo_tool = Tool(
907+
name="echo",
908+
description="Echo a text",
909+
parameters={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]},
910+
function=echo_function,
911+
)
912+
913+
# Create Toolset
914+
toolset = Toolset([weather_tool, echo_tool])
915+
916+
# Test with weather query
917+
initial_messages = [ChatMessage.from_user("What's the weather like in Tokyo?")]
918+
component = AnthropicChatGenerator(tools=toolset)
919+
results = component.run(messages=initial_messages)
920+
921+
assert len(results["replies"]) == 1
922+
message = results["replies"][0]
923+
924+
assert message.tool_calls
925+
tool_call = message.tool_call
926+
assert isinstance(tool_call, ToolCall)
927+
assert tool_call.id is not None
928+
assert tool_call.tool_name == "weather"
929+
assert tool_call.arguments == {"city": "Tokyo"}
930+
assert message.meta["finish_reason"] == "tool_use"
931+
879932
@pytest.mark.skipif(
880933
not os.environ.get("ANTHROPIC_API_KEY", None),
881934
reason="Export an env var called ANTHROPIC_API_KEY containing the Anthropic API key to run this test.",

0 commit comments

Comments
 (0)