diff --git a/src/riskmatrix/__init__.py b/src/riskmatrix/__init__.py index 2e36f92..a981841 100644 --- a/src/riskmatrix/__init__.py +++ b/src/riskmatrix/__init__.py @@ -2,6 +2,9 @@ from pyramid.config import Configurator from pyramid_beaker import session_factory_from_settings from typing import Any +from email.headerregistry import Address +from pyramid.settings import asbool +from .mail import PostmarkMailer from riskmatrix.flash import MessageQueue from riskmatrix.i18n import LocaleNegotiator @@ -24,6 +27,19 @@ def includeme(config: Configurator) -> None: settings = config.registry.settings + default_sender = settings.get( + 'email.default_sender', + 'riskmatrix@seantis.ch' + ) + token = settings.get('mail.postmark_token', '') + stream = settings.get('mail.postmark_stream', 'development') + blackhole = asbool(settings.get('mail.postmark_blackhole', False)) + config.registry.registerUtility(PostmarkMailer( + Address(addr_spec=default_sender), + token, + stream, + blackhole=blackhole + )) config.include('pyramid_beaker') config.include('pyramid_chameleon') config.include('pyramid_layout') @@ -67,8 +83,12 @@ def main( environment=sentry_environment, integrations=[PyramidIntegration(), SqlalchemyIntegration()], traces_sample_rate=1.0, - profiles_sample_rate=0.25, + profiles_sample_rate=1.0, + enable_tracing=True, + send_default_pii=True ) + print("configured sentry") + print(sentry_dsn) with Configurator(settings=settings, root_factory=root_factory) as config: includeme(config) diff --git a/src/riskmatrix/mail/__init__.py b/src/riskmatrix/mail/__init__.py new file mode 100644 index 0000000..4b44c34 --- /dev/null +++ b/src/riskmatrix/mail/__init__.py @@ -0,0 +1,15 @@ +from .exceptions import InactiveRecipient +from .exceptions import MailConnectionError +from .exceptions import MailError +from .interfaces import IMailer +from .mailer import PostmarkMailer +from .types import MailState + +__all__ = ( + 'IMailer', + 'InactiveRecipient', + 'MailConnectionError', + 'MailError', + 'MailState', + 'PostmarkMailer', +) \ No newline at end of file diff --git a/src/riskmatrix/mail/exceptions.py b/src/riskmatrix/mail/exceptions.py new file mode 100644 index 0000000..363353c --- /dev/null +++ b/src/riskmatrix/mail/exceptions.py @@ -0,0 +1,10 @@ +class MailError(Exception): + pass + + +class MailConnectionError(MailError, ConnectionError): + pass + + +class InactiveRecipient(MailError): + pass \ No newline at end of file diff --git a/src/riskmatrix/mail/interfaces.py b/src/riskmatrix/mail/interfaces.py new file mode 100644 index 0000000..8d1c5d8 --- /dev/null +++ b/src/riskmatrix/mail/interfaces.py @@ -0,0 +1,91 @@ +from typing import Any, TYPE_CHECKING +from zope.interface import Interface + +if TYPE_CHECKING: + from collections.abc import Sequence + from email.headerregistry import Address + from .types import MailState + from .types import MailParams + from .types import TemplateMailParams + from ..certificate.interfaces import ITemplate + from ..models import Organization + from ..types import JSONObject + MailID = Any + + +class IMailer(Interface): # pragma: no cover + + # NOTE: We would like to say that kwargs is OptionalMailParams + # however there is no way in mypy to express that yet. + def send(sender: 'Address | None', + receivers: 'Address | Sequence[Address]', + subject: str, + content: str, + **kwargs: Any) -> 'MailID': + """ + Send a single email. + + Returns a message uuid. + """ + pass + + def bulk_send(mails: list['MailParams'] + ) -> list['MailID | MailState']: + """ + Send multiple emails. "mails" is a list of dicts containing + the arguments to an individual send call. + + Returns a list of message uuids and their success/failure states + in the same order as the sending list. + """ + pass + + # NOTE: We would like to say that kwargs is OptionalTemplateMailParams + # however there is no way in mypy to express that yet. + def send_template(sender: 'Address | None', + receivers: 'Address | Sequence[Address]', + template: str, + data: 'JSONObject', + **kwargs: Any) -> 'MailID': + """ + Send a single email using a template using its id/name. + "data" contains the template specific data. + + Returns a message uuid. + """ + pass + + def bulk_send_template(mails: list['TemplateMailParams'], + default_template: str | None = None, + ) -> list['MailID | MailState']: + """ + Send multiple template emails using the same template. + + Returns a list of message uuids. If a message failed to be sent + the uuid will be replaced by a MailState value. + """ + pass + + def template_exists(alias: str) -> bool: + """ + Returns whether a template by the given alias exists. + """ + pass + + def create_or_update_template( + template: 'ITemplate', + organization: 'Organization | None' = None, + ) -> list[str]: + """ + Creates or updates a mailer template based on a certificate template. + + Returns a list of errors. If the list is empty, it was successful. + """ + pass + + def delete_template(template: 'ITemplate') -> list[str]: + """ + Deletes a mailer template based on a certificate template. + + Returns a list of errors. If the list is empty, it was successful. + """ \ No newline at end of file diff --git a/src/riskmatrix/mail/mailer.py b/src/riskmatrix/mail/mailer.py new file mode 100644 index 0000000..f3deba9 --- /dev/null +++ b/src/riskmatrix/mail/mailer.py @@ -0,0 +1,625 @@ +import base64 +import io +import json +import re +import requests + +from email.headerregistry import Address +from markupsafe import Markup +from string import ascii_letters +from string import digits +from zope.interface import implementer + +from .exceptions import InactiveRecipient +from .exceptions import MailConnectionError +from .exceptions import MailError +from .interfaces import IMailer +from .types import MailState + +from typing import cast, overload, Any, ClassVar, TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Sequence + from requests import Response + from .types import MailAttachment + from .types import MailParams + from .types import TemplateMailParams + from ..certificate.interfaces import ITemplate + from ..models import Organization + from ..types import JSON, JSONArray, JSONObject + MailID = str + AnyMailParams = MailParams | TemplateMailParams + + +domain_regex = re.compile(r'@[A-Za-z0-9][A-Za-z0-9.-]*[.][A-Za-z]{2,10}') +plus_regex = re.compile(r'(?@,:;.]') +alphanumeric = ascii_letters + digits +qp_prefix = '=?utf-8?q?' +qp_suffix = '?=' +QP_PREFIX_LENGTH = len(qp_prefix) +QP_SUFFIX_LENGTH = len(qp_suffix) +QP_MAX_WORD_LENGTH = 75 +QP_CONTENT_LENGTH = QP_MAX_WORD_LENGTH - QP_PREFIX_LENGTH - QP_SUFFIX_LENGTH + + +def needs_header_encode(name: str) -> bool: + # NOTE: Backslash escaping is forbidden in Postmark API + if '"' in name: + return True + try: + # NOTE: Technically there's some ASCII characters that + # should be illegal altogether such as \n, \r, \0 + # This should already be caught by the use of Address + # though, which makes sure each part only contains + # legal characters. + name.encode('ascii') + except UnicodeEncodeError: + return True + return False + + +def qp_encode_display_name(name: str) -> str: + words: list[str] = [] + current_word: list[str] = [] + + def finish_word() -> None: + nonlocal current_word + content = ''.join(current_word) + words.append(f'{qp_prefix}{content}{qp_suffix}') + current_word = [] + + for character in name: + if character == ' ': + # special case for header encoding + characters = ['_'] + elif character in alphanumeric: + # no need to encode this character + characters = [character] + else: + # QP encode the character + characters = list( + ''.join(f'={c:02X}' for c in character.encode('utf-8')) + ) + + if len(current_word) + len(characters) > QP_CONTENT_LENGTH: + finish_word() + + current_word.extend(characters) + + finish_word() + if len(words) == 1: + # We can omit the enclosing double quotes + return words[0] + + # NOTE: The enclosing double quotes are necessary so that spaces + # as word separators can be parsed correctly. + return f'"{" ".join(words)}"' + + +def format_single_address(address: Address) -> str: + # NOTE: Format the address according to Postmark API rules: + # Quoted printable encoded words can be at most 75 + # characters and we're not allowed to use backslash + # escaping for double quotes. + name = address.display_name + if not name: + return address.addr_spec + + if not needs_header_encode(name): + if specials_regex.search(name): + # simple quoting works here, since we disallow + # backslash escaping double quotes. + name = f'"{name}"' + return f'{name} <{address.addr_spec}>' + + name = qp_encode_display_name(name) + return f'{name} <{address.addr_spec}>' + + +def format_address(addresses: 'Address | Sequence[Address]') -> str: + if isinstance(addresses, Address): + addresses = [addresses] + return ', '.join(format_single_address(a) for a in addresses) + + +@implementer(IMailer) +class PostmarkMailer: + api_url: ClassVar[str] = 'https://api.postmarkapp.com' + default_sender: Address + server_token: str + stream: str + blackhole: bool + + def __init__(self, + default_sender: Address, + server_token: str, + stream: str, + blackhole: bool = False) -> None: + + self.default_sender = default_sender + self.server_token = server_token + self.stream = stream + self.blackhole = blackhole + + def request_headers(self) -> dict[str, str]: + return {'X-Postmark-Server-Token': self.server_token} + + def prepare_message(self, params: 'AnyMailParams') -> 'JSONObject': + receivers = format_address(params['receivers']) + # Strip plus addressing, so it can be used regardless of provider + # support to disambiguate multiple participants with the same + # e-mail address. + # TODO: We could maybe do this in format_single_address? + receivers = plus_regex.sub('', receivers) + if self.blackhole: + receivers = domain_regex.sub( + '@blackhole.postmarkapp.com', receivers + ) + + reply_to = params.get('sender') + if reply_to is not None and reply_to.display_name: + # Mix display name from Reply-To into From + sender = format_single_address(Address( + display_name=reply_to.display_name, + username=self.default_sender.username, + domain=self.default_sender.domain, + )) + else: + sender = format_single_address(self.default_sender) + + message: 'JSONObject' = { + 'From': sender, + 'To': receivers, + 'MessageStream': self.stream, + 'TrackOpens': True + } + if reply_to is not None: + message['ReplyTo'] = format_single_address(reply_to) + if 'content' in params: + params = cast('MailParams', params) + message['TextBody'] = params['content'] + message['Subject'] = params['subject'] + elif 'template' in params: + params = cast('TemplateMailParams', params) + # NOTE: I'm not sure i like this, but it's better than having + # to pass mailer.stream around so much, which shouldn't + # necessarily be a standard attribute of IMailer... + message['TemplateAlias'] = f"{self.stream}-{params['template']}" + message['TemplateModel'] = params['data'] + + if 'subject' in params: + message['Subject'] = params['subject'] + + if 'tag' in params: + message['Tag'] = params['tag'] + + if 'attachments' in params: + message['Attachments'] = self.prepare_attachments( + params['attachments'] + ) + return message + + def prepare_attachments(self, attachments: list['MailAttachment'] + ) -> 'JSONArray': + result: 'JSONArray' = [] + for attachment in attachments: + content = base64.b64encode(attachment['content']) + payload: 'JSONObject' = { + 'Name': attachment['filename'], + 'Content': content.decode('ascii'), + 'ContentType': attachment['content_type'], + } + if 'content_id' in attachment: + payload['ContentID'] = 'cid:' + attachment['content_id'] + result.append(payload) + return result + + def get_response_data(self, response: 'Response') -> 'JSON': + try: + data: 'JSON' = response.json() + if not response.ok: + if ( + isinstance(data, dict) + and isinstance((msg := data.get('Message')), str) + ): + raise MailError(msg) + else: + raise MailError('Unknown error') + return data + except (ValueError, KeyError): + raise MailError('Malformed response from Postmark API') from None + + def _raw_send(self, api_path: str, params: 'AnyMailParams') -> 'MailID': + send_url = self.api_url + api_path + send_data = self.prepare_message(params) + headers = self.request_headers() + try: + response = requests.post( + send_url, + json=send_data, + headers=headers, + timeout=(5, 30) + ) + except ConnectionError: + raise MailConnectionError( + 'Failed to connect to Postmark API' + ) from None + data: 'JSON' = self.get_response_data(response) + if not isinstance(data, dict): + raise MailError('Invalid API data.') + + if data['ErrorCode'] == 406: + raise InactiveRecipient() + elif data['ErrorCode'] != 0: + raise MailError(data['Message']) + return data['MessageID'] # type:ignore[no-any-return] + + def send(self, + sender: Address | None, + receivers: 'Address | Sequence[Address]', + subject: str, + content: str, + *, + tag: str | None = None, + attachments: list['MailAttachment'] | None = None, + **kwargs: Any) -> 'MailID': + + params: 'MailParams' = { + 'receivers': receivers, + 'subject': subject, + 'content': content + } + if sender: + params['sender'] = sender + if tag: + params['tag'] = tag + if attachments: + params['attachments'] = attachments + + return self._raw_send('/email', params) + + def send_template(self, + sender: Address | None, + receivers: 'Address | Sequence[Address]', + template: str, + data: 'JSONObject', + *, + subject: str | None = None, + tag: str | None = None, + attachments: list['MailAttachment'] | None = None, + **kwargs: Any) -> 'MailID': + + params: 'TemplateMailParams' = { + 'receivers': receivers, + 'template': template, + 'data': data + } + if sender: + params['sender'] = sender + if subject: + params['subject'] = subject + if tag: + params['tag'] = tag + if attachments: + params['attachments'] = attachments + + return self._raw_send('/email/withTemplate', params) + + # NOTE: For now we disallow mixing MailParams for the purpose of + # preventing users from specifying a template for MailParams. + # Technically this should not cause any problem however, so + # we could allow it in the future if we really needed it. + @overload + def _raw_bulk_send(self, + api_path: str, + mails: 'Sequence[MailParams]', + *, + preamble: bytes = b'{"messages":[', + postamble: bytes = b']}' + ) -> list['MailID | MailState']: ... + @overload # noqa: E301 + def _raw_bulk_send(self, + api_path: str, + mails: 'Sequence[TemplateMailParams]', + template: str | None = None, + *, + preamble: bytes = b'{"messages":[', + postamble: bytes = b']}' + ) -> list['MailID | MailState']: ... + def _raw_bulk_send(self, # noqa: E301 + api_path: str, + mails: 'Sequence[AnyMailParams]', + template: str | None = None, + *, + # NOTE: annoyingly Postmark either has a wrapper + # around the array or not depending on the + # api call so we need to handle this using + # these arguments + preamble: bytes = b'{"Messages": [', + postamble: bytes = b']}' + ) -> list['MailID | MailState']: + + messages: 'JSONArray' = [] + for mail in mails: + if template: + mail = cast('TemplateMailParams', mail) + # NOTE: This modifies the original dict, which could + # be a source for errors, but it's also faster... + mail.setdefault('template', template) + messages.append(self.prepare_message(mail)) + + bulk_url = self.api_url + api_path + headers = self.request_headers() + # We generate the payload ourselves so we set the headers manually + headers['Accept'] = 'application/json' + headers['Content-Type'] = 'application/json' + BATCH_LIMIT = 500 + # NOTE: The API specifies MB, so let's not chance it + # by assuming they meant MiB and just go with + # lower size limit. + SIZE_LIMIT = 50_000_000 # 50MB + # NOTE: We use a buffer to be a bit more memory efficient + # we don't initialize the buffer, so tell gives us + # the exact size of the buffer. + buffer = io.BytesIO() + buffer.write(preamble) + num_included = 0 + result: list['MailID | MailState'] = [] + + def finish_batch() -> None: + nonlocal buffer + nonlocal num_included + + buffer.write(postamble) + + # if the batch is empty we just skip it + if num_included > 0: + assert num_included <= BATCH_LIMIT + assert buffer.tell() <= SIZE_LIMIT + + try: + response = requests.post( + bulk_url, + data=buffer.getvalue(), + headers=headers, + timeout=(5, 60) + ) + data = self.get_response_data(response) + if not isinstance(data, list) or len(data) != num_included: + # TODO: should probably log this as a warning + raise MailError('Invalid API data.') + + for message in data: + if ( + not isinstance(message, dict) + or 'ErrorCode' not in message + or 'MessageID' not in message + ): + # TODO: should probably log this as a warning + result.append(MailState.failed) + continue + + error_code = message['ErrorCode'] + if error_code == 406: + result.append(MailState.inactive_recipient) + elif error_code != 0: + result.append(MailState.failed) + else: + # if we don't get an ID we don't want to fail hard + # so we just pretend the mail has been delivered + result.append(message['MessageID']) + + except ConnectionError: + # we'll treat these as a temporary failures + result.extend([MailState.temporary_failure]*num_included) + except MailError: + # we'll treat these as more permanent failures for now + result.extend([MailState.failed]*num_included) + + # prepare vars for next batch + buffer.close() + buffer = io.BytesIO() + buffer.write(preamble) + num_included = 0 + + for message in messages: + payload = json.dumps(message).encode('utf-8') + if buffer.tell() + len(payload) + len(postamble) >= SIZE_LIMIT: + finish_batch() + + if num_included: + buffer.write(b',') + + buffer.write(payload) + num_included += 1 + + if num_included == BATCH_LIMIT: + finish_batch() + + # finish final partially full batch + finish_batch() + return result + + def bulk_send(self, mails: list['MailParams'] + ) -> list['MailID | MailState']: + return self._raw_bulk_send('/email/batch', mails, + preamble=b'[', postamble=b']') + + def bulk_send_template(self, + mails: list['TemplateMailParams'], + default_template: str | None = None, + ) -> list['MailID | MailState']: + return self._raw_bulk_send( + '/email/batchWithTemplates', mails, default_template + ) + + def get_message_details(self, message_id: 'MailID') -> 'JSON': + details_url = f'{self.api_url}/messages/outbound/{message_id}/details' + headers = self.request_headers() + headers['Accept'] = 'application/json' + try: + response = requests.get( + details_url, headers=headers, timeout=(5, 10) + ) + except ConnectionError: + raise MailConnectionError( + 'Failed to connect to Postmark API' + ) from None + return self.get_response_data(response) + + def get_message_state(self, message_id: 'MailID') -> MailState: + # NOTE: With multiple recipients for a single mail Postmark will report + # a status for each recipient individually. So if we ever specify + # multiple recipients, this method will be wrong. + details = self.get_message_details(message_id) + if not isinstance(details, dict): + raise MailError('Invalid API data.') + state = max( + ( + message_event_type_to_mail_state.get(event['Type'], -1) + for event in details['MessageEvents'] + ), + default=-1 + ) + if state < 0: + return MailState.submitted + elif state == 0: + return MailState.bounced + return cast('MailState', state) + + def validate_template(self, template_data: dict[str, str]) -> list[str]: + validate_url = self.api_url + '/templates/validate' + try: + response = requests.post( + validate_url, + json=template_data, + headers=self.request_headers(), + timeout=(5, 10) + ) + if not response.ok: + return ['Template failed to validate.'] + + data = response.json() + if data['AllContentIsValid']: + return [] + + return [ + # FIXME: we should try to map these back to the original markup + ( + f'{location} line {error["Line"]}:' + f'{error["CharacterPosition"]}: {error["Message"]}' + ) + for location in ('Subject', 'HtmlBody', 'TextBody') + for error in data[location]['ValidationErrors'] + ] + except ConnectionError: + raise MailConnectionError( + 'Failed to connect to Postmark API' + ) from None + + def template_exists(self, alias: str) -> bool: + template_url = self.api_url + '/templates/' + alias + try: + headers = self.request_headers() + headers['Accept'] = 'application/json' + response = requests.get( + template_url, + headers=headers, + timeout=(5, 10) + ) + return response.ok + except ConnectionError: + raise MailConnectionError( + 'Failed to connect to Postmark API' + ) from None + + def create_or_update_template( + self, + template: 'ITemplate', + organization: 'Organization | None' = None, + ) -> list[str]: + + alias = f'{self.stream}-{template.id}' + if organization is None: + name = f'[{self.stream}] Shared: {template.name}' + else: + name = f'[{self.stream}] {organization.name}: {template.name}' + if len(name) > 100: + name = name[:97] + '...' + html_content: str = markdown_to_html(template.email_content) + plain_content = markdown_to_plaintext(template.email_content) + + # replace logos with appropriate placeholders. + html_content = html_content.replace( + '{{organization.logo}}', + Markup('{{organization}}') + ) + plain_content = plain_content.replace( + '{{organization.logo}}', + '{{organization}}' + ) + template_data = { + 'Name': name, + 'Alias': alias, + 'Subject': template.email_subject, + 'HtmlBody': html_content, + 'TextBody': plain_content, + 'TemplateType': 'Standard', + 'LayoutTemplate': 'basic', + } + errors = self.validate_template(template_data) + if errors: + return errors + + if self.template_exists(alias): + action = 'update' + method = 'put' + template_url = self.api_url + '/templates/' + alias + else: + action = 'create' + method = 'post' + template_url = self.api_url + '/templates' + + try: + response = requests.request( + method, + template_url, + json=template_data, + headers=self.request_headers(), + timeout=(5, 10) + ) + if not response.ok: + return [f'Failed to {action} template.'] + return [] + except ConnectionError: + # Let's not force people to catch an exception + return ['Failed to connect to Postmark API.'] + + def delete_template(self, template: 'ITemplate') -> list[str]: + alias = f'{self.stream}-{template.id}' + if not self.template_exists(alias): + return [] + + try: + headers = self.request_headers() + headers['Accept'] = 'application/json' + response = requests.delete( + self.api_url + '/templates/' + alias, + headers=headers, + timeout=(5, 10) + ) + if not response.ok: + return ['Failed to delete template.'] + return [] + except ConnectionError: + # Let's not force people to catch an exception + return ['Failed to connect to Postmark API.'] + + +message_event_type_to_mail_state = { + 'Delivered': MailState.delivered, + 'Transient': 0, # Needs to have lower priority than hard bounce + 'Bounced': MailState.failed, # This is a hard bounce + 'Opened': MailState.read, +} \ No newline at end of file diff --git a/src/riskmatrix/mail/types.py b/src/riskmatrix/mail/types.py new file mode 100644 index 0000000..f021005 --- /dev/null +++ b/src/riskmatrix/mail/types.py @@ -0,0 +1,67 @@ +import enum + + +from typing import TypedDict, TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from email.headerregistry import Address + + from ..types import JSONObject + + +class MailState(enum.IntEnum): + not_queued = 0 + queued = 10 + temporary_failure = 15 + failed = 20 + inactive_recipient = 30 + bounced = 40 + submitted = 50 + delivered = 60 + read = 70 + + +class _BaseMailAttachment(TypedDict): + filename: str + content: bytes + content_type: str + + +class MailAttachment(_BaseMailAttachment, total=False): + content_id: str + + +class RequiredMailParams(TypedDict): + receivers: 'Address | Sequence[Address]' + subject: str + content: str + + +class OptionalMailParams(TypedDict, total=False): + sender: 'Address' + tag: str + attachments: list[MailAttachment] + + +class MailParams(RequiredMailParams, OptionalMailParams): + pass + + +class RequiredTemplateMailParams(TypedDict): + receivers: 'Address | Sequence[Address]' + template: str + data: 'JSONObject' + + +class OptionalTemplateMailParams(TypedDict, total=False): + sender: 'Address' + subject: str + tag: str + attachments: list[MailAttachment] + + +class TemplateMailParams( + RequiredTemplateMailParams, + OptionalTemplateMailParams +): + pass \ No newline at end of file diff --git a/src/riskmatrix/models/__init__.py b/src/riskmatrix/models/__init__.py index da997fe..af61c68 100644 --- a/src/riskmatrix/models/__init__.py +++ b/src/riskmatrix/models/__init__.py @@ -13,6 +13,7 @@ from .risk_catalog import RiskCatalog from .risk_category import RiskCategory from .user import User +from .password_change_token import PasswordChangeToken from typing import TYPE_CHECKING @@ -61,5 +62,6 @@ def includeme(config: 'Configurator') -> None: 'RiskCatalog', 'RiskCategory', 'RiskMatrixAssessment', - 'User' + 'User', + 'PasswordChangeToken' ) diff --git a/src/riskmatrix/models/password_change_token.py b/src/riskmatrix/models/password_change_token.py new file mode 100644 index 0000000..54254a0 --- /dev/null +++ b/src/riskmatrix/models/password_change_token.py @@ -0,0 +1,101 @@ +import secrets +import uuid +from datetime import datetime +from datetime import timedelta +from sqlalchemy import ForeignKey +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship +from sqlalchemy.orm import Mapped + +from ..orm.meta import Base +from ..orm.meta import DateTimeNoTz +from ..orm.meta import UUIDStr +from ..orm.meta import UUIDStrPK +from ..security_policy import PasswordException + + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .user import User + + +class PasswordChangeToken(Base): + + __tablename__ = 'password_change_tokens' + + id: Mapped[UUIDStrPK] + user_id: Mapped[UUIDStr] = mapped_column( + ForeignKey('user.id', ondelete='CASCADE'), + index=True + ) + time_requested: Mapped[DateTimeNoTz] + time_consumed: Mapped[DateTimeNoTz | None] + time_expired: Mapped[DateTimeNoTz | None] + token: Mapped[str] + ip_address: Mapped[str] + + user: Mapped['User'] = relationship(lazy='joined') + + def __init__( + self, + user: 'User', + ip_address: str, + time_requested: datetime | None = None + ): + + if not user.email: + raise PasswordException( + 'Cannot request password change without an email.' + ) + self.id = str(uuid.uuid4()) + self.user = user + self.ip_address = ip_address + + if time_requested is None: + time_requested = datetime.utcnow() + self.time_requested = time_requested.replace(tzinfo=None) + self.time_consumed = None + self.token = secrets.token_urlsafe() + + def consume(self, email: str) -> None: + if self.consumed: + msg = f'Token "{self.token}" has already been used' + raise PasswordException(msg) + + user_email = self.user.email + assert user_email + + if email and email.lower() != user_email.lower(): + raise PasswordException(f'Invalid email for token "{self.token}"') + + # Initial password set link does not expire + if self.user.password and self.expired: + raise PasswordException(f'Token "{self.token}" has expired') + + time_consumed = datetime.utcnow() + time_consumed = time_consumed.replace(tzinfo=None) + self.time_consumed = time_consumed + + @property + def consumed(self) -> bool: + if self.time_consumed is not None: + return True + return False + + def expire(self, expired: datetime | None = None) -> None: + if not self.expired: + if not expired: + expired = datetime.utcnow() + expired = expired.replace(tzinfo=None) + self.time_expired = expired + + @property + def expired(self) -> bool: + expired = self.time_expired + if expired and expired < datetime.utcnow(): + return True + + expiring_time = self.time_requested + timedelta(hours=48) + if datetime.utcnow() > expiring_time: + return True + return False \ No newline at end of file diff --git a/src/riskmatrix/orm/meta.py b/src/riskmatrix/orm/meta.py index 2e183a3..66f9ef8 100644 --- a/src/riskmatrix/orm/meta.py +++ b/src/riskmatrix/orm/meta.py @@ -9,6 +9,7 @@ from sqlalchemy.orm import registry from sqlalchemy.orm import DeclarativeBase from sqlalchemy.schema import MetaData +from sqlalchemy import DateTime from .json_type import JSONObject from .utcdatetime_type import UTCDateTime @@ -46,6 +47,7 @@ primary_key=True )] BigInt = Annotated[int, 'BigInt'] +DateTimeNoTz = Annotated[datetime, 'DateTimeNoTz'] UUIDStr = Annotated[str, 'UUIDStr'] str_32 = Annotated[str, 32] str_64 = Annotated[str, 64] @@ -63,6 +65,7 @@ class Base(DeclarativeBase): datetime: UTCDateTime, BigInt: BigInteger().with_variant(Integer(), 'sqlite'), str_32: String(length=32), + DateTimeNoTz: DateTime(timezone=False), str_64: String(length=64), str_128: String(length=128), str_256: String(length=256), diff --git a/src/riskmatrix/scripts/upgrade.py b/src/riskmatrix/scripts/upgrade.py index 011568d..59ab59d 100644 --- a/src/riskmatrix/scripts/upgrade.py +++ b/src/riskmatrix/scripts/upgrade.py @@ -211,3 +211,6 @@ def main() -> None: print('Dry run') upgrade(args) + +if __name__ == '__main__': + main() diff --git a/src/riskmatrix/views/__init__.py b/src/riskmatrix/views/__init__.py index aff1b9b..6ecccb4 100644 --- a/src/riskmatrix/views/__init__.py +++ b/src/riskmatrix/views/__init__.py @@ -6,6 +6,7 @@ from riskmatrix.route_factories import risk_assessment_factory from riskmatrix.route_factories import risk_catalog_factory +from .password_change import password_change_view from .asset import assets_view from .asset import delete_asset_view from .asset import edit_asset_view @@ -24,6 +25,7 @@ from .risk_assessment import generate_risk_matrix_view from .risk_assessment import set_impact_view from .risk_assessment import set_likelihood_view +from .password_retrieval import password_retrieval_view from .risk_catalog import delete_risk_catalog_view from .risk_catalog import edit_risk_catalog_view from .risk_catalog import risk_catalog_view @@ -55,6 +57,24 @@ def includeme(config: 'Configurator') -> None: permission=NO_PERMISSION_REQUIRED ) + config.add_route('password_retrieval', '/password_retrieval') + config.add_view( + password_retrieval_view, + route_name='password_retrieval', + renderer='templates/password_retrieval.pt', + require_csrf=False, + permission=NO_PERMISSION_REQUIRED + ) + + config.add_route('password_change', '/password_change') + config.add_view( + password_change_view, + route_name='password_change', + renderer='templates/password_change.pt', + require_csrf=False, + permission=NO_PERMISSION_REQUIRED + ) + config.add_route('logout', '/logout') config.add_view(logout_view, route_name='logout') diff --git a/src/riskmatrix/views/password_change.py b/src/riskmatrix/views/password_change.py new file mode 100644 index 0000000..f0cc5cb --- /dev/null +++ b/src/riskmatrix/views/password_change.py @@ -0,0 +1,104 @@ +import logging +from pyramid.httpexceptions import HTTPFound +from wtforms import Form +from wtforms import PasswordField +from wtforms import StringField +from wtforms.validators import InputRequired + +from ..models import PasswordChangeToken +from ..security_policy import PasswordException +from ..wtform.validators import password_validator + + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from pyramid.interfaces import IRequest + from ..types import RenderDataOrRedirect + +logger = logging.getLogger('certificateclaim.auth') + + +class PasswordChangeForm(Form): + + email = StringField( + label='Email', + validators=[ + InputRequired(), + ] + ) + password = PasswordField( + label='New Password', + validators=[ + InputRequired(), + password_validator + ] + ) + password_confirmation = PasswordField( + label='Confirm New Password', + validators=[ + InputRequired() + ] + ) + + +def get_token( + token: str, + request: 'IRequest' +) -> PasswordChangeToken | None: + + session = request.dbsession + query = session.query(PasswordChangeToken) + query = query.filter(PasswordChangeToken.token == token) + return query.first() + + +def password_change_view(request: 'IRequest') -> 'RenderDataOrRedirect': + + form = PasswordChangeForm(formdata=request.POST) + if 'email' in request.POST: + if form.validate(): + assert form.email.data is not None + assert form.password.data is not None + token = request.GET.get('token', '') + email = form.email.data.lower() + password = form.password.data + try: + token_obj = get_token(token, request) + if not token_obj: + raise PasswordException(f'Token "{token}" not found') + token_obj.consume(email) + token_obj.user.set_password(password) + request.messages.add('Password changed', 'success') + return HTTPFound(request.route_url('login')) + except PasswordException as e: + msg = f'Password change: {str(e)}' + logger.warning(msg) + request.messages.add('Invalid Request', 'error') + else: + msg = ( + 'There was a problem with your submission. Errors have ' + 'been highlighted below.' + ) + request.messages.add(msg, 'error') + + else: + msg = ( + 'Password must have minimal length of 8 characters, contain one ' + 'upper case letter, one lower case letter, one digit and one ' + 'special character.' + ) + request.messages.add(msg, 'info') + + token = request.GET.get('token', '') + token_obj = get_token(token, request) + valid = True + if token_obj and (token_obj.expired or token_obj.consumed): + valid = False + msg = 'This password reset link has expired.' + request.messages.clear() + request.messages.add(msg, 'error') + + return { + 'form': form, + 'valid': valid, + } \ No newline at end of file diff --git a/src/riskmatrix/views/password_retrieval.py b/src/riskmatrix/views/password_retrieval.py new file mode 100644 index 0000000..191bf18 --- /dev/null +++ b/src/riskmatrix/views/password_retrieval.py @@ -0,0 +1,99 @@ +import logging +from email.headerregistry import Address +from pyramid.httpexceptions import HTTPFound +from wtforms import Form +from wtforms import StringField +from wtforms.validators import InputRequired + +from ..mail import IMailer +from ..models import PasswordChangeToken +from ..models import User +from ..security_policy import PasswordException +from ..wtform.validators import email_validator + + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from pyramid.interfaces import IRequest + from sqlalchemy.orm import Session + from ..types import RenderDataOrRedirect + +logger = logging.getLogger('certificateclaim.auth') + + +class PasswordRetrievalForm(Form): + + email = StringField( + label='Email', + validators=[ + InputRequired(), + email_validator + ] + ) + + +def expire_all_tokens(user: User, session: 'Session') -> None: + query = session.query(PasswordChangeToken) + query = query.filter(PasswordChangeToken.user_id == user.id) + query = query.filter(PasswordChangeToken.time_expired.is_(None)) + for token in query: + token.expire() + + +def mail_retrieval(email: str, request: 'IRequest') -> None: + # NOTE: This will probably get caught by email_validator + # but lets just be safe for now... + if '\x00' in email: + raise PasswordException(f'Invalid email "{email}"') + + session = request.dbsession + query = session.query(User) + query = query.filter(User.email.ilike(email)) + user = query.first() + if not user: + raise PasswordException(f'User "{email}" not found') + + expire_all_tokens(user, session) + ip_address = getattr(request, 'client_addr', '') + token_obj = PasswordChangeToken(user, ip_address) + session.add(token_obj) + session.flush() + + mailer = request.registry.getUtility(IMailer) + mailer.send_template( + sender=None, # This mail doesn't need a reply-to + receivers=Address(user.fullname, addr_spec=user.email), + template='password-reset', + data={ + 'name': user.fullname, + 'action_url': request.route_url( + 'password_change', + _query={'token': token_obj.token} + ) + }, + tag='password-reset', + ) + + +def password_retrieval_view(request: 'IRequest') -> 'RenderDataOrRedirect': + form = PasswordRetrievalForm(formdata=request.POST) + if 'email' in request.POST and form.validate(): + try: + assert form.email.data is not None + email = form.email.data.lower() + mail_retrieval(email, request) + logger.info(f'Password retrieval mail sent to "{email}"') + except PasswordException as e: + logger.warning( + f'[{request.client_addr}] password retrieval: {str(e)}' + ) + + request.messages.add( + 'An email has been sent to the requested account with further ' + 'information. If you do not receive an email then please ' + 'confirm you have entered the correct email address.', + 'success' + ) + return HTTPFound(location=request.route_url('login')) + + return {'form': form} \ No newline at end of file diff --git a/src/riskmatrix/views/risk.py b/src/riskmatrix/views/risk.py index 8a879ff..7c5daaa 100644 --- a/src/riskmatrix/views/risk.py +++ b/src/riskmatrix/views/risk.py @@ -112,10 +112,11 @@ def __init__( } ) - self.category.choices = category_choices( - organization_id, - session - ) + # FIXME: currently disabled due to complexity given another dimension + # self.category.choices = category_choices( + # organization_id, + # session + #) name = StringField( @@ -133,13 +134,14 @@ def __init__( ) ) - category = SelectField( - label=_('Category'), - choices=[('', '')], - validators=( - validators.Optional(), - ) - ) + # FIXME: currently disabled due to complexity given another dimension + #category = SelectField( + # label=_('Category'), + # choices=[('', '')], + # validators=( + # validators.Optional(), + # ), + #) def validate_name(self, field: 'Field') -> None: session = self.meta.dbsession @@ -188,7 +190,7 @@ class RisksTable(AJAXDataTable[Risk]): } name = DataColumn(_('Name')) - category = DataColumn(_('Category', ), class_name='visually-hidden') + #category = DataColumn(_('Category', ), class_name='visually-hidden') description = DataColumn(_('Description')) def __init__(self, catalog: 'RiskCatalog', request: 'IRequest') -> None: @@ -278,19 +280,27 @@ def edit_risk_view( organization_id = context.organization_id else: risk = None - if request.json: - risk = Risk(name=request.json["name"], description=request.json["description"], catalog=context) - request.dbsession.add(risk) - request.dbsession.flush() - request.dbsession.refresh(risk) - response = Response(status=201) + try: + if request.json: + risk = Risk(name=request.json["name"], description=request.json["description"], catalog=context) + request.dbsession.add(risk) + request.dbsession.flush() + request.dbsession.refresh(risk) + response = Response(status=201) - return response + return response + except: + pass organization_id = context.id catalog = context form = RiskMetaForm(context, request) target_url = request.route_url('risks', id=organization_id) - if request.method == 'POST' and (not request.json and form.validate()): + try: + t = request.json + has_json = t is not None + except: + has_json = False + if request.method == 'POST' and (not has_json and form.validate()): if risk is None: risk = Risk(name=form.name.data or '', catalog=catalog) request.dbsession.add(risk) diff --git a/src/riskmatrix/views/templates/login.pt b/src/riskmatrix/views/templates/login.pt index 9df3ef7..fc02abb 100644 --- a/src/riskmatrix/views/templates/login.pt +++ b/src/riskmatrix/views/templates/login.pt @@ -8,7 +8,7 @@
-
+
diff --git a/src/riskmatrix/views/templates/password_change.pt b/src/riskmatrix/views/templates/password_change.pt new file mode 100644 index 0000000..dd2fb4b --- /dev/null +++ b/src/riskmatrix/views/templates/password_change.pt @@ -0,0 +1,29 @@ + + + + +
+ +
+
+
+

Change Password

+
+ ${field.label(class_='form-label sr-only')} + ${field( + class_='form-control is-invalid' if field.errors else 'form-control', + placeholder=field.label.text, + )} +
${error}
+
+
+ +
+
+
+
+
+ +
+ +
\ No newline at end of file diff --git a/src/riskmatrix/views/templates/password_retrieval.pt b/src/riskmatrix/views/templates/password_retrieval.pt new file mode 100644 index 0000000..fd17e8c --- /dev/null +++ b/src/riskmatrix/views/templates/password_retrieval.pt @@ -0,0 +1,29 @@ + + + + +
+ +
+
+
+

Password Retrieval

+
+ ${field.label(class_='form-label sr-only')} + ${field( + class_='form-control is-invalid' if field.errors else 'form-control', + placeholder=field.label.text, + )} +
${error}
+
+
+ +
+
+
+
+
+ +
+ +
\ No newline at end of file diff --git a/src/riskmatrix/wtform/validators.py b/src/riskmatrix/wtform/validators.py index 5b4a468..f7871ec 100644 --- a/src/riskmatrix/wtform/validators.py +++ b/src/riskmatrix/wtform/validators.py @@ -16,6 +16,10 @@ r'^(?=.{8,})(?=.*[a-z])(?=.*[A-Z])(?=.*[\d])(?=.*[\W]).*$' ) +email_regex = re.compile( + r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$' +) + def password_validator(form: 'Form', field: 'Field') -> None: password = form['password'].data @@ -35,6 +39,9 @@ def password_validator(form: 'Form', field: 'Field') -> None: ) raise ValidationError(msg) +def email_validator(form: 'Form', field: 'Field') -> None: + if not email_regex.match(field.data): + raise ValidationError('Not a valid email.') class Immutable: """