Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attempt input file parameter #2

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 47 additions & 34 deletions reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ def get_dump_sql(database: str, schema: str, table: str) -> str:
return result.stdout.decode()


def get_columns(database: str, schema: str, table: str) -> Tuple[List[str], List[str]]:
def get_columns(sql_text: str, schema: str, table: str) -> Tuple[List[str], List[str]]:
"""Get columns for a table"""
sql_text = get_dump_sql(database, schema, table)
table_re = re.compile(
fr"(?P<pre>(?:\n|.)+CREATE TABLE {schema}.{table}\s+\(\n)(?P<rows>(?:\n|.)+?)(?P<post>\);(?:\n|.)+)"
)
Expand All @@ -42,38 +41,42 @@ def get_columns(database: str, schema: str, table: str) -> Tuple[List[str], List


def get_migration_sql(
database: str, schema: str, table: str, columns: List[str], extras: List[str]
sql_text: str, database: str, schema: str, table: str, columns: List[str], extras: List[str], use_input_file: bool
) -> str:
"""Get SQL command to migrate a source table into the target table"""
sql_text = get_dump_sql(database, schema, table)
table_re = re.compile(
fr"(?P<pre>(?:\n|.)+)(?P<table>CREATE TABLE {schema}\.{table}\s+\(\n(?:\n|.)+?\);)(?P<post>(?:\n|.)+)"
)
match = table_re.search(sql_text)

fk_disable = "\n".join(
[
cleandoc(
f"""
ALTER TABLE {fk['schema']}.{fk['local_table']}
DROP CONSTRAINT {fk['constraint']};
"""
)
for fk in get_foreign_keys(database, schema, table)
]
)
fk_enable = "\n".join(
[
cleandoc(
f"""
ALTER TABLE {fk['schema']}.{fk['local_table']}
ADD CONSTRAINT {fk['constraint']} FOREIGN KEY ({fk['local_column']})
REFERENCES {fk['schema']}.{fk['foreign_table']} ({fk['foreign_column']});
"""
)
for fk in get_foreign_keys(database, schema, table)
]
)
if use_input_file:
fk_disable = "\n".join(
[
cleandoc(
f"""
ALTER TABLE {fk['schema']}.{fk['local_table']}
DROP CONSTRAINT {fk['constraint']};
"""
)
for fk in get_foreign_keys(database, schema, table)
]
)
fk_enable = "\n".join(
[
cleandoc(
f"""
ALTER TABLE {fk['schema']}.{fk['local_table']}
ADD CONSTRAINT {fk['constraint']} FOREIGN KEY ({fk['local_column']})
REFERENCES {fk['schema']}.{fk['foreign_table']} ({fk['foreign_column']});
"""
)
for fk in get_foreign_keys(database, schema, table)
]
)
else:
fk_disable = "-- NO FK MANAGEMENT WHEN READING FROM FILE --"
fk_enable = "-- NO FK MANAGEMENT WHEN READING FROM FILE --"

extra_features = "\n".join(
[
f"ALTER TABLE {schema}.{table} ADD {row};"
Expand Down Expand Up @@ -148,7 +151,7 @@ def sort_input_columns(
def reorder_columns(
target_start: List[str],
target_end: List[str],
target_exclude: List[str],
target_exclude: Tuple[str],
columns: List[str],
) -> List[str]:
"""Given a lost of columns and several target lists, return a sorted list of columns"""
Expand Down Expand Up @@ -178,14 +181,16 @@ def printcols(cols: List[str], header: Optional[str] = None) -> None:
@click.option("--database", "-d", help="The name of the database.")
@click.option("--schema", "-n", default="public", help="The schema of the target table.")
@click.option("--migrate", "-m", is_flag=True, help="Output full migration sql.")
@click.option("--file", "-f", "output_file", type=click.File("w"), help="Write output into a file.")
@click.option("--input-file", "-i", "input_file", type=click.File("r"), help="Write output into a file.")
@click.option("--output-file", "-o", "output_file", type=click.File("w"), help="Write output into a file.")
@click.argument("table")
@click.argument("columns", nargs=-1)
def main(
migrate: bool,
exclude,
exclude: Tuple[str],
database: str,
schema,
schema: str,
input_file,
output_file,
table: str,
columns: Tuple[str],
Expand All @@ -198,14 +203,22 @@ def main(
and the last column will be placed at the end of the table. When entered as
"... col1 col2 col3" all three columns will be placed at the end of the table.
"""
cols, extras = get_columns(database, schema, table)

use_input_file = input_file is None
sql_text = get_dump_sql(database, schema, table) if use_input_file else input_file.read()
cols, extras = get_columns(sql_text, schema, table)

if len(columns):
target_start, target_end = sort_input_columns(list(columns))
cols = reorder_columns(target_start, target_end, list(exclude), cols)
cols = reorder_columns(
target_start,
target_end,
exclude,
cols
)

if migrate:
query = get_migration_sql(database, schema, table, cols, extras)
query = get_migration_sql(sql_text, database, schema, table, cols, extras, use_input_file=use_input_file)

if output_file is not None:
output_file.write(query)
Expand Down