|
| 1 | +# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +import json |
| 6 | +import warnings |
| 7 | +from typing import Any, Dict, List |
| 8 | + |
| 9 | +from haystack import component, default_from_dict, default_to_dict, logging |
| 10 | +from haystack.dataclasses.chat_message import ChatMessage, ToolCall |
| 11 | +from haystack.dataclasses.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace |
| 12 | + |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | +_TOOL_INVOCATION_FAILURE = "Tool invocation failed with error: {error}." |
| 16 | +_TOOL_NOT_FOUND = "Tool {tool_name} not found in the list of tools. Available tools are: {available_tools}." |
| 17 | +_TOOL_RESULT_CONVERSION_FAILURE = ( |
| 18 | + "Failed to convert tool result to string using '{conversion_function}'. Error: {error}." |
| 19 | +) |
| 20 | + |
| 21 | + |
| 22 | +class ToolNotFoundException(Exception): |
| 23 | + """ |
| 24 | + Exception raised when a tool is not found in the list of available tools. |
| 25 | + """ |
| 26 | + |
| 27 | + pass |
| 28 | + |
| 29 | + |
| 30 | +class StringConversionError(Exception): |
| 31 | + """ |
| 32 | + Exception raised when the conversion of a tool result to a string fails. |
| 33 | + """ |
| 34 | + |
| 35 | + pass |
| 36 | + |
| 37 | + |
| 38 | +@component |
| 39 | +class ToolInvoker: |
| 40 | + """ |
| 41 | + Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects. |
| 42 | +
|
| 43 | + At initialization, the ToolInvoker component is provided with a list of available tools. |
| 44 | + At runtime, the component processes a list of ChatMessage object containing tool calls |
| 45 | + and invokes the corresponding tools. |
| 46 | + The results of the tool invocations are returned as a list of ChatMessage objects with tool role. |
| 47 | +
|
| 48 | + Usage example: |
| 49 | + ```python |
| 50 | + from haystack.dataclasses import ChatMessage, ToolCall, Tool |
| 51 | + from haystack.components.tools import ToolInvoker |
| 52 | +
|
| 53 | + # Tool definition |
| 54 | + def dummy_weather_function(city: str): |
| 55 | + return f"The weather in {city} is 20 degrees." |
| 56 | +
|
| 57 | + parameters = {"type": "object", |
| 58 | + "properties": {"city": {"type": "string"}}, |
| 59 | + "required": ["city"]} |
| 60 | +
|
| 61 | + tool = Tool(name="weather_tool", |
| 62 | + description="A tool to get the weather", |
| 63 | + function=dummy_weather_function, |
| 64 | + parameters=parameters) |
| 65 | +
|
| 66 | + # Usually, the ChatMessage with tool_calls is generated by a Language Model |
| 67 | + # Here, we create it manually for demonstration purposes |
| 68 | + tool_call = ToolCall( |
| 69 | + tool_name="weather_tool", |
| 70 | + arguments={"city": "Berlin"} |
| 71 | + ) |
| 72 | + message = ChatMessage.from_assistant(tool_calls=[tool_call]) |
| 73 | +
|
| 74 | + # ToolInvoker initialization and run |
| 75 | + invoker = ToolInvoker(tools=[tool]) |
| 76 | + result = invoker.run(messages=[message]) |
| 77 | +
|
| 78 | + print(result) |
| 79 | + ``` |
| 80 | +
|
| 81 | + ``` |
| 82 | + >> { |
| 83 | + >> 'tool_messages': [ |
| 84 | + >> ChatMessage( |
| 85 | + >> _role=<ChatRole.TOOL: 'tool'>, |
| 86 | + >> _content=[ |
| 87 | + >> ToolCallResult( |
| 88 | + >> result='"The weather in Berlin is 20 degrees."', |
| 89 | + >> origin=ToolCall( |
| 90 | + >> tool_name='weather_tool', |
| 91 | + >> arguments={'city': 'Berlin'}, |
| 92 | + >> id=None |
| 93 | + >> ) |
| 94 | + >> ) |
| 95 | + >> ], |
| 96 | + >> _meta={} |
| 97 | + >> ) |
| 98 | + >> ] |
| 99 | + >> } |
| 100 | + ``` |
| 101 | + """ |
| 102 | + |
| 103 | + def __init__(self, tools: List[Tool], raise_on_failure: bool = True, convert_result_to_json_string: bool = False): |
| 104 | + """ |
| 105 | + Initialize the ToolInvoker component. |
| 106 | +
|
| 107 | + :param tools: |
| 108 | + A list of tools that can be invoked. |
| 109 | + :param raise_on_failure: |
| 110 | + If True, the component will raise an exception in case of errors |
| 111 | + (tool not found, tool invocation errors, tool result conversion errors). |
| 112 | + If False, the component will return a ChatMessage object with `error=True` |
| 113 | + and a description of the error in `result`. |
| 114 | + :param convert_result_to_json_string: |
| 115 | + If True, the tool invocation result will be converted to a string using `json.dumps`. |
| 116 | + If False, the tool invocation result will be converted to a string using `str`. |
| 117 | +
|
| 118 | + :raises ValueError: |
| 119 | + If no tools are provided or if duplicate tool names are found. |
| 120 | + """ |
| 121 | + |
| 122 | + msg = "The `ToolInvoker` component is experimental and its API may change in the future." |
| 123 | + warnings.warn(msg) |
| 124 | + |
| 125 | + if not tools: |
| 126 | + raise ValueError("ToolInvoker requires at least one tool to be provided.") |
| 127 | + _check_duplicate_tool_names(tools) |
| 128 | + |
| 129 | + self.tools = tools |
| 130 | + self._tools_with_names = dict(zip([tool.name for tool in tools], tools)) |
| 131 | + self.raise_on_failure = raise_on_failure |
| 132 | + self.convert_result_to_json_string = convert_result_to_json_string |
| 133 | + |
| 134 | + def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall) -> ChatMessage: |
| 135 | + """ |
| 136 | + Prepares a ChatMessage with the result of a tool invocation. |
| 137 | +
|
| 138 | + :param result: |
| 139 | + The tool result. |
| 140 | + :returns: |
| 141 | + A ChatMessage object containing the tool result as a string. |
| 142 | +
|
| 143 | + :raises |
| 144 | + StringConversionError: If the conversion of the tool result to a string fails |
| 145 | + and `raise_on_failure` is True. |
| 146 | + """ |
| 147 | + error = False |
| 148 | + |
| 149 | + if self.convert_result_to_json_string: |
| 150 | + try: |
| 151 | + # We disable ensure_ascii so special chars like emojis are not converted |
| 152 | + tool_result_str = json.dumps(result, ensure_ascii=False) |
| 153 | + except Exception as e: |
| 154 | + if self.raise_on_failure: |
| 155 | + raise StringConversionError("Failed to convert tool result to string using `json.dumps`") from e |
| 156 | + tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="json.dumps") |
| 157 | + error = True |
| 158 | + return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call) |
| 159 | + |
| 160 | + try: |
| 161 | + tool_result_str = str(result) |
| 162 | + except Exception as e: |
| 163 | + if self.raise_on_failure: |
| 164 | + raise StringConversionError("Failed to convert tool result to string using `str`") from e |
| 165 | + tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="str") |
| 166 | + error = True |
| 167 | + return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call) |
| 168 | + |
| 169 | + @component.output_types(tool_messages=List[ChatMessage]) |
| 170 | + def run(self, messages: List[ChatMessage]) -> Dict[str, Any]: |
| 171 | + """ |
| 172 | + Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available. |
| 173 | +
|
| 174 | + :param messages: |
| 175 | + A list of ChatMessage objects. |
| 176 | + :returns: |
| 177 | + A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role. |
| 178 | + Each ChatMessage objects wraps the result of a tool invocation. |
| 179 | +
|
| 180 | + :raises ToolNotFoundException: |
| 181 | + If the tool is not found in the list of available tools and `raise_on_failure` is True. |
| 182 | + :raises ToolInvocationError: |
| 183 | + If the tool invocation fails and `raise_on_failure` is True. |
| 184 | + :raises StringConversionError: |
| 185 | + If the conversion of the tool result to a string fails and `raise_on_failure` is True. |
| 186 | + """ |
| 187 | + tool_messages = [] |
| 188 | + |
| 189 | + for message in messages: |
| 190 | + tool_calls = message.tool_calls |
| 191 | + if not tool_calls: |
| 192 | + continue |
| 193 | + |
| 194 | + for tool_call in tool_calls: |
| 195 | + tool_name = tool_call.tool_name |
| 196 | + tool_arguments = tool_call.arguments |
| 197 | + |
| 198 | + if not tool_name in self._tools_with_names: |
| 199 | + msg = _TOOL_NOT_FOUND.format(tool_name=tool_name, available_tools=self._tools_with_names.keys()) |
| 200 | + if self.raise_on_failure: |
| 201 | + raise ToolNotFoundException(msg) |
| 202 | + tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True)) |
| 203 | + continue |
| 204 | + |
| 205 | + tool_to_invoke = self._tools_with_names[tool_name] |
| 206 | + try: |
| 207 | + tool_result = tool_to_invoke.invoke(**tool_arguments) |
| 208 | + except ToolInvocationError as e: |
| 209 | + if self.raise_on_failure: |
| 210 | + raise e |
| 211 | + msg = _TOOL_INVOCATION_FAILURE.format(error=e) |
| 212 | + tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True)) |
| 213 | + continue |
| 214 | + |
| 215 | + tool_message = self._prepare_tool_result_message(tool_result, tool_call) |
| 216 | + tool_messages.append(tool_message) |
| 217 | + |
| 218 | + return {"tool_messages": tool_messages} |
| 219 | + |
| 220 | + def to_dict(self) -> Dict[str, Any]: |
| 221 | + """ |
| 222 | + Serializes the component to a dictionary. |
| 223 | +
|
| 224 | + :returns: |
| 225 | + Dictionary with serialized data. |
| 226 | + """ |
| 227 | + serialized_tools = [tool.to_dict() for tool in self.tools] |
| 228 | + return default_to_dict( |
| 229 | + self, |
| 230 | + tools=serialized_tools, |
| 231 | + raise_on_failure=self.raise_on_failure, |
| 232 | + convert_result_to_json_string=self.convert_result_to_json_string, |
| 233 | + ) |
| 234 | + |
| 235 | + @classmethod |
| 236 | + def from_dict(cls, data: Dict[str, Any]) -> "ToolInvoker": |
| 237 | + """ |
| 238 | + Deserializes the component from a dictionary. |
| 239 | +
|
| 240 | + :param data: |
| 241 | + The dictionary to deserialize from. |
| 242 | + :returns: |
| 243 | + The deserialized component. |
| 244 | + """ |
| 245 | + deserialize_tools_inplace(data["init_parameters"], key="tools") |
| 246 | + return default_from_dict(cls, data) |
0 commit comments