Skip to content

Commit

Permalink
[Issue #1270] Modify DB session logic to allow for multiple schemas (#…
Browse files Browse the repository at this point in the history
…1520)

## Summary
Fixes #1270

### Time to review: __10 mins__

## Changes proposed
Modify our SQLAlchemy logic to allow for multiple schemas to be setup.
This includes:
* Setting the schema explicitly in a class that all SQLAlchemy models
inherit from
* Setting up a schema translate map (only meaningfully needed for local
tests) to allow for changing the name of a schema
* A new (LOCAL ONLY) script for creating the `api` schema

## Context for reviewers
**Non-locally this change does not actually change anything yet -
locally it does by making local development more similar to non-local**

This does not actually setup any new schemas, and every table we create
still lives in a single schema, the `api` schema.

This change looks far larger than it actually is. Before, all of our
tables had their schema set implicitly by the `DB_SCHEMA` environment
variable. Locally this value was set to `public` and non-locally it was
set to `api`. These changes make it so locally it also uses `api`,
however in order for that to work, the Alembic migrations need to
explicitly say `api` (in case we add more schemas later). There is a
flag in the Alembic configuration that tells it to generate with the
schemas, but we had that disabled. I enabled it so future migrations
_just work_. But in order to make everything work locally, I had to
manually fix all of the past migrations to have the `api` schema.

Non-locally the schema already was `api` so changing already-run
migrations won't matter as they already ran as if they had that value
set.

## Additional information
This change requires you run `make db-recreate` locally in order to use
the updated schemas.

To test this, I manually ran the database migrations one step at a time,
fixing any issues. I then ran the down migrations and made sure they
also worked correctly, undoing the up migrations correctly. I then ran a
few of our local scripts to make sure everything still worked properly
and didn't find any issues.

---------

Co-authored-by: nava-platform-bot <[email protected]>
  • Loading branch information
chouinar and nava-platform-bot authored Apr 2, 2024
1 parent a59f0a2 commit 85acd72
Show file tree
Hide file tree
Showing 34 changed files with 399 additions and 174 deletions.
5 changes: 4 additions & 1 deletion api/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ check: format-check lint db-check-migrations test

# Docker starts the image for the DB but it's not quite
# ready to accept connections so we add a brief wait script
init-db: start-db db-migrate
init-db: start-db setup-postgres-db db-migrate

start-db:
docker-compose up --detach grants-db
Expand Down Expand Up @@ -176,6 +176,9 @@ create-erds: # Create ERD diagrams for our DB schema
$(PY_RUN_CMD) create-erds
mv bin/*.png ../documentation/api/database/erds

setup-postgres-db: ## Does any initial setup necessary for our local database to work
$(PY_RUN_CMD) setup-postgres-db


##################################################
# Testing
Expand Down
1 change: 0 additions & 1 deletion api/local.env
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ POSTGRES_PASSWORD=secret123
DB_HOST=grants-db
DB_NAME=app
DB_USER=app
DB_SCHEMA=public
DB_PASSWORD=secret123
DB_SSL_MODE=allow

Expand Down
2 changes: 2 additions & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ db-migrate-down = "src.db.migrations.run:down"
db-migrate-down-all = "src.db.migrations.run:downall"
db-seed-local = "tests.lib.seed_local_db:seed_local_db"
create-erds = "bin.create_erds:main"
setup-postgres-db = "src.db.migrations.setup_local_postgres_db:setup_local_postgres_db"


[tool.black]
line-length = 100
Expand Down
3 changes: 2 additions & 1 deletion api/src/adapters/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
# Re-export for convenience
from src.adapters.db.client import Connection, DBClient, Session
from src.adapters.db.clients.postgres_client import PostgresDBClient
from src.adapters.db.clients.postgres_config import PostgresDBConfig

# Do not import flask_db here, because this module is not dependent on any specific framework.
# Code can choose to use this module on its own or with the flask_db module depending on needs.

__all__ = ["Connection", "DBClient", "Session", "PostgresDBClient"]
__all__ = ["Connection", "DBClient", "Session", "PostgresDBClient", "PostgresDBConfig"]
2 changes: 1 addition & 1 deletion api/src/adapters/db/clients/postgres_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_conn() -> Any:
"postgresql+psycopg://",
pool=conn_pool,
hide_parameters=db_config.hide_sql_parameter_logs,
execution_options={"schema_translate_map": db_config.get_schema_translate_map()},
# TODO: Don't think we need this as we aren't using JSON columns, but keeping for reference
# json_serializer=lambda o: json.dumps(o, default=pydantic.json.pydantic_encoder),
)
Expand Down Expand Up @@ -94,7 +95,6 @@ def get_connection_parameters(db_config: PostgresDBConfig) -> dict[str, Any]:
user=db_config.username,
password=password,
port=db_config.port,
options=f"-c search_path={db_config.db_schema}",
connect_timeout=10,
sslmode=db_config.ssl_mode,
**connect_args,
Expand Down
12 changes: 10 additions & 2 deletions api/src/adapters/db/clients/postgres_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pydantic import Field

from src.constants.schema import Schemas
from src.util.env_config import PydanticBaseEnvConfig

logger = logging.getLogger(__name__)
Expand All @@ -15,11 +16,19 @@ class PostgresDBConfig(PydanticBaseEnvConfig):
name: str = Field(alias="DB_NAME")
username: str = Field(alias="DB_USER")
password: Optional[str] = Field(None, alias="DB_PASSWORD")
db_schema: str = Field("public", alias="DB_SCHEMA")
port: int = Field(5432, alias="DB_PORT")
hide_sql_parameter_logs: bool = Field(True, alias="HIDE_SQL_PARAMETER_LOGS")
ssl_mode: str = Field("require", alias="DB_SSL_MODE")

schema_prefix_override: str | None = Field(None)

def get_schema_translate_map(self) -> dict[str, str]:
prefix = ""
if self.schema_prefix_override is not None:
prefix = self.schema_prefix_override

return {schema: f"{prefix}{schema}" for schema in Schemas}


def get_db_config() -> PostgresDBConfig:
db_config = PostgresDBConfig()
Expand All @@ -31,7 +40,6 @@ def get_db_config() -> PostgresDBConfig:
"dbname": db_config.name,
"username": db_config.username,
"password": "***" if db_config.password is not None else None,
"db_schema": db_config.db_schema,
"port": db_config.port,
"hide_sql_parameter_logs": db_config.hide_sql_parameter_logs,
},
Expand Down
5 changes: 5 additions & 0 deletions api/src/constants/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from enum import StrEnum


class Schemas(StrEnum):
API = "api"
21 changes: 12 additions & 9 deletions api/src/data_migration/copy_oracle_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import src.adapters.db as db
import src.adapters.db.flask_db as flask_db
from src.constants.schema import Schemas
from src.data_migration.data_migration_blueprint import data_migration_blueprint

logger = logging.getLogger(__name__)
Expand All @@ -20,10 +21,10 @@ class SqlCommands:
#################################

OPPORTUNITY_DELETE_QUERY = """
delete from transfer_topportunity
delete from {}.transfer_topportunity
"""
OPPORTUNITY_INSERT_QUERY = """
insert into transfer_topportunity
insert into {}.transfer_topportunity
select
opportunity_id,
oppnumber,
Expand All @@ -40,7 +41,7 @@ class SqlCommands:
last_upd_date,
creator_id,
created_date
from foreign_topportunity
from {}.foreign_topportunity
where is_draft = 'N'
"""

Expand All @@ -54,18 +55,20 @@ def copy_oracle_data(db_session: db.Session) -> None:

try:
with db_session.begin():
_run_copy_commands(db_session)
_run_copy_commands(db_session, Schemas.API)
except Exception:
logger.exception("Failed to run copy-oracle-data command")
raise

logger.info("Successfully ran copy-oracle-data")


def _run_copy_commands(db_session: db.Session) -> None:
def _run_copy_commands(db_session: db.Session, api_schema: str) -> None:
logger.info("Running copy commands for TOPPORTUNITY")

db_session.execute(text(SqlCommands.OPPORTUNITY_DELETE_QUERY))
db_session.execute(text(SqlCommands.OPPORTUNITY_INSERT_QUERY))
count = db_session.scalar(text("SELECT count(*) from transfer_topportunity"))
logger.info(f"Loaded {count} records into transfer_topportunity")
db_session.execute(text(SqlCommands.OPPORTUNITY_DELETE_QUERY.format(api_schema)))
db_session.execute(text(SqlCommands.OPPORTUNITY_INSERT_QUERY.format(api_schema, api_schema)))
count = db_session.scalar(
text(f"SELECT count(*) from {api_schema}.transfer_topportunity") # nosec
)
logger.info(f"Loaded {count} records into {api_schema}.transfer_topportunity")
15 changes: 12 additions & 3 deletions api/src/data_migration/setup_foreign_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import src.adapters.db as db
import src.adapters.db.flask_db as flask_db
from src.constants.schema import Schemas
from src.data_migration.data_migration_blueprint import data_migration_blueprint
from src.util.env_config import PydanticBaseEnvConfig

Expand All @@ -14,6 +15,7 @@

class ForeignTableConfig(PydanticBaseEnvConfig):
is_local_foreign_table: bool = Field(False)
schema_name: str = Field(Schemas.API)


@dataclass
Expand Down Expand Up @@ -62,7 +64,7 @@ def setup_foreign_tables(db_session: db.Session) -> None:
logger.info("Successfully ran setup-foreign-tables")


def build_sql(table_name: str, columns: list[Column], is_local: bool) -> str:
def build_sql(table_name: str, columns: list[Column], is_local: bool, schema_name: str) -> str:
"""
Build the SQL for creating a possibly foreign data table. If running
with is_local, it instead creates a regular table.
Expand Down Expand Up @@ -111,10 +113,17 @@ def build_sql(table_name: str, columns: list[Column], is_local: bool) -> str:
# We don't want the config at the end if we're running locally so unset it
create_command_suffix = ""

return f"{create_table_command} foreign_{table_name.lower()} ({','.join(column_sql_parts)}){create_command_suffix}"
return f"{create_table_command} {schema_name}.foreign_{table_name.lower()} ({','.join(column_sql_parts)}){create_command_suffix}"


def _run_create_table_commands(db_session: db.Session, config: ForeignTableConfig) -> None:
db_session.execute(
text(build_sql("TOPPORTUNITY", OPPORTUNITY_COLUMNS, config.is_local_foreign_table))
text(
build_sql(
"TOPPORTUNITY",
OPPORTUNITY_COLUMNS,
config.is_local_foreign_table,
config.schema_name,
)
)
)
8 changes: 7 additions & 1 deletion api/src/db/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import src.adapters.db as db
import src.logging
from src.constants.schema import Schemas
from src.db.models import metadata

from src.adapters.db.type_decorators.postgres_type_decorators import LookupColumn # isort:skip
Expand Down Expand Up @@ -36,6 +37,10 @@ def include_object(
reflected: bool,
compare_to: Any,
) -> bool:
# We don't want alembic to try and drop its own table
if name == "alembic_version":
return False

if type_ == "schema" and getattr(object, "schema", None) is not None:
return False
if type_ == "table" and name is not None and name.startswith("foreign_"):
Expand Down Expand Up @@ -69,10 +74,11 @@ def run_migrations_online() -> None:
context.configure(
connection=connection,
target_metadata=target_metadata,
include_schemas=False,
include_schemas=True,
include_object=include_object,
compare_type=True,
render_item=render_item,
version_table_schema=Schemas.API,
)
with context.begin_transaction():
context.run_migrations()
Expand Down
27 changes: 27 additions & 0 deletions api/src/db/migrations/setup_local_postgres_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import logging

from sqlalchemy import text

import src.adapters.db as db
import src.logging
from src.adapters.db import PostgresDBClient
from src.constants.schema import Schemas
from src.util.local import error_if_not_local

logger = logging.getLogger(__name__)


def setup_local_postgres_db() -> None:
with src.logging.init(__package__):
error_if_not_local()

db_client = PostgresDBClient()

with db_client.get_connection() as conn, conn.begin():
for schema in Schemas:
_create_schema(conn, schema)


def _create_schema(conn: db.Connection, schema_name: str) -> None:
logger.info("Creating schema %s if it does not already exist", schema_name)
conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name}"))
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ def upgrade():
nullable=False,
),
sa.PrimaryKeyConstraint("opportunity_id", name=op.f("topportunity_pkey")),
schema="api",
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("topportunity")
op.drop_table("topportunity", schema="api")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def upgrade():
nullable=False,
),
sa.PrimaryKeyConstraint("opportunity_id", name=op.f("opportunity_pkey")),
schema="api",
)
op.drop_table("topportunity")
op.drop_table("topportunity", schema="api")
# ### end Alembic commands ###


Expand Down Expand Up @@ -69,6 +70,7 @@ def downgrade():
nullable=False,
),
sa.PrimaryKeyConstraint("opportunity_id", name="topportunity_pkey"),
schema="api",
)
op.drop_table("opportunity")
op.drop_table("opportunity", schema="api")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,45 @@

def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("opportunity", sa.Column("category_explanation", sa.Text(), nullable=True))
op.add_column("opportunity", sa.Column("revision_number", sa.Integer(), nullable=True))
op.add_column("opportunity", sa.Column("modified_comments", sa.Text(), nullable=True))
op.add_column("opportunity", sa.Column("publisher_user_id", sa.Integer(), nullable=True))
op.add_column("opportunity", sa.Column("publisher_profile_id", sa.Integer(), nullable=True))
op.create_index(op.f("opportunity_category_idx"), "opportunity", ["category"], unique=False)
op.create_index(op.f("opportunity_is_draft_idx"), "opportunity", ["is_draft"], unique=False)
op.add_column(
"opportunity", sa.Column("category_explanation", sa.Text(), nullable=True), schema="api"
)
op.add_column(
"opportunity", sa.Column("revision_number", sa.Integer(), nullable=True), schema="api"
)
op.add_column(
"opportunity", sa.Column("modified_comments", sa.Text(), nullable=True), schema="api"
)
op.add_column(
"opportunity", sa.Column("publisher_user_id", sa.Integer(), nullable=True), schema="api"
)
op.add_column(
"opportunity", sa.Column("publisher_profile_id", sa.Integer(), nullable=True), schema="api"
)
op.create_index(
op.f("opportunity_category_idx"), "opportunity", ["category"], unique=False, schema="api"
)
op.create_index(
op.f("opportunity_is_draft_idx"), "opportunity", ["is_draft"], unique=False, schema="api"
)
op.create_index(
op.f("opportunity_opportunity_title_idx"),
"opportunity",
["opportunity_title"],
unique=False,
schema="api",
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("opportunity_opportunity_title_idx"), table_name="opportunity")
op.drop_index(op.f("opportunity_is_draft_idx"), table_name="opportunity")
op.drop_index(op.f("opportunity_category_idx"), table_name="opportunity")
op.drop_column("opportunity", "publisher_profile_id")
op.drop_column("opportunity", "publisher_user_id")
op.drop_column("opportunity", "modified_comments")
op.drop_column("opportunity", "revision_number")
op.drop_column("opportunity", "category_explanation")
op.drop_index(op.f("opportunity_opportunity_title_idx"), table_name="opportunity", schema="api")
op.drop_index(op.f("opportunity_is_draft_idx"), table_name="opportunity", schema="api")
op.drop_index(op.f("opportunity_category_idx"), table_name="opportunity", schema="api")
op.drop_column("opportunity", "publisher_profile_id", schema="api")
op.drop_column("opportunity", "publisher_user_id", schema="api")
op.drop_column("opportunity", "modified_comments", schema="api")
op.drop_column("opportunity", "revision_number", schema="api")
op.drop_column("opportunity", "category_explanation", schema="api")
# ### end Alembic commands ###
Loading

0 comments on commit 85acd72

Please sign in to comment.