-
Notifications
You must be signed in to change notification settings - Fork 95
/
Copy path_default.py
224 lines (208 loc) · 9.4 KB
/
_default.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import sys
import logging
import re
from typing import Optional
import urllib3
from azure.core.credentials import AccessToken
from azure.identity import (
AzurePowerShellCredential,
EnvironmentCredential,
ManagedIdentityCredential,
AzureCliCredential,
VisualStudioCodeCredential,
InteractiveBrowserCredential,
DeviceCodeCredential,
_internal as AzureIdentityInternals,
TokenCachePersistenceOptions,
SharedTokenCacheCredential,
_persistent_cache as AzureIdentityPersistentCache
)
from azure.quantum._constants import ConnectionConstants
from ._chained import _ChainedTokenCredential
from ._token import _TokenFileCredential
_LOGGER = logging.getLogger(__name__)
WWW_AUTHENTICATE_REGEX = re.compile(
r"""
^
Bearer\sauthorization_uri="
https://(?P<authority>[^/]*)/
(?P<tenant_id>[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})
"
""",
re.VERBOSE | re.IGNORECASE)
WWW_AUTHENTICATE_HEADER_NAME = "WWW-Authenticate"
class _DefaultAzureCredential(_ChainedTokenCredential):
"""
Based on Azure.Identity.DefaultAzureCredential from:
https://github.com/Azure/azure-sdk-for-python/blob/master/sdk/identity/azure-identity/azure/identity/_credentials/default.py
The three key differences are:
1) Inherit from _ChainedTokenCredential, which has
more aggressive error handling than ChainedTokenCredential
2) Instantiate the internal credentials the first time the get_token gets called
such that we can get the tenant_id if it was not passed by the user (but we don't
want to do that in the constructor).
We automatically identify the user's tenant_id for a given subscription
so that users with MSA accounts don't need to pass it.
This is a mitigation for bug https://github.com/Azure/azure-sdk-for-python/issues/18975
We need the following parameters to enable auto-detection of tenant_id
- subscription_id
- arm_endpoint (defaults to the production url "https://management.azure.com/")
3) Add custom TokenFileCredential as first method to attempt,
which will look for a local access token.
"""
def __init__(
self,
arm_endpoint: str,
subscription_id: str,
client_id: Optional[str] = None,
tenant_id: Optional[str] = None,
authority: Optional[str] = None,
) -> None:
if arm_endpoint is None:
raise ValueError("arm_endpoint is mandatory parameter")
if subscription_id is None:
raise ValueError("subscription_id is mandatory parameter")
self.authority = self._authority_or_default(
authority=authority,
arm_endpoint=arm_endpoint)
self.tenant_id = tenant_id
self.subscription_id = subscription_id
self.arm_endpoint = arm_endpoint
self.client_id = client_id
# credentials will be created lazy on the first call to get_token
super(_DefaultAzureCredential, self).__init__()
def _authority_or_default(self, authority: str, arm_endpoint: str):
if authority:
return AzureIdentityInternals.normalize_authority(authority)
if arm_endpoint == ConnectionConstants.ARM_DOGFOOD_ENDPOINT:
return ConnectionConstants.DOGFOOD_AUTHORITY
return ConnectionConstants.AUTHORITY
def _get_cache_options(self) -> Optional[TokenCachePersistenceOptions]:
"""
Returns a valid TokenCachePersistenceOptions
if the AzureIdentity Persistent Cache is accessible.
Returns None otherwise.
"""
cache_options = TokenCachePersistenceOptions(
allow_unencrypted_storage=False,
name="AzureQuantumSDK"
)
try:
# pylint: disable=protected-access
cache = AzureIdentityPersistentCache._load_persistent_cache(cache_options)
try:
# Try to get the location of the cache for
# tracing purpose.
_LOGGER.info(
"Using Azure.Identity Token Cache at %s.",
cache._persistence.get_location()
)
except: # pylint: disable=bare-except
_LOGGER.info("Using Azure.Identity Token Cache.")
return cache_options
except Exception as ex: # pylint: disable=broad-except
# Check if the cache issue on linux is due
# libsecret not functioning to provider better
# information to the user.
if sys.platform.startswith("linux"):
try:
# pylint: disable=import-outside-toplevel
from msal_extensions.libsecret import trial_run
trial_run()
except Exception as libsecret_ex: # pylint: disable=broad-except
_LOGGER.warning(
"libsecret dependencies are not installed or are unusable.\n"
"Please install the necessary dependencies as instructed in "
"https://github.com/AzureAD/microsoft-authentication-extensions-for-python/wiki/Encryption-on-Linux" # pylint: disable=line-too-long
"Exception:\n%s",
libsecret_ex,
exc_info=_LOGGER.isEnabledFor(logging.DEBUG),
)
_LOGGER.warning(
'Error trying to access Azure.Identity Token Cache. '
"Raised unexpected exception:\n%s",
ex,
exc_info=_LOGGER.isEnabledFor(logging.DEBUG),
)
return None
def _initialize_credentials(self) -> None:
self._discover_tenant_id_(
arm_endpoint=self.arm_endpoint,
subscription_id=self.subscription_id)
cache_options = self._get_cache_options()
credentials = []
credentials.append(_TokenFileCredential())
credentials.append(EnvironmentCredential())
if self.client_id:
credentials.append(ManagedIdentityCredential(client_id=self.client_id))
if self.authority and self.tenant_id:
credentials.append(VisualStudioCodeCredential(
authority=self.authority,
tenant_id=self.tenant_id))
credentials.append(AzureCliCredential(tenant_id=self.tenant_id))
credentials.append(AzurePowerShellCredential(tenant_id=self.tenant_id))
# The SharedTokenCacheCredential is used when the token cache
# is available to attempt loading a token stored in the cache
# by the InteractiveBrowserCredential.
if cache_options:
credentials.append(SharedTokenCacheCredential(
authority=self.authority,
tenant_id=self.tenant_id,
cache_persistence_options=cache_options))
credentials.append(
InteractiveBrowserCredential(
authority=self.authority,
tenant_id=self.tenant_id,
cache_persistence_options=cache_options))
if self.client_id:
credentials.append(DeviceCodeCredential(
authority=self.authority,
client_id=self.client_id,
tenant_id=self.tenant_id))
self.credentials = credentials
def get_token(self, *scopes: str, **kwargs) -> AccessToken:
"""
Request an access token for `scopes`.
This method is called automatically by Azure SDK clients.
:param str scopes: desired scopes for the access token.
This method requires at least one scope.
:raises ~azure.core.exceptions.ClientAuthenticationError:authentication failed.
The exception has a `message` attribute listing each authentication
attempt and its error message.
"""
# lazy-initialize the credentials
if self.credentials is None or len(self.credentials) == 0:
self._initialize_credentials()
return super(_DefaultAzureCredential, self).get_token(*scopes, **kwargs)
def _discover_tenant_id_(self, arm_endpoint:str, subscription_id:str):
"""
If the tenant_id was not given, try to obtain it
by calling the management endpoint for the subscription_id,
or by applying default values.
"""
if self.tenant_id:
return
try:
url = (
f"{arm_endpoint.rstrip('/')}/subscriptions/"
+ f"{subscription_id}?api-version=2018-01-01"
+ "&discover-tenant-id" # used by the test recording infrastructure
)
http = urllib3.PoolManager()
response = http.request(
method="GET",
url=url,
)
if WWW_AUTHENTICATE_HEADER_NAME in response.headers:
www_authenticate = response.headers[WWW_AUTHENTICATE_HEADER_NAME]
match = re.search(WWW_AUTHENTICATE_REGEX, www_authenticate)
if match:
self.tenant_id = match.group("tenant_id")
except Exception as ex: # pylint: disable=broad-exception-caught
_LOGGER.error(ex)
# apply default values
self.tenant_id = self.tenant_id or ConnectionConstants.MSA_TENANT_ID