From 16f158f60099725099579113e58933fb8583549d Mon Sep 17 00:00:00 2001 From: Young Date: Thu, 6 Feb 2025 14:20:34 +0000 Subject: [PATCH] fix: prioritize api_key over tenant_id for Azure AD token provider --- .../client_initalization_utils.py | 5 ++-- .../get_azure_ad_token_provider.py | 23 ++++++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/litellm/router_utils/client_initalization_utils.py b/litellm/router_utils/client_initalization_utils.py index 8a11edce8f1b..486455c9ee17 100644 --- a/litellm/router_utils/client_initalization_utils.py +++ b/litellm/router_utils/client_initalization_utils.py @@ -192,7 +192,8 @@ def set_client( # noqa: PLR0915 organization = get_secret_str(organization_env_name) litellm_params["organization"] = organization azure_ad_token_provider: Optional[Callable[[], str]] = None - if litellm_params.get("tenant_id"): + # If we have api_key, then we have higher priority + if not api_key and litellm_params.get("tenant_id"): verbose_router_logger.debug( "Using Azure AD Token Provider for Azure Auth" ) @@ -223,7 +224,7 @@ def set_client( # noqa: PLR0915 if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) elif ( - azure_ad_token_provider is None + not api_key and azure_ad_token_provider is None and litellm.enable_azure_ad_token_refresh is True ): try: diff --git a/litellm/secret_managers/get_azure_ad_token_provider.py b/litellm/secret_managers/get_azure_ad_token_provider.py index 82e725ee8712..2e243cf970bc 100644 --- a/litellm/secret_managers/get_azure_ad_token_provider.py +++ b/litellm/secret_managers/get_azure_ad_token_provider.py @@ -1,5 +1,6 @@ import os from typing import Callable +from litellm._logging import verbose_logger def get_azure_ad_token_provider() -> Callable[[], str]: @@ -14,20 +15,20 @@ def get_azure_ad_token_provider() -> Callable[[], str]: Returns: Callable that returns a temporary authentication token. """ - from azure.identity import ClientSecretCredential, get_bearer_token_provider + from azure.identity import get_bearer_token_provider + import azure.identity as identity + azure_scope = os.environ.get("AZURE_SCOPE", "https://cognitiveservices.azure.com/.default") + cred = os.environ.get("AZURE_CREDENTIAL", "ClientSecretCredential") - try: - credential = ClientSecretCredential( + cred_cls = getattr(identity, cred) + # ClientSecretCredential, DefaultAzureCredential, AzureCliCredential + if cred == "ClientSecretCredential": + credential = cred_cls( client_id=os.environ["AZURE_CLIENT_ID"], client_secret=os.environ["AZURE_CLIENT_SECRET"], tenant_id=os.environ["AZURE_TENANT_ID"], ) - except KeyError as e: - raise ValueError( - "Missing environment variable required by Azure AD workflow." - ) from e + else: + credential = cred_cls() - return get_bearer_token_provider( - credential, - "https://cognitiveservices.azure.com/.default", - ) + return get_bearer_token_provider(credential, azure_scope)