Skip to content

fix(azure): remove unnecessary model parameter and require azure deployment #2123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
22 changes: 12 additions & 10 deletions src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,7 @@ def __init__(
raise ValueError(
"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
)

if azure_deployment is not None:
base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
else:
base_url = f"{azure_endpoint.rstrip('/')}/openai"
base_url = f"{azure_endpoint.rstrip('/')}/openai"
else:
if azure_endpoint is not None:
raise ValueError("base_url and azure_endpoint are mutually exclusive")
Expand All @@ -229,6 +225,7 @@ def __init__(
# define a sentinel value to avoid any typing issues
api_key = API_KEY_SENTINEL

self._azure_deployment = azure_deployment
super().__init__(
api_key=api_key,
organization=organization,
Expand Down Expand Up @@ -337,10 +334,14 @@ def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL
"api-version": self._api_version,
"deployment": self._azure_deployment or model,
}
if self._azure_deployment:
query["deployment"] = self._azure_deployment
else:
query["deployment"] = model
if self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = self._get_azure_ad_token()
token = await self._get_azure_ad_token()
if token:
auth_headers = {"Authorization": f"Bearer {token}"}

Expand Down Expand Up @@ -491,10 +492,7 @@ def __init__(
"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
)

if azure_deployment is not None:
base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
else:
base_url = f"{azure_endpoint.rstrip('/')}/openai"
base_url = f"{azure_endpoint.rstrip('/')}/openai"
else:
if azure_endpoint is not None:
raise ValueError("base_url and azure_endpoint are mutually exclusive")
Expand Down Expand Up @@ -613,6 +611,10 @@ async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[htt
"api-version": self._api_version,
"deployment": self._azure_deployment or model,
}
if self._azure_deployment:
query["deployment"] = self._azure_deployment
else:
query["deployment"] = model
if self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
Expand Down
34 changes: 16 additions & 18 deletions src/openai/resources/beta/realtime/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,15 +362,14 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
extra_query = self.__extra_query
auth_headers = self.__client.auth_headers
if is_async_azure_client(self.__client):
url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
else:
url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
**extra_query,
},
)
extra_query, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)

url = self._prepare_url().copy_with(
params={
"model": self.__model,
**extra_query,
},
)
log.debug("Connecting to %s", url)
if self.__websocket_connection_options:
log.debug("Connection options: %s", self.__websocket_connection_options)
Expand Down Expand Up @@ -546,15 +545,14 @@ def __enter__(self) -> RealtimeConnection:
extra_query = self.__extra_query
auth_headers = self.__client.auth_headers
if is_azure_client(self.__client):
url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)
else:
url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
**extra_query,
},
)
extra_query, auth_headers = self.__client._configure_realtime(self.__model, extra_query)

url = self._prepare_url().copy_with(
params={
"model": self.__model,
**extra_query,
},
)
log.debug("Connecting to %s", url)
if self.__websocket_connection_options:
log.debug("Connection options: %s", self.__websocket_connection_options)
Expand Down