From 748faffe023f76f6e199feff4828fe547227ea6a Mon Sep 17 00:00:00 2001 From: logerzerox Date: Mon, 24 Feb 2025 17:21:36 +0800 Subject: [PATCH] Add API key token generation endpoint --- py/core/main/api/v3/users_router.py | 22 ++++++++++++++++++++++ py/core/main/services/auth_service.py | 4 ++++ py/core/providers/auth/r2r_auth.py | 16 ++++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/py/core/main/api/v3/users_router.py b/py/core/main/api/v3/users_router.py index d1181813d..4b7b22994 100644 --- a/py/core/main/api/v3/users_router.py +++ b/py/core/main/api/v3/users_router.py @@ -52,6 +52,28 @@ def __init__( self.github_redirect_uri = os.environ.get("GITHUB_REDIRECT_URI") def _setup_routes(self): + @self.router.post( + "/users/generate-tokens", + response_model=WrappedTokenResponse, + openapi_extra={ + "x-codeSamples": [ + { + "lang": "cURL", + "source": "curl -X POST https://api.example.com/v3/users/generate-tokens -H 'x-api-key: YOUR_API_KEY'" + } + ] + } + ) + @self.base_endpoint + async def generate_tokens_via_api_key( + auth_user=Depends(self.providers.auth.auth_wrapper(public=True)) + ) -> WrappedTokenResponse: + """Generate new access and refresh tokens using API key authentication.""" + result = await self.services.auth.generate_tokens_via_api_key( + user_id=auth_user.id + ) + return result + @self.router.post( "/users", # dependencies=[Depends(self.rate_limit_dependency)], diff --git a/py/core/main/services/auth_service.py b/py/core/main/services/auth_service.py index 980b574b6..9bfe7c3d3 100644 --- a/py/core/main/services/auth_service.py +++ b/py/core/main/services/auth_service.py @@ -330,3 +330,7 @@ async def list_user_api_keys(self, user_id: UUID) -> list[dict]: dict: Contains the list of API keys """ return await self.providers.auth.list_user_api_keys(user_id) + + async def generate_tokens_via_api_key(self, user_id: UUID) -> dict[str, Token]: + """Expose the provider method through the service layer.""" + return await self.providers.auth.generate_tokens_via_api_key(user_id) diff --git a/py/core/providers/auth/r2r_auth.py b/py/core/providers/auth/r2r_auth.py index 762884ce3..6c079c01b 100644 --- a/py/core/providers/auth/r2r_auth.py +++ b/py/core/providers/auth/r2r_auth.py @@ -699,3 +699,19 @@ async def oauth_callback_handler( "access_token": Token(token=access_token, token_type="access"), "refresh_token": Token(token=refresh_token, token_type="refresh"), } + + async def generate_tokens_via_api_key(self, user_id: UUID) -> dict[str, Token]: + """Generate new tokens for API key authenticated users.""" + user = await self.database_provider.users_handler.get_user_by_id(user_id) + + access_token = self.create_access_token( + data={"sub": normalize_email(user.email)} + ) + refresh_token = self.create_refresh_token( + data={"sub": normalize_email(user.email)} + ) + + return { + "access_token": Token(token=access_token, token_type="access"), + "refresh_token": Token(token=refresh_token, token_type="refresh"), + }