Skip to content

Commit 3f15f38

Browse files
authored
refactor: move Tool to a separate package; refactor serde (#8690)
* move tool to separate package; refactor serde * release note * rm unused import
1 parent 28ad78c commit 3f15f38

File tree

14 files changed

+127
-56
lines changed

14 files changed

+127
-56
lines changed

docs/pydoc/config/data_classess_api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ loaders:
22
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
33
search_path: [../../../haystack/dataclasses]
44
modules:
5-
["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding", "tool"]
5+
["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding",]
66
ignore_when_discovered: ["__init__"]
77
processors:
88
- type: filter
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
loaders:
2+
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
3+
search_path: [../../../haystack/components/tools]
4+
modules: ["tool_invoker"]
5+
ignore_when_discovered: ["__init__"]
6+
processors:
7+
- type: filter
8+
expression:
9+
documented_only: true
10+
do_not_filter_modules: false
11+
skip_empty_modules: true
12+
- type: smart
13+
- type: crossref
14+
renderer:
15+
type: haystack_pydoc_tools.renderers.ReadmeCoreRenderer
16+
excerpt: Components related to Tool Calling.
17+
category_slug: haystack-api
18+
title: Tool Components
19+
slug: tool-components-api
20+
order: 152
21+
markdown:
22+
descriptive_class_title: false
23+
classdef_code_block: false
24+
descriptive_module_title: true
25+
add_method_class_prefix: true
26+
add_member_class_prefix: false
27+
filename: tool_components_api.md

docs/pydoc/config/tools_api.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
loaders:
22
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
3-
search_path: [../../../haystack/components/tools]
4-
modules: ["tool_invoker"]
3+
search_path: [../../../haystack/tools]
4+
modules:
5+
["tool"]
56
ignore_when_discovered: ["__init__"]
67
processors:
78
- type: filter
@@ -13,11 +14,11 @@ processors:
1314
- type: crossref
1415
renderer:
1516
type: haystack_pydoc_tools.renderers.ReadmeCoreRenderer
16-
excerpt: Components related to Tool Calling.
17+
excerpt: Unified abstractions to represent tools across the framework.
1718
category_slug: haystack-api
1819
title: Tools
1920
slug: tools-api
20-
order: 152
21+
order: 151
2122
markdown:
2223
descriptive_class_title: false
2324
classdef_code_block: false

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from haystack import component, default_from_dict, default_to_dict, logging
88
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
9-
from haystack.dataclasses.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
109
from haystack.lazy_imports import LazyImport
10+
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
1111
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1212
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
1313
from haystack.utils.url_validation import is_valid_http_url

haystack/components/generators/chat/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from haystack import component, default_from_dict, default_to_dict, logging
1515
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
16-
from haystack.dataclasses.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
16+
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
1717
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1818

1919
logger = logging.getLogger(__name__)

haystack/components/tools/tool_invoker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from haystack import component, default_from_dict, default_to_dict, logging
1010
from haystack.dataclasses.chat_message import ChatMessage, ToolCall
11-
from haystack.dataclasses.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace
11+
from haystack.tools.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace
1212

1313
logger = logging.getLogger(__name__)
1414

haystack/dataclasses/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from haystack.dataclasses.document import Document
99
from haystack.dataclasses.sparse_embedding import SparseEmbedding
1010
from haystack.dataclasses.streaming_chunk import StreamingChunk
11-
from haystack.dataclasses.tool import Tool
1211

1312
__all__ = [
1413
"Document",
@@ -23,5 +22,4 @@
2322
"TextContent",
2423
"StreamingChunk",
2524
"SparseEmbedding",
26-
"Tool",
2725
]

haystack/tools/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
6+
7+
__all__ = ["Tool", "_check_duplicate_tool_names", "deserialize_tools_inplace"]

haystack/dataclasses/tool.py renamed to haystack/tools/tool.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from pydantic import create_model
1010

11+
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
1112
from haystack.lazy_imports import LazyImport
1213
from haystack.utils import deserialize_callable, serialize_callable
1314

@@ -89,9 +90,9 @@ def to_dict(self) -> Dict[str, Any]:
8990
Dictionary with serialized data.
9091
"""
9192

92-
serialized = asdict(self)
93-
serialized["function"] = serialize_callable(self.function)
94-
return serialized
93+
data = asdict(self)
94+
data["function"] = serialize_callable(self.function)
95+
return {"type": generate_qualified_class_name(type(self)), "data": data}
9596

9697
@classmethod
9798
def from_dict(cls, data: Dict[str, Any]) -> "Tool":
@@ -103,8 +104,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "Tool":
103104
:returns:
104105
Deserialized Tool.
105106
"""
106-
data["function"] = deserialize_callable(data["function"])
107-
return cls(**data)
107+
init_parameters = data["data"]
108+
init_parameters["function"] = deserialize_callable(init_parameters["function"])
109+
return cls(**init_parameters)
108110

109111
@classmethod
110112
def from_function(cls, function: Callable, name: Optional[str] = None, description: Optional[str] = None) -> "Tool":
@@ -253,6 +255,12 @@ def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"):
253255
for tool in serialized_tools:
254256
if not isinstance(tool, dict):
255257
raise TypeError(f"Serialized tool '{tool}' is not a dictionary")
256-
deserialized_tools.append(Tool.from_dict(tool))
258+
259+
# different classes are allowed: Tool, ComponentTool, etc.
260+
tool_class = import_class_by_name(tool["type"])
261+
if not issubclass(tool_class, Tool):
262+
raise TypeError(f"Class '{tool_class}' is not a subclass of Tool")
263+
264+
deserialized_tools.append(tool_class.from_dict(tool))
257265

258266
data[key] = deserialized_tools
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
enhancements:
3+
- |
4+
Move `Tool` to a new dedicated `tools` package.
5+
Refactor `Tool` serialization and deserialization to make it more flexible and include type information.

test/components/generators/chat/test_hugging_face_api.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from huggingface_hub.utils import RepositoryNotFoundError
2424

2525
from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator
26-
from haystack.dataclasses import ChatMessage, Tool, ToolCall
26+
from haystack.tools import Tool
27+
from haystack.dataclasses import ChatMessage, ToolCall
2728

2829

2930
@pytest.fixture
@@ -217,10 +218,13 @@ def test_to_dict(self, mock_check_valid_model):
217218
assert init_params["streaming_callback"] is None
218219
assert init_params["tools"] == [
219220
{
220-
"description": "description",
221-
"function": "builtins.print",
222-
"name": "name",
223-
"parameters": {"x": {"type": "string"}},
221+
"type": "haystack.tools.tool.Tool",
222+
"data": {
223+
"description": "description",
224+
"function": "builtins.print",
225+
"name": "name",
226+
"parameters": {"x": {"type": "string"}},
227+
},
224228
}
225229
]
226230

@@ -276,10 +280,13 @@ def test_serde_in_pipeline(self, mock_check_valid_model):
276280
"streaming_callback": None,
277281
"tools": [
278282
{
279-
"name": "name",
280-
"description": "description",
281-
"parameters": {"x": {"type": "string"}},
282-
"function": "builtins.print",
283+
"type": "haystack.tools.tool.Tool",
284+
"data": {
285+
"name": "name",
286+
"description": "description",
287+
"parameters": {"x": {"type": "string"}},
288+
"function": "builtins.print",
289+
},
283290
}
284291
],
285292
},

test/components/generators/chat/test_openai.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from haystack.components.generators.utils import print_streaming_chunk
1919
from haystack.dataclasses import StreamingChunk
2020
from haystack.utils.auth import Secret
21-
from haystack.dataclasses import ChatMessage, Tool, ToolCall
21+
from haystack.dataclasses import ChatMessage, ToolCall
22+
from haystack.tools import Tool
2223
from haystack.components.generators.chat.openai import OpenAIChatGenerator
2324

2425

@@ -200,10 +201,13 @@ def test_to_dict_with_parameters(self, monkeypatch):
200201
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
201202
"tools": [
202203
{
203-
"description": "description",
204-
"function": "builtins.print",
205-
"name": "name",
206-
"parameters": {"x": {"type": "string"}},
204+
"type": "haystack.tools.tool.Tool",
205+
"data": {
206+
"description": "description",
207+
"function": "builtins.print",
208+
"name": "name",
209+
"parameters": {"x": {"type": "string"}},
210+
},
207211
}
208212
],
209213
"tools_strict": True,
@@ -224,10 +228,13 @@ def test_from_dict(self, monkeypatch):
224228
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
225229
"tools": [
226230
{
227-
"description": "description",
228-
"function": "builtins.print",
229-
"name": "name",
230-
"parameters": {"x": {"type": "string"}},
231+
"type": "haystack.tools.tool.Tool",
232+
"data": {
233+
"description": "description",
234+
"function": "builtins.print",
235+
"name": "name",
236+
"parameters": {"x": {"type": "string"}},
237+
},
231238
}
232239
],
233240
"tools_strict": True,

test/components/tools/test_tool_invoker.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from haystack import Pipeline
55

66
from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole
7-
from haystack.dataclasses.tool import Tool, ToolInvocationError
7+
from haystack.tools.tool import Tool, ToolInvocationError
88
from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError
99
from haystack.components.generators.chat.openai import OpenAIChatGenerator
1010

@@ -238,14 +238,17 @@ def test_serde_in_pipeline(self, invoker, monkeypatch):
238238
"init_parameters": {
239239
"tools": [
240240
{
241-
"name": "weather_tool",
242-
"description": "Provides weather information for a given location.",
243-
"parameters": {
244-
"type": "object",
245-
"properties": {"location": {"type": "string"}},
246-
"required": ["location"],
241+
"type": "haystack.tools.tool.Tool",
242+
"data": {
243+
"name": "weather_tool",
244+
"description": "Provides weather information for a given location.",
245+
"parameters": {
246+
"type": "object",
247+
"properties": {"location": {"type": "string"}},
248+
"required": ["location"],
249+
},
250+
"function": "tools.test_tool_invoker.weather_function",
247251
},
248-
"function": "tools.test_tool_invoker.weather_function",
249252
}
250253
],
251254
"raise_on_failure": True,

test/dataclasses/test_tool.py renamed to test/tools/test_tool.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from typing import Literal, Optional
66

77
import pytest
8-
9-
from haystack.dataclasses.tool import (
8+
from haystack.tools.tool import (
109
SchemaGenerationError,
1110
Tool,
1211
ToolInvocationError,
@@ -78,18 +77,24 @@ def test_to_dict(self):
7877
)
7978

8079
assert tool.to_dict() == {
81-
"name": "weather",
82-
"description": "Get weather report",
83-
"parameters": parameters,
84-
"function": "test_tool.get_weather_report",
80+
"type": "haystack.tools.tool.Tool",
81+
"data": {
82+
"name": "weather",
83+
"description": "Get weather report",
84+
"parameters": parameters,
85+
"function": "test_tool.get_weather_report",
86+
},
8587
}
8688

8789
def test_from_dict(self):
8890
tool_dict = {
89-
"name": "weather",
90-
"description": "Get weather report",
91-
"parameters": parameters,
92-
"function": "test_tool.get_weather_report",
91+
"type": "haystack.tools.tool.Tool",
92+
"data": {
93+
"name": "weather",
94+
"description": "Get weather report",
95+
"parameters": parameters,
96+
"function": "test_tool.get_weather_report",
97+
},
9398
}
9499

95100
tool = Tool.from_dict(tool_dict)
@@ -179,14 +184,12 @@ def function_with_annotations(
179184

180185
def test_deserialize_tools_inplace():
181186
tool = Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report)
182-
serialized_tool = tool.to_dict()
183-
print(serialized_tool)
184187

185-
data = {"tools": [serialized_tool.copy()]}
188+
data = {"tools": [tool.to_dict()]}
186189
deserialize_tools_inplace(data)
187190
assert data["tools"] == [tool]
188191

189-
data = {"mytools": [serialized_tool.copy()]}
192+
data = {"mytools": [tool.to_dict()]}
190193
deserialize_tools_inplace(data, key="mytools")
191194
assert data["mytools"] == [tool]
192195

@@ -212,6 +215,11 @@ def test_deserialize_tools_inplace_failures():
212215
with pytest.raises(TypeError):
213216
deserialize_tools_inplace(data)
214217

218+
# not a subclass of Tool
219+
data = {"tools": [{"type": "haystack.dataclasses.ChatMessage", "data": {"irrelevant": "irrelevant"}}]}
220+
with pytest.raises(TypeError):
221+
deserialize_tools_inplace(data)
222+
215223

216224
def test_remove_title_from_schema():
217225
complex_schema = {

0 commit comments

Comments
 (0)