Skip to content

Commit 0482148

Browse files
authored
Changing SQL conversion to always use pretty=True (#248)
This drastically increases the readability of the generated SQL text when debugging, especially when errors are caused during runtime execution of the query without a clear cause unless you look at the generated SQL text. Since this changed a lot of tests, also moved all of the SQL generation tests to SQL files in a folder that either get read or written to depending on the same environment variable as planner tests.
1 parent f9b2dce commit 0482148

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+796
-84
lines changed

pydough/sqlglot/execute_relational.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def convert_relation_to_sql(
2626
relational: RelationalRoot,
2727
dialect: SQLGlotDialect,
2828
bindings: SqlGlotTransformBindings,
29-
pretty_print_sql: bool = False,
3029
) -> str:
3130
"""
3231
Convert the given relational tree to a SQL string using the given dialect.
@@ -44,7 +43,7 @@ def convert_relation_to_sql(
4443
glot_expr: SQLGlotExpression = SQLGlotRelationalVisitor(
4544
dialect, bindings
4645
).relational_to_sqlglot(relational)
47-
return glot_expr.sql(dialect, pretty=pretty_print_sql)
46+
return glot_expr.sql(dialect, pretty=True)
4847

4948

5049
def convert_dialect_to_sqlglot(dialect: DatabaseDialect) -> SQLGlotDialect:
@@ -88,12 +87,7 @@ def execute_df(
8887
The result of the query as a Pandas DataFrame
8988
"""
9089
sqlglot_dialect: SQLGlotDialect = convert_dialect_to_sqlglot(ctx.dialect)
91-
pretty_print_sql: bool = False
92-
if display_sql:
93-
pretty_print_sql = True
94-
sql: str = convert_relation_to_sql(
95-
relational, sqlglot_dialect, bindings, pretty_print_sql
96-
)
90+
sql: str = convert_relation_to_sql(relational, sqlglot_dialect, bindings)
9791
if display_sql:
9892
pyd_logger = get_logger(__name__)
9993
pyd_logger.info(f"SQL query:\n {sql}")

tests/conftest.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,25 @@ def impl(file_name: str) -> str:
152152
return impl
153153

154154

155+
@pytest.fixture(scope="session")
156+
def get_sql_test_filename() -> Callable[[str], str]:
157+
"""
158+
A function that takes in a file name and returns the path to that file
159+
from within the directory of SQL text testing refsol files.
160+
"""
161+
162+
def impl(file_name: str) -> str:
163+
return f"{os.path.dirname(__file__)}/test_sql_refsols/{file_name}.sql"
164+
165+
return impl
166+
167+
155168
@pytest.fixture
156-
def update_plan_tests() -> bool:
169+
def update_tests() -> bool:
157170
"""
158-
If True, planner tests should update the refsol file instead of verifying
159-
that the test matches the file. If False, the refsol file is used to check
160-
the answer.
171+
If True, planner/sql tests should update the refsol file instead of
172+
verifying that the test matches the file. If False, the refsol file is used
173+
to check the answer.
161174
162175
This is controlled by an environment variable `PYDOUGH_UPDATE_TESTS`.
163176
"""

tests/test_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ def test_pipeline_until_relational(
655655
get_sample_graph: graph_fetcher,
656656
default_config: PyDoughConfigs,
657657
get_plan_test_filename: Callable[[str], str],
658-
update_plan_tests: bool,
658+
update_tests: bool,
659659
) -> None:
660660
"""
661661
Tests that a PyDough unqualified node can be correctly translated to its
@@ -674,7 +674,7 @@ def test_pipeline_until_relational(
674674
qualified, PyDoughCollectionQDAG
675675
), "Expected qualified answer to be a collection, not an expression"
676676
relational: RelationalRoot = convert_ast_to_relational(qualified, default_config)
677-
if update_plan_tests:
677+
if update_tests:
678678
with open(file_path, "w") as f:
679679
f.write(relational.to_tree_string() + "\n")
680680
else:

tests/test_pydough_to_sql.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,42 +25,41 @@
2525

2626

2727
@pytest.mark.parametrize(
28-
"pydough_code, expected_sql",
29-
# Note: All of these tests are for simple code because the
30-
# exact SQL generation for inner expressions is currently
31-
# non-deterministic.
28+
"pydough_code, test_name",
3229
[
3330
pytest.param(
3431
simple_scan,
35-
"SELECT o_orderkey AS key FROM tpch.ORDERS",
32+
"simple_scan",
3633
id="simple_scan",
3734
),
3835
pytest.param(
3936
simple_filter,
40-
"SELECT o_orderkey, o_totalprice FROM (SELECT o_orderkey AS o_orderkey, o_totalprice AS o_totalprice FROM tpch.ORDERS) WHERE o_totalprice < 1000.0",
37+
"simple_filter",
4138
id="simple_filter",
4239
),
4340
pytest.param(
4441
rank_a,
45-
"SELECT ROW_NUMBER() OVER (ORDER BY acctbal DESC NULLS FIRST) AS rank FROM (SELECT c_acctbal AS acctbal FROM tpch.CUSTOMER)",
42+
"rank_a",
4643
id="rank_a",
4744
),
4845
pytest.param(
4946
rank_b,
50-
" SELECT RANK() OVER (ORDER BY order_priority NULLS LAST) AS rank FROM (SELECT o_orderpriority AS order_priority FROM tpch.ORDERS)",
47+
"rank_b",
5148
id="rank_b",
5249
),
5350
pytest.param(
5451
rank_c,
55-
"SELECT order_date, DENSE_RANK() OVER (ORDER BY order_date NULLS LAST) AS rank FROM (SELECT o_orderdate AS order_date FROM tpch.ORDERS)",
52+
"rank_c",
5653
id="rank_c",
5754
),
5855
],
5956
)
60-
def test_pydough_to_sql(
57+
def test_pydough_to_sql_tpch(
6158
pydough_code: Callable[[], UnqualifiedNode],
62-
expected_sql: str,
59+
test_name: str,
6360
get_sample_graph: graph_fetcher,
61+
get_sql_test_filename: Callable[[str], str],
62+
update_tests: bool,
6463
) -> None:
6564
"""
6665
Tests that a PyDough unqualified node can be correctly translated to its
@@ -69,26 +68,36 @@ def test_pydough_to_sql(
6968
graph: GraphMetadata = get_sample_graph("TPCH")
7069
root: UnqualifiedNode = init_pydough_context(graph)(pydough_code)()
7170
actual_sql: str = to_sql(root, metadata=graph).strip()
72-
expected_sql = expected_sql.strip()
73-
assert actual_sql == expected_sql
71+
file_path: str = get_sql_test_filename(test_name)
72+
if update_tests:
73+
with open(file_path, "w") as f:
74+
f.write(actual_sql + "\n")
75+
else:
76+
with open(file_path) as f:
77+
expected_relational_string: str = f.read()
78+
assert (
79+
actual_sql == expected_relational_string.strip()
80+
), "Mismatch between tree generated SQL text and expected SQL text"
7481

7582

7683
@pytest.mark.parametrize(
77-
"pydough_code,expected_sql,graph_name",
84+
"pydough_code,test_name,graph_name",
7885
[
7986
pytest.param(
8087
hour_minute_day,
81-
"""SELECT transaction_id, _expr0, _expr1, _expr2 FROM (SELECT transaction_id AS ordering_0, _expr0, _expr1, _expr2, transaction_id FROM (SELECT _expr0, _expr1, _expr2, symbol, transaction_id FROM (SELECT EXTRACT(HOUR FROM date_time) AS _expr0, EXTRACT(MINUTE FROM date_time) AS _expr1, EXTRACT(SECOND FROM date_time) AS _expr2, ticker_id, transaction_id FROM (SELECT sbTxDateTime AS date_time, sbTxId AS transaction_id, sbTxTickerId AS ticker_id FROM main.sbTransaction)) LEFT JOIN (SELECT sbTickerId AS _id, sbTickerSymbol AS symbol FROM main.sbTicker) ON ticker_id = _id) WHERE symbol IN ('AAPL', 'GOOGL', 'NFLX')) ORDER BY ordering_0""",
88+
"hour_minute_day",
8289
"Broker",
8390
id="hour_minute_day",
8491
),
8592
],
8693
)
8794
def test_pydough_to_sql_defog(
8895
pydough_code: Callable[[], UnqualifiedNode],
89-
expected_sql: str,
96+
test_name: str,
9097
graph_name: str,
9198
defog_graphs: graph_fetcher,
99+
get_sql_test_filename: Callable[[str], str],
100+
update_tests: bool,
92101
) -> None:
93102
"""
94103
Tests that a PyDough unqualified node can be correctly translated to its
@@ -97,5 +106,13 @@ def test_pydough_to_sql_defog(
97106
graph: GraphMetadata = defog_graphs(graph_name)
98107
root: UnqualifiedNode = init_pydough_context(graph)(pydough_code)()
99108
actual_sql: str = to_sql(root, metadata=graph).strip()
100-
expected_sql = expected_sql.strip()
101-
assert actual_sql == expected_sql
109+
file_path: str = get_sql_test_filename(test_name)
110+
if update_tests:
111+
with open(file_path, "w") as f:
112+
f.write(actual_sql + "\n")
113+
else:
114+
with open(file_path) as f:
115+
expected_relational_string: str = f.read()
116+
assert (
117+
actual_sql == expected_relational_string.strip()
118+
), "Mismatch between tree generated SQL text and expected SQL text"

tests/test_qdag_conversion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,7 +2200,7 @@ def test_ast_to_relational(
22002200
tpch_node_builder: AstNodeBuilder,
22012201
default_config: PyDoughConfigs,
22022202
get_plan_test_filename: Callable[[str], str],
2203-
update_plan_tests: bool,
2203+
update_tests: bool,
22042204
) -> None:
22052205
"""
22062206
Tests whether the QDAG nodes are correctly translated into Relational nodes
@@ -2210,7 +2210,7 @@ def test_ast_to_relational(
22102210
file_path: str = get_plan_test_filename(file_name)
22112211
collection: PyDoughCollectionQDAG = calc_pipeline.build(tpch_node_builder)
22122212
relational = convert_ast_to_relational(collection, default_config)
2213-
if update_plan_tests:
2213+
if update_tests:
22142214
with open(file_path, "w") as f:
22152215
f.write(relational.to_tree_string() + "\n")
22162216
else:
@@ -2336,7 +2336,7 @@ def test_ast_to_relational_alternative_aggregation_configs(
23362336
tpch_node_builder: AstNodeBuilder,
23372337
default_config: PyDoughConfigs,
23382338
get_plan_test_filename: Callable[[str], str],
2339-
update_plan_tests: bool,
2339+
update_tests: bool,
23402340
) -> None:
23412341
"""
23422342
Same as `test_ast_to_relational` but with various alternative aggregation
@@ -2350,7 +2350,7 @@ def test_ast_to_relational_alternative_aggregation_configs(
23502350
default_config.avg_default_zero = True
23512351
collection: PyDoughCollectionQDAG = calc_pipeline.build(tpch_node_builder)
23522352
relational = convert_ast_to_relational(collection, default_config)
2353-
if update_plan_tests:
2353+
if update_tests:
23542354
with open(file_path, "w") as f:
23552355
f.write(relational.to_tree_string() + "\n")
23562356
else:

0 commit comments

Comments
 (0)