diff --git a/README.md b/README.md index ac366bf..8812153 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![PyPI - Version](https://img.shields.io/pypi/v/synapse-token-authenticator.svg)](https://pypi.org/project/synapse-token-authenticator) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/synapse-token-authenticator.svg)](https://pypi.org/project/synapse-token-authenticator) -Synapse Token Authenticator is a synapse auth provider which allows for token authentication (and optional registration) using JWTs (Json Web Tokens). +Synapse Token Authenticator is a synapse auth provider which allows for token authentication (and optional registration) using JWTs (Json Web Tokens) and OIDC. ----- @@ -25,20 +25,33 @@ pip install synapse-token-authenticator ## Configuration Here are the available configuration options: ```yaml -# provide only one of secret, keyfile -secret: symetrical secret -keyfile: path to asymetrical keyfile - -# Algorithm of the tokens, defaults to HS512 -#algorithm: HS512 -# Allow registration of new users using these tokens, defaults to false -#allow_registration: false -# Require tokens to have an expiry set, defaults to true -#require_expiry: true +jwt: + # provide only one of secret, keyfile + secret: symetrical secret + keyfile: path to asymetrical keyfile + + # Algorithm of the tokens, defaults to HS512 (optional) + algorithm: HS512 + # Allow registration of new users, defaults to false (optional) + allow_registration: false + # Require tokens to have an expiry set, defaults to true (optional) + require_expiry: true +oidc: + issuer: "https://idp.example.com" + client_id: "" + client_secret: "" + project_id: # TODO: improve docs + organization_id: # TODO: improve docs + # Limits access to specified clients. Allows any client if not set (optional) + allowed_client_ids: ['foo', 'bar'] # TODO: better examples + # Allow registration of new users, defaults to false (optional) + allow_registration: false ``` It is recommended to have `require_expiry` set to `true` (default). As for `allow_registration`, it depends on usecase: If you only want to be able to log in *existing* users, leave it at `false` (default). If nonexistant users should be simply registered upon hitting the login endpoint, set it to `true`. ## Usage + +### JWT Authentication First you have to generate a JWT with the correct claims. The `sub` claim is the localpart or full mxid of the user you want to log in as. Be sure that the algorithm and secret match those of the configuration. An example of the claims is as follows: ```json { @@ -59,6 +72,10 @@ Next you need to post this token to the `/login` endpoint of synapse. Be sure th } ``` +### OIDC Authentication + + + ## Testing The tests uses twisted's testing framework trial, with the development diff --git a/pyproject.toml b/pyproject.toml index 7fe67c1..86f48b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,11 +37,13 @@ dependencies = [ "pytest", "pytest-cov", "mock", - "matrix-synapse" + "matrix-synapse", + "ruff", ] [tool.hatch.envs.default.scripts] cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=synapse_token_authenticator --cov=tests" format = "black ." +lint = "ruff check ." [tool.hatch.envs.ci.scripts] format = "black --check ." @@ -54,8 +56,4 @@ parallel = true omit = [] [tool.coverage.report] -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] diff --git a/synapse_token_authenticator/config.py b/synapse_token_authenticator/config.py new file mode 100644 index 0000000..4dfd545 --- /dev/null +++ b/synapse_token_authenticator/config.py @@ -0,0 +1,70 @@ +import os + + +class TokenAuthenticatorConfig: + """ + Parses and validates the provided config dictionary. + """ + + def __init__(self, other: dict): + if jwt := other.get("jwt"): + + class JwtConfig: + def __init__(self, other: dict): + self.secret: str | None = other.get("secret") + self.keyfile: str | None = other.get("keyfile") + + self.algorithm: str = other.get("algorithm", "HS512") + self.allow_registration: bool = other.get( + "allow_registration", False + ) + self.require_expiry: bool = other.get("require_expiry", True) + + self.jwt = JwtConfig(jwt) + self.verify_jwt() + + if oidc := other.get("oidc"): + + class OIDCConfig: + def __init__(self, other: dict): + try: + self.issuer: str = other["issuer"] + self.client_id: str = other["client_id"] + self.client_secret: str = other["client_secret"] + self.project_id: str = other["project_id"] + self.organization_id: str = other["organization_id"] + except KeyError as error: + raise Exception(f"Config option must be set: {error.args[0]}") + + self.allowed_client_ids: str | None = other.get( + "allowed_client_ids" + ) + + self.allow_registration: bool = other.get( + "allow_registration", False + ) + + self.oidc = OIDCConfig(oidc) + + def verify_jwt(self): + if self.jwt.secret is None and self.jwt.keyfile is None: + raise Exception("Missing secret or keyfile") + if self.jwt.keyfile is not None and not os.path.exists(self.jwt.keyfile): + raise Exception("Keyfile doesn't exist") + + if self.jwt.algorithm not in [ + "HS256", + "HS384", + "HS512", + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", + "EdDSA", + ]: + raise Exception(f"Unknown algorithm: '{self.jwt.algorithm}'") diff --git a/synapse_token_authenticator/token_authenticator.py b/synapse_token_authenticator/token_authenticator.py index 68cdd52..f0025ba 100644 --- a/synapse_token_authenticator/token_authenticator.py +++ b/synapse_token_authenticator/token_authenticator.py @@ -18,13 +18,20 @@ import logging from jwcrypto import jwt, jwk from jwcrypto.common import JWException, json_decode -import os import base64 +import requests +from requests.auth import HTTPBasicAuth +from urllib.parse import urljoin import synapse from synapse.module_api import ModuleApi from synapse.types import UserID +from twisted.web import resource + +from synapse_token_authenticator.config import TokenAuthenticatorConfig +from synapse_token_authenticator.utils import OpenIDProviderMetadata + logger = logging.getLogger(__name__) @@ -34,26 +41,54 @@ class TokenAuthenticator(object): def __init__(self, config: dict, account_handler: ModuleApi): self.api = account_handler - self.config = config - if self.config.secret: - k = { - "k": base64.urlsafe_b64encode( - self.config.secret.encode("utf-8") - ).decode("utf-8"), - "kty": "oct", - } - self.key = jwk.JWK(**k) - else: - with open(self.config.keyfile, "r") as f: - self.key = jwk.JWK.from_pem(f.read()) - - self.api.register_password_auth_provider_callbacks( - auth_checkers={ - ("com.famedly.login.token", ("token",)): self.check_auth, - }, - ) + auth_checkers = {} + + self.config: TokenAuthenticatorConfig = config + if (jwt := getattr(self.config, "jwt", None)) is not None: + if jwt.secret: + k = { + "k": base64.urlsafe_b64encode(jwt.secret.encode("utf-8")).decode( + "utf-8" + ), + "kty": "oct", + } + self.key = jwk.JWK(**k) + else: + with open(jwt.keyfile, "r") as f: + self.key = jwk.JWK.from_pem(f.read()) + auth_checkers[("com.famedly.login.token", ("token",))] = self.check_jwt_auth + + if (oidc := getattr(self.config, "oidc", None)) is not None: + auth_checkers[ + ("com.famedly.login.token.oidc", ("token",)) + ] = self.check_oidc_auth + + self.api.register_web_resource( + "/_famedly/login/com.famedly.login.token.oidc", + self.LoginMetadataResource(oidc), + ) + + self.api.register_password_auth_provider_callbacks(auth_checkers=auth_checkers) + + class LoginMetadataResource(resource.Resource): + def __init__(self, oidc_config: object): + self.issuer = oidc_config.issuer + self.metadata_url = urljoin( + oidc_config.issuer, "/.well-known/openid-configuration" + ) + self.organization_id = oidc_config.organization_id + + def render_GET(self, request): + request.responseHeaders.addRawHeader(b"content-type", b"application/json") + return f""" + {{ + "issuer": "{self.issuer}", + "issuer-metadata": "{self.metadata_url}", + "organization-id": "{self.organization_id}" + }} + """ - async def check_auth( + async def check_jwt_auth( self, username: str, login_type: str, login_dict: "synapse.module_api.JsonDict" ) -> Optional[ Tuple[ @@ -71,7 +106,7 @@ async def check_auth( token = login_dict["token"] check_claims = {} - if self.config.require_expiry: + if self.config.jwt.require_expiry: check_claims["exp"] = None try: # OK, let's verify the token @@ -79,7 +114,7 @@ async def check_auth( jwt=token, key=self.key, check_claims=check_claims, - algs=[self.config.algorithm], + algs=[self.config.jwt.algorithm], ) except ValueError as e: logger.info("Unrecognized token", e) @@ -121,12 +156,12 @@ async def check_auth( return None user_exists = await self.api.check_user_exists(user_id_str) - if not user_exists and not self.config.allow_registration: + if not user_exists and not self.config.jwt.allow_registration: logger.info("User doesn't exist and registration is disabled") return None if not user_exists: - logger.info("User doesn't exist, registering it...") + logger.info("User doesn't exist, registering them...") await self.api.register_user( user_id.localpart, admin=payload.get("admin", False) ) @@ -145,37 +180,99 @@ async def check_auth( logger.info("All done and valid, logging in!") return (user_id_str, None) + async def check_oidc_auth( + self, username: str, login_type: str, login_dict: "synapse.module_api.JsonDict" + ) -> Optional[ + Tuple[ + str, + Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], + ] + ]: + logger.info("Receiving auth request") + if login_type != "com.famedly.login.token.oidc": + logger.info("Wrong login type") + return None + if "token" not in login_dict: + logger.info("Missing token") + return None + token = login_dict["token"] + + oidc = self.config.oidc + oidc_metadata = OpenIDProviderMetadata(oidc.issuer) + + # Basic token validation + try: + jwt.JWT( + jwt=token, + key=oidc_metadata.jwks(), + algs=oidc_metadata.id_token_signing_alg_values_supported, + ) + except ValueError as e: + logger.info("Unrecognized token", e) + return None + except JWException as e: + logger.info("Invalid token", e) + return None + + # Further validation using token introspection + data = {"token": token, "token_type_hint": "access_token", "scope": "openid"} + auth = HTTPBasicAuth(oidc.client_id, oidc.client_secret) + response = requests.post( + oidc_metadata.introspection_endpoint, data=data, auth=auth + ) + response.raise_for_status() + introspection_resp = response.json() + + if not introspection_resp["active"]: + logger.info("User is not active") + return None + + if oidc.project_id not in introspection_resp["aud"]: + logger.info( + "Project ID is not part of the token's audience" + ) # TODO: more useful error message + return None + + if introspection_resp["iss"] != oidc_metadata.issuer: + logger.info(f"Token issuer does not match: {introspection_resp['iss']}") + return None + + if ( + oidc.allowed_client_ids is not None + and introspection_resp["client_id"] not in oidc.allowed_client_ids + ): + logger.info( + f"Client {introspection_resp['client_id']} is not in the list of allowed clients" + ) + return None + + # Checking if the user's localpart matches + headers = {"Authorization": f"Bearer {token}"} + response = requests.post(oidc_metadata.userinfo_endpoint, headers=headers) + response.raise_for_status() + userinfo = response.json() + + user_id_str = self.api.get_qualified_user_id(username) + user_id = UserID.from_string(user_id_str) + + if userinfo["localpart"] != user_id.localpart: + logger.info("The provided username is incorrect") + return None + + user_exists = await self.api.check_user_exists(user_id_str) + if not user_exists and not self.config.oidc.allow_registration: + logger.info("User doesn't exist and registration is disabled") + return None + + if not user_exists: + logger.info("User doesn't exist, registering it...") + await self.api.register_user(user_id.localpart) + + user_id_str = self.api.get_qualified_user_id(username) + + logger.info("All done and valid, logging in!") + return (user_id_str, None) + @staticmethod - def parse_config(config): - class _TokenAuthenticatorConfig(object): - pass - - _config = _TokenAuthenticatorConfig() - _config.secret = config.get("secret", False) - _config.keyfile = config.get("keyfile", False) - if not _config.secret and not _config.keyfile: - raise Exception("Missing secret or keyfile") - if _config.keyfile and not os.path.exists(_config.keyfile): - raise Exception("Keyfile doesn't exist") - - _config.algorithm = config.get("algorithm", "HS512") - if _config.algorithm not in [ - "HS256", - "HS384", - "HS512", - "RS256", - "RS384", - "RS512", - "ES256", - "ES384", - "ES512", - "PS256", - "PS384", - "PS512", - "EdDSA", - ]: - raise Exception("Unknown algorithm " + _config.algorithm) - - _config.allow_registration = config.get("allow_registration", False) - _config.require_expiry = config.get("require_expiry", True) - return _config + def parse_config(config: dict): + return TokenAuthenticatorConfig(config) diff --git a/synapse_token_authenticator/utils.py b/synapse_token_authenticator/utils.py new file mode 100644 index 0000000..2dbd556 --- /dev/null +++ b/synapse_token_authenticator/utils.py @@ -0,0 +1,31 @@ +import requests +from urllib.parse import urljoin +from jwcrypto.jwk import JWKSet + + +class OpenIDProviderMetadata: + """ + Wrapper around OpenID Provider Metadata values + """ + + def __init__(self, issuer: str): + response = requests.get(urljoin(issuer, "/.well-known/openid-configuration")) + response.raise_for_status() + + configuration = response.json() + + self.introspection_endpoint: str = configuration["introspection_endpoint"] + self.userinfo_endpoint: str = configuration["userinfo_endpoint"] + self.jwks_uri: str = configuration["jwks_uri"] + self.id_token_signing_alg_values_supported: list[str] = configuration[ + "id_token_signing_alg_values_supported" + ] + + def jwks(self) -> JWKSet: + """ + Signing keys used to validate signatures from the OpenID Provider + """ + response = requests.get(self.jwks_uri) + response.raise_for_status() + + return JWKSet.from_json(response.text) diff --git a/tests/__init__.py b/tests/__init__.py index 977f885..02eacc7 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -52,6 +52,13 @@ async def register_user( account_handler.set_user_admin.side_effect = set_user_admin account_handler.is_user_admin.side_effect = is_user_admin + # TODO: mock IDP: + # - /.well-known/openid-configuration + # - (only needs to contain: issuer, introspection_endpoint, userinfo_endpoint, jwks_uri, id_token_signing_alg_values_supported) + # - /oauth/v2/keys + # - /oauth/v2/introspect + # - /oidc/v1/userinfo + def get_qualified_user_id(*args): return ModuleApi.get_qualified_user_id(account_handler, *args) @@ -60,7 +67,19 @@ def get_qualified_user_id(*args): if config: config_parsed = TokenAuthenticator.parse_config(config) else: - config_parsed = TokenAuthenticator.parse_config({"secret": "foxies"}) + # TODO: add example oidc config + config_parsed = TokenAuthenticator.parse_config( + { + "jwt": {"secret": "foxies"}, + # "oidc": { + # "issuer": "https://idp.example.org", + # "client_id": "", + # "client_secret": "", + # "project_id": "", + # "organization_id": "" + # }, + } + ) return TokenAuthenticator(config_parsed, account_handler) diff --git a/tests/test_simple.py b/tests/test_jwt.py similarity index 80% rename from tests/test_simple.py rename to tests/test_jwt.py index 30cd235..f466aea 100644 --- a/tests/test_simple.py +++ b/tests/test_jwt.py @@ -18,21 +18,25 @@ from . import get_auth_provider, get_token -class SimpleTestCase(unittest.TestCase): +class JWTTests(unittest.TestCase): async def test_wrong_login_type(self): auth_provider = get_auth_provider() token = get_token("alice") - result = await auth_provider.check_auth("alice", "m.password", {"token": token}) + result = await auth_provider.check_jwt_auth( + "alice", "m.password", {"token": token} + ) self.assertEqual(result, None) async def test_missing_token(self): auth_provider = get_auth_provider() - result = await auth_provider.check_auth("alice", "com.famedly.login.token", {}) + result = await auth_provider.check_jwt_auth( + "alice", "com.famedly.login.token", {} + ) self.assertEqual(result, None) async def test_invalid_token(self): auth_provider = get_auth_provider() - result = await auth_provider.check_auth( + result = await auth_provider.check_jwt_auth( "alice", "com.famedly.login.token", {"token": "invalid"} ) self.assertEqual(result, None) @@ -40,7 +44,7 @@ async def test_invalid_token(self): async def test_token_wrong_secret(self): auth_provider = get_auth_provider() token = get_token("alice", secret="wrong secret") - result = await auth_provider.check_auth( + result = await auth_provider.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result, None) @@ -48,7 +52,7 @@ async def test_token_wrong_secret(self): async def test_token_wrong_alg(self): auth_provider = get_auth_provider() token = get_token("alice", algorithm="HS256") - result = await auth_provider.check_auth( + result = await auth_provider.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result, None) @@ -56,7 +60,7 @@ async def test_token_wrong_alg(self): async def test_token_expired(self): auth_provider = get_auth_provider() token = get_token("alice", exp_in=-60) - result = await auth_provider.check_auth( + result = await auth_provider.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result, None) @@ -64,7 +68,7 @@ async def test_token_expired(self): async def test_token_no_expiry(self): auth_provider = get_auth_provider() token = get_token("alice", exp_in=-1) - result = await auth_provider.check_auth( + result = await auth_provider.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result, None) @@ -72,12 +76,14 @@ async def test_token_no_expiry(self): async def test_token_no_expiry_with_config(self): auth_provider = get_auth_provider( config={ - "secret": "foxies", - "require_expiry": False, + "jwt": { + "secret": "foxies", + "require_expiry": False, + } } ) token = get_token("alice", exp_in=-1) - result = await auth_provider.check_auth( + result = await auth_provider.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result[0], "@alice:example.org") @@ -85,7 +91,7 @@ async def test_token_no_expiry_with_config(self): async def test_valid_login(self): auth_provider = get_auth_provider() token = get_token("alice") - result = await auth_provider.check_auth( + result = await auth_provider.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result[0], "@alice:example.org") @@ -93,7 +99,7 @@ async def test_valid_login(self): async def test_valid_login_no_register(self): auth_provider = get_auth_provider(user_exists=False) token = get_token("alice") - result = await auth_provider.check_auth( + result = await auth_provider.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result, None) @@ -103,7 +109,7 @@ async def test_chatbox_login(self): token = get_token( "alice_5833eb34-7dbf-44a7-90cf-868c50922c06", claims={"type": "chatbox"} ) - result = await auth_provider.check_auth( + result = await auth_provider.check_jwt_auth( "alice_5833eb34-7dbf-44a7-90cf-868c50922c06", "com.famedly.login.token", {"token": token}, @@ -115,19 +121,21 @@ async def test_chatbox_login(self): async def test_chatbox_login_invalid_format(self): auth_provider = get_auth_provider(user_exists=False) token = get_token("alice", claims={"type": "chatbox"}) - result = await auth_provider.check_auth( + result = await auth_provider.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result, None) async def test_valid_login_with_register(self): config = { - "secret": "foxies", - "allow_registration": True, + "jwt": { + "secret": "foxies", + "allow_registration": True, + }, } auth_provider = get_auth_provider(config=config, user_exists=False) token = get_token("alice") - result = await auth_provider.check_auth( + result = await auth_provider.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result[0], "@alice:example.org") @@ -135,7 +143,7 @@ async def test_valid_login_with_register(self): async def test_valid_login_with_admin(self): auth_provider = get_auth_provider() token = get_token("alice", admin=True) - result = await auth_provider.check_auth( + result = await auth_provider.check_jwt_auth( "alice", "com.famedly.login.token", {"token": token} ) self.assertEqual(result[0], "@alice:example.org") diff --git a/tests/test_oidc.py b/tests/test_oidc.py new file mode 100644 index 0000000..7111223 --- /dev/null +++ b/tests/test_oidc.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2020 Famedly +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from twisted.trial import unittest +from . import get_auth_provider + + +class OIDCTests(unittest.TestCase): + async def test_wrong_login_type(self): + auth_provider = get_auth_provider() + # TODO: add a function that returns example login dict + # token = get_token("alice") + result = await auth_provider.check_oidc_auth("alice", "m.password", {}) + self.assertEqual(result, None) + + async def test_missing_token(self): + auth_provider = get_auth_provider() + result = await auth_provider.check_oidc_auth( + "alice", "com.famedly.login.token,oidc", {} + ) + self.assertEqual(result, None) + + async def test_invalid_token(self): + auth_provider = get_auth_provider() + result = await auth_provider.check_oidc_auth( + "alice", "com.famedly.login.token.oidc", {"token": "invalid"} + ) + self.assertEqual(result, None) + + async def test_token_wrong_alg(self): + auth_provider = get_auth_provider() + # TODO: add a function that returns example login dict + # token = get_token("alice", algorithm="HS256") + result = await auth_provider.check_oidc_auth( + "alice", "com.famedly.login.token.oidc", {} + ) + self.assertEqual(result, None) + + async def test_valid_login(self): + auth_provider = get_auth_provider() + # TODO: add a function that returns example login dict + # token = get_token("alice") + result = await auth_provider.check_oidc_auth( + "alice", "com.famedly.login.token.oidc", {} + ) + self.assertEqual(result[0], "@alice:example.org") + + async def test_valid_login_no_register(self): + auth_provider = get_auth_provider(user_exists=False) + # TODO: add a function that returns example login dict + # token = get_token("alice") + result = await auth_provider.check_oidc_auth( + "alice", "com.famedly.login.token.oidc", {} + ) + self.assertEqual(result, None) + + async def test_valid_login_with_register(self): + # TODO: example config + config = { + "oidc": { + "issuer": "https://idp.example.org", + "client_id": "", + "client_secret": "", + "project_id": "", + "organization_id": "", + "allow_registration": True, + }, + } + auth_provider = get_auth_provider(config=config, user_exists=False) + # TODO: add a function that returns example login dict + # token = get_token("alice") + result = await auth_provider.check_oidc_auth( + "alice", "com.famedly.login.token.oidc", {} + ) + self.assertEqual(result[0], "@alice:example.org")