diff --git a/metricflow/sql/optimizer/column_pruning/required_column_aliases.py b/metricflow/sql/optimizer/column_pruning/required_column_aliases.py index 03c952778..c0e9b315f 100644 --- a/metricflow/sql/optimizer/column_pruning/required_column_aliases.py +++ b/metricflow/sql/optimizer/column_pruning/required_column_aliases.py @@ -234,26 +234,24 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: column_aliases=aliases_required_in_parent, ) - # For all string columns, assume that they are needed from all sources since we don't have a table alias - # in SqlStringExpression.used_columns + # Find instances where a column alias is referenced without a table alias. The two cases are: + # * String expressions like`col_0 + col_1` where a string is used instead of the corresponding SQL object. + # * `SqlColumnAliasReferenceExpression` - e.g. `SELECT col_0` instead of `SELECT table_0.col_0` + column_aliases_to_retain: Set[str] = set() for string_expr in exprs_used_in_this_node.string_exprs: if string_expr.used_columns: - for column_alias in string_expr.used_columns: - for node_to_retain_columns in (node.from_source,) + tuple( - join_desc.right_source for join_desc in node.join_descs - ): - self._current_required_column_alias_mapping.add_alias(node_to_retain_columns, column_alias) - - # Same with unqualified column references - it's hard to tell which source it came from, so it's safest to say - # it's required from all parents. - # An unqualified column reference expression is like `SELECT col_0` whereas a qualified column reference - # expression is like `SELECT table_0.col_0`. + column_aliases_to_retain.update(column_alias for column_alias in string_expr.used_columns) + for unqualified_column_reference_expr in exprs_used_in_this_node.column_alias_reference_exprs: - column_alias = unqualified_column_reference_expr.column_alias - for node_to_retain_columns in (node.from_source,) + tuple( - join_desc.right_source for join_desc in node.join_descs - ): - self._current_required_column_alias_mapping.add_alias(node_to_retain_columns, column_alias) + column_aliases_to_retain.add(unqualified_column_reference_expr.column_alias) + + # Assume those column aliases are needed from all sources as it may not be possible to know which source it + # comes from based on the SQL (e.g. if a query reads from two tables, you would need to know the table schema + # to know which table the column resides) + for node_to_retain_columns in (node.from_source,) + tuple( + join_desc.right_source for join_desc in node.join_descs + ): + self._current_required_column_alias_mapping.add_aliases(node_to_retain_columns, column_aliases_to_retain) # Visit recursively. self._visit_parents(node)