From deeb9140ddce75a85fa0899ae8295cbd9fed36d9 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 13 Jan 2025 11:59:57 -0500 Subject: [PATCH 001/112] Starting function list documentation --- documentation/functions.md | 144 +++++++++++++++++++++++- pydough/unqualified/unqualified_node.py | 5 + 2 files changed, 148 insertions(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index f59d750c..6f47492e 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -1,3 +1,145 @@ # PyDough Functions List -TODO (gh #199): fill out this file with a list of every function / function-like operator supported in PyDough, including their behaviors and restrictions. +Below is the list of every function/operator currently supported in PyDough as a builtin. + +## Binary Operators + +Below is each binary operator currently supported in PyDough. + +### Arithmetic + +Numerical expression values can be: +- Added together with the `+` operator +- Subtracted from one another with the `-` operator +- Multiplied by one another with the `*` operator +- divided by one another with the `/` operator (note: the behavior when the denominator is `0` depends on the database being used to evaluate the expression) + +```py +Lineitems(value = (extended_price * (1 - discount) + 1.0) / part.retail_price) +``` + +### Comparisons + +Expression values can be compared to one another with the standard comparison operators `<=`, `<`, `==`, `!=`, `>` and `>=`: + +```py +Customers( + in_debt = acctbal < 0, + at_most_12_orders = COUNT(orders) <= 12, + is_european = nation.region.name == "EUROPE", + non_german = nation.name != "GERMANY", + non_empty_acct = acctbal > 0, + at_least_5_orders = COUNT(orders) >= 5, +) +``` + +> [!WARNING] +> Do **NOT** use use chained inequalities like `a <= b <= c`, as this can cause undefined incorrect behavior in PyDough. Instead, use expressions like `(a <= b) & (b <= c)`. + +### Logical + +Multiple boolean expression values can be logically combined with `&`, `|` and `~` being used as logical AND, OR and NOT, respectively: + +```py +is_asian = nation.region.name == "ASIA" +is_european = nation.region.name == "EUROPE" +in_debt = acctbal < 0 +Customers( + is_eurasian = is_asian | is_european, + is_not_eurasian = ~(is_asian | is_european), + is_european_in_debt = is_european & in_debt +) +``` + +> [!WARNING] +> Do **NOT** use the builtin Python syntax `and`, `or`, or `not` on PyDough node. Using these instead of `&`, `|` or `~` can result in undefined incorrect results. + +## Unary Operators + +Below is each unary operator currently supported in PyDough. + +### Negation + +A numerical expression can have its sign flipped by prefixing it with the `-` operator: + +```py +Lineitems(lost_value = extended_price * (-discount)) +``` + +## Other Operators + +Below are all other operators currently supported in PyDough that use other syntax besides function calls: + +### Slicing + +A string expression can have a substring extracted with Python string slicing syntax `s[a:b:c]`: + +```py +Customers( + country_code = phone[:3], + name_without_first_char = name[1:] +) +``` + +> [!WARNING] +> PyDough currently only supports combinations of `string[start:stop:step]` where `step` is either 1 or missing, and where both `start` and `stop` are either non-negative values or are missing. + +## String Functions + +Below is each function currently supported in PyDough that operates on strings. + +### LOWER + +Calling `LOWER` on a string converts its characters to lowercase: + +```py +Customers(lowercase_name = LOWER(name)) +``` + +### UPPER + +Calling `UPPER` on a string converts its characters to uppercase: + +```py +Customers(uppercase_name = UPPER(name)) +``` + +### STARTSWITH + +The `STARTSWITH` function returns whether its first argument begins with its second argument as a string prefix: + +```py +Parts(begins_with_yellow = STARTSWITH(name, "yellow")) +``` + +### ENDSWITH + +The `ENDSWITH` function returns whether its first argument ends with its second argument as a string suffix: + +```py +Parts(ends_with_chocolate = ENDSWITH(name, "chocolate")) +``` + +### CONTAINS + +The `CONTAINS` function returns whether its first argument contains with its second argument as a substring: + +```py +Parts(is_green = CONTAINS(name, "green")) +``` + +### LIKE + +The `LIKE` function returns whether the first argument matches the SQL pattern text of the second argument, where `_` is a 1 character wildcard and `%` is an 0+ character wildcard. + +```py +Orders(is_special_request = LIKE(comment, "%special%requests%")) +``` + +Below are some examples of how to interpret these patterns: +- `"a_c"` returns True for any 3-letter string where the first character is `"a"` and the third is `"c"`. +- `"_q__"` returns True for any 4-letter string where the second character is `"q"`. +- `"%_s"` returns True for any 2+-letter string where the last character is `"s"`. +- `"a%z"` returns True for any string that starts with `"a"` and ends with `"z"`. +- `"%a%z%"` returns True for any string that contains an `"a"`, and also contains a `"z"` at some later point in the string. +- `"_e%"` returns True for any string where the second character is `"e"`. diff --git a/pydough/unqualified/unqualified_node.py b/pydough/unqualified/unqualified_node.py index c292a45c..cfe6f25e 100644 --- a/pydough/unqualified/unqualified_node.py +++ b/pydough/unqualified/unqualified_node.py @@ -131,6 +131,11 @@ def __getitem__(self, key): f"Cannot index into PyDough object {self} with {key!r}" ) + def __bool__(self): + raise PyDoughUnqualifiedException( + "PyDough code cannot be treated as a boolean. If you intend to do a logical operation, use `|`, `&` or `~` instead of `or`, `and` and `not`." + ) + def __add__(self, other: object): other_unqualified: UnqualifiedNode = self.coerce_to_unqualified(other) return UnqualifiedBinaryOperation("+", self, other_unqualified) From 5d6c5137c2cbacdb7bdb2be417eb4ddc9bb7653c Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 13 Jan 2025 12:30:31 -0500 Subject: [PATCH 002/112] Adding datetime functions and bad boolean tests --- documentation/functions.md | 28 ++++++++++++++++++++++++++++ tests/bad_pydough_functions.py | 17 +++++++++++++++++ tests/test_unqualified_node.py | 18 ++++++++++++++++++ 3 files changed, 63 insertions(+) diff --git a/documentation/functions.md b/documentation/functions.md index 6f47492e..d0dd5e1e 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -143,3 +143,31 @@ Below are some examples of how to interpret these patterns: - `"a%z"` returns True for any string that starts with `"a"` and ends with `"z"`. - `"%a%z%"` returns True for any string that contains an `"a"`, and also contains a `"z"` at some later point in the string. - `"_e%"` returns True for any string where the second character is `"e"`. + +## Datetime Functions + +Below is each function currently supported in PyDough that operates on date/time/timestamp values. + +### YEAR + +Calling `YEAR` on a date/timestamp extracts the year it belongs to: + +```py +Orders.WHERE(YEAR(order_date) == 1995) +``` + +### MONTH + +Calling `MONTH` on a date/timestamp extracts the month of the year it belongs to: + +```py +Orders(is_summer = (MONTH(order_date) >= 6) & (MONTH(order_date) <= 8)) +``` + +### DAY + +Calling `DAY` on a date/timestamp extracts the day of the month it belongs to: + +```py +Orders(is_first_of_month = DAY(order_date) == 1) +``` \ No newline at end of file diff --git a/tests/bad_pydough_functions.py b/tests/bad_pydough_functions.py index 6940f654..bfd01f38 100644 --- a/tests/bad_pydough_functions.py +++ b/tests/bad_pydough_functions.py @@ -3,6 +3,23 @@ # ruff & mypy should not try to typecheck or verify any of this +def bad_bool_1(): + # Using `or` + return Customer( + is_eurasian=(nation.region.name == "EUROPE") or (nation.region.name == "ASIA") + ) + + +def bad_bool_2(): + # Using `and` + return Parts.WHERE((size == 38) and CONTAINS(name, "green")) + + +def bad_bool_3(): + # Using `not` + return Parts.WHERE(not STARTSWITH(size, "LG")) + + def bad_window_1(): # Missing `by` return Orders(RANKING()) diff --git a/tests/test_unqualified_node.py b/tests/test_unqualified_node.py index ce799b62..2cbf7a2c 100644 --- a/tests/test_unqualified_node.py +++ b/tests/test_unqualified_node.py @@ -9,6 +9,9 @@ import pytest from bad_pydough_functions import ( + bad_bool_1, + bad_bool_2, + bad_bool_3, bad_window_1, bad_window_2, bad_window_3, @@ -426,6 +429,21 @@ def test_init_pydough_context( @pytest.mark.parametrize( "func, error_msg", [ + pytest.param( + bad_bool_1, + "PyDough code cannot be treated as a boolean. If you intend to do a logical operation, use `|`, `&` or `~` instead of `or`, `and` and `not`.", + id="bad_bool_1", + ), + pytest.param( + bad_bool_2, + "PyDough code cannot be treated as a boolean. If you intend to do a logical operation, use `|`, `&` or `~` instead of `or`, `and` and `not`.", + id="bad_bool_2", + ), + pytest.param( + bad_bool_3, + "PyDough code cannot be treated as a boolean. If you intend to do a logical operation, use `|`, `&` or `~` instead of `or`, `and` and `not`.", + id="bad_bool_3", + ), pytest.param( bad_window_1, "The `by` argument to `RANKING` must be provided", From 8b8c0985a1692b5653b647fee9c820a1374dfdac Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 13 Jan 2025 13:28:46 -0500 Subject: [PATCH 003/112] Adding remaining functions including agg/window --- documentation/functions.md | 214 +++++++++++++++++- .../expression_operators/README.md | 2 +- 2 files changed, 214 insertions(+), 2 deletions(-) diff --git a/documentation/functions.md b/documentation/functions.md index d0dd5e1e..a056719e 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -144,6 +144,18 @@ Below are some examples of how to interpret these patterns: - `"%a%z%"` returns True for any string that contains an `"a"`, and also contains a `"z"` at some later point in the string. - `"_e%"` returns True for any string where the second character is `"e"`. +### JOIN_STRINGS + +The `JOIN_STRINGS` function combines all of its string arguments by concatenating every argument after the first argument, using the first argument as a delimiter between each of the following arguments (like the `.join` method in Python): + +```py +Regions.nations.customers( + fully_qualified_name = JOIN_STRINGS("-", BACK(2).name, BACK(1).name, name) +) +``` + +For instance, `JOIN_STRINGS("; ", "Alpha", "Beta", "Gamma)` returns `"Alpha; Beta; Gamma"`. + ## Datetime Functions Below is each function currently supported in PyDough that operates on date/time/timestamp values. @@ -170,4 +182,204 @@ Calling `DAY` on a date/timestamp extracts the day of the month it belongs to: ```py Orders(is_first_of_month = DAY(order_date) == 1) -``` \ No newline at end of file +``` + +## Conditional Functions + +Below is each function currently supported in PyDough that handles conditional logic. + +### IFF + +The `IFF` function cases on the True/False value of its first argument. If it is True, it returns the second argument, otherwise it returns the third argument. In this way, the PyDough code `IFF(a, b, c)` is semantically the same as the SQL expression `CASE WHEN a THEN b ELSE c END`. + +```py +qty_from_germany = IFF(supplier.nation.name == "GERMANY", quantity, 0) +Customers( + total_quantity_shipped_from_germany = SUM(lines(q=qty_from_germany).q) +) +``` + +### DEFAULT_TO + +The `DEFAULT_TO` function returns the first of its arguments that is non-null (e.g. the same as the `COALESCE` function in SQL): + +```py +Lineitems(adj_tax = DEFAULT_TO(tax, 0)) +``` + +### PRESENT + +The `PRESENT` function returns whether its argument is non-null (e.g. the same as `IS NOT NULL` in SQL): + +```py +Lineitems(has_tax = PRESENT(tax)) +``` + +### ABSENT + +The `ABSENT` function returns whether its argument is non-null (e.g. the same as `IS NULL` in SQL): + +```py +Lineitems(no_tax = ABSENT(tax)) +``` + +### KEEP_IF + +The `KEEP_IF` function returns the first function if the second arguments is True, otherwise it returns a null value. In other words, `KEEP_IF(a, b)` is equivalent to the SQL expression `CASE WHEN b THEN a END`. + +```py +TPCH(avg_non_debt_balance = AVG(Customers(no_debt_bal = KEEP_IF(acctbal, acctbal > 0)).no_debt_bal)) +``` + +### MONOTONIC + +The `MONOTONIC` function returns whether all of its arguments are in ascending order (e.g. `MONOTONIC(a, b, c, d)` is equivalent to `(a <= b) & (b <= c) & (c <= d)`): + +```py +Lineitems.WHERE(MONOTONIC(10, quantity, 20) & MONOTONIC(5, part.size, 13)) +``` + +## Numerical Functions + +Below is each numerical function currently supported in PyDough. + +### ABS + +The `ABS` function returns the absolute value of its input. + +```py +Customers(acct_magnitude = ABS(acctbal)) +``` + +### ROUND + +The `ROUND` function rounds its first argument to the precision of its second argument. The rounding rules used depend on the database's round function. + +```py +Parts(rounded_price = ROUND(retail_price, 1)) +``` + +## Aggregation Functions + +Normally, functions in PyDough maintain the cardinality of their inputs. Aggregation functions instead take in an argument that can be plural and aggregates it into a singular value with regards to the current context. Below is each function currently supported in PyDough that can aggregate plural values into a singular value. + +### SUM + +The `SUM` function returns the sum of the plural set of numerical values it is called on. + +```py +Nations(total_consumer_wealth = SUM(customers.acctbal)) +``` + +### MIN + +The `MIN` function returns the smallest value from the set of numerical values it is called on. + +```py +Suppliers(cheapest_part_supplied = MIN(supply_records.supply_cost)) +``` + +### MAX + +The `MAX` function returns the largest value from the set of numerical values it is called on. + +```py +Suppliers(most_expensive_part_supplied = MIN(supply_records.supply_cost)) +``` + +### COUNT + +The `COUNT` function returns how many non-null records exist on the set of plural values it is called on. + +```py +Customers(num_taxed_purchases = COUNT(orders.lines.tax)) +``` + +The `COUNT` function can also be called on a sub-collection, in which case it will return how many records from that sub-collection exist. + +```py +Nations(num_customers_in_debt = COUNT(customers.WHERE(acctbal < 0))) +``` + +### NDISTINCT + +The `NDISTINCT` function returns how many distinct values of its argument exist. + +```py +Customers(num_unique_parts_purchased = NDISTINCT(orders.lines.parts.key)) +``` + +### HAS + +The `HAS` function is called on a sub-collection and returns True if at least one record of the sub-collection exists. In other words, `HAS(x)` is equivalent to `COUNT(x) > 0`. + +```py +Parts.WHERE(HAS(supply_records.supplier.WHERE(nation.name == "GERMANY"))) +``` + +### HASNOT + +The `HASNOT` function is called on a sub-collection and returns True if no records of the sub-collection exist. In other words, `HASNOT(x)` is equivalent to `COUNT(x) == 0`. + +```py +Customers.WHERE(HASNOT(orders)) +``` + +## Window Functions + +Window functions are special functions that return a value for each record in the current context that depends on other records in the same context. A common example of this is ordering all values within the current context to return a value that depends on the current record's ordinal position relative to all the other records in the context. + +Window functions in PyDough have an optional `levels` argument. If this argument is not provided, it means that the window function applies to all records of the current collection without any boundaries between records. If it is provided, it should be a value that can be used as an argument to `BACK`, and in that case it means that the window function should be used on records of the current collection grouped by that particular ancestor. + +For example, if using the `RANKING` window function, consider the following examples: + +```py +# (no levels) rank every customer relative to all other customers +Regions.nations.customers(r=RANKING(...)) + +# (levels=1) rank every customer relative to other customers in the same nation +Regions.nations.customers(r=RANKING(..., levels=1)) + +# (levels=2) rank every customer relative to other customers in the same region +Regions.nations.customers(r=RANKING(..., levels=2)) + +# (levels=3) rank every customer relative to other customers in the same nation +Regions.nations.customers(r=RANKING(..., levels=3)) +``` + +Below is each window function currently supported in PyDough. + +### RANKING + +The `RANKING` function returns ordinal position of the current record when all records in the current context are sorted by certain ordering keys. The arguments: + +- `by`: 1+ collation values, either as a single expression or an iterable of expressions, used to order the records of the current context. +- `levels`: same `levels` argument as all other window functions. +- `allow_ties`: optional argument (default False) specifying to allow values that are tied according to the `by` expressions to have the same rank value. If False, tied values have different rank values where ties are broken arbitrarily. +- `dense`: optional argument (default False) specifying that if `allow_ties` is True and a tie is found, should the next value after hte ties be the current ranking value plus 1, as opposed to jumping to a higher value based on the number of ties that were there. For example, with the values `[a, a, b, b, b, c]`, the values with `dense=True` would be `[1, 1, 2, 2, 2, 3]`, but with `dense=False` they would be `[1, 1, 3, 3, 3, 6]`. + +```py +# Rank customers per-nation by their account balance +# (highest = rank #1, no ties) +Nations.customers(r = RANKING(by=acctbal.DESC(), levels=1)) + +# For every customer, finds their most recent order +# (ties allowed) +Customers.orders.WHERE(RANKING(by=order_date.DESC(), levels=1, allow_ties=True) == 1) +``` + +### PERCENTILE + +The `PERCENTILE` function returns what index the current record belongs to if all records in the current context are ordered then split into evenly sized buckets. The arguments: + +- `by`: 1+ collation values, either as a single expression or an iterable of expressions, used to order the records of the current context. +- `levels`: same `levels` argument as all other window functions. +- `n_buckets`: optional argument (default 100) specifying the number of buckets to use. The first values according to the sort order are assigned bucket `1`, and the last values are assigned bucket `n_buckets`. + +```py +# Keep the top 0.1% of customers with the highest account balances. +Customers.WHERE(PERCENTILE(by=acctbal.ASC(), n_buckets=1000) == 1000) + +# For every region, find the top 5% of customers with the highest account balances. +Regions.nations.customers.WHERE(PERCENTILE(by=acctbal.ASC(), levels=2) > 95) +``` diff --git a/pydough/pydough_operators/expression_operators/README.md b/pydough/pydough_operators/expression_operators/README.md index d419288c..4e0014b4 100644 --- a/pydough/pydough_operators/expression_operators/README.md +++ b/pydough/pydough_operators/expression_operators/README.md @@ -73,7 +73,7 @@ These functions must be called on singular data as a function. - `STARTSWITH`: returns whether the first argument string starts with the second argument string. - `ENDSWITH`: returns whether the first argument string ends with the second argument string. - `CONTAINS`: returns whether the first argument string contains the second argument string. -- `LIKES`: returns whether the first argument matches the SQL pattern text of the second argument, where `_` is a 1 character wildcard and `%` is an 0+ character wildcard. +- `LIKE`: returns whether the first argument matches the SQL pattern text of the second argument, where `_` is a 1 character wildcard and `%` is an 0+ character wildcard. - `JOIN_STRINGS`: equivalent to the Python string join method, where the first argument is used as a delimiter to concatenate the remaining arguments. ##### Datetime Functions From cb83f1ff31541b49b9d362221884da4a2badc5c8 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 13 Jan 2025 13:30:12 -0500 Subject: [PATCH 004/112] Adding toc --- documentation/functions.md | 42 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/documentation/functions.md b/documentation/functions.md index a056719e..1e15ff5f 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -2,6 +2,48 @@ Below is the list of every function/operator currently supported in PyDough as a builtin. +- [Binary Operators](#binary-operators) + * [Arithmetic](#arithmetic) + * [Comparisons](#comparisons) + * [Logical](#logical) +- [Unary Operators](#unary-operators) + * [Negation](#negation) +- [Other Operators](#other-operators) + * [Slicing](#slicing) +- [String Functions](#string-functions) + * [LOWER](#lower) + * [UPPER](#upper) + * [STARTSWITH](#startswith) + * [ENDSWITH](#endswith) + * [CONTAINS](#contains) + * [LIKE](#like) + * [JOIN_STRINGS](#join_strings) +- [Datetime Functions](#datetime-functions) + * [YEAR](#year) + * [MONTH](#month) + * [DAY](#day) +- [Conditional Functions](#conditional-functions) + * [IFF](#iff) + * [DEFAULT_TO](#default_to) + * [PRESENT](#present) + * [ABSENT](#absent) + * [KEEP_IF](#keep_if) + * [MONOTONIC](#monotonic) +- [Numerical Functions](#numerical-functions) + * [ABS](#abs) + * [ROUND](#round) +- [Aggregation Functions](#aggregation-functions) + * [SUM](#sum) + * [MIN](#min) + * [MAX](#max) + * [COUNT](#count) + * [NDISTINCT](#ndistinct) + * [HAS](#has) + * [HASNOT](#hasnot) +- [Window Functions](#window-functions) + * [RANKING](#ranking) + * [PERCENTILE](#percentile) + ## Binary Operators Below is each binary operator currently supported in PyDough. From ef39e6e64a78fa8f150df19267043808eb7dfa91 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 13 Jan 2025 13:31:38 -0500 Subject: [PATCH 005/112] Adding toc --- documentation/functions.md | 45 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/documentation/functions.md b/documentation/functions.md index 1e15ff5f..df9d8359 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -2,6 +2,8 @@ Below is the list of every function/operator currently supported in PyDough as a builtin. + + - [Binary Operators](#binary-operators) * [Arithmetic](#arithmetic) * [Comparisons](#comparisons) @@ -44,10 +46,14 @@ Below is the list of every function/operator currently supported in PyDough as a * [RANKING](#ranking) * [PERCENTILE](#percentile) + + + ## Binary Operators Below is each binary operator currently supported in PyDough. + ### Arithmetic Numerical expression values can be: @@ -60,6 +66,7 @@ Numerical expression values can be: Lineitems(value = (extended_price * (1 - discount) + 1.0) / part.retail_price) ``` + ### Comparisons Expression values can be compared to one another with the standard comparison operators `<=`, `<`, `==`, `!=`, `>` and `>=`: @@ -78,6 +85,7 @@ Customers( > [!WARNING] > Do **NOT** use use chained inequalities like `a <= b <= c`, as this can cause undefined incorrect behavior in PyDough. Instead, use expressions like `(a <= b) & (b <= c)`. + ### Logical Multiple boolean expression values can be logically combined with `&`, `|` and `~` being used as logical AND, OR and NOT, respectively: @@ -96,10 +104,12 @@ Customers( > [!WARNING] > Do **NOT** use the builtin Python syntax `and`, `or`, or `not` on PyDough node. Using these instead of `&`, `|` or `~` can result in undefined incorrect results. + ## Unary Operators Below is each unary operator currently supported in PyDough. + ### Negation A numerical expression can have its sign flipped by prefixing it with the `-` operator: @@ -108,10 +118,12 @@ A numerical expression can have its sign flipped by prefixing it with the `-` op Lineitems(lost_value = extended_price * (-discount)) ``` + ## Other Operators Below are all other operators currently supported in PyDough that use other syntax besides function calls: + ### Slicing A string expression can have a substring extracted with Python string slicing syntax `s[a:b:c]`: @@ -126,10 +138,12 @@ Customers( > [!WARNING] > PyDough currently only supports combinations of `string[start:stop:step]` where `step` is either 1 or missing, and where both `start` and `stop` are either non-negative values or are missing. + ## String Functions Below is each function currently supported in PyDough that operates on strings. + ### LOWER Calling `LOWER` on a string converts its characters to lowercase: @@ -138,6 +152,7 @@ Calling `LOWER` on a string converts its characters to lowercase: Customers(lowercase_name = LOWER(name)) ``` + ### UPPER Calling `UPPER` on a string converts its characters to uppercase: @@ -146,6 +161,7 @@ Calling `UPPER` on a string converts its characters to uppercase: Customers(uppercase_name = UPPER(name)) ``` + ### STARTSWITH The `STARTSWITH` function returns whether its first argument begins with its second argument as a string prefix: @@ -154,6 +170,7 @@ The `STARTSWITH` function returns whether its first argument begins with its sec Parts(begins_with_yellow = STARTSWITH(name, "yellow")) ``` + ### ENDSWITH The `ENDSWITH` function returns whether its first argument ends with its second argument as a string suffix: @@ -162,6 +179,7 @@ The `ENDSWITH` function returns whether its first argument ends with its second Parts(ends_with_chocolate = ENDSWITH(name, "chocolate")) ``` + ### CONTAINS The `CONTAINS` function returns whether its first argument contains with its second argument as a substring: @@ -170,6 +188,7 @@ The `CONTAINS` function returns whether its first argument contains with its sec Parts(is_green = CONTAINS(name, "green")) ``` + ### LIKE The `LIKE` function returns whether the first argument matches the SQL pattern text of the second argument, where `_` is a 1 character wildcard and `%` is an 0+ character wildcard. @@ -186,6 +205,7 @@ Below are some examples of how to interpret these patterns: - `"%a%z%"` returns True for any string that contains an `"a"`, and also contains a `"z"` at some later point in the string. - `"_e%"` returns True for any string where the second character is `"e"`. + ### JOIN_STRINGS The `JOIN_STRINGS` function combines all of its string arguments by concatenating every argument after the first argument, using the first argument as a delimiter between each of the following arguments (like the `.join` method in Python): @@ -198,10 +218,12 @@ Regions.nations.customers( For instance, `JOIN_STRINGS("; ", "Alpha", "Beta", "Gamma)` returns `"Alpha; Beta; Gamma"`. + ## Datetime Functions Below is each function currently supported in PyDough that operates on date/time/timestamp values. + ### YEAR Calling `YEAR` on a date/timestamp extracts the year it belongs to: @@ -210,6 +232,7 @@ Calling `YEAR` on a date/timestamp extracts the year it belongs to: Orders.WHERE(YEAR(order_date) == 1995) ``` + ### MONTH Calling `MONTH` on a date/timestamp extracts the month of the year it belongs to: @@ -218,6 +241,7 @@ Calling `MONTH` on a date/timestamp extracts the month of the year it belongs to Orders(is_summer = (MONTH(order_date) >= 6) & (MONTH(order_date) <= 8)) ``` + ### DAY Calling `DAY` on a date/timestamp extracts the day of the month it belongs to: @@ -226,10 +250,12 @@ Calling `DAY` on a date/timestamp extracts the day of the month it belongs to: Orders(is_first_of_month = DAY(order_date) == 1) ``` + ## Conditional Functions Below is each function currently supported in PyDough that handles conditional logic. + ### IFF The `IFF` function cases on the True/False value of its first argument. If it is True, it returns the second argument, otherwise it returns the third argument. In this way, the PyDough code `IFF(a, b, c)` is semantically the same as the SQL expression `CASE WHEN a THEN b ELSE c END`. @@ -241,6 +267,7 @@ Customers( ) ``` + ### DEFAULT_TO The `DEFAULT_TO` function returns the first of its arguments that is non-null (e.g. the same as the `COALESCE` function in SQL): @@ -249,6 +276,7 @@ The `DEFAULT_TO` function returns the first of its arguments that is non-null (e Lineitems(adj_tax = DEFAULT_TO(tax, 0)) ``` + ### PRESENT The `PRESENT` function returns whether its argument is non-null (e.g. the same as `IS NOT NULL` in SQL): @@ -257,6 +285,7 @@ The `PRESENT` function returns whether its argument is non-null (e.g. the same a Lineitems(has_tax = PRESENT(tax)) ``` + ### ABSENT The `ABSENT` function returns whether its argument is non-null (e.g. the same as `IS NULL` in SQL): @@ -265,6 +294,7 @@ The `ABSENT` function returns whether its argument is non-null (e.g. the same as Lineitems(no_tax = ABSENT(tax)) ``` + ### KEEP_IF The `KEEP_IF` function returns the first function if the second arguments is True, otherwise it returns a null value. In other words, `KEEP_IF(a, b)` is equivalent to the SQL expression `CASE WHEN b THEN a END`. @@ -273,6 +303,7 @@ The `KEEP_IF` function returns the first function if the second arguments is Tru TPCH(avg_non_debt_balance = AVG(Customers(no_debt_bal = KEEP_IF(acctbal, acctbal > 0)).no_debt_bal)) ``` + ### MONOTONIC The `MONOTONIC` function returns whether all of its arguments are in ascending order (e.g. `MONOTONIC(a, b, c, d)` is equivalent to `(a <= b) & (b <= c) & (c <= d)`): @@ -281,10 +312,12 @@ The `MONOTONIC` function returns whether all of its arguments are in ascending o Lineitems.WHERE(MONOTONIC(10, quantity, 20) & MONOTONIC(5, part.size, 13)) ``` + ## Numerical Functions Below is each numerical function currently supported in PyDough. + ### ABS The `ABS` function returns the absolute value of its input. @@ -293,6 +326,7 @@ The `ABS` function returns the absolute value of its input. Customers(acct_magnitude = ABS(acctbal)) ``` + ### ROUND The `ROUND` function rounds its first argument to the precision of its second argument. The rounding rules used depend on the database's round function. @@ -301,10 +335,12 @@ The `ROUND` function rounds its first argument to the precision of its second ar Parts(rounded_price = ROUND(retail_price, 1)) ``` + ## Aggregation Functions Normally, functions in PyDough maintain the cardinality of their inputs. Aggregation functions instead take in an argument that can be plural and aggregates it into a singular value with regards to the current context. Below is each function currently supported in PyDough that can aggregate plural values into a singular value. + ### SUM The `SUM` function returns the sum of the plural set of numerical values it is called on. @@ -313,6 +349,7 @@ The `SUM` function returns the sum of the plural set of numerical values it is c Nations(total_consumer_wealth = SUM(customers.acctbal)) ``` + ### MIN The `MIN` function returns the smallest value from the set of numerical values it is called on. @@ -321,6 +358,7 @@ The `MIN` function returns the smallest value from the set of numerical values i Suppliers(cheapest_part_supplied = MIN(supply_records.supply_cost)) ``` + ### MAX The `MAX` function returns the largest value from the set of numerical values it is called on. @@ -329,6 +367,7 @@ The `MAX` function returns the largest value from the set of numerical values it Suppliers(most_expensive_part_supplied = MIN(supply_records.supply_cost)) ``` + ### COUNT The `COUNT` function returns how many non-null records exist on the set of plural values it is called on. @@ -343,6 +382,7 @@ The `COUNT` function can also be called on a sub-collection, in which case it wi Nations(num_customers_in_debt = COUNT(customers.WHERE(acctbal < 0))) ``` + ### NDISTINCT The `NDISTINCT` function returns how many distinct values of its argument exist. @@ -351,6 +391,7 @@ The `NDISTINCT` function returns how many distinct values of its argument exist. Customers(num_unique_parts_purchased = NDISTINCT(orders.lines.parts.key)) ``` + ### HAS The `HAS` function is called on a sub-collection and returns True if at least one record of the sub-collection exists. In other words, `HAS(x)` is equivalent to `COUNT(x) > 0`. @@ -359,6 +400,7 @@ The `HAS` function is called on a sub-collection and returns True if at least on Parts.WHERE(HAS(supply_records.supplier.WHERE(nation.name == "GERMANY"))) ``` + ### HASNOT The `HASNOT` function is called on a sub-collection and returns True if no records of the sub-collection exist. In other words, `HASNOT(x)` is equivalent to `COUNT(x) == 0`. @@ -367,6 +409,7 @@ The `HASNOT` function is called on a sub-collection and returns True if no recor Customers.WHERE(HASNOT(orders)) ``` + ## Window Functions Window functions are special functions that return a value for each record in the current context that depends on other records in the same context. A common example of this is ordering all values within the current context to return a value that depends on the current record's ordinal position relative to all the other records in the context. @@ -391,6 +434,7 @@ Regions.nations.customers(r=RANKING(..., levels=3)) Below is each window function currently supported in PyDough. + ### RANKING The `RANKING` function returns ordinal position of the current record when all records in the current context are sorted by certain ordering keys. The arguments: @@ -410,6 +454,7 @@ Nations.customers(r = RANKING(by=acctbal.DESC(), levels=1)) Customers.orders.WHERE(RANKING(by=order_date.DESC(), levels=1, allow_ties=True) == 1) ``` + ### PERCENTILE The `PERCENTILE` function returns what index the current record belongs to if all records in the current context are ordered then split into evenly sized buckets. The arguments: From 3c795430900736f42536ff46583d38d0ddaa0e48 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 13 Jan 2025 13:32:59 -0500 Subject: [PATCH 006/112] Fixing typo [RUN CI] --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index df9d8359..bf77f3a8 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -428,7 +428,7 @@ Regions.nations.customers(r=RANKING(..., levels=1)) # (levels=2) rank every customer relative to other customers in the same region Regions.nations.customers(r=RANKING(..., levels=2)) -# (levels=3) rank every customer relative to other customers in the same nation +# (levels=3) rank every customer relative to all other customers Regions.nations.customers(r=RANKING(..., levels=3)) ``` From 3db13d14b8bafe275ea4c4ea16cddd6c42b7d52a Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 13 Jan 2025 14:40:16 -0500 Subject: [PATCH 007/112] Started DSL documentation --- documentation/dsl.md | 147 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 146 insertions(+), 1 deletion(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index 4eaf5a01..4c410f22 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -1,4 +1,149 @@ # PyDough DSL Spec -TODO (gh #198): fill out this file with an extensive and exhaustive list of all features/operations in PyDough that are necessary for a user to understand what is/isn't supported in the basic DSL before any extensions, besides the specific list of supported functions. +This page describes the specification of the PyDough DSL. The specification includes rules of how PyDough code should be structured and the semantics that are used when evaluating PyDough code. Not every feature in the spec is implemented in PyDough as of this time. + +## Example Graph + +The examples in this document use a metadata graph (named `GRAPH`) with the following collections: +- `People`: records of every known person. Scalar properties: `first_name`, `middle_name`, `last_name`, `ssn`, `birth_date`, `email`, `current_address_id`. +- `Addresses`: records of every known address. Scalar properties: `address_id`, `street_number`, `street_name`, `apartment`, `zip_code`, `city`, `state`. +- `Packages`: records of every known package. Scalar properties: `package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`. + +There are also the following sub-collection relationships: +- `People.packages`: every package ordered by each person (reverse is `Packages.customer`). There can be 0, 1 or multiple packages ordered by a single person, but each package has exactly one person who ordered it. +- `People.current_address`: the current address of each person, if one exists (reverse is `Addresses.current_occupants`). Each person has at most 1 current address (which can be missing), but each address can have 0, 1, or multiple people currently occupying it. +- `Packages.shipping_address`: the address that the package is shipped to (reverse is `Addresses.packages_shipped`). Every package has exactly one shipping address, but each address can have 0, 1 or multiple packages shipped to it. +- `Packages.billing_address`: the address that the package is billed to (reverse is `Addresses.packages_billed`). Every package has exactly one billing address, but each address can have 0, 1 or multiple packages billed to it. + +## Collections + +The simplest PyDough code is scanning an entire collection. This is done by providing the name of the collection in the metadata. However, if that name is already used as a variable, then PyDough will not know to replace the name with the corresponding PyDough object. + +Good Example #1: obtains every record of the `People` collection. Every scalar property of `People` (`first_name`, `middle_name`, `last_name`, `ssn`, `birth_date`, `email`, `current_address_id`) is automatically included in the output. + +```py +%%pydough +People +``` + +Good Example #2: obtains every record of the `Addresses` collection. The `GRAPH.` prefix is optional and implied when the term is a collection name in the graph. Every scalar property of `Addresses` (`address_id`, `street_number`, `street_name`, `apartment`, `zip_code`, `city`, `state`) is automatically included in the output. + +```py +%%pydough +GRAPH.Addresses +``` + +Bad Example #1: obtains every record of the `Products` collection (there is no `Products` collection). + +```py +%%pydough +Addresses +``` + +Bad Example #2: obtains every record of the `Addresses` collection (but the name `Addresses` has been reassigned to a variable). + +```py +%%pydough +Addresses = 42 +Addresses +``` + +Bad Example #3: obtains every record of the `Addresses` collection (but the graph name `HELLO` is the wrong graph name for this example). + +```py +%%pydough +HELLO.Addresses +``` + +### Sub-Collections + +The next step in PyDough after accessing a collection is accessing any of its sub-collections. The syntax `collection.subcollection` steps into every record of `subcollection` for each record of `collection`. This can result in changes of cardinality if records of `collection` can have multiple records of `subcollection`, and can result in duplicate records in the output if records of `subcollection` can be sourced from different records of `collection`. + +Good Example #1: for every person, obtains their current address. Every scalar property of `Addresses` (`address_id`, `street_number`, `street_name`, `apartment`, `zip_code`, `city`, `state`) is automatically included in the output. A record from `Addresses` can be included multiple times if multiple different `People` records have it as their current address, or it could be missing entirely if no person has it as their current address. + +```py +%%pydough +People.current_addresses +``` + +Good Example #2: for every package, obtains the person who shipped it address. The `GRAPH.` prefix is optional and implied when the term is a collection name in the graph. Every scalar property of `People` (`first_name`, `middle_name`, `last_name`, `ssn`, `birth_date`, `email`, `current_address_id`) is automatically included in the output. A record from `People` can be included multiple times if multiple packages were ordered by that person, or it could be missing entirely if that person is not the customer who ordered any package. + +```py +%%pydough +GRAPH.Packages.customer +``` + +Good Example #3: for every address, obtains all packages that someone who lives at that address has ordered. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`). Every record from `Packages` should be included at most once since every current occupant has a single address it maps back to, and every package has a single customer it maps back to. + +```py +%%pydough +Addresses.current_occupants.packages +``` + +Bad Example #1: for every address, obtains all people who used to live there. This is invalid because the `Addresses` collection does not have a `former_occupants` property. + +```py +%%pydough +Addresses.former_occupants +``` + +### CALC + +TODO + +### Contextless Expressions + +TODO + +### BACK + +TODO + +## Collection Operators + +TODO + +### WHERE + +TODO + +### ORDER_BY + +TODO + +### TOP_K + +TODO + +### PARTITION + +TODO + +### SINGULAR + +TODO + +### BEST + +TODO + +### NEXT / PREV + +TODO + +## Induced Properties + +TODO + +### Induced Scalar Properties + +TODO + +### Induced Subcollection Properties + +TODO + +### Induced Arbitrary Joins + +TODO From 8dc17b3974e90fa79bb127958c1051e0ddd9f6e0 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 13 Jan 2025 16:06:54 -0500 Subject: [PATCH 008/112] Adding calc, contextless, and back --- documentation/dsl.md | 306 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 301 insertions(+), 5 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index 4c410f22..e7008642 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -7,7 +7,7 @@ This page describes the specification of the PyDough DSL. The specification incl The examples in this document use a metadata graph (named `GRAPH`) with the following collections: - `People`: records of every known person. Scalar properties: `first_name`, `middle_name`, `last_name`, `ssn`, `birth_date`, `email`, `current_address_id`. - `Addresses`: records of every known address. Scalar properties: `address_id`, `street_number`, `street_name`, `apartment`, `zip_code`, `city`, `state`. -- `Packages`: records of every known package. Scalar properties: `package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`. +- `Packages`: records of every known package. Scalar properties: `package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`. There are also the following sub-collection relationships: - `People.packages`: every package ordered by each person (reverse is `Packages.customer`). There can be 0, 1 or multiple packages ordered by a single person, but each package has exactly one person who ordered it. @@ -73,7 +73,7 @@ Good Example #2: for every package, obtains the person who shipped it address. T GRAPH.Packages.customer ``` -Good Example #3: for every address, obtains all packages that someone who lives at that address has ordered. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`). Every record from `Packages` should be included at most once since every current occupant has a single address it maps back to, and every package has a single customer it maps back to. +Good Example #3: for every address, obtains all packages that someone who lives at that address has ordered. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`). Every record from `Packages` should be included at most once since every current occupant has a single address it maps back to, and every package has a single customer it maps back to. ```py %%pydough @@ -89,15 +89,311 @@ Addresses.former_occupants ### CALC -TODO +The examples so far just show selecting all properties from records of a collection. Most of the time, an analytical question will only want a subset of the properties, may want to rename them, and may want to derive new properties via calculated expressions. The way to do this with a CALC term, which is done by following a PyDough collection with parenthesis containing the expressions that should be included. + +These expressions can be positional arguments or keyword arguments. Keyword arguments use the name of the keyword as the name of the output expression. Positional arguments use the name of the expression, if one exists, otherwise an arbitrary name is chosen. + +The value of one of these terms in a CALC must be expressions that are singular with regards to the current context. That can mean: +- Referencing one of the scalar properties of the current collection. +- Creating a literal. +- Referencing a singular expression of a sub-collection of the current collection that is singular with regards to the current collection. +- Calling a non-aggregation function on more singular expressions. +- Calling an aggregation function on a plural expression. + +Once a CALC term is created, all terms of the current collection still exist even if they weren't part of the CALC and can still be referenced, they just will not be part of the final answer. If there are multiple CALC terms, the last one is used to determine what expressions are part of the final answer, so earlier CALCs can be used to derive intermediary expressions. If a CALC includes a term with the same name as an existing property of the collection, the existing name is overridden to include the new term. + +A CALC can also be done on the graph itself to create a collection with 1 row and columns corresponding to the properties inside the CALC. This is useful when aggregating an entire collection globally instead of with regards to a parent collection. + +Good Example #1: For every person, fetches just their first name & last name. + +```py +%%pydough +People(first_name, last_name) +``` + +Good Example #2: For every package, fetches the package id, the first & last name of the person who ordered it, and the state that it was shipped to. Also includes a field named `secret_key` that is always equal to the string `"alphabet soup"`. + +```py +%%pydough +Packages( + package_id, + first_name=customer.first_name, + last_name=customer.last_name, + shipping_state=shipping_address.state, + secret_key="alphabet soup", +) +``` + +Good Example #3: For every person, finds their full name (without the middle name) and counts how many packages they purchased. + +```py +%%pydough +People( + name=JOIN_STRINGS("", first_name, last_name), + n_packages_ordered=COUNT(packages), +) +``` + +Good Example #4: For every person, finds their full name including the middle name if one exists, as well as their email. Notice that two CALCs are present, but only the terms from the second one are part of the answer. + +```py +%%pydough +People( + has_middle_name=PRESENT(middle_name) + full_name_with_middle=JOIN_STRINGS(" ", first_name, middle_name, last_name), + full_name_without_middle=JOIN_STRINGS(" ", first_name, last_name), +)( + full_name=IFF(has_middle_name, full_name_with_middle, full_name_without_middle), + email=email, +) +``` + +Good Example #4: For every person, finds the year from the most recent package they purchased, and from the first package they ever purchased. + +```py +%%pydough +People( + most_recent_package_year=YEAR(MAX(packages.order_date)), + first_ever_package_year=YEAR(MIN(packages.order_date)), +) +``` + +Good Example #5: Count how many people, packages, and addresses are known in the system. + +```py +%%pydough +GRAPH( + n_people=COUNT(People), + n_packages=COUNT(Packages), + n_addresses=COUNT(Addresses), +) +``` + +Good Example #6: For each package, lists the package id and whether the package was shipped to the current address of the person who ordered it. + +```py +%%pydough +Packages( + package_id, + shipped_to_curr_addr=shipping_address.address_id == customer.current_address.address_id +) +``` + +Bad Example #1: For each person, lists their first name, last name, and phone number. This is invalid because `People` does not have a property named `phone_number`. + +```py +%%pydough +People(first_name, last_name, phone_number) +``` + +Bad Example #2: For each person, lists their combined first & last name followed by their email. This is invalid because a positional argument is included after a keyword argument. + +```py +%%pydough +People( + full_name=JOIN_STRINGS(" ", first_name, last_name), + email +) +``` + +Bad Example #3: For each person, lists the address_id of packages they have ordered. This is invalid because `packages` is a plural property of `People`, so its properties cannot be included in a calc term of `People` unless aggregated. + +```py +%%pydough +People(packages.address_id) +``` + +Bad Example #4: For each person, lists their first/last name followed by the concatenated city/state name of their current address. This is invalid because `current_address` is a plural property of `People`, so its properties cannot be included in a calc term of `People` unless aggregated. + +```py +%%pydough +People( + first_name, + last_name, + location=JOIN_STRINGS(", ", current_address.city, current_address.state), +) +``` + +Bad Example #5: For each address, finds whether the state name starts with `"C"`. This is invalid because it calls the builtin Python `.startswith` string method, which is not supported in PyDough (should have instead used a defined PyDough behavior, like the `STARTSWITH` function). + +```py +%%pydough +Addresses(is_c_state=state.startswith("c")) +``` + +Bad Example #6: For each address, finds the state bird of the state it is in. This is invalid because the `state` property of each record of `Addresses` is a scalar expression, not a subcolleciton, so it does not have any properties that can be accessed with `.` syntax. + +```py +%%pydough +Addresses(state_bird=state.bird) +``` + +Bad Example #7: For each current occupant of each address, lists their first name, last name, and city/state they live in. This is invalid because `city` and `state` are not properties of the current collection (`People`, accessed via `current_occupants` of each record of `Addresses`). + +```py +%%pydough +Addresses.current_occupants(first_name, last_name, city, state) +``` + +Bad Example #8: For each person include their ssn and current address. This is invalid because a collection cannot be a CALC term, and `current_address` is a sub-collection property of `People`. Instead, properties of `current_address` can be accessed. + +```py +%%pydough +People(ssn, current_address) +``` + ### Contextless Expressions -TODO +PyDough allows defining snippets of PyDough code out of context that do not make sense until they are later placed within a context. This can be done by writing a contextless expression, binding it to a variable as if it were any other Python expression, then later using it inside of PyDough code. This should always have the same effect as if the PyDough code was written fully in-context, but allows re-using common snippets. + +Good Example #1: Same as good example #4 from the CALC section, but written with contextless expressions. + +```py +%%pydough +has_middle_name = PRESENT(middle_name) +full_name_with_middle = JOIN_STRINGS(" ", first_name, middle_name, last_name), +full_name_without_middle = JOIN_STRINGS(" ", first_name, last_name), +People( + full_name=IFF(has_middle_name, full_name_with_middle, full_name_without_middle), + email=email, +) +``` + +Good Example #2: for every person, finds the total value of all packages they ordered in February of any year, as well as the number of all such packages, the largest value of any such package, and the percentage of those packages that were specifically on valentine's day + +```py +%%pydough +is_february = MONTH(order_date) == 2 +february_value = KEEP_IF(package_cost, is_february) +aug_packages = packages( + is_february=is_february, + february_value=february_value, + is_valentines_day=is_february & (DAY(order_date) == 14) +) +n_feb_packages = SUM(aug_packages.is_february) +People( + ssn, + total_february_value=SUM(aug_packages.february_value), + n_february_packages=n_feb_packages, + most_expensive_february_package=MAX(aug_packages.february_value), + pct_valentine=n_feb_packages / SUM(aug_packages.is_valentines_day) +) +``` + +Bad Example #1: Just a contextless expression for a collection without the necessary context for it to make sense. + +```py +%%pydough +current_addresses(city, state) +``` + +Bad Example #2: Just a contextless expression for a scalar expression that has not been placed into a collection for it to make sense. + +```py +%%pydough +LOWER(current_occupants.first_name) +``` + +Bad Example #3: A contextless expression that does not make sense when placed into its context (`People` does not have a property named `package_cost`, so substituting it when `value` is referenced does not make sense). + +```py +%%pydough +value = package_cost +People(x=ssn + value) +``` ### BACK -TODO +Part of the benefit of doing `collection.subcollection` accesses is that properties from the ancestor collection can be accessed from the current collection. This is done via a `BACK` call. Accessing properties from `BACK(n)` can be done to access properties from the n-th ancestor of the current collection. The simplest recommended way to do this is to just access a scalar property of an ancestor in order to include it in the final answer. + +Good Example #1: For every address' current occupants, lists their first name last name, and the city/state of the current address they belong to. + +```py +%%pydough +Addresses.current_occupants( + first_name, + last_name, + current_city=BACK(1).city, + current_state=BACK(1).state, +) +``` + +Good Example #2: Count the total number of cases where a package is shipped to the current address of the customer who ordered it. + +```py +%%pydough +package_info = Addresses.current_occupants.packages( + is_shipped_to_current_addr=shipping_address.address_id == BACK(2).address_id +) +GRAPH(n_cases=SUM(package_info.is_shipped_to_current_addr)) +``` + +Good Example #3: Indicate whether a package is above the average cost for all packages ordered by that customer. + +```py +%%pydough +Customers( + avg_package_cost=AVG(packages.cost) +).packages( + is_above_avg=cost > BACK(1).avg_package_cost +) +``` + +Good Example #4: For every customer, indicate what percentage of all packages billed to their current address were purchased by that same customer. + +```py +%%pydough +aug_packages = packages( + include=IFF(billing_address.address_id == BACK(2).address_id, 1, 0) +) +Addresses( + n_packages=COUNT(packages_billed_to) +).current_occupants( + ssn, + pct=100.0 * SUM(aug_packages.include) / BACK(1).n_packages +) +``` + +Bad Example #1: The `GRAPH` does not have any ancestors, so `BACK(1)` is invalid. + +```py +%%pydough +GRAPH(x=BACK(1).foo) +``` + +Bad Example #2: The 1st ancestor of `People` is `GRAPH` which does not have a term named `bar`. + +```py +%%pydough +People(y=BACK(1).bar) +``` + +Bad Example #3: The 1st ancestor of `People` is `GRAPH` which does not have an ancestor, so there can be no 2nd ancestor of `People`. + +```py +%%pydough +People(z=BACK(2).fizz) +``` + +Bad Example #4: The 1st ancestor of `current_address` is `People` which does not have a term named `phone`. + +```py +%%pydough +People.current_address(a=BACK(1).phone) +``` + +Bad Example #5: Even though `cust_info` has defined `avg_package_cost`, the final expression `Customers.packages(...)` does not have `cust_info` as an ancestor, so it cannot access `BACK(1).avg_package_cost` since its 1st ancestor (`Customers`) does not have any term named `avg_package_cost`. + +```py +%%pydough +cust_info = Customers( + avg_package_cost=AVG(packages.cost) +) +Customers.packages( + is_above_avg=cost > BACK(1).avg_package_cost +) +``` ## Collection Operators From 044c2175abc42ac8f9d5558a9e6681c676773aa7 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 13 Jan 2025 16:07:46 -0500 Subject: [PATCH 009/112] Added TOC and TODOs --- documentation/dsl.md | 51 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index e7008642..e2781271 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -1,7 +1,33 @@ + + +- [PyDough DSL Spec](#pydough-dsl-spec) + * [Example Graph](#example-graph) + * [Collections](#collections) + + [Sub-Collections](#sub-collections) + + [CALC](#calc) + + [Contextless Expressions](#contextless-expressions) + + [BACK](#back) + * [Collection Operators](#collection-operators) + + [WHERE](#where) + + [ORDER_BY](#order_by) + + [TOP_K](#top_k) + + [PARTITION](#partition) + + [SINGULAR](#singular) + + [BEST](#best) + + [NEXT / PREV](#next-prev) + * [Induced Properties](#induced-properties) + + [Induced Scalar Properties](#induced-scalar-properties) + + [Induced Subcollection Properties](#induced-subcollection-properties) + + [Induced Arbitrary Joins](#induced-arbitrary-joins) + + + + # PyDough DSL Spec This page describes the specification of the PyDough DSL. The specification includes rules of how PyDough code should be structured and the semantics that are used when evaluating PyDough code. Not every feature in the spec is implemented in PyDough as of this time. + ## Example Graph The examples in this document use a metadata graph (named `GRAPH`) with the following collections: @@ -15,6 +41,7 @@ There are also the following sub-collection relationships: - `Packages.shipping_address`: the address that the package is shipped to (reverse is `Addresses.packages_shipped`). Every package has exactly one shipping address, but each address can have 0, 1 or multiple packages shipped to it. - `Packages.billing_address`: the address that the package is billed to (reverse is `Addresses.packages_billed`). Every package has exactly one billing address, but each address can have 0, 1 or multiple packages billed to it. + ## Collections The simplest PyDough code is scanning an entire collection. This is done by providing the name of the collection in the metadata. However, if that name is already used as a variable, then PyDough will not know to replace the name with the corresponding PyDough object. @@ -55,6 +82,7 @@ Bad Example #3: obtains every record of the `Addresses` collection (but the grap HELLO.Addresses ``` + ### Sub-Collections The next step in PyDough after accessing a collection is accessing any of its sub-collections. The syntax `collection.subcollection` steps into every record of `subcollection` for each record of `collection`. This can result in changes of cardinality if records of `collection` can have multiple records of `subcollection`, and can result in duplicate records in the output if records of `subcollection` can be sourced from different records of `collection`. @@ -87,6 +115,7 @@ Bad Example #1: for every address, obtains all people who used to live there. Th Addresses.former_occupants ``` + ### CALC The examples so far just show selecting all properties from records of a collection. Most of the time, an analytical question will only want a subset of the properties, may want to rename them, and may want to derive new properties via calculated expressions. The way to do this with a CALC term, which is done by following a PyDough collection with parenthesis containing the expressions that should be included. @@ -243,6 +272,7 @@ People(ssn, current_address) ``` + ### Contextless Expressions PyDough allows defining snippets of PyDough code out of context that do not make sense until they are later placed within a context. This can be done by writing a contextless expression, binding it to a variable as if it were any other Python expression, then later using it inside of PyDough code. This should always have the same effect as if the PyDough code was written fully in-context, but allows re-using common snippets. @@ -303,6 +333,7 @@ value = package_cost People(x=ssn + value) ``` + ### BACK Part of the benefit of doing `collection.subcollection` accesses is that properties from the ancestor collection can be accessed from the current collection. This is done via a `BACK` call. Accessing properties from `BACK(n)` can be done to access properties from the n-th ancestor of the current collection. The simplest recommended way to do this is to just access a scalar property of an ancestor in order to include it in the final answer. @@ -395,51 +426,63 @@ Customers.packages( ) ``` + ## Collection Operators TODO + ### WHERE TODO + ### ORDER_BY TODO + ### TOP_K TODO + ### PARTITION TODO + ### SINGULAR TODO + ### BEST TODO + ### NEXT / PREV TODO + ## Induced Properties -TODO +This section of the PyDough spec has not yet been defined. + ### Induced Scalar Properties -TODO +This section of the PyDough spec has not yet been defined. + ### Induced Subcollection Properties -TODO +This section of the PyDough spec has not yet been defined. + ### Induced Arbitrary Joins -TODO +This section of the PyDough spec has not yet been defined. From 0252cee6cd26469ec302899e1bb773466e5b2d28 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 13 Jan 2025 16:08:20 -0500 Subject: [PATCH 010/112] Added TOC and TODOs --- documentation/dsl.md | 46 +++++++++++++++++++++----------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index e2781271..08791c21 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -1,31 +1,29 @@ - +# PyDough DSL Spec -- [PyDough DSL Spec](#pydough-dsl-spec) - * [Example Graph](#example-graph) - * [Collections](#collections) - + [Sub-Collections](#sub-collections) - + [CALC](#calc) - + [Contextless Expressions](#contextless-expressions) - + [BACK](#back) - * [Collection Operators](#collection-operators) - + [WHERE](#where) - + [ORDER_BY](#order_by) - + [TOP_K](#top_k) - + [PARTITION](#partition) - + [SINGULAR](#singular) - + [BEST](#best) - + [NEXT / PREV](#next-prev) - * [Induced Properties](#induced-properties) - + [Induced Scalar Properties](#induced-scalar-properties) - + [Induced Subcollection Properties](#induced-subcollection-properties) - + [Induced Arbitrary Joins](#induced-arbitrary-joins) +This page describes the specification of the PyDough DSL. The specification includes rules of how PyDough code should be structured and the semantics that are used when evaluating PyDough code. Not every feature in the spec is implemented in PyDough as of this time. - + - -# PyDough DSL Spec +- [Example Graph](#example-graph) +- [Collections](#collections) + * [Sub-Collections](#sub-collections) + * [CALC](#calc) + * [Contextless Expressions](#contextless-expressions) + * [BACK](#back) +- [Collection Operators](#collection-operators) + * [WHERE](#where) + * [ORDER_BY](#order_by) + * [TOP_K](#top_k) + * [PARTITION](#partition) + * [SINGULAR](#singular) + * [BEST](#best) + * [NEXT / PREV](#next-prev) +- [Induced Properties](#induced-properties) + * [Induced Scalar Properties](#induced-scalar-properties) + * [Induced Subcollection Properties](#induced-subcollection-properties) + * [Induced Arbitrary Joins](#induced-arbitrary-joins) -This page describes the specification of the PyDough DSL. The specification includes rules of how PyDough code should be structured and the semantics that are used when evaluating PyDough code. Not every feature in the spec is implemented in PyDough as of this time. + ## Example Graph From 98c722b6bc7901388d0b3b180db29234d7cf436b Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 13 Jan 2025 16:12:16 -0500 Subject: [PATCH 011/112] Changing highlighting --- documentation/dsl.md | 76 ++++++++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index 08791c21..b9dd3fea 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -44,28 +44,28 @@ There are also the following sub-collection relationships: The simplest PyDough code is scanning an entire collection. This is done by providing the name of the collection in the metadata. However, if that name is already used as a variable, then PyDough will not know to replace the name with the corresponding PyDough object. -Good Example #1: obtains every record of the `People` collection. Every scalar property of `People` (`first_name`, `middle_name`, `last_name`, `ssn`, `birth_date`, `email`, `current_address_id`) is automatically included in the output. +**Good Example #1**: obtains every record of the `People` collection. Every scalar property of `People` (`first_name`, `middle_name`, `last_name`, `ssn`, `birth_date`, `email`, `current_address_id`) is automatically included in the output. ```py %%pydough People ``` -Good Example #2: obtains every record of the `Addresses` collection. The `GRAPH.` prefix is optional and implied when the term is a collection name in the graph. Every scalar property of `Addresses` (`address_id`, `street_number`, `street_name`, `apartment`, `zip_code`, `city`, `state`) is automatically included in the output. +**Good Example #2**: obtains every record of the `Addresses` collection. The `GRAPH.` prefix is optional and implied when the term is a collection name in the graph. Every scalar property of `Addresses` (`address_id`, `street_number`, `street_name`, `apartment`, `zip_code`, `city`, `state`) is automatically included in the output. ```py %%pydough GRAPH.Addresses ``` -Bad Example #1: obtains every record of the `Products` collection (there is no `Products` collection). +**Bad Example #1**: obtains every record of the `Products` collection (there is no `Products` collection). ```py %%pydough Addresses ``` -Bad Example #2: obtains every record of the `Addresses` collection (but the name `Addresses` has been reassigned to a variable). +**Bad Example #2**: obtains every record of the `Addresses` collection (but the name `Addresses` has been reassigned to a variable). ```py %%pydough @@ -73,7 +73,7 @@ Addresses = 42 Addresses ``` -Bad Example #3: obtains every record of the `Addresses` collection (but the graph name `HELLO` is the wrong graph name for this example). +**Bad Example #3**: obtains every record of the `Addresses` collection (but the graph name `HELLO` is the wrong graph name for this example). ```py %%pydough @@ -85,28 +85,28 @@ HELLO.Addresses The next step in PyDough after accessing a collection is accessing any of its sub-collections. The syntax `collection.subcollection` steps into every record of `subcollection` for each record of `collection`. This can result in changes of cardinality if records of `collection` can have multiple records of `subcollection`, and can result in duplicate records in the output if records of `subcollection` can be sourced from different records of `collection`. -Good Example #1: for every person, obtains their current address. Every scalar property of `Addresses` (`address_id`, `street_number`, `street_name`, `apartment`, `zip_code`, `city`, `state`) is automatically included in the output. A record from `Addresses` can be included multiple times if multiple different `People` records have it as their current address, or it could be missing entirely if no person has it as their current address. +**Good Example #1**: for every person, obtains their current address. Every scalar property of `Addresses` (`address_id`, `street_number`, `street_name`, `apartment`, `zip_code`, `city`, `state`) is automatically included in the output. A record from `Addresses` can be included multiple times if multiple different `People` records have it as their current address, or it could be missing entirely if no person has it as their current address. ```py %%pydough People.current_addresses ``` -Good Example #2: for every package, obtains the person who shipped it address. The `GRAPH.` prefix is optional and implied when the term is a collection name in the graph. Every scalar property of `People` (`first_name`, `middle_name`, `last_name`, `ssn`, `birth_date`, `email`, `current_address_id`) is automatically included in the output. A record from `People` can be included multiple times if multiple packages were ordered by that person, or it could be missing entirely if that person is not the customer who ordered any package. +**Good Example #2**: for every package, obtains the person who shipped it address. The `GRAPH.` prefix is optional and implied when the term is a collection name in the graph. Every scalar property of `People` (`first_name`, `middle_name`, `last_name`, `ssn`, `birth_date`, `email`, `current_address_id`) is automatically included in the output. A record from `People` can be included multiple times if multiple packages were ordered by that person, or it could be missing entirely if that person is not the customer who ordered any package. ```py %%pydough GRAPH.Packages.customer ``` -Good Example #3: for every address, obtains all packages that someone who lives at that address has ordered. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`). Every record from `Packages` should be included at most once since every current occupant has a single address it maps back to, and every package has a single customer it maps back to. +**Good Example #3**: for every address, obtains all packages that someone who lives at that address has ordered. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`). Every record from `Packages` should be included at most once since every current occupant has a single address it maps back to, and every package has a single customer it maps back to. ```py %%pydough Addresses.current_occupants.packages ``` -Bad Example #1: for every address, obtains all people who used to live there. This is invalid because the `Addresses` collection does not have a `former_occupants` property. +**Bad Example #1**: for every address, obtains all people who used to live there. This is invalid because the `Addresses` collection does not have a `former_occupants` property. ```py %%pydough @@ -131,14 +131,14 @@ Once a CALC term is created, all terms of the current collection still exist eve A CALC can also be done on the graph itself to create a collection with 1 row and columns corresponding to the properties inside the CALC. This is useful when aggregating an entire collection globally instead of with regards to a parent collection. -Good Example #1: For every person, fetches just their first name & last name. +**Good Example #1**: For every person, fetches just their first name & last name. ```py %%pydough People(first_name, last_name) ``` -Good Example #2: For every package, fetches the package id, the first & last name of the person who ordered it, and the state that it was shipped to. Also includes a field named `secret_key` that is always equal to the string `"alphabet soup"`. +**Good Example #2**: For every package, fetches the package id, the first & last name of the person who ordered it, and the state that it was shipped to. Also includes a field named `secret_key` that is always equal to the string `"alphabet soup"`. ```py %%pydough @@ -151,7 +151,7 @@ Packages( ) ``` -Good Example #3: For every person, finds their full name (without the middle name) and counts how many packages they purchased. +**Good Example #3**: For every person, finds their full name (without the middle name) and counts how many packages they purchased. ```py %%pydough @@ -161,7 +161,7 @@ People( ) ``` -Good Example #4: For every person, finds their full name including the middle name if one exists, as well as their email. Notice that two CALCs are present, but only the terms from the second one are part of the answer. +**Good Example #4**: For every person, finds their full name including the middle name if one exists, as well as their email. Notice that two CALCs are present, but only the terms from the second one are part of the answer. ```py %%pydough @@ -175,7 +175,7 @@ People( ) ``` -Good Example #4: For every person, finds the year from the most recent package they purchased, and from the first package they ever purchased. +**Good Example #4**: For every person, finds the year from the most recent package they purchased, and from the first package they ever purchased. ```py %%pydough @@ -185,7 +185,7 @@ People( ) ``` -Good Example #5: Count how many people, packages, and addresses are known in the system. +**Good Example #5**: Count how many people, packages, and addresses are known in the system. ```py %%pydough @@ -196,7 +196,7 @@ GRAPH( ) ``` -Good Example #6: For each package, lists the package id and whether the package was shipped to the current address of the person who ordered it. +**Good Example #6**: For each package, lists the package id and whether the package was shipped to the current address of the person who ordered it. ```py %%pydough @@ -206,14 +206,14 @@ Packages( ) ``` -Bad Example #1: For each person, lists their first name, last name, and phone number. This is invalid because `People` does not have a property named `phone_number`. +**Bad Example #1**: For each person, lists their first name, last name, and phone number. This is invalid because `People` does not have a property named `phone_number`. ```py %%pydough People(first_name, last_name, phone_number) ``` -Bad Example #2: For each person, lists their combined first & last name followed by their email. This is invalid because a positional argument is included after a keyword argument. +**Bad Example #2**: For each person, lists their combined first & last name followed by their email. This is invalid because a positional argument is included after a keyword argument. ```py %%pydough @@ -223,14 +223,14 @@ People( ) ``` -Bad Example #3: For each person, lists the address_id of packages they have ordered. This is invalid because `packages` is a plural property of `People`, so its properties cannot be included in a calc term of `People` unless aggregated. +**Bad Example #3**: For each person, lists the address_id of packages they have ordered. This is invalid because `packages` is a plural property of `People`, so its properties cannot be included in a calc term of `People` unless aggregated. ```py %%pydough People(packages.address_id) ``` -Bad Example #4: For each person, lists their first/last name followed by the concatenated city/state name of their current address. This is invalid because `current_address` is a plural property of `People`, so its properties cannot be included in a calc term of `People` unless aggregated. +**Bad Example #4**: For each person, lists their first/last name followed by the concatenated city/state name of their current address. This is invalid because `current_address` is a plural property of `People`, so its properties cannot be included in a calc term of `People` unless aggregated. ```py %%pydough @@ -241,28 +241,28 @@ People( ) ``` -Bad Example #5: For each address, finds whether the state name starts with `"C"`. This is invalid because it calls the builtin Python `.startswith` string method, which is not supported in PyDough (should have instead used a defined PyDough behavior, like the `STARTSWITH` function). +**Bad Example #5**: For each address, finds whether the state name starts with `"C"`. This is invalid because it calls the builtin Python `.startswith` string method, which is not supported in PyDough (should have instead used a defined PyDough behavior, like the `STARTSWITH` function). ```py %%pydough Addresses(is_c_state=state.startswith("c")) ``` -Bad Example #6: For each address, finds the state bird of the state it is in. This is invalid because the `state` property of each record of `Addresses` is a scalar expression, not a subcolleciton, so it does not have any properties that can be accessed with `.` syntax. +**Bad Example #6**: For each address, finds the state bird of the state it is in. This is invalid because the `state` property of each record of `Addresses` is a scalar expression, not a subcolleciton, so it does not have any properties that can be accessed with `.` syntax. ```py %%pydough Addresses(state_bird=state.bird) ``` -Bad Example #7: For each current occupant of each address, lists their first name, last name, and city/state they live in. This is invalid because `city` and `state` are not properties of the current collection (`People`, accessed via `current_occupants` of each record of `Addresses`). +**Bad Example #7**: For each current occupant of each address, lists their first name, last name, and city/state they live in. This is invalid because `city` and `state` are not properties of the current collection (`People`, accessed via `current_occupants` of each record of `Addresses`). ```py %%pydough Addresses.current_occupants(first_name, last_name, city, state) ``` -Bad Example #8: For each person include their ssn and current address. This is invalid because a collection cannot be a CALC term, and `current_address` is a sub-collection property of `People`. Instead, properties of `current_address` can be accessed. +**Bad Example #8**: For each person include their ssn and current address. This is invalid because a collection cannot be a CALC term, and `current_address` is a sub-collection property of `People`. Instead, properties of `current_address` can be accessed. ```py %%pydough @@ -275,7 +275,7 @@ People(ssn, current_address) PyDough allows defining snippets of PyDough code out of context that do not make sense until they are later placed within a context. This can be done by writing a contextless expression, binding it to a variable as if it were any other Python expression, then later using it inside of PyDough code. This should always have the same effect as if the PyDough code was written fully in-context, but allows re-using common snippets. -Good Example #1: Same as good example #4 from the CALC section, but written with contextless expressions. +**Good Example #1**: Same as good example #4 from the CALC section, but written with contextless expressions. ```py %%pydough @@ -288,7 +288,7 @@ People( ) ``` -Good Example #2: for every person, finds the total value of all packages they ordered in February of any year, as well as the number of all such packages, the largest value of any such package, and the percentage of those packages that were specifically on valentine's day +**Good Example #2**: for every person, finds the total value of all packages they ordered in February of any year, as well as the number of all such packages, the largest value of any such package, and the percentage of those packages that were specifically on valentine's day ```py %%pydough @@ -309,21 +309,21 @@ People( ) ``` -Bad Example #1: Just a contextless expression for a collection without the necessary context for it to make sense. +**Bad Example #1**: Just a contextless expression for a collection without the necessary context for it to make sense. ```py %%pydough current_addresses(city, state) ``` -Bad Example #2: Just a contextless expression for a scalar expression that has not been placed into a collection for it to make sense. +**Bad Example #2**: Just a contextless expression for a scalar expression that has not been placed into a collection for it to make sense. ```py %%pydough LOWER(current_occupants.first_name) ``` -Bad Example #3: A contextless expression that does not make sense when placed into its context (`People` does not have a property named `package_cost`, so substituting it when `value` is referenced does not make sense). +**Bad Example #3**: A contextless expression that does not make sense when placed into its context (`People` does not have a property named `package_cost`, so substituting it when `value` is referenced does not make sense). ```py %%pydough @@ -336,7 +336,7 @@ People(x=ssn + value) Part of the benefit of doing `collection.subcollection` accesses is that properties from the ancestor collection can be accessed from the current collection. This is done via a `BACK` call. Accessing properties from `BACK(n)` can be done to access properties from the n-th ancestor of the current collection. The simplest recommended way to do this is to just access a scalar property of an ancestor in order to include it in the final answer. -Good Example #1: For every address' current occupants, lists their first name last name, and the city/state of the current address they belong to. +**Good Example #1**: For every address' current occupants, lists their first name last name, and the city/state of the current address they belong to. ```py %%pydough @@ -348,7 +348,7 @@ Addresses.current_occupants( ) ``` -Good Example #2: Count the total number of cases where a package is shipped to the current address of the customer who ordered it. +**Good Example #2**: Count the total number of cases where a package is shipped to the current address of the customer who ordered it. ```py %%pydough @@ -358,7 +358,7 @@ package_info = Addresses.current_occupants.packages( GRAPH(n_cases=SUM(package_info.is_shipped_to_current_addr)) ``` -Good Example #3: Indicate whether a package is above the average cost for all packages ordered by that customer. +**Good Example #3**: Indicate whether a package is above the average cost for all packages ordered by that customer. ```py %%pydough @@ -369,7 +369,7 @@ Customers( ) ``` -Good Example #4: For every customer, indicate what percentage of all packages billed to their current address were purchased by that same customer. +**Good Example #4**: For every customer, indicate what percentage of all packages billed to their current address were purchased by that same customer. ```py %%pydough @@ -384,35 +384,35 @@ Addresses( ) ``` -Bad Example #1: The `GRAPH` does not have any ancestors, so `BACK(1)` is invalid. +**Bad Example #1**: The `GRAPH` does not have any ancestors, so `BACK(1)` is invalid. ```py %%pydough GRAPH(x=BACK(1).foo) ``` -Bad Example #2: The 1st ancestor of `People` is `GRAPH` which does not have a term named `bar`. +**Bad Example #2**: The 1st ancestor of `People` is `GRAPH` which does not have a term named `bar`. ```py %%pydough People(y=BACK(1).bar) ``` -Bad Example #3: The 1st ancestor of `People` is `GRAPH` which does not have an ancestor, so there can be no 2nd ancestor of `People`. +**Bad Example #3**: The 1st ancestor of `People` is `GRAPH` which does not have an ancestor, so there can be no 2nd ancestor of `People`. ```py %%pydough People(z=BACK(2).fizz) ``` -Bad Example #4: The 1st ancestor of `current_address` is `People` which does not have a term named `phone`. +**Bad Example #4**: The 1st ancestor of `current_address` is `People` which does not have a term named `phone`. ```py %%pydough People.current_address(a=BACK(1).phone) ``` -Bad Example #5: Even though `cust_info` has defined `avg_package_cost`, the final expression `Customers.packages(...)` does not have `cust_info` as an ancestor, so it cannot access `BACK(1).avg_package_cost` since its 1st ancestor (`Customers`) does not have any term named `avg_package_cost`. +**Bad Example #5**: Even though `cust_info` has defined `avg_package_cost`, the final expression `Customers.packages(...)` does not have `cust_info` as an ancestor, so it cannot access `BACK(1).avg_package_cost` since its 1st ancestor (`Customers`) does not have any term named `avg_package_cost`. ```py %%pydough From b03a5eca6c74d46d8dd746d14d00fef37fa99d73 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 13 Jan 2025 16:21:46 -0500 Subject: [PATCH 012/112] Addded more examples --- documentation/dsl.md | 56 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 3 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index b9dd3fea..522165ee 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -58,6 +58,13 @@ People GRAPH.Addresses ``` +**Good Example #3**: obtains every record of the `Packages` collection. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`) is automatically included in the output. + +```py +%%pydough +Packages +``` + **Bad Example #1**: obtains every record of the `Products` collection (there is no `Products` collection). ```py @@ -80,6 +87,14 @@ Addresses HELLO.Addresses ``` +**Bad Example #4**: obtains every record of the `People` collection (but the name `People` has been reassigned to a variable). + +```py +%%pydough +People = "not a collection" +People +``` + ### Sub-Collections @@ -106,6 +121,13 @@ GRAPH.Packages.customer Addresses.current_occupants.packages ``` +**Good Example #4**: for every person, obtains all packages they have ordered. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`). Every record from `Packages` should be included at most once since every package has a single customer it maps back to. + +```py +%%pydough +People.packages +``` + **Bad Example #1**: for every address, obtains all people who used to live there. This is invalid because the `Addresses` collection does not have a `former_occupants` property. ```py @@ -113,6 +135,13 @@ Addresses.current_occupants.packages Addresses.former_occupants ``` +**Bad Example #2**: for every package, obtains all addresses it was shipped to. This is invalid because the `Packages` collection does not have a `shipping_addresses` property (it does have a `shipping_address` property). + +```py +%%pydough +Packages.shipping_addresses +``` + ### CALC @@ -175,7 +204,7 @@ People( ) ``` -**Good Example #4**: For every person, finds the year from the most recent package they purchased, and from the first package they ever purchased. +**Good Example #5**: For every person, finds the year from the most recent package they purchased, and from the first package they ever purchased. ```py %%pydough @@ -185,7 +214,7 @@ People( ) ``` -**Good Example #5**: Count how many people, packages, and addresses are known in the system. +**Good Example #6**: Count how many people, packages, and addresses are known in the system. ```py %%pydough @@ -196,7 +225,7 @@ GRAPH( ) ``` -**Good Example #6**: For each package, lists the package id and whether the package was shipped to the current address of the person who ordered it. +**Good Example #7**: For each package, lists the package id and whether the package was shipped to the current address of the person who ordered it. ```py %%pydough @@ -269,6 +298,12 @@ Addresses.current_occupants(first_name, last_name, city, state) People(ssn, current_address) ``` +**Bad Example #9**: For each person, lists their first name, last name, and the sum of the package costs. This is invalid because `SUM` is an aggregation function and cannot be used in a CALC term without specifying the sub-collection it should be applied to. + +```py +%%pydough +People(first_name, last_name, total_cost=SUM(package_cost)) +``` ### Contextless Expressions @@ -331,6 +366,14 @@ value = package_cost People(x=ssn + value) ``` +**Bad Example #4**: A contextless expression that does not make sense when placed into its context (`People` does not have a property named `order_date`, so substituting it when `is_february` is referenced does not make sense). + +```py +%%pydough +is_february = MONTH(order_date) == 2 +People(february=is_february) +``` + ### BACK @@ -424,6 +467,13 @@ Customers.packages( ) ``` +**Bad Example #6**: The 1st ancestor of `current_occupants` is `Addresses` which does not have a term named `phone`. + +```py +%%pydough +Addresses.current_occupants(a=BACK(1).phone) +``` + ## Collection Operators From b569bfb7566256545cb8af924bef655dce211151 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 14 Jan 2025 10:33:59 -0500 Subject: [PATCH 013/112] Added WHERE, ORDER_BY, and TOP_K documentation --- documentation/dsl.md | 267 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 263 insertions(+), 4 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index 522165ee..5d818362 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -477,22 +477,281 @@ Addresses.current_occupants(a=BACK(1).phone) ## Collection Operators -TODO +So far all of the examples shown have been about accessing collections/sub-collections and deriving expression in terms of the current context, child contexts, and ancestor context. PyDough has other operations that access/create/augment collections. ### WHERE -TODO +A core PyDough operation is the ability to filter the records of a collection. This is done by appending a PyDough collection with `.WHERE(cond)` where `cond` is any expression that could have been placed in a `CALC` term and should have a True/False value. Every record where `cond` evaluates to True will be preserved, and the rest will be dropped from the answer. The terms in the collection are unchanged by the `WHERE` clause, since the only change is which records are kept/dropped. + +**Good Example #1**: For every person who has a middle name and and email that ends with `"gmail.com"`, fetches their first name and last name. + +```py +%%pydough +People.WHERE(PRESENT(middle_name) & ENDSWITH(email, "gmail.com"))(first_name, last_name) +``` + +**Good Example #2**: For every package where the package cost is greater than 100, fetches the package id and the state it was shipped to. + +```py +%%pydough +Packages.WHERE(package_cost > 100)(package_id, shipping_state=shipping_address.state) +``` + +**Good Example #3**: For every person who has ordered more than 5 packages, fetches their first name, last name, and email. + +```py +%%pydough +People(first_name, last_name, email).WHERE(COUNT(packages) > 5) +``` + +**Good Example #4**: Finds every person whose most recent order was shipped in the year 2023, and lists all properties of that person. + +```py +%%pydough +People.WHERE(YEAR(MAX(packages.order_date)) == 2023) +``` + +**Good Example #4**: Counts how many packages were ordered in January of 2018. + +```py +%%pydough +packages_jan_2018 = Packages.WHERE( + (YEAR(order_date) == 2018) & (MONTH(order_date) == 1) +) +GRAPH(n_jan_2018=COUNT(selected_packages)) +``` + +**Bad Example #1**: For every person, fetches their first name and last name only if they have a phone number. This is invalid because `People` does not have a property named `phone_number`. + +```py +%%pydough +People.WHERE(PRESENT(phone_number))(first_name, last_name) +``` + +**Bad Example #2**: For every package, fetches the package id only if the package cost is greater than 100 and the shipping state is Texas. This is invalid because `shipping_state` is not a property of `Packages`. + +```py +%%pydough +Packages.WHERE((package_cost > 100) & (shipping_state == "TX"))(package_id) +``` + +**Bad Example #3**: For every package, fetches the package id only if the package cost is greater than 100 and the shipping state is Texas. This is invalid because `and` is used instead of `&`. + +```py +%%pydough +Packages.WHERE((package_cost > 100) and (shipping_address.state == "TX"))(package_id) +``` + +**Bad Example #4**: Obtain every person whose packages were shipped in the month of June. This is invalid because `packages` is a plural property of `People`, so `MONTH(packages.order_date) == 6` is a plural expression with regards to `People` that cannot be used as a filtering condition. + +```py +%%pydough +People.WHERE(MONTH(packages.order_date) == 6) +``` ### ORDER_BY -TODO +Another operation that can be done onto PyDough collections is sorting them. This is done by appending a collection with `.ORDER_BY(...)` which will order the collection by the collation terms between the parenthesis. The collation terms must be 1+ expressions that can be inside of a CALC term (singular expressions with regards to the current context), each decorated with information making it usable as a collation. + +An expression becomes a collation expression when it is appended with `.ASC()` (indicating that the expression should be used to sort in ascending order) or `.DESC()` (indicating that the expression should be used to sort in descending order). Both `.ASC()` and `.DESC()` take in an optional argument `na_pos` indicating where to place null values. This keyword argument can be either `"first"` or `"last"`, and the default is `"first"` for `.ASC()` and `"last"` for `.DESC()`. The way the sorting works is that it orders by hte first collation term provided, and in cases of ties it moves on to the second collation term, and if there are ties in that it moves on to the third, and so on until there are no more terms to sort by, at which point the ties are broken arbitrarily. + +If there are multiple `ORDER_BY` terms, the last one is the one that takes precedence. The terms in the collection are unchanged by the `ORDER_BY` clause, since the only change is the order of the records. + +> [!WARNING] +> In the current version of PyDough, the behavior when the expressions inside an `ORDER_BY` clause are not collation expressions with `.ASC()` or `.DESC()` is undefined/unsupported. + +**Good Example #1**: Orders every person alphabetically by last name, then first name, then middle name (people with no middle name going last). + +```py +%%pydough +People.ORDER_BY(last_name.ASC(), first_name.ASC(), middle_name.ASC(na_pos="last")) +``` + +**Good Example #2**: For every person lists their ssn & how many packages they have ordered, and orders them from highest number of orders to lowest, breaking ties in favor of whoever is oldest. + +```py +%%pydough +People( + ssn, n_packages=COUNT(packages).DESC() +).ORDER_BY( + n_packages.DESC(), birth_date.ASC() +) +``` + +**Good Example #3**: Finds every address that has at least 1 person living in it and sorts them highest-to-lowest by number of occupants, with ties broken by address id in ascending order. + +```py +%%pydough +Addresses.WHERE( + HAS(current_occupants) +).ORDER_BY( + COUNT(current_occupants).DESC(), address_id.ASC() +) +``` + +**Good Example #4**: Sorts every person alphabetically by the state they live in, then the city they live in, then by their ssn. People without a current address should go last. + +```py +%%pydough +People.ORDER_BY( + current_address.state.ASC(na_pos="last"), + current_address.city.ASC(na_pos="last"), + ssn.ASC(), +) +``` + +**Good Example #5**: Same as good example #4, but written so it only includes people who are current occupants of an address in Ohio. + +```py +%%pydough +Addresses.WHERE( + state == "OHIO" +).current_occupants.ORDER_BY( + BACK(1).state.ASC(), + BACK(1).city.ASC(), + ssn.ASC(), +) +``` + +**Bad Example #1**: Sorts each person by their account balance in descending order. This is invalid because the `People` collection does not have an `account_balance` property. + +```py +%%pydough +People.ORDER_BY(account_balance.DESC()) +``` + +**Bad Example #2**: Sorts each address by the birth date date of the people who live there. This is invalid because `current_occupants` is a plural property of `Addresses`, so `current_occupants.birth_date` is plural and cannot be used as an ordering term unless aggregated. + +```py +%%pydough +Addresses.ORDER_BY(current_occupants.ASC()) +``` + +**Bad Example #3**: Same as good example #5, but incorrect because `BACK(2)` is used, and `BACK(2)` refers to the 2nd ancestor of `current_occupants` which is `GRAPH`, which does not have any properties named `state` or `city`. + +```py +%%pydough +Addresses.WHERE( + state == "OHIO" +).current_occupants.ORDER_BY( + BACK(2).state.ASC(), + BACK(2).city.ASC(), + ssn.ASC(), +) +``` + +**Bad Example #4**: Same as bad example #3, but incorrect because `BACK(3)` is used, and `BACK(3)` refers to the 3rd ancestor of `current_occupants` which does not exist because the 2nd ancestor is `GRAPH`, which does not have any ancestors. + +```py +%%pydough +Addresses.WHERE( + state == "OHIO" +).current_occupants.ORDER_BY( + BACK(2).state.ASC(), + BACK(2).city.ASC(), + ssn.ASC(), +) +``` + +**Bad Example #5**: Sorts every person by their first name. This is invalid because no `.ASC()` or `.DESC()` term is provided. + +```py +%%pydough +People.ORDER_BY(first_name) +``` + +**Bad Example #6**: Sorts every person. This is invalid because no collation terms are provided. + +```py +%%pydough +People.ORDER_BY() +``` ### TOP_K -TODO +A similar operation to `ORDER_BY` is `TOP_K`. The `TOP_K` operation also sorts a collection, but then uses the ordered results in order to pick the first `k`, values, where `k` is a provided constant. + +The syntax for this is `.TOP_K(k, by=...)` where `k` is a positive integer and the `by` clause is either a single collation term (as seen in `ORDER_BY`) or an iterable of collation terms (e.g. a list or tuple). The same restrictions as `ORDER_BY` apply to `TOP_K` regarding their collation terms. + +The terms in the collection are unchanged by the `TOP_K` clause, since the only change is the order of the records and which ones are kept/dropped. + +**Good Example #1**: Finds the 10 people who have ordered the most packages, including their first/last name, birth date, and the number of packages. If there is a tie, break it by the lowest ssn. + +```py +%%pydough +People( + first_name, + last_name, + birth_date, + n_packages=COUNT(packages) +).TOP_K(10, by=(n_packages.DESC(), ssn.ASC())) +``` + +**Good Example #2**: Finds the 5 most recently shipped packages, with ties broken arbitrarily. + +```py +%%pydough +Packages.TOP_K(5, by=order_date.DESC()) +``` + +**Good Example #3**: Finds the 100 addresses that have most recently had packages either shipped or billed to them, breaking ties arbitrarily. + +```py +%%pydough +default_date = datetime.date(1970, 1, 1) +most_recent_ship = DEFAULT_TO(MAX(packages_shipped.order_date), default_date) +most_recent_bill = DEFAULT_TO(MAX(packages_billed.order_date), default_date) +most_recent_package = IFF(most_recent_ship < most_recent_bill, most_recent_ship, most_recent_bill) +Addresses.TOP_K(10, by=most_recent_package.DESC()) +``` + +**Bad Example #1**: Finds the 5 people with the lowest GPA. This is invalid because the `People` collection does not have a `gpa` property. + +```py +%%pydough +People.TOP_K(5, by=gpa.ASC()) +``` + +**Bad Example #2**: Finds the 25 addresses with the earliest packages billed to them, by arrival date. This is invalid because `packages_billed` is a plural property of `Addresses`, so `packages_billed.arrival_date` cannot be used as a collation expression for `Addresses`. + +```py +%%pydough +Addresses.packages_billed(25, by=gpa.packages_billed.arrival_date()) +``` + +**Bad Example #3**: Finds the top 100 people currently living in the city of San Francisco. This is invalid because the `by` clause is absent. + +```py +%%pydough +People.WHERE( + current_address.city == "San Francisco" +).TOP_K(100) +``` + +**Bad Example #4**: Finds the top packages by highest value. This is invalid because there is no `k` value. + +```py +%%pydough +Packages.TOP_K(by=package_cost.DESC()) +``` + +**Bad Example #5**: Finds the top 300 addresses. This is invalid because the `by` clause is empty + +```py +%%pydough +Addresses.TOP_K(300, by=()) +``` + +**Bad Example #6**: Finds the 1000 people by birth date. This is invalid because the collation term does not have `.ASC()` or `.DESC()`. + +```py +%%pydough +People.TOP_K(1000, by=birth_date) +``` + ### PARTITION From 7607feacaac410eb6722a52341a6f92d17d945c6 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 14 Jan 2025 10:34:32 -0500 Subject: [PATCH 014/112] Fixing typo --- documentation/dsl.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index 5d818362..4ed3c1dd 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -512,7 +512,7 @@ People(first_name, last_name, email).WHERE(COUNT(packages) > 5) People.WHERE(YEAR(MAX(packages.order_date)) == 2023) ``` -**Good Example #4**: Counts how many packages were ordered in January of 2018. +**Good Example #5**: Counts how many packages were ordered in January of 2018. ```py %%pydough From 215735b2646ac70e230b3e35b9975c72d2630945 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 14 Jan 2025 10:41:35 -0500 Subject: [PATCH 015/112] Added extra example --- documentation/dsl.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index 4ed3c1dd..76675e57 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -708,6 +708,17 @@ most_recent_package = IFF(most_recent_ship < most_recent_bill, most_recent_ship, Addresses.TOP_K(10, by=most_recent_package.DESC()) ``` +**Good Example #4**: Finds the top 3 people who have spent the most money on packages, including their first/last name, and the total cost of all of their packages. + +```py +%%pydough +People( + first_name, + last_name, + total_package_cost=SUM(packages.package_cost) +).TOP_K(3, by=total_package_cost.DESC()) +``` + **Bad Example #1**: Finds the 5 people with the lowest GPA. This is invalid because the `People` collection does not have a `gpa` property. ```py @@ -752,7 +763,6 @@ Addresses.TOP_K(300, by=()) People.TOP_K(1000, by=birth_date) ``` - ### PARTITION From 0b8e1fb4066080a3b1d3b0ff69da59dd90ae7274 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:41:54 -0500 Subject: [PATCH 016/112] Update pydough/unqualified/unqualified_node.py Co-authored-by: Hadia Ahmed --- pydough/unqualified/unqualified_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydough/unqualified/unqualified_node.py b/pydough/unqualified/unqualified_node.py index cfe6f25e..2a55a3a2 100644 --- a/pydough/unqualified/unqualified_node.py +++ b/pydough/unqualified/unqualified_node.py @@ -133,7 +133,7 @@ def __getitem__(self, key): def __bool__(self): raise PyDoughUnqualifiedException( - "PyDough code cannot be treated as a boolean. If you intend to do a logical operation, use `|`, `&` or `~` instead of `or`, `and` and `not`." + "PyDough code cannot be treated as a boolean. If you intend to do a logical operation, use `|`, `&` and `~` instead of `or`, `and` and `not`." ) def __add__(self, other: object): From a5190f8583792291e4c4563693b3e86dad274857 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:42:16 -0500 Subject: [PATCH 017/112] Update documentation/functions.md Revision 2 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index bf77f3a8..6c3d0984 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -83,7 +83,7 @@ Customers( ``` > [!WARNING] -> Do **NOT** use use chained inequalities like `a <= b <= c`, as this can cause undefined incorrect behavior in PyDough. Instead, use expressions like `(a <= b) & (b <= c)`. +> Chained inequalities, like `a <= b <= c`, can cause undefined/incorrect behavior in PyDough. Instead, use expressions like `(a <= b) & (b <= c)`. ### Logical From a0d43f3e376727c51508126aba037c4c6060bcfe Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:42:30 -0500 Subject: [PATCH 018/112] Update documentation/functions.md Revision 3 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index 6c3d0984..a45b3ae6 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -69,7 +69,7 @@ Lineitems(value = (extended_price * (1 - discount) + 1.0) / part.retail_price) ### Comparisons -Expression values can be compared to one another with the standard comparison operators `<=`, `<`, `==`, `!=`, `>` and `>=`: +Expression values can be compared using standard comparison operators: `<=`, `<`, `==`, `!=`, `>` and `>=`: ```py Customers( From 24c1e4a8d7c89d65c61a60e916aea2617250d8fa Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:42:40 -0500 Subject: [PATCH 019/112] Update documentation/functions.md Revision 4 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index a45b3ae6..ae39f668 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -112,7 +112,7 @@ Below is each unary operator currently supported in PyDough. ### Negation -A numerical expression can have its sign flipped by prefixing it with the `-` operator: +A numerical expression's sign can be flipped by prefixing it with the `-` operator: ```py Lineitems(lost_value = extended_price * (-discount)) From 8009e05f5f39001bf3b17d0ca0e4e208813dc5b5 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:42:55 -0500 Subject: [PATCH 020/112] Update documentation/functions.md Revision 5 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index ae39f668..3aa3185a 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -136,7 +136,7 @@ Customers( ``` > [!WARNING] -> PyDough currently only supports combinations of `string[start:stop:step]` where `step` is either 1 or missing, and where both `start` and `stop` are either non-negative values or are missing. +> PyDough currently only supports combinations of `string[start:stop:step]` where `step` is either 1 or omitted, and both `start` and `stop` are either non-negative values or omitted. ## String Functions From 5d87b59ba16ee106ab93ee4f80035e6ea570aa58 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:43:08 -0500 Subject: [PATCH 021/112] Update documentation/functions.md Revision 6 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index 3aa3185a..6e86fce2 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -164,7 +164,7 @@ Customers(uppercase_name = UPPER(name)) ### STARTSWITH -The `STARTSWITH` function returns whether its first argument begins with its second argument as a string prefix: +The `STARTSWITH` function checks if its first argument begins with its second argument as a string prefix: ```py Parts(begins_with_yellow = STARTSWITH(name, "yellow")) From a3ee53698b7cbf9916f890450303f505077210e7 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:43:24 -0500 Subject: [PATCH 022/112] Update documentation/functions.md Revision 7 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index 6e86fce2..a50807f1 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -173,7 +173,7 @@ Parts(begins_with_yellow = STARTSWITH(name, "yellow")) ### ENDSWITH -The `ENDSWITH` function returns whether its first argument ends with its second argument as a string suffix: +The `ENDSWITH` function checks if its first argument ends with its second argument as a string suffix: ```py Parts(ends_with_chocolate = ENDSWITH(name, "chocolate")) From 9945752b49f1639a1146b23d58e87b01e08318b6 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:43:41 -0500 Subject: [PATCH 023/112] Update documentation/functions.md Revision 8 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index a50807f1..704837a4 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -182,7 +182,7 @@ Parts(ends_with_chocolate = ENDSWITH(name, "chocolate")) ### CONTAINS -The `CONTAINS` function returns whether its first argument contains with its second argument as a substring: +The `CONTAINS` function checks if its first argument contains its second argument as a substring: ```py Parts(is_green = CONTAINS(name, "green")) From 4c6aa798fbc002f51f83f38a32f1a7422066d40d Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:43:56 -0500 Subject: [PATCH 024/112] Update documentation/functions.md Revision 9 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index 704837a4..75937a11 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -191,7 +191,7 @@ Parts(is_green = CONTAINS(name, "green")) ### LIKE -The `LIKE` function returns whether the first argument matches the SQL pattern text of the second argument, where `_` is a 1 character wildcard and `%` is an 0+ character wildcard. +The `LIKE` function checks if the first argument matches the SQL pattern text of the second argument, where `_` is a 1 character wildcard and `%` is an 0+ character wildcard. ```py Orders(is_special_request = LIKE(comment, "%special%requests%")) From e60dcc3087b4b6eb645585c7b050e2a784606a1d Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:44:13 -0500 Subject: [PATCH 025/112] Update documentation/functions.md Revision 10 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index 75937a11..500ce153 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -208,7 +208,7 @@ Below are some examples of how to interpret these patterns: ### JOIN_STRINGS -The `JOIN_STRINGS` function combines all of its string arguments by concatenating every argument after the first argument, using the first argument as a delimiter between each of the following arguments (like the `.join` method in Python): +The `JOIN_STRINGS` function concatenates all its string arguments, using the first argument as a delimiter between each of the following arguments (like the `.join` method in Python): ```py Regions.nations.customers( From a23c395befeb3fe6ac3da47a5f7474ef97a0127f Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:45:07 -0500 Subject: [PATCH 026/112] Update documentation/functions.md Revision 11 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index 500ce153..75282533 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -394,7 +394,7 @@ Customers(num_unique_parts_purchased = NDISTINCT(orders.lines.parts.key)) ### HAS -The `HAS` function is called on a sub-collection and returns True if at least one record of the sub-collection exists. In other words, `HAS(x)` is equivalent to `COUNT(x) > 0`. +The `HAS` function is called on a sub-collection and returns `True` if at least one record of the sub-collection exists. In other words, `HAS(x)` is equivalent to `COUNT(x) > 0`. ```py Parts.WHERE(HAS(supply_records.supplier.WHERE(nation.name == "GERMANY"))) From 2515683bb87ef76d83f6b8695896b40ecfefa801 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:45:31 -0500 Subject: [PATCH 027/112] Update documentation/functions.md Revision 12 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index 75282533..fa489017 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -403,7 +403,7 @@ Parts.WHERE(HAS(supply_records.supplier.WHERE(nation.name == "GERMANY"))) ### HASNOT -The `HASNOT` function is called on a sub-collection and returns True if no records of the sub-collection exist. In other words, `HASNOT(x)` is equivalent to `COUNT(x) == 0`. +The `HASNOT` function is called on a sub-collection and returns `True` if no records of the sub-collection exist. In other words, `HASNOT(x)` is equivalent to `COUNT(x) == 0`. ```py Customers.WHERE(HASNOT(orders)) From d34c64c8ed0bbea6c1eea1cb0769ca16ea424b9a Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:45:44 -0500 Subject: [PATCH 028/112] Update documentation/functions.md Revision 13 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index fa489017..373eda42 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -442,7 +442,7 @@ The `RANKING` function returns ordinal position of the current record when all r - `by`: 1+ collation values, either as a single expression or an iterable of expressions, used to order the records of the current context. - `levels`: same `levels` argument as all other window functions. - `allow_ties`: optional argument (default False) specifying to allow values that are tied according to the `by` expressions to have the same rank value. If False, tied values have different rank values where ties are broken arbitrarily. -- `dense`: optional argument (default False) specifying that if `allow_ties` is True and a tie is found, should the next value after hte ties be the current ranking value plus 1, as opposed to jumping to a higher value based on the number of ties that were there. For example, with the values `[a, a, b, b, b, c]`, the values with `dense=True` would be `[1, 1, 2, 2, 2, 3]`, but with `dense=False` they would be `[1, 1, 3, 3, 3, 6]`. +- `dense`: optional argument (default False) specifying that if `allow_ties` is True and a tie is found, should the next value after the ties be the current ranking value plus 1, as opposed to jumping to a higher value based on the number of ties that were there. For example, with the values `[a, a, b, b, b, c]`, the values with `dense=True` would be `[1, 1, 2, 2, 2, 3]`, but with `dense=False` they would be `[1, 1, 3, 3, 3, 6]`. ```py # Rank customers per-nation by their account balance From cd59dc25bb7d9188422bf6de756d855b3ef8bf50 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:46:02 -0500 Subject: [PATCH 029/112] Update documentation/functions.md Revision 14 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index 373eda42..49d6d973 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -279,7 +279,7 @@ Lineitems(adj_tax = DEFAULT_TO(tax, 0)) ### PRESENT -The `PRESENT` function returns whether its argument is non-null (e.g. the same as `IS NOT NULL` in SQL): +The `PRESENT` function checks if its argument is non-null (e.g. the same as `IS NOT NULL` in SQL): ```py Lineitems(has_tax = PRESENT(tax)) From 92778e3e92b118d0733440f33142a6df610e7fd3 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:46:17 -0500 Subject: [PATCH 030/112] Update documentation/functions.md Revision 15 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index 49d6d973..327af014 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -288,7 +288,7 @@ Lineitems(has_tax = PRESENT(tax)) ### ABSENT -The `ABSENT` function returns whether its argument is non-null (e.g. the same as `IS NULL` in SQL): +The `ABSENT` function checks if its argument is null (e.g. the same as `IS NULL` in SQL): ```py Lineitems(no_tax = ABSENT(tax)) From 6003ed03f9e66ae24e9dc4468ea75107fde30ac9 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:46:32 -0500 Subject: [PATCH 031/112] Update documentation/functions.md Revision 16 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index 327af014..57c56e10 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -306,7 +306,7 @@ TPCH(avg_non_debt_balance = AVG(Customers(no_debt_bal = KEEP_IF(acctbal, acctbal ### MONOTONIC -The `MONOTONIC` function returns whether all of its arguments are in ascending order (e.g. `MONOTONIC(a, b, c, d)` is equivalent to `(a <= b) & (b <= c) & (c <= d)`): +The `MONOTONIC` function checks if all of its arguments are in ascending order (e.g. `MONOTONIC(a, b, c, d)` is equivalent to `(a <= b) & (b <= c) & (c <= d)`): ```py Lineitems.WHERE(MONOTONIC(10, quantity, 20) & MONOTONIC(5, part.size, 13)) From fb86e3a7f3fd91391c16595eb54e0a08deb2e4d4 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:46:53 -0500 Subject: [PATCH 032/112] Update documentation/functions.md Revision 17 Co-authored-by: Hadia Ahmed --- documentation/functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/functions.md b/documentation/functions.md index 57c56e10..98015bd0 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -364,7 +364,7 @@ Suppliers(cheapest_part_supplied = MIN(supply_records.supply_cost)) The `MAX` function returns the largest value from the set of numerical values it is called on. ```py -Suppliers(most_expensive_part_supplied = MIN(supply_records.supply_cost)) +Suppliers(most_expensive_part_supplied = MAX(supply_records.supply_cost)) ``` From ca6b653a1222d710f4220cf6354c73edab278aa2 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 14 Jan 2025 10:52:02 -0500 Subject: [PATCH 033/112] Updating arithmetic documentaiton and LIKE link --- documentation/functions.md | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/documentation/functions.md b/documentation/functions.md index 98015bd0..8d54c8dc 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -56,11 +56,10 @@ Below is each binary operator currently supported in PyDough. ### Arithmetic -Numerical expression values can be: -- Added together with the `+` operator -- Subtracted from one another with the `-` operator -- Multiplied by one another with the `*` operator -- divided by one another with the `/` operator (note: the behavior when the denominator is `0` depends on the database being used to evaluate the expression) +Supported mathematical operations: addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`). + +> [!NOTE] +> The behavior when the denominator is `0` depends on the database being used to evaluate the expression. ```py Lineitems(value = (extended_price * (1 - discount) + 1.0) / part.retail_price) @@ -197,13 +196,7 @@ The `LIKE` function checks if the first argument matches the SQL pattern text of Orders(is_special_request = LIKE(comment, "%special%requests%")) ``` -Below are some examples of how to interpret these patterns: -- `"a_c"` returns True for any 3-letter string where the first character is `"a"` and the third is `"c"`. -- `"_q__"` returns True for any 4-letter string where the second character is `"q"`. -- `"%_s"` returns True for any 2+-letter string where the last character is `"s"`. -- `"a%z"` returns True for any string that starts with `"a"` and ends with `"z"`. -- `"%a%z%"` returns True for any string that contains an `"a"`, and also contains a `"z"` at some later point in the string. -- `"_e%"` returns True for any string where the second character is `"e"`. +[This link](https://www.w3schools.com/sql/sql_like.asp) explains how these SQL pattern strings work and provides some examples. ### JOIN_STRINGS From 22dfb322703f45a5d7a8e6b30b5c6a6777962214 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 14 Jan 2025 10:52:40 -0500 Subject: [PATCH 034/112] Updating numerical operator warning --- documentation/functions.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/documentation/functions.md b/documentation/functions.md index 8d54c8dc..31386553 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -58,13 +58,13 @@ Below is each binary operator currently supported in PyDough. Supported mathematical operations: addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`). -> [!NOTE] -> The behavior when the denominator is `0` depends on the database being used to evaluate the expression. - ```py Lineitems(value = (extended_price * (1 - discount) + 1.0) / part.retail_price) ``` +> [!WARNING] +> The behavior when the denominator is `0` depends on the database being used to evaluate the expression. + ### Comparisons From e550ae45c31c2904d1daa051a063450983db5f49 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 14 Jan 2025 11:11:26 -0500 Subject: [PATCH 035/112] Added function list checking test and 3 missing functions --- documentation/functions.md | 30 ++++++++++++++++++++++++++++++ tests/test_documentation.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100644 tests/test_documentation.py diff --git a/documentation/functions.md b/documentation/functions.md index 31386553..42005cef 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -15,6 +15,7 @@ Below is the list of every function/operator currently supported in PyDough as a - [String Functions](#string-functions) * [LOWER](#lower) * [UPPER](#upper) + * [LENGTH](#length) * [STARTSWITH](#startswith) * [ENDSWITH](#endswith) * [CONTAINS](#contains) @@ -26,6 +27,7 @@ Below is the list of every function/operator currently supported in PyDough as a * [DAY](#day) - [Conditional Functions](#conditional-functions) * [IFF](#iff) + * [ISIN](#isin) * [DEFAULT_TO](#default_to) * [PRESENT](#present) * [ABSENT](#absent) @@ -36,6 +38,7 @@ Below is the list of every function/operator currently supported in PyDough as a * [ROUND](#round) - [Aggregation Functions](#aggregation-functions) * [SUM](#sum) + * [AVG](#avg) * [MIN](#min) * [MAX](#max) * [COUNT](#count) @@ -160,6 +163,15 @@ Calling `UPPER` on a string converts its characters to uppercase: Customers(uppercase_name = UPPER(name)) ``` + +### LENGTH + +Calling `length` on a string returns the number of characters it contains: + +```py +Suppliers(n_chars_in_comment = LENGTH(comment)) +``` + ### STARTSWITH @@ -260,6 +272,15 @@ Customers( ) ``` + +### ISIN + +The `ISIN` function takes in an expression and an iterable of literals and returns whether the expression is a member of provided literals. + +```py +Parts.WHERE(ISIN(size, (10, 11, 17, 19, 45))) +``` + ### DEFAULT_TO @@ -342,6 +363,15 @@ The `SUM` function returns the sum of the plural set of numerical values it is c Nations(total_consumer_wealth = SUM(customers.acctbal)) ``` + +### AVG + +The `AVG` function takes the average of the plural set of numerical values it is called on. + +```py +Parts(average_shipment_size = AVG(lines.quantity)) +``` + ### MIN diff --git a/tests/test_documentation.py b/tests/test_documentation.py new file mode 100644 index 00000000..2ec2a64d --- /dev/null +++ b/tests/test_documentation.py @@ -0,0 +1,36 @@ +""" +Verifies that various documentation files are up to date. +""" + +import pydough.pydough_operators as pydop + + +def test_function_list(): + """ + Tests that every function in the operator registry is also part of the + `functions.md` file, unless it is a special name that should not be + mentioned or it is a binary operator. + + Note: this test should only be run from the root directory of the project. + """ + # Identify every function that should be documented + special_names = {"NOT", "SLICE"} + function_names: set[str] = set() + for function_name, operator in pydop.builtin_registered_operators().items(): + if not ( + isinstance(operator, pydop.BinaryOperator) or function_name in special_names + ): + function_names.add(function_name) + # Identify every section header in the function documentation + headers: set[str] = set() + with open("documentation/functions.md") as f: + for line in f.readlines(): + if line.startswith("#"): + headers.add(line.strip("#").strip()) + # Remove any function name that is in the headers, and fail if there are + # any that remain + function_names.difference_update(headers) + if function_names: + raise Exception( + "The following functions are not documented: " + ", ".join(function_names) + ) From f7af1a9346f5b7d29f51fd97ca098e33ba0203c4 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 14 Jan 2025 11:19:21 -0500 Subject: [PATCH 036/112] Updated some explanations --- documentation/functions.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/documentation/functions.md b/documentation/functions.md index 42005cef..478a48cb 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -352,7 +352,9 @@ Parts(rounded_price = ROUND(retail_price, 1)) ## Aggregation Functions -Normally, functions in PyDough maintain the cardinality of their inputs. Aggregation functions instead take in an argument that can be plural and aggregates it into a singular value with regards to the current context. Below is each function currently supported in PyDough that can aggregate plural values into a singular value. +When terms of a plural sub-collection are accessed, those terms are plural with regards to the current collection. For example, if each nation in `Nations` has multiple `customers`, and each customer has a single `acctbal`, then `customers.acctbal` is plural with regards to `Nations` and cannot be used in any calculations when the current context is `Nations`. The exception to this is when `customers.acctbal` is made singular with regards to `Nations` by aggregating it. + +Aggregation functions are a special set of functions that, when called on their inputs, convert them from plural to singular. Below is each aggregation function currently supported in PyDough. ### SUM @@ -435,9 +437,9 @@ Customers.WHERE(HASNOT(orders)) ## Window Functions -Window functions are special functions that return a value for each record in the current context that depends on other records in the same context. A common example of this is ordering all values within the current context to return a value that depends on the current record's ordinal position relative to all the other records in the context. +Window functions are special functions whose output depends on other records in the same context.A common example of this is finding the ranking of each record if all of the records were to be sorted. -Window functions in PyDough have an optional `levels` argument. If this argument is not provided, it means that the window function applies to all records of the current collection without any boundaries between records. If it is provided, it should be a value that can be used as an argument to `BACK`, and in that case it means that the window function should be used on records of the current collection grouped by that particular ancestor. +Window functions in PyDough have an optional `levels` argument. If this argument is omitted, it means that the window function applies to all records of the current collection (e.g. rank all customers). If it is provided, it should be a value that can be used as an argument to `BACK`, and in that case it means that the set of values used by the window function should be per-record of the correspond ancestor (e.g. rank all customers within each nation). For example, if using the `RANKING` window function, consider the following examples: From 6a66c5f1c150560ac8359cf86e931af67c30f8b4 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 14 Jan 2025 12:09:56 -0500 Subject: [PATCH 037/112] Started PARTITION docs, still need to add a few more bad examples --- documentation/dsl.md | 203 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 202 insertions(+), 1 deletion(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index 76675e57..e5065970 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -766,7 +766,208 @@ People.TOP_K(1000, by=birth_date) ### PARTITION -TODO +The `PARTITION` operation is used to create a new collection by partitioning the records of another collection based on 1+ partitioning terms. Every unique combination values of those partitioning terms corresponds to a single record in the new collection. The terms of the new collection are the partitioning terms, and a single sub-collection mapping back to the bucketed terms of the original data. + +The syntax for this is `PARTITION(data, name="...", by=...)`. The `data` argument is the PyDough collection that is to be partitioned. The `name` argument is a string indicating the name that is to be used when accessing the partitioned data, and the `by` argument is either a single partitioning key, or an iterable of 1+ partitioning keys. + +> [!WARNING] +> PyDough currently only supports using references to scalar expressions from the `data` collection itself as partition keys, not an ancestor term, or a term from a child collection, or the result of a function call. + +If the partitioned data is accessed, its original ancestry is lost. Instead, it inherits the ancestry of the `PARTITION` clause, so `BACK(1)` is the `PARTITION` clause, and `BACK(2)` is the ancestor of the partition clause (the default is the entire graph, just like collections). + +The ancestry of the `PARTITION` clause can be changed by prepending it with another collection, separated by a dot. However, this is currently only supported in PyDough when the collection before the dot is just an augmented version of the graph context, as opposed to another collection (e.g. `GRAPH(x=42).PARTITION(...)` is supported, but `People.PARTITION(...)` is not). + +**Good Example #1**: Finds every unique state. + +```py +%%pydough +PARTITION(Addresses, name="addrs", by=state)(state) +``` + +**Good Example #2**: For every state, counts how many addresses are in that state. + +```py +%%pydough +PARTITION(Addresses, name="addrs", by=state)( + state, + n_addr=COUNT(addrs) +) +``` + +**Good Example #3**: For every city/state, counts how many people live in that city/state. + +```py +%%pydough +PARTITION(Addresses, name="addrs", by=(city, state))( + state, + city, + n_people=COUNT(addrs.current_occupants) +) +``` + +**Good Example #4**: Finds the top 5 years with the most people born in that year who have yahoo email accounts, listing the year and the number of people. + +```py +%%pydough +yahoo_people = People( + birth_year=YEAR(birth_date) +).WHERE(ENDSWITH(email, "@yahoo.com")) +PARTITION(yahoo_people, name="yah_ppl", by=birth_year)( + birth_year, + n_people=COUNT(yah_ppl) +).TOP_K(5, by=n_people.DESC()) +``` + +**Good Example #4**: For every year/month, finds all packages that were below the average cost of all packages ordered in that year/month. + +```py +%%pydough +package_info = Packages(order_year=YEAR(order_date), order_month=MONTH(order_date)) +PARTITION(package_info, name="packs", by=(order_year, order_month))( + avg_package_cost=AVG(packs.package_cost) +).packs.WHERE( + package_cost < BACK(1).avg_package_cost +) +``` + +**Good Example #5**: For every customer, finds the percentage of all orders made by current occupants of that city/state made by that specific customer. Includes the first/last name of the person, the city/state they live in, and the percentage. + +```py +%%pydough +PARTITION(Addresses, name="addrs", by=(city, state))( + total_packages=COUNT(addrs.current_occupants.packages) +).addrs.current_occupants( + first_name, + last_name, + city=BACK(1).city, + state=BACK(1).state, + pct_of_packages=100.0 * COUNT(packages) / BACK(2).total_packages, +) +``` + +**Good Example #6**: Identifies which states' current occupants account for at least 1% of all packages purchased. Lists the state and the percentage. + +```py +%%pydough +GRAPH( + total_packages=COUNT(Packages) +).PARTITION(Addresses, name="addrs", by=state)( + state, + pct_of_packages=100.0 * COUNT(addrs.current_occupants.package) / BACK(1).packages +).WHERE(pct_of_packages >= 1.0) +``` + +**Good Example #7**: Identifies which months of the year have numbers of packages shipped in that month that are above the average for all months. + +```py +%%pydough +pack_info = Packages(order_month=MONTH(order_date)) +month_info = PARTITION(pack_info, name="packs", by=order_month)( + n_packages=COUNT(packs) +) +GRAPH( + avg_packages_per_month=AVG(month_info.n_packages) +).PARTITION(pack_info, name="packs", by=order_month)( + month, +).WHERE(COUNT(packs) > BACK(1).avg_packages_per_month) +``` + +**Good Example #8**: Finds the 10 most frequent combinations of the state that the person lives in and the first letter of that person's name. + +```py +%%pydough +people_info = Addresses.current_occupants( + state=BACK(1).state, + first_letter=first_name[:1], +) +PARTITION(people_info, name="ppl", by=(state, first_letter))( + state, + first_letter, + n_people=COUNT(ppl), +).TOP_K(10, by=n_people.DESC()) +``` + +**Good Example #9**: Same as good example #8, but written differently so it will include people without a current address (their state is listed as `"N/A"`). + +```py +%%pydough +people_info = People( + state=DEFALT_TO(current_address.state, "N/A"), + first_letter=first_name[:1], +) +PARTITION(people_info, name="ppl", by=(state, first_letter))( + state, + first_letter, + n_people=COUNT(ppl), +).TOP_K(10, by=n_people.DESC()) +``` + +**Bad Example #1**: Partitions a collection `Products` that does not exist in the graph. + +```py +%%pydough +PARTITION(Products, name="p", by=product_type) +``` + +**Bad Example #2**: Does not provide a valid `name` when partitioning `Addresses` by the state. + +```py +%%pydough +PARTITION(Addresses, by=state) +``` + +**Bad Example #3**: Does not provide a `by` argument to partition `People`. + +```py +%%pydough +PARTITION(People, name="ppl") +``` + +**Bad Example #4**: Counts how many packages were ordered in each year. Invalid because `YEAR(order_date)` is not allowed ot be used as a partition term (it must be placed in a CALC so it is accessible as a named reference). + +```py +%%pydough +PARTITION(Packages, name="packs", by=YEAR(order_date))( + n_packages=COUNT(packages) +) +``` + +**Bad Example #5**: Counts how many people live in each state. Invalid because `current_address.state` is not allowed to be used as a partition term (it must be placed in a CALC so it is accessible as a named reference). + +```py +%%pydough +PARTITION(People, name="ppl", by=current_address.state)( + n_packages=COUNT(packages) +) +``` + +**Bad Example #6**: Invalid version of good example #8 that did not use a CALC to get rid of the `BACK` or `first_name[:1]`, which cannot be used as partition terms. + +```py +%%pydough +PARTITION(Addresses.current_occupants, name="ppl", by=(BACK(1).state, first_name[:1]))( + BACK(1).state, + first_name[:1], + n_people=COUNT(ppl), +).TOP_K(10, by=n_people.DESC()) +``` + +**Bad Example #7**: Partitions people by their birth year to find the number of people born in each year. Invalid because the `email` property is referenced, which is not one of the properties accessible by the partition. + +```py +%%pydough +PARTITION(People(birth_year=YEAR(birth_date)), name="ppl", by=birth_year)( + birth_year, + email, + n_people=COUNT(ppl) +) +``` + + ### SINGULAR From 19cfd47d9544cd14f00957408ccd9c94b165c8cd Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 14 Jan 2025 14:53:14 -0500 Subject: [PATCH 038/112] Added expressions, more bad partition examples, and NEXT/PREV --- documentation/dsl.md | 287 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 276 insertions(+), 11 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index e5065970..e1d7f180 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -10,6 +10,7 @@ This page describes the specification of the PyDough DSL. The specification incl * [CALC](#calc) * [Contextless Expressions](#contextless-expressions) * [BACK](#back) + * [Expressions](#expressions) - [Collection Operators](#collection-operators) * [WHERE](#where) * [ORDER_BY](#order_by) @@ -474,6 +475,81 @@ Customers.packages( Addresses.current_occupants(a=BACK(1).phone) ``` + +### Expressions + +So far, many different kinds of expressions have been noted in the examples for CALC, BACK, and contextless expressions. The following are examples & explanations of the various types of valid expressions: + +```py +# Referencing scalar properties of the current collection +People( + first_name, + last_name +) + +# Referencing scalar properties of a singular sub-collection +People( + current_state=current_address.state, + current_state=current_address.state, +) + +# Referencing scalar properties of an ancestor collection +Addresses.current_occupants.packages( + customer_email=BACK(1).email, + customer_zip_code=BACK(2).zip_code, +) + +# Invoking normal functions/operations on other singular data +Customers( + lowered_name=LOWER(name), + normalized_birth_month=MONTH(birth_date) - 1, + lives_in_c_state=STARTSWITH(current_address.state, "C"), +) + +# Supported Python literals: +# - integers +# - floats +# - strings +# - booleans +# - None +# - decimal.Decimal +# - pandas.Timestamp +# - datetime.date +# - lists/tuples of literals +from pandas import Timestamp +from datetime import date +from decimal import Decimal +Customers( + a=0, + b=3.14, + c="hello world", + d=True, + e=None, + f=decimal.Decimal("2.718281828"), + g=Timestamp("now"), + h=date(2024, 1, 1), + i=[1, 2, 4, 8, 10], + j=("SMALL", "LARGE"), +) + +# Invoking aggregation functions on plural data +Customers( + n_packages=COUNT(packages), + home_has_had_packages_billed=HAS(current_address.billed_packages), + avg_package_cost=AVG(packages.package_cost), + n_states_shipped_to=NDISTINCT(packages.shipping_address.state), + most_recent_package_ordered=MAX(packages.order_date), +) + +# Invoking window functions on singular +Customers( + cust_ranking=RANKING(by=COUNT(packages).DESC()), + cust_percentile=PERCENTILE(by=COUNT(packages).DESC()), +) +``` + +See [the list of PyDough functions](funcitons.md) to see all of the builtin functions & operators that can be called in PyDough. + ## Collection Operators @@ -615,6 +691,13 @@ Addresses.WHERE( ) ``` +**Good Example #6**: Finds all people who are in the top 1% of customers according to number of packages ordered. + +```py +%%pydough +People.WHERE(PERCENTILE(by=COUNT(packages).ASC()) == 100) +``` + **Bad Example #1**: Sorts each person by their account balance in descending order. This is invalid because the `People` collection does not have an `account_balance` property. ```py @@ -845,7 +928,7 @@ PARTITION(Addresses, name="addrs", by=(city, state))( ) ``` -**Good Example #6**: Identifies which states' current occupants account for at least 1% of all packages purchased. Lists the state and the percentage. +**Good Example #6**: Identifies the states whose current occupants account for at least 1% of all packages purchased. Lists the state and the percentage. ```py %%pydough @@ -963,44 +1046,226 @@ PARTITION(People(birth_year=YEAR(birth_date)), name="ppl", by=birth_year)( ) ``` - +**Bad Example #7**: For each person & year, counts how many times that person ordered a packaged in that year. This is invalid because doing `.PARTITION` after `People` is unsupported, since `People` is not a graph-level collection like `GRAPH(...)`. + +```py +%%pydough +People.PARTITION(packages(year=YEAR(order_date)), name="p", by=year)( + ssn=BACK(1).ssn, + year=year, + n_packs=COUNT(p) +) +``` + +**Bad Example #8**: Partitions each address' current occupants by their birth year to get the number of people per birth year. This is invalid because the example includes a field `BACK(2).bar` which does not exist because the first ancestor of the partition is `GRAPH`, which does not have a second ancestor. + +```py +%%pydough +people_info = Addresses.current_occupants(birth_year=YEAR(birth_date)) +GRAPH.PARTITION(people_info, name="p", by=birth_year)( + birth_year, + n_people=COUNT(p), + foo=BACK(2).bar, +) +``` + +**Bad Example #9**: Partitions each address' current occupants by their birth year and filters to only include people born in years where at least 10000 people were born, then gets more information of people from those years. This is invalid because after accessing `.ppl`, the term `BACK(1).state` is used. This is not valid because even though the data that `.ppl` refers to (`people_info`) has access to `BACK(1).state`, that ancestry information was lost after partitioning `people_info`. Instead, `BACK(1)` refers to the `PARTITION` clause, which does not have a `state` field. + +```py +%%pydough +people_info = Addresses.current_occupants(birth_year=YEAR(birth_date)) +GRAPH.PARTITION(people_info, name="ppl", by=birth_year).WHERE( + COUNT(p) >= 10000 +).ppl( + first_name, + last_name, + state=BACK(1).state, +) +``` ### SINGULAR -TODO +> [!IMPORTANT] +> This feature has not yet been implemented in PyDough + +Certain PyDough operations, such as specific filters, can cause plural data to become singular. In this case, PyDough will still ban the plural data from being treated as singular unless the `.SINGULAR()` modifier is used to tell PyDough that the data should be treated as singular. It is very important that this only be used if the user is certain that the data will be singular, since otherwise it can result in undefined behavior when the PyDough code is executed. + +**Good Example #1**: Accesses the package cost of the most recent package ordered by each person. This is valid because even though `.packages` is plural, the filter done on it will ensure that there is only one record for each record of `People`, so `.SINGULAR()` is valid. + +```py +%%pydough +most_recent_package = packages.WHERE( + RANKING(by=order_date.DESC(), levels=1) == 1 +).SINGULAR() +People( + ssn, + first_name, + middle_name, + last_name, + most_recent_package_cost=most_recent_package.package_cost +) +``` + +**Good Example #2**: Accesses the email of the current occupant of each address that has the name `"John Smith"` (no middle name). This is valid if it is safe to assume that each address only has one current occupant named `"John Smith"` without a middle name. + +```py +%%pydough +js = current_occupants.WHERE( + (first_name == "John") & + (last_name == "Smith") & + ABSENT(middle_name) +).SINGULAR() +Addresses( + address_id, + john_smith_email=DEFAULT_TO(js.email, "NO JOHN SMITH LIVING HERE") +) +``` ### BEST +> [!IMPORTANT] +> This feature has not yet been implemented in PyDough + TODO ### NEXT / PREV +> [!IMPORTANT] +> This feature has not yet been implemented in PyDough + +In PyDough, it is also possible to access data from other records in the same collection that occur before or after the current record, when all the records are sorted. Similar to how `BACK(n)` can be used as a collection to access terms from an ancestor context, `PREV(n, by=...)` can be used to access terms from another record of the same context, specifically the record obtained by ordering by the `by` terms then looking for the record `n` entries before the current record. Similarly, `NEXT(n, ...)` is the same as `PREV(-n, ...)`. + +The arguments to `NEXT` and `PREV` are as follows: +- `n` (optional): how many records before/after the current record to look. The default value is 1. +- `by` (required): the collation terms used to sort the data. Must be either a single collation term, or an iterable of 1+ collation terms. +- `levels` (optional): same as window functions such as `RANKING` or `PERCENTILE`, documented in the [functions list](functions.md). + +If the entry `n` records before/after the current entry does not exist, then accessing anything from it returns null. Anything that can be done to the current context can also be done to the `PREV`/`NEXT` call (e.g. aggregating data from a plural sub-collection). + +**Good Example #1**: For each package, finds whether it was ordered by the same customer as the most recently ordered package before it. + +```py +%%pydough +Packages( + package_id, + same_customer_as_prev_package=customer_ssn == PREV(by=order_date.ASC()).ssn +) +``` + +**Good Example #2**: Finds the average number of hours between every package ordered by every customer. + +```py +%%pydough +prev_package = PREV(by=order_date.ASC(), levels=1) +package_deltas = packages( + hour_difference=DATEDIFF('hours', order_date, prev_package.order_date) +) +Customers( + ssn, + avg_hours_between_purchases=AVG(package_deltas.hour_difference) +) +``` + +**Good Example #3**: Finds out for each customer whether, if they were sorted by number of packages ordered, whether they live in the same state as any of the 3 people below them on the list. + +```py +%%pydough +first_after = NEXT(1, by=COUNT(packages).DESC()) +second_after = NEXT(2, by=COUNT(packages).DESC()) +third_after = NEXT(3, by=COUNT(packages).DESC()) +Customers( + ssn, + same_state_as_order_neighbors=( + DEFAULT_TO(current_address.state == first_after.current_address.state, False) | + DEFAULT_TO(current_address.state == second_after.current_address.state, False) | + DEFAULT_TO(current_address.state == third_after.current_address.state, False) + ) +) +``` + +**Bad Example #1**: TODO: add bad example where `by` is missing + +```py +%%pydough +TODO +``` + +**Bad Example #2**: TODO: add bad example where `by` is empty + +```py +%%pydough +TODO +``` + +**Bad Example #3**: TODO: add bad example where `by` is not a collation + +```py +%%pydough +TODO +``` + +**Bad Example #4**: TODO: add bad example where `n` is malformed + +```py +%%pydough +TODO +``` + +**Bad Example #5**: TODO: add bad example where `NEXT` is used as-is without accessing + +```py +%%pydough +TODO +``` + +**Bad Example #6**: TODO: add bad example where a term is accessed from `PREV` that does not + +```py +%%pydough TODO +``` + +**Bad Example #7**: TODO: add bad example where `PREV` is access with `.` syntax. + +```py +%%pydough +TODO +``` + +**Bad Example #8**: TODO: add bad example where `levels` goes too far up. + +```py +%%pydough +TODO +``` + +**Bad Example #9**: TODO: add bad example where an aggfunc is called directly on `NEXT`. + +```py +%%pydough +TODO +``` ## Induced Properties -This section of the PyDough spec has not yet been defined. +This section of the PyDough specification has not yet been defined. ### Induced Scalar Properties -This section of the PyDough spec has not yet been defined. +This section of the PyDough specification has not yet been defined. ### Induced Subcollection Properties -This section of the PyDough spec has not yet been defined. +This section of the PyDough specification has not yet been defined. ### Induced Arbitrary Joins -This section of the PyDough spec has not yet been defined. +This section of the PyDough specification has not yet been defined. From acba2c6c8caf8b063b91ff588a24cb3ab0183939 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 14 Jan 2025 15:38:01 -0500 Subject: [PATCH 039/112] Started BEST documentation, still need to do bad next/prev/best examples --- documentation/dsl.md | 152 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 143 insertions(+), 9 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index e1d7f180..00ab2a1b 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -17,8 +17,8 @@ This page describes the specification of the PyDough DSL. The specification incl * [TOP_K](#top_k) * [PARTITION](#partition) * [SINGULAR](#singular) - * [BEST](#best) * [NEXT / PREV](#next-prev) + * [BEST](#best) - [Induced Properties](#induced-properties) * [Induced Scalar Properties](#induced-scalar-properties) * [Induced Subcollection Properties](#induced-subcollection-properties) @@ -1122,14 +1122,6 @@ Addresses( ) ``` - -### BEST - -> [!IMPORTANT] -> This feature has not yet been implemented in PyDough - -TODO - ### NEXT / PREV @@ -1249,6 +1241,148 @@ TODO TODO ``` + +### BEST + +> [!IMPORTANT] +> This feature has not yet been implemented in PyDough + +PyDough supports identifying a specific record from a sub-collection that is optimal with regards to some metric, per-record of the current collection. This is done by using `BEST` instead of directly accessing the sub-collection. The first argument to `BEST` is the sub-collection to be accessed, and the second is a `by` argument used to find the optimal record of the sub-collection. The rules for the `by` argument are the same as `PREV`, `NEXT`, `TOP_K`, etc.: it must be either a single collation term, or an iterable of 1+ collation terms. + +A call to `BEST` can either be done with `.` syntax, to step from a parent collection to a child collection, or can be a freestanding accessor used inside of a collection operator, just like `BACK`, `PREV` or `NEXT`. For example, both `Parent.BEST(child, by=...)` and `Parent(x=BEST(child, by=...).y)` are allowed. + +The original ancestry of the sub-collection is intact. So, if doing `A.BEST(b.c.d, by=...)`, `BACK(1)` revers to `c`, `BACK(2)` refers to `b` and `BACK(3)` refers to `A`. + +Additional keyword arguments can be supplied to `BEST` that change its behavior: +- `allow_ties` (default=False): if True, changes the behavior to keep all records of the sub-collection that share the optimal values of the collation terms. If `allow_ties` is True, the `BEST` clause is no longer singular. +- `n_best=True`(defaults=1): if an integer greater than 1, changes the behavior to keep the top `n_best` values of the sub-collection for each record of the parent collection (fewer if `n_best` records of the sub-collection do not exist). If `n_best` is greater than 1, the `BEST` clause is no longer singular. NOTE: `n_best` cannot be greater than 1 at the same time that `allow_ties` is True. + +**Good Example #1**: Finds the package id & zip code the package was shipped to for every package that was the first-ever purchase for the customer. + +```py +%%pydough +Customers.BEST(packages, by=order_date.ASC())( + package_id, + shipping_address.zip_code +) +``` + +**Good Example #2**: For each customer, lists their ssn and the cost of the most recent package they have purchased. + +```py +%%pydough +Customers( + ssn, + most_recent_cost=BEST(packages, by=order_date.DESC()).package_cost +) +``` + +**Good Example #3**: Finds the address in the state of New York with the most occupants, ties broken by address id. Note: the `GRAPH.` prefix is optional in this case, since it is implied if there is no prefix to the `BEST` call. + +```py +%%pydough +addr_info = Addresses.WHERE( + state == "NY" +)(address_id, n_occupants=COUNT(current_occupants)) +GRAPH.BEST(addr_info, by=(n_occupants.DESC(), address_id.ASC())) +``` + +**Good Example #4**: For each customer, finds the number of people currently living in the address that they most recently shipped a package to. + +```py +%%pydough +most_recent_package = BEST(packages, by=order_date.DESC()) +Customers( + ssn, + n_occ_most_recent_addr=COUNT(most_recent_package.shipping_address.current_occupants) +) +``` + +**Good Example #5**: For each address that has occupants, lists out the first/last name of the person living in that address who has ordered the most packages, breaking ties in favor of the person with the smaller social security number. Also includes the city/state of the address, the number of people who live there, and the number of packages that person ordered. + +```py +%%pydough +Addresses.WHERE(HAS(current_occupants))( + n_occupants=COUNT(current_occupants) +).BEST( + current_occupants(n_orders=COUNT(packages)), + by=(n_orders.DESC(), ssn.ASC()) +)( + first_name, + last_name, + n_orders, + n_living_in_same_addr=BACK(1).n_occupants, + city=BACK(1).city, + state=BACK(1).state, +) +``` + +**Good Example #6**: For each person, finds the total value of the 5 most recent packages they ordered. + +```py +%%pydough +five_most_recent=BEST(packages, by=order_date.DESC(), n_best=5) +People( + ssn, + value_most_recent_5=SUM(five_most_recent.package_cost) +) +``` + +**Good Example #7**: TODO: example with multiple back levels. + +```py +%%pydough + +``` + +**Bad Example #1**: TODO: bad sub-colleciton argument to `BEST` + +```py +%%pydough +``` + +**Bad Example #2**: TODO: `by` argument is missing + +```py +%%pydough +``` + +**Bad Example #3**: TODO: `by` argument is not a collation + +```py +%%pydough +``` + +**Bad Example #4**: TODO: `by` argument is empty + +```py +%%pydough +``` + +**Bad Example #5**: TODO: bad combination of `n_best` and `allow_ties` + +```py +%%pydough +``` + +**Bad Example #6**: TODO: treating as singular when `n_best` is greater than 1 + +```py +%%pydough +``` + +**Bad Example #7**: TODO: treating as singular when `allow_ties` is True + +```py +%%pydough +``` + +**Bad Example #8**: TODO: incorrect usage of `BACK` + +```py +%%pydough +``` + ## Induced Properties From 3f0b703c777386c1581bcb3f57ac67e49ac0faab Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 16 Jan 2025 12:40:34 -0500 Subject: [PATCH 040/112] [RUN CI] From 0f5df29a2bd8c149ff6777b741b44dd45b860640 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 17 Jan 2025 11:58:43 -0500 Subject: [PATCH 041/112] Added bad next/prev examples --- documentation/dsl.md | 46 +++++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index 00ab2a1b..f0e59098 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -1178,60 +1178,74 @@ Customers( ) ``` -**Bad Example #1**: TODO: add bad example where `by` is missing +**Bad Example #1**: Finds the number of hours between each package and the previous package. This is invalid because the `by` argument is missing ```py %%pydough -TODO +Packages( + hour_difference=DATEDIFF('hours', order_date, PREV().order_date) +) ``` -**Bad Example #2**: TODO: add bad example where `by` is empty +**Bad Example #2**: Finds the number of hours between each package and the next package. This is invalid because the `by` argument is empty. ```py %%pydough -TODO +Packages( + hour_difference=DATEDIFF('hours', order_date, NEXT(by=()).order_date) +) ``` -**Bad Example #3**: TODO: add bad example where `by` is not a collation +**Bad Example #3**: Finds the number of hours between each package and the 5th-previous package. This is invalid because the `by` argument is not a collation. ```py %%pydough -TODO +Packages( + hour_difference=DATEDIFF('hours', order_date, PREV(5, by=order_date).order_date) +) ``` -**Bad Example #4**: TODO: add bad example where `n` is malformed +**Bad Example #4**: Finds the number of hours between each package and a subsequent package. This is invalid because the `n` argument is not an integer. ```py %%pydough -TODO +Packages( + hour_difference=DATEDIFF('hours', order_date, NEXT("ten", by=order_date.ASC()).order_date) +) ``` -**Bad Example #5**: TODO: add bad example where `NEXT` is used as-is without accessing +**Bad Example #5**: Invalid usage of `PREV` that is used as-is without accessing any of its fields. ```py %%pydough -TODO +Packages( + hour_difference=DATEDIFF('hours', order_date, PREV(1, by=order_date.ASC())) +) ``` -**Bad Example #6**: TODO: add bad example where a term is accessed from `PREV` that does not +**Bad Example #6**: Finds the number of hours between each package and the previous package. This invalid because a property `.odate` is accessed that does not exist in the collection, therefore it doesn't exist in `PREV` either. ```py %%pydough -TODO +Packages( + hour_difference=DATEDIFF('hours', order_date, PREV(1, by=order_date.ASC()).odate) +) ``` -**Bad Example #7**: TODO: add bad example where `PREV` is access with `.` syntax. +**Bad Example #7**: Invalid use of `PREV` that is invoked with `.` syntax, like a subcollection. ```py %%pydough -TODO +Packages.PREV(order_date.ASC()) ``` -**Bad Example #8**: TODO: add bad example where `levels` goes too far up. +**Bad Example #8**: Finds the number of hours between each package and the previous package ordered by the customer. This invalid because the `levels` value is too large, since only 2 ancestor levels exist in `Customers.packages` (the graph, and `Customers`): ```py %%pydough -TODO +Customers.packages( + hour_difference=DATEDIFF('hours', order_date, PREV(1, by=order_date.ASC(), levels=5).order_date) +) ``` **Bad Example #9**: TODO: add bad example where an aggfunc is called directly on `NEXT`. From 9752b9285acaa7ccd0299ebfeddf6454a05caa83 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 17 Jan 2025 11:59:01 -0500 Subject: [PATCH 042/112] Added bad next/prev examples --- documentation/dsl.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index f0e59098..2457d989 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -1248,13 +1248,6 @@ Customers.packages( ) ``` -**Bad Example #9**: TODO: add bad example where an aggfunc is called directly on `NEXT`. - -```py -%%pydough -TODO -``` - ### BEST From 223c1b5076cad160a6d4985377d81c5c20a3d4e3 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 21 Jan 2025 14:53:39 -0500 Subject: [PATCH 043/112] Adding examples and fixing 911 bugs to AST/hybrid handling [RUN CI] --- demos/notebooks/what_if.ipynb | 2 +- documentation/dsl.md | 198 ++++++++++++++++++- pydough/conversion/hybrid_tree.py | 2 + pydough/unqualified/unqualified_transform.py | 25 ++- tests/simple_pydough_functions.py | 71 +++++++ tests/test_unqualified_node.py | 53 +++++ 6 files changed, 338 insertions(+), 13 deletions(-) diff --git a/demos/notebooks/what_if.ipynb b/demos/notebooks/what_if.ipynb index 3332a15e..c3b03b9a 100644 --- a/demos/notebooks/what_if.ipynb +++ b/demos/notebooks/what_if.ipynb @@ -241,7 +241,7 @@ "source": [ "%%pydough\n", "\n", - "order_total_price = orders(order_revenue=total_price)\n", + "order_total_price = orders(order_revenue=total_revenue)\n", "pydough.to_df(order_total_price)" ] }, diff --git a/documentation/dsl.md b/documentation/dsl.md index 2457d989..a097aee2 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -23,6 +23,12 @@ This page describes the specification of the PyDough DSL. The specification incl * [Induced Scalar Properties](#induced-scalar-properties) * [Induced Subcollection Properties](#induced-subcollection-properties) * [Induced Arbitrary Joins](#induced-arbitrary-joins) +- [Larger Examples](#larger-examples) + * [Example 1: Highest Residency Density States](#example-1) + * [Example 2: TODO](#example-2) + * [Example 3: TODO](#example-3) + * [Example 4: TODO](#example-4) + * [Example 5: TODO](#example-5) @@ -1328,66 +1334,94 @@ Addresses.WHERE(HAS(current_occupants))( ```py %%pydough -five_most_recent=BEST(packages, by=order_date.DESC(), n_best=5) +five_most_recent = BEST(packages, by=order_date.DESC(), n_best=5) People( ssn, value_most_recent_5=SUM(five_most_recent.package_cost) ) ``` -**Good Example #7**: TODO: example with multiple back levels. +**Good Example #7**: For each address, finds the package most recently ordered by one of the current occupants of that address, including the email of the occupant who ordered it and the address' id. Notice that `BACK(1)` refers to `current_occupants` and `BACK(2)` refers to `Addresses` as if the packages were accessed as `Addresses.current_occupants.packages` instead of using `BEST`. ```py %%pydough +most_recent_package = BEST(current_occupants.packages, by=order_date.DESC()) +Addresses.most_recent_package( + address_id=BACK(2).address_id, + cust_email=BACK(1).email, + package_id=package_id, + order_date=order_date, +) +``` +**Bad Example #1**: For each person finds their best email. This is invalid because `email` is not a sub-collection of `People` (it is a scalar attribute, so there is only 1 `email` per-person). + +```py +%%pydough +People(first_name, BEST(email, by=birth_date.DESC())) ``` -**Bad Example #1**: TODO: bad sub-colleciton argument to `BEST` +**Bad Example #2**: For each person finds their best package. This is invalid because the `by` argument is missing. ```py %%pydough +People.BEST(packages) ``` -**Bad Example #2**: TODO: `by` argument is missing +**Bad Example #3**: For each person finds their best package. This is invalid because the: `by` argument is not a collation ```py %%pydough +People.BEST(packages, by=order_date) ``` -**Bad Example #3**: TODO: `by` argument is not a collation +**Bad Example #4**: For each person finds their best package. This is invalid because the `by` argument is empty ```py %%pydough +People.BEST(packages, by=()) ``` -**Bad Example #4**: TODO: `by` argument is empty +**Bad Example #5**: For each person finds the 5 most recent packages they have ordered, allowing ties. This is invalid because `n_best` is greater than 1 at the same time that `allow_ties` is True. ```py %%pydough +People.BEST(packages, by=order_date.DESC(), n_best=5, allow_ties=True) ``` -**Bad Example #5**: TODO: bad combination of `n_best` and `allow_ties` +**Bad Example #6**: For each person, finds the package cost of their 10 most recent packages. This is invalid because `n_best` is greater than 1, which means that the `BEST` clause is non-singular so its terms cannot be accessed in the calc without aggregating. ```py %%pydough +best_packages = BEST(packages, by=order_date.DESC(), n_best=10) +People(first_name, best_cost=best_packages.package_cost) ``` -**Bad Example #6**: TODO: treating as singular when `n_best` is greater than 1 +**Bad Example #7**: For each person, finds the package cost of their most expensive package(s), allowing ties. This is invalid because `allow_ties` is True, which means that the `BEST` clause is non-singular so its terms cannot be accessed in the calc without aggregating. ```py %%pydough +best_packages = BEST(packages, by=package_cost.DESC(), allow_ties=True) +People(first_name, best_cost=best_packages.package_cost) ``` -**Bad Example #7**: TODO: treating as singular when `allow_ties` is True +**Bad Example #8**: For each address, finds the package most recently ordered by one of the current occupants of that address, including the address id of the address. This is invalid because `BACK(1)` refers to `current_occupants`, which does not have a field called `address_id`. ```py %%pydough +most_recent_package = BEST(current_occupants.packages, by=order_date.DESC()) +Addresses.most_recent_package( + address_id=BACK(1).address_id, + package_id=package_id, + order_date=order_date, +) ``` -**Bad Example #8**: TODO: incorrect usage of `BACK` +**Bad Example #9**: For each address finds the oldest occupant. This is invalid because the `BEST` clause is placed in the calc without accessing any of its attributes. ```py %%pydough +Addresses(address_id, oldest_occupant=BEST(current_occupants, by=birth_date.ASC())) ``` @@ -1410,3 +1444,147 @@ This section of the PyDough specification has not yet been defined. This section of the PyDough specification has not yet been defined. + +## Larger Examples + +The rest of the document are examples of questions asked about the data in the people/addresses/packages graph and the corresponding PyDough code, which uses several of the features described in this document. + + +### Example 1: Highest Residency Density States + +**Question**: Find the 5 states with the highest average number of occupants per address. + +**Answer**: +```py +%%pydough +addr_info = Addresses(n_occupants=COUNT(current_occupants)) +result = PARTITION( + addr_info, + name="addrs", + by=state +)( + state, + average_occupants=AVG(addrs.n_occupants) +).TOP_K( + 5, + by=average_occupants.DESC() +) +``` + + +### Example 2: Yearly Trans-Coastal Shipments + +**Question**: For every calendar year, what percentage of all packages are from a customer living in the west coast to an address on the east coast? Only include packages that have already arrived, and order by the year. + +**Answer**: +```py +%%pydough +west_coast_states = ("CA", "OR", "WA", "AK") +east_coast_states = ("FL", "GA", "SC", "NC", "VA", "MD", "DE", "NJ", "NY", "CT", "RI", "MA", "NH", "MA") +from_west_coast = ISIN(customer.current_address.state, west_coast_states) +to_east_coast = ISIN(shipping_address.state, east_coast_states) +package_info = Packages.WHERE( + PRESENT(arrival_date) +)( + is_trans_coastal=from_west_coast & to_east_coast, + year=YEAR(order_date), +) +result = PARTITION( + package_info, + name="packs", + by=year, +)( + year, + pct_trans_coastal=100.0 * SUM(packs.is_trans_coastal) / COUNT(packs), +).ORDER_BY( + year.ASC() +) +``` + + +### Example 3: Email of Oldest Non-Customer Resident + +**Question**: For every city/state, find the email of the oldest resident of that city/state who has never ordered a package (break ties in favor of the lower social security number). Also include the zip code of that occupant. Order alphabetically by state, followed by city. + +**Answer**: +```py +%%pydough +cities = PARTITION( + Addresses, + name="addrs", + by=(city, state) +).BEST( + addrs.current_occupants.WHERE(HASNOT(packages)), + by=(birth_date.ASC(), ssn.ASC()), +)( + state=BACK(2).state, + city=BACK(2).city, + email=email, + zip_code=BACK(1).zip_code, +).ORDER_BY( + state.ASC(), + city.ASC(), +) +``` + + +### Example 4: Outlier Packages Per Month Of 2017 + +**Question**: For every month of the year 2017, identify the percentage of packages ordered in that month that are at least 10x the average value of all packages ordered in 2017. Order the results by month. + +**Answer**: +```py +%%pydough +is_2017 = YEAR(order_date) == 2017 +package_info = GRAPH( + avg_package_cost=AVG(Packages.WHERE(is_2017).package_cost) +).Packages.WHERE( + is_2017 +)( + month=MONTH(order_date), + is_10x_avg=package_cost >= (10.0 * BACK(1).avg_package_cost) +) +result = PARTITION( + package_info, + name="packs", + by=month +)( + month, + pct_outliers=100.0 * SUM(packs.is_10x_avg) / COUNT(packs) +).ORDER_BY( + month.ASC() +) +``` + + +### Example 5: Regression Prediction Of Packages Quantity + +**Question**: Using linear regression of the number of packages ordered per-year, what is the predicted number of packages for the next three years? + +**Answer**: +```py +%%pydough +yearly_data = PARTITION( + Packages(year=YEAR(order_date)), + name="packs", + by=year, +)( + year, + n_orders = COUNT(packs), +) +global_info = GRAPH( + avg_x = AVG(yearly_data.year), + avg_y = AVG(yearly_data.n_orders), +) +dx = n_orders - BACK(1).avg_x +dy = year - BACK(1).avg_y +regression_data = yearly_data(value=(dx * dy) / (dx * dx)) +slope = SUM(regression_data.value) +# Could also write as `last_year = packs.WHERE(RANKING(by=year.DESC()) == 1).SINGULAR()` +last_year = BEST(packs, by=year.DESC()) +results = {} +for n in range(1, 4): + results[f"year_{n}"] = last_year.year+n + results[f"year_{n}_prediction"] = last_year.n_orders+n*slope +result = global_info(**results) +``` diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 15f3bb18..428834dd 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -874,6 +874,8 @@ def __init__( self._is_connection_root: bool = is_connection_root self._agg_keys: list[HybridExpr] | None = None self._join_keys: list[tuple[HybridExpr, HybridExpr]] | None = None + if isinstance(root_operation, HybridPartition): + self._join_keys = [] def __repr__(self): lines = [] diff --git a/pydough/unqualified/unqualified_transform.py b/pydough/unqualified/unqualified_transform.py index 3d65c109..7be19ded 100644 --- a/pydough/unqualified/unqualified_transform.py +++ b/pydough/unqualified/unqualified_transform.py @@ -35,8 +35,8 @@ def visit_Module(self, node): def visit_Assign(self, node): for target in node.targets: - assert isinstance(target, ast.Name) - self._known_names.add(target.id) + if isinstance(target, ast.Name): + self._known_names.add(target.id) return self.generic_visit(node) def create_root_def(self) -> list[ast.AST]: @@ -84,6 +84,27 @@ def visit_FunctionDef(self, node): answer: ast.AST = self.generic_visit(result) return answer + def visit_expression(self, node) -> ast.expr: + result = self.generic_visit(node) + assert isinstance(result, ast.expr) + return result + + def visit_statement(self, node) -> ast.stmt: + result = self.generic_visit(node) + assert isinstance(result, ast.stmt) + return result + + def visit_For(self, node): + if isinstance(node.target, ast.Name): + self._known_names.add(node.target.id) + return ast.For( # type: ignore + target=node.target, + iter=self.visit_expression(node.iter), + body=[self.visit_statement(elem) for elem in node.body], + orelse=[self.visit_statement(elem) for elem in node.orelse], + type_comment=node.type_comment, + ) + def visit_Name(self, node): unrecognized_var: bool = False if node.id not in self._known_names: diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index 3146bfae..fcfa3186 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -121,3 +121,74 @@ def function_sampler(): .WHERE(MONOTONIC(0.0, acctbal, 100.0)) .TOP_K(10, by=address.ASC()) ) + + +def loop_generated_terms(): + terms = {"name": name} + for i in range(3): + terms[f"interval_{i}"] = COUNT( + customers.WHERE(MONOTONIC(i * 1000, acctbal, (i + 1) * 1000)) + ) + return Nations(**terms) + + +def function_defined_terms(): + def interval_n(n): + return COUNT(customers.WHERE(MONOTONIC(n * 1000, acctbal, (n + 1) * 1000))) + + return Nations( + name, + interval_7=interval_n(7), + interval_4=interval_n(4), + interval_13=interval_n(13), + ) + + +def dict_comp_terms(): + terms = {"name": name} + terms.update( + { + f"interval_{i}": COUNT( + customers.WHERE(MONOTONIC(i * 1000, acctbal, (i + 1) * 1000)) + ) + for i in range(3) + } + ) + return Nations(**terms) + + +def list_comp_terms(): + terms = [name] + terms.extend( + [ + COUNT(customers.WHERE(MONOTONIC(i * 1000, acctbal, (i + 1) * 1000))) + for i in range(3) + ] + ) + return Nations(**terms) + + +def set_comp_terms(): + terms = [name] + terms.extend( + set( + { + COUNT(customers.WHERE(MONOTONIC(i * 1000, acctbal, (i + 1) * 1000))) + for i in range(3) + } + ) + ) + return Nations(**terms) + + +def generator_comp_terms(): + terms = {"name": name} + for term, value in ( + ( + f"interval_{i}", + COUNT(customers.WHERE(MONOTONIC(i * 1000, acctbal, (i + 1) * 1000))), + ) + for i in range(3) + ): + terms[term] = value + return Nations(**terms) diff --git a/tests/test_unqualified_node.py b/tests/test_unqualified_node.py index 2cbf7a2c..129301c0 100644 --- a/tests/test_unqualified_node.py +++ b/tests/test_unqualified_node.py @@ -20,6 +20,14 @@ bad_window_6, bad_window_7, ) +from simple_pydough_functions import ( + dict_comp_terms, + function_defined_terms, + generator_comp_terms, + list_comp_terms, + loop_generated_terms, + set_comp_terms, +) from test_utils import graph_fetcher from tpch_test_functions import ( impl_tpch_q1, @@ -403,6 +411,51 @@ def test_unqualified_to_string( "?.TPCH(avg_balance=AVG(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE((ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17']) & HASNOT(?.orders))).WHERE((?.acctbal > 0.0)).acctbal)).PARTITION(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE((ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17']) & HASNOT(?.orders))).WHERE((?.acctbal > BACK(1).avg_balance)), name='custs', by=(?.cntry_code))(CNTRY_CODE=?.cntry_code, NUM_CUSTS=COUNT(?.custs), TOTACCTBAL=SUM(?.custs.acctbal))", id="tpch_q22", ), + pytest.param( + loop_generated_terms, + "?.Nations(name=?.name, interval_0=COUNT(?.customers.WHERE(MONOTONIC(0, ?.acctbal, 1000))), interval_1=COUNT(?.customers.WHERE(MONOTONIC(1000, ?.acctbal, 2000))), interval_2=COUNT(?.customers.WHERE(MONOTONIC(2000, ?.acctbal, 3000))))", + id="loop_generated_terms", + ), + pytest.param( + function_defined_terms, + "", + id="function_defined_terms", + marks=pytest.mark.skip( + "TODO: (gh #222) ensure PyDough code is compatible with full Python syntax " + ), + ), + pytest.param( + dict_comp_terms, + "", + id="dict_comp_terms", + marks=pytest.mark.skip( + "TODO: (gh #222) ensure PyDough code is compatible with full Python syntax " + ), + ), + pytest.param( + list_comp_terms, + "", + id="list_comp_terms", + marks=pytest.mark.skip( + "TODO: (gh #222) ensure PyDough code is compatible with full Python syntax " + ), + ), + pytest.param( + set_comp_terms, + "", + id="set_comp_terms", + marks=pytest.mark.skip( + "TODO: (gh #222) ensure PyDough code is compatible with full Python syntax " + ), + ), + pytest.param( + generator_comp_terms, + "", + id="generator_comp_terms", + marks=pytest.mark.skip( + "TODO: (gh #222) ensure PyDough code is compatible with full Python syntax " + ), + ), ], ) def test_init_pydough_context( From 7b142eb2ec52be3a87468dbaae0cb26b61ea88a0 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 21 Jan 2025 14:55:36 -0500 Subject: [PATCH 044/112] Updated TOC --- documentation/dsl.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index a097aee2..efcf6a91 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -24,11 +24,11 @@ This page describes the specification of the PyDough DSL. The specification incl * [Induced Subcollection Properties](#induced-subcollection-properties) * [Induced Arbitrary Joins](#induced-arbitrary-joins) - [Larger Examples](#larger-examples) - * [Example 1: Highest Residency Density States](#example-1) - * [Example 2: TODO](#example-2) - * [Example 3: TODO](#example-3) - * [Example 4: TODO](#example-4) - * [Example 5: TODO](#example-5) + * [Example 1: Highest Residency Density States](#example-1-highest-residency-density-states) + * [Example 2: Yearly Trans-Coastal Shipments](#example-2-yearly-trans-coastal-shipments) + * [Example 3: Email of Oldest Non-Customer Resident](#example-3-email-of-oldest-non-customer-resident) + * [Example 4: Outlier Packages Per Month Of 2017](#example-4-outlier-packages-per-month-of-2017) + * [Example 5: Regression Prediction Of Packages Quantity](#example-5-regression-prediction-of-packages-quantity) @@ -1449,7 +1449,7 @@ This section of the PyDough specification has not yet been defined. The rest of the document are examples of questions asked about the data in the people/addresses/packages graph and the corresponding PyDough code, which uses several of the features described in this document. - + ### Example 1: Highest Residency Density States **Question**: Find the 5 states with the highest average number of occupants per address. @@ -1471,7 +1471,7 @@ result = PARTITION( ) ``` - + ### Example 2: Yearly Trans-Coastal Shipments **Question**: For every calendar year, what percentage of all packages are from a customer living in the west coast to an address on the east coast? Only include packages that have already arrived, and order by the year. @@ -1501,7 +1501,7 @@ result = PARTITION( ) ``` - + ### Example 3: Email of Oldest Non-Customer Resident **Question**: For every city/state, find the email of the oldest resident of that city/state who has never ordered a package (break ties in favor of the lower social security number). Also include the zip code of that occupant. Order alphabetically by state, followed by city. @@ -1527,7 +1527,7 @@ cities = PARTITION( ) ``` - + ### Example 4: Outlier Packages Per Month Of 2017 **Question**: For every month of the year 2017, identify the percentage of packages ordered in that month that are at least 10x the average value of all packages ordered in 2017. Order the results by month. @@ -1556,7 +1556,7 @@ result = PARTITION( ) ``` - + ### Example 5: Regression Prediction Of Packages Quantity **Question**: Using linear regression of the number of packages ordered per-year, what is the predicted number of packages for the next three years? From df7e1b0bef84e9247c32599d0351bc7d0b391cdb Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 21 Jan 2025 15:17:46 -0500 Subject: [PATCH 045/112] Adding 911 bugfix for partition and corresponding tests [RUN CI] --- pydough/sqlglot/sqlglot_relational_visitor.py | 13 +---- tests/simple_pydough_functions.py | 43 +++++++++++++++++ tests/test_pipeline.py | 48 +++++++++++++++++++ 3 files changed, 92 insertions(+), 12 deletions(-) diff --git a/pydough/sqlglot/sqlglot_relational_visitor.py b/pydough/sqlglot/sqlglot_relational_visitor.py index 7d52af10..31a6c5c7 100644 --- a/pydough/sqlglot/sqlglot_relational_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_visitor.py @@ -426,18 +426,7 @@ def visit_aggregate(self, aggregate: Aggregate) -> None: for alias, col in aggregate.aggregations.items() ] select_cols = keys + aggregations - query: Select - if ( - "group_by" in input_expr.args - or "qualify" in input_expr.args - or "order" in input_expr.args - or "limit" in input_expr.args - ): - query = self._build_subquery(input_expr, select_cols) - else: - query = self._merge_selects( - select_cols, input_expr, find_identifiers_in_list(select_cols) - ) + query: Select = self._build_subquery(input_expr, select_cols) if keys: query = query.group_by(*keys) self._stack.append(query) diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index fcfa3186..2f8ad21f 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -110,6 +110,7 @@ def regional_suppliers_percentile(): def function_sampler(): + # Examples of using different functions return ( Regions.nations.customers( a=JOIN_STRINGS("-", BACK(2).name, BACK(1).name, name[16:]), @@ -124,6 +125,7 @@ def function_sampler(): def loop_generated_terms(): + # Using a loop & dictionary to generate PyDough calc terms terms = {"name": name} for i in range(3): terms[f"interval_{i}"] = COUNT( @@ -133,6 +135,7 @@ def loop_generated_terms(): def function_defined_terms(): + # Using a regular function to generate PyDough calc terms def interval_n(n): return COUNT(customers.WHERE(MONOTONIC(n * 1000, acctbal, (n + 1) * 1000))) @@ -144,7 +147,22 @@ def interval_n(n): ) +def lambda_defined_terms(): + # Using a lambda function to generate PyDough calc terms + interval_n = lambda n: COUNT( + customers.WHERE(MONOTONIC(n * 1000, acctbal, (n + 1) * 1000)) + ) + + return Nations( + name, + interval_7=interval_n(7), + interval_4=interval_n(4), + interval_13=interval_n(13), + ) + + def dict_comp_terms(): + # Using a dictionary comprehension to generate PyDough calc terms terms = {"name": name} terms.update( { @@ -158,6 +176,7 @@ def dict_comp_terms(): def list_comp_terms(): + # Using a list comprehension to generate PyDough calc terms terms = [name] terms.extend( [ @@ -169,6 +188,7 @@ def list_comp_terms(): def set_comp_terms(): + # Using a set comprehension to generate PyDough calc terms terms = [name] terms.extend( set( @@ -182,6 +202,7 @@ def set_comp_terms(): def generator_comp_terms(): + # Using a generator comprehension to generate PyDough calc terms terms = {"name": name} for term, value in ( ( @@ -192,3 +213,25 @@ def generator_comp_terms(): ): terms[term] = value return Nations(**terms) + + +def agg_partition(): + # Doing a global aggregation on the output of a partition aggregation + yearly_data = PARTITION(Orders(year=YEAR(order_date)), name="orders", by=year)( + n_orders=COUNT(orders) + ) + return TPCH(best_year=MAX(yearly_data.n_orders)) + + +def double_partition(): + # Doing a partition aggregation on the output of a partition aggregation + year_month_data = PARTITION( + Orders(year=YEAR(order_date), month=MONTH(order_date)), + name="orders", + by=(year, month), + )(n_orders=COUNT(orders)) + return PARTITION( + year_month_data, + name="months", + by=year, + )(year, best_month=MAX(months.n_orders)) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 6277982d..12faa6c2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -13,6 +13,8 @@ bad_slice_4, ) from simple_pydough_functions import ( + agg_partition, + double_partition, function_sampler, percentile_customers_per_region, percentile_nations, @@ -970,6 +972,52 @@ ), id="function_sampler", ), + pytest.param( + ( + agg_partition, + """ +ROOT(columns=[('best_year', best_year)], orderings=[]) + PROJECT(columns={'best_year': agg_1}) + AGGREGATE(keys={}, aggregations={'agg_1': MAX(n_orders)}) + PROJECT(columns={'n_orders': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={'year': year}, aggregations={'agg_0': COUNT()}) + PROJECT(columns={'year': YEAR(order_date)}) + SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate}) + """, + lambda: pd.DataFrame( + { + "best_year": [228637], + } + ), + ), + id="agg_partition", + ), + pytest.param( + ( + double_partition, + """ +ROOT(columns=[('year', year), ('best_month', best_month)], orderings=[]) + PROJECT(columns={'best_month': agg_2, 'year': year}) + JOIN(conditions=[t0.year == t1.year], types=['left'], columns={'agg_2': t1.agg_2, 'year': t0.year}) + AGGREGATE(keys={'year': year}, aggregations={}) + AGGREGATE(keys={'month': month, 'year': year}, aggregations={}) + PROJECT(columns={'month': MONTH(order_date), 'year': YEAR(order_date)}) + SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate}) + AGGREGATE(keys={'year': year}, aggregations={'agg_2': MAX(n_orders)}) + PROJECT(columns={'n_orders': DEFAULT_TO(agg_1, 0:int64), 'year': year}) + AGGREGATE(keys={'month': month, 'year': year}, aggregations={'agg_1': COUNT()}) + PROJECT(columns={'month': MONTH(order_date), 'year': YEAR(order_date)}) + SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate}) + """, + lambda: pd.DataFrame( + { + "year": [1992, 1993, 1994, 1995, 1996, 1997, 1998], + "best_month": [19439, 19319, 19546, 19502, 19724, 19519, 19462], + } + ), + ), + id="double_partition", + ), ], ) def pydough_pipeline_test_data( From 3f167a4852dd427655942608f529bc3de2890a3c Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 21 Jan 2025 15:39:53 -0500 Subject: [PATCH 046/112] Fixing mkglot examples [RUN CI] --- tests/test_relational_nodes_to_sqlglot.py | 34 ++++++++++++++++++----- tests/test_relational_to_sql.py | 2 +- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/tests/test_relational_nodes_to_sqlglot.py b/tests/test_relational_nodes_to_sqlglot.py index 1075ea6e..8778b8d5 100644 --- a/tests/test_relational_nodes_to_sqlglot.py +++ b/tests/test_relational_nodes_to_sqlglot.py @@ -633,7 +633,12 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: ), mkglot( expressions=[Ident(this="b")], - _from=GlotFrom(Table(this=Ident(this="table"))), + _from=GlotFrom( + mkglot( + expressions=[Ident(this="a"), Ident(this="b")], + _from=GlotFrom(Table(this=Ident(this="table"))), + ) + ), group_by=[Ident(this="b")], ), id="simple_distinct", @@ -717,14 +722,19 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: ), mkglot( expressions=[Ident(this="b")], - where=mkglot_func(EQ, [Ident(this="a"), mk_literal(1, False)]), - group_by=[Ident(this="b")], _from=GlotFrom( mkglot( - expressions=[Ident(this="a"), Ident(this="b")], - _from=GlotFrom(Table(this=Ident(this="table"))), + expressions=[Ident(this="b")], + _from=GlotFrom( + mkglot( + expressions=[Ident(this="a"), Ident(this="b")], + _from=GlotFrom(Table(this=Ident(this="table"))), + ) + ), + where=mkglot_func(EQ, [Ident(this="a"), mk_literal(1, False)]), ) ), + group_by=[Ident(this="b")], ), id="filter_before_aggregate", ), @@ -828,7 +838,12 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: ), mkglot( expressions=[Ident(this="b")], - _from=GlotFrom(Table(this=Ident(this="table"))), + _from=GlotFrom( + mkglot( + expressions=[Ident(this="a"), Ident(this="b")], + _from=GlotFrom(Table(this=Ident(this="table"))), + ) + ), group_by=[Ident(this="b")], order_by=[Ident(this="b").desc(nulls_first=False)], limit=mk_literal(10, False), @@ -866,7 +881,12 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: mkglot( expressions=[Ident(this="b")], group_by=[Ident(this="b")], - _from=GlotFrom(Table(this=Ident(this="table"))), + _from=GlotFrom( + mkglot( + expressions=[Ident(this="a"), Ident(this="b")], + _from=GlotFrom(Table(this=Ident(this="table"))), + ) + ), ) ), ), diff --git a/tests/test_relational_to_sql.py b/tests/test_relational_to_sql.py index 8c00d76e..6d71c7db 100644 --- a/tests/test_relational_to_sql.py +++ b/tests/test_relational_to_sql.py @@ -323,7 +323,7 @@ def sqlite_dialect() -> SQLiteDialect: aggregations={}, ), ), - "SELECT b FROM table GROUP BY b", + "SELECT b FROM (SELECT a, b FROM table) GROUP BY b", id="simple_distinct", ), pytest.param( From 09f2bfe5f85164232391c4ac2de7418fa18b80c5 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 21 Jan 2025 15:55:21 -0500 Subject: [PATCH 047/112] Added extra DSL example comments --- documentation/dsl.md | 97 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 78 insertions(+), 19 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index efcf6a91..bffb65f7 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -1457,18 +1457,22 @@ The rest of the document are examples of questions asked about the data in the p **Answer**: ```py %%pydough +# For each address, identify how many current occupants it has addr_info = Addresses(n_occupants=COUNT(current_occupants)) -result = PARTITION( + +# Partition the addresses by the state, and for each state calculate the +# average value of `n_occupants` for all addresses in that state +states = PARTITION( addr_info, name="addrs", by=state )( state, average_occupants=AVG(addrs.n_occupants) -).TOP_K( - 5, - by=average_occupants.DESC() ) + +# Obtain the top-5 states with the highest average +result = states.TOP_K(5, by=average_occupants.DESC()) ``` @@ -1479,26 +1483,36 @@ result = PARTITION( **Answer**: ```py %%pydough +# Contextless expression: identifies if a package comes from the west coast west_coast_states = ("CA", "OR", "WA", "AK") -east_coast_states = ("FL", "GA", "SC", "NC", "VA", "MD", "DE", "NJ", "NY", "CT", "RI", "MA", "NH", "MA") from_west_coast = ISIN(customer.current_address.state, west_coast_states) + +# Contextless expression: identifies if a pcakge is shipped to the east coast +east_coast_states = ("FL", "GA", "SC", "NC", "VA", "MD", "DE", "NJ", "NY", "CT", "RI", "MA", "NH", "MA") to_east_coast = ISIN(shipping_address.state, east_coast_states) + +# Filter packages to only include ones that have arrived, and derive additional +# terms for if they are trans-coastal + the year they were ordered package_info = Packages.WHERE( PRESENT(arrival_date) )( is_trans_coastal=from_west_coast & to_east_coast, year=YEAR(order_date), ) -result = PARTITION( + +# Partition the packages by the order year & count how many have a True value +# for is_trans_coastal, vs the total number in that year +year_info = PARTITION( package_info, name="packs", by=year, )( year, pct_trans_coastal=100.0 * SUM(packs.is_trans_coastal) / COUNT(packs), -).ORDER_BY( - year.ASC() ) + +# Output the results ordered by year +result = year_info.ORDER_BY(year.ASC()) ``` @@ -1509,11 +1523,17 @@ result = PARTITION( **Answer**: ```py %%pydough + +# Partition every address by the city/state cities = PARTITION( Addresses, name="addrs", by=(city, state) -).BEST( +) + +# For each city, find the oldest occupant out of any address in that city +# and include the desired information about that occupant. +oldest_occupants = cities.BEST( addrs.current_occupants.WHERE(HASNOT(packages)), by=(birth_date.ASC(), ssn.ASC()), )( @@ -1521,7 +1541,10 @@ cities = PARTITION( city=BACK(2).city, email=email, zip_code=BACK(1).zip_code, -).ORDER_BY( +) + +# Sort the output by state, followed by city +result = oldest_occupants.ORDER_BY( state.ASC(), city.ASC(), ) @@ -1535,25 +1558,40 @@ cities = PARTITION( **Answer**: ```py %%pydough +# Contextless expression: identifies is a package was ordered in 2017 is_2017 = YEAR(order_date) == 2017 -package_info = GRAPH( + +# Identify the average package cost of all packages ordered in 2017 +global_info = GRAPH( avg_package_cost=AVG(Packages.WHERE(is_2017).package_cost) -).Packages.WHERE( - is_2017 -)( +) + +# Identify all packages ordered in 2017, but where BACK(1) is global_info +# instead of GRAPH, so we have access to global_info's terms. +selected_package = global_info.Packages.WHERE(is_2017) + +# For each such package, identify the month it was ordered, and add a term to +# indicate if the cost of the package is at least 10x the average for all such +# packages. +packages = selected_packages( month=MONTH(order_date), is_10x_avg=package_cost >= (10.0 * BACK(1).avg_package_cost) ) -result = PARTITION( + +# Partition the packages by the month they were ordered, and for each month +# calculate the ratio between the number of packages where is_10x_avg is True +# versus all packages ordered that month, multiplied by 100 to get a percentage. +months = PARTITION( package_info, name="packs", by=month )( month, pct_outliers=100.0 * SUM(packs.is_10x_avg) / COUNT(packs) -).ORDER_BY( - month.ASC() ) + +# Order the output by month +result = months.ORDER_BY(month.ASC()) ``` @@ -1561,9 +1599,12 @@ result = PARTITION( **Question**: Using linear regression of the number of packages ordered per-year, what is the predicted number of packages for the next three years? +Note: uses the formula [discussed here](https://medium.com/swlh/linear-regression-in-sql-is-it-possible-b9cc787d622f) to identify the slope via linear regression. + **Answer**: ```py %%pydough +# Identify every year & how many packages were ordered that year yearly_data = PARTITION( Packages(year=YEAR(order_date)), name="packs", @@ -1572,19 +1613,37 @@ yearly_data = PARTITION( year, n_orders = COUNT(packs), ) + +# Obtain the global average of the year (x-coordinate) and +# n_orders (y-coordinate). These correspond to `x-bar` and `y-bar`. global_info = GRAPH( avg_x = AVG(yearly_data.year), avg_y = AVG(yearly_data.n_orders), ) + +# Contextless expression: corresponds to `x - x-bar` with regards to yearly_data +# inside of global_info dx = n_orders - BACK(1).avg_x + +# Contextless expression: corresponds to `y - y-bar` with regards to yearly_data +# inside of global_info dy = year - BACK(1).avg_y + +# Contextless expression: derive the slope with regards to global_info regression_data = yearly_data(value=(dx * dy) / (dx * dx)) slope = SUM(regression_data.value) + +# Identify the (chronologically) last record from yearly_data. # Could also write as `last_year = packs.WHERE(RANKING(by=year.DESC()) == 1).SINGULAR()` last_year = BEST(packs, by=year.DESC()) + +# Use a loop to derive a pair of terms for each of the 3 next years: +# 1. The year itself +# 2. The predicted number of orders (should be the last year's orders + slope * number of years) +# This is allowed since calcs can operate via keyword arguments. results = {} for n in range(1, 4): - results[f"year_{n}"] = last_year.year+n - results[f"year_{n}_prediction"] = last_year.n_orders+n*slope + results[f"year_{n}"] = last_year.year + n + results[f"year_{n}_prediction"] = last_year.n_orders + (n * slope) result = global_info(**results) ``` From 30536e8d379266bf9e4d0a2cda1935518b9c7296 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 21 Jan 2025 15:55:29 -0500 Subject: [PATCH 048/112] Added extra DSL example comments --- documentation/dsl.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index bffb65f7..e1a530ec 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -1640,7 +1640,8 @@ last_year = BEST(packs, by=year.DESC()) # Use a loop to derive a pair of terms for each of the 3 next years: # 1. The year itself # 2. The predicted number of orders (should be the last year's orders + slope * number of years) -# This is allowed since calcs can operate via keyword arguments. +# This is allowed since calcs can operate via keyword arguments, whether real +# or passed in via a dictionary with ** syntax. results = {} for n in range(1, 4): results[f"year_{n}"] = last_year.year + n From 261f8699ea5a32d97aacea2b4203408e21502a68 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Wed, 22 Jan 2025 10:59:12 -0500 Subject: [PATCH 049/112] Updating alias counter --- pydough/conversion/hybrid_tree.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 15f3bb18..0a9fc64e 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -874,6 +874,8 @@ def __init__( self._is_connection_root: bool = is_connection_root self._agg_keys: list[HybridExpr] | None = None self._join_keys: list[tuple[HybridExpr, HybridExpr]] | None = None + if isinstance(root_operation, HybridPartition): + self._join_keys = [] def __repr__(self): lines = [] @@ -1229,7 +1231,16 @@ def populate_children( connection index to use. """ for child_idx, child in enumerate(child_operator.children): + # Build the hybrid tree for the child. Before doing so, reset the + # alias counter to 0 to ensure that identical subtrees are named + # in the same manner. Afterwards, reset the alias counter to its + # value within this context. + snapshot: int = self.alias_counter + self.alias_counter = 0 subtree: HybridTree = self.make_hybrid_tree(child, hybrid) + self.alias_counter = snapshot + # Infer how the child is used by the current operation based on + # the expressions that the operator uses. reference_types: set[ConnectionType] = set() match child_operator: case Where(): @@ -1246,6 +1257,10 @@ def populate_children( self.identify_connection_types(expr, child_idx, reference_types) case PartitionBy(): reference_types.add(ConnectionType.AGGREGATION) + # Combine the various references to the child to identify the type + # of connection and add the child. If it already exists, the index + # of the existing child will be used instead, but the connection + # type will be updated to reflect the new invocation of the child. if len(reference_types) == 0: raise ValueError( f"Bad call to populate_children: child {child_idx} of {child_operator} is never used" From 3794707e787f7e7e1ad41a225d5a1284b20faad7 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Wed, 22 Jan 2025 12:31:08 -0500 Subject: [PATCH 050/112] Adding triple partition test --- tests/simple_pydough_functions.py | 35 ++++++++ tests/test_pipeline.py | 133 ++++++++++++++++++++---------- 2 files changed, 123 insertions(+), 45 deletions(-) diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index 2f8ad21f..df9c02fd 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -235,3 +235,38 @@ def double_partition(): name="months", by=year, )(year, best_month=MAX(months.n_orders)) + + +def triple_partition(): + # Doing three layers of partitioned aggregation. Goal of the question: + # for each region, calculate the average percentage of purchases made from + # suppliers in that region belonging to the most common part type shipped + # from the supplier region to the customer region, averaging across all + # customer region. Only considers lineitems from June of 1992 where the + # container is small. + line_info = ( + Parts.WHERE( + STARTSWITH(container, "SM"), + ) + .lines.WHERE((MONTH(ship_date) == 6) & (YEAR(ship_date) == 1992))( + supp_region=supplier.nation.region.name, + ) + .order( + supp_region=BACK(1).supp_region, + part_type=BACK(2).part_type, + cust_region=customer.nation.region.name, + ) + ) + rrt_combos = PARTITION( + line_info, name="lines", by=(supp_region, cust_region, part_type) + )(n_instances=COUNT(lines)) + rr_combos = PARTITION(rrt_combos, name="part_types", by=(supp_region, cust_region))( + percentage=100.0 * MAX(part_types.n_instances) / SUM(part_types.n_instances) + ) + return PARTITION( + rr_combos, + name="cust_regions", + by=supp_region, + )(supp_region, avg_percentage=AVG(cust_regions.percentage)).ORDER_BY( + supp_region.ASC() + ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 12faa6c2..7b883c5b 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -27,6 +27,7 @@ regional_suppliers_percentile, simple_filter_top_five, simple_scan_top_five, + triple_partition, ) from test_utils import ( graph_fetcher, @@ -377,25 +378,17 @@ ( impl_tpch_q13, """ -ROOT(columns=[('C_COUNT', C_COUNT), ('CUSTDIST', CUSTDIST)], orderings=[(ordering_3):desc_last, (ordering_4):desc_last]) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'CUSTDIST': CUSTDIST, 'C_COUNT': C_COUNT, 'ordering_3': ordering_3, 'ordering_4': ordering_4}, orderings=[(ordering_3):desc_last, (ordering_4):desc_last]) - PROJECT(columns={'CUSTDIST': CUSTDIST, 'C_COUNT': C_COUNT, 'ordering_3': CUSTDIST, 'ordering_4': C_COUNT}) - PROJECT(columns={'CUSTDIST': DEFAULT_TO(agg_2, 0:int64), 'C_COUNT': num_non_special_orders}) - JOIN(conditions=[t0.num_non_special_orders == t1.num_non_special_orders], types=['left'], columns={'agg_2': t1.agg_2, 'num_non_special_orders': t0.num_non_special_orders}) - AGGREGATE(keys={'num_non_special_orders': num_non_special_orders}, aggregations={}) - PROJECT(columns={'num_non_special_orders': DEFAULT_TO(agg_0, 0:int64)}) - JOIN(conditions=[t0.key == t1.customer_key], types=['left'], columns={'agg_0': t1.agg_0}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey}) - AGGREGATE(keys={'customer_key': customer_key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=NOT(LIKE(comment, '%special%requests%':string)), columns={'customer_key': customer_key}) - SCAN(table=tpch.ORDERS, columns={'comment': o_comment, 'customer_key': o_custkey}) - AGGREGATE(keys={'num_non_special_orders': num_non_special_orders}, aggregations={'agg_2': COUNT()}) - PROJECT(columns={'num_non_special_orders': DEFAULT_TO(agg_1, 0:int64)}) - JOIN(conditions=[t0.key == t1.customer_key], types=['left'], columns={'agg_1': t1.agg_1}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey}) - AGGREGATE(keys={'customer_key': customer_key}, aggregations={'agg_1': COUNT()}) - FILTER(condition=NOT(LIKE(comment, '%special%requests%':string)), columns={'customer_key': customer_key}) - SCAN(table=tpch.ORDERS, columns={'comment': o_comment, 'customer_key': o_custkey}) +ROOT(columns=[('C_COUNT', C_COUNT), ('CUSTDIST', CUSTDIST)], orderings=[(ordering_1):desc_last, (ordering_2):desc_last]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'CUSTDIST': CUSTDIST, 'C_COUNT': C_COUNT, 'ordering_1': ordering_1, 'ordering_2': ordering_2}, orderings=[(ordering_1):desc_last, (ordering_2):desc_last]) + PROJECT(columns={'CUSTDIST': CUSTDIST, 'C_COUNT': C_COUNT, 'ordering_1': CUSTDIST, 'ordering_2': C_COUNT}) + PROJECT(columns={'CUSTDIST': DEFAULT_TO(agg_0, 0:int64), 'C_COUNT': num_non_special_orders}) + AGGREGATE(keys={'num_non_special_orders': num_non_special_orders}, aggregations={'agg_0': COUNT()}) + PROJECT(columns={'num_non_special_orders': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[t0.key == t1.customer_key], types=['left'], columns={'agg_0': t1.agg_0}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey}) + AGGREGATE(keys={'customer_key': customer_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=NOT(LIKE(comment, '%special%requests%':string)), columns={'customer_key': customer_key}) + SCAN(table=tpch.ORDERS, columns={'comment': o_comment, 'customer_key': o_custkey}) """, tpch_q13_output, ), @@ -422,14 +415,14 @@ ( impl_tpch_q15, """ -ROOT(columns=[('S_SUPPKEY', S_SUPPKEY), ('S_NAME', S_NAME), ('S_ADDRESS', S_ADDRESS), ('S_PHONE', S_PHONE), ('TOTAL_REVENUE', TOTAL_REVENUE)], orderings=[(ordering_3):asc_first]) - PROJECT(columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME, 'S_PHONE': S_PHONE, 'S_SUPPKEY': S_SUPPKEY, 'TOTAL_REVENUE': TOTAL_REVENUE, 'ordering_3': S_SUPPKEY}) +ROOT(columns=[('S_SUPPKEY', S_SUPPKEY), ('S_NAME', S_NAME), ('S_ADDRESS', S_ADDRESS), ('S_PHONE', S_PHONE), ('TOTAL_REVENUE', TOTAL_REVENUE)], orderings=[(ordering_2):asc_first]) + PROJECT(columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME, 'S_PHONE': S_PHONE, 'S_SUPPKEY': S_SUPPKEY, 'TOTAL_REVENUE': TOTAL_REVENUE, 'ordering_2': S_SUPPKEY}) FILTER(condition=TOTAL_REVENUE == max_revenue, columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME, 'S_PHONE': S_PHONE, 'S_SUPPKEY': S_SUPPKEY, 'TOTAL_REVENUE': TOTAL_REVENUE}) - PROJECT(columns={'S_ADDRESS': address, 'S_NAME': name, 'S_PHONE': phone, 'S_SUPPKEY': key, 'TOTAL_REVENUE': DEFAULT_TO(agg_2, 0:int64), 'max_revenue': max_revenue}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'address': t0.address, 'agg_2': t1.agg_2, 'key': t0.key, 'max_revenue': t0.max_revenue, 'name': t0.name, 'phone': t0.phone}) + PROJECT(columns={'S_ADDRESS': address, 'S_NAME': name, 'S_PHONE': phone, 'S_SUPPKEY': key, 'TOTAL_REVENUE': DEFAULT_TO(agg_1, 0:int64), 'max_revenue': max_revenue}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'address': t0.address, 'agg_1': t1.agg_1, 'key': t0.key, 'max_revenue': t0.max_revenue, 'name': t0.name, 'phone': t0.phone}) JOIN(conditions=[True:bool], types=['inner'], columns={'address': t1.address, 'key': t1.key, 'max_revenue': t0.max_revenue, 'name': t1.name, 'phone': t1.phone}) - PROJECT(columns={'max_revenue': agg_1}) - AGGREGATE(keys={}, aggregations={'agg_1': MAX(total_revenue)}) + PROJECT(columns={'max_revenue': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': MAX(total_revenue)}) PROJECT(columns={'total_revenue': DEFAULT_TO(agg_0, 0:int64)}) JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'agg_0': t1.agg_0}) SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey}) @@ -437,7 +430,7 @@ FILTER(condition=ship_date >= datetime.date(1996, 1, 1):date & ship_date < datetime.date(1996, 4, 1):date, columns={'discount': discount, 'extended_price': extended_price, 'supplier_key': supplier_key}) SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) SCAN(table=tpch.SUPPLIER, columns={'address': s_address, 'key': s_suppkey, 'name': s_name, 'phone': s_phone}) - AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_2': SUM(extended_price * 1:int64 - discount)}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_1': SUM(extended_price * 1:int64 - discount)}) FILTER(condition=ship_date >= datetime.date(1996, 1, 1):date & ship_date < datetime.date(1996, 4, 1):date, columns={'discount': discount, 'extended_price': extended_price, 'supplier_key': supplier_key}) SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) """, @@ -472,8 +465,8 @@ impl_tpch_q17, """ ROOT(columns=[('AVG_YEARLY', AVG_YEARLY)], orderings=[]) - PROJECT(columns={'AVG_YEARLY': DEFAULT_TO(agg_1, 0:int64) / 7.0:float64}) - AGGREGATE(keys={}, aggregations={'agg_1': SUM(extended_price)}) + PROJECT(columns={'AVG_YEARLY': DEFAULT_TO(agg_0, 0:int64) / 7.0:float64}) + AGGREGATE(keys={}, aggregations={'agg_0': SUM(extended_price)}) FILTER(condition=quantity < 0.2:float64 * avg_quantity, columns={'extended_price': extended_price}) JOIN(conditions=[t0.key == t1.part_key], types=['inner'], columns={'avg_quantity': t0.avg_quantity, 'extended_price': t1.extended_price, 'quantity': t1.quantity}) PROJECT(columns={'avg_quantity': agg_0, 'key': key}) @@ -528,16 +521,16 @@ ( impl_tpch_q20, """ -ROOT(columns=[('S_NAME', S_NAME), ('S_ADDRESS', S_ADDRESS)], orderings=[(ordering_2):asc_first]) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME, 'ordering_2': ordering_2}, orderings=[(ordering_2):asc_first]) - PROJECT(columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME, 'ordering_2': S_NAME}) - FILTER(condition=name_3 == 'CANADA':string & DEFAULT_TO(agg_1, 0:int64) > 0:int64, columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'S_ADDRESS': t0.S_ADDRESS, 'S_NAME': t0.S_NAME, 'agg_1': t1.agg_1, 'name_3': t0.name_3}) +ROOT(columns=[('S_NAME', S_NAME), ('S_ADDRESS', S_ADDRESS)], orderings=[(ordering_1):asc_first]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME, 'ordering_1': ordering_1}, orderings=[(ordering_1):asc_first]) + PROJECT(columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME, 'ordering_1': S_NAME}) + FILTER(condition=name_3 == 'CANADA':string & DEFAULT_TO(agg_0, 0:int64) > 0:int64, columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'S_ADDRESS': t0.S_ADDRESS, 'S_NAME': t0.S_NAME, 'agg_0': t1.agg_0, 'name_3': t0.name_3}) JOIN(conditions=[t0.nation_key == t1.key], types=['left'], columns={'S_ADDRESS': t0.S_ADDRESS, 'S_NAME': t0.S_NAME, 'key': t0.key, 'name_3': t1.name}) PROJECT(columns={'S_ADDRESS': address, 'S_NAME': name, 'key': key, 'nation_key': nation_key}) SCAN(table=tpch.SUPPLIER, columns={'address': s_address, 'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_1': COUNT()}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': COUNT()}) FILTER(condition=STARTSWITH(name, 'forest':string) & availqty > DEFAULT_TO(agg_0, 0:int64) * 0.5:float64, columns={'supplier_key': supplier_key}) JOIN(conditions=[t0.key == t1.part_key], types=['left'], columns={'agg_0': t1.agg_0, 'availqty': t0.availqty, 'name': t0.name, 'supplier_key': t0.supplier_key}) JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'availqty': t0.availqty, 'key': t1.key, 'name': t1.name, 'supplier_key': t0.supplier_key}) @@ -977,8 +970,8 @@ agg_partition, """ ROOT(columns=[('best_year', best_year)], orderings=[]) - PROJECT(columns={'best_year': agg_1}) - AGGREGATE(keys={}, aggregations={'agg_1': MAX(n_orders)}) + PROJECT(columns={'best_year': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': MAX(n_orders)}) PROJECT(columns={'n_orders': DEFAULT_TO(agg_0, 0:int64)}) AGGREGATE(keys={'year': year}, aggregations={'agg_0': COUNT()}) PROJECT(columns={'year': YEAR(order_date)}) @@ -997,17 +990,12 @@ double_partition, """ ROOT(columns=[('year', year), ('best_month', best_month)], orderings=[]) - PROJECT(columns={'best_month': agg_2, 'year': year}) - JOIN(conditions=[t0.year == t1.year], types=['left'], columns={'agg_2': t1.agg_2, 'year': t0.year}) - AGGREGATE(keys={'year': year}, aggregations={}) - AGGREGATE(keys={'month': month, 'year': year}, aggregations={}) + PROJECT(columns={'best_month': agg_0, 'year': year}) + AGGREGATE(keys={'year': year}, aggregations={'agg_0': MAX(n_orders)}) + PROJECT(columns={'n_orders': DEFAULT_TO(agg_0, 0:int64), 'year': year}) + AGGREGATE(keys={'month': month, 'year': year}, aggregations={'agg_0': COUNT()}) PROJECT(columns={'month': MONTH(order_date), 'year': YEAR(order_date)}) SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate}) - AGGREGATE(keys={'year': year}, aggregations={'agg_2': MAX(n_orders)}) - PROJECT(columns={'n_orders': DEFAULT_TO(agg_1, 0:int64), 'year': year}) - AGGREGATE(keys={'month': month, 'year': year}, aggregations={'agg_1': COUNT()}) - PROJECT(columns={'month': MONTH(order_date), 'year': YEAR(order_date)}) - SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate}) """, lambda: pd.DataFrame( { @@ -1018,6 +1006,61 @@ ), id="double_partition", ), + pytest.param( + ( + triple_partition, + """ +ROOT(columns=[('supp_region', supp_region), ('avg_percentage', avg_percentage)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'avg_percentage': avg_percentage, 'ordering_1': supp_region, 'supp_region': supp_region}) + PROJECT(columns={'avg_percentage': agg_0, 'supp_region': supp_region}) + AGGREGATE(keys={'supp_region': supp_region}, aggregations={'agg_0': AVG(percentage)}) + PROJECT(columns={'percentage': 100.0:float64 * agg_0 / DEFAULT_TO(agg_1, 0:int64), 'supp_region': supp_region}) + AGGREGATE(keys={'cust_region': cust_region, 'supp_region': supp_region}, aggregations={'agg_0': MAX(n_instances), 'agg_1': SUM(n_instances)}) + PROJECT(columns={'cust_region': cust_region, 'n_instances': DEFAULT_TO(agg_0, 0:int64), 'supp_region': supp_region}) + AGGREGATE(keys={'cust_region': cust_region, 'part_type': part_type, 'supp_region': supp_region}, aggregations={'agg_0': COUNT()}) + PROJECT(columns={'cust_region': name_15, 'part_type': part_type, 'supp_region': supp_region}) + JOIN(conditions=[t0.customer_key == t1.key], types=['left'], columns={'name_15': t1.name_15, 'part_type': t0.part_type, 'supp_region': t0.supp_region}) + JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'customer_key': t1.customer_key, 'part_type': t0.part_type, 'supp_region': t0.supp_region}) + PROJECT(columns={'order_key': order_key, 'part_type': part_type, 'supp_region': name_7}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'name_7': t1.name_7, 'order_key': t0.order_key, 'part_type': t0.part_type}) + FILTER(condition=MONTH(ship_date) == 6:int64 & YEAR(ship_date) == 1992:int64, columns={'order_key': order_key, 'part_type': part_type, 'supplier_key': supplier_key}) + JOIN(conditions=[t0.key == t1.part_key], types=['inner'], columns={'order_key': t1.order_key, 'part_type': t0.part_type, 'ship_date': t1.ship_date, 'supplier_key': t1.supplier_key}) + FILTER(condition=STARTSWITH(container, 'SM':string), columns={'key': key, 'part_type': part_type}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'part_type': p_type}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'part_key': l_partkey, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_7': t1.name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_15': t1.name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + """, + lambda: pd.DataFrame( + { + "supp_region": [ + "AFRICA", + "AMERICA", + "ASIA", + "EUROPE", + "MIDDLE EAST", + ], + "avg_percentage": [ + 1.8038152, + 1.9968418, + 1.6850716, + 1.7673618, + 1.7373118, + ], + } + ), + ), + id="triple_partition", + ), ], ) def pydough_pipeline_test_data( From c04f34d3414932df1a1d76635a57b00371fafe15 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Wed, 22 Jan 2025 13:17:05 -0500 Subject: [PATCH 051/112] Adjiusting triple_partition test [RUN CI] --- tests/simple_pydough_functions.py | 2 +- tests/test_pipeline.py | 29 +++++++++++++++-------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index df9c02fd..7ec5ba0e 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -251,7 +251,7 @@ def triple_partition(): .lines.WHERE((MONTH(ship_date) == 6) & (YEAR(ship_date) == 1992))( supp_region=supplier.nation.region.name, ) - .order( + .order.WHERE(YEAR(order_date) == 1992)( supp_region=BACK(1).supp_region, part_type=BACK(2).part_type, cust_region=customer.nation.region.name, diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 7b883c5b..41605af3 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1020,20 +1020,21 @@ AGGREGATE(keys={'cust_region': cust_region, 'part_type': part_type, 'supp_region': supp_region}, aggregations={'agg_0': COUNT()}) PROJECT(columns={'cust_region': name_15, 'part_type': part_type, 'supp_region': supp_region}) JOIN(conditions=[t0.customer_key == t1.key], types=['left'], columns={'name_15': t1.name_15, 'part_type': t0.part_type, 'supp_region': t0.supp_region}) - JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'customer_key': t1.customer_key, 'part_type': t0.part_type, 'supp_region': t0.supp_region}) - PROJECT(columns={'order_key': order_key, 'part_type': part_type, 'supp_region': name_7}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'name_7': t1.name_7, 'order_key': t0.order_key, 'part_type': t0.part_type}) - FILTER(condition=MONTH(ship_date) == 6:int64 & YEAR(ship_date) == 1992:int64, columns={'order_key': order_key, 'part_type': part_type, 'supplier_key': supplier_key}) - JOIN(conditions=[t0.key == t1.part_key], types=['inner'], columns={'order_key': t1.order_key, 'part_type': t0.part_type, 'ship_date': t1.ship_date, 'supplier_key': t1.supplier_key}) - FILTER(condition=STARTSWITH(container, 'SM':string), columns={'key': key, 'part_type': part_type}) - SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'part_type': p_type}) - SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'part_key': l_partkey, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_7': t1.name}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) + FILTER(condition=YEAR(order_date) == 1992:int64, columns={'customer_key': customer_key, 'part_type': part_type, 'supp_region': supp_region}) + JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'customer_key': t1.customer_key, 'order_date': t1.order_date, 'part_type': t0.part_type, 'supp_region': t0.supp_region}) + PROJECT(columns={'order_key': order_key, 'part_type': part_type, 'supp_region': name_7}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'name_7': t1.name_7, 'order_key': t0.order_key, 'part_type': t0.part_type}) + FILTER(condition=MONTH(ship_date) == 6:int64 & YEAR(ship_date) == 1992:int64, columns={'order_key': order_key, 'part_type': part_type, 'supplier_key': supplier_key}) + JOIN(conditions=[t0.key == t1.part_key], types=['inner'], columns={'order_key': t1.order_key, 'part_type': t0.part_type, 'ship_date': t1.ship_date, 'supplier_key': t1.supplier_key}) + FILTER(condition=STARTSWITH(container, 'SM':string), columns={'key': key, 'part_type': part_type}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'part_type': p_type}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'part_key': l_partkey, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_7': t1.name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_15': t1.name}) JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) From 39ce2aca292f98f538f4ad6fde883d44be6f6964 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Wed, 22 Jan 2025 13:36:08 -0500 Subject: [PATCH 052/112] Fixing unit test [RUN CI] --- tests/test_qdag_conversion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_qdag_conversion.py b/tests/test_qdag_conversion.py index f2b54a09..ac126436 100644 --- a/tests/test_qdag_conversion.py +++ b/tests/test_qdag_conversion.py @@ -2076,10 +2076,10 @@ ), """ ROOT(columns=[('name', name), ('n_top_suppliers', n_top_suppliers)], orderings=[]) - PROJECT(columns={'n_top_suppliers': DEFAULT_TO(agg_1, 0:int64), 'name': name}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_1': t1.agg_1, 'name': t0.name}) + PROJECT(columns={'n_top_suppliers': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': COUNT(key)}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT(key)}) LIMIT(limit=Literal(value=100, type=Int64Type()), columns={'key': key, 'nation_key': nation_key}, orderings=[(ordering_0):asc_last]) PROJECT(columns={'key': key, 'nation_key': nation_key, 'ordering_0': account_balance}) SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) From 28065c046496cc293c63b3fb8c11a6e7be1db0ad Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Wed, 22 Jan 2025 18:00:52 -0500 Subject: [PATCH 053/112] Update documentation/dsl.md Co-authored-by: Hadia Ahmed --- documentation/dsl.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index e1a530ec..2c02ab9b 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -105,7 +105,7 @@ People ### Sub-Collections -The next step in PyDough after accessing a collection is accessing any of its sub-collections. The syntax `collection.subcollection` steps into every record of `subcollection` for each record of `collection`. This can result in changes of cardinality if records of `collection` can have multiple records of `subcollection`, and can result in duplicate records in the output if records of `subcollection` can be sourced from different records of `collection`. +The next step in PyDough after accessing a collection is to access its sub-collections. Using the syntax `collection.subcollection`, you can traverse into every record of `subcollection` for each record in `collection`. This operation may change the cardinality if records of `collection` have multiple associated records in `subcollection`. Additionally, duplicate records may appear in the output if records in `subcollection` are linked to multiple records in `collection`. **Good Example #1**: for every person, obtains their current address. Every scalar property of `Addresses` (`address_id`, `street_number`, `street_name`, `apartment`, `zip_code`, `city`, `state`) is automatically included in the output. A record from `Addresses` can be included multiple times if multiple different `People` records have it as their current address, or it could be missing entirely if no person has it as their current address. From d88192fcaf1bffb62063e020d0e8576b2c07c880 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Wed, 22 Jan 2025 18:01:13 -0500 Subject: [PATCH 054/112] Update documentation/dsl.md Co-authored-by: Hadia Ahmed --- documentation/dsl.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index 2c02ab9b..f6ba3fd5 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -114,7 +114,7 @@ The next step in PyDough after accessing a collection is to access its sub-colle People.current_addresses ``` -**Good Example #2**: for every package, obtains the person who shipped it address. The `GRAPH.` prefix is optional and implied when the term is a collection name in the graph. Every scalar property of `People` (`first_name`, `middle_name`, `last_name`, `ssn`, `birth_date`, `email`, `current_address_id`) is automatically included in the output. A record from `People` can be included multiple times if multiple packages were ordered by that person, or it could be missing entirely if that person is not the customer who ordered any package. +**Good Example #2**: For every package, get the person who shipped it. The `GRAPH.` prefix is optional and implied when the term is a collection name in the graph. Every scalar property of `People` (`first_name`, `middle_name`, `last_name`, `ssn`, `birth_date`, `email`, `current_address_id`) is automatically included in the output. A record from `People` can be included multiple times if multiple packages were ordered by that person, or it could be missing entirely if that person is not the customer who ordered any package. ```py %%pydough From fed91e56ce8d821bf5a954f59be5b3694d8cc385 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 23 Jan 2025 10:03:19 -0500 Subject: [PATCH 055/112] Added extra WHERE examples --- documentation/dsl.md | 215 ++++++++++++++++++++++--------------- documentation/functions.md | 2 +- 2 files changed, 130 insertions(+), 87 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index f6ba3fd5..fe3e81ad 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -76,7 +76,7 @@ Packages ```py %%pydough -Addresses +Products ``` **Bad Example #2**: obtains every record of the `Addresses` collection (but the name `Addresses` has been reassigned to a variable). @@ -94,14 +94,6 @@ Addresses HELLO.Addresses ``` -**Bad Example #4**: obtains every record of the `People` collection (but the name `People` has been reassigned to a variable). - -```py -%%pydough -People = "not a collection" -People -``` - ### Sub-Collections @@ -187,7 +179,7 @@ Packages( ) ``` -**Good Example #3**: For every person, finds their full name (without the middle name) and counts how many packages they purchased. +**Good Example #3**: For every person, find their full name (without the middle name) and count how many packages they purchased. ```py %%pydough @@ -197,7 +189,7 @@ People( ) ``` -**Good Example #4**: For every person, finds their full name including the middle name if one exists, as well as their email. Notice that two CALCs are present, but only the terms from the second one are part of the answer. +**Good Example #4**: For every person, find their full name including the middle name if one exists, as well as their email. Notice that two CALCs are present, but only the terms from the second one are part of the answer. ```py %%pydough @@ -211,7 +203,7 @@ People( ) ``` -**Good Example #5**: For every person, finds the year from the most recent package they purchased, and from the first package they ever purchased. +**Good Example #5**: For every person, find the year from the most recent package they purchased, and from the first package they ever purchased. ```py %%pydough @@ -232,7 +224,7 @@ GRAPH( ) ``` -**Good Example #7**: For each package, lists the package id and whether the package was shipped to the current address of the person who ordered it. +**Good Example #7**: For each package, list the package id and whether the package was shipped to the current address of the person who ordered it. ```py %%pydough @@ -242,14 +234,14 @@ Packages( ) ``` -**Bad Example #1**: For each person, lists their first name, last name, and phone number. This is invalid because `People` does not have a property named `phone_number`. +**Bad Example #1**: For each person, list their first name, last name, and phone number. This is invalid because `People` does not have a property named `phone_number`. ```py %%pydough People(first_name, last_name, phone_number) ``` -**Bad Example #2**: For each person, lists their combined first & last name followed by their email. This is invalid because a positional argument is included after a keyword argument. +**Bad Example #2**: For each person, list their combined first & last name followed by their email. This is invalid because a positional argument is included after a keyword argument. ```py %%pydough @@ -259,14 +251,14 @@ People( ) ``` -**Bad Example #3**: For each person, lists the address_id of packages they have ordered. This is invalid because `packages` is a plural property of `People`, so its properties cannot be included in a calc term of `People` unless aggregated. +**Bad Example #3**: For each person, list the address_id of packages they have ordered. This is invalid because `packages` is a plural property of `People`, so its properties cannot be included in a calc term of `People` unless aggregated. ```py %%pydough People(packages.address_id) ``` -**Bad Example #4**: For each person, lists their first/last name followed by the concatenated city/state name of their current address. This is invalid because `current_address` is a plural property of `People`, so its properties cannot be included in a calc term of `People` unless aggregated. +**Bad Example #4**: For each person, list their first/last name followed by the concatenated city/state name of their current address. This is invalid because `current_address` is a plural property of `People`, so its properties cannot be included in a calc term of `People` unless aggregated. ```py %%pydough @@ -277,21 +269,21 @@ People( ) ``` -**Bad Example #5**: For each address, finds whether the state name starts with `"C"`. This is invalid because it calls the builtin Python `.startswith` string method, which is not supported in PyDough (should have instead used a defined PyDough behavior, like the `STARTSWITH` function). +**Bad Example #5**: For each address, find whether the state name starts with `"C"`. This is invalid because it calls the builtin Python `.startswith` string method, which is not supported in PyDough (should have instead used a defined PyDough behavior, like the `STARTSWITH` function). ```py %%pydough Addresses(is_c_state=state.startswith("c")) ``` -**Bad Example #6**: For each address, finds the state bird of the state it is in. This is invalid because the `state` property of each record of `Addresses` is a scalar expression, not a subcolleciton, so it does not have any properties that can be accessed with `.` syntax. +**Bad Example #6**: For each address, find the state bird of the state it is in. This is invalid because the `state` property of each record of `Addresses` is a scalar expression, not a subcolleciton, so it does not have any properties that can be accessed with `.` syntax. ```py %%pydough Addresses(state_bird=state.bird) ``` -**Bad Example #7**: For each current occupant of each address, lists their first name, last name, and city/state they live in. This is invalid because `city` and `state` are not properties of the current collection (`People`, accessed via `current_occupants` of each record of `Addresses`). +**Bad Example #7**: For each current occupant of each address, list their first name, last name, and city/state they live in. This is invalid because `city` and `state` are not properties of the current collection (`People`, accessed via `current_occupants` of each record of `Addresses`). ```py %%pydough @@ -305,7 +297,7 @@ Addresses.current_occupants(first_name, last_name, city, state) People(ssn, current_address) ``` -**Bad Example #9**: For each person, lists their first name, last name, and the sum of the package costs. This is invalid because `SUM` is an aggregation function and cannot be used in a CALC term without specifying the sub-collection it should be applied to. +**Bad Example #9**: For each person, list their first name, last name, and the sum of the package costs. This is invalid because `SUM` is an aggregation function and cannot be used in a CALC term without specifying the sub-collection it should be applied to. ```py %%pydough @@ -330,7 +322,7 @@ People( ) ``` -**Good Example #2**: for every person, finds the total value of all packages they ordered in February of any year, as well as the number of all such packages, the largest value of any such package, and the percentage of those packages that were specifically on valentine's day +**Good Example #2**: for every person, find the total value of all packages they ordered in February of any year, as well as the number of all such packages, the largest value of any such package, and the percentage of those packages that were specifically on Valentine's day ```py %%pydough @@ -386,7 +378,7 @@ People(february=is_february) Part of the benefit of doing `collection.subcollection` accesses is that properties from the ancestor collection can be accessed from the current collection. This is done via a `BACK` call. Accessing properties from `BACK(n)` can be done to access properties from the n-th ancestor of the current collection. The simplest recommended way to do this is to just access a scalar property of an ancestor in order to include it in the final answer. -**Good Example #1**: For every address' current occupants, lists their first name last name, and the city/state of the current address they belong to. +**Good Example #1**: For each address's current occupants, list their first name last name, and the city/state of the current address they belong to. ```py %%pydough @@ -474,13 +466,6 @@ Customers.packages( ) ``` -**Bad Example #6**: The 1st ancestor of `current_occupants` is `Addresses` which does not have a term named `phone`. - -```py -%%pydough -Addresses.current_occupants(a=BACK(1).phone) -``` - ### Expressions @@ -587,14 +572,14 @@ Packages.WHERE(package_cost > 100)(package_id, shipping_state=shipping_address.s People(first_name, last_name, email).WHERE(COUNT(packages) > 5) ``` -**Good Example #4**: Finds every person whose most recent order was shipped in the year 2023, and lists all properties of that person. +**Good Example #4**: Find every person whose most recent order was shipped in the year 2023, and list all properties of that person. ```py %%pydough People.WHERE(YEAR(MAX(packages.order_date)) == 2023) ``` -**Good Example #5**: Counts how many packages were ordered in January of 2018. +**Good Example #5**: Count how many packages were ordered in January of 2018. [See here](functions.md#logical) for more details on the valid/invalid use of logical operations in Python. ```py %%pydough @@ -604,6 +589,50 @@ packages_jan_2018 = Packages.WHERE( GRAPH(n_jan_2018=COUNT(selected_packages)) ``` +**Good Example #6**: Counts how many people have don't have a first or last name that starts with A. [See here](functions.md#logical) for more details on the valid/invalid use of logical operations in Python. + +```py +%%pydough +selected_people = People.WHERE( + ~STARTSWITH(first_name, "A") & ~STARTSWITH(first_name, "B") +) +GRAPH(n_people=COUNT(selected_people)) +``` + +**Good Example #7**: Counts how many people have a gmail or yahoo account. [See here](functions.md#logical) for more details on the valid/invalid use of logical operations in Python. + +```py +%%pydough +gmail_or_yahoo = People.WHERE( + ENDSWITH(email, "@gmail.com") | ENDSWITH(email, "@yahoo.com") +) +GRAPH(n_gmail_or_yahoo=COUNT(gmail_or_yahoo)) +``` + +**Good Example #8**: Counts how many people were born in the 1980s. [See here](functions.md#comparisons) for more details on the valid/invalid use of comparisons in Python. + +```py +%%pydough +eighties_babies = People.WHERE( + (1980 <= YEAR(birth_date)) & (YEAR(birth_date) < 1990) +) +GRAPH(n_eighties_babies=COUNT(eighties_babies)) +``` + +**Good Example #9**: Find every person whose has sent a package to Idaho. + +```py +%%pydough +People.WHERE(HAS(packages.WHERE(shipping_address.state == "ID"))) +``` + +**Good Example #10**: Find every person whose did not order a package in 2024. + +```py +%%pydough +People.WHERE(HASNOT(packages.WHERE(YEAR(order_date) == 2024))) +``` + **Bad Example #1**: For every person, fetches their first name and last name only if they have a phone number. This is invalid because `People` does not have a property named `phone_number`. ```py @@ -611,21 +640,35 @@ GRAPH(n_jan_2018=COUNT(selected_packages)) People.WHERE(PRESENT(phone_number))(first_name, last_name) ``` -**Bad Example #2**: For every package, fetches the package id only if the package cost is greater than 100 and the shipping state is Texas. This is invalid because `shipping_state` is not a property of `Packages`. +**Bad Example #2**: For every package, fetches the package id only if the package cost is greater than 100 and the shipping state is Texas. This is invalid because `and` is used instead of `&`. [See here](functions.md#logical) for more details on the valid/invalid use of logical operations in Python. ```py %%pydough -Packages.WHERE((package_cost > 100) & (shipping_state == "TX"))(package_id) +Packages.WHERE((package_cost > 100) and (shipping_address.state == "TX"))(package_id) ``` -**Bad Example #3**: For every package, fetches the package id only if the package cost is greater than 100 and the shipping state is Texas. This is invalid because `and` is used instead of `&`. +**Bad Example #3**: For every package, fetches the package id only if the package is either being shipped from Pennsylvania or to Pennsylvania. This is invalid because `or` is used instead of `|`. [See here](functions.md#logical) for more details on the valid/invalid use of logical operations in Python. ```py %%pydough -Packages.WHERE((package_cost > 100) and (shipping_address.state == "TX"))(package_id) +Packages.WHERE((customer.current_address.state == "PA") or (shipping_address.state == "PA"))(package_id) +``` + +**Bad Example #4**: For every package, fetches the package id only if the customer's first name does not start with a J. This is invalid because `not` is used instead of `~`. [See here](functions.md#logical) for more details on the valid/invalid use of logical operations in Python. + +```py +%%pydough +Packages.WHERE(not STARTSWITH(customer.first_name, "J"))(package_id) +``` + +**Bad Example #5**: For every package, fetches the package id only if the package was ordered between February and May. [See here](functions.md#comparisons) for more details on the valid/invalid use of comparisons in Python. + +```py +%%pydough +Packages.WHERE(2 <= MONTH(arrival_date) <= 5)(package_id) ``` -**Bad Example #4**: Obtain every person whose packages were shipped in the month of June. This is invalid because `packages` is a plural property of `People`, so `MONTH(packages.order_date) == 6` is a plural expression with regards to `People` that cannot be used as a filtering condition. +**Bad Example #6**: Obtain every person whose packages were shipped in the month of June. This is invalid because `packages` is a plural property of `People`, so `MONTH(packages.order_date) == 6` is a plural expression with regards to `People` that cannot be used as a filtering condition. ```py %%pydough @@ -651,7 +694,7 @@ If there are multiple `ORDER_BY` terms, the last one is the one that takes prece People.ORDER_BY(last_name.ASC(), first_name.ASC(), middle_name.ASC(na_pos="last")) ``` -**Good Example #2**: For every person lists their ssn & how many packages they have ordered, and orders them from highest number of orders to lowest, breaking ties in favor of whoever is oldest. +**Good Example #2**: For every person list their ssn & how many packages they have ordered, and orders them from highest number of orders to lowest, breaking ties in favor of whoever is oldest. ```py %%pydough @@ -662,7 +705,7 @@ People( ) ``` -**Good Example #3**: Finds every address that has at least 1 person living in it and sorts them highest-to-lowest by number of occupants, with ties broken by address id in ascending order. +**Good Example #3**: Find every address that has at least 1 person living in it and sorts them highest-to-lowest by number of occupants, with ties broken by address id in ascending order. ```py %%pydough @@ -697,7 +740,7 @@ Addresses.WHERE( ) ``` -**Good Example #6**: Finds all people who are in the top 1% of customers according to number of packages ordered. +**Good Example #6**: Find all people who are in the top 1% of customers according to number of packages ordered. ```py %%pydough @@ -767,7 +810,7 @@ The syntax for this is `.TOP_K(k, by=...)` where `k` is a positive integer and t The terms in the collection are unchanged by the `TOP_K` clause, since the only change is the order of the records and which ones are kept/dropped. -**Good Example #1**: Finds the 10 people who have ordered the most packages, including their first/last name, birth date, and the number of packages. If there is a tie, break it by the lowest ssn. +**Good Example #1**: Find the 10 people who have ordered the most packages, including their first/last name, birth date, and the number of packages. If there is a tie, break it by the lowest ssn. ```py %%pydough @@ -779,14 +822,14 @@ People( ).TOP_K(10, by=(n_packages.DESC(), ssn.ASC())) ``` -**Good Example #2**: Finds the 5 most recently shipped packages, with ties broken arbitrarily. +**Good Example #2**: Find the 5 most recently shipped packages, with ties broken arbitrarily. ```py %%pydough Packages.TOP_K(5, by=order_date.DESC()) ``` -**Good Example #3**: Finds the 100 addresses that have most recently had packages either shipped or billed to them, breaking ties arbitrarily. +**Good Example #3**: Find the 100 addresses that have most recently had packages either shipped or billed to them, breaking ties arbitrarily. ```py %%pydough @@ -797,7 +840,7 @@ most_recent_package = IFF(most_recent_ship < most_recent_bill, most_recent_ship, Addresses.TOP_K(10, by=most_recent_package.DESC()) ``` -**Good Example #4**: Finds the top 3 people who have spent the most money on packages, including their first/last name, and the total cost of all of their packages. +**Good Example #4**: Find the top 3 people who have spent the most money on packages, including their first/last name, and the total cost of all of their packages. ```py %%pydough @@ -808,21 +851,21 @@ People( ).TOP_K(3, by=total_package_cost.DESC()) ``` -**Bad Example #1**: Finds the 5 people with the lowest GPA. This is invalid because the `People` collection does not have a `gpa` property. +**Bad Example #1**: Find the 5 people with the lowest GPA. This is invalid because the `People` collection does not have a `gpa` property. ```py %%pydough People.TOP_K(5, by=gpa.ASC()) ``` -**Bad Example #2**: Finds the 25 addresses with the earliest packages billed to them, by arrival date. This is invalid because `packages_billed` is a plural property of `Addresses`, so `packages_billed.arrival_date` cannot be used as a collation expression for `Addresses`. +**Bad Example #2**: Find the 25 addresses with the earliest packages billed to them, by arrival date. This is invalid because `packages_billed` is a plural property of `Addresses`, so `packages_billed.arrival_date` cannot be used as a collation expression for `Addresses`. ```py %%pydough Addresses.packages_billed(25, by=gpa.packages_billed.arrival_date()) ``` -**Bad Example #3**: Finds the top 100 people currently living in the city of San Francisco. This is invalid because the `by` clause is absent. +**Bad Example #3**: Find the top 100 people currently living in the city of San Francisco. This is invalid because the `by` clause is absent. ```py %%pydough @@ -831,21 +874,21 @@ People.WHERE( ).TOP_K(100) ``` -**Bad Example #4**: Finds the top packages by highest value. This is invalid because there is no `k` value. +**Bad Example #4**: Find the top packages by highest value. This is invalid because there is no `k` value. ```py %%pydough Packages.TOP_K(by=package_cost.DESC()) ``` -**Bad Example #5**: Finds the top 300 addresses. This is invalid because the `by` clause is empty +**Bad Example #5**: Find the top 300 addresses. This is invalid because the `by` clause is empty ```py %%pydough Addresses.TOP_K(300, by=()) ``` -**Bad Example #6**: Finds the 1000 people by birth date. This is invalid because the collation term does not have `.ASC()` or `.DESC()`. +**Bad Example #6**: Find the 1000 people by birth date. This is invalid because the collation term does not have `.ASC()` or `.DESC()`. ```py %%pydough @@ -866,14 +909,14 @@ If the partitioned data is accessed, its original ancestry is lost. Instead, it The ancestry of the `PARTITION` clause can be changed by prepending it with another collection, separated by a dot. However, this is currently only supported in PyDough when the collection before the dot is just an augmented version of the graph context, as opposed to another collection (e.g. `GRAPH(x=42).PARTITION(...)` is supported, but `People.PARTITION(...)` is not). -**Good Example #1**: Finds every unique state. +**Good Example #1**: Find every unique state. ```py %%pydough PARTITION(Addresses, name="addrs", by=state)(state) ``` -**Good Example #2**: For every state, counts how many addresses are in that state. +**Good Example #2**: For every state, count how many addresses are in that state. ```py %%pydough @@ -883,7 +926,7 @@ PARTITION(Addresses, name="addrs", by=state)( ) ``` -**Good Example #3**: For every city/state, counts how many people live in that city/state. +**Good Example #3**: For every city/state, count how many people live in that city/state. ```py %%pydough @@ -894,7 +937,7 @@ PARTITION(Addresses, name="addrs", by=(city, state))( ) ``` -**Good Example #4**: Finds the top 5 years with the most people born in that year who have yahoo email accounts, listing the year and the number of people. +**Good Example #4**: Find the top 5 years with the most people born in that year who have yahoo email accounts, listing the year and the number of people. ```py %%pydough @@ -907,7 +950,7 @@ PARTITION(yahoo_people, name="yah_ppl", by=birth_year)( ).TOP_K(5, by=n_people.DESC()) ``` -**Good Example #4**: For every year/month, finds all packages that were below the average cost of all packages ordered in that year/month. +**Good Example #4**: For every year/month, find all packages that were below the average cost of all packages ordered in that year/month. ```py %%pydough @@ -919,7 +962,7 @@ PARTITION(package_info, name="packs", by=(order_year, order_month))( ) ``` -**Good Example #5**: For every customer, finds the percentage of all orders made by current occupants of that city/state made by that specific customer. Includes the first/last name of the person, the city/state they live in, and the percentage. +**Good Example #5**: For every customer, find the percentage of all orders made by current occupants of that city/state made by that specific customer. Includes the first/last name of the person, the city/state they live in, and the percentage. ```py %%pydough @@ -934,7 +977,7 @@ PARTITION(Addresses, name="addrs", by=(city, state))( ) ``` -**Good Example #6**: Identifies the states whose current occupants account for at least 1% of all packages purchased. Lists the state and the percentage. +**Good Example #6**: Identifies the states whose current occupants account for at least 1% of all packages purchased. List the state and the percentage. ```py %%pydough @@ -961,7 +1004,7 @@ GRAPH( ).WHERE(COUNT(packs) > BACK(1).avg_packages_per_month) ``` -**Good Example #8**: Finds the 10 most frequent combinations of the state that the person lives in and the first letter of that person's name. +**Good Example #8**: Find the 10 most frequent combinations of the state that the person lives in and the first letter of that person's name. ```py %%pydough @@ -1012,7 +1055,7 @@ PARTITION(Addresses, by=state) PARTITION(People, name="ppl") ``` -**Bad Example #4**: Counts how many packages were ordered in each year. Invalid because `YEAR(order_date)` is not allowed ot be used as a partition term (it must be placed in a CALC so it is accessible as a named reference). +**Bad Example #4**: Count how many packages were ordered in each year. Invalid because `YEAR(order_date)` is not allowed ot be used as a partition term (it must be placed in a CALC so it is accessible as a named reference). ```py %%pydough @@ -1021,7 +1064,7 @@ PARTITION(Packages, name="packs", by=YEAR(order_date))( ) ``` -**Bad Example #5**: Counts how many people live in each state. Invalid because `current_address.state` is not allowed to be used as a partition term (it must be placed in a CALC so it is accessible as a named reference). +**Bad Example #5**: Count how many people live in each state. Invalid because `current_address.state` is not allowed to be used as a partition term (it must be placed in a CALC so it is accessible as a named reference). ```py %%pydough @@ -1052,7 +1095,7 @@ PARTITION(People(birth_year=YEAR(birth_date)), name="ppl", by=birth_year)( ) ``` -**Bad Example #7**: For each person & year, counts how many times that person ordered a packaged in that year. This is invalid because doing `.PARTITION` after `People` is unsupported, since `People` is not a graph-level collection like `GRAPH(...)`. +**Bad Example #7**: For each person & year, count how many times that person ordered a packaged in that year. This is invalid because doing `.PARTITION` after `People` is unsupported, since `People` is not a graph-level collection like `GRAPH(...)`. ```py %%pydough @@ -1143,7 +1186,7 @@ The arguments to `NEXT` and `PREV` are as follows: If the entry `n` records before/after the current entry does not exist, then accessing anything from it returns null. Anything that can be done to the current context can also be done to the `PREV`/`NEXT` call (e.g. aggregating data from a plural sub-collection). -**Good Example #1**: For each package, finds whether it was ordered by the same customer as the most recently ordered package before it. +**Good Example #1**: For each package, find whether it was ordered by the same customer as the most recently ordered package before it. ```py %%pydough @@ -1153,7 +1196,7 @@ Packages( ) ``` -**Good Example #2**: Finds the average number of hours between every package ordered by every customer. +**Good Example #2**: Find the average number of hours between every package ordered by every customer. ```py %%pydough @@ -1167,7 +1210,7 @@ Customers( ) ``` -**Good Example #3**: Finds out for each customer whether, if they were sorted by number of packages ordered, whether they live in the same state as any of the 3 people below them on the list. +**Good Example #3**: Find out for each customer whether, if they were sorted by number of packages ordered, whether they live in the same state as any of the 3 people below them on the list. ```py %%pydough @@ -1184,7 +1227,7 @@ Customers( ) ``` -**Bad Example #1**: Finds the number of hours between each package and the previous package. This is invalid because the `by` argument is missing +**Bad Example #1**: Find the number of hours between each package and the previous package. This is invalid because the `by` argument is missing ```py %%pydough @@ -1193,7 +1236,7 @@ Packages( ) ``` -**Bad Example #2**: Finds the number of hours between each package and the next package. This is invalid because the `by` argument is empty. +**Bad Example #2**: Find the number of hours between each package and the next package. This is invalid because the `by` argument is empty. ```py %%pydough @@ -1202,7 +1245,7 @@ Packages( ) ``` -**Bad Example #3**: Finds the number of hours between each package and the 5th-previous package. This is invalid because the `by` argument is not a collation. +**Bad Example #3**: Find the number of hours between each package and the 5th-previous package. This is invalid because the `by` argument is not a collation. ```py %%pydough @@ -1211,7 +1254,7 @@ Packages( ) ``` -**Bad Example #4**: Finds the number of hours between each package and a subsequent package. This is invalid because the `n` argument is not an integer. +**Bad Example #4**: Find the number of hours between each package and a subsequent package. This is invalid because the `n` argument is not an integer. ```py %%pydough @@ -1229,7 +1272,7 @@ Packages( ) ``` -**Bad Example #6**: Finds the number of hours between each package and the previous package. This invalid because a property `.odate` is accessed that does not exist in the collection, therefore it doesn't exist in `PREV` either. +**Bad Example #6**: Find the number of hours between each package and the previous package. This invalid because a property `.odate` is accessed that does not exist in the collection, therefore it doesn't exist in `PREV` either. ```py %%pydough @@ -1245,7 +1288,7 @@ Packages( Packages.PREV(order_date.ASC()) ``` -**Bad Example #8**: Finds the number of hours between each package and the previous package ordered by the customer. This invalid because the `levels` value is too large, since only 2 ancestor levels exist in `Customers.packages` (the graph, and `Customers`): +**Bad Example #8**: Find the number of hours between each package and the previous package ordered by the customer. This invalid because the `levels` value is too large, since only 2 ancestor levels exist in `Customers.packages` (the graph, and `Customers`): ```py %%pydough @@ -1270,7 +1313,7 @@ Additional keyword arguments can be supplied to `BEST` that change its behavior: - `allow_ties` (default=False): if True, changes the behavior to keep all records of the sub-collection that share the optimal values of the collation terms. If `allow_ties` is True, the `BEST` clause is no longer singular. - `n_best=True`(defaults=1): if an integer greater than 1, changes the behavior to keep the top `n_best` values of the sub-collection for each record of the parent collection (fewer if `n_best` records of the sub-collection do not exist). If `n_best` is greater than 1, the `BEST` clause is no longer singular. NOTE: `n_best` cannot be greater than 1 at the same time that `allow_ties` is True. -**Good Example #1**: Finds the package id & zip code the package was shipped to for every package that was the first-ever purchase for the customer. +**Good Example #1**: Find the package id & zip code the package was shipped to for every package that was the first-ever purchase for the customer. ```py %%pydough @@ -1280,7 +1323,7 @@ Customers.BEST(packages, by=order_date.ASC())( ) ``` -**Good Example #2**: For each customer, lists their ssn and the cost of the most recent package they have purchased. +**Good Example #2**: For each customer, list their ssn and the cost of the most recent package they have purchased. ```py %%pydough @@ -1290,7 +1333,7 @@ Customers( ) ``` -**Good Example #3**: Finds the address in the state of New York with the most occupants, ties broken by address id. Note: the `GRAPH.` prefix is optional in this case, since it is implied if there is no prefix to the `BEST` call. +**Good Example #3**: Find the address in the state of New York with the most occupants, ties broken by address id. Note: the `GRAPH.` prefix is optional in this case, since it is implied if there is no prefix to the `BEST` call. ```py %%pydough @@ -1300,7 +1343,7 @@ addr_info = Addresses.WHERE( GRAPH.BEST(addr_info, by=(n_occupants.DESC(), address_id.ASC())) ``` -**Good Example #4**: For each customer, finds the number of people currently living in the address that they most recently shipped a package to. +**Good Example #4**: For each customer, find the number of people currently living in the address that they most recently shipped a package to. ```py %%pydough @@ -1311,7 +1354,7 @@ Customers( ) ``` -**Good Example #5**: For each address that has occupants, lists out the first/last name of the person living in that address who has ordered the most packages, breaking ties in favor of the person with the smaller social security number. Also includes the city/state of the address, the number of people who live there, and the number of packages that person ordered. +**Good Example #5**: For each address that has occupants, list out the first/last name of the person living in that address who has ordered the most packages, breaking ties in favor of the person with the smaller social security number. Also includes the city/state of the address, the number of people who live there, and the number of packages that person ordered. ```py %%pydough @@ -1330,7 +1373,7 @@ Addresses.WHERE(HAS(current_occupants))( ) ``` -**Good Example #6**: For each person, finds the total value of the 5 most recent packages they ordered. +**Good Example #6**: For each person, find the total value of the 5 most recent packages they ordered. ```py %%pydough @@ -1341,7 +1384,7 @@ People( ) ``` -**Good Example #7**: For each address, finds the package most recently ordered by one of the current occupants of that address, including the email of the occupant who ordered it and the address' id. Notice that `BACK(1)` refers to `current_occupants` and `BACK(2)` refers to `Addresses` as if the packages were accessed as `Addresses.current_occupants.packages` instead of using `BEST`. +**Good Example #7**: For each address, find the package most recently ordered by one of the current occupants of that address, including the email of the occupant who ordered it and the address' id. Notice that `BACK(1)` refers to `current_occupants` and `BACK(2)` refers to `Addresses` as if the packages were accessed as `Addresses.current_occupants.packages` instead of using `BEST`. ```py %%pydough @@ -1354,42 +1397,42 @@ Addresses.most_recent_package( ) ``` -**Bad Example #1**: For each person finds their best email. This is invalid because `email` is not a sub-collection of `People` (it is a scalar attribute, so there is only 1 `email` per-person). +**Bad Example #1**: For each person find their best email. This is invalid because `email` is not a sub-collection of `People` (it is a scalar attribute, so there is only 1 `email` per-person). ```py %%pydough People(first_name, BEST(email, by=birth_date.DESC())) ``` -**Bad Example #2**: For each person finds their best package. This is invalid because the `by` argument is missing. +**Bad Example #2**: For each person find their best package. This is invalid because the `by` argument is missing. ```py %%pydough People.BEST(packages) ``` -**Bad Example #3**: For each person finds their best package. This is invalid because the: `by` argument is not a collation +**Bad Example #3**: For each person find their best package. This is invalid because the: `by` argument is not a collation ```py %%pydough People.BEST(packages, by=order_date) ``` -**Bad Example #4**: For each person finds their best package. This is invalid because the `by` argument is empty +**Bad Example #4**: For each person find their best package. This is invalid because the `by` argument is empty ```py %%pydough People.BEST(packages, by=()) ``` -**Bad Example #5**: For each person finds the 5 most recent packages they have ordered, allowing ties. This is invalid because `n_best` is greater than 1 at the same time that `allow_ties` is True. +**Bad Example #5**: For each person find the 5 most recent packages they have ordered, allowing ties. This is invalid because `n_best` is greater than 1 at the same time that `allow_ties` is True. ```py %%pydough People.BEST(packages, by=order_date.DESC(), n_best=5, allow_ties=True) ``` -**Bad Example #6**: For each person, finds the package cost of their 10 most recent packages. This is invalid because `n_best` is greater than 1, which means that the `BEST` clause is non-singular so its terms cannot be accessed in the calc without aggregating. +**Bad Example #6**: For each person, find the package cost of their 10 most recent packages. This is invalid because `n_best` is greater than 1, which means that the `BEST` clause is non-singular so its terms cannot be accessed in the calc without aggregating. ```py %%pydough @@ -1397,7 +1440,7 @@ best_packages = BEST(packages, by=order_date.DESC(), n_best=10) People(first_name, best_cost=best_packages.package_cost) ``` -**Bad Example #7**: For each person, finds the package cost of their most expensive package(s), allowing ties. This is invalid because `allow_ties` is True, which means that the `BEST` clause is non-singular so its terms cannot be accessed in the calc without aggregating. +**Bad Example #7**: For each person, find the package cost of their most expensive package(s), allowing ties. This is invalid because `allow_ties` is True, which means that the `BEST` clause is non-singular so its terms cannot be accessed in the calc without aggregating. ```py %%pydough @@ -1405,7 +1448,7 @@ best_packages = BEST(packages, by=package_cost.DESC(), allow_ties=True) People(first_name, best_cost=best_packages.package_cost) ``` -**Bad Example #8**: For each address, finds the package most recently ordered by one of the current occupants of that address, including the address id of the address. This is invalid because `BACK(1)` refers to `current_occupants`, which does not have a field called `address_id`. +**Bad Example #8**: For each address, find the package most recently ordered by one of the current occupants of that address, including the address id of the address. This is invalid because `BACK(1)` refers to `current_occupants`, which does not have a field called `address_id`. ```py %%pydough @@ -1417,7 +1460,7 @@ Addresses.most_recent_package( ) ``` -**Bad Example #9**: For each address finds the oldest occupant. This is invalid because the `BEST` clause is placed in the calc without accessing any of its attributes. +**Bad Example #9**: For each address find the oldest occupant. This is invalid because the `BEST` clause is placed in the calc without accessing any of its attributes. ```py %%pydough diff --git a/documentation/functions.md b/documentation/functions.md index 478a48cb..3b921690 100644 --- a/documentation/functions.md +++ b/documentation/functions.md @@ -85,7 +85,7 @@ Customers( ``` > [!WARNING] -> Chained inequalities, like `a <= b <= c`, can cause undefined/incorrect behavior in PyDough. Instead, use expressions like `(a <= b) & (b <= c)`. +> Chained inequalities, like `a <= b <= c`, can cause undefined/incorrect behavior in PyDough. Instead, use expressions like `(a <= b) & (b <= c)`, or the [MONOTONIC](#monotonic) function. ### Logical From 7d9492dfcae7664f462f3defa66013c8428877f5 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 23 Jan 2025 10:15:55 -0500 Subject: [PATCH 056/112] Revisions --- documentation/dsl.md | 18 +++++----- pydough/sqlglot/sqlglot_relational_visitor.py | 15 ++++++-- tests/test_relational_nodes_to_sqlglot.py | 34 ++++--------------- tests/test_relational_to_sql.py | 2 +- 4 files changed, 30 insertions(+), 39 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index fe3e81ad..1e7c82c1 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -687,14 +687,14 @@ If there are multiple `ORDER_BY` terms, the last one is the one that takes prece > [!WARNING] > In the current version of PyDough, the behavior when the expressions inside an `ORDER_BY` clause are not collation expressions with `.ASC()` or `.DESC()` is undefined/unsupported. -**Good Example #1**: Orders every person alphabetically by last name, then first name, then middle name (people with no middle name going last). +**Good Example #1**: Order every person alphabetically by last name, then first name, then middle name (people with no middle name going last). ```py %%pydough People.ORDER_BY(last_name.ASC(), first_name.ASC(), middle_name.ASC(na_pos="last")) ``` -**Good Example #2**: For every person list their ssn & how many packages they have ordered, and orders them from highest number of orders to lowest, breaking ties in favor of whoever is oldest. +**Good Example #2**: For every person list their SSN & how many packages they have ordered, and order them from highest number of orders to lowest, breaking ties in favor of whoever is oldest. ```py %%pydough @@ -705,7 +705,7 @@ People( ) ``` -**Good Example #3**: Find every address that has at least 1 person living in it and sorts them highest-to-lowest by number of occupants, with ties broken by address id in ascending order. +**Good Example #3**: Find every address that has at least 1 person living in it and sort them highest-to-lowest by number of occupants, with ties broken by address id in ascending order. ```py %%pydough @@ -716,7 +716,7 @@ Addresses.WHERE( ) ``` -**Good Example #4**: Sorts every person alphabetically by the state they live in, then the city they live in, then by their ssn. People without a current address should go last. +**Good Example #4**: Sort every person alphabetically by the state they live in, then the city they live in, then by their ssn. People without a current address should go last. ```py %%pydough @@ -747,14 +747,14 @@ Addresses.WHERE( People.WHERE(PERCENTILE(by=COUNT(packages).ASC()) == 100) ``` -**Bad Example #1**: Sorts each person by their account balance in descending order. This is invalid because the `People` collection does not have an `account_balance` property. +**Bad Example #1**: Sort each person by their account balance in descending order. This is invalid because the `People` collection does not have an `account_balance` property. ```py %%pydough People.ORDER_BY(account_balance.DESC()) ``` -**Bad Example #2**: Sorts each address by the birth date date of the people who live there. This is invalid because `current_occupants` is a plural property of `Addresses`, so `current_occupants.birth_date` is plural and cannot be used as an ordering term unless aggregated. +**Bad Example #2**: Sort each address by the birth date of the people who live there. This is invalid because `current_occupants` is a plural property of `Addresses`, so `current_occupants.birth_date` is plural and cannot be used as an ordering term unless aggregated. ```py %%pydough @@ -787,14 +787,14 @@ Addresses.WHERE( ) ``` -**Bad Example #5**: Sorts every person by their first name. This is invalid because no `.ASC()` or `.DESC()` term is provided. +**Bad Example #5**: Sort every person by their first name. This is invalid because no `.ASC()` or `.DESC()` term is provided. ```py %%pydough People.ORDER_BY(first_name) ``` -**Bad Example #6**: Sorts every person. This is invalid because no collation terms are provided. +**Bad Example #6**: Sort every person. This is invalid because no collation terms are provided. ```py %%pydough @@ -851,7 +851,7 @@ People( ).TOP_K(3, by=total_package_cost.DESC()) ``` -**Bad Example #1**: Find the 5 people with the lowest GPA. This is invalid because the `People` collection does not have a `gpa` property. +**Bad Example #1**: Find the 5 people with the lowest GPAs. This is invalid because the `People` collection does not have a `gpa` property. ```py %%pydough diff --git a/pydough/sqlglot/sqlglot_relational_visitor.py b/pydough/sqlglot/sqlglot_relational_visitor.py index 31a6c5c7..5f1e17ae 100644 --- a/pydough/sqlglot/sqlglot_relational_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_visitor.py @@ -397,7 +397,7 @@ def visit_filter(self, filter: Filter) -> None: else: # TODO: (gh #151) Refactor a simpler way to check dependent expressions. if ( - "group_by" in input_expr.args + "group" in input_expr.args or "where" in input_expr.args or "qualify" in input_expr.args or "order" in input_expr.args @@ -426,7 +426,18 @@ def visit_aggregate(self, aggregate: Aggregate) -> None: for alias, col in aggregate.aggregations.items() ] select_cols = keys + aggregations - query: Select = self._build_subquery(input_expr, select_cols) + query: Select + if ( + "group" in input_expr.args + or "qualify" in input_expr.args + or "order" in input_expr.args + or "limit" in input_expr.args + ): + query = self._build_subquery(input_expr, select_cols) + else: + query = self._merge_selects( + select_cols, input_expr, find_identifiers_in_list(select_cols) + ) if keys: query = query.group_by(*keys) self._stack.append(query) diff --git a/tests/test_relational_nodes_to_sqlglot.py b/tests/test_relational_nodes_to_sqlglot.py index 8778b8d5..1075ea6e 100644 --- a/tests/test_relational_nodes_to_sqlglot.py +++ b/tests/test_relational_nodes_to_sqlglot.py @@ -633,12 +633,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: ), mkglot( expressions=[Ident(this="b")], - _from=GlotFrom( - mkglot( - expressions=[Ident(this="a"), Ident(this="b")], - _from=GlotFrom(Table(this=Ident(this="table"))), - ) - ), + _from=GlotFrom(Table(this=Ident(this="table"))), group_by=[Ident(this="b")], ), id="simple_distinct", @@ -722,19 +717,14 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: ), mkglot( expressions=[Ident(this="b")], + where=mkglot_func(EQ, [Ident(this="a"), mk_literal(1, False)]), + group_by=[Ident(this="b")], _from=GlotFrom( mkglot( - expressions=[Ident(this="b")], - _from=GlotFrom( - mkglot( - expressions=[Ident(this="a"), Ident(this="b")], - _from=GlotFrom(Table(this=Ident(this="table"))), - ) - ), - where=mkglot_func(EQ, [Ident(this="a"), mk_literal(1, False)]), + expressions=[Ident(this="a"), Ident(this="b")], + _from=GlotFrom(Table(this=Ident(this="table"))), ) ), - group_by=[Ident(this="b")], ), id="filter_before_aggregate", ), @@ -838,12 +828,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: ), mkglot( expressions=[Ident(this="b")], - _from=GlotFrom( - mkglot( - expressions=[Ident(this="a"), Ident(this="b")], - _from=GlotFrom(Table(this=Ident(this="table"))), - ) - ), + _from=GlotFrom(Table(this=Ident(this="table"))), group_by=[Ident(this="b")], order_by=[Ident(this="b").desc(nulls_first=False)], limit=mk_literal(10, False), @@ -881,12 +866,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: mkglot( expressions=[Ident(this="b")], group_by=[Ident(this="b")], - _from=GlotFrom( - mkglot( - expressions=[Ident(this="a"), Ident(this="b")], - _from=GlotFrom(Table(this=Ident(this="table"))), - ) - ), + _from=GlotFrom(Table(this=Ident(this="table"))), ) ), ), diff --git a/tests/test_relational_to_sql.py b/tests/test_relational_to_sql.py index 6d71c7db..8c00d76e 100644 --- a/tests/test_relational_to_sql.py +++ b/tests/test_relational_to_sql.py @@ -323,7 +323,7 @@ def sqlite_dialect() -> SQLiteDialect: aggregations={}, ), ), - "SELECT b FROM (SELECT a, b FROM table) GROUP BY b", + "SELECT b FROM table GROUP BY b", id="simple_distinct", ), pytest.param( From 2ef58e3ff557061441f21a0c26ca385219e536bd Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Thu, 23 Jan 2025 10:16:18 -0500 Subject: [PATCH 057/112] Update documentation/dsl.md Co-authored-by: Hadia Ahmed --- documentation/dsl.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index 1e7c82c1..e8cc2bee 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -113,7 +113,7 @@ People.current_addresses GRAPH.Packages.customer ``` -**Good Example #3**: for every address, obtains all packages that someone who lives at that address has ordered. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`). Every record from `Packages` should be included at most once since every current occupant has a single address it maps back to, and every package has a single customer it maps back to. +**Good Example #3**: For every address, get all packages that someone who lives at that address has ordered. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`). Every record from `Packages` should be included at most once since every current occupant has a single address it maps back to, and every package has a single customer it maps back to. ```py %%pydough From b904f3a7844c543fee3ac3deedc26400345dc230 Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Thu, 23 Jan 2025 10:16:34 -0500 Subject: [PATCH 058/112] Update documentation/dsl.md Co-authored-by: Hadia Ahmed --- documentation/dsl.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index e8cc2bee..d1bf1095 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -120,7 +120,7 @@ GRAPH.Packages.customer Addresses.current_occupants.packages ``` -**Good Example #4**: for every person, obtains all packages they have ordered. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`). Every record from `Packages` should be included at most once since every package has a single customer it maps back to. +**Good Example #4**: For every person, get all packages they have ordered. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`). Every record from `Packages` should be included at most once since every package has a single customer it maps back to. ```py %%pydough From 7f69536bfe635ec4a2e9ac93dcaabfcef00b9ea1 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 23 Jan 2025 10:17:34 -0500 Subject: [PATCH 059/112] Updating capitalization --- documentation/dsl.md | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index 1e7c82c1..082d6417 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -51,35 +51,35 @@ There are also the following sub-collection relationships: The simplest PyDough code is scanning an entire collection. This is done by providing the name of the collection in the metadata. However, if that name is already used as a variable, then PyDough will not know to replace the name with the corresponding PyDough object. -**Good Example #1**: obtains every record of the `People` collection. Every scalar property of `People` (`first_name`, `middle_name`, `last_name`, `ssn`, `birth_date`, `email`, `current_address_id`) is automatically included in the output. +**Good Example #1**: Obtains every record of the `People` collection. Every scalar property of `People` (`first_name`, `middle_name`, `last_name`, `ssn`, `birth_date`, `email`, `current_address_id`) is automatically included in the output. ```py %%pydough People ``` -**Good Example #2**: obtains every record of the `Addresses` collection. The `GRAPH.` prefix is optional and implied when the term is a collection name in the graph. Every scalar property of `Addresses` (`address_id`, `street_number`, `street_name`, `apartment`, `zip_code`, `city`, `state`) is automatically included in the output. +**Good Example #2**: Obtains every record of the `Addresses` collection. The `GRAPH.` prefix is optional and implied when the term is a collection name in the graph. Every scalar property of `Addresses` (`address_id`, `street_number`, `street_name`, `apartment`, `zip_code`, `city`, `state`) is automatically included in the output. ```py %%pydough GRAPH.Addresses ``` -**Good Example #3**: obtains every record of the `Packages` collection. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`) is automatically included in the output. +**Good Example #3**: Obtains every record of the `Packages` collection. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`) is automatically included in the output. ```py %%pydough Packages ``` -**Bad Example #1**: obtains every record of the `Products` collection (there is no `Products` collection). +**Bad Example #1**: Obtains every record of the `Products` collection (there is no `Products` collection). ```py %%pydough Products ``` -**Bad Example #2**: obtains every record of the `Addresses` collection (but the name `Addresses` has been reassigned to a variable). +**Bad Example #2**: Obtains every record of the `Addresses` collection (but the name `Addresses` has been reassigned to a variable). ```py %%pydough @@ -87,7 +87,7 @@ Addresses = 42 Addresses ``` -**Bad Example #3**: obtains every record of the `Addresses` collection (but the graph name `HELLO` is the wrong graph name for this example). +**Bad Example #3**: Obtains every record of the `Addresses` collection (but the graph name `HELLO` is the wrong graph name for this example). ```py %%pydough @@ -99,7 +99,7 @@ HELLO.Addresses The next step in PyDough after accessing a collection is to access its sub-collections. Using the syntax `collection.subcollection`, you can traverse into every record of `subcollection` for each record in `collection`. This operation may change the cardinality if records of `collection` have multiple associated records in `subcollection`. Additionally, duplicate records may appear in the output if records in `subcollection` are linked to multiple records in `collection`. -**Good Example #1**: for every person, obtains their current address. Every scalar property of `Addresses` (`address_id`, `street_number`, `street_name`, `apartment`, `zip_code`, `city`, `state`) is automatically included in the output. A record from `Addresses` can be included multiple times if multiple different `People` records have it as their current address, or it could be missing entirely if no person has it as their current address. +**Good Example #1**: For every person, obtains their current address. Every scalar property of `Addresses` (`address_id`, `street_number`, `street_name`, `apartment`, `zip_code`, `city`, `state`) is automatically included in the output. A record from `Addresses` can be included multiple times if multiple different `People` records have it as their current address, or it could be missing entirely if no person has it as their current address. ```py %%pydough @@ -113,28 +113,28 @@ People.current_addresses GRAPH.Packages.customer ``` -**Good Example #3**: for every address, obtains all packages that someone who lives at that address has ordered. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`). Every record from `Packages` should be included at most once since every current occupant has a single address it maps back to, and every package has a single customer it maps back to. +**Good Example #3**: For every address, obtains all packages that someone who lives at that address has ordered. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`). Every record from `Packages` should be included at most once since every current occupant has a single address it maps back to, and every package has a single customer it maps back to. ```py %%pydough Addresses.current_occupants.packages ``` -**Good Example #4**: for every person, obtains all packages they have ordered. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`). Every record from `Packages` should be included at most once since every package has a single customer it maps back to. +**Good Example #4**: For every person, obtains all packages they have ordered. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`). Every record from `Packages` should be included at most once since every package has a single customer it maps back to. ```py %%pydough People.packages ``` -**Bad Example #1**: for every address, obtains all people who used to live there. This is invalid because the `Addresses` collection does not have a `former_occupants` property. +**Bad Example #1**: For every address, obtains all people who used to live there. This is invalid because the `Addresses` collection does not have a `former_occupants` property. ```py %%pydough Addresses.former_occupants ``` -**Bad Example #2**: for every package, obtains all addresses it was shipped to. This is invalid because the `Packages` collection does not have a `shipping_addresses` property (it does have a `shipping_address` property). +**Bad Example #2**: For every package, obtains all addresses it was shipped to. This is invalid because the `Packages` collection does not have a `shipping_addresses` property (it does have a `shipping_address` property). ```py %%pydough @@ -322,7 +322,7 @@ People( ) ``` -**Good Example #2**: for every person, find the total value of all packages they ordered in February of any year, as well as the number of all such packages, the largest value of any such package, and the percentage of those packages that were specifically on Valentine's day +**Good Example #2**: For every person, find the total value of all packages they ordered in February of any year, as well as the number of all such packages, the largest value of any such package, and the percentage of those packages that were specifically on Valentine's day ```py %%pydough From d1ffdb5600e2eb0c2216e3920c0d0b194bc4501b Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Thu, 23 Jan 2025 10:22:13 -0500 Subject: [PATCH 060/112] Apply suggestions from code review Co-authored-by: Hadia Ahmed --- documentation/dsl.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index 03d06176..41afd585 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -144,7 +144,7 @@ Packages.shipping_addresses ### CALC -The examples so far just show selecting all properties from records of a collection. Most of the time, an analytical question will only want a subset of the properties, may want to rename them, and may want to derive new properties via calculated expressions. The way to do this with a CALC term, which is done by following a PyDough collection with parenthesis containing the expressions that should be included. +The examples so far just show selecting all properties from records of a collection. Most of the time, an analytical question will only want a subset of the properties, may want to rename them, and may want to derive new properties via calculated expressions. The way to do this is with a `CALC` term, which is done by following a PyDough collection with parenthesis containing the expressions that should be included. These expressions can be positional arguments or keyword arguments. Keyword arguments use the name of the keyword as the name of the output expression. Positional arguments use the name of the expression, if one exists, otherwise an arbitrary name is chosen. @@ -159,14 +159,14 @@ Once a CALC term is created, all terms of the current collection still exist eve A CALC can also be done on the graph itself to create a collection with 1 row and columns corresponding to the properties inside the CALC. This is useful when aggregating an entire collection globally instead of with regards to a parent collection. -**Good Example #1**: For every person, fetches just their first name & last name. +**Good Example #1**: For every person, fetch just their first name and last name. ```py %%pydough People(first_name, last_name) ``` -**Good Example #2**: For every package, fetches the package id, the first & last name of the person who ordered it, and the state that it was shipped to. Also includes a field named `secret_key` that is always equal to the string `"alphabet soup"`. +**Good Example #2**: For every package, fetch the package id, the first and last name of the person who ordered it, and the state that it was shipped to. Also, include a field named `secret_key` that is always equal to the string `"alphabet soup"`. ```py %%pydough @@ -440,7 +440,7 @@ GRAPH(x=BACK(1).foo) People(y=BACK(1).bar) ``` -**Bad Example #3**: The 1st ancestor of `People` is `GRAPH` which does not have an ancestor, so there can be no 2nd ancestor of `People`. +**Bad Example #3**: The 1st ancestor of `People` is `GRAPH` which does not have an ancestor. Therefore, `People` cannot have a 2nd ancestor. ```py %%pydough @@ -680,7 +680,7 @@ People.WHERE(MONTH(packages.order_date) == 6) Another operation that can be done onto PyDough collections is sorting them. This is done by appending a collection with `.ORDER_BY(...)` which will order the collection by the collation terms between the parenthesis. The collation terms must be 1+ expressions that can be inside of a CALC term (singular expressions with regards to the current context), each decorated with information making it usable as a collation. -An expression becomes a collation expression when it is appended with `.ASC()` (indicating that the expression should be used to sort in ascending order) or `.DESC()` (indicating that the expression should be used to sort in descending order). Both `.ASC()` and `.DESC()` take in an optional argument `na_pos` indicating where to place null values. This keyword argument can be either `"first"` or `"last"`, and the default is `"first"` for `.ASC()` and `"last"` for `.DESC()`. The way the sorting works is that it orders by hte first collation term provided, and in cases of ties it moves on to the second collation term, and if there are ties in that it moves on to the third, and so on until there are no more terms to sort by, at which point the ties are broken arbitrarily. +An expression becomes a collation expression when it is appended with `.ASC()` (indicating that the expression should be used to sort in ascending order) or `.DESC()` (indicating that the expression should be used to sort in descending order). Both `.ASC()` and `.DESC()` take in an optional argument `na_pos` indicating where to place null values. This keyword argument can be either `"first"` or `"last"`, and the default is `"first"` for `.ASC()` and `"last"` for `.DESC()`. The way the sorting works is that it orders by the first collation term provided, and in cases of ties it moves on to the second collation term, and if there are ties in that it moves on to the third, and so on until there are no more terms to sort by, at which point the ties are broken arbitrarily. If there are multiple `ORDER_BY` terms, the last one is the one that takes precedence. The terms in the collection are unchanged by the `ORDER_BY` clause, since the only change is the order of the records. @@ -1118,7 +1118,7 @@ GRAPH.PARTITION(people_info, name="p", by=birth_year)( ) ``` -**Bad Example #9**: Partitions each address' current occupants by their birth year and filters to only include people born in years where at least 10000 people were born, then gets more information of people from those years. This is invalid because after accessing `.ppl`, the term `BACK(1).state` is used. This is not valid because even though the data that `.ppl` refers to (`people_info`) has access to `BACK(1).state`, that ancestry information was lost after partitioning `people_info`. Instead, `BACK(1)` refers to the `PARTITION` clause, which does not have a `state` field. +**Bad Example #9**: Partitions the current occupants of each address by their birth year and filters to include only those born in years with at least 10,000 births. It then gets more information about people from those years. This query is invalid because, after accessing `.ppl`, the term `BACK(1).state` is used. This is not valid because, although the data that `.ppl` refers to (`people_info`) originally had access to `BACK(1).state`, that ancestry information was lost after partitioning `people_info`. Instead, `BACK(1)` now refers to the `PARTITION` clause, which does not have a state field. ```py %%pydough From 91e2a749caf2568268bd1a9d9bff2afe69877ffa Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 23 Jan 2025 10:24:07 -0500 Subject: [PATCH 061/112] Extra revisions --- documentation/dsl.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index 41afd585..733b566d 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -203,7 +203,7 @@ People( ) ``` -**Good Example #5**: For every person, find the year from the most recent package they purchased, and from the first package they ever purchased. +**Good Example #5**: For every person, find the year of the most recent package they purchased and the year of their first package purchase. ```py %%pydough From 8b9db1e96e0a33f67e93ab7e96767039189ccf49 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 23 Jan 2025 10:27:00 -0500 Subject: [PATCH 062/112] More plural fixes --- documentation/dsl.md | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/documentation/dsl.md b/documentation/dsl.md index 733b566d..57965054 100644 --- a/documentation/dsl.md +++ b/documentation/dsl.md @@ -51,35 +51,35 @@ There are also the following sub-collection relationships: The simplest PyDough code is scanning an entire collection. This is done by providing the name of the collection in the metadata. However, if that name is already used as a variable, then PyDough will not know to replace the name with the corresponding PyDough object. -**Good Example #1**: Obtains every record of the `People` collection. Every scalar property of `People` (`first_name`, `middle_name`, `last_name`, `ssn`, `birth_date`, `email`, `current_address_id`) is automatically included in the output. +**Good Example #1**: Obtain every record of the `People` collection. Every scalar property of `People` (`first_name`, `middle_name`, `last_name`, `ssn`, `birth_date`, `email`, `current_address_id`) is automatically included in the output. ```py %%pydough People ``` -**Good Example #2**: Obtains every record of the `Addresses` collection. The `GRAPH.` prefix is optional and implied when the term is a collection name in the graph. Every scalar property of `Addresses` (`address_id`, `street_number`, `street_name`, `apartment`, `zip_code`, `city`, `state`) is automatically included in the output. +**Good Example #2**: Obtain every record of the `Addresses` collection. The `GRAPH.` prefix is optional and implied when the term is a collection name in the graph. Every scalar property of `Addresses` (`address_id`, `street_number`, `street_name`, `apartment`, `zip_code`, `city`, `state`) is automatically included in the output. ```py %%pydough GRAPH.Addresses ``` -**Good Example #3**: Obtains every record of the `Packages` collection. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`) is automatically included in the output. +**Good Example #3**: Obtain every record of the `Packages` collection. Every scalar property of `Packages` (`package_id`, `customer_ssn`, `shipping_address_id`, `billing_address_id`, `order_date`, `arrival_date`, `package_cost`) is automatically included in the output. ```py %%pydough Packages ``` -**Bad Example #1**: Obtains every record of the `Products` collection (there is no `Products` collection). +**Bad Example #1**: Obtain every record of the `Products` collection (there is no `Products` collection). ```py %%pydough Products ``` -**Bad Example #2**: Obtains every record of the `Addresses` collection (but the name `Addresses` has been reassigned to a variable). +**Bad Example #2**: Obtain every record of the `Addresses` collection (but the name `Addresses` has been reassigned to a variable). ```py %%pydough @@ -87,7 +87,7 @@ Addresses = 42 Addresses ``` -**Bad Example #3**: Obtains every record of the `Addresses` collection (but the graph name `HELLO` is the wrong graph name for this example). +**Bad Example #3**: Obtain every record of the `Addresses` collection (but the graph name `HELLO` is the wrong graph name for this example). ```py %%pydough @@ -589,7 +589,7 @@ packages_jan_2018 = Packages.WHERE( GRAPH(n_jan_2018=COUNT(selected_packages)) ``` -**Good Example #6**: Counts how many people have don't have a first or last name that starts with A. [See here](functions.md#logical) for more details on the valid/invalid use of logical operations in Python. +**Good Example #6**: Count how many people have don't have a first or last name that starts with A. [See here](functions.md#logical) for more details on the valid/invalid use of logical operations in Python. ```py %%pydough @@ -599,7 +599,7 @@ selected_people = People.WHERE( GRAPH(n_people=COUNT(selected_people)) ``` -**Good Example #7**: Counts how many people have a gmail or yahoo account. [See here](functions.md#logical) for more details on the valid/invalid use of logical operations in Python. +**Good Example #7**: Count how many people have a gmail or yahoo account. [See here](functions.md#logical) for more details on the valid/invalid use of logical operations in Python. ```py %%pydough @@ -609,7 +609,7 @@ gmail_or_yahoo = People.WHERE( GRAPH(n_gmail_or_yahoo=COUNT(gmail_or_yahoo)) ``` -**Good Example #8**: Counts how many people were born in the 1980s. [See here](functions.md#comparisons) for more details on the valid/invalid use of comparisons in Python. +**Good Example #8**: Count how many people were born in the 1980s. [See here](functions.md#comparisons) for more details on the valid/invalid use of comparisons in Python. ```py %%pydough @@ -977,7 +977,7 @@ PARTITION(Addresses, name="addrs", by=(city, state))( ) ``` -**Good Example #6**: Identifies the states whose current occupants account for at least 1% of all packages purchased. List the state and the percentage. +**Good Example #6**: Identify the states whose current occupants account for at least 1% of all packages purchased. List the state and the percentage. ```py %%pydough @@ -989,7 +989,7 @@ GRAPH( ).WHERE(pct_of_packages >= 1.0) ``` -**Good Example #7**: Identifies which months of the year have numbers of packages shipped in that month that are above the average for all months. +**Good Example #7**: Identify which months of the year have numbers of packages shipped in that month that are above the average for all months. ```py %%pydough @@ -1034,7 +1034,7 @@ PARTITION(people_info, name="ppl", by=(state, first_letter))( ).TOP_K(10, by=n_people.DESC()) ``` -**Bad Example #1**: Partitions a collection `Products` that does not exist in the graph. +**Bad Example #1**: Partition a collection `Products` that does not exist in the graph. ```py %%pydough @@ -1084,7 +1084,7 @@ PARTITION(Addresses.current_occupants, name="ppl", by=(BACK(1).state, first_name ).TOP_K(10, by=n_people.DESC()) ``` -**Bad Example #7**: Partitions people by their birth year to find the number of people born in each year. Invalid because the `email` property is referenced, which is not one of the properties accessible by the partition. +**Bad Example #7**: Partition people by their birth year to find the number of people born in each year. Invalid because the `email` property is referenced, which is not one of the properties accessible by the partition. ```py %%pydough @@ -1106,7 +1106,7 @@ People.PARTITION(packages(year=YEAR(order_date)), name="p", by=year)( ) ``` -**Bad Example #8**: Partitions each address' current occupants by their birth year to get the number of people per birth year. This is invalid because the example includes a field `BACK(2).bar` which does not exist because the first ancestor of the partition is `GRAPH`, which does not have a second ancestor. +**Bad Example #8**: Partition each address' current occupants by their birth year to get the number of people per birth year. This is invalid because the example includes a field `BACK(2).bar` which does not exist because the first ancestor of the partition is `GRAPH`, which does not have a second ancestor. ```py %%pydough @@ -1118,7 +1118,7 @@ GRAPH.PARTITION(people_info, name="p", by=birth_year)( ) ``` -**Bad Example #9**: Partitions the current occupants of each address by their birth year and filters to include only those born in years with at least 10,000 births. It then gets more information about people from those years. This query is invalid because, after accessing `.ppl`, the term `BACK(1).state` is used. This is not valid because, although the data that `.ppl` refers to (`people_info`) originally had access to `BACK(1).state`, that ancestry information was lost after partitioning `people_info`. Instead, `BACK(1)` now refers to the `PARTITION` clause, which does not have a state field. +**Bad Example #9**: Partition the current occupants of each address by their birth year and filters to include only those born in years with at least 10,000 births. It then gets more information about people from those years. This query is invalid because, after accessing `.ppl`, the term `BACK(1).state` is used. This is not valid because, although the data that `.ppl` refers to (`people_info`) originally had access to `BACK(1).state`, that ancestry information was lost after partitioning `people_info`. Instead, `BACK(1)` now refers to the `PARTITION` clause, which does not have a state field. ```py %%pydough @@ -1140,7 +1140,7 @@ GRAPH.PARTITION(people_info, name="ppl", by=birth_year).WHERE( Certain PyDough operations, such as specific filters, can cause plural data to become singular. In this case, PyDough will still ban the plural data from being treated as singular unless the `.SINGULAR()` modifier is used to tell PyDough that the data should be treated as singular. It is very important that this only be used if the user is certain that the data will be singular, since otherwise it can result in undefined behavior when the PyDough code is executed. -**Good Example #1**: Accesses the package cost of the most recent package ordered by each person. This is valid because even though `.packages` is plural, the filter done on it will ensure that there is only one record for each record of `People`, so `.SINGULAR()` is valid. +**Good Example #1**: Access the package cost of the most recent package ordered by each person. This is valid because even though `.packages` is plural, the filter done on it will ensure that there is only one record for each record of `People`, so `.SINGULAR()` is valid. ```py %%pydough @@ -1156,7 +1156,7 @@ People( ) ``` -**Good Example #2**: Accesses the email of the current occupant of each address that has the name `"John Smith"` (no middle name). This is valid if it is safe to assume that each address only has one current occupant named `"John Smith"` without a middle name. +**Good Example #2**: Access the email of the current occupant of each address that has the name `"John Smith"` (no middle name). This is valid if it is safe to assume that each address only has one current occupant named `"John Smith"` without a middle name. ```py %%pydough From e8b0cdbbd7b3c3c797c94f4a2201bfe81498a523 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 23 Jan 2025 13:06:20 -0500 Subject: [PATCH 063/112] Refactor agg call handling --- pydough/conversion/hybrid_tree.py | 324 ++++++++++++++---------------- tests/test_pipeline.py | 3 - 2 files changed, 150 insertions(+), 177 deletions(-) diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 0a9fc64e..ef158814 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -1270,108 +1270,6 @@ def populate_children( connection_type = connection_type.reconcile_connection_types(con_typ) child_idx_mapping[child_idx] = hybrid.add_child(subtree, connection_type) - def make_hybrid_agg_expr( - self, - hybrid: HybridTree, - expr: PyDoughExpressionQDAG, - child_ref_mapping: dict[int, int], - ) -> tuple[HybridExpr, int | None]: - """ - Converts a QDAG expression into a HybridExpr specifically with the - intent of making it the input to an aggregation call. Returns the - converted function argument, as well as an index indicating what child - subtree the aggregation's arguments belong to. NOTE: the HybridExpr is - phrased relative to the child subtree, rather than relative to `hybrid` - itself. - - Args: - `hybrid`: the hybrid tree that should be used to derive the - translation of `expr`, as it is the context in which the `expr` - will live. - `expr`: the QDAG expression to be converted. - `child_ref_mapping`: mapping of indices used by child references - in the original expressions to the index of the child hybrid tree - relative to the current level. - - Returns: - The HybridExpr node corresponding to `expr`, as well as the index - of the child it belongs to (e.g. which subtree does this - aggregation need to be done on top of). - """ - hybrid_result: HybridExpr - # This value starts out as None since we do not know the child index - # that `expr` correspond to yet. It may still be None at the end, since - # it is possible that `expr` does not correspond to any child index. - child_idx: int | None = None - match expr: - case PartitionKey(): - return self.make_hybrid_agg_expr(hybrid, expr.expr, child_ref_mapping) - case Literal(): - # Literals are kept as-is. - hybrid_result = HybridLiteralExpr(expr) - case ChildReferenceExpression(): - # Child references become regular references because the - # expression is phrased as if we were inside the child rather - # than the parent. - child_idx = child_ref_mapping[expr.child_idx] - child_connection = hybrid.children[child_idx] - expr_name = child_connection.subtree.pipeline[-1].renamings.get( - expr.term_name, expr.term_name - ) - hybrid_result = HybridRefExpr(expr_name, expr.pydough_type) - case ExpressionFunctionCall(): - if expr.operator.is_aggregation: - raise NotImplementedError( - "PyDough does not yet support calling aggregations inside of aggregations" - ) - # Every argument must be translated in the same manner as a - # regular function argument, except that the child index it - # corresponds to must be reconciled with the child index value - # accumulated so far. - args: list[HybridExpr] = [] - for arg in expr.args: - if not isinstance(arg, PyDoughExpressionQDAG): - raise NotImplementedError( - f"TODO: support converting {arg.__class__.__name__} as a function argument" - ) - hybrid_arg, hybrid_child_index = self.make_hybrid_agg_expr( - hybrid, arg, child_ref_mapping - ) - if hybrid_child_index is not None: - if child_idx is None: - # In this case, the argument is the first one seen that - # has an index, so that index is chosen. - child_idx = hybrid_child_index - elif hybrid_child_index != child_idx: - # In this case, multiple arguments correspond to - # different children, which cannot be handled yet - # because it means it is impossible to push the agg - # call into a single HybridConnection node. - raise NotImplementedError( - "Unsupported case: multiple child indices referenced by aggregation arguments" - ) - args.append(hybrid_arg) - hybrid_result = HybridFunctionExpr( - expr.operator, args, expr.pydough_type - ) - case BackReferenceExpression(): - raise NotImplementedError( - "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of an ancestor of the current context" - ) - case Reference(): - raise NotImplementedError( - "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of the context itself" - ) - case WindowCall(): - raise NotImplementedError( - "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and window functions" - ) - case _: - raise NotImplementedError( - f"TODO: support converting {expr.__class__.__name__} in aggregations" - ) - return hybrid_result, child_idx - def postprocess_agg_output( self, agg_call: HybridFunctionExpr, agg_ref: HybridExpr, joins_can_nullify: bool ) -> HybridExpr: @@ -1533,11 +1431,109 @@ def handle_has_hasnot( # has / hasnot condition is now known to be true. return HybridLiteralExpr(Literal(True, BooleanType())) + def convert_agg_arg(self, expr: HybridExpr, child_indices: set[int]) -> HybridExpr: + """ + Translates a hybrid expression that is an argument to an aggregation + (or a subexpression of such an argument) into a form that is expressed + from the perspective of the child subtree that is being aggregated. + + Args: + `expr`: the expression to be converted. + `child_indices`: a set that is mutated to contain the indices of + any children that are referenced by `expr`. + + Returns: + The translated expression. + + Raises: + NotImplementedError if `expr` is an expression that cannot be used + inside of an aggregation call. + """ + match expr: + case HybridLiteralExpr(): + return expr + case HybridChildRefExpr(): + # Child references become regular references because the + # expression is phrased as if we were inside the child rather + # than the parent. + child_indices.add(expr.child_idx) + return HybridRefExpr(expr.name, expr.typ) + case HybridFunctionExpr(): + return HybridFunctionExpr( + expr.operator, + [self.convert_agg_arg(arg, child_indices) for arg in expr.args], + expr.typ, + ) + case HybridBackRefExpr(): + raise NotImplementedError( + "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of an ancestor of the current context" + ) + case HybridRefExpr(): + raise NotImplementedError( + "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of the context itself" + ) + case HybridWindowExpr(): + raise NotImplementedError( + "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and window functions" + ) + case _: + raise NotImplementedError( + f"TODO: support converting {expr.__class__.__name__} in aggregations" + ) + + def make_agg_call( + self, + hybrid: HybridTree, + expr: ExpressionFunctionCall, + args: list[HybridExpr], + ) -> HybridExpr: + """ + For aggregate function calls, their arguments are translated in a + manner that identifies what child subtree they correspond too, by + index, and translates them relative to the subtree. Then, the + aggregation calls are placed into the `aggs` mapping of the + corresponding child connection, and the aggregation call becomes a + child reference (referring to the aggs list), since after translation, + an aggregated child subtree only has the grouping keys and the + aggregation calls as opposed to its other terms. + + Args: + `hybrid`: the hybrid tree that should be used to derive the + translation of the aggregation call. + `expr`: the aggregation function QDAG expression to be converted. + `args`: the converted arguments to the aggregation call. + """ + child_indices: set[int] = set() + converted_args: list[HybridExpr] = [ + self.convert_agg_arg(arg, child_indices) for arg in args + ] + if len(child_indices) != 1: + raise ValueError( + f"Expected aggregation call to contain references to exactly one child collection, but found {len(child_indices)} in {expr}" + ) + hybrid_call: HybridFunctionExpr = HybridFunctionExpr( + expr.operator, converted_args, expr.pydough_type + ) + # Identify the child connection that the aggregation call is pushed + # into. + child_idx: int = child_indices.pop() + child_connection = hybrid.children[child_idx] + # Generate a unique name for the agg call to push into the child + # connection. + agg_name: str = self.get_agg_name(child_connection) + child_connection.aggs[agg_name] = hybrid_call + result_ref: HybridExpr = HybridChildRefExpr( + agg_name, child_idx, expr.pydough_type + ) + joins_can_nullify: bool = not isinstance(hybrid.pipeline[0], HybridRoot) + return self.postprocess_agg_output(hybrid_call, result_ref, joins_can_nullify) + def make_hybrid_expr( self, hybrid: HybridTree, expr: PyDoughExpressionQDAG, child_ref_mapping: dict[int, int], + inside_agg: bool, ) -> HybridExpr: """ Converts a QDAG expression into a HybridExpr. @@ -1550,6 +1546,8 @@ def make_hybrid_expr( `child_ref_mapping`: mapping of indices used by child references in the original expressions to the index of the child hybrid tree relative to the current level. + `inside_agg`: True if `expr` is beign derived is inside of an + aggregation call, False otherwise. Returns: The HybridExpr node corresponding to `expr` @@ -1561,7 +1559,9 @@ def make_hybrid_expr( ancestor_tree: HybridTree match expr: case PartitionKey(): - return self.make_hybrid_expr(hybrid, expr.expr, child_ref_mapping) + return self.make_hybrid_expr( + hybrid, expr.expr, child_ref_mapping, inside_agg + ) case Literal(): return HybridLiteralExpr(expr) case ColumnProperty(): @@ -1609,80 +1609,51 @@ def make_hybrid_expr( expr.term_name, expr.term_name ) return HybridRefExpr(expr_name, expr.pydough_type) - case ExpressionFunctionCall() if not expr.operator.is_aggregation: - # For non-aggregate function calls, translate their arguments - # normally and build the function call. Does not support any + case ExpressionFunctionCall(): + if expr.operator.is_aggregation and inside_agg: + raise NotImplementedError( + "PyDough does not yet support calling aggregations inside of aggregations" + ) + # Do special casing for operators that an have collection + # arguments. + # TODO: (gh #148) handle collection-level NDISTINCT + if ( + expr.operator == pydop.COUNT + and len(expr.args) == 1 + and isinstance(expr.args[0], PyDoughCollectionQDAG) + ): + return self.handle_collection_count(hybrid, expr, child_ref_mapping) + elif expr.operator in (pydop.HAS, pydop.HASNOT): + return self.handle_has_hasnot(hybrid, expr, child_ref_mapping) + elif any( + not isinstance(arg, PyDoughExpressionQDAG) for arg in expr.args + ): + raise NotImplementedError( + f"PyDough does not yet support non-expression arguments for aggregation function {expr.operator}" + ) + # For normal operators, translate their expression arguments + # normally. If it is a non-aggregation, build the function + # call. If it is an aggregation, transform accordingly. # such function that takes in a collection, as none currently # exist that are not aggregations. + expr.operator.is_aggregation for arg in expr.args: if not isinstance(arg, PyDoughExpressionQDAG): raise NotImplementedError( - "PyDough does not yet support converting collections as function arguments to a non-aggregation function" + f"PyDough does not yet support non-expression arguments for function {expr.operator}" ) - args.append(self.make_hybrid_expr(hybrid, arg, child_ref_mapping)) - return HybridFunctionExpr(expr.operator, args, expr.pydough_type) - case ExpressionFunctionCall() if expr.operator.is_aggregation: - # For aggregate function calls, their arguments are translated in - # a manner that identifies what child subtree they correspond too, - # by index, and translates them relative to the subtree. Then, the - # aggregation calls are placed into the `aggs` mapping of the - # corresponding child connection, and the aggregation call becomes - # a child reference (referring to the aggs list), since after - # translation, an aggregated child subtree only has the grouping - # keys & the aggregation calls as opposed to its other terms. - child_idx: int | None = None - arg_child_idx: int | None = None - for arg in expr.args: - if isinstance(arg, PyDoughExpressionQDAG): - hybrid_arg, arg_child_idx = self.make_hybrid_agg_expr( - hybrid, arg, child_ref_mapping + args.append( + self.make_hybrid_expr( + hybrid, + arg, + child_ref_mapping, + inside_agg or expr.operator.is_aggregation, ) - else: - if not isinstance(arg, ChildReferenceCollection): - raise NotImplementedError("Cannot process argument") - # TODO: (gh #148) handle collection-level NDISTINCT - if expr.operator == pydop.COUNT: - return self.handle_collection_count( - hybrid, expr, child_ref_mapping - ) - elif expr.operator in (pydop.HAS, pydop.HASNOT): - return self.handle_has_hasnot( - hybrid, expr, child_ref_mapping - ) - else: - raise NotImplementedError( - f"PyDough does not yet support collection arguments for aggregation function {expr.operator}" - ) - # Accumulate the `arg_child_idx` value from the argument across - # all function arguments, ensuring that at the end there is - # exactly one child subtree that the agg call corresponds to. - if arg_child_idx is not None: - if child_idx is None: - child_idx = arg_child_idx - elif arg_child_idx != child_idx: - raise NotImplementedError( - "Unsupported case: multiple child indices referenced by aggregation arguments" - ) - args.append(hybrid_arg) - if child_idx is None: - raise NotImplementedError( - "Unsupported case: no child indices referenced by aggregation arguments" ) - hybrid_call: HybridFunctionExpr = HybridFunctionExpr( - expr.operator, args, expr.pydough_type - ) - child_connection = hybrid.children[child_idx] - # Generate a unique name for the agg call to push into the child - # connection. - agg_name: str = self.get_agg_name(child_connection) - child_connection.aggs[agg_name] = hybrid_call - result_ref: HybridExpr = HybridChildRefExpr( - agg_name, child_idx, expr.pydough_type - ) - joins_can_nullify: bool = not isinstance(hybrid.pipeline[0], HybridRoot) - return self.postprocess_agg_output( - hybrid_call, result_ref, joins_can_nullify - ) + if expr.operator.is_aggregation: + return self.make_agg_call(hybrid, expr, args) + else: + return HybridFunctionExpr(expr.operator, args, expr.pydough_type) case WindowCall(): partition_args: list[HybridExpr] = [] order_args: list[HybridCollation] = [] @@ -1700,7 +1671,7 @@ def make_hybrid_expr( partition_args.append(shifted_arg) for arg in expr.collation_args: hybrid_arg = self.make_hybrid_expr( - hybrid, arg.expr, child_ref_mapping + hybrid, arg.expr, child_ref_mapping, inside_agg ) order_args.append(HybridCollation(hybrid_arg, arg.asc, arg.na_last)) return HybridWindowExpr( @@ -1739,7 +1710,9 @@ def process_hybrid_collations( hybrid_orderings: list[HybridCollation] = [] for collation in collations: name = self.get_ordering_name(hybrid) - expr = self.make_hybrid_expr(hybrid, collation.expr, child_ref_mapping) + expr = self.make_hybrid_expr( + hybrid, collation.expr, child_ref_mapping, False + ) new_expressions[name] = expr new_collation: HybridCollation = HybridCollation( HybridRefExpr(name, collation.expr.pydough_type), @@ -1792,7 +1765,7 @@ def make_hybrid_tree( new_expressions: dict[str, HybridExpr] = {} for name in sorted(node.calc_terms): expr = self.make_hybrid_expr( - hybrid, node.get_expr(name), child_ref_mapping + hybrid, node.get_expr(name), child_ref_mapping, False ) new_expressions[name] = expr hybrid.pipeline.append( @@ -1806,7 +1779,9 @@ def make_hybrid_tree( case Where(): hybrid = self.make_hybrid_tree(node.preceding_context, parent) self.populate_children(hybrid, node, child_ref_mapping) - expr = self.make_hybrid_expr(hybrid, node.condition, child_ref_mapping) + expr = self.make_hybrid_expr( + hybrid, node.condition, child_ref_mapping, False + ) hybrid.pipeline.append(HybridFilter(hybrid.pipeline[-1], expr)) return hybrid case PartitionBy(): @@ -1819,7 +1794,7 @@ def make_hybrid_tree( for key_name in node.calc_terms: key = node.get_expr(key_name) expr = self.make_hybrid_expr( - successor_hybrid, key, child_ref_mapping + successor_hybrid, key, child_ref_mapping, False ) partition.add_key(key_name, expr) key_exprs.append(HybridRefExpr(key_name, expr.typ)) @@ -1869,6 +1844,7 @@ def make_hybrid_tree( successor_hybrid, Reference(node.child_access, key.expr.term_name), child_ref_mapping, + False, ) assert isinstance(rhs_expr, HybridRefExpr) lhs_expr: HybridExpr = HybridChildRefExpr( diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 41605af3..f7a85399 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -197,7 +197,6 @@ tpch_q5_output, ), id="tpch_q5", - marks=pytest.mark.skip("TODO: support correlated back references"), ), pytest.param( ( @@ -552,7 +551,6 @@ tpch_q21_output, ), id="tpch_q21", - marks=pytest.mark.skip("TODO: support correlated back references"), ), pytest.param( ( @@ -562,7 +560,6 @@ tpch_q22_output, ), id="tpch_q22", - marks=pytest.mark.skip("TODO: support correlated back references"), ), pytest.param( ( From a5d8c1fceebe93caa1cb5bdb2e92adc25cb989da Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 24 Jan 2025 11:31:15 -0500 Subject: [PATCH 064/112] WIP progress on correlated references --- pydough/conversion/hybrid_tree.py | 66 ++++++++++++++++++- pydough/conversion/relational_converter.py | 51 +++++++++++++- .../sqlglot_relational_expression_visitor.py | 19 +++--- pydough/sqlglot/sqlglot_relational_visitor.py | 16 ++++- 4 files changed, 136 insertions(+), 16 deletions(-) diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index ef158814..b7fe4610 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -229,6 +229,27 @@ def shift_back(self, levels: int) -> HybridExpr | None: return HybridBackRefExpr(self.name, self.back_idx + levels, self.typ) +class HybridCorrelExpr(HybridExpr): + """ + Class for HybridExpr terms that are expressions from a parent hybrid tree + rather than an ancestor, which requires a correlated reference. + """ + + def __init__(self, hybrid: "HybridTree", expr: HybridExpr): + super().__init__(expr.typ) + self.hybrid = hybrid + self.expr: HybridExpr = expr + + def __repr__(self): + return f"CORREL({self.expr})" + + def apply_renamings(self, renamings: dict[str, str]) -> "HybridExpr": + return self + + def shift_back(self, levels: int) -> HybridExpr | None: + return self + + class HybridLiteralExpr(HybridExpr): """ Class for HybridExpr terms that are literals. @@ -1046,6 +1067,10 @@ def __init__(self, configs: PyDoughConfigs): self.configs = configs # An index used for creating fake column names for aliases self.alias_counter: int = 0 + # A stack where each element is a hybrid tree being derived + # as as subtree of the previous element, and the current tree is + # being derived as the subtree of the last element. + self.stack: list[HybridTree] = [] @staticmethod def get_join_keys( @@ -1230,6 +1255,7 @@ def populate_children( accordingly so expressions using the child indices know what hybrid connection index to use. """ + self.stack.append(hybrid) for child_idx, child in enumerate(child_operator.children): # Build the hybrid tree for the child. Before doing so, reset the # alias counter to 0 to ensure that identical subtrees are named @@ -1269,6 +1295,7 @@ def populate_children( for con_typ in reference_types: connection_type = connection_type.reconcile_connection_types(con_typ) child_idx_mapping[child_idx] = hybrid.add_child(subtree, connection_type) + self.stack.pop() def postprocess_agg_output( self, agg_call: HybridFunctionExpr, agg_ref: HybridExpr, joins_can_nullify: bool @@ -1591,11 +1618,44 @@ def make_hybrid_expr( # Keep stepping backward until `expr.back_levels` non-hidden # steps have been taken (to ignore steps that are part of a # compound). + collection: PyDoughCollectionQDAG = expr.collection while true_steps_back < expr.back_levels: + assert collection.ancestor_context is not None + collection = collection.ancestor_context if ancestor_tree.parent is None: - raise NotImplementedError( - "TODO: (gh #141) support BACK references that step from a child subtree back into a parent context." + if len(self.stack) == 0: + raise ValueError("Back reference steps too far back") + parent_tree = self.stack.pop() + remaining_steps_back: int = ( + expr.back_levels - true_steps_back - 1 ) + # TODO: deal with case where the ancestor is PARTITION + # if len(parent_tree.pipeline) == 1 and isinstance(parent_tree.pipeline[0], HybridPartition): + # remaining_steps_back += 1 + parent_result: HybridExpr + if remaining_steps_back == 0: + if expr.term_name not in parent_tree.pipeline[-1].terms: + raise ValueError( + f"Back reference to {expr.term_name} not found in parent" + ) + parent_name: str = parent_tree.pipeline[-1].renamings.get( + expr.term_name, expr.term_name + ) + parent_result = HybridRefExpr( + parent_name, expr.pydough_type + ) + else: + new_expr: PyDoughExpressionQDAG = BackReferenceExpression( + collection, expr.term_name, remaining_steps_back + ) + parent_result = self.make_hybrid_expr( + parent_tree, new_expr, {}, False + ) + self.stack.append(parent_tree) + return HybridCorrelExpr(parent_tree, parent_result) + # raise NotImplementedError( + # "TODO: (gh #141) support BACK references that step from a child subtree back into a parent context." + # ) ancestor_tree = ancestor_tree.parent back_idx += true_steps_back if not ancestor_tree.is_hidden_level: @@ -1723,7 +1783,7 @@ def process_hybrid_collations( return new_expressions, hybrid_orderings def make_hybrid_tree( - self, node: PyDoughCollectionQDAG, parent: HybridTree | None = None + self, node: PyDoughCollectionQDAG, parent: HybridTree | None ) -> HybridTree: """ Converts a collection QDAG into the HybridTree format. diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 96ddcdc9..e82a513b 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -52,6 +52,7 @@ HybridCollectionAccess, HybridColumnExpr, HybridConnection, + HybridCorrelExpr, HybridExpr, HybridFilter, HybridFunctionExpr, @@ -78,13 +79,32 @@ class TranslationOutput: """ relation: Relational + """ + The relational tree describing the way to compute the answer for the + logic originally in the hybrid tree. + """ + expressions: dict[HybridExpr, ColumnReference] + """ + A mapping of each expression that was accessible in the hybrid tree to the + corresponding column reference in the relational tree that contains the + value of that expression. + """ + + correlated_name: str | None = None + """ + The name that can be used to refer to the relational output in correlated + references. + """ class RelTranslation: def __init__(self): # An index used for creating fake column names self.dummy_idx = 1 + # A stack of contexts used to point to ancestors for correlated + # references. + self.stack: list[TranslationOutput] = [] def make_null_column(self, relation: Relational) -> ColumnReference: """ @@ -114,6 +134,24 @@ def make_null_column(self, relation: Relational) -> ColumnReference: relation.columns[name] = LiteralExpression(None, UnknownType()) return ColumnReference(name, UnknownType()) + def get_correlated_name(self, context: TranslationOutput) -> str: + """ + Finds the name used to refer to a context for correlated variable + access. If the context does not have a correlated name, a new one is + generated for it. + + Args: + `context`: the context containing the relational subtree being + referrenced in a correlated variable access. + + Returns: + The name used to refer to the context in a correlated reference. + """ + if context.correlated_name is None: + context.correlated_name = f"corr{self.dummy_idx}" + self.dummy_idx += 1 + return context.correlated_name + def translate_expression( self, expr: HybridExpr, context: TranslationOutput | None ) -> RelationalExpression: @@ -168,6 +206,15 @@ def translate_expression( order_inputs, expr.kwargs, ) + case HybridCorrelExpr(): + ancestor_context: TranslationOutput = self.stack.pop() + ancestor_expr: RelationalExpression = self.translate_expression( + expr.expr, ancestor_context + ) + self.stack.append(ancestor_context) + return CorrelatedReference( + ancestor_expr, self.get_correlated_name(ancestor_context) + ) case _: raise NotImplementedError(expr.__class__.__name__) @@ -366,9 +413,11 @@ def handle_children( """ for child_idx, child in enumerate(hybrid.children): if child.required_steps == pipeline_idx: + self.stack.append(context) child_output = self.rel_translation( child, child.subtree, len(child.subtree.pipeline) - 1 ) + self.stack.pop() assert child.subtree.join_keys is not None join_keys: list[tuple[HybridExpr, HybridExpr]] = child.subtree.join_keys agg_keys: list[HybridExpr] @@ -905,7 +954,7 @@ def convert_ast_to_relational( # Convert the QDAG node to the hybrid form, then invoke the relational # conversion procedure. The first element in the returned list is the # final rel node. - hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node) + hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node, None) renamings: dict[str, str] = hybrid.pipeline[-1].renamings output: TranslationOutput = translator.rel_translation( None, hybrid, len(hybrid.pipeline) - 1 diff --git a/pydough/sqlglot/sqlglot_relational_expression_visitor.py b/pydough/sqlglot/sqlglot_relational_expression_visitor.py index 4ca111ba..589f2220 100644 --- a/pydough/sqlglot/sqlglot_relational_expression_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_expression_visitor.py @@ -7,8 +7,8 @@ import sqlglot.expressions as sqlglot_expressions from sqlglot.dialects import Dialect as SQLGlotDialect +from sqlglot.expressions import Column, Identifier from sqlglot.expressions import Expression as SQLGlotExpression -from sqlglot.expressions import Identifier from pydough.relational import ( CallExpression, @@ -134,11 +134,11 @@ def visit_literal_expression(self, literal_expression: LiteralExpression) -> Non self._stack.append(literal) @staticmethod - def generate_column_reference_identifier( + def make_sqlglot_column( column_reference: ColumnReference, - ) -> Identifier: + ) -> Column: """ - Generate an identifier for a column reference. This is split into a + Convert a column reference to a SQLGlot column. This is split into a separate static method to ensure consistency across multiple visitors. Args: @@ -146,16 +146,15 @@ def generate_column_reference_identifier( an identifier for. Returns: - Identifier: The output identifier. + Identifier: The output column reference containing an identifier. """ + result: SQLGlotExpression = Column(this=Identifier(this=column_reference.name)) if column_reference.input_name is not None: - full_name = f"{column_reference.input_name}.{column_reference.name}" - else: - full_name = column_reference.name - return Identifier(this=full_name) + result.set("table", Identifier(this=column_reference.input_name)) + return result def visit_column_reference(self, column_reference: ColumnReference) -> None: - self._stack.append(self.generate_column_reference_identifier(column_reference)) + self._stack.append(self.make_sqlglot_column(column_reference)) def relational_to_sqlglot( self, expr: RelationalExpression, output_name: str | None = None diff --git a/pydough/sqlglot/sqlglot_relational_visitor.py b/pydough/sqlglot/sqlglot_relational_visitor.py index 5f1e17ae..c3c3c196 100644 --- a/pydough/sqlglot/sqlglot_relational_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_visitor.py @@ -9,6 +9,7 @@ from sqlglot.dialects import Dialect as SQLGlotDialect from sqlglot.expressions import Alias as SQLGlotAlias +from sqlglot.expressions import Column as SQLGlotColumn from sqlglot.expressions import Expression as SQLGlotExpression from sqlglot.expressions import Identifier, Select, Subquery, values from sqlglot.expressions import Literal as SQLGlotLiteral @@ -94,7 +95,7 @@ def _is_mergeable_column(expr: SQLGlotExpression) -> bool: if isinstance(expr, SQLGlotAlias): return SQLGlotRelationalVisitor._is_mergeable_column(expr.this) else: - return isinstance(expr, (SQLGlotLiteral, Identifier)) + return isinstance(expr, (SQLGlotLiteral, Identifier, SQLGlotColumn)) @staticmethod def _try_merge_columns( @@ -154,11 +155,22 @@ def _try_merge_columns( # If the new column is a literal, we can just add it to the old # columns. modified_old_columns.append(set_glot_alias(new_column, new_name)) - else: + elif isinstance(new_column, Identifier): expr = set_glot_alias(old_column_map[new_column.this], new_name) modified_old_columns.append(expr) if isinstance(expr, Identifier): seen_cols.add(expr) + elif isinstance(new_column, SQLGlotColumn): + expr = set_glot_alias( + old_column_map[new_column.this.this], new_name + ) + modified_old_columns.append(expr) + if isinstance(expr, Identifier): + seen_cols.add(expr) + else: + raise ValueError( + f"Unsupported expression type for column merging: {new_column.__class__.__name__}" + ) # Check that there are no missing dependencies in the old columns. if old_column_deps - seen_cols: return new_columns, old_columns From 836adca4b694157892a94c2f7201fcaa1ffc44b5 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 27 Jan 2025 10:15:38 -0500 Subject: [PATCH 065/112] Initially working implementaiton of relational handling, still need to deal with SQL --- pydough/conversion/relational_converter.py | 7 ++- pydough/relational/__init__.py | 2 + .../relational_expressions/__init__.py | 2 + .../column_reference_finder.py | 3 + .../column_reference_input_name_modifier.py | 3 + .../column_reference_input_name_remover.py | 3 + .../correlated_reference.py | 61 +++++++++++++++++++ .../relational_expression_shuttle.py | 11 ++++ .../relational_expression_visitor.py | 9 +++ .../relational_nodes/column_pruner.py | 25 ++++++-- pydough/relational/relational_nodes/join.py | 18 +++++- .../sqlglot_relational_expression_visitor.py | 8 +++ tests/test_pipeline.py | 42 +++++++++++++ 13 files changed, 186 insertions(+), 8 deletions(-) create mode 100644 pydough/relational/relational_expressions/correlated_reference.py diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index e82a513b..e743dde3 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -27,6 +27,7 @@ CallExpression, ColumnPruner, ColumnReference, + CorrelatedReference, EmptySingleton, ExpressionSortInfo, Filter, @@ -211,9 +212,12 @@ def translate_expression( ancestor_expr: RelationalExpression = self.translate_expression( expr.expr, ancestor_context ) + assert isinstance(ancestor_expr, ColumnReference) self.stack.append(ancestor_context) return CorrelatedReference( - ancestor_expr, self.get_correlated_name(ancestor_context) + ancestor_expr.name, + self.get_correlated_name(ancestor_context), + expr.typ, ) case _: raise NotImplementedError(expr.__class__.__name__) @@ -273,6 +277,7 @@ def join_outputs( [LiteralExpression(True, BooleanType())], [join_type], join_columns, + correl_name=self.get_correlated_name(lhs_result), ) input_aliases: list[str | None] = out_rel.default_input_aliases diff --git a/pydough/relational/__init__.py b/pydough/relational/__init__.py index b79fe424..9c54dcd4 100644 --- a/pydough/relational/__init__.py +++ b/pydough/relational/__init__.py @@ -6,6 +6,7 @@ "ColumnReferenceFinder", "ColumnReferenceInputNameModifier", "ColumnReferenceInputNameRemover", + "CorrelatedReference", "EmptySingleton", "ExpressionSortInfo", "Filter", @@ -30,6 +31,7 @@ ColumnReferenceFinder, ColumnReferenceInputNameModifier, ColumnReferenceInputNameRemover, + CorrelatedReference, ExpressionSortInfo, LiteralExpression, RelationalExpression, diff --git a/pydough/relational/relational_expressions/__init__.py b/pydough/relational/relational_expressions/__init__.py index 68838524..e8600faf 100644 --- a/pydough/relational/relational_expressions/__init__.py +++ b/pydough/relational/relational_expressions/__init__.py @@ -9,6 +9,7 @@ "ColumnReferenceFinder", "ColumnReferenceInputNameModifier", "ColumnReferenceInputNameRemover", + "CorrelatedReference", "ExpressionSortInfo", "LiteralExpression", "RelationalExpression", @@ -21,6 +22,7 @@ from .column_reference_finder import ColumnReferenceFinder from .column_reference_input_name_modifier import ColumnReferenceInputNameModifier from .column_reference_input_name_remover import ColumnReferenceInputNameRemover +from .correlated_reference import CorrelatedReference from .expression_sort_info import ExpressionSortInfo from .literal_expression import LiteralExpression from .relational_expression_visitor import RelationalExpressionVisitor diff --git a/pydough/relational/relational_expressions/column_reference_finder.py b/pydough/relational/relational_expressions/column_reference_finder.py index 7de5bc88..d8c7ba61 100644 --- a/pydough/relational/relational_expressions/column_reference_finder.py +++ b/pydough/relational/relational_expressions/column_reference_finder.py @@ -42,3 +42,6 @@ def visit_literal_expression(self, literal_expression: LiteralExpression) -> Non def visit_column_reference(self, column_reference: ColumnReference) -> None: self._column_references.add(column_reference) + + def visit_correlated_reference(self, correlated_reference) -> None: + pass diff --git a/pydough/relational/relational_expressions/column_reference_input_name_modifier.py b/pydough/relational/relational_expressions/column_reference_input_name_modifier.py index 0a750462..fd738ad7 100644 --- a/pydough/relational/relational_expressions/column_reference_input_name_modifier.py +++ b/pydough/relational/relational_expressions/column_reference_input_name_modifier.py @@ -45,3 +45,6 @@ def visit_column_reference(self, column_reference) -> RelationalExpression: raise ValueError( f"Input name {column_reference.input_name} not found in the input name map." ) + + def visit_correlated_reference(self, correlated_reference) -> RelationalExpression: + return correlated_reference diff --git a/pydough/relational/relational_expressions/column_reference_input_name_remover.py b/pydough/relational/relational_expressions/column_reference_input_name_remover.py index 26633764..0de2e1d9 100644 --- a/pydough/relational/relational_expressions/column_reference_input_name_remover.py +++ b/pydough/relational/relational_expressions/column_reference_input_name_remover.py @@ -37,3 +37,6 @@ def visit_column_reference(self, column_reference) -> RelationalExpression: column_reference.data_type, None, ) + + def visit_correlated_reference(self, correlated_reference) -> RelationalExpression: + return correlated_reference diff --git a/pydough/relational/relational_expressions/correlated_reference.py b/pydough/relational/relational_expressions/correlated_reference.py new file mode 100644 index 00000000..0746a5bc --- /dev/null +++ b/pydough/relational/relational_expressions/correlated_reference.py @@ -0,0 +1,61 @@ +""" +TODO +""" + +__all__ = ["CorrelatedReference"] + +from pydough.types import PyDoughType + +from .abstract_expression import RelationalExpression +from .relational_expression_shuttle import RelationalExpressionShuttle +from .relational_expression_visitor import RelationalExpressionVisitor + + +class CorrelatedReference(RelationalExpression): + """ + TODO + """ + + def __init__(self, name: str, correl_name: str, data_type: PyDoughType) -> None: + super().__init__(data_type) + self._name: str = name + self._correl_name: str = correl_name + + def __hash__(self) -> int: + return hash((self.name, self.correl_name, self.data_type)) + + @property + def name(self) -> str: + """ + The name of the column. + """ + return self._name + + @property + def correl_name(self) -> str: + """ + The name of the correlation that the reference points to. + """ + return self._correl_name + + def to_string(self, compact: bool = False) -> str: + if compact: + return f"{self.correl_name}.{self.name}" + else: + return f"CorrelatedReference(name={self.name}, correl_name={self.correl_name}, type={self.data_type})" + + def equals(self, other: object) -> bool: + return ( + isinstance(other, CorrelatedReference) + and (self.name == other.name) + and (self.correl_name == other.correl_name) + and super().equals(other) + ) + + def accept(self, visitor: RelationalExpressionVisitor) -> None: + visitor.visit_correlated_reference(self) + + def accept_shuttle( + self, shuttle: RelationalExpressionShuttle + ) -> RelationalExpression: + return shuttle.visit_correlated_reference(self) diff --git a/pydough/relational/relational_expressions/relational_expression_shuttle.py b/pydough/relational/relational_expressions/relational_expression_shuttle.py index 354a4b8e..badc0b1d 100644 --- a/pydough/relational/relational_expressions/relational_expression_shuttle.py +++ b/pydough/relational/relational_expressions/relational_expression_shuttle.py @@ -75,3 +75,14 @@ def visit_column_reference(self, column_reference): Returns: RelationalExpression: The new node resulting from visiting this node. """ + + @abstractmethod + def visit_correlated_reference(self, correlated_reference): + """ + Visit a CorrelatedReference node. + + Args: + correlated_reference (CorrelatedReference): The correlated reference node to visit. + Returns: + RelationalExpression: The new node resulting from visiting this node. + """ diff --git a/pydough/relational/relational_expressions/relational_expression_visitor.py b/pydough/relational/relational_expressions/relational_expression_visitor.py index 39746b1a..0873662a 100644 --- a/pydough/relational/relational_expressions/relational_expression_visitor.py +++ b/pydough/relational/relational_expressions/relational_expression_visitor.py @@ -61,3 +61,12 @@ def visit_column_reference(self, column_reference) -> None: Args: column_reference (ColumnReference): The column reference node to visit. """ + + @abstractmethod + def visit_correlated_reference(self, correlated_reference) -> None: + """ + Visit a CorrelatedReference node. + + Args: + correlated_reference (CorrelatedReference): The correlated reference node to visit. + """ diff --git a/pydough/relational/relational_nodes/column_pruner.py b/pydough/relational/relational_nodes/column_pruner.py index 8eee34c5..d99a265e 100644 --- a/pydough/relational/relational_nodes/column_pruner.py +++ b/pydough/relational/relational_nodes/column_pruner.py @@ -5,10 +5,12 @@ from pydough.relational.relational_expressions import ( ColumnReference, ColumnReferenceFinder, + CorrelatedReference, ) from .abstract_node import Relational from .aggregate import Aggregate +from .join import Join from .project import Project from .relational_expression_dispatcher import RelationalExpressionDispatcher from .relational_root import RelationalRoot @@ -43,7 +45,7 @@ def _prune_identity_project(self, node: Relational) -> Relational: def _prune_node_columns( self, node: Relational, kept_columns: set[str] - ) -> Relational: + ) -> tuple[Relational, set[CorrelatedReference]]: """ Prune the columns for a subtree starting at this node. @@ -92,15 +94,28 @@ def _prune_node_columns( new_inputs: list[Relational] = [] # Note: The ColumnPruner should only be run when all input names are # still present in the columns. - for i, default_input_name in enumerate(new_node.default_input_aliases): + # Iterate over the inputs in reverse order so that the source of + # correlated data is pruned last. + correl_refs: set[CorrelatedReference] = set() + for i, default_input_name in reversed( + list(enumerate(new_node.default_input_aliases)) + ): s: set[str] = set() + input_node: Relational = node.inputs[i] for identifier in found_identifiers: if identifier.input_name == default_input_name: s.add(identifier.name) - new_inputs.append(self._prune_node_columns(node.inputs[i], s)) + if isinstance(input_node, Join) and i == 0: + for correl_ref in correl_refs: + if correl_ref.correl_name == input_node.correl_name: + s.add(correl_ref.name) + new_input_node, new_correl_refs = self._prune_node_columns(input_node, s) + correl_refs.update(new_correl_refs) + new_inputs.append(new_input_node) + new_inputs.reverse() # Determine the new node. output = new_node.copy(inputs=new_inputs) - return self._prune_identity_project(output) + return self._prune_identity_project(output), correl_refs def prune_unused_columns(self, root: RelationalRoot) -> RelationalRoot: """ @@ -112,6 +127,6 @@ def prune_unused_columns(self, root: RelationalRoot) -> RelationalRoot: Returns: RelationalRoot: The root after updating all inputs. """ - new_root: Relational = self._prune_node_columns(root, set(root.columns.keys())) + new_root, _ = self._prune_node_columns(root, set(root.columns.keys())) assert isinstance(new_root, RelationalRoot), "Expected a root node." return new_root diff --git a/pydough/relational/relational_nodes/join.py b/pydough/relational/relational_nodes/join.py index 3114be15..1568e6ba 100644 --- a/pydough/relational/relational_nodes/join.py +++ b/pydough/relational/relational_nodes/join.py @@ -49,6 +49,7 @@ def __init__( conditions: list[RelationalExpression], join_types: list[JoinType], columns: MutableMapping[str, RelationalExpression], + correl_name: str | None = None, ) -> None: super().__init__(columns) num_inputs = len(inputs) @@ -65,6 +66,15 @@ def __init__( ), "Join condition must be a boolean type" self._conditions: list[RelationalExpression] = conditions self._join_types: list[JoinType] = join_types + self._correl_name: str | None = correl_name + + @property + def correl_name(self) -> str | None: + """ + The name used to refer to the first join input when subsequent inputs + have correlated references. + """ + return self._correl_name @property def conditions(self) -> list[RelationalExpression]: @@ -101,6 +111,7 @@ def node_equals(self, other: Relational) -> bool: isinstance(other, Join) and self.conditions == other.conditions and self.join_types == other.join_types + and self.correl_name == other.correl_name and all( self.inputs[i].node_equals(other.inputs[i]) for i in range(len(self.inputs)) @@ -109,7 +120,10 @@ def node_equals(self, other: Relational) -> bool: def to_string(self, compact: bool = False) -> str: conditions: list[str] = [cond.to_string(compact) for cond in self.conditions] - return f"JOIN(conditions=[{', '.join(conditions)}], types={[t.value for t in self.join_types]}, columns={self.make_column_string(self.columns, compact)})" + correl_suffix = ( + "" if self.correl_name is None else f", correl_name={self.correl_name!r}" + ) + return f"JOIN(conditions=[{', '.join(conditions)}], types={[t.value for t in self.join_types]}, columns={self.make_column_string(self.columns, compact)}{correl_suffix})" def accept(self, visitor: RelationalVisitor) -> None: visitor.visit_join(self) @@ -119,4 +133,4 @@ def node_copy( columns: MutableMapping[str, RelationalExpression], inputs: MutableSequence[Relational], ) -> Relational: - return Join(inputs, self.conditions, self.join_types, columns) + return Join(inputs, self.conditions, self.join_types, columns, self.correl_name) diff --git a/pydough/sqlglot/sqlglot_relational_expression_visitor.py b/pydough/sqlglot/sqlglot_relational_expression_visitor.py index 589f2220..24e0ea81 100644 --- a/pydough/sqlglot/sqlglot_relational_expression_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_expression_visitor.py @@ -13,6 +13,7 @@ from pydough.relational import ( CallExpression, ColumnReference, + CorrelatedReference, LiteralExpression, RelationalExpression, RelationalExpressionVisitor, @@ -133,6 +134,13 @@ def visit_literal_expression(self, literal_expression: LiteralExpression) -> Non ) self._stack.append(literal) + def visit_correlated_reference( + self, correlated_reference: CorrelatedReference + ) -> None: + raise NotImplementedError( + "TODO: support SQL conversion for correlated references" + ) + @staticmethod def make_sqlglot_column( column_reference: ColumnReference, diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index f7a85399..04363519 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -193,6 +193,27 @@ ( impl_tpch_q5, """ +ROOT(columns=[('N_NAME', N_NAME), ('REVENUE', REVENUE)], orderings=[(ordering_1):desc_last]) + PROJECT(columns={'N_NAME': N_NAME, 'REVENUE': REVENUE, 'ordering_1': REVENUE}) + PROJECT(columns={'N_NAME': name, 'REVENUE': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr15') + FILTER(condition=name_4 == 'ASIA':string, columns={'key': key, 'name': name}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_4': t1.name}, correl_name='corr1') + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': SUM(value)}) + PROJECT(columns={'nation_key': nation_key, 'value': extended_price * 1:int64 - discount}) + FILTER(condition=name_13 == corr15.name, columns={'discount': discount, 'extended_price': extended_price, 'nation_key': nation_key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name_13': t1.name_13, 'nation_key': t0.nation_key}, correl_name='corr14') + JOIN(conditions=[t0.key_7 == t1.order_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'nation_key': t0.nation_key, 'supplier_key': t1.supplier_key}, correl_name='corr8') + FILTER(condition=order_date >= datetime.date(1994, 1, 1):date & order_date < datetime.date(1995, 1, 1):date, columns={'key_7': key_7, 'nation_key': nation_key}) + JOIN(conditions=[t0.key == t1.customer_key], types=['inner'], columns={'key_7': t1.key, 'nation_key': t0.nation_key, 'order_date': t1.order_date}, correl_name='corr5') + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'supplier_key': l_suppkey}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_13': t1.name}, correl_name='corr10') + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) """, tpch_q5_output, ), @@ -547,6 +568,27 @@ ( impl_tpch_q21, """ +ROOT(columns=[('S_NAME', S_NAME), ('NUMWAIT', NUMWAIT)], orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'NUMWAIT': NUMWAIT, 'S_NAME': S_NAME, 'ordering_1': ordering_1, 'ordering_2': ordering_2}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) + PROJECT(columns={'NUMWAIT': NUMWAIT, 'S_NAME': S_NAME, 'ordering_1': NUMWAIT, 'ordering_2': S_NAME}) + PROJECT(columns={'NUMWAIT': DEFAULT_TO(agg_0, 0:int64), 'S_NAME': name}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr9') + FILTER(condition=name_4 == 'SAUDI ARABIA':string, columns={'key': key, 'name': name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_4': t1.name}, correl_name='corr1') + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=order_status == 'F':string & True:bool & True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.key == t1.order_key], types=['anti'], columns={'order_status': t0.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr8') + JOIN(conditions=[t0.key == t1.order_key], types=['semi'], columns={'key': t0.key, 'order_status': t0.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr7') + JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'key': t1.key, 'order_status': t1.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr5') + FILTER(condition=receipt_date > commit_date, columns={'order_key': order_key, 'supplier_key': supplier_key}) + SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate, 'supplier_key': l_suppkey}) + SCAN(table=tpch.ORDERS, columns={'key': o_orderkey, 'order_status': o_orderstatus}) + FILTER(condition=supplier_key != corr7.supplier_key, columns={'order_key': order_key}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'supplier_key': l_suppkey}) + FILTER(condition=supplier_key != corr8.supplier_key & receipt_date > commit_date, columns={'order_key': order_key}) + SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate, 'supplier_key': l_suppkey}) """, tpch_q21_output, ), From 49f46d8c52f2ad05fd47f2b29a780f48bdd84898 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 27 Jan 2025 10:22:27 -0500 Subject: [PATCH 066/112] Rolling back SQLGlot changes and ensuring correl names are only used when needed --- pydough/conversion/relational_converter.py | 2 +- .../sqlglot_relational_expression_visitor.py | 38 ++++++++++++++----- tests/test_pipeline.py | 34 ++++++++--------- 3 files changed, 47 insertions(+), 27 deletions(-) diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index e743dde3..766a6252 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -277,7 +277,7 @@ def join_outputs( [LiteralExpression(True, BooleanType())], [join_type], join_columns, - correl_name=self.get_correlated_name(lhs_result), + correl_name=lhs_result.correlated_name, ) input_aliases: list[str | None] = out_rel.default_input_aliases diff --git a/pydough/sqlglot/sqlglot_relational_expression_visitor.py b/pydough/sqlglot/sqlglot_relational_expression_visitor.py index 24e0ea81..cdff2a13 100644 --- a/pydough/sqlglot/sqlglot_relational_expression_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_expression_visitor.py @@ -7,8 +7,8 @@ import sqlglot.expressions as sqlglot_expressions from sqlglot.dialects import Dialect as SQLGlotDialect -from sqlglot.expressions import Column, Identifier from sqlglot.expressions import Expression as SQLGlotExpression +from sqlglot.expressions import Identifier from pydough.relational import ( CallExpression, @@ -141,25 +141,45 @@ def visit_correlated_reference( "TODO: support SQL conversion for correlated references" ) + # TODO: implement the column-based version of make_sqlglot_column, with table sources + # @staticmethod + # def make_sqlglot_column( + # column_reference: ColumnReference, + # ) -> Column: + # """ + # Convert a column reference to a SQLGlot column. This is split into a + # separate static method to ensure consistency across multiple visitors. + + # Args: + # column_reference (ColumnReference): The column reference to generate + # an identifier for. + + # Returns: + # Identifier: The output column reference containing an identifier. + # """ + # result: SQLGlotExpression = Column(this=Identifier(this=column_reference.name)) + # if column_reference.input_name is not None: + # result.set("table", Identifier(this=column_reference.input_name)) + # return result + @staticmethod def make_sqlglot_column( column_reference: ColumnReference, - ) -> Column: + ) -> Identifier: """ - Convert a column reference to a SQLGlot column. This is split into a + Generate an identifier for a column reference. This is split into a separate static method to ensure consistency across multiple visitors. - Args: column_reference (ColumnReference): The column reference to generate an identifier for. - Returns: - Identifier: The output column reference containing an identifier. + Identifier: The output identifier. """ - result: SQLGlotExpression = Column(this=Identifier(this=column_reference.name)) if column_reference.input_name is not None: - result.set("table", Identifier(this=column_reference.input_name)) - return result + full_name = f"{column_reference.input_name}.{column_reference.name}" + else: + full_name = column_reference.name + return Identifier(this=full_name) def visit_column_reference(self, column_reference: ColumnReference) -> None: self._stack.append(self.make_sqlglot_column(column_reference)) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 04363519..6b526a0e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -196,22 +196,22 @@ ROOT(columns=[('N_NAME', N_NAME), ('REVENUE', REVENUE)], orderings=[(ordering_1):desc_last]) PROJECT(columns={'N_NAME': N_NAME, 'REVENUE': REVENUE, 'ordering_1': REVENUE}) PROJECT(columns={'N_NAME': name, 'REVENUE': DEFAULT_TO(agg_0, 0:int64)}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr15') - FILTER(condition=name_4 == 'ASIA':string, columns={'key': key, 'name': name}) - JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_4': t1.name}, correl_name='corr1') + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr10') + FILTER(condition=name_3 == 'ASIA':string, columns={'key': key, 'name': name}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': SUM(value)}) PROJECT(columns={'nation_key': nation_key, 'value': extended_price * 1:int64 - discount}) - FILTER(condition=name_13 == corr15.name, columns={'discount': discount, 'extended_price': extended_price, 'nation_key': nation_key}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name_13': t1.name_13, 'nation_key': t0.nation_key}, correl_name='corr14') - JOIN(conditions=[t0.key_7 == t1.order_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'nation_key': t0.nation_key, 'supplier_key': t1.supplier_key}, correl_name='corr8') - FILTER(condition=order_date >= datetime.date(1994, 1, 1):date & order_date < datetime.date(1995, 1, 1):date, columns={'key_7': key_7, 'nation_key': nation_key}) - JOIN(conditions=[t0.key == t1.customer_key], types=['inner'], columns={'key_7': t1.key, 'nation_key': t0.nation_key, 'order_date': t1.order_date}, correl_name='corr5') + FILTER(condition=name_9 == corr10.name, columns={'discount': discount, 'extended_price': extended_price, 'nation_key': nation_key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name_9': t1.name_9, 'nation_key': t0.nation_key}) + JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'nation_key': t0.nation_key, 'supplier_key': t1.supplier_key}) + FILTER(condition=order_date >= datetime.date(1994, 1, 1):date & order_date < datetime.date(1995, 1, 1):date, columns={'key_5': key_5, 'nation_key': nation_key}) + JOIN(conditions=[t0.key == t1.customer_key], types=['inner'], columns={'key_5': t1.key, 'nation_key': t0.nation_key, 'order_date': t1.order_date}) SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'supplier_key': l_suppkey}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_13': t1.name}, correl_name='corr10') + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_9': t1.name}) SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) """, @@ -572,22 +572,22 @@ LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'NUMWAIT': NUMWAIT, 'S_NAME': S_NAME, 'ordering_1': ordering_1, 'ordering_2': ordering_2}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) PROJECT(columns={'NUMWAIT': NUMWAIT, 'S_NAME': S_NAME, 'ordering_1': NUMWAIT, 'ordering_2': S_NAME}) PROJECT(columns={'NUMWAIT': DEFAULT_TO(agg_0, 0:int64), 'S_NAME': name}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr9') - FILTER(condition=name_4 == 'SAUDI ARABIA':string, columns={'key': key, 'name': name}) - JOIN(conditions=[t0.nation_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_4': t1.name}, correl_name='corr1') + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + FILTER(condition=name_3 == 'SAUDI ARABIA':string, columns={'key': key, 'name': name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': COUNT()}) FILTER(condition=order_status == 'F':string & True:bool & True:bool, columns={'supplier_key': supplier_key}) - JOIN(conditions=[t0.key == t1.order_key], types=['anti'], columns={'order_status': t0.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr8') - JOIN(conditions=[t0.key == t1.order_key], types=['semi'], columns={'key': t0.key, 'order_status': t0.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr7') - JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'key': t1.key, 'order_status': t1.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr5') + JOIN(conditions=[t0.key == t1.order_key], types=['anti'], columns={'order_status': t0.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr6') + JOIN(conditions=[t0.key == t1.order_key], types=['semi'], columns={'key': t0.key, 'order_status': t0.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr5') + JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'key': t1.key, 'order_status': t1.order_status, 'supplier_key': t0.supplier_key}) FILTER(condition=receipt_date > commit_date, columns={'order_key': order_key, 'supplier_key': supplier_key}) SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate, 'supplier_key': l_suppkey}) SCAN(table=tpch.ORDERS, columns={'key': o_orderkey, 'order_status': o_orderstatus}) - FILTER(condition=supplier_key != corr7.supplier_key, columns={'order_key': order_key}) + FILTER(condition=supplier_key != corr5.supplier_key, columns={'order_key': order_key}) SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'supplier_key': l_suppkey}) - FILTER(condition=supplier_key != corr8.supplier_key & receipt_date > commit_date, columns={'order_key': order_key}) + FILTER(condition=supplier_key != corr6.supplier_key & receipt_date > commit_date, columns={'order_key': order_key}) SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate, 'supplier_key': l_suppkey}) """, tpch_q21_output, From 81d130eae882f87228344b5d6d164fe08d6b9f30 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 27 Jan 2025 12:28:15 -0500 Subject: [PATCH 067/112] Added SQLGlot support for correlated references; working only with exists/not-exists because of sqlite correlated query constraints --- pydough/conversion/hybrid_tree.py | 97 ++++++++++++------- .../relational_expressions/__init__.py | 2 + .../correlated_reference_finder.py | 48 +++++++++ .../relational_nodes/column_pruner.py | 42 ++++++-- .../sqlglot_relational_expression_visitor.py | 11 ++- pydough/sqlglot/sqlglot_relational_visitor.py | 25 +++-- pydough/unqualified/qualification.py | 2 +- tests/test_pipeline.py | 18 ++++ tests/tpch_outputs.py | 2 +- 9 files changed, 190 insertions(+), 57 deletions(-) create mode 100644 pydough/relational/relational_expressions/correlated_reference_finder.py diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index b7fe4610..732fbb1c 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -1555,6 +1555,62 @@ def make_agg_call( joins_can_nullify: bool = not isinstance(hybrid.pipeline[0], HybridRoot) return self.postprocess_agg_output(hybrid_call, result_ref, joins_can_nullify) + def make_hybrid_correl_expr( + self, + back_expr: BackReferenceExpression, + collection: PyDoughCollectionQDAG, + steps_taken_so_far: int, + ) -> HybridCorrelExpr: + """ + TODO + """ + if len(self.stack) == 0: + raise ValueError("Back reference steps too far back") + parent_tree = self.stack.pop() + remaining_steps_back: int = back_expr.back_levels - steps_taken_so_far - 1 + parent_result: HybridExpr + if len(parent_tree.pipeline) == 1 and isinstance( + parent_tree.pipeline[0], HybridPartition + ): + assert parent_tree.parent is not None + self.stack.append(parent_tree.parent) + parent_result = self.make_hybrid_correl_expr( + back_expr, collection, steps_taken_so_far + ) + self.stack.pop() + self.stack.append(parent_tree) + match parent_result.expr: + case HybridRefExpr(): + parent_result = HybridBackRefExpr( + parent_result.expr.name, 1, parent_result.typ + ) + case HybridBackRefExpr(): + parent_result = HybridBackRefExpr( + parent_result.expr.name, + parent_result.expr.back_idx + 1, + parent_result.typ, + ) + case _: + raise ValueError( + f"Malformed expression for correlated reference: {parent_result}" + ) + elif remaining_steps_back == 0: + if back_expr.term_name not in parent_tree.pipeline[-1].terms: + raise ValueError( + f"Back reference to {back_expr.term_name} not found in parent" + ) + parent_name: str = parent_tree.pipeline[-1].renamings.get( + back_expr.term_name, back_expr.term_name + ) + parent_result = HybridRefExpr(parent_name, back_expr.pydough_type) + else: + new_expr: PyDoughExpressionQDAG = BackReferenceExpression( + collection, back_expr.term_name, remaining_steps_back + ) + parent_result = self.make_hybrid_expr(parent_tree, new_expr, {}, False) + self.stack.append(parent_tree) + return HybridCorrelExpr(parent_tree, parent_result) + def make_hybrid_expr( self, hybrid: HybridTree, @@ -1608,10 +1664,9 @@ def make_hybrid_expr( case BackReferenceExpression(): # A reference to an expression from an ancestor becomes a # reference to one of the terms of a parent level of the hybrid - # tree. This does not yet support cases where the back - # reference steps outside of a child subtree and back into its - # parent subtree, since that breaks the independence between - # the parent and child. + # tree. If the BACK goes far enough that it must step outside + # a child subtree into the parent, a correlated reference is + # created. ancestor_tree = hybrid back_idx: int = 0 true_steps_back: int = 0 @@ -1623,39 +1678,9 @@ def make_hybrid_expr( assert collection.ancestor_context is not None collection = collection.ancestor_context if ancestor_tree.parent is None: - if len(self.stack) == 0: - raise ValueError("Back reference steps too far back") - parent_tree = self.stack.pop() - remaining_steps_back: int = ( - expr.back_levels - true_steps_back - 1 + return self.make_hybrid_correl_expr( + expr, collection, true_steps_back ) - # TODO: deal with case where the ancestor is PARTITION - # if len(parent_tree.pipeline) == 1 and isinstance(parent_tree.pipeline[0], HybridPartition): - # remaining_steps_back += 1 - parent_result: HybridExpr - if remaining_steps_back == 0: - if expr.term_name not in parent_tree.pipeline[-1].terms: - raise ValueError( - f"Back reference to {expr.term_name} not found in parent" - ) - parent_name: str = parent_tree.pipeline[-1].renamings.get( - expr.term_name, expr.term_name - ) - parent_result = HybridRefExpr( - parent_name, expr.pydough_type - ) - else: - new_expr: PyDoughExpressionQDAG = BackReferenceExpression( - collection, expr.term_name, remaining_steps_back - ) - parent_result = self.make_hybrid_expr( - parent_tree, new_expr, {}, False - ) - self.stack.append(parent_tree) - return HybridCorrelExpr(parent_tree, parent_result) - # raise NotImplementedError( - # "TODO: (gh #141) support BACK references that step from a child subtree back into a parent context." - # ) ancestor_tree = ancestor_tree.parent back_idx += true_steps_back if not ancestor_tree.is_hidden_level: diff --git a/pydough/relational/relational_expressions/__init__.py b/pydough/relational/relational_expressions/__init__.py index e8600faf..3eb8fc33 100644 --- a/pydough/relational/relational_expressions/__init__.py +++ b/pydough/relational/relational_expressions/__init__.py @@ -10,6 +10,7 @@ "ColumnReferenceInputNameModifier", "ColumnReferenceInputNameRemover", "CorrelatedReference", + "CorrelatedReferenceFinder", "ExpressionSortInfo", "LiteralExpression", "RelationalExpression", @@ -23,6 +24,7 @@ from .column_reference_input_name_modifier import ColumnReferenceInputNameModifier from .column_reference_input_name_remover import ColumnReferenceInputNameRemover from .correlated_reference import CorrelatedReference +from .correlated_reference_finder import CorrelatedReferenceFinder from .expression_sort_info import ExpressionSortInfo from .literal_expression import LiteralExpression from .relational_expression_visitor import RelationalExpressionVisitor diff --git a/pydough/relational/relational_expressions/correlated_reference_finder.py b/pydough/relational/relational_expressions/correlated_reference_finder.py new file mode 100644 index 00000000..20ae0dc2 --- /dev/null +++ b/pydough/relational/relational_expressions/correlated_reference_finder.py @@ -0,0 +1,48 @@ +""" +Find all unique column references in a relational expression. +""" + +from .call_expression import CallExpression +from .column_reference import ColumnReference +from .correlated_reference import CorrelatedReference +from .literal_expression import LiteralExpression +from .relational_expression_visitor import RelationalExpressionVisitor +from .window_call_expression import WindowCallExpression + +__all__ = ["CorrelatedReferenceFinder"] + + +class CorrelatedReferenceFinder(RelationalExpressionVisitor): + """ + Find all unique correlated references in a relational expression. + """ + + def __init__(self) -> None: + self._correlated_references: set[CorrelatedReference] = set() + + def reset(self) -> None: + self._correlated_references = set() + + def get_correlated_references(self) -> set[CorrelatedReference]: + return self._correlated_references + + def visit_call_expression(self, call_expression: CallExpression) -> None: + for arg in call_expression.inputs: + arg.accept(self) + + def visit_window_expression(self, window_expression: WindowCallExpression) -> None: + for arg in window_expression.inputs: + arg.accept(self) + for partition_arg in window_expression.partition_inputs: + partition_arg.accept(self) + for order_arg in window_expression.order_inputs: + order_arg.expr.accept(self) + + def visit_literal_expression(self, literal_expression: LiteralExpression) -> None: + pass + + def visit_column_reference(self, column_reference: ColumnReference) -> None: + pass + + def visit_correlated_reference(self, correlated_reference) -> None: + self._correlated_references.add(correlated_reference) diff --git a/pydough/relational/relational_nodes/column_pruner.py b/pydough/relational/relational_nodes/column_pruner.py index d99a265e..849bdf2b 100644 --- a/pydough/relational/relational_nodes/column_pruner.py +++ b/pydough/relational/relational_nodes/column_pruner.py @@ -6,6 +6,7 @@ ColumnReference, ColumnReferenceFinder, CorrelatedReference, + CorrelatedReferenceFinder, ) from .abstract_node import Relational @@ -21,11 +22,15 @@ class ColumnPruner: def __init__(self) -> None: self._column_finder: ColumnReferenceFinder = ColumnReferenceFinder() + self._correl_finder: CorrelatedReferenceFinder = CorrelatedReferenceFinder() # Note: We set recurse=False so we only check the expressions in the # current node. - self._dispatcher = RelationalExpressionDispatcher( + self._finder_dispatcher = RelationalExpressionDispatcher( self._column_finder, recurse=False ) + self._correl_dispatcher = RelationalExpressionDispatcher( + self._correl_finder, recurse=False + ) def _prune_identity_project(self, node: Relational) -> Relational: """ @@ -70,14 +75,17 @@ def _prune_node_columns( for name, expr in node.columns.items() if name in kept_columns or name in required_columns } + # Update the columns. new_node = node.copy(columns=columns) - self._dispatcher.reset() - # Visit the current identifiers. - new_node.accept(self._dispatcher) + + # Find all the identifiers referenced by the the current node. + self._finder_dispatcher.reset() + new_node.accept(self._finder_dispatcher) found_identifiers: set[ColumnReference] = ( self._column_finder.get_column_references() ) + # If the node is an aggregate but doesn't use any of the inputs # (e.g. a COUNT(*)), arbitrarily mark one of them as used. # TODO: (gh #196) optimize this functionality so it doesn't keep an @@ -90,12 +98,14 @@ def _prune_node_columns( node.input.columns[arbitrary_column_name].data_type, ) ) + # Determine which identifiers to pass to each input. new_inputs: list[Relational] = [] # Note: The ColumnPruner should only be run when all input names are # still present in the columns. # Iterate over the inputs in reverse order so that the source of - # correlated data is pruned last. + # correlated data is pruned last, since it will need to account for + # any correlated references in the later inputs. correl_refs: set[CorrelatedReference] = set() for i, default_input_name in reversed( list(enumerate(new_node.default_input_aliases)) @@ -105,14 +115,30 @@ def _prune_node_columns( for identifier in found_identifiers: if identifier.input_name == default_input_name: s.add(identifier.name) - if isinstance(input_node, Join) and i == 0: + if ( + isinstance(new_node, Join) + and i == 0 + and new_node.correl_name is not None + ): for correl_ref in correl_refs: - if correl_ref.correl_name == input_node.correl_name: + if correl_ref.correl_name == new_node.correl_name: s.add(correl_ref.name) new_input_node, new_correl_refs = self._prune_node_columns(input_node, s) - correl_refs.update(new_correl_refs) new_inputs.append(new_input_node) + if i == len(node.inputs) - 1: + correl_refs = new_correl_refs + else: + correl_refs.update(new_correl_refs) new_inputs.reverse() + + # Find all the correlated references in the new node. + self._correl_dispatcher.reset() + new_node.accept(self._correl_dispatcher) + found_correl_refs: set[CorrelatedReference] = ( + self._correl_finder.get_correlated_references() + ) + correl_refs.update(found_correl_refs) + # Determine the new node. output = new_node.copy(inputs=new_inputs) return self._prune_identity_project(output), correl_refs diff --git a/pydough/sqlglot/sqlglot_relational_expression_visitor.py b/pydough/sqlglot/sqlglot_relational_expression_visitor.py index cdff2a13..8c03edd5 100644 --- a/pydough/sqlglot/sqlglot_relational_expression_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_expression_visitor.py @@ -33,13 +33,17 @@ class SQLGlotRelationalExpressionVisitor(RelationalExpressionVisitor): """ def __init__( - self, dialect: SQLGlotDialect, bindings: SqlGlotTransformBindings + self, + dialect: SQLGlotDialect, + bindings: SqlGlotTransformBindings, + correlated_names: dict[str, str], ) -> None: # Keep a stack of SQLGlot expressions so we can build up # intermediate results. self._stack: list[SQLGlotExpression] = [] self._dialect: SQLGlotDialect = dialect self._bindings: SqlGlotTransformBindings = bindings + self._correlated_names: dict[str, str] = correlated_names def reset(self) -> None: """ @@ -137,9 +141,8 @@ def visit_literal_expression(self, literal_expression: LiteralExpression) -> Non def visit_correlated_reference( self, correlated_reference: CorrelatedReference ) -> None: - raise NotImplementedError( - "TODO: support SQL conversion for correlated references" - ) + full_name: str = f"{self._correlated_names[correlated_reference.correl_name]}.{correlated_reference.name}" + self._stack.append(Identifier(this=full_name)) # TODO: implement the column-based version of make_sqlglot_column, with table sources # @staticmethod diff --git a/pydough/sqlglot/sqlglot_relational_visitor.py b/pydough/sqlglot/sqlglot_relational_visitor.py index c3c3c196..327bc3ef 100644 --- a/pydough/sqlglot/sqlglot_relational_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_visitor.py @@ -21,6 +21,7 @@ ColumnReference, ColumnReferenceInputNameModifier, ColumnReferenceInputNameRemover, + CorrelatedReference, EmptySingleton, ExpressionSortInfo, Filter, @@ -55,8 +56,11 @@ def __init__( # Keep a stack of SQLGlot expressions so we can build up # intermediate results. self._stack: list[Select] = [] + self._correlated_names: dict[str, str] = {} self._expr_visitor: SQLGlotRelationalExpressionVisitor = ( - SQLGlotRelationalExpressionVisitor(dialect, bindings) + SQLGlotRelationalExpressionVisitor( + dialect, bindings, self._correlated_names + ) ) self._alias_modifier: ColumnReferenceInputNameModifier = ( ColumnReferenceInputNameModifier() @@ -303,7 +307,7 @@ def contains_window(self, exp: RelationalExpression) -> bool: match exp: case CallExpression(): return any(self.contains_window(arg) for arg in exp.inputs) - case ColumnReference() | LiteralExpression(): + case ColumnReference() | LiteralExpression() | CorrelatedReference(): return False case WindowCallExpression(): return True @@ -329,6 +333,12 @@ def visit_scan(self, scan: Scan) -> None: self._stack.append(query) def visit_join(self, join: Join) -> None: + alias_map: dict[str | None, str] = {} + if join.correl_name is not None: + input_name = join.default_input_aliases[0] + alias = self._generate_table_alias() + alias_map[input_name] = alias + self._correlated_names[join.correl_name] = alias self.visit_inputs(join) inputs: list[Select] = [self._stack.pop() for _ in range(len(join.inputs))] inputs.reverse() @@ -339,11 +349,12 @@ def visit_join(self, join: Join) -> None: seen_names[column] += 1 # Only keep duplicate names. kept_names = {key for key, value in seen_names.items() if value > 1} - alias_map = { - join.default_input_aliases[i]: self._generate_table_alias() - for i in range(len(join.inputs)) - if kept_names.intersection(join.inputs[i].columns.keys()) - } + for i in range(len(join.inputs)): + input_name = join.default_input_aliases[i] + if input_name not in alias_map and kept_names.intersection( + join.inputs[i].columns.keys() + ): + alias_map[input_name] = self._generate_table_alias() self._alias_remover.set_kept_names(kept_names) self._alias_modifier.set_map(alias_map) columns = { diff --git a/pydough/unqualified/qualification.py b/pydough/unqualified/qualification.py index 216c4011..ba33907d 100644 --- a/pydough/unqualified/qualification.py +++ b/pydough/unqualified/qualification.py @@ -371,7 +371,7 @@ def qualify_access( for _ in range(levels): if ancestor.ancestor_context is None: raise PyDoughUnqualifiedException( - f"Cannot back reference {levels} above {unqualified_parent}" + f"Cannot back reference {levels} above {context}" ) ancestor = ancestor.ancestor_context # Identify whether the access is an expression or a collection diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 6b526a0e..006b58eb 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -598,6 +598,24 @@ ( impl_tpch_q22, """ +ROOT(columns=[('CNTRY_CODE', CNTRY_CODE), ('NUM_CUSTS', NUM_CUSTS), ('TOTACCTBAL', TOTACCTBAL)], orderings=[]) + PROJECT(columns={'CNTRY_CODE': cntry_code, 'NUM_CUSTS': DEFAULT_TO(agg_1, 0:int64), 'TOTACCTBAL': DEFAULT_TO(agg_2, 0:int64)}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1, 'agg_2': t1.agg_2, 'cntry_code': t1.cntry_code}, correl_name='corr1') + PROJECT(columns={'avg_balance': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(acctbal)}) + FILTER(condition=acctbal > 0.0:float64, columns={'acctbal': acctbal}) + FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]) & True:bool, columns={'acctbal': acctbal}) + JOIN(conditions=[t0.key == t1.customer_key], types=['anti'], columns={'acctbal': t0.acctbal, 'cntry_code': t0.cntry_code}) + PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) + AGGREGATE(keys={'cntry_code': cntry_code}, aggregations={'agg_1': COUNT(), 'agg_2': SUM(acctbal)}) + FILTER(condition=acctbal > corr1.avg_balance, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) + FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]) & True:bool, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) + JOIN(conditions=[t0.key == t1.customer_key], types=['anti'], columns={'acctbal': t0.acctbal, 'cntry_code': t0.cntry_code}) + PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) """, tpch_q22_output, ), diff --git a/tests/tpch_outputs.py b/tests/tpch_outputs.py index c875d630..80948acb 100644 --- a/tests/tpch_outputs.py +++ b/tests/tpch_outputs.py @@ -692,7 +692,7 @@ def tpch_q21_output() -> pd.DataFrame: Expected output for TPC-H query 21. Note: This is truncated to the first 10 rows. """ - columns = ["S_NAME", "NUM_WAIT"] + columns = ["S_NAME", "NUMWAIT"] data = [ ("Supplier#000002829", 20), ("Supplier#000005808", 18), From 60fea67fcd1e6d27ef553528cdd79cab14f40262 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 27 Jan 2025 15:36:52 -0500 Subject: [PATCH 068/112] Adding additional tests, have a plan for how to deal with the 4 major unsupported cases (singular, agg, and both combined with SEMI) --- pydough/conversion/relational_converter.py | 19 +- tests/simple_pydough_functions.py | 126 +++++++++ tests/test_pipeline.py | 288 +++++++++++++++++++++ 3 files changed, 427 insertions(+), 6 deletions(-) diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 766a6252..c725971f 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -212,13 +212,20 @@ def translate_expression( ancestor_expr: RelationalExpression = self.translate_expression( expr.expr, ancestor_context ) - assert isinstance(ancestor_expr, ColumnReference) self.stack.append(ancestor_context) - return CorrelatedReference( - ancestor_expr.name, - self.get_correlated_name(ancestor_context), - expr.typ, - ) + match ancestor_expr: + case ColumnReference(): + return CorrelatedReference( + ancestor_expr.name, + self.get_correlated_name(ancestor_context), + expr.typ, + ) + case CorrelatedReference(): + return ancestor_expr + case _: + raise ValueError( + f"Unsupported expression to reference in a correlated reference: {ancestor_expr}" + ) case _: raise NotImplementedError(expr.__class__.__name__) diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index 7ec5ba0e..e9f06919 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -270,3 +270,129 @@ def triple_partition(): )(supp_region, avg_percentage=AVG(cust_regions.percentage)).ORDER_BY( supp_region.ASC() ) + + +def correl_1(): + # Correlated back reference example #1: simple 1-step correlated reference + # For each region, count how many of its its nations start with the same + # letter as the region. This is a true correlated join doing an aggregated + # access without requiring the RHS be present. + return Regions( + name, n_prefix_nations=COUNT(nations.WHERE(name[:1] == BACK(1).name[:1])) + ) + + +def correl_2(): + # Correlated back reference example #2: simple 2-step correlated reference + # For each region's nations, count how many customers have a comment + # starting with the same letter as the region. Exclude regions that start + # with the letter a. This is a true correlated join doing an aggregated + # access without requiring the RHS be present. + selected_custs = customers.WHERE(comment[:1] == LOWER(BACK(2).name[:1])) + return Regions.WHERE(~STARTSWITH(name, "A")).nations( + name, + n_selected_custs=COUNT(selected_custs), + ) + + +def correl_3(): + # Correlated back reference example #3: double-layer correlated reference + # For every every region, count how many of its nations have a customer + # whose comment starts with the same letter as the region. This is a true + # correlated join doing an aggregated access without requiring the RHS be + # present. + selected_custs = customers.WHERE(comment[:1] == LOWER(BACK(2).name[:1])) + return Regions(name, n_nations=COUNT(nations.WHERE(HAS(selected_custs)))) + + +def correl_4(): + # Correlated back reference example #4: 2-step correlated HASNOT + # Find every nation that does not have a customer whose account balance is + # within $5 of the smallest known account balance globally. + # (This is a correlated ANTI-join) + selected_customers = customers.WHERE(acctbal <= (BACK(2).smallest_bal + 5.0)) + return ( + TPCH( + smallest_bal=MIN(Customers.acctbal), + ) + .Nations(name) + .WHERE(HASNOT(selected_customers)) + .ORDER_BY(name.ASC()) + ) + + +def correl_5(): + # Correlated back reference example #5: 2-step correlated HAS + # Find every region that has at least 1 supplier whose account balance is + # within $4 of the smallest known account balance globally + # (This is a correlated SEMI-join) + selected_suppliers = nations.suppliers.WHERE( + account_balance <= (BACK(3).smallest_bal + 4.0) + ) + return ( + TPCH( + smallest_bal=MIN(Suppliers.account_balance), + ) + .Regions(name) + .WHERE(HAS(selected_suppliers)) + .ORDER_BY(name.ASC()) + ) + + +def correl_6(): + # Correlated back reference example #6: simple 1-step correlated reference + # For each region, count how many of its its nations start with the same + # letter as the region, but only keep regions with at least one such nation. + # This is a true correlated join doing an aggregated access that does NOT + # require that records without the RHS be kept. + selected_nations = nations.WHERE(name[:1] == BACK(1).name[:1]) + return Regions.WHERE(HAS(selected_nations))( + name, n_prefix_nations=COUNT(selected_nations) + ) + + +def correl_7(): + # Correlated back reference example #6: deleted correlated reference + # For each region, count how many of its its nations start with the same + # letter as the region, but only keep regions without at least one such + # nation. The true correlated join is trumped by the correlated ANTI-join. + selected_nations = nations.WHERE(name[:1] == BACK(1).name[:1]) + return Regions.WHERE(HASNOT(selected_nations))( + name, n_prefix_nations=COUNT(selected_nations) + ) + + +def correl_8(): + # Correlated back reference example #8: non-agg correlated reference + # For each nation, fetch the name of its region, but filter the reigon + # so it only keeps it if it starts with the same letter as the nation + # (otherwise, returns NULL). This is a true correlated join doing an + # access without aggregation without requiring the RHS be + # present. + aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) + return Nations(name, rname=aug_region.name).ORDER_BY(name.ASC()) + + +def correl_9(): + # Correlated back reference example #9: non-agg correlated reference + # For each nation, fetch the name of its region, but filter the reigon + # so it only keeps it if it starts with the same letter as the nation + # (otherwise, omit the nation). This is a true correlated join doing an + # access that also requires the RHS records be present. + aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) + return Nations.WHERE(HAS(aug_region))(name, rname=aug_region.name).ORDER_BY( + name.ASC() + ) + + +def correl_10(): + # Correlated back reference example #10: deleted correlated reference + # For each nation, fetch the name of its region, but filter the reigon + # so it only keeps it if it starts with the same letter as the nation + # (otherwise, returns NULL), and also filter the nations to only keep + # records where the region is NULL. The true correlated join is trumped by + # the correlated ANTI-join. + aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) + return Nations.WHERE(HASNOT(aug_region))(name, rname=aug_region.name).ORDER_BY( + name.ASC() + ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 006b58eb..3f693cac 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -14,6 +14,16 @@ ) from simple_pydough_functions import ( agg_partition, + correl_1, + correl_2, + correl_3, + correl_4, + correl_5, + correl_6, + correl_7, + correl_8, + correl_9, + correl_10, double_partition, function_sampler, percentile_customers_per_region, @@ -1119,6 +1129,284 @@ ), id="triple_partition", ), + pytest.param( + ( + correl_1, + """ +ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[]) + PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) +""", + lambda: pd.DataFrame( + { + "name": ["AFRICA" "AMERICA" "MIDDLE EAST" "EUROPE" "ASIA"], + "n_prefix_nations": [1, 1, 0, 0, 0], + } + ), + ), + id="correl_1", + ), + pytest.param( + ( + correl_2, + """ +ROOT(columns=[('name', name), ('n_selected_custs', n_selected_custs)], orderings=[]) + PROJECT(columns={'n_selected_custs': DEFAULT_TO(agg_0, 0:int64), 'name': name_3}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}, correl_name='corr4') + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name}) + FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(comment, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(corr4.name, None:unknown, 1:int64, None:unknown)), columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) +""", + lambda: pd.DataFrame( + { + "name": ["A"] * 5, + } + ), + ), + id="correl_2", + ), + pytest.param( + ( + correl_3, + """ +ROOT(columns=[('name', name), ('n_nations', n_nations)], orderings=[]) + PROJECT(columns={'n_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=True:bool, columns={'region_key': region_key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['semi'], columns={'region_key': t0.region_key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + FILTER(condition=SLICE(comment, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(corr1.name, None:unknown, 1:int64, None:unknown)), columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) +""", + lambda: pd.DataFrame( + { + "name": ["A"] * 5, + } + ), + ), + id="correl_3", + ), + pytest.param( + ( + correl_4, + """ +ROOT(columns=[('name', name)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'name': name, 'ordering_1': name}) + FILTER(condition=True:bool, columns={'name': name}) + JOIN(conditions=[t0.key == t1.nation_key], types=['anti'], columns={'name': t0.name}, correl_name='corr1') + JOIN(conditions=[True:bool], types=['inner'], columns={'key': t1.key, 'name': t1.name, 'smallest_bal': t0.smallest_bal}) + PROJECT(columns={'smallest_bal': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': MIN(acctbal)}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + FILTER(condition=acctbal <= corr1.smallest_bal + 5.0:float64, columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) +""", + lambda: pd.DataFrame( + { + "name": ["ARGENTINA", "KENYA", "UNITED KINGDOM"], + } + ), + ), + id="correl_4", + ), + pytest.param( + ( + correl_5, + """ +ROOT(columns=[('name', name)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'name': name, 'ordering_1': name}) + FILTER(condition=True:bool, columns={'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['semi'], columns={'name': t0.name}, correl_name='corr4') + JOIN(conditions=[True:bool], types=['inner'], columns={'key': t1.key, 'name': t1.name, 'smallest_bal': t0.smallest_bal}) + PROJECT(columns={'smallest_bal': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': MIN(account_balance)}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + FILTER(condition=account_balance <= corr4.smallest_bal + 4.0:float64, columns={'region_key': region_key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'account_balance': t1.account_balance, 'region_key': t0.region_key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) +""", + lambda: pd.DataFrame( + { + "name": ["AFRICA", "ASIA", "MIDDLE EAST"], + } + ), + ), + id="correl_5", + ), + pytest.param( + ( + correl_6, + """ +ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[]) + PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + FILTER(condition=True:bool, columns={'agg_0': agg_0, 'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) +""", + lambda: pd.DataFrame( + { + "name": ["AFRICA", "AMERICA"], + "n_prefix_nations": [1, 1], + } + ), + ), + id="correl_6", + ), + pytest.param( + ( + correl_7, + """ +ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[]) + PROJECT(columns={'n_prefix_nations': DEFAULT_TO(NULL_2, 0:int64), 'name': name}) + FILTER(condition=True:bool, columns={'NULL_2': NULL_2, 'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) +""", + lambda: pd.DataFrame( + { + "name": ["ASIA", "EUROPE", "MIDDLE EAST"], + "n_prefix_nations": [0] * 3, + } + ), + ), + id="correl_7", + ), + pytest.param( + ( + correl_8, + """ +ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) + PROJECT(columns={'name': name, 'rname': name_4}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'name': t0.name, 'name_4': t1.name}, correl_name='corr1') + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) +""", + lambda: pd.DataFrame( + { + "name": [ + "ALGERIA", + "ARGENTINA", + "BRAZIL", + "CANADA", + "CHINA", + "EGYPT", + "ETHIOPIA", + "FRANCE", + "GERMANY", + "INDIA", + "INDONESIA", + "IRAN", + "IRAQ", + "JAPAN", + "JORDAN", + "KENYA", + "MOROCCO", + "MOZAMBIQUE", + "PERU", + "ROMANIA", + "RUSSIA", + "SAUDI ARABIA", + "UNITED KINGDOM", + "UNITED STATES", + "VIETNAM", + ], + "rname": ["AFRICA", "AMERICA"] + [None] * 23, + } + ), + ), + id="correl_8", + ), + pytest.param( + ( + correl_9, + """ +ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) + PROJECT(columns={'name': name, 'rname': name_4}) + FILTER(condition=True:bool, columns={'name': name, 'name_4': name_4}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name': t0.name, 'name_4': t1.name}, correl_name='corr1') + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) +""", + lambda: pd.DataFrame( + { + "name": [ + "ALGERIA", + "ARGENTINA", + ], + "rname": ["AFRICA", "AMERICA"], + } + ), + ), + id="correl_9", + ), + pytest.param( + ( + correl_10, + """ +ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) + PROJECT(columns={'name': name, 'rname': NULL_2}) + FILTER(condition=True:bool, columns={'NULL_2': NULL_2, 'name': name}) + JOIN(conditions=[t0.region_key == t1.key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'key': key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) +""", + lambda: pd.DataFrame( + { + "name": [ + "BRAZIL", + "CANADA", + "CHINA", + "EGYPT", + "ETHIOPIA", + "FRANCE", + "GERMANY", + "INDIA", + "INDONESIA", + "IRAN", + "IRAQ", + "JAPAN", + "JORDAN", + "KENYA", + "MOROCCO", + "MOZAMBIQUE", + "PERU", + "ROMANIA", + "RUSSIA", + "SAUDI ARABIA", + "UNITED KINGDOM", + "UNITED STATES", + "VIETNAM", + ], + "rname": [None] * 23, + } + ), + ), + id="correl_10", + ), ], ) def pydough_pipeline_test_data( From 9980a2c09a6d64e78fd581897da88d1e6a1f5264 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 28 Jan 2025 11:56:31 -0500 Subject: [PATCH 069/112] Pulling down changes --- pydough/conversion/hybrid_tree.py | 423 ++++++++++-------- pydough/conversion/relational_converter.py | 63 ++- pydough/relational/__init__.py | 2 + .../relational_expressions/__init__.py | 4 + .../column_reference_finder.py | 3 + .../column_reference_input_name_modifier.py | 3 + .../column_reference_input_name_remover.py | 3 + .../correlated_reference.py | 61 +++ .../correlated_reference_finder.py | 48 ++ .../relational_expression_shuttle.py | 11 + .../relational_expression_visitor.py | 9 + .../relational_nodes/column_pruner.py | 59 ++- pydough/relational/relational_nodes/join.py | 18 +- .../sqlglot_relational_expression_visitor.py | 40 +- pydough/sqlglot/sqlglot_relational_visitor.py | 41 +- pydough/unqualified/qualification.py | 2 +- tests/simple_pydough_functions.py | 126 ++++++ tests/test_pipeline.py | 351 ++++++++++++++- tests/tpch_outputs.py | 2 +- 19 files changed, 1057 insertions(+), 212 deletions(-) create mode 100644 pydough/relational/relational_expressions/correlated_reference.py create mode 100644 pydough/relational/relational_expressions/correlated_reference_finder.py diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 0a9fc64e..732fbb1c 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -229,6 +229,27 @@ def shift_back(self, levels: int) -> HybridExpr | None: return HybridBackRefExpr(self.name, self.back_idx + levels, self.typ) +class HybridCorrelExpr(HybridExpr): + """ + Class for HybridExpr terms that are expressions from a parent hybrid tree + rather than an ancestor, which requires a correlated reference. + """ + + def __init__(self, hybrid: "HybridTree", expr: HybridExpr): + super().__init__(expr.typ) + self.hybrid = hybrid + self.expr: HybridExpr = expr + + def __repr__(self): + return f"CORREL({self.expr})" + + def apply_renamings(self, renamings: dict[str, str]) -> "HybridExpr": + return self + + def shift_back(self, levels: int) -> HybridExpr | None: + return self + + class HybridLiteralExpr(HybridExpr): """ Class for HybridExpr terms that are literals. @@ -1046,6 +1067,10 @@ def __init__(self, configs: PyDoughConfigs): self.configs = configs # An index used for creating fake column names for aliases self.alias_counter: int = 0 + # A stack where each element is a hybrid tree being derived + # as as subtree of the previous element, and the current tree is + # being derived as the subtree of the last element. + self.stack: list[HybridTree] = [] @staticmethod def get_join_keys( @@ -1230,6 +1255,7 @@ def populate_children( accordingly so expressions using the child indices know what hybrid connection index to use. """ + self.stack.append(hybrid) for child_idx, child in enumerate(child_operator.children): # Build the hybrid tree for the child. Before doing so, reset the # alias counter to 0 to ensure that identical subtrees are named @@ -1269,108 +1295,7 @@ def populate_children( for con_typ in reference_types: connection_type = connection_type.reconcile_connection_types(con_typ) child_idx_mapping[child_idx] = hybrid.add_child(subtree, connection_type) - - def make_hybrid_agg_expr( - self, - hybrid: HybridTree, - expr: PyDoughExpressionQDAG, - child_ref_mapping: dict[int, int], - ) -> tuple[HybridExpr, int | None]: - """ - Converts a QDAG expression into a HybridExpr specifically with the - intent of making it the input to an aggregation call. Returns the - converted function argument, as well as an index indicating what child - subtree the aggregation's arguments belong to. NOTE: the HybridExpr is - phrased relative to the child subtree, rather than relative to `hybrid` - itself. - - Args: - `hybrid`: the hybrid tree that should be used to derive the - translation of `expr`, as it is the context in which the `expr` - will live. - `expr`: the QDAG expression to be converted. - `child_ref_mapping`: mapping of indices used by child references - in the original expressions to the index of the child hybrid tree - relative to the current level. - - Returns: - The HybridExpr node corresponding to `expr`, as well as the index - of the child it belongs to (e.g. which subtree does this - aggregation need to be done on top of). - """ - hybrid_result: HybridExpr - # This value starts out as None since we do not know the child index - # that `expr` correspond to yet. It may still be None at the end, since - # it is possible that `expr` does not correspond to any child index. - child_idx: int | None = None - match expr: - case PartitionKey(): - return self.make_hybrid_agg_expr(hybrid, expr.expr, child_ref_mapping) - case Literal(): - # Literals are kept as-is. - hybrid_result = HybridLiteralExpr(expr) - case ChildReferenceExpression(): - # Child references become regular references because the - # expression is phrased as if we were inside the child rather - # than the parent. - child_idx = child_ref_mapping[expr.child_idx] - child_connection = hybrid.children[child_idx] - expr_name = child_connection.subtree.pipeline[-1].renamings.get( - expr.term_name, expr.term_name - ) - hybrid_result = HybridRefExpr(expr_name, expr.pydough_type) - case ExpressionFunctionCall(): - if expr.operator.is_aggregation: - raise NotImplementedError( - "PyDough does not yet support calling aggregations inside of aggregations" - ) - # Every argument must be translated in the same manner as a - # regular function argument, except that the child index it - # corresponds to must be reconciled with the child index value - # accumulated so far. - args: list[HybridExpr] = [] - for arg in expr.args: - if not isinstance(arg, PyDoughExpressionQDAG): - raise NotImplementedError( - f"TODO: support converting {arg.__class__.__name__} as a function argument" - ) - hybrid_arg, hybrid_child_index = self.make_hybrid_agg_expr( - hybrid, arg, child_ref_mapping - ) - if hybrid_child_index is not None: - if child_idx is None: - # In this case, the argument is the first one seen that - # has an index, so that index is chosen. - child_idx = hybrid_child_index - elif hybrid_child_index != child_idx: - # In this case, multiple arguments correspond to - # different children, which cannot be handled yet - # because it means it is impossible to push the agg - # call into a single HybridConnection node. - raise NotImplementedError( - "Unsupported case: multiple child indices referenced by aggregation arguments" - ) - args.append(hybrid_arg) - hybrid_result = HybridFunctionExpr( - expr.operator, args, expr.pydough_type - ) - case BackReferenceExpression(): - raise NotImplementedError( - "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of an ancestor of the current context" - ) - case Reference(): - raise NotImplementedError( - "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of the context itself" - ) - case WindowCall(): - raise NotImplementedError( - "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and window functions" - ) - case _: - raise NotImplementedError( - f"TODO: support converting {expr.__class__.__name__} in aggregations" - ) - return hybrid_result, child_idx + self.stack.pop() def postprocess_agg_output( self, agg_call: HybridFunctionExpr, agg_ref: HybridExpr, joins_can_nullify: bool @@ -1533,11 +1458,165 @@ def handle_has_hasnot( # has / hasnot condition is now known to be true. return HybridLiteralExpr(Literal(True, BooleanType())) + def convert_agg_arg(self, expr: HybridExpr, child_indices: set[int]) -> HybridExpr: + """ + Translates a hybrid expression that is an argument to an aggregation + (or a subexpression of such an argument) into a form that is expressed + from the perspective of the child subtree that is being aggregated. + + Args: + `expr`: the expression to be converted. + `child_indices`: a set that is mutated to contain the indices of + any children that are referenced by `expr`. + + Returns: + The translated expression. + + Raises: + NotImplementedError if `expr` is an expression that cannot be used + inside of an aggregation call. + """ + match expr: + case HybridLiteralExpr(): + return expr + case HybridChildRefExpr(): + # Child references become regular references because the + # expression is phrased as if we were inside the child rather + # than the parent. + child_indices.add(expr.child_idx) + return HybridRefExpr(expr.name, expr.typ) + case HybridFunctionExpr(): + return HybridFunctionExpr( + expr.operator, + [self.convert_agg_arg(arg, child_indices) for arg in expr.args], + expr.typ, + ) + case HybridBackRefExpr(): + raise NotImplementedError( + "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of an ancestor of the current context" + ) + case HybridRefExpr(): + raise NotImplementedError( + "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and fields of the context itself" + ) + case HybridWindowExpr(): + raise NotImplementedError( + "PyDough does yet support aggregations whose arguments mix between subcollection data of the current context and window functions" + ) + case _: + raise NotImplementedError( + f"TODO: support converting {expr.__class__.__name__} in aggregations" + ) + + def make_agg_call( + self, + hybrid: HybridTree, + expr: ExpressionFunctionCall, + args: list[HybridExpr], + ) -> HybridExpr: + """ + For aggregate function calls, their arguments are translated in a + manner that identifies what child subtree they correspond too, by + index, and translates them relative to the subtree. Then, the + aggregation calls are placed into the `aggs` mapping of the + corresponding child connection, and the aggregation call becomes a + child reference (referring to the aggs list), since after translation, + an aggregated child subtree only has the grouping keys and the + aggregation calls as opposed to its other terms. + + Args: + `hybrid`: the hybrid tree that should be used to derive the + translation of the aggregation call. + `expr`: the aggregation function QDAG expression to be converted. + `args`: the converted arguments to the aggregation call. + """ + child_indices: set[int] = set() + converted_args: list[HybridExpr] = [ + self.convert_agg_arg(arg, child_indices) for arg in args + ] + if len(child_indices) != 1: + raise ValueError( + f"Expected aggregation call to contain references to exactly one child collection, but found {len(child_indices)} in {expr}" + ) + hybrid_call: HybridFunctionExpr = HybridFunctionExpr( + expr.operator, converted_args, expr.pydough_type + ) + # Identify the child connection that the aggregation call is pushed + # into. + child_idx: int = child_indices.pop() + child_connection = hybrid.children[child_idx] + # Generate a unique name for the agg call to push into the child + # connection. + agg_name: str = self.get_agg_name(child_connection) + child_connection.aggs[agg_name] = hybrid_call + result_ref: HybridExpr = HybridChildRefExpr( + agg_name, child_idx, expr.pydough_type + ) + joins_can_nullify: bool = not isinstance(hybrid.pipeline[0], HybridRoot) + return self.postprocess_agg_output(hybrid_call, result_ref, joins_can_nullify) + + def make_hybrid_correl_expr( + self, + back_expr: BackReferenceExpression, + collection: PyDoughCollectionQDAG, + steps_taken_so_far: int, + ) -> HybridCorrelExpr: + """ + TODO + """ + if len(self.stack) == 0: + raise ValueError("Back reference steps too far back") + parent_tree = self.stack.pop() + remaining_steps_back: int = back_expr.back_levels - steps_taken_so_far - 1 + parent_result: HybridExpr + if len(parent_tree.pipeline) == 1 and isinstance( + parent_tree.pipeline[0], HybridPartition + ): + assert parent_tree.parent is not None + self.stack.append(parent_tree.parent) + parent_result = self.make_hybrid_correl_expr( + back_expr, collection, steps_taken_so_far + ) + self.stack.pop() + self.stack.append(parent_tree) + match parent_result.expr: + case HybridRefExpr(): + parent_result = HybridBackRefExpr( + parent_result.expr.name, 1, parent_result.typ + ) + case HybridBackRefExpr(): + parent_result = HybridBackRefExpr( + parent_result.expr.name, + parent_result.expr.back_idx + 1, + parent_result.typ, + ) + case _: + raise ValueError( + f"Malformed expression for correlated reference: {parent_result}" + ) + elif remaining_steps_back == 0: + if back_expr.term_name not in parent_tree.pipeline[-1].terms: + raise ValueError( + f"Back reference to {back_expr.term_name} not found in parent" + ) + parent_name: str = parent_tree.pipeline[-1].renamings.get( + back_expr.term_name, back_expr.term_name + ) + parent_result = HybridRefExpr(parent_name, back_expr.pydough_type) + else: + new_expr: PyDoughExpressionQDAG = BackReferenceExpression( + collection, back_expr.term_name, remaining_steps_back + ) + parent_result = self.make_hybrid_expr(parent_tree, new_expr, {}, False) + self.stack.append(parent_tree) + return HybridCorrelExpr(parent_tree, parent_result) + def make_hybrid_expr( self, hybrid: HybridTree, expr: PyDoughExpressionQDAG, child_ref_mapping: dict[int, int], + inside_agg: bool, ) -> HybridExpr: """ Converts a QDAG expression into a HybridExpr. @@ -1550,6 +1629,8 @@ def make_hybrid_expr( `child_ref_mapping`: mapping of indices used by child references in the original expressions to the index of the child hybrid tree relative to the current level. + `inside_agg`: True if `expr` is beign derived is inside of an + aggregation call, False otherwise. Returns: The HybridExpr node corresponding to `expr` @@ -1561,7 +1642,9 @@ def make_hybrid_expr( ancestor_tree: HybridTree match expr: case PartitionKey(): - return self.make_hybrid_expr(hybrid, expr.expr, child_ref_mapping) + return self.make_hybrid_expr( + hybrid, expr.expr, child_ref_mapping, inside_agg + ) case Literal(): return HybridLiteralExpr(expr) case ColumnProperty(): @@ -1581,20 +1664,22 @@ def make_hybrid_expr( case BackReferenceExpression(): # A reference to an expression from an ancestor becomes a # reference to one of the terms of a parent level of the hybrid - # tree. This does not yet support cases where the back - # reference steps outside of a child subtree and back into its - # parent subtree, since that breaks the independence between - # the parent and child. + # tree. If the BACK goes far enough that it must step outside + # a child subtree into the parent, a correlated reference is + # created. ancestor_tree = hybrid back_idx: int = 0 true_steps_back: int = 0 # Keep stepping backward until `expr.back_levels` non-hidden # steps have been taken (to ignore steps that are part of a # compound). + collection: PyDoughCollectionQDAG = expr.collection while true_steps_back < expr.back_levels: + assert collection.ancestor_context is not None + collection = collection.ancestor_context if ancestor_tree.parent is None: - raise NotImplementedError( - "TODO: (gh #141) support BACK references that step from a child subtree back into a parent context." + return self.make_hybrid_correl_expr( + expr, collection, true_steps_back ) ancestor_tree = ancestor_tree.parent back_idx += true_steps_back @@ -1609,80 +1694,51 @@ def make_hybrid_expr( expr.term_name, expr.term_name ) return HybridRefExpr(expr_name, expr.pydough_type) - case ExpressionFunctionCall() if not expr.operator.is_aggregation: - # For non-aggregate function calls, translate their arguments - # normally and build the function call. Does not support any + case ExpressionFunctionCall(): + if expr.operator.is_aggregation and inside_agg: + raise NotImplementedError( + "PyDough does not yet support calling aggregations inside of aggregations" + ) + # Do special casing for operators that an have collection + # arguments. + # TODO: (gh #148) handle collection-level NDISTINCT + if ( + expr.operator == pydop.COUNT + and len(expr.args) == 1 + and isinstance(expr.args[0], PyDoughCollectionQDAG) + ): + return self.handle_collection_count(hybrid, expr, child_ref_mapping) + elif expr.operator in (pydop.HAS, pydop.HASNOT): + return self.handle_has_hasnot(hybrid, expr, child_ref_mapping) + elif any( + not isinstance(arg, PyDoughExpressionQDAG) for arg in expr.args + ): + raise NotImplementedError( + f"PyDough does not yet support non-expression arguments for aggregation function {expr.operator}" + ) + # For normal operators, translate their expression arguments + # normally. If it is a non-aggregation, build the function + # call. If it is an aggregation, transform accordingly. # such function that takes in a collection, as none currently # exist that are not aggregations. + expr.operator.is_aggregation for arg in expr.args: if not isinstance(arg, PyDoughExpressionQDAG): raise NotImplementedError( - "PyDough does not yet support converting collections as function arguments to a non-aggregation function" + f"PyDough does not yet support non-expression arguments for function {expr.operator}" ) - args.append(self.make_hybrid_expr(hybrid, arg, child_ref_mapping)) - return HybridFunctionExpr(expr.operator, args, expr.pydough_type) - case ExpressionFunctionCall() if expr.operator.is_aggregation: - # For aggregate function calls, their arguments are translated in - # a manner that identifies what child subtree they correspond too, - # by index, and translates them relative to the subtree. Then, the - # aggregation calls are placed into the `aggs` mapping of the - # corresponding child connection, and the aggregation call becomes - # a child reference (referring to the aggs list), since after - # translation, an aggregated child subtree only has the grouping - # keys & the aggregation calls as opposed to its other terms. - child_idx: int | None = None - arg_child_idx: int | None = None - for arg in expr.args: - if isinstance(arg, PyDoughExpressionQDAG): - hybrid_arg, arg_child_idx = self.make_hybrid_agg_expr( - hybrid, arg, child_ref_mapping + args.append( + self.make_hybrid_expr( + hybrid, + arg, + child_ref_mapping, + inside_agg or expr.operator.is_aggregation, ) - else: - if not isinstance(arg, ChildReferenceCollection): - raise NotImplementedError("Cannot process argument") - # TODO: (gh #148) handle collection-level NDISTINCT - if expr.operator == pydop.COUNT: - return self.handle_collection_count( - hybrid, expr, child_ref_mapping - ) - elif expr.operator in (pydop.HAS, pydop.HASNOT): - return self.handle_has_hasnot( - hybrid, expr, child_ref_mapping - ) - else: - raise NotImplementedError( - f"PyDough does not yet support collection arguments for aggregation function {expr.operator}" - ) - # Accumulate the `arg_child_idx` value from the argument across - # all function arguments, ensuring that at the end there is - # exactly one child subtree that the agg call corresponds to. - if arg_child_idx is not None: - if child_idx is None: - child_idx = arg_child_idx - elif arg_child_idx != child_idx: - raise NotImplementedError( - "Unsupported case: multiple child indices referenced by aggregation arguments" - ) - args.append(hybrid_arg) - if child_idx is None: - raise NotImplementedError( - "Unsupported case: no child indices referenced by aggregation arguments" ) - hybrid_call: HybridFunctionExpr = HybridFunctionExpr( - expr.operator, args, expr.pydough_type - ) - child_connection = hybrid.children[child_idx] - # Generate a unique name for the agg call to push into the child - # connection. - agg_name: str = self.get_agg_name(child_connection) - child_connection.aggs[agg_name] = hybrid_call - result_ref: HybridExpr = HybridChildRefExpr( - agg_name, child_idx, expr.pydough_type - ) - joins_can_nullify: bool = not isinstance(hybrid.pipeline[0], HybridRoot) - return self.postprocess_agg_output( - hybrid_call, result_ref, joins_can_nullify - ) + if expr.operator.is_aggregation: + return self.make_agg_call(hybrid, expr, args) + else: + return HybridFunctionExpr(expr.operator, args, expr.pydough_type) case WindowCall(): partition_args: list[HybridExpr] = [] order_args: list[HybridCollation] = [] @@ -1700,7 +1756,7 @@ def make_hybrid_expr( partition_args.append(shifted_arg) for arg in expr.collation_args: hybrid_arg = self.make_hybrid_expr( - hybrid, arg.expr, child_ref_mapping + hybrid, arg.expr, child_ref_mapping, inside_agg ) order_args.append(HybridCollation(hybrid_arg, arg.asc, arg.na_last)) return HybridWindowExpr( @@ -1739,7 +1795,9 @@ def process_hybrid_collations( hybrid_orderings: list[HybridCollation] = [] for collation in collations: name = self.get_ordering_name(hybrid) - expr = self.make_hybrid_expr(hybrid, collation.expr, child_ref_mapping) + expr = self.make_hybrid_expr( + hybrid, collation.expr, child_ref_mapping, False + ) new_expressions[name] = expr new_collation: HybridCollation = HybridCollation( HybridRefExpr(name, collation.expr.pydough_type), @@ -1750,7 +1808,7 @@ def process_hybrid_collations( return new_expressions, hybrid_orderings def make_hybrid_tree( - self, node: PyDoughCollectionQDAG, parent: HybridTree | None = None + self, node: PyDoughCollectionQDAG, parent: HybridTree | None ) -> HybridTree: """ Converts a collection QDAG into the HybridTree format. @@ -1792,7 +1850,7 @@ def make_hybrid_tree( new_expressions: dict[str, HybridExpr] = {} for name in sorted(node.calc_terms): expr = self.make_hybrid_expr( - hybrid, node.get_expr(name), child_ref_mapping + hybrid, node.get_expr(name), child_ref_mapping, False ) new_expressions[name] = expr hybrid.pipeline.append( @@ -1806,7 +1864,9 @@ def make_hybrid_tree( case Where(): hybrid = self.make_hybrid_tree(node.preceding_context, parent) self.populate_children(hybrid, node, child_ref_mapping) - expr = self.make_hybrid_expr(hybrid, node.condition, child_ref_mapping) + expr = self.make_hybrid_expr( + hybrid, node.condition, child_ref_mapping, False + ) hybrid.pipeline.append(HybridFilter(hybrid.pipeline[-1], expr)) return hybrid case PartitionBy(): @@ -1819,7 +1879,7 @@ def make_hybrid_tree( for key_name in node.calc_terms: key = node.get_expr(key_name) expr = self.make_hybrid_expr( - successor_hybrid, key, child_ref_mapping + successor_hybrid, key, child_ref_mapping, False ) partition.add_key(key_name, expr) key_exprs.append(HybridRefExpr(key_name, expr.typ)) @@ -1869,6 +1929,7 @@ def make_hybrid_tree( successor_hybrid, Reference(node.child_access, key.expr.term_name), child_ref_mapping, + False, ) assert isinstance(rhs_expr, HybridRefExpr) lhs_expr: HybridExpr = HybridChildRefExpr( diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 96ddcdc9..c725971f 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -27,6 +27,7 @@ CallExpression, ColumnPruner, ColumnReference, + CorrelatedReference, EmptySingleton, ExpressionSortInfo, Filter, @@ -52,6 +53,7 @@ HybridCollectionAccess, HybridColumnExpr, HybridConnection, + HybridCorrelExpr, HybridExpr, HybridFilter, HybridFunctionExpr, @@ -78,13 +80,32 @@ class TranslationOutput: """ relation: Relational + """ + The relational tree describing the way to compute the answer for the + logic originally in the hybrid tree. + """ + expressions: dict[HybridExpr, ColumnReference] + """ + A mapping of each expression that was accessible in the hybrid tree to the + corresponding column reference in the relational tree that contains the + value of that expression. + """ + + correlated_name: str | None = None + """ + The name that can be used to refer to the relational output in correlated + references. + """ class RelTranslation: def __init__(self): # An index used for creating fake column names self.dummy_idx = 1 + # A stack of contexts used to point to ancestors for correlated + # references. + self.stack: list[TranslationOutput] = [] def make_null_column(self, relation: Relational) -> ColumnReference: """ @@ -114,6 +135,24 @@ def make_null_column(self, relation: Relational) -> ColumnReference: relation.columns[name] = LiteralExpression(None, UnknownType()) return ColumnReference(name, UnknownType()) + def get_correlated_name(self, context: TranslationOutput) -> str: + """ + Finds the name used to refer to a context for correlated variable + access. If the context does not have a correlated name, a new one is + generated for it. + + Args: + `context`: the context containing the relational subtree being + referrenced in a correlated variable access. + + Returns: + The name used to refer to the context in a correlated reference. + """ + if context.correlated_name is None: + context.correlated_name = f"corr{self.dummy_idx}" + self.dummy_idx += 1 + return context.correlated_name + def translate_expression( self, expr: HybridExpr, context: TranslationOutput | None ) -> RelationalExpression: @@ -168,6 +207,25 @@ def translate_expression( order_inputs, expr.kwargs, ) + case HybridCorrelExpr(): + ancestor_context: TranslationOutput = self.stack.pop() + ancestor_expr: RelationalExpression = self.translate_expression( + expr.expr, ancestor_context + ) + self.stack.append(ancestor_context) + match ancestor_expr: + case ColumnReference(): + return CorrelatedReference( + ancestor_expr.name, + self.get_correlated_name(ancestor_context), + expr.typ, + ) + case CorrelatedReference(): + return ancestor_expr + case _: + raise ValueError( + f"Unsupported expression to reference in a correlated reference: {ancestor_expr}" + ) case _: raise NotImplementedError(expr.__class__.__name__) @@ -226,6 +284,7 @@ def join_outputs( [LiteralExpression(True, BooleanType())], [join_type], join_columns, + correl_name=lhs_result.correlated_name, ) input_aliases: list[str | None] = out_rel.default_input_aliases @@ -366,9 +425,11 @@ def handle_children( """ for child_idx, child in enumerate(hybrid.children): if child.required_steps == pipeline_idx: + self.stack.append(context) child_output = self.rel_translation( child, child.subtree, len(child.subtree.pipeline) - 1 ) + self.stack.pop() assert child.subtree.join_keys is not None join_keys: list[tuple[HybridExpr, HybridExpr]] = child.subtree.join_keys agg_keys: list[HybridExpr] @@ -905,7 +966,7 @@ def convert_ast_to_relational( # Convert the QDAG node to the hybrid form, then invoke the relational # conversion procedure. The first element in the returned list is the # final rel node. - hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node) + hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node, None) renamings: dict[str, str] = hybrid.pipeline[-1].renamings output: TranslationOutput = translator.rel_translation( None, hybrid, len(hybrid.pipeline) - 1 diff --git a/pydough/relational/__init__.py b/pydough/relational/__init__.py index b79fe424..9c54dcd4 100644 --- a/pydough/relational/__init__.py +++ b/pydough/relational/__init__.py @@ -6,6 +6,7 @@ "ColumnReferenceFinder", "ColumnReferenceInputNameModifier", "ColumnReferenceInputNameRemover", + "CorrelatedReference", "EmptySingleton", "ExpressionSortInfo", "Filter", @@ -30,6 +31,7 @@ ColumnReferenceFinder, ColumnReferenceInputNameModifier, ColumnReferenceInputNameRemover, + CorrelatedReference, ExpressionSortInfo, LiteralExpression, RelationalExpression, diff --git a/pydough/relational/relational_expressions/__init__.py b/pydough/relational/relational_expressions/__init__.py index 68838524..3eb8fc33 100644 --- a/pydough/relational/relational_expressions/__init__.py +++ b/pydough/relational/relational_expressions/__init__.py @@ -9,6 +9,8 @@ "ColumnReferenceFinder", "ColumnReferenceInputNameModifier", "ColumnReferenceInputNameRemover", + "CorrelatedReference", + "CorrelatedReferenceFinder", "ExpressionSortInfo", "LiteralExpression", "RelationalExpression", @@ -21,6 +23,8 @@ from .column_reference_finder import ColumnReferenceFinder from .column_reference_input_name_modifier import ColumnReferenceInputNameModifier from .column_reference_input_name_remover import ColumnReferenceInputNameRemover +from .correlated_reference import CorrelatedReference +from .correlated_reference_finder import CorrelatedReferenceFinder from .expression_sort_info import ExpressionSortInfo from .literal_expression import LiteralExpression from .relational_expression_visitor import RelationalExpressionVisitor diff --git a/pydough/relational/relational_expressions/column_reference_finder.py b/pydough/relational/relational_expressions/column_reference_finder.py index 7de5bc88..d8c7ba61 100644 --- a/pydough/relational/relational_expressions/column_reference_finder.py +++ b/pydough/relational/relational_expressions/column_reference_finder.py @@ -42,3 +42,6 @@ def visit_literal_expression(self, literal_expression: LiteralExpression) -> Non def visit_column_reference(self, column_reference: ColumnReference) -> None: self._column_references.add(column_reference) + + def visit_correlated_reference(self, correlated_reference) -> None: + pass diff --git a/pydough/relational/relational_expressions/column_reference_input_name_modifier.py b/pydough/relational/relational_expressions/column_reference_input_name_modifier.py index 0a750462..fd738ad7 100644 --- a/pydough/relational/relational_expressions/column_reference_input_name_modifier.py +++ b/pydough/relational/relational_expressions/column_reference_input_name_modifier.py @@ -45,3 +45,6 @@ def visit_column_reference(self, column_reference) -> RelationalExpression: raise ValueError( f"Input name {column_reference.input_name} not found in the input name map." ) + + def visit_correlated_reference(self, correlated_reference) -> RelationalExpression: + return correlated_reference diff --git a/pydough/relational/relational_expressions/column_reference_input_name_remover.py b/pydough/relational/relational_expressions/column_reference_input_name_remover.py index 26633764..0de2e1d9 100644 --- a/pydough/relational/relational_expressions/column_reference_input_name_remover.py +++ b/pydough/relational/relational_expressions/column_reference_input_name_remover.py @@ -37,3 +37,6 @@ def visit_column_reference(self, column_reference) -> RelationalExpression: column_reference.data_type, None, ) + + def visit_correlated_reference(self, correlated_reference) -> RelationalExpression: + return correlated_reference diff --git a/pydough/relational/relational_expressions/correlated_reference.py b/pydough/relational/relational_expressions/correlated_reference.py new file mode 100644 index 00000000..0746a5bc --- /dev/null +++ b/pydough/relational/relational_expressions/correlated_reference.py @@ -0,0 +1,61 @@ +""" +TODO +""" + +__all__ = ["CorrelatedReference"] + +from pydough.types import PyDoughType + +from .abstract_expression import RelationalExpression +from .relational_expression_shuttle import RelationalExpressionShuttle +from .relational_expression_visitor import RelationalExpressionVisitor + + +class CorrelatedReference(RelationalExpression): + """ + TODO + """ + + def __init__(self, name: str, correl_name: str, data_type: PyDoughType) -> None: + super().__init__(data_type) + self._name: str = name + self._correl_name: str = correl_name + + def __hash__(self) -> int: + return hash((self.name, self.correl_name, self.data_type)) + + @property + def name(self) -> str: + """ + The name of the column. + """ + return self._name + + @property + def correl_name(self) -> str: + """ + The name of the correlation that the reference points to. + """ + return self._correl_name + + def to_string(self, compact: bool = False) -> str: + if compact: + return f"{self.correl_name}.{self.name}" + else: + return f"CorrelatedReference(name={self.name}, correl_name={self.correl_name}, type={self.data_type})" + + def equals(self, other: object) -> bool: + return ( + isinstance(other, CorrelatedReference) + and (self.name == other.name) + and (self.correl_name == other.correl_name) + and super().equals(other) + ) + + def accept(self, visitor: RelationalExpressionVisitor) -> None: + visitor.visit_correlated_reference(self) + + def accept_shuttle( + self, shuttle: RelationalExpressionShuttle + ) -> RelationalExpression: + return shuttle.visit_correlated_reference(self) diff --git a/pydough/relational/relational_expressions/correlated_reference_finder.py b/pydough/relational/relational_expressions/correlated_reference_finder.py new file mode 100644 index 00000000..20ae0dc2 --- /dev/null +++ b/pydough/relational/relational_expressions/correlated_reference_finder.py @@ -0,0 +1,48 @@ +""" +Find all unique column references in a relational expression. +""" + +from .call_expression import CallExpression +from .column_reference import ColumnReference +from .correlated_reference import CorrelatedReference +from .literal_expression import LiteralExpression +from .relational_expression_visitor import RelationalExpressionVisitor +from .window_call_expression import WindowCallExpression + +__all__ = ["CorrelatedReferenceFinder"] + + +class CorrelatedReferenceFinder(RelationalExpressionVisitor): + """ + Find all unique correlated references in a relational expression. + """ + + def __init__(self) -> None: + self._correlated_references: set[CorrelatedReference] = set() + + def reset(self) -> None: + self._correlated_references = set() + + def get_correlated_references(self) -> set[CorrelatedReference]: + return self._correlated_references + + def visit_call_expression(self, call_expression: CallExpression) -> None: + for arg in call_expression.inputs: + arg.accept(self) + + def visit_window_expression(self, window_expression: WindowCallExpression) -> None: + for arg in window_expression.inputs: + arg.accept(self) + for partition_arg in window_expression.partition_inputs: + partition_arg.accept(self) + for order_arg in window_expression.order_inputs: + order_arg.expr.accept(self) + + def visit_literal_expression(self, literal_expression: LiteralExpression) -> None: + pass + + def visit_column_reference(self, column_reference: ColumnReference) -> None: + pass + + def visit_correlated_reference(self, correlated_reference) -> None: + self._correlated_references.add(correlated_reference) diff --git a/pydough/relational/relational_expressions/relational_expression_shuttle.py b/pydough/relational/relational_expressions/relational_expression_shuttle.py index 354a4b8e..badc0b1d 100644 --- a/pydough/relational/relational_expressions/relational_expression_shuttle.py +++ b/pydough/relational/relational_expressions/relational_expression_shuttle.py @@ -75,3 +75,14 @@ def visit_column_reference(self, column_reference): Returns: RelationalExpression: The new node resulting from visiting this node. """ + + @abstractmethod + def visit_correlated_reference(self, correlated_reference): + """ + Visit a CorrelatedReference node. + + Args: + correlated_reference (CorrelatedReference): The correlated reference node to visit. + Returns: + RelationalExpression: The new node resulting from visiting this node. + """ diff --git a/pydough/relational/relational_expressions/relational_expression_visitor.py b/pydough/relational/relational_expressions/relational_expression_visitor.py index 39746b1a..0873662a 100644 --- a/pydough/relational/relational_expressions/relational_expression_visitor.py +++ b/pydough/relational/relational_expressions/relational_expression_visitor.py @@ -61,3 +61,12 @@ def visit_column_reference(self, column_reference) -> None: Args: column_reference (ColumnReference): The column reference node to visit. """ + + @abstractmethod + def visit_correlated_reference(self, correlated_reference) -> None: + """ + Visit a CorrelatedReference node. + + Args: + correlated_reference (CorrelatedReference): The correlated reference node to visit. + """ diff --git a/pydough/relational/relational_nodes/column_pruner.py b/pydough/relational/relational_nodes/column_pruner.py index 8eee34c5..849bdf2b 100644 --- a/pydough/relational/relational_nodes/column_pruner.py +++ b/pydough/relational/relational_nodes/column_pruner.py @@ -5,10 +5,13 @@ from pydough.relational.relational_expressions import ( ColumnReference, ColumnReferenceFinder, + CorrelatedReference, + CorrelatedReferenceFinder, ) from .abstract_node import Relational from .aggregate import Aggregate +from .join import Join from .project import Project from .relational_expression_dispatcher import RelationalExpressionDispatcher from .relational_root import RelationalRoot @@ -19,11 +22,15 @@ class ColumnPruner: def __init__(self) -> None: self._column_finder: ColumnReferenceFinder = ColumnReferenceFinder() + self._correl_finder: CorrelatedReferenceFinder = CorrelatedReferenceFinder() # Note: We set recurse=False so we only check the expressions in the # current node. - self._dispatcher = RelationalExpressionDispatcher( + self._finder_dispatcher = RelationalExpressionDispatcher( self._column_finder, recurse=False ) + self._correl_dispatcher = RelationalExpressionDispatcher( + self._correl_finder, recurse=False + ) def _prune_identity_project(self, node: Relational) -> Relational: """ @@ -43,7 +50,7 @@ def _prune_identity_project(self, node: Relational) -> Relational: def _prune_node_columns( self, node: Relational, kept_columns: set[str] - ) -> Relational: + ) -> tuple[Relational, set[CorrelatedReference]]: """ Prune the columns for a subtree starting at this node. @@ -68,14 +75,17 @@ def _prune_node_columns( for name, expr in node.columns.items() if name in kept_columns or name in required_columns } + # Update the columns. new_node = node.copy(columns=columns) - self._dispatcher.reset() - # Visit the current identifiers. - new_node.accept(self._dispatcher) + + # Find all the identifiers referenced by the the current node. + self._finder_dispatcher.reset() + new_node.accept(self._finder_dispatcher) found_identifiers: set[ColumnReference] = ( self._column_finder.get_column_references() ) + # If the node is an aggregate but doesn't use any of the inputs # (e.g. a COUNT(*)), arbitrarily mark one of them as used. # TODO: (gh #196) optimize this functionality so it doesn't keep an @@ -88,19 +98,50 @@ def _prune_node_columns( node.input.columns[arbitrary_column_name].data_type, ) ) + # Determine which identifiers to pass to each input. new_inputs: list[Relational] = [] # Note: The ColumnPruner should only be run when all input names are # still present in the columns. - for i, default_input_name in enumerate(new_node.default_input_aliases): + # Iterate over the inputs in reverse order so that the source of + # correlated data is pruned last, since it will need to account for + # any correlated references in the later inputs. + correl_refs: set[CorrelatedReference] = set() + for i, default_input_name in reversed( + list(enumerate(new_node.default_input_aliases)) + ): s: set[str] = set() + input_node: Relational = node.inputs[i] for identifier in found_identifiers: if identifier.input_name == default_input_name: s.add(identifier.name) - new_inputs.append(self._prune_node_columns(node.inputs[i], s)) + if ( + isinstance(new_node, Join) + and i == 0 + and new_node.correl_name is not None + ): + for correl_ref in correl_refs: + if correl_ref.correl_name == new_node.correl_name: + s.add(correl_ref.name) + new_input_node, new_correl_refs = self._prune_node_columns(input_node, s) + new_inputs.append(new_input_node) + if i == len(node.inputs) - 1: + correl_refs = new_correl_refs + else: + correl_refs.update(new_correl_refs) + new_inputs.reverse() + + # Find all the correlated references in the new node. + self._correl_dispatcher.reset() + new_node.accept(self._correl_dispatcher) + found_correl_refs: set[CorrelatedReference] = ( + self._correl_finder.get_correlated_references() + ) + correl_refs.update(found_correl_refs) + # Determine the new node. output = new_node.copy(inputs=new_inputs) - return self._prune_identity_project(output) + return self._prune_identity_project(output), correl_refs def prune_unused_columns(self, root: RelationalRoot) -> RelationalRoot: """ @@ -112,6 +153,6 @@ def prune_unused_columns(self, root: RelationalRoot) -> RelationalRoot: Returns: RelationalRoot: The root after updating all inputs. """ - new_root: Relational = self._prune_node_columns(root, set(root.columns.keys())) + new_root, _ = self._prune_node_columns(root, set(root.columns.keys())) assert isinstance(new_root, RelationalRoot), "Expected a root node." return new_root diff --git a/pydough/relational/relational_nodes/join.py b/pydough/relational/relational_nodes/join.py index 3114be15..1568e6ba 100644 --- a/pydough/relational/relational_nodes/join.py +++ b/pydough/relational/relational_nodes/join.py @@ -49,6 +49,7 @@ def __init__( conditions: list[RelationalExpression], join_types: list[JoinType], columns: MutableMapping[str, RelationalExpression], + correl_name: str | None = None, ) -> None: super().__init__(columns) num_inputs = len(inputs) @@ -65,6 +66,15 @@ def __init__( ), "Join condition must be a boolean type" self._conditions: list[RelationalExpression] = conditions self._join_types: list[JoinType] = join_types + self._correl_name: str | None = correl_name + + @property + def correl_name(self) -> str | None: + """ + The name used to refer to the first join input when subsequent inputs + have correlated references. + """ + return self._correl_name @property def conditions(self) -> list[RelationalExpression]: @@ -101,6 +111,7 @@ def node_equals(self, other: Relational) -> bool: isinstance(other, Join) and self.conditions == other.conditions and self.join_types == other.join_types + and self.correl_name == other.correl_name and all( self.inputs[i].node_equals(other.inputs[i]) for i in range(len(self.inputs)) @@ -109,7 +120,10 @@ def node_equals(self, other: Relational) -> bool: def to_string(self, compact: bool = False) -> str: conditions: list[str] = [cond.to_string(compact) for cond in self.conditions] - return f"JOIN(conditions=[{', '.join(conditions)}], types={[t.value for t in self.join_types]}, columns={self.make_column_string(self.columns, compact)})" + correl_suffix = ( + "" if self.correl_name is None else f", correl_name={self.correl_name!r}" + ) + return f"JOIN(conditions=[{', '.join(conditions)}], types={[t.value for t in self.join_types]}, columns={self.make_column_string(self.columns, compact)}{correl_suffix})" def accept(self, visitor: RelationalVisitor) -> None: visitor.visit_join(self) @@ -119,4 +133,4 @@ def node_copy( columns: MutableMapping[str, RelationalExpression], inputs: MutableSequence[Relational], ) -> Relational: - return Join(inputs, self.conditions, self.join_types, columns) + return Join(inputs, self.conditions, self.join_types, columns, self.correl_name) diff --git a/pydough/sqlglot/sqlglot_relational_expression_visitor.py b/pydough/sqlglot/sqlglot_relational_expression_visitor.py index 4ca111ba..8c03edd5 100644 --- a/pydough/sqlglot/sqlglot_relational_expression_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_expression_visitor.py @@ -13,6 +13,7 @@ from pydough.relational import ( CallExpression, ColumnReference, + CorrelatedReference, LiteralExpression, RelationalExpression, RelationalExpressionVisitor, @@ -32,13 +33,17 @@ class SQLGlotRelationalExpressionVisitor(RelationalExpressionVisitor): """ def __init__( - self, dialect: SQLGlotDialect, bindings: SqlGlotTransformBindings + self, + dialect: SQLGlotDialect, + bindings: SqlGlotTransformBindings, + correlated_names: dict[str, str], ) -> None: # Keep a stack of SQLGlot expressions so we can build up # intermediate results. self._stack: list[SQLGlotExpression] = [] self._dialect: SQLGlotDialect = dialect self._bindings: SqlGlotTransformBindings = bindings + self._correlated_names: dict[str, str] = correlated_names def reset(self) -> None: """ @@ -133,18 +138,43 @@ def visit_literal_expression(self, literal_expression: LiteralExpression) -> Non ) self._stack.append(literal) + def visit_correlated_reference( + self, correlated_reference: CorrelatedReference + ) -> None: + full_name: str = f"{self._correlated_names[correlated_reference.correl_name]}.{correlated_reference.name}" + self._stack.append(Identifier(this=full_name)) + + # TODO: implement the column-based version of make_sqlglot_column, with table sources + # @staticmethod + # def make_sqlglot_column( + # column_reference: ColumnReference, + # ) -> Column: + # """ + # Convert a column reference to a SQLGlot column. This is split into a + # separate static method to ensure consistency across multiple visitors. + + # Args: + # column_reference (ColumnReference): The column reference to generate + # an identifier for. + + # Returns: + # Identifier: The output column reference containing an identifier. + # """ + # result: SQLGlotExpression = Column(this=Identifier(this=column_reference.name)) + # if column_reference.input_name is not None: + # result.set("table", Identifier(this=column_reference.input_name)) + # return result + @staticmethod - def generate_column_reference_identifier( + def make_sqlglot_column( column_reference: ColumnReference, ) -> Identifier: """ Generate an identifier for a column reference. This is split into a separate static method to ensure consistency across multiple visitors. - Args: column_reference (ColumnReference): The column reference to generate an identifier for. - Returns: Identifier: The output identifier. """ @@ -155,7 +185,7 @@ def generate_column_reference_identifier( return Identifier(this=full_name) def visit_column_reference(self, column_reference: ColumnReference) -> None: - self._stack.append(self.generate_column_reference_identifier(column_reference)) + self._stack.append(self.make_sqlglot_column(column_reference)) def relational_to_sqlglot( self, expr: RelationalExpression, output_name: str | None = None diff --git a/pydough/sqlglot/sqlglot_relational_visitor.py b/pydough/sqlglot/sqlglot_relational_visitor.py index 5f1e17ae..327bc3ef 100644 --- a/pydough/sqlglot/sqlglot_relational_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_visitor.py @@ -9,6 +9,7 @@ from sqlglot.dialects import Dialect as SQLGlotDialect from sqlglot.expressions import Alias as SQLGlotAlias +from sqlglot.expressions import Column as SQLGlotColumn from sqlglot.expressions import Expression as SQLGlotExpression from sqlglot.expressions import Identifier, Select, Subquery, values from sqlglot.expressions import Literal as SQLGlotLiteral @@ -20,6 +21,7 @@ ColumnReference, ColumnReferenceInputNameModifier, ColumnReferenceInputNameRemover, + CorrelatedReference, EmptySingleton, ExpressionSortInfo, Filter, @@ -54,8 +56,11 @@ def __init__( # Keep a stack of SQLGlot expressions so we can build up # intermediate results. self._stack: list[Select] = [] + self._correlated_names: dict[str, str] = {} self._expr_visitor: SQLGlotRelationalExpressionVisitor = ( - SQLGlotRelationalExpressionVisitor(dialect, bindings) + SQLGlotRelationalExpressionVisitor( + dialect, bindings, self._correlated_names + ) ) self._alias_modifier: ColumnReferenceInputNameModifier = ( ColumnReferenceInputNameModifier() @@ -94,7 +99,7 @@ def _is_mergeable_column(expr: SQLGlotExpression) -> bool: if isinstance(expr, SQLGlotAlias): return SQLGlotRelationalVisitor._is_mergeable_column(expr.this) else: - return isinstance(expr, (SQLGlotLiteral, Identifier)) + return isinstance(expr, (SQLGlotLiteral, Identifier, SQLGlotColumn)) @staticmethod def _try_merge_columns( @@ -154,11 +159,22 @@ def _try_merge_columns( # If the new column is a literal, we can just add it to the old # columns. modified_old_columns.append(set_glot_alias(new_column, new_name)) - else: + elif isinstance(new_column, Identifier): expr = set_glot_alias(old_column_map[new_column.this], new_name) modified_old_columns.append(expr) if isinstance(expr, Identifier): seen_cols.add(expr) + elif isinstance(new_column, SQLGlotColumn): + expr = set_glot_alias( + old_column_map[new_column.this.this], new_name + ) + modified_old_columns.append(expr) + if isinstance(expr, Identifier): + seen_cols.add(expr) + else: + raise ValueError( + f"Unsupported expression type for column merging: {new_column.__class__.__name__}" + ) # Check that there are no missing dependencies in the old columns. if old_column_deps - seen_cols: return new_columns, old_columns @@ -291,7 +307,7 @@ def contains_window(self, exp: RelationalExpression) -> bool: match exp: case CallExpression(): return any(self.contains_window(arg) for arg in exp.inputs) - case ColumnReference() | LiteralExpression(): + case ColumnReference() | LiteralExpression() | CorrelatedReference(): return False case WindowCallExpression(): return True @@ -317,6 +333,12 @@ def visit_scan(self, scan: Scan) -> None: self._stack.append(query) def visit_join(self, join: Join) -> None: + alias_map: dict[str | None, str] = {} + if join.correl_name is not None: + input_name = join.default_input_aliases[0] + alias = self._generate_table_alias() + alias_map[input_name] = alias + self._correlated_names[join.correl_name] = alias self.visit_inputs(join) inputs: list[Select] = [self._stack.pop() for _ in range(len(join.inputs))] inputs.reverse() @@ -327,11 +349,12 @@ def visit_join(self, join: Join) -> None: seen_names[column] += 1 # Only keep duplicate names. kept_names = {key for key, value in seen_names.items() if value > 1} - alias_map = { - join.default_input_aliases[i]: self._generate_table_alias() - for i in range(len(join.inputs)) - if kept_names.intersection(join.inputs[i].columns.keys()) - } + for i in range(len(join.inputs)): + input_name = join.default_input_aliases[i] + if input_name not in alias_map and kept_names.intersection( + join.inputs[i].columns.keys() + ): + alias_map[input_name] = self._generate_table_alias() self._alias_remover.set_kept_names(kept_names) self._alias_modifier.set_map(alias_map) columns = { diff --git a/pydough/unqualified/qualification.py b/pydough/unqualified/qualification.py index 216c4011..ba33907d 100644 --- a/pydough/unqualified/qualification.py +++ b/pydough/unqualified/qualification.py @@ -371,7 +371,7 @@ def qualify_access( for _ in range(levels): if ancestor.ancestor_context is None: raise PyDoughUnqualifiedException( - f"Cannot back reference {levels} above {unqualified_parent}" + f"Cannot back reference {levels} above {context}" ) ancestor = ancestor.ancestor_context # Identify whether the access is an expression or a collection diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index 7ec5ba0e..e9f06919 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -270,3 +270,129 @@ def triple_partition(): )(supp_region, avg_percentage=AVG(cust_regions.percentage)).ORDER_BY( supp_region.ASC() ) + + +def correl_1(): + # Correlated back reference example #1: simple 1-step correlated reference + # For each region, count how many of its its nations start with the same + # letter as the region. This is a true correlated join doing an aggregated + # access without requiring the RHS be present. + return Regions( + name, n_prefix_nations=COUNT(nations.WHERE(name[:1] == BACK(1).name[:1])) + ) + + +def correl_2(): + # Correlated back reference example #2: simple 2-step correlated reference + # For each region's nations, count how many customers have a comment + # starting with the same letter as the region. Exclude regions that start + # with the letter a. This is a true correlated join doing an aggregated + # access without requiring the RHS be present. + selected_custs = customers.WHERE(comment[:1] == LOWER(BACK(2).name[:1])) + return Regions.WHERE(~STARTSWITH(name, "A")).nations( + name, + n_selected_custs=COUNT(selected_custs), + ) + + +def correl_3(): + # Correlated back reference example #3: double-layer correlated reference + # For every every region, count how many of its nations have a customer + # whose comment starts with the same letter as the region. This is a true + # correlated join doing an aggregated access without requiring the RHS be + # present. + selected_custs = customers.WHERE(comment[:1] == LOWER(BACK(2).name[:1])) + return Regions(name, n_nations=COUNT(nations.WHERE(HAS(selected_custs)))) + + +def correl_4(): + # Correlated back reference example #4: 2-step correlated HASNOT + # Find every nation that does not have a customer whose account balance is + # within $5 of the smallest known account balance globally. + # (This is a correlated ANTI-join) + selected_customers = customers.WHERE(acctbal <= (BACK(2).smallest_bal + 5.0)) + return ( + TPCH( + smallest_bal=MIN(Customers.acctbal), + ) + .Nations(name) + .WHERE(HASNOT(selected_customers)) + .ORDER_BY(name.ASC()) + ) + + +def correl_5(): + # Correlated back reference example #5: 2-step correlated HAS + # Find every region that has at least 1 supplier whose account balance is + # within $4 of the smallest known account balance globally + # (This is a correlated SEMI-join) + selected_suppliers = nations.suppliers.WHERE( + account_balance <= (BACK(3).smallest_bal + 4.0) + ) + return ( + TPCH( + smallest_bal=MIN(Suppliers.account_balance), + ) + .Regions(name) + .WHERE(HAS(selected_suppliers)) + .ORDER_BY(name.ASC()) + ) + + +def correl_6(): + # Correlated back reference example #6: simple 1-step correlated reference + # For each region, count how many of its its nations start with the same + # letter as the region, but only keep regions with at least one such nation. + # This is a true correlated join doing an aggregated access that does NOT + # require that records without the RHS be kept. + selected_nations = nations.WHERE(name[:1] == BACK(1).name[:1]) + return Regions.WHERE(HAS(selected_nations))( + name, n_prefix_nations=COUNT(selected_nations) + ) + + +def correl_7(): + # Correlated back reference example #6: deleted correlated reference + # For each region, count how many of its its nations start with the same + # letter as the region, but only keep regions without at least one such + # nation. The true correlated join is trumped by the correlated ANTI-join. + selected_nations = nations.WHERE(name[:1] == BACK(1).name[:1]) + return Regions.WHERE(HASNOT(selected_nations))( + name, n_prefix_nations=COUNT(selected_nations) + ) + + +def correl_8(): + # Correlated back reference example #8: non-agg correlated reference + # For each nation, fetch the name of its region, but filter the reigon + # so it only keeps it if it starts with the same letter as the nation + # (otherwise, returns NULL). This is a true correlated join doing an + # access without aggregation without requiring the RHS be + # present. + aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) + return Nations(name, rname=aug_region.name).ORDER_BY(name.ASC()) + + +def correl_9(): + # Correlated back reference example #9: non-agg correlated reference + # For each nation, fetch the name of its region, but filter the reigon + # so it only keeps it if it starts with the same letter as the nation + # (otherwise, omit the nation). This is a true correlated join doing an + # access that also requires the RHS records be present. + aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) + return Nations.WHERE(HAS(aug_region))(name, rname=aug_region.name).ORDER_BY( + name.ASC() + ) + + +def correl_10(): + # Correlated back reference example #10: deleted correlated reference + # For each nation, fetch the name of its region, but filter the reigon + # so it only keeps it if it starts with the same letter as the nation + # (otherwise, returns NULL), and also filter the nations to only keep + # records where the region is NULL. The true correlated join is trumped by + # the correlated ANTI-join. + aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) + return Nations.WHERE(HASNOT(aug_region))(name, rname=aug_region.name).ORDER_BY( + name.ASC() + ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 41605af3..3f693cac 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -14,6 +14,16 @@ ) from simple_pydough_functions import ( agg_partition, + correl_1, + correl_2, + correl_3, + correl_4, + correl_5, + correl_6, + correl_7, + correl_8, + correl_9, + correl_10, double_partition, function_sampler, percentile_customers_per_region, @@ -193,11 +203,31 @@ ( impl_tpch_q5, """ +ROOT(columns=[('N_NAME', N_NAME), ('REVENUE', REVENUE)], orderings=[(ordering_1):desc_last]) + PROJECT(columns={'N_NAME': N_NAME, 'REVENUE': REVENUE, 'ordering_1': REVENUE}) + PROJECT(columns={'N_NAME': name, 'REVENUE': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr10') + FILTER(condition=name_3 == 'ASIA':string, columns={'key': key, 'name': name}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': SUM(value)}) + PROJECT(columns={'nation_key': nation_key, 'value': extended_price * 1:int64 - discount}) + FILTER(condition=name_9 == corr10.name, columns={'discount': discount, 'extended_price': extended_price, 'nation_key': nation_key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name_9': t1.name_9, 'nation_key': t0.nation_key}) + JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'nation_key': t0.nation_key, 'supplier_key': t1.supplier_key}) + FILTER(condition=order_date >= datetime.date(1994, 1, 1):date & order_date < datetime.date(1995, 1, 1):date, columns={'key_5': key_5, 'nation_key': nation_key}) + JOIN(conditions=[t0.key == t1.customer_key], types=['inner'], columns={'key_5': t1.key, 'nation_key': t0.nation_key, 'order_date': t1.order_date}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'supplier_key': l_suppkey}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_9': t1.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) """, tpch_q5_output, ), id="tpch_q5", - marks=pytest.mark.skip("TODO: support correlated back references"), ), pytest.param( ( @@ -548,21 +578,58 @@ ( impl_tpch_q21, """ +ROOT(columns=[('S_NAME', S_NAME), ('NUMWAIT', NUMWAIT)], orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'NUMWAIT': NUMWAIT, 'S_NAME': S_NAME, 'ordering_1': ordering_1, 'ordering_2': ordering_2}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) + PROJECT(columns={'NUMWAIT': NUMWAIT, 'S_NAME': S_NAME, 'ordering_1': NUMWAIT, 'ordering_2': S_NAME}) + PROJECT(columns={'NUMWAIT': DEFAULT_TO(agg_0, 0:int64), 'S_NAME': name}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + FILTER(condition=name_3 == 'SAUDI ARABIA':string, columns={'key': key, 'name': name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=order_status == 'F':string & True:bool & True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.key == t1.order_key], types=['anti'], columns={'order_status': t0.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr6') + JOIN(conditions=[t0.key == t1.order_key], types=['semi'], columns={'key': t0.key, 'order_status': t0.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr5') + JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'key': t1.key, 'order_status': t1.order_status, 'supplier_key': t0.supplier_key}) + FILTER(condition=receipt_date > commit_date, columns={'order_key': order_key, 'supplier_key': supplier_key}) + SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate, 'supplier_key': l_suppkey}) + SCAN(table=tpch.ORDERS, columns={'key': o_orderkey, 'order_status': o_orderstatus}) + FILTER(condition=supplier_key != corr5.supplier_key, columns={'order_key': order_key}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'supplier_key': l_suppkey}) + FILTER(condition=supplier_key != corr6.supplier_key & receipt_date > commit_date, columns={'order_key': order_key}) + SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate, 'supplier_key': l_suppkey}) """, tpch_q21_output, ), id="tpch_q21", - marks=pytest.mark.skip("TODO: support correlated back references"), ), pytest.param( ( impl_tpch_q22, """ +ROOT(columns=[('CNTRY_CODE', CNTRY_CODE), ('NUM_CUSTS', NUM_CUSTS), ('TOTACCTBAL', TOTACCTBAL)], orderings=[]) + PROJECT(columns={'CNTRY_CODE': cntry_code, 'NUM_CUSTS': DEFAULT_TO(agg_1, 0:int64), 'TOTACCTBAL': DEFAULT_TO(agg_2, 0:int64)}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1, 'agg_2': t1.agg_2, 'cntry_code': t1.cntry_code}, correl_name='corr1') + PROJECT(columns={'avg_balance': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(acctbal)}) + FILTER(condition=acctbal > 0.0:float64, columns={'acctbal': acctbal}) + FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]) & True:bool, columns={'acctbal': acctbal}) + JOIN(conditions=[t0.key == t1.customer_key], types=['anti'], columns={'acctbal': t0.acctbal, 'cntry_code': t0.cntry_code}) + PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) + AGGREGATE(keys={'cntry_code': cntry_code}, aggregations={'agg_1': COUNT(), 'agg_2': SUM(acctbal)}) + FILTER(condition=acctbal > corr1.avg_balance, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) + FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]) & True:bool, columns={'acctbal': acctbal, 'cntry_code': cntry_code}) + JOIN(conditions=[t0.key == t1.customer_key], types=['anti'], columns={'acctbal': t0.acctbal, 'cntry_code': t0.cntry_code}) + PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'key': c_custkey, 'phone': c_phone}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) """, tpch_q22_output, ), id="tpch_q22", - marks=pytest.mark.skip("TODO: support correlated back references"), ), pytest.param( ( @@ -1062,6 +1129,284 @@ ), id="triple_partition", ), + pytest.param( + ( + correl_1, + """ +ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[]) + PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) +""", + lambda: pd.DataFrame( + { + "name": ["AFRICA" "AMERICA" "MIDDLE EAST" "EUROPE" "ASIA"], + "n_prefix_nations": [1, 1, 0, 0, 0], + } + ), + ), + id="correl_1", + ), + pytest.param( + ( + correl_2, + """ +ROOT(columns=[('name', name), ('n_selected_custs', n_selected_custs)], orderings=[]) + PROJECT(columns={'n_selected_custs': DEFAULT_TO(agg_0, 0:int64), 'name': name_3}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}, correl_name='corr4') + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name}) + FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(comment, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(corr4.name, None:unknown, 1:int64, None:unknown)), columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) +""", + lambda: pd.DataFrame( + { + "name": ["A"] * 5, + } + ), + ), + id="correl_2", + ), + pytest.param( + ( + correl_3, + """ +ROOT(columns=[('name', name), ('n_nations', n_nations)], orderings=[]) + PROJECT(columns={'n_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=True:bool, columns={'region_key': region_key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['semi'], columns={'region_key': t0.region_key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + FILTER(condition=SLICE(comment, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(corr1.name, None:unknown, 1:int64, None:unknown)), columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) +""", + lambda: pd.DataFrame( + { + "name": ["A"] * 5, + } + ), + ), + id="correl_3", + ), + pytest.param( + ( + correl_4, + """ +ROOT(columns=[('name', name)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'name': name, 'ordering_1': name}) + FILTER(condition=True:bool, columns={'name': name}) + JOIN(conditions=[t0.key == t1.nation_key], types=['anti'], columns={'name': t0.name}, correl_name='corr1') + JOIN(conditions=[True:bool], types=['inner'], columns={'key': t1.key, 'name': t1.name, 'smallest_bal': t0.smallest_bal}) + PROJECT(columns={'smallest_bal': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': MIN(acctbal)}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + FILTER(condition=acctbal <= corr1.smallest_bal + 5.0:float64, columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) +""", + lambda: pd.DataFrame( + { + "name": ["ARGENTINA", "KENYA", "UNITED KINGDOM"], + } + ), + ), + id="correl_4", + ), + pytest.param( + ( + correl_5, + """ +ROOT(columns=[('name', name)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'name': name, 'ordering_1': name}) + FILTER(condition=True:bool, columns={'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['semi'], columns={'name': t0.name}, correl_name='corr4') + JOIN(conditions=[True:bool], types=['inner'], columns={'key': t1.key, 'name': t1.name, 'smallest_bal': t0.smallest_bal}) + PROJECT(columns={'smallest_bal': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': MIN(account_balance)}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + FILTER(condition=account_balance <= corr4.smallest_bal + 4.0:float64, columns={'region_key': region_key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'account_balance': t1.account_balance, 'region_key': t0.region_key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) +""", + lambda: pd.DataFrame( + { + "name": ["AFRICA", "ASIA", "MIDDLE EAST"], + } + ), + ), + id="correl_5", + ), + pytest.param( + ( + correl_6, + """ +ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[]) + PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + FILTER(condition=True:bool, columns={'agg_0': agg_0, 'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) +""", + lambda: pd.DataFrame( + { + "name": ["AFRICA", "AMERICA"], + "n_prefix_nations": [1, 1], + } + ), + ), + id="correl_6", + ), + pytest.param( + ( + correl_7, + """ +ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[]) + PROJECT(columns={'n_prefix_nations': DEFAULT_TO(NULL_2, 0:int64), 'name': name}) + FILTER(condition=True:bool, columns={'NULL_2': NULL_2, 'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) +""", + lambda: pd.DataFrame( + { + "name": ["ASIA", "EUROPE", "MIDDLE EAST"], + "n_prefix_nations": [0] * 3, + } + ), + ), + id="correl_7", + ), + pytest.param( + ( + correl_8, + """ +ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) + PROJECT(columns={'name': name, 'rname': name_4}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'name': t0.name, 'name_4': t1.name}, correl_name='corr1') + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) +""", + lambda: pd.DataFrame( + { + "name": [ + "ALGERIA", + "ARGENTINA", + "BRAZIL", + "CANADA", + "CHINA", + "EGYPT", + "ETHIOPIA", + "FRANCE", + "GERMANY", + "INDIA", + "INDONESIA", + "IRAN", + "IRAQ", + "JAPAN", + "JORDAN", + "KENYA", + "MOROCCO", + "MOZAMBIQUE", + "PERU", + "ROMANIA", + "RUSSIA", + "SAUDI ARABIA", + "UNITED KINGDOM", + "UNITED STATES", + "VIETNAM", + ], + "rname": ["AFRICA", "AMERICA"] + [None] * 23, + } + ), + ), + id="correl_8", + ), + pytest.param( + ( + correl_9, + """ +ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) + PROJECT(columns={'name': name, 'rname': name_4}) + FILTER(condition=True:bool, columns={'name': name, 'name_4': name_4}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name': t0.name, 'name_4': t1.name}, correl_name='corr1') + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) +""", + lambda: pd.DataFrame( + { + "name": [ + "ALGERIA", + "ARGENTINA", + ], + "rname": ["AFRICA", "AMERICA"], + } + ), + ), + id="correl_9", + ), + pytest.param( + ( + correl_10, + """ +ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) + PROJECT(columns={'name': name, 'rname': NULL_2}) + FILTER(condition=True:bool, columns={'NULL_2': NULL_2, 'name': name}) + JOIN(conditions=[t0.region_key == t1.key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'key': key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) +""", + lambda: pd.DataFrame( + { + "name": [ + "BRAZIL", + "CANADA", + "CHINA", + "EGYPT", + "ETHIOPIA", + "FRANCE", + "GERMANY", + "INDIA", + "INDONESIA", + "IRAN", + "IRAQ", + "JAPAN", + "JORDAN", + "KENYA", + "MOROCCO", + "MOZAMBIQUE", + "PERU", + "ROMANIA", + "RUSSIA", + "SAUDI ARABIA", + "UNITED KINGDOM", + "UNITED STATES", + "VIETNAM", + ], + "rname": [None] * 23, + } + ), + ), + id="correl_10", + ), ], ) def pydough_pipeline_test_data( diff --git a/tests/tpch_outputs.py b/tests/tpch_outputs.py index c875d630..80948acb 100644 --- a/tests/tpch_outputs.py +++ b/tests/tpch_outputs.py @@ -692,7 +692,7 @@ def tpch_q21_output() -> pd.DataFrame: Expected output for TPC-H query 21. Note: This is truncated to the first 10 rows. """ - columns = ["S_NAME", "NUM_WAIT"] + columns = ["S_NAME", "NUMWAIT"] data = [ ("Supplier#000002829", 20), ("Supplier#000005808", 18), From 4c54d571fa5da07baf4f432ddfa5e503f86d724e Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 28 Jan 2025 11:59:03 -0500 Subject: [PATCH 070/112] Pulling out SQLGlot changes into followup PR --- .../sqlglot_relational_expression_visitor.py | 36 +++------------- pydough/sqlglot/sqlglot_relational_visitor.py | 41 ++++--------------- 2 files changed, 15 insertions(+), 62 deletions(-) diff --git a/pydough/sqlglot/sqlglot_relational_expression_visitor.py b/pydough/sqlglot/sqlglot_relational_expression_visitor.py index 8c03edd5..017ebe60 100644 --- a/pydough/sqlglot/sqlglot_relational_expression_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_expression_visitor.py @@ -33,17 +33,13 @@ class SQLGlotRelationalExpressionVisitor(RelationalExpressionVisitor): """ def __init__( - self, - dialect: SQLGlotDialect, - bindings: SqlGlotTransformBindings, - correlated_names: dict[str, str], + self, dialect: SQLGlotDialect, bindings: SqlGlotTransformBindings ) -> None: # Keep a stack of SQLGlot expressions so we can build up # intermediate results. self._stack: list[SQLGlotExpression] = [] self._dialect: SQLGlotDialect = dialect self._bindings: SqlGlotTransformBindings = bindings - self._correlated_names: dict[str, str] = correlated_names def reset(self) -> None: """ @@ -141,40 +137,20 @@ def visit_literal_expression(self, literal_expression: LiteralExpression) -> Non def visit_correlated_reference( self, correlated_reference: CorrelatedReference ) -> None: - full_name: str = f"{self._correlated_names[correlated_reference.correl_name]}.{correlated_reference.name}" - self._stack.append(Identifier(this=full_name)) - - # TODO: implement the column-based version of make_sqlglot_column, with table sources - # @staticmethod - # def make_sqlglot_column( - # column_reference: ColumnReference, - # ) -> Column: - # """ - # Convert a column reference to a SQLGlot column. This is split into a - # separate static method to ensure consistency across multiple visitors. - - # Args: - # column_reference (ColumnReference): The column reference to generate - # an identifier for. - - # Returns: - # Identifier: The output column reference containing an identifier. - # """ - # result: SQLGlotExpression = Column(this=Identifier(this=column_reference.name)) - # if column_reference.input_name is not None: - # result.set("table", Identifier(this=column_reference.input_name)) - # return result + raise NotImplementedError("TODO") @staticmethod - def make_sqlglot_column( + def generate_column_reference_identifier( column_reference: ColumnReference, ) -> Identifier: """ Generate an identifier for a column reference. This is split into a separate static method to ensure consistency across multiple visitors. + Args: column_reference (ColumnReference): The column reference to generate an identifier for. + Returns: Identifier: The output identifier. """ @@ -185,7 +161,7 @@ def make_sqlglot_column( return Identifier(this=full_name) def visit_column_reference(self, column_reference: ColumnReference) -> None: - self._stack.append(self.make_sqlglot_column(column_reference)) + self._stack.append(self.generate_column_reference_identifier(column_reference)) def relational_to_sqlglot( self, expr: RelationalExpression, output_name: str | None = None diff --git a/pydough/sqlglot/sqlglot_relational_visitor.py b/pydough/sqlglot/sqlglot_relational_visitor.py index 327bc3ef..5f1e17ae 100644 --- a/pydough/sqlglot/sqlglot_relational_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_visitor.py @@ -9,7 +9,6 @@ from sqlglot.dialects import Dialect as SQLGlotDialect from sqlglot.expressions import Alias as SQLGlotAlias -from sqlglot.expressions import Column as SQLGlotColumn from sqlglot.expressions import Expression as SQLGlotExpression from sqlglot.expressions import Identifier, Select, Subquery, values from sqlglot.expressions import Literal as SQLGlotLiteral @@ -21,7 +20,6 @@ ColumnReference, ColumnReferenceInputNameModifier, ColumnReferenceInputNameRemover, - CorrelatedReference, EmptySingleton, ExpressionSortInfo, Filter, @@ -56,11 +54,8 @@ def __init__( # Keep a stack of SQLGlot expressions so we can build up # intermediate results. self._stack: list[Select] = [] - self._correlated_names: dict[str, str] = {} self._expr_visitor: SQLGlotRelationalExpressionVisitor = ( - SQLGlotRelationalExpressionVisitor( - dialect, bindings, self._correlated_names - ) + SQLGlotRelationalExpressionVisitor(dialect, bindings) ) self._alias_modifier: ColumnReferenceInputNameModifier = ( ColumnReferenceInputNameModifier() @@ -99,7 +94,7 @@ def _is_mergeable_column(expr: SQLGlotExpression) -> bool: if isinstance(expr, SQLGlotAlias): return SQLGlotRelationalVisitor._is_mergeable_column(expr.this) else: - return isinstance(expr, (SQLGlotLiteral, Identifier, SQLGlotColumn)) + return isinstance(expr, (SQLGlotLiteral, Identifier)) @staticmethod def _try_merge_columns( @@ -159,22 +154,11 @@ def _try_merge_columns( # If the new column is a literal, we can just add it to the old # columns. modified_old_columns.append(set_glot_alias(new_column, new_name)) - elif isinstance(new_column, Identifier): + else: expr = set_glot_alias(old_column_map[new_column.this], new_name) modified_old_columns.append(expr) if isinstance(expr, Identifier): seen_cols.add(expr) - elif isinstance(new_column, SQLGlotColumn): - expr = set_glot_alias( - old_column_map[new_column.this.this], new_name - ) - modified_old_columns.append(expr) - if isinstance(expr, Identifier): - seen_cols.add(expr) - else: - raise ValueError( - f"Unsupported expression type for column merging: {new_column.__class__.__name__}" - ) # Check that there are no missing dependencies in the old columns. if old_column_deps - seen_cols: return new_columns, old_columns @@ -307,7 +291,7 @@ def contains_window(self, exp: RelationalExpression) -> bool: match exp: case CallExpression(): return any(self.contains_window(arg) for arg in exp.inputs) - case ColumnReference() | LiteralExpression() | CorrelatedReference(): + case ColumnReference() | LiteralExpression(): return False case WindowCallExpression(): return True @@ -333,12 +317,6 @@ def visit_scan(self, scan: Scan) -> None: self._stack.append(query) def visit_join(self, join: Join) -> None: - alias_map: dict[str | None, str] = {} - if join.correl_name is not None: - input_name = join.default_input_aliases[0] - alias = self._generate_table_alias() - alias_map[input_name] = alias - self._correlated_names[join.correl_name] = alias self.visit_inputs(join) inputs: list[Select] = [self._stack.pop() for _ in range(len(join.inputs))] inputs.reverse() @@ -349,12 +327,11 @@ def visit_join(self, join: Join) -> None: seen_names[column] += 1 # Only keep duplicate names. kept_names = {key for key, value in seen_names.items() if value > 1} - for i in range(len(join.inputs)): - input_name = join.default_input_aliases[i] - if input_name not in alias_map and kept_names.intersection( - join.inputs[i].columns.keys() - ): - alias_map[input_name] = self._generate_table_alias() + alias_map = { + join.default_input_aliases[i]: self._generate_table_alias() + for i in range(len(join.inputs)) + if kept_names.intersection(join.inputs[i].columns.keys()) + } self._alias_remover.set_kept_names(kept_names) self._alias_modifier.set_map(alias_map) columns = { From b519cd0b923ff7a0be3d19d98481ecc43a796c3e Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 30 Jan 2025 14:04:13 -0500 Subject: [PATCH 071/112] Added two more correlated backref edge cases --- pydough/conversion/hybrid_tree.py | 4 ++- tests/simple_pydough_functions.py | 32 +++++++++++++++++- tests/test_pipeline.py | 56 +++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 2 deletions(-) diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 732fbb1c..e49ea82e 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -1922,7 +1922,9 @@ def make_hybrid_tree( successor_hybrid = self.make_hybrid_tree( node.child_access.child_access, parent ) - partition_by = node.child_access.ancestor_context + partition_by = ( + node.child_access.ancestor_context.starting_predecessor + ) assert isinstance(partition_by, PartitionBy) for key in partition_by.keys: rhs_expr: HybridExpr = self.make_hybrid_expr( diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index e9f06919..f7516464 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -324,7 +324,7 @@ def correl_4(): def correl_5(): # Correlated back reference example #5: 2-step correlated HAS # Find every region that has at least 1 supplier whose account balance is - # within $4 of the smallest known account balance globally + # within $4 of the smallest known account balance globally. # (This is a correlated SEMI-join) selected_suppliers = nations.suppliers.WHERE( account_balance <= (BACK(3).smallest_bal + 4.0) @@ -396,3 +396,33 @@ def correl_10(): return Nations.WHERE(HASNOT(aug_region))(name, rname=aug_region.name).ORDER_BY( name.ASC() ) + + +def correl_11(): + # Correlated back reference example #11: backref out of partition child. + # Which part brands have at least 1 part that more than 40% above the + # average retail price for all parts from that brand. + # (This is a correlated SEMI-join) + brands = PARTITION(Parts, name="p", by=brand)(avg_price=AVG(p.retail_price)) + outlier_parts = p.WHERE(retail_price > 1.4 * BACK(1).avg_price) + selected_brands = brands.WHERE(HAS(outlier_parts)) + return selected_brands(brand).ORDER_BY(brand.ASC()) + + +def correl_12(): + # Correlated back reference example #12: backref out of partition child. + # Which part brands have at least 1 part that is above the average retail + # price for parts of that brand, below the average retail price for all + # parts, and has a size below 3. + # (This is a correlated SEMI-join) + global_info = TPCH(avg_price=AVG(Parts.retail_price)) + brands = global_info.PARTITION(Parts, name="p", by=brand)( + avg_price=AVG(p.retail_price) + ) + selected_parts = p.WHERE( + (retail_price > BACK(1).avg_price) + & (retail_price < BACK(2).avg_price) + & (size < 3) + ) + selected_brands = brands.WHERE(HAS(selected_parts)) + return selected_brands(brand).ORDER_BY(brand.ASC()) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 3f693cac..d16652c0 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -24,6 +24,8 @@ correl_8, correl_9, correl_10, + correl_11, + correl_12, double_partition, function_sampler, percentile_customers_per_region, @@ -1407,6 +1409,60 @@ ), id="correl_10", ), + pytest.param( + ( + correl_11, + """ +ROOT(columns=[('brand', brand)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'brand': brand, 'ordering_1': brand}) + FILTER(condition=True:bool, columns={'brand': brand}) + JOIN(conditions=[t0.brand == t1.brand], types=['semi'], columns={'brand': t0.brand}, correl_name='corr1') + PROJECT(columns={'avg_price': agg_0, 'brand': brand}) + AGGREGATE(keys={'brand': brand}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice}) + FILTER(condition=retail_price > 1.4:float64 * corr1.avg_price, columns={'brand': brand}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice}) +""", + lambda: pd.DataFrame( + {"brand": ["Brand#33", "Brand#43", "Brand#45", "Brand#55"]} + ), + ), + id="correl_11", + ), + pytest.param( + ( + correl_12, + # IMPORTANT: FIX BUG WITH THE TWO `avg_price` + """ +ROOT(columns=[('brand', brand)], orderings=[(ordering_2):asc_first]) + PROJECT(columns={'brand': brand, 'ordering_2': brand}) + FILTER(condition=True:bool, columns={'brand': brand}) + JOIN(conditions=[t0.brand == t1.brand], types=['semi'], columns={'brand': t0.brand}, correl_name='corr1') + PROJECT(columns={'avg_price': agg_1, 'brand': brand}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1, 'brand': t1.brand}) + AGGREGATE(keys={}, aggregations={}) + SCAN(table=tpch.PART, columns={'brand': p_brand}) + AGGREGATE(keys={'brand': brand}, aggregations={'agg_1': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice}) + FILTER(condition=retail_price > corr1.avg_price & retail_price < corr1.avg_price & size < 3:int64, columns={'brand': brand}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice, 'size': p_size}) +""", + lambda: pd.DataFrame( + { + "brand": [ + "Brand#14", + "Brand#31", + "Brand#33", + "Brand#33", + "Brand#43", + "Brand#43", + "Brand#55", + ] + } + ), + ), + id="correl_12", + ), ], ) def pydough_pipeline_test_data( From 4033936ddaa68910386d9cc543edf8df6b057bb4 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 30 Jan 2025 23:25:10 -0500 Subject: [PATCH 072/112] Fixing backref name conflict bug --- pydough/conversion/relational_converter.py | 14 +++++++++++ tests/test_pipeline.py | 28 ++++++++++++---------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index c725971f..b540aeb3 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -135,6 +135,18 @@ def make_null_column(self, relation: Relational) -> ColumnReference: relation.columns[name] = LiteralExpression(None, UnknownType()) return ColumnReference(name, UnknownType()) + def get_column_name( + self, name: str, existing_names: dict[str, RelationalExpression] + ) -> str: + """ + TODO + """ + new_name: str = name + while new_name in existing_names: + self.dummy_idx += 1 + new_name = f"{name}_{self.dummy_idx}" + return new_name + def get_correlated_name(self, context: TranslationOutput) -> str: """ Finds the name used to refer to a context for correlated variable @@ -742,6 +754,8 @@ def translate_calc( rel_expr: RelationalExpression = self.translate_expression( hybrid_expr, context ) + if name in proj_columns and proj_columns[name] != rel_expr: + name = self.get_column_name(name, proj_columns) proj_columns[name] = rel_expr out_columns[ref_expr] = ColumnReference(name, rel_expr.data_type) out_rel: Project = Project(context.relation, proj_columns) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index d16652c0..24491f9c 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1432,20 +1432,22 @@ pytest.param( ( correl_12, - # IMPORTANT: FIX BUG WITH THE TWO `avg_price` """ -ROOT(columns=[('brand', brand)], orderings=[(ordering_2):asc_first]) - PROJECT(columns={'brand': brand, 'ordering_2': brand}) - FILTER(condition=True:bool, columns={'brand': brand}) - JOIN(conditions=[t0.brand == t1.brand], types=['semi'], columns={'brand': t0.brand}, correl_name='corr1') - PROJECT(columns={'avg_price': agg_1, 'brand': brand}) - JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1, 'brand': t1.brand}) - AGGREGATE(keys={}, aggregations={}) - SCAN(table=tpch.PART, columns={'brand': p_brand}) - AGGREGATE(keys={'brand': brand}, aggregations={'agg_1': AVG(retail_price)}) - SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice}) - FILTER(condition=retail_price > corr1.avg_price & retail_price < corr1.avg_price & size < 3:int64, columns={'brand': brand}) - SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice, 'size': p_size}) +ROOT(columns=[('brand', brand_5)], orderings=[(ordering_2):asc_first]) + PROJECT(columns={'brand_5': brand_4, 'ordering_2': ordering_2}) + PROJECT(columns={'brand_4': brand_4, 'ordering_2': brand_4}) + PROJECT(columns={'brand_4': brand}) + FILTER(condition=True:bool, columns={'brand': brand}) + JOIN(conditions=[t0.brand == t1.brand], types=['semi'], columns={'brand': t0.brand}, correl_name='corr2') + PROJECT(columns={'avg_price': avg_price, 'avg_price_2': agg_1, 'brand': brand}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1, 'avg_price': t0.avg_price, 'brand': t1.brand}) + PROJECT(columns={'avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + AGGREGATE(keys={'brand': brand}, aggregations={'agg_1': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice}) + FILTER(condition=retail_price > corr2.avg_price_2 & retail_price < corr2.avg_price & size < 3:int64, columns={'brand': brand}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice, 'size': p_size}) """, lambda: pd.DataFrame( { From 8fcc6092959589790de448a61eba823eeb14e394 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 30 Jan 2025 23:28:39 -0500 Subject: [PATCH 073/112] Confirmed all correl queries except 1, 2, 3, 6, 8 and 9 are working --- tests/test_pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 24491f9c..21a40c9c 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1455,8 +1455,6 @@ "Brand#14", "Brand#31", "Brand#33", - "Brand#33", - "Brand#43", "Brand#43", "Brand#55", ] From 8a8e49f2c38870d09200faba27c9df3b79ee73d6 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 31 Jan 2025 14:00:30 -0500 Subject: [PATCH 074/112] Added multi-correlate example --- tests/simple_pydough_functions.py | 19 ++++++++++++++++++ tests/test_pipeline.py | 33 +++++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index f7516464..648b41d2 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -426,3 +426,22 @@ def correl_12(): ) selected_brands = brands.WHERE(HAS(selected_parts)) return selected_brands(brand).ORDER_BY(brand.ASC()) + + +def correl_13(): + # Correlated back reference example #13: multiple correlation. + # Count how many suppliers sell at least one part where the retail price + # is less than a 50% markup over the supply cost, and the retail price of + # the part is below the average of the retail price for all parts globally + # and the average for all parts from the supplier. + # (This is multiple correlated SEMI-joins) + selected_part = part.WHERE( + (retail_price < (BACK(1).supplycost * 1.5)) + & (retail_price < BACK(2).avg_price) + & (retail_price < BACK(3).avg_price) + ) + selected_supply_records = supply_records.WHERE(HAS(selected_part)) + supplier_info = Suppliers(avg_price=AVG(supply_records.part.retail_price)) + selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records)) + global_info = TPCH(avg_price=AVG(Parts.retail_price)) + return global_info(n=COUNT(selected_suppliers)) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 24491f9c..a34dd7d7 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -26,6 +26,7 @@ correl_10, correl_11, correl_12, + correl_13, double_partition, function_sampler, percentile_customers_per_region, @@ -1455,8 +1456,6 @@ "Brand#14", "Brand#31", "Brand#33", - "Brand#33", - "Brand#43", "Brand#43", "Brand#55", ] @@ -1465,6 +1464,36 @@ ), id="correl_12", ), + pytest.param( + ( + correl_13, + """ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_1}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1}, correl_name='corr4') + PROJECT(columns={'avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) + FILTER(condition=True:bool, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr3') + PROJECT(columns={'account_balance': account_balance, 'avg_price': agg_0, 'key': key}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'key': t0.key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) + FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + FILTER(condition=retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price & retail_price < corr4.avg_price, columns={'key': key}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) +""", + lambda: pd.DataFrame({"n": [516]}), + ), + id="correl_13", + ), ], ) def pydough_pipeline_test_data( From 9427ccec5bd0cef1857e6e81f6f06bde4c9b4c6a Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 3 Feb 2025 10:28:28 -0500 Subject: [PATCH 075/112] Finished adding/refining complex correlation tests --- tests/simple_pydough_functions.py | 48 ++++++++++++++++++--- tests/test_pipeline.py | 71 ++++++++++++++++++++++++++++--- 2 files changed, 109 insertions(+), 10 deletions(-) diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index 648b41d2..bc9d20f7 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -431,17 +431,55 @@ def correl_12(): def correl_13(): # Correlated back reference example #13: multiple correlation. # Count how many suppliers sell at least one part where the retail price + # is less than a 50% markup over the supply cost. Only considers suppliers + # from nations #1/#2/#3, and small parts. + # (This is a correlated SEMI-joins) + selected_part = part.WHERE(STARTSWITH(container, "SM")).WHERE( + retail_price < (BACK(1).supplycost * 1.5) + ) + selected_supply_records = supply_records.WHERE(HAS(selected_part)) + supplier_info = Suppliers.WHERE(nation_key <= 3)( + avg_price=AVG(supply_records.part.retail_price) + ) + selected_suppliers = supplier_info.WHERE(COUNT(selected_supply_records) > 0) + return TPCH(n=COUNT(selected_suppliers)) + + +def correl_14(): + # Correlated back reference example #14: multiple correlation. + # Count how many suppliers sell at least one part where the retail price # is less than a 50% markup over the supply cost, and the retail price of - # the part is below the average of the retail price for all parts globally - # and the average for all parts from the supplier. + # the part is below the average for all parts from the supplier. Only + # considers suppliers from nations #19, and LG DRUM parts. # (This is multiple correlated SEMI-joins) - selected_part = part.WHERE( + selected_part = part.WHERE(container == "LG DRUM").WHERE( + (retail_price < (BACK(1).supplycost * 1.5)) & (retail_price < BACK(2).avg_price) + ) + selected_supply_records = supply_records.WHERE(HAS(selected_part)) + supplier_info = Suppliers.WHERE(nation_key == 19)( + avg_price=AVG(supply_records.part.retail_price) + ) + selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records)) + return TPCH(n=COUNT(selected_suppliers)) + + +def correl_15(): + # Correlated back reference example #15: multiple correlation. + # Count how many suppliers sell at least one part where the retail price + # is less than a 50% markup over the supply cost, and the retail price of + # the part is below the 90% of the average of the retail price for all + # parts globally and below the average for all parts from the supplier. + # Only considers suppliers from nations #19, and LG DRUM parts. + # (This is multiple correlated SEMI-joins & a correlated aggregate) + selected_part = part.WHERE(container == "LG DRUM").WHERE( (retail_price < (BACK(1).supplycost * 1.5)) & (retail_price < BACK(2).avg_price) - & (retail_price < BACK(3).avg_price) + & (retail_price < BACK(3).avg_price * 0.9) ) selected_supply_records = supply_records.WHERE(HAS(selected_part)) - supplier_info = Suppliers(avg_price=AVG(supply_records.part.retail_price)) + supplier_info = Suppliers.WHERE(nation_key == 19)( + avg_price=AVG(supply_records.part.retail_price) + ) selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records)) global_info = TPCH(avg_price=AVG(Parts.retail_price)) return global_info(n=COUNT(selected_suppliers)) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index a34dd7d7..2e098a45 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -27,6 +27,8 @@ correl_11, correl_12, correl_13, + correl_14, + correl_15, double_partition, function_sampler, percentile_customers_per_region, @@ -1468,6 +1470,62 @@ ( correl_13, """ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) + FILTER(condition=DEFAULT_TO(agg_1, 0:int64) > 0:int64, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_1': t1.agg_1}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'key': t0.key}) + FILTER(condition=nation_key <= 3:int64, columns={'account_balance': account_balance, 'key': key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_1': COUNT()}) + FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + FILTER(condition=retail_price < corr2.supplycost * 1.5:float64, columns={'key': key}) + FILTER(condition=STARTSWITH(container, 'SM':string), columns={'key': key, 'retail_price': retail_price}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) +""", + lambda: pd.DataFrame({"n": [1129]}), + ), + id="correl_13", + ), + pytest.param( + ( + correl_14, + """ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) + FILTER(condition=True:bool, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr3') + PROJECT(columns={'account_balance': account_balance, 'avg_price': agg_0, 'key': key}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'key': t0.key}) + FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'key': key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) + FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + FILTER(condition=retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price, columns={'key': key}) + FILTER(condition=container == 'LG DRUM':string, columns={'key': key, 'retail_price': retail_price}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) +""", + lambda: pd.DataFrame({"n": [66]}), + ), + id="correl_14", + ), + pytest.param( + ( + correl_15, + """ ROOT(columns=[('n', n)], orderings=[]) PROJECT(columns={'n': agg_1}) JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1}, correl_name='corr4') @@ -1479,7 +1537,8 @@ JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr3') PROJECT(columns={'account_balance': account_balance, 'avg_price': agg_0, 'key': key}) JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'key': t0.key}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey}) + FILTER(condition=nation_key <= 3:int64, columns={'account_balance': account_balance, 'key': key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) @@ -1487,12 +1546,13 @@ FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - FILTER(condition=retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price & retail_price < corr4.avg_price, columns={'key': key}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) + FILTER(condition=retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price & retail_price < corr4.avg_price * 0.9:float64, columns={'key': key}) + FILTER(condition=STARTSWITH(container, 'SM':string), columns={'key': key, 'retail_price': retail_price}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) """, - lambda: pd.DataFrame({"n": [516]}), + lambda: pd.DataFrame({"n": [212]}), ), - id="correl_13", + id="correl_15", ), ], ) @@ -1520,6 +1580,7 @@ def test_pipeline_until_relational( ], get_sample_graph: graph_fetcher, default_config: PyDoughConfigs, + sqlite_bindings, ) -> None: """ Tests that a PyDough unqualified node can be correctly translated to its From 397f40d762168c28ef81e7ffa8e7854bb65412f0 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 3 Feb 2025 10:29:14 -0500 Subject: [PATCH 076/112] Pulling up testing changes --- tests/simple_pydough_functions.py | 48 ++++++++++++++++++--- tests/test_pipeline.py | 71 ++++++++++++++++++++++++++++--- 2 files changed, 109 insertions(+), 10 deletions(-) diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index 648b41d2..bc9d20f7 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -431,17 +431,55 @@ def correl_12(): def correl_13(): # Correlated back reference example #13: multiple correlation. # Count how many suppliers sell at least one part where the retail price + # is less than a 50% markup over the supply cost. Only considers suppliers + # from nations #1/#2/#3, and small parts. + # (This is a correlated SEMI-joins) + selected_part = part.WHERE(STARTSWITH(container, "SM")).WHERE( + retail_price < (BACK(1).supplycost * 1.5) + ) + selected_supply_records = supply_records.WHERE(HAS(selected_part)) + supplier_info = Suppliers.WHERE(nation_key <= 3)( + avg_price=AVG(supply_records.part.retail_price) + ) + selected_suppliers = supplier_info.WHERE(COUNT(selected_supply_records) > 0) + return TPCH(n=COUNT(selected_suppliers)) + + +def correl_14(): + # Correlated back reference example #14: multiple correlation. + # Count how many suppliers sell at least one part where the retail price # is less than a 50% markup over the supply cost, and the retail price of - # the part is below the average of the retail price for all parts globally - # and the average for all parts from the supplier. + # the part is below the average for all parts from the supplier. Only + # considers suppliers from nations #19, and LG DRUM parts. # (This is multiple correlated SEMI-joins) - selected_part = part.WHERE( + selected_part = part.WHERE(container == "LG DRUM").WHERE( + (retail_price < (BACK(1).supplycost * 1.5)) & (retail_price < BACK(2).avg_price) + ) + selected_supply_records = supply_records.WHERE(HAS(selected_part)) + supplier_info = Suppliers.WHERE(nation_key == 19)( + avg_price=AVG(supply_records.part.retail_price) + ) + selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records)) + return TPCH(n=COUNT(selected_suppliers)) + + +def correl_15(): + # Correlated back reference example #15: multiple correlation. + # Count how many suppliers sell at least one part where the retail price + # is less than a 50% markup over the supply cost, and the retail price of + # the part is below the 90% of the average of the retail price for all + # parts globally and below the average for all parts from the supplier. + # Only considers suppliers from nations #19, and LG DRUM parts. + # (This is multiple correlated SEMI-joins & a correlated aggregate) + selected_part = part.WHERE(container == "LG DRUM").WHERE( (retail_price < (BACK(1).supplycost * 1.5)) & (retail_price < BACK(2).avg_price) - & (retail_price < BACK(3).avg_price) + & (retail_price < BACK(3).avg_price * 0.9) ) selected_supply_records = supply_records.WHERE(HAS(selected_part)) - supplier_info = Suppliers(avg_price=AVG(supply_records.part.retail_price)) + supplier_info = Suppliers.WHERE(nation_key == 19)( + avg_price=AVG(supply_records.part.retail_price) + ) selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records)) global_info = TPCH(avg_price=AVG(Parts.retail_price)) return global_info(n=COUNT(selected_suppliers)) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index a34dd7d7..2e098a45 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -27,6 +27,8 @@ correl_11, correl_12, correl_13, + correl_14, + correl_15, double_partition, function_sampler, percentile_customers_per_region, @@ -1468,6 +1470,62 @@ ( correl_13, """ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) + FILTER(condition=DEFAULT_TO(agg_1, 0:int64) > 0:int64, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_1': t1.agg_1}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'key': t0.key}) + FILTER(condition=nation_key <= 3:int64, columns={'account_balance': account_balance, 'key': key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_1': COUNT()}) + FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + FILTER(condition=retail_price < corr2.supplycost * 1.5:float64, columns={'key': key}) + FILTER(condition=STARTSWITH(container, 'SM':string), columns={'key': key, 'retail_price': retail_price}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) +""", + lambda: pd.DataFrame({"n": [1129]}), + ), + id="correl_13", + ), + pytest.param( + ( + correl_14, + """ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) + FILTER(condition=True:bool, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr3') + PROJECT(columns={'account_balance': account_balance, 'avg_price': agg_0, 'key': key}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'key': t0.key}) + FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'key': key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) + FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + FILTER(condition=retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price, columns={'key': key}) + FILTER(condition=container == 'LG DRUM':string, columns={'key': key, 'retail_price': retail_price}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) +""", + lambda: pd.DataFrame({"n": [66]}), + ), + id="correl_14", + ), + pytest.param( + ( + correl_15, + """ ROOT(columns=[('n', n)], orderings=[]) PROJECT(columns={'n': agg_1}) JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1}, correl_name='corr4') @@ -1479,7 +1537,8 @@ JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr3') PROJECT(columns={'account_balance': account_balance, 'avg_price': agg_0, 'key': key}) JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'key': t0.key}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey}) + FILTER(condition=nation_key <= 3:int64, columns={'account_balance': account_balance, 'key': key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) @@ -1487,12 +1546,13 @@ FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - FILTER(condition=retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price & retail_price < corr4.avg_price, columns={'key': key}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) + FILTER(condition=retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price & retail_price < corr4.avg_price * 0.9:float64, columns={'key': key}) + FILTER(condition=STARTSWITH(container, 'SM':string), columns={'key': key, 'retail_price': retail_price}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) """, - lambda: pd.DataFrame({"n": [516]}), + lambda: pd.DataFrame({"n": [212]}), ), - id="correl_13", + id="correl_15", ), ], ) @@ -1520,6 +1580,7 @@ def test_pipeline_until_relational( ], get_sample_graph: graph_fetcher, default_config: PyDoughConfigs, + sqlite_bindings, ) -> None: """ Tests that a PyDough unqualified node can be correctly translated to its From 719eec50ef2142d084ed892bf3aee17102822bba Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 3 Feb 2025 11:12:45 -0500 Subject: [PATCH 077/112] Adding two more correl tests --- tests/simple_pydough_functions.py | 28 ++++++++++++++++++++ tests/test_pipeline.py | 43 +++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index bc9d20f7..87604d8d 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -483,3 +483,31 @@ def correl_15(): selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records)) global_info = TPCH(avg_price=AVG(Parts.retail_price)) return global_info(n=COUNT(selected_suppliers)) + + +def correl_16(): + # Correlated back reference example #15: hybrid tree order of operations. + # Count how many european suppliers have the exact same percentile value + # of account balance (relative to all other suppliers) as at least one + # customer's percentile value of account balance relative to all other + # customers. Percentile should be measured down to increments of 0.01%. + # (This is a correlated SEMI-joins) + selected_customers = nation(rname=region.name).customers.WHERE( + (PERCENTILE(by=acctbal.ASC(), n_buckets=10000) == BACK(2).tile) + & (BACK(1).rname == "EUROPE") + ) + supplier_info = Suppliers( + tile=PERCENTILE(by=account_balance.ASC(), n_buckets=10000) + ) + selected_suppliers = supplier_info.WHERE(HAS(selected_customers)) + return TPCH(n=COUNT(selected_suppliers)) + + +def correl_17(): + # Correlated back reference example #15: hybrid tree order of operations. + # An extremely roundabout way of getting each region_name-nation_name + # pair as a string. + # (This is a correlated singular/semi access) + region_info = region(fname=JOIN_STRINGS("-", LOWER(name), BACK(1).lname)) + nation_info = Nations(lname=LOWER(name)).WHERE(HAS(region_info)) + return nation_info(fullname=region_info.fname).ORDER_BY(fullname.ASC()) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 2e098a45..48425f5b 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -29,6 +29,8 @@ correl_13, correl_14, correl_15, + correl_16, + correl_17, double_partition, function_sampler, percentile_customers_per_region, @@ -1554,6 +1556,47 @@ ), id="correl_15", ), + pytest.param( + ( + correl_16, + """ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) + FILTER(condition=True:bool, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.nation_key == t1.key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr7') + PROJECT(columns={'account_balance': account_balance, 'nation_key': nation_key, 'tile': PERCENTILE(args=[], partition=[], order=[(account_balance):asc_last], n_buckets=10000)}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) + FILTER(condition=PERCENTILE(args=[], partition=[], order=[(acctbal):asc_last], n_buckets=10000) == corr7.tile & rname == 'EUROPE':string, columns={'key': key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'key': t0.key, 'rname': t0.rname}) + PROJECT(columns={'key': key, 'rname': name_3}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) +""", + lambda: pd.DataFrame({"n": [925]}), + ), + id="correl_16", + ), + pytest.param( + ( + correl_17, + """ +ROOT(columns=[('fullname', fullname)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'fullname': fullname, 'ordering_0': fullname}) + PROJECT(columns={'fullname': fname}) + FILTER(condition=True:bool, columns={'fname': fname}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'fname': t1.fname}, correl_name='corr1') + PROJECT(columns={'lname': LOWER(name), 'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + PROJECT(columns={'fname': JOIN_STRINGS('-':string, LOWER(name), corr1.lname), 'key': key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) +""", + lambda: pd.DataFrame({"fullname": [925]}), + ), + id="correl_17", + ), ], ) def pydough_pipeline_test_data( From 52c2af55be757d5de0be32e840f2eaded80aa41b Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 3 Feb 2025 12:36:56 -0500 Subject: [PATCH 078/112] Fixing correl 16 --- tests/simple_pydough_functions.py | 4 ++-- tests/test_pipeline.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index 87604d8d..f45b5cf4 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -493,11 +493,11 @@ def correl_16(): # customers. Percentile should be measured down to increments of 0.01%. # (This is a correlated SEMI-joins) selected_customers = nation(rname=region.name).customers.WHERE( - (PERCENTILE(by=acctbal.ASC(), n_buckets=10000) == BACK(2).tile) + (PERCENTILE(by=(acctbal.ASC(), key.ASC()), n_buckets=10000) == BACK(2).tile) & (BACK(1).rname == "EUROPE") ) supplier_info = Suppliers( - tile=PERCENTILE(by=account_balance.ASC(), n_buckets=10000) + tile=PERCENTILE(by=(account_balance.ASC(), key.ASC()), n_buckets=10000) ) selected_suppliers = supplier_info.WHERE(HAS(selected_customers)) return TPCH(n=COUNT(selected_suppliers)) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 48425f5b..c67ed1a3 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1575,7 +1575,7 @@ SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) """, - lambda: pd.DataFrame({"n": [925]}), + lambda: pd.DataFrame({"n": [929]}), ), id="correl_16", ), From 19b362881f32315ca1b61bcc0f6cd8820d08e18b Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 3 Feb 2025 12:37:20 -0500 Subject: [PATCH 079/112] Fixing correlated test #16 --- tests/simple_pydough_functions.py | 4 ++-- tests/test_pipeline.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index 87604d8d..f45b5cf4 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -493,11 +493,11 @@ def correl_16(): # customers. Percentile should be measured down to increments of 0.01%. # (This is a correlated SEMI-joins) selected_customers = nation(rname=region.name).customers.WHERE( - (PERCENTILE(by=acctbal.ASC(), n_buckets=10000) == BACK(2).tile) + (PERCENTILE(by=(acctbal.ASC(), key.ASC()), n_buckets=10000) == BACK(2).tile) & (BACK(1).rname == "EUROPE") ) supplier_info = Suppliers( - tile=PERCENTILE(by=account_balance.ASC(), n_buckets=10000) + tile=PERCENTILE(by=(account_balance.ASC(), key.ASC()), n_buckets=10000) ) selected_suppliers = supplier_info.WHERE(HAS(selected_customers)) return TPCH(n=COUNT(selected_suppliers)) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 48425f5b..c67ed1a3 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1575,7 +1575,7 @@ SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) """, - lambda: pd.DataFrame({"n": [925]}), + lambda: pd.DataFrame({"n": [929]}), ), id="correl_16", ), From b6c9b8509eebb8c69a691fe179598a1f2fd106e2 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 3 Feb 2025 13:53:51 -0500 Subject: [PATCH 080/112] WIP handling renamings --- pydough/conversion/relational_converter.py | 26 ++++- tests/test_pipeline.py | 109 +++++++++++---------- tests/test_qdag_conversion.py | 42 ++++---- 3 files changed, 104 insertions(+), 73 deletions(-) diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 96ddcdc9..981f74b7 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -78,7 +78,17 @@ class TranslationOutput: """ relation: Relational + """ + The relational tree describing the way to compute the answer for the + logic originally in the hybrid tree. + """ + expressions: dict[HybridExpr, ColumnReference] + """ + A mapping of each expression that was accessible in the hybrid tree to the + corresponding column reference in the relational tree that contains the + value of that expression. + """ class RelTranslation: @@ -114,6 +124,18 @@ def make_null_column(self, relation: Relational) -> ColumnReference: relation.columns[name] = LiteralExpression(None, UnknownType()) return ColumnReference(name, UnknownType()) + def get_column_name( + self, name: str, existing_names: dict[str, RelationalExpression] + ) -> str: + """ + TODO + """ + new_name: str = name + while new_name in existing_names: + self.dummy_idx += 1 + new_name = f"{name}_{self.dummy_idx}" + return new_name + def translate_expression( self, expr: HybridExpr, context: TranslationOutput | None ) -> RelationalExpression: @@ -681,6 +703,8 @@ def translate_calc( rel_expr: RelationalExpression = self.translate_expression( hybrid_expr, context ) + if name in proj_columns and proj_columns[name] != rel_expr: + name = self.get_column_name(name, proj_columns) proj_columns[name] = rel_expr out_columns[ref_expr] = ColumnReference(name, rel_expr.data_type) out_rel: Project = Project(context.relation, proj_columns) @@ -905,7 +929,7 @@ def convert_ast_to_relational( # Convert the QDAG node to the hybrid form, then invoke the relational # conversion procedure. The first element in the returned list is the # final rel node. - hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node) + hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node, None) renamings: dict[str, str] = hybrid.pipeline[-1].renamings output: TranslationOutput = translator.rel_translation( None, hybrid, len(hybrid.pipeline) - 1 diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 41605af3..b4da1df3 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -649,16 +649,17 @@ ( rank_nations_per_region_by_customers, """ -ROOT(columns=[('name', name), ('rank', rank)], orderings=[(ordering_1):asc_first]) - LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'name': name, 'ordering_1': ordering_1, 'rank': rank}, orderings=[(ordering_1):asc_first]) - PROJECT(columns={'name': name, 'ordering_1': rank, 'rank': rank}) - PROJECT(columns={'name': name_3, 'rank': RANKING(args=[], partition=[key], order=[(DEFAULT_TO(agg_0, 0:int64)):desc_first])}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'name_3': t0.name_3}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name_3': t1.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) - SCAN(table=tpch.CUSTOMER, columns={'nation_key': c_nationkey}) +ROOT(columns=[('name', name_6), ('rank', rank)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'name_6': name_5, 'ordering_1': ordering_1, 'rank': rank}) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'name_5': name_5, 'ordering_1': ordering_1, 'rank': rank}, orderings=[(ordering_1):asc_first]) + PROJECT(columns={'name_5': name_5, 'ordering_1': rank, 'rank': rank}) + PROJECT(columns={'name_5': name_3, 'rank': RANKING(args=[], partition=[key], order=[(DEFAULT_TO(agg_0, 0:int64)):desc_first])}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'name_3': t0.name_3}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) + SCAN(table=tpch.CUSTOMER, columns={'nation_key': c_nationkey}) """, lambda: pd.DataFrame( { @@ -673,19 +674,20 @@ ( rank_parts_per_supplier_region_by_size, """ -ROOT(columns=[('key', key), ('region', region), ('rank', rank)], orderings=[(ordering_0):asc_first]) - LIMIT(limit=Literal(value=15, type=Int64Type()), columns={'key': key, 'ordering_0': ordering_0, 'rank': rank, 'region': region}, orderings=[(ordering_0):asc_first]) - PROJECT(columns={'key': key, 'ordering_0': key, 'rank': rank, 'region': region}) - PROJECT(columns={'key': key_9, 'rank': RANKING(args=[], partition=[key], order=[(size):desc_first, (container):desc_first, (part_type):desc_first], allow_ties=True, dense=True), 'region': name}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'container': t1.container, 'key': t0.key, 'key_9': t1.key, 'name': t0.name, 'part_type': t1.part_type, 'size': t1.size}) - JOIN(conditions=[t0.key_5 == t1.supplier_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'part_key': t1.part_key}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'name': t0.name}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name': t0.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'part_type': p_type, 'size': p_size}) +ROOT(columns=[('key', key_13), ('region', region), ('rank', rank)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'key_13': key_12, 'ordering_0': ordering_0, 'rank': rank, 'region': region}) + LIMIT(limit=Literal(value=15, type=Int64Type()), columns={'key_12': key_12, 'ordering_0': ordering_0, 'rank': rank, 'region': region}, orderings=[(ordering_0):asc_first]) + PROJECT(columns={'key_12': key_12, 'ordering_0': key_12, 'rank': rank, 'region': region}) + PROJECT(columns={'key_12': key_9, 'rank': RANKING(args=[], partition=[key], order=[(size):desc_first, (container):desc_first, (part_type):desc_first], allow_ties=True, dense=True), 'region': name}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'container': t1.container, 'key': t0.key, 'key_9': t1.key, 'name': t0.name, 'part_type': t1.part_type, 'size': t1.size}) + JOIN(conditions=[t0.key_5 == t1.supplier_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'part_key': t1.part_key}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'name': t0.name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'part_type': p_type, 'size': p_size}) """, lambda: pd.DataFrame( { @@ -779,15 +781,16 @@ ( rank_with_filters_c, """ -ROOT(columns=[('size', size), ('name', name)], orderings=[]) - FILTER(condition=RANKING(args=[], partition=[size], order=[(retail_price):desc_first]) == 1:int64, columns={'name': name, 'size': size}) - PROJECT(columns={'name': name, 'retail_price': retail_price, 'size': size_1}) - JOIN(conditions=[t0.size == t1.size], types=['inner'], columns={'name': t1.name, 'retail_price': t1.retail_price, 'size_1': t1.size}) - LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'size': size}, orderings=[(ordering_0):desc_last]) - PROJECT(columns={'ordering_0': size, 'size': size}) - AGGREGATE(keys={'size': size}, aggregations={}) - SCAN(table=tpch.PART, columns={'size': p_size}) - SCAN(table=tpch.PART, columns={'name': p_name, 'retail_price': p_retailprice, 'size': p_size}) +ROOT(columns=[('size', size_4), ('name', name)], orderings=[]) + PROJECT(columns={'name': name, 'size_4': size_3}) + FILTER(condition=RANKING(args=[], partition=[size], order=[(retail_price):desc_first]) == 1:int64, columns={'name': name, 'size_3': size_3}) + PROJECT(columns={'name': name, 'retail_price': retail_price, 'size': size, 'size_3': size_1}) + JOIN(conditions=[t0.size == t1.size], types=['inner'], columns={'name': t1.name, 'retail_price': t1.retail_price, 'size': t0.size, 'size_1': t1.size}) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'size': size}, orderings=[(ordering_0):desc_last]) + PROJECT(columns={'ordering_0': size, 'size': size}) + AGGREGATE(keys={'size': size}, aggregations={}) + SCAN(table=tpch.PART, columns={'size': p_size}) + SCAN(table=tpch.PART, columns={'name': p_name, 'retail_price': p_retailprice, 'size': p_size}) """, lambda: pd.DataFrame( { @@ -851,15 +854,16 @@ ( percentile_customers_per_region, """ -ROOT(columns=[('name', name)], orderings=[(ordering_0):asc_first]) - PROJECT(columns={'name': name, 'ordering_0': name}) - FILTER(condition=PERCENTILE(args=[], partition=[key], order=[(acctbal):asc_last]) == 95:int64 & ENDSWITH(phone, '00':string), columns={'name': name}) - PROJECT(columns={'acctbal': acctbal, 'key': key, 'name': name_6, 'phone': phone}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'key': t0.key, 'name_6': t1.name, 'phone': t1.phone}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name, 'nation_key': c_nationkey, 'phone': c_phone}) +ROOT(columns=[('name', name_9)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'name_9': name_8, 'ordering_0': ordering_0}) + PROJECT(columns={'name_8': name_8, 'ordering_0': name_8}) + FILTER(condition=PERCENTILE(args=[], partition=[key], order=[(acctbal):asc_last]) == 95:int64 & ENDSWITH(phone, '00':string), columns={'name_8': name_8}) + PROJECT(columns={'acctbal': acctbal, 'key': key, 'name_8': name_6, 'phone': phone}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'key': t0.key, 'name_6': t1.name, 'phone': t1.phone}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name, 'nation_key': c_nationkey, 'phone': c_phone}) """, lambda: pd.DataFrame( { @@ -884,17 +888,18 @@ ( regional_suppliers_percentile, """ -ROOT(columns=[('name', name)], orderings=[]) - FILTER(condition=True:bool & PERCENTILE(args=[], partition=[key], order=[(DEFAULT_TO(agg_0, 0:int64)):asc_last, (name):asc_last], n_buckets=1000) == 1000:int64, columns={'name': name}) - JOIN(conditions=[t0.key_5 == t1.supplier_key], types=['inner'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'name': t0.name}) - PROJECT(columns={'key': key, 'key_5': key_5, 'name': name_6}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'name_6': t1.name}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) - AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': COUNT()}) - SCAN(table=tpch.PARTSUPP, columns={'supplier_key': ps_suppkey}) +ROOT(columns=[('name', name_9)], orderings=[]) + PROJECT(columns={'name_9': name_8}) + FILTER(condition=True:bool & PERCENTILE(args=[], partition=[key], order=[(DEFAULT_TO(agg_0, 0:int64)):asc_last, (name_8):asc_last], n_buckets=1000) == 1000:int64, columns={'name_8': name_8}) + JOIN(conditions=[t0.key_5 == t1.supplier_key], types=['inner'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'name_8': t0.name_8}) + PROJECT(columns={'key': key, 'key_5': key_5, 'name_8': name_6}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'name_6': t1.name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': COUNT()}) + SCAN(table=tpch.PARTSUPP, columns={'supplier_key': ps_suppkey}) """, lambda: pd.DataFrame( { diff --git a/tests/test_qdag_conversion.py b/tests/test_qdag_conversion.py index ac126436..4611a408 100644 --- a/tests/test_qdag_conversion.py +++ b/tests/test_qdag_conversion.py @@ -89,8 +89,8 @@ ( TableCollectionInfo("Regions") ** SubCollectionInfo("nations"), """ -ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[]) - PROJECT(columns={'comment': comment_1, 'key': key_2, 'name': name_3, 'region_key': region_key}) +ROOT(columns=[('key', key_6), ('name', name_7), ('region_key', region_key), ('comment', comment_5)], orderings=[]) + PROJECT(columns={'comment_5': comment_1, 'key_6': key_2, 'name_7': name_3, 'region_key': region_key}) JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'comment_1': t1.comment, 'key_2': t1.key, 'name_3': t1.name, 'region_key': t1.region_key}) SCAN(table=tpch.REGION, columns={'key': r_regionkey}) SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) @@ -104,8 +104,8 @@ ** SubCollectionInfo("nations") ** SubCollectionInfo("customers"), """ -ROOT(columns=[('key', key), ('name', name), ('address', address), ('nation_key', nation_key), ('phone', phone), ('acctbal', acctbal), ('mktsegment', mktsegment), ('comment', comment)], orderings=[]) - PROJECT(columns={'acctbal': acctbal, 'address': address, 'comment': comment_4, 'key': key_5, 'mktsegment': mktsegment, 'name': name_6, 'nation_key': nation_key, 'phone': phone}) +ROOT(columns=[('key', key_9), ('name', name_10), ('address', address), ('nation_key', nation_key), ('phone', phone), ('acctbal', acctbal), ('mktsegment', mktsegment), ('comment', comment_8)], orderings=[]) + PROJECT(columns={'acctbal': acctbal, 'address': address, 'comment_8': comment_4, 'key_9': key_5, 'mktsegment': mktsegment, 'name_10': name_6, 'nation_key': nation_key, 'phone': phone}) JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'address': t1.address, 'comment_4': t1.comment, 'key_5': t1.key, 'mktsegment': t1.mktsegment, 'name_6': t1.name, 'nation_key': t1.nation_key, 'phone': t1.phone}) JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key}) SCAN(table=tpch.REGION, columns={'key': r_regionkey}) @@ -242,13 +242,14 @@ mktsegment=ReferenceInfo("mktsegment"), ), """ -ROOT(columns=[('key', key_0), ('name', name), ('phone', phone), ('mktsegment', mktsegment)], orderings=[]) - PROJECT(columns={'key_0': -3:int64, 'mktsegment': mktsegment, 'name': name_6, 'phone': phone}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'mktsegment': t1.mktsegment, 'name_6': t1.name, 'phone': t1.phone}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.CUSTOMER, columns={'mktsegment': c_mktsegment, 'name': c_name, 'nation_key': c_nationkey, 'phone': c_phone}) +ROOT(columns=[('key', key_0_9), ('name', name_11), ('phone', phone), ('mktsegment', mktsegment)], orderings=[]) + PROJECT(columns={'key_0_9': key_0_9, 'mktsegment': mktsegment, 'name_11': name_10, 'phone': phone}) + PROJECT(columns={'key_0_9': -3:int64, 'mktsegment': mktsegment, 'name_10': name_7, 'phone': phone}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'mktsegment': t1.mktsegment, 'name_7': t1.name, 'phone': t1.phone}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.CUSTOMER, columns={'mktsegment': c_mktsegment, 'name': c_name, 'nation_key': c_nationkey, 'phone': c_phone}) """, ), id="join_regions_nations_calc_override", @@ -700,15 +701,16 @@ ), ), """ -ROOT(columns=[('part_key', part_key), ('supplier_key', supplier_key), ('order_key', order_key), ('order_quantity_ratio', order_quantity_ratio)], orderings=[]) - PROJECT(columns={'order_key': order_key_2, 'order_quantity_ratio': quantity / total_quantity, 'part_key': part_key, 'supplier_key': supplier_key}) - JOIN(conditions=[t0.key == t1.order_key], types=['inner'], columns={'order_key_2': t1.order_key, 'part_key': t1.part_key, 'quantity': t1.quantity, 'supplier_key': t1.supplier_key, 'total_quantity': t0.total_quantity}) - PROJECT(columns={'key': key, 'total_quantity': DEFAULT_TO(agg_0, 0:int64)}) - JOIN(conditions=[t0.key == t1.order_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) - SCAN(table=tpch.ORDERS, columns={'key': o_orderkey}) - AGGREGATE(keys={'order_key': order_key}, aggregations={'agg_0': SUM(quantity)}) - SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'quantity': l_quantity}) - SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'part_key': l_partkey, 'quantity': l_quantity, 'supplier_key': l_suppkey}) +ROOT(columns=[('part_key', part_key), ('supplier_key', supplier_key), ('order_key', order_key_5), ('order_quantity_ratio', order_quantity_ratio)], orderings=[]) + PROJECT(columns={'order_key_5': order_key_4, 'order_quantity_ratio': order_quantity_ratio, 'part_key': part_key, 'supplier_key': supplier_key}) + PROJECT(columns={'order_key_4': order_key_2, 'order_quantity_ratio': quantity / total_quantity, 'part_key': part_key, 'supplier_key': supplier_key}) + JOIN(conditions=[t0.key == t1.order_key], types=['inner'], columns={'order_key_2': t1.order_key, 'part_key': t1.part_key, 'quantity': t1.quantity, 'supplier_key': t1.supplier_key, 'total_quantity': t0.total_quantity}) + PROJECT(columns={'key': key, 'total_quantity': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[t0.key == t1.order_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) + SCAN(table=tpch.ORDERS, columns={'key': o_orderkey}) + AGGREGATE(keys={'order_key': order_key}, aggregations={'agg_0': SUM(quantity)}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'quantity': l_quantity}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'part_key': l_partkey, 'quantity': l_quantity, 'supplier_key': l_suppkey}) """, ), id="aggregate_then_backref", From afd241bfc70274087a72c1121d6f6ce45eb630db Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 3 Feb 2025 14:10:28 -0500 Subject: [PATCH 081/112] Converted qdag conversion tests to be plan file based --- tests/conftest.py | 27 +- .../access_partition_child_after_filter.txt | 7 + .../access_partition_child_backref_calc.txt | 7 + ..._partition_child_filter_backref_filter.txt | 8 + tests/test_plan_refsols/agg_max_ranking.txt | 7 + .../agg_orders_by_year_month_basic.txt | 5 + .../agg_orders_by_year_month_just_europe.txt | 16 + .../agg_orders_by_year_month_vs_europe.txt | 16 + .../agg_parts_by_type_backref_global.txt | 9 + .../agg_parts_by_type_simple.txt | 4 + tests/test_plan_refsols/aggregate_anti.txt | 9 + .../aggregate_mixed_levels_simple.txt | 9 + .../aggregate_on_function_call.txt | 6 + tests/test_plan_refsols/aggregate_semi.txt | 10 + .../aggregate_then_backref.txt | 9 + tests/test_plan_refsols/anti_aggregate.txt | 9 + .../anti_aggregate_alternate.txt | 9 + tests/test_plan_refsols/anti_singular.txt | 7 + tests/test_plan_refsols/asian_nations.txt | 5 + ...count_at_most_100_suppliers_per_nation.txt | 8 + .../count_cust_supplier_nation_combos.txt | 17 + ...multiple_subcollections_alongside_aggs.txt | 9 + .../count_single_subcollection.txt | 6 + .../customers_sum_line_price.txt | 8 + .../global_aggfunc_backref.txt | 7 + tests/test_plan_refsols/global_aggfuncs.txt | 4 + .../global_aggfuncs_multiple_children.txt | 10 + .../test_plan_refsols/global_calc_backref.txt | 6 + .../global_calc_multiple.txt | 4 + .../test_plan_refsols/global_calc_simple.txt | 3 + .../join_asia_region_nations.txt | 6 + tests/test_plan_refsols/join_order_by.txt | 6 + .../join_order_by_back_reference.txt | 6 + .../join_order_by_pruned_back_reference.txt | 6 + .../test_plan_refsols/join_region_nations.txt | 5 + .../join_region_nations_customers.txt | 7 + .../join_regions_nations_calc_override.txt | 7 + tests/test_plan_refsols/join_topk.txt | 7 + .../lineitem_regional_shipments.txt | 20 + .../lineitem_regional_shipments2.txt | 20 + .../lineitem_regional_shipments3.txt | 20 + ...lineitems_access_cust_supplier_nations.txt | 15 + .../lines_german_supplier_economy_part.txt | 13 + .../lines_shipping_vs_customer_region.txt | 20 + .../mostly_positive_accounts_per_nation1.txt | 10 + .../mostly_positive_accounts_per_nation2.txt | 11 + .../mostly_positive_accounts_per_nation3.txt | 11 + .../test_plan_refsols/multiple_has_hasnot.txt | 24 + ...ple_simple_aggregations_multiple_calcs.txt | 11 + ...ltiple_simple_aggregations_single_calc.txt | 9 + .../nation_name_contains_region_name.txt | 6 + .../nations_access_region.txt | 5 + .../nations_order_by_num_suppliers.txt | 6 + .../nations_region_order_by_name.txt | 5 + .../nations_sum_line_price.txt | 10 + .../num_positive_accounts_per_nation.txt | 10 + .../order_by_before_join.txt | 5 + .../test_plan_refsols/order_by_expression.txt | 4 + .../ordered_asian_nations.txt | 6 + .../ordering_name_overload.txt | 5 + .../orders_sum_line_price.txt | 6 + .../orders_sum_vs_count_line_price.txt | 6 + tests/test_plan_refsols/rank_customers.txt | 3 + .../rank_customers_per_nation.txt | 5 + .../rank_customers_per_region.txt | 7 + .../region_nations_backref.txt | 5 + .../regions_sum_line_price.txt | 12 + tests/test_plan_refsols/replace_order_by.txt | 6 + tests/test_plan_refsols/scan_calc.txt | 3 + tests/test_plan_refsols/scan_calc_calc.txt | 4 + .../scan_customer_call_functions.txt | 3 + tests/test_plan_refsols/scan_nations.txt | 2 + tests/test_plan_refsols/scan_regions.txt | 2 + tests/test_plan_refsols/semi_aggregate.txt | 10 + tests/test_plan_refsols/semi_singular.txt | 7 + tests/test_plan_refsols/simple_anti_1.txt | 5 + tests/test_plan_refsols/simple_anti_2.txt | 8 + tests/test_plan_refsols/simple_order_by.txt | 3 + tests/test_plan_refsols/simple_semi_1.txt | 5 + tests/test_plan_refsols/simple_semi_2.txt | 8 + tests/test_plan_refsols/simple_topk.txt | 4 + tests/test_plan_refsols/singular_anti.txt | 7 + tests/test_plan_refsols/singular_semi.txt | 7 + ...top_5_nations_balance_by_num_suppliers.txt | 8 + .../top_5_nations_by_num_supplierss.txt | 7 + tests/test_plan_refsols/topk_order_by.txt | 4 + .../test_plan_refsols/topk_order_by_calc.txt | 5 + .../topk_replace_order_by.txt | 4 + .../topk_root_different_order_by.txt | 5 + .../various_aggfuncs_global.txt | 4 + .../various_aggfuncs_simple.txt | 6 + tests/test_qdag_conversion.py | 1012 ++--------------- 92 files changed, 847 insertions(+), 893 deletions(-) create mode 100644 tests/test_plan_refsols/access_partition_child_after_filter.txt create mode 100644 tests/test_plan_refsols/access_partition_child_backref_calc.txt create mode 100644 tests/test_plan_refsols/access_partition_child_filter_backref_filter.txt create mode 100644 tests/test_plan_refsols/agg_max_ranking.txt create mode 100644 tests/test_plan_refsols/agg_orders_by_year_month_basic.txt create mode 100644 tests/test_plan_refsols/agg_orders_by_year_month_just_europe.txt create mode 100644 tests/test_plan_refsols/agg_orders_by_year_month_vs_europe.txt create mode 100644 tests/test_plan_refsols/agg_parts_by_type_backref_global.txt create mode 100644 tests/test_plan_refsols/agg_parts_by_type_simple.txt create mode 100644 tests/test_plan_refsols/aggregate_anti.txt create mode 100644 tests/test_plan_refsols/aggregate_mixed_levels_simple.txt create mode 100644 tests/test_plan_refsols/aggregate_on_function_call.txt create mode 100644 tests/test_plan_refsols/aggregate_semi.txt create mode 100644 tests/test_plan_refsols/aggregate_then_backref.txt create mode 100644 tests/test_plan_refsols/anti_aggregate.txt create mode 100644 tests/test_plan_refsols/anti_aggregate_alternate.txt create mode 100644 tests/test_plan_refsols/anti_singular.txt create mode 100644 tests/test_plan_refsols/asian_nations.txt create mode 100644 tests/test_plan_refsols/count_at_most_100_suppliers_per_nation.txt create mode 100644 tests/test_plan_refsols/count_cust_supplier_nation_combos.txt create mode 100644 tests/test_plan_refsols/count_multiple_subcollections_alongside_aggs.txt create mode 100644 tests/test_plan_refsols/count_single_subcollection.txt create mode 100644 tests/test_plan_refsols/customers_sum_line_price.txt create mode 100644 tests/test_plan_refsols/global_aggfunc_backref.txt create mode 100644 tests/test_plan_refsols/global_aggfuncs.txt create mode 100644 tests/test_plan_refsols/global_aggfuncs_multiple_children.txt create mode 100644 tests/test_plan_refsols/global_calc_backref.txt create mode 100644 tests/test_plan_refsols/global_calc_multiple.txt create mode 100644 tests/test_plan_refsols/global_calc_simple.txt create mode 100644 tests/test_plan_refsols/join_asia_region_nations.txt create mode 100644 tests/test_plan_refsols/join_order_by.txt create mode 100644 tests/test_plan_refsols/join_order_by_back_reference.txt create mode 100644 tests/test_plan_refsols/join_order_by_pruned_back_reference.txt create mode 100644 tests/test_plan_refsols/join_region_nations.txt create mode 100644 tests/test_plan_refsols/join_region_nations_customers.txt create mode 100644 tests/test_plan_refsols/join_regions_nations_calc_override.txt create mode 100644 tests/test_plan_refsols/join_topk.txt create mode 100644 tests/test_plan_refsols/lineitem_regional_shipments.txt create mode 100644 tests/test_plan_refsols/lineitem_regional_shipments2.txt create mode 100644 tests/test_plan_refsols/lineitem_regional_shipments3.txt create mode 100644 tests/test_plan_refsols/lineitems_access_cust_supplier_nations.txt create mode 100644 tests/test_plan_refsols/lines_german_supplier_economy_part.txt create mode 100644 tests/test_plan_refsols/lines_shipping_vs_customer_region.txt create mode 100644 tests/test_plan_refsols/mostly_positive_accounts_per_nation1.txt create mode 100644 tests/test_plan_refsols/mostly_positive_accounts_per_nation2.txt create mode 100644 tests/test_plan_refsols/mostly_positive_accounts_per_nation3.txt create mode 100644 tests/test_plan_refsols/multiple_has_hasnot.txt create mode 100644 tests/test_plan_refsols/multiple_simple_aggregations_multiple_calcs.txt create mode 100644 tests/test_plan_refsols/multiple_simple_aggregations_single_calc.txt create mode 100644 tests/test_plan_refsols/nation_name_contains_region_name.txt create mode 100644 tests/test_plan_refsols/nations_access_region.txt create mode 100644 tests/test_plan_refsols/nations_order_by_num_suppliers.txt create mode 100644 tests/test_plan_refsols/nations_region_order_by_name.txt create mode 100644 tests/test_plan_refsols/nations_sum_line_price.txt create mode 100644 tests/test_plan_refsols/num_positive_accounts_per_nation.txt create mode 100644 tests/test_plan_refsols/order_by_before_join.txt create mode 100644 tests/test_plan_refsols/order_by_expression.txt create mode 100644 tests/test_plan_refsols/ordered_asian_nations.txt create mode 100644 tests/test_plan_refsols/ordering_name_overload.txt create mode 100644 tests/test_plan_refsols/orders_sum_line_price.txt create mode 100644 tests/test_plan_refsols/orders_sum_vs_count_line_price.txt create mode 100644 tests/test_plan_refsols/rank_customers.txt create mode 100644 tests/test_plan_refsols/rank_customers_per_nation.txt create mode 100644 tests/test_plan_refsols/rank_customers_per_region.txt create mode 100644 tests/test_plan_refsols/region_nations_backref.txt create mode 100644 tests/test_plan_refsols/regions_sum_line_price.txt create mode 100644 tests/test_plan_refsols/replace_order_by.txt create mode 100644 tests/test_plan_refsols/scan_calc.txt create mode 100644 tests/test_plan_refsols/scan_calc_calc.txt create mode 100644 tests/test_plan_refsols/scan_customer_call_functions.txt create mode 100644 tests/test_plan_refsols/scan_nations.txt create mode 100644 tests/test_plan_refsols/scan_regions.txt create mode 100644 tests/test_plan_refsols/semi_aggregate.txt create mode 100644 tests/test_plan_refsols/semi_singular.txt create mode 100644 tests/test_plan_refsols/simple_anti_1.txt create mode 100644 tests/test_plan_refsols/simple_anti_2.txt create mode 100644 tests/test_plan_refsols/simple_order_by.txt create mode 100644 tests/test_plan_refsols/simple_semi_1.txt create mode 100644 tests/test_plan_refsols/simple_semi_2.txt create mode 100644 tests/test_plan_refsols/simple_topk.txt create mode 100644 tests/test_plan_refsols/singular_anti.txt create mode 100644 tests/test_plan_refsols/singular_semi.txt create mode 100644 tests/test_plan_refsols/top_5_nations_balance_by_num_suppliers.txt create mode 100644 tests/test_plan_refsols/top_5_nations_by_num_supplierss.txt create mode 100644 tests/test_plan_refsols/topk_order_by.txt create mode 100644 tests/test_plan_refsols/topk_order_by_calc.txt create mode 100644 tests/test_plan_refsols/topk_replace_order_by.txt create mode 100644 tests/test_plan_refsols/topk_root_different_order_by.txt create mode 100644 tests/test_plan_refsols/various_aggfuncs_global.txt create mode 100644 tests/test_plan_refsols/various_aggfuncs_simple.txt diff --git a/tests/conftest.py b/tests/conftest.py index 0bdc62d3..4012c12b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ import json import os import sqlite3 -from collections.abc import MutableMapping +from collections.abc import Callable, MutableMapping import pytest from test_utils import graph_fetcher, map_over_dict_values, noun_fetcher @@ -139,6 +139,31 @@ def tpch_node_builder(get_sample_graph) -> AstNodeBuilder: return AstNodeBuilder(get_sample_graph("TPCH")) +@pytest.fixture(scope="session") +def get_plan_test_filename() -> Callable[[str], str]: + """ + A function that takes in a file name and returns the path to that file + from within the directory of plan testing refsol files. + """ + + def impl(file_name: str) -> str: + return f"{os.path.dirname(__file__)}/test_plan_refsols/{file_name}.txt" + + return impl + + +@pytest.fixture +def update_plan_tests() -> bool: + """ + If True, planner tests should update the refsol file instead of verifying + that the test matches the file. If False, the refsol file is used to check + the answer. + + This is controlled by an environment variable `PYDOUGH_UPDATE_TESTS`. + """ + return os.getenv("PYDOUGH_UPDATE_TESTS", "0") == "1" + + @pytest.fixture( params=[ pytest.param(operator, id=operator.binop.name) diff --git a/tests/test_plan_refsols/access_partition_child_after_filter.txt b/tests/test_plan_refsols/access_partition_child_after_filter.txt new file mode 100644 index 00000000..9851a23b --- /dev/null +++ b/tests/test_plan_refsols/access_partition_child_after_filter.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('part_name', part_name), ('part_type', part_type), ('retail_price', retail_price)], orderings=[]) + PROJECT(columns={'part_name': name, 'part_type': part_type_1, 'retail_price': retail_price}) + JOIN(conditions=[t0.part_type == t1.part_type], types=['inner'], columns={'name': t1.name, 'part_type_1': t1.part_type, 'retail_price': t1.retail_price}) + FILTER(condition=agg_0 > 27.5:float64, columns={'part_type': part_type}) + AGGREGATE(keys={'part_type': part_type}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'part_type': p_type, 'retail_price': p_retailprice}) + SCAN(table=tpch.PART, columns={'name': p_name, 'part_type': p_type, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/access_partition_child_backref_calc.txt b/tests/test_plan_refsols/access_partition_child_backref_calc.txt new file mode 100644 index 00000000..dd0e9955 --- /dev/null +++ b/tests/test_plan_refsols/access_partition_child_backref_calc.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('part_name', part_name), ('part_type', part_type), ('retail_price_versus_avg', retail_price_versus_avg)], orderings=[]) + PROJECT(columns={'part_name': name, 'part_type': part_type_1, 'retail_price_versus_avg': retail_price - avg_price}) + JOIN(conditions=[t0.part_type == t1.part_type], types=['inner'], columns={'avg_price': t0.avg_price, 'name': t1.name, 'part_type_1': t1.part_type, 'retail_price': t1.retail_price}) + PROJECT(columns={'avg_price': agg_0, 'part_type': part_type}) + AGGREGATE(keys={'part_type': part_type}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'part_type': p_type, 'retail_price': p_retailprice}) + SCAN(table=tpch.PART, columns={'name': p_name, 'part_type': p_type, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/access_partition_child_filter_backref_filter.txt b/tests/test_plan_refsols/access_partition_child_filter_backref_filter.txt new file mode 100644 index 00000000..c87b0bc9 --- /dev/null +++ b/tests/test_plan_refsols/access_partition_child_filter_backref_filter.txt @@ -0,0 +1,8 @@ +ROOT(columns=[('part_name', part_name), ('part_type', part_type), ('retail_price', retail_price)], orderings=[]) + FILTER(condition=retail_price < avg_price, columns={'part_name': part_name, 'part_type': part_type, 'retail_price': retail_price}) + PROJECT(columns={'avg_price': avg_price, 'part_name': name, 'part_type': part_type_1, 'retail_price': retail_price}) + JOIN(conditions=[t0.part_type == t1.part_type], types=['inner'], columns={'avg_price': t0.avg_price, 'name': t1.name, 'part_type_1': t1.part_type, 'retail_price': t1.retail_price}) + PROJECT(columns={'avg_price': agg_0, 'part_type': part_type}) + AGGREGATE(keys={'part_type': part_type}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'part_type': p_type, 'retail_price': p_retailprice}) + SCAN(table=tpch.PART, columns={'name': p_name, 'part_type': p_type, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/agg_max_ranking.txt b/tests/test_plan_refsols/agg_max_ranking.txt new file mode 100644 index 00000000..4e62ae0c --- /dev/null +++ b/tests/test_plan_refsols/agg_max_ranking.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('nation_name', nation_name), ('highest_rank', highest_rank)], orderings=[]) + PROJECT(columns={'highest_rank': agg_0, 'nation_name': name}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': MAX(cust_rank)}) + PROJECT(columns={'cust_rank': RANKING(args=[], partition=[], order=[(acctbal):desc_first], allow_ties=True), 'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/agg_orders_by_year_month_basic.txt b/tests/test_plan_refsols/agg_orders_by_year_month_basic.txt new file mode 100644 index 00000000..9266a7d6 --- /dev/null +++ b/tests/test_plan_refsols/agg_orders_by_year_month_basic.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('year', year), ('month', month), ('total_orders', total_orders)], orderings=[]) + PROJECT(columns={'month': month, 'total_orders': DEFAULT_TO(agg_0, 0:int64), 'year': year}) + AGGREGATE(keys={'month': month, 'year': year}, aggregations={'agg_0': COUNT()}) + PROJECT(columns={'month': MONTH(order_date), 'year': YEAR(order_date)}) + SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate}) diff --git a/tests/test_plan_refsols/agg_orders_by_year_month_just_europe.txt b/tests/test_plan_refsols/agg_orders_by_year_month_just_europe.txt new file mode 100644 index 00000000..c5a743f7 --- /dev/null +++ b/tests/test_plan_refsols/agg_orders_by_year_month_just_europe.txt @@ -0,0 +1,16 @@ +ROOT(columns=[('year', year), ('month', month), ('num_european_orders', num_european_orders)], orderings=[]) + PROJECT(columns={'month': month, 'num_european_orders': DEFAULT_TO(agg_0, 0:int64), 'year': year}) + JOIN(conditions=[t0.year == t1.year & t0.month == t1.month], types=['left'], columns={'agg_0': t1.agg_0, 'month': t0.month, 'year': t0.year}) + AGGREGATE(keys={'month': month, 'year': year}, aggregations={}) + PROJECT(columns={'month': MONTH(order_date), 'year': YEAR(order_date)}) + SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate}) + AGGREGATE(keys={'month': month, 'year': year}, aggregations={'agg_0': COUNT()}) + FILTER(condition=name_6 == 'ASIA':string, columns={'month': month, 'year': year}) + JOIN(conditions=[t0.customer_key == t1.key], types=['left'], columns={'month': t0.month, 'name_6': t1.name_6, 'year': t0.year}) + PROJECT(columns={'customer_key': customer_key, 'month': MONTH(order_date), 'year': YEAR(order_date)}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_6': t1.name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/agg_orders_by_year_month_vs_europe.txt b/tests/test_plan_refsols/agg_orders_by_year_month_vs_europe.txt new file mode 100644 index 00000000..49176dfe --- /dev/null +++ b/tests/test_plan_refsols/agg_orders_by_year_month_vs_europe.txt @@ -0,0 +1,16 @@ +ROOT(columns=[('year', year), ('month', month), ('num_european_orders', num_european_orders), ('total_orders', total_orders)], orderings=[]) + PROJECT(columns={'month': month, 'num_european_orders': DEFAULT_TO(agg_0, 0:int64), 'total_orders': DEFAULT_TO(agg_1, 0:int64), 'year': year}) + JOIN(conditions=[t0.year == t1.year & t0.month == t1.month], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'month': t0.month, 'year': t0.year}) + AGGREGATE(keys={'month': month, 'year': year}, aggregations={'agg_0': COUNT()}) + PROJECT(columns={'month': MONTH(order_date), 'year': YEAR(order_date)}) + SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate}) + AGGREGATE(keys={'month': month, 'year': year}, aggregations={'agg_1': COUNT()}) + FILTER(condition=name_6 == 'ASIA':string, columns={'month': month, 'year': year}) + JOIN(conditions=[t0.customer_key == t1.key], types=['left'], columns={'month': t0.month, 'name_6': t1.name_6, 'year': t0.year}) + PROJECT(columns={'customer_key': customer_key, 'month': MONTH(order_date), 'year': YEAR(order_date)}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_6': t1.name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/agg_parts_by_type_backref_global.txt b/tests/test_plan_refsols/agg_parts_by_type_backref_global.txt new file mode 100644 index 00000000..bb378572 --- /dev/null +++ b/tests/test_plan_refsols/agg_parts_by_type_backref_global.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('part_type', part_type), ('percentage_of_parts', percentage_of_parts), ('avg_price', avg_price)], orderings=[]) + FILTER(condition=avg_price >= global_avg_price, columns={'avg_price': avg_price, 'part_type': part_type, 'percentage_of_parts': percentage_of_parts}) + PROJECT(columns={'avg_price': agg_2, 'global_avg_price': global_avg_price, 'part_type': part_type, 'percentage_of_parts': DEFAULT_TO(agg_3, 0:int64) / total_num_parts}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_2': t1.agg_2, 'agg_3': t1.agg_3, 'global_avg_price': t0.global_avg_price, 'part_type': t1.part_type, 'total_num_parts': t0.total_num_parts}) + PROJECT(columns={'global_avg_price': agg_0, 'total_num_parts': agg_1}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price), 'agg_1': COUNT()}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + AGGREGATE(keys={'part_type': part_type}, aggregations={'agg_2': AVG(retail_price), 'agg_3': COUNT()}) + SCAN(table=tpch.PART, columns={'part_type': p_type, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/agg_parts_by_type_simple.txt b/tests/test_plan_refsols/agg_parts_by_type_simple.txt new file mode 100644 index 00000000..08e3b39d --- /dev/null +++ b/tests/test_plan_refsols/agg_parts_by_type_simple.txt @@ -0,0 +1,4 @@ +ROOT(columns=[('part_type', part_type), ('num_parts', num_parts), ('avg_price', avg_price)], orderings=[]) + PROJECT(columns={'avg_price': agg_0, 'num_parts': DEFAULT_TO(agg_1, 0:int64), 'part_type': part_type}) + AGGREGATE(keys={'part_type': part_type}, aggregations={'agg_0': AVG(retail_price), 'agg_1': COUNT()}) + SCAN(table=tpch.PART, columns={'part_type': p_type, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/aggregate_anti.txt b/tests/test_plan_refsols/aggregate_anti.txt new file mode 100644 index 00000000..6ea25a0a --- /dev/null +++ b/tests/test_plan_refsols/aggregate_anti.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('name', name), ('num_10parts', num_10parts), ('avg_price_of_10parts', avg_price_of_10parts), ('sum_price_of_10parts', sum_price_of_10parts)], orderings=[]) + FILTER(condition=True:bool, columns={'avg_price_of_10parts': avg_price_of_10parts, 'name': name, 'num_10parts': num_10parts, 'sum_price_of_10parts': sum_price_of_10parts}) + PROJECT(columns={'avg_price_of_10parts': NULL_2, 'name': name, 'num_10parts': DEFAULT_TO(NULL_2, 0:int64), 'sum_price_of_10parts': DEFAULT_TO(NULL_2, 0:int64)}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name}) + FILTER(condition=size == 10:int64, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'size': t1.size, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'size': p_size}) diff --git a/tests/test_plan_refsols/aggregate_mixed_levels_simple.txt b/tests/test_plan_refsols/aggregate_mixed_levels_simple.txt new file mode 100644 index 00000000..8c723d7c --- /dev/null +++ b/tests/test_plan_refsols/aggregate_mixed_levels_simple.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('order_key', order_key), ('max_ratio', max_ratio)], orderings=[]) + PROJECT(columns={'max_ratio': agg_0, 'order_key': key}) + JOIN(conditions=[t0.key == t1.order_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) + SCAN(table=tpch.ORDERS, columns={'key': o_orderkey}) + AGGREGATE(keys={'order_key': order_key}, aggregations={'agg_0': MAX(ratio)}) + PROJECT(columns={'order_key': order_key, 'ratio': quantity / availqty}) + JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['inner'], columns={'availqty': t1.availqty, 'order_key': t0.order_key, 'quantity': t0.quantity}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'part_key': l_partkey, 'quantity': l_quantity, 'supplier_key': l_suppkey}) + SCAN(table=tpch.PARTSUPP, columns={'availqty': ps_availqty, 'part_key': ps_partkey, 'supplier_key': ps_suppkey}) diff --git a/tests/test_plan_refsols/aggregate_on_function_call.txt b/tests/test_plan_refsols/aggregate_on_function_call.txt new file mode 100644 index 00000000..ab067485 --- /dev/null +++ b/tests/test_plan_refsols/aggregate_on_function_call.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('nation_name', nation_name), ('avg_consumer_value', avg_consumer_value)], orderings=[]) + PROJECT(columns={'avg_consumer_value': agg_0, 'nation_name': key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': MAX(IFF(acctbal < 0.0:float64, 0.0:float64, acctbal))}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/aggregate_semi.txt b/tests/test_plan_refsols/aggregate_semi.txt new file mode 100644 index 00000000..370980ff --- /dev/null +++ b/tests/test_plan_refsols/aggregate_semi.txt @@ -0,0 +1,10 @@ +ROOT(columns=[('name', name), ('num_10parts', num_10parts), ('avg_price_of_10parts', avg_price_of_10parts), ('sum_price_of_10parts', sum_price_of_10parts)], orderings=[]) + FILTER(condition=True:bool, columns={'avg_price_of_10parts': avg_price_of_10parts, 'name': name, 'num_10parts': num_10parts, 'sum_price_of_10parts': sum_price_of_10parts}) + PROJECT(columns={'avg_price_of_10parts': agg_0, 'name': name, 'num_10parts': DEFAULT_TO(agg_1, 0:int64), 'sum_price_of_10parts': DEFAULT_TO(agg_2, 0:int64)}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['inner'], columns={'agg_0': t1.agg_0, 'agg_1': t1.agg_1, 'agg_2': t1.agg_2, 'name': t0.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price), 'agg_1': COUNT(), 'agg_2': SUM(retail_price)}) + FILTER(condition=size == 10:int64, columns={'retail_price': retail_price, 'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'size': t1.size, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice, 'size': p_size}) diff --git a/tests/test_plan_refsols/aggregate_then_backref.txt b/tests/test_plan_refsols/aggregate_then_backref.txt new file mode 100644 index 00000000..9b0a60f3 --- /dev/null +++ b/tests/test_plan_refsols/aggregate_then_backref.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('part_key', part_key), ('supplier_key', supplier_key), ('order_key', order_key), ('order_quantity_ratio', order_quantity_ratio)], orderings=[]) + PROJECT(columns={'order_key': order_key_2, 'order_quantity_ratio': quantity / total_quantity, 'part_key': part_key, 'supplier_key': supplier_key}) + JOIN(conditions=[t0.key == t1.order_key], types=['inner'], columns={'order_key_2': t1.order_key, 'part_key': t1.part_key, 'quantity': t1.quantity, 'supplier_key': t1.supplier_key, 'total_quantity': t0.total_quantity}) + PROJECT(columns={'key': key, 'total_quantity': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[t0.key == t1.order_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) + SCAN(table=tpch.ORDERS, columns={'key': o_orderkey}) + AGGREGATE(keys={'order_key': order_key}, aggregations={'agg_0': SUM(quantity)}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'quantity': l_quantity}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'part_key': l_partkey, 'quantity': l_quantity, 'supplier_key': l_suppkey}) diff --git a/tests/test_plan_refsols/anti_aggregate.txt b/tests/test_plan_refsols/anti_aggregate.txt new file mode 100644 index 00000000..f88fd192 --- /dev/null +++ b/tests/test_plan_refsols/anti_aggregate.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('name', name), ('num_10parts', num_10parts), ('avg_price_of_10parts', avg_price_of_10parts), ('sum_price_of_10parts', sum_price_of_10parts)], orderings=[]) + PROJECT(columns={'avg_price_of_10parts': NULL_2, 'name': name, 'num_10parts': DEFAULT_TO(NULL_2, 0:int64), 'sum_price_of_10parts': DEFAULT_TO(NULL_2, 0:int64)}) + FILTER(condition=True:bool, columns={'NULL_2': NULL_2, 'name': name}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name}) + FILTER(condition=size == 10:int64, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'size': t1.size, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'size': p_size}) diff --git a/tests/test_plan_refsols/anti_aggregate_alternate.txt b/tests/test_plan_refsols/anti_aggregate_alternate.txt new file mode 100644 index 00000000..ee7c3bdb --- /dev/null +++ b/tests/test_plan_refsols/anti_aggregate_alternate.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('name', name), ('num_10parts', num_10parts), ('avg_price_of_10parts', avg_price_of_10parts), ('sum_price_of_10parts', sum_price_of_10parts)], orderings=[]) + PROJECT(columns={'avg_price_of_10parts': DEFAULT_TO(NULL_2, 0:int64), 'name': name, 'num_10parts': DEFAULT_TO(NULL_2, 0:int64), 'sum_price_of_10parts': NULL_2}) + FILTER(condition=True:bool, columns={'NULL_2': NULL_2, 'name': name}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name}) + FILTER(condition=size == 10:int64, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'size': t1.size, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'size': p_size}) diff --git a/tests/test_plan_refsols/anti_singular.txt b/tests/test_plan_refsols/anti_singular.txt new file mode 100644 index 00000000..2e0def0f --- /dev/null +++ b/tests/test_plan_refsols/anti_singular.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('name', name), ('region_name', region_name)], orderings=[]) + PROJECT(columns={'name': name, 'region_name': NULL_1}) + FILTER(condition=True:bool, columns={'NULL_1': NULL_1, 'name': name}) + JOIN(conditions=[t0.region_key == t1.key], types=['anti'], columns={'NULL_1': None:unknown, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + FILTER(condition=name != 'ASIA':string, columns={'key': key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/asian_nations.txt b/tests/test_plan_refsols/asian_nations.txt new file mode 100644 index 00000000..0a4cf4fa --- /dev/null +++ b/tests/test_plan_refsols/asian_nations.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[]) + FILTER(condition=name_3 == 'ASIA':string, columns={'comment': comment, 'key': key, 'name': name, 'region_key': region_key}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'comment': t0.comment, 'key': t0.key, 'name': t0.name, 'name_3': t1.name, 'region_key': t0.region_key}) + SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/count_at_most_100_suppliers_per_nation.txt b/tests/test_plan_refsols/count_at_most_100_suppliers_per_nation.txt new file mode 100644 index 00000000..ceebad24 --- /dev/null +++ b/tests/test_plan_refsols/count_at_most_100_suppliers_per_nation.txt @@ -0,0 +1,8 @@ +ROOT(columns=[('name', name), ('n_top_suppliers', n_top_suppliers)], orderings=[]) + PROJECT(columns={'n_top_suppliers': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT(key)}) + LIMIT(limit=Literal(value=100, type=Int64Type()), columns={'key': key, 'nation_key': nation_key}, orderings=[(ordering_0):asc_last]) + PROJECT(columns={'key': key, 'nation_key': nation_key, 'ordering_0': account_balance}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) diff --git a/tests/test_plan_refsols/count_cust_supplier_nation_combos.txt b/tests/test_plan_refsols/count_cust_supplier_nation_combos.txt new file mode 100644 index 00000000..2f68470b --- /dev/null +++ b/tests/test_plan_refsols/count_cust_supplier_nation_combos.txt @@ -0,0 +1,17 @@ +ROOT(columns=[('year', year), ('customer_nation', customer_nation), ('supplier_nation', supplier_nation), ('num_occurrences', num_occurrences), ('total_value', total_value)], orderings=[]) + PROJECT(columns={'customer_nation': customer_nation, 'num_occurrences': DEFAULT_TO(agg_0, 0:int64), 'supplier_nation': supplier_nation, 'total_value': DEFAULT_TO(agg_1, 0:int64), 'year': year}) + AGGREGATE(keys={'customer_nation': customer_nation, 'supplier_nation': supplier_nation, 'year': year}, aggregations={'agg_0': COUNT(), 'agg_1': SUM(value)}) + PROJECT(columns={'customer_nation': name, 'supplier_nation': name_18, 'value': extended_price, 'year': YEAR(order_date)}) + JOIN(conditions=[t0.nation_key_14 == t1.key], types=['inner'], columns={'extended_price': t0.extended_price, 'name': t0.name, 'name_18': t1.name, 'order_date': t0.order_date}) + JOIN(conditions=[t0.supplier_key_9 == t1.key], types=['inner'], columns={'extended_price': t0.extended_price, 'name': t0.name, 'nation_key_14': t1.nation_key, 'order_date': t0.order_date}) + JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['inner'], columns={'extended_price': t0.extended_price, 'name': t0.name, 'order_date': t0.order_date, 'supplier_key_9': t1.supplier_key}) + JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'extended_price': t1.extended_price, 'name': t0.name, 'order_date': t0.order_date, 'part_key': t1.part_key, 'supplier_key': t1.supplier_key}) + JOIN(conditions=[t0.key_2 == t1.customer_key], types=['inner'], columns={'key_5': t1.key, 'name': t0.name, 'order_date': t1.order_date}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) + SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey, 'part_key': l_partkey, 'supplier_key': l_suppkey}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) diff --git a/tests/test_plan_refsols/count_multiple_subcollections_alongside_aggs.txt b/tests/test_plan_refsols/count_multiple_subcollections_alongside_aggs.txt new file mode 100644 index 00000000..3bd1e1c1 --- /dev/null +++ b/tests/test_plan_refsols/count_multiple_subcollections_alongside_aggs.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('nation_name', nation_name), ('num_customers', num_customers), ('num_suppliers', num_suppliers), ('customer_to_supplier_wealth_ratio', customer_to_supplier_wealth_ratio)], orderings=[]) + PROJECT(columns={'customer_to_supplier_wealth_ratio': DEFAULT_TO(agg_0, 0:int64) / DEFAULT_TO(agg_1, 0:int64), 'nation_name': key, 'num_customers': DEFAULT_TO(agg_2, 0:int64), 'num_suppliers': DEFAULT_TO(agg_3, 0:int64)}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'agg_2': t0.agg_2, 'agg_3': t1.agg_3, 'key': t0.key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'agg_2': t1.agg_2, 'key': t0.key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': SUM(acctbal), 'agg_2': COUNT()}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': SUM(account_balance), 'agg_3': COUNT()}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) diff --git a/tests/test_plan_refsols/count_single_subcollection.txt b/tests/test_plan_refsols/count_single_subcollection.txt new file mode 100644 index 00000000..d6bb50ca --- /dev/null +++ b/tests/test_plan_refsols/count_single_subcollection.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('nation_name', nation_name), ('num_customers', num_customers)], orderings=[]) + PROJECT(columns={'nation_name': key, 'num_customers': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) + SCAN(table=tpch.CUSTOMER, columns={'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/customers_sum_line_price.txt b/tests/test_plan_refsols/customers_sum_line_price.txt new file mode 100644 index 00000000..c065d6d2 --- /dev/null +++ b/tests/test_plan_refsols/customers_sum_line_price.txt @@ -0,0 +1,8 @@ +ROOT(columns=[('okey', okey), ('lsum', lsum)], orderings=[]) + PROJECT(columns={'lsum': DEFAULT_TO(agg_0, 0:int64), 'okey': key}) + JOIN(conditions=[t0.key == t1.customer_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey}) + AGGREGATE(keys={'customer_key': customer_key}, aggregations={'agg_0': SUM(extended_price)}) + JOIN(conditions=[t0.key == t1.order_key], types=['inner'], columns={'customer_key': t0.customer_key, 'extended_price': t1.extended_price}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) + SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey}) diff --git a/tests/test_plan_refsols/global_aggfunc_backref.txt b/tests/test_plan_refsols/global_aggfunc_backref.txt new file mode 100644 index 00000000..0cddc8a7 --- /dev/null +++ b/tests/test_plan_refsols/global_aggfunc_backref.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('part_name', part_name), ('is_above_avg', is_above_avg)], orderings=[]) + PROJECT(columns={'is_above_avg': retail_price > avg_price, 'part_name': name}) + JOIN(conditions=[True:bool], types=['inner'], columns={'avg_price': t0.avg_price, 'name': t1.name, 'retail_price': t1.retail_price}) + PROJECT(columns={'avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + SCAN(table=tpch.PART, columns={'name': p_name, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/global_aggfuncs.txt b/tests/test_plan_refsols/global_aggfuncs.txt new file mode 100644 index 00000000..c67d4468 --- /dev/null +++ b/tests/test_plan_refsols/global_aggfuncs.txt @@ -0,0 +1,4 @@ +ROOT(columns=[('total_bal', total_bal), ('num_bal', num_bal), ('avg_bal', avg_bal), ('min_bal', min_bal), ('max_bal', max_bal), ('num_cust', num_cust)], orderings=[]) + PROJECT(columns={'avg_bal': agg_0, 'max_bal': agg_1, 'min_bal': agg_2, 'num_bal': agg_3, 'num_cust': agg_4, 'total_bal': DEFAULT_TO(agg_5, 0:int64)}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(acctbal), 'agg_1': MAX(acctbal), 'agg_2': MIN(acctbal), 'agg_3': COUNT(acctbal), 'agg_4': COUNT(), 'agg_5': SUM(acctbal)}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal}) diff --git a/tests/test_plan_refsols/global_aggfuncs_multiple_children.txt b/tests/test_plan_refsols/global_aggfuncs_multiple_children.txt new file mode 100644 index 00000000..3017fc84 --- /dev/null +++ b/tests/test_plan_refsols/global_aggfuncs_multiple_children.txt @@ -0,0 +1,10 @@ +ROOT(columns=[('num_cust', num_cust), ('num_supp', num_supp), ('num_part', num_part)], orderings=[]) + PROJECT(columns={'num_cust': agg_0, 'num_part': agg_1, 'num_supp': agg_2}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'agg_2': t0.agg_2}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t0.agg_0, 'agg_2': t1.agg_2}) + AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal}) + AGGREGATE(keys={}, aggregations={'agg_2': COUNT()}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal}) + AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) + SCAN(table=tpch.PART, columns={'brand': p_brand}) diff --git a/tests/test_plan_refsols/global_calc_backref.txt b/tests/test_plan_refsols/global_calc_backref.txt new file mode 100644 index 00000000..22e3900a --- /dev/null +++ b/tests/test_plan_refsols/global_calc_backref.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('part_name', part_name), ('is_above_cutoff', is_above_cutoff), ('is_nickel', is_nickel)], orderings=[]) + PROJECT(columns={'is_above_cutoff': retail_price > a, 'is_nickel': CONTAINS(part_type, b), 'part_name': name}) + JOIN(conditions=[True:bool], types=['inner'], columns={'a': t0.a, 'b': t0.b, 'name': t1.name, 'part_type': t1.part_type, 'retail_price': t1.retail_price}) + PROJECT(columns={'a': 28.15:int64, 'b': 'NICKEL':string}) + EMPTYSINGLETON() + SCAN(table=tpch.PART, columns={'name': p_name, 'part_type': p_type, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/global_calc_multiple.txt b/tests/test_plan_refsols/global_calc_multiple.txt new file mode 100644 index 00000000..432b36ac --- /dev/null +++ b/tests/test_plan_refsols/global_calc_multiple.txt @@ -0,0 +1,4 @@ +ROOT(columns=[('a', a), ('b', b), ('c', c), ('d', d)], orderings=[]) + PROJECT(columns={'a': a, 'b': b, 'c': 3.14:float64, 'd': True:bool}) + PROJECT(columns={'a': 0:int64, 'b': 'X':string}) + EMPTYSINGLETON() diff --git a/tests/test_plan_refsols/global_calc_simple.txt b/tests/test_plan_refsols/global_calc_simple.txt new file mode 100644 index 00000000..ea1d190d --- /dev/null +++ b/tests/test_plan_refsols/global_calc_simple.txt @@ -0,0 +1,3 @@ +ROOT(columns=[('a', a), ('b', b), ('c', c), ('d', d)], orderings=[]) + PROJECT(columns={'a': 0:int64, 'b': 'X':string, 'c': 3.14:float64, 'd': True:bool}) + EMPTYSINGLETON() diff --git a/tests/test_plan_refsols/join_asia_region_nations.txt b/tests/test_plan_refsols/join_asia_region_nations.txt new file mode 100644 index 00000000..728a048c --- /dev/null +++ b/tests/test_plan_refsols/join_asia_region_nations.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[]) + PROJECT(columns={'comment': comment_1, 'key': key_2, 'name': name_3, 'region_key': region_key}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'comment_1': t1.comment, 'key_2': t1.key, 'name_3': t1.name, 'region_key': t1.region_key}) + FILTER(condition=name == 'ASIA':string, columns={'key': key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/join_order_by.txt b/tests/test_plan_refsols/join_order_by.txt new file mode 100644 index 00000000..cccd558d --- /dev/null +++ b/tests/test_plan_refsols/join_order_by.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('region_name', region_name), ('nation_name', nation_name)], orderings=[(ordering_0):desc_last]) + PROJECT(columns={'nation_name': name_3, 'ordering_0': ordering_0, 'region_name': name}) + PROJECT(columns={'name': name, 'name_3': name_3, 'ordering_0': name_3}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/join_order_by_back_reference.txt b/tests/test_plan_refsols/join_order_by_back_reference.txt new file mode 100644 index 00000000..4c78747c --- /dev/null +++ b/tests/test_plan_refsols/join_order_by_back_reference.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('region_name', region_name), ('nation_name', nation_name)], orderings=[(ordering_0):desc_last]) + PROJECT(columns={'nation_name': nation_name, 'ordering_0': name, 'region_name': region_name}) + PROJECT(columns={'name': name, 'nation_name': name_3, 'region_name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/join_order_by_pruned_back_reference.txt b/tests/test_plan_refsols/join_order_by_pruned_back_reference.txt new file mode 100644 index 00000000..58b2a894 --- /dev/null +++ b/tests/test_plan_refsols/join_order_by_pruned_back_reference.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('nation_name', nation_name)], orderings=[(ordering_0):desc_last]) + PROJECT(columns={'nation_name': nation_name, 'ordering_0': name}) + PROJECT(columns={'name': name, 'nation_name': name_3}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/join_region_nations.txt b/tests/test_plan_refsols/join_region_nations.txt new file mode 100644 index 00000000..e3d18998 --- /dev/null +++ b/tests/test_plan_refsols/join_region_nations.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[]) + PROJECT(columns={'comment': comment_1, 'key': key_2, 'name': name_3, 'region_key': region_key}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'comment_1': t1.comment, 'key_2': t1.key, 'name_3': t1.name, 'region_key': t1.region_key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey}) + SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/join_region_nations_customers.txt b/tests/test_plan_refsols/join_region_nations_customers.txt new file mode 100644 index 00000000..dd4770f7 --- /dev/null +++ b/tests/test_plan_refsols/join_region_nations_customers.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('key', key), ('name', name), ('address', address), ('nation_key', nation_key), ('phone', phone), ('acctbal', acctbal), ('mktsegment', mktsegment), ('comment', comment)], orderings=[]) + PROJECT(columns={'acctbal': acctbal, 'address': address, 'comment': comment_4, 'key': key_5, 'mktsegment': mktsegment, 'name': name_6, 'nation_key': nation_key, 'phone': phone}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'address': t1.address, 'comment_4': t1.comment, 'key_5': t1.key, 'mktsegment': t1.mktsegment, 'name_6': t1.name, 'nation_key': t1.nation_key, 'phone': t1.phone}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'address': c_address, 'comment': c_comment, 'key': c_custkey, 'mktsegment': c_mktsegment, 'name': c_name, 'nation_key': c_nationkey, 'phone': c_phone}) diff --git a/tests/test_plan_refsols/join_regions_nations_calc_override.txt b/tests/test_plan_refsols/join_regions_nations_calc_override.txt new file mode 100644 index 00000000..c133df59 --- /dev/null +++ b/tests/test_plan_refsols/join_regions_nations_calc_override.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('key', key_0), ('name', name), ('phone', phone), ('mktsegment', mktsegment)], orderings=[]) + PROJECT(columns={'key_0': -3:int64, 'mktsegment': mktsegment, 'name': name_6, 'phone': phone}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'mktsegment': t1.mktsegment, 'name_6': t1.name, 'phone': t1.phone}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.CUSTOMER, columns={'mktsegment': c_mktsegment, 'name': c_name, 'nation_key': c_nationkey, 'phone': c_phone}) diff --git a/tests/test_plan_refsols/join_topk.txt b/tests/test_plan_refsols/join_topk.txt new file mode 100644 index 00000000..061f9c85 --- /dev/null +++ b/tests/test_plan_refsols/join_topk.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('region_name', region_name), ('nation_name', nation_name)], orderings=[(ordering_0):asc_last]) + PROJECT(columns={'nation_name': name_3, 'ordering_0': ordering_0, 'region_name': name}) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'name': name, 'name_3': name_3, 'ordering_0': ordering_0}, orderings=[(ordering_0):asc_last]) + PROJECT(columns={'name': name, 'name_3': name_3, 'ordering_0': name_3}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/lineitem_regional_shipments.txt b/tests/test_plan_refsols/lineitem_regional_shipments.txt new file mode 100644 index 00000000..8706500f --- /dev/null +++ b/tests/test_plan_refsols/lineitem_regional_shipments.txt @@ -0,0 +1,20 @@ +ROOT(columns=[('rname', rname), ('price', price)], orderings=[]) + PROJECT(columns={'price': extended_price, 'rname': name}) + FILTER(condition=name == name_16, columns={'extended_price': extended_price, 'name': name}) + JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['left'], columns={'extended_price': t0.extended_price, 'name': t0.name, 'name_16': t1.name_16}) + JOIN(conditions=[t0.key_8 == t1.order_key], types=['inner'], columns={'extended_price': t1.extended_price, 'name': t0.name, 'part_key': t1.part_key, 'supplier_key': t1.supplier_key}) + JOIN(conditions=[t0.key_5 == t1.customer_key], types=['inner'], columns={'key_8': t1.key, 'name': t0.name}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key_5': t1.key, 'name': t0.name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) + SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey, 'part_key': l_partkey, 'supplier_key': l_suppkey}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name_16': t1.name, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'part_key': t0.part_key, 'region_key': t1.region_key, 'supplier_key': t0.supplier_key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/lineitem_regional_shipments2.txt b/tests/test_plan_refsols/lineitem_regional_shipments2.txt new file mode 100644 index 00000000..9d028765 --- /dev/null +++ b/tests/test_plan_refsols/lineitem_regional_shipments2.txt @@ -0,0 +1,20 @@ +ROOT(columns=[('rname', rname), ('price', price)], orderings=[]) + PROJECT(columns={'price': extended_price, 'rname': name_8}) + FILTER(condition=name_8 == name_15, columns={'extended_price': extended_price, 'name_8': name_8}) + JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['left'], columns={'extended_price': t0.extended_price, 'name_15': t1.name_15, 'name_8': t0.name_8}) + JOIN(conditions=[t0.order_key == t1.key], types=['left'], columns={'extended_price': t0.extended_price, 'name_8': t1.name_8, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey, 'part_key': l_partkey, 'supplier_key': l_suppkey}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_8': t1.name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) + JOIN(conditions=[t0.customer_key == t1.key], types=['inner'], columns={'key': t0.key, 'nation_key': t1.nation_key}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name_15': t1.name, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'part_key': t0.part_key, 'region_key': t1.region_key, 'supplier_key': t0.supplier_key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/lineitem_regional_shipments3.txt b/tests/test_plan_refsols/lineitem_regional_shipments3.txt new file mode 100644 index 00000000..f06cd9e3 --- /dev/null +++ b/tests/test_plan_refsols/lineitem_regional_shipments3.txt @@ -0,0 +1,20 @@ +ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[]) + PROJECT(columns={'comment': comment_31, 'key': key_32, 'name': name_33}) + FILTER(condition=name_33 == name, columns={'comment_31': comment_31, 'key_32': key_32, 'name_33': name_33}) + JOIN(conditions=[t0.region_key_30 == t1.key], types=['inner'], columns={'comment_31': t1.comment, 'key_32': t1.key, 'name': t0.name, 'name_33': t1.name}) + JOIN(conditions=[t0.nation_key_25 == t1.key], types=['inner'], columns={'name': t0.name, 'region_key_30': t1.region_key}) + JOIN(conditions=[t0.customer_key_12 == t1.key], types=['inner'], columns={'name': t0.name, 'nation_key_25': t1.nation_key}) + JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'customer_key_12': t1.customer_key, 'name': t0.name}) + JOIN(conditions=[t0.key_8 == t1.order_key], types=['inner'], columns={'name': t0.name, 'order_key': t1.order_key}) + JOIN(conditions=[t0.key_5 == t1.customer_key], types=['inner'], columns={'key_8': t1.key, 'name': t0.name}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key_5': t1.key, 'name': t0.name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/lineitems_access_cust_supplier_nations.txt b/tests/test_plan_refsols/lineitems_access_cust_supplier_nations.txt new file mode 100644 index 00000000..a7db6ab8 --- /dev/null +++ b/tests/test_plan_refsols/lineitems_access_cust_supplier_nations.txt @@ -0,0 +1,15 @@ +ROOT(columns=[('ship_year', ship_year), ('supplier_nation', supplier_nation), ('customer_nation', customer_nation), ('value', value)], orderings=[]) + PROJECT(columns={'customer_nation': name_9, 'ship_year': YEAR(ship_date), 'supplier_nation': name_4, 'value': extended_price * 1.0:float64 - discount}) + JOIN(conditions=[t0.order_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name_4': t0.name_4, 'name_9': t1.name_9, 'ship_date': t0.ship_date}) + JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name_4': t1.name_4, 'order_key': t0.order_key, 'ship_date': t0.ship_date}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'part_key': l_partkey, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'name_4': t1.name, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_9': t1.name}) + JOIN(conditions=[t0.customer_key == t1.key], types=['inner'], columns={'key': t0.key, 'nation_key': t1.nation_key}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) diff --git a/tests/test_plan_refsols/lines_german_supplier_economy_part.txt b/tests/test_plan_refsols/lines_german_supplier_economy_part.txt new file mode 100644 index 00000000..f55f2351 --- /dev/null +++ b/tests/test_plan_refsols/lines_german_supplier_economy_part.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('order_key', order_key), ('ship_date', ship_date), ('extended_price', extended_price)], orderings=[]) + FILTER(condition=name_4 == 'GERMANY':string & STARTSWITH(part_type, 'ECONOMY':string), columns={'extended_price': extended_price, 'order_key': order_key, 'ship_date': ship_date}) + JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['left'], columns={'extended_price': t0.extended_price, 'name_4': t0.name_4, 'order_key': t0.order_key, 'part_type': t1.part_type, 'ship_date': t0.ship_date}) + JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['left'], columns={'extended_price': t0.extended_price, 'name_4': t1.name_4, 'order_key': t0.order_key, 'part_key': t0.part_key, 'ship_date': t0.ship_date, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey, 'part_key': l_partkey, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'name_4': t1.name, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'part_key': t0.part_key, 'part_type': t1.part_type, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'part_type': p_type}) diff --git a/tests/test_plan_refsols/lines_shipping_vs_customer_region.txt b/tests/test_plan_refsols/lines_shipping_vs_customer_region.txt new file mode 100644 index 00000000..643e3473 --- /dev/null +++ b/tests/test_plan_refsols/lines_shipping_vs_customer_region.txt @@ -0,0 +1,20 @@ +ROOT(columns=[('order_year', order_year), ('customer_region', customer_region), ('customer_nation', customer_nation), ('supplier_region', supplier_region), ('nation_name', nation_name)], orderings=[]) + PROJECT(columns={'customer_nation': name_3, 'customer_region': name, 'nation_name': nation_name, 'order_year': YEAR(order_date), 'supplier_region': name_16}) + JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['left'], columns={'name': t0.name, 'name_16': t1.name_16, 'name_3': t0.name_3, 'nation_name': t1.nation_name, 'order_date': t0.order_date}) + JOIN(conditions=[t0.key_8 == t1.order_key], types=['inner'], columns={'name': t0.name, 'name_3': t0.name_3, 'order_date': t0.order_date, 'part_key': t1.part_key, 'supplier_key': t1.supplier_key}) + JOIN(conditions=[t0.key_5 == t1.customer_key], types=['inner'], columns={'key_8': t1.key, 'name': t0.name, 'name_3': t0.name_3, 'order_date': t1.order_date}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key_5': t1.key, 'name': t0.name, 'name_3': t0.name_3}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'part_key': l_partkey, 'supplier_key': l_suppkey}) + PROJECT(columns={'name_16': name_16, 'nation_name': name_13, 'part_key': part_key, 'supplier_key': supplier_key}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name_13': t0.name_13, 'name_16': t1.name, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'name_13': t1.name, 'part_key': t0.part_key, 'region_key': t1.region_key, 'supplier_key': t0.supplier_key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/mostly_positive_accounts_per_nation1.txt b/tests/test_plan_refsols/mostly_positive_accounts_per_nation1.txt new file mode 100644 index 00000000..9bb76c35 --- /dev/null +++ b/tests/test_plan_refsols/mostly_positive_accounts_per_nation1.txt @@ -0,0 +1,10 @@ +ROOT(columns=[('name', name)], orderings=[]) + FILTER(condition=DEFAULT_TO(agg_0, 0:int64) > 0.5:float64 * DEFAULT_TO(agg_1, 0:int64), columns={'name': name}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'name': t0.name}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT(key)}) + FILTER(condition=account_balance > 0.0:float64, columns={'key': key, 'nation_key': nation_key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': COUNT(key)}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) diff --git a/tests/test_plan_refsols/mostly_positive_accounts_per_nation2.txt b/tests/test_plan_refsols/mostly_positive_accounts_per_nation2.txt new file mode 100644 index 00000000..ca15fc25 --- /dev/null +++ b/tests/test_plan_refsols/mostly_positive_accounts_per_nation2.txt @@ -0,0 +1,11 @@ +ROOT(columns=[('name', name), ('suppliers_in_black', suppliers_in_black), ('total_suppliers', total_suppliers)], orderings=[]) + FILTER(condition=DEFAULT_TO(agg_2, 0:int64) > 0.5:float64 * DEFAULT_TO(agg_3, 0:int64), columns={'name': name, 'suppliers_in_black': suppliers_in_black, 'total_suppliers': total_suppliers}) + PROJECT(columns={'agg_2': agg_2, 'agg_3': agg_3, 'name': name, 'suppliers_in_black': DEFAULT_TO(agg_0, 0:int64), 'total_suppliers': DEFAULT_TO(agg_1, 0:int64)}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'agg_2': t0.agg_2, 'agg_3': t1.agg_3, 'name': t0.name}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'agg_2': t1.agg_2, 'key': t0.key, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT(key), 'agg_2': COUNT(key)}) + FILTER(condition=account_balance > 0.0:float64, columns={'key': key, 'nation_key': nation_key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': COUNT(key), 'agg_3': COUNT(key)}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) diff --git a/tests/test_plan_refsols/mostly_positive_accounts_per_nation3.txt b/tests/test_plan_refsols/mostly_positive_accounts_per_nation3.txt new file mode 100644 index 00000000..cf731a44 --- /dev/null +++ b/tests/test_plan_refsols/mostly_positive_accounts_per_nation3.txt @@ -0,0 +1,11 @@ +ROOT(columns=[('name', name), ('suppliers_in_black', suppliers_in_black), ('total_suppliers', total_suppliers)], orderings=[]) + FILTER(condition=suppliers_in_black > 0.5:float64 * total_suppliers, columns={'name': name, 'suppliers_in_black': suppliers_in_black, 'total_suppliers': total_suppliers}) + PROJECT(columns={'name': name, 'suppliers_in_black': DEFAULT_TO(agg_0, 0:int64), 'total_suppliers': DEFAULT_TO(agg_1, 0:int64)}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'name': t0.name}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT(key)}) + FILTER(condition=account_balance > 0.0:float64, columns={'key': key, 'nation_key': nation_key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': COUNT(key)}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) diff --git a/tests/test_plan_refsols/multiple_has_hasnot.txt b/tests/test_plan_refsols/multiple_has_hasnot.txt new file mode 100644 index 00000000..65e53dd4 --- /dev/null +++ b/tests/test_plan_refsols/multiple_has_hasnot.txt @@ -0,0 +1,24 @@ +ROOT(columns=[('name', name)], orderings=[]) + FILTER(condition=True:bool & True:bool & True:bool, columns={'name': name}) + JOIN(conditions=[t0.key == t1.part_key], types=['semi'], columns={'name': t0.name}) + JOIN(conditions=[t0.key == t1.part_key], types=['anti'], columns={'key': t0.key, 'name': t0.name}) + JOIN(conditions=[t0.key == t1.part_key], types=['semi'], columns={'key': t0.key, 'name': t0.name}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'name': p_name}) + FILTER(condition=name_4 == 'GERMANY':string, columns={'part_key': part_key}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'name_4': t1.name, 'part_key': t0.part_key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + FILTER(condition=name_8 == 'FRANCE':string, columns={'part_key': part_key}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'name_8': t1.name, 'part_key': t0.part_key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + FILTER(condition=name_12 == 'ARGENTINA':string, columns={'part_key': part_key}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'name_12': t1.name, 'part_key': t0.part_key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) diff --git a/tests/test_plan_refsols/multiple_simple_aggregations_multiple_calcs.txt b/tests/test_plan_refsols/multiple_simple_aggregations_multiple_calcs.txt new file mode 100644 index 00000000..8431aba3 --- /dev/null +++ b/tests/test_plan_refsols/multiple_simple_aggregations_multiple_calcs.txt @@ -0,0 +1,11 @@ +ROOT(columns=[('nation_name', nation_name_0), ('total_consumer_value', total_consumer_value), ('total_supplier_value', total_supplier_value), ('avg_consumer_value', avg_consumer_value), ('avg_supplier_value', avg_supplier_value), ('best_consumer_value', best_consumer_value), ('best_supplier_value', best_supplier_value)], orderings=[]) + PROJECT(columns={'avg_consumer_value': avg_consumer_value, 'avg_supplier_value': avg_supplier_value, 'best_consumer_value': agg_4, 'best_supplier_value': agg_5, 'nation_name_0': key, 'total_consumer_value': total_consumer_value, 'total_supplier_value': total_supplier_value}) + PROJECT(columns={'agg_4': agg_4, 'agg_5': agg_5, 'avg_consumer_value': avg_consumer_value, 'avg_supplier_value': agg_2, 'key': key, 'total_consumer_value': total_consumer_value, 'total_supplier_value': DEFAULT_TO(agg_3, 0:int64)}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_2': t1.agg_2, 'agg_3': t1.agg_3, 'agg_4': t0.agg_4, 'agg_5': t1.agg_5, 'avg_consumer_value': t0.avg_consumer_value, 'key': t0.key, 'total_consumer_value': t0.total_consumer_value}) + PROJECT(columns={'agg_4': agg_4, 'avg_consumer_value': agg_0, 'key': key, 'total_consumer_value': DEFAULT_TO(agg_1, 0:int64)}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'agg_1': t1.agg_1, 'agg_4': t1.agg_4, 'key': t0.key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': AVG(acctbal), 'agg_1': SUM(acctbal), 'agg_4': MAX(acctbal)}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_2': AVG(account_balance), 'agg_3': SUM(account_balance), 'agg_5': MAX(account_balance)}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) diff --git a/tests/test_plan_refsols/multiple_simple_aggregations_single_calc.txt b/tests/test_plan_refsols/multiple_simple_aggregations_single_calc.txt new file mode 100644 index 00000000..c93f153a --- /dev/null +++ b/tests/test_plan_refsols/multiple_simple_aggregations_single_calc.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('nation_name', nation_name), ('consumer_value', consumer_value), ('producer_value', producer_value)], orderings=[]) + PROJECT(columns={'consumer_value': DEFAULT_TO(agg_0, 0:int64), 'nation_name': key, 'producer_value': DEFAULT_TO(agg_1, 0:int64)}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'key': t0.key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': SUM(acctbal)}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': SUM(account_balance)}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) diff --git a/tests/test_plan_refsols/nation_name_contains_region_name.txt b/tests/test_plan_refsols/nation_name_contains_region_name.txt new file mode 100644 index 00000000..49891184 --- /dev/null +++ b/tests/test_plan_refsols/nation_name_contains_region_name.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[]) + PROJECT(columns={'comment': comment_1, 'key': key_2, 'name': name_3, 'region_key': region_key}) + FILTER(condition=CONTAINS(name_3, name), columns={'comment_1': comment_1, 'key_2': key_2, 'name_3': name_3, 'region_key': region_key}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'comment_1': t1.comment, 'key_2': t1.key, 'name': t0.name, 'name_3': t1.name, 'region_key': t1.region_key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/nations_access_region.txt b/tests/test_plan_refsols/nations_access_region.txt new file mode 100644 index 00000000..27f04db0 --- /dev/null +++ b/tests/test_plan_refsols/nations_access_region.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('nation_name', nation_name), ('region_name', region_name)], orderings=[]) + PROJECT(columns={'nation_name': name, 'region_name': name_3}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/nations_order_by_num_suppliers.txt b/tests/test_plan_refsols/nations_order_by_num_suppliers.txt new file mode 100644 index 00000000..d4485a6e --- /dev/null +++ b/tests/test_plan_refsols/nations_order_by_num_suppliers.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[(ordering_0):asc_last]) + PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': DEFAULT_TO(agg_1, 0:int64), 'region_key': region_key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_1': t1.agg_1, 'comment': t0.comment, 'key': t0.key, 'name': t0.name, 'region_key': t0.region_key}) + SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': COUNT(key)}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) diff --git a/tests/test_plan_refsols/nations_region_order_by_name.txt b/tests/test_plan_refsols/nations_region_order_by_name.txt new file mode 100644 index 00000000..9ce22933 --- /dev/null +++ b/tests/test_plan_refsols/nations_region_order_by_name.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[(ordering_0):asc_last, (ordering_1):asc_last]) + PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': name, 'ordering_1': name_3, 'region_key': region_key}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'comment': t0.comment, 'key': t0.key, 'name': t0.name, 'name_3': t1.name, 'region_key': t0.region_key}) + SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/nations_sum_line_price.txt b/tests/test_plan_refsols/nations_sum_line_price.txt new file mode 100644 index 00000000..cb32fb58 --- /dev/null +++ b/tests/test_plan_refsols/nations_sum_line_price.txt @@ -0,0 +1,10 @@ +ROOT(columns=[('okey', okey), ('lsum', lsum)], orderings=[]) + PROJECT(columns={'lsum': DEFAULT_TO(agg_0, 0:int64), 'okey': key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': SUM(extended_price)}) + JOIN(conditions=[t0.key_2 == t1.order_key], types=['inner'], columns={'extended_price': t1.extended_price, 'nation_key': t0.nation_key}) + JOIN(conditions=[t0.key == t1.customer_key], types=['inner'], columns={'key_2': t1.key, 'nation_key': t0.nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) + SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey}) diff --git a/tests/test_plan_refsols/num_positive_accounts_per_nation.txt b/tests/test_plan_refsols/num_positive_accounts_per_nation.txt new file mode 100644 index 00000000..d23aeab5 --- /dev/null +++ b/tests/test_plan_refsols/num_positive_accounts_per_nation.txt @@ -0,0 +1,10 @@ +ROOT(columns=[('name', name), ('suppliers_in_black', suppliers_in_black), ('total_suppliers', total_suppliers)], orderings=[]) + PROJECT(columns={'name': name, 'suppliers_in_black': DEFAULT_TO(agg_0, 0:int64), 'total_suppliers': DEFAULT_TO(agg_1, 0:int64)}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'name': t0.name}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT(key)}) + FILTER(condition=account_balance > 0.0:float64, columns={'key': key, 'nation_key': nation_key}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': COUNT(key)}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) diff --git a/tests/test_plan_refsols/order_by_before_join.txt b/tests/test_plan_refsols/order_by_before_join.txt new file mode 100644 index 00000000..e3d18998 --- /dev/null +++ b/tests/test_plan_refsols/order_by_before_join.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[]) + PROJECT(columns={'comment': comment_1, 'key': key_2, 'name': name_3, 'region_key': region_key}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'comment_1': t1.comment, 'key_2': t1.key, 'name_3': t1.name, 'region_key': t1.region_key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey}) + SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/order_by_expression.txt b/tests/test_plan_refsols/order_by_expression.txt new file mode 100644 index 00000000..60ef0a01 --- /dev/null +++ b/tests/test_plan_refsols/order_by_expression.txt @@ -0,0 +1,4 @@ +ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[(ordering_1):asc_first]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'comment': comment, 'key': key, 'name': name, 'ordering_1': ordering_1}, orderings=[(ordering_1):asc_first]) + PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_1': LENGTH(name)}) + SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/ordered_asian_nations.txt b/tests/test_plan_refsols/ordered_asian_nations.txt new file mode 100644 index 00000000..bb5ef871 --- /dev/null +++ b/tests/test_plan_refsols/ordered_asian_nations.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[(ordering_0):asc_last]) + FILTER(condition=name_3 == 'ASIA':string, columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': ordering_0, 'region_key': region_key}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'comment': t0.comment, 'key': t0.key, 'name': t0.name, 'name_3': t1.name, 'ordering_0': t0.ordering_0, 'region_key': t0.region_key}) + PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': name, 'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/ordering_name_overload.txt b/tests/test_plan_refsols/ordering_name_overload.txt new file mode 100644 index 00000000..7cd90db5 --- /dev/null +++ b/tests/test_plan_refsols/ordering_name_overload.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('ordering_0', ordering_0_0), ('ordering_1', ordering_1_0), ('ordering_2', ordering_2_0), ('ordering_3', ordering_3_0), ('ordering_4', ordering_4_0), ('ordering_5', ordering_5_0), ('ordering_6', ordering_6), ('ordering_7', ordering_7), ('ordering_8', ordering_8)], orderings=[(ordering_3):asc_last, (ordering_4):desc_last, (ordering_5):asc_first]) + PROJECT(columns={'ordering_0_0': ordering_2, 'ordering_1_0': ordering_0, 'ordering_2_0': ordering_1, 'ordering_3': ordering_3, 'ordering_3_0': ordering_2, 'ordering_4': ordering_4, 'ordering_4_0': ordering_1, 'ordering_5': ordering_5, 'ordering_5_0': ordering_0, 'ordering_6': LOWER(name), 'ordering_7': ABS(key), 'ordering_8': LENGTH(comment)}) + PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': ordering_0, 'ordering_1': ordering_1, 'ordering_2': ordering_2, 'ordering_3': LOWER(name), 'ordering_4': ABS(key), 'ordering_5': LENGTH(comment)}) + PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': name, 'ordering_1': key, 'ordering_2': comment}) + SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name}) diff --git a/tests/test_plan_refsols/orders_sum_line_price.txt b/tests/test_plan_refsols/orders_sum_line_price.txt new file mode 100644 index 00000000..3e45eeb7 --- /dev/null +++ b/tests/test_plan_refsols/orders_sum_line_price.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('okey', okey), ('lsum', lsum)], orderings=[]) + PROJECT(columns={'lsum': DEFAULT_TO(agg_0, 0:int64), 'okey': key}) + JOIN(conditions=[t0.key == t1.order_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) + SCAN(table=tpch.ORDERS, columns={'key': o_orderkey}) + AGGREGATE(keys={'order_key': order_key}, aggregations={'agg_0': SUM(extended_price)}) + SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey}) diff --git a/tests/test_plan_refsols/orders_sum_vs_count_line_price.txt b/tests/test_plan_refsols/orders_sum_vs_count_line_price.txt new file mode 100644 index 00000000..568339ab --- /dev/null +++ b/tests/test_plan_refsols/orders_sum_vs_count_line_price.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('okey', okey), ('lavg', lavg)], orderings=[]) + PROJECT(columns={'lavg': DEFAULT_TO(agg_0, 0:int64) / DEFAULT_TO(agg_1, 0:int64), 'okey': key}) + JOIN(conditions=[t0.key == t1.order_key], types=['left'], columns={'agg_0': t1.agg_0, 'agg_1': t1.agg_1, 'key': t0.key}) + SCAN(table=tpch.ORDERS, columns={'key': o_orderkey}) + AGGREGATE(keys={'order_key': order_key}, aggregations={'agg_0': SUM(extended_price), 'agg_1': COUNT(extended_price)}) + SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey}) diff --git a/tests/test_plan_refsols/rank_customers.txt b/tests/test_plan_refsols/rank_customers.txt new file mode 100644 index 00000000..1c7f3870 --- /dev/null +++ b/tests/test_plan_refsols/rank_customers.txt @@ -0,0 +1,3 @@ +ROOT(columns=[('name', name), ('cust_rank', cust_rank)], orderings=[]) + PROJECT(columns={'cust_rank': RANKING(args=[], partition=[], order=[(acctbal):desc_first]), 'name': name}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name}) diff --git a/tests/test_plan_refsols/rank_customers_per_nation.txt b/tests/test_plan_refsols/rank_customers_per_nation.txt new file mode 100644 index 00000000..9a3b11b3 --- /dev/null +++ b/tests/test_plan_refsols/rank_customers_per_nation.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('nation_name', nation_name), ('name', name), ('cust_rank', cust_rank)], orderings=[]) + PROJECT(columns={'cust_rank': RANKING(args=[], partition=[key], order=[(acctbal):desc_first], allow_ties=True), 'name': name_3, 'nation_name': name}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/rank_customers_per_region.txt b/tests/test_plan_refsols/rank_customers_per_region.txt new file mode 100644 index 00000000..376ebb8b --- /dev/null +++ b/tests/test_plan_refsols/rank_customers_per_region.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('nation_name', nation_name), ('name', name), ('cust_rank', cust_rank)], orderings=[]) + PROJECT(columns={'cust_rank': RANKING(args=[], partition=[key], order=[(acctbal):desc_first], allow_ties=True, dense=True), 'name': name_6, 'nation_name': name_3}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'key': t0.key, 'name_3': t0.name_3, 'name_6': t1.name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/region_nations_backref.txt b/tests/test_plan_refsols/region_nations_backref.txt new file mode 100644 index 00000000..03f35148 --- /dev/null +++ b/tests/test_plan_refsols/region_nations_backref.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('region_name', region_name), ('nation_name', nation_name)], orderings=[]) + PROJECT(columns={'nation_name': name_3, 'region_name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/regions_sum_line_price.txt b/tests/test_plan_refsols/regions_sum_line_price.txt new file mode 100644 index 00000000..36197758 --- /dev/null +++ b/tests/test_plan_refsols/regions_sum_line_price.txt @@ -0,0 +1,12 @@ +ROOT(columns=[('okey', okey), ('lsum', lsum)], orderings=[]) + PROJECT(columns={'lsum': DEFAULT_TO(agg_0, 0:int64), 'okey': key}) + JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey}) + AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': SUM(extended_price)}) + JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'extended_price': t1.extended_price, 'region_key': t0.region_key}) + JOIN(conditions=[t0.key_2 == t1.customer_key], types=['inner'], columns={'key_5': t1.key, 'region_key': t0.region_key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_2': t1.key, 'region_key': t0.region_key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) + SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey}) diff --git a/tests/test_plan_refsols/replace_order_by.txt b/tests/test_plan_refsols/replace_order_by.txt new file mode 100644 index 00000000..2d5ffed8 --- /dev/null +++ b/tests/test_plan_refsols/replace_order_by.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('region_name', region_name), ('nation_name', nation_name)], orderings=[(ordering_1):desc_last]) + PROJECT(columns={'nation_name': nation_name, 'ordering_1': region_name, 'region_name': region_name}) + PROJECT(columns={'nation_name': name_3, 'region_name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/scan_calc.txt b/tests/test_plan_refsols/scan_calc.txt new file mode 100644 index 00000000..01c1ad22 --- /dev/null +++ b/tests/test_plan_refsols/scan_calc.txt @@ -0,0 +1,3 @@ +ROOT(columns=[('region_name', region_name), ('magic_word', magic_word)], orderings=[]) + PROJECT(columns={'magic_word': 'foo':string, 'region_name': name}) + SCAN(table=tpch.REGION, columns={'name': r_name}) diff --git a/tests/test_plan_refsols/scan_calc_calc.txt b/tests/test_plan_refsols/scan_calc_calc.txt new file mode 100644 index 00000000..f6094a35 --- /dev/null +++ b/tests/test_plan_refsols/scan_calc_calc.txt @@ -0,0 +1,4 @@ +ROOT(columns=[('fizz', fizz), ('buzz', buzz)], orderings=[]) + PROJECT(columns={'buzz': key, 'fizz': name_0}) + PROJECT(columns={'key': key, 'name_0': 'foo':string}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey}) diff --git a/tests/test_plan_refsols/scan_customer_call_functions.txt b/tests/test_plan_refsols/scan_customer_call_functions.txt new file mode 100644 index 00000000..56c3fc4b --- /dev/null +++ b/tests/test_plan_refsols/scan_customer_call_functions.txt @@ -0,0 +1,3 @@ +ROOT(columns=[('name', name_0), ('country_code', country_code), ('adjusted_account_balance', adjusted_account_balance), ('is_named_john', is_named_john)], orderings=[]) + PROJECT(columns={'adjusted_account_balance': IFF(acctbal < 0:int64, 0:int64, acctbal), 'country_code': SLICE(phone, 0:int64, 3:int64, 1:int64), 'is_named_john': LOWER(name) < 'john':string, 'name_0': LOWER(name)}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name, 'phone': c_phone}) diff --git a/tests/test_plan_refsols/scan_nations.txt b/tests/test_plan_refsols/scan_nations.txt new file mode 100644 index 00000000..6e48c299 --- /dev/null +++ b/tests/test_plan_refsols/scan_nations.txt @@ -0,0 +1,2 @@ +ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[]) + SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/scan_regions.txt b/tests/test_plan_refsols/scan_regions.txt new file mode 100644 index 00000000..5061e9d3 --- /dev/null +++ b/tests/test_plan_refsols/scan_regions.txt @@ -0,0 +1,2 @@ +ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[]) + SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/semi_aggregate.txt b/tests/test_plan_refsols/semi_aggregate.txt new file mode 100644 index 00000000..4f557dfe --- /dev/null +++ b/tests/test_plan_refsols/semi_aggregate.txt @@ -0,0 +1,10 @@ +ROOT(columns=[('name', name), ('num_10parts', num_10parts), ('avg_price_of_10parts', avg_price_of_10parts), ('sum_price_of_10parts', sum_price_of_10parts)], orderings=[]) + PROJECT(columns={'avg_price_of_10parts': agg_0, 'name': name, 'num_10parts': DEFAULT_TO(agg_1, 0:int64), 'sum_price_of_10parts': DEFAULT_TO(agg_2, 0:int64)}) + FILTER(condition=True:bool, columns={'agg_0': agg_0, 'agg_1': agg_1, 'agg_2': agg_2, 'name': name}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['inner'], columns={'agg_0': t1.agg_0, 'agg_1': t1.agg_1, 'agg_2': t1.agg_2, 'name': t0.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price), 'agg_1': COUNT(), 'agg_2': SUM(retail_price)}) + FILTER(condition=size == 10:int64, columns={'retail_price': retail_price, 'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'size': t1.size, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice, 'size': p_size}) diff --git a/tests/test_plan_refsols/semi_singular.txt b/tests/test_plan_refsols/semi_singular.txt new file mode 100644 index 00000000..2cf0ef14 --- /dev/null +++ b/tests/test_plan_refsols/semi_singular.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('name', name), ('region_name', region_name)], orderings=[]) + PROJECT(columns={'name': name, 'region_name': name_3}) + FILTER(condition=True:bool, columns={'name': name, 'name_3': name_3}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + FILTER(condition=name != 'ASIA':string, columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/simple_anti_1.txt b/tests/test_plan_refsols/simple_anti_1.txt new file mode 100644 index 00000000..906635f6 --- /dev/null +++ b/tests/test_plan_refsols/simple_anti_1.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('name', name)], orderings=[]) + FILTER(condition=True:bool, columns={'name': name}) + JOIN(conditions=[t0.key == t1.customer_key], types=['anti'], columns={'name': t0.name}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'name': c_name}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) diff --git a/tests/test_plan_refsols/simple_anti_2.txt b/tests/test_plan_refsols/simple_anti_2.txt new file mode 100644 index 00000000..14624316 --- /dev/null +++ b/tests/test_plan_refsols/simple_anti_2.txt @@ -0,0 +1,8 @@ +ROOT(columns=[('name', name)], orderings=[]) + FILTER(condition=True:bool, columns={'name': name}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['anti'], columns={'name': t0.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name}) + FILTER(condition=size < 10:int64, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'size': t1.size, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'size': p_size}) diff --git a/tests/test_plan_refsols/simple_order_by.txt b/tests/test_plan_refsols/simple_order_by.txt new file mode 100644 index 00000000..d7847177 --- /dev/null +++ b/tests/test_plan_refsols/simple_order_by.txt @@ -0,0 +1,3 @@ +ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[(ordering_0):asc_last]) + PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': name}) + SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/simple_semi_1.txt b/tests/test_plan_refsols/simple_semi_1.txt new file mode 100644 index 00000000..e76d6ee0 --- /dev/null +++ b/tests/test_plan_refsols/simple_semi_1.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('name', name)], orderings=[]) + FILTER(condition=True:bool, columns={'name': name}) + JOIN(conditions=[t0.key == t1.customer_key], types=['semi'], columns={'name': t0.name}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'name': c_name}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) diff --git a/tests/test_plan_refsols/simple_semi_2.txt b/tests/test_plan_refsols/simple_semi_2.txt new file mode 100644 index 00000000..dacbfea1 --- /dev/null +++ b/tests/test_plan_refsols/simple_semi_2.txt @@ -0,0 +1,8 @@ +ROOT(columns=[('name', name)], orderings=[]) + FILTER(condition=True:bool, columns={'name': name}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'name': t0.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name}) + FILTER(condition=size < 10:int64, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'size': t1.size, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'size': p_size}) diff --git a/tests/test_plan_refsols/simple_topk.txt b/tests/test_plan_refsols/simple_topk.txt new file mode 100644 index 00000000..63e60244 --- /dev/null +++ b/tests/test_plan_refsols/simple_topk.txt @@ -0,0 +1,4 @@ +ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[(ordering_0):asc_last]) + LIMIT(limit=Literal(value=2, type=Int64Type()), columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': ordering_0}, orderings=[(ordering_0):asc_last]) + PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': name}) + SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/singular_anti.txt b/tests/test_plan_refsols/singular_anti.txt new file mode 100644 index 00000000..2beb3726 --- /dev/null +++ b/tests/test_plan_refsols/singular_anti.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('name', name), ('region_name', region_name)], orderings=[]) + FILTER(condition=True:bool, columns={'name': name, 'region_name': region_name}) + PROJECT(columns={'name': name, 'region_name': NULL_1}) + JOIN(conditions=[t0.region_key == t1.key], types=['anti'], columns={'NULL_1': None:unknown, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + FILTER(condition=name != 'ASIA':string, columns={'key': key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/singular_semi.txt b/tests/test_plan_refsols/singular_semi.txt new file mode 100644 index 00000000..d09452bf --- /dev/null +++ b/tests/test_plan_refsols/singular_semi.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('name', name), ('region_name', region_name)], orderings=[]) + FILTER(condition=True:bool, columns={'name': name, 'region_name': region_name}) + PROJECT(columns={'name': name, 'region_name': name_3}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + FILTER(condition=name != 'ASIA':string, columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/top_5_nations_balance_by_num_suppliers.txt b/tests/test_plan_refsols/top_5_nations_balance_by_num_suppliers.txt new file mode 100644 index 00000000..aee0c07d --- /dev/null +++ b/tests/test_plan_refsols/top_5_nations_balance_by_num_suppliers.txt @@ -0,0 +1,8 @@ +ROOT(columns=[('name', name), ('total_bal', total_bal)], orderings=[(ordering_0):asc_last]) + PROJECT(columns={'name': name, 'ordering_0': ordering_0, 'total_bal': DEFAULT_TO(agg_2, 0:int64)}) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'agg_2': agg_2, 'name': name, 'ordering_0': ordering_0}, orderings=[(ordering_0):asc_last]) + PROJECT(columns={'agg_2': agg_2, 'name': name, 'ordering_0': DEFAULT_TO(agg_1, 0:int64)}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_1': t1.agg_1, 'agg_2': t1.agg_2, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': COUNT(), 'agg_2': SUM(account_balance)}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) diff --git a/tests/test_plan_refsols/top_5_nations_by_num_supplierss.txt b/tests/test_plan_refsols/top_5_nations_by_num_supplierss.txt new file mode 100644 index 00000000..c16be8c6 --- /dev/null +++ b/tests/test_plan_refsols/top_5_nations_by_num_supplierss.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[(ordering_0):asc_last]) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': ordering_0, 'region_key': region_key}, orderings=[(ordering_0):asc_last]) + PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': DEFAULT_TO(agg_1, 0:int64), 'region_key': region_key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_1': t1.agg_1, 'comment': t0.comment, 'key': t0.key, 'name': t0.name, 'region_key': t0.region_key}) + SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': COUNT(key)}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) diff --git a/tests/test_plan_refsols/topk_order_by.txt b/tests/test_plan_refsols/topk_order_by.txt new file mode 100644 index 00000000..748685b7 --- /dev/null +++ b/tests/test_plan_refsols/topk_order_by.txt @@ -0,0 +1,4 @@ +ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[(ordering_1):asc_last]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'comment': comment, 'key': key, 'name': name, 'ordering_1': ordering_1}, orderings=[(ordering_1):asc_last]) + PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_1': name}) + SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/topk_order_by_calc.txt b/tests/test_plan_refsols/topk_order_by_calc.txt new file mode 100644 index 00000000..a8924b80 --- /dev/null +++ b/tests/test_plan_refsols/topk_order_by_calc.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('region_name', region_name), ('name_length', name_length)], orderings=[(ordering_1):asc_last]) + PROJECT(columns={'name_length': LENGTH(name), 'ordering_1': ordering_1, 'region_name': name}) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'name': name, 'ordering_1': ordering_1}, orderings=[(ordering_1):asc_last]) + PROJECT(columns={'name': name, 'ordering_1': name}) + SCAN(table=tpch.REGION, columns={'name': r_name}) diff --git a/tests/test_plan_refsols/topk_replace_order_by.txt b/tests/test_plan_refsols/topk_replace_order_by.txt new file mode 100644 index 00000000..14127dea --- /dev/null +++ b/tests/test_plan_refsols/topk_replace_order_by.txt @@ -0,0 +1,4 @@ +ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[(ordering_2):desc_first]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'comment': comment, 'key': key, 'name': name, 'ordering_2': ordering_2}, orderings=[(ordering_2):desc_first]) + PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_2': name}) + SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/topk_root_different_order_by.txt b/tests/test_plan_refsols/topk_root_different_order_by.txt new file mode 100644 index 00000000..b67a95a4 --- /dev/null +++ b/tests/test_plan_refsols/topk_root_different_order_by.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[(ordering_2):desc_first]) + PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_2': name}) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'comment': comment, 'key': key, 'name': name}, orderings=[(ordering_1):asc_first]) + PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_1': name}) + SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/various_aggfuncs_global.txt b/tests/test_plan_refsols/various_aggfuncs_global.txt new file mode 100644 index 00000000..1f4f6c04 --- /dev/null +++ b/tests/test_plan_refsols/various_aggfuncs_global.txt @@ -0,0 +1,4 @@ +ROOT(columns=[('total_bal', total_bal), ('num_bal', num_bal), ('avg_bal', avg_bal), ('min_bal', min_bal), ('max_bal', max_bal), ('num_cust', num_cust)], orderings=[]) + PROJECT(columns={'avg_bal': DEFAULT_TO(agg_0, 0:int64), 'max_bal': agg_1, 'min_bal': agg_2, 'num_bal': agg_3, 'num_cust': agg_4, 'total_bal': agg_5}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(acctbal), 'agg_1': MAX(acctbal), 'agg_2': MIN(acctbal), 'agg_3': COUNT(acctbal), 'agg_4': COUNT(), 'agg_5': SUM(acctbal)}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal}) diff --git a/tests/test_plan_refsols/various_aggfuncs_simple.txt b/tests/test_plan_refsols/various_aggfuncs_simple.txt new file mode 100644 index 00000000..7ecb6930 --- /dev/null +++ b/tests/test_plan_refsols/various_aggfuncs_simple.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('nation_name', nation_name), ('total_bal', total_bal), ('num_bal', num_bal), ('avg_bal', avg_bal), ('min_bal', min_bal), ('max_bal', max_bal), ('num_cust', num_cust)], orderings=[]) + PROJECT(columns={'avg_bal': DEFAULT_TO(agg_0, 0:int64), 'max_bal': agg_1, 'min_bal': agg_2, 'nation_name': name, 'num_bal': DEFAULT_TO(agg_3, 0:int64), 'num_cust': DEFAULT_TO(agg_4, 0:int64), 'total_bal': agg_5}) + JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'agg_1': t1.agg_1, 'agg_2': t1.agg_2, 'agg_3': t1.agg_3, 'agg_4': t1.agg_4, 'agg_5': t1.agg_5, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': AVG(acctbal), 'agg_1': MAX(acctbal), 'agg_2': MIN(acctbal), 'agg_3': COUNT(acctbal), 'agg_4': COUNT(), 'agg_5': SUM(acctbal)}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) diff --git a/tests/test_qdag_conversion.py b/tests/test_qdag_conversion.py index ac126436..27fc6f8a 100644 --- a/tests/test_qdag_conversion.py +++ b/tests/test_qdag_conversion.py @@ -3,6 +3,8 @@ relational tree. """ +from collections.abc import Callable + import pytest from test_utils import ( BackReferenceExpressionInfo, @@ -38,20 +40,14 @@ pytest.param( ( TableCollectionInfo("Regions"), - """ -ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[]) - SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) -""", + "scan_regions", ), id="scan_regions", ), pytest.param( ( TableCollectionInfo("Nations"), - """ -ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[]) - SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) -""", + "scan_nations", ), id="scan_nations", ), @@ -63,11 +59,7 @@ region_name=ReferenceInfo("name"), magic_word=LiteralInfo("foo", StringType()), ), - """ -ROOT(columns=[('region_name', region_name), ('magic_word', magic_word)], orderings=[]) - PROJECT(columns={'magic_word': 'foo':string, 'region_name': name}) - SCAN(table=tpch.REGION, columns={'name': r_name}) -""", + "scan_calc", ), id="scan_calc", ), @@ -76,25 +68,14 @@ TableCollectionInfo("Regions") ** CalcInfo([], name=LiteralInfo("foo", StringType())) ** CalcInfo([], fizz=ReferenceInfo("name"), buzz=ReferenceInfo("key")), - """ -ROOT(columns=[('fizz', fizz), ('buzz', buzz)], orderings=[]) - PROJECT(columns={'buzz': key, 'fizz': name_0}) - PROJECT(columns={'key': key, 'name_0': 'foo':string}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey}) -""", + "scan_calc_calc", ), id="scan_calc_calc", ), pytest.param( ( TableCollectionInfo("Regions") ** SubCollectionInfo("nations"), - """ -ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[]) - PROJECT(columns={'comment': comment_1, 'key': key_2, 'name': name_3, 'region_key': region_key}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'comment_1': t1.comment, 'key_2': t1.key, 'name_3': t1.name, 'region_key': t1.region_key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey}) - SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) -""", + "join_region_nations", ), id="join_region_nations", ), @@ -103,15 +84,7 @@ TableCollectionInfo("Regions") ** SubCollectionInfo("nations") ** SubCollectionInfo("customers"), - """ -ROOT(columns=[('key', key), ('name', name), ('address', address), ('nation_key', nation_key), ('phone', phone), ('acctbal', acctbal), ('mktsegment', mktsegment), ('comment', comment)], orderings=[]) - PROJECT(columns={'acctbal': acctbal, 'address': address, 'comment': comment_4, 'key': key_5, 'mktsegment': mktsegment, 'name': name_6, 'nation_key': nation_key, 'phone': phone}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'address': t1.address, 'comment_4': t1.comment, 'key_5': t1.key, 'mktsegment': t1.mktsegment, 'name_6': t1.name, 'nation_key': t1.nation_key, 'phone': t1.phone}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'address': c_address, 'comment': c_comment, 'key': c_custkey, 'mktsegment': c_mktsegment, 'name': c_name, 'nation_key': c_nationkey, 'phone': c_phone}) -""", + "join_region_nations_customers", ), id="join_region_nations_customers", ), @@ -152,11 +125,7 @@ ], ), ), - """ -ROOT(columns=[('name', name_0), ('country_code', country_code), ('adjusted_account_balance', adjusted_account_balance), ('is_named_john', is_named_john)], orderings=[]) - PROJECT(columns={'adjusted_account_balance': IFF(acctbal < 0:int64, 0:int64, acctbal), 'country_code': SLICE(phone, 0:int64, 3:int64, 1:int64), 'is_named_john': LOWER(name) < 'john':string, 'name_0': LOWER(name)}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name, 'phone': c_phone}) -""", + "scan_customer_call_functions", ), id="scan_customer_call_functions", ), @@ -168,13 +137,7 @@ nation_name=ReferenceInfo("name"), region_name=ChildReferenceExpressionInfo("name", 0), ), - """ -ROOT(columns=[('nation_name', nation_name), ('region_name', region_name)], orderings=[]) - PROJECT(columns={'nation_name': name, 'region_name': name_3}) - JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) -""", + "nations_access_region", ), id="nations_access_region", ), @@ -207,23 +170,7 @@ ], ), ), - """ -ROOT(columns=[('ship_year', ship_year), ('supplier_nation', supplier_nation), ('customer_nation', customer_nation), ('value', value)], orderings=[]) - PROJECT(columns={'customer_nation': name_9, 'ship_year': YEAR(ship_date), 'supplier_nation': name_4, 'value': extended_price * 1.0:float64 - discount}) - JOIN(conditions=[t0.order_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name_4': t0.name_4, 'name_9': t1.name_9, 'ship_date': t0.ship_date}) - JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name_4': t1.name_4, 'order_key': t0.order_key, 'ship_date': t0.ship_date}) - SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'part_key': l_partkey, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'name_4': t1.name, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_9': t1.name}) - JOIN(conditions=[t0.customer_key == t1.key], types=['inner'], columns={'key': t0.key, 'nation_key': t1.nation_key}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) -""", + "lineitems_access_cust_supplier_nations", ), id="lineitems_access_cust_supplier_nations", ), @@ -241,15 +188,7 @@ phone=ReferenceInfo("phone"), mktsegment=ReferenceInfo("mktsegment"), ), - """ -ROOT(columns=[('key', key_0), ('name', name), ('phone', phone), ('mktsegment', mktsegment)], orderings=[]) - PROJECT(columns={'key_0': -3:int64, 'mktsegment': mktsegment, 'name': name_6, 'phone': phone}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'mktsegment': t1.mktsegment, 'name_6': t1.name, 'phone': t1.phone}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.CUSTOMER, columns={'mktsegment': c_mktsegment, 'name': c_name, 'nation_key': c_nationkey, 'phone': c_phone}) -""", + "join_regions_nations_calc_override", ), id="join_regions_nations_calc_override", ), @@ -262,13 +201,7 @@ region_name=BackReferenceExpressionInfo("name", 1), nation_name=ReferenceInfo("name"), ), - """ -ROOT(columns=[('region_name', region_name), ('nation_name', nation_name)], orderings=[]) - PROJECT(columns={'nation_name': name_3, 'region_name': name}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) -""", + "region_nations_backref", ), id="region_nations_backref", ), @@ -297,28 +230,7 @@ supplier_region=ChildReferenceExpressionInfo("name", 0), nation_name=ChildReferenceExpressionInfo("nation_name", 0), ), - """ -ROOT(columns=[('order_year', order_year), ('customer_region', customer_region), ('customer_nation', customer_nation), ('supplier_region', supplier_region), ('nation_name', nation_name)], orderings=[]) - PROJECT(columns={'customer_nation': name_3, 'customer_region': name, 'nation_name': nation_name, 'order_year': YEAR(order_date), 'supplier_region': name_16}) - JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['left'], columns={'name': t0.name, 'name_16': t1.name_16, 'name_3': t0.name_3, 'nation_name': t1.nation_name, 'order_date': t0.order_date}) - JOIN(conditions=[t0.key_8 == t1.order_key], types=['inner'], columns={'name': t0.name, 'name_3': t0.name_3, 'order_date': t0.order_date, 'part_key': t1.part_key, 'supplier_key': t1.supplier_key}) - JOIN(conditions=[t0.key_5 == t1.customer_key], types=['inner'], columns={'key_8': t1.key, 'name': t0.name, 'name_3': t0.name_3, 'order_date': t1.order_date}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key_5': t1.key, 'name': t0.name, 'name_3': t0.name_3}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) - SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'part_key': l_partkey, 'supplier_key': l_suppkey}) - PROJECT(columns={'name_16': name_16, 'nation_name': name_13, 'part_key': part_key, 'supplier_key': supplier_key}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name_13': t0.name_13, 'name_16': t1.name, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'name_13': t1.name, 'part_key': t0.part_key, 'region_key': t1.region_key, 'supplier_key': t0.supplier_key}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) -""", + "lines_shipping_vs_customer_region", ), id="lines_shipping_vs_customer_region", ), @@ -332,14 +244,7 @@ "SUM", [ChildReferenceExpressionInfo("extended_price", 0)] ), ), - """ -ROOT(columns=[('okey', okey), ('lsum', lsum)], orderings=[]) - PROJECT(columns={'lsum': DEFAULT_TO(agg_0, 0:int64), 'okey': key}) - JOIN(conditions=[t0.key == t1.order_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) - SCAN(table=tpch.ORDERS, columns={'key': o_orderkey}) - AGGREGATE(keys={'order_key': order_key}, aggregations={'agg_0': SUM(extended_price)}) - SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey}) -""", + "orders_sum_line_price", ), id="orders_sum_line_price", ), @@ -353,16 +258,7 @@ "SUM", [ChildReferenceExpressionInfo("extended_price", 0)] ), ), - """ -ROOT(columns=[('okey', okey), ('lsum', lsum)], orderings=[]) - PROJECT(columns={'lsum': DEFAULT_TO(agg_0, 0:int64), 'okey': key}) - JOIN(conditions=[t0.key == t1.customer_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey}) - AGGREGATE(keys={'customer_key': customer_key}, aggregations={'agg_0': SUM(extended_price)}) - JOIN(conditions=[t0.key == t1.order_key], types=['inner'], columns={'customer_key': t0.customer_key, 'extended_price': t1.extended_price}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) - SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey}) -""", + "customers_sum_line_price", ), id="customers_sum_line_price", ), @@ -380,18 +276,7 @@ "SUM", [ChildReferenceExpressionInfo("extended_price", 0)] ), ), - """ -ROOT(columns=[('okey', okey), ('lsum', lsum)], orderings=[]) - PROJECT(columns={'lsum': DEFAULT_TO(agg_0, 0:int64), 'okey': key}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': SUM(extended_price)}) - JOIN(conditions=[t0.key_2 == t1.order_key], types=['inner'], columns={'extended_price': t1.extended_price, 'nation_key': t0.nation_key}) - JOIN(conditions=[t0.key == t1.customer_key], types=['inner'], columns={'key_2': t1.key, 'nation_key': t0.nation_key}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) - SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey}) -""", + "nations_sum_line_price", ), id="nations_sum_line_price", ), @@ -410,20 +295,7 @@ "SUM", [ChildReferenceExpressionInfo("extended_price", 0)] ), ), - """ -ROOT(columns=[('okey', okey), ('lsum', lsum)], orderings=[]) - PROJECT(columns={'lsum': DEFAULT_TO(agg_0, 0:int64), 'okey': key}) - JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey}) - AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': SUM(extended_price)}) - JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'extended_price': t1.extended_price, 'region_key': t0.region_key}) - JOIN(conditions=[t0.key_2 == t1.customer_key], types=['inner'], columns={'key_5': t1.key, 'region_key': t0.region_key}) - JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_2': t1.key, 'region_key': t0.region_key}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) - SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey}) -""", + "regions_sum_line_price", ), id="regions_sum_line_price", ), @@ -447,14 +319,7 @@ ], ), ), - """ -ROOT(columns=[('okey', okey), ('lavg', lavg)], orderings=[]) - PROJECT(columns={'lavg': DEFAULT_TO(agg_0, 0:int64) / DEFAULT_TO(agg_1, 0:int64), 'okey': key}) - JOIN(conditions=[t0.key == t1.order_key], types=['left'], columns={'agg_0': t1.agg_0, 'agg_1': t1.agg_1, 'key': t0.key}) - SCAN(table=tpch.ORDERS, columns={'key': o_orderkey}) - AGGREGATE(keys={'order_key': order_key}, aggregations={'agg_0': SUM(extended_price), 'agg_1': COUNT(extended_price)}) - SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey}) -""", + "orders_sum_vs_count_line_price", ), id="orders_sum_vs_count_line_price", ), @@ -474,17 +339,7 @@ "SUM", [ChildReferenceExpressionInfo("account_balance", 1)] ), ), - """ -ROOT(columns=[('nation_name', nation_name), ('consumer_value', consumer_value), ('producer_value', producer_value)], orderings=[]) - PROJECT(columns={'consumer_value': DEFAULT_TO(agg_0, 0:int64), 'nation_name': key, 'producer_value': DEFAULT_TO(agg_1, 0:int64)}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'key': t0.key}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': SUM(acctbal)}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': SUM(account_balance)}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) -""", + "multiple_simple_aggregations_single_calc", ), id="multiple_simple_aggregations_single_calc", ), @@ -524,19 +379,7 @@ "MAX", [ChildReferenceExpressionInfo("account_balance", 0)] ), ), - """ -ROOT(columns=[('nation_name', nation_name_0), ('total_consumer_value', total_consumer_value), ('total_supplier_value', total_supplier_value), ('avg_consumer_value', avg_consumer_value), ('avg_supplier_value', avg_supplier_value), ('best_consumer_value', best_consumer_value), ('best_supplier_value', best_supplier_value)], orderings=[]) - PROJECT(columns={'avg_consumer_value': avg_consumer_value, 'avg_supplier_value': avg_supplier_value, 'best_consumer_value': agg_4, 'best_supplier_value': agg_5, 'nation_name_0': key, 'total_consumer_value': total_consumer_value, 'total_supplier_value': total_supplier_value}) - PROJECT(columns={'agg_4': agg_4, 'agg_5': agg_5, 'avg_consumer_value': avg_consumer_value, 'avg_supplier_value': agg_2, 'key': key, 'total_consumer_value': total_consumer_value, 'total_supplier_value': DEFAULT_TO(agg_3, 0:int64)}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_2': t1.agg_2, 'agg_3': t1.agg_3, 'agg_4': t0.agg_4, 'agg_5': t1.agg_5, 'avg_consumer_value': t0.avg_consumer_value, 'key': t0.key, 'total_consumer_value': t0.total_consumer_value}) - PROJECT(columns={'agg_4': agg_4, 'avg_consumer_value': agg_0, 'key': key, 'total_consumer_value': DEFAULT_TO(agg_1, 0:int64)}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'agg_1': t1.agg_1, 'agg_4': t1.agg_4, 'key': t0.key}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': AVG(acctbal), 'agg_1': SUM(acctbal), 'agg_4': MAX(acctbal)}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_2': AVG(account_balance), 'agg_3': SUM(account_balance), 'agg_5': MAX(account_balance)}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) -""", + "multiple_simple_aggregations_multiple_calcs", ), id="multiple_simple_aggregations_multiple_calcs", ), @@ -550,14 +393,7 @@ "COUNT", [ChildReferenceCollectionInfo(0)] ), ), - """ -ROOT(columns=[('nation_name', nation_name), ('num_customers', num_customers)], orderings=[]) - PROJECT(columns={'nation_name': key, 'num_customers': DEFAULT_TO(agg_0, 0:int64)}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) - SCAN(table=tpch.CUSTOMER, columns={'nation_key': c_nationkey}) -""", + "count_single_subcollection", ), id="count_single_subcollection", ), @@ -586,17 +422,7 @@ ], ), ), - """ -ROOT(columns=[('nation_name', nation_name), ('num_customers', num_customers), ('num_suppliers', num_suppliers), ('customer_to_supplier_wealth_ratio', customer_to_supplier_wealth_ratio)], orderings=[]) - PROJECT(columns={'customer_to_supplier_wealth_ratio': DEFAULT_TO(agg_0, 0:int64) / DEFAULT_TO(agg_1, 0:int64), 'nation_name': key, 'num_customers': DEFAULT_TO(agg_2, 0:int64), 'num_suppliers': DEFAULT_TO(agg_3, 0:int64)}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'agg_2': t0.agg_2, 'agg_3': t1.agg_3, 'key': t0.key}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'agg_2': t1.agg_2, 'key': t0.key}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': SUM(acctbal), 'agg_2': COUNT()}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': SUM(account_balance), 'agg_3': COUNT()}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) -""", + "count_multiple_subcollections_alongside_aggs", ), id="count_multiple_subcollections_alongside_aggs", ), @@ -628,14 +454,7 @@ ], ), ), - """ -ROOT(columns=[('nation_name', nation_name), ('avg_consumer_value', avg_consumer_value)], orderings=[]) - PROJECT(columns={'avg_consumer_value': agg_0, 'nation_name': key}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': MAX(IFF(acctbal < 0.0:float64, 0.0:float64, acctbal))}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) -""", + "aggregate_on_function_call", ), id="aggregate_on_function_call", ), @@ -662,17 +481,7 @@ "MAX", [ChildReferenceExpressionInfo("ratio", 0)] ), ), - """ -ROOT(columns=[('order_key', order_key), ('max_ratio', max_ratio)], orderings=[]) - PROJECT(columns={'max_ratio': agg_0, 'order_key': key}) - JOIN(conditions=[t0.key == t1.order_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) - SCAN(table=tpch.ORDERS, columns={'key': o_orderkey}) - AGGREGATE(keys={'order_key': order_key}, aggregations={'agg_0': MAX(ratio)}) - PROJECT(columns={'order_key': order_key, 'ratio': quantity / availqty}) - JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['inner'], columns={'availqty': t1.availqty, 'order_key': t0.order_key, 'quantity': t0.quantity}) - SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'part_key': l_partkey, 'quantity': l_quantity, 'supplier_key': l_suppkey}) - SCAN(table=tpch.PARTSUPP, columns={'availqty': ps_availqty, 'part_key': ps_partkey, 'supplier_key': ps_suppkey}) -""", + "aggregate_mixed_levels_simple", ), id="aggregate_mixed_levels_simple", ), @@ -699,17 +508,7 @@ ], ), ), - """ -ROOT(columns=[('part_key', part_key), ('supplier_key', supplier_key), ('order_key', order_key), ('order_quantity_ratio', order_quantity_ratio)], orderings=[]) - PROJECT(columns={'order_key': order_key_2, 'order_quantity_ratio': quantity / total_quantity, 'part_key': part_key, 'supplier_key': supplier_key}) - JOIN(conditions=[t0.key == t1.order_key], types=['inner'], columns={'order_key_2': t1.order_key, 'part_key': t1.part_key, 'quantity': t1.quantity, 'supplier_key': t1.supplier_key, 'total_quantity': t0.total_quantity}) - PROJECT(columns={'key': key, 'total_quantity': DEFAULT_TO(agg_0, 0:int64)}) - JOIN(conditions=[t0.key == t1.order_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) - SCAN(table=tpch.ORDERS, columns={'key': o_orderkey}) - AGGREGATE(keys={'order_key': order_key}, aggregations={'agg_0': SUM(quantity)}) - SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'quantity': l_quantity}) - SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'part_key': l_partkey, 'quantity': l_quantity, 'supplier_key': l_suppkey}) -""", + "aggregate_then_backref", ), id="aggregate_then_backref", ), @@ -722,11 +521,7 @@ c=LiteralInfo(3.14, Float64Type()), d=LiteralInfo(True, BooleanType()), ), - """ -ROOT(columns=[('a', a), ('b', b), ('c', c), ('d', d)], orderings=[]) - PROJECT(columns={'a': 0:int64, 'b': 'X':string, 'c': 3.14:float64, 'd': True:bool}) - EMPTYSINGLETON() -""", + "global_calc_simple", ), id="global_calc_simple", ), @@ -744,12 +539,7 @@ c=LiteralInfo(3.14, Float64Type()), d=LiteralInfo(True, BooleanType()), ), - """ -ROOT(columns=[('a', a), ('b', b), ('c', c), ('d', d)], orderings=[]) - PROJECT(columns={'a': a, 'b': b, 'c': 3.14:float64, 'd': True:bool}) - PROJECT(columns={'a': 0:int64, 'b': 'X':string}) - EMPTYSINGLETON() -""", + "global_calc_multiple", ), id="global_calc_multiple", ), @@ -774,12 +564,7 @@ ), num_cust=FunctionInfo("COUNT", [ChildReferenceCollectionInfo(0)]), ), - """ -ROOT(columns=[('total_bal', total_bal), ('num_bal', num_bal), ('avg_bal', avg_bal), ('min_bal', min_bal), ('max_bal', max_bal), ('num_cust', num_cust)], orderings=[]) - PROJECT(columns={'avg_bal': agg_0, 'max_bal': agg_1, 'min_bal': agg_2, 'num_bal': agg_3, 'num_cust': agg_4, 'total_bal': DEFAULT_TO(agg_5, 0:int64)}) - AGGREGATE(keys={}, aggregations={'agg_0': AVG(acctbal), 'agg_1': MAX(acctbal), 'agg_2': MIN(acctbal), 'agg_3': COUNT(acctbal), 'agg_4': COUNT(), 'agg_5': SUM(acctbal)}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal}) -""", + "global_aggfuncs", ), id="global_aggfuncs", ), @@ -795,18 +580,7 @@ num_supp=FunctionInfo("COUNT", [ChildReferenceCollectionInfo(1)]), num_part=FunctionInfo("COUNT", [ChildReferenceCollectionInfo(2)]), ), - """ -ROOT(columns=[('num_cust', num_cust), ('num_supp', num_supp), ('num_part', num_part)], orderings=[]) - PROJECT(columns={'num_cust': agg_0, 'num_part': agg_1, 'num_supp': agg_2}) - JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'agg_2': t0.agg_2}) - JOIN(conditions=[True:bool], types=['left'], columns={'agg_0': t0.agg_0, 'agg_2': t1.agg_2}) - AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal}) - AGGREGATE(keys={}, aggregations={'agg_2': COUNT()}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal}) - AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) - SCAN(table=tpch.PART, columns={'brand': p_brand}) -""", + "global_aggfuncs_multiple_children", ), id="global_aggfuncs_multiple_children", ), @@ -836,14 +610,7 @@ ], ), ), - """ -ROOT(columns=[('part_name', part_name), ('is_above_cutoff', is_above_cutoff), ('is_nickel', is_nickel)], orderings=[]) - PROJECT(columns={'is_above_cutoff': retail_price > a, 'is_nickel': CONTAINS(part_type, b), 'part_name': name}) - JOIN(conditions=[True:bool], types=['inner'], columns={'a': t0.a, 'b': t0.b, 'name': t1.name, 'part_type': t1.part_type, 'retail_price': t1.retail_price}) - PROJECT(columns={'a': 28.15:int64, 'b': 'NICKEL':string}) - EMPTYSINGLETON() - SCAN(table=tpch.PART, columns={'name': p_name, 'part_type': p_type, 'retail_price': p_retailprice}) -""", + "global_calc_backref", ), id="global_calc_backref", ), @@ -867,15 +634,7 @@ ], ), ), - """ -ROOT(columns=[('part_name', part_name), ('is_above_avg', is_above_avg)], orderings=[]) - PROJECT(columns={'is_above_avg': retail_price > avg_price, 'part_name': name}) - JOIN(conditions=[True:bool], types=['inner'], columns={'avg_price': t0.avg_price, 'name': t1.name, 'retail_price': t1.retail_price}) - PROJECT(columns={'avg_price': agg_0}) - AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) - SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) - SCAN(table=tpch.PART, columns={'name': p_name, 'retail_price': p_retailprice}) -""", + "global_aggfunc_backref", ), id="global_aggfunc_backref", ), @@ -898,8 +657,7 @@ ], ), ), - """ -""", + "aggregate_mixed_levels_advanced", ), id="aggregate_mixed_levels_advanced", marks=pytest.mark.skip("TODO"), @@ -919,12 +677,7 @@ "AVG", [ChildReferenceExpressionInfo("retail_price", 0)] ), ), - """ -ROOT(columns=[('part_type', part_type), ('num_parts', num_parts), ('avg_price', avg_price)], orderings=[]) - PROJECT(columns={'avg_price': agg_0, 'num_parts': DEFAULT_TO(agg_1, 0:int64), 'part_type': part_type}) - AGGREGATE(keys={'part_type': part_type}, aggregations={'agg_0': AVG(retail_price), 'agg_1': COUNT()}) - SCAN(table=tpch.PART, columns={'part_type': p_type, 'retail_price': p_retailprice}) -""", + "agg_parts_by_type_simple", ), id="agg_parts_by_type_simple", ), @@ -951,13 +704,7 @@ "COUNT", [ChildReferenceCollectionInfo(0)] ), ), - """ -ROOT(columns=[('year', year), ('month', month), ('total_orders', total_orders)], orderings=[]) - PROJECT(columns={'month': month, 'total_orders': DEFAULT_TO(agg_0, 0:int64), 'year': year}) - AGGREGATE(keys={'month': month, 'year': year}, aggregations={'agg_0': COUNT()}) - PROJECT(columns={'month': MONTH(order_date), 'year': YEAR(order_date)}) - SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate}) -""", + "agg_orders_by_year_month_basic", ), id="agg_orders_by_year_month_basic", ), @@ -1004,24 +751,7 @@ "COUNT", [ChildReferenceCollectionInfo(1)] ), ), - """ -ROOT(columns=[('year', year), ('month', month), ('num_european_orders', num_european_orders), ('total_orders', total_orders)], orderings=[]) - PROJECT(columns={'month': month, 'num_european_orders': DEFAULT_TO(agg_0, 0:int64), 'total_orders': DEFAULT_TO(agg_1, 0:int64), 'year': year}) - JOIN(conditions=[t0.year == t1.year & t0.month == t1.month], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'month': t0.month, 'year': t0.year}) - AGGREGATE(keys={'month': month, 'year': year}, aggregations={'agg_0': COUNT()}) - PROJECT(columns={'month': MONTH(order_date), 'year': YEAR(order_date)}) - SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate}) - AGGREGATE(keys={'month': month, 'year': year}, aggregations={'agg_1': COUNT()}) - FILTER(condition=name_6 == 'ASIA':string, columns={'month': month, 'year': year}) - JOIN(conditions=[t0.customer_key == t1.key], types=['left'], columns={'month': t0.month, 'name_6': t1.name_6, 'year': t0.year}) - PROJECT(columns={'customer_key': customer_key, 'month': MONTH(order_date), 'year': YEAR(order_date)}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_6': t1.name}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) -""", + "agg_orders_by_year_month_vs_europe", ), id="agg_orders_by_year_month_vs_europe", ), @@ -1064,24 +794,7 @@ "COUNT", [ChildReferenceCollectionInfo(0)] ), ), - """ -ROOT(columns=[('year', year), ('month', month), ('num_european_orders', num_european_orders)], orderings=[]) - PROJECT(columns={'month': month, 'num_european_orders': DEFAULT_TO(agg_0, 0:int64), 'year': year}) - JOIN(conditions=[t0.year == t1.year & t0.month == t1.month], types=['left'], columns={'agg_0': t1.agg_0, 'month': t0.month, 'year': t0.year}) - AGGREGATE(keys={'month': month, 'year': year}, aggregations={}) - PROJECT(columns={'month': MONTH(order_date), 'year': YEAR(order_date)}) - SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate}) - AGGREGATE(keys={'month': month, 'year': year}, aggregations={'agg_0': COUNT()}) - FILTER(condition=name_6 == 'ASIA':string, columns={'month': month, 'year': year}) - JOIN(conditions=[t0.customer_key == t1.key], types=['left'], columns={'month': t0.month, 'name_6': t1.name_6, 'year': t0.year}) - PROJECT(columns={'customer_key': customer_key, 'month': MONTH(order_date), 'year': YEAR(order_date)}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_6': t1.name}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) -""", + "agg_orders_by_year_month_just_europe", ), id="agg_orders_by_year_month_just_europe", ), @@ -1123,25 +836,7 @@ "SUM", [ChildReferenceExpressionInfo("value", 0)] ), ), - """ -ROOT(columns=[('year', year), ('customer_nation', customer_nation), ('supplier_nation', supplier_nation), ('num_occurrences', num_occurrences), ('total_value', total_value)], orderings=[]) - PROJECT(columns={'customer_nation': customer_nation, 'num_occurrences': DEFAULT_TO(agg_0, 0:int64), 'supplier_nation': supplier_nation, 'total_value': DEFAULT_TO(agg_1, 0:int64), 'year': year}) - AGGREGATE(keys={'customer_nation': customer_nation, 'supplier_nation': supplier_nation, 'year': year}, aggregations={'agg_0': COUNT(), 'agg_1': SUM(value)}) - PROJECT(columns={'customer_nation': name, 'supplier_nation': name_18, 'value': extended_price, 'year': YEAR(order_date)}) - JOIN(conditions=[t0.nation_key_14 == t1.key], types=['inner'], columns={'extended_price': t0.extended_price, 'name': t0.name, 'name_18': t1.name, 'order_date': t0.order_date}) - JOIN(conditions=[t0.supplier_key_9 == t1.key], types=['inner'], columns={'extended_price': t0.extended_price, 'name': t0.name, 'nation_key_14': t1.nation_key, 'order_date': t0.order_date}) - JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['inner'], columns={'extended_price': t0.extended_price, 'name': t0.name, 'order_date': t0.order_date, 'supplier_key_9': t1.supplier_key}) - JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'extended_price': t1.extended_price, 'name': t0.name, 'order_date': t0.order_date, 'part_key': t1.part_key, 'supplier_key': t1.supplier_key}) - JOIN(conditions=[t0.key_2 == t1.customer_key], types=['inner'], columns={'key_5': t1.key, 'name': t0.name, 'order_date': t1.order_date}) - JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) - SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey, 'part_key': l_partkey, 'supplier_key': l_suppkey}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) -""", + "count_cust_supplier_nation_combos", ), id="count_cust_supplier_nation_combos", ), @@ -1185,17 +880,7 @@ ], ), ), - """ -ROOT(columns=[('part_type', part_type), ('percentage_of_parts', percentage_of_parts), ('avg_price', avg_price)], orderings=[]) - FILTER(condition=avg_price >= global_avg_price, columns={'avg_price': avg_price, 'part_type': part_type, 'percentage_of_parts': percentage_of_parts}) - PROJECT(columns={'avg_price': agg_2, 'global_avg_price': global_avg_price, 'part_type': part_type, 'percentage_of_parts': DEFAULT_TO(agg_3, 0:int64) / total_num_parts}) - JOIN(conditions=[True:bool], types=['left'], columns={'agg_2': t1.agg_2, 'agg_3': t1.agg_3, 'global_avg_price': t0.global_avg_price, 'part_type': t1.part_type, 'total_num_parts': t0.total_num_parts}) - PROJECT(columns={'global_avg_price': agg_0, 'total_num_parts': agg_1}) - AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price), 'agg_1': COUNT()}) - SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) - AGGREGATE(keys={'part_type': part_type}, aggregations={'agg_2': AVG(retail_price), 'agg_3': COUNT()}) - SCAN(table=tpch.PART, columns={'part_type': p_type, 'retail_price': p_retailprice}) -""", + "agg_parts_by_type_backref_global", ), id="agg_parts_by_type_backref_global", ), @@ -1225,15 +910,7 @@ part_type=ReferenceInfo("part_type"), retail_price=ReferenceInfo("retail_price"), ), - """ -ROOT(columns=[('part_name', part_name), ('part_type', part_type), ('retail_price', retail_price)], orderings=[]) - PROJECT(columns={'part_name': name, 'part_type': part_type_1, 'retail_price': retail_price}) - JOIN(conditions=[t0.part_type == t1.part_type], types=['inner'], columns={'name': t1.name, 'part_type_1': t1.part_type, 'retail_price': t1.retail_price}) - FILTER(condition=agg_0 > 27.5:float64, columns={'part_type': part_type}) - AGGREGATE(keys={'part_type': part_type}, aggregations={'agg_0': AVG(retail_price)}) - SCAN(table=tpch.PART, columns={'part_type': p_type, 'retail_price': p_retailprice}) - SCAN(table=tpch.PART, columns={'name': p_name, 'part_type': p_type, 'retail_price': p_retailprice}) -""", + "access_partition_child_after_filter", ), id="access_partition_child_after_filter", ), @@ -1263,15 +940,7 @@ ], ), ), - """ -ROOT(columns=[('part_name', part_name), ('part_type', part_type), ('retail_price_versus_avg', retail_price_versus_avg)], orderings=[]) - PROJECT(columns={'part_name': name, 'part_type': part_type_1, 'retail_price_versus_avg': retail_price - avg_price}) - JOIN(conditions=[t0.part_type == t1.part_type], types=['inner'], columns={'avg_price': t0.avg_price, 'name': t1.name, 'part_type_1': t1.part_type, 'retail_price': t1.retail_price}) - PROJECT(columns={'avg_price': agg_0, 'part_type': part_type}) - AGGREGATE(keys={'part_type': part_type}, aggregations={'agg_0': AVG(retail_price)}) - SCAN(table=tpch.PART, columns={'part_type': p_type, 'retail_price': p_retailprice}) - SCAN(table=tpch.PART, columns={'name': p_name, 'part_type': p_type, 'retail_price': p_retailprice}) -""", + "access_partition_child_backref_calc", ), id="access_partition_child_backref_calc", ), @@ -1305,16 +974,7 @@ ], ), ), - """ -ROOT(columns=[('part_name', part_name), ('part_type', part_type), ('retail_price', retail_price)], orderings=[]) - FILTER(condition=retail_price < avg_price, columns={'part_name': part_name, 'part_type': part_type, 'retail_price': retail_price}) - PROJECT(columns={'avg_price': avg_price, 'part_name': name, 'part_type': part_type_1, 'retail_price': retail_price}) - JOIN(conditions=[t0.part_type == t1.part_type], types=['inner'], columns={'avg_price': t0.avg_price, 'name': t1.name, 'part_type_1': t1.part_type, 'retail_price': t1.retail_price}) - PROJECT(columns={'avg_price': agg_0, 'part_type': part_type}) - AGGREGATE(keys={'part_type': part_type}, aggregations={'agg_0': AVG(retail_price)}) - SCAN(table=tpch.PART, columns={'part_type': p_type, 'retail_price': p_retailprice}) - SCAN(table=tpch.PART, columns={'name': p_name, 'part_type': p_type, 'retail_price': p_retailprice}) -""", + "access_partition_child_filter_backref_filter", ), id="access_partition_child_filter_backref_filter", ), @@ -1329,14 +989,7 @@ ), ) ** SubCollectionInfo("nations"), - """\ -ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[]) - PROJECT(columns={'comment': comment_1, 'key': key_2, 'name': name_3, 'region_key': region_key}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'comment_1': t1.comment, 'key_2': t1.key, 'name_3': t1.name, 'region_key': t1.region_key}) - FILTER(condition=name == 'ASIA':string, columns={'key': key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) -""", + "join_asia_region_nations", ), id="join_asia_region_nations", ), @@ -1353,13 +1006,7 @@ ], ), ), - """ -ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[]) - FILTER(condition=name_3 == 'ASIA':string, columns={'comment': comment, 'key': key, 'name': name, 'region_key': region_key}) - JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'comment': t0.comment, 'key': t0.key, 'name': t0.name, 'name_3': t1.name, 'region_key': t0.region_key}) - SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) -""", + "asian_nations", ), id="asian_nations", ), @@ -1400,21 +1047,7 @@ ship_date=ReferenceInfo("ship_date"), extended_price=ReferenceInfo("extended_price"), ), - """ -ROOT(columns=[('order_key', order_key), ('ship_date', ship_date), ('extended_price', extended_price)], orderings=[]) - FILTER(condition=name_4 == 'GERMANY':string & STARTSWITH(part_type, 'ECONOMY':string), columns={'extended_price': extended_price, 'order_key': order_key, 'ship_date': ship_date}) - JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['left'], columns={'extended_price': t0.extended_price, 'name_4': t0.name_4, 'order_key': t0.order_key, 'part_type': t1.part_type, 'ship_date': t0.ship_date}) - JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['left'], columns={'extended_price': t0.extended_price, 'name_4': t1.name_4, 'order_key': t0.order_key, 'part_key': t0.part_key, 'ship_date': t0.ship_date, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey, 'part_key': l_partkey, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'name_4': t1.name, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'part_key': t0.part_key, 'part_type': t1.part_type, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'part_type': p_type}) -""", + "lines_german_supplier_economy_part", ), id="lines_german_supplier_economy_part", ), @@ -1429,14 +1062,7 @@ [ReferenceInfo("name"), BackReferenceExpressionInfo("name", 1)], ), ), - """ -ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[]) - PROJECT(columns={'comment': comment_1, 'key': key_2, 'name': name_3, 'region_key': region_key}) - FILTER(condition=CONTAINS(name_3, name), columns={'comment_1': comment_1, 'key_2': key_2, 'name_3': name_3, 'region_key': region_key}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'comment_1': t1.comment, 'key_2': t1.key, 'name': t0.name, 'name_3': t1.name, 'region_key': t1.region_key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) -""", + "nation_name_contains_region_name", ), id="nation_name_contains_region_name", ), @@ -1467,28 +1093,7 @@ rname=BackReferenceExpressionInfo("name", 4), price=ReferenceInfo("extended_price"), ), - """ -ROOT(columns=[('rname', rname), ('price', price)], orderings=[]) - PROJECT(columns={'price': extended_price, 'rname': name}) - FILTER(condition=name == name_16, columns={'extended_price': extended_price, 'name': name}) - JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['left'], columns={'extended_price': t0.extended_price, 'name': t0.name, 'name_16': t1.name_16}) - JOIN(conditions=[t0.key_8 == t1.order_key], types=['inner'], columns={'extended_price': t1.extended_price, 'name': t0.name, 'part_key': t1.part_key, 'supplier_key': t1.supplier_key}) - JOIN(conditions=[t0.key_5 == t1.customer_key], types=['inner'], columns={'key_8': t1.key, 'name': t0.name}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key_5': t1.key, 'name': t0.name}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) - SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey, 'part_key': l_partkey, 'supplier_key': l_suppkey}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name_16': t1.name, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'part_key': t0.part_key, 'region_key': t1.region_key, 'supplier_key': t0.supplier_key}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) -""", + "lineitem_regional_shipments", ), id="lineitem_regional_shipments", ), @@ -1524,28 +1129,7 @@ rname=ChildReferenceExpressionInfo("name", 0), price=ReferenceInfo("extended_price"), ), - """ -ROOT(columns=[('rname', rname), ('price', price)], orderings=[]) - PROJECT(columns={'price': extended_price, 'rname': name_8}) - FILTER(condition=name_8 == name_15, columns={'extended_price': extended_price, 'name_8': name_8}) - JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['left'], columns={'extended_price': t0.extended_price, 'name_15': t1.name_15, 'name_8': t0.name_8}) - JOIN(conditions=[t0.order_key == t1.key], types=['left'], columns={'extended_price': t0.extended_price, 'name_8': t1.name_8, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'order_key': l_orderkey, 'part_key': l_partkey, 'supplier_key': l_suppkey}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_8': t1.name}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) - JOIN(conditions=[t0.customer_key == t1.key], types=['inner'], columns={'key': t0.key, 'nation_key': t1.nation_key}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name_15': t1.name, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'part_key': t0.part_key, 'region_key': t1.region_key, 'supplier_key': t0.supplier_key}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) -""", + "lineitem_regional_shipments2", ), id="lineitem_regional_shipments2", ), @@ -1570,28 +1154,7 @@ ], ), ), - """ -ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[]) - PROJECT(columns={'comment': comment_31, 'key': key_32, 'name': name_33}) - FILTER(condition=name_33 == name, columns={'comment_31': comment_31, 'key_32': key_32, 'name_33': name_33}) - JOIN(conditions=[t0.region_key_30 == t1.key], types=['inner'], columns={'comment_31': t1.comment, 'key_32': t1.key, 'name': t0.name, 'name_33': t1.name}) - JOIN(conditions=[t0.nation_key_25 == t1.key], types=['inner'], columns={'name': t0.name, 'region_key_30': t1.region_key}) - JOIN(conditions=[t0.customer_key_12 == t1.key], types=['inner'], columns={'name': t0.name, 'nation_key_25': t1.nation_key}) - JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'customer_key_12': t1.customer_key, 'name': t0.name}) - JOIN(conditions=[t0.key_8 == t1.order_key], types=['inner'], columns={'name': t0.name, 'order_key': t1.order_key}) - JOIN(conditions=[t0.key_5 == t1.customer_key], types=['inner'], columns={'key_8': t1.key, 'name': t0.name}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key_5': t1.key, 'name': t0.name}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) - SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) -""", + "lineitem_regional_shipments3", ), id="lineitem_regional_shipments3", ), @@ -1621,18 +1184,7 @@ "COUNT", [ChildReferenceExpressionInfo("key", 1)] ), ), - """ -ROOT(columns=[('name', name), ('suppliers_in_black', suppliers_in_black), ('total_suppliers', total_suppliers)], orderings=[]) - PROJECT(columns={'name': name, 'suppliers_in_black': DEFAULT_TO(agg_0, 0:int64), 'total_suppliers': DEFAULT_TO(agg_1, 0:int64)}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'name': t0.name}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'name': t0.name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT(key)}) - FILTER(condition=account_balance > 0.0:float64, columns={'key': key, 'nation_key': nation_key}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': COUNT(key)}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) -""", + "num_positive_accounts_per_nation", ), id="num_positive_accounts_per_nation", ), @@ -1674,18 +1226,7 @@ ], ), ), - """ -ROOT(columns=[('name', name)], orderings=[]) - FILTER(condition=DEFAULT_TO(agg_0, 0:int64) > 0.5:float64 * DEFAULT_TO(agg_1, 0:int64), columns={'name': name}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'name': t0.name}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'name': t0.name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT(key)}) - FILTER(condition=account_balance > 0.0:float64, columns={'key': key, 'nation_key': nation_key}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': COUNT(key)}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) -""", + "mostly_positive_accounts_per_nation1", ), id="mostly_positive_accounts_per_nation1", ), @@ -1749,19 +1290,7 @@ ], ), ), - """ -ROOT(columns=[('name', name), ('suppliers_in_black', suppliers_in_black), ('total_suppliers', total_suppliers)], orderings=[]) - FILTER(condition=DEFAULT_TO(agg_2, 0:int64) > 0.5:float64 * DEFAULT_TO(agg_3, 0:int64), columns={'name': name, 'suppliers_in_black': suppliers_in_black, 'total_suppliers': total_suppliers}) - PROJECT(columns={'agg_2': agg_2, 'agg_3': agg_3, 'name': name, 'suppliers_in_black': DEFAULT_TO(agg_0, 0:int64), 'total_suppliers': DEFAULT_TO(agg_1, 0:int64)}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'agg_2': t0.agg_2, 'agg_3': t1.agg_3, 'name': t0.name}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'agg_2': t1.agg_2, 'key': t0.key, 'name': t0.name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT(key), 'agg_2': COUNT(key)}) - FILTER(condition=account_balance > 0.0:float64, columns={'key': key, 'nation_key': nation_key}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': COUNT(key), 'agg_3': COUNT(key)}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) -""", + "mostly_positive_accounts_per_nation2", ), id="mostly_positive_accounts_per_nation2", ), @@ -1807,19 +1336,7 @@ ], ), ), - """ -ROOT(columns=[('name', name), ('suppliers_in_black', suppliers_in_black), ('total_suppliers', total_suppliers)], orderings=[]) - FILTER(condition=suppliers_in_black > 0.5:float64 * total_suppliers, columns={'name': name, 'suppliers_in_black': suppliers_in_black, 'total_suppliers': total_suppliers}) - PROJECT(columns={'name': name, 'suppliers_in_black': DEFAULT_TO(agg_0, 0:int64), 'total_suppliers': DEFAULT_TO(agg_1, 0:int64)}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t0.agg_0, 'agg_1': t1.agg_1, 'name': t0.name}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'name': t0.name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT(key)}) - FILTER(condition=account_balance > 0.0:float64, columns={'key': key, 'nation_key': nation_key}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': COUNT(key)}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) -""", + "mostly_positive_accounts_per_nation3", ), id="mostly_positive_accounts_per_nation3", ), @@ -1827,12 +1344,7 @@ ( TableCollectionInfo("Regions") ** TopKInfo([], 2, (ReferenceInfo("name"), True, True)), - """ -ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[(ordering_0):asc_last]) - LIMIT(limit=Literal(value=2, type=Int64Type()), columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': ordering_0}, orderings=[(ordering_0):asc_last]) - PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': name}) - SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) -""", + "simple_topk", ), id="simple_topk", ), @@ -1846,15 +1358,7 @@ region_name=BackReferenceExpressionInfo("name", 1), nation_name=ReferenceInfo("name"), ), - """ -ROOT(columns=[('region_name', region_name), ('nation_name', nation_name)], orderings=[(ordering_0):asc_last]) - PROJECT(columns={'nation_name': name_3, 'ordering_0': ordering_0, 'region_name': name}) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'name': name, 'name_3': name_3, 'ordering_0': ordering_0}, orderings=[(ordering_0):asc_last]) - PROJECT(columns={'name': name, 'name_3': name_3, 'ordering_0': name_3}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) -""", + "join_topk", ), id="join_topk", ), @@ -1862,11 +1366,7 @@ ( TableCollectionInfo("Regions") ** OrderInfo([], (ReferenceInfo("name"), True, True)), - """ -ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[(ordering_0):asc_last]) - PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': name}) - SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) -""", + "simple_order_by", ), id="simple_order_by", ), @@ -1880,14 +1380,7 @@ region_name=BackReferenceExpressionInfo("name", 1), nation_name=ReferenceInfo("name"), ), - """ -ROOT(columns=[('region_name', region_name), ('nation_name', nation_name)], orderings=[(ordering_0):desc_last]) - PROJECT(columns={'nation_name': name_3, 'ordering_0': ordering_0, 'region_name': name}) - PROJECT(columns={'name': name, 'name_3': name_3, 'ordering_0': name_3}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) -""", + "join_order_by", ), id="join_order_by", ), @@ -1902,14 +1395,7 @@ nation_name=ReferenceInfo("name"), ) ** OrderInfo([], (ReferenceInfo("region_name"), False, True)), - """ -ROOT(columns=[('region_name', region_name), ('nation_name', nation_name)], orderings=[(ordering_1):desc_last]) - PROJECT(columns={'nation_name': nation_name, 'ordering_1': region_name, 'region_name': region_name}) - PROJECT(columns={'nation_name': name_3, 'region_name': name}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) -""", + "replace_order_by", ), id="replace_order_by", ), @@ -1918,12 +1404,7 @@ TableCollectionInfo("Regions") ** OrderInfo([], (ReferenceInfo("name"), True, True)) ** TopKInfo([], 10, (ReferenceInfo("name"), True, True)), - """ -ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[(ordering_1):asc_last]) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'comment': comment, 'key': key, 'name': name, 'ordering_1': ordering_1}, orderings=[(ordering_1):asc_last]) - PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_1': name}) - SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) -""", + "topk_order_by", ), id="topk_order_by", ), @@ -1937,13 +1418,7 @@ region_name=ReferenceInfo("name"), name_length=FunctionInfo("LENGTH", [ReferenceInfo("name")]), ), - """ -ROOT(columns=[('region_name', region_name), ('name_length', name_length)], orderings=[(ordering_1):asc_last]) - PROJECT(columns={'name_length': LENGTH(name), 'ordering_1': ordering_1, 'region_name': name}) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'name': name, 'ordering_1': ordering_1}, orderings=[(ordering_1):asc_last]) - PROJECT(columns={'name': name, 'ordering_1': name}) - SCAN(table=tpch.REGION, columns={'name': r_name}) -""", + "topk_order_by_calc", ), id="topk_order_by_calc", ), @@ -1953,12 +1428,7 @@ ** OrderInfo([], (ReferenceInfo("name"), True, True)) ** OrderInfo([], (ReferenceInfo("name"), False, False)) ** TopKInfo([], 10, (ReferenceInfo("name"), False, False)), - """ -ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[(ordering_2):desc_first]) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'comment': comment, 'key': key, 'name': name, 'ordering_2': ordering_2}, orderings=[(ordering_2):desc_first]) - PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_2': name}) - SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) -""", + "topk_replace_order_by", ), # Note: This tests is less useful because the rewrite has already # occurred for TopK. @@ -1970,13 +1440,7 @@ ** OrderInfo([], (ReferenceInfo("name"), True, False)) ** TopKInfo([], 10, (ReferenceInfo("name"), True, False)) ** OrderInfo([], (ReferenceInfo("name"), False, False)), - """ -ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[(ordering_2):desc_first]) - PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_2': name}) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'comment': comment, 'key': key, 'name': name}, orderings=[(ordering_1):asc_first]) - PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_1': name}) - SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) -""", + "topk_root_different_order_by", ), id="topk_root_different_order_by", ), @@ -1991,12 +1455,7 @@ 10, (FunctionInfo("LENGTH", [ReferenceInfo("name")]), True, False), ), - """ -ROOT(columns=[('key', key), ('name', name), ('comment', comment)], orderings=[(ordering_1):asc_first]) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'comment': comment, 'key': key, 'name': name, 'ordering_1': ordering_1}, orderings=[(ordering_1):asc_first]) - PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_1': LENGTH(name)}) - SCAN(table=tpch.REGION, columns={'comment': r_comment, 'key': r_regionkey, 'name': r_name}) -""", + "order_by_expression", ), id="order_by_expression", ), @@ -2005,13 +1464,7 @@ TableCollectionInfo("Regions") ** OrderInfo([], (ReferenceInfo("name"), True, False)) ** SubCollectionInfo("nations"), - """ -ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[]) - PROJECT(columns={'comment': comment_1, 'key': key_2, 'name': name_3, 'region_key': region_key}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'comment_1': t1.comment, 'key_2': t1.key, 'name_3': t1.name, 'region_key': t1.region_key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey}) - SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) -""", + "order_by_before_join", ), # Note: This behavior may change in the future. id="order_by_before_join", @@ -2030,14 +1483,7 @@ ], ), ), - """ -ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[(ordering_0):asc_last]) - FILTER(condition=name_3 == 'ASIA':string, columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': ordering_0, 'region_key': region_key}) - JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'comment': t0.comment, 'key': t0.key, 'name': t0.name, 'name_3': t1.name, 'ordering_0': t0.ordering_0, 'region_key': t0.region_key}) - PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': name, 'region_key': region_key}) - SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) -""", + "ordered_asian_nations", ), id="ordered_asian_nations", ), @@ -2049,13 +1495,7 @@ (ReferenceInfo("name"), True, True), (ChildReferenceExpressionInfo("name", 0), True, True), ), - """ -ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[(ordering_0):asc_last, (ordering_1):asc_last]) - PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': name, 'ordering_1': name_3, 'region_key': region_key}) - JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'comment': t0.comment, 'key': t0.key, 'name': t0.name, 'name_3': t1.name, 'region_key': t0.region_key}) - SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) -""", + "nations_region_order_by_name", ), id="nations_region_order_by_name", ), @@ -2074,16 +1514,7 @@ "COUNT", [ChildReferenceExpressionInfo("key", 0)] ), ), - """ -ROOT(columns=[('name', name), ('n_top_suppliers', n_top_suppliers)], orderings=[]) - PROJECT(columns={'n_top_suppliers': DEFAULT_TO(agg_0, 0:int64), 'name': name}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT(key)}) - LIMIT(limit=Literal(value=100, type=Int64Type()), columns={'key': key, 'nation_key': nation_key}, orderings=[(ordering_0):asc_last]) - PROJECT(columns={'key': key, 'nation_key': nation_key, 'ordering_0': account_balance}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) -""", + "count_at_most_100_suppliers_per_nation", ), id="count_at_most_100_suppliers_per_nation", ), @@ -2098,14 +1529,7 @@ True, ), ), - """ -ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[(ordering_0):asc_last]) - PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': DEFAULT_TO(agg_1, 0:int64), 'region_key': region_key}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_1': t1.agg_1, 'comment': t0.comment, 'key': t0.key, 'name': t0.name, 'region_key': t0.region_key}) - SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': COUNT(key)}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) -""", + "nations_order_by_num_suppliers", ), id="nations_order_by_num_suppliers", ), @@ -2121,15 +1545,7 @@ True, ), ), - """ -ROOT(columns=[('key', key), ('name', name), ('region_key', region_key), ('comment', comment)], orderings=[(ordering_0):asc_last]) - LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': ordering_0, 'region_key': region_key}, orderings=[(ordering_0):asc_last]) - PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': DEFAULT_TO(agg_1, 0:int64), 'region_key': region_key}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_1': t1.agg_1, 'comment': t0.comment, 'key': t0.key, 'name': t0.name, 'region_key': t0.region_key}) - SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': COUNT(key)}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) -""", + "top_5_nations_by_num_supplierss", ), id="top_5_nations_by_num_suppliers", ), @@ -2152,16 +1568,7 @@ "SUM", [ChildReferenceExpressionInfo("account_balance", 0)] ), ), - """ -ROOT(columns=[('name', name), ('total_bal', total_bal)], orderings=[(ordering_0):asc_last]) - PROJECT(columns={'name': name, 'ordering_0': ordering_0, 'total_bal': DEFAULT_TO(agg_2, 0:int64)}) - LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'agg_2': agg_2, 'name': name, 'ordering_0': ordering_0}, orderings=[(ordering_0):asc_last]) - PROJECT(columns={'agg_2': agg_2, 'name': name, 'ordering_0': DEFAULT_TO(agg_1, 0:int64)}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_1': t1.agg_1, 'agg_2': t1.agg_2, 'name': t0.name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_1': COUNT(), 'agg_2': SUM(account_balance)}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) -""", + "top_5_nations_balance_by_num_suppliers", ), id="top_5_nations_balance_by_num_suppliers", ), @@ -2175,14 +1582,7 @@ nation_name=ReferenceInfo("name"), ) ** OrderInfo([], (BackReferenceExpressionInfo("name", 1), False, True)), - """ -ROOT(columns=[('region_name', region_name), ('nation_name', nation_name)], orderings=[(ordering_0):desc_last]) - PROJECT(columns={'nation_name': nation_name, 'ordering_0': name, 'region_name': region_name}) - PROJECT(columns={'name': name, 'nation_name': name_3, 'region_name': name}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) -""", + "join_order_by_back_reference", ), id="join_order_by_back_reference", ), @@ -2195,14 +1595,7 @@ nation_name=ReferenceInfo("name"), ) ** OrderInfo([], (BackReferenceExpressionInfo("name", 1), False, True)), - """ -ROOT(columns=[('nation_name', nation_name)], orderings=[(ordering_0):desc_last]) - PROJECT(columns={'nation_name': nation_name, 'ordering_0': name}) - PROJECT(columns={'name': name, 'nation_name': name_3}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) -""", + "join_order_by_pruned_back_reference", ), id="join_order_by_pruned_back_reference", ), @@ -2233,13 +1626,7 @@ ordering_7=FunctionInfo("ABS", [ReferenceInfo("key")]), ordering_8=FunctionInfo("LENGTH", [ReferenceInfo("comment")]), ), - """ -ROOT(columns=[('ordering_0', ordering_0_0), ('ordering_1', ordering_1_0), ('ordering_2', ordering_2_0), ('ordering_3', ordering_3_0), ('ordering_4', ordering_4_0), ('ordering_5', ordering_5_0), ('ordering_6', ordering_6), ('ordering_7', ordering_7), ('ordering_8', ordering_8)], orderings=[(ordering_3):asc_last, (ordering_4):desc_last, (ordering_5):asc_first]) - PROJECT(columns={'ordering_0_0': ordering_2, 'ordering_1_0': ordering_0, 'ordering_2_0': ordering_1, 'ordering_3': ordering_3, 'ordering_3_0': ordering_2, 'ordering_4': ordering_4, 'ordering_4_0': ordering_1, 'ordering_5': ordering_5, 'ordering_5_0': ordering_0, 'ordering_6': LOWER(name), 'ordering_7': ABS(key), 'ordering_8': LENGTH(comment)}) - PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': ordering_0, 'ordering_1': ordering_1, 'ordering_2': ordering_2, 'ordering_3': LOWER(name), 'ordering_4': ABS(key), 'ordering_5': LENGTH(comment)}) - PROJECT(columns={'comment': comment, 'key': key, 'name': name, 'ordering_0': name, 'ordering_1': key, 'ordering_2': comment}) - SCAN(table=tpch.NATION, columns={'comment': n_comment, 'key': n_nationkey, 'name': n_name}) -""", + "ordering_name_overload", ), id="ordering_name_overload", ), @@ -2254,13 +1641,7 @@ [], name=ReferenceInfo("name"), ), - """ -ROOT(columns=[('name', name)], orderings=[]) - FILTER(condition=True:bool, columns={'name': name}) - JOIN(conditions=[t0.key == t1.customer_key], types=['semi'], columns={'name': t0.name}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'name': c_name}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) -""", + "simple_semi_1", ), id="simple_semi_1", ), @@ -2285,16 +1666,7 @@ [], name=ReferenceInfo("name"), ), - """ -ROOT(columns=[('name', name)], orderings=[]) - FILTER(condition=True:bool, columns={'name': name}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'name': t0.name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name}) - FILTER(condition=size < 10:int64, columns={'supplier_key': supplier_key}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'size': t1.size, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'size': p_size}) -""", + "simple_semi_2", ), id="simple_semi_2", ), @@ -2309,13 +1681,7 @@ [], name=ReferenceInfo("name"), ), - """ -ROOT(columns=[('name', name)], orderings=[]) - FILTER(condition=True:bool, columns={'name': name}) - JOIN(conditions=[t0.key == t1.customer_key], types=['anti'], columns={'name': t0.name}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'name': c_name}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey}) -""", + "simple_anti_1", ), id="simple_anti_1", ), @@ -2340,16 +1706,7 @@ [], name=ReferenceInfo("name"), ), - """ -ROOT(columns=[('name', name)], orderings=[]) - FILTER(condition=True:bool, columns={'name': name}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['anti'], columns={'name': t0.name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name}) - FILTER(condition=size < 10:int64, columns={'supplier_key': supplier_key}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'size': t1.size, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'size': p_size}) -""", + "simple_anti_2", ), id="simple_anti_2", ), @@ -2389,15 +1746,7 @@ name=ReferenceInfo("name"), region_name=ChildReferenceExpressionInfo("name", 0), ), - """ -ROOT(columns=[('name', name), ('region_name', region_name)], orderings=[]) - PROJECT(columns={'name': name, 'region_name': name_3}) - FILTER(condition=True:bool, columns={'name': name, 'name_3': name_3}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) - FILTER(condition=name != 'ASIA':string, columns={'key': key, 'name': name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) -""", + "semi_singular", ), id="semi_singular", ), @@ -2437,15 +1786,7 @@ ], FunctionInfo("HAS", [ChildReferenceCollectionInfo(0)]), ), - """ -ROOT(columns=[('name', name), ('region_name', region_name)], orderings=[]) - FILTER(condition=True:bool, columns={'name': name, 'region_name': region_name}) - PROJECT(columns={'name': name, 'region_name': name_3}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) - FILTER(condition=name != 'ASIA':string, columns={'key': key, 'name': name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) -""", + "singular_semi", ), id="singular_semi", ), @@ -2489,18 +1830,7 @@ "SUM", [ChildReferenceExpressionInfo("retail_price", 0)] ), ), - """ -ROOT(columns=[('name', name), ('num_10parts', num_10parts), ('avg_price_of_10parts', avg_price_of_10parts), ('sum_price_of_10parts', sum_price_of_10parts)], orderings=[]) - PROJECT(columns={'avg_price_of_10parts': agg_0, 'name': name, 'num_10parts': DEFAULT_TO(agg_1, 0:int64), 'sum_price_of_10parts': DEFAULT_TO(agg_2, 0:int64)}) - FILTER(condition=True:bool, columns={'agg_0': agg_0, 'agg_1': agg_1, 'agg_2': agg_2, 'name': name}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['inner'], columns={'agg_0': t1.agg_0, 'agg_1': t1.agg_1, 'agg_2': t1.agg_2, 'name': t0.name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name}) - AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price), 'agg_1': COUNT(), 'agg_2': SUM(retail_price)}) - FILTER(condition=size == 10:int64, columns={'retail_price': retail_price, 'supplier_key': supplier_key}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'size': t1.size, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice, 'size': p_size}) -""", + "semi_aggregate", ), id="semi_aggregate", ), @@ -2544,18 +1874,7 @@ ], FunctionInfo("HAS", [ChildReferenceCollectionInfo(0)]), ), - """ -ROOT(columns=[('name', name), ('num_10parts', num_10parts), ('avg_price_of_10parts', avg_price_of_10parts), ('sum_price_of_10parts', sum_price_of_10parts)], orderings=[]) - FILTER(condition=True:bool, columns={'avg_price_of_10parts': avg_price_of_10parts, 'name': name, 'num_10parts': num_10parts, 'sum_price_of_10parts': sum_price_of_10parts}) - PROJECT(columns={'avg_price_of_10parts': agg_0, 'name': name, 'num_10parts': DEFAULT_TO(agg_1, 0:int64), 'sum_price_of_10parts': DEFAULT_TO(agg_2, 0:int64)}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['inner'], columns={'agg_0': t1.agg_0, 'agg_1': t1.agg_1, 'agg_2': t1.agg_2, 'name': t0.name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name}) - AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price), 'agg_1': COUNT(), 'agg_2': SUM(retail_price)}) - FILTER(condition=size == 10:int64, columns={'retail_price': retail_price, 'supplier_key': supplier_key}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'size': t1.size, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice, 'size': p_size}) -""", + "aggregate_semi", ), id="aggregate_semi", ), @@ -2595,15 +1914,7 @@ name=ReferenceInfo("name"), region_name=ChildReferenceExpressionInfo("name", 0), ), - """ -ROOT(columns=[('name', name), ('region_name', region_name)], orderings=[]) - PROJECT(columns={'name': name, 'region_name': NULL_1}) - FILTER(condition=True:bool, columns={'NULL_1': NULL_1, 'name': name}) - JOIN(conditions=[t0.region_key == t1.key], types=['anti'], columns={'NULL_1': None:unknown, 'name': t0.name}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) - FILTER(condition=name != 'ASIA':string, columns={'key': key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) -""", + "anti_singular", ), id="anti_singular", ), @@ -2643,15 +1954,7 @@ ], FunctionInfo("HASNOT", [ChildReferenceCollectionInfo(0)]), ), - """ -ROOT(columns=[('name', name), ('region_name', region_name)], orderings=[]) - FILTER(condition=True:bool, columns={'name': name, 'region_name': region_name}) - PROJECT(columns={'name': name, 'region_name': NULL_1}) - JOIN(conditions=[t0.region_key == t1.key], types=['anti'], columns={'NULL_1': None:unknown, 'name': t0.name}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) - FILTER(condition=name != 'ASIA':string, columns={'key': key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) -""", + "singular_anti", ), id="singular_anti", ), @@ -2695,17 +1998,7 @@ "SUM", [ChildReferenceExpressionInfo("retail_price", 0)] ), ), - """ -ROOT(columns=[('name', name), ('num_10parts', num_10parts), ('avg_price_of_10parts', avg_price_of_10parts), ('sum_price_of_10parts', sum_price_of_10parts)], orderings=[]) - PROJECT(columns={'avg_price_of_10parts': NULL_2, 'name': name, 'num_10parts': DEFAULT_TO(NULL_2, 0:int64), 'sum_price_of_10parts': DEFAULT_TO(NULL_2, 0:int64)}) - FILTER(condition=True:bool, columns={'NULL_2': NULL_2, 'name': name}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name}) - FILTER(condition=size == 10:int64, columns={'supplier_key': supplier_key}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'size': t1.size, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'size': p_size}) -""", + "anti_aggregate", ), id="anti_aggregate", ), @@ -2749,17 +2042,7 @@ ], FunctionInfo("HASNOT", [ChildReferenceCollectionInfo(0)]), ), - """ -ROOT(columns=[('name', name), ('num_10parts', num_10parts), ('avg_price_of_10parts', avg_price_of_10parts), ('sum_price_of_10parts', sum_price_of_10parts)], orderings=[]) - FILTER(condition=True:bool, columns={'avg_price_of_10parts': avg_price_of_10parts, 'name': name, 'num_10parts': num_10parts, 'sum_price_of_10parts': sum_price_of_10parts}) - PROJECT(columns={'avg_price_of_10parts': NULL_2, 'name': name, 'num_10parts': DEFAULT_TO(NULL_2, 0:int64), 'sum_price_of_10parts': DEFAULT_TO(NULL_2, 0:int64)}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name}) - FILTER(condition=size == 10:int64, columns={'supplier_key': supplier_key}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'size': t1.size, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'size': p_size}) -""", + "aggregate_anti", ), id="aggregate_anti", ), @@ -2818,32 +2101,7 @@ ), ) ** CalcInfo([], name=ReferenceInfo("name")), - """ -ROOT(columns=[('name', name)], orderings=[]) - FILTER(condition=True:bool & True:bool & True:bool, columns={'name': name}) - JOIN(conditions=[t0.key == t1.part_key], types=['semi'], columns={'name': t0.name}) - JOIN(conditions=[t0.key == t1.part_key], types=['anti'], columns={'key': t0.key, 'name': t0.name}) - JOIN(conditions=[t0.key == t1.part_key], types=['semi'], columns={'key': t0.key, 'name': t0.name}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'name': p_name}) - FILTER(condition=name_4 == 'GERMANY':string, columns={'part_key': part_key}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'name_4': t1.name, 'part_key': t0.part_key}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - FILTER(condition=name_8 == 'FRANCE':string, columns={'part_key': part_key}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'name_8': t1.name, 'part_key': t0.part_key}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - FILTER(condition=name_12 == 'ARGENTINA':string, columns={'part_key': part_key}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'name_12': t1.name, 'part_key': t0.part_key}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'nation_key': t1.nation_key, 'part_key': t0.part_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) -""", + "multiple_has_hasnot", ), id="multiple_has_hasnot", ), @@ -2858,11 +2116,7 @@ (ReferenceInfo("acctbal"), False, True), ), ), - """ -ROOT(columns=[('name', name), ('cust_rank', cust_rank)], orderings=[]) - PROJECT(columns={'cust_rank': RANKING(args=[], partition=[], order=[(acctbal):desc_first]), 'name': name}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name}) -""", + "rank_customers", ), id="rank_customers", ), @@ -2881,13 +2135,7 @@ allow_ties=True, ), ), - """ -ROOT(columns=[('nation_name', nation_name), ('name', name), ('cust_rank', cust_rank)], orderings=[]) - PROJECT(columns={'cust_rank': RANKING(args=[], partition=[key], order=[(acctbal):desc_first], allow_ties=True), 'name': name_3, 'nation_name': name}) - JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'key': t0.key, 'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name, 'nation_key': c_nationkey}) -""", + "rank_customers_per_nation", ), id="rank_customers_per_nation", ), @@ -2908,15 +2156,7 @@ dense=True, ), ), - """ -ROOT(columns=[('nation_name', nation_name), ('name', name), ('cust_rank', cust_rank)], orderings=[]) - PROJECT(columns={'cust_rank': RANKING(args=[], partition=[key], order=[(acctbal):desc_first], allow_ties=True, dense=True), 'name': name_6, 'nation_name': name_3}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'key': t0.key, 'name_3': t0.name_3, 'name_6': t1.name}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name_3': t1.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name, 'nation_key': c_nationkey}) -""", + "rank_customers_per_region", ), id="rank_customers_per_region", ), @@ -2940,15 +2180,7 @@ "MAX", [ChildReferenceExpressionInfo("cust_rank", 0)] ), ), - """ -ROOT(columns=[('nation_name', nation_name), ('highest_rank', highest_rank)], orderings=[]) - PROJECT(columns={'highest_rank': agg_0, 'nation_name': name}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': MAX(cust_rank)}) - PROJECT(columns={'cust_rank': RANKING(args=[], partition=[], order=[(acctbal):desc_first], allow_ties=True), 'nation_key': nation_key}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) -""", + "agg_max_ranking", ), id="agg_max_ranking", ), @@ -2967,17 +2199,26 @@ def test_ast_to_relational( relational_test_data: tuple[CollectionTestInfo, str], tpch_node_builder: AstNodeBuilder, default_config: PyDoughConfigs, + get_plan_test_filename: Callable[[str], str], + update_plan_tests: bool, ) -> None: """ Tests whether the QDAG nodes are correctly translated into Relational nodes with the expected string representation. """ - calc_pipeline, expected_relational_string = relational_test_data + calc_pipeline, file_name = relational_test_data + file_path: str = get_plan_test_filename(file_name) collection: PyDoughCollectionQDAG = calc_pipeline.build(tpch_node_builder) relational = convert_ast_to_relational(collection, default_config) - assert ( - relational.to_tree_string() == expected_relational_string.strip() - ), "Mismatch between full string representation of output Relational node versus expected string" + if update_plan_tests: + with open(file_path, "w") as f: + f.write(relational.to_tree_string() + "\n") + else: + with open(file_path) as f: + expected_relational_string: str = f.read() + assert ( + relational.to_tree_string() == expected_relational_string.strip() + ), "Mismatch between full string representation of output Relational node versus expected string" @pytest.fixture( @@ -3005,14 +2246,7 @@ def test_ast_to_relational( ), num_cust=FunctionInfo("COUNT", [ChildReferenceCollectionInfo(0)]), ), - """ -ROOT(columns=[('nation_name', nation_name), ('total_bal', total_bal), ('num_bal', num_bal), ('avg_bal', avg_bal), ('min_bal', min_bal), ('max_bal', max_bal), ('num_cust', num_cust)], orderings=[]) - PROJECT(columns={'avg_bal': DEFAULT_TO(agg_0, 0:int64), 'max_bal': agg_1, 'min_bal': agg_2, 'nation_name': name, 'num_bal': DEFAULT_TO(agg_3, 0:int64), 'num_cust': DEFAULT_TO(agg_4, 0:int64), 'total_bal': agg_5}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'agg_1': t1.agg_1, 'agg_2': t1.agg_2, 'agg_3': t1.agg_3, 'agg_4': t1.agg_4, 'agg_5': t1.agg_5, 'name': t0.name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': AVG(acctbal), 'agg_1': MAX(acctbal), 'agg_2': MIN(acctbal), 'agg_3': COUNT(acctbal), 'agg_4': COUNT(), 'agg_5': SUM(acctbal)}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) -""", + "various_aggfuncs_simple", ), id="various_aggfuncs_simple", ), @@ -3037,12 +2271,7 @@ def test_ast_to_relational( ), num_cust=FunctionInfo("COUNT", [ChildReferenceCollectionInfo(0)]), ), - """ -ROOT(columns=[('total_bal', total_bal), ('num_bal', num_bal), ('avg_bal', avg_bal), ('min_bal', min_bal), ('max_bal', max_bal), ('num_cust', num_cust)], orderings=[]) - PROJECT(columns={'avg_bal': DEFAULT_TO(agg_0, 0:int64), 'max_bal': agg_1, 'min_bal': agg_2, 'num_bal': agg_3, 'num_cust': agg_4, 'total_bal': agg_5}) - AGGREGATE(keys={}, aggregations={'agg_0': AVG(acctbal), 'agg_1': MAX(acctbal), 'agg_2': MIN(acctbal), 'agg_3': COUNT(acctbal), 'agg_4': COUNT(), 'agg_5': SUM(acctbal)}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal}) -""", + "various_aggfuncs_global", ), id="various_aggfuncs_global", ), @@ -3086,19 +2315,9 @@ def test_ast_to_relational( "SUM", [ChildReferenceExpressionInfo("retail_price", 0)] ), ), - """ -ROOT(columns=[('name', name), ('num_10parts', num_10parts), ('avg_price_of_10parts', avg_price_of_10parts), ('sum_price_of_10parts', sum_price_of_10parts)], orderings=[]) - PROJECT(columns={'avg_price_of_10parts': DEFAULT_TO(NULL_2, 0:int64), 'name': name, 'num_10parts': DEFAULT_TO(NULL_2, 0:int64), 'sum_price_of_10parts': NULL_2}) - FILTER(condition=True:bool, columns={'NULL_2': NULL_2, 'name': name}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name}) - FILTER(condition=size == 10:int64, columns={'supplier_key': supplier_key}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'size': t1.size, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'size': p_size}) -""", + "anti_aggregate_alternate", ), - id="anti_aggregate", + id="anti_aggregate_alternate", ), ], ) @@ -3115,6 +2334,8 @@ def test_ast_to_relational_alternative_aggregation_configs( relational_alternative_config_test_data: tuple[CollectionTestInfo, str], tpch_node_builder: AstNodeBuilder, default_config: PyDoughConfigs, + get_plan_test_filename: Callable[[str], str], + update_plan_tests: bool, ) -> None: """ Same as `test_ast_to_relational` but with various alternative aggregation @@ -3122,11 +2343,18 @@ def test_ast_to_relational_alternative_aggregation_configs( - `SUM` defaulting to zero is disabled. - `COUNT` defaulting to zero is disabled. """ - calc_pipeline, expected_relational_string = relational_alternative_config_test_data + calc_pipeline, file_name = relational_alternative_config_test_data + file_path: str = get_plan_test_filename(file_name) default_config.sum_default_zero = False default_config.avg_default_zero = True collection: PyDoughCollectionQDAG = calc_pipeline.build(tpch_node_builder) relational = convert_ast_to_relational(collection, default_config) - assert ( - relational.to_tree_string() == expected_relational_string.strip() - ), "Mismatch between full string representation of output Relational node versus expected string" + if update_plan_tests: + with open(file_path, "w") as f: + f.write(relational.to_tree_string() + "\n") + else: + with open(file_path) as f: + expected_relational_string: str = f.read() + assert ( + relational.to_tree_string() == expected_relational_string.strip() + ), "Mismatch between full string representation of output Relational node versus expected string" From ceff1058895ed92daf040aa5078db8db9906c38c Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 3 Feb 2025 14:18:51 -0500 Subject: [PATCH 082/112] Converting pipeline files to new format --- tests/README.md | 6 + tests/test_pipeline.py | 531 ++---------------- tests/test_plan_refsols/agg_partition.txt | 7 + tests/test_plan_refsols/double_partition.txt | 7 + tests/test_plan_refsols/function_sampler.txt | 10 + .../percentile_customers_per_region.txt | 9 + .../test_plan_refsols/percentile_nations.txt | 3 + .../rank_nations_by_region.txt | 5 + .../rank_nations_per_region_by_customers.txt | 10 + ...rank_parts_per_supplier_region_by_size.txt | 13 + .../test_plan_refsols/rank_with_filters_a.txt | 5 + .../test_plan_refsols/rank_with_filters_b.txt | 5 + .../test_plan_refsols/rank_with_filters_c.txt | 9 + .../regional_suppliers_percentile.txt | 11 + .../simple_filter_top_five.txt | 5 + .../simple_scan_top_five.txt | 4 + tests/test_plan_refsols/tpch_q1.txt | 6 + tests/test_plan_refsols/tpch_q10.txt | 15 + tests/test_plan_refsols/tpch_q11.txt | 23 + tests/test_plan_refsols/tpch_q12.txt | 9 + tests/test_plan_refsols/tpch_q13.txt | 11 + tests/test_plan_refsols/tpch_q14.txt | 8 + tests/test_plan_refsols/tpch_q15.txt | 18 + tests/test_plan_refsols/tpch_q16.txt | 13 + tests/test_plan_refsols/tpch_q17.txt | 12 + tests/test_plan_refsols/tpch_q18.txt | 11 + tests/test_plan_refsols/tpch_q19.txt | 7 + tests/test_plan_refsols/tpch_q2.txt | 31 + tests/test_plan_refsols/tpch_q20.txt | 18 + tests/test_plan_refsols/tpch_q3.txt | 12 + tests/test_plan_refsols/tpch_q4.txt | 9 + tests/test_plan_refsols/tpch_q6.txt | 6 + tests/test_plan_refsols/tpch_q7.txt | 17 + tests/test_plan_refsols/tpch_q8.txt | 24 + tests/test_plan_refsols/tpch_q9.txt | 18 + tests/test_plan_refsols/triple_partition.txt | 30 + tests/test_qdag_conversion.py | 9 +- 37 files changed, 464 insertions(+), 483 deletions(-) create mode 100644 tests/test_plan_refsols/agg_partition.txt create mode 100644 tests/test_plan_refsols/double_partition.txt create mode 100644 tests/test_plan_refsols/function_sampler.txt create mode 100644 tests/test_plan_refsols/percentile_customers_per_region.txt create mode 100644 tests/test_plan_refsols/percentile_nations.txt create mode 100644 tests/test_plan_refsols/rank_nations_by_region.txt create mode 100644 tests/test_plan_refsols/rank_nations_per_region_by_customers.txt create mode 100644 tests/test_plan_refsols/rank_parts_per_supplier_region_by_size.txt create mode 100644 tests/test_plan_refsols/rank_with_filters_a.txt create mode 100644 tests/test_plan_refsols/rank_with_filters_b.txt create mode 100644 tests/test_plan_refsols/rank_with_filters_c.txt create mode 100644 tests/test_plan_refsols/regional_suppliers_percentile.txt create mode 100644 tests/test_plan_refsols/simple_filter_top_five.txt create mode 100644 tests/test_plan_refsols/simple_scan_top_five.txt create mode 100644 tests/test_plan_refsols/tpch_q1.txt create mode 100644 tests/test_plan_refsols/tpch_q10.txt create mode 100644 tests/test_plan_refsols/tpch_q11.txt create mode 100644 tests/test_plan_refsols/tpch_q12.txt create mode 100644 tests/test_plan_refsols/tpch_q13.txt create mode 100644 tests/test_plan_refsols/tpch_q14.txt create mode 100644 tests/test_plan_refsols/tpch_q15.txt create mode 100644 tests/test_plan_refsols/tpch_q16.txt create mode 100644 tests/test_plan_refsols/tpch_q17.txt create mode 100644 tests/test_plan_refsols/tpch_q18.txt create mode 100644 tests/test_plan_refsols/tpch_q19.txt create mode 100644 tests/test_plan_refsols/tpch_q2.txt create mode 100644 tests/test_plan_refsols/tpch_q20.txt create mode 100644 tests/test_plan_refsols/tpch_q3.txt create mode 100644 tests/test_plan_refsols/tpch_q4.txt create mode 100644 tests/test_plan_refsols/tpch_q6.txt create mode 100644 tests/test_plan_refsols/tpch_q7.txt create mode 100644 tests/test_plan_refsols/tpch_q8.txt create mode 100644 tests/test_plan_refsols/tpch_q9.txt create mode 100644 tests/test_plan_refsols/triple_partition.txt diff --git a/tests/README.md b/tests/README.md index f4803a4e..4fc8baff 100644 --- a/tests/README.md +++ b/tests/README.md @@ -2,6 +2,12 @@ This module contains the tests for PyDough, including unit tests, integration tests, and utility functions used in testing. +## Planner Tests + +Many of the tests in files such as `test_qdag_conversion.py` and `test_pipeline.py` test the relational plan string generated by a PyDough query. The refsols for these tests are stored in the `test_plan_refsols` directory as text files. When these tests are run normally, the refsol is extracted from the files. However, if the environment variable `PYDOUGH_UPDATE_TESTS` is set to 1, then instead of checking against these files, the generated output is manually written to them. + +Make sure to set `PYDOUGH_UPDATE_TESTS` to 1 when you need to update these tests, or need to create a new one, then immediately unset it once your tests have been updated so you don't forget and think your tests are passing when they are actually failing and overriding the correct answers with nonsense. + ## TestInfo Classes The `TestInfo` classes are used to specify information about a QDAG (Qualified Directed Acyclic Graph) node before it can be created. These classes help in building and testing QDAG nodes for unit tests. diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 41605af3..430c7797 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -100,13 +100,7 @@ pytest.param( ( impl_tpch_q1, - """ -ROOT(columns=[('L_RETURNFLAG', L_RETURNFLAG), ('L_LINESTATUS', L_LINESTATUS), ('SUM_QTY', SUM_QTY), ('SUM_BASE_PRICE', SUM_BASE_PRICE), ('SUM_DISC_PRICE', SUM_DISC_PRICE), ('SUM_CHARGE', SUM_CHARGE), ('AVG_QTY', AVG_QTY), ('AVG_PRICE', AVG_PRICE), ('AVG_DISC', AVG_DISC), ('COUNT_ORDER', COUNT_ORDER)], orderings=[(ordering_8):asc_first, (ordering_9):asc_first]) - PROJECT(columns={'AVG_DISC': AVG_DISC, 'AVG_PRICE': AVG_PRICE, 'AVG_QTY': AVG_QTY, 'COUNT_ORDER': COUNT_ORDER, 'L_LINESTATUS': L_LINESTATUS, 'L_RETURNFLAG': L_RETURNFLAG, 'SUM_BASE_PRICE': SUM_BASE_PRICE, 'SUM_CHARGE': SUM_CHARGE, 'SUM_DISC_PRICE': SUM_DISC_PRICE, 'SUM_QTY': SUM_QTY, 'ordering_8': L_RETURNFLAG, 'ordering_9': L_LINESTATUS}) - PROJECT(columns={'AVG_DISC': agg_0, 'AVG_PRICE': agg_1, 'AVG_QTY': agg_2, 'COUNT_ORDER': DEFAULT_TO(agg_3, 0:int64), 'L_LINESTATUS': status, 'L_RETURNFLAG': return_flag, 'SUM_BASE_PRICE': DEFAULT_TO(agg_4, 0:int64), 'SUM_CHARGE': DEFAULT_TO(agg_5, 0:int64), 'SUM_DISC_PRICE': DEFAULT_TO(agg_6, 0:int64), 'SUM_QTY': DEFAULT_TO(agg_7, 0:int64)}) - AGGREGATE(keys={'return_flag': return_flag, 'status': status}, aggregations={'agg_0': AVG(discount), 'agg_1': AVG(extended_price), 'agg_2': AVG(quantity), 'agg_3': COUNT(), 'agg_4': SUM(extended_price), 'agg_5': SUM(extended_price * 1:int64 - discount * 1:int64 + tax), 'agg_6': SUM(extended_price * 1:int64 - discount), 'agg_7': SUM(quantity)}) - FILTER(condition=ship_date <= datetime.date(1998, 12, 1):date, columns={'discount': discount, 'extended_price': extended_price, 'quantity': quantity, 'return_flag': return_flag, 'status': status, 'tax': tax}) - SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'quantity': l_quantity, 'return_flag': l_returnflag, 'ship_date': l_shipdate, 'status': l_linestatus, 'tax': l_tax})""", + "tpch_q1", tpch_q1_output, ), id="tpch_q1", @@ -114,38 +108,7 @@ pytest.param( ( impl_tpch_q2, - """ -ROOT(columns=[('S_ACCTBAL', S_ACCTBAL), ('S_NAME', S_NAME), ('N_NAME', N_NAME), ('P_PARTKEY', P_PARTKEY), ('P_MFGR', P_MFGR), ('S_ADDRESS', S_ADDRESS), ('S_PHONE', S_PHONE), ('S_COMMENT', S_COMMENT)], orderings=[(ordering_1):desc_last, (ordering_2):asc_first, (ordering_3):asc_first, (ordering_4):asc_first]) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'N_NAME': N_NAME, 'P_MFGR': P_MFGR, 'P_PARTKEY': P_PARTKEY, 'S_ACCTBAL': S_ACCTBAL, 'S_ADDRESS': S_ADDRESS, 'S_COMMENT': S_COMMENT, 'S_NAME': S_NAME, 'S_PHONE': S_PHONE, 'ordering_1': ordering_1, 'ordering_2': ordering_2, 'ordering_3': ordering_3, 'ordering_4': ordering_4}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first, (ordering_3):asc_first, (ordering_4):asc_first]) - PROJECT(columns={'N_NAME': N_NAME, 'P_MFGR': P_MFGR, 'P_PARTKEY': P_PARTKEY, 'S_ACCTBAL': S_ACCTBAL, 'S_ADDRESS': S_ADDRESS, 'S_COMMENT': S_COMMENT, 'S_NAME': S_NAME, 'S_PHONE': S_PHONE, 'ordering_1': S_ACCTBAL, 'ordering_2': N_NAME, 'ordering_3': S_NAME, 'ordering_4': P_PARTKEY}) - PROJECT(columns={'N_NAME': n_name, 'P_MFGR': manufacturer, 'P_PARTKEY': key_19, 'S_ACCTBAL': s_acctbal, 'S_ADDRESS': s_address, 'S_COMMENT': s_comment, 'S_NAME': s_name, 'S_PHONE': s_phone}) - FILTER(condition=supplycost_21 == best_cost & ENDSWITH(part_type, 'BRASS':string) & size == 15:int64, columns={'key_19': key_19, 'manufacturer': manufacturer, 'n_name': n_name, 's_acctbal': s_acctbal, 's_address': s_address, 's_comment': s_comment, 's_name': s_name, 's_phone': s_phone}) - JOIN(conditions=[t0.key_9 == t1.key_19], types=['inner'], columns={'best_cost': t0.best_cost, 'key_19': t1.key_19, 'manufacturer': t1.manufacturer, 'n_name': t1.n_name, 'part_type': t1.part_type, 's_acctbal': t1.s_acctbal, 's_address': t1.s_address, 's_comment': t1.s_comment, 's_name': t1.s_name, 's_phone': t1.s_phone, 'size': t1.size, 'supplycost_21': t1.supplycost}) - PROJECT(columns={'best_cost': agg_0, 'key_9': key_9}) - AGGREGATE(keys={'key_9': key_9}, aggregations={'agg_0': MIN(supplycost)}) - FILTER(condition=ENDSWITH(part_type, 'BRASS':string) & size == 15:int64, columns={'key_9': key_9, 'supplycost': supplycost}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'key_9': t1.key, 'part_type': t1.part_type, 'size': t1.size, 'supplycost': t0.supplycost}) - JOIN(conditions=[t0.key_5 == t1.supplier_key], types=['inner'], columns={'part_key': t1.part_key, 'supplycost': t1.supplycost}) - JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_5': t1.key}) - FILTER(condition=name_3 == 'EUROPE':string, columns={'key': key}) - JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name_3': t1.name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'part_type': p_type, 'size': p_size}) - FILTER(condition=ENDSWITH(part_type, 'BRASS':string) & size == 15:int64, columns={'key_19': key_19, 'manufacturer': manufacturer, 'n_name': n_name, 'part_type': part_type, 's_acctbal': s_acctbal, 's_address': s_address, 's_comment': s_comment, 's_name': s_name, 's_phone': s_phone, 'size': size, 'supplycost': supplycost}) - PROJECT(columns={'key_19': key_19, 'manufacturer': manufacturer, 'n_name': name, 'part_type': part_type, 's_acctbal': account_balance, 's_address': address, 's_comment': comment_14, 's_name': name_16, 's_phone': phone, 'size': size, 'supplycost': supplycost}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'account_balance': t0.account_balance, 'address': t0.address, 'comment_14': t0.comment_14, 'key_19': t1.key, 'manufacturer': t1.manufacturer, 'name': t0.name, 'name_16': t0.name_16, 'part_type': t1.part_type, 'phone': t0.phone, 'size': t1.size, 'supplycost': t0.supplycost}) - JOIN(conditions=[t0.key_15 == t1.supplier_key], types=['inner'], columns={'account_balance': t0.account_balance, 'address': t0.address, 'comment_14': t0.comment_14, 'name': t0.name, 'name_16': t0.name_16, 'part_key': t1.part_key, 'phone': t0.phone, 'supplycost': t1.supplycost}) - JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'account_balance': t1.account_balance, 'address': t1.address, 'comment_14': t1.comment, 'key_15': t1.key, 'name': t0.name, 'name_16': t1.name, 'phone': t1.phone}) - FILTER(condition=name_13 == 'EUROPE':string, columns={'key': key, 'name': name}) - JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_13': t1.name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'address': s_address, 'comment': s_comment, 'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey, 'phone': s_phone}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'manufacturer': p_mfgr, 'part_type': p_type, 'size': p_size})""", + "tpch_q2", tpch_q2_output, ), id="tpch_q2", @@ -153,20 +116,7 @@ pytest.param( ( impl_tpch_q3, - """ -ROOT(columns=[('L_ORDERKEY', L_ORDERKEY), ('REVENUE', REVENUE), ('O_ORDERDATE', O_ORDERDATE), ('O_SHIPPRIORITY', O_SHIPPRIORITY)], orderings=[(ordering_1):desc_last, (ordering_2):asc_first, (ordering_3):asc_first]) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'L_ORDERKEY': L_ORDERKEY, 'O_ORDERDATE': O_ORDERDATE, 'O_SHIPPRIORITY': O_SHIPPRIORITY, 'REVENUE': REVENUE, 'ordering_1': ordering_1, 'ordering_2': ordering_2, 'ordering_3': ordering_3}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first, (ordering_3):asc_first]) - PROJECT(columns={'L_ORDERKEY': L_ORDERKEY, 'O_ORDERDATE': O_ORDERDATE, 'O_SHIPPRIORITY': O_SHIPPRIORITY, 'REVENUE': REVENUE, 'ordering_1': REVENUE, 'ordering_2': O_ORDERDATE, 'ordering_3': L_ORDERKEY}) - PROJECT(columns={'L_ORDERKEY': order_key, 'O_ORDERDATE': order_date, 'O_SHIPPRIORITY': ship_priority, 'REVENUE': DEFAULT_TO(agg_0, 0:int64)}) - AGGREGATE(keys={'order_date': order_date, 'order_key': order_key, 'ship_priority': ship_priority}, aggregations={'agg_0': SUM(extended_price * 1:int64 - discount)}) - FILTER(condition=ship_date > datetime.date(1995, 3, 15):date, columns={'discount': discount, 'extended_price': extended_price, 'order_date': order_date, 'order_key': order_key, 'ship_priority': ship_priority}) - JOIN(conditions=[t0.key == t1.order_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'order_date': t0.order_date, 'order_key': t1.order_key, 'ship_date': t1.ship_date, 'ship_priority': t0.ship_priority}) - FILTER(condition=mktsegment == 'BUILDING':string & order_date < datetime.date(1995, 3, 15):date, columns={'key': key, 'order_date': order_date, 'ship_priority': ship_priority}) - JOIN(conditions=[t0.customer_key == t1.key], types=['left'], columns={'key': t0.key, 'mktsegment': t1.mktsegment, 'order_date': t0.order_date, 'ship_priority': t0.ship_priority}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate, 'ship_priority': o_shippriority}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'mktsegment': c_mktsegment}) - SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'ship_date': l_shipdate}) -""", + "tpch_q3", tpch_q3_output, ), id="tpch_q3", @@ -174,17 +124,7 @@ pytest.param( ( impl_tpch_q4, - """ -ROOT(columns=[('O_ORDERPRIORITY', O_ORDERPRIORITY), ('ORDER_COUNT', ORDER_COUNT)], orderings=[(ordering_1):asc_first]) - PROJECT(columns={'ORDER_COUNT': ORDER_COUNT, 'O_ORDERPRIORITY': O_ORDERPRIORITY, 'ordering_1': O_ORDERPRIORITY}) - PROJECT(columns={'ORDER_COUNT': DEFAULT_TO(agg_0, 0:int64), 'O_ORDERPRIORITY': order_priority}) - AGGREGATE(keys={'order_priority': order_priority}, aggregations={'agg_0': COUNT()}) - FILTER(condition=order_date >= datetime.date(1993, 7, 1):date & order_date < datetime.date(1993, 10, 1):date & True:bool, columns={'order_priority': order_priority}) - JOIN(conditions=[t0.key == t1.order_key], types=['semi'], columns={'order_date': t0.order_date, 'order_priority': t0.order_priority}) - SCAN(table=tpch.ORDERS, columns={'key': o_orderkey, 'order_date': o_orderdate, 'order_priority': o_orderpriority}) - FILTER(condition=commit_date < receipt_date, columns={'order_key': order_key}) - SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate}) -""", + "tpch_q4", tpch_q4_output, ), id="tpch_q4", @@ -192,8 +132,7 @@ pytest.param( ( impl_tpch_q5, - """ -""", + "tpch_q5", tpch_q5_output, ), id="tpch_q5", @@ -202,14 +141,7 @@ pytest.param( ( impl_tpch_q6, - """ -ROOT(columns=[('REVENUE', REVENUE)], orderings=[]) - PROJECT(columns={'REVENUE': DEFAULT_TO(agg_0, 0:int64)}) - AGGREGATE(keys={}, aggregations={'agg_0': SUM(amt)}) - PROJECT(columns={'amt': extended_price * discount}) - FILTER(condition=ship_date >= datetime.date(1994, 1, 1):date & ship_date < datetime.date(1995, 1, 1):date & discount >= 0.05:float64 & discount <= 0.07:float64 & quantity < 24:int64, columns={'discount': discount, 'extended_price': extended_price}) - SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'quantity': l_quantity, 'ship_date': l_shipdate}) -""", + "tpch_q6", tpch_q6_output, ), id="tpch_q6", @@ -217,25 +149,7 @@ pytest.param( ( impl_tpch_q7, - """ -ROOT(columns=[('SUPP_NATION', SUPP_NATION), ('CUST_NATION', CUST_NATION), ('L_YEAR', L_YEAR), ('REVENUE', REVENUE)], orderings=[(ordering_1):asc_first, (ordering_2):asc_first, (ordering_3):asc_first]) - PROJECT(columns={'CUST_NATION': CUST_NATION, 'L_YEAR': L_YEAR, 'REVENUE': REVENUE, 'SUPP_NATION': SUPP_NATION, 'ordering_1': SUPP_NATION, 'ordering_2': CUST_NATION, 'ordering_3': L_YEAR}) - PROJECT(columns={'CUST_NATION': cust_nation, 'L_YEAR': l_year, 'REVENUE': DEFAULT_TO(agg_0, 0:int64), 'SUPP_NATION': supp_nation}) - AGGREGATE(keys={'cust_nation': cust_nation, 'l_year': l_year, 'supp_nation': supp_nation}, aggregations={'agg_0': SUM(volume)}) - FILTER(condition=ship_date >= datetime.date(1995, 1, 1):date & ship_date <= datetime.date(1996, 12, 31):date & supp_nation == 'FRANCE':string & cust_nation == 'GERMANY':string | supp_nation == 'GERMANY':string & cust_nation == 'FRANCE':string, columns={'cust_nation': cust_nation, 'l_year': l_year, 'supp_nation': supp_nation, 'volume': volume}) - PROJECT(columns={'cust_nation': name_8, 'l_year': YEAR(ship_date), 'ship_date': ship_date, 'supp_nation': name_3, 'volume': extended_price * 1:int64 - discount}) - JOIN(conditions=[t0.order_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name_3': t0.name_3, 'name_8': t1.name_8, 'ship_date': t0.ship_date}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name_3': t1.name_3, 'order_key': t0.order_key, 'ship_date': t0.ship_date}) - SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_3': t1.name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_8': t1.name}) - JOIN(conditions=[t0.customer_key == t1.key], types=['inner'], columns={'key': t0.key, 'nation_key': t1.nation_key}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) -""", + "tpch_q7", tpch_q7_output, ), id="tpch_q7", @@ -243,32 +157,7 @@ pytest.param( ( impl_tpch_q8, - """ -ROOT(columns=[('O_YEAR', O_YEAR), ('MKT_SHARE', MKT_SHARE)], orderings=[]) - PROJECT(columns={'MKT_SHARE': DEFAULT_TO(agg_0, 0:int64) / DEFAULT_TO(agg_1, 0:int64), 'O_YEAR': o_year}) - AGGREGATE(keys={'o_year': o_year}, aggregations={'agg_0': SUM(brazil_volume), 'agg_1': SUM(volume)}) - FILTER(condition=order_date >= datetime.date(1995, 1, 1):date & order_date <= datetime.date(1996, 12, 31):date & name_18 == 'AMERICA':string, columns={'brazil_volume': brazil_volume, 'o_year': o_year, 'volume': volume}) - JOIN(conditions=[t0.customer_key == t1.key], types=['left'], columns={'brazil_volume': t0.brazil_volume, 'name_18': t1.name_18, 'o_year': t0.o_year, 'order_date': t0.order_date, 'volume': t0.volume}) - PROJECT(columns={'brazil_volume': IFF(name == 'BRAZIL':string, volume, 0:int64), 'customer_key': customer_key, 'o_year': YEAR(order_date), 'order_date': order_date, 'volume': volume}) - JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'customer_key': t1.customer_key, 'name': t0.name, 'order_date': t1.order_date, 'volume': t0.volume}) - PROJECT(columns={'name': name, 'order_key': order_key, 'volume': extended_price * 1:int64 - discount}) - JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'name': t0.name, 'order_key': t1.order_key}) - FILTER(condition=part_type == 'ECONOMY ANODIZED STEEL':string, columns={'name': name, 'part_key': part_key, 'supplier_key': supplier_key}) - JOIN(conditions=[t0.part_key == t1.key], types=['left'], columns={'name': t0.name, 'part_key': t0.part_key, 'part_type': t1.part_type, 'supplier_key': t0.supplier_key}) - JOIN(conditions=[t0.key_2 == t1.supplier_key], types=['inner'], columns={'name': t0.name, 'part_key': t1.part_key, 'supplier_key': t1.supplier_key}) - JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'part_type': p_type}) - SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'part_key': l_partkey, 'supplier_key': l_suppkey}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_18': t1.name}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) -""", + "tpch_q8", tpch_q8_output, ), id="tpch_q8", @@ -276,26 +165,7 @@ pytest.param( ( impl_tpch_q9, - """ -ROOT(columns=[('NATION', NATION), ('O_YEAR', O_YEAR), ('AMOUNT', AMOUNT)], orderings=[(ordering_1):asc_first, (ordering_2):desc_last]) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'AMOUNT': AMOUNT, 'NATION': NATION, 'O_YEAR': O_YEAR, 'ordering_1': ordering_1, 'ordering_2': ordering_2}, orderings=[(ordering_1):asc_first, (ordering_2):desc_last]) - PROJECT(columns={'AMOUNT': AMOUNT, 'NATION': NATION, 'O_YEAR': O_YEAR, 'ordering_1': NATION, 'ordering_2': O_YEAR}) - PROJECT(columns={'AMOUNT': DEFAULT_TO(agg_0, 0:int64), 'NATION': nation, 'O_YEAR': o_year}) - AGGREGATE(keys={'nation': nation, 'o_year': o_year}, aggregations={'agg_0': SUM(value)}) - PROJECT(columns={'nation': name, 'o_year': YEAR(order_date), 'value': extended_price * 1:int64 - discount - supplycost * quantity}) - JOIN(conditions=[t0.order_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name': t0.name, 'order_date': t1.order_date, 'quantity': t0.quantity, 'supplycost': t0.supplycost}) - JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'name': t0.name, 'order_key': t1.order_key, 'quantity': t1.quantity, 'supplycost': t0.supplycost}) - FILTER(condition=CONTAINS(name_7, 'green':string), columns={'name': name, 'part_key': part_key, 'supplier_key': supplier_key, 'supplycost': supplycost}) - JOIN(conditions=[t0.part_key == t1.key], types=['left'], columns={'name': t0.name, 'name_7': t1.name, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key, 'supplycost': t0.supplycost}) - JOIN(conditions=[t0.key_2 == t1.supplier_key], types=['inner'], columns={'name': t0.name, 'part_key': t1.part_key, 'supplier_key': t1.supplier_key, 'supplycost': t1.supplycost}) - JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'name': p_name}) - SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'part_key': l_partkey, 'quantity': l_quantity, 'supplier_key': l_suppkey}) - SCAN(table=tpch.ORDERS, columns={'key': o_orderkey, 'order_date': o_orderdate}) -""", + "tpch_q9", tpch_q9_output, ), id="tpch_q9", @@ -303,23 +173,7 @@ pytest.param( ( impl_tpch_q10, - """ -ROOT(columns=[('C_CUSTKEY', C_CUSTKEY), ('C_NAME', C_NAME), ('REVENUE', REVENUE), ('C_ACCTBAL', C_ACCTBAL), ('N_NAME', N_NAME), ('C_ADDRESS', C_ADDRESS), ('C_PHONE', C_PHONE), ('C_COMMENT', C_COMMENT)], orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) - LIMIT(limit=Literal(value=20, type=Int64Type()), columns={'C_ACCTBAL': C_ACCTBAL, 'C_ADDRESS': C_ADDRESS, 'C_COMMENT': C_COMMENT, 'C_CUSTKEY': C_CUSTKEY, 'C_NAME': C_NAME, 'C_PHONE': C_PHONE, 'N_NAME': N_NAME, 'REVENUE': REVENUE, 'ordering_1': ordering_1, 'ordering_2': ordering_2}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) - PROJECT(columns={'C_ACCTBAL': C_ACCTBAL, 'C_ADDRESS': C_ADDRESS, 'C_COMMENT': C_COMMENT, 'C_CUSTKEY': C_CUSTKEY, 'C_NAME': C_NAME, 'C_PHONE': C_PHONE, 'N_NAME': N_NAME, 'REVENUE': REVENUE, 'ordering_1': REVENUE, 'ordering_2': C_CUSTKEY}) - PROJECT(columns={'C_ACCTBAL': acctbal, 'C_ADDRESS': address, 'C_COMMENT': comment, 'C_CUSTKEY': key, 'C_NAME': name, 'C_PHONE': phone, 'N_NAME': name_4, 'REVENUE': DEFAULT_TO(agg_0, 0:int64)}) - JOIN(conditions=[t0.nation_key == t1.key], types=['left'], columns={'acctbal': t0.acctbal, 'address': t0.address, 'agg_0': t0.agg_0, 'comment': t0.comment, 'key': t0.key, 'name': t0.name, 'name_4': t1.name, 'phone': t0.phone}) - JOIN(conditions=[t0.key == t1.customer_key], types=['left'], columns={'acctbal': t0.acctbal, 'address': t0.address, 'agg_0': t1.agg_0, 'comment': t0.comment, 'key': t0.key, 'name': t0.name, 'nation_key': t0.nation_key, 'phone': t0.phone}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'address': c_address, 'comment': c_comment, 'key': c_custkey, 'name': c_name, 'nation_key': c_nationkey, 'phone': c_phone}) - AGGREGATE(keys={'customer_key': customer_key}, aggregations={'agg_0': SUM(amt)}) - PROJECT(columns={'amt': extended_price * 1:int64 - discount, 'customer_key': customer_key}) - FILTER(condition=return_flag == 'R':string, columns={'customer_key': customer_key, 'discount': discount, 'extended_price': extended_price}) - JOIN(conditions=[t0.key == t1.order_key], types=['inner'], columns={'customer_key': t0.customer_key, 'discount': t1.discount, 'extended_price': t1.extended_price, 'return_flag': t1.return_flag}) - FILTER(condition=order_date >= datetime.date(1993, 10, 1):date & order_date < datetime.date(1994, 1, 1):date, columns={'customer_key': customer_key, 'key': key}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) - SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'return_flag': l_returnflag}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) -""", + "tpch_q10", tpch_q10_output, ), id="tpch_q10", @@ -327,31 +181,7 @@ pytest.param( ( impl_tpch_q11, - """ -ROOT(columns=[('PS_PARTKEY', PS_PARTKEY), ('VALUE', VALUE)], orderings=[(ordering_2):desc_last]) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'PS_PARTKEY': PS_PARTKEY, 'VALUE': VALUE, 'ordering_2': ordering_2}, orderings=[(ordering_2):desc_last]) - PROJECT(columns={'PS_PARTKEY': PS_PARTKEY, 'VALUE': VALUE, 'ordering_2': VALUE}) - FILTER(condition=VALUE > min_market_share, columns={'PS_PARTKEY': PS_PARTKEY, 'VALUE': VALUE}) - PROJECT(columns={'PS_PARTKEY': part_key, 'VALUE': DEFAULT_TO(agg_1, 0:int64), 'min_market_share': min_market_share}) - JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1, 'min_market_share': t0.min_market_share, 'part_key': t1.part_key}) - PROJECT(columns={'min_market_share': DEFAULT_TO(agg_0, 0:int64) * 0.0001:float64}) - AGGREGATE(keys={}, aggregations={'agg_0': SUM(metric)}) - PROJECT(columns={'metric': supplycost * availqty}) - FILTER(condition=name_3 == 'GERMANY':string, columns={'availqty': availqty, 'supplycost': supplycost}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'availqty': t0.availqty, 'name_3': t1.name_3, 'supplycost': t0.supplycost}) - SCAN(table=tpch.PARTSUPP, columns={'availqty': ps_availqty, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_3': t1.name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'part_key': part_key}, aggregations={'agg_1': SUM(metric)}) - PROJECT(columns={'metric': supplycost * availqty, 'part_key': part_key}) - FILTER(condition=name_6 == 'GERMANY':string, columns={'availqty': availqty, 'part_key': part_key, 'supplycost': supplycost}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'availqty': t0.availqty, 'name_6': t1.name_6, 'part_key': t0.part_key, 'supplycost': t0.supplycost}) - SCAN(table=tpch.PARTSUPP, columns={'availqty': ps_availqty, 'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_6': t1.name}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) -""", + "tpch_q11", tpch_q11_output, ), id="tpch_q11", @@ -359,17 +189,7 @@ pytest.param( ( impl_tpch_q12, - """ -ROOT(columns=[('L_SHIPMODE', L_SHIPMODE), ('HIGH_LINE_COUNT', HIGH_LINE_COUNT), ('LOW_LINE_COUNT', LOW_LINE_COUNT)], orderings=[(ordering_2):asc_first]) - PROJECT(columns={'HIGH_LINE_COUNT': HIGH_LINE_COUNT, 'LOW_LINE_COUNT': LOW_LINE_COUNT, 'L_SHIPMODE': L_SHIPMODE, 'ordering_2': L_SHIPMODE}) - PROJECT(columns={'HIGH_LINE_COUNT': DEFAULT_TO(agg_0, 0:int64), 'LOW_LINE_COUNT': DEFAULT_TO(agg_1, 0:int64), 'L_SHIPMODE': ship_mode}) - AGGREGATE(keys={'ship_mode': ship_mode}, aggregations={'agg_0': SUM(is_high_priority), 'agg_1': SUM(NOT(is_high_priority))}) - PROJECT(columns={'is_high_priority': order_priority == '1-URGENT':string | order_priority == '2-HIGH':string, 'ship_mode': ship_mode}) - JOIN(conditions=[t0.order_key == t1.key], types=['left'], columns={'order_priority': t1.order_priority, 'ship_mode': t0.ship_mode}) - FILTER(condition=ship_mode == 'MAIL':string | ship_mode == 'SHIP':string & ship_date < commit_date & commit_date < receipt_date & receipt_date >= datetime.date(1994, 1, 1):date & receipt_date < datetime.date(1995, 1, 1):date, columns={'order_key': order_key, 'ship_mode': ship_mode}) - SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate, 'ship_date': l_shipdate, 'ship_mode': l_shipmode}) - SCAN(table=tpch.ORDERS, columns={'key': o_orderkey, 'order_priority': o_orderpriority}) -""", + "tpch_q12", tpch_q12_output, ), id="tpch_q12", @@ -377,19 +197,7 @@ pytest.param( ( impl_tpch_q13, - """ -ROOT(columns=[('C_COUNT', C_COUNT), ('CUSTDIST', CUSTDIST)], orderings=[(ordering_1):desc_last, (ordering_2):desc_last]) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'CUSTDIST': CUSTDIST, 'C_COUNT': C_COUNT, 'ordering_1': ordering_1, 'ordering_2': ordering_2}, orderings=[(ordering_1):desc_last, (ordering_2):desc_last]) - PROJECT(columns={'CUSTDIST': CUSTDIST, 'C_COUNT': C_COUNT, 'ordering_1': CUSTDIST, 'ordering_2': C_COUNT}) - PROJECT(columns={'CUSTDIST': DEFAULT_TO(agg_0, 0:int64), 'C_COUNT': num_non_special_orders}) - AGGREGATE(keys={'num_non_special_orders': num_non_special_orders}, aggregations={'agg_0': COUNT()}) - PROJECT(columns={'num_non_special_orders': DEFAULT_TO(agg_0, 0:int64)}) - JOIN(conditions=[t0.key == t1.customer_key], types=['left'], columns={'agg_0': t1.agg_0}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey}) - AGGREGATE(keys={'customer_key': customer_key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=NOT(LIKE(comment, '%special%requests%':string)), columns={'customer_key': customer_key}) - SCAN(table=tpch.ORDERS, columns={'comment': o_comment, 'customer_key': o_custkey}) -""", + "tpch_q13", tpch_q13_output, ), id="tpch_q13", @@ -397,16 +205,7 @@ pytest.param( ( impl_tpch_q14, - """ -ROOT(columns=[('PROMO_REVENUE', PROMO_REVENUE)], orderings=[]) - PROJECT(columns={'PROMO_REVENUE': 100.0:float64 * DEFAULT_TO(agg_0, 0:int64) / DEFAULT_TO(agg_1, 0:int64)}) - AGGREGATE(keys={}, aggregations={'agg_0': SUM(promo_value), 'agg_1': SUM(value)}) - PROJECT(columns={'promo_value': IFF(STARTSWITH(part_type, 'PROMO':string), extended_price * 1:int64 - discount, 0:int64), 'value': extended_price * 1:int64 - discount}) - JOIN(conditions=[t0.part_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'part_type': t1.part_type}) - FILTER(condition=ship_date >= datetime.date(1995, 9, 1):date & ship_date < datetime.date(1995, 10, 1):date, columns={'discount': discount, 'extended_price': extended_price, 'part_key': part_key}) - SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'part_key': l_partkey, 'ship_date': l_shipdate}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'part_type': p_type}) -""", + "tpch_q14", tpch_q14_output, ), id="tpch_q14", @@ -414,26 +213,7 @@ pytest.param( ( impl_tpch_q15, - """ -ROOT(columns=[('S_SUPPKEY', S_SUPPKEY), ('S_NAME', S_NAME), ('S_ADDRESS', S_ADDRESS), ('S_PHONE', S_PHONE), ('TOTAL_REVENUE', TOTAL_REVENUE)], orderings=[(ordering_2):asc_first]) - PROJECT(columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME, 'S_PHONE': S_PHONE, 'S_SUPPKEY': S_SUPPKEY, 'TOTAL_REVENUE': TOTAL_REVENUE, 'ordering_2': S_SUPPKEY}) - FILTER(condition=TOTAL_REVENUE == max_revenue, columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME, 'S_PHONE': S_PHONE, 'S_SUPPKEY': S_SUPPKEY, 'TOTAL_REVENUE': TOTAL_REVENUE}) - PROJECT(columns={'S_ADDRESS': address, 'S_NAME': name, 'S_PHONE': phone, 'S_SUPPKEY': key, 'TOTAL_REVENUE': DEFAULT_TO(agg_1, 0:int64), 'max_revenue': max_revenue}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'address': t0.address, 'agg_1': t1.agg_1, 'key': t0.key, 'max_revenue': t0.max_revenue, 'name': t0.name, 'phone': t0.phone}) - JOIN(conditions=[True:bool], types=['inner'], columns={'address': t1.address, 'key': t1.key, 'max_revenue': t0.max_revenue, 'name': t1.name, 'phone': t1.phone}) - PROJECT(columns={'max_revenue': agg_0}) - AGGREGATE(keys={}, aggregations={'agg_0': MAX(total_revenue)}) - PROJECT(columns={'total_revenue': DEFAULT_TO(agg_0, 0:int64)}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'agg_0': t1.agg_0}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey}) - AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': SUM(extended_price * 1:int64 - discount)}) - FILTER(condition=ship_date >= datetime.date(1996, 1, 1):date & ship_date < datetime.date(1996, 4, 1):date, columns={'discount': discount, 'extended_price': extended_price, 'supplier_key': supplier_key}) - SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) - SCAN(table=tpch.SUPPLIER, columns={'address': s_address, 'key': s_suppkey, 'name': s_name, 'phone': s_phone}) - AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_1': SUM(extended_price * 1:int64 - discount)}) - FILTER(condition=ship_date >= datetime.date(1996, 1, 1):date & ship_date < datetime.date(1996, 4, 1):date, columns={'discount': discount, 'extended_price': extended_price, 'supplier_key': supplier_key}) - SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) -""", + "tpch_q15", tpch_q15_output, ), id="tpch_q15", @@ -441,21 +221,7 @@ pytest.param( ( impl_tpch_q16, - """ -ROOT(columns=[('P_BRAND', P_BRAND), ('P_TYPE', P_TYPE), ('P_SIZE', P_SIZE), ('SUPPLIER_COUNT', SUPPLIER_COUNT)], orderings=[(ordering_1):desc_last, (ordering_2):asc_first, (ordering_3):asc_first, (ordering_4):asc_first]) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'P_BRAND': P_BRAND, 'P_SIZE': P_SIZE, 'P_TYPE': P_TYPE, 'SUPPLIER_COUNT': SUPPLIER_COUNT, 'ordering_1': ordering_1, 'ordering_2': ordering_2, 'ordering_3': ordering_3, 'ordering_4': ordering_4}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first, (ordering_3):asc_first, (ordering_4):asc_first]) - PROJECT(columns={'P_BRAND': P_BRAND, 'P_SIZE': P_SIZE, 'P_TYPE': P_TYPE, 'SUPPLIER_COUNT': SUPPLIER_COUNT, 'ordering_1': SUPPLIER_COUNT, 'ordering_2': P_BRAND, 'ordering_3': P_TYPE, 'ordering_4': P_SIZE}) - PROJECT(columns={'P_BRAND': p_brand, 'P_SIZE': p_size, 'P_TYPE': p_type, 'SUPPLIER_COUNT': agg_0}) - AGGREGATE(keys={'p_brand': p_brand, 'p_size': p_size, 'p_type': p_type}, aggregations={'agg_0': NDISTINCT(supplier_key)}) - FILTER(condition=NOT(LIKE(comment_2, '%Customer%Complaints%':string)), columns={'p_brand': p_brand, 'p_size': p_size, 'p_type': p_type, 'supplier_key': supplier_key}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'comment_2': t1.comment, 'p_brand': t0.p_brand, 'p_size': t0.p_size, 'p_type': t0.p_type, 'supplier_key': t0.supplier_key}) - PROJECT(columns={'p_brand': brand, 'p_size': size, 'p_type': part_type, 'supplier_key': supplier_key}) - JOIN(conditions=[t0.key == t1.part_key], types=['inner'], columns={'brand': t0.brand, 'part_type': t0.part_type, 'size': t0.size, 'supplier_key': t1.supplier_key}) - FILTER(condition=brand != 'BRAND#45':string & NOT(STARTSWITH(part_type, 'MEDIUM POLISHED%':string)) & ISIN(size, [49, 14, 23, 45, 19, 3, 36, 9]:array[unknown]), columns={'brand': brand, 'key': key, 'part_type': part_type, 'size': size}) - SCAN(table=tpch.PART, columns={'brand': p_brand, 'key': p_partkey, 'part_type': p_type, 'size': p_size}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.SUPPLIER, columns={'comment': s_comment, 'key': s_suppkey}) -""", + "tpch_q16", tpch_q16_output, ), id="tpch_q16", @@ -463,20 +229,7 @@ pytest.param( ( impl_tpch_q17, - """ -ROOT(columns=[('AVG_YEARLY', AVG_YEARLY)], orderings=[]) - PROJECT(columns={'AVG_YEARLY': DEFAULT_TO(agg_0, 0:int64) / 7.0:float64}) - AGGREGATE(keys={}, aggregations={'agg_0': SUM(extended_price)}) - FILTER(condition=quantity < 0.2:float64 * avg_quantity, columns={'extended_price': extended_price}) - JOIN(conditions=[t0.key == t1.part_key], types=['inner'], columns={'avg_quantity': t0.avg_quantity, 'extended_price': t1.extended_price, 'quantity': t1.quantity}) - PROJECT(columns={'avg_quantity': agg_0, 'key': key}) - JOIN(conditions=[t0.key == t1.part_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) - FILTER(condition=brand == 'Brand#23':string & container == 'MED BOX':string, columns={'key': key}) - SCAN(table=tpch.PART, columns={'brand': p_brand, 'container': p_container, 'key': p_partkey}) - AGGREGATE(keys={'part_key': part_key}, aggregations={'agg_0': AVG(quantity)}) - SCAN(table=tpch.LINEITEM, columns={'part_key': l_partkey, 'quantity': l_quantity}) - SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'part_key': l_partkey, 'quantity': l_quantity}) -""", + "tpch_q17", tpch_q17_output, ), id="tpch_q17", @@ -484,19 +237,7 @@ pytest.param( ( impl_tpch_q18, - """ -ROOT(columns=[('C_NAME', C_NAME), ('C_CUSTKEY', C_CUSTKEY), ('O_ORDERKEY', O_ORDERKEY), ('O_ORDERDATE', O_ORDERDATE), ('O_TOTALPRICE', O_TOTALPRICE), ('TOTAL_QUANTITY', TOTAL_QUANTITY)], orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'C_CUSTKEY': C_CUSTKEY, 'C_NAME': C_NAME, 'O_ORDERDATE': O_ORDERDATE, 'O_ORDERKEY': O_ORDERKEY, 'O_TOTALPRICE': O_TOTALPRICE, 'TOTAL_QUANTITY': TOTAL_QUANTITY, 'ordering_1': ordering_1, 'ordering_2': ordering_2}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) - PROJECT(columns={'C_CUSTKEY': C_CUSTKEY, 'C_NAME': C_NAME, 'O_ORDERDATE': O_ORDERDATE, 'O_ORDERKEY': O_ORDERKEY, 'O_TOTALPRICE': O_TOTALPRICE, 'TOTAL_QUANTITY': TOTAL_QUANTITY, 'ordering_1': O_TOTALPRICE, 'ordering_2': O_ORDERDATE}) - FILTER(condition=TOTAL_QUANTITY > 300:int64, columns={'C_CUSTKEY': C_CUSTKEY, 'C_NAME': C_NAME, 'O_ORDERDATE': O_ORDERDATE, 'O_ORDERKEY': O_ORDERKEY, 'O_TOTALPRICE': O_TOTALPRICE, 'TOTAL_QUANTITY': TOTAL_QUANTITY}) - PROJECT(columns={'C_CUSTKEY': key_2, 'C_NAME': name, 'O_ORDERDATE': order_date, 'O_ORDERKEY': key, 'O_TOTALPRICE': total_price, 'TOTAL_QUANTITY': DEFAULT_TO(agg_0, 0:int64)}) - JOIN(conditions=[t0.key == t1.order_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'key_2': t0.key_2, 'name': t0.name, 'order_date': t0.order_date, 'total_price': t0.total_price}) - JOIN(conditions=[t0.customer_key == t1.key], types=['left'], columns={'key': t0.key, 'key_2': t1.key, 'name': t1.name, 'order_date': t0.order_date, 'total_price': t0.total_price}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'name': c_name}) - AGGREGATE(keys={'order_key': order_key}, aggregations={'agg_0': SUM(quantity)}) - SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'quantity': l_quantity}) -""", + "tpch_q18", tpch_q18_output, ), id="tpch_q18", @@ -504,15 +245,7 @@ pytest.param( ( impl_tpch_q19, - """ -ROOT(columns=[('REVENUE', REVENUE)], orderings=[]) - PROJECT(columns={'REVENUE': DEFAULT_TO(agg_0, 0:int64)}) - AGGREGATE(keys={}, aggregations={'agg_0': SUM(extended_price * 1:int64 - discount)}) - FILTER(condition=ISIN(ship_mode, ['AIR', 'AIR REG']:array[unknown]) & ship_instruct == 'DELIVER IN PERSON':string & size >= 1:int64 & size <= 5:int64 & quantity >= 1:int64 & quantity <= 11:int64 & ISIN(container, ['SM CASE', 'SM BOX', 'SM PACK', 'SM PKG']:array[unknown]) & brand == 'Brand#12':string | size <= 10:int64 & quantity >= 10:int64 & quantity <= 20:int64 & ISIN(container, ['MED BAG', 'MED BOX', 'MED PACK', 'MED PKG']:array[unknown]) & brand == 'Brand#23':string | size <= 15:int64 & quantity >= 20:int64 & quantity <= 30:int64 & ISIN(container, ['LG CASE', 'LG BOX', 'LG PACK', 'LG PKG']:array[unknown]) & brand == 'Brand#34':string, columns={'discount': discount, 'extended_price': extended_price}) - JOIN(conditions=[t0.part_key == t1.key], types=['left'], columns={'brand': t1.brand, 'container': t1.container, 'discount': t0.discount, 'extended_price': t0.extended_price, 'quantity': t0.quantity, 'ship_instruct': t0.ship_instruct, 'ship_mode': t0.ship_mode, 'size': t1.size}) - SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'part_key': l_partkey, 'quantity': l_quantity, 'ship_instruct': l_shipinstruct, 'ship_mode': l_shipmode}) - SCAN(table=tpch.PART, columns={'brand': p_brand, 'container': p_container, 'key': p_partkey, 'size': p_size}) -""", + "tpch_q19", tpch_q19_output, ), id="tpch_q19", @@ -520,26 +253,7 @@ pytest.param( ( impl_tpch_q20, - """ -ROOT(columns=[('S_NAME', S_NAME), ('S_ADDRESS', S_ADDRESS)], orderings=[(ordering_1):asc_first]) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME, 'ordering_1': ordering_1}, orderings=[(ordering_1):asc_first]) - PROJECT(columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME, 'ordering_1': S_NAME}) - FILTER(condition=name_3 == 'CANADA':string & DEFAULT_TO(agg_0, 0:int64) > 0:int64, columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'S_ADDRESS': t0.S_ADDRESS, 'S_NAME': t0.S_NAME, 'agg_0': t1.agg_0, 'name_3': t0.name_3}) - JOIN(conditions=[t0.nation_key == t1.key], types=['left'], columns={'S_ADDRESS': t0.S_ADDRESS, 'S_NAME': t0.S_NAME, 'key': t0.key, 'name_3': t1.name}) - PROJECT(columns={'S_ADDRESS': address, 'S_NAME': name, 'key': key, 'nation_key': nation_key}) - SCAN(table=tpch.SUPPLIER, columns={'address': s_address, 'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) - AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=STARTSWITH(name, 'forest':string) & availqty > DEFAULT_TO(agg_0, 0:int64) * 0.5:float64, columns={'supplier_key': supplier_key}) - JOIN(conditions=[t0.key == t1.part_key], types=['left'], columns={'agg_0': t1.agg_0, 'availqty': t0.availqty, 'name': t0.name, 'supplier_key': t0.supplier_key}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'availqty': t0.availqty, 'key': t1.key, 'name': t1.name, 'supplier_key': t0.supplier_key}) - SCAN(table=tpch.PARTSUPP, columns={'availqty': ps_availqty, 'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.PART, columns={'key': p_partkey, 'name': p_name}) - AGGREGATE(keys={'part_key': part_key}, aggregations={'agg_0': SUM(quantity)}) - FILTER(condition=ship_date >= datetime.date(1994, 1, 1):date & ship_date < datetime.date(1995, 1, 1):date, columns={'part_key': part_key, 'quantity': quantity}) - SCAN(table=tpch.LINEITEM, columns={'part_key': l_partkey, 'quantity': l_quantity, 'ship_date': l_shipdate}) -""", + "tpch_q20", tpch_q20_output, ), id="tpch_q20", @@ -547,8 +261,7 @@ pytest.param( ( impl_tpch_q21, - """ -""", + "tpch_q21", tpch_q21_output, ), id="tpch_q21", @@ -557,8 +270,7 @@ pytest.param( ( impl_tpch_q22, - """ -""", + "tpch_q22", tpch_q22_output, ), id="tpch_q22", @@ -567,12 +279,7 @@ pytest.param( ( simple_scan_top_five, - """ -ROOT(columns=[('key', key)], orderings=[(ordering_0):asc_first]) - LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'key': key, 'ordering_0': ordering_0}, orderings=[(ordering_0):asc_first]) - PROJECT(columns={'key': key, 'ordering_0': key}) - SCAN(table=tpch.ORDERS, columns={'key': o_orderkey}) -""", + "simple_scan_top_five", lambda: pd.DataFrame( { "key": [1, 2, 3, 4, 5], @@ -584,13 +291,7 @@ pytest.param( ( simple_filter_top_five, - """ -ROOT(columns=[('key', key), ('total_price', total_price)], orderings=[(ordering_0):desc_last]) - LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'key': key, 'ordering_0': ordering_0, 'total_price': total_price}, orderings=[(ordering_0):desc_last]) - PROJECT(columns={'key': key, 'ordering_0': key, 'total_price': total_price}) - FILTER(condition=total_price < 1000.0:float64, columns={'key': key, 'total_price': total_price}) - SCAN(table=tpch.ORDERS, columns={'key': o_orderkey, 'total_price': o_totalprice}) -""", + "simple_filter_top_five", lambda: pd.DataFrame( { "key": [5989315, 5935174, 5881093, 5876066, 5866437], @@ -603,13 +304,7 @@ pytest.param( ( rank_nations_by_region, - """ -ROOT(columns=[('name', name), ('rank', rank)], orderings=[]) - PROJECT(columns={'name': name, 'rank': RANKING(args=[], partition=[], order=[(name_3):asc_last], allow_ties=True)}) - JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) -""", + "rank_nations_by_region", lambda: pd.DataFrame( { "name": [ @@ -648,18 +343,7 @@ pytest.param( ( rank_nations_per_region_by_customers, - """ -ROOT(columns=[('name', name), ('rank', rank)], orderings=[(ordering_1):asc_first]) - LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'name': name, 'ordering_1': ordering_1, 'rank': rank}, orderings=[(ordering_1):asc_first]) - PROJECT(columns={'name': name, 'ordering_1': rank, 'rank': rank}) - PROJECT(columns={'name': name_3, 'rank': RANKING(args=[], partition=[key], order=[(DEFAULT_TO(agg_0, 0:int64)):desc_first])}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'name_3': t0.name_3}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name_3': t1.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) - SCAN(table=tpch.CUSTOMER, columns={'nation_key': c_nationkey}) -""", + "rank_nations_per_region_by_customers", lambda: pd.DataFrame( { "name": ["KENYA", "CANADA", "INDONESIA", "FRANCE", "JORDAN"], @@ -672,21 +356,7 @@ pytest.param( ( rank_parts_per_supplier_region_by_size, - """ -ROOT(columns=[('key', key), ('region', region), ('rank', rank)], orderings=[(ordering_0):asc_first]) - LIMIT(limit=Literal(value=15, type=Int64Type()), columns={'key': key, 'ordering_0': ordering_0, 'rank': rank, 'region': region}, orderings=[(ordering_0):asc_first]) - PROJECT(columns={'key': key, 'ordering_0': key, 'rank': rank, 'region': region}) - PROJECT(columns={'key': key_9, 'rank': RANKING(args=[], partition=[key], order=[(size):desc_first, (container):desc_first, (part_type):desc_first], allow_ties=True, dense=True), 'region': name}) - JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'container': t1.container, 'key': t0.key, 'key_9': t1.key, 'name': t0.name, 'part_type': t1.part_type, 'size': t1.size}) - JOIN(conditions=[t0.key_5 == t1.supplier_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'part_key': t1.part_key}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'name': t0.name}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name': t0.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) - SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'part_type': p_type, 'size': p_size}) -""", + "rank_parts_per_supplier_region_by_size", lambda: pd.DataFrame( { "key": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4], @@ -732,13 +402,7 @@ pytest.param( ( rank_with_filters_a, - """ -ROOT(columns=[('n', n), ('r', r)], orderings=[]) - FILTER(condition=r <= 30:int64, columns={'n': n, 'r': r}) - FILTER(condition=ENDSWITH(name, '0':string), columns={'n': n, 'r': r}) - PROJECT(columns={'n': name, 'name': name, 'r': RANKING(args=[], partition=[], order=[(acctbal):desc_first])}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name}) - """, + "rank_with_filters_a", lambda: pd.DataFrame( { "n": [ @@ -755,13 +419,7 @@ pytest.param( ( rank_with_filters_b, - """ -ROOT(columns=[('n', n), ('r', r)], orderings=[]) - FILTER(condition=ENDSWITH(name, '0':string), columns={'n': n, 'r': r}) - FILTER(condition=r <= 30:int64, columns={'n': n, 'name': name, 'r': r}) - PROJECT(columns={'n': name, 'name': name, 'r': RANKING(args=[], partition=[], order=[(acctbal):desc_first])}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name}) - """, + "rank_with_filters_b", lambda: pd.DataFrame( { "n": [ @@ -778,17 +436,7 @@ pytest.param( ( rank_with_filters_c, - """ -ROOT(columns=[('size', size), ('name', name)], orderings=[]) - FILTER(condition=RANKING(args=[], partition=[size], order=[(retail_price):desc_first]) == 1:int64, columns={'name': name, 'size': size}) - PROJECT(columns={'name': name, 'retail_price': retail_price, 'size': size_1}) - JOIN(conditions=[t0.size == t1.size], types=['inner'], columns={'name': t1.name, 'retail_price': t1.retail_price, 'size_1': t1.size}) - LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'size': size}, orderings=[(ordering_0):desc_last]) - PROJECT(columns={'ordering_0': size, 'size': size}) - AGGREGATE(keys={'size': size}, aggregations={}) - SCAN(table=tpch.PART, columns={'size': p_size}) - SCAN(table=tpch.PART, columns={'name': p_name, 'retail_price': p_retailprice, 'size': p_size}) - """, + "rank_with_filters_c", lambda: pd.DataFrame( { "size": [46, 47, 48, 49, 50], @@ -807,11 +455,7 @@ pytest.param( ( percentile_nations, - """ -ROOT(columns=[('name', name), ('p', p)], orderings=[]) - PROJECT(columns={'name': name, 'p': PERCENTILE(args=[], partition=[], order=[(name):asc_last], n_buckets=5)}) - SCAN(table=tpch.NATION, columns={'name': n_name}) - """, + "percentile_nations", lambda: pd.DataFrame( { "name": [ @@ -850,17 +494,7 @@ pytest.param( ( percentile_customers_per_region, - """ -ROOT(columns=[('name', name)], orderings=[(ordering_0):asc_first]) - PROJECT(columns={'name': name, 'ordering_0': name}) - FILTER(condition=PERCENTILE(args=[], partition=[key], order=[(acctbal):asc_last]) == 95:int64 & ENDSWITH(phone, '00':string), columns={'name': name}) - PROJECT(columns={'acctbal': acctbal, 'key': key, 'name': name_6, 'phone': phone}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'key': t0.key, 'name_6': t1.name, 'phone': t1.phone}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name, 'nation_key': c_nationkey, 'phone': c_phone}) - """, + "percentile_customers_per_region", lambda: pd.DataFrame( { "name": [ @@ -883,19 +517,7 @@ pytest.param( ( regional_suppliers_percentile, - """ -ROOT(columns=[('name', name)], orderings=[]) - FILTER(condition=True:bool & PERCENTILE(args=[], partition=[key], order=[(DEFAULT_TO(agg_0, 0:int64)):asc_last, (name):asc_last], n_buckets=1000) == 1000:int64, columns={'name': name}) - JOIN(conditions=[t0.key_5 == t1.supplier_key], types=['inner'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'name': t0.name}) - PROJECT(columns={'key': key, 'key_5': key_5, 'name': name_6}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'name_6': t1.name}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) - AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': COUNT()}) - SCAN(table=tpch.PARTSUPP, columns={'supplier_key': ps_suppkey}) - """, + "regional_suppliers_percentile", lambda: pd.DataFrame( { "name": [ @@ -916,18 +538,7 @@ pytest.param( ( function_sampler, - """ -ROOT(columns=[('a', a), ('b', b), ('c', c), ('d', d), ('e', e)], orderings=[(ordering_0):asc_first]) - LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'a': a, 'b': b, 'c': c, 'd': d, 'e': e, 'ordering_0': ordering_0}, orderings=[(ordering_0):asc_first]) - PROJECT(columns={'a': a, 'b': b, 'c': c, 'd': d, 'e': e, 'ordering_0': address}) - FILTER(condition=MONOTONIC(0.0:float64, acctbal, 100.0:float64), columns={'a': a, 'address': address, 'b': b, 'c': c, 'd': d, 'e': e}) - PROJECT(columns={'a': JOIN_STRINGS('-':string, name, name_3, SLICE(name_6, 16:int64, None:unknown, None:unknown)), 'acctbal': acctbal, 'address': address, 'b': ROUND(acctbal, 1:int64), 'c': KEEP_IF(name_6, SLICE(phone, None:unknown, 1:int64, None:unknown) == '3':string), 'd': PRESENT(KEEP_IF(name_6, SLICE(phone, 1:int64, 2:int64, None:unknown) == '1':string)), 'e': ABSENT(KEEP_IF(name_6, SLICE(phone, 14:int64, None:unknown, None:unknown) == '7':string))}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'address': t1.address, 'name': t0.name, 'name_3': t0.name_3, 'name_6': t1.name, 'phone': t1.phone}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'address': c_address, 'name': c_name, 'nation_key': c_nationkey, 'phone': c_phone}) - """, + "function_sampler", lambda: pd.DataFrame( { "a": [ @@ -968,15 +579,7 @@ pytest.param( ( agg_partition, - """ -ROOT(columns=[('best_year', best_year)], orderings=[]) - PROJECT(columns={'best_year': agg_0}) - AGGREGATE(keys={}, aggregations={'agg_0': MAX(n_orders)}) - PROJECT(columns={'n_orders': DEFAULT_TO(agg_0, 0:int64)}) - AGGREGATE(keys={'year': year}, aggregations={'agg_0': COUNT()}) - PROJECT(columns={'year': YEAR(order_date)}) - SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate}) - """, + "agg_partition", lambda: pd.DataFrame( { "best_year": [228637], @@ -988,15 +591,7 @@ pytest.param( ( double_partition, - """ -ROOT(columns=[('year', year), ('best_month', best_month)], orderings=[]) - PROJECT(columns={'best_month': agg_0, 'year': year}) - AGGREGATE(keys={'year': year}, aggregations={'agg_0': MAX(n_orders)}) - PROJECT(columns={'n_orders': DEFAULT_TO(agg_0, 0:int64), 'year': year}) - AGGREGATE(keys={'month': month, 'year': year}, aggregations={'agg_0': COUNT()}) - PROJECT(columns={'month': MONTH(order_date), 'year': YEAR(order_date)}) - SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate}) - """, + "double_partition", lambda: pd.DataFrame( { "year": [1992, 1993, 1994, 1995, 1996, 1997, 1998], @@ -1009,38 +604,7 @@ pytest.param( ( triple_partition, - """ -ROOT(columns=[('supp_region', supp_region), ('avg_percentage', avg_percentage)], orderings=[(ordering_1):asc_first]) - PROJECT(columns={'avg_percentage': avg_percentage, 'ordering_1': supp_region, 'supp_region': supp_region}) - PROJECT(columns={'avg_percentage': agg_0, 'supp_region': supp_region}) - AGGREGATE(keys={'supp_region': supp_region}, aggregations={'agg_0': AVG(percentage)}) - PROJECT(columns={'percentage': 100.0:float64 * agg_0 / DEFAULT_TO(agg_1, 0:int64), 'supp_region': supp_region}) - AGGREGATE(keys={'cust_region': cust_region, 'supp_region': supp_region}, aggregations={'agg_0': MAX(n_instances), 'agg_1': SUM(n_instances)}) - PROJECT(columns={'cust_region': cust_region, 'n_instances': DEFAULT_TO(agg_0, 0:int64), 'supp_region': supp_region}) - AGGREGATE(keys={'cust_region': cust_region, 'part_type': part_type, 'supp_region': supp_region}, aggregations={'agg_0': COUNT()}) - PROJECT(columns={'cust_region': name_15, 'part_type': part_type, 'supp_region': supp_region}) - JOIN(conditions=[t0.customer_key == t1.key], types=['left'], columns={'name_15': t1.name_15, 'part_type': t0.part_type, 'supp_region': t0.supp_region}) - FILTER(condition=YEAR(order_date) == 1992:int64, columns={'customer_key': customer_key, 'part_type': part_type, 'supp_region': supp_region}) - JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'customer_key': t1.customer_key, 'order_date': t1.order_date, 'part_type': t0.part_type, 'supp_region': t0.supp_region}) - PROJECT(columns={'order_key': order_key, 'part_type': part_type, 'supp_region': name_7}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'name_7': t1.name_7, 'order_key': t0.order_key, 'part_type': t0.part_type}) - FILTER(condition=MONTH(ship_date) == 6:int64 & YEAR(ship_date) == 1992:int64, columns={'order_key': order_key, 'part_type': part_type, 'supplier_key': supplier_key}) - JOIN(conditions=[t0.key == t1.part_key], types=['inner'], columns={'order_key': t1.order_key, 'part_type': t0.part_type, 'ship_date': t1.ship_date, 'supplier_key': t1.supplier_key}) - FILTER(condition=STARTSWITH(container, 'SM':string), columns={'key': key, 'part_type': part_type}) - SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'part_type': p_type}) - SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'part_key': l_partkey, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_7': t1.name}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) - SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_15': t1.name}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - """, + "triple_partition", lambda: pd.DataFrame( { "supp_region": [ @@ -1074,8 +638,8 @@ def pydough_pipeline_test_data( arguments: 1. `unqualified_impl`: a function that takes in an unqualified root and creates the unqualified node for the TPCH query. - 2. `relational_str`: the string representation of the relational plan - produced for the TPCH query. + 2. `file_name`: the name of the file containing the expected relational + plan. 3. `answer_impl`: a function that takes in nothing and returns the answer to a TPCH query as a Pandas DataFrame. """ @@ -1088,6 +652,8 @@ def test_pipeline_until_relational( ], get_sample_graph: graph_fetcher, default_config: PyDoughConfigs, + get_plan_test_filename: Callable[[str], str], + update_plan_tests: bool, ) -> None: """ Tests that a PyDough unqualified node can be correctly translated to its @@ -1096,7 +662,8 @@ def test_pipeline_until_relational( # Run the query through the stages from unqualified node to qualified node # to relational tree, and confirm the tree string matches the expected # structure. - unqualified_impl, relational_string, _ = pydough_pipeline_test_data + unqualified_impl, file_name, _ = pydough_pipeline_test_data + file_path: str = get_plan_test_filename(file_name) graph: GraphMetadata = get_sample_graph("TPCH") UnqualifiedRoot(graph) unqualified: UnqualifiedNode = init_pydough_context(graph)(unqualified_impl)() @@ -1105,9 +672,15 @@ def test_pipeline_until_relational( qualified, PyDoughCollectionQDAG ), "Expected qualified answer to be a collection, not an expression" relational: RelationalRoot = convert_ast_to_relational(qualified, default_config) - assert ( - relational.to_tree_string() == relational_string.strip() - ), "Mismatch between tree string representation of relational node and expected Relational tree string" + if update_plan_tests: + with open(file_path, "w") as f: + f.write(relational.to_tree_string() + "\n") + else: + with open(file_path) as f: + expected_relational_string: str = f.read() + assert ( + relational.to_tree_string() == expected_relational_string.strip() + ), "Mismatch between tree string representation of relational node and expected Relational tree string" @pytest.mark.execute diff --git a/tests/test_plan_refsols/agg_partition.txt b/tests/test_plan_refsols/agg_partition.txt new file mode 100644 index 00000000..0b8ffdc3 --- /dev/null +++ b/tests/test_plan_refsols/agg_partition.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('best_year', best_year)], orderings=[]) + PROJECT(columns={'best_year': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': MAX(n_orders)}) + PROJECT(columns={'n_orders': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={'year': year}, aggregations={'agg_0': COUNT()}) + PROJECT(columns={'year': YEAR(order_date)}) + SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate}) diff --git a/tests/test_plan_refsols/double_partition.txt b/tests/test_plan_refsols/double_partition.txt new file mode 100644 index 00000000..9a351bfb --- /dev/null +++ b/tests/test_plan_refsols/double_partition.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('year', year), ('best_month', best_month)], orderings=[]) + PROJECT(columns={'best_month': agg_0, 'year': year}) + AGGREGATE(keys={'year': year}, aggregations={'agg_0': MAX(n_orders)}) + PROJECT(columns={'n_orders': DEFAULT_TO(agg_0, 0:int64), 'year': year}) + AGGREGATE(keys={'month': month, 'year': year}, aggregations={'agg_0': COUNT()}) + PROJECT(columns={'month': MONTH(order_date), 'year': YEAR(order_date)}) + SCAN(table=tpch.ORDERS, columns={'order_date': o_orderdate}) diff --git a/tests/test_plan_refsols/function_sampler.txt b/tests/test_plan_refsols/function_sampler.txt new file mode 100644 index 00000000..86507720 --- /dev/null +++ b/tests/test_plan_refsols/function_sampler.txt @@ -0,0 +1,10 @@ +ROOT(columns=[('a', a), ('b', b), ('c', c), ('d', d), ('e', e)], orderings=[(ordering_0):asc_first]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'a': a, 'b': b, 'c': c, 'd': d, 'e': e, 'ordering_0': ordering_0}, orderings=[(ordering_0):asc_first]) + PROJECT(columns={'a': a, 'b': b, 'c': c, 'd': d, 'e': e, 'ordering_0': address}) + FILTER(condition=MONOTONIC(0.0:float64, acctbal, 100.0:float64), columns={'a': a, 'address': address, 'b': b, 'c': c, 'd': d, 'e': e}) + PROJECT(columns={'a': JOIN_STRINGS('-':string, name, name_3, SLICE(name_6, 16:int64, None:unknown, None:unknown)), 'acctbal': acctbal, 'address': address, 'b': ROUND(acctbal, 1:int64), 'c': KEEP_IF(name_6, SLICE(phone, None:unknown, 1:int64, None:unknown) == '3':string), 'd': PRESENT(KEEP_IF(name_6, SLICE(phone, 1:int64, 2:int64, None:unknown) == '1':string)), 'e': ABSENT(KEEP_IF(name_6, SLICE(phone, 14:int64, None:unknown, None:unknown) == '7':string))}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'address': t1.address, 'name': t0.name, 'name_3': t0.name_3, 'name_6': t1.name, 'phone': t1.phone}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'address': c_address, 'name': c_name, 'nation_key': c_nationkey, 'phone': c_phone}) diff --git a/tests/test_plan_refsols/percentile_customers_per_region.txt b/tests/test_plan_refsols/percentile_customers_per_region.txt new file mode 100644 index 00000000..e6265ee2 --- /dev/null +++ b/tests/test_plan_refsols/percentile_customers_per_region.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('name', name)], orderings=[(ordering_0):asc_first]) + PROJECT(columns={'name': name, 'ordering_0': name}) + FILTER(condition=PERCENTILE(args=[], partition=[key], order=[(acctbal):asc_last]) == 95:int64 & ENDSWITH(phone, '00':string), columns={'name': name}) + PROJECT(columns={'acctbal': acctbal, 'key': key, 'name': name_6, 'phone': phone}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'key': t0.key, 'name_6': t1.name, 'phone': t1.phone}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name, 'nation_key': c_nationkey, 'phone': c_phone}) diff --git a/tests/test_plan_refsols/percentile_nations.txt b/tests/test_plan_refsols/percentile_nations.txt new file mode 100644 index 00000000..19a91388 --- /dev/null +++ b/tests/test_plan_refsols/percentile_nations.txt @@ -0,0 +1,3 @@ +ROOT(columns=[('name', name), ('p', p)], orderings=[]) + PROJECT(columns={'name': name, 'p': PERCENTILE(args=[], partition=[], order=[(name):asc_last], n_buckets=5)}) + SCAN(table=tpch.NATION, columns={'name': n_name}) diff --git a/tests/test_plan_refsols/rank_nations_by_region.txt b/tests/test_plan_refsols/rank_nations_by_region.txt new file mode 100644 index 00000000..1b1bc038 --- /dev/null +++ b/tests/test_plan_refsols/rank_nations_by_region.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('name', name), ('rank', rank)], orderings=[]) + PROJECT(columns={'name': name, 'rank': RANKING(args=[], partition=[], order=[(name_3):asc_last], allow_ties=True)}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/rank_nations_per_region_by_customers.txt b/tests/test_plan_refsols/rank_nations_per_region_by_customers.txt new file mode 100644 index 00000000..b4a8cea3 --- /dev/null +++ b/tests/test_plan_refsols/rank_nations_per_region_by_customers.txt @@ -0,0 +1,10 @@ +ROOT(columns=[('name', name), ('rank', rank)], orderings=[(ordering_1):asc_first]) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'name': name, 'ordering_1': ordering_1, 'rank': rank}, orderings=[(ordering_1):asc_first]) + PROJECT(columns={'name': name, 'ordering_1': rank, 'rank': rank}) + PROJECT(columns={'name': name_3, 'rank': RANKING(args=[], partition=[key], order=[(DEFAULT_TO(agg_0, 0:int64)):desc_first])}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'name_3': t0.name_3}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) + SCAN(table=tpch.CUSTOMER, columns={'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/rank_parts_per_supplier_region_by_size.txt b/tests/test_plan_refsols/rank_parts_per_supplier_region_by_size.txt new file mode 100644 index 00000000..ffec8e5d --- /dev/null +++ b/tests/test_plan_refsols/rank_parts_per_supplier_region_by_size.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('key', key), ('region', region), ('rank', rank)], orderings=[(ordering_0):asc_first]) + LIMIT(limit=Literal(value=15, type=Int64Type()), columns={'key': key, 'ordering_0': ordering_0, 'rank': rank, 'region': region}, orderings=[(ordering_0):asc_first]) + PROJECT(columns={'key': key, 'ordering_0': key, 'rank': rank, 'region': region}) + PROJECT(columns={'key': key_9, 'rank': RANKING(args=[], partition=[key], order=[(size):desc_first, (container):desc_first, (part_type):desc_first], allow_ties=True, dense=True), 'region': name}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'container': t1.container, 'key': t0.key, 'key_9': t1.key, 'name': t0.name, 'part_type': t1.part_type, 'size': t1.size}) + JOIN(conditions=[t0.key_5 == t1.supplier_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'part_key': t1.part_key}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'name': t0.name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'part_type': p_type, 'size': p_size}) diff --git a/tests/test_plan_refsols/rank_with_filters_a.txt b/tests/test_plan_refsols/rank_with_filters_a.txt new file mode 100644 index 00000000..9e1cd1f7 --- /dev/null +++ b/tests/test_plan_refsols/rank_with_filters_a.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('n', n), ('r', r)], orderings=[]) + FILTER(condition=r <= 30:int64, columns={'n': n, 'r': r}) + FILTER(condition=ENDSWITH(name, '0':string), columns={'n': n, 'r': r}) + PROJECT(columns={'n': name, 'name': name, 'r': RANKING(args=[], partition=[], order=[(acctbal):desc_first])}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name}) diff --git a/tests/test_plan_refsols/rank_with_filters_b.txt b/tests/test_plan_refsols/rank_with_filters_b.txt new file mode 100644 index 00000000..30d6edc0 --- /dev/null +++ b/tests/test_plan_refsols/rank_with_filters_b.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('n', n), ('r', r)], orderings=[]) + FILTER(condition=ENDSWITH(name, '0':string), columns={'n': n, 'r': r}) + FILTER(condition=r <= 30:int64, columns={'n': n, 'name': name, 'r': r}) + PROJECT(columns={'n': name, 'name': name, 'r': RANKING(args=[], partition=[], order=[(acctbal):desc_first])}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name}) diff --git a/tests/test_plan_refsols/rank_with_filters_c.txt b/tests/test_plan_refsols/rank_with_filters_c.txt new file mode 100644 index 00000000..8f072927 --- /dev/null +++ b/tests/test_plan_refsols/rank_with_filters_c.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('size', size), ('name', name)], orderings=[]) + FILTER(condition=RANKING(args=[], partition=[size], order=[(retail_price):desc_first]) == 1:int64, columns={'name': name, 'size': size}) + PROJECT(columns={'name': name, 'retail_price': retail_price, 'size': size_1}) + JOIN(conditions=[t0.size == t1.size], types=['inner'], columns={'name': t1.name, 'retail_price': t1.retail_price, 'size_1': t1.size}) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'size': size}, orderings=[(ordering_0):desc_last]) + PROJECT(columns={'ordering_0': size, 'size': size}) + AGGREGATE(keys={'size': size}, aggregations={}) + SCAN(table=tpch.PART, columns={'size': p_size}) + SCAN(table=tpch.PART, columns={'name': p_name, 'retail_price': p_retailprice, 'size': p_size}) diff --git a/tests/test_plan_refsols/regional_suppliers_percentile.txt b/tests/test_plan_refsols/regional_suppliers_percentile.txt new file mode 100644 index 00000000..f9f22338 --- /dev/null +++ b/tests/test_plan_refsols/regional_suppliers_percentile.txt @@ -0,0 +1,11 @@ +ROOT(columns=[('name', name)], orderings=[]) + FILTER(condition=True:bool & PERCENTILE(args=[], partition=[key], order=[(DEFAULT_TO(agg_0, 0:int64)):asc_last, (name):asc_last], n_buckets=1000) == 1000:int64, columns={'name': name}) + JOIN(conditions=[t0.key_5 == t1.supplier_key], types=['inner'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'name': t0.name}) + PROJECT(columns={'key': key, 'key_5': key_5, 'name': name_6}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'name_6': t1.name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': COUNT()}) + SCAN(table=tpch.PARTSUPP, columns={'supplier_key': ps_suppkey}) diff --git a/tests/test_plan_refsols/simple_filter_top_five.txt b/tests/test_plan_refsols/simple_filter_top_five.txt new file mode 100644 index 00000000..9049674a --- /dev/null +++ b/tests/test_plan_refsols/simple_filter_top_five.txt @@ -0,0 +1,5 @@ +ROOT(columns=[('key', key), ('total_price', total_price)], orderings=[(ordering_0):desc_last]) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'key': key, 'ordering_0': ordering_0, 'total_price': total_price}, orderings=[(ordering_0):desc_last]) + PROJECT(columns={'key': key, 'ordering_0': key, 'total_price': total_price}) + FILTER(condition=total_price < 1000.0:float64, columns={'key': key, 'total_price': total_price}) + SCAN(table=tpch.ORDERS, columns={'key': o_orderkey, 'total_price': o_totalprice}) diff --git a/tests/test_plan_refsols/simple_scan_top_five.txt b/tests/test_plan_refsols/simple_scan_top_five.txt new file mode 100644 index 00000000..dd524cb1 --- /dev/null +++ b/tests/test_plan_refsols/simple_scan_top_five.txt @@ -0,0 +1,4 @@ +ROOT(columns=[('key', key)], orderings=[(ordering_0):asc_first]) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'key': key, 'ordering_0': ordering_0}, orderings=[(ordering_0):asc_first]) + PROJECT(columns={'key': key, 'ordering_0': key}) + SCAN(table=tpch.ORDERS, columns={'key': o_orderkey}) diff --git a/tests/test_plan_refsols/tpch_q1.txt b/tests/test_plan_refsols/tpch_q1.txt new file mode 100644 index 00000000..a881e5b7 --- /dev/null +++ b/tests/test_plan_refsols/tpch_q1.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('L_RETURNFLAG', L_RETURNFLAG), ('L_LINESTATUS', L_LINESTATUS), ('SUM_QTY', SUM_QTY), ('SUM_BASE_PRICE', SUM_BASE_PRICE), ('SUM_DISC_PRICE', SUM_DISC_PRICE), ('SUM_CHARGE', SUM_CHARGE), ('AVG_QTY', AVG_QTY), ('AVG_PRICE', AVG_PRICE), ('AVG_DISC', AVG_DISC), ('COUNT_ORDER', COUNT_ORDER)], orderings=[(ordering_8):asc_first, (ordering_9):asc_first]) + PROJECT(columns={'AVG_DISC': AVG_DISC, 'AVG_PRICE': AVG_PRICE, 'AVG_QTY': AVG_QTY, 'COUNT_ORDER': COUNT_ORDER, 'L_LINESTATUS': L_LINESTATUS, 'L_RETURNFLAG': L_RETURNFLAG, 'SUM_BASE_PRICE': SUM_BASE_PRICE, 'SUM_CHARGE': SUM_CHARGE, 'SUM_DISC_PRICE': SUM_DISC_PRICE, 'SUM_QTY': SUM_QTY, 'ordering_8': L_RETURNFLAG, 'ordering_9': L_LINESTATUS}) + PROJECT(columns={'AVG_DISC': agg_0, 'AVG_PRICE': agg_1, 'AVG_QTY': agg_2, 'COUNT_ORDER': DEFAULT_TO(agg_3, 0:int64), 'L_LINESTATUS': status, 'L_RETURNFLAG': return_flag, 'SUM_BASE_PRICE': DEFAULT_TO(agg_4, 0:int64), 'SUM_CHARGE': DEFAULT_TO(agg_5, 0:int64), 'SUM_DISC_PRICE': DEFAULT_TO(agg_6, 0:int64), 'SUM_QTY': DEFAULT_TO(agg_7, 0:int64)}) + AGGREGATE(keys={'return_flag': return_flag, 'status': status}, aggregations={'agg_0': AVG(discount), 'agg_1': AVG(extended_price), 'agg_2': AVG(quantity), 'agg_3': COUNT(), 'agg_4': SUM(extended_price), 'agg_5': SUM(extended_price * 1:int64 - discount * 1:int64 + tax), 'agg_6': SUM(extended_price * 1:int64 - discount), 'agg_7': SUM(quantity)}) + FILTER(condition=ship_date <= datetime.date(1998, 12, 1):date, columns={'discount': discount, 'extended_price': extended_price, 'quantity': quantity, 'return_flag': return_flag, 'status': status, 'tax': tax}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'quantity': l_quantity, 'return_flag': l_returnflag, 'ship_date': l_shipdate, 'status': l_linestatus, 'tax': l_tax}) diff --git a/tests/test_plan_refsols/tpch_q10.txt b/tests/test_plan_refsols/tpch_q10.txt new file mode 100644 index 00000000..da802660 --- /dev/null +++ b/tests/test_plan_refsols/tpch_q10.txt @@ -0,0 +1,15 @@ +ROOT(columns=[('C_CUSTKEY', C_CUSTKEY), ('C_NAME', C_NAME), ('REVENUE', REVENUE), ('C_ACCTBAL', C_ACCTBAL), ('N_NAME', N_NAME), ('C_ADDRESS', C_ADDRESS), ('C_PHONE', C_PHONE), ('C_COMMENT', C_COMMENT)], orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) + LIMIT(limit=Literal(value=20, type=Int64Type()), columns={'C_ACCTBAL': C_ACCTBAL, 'C_ADDRESS': C_ADDRESS, 'C_COMMENT': C_COMMENT, 'C_CUSTKEY': C_CUSTKEY, 'C_NAME': C_NAME, 'C_PHONE': C_PHONE, 'N_NAME': N_NAME, 'REVENUE': REVENUE, 'ordering_1': ordering_1, 'ordering_2': ordering_2}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) + PROJECT(columns={'C_ACCTBAL': C_ACCTBAL, 'C_ADDRESS': C_ADDRESS, 'C_COMMENT': C_COMMENT, 'C_CUSTKEY': C_CUSTKEY, 'C_NAME': C_NAME, 'C_PHONE': C_PHONE, 'N_NAME': N_NAME, 'REVENUE': REVENUE, 'ordering_1': REVENUE, 'ordering_2': C_CUSTKEY}) + PROJECT(columns={'C_ACCTBAL': acctbal, 'C_ADDRESS': address, 'C_COMMENT': comment, 'C_CUSTKEY': key, 'C_NAME': name, 'C_PHONE': phone, 'N_NAME': name_4, 'REVENUE': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[t0.nation_key == t1.key], types=['left'], columns={'acctbal': t0.acctbal, 'address': t0.address, 'agg_0': t0.agg_0, 'comment': t0.comment, 'key': t0.key, 'name': t0.name, 'name_4': t1.name, 'phone': t0.phone}) + JOIN(conditions=[t0.key == t1.customer_key], types=['left'], columns={'acctbal': t0.acctbal, 'address': t0.address, 'agg_0': t1.agg_0, 'comment': t0.comment, 'key': t0.key, 'name': t0.name, 'nation_key': t0.nation_key, 'phone': t0.phone}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'address': c_address, 'comment': c_comment, 'key': c_custkey, 'name': c_name, 'nation_key': c_nationkey, 'phone': c_phone}) + AGGREGATE(keys={'customer_key': customer_key}, aggregations={'agg_0': SUM(amt)}) + PROJECT(columns={'amt': extended_price * 1:int64 - discount, 'customer_key': customer_key}) + FILTER(condition=return_flag == 'R':string, columns={'customer_key': customer_key, 'discount': discount, 'extended_price': extended_price}) + JOIN(conditions=[t0.key == t1.order_key], types=['inner'], columns={'customer_key': t0.customer_key, 'discount': t1.discount, 'extended_price': t1.extended_price, 'return_flag': t1.return_flag}) + FILTER(condition=order_date >= datetime.date(1993, 10, 1):date & order_date < datetime.date(1994, 1, 1):date, columns={'customer_key': customer_key, 'key': key}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'return_flag': l_returnflag}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) diff --git a/tests/test_plan_refsols/tpch_q11.txt b/tests/test_plan_refsols/tpch_q11.txt new file mode 100644 index 00000000..68f294eb --- /dev/null +++ b/tests/test_plan_refsols/tpch_q11.txt @@ -0,0 +1,23 @@ +ROOT(columns=[('PS_PARTKEY', PS_PARTKEY), ('VALUE', VALUE)], orderings=[(ordering_2):desc_last]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'PS_PARTKEY': PS_PARTKEY, 'VALUE': VALUE, 'ordering_2': ordering_2}, orderings=[(ordering_2):desc_last]) + PROJECT(columns={'PS_PARTKEY': PS_PARTKEY, 'VALUE': VALUE, 'ordering_2': VALUE}) + FILTER(condition=VALUE > min_market_share, columns={'PS_PARTKEY': PS_PARTKEY, 'VALUE': VALUE}) + PROJECT(columns={'PS_PARTKEY': part_key, 'VALUE': DEFAULT_TO(agg_1, 0:int64), 'min_market_share': min_market_share}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1, 'min_market_share': t0.min_market_share, 'part_key': t1.part_key}) + PROJECT(columns={'min_market_share': DEFAULT_TO(agg_0, 0:int64) * 0.0001:float64}) + AGGREGATE(keys={}, aggregations={'agg_0': SUM(metric)}) + PROJECT(columns={'metric': supplycost * availqty}) + FILTER(condition=name_3 == 'GERMANY':string, columns={'availqty': availqty, 'supplycost': supplycost}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'availqty': t0.availqty, 'name_3': t1.name_3, 'supplycost': t0.supplycost}) + SCAN(table=tpch.PARTSUPP, columns={'availqty': ps_availqty, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_3': t1.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'part_key': part_key}, aggregations={'agg_1': SUM(metric)}) + PROJECT(columns={'metric': supplycost * availqty, 'part_key': part_key}) + FILTER(condition=name_6 == 'GERMANY':string, columns={'availqty': availqty, 'part_key': part_key, 'supplycost': supplycost}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'availqty': t0.availqty, 'name_6': t1.name_6, 'part_key': t0.part_key, 'supplycost': t0.supplycost}) + SCAN(table=tpch.PARTSUPP, columns={'availqty': ps_availqty, 'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_6': t1.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) diff --git a/tests/test_plan_refsols/tpch_q12.txt b/tests/test_plan_refsols/tpch_q12.txt new file mode 100644 index 00000000..2594c04e --- /dev/null +++ b/tests/test_plan_refsols/tpch_q12.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('L_SHIPMODE', L_SHIPMODE), ('HIGH_LINE_COUNT', HIGH_LINE_COUNT), ('LOW_LINE_COUNT', LOW_LINE_COUNT)], orderings=[(ordering_2):asc_first]) + PROJECT(columns={'HIGH_LINE_COUNT': HIGH_LINE_COUNT, 'LOW_LINE_COUNT': LOW_LINE_COUNT, 'L_SHIPMODE': L_SHIPMODE, 'ordering_2': L_SHIPMODE}) + PROJECT(columns={'HIGH_LINE_COUNT': DEFAULT_TO(agg_0, 0:int64), 'LOW_LINE_COUNT': DEFAULT_TO(agg_1, 0:int64), 'L_SHIPMODE': ship_mode}) + AGGREGATE(keys={'ship_mode': ship_mode}, aggregations={'agg_0': SUM(is_high_priority), 'agg_1': SUM(NOT(is_high_priority))}) + PROJECT(columns={'is_high_priority': order_priority == '1-URGENT':string | order_priority == '2-HIGH':string, 'ship_mode': ship_mode}) + JOIN(conditions=[t0.order_key == t1.key], types=['left'], columns={'order_priority': t1.order_priority, 'ship_mode': t0.ship_mode}) + FILTER(condition=ship_mode == 'MAIL':string | ship_mode == 'SHIP':string & ship_date < commit_date & commit_date < receipt_date & receipt_date >= datetime.date(1994, 1, 1):date & receipt_date < datetime.date(1995, 1, 1):date, columns={'order_key': order_key, 'ship_mode': ship_mode}) + SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate, 'ship_date': l_shipdate, 'ship_mode': l_shipmode}) + SCAN(table=tpch.ORDERS, columns={'key': o_orderkey, 'order_priority': o_orderpriority}) diff --git a/tests/test_plan_refsols/tpch_q13.txt b/tests/test_plan_refsols/tpch_q13.txt new file mode 100644 index 00000000..9b0accb9 --- /dev/null +++ b/tests/test_plan_refsols/tpch_q13.txt @@ -0,0 +1,11 @@ +ROOT(columns=[('C_COUNT', C_COUNT), ('CUSTDIST', CUSTDIST)], orderings=[(ordering_1):desc_last, (ordering_2):desc_last]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'CUSTDIST': CUSTDIST, 'C_COUNT': C_COUNT, 'ordering_1': ordering_1, 'ordering_2': ordering_2}, orderings=[(ordering_1):desc_last, (ordering_2):desc_last]) + PROJECT(columns={'CUSTDIST': CUSTDIST, 'C_COUNT': C_COUNT, 'ordering_1': CUSTDIST, 'ordering_2': C_COUNT}) + PROJECT(columns={'CUSTDIST': DEFAULT_TO(agg_0, 0:int64), 'C_COUNT': num_non_special_orders}) + AGGREGATE(keys={'num_non_special_orders': num_non_special_orders}, aggregations={'agg_0': COUNT()}) + PROJECT(columns={'num_non_special_orders': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[t0.key == t1.customer_key], types=['left'], columns={'agg_0': t1.agg_0}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey}) + AGGREGATE(keys={'customer_key': customer_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=NOT(LIKE(comment, '%special%requests%':string)), columns={'customer_key': customer_key}) + SCAN(table=tpch.ORDERS, columns={'comment': o_comment, 'customer_key': o_custkey}) diff --git a/tests/test_plan_refsols/tpch_q14.txt b/tests/test_plan_refsols/tpch_q14.txt new file mode 100644 index 00000000..f282efed --- /dev/null +++ b/tests/test_plan_refsols/tpch_q14.txt @@ -0,0 +1,8 @@ +ROOT(columns=[('PROMO_REVENUE', PROMO_REVENUE)], orderings=[]) + PROJECT(columns={'PROMO_REVENUE': 100.0:float64 * DEFAULT_TO(agg_0, 0:int64) / DEFAULT_TO(agg_1, 0:int64)}) + AGGREGATE(keys={}, aggregations={'agg_0': SUM(promo_value), 'agg_1': SUM(value)}) + PROJECT(columns={'promo_value': IFF(STARTSWITH(part_type, 'PROMO':string), extended_price * 1:int64 - discount, 0:int64), 'value': extended_price * 1:int64 - discount}) + JOIN(conditions=[t0.part_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'part_type': t1.part_type}) + FILTER(condition=ship_date >= datetime.date(1995, 9, 1):date & ship_date < datetime.date(1995, 10, 1):date, columns={'discount': discount, 'extended_price': extended_price, 'part_key': part_key}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'part_key': l_partkey, 'ship_date': l_shipdate}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'part_type': p_type}) diff --git a/tests/test_plan_refsols/tpch_q15.txt b/tests/test_plan_refsols/tpch_q15.txt new file mode 100644 index 00000000..0adfd50f --- /dev/null +++ b/tests/test_plan_refsols/tpch_q15.txt @@ -0,0 +1,18 @@ +ROOT(columns=[('S_SUPPKEY', S_SUPPKEY), ('S_NAME', S_NAME), ('S_ADDRESS', S_ADDRESS), ('S_PHONE', S_PHONE), ('TOTAL_REVENUE', TOTAL_REVENUE)], orderings=[(ordering_2):asc_first]) + PROJECT(columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME, 'S_PHONE': S_PHONE, 'S_SUPPKEY': S_SUPPKEY, 'TOTAL_REVENUE': TOTAL_REVENUE, 'ordering_2': S_SUPPKEY}) + FILTER(condition=TOTAL_REVENUE == max_revenue, columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME, 'S_PHONE': S_PHONE, 'S_SUPPKEY': S_SUPPKEY, 'TOTAL_REVENUE': TOTAL_REVENUE}) + PROJECT(columns={'S_ADDRESS': address, 'S_NAME': name, 'S_PHONE': phone, 'S_SUPPKEY': key, 'TOTAL_REVENUE': DEFAULT_TO(agg_1, 0:int64), 'max_revenue': max_revenue}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'address': t0.address, 'agg_1': t1.agg_1, 'key': t0.key, 'max_revenue': t0.max_revenue, 'name': t0.name, 'phone': t0.phone}) + JOIN(conditions=[True:bool], types=['inner'], columns={'address': t1.address, 'key': t1.key, 'max_revenue': t0.max_revenue, 'name': t1.name, 'phone': t1.phone}) + PROJECT(columns={'max_revenue': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': MAX(total_revenue)}) + PROJECT(columns={'total_revenue': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'agg_0': t1.agg_0}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': SUM(extended_price * 1:int64 - discount)}) + FILTER(condition=ship_date >= datetime.date(1996, 1, 1):date & ship_date < datetime.date(1996, 4, 1):date, columns={'discount': discount, 'extended_price': extended_price, 'supplier_key': supplier_key}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'address': s_address, 'key': s_suppkey, 'name': s_name, 'phone': s_phone}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_1': SUM(extended_price * 1:int64 - discount)}) + FILTER(condition=ship_date >= datetime.date(1996, 1, 1):date & ship_date < datetime.date(1996, 4, 1):date, columns={'discount': discount, 'extended_price': extended_price, 'supplier_key': supplier_key}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) diff --git a/tests/test_plan_refsols/tpch_q16.txt b/tests/test_plan_refsols/tpch_q16.txt new file mode 100644 index 00000000..094e075a --- /dev/null +++ b/tests/test_plan_refsols/tpch_q16.txt @@ -0,0 +1,13 @@ +ROOT(columns=[('P_BRAND', P_BRAND), ('P_TYPE', P_TYPE), ('P_SIZE', P_SIZE), ('SUPPLIER_COUNT', SUPPLIER_COUNT)], orderings=[(ordering_1):desc_last, (ordering_2):asc_first, (ordering_3):asc_first, (ordering_4):asc_first]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'P_BRAND': P_BRAND, 'P_SIZE': P_SIZE, 'P_TYPE': P_TYPE, 'SUPPLIER_COUNT': SUPPLIER_COUNT, 'ordering_1': ordering_1, 'ordering_2': ordering_2, 'ordering_3': ordering_3, 'ordering_4': ordering_4}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first, (ordering_3):asc_first, (ordering_4):asc_first]) + PROJECT(columns={'P_BRAND': P_BRAND, 'P_SIZE': P_SIZE, 'P_TYPE': P_TYPE, 'SUPPLIER_COUNT': SUPPLIER_COUNT, 'ordering_1': SUPPLIER_COUNT, 'ordering_2': P_BRAND, 'ordering_3': P_TYPE, 'ordering_4': P_SIZE}) + PROJECT(columns={'P_BRAND': p_brand, 'P_SIZE': p_size, 'P_TYPE': p_type, 'SUPPLIER_COUNT': agg_0}) + AGGREGATE(keys={'p_brand': p_brand, 'p_size': p_size, 'p_type': p_type}, aggregations={'agg_0': NDISTINCT(supplier_key)}) + FILTER(condition=NOT(LIKE(comment_2, '%Customer%Complaints%':string)), columns={'p_brand': p_brand, 'p_size': p_size, 'p_type': p_type, 'supplier_key': supplier_key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'comment_2': t1.comment, 'p_brand': t0.p_brand, 'p_size': t0.p_size, 'p_type': t0.p_type, 'supplier_key': t0.supplier_key}) + PROJECT(columns={'p_brand': brand, 'p_size': size, 'p_type': part_type, 'supplier_key': supplier_key}) + JOIN(conditions=[t0.key == t1.part_key], types=['inner'], columns={'brand': t0.brand, 'part_type': t0.part_type, 'size': t0.size, 'supplier_key': t1.supplier_key}) + FILTER(condition=brand != 'BRAND#45':string & NOT(STARTSWITH(part_type, 'MEDIUM POLISHED%':string)) & ISIN(size, [49, 14, 23, 45, 19, 3, 36, 9]:array[unknown]), columns={'brand': brand, 'key': key, 'part_type': part_type, 'size': size}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'key': p_partkey, 'part_type': p_type, 'size': p_size}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'comment': s_comment, 'key': s_suppkey}) diff --git a/tests/test_plan_refsols/tpch_q17.txt b/tests/test_plan_refsols/tpch_q17.txt new file mode 100644 index 00000000..09e258eb --- /dev/null +++ b/tests/test_plan_refsols/tpch_q17.txt @@ -0,0 +1,12 @@ +ROOT(columns=[('AVG_YEARLY', AVG_YEARLY)], orderings=[]) + PROJECT(columns={'AVG_YEARLY': DEFAULT_TO(agg_0, 0:int64) / 7.0:float64}) + AGGREGATE(keys={}, aggregations={'agg_0': SUM(extended_price)}) + FILTER(condition=quantity < 0.2:float64 * avg_quantity, columns={'extended_price': extended_price}) + JOIN(conditions=[t0.key == t1.part_key], types=['inner'], columns={'avg_quantity': t0.avg_quantity, 'extended_price': t1.extended_price, 'quantity': t1.quantity}) + PROJECT(columns={'avg_quantity': agg_0, 'key': key}) + JOIN(conditions=[t0.key == t1.part_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key}) + FILTER(condition=brand == 'Brand#23':string & container == 'MED BOX':string, columns={'key': key}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'container': p_container, 'key': p_partkey}) + AGGREGATE(keys={'part_key': part_key}, aggregations={'agg_0': AVG(quantity)}) + SCAN(table=tpch.LINEITEM, columns={'part_key': l_partkey, 'quantity': l_quantity}) + SCAN(table=tpch.LINEITEM, columns={'extended_price': l_extendedprice, 'part_key': l_partkey, 'quantity': l_quantity}) diff --git a/tests/test_plan_refsols/tpch_q18.txt b/tests/test_plan_refsols/tpch_q18.txt new file mode 100644 index 00000000..6728b401 --- /dev/null +++ b/tests/test_plan_refsols/tpch_q18.txt @@ -0,0 +1,11 @@ +ROOT(columns=[('C_NAME', C_NAME), ('C_CUSTKEY', C_CUSTKEY), ('O_ORDERKEY', O_ORDERKEY), ('O_ORDERDATE', O_ORDERDATE), ('O_TOTALPRICE', O_TOTALPRICE), ('TOTAL_QUANTITY', TOTAL_QUANTITY)], orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'C_CUSTKEY': C_CUSTKEY, 'C_NAME': C_NAME, 'O_ORDERDATE': O_ORDERDATE, 'O_ORDERKEY': O_ORDERKEY, 'O_TOTALPRICE': O_TOTALPRICE, 'TOTAL_QUANTITY': TOTAL_QUANTITY, 'ordering_1': ordering_1, 'ordering_2': ordering_2}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first]) + PROJECT(columns={'C_CUSTKEY': C_CUSTKEY, 'C_NAME': C_NAME, 'O_ORDERDATE': O_ORDERDATE, 'O_ORDERKEY': O_ORDERKEY, 'O_TOTALPRICE': O_TOTALPRICE, 'TOTAL_QUANTITY': TOTAL_QUANTITY, 'ordering_1': O_TOTALPRICE, 'ordering_2': O_ORDERDATE}) + FILTER(condition=TOTAL_QUANTITY > 300:int64, columns={'C_CUSTKEY': C_CUSTKEY, 'C_NAME': C_NAME, 'O_ORDERDATE': O_ORDERDATE, 'O_ORDERKEY': O_ORDERKEY, 'O_TOTALPRICE': O_TOTALPRICE, 'TOTAL_QUANTITY': TOTAL_QUANTITY}) + PROJECT(columns={'C_CUSTKEY': key_2, 'C_NAME': name, 'O_ORDERDATE': order_date, 'O_ORDERKEY': key, 'O_TOTALPRICE': total_price, 'TOTAL_QUANTITY': DEFAULT_TO(agg_0, 0:int64)}) + JOIN(conditions=[t0.key == t1.order_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key, 'key_2': t0.key_2, 'name': t0.name, 'order_date': t0.order_date, 'total_price': t0.total_price}) + JOIN(conditions=[t0.customer_key == t1.key], types=['left'], columns={'key': t0.key, 'key_2': t1.key, 'name': t1.name, 'order_date': t0.order_date, 'total_price': t0.total_price}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'name': c_name}) + AGGREGATE(keys={'order_key': order_key}, aggregations={'agg_0': SUM(quantity)}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'quantity': l_quantity}) diff --git a/tests/test_plan_refsols/tpch_q19.txt b/tests/test_plan_refsols/tpch_q19.txt new file mode 100644 index 00000000..f8761aee --- /dev/null +++ b/tests/test_plan_refsols/tpch_q19.txt @@ -0,0 +1,7 @@ +ROOT(columns=[('REVENUE', REVENUE)], orderings=[]) + PROJECT(columns={'REVENUE': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={}, aggregations={'agg_0': SUM(extended_price * 1:int64 - discount)}) + FILTER(condition=ISIN(ship_mode, ['AIR', 'AIR REG']:array[unknown]) & ship_instruct == 'DELIVER IN PERSON':string & size >= 1:int64 & size <= 5:int64 & quantity >= 1:int64 & quantity <= 11:int64 & ISIN(container, ['SM CASE', 'SM BOX', 'SM PACK', 'SM PKG']:array[unknown]) & brand == 'Brand#12':string | size <= 10:int64 & quantity >= 10:int64 & quantity <= 20:int64 & ISIN(container, ['MED BAG', 'MED BOX', 'MED PACK', 'MED PKG']:array[unknown]) & brand == 'Brand#23':string | size <= 15:int64 & quantity >= 20:int64 & quantity <= 30:int64 & ISIN(container, ['LG CASE', 'LG BOX', 'LG PACK', 'LG PKG']:array[unknown]) & brand == 'Brand#34':string, columns={'discount': discount, 'extended_price': extended_price}) + JOIN(conditions=[t0.part_key == t1.key], types=['left'], columns={'brand': t1.brand, 'container': t1.container, 'discount': t0.discount, 'extended_price': t0.extended_price, 'quantity': t0.quantity, 'ship_instruct': t0.ship_instruct, 'ship_mode': t0.ship_mode, 'size': t1.size}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'part_key': l_partkey, 'quantity': l_quantity, 'ship_instruct': l_shipinstruct, 'ship_mode': l_shipmode}) + SCAN(table=tpch.PART, columns={'brand': p_brand, 'container': p_container, 'key': p_partkey, 'size': p_size}) diff --git a/tests/test_plan_refsols/tpch_q2.txt b/tests/test_plan_refsols/tpch_q2.txt new file mode 100644 index 00000000..461d2e30 --- /dev/null +++ b/tests/test_plan_refsols/tpch_q2.txt @@ -0,0 +1,31 @@ +ROOT(columns=[('S_ACCTBAL', S_ACCTBAL), ('S_NAME', S_NAME), ('N_NAME', N_NAME), ('P_PARTKEY', P_PARTKEY), ('P_MFGR', P_MFGR), ('S_ADDRESS', S_ADDRESS), ('S_PHONE', S_PHONE), ('S_COMMENT', S_COMMENT)], orderings=[(ordering_1):desc_last, (ordering_2):asc_first, (ordering_3):asc_first, (ordering_4):asc_first]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'N_NAME': N_NAME, 'P_MFGR': P_MFGR, 'P_PARTKEY': P_PARTKEY, 'S_ACCTBAL': S_ACCTBAL, 'S_ADDRESS': S_ADDRESS, 'S_COMMENT': S_COMMENT, 'S_NAME': S_NAME, 'S_PHONE': S_PHONE, 'ordering_1': ordering_1, 'ordering_2': ordering_2, 'ordering_3': ordering_3, 'ordering_4': ordering_4}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first, (ordering_3):asc_first, (ordering_4):asc_first]) + PROJECT(columns={'N_NAME': N_NAME, 'P_MFGR': P_MFGR, 'P_PARTKEY': P_PARTKEY, 'S_ACCTBAL': S_ACCTBAL, 'S_ADDRESS': S_ADDRESS, 'S_COMMENT': S_COMMENT, 'S_NAME': S_NAME, 'S_PHONE': S_PHONE, 'ordering_1': S_ACCTBAL, 'ordering_2': N_NAME, 'ordering_3': S_NAME, 'ordering_4': P_PARTKEY}) + PROJECT(columns={'N_NAME': n_name, 'P_MFGR': manufacturer, 'P_PARTKEY': key_19, 'S_ACCTBAL': s_acctbal, 'S_ADDRESS': s_address, 'S_COMMENT': s_comment, 'S_NAME': s_name, 'S_PHONE': s_phone}) + FILTER(condition=supplycost_21 == best_cost & ENDSWITH(part_type, 'BRASS':string) & size == 15:int64, columns={'key_19': key_19, 'manufacturer': manufacturer, 'n_name': n_name, 's_acctbal': s_acctbal, 's_address': s_address, 's_comment': s_comment, 's_name': s_name, 's_phone': s_phone}) + JOIN(conditions=[t0.key_9 == t1.key_19], types=['inner'], columns={'best_cost': t0.best_cost, 'key_19': t1.key_19, 'manufacturer': t1.manufacturer, 'n_name': t1.n_name, 'part_type': t1.part_type, 's_acctbal': t1.s_acctbal, 's_address': t1.s_address, 's_comment': t1.s_comment, 's_name': t1.s_name, 's_phone': t1.s_phone, 'size': t1.size, 'supplycost_21': t1.supplycost}) + PROJECT(columns={'best_cost': agg_0, 'key_9': key_9}) + AGGREGATE(keys={'key_9': key_9}, aggregations={'agg_0': MIN(supplycost)}) + FILTER(condition=ENDSWITH(part_type, 'BRASS':string) & size == 15:int64, columns={'key_9': key_9, 'supplycost': supplycost}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'key_9': t1.key, 'part_type': t1.part_type, 'size': t1.size, 'supplycost': t0.supplycost}) + JOIN(conditions=[t0.key_5 == t1.supplier_key], types=['inner'], columns={'part_key': t1.part_key, 'supplycost': t1.supplycost}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_5': t1.key}) + FILTER(condition=name_3 == 'EUROPE':string, columns={'key': key}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'part_type': p_type, 'size': p_size}) + FILTER(condition=ENDSWITH(part_type, 'BRASS':string) & size == 15:int64, columns={'key_19': key_19, 'manufacturer': manufacturer, 'n_name': n_name, 'part_type': part_type, 's_acctbal': s_acctbal, 's_address': s_address, 's_comment': s_comment, 's_name': s_name, 's_phone': s_phone, 'size': size, 'supplycost': supplycost}) + PROJECT(columns={'key_19': key_19, 'manufacturer': manufacturer, 'n_name': name, 'part_type': part_type, 's_acctbal': account_balance, 's_address': address, 's_comment': comment_14, 's_name': name_16, 's_phone': phone, 'size': size, 'supplycost': supplycost}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'account_balance': t0.account_balance, 'address': t0.address, 'comment_14': t0.comment_14, 'key_19': t1.key, 'manufacturer': t1.manufacturer, 'name': t0.name, 'name_16': t0.name_16, 'part_type': t1.part_type, 'phone': t0.phone, 'size': t1.size, 'supplycost': t0.supplycost}) + JOIN(conditions=[t0.key_15 == t1.supplier_key], types=['inner'], columns={'account_balance': t0.account_balance, 'address': t0.address, 'comment_14': t0.comment_14, 'name': t0.name, 'name_16': t0.name_16, 'part_key': t1.part_key, 'phone': t0.phone, 'supplycost': t1.supplycost}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'account_balance': t1.account_balance, 'address': t1.address, 'comment_14': t1.comment, 'key_15': t1.key, 'name': t0.name, 'name_16': t1.name, 'phone': t1.phone}) + FILTER(condition=name_13 == 'EUROPE':string, columns={'key': key, 'name': name}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_13': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'address': s_address, 'comment': s_comment, 'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey, 'phone': s_phone}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'manufacturer': p_mfgr, 'part_type': p_type, 'size': p_size}) diff --git a/tests/test_plan_refsols/tpch_q20.txt b/tests/test_plan_refsols/tpch_q20.txt new file mode 100644 index 00000000..92c4420c --- /dev/null +++ b/tests/test_plan_refsols/tpch_q20.txt @@ -0,0 +1,18 @@ +ROOT(columns=[('S_NAME', S_NAME), ('S_ADDRESS', S_ADDRESS)], orderings=[(ordering_1):asc_first]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME, 'ordering_1': ordering_1}, orderings=[(ordering_1):asc_first]) + PROJECT(columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME, 'ordering_1': S_NAME}) + FILTER(condition=name_3 == 'CANADA':string & DEFAULT_TO(agg_0, 0:int64) > 0:int64, columns={'S_ADDRESS': S_ADDRESS, 'S_NAME': S_NAME}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'S_ADDRESS': t0.S_ADDRESS, 'S_NAME': t0.S_NAME, 'agg_0': t1.agg_0, 'name_3': t0.name_3}) + JOIN(conditions=[t0.nation_key == t1.key], types=['left'], columns={'S_ADDRESS': t0.S_ADDRESS, 'S_NAME': t0.S_NAME, 'key': t0.key, 'name_3': t1.name}) + PROJECT(columns={'S_ADDRESS': address, 'S_NAME': name, 'key': key, 'nation_key': nation_key}) + SCAN(table=tpch.SUPPLIER, columns={'address': s_address, 'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=STARTSWITH(name, 'forest':string) & availqty > DEFAULT_TO(agg_0, 0:int64) * 0.5:float64, columns={'supplier_key': supplier_key}) + JOIN(conditions=[t0.key == t1.part_key], types=['left'], columns={'agg_0': t1.agg_0, 'availqty': t0.availqty, 'name': t0.name, 'supplier_key': t0.supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'availqty': t0.availqty, 'key': t1.key, 'name': t1.name, 'supplier_key': t0.supplier_key}) + SCAN(table=tpch.PARTSUPP, columns={'availqty': ps_availqty, 'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'name': p_name}) + AGGREGATE(keys={'part_key': part_key}, aggregations={'agg_0': SUM(quantity)}) + FILTER(condition=ship_date >= datetime.date(1994, 1, 1):date & ship_date < datetime.date(1995, 1, 1):date, columns={'part_key': part_key, 'quantity': quantity}) + SCAN(table=tpch.LINEITEM, columns={'part_key': l_partkey, 'quantity': l_quantity, 'ship_date': l_shipdate}) diff --git a/tests/test_plan_refsols/tpch_q3.txt b/tests/test_plan_refsols/tpch_q3.txt new file mode 100644 index 00000000..607c49fb --- /dev/null +++ b/tests/test_plan_refsols/tpch_q3.txt @@ -0,0 +1,12 @@ +ROOT(columns=[('L_ORDERKEY', L_ORDERKEY), ('REVENUE', REVENUE), ('O_ORDERDATE', O_ORDERDATE), ('O_SHIPPRIORITY', O_SHIPPRIORITY)], orderings=[(ordering_1):desc_last, (ordering_2):asc_first, (ordering_3):asc_first]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'L_ORDERKEY': L_ORDERKEY, 'O_ORDERDATE': O_ORDERDATE, 'O_SHIPPRIORITY': O_SHIPPRIORITY, 'REVENUE': REVENUE, 'ordering_1': ordering_1, 'ordering_2': ordering_2, 'ordering_3': ordering_3}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first, (ordering_3):asc_first]) + PROJECT(columns={'L_ORDERKEY': L_ORDERKEY, 'O_ORDERDATE': O_ORDERDATE, 'O_SHIPPRIORITY': O_SHIPPRIORITY, 'REVENUE': REVENUE, 'ordering_1': REVENUE, 'ordering_2': O_ORDERDATE, 'ordering_3': L_ORDERKEY}) + PROJECT(columns={'L_ORDERKEY': order_key, 'O_ORDERDATE': order_date, 'O_SHIPPRIORITY': ship_priority, 'REVENUE': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={'order_date': order_date, 'order_key': order_key, 'ship_priority': ship_priority}, aggregations={'agg_0': SUM(extended_price * 1:int64 - discount)}) + FILTER(condition=ship_date > datetime.date(1995, 3, 15):date, columns={'discount': discount, 'extended_price': extended_price, 'order_date': order_date, 'order_key': order_key, 'ship_priority': ship_priority}) + JOIN(conditions=[t0.key == t1.order_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'order_date': t0.order_date, 'order_key': t1.order_key, 'ship_date': t1.ship_date, 'ship_priority': t0.ship_priority}) + FILTER(condition=mktsegment == 'BUILDING':string & order_date < datetime.date(1995, 3, 15):date, columns={'key': key, 'order_date': order_date, 'ship_priority': ship_priority}) + JOIN(conditions=[t0.customer_key == t1.key], types=['left'], columns={'key': t0.key, 'mktsegment': t1.mktsegment, 'order_date': t0.order_date, 'ship_priority': t0.ship_priority}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate, 'ship_priority': o_shippriority}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'mktsegment': c_mktsegment}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'ship_date': l_shipdate}) diff --git a/tests/test_plan_refsols/tpch_q4.txt b/tests/test_plan_refsols/tpch_q4.txt new file mode 100644 index 00000000..8f1dffe0 --- /dev/null +++ b/tests/test_plan_refsols/tpch_q4.txt @@ -0,0 +1,9 @@ +ROOT(columns=[('O_ORDERPRIORITY', O_ORDERPRIORITY), ('ORDER_COUNT', ORDER_COUNT)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'ORDER_COUNT': ORDER_COUNT, 'O_ORDERPRIORITY': O_ORDERPRIORITY, 'ordering_1': O_ORDERPRIORITY}) + PROJECT(columns={'ORDER_COUNT': DEFAULT_TO(agg_0, 0:int64), 'O_ORDERPRIORITY': order_priority}) + AGGREGATE(keys={'order_priority': order_priority}, aggregations={'agg_0': COUNT()}) + FILTER(condition=order_date >= datetime.date(1993, 7, 1):date & order_date < datetime.date(1993, 10, 1):date & True:bool, columns={'order_priority': order_priority}) + JOIN(conditions=[t0.key == t1.order_key], types=['semi'], columns={'order_date': t0.order_date, 'order_priority': t0.order_priority}) + SCAN(table=tpch.ORDERS, columns={'key': o_orderkey, 'order_date': o_orderdate, 'order_priority': o_orderpriority}) + FILTER(condition=commit_date < receipt_date, columns={'order_key': order_key}) + SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate}) diff --git a/tests/test_plan_refsols/tpch_q6.txt b/tests/test_plan_refsols/tpch_q6.txt new file mode 100644 index 00000000..844de11b --- /dev/null +++ b/tests/test_plan_refsols/tpch_q6.txt @@ -0,0 +1,6 @@ +ROOT(columns=[('REVENUE', REVENUE)], orderings=[]) + PROJECT(columns={'REVENUE': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={}, aggregations={'agg_0': SUM(amt)}) + PROJECT(columns={'amt': extended_price * discount}) + FILTER(condition=ship_date >= datetime.date(1994, 1, 1):date & ship_date < datetime.date(1995, 1, 1):date & discount >= 0.05:float64 & discount <= 0.07:float64 & quantity < 24:int64, columns={'discount': discount, 'extended_price': extended_price}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'quantity': l_quantity, 'ship_date': l_shipdate}) diff --git a/tests/test_plan_refsols/tpch_q7.txt b/tests/test_plan_refsols/tpch_q7.txt new file mode 100644 index 00000000..cd104627 --- /dev/null +++ b/tests/test_plan_refsols/tpch_q7.txt @@ -0,0 +1,17 @@ +ROOT(columns=[('SUPP_NATION', SUPP_NATION), ('CUST_NATION', CUST_NATION), ('L_YEAR', L_YEAR), ('REVENUE', REVENUE)], orderings=[(ordering_1):asc_first, (ordering_2):asc_first, (ordering_3):asc_first]) + PROJECT(columns={'CUST_NATION': CUST_NATION, 'L_YEAR': L_YEAR, 'REVENUE': REVENUE, 'SUPP_NATION': SUPP_NATION, 'ordering_1': SUPP_NATION, 'ordering_2': CUST_NATION, 'ordering_3': L_YEAR}) + PROJECT(columns={'CUST_NATION': cust_nation, 'L_YEAR': l_year, 'REVENUE': DEFAULT_TO(agg_0, 0:int64), 'SUPP_NATION': supp_nation}) + AGGREGATE(keys={'cust_nation': cust_nation, 'l_year': l_year, 'supp_nation': supp_nation}, aggregations={'agg_0': SUM(volume)}) + FILTER(condition=ship_date >= datetime.date(1995, 1, 1):date & ship_date <= datetime.date(1996, 12, 31):date & supp_nation == 'FRANCE':string & cust_nation == 'GERMANY':string | supp_nation == 'GERMANY':string & cust_nation == 'FRANCE':string, columns={'cust_nation': cust_nation, 'l_year': l_year, 'supp_nation': supp_nation, 'volume': volume}) + PROJECT(columns={'cust_nation': name_8, 'l_year': YEAR(ship_date), 'ship_date': ship_date, 'supp_nation': name_3, 'volume': extended_price * 1:int64 - discount}) + JOIN(conditions=[t0.order_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name_3': t0.name_3, 'name_8': t1.name_8, 'ship_date': t0.ship_date}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name_3': t1.name_3, 'order_key': t0.order_key, 'ship_date': t0.ship_date}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_3': t1.name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_8': t1.name}) + JOIN(conditions=[t0.customer_key == t1.key], types=['inner'], columns={'key': t0.key, 'nation_key': t1.nation_key}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) diff --git a/tests/test_plan_refsols/tpch_q8.txt b/tests/test_plan_refsols/tpch_q8.txt new file mode 100644 index 00000000..7ae110a9 --- /dev/null +++ b/tests/test_plan_refsols/tpch_q8.txt @@ -0,0 +1,24 @@ +ROOT(columns=[('O_YEAR', O_YEAR), ('MKT_SHARE', MKT_SHARE)], orderings=[]) + PROJECT(columns={'MKT_SHARE': DEFAULT_TO(agg_0, 0:int64) / DEFAULT_TO(agg_1, 0:int64), 'O_YEAR': o_year}) + AGGREGATE(keys={'o_year': o_year}, aggregations={'agg_0': SUM(brazil_volume), 'agg_1': SUM(volume)}) + FILTER(condition=order_date >= datetime.date(1995, 1, 1):date & order_date <= datetime.date(1996, 12, 31):date & name_18 == 'AMERICA':string, columns={'brazil_volume': brazil_volume, 'o_year': o_year, 'volume': volume}) + JOIN(conditions=[t0.customer_key == t1.key], types=['left'], columns={'brazil_volume': t0.brazil_volume, 'name_18': t1.name_18, 'o_year': t0.o_year, 'order_date': t0.order_date, 'volume': t0.volume}) + PROJECT(columns={'brazil_volume': IFF(name == 'BRAZIL':string, volume, 0:int64), 'customer_key': customer_key, 'o_year': YEAR(order_date), 'order_date': order_date, 'volume': volume}) + JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'customer_key': t1.customer_key, 'name': t0.name, 'order_date': t1.order_date, 'volume': t0.volume}) + PROJECT(columns={'name': name, 'order_key': order_key, 'volume': extended_price * 1:int64 - discount}) + JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'name': t0.name, 'order_key': t1.order_key}) + FILTER(condition=part_type == 'ECONOMY ANODIZED STEEL':string, columns={'name': name, 'part_key': part_key, 'supplier_key': supplier_key}) + JOIN(conditions=[t0.part_key == t1.key], types=['left'], columns={'name': t0.name, 'part_key': t0.part_key, 'part_type': t1.part_type, 'supplier_key': t0.supplier_key}) + JOIN(conditions=[t0.key_2 == t1.supplier_key], types=['inner'], columns={'name': t0.name, 'part_key': t1.part_key, 'supplier_key': t1.supplier_key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'part_type': p_type}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'part_key': l_partkey, 'supplier_key': l_suppkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_18': t1.name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/tpch_q9.txt b/tests/test_plan_refsols/tpch_q9.txt new file mode 100644 index 00000000..73b07439 --- /dev/null +++ b/tests/test_plan_refsols/tpch_q9.txt @@ -0,0 +1,18 @@ +ROOT(columns=[('NATION', NATION), ('O_YEAR', O_YEAR), ('AMOUNT', AMOUNT)], orderings=[(ordering_1):asc_first, (ordering_2):desc_last]) + LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'AMOUNT': AMOUNT, 'NATION': NATION, 'O_YEAR': O_YEAR, 'ordering_1': ordering_1, 'ordering_2': ordering_2}, orderings=[(ordering_1):asc_first, (ordering_2):desc_last]) + PROJECT(columns={'AMOUNT': AMOUNT, 'NATION': NATION, 'O_YEAR': O_YEAR, 'ordering_1': NATION, 'ordering_2': O_YEAR}) + PROJECT(columns={'AMOUNT': DEFAULT_TO(agg_0, 0:int64), 'NATION': nation, 'O_YEAR': o_year}) + AGGREGATE(keys={'nation': nation, 'o_year': o_year}, aggregations={'agg_0': SUM(value)}) + PROJECT(columns={'nation': name, 'o_year': YEAR(order_date), 'value': extended_price * 1:int64 - discount - supplycost * quantity}) + JOIN(conditions=[t0.order_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name': t0.name, 'order_date': t1.order_date, 'quantity': t0.quantity, 'supplycost': t0.supplycost}) + JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'name': t0.name, 'order_key': t1.order_key, 'quantity': t1.quantity, 'supplycost': t0.supplycost}) + FILTER(condition=CONTAINS(name_7, 'green':string), columns={'name': name, 'part_key': part_key, 'supplier_key': supplier_key, 'supplycost': supplycost}) + JOIN(conditions=[t0.part_key == t1.key], types=['left'], columns={'name': t0.name, 'name_7': t1.name, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key, 'supplycost': t0.supplycost}) + JOIN(conditions=[t0.key_2 == t1.supplier_key], types=['inner'], columns={'name': t0.name, 'part_key': t1.part_key, 'supplier_key': t1.supplier_key, 'supplycost': t1.supplycost}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) + SCAN(table=tpch.PART, columns={'key': p_partkey, 'name': p_name}) + SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'part_key': l_partkey, 'quantity': l_quantity, 'supplier_key': l_suppkey}) + SCAN(table=tpch.ORDERS, columns={'key': o_orderkey, 'order_date': o_orderdate}) diff --git a/tests/test_plan_refsols/triple_partition.txt b/tests/test_plan_refsols/triple_partition.txt new file mode 100644 index 00000000..d078683d --- /dev/null +++ b/tests/test_plan_refsols/triple_partition.txt @@ -0,0 +1,30 @@ +ROOT(columns=[('supp_region', supp_region), ('avg_percentage', avg_percentage)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'avg_percentage': avg_percentage, 'ordering_1': supp_region, 'supp_region': supp_region}) + PROJECT(columns={'avg_percentage': agg_0, 'supp_region': supp_region}) + AGGREGATE(keys={'supp_region': supp_region}, aggregations={'agg_0': AVG(percentage)}) + PROJECT(columns={'percentage': 100.0:float64 * agg_0 / DEFAULT_TO(agg_1, 0:int64), 'supp_region': supp_region}) + AGGREGATE(keys={'cust_region': cust_region, 'supp_region': supp_region}, aggregations={'agg_0': MAX(n_instances), 'agg_1': SUM(n_instances)}) + PROJECT(columns={'cust_region': cust_region, 'n_instances': DEFAULT_TO(agg_0, 0:int64), 'supp_region': supp_region}) + AGGREGATE(keys={'cust_region': cust_region, 'part_type': part_type, 'supp_region': supp_region}, aggregations={'agg_0': COUNT()}) + PROJECT(columns={'cust_region': name_15, 'part_type': part_type, 'supp_region': supp_region}) + JOIN(conditions=[t0.customer_key == t1.key], types=['left'], columns={'name_15': t1.name_15, 'part_type': t0.part_type, 'supp_region': t0.supp_region}) + FILTER(condition=YEAR(order_date) == 1992:int64, columns={'customer_key': customer_key, 'part_type': part_type, 'supp_region': supp_region}) + JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'customer_key': t1.customer_key, 'order_date': t1.order_date, 'part_type': t0.part_type, 'supp_region': t0.supp_region}) + PROJECT(columns={'order_key': order_key, 'part_type': part_type, 'supp_region': name_7}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'name_7': t1.name_7, 'order_key': t0.order_key, 'part_type': t0.part_type}) + FILTER(condition=MONTH(ship_date) == 6:int64 & YEAR(ship_date) == 1992:int64, columns={'order_key': order_key, 'part_type': part_type, 'supplier_key': supplier_key}) + JOIN(conditions=[t0.key == t1.part_key], types=['inner'], columns={'order_key': t1.order_key, 'part_type': t0.part_type, 'ship_date': t1.ship_date, 'supplier_key': t1.supplier_key}) + FILTER(condition=STARTSWITH(container, 'SM':string), columns={'key': key, 'part_type': part_type}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'part_type': p_type}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'part_key': l_partkey, 'ship_date': l_shipdate, 'supplier_key': l_suppkey}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_7': t1.name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) + SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_15': t1.name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'region_key': t1.region_key}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_qdag_conversion.py b/tests/test_qdag_conversion.py index 27fc6f8a..8997ada5 100644 --- a/tests/test_qdag_conversion.py +++ b/tests/test_qdag_conversion.py @@ -2189,8 +2189,8 @@ def relational_test_data(request) -> tuple[CollectionTestInfo, str]: """ Input data for `test_ast_to_relational`. Parameters are the info to build - the input QDAG nodes, and the expected output string after converting to a - relational tree. + the input QDAG nodes, and the name of the file containing the expected + output string after converting to a relational tree. """ return request.param @@ -2324,8 +2324,9 @@ def test_ast_to_relational( def relational_alternative_config_test_data(request) -> tuple[CollectionTestInfo, str]: """ Input data for `test_ast_to_relational_alternative_aggregation_configs`. - Parameters are the info to build the input QDAG nodes, and the expected - output string after converting to a relational tree. + Parameters are the info to build the input QDAG nodes, and the name of the + file containing the expected output string after converting to a relational + tree. """ return request.param From 64311bcb3f932655212666346e3952d591145440 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 3 Feb 2025 14:25:20 -0500 Subject: [PATCH 083/112] [RUN CI] From 8d6878f60c2f3f00617cd75bf9488e77bafd30c4 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 3 Feb 2025 14:32:36 -0500 Subject: [PATCH 084/112] Adding comments [RUN CI] --- pydough/conversion/relational_converter.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 981f74b7..f597fdc6 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -128,7 +128,16 @@ def get_column_name( self, name: str, existing_names: dict[str, RelationalExpression] ) -> str: """ - TODO + Replaces a name for a new column with another name if the name is + already being used. + + Args: + `name`: the name of the column to be replaced. + `existing_names`: the dictionary of existing column names that + are already being used in the current relational tree. + + Returns: + A string based on `name` that is not part of `existing_names`. """ new_name: str = name while new_name in existing_names: @@ -703,6 +712,8 @@ def translate_calc( rel_expr: RelationalExpression = self.translate_expression( hybrid_expr, context ) + # Ensure the name of the new column is not already being used. If + # it is, choose a new name. if name in proj_columns and proj_columns[name] != rel_expr: name = self.get_column_name(name, proj_columns) proj_columns[name] = rel_expr @@ -929,7 +940,7 @@ def convert_ast_to_relational( # Convert the QDAG node to the hybrid form, then invoke the relational # conversion procedure. The first element in the returned list is the # final rel node. - hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node, None) + hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node) renamings: dict[str, str] = hybrid.pipeline[-1].renamings output: TranslationOutput = translator.rel_translation( None, hybrid, len(hybrid.pipeline) - 1 From 6949cdd67cc953a3c89c942e5255c451714b71db Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 4 Feb 2025 11:37:46 -0500 Subject: [PATCH 085/112] Revisions [RUN CI] --- pydough/conversion/relational_converter.py | 45 ++++++++++--------- pydough/relational/__init__.py | 4 +- .../relational/relational_nodes/__init__.py | 4 +- .../relational_nodes/abstract_node.py | 18 ++++---- .../relational/relational_nodes/aggregate.py | 10 ++--- .../relational_nodes/column_pruner.py | 14 +++--- .../relational_nodes/empty_singleton.py | 12 ++--- pydough/relational/relational_nodes/filter.py | 10 ++--- pydough/relational/relational_nodes/join.py | 14 +++--- pydough/relational/relational_nodes/limit.py | 10 ++--- .../relational/relational_nodes/project.py | 10 ++--- .../relational_expression_dispatcher.py | 4 +- .../relational_nodes/relational_root.py | 10 ++--- pydough/relational/relational_nodes/scan.py | 12 ++--- .../relational_nodes/single_relational.py | 14 +++--- tests/test_relational.py | 20 +++++---- tests/test_relational_execution.py | 12 ++--- tests/test_relational_nodes_to_sqlglot.py | 4 +- 18 files changed, 119 insertions(+), 108 deletions(-) diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index f597fdc6..55de7714 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -35,8 +35,8 @@ Limit, LiteralExpression, Project, - Relational, RelationalExpression, + RelationalNode, RelationalRoot, Scan, WindowCallExpression, @@ -77,7 +77,7 @@ class TranslationOutput: access any HybridExpr's equivalent expression in the Relational node. """ - relation: Relational + relational_node: RelationalNode """ The relational tree describing the way to compute the answer for the logic originally in the hybrid tree. @@ -96,7 +96,7 @@ def __init__(self): # An index used for creating fake column names self.dummy_idx = 1 - def make_null_column(self, relation: Relational) -> ColumnReference: + def make_null_column(self, relation: RelationalNode) -> ColumnReference: """ Inserts a new column into the relation whose value is NULL. If such a column already exists, it is used. @@ -237,7 +237,7 @@ def join_outputs( # Special case: if the lhs is an EmptySingleton, just return the RHS, # decorated if needed. - if isinstance(lhs_result.relation, EmptySingleton): + if isinstance(lhs_result.relational_node, EmptySingleton): if child_idx is None: return rhs_result else: @@ -247,13 +247,13 @@ def join_outputs( expr.name, child_idx, expr.typ ) out_columns[child_ref] = col_ref - return TranslationOutput(rhs_result.relation, out_columns) + return TranslationOutput(rhs_result.relational_node, out_columns) # Create the join node so we know what aliases it uses, but leave # the condition as always-True and the output columns empty for now. # The condition & output columns will be filled in later. out_rel: Join = Join( - [lhs_result.relation, rhs_result.relation], + [lhs_result.relational_node, rhs_result.relational_node], [LiteralExpression(True, BooleanType())], [join_type], join_columns, @@ -373,7 +373,7 @@ def apply_aggregations( aggregations[agg_name] = CallExpression( agg_func.operator, agg_func.typ, args ) - out_rel: Relational = Aggregate(context.relation, keys, aggregations) + out_rel: RelationalNode = Aggregate(context.relational_node, keys, aggregations) return TranslationOutput(out_rel, out_columns) def handle_children( @@ -442,7 +442,7 @@ def handle_children( ) # Map every child_idx reference from child_output to null null_column: ColumnReference = self.make_null_column( - context.relation + context.relational_node ) for expr in child_output.expressions: if isinstance(expr, HybridRefExpr): @@ -601,7 +601,9 @@ def translate_partition( shifted_expr: HybridExpr | None = expr.shift_back(1) if shifted_expr is not None: expressions[shifted_expr] = ref - result: TranslationOutput = TranslationOutput(context.relation, expressions) + result: TranslationOutput = TranslationOutput( + context.relational_node, expressions + ) result = self.handle_children(result, hybrid, pipeline_idx) # Pull every aggregation key into the current context since it is now # accessible as a normal ref instead of a child ref. @@ -632,13 +634,13 @@ def translate_filter( """ # Keep all existing columns. kept_columns: dict[str, RelationalExpression] = { - name: ColumnReference(name, context.relation.columns[name].data_type) - for name in context.relation.columns + name: ColumnReference(name, context.relational_node.columns[name].data_type) + for name in context.relational_node.columns } condition: RelationalExpression = self.translate_expression( node.condition, context ) - out_rel: Filter = Filter(context.relation, condition, kept_columns) + out_rel: Filter = Filter(context.relational_node, condition, kept_columns) return TranslationOutput(out_rel, context.expressions) def translate_limit( @@ -662,8 +664,8 @@ def translate_limit( """ # Keep all existing columns. kept_columns: dict[str, RelationalExpression] = { - name: ColumnReference(name, context.relation.columns[name].data_type) - for name in context.relation.columns + name: ColumnReference(name, context.relational_node.columns[name].data_type) + for name in context.relational_node.columns } limit_expr: LiteralExpression = LiteralExpression( node.records_to_keep, Int64Type() @@ -671,7 +673,9 @@ def translate_limit( orderings: list[ExpressionSortInfo] = make_relational_ordering( node.orderings, context.expressions ) - out_rel: Limit = Limit(context.relation, limit_expr, kept_columns, orderings) + out_rel: Limit = Limit( + context.relational_node, limit_expr, kept_columns, orderings + ) return TranslationOutput(out_rel, context.expressions) def translate_calc( @@ -697,9 +701,9 @@ def translate_calc( proj_columns: dict[str, RelationalExpression] = {} out_columns: dict[HybridExpr, ColumnReference] = {} # Propagate all of the existing columns. - for name in context.relation.columns: + for name in context.relational_node.columns: proj_columns[name] = ColumnReference( - name, context.relation.columns[name].data_type + name, context.relational_node.columns[name].data_type ) for expr in context.expressions: out_columns[expr] = context.expressions[expr].with_input(None) @@ -713,12 +717,13 @@ def translate_calc( hybrid_expr, context ) # Ensure the name of the new column is not already being used. If - # it is, choose a new name. + # it is, choose a new name. The new name will be the original name + # with a numerical index appended to it. if name in proj_columns and proj_columns[name] != rel_expr: name = self.get_column_name(name, proj_columns) proj_columns[name] = rel_expr out_columns[ref_expr] = ColumnReference(name, rel_expr.data_type) - out_rel: Project = Project(context.relation, proj_columns) + out_rel: Project = Project(context.relational_node, proj_columns) return TranslationOutput(out_rel, out_columns) def translate_partition_child( @@ -964,6 +969,6 @@ def convert_ast_to_relational( if hybrid_orderings: orderings = make_relational_ordering(hybrid_orderings, output.expressions) unpruned_result: RelationalRoot = RelationalRoot( - output.relation, ordered_columns, orderings + output.relational_node, ordered_columns, orderings ) return ColumnPruner().prune_unused_columns(unpruned_result) diff --git a/pydough/relational/__init__.py b/pydough/relational/__init__.py index b79fe424..75591d4f 100644 --- a/pydough/relational/__init__.py +++ b/pydough/relational/__init__.py @@ -14,10 +14,10 @@ "Limit", "LiteralExpression", "Project", - "Relational", "RelationalExpression", "RelationalExpressionDispatcher", "RelationalExpressionVisitor", + "RelationalNode", "RelationalRoot", "RelationalVisitor", "Scan", @@ -45,8 +45,8 @@ JoinType, Limit, Project, - Relational, RelationalExpressionDispatcher, + RelationalNode, RelationalRoot, RelationalVisitor, Scan, diff --git a/pydough/relational/relational_nodes/__init__.py b/pydough/relational/relational_nodes/__init__.py index 08fa2d4c..96e5a474 100644 --- a/pydough/relational/relational_nodes/__init__.py +++ b/pydough/relational/relational_nodes/__init__.py @@ -12,13 +12,13 @@ "JoinType", "Limit", "Project", - "Relational", "RelationalExpressionDispatcher", + "RelationalNode", "RelationalRoot", "RelationalVisitor", "Scan", ] -from .abstract_node import Relational +from .abstract_node import RelationalNode from .aggregate import Aggregate from .column_pruner import ColumnPruner from .empty_singleton import EmptySingleton diff --git a/pydough/relational/relational_nodes/abstract_node.py b/pydough/relational/relational_nodes/abstract_node.py index 29e36de9..61c4e0ef 100644 --- a/pydough/relational/relational_nodes/abstract_node.py +++ b/pydough/relational/relational_nodes/abstract_node.py @@ -14,7 +14,7 @@ from .relational_visitor import RelationalVisitor -class Relational(ABC): +class RelationalNode(ABC): """ The base class for any relational node. This interface defines the basic structure of all relational nodes in the PyDough system. @@ -25,7 +25,7 @@ def __init__(self, columns: MutableMapping[str, RelationalExpression]) -> None: @property @abstractmethod - def inputs(self) -> MutableSequence["Relational"]: + def inputs(self) -> MutableSequence["RelationalNode"]: """ Returns any inputs to the current relational expression. @@ -60,7 +60,7 @@ def columns(self) -> MutableMapping[str, RelationalExpression]: return self._columns @abstractmethod - def node_equals(self, other: "Relational") -> bool: + def node_equals(self, other: "RelationalNode") -> bool: """ Determine if two relational nodes are exactly identical, excluding column generic column details shared by every @@ -74,7 +74,7 @@ def node_equals(self, other: "Relational") -> bool: bool: Are the two relational nodes equal. """ - def equals(self, other: "Relational") -> bool: + def equals(self, other: "RelationalNode") -> bool: """ Determine if two relational nodes are exactly identical, including column ordering. @@ -88,7 +88,7 @@ def equals(self, other: "Relational") -> bool: return self.node_equals(other) and self.columns == other.columns def __eq__(self, other: Any) -> bool: - return isinstance(other, Relational) and self.equals(other) + return isinstance(other, RelationalNode) and self.equals(other) def make_column_string( self, columns: MutableMapping[str, Any], compact: bool @@ -149,8 +149,8 @@ def accept(self, visitor: RelationalVisitor) -> None: def node_copy( self, columns: MutableMapping[str, RelationalExpression], - inputs: MutableSequence["Relational"], - ) -> "Relational": + inputs: MutableSequence["RelationalNode"], + ) -> "RelationalNode": """ Copy the given relational node with the provided columns and/or inputs. This copy maintains any additional properties of the @@ -169,8 +169,8 @@ def node_copy( def copy( self, columns: MutableMapping[str, RelationalExpression] | None = None, - inputs: MutableSequence["Relational"] | None = None, - ) -> "Relational": + inputs: MutableSequence["RelationalNode"] | None = None, + ) -> "RelationalNode": """ Copy the given relational node with the provided columns and/or inputs. This copy maintains any additional properties of the diff --git a/pydough/relational/relational_nodes/aggregate.py b/pydough/relational/relational_nodes/aggregate.py index 18151991..d2a2046e 100644 --- a/pydough/relational/relational_nodes/aggregate.py +++ b/pydough/relational/relational_nodes/aggregate.py @@ -12,7 +12,7 @@ RelationalExpression, ) -from .abstract_node import Relational +from .abstract_node import RelationalNode from .relational_visitor import RelationalVisitor from .single_relational import SingleRelational @@ -26,7 +26,7 @@ class Aggregate(SingleRelational): def __init__( self, - input: Relational, + input: RelationalNode, keys: MutableMapping[str, ColumnReference], aggregations: MutableMapping[str, CallExpression], ) -> None: @@ -55,7 +55,7 @@ def aggregations(self) -> MutableMapping[str, CallExpression]: """ return self._aggregations - def node_equals(self, other: Relational) -> bool: + def node_equals(self, other: RelationalNode) -> bool: return ( isinstance(other, Aggregate) and self.keys == other.keys @@ -72,8 +72,8 @@ def accept(self, visitor: RelationalVisitor) -> None: def node_copy( self, columns: MutableMapping[str, RelationalExpression], - inputs: MutableSequence[Relational], - ) -> Relational: + inputs: MutableSequence[RelationalNode], + ) -> RelationalNode: assert len(inputs) == 1, "Aggregate node should have exactly one input" # Aggregate nodes don't cleanly map to the existing columns API. # We still fulfill it as much as possible by mapping all column diff --git a/pydough/relational/relational_nodes/column_pruner.py b/pydough/relational/relational_nodes/column_pruner.py index 8eee34c5..91f7f476 100644 --- a/pydough/relational/relational_nodes/column_pruner.py +++ b/pydough/relational/relational_nodes/column_pruner.py @@ -7,7 +7,7 @@ ColumnReferenceFinder, ) -from .abstract_node import Relational +from .abstract_node import RelationalNode from .aggregate import Aggregate from .project import Project from .relational_expression_dispatcher import RelationalExpressionDispatcher @@ -25,7 +25,7 @@ def __init__(self) -> None: self._column_finder, recurse=False ) - def _prune_identity_project(self, node: Relational) -> Relational: + def _prune_identity_project(self, node: RelationalNode) -> RelationalNode: """ Remove a projection and return the input if it is an identity projection. @@ -42,8 +42,8 @@ def _prune_identity_project(self, node: Relational) -> Relational: return node def _prune_node_columns( - self, node: Relational, kept_columns: set[str] - ) -> Relational: + self, node: RelationalNode, kept_columns: set[str] + ) -> RelationalNode: """ Prune the columns for a subtree starting at this node. @@ -89,7 +89,7 @@ def _prune_node_columns( ) ) # Determine which identifiers to pass to each input. - new_inputs: list[Relational] = [] + new_inputs: list[RelationalNode] = [] # Note: The ColumnPruner should only be run when all input names are # still present in the columns. for i, default_input_name in enumerate(new_node.default_input_aliases): @@ -112,6 +112,8 @@ def prune_unused_columns(self, root: RelationalRoot) -> RelationalRoot: Returns: RelationalRoot: The root after updating all inputs. """ - new_root: Relational = self._prune_node_columns(root, set(root.columns.keys())) + new_root: RelationalNode = self._prune_node_columns( + root, set(root.columns.keys()) + ) assert isinstance(new_root, RelationalRoot), "Expected a root node." return new_root diff --git a/pydough/relational/relational_nodes/empty_singleton.py b/pydough/relational/relational_nodes/empty_singleton.py index 1aad1a23..6dbe714f 100644 --- a/pydough/relational/relational_nodes/empty_singleton.py +++ b/pydough/relational/relational_nodes/empty_singleton.py @@ -12,11 +12,11 @@ RelationalExpression, ) -from .abstract_node import Relational +from .abstract_node import RelationalNode from .relational_visitor import RelationalVisitor -class EmptySingleton(Relational): +class EmptySingleton(RelationalNode): """ A node in the relational tree representing a constant table with 1 row and 0 columns, for use in cases such as `SELECT 42 as A from (VALUES())` @@ -26,10 +26,10 @@ def __init__(self) -> None: super().__init__({}) @property - def inputs(self) -> MutableSequence[Relational]: + def inputs(self) -> MutableSequence[RelationalNode]: return [] - def node_equals(self, other: Relational) -> bool: + def node_equals(self, other: RelationalNode) -> bool: return isinstance(other, EmptySingleton) def to_string(self, compact: bool = False) -> str: @@ -41,8 +41,8 @@ def accept(self, visitor: RelationalVisitor) -> None: def node_copy( self, columns: MutableMapping[str, RelationalExpression], - inputs: MutableSequence[Relational], - ) -> Relational: + inputs: MutableSequence[RelationalNode], + ) -> RelationalNode: assert len(columns) == 0, "EmptySingleton has no columns" assert len(inputs) == 0, "EmptySingleton has no inputs" return EmptySingleton() diff --git a/pydough/relational/relational_nodes/filter.py b/pydough/relational/relational_nodes/filter.py index a8fdb7a8..d8bd6036 100644 --- a/pydough/relational/relational_nodes/filter.py +++ b/pydough/relational/relational_nodes/filter.py @@ -9,7 +9,7 @@ from pydough.relational.relational_expressions import RelationalExpression from pydough.types.boolean_type import BooleanType -from .abstract_node import Relational +from .abstract_node import RelationalNode from .relational_visitor import RelationalVisitor from .single_relational import SingleRelational @@ -22,7 +22,7 @@ class Filter(SingleRelational): def __init__( self, - input: Relational, + input: RelationalNode, condition: RelationalExpression, columns: MutableMapping[str, RelationalExpression], ) -> None: @@ -39,7 +39,7 @@ def condition(self) -> RelationalExpression: """ return self._condition - def node_equals(self, other: Relational) -> bool: + def node_equals(self, other: RelationalNode) -> bool: return ( isinstance(other, Filter) and self.condition == other.condition @@ -55,7 +55,7 @@ def accept(self, visitor: RelationalVisitor) -> None: def node_copy( self, columns: MutableMapping[str, RelationalExpression], - inputs: MutableSequence[Relational], - ) -> Relational: + inputs: MutableSequence[RelationalNode], + ) -> RelationalNode: assert len(inputs) == 1, "Filter node should have exactly one input" return Filter(inputs[0], self.condition, columns) diff --git a/pydough/relational/relational_nodes/join.py b/pydough/relational/relational_nodes/join.py index 3114be15..f1629689 100644 --- a/pydough/relational/relational_nodes/join.py +++ b/pydough/relational/relational_nodes/join.py @@ -9,7 +9,7 @@ from pydough.relational.relational_expressions import RelationalExpression from pydough.types.boolean_type import BooleanType -from .abstract_node import Relational +from .abstract_node import RelationalNode from .relational_visitor import RelationalVisitor @@ -22,7 +22,7 @@ class JoinType(Enum): SEMI = "semi" -class Join(Relational): +class Join(RelationalNode): """ Relational representation of all join operations. This single node can represent multiple joins at once, similar to a multi-join @@ -45,7 +45,7 @@ class Join(Relational): def __init__( self, - inputs: MutableSequence[Relational], + inputs: MutableSequence[RelationalNode], conditions: list[RelationalExpression], join_types: list[JoinType], columns: MutableMapping[str, RelationalExpression], @@ -81,7 +81,7 @@ def join_types(self) -> list[JoinType]: return self._join_types @property - def inputs(self) -> MutableSequence[Relational]: + def inputs(self) -> MutableSequence[RelationalNode]: return self._inputs @property @@ -96,7 +96,7 @@ def default_input_aliases(self) -> list[str | None]: """ return [f"t{i}" for i in range(len(self.inputs))] - def node_equals(self, other: Relational) -> bool: + def node_equals(self, other: RelationalNode) -> bool: return ( isinstance(other, Join) and self.conditions == other.conditions @@ -117,6 +117,6 @@ def accept(self, visitor: RelationalVisitor) -> None: def node_copy( self, columns: MutableMapping[str, RelationalExpression], - inputs: MutableSequence[Relational], - ) -> Relational: + inputs: MutableSequence[RelationalNode], + ) -> RelationalNode: return Join(inputs, self.conditions, self.join_types, columns) diff --git a/pydough/relational/relational_nodes/limit.py b/pydough/relational/relational_nodes/limit.py index e6a86f96..2cc03ed4 100644 --- a/pydough/relational/relational_nodes/limit.py +++ b/pydough/relational/relational_nodes/limit.py @@ -12,7 +12,7 @@ ) from pydough.types.integer_types import IntegerType -from .abstract_node import Relational +from .abstract_node import RelationalNode from .relational_visitor import RelationalVisitor from .single_relational import SingleRelational @@ -26,7 +26,7 @@ class Limit(SingleRelational): def __init__( self, - input: Relational, + input: RelationalNode, limit: RelationalExpression, columns: MutableMapping[str, RelationalExpression], orderings: MutableSequence[ExpressionSortInfo] | None = None, @@ -57,7 +57,7 @@ def orderings(self) -> MutableSequence[ExpressionSortInfo]: """ return self._orderings - def node_equals(self, other: Relational) -> bool: + def node_equals(self, other: RelationalNode) -> bool: return ( isinstance(other, Limit) and self.limit == other.limit @@ -77,7 +77,7 @@ def accept(self, visitor: RelationalVisitor) -> None: def node_copy( self, columns: MutableMapping[str, RelationalExpression], - inputs: MutableSequence[Relational], - ) -> Relational: + inputs: MutableSequence[RelationalNode], + ) -> RelationalNode: assert len(inputs) == 1, "Limit node should have exactly one input" return Limit(inputs[0], self.limit, columns, self.orderings) diff --git a/pydough/relational/relational_nodes/project.py b/pydough/relational/relational_nodes/project.py index 25fbbf1c..c2413dff 100644 --- a/pydough/relational/relational_nodes/project.py +++ b/pydough/relational/relational_nodes/project.py @@ -13,7 +13,7 @@ RelationalExpression, ) -from .abstract_node import Relational +from .abstract_node import RelationalNode from .relational_visitor import RelationalVisitor from .single_relational import SingleRelational @@ -27,12 +27,12 @@ class Project(SingleRelational): def __init__( self, - input: Relational, + input: RelationalNode, columns: MutableMapping[str, RelationalExpression], ) -> None: super().__init__(input, columns) - def node_equals(self, other: Relational) -> bool: + def node_equals(self, other: RelationalNode) -> bool: return isinstance(other, Project) and super().node_equals(other) def to_string(self, compact: bool = False) -> str: @@ -56,7 +56,7 @@ def is_identity(self) -> bool: def node_copy( self, columns: MutableMapping[str, RelationalExpression], - inputs: MutableSequence[Relational], - ) -> Relational: + inputs: MutableSequence[RelationalNode], + ) -> RelationalNode: assert len(inputs) == 1, "Project node should have exactly one input" return Project(inputs[0], columns) diff --git a/pydough/relational/relational_nodes/relational_expression_dispatcher.py b/pydough/relational/relational_nodes/relational_expression_dispatcher.py index 6a581fca..fb6fb6ff 100644 --- a/pydough/relational/relational_nodes/relational_expression_dispatcher.py +++ b/pydough/relational/relational_nodes/relational_expression_dispatcher.py @@ -7,7 +7,7 @@ RelationalExpressionVisitor, ) -from .abstract_node import Relational +from .abstract_node import RelationalNode from .aggregate import Aggregate from .empty_singleton import EmptySingleton from .filter import Filter @@ -38,7 +38,7 @@ def reset(self) -> None: def get_expr_visitor(self) -> RelationalExpressionVisitor: return self._expr_visitor - def visit_common(self, node: Relational) -> None: + def visit_common(self, node: RelationalNode) -> None: """ Applies a visit common to each node. """ diff --git a/pydough/relational/relational_nodes/relational_root.py b/pydough/relational/relational_nodes/relational_root.py index 17569d05..c4c31285 100644 --- a/pydough/relational/relational_nodes/relational_root.py +++ b/pydough/relational/relational_nodes/relational_root.py @@ -11,7 +11,7 @@ RelationalExpression, ) -from .abstract_node import Relational +from .abstract_node import RelationalNode from .relational_visitor import RelationalVisitor from .single_relational import SingleRelational @@ -25,7 +25,7 @@ class RelationalRoot(SingleRelational): def __init__( self, - input: Relational, + input: RelationalNode, ordered_columns: MutableSequence[tuple[str, RelationalExpression]], orderings: MutableSequence[ExpressionSortInfo] | None = None, ) -> None: @@ -56,7 +56,7 @@ def orderings(self) -> MutableSequence[ExpressionSortInfo]: """ return self._orderings - def node_equals(self, other: Relational) -> bool: + def node_equals(self, other: RelationalNode) -> bool: return ( isinstance(other, RelationalRoot) and self.ordered_columns == other.ordered_columns @@ -82,8 +82,8 @@ def accept(self, visitor: RelationalVisitor) -> None: def node_copy( self, columns: MutableMapping[str, RelationalExpression], - inputs: MutableSequence[Relational], - ) -> Relational: + inputs: MutableSequence[RelationalNode], + ) -> RelationalNode: assert len(inputs) == 1, "Root node should have exactly one input" assert columns == self.columns, "Root columns should not be modified" return RelationalRoot(inputs[0], self.ordered_columns, self.orderings) diff --git a/pydough/relational/relational_nodes/scan.py b/pydough/relational/relational_nodes/scan.py index 1ee0fe17..d41e2979 100644 --- a/pydough/relational/relational_nodes/scan.py +++ b/pydough/relational/relational_nodes/scan.py @@ -11,11 +11,11 @@ class for more specific implementations. RelationalExpression, ) -from .abstract_node import Relational +from .abstract_node import RelationalNode from .relational_visitor import RelationalVisitor -class Scan(Relational): +class Scan(RelationalNode): """ The Scan node in the relational tree. Right now these refer to tables stored within a provided database connection with is assumed to be singular @@ -29,11 +29,11 @@ def __init__( self.table_name: str = table_name @property - def inputs(self) -> MutableSequence[Relational]: + def inputs(self) -> MutableSequence[RelationalNode]: # A scan is required to be the leaf node of the relational tree. return [] - def node_equals(self, other: Relational) -> bool: + def node_equals(self, other: RelationalNode) -> bool: return isinstance(other, Scan) and self.table_name == other.table_name def accept(self, visitor: RelationalVisitor) -> None: @@ -45,7 +45,7 @@ def to_string(self, compact=False) -> str: def node_copy( self, columns: MutableMapping[str, RelationalExpression], - inputs: MutableSequence[Relational], - ) -> Relational: + inputs: MutableSequence[RelationalNode], + ) -> RelationalNode: assert not inputs, "Scan node should have 0 inputs" return Scan(self.table_name, columns) diff --git a/pydough/relational/relational_nodes/single_relational.py b/pydough/relational/relational_nodes/single_relational.py index 8efdaa5a..b41b1849 100644 --- a/pydough/relational/relational_nodes/single_relational.py +++ b/pydough/relational/relational_nodes/single_relational.py @@ -7,31 +7,31 @@ from pydough.relational.relational_expressions import RelationalExpression -from .abstract_node import Relational +from .abstract_node import RelationalNode -class SingleRelational(Relational): +class SingleRelational(RelationalNode): """ Base abstract class for relational nodes that have a single input. """ def __init__( self, - input: Relational, + input: RelationalNode, columns: MutableMapping[str, RelationalExpression], ) -> None: super().__init__(columns) - self._input: Relational = input + self._input: RelationalNode = input @property - def inputs(self) -> MutableSequence[Relational]: + def inputs(self) -> MutableSequence[RelationalNode]: return [self._input] @property - def input(self) -> Relational: + def input(self) -> RelationalNode: return self._input - def node_equals(self, other: Relational) -> bool: + def node_equals(self, other: RelationalNode) -> bool: """ Determine if two relational nodes are exactly identical, excluding column ordering. This should be extended to avoid diff --git a/tests/test_relational.py b/tests/test_relational.py index 42f3fe49..41380e2c 100644 --- a/tests/test_relational.py +++ b/tests/test_relational.py @@ -22,7 +22,7 @@ Limit, LiteralExpression, Project, - Relational, + RelationalNode, RelationalRoot, Scan, ) @@ -136,7 +136,9 @@ def test_scan_to_string(scan_node: Scan, output: str) -> None: ), ], ) -def test_scan_equals(first_scan: Scan, second_scan: Relational, output: bool) -> None: +def test_scan_equals( + first_scan: Scan, second_scan: RelationalNode, output: bool +) -> None: """ Tests the equality functionality for the Scan node. """ @@ -278,7 +280,7 @@ def test_empty_singleton_equals() -> None: ], ) def test_project_equals( - first_project: Project, second_project: Relational, output: bool + first_project: Project, second_project: RelationalNode, output: bool ) -> None: """ Tests the equality functionality for the Project node. @@ -487,7 +489,7 @@ def test_limit_to_string(limit: Limit, output: str) -> None: ], ) def test_limit_equals( - first_limit: Limit, second_limit: Relational, output: bool + first_limit: Limit, second_limit: RelationalNode, output: bool ) -> None: """ Tests the equality functionality for the Limit node. @@ -763,7 +765,7 @@ def test_aggregate_to_string(agg: Aggregate, output: str) -> None: ], ) def test_aggregate_equals( - first_agg: Aggregate, second_agg: Relational, output: bool + first_agg: Aggregate, second_agg: RelationalNode, output: bool ) -> None: """ Tests the equality functionality for the Aggregate node. @@ -968,7 +970,7 @@ def test_filter_to_string(filter: Filter, output: str) -> None: ], ) def test_filter_equals( - first_filter: Filter, second_filter: Relational, output: bool + first_filter: Filter, second_filter: RelationalNode, output: bool ) -> None: """ Tests the equality functionality for the Filter node. @@ -1211,7 +1213,7 @@ def test_root_to_string(root: RelationalRoot, output: str) -> None: ], ) def test_root_equals( - first_root: RelationalRoot, second_root: Relational, output: bool + first_root: RelationalRoot, second_root: RelationalNode, output: bool ) -> None: """ Tests the equality functionality for the Root node. @@ -1916,7 +1918,9 @@ def test_join_to_string(join: Join, output: str) -> None: ), ], ) -def test_join_equals(first_join: Join, second_join: Relational, output: bool) -> None: +def test_join_equals( + first_join: Join, second_join: RelationalNode, output: bool +) -> None: """ Tests the equality functionality for the Join node. """ diff --git a/tests/test_relational_execution.py b/tests/test_relational_execution.py index 8f6688bf..7928598c 100644 --- a/tests/test_relational_execution.py +++ b/tests/test_relational_execution.py @@ -21,7 +21,7 @@ CallExpression, Join, JoinType, - Relational, + RelationalNode, RelationalRoot, Scan, ) @@ -39,14 +39,14 @@ def test_person_total_salary( Tests a simple join and aggregate to compute the total salary for each person in the PEOPLE table. """ - people: Relational = Scan( + people: RelationalNode = Scan( table_name="PEOPLE", columns={ "person_id": make_relational_column_reference("person_id"), "name": make_relational_column_reference("name"), }, ) - jobs: Relational = Aggregate( + jobs: RelationalNode = Aggregate( keys={"person_id": make_relational_column_reference("person_id")}, aggregations={ "total_salary": CallExpression( @@ -112,7 +112,7 @@ def test_person_jobs_multi_join( represent multiple joins. It should be noted that this may not be optimal way to represent this query, but it is a valid way to represent it. """ - people: Relational = Scan( + people: RelationalNode = Scan( table_name="PEOPLE", columns={ "person_id": make_relational_column_reference("person_id"), @@ -120,7 +120,7 @@ def test_person_jobs_multi_join( }, ) # Select each person's highest salary - jobs: Relational = Aggregate( + jobs: RelationalNode = Aggregate( keys={"person_id": make_relational_column_reference("person_id")}, aggregations={ "max_salary": CallExpression( @@ -138,7 +138,7 @@ def test_person_jobs_multi_join( ), ) # Select the average salary across all jobs ever recorded - average_salary: Relational = Aggregate( + average_salary: RelationalNode = Aggregate( keys={}, aggregations={ "average_salary": CallExpression( diff --git a/tests/test_relational_nodes_to_sqlglot.py b/tests/test_relational_nodes_to_sqlglot.py index 1075ea6e..d04113fb 100644 --- a/tests/test_relational_nodes_to_sqlglot.py +++ b/tests/test_relational_nodes_to_sqlglot.py @@ -60,7 +60,7 @@ Limit, LiteralExpression, Project, - Relational, + RelationalNode, RelationalRoot, Scan, WindowCallExpression, @@ -1577,7 +1577,7 @@ def mkglot_func(op: type[Expression], args: list[Expression]) -> Expression: ) def test_node_to_sqlglot( sqlglot_relational_visitor: SQLGlotRelationalVisitor, - node: Relational, + node: RelationalNode, sqlglot_expr: Expression, ) -> None: """ From 45fea5bb5f6926961786057441d5b673e62e6edf Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 6 Feb 2025 00:55:07 -0500 Subject: [PATCH 086/112] Resolving conflicts --- pydough/logger/logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydough/logger/logger.py b/pydough/logger/logger.py index 9f70460d..a642ee66 100644 --- a/pydough/logger/logger.py +++ b/pydough/logger/logger.py @@ -26,7 +26,7 @@ def get_logger( `logging.Logger` : Configured logger instance. """ logger: logging.Logger = logging.getLogger(name) - level_env: str | None = os.getenv("PYDOUGH_LOG_LEVEL") + level_env: str = os.getenv("PYDOUGH_LOG_LEVEL") level: int if level_env is not None: From 7bb353fb35484b021de15847915b82d9aefb4eb6 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 6 Feb 2025 00:55:14 -0500 Subject: [PATCH 087/112] Resolving conflicts --- pydough/logger/logger.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/pydough/logger/logger.py b/pydough/logger/logger.py index a642ee66..f4280bf0 100644 --- a/pydough/logger/logger.py +++ b/pydough/logger/logger.py @@ -30,26 +30,13 @@ def get_logger( level: int if level_env is not None: - assert isinstance( - level_env, str - ), f"expected environment variable 'PYDOUGH_LOG_LEVEL' to be a string, found {level_env.__class__.__name__}" + assert isinstance(level_env, str), f"expected environment variable 'PYDOUGH_LOG_LEVEL' to be a string, found {level_env.__class__.__name__}" allowed_levels: list[str] = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] - assert ( - level_env in allowed_levels - ), f"expected environment variable 'PYDOUGH_LOG_LEVEL' to be one of {', '.join(allowed_levels)}, found {level_env}" + assert level_env in allowed_levels, f"expected environment variable 'PYDOUGH_LOG_LEVEL' to be one of {', '.join(allowed_levels)}, found {level_env}" # Convert string level (e.g., "DEBUG", "INFO") to a logging constant level = getattr(logging, level_env.upper(), default_level) else: - assert ( - default_level - in [ - logging.DEBUG, - logging.INFO, - logging.WARNING, - logging.ERROR, - logging.CRITICAL, - ] - ), f"expected arguement default_value to be one of logging.DEBUG,logging.INFO,logging.WARNING,logging.ERROR,logging.CRITICAL, found {default_level}" + assert default_level in [logging.DEBUG,logging.INFO,logging.WARNING,logging.ERROR,logging.CRITICAL], f"expected arguement default_value to be one of logging.DEBUG,logging.INFO,logging.WARNING,logging.ERROR,logging.CRITICAL, found {default_level}" level = default_level # Create default console handler From 0cf4c11d966cef8f3c8ad48709e5db73fc2949c9 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 6 Feb 2025 01:22:36 -0500 Subject: [PATCH 088/112] Adding documentation --- pydough/logger/logger.py | 21 +++++++++++++++---- pydough/relational/README.md | 8 +++++++ .../relational_expressions/README.md | 19 +++++++++++++++++ .../correlated_reference.py | 8 +++++-- tests/simple_pydough_functions.py | 4 ++-- 5 files changed, 52 insertions(+), 8 deletions(-) diff --git a/pydough/logger/logger.py b/pydough/logger/logger.py index f4280bf0..9f70460d 100644 --- a/pydough/logger/logger.py +++ b/pydough/logger/logger.py @@ -26,17 +26,30 @@ def get_logger( `logging.Logger` : Configured logger instance. """ logger: logging.Logger = logging.getLogger(name) - level_env: str = os.getenv("PYDOUGH_LOG_LEVEL") + level_env: str | None = os.getenv("PYDOUGH_LOG_LEVEL") level: int if level_env is not None: - assert isinstance(level_env, str), f"expected environment variable 'PYDOUGH_LOG_LEVEL' to be a string, found {level_env.__class__.__name__}" + assert isinstance( + level_env, str + ), f"expected environment variable 'PYDOUGH_LOG_LEVEL' to be a string, found {level_env.__class__.__name__}" allowed_levels: list[str] = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] - assert level_env in allowed_levels, f"expected environment variable 'PYDOUGH_LOG_LEVEL' to be one of {', '.join(allowed_levels)}, found {level_env}" + assert ( + level_env in allowed_levels + ), f"expected environment variable 'PYDOUGH_LOG_LEVEL' to be one of {', '.join(allowed_levels)}, found {level_env}" # Convert string level (e.g., "DEBUG", "INFO") to a logging constant level = getattr(logging, level_env.upper(), default_level) else: - assert default_level in [logging.DEBUG,logging.INFO,logging.WARNING,logging.ERROR,logging.CRITICAL], f"expected arguement default_value to be one of logging.DEBUG,logging.INFO,logging.WARNING,logging.ERROR,logging.CRITICAL, found {default_level}" + assert ( + default_level + in [ + logging.DEBUG, + logging.INFO, + logging.WARNING, + logging.ERROR, + logging.CRITICAL, + ] + ), f"expected arguement default_value to be one of logging.DEBUG,logging.INFO,logging.WARNING,logging.ERROR,logging.CRITICAL, found {default_level}" level = default_level # Create default console handler diff --git a/pydough/relational/README.md b/pydough/relational/README.md index d0259995..57939f8b 100644 --- a/pydough/relational/README.md +++ b/pydough/relational/README.md @@ -18,6 +18,8 @@ The relational_expressions submodule provides functionality to define and manage - `ExpressionSortInfo`: The representation of ordering for an expression within a relational node. - `RelationalExpressionVisitor`: The basic Visitor pattern to perform operations across the expression components of a relational tree. - `ColumnReferenceFinder`: Finds all unique column references in a relational expression. +- `CorrelatedReference`: The expression implementation for accessing a correlated column reference in a relational node. +- `CorrelatedReferenceFinder`: Finds all unique correlated references in a relational expression. - `RelationalExpressionShuttle`: Specialized form of the visitor pattern that returns a relational expression. - `ColumnReferenceInputNameModifier`: Shuttle implementation designed to update all uses of a column reference's input name to a new input name based on a dictionary. @@ -33,6 +35,7 @@ from pydough.relational.relational_expressions import ( ExpressionSortInfo, ColumnReferenceFinder, ColumnReferenceInputNameModifier, + CorrelatedReferenceFinder, WindowCallExpression, ) from pydough.pydough_operators import ADD, RANKING @@ -64,6 +67,11 @@ unique_column_refs = finder.get_column_references() # Modify the input name of column references in the call expression modifier = ColumnReferenceInputNameModifier({"old_input_name": "new_input_name"}) modified_call_expr = call_expr.accept_shuttle(modifier) + +# Find all unique correlated references in the call expression +correlated_finder = CorrelatedReferenceFinder() +call_expr.accept(correlated_finder) +unique_correlated_refs = correlated_finder.get_correlated_references() ``` ## [Relational Nodes](relational_nodes/README.md) diff --git a/pydough/relational/relational_expressions/README.md b/pydough/relational/relational_expressions/README.md index dcca7290..c11c8f44 100644 --- a/pydough/relational/relational_expressions/README.md +++ b/pydough/relational/relational_expressions/README.md @@ -45,6 +45,14 @@ The relational_expressions module provides functionality to define and manage va - `ColumnReferenceFinder`: Finds all unique column references in a relational expression. +### [correlated_reference.py](correlated_reference.py) + +- `CorrelatedReference`: The expression implementation for accessing a correlated column reference in a relational node. + +### [correlated_reference_finder.py](correlated_reference_finder.py) + +- `CorrelatedReferenceFinder`: Finds all unique correlated references in a relational expression. + ### [relational_expression_shuttle.py](relational_expression_shuttle.py) - `RelationalExpressionShuttle`: Specialized form of the visitor pattern that returns a relational expression. This is used to handle the common case where we need to modify a type of input. @@ -69,6 +77,8 @@ from pydough.relational.relational_expressions import ( ExpressionSortInfo, ColumnReferenceFinder, ColumnReferenceInputNameModifier, + CorrelatedReference, + CorrelatedReferenceFinder, ) from pydough.pydough_operators import ADD from pydough.types import Int64Type @@ -82,6 +92,10 @@ literal_expr = LiteralExpression(10, Int64Type()) # Create a call expression for addition call_expr = CallExpression(ADD, Int64Type(), [column_ref, literal_expr]) +# Create a correlated reference to column `column_name` in the first input to +# an ancestor join of `corr1` +correlated_ref = CorrelatedReference("column_name", "corr1", Int64Type()) + # Create an expression sort info sort_info = ExpressionSortInfo(call_expr, ascending=True, nulls_first=False) @@ -96,4 +110,9 @@ unique_column_refs = finder.get_column_references() # Modify the input name of column references in the call expression modifier = ColumnReferenceInputNameModifier({"old_input_name": "new_input_name"}) modified_call_expr = call_expr.accept_shuttle(modifier) + +# Find all unique correlated references in the call expression +correlated_finder = CorrelatedReferenceFinder() +call_expr.accept(correlated_finder) +unique_correlated_refs = correlated_finder.get_correlated_references() ``` diff --git a/pydough/relational/relational_expressions/correlated_reference.py b/pydough/relational/relational_expressions/correlated_reference.py index 0746a5bc..e6be2790 100644 --- a/pydough/relational/relational_expressions/correlated_reference.py +++ b/pydough/relational/relational_expressions/correlated_reference.py @@ -1,5 +1,8 @@ """ -TODO +The representation of a correlated column access for use in a relational tree. +The correl name should be the `correl_name` property of a join ancestor of the +tree, and the name should match one of the column names of the first input to +that join, which is the column that the correlated reference refers to. """ __all__ = ["CorrelatedReference"] @@ -13,7 +16,8 @@ class CorrelatedReference(RelationalExpression): """ - TODO + The Expression implementation for accessing a correlated column reference + in a relational node. """ def __init__(self, name: str, correl_name: str, data_type: PyDoughType) -> None: diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index b0f7647f..e9230383 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -486,7 +486,7 @@ def correl_15(): def correl_16(): - # Correlated back reference example #15: hybrid tree order of operations. + # Correlated back reference example #16: hybrid tree order of operations. # Count how many european suppliers have the exact same percentile value # of account balance (relative to all other suppliers) as at least one # customer's percentile value of account balance relative to all other @@ -504,7 +504,7 @@ def correl_16(): def correl_17(): - # Correlated back reference example #15: hybrid tree order of operations. + # Correlated back reference example #17: hybrid tree order of operations. # An extremely roundabout way of getting each region_name-nation_name # pair as a string. # (This is a correlated singular/semi access) From 3a1516e72a7ae2f8fd3c1fd0130b17abfc05e6bf Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 6 Feb 2025 01:48:47 -0500 Subject: [PATCH 089/112] Adding more comments --- pydough/conversion/hybrid_tree.py | 32 ++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index e49ea82e..e10a6d94 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -1562,23 +1562,43 @@ def make_hybrid_correl_expr( steps_taken_so_far: int, ) -> HybridCorrelExpr: """ - TODO + Converts a BACK reference into a correlated reference when the number + of BACK levels exceeds the height of the current subtree. + + Args: + `back_expr`: the original BACK reference to be converted. + `collection`: the collection at the top of the current subtree, + before we have run out of BACK levels to step up out of. + `steps_taken_so_far`: the number of steps already taken to step + up from the BACK node. This is needed so we know how many steps + still need to be taken upward once we have stepped out of hte child + subtree back into the parent subtree. """ if len(self.stack) == 0: raise ValueError("Back reference steps too far back") + # Identify the parent subtree that the BACK reference is stepping back + # into, out of the child. parent_tree = self.stack.pop() remaining_steps_back: int = back_expr.back_levels - steps_taken_so_far - 1 parent_result: HybridExpr + # Special case: stepping out of the data argument of PARTITION back + # into its ancestor. For example: + # TPCH(x=...).PARTITION(data.WHERE(y > BACK(1).x), ...) if len(parent_tree.pipeline) == 1 and isinstance( parent_tree.pipeline[0], HybridPartition ): assert parent_tree.parent is not None + # Treat the partition's parent as the conext for the back + # to step into, as opposed to the partition itself (so the back + # levels are consistent) self.stack.append(parent_tree.parent) parent_result = self.make_hybrid_correl_expr( back_expr, collection, steps_taken_so_far ) self.stack.pop() self.stack.append(parent_tree) + # Then, postprocess the output to account for the fact that a + # BACK level got skipped due to the change in subtree. match parent_result.expr: case HybridRefExpr(): parent_result = HybridBackRefExpr( @@ -1595,6 +1615,8 @@ def make_hybrid_correl_expr( f"Malformed expression for correlated reference: {parent_result}" ) elif remaining_steps_back == 0: + # If there are no more steps back to be made, then the correlated + # reference is to a reference from the current context. if back_expr.term_name not in parent_tree.pipeline[-1].terms: raise ValueError( f"Back reference to {back_expr.term_name} not found in parent" @@ -1604,11 +1626,19 @@ def make_hybrid_correl_expr( ) parent_result = HybridRefExpr(parent_name, back_expr.pydough_type) else: + # Otherwise, a back reference needs to be made from the current + # collection a number of steps back based on how many steps still + # need to be taken, and it must be recursively converted to a + # hybrid expression that gets wrapped in a correlated reference. new_expr: PyDoughExpressionQDAG = BackReferenceExpression( collection, back_expr.term_name, remaining_steps_back ) parent_result = self.make_hybrid_expr(parent_tree, new_expr, {}, False) + # Restore parent_tree back onto the stack, since evaluating `back_expr` + # does not change the program's current placement in the sutbtrees. self.stack.append(parent_tree) + # Create the correlated reference to the expression with regards to + # the parent tree, which could also be a correlated expression. return HybridCorrelExpr(parent_tree, parent_result) def make_hybrid_expr( From 746eed595d8e9e6e10d0d146c24b35c62798b43b Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 6 Feb 2025 01:55:01 -0500 Subject: [PATCH 090/112] Added more comments [RUN CI] --- pydough/conversion/relational_converter.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index b121c3e3..637225ec 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -229,6 +229,9 @@ def translate_expression( expr.kwargs, ) case HybridCorrelExpr(): + # Convert correlated expressions by converting the expression + # they point to in the context of the top of the stack, then + # wrapping the result in a correlated reference. ancestor_context: TranslationOutput = self.stack.pop() ancestor_expr: RelationalExpression = self.translate_expression( expr.expr, ancestor_context From 17116d48993b1899077c81ba92b12a703f88ce10 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 6 Feb 2025 16:28:00 -0500 Subject: [PATCH 091/112] Initial handling of decorrelation setup --- pydough/conversion/hybrid_tree.py | 11 +++++++++++ pydough/conversion/relational_converter.py | 8 +++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index e10a6d94..348536b4 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -895,6 +895,7 @@ def __init__( self._is_connection_root: bool = is_connection_root self._agg_keys: list[HybridExpr] | None = None self._join_keys: list[tuple[HybridExpr, HybridExpr]] | None = None + self._correlated_children: set[int] = set() if isinstance(root_operation, HybridPartition): self._join_keys = [] @@ -935,6 +936,14 @@ def children(self) -> list[HybridConnection]: """ return self._children + @property + def correlated_children(self) -> set[int]: + """ + The set of indices of children that contain correlated references to + the current hybrid tree. + """ + return self._correlated_children + @property def successor(self) -> Optional["HybridTree"]: """ @@ -1634,6 +1643,8 @@ def make_hybrid_correl_expr( collection, back_expr.term_name, remaining_steps_back ) parent_result = self.make_hybrid_expr(parent_tree, new_expr, {}, False) + if not isinstance(parent_result, HybridCorrelExpr): + parent_tree.correlated_children.add(len(parent_tree.children)) # Restore parent_tree back onto the stack, since evaluating `back_expr` # does not change the program's current placement in the sutbtrees. self.stack.append(parent_tree) diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 637225ec..6ae8a303 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -44,6 +44,7 @@ ) from pydough.types import BooleanType, Int64Type, UnknownType +from .hybrid_decorrelater import decorrelate_hybrid from .hybrid_tree import ( ConnectionType, HybridBackRefExpr, @@ -996,10 +997,11 @@ def convert_ast_to_relational( final_terms: set[str] = node.calc_terms node = translator.preprocess_root(node) - # Convert the QDAG node to the hybrid form, then invoke the relational - # conversion procedure. The first element in the returned list is the - # final rel node. + # Convert the QDAG node to the hybrid form, decorrelate it, then invoke + # the relational conversion procedure. The first element in the returned + # list is the final rel node. hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node, None) + decorrelate_hybrid(hybrid) renamings: dict[str, str] = hybrid.pipeline[-1].renamings output: TranslationOutput = translator.rel_translation( None, hybrid, len(hybrid.pipeline) - 1 From 9ab9bedd83099d8fb502f2c21ea6d9e890789a56 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 6 Feb 2025 16:31:38 -0500 Subject: [PATCH 092/112] Adding decorrelater file --- pydough/conversion/hybrid_decorrelater.py | 64 +++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 pydough/conversion/hybrid_decorrelater.py diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py new file mode 100644 index 00000000..8303f433 --- /dev/null +++ b/pydough/conversion/hybrid_decorrelater.py @@ -0,0 +1,64 @@ +""" +Logic for applying decorrelation to hybrid trees before relational conversion +if the correlate is not a semi/anti join. +""" + +__all__ = ["decorrelate_hybrid"] + + +from .hybrid_tree import ( + ConnectionType, + HybridTree, +) + + +class Decorrelater: + """ + TODO + """ + + def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: + """ + TODO + """ + # Recursively decorrelate the ancestors of the current level of the + # hybrid tree. + if hybrid.parent is not None: + hybrid._parent = self.decorrelate_hybrid_tree(hybrid.parent) + # Iterate across all the children and transform any that require + # decorrelation due to the type of connection. + for idx, child in enumerate(hybrid.children): + if idx not in hybrid.correlated_children: + continue + match child.connection_type: + case ( + ConnectionType.SINGULAR + | ConnectionType.AGGREGATION + | ConnectionType.SINGULAR_ONLY_MATCH + | ConnectionType.AGGREGATION_ONLY_MATCH + | ConnectionType.NDISTINCT + | ConnectionType.NDISTINCT_ONLY_MATCH + ): + raise NotImplementedError( + f"PyDough does not yet support correlated references with the {child.connection_type.name} pattern." + ) + case ( + ConnectionType.SEMI + | ConnectionType.ANTI + | ConnectionType.NO_MATCH_SINGULAR + | ConnectionType.NO_MATCH_AGGREGATION + | ConnectionType.NO_MATCH_NDISTINCT + ): + continue + # Iterate across all the children and decorrelate them. + for idx, child in enumerate(hybrid.children): + hybrid.children[idx].subtree = self.decorrelate_hybrid_tree(child.subtree) + return hybrid + + +def decorrelate_hybrid(hybrid: HybridTree) -> HybridTree: + """ + TODO + """ + decorr: Decorrelater = Decorrelater() + return decorr.decorrelate_hybrid_tree(hybrid) From a8f6535b69030ce4052e255bd379c927f26df2bf Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 6 Feb 2025 16:33:35 -0500 Subject: [PATCH 093/112] Renaming and added comment --- pydough/conversion/hybrid_decorrelater.py | 11 +++++++---- pydough/conversion/relational_converter.py | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py index 8303f433..2cf1e643 100644 --- a/pydough/conversion/hybrid_decorrelater.py +++ b/pydough/conversion/hybrid_decorrelater.py @@ -3,7 +3,7 @@ if the correlate is not a semi/anti join. """ -__all__ = ["decorrelate_hybrid"] +__all__ = ["run_hybrid_decorrelation"] from .hybrid_tree import ( @@ -25,8 +25,9 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: # hybrid tree. if hybrid.parent is not None: hybrid._parent = self.decorrelate_hybrid_tree(hybrid.parent) - # Iterate across all the children and transform any that require - # decorrelation due to the type of connection. + # Iterate across all the children, identify any that are correlated, + # and transform any of the correlated ones that require decorrelation + # due to the type of connection. for idx, child in enumerate(hybrid.children): if idx not in hybrid.correlated_children: continue @@ -49,6 +50,8 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: | ConnectionType.NO_MATCH_AGGREGATION | ConnectionType.NO_MATCH_NDISTINCT ): + # These patterns do not require decorrelation since they + # are supported via correlated SEMI/ANTI joins. continue # Iterate across all the children and decorrelate them. for idx, child in enumerate(hybrid.children): @@ -56,7 +59,7 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: return hybrid -def decorrelate_hybrid(hybrid: HybridTree) -> HybridTree: +def run_hybrid_decorrelation(hybrid: HybridTree) -> HybridTree: """ TODO """ diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 6ae8a303..4c69487e 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -44,7 +44,7 @@ ) from pydough.types import BooleanType, Int64Type, UnknownType -from .hybrid_decorrelater import decorrelate_hybrid +from .hybrid_decorrelater import run_hybrid_decorrelation from .hybrid_tree import ( ConnectionType, HybridBackRefExpr, @@ -1001,7 +1001,7 @@ def convert_ast_to_relational( # the relational conversion procedure. The first element in the returned # list is the final rel node. hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node, None) - decorrelate_hybrid(hybrid) + run_hybrid_decorrelation(hybrid) renamings: dict[str, str] = hybrid.pipeline[-1].renamings output: TranslationOutput = translator.rel_translation( None, hybrid, len(hybrid.pipeline) - 1 From 8ff51578da1ab5367a318c1ee756b58a788785b6 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 7 Feb 2025 12:18:13 -0500 Subject: [PATCH 094/112] Implented singular decorrelation handling --- pydough/conversion/hybrid_decorrelater.py | 122 ++++++++++++++++++++-- tests/test_plan_refsols/correl_17.txt | 12 ++- tests/test_plan_refsols/correl_8.txt | 10 +- tests/test_plan_refsols/correl_9.txt | 12 ++- 4 files changed, 136 insertions(+), 20 deletions(-) diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py index 2cf1e643..4ba678d7 100644 --- a/pydough/conversion/hybrid_decorrelater.py +++ b/pydough/conversion/hybrid_decorrelater.py @@ -6,9 +6,23 @@ __all__ = ["run_hybrid_decorrelation"] +import copy + from .hybrid_tree import ( ConnectionType, + HybridBackRefExpr, + HybridCalc, + HybridChildRefExpr, + HybridColumnExpr, + HybridConnection, + HybridCorrelExpr, + HybridExpr, + HybridFilter, + HybridFunctionExpr, + HybridLiteralExpr, + HybridRefExpr, HybridTree, + HybridWindowExpr, ) @@ -17,6 +31,101 @@ class Decorrelater: TODO """ + def make_decorrelate_parent(self, hybrid: HybridTree, child_idx: int) -> HybridTree: + """ + TODO + """ + successor: HybridTree | None = hybrid.successor + hybrid._successor = None + new_hybrid: HybridTree = copy.deepcopy(hybrid) + hybrid._successor = successor + new_hybrid._children = new_hybrid._children[:child_idx] + new_hybrid._pipeline = new_hybrid._pipeline[ + : hybrid.children[child_idx].required_steps + 1 + ] + return new_hybrid + + def remove_correl_refs( + self, expr: HybridExpr, parent: HybridTree, child_height: int + ) -> HybridExpr: + """ + TODO + """ + match expr: + case HybridCorrelExpr(): + result: HybridExpr | None = expr.expr.shift_back(child_height) + assert result is not None + return result + case HybridFunctionExpr(): + for idx, arg in enumerate(expr.args): + expr.args[idx] = self.remove_correl_refs(arg, parent, child_height) + return expr + case HybridWindowExpr(): + for idx, arg in enumerate(expr.args): + expr.args[idx] = self.remove_correl_refs(arg, parent, child_height) + for idx, arg in enumerate(expr.partition_args): + expr.partition_args[idx] = self.remove_correl_refs( + arg, parent, child_height + ) + for order_arg in expr.order_args: + order_arg.expr = self.remove_correl_refs( + order_arg.expr, parent, child_height + ) + return expr + case ( + HybridBackRefExpr() + | HybridRefExpr() + | HybridChildRefExpr() + | HybridLiteralExpr() + | HybridColumnExpr() + ): + return expr + case _: + raise NotImplementedError( + f"Unsupported expression type: {expr.__class__.__name__}." + ) + + def decorrelate_singular( + self, old_parent: HybridTree, new_parent: HybridTree, child: HybridConnection + ) -> None: + """ + TODO + """ + # First, find the height of the child subtree & its top-most level. + child_root: HybridTree = child.subtree + child_height: int = 1 + while child_root.parent is not None: + child_root = child_root.parent + child_height += 1 + # Link the top level of the child subtree to the new parent. + new_parent.add_successor(child_root) + # Replace any correlated references to the original parent with BACK references. + level: HybridTree = child.subtree + while level.parent is not None and level is not new_parent: + for operation in level.pipeline: + for name, expr in operation.terms.items(): + operation.terms[name] = self.remove_correl_refs( + expr, old_parent, child_height + ) + for ordering in operation.orderings: + ordering.expr = self.remove_correl_refs( + ordering.expr, old_parent, child_height + ) + for idx, expr in enumerate(operation.unique_exprs): + operation.unique_exprs[idx] = self.remove_correl_refs( + expr, old_parent, child_height + ) + if isinstance(operation, HybridCalc): + for str, expr in operation.new_expressions.items(): + operation.new_expressions[str] = self.remove_correl_refs( + expr, old_parent, child_height + ) + if isinstance(operation, HybridFilter): + operation.condition = self.remove_correl_refs( + operation.condition, old_parent, child_height + ) + level = level.parent + def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: """ TODO @@ -25,17 +134,21 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: # hybrid tree. if hybrid.parent is not None: hybrid._parent = self.decorrelate_hybrid_tree(hybrid.parent) + # Iterate across all the children and recursively decorrelate them. + for child in hybrid.children: + child.subtree = self.decorrelate_hybrid_tree(child.subtree) # Iterate across all the children, identify any that are correlated, # and transform any of the correlated ones that require decorrelation # due to the type of connection. for idx, child in enumerate(hybrid.children): if idx not in hybrid.correlated_children: continue + new_parent: HybridTree = self.make_decorrelate_parent(hybrid, idx) match child.connection_type: + case ConnectionType.SINGULAR | ConnectionType.SINGULAR_ONLY_MATCH: + self.decorrelate_singular(hybrid, new_parent, child) case ( - ConnectionType.SINGULAR - | ConnectionType.AGGREGATION - | ConnectionType.SINGULAR_ONLY_MATCH + ConnectionType.AGGREGATION | ConnectionType.AGGREGATION_ONLY_MATCH | ConnectionType.NDISTINCT | ConnectionType.NDISTINCT_ONLY_MATCH @@ -53,9 +166,6 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: # These patterns do not require decorrelation since they # are supported via correlated SEMI/ANTI joins. continue - # Iterate across all the children and decorrelate them. - for idx, child in enumerate(hybrid.children): - hybrid.children[idx].subtree = self.decorrelate_hybrid_tree(child.subtree) return hybrid diff --git a/tests/test_plan_refsols/correl_17.txt b/tests/test_plan_refsols/correl_17.txt index 4e532f3c..9b802ed6 100644 --- a/tests/test_plan_refsols/correl_17.txt +++ b/tests/test_plan_refsols/correl_17.txt @@ -2,8 +2,10 @@ ROOT(columns=[('fullname', fullname)], orderings=[(ordering_0):asc_first]) PROJECT(columns={'fullname': fullname, 'ordering_0': fullname}) PROJECT(columns={'fullname': fname}) FILTER(condition=True:bool, columns={'fname': fname}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'fname': t1.fname}, correl_name='corr1') - PROJECT(columns={'lname': LOWER(name), 'region_key': region_key}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) - PROJECT(columns={'fname': JOIN_STRINGS('-':string, LOWER(name), corr1.lname), 'key': key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + JOIN(conditions=[t0.region_key == t1.key_2], types=['inner'], columns={'fname': t1.fname}) + SCAN(table=tpch.NATION, columns={'region_key': n_regionkey}) + PROJECT(columns={'fname': JOIN_STRINGS('-':string, LOWER(name_3), lname), 'key_2': key_2}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key_2': t1.key, 'lname': t0.lname, 'name_3': t1.name}) + PROJECT(columns={'lname': LOWER(name), 'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/correl_8.txt b/tests/test_plan_refsols/correl_8.txt index 87bcc66e..cf9d5a88 100644 --- a/tests/test_plan_refsols/correl_8.txt +++ b/tests/test_plan_refsols/correl_8.txt @@ -1,7 +1,9 @@ ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) - PROJECT(columns={'name': name, 'rname': name_4}) - JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'name': t0.name, 'name_4': t1.name}, correl_name='corr1') + PROJECT(columns={'name': name, 'rname': name_3}) + JOIN(conditions=[t0.region_key == t1.key_2], types=['left'], columns={'name': t0.name, 'name_3': t1.name_3}) SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) - FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name': name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key_2': key_2, 'name_3': name_3}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/correl_9.txt b/tests/test_plan_refsols/correl_9.txt index 6a7a6c13..58564f54 100644 --- a/tests/test_plan_refsols/correl_9.txt +++ b/tests/test_plan_refsols/correl_9.txt @@ -1,8 +1,10 @@ ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) - PROJECT(columns={'name': name, 'rname': name_4}) - FILTER(condition=True:bool, columns={'name': name, 'name_4': name_4}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name': t0.name, 'name_4': t1.name}, correl_name='corr1') + PROJECT(columns={'name': name, 'rname': name_3}) + FILTER(condition=True:bool, columns={'name': name, 'name_3': name_3}) + JOIN(conditions=[t0.region_key == t1.key_2], types=['inner'], columns={'name': t0.name, 'name_3': t1.name_3}) SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) - FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name': name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key_2': key_2, 'name_3': name_3}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) From 9191ec46506bf127b2ea5258acebde5ba807d97a Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 7 Feb 2025 12:41:10 -0500 Subject: [PATCH 095/112] Fixing aggregation for singular case --- pydough/conversion/hybrid_decorrelater.py | 15 ++++++++++ pydough/conversion/hybrid_tree.py | 2 ++ pydough/pydough_operators/base_operator.py | 12 ++++++++ pydough/sqlglot/transform_bindings.py | 1 + tests/test_pipeline.py | 32 +++++++++++++++++++++- 5 files changed, 61 insertions(+), 1 deletion(-) diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py index 4ba678d7..689f1d5d 100644 --- a/pydough/conversion/hybrid_decorrelater.py +++ b/pydough/conversion/hybrid_decorrelater.py @@ -125,6 +125,21 @@ def decorrelate_singular( operation.condition, old_parent, child_height ) level = level.parent + # Update the join keys to join on the unique keys of all the ancestors. + new_join_keys: list[tuple[HybridExpr, HybridExpr]] = [] + additional_levels: int = 0 + current_level: HybridTree | None = old_parent + while current_level is not None: + for unique_key in current_level.pipeline[0].unique_exprs: + lhs_key: HybridExpr | None = unique_key.shift_back(additional_levels) + rhs_key: HybridExpr | None = unique_key.shift_back( + additional_levels + child_height + ) + assert lhs_key is not None and rhs_key is not None + new_join_keys.append((lhs_key, rhs_key)) + current_level = current_level.parent + additional_levels += 1 + child.subtree.join_keys = new_join_keys def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: """ diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 348536b4..5a5bd3df 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -184,6 +184,8 @@ def apply_renamings(self, renamings: dict[str, str]) -> "HybridExpr": return self def shift_back(self, levels: int) -> HybridExpr | None: + if levels == 0: + return self return HybridBackRefExpr(self.name, levels, self.typ) diff --git a/pydough/pydough_operators/base_operator.py b/pydough/pydough_operators/base_operator.py index fe12629e..82ccfacf 100644 --- a/pydough/pydough_operators/base_operator.py +++ b/pydough/pydough_operators/base_operator.py @@ -80,3 +80,15 @@ def to_string(self, arg_strings: list[str]) -> str: Returns: The string representation of the operator called on its arguments. """ + + @abstractmethod + def equals(self, other: object) -> bool: + """ + Returns whether this operator is equal to another operator. + """ + + def __eq__(self, other: object) -> bool: + return self.equals(other) + + def __hash__(self) -> int: + return hash(repr(self)) diff --git a/pydough/sqlglot/transform_bindings.py b/pydough/sqlglot/transform_bindings.py index 8f925183..86a3ecdb 100644 --- a/pydough/sqlglot/transform_bindings.py +++ b/pydough/sqlglot/transform_bindings.py @@ -625,6 +625,7 @@ def call( """ if operator not in self.bindings: # TODO: (gh #169) add support for UDFs + breakpoint() raise ValueError(f"Unsupported function {operator}") binding: transform_binding = self.bindings[operator] return binding(raw_args, sql_glot_args) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index db188ef3..c1c56e83 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -885,7 +885,37 @@ ( correl_17, "correl_17", - lambda: pd.DataFrame({"fullname": [925]}), + lambda: pd.DataFrame( + { + "fullname": [ + "africa-algeria", + "africa-ethiopia", + "africa-kenya", + "africa-morocco", + "africa-mozambique", + "america-argentina", + "america-brazil", + "america-canada", + "america-peru", + "america-united states", + "asia-china", + "asia-india", + "asia-indonesia", + "asia-japan", + "asia-vietnam", + "europe-france", + "europe-germany", + "europe-romania", + "europe-russia", + "europe-united kingdom", + "middle east-egypt", + "middle east-iran", + "middle east-iraq", + "middle east-jordan", + "middle east-saudi arabia", + ] + } + ), ), id="correl_17", ), From 9bc3564c545aeec1a0e0d14baaa5458c58a7d3be Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 7 Feb 2025 16:45:17 -0500 Subject: [PATCH 096/112] WIP handling edge cases and aggregation; need to fix tpch q5/q22, and correl queries 1/2/3/8/15 --- pydough/conversion/hybrid_decorrelater.py | 42 +++++++++++++++------- pydough/conversion/hybrid_tree.py | 39 ++++++++++---------- pydough/conversion/relational_converter.py | 4 +++ pydough/sqlglot/transform_bindings.py | 1 - 4 files changed, 54 insertions(+), 32 deletions(-) diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py index 689f1d5d..3581721a 100644 --- a/pydough/conversion/hybrid_decorrelater.py +++ b/pydough/conversion/hybrid_decorrelater.py @@ -20,6 +20,7 @@ HybridFilter, HybridFunctionExpr, HybridLiteralExpr, + HybridPartition, HybridRefExpr, HybridTree, HybridWindowExpr, @@ -31,18 +32,23 @@ class Decorrelater: TODO """ - def make_decorrelate_parent(self, hybrid: HybridTree, child_idx: int) -> HybridTree: + def make_decorrelate_parent( + self, hybrid: HybridTree, child_idx: int, required_steps: int + ) -> HybridTree: """ TODO """ + if isinstance(hybrid.pipeline[0], HybridPartition) and child_idx == 0: + assert hybrid.parent is not None + return self.make_decorrelate_parent( + hybrid.parent, len(hybrid.parent.children), required_steps + ) successor: HybridTree | None = hybrid.successor hybrid._successor = None new_hybrid: HybridTree = copy.deepcopy(hybrid) hybrid._successor = successor new_hybrid._children = new_hybrid._children[:child_idx] - new_hybrid._pipeline = new_hybrid._pipeline[ - : hybrid.children[child_idx].required_steps + 1 - ] + new_hybrid._pipeline = new_hybrid._pipeline[: required_steps + 1] return new_hybrid def remove_correl_refs( @@ -95,8 +101,8 @@ def decorrelate_singular( child_root: HybridTree = child.subtree child_height: int = 1 while child_root.parent is not None: - child_root = child_root.parent child_height += 1 + child_root = child_root.parent # Link the top level of the child subtree to the new parent. new_parent.add_successor(child_root) # Replace any correlated references to the original parent with BACK references. @@ -141,6 +147,19 @@ def decorrelate_singular( additional_levels += 1 child.subtree.join_keys = new_join_keys + def decorrelate_aggregate( + self, old_parent: HybridTree, new_parent: HybridTree, child: HybridConnection + ) -> None: + """ + TODO + """ + self.decorrelate_singular(old_parent, new_parent, child) + new_agg_keys: list[HybridExpr] = [] + assert child.subtree.join_keys is not None + for _, rhs_key in child.subtree.join_keys: + new_agg_keys.append(rhs_key) + child.subtree.agg_keys = new_agg_keys + def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: """ TODO @@ -158,16 +177,15 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: for idx, child in enumerate(hybrid.children): if idx not in hybrid.correlated_children: continue - new_parent: HybridTree = self.make_decorrelate_parent(hybrid, idx) + new_parent: HybridTree = self.make_decorrelate_parent( + hybrid, idx, hybrid.children[idx].required_steps + 1 + ) match child.connection_type: case ConnectionType.SINGULAR | ConnectionType.SINGULAR_ONLY_MATCH: self.decorrelate_singular(hybrid, new_parent, child) - case ( - ConnectionType.AGGREGATION - | ConnectionType.AGGREGATION_ONLY_MATCH - | ConnectionType.NDISTINCT - | ConnectionType.NDISTINCT_ONLY_MATCH - ): + case ConnectionType.AGGREGATION | ConnectionType.AGGREGATION_ONLY_MATCH: + self.decorrelate_aggregate(hybrid, new_parent, child) + case ConnectionType.NDISTINCT | ConnectionType.NDISTINCT_ONLY_MATCH: raise NotImplementedError( f"PyDough does not yet support correlated references with the {child.connection_type.name} pattern." ) diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 5a5bd3df..bc0f76cd 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -1595,9 +1595,10 @@ def make_hybrid_correl_expr( # Special case: stepping out of the data argument of PARTITION back # into its ancestor. For example: # TPCH(x=...).PARTITION(data.WHERE(y > BACK(1).x), ...) - if len(parent_tree.pipeline) == 1 and isinstance( + partition_edge_case: bool = len(parent_tree.pipeline) == 1 and isinstance( parent_tree.pipeline[0], HybridPartition - ): + ) + if partition_edge_case: assert parent_tree.parent is not None # Treat the partition's parent as the conext for the back # to step into, as opposed to the partition itself (so the back @@ -1605,26 +1606,26 @@ def make_hybrid_correl_expr( self.stack.append(parent_tree.parent) parent_result = self.make_hybrid_correl_expr( back_expr, collection, steps_taken_so_far - ) + ).expr self.stack.pop() - self.stack.append(parent_tree) + # self.stack.append(parent_tree) # Then, postprocess the output to account for the fact that a # BACK level got skipped due to the change in subtree. - match parent_result.expr: - case HybridRefExpr(): - parent_result = HybridBackRefExpr( - parent_result.expr.name, 1, parent_result.typ - ) - case HybridBackRefExpr(): - parent_result = HybridBackRefExpr( - parent_result.expr.name, - parent_result.expr.back_idx + 1, - parent_result.typ, - ) - case _: - raise ValueError( - f"Malformed expression for correlated reference: {parent_result}" - ) + # match parent_result.expr: + # case HybridRefExpr(): + # parent_result = HybridBackRefExpr( + # parent_result.expr.name, 1, parent_result.typ + # ) + # case HybridBackRefExpr(): + # parent_result = HybridBackRefExpr( + # parent_result.expr.name, + # parent_result.expr.back_idx + 1, + # parent_result.typ, + # ) + # case _: + # raise ValueError( + # f"Malformed expression for correlated reference: {parent_result}" + # ) elif remaining_steps_back == 0: # If there are no more steps back to be made, then the correlated # reference is to a reference from the current context. diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 4c69487e..1fd3ca5c 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -1001,7 +1001,11 @@ def convert_ast_to_relational( # the relational conversion procedure. The first element in the returned # list is the final rel node. hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node, None) + print() + print(hybrid) + print("DECOR") run_hybrid_decorrelation(hybrid) + print(hybrid) renamings: dict[str, str] = hybrid.pipeline[-1].renamings output: TranslationOutput = translator.rel_translation( None, hybrid, len(hybrid.pipeline) - 1 diff --git a/pydough/sqlglot/transform_bindings.py b/pydough/sqlglot/transform_bindings.py index 86a3ecdb..8f925183 100644 --- a/pydough/sqlglot/transform_bindings.py +++ b/pydough/sqlglot/transform_bindings.py @@ -625,7 +625,6 @@ def call( """ if operator not in self.bindings: # TODO: (gh #169) add support for UDFs - breakpoint() raise ValueError(f"Unsupported function {operator}") binding: transform_binding = self.bindings[operator] return binding(raw_args, sql_glot_args) From 67a342e31921b4d4b65e50d9ea524e5928a085cb Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 7 Feb 2025 16:46:29 -0500 Subject: [PATCH 097/112] Updating plans for newly working correl queries 6/9/17 --- tests/test_plan_refsols/correl_17.txt | 13 +++++++------ tests/test_plan_refsols/correl_6.txt | 11 +++++++---- tests/test_plan_refsols/correl_9.txt | 11 ++++++----- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/tests/test_plan_refsols/correl_17.txt b/tests/test_plan_refsols/correl_17.txt index 9b802ed6..8ed9fe94 100644 --- a/tests/test_plan_refsols/correl_17.txt +++ b/tests/test_plan_refsols/correl_17.txt @@ -2,10 +2,11 @@ ROOT(columns=[('fullname', fullname)], orderings=[(ordering_0):asc_first]) PROJECT(columns={'fullname': fullname, 'ordering_0': fullname}) PROJECT(columns={'fullname': fname}) FILTER(condition=True:bool, columns={'fname': fname}) - JOIN(conditions=[t0.region_key == t1.key_2], types=['inner'], columns={'fname': t1.fname}) - SCAN(table=tpch.NATION, columns={'region_key': n_regionkey}) - PROJECT(columns={'fname': JOIN_STRINGS('-':string, LOWER(name_3), lname), 'key_2': key_2}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key_2': t1.key, 'lname': t0.lname, 'name_3': t1.name}) - PROJECT(columns={'lname': LOWER(name), 'region_key': region_key}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + JOIN(conditions=[t0.key == t1.key], types=['inner'], columns={'fname': t1.fname}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey}) + PROJECT(columns={'fname': JOIN_STRINGS('-':string, LOWER(name_3), lname), 'key': key}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'lname': t0.lname, 'name_3': t1.name}) + FILTER(condition=True:bool, columns={'key': key, 'lname': lname, 'region_key': region_key}) + PROJECT(columns={'key': key, 'lname': LOWER(name), 'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/correl_6.txt b/tests/test_plan_refsols/correl_6.txt index 0a85a6fa..18acd5f9 100644 --- a/tests/test_plan_refsols/correl_6.txt +++ b/tests/test_plan_refsols/correl_6.txt @@ -1,8 +1,11 @@ ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[]) PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) FILTER(condition=True:bool, columns={'agg_0': agg_0, 'name': name}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + JOIN(conditions=[t0.key == t1.key], types=['inner'], columns={'agg_0': t1.agg_0, 'name': t0.name}) SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + FILTER(condition=True:bool, columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/correl_9.txt b/tests/test_plan_refsols/correl_9.txt index 58564f54..4ee14bcc 100644 --- a/tests/test_plan_refsols/correl_9.txt +++ b/tests/test_plan_refsols/correl_9.txt @@ -2,9 +2,10 @@ ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_fir PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) PROJECT(columns={'name': name, 'rname': name_3}) FILTER(condition=True:bool, columns={'name': name, 'name_3': name_3}) - JOIN(conditions=[t0.region_key == t1.key_2], types=['inner'], columns={'name': t0.name, 'name_3': t1.name_3}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) - FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key_2': key_2, 'name_3': name_3}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + JOIN(conditions=[t0.key == t1.key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name_3}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name_3': name_3}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + FILTER(condition=True:bool, columns={'key': key, 'name': name, 'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) From bb24b304c85d06b157e66abd13f70c9bcda26e38 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Sat, 8 Feb 2025 22:54:02 -0500 Subject: [PATCH 098/112] Bugfixes to decorrelation, compressing helper functions --- pydough/conversion/hybrid_decorrelater.py | 47 ++++++++++++----------- tests/test_plan_refsols/correl_1.txt | 10 +++-- tests/test_plan_refsols/correl_15.txt | 15 +++++--- tests/test_plan_refsols/correl_17.txt | 5 +-- tests/test_plan_refsols/correl_2.txt | 23 ++++++----- tests/test_plan_refsols/correl_3.txt | 14 ++++--- tests/test_plan_refsols/correl_6.txt | 3 +- tests/test_plan_refsols/correl_8.txt | 10 ++--- tests/test_plan_refsols/correl_9.txt | 3 +- tests/test_plan_refsols/tpch_q5.txt | 25 +++++++----- 10 files changed, 85 insertions(+), 70 deletions(-) diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py index 3581721a..1363f503 100644 --- a/pydough/conversion/hybrid_decorrelater.py +++ b/pydough/conversion/hybrid_decorrelater.py @@ -41,14 +41,13 @@ def make_decorrelate_parent( if isinstance(hybrid.pipeline[0], HybridPartition) and child_idx == 0: assert hybrid.parent is not None return self.make_decorrelate_parent( - hybrid.parent, len(hybrid.parent.children), required_steps + hybrid.parent, len(hybrid.parent.children), len(hybrid.pipeline) ) - successor: HybridTree | None = hybrid.successor hybrid._successor = None new_hybrid: HybridTree = copy.deepcopy(hybrid) - hybrid._successor = successor new_hybrid._children = new_hybrid._children[:child_idx] new_hybrid._pipeline = new_hybrid._pipeline[: required_steps + 1] + # breakpoint() return new_hybrid def remove_correl_refs( @@ -91,8 +90,12 @@ def remove_correl_refs( f"Unsupported expression type: {expr.__class__.__name__}." ) - def decorrelate_singular( - self, old_parent: HybridTree, new_parent: HybridTree, child: HybridConnection + def decorrelate_child( + self, + old_parent: HybridTree, + new_parent: HybridTree, + child: HybridConnection, + is_aggregate: bool, ) -> None: """ TODO @@ -146,19 +149,12 @@ def decorrelate_singular( current_level = current_level.parent additional_levels += 1 child.subtree.join_keys = new_join_keys - - def decorrelate_aggregate( - self, old_parent: HybridTree, new_parent: HybridTree, child: HybridConnection - ) -> None: - """ - TODO - """ - self.decorrelate_singular(old_parent, new_parent, child) - new_agg_keys: list[HybridExpr] = [] - assert child.subtree.join_keys is not None - for _, rhs_key in child.subtree.join_keys: - new_agg_keys.append(rhs_key) - child.subtree.agg_keys = new_agg_keys + if is_aggregate: + new_agg_keys: list[HybridExpr] = [] + assert child.subtree.join_keys is not None + for _, rhs_key in child.subtree.join_keys: + new_agg_keys.append(rhs_key) + child.subtree.agg_keys = new_agg_keys def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: """ @@ -178,13 +174,18 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: if idx not in hybrid.correlated_children: continue new_parent: HybridTree = self.make_decorrelate_parent( - hybrid, idx, hybrid.children[idx].required_steps + 1 + hybrid, idx, hybrid.children[idx].required_steps ) match child.connection_type: - case ConnectionType.SINGULAR | ConnectionType.SINGULAR_ONLY_MATCH: - self.decorrelate_singular(hybrid, new_parent, child) - case ConnectionType.AGGREGATION | ConnectionType.AGGREGATION_ONLY_MATCH: - self.decorrelate_aggregate(hybrid, new_parent, child) + case ( + ConnectionType.SINGULAR + | ConnectionType.SINGULAR_ONLY_MATCH + | ConnectionType.AGGREGATION + | ConnectionType.AGGREGATION_ONLY_MATCH + ): + self.decorrelate_child( + hybrid, new_parent, child, child.connection_type.is_aggregation + ) case ConnectionType.NDISTINCT | ConnectionType.NDISTINCT_ONLY_MATCH: raise NotImplementedError( f"PyDough does not yet support correlated references with the {child.connection_type.name} pattern." diff --git a/tests/test_plan_refsols/correl_1.txt b/tests/test_plan_refsols/correl_1.txt index 135f5242..9ffac13b 100644 --- a/tests/test_plan_refsols/correl_1.txt +++ b/tests/test_plan_refsols/correl_1.txt @@ -1,7 +1,9 @@ ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[]) PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) - JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/correl_15.txt b/tests/test_plan_refsols/correl_15.txt index 82d8f929..287fab45 100644 --- a/tests/test_plan_refsols/correl_15.txt +++ b/tests/test_plan_refsols/correl_15.txt @@ -1,23 +1,26 @@ ROOT(columns=[('n', n)], orderings=[]) PROJECT(columns={'n': agg_1}) - JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1}, correl_name='corr4') + JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1}, correl_name='corr5') PROJECT(columns={'avg_price': agg_0}) AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) FILTER(condition=True:bool, columns={'account_balance': account_balance}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr3') - PROJECT(columns={'account_balance': account_balance, 'avg_price': agg_0, 'key': key}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr4') + PROJECT(columns={'account_balance': account_balance, 'avg_price_3': agg_0, 'key': key}) JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'key': t0.key}) FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'key': key}) - SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + JOIN(conditions=[True:bool], types=['inner'], columns={'account_balance': t1.account_balance, 'key': t1.key, 'nation_key': t1.nation_key}) + AGGREGATE(keys={}, aggregations={}) + SCAN(table=tpch.PART, columns={'brand': p_brand}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey}) SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice}) FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) - JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') + JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr3') SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - FILTER(condition=retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price & retail_price < corr4.avg_price * 0.9:float64, columns={'key': key}) + FILTER(condition=retail_price < corr3.supplycost * 1.5:float64 & retail_price < corr4.avg_price_3 & retail_price < corr5.avg_price * 0.9:float64, columns={'key': key}) FILTER(condition=container == 'LG DRUM':string, columns={'key': key, 'retail_price': retail_price}) SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_17.txt b/tests/test_plan_refsols/correl_17.txt index 8ed9fe94..aad7c616 100644 --- a/tests/test_plan_refsols/correl_17.txt +++ b/tests/test_plan_refsols/correl_17.txt @@ -6,7 +6,6 @@ ROOT(columns=[('fullname', fullname)], orderings=[(ordering_0):asc_first]) SCAN(table=tpch.NATION, columns={'key': n_nationkey}) PROJECT(columns={'fname': JOIN_STRINGS('-':string, LOWER(name_3), lname), 'key': key}) JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'lname': t0.lname, 'name_3': t1.name}) - FILTER(condition=True:bool, columns={'key': key, 'lname': lname, 'region_key': region_key}) - PROJECT(columns={'key': key, 'lname': LOWER(name), 'region_key': region_key}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + PROJECT(columns={'key': key, 'lname': LOWER(name), 'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/correl_2.txt b/tests/test_plan_refsols/correl_2.txt index d62b8784..0579c518 100644 --- a/tests/test_plan_refsols/correl_2.txt +++ b/tests/test_plan_refsols/correl_2.txt @@ -1,11 +1,16 @@ -ROOT(columns=[('name', name_7), ('n_selected_custs', n_selected_custs)], orderings=[]) - PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_7': name_6}) - PROJECT(columns={'n_selected_custs': DEFAULT_TO(agg_0, 0:int64), 'name_6': name_3}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}, correl_name='corr4') - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name}) - FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key, 'name': name}) +ROOT(columns=[('name', name_12), ('n_selected_custs', n_selected_custs)], orderings=[]) + PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_12': name_11}) + PROJECT(columns={'n_selected_custs': DEFAULT_TO(agg_0, 0:int64), 'name_11': name_3}) + JOIN(conditions=[t0.key_2 == t1.key_5 & t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name_3': t1.name}) + FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key}) SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=SLICE(comment, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(corr4.name, None:unknown, 1:int64, None:unknown)), columns={'nation_key': nation_key}) - SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) + AGGREGATE(keys={'key': key, 'key_5': key_5}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(comment_7, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(name, None:unknown, 1:int64, None:unknown)), columns={'key': key, 'key_5': key_5}) + JOIN(conditions=[t0.key_5 == t1.nation_key], types=['inner'], columns={'comment_7': t1.comment, 'key': t0.key, 'key_5': t0.key_5, 'name': t0.name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'name': t0.name}) + FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_3.txt b/tests/test_plan_refsols/correl_3.txt index 8e285cb4..70e5d01e 100644 --- a/tests/test_plan_refsols/correl_3.txt +++ b/tests/test_plan_refsols/correl_3.txt @@ -1,10 +1,12 @@ ROOT(columns=[('name', name), ('n_nations', n_nations)], orderings=[]) PROJECT(columns={'n_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) - JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr4') SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=True:bool, columns={'region_key': region_key}) - JOIN(conditions=[t0.key == t1.nation_key], types=['semi'], columns={'region_key': t0.region_key}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - FILTER(condition=SLICE(comment, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(corr1.name, None:unknown, 1:int64, None:unknown)), columns={'nation_key': nation_key}) + AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=True:bool, columns={'key': key}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['semi'], columns={'key': t0.key}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + FILTER(condition=SLICE(comment, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(corr4.name, None:unknown, 1:int64, None:unknown)), columns={'nation_key': nation_key}) SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_6.txt b/tests/test_plan_refsols/correl_6.txt index 18acd5f9..a6829877 100644 --- a/tests/test_plan_refsols/correl_6.txt +++ b/tests/test_plan_refsols/correl_6.txt @@ -6,6 +6,5 @@ ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()}) FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key}) JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) - FILTER(condition=True:bool, columns={'key': key, 'name': name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/correl_8.txt b/tests/test_plan_refsols/correl_8.txt index cf9d5a88..8da228f0 100644 --- a/tests/test_plan_refsols/correl_8.txt +++ b/tests/test_plan_refsols/correl_8.txt @@ -1,9 +1,9 @@ ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_first]) PROJECT(columns={'name': name, 'ordering_0': name, 'rname': rname}) PROJECT(columns={'name': name, 'rname': name_3}) - JOIN(conditions=[t0.region_key == t1.key_2], types=['left'], columns={'name': t0.name, 'name_3': t1.name_3}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) - FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key_2': key_2, 'name_3': name_3}) - JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) + JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'name': t0.name, 'name_3': t1.name_3}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name_3': name_3}) + JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/correl_9.txt b/tests/test_plan_refsols/correl_9.txt index 4ee14bcc..449cceea 100644 --- a/tests/test_plan_refsols/correl_9.txt +++ b/tests/test_plan_refsols/correl_9.txt @@ -6,6 +6,5 @@ ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_fir SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name_3': name_3}) JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) - FILTER(condition=True:bool, columns={'key': key, 'name': name, 'region_key': region_key}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) diff --git a/tests/test_plan_refsols/tpch_q5.txt b/tests/test_plan_refsols/tpch_q5.txt index fbd35207..7cb9925b 100644 --- a/tests/test_plan_refsols/tpch_q5.txt +++ b/tests/test_plan_refsols/tpch_q5.txt @@ -1,21 +1,26 @@ ROOT(columns=[('N_NAME', N_NAME), ('REVENUE', REVENUE)], orderings=[(ordering_1):desc_last]) PROJECT(columns={'N_NAME': N_NAME, 'REVENUE': REVENUE, 'ordering_1': REVENUE}) PROJECT(columns={'N_NAME': name, 'REVENUE': DEFAULT_TO(agg_0, 0:int64)}) - JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr10') + JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) FILTER(condition=name_3 == 'ASIA':string, columns={'key': key, 'name': name}) JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': SUM(value)}) - PROJECT(columns={'nation_key': nation_key, 'value': extended_price * 1:int64 - discount}) - FILTER(condition=name_9 == corr10.name, columns={'discount': discount, 'extended_price': extended_price, 'nation_key': nation_key}) - JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name_9': t1.name_9, 'nation_key': t0.nation_key}) - JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'nation_key': t0.nation_key, 'supplier_key': t1.supplier_key}) - FILTER(condition=order_date >= datetime.date(1994, 1, 1):date & order_date < datetime.date(1995, 1, 1):date, columns={'key_5': key_5, 'nation_key': nation_key}) - JOIN(conditions=[t0.key == t1.customer_key], types=['inner'], columns={'key_5': t1.key, 'nation_key': t0.nation_key, 'order_date': t1.order_date}) - SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + AGGREGATE(keys={'key': key}, aggregations={'agg_0': SUM(value)}) + PROJECT(columns={'key': key, 'value': extended_price * 1:int64 - discount}) + FILTER(condition=name_15 == name, columns={'discount': discount, 'extended_price': extended_price, 'key': key}) + JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'key': t0.key, 'name': t0.name, 'name_15': t1.name_15}) + JOIN(conditions=[t0.key_11 == t1.order_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'key': t0.key, 'name': t0.name, 'supplier_key': t1.supplier_key}) + FILTER(condition=order_date >= datetime.date(1994, 1, 1):date & order_date < datetime.date(1995, 1, 1):date, columns={'key': key, 'key_11': key_11, 'name': name}) + JOIN(conditions=[t0.key_8 == t1.customer_key], types=['inner'], columns={'key': t0.key, 'key_11': t1.key, 'name': t0.name, 'order_date': t1.order_date}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_8': t1.key, 'name': t0.name}) + FILTER(condition=name_6 == 'ASIA':string, columns={'key': key, 'name': name}) + JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_6': t1.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) SCAN(table=tpch.LINEITEM, columns={'discount': l_discount, 'extended_price': l_extendedprice, 'order_key': l_orderkey, 'supplier_key': l_suppkey}) - JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_9': t1.name}) + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_15': t1.name}) SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey}) SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) From 8f83903e1cc0aedf1c816f8cc0752c386d207ed4 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Sat, 8 Feb 2025 23:23:14 -0500 Subject: [PATCH 099/112] Filling vlaue for correl tests 1/2/3 --- pydough/conversion/hybrid_decorrelater.py | 41 ++++++++++++++++------- tests/simple_pydough_functions.py | 20 +++++++---- tests/test_pipeline.py | 30 +++++++++++++++-- 3 files changed, 68 insertions(+), 23 deletions(-) diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py index 1363f503..41b0441f 100644 --- a/pydough/conversion/hybrid_decorrelater.py +++ b/pydough/conversion/hybrid_decorrelater.py @@ -90,27 +90,21 @@ def remove_correl_refs( f"Unsupported expression type: {expr.__class__.__name__}." ) - def decorrelate_child( + def correl_ref_purge( self, + level: HybridTree, old_parent: HybridTree, new_parent: HybridTree, - child: HybridConnection, - is_aggregate: bool, + child_height: int, ) -> None: """ TODO """ - # First, find the height of the child subtree & its top-most level. - child_root: HybridTree = child.subtree - child_height: int = 1 - while child_root.parent is not None: - child_height += 1 - child_root = child_root.parent - # Link the top level of the child subtree to the new parent. - new_parent.add_successor(child_root) - # Replace any correlated references to the original parent with BACK references. - level: HybridTree = child.subtree while level.parent is not None and level is not new_parent: + for child in level.children: + self.correl_ref_purge( + child.subtree, old_parent, new_parent, child_height + ) for operation in level.pipeline: for name, expr in operation.terms.items(): operation.terms[name] = self.remove_correl_refs( @@ -134,6 +128,27 @@ def decorrelate_child( operation.condition, old_parent, child_height ) level = level.parent + + def decorrelate_child( + self, + old_parent: HybridTree, + new_parent: HybridTree, + child: HybridConnection, + is_aggregate: bool, + ) -> None: + """ + TODO + """ + # First, find the height of the child subtree & its top-most level. + child_root: HybridTree = child.subtree + child_height: int = 1 + while child_root.parent is not None: + child_height += 1 + child_root = child_root.parent + # Link the top level of the child subtree to the new parent. + new_parent.add_successor(child_root) + # Replace any correlated references to the original parent with BACK references. + self.correl_ref_purge(child.subtree, old_parent, new_parent, child_height) # Update the join keys to join on the unique keys of all the ancestors. new_join_keys: list[tuple[HybridExpr, HybridExpr]] = [] additional_levels: int = 0 diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index e9230383..443e27a8 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -279,7 +279,7 @@ def correl_1(): # access without requiring the RHS be present. return Regions( name, n_prefix_nations=COUNT(nations.WHERE(name[:1] == BACK(1).name[:1])) - ) + ).ORDER_BY(name.ASC()) def correl_2(): @@ -289,20 +289,26 @@ def correl_2(): # with the letter a. This is a true correlated join doing an aggregated # access without requiring the RHS be present. selected_custs = customers.WHERE(comment[:1] == LOWER(BACK(2).name[:1])) - return Regions.WHERE(~STARTSWITH(name, "A")).nations( - name, - n_selected_custs=COUNT(selected_custs), + return ( + Regions.WHERE(~STARTSWITH(name, "A")) + .nations( + name, + n_selected_custs=COUNT(selected_custs), + ) + .ORDER_BY(name.ASC()) ) def correl_3(): # Correlated back reference example #3: double-layer correlated reference # For every every region, count how many of its nations have a customer - # whose comment starts with the same letter as the region. This is a true + # whose comment starts with the same 2 letter as the region. This is a true # correlated join doing an aggregated access without requiring the RHS be # present. - selected_custs = customers.WHERE(comment[:1] == LOWER(BACK(2).name[:1])) - return Regions(name, n_nations=COUNT(nations.WHERE(HAS(selected_custs)))) + selected_custs = customers.WHERE(comment[:2] == LOWER(BACK(2).name[:2])) + return Regions(name, n_nations=COUNT(nations.WHERE(HAS(selected_custs)))).ORDER_BY( + name.ASC() + ) def correl_4(): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index c1c56e83..82de65e3 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -648,7 +648,7 @@ "correl_1", lambda: pd.DataFrame( { - "name": ["AFRICA" "AMERICA" "MIDDLE EAST" "EUROPE" "ASIA"], + "name": ["AFRICA", "AMERICA", "ASIA", "EUROPE", "MIDDLE EAST"], "n_prefix_nations": [1, 1, 0, 0, 0], } ), @@ -661,7 +661,30 @@ "correl_2", lambda: pd.DataFrame( { - "name": ["A"] * 5, + "name": [ + "EGYPT", + "FRANCE", + "GERMANY", + "IRAN", + "IRAQ", + "JORDAN", + "ROMANIA", + "RUSSIA", + "SAUDI ARABIA", + "UNITED KINGDOM", + ], + "n_selected_custs": [ + 19, + 593, + 595, + 15, + 21, + 9, + 588, + 620, + 19, + 585, + ], } ), ), @@ -673,7 +696,8 @@ "correl_3", lambda: pd.DataFrame( { - "name": ["A"] * 5, + "name": ["AFRICA" "AMERICA" "ASIA" "EUROPE" "MIDDLE EAST"], + "n_nations": [5, 5, 5, 0, 2], } ), ), From b0964433027bc510d440d3caad591c28619ac19c Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Sat, 8 Feb 2025 23:24:01 -0500 Subject: [PATCH 100/112] Pulling up testing completion changes --- tests/simple_pydough_functions.py | 20 ++++++---- tests/test_pipeline.py | 62 +++++++++++++++++++++++++++++-- 2 files changed, 71 insertions(+), 11 deletions(-) diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index e9230383..443e27a8 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -279,7 +279,7 @@ def correl_1(): # access without requiring the RHS be present. return Regions( name, n_prefix_nations=COUNT(nations.WHERE(name[:1] == BACK(1).name[:1])) - ) + ).ORDER_BY(name.ASC()) def correl_2(): @@ -289,20 +289,26 @@ def correl_2(): # with the letter a. This is a true correlated join doing an aggregated # access without requiring the RHS be present. selected_custs = customers.WHERE(comment[:1] == LOWER(BACK(2).name[:1])) - return Regions.WHERE(~STARTSWITH(name, "A")).nations( - name, - n_selected_custs=COUNT(selected_custs), + return ( + Regions.WHERE(~STARTSWITH(name, "A")) + .nations( + name, + n_selected_custs=COUNT(selected_custs), + ) + .ORDER_BY(name.ASC()) ) def correl_3(): # Correlated back reference example #3: double-layer correlated reference # For every every region, count how many of its nations have a customer - # whose comment starts with the same letter as the region. This is a true + # whose comment starts with the same 2 letter as the region. This is a true # correlated join doing an aggregated access without requiring the RHS be # present. - selected_custs = customers.WHERE(comment[:1] == LOWER(BACK(2).name[:1])) - return Regions(name, n_nations=COUNT(nations.WHERE(HAS(selected_custs)))) + selected_custs = customers.WHERE(comment[:2] == LOWER(BACK(2).name[:2])) + return Regions(name, n_nations=COUNT(nations.WHERE(HAS(selected_custs)))).ORDER_BY( + name.ASC() + ) def correl_4(): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index db188ef3..82de65e3 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -648,7 +648,7 @@ "correl_1", lambda: pd.DataFrame( { - "name": ["AFRICA" "AMERICA" "MIDDLE EAST" "EUROPE" "ASIA"], + "name": ["AFRICA", "AMERICA", "ASIA", "EUROPE", "MIDDLE EAST"], "n_prefix_nations": [1, 1, 0, 0, 0], } ), @@ -661,7 +661,30 @@ "correl_2", lambda: pd.DataFrame( { - "name": ["A"] * 5, + "name": [ + "EGYPT", + "FRANCE", + "GERMANY", + "IRAN", + "IRAQ", + "JORDAN", + "ROMANIA", + "RUSSIA", + "SAUDI ARABIA", + "UNITED KINGDOM", + ], + "n_selected_custs": [ + 19, + 593, + 595, + 15, + 21, + 9, + 588, + 620, + 19, + 585, + ], } ), ), @@ -673,7 +696,8 @@ "correl_3", lambda: pd.DataFrame( { - "name": ["A"] * 5, + "name": ["AFRICA" "AMERICA" "ASIA" "EUROPE" "MIDDLE EAST"], + "n_nations": [5, 5, 5, 0, 2], } ), ), @@ -885,7 +909,37 @@ ( correl_17, "correl_17", - lambda: pd.DataFrame({"fullname": [925]}), + lambda: pd.DataFrame( + { + "fullname": [ + "africa-algeria", + "africa-ethiopia", + "africa-kenya", + "africa-morocco", + "africa-mozambique", + "america-argentina", + "america-brazil", + "america-canada", + "america-peru", + "america-united states", + "asia-china", + "asia-india", + "asia-indonesia", + "asia-japan", + "asia-vietnam", + "europe-france", + "europe-germany", + "europe-romania", + "europe-russia", + "europe-united kingdom", + "middle east-egypt", + "middle east-iran", + "middle east-iraq", + "middle east-jordan", + "middle east-saudi arabia", + ] + } + ), ), id="correl_17", ), From e9f2606a21ee191e5fd77289f6cd879a1c6a61a0 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Sat, 8 Feb 2025 23:24:30 -0500 Subject: [PATCH 101/112] Updating refsols --- tests/test_plan_refsols/correl_1.txt | 15 ++++++++------- tests/test_plan_refsols/correl_2.txt | 23 ++++++++++++----------- tests/test_plan_refsols/correl_3.txt | 21 +++++++++++---------- 3 files changed, 31 insertions(+), 28 deletions(-) diff --git a/tests/test_plan_refsols/correl_1.txt b/tests/test_plan_refsols/correl_1.txt index 135f5242..c5956c80 100644 --- a/tests/test_plan_refsols/correl_1.txt +++ b/tests/test_plan_refsols/correl_1.txt @@ -1,7 +1,8 @@ -ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[]) - PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) - JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) +ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'n_prefix_nations': n_prefix_nations, 'name': name, 'ordering_1': name}) + PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/correl_2.txt b/tests/test_plan_refsols/correl_2.txt index d62b8784..529b06fd 100644 --- a/tests/test_plan_refsols/correl_2.txt +++ b/tests/test_plan_refsols/correl_2.txt @@ -1,11 +1,12 @@ -ROOT(columns=[('name', name_7), ('n_selected_custs', n_selected_custs)], orderings=[]) - PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_7': name_6}) - PROJECT(columns={'n_selected_custs': DEFAULT_TO(agg_0, 0:int64), 'name_6': name_3}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}, correl_name='corr4') - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name}) - FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key, 'name': name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=SLICE(comment, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(corr4.name, None:unknown, 1:int64, None:unknown)), columns={'nation_key': nation_key}) - SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) +ROOT(columns=[('name', name_7), ('n_selected_custs', n_selected_custs)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_7': name_6, 'ordering_1': ordering_1}) + PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_6': name_6, 'ordering_1': name_6}) + PROJECT(columns={'n_selected_custs': DEFAULT_TO(agg_0, 0:int64), 'name_6': name_3}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}, correl_name='corr4') + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name}) + FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(comment, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(corr4.name, None:unknown, 1:int64, None:unknown)), columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_3.txt b/tests/test_plan_refsols/correl_3.txt index 8e285cb4..57d2dfdf 100644 --- a/tests/test_plan_refsols/correl_3.txt +++ b/tests/test_plan_refsols/correl_3.txt @@ -1,10 +1,11 @@ -ROOT(columns=[('name', name), ('n_nations', n_nations)], orderings=[]) - PROJECT(columns={'n_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) - JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=True:bool, columns={'region_key': region_key}) - JOIN(conditions=[t0.key == t1.nation_key], types=['semi'], columns={'region_key': t0.region_key}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - FILTER(condition=SLICE(comment, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(corr1.name, None:unknown, 1:int64, None:unknown)), columns={'nation_key': nation_key}) - SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) +ROOT(columns=[('name', name), ('n_nations', n_nations)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'n_nations': n_nations, 'name': name, 'ordering_1': name}) + PROJECT(columns={'n_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + JOIN(conditions=[t0.key == t1.region_key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr1') + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'region_key': region_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=True:bool, columns={'region_key': region_key}) + JOIN(conditions=[t0.key == t1.nation_key], types=['semi'], columns={'region_key': t0.region_key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + FILTER(condition=SLICE(comment, None:unknown, 2:int64, None:unknown) == LOWER(SLICE(corr1.name, None:unknown, 2:int64, None:unknown)), columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) From 0a00ceab126c3c40b59e19c1bae60d6b7fb78ca4 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 10 Feb 2025 11:12:25 -0500 Subject: [PATCH 102/112] Resolving issues for correl #3 --- pydough/conversion/hybrid_decorrelater.py | 14 ++++++---- tests/test_pipeline.py | 2 +- tests/test_plan_refsols/correl_1.txt | 19 ++++++------- tests/test_plan_refsols/correl_15.txt | 22 +++++++-------- tests/test_plan_refsols/correl_2.txt | 33 ++++++++++++----------- tests/test_plan_refsols/correl_3.txt | 25 ++++++++--------- 6 files changed, 61 insertions(+), 54 deletions(-) diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py index 41b0441f..48a1329d 100644 --- a/pydough/conversion/hybrid_decorrelater.py +++ b/pydough/conversion/hybrid_decorrelater.py @@ -58,9 +58,13 @@ def remove_correl_refs( """ match expr: case HybridCorrelExpr(): - result: HybridExpr | None = expr.expr.shift_back(child_height) - assert result is not None - return result + if expr.hybrid is parent: + result: HybridExpr | None = expr.expr.shift_back(child_height) + assert result is not None + return result + else: + expr.expr = self.remove_correl_refs(expr.expr, parent, child_height) + return expr case HybridFunctionExpr(): for idx, arg in enumerate(expr.args): expr.args[idx] = self.remove_correl_refs(arg, parent, child_height) @@ -92,7 +96,7 @@ def remove_correl_refs( def correl_ref_purge( self, - level: HybridTree, + level: HybridTree | None, old_parent: HybridTree, new_parent: HybridTree, child_height: int, @@ -100,7 +104,7 @@ def correl_ref_purge( """ TODO """ - while level.parent is not None and level is not new_parent: + while level is not None and level is not new_parent: for child in level.children: self.correl_ref_purge( child.subtree, old_parent, new_parent, child_height diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 82de65e3..d09672f8 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -696,7 +696,7 @@ "correl_3", lambda: pd.DataFrame( { - "name": ["AFRICA" "AMERICA" "ASIA" "EUROPE" "MIDDLE EAST"], + "name": ["AFRICA", "AMERICA", "ASIA", "EUROPE", "MIDDLE EAST"], "n_nations": [5, 5, 5, 0, 2], } ), diff --git a/tests/test_plan_refsols/correl_1.txt b/tests/test_plan_refsols/correl_1.txt index 9ffac13b..bcc6d73d 100644 --- a/tests/test_plan_refsols/correl_1.txt +++ b/tests/test_plan_refsols/correl_1.txt @@ -1,9 +1,10 @@ -ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[]) - PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) - JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) +ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'n_prefix_nations': n_prefix_nations, 'name': name, 'ordering_1': name}) + PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey}) diff --git a/tests/test_plan_refsols/correl_15.txt b/tests/test_plan_refsols/correl_15.txt index 287fab45..0669fd5d 100644 --- a/tests/test_plan_refsols/correl_15.txt +++ b/tests/test_plan_refsols/correl_15.txt @@ -1,18 +1,18 @@ ROOT(columns=[('n', n)], orderings=[]) PROJECT(columns={'n': agg_1}) - JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1}, correl_name='corr5') - PROJECT(columns={'avg_price': agg_0}) - AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) - SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) + JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1}) + AGGREGATE(keys={}, aggregations={}) + SCAN(table=tpch.PART, columns={'brand': p_brand}) AGGREGATE(keys={}, aggregations={'agg_1': COUNT()}) FILTER(condition=True:bool, columns={'account_balance': account_balance}) JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr4') - PROJECT(columns={'account_balance': account_balance, 'avg_price_3': agg_0, 'key': key}) - JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'key': t0.key}) - FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'key': key}) - JOIN(conditions=[True:bool], types=['inner'], columns={'account_balance': t1.account_balance, 'key': t1.key, 'nation_key': t1.nation_key}) - AGGREGATE(keys={}, aggregations={}) - SCAN(table=tpch.PART, columns={'brand': p_brand}) + PROJECT(columns={'account_balance': account_balance, 'avg_price': avg_price, 'avg_price_3': agg_0, 'key': key}) + JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'avg_price': t0.avg_price, 'key': t0.key}) + FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'avg_price': avg_price, 'key': key}) + JOIN(conditions=[True:bool], types=['inner'], columns={'account_balance': t1.account_balance, 'avg_price': t0.avg_price, 'key': t1.key, 'nation_key': t1.nation_key}) + PROJECT(columns={'avg_price': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)}) + SCAN(table=tpch.PART, columns={'retail_price': p_retailprice}) SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': AVG(retail_price)}) JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'retail_price': t1.retail_price, 'supplier_key': t0.supplier_key}) @@ -21,6 +21,6 @@ ROOT(columns=[('n', n)], orderings=[]) FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr3') SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - FILTER(condition=retail_price < corr3.supplycost * 1.5:float64 & retail_price < corr4.avg_price_3 & retail_price < corr5.avg_price * 0.9:float64, columns={'key': key}) + FILTER(condition=retail_price < corr3.supplycost * 1.5:float64 & retail_price < corr4.avg_price_3 & retail_price < corr4.avg_price * 0.9:float64, columns={'key': key}) FILTER(condition=container == 'LG DRUM':string, columns={'key': key, 'retail_price': retail_price}) SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_2.txt b/tests/test_plan_refsols/correl_2.txt index 0579c518..b5d64b7c 100644 --- a/tests/test_plan_refsols/correl_2.txt +++ b/tests/test_plan_refsols/correl_2.txt @@ -1,16 +1,17 @@ -ROOT(columns=[('name', name_12), ('n_selected_custs', n_selected_custs)], orderings=[]) - PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_12': name_11}) - PROJECT(columns={'n_selected_custs': DEFAULT_TO(agg_0, 0:int64), 'name_11': name_3}) - JOIN(conditions=[t0.key_2 == t1.key_5 & t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name_3': t1.name}) - FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) - AGGREGATE(keys={'key': key, 'key_5': key_5}, aggregations={'agg_0': COUNT()}) - FILTER(condition=SLICE(comment_7, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(name, None:unknown, 1:int64, None:unknown)), columns={'key': key, 'key_5': key_5}) - JOIN(conditions=[t0.key_5 == t1.nation_key], types=['inner'], columns={'comment_7': t1.comment, 'key': t0.key, 'key_5': t0.key_5, 'name': t0.name}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'name': t0.name}) - FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key, 'name': name}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) +ROOT(columns=[('name', name_12), ('n_selected_custs', n_selected_custs)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_12': name_11, 'ordering_1': ordering_1}) + PROJECT(columns={'n_selected_custs': n_selected_custs, 'name_11': name_11, 'ordering_1': name_11}) + PROJECT(columns={'n_selected_custs': DEFAULT_TO(agg_0, 0:int64), 'name_11': name_3}) + JOIN(conditions=[t0.key_2 == t1.key_5 & t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name_3': t1.name}) + FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey}) + AGGREGATE(keys={'key': key, 'key_5': key_5}, aggregations={'agg_0': COUNT()}) + FILTER(condition=SLICE(comment_7, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(name, None:unknown, 1:int64, None:unknown)), columns={'key': key, 'key_5': key_5}) + JOIN(conditions=[t0.key_5 == t1.nation_key], types=['inner'], columns={'comment_7': t1.comment, 'key': t0.key, 'key_5': t0.key_5, 'name': t0.name}) + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'name': t0.name}) + FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key, 'name': name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_3.txt b/tests/test_plan_refsols/correl_3.txt index 70e5d01e..2bbb01bc 100644 --- a/tests/test_plan_refsols/correl_3.txt +++ b/tests/test_plan_refsols/correl_3.txt @@ -1,12 +1,13 @@ -ROOT(columns=[('name', name), ('n_nations', n_nations)], orderings=[]) - PROJECT(columns={'n_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) - JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}, correl_name='corr4') - SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) - AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()}) - FILTER(condition=True:bool, columns={'key': key}) - JOIN(conditions=[t0.key_2 == t1.nation_key], types=['semi'], columns={'key': t0.key}) - JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key}) - SCAN(table=tpch.REGION, columns={'key': r_regionkey}) - SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) - FILTER(condition=SLICE(comment, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(corr4.name, None:unknown, 1:int64, None:unknown)), columns={'nation_key': nation_key}) - SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) +ROOT(columns=[('name', name), ('n_nations', n_nations)], orderings=[(ordering_1):asc_first]) + PROJECT(columns={'n_nations': n_nations, 'name': name, 'ordering_1': name}) + PROJECT(columns={'n_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name}) + JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=True:bool, columns={'key': key}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['semi'], columns={'key': t0.key}, correl_name='corr4') + JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey}) + FILTER(condition=SLICE(comment, None:unknown, 2:int64, None:unknown) == LOWER(SLICE(corr4.name, None:unknown, 2:int64, None:unknown)), columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey}) From d1e1520d980aa726a1520d63df9cdd202dcb923c Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 10 Feb 2025 11:14:16 -0500 Subject: [PATCH 103/112] Fixing correl #3 test output --- tests/test_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 82de65e3..d09672f8 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -696,7 +696,7 @@ "correl_3", lambda: pd.DataFrame( { - "name": ["AFRICA" "AMERICA" "ASIA" "EUROPE" "MIDDLE EAST"], + "name": ["AFRICA", "AMERICA", "ASIA", "EUROPE", "MIDDLE EAST"], "n_nations": [5, 5, 5, 0, 2], } ), From fd37df11271d03372d37fc5650e06318718c8286 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 10 Feb 2025 11:16:41 -0500 Subject: [PATCH 104/112] Moving correlated tests to their own file --- tests/correlated_pydough_functions.py | 255 ++++++++++++++++++++++++++ tests/simple_pydough_functions.py | 250 +------------------------ tests/test_pipeline.py | 6 +- 3 files changed, 262 insertions(+), 249 deletions(-) create mode 100644 tests/correlated_pydough_functions.py diff --git a/tests/correlated_pydough_functions.py b/tests/correlated_pydough_functions.py new file mode 100644 index 00000000..1e161871 --- /dev/null +++ b/tests/correlated_pydough_functions.py @@ -0,0 +1,255 @@ +""" +Variant of `simple_pydough_functions.py` for functions testing edge cases in +correlation & de-correlation handling. +""" + +# ruff: noqa +# mypy: ignore-errors +# ruff & mypy should not try to typecheck or verify any of this + + +def correl_1(): + # Correlated back reference example #1: simple 1-step correlated reference + # For each region, count how many of its its nations start with the same + # letter as the region. This is a true correlated join doing an aggregated + # access without requiring the RHS be present. + return Regions( + name, n_prefix_nations=COUNT(nations.WHERE(name[:1] == BACK(1).name[:1])) + ).ORDER_BY(name.ASC()) + + +def correl_2(): + # Correlated back reference example #2: simple 2-step correlated reference + # For each region's nations, count how many customers have a comment + # starting with the same letter as the region. Exclude regions that start + # with the letter a. This is a true correlated join doing an aggregated + # access without requiring the RHS be present. + selected_custs = customers.WHERE(comment[:1] == LOWER(BACK(2).name[:1])) + return ( + Regions.WHERE(~STARTSWITH(name, "A")) + .nations( + name, + n_selected_custs=COUNT(selected_custs), + ) + .ORDER_BY(name.ASC()) + ) + + +def correl_3(): + # Correlated back reference example #3: double-layer correlated reference + # For every every region, count how many of its nations have a customer + # whose comment starts with the same 2 letter as the region. This is a true + # correlated join doing an aggregated access without requiring the RHS be + # present. + selected_custs = customers.WHERE(comment[:2] == LOWER(BACK(2).name[:2])) + return Regions(name, n_nations=COUNT(nations.WHERE(HAS(selected_custs)))).ORDER_BY( + name.ASC() + ) + + +def correl_4(): + # Correlated back reference example #4: 2-step correlated HASNOT + # Find every nation that does not have a customer whose account balance is + # within $5 of the smallest known account balance globally. + # (This is a correlated ANTI-join) + selected_customers = customers.WHERE(acctbal <= (BACK(2).smallest_bal + 5.0)) + return ( + TPCH( + smallest_bal=MIN(Customers.acctbal), + ) + .Nations(name) + .WHERE(HASNOT(selected_customers)) + .ORDER_BY(name.ASC()) + ) + + +def correl_5(): + # Correlated back reference example #5: 2-step correlated HAS + # Find every region that has at least 1 supplier whose account balance is + # within $4 of the smallest known account balance globally. + # (This is a correlated SEMI-join) + selected_suppliers = nations.suppliers.WHERE( + account_balance <= (BACK(3).smallest_bal + 4.0) + ) + return ( + TPCH( + smallest_bal=MIN(Suppliers.account_balance), + ) + .Regions(name) + .WHERE(HAS(selected_suppliers)) + .ORDER_BY(name.ASC()) + ) + + +def correl_6(): + # Correlated back reference example #6: simple 1-step correlated reference + # For each region, count how many of its its nations start with the same + # letter as the region, but only keep regions with at least one such nation. + # This is a true correlated join doing an aggregated access that does NOT + # require that records without the RHS be kept. + selected_nations = nations.WHERE(name[:1] == BACK(1).name[:1]) + return Regions.WHERE(HAS(selected_nations))( + name, n_prefix_nations=COUNT(selected_nations) + ) + + +def correl_7(): + # Correlated back reference example #6: deleted correlated reference + # For each region, count how many of its its nations start with the same + # letter as the region, but only keep regions without at least one such + # nation. The true correlated join is trumped by the correlated ANTI-join. + selected_nations = nations.WHERE(name[:1] == BACK(1).name[:1]) + return Regions.WHERE(HASNOT(selected_nations))( + name, n_prefix_nations=COUNT(selected_nations) + ) + + +def correl_8(): + # Correlated back reference example #8: non-agg correlated reference + # For each nation, fetch the name of its region, but filter the reigon + # so it only keeps it if it starts with the same letter as the nation + # (otherwise, returns NULL). This is a true correlated join doing an + # access without aggregation without requiring the RHS be + # present. + aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) + return Nations(name, rname=aug_region.name).ORDER_BY(name.ASC()) + + +def correl_9(): + # Correlated back reference example #9: non-agg correlated reference + # For each nation, fetch the name of its region, but filter the reigon + # so it only keeps it if it starts with the same letter as the nation + # (otherwise, omit the nation). This is a true correlated join doing an + # access that also requires the RHS records be present. + aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) + return Nations.WHERE(HAS(aug_region))(name, rname=aug_region.name).ORDER_BY( + name.ASC() + ) + + +def correl_10(): + # Correlated back reference example #10: deleted correlated reference + # For each nation, fetch the name of its region, but filter the reigon + # so it only keeps it if it starts with the same letter as the nation + # (otherwise, returns NULL), and also filter the nations to only keep + # records where the region is NULL. The true correlated join is trumped by + # the correlated ANTI-join. + aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) + return Nations.WHERE(HASNOT(aug_region))(name, rname=aug_region.name).ORDER_BY( + name.ASC() + ) + + +def correl_11(): + # Correlated back reference example #11: backref out of partition child. + # Which part brands have at least 1 part that more than 40% above the + # average retail price for all parts from that brand. + # (This is a correlated SEMI-join) + brands = PARTITION(Parts, name="p", by=brand)(avg_price=AVG(p.retail_price)) + outlier_parts = p.WHERE(retail_price > 1.4 * BACK(1).avg_price) + selected_brands = brands.WHERE(HAS(outlier_parts)) + return selected_brands(brand).ORDER_BY(brand.ASC()) + + +def correl_12(): + # Correlated back reference example #12: backref out of partition child. + # Which part brands have at least 1 part that is above the average retail + # price for parts of that brand, below the average retail price for all + # parts, and has a size below 3. + # (This is a correlated SEMI-join) + global_info = TPCH(avg_price=AVG(Parts.retail_price)) + brands = global_info.PARTITION(Parts, name="p", by=brand)( + avg_price=AVG(p.retail_price) + ) + selected_parts = p.WHERE( + (retail_price > BACK(1).avg_price) + & (retail_price < BACK(2).avg_price) + & (size < 3) + ) + selected_brands = brands.WHERE(HAS(selected_parts)) + return selected_brands(brand).ORDER_BY(brand.ASC()) + + +def correl_13(): + # Correlated back reference example #13: multiple correlation. + # Count how many suppliers sell at least one part where the retail price + # is less than a 50% markup over the supply cost. Only considers suppliers + # from nations #1/#2/#3, and small parts. + # (This is a correlated SEMI-joins) + selected_part = part.WHERE(STARTSWITH(container, "SM")).WHERE( + retail_price < (BACK(1).supplycost * 1.5) + ) + selected_supply_records = supply_records.WHERE(HAS(selected_part)) + supplier_info = Suppliers.WHERE(nation_key <= 3)( + avg_price=AVG(supply_records.part.retail_price) + ) + selected_suppliers = supplier_info.WHERE(COUNT(selected_supply_records) > 0) + return TPCH(n=COUNT(selected_suppliers)) + + +def correl_14(): + # Correlated back reference example #14: multiple correlation. + # Count how many suppliers sell at least one part where the retail price + # is less than a 50% markup over the supply cost, and the retail price of + # the part is below the average for all parts from the supplier. Only + # considers suppliers from nations #19, and LG DRUM parts. + # (This is multiple correlated SEMI-joins) + selected_part = part.WHERE(container == "LG DRUM").WHERE( + (retail_price < (BACK(1).supplycost * 1.5)) & (retail_price < BACK(2).avg_price) + ) + selected_supply_records = supply_records.WHERE(HAS(selected_part)) + supplier_info = Suppliers.WHERE(nation_key == 19)( + avg_price=AVG(supply_records.part.retail_price) + ) + selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records)) + return TPCH(n=COUNT(selected_suppliers)) + + +def correl_15(): + # Correlated back reference example #15: multiple correlation. + # Count how many suppliers sell at least one part where the retail price + # is less than a 50% markup over the supply cost, and the retail price of + # the part is below the 90% of the average of the retail price for all + # parts globally and below the average for all parts from the supplier. + # Only considers suppliers from nations #19, and LG DRUM parts. + # (This is multiple correlated SEMI-joins & a correlated aggregate) + selected_part = part.WHERE(container == "LG DRUM").WHERE( + (retail_price < (BACK(1).supplycost * 1.5)) + & (retail_price < BACK(2).avg_price) + & (retail_price < BACK(3).avg_price * 0.9) + ) + selected_supply_records = supply_records.WHERE(HAS(selected_part)) + supplier_info = Suppliers.WHERE(nation_key == 19)( + avg_price=AVG(supply_records.part.retail_price) + ) + selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records)) + global_info = TPCH(avg_price=AVG(Parts.retail_price)) + return global_info(n=COUNT(selected_suppliers)) + + +def correl_16(): + # Correlated back reference example #16: hybrid tree order of operations. + # Count how many european suppliers have the exact same percentile value + # of account balance (relative to all other suppliers) as at least one + # customer's percentile value of account balance relative to all other + # customers. Percentile should be measured down to increments of 0.01%. + # (This is a correlated SEMI-joins) + selected_customers = nation(rname=region.name).customers.WHERE( + (PERCENTILE(by=(acctbal.ASC(), key.ASC()), n_buckets=10000) == BACK(2).tile) + & (BACK(1).rname == "EUROPE") + ) + supplier_info = Suppliers( + tile=PERCENTILE(by=(account_balance.ASC(), key.ASC()), n_buckets=10000) + ) + selected_suppliers = supplier_info.WHERE(HAS(selected_customers)) + return TPCH(n=COUNT(selected_suppliers)) + + +def correl_17(): + # Correlated back reference example #17: hybrid tree order of operations. + # An extremely roundabout way of getting each region_name-nation_name + # pair as a string. + # (This is a correlated singular/semi access) + region_info = region(fname=JOIN_STRINGS("-", LOWER(name), BACK(1).lname)) + nation_info = Nations(lname=LOWER(name)).WHERE(HAS(region_info)) + return nation_info(fullname=region_info.fname).ORDER_BY(fullname.ASC()) diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py index 443e27a8..f74ca874 100644 --- a/tests/simple_pydough_functions.py +++ b/tests/simple_pydough_functions.py @@ -1,3 +1,6 @@ +""" +Various functions containing PyDough code snippets for testing purposes. +""" # ruff: noqa # mypy: ignore-errors # ruff & mypy should not try to typecheck or verify any of this @@ -272,253 +275,6 @@ def triple_partition(): ) -def correl_1(): - # Correlated back reference example #1: simple 1-step correlated reference - # For each region, count how many of its its nations start with the same - # letter as the region. This is a true correlated join doing an aggregated - # access without requiring the RHS be present. - return Regions( - name, n_prefix_nations=COUNT(nations.WHERE(name[:1] == BACK(1).name[:1])) - ).ORDER_BY(name.ASC()) - - -def correl_2(): - # Correlated back reference example #2: simple 2-step correlated reference - # For each region's nations, count how many customers have a comment - # starting with the same letter as the region. Exclude regions that start - # with the letter a. This is a true correlated join doing an aggregated - # access without requiring the RHS be present. - selected_custs = customers.WHERE(comment[:1] == LOWER(BACK(2).name[:1])) - return ( - Regions.WHERE(~STARTSWITH(name, "A")) - .nations( - name, - n_selected_custs=COUNT(selected_custs), - ) - .ORDER_BY(name.ASC()) - ) - - -def correl_3(): - # Correlated back reference example #3: double-layer correlated reference - # For every every region, count how many of its nations have a customer - # whose comment starts with the same 2 letter as the region. This is a true - # correlated join doing an aggregated access without requiring the RHS be - # present. - selected_custs = customers.WHERE(comment[:2] == LOWER(BACK(2).name[:2])) - return Regions(name, n_nations=COUNT(nations.WHERE(HAS(selected_custs)))).ORDER_BY( - name.ASC() - ) - - -def correl_4(): - # Correlated back reference example #4: 2-step correlated HASNOT - # Find every nation that does not have a customer whose account balance is - # within $5 of the smallest known account balance globally. - # (This is a correlated ANTI-join) - selected_customers = customers.WHERE(acctbal <= (BACK(2).smallest_bal + 5.0)) - return ( - TPCH( - smallest_bal=MIN(Customers.acctbal), - ) - .Nations(name) - .WHERE(HASNOT(selected_customers)) - .ORDER_BY(name.ASC()) - ) - - -def correl_5(): - # Correlated back reference example #5: 2-step correlated HAS - # Find every region that has at least 1 supplier whose account balance is - # within $4 of the smallest known account balance globally. - # (This is a correlated SEMI-join) - selected_suppliers = nations.suppliers.WHERE( - account_balance <= (BACK(3).smallest_bal + 4.0) - ) - return ( - TPCH( - smallest_bal=MIN(Suppliers.account_balance), - ) - .Regions(name) - .WHERE(HAS(selected_suppliers)) - .ORDER_BY(name.ASC()) - ) - - -def correl_6(): - # Correlated back reference example #6: simple 1-step correlated reference - # For each region, count how many of its its nations start with the same - # letter as the region, but only keep regions with at least one such nation. - # This is a true correlated join doing an aggregated access that does NOT - # require that records without the RHS be kept. - selected_nations = nations.WHERE(name[:1] == BACK(1).name[:1]) - return Regions.WHERE(HAS(selected_nations))( - name, n_prefix_nations=COUNT(selected_nations) - ) - - -def correl_7(): - # Correlated back reference example #6: deleted correlated reference - # For each region, count how many of its its nations start with the same - # letter as the region, but only keep regions without at least one such - # nation. The true correlated join is trumped by the correlated ANTI-join. - selected_nations = nations.WHERE(name[:1] == BACK(1).name[:1]) - return Regions.WHERE(HASNOT(selected_nations))( - name, n_prefix_nations=COUNT(selected_nations) - ) - - -def correl_8(): - # Correlated back reference example #8: non-agg correlated reference - # For each nation, fetch the name of its region, but filter the reigon - # so it only keeps it if it starts with the same letter as the nation - # (otherwise, returns NULL). This is a true correlated join doing an - # access without aggregation without requiring the RHS be - # present. - aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) - return Nations(name, rname=aug_region.name).ORDER_BY(name.ASC()) - - -def correl_9(): - # Correlated back reference example #9: non-agg correlated reference - # For each nation, fetch the name of its region, but filter the reigon - # so it only keeps it if it starts with the same letter as the nation - # (otherwise, omit the nation). This is a true correlated join doing an - # access that also requires the RHS records be present. - aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) - return Nations.WHERE(HAS(aug_region))(name, rname=aug_region.name).ORDER_BY( - name.ASC() - ) - - -def correl_10(): - # Correlated back reference example #10: deleted correlated reference - # For each nation, fetch the name of its region, but filter the reigon - # so it only keeps it if it starts with the same letter as the nation - # (otherwise, returns NULL), and also filter the nations to only keep - # records where the region is NULL. The true correlated join is trumped by - # the correlated ANTI-join. - aug_region = region.WHERE(name[:1] == BACK(1).name[:1]) - return Nations.WHERE(HASNOT(aug_region))(name, rname=aug_region.name).ORDER_BY( - name.ASC() - ) - - -def correl_11(): - # Correlated back reference example #11: backref out of partition child. - # Which part brands have at least 1 part that more than 40% above the - # average retail price for all parts from that brand. - # (This is a correlated SEMI-join) - brands = PARTITION(Parts, name="p", by=brand)(avg_price=AVG(p.retail_price)) - outlier_parts = p.WHERE(retail_price > 1.4 * BACK(1).avg_price) - selected_brands = brands.WHERE(HAS(outlier_parts)) - return selected_brands(brand).ORDER_BY(brand.ASC()) - - -def correl_12(): - # Correlated back reference example #12: backref out of partition child. - # Which part brands have at least 1 part that is above the average retail - # price for parts of that brand, below the average retail price for all - # parts, and has a size below 3. - # (This is a correlated SEMI-join) - global_info = TPCH(avg_price=AVG(Parts.retail_price)) - brands = global_info.PARTITION(Parts, name="p", by=brand)( - avg_price=AVG(p.retail_price) - ) - selected_parts = p.WHERE( - (retail_price > BACK(1).avg_price) - & (retail_price < BACK(2).avg_price) - & (size < 3) - ) - selected_brands = brands.WHERE(HAS(selected_parts)) - return selected_brands(brand).ORDER_BY(brand.ASC()) - - -def correl_13(): - # Correlated back reference example #13: multiple correlation. - # Count how many suppliers sell at least one part where the retail price - # is less than a 50% markup over the supply cost. Only considers suppliers - # from nations #1/#2/#3, and small parts. - # (This is a correlated SEMI-joins) - selected_part = part.WHERE(STARTSWITH(container, "SM")).WHERE( - retail_price < (BACK(1).supplycost * 1.5) - ) - selected_supply_records = supply_records.WHERE(HAS(selected_part)) - supplier_info = Suppliers.WHERE(nation_key <= 3)( - avg_price=AVG(supply_records.part.retail_price) - ) - selected_suppliers = supplier_info.WHERE(COUNT(selected_supply_records) > 0) - return TPCH(n=COUNT(selected_suppliers)) - - -def correl_14(): - # Correlated back reference example #14: multiple correlation. - # Count how many suppliers sell at least one part where the retail price - # is less than a 50% markup over the supply cost, and the retail price of - # the part is below the average for all parts from the supplier. Only - # considers suppliers from nations #19, and LG DRUM parts. - # (This is multiple correlated SEMI-joins) - selected_part = part.WHERE(container == "LG DRUM").WHERE( - (retail_price < (BACK(1).supplycost * 1.5)) & (retail_price < BACK(2).avg_price) - ) - selected_supply_records = supply_records.WHERE(HAS(selected_part)) - supplier_info = Suppliers.WHERE(nation_key == 19)( - avg_price=AVG(supply_records.part.retail_price) - ) - selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records)) - return TPCH(n=COUNT(selected_suppliers)) - - -def correl_15(): - # Correlated back reference example #15: multiple correlation. - # Count how many suppliers sell at least one part where the retail price - # is less than a 50% markup over the supply cost, and the retail price of - # the part is below the 90% of the average of the retail price for all - # parts globally and below the average for all parts from the supplier. - # Only considers suppliers from nations #19, and LG DRUM parts. - # (This is multiple correlated SEMI-joins & a correlated aggregate) - selected_part = part.WHERE(container == "LG DRUM").WHERE( - (retail_price < (BACK(1).supplycost * 1.5)) - & (retail_price < BACK(2).avg_price) - & (retail_price < BACK(3).avg_price * 0.9) - ) - selected_supply_records = supply_records.WHERE(HAS(selected_part)) - supplier_info = Suppliers.WHERE(nation_key == 19)( - avg_price=AVG(supply_records.part.retail_price) - ) - selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records)) - global_info = TPCH(avg_price=AVG(Parts.retail_price)) - return global_info(n=COUNT(selected_suppliers)) - - -def correl_16(): - # Correlated back reference example #16: hybrid tree order of operations. - # Count how many european suppliers have the exact same percentile value - # of account balance (relative to all other suppliers) as at least one - # customer's percentile value of account balance relative to all other - # customers. Percentile should be measured down to increments of 0.01%. - # (This is a correlated SEMI-joins) - selected_customers = nation(rname=region.name).customers.WHERE( - (PERCENTILE(by=(acctbal.ASC(), key.ASC()), n_buckets=10000) == BACK(2).tile) - & (BACK(1).rname == "EUROPE") - ) - supplier_info = Suppliers( - tile=PERCENTILE(by=(account_balance.ASC(), key.ASC()), n_buckets=10000) - ) - selected_suppliers = supplier_info.WHERE(HAS(selected_customers)) - return TPCH(n=COUNT(selected_suppliers)) - - -def correl_17(): - # Correlated back reference example #17: hybrid tree order of operations. - # An extremely roundabout way of getting each region_name-nation_name - # pair as a string. - # (This is a correlated singular/semi access) - region_info = region(fname=JOIN_STRINGS("-", LOWER(name), BACK(1).lname)) - nation_info = Nations(lname=LOWER(name)).WHERE(HAS(region_info)) - return nation_info(fullname=region_info.fname).ORDER_BY(fullname.ASC()) - - def hour_minute_day(): """ Return the transaction IDs with the hour, minute, and second extracted from diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index d09672f8..f8c33e59 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -12,8 +12,7 @@ bad_slice_3, bad_slice_4, ) -from simple_pydough_functions import ( - agg_partition, +from correlated_pydough_functions import ( correl_1, correl_2, correl_3, @@ -31,6 +30,9 @@ correl_15, correl_16, correl_17, +) +from simple_pydough_functions import ( + agg_partition, double_partition, exponentiation, function_sampler, From 2cfe999bfed0ecab95b194bdaf2ff54802d6546e Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 10 Feb 2025 11:43:40 -0500 Subject: [PATCH 105/112] Added documentation --- pydough/conversion/hybrid_decorrelater.py | 116 ++++++++++++++++++++-- 1 file changed, 108 insertions(+), 8 deletions(-) diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py index 48a1329d..a4cd5f4a 100644 --- a/pydough/conversion/hybrid_decorrelater.py +++ b/pydough/conversion/hybrid_decorrelater.py @@ -1,5 +1,5 @@ """ -Logic for applying decorrelation to hybrid trees before relational conversion +Logic for applying de-correlation to hybrid trees before relational conversion if the correlate is not a semi/anti join. """ @@ -29,35 +29,85 @@ class Decorrelater: """ - TODO + Class that encapsulates the logic used for de-correlation of hybrid trees. """ def make_decorrelate_parent( self, hybrid: HybridTree, child_idx: int, required_steps: int ) -> HybridTree: """ - TODO + Creates a snapshot of the ancestry of the hybrid tree that contains + a correlated child, without any of its children, its descendants, or + any pipeline operators that do not need to be there. + + Args: + `hybrid`: The hybrid tree to create a snapshot of in order to aid + in the de-correlation of a correlated child. + `child_idx`: The index of the correlated child of hybrid that the + snapshot is being created to aid in the de-correlation of. + `required_steps`: The index of the last pipeline operator that + needs to be included in the snapshot in order for the child to be + derivable. + + Returns: + A snapshot of `hybrid` and its ancestry in the hybrid tree, without + without any of its children or pipeline operators that occur during + or after the derivation of the correlated child, or without any of + its descendants. """ if isinstance(hybrid.pipeline[0], HybridPartition) and child_idx == 0: + # Special case: if the correlated child is the data argument of a + # partition operation, then the parent to snapshot is actually the + # parent of the level containing the partition operation. In this + # case, all of the parent's children & pipeline operators should be + # included in the snapshot. assert hybrid.parent is not None return self.make_decorrelate_parent( hybrid.parent, len(hybrid.parent.children), len(hybrid.pipeline) ) + # Temporarily detach the successor of the current level, then create a + # deep copy of the current level (which will include its ancestors), + # then reattach the successor back to the original. This ensures that + # the descendants of the current level are not included when providing + # the parent to the correlated child as its new ancestor. + successor: HybridTree | None = hybrid.successor hybrid._successor = None new_hybrid: HybridTree = copy.deepcopy(hybrid) + hybrid._successor = successor + # Ensure the new parent only includes the children & pipeline operators + # that is has to. new_hybrid._children = new_hybrid._children[:child_idx] new_hybrid._pipeline = new_hybrid._pipeline[: required_steps + 1] - # breakpoint() return new_hybrid def remove_correl_refs( self, expr: HybridExpr, parent: HybridTree, child_height: int ) -> HybridExpr: """ - TODO + Recursively & destructively removes correlated references within a + hybrid expression if they point to a specific correlated ancestor + hybrid tree, and replaces them with corresponding BACK references. + + Args: + `expr`: The hybrid expression to remove correlated references from. + `parent`: The correlated ancestor hybrid tree that the correlated + references should point to when they are targeted for removal. + `child_height`: The height of the correlated child within the + hybrid tree that the correlated references is point to. This is + the number of BACK indices to shift by when replacing the + correlated reference with a BACK reference. + + Returns: + The hybrid expression with all correlated references to `parent` + replaced with corresponding BACK references. The replacement also + happens in-place. """ match expr: case HybridCorrelExpr(): + # If the correlated reference points to the parent, then + # replace it with a BACK reference. Otherwise, recursively + # transform its input expression in case it contains another + # correlated reference. if expr.hybrid is parent: result: HybridExpr | None = expr.expr.shift_back(child_height) assert result is not None @@ -66,10 +116,14 @@ def remove_correl_refs( expr.expr = self.remove_correl_refs(expr.expr, parent, child_height) return expr case HybridFunctionExpr(): + # For regular functions, recursively transform all of their + # arguments. for idx, arg in enumerate(expr.args): expr.args[idx] = self.remove_correl_refs(arg, parent, child_height) return expr case HybridWindowExpr(): + # For window functions, recursively transform all of their + # arguments, partition keys, and order keys. for idx, arg in enumerate(expr.args): expr.args[idx] = self.remove_correl_refs(arg, parent, child_height) for idx, arg in enumerate(expr.partition_args): @@ -88,6 +142,8 @@ def remove_correl_refs( | HybridLiteralExpr() | HybridColumnExpr() ): + # All other expression types do not require any transformation + # to de-correlate since they cannot contain correlations. return expr case _: raise NotImplementedError( @@ -102,13 +158,37 @@ def correl_ref_purge( child_height: int, ) -> None: """ - TODO + The recursive procedure to remove correlated references from the + expressions of a hybrid tree or any of its ancestors or children if + they refer to a specific correlated ancestor that is being removed. + + Args: + `level`: The current level of the hybrid tree to remove correlated + references from. + `old_parent`: The correlated ancestor hybrid tree that the correlated + references should point to when they are targeted for removal. + `new_parent`: The ancestor of `level` that removal should stop at + because it is the transposed snapshot of `old_parent`, and + therefore it & its ancestors cannot contain any more correlated + references that would be targeted for removal. + `child_height`: The height of the correlated child within the + hybrid tree that the correlated references is point to. This is + the number of BACK indices to shift by when replacing the + correlated reference with a BACK """ while level is not None and level is not new_parent: + # First, recursively remove any targeted correlated references from + # the children of the current level. for child in level.children: self.correl_ref_purge( child.subtree, old_parent, new_parent, child_height ) + # Then, remove any correlated references from the pipeline + # operators of the current level. Usually this just means + # transforming the terms/orderings/unique keys of the operation, + # but specific operation types will require special casing if they + # have additional expressions stored in other field that need to be + # transformed. for operation in level.pipeline: for name, expr in operation.terms.items(): operation.terms[name] = self.remove_correl_refs( @@ -131,6 +211,8 @@ def correl_ref_purge( operation.condition = self.remove_correl_refs( operation.condition, old_parent, child_height ) + # Repeat the process on the ancestor until either loop guard + # condition is no longer True. level = level.parent def decorrelate_child( @@ -141,7 +223,13 @@ def decorrelate_child( is_aggregate: bool, ) -> None: """ - TODO + Runs the logic to de-correlate a child of a hybrid tree that contains + a correlated reference. This involves linking the child to a new parent + as its ancestor, the parent being a snapshot of the original hybrid + tree that contained the correlated child as a child. The transformed + child can now replace correlated references with BACK references that + point to terms in its newly expanded ancestry, and the original hybrid + tree cna now join onto this child using its uniqueness keys. """ # First, find the height of the child subtree & its top-most level. child_root: HybridTree = child.subtree @@ -168,6 +256,7 @@ def decorrelate_child( current_level = current_level.parent additional_levels += 1 child.subtree.join_keys = new_join_keys + # If aggregating, do the same with the aggregation keys. if is_aggregate: new_agg_keys: list[HybridExpr] = [] assert child.subtree.join_keys is not None @@ -183,6 +272,7 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: # hybrid tree. if hybrid.parent is not None: hybrid._parent = self.decorrelate_hybrid_tree(hybrid.parent) + hybrid._parent._successor = hybrid # Iterate across all the children and recursively decorrelate them. for child in hybrid.children: child.subtree = self.decorrelate_hybrid_tree(child.subtree) @@ -224,7 +314,17 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: def run_hybrid_decorrelation(hybrid: HybridTree) -> HybridTree: """ - TODO + Invokes the procedure to remove correlated references from a hybrid tree + before relational conversion if those correlated references are invalid + (e.g. not from a semi/anti join). + + Args: + `hybrid`: The hybrid tree to remove correlated references from. + + Returns: + The hybrid tree with all invalid correlated references removed as the + tree structure is re-written to allow them to be replaced with BACK + references. The transformation is also done in-place. """ decorr: Decorrelater = Decorrelater() return decorr.decorrelate_hybrid_tree(hybrid) From 5a55ac52fbb0d22a012de48a70102319450dc4ed Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 10 Feb 2025 13:37:36 -0500 Subject: [PATCH 106/112] Pulling up downstream test changes --- tests/correlated_pydough_functions.py | 63 +++++++++++++++++++++++---- tests/test_pipeline.py | 40 ++++++++++++++++- tests/test_plan_refsols/correl_13.txt | 5 +-- tests/test_plan_refsols/correl_14.txt | 5 +-- tests/test_plan_refsols/correl_15.txt | 5 +-- 5 files changed, 100 insertions(+), 18 deletions(-) diff --git a/tests/correlated_pydough_functions.py b/tests/correlated_pydough_functions.py index 1e161871..e3b73053 100644 --- a/tests/correlated_pydough_functions.py +++ b/tests/correlated_pydough_functions.py @@ -176,8 +176,8 @@ def correl_13(): # is less than a 50% markup over the supply cost. Only considers suppliers # from nations #1/#2/#3, and small parts. # (This is a correlated SEMI-joins) - selected_part = part.WHERE(STARTSWITH(container, "SM")).WHERE( - retail_price < (BACK(1).supplycost * 1.5) + selected_part = part.WHERE( + STARTSWITH(container, "SM") & (retail_price < (BACK(1).supplycost * 1.5)) ) selected_supply_records = supply_records.WHERE(HAS(selected_part)) supplier_info = Suppliers.WHERE(nation_key <= 3)( @@ -194,8 +194,10 @@ def correl_14(): # the part is below the average for all parts from the supplier. Only # considers suppliers from nations #19, and LG DRUM parts. # (This is multiple correlated SEMI-joins) - selected_part = part.WHERE(container == "LG DRUM").WHERE( - (retail_price < (BACK(1).supplycost * 1.5)) & (retail_price < BACK(2).avg_price) + selected_part = part.WHERE( + (container == "LG DRUM") + & (retail_price < (BACK(1).supplycost * 1.5)) + & (retail_price < BACK(2).avg_price) ) selected_supply_records = supply_records.WHERE(HAS(selected_part)) supplier_info = Suppliers.WHERE(nation_key == 19)( @@ -209,14 +211,15 @@ def correl_15(): # Correlated back reference example #15: multiple correlation. # Count how many suppliers sell at least one part where the retail price # is less than a 50% markup over the supply cost, and the retail price of - # the part is below the 90% of the average of the retail price for all + # the part is below the 85% of the average of the retail price for all # parts globally and below the average for all parts from the supplier. # Only considers suppliers from nations #19, and LG DRUM parts. # (This is multiple correlated SEMI-joins & a correlated aggregate) - selected_part = part.WHERE(container == "LG DRUM").WHERE( - (retail_price < (BACK(1).supplycost * 1.5)) + selected_part = part.WHERE( + (container == "LG DRUM") + & (retail_price < (BACK(1).supplycost * 1.5)) & (retail_price < BACK(2).avg_price) - & (retail_price < BACK(3).avg_price * 0.9) + & (retail_price < BACK(3).avg_price * 0.85) ) selected_supply_records = supply_records.WHERE(HAS(selected_part)) supplier_info = Suppliers.WHERE(nation_key == 19)( @@ -253,3 +256,47 @@ def correl_17(): region_info = region(fname=JOIN_STRINGS("-", LOWER(name), BACK(1).lname)) nation_info = Nations(lname=LOWER(name)).WHERE(HAS(region_info)) return nation_info(fullname=region_info.fname).ORDER_BY(fullname.ASC()) + + +def correl_18(): + # Correlated back reference example #18: partition decorrelation edge case. + # Count how many orders corresponded to at least half of the total price + # spent by the ordering customer in a single day, but only if the customer + # ordered multiple orders in on that day. Only considers orders made in + # 1993. + # (This is a correlated aggregation access) + cust_date_groups = PARTITION( + Orders.WHERE(YEAR(order_date) == 1993), + name="o", + by=(customer_key, order_date), + ) + selected_groups = cust_date_groups.WHERE(COUNT(o) > 1)( + total_price=SUM(o.total_price), + )(n_above_avg=COUNT(o.WHERE(total_price >= 0.5 * BACK(1).total_price))) + return TPCH(n=SUM(selected_groups.n_above_avg)) + + +def correl_19(): + # Correlated back reference example #19: cardinality edge case. + # For every supplier, count how many customers in the same nation have a + # higher account balance than that supplier. Pick the 5 suppliers with the + # largest such count. + # (This is a correlated aggregation access) + super_cust = customers.WHERE(acctbal > BACK(2).account_balance) + return Suppliers.nation(name=BACK(1).name, n_super_cust=COUNT(super_cust)).TOP_K( + 5, n_super_cust.DESC() + ) + + +def correl_20(): + # Correlated back reference example #20: multiple ancestor uniqueness keys. + # Count the instances where a nation's suppliers shipped a part to a + # customer in the same nation, only counting instances where the order was + # made in June of 1998. + # (This is a correlated singular/semi access) + is_domestic = nation(domestic=name == BACK(5).name).domestic + selected_orders = Nations.customers.orders.WHERE( + (YEAR(order_date) == 1998) & (MONTH(order_date) == 6) + ) + instances = selected_orders.lines.supplier.WHERE(is_domestic) + return TPCH(n=COUNT(instances)) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index f8c33e59..3324d3f6 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -30,6 +30,9 @@ correl_15, correl_16, correl_17, + correl_18, + correl_19, + correl_20, ) from simple_pydough_functions import ( agg_partition, @@ -895,7 +898,7 @@ ( correl_15, "correl_15", - lambda: pd.DataFrame({"n": [212]}), + lambda: pd.DataFrame({"n": [61]}), ), id="correl_15", ), @@ -945,6 +948,41 @@ ), id="correl_17", ), + pytest.param( + ( + correl_18, + "correl_18", + lambda: pd.DataFrame({"n": [697]}), + ), + id="correl_18", + ), + pytest.param( + ( + correl_19, + "correl_19", + lambda: pd.DataFrame( + { + "name": [ + "Supplier#000003934", + "Supplier#000003887", + "Supplier#000002628", + "Supplier#000008722", + "Supplier#000007971", + ], + "n_super_cust": [6160, 6142, 6129, 6127, 6117], + } + ), + ), + id="correl_19", + ), + pytest.param( + ( + correl_20, + "correl_20", + lambda: pd.DataFrame({"n": [3002]}), + ), + id="correl_20", + ), ], ) def pydough_pipeline_test_data( diff --git a/tests/test_plan_refsols/correl_13.txt b/tests/test_plan_refsols/correl_13.txt index 1c838136..bc779fbd 100644 --- a/tests/test_plan_refsols/correl_13.txt +++ b/tests/test_plan_refsols/correl_13.txt @@ -14,6 +14,5 @@ ROOT(columns=[('n', n)], orderings=[]) FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - FILTER(condition=retail_price < corr2.supplycost * 1.5:float64, columns={'key': key}) - FILTER(condition=STARTSWITH(container, 'SM':string), columns={'key': key, 'retail_price': retail_price}) - SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) + FILTER(condition=STARTSWITH(container, 'SM':string) & retail_price < corr2.supplycost * 1.5:float64, columns={'key': key}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_14.txt b/tests/test_plan_refsols/correl_14.txt index 544545e3..909904c1 100644 --- a/tests/test_plan_refsols/correl_14.txt +++ b/tests/test_plan_refsols/correl_14.txt @@ -14,6 +14,5 @@ ROOT(columns=[('n', n)], orderings=[]) FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - FILTER(condition=retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price, columns={'key': key}) - FILTER(condition=container == 'LG DRUM':string, columns={'key': key, 'retail_price': retail_price}) - SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) + FILTER(condition=container == 'LG DRUM':string & retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price, columns={'key': key}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) diff --git a/tests/test_plan_refsols/correl_15.txt b/tests/test_plan_refsols/correl_15.txt index 82d8f929..0dc3f6a6 100644 --- a/tests/test_plan_refsols/correl_15.txt +++ b/tests/test_plan_refsols/correl_15.txt @@ -18,6 +18,5 @@ ROOT(columns=[('n', n)], orderings=[]) FILTER(condition=True:bool, columns={'supplier_key': supplier_key}) JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2') SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost}) - FILTER(condition=retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price & retail_price < corr4.avg_price * 0.9:float64, columns={'key': key}) - FILTER(condition=container == 'LG DRUM':string, columns={'key': key, 'retail_price': retail_price}) - SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) + FILTER(condition=container == 'LG DRUM':string & retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.avg_price & retail_price < corr4.avg_price * 0.85:float64, columns={'key': key}) + SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice}) From ad61d27172e7bd40bad52d8d6be918341f18185f Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 10 Feb 2025 13:37:41 -0500 Subject: [PATCH 107/112] Pulling up downstream test changes --- tests/test_plan_refsols/correl_18.txt | 14 ++++++++++++++ tests/test_plan_refsols/correl_19.txt | 12 ++++++++++++ tests/test_plan_refsols/correl_20.txt | 17 +++++++++++++++++ 3 files changed, 43 insertions(+) create mode 100644 tests/test_plan_refsols/correl_18.txt create mode 100644 tests/test_plan_refsols/correl_19.txt create mode 100644 tests/test_plan_refsols/correl_20.txt diff --git a/tests/test_plan_refsols/correl_18.txt b/tests/test_plan_refsols/correl_18.txt new file mode 100644 index 00000000..ab36ffd4 --- /dev/null +++ b/tests/test_plan_refsols/correl_18.txt @@ -0,0 +1,14 @@ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': DEFAULT_TO(agg_0, 0:int64)}) + AGGREGATE(keys={}, aggregations={'agg_0': SUM(n_above_avg)}) + PROJECT(columns={'n_above_avg': DEFAULT_TO(agg_2, 0:int64)}) + JOIN(conditions=[t0.customer_key == t1.customer_key & t0.order_date == t1.order_date], types=['left'], columns={'agg_2': t1.agg_2}, correl_name='corr1') + PROJECT(columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': DEFAULT_TO(agg_1, 0:int64)}) + FILTER(condition=DEFAULT_TO(agg_0, 0:int64) > 1:int64, columns={'agg_1': agg_1, 'customer_key': customer_key, 'order_date': order_date}) + AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_0': COUNT(), 'agg_1': SUM(total_price)}) + FILTER(condition=YEAR(order_date) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) + AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_2': COUNT()}) + FILTER(condition=total_price >= 0.5:float64 * corr1.total_price, columns={'customer_key': customer_key, 'order_date': order_date}) + FILTER(condition=YEAR(order_date) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate, 'total_price': o_totalprice}) diff --git a/tests/test_plan_refsols/correl_19.txt b/tests/test_plan_refsols/correl_19.txt new file mode 100644 index 00000000..a273084b --- /dev/null +++ b/tests/test_plan_refsols/correl_19.txt @@ -0,0 +1,12 @@ +ROOT(columns=[('name', name_7), ('n_super_cust', n_super_cust)], orderings=[(ordering_1):desc_last]) + PROJECT(columns={'n_super_cust': n_super_cust, 'name_7': name_3, 'ordering_1': ordering_1}) + LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'n_super_cust': n_super_cust, 'name_3': name_3, 'ordering_1': ordering_1}, orderings=[(ordering_1):desc_last]) + PROJECT(columns={'n_super_cust': n_super_cust, 'name_3': name_3, 'ordering_1': n_super_cust}) + PROJECT(columns={'n_super_cust': DEFAULT_TO(agg_0, 0:int64), 'name_3': name_3}) + JOIN(conditions=[t0.key_2 == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'name_3': t0.name_3}, correl_name='corr4') + JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'account_balance': t0.account_balance, 'key_2': t1.key, 'name_3': t1.name}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': COUNT()}) + FILTER(condition=acctbal > corr4.account_balance, columns={'nation_key': nation_key}) + SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'nation_key': c_nationkey}) diff --git a/tests/test_plan_refsols/correl_20.txt b/tests/test_plan_refsols/correl_20.txt new file mode 100644 index 00000000..c1d388c4 --- /dev/null +++ b/tests/test_plan_refsols/correl_20.txt @@ -0,0 +1,17 @@ +ROOT(columns=[('n', n)], orderings=[]) + PROJECT(columns={'n': agg_0}) + AGGREGATE(keys={}, aggregations={'agg_0': COUNT()}) + FILTER(condition=domestic, columns={'account_balance': account_balance}) + JOIN(conditions=[t0.nation_key_11 == t1.key], types=['left'], columns={'account_balance': t0.account_balance, 'domestic': t1.domestic}, correl_name='corr13') + JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'account_balance': t1.account_balance, 'name': t0.name, 'nation_key_11': t1.nation_key}) + JOIN(conditions=[t0.key_5 == t1.order_key], types=['inner'], columns={'name': t0.name, 'supplier_key': t1.supplier_key}) + FILTER(condition=YEAR(order_date) == 1998:int64 & MONTH(order_date) == 6:int64, columns={'key_5': key_5, 'name': name}) + JOIN(conditions=[t0.key_2 == t1.customer_key], types=['inner'], columns={'key_5': t1.key, 'name': t0.name, 'order_date': t1.order_date}) + JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) + SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey}) + SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate}) + SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'supplier_key': l_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey}) + PROJECT(columns={'domestic': name == corr13.name, 'key': key}) + SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name}) From 7ef04760cf813324e1fbb2718920819051dc4d7b Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 10 Feb 2025 15:04:31 -0500 Subject: [PATCH 108/112] Cleanup --- pydough/conversion/relational_converter.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 9f6b38b3..4c69487e 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -1001,11 +1001,7 @@ def convert_ast_to_relational( # the relational conversion procedure. The first element in the returned # list is the final rel node. hybrid: HybridTree = HybridTranslator(configs).make_hybrid_tree(node, None) - # print() - # print(hybrid) - # print("DECOR") run_hybrid_decorrelation(hybrid) - # print(hybrid) renamings: dict[str, str] = hybrid.pipeline[-1].renamings output: TranslationOutput = translator.rel_translation( None, hybrid, len(hybrid.pipeline) - 1 From c83b67e93548a57be187d42535c8b1138a6f304b Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 10 Feb 2025 17:08:23 -0500 Subject: [PATCH 109/112] Removing dead code --- pydough/conversion/hybrid_tree.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index bc0f76cd..fae2842c 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -1608,24 +1608,6 @@ def make_hybrid_correl_expr( back_expr, collection, steps_taken_so_far ).expr self.stack.pop() - # self.stack.append(parent_tree) - # Then, postprocess the output to account for the fact that a - # BACK level got skipped due to the change in subtree. - # match parent_result.expr: - # case HybridRefExpr(): - # parent_result = HybridBackRefExpr( - # parent_result.expr.name, 1, parent_result.typ - # ) - # case HybridBackRefExpr(): - # parent_result = HybridBackRefExpr( - # parent_result.expr.name, - # parent_result.expr.back_idx + 1, - # parent_result.typ, - # ) - # case _: - # raise ValueError( - # f"Malformed expression for correlated reference: {parent_result}" - # ) elif remaining_steps_back == 0: # If there are no more steps back to be made, then the correlated # reference is to a reference from the current context. From fdbf90923e2872946df0457f879e8aaef1086092 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Sun, 16 Feb 2025 21:22:28 -0500 Subject: [PATCH 110/112] Revisions [RUN CI] --- pydough/conversion/relational_converter.py | 4 +++- .../relational/relational_expressions/abstract_expression.py | 3 +++ .../relational/relational_expressions/literal_expression.py | 4 ---- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 637225ec..494d7fbd 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -251,7 +251,9 @@ def translate_expression( f"Unsupported expression to reference in a correlated reference: {ancestor_expr}" ) case _: - raise NotImplementedError(expr.__class__.__name__) + raise NotImplementedError( + f"TODO: support relational conversion on {expr.__class__.__name__}" + ) def join_outputs( self, diff --git a/pydough/relational/relational_expressions/abstract_expression.py b/pydough/relational/relational_expressions/abstract_expression.py index 7236ad0b..ac3b6d33 100644 --- a/pydough/relational/relational_expressions/abstract_expression.py +++ b/pydough/relational/relational_expressions/abstract_expression.py @@ -81,6 +81,9 @@ def to_string(self, compact: bool = False) -> str: def __repr__(self) -> str: return self.to_string() + def __hash__(self) -> int: + return hash(self.to_string()) + @abstractmethod def accept(self, visitor: RelationalExpressionVisitor) -> None: """ diff --git a/pydough/relational/relational_expressions/literal_expression.py b/pydough/relational/relational_expressions/literal_expression.py index 3b0eb803..1edc9981 100644 --- a/pydough/relational/relational_expressions/literal_expression.py +++ b/pydough/relational/relational_expressions/literal_expression.py @@ -29,10 +29,6 @@ def __init__(self, value: Any, data_type: PyDoughType): super().__init__(data_type) self._value: Any = value - def __hash__(self) -> int: - # Note: This will break if the value isn't hashable. - return hash((self.value, self.data_type)) - @property def value(self) -> object: """ From a5531e3ca22e9c5b08d152618722af52d2e699a1 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 17 Feb 2025 15:26:02 -0500 Subject: [PATCH 111/112] Revisions --- pydough/conversion/hybrid_tree.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index e10a6d94..79a03062 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -1544,7 +1544,7 @@ def make_agg_call( # Identify the child connection that the aggregation call is pushed # into. child_idx: int = child_indices.pop() - child_connection = hybrid.children[child_idx] + child_connection: HybridConnection = hybrid.children[child_idx] # Generate a unique name for the agg call to push into the child # connection. agg_name: str = self.get_agg_name(child_connection) @@ -1571,14 +1571,14 @@ def make_hybrid_correl_expr( before we have run out of BACK levels to step up out of. `steps_taken_so_far`: the number of steps already taken to step up from the BACK node. This is needed so we know how many steps - still need to be taken upward once we have stepped out of hte child + still need to be taken upward once we have stepped out of the child subtree back into the parent subtree. """ if len(self.stack) == 0: raise ValueError("Back reference steps too far back") # Identify the parent subtree that the BACK reference is stepping back # into, out of the child. - parent_tree = self.stack.pop() + parent_tree: HybridTree = self.stack.pop() remaining_steps_back: int = back_expr.back_levels - steps_taken_so_far - 1 parent_result: HybridExpr # Special case: stepping out of the data argument of PARTITION back From 2105bc8c813f2cf53dbdf85eaa05c4ee853d75f8 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 18 Feb 2025 15:37:51 -0500 Subject: [PATCH 112/112] Fixing typos --- pydough/conversion/hybrid_decorrelater.py | 2 +- pydough/conversion/relational_converter.py | 2 +- pydough/types/struct_type.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py index a4cd5f4a..72f6aad1 100644 --- a/pydough/conversion/hybrid_decorrelater.py +++ b/pydough/conversion/hybrid_decorrelater.py @@ -229,7 +229,7 @@ def decorrelate_child( tree that contained the correlated child as a child. The transformed child can now replace correlated references with BACK references that point to terms in its newly expanded ancestry, and the original hybrid - tree cna now join onto this child using its uniqueness keys. + tree can now join onto this child using its uniqueness keys. """ # First, find the height of the child subtree & its top-most level. child_root: HybridTree = child.subtree diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index fdbbe9c8..77c10c86 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -649,7 +649,7 @@ def translate_partition( Returns: The TranslationOutput payload containing access to the aggregated - child corresponding tot he partition data. + child corresponding to the partition data. """ expressions: dict[HybridExpr, ColumnReference] = {} # Account for the fact that the PARTITION is stepping down a level, diff --git a/pydough/types/struct_type.py b/pydough/types/struct_type.py index 94b3222a..7b3ef680 100644 --- a/pydough/types/struct_type.py +++ b/pydough/types/struct_type.py @@ -109,7 +109,7 @@ def parse_struct_body( except PyDoughTypeException: pass - # Otherwise, iterate across all commas int he right hand side + # Otherwise, iterate across all commas in the right hand side # that are candidate splitting locations between a PyDough # type and a suffix that is a valid list of fields. if field_type is None: