Skip to content

Commit c3d0643

Browse files
authored
feat: AzureOpenAIChatGenerator - support for tools (#8757)
* feat: AzureOpenAIChatGenerator - support for tools * release note * feedback
1 parent f96839e commit c3d0643

File tree

3 files changed

+146
-8
lines changed

3 files changed

+146
-8
lines changed

Diff for: haystack/components/generators/chat/azure.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import os
6-
from typing import Any, Callable, Dict, Optional
6+
from typing import Any, Callable, Dict, List, Optional
77

88
# pylint: disable=import-error
99
from openai.lib.azure import AzureOpenAI
1010

1111
from haystack import component, default_from_dict, default_to_dict, logging
1212
from haystack.components.generators.chat import OpenAIChatGenerator
1313
from haystack.dataclasses import StreamingChunk
14+
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
1415
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1516

1617
logger = logging.getLogger(__name__)
@@ -75,6 +76,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
7576
max_retries: Optional[int] = None,
7677
generation_kwargs: Optional[Dict[str, Any]] = None,
7778
default_headers: Optional[Dict[str, str]] = None,
79+
tools: Optional[List[Tool]] = None,
80+
tools_strict: bool = False,
7881
):
7982
"""
8083
Initialize the Azure OpenAI Chat Generator component.
@@ -112,6 +115,11 @@ def __init__( # pylint: disable=too-many-positional-arguments
112115
- `logit_bias`: Adds a logit bias to specific tokens. The keys of the dictionary are tokens, and the
113116
values are the bias to add to that token.
114117
:param default_headers: Default headers to use for the AzureOpenAI client.
118+
:param tools:
119+
A list of tools for which the model can prepare calls.
120+
:param tools_strict:
121+
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
122+
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
115123
"""
116124
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
117125
# with the API.
@@ -142,10 +150,9 @@ def __init__( # pylint: disable=too-many-positional-arguments
142150
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
143151
self.default_headers = default_headers or {}
144152

145-
# This ChatGenerator does not yet supports tools. The following workaround ensures that we do not
146-
# get an error when invoking the run method of the parent class (OpenAIChatGenerator).
147-
self.tools = None
148-
self.tools_strict = False
153+
_check_duplicate_tool_names(tools)
154+
self.tools = tools
155+
self.tools_strict = tools_strict
149156

150157
self.client = AzureOpenAI(
151158
api_version=api_version,
@@ -180,6 +187,8 @@ def to_dict(self) -> Dict[str, Any]:
180187
api_key=self.api_key.to_dict() if self.api_key is not None else None,
181188
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
182189
default_headers=self.default_headers,
190+
tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
191+
tools_strict=self.tools_strict,
183192
)
184193

185194
@classmethod
@@ -192,6 +201,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIChatGenerator":
192201
The deserialized component instance.
193202
"""
194203
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
204+
deserialize_tools_inplace(data["init_parameters"], key="tools")
195205
init_params = data.get("init_parameters", {})
196206
serialized_callback_handler = init_params.get("streaming_callback")
197207
if serialized_callback_handler:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
Add support for Tools in the Azure OpenAI Chat Generator.

Diff for: test/components/generators/chat/test_azure.py

+127-3
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,25 @@
99
from haystack import Pipeline
1010
from haystack.components.generators.chat import AzureOpenAIChatGenerator
1111
from haystack.components.generators.utils import print_streaming_chunk
12-
from haystack.dataclasses import ChatMessage
12+
from haystack.dataclasses import ChatMessage, ToolCall
13+
from haystack.tools.tool import Tool
1314
from haystack.utils.auth import Secret
1415

1516

16-
class TestOpenAIChatGenerator:
17+
@pytest.fixture
18+
def tools():
19+
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
20+
tool = Tool(
21+
name="weather",
22+
description="useful to determine the weather in a given location",
23+
parameters=tool_parameters,
24+
function=lambda x: x,
25+
)
26+
27+
return [tool]
28+
29+
30+
class TestAzureOpenAIChatGenerator:
1731
def test_init_default(self, monkeypatch):
1832
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
1933
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
@@ -28,17 +42,21 @@ def test_init_fail_wo_api_key(self, monkeypatch):
2842
with pytest.raises(OpenAIError):
2943
AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
3044

31-
def test_init_with_parameters(self):
45+
def test_init_with_parameters(self, tools):
3246
component = AzureOpenAIChatGenerator(
3347
api_key=Secret.from_token("test-api-key"),
3448
azure_endpoint="some-non-existing-endpoint",
3549
streaming_callback=print_streaming_chunk,
3650
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
51+
tools=tools,
52+
tools_strict=True,
3753
)
3854
assert component.client.api_key == "test-api-key"
3955
assert component.azure_deployment == "gpt-4o-mini"
4056
assert component.streaming_callback is print_streaming_chunk
4157
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
58+
assert component.tools == tools
59+
assert component.tools_strict
4260

4361
def test_to_dict_default(self, monkeypatch):
4462
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
@@ -58,6 +76,8 @@ def test_to_dict_default(self, monkeypatch):
5876
"timeout": 30.0,
5977
"max_retries": 5,
6078
"default_headers": {},
79+
"tools": None,
80+
"tools_strict": False,
6181
},
6282
}
6383

@@ -85,15 +105,94 @@ def test_to_dict_with_parameters(self, monkeypatch):
85105
"timeout": 2.5,
86106
"max_retries": 10,
87107
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
108+
"tools": None,
109+
"tools_strict": False,
88110
"default_headers": {},
89111
},
90112
}
91113

114+
def test_from_dict(self, monkeypatch):
115+
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
116+
monkeypatch.setenv("AZURE_OPENAI_AD_TOKEN", "test-ad-token")
117+
data = {
118+
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",
119+
"init_parameters": {
120+
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
121+
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
122+
"api_version": "2023-05-15",
123+
"azure_endpoint": "some-non-existing-endpoint",
124+
"azure_deployment": "gpt-4o-mini",
125+
"organization": None,
126+
"streaming_callback": None,
127+
"generation_kwargs": {},
128+
"timeout": 30.0,
129+
"max_retries": 5,
130+
"default_headers": {},
131+
"tools": [
132+
{
133+
"type": "haystack.tools.tool.Tool",
134+
"data": {
135+
"description": "description",
136+
"function": "builtins.print",
137+
"name": "name",
138+
"parameters": {"x": {"type": "string"}},
139+
},
140+
}
141+
],
142+
"tools_strict": False,
143+
},
144+
}
145+
146+
generator = AzureOpenAIChatGenerator.from_dict(data)
147+
assert isinstance(generator, AzureOpenAIChatGenerator)
148+
149+
assert generator.api_key == Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False)
150+
assert generator.azure_ad_token == Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False)
151+
assert generator.api_version == "2023-05-15"
152+
assert generator.azure_endpoint == "some-non-existing-endpoint"
153+
assert generator.azure_deployment == "gpt-4o-mini"
154+
assert generator.organization is None
155+
assert generator.streaming_callback is None
156+
assert generator.generation_kwargs == {}
157+
assert generator.timeout == 30.0
158+
assert generator.max_retries == 5
159+
assert generator.default_headers == {}
160+
assert generator.tools == [
161+
Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)
162+
]
163+
assert generator.tools_strict == False
164+
92165
def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch):
93166
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
94167
generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
95168
p = Pipeline()
96169
p.add_component(instance=generator, name="generator")
170+
171+
assert p.to_dict() == {
172+
"metadata": {},
173+
"max_runs_per_component": 100,
174+
"components": {
175+
"generator": {
176+
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",
177+
"init_parameters": {
178+
"azure_endpoint": "some-non-existing-endpoint",
179+
"azure_deployment": "gpt-4o-mini",
180+
"organization": None,
181+
"api_version": "2023-05-15",
182+
"streaming_callback": None,
183+
"generation_kwargs": {},
184+
"timeout": 30.0,
185+
"max_retries": 5,
186+
"api_key": {"type": "env_var", "env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False},
187+
"azure_ad_token": {"type": "env_var", "env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False},
188+
"default_headers": {},
189+
"tools": None,
190+
"tools_strict": False,
191+
},
192+
}
193+
},
194+
"connections": [],
195+
}
97196
p_str = p.dumps()
98197
q = Pipeline.loads(p_str)
99198
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed."
@@ -117,4 +216,29 @@ def test_live_run(self):
117216
assert "gpt-4o-mini" in message.meta["model"]
118217
assert message.meta["finish_reason"] == "stop"
119218

219+
@pytest.mark.integration
220+
@pytest.mark.skipif(
221+
not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
222+
reason=(
223+
"Please export env variables called AZURE_OPENAI_API_KEY containing "
224+
"the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "
225+
"the Azure OpenAI endpoint URL to run this test."
226+
),
227+
)
228+
def test_live_run_with_tools(self, tools):
229+
chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
230+
component = AzureOpenAIChatGenerator(organization="HaystackCI", tools=tools)
231+
results = component.run(chat_messages)
232+
assert len(results["replies"]) == 1
233+
message = results["replies"][0]
234+
235+
assert not message.texts
236+
assert not message.text
237+
assert message.tool_calls
238+
tool_call = message.tool_call
239+
assert isinstance(tool_call, ToolCall)
240+
assert tool_call.tool_name == "weather"
241+
assert tool_call.arguments == {"city": "Paris"}
242+
assert message.meta["finish_reason"] == "tool_calls"
243+
120244
# additional tests intentionally omitted as they are covered by test_openai.py

0 commit comments

Comments
 (0)