|
| 1 | +import logging |
| 2 | +import uuid |
| 3 | +from datetime import date |
| 4 | +from enum import StrEnum |
| 5 | + |
| 6 | +import click |
| 7 | +from sqlalchemy import select |
| 8 | + |
| 9 | +import src.adapters.db.flask_db as flask_db |
| 10 | +import src.util.datetime_util as datetime_util |
| 11 | +from src.adapters import db |
| 12 | +from src.constants.lookup_constants import UserType |
| 13 | +from src.db.models import staging |
| 14 | +from src.db.models.user_models import ( |
| 15 | + Agency, |
| 16 | + AgencyUser, |
| 17 | + AgencyUserRole, |
| 18 | + LegacyCertificate, |
| 19 | + Role, |
| 20 | + User, |
| 21 | +) |
| 22 | +from src.task.ecs_background_task import ecs_background_task |
| 23 | +from src.task.task import Task |
| 24 | +from src.task.task_blueprint import task_blueprint |
| 25 | + |
| 26 | +logger = logging.getLogger(__name__) |
| 27 | + |
| 28 | +FUTURE_DATE = date(2050, 1, 1) |
| 29 | + |
| 30 | + |
| 31 | +class SetupCertUserTaskStatus(StrEnum): |
| 32 | + INVALID_ROLE_IDS = "Invalid role ids" |
| 33 | + AGENCY_NOT_FOUND = "Agency not found" |
| 34 | + LEGACY_CERTIFICATE_ALREADY_EXISTS = "LegacyCertificate already exists" |
| 35 | + SUCCESS = "Success" |
| 36 | + TCERTIFICATE_IS_EXPIRED = "Tcertificate is expired" |
| 37 | + TCERTIFICATE_NOT_FOUND = "Tcertificate not found" |
| 38 | + TCERTIFICATE_IS_MISSING_SERIAL_NUMBER = "Tcertificate is missing serial number" |
| 39 | + |
| 40 | + |
| 41 | +@task_blueprint.cli.command("setup-cert-user", help="Setup the LegacyCertificate and User") |
| 42 | +@click.option("--tcertficates-id", "-t", help="tcertificates_id on Staging Tcertificate") |
| 43 | +@click.option("--role-ids", "-t", help="role_id of role that needs to be added", multiple=True) |
| 44 | +@flask_db.with_db_session() |
| 45 | +@ecs_background_task(task_name="setup-cert-user") |
| 46 | +def setup_cert_user(db_session: db.Session, cert_id: str, role_ids: list[str]) -> None: |
| 47 | + SetupCertUserTask(db_session, cert_id, role_ids).run_task() |
| 48 | + |
| 49 | + |
| 50 | +class SetupCertUserTask(Task): |
| 51 | + |
| 52 | + def __init__(self, db_session: db.Session, tcertificates_id: str, role_ids: list[str]): |
| 53 | + super().__init__(db_session) |
| 54 | + self.tcertificates_id = tcertificates_id |
| 55 | + self.role_ids = role_ids |
| 56 | + |
| 57 | + def run_task(self) -> None: |
| 58 | + with self.db_session.begin(): |
| 59 | + self.setup_cert() |
| 60 | + |
| 61 | + def setup_cert(self) -> SetupCertUserTaskStatus: |
| 62 | + logger.info("setup cert user start") |
| 63 | + roles = self.get_roles() |
| 64 | + if roles is None: |
| 65 | + return SetupCertUserTaskStatus.INVALID_ROLE_IDS |
| 66 | + |
| 67 | + tcertificate = self.get_tcertificate() |
| 68 | + if tcertificate is None: |
| 69 | + logger.warning("Tcertificate not found") |
| 70 | + return SetupCertUserTaskStatus.TCERTIFICATE_NOT_FOUND |
| 71 | + if not tcertificate.serial_num: |
| 72 | + logger.warning("Tcertificate is missing serial number") |
| 73 | + return SetupCertUserTaskStatus.TCERTIFICATE_IS_MISSING_SERIAL_NUMBER |
| 74 | + valid_expiration_date = tcertificate.expirationdate or FUTURE_DATE |
| 75 | + if valid_expiration_date <= datetime_util.get_now_us_eastern_date(): |
| 76 | + logger.warning("Cert is expired") |
| 77 | + return SetupCertUserTaskStatus.TCERTIFICATE_IS_EXPIRED |
| 78 | + if self.is_existing_certificate(tcertificate): |
| 79 | + logger.warning("LegacyCertificate already exists") |
| 80 | + return SetupCertUserTaskStatus.LEGACY_CERTIFICATE_ALREADY_EXISTS |
| 81 | + |
| 82 | + agency, related_agencies = self.get_agencies(tcertificate) |
| 83 | + if agency is None: |
| 84 | + return SetupCertUserTaskStatus.AGENCY_NOT_FOUND |
| 85 | + |
| 86 | + else: |
| 87 | + self.process_cert_user( |
| 88 | + roles, tcertificate, agency, related_agencies, valid_expiration_date |
| 89 | + ) |
| 90 | + logger.info("setup cert user complete") |
| 91 | + return SetupCertUserTaskStatus.SUCCESS |
| 92 | + |
| 93 | + def process_cert_user( |
| 94 | + self, |
| 95 | + roles: list[Role], |
| 96 | + tcertificate: staging.certificates.Tcertificates, |
| 97 | + agency: Agency, |
| 98 | + related_agencies: list[Agency], |
| 99 | + valid_expiration_date: date, |
| 100 | + ) -> None: |
| 101 | + all_agencies = related_agencies + [agency] |
| 102 | + user = self.create_user_with_agency_roles(all_agencies, roles) |
| 103 | + legacy_certificate = LegacyCertificate( |
| 104 | + legacy_certificate_id=uuid.uuid4(), |
| 105 | + agency=agency, |
| 106 | + cert_id=tcertificate.currentcertid, |
| 107 | + expiration_date=valid_expiration_date, |
| 108 | + serial_number=tcertificate.serial_num, |
| 109 | + user=user, |
| 110 | + ) |
| 111 | + self.db_session.add(legacy_certificate) |
| 112 | + |
| 113 | + logger.info( |
| 114 | + "Created legacy certificate", |
| 115 | + extra={ |
| 116 | + "legacy_certificate_id": legacy_certificate.legacy_certificate_id, |
| 117 | + "user_id": user.user_id, |
| 118 | + "agency_code": agency.agency_code, |
| 119 | + }, |
| 120 | + ) |
| 121 | + |
| 122 | + def create_user_with_agency_roles(self, agencies: list[Agency], roles: list[Role]) -> User: |
| 123 | + user = User(user_id=uuid.uuid4(), user_type=UserType.LEGACY_CERTIFICATE) |
| 124 | + self.db_session.add(user) |
| 125 | + |
| 126 | + log_extra = {"user_id": user.user_id} |
| 127 | + logger.info("Created legacy cert user", extra=log_extra) |
| 128 | + |
| 129 | + for agency in agencies: |
| 130 | + agency_user = AgencyUser(user=user, agency=agency) |
| 131 | + self.db_session.add(agency_user) |
| 132 | + |
| 133 | + agency_roles = [AgencyUserRole(agency_user=agency_user, role=r) for r in roles] |
| 134 | + |
| 135 | + self.db_session.add_all(agency_roles) |
| 136 | + |
| 137 | + logger.info( |
| 138 | + "Added user to agency", |
| 139 | + extra=log_extra |
| 140 | + | {"agency_code": agency.agency_code, "role_ids": [r.role_id for r in roles]}, |
| 141 | + ) |
| 142 | + |
| 143 | + return user |
| 144 | + |
| 145 | + def get_roles(self) -> list[Role] | None: |
| 146 | + roles = list( |
| 147 | + self.db_session.scalars( |
| 148 | + select(Role).where(Role.role_id.in_([uuid.UUID(r) for r in self.role_ids])) |
| 149 | + ).all() |
| 150 | + ) |
| 151 | + if len(self.role_ids) != len(roles): |
| 152 | + log_extra = {"found_role_ids": [r.role_id for r in roles] if roles else []} |
| 153 | + logger.warning("Invalid role ids", extra=log_extra) |
| 154 | + return None |
| 155 | + return roles |
| 156 | + |
| 157 | + def get_tcertificate(self) -> staging.certificates.Tcertificates | None: |
| 158 | + return self.db_session.scalars( |
| 159 | + select(staging.certificates.Tcertificates).where( |
| 160 | + staging.certificates.Tcertificates.tcertificates_id |
| 161 | + == uuid.UUID(self.tcertificates_id) |
| 162 | + ) |
| 163 | + ).one_or_none() |
| 164 | + |
| 165 | + def get_agencies( |
| 166 | + self, tcertificate: staging.certificates.Tcertificates |
| 167 | + ) -> tuple[Agency | None, list[Agency]]: |
| 168 | + agency = self.db_session.scalar( |
| 169 | + select(Agency).where(Agency.agency_code == tcertificate.agencyid) |
| 170 | + ) |
| 171 | + if not agency: |
| 172 | + logger.warning("Agency not found") |
| 173 | + return (None, []) |
| 174 | + agencies: tuple[Agency | None, list[Agency]] = (agency, []) |
| 175 | + """ |
| 176 | + If the tcertificate agency has is_multilevel marked as True then: |
| 177 | + 1. fetch every agency that starts with the same prefix as the agency: |
| 178 | + SELECT * FROM agency WHERE agency_code LIKE '{agency.agency_code}-%' |
| 179 | + 2. add an AgencyUser and AgencyUserRole for every subagency |
| 180 | + this is to mimic the grants.gov behavior |
| 181 | + """ |
| 182 | + if agency.is_multilevel_agency: |
| 183 | + search_pattern = f"{agency.agency_code}-%" |
| 184 | + agency_query_results = list( |
| 185 | + self.db_session.scalars( |
| 186 | + select(Agency).where(Agency.agency_code.like(search_pattern)) |
| 187 | + ).all() |
| 188 | + ) |
| 189 | + agencies[1].extend(agency_query_results) |
| 190 | + return agencies |
| 191 | + |
| 192 | + def is_existing_certificate(self, tcertificate: staging.certificates.Tcertificates) -> bool: |
| 193 | + existing_tcertificate = self.db_session.scalars( |
| 194 | + select(LegacyCertificate.legacy_certificate_id).where( |
| 195 | + LegacyCertificate.cert_id == tcertificate.currentcertid |
| 196 | + ) |
| 197 | + ).one_or_none() |
| 198 | + return existing_tcertificate is not None |
0 commit comments