From 5bf8808f115a1045b763f9d19579a5672b74e4e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Bompard?= Date: Tue, 11 Jun 2024 00:19:44 +0200 Subject: [PATCH] Add a user model to `flask.g` with convenience properties MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Aurélien Bompard --- docs/index.rst | 24 +++++++++-- flask_oidc/__init__.py | 11 +++++ flask_oidc/model.py | 56 ++++++++++++++++++++++++++ tests/test_flask_oidc.py | 17 +++----- tests/test_model.py | 86 ++++++++++++++++++++++++++++++++++++++++ tests/utils.py | 11 +++++ 6 files changed, 191 insertions(+), 14 deletions(-) create mode 100644 flask_oidc/model.py create mode 100644 tests/test_model.py create mode 100644 tests/utils.py diff --git a/docs/index.rst b/docs/index.rst index de220e7..c130778 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -33,8 +33,8 @@ Install the extension with `pip`:: Integration ----------- -To integrate Flask-OpenID into your application you need to create an -instance of the :class:`OpenID` object first:: +To integrate Flask-OIDC into your application you need to create an +instance of the :class:`OpenIDConnect` object first:: from flask_oidc import OpenIDConnect oidc = OpenIDConnect(app) @@ -47,9 +47,15 @@ Using this library is very simple: you can use :data:`~flask_oidc.OpenIDConnect.user_loggedin` to determine whether a user is currently logged in using OpenID Connect. -If the user is logged in, you an use ``session["oidc_auth_profile"]`` to get +If the user is logged in, you can use ``session["oidc_auth_profile"]`` to get information about the currently logged in user. +A :class:`~flask_oidc.model.User` object is also provided on ``g.oidc_user``, see its API +documentation to discover which convenient properties are available. +You can set the ``OIDC_USER_CLASS`` configuration variable to the python path of a class if +you want to user your own class. It needs to accept the extension instance as only +constructor argument. + You can decorate any view function with :meth:`~flask_oidc.OpenIDConnect.require_login` to redirect anonymous users to the OIDC provider. @@ -70,6 +76,14 @@ A very basic example client:: def login(): return 'Welcome %s' % session["oidc_auth_profile"].get('email') + @app.route('/alt') + def alternative(): + # This uses the user instance at g.oidc_user instead + if g.oidc_user.logged_in: + return 'Welcome %s' % g.oidc_user.profile.get('email') + else + return 'Not logged in' + Resource server =============== @@ -207,6 +221,10 @@ This is a list of all settings supported in the current release. the token_introspection_uri. Valid values are 'client_secret_post', 'client_secret_basic', or 'bearer'. Defaults to 'client_secret_post'. + OIDC_USER_CLASS + The python path to a custom :class:`~flask_oidc.model.User` model. + It needs to accept the extension instance as only constructor argument. + Other docs ========== diff --git a/flask_oidc/__init__.py b/flask_oidc/__init__.py index ec6094f..890bee4 100644 --- a/flask_oidc/__init__.py +++ b/flask_oidc/__init__.py @@ -19,6 +19,7 @@ IntrospectTokenValidator as BaseIntrospectTokenValidator, ) from flask import abort, current_app, g, redirect, request, session, url_for +from werkzeug.utils import import_string from .views import auth_routes, legacy_oidc_callback @@ -162,6 +163,14 @@ def init_app(self, app, prefix=None): if app.config["OIDC_CALLBACK_ROUTE"]: app.route(app.config["OIDC_CALLBACK_ROUTE"])(legacy_oidc_callback) + # User model + app.config.setdefault("OIDC_USER_CLASS", "flask_oidc.model.User") + if app.config["OIDC_USER_CLASS"]: + app.extensions["_oidc_user_class"] = import_string( + app.config["OIDC_USER_CLASS"] + ) + + # Flask hooks app.before_request(self._before_request) def load_secrets(self, app): @@ -175,6 +184,8 @@ def load_secrets(self, app): def _before_request(self): g._oidc_auth = self.oauth.oidc + if current_app.extensions.get("_oidc_user_class"): + g.oidc_user = current_app.extensions["_oidc_user_class"](self) if not current_app.config["OIDC_RESOURCE_SERVER_ONLY"]: return self.check_token_expiry() diff --git a/flask_oidc/model.py b/flask_oidc/model.py new file mode 100644 index 0000000..8d0da43 --- /dev/null +++ b/flask_oidc/model.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: 2023 Aurélien Bompard +# +# SPDX-License-Identifier: BSD-2-Clause + + +from flask import current_app, session + + +class User: + """A representation of an OIDC-based user. + + Arguments: + ext (OpenIDConnect): the extension instance + """ + + def __init__(self, ext): + self._ext = ext + + @property + def logged_in(self): + """Return ``True`` if the user is logged in, ``False`` otherwise.""" + return session.get("oidc_auth_token") is not None + + @property + def access_token(self): + """The user's OIDC access token.""" + return self._ext.get_access_token() + + @property + def refresh_token(self): + """The user's OIDC refresh token.""" + return self._ext.get_refresh_token() + + @property + def profile(self): + """The user's OIDC profile, if any. + + Raises: + RuntimeError: when ``OIDC_USER_INFO_ENABLED`` is ``False`` in the application's + configuration. + """ + if not current_app.config["OIDC_USER_INFO_ENABLED"]: + raise RuntimeError( + "User info is disabled in configuration (OIDC_USER_INFO_ENABLED)" + ) + return session.get("oidc_auth_profile", {}) + + @property + def name(self): + """The user's nickname.""" + return self.profile.get("nickname") + + @property + def groups(self): + """The list of group names the user belongs to.""" + return self.profile.get("groups", []) diff --git a/tests/test_flask_oidc.py b/tests/test_flask_oidc.py index c33011b..652f5b2 100644 --- a/tests/test_flask_oidc.py +++ b/tests/test_flask_oidc.py @@ -20,6 +20,7 @@ from flask_oidc import OpenIDConnect from .app import oidc as oidc_ext +from .utils import set_token def callback_url_for(response): @@ -32,12 +33,6 @@ def callback_url_for(response): return f"{query['redirect_uri'][0]}?state={query['state'][0]}&code=mock_auth_code" -def _set_token(client, token): - with client.session_transaction() as session: - session["oidc_auth_token"] = token - session["oidc_auth_profile"] = {"nickname": "dummy"} - - def test_signin(test_app, client, mocked_responses, dummy_token): """Happy path authentication test.""" mocked_responses.post("https://test/openidc/Token", json=dummy_token) @@ -105,7 +100,7 @@ def test_logout_redirect_loop(make_test_app, dummy_token, mocked_responses): "https://test/openidc/Token", json={"error": "dummy"}, status=401 ) dummy_token["expires_at"] = int(time.time()) - _set_token(client, dummy_token) + set_token(client, dummy_token) resp = client.get("/logout?reason=expired") assert resp.location == "http://localhost/subpath/" @@ -118,7 +113,7 @@ def test_expired_token(client, dummy_token, mocked_responses): refresh_call = mocked_responses.post("https://test/openidc/Token", json=new_token) dummy_token["expires_at"] = int(time.time()) - _set_token(client, dummy_token) + set_token(client, dummy_token) resp = client.get("/") @@ -146,7 +141,7 @@ def test_expired_token_cant_renew(client, dummy_token, mocked_responses): ) dummy_token["expires_at"] = int(time.time()) - _set_token(client, dummy_token) + set_token(client, dummy_token) resp = client.get("/") @@ -162,7 +157,7 @@ def test_expired_token_cant_renew(client, dummy_token, mocked_responses): def test_expired_token_no_refresh_token(client, dummy_token): del dummy_token["refresh_token"] dummy_token["expires_at"] = int(time.time()) - _set_token(client, dummy_token) + set_token(client, dummy_token) resp = client.get("/") @@ -175,7 +170,7 @@ def test_expired_token_no_refresh_token(client, dummy_token): def test_bad_token(client): - _set_token(client, "bad_token") + set_token(client, "bad_token") resp = client.get("/") assert resp.status_code == 500 assert "Internal Server Error" in resp.get_data(as_text=True) diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..5f2af73 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: 2014-2015 Erica Ehrhardt +# SPDX-FileCopyrightText: 2016-2022 Patrick Uiterwijk +# SPDX-FileCopyrightText: 2023 Aurélien Bompard +# +# SPDX-License-Identifier: BSD-2-Clause + +import flask +import pytest + +from .utils import set_token + + +def test_user_in_view(client, dummy_token): + set_token(client, dummy_token) + + resp = client.get("/") + assert resp.status_code == 200 + + assert hasattr(flask.g, "oidc_user") + user = flask.g.oidc_user + + assert user.logged_in + assert user.access_token == dummy_token["access_token"] + assert user.refresh_token == dummy_token["refresh_token"] + assert user.profile == {"nickname": "dummy"} + assert user.name == "dummy" + + +def test_user_with_groups(client, dummy_token): + set_token(client, dummy_token, {"groups": ["dummy_group"]}) + + resp = client.get("/") + assert resp.status_code == 200 + + assert hasattr(flask.g, "oidc_user") + user = flask.g.oidc_user + assert user.profile == {"nickname": "dummy", "groups": ["dummy_group"]} + assert user.groups == ["dummy_group"] + + +def test_user_no_user_info(client, test_app, dummy_token): + set_token(client, dummy_token) + test_app.config["OIDC_USER_INFO_ENABLED"] = False + + resp = client.get("/") + assert resp.status_code == 200 + + assert hasattr(flask.g, "oidc_user") + user = flask.g.oidc_user + + with pytest.raises(RuntimeError): + user.profile + + with pytest.raises(RuntimeError): + user.name + + with pytest.raises(RuntimeError): + user.groups + + +def test_user_no_user(make_test_app, dummy_token): + test_app = make_test_app({"OIDC_USER_CLASS": None}) + client = test_app.test_client() + with client: + set_token(client, dummy_token) + resp = client.get("/") + + assert resp.status_code == 200 + assert not hasattr(flask.g, "oidc_user") + + +class OtherUser: + def __init__(self, ext): + pass + + +def test_user_special_user_class(make_test_app, dummy_token): + test_app = make_test_app({"OIDC_USER_CLASS": f"{__name__}.OtherUser"}) + client = test_app.test_client() + with client: + set_token(client, dummy_token) + resp = client.get("/") + + assert resp.status_code == 200 + assert hasattr(flask.g, "oidc_user") + assert isinstance(flask.g.oidc_user, OtherUser) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..fe9bf83 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: 2023 Aurélien Bompard +# +# SPDX-License-Identifier: BSD-2-Clause + + +def set_token(client, token, profile=None): + _profile = {"nickname": "dummy"} + _profile.update(profile or {}) + with client.session_transaction() as session: + session["oidc_auth_token"] = token + session["oidc_auth_profile"] = _profile