9
9
from haystack import Pipeline
10
10
from haystack .components .generators .chat import AzureOpenAIChatGenerator
11
11
from haystack .components .generators .utils import print_streaming_chunk
12
- from haystack .dataclasses import ChatMessage
12
+ from haystack .dataclasses import ChatMessage , ToolCall
13
+ from haystack .tools .tool import Tool
13
14
from haystack .utils .auth import Secret
14
15
15
16
16
- class TestOpenAIChatGenerator :
17
+ @pytest .fixture
18
+ def tools ():
19
+ tool_parameters = {"type" : "object" , "properties" : {"city" : {"type" : "string" }}, "required" : ["city" ]}
20
+ tool = Tool (
21
+ name = "weather" ,
22
+ description = "useful to determine the weather in a given location" ,
23
+ parameters = tool_parameters ,
24
+ function = lambda x : x ,
25
+ )
26
+
27
+ return [tool ]
28
+
29
+
30
+ class TestAzureOpenAIChatGenerator :
17
31
def test_init_default (self , monkeypatch ):
18
32
monkeypatch .setenv ("AZURE_OPENAI_API_KEY" , "test-api-key" )
19
33
component = AzureOpenAIChatGenerator (azure_endpoint = "some-non-existing-endpoint" )
@@ -28,17 +42,21 @@ def test_init_fail_wo_api_key(self, monkeypatch):
28
42
with pytest .raises (OpenAIError ):
29
43
AzureOpenAIChatGenerator (azure_endpoint = "some-non-existing-endpoint" )
30
44
31
- def test_init_with_parameters (self ):
45
+ def test_init_with_parameters (self , tools ):
32
46
component = AzureOpenAIChatGenerator (
33
47
api_key = Secret .from_token ("test-api-key" ),
34
48
azure_endpoint = "some-non-existing-endpoint" ,
35
49
streaming_callback = print_streaming_chunk ,
36
50
generation_kwargs = {"max_tokens" : 10 , "some_test_param" : "test-params" },
51
+ tools = tools ,
52
+ tools_strict = True ,
37
53
)
38
54
assert component .client .api_key == "test-api-key"
39
55
assert component .azure_deployment == "gpt-4o-mini"
40
56
assert component .streaming_callback is print_streaming_chunk
41
57
assert component .generation_kwargs == {"max_tokens" : 10 , "some_test_param" : "test-params" }
58
+ assert component .tools == tools
59
+ assert component .tools_strict
42
60
43
61
def test_to_dict_default (self , monkeypatch ):
44
62
monkeypatch .setenv ("AZURE_OPENAI_API_KEY" , "test-api-key" )
@@ -58,6 +76,8 @@ def test_to_dict_default(self, monkeypatch):
58
76
"timeout" : 30.0 ,
59
77
"max_retries" : 5 ,
60
78
"default_headers" : {},
79
+ "tools" : None ,
80
+ "tools_strict" : False ,
61
81
},
62
82
}
63
83
@@ -85,15 +105,94 @@ def test_to_dict_with_parameters(self, monkeypatch):
85
105
"timeout" : 2.5 ,
86
106
"max_retries" : 10 ,
87
107
"generation_kwargs" : {"max_tokens" : 10 , "some_test_param" : "test-params" },
108
+ "tools" : None ,
109
+ "tools_strict" : False ,
88
110
"default_headers" : {},
89
111
},
90
112
}
91
113
114
+ def test_from_dict (self , monkeypatch ):
115
+ monkeypatch .setenv ("AZURE_OPENAI_API_KEY" , "test-api-key" )
116
+ monkeypatch .setenv ("AZURE_OPENAI_AD_TOKEN" , "test-ad-token" )
117
+ data = {
118
+ "type" : "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator" ,
119
+ "init_parameters" : {
120
+ "api_key" : {"env_vars" : ["AZURE_OPENAI_API_KEY" ], "strict" : False , "type" : "env_var" },
121
+ "azure_ad_token" : {"env_vars" : ["AZURE_OPENAI_AD_TOKEN" ], "strict" : False , "type" : "env_var" },
122
+ "api_version" : "2023-05-15" ,
123
+ "azure_endpoint" : "some-non-existing-endpoint" ,
124
+ "azure_deployment" : "gpt-4o-mini" ,
125
+ "organization" : None ,
126
+ "streaming_callback" : None ,
127
+ "generation_kwargs" : {},
128
+ "timeout" : 30.0 ,
129
+ "max_retries" : 5 ,
130
+ "default_headers" : {},
131
+ "tools" : [
132
+ {
133
+ "type" : "haystack.tools.tool.Tool" ,
134
+ "data" : {
135
+ "description" : "description" ,
136
+ "function" : "builtins.print" ,
137
+ "name" : "name" ,
138
+ "parameters" : {"x" : {"type" : "string" }},
139
+ },
140
+ }
141
+ ],
142
+ "tools_strict" : False ,
143
+ },
144
+ }
145
+
146
+ generator = AzureOpenAIChatGenerator .from_dict (data )
147
+ assert isinstance (generator , AzureOpenAIChatGenerator )
148
+
149
+ assert generator .api_key == Secret .from_env_var ("AZURE_OPENAI_API_KEY" , strict = False )
150
+ assert generator .azure_ad_token == Secret .from_env_var ("AZURE_OPENAI_AD_TOKEN" , strict = False )
151
+ assert generator .api_version == "2023-05-15"
152
+ assert generator .azure_endpoint == "some-non-existing-endpoint"
153
+ assert generator .azure_deployment == "gpt-4o-mini"
154
+ assert generator .organization is None
155
+ assert generator .streaming_callback is None
156
+ assert generator .generation_kwargs == {}
157
+ assert generator .timeout == 30.0
158
+ assert generator .max_retries == 5
159
+ assert generator .default_headers == {}
160
+ assert generator .tools == [
161
+ Tool (name = "name" , description = "description" , parameters = {"x" : {"type" : "string" }}, function = print )
162
+ ]
163
+ assert generator .tools_strict == False
164
+
92
165
def test_pipeline_serialization_deserialization (self , tmp_path , monkeypatch ):
93
166
monkeypatch .setenv ("AZURE_OPENAI_API_KEY" , "test-api-key" )
94
167
generator = AzureOpenAIChatGenerator (azure_endpoint = "some-non-existing-endpoint" )
95
168
p = Pipeline ()
96
169
p .add_component (instance = generator , name = "generator" )
170
+
171
+ assert p .to_dict () == {
172
+ "metadata" : {},
173
+ "max_runs_per_component" : 100 ,
174
+ "components" : {
175
+ "generator" : {
176
+ "type" : "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator" ,
177
+ "init_parameters" : {
178
+ "azure_endpoint" : "some-non-existing-endpoint" ,
179
+ "azure_deployment" : "gpt-4o-mini" ,
180
+ "organization" : None ,
181
+ "api_version" : "2023-05-15" ,
182
+ "streaming_callback" : None ,
183
+ "generation_kwargs" : {},
184
+ "timeout" : 30.0 ,
185
+ "max_retries" : 5 ,
186
+ "api_key" : {"type" : "env_var" , "env_vars" : ["AZURE_OPENAI_API_KEY" ], "strict" : False },
187
+ "azure_ad_token" : {"type" : "env_var" , "env_vars" : ["AZURE_OPENAI_AD_TOKEN" ], "strict" : False },
188
+ "default_headers" : {},
189
+ "tools" : None ,
190
+ "tools_strict" : False ,
191
+ },
192
+ }
193
+ },
194
+ "connections" : [],
195
+ }
97
196
p_str = p .dumps ()
98
197
q = Pipeline .loads (p_str )
99
198
assert p .to_dict () == q .to_dict (), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed."
@@ -117,4 +216,29 @@ def test_live_run(self):
117
216
assert "gpt-4o-mini" in message .meta ["model" ]
118
217
assert message .meta ["finish_reason" ] == "stop"
119
218
219
+ @pytest .mark .integration
220
+ @pytest .mark .skipif (
221
+ not os .environ .get ("AZURE_OPENAI_API_KEY" , None ) or not os .environ .get ("AZURE_OPENAI_ENDPOINT" , None ),
222
+ reason = (
223
+ "Please export env variables called AZURE_OPENAI_API_KEY containing "
224
+ "the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "
225
+ "the Azure OpenAI endpoint URL to run this test."
226
+ ),
227
+ )
228
+ def test_live_run_with_tools (self , tools ):
229
+ chat_messages = [ChatMessage .from_user ("What's the weather like in Paris?" )]
230
+ component = AzureOpenAIChatGenerator (organization = "HaystackCI" , tools = tools )
231
+ results = component .run (chat_messages )
232
+ assert len (results ["replies" ]) == 1
233
+ message = results ["replies" ][0 ]
234
+
235
+ assert not message .texts
236
+ assert not message .text
237
+ assert message .tool_calls
238
+ tool_call = message .tool_call
239
+ assert isinstance (tool_call , ToolCall )
240
+ assert tool_call .tool_name == "weather"
241
+ assert tool_call .arguments == {"city" : "Paris" }
242
+ assert message .meta ["finish_reason" ] == "tool_calls"
243
+
120
244
# additional tests intentionally omitted as they are covered by test_openai.py
0 commit comments