Skip to content

Commit 3a82fae

Browse files
Implement the error handling (GH-28)
2 parents f3ae78f + b946f89 commit 3a82fae

File tree

6 files changed

+105
-23
lines changed

6 files changed

+105
-23
lines changed

docs/references/tutorials.md

+34-10
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,20 @@ scopes required for the API endpoint.
3939

4040
```mermaid
4141
flowchart TB
42-
subgraph level2["request (Starlette's Request object)"]
43-
direction TB
44-
subgraph level1["auth (Starlette's extended Auth Credentials)"]
42+
subgraph level2["request (Starlette's Request object)"]
4543
direction TB
46-
subgraph level0["provider (OAuth2 provider with client's credentials)"]
44+
subgraph level1["auth (Starlette's extended Auth Credentials)"]
4745
direction TB
48-
token["access_token (Access token for the specified scopes)"]
46+
subgraph level0["provider (OAuth2 provider with client's credentials)"]
47+
direction TB
48+
token["access_token (Access token for the specified scopes)"]
49+
end
4950
end
5051
end
51-
end
52-
style level2 fill:#00948680,color:#f6f6f7,stroke:#3c3c43;
53-
style level1 fill:#2b75a080,color:#f6f6f7,stroke:#3c3c43;
54-
style level0 fill:#5c837480,color:#f6f6f7,stroke:#3c3c43;
55-
style token fill:#44506980,color:#f6f6f7,stroke:#3c3c43;
52+
style level2 fill: #00948680, color: #f6f6f7, stroke: #3c3c43;
53+
style level1 fill: #2b75a080, color: #f6f6f7, stroke: #3c3c43;
54+
style level0 fill: #5c837480, color: #f6f6f7, stroke: #3c3c43;
55+
style token fill: #44506980, color: #f6f6f7, stroke: #3c3c43;
5656
```
5757

5858
:::
@@ -129,6 +129,30 @@ approach is useful when there missing mandatory attributes in `request.user` for
129129
database. You need to define a route for provisioning and provide it as `redirect_uri`, so
130130
the [user context](/integration/integration#user-context) will be available for usage.
131131

132+
## Error handling
133+
134+
The exceptions that possibly can occur when using the library are reraised as `HTTPException` with the appropriate
135+
status code and a message describing the actual error cause. So they can be handled in a natural way by following the
136+
FastAPI [docs](https://fastapi.tiangolo.com/tutorial/handling-errors/) on handling errors and using the exceptions from
137+
the `fastapi_oauth2.exceptions` module.
138+
139+
```python
140+
from fastapi_oauth2.exceptions import OAuth2AuthenticationError
141+
142+
@app.exception_handler(OAuth2AuthenticationError)
143+
async def error_handler(request: Request, exc: OAuth2AuthenticationError):
144+
return RedirectResponse(url="/login", status_code=303)
145+
```
146+
147+
The complete list of exceptions is the following.
148+
149+
- `OAuth2Error` - Base exception for all errors raised by the FastAPI OAuth2 library.
150+
- `OAuth2AuthenticationError` - An exception is raised when the authentication fails.
151+
- `OAuth2InvalidRequestError` - An exception is raised when the request is invalid.
152+
153+
The request is considered invalid when one of the mandatory parameters, such as `state` or `code` is missing or the
154+
request fails. And the errors that occur during the OAuth steps are considered authentication errors.
155+
132156
<style>
133157
.info, .details {
134158
border: 0;

examples/demonstration/main.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from fastapi import APIRouter
22
from fastapi import FastAPI
3+
from fastapi import Request
34
from fastapi.staticfiles import StaticFiles
45
from sqlalchemy.orm import Session
6+
from starlette.responses import RedirectResponse
57

68
from config import oauth2_config
79
from database import Base
810
from database import engine
911
from database import get_db
12+
from fastapi_oauth2.exceptions import OAuth2Error
1013
from fastapi_oauth2.middleware import Auth
1114
from fastapi_oauth2.middleware import OAuth2Middleware
1215
from fastapi_oauth2.middleware import User
@@ -37,6 +40,15 @@ async def on_auth(auth: Auth, user: User):
3740

3841

3942
app = FastAPI()
43+
44+
45+
# https://fastapi.tiangolo.com/tutorial/handling-errors/
46+
@app.exception_handler(OAuth2Error)
47+
async def error_handler(request: Request, e: OAuth2Error):
48+
print("An error occurred in OAuth2Middleware", e)
49+
return RedirectResponse(url="/", status_code=303)
50+
51+
4052
app.include_router(router_api)
4153
app.include_router(router_ssr)
4254
app.include_router(oauth2_router)

src/fastapi_oauth2/core.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,18 @@
99
from urllib.parse import urljoin
1010

1111
import httpx
12+
from oauthlib.oauth2 import OAuth2Error
1213
from oauthlib.oauth2 import WebApplicationClient
13-
from oauthlib.oauth2.rfc6749.errors import CustomOAuth2Error
1414
from social_core.backends.oauth import BaseOAuth2
15+
from social_core.exceptions import AuthException
1516
from social_core.strategy import BaseStrategy
16-
from starlette.exceptions import HTTPException
1717
from starlette.requests import Request
1818
from starlette.responses import RedirectResponse
1919

2020
from .claims import Claims
2121
from .client import OAuth2Client
22-
23-
24-
class OAuth2LoginError(HTTPException):
25-
"""Raised when any login-related error occurs."""
22+
from .exceptions import OAuth2AuthenticationError
23+
from .exceptions import OAuth2InvalidRequestError
2624

2725

2826
class OAuth2Strategy(BaseStrategy):
@@ -56,6 +54,7 @@ class OAuth2Core:
5654
_oauth_client: Optional[WebApplicationClient] = None
5755
_authorization_endpoint: str = None
5856
_token_endpoint: str = None
57+
_state: str = None
5958

6059
def __init__(self, client: OAuth2Client) -> None:
6160
self.client_id = client.client_id
@@ -83,6 +82,8 @@ def authorization_url(self, request: Request) -> str:
8382
oauth2_query_params = dict(state=state, scope=self.scope, redirect_uri=redirect_uri)
8483
oauth2_query_params.update(request.query_params)
8584

85+
self._state = oauth2_query_params.get("state")
86+
8687
return str(self._oauth_client.prepare_request_uri(
8788
self._authorization_endpoint,
8889
**oauth2_query_params,
@@ -93,9 +94,11 @@ def authorization_redirect(self, request: Request) -> RedirectResponse:
9394

9495
async def token_data(self, request: Request, **httpx_client_args) -> dict:
9596
if not request.query_params.get("code"):
96-
raise OAuth2LoginError(400, "'code' parameter was not found in callback request")
97+
raise OAuth2InvalidRequestError(400, "'code' parameter was not found in callback request")
9798
if not request.query_params.get("state"):
98-
raise OAuth2LoginError(400, "'state' parameter was not found in callback request")
99+
raise OAuth2InvalidRequestError(400, "'state' parameter was not found in callback request")
100+
if request.query_params.get("state") != self._state:
101+
raise OAuth2InvalidRequestError(400, "'state' parameter does not match")
99102

100103
redirect_uri = self.get_redirect_uri(request)
101104
scheme = "http" if request.auth.http else "https"
@@ -112,12 +115,14 @@ async def token_data(self, request: Request, **httpx_client_args) -> dict:
112115
headers.update({"Accept": "application/json"})
113116
auth = httpx.BasicAuth(self.client_id, self.client_secret)
114117
async with httpx.AsyncClient(auth=auth, **httpx_client_args) as session:
115-
response = await session.post(token_url, headers=headers, content=content)
116118
try:
119+
response = await session.post(token_url, headers=headers, content=content)
117120
self._oauth_client.parse_request_body_response(json.dumps(response.json()))
118121
return self.standardize(self.backend.user_data(self.access_token))
119-
except (CustomOAuth2Error, Exception) as e:
120-
raise OAuth2LoginError(400, str(e))
122+
except (OAuth2Error, httpx.HTTPError) as e:
123+
raise OAuth2InvalidRequestError(400, str(e))
124+
except (AuthException, Exception) as e:
125+
raise OAuth2AuthenticationError(401, str(e))
121126

122127
async def token_redirect(self, request: Request, **kwargs) -> RedirectResponse:
123128
access_token = request.auth.jwt_create(await self.token_data(request, **kwargs))

src/fastapi_oauth2/exceptions.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from starlette.exceptions import HTTPException
2+
3+
4+
class OAuth2Error(HTTPException):
5+
"""Base OAuth2 exception."""
6+
7+
8+
class OAuth2AuthenticationError(OAuth2Error):
9+
"""Raised when authentication fails."""
10+
11+
12+
class OAuth2InvalidRequestError(OAuth2Error):
13+
"""Raised when request is invalid."""

src/fastapi_oauth2/middleware.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
from typing import Union
1111

1212
from fastapi.security.utils import get_authorization_scheme_param
13+
from jose.exceptions import JOSEError
1314
from jose.jwt import decode as jwt_decode
1415
from jose.jwt import encode as jwt_encode
1516
from starlette.authentication import AuthCredentials
1617
from starlette.authentication import AuthenticationBackend
1718
from starlette.authentication import BaseUser
1819
from starlette.middleware.authentication import AuthenticationMiddleware
1920
from starlette.requests import Request
21+
from starlette.responses import PlainTextResponse
2022
from starlette.types import ASGIApp
2123
from starlette.types import Receive
2224
from starlette.types import Scope
@@ -139,7 +141,14 @@ def __init__(
139141
config = OAuth2Config(**config)
140142
elif not isinstance(config, OAuth2Config):
141143
raise TypeError("config is not a valid type")
144+
self.default_application_middleware = app
142145
self.auth_middleware = AuthenticationMiddleware(app, backend=OAuth2Backend(config, callback), **kwargs)
143146

144147
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
145-
await self.auth_middleware(scope, receive, send)
148+
if scope["type"] == "http":
149+
try:
150+
return await self.auth_middleware(scope, receive, send)
151+
except (JOSEError, Exception) as e:
152+
middleware = PlainTextResponse(str(e), status_code=401)
153+
return await middleware(scope, receive, send)
154+
await self.default_application_middleware(scope, receive, send)

tests/test_oauth2.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from urllib.parse import parse_qs
12
from urllib.parse import urlencode
3+
from urllib.parse import urlparse
24

35
import pytest
46
from httpx import AsyncClient
@@ -14,7 +16,11 @@ async def oauth2_workflow(get_app, idp=False, ssr=True, authorize_query="", toke
1416
response = await client.get("/oauth2/test/authorize" + authorize_query) # Get authorization endpoint
1517
authorization_endpoint = response.headers.get("location") if ssr else response.json().get("url")
1618
response = await client.get(authorization_endpoint) # Authorize
17-
response = await client.get(response.headers.get("location") + token_query) # Obtain token
19+
token_url = response.headers.get("location")
20+
query = {k: v[0] for k, v in parse_qs(urlparse(token_url).query).items()}
21+
query.update({k: v[0] for k, v in parse_qs(token_query).items()})
22+
token_url = "%s?%s" % (token_url.split("?")[0], urlencode(query))
23+
response = await client.get(token_url) # Obtain token
1824

1925
response = await client.get("/user", headers=dict(
2026
Authorization=jwt_encode(response.json(), "") # Set token
@@ -43,3 +49,16 @@ async def test_oauth2_pkce_workflow(get_app):
4349
tq = "&" + urlencode(dict(code_verifier=code_verifier))
4450
await oauth2_workflow(get_app, idp=True, authorize_query=aq, token_query=tq)
4551
await oauth2_workflow(get_app, idp=True, ssr=False, authorize_query=aq, token_query=tq, use_header=True)
52+
53+
54+
@pytest.mark.anyio
55+
async def test_oauth2_csrf_workflow(get_app):
56+
for aq, tq in [
57+
("?state=test_state", "&state=test_state"),
58+
("?state=test_state", "&state=test_wrong_state")
59+
]:
60+
try:
61+
await oauth2_workflow(get_app, idp=True, authorize_query=aq, token_query=tq)
62+
await oauth2_workflow(get_app, idp=True, ssr=False, authorize_query=aq, token_query=tq, use_header=True)
63+
except AssertionError:
64+
assert aq != tq

0 commit comments

Comments
 (0)