|
1 | 1 | import traceback |
2 | 2 | import uuid |
| 3 | +from collections.abc import Callable, Coroutine |
| 4 | +from typing import Any |
3 | 5 |
|
4 | | -from fastapi import FastAPI |
5 | 6 | from fastapi.encoders import jsonable_encoder |
6 | 7 | from fastapi.exceptions import RequestValidationError |
| 8 | +from fastapi.routing import APIRoute |
7 | 9 | from httpx import HTTPStatusError |
8 | 10 | from starlette import status |
9 | 11 | from starlette.requests import Request |
10 | | -from starlette.responses import JSONResponse |
| 12 | +from starlette.responses import JSONResponse, Response |
11 | 13 |
|
12 | 14 | from common.exceptions import ( |
13 | 15 | ApplicationException, |
|
16 | 18 | ExceptionSeverity, |
17 | 19 | MissingPrivilegeException, |
18 | 20 | NotFoundException, |
| 21 | + UnauthorizedException, |
19 | 22 | ValidationException, |
20 | 23 | ) |
21 | 24 | from common.logger import logger |
22 | 25 |
|
23 | 26 |
|
24 | | -def add_exception_handlers(app: FastAPI) -> None: |
25 | | - # Handle custom exceptions |
26 | | - app.add_exception_handler(BadRequestException, generic_exception_handler) |
27 | | - app.add_exception_handler(ValidationException, generic_exception_handler) |
28 | | - app.add_exception_handler(NotFoundException, generic_exception_handler) |
29 | | - app.add_exception_handler(MissingPrivilegeException, generic_exception_handler) |
30 | | - |
31 | | - # Override built-in default handler |
32 | | - app.add_exception_handler(RequestValidationError, validation_exception_handler) # type: ignore |
33 | | - app.add_exception_handler(HTTPStatusError, http_exception_handler) |
34 | | - |
35 | | - # Fallback exception handler for all unexpected exceptions |
36 | | - app.add_exception_handler(Exception, fall_back_exception_handler) |
37 | | - |
38 | | - |
39 | 27 | def fall_back_exception_handler(request: Request, exc: Exception) -> JSONResponse: |
40 | 28 | error_id = uuid.uuid4() |
41 | 29 | traceback_string = " ".join(traceback.format_tb(tb=exc.__traceback__)) |
| 30 | + print(traceback_string) |
42 | 31 | logger.error( |
43 | 32 | f"Unexpected unhandled exception ({error_id}): {exc}", |
44 | 33 | extra={"custom_dimensions": {"Error ID": error_id, "Traceback": traceback_string}}, |
@@ -98,3 +87,33 @@ def http_exception_handler(request: Request, exc: HTTPStatusError) -> JSONRespon |
98 | 87 | debug=exc.response, |
99 | 88 | ) |
100 | 89 | ) |
| 90 | + |
| 91 | + |
| 92 | +class ExceptionHandlingRoute(APIRoute): |
| 93 | + """APIRoute class for handling exceptions.""" |
| 94 | + |
| 95 | + def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: |
| 96 | + """Intercept response and return correct exception response.""" |
| 97 | + original_route_handler = super().get_route_handler() |
| 98 | + |
| 99 | + async def custom_route_handler(request: Request) -> Response: |
| 100 | + try: |
| 101 | + return await original_route_handler(request) |
| 102 | + except BadRequestException as e: |
| 103 | + return generic_exception_handler(request, e) |
| 104 | + except ValidationException as e: |
| 105 | + return generic_exception_handler(request, e) |
| 106 | + except NotFoundException as e: |
| 107 | + return generic_exception_handler(request, e) |
| 108 | + except MissingPrivilegeException as e: |
| 109 | + return generic_exception_handler(request, e) |
| 110 | + except RequestValidationError as e: |
| 111 | + return validation_exception_handler(request, e) |
| 112 | + except HTTPStatusError as e: |
| 113 | + return http_exception_handler(request, e) |
| 114 | + except UnauthorizedException as e: |
| 115 | + return generic_exception_handler(request, e) |
| 116 | + except Exception as e: |
| 117 | + return fall_back_exception_handler(request, e) |
| 118 | + |
| 119 | + return custom_route_handler |
0 commit comments