From 58d6e6ebab169f0c50cbf8ad4625d38caaa9ebd0 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Tue, 12 Nov 2024 21:04:43 -0800 Subject: [PATCH] Update `SqlTableAliasSimplifier` to handle CTEs (#1516) Similar to the updates for the other optimizers in https://github.com/dbt-labs/metricflow/pull/1503 and https://github.com/dbt-labs/metricflow/pull/1504, this PR implements the CTE case for the `SqlTableAliasSimplifier` --- .../sql/optimizer/table_alias_simplifier.py | 13 +- ..._table_alias_no_simplification__result.txt | 41 +++ ...est_table_alias_simplification__result.txt | 85 +++++++ ...lification__after_alias_simplification.sql | 2 +- ...ification__before_alias_simplification.sql | 2 +- .../test_cte_table_alias_simplifier.py | 235 ++++++++++++++++++ .../optimizer/test_table_alias_simplifier.py | 2 +- 7 files changed, 374 insertions(+), 6 deletions(-) create mode 100644 tests_metricflow/snapshots/test_cte_table_alias_simplifier.py/str/test_table_alias_no_simplification__result.txt create mode 100644 tests_metricflow/snapshots/test_cte_table_alias_simplifier.py/str/test_table_alias_simplification__result.txt create mode 100644 tests_metricflow/sql/optimizer/test_cte_table_alias_simplifier.py diff --git a/metricflow/sql/optimizer/table_alias_simplifier.py b/metricflow/sql/optimizer/table_alias_simplifier.py index 646bdd0f05..b503060fda 100644 --- a/metricflow/sql/optimizer/table_alias_simplifier.py +++ b/metricflow/sql/optimizer/table_alias_simplifier.py @@ -26,11 +26,11 @@ class SqlTableAliasSimplifierVisitor(SqlQueryPlanNodeVisitor[SqlQueryPlanNode]): @override def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode: - raise NotImplementedError + return node.with_new_select(node.select_statement.accept(self)) def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102 - # If there is only a single parent, no table aliases are required since there's no ambiguity. - should_simplify_table_aliases = len(node.parent_nodes) <= 1 + # If there is only a single source in the SELECT, no table aliases are required since there's no ambiguity. + should_simplify_table_aliases = len(node.join_descs) == 0 if should_simplify_table_aliases: return SqlSelectStatementNode.create( @@ -41,6 +41,10 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP ), from_source=node.from_source.accept(self), from_source_alias=node.from_source_alias, + cte_sources=tuple( + cte_source.with_new_select(cte_source.select_statement.accept(self)) + for cte_source in node.cte_sources + ), group_bys=tuple( SqlSelectColumn(expr=x.expr.rewrite(should_render_table_alias=False), column_alias=x.column_alias) for x in node.group_bys @@ -59,6 +63,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP select_columns=node.select_columns, from_source=node.from_source.accept(self), from_source_alias=node.from_source_alias, + cte_sources=tuple( + cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources + ), join_descs=tuple( SqlJoinDescription( right_source=x.right_source.accept(self), diff --git a/tests_metricflow/snapshots/test_cte_table_alias_simplifier.py/str/test_table_alias_no_simplification__result.txt b/tests_metricflow/snapshots/test_cte_table_alias_simplifier.py/str/test_table_alias_no_simplification__result.txt new file mode 100644 index 0000000000..b9433df715 --- /dev/null +++ b/tests_metricflow/snapshots/test_cte_table_alias_simplifier.py/str/test_table_alias_no_simplification__result.txt @@ -0,0 +1,41 @@ +test_name: test_table_alias_no_simplification +test_filename: test_cte_table_alias_simplifier.py +docstring: + Tests that table aliases in the SELECT statement of a CTE are not removed when required. +--- +optimizer: + SqlTableAliasSimplifier + +sql_before_optimizing: + -- Top-level SELECT + WITH cte_source_0 AS ( + -- CTE source 0 + SELECT + from_source_alias.col_0 AS cte_source_0__col_0 + FROM test_schema.test_table_0 from_source_alias + INNER JOIN + test_schema.test_table_1 right_source_alias + ON + from_source_alias.col_0 = right_source_alias.col_0 + ) + + SELECT + cte_source_0_alias.cte_source_0__col_0 AS top_level__col_0 + FROM cte_source_0 cte_source_0_alias + +sql_after_optimizing: + -- Top-level SELECT + WITH cte_source_0 AS ( + -- CTE source 0 + SELECT + from_source_alias.col_0 AS cte_source_0__col_0 + FROM test_schema.test_table_0 from_source_alias + INNER JOIN + test_schema.test_table_1 right_source_alias + ON + from_source_alias.col_0 = right_source_alias.col_0 + ) + + SELECT + cte_source_0__col_0 AS top_level__col_0 + FROM cte_source_0 cte_source_0_alias diff --git a/tests_metricflow/snapshots/test_cte_table_alias_simplifier.py/str/test_table_alias_simplification__result.txt b/tests_metricflow/snapshots/test_cte_table_alias_simplifier.py/str/test_table_alias_simplification__result.txt new file mode 100644 index 0000000000..8ff80188c6 --- /dev/null +++ b/tests_metricflow/snapshots/test_cte_table_alias_simplifier.py/str/test_table_alias_simplification__result.txt @@ -0,0 +1,85 @@ +test_name: test_table_alias_simplification +test_filename: test_cte_table_alias_simplifier.py +docstring: + Tests that table aliases in the SELECT statement of a CTE are removed when not needed. +--- +optimizer: + SqlTableAliasSimplifier + +sql_before_optimizing: + -- Top-level SELECT + WITH cte_source_0 AS ( + -- CTE source 0 + SELECT + test_table_alias.col_0 AS cte_source_0__col_0 + , test_table_alias.col_1 AS cte_source_0__col_1 + FROM ( + -- CTE source 0 sub-query + SELECT + test_table_alias.col_0 AS cte_source_0_subquery__col_0 + , test_table_alias.col_0 AS cte_source_0_subquery__col_1 + FROM test_schema.test_table test_table_alias + ) cte_source_0_subquery + ) + + , cte_source_1 AS ( + -- CTE source 1 + SELECT + test_table_alias.col_0 AS cte_source_1__col_0 + , test_table_alias.col_1 AS cte_source_1__col_1 + FROM ( + -- CTE source 1 sub-query + SELECT + cte_source_0_alias.cte_source_0__col_0 AS cte_source_1_subquery__col_0 + , cte_source_0_alias.cte_source_0__col_0 AS cte_source_1_subquery__col_1 + FROM cte_source_0 cte_source_0_alias + ) cte_source_1_subquery + ) + + SELECT + cte_source_0_alias.cte_source_0__col_0 AS top_level__col_0 + , right_source_alias.right_source__col_1 AS top_level__col_1 + FROM cte_source_0 cte_source_0_alias + INNER JOIN + cte_source_1 right_source_alias + ON + cte_source_0_alias.cte_source_0__col_1 = right_source_alias.right_source__col_1 + +sql_after_optimizing: + -- Top-level SELECT + WITH cte_source_0 AS ( + -- CTE source 0 + SELECT + col_0 AS cte_source_0__col_0 + , col_1 AS cte_source_0__col_1 + FROM ( + -- CTE source 0 sub-query + SELECT + col_0 AS cte_source_0_subquery__col_0 + , col_0 AS cte_source_0_subquery__col_1 + FROM test_schema.test_table test_table_alias + ) cte_source_0_subquery + ) + + , cte_source_1 AS ( + -- CTE source 1 + SELECT + col_0 AS cte_source_1__col_0 + , col_1 AS cte_source_1__col_1 + FROM ( + -- CTE source 1 sub-query + SELECT + cte_source_0__col_0 AS cte_source_1_subquery__col_0 + , cte_source_0__col_0 AS cte_source_1_subquery__col_1 + FROM cte_source_0 cte_source_0_alias + ) cte_source_1_subquery + ) + + SELECT + cte_source_0_alias.cte_source_0__col_0 AS top_level__col_0 + , right_source_alias.right_source__col_1 AS top_level__col_1 + FROM cte_source_0 cte_source_0_alias + INNER JOIN + cte_source_1 right_source_alias + ON + cte_source_0_alias.cte_source_0__col_1 = right_source_alias.right_source__col_1 diff --git a/tests_metricflow/snapshots/test_table_alias_simplifier.py/SqlQueryPlan/test_table_alias_simplification__after_alias_simplification.sql b/tests_metricflow/snapshots/test_table_alias_simplifier.py/SqlQueryPlan/test_table_alias_simplification__after_alias_simplification.sql index ef3a220e2b..c6dc990b8a 100644 --- a/tests_metricflow/snapshots/test_table_alias_simplifier.py/SqlQueryPlan/test_table_alias_simplification__after_alias_simplification.sql +++ b/tests_metricflow/snapshots/test_table_alias_simplifier.py/SqlQueryPlan/test_table_alias_simplification__after_alias_simplification.sql @@ -1,7 +1,7 @@ test_name: test_table_alias_simplification test_filename: test_table_alias_simplifier.py docstring: - Tests a case where no pruning should occur. + Tests that table aliases are removed when not needed. --- -- test0 SELECT diff --git a/tests_metricflow/snapshots/test_table_alias_simplifier.py/SqlQueryPlan/test_table_alias_simplification__before_alias_simplification.sql b/tests_metricflow/snapshots/test_table_alias_simplifier.py/SqlQueryPlan/test_table_alias_simplification__before_alias_simplification.sql index c7fcac1e8b..307de0c52d 100644 --- a/tests_metricflow/snapshots/test_table_alias_simplifier.py/SqlQueryPlan/test_table_alias_simplification__before_alias_simplification.sql +++ b/tests_metricflow/snapshots/test_table_alias_simplifier.py/SqlQueryPlan/test_table_alias_simplification__before_alias_simplification.sql @@ -1,7 +1,7 @@ test_name: test_table_alias_simplification test_filename: test_table_alias_simplifier.py docstring: - Tests a case where no pruning should occur. + Tests that table aliases are removed when not needed. --- -- test0 SELECT diff --git a/tests_metricflow/sql/optimizer/test_cte_table_alias_simplifier.py b/tests_metricflow/sql/optimizer/test_cte_table_alias_simplifier.py new file mode 100644 index 0000000000..f2794f5ecd --- /dev/null +++ b/tests_metricflow/sql/optimizer/test_cte_table_alias_simplifier.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import pytest +from _pytest.fixtures import FixtureRequest +from metricflow_semantics.sql.sql_join_type import SqlJoinType +from metricflow_semantics.sql.sql_table import SqlTable +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration + +from metricflow.sql.optimizer.table_alias_simplifier import SqlTableAliasSimplifier +from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer +from metricflow.sql.sql_exprs import ( + SqlColumnReference, + SqlColumnReferenceExpression, + SqlComparison, + SqlComparisonExpression, +) +from metricflow.sql.sql_plan import ( + SqlCteNode, + SqlJoinDescription, + SqlSelectColumn, + SqlSelectStatementNode, + SqlTableNode, +) +from tests_metricflow.sql.optimizer.check_optimizer import assert_optimizer_result_snapshot_equal + + +@pytest.fixture +def sql_plan_renderer() -> DefaultSqlQueryPlanRenderer: # noqa: D103 + return DefaultSqlQueryPlanRenderer() + + +def test_table_alias_simplification( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + sql_plan_renderer: DefaultSqlQueryPlanRenderer, +) -> None: + """Tests that table aliases in the SELECT statement of a CTE are removed when not needed.""" + select_statement = SqlSelectStatementNode.create( + description="Top-level SELECT", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="cte_source_0_alias", column_name="cte_source_0__col_0") + ), + column_alias="top_level__col_0", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="right_source_alias", column_name="right_source__col_1") + ), + column_alias="top_level__col_1", + ), + ), + from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")), + from_source_alias="cte_source_0_alias", + join_descs=( + SqlJoinDescription( + right_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_1")), + right_source_alias="right_source_alias", + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="cte_source_0_alias", column_name="cte_source_0__col_1") + ), + comparison=SqlComparison.EQUALS, + right_expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="right_source_alias", column_name="right_source__col_1") + ), + ), + join_type=SqlJoinType.INNER, + ), + ), + cte_sources=( + SqlCteNode.create( + cte_alias="cte_source_0", + select_statement=SqlSelectStatementNode.create( + description="CTE source 0", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0") + ), + column_alias="cte_source_0__col_0", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_1") + ), + column_alias="cte_source_0__col_1", + ), + ), + from_source=SqlSelectStatementNode.create( + description="CTE source 0 sub-query", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0") + ), + column_alias="cte_source_0_subquery__col_0", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0") + ), + column_alias="cte_source_0_subquery__col_1", + ), + ), + from_source=SqlTableNode.create( + sql_table=SqlTable(schema_name="test_schema", table_name="test_table") + ), + from_source_alias="test_table_alias", + ), + from_source_alias="cte_source_0_subquery", + ), + ), + SqlCteNode.create( + cte_alias="cte_source_1", + select_statement=SqlSelectStatementNode.create( + description="CTE source 1", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0") + ), + column_alias="cte_source_1__col_0", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_1") + ), + column_alias="cte_source_1__col_1", + ), + ), + from_source=SqlSelectStatementNode.create( + description="CTE source 1 sub-query", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference( + table_alias="cte_source_0_alias", column_name="cte_source_0__col_0" + ) + ), + column_alias="cte_source_1_subquery__col_0", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference( + table_alias="cte_source_0_alias", column_name="cte_source_0__col_0" + ) + ), + column_alias="cte_source_1_subquery__col_1", + ), + ), + from_source=SqlTableNode.create( + sql_table=SqlTable(schema_name=None, table_name="cte_source_0") + ), + from_source_alias="cte_source_0_alias", + ), + from_source_alias="cte_source_1_subquery", + ), + ), + ), + ) + assert_optimizer_result_snapshot_equal( + request=request, + mf_test_configuration=mf_test_configuration, + optimizer=SqlTableAliasSimplifier(), + sql_plan_renderer=sql_plan_renderer, + select_statement=select_statement, + ) + + +def test_table_alias_no_simplification( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + sql_plan_renderer: DefaultSqlQueryPlanRenderer, +) -> None: + """Tests that table aliases in the SELECT statement of a CTE are not removed when required.""" + select_statement = SqlSelectStatementNode.create( + description="Top-level SELECT", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="cte_source_0_alias", column_name="cte_source_0__col_0") + ), + column_alias="top_level__col_0", + ), + ), + from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")), + from_source_alias="cte_source_0_alias", + cte_sources=( + SqlCteNode.create( + cte_alias="cte_source_0", + select_statement=SqlSelectStatementNode.create( + description="CTE source 0", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="from_source_alias", column_name="col_0") + ), + column_alias="cte_source_0__col_0", + ), + ), + from_source=SqlTableNode.create( + sql_table=SqlTable(schema_name="test_schema", table_name="test_table_0") + ), + from_source_alias="from_source_alias", + join_descs=( + SqlJoinDescription( + right_source=SqlTableNode.create( + sql_table=SqlTable(schema_name="test_schema", table_name="test_table_1") + ), + right_source_alias="right_source_alias", + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="from_source_alias", column_name="col_0") + ), + comparison=SqlComparison.EQUALS, + right_expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="right_source_alias", column_name="col_0") + ), + ), + join_type=SqlJoinType.INNER, + ), + ), + ), + ), + ), + ) + assert_optimizer_result_snapshot_equal( + request=request, + mf_test_configuration=mf_test_configuration, + optimizer=SqlTableAliasSimplifier(), + sql_plan_renderer=sql_plan_renderer, + select_statement=select_statement, + ) diff --git a/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py b/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py index ae54998bf5..1615b3795b 100644 --- a/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py +++ b/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py @@ -145,7 +145,7 @@ def test_table_alias_simplification( mf_test_configuration: MetricFlowTestConfiguration, base_select_statement: SqlSelectStatementNode, ) -> None: - """Tests a case where no pruning should occur.""" + """Tests that table aliases are removed when not needed.""" assert_default_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration,