Skip to content

Commit 96b4a1d

Browse files
anakin87dfokina
andauthored
feat: Tool dataclass - unified abstraction to represent tools (#8652)
* draft * del HF token in tests * adaptations * progress * fix type * import sorting * more control on deserialization * release note * improvements * support name field * fix chatpromptbuilder test * port Tool from experimental * release note * docs upd * Update tool.py --------- Co-authored-by: Daria Fokina <[email protected]>
1 parent ea36026 commit 96b4a1d

File tree

6 files changed

+561
-2
lines changed

6 files changed

+561
-2
lines changed

docs/pydoc/config/data_classess_api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ loaders:
22
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
33
search_path: [../../../haystack/dataclasses]
44
modules:
5-
["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding"]
5+
["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding", "tool"]
66
ignore_when_discovered: ["__init__"]
77
processors:
88
- type: filter

haystack/dataclasses/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from haystack.dataclasses.document import Document
99
from haystack.dataclasses.sparse_embedding import SparseEmbedding
1010
from haystack.dataclasses.streaming_chunk import StreamingChunk
11+
from haystack.dataclasses.tool import Tool
1112

1213
__all__ = [
1314
"Document",
@@ -22,4 +23,5 @@
2223
"TextContent",
2324
"StreamingChunk",
2425
"SparseEmbedding",
26+
"Tool",
2527
]

haystack/dataclasses/tool.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import inspect
6+
from dataclasses import asdict, dataclass
7+
from typing import Any, Callable, Dict, Optional
8+
9+
from pydantic import create_model
10+
11+
from haystack.lazy_imports import LazyImport
12+
from haystack.utils import deserialize_callable, serialize_callable
13+
14+
with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import:
15+
from jsonschema import Draft202012Validator
16+
from jsonschema.exceptions import SchemaError
17+
18+
19+
class ToolInvocationError(Exception):
20+
"""
21+
Exception raised when a Tool invocation fails.
22+
"""
23+
24+
pass
25+
26+
27+
class SchemaGenerationError(Exception):
28+
"""
29+
Exception raised when automatic schema generation fails.
30+
"""
31+
32+
pass
33+
34+
35+
@dataclass
36+
class Tool:
37+
"""
38+
Data class representing a Tool that Language Models can prepare a call for.
39+
40+
Accurate definitions of the textual attributes such as `name` and `description`
41+
are important for the Language Model to correctly prepare the call.
42+
43+
:param name:
44+
Name of the Tool.
45+
:param description:
46+
Description of the Tool.
47+
:param parameters:
48+
A JSON schema defining the parameters expected by the Tool.
49+
:param function:
50+
The function that will be invoked when the Tool is called.
51+
"""
52+
53+
name: str
54+
description: str
55+
parameters: Dict[str, Any]
56+
function: Callable
57+
58+
def __post_init__(self):
59+
jsonschema_import.check()
60+
# Check that the parameters define a valid JSON schema
61+
try:
62+
Draft202012Validator.check_schema(self.parameters)
63+
except SchemaError as e:
64+
raise ValueError("The provided parameters do not define a valid JSON schema") from e
65+
66+
@property
67+
def tool_spec(self) -> Dict[str, Any]:
68+
"""
69+
Return the Tool specification to be used by the Language Model.
70+
"""
71+
return {"name": self.name, "description": self.description, "parameters": self.parameters}
72+
73+
def invoke(self, **kwargs) -> Any:
74+
"""
75+
Invoke the Tool with the provided keyword arguments.
76+
"""
77+
78+
try:
79+
result = self.function(**kwargs)
80+
except Exception as e:
81+
raise ToolInvocationError(f"Failed to invoke Tool `{self.name}` with parameters {kwargs}") from e
82+
return result
83+
84+
def to_dict(self) -> Dict[str, Any]:
85+
"""
86+
Serializes the Tool to a dictionary.
87+
88+
:returns:
89+
Dictionary with serialized data.
90+
"""
91+
92+
serialized = asdict(self)
93+
serialized["function"] = serialize_callable(self.function)
94+
return serialized
95+
96+
@classmethod
97+
def from_dict(cls, data: Dict[str, Any]) -> "Tool":
98+
"""
99+
Deserializes the Tool from a dictionary.
100+
101+
:param data:
102+
Dictionary to deserialize from.
103+
:returns:
104+
Deserialized Tool.
105+
"""
106+
data["function"] = deserialize_callable(data["function"])
107+
return cls(**data)
108+
109+
@classmethod
110+
def from_function(cls, function: Callable, name: Optional[str] = None, description: Optional[str] = None) -> "Tool":
111+
"""
112+
Create a Tool instance from a function.
113+
114+
### Usage example
115+
116+
```python
117+
from typing import Annotated, Literal
118+
from haystack.dataclasses import Tool
119+
120+
def get_weather(
121+
city: Annotated[str, "the city for which to get the weather"] = "Munich",
122+
unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius"):
123+
'''A simple function to get the current weather for a location.'''
124+
return f"Weather report for {city}: 20 {unit}, sunny"
125+
126+
tool = Tool.from_function(get_weather)
127+
128+
print(tool)
129+
>>> Tool(name='get_weather', description='A simple function to get the current weather for a location.',
130+
>>> parameters={
131+
>>> 'type': 'object',
132+
>>> 'properties': {
133+
>>> 'city': {'type': 'string', 'description': 'the city for which to get the weather', 'default': 'Munich'},
134+
>>> 'unit': {
135+
>>> 'type': 'string',
136+
>>> 'enum': ['Celsius', 'Fahrenheit'],
137+
>>> 'description': 'the unit for the temperature',
138+
>>> 'default': 'Celsius',
139+
>>> },
140+
>>> }
141+
>>> },
142+
>>> function=<function get_weather at 0x7f7b3a8a9b80>)
143+
```
144+
145+
:param function:
146+
The function to be converted into a Tool.
147+
The function must include type hints for all parameters.
148+
If a parameter is annotated using `typing.Annotated`, its metadata will be used as parameter description.
149+
:param name:
150+
The name of the Tool. If not provided, the name of the function will be used.
151+
:param description:
152+
The description of the Tool. If not provided, the docstring of the function will be used.
153+
To intentionally leave the description empty, pass an empty string.
154+
155+
:returns:
156+
The Tool created from the function.
157+
158+
:raises ValueError:
159+
If any parameter of the function lacks a type hint.
160+
:raises SchemaGenerationError:
161+
If there is an error generating the JSON schema for the Tool.
162+
"""
163+
164+
tool_description = description if description is not None else (function.__doc__ or "")
165+
166+
signature = inspect.signature(function)
167+
168+
# collect fields (types and defaults) and descriptions from function parameters
169+
fields: Dict[str, Any] = {}
170+
descriptions = {}
171+
172+
for param_name, param in signature.parameters.items():
173+
if param.annotation is param.empty:
174+
raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.")
175+
176+
# if the parameter has not a default value, Pydantic requires an Ellipsis (...)
177+
# to explicitly indicate that the parameter is required
178+
default = param.default if param.default is not param.empty else ...
179+
fields[param_name] = (param.annotation, default)
180+
181+
if hasattr(param.annotation, "__metadata__"):
182+
descriptions[param_name] = param.annotation.__metadata__[0]
183+
184+
# create Pydantic model and generate JSON schema
185+
try:
186+
model = create_model(function.__name__, **fields)
187+
schema = model.model_json_schema()
188+
except Exception as e:
189+
raise SchemaGenerationError(f"Failed to create JSON schema for function '{function.__name__}'") from e
190+
191+
# we don't want to include title keywords in the schema, as they contain redundant information
192+
# there is no programmatic way to prevent Pydantic from adding them, so we remove them later
193+
# see https://github.com/pydantic/pydantic/discussions/8504
194+
_remove_title_from_schema(schema)
195+
196+
# add parameters descriptions to the schema
197+
for param_name, param_description in descriptions.items():
198+
if param_name in schema["properties"]:
199+
schema["properties"][param_name]["description"] = param_description
200+
201+
return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function)
202+
203+
204+
def _remove_title_from_schema(schema: Dict[str, Any]):
205+
"""
206+
Remove the 'title' keyword from JSON schema and contained property schemas.
207+
208+
:param schema:
209+
The JSON schema to remove the 'title' keyword from.
210+
"""
211+
schema.pop("title", None)
212+
213+
for property_schema in schema["properties"].values():
214+
for key in list(property_schema.keys()):
215+
if key == "title":
216+
del property_schema[key]
217+
218+
219+
def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"):
220+
"""
221+
Deserialize Tools in a dictionary inplace.
222+
223+
:param data:
224+
The dictionary with the serialized data.
225+
:param key:
226+
The key in the dictionary where the Tools are stored.
227+
"""
228+
if key in data:
229+
serialized_tools = data[key]
230+
231+
if serialized_tools is None:
232+
return
233+
234+
if not isinstance(serialized_tools, list):
235+
raise TypeError(f"The value of '{key}' is not a list")
236+
237+
deserialized_tools = []
238+
for tool in serialized_tools:
239+
if not isinstance(tool, dict):
240+
raise TypeError(f"Serialized tool '{tool}' is not a dictionary")
241+
deserialized_tools.append(Tool.from_dict(tool))
242+
243+
data[key] = deserialized_tools

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dependencies = [
4747
"tenacity!=8.4.0",
4848
"lazy-imports",
4949
"openai>=1.56.1",
50+
"pydantic",
5051
"Jinja2",
5152
"posthog", # telemetry
5253
"pyyaml",
@@ -113,7 +114,7 @@ extra-dependencies = [
113114
"jsonref", # OpenAPIServiceConnector, OpenAPIServiceToFunctions
114115
"openapi3",
115116

116-
# Validation
117+
# JsonSchemaValidator, Tool
117118
"jsonschema",
118119

119120
# Tracing
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
highlights: >
3+
We are introducing the `Tool` dataclass: a simple and unified abstraction to represent tools throughout the framework.
4+
By building on this abstraction, we will enable support for tools in Chat Generators,
5+
providing a consistent experience across models.
6+
features:
7+
- |
8+
Added a new `Tool` dataclass to represent a tool for which Language Models can prepare calls.

0 commit comments

Comments
 (0)