From 15bd4c222ae8bbfc1fdef10d9f4d19a76b9ae886 Mon Sep 17 00:00:00 2001 From: Vlad Zagvozdkin Date: Mon, 24 Jun 2024 22:20:08 +0500 Subject: [PATCH] refactor!: custom flow into configurable oauth --- README.md | 355 +++++++++++++++--- pyproject.toml | 2 +- .../claims_validator.py | 179 +++++++++ synapse_token_authenticator/config.py | 173 ++++++++- .../token_authenticator.py | 263 +++++++++---- synapse_token_authenticator/utils.py | 51 ++- tests/__init__.py | 71 +++- tests/test_custom.py | 142 ------- tests/test_oauth.py | 236 ++++++++++++ tests/test_sta_utils.py | 39 ++ tests/test_validators.py | 133 +++++++ 11 files changed, 1352 insertions(+), 292 deletions(-) create mode 100644 synapse_token_authenticator/claims_validator.py delete mode 100644 tests/test_custom.py create mode 100644 tests/test_oauth.py create mode 100644 tests/test_sta_utils.py create mode 100644 tests/test_validators.py diff --git a/README.md b/README.md index ff0661d..3fc7f63 100644 --- a/README.md +++ b/README.md @@ -9,12 +9,32 @@ Synapse Token Authenticator is a synapse auth provider which allows for token au **Table of Contents** -- [Installation](#installation) -- [Configuration](#configuration) -- [Usage](#usage) -- [Testing](#testing) -- [Releasing](#releasing) -- [License](#license) +* [Installation](#installation) +* [Configuration](#configuration) + * [OAuthConfig](#oauthconfig) + * [JwtValidationConfig](#jwtvalidationconfig) + * [IntrospectionValidationConfig](#introspectionvalidationconfig) + * [NotifyOnRegistration](#notifyonregistration) + * [Path](#path) + * [BasicAuth](#basicauth) + * [BearerAuth](#bearerauth) + * [HttpAuth](#httpauth) + * [Validator](#validator) + * [Exist](#exist) + * [Not](#not) + * [Equal](#equal) + * [MatchesRegex](#matchesregex) + * [AnyOf](#anyof) + * [AllOf](#allof) + * [In](#in) + * [ListAllOf](#listallof) + * [ListAnyOf](#listanyof) +* [Usage](#usage) + * [JWT Authentication](#jwt-authentication) + * [OIDC Authentication](#oidc-authentication) +* [Testing](#testing) +* [Releasing](#releasing) +* [License](#license) ## Installation @@ -49,23 +69,286 @@ oidc: allowed_client_ids: ['2897827328738@project_name'] # Allow registration of new users, defaults to false (optional) allow_registration: false -custom_flow: - # provide only one of secret, keyfile - secret: symetrical secret - keyfile: path to asymetrical keyfile - - # Algorithm of the tokens, defaults to RS256 (optional) - algorithm: RS256 - # Require tokens to have an expiry set, defaults to true (optional) - require_expiry: true - # This endpoint will be called when new user is registered - # with `{"token": }` as its request body - notify_on_registration_uri: http://example.com/notify - # Bearer auth token for `notify_on_registration_uri` call (optional) - notification_access_token: 'my$3cr37' +oauth: + # see OAuthConfig section ``` It is recommended to have `require_expiry` set to `true` (default). As for `allow_registration`, it depends on usecase: If you only want to be able to log in *existing* users, leave it at `false` (default). If nonexistant users should be simply registered upon hitting the login endpoint, set it to `true`. +### OAuthConfig +| Parameter | Type | +|----------------------------|------------------------------------------------------------------------------| +| `jwt_validation` | [`JwtValidationConfig`](#JwtValidationConfig) (optional) | +| `introspection_validation` | [`IntrospectionValidationConfig`](#IntrospectionValidationConfig) (optional) | +| `username_type` | One of `'fq_uid'`, `'localpart'`, `'user_id'` (optional) | +| `notify_on_registration` | [`NotifyOnRegistration`](#NotifyOnRegistration) (optional) | +| `expose_metadata_resource` | Any (optional) | +| `registration_enabled` | Bool (defaults to `false`) | + +At least one of `jwt_validation` or `introspection_validation` must be defined. + +`username_type` specifies the role of `identifier.user`: +- `'fq_uid'` — must be fully qualified username, e.g. `@alice:example.test` +- `'localpart'` — must be localpart, e.g. `alice` +- `'user_id'` — could be localpart or fully qualified username +- `null` — the username is ignored, it will be source from the token or introspection response + +If `notify_on_registration` is set then `notify_on_registration.url` will be called when a new user is registered with this body: +```json +{ + "localpart": "alice", + "fully_qualified_uid": "@alice:example.test", + "displayname": "Alice", +}, +``` + +`expose_metadata_resource` must be an object with `name` field. The object will be exposed at `/_famedly/login/{expose_metadata_resource.name}`. + +`jwt_validation` and `introspection_validation` contain a bunch of `*_path` optional fields. Each of these, if specified will be used to source either localpart, user id, or fully qualified user id from jwt claims and introspection response. They values are going to be compared for equality, if they differ, authentication would fail. Be careful with these, as it is possible to configure in such a way that authentication would always fail, or, if `username_type` is `null`, no user id data can be sourced, thus also leading to failure. + + +### JwtValidationConfig +[RFC 7519 - JSON Web Token (JWT)](https://datatracker.ietf.org/doc/html/rfc7519) +| Parameter | Type | +|--------------------|-----------------------------------------------------------| +| `validator` | [`Validator`](#Validator) (defaults to [`Exist`](#Exist)) | +| `require_expiry` | Bool (defaults to `false`) | +| `localpart_path` | [`Path`](#Path) (optional) | +| `user_id_path` | [`Path`](#Path) (optional) | +| `fq_uid_path` | [`Path`](#Path) (optional) | +| `displayname_path` | [`Path`](#Path) (optional) | +| `required_scopes` | Space separated string or a list of strings (optional) | +| `jwk_set` | [JWKSet](https://datatracker.ietf.org/doc/html/rfc7517#section-5) or [JWK](https://datatracker.ietf.org/doc/html/rfc7517#section-4) (optional) | +| `jwk_file` | String (optional) | + +Either `jwk_set` or `jwk_file` must be specified + + +### IntrospectionValidationConfig +[RFC 7662 - OAuth 2.0 Token Introspection](https://datatracker.ietf.org/doc/html/rfc7662) +| Parameter | Type | +|--------------------|-----------------------------------------------------------| +| `endpoint` | String | +| `validator` | [`Validator`](#Validator) (defaults to [`Exist`](#Exist)) | +| `auth` | [`HttpAuth`](#HttpAuth) (optional) | +| `localpart_path` | [`Path`](#Path) (optional) | +| `user_id_path` | [`Path`](#Path) (optional) | +| `fq_uid_path` | [`Path`](#Path) (optional) | +| `displayname_path` | [`Path`](#Path) (optional) | +| `required_scopes` | Space separated string or a list of strings (optional) | + +Keep in mind, that default validator will always pass. According to the [spec](https://datatracker.ietf.org/doc/html/rfc7662), you probably want at least +```yaml +type: in +path: 'active' +validator: + type: equal + value: true +``` +or +```yaml +['in', 'active', ['equal', true]] +``` + +### NotifyOnRegistration: +| Parameter | Type | +|----------------------|------------------------------------| +| `url` | String | +| `auth` | [`HttpAuth`](#HttpAuth) (optional) | +| `interrupt_on_error` | bool (defaults to `true`) | + +### Path +A path is either a string or a list of strings. A path is used to get a value inside a nested dictionary/object. + +#### Examples +- `'foo'` is an existing path in `{'foo': 3}`, resulting in value `3` +- `['foo']` is an existing path in `{'foo': 3}`, resulting in value `3` +- `['foo', 'bar']` is an existing path in `{'foo': {'bar': 3}}`, resulting in value `3` + +### BasicAuth +| Parameter | Type | +|------------|--------| +| `username` | String | +| `password` | String | + +### BearerAuth +| Parameter | Type | +|-----------|--------| +| `token` | String | + +### HttpAuth +Authentication options, always optional +| Parameter | Type | +|-----------|-------------------------| +| `type` | `'basic'` \| `'bearer'` | + +Possible options: [`BasicAuth`](#BasicAuth), [`BearerAuth`](#BearerAuth), + +### Validator +A validator is any of these types: + [`Exist`](#Exist), + [`Not`](#Not), + [`Equal`](#Equal), + [`MatchesRegex`](#MatchesRegex), + [`AnyOf`](#AnyOf), + [`AllOf`](#AllOf), + [`In`](#In), + [`ListAnyOf`](#ListAnyOf), + [`ListAllOf`](#ListAllOf) + +Each validator has `type` field + +### Exist +Validator that always returns true + +#### Examples +```yaml +{'type': 'exist'} +``` +or +```yaml +['exist'] +``` + +### Not +Validator that inverses the result of the inner validator + +| Parameter | Type | +|-------------|---------------------------| +| `validator` | [`Validator`](#Validator) | + +#### Examples +```yaml +{'type': 'not', 'validator': 'exist'} +``` +or +```yaml +['not', 'exist'] +``` + +### Equal +Validator that checks for equality with a constant + +| Parameter | Type | +|-----------|-------| +| `value` | `Any` | + +#### Examples +```yaml +{'type': 'equal', 'value': 3} +``` +or +```yaml +['equal', 3] +``` + +### MatchesRegex +Validator that checks if a value is string and matches given regex + +| Parameter | Type | Description | +|--------------------------------------------|--------|-----------------------------| +| `regex` | `str` | Python regex syntax | +| `full_match` (optional, `true` by default) | `bool` | Full match or partial match | + +#### Examples +```yaml +{'type': 'regex', 'regex': 'hello.'} +``` +or +```yaml +['regex', 'hello.', false] +``` + +### AnyOf +Validator that checks if **any** of the inner validators passes + + +| Parameter | Type | +|--------------|-----------------------------------| +| `validators` | List of [`Validator`](#Validator) | + +#### Examples +```yaml +type: any_of +validators: + - ['in', 'foo', ['equal', 3]] + - ['in', 'bar' ['exist']] +``` +or +```yaml +['any_of', [['in', 'bar' ['exist']], ['in', 'foo', ['equal', 3]]]] +``` + +### AllOf +Validator that checks if **all** of the inner validators pass + +| Parameter | Type | +|--------------|-----------------------------------| +| `validators` | List of [`Validator`](#Validator) | + +#### Examples +```yaml +type: all_of +validators: + - ['exist'] + - ['in', 'foo', ['equal', 3]] +``` +or +```yaml +['all_of', [['exist'], ['in', 'foo', ['equal', 3]]]] +``` + +### In +Validator that modifies the context for the inner validator, *going inside* a dict key. +If the validated object is not a dict, or doesn't have specivied `path`, validation failes. Validator + +| Parameter | Type | +|-------------|---------------------------------------------------------------------| +| `path` | [`Path`](#Path) | +| `validator` | [`Validator`](#Validator) (optional, defaults to [`Exist`](#Exist)) | + +#### Examples +```yaml +['in', ['foo', 'bar'], ['equal', 3]] +``` + +### ListAllOf +Validator that checks if the value is list and **all** of its elements satisfy specified validator. + +| Parameter | Type | +|-------------|---------------------------| +| `validator` | [`Validator`](#Validator) | + +#### Examples +```yaml +type: list_all_of +validator: + type: regex + regex: 'ab..' +``` +or +```yaml +['list_all_of', ['regex', 'ab..']] +``` + +### ListAnyOf +Validator that checks if the value is list and if **any** of its elements satisfy specified validator. + +| Parameter | Type | +|-------------|---------------------------| +| `validator` | [`Validator`](#Validator) | + +#### Examples +```yaml +type: list_all_of +validator: + type: equal + value: 3 +``` +or +```yaml +['list_any_of', ['equal', 3]] +``` + ## Usage ### JWT Authentication @@ -110,38 +393,6 @@ Next, the client needs to use these tokens and construct a payload to the login } ``` -### Custom flow - -This is similar to jwt flow except few additinal claims are checked: -- `name` claim must be present -- `urn:messaging:matrix:localpart` claim must be equal to user name -- `urn:messaging:matrix:mxid` claim must be valid mxid with localpart matching `urn:messaging:matrix:localpart` claim and domain name matching this homeserver domain - -```jsonc -{ - "type": "com.famedly.login.token.custom", - "identifier": { - "type": "m.id.user", - "user": "d2773fdb-91b5-4e77-9367-d4bd121afc48" // localpart, same as `urn:messaging:matrix:localpart` in JWT - }, - "token": "" -} -``` - -An example of a JWT payload: -```jsonc -{ - "iss": "https://auth.example.com", - "sub": "8fd1ec9b-c054-4de0-bbd0-90d40ce9200e", - "exp": 1701432906, - "urn:messaging:matrix:mxid": "@d2773fdb-91b5-4e77-9367-d4bd121afc48:homserver.matrix.de", - "urn:messaging:matrix:localpart": "d2773fdb-91b5-4e77-9367-d4bd121afc48", - "name": "Alice Bob" -} -``` - -Additionally, when a new user is registered, a POST json request is made with `{"token": }` as its request body. The handler of the request must return any json due to some implementation details (synapse's `BaseHttpClient` poor interface) - ## Testing The tests uses twisted's testing framework trial, with the development diff --git a/pyproject.toml b/pyproject.toml index fee716f..c6d9772 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ "ruff", ] [tool.hatch.envs.default.scripts] -cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=synapse_token_authenticator --cov=tests" +cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=synapse_token_authenticator --cov=tests {args}" format = "black ." lint = "ruff check ." diff --git a/synapse_token_authenticator/claims_validator.py b/synapse_token_authenticator/claims_validator.py new file mode 100644 index 0000000..30054f1 --- /dev/null +++ b/synapse_token_authenticator/claims_validator.py @@ -0,0 +1,179 @@ +from dataclasses import dataclass +from typing import List, Optional, Any, TypeAlias, Union +from synapse_token_authenticator.utils import get_path_in_dict +import re + +Validator: TypeAlias = Union[ + "Exist", + "Not", + "Equal", + "MatchesRegex", + "AnyOf", + "AllOf", + "In", + "ListAnyOf", + "ListAllOf", +] + + +def parse_validator(d: dict) -> Validator: + if isinstance(d, dict): + type = d.pop("type") + match type: + case "exist": + return Exist(**d) + case "not": + return Not(**d) + case "equal": + return Equal(**d) + case "regex": + return MatchesRegex(**d) + case "any_of": + return AnyOf(**d) + case "all_of": + return AllOf(**d) + case "in": + return In(**d) + case "list_any_of": + return ListAnyOf(**d) + case "list_all_of": + return ListAllOf(**d) + case t: + raise Exception(f"Unknown validator type {t}") + elif isinstance(d, list): + type = d.pop(0) + match type: + case "exist": + return Exist(*d) + case "not": + return Not(*d) + case "equal": + return Equal(*d) + case "regex": + return MatchesRegex(*d) + case "any_of": + return AnyOf(*d) + case "all_of": + return AllOf(*d) + case "in": + return In(*d) + case "list_any_of": + return ListAnyOf(*d) + case "list_all_of": + return ListAllOf(*d) + case t: + raise Exception(f"Unknown validator type {t}") + else: + raise Exception("Validator parsing failed, expected list or dict") + + +@dataclass +class Exist: + def validate(self, x: Any) -> bool: + return True + + +@dataclass +class Not: + validator: Validator + + def __post_init__(self): + self.validator = parse_validator(self.validator) + + def validate(self, x: Any) -> bool: + return not self.validator.validate(x) + + +@dataclass +class Equal: + value: Any + + def validate(self, x: Any) -> bool: + return x == self.value + + +@dataclass +class MatchesRegex: + regex: str + full_match: bool | None = True + + def __post_init__(self): + self.regex_prog = re.compile(self.regex) + + def validate(self, s: Any) -> bool: + if not isinstance(s, str): + return False + if self.full_match: + return bool(self.regex_prog.fullmatch(s)) + else: + return bool(self.regex_prog.search(s)) + + +@dataclass +class AnyOf: + validators: List[Validator] + + def __post_init__(self): + self.validators = list(map(lambda v: parse_validator(v), self.validators)) + + def validate(self, x: Any) -> bool: + return any(v.validate(x) for v in self.validators) + + +@dataclass +class AllOf: + validators: List[Validator] + + def __post_init__(self): + self.validators = list(map(lambda v: parse_validator(v), self.validators)) + + def validate(self, x: Any) -> bool: + return all(v.validate(x) for v in self.validators) + + +@dataclass +class In: + path: str | List[str] + validator: Optional[Validator] = None + + def __post_init__(self): + if not self.path: + raise Exception("Path list is empty") + if self.validator: + self.validator = parse_validator(self.validator) + + def validate(self, x: Any) -> bool: + if not isinstance(x, dict): + return False + val = get_path_in_dict(self.path, x) + return ( + (self.validator.validate(val) if self.validator else True) if val else False + ) + + +@dataclass +class ListAllOf: + validator: Validator + + def __post_init__(self): + if self.validator: + self.validator = parse_validator(self.validator) + + def validate(self, list_: Any) -> bool: + if not isinstance(list_, list): + return False + return all(self.validator.validate(x) for x in list_) + + +@dataclass +class ListAnyOf: + validator: Validator + + def __post_init__(self): + if self.validator: + self.validator = parse_validator(self.validator) + + def validate(self, list_: Any) -> bool: + if not isinstance(list_, list): + return False + return any(self.validator.validate(x) for x in list_) diff --git a/synapse_token_authenticator/config.py b/synapse_token_authenticator/config.py index cb54860..8dd108b 100644 --- a/synapse_token_authenticator/config.py +++ b/synapse_token_authenticator/config.py @@ -1,4 +1,13 @@ import os +from dataclasses import dataclass, field +from typing import List, Literal, Union, TypeAlias, Any +from jwcrypto.jwk import JWK, JWKSet +from synapse_token_authenticator.claims_validator import ( + parse_validator, + Validator, + Exist, +) +from synapse_token_authenticator.utils import bearer_auth, basic_auth class TokenAuthenticatorConfig: @@ -46,24 +55,99 @@ def __init__(self, other: dict): self.oidc = OIDCConfig(oidc) - if custom_flow := other.get("custom_flow"): - - class CustomFlowConfig: - def __init__(self, other: dict): - self.secret: str | None = other.get("secret") - self.keyfile: str | None = other.get("keyfile") - - self.algorithm: str = other.get("algorithm", "RS256") - self.require_expiry: bool = other.get("require_expiry", True) - self.notify_on_registration_uri: str = other.get( - "notify_on_registration_uri" - ) - self.notification_access_token: str | None = other.get( - "notification_access_token", None - ) - - self.custom_flow = CustomFlowConfig(custom_flow) - verify_jwt_based_cfg(self.custom_flow) + if config := other.get("oauth"): + + Path: TypeAlias = Union[str, List[str]] + + @dataclass + class JwtValidationConfig: + validator: Validator = field(default_factory=Exist) + require_expiry: bool = False + localpart_path: Path | None = None + user_id_path: Path | None = None + fq_uid_path: Path | None = None + displayname_path: Path | None = None + required_scopes: str | List[str] | None = None + jwk_set: JWKSet | JWK | None = None + jwk_file: str | None = None + + def __post_init__(self): + if not isinstance(self.validator, Exist): + self.validator = parse_validator(self.validator) + + if self.jwk_set and ("keys" in self.jwk_set): + self.jwk_set = JWKSet(**self.jwk_set) + elif self.jwk_set: + self.jwk_set = JWK(**self.jwk_set) + elif self.jwk_file: + with open(self.jwk_file) as f: + self.jwk_set = JWK.from_pem(f.read()) + else: + raise Exception("No JWK") + + @dataclass + class IntrospectionValidationConfig: + endpoint: str + validator: Validator = field(default_factory=Exist) + auth: HttpAuth = field(default_factory=NoAuth) + localpart_path: Path | None = None + user_id_path: Path | None = None + fq_uid_path: Path | None = None + displayname_path: Path | None = None + required_scopes: str | List[str] | None = None + + def __post_init__(self): + if not isinstance(self.validator, Exist): + self.validator = parse_validator(self.validator) + + if not isinstance(self.auth, NoAuth): + self.auth = parse_auth(self.auth) + + @dataclass + class NotifyOnRegistration: + url: str + auth: HttpAuth = field(default_factory=NoAuth) + interrupt_on_error: bool = True + + def __post_init__(self): + if not isinstance(self.auth, NoAuth): + self.auth = parse_auth(self.auth) + + @dataclass + class OAuthConfig: + jwt_validation: JwtValidationConfig | None = None + introspection_validation: IntrospectionValidationConfig | None = None + username_type: Literal["fq_uid", "localpart", "user_id"] | None = None + notify_on_registration: NotifyOnRegistration | None = None + expose_metadata_resource: Any = None + registration_enabled: bool = False + + def __post_init__(self): + if self.notify_on_registration: + self.notify_on_registration = NotifyOnRegistration( + **self.notify_on_registration + ) + if self.jwt_validation: + self.jwt_validation = JwtValidationConfig( + **(self.jwt_validation) + ) + if self.introspection_validation: + self.introspection_validation = IntrospectionValidationConfig( + **self.introspection_validation + ) + if not (self.jwt_validation or self.introspection_validation): + raise Exception( + "Neither jwt_validation nor introspection_validation was specified" + ) + if self.username_type not in [ + "fq_uid", + "localpart", + "user_id", + None, + ]: + raise Exception(f"Unknown username_type {self.username_type}") + + self.oauth = OAuthConfig(**config) def verify_jwt_based_cfg(cfg): @@ -88,3 +172,56 @@ def verify_jwt_based_cfg(cfg): "EdDSA", ]: raise Exception(f"Unknown algorithm: '{cfg.algorithm}'") + + +@dataclass +class NoAuth: + def header_map(self): + return {} + + +@dataclass +class BasicAuth: + username: str + password: str + + def header_map(self): + return basic_auth(self.username, self.password) + + +@dataclass +class BearerAuth: + token: str + + def header_map(self): + return bearer_auth(self.token) + + +HttpAuth: TypeAlias = Union[BasicAuth, BearerAuth, NoAuth] + + +def parse_auth(d: dict) -> HttpAuth: + if isinstance(d, dict): + type = d.pop("type") + match type: + case None: + return NoAuth() + case "basic": + return BasicAuth(**d) + case "bearer": + return BearerAuth(**d) + case t: + raise Exception(f"Unknown HttpAuth type {t}") + elif isinstance(d, list): + type = d.pop(0) + match type: + case None: + return NoAuth() + case "basic": + return BasicAuth(*d) + case "bearer": + return BearerAuth(*d) + case t: + raise Exception(f"Unknown HttpAuth type {t}") + else: + raise Exception("HttpAuth parsing failed, expected list or dict") diff --git a/synapse_token_authenticator/token_authenticator.py b/synapse_token_authenticator/token_authenticator.py index d59843d..5b416b2 100644 --- a/synapse_token_authenticator/token_authenticator.py +++ b/synapse_token_authenticator/token_authenticator.py @@ -29,7 +29,15 @@ from twisted.web import resource from synapse_token_authenticator.config import TokenAuthenticatorConfig -from synapse_token_authenticator.utils import get_oidp_metadata, basic_auth +from synapse_token_authenticator.utils import ( + get_oidp_metadata, + basic_auth, + validate_scopes, + all_list_elems_are_equal_return_the_elem, + get_path_in_dict, + if_not_none, + MetadataResource, +) logger = logging.getLogger(__name__) @@ -67,20 +75,15 @@ def __init__(self, config: dict, account_handler: ModuleApi): self.LoginMetadataResource(oidc), ) - if (custom_flow := getattr(self.config, "custom_flow", None)) is not None: - if custom_flow.secret: - k = { - "k": base64.urlsafe_b64encode( - custom_flow.secret.encode("utf-8") - ).decode("utf-8"), - "kty": "oct", - } - self.custom_flow_key = jwk.JWK(**k) - else: - with open(custom_flow.keyfile) as f: - self.key = jwk.JWK.from_pem(f.read()) - auth_checkers[("com.famedly.login.token.custom", ("token",))] = ( - self.check_custom_flow + if (cfg := getattr(self.config, "oauth", None)) is not None: + if cfg.expose_metadata_resource: + self.api.register_web_resource( + f"/_famedly/login/{cfg.expose_metadata_resource.name}", + MetadataResource(cfg.expose_metadata_resource), + ) + + auth_checkers[("com.famedly.login.token.oauth", ("token",))] = ( + self.check_oauth ) self.api.register_password_auth_provider_callbacks(auth_checkers=auth_checkers) @@ -123,7 +126,7 @@ async def check_jwt_auth( return None token = login_dict["token"] - check_claims = {} + check_claims: dict = {} if self.config.jwt.require_expiry: check_claims["exp"] = None try: @@ -167,7 +170,7 @@ async def check_jwt_auth( if user_id.domain != self.api.server_name: logger.info("user_id isn't for our homeserver") - return + return None if user_id_str != token_user_id_str: logger.info("Non-matching user") @@ -287,7 +290,7 @@ async def check_oidc_auth( logger.info("All done and valid, logging in!") return (user_id_str, None) - async def check_custom_flow( + async def check_oauth( self, username: str, login_type: str, login_dict: "synapse.module_api.JsonDict" ) -> Optional[ tuple[ @@ -295,8 +298,9 @@ async def check_custom_flow( Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], ] ]: + config = self.config.oauth logger.info("Receiving auth request") - if login_type != "com.famedly.login.token.custom": + if login_type != "com.famedly.login.token.oauth": logger.info("Wrong login type") return None if "token" not in login_dict: @@ -306,70 +310,193 @@ async def check_custom_flow( client = self.api._hs.get_proxied_http_client() - check_claims = {} + jwt_claims = {} + + if config.jwt_validation is not None: + check_claims: dict = {} + if config.jwt_validation.require_expiry: + check_claims["exp"] = None + try: + token = jwt.JWT( + jwt=token, + key=config.jwt_validation.jwk_set, + check_claims=check_claims, + ) + except ValueError as e: + logger.info("Unrecognized token %s", e) + return None + except JWException as e: + logger.info("Invalid token %s", e) + return None + + jwt_claims = json_decode(token.claims) - user_id_str = self.api.get_qualified_user_id(username) - check_claims["urn:messaging:matrix:localpart"] = username - check_claims["urn:messaging:matrix:mxid"] = user_id_str + if config.jwt_validation.required_scopes: + provided_scope = jwt_claims.get("scope") + if not isinstance(provided_scope, str): + logger.info("Token missing scope claim") + return None + + if not validate_scopes( + config.jwt_validation.required_scopes, provided_scope + ): + logger.info("Token scope validation failed") + return None + + if not config.jwt_validation.validator.validate(jwt_claims): + logger.info("Token claims validation failed") + return None + + introspection_claims = {} + + if config.introspection_validation is not None: + try: + introspection_claims = await client.post_urlencoded_get_json( + config.introspection_validation.endpoint, + {"token": token}, + headers=config.introspection_validation.auth.header_map(), + ) + except HttpResponseException as e: + if e.code == 401: + logger.info("Introspection auth failed") + return None + else: + raise e + + if config.introspection_validation.required_scopes: + provided_scope = introspection_claims.get("scope") + if not isinstance(provided_scope, str): + logger.info("Token missing scope claim") + return None + + if not validate_scopes( + config.introspection_validation.required_scopes, provided_scope + ): + logger.info("Token scope validation failed") + return None + + if not config.introspection_validation.validator.validate( + introspection_claims + ): + logger.info("Introspection response validation failed for a token") + return None + + # getting localpart and fully qualified user_id, validate all sources for equality + + def get_from_set(set_): + return if_not_none(lambda path: get_path_in_dict(path, set_)) + + username_type = config.username_type - if self.config.custom_flow.require_expiry: - check_claims["exp"] = None try: - token = jwt.JWT( - jwt=token, - key=self.custom_flow_key, - check_claims=check_claims, - algs=[self.config.custom_flow.algorithm], + get_localpart_mb = if_not_none(lambda x: x.localpart_path) + + localpart = all_list_elems_are_equal_return_the_elem( + [ + get_from_set(jwt_claims)(get_localpart_mb(config.jwt_validation)), + get_from_set(introspection_claims)( + get_localpart_mb(config.introspection_validation) + ), + username if username_type == "localpart" else None, + ( + UserID.from_string(username).localpart + if username_type == "fq_uid" + else None + ), + ( + UserID.from_string( + self.api.get_qualified_user_id(username) + ).localpart + if username_type == "user_id" + else None + ), + ] ) - except ValueError as e: - logger.info("Unrecognized token %s", e) - return None - except JWException as e: - logger.info("Invalid token %s", e) - return None - payload = json_decode(token.claims) - sub = payload["sub"] - if not isinstance(sub, str): - logger.info("user_id isn't a string") - return None - if "name" not in payload: - logger.info("No name claim in payload") + get_fq_uid_mb = if_not_none(lambda x: x.fq_uid_path) + + fully_qualified_uid = all_list_elems_are_equal_return_the_elem( + [ + get_from_set(jwt_claims)(get_fq_uid_mb(config.jwt_validation)), + get_from_set(introspection_claims)( + get_fq_uid_mb(config.introspection_validation) + ), + username if username_type == "fq_uid" else None, + ( + self.api.get_qualified_user_id(username) + if username_type == "user_id" + else None + ), + ( + self.api.get_qualified_user_id(username) + if username_type == "localpart" + else None + ), + ] + ) + except Exception as e: + logger.info(e) return None - user_id = UserID.from_string(user_id_str) - user_exists = await self.api.check_user_exists(user_id_str) + if localpart is None and fully_qualified_uid is None: + logger.info("No user id was provided") + return None - if not user_exists: - logger.info("User doesn't exist, registering them...") - await self.api.register_user(user_id.localpart) + if localpart is None: + localpart = UserID.from_string(fully_qualified_uid).localpart - # notification_access_token - headers = {} - if self.config.custom_flow.notification_access_token is not None: - headers = { - b"Authorization": [ - b"Bearer " + self.config.custom_flow.notification_access_token - ] - } + if fully_qualified_uid is None: + fully_qualified_uid = self.api.get_qualified_user_id(localpart) - await client.post_json_get_json( - self.config.custom_flow.notify_on_registration_uri, - {"token": login_dict["token"]}, - headers=headers, + try: + get_displayname_mb = if_not_none(lambda x: x.displayname_path) + displayname = all_list_elems_are_equal_return_the_elem( + [ + get_from_set(jwt_claims)(get_displayname_mb(config.jwt_validation)), + get_from_set(introspection_claims)( + get_displayname_mb(config.introspection_validation) + ), + ] ) + except Exception as e: + logger.info(e) + return None - logger.info("Registered user %s (%s)", user_id, payload["name"]) + user_exists = await self.api.check_user_exists(fully_qualified_uid) - await self.api._hs.get_profile_handler().set_displayname( - requester=synapse.types.create_requester(user_id), - target_user=user_id, - by_admin=True, - new_displayname=payload["name"], - ) + if not user_exists and config.registration_enabled: + logger.info("User doesn't exist, registering them...") + await self.api.register_user(localpart) + + if config.notify_on_registration: + try: + await client.post_json_get_json( + config.notify_on_registration.url, + { + "localpart": localpart, + "fully_qualified_uid": fully_qualified_uid, + "displayname": displayname, + }, + headers=config.notify_on_registration.auth.header_map(), + ) + except HttpResponseException as e: + logger.info(e) + if config.notify_on_registration.interrupt_on_error: + return None + + logger.info("Registered user %s (%s)", localpart, displayname) + + if displayname: + user_id = UserID.from_string(fully_qualified_uid) + await self.api._hs.get_profile_handler().set_displayname( + requester=synapse.types.create_requester(user_id), + target_user=user_id, + by_admin=True, + new_displayname=displayname, + ) logger.info("All done and valid, logging in!") - return (user_id_str, None) + return (fully_qualified_uid, None) @staticmethod def parse_config(config: dict): diff --git a/synapse_token_authenticator/utils.py b/synapse_token_authenticator/utils.py index 1dbfcc8..71c4bc2 100644 --- a/synapse_token_authenticator/utils.py +++ b/synapse_token_authenticator/utils.py @@ -1,5 +1,8 @@ from base64 import b64encode from urllib.parse import urljoin +from typing import List, Optional, Any +import json +from twisted.web import resource class OpenIDProviderMetadata: @@ -23,8 +26,54 @@ async def get_oidp_metadata(issuer, client) -> OpenIDProviderMetadata: return OpenIDProviderMetadata(issuer, config) -def basic_auth(username: str, password: str) -> dict[bytes, bytes]: +def basic_auth(username: str, password: str) -> dict[bytes, list[bytes]]: authorization = b64encode( b":".join((username.encode("utf8"), password.encode("utf8"))) ) return {b"Authorization": [b"Basic " + authorization]} + + +def bearer_auth(token: str) -> dict[bytes, list[bytes]]: + return {b"Authorization": [b"Bearer " + token.encode("utf8")]} + + +def if_not_none(f): + return lambda x: (f(x) if x is not None else None) + + +def all_list_elems_are_equal_return_the_elem(list_): + filtered_list = list(filter(lambda x: x is not None, list_)) + if len(filtered_list) == 0: + return None + val = filtered_list[0] + if not all(i == val for i in filtered_list): + raise Exception(f"Elements in {filtered_list} are not equal") + return val + + +def get_path_in_dict(path: str | List[str], d: Any) -> Optional[Any]: + if isinstance(path, str): + path = [path] + r = d + for p in path: + if not isinstance(r, dict): + return None + r = r.get(p) + return r + + +def validate_scopes(required_scopes: str | List[str], provided_scopes: str) -> bool: + if isinstance(required_scopes, str): + required_scopes = required_scopes.split() + provided_scopes_list = provided_scopes.split() + return all(scope in provided_scopes_list for scope in required_scopes) + + +class MetadataResource(resource.Resource): + def __init__(self, resource: object): + self.resource = resource + + def render_GET(self, request): + request.setHeader(b"content-type", b"application/json") + request.setHeader(b"access-control-allow-origin", b"*") + return json.dumps(self.resource).encode("utf-8") diff --git a/tests/__init__.py b/tests/__init__.py index d3b7c60..233bf33 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -16,6 +16,7 @@ import base64 import logging import time +import json from typing import Any from unittest.mock import AsyncMock, Mock, patch from urllib.parse import parse_qs @@ -109,11 +110,6 @@ def default_config(self) -> dict[str, Any]: "module": "synapse_token_authenticator.TokenAuthenticator", "config": { "jwt": {"secret": "foxies"}, - "custom_flow": { - "algorithm": "HS512", - "secret": "foxies", - "notify_on_registration_uri": "http://example.test", - }, "oidc": { "issuer": "https://idp.example.test", "client_id": "1111@project", @@ -121,20 +117,35 @@ def default_config(self) -> dict[str, Any]: "project_id": "231872387283", "organization_id": "2283783782778", }, + "oauth": { + "jwt_validation": { + "validator": ["exist"], + "require_expiry": True, + "localpart_path": "urn:messaging:matrix:localpart", + "fq_uid_path": "urn:messaging:matrix:mxid", + "required_scopes": "foo bar", + "jwk_set": get_jwk(), + }, + "username_type": "user_id", + "registration_enabled": True, + }, }, } ] return conf +def get_jwk(secret="foxies"): + return jwk.JWK( + k=base64.urlsafe_b64encode(secret.encode("utf-8")).decode("utf-8"), + kty="oct", + ) + + def get_jwt_token( username, exp_in=None, secret="foxies", algorithm="HS512", admin=None, claims=None ): - k = { - "k": base64.urlsafe_b64encode(secret.encode("utf-8")).decode("utf-8"), - "kty": "oct", - } - key = jwk.JWK(**k) + key = get_jwk(secret) if claims is None: claims = {} claims["sub"] = username @@ -206,3 +217,43 @@ def mock_idp_post(uri, data_raw, **kwargs): ) else: return Response(code=404) + + +def mock_for_oauth(method, uri, data=None, **extrargs): + match (method, uri): + case ("POST", "http://idp.test/introspect"): + data = parse_qs(data.decode()) + match data: + case {"token": token}: + pass + case _: + logger.error(f"Bad introspect request: {data}") + return Response(code=400) + return Response.json( + payload={ + "active": True, + "localpart": "alice", + "scope": "bar foo", + "name": "Alice", + } + ) + case ("POST", "http://iop.test/notify"): + data = json.loads(data) + match data: + case { + "localpart": localpart, + "fully_qualified_uid": fully_qualified_uid, + "displayname": displayname, + }: + assert data == { + "localpart": "alice", + "fully_qualified_uid": "@alice:example.test", + "displayname": "Alice", + } + case _: + logger.error(f"Bad notify request: {data}") + return Response(code=400) + return Response.json(payload=None) + case (m, u): + logger.error(f"Unknown request {m} {u}") + return Response(code=404) diff --git a/tests/test_custom.py b/tests/test_custom.py deleted file mode 100644 index a2d8866..0000000 --- a/tests/test_custom.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (C) 2024 Famedly -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . - -from unittest import mock - -import tests.unittest as synapsetest - -from . import ModuleApiTestCase, get_jwt_token - -default_claims = { - "urn:messaging:matrix:localpart": "alice", - "urn:messaging:matrix:mxid": "@alice:example.test", - "name": "Alice", -} - - -class CustomFlowTests(ModuleApiTestCase): - async def test_wrong_login_type(self): - token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_custom_flow( - "alice", "com.famedly.login.token", {"token": token} - ) - self.assertEqual(result, None) - - async def test_missing_token(self): - result = await self.hs.mockmod.check_custom_flow( - "alice", "com.famedly.login.token.custom", {} - ) - self.assertEqual(result, None) - - async def test_invalid_token(self): - result = await self.hs.mockmod.check_custom_flow( - "alice", "com.famedly.login.token.custom", {"token": "invalid"} - ) - self.assertEqual(result, None) - - async def test_token_wrong_secret(self): - token = get_jwt_token("aliceid", secret="wrong secret", claims=default_claims) - result = await self.hs.mockmod.check_custom_flow( - "alice", "com.famedly.login.token.custom", {"token": token} - ) - self.assertEqual(result, None) - - async def test_token_wrong_alg(self): - token = get_jwt_token("aliceid", algorithm="HS256", claims=default_claims) - result = await self.hs.mockmod.check_custom_flow( - "alice", "com.famedly.login.token.custom", {"token": token} - ) - self.assertEqual(result, None) - - async def test_token_expired(self): - token = get_jwt_token("aliceid", exp_in=-60, claims=default_claims) - result = await self.hs.mockmod.check_custom_flow( - "alice", "com.famedly.login.token.custom", {"token": token} - ) - self.assertEqual(result, None) - - async def test_token_no_expiry(self): - token = get_jwt_token("aliceid", exp_in=-1, claims=default_claims) - result = await self.hs.mockmod.check_custom_flow( - "alice", "com.famedly.login.token.custom", {"token": token} - ) - self.assertEqual(result, None) - - async def test_token_bad_localpart(self): - claims = default_claims.copy() - claims["urn:messaging:matrix:localpart"] = "bobby" - token = get_jwt_token("aliceid", claims=claims) - result = await self.hs.mockmod.check_custom_flow( - "alice", "com.famedly.login.token.custom", {"token": token} - ) - self.assertEqual(result, None) - - async def test_token_bad_mxid(self): - claims = default_claims.copy() - claims["urn:messaging:matrix:mxid"] = "@bobby:example.test" - token = get_jwt_token("aliceid", claims=claims) - result = await self.hs.mockmod.check_custom_flow( - "alice", "com.famedly.login.token.custom", {"token": token} - ) - self.assertEqual(result, None) - - async def test_token_claims_username_mismatch(self): - token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_custom_flow( - "bobby", "com.famedly.login.token.custom", {"token": token} - ) - self.assertEqual(result, None) - - @synapsetest.override_config( - { - "modules": [ - { - "module": "synapse_token_authenticator.TokenAuthenticator", - "config": { - "custom_flow": { - "secret": "foxies", - "require_expiry": False, - "algorithm": "HS512", - "notify_on_registration_uri": "http://example.test", - } - }, - } - ] - } - ) - async def test_token_no_expiry_with_config(self, *args): - token = get_jwt_token("aliceid", exp_in=-1, claims=default_claims) - result = await self.hs.mockmod.check_custom_flow( - "alice", "com.famedly.login.token.custom", {"token": token} - ) - self.assertEqual(result[0], "@alice:example.test") - - async def test_valid_login(self): - token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_custom_flow( - "alice", "com.famedly.login.token.custom", {"token": token} - ) - self.assertEqual(result[0], "@alice:example.test") - - @mock.patch("synapse.module_api.ModuleApi.check_user_exists", return_value=False) - @mock.patch( - "synapse.http.client.SimpleHttpClient.post_json_get_json", return_value={} - ) - async def test_valid_login_register(self, *args): - token = get_jwt_token("aliceid", claims=default_claims) - result = await self.hs.mockmod.check_custom_flow( - "alice", "com.famedly.login.token.custom", {"token": token} - ) - self.assertEqual(result[0], "@alice:example.test") diff --git a/tests/test_oauth.py b/tests/test_oauth.py new file mode 100644 index 0000000..e387311 --- /dev/null +++ b/tests/test_oauth.py @@ -0,0 +1,236 @@ +# Copyright (C) 2024 Famedly +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from unittest import mock + +import tests.unittest as synapsetest + +from . import ModuleApiTestCase, get_jwt_token, get_jwk, mock_for_oauth +from copy import deepcopy + +default_claims = { + "urn:messaging:matrix:localpart": "alice", + "urn:messaging:matrix:mxid": "@alice:example.test", + "name": "Alice", + "scope": "bar foo", +} + + +class CustomFlowTests(ModuleApiTestCase): + async def test_wrong_login_type(self): + token = get_jwt_token("aliceid", claims=default_claims) + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token", {"token": token} + ) + self.assertEqual(result, None) + + async def test_missing_token(self): + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token.oauth", {} + ) + self.assertEqual(result, None) + + async def test_invalid_token(self): + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token.oauth", {"token": "invalid"} + ) + self.assertEqual(result, None) + + async def test_token_wrong_secret(self): + token = get_jwt_token("aliceid", secret="wrong secret", claims=default_claims) + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token.oauth", {"token": token} + ) + self.assertEqual(result, None) + + async def test_token_expired(self): + token = get_jwt_token("aliceid", exp_in=-60, claims=default_claims) + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token.oauth", {"token": token} + ) + self.assertEqual(result, None) + + async def test_token_no_expiry(self): + token = get_jwt_token("aliceid", exp_in=-1, claims=default_claims) + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token.oauth", {"token": token} + ) + self.assertEqual(result, None) + + async def test_token_bad_localpart(self): + claims = default_claims.copy() + claims["urn:messaging:matrix:localpart"] = "bobby" + token = get_jwt_token("aliceid", claims=claims) + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token.oauth", {"token": token} + ) + self.assertEqual(result, None) + + async def test_token_bad_mxid(self): + claims = default_claims.copy() + claims["urn:messaging:matrix:mxid"] = "@bobby:example.test" + token = get_jwt_token("aliceid", claims=claims) + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token.oauth", {"token": token} + ) + self.assertEqual(result, None) + + async def test_token_claims_username_mismatch(self): + token = get_jwt_token("aliceid", claims=default_claims) + result = await self.hs.mockmod.check_oauth( + "bobby", "com.famedly.login.token.oauth", {"token": token} + ) + self.assertEqual(result, None) + + @synapsetest.override_config( + { + "modules": [ + { + "module": "synapse_token_authenticator.TokenAuthenticator", + "config": { + "oauth": { + "jwt_validation": { + "validator": ["exist"], + "require_expiry": False, + "jwk_set": get_jwk(), + }, + "username_type": "user_id", + }, + }, + } + ] + } + ) + async def test_token_no_expiry_with_config(self, *args): + token = get_jwt_token("aliceid", exp_in=-1, claims=default_claims) + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token.oauth", {"token": token} + ) + self.assertEqual(result[0], "@alice:example.test") + + async def test_valid_login(self): + token = get_jwt_token("aliceid", claims=default_claims) + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token.oauth", {"token": token} + ) + self.assertEqual(result[0], "@alice:example.test") + + @mock.patch("synapse.module_api.ModuleApi.check_user_exists", return_value=False) + @mock.patch( + "synapse.http.client.SimpleHttpClient.post_json_get_json", return_value={} + ) + async def test_valid_login_register(self, *args): + token = get_jwt_token("aliceid", claims=default_claims) + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token.oauth", {"token": token} + ) + self.assertEqual(result[0], "@alice:example.test") + + async def test_invalid_scope(self): + claims = default_claims.copy() + claims["scope"] = "foo" + token = get_jwt_token("aliceid", claims=claims) + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token.oauth", {"token": token} + ) + self.assertEqual(result, None) + + config_for_introspection = { + "modules": [ + { + "module": "synapse_token_authenticator.TokenAuthenticator", + "config": { + "oauth": { + "introspection_validation": { + "endpoint": "http://idp.test/introspect", + "validator": ["in", "active", ["equal", True]], + "localpart_path": "localpart", + "displayname_path": "name", + "required_scopes": "foo bar", + }, + "username_type": "user_id", + "notify_on_registration": {"url": "http://iop.test/notify"}, + "registration_enabled": True, + }, + }, + } + ] + } + + @synapsetest.override_config(config_for_introspection) + @mock.patch( + "synapse.http.client.SimpleHttpClient.request", side_effect=mock_for_oauth + ) + @mock.patch("synapse.module_api.ModuleApi.check_user_exists", return_value=False) + async def test_valid_login_introspection(self, *args): + token = get_jwt_token("aliceid", claims=default_claims) + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token.oauth", {"token": token} + ) + self.assertEqual(result[0], "@alice:example.test") + + config_for_introspection_bad_notify_url = deepcopy(config_for_introspection) + config_for_introspection_bad_notify_url["modules"][0]["config"]["oauth"][ + "notify_on_registration" + ]["url"] = "http://bad-iop.test/notify" + + @synapsetest.override_config(config_for_introspection_bad_notify_url) + @mock.patch( + "synapse.http.client.SimpleHttpClient.request", side_effect=mock_for_oauth + ) + @mock.patch("synapse.module_api.ModuleApi.check_user_exists", return_value=False) + async def test_login_introspection_notify_fails(self, *args): + token = get_jwt_token("aliceid", claims=default_claims) + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token.oauth", {"token": token} + ) + self.assertEqual(result, None) + + config_for_introspection_bad_notify_url_but_ok = deepcopy( + config_for_introspection_bad_notify_url + ) + config_for_introspection_bad_notify_url_but_ok["modules"][0]["config"]["oauth"][ + "notify_on_registration" + ]["interrupt_on_error"] = False + + @synapsetest.override_config(config_for_introspection_bad_notify_url_but_ok) + @mock.patch( + "synapse.http.client.SimpleHttpClient.request", side_effect=mock_for_oauth + ) + @mock.patch("synapse.module_api.ModuleApi.check_user_exists", return_value=False) + async def test_login_introspection_notify_fails_but_ok(self, *args): + token = get_jwt_token("aliceid", claims=default_claims) + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token.oauth", {"token": token} + ) + self.assertEqual(result[0], "@alice:example.test") + + config_for_introspection_more_required_scopes = deepcopy(config_for_introspection) + config_for_introspection_more_required_scopes["modules"][0]["config"]["oauth"][ + "introspection_validation" + ]["required_scopes"] = ["foo", "bar", "baz"] + + @synapsetest.override_config(config_for_introspection_more_required_scopes) + @mock.patch( + "synapse.http.client.SimpleHttpClient.request", side_effect=mock_for_oauth + ) + async def test_login_introspection_invalid_scope(self, *args): + claims = default_claims.copy() + claims["scope"] = "foo" + token = get_jwt_token("aliceid", claims=claims) + result = await self.hs.mockmod.check_oauth( + "alice", "com.famedly.login.token.oauth", {"token": token} + ) + self.assertEqual(result, None) diff --git a/tests/test_sta_utils.py b/tests/test_sta_utils.py new file mode 100644 index 0000000..1aaefa8 --- /dev/null +++ b/tests/test_sta_utils.py @@ -0,0 +1,39 @@ +from synapse_token_authenticator.utils import get_path_in_dict, validate_scopes, if_not_none, all_list_elems_are_equal_return_the_elem + +def test_get_path_in_dict(): + assert get_path_in_dict('foo', {'foo': 3}) == 3 + assert get_path_in_dict('foo', {'loo': 3}) == None + assert get_path_in_dict('foo', [3, 4]) == None + assert get_path_in_dict('foo', {'foo': None}) == None + assert get_path_in_dict('foo', {'foo': False}) == False + assert get_path_in_dict(['foo'], {'foo': 3}) == 3 + assert get_path_in_dict(['foo', 'loo'], {'foo': {'loo': 3}}) == 3 + assert get_path_in_dict(['foo', 'loo', 'boo'], {'foo': {'loo': {'boo': 3}}}) == 3 + assert get_path_in_dict(['foo', 'loo'], {'foo': {'loo': {'boo': 3}}}) == {'boo': 3} + assert get_path_in_dict([], {'foo': 3}) == {'foo': 3} + assert get_path_in_dict(['foo', 'loo'], {'foo': {'boo': 3}}) == None + +def test_validate_scopes(): + assert validate_scopes("foo boo", "boo foo") + assert validate_scopes(["foo", "boo"], "boo foo") + assert not validate_scopes("foo boo", "foo") + assert not validate_scopes(["foo", "boo"], "foo") + assert validate_scopes("foo boo", "boo foo loo") + +def test_if_not_none(): + assert if_not_none(lambda x: x + 1)(3) == 4 + assert if_not_none(lambda x: x + 1)(None) == None + +def test_all_list_elems_are_equal_return_the_elem(): + assert all_list_elems_are_equal_return_the_elem([None, None]) == None + assert all_list_elems_are_equal_return_the_elem([]) == None + assert all_list_elems_are_equal_return_the_elem([3, None]) == 3 + assert all_list_elems_are_equal_return_the_elem([None, 3]) == 3 + assert all_list_elems_are_equal_return_the_elem([3, 3]) == 3 + assert all_list_elems_are_equal_return_the_elem([3]) == 3 + try: + all_list_elems_are_equal_return_the_elem([3, 4]) + assert False + except: + assert True + diff --git a/tests/test_validators.py b/tests/test_validators.py new file mode 100644 index 0000000..745f3d8 --- /dev/null +++ b/tests/test_validators.py @@ -0,0 +1,133 @@ +from synapse_token_authenticator.claims_validator import parse_validator + + +def test_validator_exists(): + assert parse_validator(["exist"]).validate(None) + + +def test_validator_in(): + assert parse_validator(["in", "foo"]).validate({"foo": 3}) + assert not parse_validator(["in", "foo"]).validate({"loo": 3}) + assert parse_validator(["in", "foo", ["equal", 3]]).validate({"foo": 3}) + assert not parse_validator(["in", "foo", ["equal", 3]]).validate({"foo": 4}) + + +def test_validator_not(): + assert not parse_validator(["not", ["in", "foo"]]).validate({"foo": 3}) + assert parse_validator(["not", ["in", "foo"]]).validate({"loo": 3}) + assert not parse_validator(["not", ["exist"]]).validate(None) + + +def test_validator_equal(): + assert parse_validator(["equal", 3]).validate(3) + assert not parse_validator(["equal", 3]).validate(4) + assert parse_validator(["equal", {"hi": 3}]).validate({"hi": 3}) + assert not parse_validator(["equal", {"hi": 3}]).validate({"hi": 4}) + + +def test_validator_regex(): + txt = "The rain in Spain" + regexp = "The.*Spain" + assert parse_validator(["regex", regexp]).validate(txt) + assert parse_validator(["regex", regexp, False]).validate("smth" + txt + "smth") + assert not parse_validator(["regex", regexp]).validate("bad string") + + +def test_validator_all_of(): + assert parse_validator(["all_of", [["in", "foo"], ["in", "loo"]]]).validate( + {"foo": 3, "loo": 4} + ) + assert not parse_validator(["all_of", [["in", "foo"], ["in", "loo"]]]).validate( + {"foo": 3, "boo": 4} + ) + assert parse_validator(["all_of", []]).validate([]) + + +def test_validator_any_of(): + assert parse_validator(["any_of", [["in", "foo"], ["in", "loo"]]]).validate( + {"foo": 3, "loo": 4} + ) + assert parse_validator(["any_of", [["in", "foo"], ["in", "loo"]]]).validate( + {"foo": 3} + ) + assert not parse_validator(["any_of", [["in", "foo"], ["in", "loo"]]]).validate( + {"boo": 3} + ) + assert not parse_validator(["any_of", []]).validate({}) + + +def test_validator_list_all_of(): + assert parse_validator(["list_all_of", ["in", "foo"]]).validate( + [{"foo": 3}, {"foo": 4}] + ) + assert parse_validator(["list_all_of", ["in", "foo"]]).validate([]) + assert not parse_validator(["list_all_of", ["in", "foo"]]).validate( + [{"foo": 3}, {"loo": 4}] + ) + + +def test_validator_list_all_of(): + assert parse_validator(["list_any_of", ["in", "foo"]]).validate( + [{"foo": 3}, {"foo": 4}] + ) + assert not parse_validator(["list_any_of", ["in", "foo"]]).validate([]) + assert parse_validator(["list_any_of", ["in", "foo"]]).validate( + [{"foo": 3}, {"loo": 4}] + ) + + +def test_validator_full(): + required_claims = { + "type": "all_of", + "validators": [ + { + "type": "in", + "path": "foo", + "validator": { + "type": "regex", + "regex": "hell", + "full_match": False, + }, + }, + { + "type": "in", + "path": "bar", + "validator": { + "type": "equal", + "value": "hi", + }, + }, + { + "type": "in", + "path": ["baz", "laz", "loo"], + "validator": {"type": "equal", "value": 3}, + }, + ], + } + + assert parse_validator(required_claims).validate(jwt_claims) + + +def test_validator_short(): + required_claims_short = [ + "all_of", + [ + ["in", "foo", ["regex", "hell", False]], + ["in", "bar", ["equal", "hi"]], + ["in", ["baz", "laz", "loo"], ["equal", 3]], + ], + ] + + assert parse_validator(required_claims_short).validate(jwt_claims) + + +jwt_claims = { + "foo": "hello", + "bar": "hi", + "baz": { + "laz": { + "loo": 3, + "goo": 4, + } + }, +}