diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py new file mode 100644 index 00000000..a1c36407 --- /dev/null +++ b/pydough/conversion/hybrid_decorrelater.py @@ -0,0 +1,365 @@ +""" +Logic for applying de-correlation to hybrid trees before relational conversion +if the correlate is not a semi/anti join. +""" + +__all__ = ["run_hybrid_decorrelation"] + + +import copy + +from .hybrid_tree import ( + ConnectionType, + HybridBackRefExpr, + HybridCalc, + HybridChildRefExpr, + HybridColumnExpr, + HybridConnection, + HybridCorrelExpr, + HybridExpr, + HybridFilter, + HybridFunctionExpr, + HybridLiteralExpr, + HybridPartition, + HybridRefExpr, + HybridTree, + HybridWindowExpr, +) + + +class Decorrelater: + """ + Class that encapsulates the logic used for de-correlation of hybrid trees. + """ + + def make_decorrelate_parent( + self, hybrid: HybridTree, child_idx: int, required_steps: int + ) -> tuple[HybridTree, int]: + """ + Creates a snapshot of the ancestry of the hybrid tree that contains + a correlated child, without any of its children, its descendants, or + any pipeline operators that do not need to be there. + + Args: + `hybrid`: The hybrid tree to create a snapshot of in order to aid + in the de-correlation of a correlated child. + `child_idx`: The index of the correlated child of hybrid that the + snapshot is being created to aid in the de-correlation of. + `required_steps`: The index of the last pipeline operator that + needs to be included in the snapshot in order for the child to be + derivable. + + Returns: + A tuple where the first entry is a snapshot of `hybrid` and its + ancestry in the hybrid tree, without without any of its children or + pipeline operators that occur during or after the derivation of the + correlated child, or without any of its descendants. The second + entry is the number of ancestor layers that should be skipped due + to the PARTITION edge case. + """ + if isinstance(hybrid.pipeline[0], HybridPartition) and child_idx == 0: + # Special case: if the correlated child is the data argument of a + # partition operation, then the parent to snapshot is actually the + # parent of the level containing the partition operation. In this + # case, all of the parent's children & pipeline operators should be + # included in the snapshot. + if hybrid.parent is None: + raise ValueError( + "Malformed hybrid tree: partition data input to a partition node cannot contain a correlated reference to the partition node." + ) + result = self.make_decorrelate_parent( + hybrid.parent, len(hybrid.parent.children), len(hybrid.pipeline) + ) + return result[0], result[1] + 1 + # Temporarily detach the successor of the current level, then create a + # deep copy of the current level (which will include its ancestors), + # then reattach the successor back to the original. This ensures that + # the descendants of the current level are not included when providing + # the parent to the correlated child as its new ancestor. + successor: HybridTree | None = hybrid.successor + hybrid._successor = None + new_hybrid: HybridTree = copy.deepcopy(hybrid) + hybrid._successor = successor + # Ensure the new parent only includes the children & pipeline operators + # that is has to. + new_hybrid._children = new_hybrid._children[:child_idx] + new_hybrid._pipeline = new_hybrid._pipeline[: required_steps + 1] + return new_hybrid, 0 + + def remove_correl_refs( + self, expr: HybridExpr, parent: HybridTree, child_height: int + ) -> HybridExpr: + """ + Recursively & destructively removes correlated references within a + hybrid expression if they point to a specific correlated ancestor + hybrid tree, and replaces them with corresponding BACK references. + + Args: + `expr`: The hybrid expression to remove correlated references from. + `parent`: The correlated ancestor hybrid tree that the correlated + references should point to when they are targeted for removal. + `child_height`: The height of the correlated child within the + hybrid tree that the correlated references is point to. This is + the number of BACK indices to shift by when replacing the + correlated reference with a BACK reference. + + Returns: + The hybrid expression with all correlated references to `parent` + replaced with corresponding BACK references. The replacement also + happens in-place. + """ + match expr: + case HybridCorrelExpr(): + # If the correlated reference points to the parent, then + # replace it with a BACK reference. Otherwise, recursively + # transform its input expression in case it contains another + # correlated reference. + if expr.hybrid is parent: + result: HybridExpr | None = expr.expr.shift_back(child_height) + assert result is not None + return result + else: + expr.expr = self.remove_correl_refs(expr.expr, parent, child_height) + return expr + case HybridFunctionExpr(): + # For regular functions, recursively transform all of their + # arguments. + for idx, arg in enumerate(expr.args): + expr.args[idx] = self.remove_correl_refs(arg, parent, child_height) + return expr + case HybridWindowExpr(): + # For window functions, recursively transform all of their + # arguments, partition keys, and order keys. + for idx, arg in enumerate(expr.args): + expr.args[idx] = self.remove_correl_refs(arg, parent, child_height) + for idx, arg in enumerate(expr.partition_args): + expr.partition_args[idx] = self.remove_correl_refs( + arg, parent, child_height + ) + for order_arg in expr.order_args: + order_arg.expr = self.remove_correl_refs( + order_arg.expr, parent, child_height + ) + return expr + case ( + HybridBackRefExpr() + | HybridRefExpr() + | HybridChildRefExpr() + | HybridLiteralExpr() + | HybridColumnExpr() + ): + # All other expression types do not require any transformation + # to de-correlate since they cannot contain correlations. + return expr + case _: + raise NotImplementedError( + f"Unsupported expression type: {expr.__class__.__name__}." + ) + + def correl_ref_purge( + self, + level: HybridTree | None, + old_parent: HybridTree, + new_parent: HybridTree, + child_height: int, + ) -> None: + """ + The recursive procedure to remove correlated references from the + expressions of a hybrid tree or any of its ancestors or children if + they refer to a specific correlated ancestor that is being removed. + + Args: + `level`: The current level of the hybrid tree to remove correlated + references from. + `old_parent`: The correlated ancestor hybrid tree that the correlated + references should point to when they are targeted for removal. + `new_parent`: The ancestor of `level` that removal should stop at + because it is the transposed snapshot of `old_parent`, and + therefore it & its ancestors cannot contain any more correlated + references that would be targeted for removal. + `child_height`: The height of the correlated child within the + hybrid tree that the correlated references is point to. This is + the number of BACK indices to shift by when replacing the + correlated reference with a BACK + """ + while level is not None and level is not new_parent: + # First, recursively remove any targeted correlated references from + # the children of the current level. + for child in level.children: + self.correl_ref_purge( + child.subtree, old_parent, new_parent, child_height + ) + # Then, remove any correlated references from the pipeline + # operators of the current level. Usually this just means + # transforming the terms/orderings/unique keys of the operation, + # but specific operation types will require special casing if they + # have additional expressions stored in other field that need to be + # transformed. + for operation in level.pipeline: + for name, expr in operation.terms.items(): + operation.terms[name] = self.remove_correl_refs( + expr, old_parent, child_height + ) + for ordering in operation.orderings: + ordering.expr = self.remove_correl_refs( + ordering.expr, old_parent, child_height + ) + for idx, expr in enumerate(operation.unique_exprs): + operation.unique_exprs[idx] = self.remove_correl_refs( + expr, old_parent, child_height + ) + if isinstance(operation, HybridCalc): + for str, expr in operation.new_expressions.items(): + operation.new_expressions[str] = self.remove_correl_refs( + expr, old_parent, child_height + ) + if isinstance(operation, HybridFilter): + operation.condition = self.remove_correl_refs( + operation.condition, old_parent, child_height + ) + # Repeat the process on the ancestor until either loop guard + # condition is no longer True. + level = level.parent + + def decorrelate_child( + self, + old_parent: HybridTree, + new_parent: HybridTree, + child: HybridConnection, + is_aggregate: bool, + skipped_levels: int, + ) -> None: + """ + Runs the logic to de-correlate a child of a hybrid tree that contains + a correlated reference. This involves linking the child to a new parent + as its ancestor, the parent being a snapshot of the original hybrid + tree that contained the correlated child as a child. The transformed + child can now replace correlated references with BACK references that + point to terms in its newly expanded ancestry, and the original hybrid + tree can now join onto this child using its uniqueness keys. + + Args: + `old_parent`: The correlated ancestor hybrid tree that the correlated + references should point to when they are targeted for removal. + `new_parent`: The ancestor of `level` that removal should stop at. + `child`: The child of the hybrid tree that contains the correlated + nodes to be removed. + `is_aggregate`: Whether the child is being aggregated with regards + to its parent. + `skipped_levels`: The number of ancestor layers that should be + ignored when deriving backshifts of join/agg keys. + """ + # First, find the height of the child subtree & its top-most level. + child_root: HybridTree = child.subtree + child_height: int = 1 + while child_root.parent is not None: + child_height += 1 + child_root = child_root.parent + # Link the top level of the child subtree to the new parent. + new_parent.add_successor(child_root) + # Replace any correlated references to the original parent with BACK references. + self.correl_ref_purge(child.subtree, old_parent, new_parent, child_height) + # Update the join keys to join on the unique keys of all the ancestors. + new_join_keys: list[tuple[HybridExpr, HybridExpr]] = [] + additional_levels: int = 0 + current_level: HybridTree | None = old_parent + new_agg_keys: list[HybridExpr] = [] + while current_level is not None: + skip_join: bool = ( + isinstance(current_level.pipeline[0], HybridPartition) + and child is current_level.children[0] + ) + for unique_key in sorted(current_level.pipeline[0].unique_exprs, key=str): + lhs_key: HybridExpr | None = unique_key.shift_back(additional_levels) + rhs_key: HybridExpr | None = unique_key.shift_back( + additional_levels + child_height - skipped_levels + ) + assert lhs_key is not None and rhs_key is not None + if not skip_join: + new_join_keys.append((lhs_key, rhs_key)) + new_agg_keys.append(rhs_key) + current_level = current_level.parent + additional_levels += 1 + child.subtree.join_keys = new_join_keys + # If aggregating, update the aggregation keys accordingly. + if is_aggregate: + child.subtree.agg_keys = new_agg_keys + + def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: + """ + The recursive procedure to remove unwanted correlated references from + the entire hybrid tree, called from the bottom and working upwards + to the top layer, and having each layer also de-correlate its children. + + Args: + `hybrid`: The hybrid tree to remove correlated references from. + + Returns: + The hybrid tree with all invalid correlated references removed as the + tree structure is re-written to allow them to be replaced with BACK + references. The transformation is also done in-place. + """ + # Recursively decorrelate the ancestors of the current level of the + # hybrid tree. + if hybrid.parent is not None: + hybrid._parent = self.decorrelate_hybrid_tree(hybrid.parent) + hybrid._parent._successor = hybrid + # Iterate across all the children and recursively decorrelate them. + for child in hybrid.children: + child.subtree = self.decorrelate_hybrid_tree(child.subtree) + # Iterate across all the children, identify any that are correlated, + # and transform any of the correlated ones that require decorrelation + # due to the type of connection. + for idx, child in enumerate(hybrid.children): + if idx not in hybrid.correlated_children: + continue + match child.connection_type: + case ( + ConnectionType.SINGULAR + | ConnectionType.SINGULAR_ONLY_MATCH + | ConnectionType.AGGREGATION + | ConnectionType.AGGREGATION_ONLY_MATCH + ): + new_parent, skipped_levels = self.make_decorrelate_parent( + hybrid, idx, hybrid.children[idx].required_steps + ) + self.decorrelate_child( + hybrid, + new_parent, + child, + child.connection_type.is_aggregation, + skipped_levels, + ) + case ConnectionType.NDISTINCT | ConnectionType.NDISTINCT_ONLY_MATCH: + raise NotImplementedError( + f"PyDough does not yet support correlated references with the {child.connection_type.name} pattern." + ) + case ( + ConnectionType.SEMI + | ConnectionType.ANTI + | ConnectionType.NO_MATCH_SINGULAR + | ConnectionType.NO_MATCH_AGGREGATION + | ConnectionType.NO_MATCH_NDISTINCT + ): + # These patterns do not require decorrelation since they + # are supported via correlated SEMI/ANTI joins. + continue + return hybrid + + +def run_hybrid_decorrelation(hybrid: HybridTree) -> HybridTree: + """ + Invokes the procedure to remove correlated references from a hybrid tree + before relational conversion if those correlated references are invalid + (e.g. not from a semi/anti join). + + Args: + `hybrid`: The hybrid tree to remove correlated references from. + + Returns: + The hybrid tree with all invalid correlated references removed as the + tree structure is re-written to allow them to be replaced with BACK + references. The transformation is also done in-place. + """ + decorr: Decorrelater = Decorrelater() + return decorr.decorrelate_hybrid_tree(hybrid) diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 0a9fc64e..8e7d4a49 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -184,6 +184,8 @@ def apply_renamings(self, renamings: dict[str, str]) -> "HybridExpr": return self def shift_back(self, levels: int) -> HybridExpr | None: + if levels == 0: + return self return HybridBackRefExpr(self.name, levels, self.typ) @@ -229,6 +231,27 @@ def shift_back(self, levels: int) -> HybridExpr | None: return HybridBackRefExpr(self.name, self.back_idx + levels, self.typ) +class HybridCorrelExpr(HybridExpr): + """ + Class for HybridExpr terms that are expressions from a parent hybrid tree + rather than an ancestor, which requires a correlated reference. + """ + + def __init__(self, hybrid: "HybridTree", expr: HybridExpr): + super().__init__(expr.typ) + self.hybrid = hybrid + self.expr: HybridExpr = expr + + def __repr__(self): + return f"CORREL({self.expr})" + + def apply_renamings(self, renamings: dict[str, str]) -> "HybridExpr": + return self + + def shift_back(self, levels: int) -> HybridExpr | None: + return self + + class HybridLiteralExpr(HybridExpr): """ Class for HybridExpr terms that are literals. @@ -471,17 +494,26 @@ def __init__( for name, expr in predecessor.terms.items(): terms[name] = HybridRefExpr(name, expr.typ) renamings.update(predecessor.renamings) + new_renamings: dict[str, str] = {} for name, expr in new_expressions.items(): if name in terms and terms[name] == expr: continue expr = expr.apply_renamings(predecessor.renamings) used_name: str = name idx: int = 0 - while used_name in terms or used_name in renamings: + while ( + used_name in terms + or used_name in renamings + or used_name in new_renamings + ): used_name = f"{name}_{idx}" idx += 1 terms[used_name] = expr - renamings[name] = used_name + new_renamings[name] = used_name + renamings.update(new_renamings) + for old_name, new_name in new_renamings.items(): + expr = new_expressions.pop(old_name) + new_expressions[new_name] = expr super().__init__(terms, renamings, orderings, predecessor.unique_exprs) self.calc = Calc self.new_expressions = new_expressions @@ -497,7 +529,10 @@ class HybridFilter(HybridOperation): def __init__(self, predecessor: HybridOperation, condition: HybridExpr): super().__init__( - predecessor.terms, {}, predecessor.orderings, predecessor.unique_exprs + predecessor.terms, + predecessor.renamings, + predecessor.orderings, + predecessor.unique_exprs, ) self.predecessor: HybridOperation = predecessor self.condition: HybridExpr = condition @@ -543,7 +578,10 @@ def __init__( records_to_keep: int, ): super().__init__( - predecessor.terms, {}, predecessor.orderings, predecessor.unique_exprs + predecessor.terms, + predecessor.renamings, + predecessor.orderings, + predecessor.unique_exprs, ) self.predecessor: HybridOperation = predecessor self.records_to_keep: int = records_to_keep @@ -874,6 +912,7 @@ def __init__( self._is_connection_root: bool = is_connection_root self._agg_keys: list[HybridExpr] | None = None self._join_keys: list[tuple[HybridExpr, HybridExpr]] | None = None + self._correlated_children: set[int] = set() if isinstance(root_operation, HybridPartition): self._join_keys = [] @@ -884,13 +923,13 @@ def __repr__(self): lines.append(" -> ".join(repr(operation) for operation in self.pipeline)) prefix = " " if self.successor is None else "↓" for idx, child in enumerate(self.children): - lines.append(f"{prefix} child #{idx}:") + lines.append(f"{prefix} child #{idx} ({child.connection_type.name}):") if child.subtree.agg_keys is not None: - lines.append( - f"{prefix} aggregate: {child.subtree.agg_keys} -> {child.aggs}:" - ) + lines.append(f"{prefix} aggregate: {child.subtree.agg_keys}") + if len(child.aggs): + lines.append(f"{prefix} aggs: {child.aggs}:") if child.subtree.join_keys is not None: - lines.append(f"{prefix} join: {child.subtree.join_keys}:") + lines.append(f"{prefix} join: {child.subtree.join_keys}") for line in repr(child.subtree).splitlines(): lines.append(f"{prefix} {line}") return "\n".join(lines) @@ -914,6 +953,14 @@ def children(self) -> list[HybridConnection]: """ return self._children + @property + def correlated_children(self) -> set[int]: + """ + The set of indices of children that contain correlated references to + the current hybrid tree. + """ + return self._correlated_children + @property def successor(self) -> Optional["HybridTree"]: """ @@ -1046,6 +1093,10 @@ def __init__(self, configs: PyDoughConfigs): self.configs = configs # An index used for creating fake column names for aliases self.alias_counter: int = 0 + # A stack where each element is a hybrid tree being derived + # as as subtree of the previous element, and the current tree is + # being derived as the subtree of the last element. + self.stack: list[HybridTree] = [] @staticmethod def get_join_keys( @@ -1230,6 +1281,7 @@ def populate_children( accordingly so expressions using the child indices know what hybrid connection index to use. """ + self.stack.append(hybrid) for child_idx, child in enumerate(child_operator.children): # Build the hybrid tree for the child. Before doing so, reset the # alias counter to 0 to ensure that identical subtrees are named @@ -1269,108 +1321,7 @@ def populate_children( for con_typ in reference_types: connection_type = connection_type.reconcile_connection_types(con_typ) child_idx_mapping[child_idx] = hybrid.add_child(subtree, connection_type) - - def make_hybrid_agg_expr( - self, - hybrid: HybridTree, - expr: PyDoughExpressionQDAG, - child_ref_mapping: dict[int, int], - ) -> tuple[HybridExpr, int | None]: - """ - Converts a QDAG expression into a HybridExpr specifically with the - intent of making it the input to an aggregation call. Returns the - converted function argument, as well as an index indicating what child - subtree the aggregation's arguments belong to. NOTE: the HybridExpr is - phrased relative to the child subtree, rather than relative to `hybrid` - itself. - - Args: - `hybrid`: the hybrid tree that should be used to derive the - translation of `expr`, as it is the context in which the `expr` - will live. - `expr`: the QDAG expression to be converted. - `child_ref_mapping`: mapping of indices used by child references - in the original expressions to the index of the child hybrid tree - relative to the current level. - - Returns: - The HybridExpr node corresponding to `expr`, as well as the index - of the child it belongs to (e.g. which subtree does this - aggregation need to be done on top of). - """ - hybrid_result: HybridExpr - # This value starts out as None since we do not know the child index - # that `expr` correspond to yet. It may still be None at the end, since - # it is possible that `expr` does not correspond to any child index. - child_idx: int | None = None - match expr: - case PartitionKey(): - return self.make_hybrid_agg_expr(hybrid, expr.expr, child_ref_mapping) - case Literal(): - # Literals are kept as-is. - hybrid_result = HybridLiteralExpr(expr) - case ChildReferenceExpression(): - # Child references become regular references because the - # expression is phrased as if we were inside the child rather - # than the parent. - child_idx = child_ref_mapping[expr.child_idx] - child_connection = hybrid.children[child_idx] - expr_name = child_connection.subtree.pipeline[-1].renamings.get( - expr.term_name, expr.term_name - ) - hybrid_result = HybridRefExpr(expr_name, expr.pydough_type) - case ExpressionFunctionCall(): - if expr.operator.is_aggregation: - raise NotImplementedError( - "PyDough does not yet support calling aggregations inside of aggregations" - ) - # Every argument must be translated in the same manner as a - # regular function argument, except that the child index it - # corresponds to must be reconciled with the child index value - # accumulated so far. - args: list[HybridExpr] = [] - for arg in expr.args: - if not isinstance(arg, PyDoughExpressionQDAG): - raise NotImplementedError( - f"TODO: support converting {arg.__class__.__name__} as a function argument" - ) - hybrid_arg, hybrid_child_index = self.make_hybrid_agg_expr( - hybrid, arg, child_ref_mapping - ) - if hybrid_child_index is not None: - if child_idx is None: - # In this case, the argument is the first one seen that - # has an index, so that index is chosen. - child_idx = hybrid_child_index - elif hybrid_child_index != child_idx: - # In this case, multiple arguments correspond to - # different children, which cannot be handled yet - # because it means it is impossible to push the agg - # call into a single HybridConnection node. - raise NotImplementedError( - "Unsupported case: multiple child indices referenced by aggregation arguments" - ) - args.append(hybrid_arg) - hybrid_result = HybridFunctionExpr( - expr.operator, args, expr.pydough_type - ) - case BackReferenceExpression(): - raise NotImplementedError( - "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of an ancestor of the current context" - ) - case Reference(): - raise NotImplementedError( - "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of the context itself" - ) - case WindowCall(): - raise NotImplementedError( - "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and window functions" - ) - case _: - raise NotImplementedError( - f"TODO: support converting {expr.__class__.__name__} in aggregations" - ) - return hybrid_result, child_idx + self.stack.pop() def postprocess_agg_output( self, agg_call: HybridFunctionExpr, agg_ref: HybridExpr, joins_can_nullify: bool @@ -1533,11 +1484,180 @@ def handle_has_hasnot( # has / hasnot condition is now known to be true. return HybridLiteralExpr(Literal(True, BooleanType())) + def convert_agg_arg(self, expr: HybridExpr, child_indices: set[int]) -> HybridExpr: + """ + Translates a hybrid expression that is an argument to an aggregation + (or a subexpression of such an argument) into a form that is expressed + from the perspective of the child subtree that is being aggregated. + + Args: + `expr`: the expression to be converted. + `child_indices`: a set that is mutated to contain the indices of + any children that are referenced by `expr`. + + Returns: + The translated expression. + + Raises: + NotImplementedError if `expr` is an expression that cannot be used + inside of an aggregation call. + """ + match expr: + case HybridLiteralExpr(): + return expr + case HybridChildRefExpr(): + # Child references become regular references because the + # expression is phrased as if we were inside the child rather + # than the parent. + child_indices.add(expr.child_idx) + return HybridRefExpr(expr.name, expr.typ) + case HybridFunctionExpr(): + return HybridFunctionExpr( + expr.operator, + [self.convert_agg_arg(arg, child_indices) for arg in expr.args], + expr.typ, + ) + case HybridBackRefExpr(): + raise NotImplementedError( + "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of an ancestor of the current context" + ) + case HybridRefExpr(): + raise NotImplementedError( + "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of the context itself" + ) + case HybridWindowExpr(): + raise NotImplementedError( + "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and window functions" + ) + case _: + raise NotImplementedError( + f"TODO: support converting {expr.__class__.__name__} in aggregations" + ) + + def make_agg_call( + self, + hybrid: HybridTree, + expr: ExpressionFunctionCall, + args: list[HybridExpr], + ) -> HybridExpr: + """ + For aggregate function calls, their arguments are translated in a + manner that identifies what child subtree they correspond too, by + index, and translates them relative to the subtree. Then, the + aggregation calls are placed into the `aggs` mapping of the + corresponding child connection, and the aggregation call becomes a + child reference (referring to the aggs list), since after translation, + an aggregated child subtree only has the grouping keys and the + aggregation calls as opposed to its other terms. + + Args: + `hybrid`: the hybrid tree that should be used to derive the + translation of the aggregation call. + `expr`: the aggregation function QDAG expression to be converted. + `args`: the converted arguments to the aggregation call. + """ + child_indices: set[int] = set() + converted_args: list[HybridExpr] = [ + self.convert_agg_arg(arg, child_indices) for arg in args + ] + if len(child_indices) != 1: + raise ValueError( + f"Expected aggregation call to contain references to exactly one child collection, but found {len(child_indices)} in {expr}" + ) + hybrid_call: HybridFunctionExpr = HybridFunctionExpr( + expr.operator, converted_args, expr.pydough_type + ) + # Identify the child connection that the aggregation call is pushed + # into. + child_idx: int = child_indices.pop() + child_connection: HybridConnection = hybrid.children[child_idx] + # Generate a unique name for the agg call to push into the child + # connection. + agg_name: str = self.get_agg_name(child_connection) + child_connection.aggs[agg_name] = hybrid_call + result_ref: HybridExpr = HybridChildRefExpr( + agg_name, child_idx, expr.pydough_type + ) + joins_can_nullify: bool = not isinstance(hybrid.pipeline[0], HybridRoot) + return self.postprocess_agg_output(hybrid_call, result_ref, joins_can_nullify) + + def make_hybrid_correl_expr( + self, + back_expr: BackReferenceExpression, + collection: PyDoughCollectionQDAG, + steps_taken_so_far: int, + ) -> HybridCorrelExpr: + """ + Converts a BACK reference into a correlated reference when the number + of BACK levels exceeds the height of the current subtree. + + Args: + `back_expr`: the original BACK reference to be converted. + `collection`: the collection at the top of the current subtree, + before we have run out of BACK levels to step up out of. + `steps_taken_so_far`: the number of steps already taken to step + up from the BACK node. This is needed so we know how many steps + still need to be taken upward once we have stepped out of the child + subtree back into the parent subtree. + """ + if len(self.stack) == 0: + raise ValueError("Back reference steps too far back") + # Identify the parent subtree that the BACK reference is stepping back + # into, out of the child. + parent_tree: HybridTree = self.stack.pop() + remaining_steps_back: int = back_expr.back_levels - steps_taken_so_far - 1 + parent_result: HybridExpr + # Special case: stepping out of the data argument of PARTITION back + # into its ancestor. For example: + # TPCH(x=...).PARTITION(data.WHERE(y > BACK(1).x), ...) + partition_edge_case: bool = len(parent_tree.pipeline) == 1 and isinstance( + parent_tree.pipeline[0], HybridPartition + ) + if partition_edge_case: + assert parent_tree.parent is not None + # Treat the partition's parent as the conext for the back + # to step into, as opposed to the partition itself (so the back + # levels are consistent) + self.stack.append(parent_tree.parent) + parent_result = self.make_hybrid_correl_expr( + back_expr, collection, steps_taken_so_far + ).expr + self.stack.pop() + elif remaining_steps_back == 0: + # If there are no more steps back to be made, then the correlated + # reference is to a reference from the current context. + if back_expr.term_name not in parent_tree.pipeline[-1].terms: + raise ValueError( + f"Back reference to {back_expr.term_name} not found in parent" + ) + parent_name: str = parent_tree.pipeline[-1].renamings.get( + back_expr.term_name, back_expr.term_name + ) + parent_result = HybridRefExpr(parent_name, back_expr.pydough_type) + else: + # Otherwise, a back reference needs to be made from the current + # collection a number of steps back based on how many steps still + # need to be taken, and it must be recursively converted to a + # hybrid expression that gets wrapped in a correlated reference. + new_expr: PyDoughExpressionQDAG = BackReferenceExpression( + collection, back_expr.term_name, remaining_steps_back + ) + parent_result = self.make_hybrid_expr(parent_tree, new_expr, {}, False) + if not isinstance(parent_result, HybridCorrelExpr): + parent_tree.correlated_children.add(len(parent_tree.children)) + # Restore parent_tree back onto the stack, since evaluating `back_expr` + # does not change the program's current placement in the sutbtrees. + self.stack.append(parent_tree) + # Create the correlated reference to the expression with regards to + # the parent tree, which could also be a correlated expression. + return HybridCorrelExpr(parent_tree, parent_result) + def make_hybrid_expr( self, hybrid: HybridTree, expr: PyDoughExpressionQDAG, child_ref_mapping: dict[int, int], + inside_agg: bool, ) -> HybridExpr: """ Converts a QDAG expression into a HybridExpr. @@ -1550,6 +1670,8 @@ def make_hybrid_expr( `child_ref_mapping`: mapping of indices used by child references in the original expressions to the index of the child hybrid tree relative to the current level. + `inside_agg`: True if `expr` is beign derived is inside of an + aggregation call, False otherwise. Returns: The HybridExpr node corresponding to `expr` @@ -1561,7 +1683,9 @@ def make_hybrid_expr( ancestor_tree: HybridTree match expr: case PartitionKey(): - return self.make_hybrid_expr(hybrid, expr.expr, child_ref_mapping) + return self.make_hybrid_expr( + hybrid, expr.expr, child_ref_mapping, inside_agg + ) case Literal(): return HybridLiteralExpr(expr) case ColumnProperty(): @@ -1581,20 +1705,22 @@ def make_hybrid_expr( case BackReferenceExpression(): # A reference to an expression from an ancestor becomes a # reference to one of the terms of a parent level of the hybrid - # tree. This does not yet support cases where the back - # reference steps outside of a child subtree and back into its - # parent subtree, since that breaks the independence between - # the parent and child. + # tree. If the BACK goes far enough that it must step outside + # a child subtree into the parent, a correlated reference is + # created. ancestor_tree = hybrid back_idx: int = 0 true_steps_back: int = 0 # Keep stepping backward until `expr.back_levels` non-hidden # steps have been taken (to ignore steps that are part of a # compound). + collection: PyDoughCollectionQDAG = expr.collection while true_steps_back < expr.back_levels: + assert collection.ancestor_context is not None + collection = collection.ancestor_context if ancestor_tree.parent is None: - raise NotImplementedError( - "TODO: (gh #141) support BACK references that step from a child subtree back into a parent context." + return self.make_hybrid_correl_expr( + expr, collection, true_steps_back ) ancestor_tree = ancestor_tree.parent back_idx += true_steps_back @@ -1609,80 +1735,51 @@ def make_hybrid_expr( expr.term_name, expr.term_name ) return HybridRefExpr(expr_name, expr.pydough_type) - case ExpressionFunctionCall() if not expr.operator.is_aggregation: - # For non-aggregate function calls, translate their arguments - # normally and build the function call. Does not support any + case ExpressionFunctionCall(): + if expr.operator.is_aggregation and inside_agg: + raise NotImplementedError( + "PyDough does not yet support calling aggregations inside of aggregations" + ) + # Do special casing for operators that an have collection + # arguments. + # TODO: (gh #148) handle collection-level NDISTINCT + if ( + expr.operator == pydop.COUNT + and len(expr.args) == 1 + and isinstance(expr.args[0], PyDoughCollectionQDAG) + ): + return self.handle_collection_count(hybrid, expr, child_ref_mapping) + elif expr.operator in (pydop.HAS, pydop.HASNOT): + return self.handle_has_hasnot(hybrid, expr, child_ref_mapping) + elif any( + not isinstance(arg, PyDoughExpressionQDAG) for arg in expr.args + ): + raise NotImplementedError( + f"PyDough does not yet support non-expression arguments for aggregation function {expr.operator}" + ) + # For normal operators, translate their expression arguments + # normally. If it is a non-aggregation, build the function + # call. If it is an aggregation, transform accordingly. # such function that takes in a collection, as none currently # exist that are not aggregations. + expr.operator.is_aggregation for arg in expr.args: if not isinstance(arg, PyDoughExpressionQDAG): raise NotImplementedError( - "PyDough does not yet support converting collections as function arguments to a non-aggregation function" + f"PyDough does not yet support non-expression arguments for function {expr.operator}" ) - args.append(self.make_hybrid_expr(hybrid, arg, child_ref_mapping)) - return HybridFunctionExpr(expr.operator, args, expr.pydough_type) - case ExpressionFunctionCall() if expr.operator.is_aggregation: - # For aggregate function calls, their arguments are translated in - # a manner that identifies what child subtree they correspond too, - # by index, and translates them relative to the subtree. Then, the - # aggregation calls are placed into the `aggs` mapping of the - # corresponding child connection, and the aggregation call becomes - # a child reference (referring to the aggs list), since after - # translation, an aggregated child subtree only has the grouping - # keys & the aggregation calls as opposed to its other terms. - child_idx: int | None = None - arg_child_idx: int | None = None - for arg in expr.args: - if isinstance(arg, PyDoughExpressionQDAG): - hybrid_arg, arg_child_idx = self.make_hybrid_agg_expr( - hybrid, arg, child_ref_mapping + args.append( + self.make_hybrid_expr( + hybrid, + arg, + child_ref_mapping, + inside_agg or expr.operator.is_aggregation, ) - else: - if not isinstance(arg, ChildReferenceCollection): - raise NotImplementedError("Cannot process argument") - # TODO: (gh #148) handle collection-level NDISTINCT - if expr.operator == pydop.COUNT: - return self.handle_collection_count( - hybrid, expr, child_ref_mapping - ) - elif expr.operator in (pydop.HAS, pydop.HASNOT): - return self.handle_has_hasnot( - hybrid, expr, child_ref_mapping - ) - else: - raise NotImplementedError( - f"PyDough does not yet support collection arguments for aggregation function {expr.operator}" - ) - # Accumulate the `arg_child_idx` value from the argument across - # all function arguments, ensuring that at the end there is - # exactly one child subtree that the agg call corresponds to. - if arg_child_idx is not None: - if child_idx is None: - child_idx = arg_child_idx - elif arg_child_idx != child_idx: - raise NotImplementedError( - "Unsupported case: multiple child indices referenced by aggregation arguments" - ) - args.append(hybrid_arg) - if child_idx is None: - raise NotImplementedError( - "Unsupported case: no child indices referenced by aggregation arguments" ) - hybrid_call: HybridFunctionExpr = HybridFunctionExpr( - expr.operator, args, expr.pydough_type - ) - child_connection = hybrid.children[child_idx] - # Generate a unique name for the agg call to push into the child - # connection. - agg_name: str = self.get_agg_name(child_connection) - child_connection.aggs[agg_name] = hybrid_call - result_ref: HybridExpr = HybridChildRefExpr( - agg_name, child_idx, expr.pydough_type - ) - joins_can_nullify: bool = not isinstance(hybrid.pipeline[0], HybridRoot) - return self.postprocess_agg_output( - hybrid_call, result_ref, joins_can_nullify - ) + if expr.operator.is_aggregation: + return self.make_agg_call(hybrid, expr, args) + else: + return HybridFunctionExpr(expr.operator, args, expr.pydough_type) case WindowCall(): partition_args: list[HybridExpr] = [] order_args: list[HybridCollation] = [] @@ -1700,7 +1797,7 @@ def make_hybrid_expr( partition_args.append(shifted_arg) for arg in expr.collation_args: hybrid_arg = self.make_hybrid_expr( - hybrid, arg.expr, child_ref_mapping + hybrid, arg.expr, child_ref_mapping, inside_agg ) order_args.append(HybridCollation(hybrid_arg, arg.asc, arg.na_last)) return HybridWindowExpr( @@ -1739,7 +1836,9 @@ def process_hybrid_collations( hybrid_orderings: list[HybridCollation] = [] for collation in collations: name = self.get_ordering_name(hybrid) - expr = self.make_hybrid_expr(hybrid, collation.expr, child_ref_mapping) + expr = self.make_hybrid_expr( + hybrid, collation.expr, child_ref_mapping, False + ) new_expressions[name] = expr new_collation: HybridCollation = HybridCollation( HybridRefExpr(name, collation.expr.pydough_type), @@ -1750,7 +1849,7 @@ def process_hybrid_collations( return new_expressions, hybrid_orderings def make_hybrid_tree( - self, node: PyDoughCollectionQDAG, parent: HybridTree | None = None + self, node: PyDoughCollectionQDAG, parent: HybridTree | None ) -> HybridTree: """ Converts a collection QDAG into the HybridTree format. @@ -1781,9 +1880,14 @@ def make_hybrid_tree( return successor_hybrid case PartitionChild(): hybrid = self.make_hybrid_tree(node.ancestor_context, parent) - successor_hybrid = HybridTree( - HybridPartitionChild(hybrid.children[0].subtree) - ) + # Identify the original data being partitioned, which may + # require stepping in multiple times if the partition is + # nested inside another partition. + src_tree: HybridTree = hybrid + while isinstance(src_tree.pipeline[0], HybridPartitionChild): + src_tree = src_tree.pipeline[0].subtree + subtree: HybridTree = src_tree.children[0].subtree + successor_hybrid = HybridTree(HybridPartitionChild(subtree)) hybrid.add_successor(successor_hybrid) return successor_hybrid case Calc(): @@ -1792,7 +1896,7 @@ def make_hybrid_tree( new_expressions: dict[str, HybridExpr] = {} for name in sorted(node.calc_terms): expr = self.make_hybrid_expr( - hybrid, node.get_expr(name), child_ref_mapping + hybrid, node.get_expr(name), child_ref_mapping, False ) new_expressions[name] = expr hybrid.pipeline.append( @@ -1806,7 +1910,9 @@ def make_hybrid_tree( case Where(): hybrid = self.make_hybrid_tree(node.preceding_context, parent) self.populate_children(hybrid, node, child_ref_mapping) - expr = self.make_hybrid_expr(hybrid, node.condition, child_ref_mapping) + expr = self.make_hybrid_expr( + hybrid, node.condition, child_ref_mapping, False + ) hybrid.pipeline.append(HybridFilter(hybrid.pipeline[-1], expr)) return hybrid case PartitionBy(): @@ -1819,7 +1925,7 @@ def make_hybrid_tree( for key_name in node.calc_terms: key = node.get_expr(key_name) expr = self.make_hybrid_expr( - successor_hybrid, key, child_ref_mapping + successor_hybrid, key, child_ref_mapping, False ) partition.add_key(key_name, expr) key_exprs.append(HybridRefExpr(key_name, expr.typ)) @@ -1862,19 +1968,40 @@ def make_hybrid_tree( successor_hybrid = self.make_hybrid_tree( node.child_access.child_access, parent ) - partition_by = node.child_access.ancestor_context + partition_by = ( + node.child_access.ancestor_context.starting_predecessor + ) assert isinstance(partition_by, PartitionBy) for key in partition_by.keys: rhs_expr: HybridExpr = self.make_hybrid_expr( successor_hybrid, Reference(node.child_access, key.expr.term_name), child_ref_mapping, + False, ) assert isinstance(rhs_expr, HybridRefExpr) lhs_expr: HybridExpr = HybridChildRefExpr( rhs_expr.name, 0, rhs_expr.typ ) join_key_exprs.append((lhs_expr, rhs_expr)) + + case PartitionBy(): + partition = HybridPartition() + successor_hybrid = HybridTree(partition) + self.populate_children( + successor_hybrid, node.child_access, child_ref_mapping + ) + partition_child_idx = child_ref_mapping[0] + for key_name in node.calc_terms: + key = node.get_expr(key_name) + expr = self.make_hybrid_expr( + successor_hybrid, key, child_ref_mapping, False + ) + partition.add_key(key_name, expr) + key_exprs.append(HybridRefExpr(key_name, expr.typ)) + successor_hybrid.children[ + partition_child_idx + ].subtree.agg_keys = key_exprs case _: raise NotImplementedError( f"{node.__class__.__name__} (child is {node.child_access.__class__.__name__})" diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 55de7714..c48132a2 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -27,6 +27,7 @@ CallExpression, ColumnPruner, ColumnReference, + CorrelatedReference, EmptySingleton, ExpressionSortInfo, Filter, @@ -43,6 +44,7 @@ ) from pydough.types import BooleanType, Int64Type, UnknownType +from .hybrid_decorrelater import run_hybrid_decorrelation from .hybrid_tree import ( ConnectionType, HybridBackRefExpr, @@ -52,6 +54,7 @@ HybridCollectionAccess, HybridColumnExpr, HybridConnection, + HybridCorrelExpr, HybridExpr, HybridFilter, HybridFunctionExpr, @@ -90,11 +93,20 @@ class TranslationOutput: value of that expression. """ + correlated_name: str | None = None + """ + The name that can be used to refer to the relational output in correlated + references. + """ + class RelTranslation: def __init__(self): # An index used for creating fake column names self.dummy_idx = 1 + # A stack of contexts used to point to ancestors for correlated + # references. + self.stack: list[TranslationOutput] = [] def make_null_column(self, relation: RelationalNode) -> ColumnReference: """ @@ -145,6 +157,24 @@ def get_column_name( new_name = f"{name}_{self.dummy_idx}" return new_name + def get_correlated_name(self, context: TranslationOutput) -> str: + """ + Finds the name used to refer to a context for correlated variable + access. If the context does not have a correlated name, a new one is + generated for it. + + Args: + `context`: the context containing the relational subtree being + referrenced in a correlated variable access. + + Returns: + The name used to refer to the context in a correlated reference. + """ + if context.correlated_name is None: + context.correlated_name = f"corr{self.dummy_idx}" + self.dummy_idx += 1 + return context.correlated_name + def translate_expression( self, expr: HybridExpr, context: TranslationOutput | None ) -> RelationalExpression: @@ -199,8 +229,32 @@ def translate_expression( order_inputs, expr.kwargs, ) + case HybridCorrelExpr(): + # Convert correlated expressions by converting the expression + # they point to in the context of the top of the stack, then + # wrapping the result in a correlated reference. + ancestor_context: TranslationOutput = self.stack.pop() + ancestor_expr: RelationalExpression = self.translate_expression( + expr.expr, ancestor_context + ) + self.stack.append(ancestor_context) + match ancestor_expr: + case ColumnReference(): + return CorrelatedReference( + ancestor_expr.name, + self.get_correlated_name(ancestor_context), + expr.typ, + ) + case CorrelatedReference(): + return ancestor_expr + case _: + raise ValueError( + f"Unsupported expression to reference in a correlated reference: {ancestor_expr}" + ) case _: - raise NotImplementedError(expr.__class__.__name__) + raise NotImplementedError( + f"TODO: support relational conversion on {expr.__class__.__name__}" + ) def join_outputs( self, @@ -257,6 +311,7 @@ def join_outputs( [LiteralExpression(True, BooleanType())], [join_type], join_columns, + correl_name=lhs_result.correlated_name, ) input_aliases: list[str | None] = out_rel.default_input_aliases @@ -397,9 +452,11 @@ def handle_children( """ for child_idx, child in enumerate(hybrid.children): if child.required_steps == pipeline_idx: + self.stack.append(context) child_output = self.rel_translation( child, child.subtree, len(child.subtree.pipeline) - 1 ) + self.stack.pop() assert child.subtree.join_keys is not None join_keys: list[tuple[HybridExpr, HybridExpr]] = child.subtree.join_keys agg_keys: list[HybridExpr] @@ -592,7 +649,7 @@ def translate_partition( Returns: The TranslationOutput payload containing access to the aggregated - child corresponding tot he partition data. + child corresponding to the partition data. """ expressions: dict[HybridExpr, ColumnReference] = {} # Account for the fact that the PARTITION is stepping down a level, @@ -828,11 +885,26 @@ def rel_translation( if isinstance(operation.collection, TableCollection): result = self.build_simple_table_scan(operation) if context is not None: + # If the collection access is the child of something + # else, join it onto that something else. Use the + # uniqueness keys of the ancestor, which should also be + # present in the collection (e.g. joining a partition + # onto the original data using the partition keys). + assert preceding_hybrid is not None + join_keys: list[tuple[HybridExpr, HybridExpr]] = [] + for unique_column in sorted( + preceding_hybrid[0].pipeline[0].unique_exprs, key=str + ): + if unique_column not in result.expressions: + raise ValueError( + f"Cannot connect parent context to child {operation.collection} because {unique_column} is not in the child's expressions." + ) + join_keys.append((unique_column, unique_column)) result = self.join_outputs( context, result, JoinType.INNER, - [], + join_keys, None, ) else: @@ -856,7 +928,8 @@ def rel_translation( assert context is not None, "Malformed HybridTree pattern." result = self.translate_filter(operation, context) case HybridPartition(): - assert context is not None, "Malformed HybridTree pattern." + if context is None: + context = TranslationOutput(EmptySingleton(), {}) result = self.translate_partition( operation, context, hybrid, pipeline_idx ) @@ -942,10 +1015,11 @@ def convert_ast_to_relational( final_terms: set[str] = node.calc_terms node = translator.preprocess_root(node) - # Convert the QDAG node to the hybrid form, then invoke the relational - # conversion procedure. The first element in the returned list is the - # final rel node. - hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node) + # Convert the QDAG node to the hybrid form, decorrelate it, then invoke + # the relational conversion procedure. The first element in the returned + # list is the final rel node. + hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node, None) + run_hybrid_decorrelation(hybrid) renamings: dict[str, str] = hybrid.pipeline[-1].renamings output: TranslationOutput = translator.rel_translation( None, hybrid, len(hybrid.pipeline) - 1 diff --git a/pydough/pydough_operators/base_operator.py b/pydough/pydough_operators/base_operator.py index fe12629e..82ccfacf 100644 --- a/pydough/pydough_operators/base_operator.py +++ b/pydough/pydough_operators/base_operator.py @@ -80,3 +80,15 @@ def to_string(self, arg_strings: list[str]) -> str: Returns: The string representation of the operator called on its arguments. """ + + @abstractmethod + def equals(self, other: object) -> bool: + """ + Returns whether this operator is equal to another operator. + """ + + def __eq__(self, other: object) -> bool: + return self.equals(other) + + def __hash__(self) -> int: + return hash(repr(self)) diff --git a/pydough/relational/README.md b/pydough/relational/README.md index d0259995..57939f8b 100644 --- a/pydough/relational/README.md +++ b/pydough/relational/README.md @@ -18,6 +18,8 @@ The relational_expressions submodule provides functionality to define and manage - `ExpressionSortInfo`: The representation of ordering for an expression within a relational node. - `RelationalExpressionVisitor`: The basic Visitor pattern to perform operations across the expression components of a relational tree. - `ColumnReferenceFinder`: Finds all unique column references in a relational expression. +- `CorrelatedReference`: The expression implementation for accessing a correlated column reference in a relational node. +- `CorrelatedReferenceFinder`: Finds all unique correlated references in a relational expression. - `RelationalExpressionShuttle`: Specialized form of the visitor pattern that returns a relational expression. - `ColumnReferenceInputNameModifier`: Shuttle implementation designed to update all uses of a column reference's input name to a new input name based on a dictionary. @@ -33,6 +35,7 @@ from pydough.relational.relational_expressions import ( ExpressionSortInfo, ColumnReferenceFinder, ColumnReferenceInputNameModifier, + CorrelatedReferenceFinder, WindowCallExpression, ) from pydough.pydough_operators import ADD, RANKING @@ -64,6 +67,11 @@ unique_column_refs = finder.get_column_references() # Modify the input name of column references in the call expression modifier = ColumnReferenceInputNameModifier({"old_input_name": "new_input_name"}) modified_call_expr = call_expr.accept_shuttle(modifier) + +# Find all unique correlated references in the call expression +correlated_finder = CorrelatedReferenceFinder() +call_expr.accept(correlated_finder) +unique_correlated_refs = correlated_finder.get_correlated_references() ``` ## [Relational Nodes](relational_nodes/README.md) diff --git a/pydough/relational/__init__.py b/pydough/relational/__init__.py index 75591d4f..b954a2f3 100644 --- a/pydough/relational/__init__.py +++ b/pydough/relational/__init__.py @@ -6,6 +6,7 @@ "ColumnReferenceFinder", "ColumnReferenceInputNameModifier", "ColumnReferenceInputNameRemover", + "CorrelatedReference", "EmptySingleton", "ExpressionSortInfo", "Filter", @@ -30,6 +31,7 @@ ColumnReferenceFinder, ColumnReferenceInputNameModifier, ColumnReferenceInputNameRemover, + CorrelatedReference, ExpressionSortInfo, LiteralExpression, RelationalExpression, diff --git a/pydough/relational/relational_expressions/README.md b/pydough/relational/relational_expressions/README.md index dcca7290..c11c8f44 100644 --- a/pydough/relational/relational_expressions/README.md +++ b/pydough/relational/relational_expressions/README.md @@ -45,6 +45,14 @@ The relational_expressions module provides functionality to define and manage va - `ColumnReferenceFinder`: Finds all unique column references in a relational expression. +### [correlated_reference.py](correlated_reference.py) + +- `CorrelatedReference`: The expression implementation for accessing a correlated column reference in a relational node. + +### [correlated_reference_finder.py](correlated_reference_finder.py) + +- `CorrelatedReferenceFinder`: Finds all unique correlated references in a relational expression. + ### [relational_expression_shuttle.py](relational_expression_shuttle.py) - `RelationalExpressionShuttle`: Specialized form of the visitor pattern that returns a relational expression. This is used to handle the common case where we need to modify a type of input. @@ -69,6 +77,8 @@ from pydough.relational.relational_expressions import ( ExpressionSortInfo, ColumnReferenceFinder, ColumnReferenceInputNameModifier, + CorrelatedReference, + CorrelatedReferenceFinder, ) from pydough.pydough_operators import ADD from pydough.types import Int64Type @@ -82,6 +92,10 @@ literal_expr = LiteralExpression(10, Int64Type()) # Create a call expression for addition call_expr = CallExpression(ADD, Int64Type(), [column_ref, literal_expr]) +# Create a correlated reference to column `column_name` in the first input to +# an ancestor join of `corr1` +correlated_ref = CorrelatedReference("column_name", "corr1", Int64Type()) + # Create an expression sort info sort_info = ExpressionSortInfo(call_expr, ascending=True, nulls_first=False) @@ -96,4 +110,9 @@ unique_column_refs = finder.get_column_references() # Modify the input name of column references in the call expression modifier = ColumnReferenceInputNameModifier({"old_input_name": "new_input_name"}) modified_call_expr = call_expr.accept_shuttle(modifier) + +# Find all unique correlated references in the call expression +correlated_finder = CorrelatedReferenceFinder() +call_expr.accept(correlated_finder) +unique_correlated_refs = correlated_finder.get_correlated_references() ``` diff --git a/pydough/relational/relational_expressions/__init__.py b/pydough/relational/relational_expressions/__init__.py index 68838524..3eb8fc33 100644 --- a/pydough/relational/relational_expressions/__init__.py +++ b/pydough/relational/relational_expressions/__init__.py @@ -9,6 +9,8 @@ "ColumnReferenceFinder", "ColumnReferenceInputNameModifier", "ColumnReferenceInputNameRemover", + "CorrelatedReference", + "CorrelatedReferenceFinder", "ExpressionSortInfo", "LiteralExpression", "RelationalExpression", @@ -21,6 +23,8 @@ from .column_reference_finder import ColumnReferenceFinder from .column_reference_input_name_modifier import ColumnReferenceInputNameModifier from .column_reference_input_name_remover import ColumnReferenceInputNameRemover +from .correlated_reference import CorrelatedReference +from .correlated_reference_finder import CorrelatedReferenceFinder from .expression_sort_info import ExpressionSortInfo from .literal_expression import LiteralExpression from .relational_expression_visitor import RelationalExpressionVisitor diff --git a/pydough/relational/relational_expressions/abstract_expression.py b/pydough/relational/relational_expressions/abstract_expression.py index 7236ad0b..ac3b6d33 100644 --- a/pydough/relational/relational_expressions/abstract_expression.py +++ b/pydough/relational/relational_expressions/abstract_expression.py @@ -81,6 +81,9 @@ def to_string(self, compact: bool = False) -> str: def __repr__(self) -> str: return self.to_string() + def __hash__(self) -> int: + return hash(self.to_string()) + @abstractmethod def accept(self, visitor: RelationalExpressionVisitor) -> None: """ diff --git a/pydough/relational/relational_expressions/column_reference_finder.py b/pydough/relational/relational_expressions/column_reference_finder.py index 7de5bc88..d8c7ba61 100644 --- a/pydough/relational/relational_expressions/column_reference_finder.py +++ b/pydough/relational/relational_expressions/column_reference_finder.py @@ -42,3 +42,6 @@ def visit_literal_expression(self, literal_expression: LiteralExpression) -> Non def visit_column_reference(self, column_reference: ColumnReference) -> None: self._column_references.add(column_reference) + + def visit_correlated_reference(self, correlated_reference) -> None: + pass diff --git a/pydough/relational/relational_expressions/column_reference_input_name_modifier.py b/pydough/relational/relational_expressions/column_reference_input_name_modifier.py index 0a750462..fd738ad7 100644 --- a/pydough/relational/relational_expressions/column_reference_input_name_modifier.py +++ b/pydough/relational/relational_expressions/column_reference_input_name_modifier.py @@ -45,3 +45,6 @@ def visit_column_reference(self, column_reference) -> RelationalExpression: raise ValueError( f"Input name {column_reference.input_name} not found in the input name map." ) + + def visit_correlated_reference(self, correlated_reference) -> RelationalExpression: + return correlated_reference diff --git a/pydough/relational/relational_expressions/column_reference_input_name_remover.py b/pydough/relational/relational_expressions/column_reference_input_name_remover.py index 26633764..0de2e1d9 100644 --- a/pydough/relational/relational_expressions/column_reference_input_name_remover.py +++ b/pydough/relational/relational_expressions/column_reference_input_name_remover.py @@ -37,3 +37,6 @@ def visit_column_reference(self, column_reference) -> RelationalExpression: column_reference.data_type, None, ) + + def visit_correlated_reference(self, correlated_reference) -> RelationalExpression: + return correlated_reference diff --git a/pydough/relational/relational_expressions/correlated_reference.py b/pydough/relational/relational_expressions/correlated_reference.py new file mode 100644 index 00000000..e6be2790 --- /dev/null +++ b/pydough/relational/relational_expressions/correlated_reference.py @@ -0,0 +1,65 @@ +""" +The representation of a correlated column access for use in a relational tree. +The correl name should be the `correl_name` property of a join ancestor of the +tree, and the name should match one of the column names of the first input to +that join, which is the column that the correlated reference refers to. +""" + +__all__ = ["CorrelatedReference"] + +from pydough.types import PyDoughType + +from .abstract_expression import RelationalExpression +from .relational_expression_shuttle import RelationalExpressionShuttle +from .relational_expression_visitor import RelationalExpressionVisitor + + +class CorrelatedReference(RelationalExpression): + """ + The Expression implementation for accessing a correlated column reference + in a relational node. + """ + + def __init__(self, name: str, correl_name: str, data_type: PyDoughType) -> None: + super().__init__(data_type) + self._name: str = name + self._correl_name: str = correl_name + + def __hash__(self) -> int: + return hash((self.name, self.correl_name, self.data_type)) + + @property + def name(self) -> str: + """ + The name of the column. + """ + return self._name + + @property + def correl_name(self) -> str: + """ + The name of the correlation that the reference points to. + """ + return self._correl_name + + def to_string(self, compact: bool = False) -> str: + if compact: + return f"{self.correl_name}.{self.name}" + else: + return f"CorrelatedReference(name={self.name}, correl_name={self.correl_name}, type={self.data_type})" + + def equals(self, other: object) -> bool: + return ( + isinstance(other, CorrelatedReference) + and (self.name == other.name) + and (self.correl_name == other.correl_name) + and super().equals(other) + ) + + def accept(self, visitor: RelationalExpressionVisitor) -> None: + visitor.visit_correlated_reference(self) + + def accept_shuttle( + self, shuttle: RelationalExpressionShuttle + ) -> RelationalExpression: + return shuttle.visit_correlated_reference(self) diff --git a/pydough/relational/relational_expressions/correlated_reference_finder.py b/pydough/relational/relational_expressions/correlated_reference_finder.py new file mode 100644 index 00000000..20ae0dc2 --- /dev/null +++ b/pydough/relational/relational_expressions/correlated_reference_finder.py @@ -0,0 +1,48 @@ +""" +Find all unique column references in a relational expression. +""" + +from .call_expression import CallExpression +from .column_reference import ColumnReference +from .correlated_reference import CorrelatedReference +from .literal_expression import LiteralExpression +from .relational_expression_visitor import RelationalExpressionVisitor +from .window_call_expression import WindowCallExpression + +__all__ = ["CorrelatedReferenceFinder"] + + +class CorrelatedReferenceFinder(RelationalExpressionVisitor): + """ + Find all unique correlated references in a relational expression. + """ + + def __init__(self) -> None: + self._correlated_references: set[CorrelatedReference] = set() + + def reset(self) -> None: + self._correlated_references = set() + + def get_correlated_references(self) -> set[CorrelatedReference]: + return self._correlated_references + + def visit_call_expression(self, call_expression: CallExpression) -> None: + for arg in call_expression.inputs: + arg.accept(self) + + def visit_window_expression(self, window_expression: WindowCallExpression) -> None: + for arg in window_expression.inputs: + arg.accept(self) + for partition_arg in window_expression.partition_inputs: + partition_arg.accept(self) + for order_arg in window_expression.order_inputs: + order_arg.expr.accept(self) + + def visit_literal_expression(self, literal_expression: LiteralExpression) -> None: + pass + + def visit_column_reference(self, column_reference: ColumnReference) -> None: + pass + + def visit_correlated_reference(self, correlated_reference) -> None: + self._correlated_references.add(correlated_reference) diff --git a/pydough/relational/relational_expressions/literal_expression.py b/pydough/relational/relational_expressions/literal_expression.py index 3b0eb803..1edc9981 100644 --- a/pydough/relational/relational_expressions/literal_expression.py +++ b/pydough/relational/relational_expressions/literal_expression.py @@ -29,10 +29,6 @@ def __init__(self, value: Any, data_type: PyDoughType): super().__init__(data_type) self._value: Any = value - def __hash__(self) -> int: - # Note: This will break if the value isn't hashable. - return hash((self.value, self.data_type)) - @property def value(self) -> object: """ diff --git a/pydough/relational/relational_expressions/relational_expression_shuttle.py b/pydough/relational/relational_expressions/relational_expression_shuttle.py index 354a4b8e..badc0b1d 100644 --- a/pydough/relational/relational_expressions/relational_expression_shuttle.py +++ b/pydough/relational/relational_expressions/relational_expression_shuttle.py @@ -75,3 +75,14 @@ def visit_column_reference(self, column_reference): Returns: RelationalExpression: The new node resulting from visiting this node. """ + + @abstractmethod + def visit_correlated_reference(self, correlated_reference): + """ + Visit a CorrelatedReference node. + + Args: + correlated_reference (CorrelatedReference): The correlated reference node to visit. + Returns: + RelationalExpression: The new node resulting from visiting this node. + """ diff --git a/pydough/relational/relational_expressions/relational_expression_visitor.py b/pydough/relational/relational_expressions/relational_expression_visitor.py index 39746b1a..0873662a 100644 --- a/pydough/relational/relational_expressions/relational_expression_visitor.py +++ b/pydough/relational/relational_expressions/relational_expression_visitor.py @@ -61,3 +61,12 @@ def visit_column_reference(self, column_reference) -> None: Args: column_reference (ColumnReference): The column reference node to visit. """ + + @abstractmethod + def visit_correlated_reference(self, correlated_reference) -> None: + """ + Visit a CorrelatedReference node. + + Args: + correlated_reference (CorrelatedReference): The correlated reference node to visit. + """ diff --git a/pydough/relational/relational_nodes/column_pruner.py b/pydough/relational/relational_nodes/column_pruner.py index 91f7f476..ef8e7cf2 100644 --- a/pydough/relational/relational_nodes/column_pruner.py +++ b/pydough/relational/relational_nodes/column_pruner.py @@ -5,10 +5,14 @@ from pydough.relational.relational_expressions import ( ColumnReference, ColumnReferenceFinder, + CorrelatedReference, + CorrelatedReferenceFinder, ) from .abstract_node import RelationalNode from .aggregate import Aggregate +from .empty_singleton import EmptySingleton +from .join import Join, JoinType from .project import Project from .relational_expression_dispatcher import RelationalExpressionDispatcher from .relational_root import RelationalRoot @@ -19,11 +23,15 @@ class ColumnPruner: def __init__(self) -> None: self._column_finder: ColumnReferenceFinder = ColumnReferenceFinder() + self._correl_finder: CorrelatedReferenceFinder = CorrelatedReferenceFinder() # Note: We set recurse=False so we only check the expressions in the # current node. - self._dispatcher = RelationalExpressionDispatcher( + self._finder_dispatcher = RelationalExpressionDispatcher( self._column_finder, recurse=False ) + self._correl_dispatcher = RelationalExpressionDispatcher( + self._correl_finder, recurse=False + ) def _prune_identity_project(self, node: RelationalNode) -> RelationalNode: """ @@ -43,7 +51,7 @@ def _prune_identity_project(self, node: RelationalNode) -> RelationalNode: def _prune_node_columns( self, node: RelationalNode, kept_columns: set[str] - ) -> RelationalNode: + ) -> tuple[RelationalNode, set[CorrelatedReference]]: """ Prune the columns for a subtree starting at this node. @@ -68,14 +76,17 @@ def _prune_node_columns( for name, expr in node.columns.items() if name in kept_columns or name in required_columns } + # Update the columns. new_node = node.copy(columns=columns) - self._dispatcher.reset() - # Visit the current identifiers. - new_node.accept(self._dispatcher) + + # Find all the identifiers referenced by the the current node. + self._finder_dispatcher.reset() + new_node.accept(self._finder_dispatcher) found_identifiers: set[ColumnReference] = ( self._column_finder.get_column_references() ) + # If the node is an aggregate but doesn't use any of the inputs # (e.g. a COUNT(*)), arbitrarily mark one of them as used. # TODO: (gh #196) optimize this functionality so it doesn't keep an @@ -88,19 +99,66 @@ def _prune_node_columns( node.input.columns[arbitrary_column_name].data_type, ) ) + # Determine which identifiers to pass to each input. new_inputs: list[RelationalNode] = [] # Note: The ColumnPruner should only be run when all input names are # still present in the columns. - for i, default_input_name in enumerate(new_node.default_input_aliases): + # Iterate over the inputs in reverse order so that the source of + # correlated data is pruned last, since it will need to account for + # any correlated references in the later inputs. + correl_refs: set[CorrelatedReference] = set() + for i, default_input_name in reversed( + list(enumerate(new_node.default_input_aliases)) + ): s: set[str] = set() + input_node: RelationalNode = node.inputs[i] for identifier in found_identifiers: if identifier.input_name == default_input_name: s.add(identifier.name) - new_inputs.append(self._prune_node_columns(node.inputs[i], s)) + if ( + isinstance(new_node, Join) + and i == 0 + and new_node.correl_name is not None + ): + for correl_ref in correl_refs: + if correl_ref.correl_name == new_node.correl_name: + s.add(correl_ref.name) + new_input_node, new_correl_refs = self._prune_node_columns(input_node, s) + new_inputs.append(new_input_node) + if i == len(node.inputs) - 1: + correl_refs = new_correl_refs + else: + correl_refs.update(new_correl_refs) + new_inputs.reverse() + + # Find all the correlated references in the new node. + self._correl_dispatcher.reset() + new_node.accept(self._correl_dispatcher) + found_correl_refs: set[CorrelatedReference] = ( + self._correl_finder.get_correlated_references() + ) + correl_refs.update(found_correl_refs) + # Determine the new node. output = new_node.copy(inputs=new_inputs) - return self._prune_identity_project(output) + output = self._prune_identity_project(output) + # Special case: replace empty aggregation with VALUES () if possible. + if ( + isinstance(output, Aggregate) + and len(output.keys) == 0 + and len(output.aggregations) == 0 + ): + return EmptySingleton(), correl_refs + # Special case: replace join where LHS is VALUES () with the RHS if + # possible. + if ( + isinstance(output, Join) + and isinstance(output.inputs[0], EmptySingleton) + and output.join_types in ([JoinType.INNER], [JoinType.LEFT]) + ): + return output.inputs[1], correl_refs + return output, correl_refs def prune_unused_columns(self, root: RelationalRoot) -> RelationalRoot: """ @@ -112,8 +170,6 @@ def prune_unused_columns(self, root: RelationalRoot) -> RelationalRoot: Returns: RelationalRoot: The root after updating all inputs. """ - new_root: RelationalNode = self._prune_node_columns( - root, set(root.columns.keys()) - ) + new_root, _ = self._prune_node_columns(root, set(root.columns.keys())) assert isinstance(new_root, RelationalRoot), "Expected a root node." return new_root diff --git a/pydough/relational/relational_nodes/join.py b/pydough/relational/relational_nodes/join.py index f1629689..f7759571 100644 --- a/pydough/relational/relational_nodes/join.py +++ b/pydough/relational/relational_nodes/join.py @@ -49,6 +49,7 @@ def __init__( conditions: list[RelationalExpression], join_types: list[JoinType], columns: MutableMapping[str, RelationalExpression], + correl_name: str | None = None, ) -> None: super().__init__(columns) num_inputs = len(inputs) @@ -65,6 +66,15 @@ def __init__( ), "Join condition must be a boolean type" self._conditions: list[RelationalExpression] = conditions self._join_types: list[JoinType] = join_types + self._correl_name: str | None = correl_name + + @property + def correl_name(self) -> str | None: + """ + The name used to refer to the first join input when subsequent inputs + have correlated references. + """ + return self._correl_name @property def conditions(self) -> list[RelationalExpression]: @@ -101,6 +111,7 @@ def node_equals(self, other: RelationalNode) -> bool: isinstance(other, Join) and self.conditions == other.conditions and self.join_types == other.join_types + and self.correl_name == other.correl_name and all( self.inputs[i].node_equals(other.inputs[i]) for i in range(len(self.inputs)) @@ -109,7 +120,10 @@ def node_equals(self, other: RelationalNode) -> bool: def to_string(self, compact: bool = False) -> str: conditions: list[str] = [cond.to_string(compact) for cond in self.conditions] - return f"JOIN(conditions=[{', '.join(conditions)}], types={[t.value for t in self.join_types]}, columns={self.make_column_string(self.columns, compact)})" + correl_suffix = ( + "" if self.correl_name is None else f", correl_name={self.correl_name!r}" + ) + return f"JOIN(conditions=[{', '.join(conditions)}], types={[t.value for t in self.join_types]}, columns={self.make_column_string(self.columns, compact)}{correl_suffix})" def accept(self, visitor: RelationalVisitor) -> None: visitor.visit_join(self) @@ -119,4 +133,4 @@ def node_copy( columns: MutableMapping[str, RelationalExpression], inputs: MutableSequence[RelationalNode], ) -> RelationalNode: - return Join(inputs, self.conditions, self.join_types, columns) + return Join(inputs, self.conditions, self.join_types, columns, self.correl_name) diff --git a/pydough/sqlglot/sqlglot_relational_expression_visitor.py b/pydough/sqlglot/sqlglot_relational_expression_visitor.py index 4ca111ba..e15e5007 100644 --- a/pydough/sqlglot/sqlglot_relational_expression_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_expression_visitor.py @@ -13,6 +13,7 @@ from pydough.relational import ( CallExpression, ColumnReference, + CorrelatedReference, LiteralExpression, RelationalExpression, RelationalExpressionVisitor, @@ -32,13 +33,17 @@ class SQLGlotRelationalExpressionVisitor(RelationalExpressionVisitor): """ def __init__( - self, dialect: SQLGlotDialect, bindings: SqlGlotTransformBindings + self, + dialect: SQLGlotDialect, + bindings: SqlGlotTransformBindings, + correlated_names: dict[str, str], ) -> None: # Keep a stack of SQLGlot expressions so we can build up # intermediate results. self._stack: list[SQLGlotExpression] = [] self._dialect: SQLGlotDialect = dialect self._bindings: SqlGlotTransformBindings = bindings + self._correlated_names: dict[str, str] = correlated_names def reset(self) -> None: """ @@ -133,18 +138,22 @@ def visit_literal_expression(self, literal_expression: LiteralExpression) -> Non ) self._stack.append(literal) + def visit_correlated_reference( + self, correlated_reference: CorrelatedReference + ) -> None: + full_name: str = f"{self._correlated_names[correlated_reference.correl_name]}.{correlated_reference.name}" + self._stack.append(Identifier(this=full_name)) + @staticmethod - def generate_column_reference_identifier( + def make_sqlglot_column( column_reference: ColumnReference, ) -> Identifier: """ Generate an identifier for a column reference. This is split into a separate static method to ensure consistency across multiple visitors. - Args: column_reference (ColumnReference): The column reference to generate an identifier for. - Returns: Identifier: The output identifier. """ @@ -155,7 +164,7 @@ def generate_column_reference_identifier( return Identifier(this=full_name) def visit_column_reference(self, column_reference: ColumnReference) -> None: - self._stack.append(self.generate_column_reference_identifier(column_reference)) + self._stack.append(self.make_sqlglot_column(column_reference)) def relational_to_sqlglot( self, expr: RelationalExpression, output_name: str | None = None diff --git a/pydough/sqlglot/sqlglot_relational_visitor.py b/pydough/sqlglot/sqlglot_relational_visitor.py index ae619893..ca9cbe2c 100644 --- a/pydough/sqlglot/sqlglot_relational_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_visitor.py @@ -9,6 +9,7 @@ from sqlglot.dialects import Dialect as SQLGlotDialect from sqlglot.expressions import Alias as SQLGlotAlias +from sqlglot.expressions import Column as SQLGlotColumn from sqlglot.expressions import Expression as SQLGlotExpression from sqlglot.expressions import Identifier, Select, Subquery, values from sqlglot.expressions import Literal as SQLGlotLiteral @@ -20,6 +21,7 @@ ColumnReference, ColumnReferenceInputNameModifier, ColumnReferenceInputNameRemover, + CorrelatedReference, EmptySingleton, ExpressionSortInfo, Filter, @@ -54,8 +56,11 @@ def __init__( # Keep a stack of SQLGlot expressions so we can build up # intermediate results. self._stack: list[Select] = [] + self._correlated_names: dict[str, str] = {} self._expr_visitor: SQLGlotRelationalExpressionVisitor = ( - SQLGlotRelationalExpressionVisitor(dialect, bindings) + SQLGlotRelationalExpressionVisitor( + dialect, bindings, self._correlated_names + ) ) self._alias_modifier: ColumnReferenceInputNameModifier = ( ColumnReferenceInputNameModifier() @@ -94,7 +99,7 @@ def _is_mergeable_column(expr: SQLGlotExpression) -> bool: if isinstance(expr, SQLGlotAlias): return SQLGlotRelationalVisitor._is_mergeable_column(expr.this) else: - return isinstance(expr, (SQLGlotLiteral, Identifier)) + return isinstance(expr, (SQLGlotLiteral, Identifier, SQLGlotColumn)) @staticmethod def _try_merge_columns( @@ -154,11 +159,22 @@ def _try_merge_columns( # If the new column is a literal, we can just add it to the old # columns. modified_old_columns.append(set_glot_alias(new_column, new_name)) - else: + elif isinstance(new_column, Identifier): expr = set_glot_alias(old_column_map[new_column.this], new_name) modified_old_columns.append(expr) if isinstance(expr, Identifier): seen_cols.add(expr) + elif isinstance(new_column, SQLGlotColumn): + expr = set_glot_alias( + old_column_map[new_column.this.this], new_name + ) + modified_old_columns.append(expr) + if isinstance(expr, Identifier): + seen_cols.add(expr) + else: + raise ValueError( + f"Unsupported expression type for column merging: {new_column.__class__.__name__}" + ) # Check that there are no missing dependencies in the old columns. if old_column_deps - seen_cols: return new_columns, old_columns @@ -301,7 +317,7 @@ def contains_window(self, exp: RelationalExpression) -> bool: match exp: case CallExpression(): return any(self.contains_window(arg) for arg in exp.inputs) - case ColumnReference() | LiteralExpression(): + case ColumnReference() | LiteralExpression() | CorrelatedReference(): return False case WindowCallExpression(): return True @@ -327,6 +343,12 @@ def visit_scan(self, scan: Scan) -> None: self._stack.append(query) def visit_join(self, join: Join) -> None: + alias_map: dict[str | None, str] = {} + if join.correl_name is not None: + input_name = join.default_input_aliases[0] + alias = self._generate_table_alias() + alias_map[input_name] = alias + self._correlated_names[join.correl_name] = alias self.visit_inputs(join) inputs: list[Select] = [self._stack.pop() for _ in range(len(join.inputs))] inputs.reverse() @@ -337,11 +359,12 @@ def visit_join(self, join: Join) -> None: seen_names[column] += 1 # Only keep duplicate names. kept_names = {key for key, value in seen_names.items() if value > 1} - alias_map = { - join.default_input_aliases[i]: self._generate_table_alias() - for i in range(len(join.inputs)) - if kept_names.intersection(join.inputs[i].columns.keys()) - } + for i in range(len(join.inputs)): + input_name = join.default_input_aliases[i] + if input_name not in alias_map and kept_names.intersection( + join.inputs[i].columns.keys() + ): + alias_map[input_name] = self._generate_table_alias() self._alias_remover.set_kept_names(kept_names) self._alias_modifier.set_map(alias_map) columns = { @@ -408,6 +431,7 @@ def visit_filter(self, filter: Filter) -> None: # TODO: (gh #151) Refactor a simpler way to check dependent expressions. if ( "group" in input_expr.args + or "distinct" in input_expr.args or "where" in input_expr.args or "qualify" in input_expr.args or "order" in input_expr.args @@ -439,6 +463,7 @@ def visit_aggregate(self, aggregate: Aggregate) -> None: query: Select if ( "group" in input_expr.args + or "distinct" in input_expr.args or "qualify" in input_expr.args or "order" in input_expr.args or "limit" in input_expr.args @@ -449,7 +474,10 @@ def visit_aggregate(self, aggregate: Aggregate) -> None: select_cols, input_expr, find_identifiers_in_list(select_cols) ) if keys: - query = query.group_by(*keys) + if aggregations: + query = query.group_by(*keys) + else: + query = query.distinct() self._stack.append(query) def visit_limit(self, limit: Limit) -> None: @@ -488,7 +516,7 @@ def visit_limit(self, limit: Limit) -> None: self._stack.append(query) def visit_empty_singleton(self, singleton: EmptySingleton) -> None: - self._stack.append(Select().from_(values([()]))) + self._stack.append(Select().select(SQLGlotStar()).from_(values([()]))) def visit_root(self, root: RelationalRoot) -> None: self.visit_inputs(root) diff --git a/pydough/types/struct_type.py b/pydough/types/struct_type.py index 94b3222a..7b3ef680 100644 --- a/pydough/types/struct_type.py +++ b/pydough/types/struct_type.py @@ -109,7 +109,7 @@ def parse_struct_body( except PyDoughTypeException: pass - # Otherwise, iterate across all commas int he right hand side + # Otherwise, iterate across all commas in the right hand side # that are candidate splitting locations between a PyDough # type and a suffix that is a valid list of fields. if field_type is None: diff --git a/pydough/unqualified/qualification.py b/pydough/unqualified/qualification.py index 216c4011..0a532d78 100644 --- a/pydough/unqualified/qualification.py +++ b/pydough/unqualified/qualification.py @@ -371,7 +371,7 @@ def qualify_access( for _ in range(levels): if ancestor.ancestor_context is None: raise PyDoughUnqualifiedException( - f"Cannot back reference {levels} above {unqualified_parent}" + f"Cannot back reference {levels} above {context}" ) ancestor = ancestor.ancestor_context # Identify whether the access is an expression or a collection @@ -676,7 +676,12 @@ def qualify_partition( partition: PartitionBy = self.builder.build_partition( qualified_parent, qualified_child, child_name ) - return partition.with_keys(child_references) + partition = partition.with_keys(child_references) + # Special case: if accessing as a child, wrap in a + # ChildOperatorChildAccess term. + if isinstance(unqualified_parent, UnqualifiedRoot) and is_child: + return ChildOperatorChildAccess(partition) + return partition def qualify_collection( self, diff --git a/tests/correlated_pydough_functions.py b/tests/correlated_pydough_functions.py new file mode 100644 index 00000000..3d8c4eb7 --- /dev/null +++ b/tests/correlated_pydough_functions.py @@ -0,0 +1,345 @@ +""" +Variant of `simple_pydough_functions.py` for functions testing edge cases in +correlation & de-correlation handling. +""" + +# ruff: noqa +# mypy: ignore-errors +# ruff & mypy should not try to typecheck or verify any of this + + +def correl_1(): + # Correlated back reference example #1: simple 1-step correlated reference + # For each region, count how many of its its nations start with the same + # letter as the region. This is a true correlated join doing an aggregated + # access without requiring the RHS be present. + return Regions( + name, n_prefix_nations=COUNT(nations.WHERE(name[:1] == BACK(1).name[:1])) + ).ORDER_BY(name.ASC()) + + +def correl_2(): + # Correlated back reference example #2: simple 2-step correlated reference + # For each region's nations, count how many customers have a comment + # starting with the same letter as the region. Exclude regions that start + # with the letter a. This is a true correlated join doing an aggregated + # access without requiring the RHS be present. + selected_custs = customers.WHERE(comment[:1] == LOWER(BACK(2).name[:1])) + return ( + Regions.WHERE(~STARTSWITH(name, "A")) + .nations( + name, + n_selected_custs=COUNT(selected_custs), + ) + .ORDER_BY(name.ASC()) + ) + + +def correl_3(): + # Correlated back reference example #3: double-layer correlated reference + # For every every region, count how many of its nations have a customer + # whose comment starts with the same 2 letter as the region. This is a true + # correlated join doing an aggregated access without requiring the RHS be + # present. + selected_custs = customers.WHERE(comment[:2] == LOWER(BACK(2).name[:2])) + return Regions(name, n_nations=COUNT(nations.WHERE(HAS(selected_custs)))).ORDER_BY( + name.ASC() + ) + + +def correl_4(): + # Correlated back reference example #4: 2-step correlated HASNOT + # Find every nation that does not have a customer whose account balance is + # within $5 of the smallest known account balance globally. + # (This is a correlated ANTI-join) + selected_customers = customers.WHERE(acctbal <= (BACK(2).smallest_bal + 5.0)) + return ( + TPCH( + smallest_bal=MIN(Customers.acctbal), + ) + .Nations(name) + .WHERE(HASNOT(selected_customers)) + .ORDER_BY(name.ASC()) + ) + + +def correl_5(): + # Correlated back reference example #5: 2-step correlated HAS + # Find every region that has at least 1 supplier whose account balance is + # within $4 of the smallest known account balance globally. + # (This is a correlated SEMI-join) + selected_suppliers = nations.suppliers.WHERE( + account_balance <= (BACK(3).smallest_bal + 4.0) + ) + return ( + TPCH( + smallest_bal=MIN(Suppliers.account_balance), + ) + .Regions(name) + .WHERE(HAS(selected_suppliers)) + .ORDER_BY(name.ASC()) + ) + + +def correl_6(): + # Correlated back reference example #6: simple 1-step correlated reference + # For each region, count how many of its its nations start with the same + # letter as the region, but only keep regions with at least one such nation. + # This is a true correlated join doing an aggregated access that does NOT + # require that records without the RHS be kept. + selected_nations = nations.WHERE(name[:1] == BACK(1).name[:1]) + return Regions.WHERE(HAS(selected_nations))( + name, n_prefix_nations=COUNT(selected_nations) + ) + + +def correl_7(): + # Correlated back reference example #6: deleted correlated reference + # For each region, count how many of its its nations start with the same + # letter as the region, but only keep regions without at least one such + # nation. The true correlated join is trumped by the correlated ANTI-join. + selected_nations = nations.WHERE(name[:1] == BACK(1).name[:1]) + return Regions.WHERE(HASNOT(selected_nations))( + name, n_prefix_nations=COUNT(selected_nations) + ) + + +def correl_8(): + # Correlated back reference example #8: non-agg correlated reference + # For each nation, fetch the name of its region, but filter the reigon + # so it only keeps it if it starts with the same letter as the nation + # (otherwise, returns NULL). This is a true correlated join doing an + # access without aggregation without requiring the RHS be + # present. + aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) + return Nations(name, rname=aug_region.name).ORDER_BY(name.ASC()) + + +def correl_9(): + # Correlated back reference example #9: non-agg correlated reference + # For each nation, fetch the name of its region, but filter the reigon + # so it only keeps it if it starts with the same letter as the nation + # (otherwise, omit the nation). This is a true correlated join doing an + # access that also requires the RHS records be present. + aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) + return Nations.WHERE(HAS(aug_region))(name, rname=aug_region.name).ORDER_BY( + name.ASC() + ) + + +def correl_10(): + # Correlated back reference example #10: deleted correlated reference + # For each nation, fetch the name of its region, but filter the reigon + # so it only keeps it if it starts with the same letter as the nation + # (otherwise, returns NULL), and also filter the nations to only keep + # records where the region is NULL. The true correlated join is trumped by + # the correlated ANTI-join. + aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) + return Nations.WHERE(HASNOT(aug_region))(name, rname=aug_region.name).ORDER_BY( + name.ASC() + ) + + +def correl_11(): + # Correlated back reference example #11: backref out of partition child. + # Which part brands have at least 1 part that more than 40% above the + # average retail price for all parts from that brand. + # (This is a correlated SEMI-join) + brands = PARTITION(Parts, name="p", by=brand)(avg_price=AVG(p.retail_price)) + outlier_parts = p.WHERE(retail_price > 1.4 * BACK(1).avg_price) + selected_brands = brands.WHERE(HAS(outlier_parts)) + return selected_brands(brand).ORDER_BY(brand.ASC()) + + +def correl_12(): + # Correlated back reference example #12: backref out of partition child. + # Which part brands have at least 1 part that is above the average retail + # price for parts of that brand, below the average retail price for all + # parts, and has a size below 3. + # (This is a correlated SEMI-join) + global_info = TPCH(avg_price=AVG(Parts.retail_price)) + brands = global_info.PARTITION(Parts, name="p", by=brand)( + avg_price=AVG(p.retail_price) + ) + selected_parts = p.WHERE( + (retail_price > BACK(1).avg_price) + & (retail_price < BACK(2).avg_price) + & (size < 3) + ) + selected_brands = brands.WHERE(HAS(selected_parts)) + return selected_brands(brand).ORDER_BY(brand.ASC()) + + +def correl_13(): + # Correlated back reference example #13: multiple correlation. + # Count how many suppliers sell at least one part where the retail price + # is less than a 50% markup over the supply cost. Only considers suppliers + # from nations #1/#2/#3, and small parts. + # (This is a correlated SEMI-joins) + selected_part = part.WHERE( + STARTSWITH(container, "SM") & (retail_price < (BACK(1).supplycost * 1.5)) + ) + selected_supply_records = supply_records.WHERE(HAS(selected_part)) + supplier_info = Suppliers.WHERE(nation_key <= 3)( + avg_price=AVG(supply_records.part.retail_price) + ) + selected_suppliers = supplier_info.WHERE(COUNT(selected_supply_records) > 0) + return TPCH(n=COUNT(selected_suppliers)) + + +def correl_14(): + # Correlated back reference example #14: multiple correlation. + # Count how many suppliers sell at least one part where the retail price + # is less than a 50% markup over the supply cost, and the retail price of + # the part is below the average for all parts from the supplier. Only + # considers suppliers from nations #19, and LG DRUM parts. + # (This is multiple correlated SEMI-joins) + selected_part = part.WHERE( + (container == "LG DRUM") + & (retail_price < (BACK(1).supplycost * 1.5)) + & (retail_price < BACK(2).avg_price) + ) + selected_supply_records = supply_records.WHERE(HAS(selected_part)) + supplier_info = Suppliers.WHERE(nation_key == 19)( + avg_price=AVG(supply_records.part.retail_price) + ) + selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records)) + return TPCH(n=COUNT(selected_suppliers)) + + +def correl_15(): + # Correlated back reference example #15: multiple correlation. + # Count how many suppliers sell at least one part where the retail price + # is less than a 50% markup over the supply cost, and the retail price of + # the part is below the 85% of the average of the retail price for all + # parts globally and below the average for all parts from the supplier. + # Only considers suppliers from nations #19, and LG DRUM parts. + # (This is multiple correlated SEMI-joins & a correlated aggregate) + selected_part = part.WHERE( + (container == "LG DRUM") + & (retail_price < (BACK(1).supplycost * 1.5)) + & (retail_price < BACK(2).avg_price) + & (retail_price < BACK(3).avg_price * 0.85) + ) + selected_supply_records = supply_records.WHERE(HAS(selected_part)) + supplier_info = Suppliers.WHERE(nation_key == 19)( + avg_price=AVG(supply_records.part.retail_price) + ) + selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records)) + global_info = TPCH(avg_price=AVG(Parts.retail_price)) + return global_info(n=COUNT(selected_suppliers)) + + +def correl_16(): + # Correlated back reference example #16: hybrid tree order of operations. + # Count how many european suppliers have the exact same percentile value + # of account balance (relative to all other suppliers) as at least one + # customer's percentile value of account balance relative to all other + # customers. Percentile should be measured down to increments of 0.01%. + # (This is a correlated SEMI-joins) + selected_customers = nation(rname=region.name).customers.WHERE( + (PERCENTILE(by=(acctbal.ASC(), key.ASC()), n_buckets=10000) == BACK(2).tile) + & (BACK(1).rname == "EUROPE") + ) + supplier_info = Suppliers( + tile=PERCENTILE(by=(account_balance.ASC(), key.ASC()), n_buckets=10000) + ) + selected_suppliers = supplier_info.WHERE(HAS(selected_customers)) + return TPCH(n=COUNT(selected_suppliers)) + + +def correl_17(): + # Correlated back reference example #17: hybrid tree order of operations. + # An extremely roundabout way of getting each region_name-nation_name + # pair as a string. + # (This is a correlated singular/semi access) + region_info = region(fname=JOIN_STRINGS("-", LOWER(name), BACK(1).lname)) + nation_info = Nations(lname=LOWER(name)).WHERE(HAS(region_info)) + return nation_info(fullname=region_info.fname).ORDER_BY(fullname.ASC()) + + +def correl_18(): + # Correlated back reference example #18: partition decorrelation edge case. + # Count how many orders corresponded to at least half of the total price + # spent by the ordering customer in a single day, but only if the customer + # ordered multiple orders in on that day. Only considers orders made in + # 1993. + # (This is a correlated aggregation access) + cust_date_groups = PARTITION( + Orders.WHERE(YEAR(order_date) == 1993), + name="o", + by=(customer_key, order_date), + ) + selected_groups = cust_date_groups.WHERE(COUNT(o) > 1)( + total_price=SUM(o.total_price), + )(n_above_avg=COUNT(o.WHERE(total_price >= 0.5 * BACK(1).total_price))) + return TPCH(n=SUM(selected_groups.n_above_avg)) + + +def correl_19(): + # Correlated back reference example #19: cardinality edge case. + # For every supplier, count how many customers in the same nation have a + # higher account balance than that supplier. Pick the 5 suppliers with the + # largest such count. + # (This is a correlated aggregation access) + super_cust = customers.WHERE(acctbal > BACK(2).account_balance) + return Suppliers.nation(name=BACK(1).name, n_super_cust=COUNT(super_cust)).TOP_K( + 5, n_super_cust.DESC() + ) + + +def correl_20(): + # Correlated back reference example #20: multiple ancestor uniqueness keys. + # Count the instances where a nation's suppliers shipped a part to a + # customer in the same nation, only counting instances where the order was + # made in June of 1998. + # (This is a correlated singular/semi access) + is_domestic = nation(domestic=name == BACK(5).name).domestic + selected_orders = Nations.customers.orders.WHERE( + (YEAR(order_date) == 1998) & (MONTH(order_date) == 6) + ) + instances = selected_orders.lines.supplier.WHERE(is_domestic) + return TPCH(n=COUNT(instances)) + + +def correl_21(): + # Correlated back reference example #21: partition edge case. + # Count how many part sizes have an above-average number of parts + # of that size. + # (This is a correlated aggregation access) + sizes = PARTITION(Parts, name="p", by=size)(n_parts=COUNT(p)) + return TPCH(avg_n_parts=AVG(sizes.n_parts))( + n_sizes=COUNT(sizes.WHERE(n_parts > BACK(1).avg_n_parts)) + ) + + +def correl_22(): + # Correlated back reference example #22: partition edge case. + # Finds the top 5 part sizes with the most container types + # where the average retail price of parts of that container type + # & part type is above the global average retail price. + # (This is a correlated aggregation access) + ct_combos = PARTITION(Parts, name="p", by=(container, part_type))( + avg_price=AVG(p.retail_price) + ) + return ( + TPCH(global_avg_price=AVG(Parts.retail_price)) + .PARTITION( + ct_combos.WHERE(avg_price > BACK(1).global_avg_price), + name="ct", + by=container, + )(container, n_types=COUNT(ct)) + .TOP_K(5, (n_types.DESC(), container.ASC())) + ) + + +def correl_23(): + # Correlated back reference example #23: partition edge case. + # Counts how many part sizes have an above-average number of combinations + # of part types/containers. + # (This is a correlated aggregation access) + combos = PARTITION(Parts, name="p", by=(size, part_type, container)) + sizes = PARTITION(combos, name="c", by=size)(n_combos=COUNT(c)) + return TPCH(avg_n_combo=AVG(sizes.n_combos))( + n_sizes=COUNT(sizes.WHERE(n_combos > BACK(1).avg_n_combo)), + ) diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index f96e7061..e80d9249 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -1,3 +1,6 @@ +""" +Various functions containing PyDough code snippets for testing purposes. +""" # ruff: noqa # mypy: ignore-errors # ruff & mypy should not try to typecheck or verify any of this @@ -240,6 +243,44 @@ def agg_partition(): return TPCH(best_year=MAX(yearly_data.n_orders)) +def multi_partition_access_1(): + # A use of multiple PARTITION and stepping into partition children that is + # a no-op. + data = Tickers(symbol).TOP_K(5, by=symbol.ASC()) + grps_a = PARTITION(data, name="child_3", by=(currency, exchange, ticker_type)) + grps_b = PARTITION(grps_a, name="child_2", by=(currency, exchange)) + grps_c = PARTITION(grps_b, name="child_1", by=exchange) + return grps_c.child_1.child_2.child_3 + + +def multi_partition_access_2(): + # Identify transactions that are below the average number of shares for + # transactions of the same combinations of (customer, stock, type), or + # the same combination of (customer, stock), or the same customer. + grps_a = PARTITION( + Transactions, name="child_3", by=(customer_id, ticker_id, transaction_type) + )(avg_shares_a=AVG(child_3.shares)) + grps_b = PARTITION(grps_a, name="child_2", by=(customer_id, ticker_id))( + avg_shares_b=AVG(child_2.child_3.shares) + ) + grps_c = PARTITION(grps_b, name="child_1", by=customer_id)( + avg_shares_c=AVG(child_1.child_2.child_3.shares) + ) + return grps_c.child_1.child_2.child_3.WHERE( + (shares < BACK(1).avg_shares_a) + & (shares < BACK(2).avg_shares_b) + & (shares < BACK(3).avg_shares_c) + )( + transaction_id, + customer.name, + ticker.symbol, + transaction_type, + BACK(1).avg_shares_a, + BACK(2).avg_shares_b, + BACK(3).avg_shares_c, + ).ORDER_BY(transaction_id.ASC()) + + def double_partition(): # Doing a partition aggregation on the output of a partition aggregation year_month_data = PARTITION( diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 34475607..6faf383f 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -12,6 +12,31 @@ bad_slice_3, bad_slice_4, ) +from correlated_pydough_functions import ( + correl_1, + correl_2, + correl_3, + correl_4, + correl_5, + correl_6, + correl_7, + correl_8, + correl_9, + correl_10, + correl_11, + correl_12, + correl_13, + correl_14, + correl_15, + correl_16, + correl_17, + correl_18, + correl_19, + correl_20, + correl_21, + correl_22, + correl_23, +) from simple_pydough_functions import ( agg_partition, double_partition, @@ -19,6 +44,8 @@ function_sampler, hour_minute_day, minutes_seconds_datediff, + multi_partition_access_1, + multi_partition_access_2, percentile_customers_per_region, percentile_nations, rank_nations_by_region, @@ -140,7 +167,6 @@ tpch_q5_output, ), id="tpch_q5", - marks=pytest.mark.skip("TODO: support correlated back references"), ), pytest.param( ( @@ -269,7 +295,6 @@ tpch_q21_output, ), id="tpch_q21", - marks=pytest.mark.skip("TODO: support correlated back references"), ), pytest.param( ( @@ -278,7 +303,6 @@ tpch_q22_output, ), id="tpch_q22", - marks=pytest.mark.skip("TODO: support correlated back references"), ), pytest.param( ( @@ -630,6 +654,377 @@ ), id="triple_partition", ), + pytest.param( + ( + correl_1, + "correl_1", + lambda: pd.DataFrame( + { + "name": ["AFRICA", "AMERICA", "ASIA", "EUROPE", "MIDDLE EAST"], + "n_prefix_nations": [1, 1, 0, 0, 0], + } + ), + ), + id="correl_1", + ), + pytest.param( + ( + correl_2, + "correl_2", + lambda: pd.DataFrame( + { + "name": [ + "EGYPT", + "FRANCE", + "GERMANY", + "IRAN", + "IRAQ", + "JORDAN", + "ROMANIA", + "RUSSIA", + "SAUDI ARABIA", + "UNITED KINGDOM", + ], + "n_selected_custs": [ + 19, + 593, + 595, + 15, + 21, + 9, + 588, + 620, + 19, + 585, + ], + } + ), + ), + id="correl_2", + ), + pytest.param( + ( + correl_3, + "correl_3", + lambda: pd.DataFrame( + { + "name": ["AFRICA", "AMERICA", "ASIA", "EUROPE", "MIDDLE EAST"], + "n_nations": [5, 5, 5, 0, 2], + } + ), + ), + id="correl_3", + ), + pytest.param( + ( + correl_4, + "correl_4", + lambda: pd.DataFrame( + { + "name": ["ARGENTINA", "KENYA", "UNITED KINGDOM"], + } + ), + ), + id="correl_4", + ), + pytest.param( + ( + correl_5, + "correl_5", + lambda: pd.DataFrame( + { + "name": ["AFRICA", "ASIA", "MIDDLE EAST"], + } + ), + ), + id="correl_5", + ), + pytest.param( + ( + correl_6, + "correl_6", + lambda: pd.DataFrame( + { + "name": ["AFRICA", "AMERICA"], + "n_prefix_nations": [1, 1], + } + ), + ), + id="correl_6", + ), + pytest.param( + ( + correl_7, + "correl_7", + lambda: pd.DataFrame( + { + "name": ["ASIA", "EUROPE", "MIDDLE EAST"], + "n_prefix_nations": [0] * 3, + } + ), + ), + id="correl_7", + ), + pytest.param( + ( + correl_8, + "correl_8", + lambda: pd.DataFrame( + { + "name": [ + "ALGERIA", + "ARGENTINA", + "BRAZIL", + "CANADA", + "CHINA", + "EGYPT", + "ETHIOPIA", + "FRANCE", + "GERMANY", + "INDIA", + "INDONESIA", + "IRAN", + "IRAQ", + "JAPAN", + "JORDAN", + "KENYA", + "MOROCCO", + "MOZAMBIQUE", + "PERU", + "ROMANIA", + "RUSSIA", + "SAUDI ARABIA", + "UNITED KINGDOM", + "UNITED STATES", + "VIETNAM", + ], + "rname": ["AFRICA", "AMERICA"] + [None] * 23, + } + ), + ), + id="correl_8", + ), + pytest.param( + ( + correl_9, + "correl_9", + lambda: pd.DataFrame( + { + "name": [ + "ALGERIA", + "ARGENTINA", + ], + "rname": ["AFRICA", "AMERICA"], + } + ), + ), + id="correl_9", + ), + pytest.param( + ( + correl_10, + "correl_10", + lambda: pd.DataFrame( + { + "name": [ + "BRAZIL", + "CANADA", + "CHINA", + "EGYPT", + "ETHIOPIA", + "FRANCE", + "GERMANY", + "INDIA", + "INDONESIA", + "IRAN", + "IRAQ", + "JAPAN", + "JORDAN", + "KENYA", + "MOROCCO", + "MOZAMBIQUE", + "PERU", + "ROMANIA", + "RUSSIA", + "SAUDI ARABIA", + "UNITED KINGDOM", + "UNITED STATES", + "VIETNAM", + ], + "rname": [None] * 23, + } + ), + ), + id="correl_10", + ), + pytest.param( + ( + correl_11, + "correl_11", + lambda: pd.DataFrame( + {"brand": ["Brand#33", "Brand#43", "Brand#45", "Brand#55"]} + ), + ), + id="correl_11", + ), + pytest.param( + ( + correl_12, + "correl_12", + lambda: pd.DataFrame( + { + "brand": [ + "Brand#14", + "Brand#31", + "Brand#33", + "Brand#43", + "Brand#55", + ] + } + ), + ), + id="correl_12", + ), + pytest.param( + ( + correl_13, + "correl_13", + lambda: pd.DataFrame({"n": [1129]}), + ), + id="correl_13", + ), + pytest.param( + ( + correl_14, + "correl_14", + lambda: pd.DataFrame({"n": [66]}), + ), + id="correl_14", + ), + pytest.param( + ( + correl_15, + "correl_15", + lambda: pd.DataFrame({"n": [61]}), + ), + id="correl_15", + ), + pytest.param( + ( + correl_16, + "correl_16", + lambda: pd.DataFrame({"n": [929]}), + ), + id="correl_16", + ), + pytest.param( + ( + correl_17, + "correl_17", + lambda: pd.DataFrame( + { + "fullname": [ + "africa-algeria", + "africa-ethiopia", + "africa-kenya", + "africa-morocco", + "africa-mozambique", + "america-argentina", + "america-brazil", + "america-canada", + "america-peru", + "america-united states", + "asia-china", + "asia-india", + "asia-indonesia", + "asia-japan", + "asia-vietnam", + "europe-france", + "europe-germany", + "europe-romania", + "europe-russia", + "europe-united kingdom", + "middle east-egypt", + "middle east-iran", + "middle east-iraq", + "middle east-jordan", + "middle east-saudi arabia", + ] + } + ), + ), + id="correl_17", + ), + pytest.param( + ( + correl_18, + "correl_18", + lambda: pd.DataFrame({"n": [697]}), + ), + id="correl_18", + ), + pytest.param( + ( + correl_19, + "correl_19", + lambda: pd.DataFrame( + { + "name": [ + "Supplier#000003934", + "Supplier#000003887", + "Supplier#000002628", + "Supplier#000008722", + "Supplier#000007971", + ], + "n_super_cust": [6160, 6142, 6129, 6127, 6117], + } + ), + ), + id="correl_19", + ), + pytest.param( + ( + correl_20, + "correl_20", + lambda: pd.DataFrame({"n": [3002]}), + ), + id="correl_20", + ), + pytest.param( + ( + correl_21, + "correl_21", + lambda: pd.DataFrame({"n_sizes": [30]}), + ), + id="correl_21", + ), + pytest.param( + ( + correl_22, + "correl_22", + lambda: pd.DataFrame( + { + "container": [ + "JUMBO DRUM", + "JUMBO PKG", + "MED DRUM", + "SM BAG", + "LG PKG", + ], + "n_types": [89, 86, 81, 81, 80], + } + ), + ), + id="correl_22", + ), + pytest.param( + ( + correl_23, + "correl_23", + lambda: pd.DataFrame({"n_sizes": [23]}), + ), + id="correl_23", + ), ], ) def pydough_pipeline_test_data( @@ -749,6 +1144,40 @@ def test_pipeline_e2e_errors( @pytest.fixture( params=[ + pytest.param( + ( + multi_partition_access_1, + "Broker", + lambda: pd.DataFrame( + {"symbol": ["AAPL", "AMZN", "BRK.B", "FB", "GOOG"]} + ), + ), + id="multi_partition_access_1", + ), + pytest.param( + ( + multi_partition_access_2, + "Broker", + lambda: pd.DataFrame( + { + "transaction_id": [f"TX{i:03}" for i in (22, 24, 25, 27, 56)], + "name": [ + "Jane Smith", + "Samantha Lee", + "Michael Chen", + "David Kim", + "Jane Smith", + ], + "symbol": ["MSFT", "TSLA", "GOOGL", "BRK.B", "FB"], + "transaction_type": ["sell", "sell", "buy", "buy", "sell"], + "avg_shares_a": [56.66667, 55.0, 4.0, 55.5, 47.5], + "avg_shares_b": [50.0, 41.66667, 3.33333, 37.33333, 47.5], + "avg_shares_c": [50.625, 46.25, 40.0, 37.33333, 50.625], + } + ), + ), + id="multi_partition_access_2", + ), pytest.param( ( hour_minute_day, diff --git a/tests/test_plan_refsols/correl_1.txt b/tests/test_plan_refsols/correl_1.txt new file mode 100644 index 00000000..bcc6d73d --- /dev/null +++ b/tests/test_plan_refsols/correl_1.txt @@ -0,0 +1,10 @@ +ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'n_prefix_nations': n_prefix_nations, 'name': name, 'ordering_1': name}) + PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/correl_10.txt b/tests/test_plan_refsols/correl_10.txt new file mode 100644 index 00000000..e762421d --- /dev/null +++ b/tests/test_plan_refsols/correl_10.txt @@ -0,0 +1,8 @@ +ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) + PROJECT(columns={'name': name, 'rname': NULL_2}) + FILTER(condition=True:bool, columns={'NULL_2': NULL_2, 'name': name}) + JOIN(conditions=[t0.region_key == t1.key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'key': key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/correl_11.txt b/tests/test_plan_refsols/correl_11.txt new file mode 100644 index 00000000..7e48bb88 --- /dev/null +++ b/tests/test_plan_refsols/correl_11.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('brand', brand)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'brand': brand, 'ordering_1': brand}) + FILTER(condition=True:bool, columns={'brand': brand}) + JOIN(conditions=[t0.brand == t1.brand], types=['semi'], columns={'brand': t0.brand}, correl_name='corr1') + PROJECT(columns={'avg_price': agg_0, 'brand': brand}) + AGGREGATE(keys={'brand': brand}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice}) + FILTER(condition=retail_price > 1.4:float64 * corr1.avg_price, columns={'brand': brand}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_12.txt b/tests/test_plan_refsols/correl_12.txt new file mode 100644 index 00000000..60626462 --- /dev/null +++ b/tests/test_plan_refsols/correl_12.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('brand', brand)], orderings=[(ordering_2):asc_first]) + PROJECT(columns={'brand': brand, 'ordering_2': brand}) + FILTER(condition=True:bool, columns={'brand': brand}) + JOIN(conditions=[t0.brand == t1.brand], types=['semi'], columns={'brand': t0.brand}, correl_name='corr2') + PROJECT(columns={'avg_price': avg_price, 'avg_price_2': agg_1, 'brand': brand}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1, 'avg_price': t0.avg_price, 'brand': t1.brand}) + PROJECT(columns={'avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + AGGREGATE(keys={'brand': brand}, aggregations={'agg_1': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice}) + FILTER(condition=retail_price > corr2.avg_price_2 & retail_price < corr2.avg_price & size < 3:int64, columns={'brand': brand}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice, 'size': p_size}) diff --git a/tests/test_plan_refsols/correl_13.txt b/tests/test_plan_refsols/correl_13.txt new file mode 100644 index 00000000..bc779fbd --- /dev/null +++ b/tests/test_plan_refsols/correl_13.txt @@ -0,0 +1,18 @@ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) + FILTER(condition=DEFAULT_TO(agg_1, 0:int64) > 0:int64, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_1': t1.agg_1}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'key': t0.key}) + FILTER(condition=nation_key <= 3:int64, columns={'account_balance': account_balance, 'key': key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_1': COUNT()}) + FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + FILTER(condition=STARTSWITH(container, 'SM':string) & retail_price < corr2.supplycost * 1.5:float64, columns={'key': key}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_14.txt b/tests/test_plan_refsols/correl_14.txt new file mode 100644 index 00000000..909904c1 --- /dev/null +++ b/tests/test_plan_refsols/correl_14.txt @@ -0,0 +1,18 @@ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) + FILTER(condition=True:bool, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr3') + PROJECT(columns={'account_balance': account_balance, 'avg_price': agg_0, 'key': key}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'key': t0.key}) + FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'key': key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) + FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + FILTER(condition=container == 'LG DRUM':string & retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price, columns={'key': key}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_15.txt b/tests/test_plan_refsols/correl_15.txt new file mode 100644 index 00000000..b71be80d --- /dev/null +++ b/tests/test_plan_refsols/correl_15.txt @@ -0,0 +1,22 @@ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_1}) + AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) + FILTER(condition=True:bool, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr4') + PROJECT(columns={'account_balance': account_balance, 'avg_price': avg_price, 'avg_price_3': agg_0, 'key': key}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'avg_price': t0.avg_price, 'key': t0.key}) + FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'avg_price': avg_price, 'key': key}) + JOIN(conditions=[True:bool], types=['inner'], columns={'account_balance': t1.account_balance, 'avg_price': t0.avg_price, 'key': t1.key, 'nation_key': t1.nation_key}) + PROJECT(columns={'avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) + FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr3') + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + FILTER(condition=container == 'LG DRUM':string & retail_price < corr3.supplycost * 1.5:float64 & retail_price < corr4.avg_price_3 & retail_price < corr4.avg_price * 0.85:float64, columns={'key': key}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_16.txt b/tests/test_plan_refsols/correl_16.txt new file mode 100644 index 00000000..f8c2e4eb --- /dev/null +++ b/tests/test_plan_refsols/correl_16.txt @@ -0,0 +1,14 @@ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) + FILTER(condition=True:bool, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.nation_key == t1.key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr7') + PROJECT(columns={'account_balance': account_balance, 'nation_key': nation_key, 'tile': PERCENTILE(args=[], partition=[], order=[(account_balance):asc_last, (key):asc_last], n_buckets=10000)}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + FILTER(condition=PERCENTILE(args=[], partition=[], order=[(acctbal):asc_last, (key_5):asc_last], n_buckets=10000) == corr7.tile & rname == 'EUROPE':string, columns={'key': key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'key': t0.key, 'key_5': t1.key, 'rname': t0.rname}) + PROJECT(columns={'key': key, 'rname': name_3}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_17.txt b/tests/test_plan_refsols/correl_17.txt new file mode 100644 index 00000000..aad7c616 --- /dev/null +++ b/tests/test_plan_refsols/correl_17.txt @@ -0,0 +1,11 @@ +ROOT(columns=[('fullname', fullname)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'fullname': fullname, 'ordering_0': fullname}) + PROJECT(columns={'fullname': fname}) + FILTER(condition=True:bool, columns={'fname': fname}) + JOIN(conditions=[t0.key == t1.key], types=['inner'], columns={'fname': t1.fname}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + PROJECT(columns={'fname': JOIN_STRINGS('-':string, LOWER(name_3), lname), 'key': key}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'lname': t0.lname, 'name_3': t1.name}) + PROJECT(columns={'key': key, 'lname': LOWER(name), 'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/correl_18.txt b/tests/test_plan_refsols/correl_18.txt new file mode 100644 index 00000000..74fab0da --- /dev/null +++ b/tests/test_plan_refsols/correl_18.txt @@ -0,0 +1,19 @@ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={}, aggregations={'agg_0': SUM(n_above_avg)}) + PROJECT(columns={'n_above_avg': DEFAULT_TO(agg_2, 0:int64)}) + JOIN(conditions=[t0.customer_key == t1.customer_key & t0.order_date == t1.order_date], types=['left'], columns={'agg_2': t1.agg_2}) + FILTER(condition=DEFAULT_TO(agg_0, 0:int64) > 1:int64, columns={'customer_key': customer_key, 'order_date': order_date}) + AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_0': COUNT()}) + FILTER(condition=YEAR(order_date) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate}) + AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_2': COUNT()}) + FILTER(condition=total_price_3 >= 0.5:float64 * total_price, columns={'customer_key': customer_key, 'order_date': order_date}) + FILTER(condition=YEAR(order_date_2) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price, 'total_price_3': total_price_3}) + JOIN(conditions=[t0.customer_key == t1.customer_key & t0.order_date == t1.order_date], types=['inner'], columns={'customer_key': t0.customer_key, 'order_date': t0.order_date, 'order_date_2': t1.order_date, 'total_price': t0.total_price, 'total_price_3': t1.total_price}) + PROJECT(columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': DEFAULT_TO(agg_1, 0:int64)}) + FILTER(condition=DEFAULT_TO(agg_0, 0:int64) > 1:int64, columns={'agg_1': agg_1, 'customer_key': customer_key, 'order_date': order_date}) + AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_0': COUNT(), 'agg_1': SUM(total_price)}) + FILTER(condition=YEAR(order_date) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) diff --git a/tests/test_plan_refsols/correl_19.txt b/tests/test_plan_refsols/correl_19.txt new file mode 100644 index 00000000..ddaa017a --- /dev/null +++ b/tests/test_plan_refsols/correl_19.txt @@ -0,0 +1,15 @@ +ROOT(columns=[('name', name_0), ('n_super_cust', n_super_cust)], orderings=[(ordering_1):desc_last]) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'n_super_cust': n_super_cust, 'name_0': name_0, 'ordering_1': ordering_1}, orderings=[(ordering_1):desc_last]) + PROJECT(columns={'n_super_cust': n_super_cust, 'name_0': name_0, 'ordering_1': n_super_cust}) + PROJECT(columns={'n_super_cust': DEFAULT_TO(agg_0, 0:int64), 'name_0': name}) + JOIN(conditions=[t0.key_2 == t1.key_5 & t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + AGGREGATE(keys={'key': key, 'key_5': key_5}, aggregations={'agg_0': COUNT()}) + FILTER(condition=acctbal > account_balance, columns={'key': key, 'key_5': key_5}) + JOIN(conditions=[t0.key_5 == t1.nation_key], types=['inner'], columns={'account_balance': t0.account_balance, 'acctbal': t1.acctbal, 'key': t0.key, 'key_5': t0.key_5}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'account_balance': t0.account_balance, 'key': t0.key, 'key_5': t1.key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_2.txt b/tests/test_plan_refsols/correl_2.txt new file mode 100644 index 00000000..b5d64b7c --- /dev/null +++ b/tests/test_plan_refsols/correl_2.txt @@ -0,0 +1,17 @@ +ROOT(columns=[('name', name_12), ('n_selected_custs', n_selected_custs)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_12': name_11, 'ordering_1': ordering_1}) + PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_11': name_11, 'ordering_1': name_11}) + PROJECT(columns={'n_selected_custs': DEFAULT_TO(agg_0, 0:int64), 'name_11': name_3}) + JOIN(conditions=[t0.key_2 == t1.key_5 & t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name_3': t1.name}) + FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + AGGREGATE(keys={'key': key, 'key_5': key_5}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(comment_7, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(name, None:unknown, 1:int64, None:unknown)), columns={'key': key, 'key_5': key_5}) + JOIN(conditions=[t0.key_5 == t1.nation_key], types=['inner'], columns={'comment_7': t1.comment, 'key': t0.key, 'key_5': t0.key_5, 'name': t0.name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'name': t0.name}) + FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_20.txt b/tests/test_plan_refsols/correl_20.txt new file mode 100644 index 00000000..66ed3e35 --- /dev/null +++ b/tests/test_plan_refsols/correl_20.txt @@ -0,0 +1,28 @@ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) + FILTER(condition=domestic, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.key_9 == t1.key_21 & t0.line_number == t1.line_number & t0.order_key == t1.order_key & t0.key_5 == t1.key_17 & t0.key_2 == t1.key_14 & t0.key == t1.key], types=['left'], columns={'account_balance': t0.account_balance, 'domestic': t1.domestic}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'account_balance': t1.account_balance, 'key': t0.key, 'key_2': t0.key_2, 'key_5': t0.key_5, 'key_9': t1.key, 'line_number': t0.line_number, 'order_key': t0.order_key}) + JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'key': t0.key, 'key_2': t0.key_2, 'key_5': t0.key_5, 'line_number': t1.line_number, 'order_key': t1.order_key, 'supplier_key': t1.supplier_key}) + FILTER(condition=YEAR(order_date) == 1998:int64 & MONTH(order_date) == 6:int64, columns={'key': key, 'key_2': key_2, 'key_5': key_5}) + JOIN(conditions=[t0.key_2 == t1.customer_key], types=['inner'], columns={'key': t0.key, 'key_2': t0.key_2, 'key_5': t1.key, 'order_date': t1.order_date}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) + SCAN(table=tpch.LINEITEM, columns={'line_number': l_linenumber, 'order_key': l_orderkey, 'supplier_key': l_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey}) + PROJECT(columns={'domestic': name_27 == name, 'key': key, 'key_14': key_14, 'key_17': key_17, 'key_21': key_21, 'line_number': line_number, 'order_key': order_key}) + JOIN(conditions=[t0.nation_key_23 == t1.key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t0.key_17, 'key_21': t0.key_21, 'line_number': t0.line_number, 'name': t0.name, 'name_27': t1.name, 'order_key': t0.order_key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t0.key_17, 'key_21': t1.key, 'line_number': t0.line_number, 'name': t0.name, 'nation_key_23': t1.nation_key, 'order_key': t0.order_key}) + JOIN(conditions=[t0.key_17 == t1.order_key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t0.key_17, 'line_number': t1.line_number, 'name': t0.name, 'order_key': t1.order_key, 'supplier_key': t1.supplier_key}) + FILTER(condition=YEAR(order_date) == 1998:int64 & MONTH(order_date) == 6:int64, columns={'key': key, 'key_14': key_14, 'key_17': key_17, 'name': name}) + JOIN(conditions=[t0.key_14 == t1.customer_key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t1.key, 'name': t0.name, 'order_date': t1.order_date}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_14': t1.key, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) + SCAN(table=tpch.LINEITEM, columns={'line_number': l_linenumber, 'order_key': l_orderkey, 'supplier_key': l_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) diff --git a/tests/test_plan_refsols/correl_21.txt b/tests/test_plan_refsols/correl_21.txt new file mode 100644 index 00000000..24c781d1 --- /dev/null +++ b/tests/test_plan_refsols/correl_21.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('n_sizes', n_sizes)], orderings=[]) + PROJECT(columns={'n_sizes': agg_1}) + AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) + FILTER(condition=n_parts > avg_n_parts, columns={'agg_0': agg_0}) + PROJECT(columns={'agg_0': agg_0, 'avg_n_parts': avg_n_parts, 'n_parts': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t1.agg_0, 'avg_n_parts': t0.avg_n_parts}) + PROJECT(columns={'avg_n_parts': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(n_parts)}) + PROJECT(columns={'n_parts': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + SCAN(table=tpch.PART, columns={'size': p_size}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + SCAN(table=tpch.PART, columns={'size': p_size}) diff --git a/tests/test_plan_refsols/correl_22.txt b/tests/test_plan_refsols/correl_22.txt new file mode 100644 index 00000000..4c5be854 --- /dev/null +++ b/tests/test_plan_refsols/correl_22.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('container', container), ('n_types', n_types)], orderings=[(ordering_2):desc_last, (ordering_3):asc_first]) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'container': container, 'n_types': n_types, 'ordering_2': ordering_2, 'ordering_3': ordering_3}, orderings=[(ordering_2):desc_last, (ordering_3):asc_first]) + PROJECT(columns={'container': container, 'n_types': n_types, 'ordering_2': n_types, 'ordering_3': container}) + PROJECT(columns={'container': container, 'n_types': DEFAULT_TO(agg_1, 0:int64)}) + AGGREGATE(keys={'container': container}, aggregations={'agg_1': COUNT()}) + FILTER(condition=avg_price > global_avg_price, columns={'container': container}) + PROJECT(columns={'avg_price': agg_0, 'container': container, 'global_avg_price': global_avg_price}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t1.agg_0, 'container': t1.container, 'global_avg_price': t0.global_avg_price}) + PROJECT(columns={'global_avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + AGGREGATE(keys={'container': container, 'part_type': part_type}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'container': p_container, 'part_type': p_type, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_23.txt b/tests/test_plan_refsols/correl_23.txt new file mode 100644 index 00000000..5bbadbc9 --- /dev/null +++ b/tests/test_plan_refsols/correl_23.txt @@ -0,0 +1,15 @@ +ROOT(columns=[('n_sizes', n_sizes)], orderings=[]) + PROJECT(columns={'n_sizes': agg_1}) + AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) + FILTER(condition=n_combos > avg_n_combo, columns={'agg_0': agg_0}) + PROJECT(columns={'agg_0': agg_0, 'avg_n_combo': avg_n_combo, 'n_combos': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t1.agg_0, 'avg_n_combo': t0.avg_n_combo}) + PROJECT(columns={'avg_n_combo': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(n_combos)}) + PROJECT(columns={'n_combos': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + AGGREGATE(keys={'container': container, 'part_type': part_type, 'size': size}, aggregations={}) + SCAN(table=tpch.PART, columns={'container': p_container, 'part_type': p_type, 'size': p_size}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + AGGREGATE(keys={'container': container, 'part_type': part_type, 'size': size}, aggregations={}) + SCAN(table=tpch.PART, columns={'container': p_container, 'part_type': p_type, 'size': p_size}) diff --git a/tests/test_plan_refsols/correl_3.txt b/tests/test_plan_refsols/correl_3.txt new file mode 100644 index 00000000..2bbb01bc --- /dev/null +++ b/tests/test_plan_refsols/correl_3.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('name', name), ('n_nations', n_nations)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'n_nations': n_nations, 'name': name, 'ordering_1': name}) + PROJECT(columns={'n_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=True:bool, columns={'key': key}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['semi'], columns={'key': t0.key}, correl_name='corr4') + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + FILTER(condition=SLICE(comment, None:unknown, 2:int64, None:unknown) == LOWER(SLICE(corr4.name, None:unknown, 2:int64, None:unknown)), columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_4.txt b/tests/test_plan_refsols/correl_4.txt new file mode 100644 index 00000000..38feea33 --- /dev/null +++ b/tests/test_plan_refsols/correl_4.txt @@ -0,0 +1,11 @@ +ROOT(columns=[('name', name)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'name': name, 'ordering_1': name}) + FILTER(condition=True:bool, columns={'name': name}) + JOIN(conditions=[t0.key == t1.nation_key], types=['anti'], columns={'name': t0.name}, correl_name='corr1') + JOIN(conditions=[True:bool], types=['inner'], columns={'key': t1.key, 'name': t1.name, 'smallest_bal': t0.smallest_bal}) + PROJECT(columns={'smallest_bal': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': MIN(acctbal)}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + FILTER(condition=acctbal <= corr1.smallest_bal + 5.0:float64, columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_5.txt b/tests/test_plan_refsols/correl_5.txt new file mode 100644 index 00000000..73d793f0 --- /dev/null +++ b/tests/test_plan_refsols/correl_5.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('name', name)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'name': name, 'ordering_1': name}) + FILTER(condition=True:bool, columns={'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['semi'], columns={'name': t0.name}, correl_name='corr4') + JOIN(conditions=[True:bool], types=['inner'], columns={'key': t1.key, 'name': t1.name, 'smallest_bal': t0.smallest_bal}) + PROJECT(columns={'smallest_bal': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': MIN(account_balance)}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + FILTER(condition=account_balance <= corr4.smallest_bal + 4.0:float64, columns={'region_key': region_key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'account_balance': t1.account_balance, 'region_key': t0.region_key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) diff --git a/tests/test_plan_refsols/correl_6.txt b/tests/test_plan_refsols/correl_6.txt new file mode 100644 index 00000000..a6829877 --- /dev/null +++ b/tests/test_plan_refsols/correl_6.txt @@ -0,0 +1,10 @@ +ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[]) + PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + FILTER(condition=True:bool, columns={'agg_0': agg_0, 'name': name}) + JOIN(conditions=[t0.key == t1.key], types=['inner'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/correl_7.txt b/tests/test_plan_refsols/correl_7.txt new file mode 100644 index 00000000..94a129af --- /dev/null +++ b/tests/test_plan_refsols/correl_7.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[]) + PROJECT(columns={'n_prefix_nations': DEFAULT_TO(NULL_2, 0:int64), 'name': name}) + FILTER(condition=True:bool, columns={'NULL_2': NULL_2, 'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/correl_8.txt b/tests/test_plan_refsols/correl_8.txt new file mode 100644 index 00000000..8da228f0 --- /dev/null +++ b/tests/test_plan_refsols/correl_8.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) + PROJECT(columns={'name': name, 'rname': name_3}) + JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'name': t0.name, 'name_3': t1.name_3}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name_3': name_3}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/correl_9.txt b/tests/test_plan_refsols/correl_9.txt new file mode 100644 index 00000000..449cceea --- /dev/null +++ b/tests/test_plan_refsols/correl_9.txt @@ -0,0 +1,10 @@ +ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) + PROJECT(columns={'name': name, 'rname': name_3}) + FILTER(condition=True:bool, columns={'name': name, 'name_3': name_3}) + JOIN(conditions=[t0.key == t1.key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name_3}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name_3': name_3}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/join_regions_nations_calc_override.txt b/tests/test_plan_refsols/join_regions_nations_calc_override.txt index 97156430..e1cce36c 100644 --- a/tests/test_plan_refsols/join_regions_nations_calc_override.txt +++ b/tests/test_plan_refsols/join_regions_nations_calc_override.txt @@ -1,6 +1,6 @@ -ROOT(columns=[('key', key_0_9), ('name', name_11), ('phone', phone), ('mktsegment', mktsegment)], orderings=[]) - PROJECT(columns={'key_0_9': key_0_9, 'mktsegment': mktsegment, 'name_11': name_10, 'phone': phone}) - PROJECT(columns={'key_0_9': -3:int64, 'mktsegment': mktsegment, 'name_10': name_7, 'phone': phone}) +ROOT(columns=[('key', key_0_10), ('name', name_11), ('phone', phone), ('mktsegment', mktsegment)], orderings=[]) + PROJECT(columns={'key_0_10': key_0_10, 'mktsegment': mktsegment, 'name_11': name_9, 'phone': phone}) + PROJECT(columns={'key_0_10': -3:int64, 'mktsegment': mktsegment, 'name_9': name_7, 'phone': phone}) JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'mktsegment': t1.mktsegment, 'name_7': t1.name, 'phone': t1.phone}) JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key}) SCAN(table=tpch.REGION, columns={'key': r_regionkey}) diff --git a/tests/test_plan_refsols/tpch_q21.txt b/tests/test_plan_refsols/tpch_q21.txt new file mode 100644 index 00000000..fdf99273 --- /dev/null +++ b/tests/test_plan_refsols/tpch_q21.txt @@ -0,0 +1,21 @@ +ROOT(columns=[('S_NAME', S_NAME), ('NUMWAIT', NUMWAIT)], orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'NUMWAIT': NUMWAIT, 'S_NAME': S_NAME, 'ordering_1': ordering_1, 'ordering_2': ordering_2}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) + PROJECT(columns={'NUMWAIT': NUMWAIT, 'S_NAME': S_NAME, 'ordering_1': NUMWAIT, 'ordering_2': S_NAME}) + PROJECT(columns={'NUMWAIT': DEFAULT_TO(agg_0, 0:int64), 'S_NAME': name}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + FILTER(condition=name_3 == 'SAUDI ARABIA':string, columns={'key': key, 'name': name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=order_status == 'F':string & True:bool & True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.key == t1.order_key], types=['anti'], columns={'order_status': t0.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr6') + JOIN(conditions=[t0.key == t1.order_key], types=['semi'], columns={'key': t0.key, 'order_status': t0.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr5') + JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'key': t1.key, 'order_status': t1.order_status, 'supplier_key': t0.supplier_key}) + FILTER(condition=receipt_date > commit_date, columns={'order_key': order_key, 'supplier_key': supplier_key}) + SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate, 'supplier_key': l_suppkey}) + SCAN(table=tpch.ORDERS, columns={'key': o_orderkey, 'order_status': o_orderstatus}) + FILTER(condition=supplier_key != corr5.supplier_key, columns={'order_key': order_key}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'supplier_key': l_suppkey}) + FILTER(condition=supplier_key != corr6.supplier_key & receipt_date > commit_date, columns={'order_key': order_key}) + SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate, 'supplier_key': l_suppkey}) diff --git a/tests/test_plan_refsols/tpch_q22.txt b/tests/test_plan_refsols/tpch_q22.txt new file mode 100644 index 00000000..01b9f860 --- /dev/null +++ b/tests/test_plan_refsols/tpch_q22.txt @@ -0,0 +1,18 @@ +ROOT(columns=[('CNTRY_CODE', CNTRY_CODE), ('NUM_CUSTS', NUM_CUSTS), ('TOTACCTBAL', TOTACCTBAL)], orderings=[(ordering_3):asc_first]) + PROJECT(columns={'CNTRY_CODE': CNTRY_CODE, 'NUM_CUSTS': NUM_CUSTS, 'TOTACCTBAL': TOTACCTBAL, 'ordering_3': CNTRY_CODE}) + PROJECT(columns={'CNTRY_CODE': cntry_code, 'NUM_CUSTS': DEFAULT_TO(agg_1, 0:int64), 'TOTACCTBAL': DEFAULT_TO(agg_2, 0:int64)}) + AGGREGATE(keys={'cntry_code': cntry_code}, aggregations={'agg_1': COUNT(), 'agg_2': SUM(acctbal)}) + FILTER(condition=acctbal > avg_balance & DEFAULT_TO(agg_0, 0:int64) == 0:int64, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) + JOIN(conditions=[t0.key == t1.customer_key], types=['left'], columns={'acctbal': t0.acctbal, 'agg_0': t1.agg_0, 'avg_balance': t0.avg_balance, 'cntry_code': t0.cntry_code}) + FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]), columns={'acctbal': acctbal, 'avg_balance': avg_balance, 'cntry_code': cntry_code, 'key': key}) + PROJECT(columns={'acctbal': acctbal, 'avg_balance': avg_balance, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) + JOIN(conditions=[True:bool], types=['inner'], columns={'acctbal': t1.acctbal, 'avg_balance': t0.avg_balance, 'key': t1.key, 'phone': t1.phone}) + PROJECT(columns={'avg_balance': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(acctbal)}) + FILTER(condition=acctbal > 0.0:float64, columns={'acctbal': acctbal}) + FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]), columns={'acctbal': acctbal}) + PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown)}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'phone': c_phone}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) + AGGREGATE(keys={'customer_key': customer_key}, aggregations={'agg_0': COUNT()}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) diff --git a/tests/test_plan_refsols/tpch_q5.txt b/tests/test_plan_refsols/tpch_q5.txt new file mode 100644 index 00000000..7cb9925b --- /dev/null +++ b/tests/test_plan_refsols/tpch_q5.txt @@ -0,0 +1,26 @@ +ROOT(columns=[('N_NAME', N_NAME), ('REVENUE', REVENUE)], orderings=[(ordering_1):desc_last]) + PROJECT(columns={'N_NAME': N_NAME, 'REVENUE': REVENUE, 'ordering_1': REVENUE}) + PROJECT(columns={'N_NAME': name, 'REVENUE': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + FILTER(condition=name_3 == 'ASIA':string, columns={'key': key, 'name': name}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'key': key}, aggregations={'agg_0': SUM(value)}) + PROJECT(columns={'key': key, 'value': extended_price * 1:int64 - discount}) + FILTER(condition=name_15 == name, columns={'discount': discount, 'extended_price': extended_price, 'key': key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'key': t0.key, 'name': t0.name, 'name_15': t1.name_15}) + JOIN(conditions=[t0.key_11 == t1.order_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'key': t0.key, 'name': t0.name, 'supplier_key': t1.supplier_key}) + FILTER(condition=order_date >= datetime.date(1994, 1, 1):date & order_date < datetime.date(1995, 1, 1):date, columns={'key': key, 'key_11': key_11, 'name': name}) + JOIN(conditions=[t0.key_8 == t1.customer_key], types=['inner'], columns={'key': t0.key, 'key_11': t1.key, 'name': t0.name, 'order_date': t1.order_date}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_8': t1.key, 'name': t0.name}) + FILTER(condition=name_6 == 'ASIA':string, columns={'key': key, 'name': name}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_6': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'supplier_key': l_suppkey}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_15': t1.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) diff --git a/tests/test_qualification.py b/tests/test_qualification.py index 6e25259d..eab7c54b 100644 --- a/tests/test_qualification.py +++ b/tests/test_qualification.py @@ -68,6 +68,26 @@ def pydough_impl_misc_02(root: UnqualifiedNode) -> UnqualifiedNode: ) +def pydough_impl_misc_03(root: UnqualifiedNode) -> UnqualifiedNode: + """ + Creates an UnqualifiedNode for the following PyDough snippet: + ``` + sizes = PARTITION(Parts, name="p", by=size)(n_parts=COUNT(p)) + TPCH( + avg_n_parts=AVG(sizes.n_parts) + )( + n_parts=COUNT(sizes.WHERE(n_parts > BACK(1).avg_n_parts)) + ) + ``` + """ + sizes = root.PARTITION(root.Parts, name="p", by=root.size)( + n_parts=root.COUNT(root.p) + ) + return root.TPCH(avg_n_parts=root.AVG(sizes.n_parts))( + n_parts=root.COUNT(sizes.WHERE(root.n_parts > root.BACK(1).avg_n_parts)) + ) + + def pydough_impl_tpch_q1(root: UnqualifiedNode) -> UnqualifiedNode: """ Creates an UnqualifiedNode for TPC-H query 1. @@ -612,6 +632,30 @@ def pydough_impl_tpch_q22(root: UnqualifiedNode) -> UnqualifiedNode: """, id="misc_02", ), + pytest.param( + pydough_impl_misc_03, + """ +┌─── TPCH +├─┬─ Calc[avg_n_parts=AVG($1.n_parts)] +│ └─┬─ AccessChild +│ ├─┬─ Partition[name='p', by=size] +│ │ └─┬─ AccessChild +│ │ └─── TableCollection[Parts] +│ └─┬─ Calc[n_parts=COUNT($1)] +│ └─┬─ AccessChild +│ └─── PartitionChild[p] +└─┬─ Calc[n_parts=COUNT($1)] + └─┬─ AccessChild + ├─┬─ Partition[name='p', by=size] + │ └─┬─ AccessChild + │ └─── TableCollection[Parts] + ├─┬─ Calc[n_parts=COUNT($1)] + │ └─┬─ AccessChild + │ └─── PartitionChild[p] + └─── Where[n_parts > BACK(1).avg_n_parts] +""", + id="misc_03", + ), pytest.param( pydough_impl_tpch_q1, """ diff --git a/tests/test_relational_nodes_to_sqlglot.py b/tests/test_relational_nodes_to_sqlglot.py index 7859c88f..6c3c04cd 100644 --- a/tests/test_relational_nodes_to_sqlglot.py +++ b/tests/test_relational_nodes_to_sqlglot.py @@ -152,6 +152,8 @@ def mkglot(expressions: list[Expression], _from: Expression, **kwargs) -> Select query = query.where(kwargs.pop("where")) if "group_by" in kwargs: query = query.group_by(*kwargs.pop("group_by")) + if kwargs.pop("distinct", False): + query = query.distinct() if "qualify" in kwargs: query = query.qualify(kwargs.pop("qualify")) if "order_by" in kwargs: @@ -627,14 +629,15 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: Aggregate( input=build_simple_scan(), keys={ + "a": make_relational_column_reference("a"), "b": make_relational_column_reference("b"), }, aggregations={}, ), mkglot( - expressions=[Ident(this="b")], + expressions=[Ident(this="a"), Ident(this="b")], _from=GlotFrom(Table(this=Ident(this="table"))), - group_by=[Ident(this="b")], + distinct=True, ), id="simple_distinct", ), @@ -718,7 +721,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: mkglot( expressions=[Ident(this="b")], where=mkglot_func(EQ, [Ident(this="a"), mk_literal(1, False)]), - group_by=[Ident(this="b")], + distinct=True, _from=GlotFrom( mkglot( expressions=[Ident(this="a"), Ident(this="b")], @@ -794,7 +797,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: ), mkglot( expressions=[Ident(this="b")], - group_by=[Ident(this="b")], + distinct=True, _from=GlotFrom( mkglot( expressions=[Ident(this="a"), Ident(this="b")], @@ -829,7 +832,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: mkglot( expressions=[Ident(this="b")], _from=GlotFrom(Table(this=Ident(this="table"))), - group_by=[Ident(this="b")], + distinct=True, order_by=[Ident(this="b").desc(nulls_first=False)], limit=mk_literal(10, False), ), @@ -865,8 +868,8 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: _from=GlotFrom( mkglot( expressions=[Ident(this="b")], - group_by=[Ident(this="b")], _from=GlotFrom(Table(this=Ident(this="table"))), + distinct=True, ) ), ), diff --git a/tests/test_sql_refsols/simple_distinct.sql b/tests/test_sql_refsols/simple_distinct.sql index 696d95c0..38ad7566 100644 --- a/tests/test_sql_refsols/simple_distinct.sql +++ b/tests/test_sql_refsols/simple_distinct.sql @@ -1,5 +1,3 @@ -SELECT +SELECT DISTINCT b FROM table -GROUP BY - b diff --git a/tests/test_unqualified_node.py b/tests/test_unqualified_node.py index 386266b2..34c7921c 100644 --- a/tests/test_unqualified_node.py +++ b/tests/test_unqualified_node.py @@ -418,7 +418,7 @@ def test_unqualified_to_string( ), pytest.param( impl_tpch_q22, - "?.TPCH(avg_balance=AVG(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE((ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17']) & HASNOT(?.orders))).WHERE((?.acctbal > 0.0)).acctbal)).PARTITION(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE((ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17']) & HASNOT(?.orders))).WHERE((?.acctbal > BACK(1).avg_balance)), name='custs', by=(?.cntry_code))(CNTRY_CODE=?.cntry_code, NUM_CUSTS=COUNT(?.custs), TOTACCTBAL=SUM(?.custs.acctbal))", + "?.TPCH(avg_balance=AVG(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE(ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17'])).WHERE((?.acctbal > 0.0)).acctbal)).PARTITION(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE(ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17'])).WHERE(((?.acctbal > BACK(1).avg_balance) & (COUNT(?.orders) == 0))), name='custs', by=(?.cntry_code))(CNTRY_CODE=?.cntry_code, NUM_CUSTS=COUNT(?.custs), TOTACCTBAL=SUM(?.custs.acctbal)).ORDER_BY(?.CNTRY_CODE.ASC(na_pos='first'))", id="tpch_q22", ), pytest.param( diff --git a/tests/tpch_outputs.py b/tests/tpch_outputs.py index c875d630..82fbe5cb 100644 --- a/tests/tpch_outputs.py +++ b/tests/tpch_outputs.py @@ -692,7 +692,7 @@ def tpch_q21_output() -> pd.DataFrame: Expected output for TPC-H query 21. Note: This is truncated to the first 10 rows. """ - columns = ["S_NAME", "NUM_WAIT"] + columns = ["S_NAME", "NUMWAIT"] data = [ ("Supplier#000002829", 20), ("Supplier#000005808", 18), @@ -715,7 +715,7 @@ def tpch_q22_output() -> pd.DataFrame: This query needs manual rewriting to run efficiently in SQLite by avoiding the correlated join. """ - columns = ["CNTRYCODE", "NUMCUST", "TOTACCTBAL"] + columns = ["CNTRY_CODE", "NUM_CUSTS", "TOTACCTBAL"] data = [ ("13", 888, 6737713.99), ("17", 861, 6460573.72), diff --git a/tests/tpch_test_functions.py b/tests/tpch_test_functions.py index 21cfdaa8..c1df2cf8 100644 --- a/tests/tpch_test_functions.py +++ b/tests/tpch_test_functions.py @@ -494,16 +494,20 @@ def impl_tpch_q22(): PyDough implementation of TPCH Q22. """ selected_customers = Customers(cntry_code=phone[:2]).WHERE( - ISIN(cntry_code, ("13", "31", "23", "29", "30", "18", "17")) & HASNOT(orders) + ISIN(cntry_code, ("13", "31", "23", "29", "30", "18", "17")) ) - return TPCH( - avg_balance=AVG(selected_customers.WHERE(acctbal > 0.0).acctbal) - ).PARTITION( - selected_customers.WHERE(acctbal > BACK(1).avg_balance), - name="custs", - by=cntry_code, - )( - CNTRY_CODE=cntry_code, - NUM_CUSTS=COUNT(custs), - TOTACCTBAL=SUM(custs.acctbal), + return ( + TPCH(avg_balance=AVG(selected_customers.WHERE(acctbal > 0.0).acctbal)) + .PARTITION( + selected_customers.WHERE( + (acctbal > BACK(1).avg_balance) & (COUNT(orders) == 0) + ), + name="custs", + by=cntry_code, + )( + CNTRY_CODE=cntry_code, + NUM_CUSTS=COUNT(custs), + TOTACCTBAL=SUM(custs.acctbal), + ) + .ORDER_BY(CNTRY_CODE.ASC()) )