9
9
from urllib .parse import urljoin
10
10
11
11
import httpx
12
+ from oauthlib .oauth2 import OAuth2Error
12
13
from oauthlib .oauth2 import WebApplicationClient
13
- from oauthlib .oauth2 .rfc6749 .errors import CustomOAuth2Error
14
14
from social_core .backends .oauth import BaseOAuth2
15
+ from social_core .exceptions import AuthException
15
16
from social_core .strategy import BaseStrategy
16
- from starlette .exceptions import HTTPException
17
17
from starlette .requests import Request
18
18
from starlette .responses import RedirectResponse
19
19
20
20
from .claims import Claims
21
21
from .client import OAuth2Client
22
-
23
-
24
- class OAuth2LoginError (HTTPException ):
25
- """Raised when any login-related error occurs."""
22
+ from .exceptions import OAuth2AuthenticationError
23
+ from .exceptions import OAuth2InvalidRequestError
26
24
27
25
28
26
class OAuth2Strategy (BaseStrategy ):
@@ -56,6 +54,7 @@ class OAuth2Core:
56
54
_oauth_client : Optional [WebApplicationClient ] = None
57
55
_authorization_endpoint : str = None
58
56
_token_endpoint : str = None
57
+ _state : str = None
59
58
60
59
def __init__ (self , client : OAuth2Client ) -> None :
61
60
self .client_id = client .client_id
@@ -83,6 +82,8 @@ def authorization_url(self, request: Request) -> str:
83
82
oauth2_query_params = dict (state = state , scope = self .scope , redirect_uri = redirect_uri )
84
83
oauth2_query_params .update (request .query_params )
85
84
85
+ self ._state = oauth2_query_params .get ("state" )
86
+
86
87
return str (self ._oauth_client .prepare_request_uri (
87
88
self ._authorization_endpoint ,
88
89
** oauth2_query_params ,
@@ -93,9 +94,11 @@ def authorization_redirect(self, request: Request) -> RedirectResponse:
93
94
94
95
async def token_data (self , request : Request , ** httpx_client_args ) -> dict :
95
96
if not request .query_params .get ("code" ):
96
- raise OAuth2LoginError (400 , "'code' parameter was not found in callback request" )
97
+ raise OAuth2InvalidRequestError (400 , "'code' parameter was not found in callback request" )
97
98
if not request .query_params .get ("state" ):
98
- raise OAuth2LoginError (400 , "'state' parameter was not found in callback request" )
99
+ raise OAuth2InvalidRequestError (400 , "'state' parameter was not found in callback request" )
100
+ if request .query_params .get ("state" ) != self ._state :
101
+ raise OAuth2InvalidRequestError (400 , "'state' parameter does not match" )
99
102
100
103
redirect_uri = self .get_redirect_uri (request )
101
104
scheme = "http" if request .auth .http else "https"
@@ -112,12 +115,14 @@ async def token_data(self, request: Request, **httpx_client_args) -> dict:
112
115
headers .update ({"Accept" : "application/json" })
113
116
auth = httpx .BasicAuth (self .client_id , self .client_secret )
114
117
async with httpx .AsyncClient (auth = auth , ** httpx_client_args ) as session :
115
- response = await session .post (token_url , headers = headers , content = content )
116
118
try :
119
+ response = await session .post (token_url , headers = headers , content = content )
117
120
self ._oauth_client .parse_request_body_response (json .dumps (response .json ()))
118
121
return self .standardize (self .backend .user_data (self .access_token ))
119
- except (CustomOAuth2Error , Exception ) as e :
120
- raise OAuth2LoginError (400 , str (e ))
122
+ except (OAuth2Error , httpx .HTTPError ) as e :
123
+ raise OAuth2InvalidRequestError (400 , str (e ))
124
+ except (AuthException , Exception ) as e :
125
+ raise OAuth2AuthenticationError (401 , str (e ))
121
126
122
127
async def token_redirect (self , request : Request , ** kwargs ) -> RedirectResponse :
123
128
access_token = request .auth .jwt_create (await self .token_data (request , ** kwargs ))
0 commit comments