Skip to content

Commit 89d8b76

Browse files
committed
Add tools serde and tests
1 parent 7cbcbd6 commit 89d8b76

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

Diff for: haystack/components/generators/chat/hugging_face_local.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from haystack import component, default_from_dict, default_to_dict, logging
1111
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
1212
from haystack.lazy_imports import LazyImport
13-
from haystack.tools import Tool, _check_duplicate_tool_names
13+
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
1414
from haystack.utils import (
1515
ComponentDevice,
1616
Secret,
@@ -219,13 +219,15 @@ def to_dict(self) -> Dict[str, Any]:
219219
Dictionary with serialized data.
220220
"""
221221
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
222+
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
222223
serialization_dict = default_to_dict(
223224
self,
224225
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
225226
generation_kwargs=self.generation_kwargs,
226227
streaming_callback=callback_name,
227228
token=self.token.to_dict() if self.token else None,
228229
chat_template=self.chat_template,
230+
tools=serialized_tools,
229231
)
230232

231233
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
@@ -246,6 +248,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator":
246248
"""
247249
torch_and_transformers_import.check() # leave this, cls method
248250
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
251+
deserialize_tools_inplace(data["init_parameters"], key="tools")
249252
init_params = data.get("init_parameters", {})
250253
serialized_callback_handler = init_params.get("streaming_callback")
251254
if serialized_callback_handler:

Diff for: test/components/generators/chat/test_hugging_face_local.py

+34-7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ def streaming_callback_handler(x):
1919
return x
2020

2121

22+
def get_weather(city: str) -> str:
23+
"""Get the weather for a given city."""
24+
return f"Weather data for {city}"
25+
26+
2227
@pytest.fixture
2328
def chat_messages():
2429
return [
@@ -57,8 +62,9 @@ def tools():
5762
name="weather",
5863
description="useful to determine the weather in a given location",
5964
parameters=tool_parameters,
60-
function=lambda x: x,
65+
function=get_weather,
6166
)
67+
6268
return [tool]
6369

6470

@@ -151,14 +157,15 @@ def test_init_invalid_task(self):
151157
with pytest.raises(ValueError, match="is not supported."):
152158
HuggingFaceLocalChatGenerator(task="text-classification")
153159

154-
def test_to_dict(self, model_info_mock):
160+
def test_to_dict(self, model_info_mock, tools):
155161
generator = HuggingFaceLocalChatGenerator(
156162
model="NousResearch/Llama-2-7b-chat-hf",
157163
token=Secret.from_env_var("ENV_VAR", strict=False),
158164
generation_kwargs={"n": 5},
159165
stop_words=["stop", "words"],
160-
streaming_callback=streaming_callback_handler,
166+
streaming_callback=None,
161167
chat_template="irrelevant",
168+
tools=tools,
162169
)
163170

164171
# Call the to_dict method
@@ -170,16 +177,28 @@ def test_to_dict(self, model_info_mock):
170177
assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf"
171178
assert "token" not in init_params["huggingface_pipeline_kwargs"]
172179
assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
173-
assert init_params["streaming_callback"] == "chat.test_hugging_face_local.streaming_callback_handler"
180+
assert init_params["streaming_callback"] is None
174181
assert init_params["chat_template"] == "irrelevant"
182+
assert init_params["tools"] == [
183+
{
184+
"type": "haystack.tools.tool.Tool",
185+
"data": {
186+
"name": "weather",
187+
"description": "useful to determine the weather in a given location",
188+
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
189+
"function": "chat.test_hugging_face_local.get_weather",
190+
},
191+
}
192+
]
175193

176-
def test_from_dict(self, model_info_mock):
194+
def test_from_dict(self, model_info_mock, tools):
177195
generator = HuggingFaceLocalChatGenerator(
178196
model="NousResearch/Llama-2-7b-chat-hf",
179197
generation_kwargs={"n": 5},
180198
stop_words=["stop", "words"],
181-
streaming_callback=streaming_callback_handler,
199+
streaming_callback=None,
182200
chat_template="irrelevant",
201+
tools=tools,
183202
)
184203
# Call the to_dict method
185204
result = generator.to_dict()
@@ -188,8 +207,16 @@ def test_from_dict(self, model_info_mock):
188207

189208
assert generator_2.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
190209
assert generator_2.generation_kwargs == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
191-
assert generator_2.streaming_callback is streaming_callback_handler
210+
assert generator_2.streaming_callback is None
192211
assert generator_2.chat_template == "irrelevant"
212+
assert len(generator_2.tools) == 1
213+
assert generator_2.tools[0].name == "weather"
214+
assert generator_2.tools[0].description == "useful to determine the weather in a given location"
215+
assert generator_2.tools[0].parameters == {
216+
"type": "object",
217+
"properties": {"city": {"type": "string"}},
218+
"required": ["city"],
219+
}
193220

194221
@patch("haystack.components.generators.chat.hugging_face_local.pipeline")
195222
def test_warm_up(self, pipeline_mock, monkeypatch):

0 commit comments

Comments
 (0)