Skip to content

Commit

Permalink
Changing SQL conversion to always use pretty=True (#248)
Browse files Browse the repository at this point in the history
This drastically increases the readability of the generated SQL text when debugging, especially when errors are caused during runtime execution of the query without a clear cause unless you look at the generated SQL text. Since this changed a lot of tests, also moved all of the SQL generation tests to SQL files in a folder that either get read or written to depending on the same environment variable as planner tests.
  • Loading branch information
knassre-bodo authored Feb 6, 2025
1 parent f9b2dce commit 0482148
Show file tree
Hide file tree
Showing 48 changed files with 796 additions and 84 deletions.
10 changes: 2 additions & 8 deletions pydough/sqlglot/execute_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def convert_relation_to_sql(
relational: RelationalRoot,
dialect: SQLGlotDialect,
bindings: SqlGlotTransformBindings,
pretty_print_sql: bool = False,
) -> str:
"""
Convert the given relational tree to a SQL string using the given dialect.
Expand All @@ -44,7 +43,7 @@ def convert_relation_to_sql(
glot_expr: SQLGlotExpression = SQLGlotRelationalVisitor(
dialect, bindings
).relational_to_sqlglot(relational)
return glot_expr.sql(dialect, pretty=pretty_print_sql)
return glot_expr.sql(dialect, pretty=True)


def convert_dialect_to_sqlglot(dialect: DatabaseDialect) -> SQLGlotDialect:
Expand Down Expand Up @@ -88,12 +87,7 @@ def execute_df(
The result of the query as a Pandas DataFrame
"""
sqlglot_dialect: SQLGlotDialect = convert_dialect_to_sqlglot(ctx.dialect)
pretty_print_sql: bool = False
if display_sql:
pretty_print_sql = True
sql: str = convert_relation_to_sql(
relational, sqlglot_dialect, bindings, pretty_print_sql
)
sql: str = convert_relation_to_sql(relational, sqlglot_dialect, bindings)
if display_sql:
pyd_logger = get_logger(__name__)
pyd_logger.info(f"SQL query:\n {sql}")
Expand Down
21 changes: 17 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,25 @@ def impl(file_name: str) -> str:
return impl


@pytest.fixture(scope="session")
def get_sql_test_filename() -> Callable[[str], str]:
"""
A function that takes in a file name and returns the path to that file
from within the directory of SQL text testing refsol files.
"""

def impl(file_name: str) -> str:
return f"{os.path.dirname(__file__)}/test_sql_refsols/{file_name}.sql"

return impl


@pytest.fixture
def update_plan_tests() -> bool:
def update_tests() -> bool:
"""
If True, planner tests should update the refsol file instead of verifying
that the test matches the file. If False, the refsol file is used to check
the answer.
If True, planner/sql tests should update the refsol file instead of
verifying that the test matches the file. If False, the refsol file is used
to check the answer.
This is controlled by an environment variable `PYDOUGH_UPDATE_TESTS`.
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ def test_pipeline_until_relational(
get_sample_graph: graph_fetcher,
default_config: PyDoughConfigs,
get_plan_test_filename: Callable[[str], str],
update_plan_tests: bool,
update_tests: bool,
) -> None:
"""
Tests that a PyDough unqualified node can be correctly translated to its
Expand All @@ -674,7 +674,7 @@ def test_pipeline_until_relational(
qualified, PyDoughCollectionQDAG
), "Expected qualified answer to be a collection, not an expression"
relational: RelationalRoot = convert_ast_to_relational(qualified, default_config)
if update_plan_tests:
if update_tests:
with open(file_path, "w") as f:
f.write(relational.to_tree_string() + "\n")
else:
Expand Down
53 changes: 35 additions & 18 deletions tests/test_pydough_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,42 +25,41 @@


@pytest.mark.parametrize(
"pydough_code, expected_sql",
# Note: All of these tests are for simple code because the
# exact SQL generation for inner expressions is currently
# non-deterministic.
"pydough_code, test_name",
[
pytest.param(
simple_scan,
"SELECT o_orderkey AS key FROM tpch.ORDERS",
"simple_scan",
id="simple_scan",
),
pytest.param(
simple_filter,
"SELECT o_orderkey, o_totalprice FROM (SELECT o_orderkey AS o_orderkey, o_totalprice AS o_totalprice FROM tpch.ORDERS) WHERE o_totalprice < 1000.0",
"simple_filter",
id="simple_filter",
),
pytest.param(
rank_a,
"SELECT ROW_NUMBER() OVER (ORDER BY acctbal DESC NULLS FIRST) AS rank FROM (SELECT c_acctbal AS acctbal FROM tpch.CUSTOMER)",
"rank_a",
id="rank_a",
),
pytest.param(
rank_b,
" SELECT RANK() OVER (ORDER BY order_priority NULLS LAST) AS rank FROM (SELECT o_orderpriority AS order_priority FROM tpch.ORDERS)",
"rank_b",
id="rank_b",
),
pytest.param(
rank_c,
"SELECT order_date, DENSE_RANK() OVER (ORDER BY order_date NULLS LAST) AS rank FROM (SELECT o_orderdate AS order_date FROM tpch.ORDERS)",
"rank_c",
id="rank_c",
),
],
)
def test_pydough_to_sql(
def test_pydough_to_sql_tpch(
pydough_code: Callable[[], UnqualifiedNode],
expected_sql: str,
test_name: str,
get_sample_graph: graph_fetcher,
get_sql_test_filename: Callable[[str], str],
update_tests: bool,
) -> None:
"""
Tests that a PyDough unqualified node can be correctly translated to its
Expand All @@ -69,26 +68,36 @@ def test_pydough_to_sql(
graph: GraphMetadata = get_sample_graph("TPCH")
root: UnqualifiedNode = init_pydough_context(graph)(pydough_code)()
actual_sql: str = to_sql(root, metadata=graph).strip()
expected_sql = expected_sql.strip()
assert actual_sql == expected_sql
file_path: str = get_sql_test_filename(test_name)
if update_tests:
with open(file_path, "w") as f:
f.write(actual_sql + "\n")
else:
with open(file_path) as f:
expected_relational_string: str = f.read()
assert (
actual_sql == expected_relational_string.strip()
), "Mismatch between tree generated SQL text and expected SQL text"


@pytest.mark.parametrize(
"pydough_code,expected_sql,graph_name",
"pydough_code,test_name,graph_name",
[
pytest.param(
hour_minute_day,
"""SELECT transaction_id, _expr0, _expr1, _expr2 FROM (SELECT transaction_id AS ordering_0, _expr0, _expr1, _expr2, transaction_id FROM (SELECT _expr0, _expr1, _expr2, symbol, transaction_id FROM (SELECT EXTRACT(HOUR FROM date_time) AS _expr0, EXTRACT(MINUTE FROM date_time) AS _expr1, EXTRACT(SECOND FROM date_time) AS _expr2, ticker_id, transaction_id FROM (SELECT sbTxDateTime AS date_time, sbTxId AS transaction_id, sbTxTickerId AS ticker_id FROM main.sbTransaction)) LEFT JOIN (SELECT sbTickerId AS _id, sbTickerSymbol AS symbol FROM main.sbTicker) ON ticker_id = _id) WHERE symbol IN ('AAPL', 'GOOGL', 'NFLX')) ORDER BY ordering_0""",
"hour_minute_day",
"Broker",
id="hour_minute_day",
),
],
)
def test_pydough_to_sql_defog(
pydough_code: Callable[[], UnqualifiedNode],
expected_sql: str,
test_name: str,
graph_name: str,
defog_graphs: graph_fetcher,
get_sql_test_filename: Callable[[str], str],
update_tests: bool,
) -> None:
"""
Tests that a PyDough unqualified node can be correctly translated to its
Expand All @@ -97,5 +106,13 @@ def test_pydough_to_sql_defog(
graph: GraphMetadata = defog_graphs(graph_name)
root: UnqualifiedNode = init_pydough_context(graph)(pydough_code)()
actual_sql: str = to_sql(root, metadata=graph).strip()
expected_sql = expected_sql.strip()
assert actual_sql == expected_sql
file_path: str = get_sql_test_filename(test_name)
if update_tests:
with open(file_path, "w") as f:
f.write(actual_sql + "\n")
else:
with open(file_path) as f:
expected_relational_string: str = f.read()
assert (
actual_sql == expected_relational_string.strip()
), "Mismatch between tree generated SQL text and expected SQL text"
8 changes: 4 additions & 4 deletions tests/test_qdag_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2200,7 +2200,7 @@ def test_ast_to_relational(
tpch_node_builder: AstNodeBuilder,
default_config: PyDoughConfigs,
get_plan_test_filename: Callable[[str], str],
update_plan_tests: bool,
update_tests: bool,
) -> None:
"""
Tests whether the QDAG nodes are correctly translated into Relational nodes
Expand All @@ -2210,7 +2210,7 @@ def test_ast_to_relational(
file_path: str = get_plan_test_filename(file_name)
collection: PyDoughCollectionQDAG = calc_pipeline.build(tpch_node_builder)
relational = convert_ast_to_relational(collection, default_config)
if update_plan_tests:
if update_tests:
with open(file_path, "w") as f:
f.write(relational.to_tree_string() + "\n")
else:
Expand Down Expand Up @@ -2336,7 +2336,7 @@ def test_ast_to_relational_alternative_aggregation_configs(
tpch_node_builder: AstNodeBuilder,
default_config: PyDoughConfigs,
get_plan_test_filename: Callable[[str], str],
update_plan_tests: bool,
update_tests: bool,
) -> None:
"""
Same as `test_ast_to_relational` but with various alternative aggregation
Expand All @@ -2350,7 +2350,7 @@ def test_ast_to_relational_alternative_aggregation_configs(
default_config.avg_default_zero = True
collection: PyDoughCollectionQDAG = calc_pipeline.build(tpch_node_builder)
relational = convert_ast_to_relational(collection, default_config)
if update_plan_tests:
if update_tests:
with open(file_path, "w") as f:
f.write(relational.to_tree_string() + "\n")
else:
Expand Down
Loading

0 comments on commit 0482148

Please sign in to comment.