Skip to content

Commit 7110ea0

Browse files
Implement the registration possible (GH-11)
2 parents c50220a + b876221 commit 7110ea0

File tree

16 files changed

+305
-64
lines changed

16 files changed

+305
-64
lines changed

Diff for: .github/workflows/tests.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ jobs:
3232
env: py311-fastapi84
3333

3434
- python: "3.7"
35-
env: py37-fastapi99
35+
env: py37-fastapi100
3636
- python: "3.9"
37-
env: py39-fastapi99
37+
env: py39-fastapi100
3838
- python: "3.10"
39-
env: py310-fastapi99
39+
env: py310-fastapi100
4040
- python: "3.11"
41-
env: py311-fastapi99
41+
env: py311-fastapi100
4242

4343
steps:
4444
- uses: actions/checkout@v2

Diff for: README.md

+6-2
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ the [social-core](https://github.com/python-social-auth/social-core) authenticat
1313

1414
- Use multiple OAuth2 providers at the same time
1515
* There need to be provided a way to configure the OAuth2 for multiple providers
16-
- Token -> user data, user data -> token easy conversion
1716
- Customizable OAuth2 routes
18-
- Registration support
1917

2018
## Installation
2119

@@ -43,12 +41,14 @@ middleware configuration is declared with the `OAuth2Config` and `OAuth2Client`
4341
- `client_secret` - The OAuth2 client secret for the particular provider.
4442
- `redirect_uri` - The OAuth2 redirect URI to redirect to after success. Defaults to the base URL.
4543
- `scope` - The OAuth2 scope for the particular provider. Defaults to `[]`.
44+
- `claims` - Claims mapping for the certain provider.
4645

4746
It is also important to mention that for the configured clients of the auth providers, the authorization URLs are
4847
accessible by the `/oauth2/{provider}/auth` path where the `provider` variable represents the exact value of the auth
4948
provider backend `name` attribute.
5049

5150
```python
51+
from fastapi_oauth2.claims import Claims
5252
from fastapi_oauth2.client import OAuth2Client
5353
from fastapi_oauth2.config import OAuth2Config
5454
from social_core.backends.github import GithubOAuth2
@@ -65,6 +65,10 @@ oauth2_config = OAuth2Config(
6565
client_secret=os.getenv("OAUTH2_CLIENT_SECRET"),
6666
redirect_uri="https://pysnippet.org/",
6767
scope=["user:email"],
68+
claims=Claims(
69+
picture="avatar_url",
70+
identity=lambda user: "%s:%s" % (user.get("provider"), user.get("id")),
71+
),
6872
),
6973
]
7074
)

Diff for: examples/demonstration/config.py

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dotenv import load_dotenv
44
from social_core.backends.github import GithubOAuth2
55

6+
from fastapi_oauth2.claims import Claims
67
from fastapi_oauth2.client import OAuth2Client
78
from fastapi_oauth2.config import OAuth2Config
89

@@ -20,6 +21,10 @@
2021
client_secret=os.getenv("OAUTH2_CLIENT_SECRET"),
2122
# redirect_uri="http://127.0.0.1:8000/",
2223
scope=["user:email"],
24+
claims=Claims(
25+
picture="avatar_url",
26+
identity=lambda user: "%s:%s" % (user.get("provider"), user.get("id")),
27+
),
2328
),
2429
]
2530
)

Diff for: examples/demonstration/database.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from sqlalchemy import create_engine
2+
from sqlalchemy.ext.declarative import declarative_base
3+
from sqlalchemy.orm import sessionmaker
4+
5+
engine = create_engine(
6+
"sqlite:///./database.sqlite",
7+
connect_args={
8+
"check_same_thread": False,
9+
},
10+
)
11+
12+
Base = declarative_base()
13+
SessionLocal = sessionmaker(bind=engine, autoflush=False)
14+
15+
16+
def get_db():
17+
db = SessionLocal()
18+
try:
19+
yield db
20+
finally:
21+
db.close()

Diff for: examples/demonstration/main.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,39 @@
11
from fastapi import APIRouter
22
from fastapi import FastAPI
3+
from sqlalchemy.orm import Session
34

45
from config import oauth2_config
6+
from database import Base
7+
from database import engine
8+
from database import get_db
9+
from fastapi_oauth2.middleware import Auth
510
from fastapi_oauth2.middleware import OAuth2Middleware
11+
from fastapi_oauth2.middleware import User
612
from fastapi_oauth2.router import router as oauth2_router
13+
from models import User as UserModel
714
from router import router as app_router
815

16+
Base.metadata.create_all(bind=engine)
17+
918
router = APIRouter()
1019

20+
21+
async def on_auth(auth: Auth, user: User):
22+
# perform a check for user existence in
23+
# the database and create if not exists
24+
db: Session = next(get_db())
25+
query = db.query(UserModel)
26+
if user.identity and not query.filter_by(identity=user.identity).first():
27+
UserModel(**{
28+
"identity": user.get("identity"),
29+
"username": user.get("username"),
30+
"image": user.get("image"),
31+
"email": user.get("email"),
32+
"name": user.get("name"),
33+
}).save(db)
34+
35+
1136
app = FastAPI()
1237
app.include_router(app_router)
1338
app.include_router(oauth2_router)
14-
app.add_middleware(OAuth2Middleware, config=oauth2_config)
39+
app.add_middleware(OAuth2Middleware, config=oauth2_config, callback=on_auth)

Diff for: examples/demonstration/models.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from sqlalchemy import Column
2+
from sqlalchemy import Integer
3+
from sqlalchemy import String
4+
from sqlalchemy.orm import Session
5+
6+
from database import Base
7+
8+
9+
class BaseModel(Base):
10+
__abstract__ = True
11+
12+
def save(self, db: Session):
13+
db.add(self)
14+
db.commit()
15+
db.refresh(self)
16+
return self
17+
18+
19+
class User(BaseModel):
20+
__tablename__ = "users"
21+
22+
id = Column(Integer, primary_key=True, index=True)
23+
username = Column(String)
24+
email = Column(String)
25+
name = Column(String)
26+
image = Column(String)
27+
identity = Column(String, unique=True) # provider_name:user_id

Diff for: examples/demonstration/router.py

+39-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import json
22

3+
from fastapi import APIRouter
34
from fastapi import Depends
45
from fastapi import Request
5-
from fastapi import APIRouter
66
from fastapi.responses import HTMLResponse
77
from fastapi.templating import Jinja2Templates
8+
from sqlalchemy.orm import Session
9+
from starlette.responses import RedirectResponse
810

11+
from database import get_db
912
from fastapi_oauth2.security import OAuth2
13+
from models import User
1014

1115
oauth2 = OAuth2()
1216
router = APIRouter()
@@ -18,6 +22,39 @@ async def root(request: Request):
1822
return templates.TemplateResponse("index.html", {"request": request, "user": request.user, "json": json})
1923

2024

25+
@router.get("/auth")
26+
def sim_auth(request: Request):
27+
access_token = request.auth.jwt_create({
28+
"id": 1,
29+
"identity": "demo:1",
30+
"image": None,
31+
"display_name": "John Doe",
32+
"email": "[email protected]",
33+
"username": "JohnDoe",
34+
"exp": 3689609839,
35+
})
36+
response = RedirectResponse("/")
37+
response.set_cookie(
38+
"Authorization",
39+
value=f"Bearer {access_token}",
40+
max_age=request.auth.expires,
41+
expires=request.auth.expires,
42+
httponly=request.auth.http,
43+
)
44+
return response
45+
46+
2147
@router.get("/user")
22-
def user(request: Request, _: str = Depends(oauth2)):
48+
def user_get(request: Request, _: str = Depends(oauth2)):
2349
return request.user
50+
51+
52+
@router.get("/users")
53+
def users_get(request: Request, db: Session = Depends(get_db), _: str = Depends(oauth2)):
54+
return db.query(User).all()
55+
56+
57+
@router.post("/users")
58+
async def users_post(request: Request, db: Session = Depends(get_db), _: str = Depends(oauth2)):
59+
data = await request.json()
60+
return User(**data).save(db)

Diff for: examples/demonstration/templates/index.html

+9-2
Large diffs are not rendered by default.

Diff for: src/fastapi_oauth2/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.0-alpha"
1+
__version__ = "1.0.0-alpha.1"

Diff for: src/fastapi_oauth2/claims.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import Any
2+
from typing import Callable
3+
from typing import Union
4+
5+
6+
class Claims(dict):
7+
"""Claims configuration for a single provider."""
8+
9+
display_name: Union[str, Callable[[dict], Any]]
10+
identity: Union[str, Callable[[dict], Any]]
11+
picture: Union[str, Callable[[dict], Any]]
12+
email: Union[str, Callable[[dict], Any]]
13+
14+
def __init__(self, seq=None, **kwargs) -> None:
15+
super().__init__(seq or {}, **kwargs)
16+
self["display_name"] = kwargs.get("display_name", self.get("display_name", "name"))
17+
self["identity"] = kwargs.get("identity", self.get("identity", "sub"))
18+
self["picture"] = kwargs.get("picture", self.get("picture", "picture"))
19+
self["email"] = kwargs.get("email", self.get("email", "email"))

Diff for: src/fastapi_oauth2/client.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
from typing import Optional
22
from typing import Sequence
33
from typing import Type
4+
from typing import Union
45

56
from social_core.backends.oauth import BaseOAuth2
67

8+
from .claims import Claims
9+
710

811
class OAuth2Client:
12+
"""OAuth2 client configuration for a single provider."""
13+
914
backend: Type[BaseOAuth2]
1015
client_id: str
1116
client_secret: str
1217
redirect_uri: Optional[str]
1318
scope: Optional[Sequence[str]]
19+
claims: Optional[Union[Claims, dict]]
1420

1521
def __init__(
1622
self,
@@ -20,9 +26,11 @@ def __init__(
2026
client_secret: str,
2127
redirect_uri: Optional[str] = None,
2228
scope: Optional[Sequence[str]] = None,
23-
):
29+
claims: Optional[Union[Claims, dict]] = None,
30+
) -> None:
2431
self.backend = backend
2532
self.client_id = client_id
2633
self.client_secret = client_secret
2734
self.redirect_uri = redirect_uri
2835
self.scope = scope or []
36+
self.claims = Claims(claims)

Diff for: src/fastapi_oauth2/config.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77

88
class OAuth2Config:
9+
"""Configuration class of the authentication middleware."""
10+
911
allow_http: bool
1012
jwt_secret: str
1113
jwt_expires: int
@@ -20,7 +22,7 @@ def __init__(
2022
jwt_expires: Union[int, str] = 900,
2123
jwt_algorithm: str = "HS256",
2224
clients: List[OAuth2Client] = None,
23-
):
25+
) -> None:
2426
if allow_http:
2527
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
2628
self.allow_http = allow_http

0 commit comments

Comments
 (0)