From 051100b0d1355cc0baf3b030093631305635435e Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 17 Feb 2025 15:17:48 -0500 Subject: [PATCH 1/7] Pulling in minor child change --- pydough/unqualified/qualification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydough/unqualified/qualification.py b/pydough/unqualified/qualification.py index 216c4011..ba33907d 100644 --- a/pydough/unqualified/qualification.py +++ b/pydough/unqualified/qualification.py @@ -371,7 +371,7 @@ def qualify_access( for _ in range(levels): if ancestor.ancestor_context is None: raise PyDoughUnqualifiedException( - f"Cannot back reference {levels} above {unqualified_parent}" + f"Cannot back reference {levels} above {context}" ) ancestor = ancestor.ancestor_context # Identify whether the access is an expression or a collection From 37dfd5c6eacf49acb57b43f6336a7d8a2966ba6d Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Mon, 17 Feb 2025 15:26:33 -0500 Subject: [PATCH 2/7] Support uses of BACK that cause correlated references: Hybrid & Relational support for Correlate nodes (#232) Child of #269 that is part of addressing #141. This PR deals with the hybrid & relational conversion, including the creation of new types of hybrid/relational nodes to express correlation, and also defines all of the correlation unit tests. This means that all correlated queries should produce a Relational plan, but will not be possible to convert to SQLGlot until the next PR. Subsequent PRs will deal with cases where a the relational plan needs to be de-correlated. --- pydough/conversion/hybrid_tree.py | 457 +++++++++++------- pydough/conversion/relational_converter.py | 60 ++- pydough/relational/README.md | 8 + pydough/relational/__init__.py | 2 + .../relational_expressions/README.md | 19 + .../relational_expressions/__init__.py | 4 + .../abstract_expression.py | 3 + .../column_reference_finder.py | 3 + .../column_reference_input_name_modifier.py | 3 + .../column_reference_input_name_remover.py | 3 + .../correlated_reference.py | 65 +++ .../correlated_reference_finder.py | 48 ++ .../literal_expression.py | 4 - .../relational_expression_shuttle.py | 11 + .../relational_expression_visitor.py | 9 + .../relational_nodes/column_pruner.py | 61 ++- pydough/relational/relational_nodes/join.py | 18 +- .../sqlglot_relational_expression_visitor.py | 6 + tests/correlated_pydough_functions.py | 302 ++++++++++++ tests/simple_pydough_functions.py | 3 + tests/test_pipeline.py | 361 +++++++++++++- tests/test_plan_refsols/correl_1.txt | 8 + tests/test_plan_refsols/correl_10.txt | 8 + tests/test_plan_refsols/correl_11.txt | 9 + tests/test_plan_refsols/correl_12.txt | 13 + tests/test_plan_refsols/correl_13.txt | 18 + tests/test_plan_refsols/correl_14.txt | 18 + tests/test_plan_refsols/correl_15.txt | 22 + tests/test_plan_refsols/correl_16.txt | 14 + tests/test_plan_refsols/correl_17.txt | 9 + tests/test_plan_refsols/correl_18.txt | 14 + tests/test_plan_refsols/correl_19.txt | 12 + tests/test_plan_refsols/correl_2.txt | 12 + tests/test_plan_refsols/correl_20.txt | 17 + tests/test_plan_refsols/correl_3.txt | 11 + tests/test_plan_refsols/correl_4.txt | 11 + tests/test_plan_refsols/correl_5.txt | 13 + tests/test_plan_refsols/correl_6.txt | 8 + tests/test_plan_refsols/correl_7.txt | 7 + tests/test_plan_refsols/correl_8.txt | 7 + tests/test_plan_refsols/correl_9.txt | 8 + tests/test_plan_refsols/tpch_q21.txt | 21 + tests/test_plan_refsols/tpch_q22.txt | 18 + tests/test_plan_refsols/tpch_q5.txt | 21 + tests/tpch_outputs.py | 2 +- 45 files changed, 1546 insertions(+), 205 deletions(-) create mode 100644 pydough/relational/relational_expressions/correlated_reference.py create mode 100644 pydough/relational/relational_expressions/correlated_reference_finder.py create mode 100644 tests/correlated_pydough_functions.py create mode 100644 tests/test_plan_refsols/correl_1.txt create mode 100644 tests/test_plan_refsols/correl_10.txt create mode 100644 tests/test_plan_refsols/correl_11.txt create mode 100644 tests/test_plan_refsols/correl_12.txt create mode 100644 tests/test_plan_refsols/correl_13.txt create mode 100644 tests/test_plan_refsols/correl_14.txt create mode 100644 tests/test_plan_refsols/correl_15.txt create mode 100644 tests/test_plan_refsols/correl_16.txt create mode 100644 tests/test_plan_refsols/correl_17.txt create mode 100644 tests/test_plan_refsols/correl_18.txt create mode 100644 tests/test_plan_refsols/correl_19.txt create mode 100644 tests/test_plan_refsols/correl_2.txt create mode 100644 tests/test_plan_refsols/correl_20.txt create mode 100644 tests/test_plan_refsols/correl_3.txt create mode 100644 tests/test_plan_refsols/correl_4.txt create mode 100644 tests/test_plan_refsols/correl_5.txt create mode 100644 tests/test_plan_refsols/correl_6.txt create mode 100644 tests/test_plan_refsols/correl_7.txt create mode 100644 tests/test_plan_refsols/correl_8.txt create mode 100644 tests/test_plan_refsols/correl_9.txt create mode 100644 tests/test_plan_refsols/tpch_q21.txt create mode 100644 tests/test_plan_refsols/tpch_q22.txt create mode 100644 tests/test_plan_refsols/tpch_q5.txt diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 0a9fc64e..79a03062 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -229,6 +229,27 @@ def shift_back(self, levels: int) -> HybridExpr | None: return HybridBackRefExpr(self.name, self.back_idx + levels, self.typ) +class HybridCorrelExpr(HybridExpr): + """ + Class for HybridExpr terms that are expressions from a parent hybrid tree + rather than an ancestor, which requires a correlated reference. + """ + + def __init__(self, hybrid: "HybridTree", expr: HybridExpr): + super().__init__(expr.typ) + self.hybrid = hybrid + self.expr: HybridExpr = expr + + def __repr__(self): + return f"CORREL({self.expr})" + + def apply_renamings(self, renamings: dict[str, str]) -> "HybridExpr": + return self + + def shift_back(self, levels: int) -> HybridExpr | None: + return self + + class HybridLiteralExpr(HybridExpr): """ Class for HybridExpr terms that are literals. @@ -1046,6 +1067,10 @@ def __init__(self, configs: PyDoughConfigs): self.configs = configs # An index used for creating fake column names for aliases self.alias_counter: int = 0 + # A stack where each element is a hybrid tree being derived + # as as subtree of the previous element, and the current tree is + # being derived as the subtree of the last element. + self.stack: list[HybridTree] = [] @staticmethod def get_join_keys( @@ -1230,6 +1255,7 @@ def populate_children( accordingly so expressions using the child indices know what hybrid connection index to use. """ + self.stack.append(hybrid) for child_idx, child in enumerate(child_operator.children): # Build the hybrid tree for the child. Before doing so, reset the # alias counter to 0 to ensure that identical subtrees are named @@ -1269,108 +1295,7 @@ def populate_children( for con_typ in reference_types: connection_type = connection_type.reconcile_connection_types(con_typ) child_idx_mapping[child_idx] = hybrid.add_child(subtree, connection_type) - - def make_hybrid_agg_expr( - self, - hybrid: HybridTree, - expr: PyDoughExpressionQDAG, - child_ref_mapping: dict[int, int], - ) -> tuple[HybridExpr, int | None]: - """ - Converts a QDAG expression into a HybridExpr specifically with the - intent of making it the input to an aggregation call. Returns the - converted function argument, as well as an index indicating what child - subtree the aggregation's arguments belong to. NOTE: the HybridExpr is - phrased relative to the child subtree, rather than relative to `hybrid` - itself. - - Args: - `hybrid`: the hybrid tree that should be used to derive the - translation of `expr`, as it is the context in which the `expr` - will live. - `expr`: the QDAG expression to be converted. - `child_ref_mapping`: mapping of indices used by child references - in the original expressions to the index of the child hybrid tree - relative to the current level. - - Returns: - The HybridExpr node corresponding to `expr`, as well as the index - of the child it belongs to (e.g. which subtree does this - aggregation need to be done on top of). - """ - hybrid_result: HybridExpr - # This value starts out as None since we do not know the child index - # that `expr` correspond to yet. It may still be None at the end, since - # it is possible that `expr` does not correspond to any child index. - child_idx: int | None = None - match expr: - case PartitionKey(): - return self.make_hybrid_agg_expr(hybrid, expr.expr, child_ref_mapping) - case Literal(): - # Literals are kept as-is. - hybrid_result = HybridLiteralExpr(expr) - case ChildReferenceExpression(): - # Child references become regular references because the - # expression is phrased as if we were inside the child rather - # than the parent. - child_idx = child_ref_mapping[expr.child_idx] - child_connection = hybrid.children[child_idx] - expr_name = child_connection.subtree.pipeline[-1].renamings.get( - expr.term_name, expr.term_name - ) - hybrid_result = HybridRefExpr(expr_name, expr.pydough_type) - case ExpressionFunctionCall(): - if expr.operator.is_aggregation: - raise NotImplementedError( - "PyDough does not yet support calling aggregations inside of aggregations" - ) - # Every argument must be translated in the same manner as a - # regular function argument, except that the child index it - # corresponds to must be reconciled with the child index value - # accumulated so far. - args: list[HybridExpr] = [] - for arg in expr.args: - if not isinstance(arg, PyDoughExpressionQDAG): - raise NotImplementedError( - f"TODO: support converting {arg.__class__.__name__} as a function argument" - ) - hybrid_arg, hybrid_child_index = self.make_hybrid_agg_expr( - hybrid, arg, child_ref_mapping - ) - if hybrid_child_index is not None: - if child_idx is None: - # In this case, the argument is the first one seen that - # has an index, so that index is chosen. - child_idx = hybrid_child_index - elif hybrid_child_index != child_idx: - # In this case, multiple arguments correspond to - # different children, which cannot be handled yet - # because it means it is impossible to push the agg - # call into a single HybridConnection node. - raise NotImplementedError( - "Unsupported case: multiple child indices referenced by aggregation arguments" - ) - args.append(hybrid_arg) - hybrid_result = HybridFunctionExpr( - expr.operator, args, expr.pydough_type - ) - case BackReferenceExpression(): - raise NotImplementedError( - "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of an ancestor of the current context" - ) - case Reference(): - raise NotImplementedError( - "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of the context itself" - ) - case WindowCall(): - raise NotImplementedError( - "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and window functions" - ) - case _: - raise NotImplementedError( - f"TODO: support converting {expr.__class__.__name__} in aggregations" - ) - return hybrid_result, child_idx + self.stack.pop() def postprocess_agg_output( self, agg_call: HybridFunctionExpr, agg_ref: HybridExpr, joins_can_nullify: bool @@ -1533,11 +1458,195 @@ def handle_has_hasnot( # has / hasnot condition is now known to be true. return HybridLiteralExpr(Literal(True, BooleanType())) + def convert_agg_arg(self, expr: HybridExpr, child_indices: set[int]) -> HybridExpr: + """ + Translates a hybrid expression that is an argument to an aggregation + (or a subexpression of such an argument) into a form that is expressed + from the perspective of the child subtree that is being aggregated. + + Args: + `expr`: the expression to be converted. + `child_indices`: a set that is mutated to contain the indices of + any children that are referenced by `expr`. + + Returns: + The translated expression. + + Raises: + NotImplementedError if `expr` is an expression that cannot be used + inside of an aggregation call. + """ + match expr: + case HybridLiteralExpr(): + return expr + case HybridChildRefExpr(): + # Child references become regular references because the + # expression is phrased as if we were inside the child rather + # than the parent. + child_indices.add(expr.child_idx) + return HybridRefExpr(expr.name, expr.typ) + case HybridFunctionExpr(): + return HybridFunctionExpr( + expr.operator, + [self.convert_agg_arg(arg, child_indices) for arg in expr.args], + expr.typ, + ) + case HybridBackRefExpr(): + raise NotImplementedError( + "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of an ancestor of the current context" + ) + case HybridRefExpr(): + raise NotImplementedError( + "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of the context itself" + ) + case HybridWindowExpr(): + raise NotImplementedError( + "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and window functions" + ) + case _: + raise NotImplementedError( + f"TODO: support converting {expr.__class__.__name__} in aggregations" + ) + + def make_agg_call( + self, + hybrid: HybridTree, + expr: ExpressionFunctionCall, + args: list[HybridExpr], + ) -> HybridExpr: + """ + For aggregate function calls, their arguments are translated in a + manner that identifies what child subtree they correspond too, by + index, and translates them relative to the subtree. Then, the + aggregation calls are placed into the `aggs` mapping of the + corresponding child connection, and the aggregation call becomes a + child reference (referring to the aggs list), since after translation, + an aggregated child subtree only has the grouping keys and the + aggregation calls as opposed to its other terms. + + Args: + `hybrid`: the hybrid tree that should be used to derive the + translation of the aggregation call. + `expr`: the aggregation function QDAG expression to be converted. + `args`: the converted arguments to the aggregation call. + """ + child_indices: set[int] = set() + converted_args: list[HybridExpr] = [ + self.convert_agg_arg(arg, child_indices) for arg in args + ] + if len(child_indices) != 1: + raise ValueError( + f"Expected aggregation call to contain references to exactly one child collection, but found {len(child_indices)} in {expr}" + ) + hybrid_call: HybridFunctionExpr = HybridFunctionExpr( + expr.operator, converted_args, expr.pydough_type + ) + # Identify the child connection that the aggregation call is pushed + # into. + child_idx: int = child_indices.pop() + child_connection: HybridConnection = hybrid.children[child_idx] + # Generate a unique name for the agg call to push into the child + # connection. + agg_name: str = self.get_agg_name(child_connection) + child_connection.aggs[agg_name] = hybrid_call + result_ref: HybridExpr = HybridChildRefExpr( + agg_name, child_idx, expr.pydough_type + ) + joins_can_nullify: bool = not isinstance(hybrid.pipeline[0], HybridRoot) + return self.postprocess_agg_output(hybrid_call, result_ref, joins_can_nullify) + + def make_hybrid_correl_expr( + self, + back_expr: BackReferenceExpression, + collection: PyDoughCollectionQDAG, + steps_taken_so_far: int, + ) -> HybridCorrelExpr: + """ + Converts a BACK reference into a correlated reference when the number + of BACK levels exceeds the height of the current subtree. + + Args: + `back_expr`: the original BACK reference to be converted. + `collection`: the collection at the top of the current subtree, + before we have run out of BACK levels to step up out of. + `steps_taken_so_far`: the number of steps already taken to step + up from the BACK node. This is needed so we know how many steps + still need to be taken upward once we have stepped out of the child + subtree back into the parent subtree. + """ + if len(self.stack) == 0: + raise ValueError("Back reference steps too far back") + # Identify the parent subtree that the BACK reference is stepping back + # into, out of the child. + parent_tree: HybridTree = self.stack.pop() + remaining_steps_back: int = back_expr.back_levels - steps_taken_so_far - 1 + parent_result: HybridExpr + # Special case: stepping out of the data argument of PARTITION back + # into its ancestor. For example: + # TPCH(x=...).PARTITION(data.WHERE(y > BACK(1).x), ...) + if len(parent_tree.pipeline) == 1 and isinstance( + parent_tree.pipeline[0], HybridPartition + ): + assert parent_tree.parent is not None + # Treat the partition's parent as the conext for the back + # to step into, as opposed to the partition itself (so the back + # levels are consistent) + self.stack.append(parent_tree.parent) + parent_result = self.make_hybrid_correl_expr( + back_expr, collection, steps_taken_so_far + ) + self.stack.pop() + self.stack.append(parent_tree) + # Then, postprocess the output to account for the fact that a + # BACK level got skipped due to the change in subtree. + match parent_result.expr: + case HybridRefExpr(): + parent_result = HybridBackRefExpr( + parent_result.expr.name, 1, parent_result.typ + ) + case HybridBackRefExpr(): + parent_result = HybridBackRefExpr( + parent_result.expr.name, + parent_result.expr.back_idx + 1, + parent_result.typ, + ) + case _: + raise ValueError( + f"Malformed expression for correlated reference: {parent_result}" + ) + elif remaining_steps_back == 0: + # If there are no more steps back to be made, then the correlated + # reference is to a reference from the current context. + if back_expr.term_name not in parent_tree.pipeline[-1].terms: + raise ValueError( + f"Back reference to {back_expr.term_name} not found in parent" + ) + parent_name: str = parent_tree.pipeline[-1].renamings.get( + back_expr.term_name, back_expr.term_name + ) + parent_result = HybridRefExpr(parent_name, back_expr.pydough_type) + else: + # Otherwise, a back reference needs to be made from the current + # collection a number of steps back based on how many steps still + # need to be taken, and it must be recursively converted to a + # hybrid expression that gets wrapped in a correlated reference. + new_expr: PyDoughExpressionQDAG = BackReferenceExpression( + collection, back_expr.term_name, remaining_steps_back + ) + parent_result = self.make_hybrid_expr(parent_tree, new_expr, {}, False) + # Restore parent_tree back onto the stack, since evaluating `back_expr` + # does not change the program's current placement in the sutbtrees. + self.stack.append(parent_tree) + # Create the correlated reference to the expression with regards to + # the parent tree, which could also be a correlated expression. + return HybridCorrelExpr(parent_tree, parent_result) + def make_hybrid_expr( self, hybrid: HybridTree, expr: PyDoughExpressionQDAG, child_ref_mapping: dict[int, int], + inside_agg: bool, ) -> HybridExpr: """ Converts a QDAG expression into a HybridExpr. @@ -1550,6 +1659,8 @@ def make_hybrid_expr( `child_ref_mapping`: mapping of indices used by child references in the original expressions to the index of the child hybrid tree relative to the current level. + `inside_agg`: True if `expr` is beign derived is inside of an + aggregation call, False otherwise. Returns: The HybridExpr node corresponding to `expr` @@ -1561,7 +1672,9 @@ def make_hybrid_expr( ancestor_tree: HybridTree match expr: case PartitionKey(): - return self.make_hybrid_expr(hybrid, expr.expr, child_ref_mapping) + return self.make_hybrid_expr( + hybrid, expr.expr, child_ref_mapping, inside_agg + ) case Literal(): return HybridLiteralExpr(expr) case ColumnProperty(): @@ -1581,20 +1694,22 @@ def make_hybrid_expr( case BackReferenceExpression(): # A reference to an expression from an ancestor becomes a # reference to one of the terms of a parent level of the hybrid - # tree. This does not yet support cases where the back - # reference steps outside of a child subtree and back into its - # parent subtree, since that breaks the independence between - # the parent and child. + # tree. If the BACK goes far enough that it must step outside + # a child subtree into the parent, a correlated reference is + # created. ancestor_tree = hybrid back_idx: int = 0 true_steps_back: int = 0 # Keep stepping backward until `expr.back_levels` non-hidden # steps have been taken (to ignore steps that are part of a # compound). + collection: PyDoughCollectionQDAG = expr.collection while true_steps_back < expr.back_levels: + assert collection.ancestor_context is not None + collection = collection.ancestor_context if ancestor_tree.parent is None: - raise NotImplementedError( - "TODO: (gh #141) support BACK references that step from a child subtree back into a parent context." + return self.make_hybrid_correl_expr( + expr, collection, true_steps_back ) ancestor_tree = ancestor_tree.parent back_idx += true_steps_back @@ -1609,80 +1724,51 @@ def make_hybrid_expr( expr.term_name, expr.term_name ) return HybridRefExpr(expr_name, expr.pydough_type) - case ExpressionFunctionCall() if not expr.operator.is_aggregation: - # For non-aggregate function calls, translate their arguments - # normally and build the function call. Does not support any + case ExpressionFunctionCall(): + if expr.operator.is_aggregation and inside_agg: + raise NotImplementedError( + "PyDough does not yet support calling aggregations inside of aggregations" + ) + # Do special casing for operators that an have collection + # arguments. + # TODO: (gh #148) handle collection-level NDISTINCT + if ( + expr.operator == pydop.COUNT + and len(expr.args) == 1 + and isinstance(expr.args[0], PyDoughCollectionQDAG) + ): + return self.handle_collection_count(hybrid, expr, child_ref_mapping) + elif expr.operator in (pydop.HAS, pydop.HASNOT): + return self.handle_has_hasnot(hybrid, expr, child_ref_mapping) + elif any( + not isinstance(arg, PyDoughExpressionQDAG) for arg in expr.args + ): + raise NotImplementedError( + f"PyDough does not yet support non-expression arguments for aggregation function {expr.operator}" + ) + # For normal operators, translate their expression arguments + # normally. If it is a non-aggregation, build the function + # call. If it is an aggregation, transform accordingly. # such function that takes in a collection, as none currently # exist that are not aggregations. + expr.operator.is_aggregation for arg in expr.args: if not isinstance(arg, PyDoughExpressionQDAG): raise NotImplementedError( - "PyDough does not yet support converting collections as function arguments to a non-aggregation function" + f"PyDough does not yet support non-expression arguments for function {expr.operator}" ) - args.append(self.make_hybrid_expr(hybrid, arg, child_ref_mapping)) - return HybridFunctionExpr(expr.operator, args, expr.pydough_type) - case ExpressionFunctionCall() if expr.operator.is_aggregation: - # For aggregate function calls, their arguments are translated in - # a manner that identifies what child subtree they correspond too, - # by index, and translates them relative to the subtree. Then, the - # aggregation calls are placed into the `aggs` mapping of the - # corresponding child connection, and the aggregation call becomes - # a child reference (referring to the aggs list), since after - # translation, an aggregated child subtree only has the grouping - # keys & the aggregation calls as opposed to its other terms. - child_idx: int | None = None - arg_child_idx: int | None = None - for arg in expr.args: - if isinstance(arg, PyDoughExpressionQDAG): - hybrid_arg, arg_child_idx = self.make_hybrid_agg_expr( - hybrid, arg, child_ref_mapping + args.append( + self.make_hybrid_expr( + hybrid, + arg, + child_ref_mapping, + inside_agg or expr.operator.is_aggregation, ) - else: - if not isinstance(arg, ChildReferenceCollection): - raise NotImplementedError("Cannot process argument") - # TODO: (gh #148) handle collection-level NDISTINCT - if expr.operator == pydop.COUNT: - return self.handle_collection_count( - hybrid, expr, child_ref_mapping - ) - elif expr.operator in (pydop.HAS, pydop.HASNOT): - return self.handle_has_hasnot( - hybrid, expr, child_ref_mapping - ) - else: - raise NotImplementedError( - f"PyDough does not yet support collection arguments for aggregation function {expr.operator}" - ) - # Accumulate the `arg_child_idx` value from the argument across - # all function arguments, ensuring that at the end there is - # exactly one child subtree that the agg call corresponds to. - if arg_child_idx is not None: - if child_idx is None: - child_idx = arg_child_idx - elif arg_child_idx != child_idx: - raise NotImplementedError( - "Unsupported case: multiple child indices referenced by aggregation arguments" - ) - args.append(hybrid_arg) - if child_idx is None: - raise NotImplementedError( - "Unsupported case: no child indices referenced by aggregation arguments" ) - hybrid_call: HybridFunctionExpr = HybridFunctionExpr( - expr.operator, args, expr.pydough_type - ) - child_connection = hybrid.children[child_idx] - # Generate a unique name for the agg call to push into the child - # connection. - agg_name: str = self.get_agg_name(child_connection) - child_connection.aggs[agg_name] = hybrid_call - result_ref: HybridExpr = HybridChildRefExpr( - agg_name, child_idx, expr.pydough_type - ) - joins_can_nullify: bool = not isinstance(hybrid.pipeline[0], HybridRoot) - return self.postprocess_agg_output( - hybrid_call, result_ref, joins_can_nullify - ) + if expr.operator.is_aggregation: + return self.make_agg_call(hybrid, expr, args) + else: + return HybridFunctionExpr(expr.operator, args, expr.pydough_type) case WindowCall(): partition_args: list[HybridExpr] = [] order_args: list[HybridCollation] = [] @@ -1700,7 +1786,7 @@ def make_hybrid_expr( partition_args.append(shifted_arg) for arg in expr.collation_args: hybrid_arg = self.make_hybrid_expr( - hybrid, arg.expr, child_ref_mapping + hybrid, arg.expr, child_ref_mapping, inside_agg ) order_args.append(HybridCollation(hybrid_arg, arg.asc, arg.na_last)) return HybridWindowExpr( @@ -1739,7 +1825,9 @@ def process_hybrid_collations( hybrid_orderings: list[HybridCollation] = [] for collation in collations: name = self.get_ordering_name(hybrid) - expr = self.make_hybrid_expr(hybrid, collation.expr, child_ref_mapping) + expr = self.make_hybrid_expr( + hybrid, collation.expr, child_ref_mapping, False + ) new_expressions[name] = expr new_collation: HybridCollation = HybridCollation( HybridRefExpr(name, collation.expr.pydough_type), @@ -1750,7 +1838,7 @@ def process_hybrid_collations( return new_expressions, hybrid_orderings def make_hybrid_tree( - self, node: PyDoughCollectionQDAG, parent: HybridTree | None = None + self, node: PyDoughCollectionQDAG, parent: HybridTree | None ) -> HybridTree: """ Converts a collection QDAG into the HybridTree format. @@ -1792,7 +1880,7 @@ def make_hybrid_tree( new_expressions: dict[str, HybridExpr] = {} for name in sorted(node.calc_terms): expr = self.make_hybrid_expr( - hybrid, node.get_expr(name), child_ref_mapping + hybrid, node.get_expr(name), child_ref_mapping, False ) new_expressions[name] = expr hybrid.pipeline.append( @@ -1806,7 +1894,9 @@ def make_hybrid_tree( case Where(): hybrid = self.make_hybrid_tree(node.preceding_context, parent) self.populate_children(hybrid, node, child_ref_mapping) - expr = self.make_hybrid_expr(hybrid, node.condition, child_ref_mapping) + expr = self.make_hybrid_expr( + hybrid, node.condition, child_ref_mapping, False + ) hybrid.pipeline.append(HybridFilter(hybrid.pipeline[-1], expr)) return hybrid case PartitionBy(): @@ -1819,7 +1909,7 @@ def make_hybrid_tree( for key_name in node.calc_terms: key = node.get_expr(key_name) expr = self.make_hybrid_expr( - successor_hybrid, key, child_ref_mapping + successor_hybrid, key, child_ref_mapping, False ) partition.add_key(key_name, expr) key_exprs.append(HybridRefExpr(key_name, expr.typ)) @@ -1862,13 +1952,16 @@ def make_hybrid_tree( successor_hybrid = self.make_hybrid_tree( node.child_access.child_access, parent ) - partition_by = node.child_access.ancestor_context + partition_by = ( + node.child_access.ancestor_context.starting_predecessor + ) assert isinstance(partition_by, PartitionBy) for key in partition_by.keys: rhs_expr: HybridExpr = self.make_hybrid_expr( successor_hybrid, Reference(node.child_access, key.expr.term_name), child_ref_mapping, + False, ) assert isinstance(rhs_expr, HybridRefExpr) lhs_expr: HybridExpr = HybridChildRefExpr( diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 55de7714..494d7fbd 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -27,6 +27,7 @@ CallExpression, ColumnPruner, ColumnReference, + CorrelatedReference, EmptySingleton, ExpressionSortInfo, Filter, @@ -52,6 +53,7 @@ HybridCollectionAccess, HybridColumnExpr, HybridConnection, + HybridCorrelExpr, HybridExpr, HybridFilter, HybridFunctionExpr, @@ -90,11 +92,20 @@ class TranslationOutput: value of that expression. """ + correlated_name: str | None = None + """ + The name that can be used to refer to the relational output in correlated + references. + """ + class RelTranslation: def __init__(self): # An index used for creating fake column names self.dummy_idx = 1 + # A stack of contexts used to point to ancestors for correlated + # references. + self.stack: list[TranslationOutput] = [] def make_null_column(self, relation: RelationalNode) -> ColumnReference: """ @@ -145,6 +156,24 @@ def get_column_name( new_name = f"{name}_{self.dummy_idx}" return new_name + def get_correlated_name(self, context: TranslationOutput) -> str: + """ + Finds the name used to refer to a context for correlated variable + access. If the context does not have a correlated name, a new one is + generated for it. + + Args: + `context`: the context containing the relational subtree being + referrenced in a correlated variable access. + + Returns: + The name used to refer to the context in a correlated reference. + """ + if context.correlated_name is None: + context.correlated_name = f"corr{self.dummy_idx}" + self.dummy_idx += 1 + return context.correlated_name + def translate_expression( self, expr: HybridExpr, context: TranslationOutput | None ) -> RelationalExpression: @@ -199,8 +228,32 @@ def translate_expression( order_inputs, expr.kwargs, ) + case HybridCorrelExpr(): + # Convert correlated expressions by converting the expression + # they point to in the context of the top of the stack, then + # wrapping the result in a correlated reference. + ancestor_context: TranslationOutput = self.stack.pop() + ancestor_expr: RelationalExpression = self.translate_expression( + expr.expr, ancestor_context + ) + self.stack.append(ancestor_context) + match ancestor_expr: + case ColumnReference(): + return CorrelatedReference( + ancestor_expr.name, + self.get_correlated_name(ancestor_context), + expr.typ, + ) + case CorrelatedReference(): + return ancestor_expr + case _: + raise ValueError( + f"Unsupported expression to reference in a correlated reference: {ancestor_expr}" + ) case _: - raise NotImplementedError(expr.__class__.__name__) + raise NotImplementedError( + f"TODO: support relational conversion on {expr.__class__.__name__}" + ) def join_outputs( self, @@ -257,6 +310,7 @@ def join_outputs( [LiteralExpression(True, BooleanType())], [join_type], join_columns, + correl_name=lhs_result.correlated_name, ) input_aliases: list[str | None] = out_rel.default_input_aliases @@ -397,9 +451,11 @@ def handle_children( """ for child_idx, child in enumerate(hybrid.children): if child.required_steps == pipeline_idx: + self.stack.append(context) child_output = self.rel_translation( child, child.subtree, len(child.subtree.pipeline) - 1 ) + self.stack.pop() assert child.subtree.join_keys is not None join_keys: list[tuple[HybridExpr, HybridExpr]] = child.subtree.join_keys agg_keys: list[HybridExpr] @@ -945,7 +1001,7 @@ def convert_ast_to_relational( # Convert the QDAG node to the hybrid form, then invoke the relational # conversion procedure. The first element in the returned list is the # final rel node. - hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node) + hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node, None) renamings: dict[str, str] = hybrid.pipeline[-1].renamings output: TranslationOutput = translator.rel_translation( None, hybrid, len(hybrid.pipeline) - 1 diff --git a/pydough/relational/README.md b/pydough/relational/README.md index d0259995..57939f8b 100644 --- a/pydough/relational/README.md +++ b/pydough/relational/README.md @@ -18,6 +18,8 @@ The relational_expressions submodule provides functionality to define and manage - `ExpressionSortInfo`: The representation of ordering for an expression within a relational node. - `RelationalExpressionVisitor`: The basic Visitor pattern to perform operations across the expression components of a relational tree. - `ColumnReferenceFinder`: Finds all unique column references in a relational expression. +- `CorrelatedReference`: The expression implementation for accessing a correlated column reference in a relational node. +- `CorrelatedReferenceFinder`: Finds all unique correlated references in a relational expression. - `RelationalExpressionShuttle`: Specialized form of the visitor pattern that returns a relational expression. - `ColumnReferenceInputNameModifier`: Shuttle implementation designed to update all uses of a column reference's input name to a new input name based on a dictionary. @@ -33,6 +35,7 @@ from pydough.relational.relational_expressions import ( ExpressionSortInfo, ColumnReferenceFinder, ColumnReferenceInputNameModifier, + CorrelatedReferenceFinder, WindowCallExpression, ) from pydough.pydough_operators import ADD, RANKING @@ -64,6 +67,11 @@ unique_column_refs = finder.get_column_references() # Modify the input name of column references in the call expression modifier = ColumnReferenceInputNameModifier({"old_input_name": "new_input_name"}) modified_call_expr = call_expr.accept_shuttle(modifier) + +# Find all unique correlated references in the call expression +correlated_finder = CorrelatedReferenceFinder() +call_expr.accept(correlated_finder) +unique_correlated_refs = correlated_finder.get_correlated_references() ``` ## [Relational Nodes](relational_nodes/README.md) diff --git a/pydough/relational/__init__.py b/pydough/relational/__init__.py index 75591d4f..b954a2f3 100644 --- a/pydough/relational/__init__.py +++ b/pydough/relational/__init__.py @@ -6,6 +6,7 @@ "ColumnReferenceFinder", "ColumnReferenceInputNameModifier", "ColumnReferenceInputNameRemover", + "CorrelatedReference", "EmptySingleton", "ExpressionSortInfo", "Filter", @@ -30,6 +31,7 @@ ColumnReferenceFinder, ColumnReferenceInputNameModifier, ColumnReferenceInputNameRemover, + CorrelatedReference, ExpressionSortInfo, LiteralExpression, RelationalExpression, diff --git a/pydough/relational/relational_expressions/README.md b/pydough/relational/relational_expressions/README.md index dcca7290..c11c8f44 100644 --- a/pydough/relational/relational_expressions/README.md +++ b/pydough/relational/relational_expressions/README.md @@ -45,6 +45,14 @@ The relational_expressions module provides functionality to define and manage va - `ColumnReferenceFinder`: Finds all unique column references in a relational expression. +### [correlated_reference.py](correlated_reference.py) + +- `CorrelatedReference`: The expression implementation for accessing a correlated column reference in a relational node. + +### [correlated_reference_finder.py](correlated_reference_finder.py) + +- `CorrelatedReferenceFinder`: Finds all unique correlated references in a relational expression. + ### [relational_expression_shuttle.py](relational_expression_shuttle.py) - `RelationalExpressionShuttle`: Specialized form of the visitor pattern that returns a relational expression. This is used to handle the common case where we need to modify a type of input. @@ -69,6 +77,8 @@ from pydough.relational.relational_expressions import ( ExpressionSortInfo, ColumnReferenceFinder, ColumnReferenceInputNameModifier, + CorrelatedReference, + CorrelatedReferenceFinder, ) from pydough.pydough_operators import ADD from pydough.types import Int64Type @@ -82,6 +92,10 @@ literal_expr = LiteralExpression(10, Int64Type()) # Create a call expression for addition call_expr = CallExpression(ADD, Int64Type(), [column_ref, literal_expr]) +# Create a correlated reference to column `column_name` in the first input to +# an ancestor join of `corr1` +correlated_ref = CorrelatedReference("column_name", "corr1", Int64Type()) + # Create an expression sort info sort_info = ExpressionSortInfo(call_expr, ascending=True, nulls_first=False) @@ -96,4 +110,9 @@ unique_column_refs = finder.get_column_references() # Modify the input name of column references in the call expression modifier = ColumnReferenceInputNameModifier({"old_input_name": "new_input_name"}) modified_call_expr = call_expr.accept_shuttle(modifier) + +# Find all unique correlated references in the call expression +correlated_finder = CorrelatedReferenceFinder() +call_expr.accept(correlated_finder) +unique_correlated_refs = correlated_finder.get_correlated_references() ``` diff --git a/pydough/relational/relational_expressions/__init__.py b/pydough/relational/relational_expressions/__init__.py index 68838524..3eb8fc33 100644 --- a/pydough/relational/relational_expressions/__init__.py +++ b/pydough/relational/relational_expressions/__init__.py @@ -9,6 +9,8 @@ "ColumnReferenceFinder", "ColumnReferenceInputNameModifier", "ColumnReferenceInputNameRemover", + "CorrelatedReference", + "CorrelatedReferenceFinder", "ExpressionSortInfo", "LiteralExpression", "RelationalExpression", @@ -21,6 +23,8 @@ from .column_reference_finder import ColumnReferenceFinder from .column_reference_input_name_modifier import ColumnReferenceInputNameModifier from .column_reference_input_name_remover import ColumnReferenceInputNameRemover +from .correlated_reference import CorrelatedReference +from .correlated_reference_finder import CorrelatedReferenceFinder from .expression_sort_info import ExpressionSortInfo from .literal_expression import LiteralExpression from .relational_expression_visitor import RelationalExpressionVisitor diff --git a/pydough/relational/relational_expressions/abstract_expression.py b/pydough/relational/relational_expressions/abstract_expression.py index 7236ad0b..ac3b6d33 100644 --- a/pydough/relational/relational_expressions/abstract_expression.py +++ b/pydough/relational/relational_expressions/abstract_expression.py @@ -81,6 +81,9 @@ def to_string(self, compact: bool = False) -> str: def __repr__(self) -> str: return self.to_string() + def __hash__(self) -> int: + return hash(self.to_string()) + @abstractmethod def accept(self, visitor: RelationalExpressionVisitor) -> None: """ diff --git a/pydough/relational/relational_expressions/column_reference_finder.py b/pydough/relational/relational_expressions/column_reference_finder.py index 7de5bc88..d8c7ba61 100644 --- a/pydough/relational/relational_expressions/column_reference_finder.py +++ b/pydough/relational/relational_expressions/column_reference_finder.py @@ -42,3 +42,6 @@ def visit_literal_expression(self, literal_expression: LiteralExpression) -> Non def visit_column_reference(self, column_reference: ColumnReference) -> None: self._column_references.add(column_reference) + + def visit_correlated_reference(self, correlated_reference) -> None: + pass diff --git a/pydough/relational/relational_expressions/column_reference_input_name_modifier.py b/pydough/relational/relational_expressions/column_reference_input_name_modifier.py index 0a750462..fd738ad7 100644 --- a/pydough/relational/relational_expressions/column_reference_input_name_modifier.py +++ b/pydough/relational/relational_expressions/column_reference_input_name_modifier.py @@ -45,3 +45,6 @@ def visit_column_reference(self, column_reference) -> RelationalExpression: raise ValueError( f"Input name {column_reference.input_name} not found in the input name map." ) + + def visit_correlated_reference(self, correlated_reference) -> RelationalExpression: + return correlated_reference diff --git a/pydough/relational/relational_expressions/column_reference_input_name_remover.py b/pydough/relational/relational_expressions/column_reference_input_name_remover.py index 26633764..0de2e1d9 100644 --- a/pydough/relational/relational_expressions/column_reference_input_name_remover.py +++ b/pydough/relational/relational_expressions/column_reference_input_name_remover.py @@ -37,3 +37,6 @@ def visit_column_reference(self, column_reference) -> RelationalExpression: column_reference.data_type, None, ) + + def visit_correlated_reference(self, correlated_reference) -> RelationalExpression: + return correlated_reference diff --git a/pydough/relational/relational_expressions/correlated_reference.py b/pydough/relational/relational_expressions/correlated_reference.py new file mode 100644 index 00000000..e6be2790 --- /dev/null +++ b/pydough/relational/relational_expressions/correlated_reference.py @@ -0,0 +1,65 @@ +""" +The representation of a correlated column access for use in a relational tree. +The correl name should be the `correl_name` property of a join ancestor of the +tree, and the name should match one of the column names of the first input to +that join, which is the column that the correlated reference refers to. +""" + +__all__ = ["CorrelatedReference"] + +from pydough.types import PyDoughType + +from .abstract_expression import RelationalExpression +from .relational_expression_shuttle import RelationalExpressionShuttle +from .relational_expression_visitor import RelationalExpressionVisitor + + +class CorrelatedReference(RelationalExpression): + """ + The Expression implementation for accessing a correlated column reference + in a relational node. + """ + + def __init__(self, name: str, correl_name: str, data_type: PyDoughType) -> None: + super().__init__(data_type) + self._name: str = name + self._correl_name: str = correl_name + + def __hash__(self) -> int: + return hash((self.name, self.correl_name, self.data_type)) + + @property + def name(self) -> str: + """ + The name of the column. + """ + return self._name + + @property + def correl_name(self) -> str: + """ + The name of the correlation that the reference points to. + """ + return self._correl_name + + def to_string(self, compact: bool = False) -> str: + if compact: + return f"{self.correl_name}.{self.name}" + else: + return f"CorrelatedReference(name={self.name}, correl_name={self.correl_name}, type={self.data_type})" + + def equals(self, other: object) -> bool: + return ( + isinstance(other, CorrelatedReference) + and (self.name == other.name) + and (self.correl_name == other.correl_name) + and super().equals(other) + ) + + def accept(self, visitor: RelationalExpressionVisitor) -> None: + visitor.visit_correlated_reference(self) + + def accept_shuttle( + self, shuttle: RelationalExpressionShuttle + ) -> RelationalExpression: + return shuttle.visit_correlated_reference(self) diff --git a/pydough/relational/relational_expressions/correlated_reference_finder.py b/pydough/relational/relational_expressions/correlated_reference_finder.py new file mode 100644 index 00000000..20ae0dc2 --- /dev/null +++ b/pydough/relational/relational_expressions/correlated_reference_finder.py @@ -0,0 +1,48 @@ +""" +Find all unique column references in a relational expression. +""" + +from .call_expression import CallExpression +from .column_reference import ColumnReference +from .correlated_reference import CorrelatedReference +from .literal_expression import LiteralExpression +from .relational_expression_visitor import RelationalExpressionVisitor +from .window_call_expression import WindowCallExpression + +__all__ = ["CorrelatedReferenceFinder"] + + +class CorrelatedReferenceFinder(RelationalExpressionVisitor): + """ + Find all unique correlated references in a relational expression. + """ + + def __init__(self) -> None: + self._correlated_references: set[CorrelatedReference] = set() + + def reset(self) -> None: + self._correlated_references = set() + + def get_correlated_references(self) -> set[CorrelatedReference]: + return self._correlated_references + + def visit_call_expression(self, call_expression: CallExpression) -> None: + for arg in call_expression.inputs: + arg.accept(self) + + def visit_window_expression(self, window_expression: WindowCallExpression) -> None: + for arg in window_expression.inputs: + arg.accept(self) + for partition_arg in window_expression.partition_inputs: + partition_arg.accept(self) + for order_arg in window_expression.order_inputs: + order_arg.expr.accept(self) + + def visit_literal_expression(self, literal_expression: LiteralExpression) -> None: + pass + + def visit_column_reference(self, column_reference: ColumnReference) -> None: + pass + + def visit_correlated_reference(self, correlated_reference) -> None: + self._correlated_references.add(correlated_reference) diff --git a/pydough/relational/relational_expressions/literal_expression.py b/pydough/relational/relational_expressions/literal_expression.py index 3b0eb803..1edc9981 100644 --- a/pydough/relational/relational_expressions/literal_expression.py +++ b/pydough/relational/relational_expressions/literal_expression.py @@ -29,10 +29,6 @@ def __init__(self, value: Any, data_type: PyDoughType): super().__init__(data_type) self._value: Any = value - def __hash__(self) -> int: - # Note: This will break if the value isn't hashable. - return hash((self.value, self.data_type)) - @property def value(self) -> object: """ diff --git a/pydough/relational/relational_expressions/relational_expression_shuttle.py b/pydough/relational/relational_expressions/relational_expression_shuttle.py index 354a4b8e..badc0b1d 100644 --- a/pydough/relational/relational_expressions/relational_expression_shuttle.py +++ b/pydough/relational/relational_expressions/relational_expression_shuttle.py @@ -75,3 +75,14 @@ def visit_column_reference(self, column_reference): Returns: RelationalExpression: The new node resulting from visiting this node. """ + + @abstractmethod + def visit_correlated_reference(self, correlated_reference): + """ + Visit a CorrelatedReference node. + + Args: + correlated_reference (CorrelatedReference): The correlated reference node to visit. + Returns: + RelationalExpression: The new node resulting from visiting this node. + """ diff --git a/pydough/relational/relational_expressions/relational_expression_visitor.py b/pydough/relational/relational_expressions/relational_expression_visitor.py index 39746b1a..0873662a 100644 --- a/pydough/relational/relational_expressions/relational_expression_visitor.py +++ b/pydough/relational/relational_expressions/relational_expression_visitor.py @@ -61,3 +61,12 @@ def visit_column_reference(self, column_reference) -> None: Args: column_reference (ColumnReference): The column reference node to visit. """ + + @abstractmethod + def visit_correlated_reference(self, correlated_reference) -> None: + """ + Visit a CorrelatedReference node. + + Args: + correlated_reference (CorrelatedReference): The correlated reference node to visit. + """ diff --git a/pydough/relational/relational_nodes/column_pruner.py b/pydough/relational/relational_nodes/column_pruner.py index 91f7f476..47cda027 100644 --- a/pydough/relational/relational_nodes/column_pruner.py +++ b/pydough/relational/relational_nodes/column_pruner.py @@ -5,10 +5,13 @@ from pydough.relational.relational_expressions import ( ColumnReference, ColumnReferenceFinder, + CorrelatedReference, + CorrelatedReferenceFinder, ) from .abstract_node import RelationalNode from .aggregate import Aggregate +from .join import Join from .project import Project from .relational_expression_dispatcher import RelationalExpressionDispatcher from .relational_root import RelationalRoot @@ -19,11 +22,15 @@ class ColumnPruner: def __init__(self) -> None: self._column_finder: ColumnReferenceFinder = ColumnReferenceFinder() + self._correl_finder: CorrelatedReferenceFinder = CorrelatedReferenceFinder() # Note: We set recurse=False so we only check the expressions in the # current node. - self._dispatcher = RelationalExpressionDispatcher( + self._finder_dispatcher = RelationalExpressionDispatcher( self._column_finder, recurse=False ) + self._correl_dispatcher = RelationalExpressionDispatcher( + self._correl_finder, recurse=False + ) def _prune_identity_project(self, node: RelationalNode) -> RelationalNode: """ @@ -43,7 +50,7 @@ def _prune_identity_project(self, node: RelationalNode) -> RelationalNode: def _prune_node_columns( self, node: RelationalNode, kept_columns: set[str] - ) -> RelationalNode: + ) -> tuple[RelationalNode, set[CorrelatedReference]]: """ Prune the columns for a subtree starting at this node. @@ -68,14 +75,17 @@ def _prune_node_columns( for name, expr in node.columns.items() if name in kept_columns or name in required_columns } + # Update the columns. new_node = node.copy(columns=columns) - self._dispatcher.reset() - # Visit the current identifiers. - new_node.accept(self._dispatcher) + + # Find all the identifiers referenced by the the current node. + self._finder_dispatcher.reset() + new_node.accept(self._finder_dispatcher) found_identifiers: set[ColumnReference] = ( self._column_finder.get_column_references() ) + # If the node is an aggregate but doesn't use any of the inputs # (e.g. a COUNT(*)), arbitrarily mark one of them as used. # TODO: (gh #196) optimize this functionality so it doesn't keep an @@ -88,19 +98,50 @@ def _prune_node_columns( node.input.columns[arbitrary_column_name].data_type, ) ) + # Determine which identifiers to pass to each input. new_inputs: list[RelationalNode] = [] # Note: The ColumnPruner should only be run when all input names are # still present in the columns. - for i, default_input_name in enumerate(new_node.default_input_aliases): + # Iterate over the inputs in reverse order so that the source of + # correlated data is pruned last, since it will need to account for + # any correlated references in the later inputs. + correl_refs: set[CorrelatedReference] = set() + for i, default_input_name in reversed( + list(enumerate(new_node.default_input_aliases)) + ): s: set[str] = set() + input_node: RelationalNode = node.inputs[i] for identifier in found_identifiers: if identifier.input_name == default_input_name: s.add(identifier.name) - new_inputs.append(self._prune_node_columns(node.inputs[i], s)) + if ( + isinstance(new_node, Join) + and i == 0 + and new_node.correl_name is not None + ): + for correl_ref in correl_refs: + if correl_ref.correl_name == new_node.correl_name: + s.add(correl_ref.name) + new_input_node, new_correl_refs = self._prune_node_columns(input_node, s) + new_inputs.append(new_input_node) + if i == len(node.inputs) - 1: + correl_refs = new_correl_refs + else: + correl_refs.update(new_correl_refs) + new_inputs.reverse() + + # Find all the correlated references in the new node. + self._correl_dispatcher.reset() + new_node.accept(self._correl_dispatcher) + found_correl_refs: set[CorrelatedReference] = ( + self._correl_finder.get_correlated_references() + ) + correl_refs.update(found_correl_refs) + # Determine the new node. output = new_node.copy(inputs=new_inputs) - return self._prune_identity_project(output) + return self._prune_identity_project(output), correl_refs def prune_unused_columns(self, root: RelationalRoot) -> RelationalRoot: """ @@ -112,8 +153,6 @@ def prune_unused_columns(self, root: RelationalRoot) -> RelationalRoot: Returns: RelationalRoot: The root after updating all inputs. """ - new_root: RelationalNode = self._prune_node_columns( - root, set(root.columns.keys()) - ) + new_root, _ = self._prune_node_columns(root, set(root.columns.keys())) assert isinstance(new_root, RelationalRoot), "Expected a root node." return new_root diff --git a/pydough/relational/relational_nodes/join.py b/pydough/relational/relational_nodes/join.py index f1629689..f7759571 100644 --- a/pydough/relational/relational_nodes/join.py +++ b/pydough/relational/relational_nodes/join.py @@ -49,6 +49,7 @@ def __init__( conditions: list[RelationalExpression], join_types: list[JoinType], columns: MutableMapping[str, RelationalExpression], + correl_name: str | None = None, ) -> None: super().__init__(columns) num_inputs = len(inputs) @@ -65,6 +66,15 @@ def __init__( ), "Join condition must be a boolean type" self._conditions: list[RelationalExpression] = conditions self._join_types: list[JoinType] = join_types + self._correl_name: str | None = correl_name + + @property + def correl_name(self) -> str | None: + """ + The name used to refer to the first join input when subsequent inputs + have correlated references. + """ + return self._correl_name @property def conditions(self) -> list[RelationalExpression]: @@ -101,6 +111,7 @@ def node_equals(self, other: RelationalNode) -> bool: isinstance(other, Join) and self.conditions == other.conditions and self.join_types == other.join_types + and self.correl_name == other.correl_name and all( self.inputs[i].node_equals(other.inputs[i]) for i in range(len(self.inputs)) @@ -109,7 +120,10 @@ def node_equals(self, other: RelationalNode) -> bool: def to_string(self, compact: bool = False) -> str: conditions: list[str] = [cond.to_string(compact) for cond in self.conditions] - return f"JOIN(conditions=[{', '.join(conditions)}], types={[t.value for t in self.join_types]}, columns={self.make_column_string(self.columns, compact)})" + correl_suffix = ( + "" if self.correl_name is None else f", correl_name={self.correl_name!r}" + ) + return f"JOIN(conditions=[{', '.join(conditions)}], types={[t.value for t in self.join_types]}, columns={self.make_column_string(self.columns, compact)}{correl_suffix})" def accept(self, visitor: RelationalVisitor) -> None: visitor.visit_join(self) @@ -119,4 +133,4 @@ def node_copy( columns: MutableMapping[str, RelationalExpression], inputs: MutableSequence[RelationalNode], ) -> RelationalNode: - return Join(inputs, self.conditions, self.join_types, columns) + return Join(inputs, self.conditions, self.join_types, columns, self.correl_name) diff --git a/pydough/sqlglot/sqlglot_relational_expression_visitor.py b/pydough/sqlglot/sqlglot_relational_expression_visitor.py index 4ca111ba..017ebe60 100644 --- a/pydough/sqlglot/sqlglot_relational_expression_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_expression_visitor.py @@ -13,6 +13,7 @@ from pydough.relational import ( CallExpression, ColumnReference, + CorrelatedReference, LiteralExpression, RelationalExpression, RelationalExpressionVisitor, @@ -133,6 +134,11 @@ def visit_literal_expression(self, literal_expression: LiteralExpression) -> Non ) self._stack.append(literal) + def visit_correlated_reference( + self, correlated_reference: CorrelatedReference + ) -> None: + raise NotImplementedError("TODO") + @staticmethod def generate_column_reference_identifier( column_reference: ColumnReference, diff --git a/tests/correlated_pydough_functions.py b/tests/correlated_pydough_functions.py new file mode 100644 index 00000000..e3b73053 --- /dev/null +++ b/tests/correlated_pydough_functions.py @@ -0,0 +1,302 @@ +""" +Variant of `simple_pydough_functions.py` for functions testing edge cases in +correlation & de-correlation handling. +""" + +# ruff: noqa +# mypy: ignore-errors +# ruff & mypy should not try to typecheck or verify any of this + + +def correl_1(): + # Correlated back reference example #1: simple 1-step correlated reference + # For each region, count how many of its its nations start with the same + # letter as the region. This is a true correlated join doing an aggregated + # access without requiring the RHS be present. + return Regions( + name, n_prefix_nations=COUNT(nations.WHERE(name[:1] == BACK(1).name[:1])) + ).ORDER_BY(name.ASC()) + + +def correl_2(): + # Correlated back reference example #2: simple 2-step correlated reference + # For each region's nations, count how many customers have a comment + # starting with the same letter as the region. Exclude regions that start + # with the letter a. This is a true correlated join doing an aggregated + # access without requiring the RHS be present. + selected_custs = customers.WHERE(comment[:1] == LOWER(BACK(2).name[:1])) + return ( + Regions.WHERE(~STARTSWITH(name, "A")) + .nations( + name, + n_selected_custs=COUNT(selected_custs), + ) + .ORDER_BY(name.ASC()) + ) + + +def correl_3(): + # Correlated back reference example #3: double-layer correlated reference + # For every every region, count how many of its nations have a customer + # whose comment starts with the same 2 letter as the region. This is a true + # correlated join doing an aggregated access without requiring the RHS be + # present. + selected_custs = customers.WHERE(comment[:2] == LOWER(BACK(2).name[:2])) + return Regions(name, n_nations=COUNT(nations.WHERE(HAS(selected_custs)))).ORDER_BY( + name.ASC() + ) + + +def correl_4(): + # Correlated back reference example #4: 2-step correlated HASNOT + # Find every nation that does not have a customer whose account balance is + # within $5 of the smallest known account balance globally. + # (This is a correlated ANTI-join) + selected_customers = customers.WHERE(acctbal <= (BACK(2).smallest_bal + 5.0)) + return ( + TPCH( + smallest_bal=MIN(Customers.acctbal), + ) + .Nations(name) + .WHERE(HASNOT(selected_customers)) + .ORDER_BY(name.ASC()) + ) + + +def correl_5(): + # Correlated back reference example #5: 2-step correlated HAS + # Find every region that has at least 1 supplier whose account balance is + # within $4 of the smallest known account balance globally. + # (This is a correlated SEMI-join) + selected_suppliers = nations.suppliers.WHERE( + account_balance <= (BACK(3).smallest_bal + 4.0) + ) + return ( + TPCH( + smallest_bal=MIN(Suppliers.account_balance), + ) + .Regions(name) + .WHERE(HAS(selected_suppliers)) + .ORDER_BY(name.ASC()) + ) + + +def correl_6(): + # Correlated back reference example #6: simple 1-step correlated reference + # For each region, count how many of its its nations start with the same + # letter as the region, but only keep regions with at least one such nation. + # This is a true correlated join doing an aggregated access that does NOT + # require that records without the RHS be kept. + selected_nations = nations.WHERE(name[:1] == BACK(1).name[:1]) + return Regions.WHERE(HAS(selected_nations))( + name, n_prefix_nations=COUNT(selected_nations) + ) + + +def correl_7(): + # Correlated back reference example #6: deleted correlated reference + # For each region, count how many of its its nations start with the same + # letter as the region, but only keep regions without at least one such + # nation. The true correlated join is trumped by the correlated ANTI-join. + selected_nations = nations.WHERE(name[:1] == BACK(1).name[:1]) + return Regions.WHERE(HASNOT(selected_nations))( + name, n_prefix_nations=COUNT(selected_nations) + ) + + +def correl_8(): + # Correlated back reference example #8: non-agg correlated reference + # For each nation, fetch the name of its region, but filter the reigon + # so it only keeps it if it starts with the same letter as the nation + # (otherwise, returns NULL). This is a true correlated join doing an + # access without aggregation without requiring the RHS be + # present. + aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) + return Nations(name, rname=aug_region.name).ORDER_BY(name.ASC()) + + +def correl_9(): + # Correlated back reference example #9: non-agg correlated reference + # For each nation, fetch the name of its region, but filter the reigon + # so it only keeps it if it starts with the same letter as the nation + # (otherwise, omit the nation). This is a true correlated join doing an + # access that also requires the RHS records be present. + aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) + return Nations.WHERE(HAS(aug_region))(name, rname=aug_region.name).ORDER_BY( + name.ASC() + ) + + +def correl_10(): + # Correlated back reference example #10: deleted correlated reference + # For each nation, fetch the name of its region, but filter the reigon + # so it only keeps it if it starts with the same letter as the nation + # (otherwise, returns NULL), and also filter the nations to only keep + # records where the region is NULL. The true correlated join is trumped by + # the correlated ANTI-join. + aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) + return Nations.WHERE(HASNOT(aug_region))(name, rname=aug_region.name).ORDER_BY( + name.ASC() + ) + + +def correl_11(): + # Correlated back reference example #11: backref out of partition child. + # Which part brands have at least 1 part that more than 40% above the + # average retail price for all parts from that brand. + # (This is a correlated SEMI-join) + brands = PARTITION(Parts, name="p", by=brand)(avg_price=AVG(p.retail_price)) + outlier_parts = p.WHERE(retail_price > 1.4 * BACK(1).avg_price) + selected_brands = brands.WHERE(HAS(outlier_parts)) + return selected_brands(brand).ORDER_BY(brand.ASC()) + + +def correl_12(): + # Correlated back reference example #12: backref out of partition child. + # Which part brands have at least 1 part that is above the average retail + # price for parts of that brand, below the average retail price for all + # parts, and has a size below 3. + # (This is a correlated SEMI-join) + global_info = TPCH(avg_price=AVG(Parts.retail_price)) + brands = global_info.PARTITION(Parts, name="p", by=brand)( + avg_price=AVG(p.retail_price) + ) + selected_parts = p.WHERE( + (retail_price > BACK(1).avg_price) + & (retail_price < BACK(2).avg_price) + & (size < 3) + ) + selected_brands = brands.WHERE(HAS(selected_parts)) + return selected_brands(brand).ORDER_BY(brand.ASC()) + + +def correl_13(): + # Correlated back reference example #13: multiple correlation. + # Count how many suppliers sell at least one part where the retail price + # is less than a 50% markup over the supply cost. Only considers suppliers + # from nations #1/#2/#3, and small parts. + # (This is a correlated SEMI-joins) + selected_part = part.WHERE( + STARTSWITH(container, "SM") & (retail_price < (BACK(1).supplycost * 1.5)) + ) + selected_supply_records = supply_records.WHERE(HAS(selected_part)) + supplier_info = Suppliers.WHERE(nation_key <= 3)( + avg_price=AVG(supply_records.part.retail_price) + ) + selected_suppliers = supplier_info.WHERE(COUNT(selected_supply_records) > 0) + return TPCH(n=COUNT(selected_suppliers)) + + +def correl_14(): + # Correlated back reference example #14: multiple correlation. + # Count how many suppliers sell at least one part where the retail price + # is less than a 50% markup over the supply cost, and the retail price of + # the part is below the average for all parts from the supplier. Only + # considers suppliers from nations #19, and LG DRUM parts. + # (This is multiple correlated SEMI-joins) + selected_part = part.WHERE( + (container == "LG DRUM") + & (retail_price < (BACK(1).supplycost * 1.5)) + & (retail_price < BACK(2).avg_price) + ) + selected_supply_records = supply_records.WHERE(HAS(selected_part)) + supplier_info = Suppliers.WHERE(nation_key == 19)( + avg_price=AVG(supply_records.part.retail_price) + ) + selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records)) + return TPCH(n=COUNT(selected_suppliers)) + + +def correl_15(): + # Correlated back reference example #15: multiple correlation. + # Count how many suppliers sell at least one part where the retail price + # is less than a 50% markup over the supply cost, and the retail price of + # the part is below the 85% of the average of the retail price for all + # parts globally and below the average for all parts from the supplier. + # Only considers suppliers from nations #19, and LG DRUM parts. + # (This is multiple correlated SEMI-joins & a correlated aggregate) + selected_part = part.WHERE( + (container == "LG DRUM") + & (retail_price < (BACK(1).supplycost * 1.5)) + & (retail_price < BACK(2).avg_price) + & (retail_price < BACK(3).avg_price * 0.85) + ) + selected_supply_records = supply_records.WHERE(HAS(selected_part)) + supplier_info = Suppliers.WHERE(nation_key == 19)( + avg_price=AVG(supply_records.part.retail_price) + ) + selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records)) + global_info = TPCH(avg_price=AVG(Parts.retail_price)) + return global_info(n=COUNT(selected_suppliers)) + + +def correl_16(): + # Correlated back reference example #16: hybrid tree order of operations. + # Count how many european suppliers have the exact same percentile value + # of account balance (relative to all other suppliers) as at least one + # customer's percentile value of account balance relative to all other + # customers. Percentile should be measured down to increments of 0.01%. + # (This is a correlated SEMI-joins) + selected_customers = nation(rname=region.name).customers.WHERE( + (PERCENTILE(by=(acctbal.ASC(), key.ASC()), n_buckets=10000) == BACK(2).tile) + & (BACK(1).rname == "EUROPE") + ) + supplier_info = Suppliers( + tile=PERCENTILE(by=(account_balance.ASC(), key.ASC()), n_buckets=10000) + ) + selected_suppliers = supplier_info.WHERE(HAS(selected_customers)) + return TPCH(n=COUNT(selected_suppliers)) + + +def correl_17(): + # Correlated back reference example #17: hybrid tree order of operations. + # An extremely roundabout way of getting each region_name-nation_name + # pair as a string. + # (This is a correlated singular/semi access) + region_info = region(fname=JOIN_STRINGS("-", LOWER(name), BACK(1).lname)) + nation_info = Nations(lname=LOWER(name)).WHERE(HAS(region_info)) + return nation_info(fullname=region_info.fname).ORDER_BY(fullname.ASC()) + + +def correl_18(): + # Correlated back reference example #18: partition decorrelation edge case. + # Count how many orders corresponded to at least half of the total price + # spent by the ordering customer in a single day, but only if the customer + # ordered multiple orders in on that day. Only considers orders made in + # 1993. + # (This is a correlated aggregation access) + cust_date_groups = PARTITION( + Orders.WHERE(YEAR(order_date) == 1993), + name="o", + by=(customer_key, order_date), + ) + selected_groups = cust_date_groups.WHERE(COUNT(o) > 1)( + total_price=SUM(o.total_price), + )(n_above_avg=COUNT(o.WHERE(total_price >= 0.5 * BACK(1).total_price))) + return TPCH(n=SUM(selected_groups.n_above_avg)) + + +def correl_19(): + # Correlated back reference example #19: cardinality edge case. + # For every supplier, count how many customers in the same nation have a + # higher account balance than that supplier. Pick the 5 suppliers with the + # largest such count. + # (This is a correlated aggregation access) + super_cust = customers.WHERE(acctbal > BACK(2).account_balance) + return Suppliers.nation(name=BACK(1).name, n_super_cust=COUNT(super_cust)).TOP_K( + 5, n_super_cust.DESC() + ) + + +def correl_20(): + # Correlated back reference example #20: multiple ancestor uniqueness keys. + # Count the instances where a nation's suppliers shipped a part to a + # customer in the same nation, only counting instances where the order was + # made in June of 1998. + # (This is a correlated singular/semi access) + is_domestic = nation(domestic=name == BACK(5).name).domestic + selected_orders = Nations.customers.orders.WHERE( + (YEAR(order_date) == 1998) & (MONTH(order_date) == 6) + ) + instances = selected_orders.lines.supplier.WHERE(is_domestic) + return TPCH(n=COUNT(instances)) diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index f96e7061..df322958 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -1,3 +1,6 @@ +""" +Various functions containing PyDough code snippets for testing purposes. +""" # ruff: noqa # mypy: ignore-errors # ruff & mypy should not try to typecheck or verify any of this diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 34475607..1ec38017 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -12,6 +12,28 @@ bad_slice_3, bad_slice_4, ) +from correlated_pydough_functions import ( + correl_1, + correl_2, + correl_3, + correl_4, + correl_5, + correl_6, + correl_7, + correl_8, + correl_9, + correl_10, + correl_11, + correl_12, + correl_13, + correl_14, + correl_15, + correl_16, + correl_17, + correl_18, + correl_19, + correl_20, +) from simple_pydough_functions import ( agg_partition, double_partition, @@ -140,7 +162,6 @@ tpch_q5_output, ), id="tpch_q5", - marks=pytest.mark.skip("TODO: support correlated back references"), ), pytest.param( ( @@ -269,7 +290,6 @@ tpch_q21_output, ), id="tpch_q21", - marks=pytest.mark.skip("TODO: support correlated back references"), ), pytest.param( ( @@ -278,7 +298,6 @@ tpch_q22_output, ), id="tpch_q22", - marks=pytest.mark.skip("TODO: support correlated back references"), ), pytest.param( ( @@ -630,6 +649,342 @@ ), id="triple_partition", ), + pytest.param( + ( + correl_1, + "correl_1", + lambda: pd.DataFrame( + { + "name": ["AFRICA", "AMERICA", "ASIA", "EUROPE", "MIDDLE EAST"], + "n_prefix_nations": [1, 1, 0, 0, 0], + } + ), + ), + id="correl_1", + ), + pytest.param( + ( + correl_2, + "correl_2", + lambda: pd.DataFrame( + { + "name": [ + "EGYPT", + "FRANCE", + "GERMANY", + "IRAN", + "IRAQ", + "JORDAN", + "ROMANIA", + "RUSSIA", + "SAUDI ARABIA", + "UNITED KINGDOM", + ], + "n_selected_custs": [ + 19, + 593, + 595, + 15, + 21, + 9, + 588, + 620, + 19, + 585, + ], + } + ), + ), + id="correl_2", + ), + pytest.param( + ( + correl_3, + "correl_3", + lambda: pd.DataFrame( + { + "name": ["AFRICA", "AMERICA", "ASIA", "EUROPE", "MIDDLE EAST"], + "n_nations": [5, 5, 5, 0, 2], + } + ), + ), + id="correl_3", + ), + pytest.param( + ( + correl_4, + "correl_4", + lambda: pd.DataFrame( + { + "name": ["ARGENTINA", "KENYA", "UNITED KINGDOM"], + } + ), + ), + id="correl_4", + ), + pytest.param( + ( + correl_5, + "correl_5", + lambda: pd.DataFrame( + { + "name": ["AFRICA", "ASIA", "MIDDLE EAST"], + } + ), + ), + id="correl_5", + ), + pytest.param( + ( + correl_6, + "correl_6", + lambda: pd.DataFrame( + { + "name": ["AFRICA", "AMERICA"], + "n_prefix_nations": [1, 1], + } + ), + ), + id="correl_6", + ), + pytest.param( + ( + correl_7, + "correl_7", + lambda: pd.DataFrame( + { + "name": ["ASIA", "EUROPE", "MIDDLE EAST"], + "n_prefix_nations": [0] * 3, + } + ), + ), + id="correl_7", + ), + pytest.param( + ( + correl_8, + "correl_8", + lambda: pd.DataFrame( + { + "name": [ + "ALGERIA", + "ARGENTINA", + "BRAZIL", + "CANADA", + "CHINA", + "EGYPT", + "ETHIOPIA", + "FRANCE", + "GERMANY", + "INDIA", + "INDONESIA", + "IRAN", + "IRAQ", + "JAPAN", + "JORDAN", + "KENYA", + "MOROCCO", + "MOZAMBIQUE", + "PERU", + "ROMANIA", + "RUSSIA", + "SAUDI ARABIA", + "UNITED KINGDOM", + "UNITED STATES", + "VIETNAM", + ], + "rname": ["AFRICA", "AMERICA"] + [None] * 23, + } + ), + ), + id="correl_8", + ), + pytest.param( + ( + correl_9, + "correl_9", + lambda: pd.DataFrame( + { + "name": [ + "ALGERIA", + "ARGENTINA", + ], + "rname": ["AFRICA", "AMERICA"], + } + ), + ), + id="correl_9", + ), + pytest.param( + ( + correl_10, + "correl_10", + lambda: pd.DataFrame( + { + "name": [ + "BRAZIL", + "CANADA", + "CHINA", + "EGYPT", + "ETHIOPIA", + "FRANCE", + "GERMANY", + "INDIA", + "INDONESIA", + "IRAN", + "IRAQ", + "JAPAN", + "JORDAN", + "KENYA", + "MOROCCO", + "MOZAMBIQUE", + "PERU", + "ROMANIA", + "RUSSIA", + "SAUDI ARABIA", + "UNITED KINGDOM", + "UNITED STATES", + "VIETNAM", + ], + "rname": [None] * 23, + } + ), + ), + id="correl_10", + ), + pytest.param( + ( + correl_11, + "correl_11", + lambda: pd.DataFrame( + {"brand": ["Brand#33", "Brand#43", "Brand#45", "Brand#55"]} + ), + ), + id="correl_11", + ), + pytest.param( + ( + correl_12, + "correl_12", + lambda: pd.DataFrame( + { + "brand": [ + "Brand#14", + "Brand#31", + "Brand#33", + "Brand#43", + "Brand#55", + ] + } + ), + ), + id="correl_12", + ), + pytest.param( + ( + correl_13, + "correl_13", + lambda: pd.DataFrame({"n": [1129]}), + ), + id="correl_13", + ), + pytest.param( + ( + correl_14, + "correl_14", + lambda: pd.DataFrame({"n": [66]}), + ), + id="correl_14", + ), + pytest.param( + ( + correl_15, + "correl_15", + lambda: pd.DataFrame({"n": [61]}), + ), + id="correl_15", + ), + pytest.param( + ( + correl_16, + "correl_16", + lambda: pd.DataFrame({"n": [929]}), + ), + id="correl_16", + ), + pytest.param( + ( + correl_17, + "correl_17", + lambda: pd.DataFrame( + { + "fullname": [ + "africa-algeria", + "africa-ethiopia", + "africa-kenya", + "africa-morocco", + "africa-mozambique", + "america-argentina", + "america-brazil", + "america-canada", + "america-peru", + "america-united states", + "asia-china", + "asia-india", + "asia-indonesia", + "asia-japan", + "asia-vietnam", + "europe-france", + "europe-germany", + "europe-romania", + "europe-russia", + "europe-united kingdom", + "middle east-egypt", + "middle east-iran", + "middle east-iraq", + "middle east-jordan", + "middle east-saudi arabia", + ] + } + ), + ), + id="correl_17", + ), + pytest.param( + ( + correl_18, + "correl_18", + lambda: pd.DataFrame({"n": [697]}), + ), + id="correl_18", + ), + pytest.param( + ( + correl_19, + "correl_19", + lambda: pd.DataFrame( + { + "name": [ + "Supplier#000003934", + "Supplier#000003887", + "Supplier#000002628", + "Supplier#000008722", + "Supplier#000007971", + ], + "n_super_cust": [6160, 6142, 6129, 6127, 6117], + } + ), + ), + id="correl_19", + ), + pytest.param( + ( + correl_20, + "correl_20", + lambda: pd.DataFrame({"n": [3002]}), + ), + id="correl_20", + ), ], ) def pydough_pipeline_test_data( diff --git a/tests/test_plan_refsols/correl_1.txt b/tests/test_plan_refsols/correl_1.txt new file mode 100644 index 00000000..c5956c80 --- /dev/null +++ b/tests/test_plan_refsols/correl_1.txt @@ -0,0 +1,8 @@ +ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'n_prefix_nations': n_prefix_nations, 'name': name, 'ordering_1': name}) + PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/correl_10.txt b/tests/test_plan_refsols/correl_10.txt new file mode 100644 index 00000000..e762421d --- /dev/null +++ b/tests/test_plan_refsols/correl_10.txt @@ -0,0 +1,8 @@ +ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) + PROJECT(columns={'name': name, 'rname': NULL_2}) + FILTER(condition=True:bool, columns={'NULL_2': NULL_2, 'name': name}) + JOIN(conditions=[t0.region_key == t1.key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'key': key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/correl_11.txt b/tests/test_plan_refsols/correl_11.txt new file mode 100644 index 00000000..7e48bb88 --- /dev/null +++ b/tests/test_plan_refsols/correl_11.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('brand', brand)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'brand': brand, 'ordering_1': brand}) + FILTER(condition=True:bool, columns={'brand': brand}) + JOIN(conditions=[t0.brand == t1.brand], types=['semi'], columns={'brand': t0.brand}, correl_name='corr1') + PROJECT(columns={'avg_price': agg_0, 'brand': brand}) + AGGREGATE(keys={'brand': brand}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice}) + FILTER(condition=retail_price > 1.4:float64 * corr1.avg_price, columns={'brand': brand}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_12.txt b/tests/test_plan_refsols/correl_12.txt new file mode 100644 index 00000000..60626462 --- /dev/null +++ b/tests/test_plan_refsols/correl_12.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('brand', brand)], orderings=[(ordering_2):asc_first]) + PROJECT(columns={'brand': brand, 'ordering_2': brand}) + FILTER(condition=True:bool, columns={'brand': brand}) + JOIN(conditions=[t0.brand == t1.brand], types=['semi'], columns={'brand': t0.brand}, correl_name='corr2') + PROJECT(columns={'avg_price': avg_price, 'avg_price_2': agg_1, 'brand': brand}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1, 'avg_price': t0.avg_price, 'brand': t1.brand}) + PROJECT(columns={'avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + AGGREGATE(keys={'brand': brand}, aggregations={'agg_1': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice}) + FILTER(condition=retail_price > corr2.avg_price_2 & retail_price < corr2.avg_price & size < 3:int64, columns={'brand': brand}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice, 'size': p_size}) diff --git a/tests/test_plan_refsols/correl_13.txt b/tests/test_plan_refsols/correl_13.txt new file mode 100644 index 00000000..bc779fbd --- /dev/null +++ b/tests/test_plan_refsols/correl_13.txt @@ -0,0 +1,18 @@ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) + FILTER(condition=DEFAULT_TO(agg_1, 0:int64) > 0:int64, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_1': t1.agg_1}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'key': t0.key}) + FILTER(condition=nation_key <= 3:int64, columns={'account_balance': account_balance, 'key': key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_1': COUNT()}) + FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + FILTER(condition=STARTSWITH(container, 'SM':string) & retail_price < corr2.supplycost * 1.5:float64, columns={'key': key}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_14.txt b/tests/test_plan_refsols/correl_14.txt new file mode 100644 index 00000000..909904c1 --- /dev/null +++ b/tests/test_plan_refsols/correl_14.txt @@ -0,0 +1,18 @@ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) + FILTER(condition=True:bool, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr3') + PROJECT(columns={'account_balance': account_balance, 'avg_price': agg_0, 'key': key}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'key': t0.key}) + FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'key': key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) + FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + FILTER(condition=container == 'LG DRUM':string & retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price, columns={'key': key}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_15.txt b/tests/test_plan_refsols/correl_15.txt new file mode 100644 index 00000000..0dc3f6a6 --- /dev/null +++ b/tests/test_plan_refsols/correl_15.txt @@ -0,0 +1,22 @@ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_1}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1}, correl_name='corr4') + PROJECT(columns={'avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) + FILTER(condition=True:bool, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr3') + PROJECT(columns={'account_balance': account_balance, 'avg_price': agg_0, 'key': key}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'key': t0.key}) + FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'key': key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) + FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + FILTER(condition=container == 'LG DRUM':string & retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price & retail_price < corr4.avg_price * 0.85:float64, columns={'key': key}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_16.txt b/tests/test_plan_refsols/correl_16.txt new file mode 100644 index 00000000..f8c2e4eb --- /dev/null +++ b/tests/test_plan_refsols/correl_16.txt @@ -0,0 +1,14 @@ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) + FILTER(condition=True:bool, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.nation_key == t1.key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr7') + PROJECT(columns={'account_balance': account_balance, 'nation_key': nation_key, 'tile': PERCENTILE(args=[], partition=[], order=[(account_balance):asc_last, (key):asc_last], n_buckets=10000)}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + FILTER(condition=PERCENTILE(args=[], partition=[], order=[(acctbal):asc_last, (key_5):asc_last], n_buckets=10000) == corr7.tile & rname == 'EUROPE':string, columns={'key': key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'key': t0.key, 'key_5': t1.key, 'rname': t0.rname}) + PROJECT(columns={'key': key, 'rname': name_3}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_17.txt b/tests/test_plan_refsols/correl_17.txt new file mode 100644 index 00000000..4e532f3c --- /dev/null +++ b/tests/test_plan_refsols/correl_17.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('fullname', fullname)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'fullname': fullname, 'ordering_0': fullname}) + PROJECT(columns={'fullname': fname}) + FILTER(condition=True:bool, columns={'fname': fname}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'fname': t1.fname}, correl_name='corr1') + PROJECT(columns={'lname': LOWER(name), 'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + PROJECT(columns={'fname': JOIN_STRINGS('-':string, LOWER(name), corr1.lname), 'key': key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/correl_18.txt b/tests/test_plan_refsols/correl_18.txt new file mode 100644 index 00000000..ab36ffd4 --- /dev/null +++ b/tests/test_plan_refsols/correl_18.txt @@ -0,0 +1,14 @@ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={}, aggregations={'agg_0': SUM(n_above_avg)}) + PROJECT(columns={'n_above_avg': DEFAULT_TO(agg_2, 0:int64)}) + JOIN(conditions=[t0.customer_key == t1.customer_key & t0.order_date == t1.order_date], types=['left'], columns={'agg_2': t1.agg_2}, correl_name='corr1') + PROJECT(columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': DEFAULT_TO(agg_1, 0:int64)}) + FILTER(condition=DEFAULT_TO(agg_0, 0:int64) > 1:int64, columns={'agg_1': agg_1, 'customer_key': customer_key, 'order_date': order_date}) + AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_0': COUNT(), 'agg_1': SUM(total_price)}) + FILTER(condition=YEAR(order_date) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) + AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_2': COUNT()}) + FILTER(condition=total_price >= 0.5:float64 * corr1.total_price, columns={'customer_key': customer_key, 'order_date': order_date}) + FILTER(condition=YEAR(order_date) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) diff --git a/tests/test_plan_refsols/correl_19.txt b/tests/test_plan_refsols/correl_19.txt new file mode 100644 index 00000000..a273084b --- /dev/null +++ b/tests/test_plan_refsols/correl_19.txt @@ -0,0 +1,12 @@ +ROOT(columns=[('name', name_7), ('n_super_cust', n_super_cust)], orderings=[(ordering_1):desc_last]) + PROJECT(columns={'n_super_cust': n_super_cust, 'name_7': name_3, 'ordering_1': ordering_1}) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'n_super_cust': n_super_cust, 'name_3': name_3, 'ordering_1': ordering_1}, orderings=[(ordering_1):desc_last]) + PROJECT(columns={'n_super_cust': n_super_cust, 'name_3': name_3, 'ordering_1': n_super_cust}) + PROJECT(columns={'n_super_cust': DEFAULT_TO(agg_0, 0:int64), 'name_3': name_3}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}, correl_name='corr4') + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'account_balance': t0.account_balance, 'key_2': t1.key, 'name_3': t1.name}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=acctbal > corr4.account_balance, columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_2.txt b/tests/test_plan_refsols/correl_2.txt new file mode 100644 index 00000000..529b06fd --- /dev/null +++ b/tests/test_plan_refsols/correl_2.txt @@ -0,0 +1,12 @@ +ROOT(columns=[('name', name_7), ('n_selected_custs', n_selected_custs)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_7': name_6, 'ordering_1': ordering_1}) + PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_6': name_6, 'ordering_1': name_6}) + PROJECT(columns={'n_selected_custs': DEFAULT_TO(agg_0, 0:int64), 'name_6': name_3}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}, correl_name='corr4') + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name}) + FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(comment, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(corr4.name, None:unknown, 1:int64, None:unknown)), columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_20.txt b/tests/test_plan_refsols/correl_20.txt new file mode 100644 index 00000000..c1d388c4 --- /dev/null +++ b/tests/test_plan_refsols/correl_20.txt @@ -0,0 +1,17 @@ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) + FILTER(condition=domestic, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.nation_key_11 == t1.key], types=['left'], columns={'account_balance': t0.account_balance, 'domestic': t1.domestic}, correl_name='corr13') + JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'account_balance': t1.account_balance, 'name': t0.name, 'nation_key_11': t1.nation_key}) + JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'name': t0.name, 'supplier_key': t1.supplier_key}) + FILTER(condition=YEAR(order_date) == 1998:int64 & MONTH(order_date) == 6:int64, columns={'key_5': key_5, 'name': name}) + JOIN(conditions=[t0.key_2 == t1.customer_key], types=['inner'], columns={'key_5': t1.key, 'name': t0.name, 'order_date': t1.order_date}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'supplier_key': l_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + PROJECT(columns={'domestic': name == corr13.name, 'key': key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) diff --git a/tests/test_plan_refsols/correl_3.txt b/tests/test_plan_refsols/correl_3.txt new file mode 100644 index 00000000..57d2dfdf --- /dev/null +++ b/tests/test_plan_refsols/correl_3.txt @@ -0,0 +1,11 @@ +ROOT(columns=[('name', name), ('n_nations', n_nations)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'n_nations': n_nations, 'name': name, 'ordering_1': name}) + PROJECT(columns={'n_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=True:bool, columns={'region_key': region_key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['semi'], columns={'region_key': t0.region_key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + FILTER(condition=SLICE(comment, None:unknown, 2:int64, None:unknown) == LOWER(SLICE(corr1.name, None:unknown, 2:int64, None:unknown)), columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_4.txt b/tests/test_plan_refsols/correl_4.txt new file mode 100644 index 00000000..38feea33 --- /dev/null +++ b/tests/test_plan_refsols/correl_4.txt @@ -0,0 +1,11 @@ +ROOT(columns=[('name', name)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'name': name, 'ordering_1': name}) + FILTER(condition=True:bool, columns={'name': name}) + JOIN(conditions=[t0.key == t1.nation_key], types=['anti'], columns={'name': t0.name}, correl_name='corr1') + JOIN(conditions=[True:bool], types=['inner'], columns={'key': t1.key, 'name': t1.name, 'smallest_bal': t0.smallest_bal}) + PROJECT(columns={'smallest_bal': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': MIN(acctbal)}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + FILTER(condition=acctbal <= corr1.smallest_bal + 5.0:float64, columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_5.txt b/tests/test_plan_refsols/correl_5.txt new file mode 100644 index 00000000..73d793f0 --- /dev/null +++ b/tests/test_plan_refsols/correl_5.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('name', name)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'name': name, 'ordering_1': name}) + FILTER(condition=True:bool, columns={'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['semi'], columns={'name': t0.name}, correl_name='corr4') + JOIN(conditions=[True:bool], types=['inner'], columns={'key': t1.key, 'name': t1.name, 'smallest_bal': t0.smallest_bal}) + PROJECT(columns={'smallest_bal': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': MIN(account_balance)}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + FILTER(condition=account_balance <= corr4.smallest_bal + 4.0:float64, columns={'region_key': region_key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'account_balance': t1.account_balance, 'region_key': t0.region_key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) diff --git a/tests/test_plan_refsols/correl_6.txt b/tests/test_plan_refsols/correl_6.txt new file mode 100644 index 00000000..0a85a6fa --- /dev/null +++ b/tests/test_plan_refsols/correl_6.txt @@ -0,0 +1,8 @@ +ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[]) + PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + FILTER(condition=True:bool, columns={'agg_0': agg_0, 'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/correl_7.txt b/tests/test_plan_refsols/correl_7.txt new file mode 100644 index 00000000..94a129af --- /dev/null +++ b/tests/test_plan_refsols/correl_7.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[]) + PROJECT(columns={'n_prefix_nations': DEFAULT_TO(NULL_2, 0:int64), 'name': name}) + FILTER(condition=True:bool, columns={'NULL_2': NULL_2, 'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/correl_8.txt b/tests/test_plan_refsols/correl_8.txt new file mode 100644 index 00000000..87bcc66e --- /dev/null +++ b/tests/test_plan_refsols/correl_8.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) + PROJECT(columns={'name': name, 'rname': name_4}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'name': t0.name, 'name_4': t1.name}, correl_name='corr1') + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/correl_9.txt b/tests/test_plan_refsols/correl_9.txt new file mode 100644 index 00000000..6a7a6c13 --- /dev/null +++ b/tests/test_plan_refsols/correl_9.txt @@ -0,0 +1,8 @@ +ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) + PROJECT(columns={'name': name, 'rname': name_4}) + FILTER(condition=True:bool, columns={'name': name, 'name_4': name_4}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name': t0.name, 'name_4': t1.name}, correl_name='corr1') + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/tpch_q21.txt b/tests/test_plan_refsols/tpch_q21.txt new file mode 100644 index 00000000..fdf99273 --- /dev/null +++ b/tests/test_plan_refsols/tpch_q21.txt @@ -0,0 +1,21 @@ +ROOT(columns=[('S_NAME', S_NAME), ('NUMWAIT', NUMWAIT)], orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'NUMWAIT': NUMWAIT, 'S_NAME': S_NAME, 'ordering_1': ordering_1, 'ordering_2': ordering_2}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) + PROJECT(columns={'NUMWAIT': NUMWAIT, 'S_NAME': S_NAME, 'ordering_1': NUMWAIT, 'ordering_2': S_NAME}) + PROJECT(columns={'NUMWAIT': DEFAULT_TO(agg_0, 0:int64), 'S_NAME': name}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + FILTER(condition=name_3 == 'SAUDI ARABIA':string, columns={'key': key, 'name': name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=order_status == 'F':string & True:bool & True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.key == t1.order_key], types=['anti'], columns={'order_status': t0.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr6') + JOIN(conditions=[t0.key == t1.order_key], types=['semi'], columns={'key': t0.key, 'order_status': t0.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr5') + JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'key': t1.key, 'order_status': t1.order_status, 'supplier_key': t0.supplier_key}) + FILTER(condition=receipt_date > commit_date, columns={'order_key': order_key, 'supplier_key': supplier_key}) + SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate, 'supplier_key': l_suppkey}) + SCAN(table=tpch.ORDERS, columns={'key': o_orderkey, 'order_status': o_orderstatus}) + FILTER(condition=supplier_key != corr5.supplier_key, columns={'order_key': order_key}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'supplier_key': l_suppkey}) + FILTER(condition=supplier_key != corr6.supplier_key & receipt_date > commit_date, columns={'order_key': order_key}) + SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate, 'supplier_key': l_suppkey}) diff --git a/tests/test_plan_refsols/tpch_q22.txt b/tests/test_plan_refsols/tpch_q22.txt new file mode 100644 index 00000000..1bd521eb --- /dev/null +++ b/tests/test_plan_refsols/tpch_q22.txt @@ -0,0 +1,18 @@ +ROOT(columns=[('CNTRY_CODE', CNTRY_CODE), ('NUM_CUSTS', NUM_CUSTS), ('TOTACCTBAL', TOTACCTBAL)], orderings=[]) + PROJECT(columns={'CNTRY_CODE': cntry_code, 'NUM_CUSTS': DEFAULT_TO(agg_1, 0:int64), 'TOTACCTBAL': DEFAULT_TO(agg_2, 0:int64)}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1, 'agg_2': t1.agg_2, 'cntry_code': t1.cntry_code}, correl_name='corr1') + PROJECT(columns={'avg_balance': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(acctbal)}) + FILTER(condition=acctbal > 0.0:float64, columns={'acctbal': acctbal}) + FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]) & True:bool, columns={'acctbal': acctbal}) + JOIN(conditions=[t0.key == t1.customer_key], types=['anti'], columns={'acctbal': t0.acctbal, 'cntry_code': t0.cntry_code}) + PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) + AGGREGATE(keys={'cntry_code': cntry_code}, aggregations={'agg_1': COUNT(), 'agg_2': SUM(acctbal)}) + FILTER(condition=acctbal > corr1.avg_balance, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) + FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]) & True:bool, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) + JOIN(conditions=[t0.key == t1.customer_key], types=['anti'], columns={'acctbal': t0.acctbal, 'cntry_code': t0.cntry_code}) + PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) diff --git a/tests/test_plan_refsols/tpch_q5.txt b/tests/test_plan_refsols/tpch_q5.txt new file mode 100644 index 00000000..fbd35207 --- /dev/null +++ b/tests/test_plan_refsols/tpch_q5.txt @@ -0,0 +1,21 @@ +ROOT(columns=[('N_NAME', N_NAME), ('REVENUE', REVENUE)], orderings=[(ordering_1):desc_last]) + PROJECT(columns={'N_NAME': N_NAME, 'REVENUE': REVENUE, 'ordering_1': REVENUE}) + PROJECT(columns={'N_NAME': name, 'REVENUE': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr10') + FILTER(condition=name_3 == 'ASIA':string, columns={'key': key, 'name': name}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': SUM(value)}) + PROJECT(columns={'nation_key': nation_key, 'value': extended_price * 1:int64 - discount}) + FILTER(condition=name_9 == corr10.name, columns={'discount': discount, 'extended_price': extended_price, 'nation_key': nation_key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name_9': t1.name_9, 'nation_key': t0.nation_key}) + JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'nation_key': t0.nation_key, 'supplier_key': t1.supplier_key}) + FILTER(condition=order_date >= datetime.date(1994, 1, 1):date & order_date < datetime.date(1995, 1, 1):date, columns={'key_5': key_5, 'nation_key': nation_key}) + JOIN(conditions=[t0.key == t1.customer_key], types=['inner'], columns={'key_5': t1.key, 'nation_key': t0.nation_key, 'order_date': t1.order_date}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'supplier_key': l_suppkey}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_9': t1.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) diff --git a/tests/tpch_outputs.py b/tests/tpch_outputs.py index c875d630..80948acb 100644 --- a/tests/tpch_outputs.py +++ b/tests/tpch_outputs.py @@ -692,7 +692,7 @@ def tpch_q21_output() -> pd.DataFrame: Expected output for TPC-H query 21. Note: This is truncated to the first 10 rows. """ - columns = ["S_NAME", "NUM_WAIT"] + columns = ["S_NAME", "NUMWAIT"] data = [ ("Supplier#000002829", 20), ("Supplier#000005808", 18), From 0a28b20b71930ee04ae24f19fcbc8ad9da0bad31 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Mon, 17 Feb 2025 15:28:19 -0500 Subject: [PATCH 3/7] Support uses of BACK that cause correlated references: SQLGlot conversion (#234) Child PR of #269 that is part of addressing #141. Adds the SQLGlot conversion step. With these changes, all SQLGlot correlation queries are functional except for: - TPC-H queries: 5 & 22 - Correl queries: 1, 2, 3, 6, 8, 9, 15, 17, 18, 19, 20 These remainders need to be handled via the decorrelation cases in the comments of #141. Specifically: - Singular: Correl 8 - Aggregation: TPC-H Q5, TPC-H Q22, Correl 1, Correl 2, Correl 3, Correl 18, Correl 19 - Semi-Singular: Correl 9, Correl 17, Correl 20 - Semi-Aggregation: Correl 6 --- .../sqlglot_relational_expression_visitor.py | 15 ++++--- pydough/sqlglot/sqlglot_relational_visitor.py | 41 +++++++++++++++---- 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/pydough/sqlglot/sqlglot_relational_expression_visitor.py b/pydough/sqlglot/sqlglot_relational_expression_visitor.py index 017ebe60..e15e5007 100644 --- a/pydough/sqlglot/sqlglot_relational_expression_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_expression_visitor.py @@ -33,13 +33,17 @@ class SQLGlotRelationalExpressionVisitor(RelationalExpressionVisitor): """ def __init__( - self, dialect: SQLGlotDialect, bindings: SqlGlotTransformBindings + self, + dialect: SQLGlotDialect, + bindings: SqlGlotTransformBindings, + correlated_names: dict[str, str], ) -> None: # Keep a stack of SQLGlot expressions so we can build up # intermediate results. self._stack: list[SQLGlotExpression] = [] self._dialect: SQLGlotDialect = dialect self._bindings: SqlGlotTransformBindings = bindings + self._correlated_names: dict[str, str] = correlated_names def reset(self) -> None: """ @@ -137,20 +141,19 @@ def visit_literal_expression(self, literal_expression: LiteralExpression) -> Non def visit_correlated_reference( self, correlated_reference: CorrelatedReference ) -> None: - raise NotImplementedError("TODO") + full_name: str = f"{self._correlated_names[correlated_reference.correl_name]}.{correlated_reference.name}" + self._stack.append(Identifier(this=full_name)) @staticmethod - def generate_column_reference_identifier( + def make_sqlglot_column( column_reference: ColumnReference, ) -> Identifier: """ Generate an identifier for a column reference. This is split into a separate static method to ensure consistency across multiple visitors. - Args: column_reference (ColumnReference): The column reference to generate an identifier for. - Returns: Identifier: The output identifier. """ @@ -161,7 +164,7 @@ def generate_column_reference_identifier( return Identifier(this=full_name) def visit_column_reference(self, column_reference: ColumnReference) -> None: - self._stack.append(self.generate_column_reference_identifier(column_reference)) + self._stack.append(self.make_sqlglot_column(column_reference)) def relational_to_sqlglot( self, expr: RelationalExpression, output_name: str | None = None diff --git a/pydough/sqlglot/sqlglot_relational_visitor.py b/pydough/sqlglot/sqlglot_relational_visitor.py index ae619893..0b53f04d 100644 --- a/pydough/sqlglot/sqlglot_relational_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_visitor.py @@ -9,6 +9,7 @@ from sqlglot.dialects import Dialect as SQLGlotDialect from sqlglot.expressions import Alias as SQLGlotAlias +from sqlglot.expressions import Column as SQLGlotColumn from sqlglot.expressions import Expression as SQLGlotExpression from sqlglot.expressions import Identifier, Select, Subquery, values from sqlglot.expressions import Literal as SQLGlotLiteral @@ -20,6 +21,7 @@ ColumnReference, ColumnReferenceInputNameModifier, ColumnReferenceInputNameRemover, + CorrelatedReference, EmptySingleton, ExpressionSortInfo, Filter, @@ -54,8 +56,11 @@ def __init__( # Keep a stack of SQLGlot expressions so we can build up # intermediate results. self._stack: list[Select] = [] + self._correlated_names: dict[str, str] = {} self._expr_visitor: SQLGlotRelationalExpressionVisitor = ( - SQLGlotRelationalExpressionVisitor(dialect, bindings) + SQLGlotRelationalExpressionVisitor( + dialect, bindings, self._correlated_names + ) ) self._alias_modifier: ColumnReferenceInputNameModifier = ( ColumnReferenceInputNameModifier() @@ -94,7 +99,7 @@ def _is_mergeable_column(expr: SQLGlotExpression) -> bool: if isinstance(expr, SQLGlotAlias): return SQLGlotRelationalVisitor._is_mergeable_column(expr.this) else: - return isinstance(expr, (SQLGlotLiteral, Identifier)) + return isinstance(expr, (SQLGlotLiteral, Identifier, SQLGlotColumn)) @staticmethod def _try_merge_columns( @@ -154,11 +159,22 @@ def _try_merge_columns( # If the new column is a literal, we can just add it to the old # columns. modified_old_columns.append(set_glot_alias(new_column, new_name)) - else: + elif isinstance(new_column, Identifier): expr = set_glot_alias(old_column_map[new_column.this], new_name) modified_old_columns.append(expr) if isinstance(expr, Identifier): seen_cols.add(expr) + elif isinstance(new_column, SQLGlotColumn): + expr = set_glot_alias( + old_column_map[new_column.this.this], new_name + ) + modified_old_columns.append(expr) + if isinstance(expr, Identifier): + seen_cols.add(expr) + else: + raise ValueError( + f"Unsupported expression type for column merging: {new_column.__class__.__name__}" + ) # Check that there are no missing dependencies in the old columns. if old_column_deps - seen_cols: return new_columns, old_columns @@ -301,7 +317,7 @@ def contains_window(self, exp: RelationalExpression) -> bool: match exp: case CallExpression(): return any(self.contains_window(arg) for arg in exp.inputs) - case ColumnReference() | LiteralExpression(): + case ColumnReference() | LiteralExpression() | CorrelatedReference(): return False case WindowCallExpression(): return True @@ -327,6 +343,12 @@ def visit_scan(self, scan: Scan) -> None: self._stack.append(query) def visit_join(self, join: Join) -> None: + alias_map: dict[str | None, str] = {} + if join.correl_name is not None: + input_name = join.default_input_aliases[0] + alias = self._generate_table_alias() + alias_map[input_name] = alias + self._correlated_names[join.correl_name] = alias self.visit_inputs(join) inputs: list[Select] = [self._stack.pop() for _ in range(len(join.inputs))] inputs.reverse() @@ -337,11 +359,12 @@ def visit_join(self, join: Join) -> None: seen_names[column] += 1 # Only keep duplicate names. kept_names = {key for key, value in seen_names.items() if value > 1} - alias_map = { - join.default_input_aliases[i]: self._generate_table_alias() - for i in range(len(join.inputs)) - if kept_names.intersection(join.inputs[i].columns.keys()) - } + for i in range(len(join.inputs)): + input_name = join.default_input_aliases[i] + if input_name not in alias_map and kept_names.intersection( + join.inputs[i].columns.keys() + ): + alias_map[input_name] = self._generate_table_alias() self._alias_remover.set_kept_names(kept_names) self._alias_modifier.set_map(alias_map) columns = { From b304c2977b41727c2100d379c4b223ce5ec939f0 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 18 Feb 2025 15:39:02 -0500 Subject: [PATCH 4/7] Support uses of BACK that cause correlated references: setup decorrelation handling (#251) Child PR of #269 that is part of addressing #141. Adds the handling of de-correlation as a post-processing step of the hybrid tree before relational conversion starts. Handles the singular & aggregate cases as described in the issue, and handles the semi-singular and semi-aggregate cases in the same manner as their non-semi equivalents (for now), saving the optimized variants for another day. These remaining still need to be addressed in the next follow up: - TPC-H Q22: error during relational conversion that likely means a bug happened during hybrid conversion or hybrid decorrelation - Correl 15: error during decorrelation causing over-excessive column pruning - Correl 18: tbd bug in decorrelation - Correl 19: bug in name resolution handling that causes the wrong term to be used when certain property names collide --- pydough/conversion/hybrid_decorrelater.py | 330 +++++++++++++++++++++ pydough/conversion/hybrid_tree.py | 38 ++- pydough/conversion/relational_converter.py | 10 +- pydough/pydough_operators/base_operator.py | 12 + pydough/types/struct_type.py | 2 +- tests/test_plan_refsols/correl_1.txt | 10 +- tests/test_plan_refsols/correl_15.txt | 25 +- tests/test_plan_refsols/correl_17.txt | 12 +- tests/test_plan_refsols/correl_18.txt | 23 +- tests/test_plan_refsols/correl_19.txt | 20 +- tests/test_plan_refsols/correl_2.txt | 25 +- tests/test_plan_refsols/correl_20.txt | 33 ++- tests/test_plan_refsols/correl_3.txt | 14 +- tests/test_plan_refsols/correl_6.txt | 10 +- tests/test_plan_refsols/correl_8.txt | 12 +- tests/test_plan_refsols/correl_9.txt | 14 +- tests/test_plan_refsols/tpch_q5.txt | 25 +- 17 files changed, 500 insertions(+), 115 deletions(-) create mode 100644 pydough/conversion/hybrid_decorrelater.py diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py new file mode 100644 index 00000000..72f6aad1 --- /dev/null +++ b/pydough/conversion/hybrid_decorrelater.py @@ -0,0 +1,330 @@ +""" +Logic for applying de-correlation to hybrid trees before relational conversion +if the correlate is not a semi/anti join. +""" + +__all__ = ["run_hybrid_decorrelation"] + + +import copy + +from .hybrid_tree import ( + ConnectionType, + HybridBackRefExpr, + HybridCalc, + HybridChildRefExpr, + HybridColumnExpr, + HybridConnection, + HybridCorrelExpr, + HybridExpr, + HybridFilter, + HybridFunctionExpr, + HybridLiteralExpr, + HybridPartition, + HybridRefExpr, + HybridTree, + HybridWindowExpr, +) + + +class Decorrelater: + """ + Class that encapsulates the logic used for de-correlation of hybrid trees. + """ + + def make_decorrelate_parent( + self, hybrid: HybridTree, child_idx: int, required_steps: int + ) -> HybridTree: + """ + Creates a snapshot of the ancestry of the hybrid tree that contains + a correlated child, without any of its children, its descendants, or + any pipeline operators that do not need to be there. + + Args: + `hybrid`: The hybrid tree to create a snapshot of in order to aid + in the de-correlation of a correlated child. + `child_idx`: The index of the correlated child of hybrid that the + snapshot is being created to aid in the de-correlation of. + `required_steps`: The index of the last pipeline operator that + needs to be included in the snapshot in order for the child to be + derivable. + + Returns: + A snapshot of `hybrid` and its ancestry in the hybrid tree, without + without any of its children or pipeline operators that occur during + or after the derivation of the correlated child, or without any of + its descendants. + """ + if isinstance(hybrid.pipeline[0], HybridPartition) and child_idx == 0: + # Special case: if the correlated child is the data argument of a + # partition operation, then the parent to snapshot is actually the + # parent of the level containing the partition operation. In this + # case, all of the parent's children & pipeline operators should be + # included in the snapshot. + assert hybrid.parent is not None + return self.make_decorrelate_parent( + hybrid.parent, len(hybrid.parent.children), len(hybrid.pipeline) + ) + # Temporarily detach the successor of the current level, then create a + # deep copy of the current level (which will include its ancestors), + # then reattach the successor back to the original. This ensures that + # the descendants of the current level are not included when providing + # the parent to the correlated child as its new ancestor. + successor: HybridTree | None = hybrid.successor + hybrid._successor = None + new_hybrid: HybridTree = copy.deepcopy(hybrid) + hybrid._successor = successor + # Ensure the new parent only includes the children & pipeline operators + # that is has to. + new_hybrid._children = new_hybrid._children[:child_idx] + new_hybrid._pipeline = new_hybrid._pipeline[: required_steps + 1] + return new_hybrid + + def remove_correl_refs( + self, expr: HybridExpr, parent: HybridTree, child_height: int + ) -> HybridExpr: + """ + Recursively & destructively removes correlated references within a + hybrid expression if they point to a specific correlated ancestor + hybrid tree, and replaces them with corresponding BACK references. + + Args: + `expr`: The hybrid expression to remove correlated references from. + `parent`: The correlated ancestor hybrid tree that the correlated + references should point to when they are targeted for removal. + `child_height`: The height of the correlated child within the + hybrid tree that the correlated references is point to. This is + the number of BACK indices to shift by when replacing the + correlated reference with a BACK reference. + + Returns: + The hybrid expression with all correlated references to `parent` + replaced with corresponding BACK references. The replacement also + happens in-place. + """ + match expr: + case HybridCorrelExpr(): + # If the correlated reference points to the parent, then + # replace it with a BACK reference. Otherwise, recursively + # transform its input expression in case it contains another + # correlated reference. + if expr.hybrid is parent: + result: HybridExpr | None = expr.expr.shift_back(child_height) + assert result is not None + return result + else: + expr.expr = self.remove_correl_refs(expr.expr, parent, child_height) + return expr + case HybridFunctionExpr(): + # For regular functions, recursively transform all of their + # arguments. + for idx, arg in enumerate(expr.args): + expr.args[idx] = self.remove_correl_refs(arg, parent, child_height) + return expr + case HybridWindowExpr(): + # For window functions, recursively transform all of their + # arguments, partition keys, and order keys. + for idx, arg in enumerate(expr.args): + expr.args[idx] = self.remove_correl_refs(arg, parent, child_height) + for idx, arg in enumerate(expr.partition_args): + expr.partition_args[idx] = self.remove_correl_refs( + arg, parent, child_height + ) + for order_arg in expr.order_args: + order_arg.expr = self.remove_correl_refs( + order_arg.expr, parent, child_height + ) + return expr + case ( + HybridBackRefExpr() + | HybridRefExpr() + | HybridChildRefExpr() + | HybridLiteralExpr() + | HybridColumnExpr() + ): + # All other expression types do not require any transformation + # to de-correlate since they cannot contain correlations. + return expr + case _: + raise NotImplementedError( + f"Unsupported expression type: {expr.__class__.__name__}." + ) + + def correl_ref_purge( + self, + level: HybridTree | None, + old_parent: HybridTree, + new_parent: HybridTree, + child_height: int, + ) -> None: + """ + The recursive procedure to remove correlated references from the + expressions of a hybrid tree or any of its ancestors or children if + they refer to a specific correlated ancestor that is being removed. + + Args: + `level`: The current level of the hybrid tree to remove correlated + references from. + `old_parent`: The correlated ancestor hybrid tree that the correlated + references should point to when they are targeted for removal. + `new_parent`: The ancestor of `level` that removal should stop at + because it is the transposed snapshot of `old_parent`, and + therefore it & its ancestors cannot contain any more correlated + references that would be targeted for removal. + `child_height`: The height of the correlated child within the + hybrid tree that the correlated references is point to. This is + the number of BACK indices to shift by when replacing the + correlated reference with a BACK + """ + while level is not None and level is not new_parent: + # First, recursively remove any targeted correlated references from + # the children of the current level. + for child in level.children: + self.correl_ref_purge( + child.subtree, old_parent, new_parent, child_height + ) + # Then, remove any correlated references from the pipeline + # operators of the current level. Usually this just means + # transforming the terms/orderings/unique keys of the operation, + # but specific operation types will require special casing if they + # have additional expressions stored in other field that need to be + # transformed. + for operation in level.pipeline: + for name, expr in operation.terms.items(): + operation.terms[name] = self.remove_correl_refs( + expr, old_parent, child_height + ) + for ordering in operation.orderings: + ordering.expr = self.remove_correl_refs( + ordering.expr, old_parent, child_height + ) + for idx, expr in enumerate(operation.unique_exprs): + operation.unique_exprs[idx] = self.remove_correl_refs( + expr, old_parent, child_height + ) + if isinstance(operation, HybridCalc): + for str, expr in operation.new_expressions.items(): + operation.new_expressions[str] = self.remove_correl_refs( + expr, old_parent, child_height + ) + if isinstance(operation, HybridFilter): + operation.condition = self.remove_correl_refs( + operation.condition, old_parent, child_height + ) + # Repeat the process on the ancestor until either loop guard + # condition is no longer True. + level = level.parent + + def decorrelate_child( + self, + old_parent: HybridTree, + new_parent: HybridTree, + child: HybridConnection, + is_aggregate: bool, + ) -> None: + """ + Runs the logic to de-correlate a child of a hybrid tree that contains + a correlated reference. This involves linking the child to a new parent + as its ancestor, the parent being a snapshot of the original hybrid + tree that contained the correlated child as a child. The transformed + child can now replace correlated references with BACK references that + point to terms in its newly expanded ancestry, and the original hybrid + tree can now join onto this child using its uniqueness keys. + """ + # First, find the height of the child subtree & its top-most level. + child_root: HybridTree = child.subtree + child_height: int = 1 + while child_root.parent is not None: + child_height += 1 + child_root = child_root.parent + # Link the top level of the child subtree to the new parent. + new_parent.add_successor(child_root) + # Replace any correlated references to the original parent with BACK references. + self.correl_ref_purge(child.subtree, old_parent, new_parent, child_height) + # Update the join keys to join on the unique keys of all the ancestors. + new_join_keys: list[tuple[HybridExpr, HybridExpr]] = [] + additional_levels: int = 0 + current_level: HybridTree | None = old_parent + while current_level is not None: + for unique_key in current_level.pipeline[0].unique_exprs: + lhs_key: HybridExpr | None = unique_key.shift_back(additional_levels) + rhs_key: HybridExpr | None = unique_key.shift_back( + additional_levels + child_height + ) + assert lhs_key is not None and rhs_key is not None + new_join_keys.append((lhs_key, rhs_key)) + current_level = current_level.parent + additional_levels += 1 + child.subtree.join_keys = new_join_keys + # If aggregating, do the same with the aggregation keys. + if is_aggregate: + new_agg_keys: list[HybridExpr] = [] + assert child.subtree.join_keys is not None + for _, rhs_key in child.subtree.join_keys: + new_agg_keys.append(rhs_key) + child.subtree.agg_keys = new_agg_keys + + def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: + """ + TODO + """ + # Recursively decorrelate the ancestors of the current level of the + # hybrid tree. + if hybrid.parent is not None: + hybrid._parent = self.decorrelate_hybrid_tree(hybrid.parent) + hybrid._parent._successor = hybrid + # Iterate across all the children and recursively decorrelate them. + for child in hybrid.children: + child.subtree = self.decorrelate_hybrid_tree(child.subtree) + # Iterate across all the children, identify any that are correlated, + # and transform any of the correlated ones that require decorrelation + # due to the type of connection. + for idx, child in enumerate(hybrid.children): + if idx not in hybrid.correlated_children: + continue + new_parent: HybridTree = self.make_decorrelate_parent( + hybrid, idx, hybrid.children[idx].required_steps + ) + match child.connection_type: + case ( + ConnectionType.SINGULAR + | ConnectionType.SINGULAR_ONLY_MATCH + | ConnectionType.AGGREGATION + | ConnectionType.AGGREGATION_ONLY_MATCH + ): + self.decorrelate_child( + hybrid, new_parent, child, child.connection_type.is_aggregation + ) + case ConnectionType.NDISTINCT | ConnectionType.NDISTINCT_ONLY_MATCH: + raise NotImplementedError( + f"PyDough does not yet support correlated references with the {child.connection_type.name} pattern." + ) + case ( + ConnectionType.SEMI + | ConnectionType.ANTI + | ConnectionType.NO_MATCH_SINGULAR + | ConnectionType.NO_MATCH_AGGREGATION + | ConnectionType.NO_MATCH_NDISTINCT + ): + # These patterns do not require decorrelation since they + # are supported via correlated SEMI/ANTI joins. + continue + return hybrid + + +def run_hybrid_decorrelation(hybrid: HybridTree) -> HybridTree: + """ + Invokes the procedure to remove correlated references from a hybrid tree + before relational conversion if those correlated references are invalid + (e.g. not from a semi/anti join). + + Args: + `hybrid`: The hybrid tree to remove correlated references from. + + Returns: + The hybrid tree with all invalid correlated references removed as the + tree structure is re-written to allow them to be replaced with BACK + references. The transformation is also done in-place. + """ + decorr: Decorrelater = Decorrelater() + return decorr.decorrelate_hybrid_tree(hybrid) diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 79a03062..956e88f4 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -184,6 +184,8 @@ def apply_renamings(self, renamings: dict[str, str]) -> "HybridExpr": return self def shift_back(self, levels: int) -> HybridExpr | None: + if levels == 0: + return self return HybridBackRefExpr(self.name, levels, self.typ) @@ -895,6 +897,7 @@ def __init__( self._is_connection_root: bool = is_connection_root self._agg_keys: list[HybridExpr] | None = None self._join_keys: list[tuple[HybridExpr, HybridExpr]] | None = None + self._correlated_children: set[int] = set() if isinstance(root_operation, HybridPartition): self._join_keys = [] @@ -935,6 +938,14 @@ def children(self) -> list[HybridConnection]: """ return self._children + @property + def correlated_children(self) -> set[int]: + """ + The set of indices of children that contain correlated references to + the current hybrid tree. + """ + return self._correlated_children + @property def successor(self) -> Optional["HybridTree"]: """ @@ -1584,9 +1595,10 @@ def make_hybrid_correl_expr( # Special case: stepping out of the data argument of PARTITION back # into its ancestor. For example: # TPCH(x=...).PARTITION(data.WHERE(y > BACK(1).x), ...) - if len(parent_tree.pipeline) == 1 and isinstance( + partition_edge_case: bool = len(parent_tree.pipeline) == 1 and isinstance( parent_tree.pipeline[0], HybridPartition - ): + ) + if partition_edge_case: assert parent_tree.parent is not None # Treat the partition's parent as the conext for the back # to step into, as opposed to the partition itself (so the back @@ -1594,26 +1606,8 @@ def make_hybrid_correl_expr( self.stack.append(parent_tree.parent) parent_result = self.make_hybrid_correl_expr( back_expr, collection, steps_taken_so_far - ) + ).expr self.stack.pop() - self.stack.append(parent_tree) - # Then, postprocess the output to account for the fact that a - # BACK level got skipped due to the change in subtree. - match parent_result.expr: - case HybridRefExpr(): - parent_result = HybridBackRefExpr( - parent_result.expr.name, 1, parent_result.typ - ) - case HybridBackRefExpr(): - parent_result = HybridBackRefExpr( - parent_result.expr.name, - parent_result.expr.back_idx + 1, - parent_result.typ, - ) - case _: - raise ValueError( - f"Malformed expression for correlated reference: {parent_result}" - ) elif remaining_steps_back == 0: # If there are no more steps back to be made, then the correlated # reference is to a reference from the current context. @@ -1634,6 +1628,8 @@ def make_hybrid_correl_expr( collection, back_expr.term_name, remaining_steps_back ) parent_result = self.make_hybrid_expr(parent_tree, new_expr, {}, False) + if not isinstance(parent_result, HybridCorrelExpr): + parent_tree.correlated_children.add(len(parent_tree.children)) # Restore parent_tree back onto the stack, since evaluating `back_expr` # does not change the program's current placement in the sutbtrees. self.stack.append(parent_tree) diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 494d7fbd..77c10c86 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -44,6 +44,7 @@ ) from pydough.types import BooleanType, Int64Type, UnknownType +from .hybrid_decorrelater import run_hybrid_decorrelation from .hybrid_tree import ( ConnectionType, HybridBackRefExpr, @@ -648,7 +649,7 @@ def translate_partition( Returns: The TranslationOutput payload containing access to the aggregated - child corresponding tot he partition data. + child corresponding to the partition data. """ expressions: dict[HybridExpr, ColumnReference] = {} # Account for the fact that the PARTITION is stepping down a level, @@ -998,10 +999,11 @@ def convert_ast_to_relational( final_terms: set[str] = node.calc_terms node = translator.preprocess_root(node) - # Convert the QDAG node to the hybrid form, then invoke the relational - # conversion procedure. The first element in the returned list is the - # final rel node. + # Convert the QDAG node to the hybrid form, decorrelate it, then invoke + # the relational conversion procedure. The first element in the returned + # list is the final rel node. hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node, None) + run_hybrid_decorrelation(hybrid) renamings: dict[str, str] = hybrid.pipeline[-1].renamings output: TranslationOutput = translator.rel_translation( None, hybrid, len(hybrid.pipeline) - 1 diff --git a/pydough/pydough_operators/base_operator.py b/pydough/pydough_operators/base_operator.py index fe12629e..82ccfacf 100644 --- a/pydough/pydough_operators/base_operator.py +++ b/pydough/pydough_operators/base_operator.py @@ -80,3 +80,15 @@ def to_string(self, arg_strings: list[str]) -> str: Returns: The string representation of the operator called on its arguments. """ + + @abstractmethod + def equals(self, other: object) -> bool: + """ + Returns whether this operator is equal to another operator. + """ + + def __eq__(self, other: object) -> bool: + return self.equals(other) + + def __hash__(self) -> int: + return hash(repr(self)) diff --git a/pydough/types/struct_type.py b/pydough/types/struct_type.py index 94b3222a..7b3ef680 100644 --- a/pydough/types/struct_type.py +++ b/pydough/types/struct_type.py @@ -109,7 +109,7 @@ def parse_struct_body( except PyDoughTypeException: pass - # Otherwise, iterate across all commas int he right hand side + # Otherwise, iterate across all commas in the right hand side # that are candidate splitting locations between a PyDough # type and a suffix that is a valid list of fields. if field_type is None: diff --git a/tests/test_plan_refsols/correl_1.txt b/tests/test_plan_refsols/correl_1.txt index c5956c80..bcc6d73d 100644 --- a/tests/test_plan_refsols/correl_1.txt +++ b/tests/test_plan_refsols/correl_1.txt @@ -1,8 +1,10 @@ ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[(ordering_1):asc_first]) PROJECT(columns={'n_prefix_nations': n_prefix_nations, 'name': name, 'ordering_1': name}) PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) - JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/correl_15.txt b/tests/test_plan_refsols/correl_15.txt index 0dc3f6a6..960b0c69 100644 --- a/tests/test_plan_refsols/correl_15.txt +++ b/tests/test_plan_refsols/correl_15.txt @@ -1,22 +1,25 @@ ROOT(columns=[('n', n)], orderings=[]) PROJECT(columns={'n': agg_1}) - JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1}, correl_name='corr4') - PROJECT(columns={'avg_price': agg_0}) - AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) - SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1}) + AGGREGATE(keys={}, aggregations={}) + SCAN(table=tpch.PART, columns={'brand': p_brand}) AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) FILTER(condition=True:bool, columns={'account_balance': account_balance}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr3') - PROJECT(columns={'account_balance': account_balance, 'avg_price': agg_0, 'key': key}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'key': t0.key}) - FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'key': key}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr4') + PROJECT(columns={'account_balance': account_balance, 'avg_price': avg_price, 'avg_price_3': agg_0, 'key': key}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'avg_price': t0.avg_price, 'key': t0.key}) + FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'avg_price': avg_price, 'key': key}) + JOIN(conditions=[True:bool], types=['inner'], columns={'account_balance': t1.account_balance, 'avg_price': t0.avg_price, 'key': t1.key, 'nation_key': t1.nation_key}) + PROJECT(columns={'avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) - JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr3') SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - FILTER(condition=container == 'LG DRUM':string & retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price & retail_price < corr4.avg_price * 0.85:float64, columns={'key': key}) + FILTER(condition=container == 'LG DRUM':string & retail_price < corr3.supplycost * 1.5:float64 & retail_price < corr4.avg_price_3 & retail_price < corr4.avg_price * 0.85:float64, columns={'key': key}) SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_17.txt b/tests/test_plan_refsols/correl_17.txt index 4e532f3c..aad7c616 100644 --- a/tests/test_plan_refsols/correl_17.txt +++ b/tests/test_plan_refsols/correl_17.txt @@ -2,8 +2,10 @@ ROOT(columns=[('fullname', fullname)], orderings=[(ordering_0):asc_first]) PROJECT(columns={'fullname': fullname, 'ordering_0': fullname}) PROJECT(columns={'fullname': fname}) FILTER(condition=True:bool, columns={'fname': fname}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'fname': t1.fname}, correl_name='corr1') - PROJECT(columns={'lname': LOWER(name), 'region_key': region_key}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) - PROJECT(columns={'fname': JOIN_STRINGS('-':string, LOWER(name), corr1.lname), 'key': key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + JOIN(conditions=[t0.key == t1.key], types=['inner'], columns={'fname': t1.fname}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + PROJECT(columns={'fname': JOIN_STRINGS('-':string, LOWER(name_3), lname), 'key': key}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'lname': t0.lname, 'name_3': t1.name}) + PROJECT(columns={'key': key, 'lname': LOWER(name), 'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/correl_18.txt b/tests/test_plan_refsols/correl_18.txt index ab36ffd4..db0b5291 100644 --- a/tests/test_plan_refsols/correl_18.txt +++ b/tests/test_plan_refsols/correl_18.txt @@ -2,13 +2,18 @@ ROOT(columns=[('n', n)], orderings=[]) PROJECT(columns={'n': DEFAULT_TO(agg_0, 0:int64)}) AGGREGATE(keys={}, aggregations={'agg_0': SUM(n_above_avg)}) PROJECT(columns={'n_above_avg': DEFAULT_TO(agg_2, 0:int64)}) - JOIN(conditions=[t0.customer_key == t1.customer_key & t0.order_date == t1.order_date], types=['left'], columns={'agg_2': t1.agg_2}, correl_name='corr1') - PROJECT(columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': DEFAULT_TO(agg_1, 0:int64)}) - FILTER(condition=DEFAULT_TO(agg_0, 0:int64) > 1:int64, columns={'agg_1': agg_1, 'customer_key': customer_key, 'order_date': order_date}) - AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_0': COUNT(), 'agg_1': SUM(total_price)}) - FILTER(condition=YEAR(order_date) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) + JOIN(conditions=[t0.customer_key == t1.customer_key & t0.order_date == t1.order_date], types=['left'], columns={'agg_2': t1.agg_2}) + FILTER(condition=DEFAULT_TO(agg_0, 0:int64) > 1:int64, columns={'customer_key': customer_key, 'order_date': order_date}) + AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_0': COUNT()}) + FILTER(condition=YEAR(order_date) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate}) AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_2': COUNT()}) - FILTER(condition=total_price >= 0.5:float64 * corr1.total_price, columns={'customer_key': customer_key, 'order_date': order_date}) - FILTER(condition=YEAR(order_date) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) + FILTER(condition=total_price_3 >= 0.5:float64 * total_price, columns={'customer_key': customer_key, 'order_date': order_date}) + FILTER(condition=YEAR(order_date_2) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price, 'total_price_3': total_price_3}) + JOIN(conditions=[True:bool], types=['inner'], columns={'customer_key': t0.customer_key, 'order_date': t0.order_date, 'order_date_2': t1.order_date, 'total_price': t0.total_price, 'total_price_3': t1.total_price}) + PROJECT(columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': DEFAULT_TO(agg_1, 0:int64)}) + FILTER(condition=DEFAULT_TO(agg_0, 0:int64) > 1:int64, columns={'agg_1': agg_1, 'customer_key': customer_key, 'order_date': order_date}) + AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_0': COUNT(), 'agg_1': SUM(total_price)}) + FILTER(condition=YEAR(order_date) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) + SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate, 'total_price': o_totalprice}) diff --git a/tests/test_plan_refsols/correl_19.txt b/tests/test_plan_refsols/correl_19.txt index a273084b..a65ac794 100644 --- a/tests/test_plan_refsols/correl_19.txt +++ b/tests/test_plan_refsols/correl_19.txt @@ -1,12 +1,16 @@ -ROOT(columns=[('name', name_7), ('n_super_cust', n_super_cust)], orderings=[(ordering_1):desc_last]) - PROJECT(columns={'n_super_cust': n_super_cust, 'name_7': name_3, 'ordering_1': ordering_1}) +ROOT(columns=[('name', name_14), ('n_super_cust', n_super_cust)], orderings=[(ordering_1):desc_last]) + PROJECT(columns={'n_super_cust': n_super_cust, 'name_14': name_3, 'ordering_1': ordering_1}) LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'n_super_cust': n_super_cust, 'name_3': name_3, 'ordering_1': ordering_1}, orderings=[(ordering_1):desc_last]) PROJECT(columns={'n_super_cust': n_super_cust, 'name_3': name_3, 'ordering_1': n_super_cust}) PROJECT(columns={'n_super_cust': DEFAULT_TO(agg_0, 0:int64), 'name_3': name_3}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}, correl_name='corr4') - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'account_balance': t0.account_balance, 'key_2': t1.key, 'name_3': t1.name}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) + JOIN(conditions=[t0.key_2 == t1.key_5 & t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name_3': t1.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=acctbal > corr4.account_balance, columns={'nation_key': nation_key}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) + AGGREGATE(keys={'key': key, 'key_5': key_5}, aggregations={'agg_0': COUNT()}) + FILTER(condition=acctbal > account_balance, columns={'key': key, 'key_5': key_5}) + JOIN(conditions=[t0.key_5 == t1.nation_key], types=['inner'], columns={'account_balance': t0.account_balance, 'acctbal': t1.acctbal, 'key': t0.key, 'key_5': t0.key_5}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'account_balance': t0.account_balance, 'key': t0.key, 'key_5': t1.key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_2.txt b/tests/test_plan_refsols/correl_2.txt index 529b06fd..b5d64b7c 100644 --- a/tests/test_plan_refsols/correl_2.txt +++ b/tests/test_plan_refsols/correl_2.txt @@ -1,12 +1,17 @@ -ROOT(columns=[('name', name_7), ('n_selected_custs', n_selected_custs)], orderings=[(ordering_1):asc_first]) - PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_7': name_6, 'ordering_1': ordering_1}) - PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_6': name_6, 'ordering_1': name_6}) - PROJECT(columns={'n_selected_custs': DEFAULT_TO(agg_0, 0:int64), 'name_6': name_3}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}, correl_name='corr4') - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name}) - FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key, 'name': name}) +ROOT(columns=[('name', name_12), ('n_selected_custs', n_selected_custs)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_12': name_11, 'ordering_1': ordering_1}) + PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_11': name_11, 'ordering_1': name_11}) + PROJECT(columns={'n_selected_custs': DEFAULT_TO(agg_0, 0:int64), 'name_11': name_3}) + JOIN(conditions=[t0.key_2 == t1.key_5 & t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name_3': t1.name}) + FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key}) SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=SLICE(comment, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(corr4.name, None:unknown, 1:int64, None:unknown)), columns={'nation_key': nation_key}) - SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) + AGGREGATE(keys={'key': key, 'key_5': key_5}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(comment_7, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(name, None:unknown, 1:int64, None:unknown)), columns={'key': key, 'key_5': key_5}) + JOIN(conditions=[t0.key_5 == t1.nation_key], types=['inner'], columns={'comment_7': t1.comment, 'key': t0.key, 'key_5': t0.key_5, 'name': t0.name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'name': t0.name}) + FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_20.txt b/tests/test_plan_refsols/correl_20.txt index c1d388c4..670d0880 100644 --- a/tests/test_plan_refsols/correl_20.txt +++ b/tests/test_plan_refsols/correl_20.txt @@ -2,16 +2,27 @@ ROOT(columns=[('n', n)], orderings=[]) PROJECT(columns={'n': agg_0}) AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) FILTER(condition=domestic, columns={'account_balance': account_balance}) - JOIN(conditions=[t0.nation_key_11 == t1.key], types=['left'], columns={'account_balance': t0.account_balance, 'domestic': t1.domestic}, correl_name='corr13') - JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'account_balance': t1.account_balance, 'name': t0.name, 'nation_key_11': t1.nation_key}) - JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'name': t0.name, 'supplier_key': t1.supplier_key}) - FILTER(condition=YEAR(order_date) == 1998:int64 & MONTH(order_date) == 6:int64, columns={'key_5': key_5, 'name': name}) - JOIN(conditions=[t0.key_2 == t1.customer_key], types=['inner'], columns={'key_5': t1.key, 'name': t0.name, 'order_date': t1.order_date}) - JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + JOIN(conditions=[t0.key_9 == t1.key_21 & t0.order_key == t1.order_key & t0.line_number == t1.line_number & t0.key_5 == t1.key_17 & t0.key_2 == t1.key_14 & t0.key == t1.key], types=['left'], columns={'account_balance': t0.account_balance, 'domestic': t1.domestic}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'account_balance': t1.account_balance, 'key': t0.key, 'key_2': t0.key_2, 'key_5': t0.key_5, 'key_9': t1.key, 'line_number': t0.line_number, 'order_key': t0.order_key}) + JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'key': t0.key, 'key_2': t0.key_2, 'key_5': t0.key_5, 'line_number': t1.line_number, 'order_key': t1.order_key, 'supplier_key': t1.supplier_key}) + FILTER(condition=YEAR(order_date) == 1998:int64 & MONTH(order_date) == 6:int64, columns={'key': key, 'key_2': key_2, 'key_5': key_5}) + JOIN(conditions=[t0.key_2 == t1.customer_key], types=['inner'], columns={'key': t0.key, 'key_2': t0.key_2, 'key_5': t1.key, 'order_date': t1.order_date}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) - SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'supplier_key': l_suppkey}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) - PROJECT(columns={'domestic': name == corr13.name, 'key': key}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + SCAN(table=tpch.LINEITEM, columns={'line_number': l_linenumber, 'order_key': l_orderkey, 'supplier_key': l_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey}) + PROJECT(columns={'domestic': name_27 == name, 'key': key, 'key_14': key_14, 'key_17': key_17, 'key_21': key_21, 'line_number': line_number, 'order_key': order_key}) + JOIN(conditions=[t0.nation_key_23 == t1.key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t0.key_17, 'key_21': t0.key_21, 'line_number': t0.line_number, 'name': t0.name, 'name_27': t1.name, 'order_key': t0.order_key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t0.key_17, 'key_21': t1.key, 'line_number': t0.line_number, 'name': t0.name, 'nation_key_23': t1.nation_key, 'order_key': t0.order_key}) + JOIN(conditions=[t0.key_17 == t1.order_key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t0.key_17, 'line_number': t1.line_number, 'name': t0.name, 'order_key': t1.order_key, 'supplier_key': t1.supplier_key}) + FILTER(condition=YEAR(order_date) == 1998:int64 & MONTH(order_date) == 6:int64, columns={'key': key, 'key_14': key_14, 'key_17': key_17, 'name': name}) + JOIN(conditions=[t0.key_14 == t1.customer_key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t1.key, 'name': t0.name, 'order_date': t1.order_date}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_14': t1.key, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) + SCAN(table=tpch.LINEITEM, columns={'line_number': l_linenumber, 'order_key': l_orderkey, 'supplier_key': l_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) diff --git a/tests/test_plan_refsols/correl_3.txt b/tests/test_plan_refsols/correl_3.txt index 57d2dfdf..2bbb01bc 100644 --- a/tests/test_plan_refsols/correl_3.txt +++ b/tests/test_plan_refsols/correl_3.txt @@ -1,11 +1,13 @@ ROOT(columns=[('name', name), ('n_nations', n_nations)], orderings=[(ordering_1):asc_first]) PROJECT(columns={'n_nations': n_nations, 'name': name, 'ordering_1': name}) PROJECT(columns={'n_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) - JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=True:bool, columns={'region_key': region_key}) - JOIN(conditions=[t0.key == t1.nation_key], types=['semi'], columns={'region_key': t0.region_key}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - FILTER(condition=SLICE(comment, None:unknown, 2:int64, None:unknown) == LOWER(SLICE(corr1.name, None:unknown, 2:int64, None:unknown)), columns={'nation_key': nation_key}) + AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=True:bool, columns={'key': key}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['semi'], columns={'key': t0.key}, correl_name='corr4') + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + FILTER(condition=SLICE(comment, None:unknown, 2:int64, None:unknown) == LOWER(SLICE(corr4.name, None:unknown, 2:int64, None:unknown)), columns={'nation_key': nation_key}) SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_6.txt b/tests/test_plan_refsols/correl_6.txt index 0a85a6fa..a6829877 100644 --- a/tests/test_plan_refsols/correl_6.txt +++ b/tests/test_plan_refsols/correl_6.txt @@ -1,8 +1,10 @@ ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[]) PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) FILTER(condition=True:bool, columns={'agg_0': agg_0, 'name': name}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + JOIN(conditions=[t0.key == t1.key], types=['inner'], columns={'agg_0': t1.agg_0, 'name': t0.name}) SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/correl_8.txt b/tests/test_plan_refsols/correl_8.txt index 87bcc66e..8da228f0 100644 --- a/tests/test_plan_refsols/correl_8.txt +++ b/tests/test_plan_refsols/correl_8.txt @@ -1,7 +1,9 @@ ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) - PROJECT(columns={'name': name, 'rname': name_4}) - JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'name': t0.name, 'name_4': t1.name}, correl_name='corr1') - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) - FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name': name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + PROJECT(columns={'name': name, 'rname': name_3}) + JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'name': t0.name, 'name_3': t1.name_3}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name_3': name_3}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/correl_9.txt b/tests/test_plan_refsols/correl_9.txt index 6a7a6c13..449cceea 100644 --- a/tests/test_plan_refsols/correl_9.txt +++ b/tests/test_plan_refsols/correl_9.txt @@ -1,8 +1,10 @@ ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) - PROJECT(columns={'name': name, 'rname': name_4}) - FILTER(condition=True:bool, columns={'name': name, 'name_4': name_4}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name': t0.name, 'name_4': t1.name}, correl_name='corr1') - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) - FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name': name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + PROJECT(columns={'name': name, 'rname': name_3}) + FILTER(condition=True:bool, columns={'name': name, 'name_3': name_3}) + JOIN(conditions=[t0.key == t1.key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name_3}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name_3': name_3}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/tpch_q5.txt b/tests/test_plan_refsols/tpch_q5.txt index fbd35207..7cb9925b 100644 --- a/tests/test_plan_refsols/tpch_q5.txt +++ b/tests/test_plan_refsols/tpch_q5.txt @@ -1,21 +1,26 @@ ROOT(columns=[('N_NAME', N_NAME), ('REVENUE', REVENUE)], orderings=[(ordering_1):desc_last]) PROJECT(columns={'N_NAME': N_NAME, 'REVENUE': REVENUE, 'ordering_1': REVENUE}) PROJECT(columns={'N_NAME': name, 'REVENUE': DEFAULT_TO(agg_0, 0:int64)}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr10') + JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) FILTER(condition=name_3 == 'ASIA':string, columns={'key': key, 'name': name}) JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': SUM(value)}) - PROJECT(columns={'nation_key': nation_key, 'value': extended_price * 1:int64 - discount}) - FILTER(condition=name_9 == corr10.name, columns={'discount': discount, 'extended_price': extended_price, 'nation_key': nation_key}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name_9': t1.name_9, 'nation_key': t0.nation_key}) - JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'nation_key': t0.nation_key, 'supplier_key': t1.supplier_key}) - FILTER(condition=order_date >= datetime.date(1994, 1, 1):date & order_date < datetime.date(1995, 1, 1):date, columns={'key_5': key_5, 'nation_key': nation_key}) - JOIN(conditions=[t0.key == t1.customer_key], types=['inner'], columns={'key_5': t1.key, 'nation_key': t0.nation_key, 'order_date': t1.order_date}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + AGGREGATE(keys={'key': key}, aggregations={'agg_0': SUM(value)}) + PROJECT(columns={'key': key, 'value': extended_price * 1:int64 - discount}) + FILTER(condition=name_15 == name, columns={'discount': discount, 'extended_price': extended_price, 'key': key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'key': t0.key, 'name': t0.name, 'name_15': t1.name_15}) + JOIN(conditions=[t0.key_11 == t1.order_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'key': t0.key, 'name': t0.name, 'supplier_key': t1.supplier_key}) + FILTER(condition=order_date >= datetime.date(1994, 1, 1):date & order_date < datetime.date(1995, 1, 1):date, columns={'key': key, 'key_11': key_11, 'name': name}) + JOIN(conditions=[t0.key_8 == t1.customer_key], types=['inner'], columns={'key': t0.key, 'key_11': t1.key, 'name': t0.name, 'order_date': t1.order_date}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_8': t1.key, 'name': t0.name}) + FILTER(condition=name_6 == 'ASIA':string, columns={'key': key, 'name': name}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_6': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'supplier_key': l_suppkey}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_9': t1.name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_15': t1.name}) SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) From 7e3de91ecdedcf274dbf8f8692d921f769279c54 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 18 Feb 2025 17:44:26 -0500 Subject: [PATCH 5/7] Support uses of BACK that cause correlated references: fix remaining decorrelation edge cases (#254) Child PR of #269 that is part of addressing #141. Handles certain edge cases, such as: - When the de-correlation causes an aggregate to prune all of its outputs, creating an empty SELECT clause. - Fixing a bug in how decorrelation of a PARTITION child results in a cartesian join instead of joining on the partition keys. - Fixing a bug in renaming of hybrid children causing the wrong name overload to be chosen. - Fixing nondeterministic join key ordering during decorrelation. - Adjusting how BACK shifts for join/agg keys are handled so the levels takes into account levels that were skipped when snapshotting the parent. - Adjusting the SQLGlot conversion for aggregations so if there is no aggfunc, it instead does `SELECT DISTINCT` - Added 3 more correlation tests for very particular edge cases involving `PARTITION`, and also made some adjustments to the qualification/hybrid conversion of `PARTITION` to account for cases where a `PARTITION` node is the root of a child operator child access. --- pydough/conversion/hybrid_decorrelater.py | 77 ++++++++++++++----- pydough/conversion/hybrid_tree.py | 51 +++++++++--- pydough/conversion/relational_converter.py | 20 ++++- .../relational_nodes/column_pruner.py | 21 ++++- pydough/sqlglot/sqlglot_relational_visitor.py | 9 ++- pydough/unqualified/qualification.py | 7 +- tests/correlated_pydough_functions.py | 43 +++++++++++ tests/test_pipeline.py | 38 +++++++++ tests/test_plan_refsols/correl_15.txt | 43 +++++------ tests/test_plan_refsols/correl_18.txt | 4 +- tests/test_plan_refsols/correl_19.txt | 31 ++++---- tests/test_plan_refsols/correl_20.txt | 2 +- tests/test_plan_refsols/correl_21.txt | 13 ++++ tests/test_plan_refsols/correl_22.txt | 13 ++++ tests/test_plan_refsols/correl_23.txt | 15 ++++ .../join_regions_nations_calc_override.txt | 6 +- tests/test_plan_refsols/tpch_q22.txt | 32 ++++---- tests/test_qualification.py | 44 +++++++++++ tests/test_relational_nodes_to_sqlglot.py | 15 ++-- tests/test_sql_refsols/simple_distinct.sql | 4 +- tests/test_unqualified_node.py | 2 +- tests/tpch_outputs.py | 2 +- tests/tpch_test_functions.py | 26 ++++--- 23 files changed, 398 insertions(+), 120 deletions(-) create mode 100644 tests/test_plan_refsols/correl_21.txt create mode 100644 tests/test_plan_refsols/correl_22.txt create mode 100644 tests/test_plan_refsols/correl_23.txt diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py index 72f6aad1..a1c36407 100644 --- a/pydough/conversion/hybrid_decorrelater.py +++ b/pydough/conversion/hybrid_decorrelater.py @@ -34,7 +34,7 @@ class Decorrelater: def make_decorrelate_parent( self, hybrid: HybridTree, child_idx: int, required_steps: int - ) -> HybridTree: + ) -> tuple[HybridTree, int]: """ Creates a snapshot of the ancestry of the hybrid tree that contains a correlated child, without any of its children, its descendants, or @@ -50,10 +50,12 @@ def make_decorrelate_parent( derivable. Returns: - A snapshot of `hybrid` and its ancestry in the hybrid tree, without - without any of its children or pipeline operators that occur during - or after the derivation of the correlated child, or without any of - its descendants. + A tuple where the first entry is a snapshot of `hybrid` and its + ancestry in the hybrid tree, without without any of its children or + pipeline operators that occur during or after the derivation of the + correlated child, or without any of its descendants. The second + entry is the number of ancestor layers that should be skipped due + to the PARTITION edge case. """ if isinstance(hybrid.pipeline[0], HybridPartition) and child_idx == 0: # Special case: if the correlated child is the data argument of a @@ -61,10 +63,14 @@ def make_decorrelate_parent( # parent of the level containing the partition operation. In this # case, all of the parent's children & pipeline operators should be # included in the snapshot. - assert hybrid.parent is not None - return self.make_decorrelate_parent( + if hybrid.parent is None: + raise ValueError( + "Malformed hybrid tree: partition data input to a partition node cannot contain a correlated reference to the partition node." + ) + result = self.make_decorrelate_parent( hybrid.parent, len(hybrid.parent.children), len(hybrid.pipeline) ) + return result[0], result[1] + 1 # Temporarily detach the successor of the current level, then create a # deep copy of the current level (which will include its ancestors), # then reattach the successor back to the original. This ensures that @@ -78,7 +84,7 @@ def make_decorrelate_parent( # that is has to. new_hybrid._children = new_hybrid._children[:child_idx] new_hybrid._pipeline = new_hybrid._pipeline[: required_steps + 1] - return new_hybrid + return new_hybrid, 0 def remove_correl_refs( self, expr: HybridExpr, parent: HybridTree, child_height: int @@ -221,6 +227,7 @@ def decorrelate_child( new_parent: HybridTree, child: HybridConnection, is_aggregate: bool, + skipped_levels: int, ) -> None: """ Runs the logic to de-correlate a child of a hybrid tree that contains @@ -230,6 +237,17 @@ def decorrelate_child( child can now replace correlated references with BACK references that point to terms in its newly expanded ancestry, and the original hybrid tree can now join onto this child using its uniqueness keys. + + Args: + `old_parent`: The correlated ancestor hybrid tree that the correlated + references should point to when they are targeted for removal. + `new_parent`: The ancestor of `level` that removal should stop at. + `child`: The child of the hybrid tree that contains the correlated + nodes to be removed. + `is_aggregate`: Whether the child is being aggregated with regards + to its parent. + `skipped_levels`: The number of ancestor layers that should be + ignored when deriving backshifts of join/agg keys. """ # First, find the height of the child subtree & its top-most level. child_root: HybridTree = child.subtree @@ -245,28 +263,41 @@ def decorrelate_child( new_join_keys: list[tuple[HybridExpr, HybridExpr]] = [] additional_levels: int = 0 current_level: HybridTree | None = old_parent + new_agg_keys: list[HybridExpr] = [] while current_level is not None: - for unique_key in current_level.pipeline[0].unique_exprs: + skip_join: bool = ( + isinstance(current_level.pipeline[0], HybridPartition) + and child is current_level.children[0] + ) + for unique_key in sorted(current_level.pipeline[0].unique_exprs, key=str): lhs_key: HybridExpr | None = unique_key.shift_back(additional_levels) rhs_key: HybridExpr | None = unique_key.shift_back( - additional_levels + child_height + additional_levels + child_height - skipped_levels ) assert lhs_key is not None and rhs_key is not None - new_join_keys.append((lhs_key, rhs_key)) + if not skip_join: + new_join_keys.append((lhs_key, rhs_key)) + new_agg_keys.append(rhs_key) current_level = current_level.parent additional_levels += 1 child.subtree.join_keys = new_join_keys - # If aggregating, do the same with the aggregation keys. + # If aggregating, update the aggregation keys accordingly. if is_aggregate: - new_agg_keys: list[HybridExpr] = [] - assert child.subtree.join_keys is not None - for _, rhs_key in child.subtree.join_keys: - new_agg_keys.append(rhs_key) child.subtree.agg_keys = new_agg_keys def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: """ - TODO + The recursive procedure to remove unwanted correlated references from + the entire hybrid tree, called from the bottom and working upwards + to the top layer, and having each layer also de-correlate its children. + + Args: + `hybrid`: The hybrid tree to remove correlated references from. + + Returns: + The hybrid tree with all invalid correlated references removed as the + tree structure is re-written to allow them to be replaced with BACK + references. The transformation is also done in-place. """ # Recursively decorrelate the ancestors of the current level of the # hybrid tree. @@ -282,9 +313,6 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: for idx, child in enumerate(hybrid.children): if idx not in hybrid.correlated_children: continue - new_parent: HybridTree = self.make_decorrelate_parent( - hybrid, idx, hybrid.children[idx].required_steps - ) match child.connection_type: case ( ConnectionType.SINGULAR @@ -292,8 +320,15 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: | ConnectionType.AGGREGATION | ConnectionType.AGGREGATION_ONLY_MATCH ): + new_parent, skipped_levels = self.make_decorrelate_parent( + hybrid, idx, hybrid.children[idx].required_steps + ) self.decorrelate_child( - hybrid, new_parent, child, child.connection_type.is_aggregation + hybrid, + new_parent, + child, + child.connection_type.is_aggregation, + skipped_levels, ) case ConnectionType.NDISTINCT | ConnectionType.NDISTINCT_ONLY_MATCH: raise NotImplementedError( diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 956e88f4..16b0bca5 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -494,17 +494,26 @@ def __init__( for name, expr in predecessor.terms.items(): terms[name] = HybridRefExpr(name, expr.typ) renamings.update(predecessor.renamings) + new_renamings: dict[str, str] = {} for name, expr in new_expressions.items(): if name in terms and terms[name] == expr: continue expr = expr.apply_renamings(predecessor.renamings) used_name: str = name idx: int = 0 - while used_name in terms or used_name in renamings: + while ( + used_name in terms + or used_name in renamings + or used_name in new_renamings + ): used_name = f"{name}_{idx}" idx += 1 terms[used_name] = expr - renamings[name] = used_name + new_renamings[name] = used_name + renamings.update(new_renamings) + for old_name, new_name in new_renamings.items(): + expr = new_expressions.pop(old_name) + new_expressions[new_name] = expr super().__init__(terms, renamings, orderings, predecessor.unique_exprs) self.calc = Calc self.new_expressions = new_expressions @@ -520,7 +529,10 @@ class HybridFilter(HybridOperation): def __init__(self, predecessor: HybridOperation, condition: HybridExpr): super().__init__( - predecessor.terms, {}, predecessor.orderings, predecessor.unique_exprs + predecessor.terms, + predecessor.renamings, + predecessor.orderings, + predecessor.unique_exprs, ) self.predecessor: HybridOperation = predecessor self.condition: HybridExpr = condition @@ -566,7 +578,10 @@ def __init__( records_to_keep: int, ): super().__init__( - predecessor.terms, {}, predecessor.orderings, predecessor.unique_exprs + predecessor.terms, + predecessor.renamings, + predecessor.orderings, + predecessor.unique_exprs, ) self.predecessor: HybridOperation = predecessor self.records_to_keep: int = records_to_keep @@ -908,13 +923,13 @@ def __repr__(self): lines.append(" -> ".join(repr(operation) for operation in self.pipeline)) prefix = " " if self.successor is None else "↓" for idx, child in enumerate(self.children): - lines.append(f"{prefix} child #{idx}:") + lines.append(f"{prefix} child #{idx} ({child.connection_type.name}):") if child.subtree.agg_keys is not None: - lines.append( - f"{prefix} aggregate: {child.subtree.agg_keys} -> {child.aggs}:" - ) + lines.append(f"{prefix} aggregate: {child.subtree.agg_keys}") + if len(child.aggs): + lines.append(f"{prefix} aggs: {child.aggs}:") if child.subtree.join_keys is not None: - lines.append(f"{prefix} join: {child.subtree.join_keys}:") + lines.append(f"{prefix} join: {child.subtree.join_keys}") for line in repr(child.subtree).splitlines(): lines.append(f"{prefix} {line}") return "\n".join(lines) @@ -1964,6 +1979,24 @@ def make_hybrid_tree( rhs_expr.name, 0, rhs_expr.typ ) join_key_exprs.append((lhs_expr, rhs_expr)) + + case PartitionBy(): + partition = HybridPartition() + successor_hybrid = HybridTree(partition) + self.populate_children( + successor_hybrid, node.child_access, child_ref_mapping + ) + partition_child_idx = child_ref_mapping[0] + for key_name in node.calc_terms: + key = node.get_expr(key_name) + expr = self.make_hybrid_expr( + successor_hybrid, key, child_ref_mapping, False + ) + partition.add_key(key_name, expr) + key_exprs.append(HybridRefExpr(key_name, expr.typ)) + successor_hybrid.children[ + partition_child_idx + ].subtree.agg_keys = key_exprs case _: raise NotImplementedError( f"{node.__class__.__name__} (child is {node.child_access.__class__.__name__})" diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 77c10c86..c48132a2 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -885,11 +885,26 @@ def rel_translation( if isinstance(operation.collection, TableCollection): result = self.build_simple_table_scan(operation) if context is not None: + # If the collection access is the child of something + # else, join it onto that something else. Use the + # uniqueness keys of the ancestor, which should also be + # present in the collection (e.g. joining a partition + # onto the original data using the partition keys). + assert preceding_hybrid is not None + join_keys: list[tuple[HybridExpr, HybridExpr]] = [] + for unique_column in sorted( + preceding_hybrid[0].pipeline[0].unique_exprs, key=str + ): + if unique_column not in result.expressions: + raise ValueError( + f"Cannot connect parent context to child {operation.collection} because {unique_column} is not in the child's expressions." + ) + join_keys.append((unique_column, unique_column)) result = self.join_outputs( context, result, JoinType.INNER, - [], + join_keys, None, ) else: @@ -913,7 +928,8 @@ def rel_translation( assert context is not None, "Malformed HybridTree pattern." result = self.translate_filter(operation, context) case HybridPartition(): - assert context is not None, "Malformed HybridTree pattern." + if context is None: + context = TranslationOutput(EmptySingleton(), {}) result = self.translate_partition( operation, context, hybrid, pipeline_idx ) diff --git a/pydough/relational/relational_nodes/column_pruner.py b/pydough/relational/relational_nodes/column_pruner.py index 47cda027..ef8e7cf2 100644 --- a/pydough/relational/relational_nodes/column_pruner.py +++ b/pydough/relational/relational_nodes/column_pruner.py @@ -11,7 +11,8 @@ from .abstract_node import RelationalNode from .aggregate import Aggregate -from .join import Join +from .empty_singleton import EmptySingleton +from .join import Join, JoinType from .project import Project from .relational_expression_dispatcher import RelationalExpressionDispatcher from .relational_root import RelationalRoot @@ -141,7 +142,23 @@ def _prune_node_columns( # Determine the new node. output = new_node.copy(inputs=new_inputs) - return self._prune_identity_project(output), correl_refs + output = self._prune_identity_project(output) + # Special case: replace empty aggregation with VALUES () if possible. + if ( + isinstance(output, Aggregate) + and len(output.keys) == 0 + and len(output.aggregations) == 0 + ): + return EmptySingleton(), correl_refs + # Special case: replace join where LHS is VALUES () with the RHS if + # possible. + if ( + isinstance(output, Join) + and isinstance(output.inputs[0], EmptySingleton) + and output.join_types in ([JoinType.INNER], [JoinType.LEFT]) + ): + return output.inputs[1], correl_refs + return output, correl_refs def prune_unused_columns(self, root: RelationalRoot) -> RelationalRoot: """ diff --git a/pydough/sqlglot/sqlglot_relational_visitor.py b/pydough/sqlglot/sqlglot_relational_visitor.py index 0b53f04d..ca9cbe2c 100644 --- a/pydough/sqlglot/sqlglot_relational_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_visitor.py @@ -431,6 +431,7 @@ def visit_filter(self, filter: Filter) -> None: # TODO: (gh #151) Refactor a simpler way to check dependent expressions. if ( "group" in input_expr.args + or "distinct" in input_expr.args or "where" in input_expr.args or "qualify" in input_expr.args or "order" in input_expr.args @@ -462,6 +463,7 @@ def visit_aggregate(self, aggregate: Aggregate) -> None: query: Select if ( "group" in input_expr.args + or "distinct" in input_expr.args or "qualify" in input_expr.args or "order" in input_expr.args or "limit" in input_expr.args @@ -472,7 +474,10 @@ def visit_aggregate(self, aggregate: Aggregate) -> None: select_cols, input_expr, find_identifiers_in_list(select_cols) ) if keys: - query = query.group_by(*keys) + if aggregations: + query = query.group_by(*keys) + else: + query = query.distinct() self._stack.append(query) def visit_limit(self, limit: Limit) -> None: @@ -511,7 +516,7 @@ def visit_limit(self, limit: Limit) -> None: self._stack.append(query) def visit_empty_singleton(self, singleton: EmptySingleton) -> None: - self._stack.append(Select().from_(values([()]))) + self._stack.append(Select().select(SQLGlotStar()).from_(values([()]))) def visit_root(self, root: RelationalRoot) -> None: self.visit_inputs(root) diff --git a/pydough/unqualified/qualification.py b/pydough/unqualified/qualification.py index ba33907d..0a532d78 100644 --- a/pydough/unqualified/qualification.py +++ b/pydough/unqualified/qualification.py @@ -676,7 +676,12 @@ def qualify_partition( partition: PartitionBy = self.builder.build_partition( qualified_parent, qualified_child, child_name ) - return partition.with_keys(child_references) + partition = partition.with_keys(child_references) + # Special case: if accessing as a child, wrap in a + # ChildOperatorChildAccess term. + if isinstance(unqualified_parent, UnqualifiedRoot) and is_child: + return ChildOperatorChildAccess(partition) + return partition def qualify_collection( self, diff --git a/tests/correlated_pydough_functions.py b/tests/correlated_pydough_functions.py index e3b73053..3d8c4eb7 100644 --- a/tests/correlated_pydough_functions.py +++ b/tests/correlated_pydough_functions.py @@ -300,3 +300,46 @@ def correl_20(): ) instances = selected_orders.lines.supplier.WHERE(is_domestic) return TPCH(n=COUNT(instances)) + + +def correl_21(): + # Correlated back reference example #21: partition edge case. + # Count how many part sizes have an above-average number of parts + # of that size. + # (This is a correlated aggregation access) + sizes = PARTITION(Parts, name="p", by=size)(n_parts=COUNT(p)) + return TPCH(avg_n_parts=AVG(sizes.n_parts))( + n_sizes=COUNT(sizes.WHERE(n_parts > BACK(1).avg_n_parts)) + ) + + +def correl_22(): + # Correlated back reference example #22: partition edge case. + # Finds the top 5 part sizes with the most container types + # where the average retail price of parts of that container type + # & part type is above the global average retail price. + # (This is a correlated aggregation access) + ct_combos = PARTITION(Parts, name="p", by=(container, part_type))( + avg_price=AVG(p.retail_price) + ) + return ( + TPCH(global_avg_price=AVG(Parts.retail_price)) + .PARTITION( + ct_combos.WHERE(avg_price > BACK(1).global_avg_price), + name="ct", + by=container, + )(container, n_types=COUNT(ct)) + .TOP_K(5, (n_types.DESC(), container.ASC())) + ) + + +def correl_23(): + # Correlated back reference example #23: partition edge case. + # Counts how many part sizes have an above-average number of combinations + # of part types/containers. + # (This is a correlated aggregation access) + combos = PARTITION(Parts, name="p", by=(size, part_type, container)) + sizes = PARTITION(combos, name="c", by=size)(n_combos=COUNT(c)) + return TPCH(avg_n_combo=AVG(sizes.n_combos))( + n_sizes=COUNT(sizes.WHERE(n_combos > BACK(1).avg_n_combo)), + ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1ec38017..f862efa2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -33,6 +33,9 @@ correl_18, correl_19, correl_20, + correl_21, + correl_22, + correl_23, ) from simple_pydough_functions import ( agg_partition, @@ -985,6 +988,41 @@ ), id="correl_20", ), + pytest.param( + ( + correl_21, + "correl_21", + lambda: pd.DataFrame({"n_sizes": [30]}), + ), + id="correl_21", + ), + pytest.param( + ( + correl_22, + "correl_22", + lambda: pd.DataFrame( + { + "container": [ + "JUMBO DRUM", + "JUMBO PKG", + "MED DRUM", + "SM BAG", + "LG PKG", + ], + "n_types": [89, 86, 81, 81, 80], + } + ), + ), + id="correl_22", + ), + pytest.param( + ( + correl_23, + "correl_23", + lambda: pd.DataFrame({"n_sizes": [23]}), + ), + id="correl_23", + ), ], ) def pydough_pipeline_test_data( diff --git a/tests/test_plan_refsols/correl_15.txt b/tests/test_plan_refsols/correl_15.txt index 960b0c69..b71be80d 100644 --- a/tests/test_plan_refsols/correl_15.txt +++ b/tests/test_plan_refsols/correl_15.txt @@ -1,25 +1,22 @@ ROOT(columns=[('n', n)], orderings=[]) PROJECT(columns={'n': agg_1}) - JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1}) - AGGREGATE(keys={}, aggregations={}) - SCAN(table=tpch.PART, columns={'brand': p_brand}) - AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) - FILTER(condition=True:bool, columns={'account_balance': account_balance}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr4') - PROJECT(columns={'account_balance': account_balance, 'avg_price': avg_price, 'avg_price_3': agg_0, 'key': key}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'avg_price': t0.avg_price, 'key': t0.key}) - FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'avg_price': avg_price, 'key': key}) - JOIN(conditions=[True:bool], types=['inner'], columns={'account_balance': t1.account_balance, 'avg_price': t0.avg_price, 'key': t1.key, 'nation_key': t1.nation_key}) - PROJECT(columns={'avg_price': agg_0}) - AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) - SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) - AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) - FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) - JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr3') - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - FILTER(condition=container == 'LG DRUM':string & retail_price < corr3.supplycost * 1.5:float64 & retail_price < corr4.avg_price_3 & retail_price < corr4.avg_price * 0.85:float64, columns={'key': key}) - SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) + AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) + FILTER(condition=True:bool, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr4') + PROJECT(columns={'account_balance': account_balance, 'avg_price': avg_price, 'avg_price_3': agg_0, 'key': key}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'avg_price': t0.avg_price, 'key': t0.key}) + FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'avg_price': avg_price, 'key': key}) + JOIN(conditions=[True:bool], types=['inner'], columns={'account_balance': t1.account_balance, 'avg_price': t0.avg_price, 'key': t1.key, 'nation_key': t1.nation_key}) + PROJECT(columns={'avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) + FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr3') + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + FILTER(condition=container == 'LG DRUM':string & retail_price < corr3.supplycost * 1.5:float64 & retail_price < corr4.avg_price_3 & retail_price < corr4.avg_price * 0.85:float64, columns={'key': key}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_18.txt b/tests/test_plan_refsols/correl_18.txt index db0b5291..74fab0da 100644 --- a/tests/test_plan_refsols/correl_18.txt +++ b/tests/test_plan_refsols/correl_18.txt @@ -10,10 +10,10 @@ ROOT(columns=[('n', n)], orderings=[]) AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_2': COUNT()}) FILTER(condition=total_price_3 >= 0.5:float64 * total_price, columns={'customer_key': customer_key, 'order_date': order_date}) FILTER(condition=YEAR(order_date_2) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price, 'total_price_3': total_price_3}) - JOIN(conditions=[True:bool], types=['inner'], columns={'customer_key': t0.customer_key, 'order_date': t0.order_date, 'order_date_2': t1.order_date, 'total_price': t0.total_price, 'total_price_3': t1.total_price}) + JOIN(conditions=[t0.customer_key == t1.customer_key & t0.order_date == t1.order_date], types=['inner'], columns={'customer_key': t0.customer_key, 'order_date': t0.order_date, 'order_date_2': t1.order_date, 'total_price': t0.total_price, 'total_price_3': t1.total_price}) PROJECT(columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': DEFAULT_TO(agg_1, 0:int64)}) FILTER(condition=DEFAULT_TO(agg_0, 0:int64) > 1:int64, columns={'agg_1': agg_1, 'customer_key': customer_key, 'order_date': order_date}) AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_0': COUNT(), 'agg_1': SUM(total_price)}) FILTER(condition=YEAR(order_date) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price}) SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) - SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate, 'total_price': o_totalprice}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) diff --git a/tests/test_plan_refsols/correl_19.txt b/tests/test_plan_refsols/correl_19.txt index a65ac794..ddaa017a 100644 --- a/tests/test_plan_refsols/correl_19.txt +++ b/tests/test_plan_refsols/correl_19.txt @@ -1,16 +1,15 @@ -ROOT(columns=[('name', name_14), ('n_super_cust', n_super_cust)], orderings=[(ordering_1):desc_last]) - PROJECT(columns={'n_super_cust': n_super_cust, 'name_14': name_3, 'ordering_1': ordering_1}) - LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'n_super_cust': n_super_cust, 'name_3': name_3, 'ordering_1': ordering_1}, orderings=[(ordering_1):desc_last]) - PROJECT(columns={'n_super_cust': n_super_cust, 'name_3': name_3, 'ordering_1': n_super_cust}) - PROJECT(columns={'n_super_cust': DEFAULT_TO(agg_0, 0:int64), 'name_3': name_3}) - JOIN(conditions=[t0.key_2 == t1.key_5 & t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name_3': t1.name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'key': key, 'key_5': key_5}, aggregations={'agg_0': COUNT()}) - FILTER(condition=acctbal > account_balance, columns={'key': key, 'key_5': key_5}) - JOIN(conditions=[t0.key_5 == t1.nation_key], types=['inner'], columns={'account_balance': t0.account_balance, 'acctbal': t1.acctbal, 'key': t0.key, 'key_5': t0.key_5}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'account_balance': t0.account_balance, 'key': t0.key, 'key_5': t1.key}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) +ROOT(columns=[('name', name_0), ('n_super_cust', n_super_cust)], orderings=[(ordering_1):desc_last]) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'n_super_cust': n_super_cust, 'name_0': name_0, 'ordering_1': ordering_1}, orderings=[(ordering_1):desc_last]) + PROJECT(columns={'n_super_cust': n_super_cust, 'name_0': name_0, 'ordering_1': n_super_cust}) + PROJECT(columns={'n_super_cust': DEFAULT_TO(agg_0, 0:int64), 'name_0': name}) + JOIN(conditions=[t0.key_2 == t1.key_5 & t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + AGGREGATE(keys={'key': key, 'key_5': key_5}, aggregations={'agg_0': COUNT()}) + FILTER(condition=acctbal > account_balance, columns={'key': key, 'key_5': key_5}) + JOIN(conditions=[t0.key_5 == t1.nation_key], types=['inner'], columns={'account_balance': t0.account_balance, 'acctbal': t1.acctbal, 'key': t0.key, 'key_5': t0.key_5}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'account_balance': t0.account_balance, 'key': t0.key, 'key_5': t1.key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_20.txt b/tests/test_plan_refsols/correl_20.txt index 670d0880..66ed3e35 100644 --- a/tests/test_plan_refsols/correl_20.txt +++ b/tests/test_plan_refsols/correl_20.txt @@ -2,7 +2,7 @@ ROOT(columns=[('n', n)], orderings=[]) PROJECT(columns={'n': agg_0}) AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) FILTER(condition=domestic, columns={'account_balance': account_balance}) - JOIN(conditions=[t0.key_9 == t1.key_21 & t0.order_key == t1.order_key & t0.line_number == t1.line_number & t0.key_5 == t1.key_17 & t0.key_2 == t1.key_14 & t0.key == t1.key], types=['left'], columns={'account_balance': t0.account_balance, 'domestic': t1.domestic}) + JOIN(conditions=[t0.key_9 == t1.key_21 & t0.line_number == t1.line_number & t0.order_key == t1.order_key & t0.key_5 == t1.key_17 & t0.key_2 == t1.key_14 & t0.key == t1.key], types=['left'], columns={'account_balance': t0.account_balance, 'domestic': t1.domestic}) JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'account_balance': t1.account_balance, 'key': t0.key, 'key_2': t0.key_2, 'key_5': t0.key_5, 'key_9': t1.key, 'line_number': t0.line_number, 'order_key': t0.order_key}) JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'key': t0.key, 'key_2': t0.key_2, 'key_5': t0.key_5, 'line_number': t1.line_number, 'order_key': t1.order_key, 'supplier_key': t1.supplier_key}) FILTER(condition=YEAR(order_date) == 1998:int64 & MONTH(order_date) == 6:int64, columns={'key': key, 'key_2': key_2, 'key_5': key_5}) diff --git a/tests/test_plan_refsols/correl_21.txt b/tests/test_plan_refsols/correl_21.txt new file mode 100644 index 00000000..24c781d1 --- /dev/null +++ b/tests/test_plan_refsols/correl_21.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('n_sizes', n_sizes)], orderings=[]) + PROJECT(columns={'n_sizes': agg_1}) + AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) + FILTER(condition=n_parts > avg_n_parts, columns={'agg_0': agg_0}) + PROJECT(columns={'agg_0': agg_0, 'avg_n_parts': avg_n_parts, 'n_parts': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t1.agg_0, 'avg_n_parts': t0.avg_n_parts}) + PROJECT(columns={'avg_n_parts': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(n_parts)}) + PROJECT(columns={'n_parts': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + SCAN(table=tpch.PART, columns={'size': p_size}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + SCAN(table=tpch.PART, columns={'size': p_size}) diff --git a/tests/test_plan_refsols/correl_22.txt b/tests/test_plan_refsols/correl_22.txt new file mode 100644 index 00000000..4c5be854 --- /dev/null +++ b/tests/test_plan_refsols/correl_22.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('container', container), ('n_types', n_types)], orderings=[(ordering_2):desc_last, (ordering_3):asc_first]) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'container': container, 'n_types': n_types, 'ordering_2': ordering_2, 'ordering_3': ordering_3}, orderings=[(ordering_2):desc_last, (ordering_3):asc_first]) + PROJECT(columns={'container': container, 'n_types': n_types, 'ordering_2': n_types, 'ordering_3': container}) + PROJECT(columns={'container': container, 'n_types': DEFAULT_TO(agg_1, 0:int64)}) + AGGREGATE(keys={'container': container}, aggregations={'agg_1': COUNT()}) + FILTER(condition=avg_price > global_avg_price, columns={'container': container}) + PROJECT(columns={'avg_price': agg_0, 'container': container, 'global_avg_price': global_avg_price}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t1.agg_0, 'container': t1.container, 'global_avg_price': t0.global_avg_price}) + PROJECT(columns={'global_avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + AGGREGATE(keys={'container': container, 'part_type': part_type}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'container': p_container, 'part_type': p_type, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_23.txt b/tests/test_plan_refsols/correl_23.txt new file mode 100644 index 00000000..5bbadbc9 --- /dev/null +++ b/tests/test_plan_refsols/correl_23.txt @@ -0,0 +1,15 @@ +ROOT(columns=[('n_sizes', n_sizes)], orderings=[]) + PROJECT(columns={'n_sizes': agg_1}) + AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) + FILTER(condition=n_combos > avg_n_combo, columns={'agg_0': agg_0}) + PROJECT(columns={'agg_0': agg_0, 'avg_n_combo': avg_n_combo, 'n_combos': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t1.agg_0, 'avg_n_combo': t0.avg_n_combo}) + PROJECT(columns={'avg_n_combo': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(n_combos)}) + PROJECT(columns={'n_combos': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + AGGREGATE(keys={'container': container, 'part_type': part_type, 'size': size}, aggregations={}) + SCAN(table=tpch.PART, columns={'container': p_container, 'part_type': p_type, 'size': p_size}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + AGGREGATE(keys={'container': container, 'part_type': part_type, 'size': size}, aggregations={}) + SCAN(table=tpch.PART, columns={'container': p_container, 'part_type': p_type, 'size': p_size}) diff --git a/tests/test_plan_refsols/join_regions_nations_calc_override.txt b/tests/test_plan_refsols/join_regions_nations_calc_override.txt index 97156430..e1cce36c 100644 --- a/tests/test_plan_refsols/join_regions_nations_calc_override.txt +++ b/tests/test_plan_refsols/join_regions_nations_calc_override.txt @@ -1,6 +1,6 @@ -ROOT(columns=[('key', key_0_9), ('name', name_11), ('phone', phone), ('mktsegment', mktsegment)], orderings=[]) - PROJECT(columns={'key_0_9': key_0_9, 'mktsegment': mktsegment, 'name_11': name_10, 'phone': phone}) - PROJECT(columns={'key_0_9': -3:int64, 'mktsegment': mktsegment, 'name_10': name_7, 'phone': phone}) +ROOT(columns=[('key', key_0_10), ('name', name_11), ('phone', phone), ('mktsegment', mktsegment)], orderings=[]) + PROJECT(columns={'key_0_10': key_0_10, 'mktsegment': mktsegment, 'name_11': name_9, 'phone': phone}) + PROJECT(columns={'key_0_10': -3:int64, 'mktsegment': mktsegment, 'name_9': name_7, 'phone': phone}) JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'mktsegment': t1.mktsegment, 'name_7': t1.name, 'phone': t1.phone}) JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key}) SCAN(table=tpch.REGION, columns={'key': r_regionkey}) diff --git a/tests/test_plan_refsols/tpch_q22.txt b/tests/test_plan_refsols/tpch_q22.txt index 1bd521eb..01b9f860 100644 --- a/tests/test_plan_refsols/tpch_q22.txt +++ b/tests/test_plan_refsols/tpch_q22.txt @@ -1,18 +1,18 @@ -ROOT(columns=[('CNTRY_CODE', CNTRY_CODE), ('NUM_CUSTS', NUM_CUSTS), ('TOTACCTBAL', TOTACCTBAL)], orderings=[]) - PROJECT(columns={'CNTRY_CODE': cntry_code, 'NUM_CUSTS': DEFAULT_TO(agg_1, 0:int64), 'TOTACCTBAL': DEFAULT_TO(agg_2, 0:int64)}) - JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1, 'agg_2': t1.agg_2, 'cntry_code': t1.cntry_code}, correl_name='corr1') - PROJECT(columns={'avg_balance': agg_0}) - AGGREGATE(keys={}, aggregations={'agg_0': AVG(acctbal)}) - FILTER(condition=acctbal > 0.0:float64, columns={'acctbal': acctbal}) - FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]) & True:bool, columns={'acctbal': acctbal}) - JOIN(conditions=[t0.key == t1.customer_key], types=['anti'], columns={'acctbal': t0.acctbal, 'cntry_code': t0.cntry_code}) - PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) +ROOT(columns=[('CNTRY_CODE', CNTRY_CODE), ('NUM_CUSTS', NUM_CUSTS), ('TOTACCTBAL', TOTACCTBAL)], orderings=[(ordering_3):asc_first]) + PROJECT(columns={'CNTRY_CODE': CNTRY_CODE, 'NUM_CUSTS': NUM_CUSTS, 'TOTACCTBAL': TOTACCTBAL, 'ordering_3': CNTRY_CODE}) + PROJECT(columns={'CNTRY_CODE': cntry_code, 'NUM_CUSTS': DEFAULT_TO(agg_1, 0:int64), 'TOTACCTBAL': DEFAULT_TO(agg_2, 0:int64)}) AGGREGATE(keys={'cntry_code': cntry_code}, aggregations={'agg_1': COUNT(), 'agg_2': SUM(acctbal)}) - FILTER(condition=acctbal > corr1.avg_balance, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) - FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]) & True:bool, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) - JOIN(conditions=[t0.key == t1.customer_key], types=['anti'], columns={'acctbal': t0.acctbal, 'cntry_code': t0.cntry_code}) - PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) + FILTER(condition=acctbal > avg_balance & DEFAULT_TO(agg_0, 0:int64) == 0:int64, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) + JOIN(conditions=[t0.key == t1.customer_key], types=['left'], columns={'acctbal': t0.acctbal, 'agg_0': t1.agg_0, 'avg_balance': t0.avg_balance, 'cntry_code': t0.cntry_code}) + FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]), columns={'acctbal': acctbal, 'avg_balance': avg_balance, 'cntry_code': cntry_code, 'key': key}) + PROJECT(columns={'acctbal': acctbal, 'avg_balance': avg_balance, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) + JOIN(conditions=[True:bool], types=['inner'], columns={'acctbal': t1.acctbal, 'avg_balance': t0.avg_balance, 'key': t1.key, 'phone': t1.phone}) + PROJECT(columns={'avg_balance': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(acctbal)}) + FILTER(condition=acctbal > 0.0:float64, columns={'acctbal': acctbal}) + FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]), columns={'acctbal': acctbal}) + PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown)}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'phone': c_phone}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) + AGGREGATE(keys={'customer_key': customer_key}, aggregations={'agg_0': COUNT()}) SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) diff --git a/tests/test_qualification.py b/tests/test_qualification.py index 6e25259d..eab7c54b 100644 --- a/tests/test_qualification.py +++ b/tests/test_qualification.py @@ -68,6 +68,26 @@ def pydough_impl_misc_02(root: UnqualifiedNode) -> UnqualifiedNode: ) +def pydough_impl_misc_03(root: UnqualifiedNode) -> UnqualifiedNode: + """ + Creates an UnqualifiedNode for the following PyDough snippet: + ``` + sizes = PARTITION(Parts, name="p", by=size)(n_parts=COUNT(p)) + TPCH( + avg_n_parts=AVG(sizes.n_parts) + )( + n_parts=COUNT(sizes.WHERE(n_parts > BACK(1).avg_n_parts)) + ) + ``` + """ + sizes = root.PARTITION(root.Parts, name="p", by=root.size)( + n_parts=root.COUNT(root.p) + ) + return root.TPCH(avg_n_parts=root.AVG(sizes.n_parts))( + n_parts=root.COUNT(sizes.WHERE(root.n_parts > root.BACK(1).avg_n_parts)) + ) + + def pydough_impl_tpch_q1(root: UnqualifiedNode) -> UnqualifiedNode: """ Creates an UnqualifiedNode for TPC-H query 1. @@ -612,6 +632,30 @@ def pydough_impl_tpch_q22(root: UnqualifiedNode) -> UnqualifiedNode: """, id="misc_02", ), + pytest.param( + pydough_impl_misc_03, + """ +┌─── TPCH +├─┬─ Calc[avg_n_parts=AVG($1.n_parts)] +│ └─┬─ AccessChild +│ ├─┬─ Partition[name='p', by=size] +│ │ └─┬─ AccessChild +│ │ └─── TableCollection[Parts] +│ └─┬─ Calc[n_parts=COUNT($1)] +│ └─┬─ AccessChild +│ └─── PartitionChild[p] +└─┬─ Calc[n_parts=COUNT($1)] + └─┬─ AccessChild + ├─┬─ Partition[name='p', by=size] + │ └─┬─ AccessChild + │ └─── TableCollection[Parts] + ├─┬─ Calc[n_parts=COUNT($1)] + │ └─┬─ AccessChild + │ └─── PartitionChild[p] + └─── Where[n_parts > BACK(1).avg_n_parts] +""", + id="misc_03", + ), pytest.param( pydough_impl_tpch_q1, """ diff --git a/tests/test_relational_nodes_to_sqlglot.py b/tests/test_relational_nodes_to_sqlglot.py index 7859c88f..6c3c04cd 100644 --- a/tests/test_relational_nodes_to_sqlglot.py +++ b/tests/test_relational_nodes_to_sqlglot.py @@ -152,6 +152,8 @@ def mkglot(expressions: list[Expression], _from: Expression, **kwargs) -> Select query = query.where(kwargs.pop("where")) if "group_by" in kwargs: query = query.group_by(*kwargs.pop("group_by")) + if kwargs.pop("distinct", False): + query = query.distinct() if "qualify" in kwargs: query = query.qualify(kwargs.pop("qualify")) if "order_by" in kwargs: @@ -627,14 +629,15 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: Aggregate( input=build_simple_scan(), keys={ + "a": make_relational_column_reference("a"), "b": make_relational_column_reference("b"), }, aggregations={}, ), mkglot( - expressions=[Ident(this="b")], + expressions=[Ident(this="a"), Ident(this="b")], _from=GlotFrom(Table(this=Ident(this="table"))), - group_by=[Ident(this="b")], + distinct=True, ), id="simple_distinct", ), @@ -718,7 +721,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: mkglot( expressions=[Ident(this="b")], where=mkglot_func(EQ, [Ident(this="a"), mk_literal(1, False)]), - group_by=[Ident(this="b")], + distinct=True, _from=GlotFrom( mkglot( expressions=[Ident(this="a"), Ident(this="b")], @@ -794,7 +797,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: ), mkglot( expressions=[Ident(this="b")], - group_by=[Ident(this="b")], + distinct=True, _from=GlotFrom( mkglot( expressions=[Ident(this="a"), Ident(this="b")], @@ -829,7 +832,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: mkglot( expressions=[Ident(this="b")], _from=GlotFrom(Table(this=Ident(this="table"))), - group_by=[Ident(this="b")], + distinct=True, order_by=[Ident(this="b").desc(nulls_first=False)], limit=mk_literal(10, False), ), @@ -865,8 +868,8 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: _from=GlotFrom( mkglot( expressions=[Ident(this="b")], - group_by=[Ident(this="b")], _from=GlotFrom(Table(this=Ident(this="table"))), + distinct=True, ) ), ), diff --git a/tests/test_sql_refsols/simple_distinct.sql b/tests/test_sql_refsols/simple_distinct.sql index 696d95c0..38ad7566 100644 --- a/tests/test_sql_refsols/simple_distinct.sql +++ b/tests/test_sql_refsols/simple_distinct.sql @@ -1,5 +1,3 @@ -SELECT +SELECT DISTINCT b FROM table -GROUP BY - b diff --git a/tests/test_unqualified_node.py b/tests/test_unqualified_node.py index 386266b2..34c7921c 100644 --- a/tests/test_unqualified_node.py +++ b/tests/test_unqualified_node.py @@ -418,7 +418,7 @@ def test_unqualified_to_string( ), pytest.param( impl_tpch_q22, - "?.TPCH(avg_balance=AVG(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE((ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17']) & HASNOT(?.orders))).WHERE((?.acctbal > 0.0)).acctbal)).PARTITION(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE((ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17']) & HASNOT(?.orders))).WHERE((?.acctbal > BACK(1).avg_balance)), name='custs', by=(?.cntry_code))(CNTRY_CODE=?.cntry_code, NUM_CUSTS=COUNT(?.custs), TOTACCTBAL=SUM(?.custs.acctbal))", + "?.TPCH(avg_balance=AVG(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE(ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17'])).WHERE((?.acctbal > 0.0)).acctbal)).PARTITION(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE(ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17'])).WHERE(((?.acctbal > BACK(1).avg_balance) & (COUNT(?.orders) == 0))), name='custs', by=(?.cntry_code))(CNTRY_CODE=?.cntry_code, NUM_CUSTS=COUNT(?.custs), TOTACCTBAL=SUM(?.custs.acctbal)).ORDER_BY(?.CNTRY_CODE.ASC(na_pos='first'))", id="tpch_q22", ), pytest.param( diff --git a/tests/tpch_outputs.py b/tests/tpch_outputs.py index 80948acb..82fbe5cb 100644 --- a/tests/tpch_outputs.py +++ b/tests/tpch_outputs.py @@ -715,7 +715,7 @@ def tpch_q22_output() -> pd.DataFrame: This query needs manual rewriting to run efficiently in SQLite by avoiding the correlated join. """ - columns = ["CNTRYCODE", "NUMCUST", "TOTACCTBAL"] + columns = ["CNTRY_CODE", "NUM_CUSTS", "TOTACCTBAL"] data = [ ("13", 888, 6737713.99), ("17", 861, 6460573.72), diff --git a/tests/tpch_test_functions.py b/tests/tpch_test_functions.py index 21cfdaa8..c1df2cf8 100644 --- a/tests/tpch_test_functions.py +++ b/tests/tpch_test_functions.py @@ -494,16 +494,20 @@ def impl_tpch_q22(): PyDough implementation of TPCH Q22. """ selected_customers = Customers(cntry_code=phone[:2]).WHERE( - ISIN(cntry_code, ("13", "31", "23", "29", "30", "18", "17")) & HASNOT(orders) + ISIN(cntry_code, ("13", "31", "23", "29", "30", "18", "17")) ) - return TPCH( - avg_balance=AVG(selected_customers.WHERE(acctbal > 0.0).acctbal) - ).PARTITION( - selected_customers.WHERE(acctbal > BACK(1).avg_balance), - name="custs", - by=cntry_code, - )( - CNTRY_CODE=cntry_code, - NUM_CUSTS=COUNT(custs), - TOTACCTBAL=SUM(custs.acctbal), + return ( + TPCH(avg_balance=AVG(selected_customers.WHERE(acctbal > 0.0).acctbal)) + .PARTITION( + selected_customers.WHERE( + (acctbal > BACK(1).avg_balance) & (COUNT(orders) == 0) + ), + name="custs", + by=cntry_code, + )( + CNTRY_CODE=cntry_code, + NUM_CUSTS=COUNT(custs), + TOTACCTBAL=SUM(custs.acctbal), + ) + .ORDER_BY(CNTRY_CODE.ASC()) ) From 6a932b6bd812ddc221a0c0d7765c4449a9323b8e Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 18 Feb 2025 22:00:45 -0500 Subject: [PATCH 6/7] Fix bug with accessing multiple chained PARTITION children (#272) **Goal**: fix the bug causing PyDough such as the following to fail: ```py data = Tickers(symbol).TOP_K(5, by=symbol.ASC()) grps_a = PARTITION(data, name="child_3", by=(currency, exchange, ticker_type)) grps_b = PARTITION(grps_a, name="child_2", by=(currency, exchange)) grps_c = PARTITION(grps_b, name="child_1", by=exchange) result = grps_c.child_1.child_2.child_3 ``` The error is an `IndexError: list index out of range` in the following section of hybrid conversion: ```py case PartitionChild(): hybrid = self.make_hybrid_tree(node.ancestor_context, parent) successor_hybrid = HybridTree( HybridPartitionChild(hybrid.children[0].subtree) ) ... ``` Reason: if `node.ancestor_context` is **also** a partition child, the hybrid child subtree handling/invocation is not being done correctly. The correct way is to recursively step into `hybrid.pipeline[0].subtree` until `hybrid.pipeline[0]` is no longer a partition child. Once this happens, then we can access `children[0].subtree`, because then we have found original data being partitioned. --- pydough/conversion/hybrid_tree.py | 11 ++++++--- tests/simple_pydough_functions.py | 38 +++++++++++++++++++++++++++++++ tests/test_pipeline.py | 36 +++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 3 deletions(-) diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 16b0bca5..8e7d4a49 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -1880,9 +1880,14 @@ def make_hybrid_tree( return successor_hybrid case PartitionChild(): hybrid = self.make_hybrid_tree(node.ancestor_context, parent) - successor_hybrid = HybridTree( - HybridPartitionChild(hybrid.children[0].subtree) - ) + # Identify the original data being partitioned, which may + # require stepping in multiple times if the partition is + # nested inside another partition. + src_tree: HybridTree = hybrid + while isinstance(src_tree.pipeline[0], HybridPartitionChild): + src_tree = src_tree.pipeline[0].subtree + subtree: HybridTree = src_tree.children[0].subtree + successor_hybrid = HybridTree(HybridPartitionChild(subtree)) hybrid.add_successor(successor_hybrid) return successor_hybrid case Calc(): diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index df322958..e80d9249 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -243,6 +243,44 @@ def agg_partition(): return TPCH(best_year=MAX(yearly_data.n_orders)) +def multi_partition_access_1(): + # A use of multiple PARTITION and stepping into partition children that is + # a no-op. + data = Tickers(symbol).TOP_K(5, by=symbol.ASC()) + grps_a = PARTITION(data, name="child_3", by=(currency, exchange, ticker_type)) + grps_b = PARTITION(grps_a, name="child_2", by=(currency, exchange)) + grps_c = PARTITION(grps_b, name="child_1", by=exchange) + return grps_c.child_1.child_2.child_3 + + +def multi_partition_access_2(): + # Identify transactions that are below the average number of shares for + # transactions of the same combinations of (customer, stock, type), or + # the same combination of (customer, stock), or the same customer. + grps_a = PARTITION( + Transactions, name="child_3", by=(customer_id, ticker_id, transaction_type) + )(avg_shares_a=AVG(child_3.shares)) + grps_b = PARTITION(grps_a, name="child_2", by=(customer_id, ticker_id))( + avg_shares_b=AVG(child_2.child_3.shares) + ) + grps_c = PARTITION(grps_b, name="child_1", by=customer_id)( + avg_shares_c=AVG(child_1.child_2.child_3.shares) + ) + return grps_c.child_1.child_2.child_3.WHERE( + (shares < BACK(1).avg_shares_a) + & (shares < BACK(2).avg_shares_b) + & (shares < BACK(3).avg_shares_c) + )( + transaction_id, + customer.name, + ticker.symbol, + transaction_type, + BACK(1).avg_shares_a, + BACK(2).avg_shares_b, + BACK(3).avg_shares_c, + ).ORDER_BY(transaction_id.ASC()) + + def double_partition(): # Doing a partition aggregation on the output of a partition aggregation year_month_data = PARTITION( diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index f862efa2..6faf383f 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -44,6 +44,8 @@ function_sampler, hour_minute_day, minutes_seconds_datediff, + multi_partition_access_1, + multi_partition_access_2, percentile_customers_per_region, percentile_nations, rank_nations_by_region, @@ -1142,6 +1144,40 @@ def test_pipeline_e2e_errors( @pytest.fixture( params=[ + pytest.param( + ( + multi_partition_access_1, + "Broker", + lambda: pd.DataFrame( + {"symbol": ["AAPL", "AMZN", "BRK.B", "FB", "GOOG"]} + ), + ), + id="multi_partition_access_1", + ), + pytest.param( + ( + multi_partition_access_2, + "Broker", + lambda: pd.DataFrame( + { + "transaction_id": [f"TX{i:03}" for i in (22, 24, 25, 27, 56)], + "name": [ + "Jane Smith", + "Samantha Lee", + "Michael Chen", + "David Kim", + "Jane Smith", + ], + "symbol": ["MSFT", "TSLA", "GOOGL", "BRK.B", "FB"], + "transaction_type": ["sell", "sell", "buy", "buy", "sell"], + "avg_shares_a": [56.66667, 55.0, 4.0, 55.5, 47.5], + "avg_shares_b": [50.0, 41.66667, 3.33333, 37.33333, 47.5], + "avg_shares_c": [50.625, 46.25, 40.0, 37.33333, 50.625], + } + ), + ), + id="multi_partition_access_2", + ), pytest.param( ( hour_minute_day, From 269fe28395206cd6ec3e7504ad9028aa0939f86d Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 18 Feb 2025 22:02:11 -0500 Subject: [PATCH 7/7] [RUN CI]