Skip to content

Commit

Permalink
fix authentication issue
Browse files Browse the repository at this point in the history
  • Loading branch information
devketanpro committed Feb 20, 2025
1 parent 726fc44 commit c4cbf76
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 37 deletions.
1 change: 1 addition & 0 deletions features/mgmt_api/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ async def before_scenario_async(context, scenario):
"MGMT_API_ENABLED": True,
"AUTH_SERVER_SHARED_SECRET": "test-secret",
"CACHE_TYPE": "null",
"ASYNC_AUTH_CLASS": "newsroom.mgmt_api.auth:JWTTokenAuth",
}

context.app = get_app(config=config)
Expand Down
2 changes: 1 addition & 1 deletion features/mgmt_api/mgmt_api_topics.feature
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ Feature: Management API - Topics
"navigation": ["619277ef8bbbbfac6034aab7"]
}
"""
Then we get response code 500
Then we get response code 400

When we post to this "/navigations"
"""
Expand Down
3 changes: 0 additions & 3 deletions features/mgmt_api/mgmt_api_users.feature
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,6 @@ Feature: Management API - Users
"""

Then we get error 400
"""
{"code": 400, "message": "Locale is not in configured list of locales."}
"""

When we post to this "/users"
"""
Expand Down
83 changes: 62 additions & 21 deletions newsroom/mgmt_api/auth/jwt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from superdesk.core.auth.token_auth import TokenAuthorization
from superdesk.core.types import Request
from superdesk.errors import SuperdeskApiError
import logging
import time
from typing import List, Optional

from authlib.jose import jwt
from authlib.jose.errors import BadSignatureError, ExpiredTokenError, DecodeError
from superdesk.core import get_app_config
from time import time
import logging

from newsroom.auth.utils import get_current_request

from superdesk.core import get_app_config
from superdesk.core.types import Request
from superdesk.core.auth.token_auth import TokenAuthorization
from superdesk.errors import SuperdeskApiError
from superdesk.core.auth.rules import endpoint_intrinsic_auth_rule

logger = logging.getLogger(__name__)


Expand All @@ -16,22 +21,37 @@ class JWTTokenAuth(TokenAuthorization):
Implements Async JWT authentication by extending the new async TokenAuthorization.
"""

def get_token_from_request(self, request: Request) -> str | None:
def get_default_auth_rules(self) -> List:
"""
Returns the default authentication rules.
:return: A list of authentication rules.
"""
return [endpoint_intrinsic_auth_rule]

def get_token_from_request(self, request: Request) -> Optional[str]:
"""
Extracts the token from `Authorization` header.
Extracts the token from the `Authorization` header.
:param request: The request object containing headers.
:return: The extracted token or None if not found.
"""
auth = (request.get_header("Authorization") or "").strip()
if auth.lower().startswith(("token", "bearer", "basic")):
return auth.split(" ")[1] if " " in auth else None
return auth if auth else None
return auth or None

def authenticate(self, request: Request = None):
def check_auth(self, request: Optional[Request] = None) -> dict:
"""
Validates the JWT token and authenticates the user.
Validates the JWT token and returns the decoded payload.
:param request: The request object. Defaults to the current request if not provided.
:return: The decoded JWT payload as a dictionary.
:raises SuperdeskApiError: If the token is missing, invalid, or expired.
"""
if request is None:
request = get_current_request()
request = request or get_current_request()
token = self.get_token_from_request(request)

if not token:
logger.warning("Missing Authorization token")
raise SuperdeskApiError.unauthorizedError()
Expand All @@ -43,38 +63,59 @@ def authenticate(self, request: Request = None):

try:
decoded_jwt = jwt.decode(token, key=secret)
decoded_jwt.validate_exp(now=int(time()), leeway=0)
decoded_jwt.validate_exp(now=int(time.time()), leeway=0)
return decoded_jwt
except (BadSignatureError, ExpiredTokenError, DecodeError) as e:
logger.error(f"JWT authentication failed: {e}")
raise SuperdeskApiError.unauthorizedError()

async def authenticate(self, request: Optional[Request] = None) -> None:
"""
Asynchronously authenticates the request by validating the JWT token.
:param request: The request object. Defaults to the current request if not provided.
:raises SuperdeskApiError: If authentication fails.
"""
decoded_jwt = self.check_auth(request)
self.start_session(request, decoded_jwt)

def start_session(self, request: Request, token_data: dict):
def start_session(self, request: Request, token_data: dict) -> None:
"""
Starts a session by storing token data.
Starts a session by storing token data in the request storage.
:param request: The request object.
:param token_data: The decoded JWT payload.
"""
request.storage.request.set("auth_token", token_data)
request.storage.request.set("user_id", token_data.get("client_id"))

def get_current_user(self, request: Request):
def get_current_user(self, request: Request) -> Optional[str]:
"""
Retrieves the current user from the session.
Retrieves the current user ID from the session.
:param request: The request object.
:return: The user ID if available, otherwise None.
"""
return request.storage.request.get("user_id")

def authorized(self, allowed_roles, resource, method) -> bool:
def authorized(self, allowed_roles: List[str], resource: str, method: str) -> bool:
"""
Checks if the request is authorized by validating the token.
:param allowed_roles: A list of roles allowed to access the resource.
:param resource: The resource being accessed.
:param method: The HTTP method of the request.
:return: True if authorized, False otherwise.
"""
request = get_current_request()
token = self.get_token_from_request(request)

if not token:
logger.warning("No token found in request")
return False
raise SuperdeskApiError.unauthorizedError()

try:
self.authenticate(request)
self.check_auth(request)
return True
except SuperdeskApiError:
return False # Return False instead of raising an error
2 changes: 1 addition & 1 deletion newsroom/mgmt_api/companies.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def on_delete(self, doc: CompanyResource):
data_class=CompanyResource,
service=CPCompaniesService,
mongo=MongoResourceConfig(prefix=MONGO_PREFIX),
rest_endpoints=RestEndpointConfig(auth=False),
rest_endpoints=RestEndpointConfig(),
)

module = Module(
Expand Down
8 changes: 2 additions & 6 deletions newsroom/mgmt_api/companies_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ async def get_company(self) -> CompanyResource | None:
return await CompanyResource.get_service().find_by_id(self.company_id)


@company_products_endpoints.endpoint(
"/api/companies/<regex('[a-f0-9]{24}'):company_id>/products", methods=["POST"], auth=False
)
@company_products_endpoints.endpoint("/api/companies/<regex('[a-f0-9]{24}'):company_id>/products", methods=["POST"])
async def update_company_products(args: CompanyProductRouteArguments, params: None, request: Request) -> Response:
company = await args.get_company()
if not company:
Expand Down Expand Up @@ -78,9 +76,7 @@ async def update_company_products(args: CompanyProductRouteArguments, params: No
return Response({"updated_product_ids": ids}, 201)


@company_products_endpoints.endpoint(
"/api/companies/<regex('[a-f0-9]{24}'):company_id>/products", methods=["GET"], auth=False
)
@company_products_endpoints.endpoint("/api/companies/<regex('[a-f0-9]{24}'):company_id>/products", methods=["GET"])
async def get_company_products_endpoint(args: CompanyProductRouteArguments, params: None, request: Request) -> Response:
company = await args.get_company()
if not company:
Expand Down
2 changes: 1 addition & 1 deletion newsroom/mgmt_api/navigations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
data_class=NavigationModel,
service=NavigationsService,
mongo=MongoResourceConfig(prefix=MONGO_PREFIX),
rest_endpoints=RestEndpointConfig(auth=False),
rest_endpoints=RestEndpointConfig(),
)

module = Module(
Expand Down
2 changes: 1 addition & 1 deletion newsroom/mgmt_api/products.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
data_class=ProductResourceModel,
service=ProductsService,
mongo=MongoResourceConfig(prefix=MONGO_PREFIX),
rest_endpoints=RestEndpointConfig(auth=False),
rest_endpoints=RestEndpointConfig(),
uses_etag=False,
)

Expand Down
4 changes: 2 additions & 2 deletions newsroom/mgmt_api/topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ async def on_update(self, updates: Dict[str, Any], original: TopicResourceModel)
data_class=TopicResourceModel,
service=GlobalTopicsService,
mongo=MongoResourceConfig(prefix=MONGO_PREFIX),
rest_endpoints=RestEndpointConfig(auth=False),
rest_endpoints=RestEndpointConfig(),
)

folders_resource_config = ResourceConfig(
name="topic_folders",
data_class=TopicFolderResourceModel,
service=FolderResourceService,
mongo=MongoResourceConfig(prefix=MONGO_PREFIX),
rest_endpoints=RestEndpointConfig(auth=False),
rest_endpoints=RestEndpointConfig(),
)

module = Module(
Expand Down
2 changes: 1 addition & 1 deletion newsroom/mgmt_api/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def find(self, req: Request):
data_class=UserResourceModel,
service=CPUsersService,
mongo=MongoResourceConfig(prefix=MONGO_PREFIX),
rest_endpoints=RestEndpointConfig(auth=False),
rest_endpoints=RestEndpointConfig(),
)

module = Module(
Expand Down

0 comments on commit c4cbf76

Please sign in to comment.