Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 106 additions & 3 deletions src/sempy_labs/_authentication.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Literal
from typing import Literal, Optional
from sempy.fabric._token_provider import TokenProvider
from azure.identity import ClientSecretCredential
from sempy._utils._log import log
from contextlib import contextmanager
import contextvars


class ServicePrincipalTokenProvider(TokenProvider):
Expand Down Expand Up @@ -41,6 +44,10 @@ def from_aad_application_key_authentication(
tenant_id=tenant_id, client_id=client_id, client_secret=client_secret
)

cls.tenant_id = tenant_id
cls.client_id = client_id
cls.client_secret = client_secret

return cls(credential)

@classmethod
Expand Down Expand Up @@ -89,20 +96,116 @@ def from_azure_key_vault(
tenant_id=tenant_id, client_id=client_id, client_secret=client_secret
)

cls.tenant_id = tenant_id
cls.client_id = client_id
cls.client_secret = client_secret

return cls(credential)

def __call__(self, audience: Literal["pbi", "storage"] = "pbi") -> str:
def __call__(
self,
audience: Literal[
"pbi", "storage", "azure", "graph", "asazure", "keyvault"
] = "pbi",
region: Optional[str] = None,
) -> str:
"""
Parameters
----------
audience : Literal["pbi", "storage"] = "pbi") -> str
audience : Literal["pbi", "storage", "azure", "graph", "asazure", "keyvault"] = "pbi") -> str
Literal if it's for PBI/Fabric API call or OneLake/Storage Account call.
region : str, default=None
The region of the Azure Analysis Services. For example: 'westus2'.
"""
if audience == "pbi":
return self.credential.get_token(
"https://analysis.windows.net/powerbi/api/.default"
).token
elif audience == "storage":
return self.credential.get_token("https://storage.azure.com/.default").token
elif audience == "azure":
return self.credential.get_token(
"https://management.azure.com/.default"
).token
elif audience == "graph":
return self.credential.get_token(
"https://graph.microsoft.com/.default"
).token
elif audience == "asazure":
return self.credential.get_token(
f"https://{region}.asazure.windows.net/.default"
).token
elif audience == "keyvault":
return self.credential.get_token("https://vault.azure.net/.default").token
else:
raise NotImplementedError


def _get_headers(
token_provider: str,
audience: Literal[
"pbi", "storage", "azure", "graph", "asazure", "keyvault"
] = "azure",
):
"""
Generates headers for an API request.
"""

token = token_provider(audience=audience)

headers = {"Authorization": f"Bearer {token}"}

if audience == "graph":
headers["ConsistencyLevel"] = "eventual"
else:
headers["Content-Type"] = "application/json"

return headers


token_provider = contextvars.ContextVar("token_provider", default=None)


@log
@contextmanager
def service_principal_authentication(
key_vault_uri: str,
key_vault_tenant_id: str,
key_vault_client_id: str,
key_vault_client_secret: str,
):
"""
Establishes an authentication via Service Principal.

Parameters
----------
key_vault_uri : str
Azure Key Vault URI.
key_vault_tenant_id : str
Name of the secret in the Key Vault with the Fabric Tenant ID.
key_vault_client_id : str
Name of the secret in the Key Vault with the Service Principal Client ID.
key_vault_client_secret : str
Name of the secret in the Key Vault with the Service Principal Client Secret.
"""

# Save the prior state
prior_token = token_provider.get()

# Set the new token_provider in a thread-safe manner
token_provider.set(
ServicePrincipalTokenProvider.from_azure_key_vault(
key_vault_uri=key_vault_uri,
key_vault_tenant_id=key_vault_tenant_id,
key_vault_client_id=key_vault_client_id,
key_vault_client_secret=key_vault_client_secret,
)
)
try:
yield
finally:
# Restore the prior state
if prior_token is None:
token_provider.set(None)
else:
token_provider.set(prior_token)
Loading