@@ -19,6 +19,11 @@ def streaming_callback_handler(x):
19
19
return x
20
20
21
21
22
+ def get_weather (city : str ) -> str :
23
+ """Get the weather for a given city."""
24
+ return f"Weather data for { city } "
25
+
26
+
22
27
@pytest .fixture
23
28
def chat_messages ():
24
29
return [
@@ -57,8 +62,9 @@ def tools():
57
62
name = "weather" ,
58
63
description = "useful to determine the weather in a given location" ,
59
64
parameters = tool_parameters ,
60
- function = lambda x : x ,
65
+ function = get_weather ,
61
66
)
67
+
62
68
return [tool ]
63
69
64
70
@@ -151,14 +157,15 @@ def test_init_invalid_task(self):
151
157
with pytest .raises (ValueError , match = "is not supported." ):
152
158
HuggingFaceLocalChatGenerator (task = "text-classification" )
153
159
154
- def test_to_dict (self , model_info_mock ):
160
+ def test_to_dict (self , model_info_mock , tools ):
155
161
generator = HuggingFaceLocalChatGenerator (
156
162
model = "NousResearch/Llama-2-7b-chat-hf" ,
157
163
token = Secret .from_env_var ("ENV_VAR" , strict = False ),
158
164
generation_kwargs = {"n" : 5 },
159
165
stop_words = ["stop" , "words" ],
160
- streaming_callback = streaming_callback_handler ,
166
+ streaming_callback = None ,
161
167
chat_template = "irrelevant" ,
168
+ tools = tools ,
162
169
)
163
170
164
171
# Call the to_dict method
@@ -170,16 +177,28 @@ def test_to_dict(self, model_info_mock):
170
177
assert init_params ["huggingface_pipeline_kwargs" ]["model" ] == "NousResearch/Llama-2-7b-chat-hf"
171
178
assert "token" not in init_params ["huggingface_pipeline_kwargs" ]
172
179
assert init_params ["generation_kwargs" ] == {"max_new_tokens" : 512 , "n" : 5 , "stop_sequences" : ["stop" , "words" ]}
173
- assert init_params ["streaming_callback" ] == "chat.test_hugging_face_local.streaming_callback_handler"
180
+ assert init_params ["streaming_callback" ] is None
174
181
assert init_params ["chat_template" ] == "irrelevant"
182
+ assert init_params ["tools" ] == [
183
+ {
184
+ "type" : "haystack.tools.tool.Tool" ,
185
+ "data" : {
186
+ "name" : "weather" ,
187
+ "description" : "useful to determine the weather in a given location" ,
188
+ "parameters" : {"type" : "object" , "properties" : {"city" : {"type" : "string" }}, "required" : ["city" ]},
189
+ "function" : "chat.test_hugging_face_local.get_weather" ,
190
+ },
191
+ }
192
+ ]
175
193
176
- def test_from_dict (self , model_info_mock ):
194
+ def test_from_dict (self , model_info_mock , tools ):
177
195
generator = HuggingFaceLocalChatGenerator (
178
196
model = "NousResearch/Llama-2-7b-chat-hf" ,
179
197
generation_kwargs = {"n" : 5 },
180
198
stop_words = ["stop" , "words" ],
181
- streaming_callback = streaming_callback_handler ,
199
+ streaming_callback = None ,
182
200
chat_template = "irrelevant" ,
201
+ tools = tools ,
183
202
)
184
203
# Call the to_dict method
185
204
result = generator .to_dict ()
@@ -188,8 +207,16 @@ def test_from_dict(self, model_info_mock):
188
207
189
208
assert generator_2 .token == Secret .from_env_var (["HF_API_TOKEN" , "HF_TOKEN" ], strict = False )
190
209
assert generator_2 .generation_kwargs == {"max_new_tokens" : 512 , "n" : 5 , "stop_sequences" : ["stop" , "words" ]}
191
- assert generator_2 .streaming_callback is streaming_callback_handler
210
+ assert generator_2 .streaming_callback is None
192
211
assert generator_2 .chat_template == "irrelevant"
212
+ assert len (generator_2 .tools ) == 1
213
+ assert generator_2 .tools [0 ].name == "weather"
214
+ assert generator_2 .tools [0 ].description == "useful to determine the weather in a given location"
215
+ assert generator_2 .tools [0 ].parameters == {
216
+ "type" : "object" ,
217
+ "properties" : {"city" : {"type" : "string" }},
218
+ "required" : ["city" ],
219
+ }
193
220
194
221
@patch ("haystack.components.generators.chat.hugging_face_local.pipeline" )
195
222
def test_warm_up (self , pipeline_mock , monkeypatch ):
0 commit comments