Skip to content

Commit f71d586

Browse files
committed
/* PR_START p--sql 03 */ Fix handling of some expressions when column-pruning CTEs.
1 parent 0d8411f commit f71d586

File tree

2 files changed

+116
-0
lines changed

2 files changed

+116
-0
lines changed

metricflow/sql/optimizer/column_pruning/required_column_aliases.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,13 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
252252
join_desc.right_source for join_desc in node.join_descs
253253
):
254254
self._current_required_column_alias_mapping.add_aliases(node_to_retain_columns, column_aliases_to_retain)
255+
sql_table_node = node_to_retain_columns.as_sql_table_node
256+
if sql_table_node is not None and sql_table_node.sql_table.schema_name is None:
257+
self._map_required_column_aliases_in_potential_cte(
258+
cte_alias_mapping=cte_alias_mapping,
259+
table_name=sql_table_node.sql_table.table_name,
260+
column_aliases=column_aliases_to_retain,
261+
)
255262

256263
# Visit recursively.
257264
self._visit_parents(node)

tests_metricflow/sql/optimizer/test_cte_column_pruner.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
SqlColumnReferenceExpression,
1111
SqlComparison,
1212
SqlComparisonExpression,
13+
SqlStringExpression,
1314
)
1415
from metricflow_semantics.sql.sql_join_type import SqlJoinType
1516
from metricflow_semantics.sql.sql_table import SqlTable
@@ -464,3 +465,111 @@ def test_common_cte_aliases_in_nested_query(
464465
"""
465466
),
466467
)
468+
469+
470+
def test_string_expression(
471+
request: FixtureRequest,
472+
mf_test_configuration: MetricFlowTestConfiguration,
473+
column_pruner: SqlColumnPrunerOptimizer,
474+
sql_plan_renderer: DefaultSqlPlanRenderer,
475+
) -> None:
476+
"""Test a string expression that references a column in the cte."""
477+
select_statement = SqlSelectStatementNode.create(
478+
description="Top-level SELECT",
479+
select_columns=(
480+
SqlSelectColumn(
481+
expr=SqlStringExpression.create(sql_expr="cte_source_0__col_0", used_columns=("cte_source_0__col_0",)),
482+
column_alias="top_level__col_0",
483+
),
484+
),
485+
from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")),
486+
from_source_alias="cte_source_0_alias",
487+
cte_sources=(
488+
SqlCteNode.create(
489+
cte_alias="cte_source_0",
490+
select_statement=SqlSelectStatementNode.create(
491+
description="CTE source 0",
492+
select_columns=(
493+
SqlSelectColumn(
494+
expr=SqlColumnReferenceExpression.create(
495+
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
496+
),
497+
column_alias="cte_source_0__col_0",
498+
),
499+
SqlSelectColumn(
500+
expr=SqlColumnReferenceExpression.create(
501+
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
502+
),
503+
column_alias="cte_source_0__col_1",
504+
),
505+
),
506+
from_source=SqlTableNode.create(
507+
sql_table=SqlTable(schema_name="test_schema", table_name="test_table")
508+
),
509+
from_source_alias="test_table_alias",
510+
),
511+
),
512+
),
513+
)
514+
assert_optimizer_result_snapshot_equal(
515+
request=request,
516+
mf_test_configuration=mf_test_configuration,
517+
optimizer=column_pruner,
518+
sql_plan_renderer=sql_plan_renderer,
519+
select_statement=select_statement,
520+
expectation_description="`cte_source_0__col_01` should be retained in the CTE.",
521+
)
522+
523+
524+
def test_column_reference_expression(
525+
request: FixtureRequest,
526+
mf_test_configuration: MetricFlowTestConfiguration,
527+
column_pruner: SqlColumnPrunerOptimizer,
528+
sql_plan_renderer: DefaultSqlPlanRenderer,
529+
) -> None:
530+
"""Test a column reference expression that does not specify a table alias."""
531+
select_statement = SqlSelectStatementNode.create(
532+
description="Top-level SELECT",
533+
select_columns=(
534+
SqlSelectColumn(
535+
expr=SqlStringExpression.create(sql_expr="cte_source_0__col_0", used_columns=("cte_source_0__col_0",)),
536+
column_alias="top_level__col_0",
537+
),
538+
),
539+
from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")),
540+
from_source_alias="cte_source_0_alias",
541+
cte_sources=(
542+
SqlCteNode.create(
543+
cte_alias="cte_source_0",
544+
select_statement=SqlSelectStatementNode.create(
545+
description="CTE source 0",
546+
select_columns=(
547+
SqlSelectColumn(
548+
expr=SqlColumnReferenceExpression.create(
549+
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
550+
),
551+
column_alias="cte_source_0__col_0",
552+
),
553+
SqlSelectColumn(
554+
expr=SqlColumnReferenceExpression.create(
555+
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
556+
),
557+
column_alias="cte_source_0__col_1",
558+
),
559+
),
560+
from_source=SqlTableNode.create(
561+
sql_table=SqlTable(schema_name="test_schema", table_name="test_table")
562+
),
563+
from_source_alias="test_table_alias",
564+
),
565+
),
566+
),
567+
)
568+
assert_optimizer_result_snapshot_equal(
569+
request=request,
570+
mf_test_configuration=mf_test_configuration,
571+
optimizer=column_pruner,
572+
sql_plan_renderer=sql_plan_renderer,
573+
select_statement=select_statement,
574+
expectation_description="`cte_source_0__col_01` should be retained in the CTE.",
575+
)

0 commit comments

Comments
 (0)