Skip to content

Commit

Permalink
Assorted renames in column pruner classes (#1677)
Browse files Browse the repository at this point in the history
This PR renames a few variables and methods in the column pruner classes
in preparation for #1679.
  • Loading branch information
plypaul authored Feb 21, 2025
1 parent 366bb5d commit 73c2e42
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
22 changes: 14 additions & 8 deletions metricflow/sql/optimizer/column_pruning/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def __init__(
self,
required_alias_mapping: NodeToColumnAliasMapping,
) -> None:
"""Constructor.
"""Initializer.
Args:
required_alias_mapping: Describes columns aliases that should be kept / not pruned for each node.
required_alias_mapping: Describes columns aliases that should be retained for each node.
"""
self._required_alias_mapping = required_alias_mapping

Expand All @@ -46,27 +46,33 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanNo
required_column_aliases = self._required_alias_mapping.get_aliases(node)
if required_column_aliases is None:
logger.error(
f"Did not find {node.node_id=} in the required alias mapping. Returning the non-pruned version "
f"as it should be valid SQL, but this is a bug and should be investigated."
LazyFormat(
"Did not find the given node in the required alias mapping. Returning the original version "
"as it should be valid SQL, but this is a bug and should be investigated.",
node_id=node.node_id,
)
)
return node

if len(required_column_aliases) == 0:
logger.error(
f"Got no required column aliases for {node}. Returning the non-pruned version as it should be valid "
f"SQL, but this is a bug and should be investigated."
LazyFormat(
"Got no required column aliases the given node. Returning the original version as it should be valid "
"SQL, but this is a bug and should be investigated.",
node_id=node.node_id,
)
)
return node

pruned_select_columns = tuple(
retained_select_columns = tuple(
select_column
for select_column in node.select_columns
if select_column.column_alias in required_column_aliases
)

return SqlSelectStatementNode.create(
description=node.description,
select_columns=pruned_select_columns,
select_columns=retained_select_columns,
from_source=node.from_source.accept(self),
from_source_alias=node.from_source_alias,
cte_sources=tuple(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,10 @@ def _visit_parents(self, node: SqlPlanNode) -> None:
parent_node.accept(self)
return

def _tag_potential_cte_node(
def _map_required_column_aliases_in_potential_cte(
self, cte_alias_mapping: SqlCteAliasMapping, table_name: str, column_aliases: Set[str]
) -> None:
"""A reference to a SQL table might be a CTE. If so, tag the appropriate aliases in the CTEs."""
"""A reference to a SQL table might be a CTE. If so, map the required column aliases in the CTEs."""
cte_node = cte_alias_mapping.get_cte_node_for_alias(table_name)

if cte_node is not None:
Expand Down Expand Up @@ -215,7 +215,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
)
from_source_as_sql_table_node = node.from_source.as_sql_table_node
if from_source_as_sql_table_node is not None:
self._tag_potential_cte_node(
self._map_required_column_aliases_in_potential_cte(
cte_alias_mapping=cte_alias_mapping,
table_name=from_source_as_sql_table_node.sql_table.table_name,
column_aliases=aliases_required_in_parent,
Expand All @@ -228,7 +228,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
)
right_source_as_sql_table_node = join_desc.right_source.as_sql_table_node
if right_source_as_sql_table_node is not None:
self._tag_potential_cte_node(
self._map_required_column_aliases_in_potential_cte(
cte_alias_mapping=cte_alias_mapping,
table_name=right_source_as_sql_table_node.sql_table.table_name,
column_aliases=aliases_required_in_parent,
Expand All @@ -239,21 +239,21 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
for string_expr in exprs_used_in_this_node.string_exprs:
if string_expr.used_columns:
for column_alias in string_expr.used_columns:
for node_to_retain_all_columns in (node.from_source,) + tuple(
for node_to_retain_columns in (node.from_source,) + tuple(
join_desc.right_source for join_desc in node.join_descs
):
self._current_required_column_alias_mapping.add_alias(node_to_retain_all_columns, column_alias)
self._current_required_column_alias_mapping.add_alias(node_to_retain_columns, column_alias)

# Same with unqualified column references - it's hard to tell which source it came from, so it's safest to say
# it's required from all parents.
# An unqualified column reference expression is like `SELECT col_0` whereas a qualified column reference
# expression is like `SELECT table_0.col_0`.
for unqualified_column_reference_expr in exprs_used_in_this_node.column_alias_reference_exprs:
column_alias = unqualified_column_reference_expr.column_alias
for node_to_retain_all_columns in (node.from_source,) + tuple(
for node_to_retain_columns in (node.from_source,) + tuple(
join_desc.right_source for join_desc in node.join_descs
):
self._current_required_column_alias_mapping.add_alias(node_to_retain_all_columns, column_alias)
self._current_required_column_alias_mapping.add_alias(node_to_retain_columns, column_alias)

# Visit recursively.
self._visit_parents(node)
Expand Down

0 comments on commit 73c2e42

Please sign in to comment.