Skip to content

Commit 7dcbf25

Browse files
authored
feat: add Tool Invoker component (#8664)
* port toolinvoker * release note
1 parent c192488 commit 7dcbf25

File tree

6 files changed

+509
-0
lines changed

6 files changed

+509
-0
lines changed

docs/pydoc/config/tools_api.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
loaders:
2+
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
3+
search_path: [../../../haystack/components/tools]
4+
modules: ["tool_invoker"]
5+
ignore_when_discovered: ["__init__"]
6+
processors:
7+
- type: filter
8+
expression:
9+
documented_only: true
10+
do_not_filter_modules: false
11+
skip_empty_modules: true
12+
- type: smart
13+
- type: crossref
14+
renderer:
15+
type: haystack_pydoc_tools.renderers.ReadmeCoreRenderer
16+
excerpt: Components related to Tool Calling.
17+
category_slug: haystack-api
18+
title: Tools
19+
slug: tools-api
20+
order: 152
21+
markdown:
22+
descriptive_class_title: false
23+
classdef_code_block: false
24+
descriptive_module_title: true
25+
add_method_class_prefix: true
26+
add_member_class_prefix: false
27+
filename: tools_api.md

haystack/components/tools/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from haystack.components.tools.tool_invoker import ToolInvoker
6+
7+
_all_ = ["ToolInvoker"]
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import json
6+
import warnings
7+
from typing import Any, Dict, List
8+
9+
from haystack import component, default_from_dict, default_to_dict, logging
10+
from haystack.dataclasses.chat_message import ChatMessage, ToolCall
11+
from haystack.dataclasses.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace
12+
13+
logger = logging.getLogger(__name__)
14+
15+
_TOOL_INVOCATION_FAILURE = "Tool invocation failed with error: {error}."
16+
_TOOL_NOT_FOUND = "Tool {tool_name} not found in the list of tools. Available tools are: {available_tools}."
17+
_TOOL_RESULT_CONVERSION_FAILURE = (
18+
"Failed to convert tool result to string using '{conversion_function}'. Error: {error}."
19+
)
20+
21+
22+
class ToolNotFoundException(Exception):
23+
"""
24+
Exception raised when a tool is not found in the list of available tools.
25+
"""
26+
27+
pass
28+
29+
30+
class StringConversionError(Exception):
31+
"""
32+
Exception raised when the conversion of a tool result to a string fails.
33+
"""
34+
35+
pass
36+
37+
38+
@component
39+
class ToolInvoker:
40+
"""
41+
Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects.
42+
43+
At initialization, the ToolInvoker component is provided with a list of available tools.
44+
At runtime, the component processes a list of ChatMessage object containing tool calls
45+
and invokes the corresponding tools.
46+
The results of the tool invocations are returned as a list of ChatMessage objects with tool role.
47+
48+
Usage example:
49+
```python
50+
from haystack.dataclasses import ChatMessage, ToolCall, Tool
51+
from haystack.components.tools import ToolInvoker
52+
53+
# Tool definition
54+
def dummy_weather_function(city: str):
55+
return f"The weather in {city} is 20 degrees."
56+
57+
parameters = {"type": "object",
58+
"properties": {"city": {"type": "string"}},
59+
"required": ["city"]}
60+
61+
tool = Tool(name="weather_tool",
62+
description="A tool to get the weather",
63+
function=dummy_weather_function,
64+
parameters=parameters)
65+
66+
# Usually, the ChatMessage with tool_calls is generated by a Language Model
67+
# Here, we create it manually for demonstration purposes
68+
tool_call = ToolCall(
69+
tool_name="weather_tool",
70+
arguments={"city": "Berlin"}
71+
)
72+
message = ChatMessage.from_assistant(tool_calls=[tool_call])
73+
74+
# ToolInvoker initialization and run
75+
invoker = ToolInvoker(tools=[tool])
76+
result = invoker.run(messages=[message])
77+
78+
print(result)
79+
```
80+
81+
```
82+
>> {
83+
>> 'tool_messages': [
84+
>> ChatMessage(
85+
>> _role=<ChatRole.TOOL: 'tool'>,
86+
>> _content=[
87+
>> ToolCallResult(
88+
>> result='"The weather in Berlin is 20 degrees."',
89+
>> origin=ToolCall(
90+
>> tool_name='weather_tool',
91+
>> arguments={'city': 'Berlin'},
92+
>> id=None
93+
>> )
94+
>> )
95+
>> ],
96+
>> _meta={}
97+
>> )
98+
>> ]
99+
>> }
100+
```
101+
"""
102+
103+
def __init__(self, tools: List[Tool], raise_on_failure: bool = True, convert_result_to_json_string: bool = False):
104+
"""
105+
Initialize the ToolInvoker component.
106+
107+
:param tools:
108+
A list of tools that can be invoked.
109+
:param raise_on_failure:
110+
If True, the component will raise an exception in case of errors
111+
(tool not found, tool invocation errors, tool result conversion errors).
112+
If False, the component will return a ChatMessage object with `error=True`
113+
and a description of the error in `result`.
114+
:param convert_result_to_json_string:
115+
If True, the tool invocation result will be converted to a string using `json.dumps`.
116+
If False, the tool invocation result will be converted to a string using `str`.
117+
118+
:raises ValueError:
119+
If no tools are provided or if duplicate tool names are found.
120+
"""
121+
122+
msg = "The `ToolInvoker` component is experimental and its API may change in the future."
123+
warnings.warn(msg)
124+
125+
if not tools:
126+
raise ValueError("ToolInvoker requires at least one tool to be provided.")
127+
_check_duplicate_tool_names(tools)
128+
129+
self.tools = tools
130+
self._tools_with_names = dict(zip([tool.name for tool in tools], tools))
131+
self.raise_on_failure = raise_on_failure
132+
self.convert_result_to_json_string = convert_result_to_json_string
133+
134+
def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall) -> ChatMessage:
135+
"""
136+
Prepares a ChatMessage with the result of a tool invocation.
137+
138+
:param result:
139+
The tool result.
140+
:returns:
141+
A ChatMessage object containing the tool result as a string.
142+
143+
:raises
144+
StringConversionError: If the conversion of the tool result to a string fails
145+
and `raise_on_failure` is True.
146+
"""
147+
error = False
148+
149+
if self.convert_result_to_json_string:
150+
try:
151+
# We disable ensure_ascii so special chars like emojis are not converted
152+
tool_result_str = json.dumps(result, ensure_ascii=False)
153+
except Exception as e:
154+
if self.raise_on_failure:
155+
raise StringConversionError("Failed to convert tool result to string using `json.dumps`") from e
156+
tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="json.dumps")
157+
error = True
158+
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
159+
160+
try:
161+
tool_result_str = str(result)
162+
except Exception as e:
163+
if self.raise_on_failure:
164+
raise StringConversionError("Failed to convert tool result to string using `str`") from e
165+
tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="str")
166+
error = True
167+
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
168+
169+
@component.output_types(tool_messages=List[ChatMessage])
170+
def run(self, messages: List[ChatMessage]) -> Dict[str, Any]:
171+
"""
172+
Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
173+
174+
:param messages:
175+
A list of ChatMessage objects.
176+
:returns:
177+
A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
178+
Each ChatMessage objects wraps the result of a tool invocation.
179+
180+
:raises ToolNotFoundException:
181+
If the tool is not found in the list of available tools and `raise_on_failure` is True.
182+
:raises ToolInvocationError:
183+
If the tool invocation fails and `raise_on_failure` is True.
184+
:raises StringConversionError:
185+
If the conversion of the tool result to a string fails and `raise_on_failure` is True.
186+
"""
187+
tool_messages = []
188+
189+
for message in messages:
190+
tool_calls = message.tool_calls
191+
if not tool_calls:
192+
continue
193+
194+
for tool_call in tool_calls:
195+
tool_name = tool_call.tool_name
196+
tool_arguments = tool_call.arguments
197+
198+
if not tool_name in self._tools_with_names:
199+
msg = _TOOL_NOT_FOUND.format(tool_name=tool_name, available_tools=self._tools_with_names.keys())
200+
if self.raise_on_failure:
201+
raise ToolNotFoundException(msg)
202+
tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True))
203+
continue
204+
205+
tool_to_invoke = self._tools_with_names[tool_name]
206+
try:
207+
tool_result = tool_to_invoke.invoke(**tool_arguments)
208+
except ToolInvocationError as e:
209+
if self.raise_on_failure:
210+
raise e
211+
msg = _TOOL_INVOCATION_FAILURE.format(error=e)
212+
tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True))
213+
continue
214+
215+
tool_message = self._prepare_tool_result_message(tool_result, tool_call)
216+
tool_messages.append(tool_message)
217+
218+
return {"tool_messages": tool_messages}
219+
220+
def to_dict(self) -> Dict[str, Any]:
221+
"""
222+
Serializes the component to a dictionary.
223+
224+
:returns:
225+
Dictionary with serialized data.
226+
"""
227+
serialized_tools = [tool.to_dict() for tool in self.tools]
228+
return default_to_dict(
229+
self,
230+
tools=serialized_tools,
231+
raise_on_failure=self.raise_on_failure,
232+
convert_result_to_json_string=self.convert_result_to_json_string,
233+
)
234+
235+
@classmethod
236+
def from_dict(cls, data: Dict[str, Any]) -> "ToolInvoker":
237+
"""
238+
Deserializes the component from a dictionary.
239+
240+
:param data:
241+
The dictionary to deserialize from.
242+
:returns:
243+
The deserialized component.
244+
"""
245+
deserialize_tools_inplace(data["init_parameters"], key="tools")
246+
return default_from_dict(cls, data)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
features:
3+
- |
4+
Add a new experimental component `ToolInvoker`.
5+
This component invokes tools based on tool calls prepared by Language Models and returns the results as a list of
6+
ChatMessage objects with tool role.

test/components/tools/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0

0 commit comments

Comments
 (0)