Skip to content

Add back the SharedTokenCacheCredential to handle token which is cached by the InteractiveBrowserCredential #603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
59 changes: 55 additions & 4 deletions azure-quantum/azure/quantum/_authentication/_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
InteractiveBrowserCredential,
DeviceCodeCredential,
_internal as AzureIdentityInternals,
TokenCachePersistenceOptions,
SharedTokenCacheCredential,
_persistent_cache as AzureIdentityPersistentCache
)
from ._chained import _ChainedTokenCredential
from ._token import _TokenFileCredential
Expand Down Expand Up @@ -84,22 +87,70 @@ def _authority_or_default(self, authority: str, arm_endpoint: str):
return ConnectionConstants.DOGFOOD_AUTHORITY
return ConnectionConstants.AUTHORITY

def _initialize_credentials(self):
def _get_cache_options(self) -> Optional[TokenCachePersistenceOptions]:
"""
Returns a valid TokenCachePersistenceOptions
if the AzureIdentity Persistent Cache is accessible.
Returns None otherwise.
"""
cache_options = TokenCachePersistenceOptions(
allow_unencrypted_storage=True,
name="AzureQuantumSDK"
)
try:
# pylint: disable=protected-access
cache = AzureIdentityPersistentCache._load_persistent_cache(cache_options)
try:
_LOGGER.error(
'Using Azure.Identity Token Cache at %s. ',
cache._persistence.get_location()
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I am missing something here, but should this be an error log level?

and what do we really guard here in that try/catch clause? cache._persistence.get_location() ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I was using error log level for local testing. I updated it to info.

I guarded the cache._persistence.get_location() because:

  1. We are calling it just for tracing / info logging purposes
  2. _persistence is a private attribute that can change in the future
  3. get_location() could raise an expected exception in a particular situation that we are not aware of
  4. The cache implementation is different for each OS and would be hard for us to guarantee it always work
  5. Finally, I thought it's not worth the risk of this info logging to break our auth :-)

But feel free to modify the code as you see better.
Like we could try to get the location, but even if fails, it we could log the rest of the info ("Using Azure.Identity Token Cache.").

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a small improvement based on my own comment, but feel free to adjust it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes a lot of sense now, thanks :)

except: # pylint: disable=bare-except
pass
return cache_options
except Exception as ex: # pylint: disable=broad-except
_LOGGER.warning(
'Error trying to access Azure.Identity Token Cache at %s. '
'Raised unexpected exception:\n%s',
self.__class__._get_cache_options.__qualname__,
ex,
exc_info=_LOGGER.isEnabledFor(logging.DEBUG),
)
return None

def _initialize_credentials(self) -> None:
self._discover_tenant_id_(
arm_endpoint=self.arm_endpoint,
subscription_id=self.subscription_id)
cache_options = self._get_cache_options()
credentials = []
credentials.append(_TokenFileCredential())
credentials.append(EnvironmentCredential())
if self.client_id:
credentials.append(ManagedIdentityCredential(client_id=self.client_id))
if self.authority and self.tenant_id:
credentials.append(VisualStudioCodeCredential(authority=self.authority, tenant_id=self.tenant_id))
credentials.append(VisualStudioCodeCredential(
authority=self.authority,
tenant_id=self.tenant_id))
credentials.append(AzureCliCredential(tenant_id=self.tenant_id))
credentials.append(AzurePowerShellCredential(tenant_id=self.tenant_id))
credentials.append(InteractiveBrowserCredential(authority=self.authority, tenant_id=self.tenant_id))
# The SharedTokenCacheCredential is used when the token cache
# is available to attempt loading a token stored in the cache
# by the InteractiveBrowserCredential.
if cache_options:
credentials.append(SharedTokenCacheCredential(
authority=self.authority,
Copy link
Contributor

@kikomiss kikomiss Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might also need to pass tenant_id=self.tenant_id to here. It will allow the filtering part to pick-up accounts only the discovered tenant

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor Author

@ArthurKamalov ArthurKamalov Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned in this comment #603 (comment), on Linux I have a mismatch of a current tenant id and id from the cache, which leads to exception in the SharedTokenCacheCredential. Still trying to find the cause, still not sure if that's only on my machine or not.

cache_persistence_options=cache_options))
credentials.append(
InteractiveBrowserCredential(
authority=self.authority,
tenant_id=self.tenant_id,
cache_persistence_options=cache_options))
if self.client_id:
credentials.append(DeviceCodeCredential(authority=self.authority, client_id=self.client_id, tenant_id=self.tenant_id))
credentials.append(DeviceCodeCredential(
authority=self.authority,
client_id=self.client_id,
tenant_id=self.tenant_id))
self.credentials = credentials

def get_token(self, *scopes: str, **kwargs) -> AccessToken:
Expand Down
Loading