Skip to content

Commit 685d12b

Browse files
committed
feat(drivers-azure): bump api_version in Azure Prompt/Embedding Drivers
1 parent 6d7b2b7 commit 685d12b

File tree

5 files changed

+43
-13
lines changed

5 files changed

+43
-13
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4343
- Improved mime type detection in `FileManagerTool`.
4444
- Improve `SqlDriver.get_table_schema` speed.
4545
- Cache `SqlDriver.get_table_schema` results.
46+
- Updated Azure Drivers to use the latest Azure OpenAI API version, `2024-10-21`.
4647

4748
### Deprecated
4849

griptape/drivers/embedding/azure_openai_embedding_driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
3636
default=None,
3737
metadata={"serializable": False},
3838
)
39-
api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True})
39+
api_version: str = field(default="2024-10-21", kw_only=True, metadata={"serializable": True})
4040
tokenizer: OpenAiTokenizer = field(
4141
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
4242
kw_only=True,

griptape/drivers/prompt/azure_openai_chat_prompt_driver.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
3737
default=None,
3838
metadata={"serializable": False},
3939
)
40-
api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True})
40+
api_version: str = field(default="2024-10-21", kw_only=True, metadata={"serializable": True})
4141
_client: openai.AzureOpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
4242

4343
@lazy_property()
@@ -54,13 +54,11 @@ def client(self) -> openai.AzureOpenAI:
5454

5555
def _base_params(self, prompt_stack: PromptStack) -> dict:
5656
params = super()._base_params(prompt_stack)
57-
# TODO: Add `seed` parameter once Azure supports it.
58-
if "seed" in params:
57+
if self.api_version < "2024-02-01" and "seed" in params:
5958
del params["seed"]
60-
# TODO: Add `stream_options` parameter once Azure supports it.
61-
if "stream_options" in params:
62-
del params["stream_options"]
63-
# TODO: Add `parallel_tool_calls` parameter once Azure supports it.
64-
if "parallel_tool_calls" in params:
65-
del params["parallel_tool_calls"]
59+
if self.api_version < "2024-10-21":
60+
if "stream_options" in params:
61+
del params["stream_options"]
62+
if "parallel_tool_calls" in params:
63+
del params["parallel_tool_calls"]
6664
return params

tests/unit/configs/drivers/test_azure_openai_drivers_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_to_dict(self, config):
2626
"model": "gpt-4o",
2727
"azure_deployment": "gpt-4o",
2828
"azure_endpoint": "http://localhost:8080",
29-
"api_version": "2023-05-15",
29+
"api_version": "2024-10-21",
3030
"organization": None,
3131
"parallel_tool_calls": True,
3232
"reasoning_effort": "medium",
@@ -47,7 +47,7 @@ def test_to_dict(self, config):
4747
"embedding_driver": {
4848
"base_url": None,
4949
"model": "text-embedding-3-small",
50-
"api_version": "2023-05-15",
50+
"api_version": "2024-10-21",
5151
"azure_deployment": "text-embedding-3-small",
5252
"azure_endpoint": "http://localhost:8080",
5353
"organization": None,
@@ -70,7 +70,7 @@ def test_to_dict(self, config):
7070
"embedding_driver": {
7171
"base_url": None,
7272
"model": "text-embedding-3-small",
73-
"api_version": "2023-05-15",
73+
"api_version": "2024-10-21",
7474
"azure_deployment": "text-embedding-3-small",
7575
"azure_endpoint": "http://localhost:8080",
7676
"organization": None,

tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,15 @@ def test_init(self):
6868

6969
@pytest.mark.parametrize("use_native_tools", [True, False])
7070
@pytest.mark.parametrize("structured_output_strategy", ["native", "tool"])
71+
@pytest.mark.parametrize("api_version", ["2023-05-15", "2024-02-01", "2024-06-01", "2024-10-21"])
7172
def test_try_run(
7273
self,
7374
mock_chat_completion_create,
7475
prompt_stack,
7576
messages,
7677
use_native_tools,
7778
structured_output_strategy,
79+
api_version,
7880
):
7981
# Given
8082
driver = AzureOpenAiChatPromptDriver(
@@ -84,6 +86,7 @@ def test_try_run(
8486
use_native_tools=use_native_tools,
8587
structured_output_strategy=structured_output_strategy,
8688
extra_params={"foo": "bar"},
89+
api_version=api_version,
8790
)
8891

8992
# When
@@ -95,6 +98,16 @@ def test_try_run(
9598
temperature=driver.temperature,
9699
user=driver.user,
97100
messages=messages,
101+
**{
102+
"seed": driver.seed,
103+
}
104+
if driver.api_version >= "2024-02-01"
105+
else {},
106+
**{
107+
"parallel_tool_calls": driver.parallel_tool_calls,
108+
}
109+
if driver.api_version >= "2024-10-21" and prompt_stack.tools and driver.use_native_tools
110+
else {},
98111
**{
99112
"tools": self.OPENAI_TOOLS,
100113
"tool_choice": "required" if structured_output_strategy == "tool" else driver.tool_choice,
@@ -125,13 +138,15 @@ def test_try_run(
125138

126139
@pytest.mark.parametrize("use_native_tools", [True, False])
127140
@pytest.mark.parametrize("structured_output_strategy", ["native", "tool"])
141+
@pytest.mark.parametrize("api_version", ["2023-05-15", "2024-02-01", "2024-06-01", "2024-10-21"])
128142
def test_try_stream_run(
129143
self,
130144
mock_chat_completion_stream_create,
131145
prompt_stack,
132146
messages,
133147
use_native_tools,
134148
structured_output_strategy,
149+
api_version,
135150
):
136151
# Given
137152
driver = AzureOpenAiChatPromptDriver(
@@ -142,6 +157,7 @@ def test_try_stream_run(
142157
use_native_tools=use_native_tools,
143158
structured_output_strategy=structured_output_strategy,
144159
extra_params={"foo": "bar"},
160+
api_version=api_version,
145161
)
146162

147163
# When
@@ -161,6 +177,16 @@ def test_try_stream_run(
161177
}
162178
if use_native_tools
163179
else {},
180+
**{
181+
"seed": driver.seed,
182+
}
183+
if driver.api_version >= "2024-02-01"
184+
else {},
185+
**{
186+
"parallel_tool_calls": driver.parallel_tool_calls,
187+
}
188+
if driver.api_version >= "2024-10-21" and prompt_stack.tools and driver.use_native_tools
189+
else {},
164190
**{
165191
"response_format": {
166192
"type": "json_schema",
@@ -174,6 +200,11 @@ def test_try_stream_run(
174200
if structured_output_strategy == "native"
175201
else {},
176202
foo="bar",
203+
**{
204+
"stream_options": {"include_usage": True},
205+
}
206+
if driver.api_version >= "2024-10-21"
207+
else {},
177208
)
178209

179210
assert isinstance(event.content, TextDeltaMessageContent)

0 commit comments

Comments
 (0)