diff --git a/metricflow/sql/optimizer/column_pruning/column_pruner.py b/metricflow/sql/optimizer/column_pruning/column_pruner.py index 7e9f62c4a..ec7bc3f00 100644 --- a/metricflow/sql/optimizer/column_pruning/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruning/column_pruner.py @@ -32,10 +32,10 @@ def __init__( self, required_alias_mapping: NodeToColumnAliasMapping, ) -> None: - """Constructor. + """Initializer. Args: - required_alias_mapping: Describes columns aliases that should be kept / not pruned for each node. + required_alias_mapping: Describes columns aliases that should be retained for each node. """ self._required_alias_mapping = required_alias_mapping @@ -46,19 +46,25 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanNo required_column_aliases = self._required_alias_mapping.get_aliases(node) if required_column_aliases is None: logger.error( - f"Did not find {node.node_id=} in the required alias mapping. Returning the non-pruned version " - f"as it should be valid SQL, but this is a bug and should be investigated." + LazyFormat( + "Did not find the given node in the required alias mapping. Returning the original version " + "as it should be valid SQL, but this is a bug and should be investigated.", + node_id=node.node_id, + ) ) return node if len(required_column_aliases) == 0: logger.error( - f"Got no required column aliases for {node}. Returning the non-pruned version as it should be valid " - f"SQL, but this is a bug and should be investigated." + LazyFormat( + "Got no required column aliases the given node. Returning the original version as it should be valid " + "SQL, but this is a bug and should be investigated.", + node_id=node.node_id, + ) ) return node - pruned_select_columns = tuple( + retained_select_columns = tuple( select_column for select_column in node.select_columns if select_column.column_alias in required_column_aliases @@ -66,7 +72,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanNo return SqlSelectStatementNode.create( description=node.description, - select_columns=pruned_select_columns, + select_columns=retained_select_columns, from_source=node.from_source.accept(self), from_source_alias=node.from_source_alias, cte_sources=tuple( diff --git a/metricflow/sql/optimizer/column_pruning/required_column_aliases.py b/metricflow/sql/optimizer/column_pruning/required_column_aliases.py index 78a3b9411..03c952778 100644 --- a/metricflow/sql/optimizer/column_pruning/required_column_aliases.py +++ b/metricflow/sql/optimizer/column_pruning/required_column_aliases.py @@ -121,10 +121,10 @@ def _visit_parents(self, node: SqlPlanNode) -> None: parent_node.accept(self) return - def _tag_potential_cte_node( + def _map_required_column_aliases_in_potential_cte( self, cte_alias_mapping: SqlCteAliasMapping, table_name: str, column_aliases: Set[str] ) -> None: - """A reference to a SQL table might be a CTE. If so, tag the appropriate aliases in the CTEs.""" + """A reference to a SQL table might be a CTE. If so, map the required column aliases in the CTEs.""" cte_node = cte_alias_mapping.get_cte_node_for_alias(table_name) if cte_node is not None: @@ -215,7 +215,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: ) from_source_as_sql_table_node = node.from_source.as_sql_table_node if from_source_as_sql_table_node is not None: - self._tag_potential_cte_node( + self._map_required_column_aliases_in_potential_cte( cte_alias_mapping=cte_alias_mapping, table_name=from_source_as_sql_table_node.sql_table.table_name, column_aliases=aliases_required_in_parent, @@ -228,7 +228,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: ) right_source_as_sql_table_node = join_desc.right_source.as_sql_table_node if right_source_as_sql_table_node is not None: - self._tag_potential_cte_node( + self._map_required_column_aliases_in_potential_cte( cte_alias_mapping=cte_alias_mapping, table_name=right_source_as_sql_table_node.sql_table.table_name, column_aliases=aliases_required_in_parent, @@ -239,10 +239,10 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: for string_expr in exprs_used_in_this_node.string_exprs: if string_expr.used_columns: for column_alias in string_expr.used_columns: - for node_to_retain_all_columns in (node.from_source,) + tuple( + for node_to_retain_columns in (node.from_source,) + tuple( join_desc.right_source for join_desc in node.join_descs ): - self._current_required_column_alias_mapping.add_alias(node_to_retain_all_columns, column_alias) + self._current_required_column_alias_mapping.add_alias(node_to_retain_columns, column_alias) # Same with unqualified column references - it's hard to tell which source it came from, so it's safest to say # it's required from all parents. @@ -250,10 +250,10 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: # expression is like `SELECT table_0.col_0`. for unqualified_column_reference_expr in exprs_used_in_this_node.column_alias_reference_exprs: column_alias = unqualified_column_reference_expr.column_alias - for node_to_retain_all_columns in (node.from_source,) + tuple( + for node_to_retain_columns in (node.from_source,) + tuple( join_desc.right_source for join_desc in node.join_descs ): - self._current_required_column_alias_mapping.add_alias(node_to_retain_all_columns, column_alias) + self._current_required_column_alias_mapping.add_alias(node_to_retain_columns, column_alias) # Visit recursively. self._visit_parents(node)