diff --git a/reorder.py b/reorder.py index fa01a04..46738c8 100755 --- a/reorder.py +++ b/reorder.py @@ -20,9 +20,8 @@ def get_dump_sql(database: str, schema: str, table: str) -> str: return result.stdout.decode() -def get_columns(database: str, schema: str, table: str) -> Tuple[List[str], List[str]]: +def get_columns(sql_text: str, schema: str, table: str) -> Tuple[List[str], List[str]]: """Get columns for a table""" - sql_text = get_dump_sql(database, schema, table) table_re = re.compile( fr"(?P
(?:\n|.)+CREATE TABLE {schema}.{table}\s+\(\n)(?P(?:\n|.)+?)(?P\);(?:\n|.)+)"
     )
@@ -42,38 +41,42 @@ def get_columns(database: str, schema: str, table: str) -> Tuple[List[str], List
 
 
 def get_migration_sql(
-    database: str, schema: str, table: str, columns: List[str], extras: List[str]
+    sql_text: str, database: str, schema: str, table: str, columns: List[str], extras: List[str], use_input_file: bool
 ) -> str:
     """Get SQL command to migrate a source table into the target table"""
-    sql_text = get_dump_sql(database, schema, table)
     table_re = re.compile(
         fr"(?P
(?:\n|.)+)(?PCREATE TABLE {schema}\.{table}\s+\(\n(?:\n|.)+?\);)(?P(?:\n|.)+)"
     )
     match = table_re.search(sql_text)
 
-    fk_disable = "\n".join(
-        [
-            cleandoc(
-                f"""
-            ALTER TABLE {fk['schema']}.{fk['local_table']}
-            DROP CONSTRAINT {fk['constraint']};
-        """
-            )
-            for fk in get_foreign_keys(database, schema, table)
-        ]
-    )
-    fk_enable = "\n".join(
-        [
-            cleandoc(
-                f"""
-            ALTER TABLE {fk['schema']}.{fk['local_table']}
-            ADD CONSTRAINT {fk['constraint']} FOREIGN KEY ({fk['local_column']})
-            REFERENCES {fk['schema']}.{fk['foreign_table']} ({fk['foreign_column']});
-        """
-            )
-            for fk in get_foreign_keys(database, schema, table)
-        ]
-    )
+    if use_input_file:
+        fk_disable = "\n".join(
+            [
+                cleandoc(
+                    f"""
+                ALTER TABLE {fk['schema']}.{fk['local_table']}
+                DROP CONSTRAINT {fk['constraint']};
+            """
+                )
+                for fk in get_foreign_keys(database, schema, table)
+            ]
+        )
+        fk_enable = "\n".join(
+            [
+                cleandoc(
+                    f"""
+                ALTER TABLE {fk['schema']}.{fk['local_table']}
+                ADD CONSTRAINT {fk['constraint']} FOREIGN KEY ({fk['local_column']})
+                REFERENCES {fk['schema']}.{fk['foreign_table']} ({fk['foreign_column']});
+            """
+                )
+                for fk in get_foreign_keys(database, schema, table)
+            ]
+        )
+    else:
+        fk_disable = "-- NO FK MANAGEMENT WHEN READING FROM FILE --"
+        fk_enable = "-- NO FK MANAGEMENT WHEN READING FROM FILE --"
+
     extra_features = "\n".join(
         [
             f"ALTER TABLE {schema}.{table} ADD {row};"
@@ -148,7 +151,7 @@ def sort_input_columns(
 def reorder_columns(
     target_start: List[str],
     target_end: List[str],
-    target_exclude: List[str],
+    target_exclude: Tuple[str],
     columns: List[str],
 ) -> List[str]:
     """Given a lost of columns and several target lists, return a sorted list of columns"""
@@ -178,14 +181,16 @@ def printcols(cols: List[str], header: Optional[str] = None) -> None:
 @click.option("--database", "-d", help="The name of the database.")
 @click.option("--schema", "-n", default="public", help="The schema of the target table.")
 @click.option("--migrate", "-m", is_flag=True, help="Output full migration sql.")
-@click.option("--file", "-f", "output_file", type=click.File("w"), help="Write output into a file.")
+@click.option("--input-file", "-i", "input_file", type=click.File("r"), help="Write output into a file.")
+@click.option("--output-file", "-o", "output_file", type=click.File("w"), help="Write output into a file.")
 @click.argument("table")
 @click.argument("columns", nargs=-1)
 def main(
     migrate: bool,
-    exclude,
+    exclude: Tuple[str],
     database: str,
-    schema,
+    schema: str,
+    input_file,
     output_file,
     table: str,
     columns: Tuple[str],
@@ -198,14 +203,22 @@ def main(
     and the last column will be placed at the end of the table. When entered as
     "... col1 col2 col3" all three columns will be placed at the end of the table.
     """
-    cols, extras = get_columns(database, schema, table)
+
+    use_input_file = input_file is None
+    sql_text = get_dump_sql(database, schema, table) if use_input_file else input_file.read()
+    cols, extras = get_columns(sql_text, schema, table)
 
     if len(columns):
         target_start, target_end = sort_input_columns(list(columns))
-        cols = reorder_columns(target_start, target_end, list(exclude), cols)
+        cols = reorder_columns(
+            target_start,
+            target_end,
+            exclude,
+            cols
+        )
 
         if migrate:
-            query = get_migration_sql(database, schema, table, cols, extras)
+            query = get_migration_sql(sql_text, database, schema, table, cols, extras, use_input_file=use_input_file)
 
             if output_file is not None:
                 output_file.write(query)