diff --git a/features/mgmt_api/environment.py b/features/mgmt_api/environment.py index 88bd8e110..9149483e9 100644 --- a/features/mgmt_api/environment.py +++ b/features/mgmt_api/environment.py @@ -54,6 +54,7 @@ async def before_scenario_async(context, scenario): "MGMT_API_ENABLED": True, "AUTH_SERVER_SHARED_SECRET": "test-secret", "CACHE_TYPE": "null", + "ASYNC_AUTH_CLASS": "newsroom.mgmt_api.auth:JWTTokenAuth", } context.app = get_app(config=config) diff --git a/features/mgmt_api/mgmt_api_topics.feature b/features/mgmt_api/mgmt_api_topics.feature index 51afcc674..3696dfb21 100644 --- a/features/mgmt_api/mgmt_api_topics.feature +++ b/features/mgmt_api/mgmt_api_topics.feature @@ -131,7 +131,7 @@ Feature: Management API - Topics "navigation": ["619277ef8bbbbfac6034aab7"] } """ - Then we get response code 500 + Then we get response code 400 When we post to this "/navigations" """ diff --git a/features/mgmt_api/mgmt_api_users.feature b/features/mgmt_api/mgmt_api_users.feature index 9551eeaeb..46f83a950 100644 --- a/features/mgmt_api/mgmt_api_users.feature +++ b/features/mgmt_api/mgmt_api_users.feature @@ -192,9 +192,6 @@ Feature: Management API - Users """ Then we get error 400 - """ - {"code": 400, "message": "Locale is not in configured list of locales."} - """ When we post to this "/users" """ diff --git a/newsroom/mgmt_api/auth/jwt.py b/newsroom/mgmt_api/auth/jwt.py index 5e5ea2774..150f5dfbe 100644 --- a/newsroom/mgmt_api/auth/jwt.py +++ b/newsroom/mgmt_api/auth/jwt.py @@ -1,13 +1,18 @@ -from superdesk.core.auth.token_auth import TokenAuthorization -from superdesk.core.types import Request -from superdesk.errors import SuperdeskApiError +import logging +import time +from typing import List, Optional + from authlib.jose import jwt from authlib.jose.errors import BadSignatureError, ExpiredTokenError, DecodeError -from superdesk.core import get_app_config -from time import time -import logging + from newsroom.auth.utils import get_current_request +from superdesk.core import get_app_config +from superdesk.core.types import Request +from superdesk.core.auth.token_auth import TokenAuthorization +from superdesk.errors import SuperdeskApiError +from superdesk.core.auth.rules import endpoint_intrinsic_auth_rule + logger = logging.getLogger(__name__) @@ -16,22 +21,37 @@ class JWTTokenAuth(TokenAuthorization): Implements Async JWT authentication by extending the new async TokenAuthorization. """ - def get_token_from_request(self, request: Request) -> str | None: + def get_default_auth_rules(self) -> List: + """ + Returns the default authentication rules. + + :return: A list of authentication rules. + """ + return [endpoint_intrinsic_auth_rule] + + def get_token_from_request(self, request: Request) -> Optional[str]: """ - Extracts the token from `Authorization` header. + Extracts the token from the `Authorization` header. + + :param request: The request object containing headers. + :return: The extracted token or None if not found. """ auth = (request.get_header("Authorization") or "").strip() if auth.lower().startswith(("token", "bearer", "basic")): return auth.split(" ")[1] if " " in auth else None - return auth if auth else None + return auth or None - def authenticate(self, request: Request = None): + def check_auth(self, request: Optional[Request] = None) -> dict: """ - Validates the JWT token and authenticates the user. + Validates the JWT token and returns the decoded payload. + + :param request: The request object. Defaults to the current request if not provided. + :return: The decoded JWT payload as a dictionary. + :raises SuperdeskApiError: If the token is missing, invalid, or expired. """ - if request is None: - request = get_current_request() + request = request or get_current_request() token = self.get_token_from_request(request) + if not token: logger.warning("Missing Authorization token") raise SuperdeskApiError.unauthorizedError() @@ -43,38 +63,59 @@ def authenticate(self, request: Request = None): try: decoded_jwt = jwt.decode(token, key=secret) - decoded_jwt.validate_exp(now=int(time()), leeway=0) + decoded_jwt.validate_exp(now=int(time.time()), leeway=0) + return decoded_jwt except (BadSignatureError, ExpiredTokenError, DecodeError) as e: logger.error(f"JWT authentication failed: {e}") raise SuperdeskApiError.unauthorizedError() + async def authenticate(self, request: Optional[Request] = None) -> None: + """ + Asynchronously authenticates the request by validating the JWT token. + + :param request: The request object. Defaults to the current request if not provided. + :raises SuperdeskApiError: If authentication fails. + """ + decoded_jwt = self.check_auth(request) self.start_session(request, decoded_jwt) - def start_session(self, request: Request, token_data: dict): + def start_session(self, request: Request, token_data: dict) -> None: """ - Starts a session by storing token data. + Starts a session by storing token data in the request storage. + + :param request: The request object. + :param token_data: The decoded JWT payload. """ request.storage.request.set("auth_token", token_data) request.storage.request.set("user_id", token_data.get("client_id")) - def get_current_user(self, request: Request): + def get_current_user(self, request: Request) -> Optional[str]: """ - Retrieves the current user from the session. + Retrieves the current user ID from the session. + + :param request: The request object. + :return: The user ID if available, otherwise None. """ return request.storage.request.get("user_id") - def authorized(self, allowed_roles, resource, method) -> bool: + def authorized(self, allowed_roles: List[str], resource: str, method: str) -> bool: """ Checks if the request is authorized by validating the token. + + :param allowed_roles: A list of roles allowed to access the resource. + :param resource: The resource being accessed. + :param method: The HTTP method of the request. + :return: True if authorized, False otherwise. """ request = get_current_request() token = self.get_token_from_request(request) + if not token: logger.warning("No token found in request") - return False + raise SuperdeskApiError.unauthorizedError() try: - self.authenticate(request) + self.check_auth(request) return True except SuperdeskApiError: return False # Return False instead of raising an error diff --git a/newsroom/mgmt_api/companies.py b/newsroom/mgmt_api/companies.py index 6e56bc93d..aa25a750a 100644 --- a/newsroom/mgmt_api/companies.py +++ b/newsroom/mgmt_api/companies.py @@ -43,7 +43,7 @@ async def on_delete(self, doc: CompanyResource): data_class=CompanyResource, service=CPCompaniesService, mongo=MongoResourceConfig(prefix=MONGO_PREFIX), - rest_endpoints=RestEndpointConfig(auth=False), + rest_endpoints=RestEndpointConfig(), ) module = Module( diff --git a/newsroom/mgmt_api/companies_products.py b/newsroom/mgmt_api/companies_products.py index 50ef18956..4aa6156ce 100644 --- a/newsroom/mgmt_api/companies_products.py +++ b/newsroom/mgmt_api/companies_products.py @@ -40,9 +40,7 @@ async def get_company(self) -> CompanyResource | None: return await CompanyResource.get_service().find_by_id(self.company_id) -@company_products_endpoints.endpoint( - "/api/companies//products", methods=["POST"], auth=False -) +@company_products_endpoints.endpoint("/api/companies//products", methods=["POST"]) async def update_company_products(args: CompanyProductRouteArguments, params: None, request: Request) -> Response: company = await args.get_company() if not company: @@ -78,9 +76,7 @@ async def update_company_products(args: CompanyProductRouteArguments, params: No return Response({"updated_product_ids": ids}, 201) -@company_products_endpoints.endpoint( - "/api/companies//products", methods=["GET"], auth=False -) +@company_products_endpoints.endpoint("/api/companies//products", methods=["GET"]) async def get_company_products_endpoint(args: CompanyProductRouteArguments, params: None, request: Request) -> Response: company = await args.get_company() if not company: diff --git a/newsroom/mgmt_api/navigations.py b/newsroom/mgmt_api/navigations.py index 523cbbeeb..8dfbefe09 100644 --- a/newsroom/mgmt_api/navigations.py +++ b/newsroom/mgmt_api/navigations.py @@ -11,7 +11,7 @@ data_class=NavigationModel, service=NavigationsService, mongo=MongoResourceConfig(prefix=MONGO_PREFIX), - rest_endpoints=RestEndpointConfig(auth=False), + rest_endpoints=RestEndpointConfig(), ) module = Module( diff --git a/newsroom/mgmt_api/products.py b/newsroom/mgmt_api/products.py index 9108c7e85..244d0d93b 100644 --- a/newsroom/mgmt_api/products.py +++ b/newsroom/mgmt_api/products.py @@ -11,7 +11,7 @@ data_class=ProductResourceModel, service=ProductsService, mongo=MongoResourceConfig(prefix=MONGO_PREFIX), - rest_endpoints=RestEndpointConfig(auth=False), + rest_endpoints=RestEndpointConfig(), uses_etag=False, ) diff --git a/newsroom/mgmt_api/topics.py b/newsroom/mgmt_api/topics.py index ca5ea55fe..c275c2fa3 100644 --- a/newsroom/mgmt_api/topics.py +++ b/newsroom/mgmt_api/topics.py @@ -40,7 +40,7 @@ async def on_update(self, updates: Dict[str, Any], original: TopicResourceModel) data_class=TopicResourceModel, service=GlobalTopicsService, mongo=MongoResourceConfig(prefix=MONGO_PREFIX), - rest_endpoints=RestEndpointConfig(auth=False), + rest_endpoints=RestEndpointConfig(), ) folders_resource_config = ResourceConfig( @@ -48,7 +48,7 @@ async def on_update(self, updates: Dict[str, Any], original: TopicResourceModel) data_class=TopicFolderResourceModel, service=FolderResourceService, mongo=MongoResourceConfig(prefix=MONGO_PREFIX), - rest_endpoints=RestEndpointConfig(auth=False), + rest_endpoints=RestEndpointConfig(), ) module = Module( diff --git a/newsroom/mgmt_api/users.py b/newsroom/mgmt_api/users.py index 1485fe57c..50d5a077a 100644 --- a/newsroom/mgmt_api/users.py +++ b/newsroom/mgmt_api/users.py @@ -58,7 +58,7 @@ async def find(self, req: Request): data_class=UserResourceModel, service=CPUsersService, mongo=MongoResourceConfig(prefix=MONGO_PREFIX), - rest_endpoints=RestEndpointConfig(auth=False), + rest_endpoints=RestEndpointConfig(), ) module = Module(