diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py index 72f6aad1..a1c36407 100644 --- a/pydough/conversion/hybrid_decorrelater.py +++ b/pydough/conversion/hybrid_decorrelater.py @@ -34,7 +34,7 @@ class Decorrelater: def make_decorrelate_parent( self, hybrid: HybridTree, child_idx: int, required_steps: int - ) -> HybridTree: + ) -> tuple[HybridTree, int]: """ Creates a snapshot of the ancestry of the hybrid tree that contains a correlated child, without any of its children, its descendants, or @@ -50,10 +50,12 @@ def make_decorrelate_parent( derivable. Returns: - A snapshot of `hybrid` and its ancestry in the hybrid tree, without - without any of its children or pipeline operators that occur during - or after the derivation of the correlated child, or without any of - its descendants. + A tuple where the first entry is a snapshot of `hybrid` and its + ancestry in the hybrid tree, without without any of its children or + pipeline operators that occur during or after the derivation of the + correlated child, or without any of its descendants. The second + entry is the number of ancestor layers that should be skipped due + to the PARTITION edge case. """ if isinstance(hybrid.pipeline[0], HybridPartition) and child_idx == 0: # Special case: if the correlated child is the data argument of a @@ -61,10 +63,14 @@ def make_decorrelate_parent( # parent of the level containing the partition operation. In this # case, all of the parent's children & pipeline operators should be # included in the snapshot. - assert hybrid.parent is not None - return self.make_decorrelate_parent( + if hybrid.parent is None: + raise ValueError( + "Malformed hybrid tree: partition data input to a partition node cannot contain a correlated reference to the partition node." + ) + result = self.make_decorrelate_parent( hybrid.parent, len(hybrid.parent.children), len(hybrid.pipeline) ) + return result[0], result[1] + 1 # Temporarily detach the successor of the current level, then create a # deep copy of the current level (which will include its ancestors), # then reattach the successor back to the original. This ensures that @@ -78,7 +84,7 @@ def make_decorrelate_parent( # that is has to. new_hybrid._children = new_hybrid._children[:child_idx] new_hybrid._pipeline = new_hybrid._pipeline[: required_steps + 1] - return new_hybrid + return new_hybrid, 0 def remove_correl_refs( self, expr: HybridExpr, parent: HybridTree, child_height: int @@ -221,6 +227,7 @@ def decorrelate_child( new_parent: HybridTree, child: HybridConnection, is_aggregate: bool, + skipped_levels: int, ) -> None: """ Runs the logic to de-correlate a child of a hybrid tree that contains @@ -230,6 +237,17 @@ def decorrelate_child( child can now replace correlated references with BACK references that point to terms in its newly expanded ancestry, and the original hybrid tree can now join onto this child using its uniqueness keys. + + Args: + `old_parent`: The correlated ancestor hybrid tree that the correlated + references should point to when they are targeted for removal. + `new_parent`: The ancestor of `level` that removal should stop at. + `child`: The child of the hybrid tree that contains the correlated + nodes to be removed. + `is_aggregate`: Whether the child is being aggregated with regards + to its parent. + `skipped_levels`: The number of ancestor layers that should be + ignored when deriving backshifts of join/agg keys. """ # First, find the height of the child subtree & its top-most level. child_root: HybridTree = child.subtree @@ -245,28 +263,41 @@ def decorrelate_child( new_join_keys: list[tuple[HybridExpr, HybridExpr]] = [] additional_levels: int = 0 current_level: HybridTree | None = old_parent + new_agg_keys: list[HybridExpr] = [] while current_level is not None: - for unique_key in current_level.pipeline[0].unique_exprs: + skip_join: bool = ( + isinstance(current_level.pipeline[0], HybridPartition) + and child is current_level.children[0] + ) + for unique_key in sorted(current_level.pipeline[0].unique_exprs, key=str): lhs_key: HybridExpr | None = unique_key.shift_back(additional_levels) rhs_key: HybridExpr | None = unique_key.shift_back( - additional_levels + child_height + additional_levels + child_height - skipped_levels ) assert lhs_key is not None and rhs_key is not None - new_join_keys.append((lhs_key, rhs_key)) + if not skip_join: + new_join_keys.append((lhs_key, rhs_key)) + new_agg_keys.append(rhs_key) current_level = current_level.parent additional_levels += 1 child.subtree.join_keys = new_join_keys - # If aggregating, do the same with the aggregation keys. + # If aggregating, update the aggregation keys accordingly. if is_aggregate: - new_agg_keys: list[HybridExpr] = [] - assert child.subtree.join_keys is not None - for _, rhs_key in child.subtree.join_keys: - new_agg_keys.append(rhs_key) child.subtree.agg_keys = new_agg_keys def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: """ - TODO + The recursive procedure to remove unwanted correlated references from + the entire hybrid tree, called from the bottom and working upwards + to the top layer, and having each layer also de-correlate its children. + + Args: + `hybrid`: The hybrid tree to remove correlated references from. + + Returns: + The hybrid tree with all invalid correlated references removed as the + tree structure is re-written to allow them to be replaced with BACK + references. The transformation is also done in-place. """ # Recursively decorrelate the ancestors of the current level of the # hybrid tree. @@ -282,9 +313,6 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: for idx, child in enumerate(hybrid.children): if idx not in hybrid.correlated_children: continue - new_parent: HybridTree = self.make_decorrelate_parent( - hybrid, idx, hybrid.children[idx].required_steps - ) match child.connection_type: case ( ConnectionType.SINGULAR @@ -292,8 +320,15 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: | ConnectionType.AGGREGATION | ConnectionType.AGGREGATION_ONLY_MATCH ): + new_parent, skipped_levels = self.make_decorrelate_parent( + hybrid, idx, hybrid.children[idx].required_steps + ) self.decorrelate_child( - hybrid, new_parent, child, child.connection_type.is_aggregation + hybrid, + new_parent, + child, + child.connection_type.is_aggregation, + skipped_levels, ) case ConnectionType.NDISTINCT | ConnectionType.NDISTINCT_ONLY_MATCH: raise NotImplementedError( diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 956e88f4..16b0bca5 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -494,17 +494,26 @@ def __init__( for name, expr in predecessor.terms.items(): terms[name] = HybridRefExpr(name, expr.typ) renamings.update(predecessor.renamings) + new_renamings: dict[str, str] = {} for name, expr in new_expressions.items(): if name in terms and terms[name] == expr: continue expr = expr.apply_renamings(predecessor.renamings) used_name: str = name idx: int = 0 - while used_name in terms or used_name in renamings: + while ( + used_name in terms + or used_name in renamings + or used_name in new_renamings + ): used_name = f"{name}_{idx}" idx += 1 terms[used_name] = expr - renamings[name] = used_name + new_renamings[name] = used_name + renamings.update(new_renamings) + for old_name, new_name in new_renamings.items(): + expr = new_expressions.pop(old_name) + new_expressions[new_name] = expr super().__init__(terms, renamings, orderings, predecessor.unique_exprs) self.calc = Calc self.new_expressions = new_expressions @@ -520,7 +529,10 @@ class HybridFilter(HybridOperation): def __init__(self, predecessor: HybridOperation, condition: HybridExpr): super().__init__( - predecessor.terms, {}, predecessor.orderings, predecessor.unique_exprs + predecessor.terms, + predecessor.renamings, + predecessor.orderings, + predecessor.unique_exprs, ) self.predecessor: HybridOperation = predecessor self.condition: HybridExpr = condition @@ -566,7 +578,10 @@ def __init__( records_to_keep: int, ): super().__init__( - predecessor.terms, {}, predecessor.orderings, predecessor.unique_exprs + predecessor.terms, + predecessor.renamings, + predecessor.orderings, + predecessor.unique_exprs, ) self.predecessor: HybridOperation = predecessor self.records_to_keep: int = records_to_keep @@ -908,13 +923,13 @@ def __repr__(self): lines.append(" -> ".join(repr(operation) for operation in self.pipeline)) prefix = " " if self.successor is None else "↓" for idx, child in enumerate(self.children): - lines.append(f"{prefix} child #{idx}:") + lines.append(f"{prefix} child #{idx} ({child.connection_type.name}):") if child.subtree.agg_keys is not None: - lines.append( - f"{prefix} aggregate: {child.subtree.agg_keys} -> {child.aggs}:" - ) + lines.append(f"{prefix} aggregate: {child.subtree.agg_keys}") + if len(child.aggs): + lines.append(f"{prefix} aggs: {child.aggs}:") if child.subtree.join_keys is not None: - lines.append(f"{prefix} join: {child.subtree.join_keys}:") + lines.append(f"{prefix} join: {child.subtree.join_keys}") for line in repr(child.subtree).splitlines(): lines.append(f"{prefix} {line}") return "\n".join(lines) @@ -1964,6 +1979,24 @@ def make_hybrid_tree( rhs_expr.name, 0, rhs_expr.typ ) join_key_exprs.append((lhs_expr, rhs_expr)) + + case PartitionBy(): + partition = HybridPartition() + successor_hybrid = HybridTree(partition) + self.populate_children( + successor_hybrid, node.child_access, child_ref_mapping + ) + partition_child_idx = child_ref_mapping[0] + for key_name in node.calc_terms: + key = node.get_expr(key_name) + expr = self.make_hybrid_expr( + successor_hybrid, key, child_ref_mapping, False + ) + partition.add_key(key_name, expr) + key_exprs.append(HybridRefExpr(key_name, expr.typ)) + successor_hybrid.children[ + partition_child_idx + ].subtree.agg_keys = key_exprs case _: raise NotImplementedError( f"{node.__class__.__name__} (child is {node.child_access.__class__.__name__})" diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 77c10c86..c48132a2 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -885,11 +885,26 @@ def rel_translation( if isinstance(operation.collection, TableCollection): result = self.build_simple_table_scan(operation) if context is not None: + # If the collection access is the child of something + # else, join it onto that something else. Use the + # uniqueness keys of the ancestor, which should also be + # present in the collection (e.g. joining a partition + # onto the original data using the partition keys). + assert preceding_hybrid is not None + join_keys: list[tuple[HybridExpr, HybridExpr]] = [] + for unique_column in sorted( + preceding_hybrid[0].pipeline[0].unique_exprs, key=str + ): + if unique_column not in result.expressions: + raise ValueError( + f"Cannot connect parent context to child {operation.collection} because {unique_column} is not in the child's expressions." + ) + join_keys.append((unique_column, unique_column)) result = self.join_outputs( context, result, JoinType.INNER, - [], + join_keys, None, ) else: @@ -913,7 +928,8 @@ def rel_translation( assert context is not None, "Malformed HybridTree pattern." result = self.translate_filter(operation, context) case HybridPartition(): - assert context is not None, "Malformed HybridTree pattern." + if context is None: + context = TranslationOutput(EmptySingleton(), {}) result = self.translate_partition( operation, context, hybrid, pipeline_idx ) diff --git a/pydough/relational/relational_nodes/column_pruner.py b/pydough/relational/relational_nodes/column_pruner.py index 47cda027..ef8e7cf2 100644 --- a/pydough/relational/relational_nodes/column_pruner.py +++ b/pydough/relational/relational_nodes/column_pruner.py @@ -11,7 +11,8 @@ from .abstract_node import RelationalNode from .aggregate import Aggregate -from .join import Join +from .empty_singleton import EmptySingleton +from .join import Join, JoinType from .project import Project from .relational_expression_dispatcher import RelationalExpressionDispatcher from .relational_root import RelationalRoot @@ -141,7 +142,23 @@ def _prune_node_columns( # Determine the new node. output = new_node.copy(inputs=new_inputs) - return self._prune_identity_project(output), correl_refs + output = self._prune_identity_project(output) + # Special case: replace empty aggregation with VALUES () if possible. + if ( + isinstance(output, Aggregate) + and len(output.keys) == 0 + and len(output.aggregations) == 0 + ): + return EmptySingleton(), correl_refs + # Special case: replace join where LHS is VALUES () with the RHS if + # possible. + if ( + isinstance(output, Join) + and isinstance(output.inputs[0], EmptySingleton) + and output.join_types in ([JoinType.INNER], [JoinType.LEFT]) + ): + return output.inputs[1], correl_refs + return output, correl_refs def prune_unused_columns(self, root: RelationalRoot) -> RelationalRoot: """ diff --git a/pydough/sqlglot/sqlglot_relational_visitor.py b/pydough/sqlglot/sqlglot_relational_visitor.py index 0b53f04d..ca9cbe2c 100644 --- a/pydough/sqlglot/sqlglot_relational_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_visitor.py @@ -431,6 +431,7 @@ def visit_filter(self, filter: Filter) -> None: # TODO: (gh #151) Refactor a simpler way to check dependent expressions. if ( "group" in input_expr.args + or "distinct" in input_expr.args or "where" in input_expr.args or "qualify" in input_expr.args or "order" in input_expr.args @@ -462,6 +463,7 @@ def visit_aggregate(self, aggregate: Aggregate) -> None: query: Select if ( "group" in input_expr.args + or "distinct" in input_expr.args or "qualify" in input_expr.args or "order" in input_expr.args or "limit" in input_expr.args @@ -472,7 +474,10 @@ def visit_aggregate(self, aggregate: Aggregate) -> None: select_cols, input_expr, find_identifiers_in_list(select_cols) ) if keys: - query = query.group_by(*keys) + if aggregations: + query = query.group_by(*keys) + else: + query = query.distinct() self._stack.append(query) def visit_limit(self, limit: Limit) -> None: @@ -511,7 +516,7 @@ def visit_limit(self, limit: Limit) -> None: self._stack.append(query) def visit_empty_singleton(self, singleton: EmptySingleton) -> None: - self._stack.append(Select().from_(values([()]))) + self._stack.append(Select().select(SQLGlotStar()).from_(values([()]))) def visit_root(self, root: RelationalRoot) -> None: self.visit_inputs(root) diff --git a/pydough/unqualified/qualification.py b/pydough/unqualified/qualification.py index ba33907d..0a532d78 100644 --- a/pydough/unqualified/qualification.py +++ b/pydough/unqualified/qualification.py @@ -676,7 +676,12 @@ def qualify_partition( partition: PartitionBy = self.builder.build_partition( qualified_parent, qualified_child, child_name ) - return partition.with_keys(child_references) + partition = partition.with_keys(child_references) + # Special case: if accessing as a child, wrap in a + # ChildOperatorChildAccess term. + if isinstance(unqualified_parent, UnqualifiedRoot) and is_child: + return ChildOperatorChildAccess(partition) + return partition def qualify_collection( self, diff --git a/tests/correlated_pydough_functions.py b/tests/correlated_pydough_functions.py index e3b73053..3d8c4eb7 100644 --- a/tests/correlated_pydough_functions.py +++ b/tests/correlated_pydough_functions.py @@ -300,3 +300,46 @@ def correl_20(): ) instances = selected_orders.lines.supplier.WHERE(is_domestic) return TPCH(n=COUNT(instances)) + + +def correl_21(): + # Correlated back reference example #21: partition edge case. + # Count how many part sizes have an above-average number of parts + # of that size. + # (This is a correlated aggregation access) + sizes = PARTITION(Parts, name="p", by=size)(n_parts=COUNT(p)) + return TPCH(avg_n_parts=AVG(sizes.n_parts))( + n_sizes=COUNT(sizes.WHERE(n_parts > BACK(1).avg_n_parts)) + ) + + +def correl_22(): + # Correlated back reference example #22: partition edge case. + # Finds the top 5 part sizes with the most container types + # where the average retail price of parts of that container type + # & part type is above the global average retail price. + # (This is a correlated aggregation access) + ct_combos = PARTITION(Parts, name="p", by=(container, part_type))( + avg_price=AVG(p.retail_price) + ) + return ( + TPCH(global_avg_price=AVG(Parts.retail_price)) + .PARTITION( + ct_combos.WHERE(avg_price > BACK(1).global_avg_price), + name="ct", + by=container, + )(container, n_types=COUNT(ct)) + .TOP_K(5, (n_types.DESC(), container.ASC())) + ) + + +def correl_23(): + # Correlated back reference example #23: partition edge case. + # Counts how many part sizes have an above-average number of combinations + # of part types/containers. + # (This is a correlated aggregation access) + combos = PARTITION(Parts, name="p", by=(size, part_type, container)) + sizes = PARTITION(combos, name="c", by=size)(n_combos=COUNT(c)) + return TPCH(avg_n_combo=AVG(sizes.n_combos))( + n_sizes=COUNT(sizes.WHERE(n_combos > BACK(1).avg_n_combo)), + ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1ec38017..f862efa2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -33,6 +33,9 @@ correl_18, correl_19, correl_20, + correl_21, + correl_22, + correl_23, ) from simple_pydough_functions import ( agg_partition, @@ -985,6 +988,41 @@ ), id="correl_20", ), + pytest.param( + ( + correl_21, + "correl_21", + lambda: pd.DataFrame({"n_sizes": [30]}), + ), + id="correl_21", + ), + pytest.param( + ( + correl_22, + "correl_22", + lambda: pd.DataFrame( + { + "container": [ + "JUMBO DRUM", + "JUMBO PKG", + "MED DRUM", + "SM BAG", + "LG PKG", + ], + "n_types": [89, 86, 81, 81, 80], + } + ), + ), + id="correl_22", + ), + pytest.param( + ( + correl_23, + "correl_23", + lambda: pd.DataFrame({"n_sizes": [23]}), + ), + id="correl_23", + ), ], ) def pydough_pipeline_test_data( diff --git a/tests/test_plan_refsols/correl_15.txt b/tests/test_plan_refsols/correl_15.txt index 960b0c69..b71be80d 100644 --- a/tests/test_plan_refsols/correl_15.txt +++ b/tests/test_plan_refsols/correl_15.txt @@ -1,25 +1,22 @@ ROOT(columns=[('n', n)], orderings=[]) PROJECT(columns={'n': agg_1}) - JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1}) - AGGREGATE(keys={}, aggregations={}) - SCAN(table=tpch.PART, columns={'brand': p_brand}) - AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) - FILTER(condition=True:bool, columns={'account_balance': account_balance}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr4') - PROJECT(columns={'account_balance': account_balance, 'avg_price': avg_price, 'avg_price_3': agg_0, 'key': key}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'avg_price': t0.avg_price, 'key': t0.key}) - FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'avg_price': avg_price, 'key': key}) - JOIN(conditions=[True:bool], types=['inner'], columns={'account_balance': t1.account_balance, 'avg_price': t0.avg_price, 'key': t1.key, 'nation_key': t1.nation_key}) - PROJECT(columns={'avg_price': agg_0}) - AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) - SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) - AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) - FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) - JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr3') - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - FILTER(condition=container == 'LG DRUM':string & retail_price < corr3.supplycost * 1.5:float64 & retail_price < corr4.avg_price_3 & retail_price < corr4.avg_price * 0.85:float64, columns={'key': key}) - SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) + AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) + FILTER(condition=True:bool, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr4') + PROJECT(columns={'account_balance': account_balance, 'avg_price': avg_price, 'avg_price_3': agg_0, 'key': key}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'avg_price': t0.avg_price, 'key': t0.key}) + FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'avg_price': avg_price, 'key': key}) + JOIN(conditions=[True:bool], types=['inner'], columns={'account_balance': t1.account_balance, 'avg_price': t0.avg_price, 'key': t1.key, 'nation_key': t1.nation_key}) + PROJECT(columns={'avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) + FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr3') + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + FILTER(condition=container == 'LG DRUM':string & retail_price < corr3.supplycost * 1.5:float64 & retail_price < corr4.avg_price_3 & retail_price < corr4.avg_price * 0.85:float64, columns={'key': key}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_18.txt b/tests/test_plan_refsols/correl_18.txt index db0b5291..74fab0da 100644 --- a/tests/test_plan_refsols/correl_18.txt +++ b/tests/test_plan_refsols/correl_18.txt @@ -10,10 +10,10 @@ ROOT(columns=[('n', n)], orderings=[]) AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_2': COUNT()}) FILTER(condition=total_price_3 >= 0.5:float64 * total_price, columns={'customer_key': customer_key, 'order_date': order_date}) FILTER(condition=YEAR(order_date_2) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price, 'total_price_3': total_price_3}) - JOIN(conditions=[True:bool], types=['inner'], columns={'customer_key': t0.customer_key, 'order_date': t0.order_date, 'order_date_2': t1.order_date, 'total_price': t0.total_price, 'total_price_3': t1.total_price}) + JOIN(conditions=[t0.customer_key == t1.customer_key & t0.order_date == t1.order_date], types=['inner'], columns={'customer_key': t0.customer_key, 'order_date': t0.order_date, 'order_date_2': t1.order_date, 'total_price': t0.total_price, 'total_price_3': t1.total_price}) PROJECT(columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': DEFAULT_TO(agg_1, 0:int64)}) FILTER(condition=DEFAULT_TO(agg_0, 0:int64) > 1:int64, columns={'agg_1': agg_1, 'customer_key': customer_key, 'order_date': order_date}) AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_0': COUNT(), 'agg_1': SUM(total_price)}) FILTER(condition=YEAR(order_date) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price}) SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) - SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate, 'total_price': o_totalprice}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) diff --git a/tests/test_plan_refsols/correl_19.txt b/tests/test_plan_refsols/correl_19.txt index a65ac794..ddaa017a 100644 --- a/tests/test_plan_refsols/correl_19.txt +++ b/tests/test_plan_refsols/correl_19.txt @@ -1,16 +1,15 @@ -ROOT(columns=[('name', name_14), ('n_super_cust', n_super_cust)], orderings=[(ordering_1):desc_last]) - PROJECT(columns={'n_super_cust': n_super_cust, 'name_14': name_3, 'ordering_1': ordering_1}) - LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'n_super_cust': n_super_cust, 'name_3': name_3, 'ordering_1': ordering_1}, orderings=[(ordering_1):desc_last]) - PROJECT(columns={'n_super_cust': n_super_cust, 'name_3': name_3, 'ordering_1': n_super_cust}) - PROJECT(columns={'n_super_cust': DEFAULT_TO(agg_0, 0:int64), 'name_3': name_3}) - JOIN(conditions=[t0.key_2 == t1.key_5 & t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name_3': t1.name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'key': key, 'key_5': key_5}, aggregations={'agg_0': COUNT()}) - FILTER(condition=acctbal > account_balance, columns={'key': key, 'key_5': key_5}) - JOIN(conditions=[t0.key_5 == t1.nation_key], types=['inner'], columns={'account_balance': t0.account_balance, 'acctbal': t1.acctbal, 'key': t0.key, 'key_5': t0.key_5}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'account_balance': t0.account_balance, 'key': t0.key, 'key_5': t1.key}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) +ROOT(columns=[('name', name_0), ('n_super_cust', n_super_cust)], orderings=[(ordering_1):desc_last]) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'n_super_cust': n_super_cust, 'name_0': name_0, 'ordering_1': ordering_1}, orderings=[(ordering_1):desc_last]) + PROJECT(columns={'n_super_cust': n_super_cust, 'name_0': name_0, 'ordering_1': n_super_cust}) + PROJECT(columns={'n_super_cust': DEFAULT_TO(agg_0, 0:int64), 'name_0': name}) + JOIN(conditions=[t0.key_2 == t1.key_5 & t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + AGGREGATE(keys={'key': key, 'key_5': key_5}, aggregations={'agg_0': COUNT()}) + FILTER(condition=acctbal > account_balance, columns={'key': key, 'key_5': key_5}) + JOIN(conditions=[t0.key_5 == t1.nation_key], types=['inner'], columns={'account_balance': t0.account_balance, 'acctbal': t1.acctbal, 'key': t0.key, 'key_5': t0.key_5}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'account_balance': t0.account_balance, 'key': t0.key, 'key_5': t1.key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_20.txt b/tests/test_plan_refsols/correl_20.txt index 670d0880..66ed3e35 100644 --- a/tests/test_plan_refsols/correl_20.txt +++ b/tests/test_plan_refsols/correl_20.txt @@ -2,7 +2,7 @@ ROOT(columns=[('n', n)], orderings=[]) PROJECT(columns={'n': agg_0}) AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) FILTER(condition=domestic, columns={'account_balance': account_balance}) - JOIN(conditions=[t0.key_9 == t1.key_21 & t0.order_key == t1.order_key & t0.line_number == t1.line_number & t0.key_5 == t1.key_17 & t0.key_2 == t1.key_14 & t0.key == t1.key], types=['left'], columns={'account_balance': t0.account_balance, 'domestic': t1.domestic}) + JOIN(conditions=[t0.key_9 == t1.key_21 & t0.line_number == t1.line_number & t0.order_key == t1.order_key & t0.key_5 == t1.key_17 & t0.key_2 == t1.key_14 & t0.key == t1.key], types=['left'], columns={'account_balance': t0.account_balance, 'domestic': t1.domestic}) JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'account_balance': t1.account_balance, 'key': t0.key, 'key_2': t0.key_2, 'key_5': t0.key_5, 'key_9': t1.key, 'line_number': t0.line_number, 'order_key': t0.order_key}) JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'key': t0.key, 'key_2': t0.key_2, 'key_5': t0.key_5, 'line_number': t1.line_number, 'order_key': t1.order_key, 'supplier_key': t1.supplier_key}) FILTER(condition=YEAR(order_date) == 1998:int64 & MONTH(order_date) == 6:int64, columns={'key': key, 'key_2': key_2, 'key_5': key_5}) diff --git a/tests/test_plan_refsols/correl_21.txt b/tests/test_plan_refsols/correl_21.txt new file mode 100644 index 00000000..24c781d1 --- /dev/null +++ b/tests/test_plan_refsols/correl_21.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('n_sizes', n_sizes)], orderings=[]) + PROJECT(columns={'n_sizes': agg_1}) + AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) + FILTER(condition=n_parts > avg_n_parts, columns={'agg_0': agg_0}) + PROJECT(columns={'agg_0': agg_0, 'avg_n_parts': avg_n_parts, 'n_parts': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t1.agg_0, 'avg_n_parts': t0.avg_n_parts}) + PROJECT(columns={'avg_n_parts': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(n_parts)}) + PROJECT(columns={'n_parts': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + SCAN(table=tpch.PART, columns={'size': p_size}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + SCAN(table=tpch.PART, columns={'size': p_size}) diff --git a/tests/test_plan_refsols/correl_22.txt b/tests/test_plan_refsols/correl_22.txt new file mode 100644 index 00000000..4c5be854 --- /dev/null +++ b/tests/test_plan_refsols/correl_22.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('container', container), ('n_types', n_types)], orderings=[(ordering_2):desc_last, (ordering_3):asc_first]) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'container': container, 'n_types': n_types, 'ordering_2': ordering_2, 'ordering_3': ordering_3}, orderings=[(ordering_2):desc_last, (ordering_3):asc_first]) + PROJECT(columns={'container': container, 'n_types': n_types, 'ordering_2': n_types, 'ordering_3': container}) + PROJECT(columns={'container': container, 'n_types': DEFAULT_TO(agg_1, 0:int64)}) + AGGREGATE(keys={'container': container}, aggregations={'agg_1': COUNT()}) + FILTER(condition=avg_price > global_avg_price, columns={'container': container}) + PROJECT(columns={'avg_price': agg_0, 'container': container, 'global_avg_price': global_avg_price}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t1.agg_0, 'container': t1.container, 'global_avg_price': t0.global_avg_price}) + PROJECT(columns={'global_avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + AGGREGATE(keys={'container': container, 'part_type': part_type}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'container': p_container, 'part_type': p_type, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_23.txt b/tests/test_plan_refsols/correl_23.txt new file mode 100644 index 00000000..5bbadbc9 --- /dev/null +++ b/tests/test_plan_refsols/correl_23.txt @@ -0,0 +1,15 @@ +ROOT(columns=[('n_sizes', n_sizes)], orderings=[]) + PROJECT(columns={'n_sizes': agg_1}) + AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) + FILTER(condition=n_combos > avg_n_combo, columns={'agg_0': agg_0}) + PROJECT(columns={'agg_0': agg_0, 'avg_n_combo': avg_n_combo, 'n_combos': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t1.agg_0, 'avg_n_combo': t0.avg_n_combo}) + PROJECT(columns={'avg_n_combo': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(n_combos)}) + PROJECT(columns={'n_combos': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + AGGREGATE(keys={'container': container, 'part_type': part_type, 'size': size}, aggregations={}) + SCAN(table=tpch.PART, columns={'container': p_container, 'part_type': p_type, 'size': p_size}) + AGGREGATE(keys={'size': size}, aggregations={'agg_0': COUNT()}) + AGGREGATE(keys={'container': container, 'part_type': part_type, 'size': size}, aggregations={}) + SCAN(table=tpch.PART, columns={'container': p_container, 'part_type': p_type, 'size': p_size}) diff --git a/tests/test_plan_refsols/join_regions_nations_calc_override.txt b/tests/test_plan_refsols/join_regions_nations_calc_override.txt index 97156430..e1cce36c 100644 --- a/tests/test_plan_refsols/join_regions_nations_calc_override.txt +++ b/tests/test_plan_refsols/join_regions_nations_calc_override.txt @@ -1,6 +1,6 @@ -ROOT(columns=[('key', key_0_9), ('name', name_11), ('phone', phone), ('mktsegment', mktsegment)], orderings=[]) - PROJECT(columns={'key_0_9': key_0_9, 'mktsegment': mktsegment, 'name_11': name_10, 'phone': phone}) - PROJECT(columns={'key_0_9': -3:int64, 'mktsegment': mktsegment, 'name_10': name_7, 'phone': phone}) +ROOT(columns=[('key', key_0_10), ('name', name_11), ('phone', phone), ('mktsegment', mktsegment)], orderings=[]) + PROJECT(columns={'key_0_10': key_0_10, 'mktsegment': mktsegment, 'name_11': name_9, 'phone': phone}) + PROJECT(columns={'key_0_10': -3:int64, 'mktsegment': mktsegment, 'name_9': name_7, 'phone': phone}) JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'mktsegment': t1.mktsegment, 'name_7': t1.name, 'phone': t1.phone}) JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key}) SCAN(table=tpch.REGION, columns={'key': r_regionkey}) diff --git a/tests/test_plan_refsols/tpch_q22.txt b/tests/test_plan_refsols/tpch_q22.txt index 1bd521eb..01b9f860 100644 --- a/tests/test_plan_refsols/tpch_q22.txt +++ b/tests/test_plan_refsols/tpch_q22.txt @@ -1,18 +1,18 @@ -ROOT(columns=[('CNTRY_CODE', CNTRY_CODE), ('NUM_CUSTS', NUM_CUSTS), ('TOTACCTBAL', TOTACCTBAL)], orderings=[]) - PROJECT(columns={'CNTRY_CODE': cntry_code, 'NUM_CUSTS': DEFAULT_TO(agg_1, 0:int64), 'TOTACCTBAL': DEFAULT_TO(agg_2, 0:int64)}) - JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1, 'agg_2': t1.agg_2, 'cntry_code': t1.cntry_code}, correl_name='corr1') - PROJECT(columns={'avg_balance': agg_0}) - AGGREGATE(keys={}, aggregations={'agg_0': AVG(acctbal)}) - FILTER(condition=acctbal > 0.0:float64, columns={'acctbal': acctbal}) - FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]) & True:bool, columns={'acctbal': acctbal}) - JOIN(conditions=[t0.key == t1.customer_key], types=['anti'], columns={'acctbal': t0.acctbal, 'cntry_code': t0.cntry_code}) - PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) +ROOT(columns=[('CNTRY_CODE', CNTRY_CODE), ('NUM_CUSTS', NUM_CUSTS), ('TOTACCTBAL', TOTACCTBAL)], orderings=[(ordering_3):asc_first]) + PROJECT(columns={'CNTRY_CODE': CNTRY_CODE, 'NUM_CUSTS': NUM_CUSTS, 'TOTACCTBAL': TOTACCTBAL, 'ordering_3': CNTRY_CODE}) + PROJECT(columns={'CNTRY_CODE': cntry_code, 'NUM_CUSTS': DEFAULT_TO(agg_1, 0:int64), 'TOTACCTBAL': DEFAULT_TO(agg_2, 0:int64)}) AGGREGATE(keys={'cntry_code': cntry_code}, aggregations={'agg_1': COUNT(), 'agg_2': SUM(acctbal)}) - FILTER(condition=acctbal > corr1.avg_balance, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) - FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]) & True:bool, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) - JOIN(conditions=[t0.key == t1.customer_key], types=['anti'], columns={'acctbal': t0.acctbal, 'cntry_code': t0.cntry_code}) - PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) + FILTER(condition=acctbal > avg_balance & DEFAULT_TO(agg_0, 0:int64) == 0:int64, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) + JOIN(conditions=[t0.key == t1.customer_key], types=['left'], columns={'acctbal': t0.acctbal, 'agg_0': t1.agg_0, 'avg_balance': t0.avg_balance, 'cntry_code': t0.cntry_code}) + FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]), columns={'acctbal': acctbal, 'avg_balance': avg_balance, 'cntry_code': cntry_code, 'key': key}) + PROJECT(columns={'acctbal': acctbal, 'avg_balance': avg_balance, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) + JOIN(conditions=[True:bool], types=['inner'], columns={'acctbal': t1.acctbal, 'avg_balance': t0.avg_balance, 'key': t1.key, 'phone': t1.phone}) + PROJECT(columns={'avg_balance': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(acctbal)}) + FILTER(condition=acctbal > 0.0:float64, columns={'acctbal': acctbal}) + FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]), columns={'acctbal': acctbal}) + PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown)}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'phone': c_phone}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) + AGGREGATE(keys={'customer_key': customer_key}, aggregations={'agg_0': COUNT()}) SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) diff --git a/tests/test_qualification.py b/tests/test_qualification.py index 6e25259d..eab7c54b 100644 --- a/tests/test_qualification.py +++ b/tests/test_qualification.py @@ -68,6 +68,26 @@ def pydough_impl_misc_02(root: UnqualifiedNode) -> UnqualifiedNode: ) +def pydough_impl_misc_03(root: UnqualifiedNode) -> UnqualifiedNode: + """ + Creates an UnqualifiedNode for the following PyDough snippet: + ``` + sizes = PARTITION(Parts, name="p", by=size)(n_parts=COUNT(p)) + TPCH( + avg_n_parts=AVG(sizes.n_parts) + )( + n_parts=COUNT(sizes.WHERE(n_parts > BACK(1).avg_n_parts)) + ) + ``` + """ + sizes = root.PARTITION(root.Parts, name="p", by=root.size)( + n_parts=root.COUNT(root.p) + ) + return root.TPCH(avg_n_parts=root.AVG(sizes.n_parts))( + n_parts=root.COUNT(sizes.WHERE(root.n_parts > root.BACK(1).avg_n_parts)) + ) + + def pydough_impl_tpch_q1(root: UnqualifiedNode) -> UnqualifiedNode: """ Creates an UnqualifiedNode for TPC-H query 1. @@ -612,6 +632,30 @@ def pydough_impl_tpch_q22(root: UnqualifiedNode) -> UnqualifiedNode: """, id="misc_02", ), + pytest.param( + pydough_impl_misc_03, + """ +┌─── TPCH +├─┬─ Calc[avg_n_parts=AVG($1.n_parts)] +│ └─┬─ AccessChild +│ ├─┬─ Partition[name='p', by=size] +│ │ └─┬─ AccessChild +│ │ └─── TableCollection[Parts] +│ └─┬─ Calc[n_parts=COUNT($1)] +│ └─┬─ AccessChild +│ └─── PartitionChild[p] +└─┬─ Calc[n_parts=COUNT($1)] + └─┬─ AccessChild + ├─┬─ Partition[name='p', by=size] + │ └─┬─ AccessChild + │ └─── TableCollection[Parts] + ├─┬─ Calc[n_parts=COUNT($1)] + │ └─┬─ AccessChild + │ └─── PartitionChild[p] + └─── Where[n_parts > BACK(1).avg_n_parts] +""", + id="misc_03", + ), pytest.param( pydough_impl_tpch_q1, """ diff --git a/tests/test_relational_nodes_to_sqlglot.py b/tests/test_relational_nodes_to_sqlglot.py index 7859c88f..6c3c04cd 100644 --- a/tests/test_relational_nodes_to_sqlglot.py +++ b/tests/test_relational_nodes_to_sqlglot.py @@ -152,6 +152,8 @@ def mkglot(expressions: list[Expression], _from: Expression, **kwargs) -> Select query = query.where(kwargs.pop("where")) if "group_by" in kwargs: query = query.group_by(*kwargs.pop("group_by")) + if kwargs.pop("distinct", False): + query = query.distinct() if "qualify" in kwargs: query = query.qualify(kwargs.pop("qualify")) if "order_by" in kwargs: @@ -627,14 +629,15 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: Aggregate( input=build_simple_scan(), keys={ + "a": make_relational_column_reference("a"), "b": make_relational_column_reference("b"), }, aggregations={}, ), mkglot( - expressions=[Ident(this="b")], + expressions=[Ident(this="a"), Ident(this="b")], _from=GlotFrom(Table(this=Ident(this="table"))), - group_by=[Ident(this="b")], + distinct=True, ), id="simple_distinct", ), @@ -718,7 +721,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: mkglot( expressions=[Ident(this="b")], where=mkglot_func(EQ, [Ident(this="a"), mk_literal(1, False)]), - group_by=[Ident(this="b")], + distinct=True, _from=GlotFrom( mkglot( expressions=[Ident(this="a"), Ident(this="b")], @@ -794,7 +797,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: ), mkglot( expressions=[Ident(this="b")], - group_by=[Ident(this="b")], + distinct=True, _from=GlotFrom( mkglot( expressions=[Ident(this="a"), Ident(this="b")], @@ -829,7 +832,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: mkglot( expressions=[Ident(this="b")], _from=GlotFrom(Table(this=Ident(this="table"))), - group_by=[Ident(this="b")], + distinct=True, order_by=[Ident(this="b").desc(nulls_first=False)], limit=mk_literal(10, False), ), @@ -865,8 +868,8 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: _from=GlotFrom( mkglot( expressions=[Ident(this="b")], - group_by=[Ident(this="b")], _from=GlotFrom(Table(this=Ident(this="table"))), + distinct=True, ) ), ), diff --git a/tests/test_sql_refsols/simple_distinct.sql b/tests/test_sql_refsols/simple_distinct.sql index 696d95c0..38ad7566 100644 --- a/tests/test_sql_refsols/simple_distinct.sql +++ b/tests/test_sql_refsols/simple_distinct.sql @@ -1,5 +1,3 @@ -SELECT +SELECT DISTINCT b FROM table -GROUP BY - b diff --git a/tests/test_unqualified_node.py b/tests/test_unqualified_node.py index 386266b2..34c7921c 100644 --- a/tests/test_unqualified_node.py +++ b/tests/test_unqualified_node.py @@ -418,7 +418,7 @@ def test_unqualified_to_string( ), pytest.param( impl_tpch_q22, - "?.TPCH(avg_balance=AVG(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE((ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17']) & HASNOT(?.orders))).WHERE((?.acctbal > 0.0)).acctbal)).PARTITION(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE((ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17']) & HASNOT(?.orders))).WHERE((?.acctbal > BACK(1).avg_balance)), name='custs', by=(?.cntry_code))(CNTRY_CODE=?.cntry_code, NUM_CUSTS=COUNT(?.custs), TOTACCTBAL=SUM(?.custs.acctbal))", + "?.TPCH(avg_balance=AVG(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE(ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17'])).WHERE((?.acctbal > 0.0)).acctbal)).PARTITION(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE(ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17'])).WHERE(((?.acctbal > BACK(1).avg_balance) & (COUNT(?.orders) == 0))), name='custs', by=(?.cntry_code))(CNTRY_CODE=?.cntry_code, NUM_CUSTS=COUNT(?.custs), TOTACCTBAL=SUM(?.custs.acctbal)).ORDER_BY(?.CNTRY_CODE.ASC(na_pos='first'))", id="tpch_q22", ), pytest.param( diff --git a/tests/tpch_outputs.py b/tests/tpch_outputs.py index 80948acb..82fbe5cb 100644 --- a/tests/tpch_outputs.py +++ b/tests/tpch_outputs.py @@ -715,7 +715,7 @@ def tpch_q22_output() -> pd.DataFrame: This query needs manual rewriting to run efficiently in SQLite by avoiding the correlated join. """ - columns = ["CNTRYCODE", "NUMCUST", "TOTACCTBAL"] + columns = ["CNTRY_CODE", "NUM_CUSTS", "TOTACCTBAL"] data = [ ("13", 888, 6737713.99), ("17", 861, 6460573.72), diff --git a/tests/tpch_test_functions.py b/tests/tpch_test_functions.py index 21cfdaa8..c1df2cf8 100644 --- a/tests/tpch_test_functions.py +++ b/tests/tpch_test_functions.py @@ -494,16 +494,20 @@ def impl_tpch_q22(): PyDough implementation of TPCH Q22. """ selected_customers = Customers(cntry_code=phone[:2]).WHERE( - ISIN(cntry_code, ("13", "31", "23", "29", "30", "18", "17")) & HASNOT(orders) + ISIN(cntry_code, ("13", "31", "23", "29", "30", "18", "17")) ) - return TPCH( - avg_balance=AVG(selected_customers.WHERE(acctbal > 0.0).acctbal) - ).PARTITION( - selected_customers.WHERE(acctbal > BACK(1).avg_balance), - name="custs", - by=cntry_code, - )( - CNTRY_CODE=cntry_code, - NUM_CUSTS=COUNT(custs), - TOTACCTBAL=SUM(custs.acctbal), + return ( + TPCH(avg_balance=AVG(selected_customers.WHERE(acctbal > 0.0).acctbal)) + .PARTITION( + selected_customers.WHERE( + (acctbal > BACK(1).avg_balance) & (COUNT(orders) == 0) + ), + name="custs", + by=cntry_code, + )( + CNTRY_CODE=cntry_code, + NUM_CUSTS=COUNT(custs), + TOTACCTBAL=SUM(custs.acctbal), + ) + .ORDER_BY(CNTRY_CODE.ASC()) )