diff --git a/fastapi_azure_auth/openid_config.py b/fastapi_azure_auth/openid_config.py index 470fc4a..da9933c 100644 --- a/fastapi_azure_auth/openid_config.py +++ b/fastapi_azure_auth/openid_config.py @@ -1,4 +1,5 @@ import logging +from asyncio import Lock from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any, Dict, List, Optional @@ -11,6 +12,8 @@ log = logging.getLogger('fastapi_azure_auth') +refresh_lock: Lock = Lock() + class OpenIdConfig: def __init__( @@ -35,24 +38,25 @@ async def load_config(self) -> None: """ Loads config from the Intility openid-config endpoint if it's over 24 hours old (or don't exist) """ - refresh_time = datetime.now() - timedelta(hours=24) - if not self._config_timestamp or self._config_timestamp < refresh_time: - try: - log.debug('Loading Azure Entra ID OpenID configuration.') - await self._load_openid_config() - self._config_timestamp = datetime.now() - except Exception as error: - log.exception('Unable to fetch OpenID configuration from Azure Entra ID. Error: %s', error) - # We can't fetch an up to date openid-config, so authentication will not work. - if self._config_timestamp: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail='Connection to Azure Entra ID is down. Unable to fetch provider configuration', - headers={'WWW-Authenticate': 'Bearer'}, - ) from error - - else: - raise RuntimeError(f'Unable to fetch provider information. {error}') from error + async with refresh_lock: + refresh_time = datetime.now() - timedelta(hours=24) + if not self._config_timestamp or self._config_timestamp < refresh_time: + try: + log.debug('Loading Azure Entra ID OpenID configuration.') + await self._load_openid_config() + self._config_timestamp = datetime.now() + except Exception as error: + log.exception('Unable to fetch OpenID configuration from Azure Entra ID. Error: %s', error) + # We can't fetch an up to date openid-config, so authentication will not work. + if self._config_timestamp: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='Connection to Azure Entra ID is down. Unable to fetch provider configuration', + headers={'WWW-Authenticate': 'Bearer'}, + ) from error + + else: + raise RuntimeError(f'Unable to fetch provider information. {error}') from error log.info('fastapi-azure-auth loaded settings from Azure Entra ID.') log.info('authorization endpoint: %s', self.authorization_endpoint) diff --git a/tests/test_provider_config.py b/tests/test_provider_config.py index 757f0c5..d619f23 100644 --- a/tests/test_provider_config.py +++ b/tests/test_provider_config.py @@ -1,11 +1,14 @@ +import asyncio from datetime import datetime, timedelta +import httpx import pytest +import respx from asgi_lifespan import LifespanManager from demo_project.api.dependencies import azure_scheme from demo_project.main import app from httpx import AsyncClient -from tests.utils import build_access_token, build_openid_keys, openid_configuration +from tests.utils import build_access_token, build_openid_keys, keys_url, openid_config_url, openid_configuration from fastapi_azure_auth.openid_config import OpenIdConfig @@ -64,3 +67,28 @@ async def test_custom_config_id(respx_mock): ) await openid_config.load_config() assert len(openid_config.signing_keys) == 2 + + +async def test_concurrent_refresh_requests(): + """Test that concurrent refreshes are handled correctly""" + with respx.mock(assert_all_called=True) as mock: + + async def slow_config_response(*args, **kwargs): + await asyncio.sleep(0.2) + return httpx.Response(200, json=openid_configuration()) + + async def slow_keys_response(*args, **kwargs): + await asyncio.sleep(0.2) + return httpx.Response(200, json=build_openid_keys()) + + config_route = mock.get(openid_config_url()).mock(side_effect=slow_config_response) + keys_route = mock.get(keys_url()).mock(side_effect=slow_keys_response) + + azure_scheme.openid_config._config_timestamp = None + + tasks = [azure_scheme.openid_config.load_config() for _ in range(5)] + await asyncio.gather(*tasks) + + assert len(config_route.calls) == 1, "Config endpoint called multiple times" + assert len(keys_route.calls) == 1, "Keys endpoint called multiple times" + assert len(azure_scheme.openid_config.signing_keys) == 2