Skip to content

Commit

Permalink
Add a user model to flask.g with convenience properties
Browse files Browse the repository at this point in the history
Signed-off-by: Aurélien Bompard <[email protected]>
  • Loading branch information
abompard committed Jun 11, 2024
1 parent 7bc785d commit 5bf8808
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 14 deletions.
24 changes: 21 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ Install the extension with `pip`::

Integration
-----------
To integrate Flask-OpenID into your application you need to create an
instance of the :class:`OpenID` object first::
To integrate Flask-OIDC into your application you need to create an
instance of the :class:`OpenIDConnect` object first::

from flask_oidc import OpenIDConnect
oidc = OpenIDConnect(app)
Expand All @@ -47,9 +47,15 @@ Using this library is very simple: you can use
:data:`~flask_oidc.OpenIDConnect.user_loggedin` to determine whether a user is currently
logged in using OpenID Connect.

If the user is logged in, you an use ``session["oidc_auth_profile"]`` to get
If the user is logged in, you can use ``session["oidc_auth_profile"]`` to get
information about the currently logged in user.

A :class:`~flask_oidc.model.User` object is also provided on ``g.oidc_user``, see its API
documentation to discover which convenient properties are available.
You can set the ``OIDC_USER_CLASS`` configuration variable to the python path of a class if
you want to user your own class. It needs to accept the extension instance as only
constructor argument.

You can decorate any view function with :meth:`~flask_oidc.OpenIDConnect.require_login`
to redirect anonymous users to the OIDC provider.

Expand All @@ -70,6 +76,14 @@ A very basic example client::
def login():
return 'Welcome %s' % session["oidc_auth_profile"].get('email')

@app.route('/alt')
def alternative():
# This uses the user instance at g.oidc_user instead
if g.oidc_user.logged_in:
return 'Welcome %s' % g.oidc_user.profile.get('email')
else
return 'Not logged in'


Resource server
===============
Expand Down Expand Up @@ -207,6 +221,10 @@ This is a list of all settings supported in the current release.
the token_introspection_uri. Valid values are 'client_secret_post',
'client_secret_basic', or 'bearer'. Defaults to 'client_secret_post'.

OIDC_USER_CLASS
The python path to a custom :class:`~flask_oidc.model.User` model.
It needs to accept the extension instance as only constructor argument.


Other docs
==========
Expand Down
11 changes: 11 additions & 0 deletions flask_oidc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
IntrospectTokenValidator as BaseIntrospectTokenValidator,
)
from flask import abort, current_app, g, redirect, request, session, url_for
from werkzeug.utils import import_string

from .views import auth_routes, legacy_oidc_callback

Expand Down Expand Up @@ -162,6 +163,14 @@ def init_app(self, app, prefix=None):
if app.config["OIDC_CALLBACK_ROUTE"]:
app.route(app.config["OIDC_CALLBACK_ROUTE"])(legacy_oidc_callback)

# User model
app.config.setdefault("OIDC_USER_CLASS", "flask_oidc.model.User")
if app.config["OIDC_USER_CLASS"]:
app.extensions["_oidc_user_class"] = import_string(
app.config["OIDC_USER_CLASS"]
)

# Flask hooks
app.before_request(self._before_request)

def load_secrets(self, app):
Expand All @@ -175,6 +184,8 @@ def load_secrets(self, app):

def _before_request(self):
g._oidc_auth = self.oauth.oidc
if current_app.extensions.get("_oidc_user_class"):
g.oidc_user = current_app.extensions["_oidc_user_class"](self)
if not current_app.config["OIDC_RESOURCE_SERVER_ONLY"]:
return self.check_token_expiry()

Expand Down
56 changes: 56 additions & 0 deletions flask_oidc/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-FileCopyrightText: 2023 Aurélien Bompard <[email protected]>
#
# SPDX-License-Identifier: BSD-2-Clause


from flask import current_app, session


class User:
"""A representation of an OIDC-based user.
Arguments:
ext (OpenIDConnect): the extension instance
"""

def __init__(self, ext):
self._ext = ext

@property
def logged_in(self):
"""Return ``True`` if the user is logged in, ``False`` otherwise."""
return session.get("oidc_auth_token") is not None

@property
def access_token(self):
"""The user's OIDC access token."""
return self._ext.get_access_token()

@property
def refresh_token(self):
"""The user's OIDC refresh token."""
return self._ext.get_refresh_token()

@property
def profile(self):
"""The user's OIDC profile, if any.
Raises:
RuntimeError: when ``OIDC_USER_INFO_ENABLED`` is ``False`` in the application's
configuration.
"""
if not current_app.config["OIDC_USER_INFO_ENABLED"]:
raise RuntimeError(
"User info is disabled in configuration (OIDC_USER_INFO_ENABLED)"
)
return session.get("oidc_auth_profile", {})

@property
def name(self):
"""The user's nickname."""
return self.profile.get("nickname")

@property
def groups(self):
"""The list of group names the user belongs to."""
return self.profile.get("groups", [])
17 changes: 6 additions & 11 deletions tests/test_flask_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from flask_oidc import OpenIDConnect

from .app import oidc as oidc_ext
from .utils import set_token


def callback_url_for(response):
Expand All @@ -32,12 +33,6 @@ def callback_url_for(response):
return f"{query['redirect_uri'][0]}?state={query['state'][0]}&code=mock_auth_code"


def _set_token(client, token):
with client.session_transaction() as session:
session["oidc_auth_token"] = token
session["oidc_auth_profile"] = {"nickname": "dummy"}


def test_signin(test_app, client, mocked_responses, dummy_token):
"""Happy path authentication test."""
mocked_responses.post("https://test/openidc/Token", json=dummy_token)
Expand Down Expand Up @@ -105,7 +100,7 @@ def test_logout_redirect_loop(make_test_app, dummy_token, mocked_responses):
"https://test/openidc/Token", json={"error": "dummy"}, status=401
)
dummy_token["expires_at"] = int(time.time())
_set_token(client, dummy_token)
set_token(client, dummy_token)

resp = client.get("/logout?reason=expired")
assert resp.location == "http://localhost/subpath/"
Expand All @@ -118,7 +113,7 @@ def test_expired_token(client, dummy_token, mocked_responses):
refresh_call = mocked_responses.post("https://test/openidc/Token", json=new_token)

dummy_token["expires_at"] = int(time.time())
_set_token(client, dummy_token)
set_token(client, dummy_token)

resp = client.get("/")

Expand Down Expand Up @@ -146,7 +141,7 @@ def test_expired_token_cant_renew(client, dummy_token, mocked_responses):
)

dummy_token["expires_at"] = int(time.time())
_set_token(client, dummy_token)
set_token(client, dummy_token)

resp = client.get("/")

Expand All @@ -162,7 +157,7 @@ def test_expired_token_cant_renew(client, dummy_token, mocked_responses):
def test_expired_token_no_refresh_token(client, dummy_token):
del dummy_token["refresh_token"]
dummy_token["expires_at"] = int(time.time())
_set_token(client, dummy_token)
set_token(client, dummy_token)

resp = client.get("/")

Expand All @@ -175,7 +170,7 @@ def test_expired_token_no_refresh_token(client, dummy_token):


def test_bad_token(client):
_set_token(client, "bad_token")
set_token(client, "bad_token")
resp = client.get("/")
assert resp.status_code == 500
assert "Internal Server Error" in resp.get_data(as_text=True)
Expand Down
86 changes: 86 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-FileCopyrightText: 2014-2015 Erica Ehrhardt
# SPDX-FileCopyrightText: 2016-2022 Patrick Uiterwijk <[email protected]>
# SPDX-FileCopyrightText: 2023 Aurélien Bompard <[email protected]>
#
# SPDX-License-Identifier: BSD-2-Clause

import flask
import pytest

from .utils import set_token


def test_user_in_view(client, dummy_token):
set_token(client, dummy_token)

resp = client.get("/")
assert resp.status_code == 200

assert hasattr(flask.g, "oidc_user")
user = flask.g.oidc_user

assert user.logged_in
assert user.access_token == dummy_token["access_token"]
assert user.refresh_token == dummy_token["refresh_token"]
assert user.profile == {"nickname": "dummy"}
assert user.name == "dummy"


def test_user_with_groups(client, dummy_token):
set_token(client, dummy_token, {"groups": ["dummy_group"]})

resp = client.get("/")
assert resp.status_code == 200

assert hasattr(flask.g, "oidc_user")
user = flask.g.oidc_user
assert user.profile == {"nickname": "dummy", "groups": ["dummy_group"]}
assert user.groups == ["dummy_group"]


def test_user_no_user_info(client, test_app, dummy_token):
set_token(client, dummy_token)
test_app.config["OIDC_USER_INFO_ENABLED"] = False

resp = client.get("/")
assert resp.status_code == 200

assert hasattr(flask.g, "oidc_user")
user = flask.g.oidc_user

with pytest.raises(RuntimeError):
user.profile

with pytest.raises(RuntimeError):
user.name

with pytest.raises(RuntimeError):
user.groups


def test_user_no_user(make_test_app, dummy_token):
test_app = make_test_app({"OIDC_USER_CLASS": None})
client = test_app.test_client()
with client:
set_token(client, dummy_token)
resp = client.get("/")

assert resp.status_code == 200
assert not hasattr(flask.g, "oidc_user")


class OtherUser:
def __init__(self, ext):
pass


def test_user_special_user_class(make_test_app, dummy_token):
test_app = make_test_app({"OIDC_USER_CLASS": f"{__name__}.OtherUser"})
client = test_app.test_client()
with client:
set_token(client, dummy_token)
resp = client.get("/")

assert resp.status_code == 200
assert hasattr(flask.g, "oidc_user")
assert isinstance(flask.g.oidc_user, OtherUser)
11 changes: 11 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-FileCopyrightText: 2023 Aurélien Bompard <[email protected]>
#
# SPDX-License-Identifier: BSD-2-Clause


def set_token(client, token, profile=None):
_profile = {"nickname": "dummy"}
_profile.update(profile or {})
with client.session_transaction() as session:
session["oidc_auth_token"] = token
session["oidc_auth_profile"] = _profile

0 comments on commit 5bf8808

Please sign in to comment.