Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for uses of BACK that cause correlated references #269

Merged
merged 7 commits into from
Feb 19, 2025
365 changes: 365 additions & 0 deletions pydough/conversion/hybrid_decorrelater.py

Large diffs are not rendered by default.

515 changes: 321 additions & 194 deletions pydough/conversion/hybrid_tree.py

Large diffs are not rendered by default.

90 changes: 82 additions & 8 deletions pydough/conversion/relational_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
CallExpression,
ColumnPruner,
ColumnReference,
CorrelatedReference,
EmptySingleton,
ExpressionSortInfo,
Filter,
Expand All @@ -43,6 +44,7 @@
)
from pydough.types import BooleanType, Int64Type, UnknownType

from .hybrid_decorrelater import run_hybrid_decorrelation
from .hybrid_tree import (
ConnectionType,
HybridBackRefExpr,
Expand All @@ -52,6 +54,7 @@
HybridCollectionAccess,
HybridColumnExpr,
HybridConnection,
HybridCorrelExpr,
HybridExpr,
HybridFilter,
HybridFunctionExpr,
Expand Down Expand Up @@ -90,11 +93,20 @@ class TranslationOutput:
value of that expression.
"""

correlated_name: str | None = None
"""
The name that can be used to refer to the relational output in correlated
references.
"""


class RelTranslation:
def __init__(self):
# An index used for creating fake column names
self.dummy_idx = 1
# A stack of contexts used to point to ancestors for correlated
# references.
self.stack: list[TranslationOutput] = []

def make_null_column(self, relation: RelationalNode) -> ColumnReference:
"""
Expand Down Expand Up @@ -145,6 +157,24 @@ def get_column_name(
new_name = f"{name}_{self.dummy_idx}"
return new_name

def get_correlated_name(self, context: TranslationOutput) -> str:
"""
Finds the name used to refer to a context for correlated variable
access. If the context does not have a correlated name, a new one is
generated for it.

Args:
`context`: the context containing the relational subtree being
referrenced in a correlated variable access.

Returns:
The name used to refer to the context in a correlated reference.
"""
if context.correlated_name is None:
context.correlated_name = f"corr{self.dummy_idx}"
self.dummy_idx += 1
return context.correlated_name

def translate_expression(
self, expr: HybridExpr, context: TranslationOutput | None
) -> RelationalExpression:
Expand Down Expand Up @@ -199,8 +229,32 @@ def translate_expression(
order_inputs,
expr.kwargs,
)
case HybridCorrelExpr():
# Convert correlated expressions by converting the expression
# they point to in the context of the top of the stack, then
# wrapping the result in a correlated reference.
ancestor_context: TranslationOutput = self.stack.pop()
ancestor_expr: RelationalExpression = self.translate_expression(
expr.expr, ancestor_context
)
self.stack.append(ancestor_context)
match ancestor_expr:
case ColumnReference():
return CorrelatedReference(
ancestor_expr.name,
self.get_correlated_name(ancestor_context),
expr.typ,
)
case CorrelatedReference():
return ancestor_expr
case _:
raise ValueError(
f"Unsupported expression to reference in a correlated reference: {ancestor_expr}"
)
case _:
raise NotImplementedError(expr.__class__.__name__)
raise NotImplementedError(
f"TODO: support relational conversion on {expr.__class__.__name__}"
)

def join_outputs(
self,
Expand Down Expand Up @@ -257,6 +311,7 @@ def join_outputs(
[LiteralExpression(True, BooleanType())],
[join_type],
join_columns,
correl_name=lhs_result.correlated_name,
)
input_aliases: list[str | None] = out_rel.default_input_aliases

Expand Down Expand Up @@ -397,9 +452,11 @@ def handle_children(
"""
for child_idx, child in enumerate(hybrid.children):
if child.required_steps == pipeline_idx:
self.stack.append(context)
child_output = self.rel_translation(
child, child.subtree, len(child.subtree.pipeline) - 1
)
self.stack.pop()
assert child.subtree.join_keys is not None
join_keys: list[tuple[HybridExpr, HybridExpr]] = child.subtree.join_keys
agg_keys: list[HybridExpr]
Expand Down Expand Up @@ -592,7 +649,7 @@ def translate_partition(

Returns:
The TranslationOutput payload containing access to the aggregated
child corresponding tot he partition data.
child corresponding to the partition data.
"""
expressions: dict[HybridExpr, ColumnReference] = {}
# Account for the fact that the PARTITION is stepping down a level,
Expand Down Expand Up @@ -828,11 +885,26 @@ def rel_translation(
if isinstance(operation.collection, TableCollection):
result = self.build_simple_table_scan(operation)
if context is not None:
# If the collection access is the child of something
# else, join it onto that something else. Use the
# uniqueness keys of the ancestor, which should also be
# present in the collection (e.g. joining a partition
# onto the original data using the partition keys).
assert preceding_hybrid is not None
join_keys: list[tuple[HybridExpr, HybridExpr]] = []
for unique_column in sorted(
preceding_hybrid[0].pipeline[0].unique_exprs, key=str
):
if unique_column not in result.expressions:
raise ValueError(
f"Cannot connect parent context to child {operation.collection} because {unique_column} is not in the child's expressions."
)
join_keys.append((unique_column, unique_column))
result = self.join_outputs(
context,
result,
JoinType.INNER,
[],
join_keys,
None,
)
else:
Expand All @@ -856,7 +928,8 @@ def rel_translation(
assert context is not None, "Malformed HybridTree pattern."
result = self.translate_filter(operation, context)
case HybridPartition():
assert context is not None, "Malformed HybridTree pattern."
if context is None:
context = TranslationOutput(EmptySingleton(), {})
result = self.translate_partition(
operation, context, hybrid, pipeline_idx
)
Expand Down Expand Up @@ -942,10 +1015,11 @@ def convert_ast_to_relational(
final_terms: set[str] = node.calc_terms
node = translator.preprocess_root(node)

# Convert the QDAG node to the hybrid form, then invoke the relational
# conversion procedure. The first element in the returned list is the
# final rel node.
hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node)
# Convert the QDAG node to the hybrid form, decorrelate it, then invoke
# the relational conversion procedure. The first element in the returned
# list is the final rel node.
hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node, None)
run_hybrid_decorrelation(hybrid)
renamings: dict[str, str] = hybrid.pipeline[-1].renamings
output: TranslationOutput = translator.rel_translation(
None, hybrid, len(hybrid.pipeline) - 1
Expand Down
12 changes: 12 additions & 0 deletions pydough/pydough_operators/base_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,15 @@ def to_string(self, arg_strings: list[str]) -> str:
Returns:
The string representation of the operator called on its arguments.
"""

@abstractmethod
def equals(self, other: object) -> bool:
"""
Returns whether this operator is equal to another operator.
"""

def __eq__(self, other: object) -> bool:
return self.equals(other)

def __hash__(self) -> int:
return hash(repr(self))
8 changes: 8 additions & 0 deletions pydough/relational/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The relational_expressions submodule provides functionality to define and manage
- `ExpressionSortInfo`: The representation of ordering for an expression within a relational node.
- `RelationalExpressionVisitor`: The basic Visitor pattern to perform operations across the expression components of a relational tree.
- `ColumnReferenceFinder`: Finds all unique column references in a relational expression.
- `CorrelatedReference`: The expression implementation for accessing a correlated column reference in a relational node.
- `CorrelatedReferenceFinder`: Finds all unique correlated references in a relational expression.
- `RelationalExpressionShuttle`: Specialized form of the visitor pattern that returns a relational expression.
- `ColumnReferenceInputNameModifier`: Shuttle implementation designed to update all uses of a column reference's input name to a new input name based on a dictionary.

Expand All @@ -33,6 +35,7 @@ from pydough.relational.relational_expressions import (
ExpressionSortInfo,
ColumnReferenceFinder,
ColumnReferenceInputNameModifier,
CorrelatedReferenceFinder,
WindowCallExpression,
)
from pydough.pydough_operators import ADD, RANKING
Expand Down Expand Up @@ -64,6 +67,11 @@ unique_column_refs = finder.get_column_references()
# Modify the input name of column references in the call expression
modifier = ColumnReferenceInputNameModifier({"old_input_name": "new_input_name"})
modified_call_expr = call_expr.accept_shuttle(modifier)

# Find all unique correlated references in the call expression
correlated_finder = CorrelatedReferenceFinder()
call_expr.accept(correlated_finder)
unique_correlated_refs = correlated_finder.get_correlated_references()
```

## [Relational Nodes](relational_nodes/README.md)
Expand Down
2 changes: 2 additions & 0 deletions pydough/relational/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"ColumnReferenceFinder",
"ColumnReferenceInputNameModifier",
"ColumnReferenceInputNameRemover",
"CorrelatedReference",
"EmptySingleton",
"ExpressionSortInfo",
"Filter",
Expand All @@ -30,6 +31,7 @@
ColumnReferenceFinder,
ColumnReferenceInputNameModifier,
ColumnReferenceInputNameRemover,
CorrelatedReference,
ExpressionSortInfo,
LiteralExpression,
RelationalExpression,
Expand Down
19 changes: 19 additions & 0 deletions pydough/relational/relational_expressions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ The relational_expressions module provides functionality to define and manage va

- `ColumnReferenceFinder`: Finds all unique column references in a relational expression.

### [correlated_reference.py](correlated_reference.py)

- `CorrelatedReference`: The expression implementation for accessing a correlated column reference in a relational node.

### [correlated_reference_finder.py](correlated_reference_finder.py)

- `CorrelatedReferenceFinder`: Finds all unique correlated references in a relational expression.

### [relational_expression_shuttle.py](relational_expression_shuttle.py)

- `RelationalExpressionShuttle`: Specialized form of the visitor pattern that returns a relational expression. This is used to handle the common case where we need to modify a type of input.
Expand All @@ -69,6 +77,8 @@ from pydough.relational.relational_expressions import (
ExpressionSortInfo,
ColumnReferenceFinder,
ColumnReferenceInputNameModifier,
CorrelatedReference,
CorrelatedReferenceFinder,
)
from pydough.pydough_operators import ADD
from pydough.types import Int64Type
Expand All @@ -82,6 +92,10 @@ literal_expr = LiteralExpression(10, Int64Type())
# Create a call expression for addition
call_expr = CallExpression(ADD, Int64Type(), [column_ref, literal_expr])

# Create a correlated reference to column `column_name` in the first input to
# an ancestor join of `corr1`
correlated_ref = CorrelatedReference("column_name", "corr1", Int64Type())

# Create an expression sort info
sort_info = ExpressionSortInfo(call_expr, ascending=True, nulls_first=False)

Expand All @@ -96,4 +110,9 @@ unique_column_refs = finder.get_column_references()
# Modify the input name of column references in the call expression
modifier = ColumnReferenceInputNameModifier({"old_input_name": "new_input_name"})
modified_call_expr = call_expr.accept_shuttle(modifier)

# Find all unique correlated references in the call expression
correlated_finder = CorrelatedReferenceFinder()
call_expr.accept(correlated_finder)
unique_correlated_refs = correlated_finder.get_correlated_references()
```
4 changes: 4 additions & 0 deletions pydough/relational/relational_expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"ColumnReferenceFinder",
"ColumnReferenceInputNameModifier",
"ColumnReferenceInputNameRemover",
"CorrelatedReference",
"CorrelatedReferenceFinder",
"ExpressionSortInfo",
"LiteralExpression",
"RelationalExpression",
Expand All @@ -21,6 +23,8 @@
from .column_reference_finder import ColumnReferenceFinder
from .column_reference_input_name_modifier import ColumnReferenceInputNameModifier
from .column_reference_input_name_remover import ColumnReferenceInputNameRemover
from .correlated_reference import CorrelatedReference
from .correlated_reference_finder import CorrelatedReferenceFinder
from .expression_sort_info import ExpressionSortInfo
from .literal_expression import LiteralExpression
from .relational_expression_visitor import RelationalExpressionVisitor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def to_string(self, compact: bool = False) -> str:
def __repr__(self) -> str:
return self.to_string()

def __hash__(self) -> int:
return hash(self.to_string())

@abstractmethod
def accept(self, visitor: RelationalExpressionVisitor) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,6 @@ def visit_literal_expression(self, literal_expression: LiteralExpression) -> Non

def visit_column_reference(self, column_reference: ColumnReference) -> None:
self._column_references.add(column_reference)

def visit_correlated_reference(self, correlated_reference) -> None:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,6 @@ def visit_column_reference(self, column_reference) -> RelationalExpression:
raise ValueError(
f"Input name {column_reference.input_name} not found in the input name map."
)

def visit_correlated_reference(self, correlated_reference) -> RelationalExpression:
return correlated_reference
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@ def visit_column_reference(self, column_reference) -> RelationalExpression:
column_reference.data_type,
None,
)

def visit_correlated_reference(self, correlated_reference) -> RelationalExpression:
return correlated_reference
Loading