Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assorted renames in column pruner classes #1677

Merged
merged 3 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions metricflow/sql/optimizer/column_pruning/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def __init__(
self,
required_alias_mapping: NodeToColumnAliasMapping,
) -> None:
"""Constructor.
"""Initializer.

Args:
required_alias_mapping: Describes columns aliases that should be kept / not pruned for each node.
required_alias_mapping: Describes columns aliases that should be retained for each node.
"""
self._required_alias_mapping = required_alias_mapping

Expand All @@ -46,27 +46,33 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanNo
required_column_aliases = self._required_alias_mapping.get_aliases(node)
if required_column_aliases is None:
logger.error(
f"Did not find {node.node_id=} in the required alias mapping. Returning the non-pruned version "
f"as it should be valid SQL, but this is a bug and should be investigated."
LazyFormat(
"Did not find the given node in the required alias mapping. Returning the original version "
"as it should be valid SQL, but this is a bug and should be investigated.",
node_id=node.node_id,
)
)
return node

if len(required_column_aliases) == 0:
logger.error(
f"Got no required column aliases for {node}. Returning the non-pruned version as it should be valid "
f"SQL, but this is a bug and should be investigated."
LazyFormat(
"Got no required column aliases the given node. Returning the original version as it should be valid "
"SQL, but this is a bug and should be investigated.",
node_id=node.node_id,
)
)
return node

pruned_select_columns = tuple(
retained_select_columns = tuple(
select_column
for select_column in node.select_columns
if select_column.column_alias in required_column_aliases
)

return SqlSelectStatementNode.create(
description=node.description,
select_columns=pruned_select_columns,
select_columns=retained_select_columns,
from_source=node.from_source.accept(self),
from_source_alias=node.from_source_alias,
cte_sources=tuple(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,10 @@ def _visit_parents(self, node: SqlPlanNode) -> None:
parent_node.accept(self)
return

def _tag_potential_cte_node(
def _map_required_column_aliases_in_potential_cte(
self, cte_alias_mapping: SqlCteAliasMapping, table_name: str, column_aliases: Set[str]
) -> None:
"""A reference to a SQL table might be a CTE. If so, tag the appropriate aliases in the CTEs."""
"""A reference to a SQL table might be a CTE. If so, map the required column aliases in the CTEs."""
cte_node = cte_alias_mapping.get_cte_node_for_alias(table_name)

if cte_node is not None:
Expand Down Expand Up @@ -215,7 +215,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
)
from_source_as_sql_table_node = node.from_source.as_sql_table_node
if from_source_as_sql_table_node is not None:
self._tag_potential_cte_node(
self._map_required_column_aliases_in_potential_cte(
cte_alias_mapping=cte_alias_mapping,
table_name=from_source_as_sql_table_node.sql_table.table_name,
column_aliases=aliases_required_in_parent,
Expand All @@ -228,7 +228,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
)
right_source_as_sql_table_node = join_desc.right_source.as_sql_table_node
if right_source_as_sql_table_node is not None:
self._tag_potential_cte_node(
self._map_required_column_aliases_in_potential_cte(
cte_alias_mapping=cte_alias_mapping,
table_name=right_source_as_sql_table_node.sql_table.table_name,
column_aliases=aliases_required_in_parent,
Expand All @@ -239,21 +239,21 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
for string_expr in exprs_used_in_this_node.string_exprs:
if string_expr.used_columns:
for column_alias in string_expr.used_columns:
for node_to_retain_all_columns in (node.from_source,) + tuple(
for node_to_retain_columns in (node.from_source,) + tuple(
join_desc.right_source for join_desc in node.join_descs
):
self._current_required_column_alias_mapping.add_alias(node_to_retain_all_columns, column_alias)
self._current_required_column_alias_mapping.add_alias(node_to_retain_columns, column_alias)

# Same with unqualified column references - it's hard to tell which source it came from, so it's safest to say
# it's required from all parents.
# An unqualified column reference expression is like `SELECT col_0` whereas a qualified column reference
# expression is like `SELECT table_0.col_0`.
for unqualified_column_reference_expr in exprs_used_in_this_node.column_alias_reference_exprs:
column_alias = unqualified_column_reference_expr.column_alias
for node_to_retain_all_columns in (node.from_source,) + tuple(
for node_to_retain_columns in (node.from_source,) + tuple(
join_desc.right_source for join_desc in node.join_descs
):
self._current_required_column_alias_mapping.add_alias(node_to_retain_all_columns, column_alias)
self._current_required_column_alias_mapping.add_alias(node_to_retain_columns, column_alias)

# Visit recursively.
self._visit_parents(node)
Expand Down
Loading