diff --git a/doc/README.md b/doc/README.md index 8d001847e..df5b552e9 100644 --- a/doc/README.md +++ b/doc/README.md @@ -125,7 +125,7 @@ attribute to use, e.g. `address.formatted` will access the attribute value attributes: mail: openid: [email] - saml: [mail, emailAdress, email] + saml: [mail, emailAddress, email] address: openid: [address.formatted] saml: [postaladdress] @@ -140,7 +140,7 @@ attributes (in the proxy backend) <-> internal <-> returned attributes (from the * Any plugin using the `openid` profile will use the attribute value from `email` delivered from the target provider as the value for `mail`. * Any plugin using the `saml` profile will use the attribute value from `mail`, - `emailAdress` and `email` depending on which attributes are delivered by the + `emailAddress` and `email` depending on which attributes are delivered by the target provider as the value for `mail`. * Any plugin using the `openid` profile will use the attribute value under the key `formatted` in the `address` attribute delivered by the target provider. @@ -266,7 +266,7 @@ provider. 2. The **SAMLMirrorFrontend** module mirrors each target provider as a separate entity in the SAML metadata. In this proxy this is handled with dynamic entity id's, encoding the target provider. This allows external discovery services to present the mirrored providers transparently, as separate entities -in its UI. The following flow diagram shows the communcation: +in its UI. The following flow diagram shows the communication: `SP -> optional discovery service -> selected proxy SAML entity -> target IdP` @@ -311,7 +311,7 @@ config: #### Policy -Some settings related to how a SAML response is formed can be overriden on a per-instance or a per-SP +Some settings related to how a SAML response is formed can be overridden on a per-instance or a per-SP basis. This example summarizes the most common settings (hopefully self-explanatory) with their defaults: ```yaml @@ -328,7 +328,7 @@ config: ``` Overrides per SP entityID is possible by using the entityID as a key instead of the "default" key -in the yaml structure. The most specific key takes presedence. If no policy overrides are provided +in the yaml structure. The most specific key takes precedence. If no policy overrides are provided the defaults above are used. ### SAML2 Backend @@ -397,7 +397,7 @@ the user will have to always select a target provider when a discovery service is configured. If the parameter is set to `True` (and `ForceAuthn` is not set), the proxy will remember and reuse the selected target provider for the duration that the state cookie is valid. If `ForceAuthn` is set, then the -`use_memorized_idp_when_force_authn` configuration option can overide +`use_memorized_idp_when_force_authn` configuration option can override this property and still reuse the selected target provider. The default behaviour is `False`. @@ -803,12 +803,12 @@ Backends and Frontends act like adapters, while micro-services act like plugins and all of them can be developed by anyone and shared with everyone. Other people that have been working with the SaToSa proxy, have built -extentions mainly in the form of additional micro-services that are shared to +extensions mainly in the form of additional micro-services that are shared to be used by anyone. -- SUNET maintains a small collection of extentions that focus around the SWAMID +- SUNET maintains a small collection of extensions that focus around the SWAMID policies. - The extentions are licensed under the Apache2.0 license. + The extensions are licensed under the Apache2.0 license. You can find the code using the following URL: - https://github.com/SUNET/swamid-satosa/ @@ -828,16 +828,16 @@ be used by anyone. - https://github.com/italia/Satosa-Saml2Spid - DAASI International have been a long-time user of this software and have made - their extentions available. - The extentions are licensed under the Apache2.0 license. + their extensions available. + The extensions are licensed under the Apache2.0 license. You can find the code using the following URL: - https://gitlab.daasi.de/didmos2/didmos2-auth/-/tree/master/src/didmos_oidc/satosa/micro_services - The extentions include: + The extensions include: - SCIM attribute store to fetch attributes via SCIM API (instead of LDAP) - - Authoritzation module for blocking services if necessary group memberships or + - Authorization module for blocking services if necessary group memberships or attributes are missing in the identity (for service providers that do not evaluate attributes themselves) - Backend chooser with Django UI for letting the user choose between any diff --git a/src/satosa/attribute_mapping.py b/src/satosa/attribute_mapping.py index e8729561c..38ad9ea70 100644 --- a/src/satosa/attribute_mapping.py +++ b/src/satosa/attribute_mapping.py @@ -1,21 +1,22 @@ import logging from collections import defaultdict from itertools import chain +from typing import Any, Mapping, Optional, Union from mako.template import Template logger = logging.getLogger(__name__) -def scope(s): +def scope(s: str) -> str: """ Mako filter: used to extract scope from attribute :param s: string to extract scope from (filtered string in mako template) :return: the scope """ - if '@' not in s: + if "@" not in s: raise ValueError("Unscoped string") - (local_part, _, domain_part) = s.partition('@') + (local_part, _, domain_part) = s.partition("@") return domain_part @@ -24,9 +25,8 @@ class AttributeMapper(object): Converts between internal and external data format """ - def __init__(self, internal_attributes): + def __init__(self, internal_attributes: dict[str, dict[str, dict[str, list[str]]]]): """ - :type internal_attributes: dict[str, dict[str, dict[str, str]]] :param internal_attributes: A map of how to convert the attributes (dict[internal_name, dict[attribute_profile, external_name]]) """ @@ -35,21 +35,16 @@ def __init__(self, internal_attributes): self.from_internal_attributes = internal_attributes["attributes"] self.template_attributes = internal_attributes.get("template_attributes", None) - self.to_internal_attributes = defaultdict(dict) + self.to_internal_attributes: dict[str, Any] = defaultdict(dict) for internal_attribute_name, mappings in self.from_internal_attributes.items(): for profile, external_attribute_names in mappings.items(): for external_attribute_name in external_attribute_names: self.to_internal_attributes[profile][external_attribute_name] = internal_attribute_name - def to_internal_filter(self, attribute_profile, external_attribute_names): + def to_internal_filter(self, attribute_profile: str, external_attribute_names: list[str]) -> list[str]: """ Converts attribute names from external "type" to internal - :type attribute_profile: str - :type external_attribute_names: list[str] - :type case_insensitive: bool - :rtype: list[str] - :param attribute_profile: From which external type to convert (ex: oidc, saml, ...) :param external_attribute_names: A list of attribute names :param case_insensitive: Create a case insensitive filter @@ -63,7 +58,7 @@ def to_internal_filter(self, attribute_profile, external_attribute_names): # no attributes since the given profile is not configured return [] - internal_attribute_names = set() # use set to ensure only unique values + internal_attribute_names: set[str] = set() # use set to ensure only unique values for external_attribute_name in external_attribute_names: try: internal_attribute_name = profile_mapping[external_attribute_name] @@ -73,14 +68,10 @@ def to_internal_filter(self, attribute_profile, external_attribute_names): return list(internal_attribute_names) - def to_internal(self, attribute_profile, external_dict): + def to_internal(self, attribute_profile: str, external_dict: Mapping[str, list[str]]) -> dict[str, list[str]]: """ Converts the external data from "type" to internal - :type attribute_profile: str - :type external_dict: dict[str, str] - :rtype: dict[str, str] - :param attribute_profile: From which external type to convert (ex: oidc, saml, ...) :param external_dict: Attributes in the external format :return: Attributes in the internal format @@ -97,8 +88,7 @@ def to_internal(self, attribute_profile, external_dict): continue external_attribute_name = mapping[attribute_profile] - attribute_values = self._collate_attribute_values_by_priority_order(external_attribute_name, - external_dict) + attribute_values = self._collate_attribute_values_by_priority_order(external_attribute_name, external_dict) if attribute_values: # Only insert key if it has some values logline = "backend attribute {external} mapped to {internal} ({value})".format( external=external_attribute_name, internal=internal_attribute_name, value=attribute_values @@ -106,15 +96,15 @@ def to_internal(self, attribute_profile, external_dict): logger.debug(logline) internal_dict[internal_attribute_name] = attribute_values else: - logline = "skipped backend attribute {}: no value found".format( - external_attribute_name - ) + logline = "skipped backend attribute {}: no value found".format(external_attribute_name) logger.debug(logline) internal_dict = self._handle_template_attributes(attribute_profile, internal_dict) return internal_dict - def _collate_attribute_values_by_priority_order(self, attribute_names, data): - result = [] + def _collate_attribute_values_by_priority_order( + self, attribute_names: list[str], data: Mapping[str, list[str]] + ) -> list[str]: + result: list[str] = [] for attr_name in attribute_names: attr_val = self._get_nested_attribute_value(attr_name, data) @@ -125,14 +115,19 @@ def _collate_attribute_values_by_priority_order(self, attribute_names, data): return result - def _render_attribute_template(self, template, data): + def _render_attribute_template(self, template: str, data: Mapping[str, list[str]]) -> list[str]: t = Template(template, cache_enabled=True, imports=["from satosa.attribute_mapping import scope"]) try: - return t.render(**data).split(self.multivalue_separator) + _rendered = t.render(**data) + if not isinstance(_rendered, str): + raise TypeError("Rendered data is not a string") + return _rendered.split(self.multivalue_separator) except (NameError, TypeError): return [] - def _handle_template_attributes(self, attribute_profile, internal_dict): + def _handle_template_attributes( + self, attribute_profile: str, internal_dict: dict[str, list[str]] + ) -> dict[str, list[str]]: if not self.template_attributes: return internal_dict @@ -143,26 +138,27 @@ def _handle_template_attributes(self, attribute_profile, internal_dict): external_attribute_name = mapping[attribute_profile] templates = [t for t in external_attribute_name if "$" in t] # these looks like templates... - template_attribute_values = [self._render_attribute_template(template, internal_dict) for template in - templates] - flattened_attribute_values = list(chain.from_iterable(template_attribute_values)) - attribute_values = flattened_attribute_values or internal_dict.get(internal_attribute_name, None) + template_attribute_values = [ + self._render_attribute_template(template, internal_dict) for template in templates + ] + flattened_attribute_values: list[str] = list(chain.from_iterable(template_attribute_values)) + attribute_values = flattened_attribute_values or internal_dict.get(internal_attribute_name) if attribute_values: # only insert key if it has some values internal_dict[internal_attribute_name] = attribute_values return internal_dict - def _get_nested_attribute_value(self, nested_key, data): + def _get_nested_attribute_value(self, nested_key: str, data: Mapping[str, Any]) -> Optional[Any]: keys = nested_key.split(self.separator) d = data for key in keys: - d = d.get(key) + d = d.get(key) # type: ignore[assignment] if d is None: return None return d - def _create_nested_attribute_value(self, nested_attribute_names, value): + def _create_nested_attribute_value(self, nested_attribute_names: list[str], value: Any) -> dict[str, Any]: if len(nested_attribute_names) == 1: # we've reached the inner-most attribute name, set value here return {nested_attribute_names[0]: value} @@ -171,26 +167,22 @@ def _create_nested_attribute_value(self, nested_attribute_names, value): child_dict = self._create_nested_attribute_value(nested_attribute_names[1:], value) return {nested_attribute_names[0]: child_dict} - def from_internal(self, attribute_profile, internal_dict): + def from_internal( + self, attribute_profile: str, internal_dict: dict[str, list[str]] + ) -> dict[str, Union[list[str], dict[str, list[str]]]]: """ Converts the internal data to "type" - :type attribute_profile: str - :type internal_dict: dict[str, str] - :rtype: dict[str, str] - :param attribute_profile: To which external type to convert (ex: oidc, saml, ...) :param internal_dict: attributes to map :return: attribute values and names in the specified "profile" """ - external_dict = {} + external_dict: dict[str, Union[list[str], dict[str, list[str]]]] = {} for internal_attribute_name in internal_dict: try: attribute_mapping = self.from_internal_attributes[internal_attribute_name] except KeyError: - logline = "no attribute mapping found for the internal attribute {}".format( - internal_attribute_name - ) + logline = "no attribute mapping found for the internal attribute {}".format(internal_attribute_name) logger.debug(logline) continue @@ -206,14 +198,17 @@ def from_internal(self, attribute_profile, internal_dict): # select the first attribute name external_attribute_name = external_attribute_names[0] logline = "frontend attribute {external} mapped from {internal} ({value})".format( - external=external_attribute_name, internal=internal_attribute_name, value=internal_dict[internal_attribute_name] + external=external_attribute_name, + internal=internal_attribute_name, + value=internal_dict[internal_attribute_name], ) logger.debug(logline) if self.separator in external_attribute_name: nested_attribute_names = external_attribute_name.split(self.separator) - nested_dict = self._create_nested_attribute_value(nested_attribute_names[1:], - internal_dict[internal_attribute_name]) + nested_dict = self._create_nested_attribute_value( + nested_attribute_names[1:], internal_dict[internal_attribute_name] + ) external_dict[nested_attribute_names[0]] = nested_dict else: external_dict[external_attribute_name] = internal_dict[internal_attribute_name] diff --git a/src/satosa/base.py b/src/satosa/base.py index 404104920..a686f316e 100644 --- a/src/satosa/base.py +++ b/src/satosa/base.py @@ -39,32 +39,37 @@ def __init__(self, config): self.config = config logger.info("Loading backend modules...") - backends = load_backends(self.config, self._auth_resp_callback_func, - self.config["INTERNAL_ATTRIBUTES"]) + backends = load_backends(self.config, self._auth_resp_callback_func, self.config["INTERNAL_ATTRIBUTES"]) logger.info("Loading frontend modules...") - frontends = load_frontends(self.config, self._auth_req_callback_func, - self.config["INTERNAL_ATTRIBUTES"]) + frontends = load_frontends(self.config, self._auth_req_callback_func, self.config["INTERNAL_ATTRIBUTES"]) self.response_micro_services = [] self.request_micro_services = [] logger.info("Loading micro services...") if "MICRO_SERVICES" in self.config: - self.request_micro_services.extend(load_request_microservices( - self.config.get("CUSTOM_PLUGIN_MODULE_PATHS"), - self.config["MICRO_SERVICES"], - self.config["INTERNAL_ATTRIBUTES"], - self.config["BASE"])) + self.request_micro_services.extend( + load_request_microservices( + self.config.get("CUSTOM_PLUGIN_MODULE_PATHS"), + self.config["MICRO_SERVICES"], + self.config["INTERNAL_ATTRIBUTES"], + self.config["BASE"], + ) + ) self._link_micro_services(self.request_micro_services, self._auth_req_finish) self.response_micro_services.extend( - load_response_microservices(self.config.get("CUSTOM_PLUGIN_MODULE_PATHS"), - self.config["MICRO_SERVICES"], - self.config["INTERNAL_ATTRIBUTES"], - self.config["BASE"])) + load_response_microservices( + self.config.get("CUSTOM_PLUGIN_MODULE_PATHS"), + self.config["MICRO_SERVICES"], + self.config["INTERNAL_ATTRIBUTES"], + self.config["BASE"], + ) + ) self._link_micro_services(self.response_micro_services, self._auth_resp_finish) - self.module_router = ModuleRouter(frontends, backends, - self.request_micro_services + self.response_micro_services) + self.module_router = ModuleRouter( + frontends, backends, self.request_micro_services + self.response_micro_services + ) def _link_micro_services(self, micro_services, finisher): if not micro_services: @@ -138,14 +143,13 @@ def _auth_resp_callback_func(self, context, internal_response): # If configured construct the user id from attribute values. if "user_id_from_attrs" in self.config["INTERNAL_ATTRIBUTES"]: subject_id = [ - "".join(internal_response.attributes[attr]) for attr in - self.config["INTERNAL_ATTRIBUTES"]["user_id_from_attrs"] + "".join(internal_response.attributes[attr]) + for attr in self.config["INTERNAL_ATTRIBUTES"]["user_id_from_attrs"] ] internal_response.subject_id = "".join(subject_id) if self.response_micro_services: - return self.response_micro_services[0].process( - context, internal_response) + return self.response_micro_services[0].process(context, internal_response) return self._auth_resp_finish(context, internal_response) @@ -180,9 +184,7 @@ def _run_bound_endpoint(self, context, spec): except SATOSAAuthenticationError as error: error.error_id = uuid.uuid4().urn state = json.dumps(error.state.state_dict, indent=4) - msg = "ERROR_ID [{err_id}]\nSTATE:\n{state}".format( - err_id=error.error_id, state=state - ) + msg = "ERROR_ID [{err_id}]\nSTATE:\n{state}".format(err_id=error.error_id, state=state) logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) logger.error(logline, exc_info=True) return self._handle_satosa_authentication_error(error) @@ -219,8 +221,9 @@ def _save_state(self, resp, context): :param context: Session context """ - cookie = state_to_cookie(context.state, self.config["COOKIE_STATE_NAME"], "/", - self.config["STATE_ENCRYPTION_KEY"]) + cookie = state_to_cookie( + context.state, self.config["COOKIE_STATE_NAME"], "/", self.config["STATE_ENCRYPTION_KEY"] + ) resp.headers.append(tuple(cookie.output().split(": ", 1))) def run(self, context): @@ -259,16 +262,14 @@ def run(self, context): class SAMLBaseModule(object): - KEY_ENTITYID_ENDPOINT = 'entityid_endpoint' - KEY_ENABLE_METADATA_RELOAD = 'enable_metadata_reload' - KEY_ATTRIBUTE_PROFILE = 'attribute_profile' - KEY_ACR_MAPPING = 'acr_mapping' - VALUE_ATTRIBUTE_PROFILE_DEFAULT = 'saml' + KEY_ENTITYID_ENDPOINT = "entityid_endpoint" + KEY_ENABLE_METADATA_RELOAD = "enable_metadata_reload" + KEY_ATTRIBUTE_PROFILE = "attribute_profile" + KEY_ACR_MAPPING = "acr_mapping" + VALUE_ATTRIBUTE_PROFILE_DEFAULT = "saml" def init_config(self, config): - self.attribute_profile = config.get( - self.KEY_ATTRIBUTE_PROFILE, - self.VALUE_ATTRIBUTE_PROFILE_DEFAULT) + self.attribute_profile = config.get(self.KEY_ATTRIBUTE_PROFILE, self.VALUE_ATTRIBUTE_PROFILE_DEFAULT) self.acr_mapping = config.get(self.KEY_ACR_MAPPING) return config @@ -287,13 +288,13 @@ def enable_metadata_reload(self): class SAMLEIDASBaseModule(SAMLBaseModule): - VALUE_ATTRIBUTE_PROFILE_DEFAULT = 'eidas' + VALUE_ATTRIBUTE_PROFILE_DEFAULT = "eidas" def init_config(self, config): config = super().init_config(config) spec_eidas = { - 'entityid_endpoint': True, + "entityid_endpoint": True, } return util.check_set_dict_defaults(config, spec_eidas) diff --git a/src/satosa/context.py b/src/satosa/context.py index 1cf140586..0b13868c8 100644 --- a/src/satosa/context.py +++ b/src/satosa/context.py @@ -1,3 +1,4 @@ +from typing import Any, Optional from warnings import warn as _warn from satosa.exception import SATOSAError @@ -7,6 +8,7 @@ class SATOSABadContextError(SATOSAError): """ Raise this exception if validating the Context and failing. """ + pass @@ -14,16 +16,17 @@ class Context(object): """ Holds methods for sharing proxy data through the current request """ - KEY_METADATA_STORE = 'metadata_store' - KEY_TARGET_ENTITYID = 'target_entity_id' - KEY_FORCE_AUTHN = 'force_authn' - KEY_MEMORIZED_IDP = 'memorized_idp' - KEY_AUTHN_CONTEXT_CLASS_REF = 'authn_context_class_ref' - KEY_TARGET_AUTHN_CONTEXT_CLASS_REF = 'target_authn_context_class_ref' - - def __init__(self): - self._path = None - self.request = None + + KEY_METADATA_STORE = "metadata_store" + KEY_TARGET_ENTITYID = "target_entity_id" + KEY_FORCE_AUTHN = "force_authn" + KEY_MEMORIZED_IDP = "memorized_idp" + KEY_AUTHN_CONTEXT_CLASS_REF = "authn_context_class_ref" + KEY_TARGET_AUTHN_CONTEXT_CLASS_REF = "target_authn_context_class_ref" + + def __init__(self) -> None: + self._path: Optional[str] = None + self.request: dict[str, Any] = {} self.request_uri = None self.request_method = None self.qs_params = None @@ -35,8 +38,8 @@ def __init__(self): self.target_frontend = None self.target_micro_service = None # This dict is a data carrier between frontend and backend modules. - self.internal_data = {} - self.state = None + self.internal_data: dict[str, Any] = {} + self.state: dict[str, Any] = {} @property def KEY_BACKEND_METADATA_STORE(self): @@ -47,18 +50,16 @@ def KEY_BACKEND_METADATA_STORE(self): return Context.KEY_METADATA_STORE @property - def path(self): + def path(self) -> Optional[str]: """ Get the path - :rtype: str - :return: context path """ return self._path @path.setter - def path(self, p): + def path(self, p: str) -> None: """ Inserts a path to the context. This path is striped by the base_url, so for example: @@ -72,26 +73,26 @@ def path(self, p): """ if not p: raise ValueError("path can't be set to None") - elif p.startswith('/'): + elif p.startswith("/"): raise ValueError("path can't start with '/'") self._path = p - def target_entity_id_from_path(self): + def target_entity_id_from_path(self) -> Optional[str]: + if not self.path: + return None target_entity_id = self.path.split("/")[1] return target_entity_id - def decorate(self, key, value): + def decorate(self, key: str, value: Any) -> "Context": """ Add information to the context """ - self.internal_data[key] = value return self - def get_decoration(self, key): + def get_decoration(self, key: str) -> Any: """ Retrieve information from the context """ - value = self.internal_data.get(key) return value diff --git a/src/satosa/internal.py b/src/satosa/internal.py index 24de31890..8f49ea180 100644 --- a/src/satosa/internal.py +++ b/src/satosa/internal.py @@ -1,20 +1,21 @@ """Internal data representation for SAML/OAuth/OpenID connect.""" +from __future__ import annotations - +from typing import Any, Mapping, NewType, Optional, Type, TypeVar import warnings as _warnings from collections import UserDict +TDatafySubclass = TypeVar("TDatafySubclass", bound="_Datafy") + class _Datafy(UserDict): - _DEPRECATED_TO_NEW_MEMBERS = {} + _DEPRECATED_TO_NEW_MEMBERS: Mapping[str, str] = {} def _get_new_key(self, old_key): new_key = self.__class__._DEPRECATED_TO_NEW_MEMBERS.get(old_key, old_key) is_key_deprecated = old_key != new_key if is_key_deprecated: - msg = "'{old_key}' is deprecated; use '{new_key}' instead.".format( - old_key=old_key, new_key=new_key - ) + msg = "'{old_key}' is deprecated; use '{new_key}' instead.".format(old_key=old_key, new_key=new_key) _warnings.warn(msg, DeprecationWarning) return new_key @@ -40,38 +41,26 @@ def __getattr__(self, key): try: value = self.__getitem__(key) except KeyError as e: - msg = "'{type}' object has no attribute '{attr}'".format( - type=type(self), attr=key - ) + msg = "'{type}' object has no attribute '{attr}'".format(type=type(self), attr=key) raise AttributeError(msg) from e return value - def to_dict(self): + def to_dict(self) -> dict[str, Any]: """ Converts an object to a dict - :rtype: dict[str, str] :return: A dict representation of the object """ data = { key: value for key, value_obj in self.items() - for value in [ - value_obj.to_dict() if hasattr(value_obj, "to_dict") else value_obj - ] + for value in [value_obj.to_dict() if hasattr(value_obj, "to_dict") else value_obj] } - data.update( - { - key: data.get(value) - for key, value in self.__class__._DEPRECATED_TO_NEW_MEMBERS.items() - } - ) + data.update({key: data.get(value) for key, value in self.__class__._DEPRECATED_TO_NEW_MEMBERS.items()}) return data @classmethod - def from_dict(cls, data): + def from_dict(cls: type[TDatafySubclass], data: dict[str, Any]) -> TDatafySubclass: """ - :type data: dict[str, str] - :rtype: satosa.internal.AuthenticationInformation :param data: A dict representation of an object :return: An object """ @@ -86,20 +75,16 @@ class AuthenticationInformation(_Datafy): def __init__( self, - auth_class_ref=None, - timestamp=None, - issuer=None, - authority=None, + auth_class_ref: Optional[str] = None, + timestamp: Optional[str] = None, + issuer: Optional[str] = None, + authority: Optional[Any] = None, *args, **kwargs, ): """ Initiate the data carrier - :type auth_class_ref: str - :type timestamp: str - :type issuer: str - :param auth_class_ref: What method that was used for the authentication :param timestamp: Time when the authentication was done :param issuer: Where the authentication was done @@ -118,12 +103,12 @@ class InternalData(_Datafy): def __init__( self, - auth_info=None, - requester=None, - requester_name=None, - subject_id=None, - subject_type=None, - attributes=None, + auth_info: Optional[AuthenticationInformation] = None, + requester: Optional[str] = None, + requester_name: Optional[list[Mapping[str, Any]]] = None, + subject_id: Optional[str] = None, + subject_type: Optional[str] = None, + attributes: Optional[dict[str, Any]] = None, *args, **kwargs, ): @@ -134,13 +119,6 @@ def __init__( :param subject_id: :param subject_type: :param attributes: - - :type auth_info: AuthenticationInformation - :type requester: str - :type requester_name: - :type subject_id: str - :type subject_type: str - :type attributes: dict """ super().__init__(self, *args, **kwargs) self.auth_info = ( @@ -149,11 +127,7 @@ def __init__( else AuthenticationInformation(**(auth_info or {})) ) self.requester = requester - self.requester_name = ( - requester_name - if requester_name is not None - else [{"text": requester, "lang": "en"}] - ) + self.requester_name = requester_name if requester_name is not None else [{"text": requester, "lang": "en"}] self.subject_id = subject_id self.subject_type = subject_type self.attributes = attributes if attributes is not None else {} diff --git a/src/satosa/micro_services/base.py b/src/satosa/micro_services/base.py index 084cbea76..e19173b75 100644 --- a/src/satosa/micro_services/base.py +++ b/src/satosa/micro_services/base.py @@ -2,37 +2,45 @@ Micro service for SATOSA """ import logging +from typing import Any, Callable, Optional, Union +import satosa.context +import satosa.internal +import satosa.response logger = logging.getLogger(__name__) +ProcessReturnType = Union[satosa.internal.InternalData, satosa.response.Response] +MicroServiceCallSignature = Callable[[satosa.context.Context, satosa.internal.InternalData], ProcessReturnType] +CallbackReturnType = satosa.response.Response +CallbackCallSignature = Callable[[satosa.context.Context, Any], CallbackReturnType] + + class MicroService(object): """ Abstract class for micro services """ - def __init__(self, name, base_url, **kwargs): + def __init__(self, name: str, base_url: str, **kwargs: Any): self.name = name self.base_url = base_url - self.next = None + self.next: Optional[MicroServiceCallSignature] = None - def process(self, context, data): + def process(self, context: satosa.context.Context, data: satosa.internal.InternalData) -> ProcessReturnType: """ This is where the micro service should modify the request / response. Subclasses must call this method (or in another way make sure the `next` callable is called). - :type context: satosa.context.Context - :type data: satosa.internal.InternalData - :rtype: satosa.internal.InternalData - :param context: The current context :param data: Data to be modified :return: Modified data """ + if not self.next: + raise RuntimeError("No next micro service") return self.next(context, data) - def register_endpoints(self): + def register_endpoints(self) -> list[tuple[str, CallbackCallSignature]]: """ URL mapping of additional endpoints this micro service needs to register for callbacks. @@ -41,11 +49,8 @@ def register_endpoints(self): ("^/callback1$", self.callback), ] - - :rtype List[Tuple[str, Callable[[satosa.context.Context, Any], satosa.response.Response]]] - :return: A list with functions and args bound to a specific endpoint url, - [(regexp, Callable[[satosa.context.Context], satosa.response.Response]), ...] + [(regexp, CallbackCallSignature), ...] """ return [] @@ -54,6 +59,7 @@ class ResponseMicroService(MicroService): """ Base class for response micro services """ + pass @@ -61,4 +67,5 @@ class RequestMicroService(MicroService): """ Base class for request micro services """ + pass