diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py index 72f6aad15..a1c364077 100644 --- a/pydough/conversion/hybrid_decorrelater.py +++ b/pydough/conversion/hybrid_decorrelater.py @@ -34,7 +34,7 @@ class Decorrelater: def make_decorrelate_parent( self, hybrid: HybridTree, child_idx: int, required_steps: int - ) -> HybridTree: + ) -> tuple[HybridTree, int]: """ Creates a snapshot of the ancestry of the hybrid tree that contains a correlated child, without any of its children, its descendants, or @@ -50,10 +50,12 @@ def make_decorrelate_parent( derivable. Returns: - A snapshot of `hybrid` and its ancestry in the hybrid tree, without - without any of its children or pipeline operators that occur during - or after the derivation of the correlated child, or without any of - its descendants. + A tuple where the first entry is a snapshot of `hybrid` and its + ancestry in the hybrid tree, without without any of its children or + pipeline operators that occur during or after the derivation of the + correlated child, or without any of its descendants. The second + entry is the number of ancestor layers that should be skipped due + to the PARTITION edge case. """ if isinstance(hybrid.pipeline[0], HybridPartition) and child_idx == 0: # Special case: if the correlated child is the data argument of a @@ -61,10 +63,14 @@ def make_decorrelate_parent( # parent of the level containing the partition operation. In this # case, all of the parent's children & pipeline operators should be # included in the snapshot. - assert hybrid.parent is not None - return self.make_decorrelate_parent( + if hybrid.parent is None: + raise ValueError( + "Malformed hybrid tree: partition data input to a partition node cannot contain a correlated reference to the partition node." + ) + result = self.make_decorrelate_parent( hybrid.parent, len(hybrid.parent.children), len(hybrid.pipeline) ) + return result[0], result[1] + 1 # Temporarily detach the successor of the current level, then create a # deep copy of the current level (which will include its ancestors), # then reattach the successor back to the original. This ensures that @@ -78,7 +84,7 @@ def make_decorrelate_parent( # that is has to. new_hybrid._children = new_hybrid._children[:child_idx] new_hybrid._pipeline = new_hybrid._pipeline[: required_steps + 1] - return new_hybrid + return new_hybrid, 0 def remove_correl_refs( self, expr: HybridExpr, parent: HybridTree, child_height: int @@ -221,6 +227,7 @@ def decorrelate_child( new_parent: HybridTree, child: HybridConnection, is_aggregate: bool, + skipped_levels: int, ) -> None: """ Runs the logic to de-correlate a child of a hybrid tree that contains @@ -230,6 +237,17 @@ def decorrelate_child( child can now replace correlated references with BACK references that point to terms in its newly expanded ancestry, and the original hybrid tree can now join onto this child using its uniqueness keys. + + Args: + `old_parent`: The correlated ancestor hybrid tree that the correlated + references should point to when they are targeted for removal. + `new_parent`: The ancestor of `level` that removal should stop at. + `child`: The child of the hybrid tree that contains the correlated + nodes to be removed. + `is_aggregate`: Whether the child is being aggregated with regards + to its parent. + `skipped_levels`: The number of ancestor layers that should be + ignored when deriving backshifts of join/agg keys. """ # First, find the height of the child subtree & its top-most level. child_root: HybridTree = child.subtree @@ -245,28 +263,41 @@ def decorrelate_child( new_join_keys: list[tuple[HybridExpr, HybridExpr]] = [] additional_levels: int = 0 current_level: HybridTree | None = old_parent + new_agg_keys: list[HybridExpr] = [] while current_level is not None: - for unique_key in current_level.pipeline[0].unique_exprs: + skip_join: bool = ( + isinstance(current_level.pipeline[0], HybridPartition) + and child is current_level.children[0] + ) + for unique_key in sorted(current_level.pipeline[0].unique_exprs, key=str): lhs_key: HybridExpr | None = unique_key.shift_back(additional_levels) rhs_key: HybridExpr | None = unique_key.shift_back( - additional_levels + child_height + additional_levels + child_height - skipped_levels ) assert lhs_key is not None and rhs_key is not None - new_join_keys.append((lhs_key, rhs_key)) + if not skip_join: + new_join_keys.append((lhs_key, rhs_key)) + new_agg_keys.append(rhs_key) current_level = current_level.parent additional_levels += 1 child.subtree.join_keys = new_join_keys - # If aggregating, do the same with the aggregation keys. + # If aggregating, update the aggregation keys accordingly. if is_aggregate: - new_agg_keys: list[HybridExpr] = [] - assert child.subtree.join_keys is not None - for _, rhs_key in child.subtree.join_keys: - new_agg_keys.append(rhs_key) child.subtree.agg_keys = new_agg_keys def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: """ - TODO + The recursive procedure to remove unwanted correlated references from + the entire hybrid tree, called from the bottom and working upwards + to the top layer, and having each layer also de-correlate its children. + + Args: + `hybrid`: The hybrid tree to remove correlated references from. + + Returns: + The hybrid tree with all invalid correlated references removed as the + tree structure is re-written to allow them to be replaced with BACK + references. The transformation is also done in-place. """ # Recursively decorrelate the ancestors of the current level of the # hybrid tree. @@ -282,9 +313,6 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: for idx, child in enumerate(hybrid.children): if idx not in hybrid.correlated_children: continue - new_parent: HybridTree = self.make_decorrelate_parent( - hybrid, idx, hybrid.children[idx].required_steps - ) match child.connection_type: case ( ConnectionType.SINGULAR @@ -292,8 +320,15 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: | ConnectionType.AGGREGATION | ConnectionType.AGGREGATION_ONLY_MATCH ): + new_parent, skipped_levels = self.make_decorrelate_parent( + hybrid, idx, hybrid.children[idx].required_steps + ) self.decorrelate_child( - hybrid, new_parent, child, child.connection_type.is_aggregation + hybrid, + new_parent, + child, + child.connection_type.is_aggregation, + skipped_levels, ) case ConnectionType.NDISTINCT | ConnectionType.NDISTINCT_ONLY_MATCH: raise NotImplementedError( diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 956e88f48..16b0bca5a 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -494,17 +494,26 @@ def __init__( for name, expr in predecessor.terms.items(): terms[name] = HybridRefExpr(name, expr.typ) renamings.update(predecessor.renamings) + new_renamings: dict[str, str] = {} for name, expr in new_expressions.items(): if name in terms and terms[name] == expr: continue expr = expr.apply_renamings(predecessor.renamings) used_name: str = name idx: int = 0 - while used_name in terms or used_name in renamings: + while ( + used_name in terms + or used_name in renamings + or used_name in new_renamings + ): used_name = f"{name}_{idx}" idx += 1 terms[used_name] = expr - renamings[name] = used_name + new_renamings[name] = used_name + renamings.update(new_renamings) + for old_name, new_name in new_renamings.items(): + expr = new_expressions.pop(old_name) + new_expressions[new_name] = expr super().__init__(terms, renamings, orderings, predecessor.unique_exprs) self.calc = Calc self.new_expressions = new_expressions @@ -520,7 +529,10 @@ class HybridFilter(HybridOperation): def __init__(self, predecessor: HybridOperation, condition: HybridExpr): super().__init__( - predecessor.terms, {}, predecessor.orderings, predecessor.unique_exprs + predecessor.terms, + predecessor.renamings, + predecessor.orderings, + predecessor.unique_exprs, ) self.predecessor: HybridOperation = predecessor self.condition: HybridExpr = condition @@ -566,7 +578,10 @@ def __init__( records_to_keep: int, ): super().__init__( - predecessor.terms, {}, predecessor.orderings, predecessor.unique_exprs + predecessor.terms, + predecessor.renamings, + predecessor.orderings, + predecessor.unique_exprs, ) self.predecessor: HybridOperation = predecessor self.records_to_keep: int = records_to_keep @@ -908,13 +923,13 @@ def __repr__(self): lines.append(" -> ".join(repr(operation) for operation in self.pipeline)) prefix = " " if self.successor is None else "↓" for idx, child in enumerate(self.children): - lines.append(f"{prefix} child #{idx}:") + lines.append(f"{prefix} child #{idx} ({child.connection_type.name}):") if child.subtree.agg_keys is not None: - lines.append( - f"{prefix} aggregate: {child.subtree.agg_keys} -> {child.aggs}:" - ) + lines.append(f"{prefix} aggregate: {child.subtree.agg_keys}") + if len(child.aggs): + lines.append(f"{prefix} aggs: {child.aggs}:") if child.subtree.join_keys is not None: - lines.append(f"{prefix} join: {child.subtree.join_keys}:") + lines.append(f"{prefix} join: {child.subtree.join_keys}") for line in repr(child.subtree).splitlines(): lines.append(f"{prefix} {line}") return "\n".join(lines) @@ -1964,6 +1979,24 @@ def make_hybrid_tree( rhs_expr.name, 0, rhs_expr.typ ) join_key_exprs.append((lhs_expr, rhs_expr)) + + case PartitionBy(): + partition = HybridPartition() + successor_hybrid = HybridTree(partition) + self.populate_children( + successor_hybrid, node.child_access, child_ref_mapping + ) + partition_child_idx = child_ref_mapping[0] + for key_name in node.calc_terms: + key = node.get_expr(key_name) + expr = self.make_hybrid_expr( + successor_hybrid, key, child_ref_mapping, False + ) + partition.add_key(key_name, expr) + key_exprs.append(HybridRefExpr(key_name, expr.typ)) + successor_hybrid.children[ + partition_child_idx + ].subtree.agg_keys = key_exprs case _: raise NotImplementedError( f"{node.__class__.__name__} (child is {node.child_access.__class__.__name__})" diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 77c10c861..c48132a2a 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -885,11 +885,26 @@ def rel_translation( if isinstance(operation.collection, TableCollection): result = self.build_simple_table_scan(operation) if context is not None: + # If the collection access is the child of something + # else, join it onto that something else. Use the + # uniqueness keys of the ancestor, which should also be + # present in the collection (e.g. joining a partition + # onto the original data using the partition keys). + assert preceding_hybrid is not None + join_keys: list[tuple[HybridExpr, HybridExpr]] = [] + for unique_column in sorted( + preceding_hybrid[0].pipeline[0].unique_exprs, key=str + ): + if unique_column not in result.expressions: + raise ValueError( + f"Cannot connect parent context to child {operation.collection} because {unique_column} is not in the child's expressions." + ) + join_keys.append((unique_column, unique_column)) result = self.join_outputs( context, result, JoinType.INNER, - [], + join_keys, None, ) else: @@ -913,7 +928,8 @@ def rel_translation( assert context is not None, "Malformed HybridTree pattern." result = self.translate_filter(operation, context) case HybridPartition(): - assert context is not None, "Malformed HybridTree pattern." + if context is None: + context = TranslationOutput(EmptySingleton(), {}) result = self.translate_partition( operation, context, hybrid, pipeline_idx ) diff --git a/pydough/relational/relational_nodes/column_pruner.py b/pydough/relational/relational_nodes/column_pruner.py index 47cda027c..ef8e7cf25 100644 --- a/pydough/relational/relational_nodes/column_pruner.py +++ b/pydough/relational/relational_nodes/column_pruner.py @@ -11,7 +11,8 @@ from .abstract_node import RelationalNode from .aggregate import Aggregate -from .join import Join +from .empty_singleton import EmptySingleton +from .join import Join, JoinType from .project import Project from .relational_expression_dispatcher import RelationalExpressionDispatcher from .relational_root import RelationalRoot @@ -141,7 +142,23 @@ def _prune_node_columns( # Determine the new node. output = new_node.copy(inputs=new_inputs) - return self._prune_identity_project(output), correl_refs + output = self._prune_identity_project(output) + # Special case: replace empty aggregation with VALUES () if possible. + if ( + isinstance(output, Aggregate) + and len(output.keys) == 0 + and len(output.aggregations) == 0 + ): + return EmptySingleton(), correl_refs + # Special case: replace join where LHS is VALUES () with the RHS if + # possible. + if ( + isinstance(output, Join) + and isinstance(output.inputs[0], EmptySingleton) + and output.join_types in ([JoinType.INNER], [JoinType.LEFT]) + ): + return output.inputs[1], correl_refs + return output, correl_refs def prune_unused_columns(self, root: RelationalRoot) -> RelationalRoot: """ diff --git a/pydough/sqlglot/sqlglot_relational_visitor.py b/pydough/sqlglot/sqlglot_relational_visitor.py index 0b53f04d4..ca9cbe2c0 100644 --- a/pydough/sqlglot/sqlglot_relational_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_visitor.py @@ -431,6 +431,7 @@ def visit_filter(self, filter: Filter) -> None: # TODO: (gh #151) Refactor a simpler way to check dependent expressions. if ( "group" in input_expr.args + or "distinct" in input_expr.args or "where" in input_expr.args or "qualify" in input_expr.args or "order" in input_expr.args @@ -462,6 +463,7 @@ def visit_aggregate(self, aggregate: Aggregate) -> None: query: Select if ( "group" in input_expr.args + or "distinct" in input_expr.args or "qualify" in input_expr.args or "order" in input_expr.args or "limit" in input_expr.args @@ -472,7 +474,10 @@ def visit_aggregate(self, aggregate: Aggregate) -> None: select_cols, input_expr, find_identifiers_in_list(select_cols) ) if keys: - query = query.group_by(*keys) + if aggregations: + query = query.group_by(*keys) + else: + query = query.distinct() self._stack.append(query) def visit_limit(self, limit: Limit) -> None: @@ -511,7 +516,7 @@ def visit_limit(self, limit: Limit) -> None: self._stack.append(query) def visit_empty_singleton(self, singleton: EmptySingleton) -> None: - self._stack.append(Select().from_(values([()]))) + self._stack.append(Select().select(SQLGlotStar()).from_(values([()]))) def visit_root(self, root: RelationalRoot) -> None: self.visit_inputs(root) diff --git a/pydough/unqualified/qualification.py b/pydough/unqualified/qualification.py index ba33907db..0a532d783 100644 --- a/pydough/unqualified/qualification.py +++ b/pydough/unqualified/qualification.py @@ -676,7 +676,12 @@ def qualify_partition( partition: PartitionBy = self.builder.build_partition( qualified_parent, qualified_child, child_name ) - return partition.with_keys(child_references) + partition = partition.with_keys(child_references) + # Special case: if accessing as a child, wrap in a + # ChildOperatorChildAccess term. + if isinstance(unqualified_parent, UnqualifiedRoot) and is_child: + return ChildOperatorChildAccess(partition) + return partition def qualify_collection( self, diff --git a/tests/correlated_pydough_functions.py b/tests/correlated_pydough_functions.py index e3b73053f..3d8c4eb72 100644 --- a/tests/correlated_pydough_functions.py +++ b/tests/correlated_pydough_functions.py @@ -300,3 +300,46 @@ def correl_20(): ) instances = selected_orders.lines.supplier.WHERE(is_domestic) return TPCH(n=COUNT(instances)) + + +def correl_21(): + # Correlated back reference example #21: partition edge case. + # Count how many part sizes have an above-average number of parts + # of that size. + # (This is a correlated aggregation access) + sizes = PARTITION(Parts, name="p", by=size)(n_parts=COUNT(p)) + return TPCH(avg_n_parts=AVG(sizes.n_parts))( + n_sizes=COUNT(sizes.WHERE(n_parts > BACK(1).avg_n_parts)) + ) + + +def correl_22(): + # Correlated back reference example #22: partition edge case. + # Finds the top 5 part sizes with the most container types + # where the average retail price of parts of that container type + # & part type is above the global average retail price. + # (This is a correlated aggregation access) + ct_combos = PARTITION(Parts, name="p", by=(container, part_type))( + avg_price=AVG(p.retail_price) + ) + return ( + TPCH(global_avg_price=AVG(Parts.retail_price)) + .PARTITION( + ct_combos.WHERE(avg_price > BACK(1).global_avg_price), + name="ct", + by=container, + )(container, n_types=COUNT(ct)) + .TOP_K(5, (n_types.DESC(), container.ASC())) + ) + + +def correl_23(): + # Correlated back reference example #23: partition edge case. + # Counts how many part sizes have an above-average number of combinations + # of part types/containers. + # (This is a correlated aggregation access) + combos = PARTITION(Parts, name="p", by=(size, part_type, container)) + sizes = PARTITION(combos, name="c", by=size)(n_combos=COUNT(c)) + return TPCH(avg_n_combo=AVG(sizes.n_combos))( + n_sizes=COUNT(sizes.WHERE(n_combos > BACK(1).avg_n_combo)), + ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1ec380172..f862efa29 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -33,6 +33,9 @@ correl_18, correl_19, correl_20, + correl_21, + correl_22, + correl_23, ) from simple_pydough_functions import ( agg_partition, @@ -985,6 +988,41 @@ ), id="correl_20", ), + pytest.param( + ( + correl_21, + "correl_21", + lambda: pd.DataFrame({"n_sizes": [30]}), + ), + id="correl_21", + ), + pytest.param( + ( + correl_22, + "correl_22", + lambda: pd.DataFrame( + { + "container": [ + "JUMBO DRUM", + "JUMBO PKG", + "MED DRUM", + "SM BAG", + "LG PKG", + ], + "n_types": [89, 86, 81, 81, 80], + } + ), + ), + id="correl_22", + ), + pytest.param( + ( + correl_23, + "correl_23", + lambda: pd.DataFrame({"n_sizes": [23]}), + ), + id="correl_23", + ), ], ) def pydough_pipeline_test_data( diff --git a/tests/test_plan_refsols/correl_15.txt b/tests/test_plan_refsols/correl_15.txt index 960b0c691..b71be80d9 100644 --- a/tests/test_plan_refsols/correl_15.txt +++ b/tests/test_plan_refsols/correl_15.txt @@ -1,25 +1,22 @@ ROOT(columns=[('n', n)], orderings=[]) PROJECT(columns={'n': agg_1}) - JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1}) - AGGREGATE(keys={}, aggregations={}) - SCAN(table=tpch.PART, columns={'brand': p_brand}) - AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) - FILTER(condition=True:bool, columns={'account_balance': account_balance}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr4') - PROJECT(columns={'account_balance': account_balance, 'avg_price': avg_price, 'avg_price_3': agg_0, 'key': key}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'avg_price': t0.avg_price, 'key': t0.key}) - FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'avg_price': avg_price, 'key': key}) - JOIN(conditions=[True:bool], types=['inner'], columns={'account_balance': t1.account_balance, 'avg_price': t0.avg_price, 'key': t1.key, 'nation_key': t1.nation_key}) - PROJECT(columns={'avg_price': agg_0}) - AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) - SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) - AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) - FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) - JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr3') - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - FILTER(condition=container == 'LG DRUM':string & retail_price < corr3.supplycost * 1.5:float64 & retail_price < corr4.avg_price_3 & retail_price < corr4.avg_price * 0.85:float64, columns={'key': key}) - SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) + AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) + FILTER(condition=True:bool, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr4') + PROJECT(columns={'account_balance': account_balance, 'avg_price': avg_price, 'avg_price_3': agg_0, 'key': key}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'avg_price': t0.avg_price, 'key': t0.key}) + FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'avg_price': avg_price, 'key': key}) + JOIN(conditions=[True:bool], types=['inner'], columns={'account_balance': t1.account_balance, 'avg_price': t0.avg_price, 'key': t1.key, 'nation_key': t1.nation_key}) + PROJECT(columns={'avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) + FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr3') + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + FILTER(condition=container == 'LG DRUM':string & retail_price < corr3.supplycost * 1.5:float64 & retail_price < corr4.avg_price_3 & retail_price < corr4.avg_price * 0.85:float64, columns={'key': key}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_18.txt b/tests/test_plan_refsols/correl_18.txt index db0b52912..74fab0daa 100644 --- a/tests/test_plan_refsols/correl_18.txt +++ b/tests/test_plan_refsols/correl_18.txt @@ -10,10 +10,10 @@ ROOT(columns=[('n', n)], orderings=[]) AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_2': COUNT()}) FILTER(condition=total_price_3 >= 0.5:float64 * total_price, columns={'customer_key': customer_key, 'order_date': order_date}) FILTER(condition=YEAR(order_date_2) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price, 'total_price_3': total_price_3}) - JOIN(conditions=[True:bool], types=['inner'], columns={'customer_key': t0.customer_key, 'order_date': t0.order_date, 'order_date_2': t1.order_date, 'total_price': t0.total_price, 'total_price_3': t1.total_price}) + JOIN(conditions=[t0.customer_key == t1.customer_key & t0.order_date == t1.order_date], types=['inner'], columns={'customer_key': t0.customer_key, 'order_date': t0.order_date, 'order_date_2': t1.order_date, 'total_price': t0.total_price, 'total_price_3': t1.total_price}) PROJECT(columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': DEFAULT_TO(agg_1, 0:int64)}) FILTER(condition=DEFAULT_TO(agg_0, 0:int64) > 1:int64, columns={'agg_1': agg_1, 'customer_key': customer_key, 'order_date': order_date}) AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_0': COUNT(), 'agg_1': SUM(total_price)}) FILTER(condition=YEAR(order_date) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price}) SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) - SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate, 'total_price': o_totalprice}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) diff --git a/tests/test_plan_refsols/correl_19.txt b/tests/test_plan_refsols/correl_19.txt index a65ac794e..ddaa017a0 100644 --- a/tests/test_plan_refsols/correl_19.txt +++ b/tests/test_plan_refsols/correl_19.txt @@ -1,16 +1,15 @@ -ROOT(columns=[('name', name_14), ('n_super_cust', n_super_cust)], orderings=[(ordering_1):desc_last]) - PROJECT(columns={'n_super_cust': n_super_cust, 'name_14': name_3, 'ordering_1': ordering_1}) - LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'n_super_cust': n_super_cust, 'name_3': name_3, 'ordering_1': ordering_1}, orderings=[(ordering_1):desc_last]) - PROJECT(columns={'n_super_cust': n_super_cust, 'name_3': name_3, 'ordering_1': n_super_cust}) - PROJECT(columns={'n_super_cust': DEFAULT_TO(agg_0, 0:int64), 'name_3': name_3}) - JOIN(conditions=[t0.key_2 == t1.key_5 & t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name_3': t1.name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'key': key, 'key_5': key_5}, aggregations={'agg_0': COUNT()}) - FILTER(condition=acctbal > account_balance, columns={'key': key, 'key_5': key_5}) - JOIN(conditions=[t0.key_5 == t1.nation_key], types=['inner'], columns={'account_balance': t0.account_balance, 'acctbal': t1.acctbal, 'key': t0.key, 'key_5': t0.key_5}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'account_balance': t0.account_balance, 'key': t0.key, 'key_5': t1.key}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) +ROOT(columns=[('name', name_0), ('n_super_cust', n_super_cust)], orderings=[(ordering_1):desc_last]) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'n_super_cust': n_super_cust, 'name_0': name_0, 'ordering_1': ordering_1}, orderings=[(ordering_1):desc_last]) + PROJECT(columns={'n_super_cust': n_super_cust, 'name_0': name_0, 'ordering_1': n_super_cust}) + PROJECT(columns={'n_super_cust': DEFAULT_TO(agg_0, 0:int64), 'name_0': name}) + JOIN(conditions=[t0.key_2 == t1.key_5 & t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + AGGREGATE(keys={'key': key, 'key_5': key_5}, aggregations={'agg_0': COUNT()}) + FILTER(condition=acctbal > account_balance, columns={'key': key, 'key_5': key_5}) + JOIN(conditions=[t0.key_5 == t1.nation_key], types=['inner'], columns={'account_balance': t0.account_balance, 'acctbal': t1.acctbal, 'key': t0.key, 'key_5': t0.key_5}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'account_balance': t0.account_balance, 'key': t0.key, 'key_5': t1.key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_20.txt b/tests/test_plan_refsols/correl_20.txt index 670d08803..66ed3e350 100644 --- a/tests/test_plan_refsols/correl_20.txt +++ b/tests/test_plan_refsols/correl_20.txt @@ -2,7 +2,7 @@ ROOT(columns=[('n', n)], orderings=[]) PROJECT(columns={'n': agg_0}) AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) FILTER(condition=domestic, columns={'account_balance': account_balance}) - JOIN(conditions=[t0.key_9 == t1.key_21 & t0.order_key == t1.order_key & t0.line_number == t1.line_number & t0.key_5 == t1.key_17 & t0.key_2 == t1.key_14 & t0.key == t1.key], types=['left'], columns={'account_balance': t0.account_balance, 'domestic': t1.domestic}) + JOIN(conditions=[t0.key_9 == t1.key_21 & t0.line_number == t1.line_number & t0.order_key == t1.order_key & t0.key_5 == t1.key_17 & t0.key_2 == t1.key_14 & t0.key == t1.key], types=['left'], columns={'account_balance': t0.account_balance, 'domestic': t1.domestic}) JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'account_balance': t1.account_balance, 'key': t0.key, 'key_2': t0.key_2, 'key_5': t0.key_5, 'key_9': t1.key, 'line_number': t0.line_number, 'order_key': t0.order_key}) JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'key': t0.key, 'key_2': t0.key_2, 'key_5': t0.key_5, 'line_number': t1.line_number, 'order_key': t1.order_key, 'supplier_key': t1.supplier_key}) FILTER(condition=YEAR(order_date) == 1998:int64 & MONTH(order_date) == 6:int64, columns={'key': key, 'key_2': key_2, 'key_5': key_5}) diff --git a/tests/test_plan_refsols/correl_21.txt b/tests/test_plan_refsols/correl_21.txt new file mode 100644 index 000000000..24c781d1b --- /dev/null +++ b/tests/test_plan_refsols/correl_21.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('n_sizes', n_sizes)], orderings=[]) + PROJECT(columns={'n_sizes': agg_1}) + AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) + FILTER(condition=n_parts > avg_n_parts, columns={'agg_0': agg_0}) + PROJECT(columns={'agg_0': agg_0, 'avg_n_parts': avg_n_parts, 'n_parts': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t1.agg_0, 'avg_n_parts': t0.avg_n_parts}) + PROJECT(columns={'avg_n_parts': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(n_parts)}) + PROJECT(columns={'n_parts': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + SCAN(table=tpch.PART, columns={'size': p_size}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + SCAN(table=tpch.PART, columns={'size': p_size}) diff --git a/tests/test_plan_refsols/correl_22.txt b/tests/test_plan_refsols/correl_22.txt new file mode 100644 index 000000000..4c5be8549 --- /dev/null +++ b/tests/test_plan_refsols/correl_22.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('container', container), ('n_types', n_types)], orderings=[(ordering_2):desc_last, (ordering_3):asc_first]) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'container': container, 'n_types': n_types, 'ordering_2': ordering_2, 'ordering_3': ordering_3}, orderings=[(ordering_2):desc_last, (ordering_3):asc_first]) + PROJECT(columns={'container': container, 'n_types': n_types, 'ordering_2': n_types, 'ordering_3': container}) + PROJECT(columns={'container': container, 'n_types': DEFAULT_TO(agg_1, 0:int64)}) + AGGREGATE(keys={'container': container}, aggregations={'agg_1': COUNT()}) + FILTER(condition=avg_price > global_avg_price, columns={'container': container}) + PROJECT(columns={'avg_price': agg_0, 'container': container, 'global_avg_price': global_avg_price}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t1.agg_0, 'container': t1.container, 'global_avg_price': t0.global_avg_price}) + PROJECT(columns={'global_avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + AGGREGATE(keys={'container': container, 'part_type': part_type}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'container': p_container, 'part_type': p_type, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_23.txt b/tests/test_plan_refsols/correl_23.txt new file mode 100644 index 000000000..5bbadbc9d --- /dev/null +++ b/tests/test_plan_refsols/correl_23.txt @@ -0,0 +1,15 @@ +ROOT(columns=[('n_sizes', n_sizes)], orderings=[]) + PROJECT(columns={'n_sizes': agg_1}) + AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) + FILTER(condition=n_combos > avg_n_combo, columns={'agg_0': agg_0}) + PROJECT(columns={'agg_0': agg_0, 'avg_n_combo': avg_n_combo, 'n_combos': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t1.agg_0, 'avg_n_combo': t0.avg_n_combo}) + PROJECT(columns={'avg_n_combo': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(n_combos)}) + PROJECT(columns={'n_combos': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + AGGREGATE(keys={'container': container, 'part_type': part_type, 'size': size}, aggregations={}) + SCAN(table=tpch.PART, columns={'container': p_container, 'part_type': p_type, 'size': p_size}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + AGGREGATE(keys={'container': container, 'part_type': part_type, 'size': size}, aggregations={}) + SCAN(table=tpch.PART, columns={'container': p_container, 'part_type': p_type, 'size': p_size}) diff --git a/tests/test_plan_refsols/join_regions_nations_calc_override.txt b/tests/test_plan_refsols/join_regions_nations_calc_override.txt index 971564305..e1cce36c6 100644 --- a/tests/test_plan_refsols/join_regions_nations_calc_override.txt +++ b/tests/test_plan_refsols/join_regions_nations_calc_override.txt @@ -1,6 +1,6 @@ -ROOT(columns=[('key', key_0_9), ('name', name_11), ('phone', phone), ('mktsegment', mktsegment)], orderings=[]) - PROJECT(columns={'key_0_9': key_0_9, 'mktsegment': mktsegment, 'name_11': name_10, 'phone': phone}) - PROJECT(columns={'key_0_9': -3:int64, 'mktsegment': mktsegment, 'name_10': name_7, 'phone': phone}) +ROOT(columns=[('key', key_0_10), ('name', name_11), ('phone', phone), ('mktsegment', mktsegment)], orderings=[]) + PROJECT(columns={'key_0_10': key_0_10, 'mktsegment': mktsegment, 'name_11': name_9, 'phone': phone}) + PROJECT(columns={'key_0_10': -3:int64, 'mktsegment': mktsegment, 'name_9': name_7, 'phone': phone}) JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'mktsegment': t1.mktsegment, 'name_7': t1.name, 'phone': t1.phone}) JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key}) SCAN(table=tpch.REGION, columns={'key': r_regionkey}) diff --git a/tests/test_plan_refsols/tpch_q22.txt b/tests/test_plan_refsols/tpch_q22.txt index 1bd521eb4..01b9f8605 100644 --- a/tests/test_plan_refsols/tpch_q22.txt +++ b/tests/test_plan_refsols/tpch_q22.txt @@ -1,18 +1,18 @@ -ROOT(columns=[('CNTRY_CODE', CNTRY_CODE), ('NUM_CUSTS', NUM_CUSTS), ('TOTACCTBAL', TOTACCTBAL)], orderings=[]) - PROJECT(columns={'CNTRY_CODE': cntry_code, 'NUM_CUSTS': DEFAULT_TO(agg_1, 0:int64), 'TOTACCTBAL': DEFAULT_TO(agg_2, 0:int64)}) - JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1, 'agg_2': t1.agg_2, 'cntry_code': t1.cntry_code}, correl_name='corr1') - PROJECT(columns={'avg_balance': agg_0}) - AGGREGATE(keys={}, aggregations={'agg_0': AVG(acctbal)}) - FILTER(condition=acctbal > 0.0:float64, columns={'acctbal': acctbal}) - FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]) & True:bool, columns={'acctbal': acctbal}) - JOIN(conditions=[t0.key == t1.customer_key], types=['anti'], columns={'acctbal': t0.acctbal, 'cntry_code': t0.cntry_code}) - PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) +ROOT(columns=[('CNTRY_CODE', CNTRY_CODE), ('NUM_CUSTS', NUM_CUSTS), ('TOTACCTBAL', TOTACCTBAL)], orderings=[(ordering_3):asc_first]) + PROJECT(columns={'CNTRY_CODE': CNTRY_CODE, 'NUM_CUSTS': NUM_CUSTS, 'TOTACCTBAL': TOTACCTBAL, 'ordering_3': CNTRY_CODE}) + PROJECT(columns={'CNTRY_CODE': cntry_code, 'NUM_CUSTS': DEFAULT_TO(agg_1, 0:int64), 'TOTACCTBAL': DEFAULT_TO(agg_2, 0:int64)}) AGGREGATE(keys={'cntry_code': cntry_code}, aggregations={'agg_1': COUNT(), 'agg_2': SUM(acctbal)}) - FILTER(condition=acctbal > corr1.avg_balance, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) - FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]) & True:bool, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) - JOIN(conditions=[t0.key == t1.customer_key], types=['anti'], columns={'acctbal': t0.acctbal, 'cntry_code': t0.cntry_code}) - PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) + FILTER(condition=acctbal > avg_balance & DEFAULT_TO(agg_0, 0:int64) == 0:int64, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) + JOIN(conditions=[t0.key == t1.customer_key], types=['left'], columns={'acctbal': t0.acctbal, 'agg_0': t1.agg_0, 'avg_balance': t0.avg_balance, 'cntry_code': t0.cntry_code}) + FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]), columns={'acctbal': acctbal, 'avg_balance': avg_balance, 'cntry_code': cntry_code, 'key': key}) + PROJECT(columns={'acctbal': acctbal, 'avg_balance': avg_balance, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) + JOIN(conditions=[True:bool], types=['inner'], columns={'acctbal': t1.acctbal, 'avg_balance': t0.avg_balance, 'key': t1.key, 'phone': t1.phone}) + PROJECT(columns={'avg_balance': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(acctbal)}) + FILTER(condition=acctbal > 0.0:float64, columns={'acctbal': acctbal}) + FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]), columns={'acctbal': acctbal}) + PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown)}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'phone': c_phone}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) + AGGREGATE(keys={'customer_key': customer_key}, aggregations={'agg_0': COUNT()}) SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) diff --git a/tests/test_qualification.py b/tests/test_qualification.py index 6e25259d5..eab7c54b5 100644 --- a/tests/test_qualification.py +++ b/tests/test_qualification.py @@ -68,6 +68,26 @@ def pydough_impl_misc_02(root: UnqualifiedNode) -> UnqualifiedNode: ) +def pydough_impl_misc_03(root: UnqualifiedNode) -> UnqualifiedNode: + """ + Creates an UnqualifiedNode for the following PyDough snippet: + ``` + sizes = PARTITION(Parts, name="p", by=size)(n_parts=COUNT(p)) + TPCH( + avg_n_parts=AVG(sizes.n_parts) + )( + n_parts=COUNT(sizes.WHERE(n_parts > BACK(1).avg_n_parts)) + ) + ``` + """ + sizes = root.PARTITION(root.Parts, name="p", by=root.size)( + n_parts=root.COUNT(root.p) + ) + return root.TPCH(avg_n_parts=root.AVG(sizes.n_parts))( + n_parts=root.COUNT(sizes.WHERE(root.n_parts > root.BACK(1).avg_n_parts)) + ) + + def pydough_impl_tpch_q1(root: UnqualifiedNode) -> UnqualifiedNode: """ Creates an UnqualifiedNode for TPC-H query 1. @@ -612,6 +632,30 @@ def pydough_impl_tpch_q22(root: UnqualifiedNode) -> UnqualifiedNode: """, id="misc_02", ), + pytest.param( + pydough_impl_misc_03, + """ +┌─── TPCH +├─┬─ Calc[avg_n_parts=AVG($1.n_parts)] +│ └─┬─ AccessChild +│ ├─┬─ Partition[name='p', by=size] +│ │ └─┬─ AccessChild +│ │ └─── TableCollection[Parts] +│ └─┬─ Calc[n_parts=COUNT($1)] +│ └─┬─ AccessChild +│ └─── PartitionChild[p] +└─┬─ Calc[n_parts=COUNT($1)] + └─┬─ AccessChild + ├─┬─ Partition[name='p', by=size] + │ └─┬─ AccessChild + │ └─── TableCollection[Parts] + ├─┬─ Calc[n_parts=COUNT($1)] + │ └─┬─ AccessChild + │ └─── PartitionChild[p] + └─── Where[n_parts > BACK(1).avg_n_parts] +""", + id="misc_03", + ), pytest.param( pydough_impl_tpch_q1, """ diff --git a/tests/test_relational_nodes_to_sqlglot.py b/tests/test_relational_nodes_to_sqlglot.py index 7859c88fb..6c3c04cde 100644 --- a/tests/test_relational_nodes_to_sqlglot.py +++ b/tests/test_relational_nodes_to_sqlglot.py @@ -152,6 +152,8 @@ def mkglot(expressions: list[Expression], _from: Expression, **kwargs) -> Select query = query.where(kwargs.pop("where")) if "group_by" in kwargs: query = query.group_by(*kwargs.pop("group_by")) + if kwargs.pop("distinct", False): + query = query.distinct() if "qualify" in kwargs: query = query.qualify(kwargs.pop("qualify")) if "order_by" in kwargs: @@ -627,14 +629,15 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: Aggregate( input=build_simple_scan(), keys={ + "a": make_relational_column_reference("a"), "b": make_relational_column_reference("b"), }, aggregations={}, ), mkglot( - expressions=[Ident(this="b")], + expressions=[Ident(this="a"), Ident(this="b")], _from=GlotFrom(Table(this=Ident(this="table"))), - group_by=[Ident(this="b")], + distinct=True, ), id="simple_distinct", ), @@ -718,7 +721,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: mkglot( expressions=[Ident(this="b")], where=mkglot_func(EQ, [Ident(this="a"), mk_literal(1, False)]), - group_by=[Ident(this="b")], + distinct=True, _from=GlotFrom( mkglot( expressions=[Ident(this="a"), Ident(this="b")], @@ -794,7 +797,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: ), mkglot( expressions=[Ident(this="b")], - group_by=[Ident(this="b")], + distinct=True, _from=GlotFrom( mkglot( expressions=[Ident(this="a"), Ident(this="b")], @@ -829,7 +832,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: mkglot( expressions=[Ident(this="b")], _from=GlotFrom(Table(this=Ident(this="table"))), - group_by=[Ident(this="b")], + distinct=True, order_by=[Ident(this="b").desc(nulls_first=False)], limit=mk_literal(10, False), ), @@ -865,8 +868,8 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: _from=GlotFrom( mkglot( expressions=[Ident(this="b")], - group_by=[Ident(this="b")], _from=GlotFrom(Table(this=Ident(this="table"))), + distinct=True, ) ), ), diff --git a/tests/test_sql_refsols/simple_distinct.sql b/tests/test_sql_refsols/simple_distinct.sql index 696d95c0c..38ad75661 100644 --- a/tests/test_sql_refsols/simple_distinct.sql +++ b/tests/test_sql_refsols/simple_distinct.sql @@ -1,5 +1,3 @@ -SELECT +SELECT DISTINCT b FROM table -GROUP BY - b diff --git a/tests/test_unqualified_node.py b/tests/test_unqualified_node.py index 386266b23..34c7921c0 100644 --- a/tests/test_unqualified_node.py +++ b/tests/test_unqualified_node.py @@ -418,7 +418,7 @@ def test_unqualified_to_string( ), pytest.param( impl_tpch_q22, - "?.TPCH(avg_balance=AVG(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE((ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17']) & HASNOT(?.orders))).WHERE((?.acctbal > 0.0)).acctbal)).PARTITION(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE((ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17']) & HASNOT(?.orders))).WHERE((?.acctbal > BACK(1).avg_balance)), name='custs', by=(?.cntry_code))(CNTRY_CODE=?.cntry_code, NUM_CUSTS=COUNT(?.custs), TOTACCTBAL=SUM(?.custs.acctbal))", + "?.TPCH(avg_balance=AVG(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE(ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17'])).WHERE((?.acctbal > 0.0)).acctbal)).PARTITION(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE(ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17'])).WHERE(((?.acctbal > BACK(1).avg_balance) & (COUNT(?.orders) == 0))), name='custs', by=(?.cntry_code))(CNTRY_CODE=?.cntry_code, NUM_CUSTS=COUNT(?.custs), TOTACCTBAL=SUM(?.custs.acctbal)).ORDER_BY(?.CNTRY_CODE.ASC(na_pos='first'))", id="tpch_q22", ), pytest.param( diff --git a/tests/tpch_outputs.py b/tests/tpch_outputs.py index 80948acb1..82fbe5cb7 100644 --- a/tests/tpch_outputs.py +++ b/tests/tpch_outputs.py @@ -715,7 +715,7 @@ def tpch_q22_output() -> pd.DataFrame: This query needs manual rewriting to run efficiently in SQLite by avoiding the correlated join. """ - columns = ["CNTRYCODE", "NUMCUST", "TOTACCTBAL"] + columns = ["CNTRY_CODE", "NUM_CUSTS", "TOTACCTBAL"] data = [ ("13", 888, 6737713.99), ("17", 861, 6460573.72), diff --git a/tests/tpch_test_functions.py b/tests/tpch_test_functions.py index 21cfdaa89..c1df2cf87 100644 --- a/tests/tpch_test_functions.py +++ b/tests/tpch_test_functions.py @@ -494,16 +494,20 @@ def impl_tpch_q22(): PyDough implementation of TPCH Q22. """ selected_customers = Customers(cntry_code=phone[:2]).WHERE( - ISIN(cntry_code, ("13", "31", "23", "29", "30", "18", "17")) & HASNOT(orders) + ISIN(cntry_code, ("13", "31", "23", "29", "30", "18", "17")) ) - return TPCH( - avg_balance=AVG(selected_customers.WHERE(acctbal > 0.0).acctbal) - ).PARTITION( - selected_customers.WHERE(acctbal > BACK(1).avg_balance), - name="custs", - by=cntry_code, - )( - CNTRY_CODE=cntry_code, - NUM_CUSTS=COUNT(custs), - TOTACCTBAL=SUM(custs.acctbal), + return ( + TPCH(avg_balance=AVG(selected_customers.WHERE(acctbal > 0.0).acctbal)) + .PARTITION( + selected_customers.WHERE( + (acctbal > BACK(1).avg_balance) & (COUNT(orders) == 0) + ), + name="custs", + by=cntry_code, + )( + CNTRY_CODE=cntry_code, + NUM_CUSTS=COUNT(custs), + TOTACCTBAL=SUM(custs.acctbal), + ) + .ORDER_BY(CNTRY_CODE.ASC()) )