Skip to content

Commit

Permalink
Add signals to hook into the login and logout process
Browse files Browse the repository at this point in the history
Signed-off-by: Aurélien Bompard <[email protected]>
  • Loading branch information
abompard committed Jun 11, 2024
1 parent 5bf8808 commit 5841c9c
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 8 deletions.
7 changes: 7 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,13 @@ This is a list of all settings supported in the current release.
It needs to accept the extension instance as only constructor argument.


Signals
=======

Some signals are available to hook into the login and logout process, see the
:mod:`~flask_oidc.signals` module documentation for details.


Other docs
==========

Expand Down
31 changes: 31 additions & 0 deletions flask_oidc/signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-FileCopyrightText: 2023 Aurélien Bompard <[email protected]>
#
# SPDX-License-Identifier: BSD-2-Clause

"""
This module contains signals that can be connected to to hook into the login and logout process.
See the `Flask documentation on signals <https://flask.palletsprojects.com/en/2.3.x/signals/>`_
to learn how to connect to these.
"""


from blinker import Namespace

# This namespace is only for signals provided by flask-oidc.
_signals = Namespace()

before_login_redirect = _signals.signal("before-login-redirect")
"""Emitted before the user is redirected to the identity provider."""

before_authorize = _signals.signal("before-authorize")
"""Emitted when the user is redirected back from the identity provider."""

after_authorize = _signals.signal("after-authorize")
"""Emitted when the user is authenticated."""

before_logout = _signals.signal("before-logout")
"""Emitted before the user is logged out."""

after_logout = _signals.signal("after-logout")
"""Emitted after the user is logged out."""
17 changes: 17 additions & 0 deletions flask_oidc/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@
url_for,
)

from .signals import (
after_authorize,
after_logout,
before_authorize,
before_login_redirect,
before_logout,
)

logger = logging.getLogger(__name__)

auth_routes = Blueprint("oidc_auth", __name__)
Expand All @@ -37,11 +45,17 @@ def login_view():
else:
redirect_uri = url_for("oidc_auth.authorize", _external=True)
session["next"] = request.args.get("next", request.root_url)
before_login_redirect.send(
g._oidc_auth,
redirect_uri=redirect_uri,
next=session["next"],
)
return g._oidc_auth.authorize_redirect(redirect_uri)


@auth_routes.route("/authorize", endpoint="authorize")
def authorize_view():
before_authorize.send(g._oidc_auth)
try:
token = g._oidc_auth.authorize_access_token()
except OAuthError as e:
Expand All @@ -57,6 +71,7 @@ def authorize_view():
del session["next"]
except KeyError:
return_to = request.root_url
after_authorize.send(g._oidc_auth, token=token, return_to=return_to)
return redirect(return_to)


Expand All @@ -75,6 +90,7 @@ def logout_view():
.. versionadded:: 1.0
"""
before_logout.send(g._oidc_auth)
session.pop("oidc_auth_token", None)
session.pop("oidc_auth_profile", None)
g.oidc_id_token = None
Expand All @@ -84,6 +100,7 @@ def logout_view():
else:
flash("You were successfully logged out.")
return_to = request.args.get("next", request.root_url)
after_logout.send(g._oidc_auth, reason=reason, return_to=return_to)
return redirect(return_to)


Expand Down
5 changes: 2 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ python = "^3.8"
flask = "^2.0.0 || ^3.0.0"
authlib = "^1.2.0"
requests = "^2.24.0"
blinker = "^1.8.2"

[tool.poetry.group.dev.dependencies]
black = ">=22.6.0"
Expand Down
70 changes: 65 additions & 5 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,78 @@
#
# SPDX-License-Identifier: BSD-2-Clause

import sys
import time
from collections import defaultdict

import flask
import pytest

from flask_oidc.signals import (
after_authorize,
after_logout,
before_authorize,
before_logout,
)

def test_authorize_error(client):
HAS_MULTIPLE_CONTEXT_MANAGERS = sys.hexversion >= 0x030900F0 # 3.9.0


@pytest.fixture()
def sent_signals():
if not HAS_MULTIPLE_CONTEXT_MANAGERS:
yield {}
return

sent = defaultdict(list)

def record_signal(signal):
def record(sender, **kwargs):
sent[signal].append(kwargs)

return signal.connected_to(record)

with (
record_signal(before_authorize),
record_signal(after_authorize),
record_signal(before_logout),
record_signal(after_logout),
):
yield sent


def test_authorize_error(client, sent_signals):
resp = client.get(
"http://localhost/authorize?error=dummy_error&error_description=Dummy+Error"
)
assert resp.status_code == 401
assert "<p>dummy_error: Dummy Error</p>" in resp.get_data(as_text=True)
# Model
assert flask.g.oidc_user.logged_in is False
# Signals
if HAS_MULTIPLE_CONTEXT_MANAGERS:
assert len(sent_signals[before_authorize]) == 1
assert len(sent_signals[after_authorize]) == 0


def test_authorize_no_return_url(client, mocked_responses, dummy_token):
def test_authorize_no_return_url(client, mocked_responses, dummy_token, sent_signals):
mocked_responses.post("https://test/openidc/Token", json=dummy_token)
mocked_responses.get("https://test/openidc/UserInfo", json={"nickname": "dummy"})
with client.session_transaction() as session:
session["_state_oidc_dummy_state"] = {"data": {}}
resp = client.get("/authorize?state=dummy_state&code=dummy_code")
assert resp.status_code == 302
assert resp.location == "http://localhost/"
# Signals
if HAS_MULTIPLE_CONTEXT_MANAGERS:
assert len(sent_signals[before_authorize]) == 1
assert len(sent_signals[after_authorize]) == 1
assert sent_signals[after_authorize][0]["token"] == dummy_token


def test_authorize_no_user_info(test_app, client, mocked_responses, dummy_token):
def test_authorize_no_user_info(
test_app, client, mocked_responses, dummy_token, sent_signals
):
test_app.config["OIDC_USER_INFO_ENABLED"] = False
mocked_responses.post("https://test/openidc/Token", json=dummy_token)
with client.session_transaction() as session:
Expand All @@ -37,9 +84,14 @@ def test_authorize_no_user_info(test_app, client, mocked_responses, dummy_token)
assert resp.status_code == 302
assert "oidc_auth_token" in flask.session
assert "oidc_auth_profile" not in flask.session
# Signals
if HAS_MULTIPLE_CONTEXT_MANAGERS:
assert len(sent_signals[before_authorize]) == 1
assert len(sent_signals[after_authorize]) == 1
assert sent_signals[after_authorize][0]["token"] == dummy_token


def test_logout(client, dummy_token):
def test_logout(client, dummy_token, sent_signals):
with client.session_transaction() as session:
session["oidc_auth_token"] = dummy_token
session["oidc_auth_profile"] = {"nickname": "dummy"}
Expand All @@ -51,9 +103,13 @@ def test_logout(client, dummy_token):
flashes = flask.get_flashed_messages()
assert len(flashes) == 1
assert flashes[0] == "You were successfully logged out."
# Signals
if HAS_MULTIPLE_CONTEXT_MANAGERS:
assert len(sent_signals[before_logout]) == 1
assert len(sent_signals[after_logout]) == 1


def test_logout_expired(client, dummy_token):
def test_logout_expired(client, dummy_token, sent_signals):
dummy_token["expires_at"] = int(time.time())
with client.session_transaction() as session:
session["oidc_auth_token"] = dummy_token
Expand All @@ -65,6 +121,10 @@ def test_logout_expired(client, dummy_token):
flashes = flask.get_flashed_messages()
assert len(flashes) == 1
assert flashes[0] == "Your session expired, please reconnect."
# Signals
if HAS_MULTIPLE_CONTEXT_MANAGERS:
assert len(sent_signals[before_logout]) == 1
assert len(sent_signals[after_logout]) == 1


def test_oidc_callback_route(test_app, client, dummy_token):
Expand Down

0 comments on commit 5841c9c

Please sign in to comment.