diff --git a/.gitignore b/.gitignore
index 85cb80a1..991ab3df 100644
--- a/.gitignore
+++ b/.gitignore
@@ -132,6 +132,7 @@ venv/
ENV/
env.bak/
venv.bak/
+.vscode
# Spyder project settings
.spyderproject
diff --git a/README.md b/README.md
index 059d2444..e62794af 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@ Suppose I want to know for every person their name & the total income they've ma
The following PyDough snippet solves this problem:
```py
-result = People(
+result = People.CALCULATE(
name,
net_income = SUM(jobs.income_earned) - SUM(schools.tuition_paid)
)
diff --git a/demos/notebooks/1_introduction.ipynb b/demos/notebooks/1_introduction.ipynb
index 5f08c7de..76086cfe 100644
--- a/demos/notebooks/1_introduction.ipynb
+++ b/demos/notebooks/1_introduction.ipynb
@@ -101,7 +101,7 @@
"source": [
"%%pydough\n",
"\n",
- "nations(key, name)"
+ "nations.CALCULATE(nkey=key, nname=name)"
]
},
{
@@ -121,7 +121,7 @@
"source": [
"%%pydough\n",
"\n",
- "nation_keys = nations(key, name)"
+ "nation_keys = nations.CALCULATE(nkey=key, nname=name)"
]
},
{
@@ -149,7 +149,7 @@
"source": [
"%%pydough\n",
"\n",
- "lowest_customer_nations = nation_keys(key, name, cust_count=COUNT(customers)).TOP_K(2, by=cust_count.ASC())\n",
+ "lowest_customer_nations = nation_keys.CALCULATE(nkey, nname, cust_count=COUNT(customers)).TOP_K(2, by=cust_count.ASC())\n",
"lowest_customer_nations"
]
},
@@ -236,7 +236,9 @@
"id": "f52dfcfe-6e90-44b8-b9c4-7dc08a5b28ca",
"metadata": {},
"source": [
- "Finally, while building a statement from smaller components is best practice in Pydough, you can always evaluate the entire expression all at once within a PyDough cell, such as this example that loads the all asian nations in the dataset."
+ "Finally, while building a statement from smaller components is best practice in Pydough, you can always evaluate the entire expression all at once within a PyDough cell, such as this example that loads the all Asian nations in the dataset.\n",
+ "\n",
+ "We can use the optional `columns` argument to `to_sql` or `to_df` to specify which columns to include, or even what they should be renamed as."
]
},
{
@@ -248,7 +250,9 @@
"source": [
"%%pydough\n",
"\n",
- "pydough.to_df(nations.WHERE(region.name == \"ASIA\"))"
+ "asian_countries = nations.WHERE(region.name == \"ASIA\")\n",
+ "print(pydough.to_df(asian_countries, columns=[\"name\", \"key\"]))\n",
+ "pydough.to_df(asian_countries, columns={\"nation_name\": \"name\", \"id\": \"key\"})"
]
},
{
@@ -290,7 +294,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.7"
+ "version": "3.12.6"
}
},
"nbformat": 4,
diff --git a/demos/notebooks/2_pydough_operations.ipynb b/demos/notebooks/2_pydough_operations.ipynb
index 84c21ebc..6b39856e 100644
--- a/demos/notebooks/2_pydough_operations.ipynb
+++ b/demos/notebooks/2_pydough_operations.ipynb
@@ -84,9 +84,9 @@
"id": "a25a2965-4f88-4626-b326-caf931fdba9c",
"metadata": {},
"source": [
- "## Calc\n",
+ "## Calculate\n",
"\n",
- "The next important operation is the `CALC` operation, which is used by \"calling\" a collection as a function."
+ "The next important operation is the `CALCULATE` operation, which takes in a variable number of positioning and/or keyword arguments."
]
},
{
@@ -98,7 +98,7 @@
"source": [
"%%pydough\n",
"\n",
- "pydough.to_sql(nations(key))"
+ "print(pydough.to_sql(nations.CALCULATE(key, nation_name=name)))"
]
},
{
@@ -106,10 +106,13 @@
"id": "f89da4ca-5493-493f-bfe3-41d8a5f5d2a1",
"metadata": {},
"source": [
- "Calc has a few purposes:\n",
+ "Calculate has a few purposes:\n",
"* Select which entries you want in the output.\n",
"* Define new fields by calling functions.\n",
- "* Allow operations to be evaluated for each entry in the outermost collection's \"context\"."
+ "* Allow operations to be evaluated for each entry in the outermost collection's \"context\".\n",
+ "* Define aliases for terms that get down-streamed to descendants ([see here](#down-streaming)).\n",
+ "\n",
+ "The terms of the last `CALCULATE` in the PyDough logic are the terms that are included in the result (unless the `columns` argument of `to_sql` or `to_df` is used)."
]
},
{
@@ -121,7 +124,7 @@
"source": [
"%%pydough\n",
"\n",
- "pydough.to_sql(nations(key + 1))"
+ "print(pydough.to_sql(nations.CALCULATE(adjusted_key = key + 1)))"
]
},
{
@@ -129,7 +132,7 @@
"id": "24031aa2-1df7-441d-b487-aa093b852504",
"metadata": {},
"source": [
- "Here the context is the \"nations\" at the root of the graph. This means that for each entry within nations, we compute the result. This has important implications for when we get to more complex expressions. For example, if we want to know how many nations we have stored in each region, we can do via CALC."
+ "Here the context is the \"nations\" at the root of the graph. This means that for each entry within nations, we compute the result. This has important implications for when we get to more complex expressions. For example, if we want to know how many nations we have stored in each region, we can do via `CALCULATE`."
]
},
{
@@ -141,7 +144,7 @@
"source": [
"%%pydough\n",
"\n",
- "pydough.to_df(regions(name, nation_count=COUNT(nations)))"
+ "pydough.to_df(regions.CALCULATE(name, nation_count=COUNT(nations)))"
]
},
{
@@ -151,7 +154,7 @@
"source": [
"Internally, this process evaluates `COUNT(nations)` grouped on each region and then joining the result with the original `regions` table. Importantly, this outputs a \"scalar\" value for each region.\n",
"\n",
- "This shows a very important restriction of CALC, each final entry in a calc expression must be scalar with respect to a current context. For example, the expression `regions(region_name=name, nation_name=nations.name)` is not legal because region and nation is a one to many relationship, so there is not a single nation name for each region. \n",
+ "This shows a very important restriction of `CALCULATE`: each final entry in the operation must be scalar with respect to a current context. For example, the expression `regions(region_name=name, nation_name=nations.name)` is not legal because region and nation is a one to many relationship, so there is not a single nation name for each region. \n",
"\n",
"**The cell below will result in an error because it violates this restriction.**"
]
@@ -165,7 +168,7 @@
"source": [
"%%pydough\n",
"\n",
- "pydough.to_df(regions(region_name=name, nation_name=nations.name))"
+ "pydough.to_df(regions.CALCULATE(region_name=name, nation_name=nations.name))"
]
},
{
@@ -185,7 +188,7 @@
"source": [
"%%pydough\n",
"\n",
- "pydough.to_df(nations(nation_name=name, region_name=region.name))"
+ "pydough.to_df(nations.CALCULATE(nation_name=name, region_name=region.name))"
]
},
{
@@ -216,29 +219,39 @@
"%%pydough\n",
"\n",
"# Numeric operations\n",
- "print(pydough.to_sql(nations(key + 1, key - 1, key * 1, key / 1)))\n",
+ "print(\"Q1\")\n",
+ "print(pydough.to_sql(nations.CALCULATE(key + 1, key - 1, key * 1, key / 1)))\n",
"\n",
"# Comparison operators\n",
- "print(pydough.to_sql(nations(key == 0, key < 0, key != 0, key >= 5)))\n",
+ "print(\"\\nQ2\")\n",
+ "print(pydough.to_sql(nations.CALCULATE(key == 0, key < 0, key != 0, key >= 5)))\n",
"\n",
"# String Operations\n",
- "print(pydough.to_sql(nations(LENGTH(name), UPPER(name), LOWER(name), STARTSWITH(name, \"A\"))))\n",
+ "print(\"\\nQ3\")\n",
+ "print(pydough.to_sql(nations.CALCULATE(LENGTH(name), UPPER(name), LOWER(name), STARTSWITH(name, \"A\"))))\n",
"\n",
"# Boolean operations\n",
- "print(pydough.to_sql(nations((key != 1) & (LENGTH(name) > 5)))) # Boolean AND\n",
- "print(pydough.to_sql(nations((key != 1) | (LENGTH(name) > 5)))) # Boolean OR\n",
- "print(pydough.to_sql(nations(~(LENGTH(name) > 5)))) # Boolean NOT \n",
- "print(pydough.to_sql(nations(ISIN(name, (\"KENYA\", \"JAPAN\"))))) # In\n",
+ "print(\"\\nQ4\")\n",
+ "print(pydough.to_sql(nations.CALCULATE((key != 1) & (LENGTH(name) > 5)))) # Boolean AND\n",
+ "print(\"\\nQ5\")\n",
+ "print(pydough.to_sql(nations.CALCULATE((key != 1) | (LENGTH(name) > 5)))) # Boolean OR\n",
+ "print(\"\\nQ6\")\n",
+ "print(pydough.to_sql(nations.CALCULATE(~(LENGTH(name) > 5)))) # Boolean NOT \n",
+ "print(\"\\nQ7\") \n",
+ "print(pydough.to_sql(nations.CALCULATE(ISIN(name, (\"KENYA\", \"JAPAN\"))))) # In\n",
"\n",
"# Datetime Operations\n",
"# Note: Since this is based on SQL lite the underlying date is a bit strange.\n",
- "print(pydough.to_sql(lines(YEAR(ship_date), MONTH(ship_date), DAY(ship_date),HOUR(ship_date),MINUTE(ship_date),SECOND(ship_date))))\n",
+ "print(\"\\nQ8\")\n",
+ "print(pydough.to_sql(lines.CALCULATE(YEAR(ship_date), MONTH(ship_date), DAY(ship_date),HOUR(ship_date),MINUTE(ship_date),SECOND(ship_date))))\n",
"\n",
"# Aggregation operations\n",
- "print(pydough.to_sql(TPCH(NDISTINCT(nations.comment), SUM(nations.key))))\n",
+ "print(\"\\nQ9\")\n",
+ "print(pydough.to_sql(TPCH.CALCULATE(NDISTINCT(nations.comment), SUM(nations.key))))\n",
"# Count can be used on a column for non-null entries or a collection\n",
"# for total entries.\n",
- "print(pydough.to_sql(TPCH(COUNT(nations), COUNT(nations.comment))))"
+ "print(\"\\nQ10\")\n",
+ "print(pydough.to_sql(TPCH.CALCULATE(COUNT(nations), COUNT(nations.comment))))"
]
},
{
@@ -260,9 +273,11 @@
"id": "b70993e8-3cd2-4c45-87e3-8e68f67b92a0",
"metadata": {},
"source": [
- "### BACK\n",
+ "### Down-Streaming\n",
+ "\n",
+ "Sometimes you need to load a value from a previous context to use at a later step in a PyDough statement. Any expression from an ancestor context that is placed in a `CALCULATE` is automatically made available to all descendants of that context. However, an error will occur if the name of the term defined in the ancestor collides with a name of a term or property of a descendant context, since PyDough will not know which one to use.\n",
"\n",
- "Sometimes you need to load a value from a previous context to use at a later step in a PyDough statement. That can be done using the `BACK` operation. This step moves back `k` steps to find the name you are searching for. This is useful to avoid repeating computation."
+ "Notice how in the example below, `region_name` is defined in a `CALCULATE` within the context of `regions`, so the calculate within the context of `nations` also has access to `region_name` (interpreted as \"the name of the region that this nation belongs to\")."
]
},
{
@@ -274,7 +289,7 @@
"source": [
"%%pydough\n",
"\n",
- "pydough.to_df(regions.nations(region_name=BACK(1).name, nation_name=name))"
+ "pydough.to_df(regions.CALCULATE(region_name=name).nations.CALCULATE(region_name, nation_name=name))"
]
},
{
@@ -282,7 +297,7 @@
"id": "6040a7c5-fc82-4e33-8b2b-a1b3ef394f71",
"metadata": {},
"source": [
- "Here is a more complex example showing intermediate values. Here we will first compute `total_value` and then reuse it via `BACK`."
+ "Here is a more complex example showing intermediate values. Here we will first compute `total_value` and then reuse it downstream."
]
},
{
@@ -294,7 +309,7 @@
"source": [
"%%pydough\n",
"\n",
- "nations_value = nations(name, total_value=SUM(suppliers.account_balance))\n",
+ "nations_value = nations.CALCULATE(nation_name=name, total_value=SUM(suppliers.account_balance))\n",
"pydough.to_df(nations_value)"
]
},
@@ -306,12 +321,12 @@
"outputs": [],
"source": [
"%%pydough\n",
- "suppliers_value = nations_value.suppliers(\n",
+ "suppliers_value = nations_value.suppliers.CALCULATE(\n",
" key,\n",
" name,\n",
- " nation_name=BACK(1).name,\n",
+ " nation_name,\n",
" account_balance=account_balance,\n",
- " percentage_of_national_value=100 * account_balance / BACK(1).total_value\n",
+ " percentage_of_national_value=100 * account_balance / total_value\n",
")\n",
"top_suppliers = suppliers_value.TOP_K(20, by=percentage_of_national_value.DESC())\n",
"pydough.to_df(top_suppliers)"
@@ -324,7 +339,7 @@
"source": [
"## WHERE\n",
"\n",
- "The WHERE operation by be used to filter unwanted entries in a context. For example, we can filter `nations` to only consider the `AMERICA` and `EUROPE` regions. A WHERE's context functions similarly to a calc except that it cannot be used to assign new properties. "
+ "The `WHERE` operation by be used to filter unwanted entries in a context. For example, we can filter `nations` to only consider the `AMERICA` and `EUROPE` regions. A WHERE's context functions similarly to a `CALCULATE` except that it cannot be used to assign new properties; it only contains a single positional argument: the predicate to filter on. "
]
},
{
@@ -367,7 +382,7 @@
"metadata": {},
"source": [
"The `by` argument requirements are:\n",
- "* Anything that can be an expression used in a `CALC` or a `WHERE` can be used a component of a `by`.\n",
+ "* Anything that can be an expression used in a `CALCULATE` or a `WHERE` can be used a component of a `by`.\n",
"* The value in the `by` must end with either `.ASC()` or `.DESC()`\n",
"\n",
"You can also provide a tuple to by if you need to break ties. Consider this alternatives that instead selects the 20 parts with the largest size, starting with the smallest part id."
@@ -428,10 +443,10 @@
"source": [
"%%pydough\n",
"\n",
- "updated_nations = nations(key, name_length=LENGTH(name))\n",
+ "updated_nations = nations.CALCULATE(key, name_length=LENGTH(name))\n",
"grouped_nations = PARTITION(\n",
" updated_nations, name=\"n\", by=(name_length)\n",
- ")(\n",
+ ").CALCULATE(\n",
" name_length,\n",
" nation_count=COUNT(n.key)\n",
")\n",
@@ -446,7 +461,7 @@
"A couple important usage details:\n",
"* The `name` argument specifies the name of the subcollection access from the partitions to the original unpartitioned data.\n",
"* `keys` can be either be a single expression or a tuple of them, but it can only be references to expressions that already exist in the context of the data (e.g. `name`, not `LOWER(name)` or `region.name`)\n",
- "* `BACK` should be used to step back into the partition child without retaining the partitioning. An example is shown below where we select brass european parts but only with the minimum supply cost."
+ "* Terms defined from the context of the `PARTITION` can be down-streamed to its descendants. An example is shown below where we select brass parts of size 15, but only the ones whose supply is below the average of all such parts."
]
},
{
@@ -459,8 +474,8 @@
"%%pydough\n",
"\n",
"selected_parts = parts.WHERE(ENDSWITH(part_type, \"BRASS\") & (size == 15))\n",
- "part_types = PARTITION(selected_parts, name=\"p\", by=part_type)(avg_price=AVG(p.retail_price))\n",
- "output = part_types.p.WHERE(retail_price < BACK(1).avg_price)\n",
+ "part_types = PARTITION(selected_parts, name=\"p\", by=part_type).CALCULATE(avg_price=AVG(p.retail_price))\n",
+ "output = part_types.p.WHERE(retail_price < avg_price)\n",
"pydough.to_df(output)"
]
},
@@ -532,7 +547,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.7"
+ "version": "3.12.6"
}
},
"nbformat": 4,
diff --git a/demos/notebooks/3_exploration.ipynb b/demos/notebooks/3_exploration.ipynb
index 753d91c0..590d8477 100644
--- a/demos/notebooks/3_exploration.ipynb
+++ b/demos/notebooks/3_exploration.ipynb
@@ -308,7 +308,7 @@
"\n",
"orders_1995 = customers.orders.WHERE(YEAR(order_date) == 1995)\n",
"\n",
- "asian_countries_info = asian_countries(country_name=LOWER(name), total_orders=COUNT(orders_1995))\n",
+ "asian_countries_info = asian_countries.CALCULATE(country_name=LOWER(name), total_orders=COUNT(orders_1995))\n",
"\n",
"top_asian_countries = asian_countries_info.TOP_K(3, by=total_orders.DESC())\n",
"\n",
@@ -408,7 +408,7 @@
"source": [
"Here, we learn that `customers.orders` invokes a child of the current context (`nations.WHERE(region.name == 'ASIA')`) by accessing the `customers` subcollection, then accessing its `orders` collection, then filtering it on the conedition `YEAR(order_date) == 1995`. \n",
"\n",
- "We also know that this resulting child is plural with regards to the context, meaning that `asian_countries(asian_countries.order_date)` would be illegal, but `asian_countries(MAX(asian_countries.order_date))` is legal.\n",
+ "We also know that this resulting child is plural with regards to the context, meaning that `asian_countries.CALCULATE(asian_countries.order_date)` would be illegal, but `asian_countries.CALCULATE(MAX(asian_countries.order_date))` is legal.\n",
"\n",
"More combinations of `pydough.explain` and `pydough.explain_terms` can be done to learn more about what each of these components does."
]
@@ -438,7 +438,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.7"
+ "version": "3.12.6"
}
},
"nbformat": 4,
diff --git a/demos/notebooks/4_tpch.ipynb b/demos/notebooks/4_tpch.ipynb
index eff56358..051b5592 100644
--- a/demos/notebooks/4_tpch.ipynb
+++ b/demos/notebooks/4_tpch.ipynb
@@ -101,7 +101,7 @@
"charge = disc_price * (1 + l.tax)\n",
"selected_lines = lines.WHERE((ship_date <= datetime.date(1998, 9, 2)))\n",
"partitioned_lines = PARTITION(selected_lines, name=\"l\", by=(return_flag, status))\n",
- "output = partitioned_lines(\n",
+ "output = partitioned_lines.CALCULATE(\n",
" L_RETURNFLAG=return_flag,\n",
" L_LINESTATUS=status,\n",
" SUM_QTY=SUM(l.quantity),\n",
@@ -170,9 +170,7 @@
"LIMIT 100;\n",
"```\n",
"\n",
- "Notice the use of a correlated subqueries to determine the `PS_SUPPLYCOST` criteria. In PyDough this is handled naturally by simply first peforming the partition, applying the filter, and then navigating back to `selected_parts` in question, which avoids fully stepping out of the query. \n",
- "\n",
- "**NOTE**: The `BACK` terms are included in `selected_parts` because once it is partitioned, it will be impossible to access the original ancestry of `part` since its ancestry is replaced with `part_groups`."
+ "Notice the use of a correlated subqueries to determine the `PS_SUPPLYCOST` criteria. In PyDough this is handled naturally by simply first peforming the partition, applying the filter, and then navigating back to `selected_parts` in question, which avoids fully stepping out of the query. "
]
},
{
@@ -185,26 +183,28 @@
"%%pydough\n",
"\n",
"selected_parts = (\n",
- " nations.WHERE(region.name == \"EUROPE\")\n",
- " .suppliers.supply_records.part(\n",
- " s_acctbal=BACK(2).account_balance,\n",
- " s_name=BACK(2).name,\n",
- " n_name=BACK(3).name,\n",
- " s_address=BACK(2).address,\n",
- " s_phone=BACK(2).phone,\n",
- " s_comment=BACK(2).comment,\n",
- " supplycost=BACK(1).supplycost,\n",
+ " nations.CALCULATE(n_name=name)\n",
+ " .WHERE(region.name == \"EUROPE\")\n",
+ " .suppliers.CALCULATE(\n",
+ " s_acctbal=account_balance,\n",
+ " s_name=name,\n",
+ " s_address=address,\n",
+ " s_phone=phone,\n",
+ " s_comment=comment,\n",
+ " )\n",
+ " .supply_records.CALCULATE(\n",
+ " supplycost=supplycost,\n",
+ " )\n",
+ " .part.WHERE(ENDSWITH(part_type, \"BRASS\") & (size == 15))\n",
" )\n",
- " .WHERE(ENDSWITH(part_type, \"BRASS\") & (size == 15))\n",
- ")\n",
- "part_groups = PARTITION(selected_parts, name=\"p\", by=key)(\n",
+ "part_groups = PARTITION(selected_parts, name=\"p\", by=key).CALCULATE(\n",
" best_cost=MIN(p.supplycost)\n",
")\n",
"output = part_groups.p.WHERE(\n",
- " (supplycost == BACK(1).best_cost)\n",
+ " (supplycost == best_cost)\n",
" & ENDSWITH(part_type, \"BRASS\")\n",
" & (size == 15)\n",
- ")(\n",
+ ").CALCULATE(\n",
" S_ACCTBAL=s_acctbal,\n",
" S_NAME=s_name,\n",
" N_NAME=n_name,\n",
@@ -270,16 +270,15 @@
"%%pydough\n",
"\n",
"cutoff_date = datetime.date(1995, 3, 15)\n",
- "selected_orders = orders.WHERE(\n",
+ "selected_orders = orders.CALCULATE(\n",
+ " order_date, ship_priority\n",
+ ").WHERE(\n",
" (customer.mktsegment == \"BUILDING\") & (order_date < cutoff_date)\n",
")\n",
- "selected_lines = selected_orders.lines.WHERE(ship_date > cutoff_date)(\n",
- " BACK(1).order_date,\n",
- " BACK(1).ship_priority,\n",
- ")\n",
+ "selected_lines = selected_orders.lines.WHERE(ship_date > cutoff_date)\n",
"output = PARTITION(\n",
" selected_lines, name=\"l\", by=(order_key, order_date, ship_priority)\n",
- ")(\n",
+ ").CALCULATE(\n",
" L_ORDERKEY=order_key,\n",
" REVENUE=SUM(l.extended_price * (1 - l.discount)),\n",
" O_ORDERDATE=order_date,\n",
@@ -340,13 +339,80 @@
" & (order_date < datetime.date(1993, 10, 1))\n",
" & HAS(selected_lines)\n",
")\n",
- "output = PARTITION(selected_orders, name=\"o\", by=order_priority)(\n",
+ "output = PARTITION(selected_orders, name=\"o\", by=order_priority).CALCULATE(\n",
" O_ORDERPRIORITY=order_priority,\n",
" ORDER_COUNT=COUNT(o),\n",
").ORDER_BY(O_ORDERPRIORITY.ASC())\n",
"pydough.to_df(output)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "d6fee20d",
+ "metadata": {},
+ "source": [
+ "## Query 5\n",
+ "\n",
+ "This question seeks to learn about domestic revenue from suppliers in Asian countries. It does so by calculating, for each Asian country, the total revenue generated by suppliers in that nation shipping a part to a customer from the same nation, only considering shipments ordered in 1994. Revenue volume for all qualifying lineitems in a particular nation is defined as `sum(l_extendedprice * (1 - l_discount))`.\n",
+ "\n",
+ "\n",
+ "\n",
+ "Here is the corresponding SQL:\n",
+ "\n",
+ "```SQL\n",
+ "select\n",
+ " N_NAME,\n",
+ " sum(l_extendedprice * (1 - l_discount)) as REVENUE\n",
+ "from\n",
+ " customer,\n",
+ " orders,\n",
+ " lineitem,\n",
+ " supplier,\n",
+ " nation,\n",
+ " region\n",
+ "where\n",
+ " c_custkey = o_custkey\n",
+ " and l_orderkey = o_orderkey\n",
+ " and l_suppkey = s_suppkey\n",
+ " and c_nationkey = s_nationkey\n",
+ " and s_nationkey = n_nationkey\n",
+ " and n_regionkey = r_regionkey\n",
+ " and r_name = 'ASIA'\n",
+ " and o_orderdate >= '1994-01-01'\n",
+ " and o_orderdate < '1995-01-01'\n",
+ "group by\n",
+ " n_name\n",
+ "order by\n",
+ " revenue desc\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4f8eb6d0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%pydough\n",
+ "\n",
+ "selected_lines = (\n",
+ " customers.orders.WHERE(\n",
+ " (order_date >= datetime.date(1994, 1, 1))\n",
+ " & (order_date < datetime.date(1995, 1, 1))\n",
+ " )\n",
+ " .lines.WHERE(supplier.nation.name == nation_name)\n",
+ " .CALCULATE(value=extended_price * (1 - discount))\n",
+ ")\n",
+ "output = (\n",
+ " nations.CALCULATE(nation_name=name)\n",
+ " .WHERE(region.name == \"ASIA\")\n",
+ " .CALCULATE(N_NAME=name, REVENUE=SUM(selected_lines.value))\n",
+ " .ORDER_BY(REVENUE.DESC())\n",
+ ")\n",
+ "pydough.to_df(output)"
+ ]
+ },
{
"cell_type": "markdown",
"id": "c5c4d84b-b3c0-4b58-8635-b1d30ea2946a",
@@ -388,8 +454,8 @@
" & (0.05 <= discount)\n",
" & (discount <= 0.07)\n",
" & (quantity < 24)\n",
- ")(amt=extended_price * discount)\n",
- "output = TPCH(REVENUE=SUM(selected_lines.amt))\n",
+ ").CALCULATE(amt=extended_price * discount)\n",
+ "output = TPCH.CALCULATE(REVENUE=SUM(selected_lines.amt))\n",
"pydough.to_df(output)"
]
},
@@ -457,7 +523,7 @@
"source": [
"%%pydough\n",
"\n",
- "line_info = lines(\n",
+ "line_info = lines.CALCULATE(\n",
" supp_nation=supplier.nation.name,\n",
" cust_nation=order.customer.nation.name,\n",
" l_year=YEAR(ship_date),\n",
@@ -471,7 +537,7 @@
" )\n",
")\n",
"\n",
- "output = PARTITION(line_info, name=\"l\", by=(supp_nation, cust_nation, l_year))(\n",
+ "output = PARTITION(line_info, name=\"l\", by=(supp_nation, cust_nation, l_year)).CALCULATE(\n",
" SUPP_NATION=supp_nation,\n",
" CUST_NATION=cust_nation,\n",
" L_YEAR=l_year,\n",
@@ -545,23 +611,22 @@
"source": [
"%%pydough\n",
"\n",
- "selected_orders = orders.WHERE(\n",
- " (order_date >= datetime.date(1995, 1, 1))\n",
- " & (order_date <= datetime.date(1996, 12, 31))\n",
- " & (customer.nation.region.name == \"AMERICA\")\n",
- ")\n",
- "\n",
- "volume = extended_price * (1 - discount)\n",
- "\n",
- "volume_data = selected_orders.lines.WHERE(\n",
- " part.part_type == \"ECONOMY ANODIZED STEEL\"\n",
- ")(\n",
- " o_year=YEAR(BACK(1).order_date),\n",
- " volume=volume,\n",
- " brazil_volume=IFF(supplier.nation.name == \"BRAZIL\", volume, 0)\n",
+ "volume_data = (\n",
+ " nations.CALCULATE(nation_name=name)\n",
+ " .suppliers.supply_records.WHERE(part.part_type == \"ECONOMY ANODIZED STEEL\")\n",
+ " .lines.CALCULATE(volume=extended_price * (1 - discount))\n",
+ " .order.CALCULATE(\n",
+ " o_year=YEAR(order_date),\n",
+ " brazil_volume=IFF(nation_name == \"BRAZIL\", volume, 0),\n",
+ " )\n",
+ " .WHERE(\n",
+ " (order_date >= datetime.date(1995, 1, 1))\n",
+ " & (order_date <= datetime.date(1996, 12, 31))\n",
+ " & (customer.nation.region.name == \"AMERICA\")\n",
+ " )\n",
")\n",
"\n",
- "output = PARTITION(volume_data, name=\"v\", by=o_year)(\n",
+ "output = PARTITION(volume_data, name=\"v\", by=o_year).CALCULATE(\n",
" O_YEAR=o_year,\n",
" MKT_SHARE=SUM(v.brazil_volume) / SUM(v.volume),\n",
")\n",
@@ -626,17 +691,20 @@
"source": [
"%%pydough\n",
"\n",
- "selected_lines = nations.suppliers.supply_records.WHERE(\n",
- " CONTAINS(part.name, \"green\")\n",
- ").lines(\n",
- " nation=BACK(3).name,\n",
- " o_year=YEAR(order.order_date),\n",
- " value=extended_price * (1 - discount) - BACK(1).supplycost * quantity,\n",
+ "selected_lines = (\n",
+ " nations.CALCULATE(nation_name=name)\n",
+ " .suppliers.supply_records.CALCULATE(supplycost)\n",
+ " .WHERE(CONTAINS(part.name, \"green\"))\n",
+ " .lines.CALCULATE(\n",
+ " o_year=YEAR(order.order_date),\n",
+ " value=extended_price * (1 - discount) - supplycost * quantity,\n",
+ " )\n",
+ ")\n",
+ "output = (\n",
+ " PARTITION(selected_lines, name=\"l\", by=(nation_name, o_year))\n",
+ " .CALCULATE(NATION=nation_name, O_YEAR=o_year, AMOUNT=SUM(l.value))\n",
+ " .TOP_K(10, by=(NATION.ASC(), O_YEAR.DESC()))\n",
")\n",
- "\n",
- "output = PARTITION(selected_lines, name=\"l\", by=(nation, o_year))(\n",
- " NATION=nation, O_YEAR=o_year, AMOUNT=SUM(l.value)\n",
- ").ORDER_BY(NATION.ASC(), O_YEAR.DESC())\n",
"pydough.to_df(output)"
]
},
@@ -703,9 +771,9 @@
"selected_lines = orders.WHERE(\n",
" (order_date >= datetime.date(1993, 10, 1))\n",
" & (order_date < datetime.date(1994, 1, 1))\n",
- ").lines.WHERE(return_flag == \"R\")(amt=extended_price * (1 - discount))\n",
+ ").lines.WHERE(return_flag == \"R\").CALCULATE(amt=extended_price * (1 - discount))\n",
"\n",
- "output = customers(\n",
+ "output = customers.CALCULATE(\n",
" C_CUSTKEY=key,\n",
" C_NAME=name,\n",
" REVENUE=SUM(selected_lines.amt),\n",
@@ -774,12 +842,12 @@
"source": [
"%%pydough\n",
"is_german_supplier = supplier.nation.name == \"GERMANY\"\n",
- "selected_records = supply_records.WHERE(is_german_supplier)(metric=supplycost * availqty)\n",
- "output = TPCH(min_market_share=SUM(selected_records.metric) * 0.0001).PARTITION(\n",
+ "selected_records = supply_records.WHERE(is_german_supplier).CALCULATE(metric=supplycost * availqty)\n",
+ "output = TPCH.CALCULATE(min_market_share=SUM(selected_records.metric) * 0.0001).PARTITION(\n",
" selected_records, name=\"ps\", by=part_key\n",
- ")(\n",
+ ").CALCULATE(\n",
" PS_PARTKEY=part_key, VALUE=SUM(ps.metric)\n",
- ").WHERE(VALUE > BACK(1).min_market_share).ORDER_BY(VALUE.DESC())\n",
+ ").WHERE(VALUE > min_market_share).ORDER_BY(VALUE.DESC())\n",
"pydough.to_df(output)"
]
},
@@ -843,11 +911,11 @@
" & (commit_date < receipt_date)\n",
" & (receipt_date >= datetime.date(1994, 1, 1))\n",
" & (receipt_date < datetime.date(1995, 1, 1))\n",
- ")(\n",
+ ").CALCULATE(\n",
" is_high_priority=(order.order_priority == \"1-URGENT\")\n",
" | (order.order_priority == \"2-HIGH\"),\n",
")\n",
- "output = PARTITION(selected_lines, \"l\", by=ship_mode)(\n",
+ "output = PARTITION(selected_lines, \"l\", by=ship_mode).CALCULATE(\n",
" L_SHIPMODE=ship_mode,\n",
" HIGH_LINE_COUNT=SUM(l.is_high_priority),\n",
" LOW_LINE_COUNT=SUM(~(l.is_high_priority)),\n",
@@ -899,13 +967,13 @@
"source": [
"%%pydough\n",
"\n",
- "customer_info = customers(\n",
+ "customer_info = customers.CALCULATE(\n",
" key,\n",
" num_non_special_orders=COUNT(\n",
" orders.WHERE(~(LIKE(comment, \"%special%requests%\")))\n",
" ),\n",
")\n",
- "output = PARTITION(customer_info, name=\"custs\", by=num_non_special_orders)(\n",
+ "output = PARTITION(customer_info, name=\"custs\", by=num_non_special_orders).CALCULATE(\n",
" C_COUNT=num_non_special_orders, CUSTDIST=COUNT(custs)\n",
").ORDER_BY(CUSTDIST.DESC(), C_COUNT.DESC())\n",
"pydough.to_df(output)"
@@ -952,11 +1020,11 @@
"selected_lines = lines.WHERE(\n",
" (ship_date >= datetime.date(1995, 9, 1))\n",
" & (ship_date < datetime.date(1995, 10, 1))\n",
- ")(\n",
+ ").CALCULATE(\n",
" value=value,\n",
" promo_value=IFF(STARTSWITH(part.part_type, \"PROMO\"), value, 0),\n",
")\n",
- "output = TPCH(PROMO_REVENUE=100.0 * SUM(selected_lines.promo_value) / SUM(selected_lines.value))\n",
+ "output = TPCH.CALCULATE(PROMO_REVENUE=100.0 * SUM(selected_lines.promo_value) / SUM(selected_lines.value))\n",
"pydough.to_df(output)"
]
},
@@ -1022,15 +1090,15 @@
" & (ship_date < datetime.date(1996, 4, 1))\n",
")\n",
"total = SUM(selected_lines.extended_price * (1 - selected_lines.discount))\n",
- "output = TPCH(\n",
- " max_revenue=MAX(suppliers(total_revenue=total).total_revenue)\n",
- ").suppliers(\n",
+ "output = TPCH.CALCULATE(\n",
+ " max_revenue=MAX(suppliers.CALCULATE(total_revenue=total).total_revenue)\n",
+ ").suppliers.CALCULATE(\n",
" S_SUPPKEY=key,\n",
" S_NAME=name,\n",
" S_ADDRESS=address,\n",
" S_PHONE=phone,\n",
" TOTAL_REVENUE=total,\n",
- ").WHERE(TOTAL_REVENUE == BACK(1).max_revenue).ORDER_BY(S_SUPPKEY.ASC())\n",
+ ").WHERE(TOTAL_REVENUE == max_revenue).ORDER_BY(S_SUPPKEY.ASC())\n",
"pydough.to_df(output)"
]
},
@@ -1091,20 +1159,19 @@
"%%pydough\n",
"\n",
"selected_records = (\n",
- " parts.WHERE(\n",
+ " parts.CALCULATE(\n",
+ " p_brand=brand,\n",
+ " p_type=part_type,\n",
+ " p_size=size,\n",
+ " ).WHERE(\n",
" (brand != \"BRAND#45\")\n",
" & ~STARTSWITH(part_type, \"MEDIUM POLISHED%\")\n",
" & ISIN(size, [49, 14, 23, 45, 19, 3, 36, 9])\n",
" )\n",
- " .supply_records(\n",
- " p_brand=BACK(1).brand,\n",
- " p_type=BACK(1).part_type,\n",
- " p_size=BACK(1).size,\n",
- " ps_suppkey=supplier_key,\n",
- " )\n",
+ " .supply_records\n",
" .WHERE(~LIKE(supplier.comment, \"%Customer%Complaints%\"))\n",
")\n",
- "output = PARTITION(selected_records, name=\"ps\", by=(p_brand, p_type, p_size))(\n",
+ "output = PARTITION(selected_records, name=\"ps\", by=(p_brand, p_type, p_size)).CALCULATE(\n",
" P_BRAND=p_brand,\n",
" P_TYPE=p_type,\n",
" P_SIZE=p_size,\n",
@@ -1158,10 +1225,12 @@
"source": [
"%%pydough\n",
"\n",
- "selected_lines = parts.WHERE((brand == \"Brand#23\") & (container == \"MED BOX\"))(\n",
+ "selected_lines = parts.WHERE(\n",
+ " (brand == \"Brand#23\") & (container == \"MED BOX\")\n",
+ ").CALCULATE(\n",
" avg_quantity=AVG(lines.quantity)\n",
- ").lines.WHERE(quantity < 0.2 * BACK(1).avg_quantity)\n",
- "output = TPCH(AVG_YEARLY=SUM(selected_lines.extended_price) / 7.0)\n",
+ ").lines.WHERE(quantity < 0.2 * avg_quantity)\n",
+ "output = TPCH.CALCULATE(AVG_YEARLY=SUM(selected_lines.extended_price) / 7.0)\n",
"pydough.to_df(output)"
]
},
@@ -1224,7 +1293,7 @@
"source": [
"%%pydough\n",
"\n",
- "output = orders(\n",
+ "output = orders.CALCULATE(\n",
" C_NAME=customer.name,\n",
" C_CUSTKEY=customer.key,\n",
" O_ORDERKEY=key,\n",
@@ -1338,7 +1407,7 @@
" )\n",
" )\n",
")\n",
- "output = TPCH(\n",
+ "output = TPCH.CALCULATE(\n",
" REVENUE=SUM(selected_lines.extended_price * (1 - selected_lines.discount))\n",
")\n",
"pydough.to_df(output)"
@@ -1413,22 +1482,190 @@
" & (ship_date < datetime.date(1995, 1, 1))\n",
" ).quantity\n",
")\n",
- "selected_part_supplied = supply_records.part.WHERE(\n",
- " STARTSWITH(name, \"forest\") & (BACK(1).availqty > part_qty * 0.5)\n",
+ "selected_part_supplied = supply_records.CALCULATE(\n",
+ " availqty\n",
+ ").part.WHERE(\n",
+ " STARTSWITH(name, \"forest\") & (availqty > part_qty * 0.5)\n",
")\n",
- "output = suppliers(\n",
+ "output = suppliers.CALCULATE(\n",
" S_NAME=name,\n",
" S_ADDRESS=address,\n",
").WHERE((nation.name == \"CANADA\") & COUNT(selected_part_supplied) > 0).ORDER_BY(S_NAME.ASC())\n",
"pydough.to_df(output)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "82c059e6",
+ "metadata": {},
+ "source": [
+ "## Query 21\n",
+ "\n",
+ "This query **identifies certain suppliers who were not able to ship required parts in a timely manner**. It does so by counting, for each Saudi Arabian supplier, how many times their product was part of a multi-supplier order (with current status of 'F') where they were the only supplier who failed to meet the committed delivery date.\n",
+ "\n",
+ "\n",
+ "Here is the corresponding SQL:\n",
+ "\n",
+ "```SQL\n",
+ "select\n",
+ " S_NAME,\n",
+ " count(*) as NUMWAIT\n",
+ "from\n",
+ " supplier,\n",
+ " lineitem l1,\n",
+ " orders,\n",
+ " nation\n",
+ "where\n",
+ " s_suppkey = l1.l_suppkey\n",
+ " and o_orderkey = l1.l_orderkey\n",
+ " and o_orderstatus = 'F'\n",
+ " and l1.l_receiptdate > l1.l_commitdate\n",
+ " and exists (\n",
+ " select\n",
+ " *\n",
+ " from\n",
+ " lineitem l2\n",
+ " where\n",
+ " l2.l_orderkey = l1.l_orderkey\n",
+ " and l2.l_suppkey <> l1.l_suppkey\n",
+ " )\n",
+ " and not exists (\n",
+ " select\n",
+ " *\n",
+ " from\n",
+ " lineitem l3\n",
+ " where\n",
+ " l3.l_orderkey = l1.l_orderkey\n",
+ " and l3.l_suppkey <> l1.l_suppkey\n",
+ " and l3.l_receiptdate > l3.l_commitdate\n",
+ " )\n",
+ " and s_nationkey = n_nationkey\n",
+ " and n_name = 'SAUDI ARABIA'\n",
+ "group by\n",
+ " s_name\n",
+ "order by\n",
+ " numwait desc,\n",
+ " s_name\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8e400982",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%pydough\n",
+ "\n",
+ "date_check = receipt_date > commit_date\n",
+ "selected_orders = lines.CALCULATE(original_key=supplier_key).WHERE(date_check).order\n",
+ "different_supplier = supplier_key != original_key\n",
+ "waiting_entries = selected_orders.WHERE(\n",
+ " (order_status == \"F\")\n",
+ " & HAS(lines.WHERE(different_supplier))\n",
+ " & HASNOT(lines.WHERE(different_supplier & date_check))\n",
+ ")\n",
+ "output = (\n",
+ " suppliers.WHERE(nation.name == \"SAUDI ARABIA\")\n",
+ " .CALCULATE(\n",
+ " S_NAME=name,\n",
+ " NUMWAIT=COUNT(waiting_entries),\n",
+ " )\n",
+ " .ORDER_BY(NUMWAIT.DESC(), S_NAME.ASC())\n",
+ ")\n",
+ "\n",
+ "pydough.to_df(output)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "be8bbff8",
+ "metadata": {},
+ "source": [
+ "## Query 22\n",
+ "\n",
+ "This query **identifies geographies where there are customers who may be likely to make a purchase**. It does so by breaking down how many customers, by country code of their phone number (only including customers from certain country codes), have not placed an order but have an account balance that is above the average for all positive account balances of such customers. Also includes the total balance for all such customers for each country code. The country codes to include are 13, 31, 23, 29, 30, 18 and 17.\n",
+ "\n",
+ "\n",
+ "Here is the corresponding SQL:\n",
+ "\n",
+ "```SQL\n",
+ "select\n",
+ " CNTRYCODE,\n",
+ " count(*) as NUMCUST,\n",
+ " sum(c_acctbal) as TOTALACCTBAL\n",
+ "from (\n",
+ " select\n",
+ " substring(c_phone from 1 for 2) as cntrycode,\n",
+ " c_acctbal\n",
+ " from\n",
+ " customer\n",
+ " where\n",
+ " substring(c_phone from 1 for 2) in\n",
+ " ('13', '31', '23', '29', '30', '18', '17')\n",
+ " and c_acctbal > (\n",
+ " select\n",
+ " avg(c_acctbal)\n",
+ " from\n",
+ " customer\n",
+ " where\n",
+ " c_acctbal > 0.00\n",
+ " and substring (c_phone from 1 for 2) in\n",
+ " ('13', '31', '23', '29', '30', '18', '17')\n",
+ " )\n",
+ " and not exists (\n",
+ " select\n",
+ " *\n",
+ " from\n",
+ " orders\n",
+ " where\n",
+ " o_custkey = c_custkey\n",
+ " )\n",
+ " ) as custsale\n",
+ "group by\n",
+ " cntrycode\n",
+ "order by\n",
+ " cntrycode\n",
+ "```"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
"id": "ed31951a-e76d-40f3-ad70-c911ba80c205",
"metadata": {},
"outputs": [],
+ "source": [
+ "%%pydough\n",
+ "\n",
+ "selected_customers = customers.CALCULATE(cntry_code=phone[:2]).WHERE(\n",
+ " ISIN(cntry_code, (\"13\", \"31\", \"23\", \"29\", \"30\", \"18\", \"17\"))\n",
+ ")\n",
+ "output = (\n",
+ " TPCH.CALCULATE(\n",
+ " global_avg_balance=AVG(selected_customers.WHERE(acctbal > 0.0).acctbal)\n",
+ " )\n",
+ " .PARTITION(\n",
+ " selected_customers.WHERE((acctbal > global_avg_balance) & (COUNT(orders) == 0)),\n",
+ " name=\"custs\",\n",
+ " by=cntry_code,\n",
+ " )\n",
+ " .CALCULATE(\n",
+ " CNTRY_CODE=cntry_code,\n",
+ " NUM_CUSTS=COUNT(custs),\n",
+ " TOTACCTBAL=SUM(custs.acctbal),\n",
+ " )\n",
+ ")\n",
+ "pydough.to_df(output)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9eca1213-52cf-45b3-b176-da87d95ec080",
+ "metadata": {},
+ "outputs": [],
"source": []
}
],
@@ -1448,7 +1685,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.7"
+ "version": "3.12.6"
}
},
"nbformat": 4,
diff --git a/demos/notebooks/5_what_if.ipynb b/demos/notebooks/5_what_if.ipynb
index c3b03b9a..15811785 100644
--- a/demos/notebooks/5_what_if.ipynb
+++ b/demos/notebooks/5_what_if.ipynb
@@ -114,7 +114,7 @@
"```python\n",
"revenue_def = extended_price*(1-discount)\n",
"orders(total_line_price=SUM(lines(line_price=revenue_def).line_price)).lines(\n",
- " revenue_ratio=revenue_def / BACK(1).total_price, \n",
+ " revenue_ratio=revenue_def / total_line_price, \n",
" order_key=order_key, \n",
" line_number=line_number\n",
").TOP_K(5, by=(revenue_ratio.ASC(), order_key.DESC(), line_number.DESC()))\n",
@@ -200,7 +200,7 @@
"source": [
"%%pydough\n",
"\n",
- "total_revenue = SUM(lines(line_renvenue=revenue_def).line_renvenue)\n",
+ "total_revenue = SUM(lines.CALCULATE(line_renvenue=revenue_def).line_renvenue)\n",
"total_revenue"
]
},
@@ -221,7 +221,7 @@
"source": [
"%%pydough\n",
"\n",
- "pydough.to_df(TPCH(total_line_revenue=total_revenue))"
+ "pydough.to_df(TPCH.CALCULATE(total_line_revenue=total_revenue))"
]
},
{
@@ -241,7 +241,7 @@
"source": [
"%%pydough\n",
"\n",
- "order_total_price = orders(order_revenue=total_revenue)\n",
+ "order_total_price = orders.CALCULATE(order_revenue=total_revenue)\n",
"pydough.to_df(order_total_price)"
]
},
@@ -267,7 +267,7 @@
"%%pydough\n",
"# Compute the sum of the first 5 line numbers, which can be known for testing.\n",
"top_five_lines = lines.TOP_K(5, by=(line_number.ASC(), order_key.ASC()))\n",
- "top_five_line_price = TPCH(total_line_revenue=SUM(top_five_lines(line_revenue=revenue_def).line_revenue))\n",
+ "top_five_line_price = TPCH.CALCULATE(total_line_revenue=SUM(top_five_lines.CALCULATE(line_revenue=revenue_def).line_revenue))\n",
"pydough.to_df(top_five_line_price)"
]
},
@@ -276,7 +276,7 @@
"id": "23e6c569-fc61-4338-ad28-ddbe20aa3b88",
"metadata": {},
"source": [
- "Now let's return to extending our question. Building able to compute order sums is great, but we care about results per line. As a result, now we can even extend our orders to an additional context within lines. We will once again define more defintions. Our ratio definition will now ask us to propagate our previous `total_revenue` that we computed and compare it to the result of `revenue_def`."
+ "Now let's return to extending our question. Building able to compute order sums is great, but we care about results per line. As a result, now we can even extend our orders to an additional context within lines. We will once again define more defintions. Our ratio definition will now ask us to propagate our previous `order_revenue` that we computed (down-streamed from an ancestor context) and compare it to the result of `revenue_def`."
]
},
{
@@ -288,7 +288,7 @@
"source": [
"%%pydough\n",
"\n",
- "ratio = revenue_def / BACK(1).order_revenue"
+ "ratio = revenue_def / order_revenue"
]
},
{
@@ -310,7 +310,7 @@
"source": [
"%%pydough\n",
"\n",
- "line_ratios = order_total_price.lines(revenue_ratio=ratio, order_key=order_key, line_number=line_number)\n",
+ "line_ratios = order_total_price.lines.CALCULATE(revenue_ratio=ratio, order_key=order_key, line_number=line_number)\n",
"lowest_ratios = line_ratios.TOP_K(5, by=(revenue_ratio.ASC(), order_key.DESC(), line_number.DESC()))"
]
},
@@ -375,14 +375,14 @@
"%%pydough\n",
"\n",
"total_lines = COUNT(lines)\n",
- "order_total_price = orders(order_revenue=total_revenue, line_count=total_lines)\n",
- "line_ratios = order_total_price.lines(\n",
+ "order_total_price = orders.CALCULATE(order_revenue=total_revenue, line_count=total_lines)\n",
+ "line_ratios = order_total_price.lines.CALCULATE(\n",
" revenue_ratio=ratio, \n",
- " line_count=BACK(1).line_count, \n",
+ " line_count=line_count, \n",
" order_key=order_key, \n",
" line_number=line_number\n",
")\n",
- "filtered_ratios = line_ratios.WHERE(line_count > 3)(revenue_ratio, order_key, line_number)"
+ "filtered_ratios = line_ratios.WHERE(line_count > 3).CALCULATE(revenue_ratio, order_key, line_number)"
]
},
{
@@ -435,7 +435,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.7"
+ "version": "3.12.6"
}
},
"nbformat": 4,
diff --git a/documentation/dsl.md b/documentation/dsl.md
index 57965054..35f04a63 100644
--- a/documentation/dsl.md
+++ b/documentation/dsl.md
@@ -7,10 +7,10 @@ This page describes the specification of the PyDough DSL. The specification incl
- [Example Graph](#example-graph)
- [Collections](#collections)
* [Sub-Collections](#sub-collections)
- * [CALC](#calc)
+ * [CALCULATE](#calculate)
* [Contextless Expressions](#contextless-expressions)
- * [BACK](#back)
* [Expressions](#expressions)
+ * [Down-Streaming](#down-streaming)
- [Collection Operators](#collection-operators)
* [WHERE](#where)
* [ORDER_BY](#order_by)
@@ -141,36 +141,38 @@ Addresses.former_occupants
Packages.shipping_addresses
```
-
-### CALC
+
+### CALCULATE
-The examples so far just show selecting all properties from records of a collection. Most of the time, an analytical question will only want a subset of the properties, may want to rename them, and may want to derive new properties via calculated expressions. The way to do this is with a `CALC` term, which is done by following a PyDough collection with parenthesis containing the expressions that should be included.
+The examples so far just show selecting all properties from records of a collection. Most of the time, an analytical question will only want a subset of the properties, and may want to derive new properties via calculated expressions. The way to do this is with a `CALCULATE` term. This method contains the expressions that should be derived by the `CALCULATE` operation.
These expressions can be positional arguments or keyword arguments. Keyword arguments use the name of the keyword as the name of the output expression. Positional arguments use the name of the expression, if one exists, otherwise an arbitrary name is chosen.
-The value of one of these terms in a CALC must be expressions that are singular with regards to the current context. That can mean:
+The value of one of these terms in a `CALCULATE` must be expressions that are singular with regards to the current context. That can mean:
- Referencing one of the scalar properties of the current collection.
- Creating a literal.
- Referencing a singular expression of a sub-collection of the current collection that is singular with regards to the current collection.
- Calling a non-aggregation function on more singular expressions.
- Calling an aggregation function on a plural expression.
-Once a CALC term is created, all terms of the current collection still exist even if they weren't part of the CALC and can still be referenced, they just will not be part of the final answer. If there are multiple CALC terms, the last one is used to determine what expressions are part of the final answer, so earlier CALCs can be used to derive intermediary expressions. If a CALC includes a term with the same name as an existing property of the collection, the existing name is overridden to include the new term.
+Once a `CALCULATE` clause is created, all terms of the current collection still exist even if they weren't part of the `CALCULATE` and can still be referenced, they just will not be part of the final answer. If there are multiple `CALCULATE` clause, the last one is used to determine what expressions are part of the final answer, so earlier `CALCULATE` clauses can be used to derive intermediary expressions. If a `CALCULATE` includes a term with the same name as an existing property of the collection, the existing name is overridden to include the new term.
-A CALC can also be done on the graph itself to create a collection with 1 row and columns corresponding to the properties inside the CALC. This is useful when aggregating an entire collection globally instead of with regards to a parent collection.
+Importantly, when a term is defined in a `CALCULATE`, that definition does not take effect until after the `CALCULATE` completes. This means that if a term in a `CALCULATE` uses the definition of a term defined in the same `CALCULATE`, it will not work.
+
+A `CALCULATE` can also be done on the graph itself to create a collection with 1 row and columns corresponding to the properties inside the `CALCULATE`. This is useful when aggregating an entire collection globally instead of with regards to a parent collection.
**Good Example #1**: For every person, fetch just their first name and last name.
```py
%%pydough
-People(first_name, last_name)
+People.CALCULATE(first_name, last_name)
```
**Good Example #2**: For every package, fetch the package id, the first and last name of the person who ordered it, and the state that it was shipped to. Also, include a field named `secret_key` that is always equal to the string `"alphabet soup"`.
```py
%%pydough
-Packages(
+Packages.CALCULATE(
package_id,
first_name=customer.first_name,
last_name=customer.last_name,
@@ -183,7 +185,7 @@ Packages(
```py
%%pydough
-People(
+People.CALCULATE(
name=JOIN_STRINGS("", first_name, last_name),
n_packages_ordered=COUNT(packages),
)
@@ -193,11 +195,11 @@ People(
```py
%%pydough
-People(
+People.CALCULATE(
has_middle_name=PRESENT(middle_name)
full_name_with_middle=JOIN_STRINGS(" ", first_name, middle_name, last_name),
full_name_without_middle=JOIN_STRINGS(" ", first_name, last_name),
-)(
+).CALCULATE(
full_name=IFF(has_middle_name, full_name_with_middle, full_name_without_middle),
email=email,
)
@@ -207,7 +209,7 @@ People(
```py
%%pydough
-People(
+People.CALCULATE(
most_recent_package_year=YEAR(MAX(packages.order_date)),
first_ever_package_year=YEAR(MIN(packages.order_date)),
)
@@ -217,7 +219,7 @@ People(
```py
%%pydough
-GRAPH(
+GRAPH.CALCULATE(
n_people=COUNT(People),
n_packages=COUNT(Packages),
n_addresses=COUNT(Addresses),
@@ -228,7 +230,7 @@ GRAPH(
```py
%%pydough
-Packages(
+Packages.CALCULATE(
package_id,
shipped_to_curr_addr=shipping_address.address_id == customer.current_address.address_id
)
@@ -238,31 +240,31 @@ Packages(
```py
%%pydough
-People(first_name, last_name, phone_number)
+People.CALCULATE(first_name, last_name, phone_number)
```
**Bad Example #2**: For each person, list their combined first & last name followed by their email. This is invalid because a positional argument is included after a keyword argument.
```py
%%pydough
-People(
+People.CALCULATE(
full_name=JOIN_STRINGS(" ", first_name, last_name),
email
)
```
-**Bad Example #3**: For each person, list the address_id of packages they have ordered. This is invalid because `packages` is a plural property of `People`, so its properties cannot be included in a calc term of `People` unless aggregated.
+**Bad Example #3**: For each person, list the address_id of packages they have ordered. This is invalid because `packages` is a plural property of `People`, so its properties cannot be included in a `CALCULATE` term of `People` unless aggregated.
```py
%%pydough
-People(packages.address_id)
+People.CALCULATE(packages.address_id)
```
-**Bad Example #4**: For each person, list their first/last name followed by the concatenated city/state name of their current address. This is invalid because `current_address` is a plural property of `People`, so its properties cannot be included in a calc term of `People` unless aggregated.
+**Bad Example #4**: For each person, list their first/last name followed by the concatenated city/state name of their current address. This is invalid because `current_address` is a plural property of `People`, so its properties cannot be included in a `CALCULATE` term of `People` unless aggregated.
```py
%%pydough
-People(
+People.CALCULATE(
first_name,
last_name,
location=JOIN_STRINGS(", ", current_address.city, current_address.state),
@@ -273,35 +275,48 @@ People(
```py
%%pydough
-Addresses(is_c_state=state.startswith("c"))
+Addresses.CALCULATE(is_c_state=state.startswith("c"))
```
**Bad Example #6**: For each address, find the state bird of the state it is in. This is invalid because the `state` property of each record of `Addresses` is a scalar expression, not a subcolleciton, so it does not have any properties that can be accessed with `.` syntax.
```py
%%pydough
-Addresses(state_bird=state.bird)
+Addresses.CALCULATE(state_bird=state.bird)
```
**Bad Example #7**: For each current occupant of each address, list their first name, last name, and city/state they live in. This is invalid because `city` and `state` are not properties of the current collection (`People`, accessed via `current_occupants` of each record of `Addresses`).
```py
%%pydough
-Addresses.current_occupants(first_name, last_name, city, state)
+Addresses.current_occupants.CALCULATE(first_name, last_name, city, state)
```
-**Bad Example #8**: For each person include their ssn and current address. This is invalid because a collection cannot be a CALC term, and `current_address` is a sub-collection property of `People`. Instead, properties of `current_address` can be accessed.
+**Bad Example #8**: For each person include their ssn and current address. This is invalid because a collection cannot be a `CALCULATE` term, and `current_address` is a sub-collection property of `People`. Instead, properties of `current_address` can be accessed.
```py
%%pydough
-People(ssn, current_address)
+People.CALCULATE(ssn, current_address)
```
-**Bad Example #9**: For each person, list their first name, last name, and the sum of the package costs. This is invalid because `SUM` is an aggregation function and cannot be used in a CALC term without specifying the sub-collection it should be applied to.
+**Bad Example #9**: For each person, list their first name, last name, and the sum of the package costs. This is invalid because `SUM` is an aggregation function and cannot be used in a `CALCULATE` term without specifying the sub-collection it should be applied to.
```py
%%pydough
-People(first_name, last_name, total_cost=SUM(package_cost))
+People.CALCULATE(first_name, last_name, total_cost=SUM(package_cost))
+```
+
+**Bad Example #9**: For each person, list their first name, last name, and the ratio between the cost of all packages they apply ordered and the number of packages they ordered. This is invalid the `total_cost` and `n_packages` are used to define `ratio` in the same `CALCULATE` where they are defined.
+
+```py
+%%pydough
+People.CALCULATE(
+ first_name,
+ last_name,
+ total_cost=SUM(packages.package_cost),
+ n_packages=COUNT(packages),
+ ratio=total_cost/n_packages,
+)
```
@@ -309,14 +324,14 @@ People(first_name, last_name, total_cost=SUM(package_cost))
PyDough allows defining snippets of PyDough code out of context that do not make sense until they are later placed within a context. This can be done by writing a contextless expression, binding it to a variable as if it were any other Python expression, then later using it inside of PyDough code. This should always have the same effect as if the PyDough code was written fully in-context, but allows re-using common snippets.
-**Good Example #1**: Same as good example #4 from the CALC section, but written with contextless expressions.
+**Good Example #1**: Same as good example #4 from the `CALCULATE` section, but written with contextless expressions.
```py
%%pydough
has_middle_name = PRESENT(middle_name)
full_name_with_middle = JOIN_STRINGS(" ", first_name, middle_name, last_name),
full_name_without_middle = JOIN_STRINGS(" ", first_name, last_name),
-People(
+People.CALCULATE(
full_name=IFF(has_middle_name, full_name_with_middle, full_name_without_middle),
email=email,
)
@@ -328,13 +343,13 @@ People(
%%pydough
is_february = MONTH(order_date) == 2
february_value = KEEP_IF(package_cost, is_february)
-aug_packages = packages(
+aug_packages = packages.CALCULATE(
is_february=is_february,
february_value=february_value,
is_valentines_day=is_february & (DAY(order_date) == 14)
)
n_feb_packages = SUM(aug_packages.is_february)
-People(
+People.CALCULATE(
ssn,
total_february_value=SUM(aug_packages.february_value),
n_february_packages=n_feb_packages,
@@ -347,7 +362,7 @@ People(
```py
%%pydough
-current_addresses(city, state)
+current_addresses.CALCULATE(city, state)
```
**Bad Example #2**: Just a contextless expression for a scalar expression that has not been placed into a collection for it to make sense.
@@ -362,7 +377,7 @@ LOWER(current_occupants.first_name)
```py
%%pydough
value = package_cost
-People(x=ssn + value)
+People.CALCULATE(x=ssn + value)
```
**Bad Example #4**: A contextless expression that does not make sense when placed into its context (`People` does not have a property named `order_date`, so substituting it when `is_february` is referenced does not make sense).
@@ -370,23 +385,106 @@ People(x=ssn + value)
```py
%%pydough
is_february = MONTH(order_date) == 2
-People(february=is_february)
+People.CALCULATE(february=is_february)
+```
+
+
+### Expressions
+
+So far, many different kinds of expressions have been noted in the examples for `CALCULATE` and contextless expressions. The following are examples & explanations of the various types of valid expressions:
+
+```py
+# Referencing scalar properties of the current collection
+People.CALCULATE(
+ first_name,
+ last_name
+)
+
+# Referencing scalar properties of a singular sub-collection
+People.CALCULATE(
+ current_state=current_address.state,
+ current_state=current_address.state,
+)
+
+# Referencing properties from the CALCULATE an ancestor collection
+# (see down-streaming for more details)
+Addresses.CALCULATE(zip_code).current_occupants.CALCULATE(email).packages.CALCULATE(
+ email, # <- refers to the `email` from `current_occupants`
+ zip_code, # <- refers to the `zip_code` from `Addresses`
+)
+
+# Invoking normal functions/operations on other singular data
+Customers.CALCULATE(
+ lowered_name=LOWER(name),
+ normalized_birth_month=MONTH(birth_date) - 1,
+ lives_in_c_state=STARTSWITH(current_address.state, "C"),
+)
+
+# Supported Python literals:
+# - integers
+# - floats
+# - strings
+# - booleans
+# - None
+# - decimal.Decimal
+# - pandas.Timestamp
+# - datetime.date
+# - lists/tuples of literals
+from pandas import Timestamp
+from datetime import date
+from decimal import Decimal
+Customers.CALCULATE(
+ a=0,
+ b=3.14,
+ c="hello world",
+ d=True,
+ e=None,
+ f=decimal.Decimal("2.718281828"),
+ g=Timestamp("now"),
+ h=date(2024, 1, 1),
+ i=[1, 2, 4, 8, 10],
+ j=("SMALL", "LARGE"),
+)
+
+# Invoking aggregation functions on plural data
+Customers.CALCULATE(
+ n_packages=COUNT(packages),
+ home_has_had_packages_billed=HAS(current_address.billed_packages),
+ avg_package_cost=AVG(packages.package_cost),
+ n_states_shipped_to=NDISTINCT(packages.shipping_address.state),
+ most_recent_package_ordered=MAX(packages.order_date),
+)
+
+# Invoking window functions on singular
+Customers.CALCULATE(
+ cust_ranking=RANKING(by=COUNT(packages).DESC()),
+ cust_percentile=PERCENTILE(by=COUNT(packages).DESC()),
+)
```
-
-### BACK
+See [the list of PyDough functions](funcitons.md) to see all of the builtin functions & operators that can be called in PyDough.
+
+
+### Down-Streaming
+
+Whenever an expression is defined inside of a `CALCULATE` call, it is available to all descendants of the current context using the same name. However, to avoid ambiguity, this means that descendants invoke or create any properties they have that share a name with one of these terms from an ancestor `CALCULATE`. As a result, it is best practice to avoid using names in `CALCULATE` that exist elsewhere in the collections being used.
+
+However, only names that have been placed in a `CALCULATE` are available to descendant terms; any other properties of the current context are not made available to its descendants.
+
+There is a key caveat to the name conflict rule: it is ok to create a term with a name conflict so long as it is a no-op assignment. For example, `collection.CALCULATE(x=a+b, y=a-b).subcollection.CALCULATE(x, y=y)` is legal even though `x` and `y` are defined in both the ancestor and descendant `CALCULATE` clauses because the definitions in the descendant just re-use the ones from the ancestor without changing any information or aliases.
-Part of the benefit of doing `collection.subcollection` accesses is that properties from the ancestor collection can be accessed from the current collection. This is done via a `BACK` call. Accessing properties from `BACK(n)` can be done to access properties from the n-th ancestor of the current collection. The simplest recommended way to do this is to just access a scalar property of an ancestor in order to include it in the final answer.
**Good Example #1**: For each address's current occupants, list their first name last name, and the city/state of the current address they belong to.
```py
%%pydough
-Addresses.current_occupants(
+Addresses.CALCULATE(
+ current_city=city, current_state=state
+).current_occupants.CALCULATE(
first_name,
last_name,
- current_city=BACK(1).city,
- current_state=BACK(1).state,
+ current_city,
+ current_state=current_state,
)
```
@@ -394,153 +492,92 @@ Addresses.current_occupants(
```py
%%pydough
-package_info = Addresses.current_occupants.packages(
- is_shipped_to_current_addr=shipping_address.address_id == BACK(2).address_id
+package_info = Addresses.CALCULATE(
+ first_address_id=address_id
+).current_occupants.packages.CALCULATE(
+ is_shipped_to_current_addr=shipping_address.address_id == first_address_id
)
-GRAPH(n_cases=SUM(package_info.is_shipped_to_current_addr))
+GRAPH.CALCULATE(n_cases=SUM(package_info.is_shipped_to_current_addr))
```
**Good Example #3**: Indicate whether a package is above the average cost for all packages ordered by that customer.
```py
%%pydough
-Customers(
+Customers.CALCULATE(
avg_package_cost=AVG(packages.cost)
-).packages(
- is_above_avg=cost > BACK(1).avg_package_cost
+).packages.CALCULATE(
+ is_above_avg=cost > avg_package_cost
)
```
-**Good Example #4**: For every customer, indicate what percentage of all packages billed to their current address were purchased by that same customer.
+**Good Example #4**: For every customer, indicate what percentage of all packages billed to their current address were purchased by that same customer (see [WHERE](#where) for more details).
```py
%%pydough
-aug_packages = packages(
- include=IFF(billing_address.address_id == BACK(2).address_id, 1, 0)
+packages_billed_home = packages.WHERE(
+ billing_address.address_id == original_address
)
-Addresses(
- n_packages=COUNT(packages_billed_to)
-).current_occupants(
+People.CALCULATE(
+ original_address=current_address.address_id,
+ n_packages=COUNT(current_address.packages_billed_to),
+).CALCULATE(
ssn,
- pct=100.0 * SUM(aug_packages.include) / BACK(1).n_packages
+ pct=100.0 * COUNT(packages_billed_home) / n_packages
)
```
-**Bad Example #1**: The `GRAPH` does not have any ancestors, so `BACK(1)` is invalid.
+**Bad Example #1**: `GRAPH` does not have a term named `foo`, nor does it have an ancestor, so there is no ancestor term that can be down-streamed.
```py
%%pydough
-GRAPH(x=BACK(1).foo)
+GRAPH.CALCULATE(x=foo)
```
-**Bad Example #2**: The 1st ancestor of `People` is `GRAPH` which does not have a term named `bar`.
+**Bad Example #2**: The only ancestor of `People` is `GRAPH` which does not have a term named `bar`.
```py
%%pydough
-People(y=BACK(1).bar)
+People.CALCULATE(y=bar)
```
-**Bad Example #3**: The 1st ancestor of `People` is `GRAPH` which does not have an ancestor. Therefore, `People` cannot have a 2nd ancestor.
+**Bad Example #3**: Even though `email` is a property of `People`, which is an ancestor of `packages`, it was not included in a `CALCULATE` of `People`, so it cannot be accessed by `packages`.
```py
%%pydough
-People(z=BACK(2).fizz)
+People.packages.CALCULATE(email)
```
-**Bad Example #4**: The 1st ancestor of `current_address` is `People` which does not have a term named `phone`.
+**Bad Example #4**: This time, `email` was placed in a `CALCULATE`, but it was given a different name `my_email` which means that `my_email` has to be used to access it, instead of `email`.
```py
%%pydough
-People.current_address(a=BACK(1).phone)
+People.CALCULATE(my_email=email).packages.CALCULATE(email)
```
-**Bad Example #5**: Even though `cust_info` has defined `avg_package_cost`, the final expression `Customers.packages(...)` does not have `cust_info` as an ancestor, so it cannot access `BACK(1).avg_package_cost` since its 1st ancestor (`Customers`) does not have any term named `avg_package_cost`.
+**Bad Example #5**: Even though `cust_info` has defined `avg_package_cost`, the final expression `Customers.packages.CALCULATE(...)` does not have `cust_info` as an ancestor, so it cannot access `avg_package_cost` since it is not part of its own ancestry.
```py
%%pydough
-cust_info = Customers(
+cust_info = Customers.CALCULATE(
avg_package_cost=AVG(packages.cost)
)
-Customers.packages(
- is_above_avg=cost > BACK(1).avg_package_cost
+Customers.packages.CALCULATE(
+ is_above_avg=cost > avg_package_cost
)
```
-
-### Expressions
-
-So far, many different kinds of expressions have been noted in the examples for CALC, BACK, and contextless expressions. The following are examples & explanations of the various types of valid expressions:
+**Bad Example #6**: The `CALCULATE` defines a term `zip_code`, which has a conflict since `zip_code` also exists as a property that is invoked by `current_address`. When `zip_code` is invoked, PyDough does not know whether it is referring to the `zip_code` property of `current_addresses` or the `zip_code` property defined in the `CALCULATE`.
```py
-# Referencing scalar properties of the current collection
-People(
- first_name,
- last_name
-)
-
-# Referencing scalar properties of a singular sub-collection
-People(
- current_state=current_address.state,
- current_state=current_address.state,
-)
-
-# Referencing scalar properties of an ancestor collection
-Addresses.current_occupants.packages(
- customer_email=BACK(1).email,
- customer_zip_code=BACK(2).zip_code,
-)
-
-# Invoking normal functions/operations on other singular data
-Customers(
- lowered_name=LOWER(name),
- normalized_birth_month=MONTH(birth_date) - 1,
- lives_in_c_state=STARTSWITH(current_address.state, "C"),
-)
-
-# Supported Python literals:
-# - integers
-# - floats
-# - strings
-# - booleans
-# - None
-# - decimal.Decimal
-# - pandas.Timestamp
-# - datetime.date
-# - lists/tuples of literals
-from pandas import Timestamp
-from datetime import date
-from decimal import Decimal
-Customers(
- a=0,
- b=3.14,
- c="hello world",
- d=True,
- e=None,
- f=decimal.Decimal("2.718281828"),
- g=Timestamp("now"),
- h=date(2024, 1, 1),
- i=[1, 2, 4, 8, 10],
- j=("SMALL", "LARGE"),
-)
-
-# Invoking aggregation functions on plural data
-Customers(
- n_packages=COUNT(packages),
- home_has_had_packages_billed=HAS(current_address.billed_packages),
- avg_package_cost=AVG(packages.package_cost),
- n_states_shipped_to=NDISTINCT(packages.shipping_address.state),
- most_recent_package_ordered=MAX(packages.order_date),
-)
-
-# Invoking window functions on singular
-Customers(
- cust_ranking=RANKING(by=COUNT(packages).DESC()),
- cust_percentile=PERCENTILE(by=COUNT(packages).DESC()),
+%%pydough
+People.CALCULATE(
+ zip_code=15213
+).current_address.CALCULATE(
+ is_chosen_zip_code = zip_code == 55555
)
```
-See [the list of PyDough functions](funcitons.md) to see all of the builtin functions & operators that can be called in PyDough.
-
## Collection Operators
@@ -549,27 +586,27 @@ So far all of the examples shown have been about accessing collections/sub-colle
### WHERE
-A core PyDough operation is the ability to filter the records of a collection. This is done by appending a PyDough collection with `.WHERE(cond)` where `cond` is any expression that could have been placed in a `CALC` term and should have a True/False value. Every record where `cond` evaluates to True will be preserved, and the rest will be dropped from the answer. The terms in the collection are unchanged by the `WHERE` clause, since the only change is which records are kept/dropped.
+A core PyDough operation is the ability to filter the records of a collection. This is done by appending a PyDough collection with `.WHERE(cond)` where `cond` is any expression that could have been placed in a `CALCULATE` term and should have a True/False value. Every record where `cond` evaluates to True will be preserved, and the rest will be dropped from the answer. The terms in the collection are unchanged by the `WHERE` clause, since the only change is which records are kept/dropped.
**Good Example #1**: For every person who has a middle name and and email that ends with `"gmail.com"`, fetches their first name and last name.
```py
%%pydough
-People.WHERE(PRESENT(middle_name) & ENDSWITH(email, "gmail.com"))(first_name, last_name)
+People.WHERE(PRESENT(middle_name) & ENDSWITH(email, "gmail.com")).CALCULATE(first_name, last_name)
```
**Good Example #2**: For every package where the package cost is greater than 100, fetches the package id and the state it was shipped to.
```py
%%pydough
-Packages.WHERE(package_cost > 100)(package_id, shipping_state=shipping_address.state)
+Packages.WHERE(package_cost > 100).CALCULATE(package_id, shipping_state=shipping_address.state)
```
**Good Example #3**: For every person who has ordered more than 5 packages, fetches their first name, last name, and email.
```py
%%pydough
-People(first_name, last_name, email).WHERE(COUNT(packages) > 5)
+People.CALCULATE(first_name, last_name, email).WHERE(COUNT(packages) > 5)
```
**Good Example #4**: Find every person whose most recent order was shipped in the year 2023, and list all properties of that person.
@@ -586,7 +623,7 @@ People.WHERE(YEAR(MAX(packages.order_date)) == 2023)
packages_jan_2018 = Packages.WHERE(
(YEAR(order_date) == 2018) & (MONTH(order_date) == 1)
)
-GRAPH(n_jan_2018=COUNT(selected_packages))
+GRAPH.CALCULATE(n_jan_2018=COUNT(selected_packages))
```
**Good Example #6**: Count how many people have don't have a first or last name that starts with A. [See here](functions.md#logical) for more details on the valid/invalid use of logical operations in Python.
@@ -596,7 +633,7 @@ GRAPH(n_jan_2018=COUNT(selected_packages))
selected_people = People.WHERE(
~STARTSWITH(first_name, "A") & ~STARTSWITH(first_name, "B")
)
-GRAPH(n_people=COUNT(selected_people))
+GRAPH.CALCULATE(n_people=COUNT(selected_people))
```
**Good Example #7**: Count how many people have a gmail or yahoo account. [See here](functions.md#logical) for more details on the valid/invalid use of logical operations in Python.
@@ -606,7 +643,7 @@ GRAPH(n_people=COUNT(selected_people))
gmail_or_yahoo = People.WHERE(
ENDSWITH(email, "@gmail.com") | ENDSWITH(email, "@yahoo.com")
)
-GRAPH(n_gmail_or_yahoo=COUNT(gmail_or_yahoo))
+GRAPH.CALCULATE(n_gmail_or_yahoo=COUNT(gmail_or_yahoo))
```
**Good Example #8**: Count how many people were born in the 1980s. [See here](functions.md#comparisons) for more details on the valid/invalid use of comparisons in Python.
@@ -616,7 +653,7 @@ GRAPH(n_gmail_or_yahoo=COUNT(gmail_or_yahoo))
eighties_babies = People.WHERE(
(1980 <= YEAR(birth_date)) & (YEAR(birth_date) < 1990)
)
-GRAPH(n_eighties_babies=COUNT(eighties_babies))
+GRAPH.CALCULATE(n_eighties_babies=COUNT(eighties_babies))
```
**Good Example #9**: Find every person whose has sent a package to Idaho.
@@ -637,35 +674,35 @@ People.WHERE(HASNOT(packages.WHERE(YEAR(order_date) == 2024)))
```py
%%pydough
-People.WHERE(PRESENT(phone_number))(first_name, last_name)
+People.WHERE(PRESENT(phone_number)).CALCULATE(first_name, last_name)
```
**Bad Example #2**: For every package, fetches the package id only if the package cost is greater than 100 and the shipping state is Texas. This is invalid because `and` is used instead of `&`. [See here](functions.md#logical) for more details on the valid/invalid use of logical operations in Python.
```py
%%pydough
-Packages.WHERE((package_cost > 100) and (shipping_address.state == "TX"))(package_id)
+Packages.WHERE((package_cost > 100) and (shipping_address.state == "TX")).CALCULATE(package_id)
```
**Bad Example #3**: For every package, fetches the package id only if the package is either being shipped from Pennsylvania or to Pennsylvania. This is invalid because `or` is used instead of `|`. [See here](functions.md#logical) for more details on the valid/invalid use of logical operations in Python.
```py
%%pydough
-Packages.WHERE((customer.current_address.state == "PA") or (shipping_address.state == "PA"))(package_id)
+Packages.WHERE((customer.current_address.state == "PA") or (shipping_address.state == "PA")).CALCULATE(package_id)
```
**Bad Example #4**: For every package, fetches the package id only if the customer's first name does not start with a J. This is invalid because `not` is used instead of `~`. [See here](functions.md#logical) for more details on the valid/invalid use of logical operations in Python.
```py
%%pydough
-Packages.WHERE(not STARTSWITH(customer.first_name, "J"))(package_id)
+Packages.WHERE(not STARTSWITH(customer.first_name, "J")).CALCULATE(package_id)
```
**Bad Example #5**: For every package, fetches the package id only if the package was ordered between February and May. [See here](functions.md#comparisons) for more details on the valid/invalid use of comparisons in Python.
```py
%%pydough
-Packages.WHERE(2 <= MONTH(arrival_date) <= 5)(package_id)
+Packages.WHERE(2 <= MONTH(arrival_date) <= 5).CALCULATE(package_id)
```
**Bad Example #6**: Obtain every person whose packages were shipped in the month of June. This is invalid because `packages` is a plural property of `People`, so `MONTH(packages.order_date) == 6` is a plural expression with regards to `People` that cannot be used as a filtering condition.
@@ -678,7 +715,7 @@ People.WHERE(MONTH(packages.order_date) == 6)
### ORDER_BY
-Another operation that can be done onto PyDough collections is sorting them. This is done by appending a collection with `.ORDER_BY(...)` which will order the collection by the collation terms between the parenthesis. The collation terms must be 1+ expressions that can be inside of a CALC term (singular expressions with regards to the current context), each decorated with information making it usable as a collation.
+Another operation that can be done onto PyDough collections is sorting them. This is done by appending a collection with `.ORDER_BY(...)` which will order the collection by the collation terms between the parenthesis. The collation terms must be 1+ expressions that can be inside of a `CALCULATE` term (singular expressions with regards to the current context), each decorated with information making it usable as a collation.
An expression becomes a collation expression when it is appended with `.ASC()` (indicating that the expression should be used to sort in ascending order) or `.DESC()` (indicating that the expression should be used to sort in descending order). Both `.ASC()` and `.DESC()` take in an optional argument `na_pos` indicating where to place null values. This keyword argument can be either `"first"` or `"last"`, and the default is `"first"` for `.ASC()` and `"last"` for `.DESC()`. The way the sorting works is that it orders by the first collation term provided, and in cases of ties it moves on to the second collation term, and if there are ties in that it moves on to the third, and so on until there are no more terms to sort by, at which point the ties are broken arbitrarily.
@@ -698,7 +735,7 @@ People.ORDER_BY(last_name.ASC(), first_name.ASC(), middle_name.ASC(na_pos="last"
```py
%%pydough
-People(
+People.CALCULATE(
ssn, n_packages=COUNT(packages).DESC()
).ORDER_BY(
n_packages.DESC(), birth_date.ASC()
@@ -727,15 +764,15 @@ People.ORDER_BY(
)
```
-**Good Example #5**: Same as good example #4, but written so it only includes people who are current occupants of an address in Ohio.
+**Good Example #5**: Same as good example #4, but written so it only includes people who are current occupants of an address in Ohio, and accesses the state/city via down-streaming.
```py
%%pydough
-Addresses.WHERE(
+Addresses.CALCULATE(state, city).WHERE(
state == "OHIO"
).current_occupants.ORDER_BY(
- BACK(1).state.ASC(),
- BACK(1).city.ASC(),
+ state.ASC(),
+ city.ASC(),
ssn.ASC(),
)
```
@@ -761,40 +798,27 @@ People.ORDER_BY(account_balance.DESC())
Addresses.ORDER_BY(current_occupants.ASC())
```
-**Bad Example #3**: Same as good example #5, but incorrect because `BACK(2)` is used, and `BACK(2)` refers to the 2nd ancestor of `current_occupants` which is `GRAPH`, which does not have any properties named `state` or `city`.
-
-```py
-%%pydough
-Addresses.WHERE(
- state == "OHIO"
-).current_occupants.ORDER_BY(
- BACK(2).state.ASC(),
- BACK(2).city.ASC(),
- ssn.ASC(),
-)
-```
-
-**Bad Example #4**: Same as bad example #3, but incorrect because `BACK(3)` is used, and `BACK(3)` refers to the 3rd ancestor of `current_occupants` which does not exist because the 2nd ancestor is `GRAPH`, which does not have any ancestors.
+**Bad Example #3**: Same as good example #5, but incorrect because `state` and `city` were not made available for down-streaming.
```py
%%pydough
Addresses.WHERE(
state == "OHIO"
).current_occupants.ORDER_BY(
- BACK(2).state.ASC(),
- BACK(2).city.ASC(),
+ state.ASC(),
+ city.ASC(),
ssn.ASC(),
)
```
-**Bad Example #5**: Sort every person by their first name. This is invalid because no `.ASC()` or `.DESC()` term is provided.
+**Bad Example #4**: Sort every person by their first name. This is invalid because no `.ASC()` or `.DESC()` term is provided.
```py
%%pydough
People.ORDER_BY(first_name)
```
-**Bad Example #6**: Sort every person. This is invalid because no collation terms are provided.
+**Bad Example #5**: Sort every person. This is invalid because no collation terms are provided.
```py
%%pydough
@@ -814,7 +838,7 @@ The terms in the collection are unchanged by the `TOP_K` clause, since the only
```py
%%pydough
-People(
+People.CALCULATE(
first_name,
last_name,
birth_date,
@@ -844,7 +868,7 @@ Addresses.TOP_K(10, by=most_recent_package.DESC())
```py
%%pydough
-People(
+People.CALCULATE(
first_name,
last_name,
total_package_cost=SUM(packages.package_cost)
@@ -862,7 +886,7 @@ People.TOP_K(5, by=gpa.ASC())
```py
%%pydough
-Addresses.packages_billed(25, by=gpa.packages_billed.arrival_date())
+Addresses.packages_billed.CALCULATE(25, by=gpa.packages_billed.arrival_date())
```
**Bad Example #3**: Find the top 100 people currently living in the city of San Francisco. This is invalid because the `by` clause is absent.
@@ -905,22 +929,22 @@ The syntax for this is `PARTITION(data, name="...", by=...)`. The `data` argumen
> [!WARNING]
> PyDough currently only supports using references to scalar expressions from the `data` collection itself as partition keys, not an ancestor term, or a term from a child collection, or the result of a function call.
-If the partitioned data is accessed, its original ancestry is lost. Instead, it inherits the ancestry of the `PARTITION` clause, so `BACK(1)` is the `PARTITION` clause, and `BACK(2)` is the ancestor of the partition clause (the default is the entire graph, just like collections).
+If the partitioned data is accessed, its original ancestry is lost. Instead, it inherits the ancestry from the `PARTITION` clause. The default ancestor of `PARTITION`, if not specified, is the entire graph (just like for table collections). The partitioned data still has access to any of the down-streamed terms from its original ancestry.
-The ancestry of the `PARTITION` clause can be changed by prepending it with another collection, separated by a dot. However, this is currently only supported in PyDough when the collection before the dot is just an augmented version of the graph context, as opposed to another collection (e.g. `GRAPH(x=42).PARTITION(...)` is supported, but `People.PARTITION(...)` is not).
+The ancestry of the `PARTITION` clause can be changed by prepending it with another collection, separated by a dot. However, this is currently only supported in PyDough when the collection before the dot is just an augmented version of the graph context, as opposed to another collection (e.g. `GRAPH.CALCULATE(x=42).PARTITION(...)` is supported, but `People.PARTITION(...)` is not).
**Good Example #1**: Find every unique state.
```py
%%pydough
-PARTITION(Addresses, name="addrs", by=state)(state)
+PARTITION(Addresses, name="addrs", by=state).CALCULATE(state)
```
**Good Example #2**: For every state, count how many addresses are in that state.
```py
%%pydough
-PARTITION(Addresses, name="addrs", by=state)(
+PARTITION(Addresses, name="addrs", by=state).CALCULATE(
state,
n_addr=COUNT(addrs)
)
@@ -930,7 +954,7 @@ PARTITION(Addresses, name="addrs", by=state)(
```py
%%pydough
-PARTITION(Addresses, name="addrs", by=(city, state))(
+PARTITION(Addresses, name="addrs", by=(city, state)).CALCULATE(
state,
city,
n_people=COUNT(addrs.current_occupants)
@@ -941,24 +965,24 @@ PARTITION(Addresses, name="addrs", by=(city, state))(
```py
%%pydough
-yahoo_people = People(
+yahoo_people = People.CALCULATE(
birth_year=YEAR(birth_date)
).WHERE(ENDSWITH(email, "@yahoo.com"))
-PARTITION(yahoo_people, name="yah_ppl", by=birth_year)(
+PARTITION(yahoo_people, name="yah_ppl", by=birth_year).CALCULATE(
birth_year,
n_people=COUNT(yah_ppl)
).TOP_K(5, by=n_people.DESC())
```
-**Good Example #4**: For every year/month, find all packages that were below the average cost of all packages ordered in that year/month.
+**Good Example #4**: For every year/month, find all packages that were below the average cost of all packages ordered in that year/month. Notice how `packs` can access `avg_package_cost`, which was defined by its ancestor (at the `PARTITION` level).
```py
%%pydough
-package_info = Packages(order_year=YEAR(order_date), order_month=MONTH(order_date))
-PARTITION(package_info, name="packs", by=(order_year, order_month))(
+package_info = Packages.CALCULATE(order_year=YEAR(order_date), order_month=MONTH(order_date))
+PARTITION(package_info, name="packs", by=(order_year, order_month)).CALCULATE(
avg_package_cost=AVG(packs.package_cost)
).packs.WHERE(
- package_cost < BACK(1).avg_package_cost
+ package_cost < avg_package_cost
)
```
@@ -966,26 +990,26 @@ PARTITION(package_info, name="packs", by=(order_year, order_month))(
```py
%%pydough
-PARTITION(Addresses, name="addrs", by=(city, state))(
+PARTITION(Addresses, name="addrs", by=(city, state)).CALCULATE(
total_packages=COUNT(addrs.current_occupants.packages)
-).addrs.current_occupants(
+).addrs.CALCULATE(city, state).current_occupants.CALCULATE(
first_name,
last_name,
- city=BACK(1).city,
- state=BACK(1).state,
- pct_of_packages=100.0 * COUNT(packages) / BACK(2).total_packages,
+ city=city,
+ state=state,
+ pct_of_packages=100.0 * COUNT(packages) / total_packages,
)
```
-**Good Example #6**: Identify the states whose current occupants account for at least 1% of all packages purchased. List the state and the percentage.
+**Good Example #6**: Identify the states whose current occupants account for at least 1% of all packages purchased. List the state and the percentage. Notice how `total_packages` is down-streamed from the graph-level `CALCULATE`.
```py
%%pydough
-GRAPH(
+GRAPH.CALCULATE(
total_packages=COUNT(Packages)
-).PARTITION(Addresses, name="addrs", by=state)(
+).PARTITION(Addresses, name="addrs", by=state).CALCULATE(
state,
- pct_of_packages=100.0 * COUNT(addrs.current_occupants.package) / BACK(1).packages
+ pct_of_packages=100.0 * COUNT(addrs.current_occupants.package) / total_packages
).WHERE(pct_of_packages >= 1.0)
```
@@ -993,26 +1017,25 @@ GRAPH(
```py
%%pydough
-pack_info = Packages(order_month=MONTH(order_date))
-month_info = PARTITION(pack_info, name="packs", by=order_month)(
+pack_info = Packages.CALCULATE(order_month=MONTH(order_date))
+month_info = PARTITION(pack_info, name="packs", by=order_month).CALCULATE(
n_packages=COUNT(packs)
)
-GRAPH(
+GRAPH.CALCULATE(
avg_packages_per_month=AVG(month_info.n_packages)
-).PARTITION(pack_info, name="packs", by=order_month)(
+).PARTITION(pack_info, name="packs", by=order_month).CALCULATE(
month,
-).WHERE(COUNT(packs) > BACK(1).avg_packages_per_month)
+).WHERE(COUNT(packs) > avg_packages_per_month)
```
-**Good Example #8**: Find the 10 most frequent combinations of the state that the person lives in and the first letter of that person's name.
+**Good Example #8**: Find the 10 most frequent combinations of the state that the person lives in and the first letter of that person's name. Notice how `state` can be used as a partition key of `people_info` since it was made available via down-streaming.
```py
%%pydough
-people_info = Addresses.current_occupants(
- state=BACK(1).state,
+people_info = Addresses.CALCULATE(state).current_occupants.CALCULATE(
first_letter=first_name[:1],
)
-PARTITION(people_info, name="ppl", by=(state, first_letter))(
+PARTITION(people_info, name="ppl", by=(state, first_letter)).CALCULATE(
state,
first_letter,
n_people=COUNT(ppl),
@@ -1023,17 +1046,57 @@ PARTITION(people_info, name="ppl", by=(state, first_letter))(
```py
%%pydough
-people_info = People(
+people_info = People.CALCULATE(
state=DEFALT_TO(current_address.state, "N/A"),
first_letter=first_name[:1],
)
-PARTITION(people_info, name="ppl", by=(state, first_letter))(
+PARTITION(people_info, name="ppl", by=(state, first_letter)).CALCULATE(
state,
first_letter,
n_people=COUNT(ppl),
).TOP_K(10, by=n_people.DESC())
```
+**Good Example #10**: Partition the current occupants of each address by their birth year and filter to include individuals born in years with at least 10,000 births. For each such person, list their first/last name and the state they live in. This is valid because `state` was down-streamed to `people_info` before it was partitioned, so when `ppl` is accessed, it still has access to `state`.
+
+```py
+%%pydough
+people_info = Addresses.CALCULATE(state).current_occupants.CALCULATE(birth_year=YEAR(birth_date))
+GRAPH.PARTITION(people_info, name="ppl", by=birth_year).WHERE(
+ COUNT(p) >= 10000
+).ppl.CALCULATE(
+ first_name,
+ last_name,
+ state
+)
+```
+
+**Good Example #11**: Find all packages that meet the following criteria: they were ordered in the last year that any package in the database was ordered, their cost was below the average of all packages ever ordered, and the state it was shipped to received at least 10,000 packages that year.
+
+```py
+%%pydough
+package_info = Packages.CALCULATE(
+ order_year=YEAR(order_date),
+ shipping_state=shipping_address.state
+)
+GRAPH.CALCULATE(
+ avg_cost=AVG(package_info.package_cost),
+ final_year=MAX(package_info.order_year),
+).PARTITION(
+ package_info.WHERE(order_year == final_year),
+ name="packs",
+ by=shipping_state
+).WHERE(
+ COUNT(packs) > 10000
+).packs.WHERE(
+ package_cost < avg_cost
+).CALCULATE(
+ shipping_state,
+ package_id,
+ order_date,
+)
+```
+
**Bad Example #1**: Partition a collection `Products` that does not exist in the graph.
```py
@@ -1055,83 +1118,59 @@ PARTITION(Addresses, by=state)
PARTITION(People, name="ppl")
```
-**Bad Example #4**: Count how many packages were ordered in each year. Invalid because `YEAR(order_date)` is not allowed ot be used as a partition term (it must be placed in a CALC so it is accessible as a named reference).
+**Bad Example #4**: Count how many packages were ordered in each year. Invalid because `YEAR(order_date)` is not allowed to be used as a partition term (it must be placed in a `CALCULATE` so it is accessible as a named reference).
```py
%%pydough
-PARTITION(Packages, name="packs", by=YEAR(order_date))(
+PARTITION(Packages, name="packs", by=YEAR(order_date)).CALCULATE(
n_packages=COUNT(packages)
)
```
-**Bad Example #5**: Count how many people live in each state. Invalid because `current_address.state` is not allowed to be used as a partition term (it must be placed in a CALC so it is accessible as a named reference).
+**Bad Example #5**: Count how many people live in each state. Invalid because `current_address.state` is not allowed to be used as a partition term (it must be placed in a `CALCULATE` so it is accessible as a named reference).
```py
%%pydough
-PARTITION(People, name="ppl", by=current_address.state)(
+PARTITION(People, name="ppl", by=current_address.state).CALCULATE(
n_packages=COUNT(packages)
)
```
-**Bad Example #6**: Invalid version of good example #8 that did not use a CALC to get rid of the `BACK` or `first_name[:1]`, which cannot be used as partition terms.
+**Bad Example #6**: Invalid version of good example #8 that did not use a `CALCULATE` to make `state` available via down-streaming or to bind `first_name[:1]` to a name, therefore neither can be used as a partition term.
```py
%%pydough
-PARTITION(Addresses.current_occupants, name="ppl", by=(BACK(1).state, first_name[:1]))(
- BACK(1).state,
+PARTITION(Addresses.current_occupants, name="ppl", by=(state, first_name[:1])).CALCULATE(
+ state,
first_name[:1],
n_people=COUNT(ppl),
).TOP_K(10, by=n_people.DESC())
```
-**Bad Example #7**: Partition people by their birth year to find the number of people born in each year. Invalid because the `email` property is referenced, which is not one of the properties accessible by the partition.
+**Bad Example #7**: Partition people by their birth year to find the number of people born in each year. Invalid because the `email` property is referenced, which is not one of the partition keys, even though the data being partitioned does have an `email` property.
```py
%%pydough
-PARTITION(People(birth_year=YEAR(birth_date)), name="ppl", by=birth_year)(
+PARTITION(People.CALCULATE(birth_year=YEAR(birth_date)), name="ppl", by=birth_year).CALCULATE(
birth_year,
email,
n_people=COUNT(ppl)
)
```
-**Bad Example #7**: For each person & year, count how many times that person ordered a packaged in that year. This is invalid because doing `.PARTITION` after `People` is unsupported, since `People` is not a graph-level collection like `GRAPH(...)`.
+**Bad Example #7**: For each person & year, count how many times that person ordered a packaged in that year. This is invalid because doing `.PARTITION` after `People` is unsupported, since `People` is not a graph-level collection like `GRAPH.CALCULATE(...)`.
```py
%%pydough
-People.PARTITION(packages(year=YEAR(order_date)), name="p", by=year)(
- ssn=BACK(1).ssn,
+People.CALCULATE(ssn).PARTITION(
+ packages.CALCULATE(year=YEAR(order_date)), name="p", by=year
+).CALCULATE(
+ ssn=ssn,
year=year,
n_packs=COUNT(p)
)
```
-**Bad Example #8**: Partition each address' current occupants by their birth year to get the number of people per birth year. This is invalid because the example includes a field `BACK(2).bar` which does not exist because the first ancestor of the partition is `GRAPH`, which does not have a second ancestor.
-
-```py
-%%pydough
-people_info = Addresses.current_occupants(birth_year=YEAR(birth_date))
-GRAPH.PARTITION(people_info, name="p", by=birth_year)(
- birth_year,
- n_people=COUNT(p),
- foo=BACK(2).bar,
-)
-```
-
-**Bad Example #9**: Partition the current occupants of each address by their birth year and filters to include only those born in years with at least 10,000 births. It then gets more information about people from those years. This query is invalid because, after accessing `.ppl`, the term `BACK(1).state` is used. This is not valid because, although the data that `.ppl` refers to (`people_info`) originally had access to `BACK(1).state`, that ancestry information was lost after partitioning `people_info`. Instead, `BACK(1)` now refers to the `PARTITION` clause, which does not have a state field.
-
-```py
-%%pydough
-people_info = Addresses.current_occupants(birth_year=YEAR(birth_date))
-GRAPH.PARTITION(people_info, name="ppl", by=birth_year).WHERE(
- COUNT(p) >= 10000
-).ppl(
- first_name,
- last_name,
- state=BACK(1).state,
-)
-```
-
### SINGULAR
@@ -1147,7 +1186,7 @@ Certain PyDough operations, such as specific filters, can cause plural data to b
most_recent_package = packages.WHERE(
RANKING(by=order_date.DESC(), levels=1) == 1
).SINGULAR()
-People(
+People.CALCULATE(
ssn,
first_name,
middle_name,
@@ -1165,7 +1204,7 @@ js = current_occupants.WHERE(
(last_name == "Smith") &
ABSENT(middle_name)
).SINGULAR()
-Addresses(
+Addresses.CALCULATE(
address_id,
john_smith_email=DEFAULT_TO(js.email, "NO JOHN SMITH LIVING HERE")
)
@@ -1177,7 +1216,7 @@ Addresses(
> [!IMPORTANT]
> This feature has not yet been implemented in PyDough
-In PyDough, it is also possible to access data from other records in the same collection that occur before or after the current record, when all the records are sorted. Similar to how `BACK(n)` can be used as a collection to access terms from an ancestor context, `PREV(n, by=...)` can be used to access terms from another record of the same context, specifically the record obtained by ordering by the `by` terms then looking for the record `n` entries before the current record. Similarly, `NEXT(n, ...)` is the same as `PREV(-n, ...)`.
+In PyDough, it is also possible to access data from other records in the same collection that occur before or after the current record, when all the records are sorted. Similar to how down-streaming can be used to access terms from an ancestor context, `PREV(n, by=...)` can be used as a collection to access terms from another record of the same context, specifically the record obtained by ordering by the `by` terms then looking for the record `n` entries before the current record. Similarly, `NEXT(n, ...)` is the same as `PREV(-n, ...)`.
The arguments to `NEXT` and `PREV` are as follows:
- `n` (optional): how many records before/after the current record to look. The default value is 1.
@@ -1190,7 +1229,7 @@ If the entry `n` records before/after the current entry does not exist, then acc
```py
%%pydough
-Packages(
+Packages.CALCULATE(
package_id,
same_customer_as_prev_package=customer_ssn == PREV(by=order_date.ASC()).ssn
)
@@ -1201,10 +1240,10 @@ Packages(
```py
%%pydough
prev_package = PREV(by=order_date.ASC(), levels=1)
-package_deltas = packages(
+package_deltas = packages.CALCULATE(
hour_difference=DATEDIFF('hours', order_date, prev_package.order_date)
)
-Customers(
+Customers.CALCULATE(
ssn,
avg_hours_between_purchases=AVG(package_deltas.hour_difference)
)
@@ -1217,7 +1256,7 @@ Customers(
first_after = NEXT(1, by=COUNT(packages).DESC())
second_after = NEXT(2, by=COUNT(packages).DESC())
third_after = NEXT(3, by=COUNT(packages).DESC())
-Customers(
+Customers.CALCULATE(
ssn,
same_state_as_order_neighbors=(
DEFAULT_TO(current_address.state == first_after.current_address.state, False) |
@@ -1231,7 +1270,7 @@ Customers(
```py
%%pydough
-Packages(
+Packages.CALCULATE(
hour_difference=DATEDIFF('hours', order_date, PREV().order_date)
)
```
@@ -1240,7 +1279,7 @@ Packages(
```py
%%pydough
-Packages(
+Packages.CALCULATE(
hour_difference=DATEDIFF('hours', order_date, NEXT(by=()).order_date)
)
```
@@ -1249,7 +1288,7 @@ Packages(
```py
%%pydough
-Packages(
+Packages.CALCULATE(
hour_difference=DATEDIFF('hours', order_date, PREV(5, by=order_date).order_date)
)
```
@@ -1258,7 +1297,7 @@ Packages(
```py
%%pydough
-Packages(
+Packages.CALCULATE(
hour_difference=DATEDIFF('hours', order_date, NEXT("ten", by=order_date.ASC()).order_date)
)
```
@@ -1267,7 +1306,7 @@ Packages(
```py
%%pydough
-Packages(
+Packages.CALCULATE(
hour_difference=DATEDIFF('hours', order_date, PREV(1, by=order_date.ASC()))
)
```
@@ -1276,7 +1315,7 @@ Packages(
```py
%%pydough
-Packages(
+Packages.CALCULATE(
hour_difference=DATEDIFF('hours', order_date, PREV(1, by=order_date.ASC()).odate)
)
```
@@ -1292,7 +1331,7 @@ Packages.PREV(order_date.ASC())
```py
%%pydough
-Customers.packages(
+Customers.packages.CALCULATE(
hour_difference=DATEDIFF('hours', order_date, PREV(1, by=order_date.ASC(), levels=5).order_date)
)
```
@@ -1307,7 +1346,7 @@ PyDough supports identifying a specific record from a sub-collection that is opt
A call to `BEST` can either be done with `.` syntax, to step from a parent collection to a child collection, or can be a freestanding accessor used inside of a collection operator, just like `BACK`, `PREV` or `NEXT`. For example, both `Parent.BEST(child, by=...)` and `Parent(x=BEST(child, by=...).y)` are allowed.
-The original ancestry of the sub-collection is intact. So, if doing `A.BEST(b.c.d, by=...)`, `BACK(1)` revers to `c`, `BACK(2)` refers to `b` and `BACK(3)` refers to `A`.
+The original ancestry of the sub-collection is intact, so any down-streaming is preserved.
Additional keyword arguments can be supplied to `BEST` that change its behavior:
- `allow_ties` (default=False): if True, changes the behavior to keep all records of the sub-collection that share the optimal values of the collation terms. If `allow_ties` is True, the `BEST` clause is no longer singular.
@@ -1317,7 +1356,7 @@ Additional keyword arguments can be supplied to `BEST` that change its behavior:
```py
%%pydough
-Customers.BEST(packages, by=order_date.ASC())(
+Customers.BEST(packages, by=order_date.ASC()).CALCULATE(
package_id,
shipping_address.zip_code
)
@@ -1327,7 +1366,7 @@ Customers.BEST(packages, by=order_date.ASC())(
```py
%%pydough
-Customers(
+Customers.CALCULATE(
ssn,
most_recent_cost=BEST(packages, by=order_date.DESC()).package_cost
)
@@ -1339,7 +1378,7 @@ Customers(
%%pydough
addr_info = Addresses.WHERE(
state == "NY"
-)(address_id, n_occupants=COUNT(current_occupants))
+).CALCULATE(address_id, n_occupants=COUNT(current_occupants))
GRAPH.BEST(addr_info, by=(n_occupants.DESC(), address_id.ASC()))
```
@@ -1348,7 +1387,7 @@ GRAPH.BEST(addr_info, by=(n_occupants.DESC(), address_id.ASC()))
```py
%%pydough
most_recent_package = BEST(packages, by=order_date.DESC())
-Customers(
+Customers.CALCULATE(
ssn,
n_occ_most_recent_addr=COUNT(most_recent_package.shipping_address.current_occupants)
)
@@ -1358,18 +1397,20 @@ Customers(
```py
%%pydough
-Addresses.WHERE(HAS(current_occupants))(
- n_occupants=COUNT(current_occupants)
+Addresses.WHERE(HAS(current_occupants)).CALCULATE(
+ city,
+ state,
+ n_occupants=COUNT(current_occupants),
).BEST(
- current_occupants(n_orders=COUNT(packages)),
+ current_occupants.CALCULATE(n_orders=COUNT(packages)),
by=(n_orders.DESC(), ssn.ASC())
-)(
+).CALCULATE(
first_name,
last_name,
n_orders,
- n_living_in_same_addr=BACK(1).n_occupants,
- city=BACK(1).city,
- state=BACK(1).state,
+ n_living_in_same_addr=n_occupants,
+ city=city,
+ state=state,
)
```
@@ -1378,20 +1419,21 @@ Addresses.WHERE(HAS(current_occupants))(
```py
%%pydough
five_most_recent = BEST(packages, by=order_date.DESC(), n_best=5)
-People(
+People.CALCULATE(
ssn,
value_most_recent_5=SUM(five_most_recent.package_cost)
)
```
-**Good Example #7**: For each address, find the package most recently ordered by one of the current occupants of that address, including the email of the occupant who ordered it and the address' id. Notice that `BACK(1)` refers to `current_occupants` and `BACK(2)` refers to `Addresses` as if the packages were accessed as `Addresses.current_occupants.packages` instead of using `BEST`.
+**Good Example #7**: For each address, find the package most recently ordered by one of the current occupants of that address, including the email of the occupant who ordered it and the address' id.
```py
%%pydough
-most_recent_package = BEST(current_occupants.packages, by=order_date.DESC())
-Addresses.most_recent_package(
- address_id=BACK(2).address_id,
- cust_email=BACK(1).email,
+packages_from_occupants = current_occupants.CALCULATE(email).packages
+most_recent_package = BEST(packages_from_occupants, by=order_date.DESC())
+Addresses.CALCULATE(address_id).most_recent_package.CALCULATE(
+ address_id,
+ email,
package_id=package_id,
order_date=order_date,
)
@@ -1401,7 +1443,7 @@ Addresses.most_recent_package(
```py
%%pydough
-People(first_name, BEST(email, by=birth_date.DESC()))
+People.CALCULATE(first_name, BEST(email, by=birth_date.DESC()))
```
**Bad Example #2**: For each person find their best package. This is invalid because the `by` argument is missing.
@@ -1432,39 +1474,39 @@ People.BEST(packages, by=())
People.BEST(packages, by=order_date.DESC(), n_best=5, allow_ties=True)
```
-**Bad Example #6**: For each person, find the package cost of their 10 most recent packages. This is invalid because `n_best` is greater than 1, which means that the `BEST` clause is non-singular so its terms cannot be accessed in the calc without aggregating.
+**Bad Example #6**: For each person, find the package cost of their 10 most recent packages. This is invalid because `n_best` is greater than 1, which means that the `BEST` clause is non-singular so its terms cannot be accessed in the `CALCULATE` without aggregating.
```py
%%pydough
best_packages = BEST(packages, by=order_date.DESC(), n_best=10)
-People(first_name, best_cost=best_packages.package_cost)
+People.CALCULATE(first_name, best_cost=best_packages.package_cost)
```
-**Bad Example #7**: For each person, find the package cost of their most expensive package(s), allowing ties. This is invalid because `allow_ties` is True, which means that the `BEST` clause is non-singular so its terms cannot be accessed in the calc without aggregating.
+**Bad Example #7**: For each person, find the package cost of their most expensive package(s), allowing ties. This is invalid because `allow_ties` is True, which means that the `BEST` clause is non-singular so its terms cannot be accessed in the `CALCULATE` without aggregating.
```py
%%pydough
best_packages = BEST(packages, by=package_cost.DESC(), allow_ties=True)
-People(first_name, best_cost=best_packages.package_cost)
+People.CALCULATE(first_name, best_cost=best_packages.package_cost)
```
-**Bad Example #8**: For each address, find the package most recently ordered by one of the current occupants of that address, including the address id of the address. This is invalid because `BACK(1)` refers to `current_occupants`, which does not have a field called `address_id`.
+**Bad Example #8**: For each address, find the package most recently ordered by one of the current occupants of that address, including the address id of the address. This is invalid because `address_id` is used despite not being down-streamed from `Addressed`.
```py
%%pydough
most_recent_package = BEST(current_occupants.packages, by=order_date.DESC())
-Addresses.most_recent_package(
- address_id=BACK(1).address_id,
+Addresses.most_recent_package.CALCULATE(
+ address_id=address_id,
package_id=package_id,
order_date=order_date,
)
```
-**Bad Example #9**: For each address find the oldest occupant. This is invalid because the `BEST` clause is placed in the calc without accessing any of its attributes.
+**Bad Example #9**: For each address find the oldest occupant. This is invalid because the `BEST` clause is placed in the `CALCULATE` without accessing any of its attributes.
```py
%%pydough
-Addresses(address_id, oldest_occupant=BEST(current_occupants, by=birth_date.ASC()))
+Addresses.CALCULATE(address_id, oldest_occupant=BEST(current_occupants, by=birth_date.ASC()))
```
@@ -1501,7 +1543,7 @@ The rest of the document are examples of questions asked about the data in the p
```py
%%pydough
# For each address, identify how many current occupants it has
-addr_info = Addresses(n_occupants=COUNT(current_occupants))
+addr_info = Addresses.CALCULATE(n_occupants=COUNT(current_occupants))
# Partition the addresses by the state, and for each state calculate the
# average value of `n_occupants` for all addresses in that state
@@ -1509,7 +1551,7 @@ states = PARTITION(
addr_info,
name="addrs",
by=state
-)(
+).CALCULATE(
state,
average_occupants=AVG(addrs.n_occupants)
)
@@ -1538,7 +1580,7 @@ to_east_coast = ISIN(shipping_address.state, east_coast_states)
# terms for if they are trans-coastal + the year they were ordered
package_info = Packages.WHERE(
PRESENT(arrival_date)
-)(
+).CALCULATE(
is_trans_coastal=from_west_coast & to_east_coast,
year=YEAR(order_date),
)
@@ -1549,7 +1591,7 @@ year_info = PARTITION(
package_info,
name="packs",
by=year,
-)(
+).CALCULATE(
year,
pct_trans_coastal=100.0 * SUM(packs.is_trans_coastal) / COUNT(packs),
)
@@ -1569,7 +1611,7 @@ result = year_info.ORDER_BY(year.ASC())
# Partition every address by the city/state
cities = PARTITION(
- Addresses,
+ Addresses.CALCULATE(city, state, zip_code),
name="addrs",
by=(city, state)
)
@@ -1579,11 +1621,11 @@ cities = PARTITION(
oldest_occupants = cities.BEST(
addrs.current_occupants.WHERE(HASNOT(packages)),
by=(birth_date.ASC(), ssn.ASC()),
-)(
- state=BACK(2).state,
- city=BACK(2).city,
- email=email,
- zip_code=BACK(1).zip_code,
+).CALCULATE(
+ state,
+ city,
+ email,
+ zip_code
)
# Sort the output by state, followed by city
@@ -1605,20 +1647,20 @@ result = oldest_occupants.ORDER_BY(
is_2017 = YEAR(order_date) == 2017
# Identify the average package cost of all packages ordered in 2017
-global_info = GRAPH(
+global_info = GRAPH.CALCULATE(
avg_package_cost=AVG(Packages.WHERE(is_2017).package_cost)
)
-# Identify all packages ordered in 2017, but where BACK(1) is global_info
-# instead of GRAPH, so we have access to global_info's terms.
+# Identify all packages ordered in 2017, but where the ancestor is global_info
+# instead of GRAPH, so `avg_package_cost` gets down-streamed.
selected_package = global_info.Packages.WHERE(is_2017)
# For each such package, identify the month it was ordered, and add a term to
# indicate if the cost of the package is at least 10x the average for all such
# packages.
-packages = selected_packages(
+packages = selected_packages.CALCULATE(
month=MONTH(order_date),
- is_10x_avg=package_cost >= (10.0 * BACK(1).avg_package_cost)
+ is_10x_avg=package_cost >= (10.0 * avg_package_cost)
)
# Partition the packages by the month they were ordered, and for each month
@@ -1628,7 +1670,7 @@ months = PARTITION(
package_info,
name="packs",
by=month
-)(
+).CALCULATE(
month,
pct_outliers=100.0 * SUM(packs.is_10x_avg) / COUNT(packs)
)
@@ -1649,28 +1691,28 @@ Note: uses the formula [discussed here](https://medium.com/swlh/linear-regressio
%%pydough
# Identify every year & how many packages were ordered that year
yearly_data = PARTITION(
- Packages(year=YEAR(order_date)),
+ Packages.CALCULATE(year=YEAR(order_date)),
name="packs",
by=year,
-)(
+).CALCULATE(
year,
n_orders = COUNT(packs),
)
# Obtain the global average of the year (x-coordinate) and
# n_orders (y-coordinate). These correspond to `x-bar` and `y-bar`.
-global_info = GRAPH(
+global_info = GRAPH.CALCULATE(
avg_x = AVG(yearly_data.year),
avg_y = AVG(yearly_data.n_orders),
)
# Contextless expression: corresponds to `x - x-bar` with regards to yearly_data
# inside of global_info
-dx = n_orders - BACK(1).avg_x
+dx = n_orders - avg_x
# Contextless expression: corresponds to `y - y-bar` with regards to yearly_data
# inside of global_info
-dy = year - BACK(1).avg_y
+dy = year - avg_y
# Contextless expression: derive the slope with regards to global_info
regression_data = yearly_data(value=(dx * dy) / (dx * dx))
diff --git a/documentation/functions.md b/documentation/functions.md
index b9e034a2..8f4afc79 100644
--- a/documentation/functions.md
+++ b/documentation/functions.md
@@ -85,7 +85,7 @@ Below is each binary operator currently supported in PyDough.
Supported mathematical operations: addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), exponentiation (`**`).
```py
-Lineitems(value = (extended_price * (1 - (discount ** 2)) + 1.0) / part.retail_price)
+Lineitems.CALCULATE(value = (extended_price * (1 - (discount ** 2)) + 1.0) / part.retail_price)
```
> [!WARNING]
@@ -98,7 +98,7 @@ Lineitems(value = (extended_price * (1 - (discount ** 2)) + 1.0) / part.retail_p
Expression values can be compared using standard comparison operators: `<=`, `<`, `==`, `!=`, `>` and `>=`:
```py
-Customers(
+Customers.CALCULATE(
in_debt = acctbal < 0,
at_most_12_orders = COUNT(orders) <= 12,
is_european = nation.region.name == "EUROPE",
@@ -121,7 +121,7 @@ Multiple boolean expression values can be logically combined with `&`, `|` and `
is_asian = nation.region.name == "ASIA"
is_european = nation.region.name == "EUROPE"
in_debt = acctbal < 0
-Customers(
+Customers.CALCULATE(
is_eurasian = is_asian | is_european,
is_not_eurasian = ~(is_asian | is_european),
is_european_in_debt = is_european & in_debt
@@ -144,7 +144,7 @@ Below is each unary operator currently supported in PyDough.
A numerical expression's sign can be flipped by prefixing it with the `-` operator:
```py
-Lineitems(lost_value = extended_price * (-discount))
+Lineitems.CALCULATE(lost_value = extended_price * (-discount))
```
@@ -160,7 +160,7 @@ Below are all other operators currently supported in PyDough that use other synt
A string expression can have a substring extracted with Python string slicing syntax `s[a:b:c]`:
```py
-Customers(
+Customers.CALCULATE(
country_code = phone[:3],
name_without_first_char = name[1:]
)
@@ -182,7 +182,7 @@ Below is each function currently supported in PyDough that operates on strings.
Calling `LOWER` on a string converts its characters to lowercase:
```py
-Customers(lowercase_name = LOWER(name))
+Customers.CALCULATE(lowercase_name = LOWER(name))
```
@@ -192,7 +192,7 @@ Customers(lowercase_name = LOWER(name))
Calling `UPPER` on a string converts its characters to uppercase:
```py
-Customers(uppercase_name = UPPER(name))
+Customers.CALCULATE(uppercase_name = UPPER(name))
```
@@ -202,7 +202,7 @@ Customers(uppercase_name = UPPER(name))
Calling `length` on a string returns the number of characters it contains:
```py
-Suppliers(n_chars_in_comment = LENGTH(comment))
+Suppliers.CALCULATE(n_chars_in_comment = LENGTH(comment))
```
@@ -212,7 +212,7 @@ Suppliers(n_chars_in_comment = LENGTH(comment))
The `STARTSWITH` function checks if its first argument begins with its second argument as a string prefix:
```py
-Parts(begins_with_yellow = STARTSWITH(name, "yellow"))
+Parts.CALCULATE(begins_with_yellow = STARTSWITH(name, "yellow"))
```
@@ -222,7 +222,7 @@ Parts(begins_with_yellow = STARTSWITH(name, "yellow"))
The `ENDSWITH` function checks if its first argument ends with its second argument as a string suffix:
```py
-Parts(ends_with_chocolate = ENDSWITH(name, "chocolate"))
+Parts.CALCULATE(ends_with_chocolate = ENDSWITH(name, "chocolate"))
```
@@ -232,7 +232,7 @@ Parts(ends_with_chocolate = ENDSWITH(name, "chocolate"))
The `CONTAINS` function checks if its first argument contains its second argument as a substring:
```py
-Parts(is_green = CONTAINS(name, "green"))
+Parts.CALCULATE(is_green = CONTAINS(name, "green"))
```
@@ -242,7 +242,7 @@ Parts(is_green = CONTAINS(name, "green"))
The `LIKE` function checks if the first argument matches the SQL pattern text of the second argument, where `_` is a 1 character wildcard and `%` is an 0+ character wildcard.
```py
-Orders(is_special_request = LIKE(comment, "%special%requests%"))
+Orders.CALCULATE(is_special_request = LIKE(comment, "%special%requests%"))
```
[This link](https://www.w3schools.com/sql/sql_like.asp) explains how these SQL pattern strings work and provides some examples.
@@ -254,8 +254,12 @@ Orders(is_special_request = LIKE(comment, "%special%requests%"))
The `JOIN_STRINGS` function concatenates all its string arguments, using the first argument as a delimiter between each of the following arguments (like the `.join` method in Python):
```py
-Regions.nations.customers(
- fully_qualified_name = JOIN_STRINGS("-", BACK(2).name, BACK(1).name, name)
+Regions.CALCULATE(
+ region_name=name
+).nations.CALCULATE(
+ nation_name=name
+).customers.CALCULATE(
+ fully_qualified_name = JOIN_STRINGS("-", region_name, nation_name, name)
)
```
@@ -298,7 +302,7 @@ If there are multiple modifiers, they operate left-to-right.
# 3. Exactly 12 hours from now
# 4. The last day of the previous year
# 5. The current day, at midnight
-TPCH(
+TPCH.CALCULATE(
ts_1=DATETIME('now'),
ts_2=DATETIME('NoW', 'start of month'),
ts_3=DATETIME(' CURRENT_DATE ', '12 hours'),
@@ -307,7 +311,7 @@ TPCH(
)
# For each order, truncates the order date to the first day of the year
-Orders(order_year=DATETIME(order_year, 'START OF Y'))
+Orders.CALCULATE(order_year=DATETIME(order_year, 'START OF Y'))
```
@@ -327,7 +331,7 @@ Orders.WHERE(YEAR(order_date) == 1995)
Calling `MONTH` on a date/timestamp extracts the month of the year it belongs to:
```py
-Orders(is_summer = (MONTH(order_date) >= 6) & (MONTH(order_date) <= 8))
+Orders.CALCULATE(is_summer = (MONTH(order_date) >= 6) & (MONTH(order_date) <= 8))
```
@@ -337,7 +341,7 @@ Orders(is_summer = (MONTH(order_date) >= 6) & (MONTH(order_date) <= 8))
Calling `DAY` on a date/timestamp extracts the day of the month it belongs to:
```py
-Orders(is_first_of_month = DAY(order_date) == 1)
+Orders.CALCULATE(is_first_of_month = DAY(order_date) == 1)
```
@@ -348,7 +352,7 @@ Calling `HOUR` on a date/timestamp extracts the hour it belongs to. The range of
is from 0-23:
```py
-Orders(is_12pm = HOUR(order_date) == 12)
+Orders.CALCULATE(is_12pm = HOUR(order_date) == 12)
```
@@ -359,7 +363,7 @@ Calling `MINUTE` on a date/timestamp extracts the minute. The range of output
is from 0-59:
```py
-Orders(is_half_hour = MINUTE(order_date) == 30)
+Orders.CALCULATE(is_half_hour = MINUTE(order_date) == 30)
```
@@ -370,7 +374,7 @@ Calling `SECOND` on a date/timestamp extracts the second. The range of output
is from 0-59:
```py
-Orders(is_lt_30_seconds = SECOND(order_date) < 30)
+Orders.CALCULATE(is_lt_30_seconds = SECOND(order_date) < 30)
```
@@ -389,7 +393,7 @@ Calling `DATEDIFF` between 2 timestamps returns the difference in one of `years`
```py
# Calculates, for each order, the number of days since January 1st 1992
# that the order was placed:
-orders(
+Orders.CALCULATE(
days_since=DATEDIFF("days",datetime.date(1992, 1, 1), order_date)
)
```
@@ -410,8 +414,8 @@ The `IFF` function cases on the True/False value of its first argument. If it is
```py
qty_from_germany = IFF(supplier.nation.name == "GERMANY", quantity, 0)
-Customers(
- total_quantity_shipped_from_germany = SUM(lines(q=qty_from_germany).q)
+Customers.CALCULATE(
+ total_quantity_shipped_from_germany = SUM(lines.CALCULATE(q=qty_from_germany).q)
)
```
@@ -432,7 +436,7 @@ Parts.WHERE(ISIN(size, (10, 11, 17, 19, 45)))
The `DEFAULT_TO` function returns the first of its arguments that is non-null (e.g. the same as the `COALESCE` function in SQL):
```py
-Lineitems(adj_tax = DEFAULT_TO(tax, 0))
+Lineitems.CALCULATE(adj_tax = DEFAULT_TO(tax, 0))
```
@@ -442,7 +446,7 @@ Lineitems(adj_tax = DEFAULT_TO(tax, 0))
The `PRESENT` function checks if its argument is non-null (e.g. the same as `IS NOT NULL` in SQL):
```py
-Lineitems(has_tax = PRESENT(tax))
+Lineitems.CALCULATE(has_tax = PRESENT(tax))
```
@@ -452,7 +456,7 @@ Lineitems(has_tax = PRESENT(tax))
The `ABSENT` function checks if its argument is null (e.g. the same as `IS NULL` in SQL):
```py
-Lineitems(no_tax = ABSENT(tax))
+Lineitems.CALCULATE(no_tax = ABSENT(tax))
```
@@ -462,7 +466,7 @@ Lineitems(no_tax = ABSENT(tax))
The `KEEP_IF` function returns the first function if the second arguments is True, otherwise it returns a null value. In other words, `KEEP_IF(a, b)` is equivalent to the SQL expression `CASE WHEN b THEN a END`.
```py
-TPCH(avg_non_debt_balance = AVG(Customers(no_debt_bal = KEEP_IF(acctbal, acctbal > 0)).no_debt_bal))
+TPCH.CALCULATE(avg_non_debt_balance = AVG(Customers.CALCULATE(no_debt_bal = KEEP_IF(acctbal, acctbal > 0)).no_debt_bal))
```
@@ -488,9 +492,9 @@ Below is each numerical function currently supported in PyDough.
The `ABS` function returns the absolute value of its input. The Python builtin `abs()` function can also be used to accomplish the same thing.
```py
-Customers(acct_magnitude = ABS(acctbal))
+Customers.CALCULATE(acct_magnitude = ABS(acctbal))
# The below statement is equivalent to above.
-Customers(acct_magnitude = abs(acctbal))
+Customers.CALCULATE(acct_magnitude = abs(acctbal))
```
@@ -500,18 +504,18 @@ Customers(acct_magnitude = abs(acctbal))
The `ROUND` function rounds its first argument to the precision of its second argument. The rounding rules used depend on the database's round function. The Python builtin `round()` function can also be used to accomplish the same thing.
```py
-Parts(rounded_price = ROUND(retail_price, 1))
+Parts.CALCULATE(rounded_price = ROUND(retail_price, 1))
# The below statement is equivalent to above.
-Parts(rounded_price = round(retail_price, 1))
+Parts.CALCULATE(rounded_price = round(retail_price, 1))
```
Note: The default precision for builtin `round` method is 0, to be in alignment with the Python implementation. The PyDough `ROUND` function requires the precision to be specified.
```py
# This is legal.
-Parts(rounded_price = round(retail_price))
+Parts.CALCULATE(rounded_price = round(retail_price))
# This is illegal as precision is not specified.
-Parts(rounded_price = ROUND(retail_price))
+Parts.CALCULATE(rounded_price = ROUND(retail_price))
```
@@ -521,7 +525,7 @@ Parts(rounded_price = ROUND(retail_price))
The `POWER` function exponentiates its first argument to the power of its second argument.
```py
-Parts(powered_price = POWER(retail_price, 2))
+Parts.CALCULATE(powered_price = POWER(retail_price, 2))
```
@@ -531,7 +535,7 @@ Parts(powered_price = POWER(retail_price, 2))
The `SQRT` function takes the square root of its input. It's equivalent to `POWER(x,0.5)`.
```py
-Parts(sqrt_price = SQRT(retail_price))
+Parts.CALCULATE(sqrt_price = SQRT(retail_price))
```
@@ -549,7 +553,7 @@ Aggregation functions are a special set of functions that, when called on their
The `SUM` function returns the sum of the plural set of numerical values it is called on.
```py
-Nations(total_consumer_wealth = SUM(customers.acctbal))
+Nations.CALCULATE(total_consumer_wealth = SUM(customers.acctbal))
```
@@ -559,7 +563,7 @@ Nations(total_consumer_wealth = SUM(customers.acctbal))
The `AVG` function takes the average of the plural set of numerical values it is called on.
```py
-Parts(average_shipment_size = AVG(lines.quantity))
+Parts.CALCULATE(average_shipment_size = AVG(lines.quantity))
```
@@ -569,7 +573,7 @@ Parts(average_shipment_size = AVG(lines.quantity))
The `MIN` function returns the smallest value from the set of numerical values it is called on.
```py
-Suppliers(cheapest_part_supplied = MIN(supply_records.supply_cost))
+Suppliers.CALCULATE(cheapest_part_supplied = MIN(supply_records.supply_cost))
```
@@ -579,7 +583,7 @@ Suppliers(cheapest_part_supplied = MIN(supply_records.supply_cost))
The `MAX` function returns the largest value from the set of numerical values it is called on.
```py
-Suppliers(most_expensive_part_supplied = MAX(supply_records.supply_cost))
+Suppliers.CALCULATE(most_expensive_part_supplied = MAX(supply_records.supply_cost))
```
@@ -589,13 +593,13 @@ Suppliers(most_expensive_part_supplied = MAX(supply_records.supply_cost))
The `COUNT` function returns how many non-null records exist on the set of plural values it is called on.
```py
-Customers(num_taxed_purchases = COUNT(orders.lines.tax))
+Customers.CALCULATE(num_taxed_purchases = COUNT(orders.lines.tax))
```
The `COUNT` function can also be called on a sub-collection, in which case it will return how many records from that sub-collection exist.
```py
-Nations(num_customers_in_debt = COUNT(customers.WHERE(acctbal < 0)))
+Nations.CALCULATE(num_customers_in_debt = COUNT(customers.WHERE(acctbal < 0)))
```
@@ -605,7 +609,7 @@ Nations(num_customers_in_debt = COUNT(customers.WHERE(acctbal < 0)))
The `NDISTINCT` function returns how many distinct values of its argument exist.
```py
-Customers(num_unique_parts_purchased = NDISTINCT(orders.lines.parts.key))
+Customers.CALCULATE(num_unique_parts_purchased = NDISTINCT(orders.lines.parts.key))
```
@@ -640,16 +644,16 @@ For example, if using the `RANKING` window function, consider the following exam
```py
# (no levels) rank every customer relative to all other customers
-Regions.nations.customers(r=RANKING(...))
+Regions.nations.customers.CALCULATE(r=RANKING(...))
# (levels=1) rank every customer relative to other customers in the same nation
-Regions.nations.customers(r=RANKING(..., levels=1))
+Regions.nations.customers.CALCULATE(r=RANKING(..., levels=1))
# (levels=2) rank every customer relative to other customers in the same region
-Regions.nations.customers(r=RANKING(..., levels=2))
+Regions.nations.customers.CALCULATE(r=RANKING(..., levels=2))
# (levels=3) rank every customer relative to all other customers
-Regions.nations.customers(r=RANKING(..., levels=3))
+Regions.nations.customers.CALCULATE(r=RANKING(..., levels=3))
```
Below is each window function currently supported in PyDough.
@@ -668,7 +672,7 @@ The `RANKING` function returns ordinal position of the current record when all r
```py
# Rank customers per-nation by their account balance
# (highest = rank #1, no ties)
-Nations.customers(r = RANKING(by=acctbal.DESC(), levels=1))
+Nations.customers.CALCULATE(r = RANKING(by=acctbal.DESC(), levels=1))
# For every customer, finds their most recent order
# (ties allowed)
diff --git a/documentation/usage.md b/documentation/usage.md
index 30a986fe..c4d36d9e 100644
--- a/documentation/usage.md
+++ b/documentation/usage.md
@@ -46,7 +46,7 @@ Once you have done all of these steps, you can run PyDough in any cell of the no
```py
%%pydough
-result = nations.WHERE(region.name == "ASIA")(name, n_cust=COUNT(customers))
+result = nations.WHERE(region.name == "ASIA").CALCULATE(name, n_cust=COUNT(customers))
pydough.to_df(result)
```
@@ -147,7 +147,7 @@ For example, consider this PyDough snippet:
```py
%%pydough
selected_customers = customers.WHERE(CONTAINS(name, '2468'))
-result = Nations(
+result = Nations.CALCULATE(
name,
total_bal=SUM(selected_customers.acctbal),
average_bal=AVG(selected_customers.acctbal),
@@ -212,6 +212,7 @@ This sections describes various APIs you can use to execute PyDough code.
The `to_sql` API takes in PyDough code and transforms it into SQL query text without executing it on a database. The first argument it takes in is the PyDough node for the collection being converted to SQL. It can optionally take in the following keyword arguments:
+- `columns`: which columns to include in the answer if what names to call them by (if omitted, uses the names/ordering from the last `CALCULATE` clause). This can either be a non-empty list of column name strings, or a non-empty dictionary where the values are column name strings, and the keys are the strings of the aliases they should be named as.
- `metadata`: the PyDough knowledge graph to use for the conversion (if omitted, `pydough.active_session.metadata` is used instead).
- `config`: the PyDough configuration settings to use for the conversion (if omitted, `pydough.active_session.config` is used instead).
- `database`: the database context to use for the conversion (if omitted, `pydough.active_session.database` is used instead). The database context matters because it controls which SQL dialect is used for the translation.
@@ -221,8 +222,8 @@ Below is an example of using `pydough.to_sql` and the output (the SQL output may
```py
%%pydough
european_countries = nations.WHERE(region.name == "EUROPE")
-result = european_countries(name, n_custs=COUNT(customers))
-pydough.to_sql(result)
+result = european_countries.CALCULATE(name, n_custs=COUNT(customers))
+pydough.to_sql(result, columns=["name", "n_custs"])
```
```sql
@@ -263,6 +264,8 @@ See the [demo notebooks](../demos/README.md) for more instances of how to use th
The `to_df` API does all the same steps as the [`to_sql` API](#pydoughto_sql), but goes a step further and executes the query using the provided database connection, returning the result as a pandas DataFrame. The first argument it takes in is the PyDough node for the collection being converted to SQL. It can optionally take in the following keyword arguments:
+
+- `columns`: which columns to include in the answer if what names to call them by (if omitted, uses the names/ordering from the last `CALCULATE` clause). This can either be a non-empty list of column name strings, or a non-empty dictionary where the values are column name strings, and the keys are the strings of the aliases they should be named as.
- `metadata`: the PyDough knowledge graph to use for the conversion (if omitted, `pydough.active_session.metadata` is used instead).
- `config`: the PyDough configuration settings to use for the conversion (if omitted, `pydough.active_session.config` is used instead).
- `database`: the database context to use for the conversion (if omitted, `pydough.active_session.database` is used instead). The database context matters because it controls which SQL dialect is used for the translation.
@@ -273,8 +276,8 @@ Below is an example of using `pydough.to_df` and the output, attached to a sqlit
```py
%%pydough
european_countries = nations.WHERE(region.name == "EUROPE")
-result = european_countries(name, n_custs=COUNT(customers))
-pydough.to_df(result)
+result = european_countries.CALCULATE(n=COUNT(customers))
+pydough.to_df(result, columns={"name": "name", "n_custs": "n"})
```
@@ -478,12 +481,10 @@ pydough.explain(result, verbose=True)
PyDough collection representing the following logic:
TPCH
-This node is a reference to the global context for the entire graph. An operation must be done onto this node (e.g. a CALC or accessing a collection) before it can be executed.
+This node is a reference to the global context for the entire graph. An operation must be done onto this node (e.g. a CALCULATE or accessing a collection) before it can be executed.
The collection does not have any terms that can be included in a result if it is executed.
-It is not possible to use BACK from this collection.
-
The collection has access to the following collections:
customers, lines, nations, orders, parts, regions, suppliers, supply_records
@@ -510,8 +511,6 @@ Call pydough.explain(graph['nations']) to learn more about this collection.
The following terms will be included in the result if this collection is executed:
comment, key, name, region_key
-It is possible to use BACK to go up to 1 level above this collection.
-
The collection has access to the following expressions:
comment, key, name, region_key
@@ -548,8 +547,6 @@ The main task of this node is to filter on the following conditions:
The following terms will be included in the result if this collection is executed:
comment, key, name, region_key
-It is possible to use BACK to go up to 1 level above this collection.
-
The collection has access to the following expressions:
comment, key, name, region_key
@@ -560,11 +557,11 @@ Call pydough.explain_term(collection, term) to learn more about any of these
expressions or collections that the collection has access to.
```
-4d. Calling `explain` on PyDough code for a collection (example 4: calc).
+4d. Calling `explain` on PyDough code for a collection (example 4: CALCULATE).
```py
%%pydough
-result = nations.WHERE(region.name == "EUROPE")(name, n_custs=COUNT(customers))
+result = nations.WHERE(region.name == "EUROPE").CALCULATE(name, n_custs=COUNT(customers))
pydough.explain(result, verbose=True)
```
@@ -575,7 +572,7 @@ PyDough collection representing the following logic:
├─┬─ Where[$1.name == 'EUROPE']
│ └─┬─ AccessChild
│ └─── SubCollection[region]
- └─┬─ Calc[name=name, n_custs=COUNT($1)]
+ └─┬─ Calculate[name=name, n_custs=COUNT($1)]
└─┬─ AccessChild
└─── SubCollection[customers]
@@ -590,8 +587,6 @@ The main task of this node is to calculate the following additional expressions
The following terms will be included in the result if this collection is executed:
n_custs, name
-It is possible to use BACK to go up to 1 level above this collection.
-
The collection has access to the following expressions:
comment, key, n_custs, name, region_key
@@ -634,9 +629,9 @@ The term is the following expression: name
This is column 'name' of collection 'nations'
-This term is singular with regards to the collection, meaning it can be placed in a CALC of a collection.
+This term is singular with regards to the collection, meaning it can be placed in a CALCULATE of a collection.
For example, the following is valid:
- TPCH.nations.WHERE(region.name == 'EUROPE')(name)
+ TPCH.nations.WHERE(region.name == 'EUROPE').CALCULATE(name)
```
2. Calling `explain_term` on a sub-collection of a collection.
@@ -662,7 +657,7 @@ The term is the following child of the collection:
This child is plural with regards to the collection, meaning its scalar terms can only be accessed by the collection if they are aggregated.
For example, the following are valid:
- TPCH.nations.WHERE(region.name == 'EUROPE')(COUNT(customers.acctbal))
+ TPCH.nations.WHERE(region.name == 'EUROPE').CALCULATE(COUNT(customers.acctbal))
TPCH.nations.WHERE(region.name == 'EUROPE').WHERE(HAS(customers))
TPCH.nations.WHERE(region.name == 'EUROPE').ORDER_BY(COUNT(customers).DESC())
@@ -695,9 +690,9 @@ The term is the following expression: $1.acctbal
This is a reference to expression 'acctbal' of child $1
-This expression is plural with regards to the collection, meaning it can be placed in a CALC of a collection if it is aggregated.
+This expression is plural with regards to the collection, meaning it can be placed in a CALCULATE of a collection if it is aggregated.
For example, the following is valid:
- TPCH.nations.WHERE(region.name == 'EUROPE')(COUNT(customers.acctbal))
+ TPCH.nations.WHERE(region.name == 'EUROPE').CALCULATE(COUNT(customers.acctbal))
```
4. Calling `explain_term` on an aggregation function call.
@@ -728,9 +723,9 @@ This expression calls the function 'AVG' on the following arguments, aggregating
Call pydough.explain_term with this collection and any of the arguments to learn more about them.
-This term is singular with regards to the collection, meaning it can be placed in a CALC of a collection.
+This term is singular with regards to the collection, meaning it can be placed in a CALCULATE of a collection.
For example, the following is valid:
- TPCH.nations.WHERE(region.name == 'EUROPE')(AVG(customers.acctbal))
+ TPCH.nations.WHERE(region.name == 'EUROPE').CALCULATE(AVG(customers.acctbal))
```
## Logging
@@ -755,7 +750,7 @@ logger.info("This is an info message.")
logger.error("This is an error message.")
```
-We can also set the level of logging via a function argument. Note that if `PYDOUGH_LOG_LEVEL` is available, the default_level argument is overriden.
+We can also set the level of logging via a function argument. Note that if `PYDOUGH_LOG_LEVEL` is available, the default_level argument is overridden.
```python
# Import the function
diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py
index a1c36407..ea050c3f 100644
--- a/pydough/conversion/hybrid_decorrelater.py
+++ b/pydough/conversion/hybrid_decorrelater.py
@@ -11,7 +11,7 @@
from .hybrid_tree import (
ConnectionType,
HybridBackRefExpr,
- HybridCalc,
+ HybridCalculate,
HybridChildRefExpr,
HybridColumnExpr,
HybridConnection,
@@ -208,7 +208,7 @@ def correl_ref_purge(
operation.unique_exprs[idx] = self.remove_correl_refs(
expr, old_parent, child_height
)
- if isinstance(operation, HybridCalc):
+ if isinstance(operation, HybridCalculate):
for str, expr in operation.new_expressions.items():
operation.new_expressions[str] = self.remove_correl_refs(
expr, old_parent, child_height
diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py
index 8e7d4a49..d7acf21b 100644
--- a/pydough/conversion/hybrid_tree.py
+++ b/pydough/conversion/hybrid_tree.py
@@ -6,7 +6,7 @@
__all__ = [
"HybridBackRefExpr",
- "HybridCalc",
+ "HybridCalculate",
"HybridChildRefExpr",
"HybridCollation",
"HybridCollectionAccess",
@@ -25,6 +25,7 @@
"HybridTree",
]
+import copy
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
@@ -39,7 +40,7 @@
)
from pydough.qdag import (
BackReferenceExpression,
- Calc,
+ Calculate,
ChildOperator,
ChildOperatorChildAccess,
ChildReferenceCollection,
@@ -400,10 +401,10 @@ class HybridOperation:
- `terms`: mapping of names to expressions accessible from that point in
the pipeline execution.
- `renamings`: mapping of names to a new name that should be used to access
- them from within `terms`. This is used when a `CALC` overrides a
- term name so that future invocations of the term name use the
- renamed version, while key operations like joins can still
- access the original version.
+ them from within `terms`. This is used when a `CALCULATE`
+ overrides a term name so that future invocations of the term
+ name use the renamed version, while key operations like joins
+ can still access the original version.
- `orderings`: list of collation expressions that specify the order
that a hybrid operation is sorted by.
- `unique_exprs`: list of expressions that are used to uniquely identify
@@ -446,11 +447,15 @@ def __init__(self, collection: CollectionAccess):
self.collection: CollectionAccess = collection
terms: dict[str, HybridExpr] = {}
for name in collection.calc_terms:
+ # Skip columns that are overloaded with a name from an ancestor,
+ # since they should not be used.
+ if name in collection.ancestral_mapping:
+ continue
expr = collection.get_expr(name)
assert isinstance(expr, ColumnProperty)
terms[name] = HybridColumnExpr(expr)
unique_exprs: list[HybridExpr] = []
- for name in collection.unique_terms:
+ for name in sorted(collection.unique_terms, key=str):
expr = collection.get_expr(name)
unique_exprs.append(HybridRefExpr(name, expr.pydough_type))
super().__init__(terms, {}, [], unique_exprs)
@@ -478,9 +483,9 @@ def __repr__(self):
return "PARTITION_CHILD[*]"
-class HybridCalc(HybridOperation):
+class HybridCalculate(HybridOperation):
"""
- Class for HybridOperation corresponding to a CALC operation.
+ Class for HybridOperation corresponding to a CALCULATE operation.
"""
def __init__(
@@ -515,11 +520,10 @@ def __init__(
expr = new_expressions.pop(old_name)
new_expressions[new_name] = expr
super().__init__(terms, renamings, orderings, predecessor.unique_exprs)
- self.calc = Calc
self.new_expressions = new_expressions
def __repr__(self):
- return f"CALC[{self.new_expressions}]"
+ return f"CALCULATE[{self.new_expressions}]"
class HybridFilter(HybridOperation):
@@ -609,7 +613,7 @@ class ConnectionType(Enum):
combined with the parent tree via a left join. The aggregate call may be
augmented after the left join, e.g. to coalesce with a default value if the
left join was not used. The grouping keys for the aggregate are the keys
- used to join the parent tree output onto the subtree ouput.
+ used to join the parent tree output onto the subtree output.
If this is used as a child access of a `PARTITION` node, there is no left
join, though some of the post-processing steps may still occur.
@@ -901,11 +905,13 @@ class HybridTree:
def __init__(
self,
root_operation: HybridOperation,
+ ancestral_mapping: dict[str, int],
is_hidden_level: bool = False,
is_connection_root: bool = False,
):
self._pipeline: list[HybridOperation] = [root_operation]
self._children: list[HybridConnection] = []
+ self._ancestral_mapping: dict[str, int] = dict(ancestral_mapping)
self._successor: HybridTree | None = None
self._parent: HybridTree | None = None
self._is_hidden_level: bool = is_hidden_level
@@ -953,6 +959,14 @@ def children(self) -> list[HybridConnection]:
"""
return self._children
+ @property
+ def ancestral_mapping(self) -> dict[str, int]:
+ """
+ The mapping used to identify terms that are references to an alias
+ defined in an ancestor.
+ """
+ return self._ancestral_mapping
+
@property
def correlated_children(self) -> set[int]:
"""
@@ -1264,14 +1278,14 @@ def populate_children(
child_idx_mapping: dict[int, int],
) -> None:
"""
- Helper utility that takes any children of a child operator (CALC,
+ Helper utility that takes any children of a child operator (CALCULATE,
WHERE, etc.) and builds the corresponding HybridTree subtree,
where the parent of the subtree's root is absent instead of the
current level, and inserts the corresponding HybridConnection node.
Args:
`hybrid`: the HybridTree having children added to it.
- `child_operator`: the collection QDAG node (CALC, WHERE, etc.)
+ `child_operator`: the collection QDAG node (CALCULATE, WHERE, etc.)
containing the children.
`child_idx_mapping`: a mapping of indices of children of the
original `child_operator` to the indices of children of the hybrid
@@ -1304,7 +1318,7 @@ def populate_children(
self.identify_connection_types(
col.expr, child_idx, reference_types
)
- case Calc():
+ case Calculate():
for expr in child_operator.calc_term_values.values():
self.identify_connection_types(expr, child_idx, reference_types)
case PartitionBy():
@@ -1607,9 +1621,10 @@ def make_hybrid_correl_expr(
parent_tree: HybridTree = self.stack.pop()
remaining_steps_back: int = back_expr.back_levels - steps_taken_so_far - 1
parent_result: HybridExpr
+ new_expr: PyDoughExpressionQDAG
# Special case: stepping out of the data argument of PARTITION back
# into its ancestor. For example:
- # TPCH(x=...).PARTITION(data.WHERE(y > BACK(1).x), ...)
+ # TPCH.CALCULATE(x=...).PARTITION(data.WHERE(y > BACK(1).x), ...)
partition_edge_case: bool = len(parent_tree.pipeline) == 1 and isinstance(
parent_tree.pipeline[0], HybridPartition
)
@@ -1626,20 +1641,28 @@ def make_hybrid_correl_expr(
elif remaining_steps_back == 0:
# If there are no more steps back to be made, then the correlated
# reference is to a reference from the current context.
- if back_expr.term_name not in parent_tree.pipeline[-1].terms:
+ if back_expr.term_name in parent_tree.ancestral_mapping:
+ new_expr = BackReferenceExpression(
+ collection,
+ back_expr.term_name,
+ parent_tree.ancestral_mapping[back_expr.term_name],
+ )
+ parent_result = self.make_hybrid_expr(parent_tree, new_expr, {}, False)
+ elif back_expr.term_name in parent_tree.pipeline[-1].terms:
+ parent_name: str = parent_tree.pipeline[-1].renamings.get(
+ back_expr.term_name, back_expr.term_name
+ )
+ parent_result = HybridRefExpr(parent_name, back_expr.pydough_type)
+ else:
raise ValueError(
f"Back reference to {back_expr.term_name} not found in parent"
)
- parent_name: str = parent_tree.pipeline[-1].renamings.get(
- back_expr.term_name, back_expr.term_name
- )
- parent_result = HybridRefExpr(parent_name, back_expr.pydough_type)
else:
# Otherwise, a back reference needs to be made from the current
# collection a number of steps back based on how many steps still
# need to be taken, and it must be recursively converted to a
# hybrid expression that gets wrapped in a correlated reference.
- new_expr: PyDoughExpressionQDAG = BackReferenceExpression(
+ new_expr = BackReferenceExpression(
collection, back_expr.term_name, remaining_steps_back
)
parent_result = self.make_hybrid_expr(parent_tree, new_expr, {}, False)
@@ -1681,6 +1704,7 @@ def make_hybrid_expr(
args: list[HybridExpr] = []
hybrid_arg: HybridExpr
ancestor_tree: HybridTree
+ collection: PyDoughCollectionQDAG
match expr:
case PartitionKey():
return self.make_hybrid_expr(
@@ -1714,7 +1738,7 @@ def make_hybrid_expr(
# Keep stepping backward until `expr.back_levels` non-hidden
# steps have been taken (to ignore steps that are part of a
# compound).
- collection: PyDoughCollectionQDAG = expr.collection
+ collection = expr.collection
while true_steps_back < expr.back_levels:
assert collection.ancestor_context is not None
collection = collection.ancestor_context
@@ -1731,6 +1755,23 @@ def make_hybrid_expr(
)
return HybridBackRefExpr(expr_name, expr.back_levels, expr.pydough_type)
case Reference():
+ if hybrid.ancestral_mapping.get(expr.term_name, 0) > 0:
+ collection = expr.collection
+ while (
+ isinstance(collection, PartitionChild)
+ and expr.term_name in collection.child_access.ancestral_mapping
+ ):
+ collection = collection.child_access
+ return self.make_hybrid_expr(
+ hybrid,
+ BackReferenceExpression(
+ collection,
+ expr.term_name,
+ hybrid.ancestral_mapping[expr.term_name],
+ ),
+ child_ref_mapping,
+ inside_agg,
+ )
expr_name = hybrid.pipeline[-1].renamings.get(
expr.term_name, expr.term_name
)
@@ -1789,7 +1830,9 @@ def make_hybrid_expr(
if ancestor_tree.parent is None:
raise ValueError("Window function references too far back")
ancestor_tree = ancestor_tree.parent
- for unique_term in ancestor_tree.pipeline[-1].unique_exprs:
+ for unique_term in sorted(
+ ancestor_tree.pipeline[-1].unique_exprs, key=str
+ ):
shifted_arg: HybridExpr | None = unique_term.shift_back(
expr.levels
)
@@ -1830,7 +1873,7 @@ def process_hybrid_collations(
Returns:
A tuple containing a dictionary of new expressions for generating
- a calc and a list of the new HybridCollation values.
+ a `CALCULATE` and a list of the new HybridCollation values.
"""
new_expressions: dict[str, HybridExpr] = {}
hybrid_orderings: list[HybridCollation] = []
@@ -1863,18 +1906,23 @@ def make_hybrid_tree(
The HybridTree representation of `node`.
"""
hybrid: HybridTree
+ subtree: HybridTree
successor_hybrid: HybridTree
expr: HybridExpr
+ hybrid_back_expr: HybridExpr
child_ref_mapping: dict[int, int] = {}
key_exprs: list[HybridExpr] = []
join_key_exprs: list[tuple[HybridExpr, HybridExpr]] = []
+ back_exprs: dict[str, HybridExpr] = {}
match node:
case GlobalContext():
- return HybridTree(HybridRoot())
+ return HybridTree(HybridRoot(), node.ancestral_mapping)
case CompoundSubCollection():
raise NotImplementedError(f"{node.__class__.__name__}")
case TableCollection() | SubCollection():
- successor_hybrid = HybridTree(HybridCollectionAccess(node))
+ successor_hybrid = HybridTree(
+ HybridCollectionAccess(node), node.ancestral_mapping
+ )
hybrid = self.make_hybrid_tree(node.ancestor_context, parent)
hybrid.add_successor(successor_hybrid)
return successor_hybrid
@@ -1886,11 +1934,14 @@ def make_hybrid_tree(
src_tree: HybridTree = hybrid
while isinstance(src_tree.pipeline[0], HybridPartitionChild):
src_tree = src_tree.pipeline[0].subtree
- subtree: HybridTree = src_tree.children[0].subtree
- successor_hybrid = HybridTree(HybridPartitionChild(subtree))
+ subtree = src_tree.children[0].subtree
+ successor_hybrid = HybridTree(
+ HybridPartitionChild(subtree),
+ node.ancestral_mapping,
+ )
hybrid.add_successor(successor_hybrid)
return successor_hybrid
- case Calc():
+ case Calculate():
hybrid = self.make_hybrid_tree(node.preceding_context, parent)
self.populate_children(hybrid, node, child_ref_mapping)
new_expressions: dict[str, HybridExpr] = {}
@@ -1900,7 +1951,7 @@ def make_hybrid_tree(
)
new_expressions[name] = expr
hybrid.pipeline.append(
- HybridCalc(
+ HybridCalculate(
hybrid.pipeline[-1],
new_expressions,
hybrid.pipeline[-1].orderings,
@@ -1918,11 +1969,33 @@ def make_hybrid_tree(
case PartitionBy():
hybrid = self.make_hybrid_tree(node.ancestor_context, parent)
partition: HybridPartition = HybridPartition()
- successor_hybrid = HybridTree(partition)
+ successor_hybrid = HybridTree(partition, node.ancestral_mapping)
hybrid.add_successor(successor_hybrid)
self.populate_children(successor_hybrid, node, child_ref_mapping)
partition_child_idx: int = child_ref_mapping[0]
- for key_name in node.calc_terms:
+ subtree = successor_hybrid.children[partition_child_idx].subtree
+ for name in subtree.ancestral_mapping:
+ # Skip adding backrefs for terms that remain part of the
+ # ancestry through the PARTITION, since this creates an
+ # unecessary correlation.
+ if name in node.ancestor_context.ancestral_mapping:
+ continue
+ hybrid_back_expr = self.make_hybrid_expr(
+ subtree,
+ node.children[0].get_expr(name),
+ {},
+ False,
+ )
+ back_exprs[name] = hybrid_back_expr
+ if len(back_exprs):
+ subtree.pipeline.append(
+ HybridCalculate(
+ subtree.pipeline[-1],
+ back_exprs,
+ subtree.pipeline[-1].orderings,
+ )
+ )
+ for key_name in sorted(node.calc_terms, key=str):
key = node.get_expr(key_name)
expr = self.make_hybrid_expr(
successor_hybrid, key, child_ref_mapping, False
@@ -1942,7 +2015,7 @@ def make_hybrid_tree(
hybrid, node.collation, child_ref_mapping
)
hybrid.pipeline.append(
- HybridCalc(hybrid.pipeline[-1], new_nodes, hybrid_orderings)
+ HybridCalculate(hybrid.pipeline[-1], new_nodes, hybrid_orderings)
)
if isinstance(node, TopK):
hybrid.pipeline.append(
@@ -1956,7 +2029,8 @@ def make_hybrid_tree(
node.child_access, CompoundSubCollection
):
successor_hybrid = HybridTree(
- HybridCollectionAccess(node.child_access)
+ HybridCollectionAccess(node.child_access),
+ node.ancestral_mapping,
)
if isinstance(node.child_access, SubCollection):
join_key_exprs = HybridTranslator.get_join_keys(
@@ -1965,9 +2039,8 @@ def make_hybrid_tree(
successor_hybrid.pipeline[-1],
)
case PartitionChild():
- successor_hybrid = self.make_hybrid_tree(
- node.child_access.child_access, parent
- )
+ successor_hybrid = copy.deepcopy(parent.children[0].subtree)
+ successor_hybrid._ancestral_mapping = node.ancestral_mapping
partition_by = (
node.child_access.ancestor_context.starting_predecessor
)
@@ -1979,7 +2052,9 @@ def make_hybrid_tree(
child_ref_mapping,
False,
)
- assert isinstance(rhs_expr, HybridRefExpr)
+ assert isinstance(
+ rhs_expr, (HybridRefExpr, HybridBackRefExpr)
+ )
lhs_expr: HybridExpr = HybridChildRefExpr(
rhs_expr.name, 0, rhs_expr.typ
)
@@ -1987,7 +2062,7 @@ def make_hybrid_tree(
case PartitionBy():
partition = HybridPartition()
- successor_hybrid = HybridTree(partition)
+ successor_hybrid = HybridTree(partition, node.ancestral_mapping)
self.populate_children(
successor_hybrid, node.child_access, child_ref_mapping
)
@@ -1999,6 +2074,28 @@ def make_hybrid_tree(
)
partition.add_key(key_name, expr)
key_exprs.append(HybridRefExpr(key_name, expr.typ))
+ subtree = successor_hybrid.children[partition_child_idx].subtree
+ for name in subtree.ancestral_mapping:
+ # Skip adding backrefs for terms that remain part of the
+ # ancestry through the PARTITION, since this creates an
+ # unecessary correlation.
+ if name in node.ancestor_context.ancestral_mapping:
+ continue
+ hybrid_back_expr = self.make_hybrid_expr(
+ subtree,
+ node.child_access.children[0].get_expr(name),
+ {},
+ False,
+ )
+ back_exprs[name] = hybrid_back_expr
+ if len(back_exprs):
+ subtree.pipeline.append(
+ HybridCalculate(
+ subtree.pipeline[-1],
+ back_exprs,
+ subtree.pipeline[-1].orderings,
+ )
+ )
successor_hybrid.children[
partition_child_idx
].subtree.agg_keys = key_exprs
diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py
index c48132a2..5a137833 100644
--- a/pydough/conversion/relational_converter.py
+++ b/pydough/conversion/relational_converter.py
@@ -14,7 +14,7 @@
SimpleTableMetadata,
)
from pydough.qdag import (
- Calc,
+ Calculate,
CollectionAccess,
PyDoughCollectionQDAG,
PyDoughExpressionQDAG,
@@ -48,7 +48,7 @@
from .hybrid_tree import (
ConnectionType,
HybridBackRefExpr,
- HybridCalc,
+ HybridCalculate,
HybridChildRefExpr,
HybridCollation,
HybridCollectionAccess,
@@ -203,6 +203,15 @@ def translate_expression(
return LiteralExpression(expr.literal.value, expr.typ)
case HybridRefExpr() | HybridChildRefExpr() | HybridBackRefExpr():
assert context is not None
+ if expr not in context.expressions:
+ if isinstance(expr, HybridRefExpr):
+ for back_expr in context.expressions:
+ if (
+ isinstance(back_expr, HybridBackRefExpr)
+ and back_expr.name == expr.name
+ ):
+ return context.expressions[back_expr]
+ raise ValueError(f"Context does not contain expression {expr}")
return context.expressions[expr]
case HybridFunctionExpr():
inputs = [self.translate_expression(arg, context) for arg in expr.args]
@@ -735,17 +744,17 @@ def translate_limit(
)
return TranslationOutput(out_rel, context.expressions)
- def translate_calc(
+ def translate_calculate(
self,
- node: HybridCalc,
+ node: HybridCalculate,
context: TranslationOutput,
) -> TranslationOutput:
"""
- Converts a calc into a project on top of its child to derive additional
- terms.
+ Converts a CALCULATE into a project on top of its child to derive
+ additional terms.
Args:
- `node`: the node corresponding to the calc being derived.
+ `node`: the node corresponding to the CALCULATE being derived.
`context`: the data structure storing information used by the
conversion, such as bindings of already translated terms from
preceding contexts and the corresponding relational node.
@@ -810,16 +819,17 @@ def translate_partition_child(
)
join_keys: list[tuple[HybridExpr, HybridExpr]] = []
assert node.subtree.agg_keys is not None
- for agg_key in node.subtree.agg_keys:
+ for agg_key in sorted(node.subtree.agg_keys, key=str):
join_keys.append((agg_key, agg_key))
- return self.join_outputs(
+ result = self.join_outputs(
context,
child_output,
JoinType.INNER,
join_keys,
None,
)
+ return result
def rel_translation(
self,
@@ -921,9 +931,9 @@ def rel_translation(
case HybridPartitionChild():
assert context is not None, "Malformed HybridTree pattern."
result = self.translate_partition_child(operation, context)
- case HybridCalc():
+ case HybridCalculate():
assert context is not None, "Malformed HybridTree pattern."
- result = self.translate_calc(operation, context)
+ result = self.translate_calculate(operation, context)
case HybridFilter():
assert context is not None, "Malformed HybridTree pattern."
result = self.translate_filter(operation, context)
@@ -948,21 +958,23 @@ def rel_translation(
@staticmethod
def preprocess_root(
node: PyDoughCollectionQDAG,
+ output_cols: list[tuple[str, str]] | None,
) -> PyDoughCollectionQDAG:
"""
- Transforms the final PyDough collection by appending it with an extra CALC
- containing all of the columns that are output.
+ Transforms the final PyDough collection by appending it with an extra
+ CALCULATE containing all of the columns that are output.
"""
# Fetch all of the expressions that should be kept in the final output
- original_calc_terms: set[str] = node.calc_terms
final_terms: list[tuple[str, PyDoughExpressionQDAG]] = []
- all_names: set[str] = set()
- for name in original_calc_terms:
- final_terms.append((name, Reference(node, name)))
- all_names.add(name)
- final_terms.sort(key=lambda term: node.get_expression_position(term[0]))
+ if output_cols is None:
+ for name in node.calc_terms:
+ final_terms.append((name, Reference(node, name)))
+ final_terms.sort(key=lambda term: node.get_expression_position(term[0]))
+ else:
+ for _, column in output_cols:
+ final_terms.append((column, Reference(node, column)))
children: list[PyDoughCollectionQDAG] = []
- final_calc: Calc = Calc(node, children).with_terms(final_terms)
+ final_calc: Calculate = Calculate(node, children).with_terms(final_terms)
return final_calc
@@ -994,7 +1006,9 @@ def make_relational_ordering(
def convert_ast_to_relational(
- node: PyDoughCollectionQDAG, configs: PyDoughConfigs
+ node: PyDoughCollectionQDAG,
+ columns: list[tuple[str, str]] | None,
+ configs: PyDoughConfigs,
) -> RelationalRoot:
"""
Main API for converting from the collection QDAG form into relational
@@ -1002,18 +1016,22 @@ def convert_ast_to_relational(
Args:
`node`: the PyDough QDAG collection node to be translated.
+ `columns`: a list of tuples in the form `(alias, column)`
+ describing every column that should be in the output, in the order
+ they should appear, and the alias they should be given. If None, uses
+ the most recent CALCULATE in the node to determine the columns.
+ `configs`: the configuration settings to use during translation.
Returns:
The RelationalRoot for the entire PyDough calculation that the
- collection node corresponds to. Ensures that the calc terms of
- `node` are included in the root in the correct order, and if it
+ collection node corresponds to. Ensures that the desired output columns
+ of `node` are included in the root in the correct order, and if it
has an ordering then the relational root stores that information.
"""
- # Pre-process the QDAG node so the final CALC term includes any ordering
+ # Pre-process the QDAG node so the final CALCULATE includes any ordering
# keys.
translator: RelTranslation = RelTranslation()
- final_terms: set[str] = node.calc_terms
- node = translator.preprocess_root(node)
+ node = translator.preprocess_root(node, columns)
# Convert the QDAG node to the hybrid form, decorrelate it, then invoke
# the relational conversion procedure. The first element in the returned
@@ -1033,12 +1051,18 @@ def convert_ast_to_relational(
rel_expr: RelationalExpression
name: str
original_name: str
- for original_name in final_terms:
- name = renamings.get(original_name, original_name)
- hybrid_expr = hybrid.pipeline[-1].terms[name]
- rel_expr = output.expressions[hybrid_expr]
- ordered_columns.append((original_name, rel_expr))
- ordered_columns.sort(key=lambda col: node.get_expression_position(col[0]))
+ if columns is None:
+ for original_name in node.calc_terms:
+ name = renamings.get(original_name, original_name)
+ hybrid_expr = hybrid.pipeline[-1].terms[name]
+ rel_expr = output.expressions[hybrid_expr]
+ ordered_columns.append((original_name, rel_expr))
+ ordered_columns.sort(key=lambda col: node.get_expression_position(col[0]))
+ else:
+ for alias, column in columns:
+ hybrid_expr = hybrid.pipeline[-1].terms[column]
+ rel_expr = output.expressions[hybrid_expr]
+ ordered_columns.append((alias, rel_expr))
hybrid_orderings: list[HybridCollation] = hybrid.pipeline[-1].orderings
if hybrid_orderings:
orderings = make_relational_ordering(hybrid_orderings, output.expressions)
diff --git a/pydough/evaluation/evaluate_unqualified.py b/pydough/evaluation/evaluate_unqualified.py
index 2b908f29..090c396a 100644
--- a/pydough/evaluation/evaluate_unqualified.py
+++ b/pydough/evaluation/evaluate_unqualified.py
@@ -66,6 +66,47 @@ def _load_session_info(
return metadata, config, database, bindings
+def _load_column_selection(kwargs: dict[str, object]) -> list[tuple[str, str]] | None:
+ """
+ Load the column selection from the keyword arguments if it is found.
+ The column selection must be a keyword argument `columns` that is either a
+ list of strings, or a dictionary mapping output column names to the column
+ they correspond to in the collection.
+
+ Args:
+ kwargs: The keyword arguments to load the column selection from.
+
+ Returns:
+ The column selection if it is found, otherwise None.
+ """
+ columns_arg = kwargs.pop("columns", None)
+ result: list[tuple[str, str]] = []
+ if columns_arg is None:
+ return None
+ elif isinstance(columns_arg, list):
+ for column in columns_arg:
+ assert isinstance(
+ column, str
+ ), f"Expected column name in `columns` argument to be a string, found {column.__class__.__name__}"
+ result.append((column, column))
+ elif isinstance(columns_arg, dict):
+ for alias, column in columns_arg.items():
+ assert isinstance(
+ alias, str
+ ), f"Expected alias name in `columns` argument to be a string, found {column.__class__.__name__}"
+ assert isinstance(
+ column, str
+ ), f"Expected column name in `columns` argument to be a string, found {column.__class__.__name__}"
+ result.append((alias, column))
+ else:
+ raise TypeError(
+ f"Expected `columns` argument to be a list or dictionary, found {columns_arg.__class__.__name__}"
+ )
+ if len(result) == 0:
+ raise ValueError("Column selection must not be empty")
+ return result
+
+
def to_sql(node: UnqualifiedNode, **kwargs) -> str:
"""
Convert the given unqualified tree to a SQL string.
@@ -84,13 +125,16 @@ def to_sql(node: UnqualifiedNode, **kwargs) -> str:
graph: GraphMetadata
config: PyDoughConfigs
database: DatabaseContext
+ column_selection: list[tuple[str, str]] | None = _load_column_selection(kwargs)
graph, config, database, bindings = _load_session_info(**kwargs)
qualified: PyDoughQDAG = qualify_node(node, graph)
if not isinstance(qualified, PyDoughCollectionQDAG):
raise TypeError(
f"Final qualified expression must be a collection, found {qualified.__class__.__name__}"
)
- relational: RelationalRoot = convert_ast_to_relational(qualified, config)
+ relational: RelationalRoot = convert_ast_to_relational(
+ qualified, column_selection, config
+ )
return convert_relation_to_sql(
relational, convert_dialect_to_sqlglot(database.dialect), bindings
)
@@ -115,6 +159,7 @@ def to_df(node: UnqualifiedNode, **kwargs) -> pd.DataFrame:
graph: GraphMetadata
config: PyDoughConfigs
database: DatabaseContext
+ column_selection: list[tuple[str, str]] | None = _load_column_selection(kwargs)
display_sql: bool = bool(kwargs.pop("display_sql", False))
graph, config, database, bindings = _load_session_info(**kwargs)
qualified: PyDoughQDAG = qualify_node(node, graph)
@@ -122,5 +167,7 @@ def to_df(node: UnqualifiedNode, **kwargs) -> pd.DataFrame:
raise TypeError(
f"Final qualified expression must be a collection, found {qualified.__class__.__name__}"
)
- relational: RelationalRoot = convert_ast_to_relational(qualified, config)
+ relational: RelationalRoot = convert_ast_to_relational(
+ qualified, column_selection, config
+ )
return execute_df(relational, database, bindings, display_sql)
diff --git a/pydough/exploration/explain.py b/pydough/exploration/explain.py
index d60bd2f3..843228b9 100644
--- a/pydough/exploration/explain.py
+++ b/pydough/exploration/explain.py
@@ -21,7 +21,8 @@
TableColumnMetadata,
)
from pydough.qdag import (
- Calc,
+ BackReferenceExpression,
+ Calculate,
ChildOperator,
ExpressionFunctionCall,
GlobalContext,
@@ -349,7 +350,7 @@ def explain_unqualified(node: UnqualifiedNode, verbose: bool) -> str:
match qualified_node:
case GlobalContext():
lines.append(
- "This node is a reference to the global context for the entire graph. An operation must be done onto this node (e.g. a CALC or accessing a collection) before it can be executed."
+ "This node is a reference to the global context for the entire graph. An operation must be done onto this node (e.g. a CALCULATE or accessing a collection) before it can be executed."
)
case TableCollection():
collection_name = qualified_node.collection.name
@@ -367,7 +368,6 @@ def explain_unqualified(node: UnqualifiedNode, verbose: bool) -> str:
lines.append(
f"This node, specifically, accesses the unpartitioned data of a partitioning (child name: {qualified_node.partition_child_name})."
)
- lines.append("Using BACK(1) will access the partitioned data.")
case ChildOperator():
if len(qualified_node.children):
lines.append(
@@ -382,7 +382,7 @@ def explain_unqualified(node: UnqualifiedNode, verbose: bool) -> str:
lines.append(f" child ${idx + 1}: {child.to_string()}")
lines.append("")
match qualified_node:
- case Calc():
+ case Calculate():
lines.append(
"The main task of this node is to calculate the following additional expressions that are added to the terms of the collection:"
)
@@ -401,6 +401,10 @@ def explain_unqualified(node: UnqualifiedNode, verbose: bool) -> str:
suffix += " (propagated from previous collection)"
else:
suffix += f" (overwrites existing value of {name})"
+ elif isinstance(expr, BackReferenceExpression):
+ suffix = (
+ " (referencing an alias defined in an ancestor)"
+ )
lines.append(f" {name} <- {tree_string}{suffix}")
case Where():
lines.append(
@@ -474,23 +478,6 @@ def explain_unqualified(node: UnqualifiedNode, verbose: bool) -> str:
"\nThe collection does not have any terms that can be included in a result if it is executed."
)
- # Identify the number of BACK levels that are accessible
- back_counter: int = 0
- copy_node: PyDoughCollectionQDAG = qualified_node
- while copy_node.ancestor_context is not None:
- back_counter += 1
- copy_node = copy_node.ancestor_context
- if back_counter == 0:
- lines.append("\nIt is not possible to use BACK from this collection.")
- elif back_counter == 1:
- lines.append(
- "\nIt is possible to use BACK to go up to 1 level above this collection."
- )
- else:
- lines.append(
- f"\nIt is possible to use BACK to go up to {back_counter} levels above this collection."
- )
-
# Dump the collection & expression terms of the collection
expr_names: list[str] = []
collection_names: list[str] = []
diff --git a/pydough/exploration/term.py b/pydough/exploration/term.py
index 50169080..b0e12b69 100644
--- a/pydough/exploration/term.py
+++ b/pydough/exploration/term.py
@@ -9,7 +9,6 @@
import pydough.pydough_operators as pydop
from pydough.qdag import (
- BackReferenceCollection,
BackReferenceExpression,
ChildReferenceExpression,
ColumnProperty,
@@ -22,7 +21,7 @@
)
from pydough.unqualified import (
UnqualifiedAccess,
- UnqualifiedCalc,
+ UnqualifiedCalculate,
UnqualifiedNode,
UnqualifiedOrderBy,
UnqualifiedPartition,
@@ -51,7 +50,7 @@ def find_unqualified_root(node: UnqualifiedNode) -> UnqualifiedRoot | None:
return node
case (
UnqualifiedAccess()
- | UnqualifiedCalc()
+ | UnqualifiedCalculate()
| UnqualifiedWhere()
| UnqualifiedOrderBy()
| UnqualifiedTopK()
@@ -69,8 +68,8 @@ def collection_in_context_string(
"""
Converts a collection in the context of another collection into a single
string in a way that elides back collection references. For example,
- if the context is A.B.C.D, and the collection is BACK(2).E.F, the result
- would be "A.B.E.F".
+ if the context is A.B.WHERE(C), and the collection is D.E, the result
+ would be "A.B.WHERE(C).D.E".
Args:
`context`: the collection representing the context that `collection`
@@ -80,13 +79,7 @@ def collection_in_context_string(
Returns:
The desired string representation of context and collection combined.
"""
- if isinstance(collection, BackReferenceCollection):
- ancestor: PyDoughCollectionQDAG = context
- for _ in range(collection.back_levels):
- assert ancestor.ancestor_context is not None
- ancestor = ancestor.ancestor_context
- return f"{ancestor.to_string()}.{collection.term_name}"
- elif (
+ if (
collection.preceding_context is not None
and collection.preceding_context is not context
):
@@ -281,19 +274,19 @@ def explain_term(
lines.append("")
if qualified_term.is_singular(qualified_node.starting_predecessor):
lines.append(
- "This term is singular with regards to the collection, meaning it can be placed in a CALC of a collection."
+ "This term is singular with regards to the collection, meaning it can be placed in a CALCULATE of a collection."
)
lines.append("For example, the following is valid:")
lines.append(
- f" {qualified_node.to_string()}({qualified_term.to_string()})"
+ f" {qualified_node.to_string()}.CALCULATE({qualified_term.to_string()})"
)
else:
lines.append(
- "This expression is plural with regards to the collection, meaning it can be placed in a CALC of a collection if it is aggregated."
+ "This expression is plural with regards to the collection, meaning it can be placed in a CALCULATE of a collection if it is aggregated."
)
lines.append("For example, the following is valid:")
lines.append(
- f" {qualified_node.to_string()}(COUNT({qualified_term.to_string()}))"
+ f" {qualified_node.to_string()}.CALCULATE(COUNT({qualified_term.to_string()}))"
)
else:
assert isinstance(qualified_term, PyDoughCollectionQDAG)
@@ -317,7 +310,7 @@ def explain_term(
)
lines.append("For example, the following is valid:")
lines.append(
- f" {qualified_node.to_string()}({qualified_term.to_string()}.{chosen_term_name})"
+ f" {qualified_node.to_string()}.CALCULATE({qualified_term.to_string()}.{chosen_term_name})"
)
else:
lines.append(
@@ -325,7 +318,7 @@ def explain_term(
)
lines.append("For example, the following are valid:")
lines.append(
- f" {qualified_node.to_string()}(COUNT({qualified_term.to_string()}.{chosen_term_name}))"
+ f" {qualified_node.to_string()}.CALCULATE(COUNT({qualified_term.to_string()}.{chosen_term_name}))"
)
lines.append(
f" {qualified_node.to_string()}.WHERE(HAS({qualified_term.to_string()}))"
diff --git a/pydough/jupyter_extensions/README.md b/pydough/jupyter_extensions/README.md
index 48af0c74..32611e69 100644
--- a/pydough/jupyter_extensions/README.md
+++ b/pydough/jupyter_extensions/README.md
@@ -26,7 +26,7 @@ Once the extension is loaded, you can use the `%%pydough` magic command to run P
```python
%%pydough
-result = Nations(
+result = Nations.CALCULATE(
nation_name=name,
region_name=region.name,
num_customers=COUNT(customers)
@@ -40,7 +40,7 @@ The transformed code will look like this:
from pydough.unqualified import UnqualifiedRoot
_ROOT = UnqualifiedRoot(pydough.active_session.metadata)
-result = _ROOT.Nations(
+result = _ROOT.Nations.CALCULATE(
nation_name=_ROOT.name,
region_name=_ROOT.region.name,
num_customers=_ROOT.COUNT(_ROOT.customers)
diff --git a/pydough/pydough_operators/expression_operators/README.md b/pydough/pydough_operators/expression_operators/README.md
index 4bfa3fdd..d63abf4b 100644
--- a/pydough/pydough_operators/expression_operators/README.md
+++ b/pydough/pydough_operators/expression_operators/README.md
@@ -136,7 +136,7 @@ These functions return an expression and use logic that produces a value that de
- `RANKING(by=..., levels=None, allow_ties=False, dense=False)`: returns the ordinal position of the current record when all records are sorted by the collation expressions in the `by` argument. By default, uses the same semantics as `ROW_NUMBER`. If `allow_ties=True`, instead uses `RANK`. If `allow_ties=True` and `dense=True`, instead uses `DENSE_RANK`.
- `PERCENTILE(by=..., levels=None, n_buckets=100)`: splits the data into `n_buckets` equal sized sections by ordering the data by the `by` arguments, where bucket `1` is the smallest data and bucket `n_buckets` is the largest. This is useful for understanding the relative position of a value within a group, like finding the top 10% of performers in a class.
-For an example of how `levels` works, when doing `Regions.nations.customers(r=RANKING(by=...))`:
+For an example of how `levels` works, when doing `Regions.nations.customers.CALCULATE(r=RANKING(by=...))`:
- If `levels=None` or `levels=3`, `r` is the ranking across all `customers`.
- If `levels=1`, `r` is the ranking of customers per-nation (meaning the ranking resets to 1 within each nation).
diff --git a/pydough/qdag/README.md b/pydough/qdag/README.md
index b27f6b82..dc936da0 100644
--- a/pydough/qdag/README.md
+++ b/pydough/qdag/README.md
@@ -77,10 +77,10 @@ sub_collection = builder.build_child_access("region", table_collection)
child_collection = ChildOperatorChildAccess(sub_collection)
child_reference_node = builder.build_child_reference_expression([child_collection], 0, "name")
-# Build a CALC node
-# Equivalent PyDough code: `TPCH.Nations(region_name=region.name)`
-calc_node = builder.build_calc(table_collection, [child_collection])
-calc_node = calc_node.with_terms([("region_name", child_reference_node)])
+# Build a CALCULATE node
+# Equivalent PyDough code: `TPCH.Nations.CALCULATE(region_name=region.name)`
+calculate_node = builder.build_calc(table_collection, [child_collection])
+calculate_node = calculate_node.with_terms([("region_name", child_reference_node)])
# Build a WHERE node
# Equivalent PyDough code: `TPCH.Nations.WHERE(region.name == "ASIA")`
@@ -111,12 +111,8 @@ partition_key = builder.build_reference(part_collection, "part_type")
partition_by_node = builder.build_partition(part_collection, child_collection, "p")
partition_by_node = partition_by_node.with_keys([partition_key])
-# Build a back reference collection node
-# Equivalent PyDough code: `BACK(1).subcollection`
-back_reference_collection_node = builder.build_back_reference_collection(table_collection, "subcollection", 1)
-
# Build a child reference collection node
-# Equivalent PyDough code: `Nations(n_customers=COUNT(customers))`
+# Equivalent PyDough code: `Nations.CALCULATE(n_customers=COUNT(customers))`
customers_sub_collection = builder.build_child_access("customers", table_collection)
customers_child = ChildOperatorChildAccess(customers_sub_collection)
child_reference_collection_node = builder.build_child_reference_collection(
@@ -127,8 +123,8 @@ count_call = builder.build_expression_function_call(
"COUNT",
[child_reference_collection_node]
)
-calc_node = builder.build_calc(table_collection, [customers_child])
-calc_node = calc_node.with_terms([("n_customers", count_call)])
+calculate_node = builder.build_calc(table_collection, [customers_child])
+calculate_node = calculate_node.with_terms([("n_customers", count_call)])
# Build a window function call node
# Equivalent PyDough code: `RANKING(by=TPCH.Nations.name, levels=1, allow_ties=True)`
@@ -144,11 +140,11 @@ Below are some examples of PyDough snippets that are/aren't affected by the rewr
```python
-# Will be rewritten to `Customers(name, has_orders=COUNT(orders) > 0)`
-Customers(name, has_orders=HAS(orders))
+# Will be rewritten to `Customers.CALCULATE(name, has_orders=COUNT(orders) > 0)`
+Customers.CALCULATE(name, has_orders=HAS(orders))
-# Will be rewritten to `Customers(name, never_made_order=COUNT(orders) == 0)`
-Customers(name, never_made_order=HASNOT(orders))
+# Will be rewritten to `Customers.CALCULATE(name, never_made_order=COUNT(orders) == 0)`
+Customers.CALCULATE(name, never_made_order=HASNOT(orders))
# Will not be rewritten
Customers.WHERE(HAS(orders) & (nation.region.name == "EUROPE"))
diff --git a/pydough/qdag/__init__.py b/pydough/qdag/__init__.py
index bf588c43..28793308 100644
--- a/pydough/qdag/__init__.py
+++ b/pydough/qdag/__init__.py
@@ -6,9 +6,8 @@
__all__ = [
"AstNodeBuilder",
- "BackReferenceCollection",
"BackReferenceExpression",
- "Calc",
+ "Calculate",
"ChildAccess",
"ChildOperator",
"ChildOperatorChildAccess",
@@ -39,8 +38,7 @@
from .abstract_pydough_qdag import PyDoughQDAG
from .collections import (
- BackReferenceCollection,
- Calc,
+ Calculate,
ChildAccess,
ChildOperator,
ChildOperatorChildAccess,
diff --git a/pydough/qdag/collections/README.md b/pydough/qdag/collections/README.md
index 155de9df..b780cd9b 100644
--- a/pydough/qdag/collections/README.md
+++ b/pydough/qdag/collections/README.md
@@ -16,10 +16,8 @@ The QDAG collections module contains the following hierarchy of collection class
- [`TableCollection`](table_collection.py) (concrete): Accessing a table collection directly.
- [`SubCollection`](sub_collection.py) (concrete): Accessing a subcolleciton of another collection.
- [`CompoundSubCollection`](sub_collection.py) (concrete): Accessing a subcollection of another collection where the subcollection property is a compound relationship.
- - [`BackReferenceCollection`](back_reference_collection.py) (concrete): Same idea as `ChildReferenceCollection`, but on a subcollection of an ancestor collection
- - [`HiddenBackReferenceCollection`](hidden_back_reference_collection.py) (concrete): Same idea as `BackReferenceCollection`, but where the back reference is hidden because it is a subcollection reference where the subcollection comes from a hidden ancestor of a compound relationship.
- [`ChildOperator`](child_operator.py) (abstract): Base class for collection QDAG nodes that need to access child contexts in order to make a child reference.
- - [`Calc`](calc.py) (concrete): Operation that defines new singular expression terms in the current context and names them.
+ - [`Calculate`](calculate.py) (concrete): Operation that defines new singular expression terms in the current context and names them.
- [`Where`](where.py) (concrete): Operation that filters the current context based on a predicate that is a singular expression.
- [`OrderBy`](order_by.py) (concrete): Operation that sorts the current context based on 1+ singular collation expressions.
- [`TopK`](top_k.py) (concrete): Operation that sorts the current context based on 1+ singular collation expressions and filters to only keep the first `k` records.
@@ -33,8 +31,8 @@ The base QDAG collection node contains the following interface:
- `all_terms`: Property that returns the set of all names of terms of the collection (collections or expressions).
- `is_singular`: Method that takes in a context and returns whether the current collection is singular with regards to that context. (Note: it is assumed that `.starting_predecessor` has been called on all the arguments already).
- `starting_predecessor`: Property that finds the furthest predecessor of the curren collection.
-- `verify_singular_terms`: Method that takes in a sequence of expression QDAG nodes and verifies that all of them are singular with regards to the current context (e.g. can they be used as CALC terms).
-- `get_expression_position`: Method that takes in the string name of a calc term and returns its ordinal position when placed in the output.
+- `verify_singular_terms`: Method that takes in a sequence of expression QDAG nodes and verifies that all of them are singular with regards to the current context (e.g. can they be used as CALCULATE terms).
+- `get_expression_position`: Method that takes in the string name of a calculate term and returns its ordinal position when placed in the output.
- `get_term`: Method that takes in the string name of any term of the current context and returns the QDAG node for it with regards to the current context. E.g. if calling on the name of a subcollection, returns the subcollection node.
- `get_expr`: Same as `get_term` but specifically for expressions-only.
- `get_collection`: Same as `get_term` but specifically for collections-only.
@@ -57,22 +55,25 @@ The objects are created by calling the `to_tree_form` API of a collection QDAG n
Below is an example of a PyDough snippet and the corresponding tree string representation:
```python
-Nations.WHERE(
+Nations.CALCULATE(
+ nation_name=name,
+).WHERE(
region.name == "EUROPE"
-).suppliers(
+).suppliers.CALCULATE(
supplier_name=name,
- nation_name=BACK(1).name
+ nation_name=nation_name,
)
```
```
──┬─ TPCH
├─── TableCollection[Nations]
+ ├─── Calculate[nation_name=name]
└─┬─ Where[$1.name == 'EUROPE']
├─┬─ AccessChild
│ └─── SubCollection[region]
├─── SubCollection[suppliers]
- └─── Calc[supplier_name=name, nation_name=BACK(1).name]
+ └─── Calculate[supplier_name=name, nation_name=nation_name]
```
And below is another such example:
@@ -80,7 +81,7 @@ And below is another such example:
```python
german_suppliers = supply_records.WHERE(supplier.nation == "GERMANY")
selected_parts = parts.WHERE(HAS(german_suppliers))
-PARTITION(selected_parts, name="p", by=size)(
+PARTITION(selected_parts, name="p", by=size).CALCULATE(
size,
n_parts_with_german_supplier=COUNT(p)
).TOP_K(
@@ -101,7 +102,7 @@ PARTITION(selected_parts, name="p", by=size)(
│ └─┬─ AccessChild
│ └─┬─ SubCollection[supplier]
│ └─── SubCollection[nation]
- ├─┬─ Calc[size=size, n_parts_with_german_supplier=COUNT($1)]
+ ├─┬─ Calculate[size=size, n_parts_with_german_supplier=COUNT($1)]
│ └─┬─ AccessChild
│ └─── PartitionChild[p]
└─── TopK[10, n_parts_with_german_supplier.DESC(na_pos='last')]
diff --git a/pydough/qdag/collections/__init__.py b/pydough/qdag/collections/__init__.py
index 8412efed..dc7bba03 100644
--- a/pydough/qdag/collections/__init__.py
+++ b/pydough/qdag/collections/__init__.py
@@ -5,8 +5,7 @@
__all__ = [
"AugmentingChildOperator",
- "BackReferenceCollection",
- "Calc",
+ "Calculate",
"ChildAccess",
"ChildOperator",
"ChildOperatorChildAccess",
@@ -25,8 +24,7 @@
]
from .augmenting_child_operator import AugmentingChildOperator
-from .back_reference_collection import BackReferenceCollection
-from .calc import Calc
+from .calculate import Calculate
from .child_access import ChildAccess
from .child_operator import ChildOperator
from .child_operator_child_access import ChildOperatorChildAccess
diff --git a/pydough/qdag/collections/augmenting_child_operator.py b/pydough/qdag/collections/augmenting_child_operator.py
index a73d4c5d..176f9796 100644
--- a/pydough/qdag/collections/augmenting_child_operator.py
+++ b/pydough/qdag/collections/augmenting_child_operator.py
@@ -1,6 +1,6 @@
"""
Defines an abstract subclass of ChildOperator for operations that augment their
-preceding context without stepping down into another context, like CALC or
+preceding context without stepping down into another context, like CALCULATE or
WHERE.
"""
@@ -10,7 +10,7 @@
from functools import cache
from pydough.qdag.abstract_pydough_qdag import PyDoughQDAG
-from pydough.qdag.expressions import CollationExpression
+from pydough.qdag.expressions import CollationExpression, PyDoughExpressionQDAG
from .child_access import ChildAccess
from .child_operator import ChildOperator
@@ -40,10 +40,26 @@ def ancestor_context(self) -> PyDoughCollectionQDAG | None:
def preceding_context(self) -> PyDoughCollectionQDAG:
return self._preceding_context
+ @property
+ def ancestral_mapping(self) -> dict[str, int]:
+ return self.preceding_context.ancestral_mapping
+
+ @property
+ def inherited_downstreamed_terms(self) -> set[str]:
+ return self.preceding_context.inherited_downstreamed_terms
+
@property
def ordering(self) -> list[CollationExpression] | None:
return self.preceding_context.ordering
+ @property
+ def calc_terms(self) -> set[str]:
+ return self.preceding_context.calc_terms
+
+ @property
+ def all_terms(self) -> set[str]:
+ return self.preceding_context.all_terms
+
@property
def unique_terms(self) -> list[str]:
return self.preceding_context.unique_terms
@@ -58,7 +74,7 @@ def get_expression_position(self, expr_name: str) -> int:
@cache
def get_term(self, term_name: str) -> PyDoughQDAG:
- from pydough.qdag.expressions import PyDoughExpressionQDAG, Reference
+ from pydough.qdag.expressions import Reference
term: PyDoughQDAG = self.preceding_context.get_term(term_name)
if isinstance(term, ChildAccess):
@@ -67,6 +83,10 @@ def get_term(self, term_name: str) -> PyDoughQDAG:
term = Reference(self.preceding_context, term_name)
return term
+ @cache
+ def to_string(self) -> str:
+ return f"{self.preceding_context.to_string()}.{self.standalone_string}"
+
def to_tree_form(self, is_last: bool) -> CollectionTreeForm:
predecessor: CollectionTreeForm = self.preceding_context.to_tree_form(is_last)
predecessor.has_successor = True
diff --git a/pydough/qdag/collections/back_reference_collection.py b/pydough/qdag/collections/back_reference_collection.py
deleted file mode 100644
index 31fb2c4a..00000000
--- a/pydough/qdag/collections/back_reference_collection.py
+++ /dev/null
@@ -1,92 +0,0 @@
-"""
-Definition of PyDough QDAG collection type for accesses to a subcollection of an
-ancestor of the current context.
-"""
-
-__all__ = ["BackReferenceCollection"]
-
-
-from functools import cache
-
-from pydough.qdag.errors import PyDoughQDAGException
-
-from .child_access import ChildAccess
-from .collection_access import CollectionAccess
-from .collection_qdag import PyDoughCollectionQDAG
-
-
-class BackReferenceCollection(CollectionAccess):
- """
- The QDAG node implementation class representing a subcollection of an
- ancestor collection.
- """
-
- def __init__(
- self,
- parent: PyDoughCollectionQDAG,
- term_name: str,
- back_levels: int,
- ):
- if not (isinstance(back_levels, int) and back_levels > 0):
- raise PyDoughQDAGException(
- f"Expected number of levels in BACK to be a positive integer, received {back_levels!r}"
- )
- self._term_name: str = term_name
- self._back_levels: int = back_levels
- ancestor: PyDoughCollectionQDAG = parent
- for _ in range(back_levels):
- if ancestor.ancestor_context is None:
- msg: str = "1 level" if back_levels == 1 else f"{back_levels} levels"
- raise PyDoughQDAGException(
- f"Cannot reference back {msg} above {parent!r}"
- )
- ancestor = ancestor.ancestor_context
- access = ancestor.get_collection(term_name)
- assert isinstance(access, CollectionAccess)
- self._collection_access: CollectionAccess = access
- super().__init__(self._collection_access.collection, ancestor)
-
- def clone_with_parent(self, new_ancestor: PyDoughCollectionQDAG) -> ChildAccess:
- return BackReferenceCollection(new_ancestor, self.term_name, self.back_levels)
-
- @property
- def back_levels(self) -> int:
- """
- The number of levels upward that the backreference refers to.
- """
- return self._back_levels
-
- @property
- def term_name(self) -> str:
- """
- The name of the subcollection being accessed from the ancestor.
- """
- return self._term_name
-
- @property
- def collection_access(self) -> CollectionAccess:
- """
- The collection access property of the ancestor that BACK points to.
- """
- return self._collection_access
-
- @cache
- def is_singular(self, context: PyDoughCollectionQDAG) -> bool:
- return self.collection_access.is_singular(
- self.collection_access.ancestor_context.starting_predecessor
- )
-
- @property
- def key(self) -> str:
- return self.standalone_string
-
- @property
- def standalone_string(self) -> str:
- return f"BACK({self.back_levels}).{self.term_name}"
-
- def to_string(self) -> str:
- return self.standalone_string
-
- @property
- def tree_item_string(self) -> str:
- return f"BackSubCollection[{self.back_levels}, {self.term_name}]"
diff --git a/pydough/qdag/collections/calc.py b/pydough/qdag/collections/calculate.py
similarity index 57%
rename from pydough/qdag/collections/calc.py
rename to pydough/qdag/collections/calculate.py
index c1296e54..96ad0917 100644
--- a/pydough/qdag/collections/calc.py
+++ b/pydough/qdag/collections/calculate.py
@@ -1,26 +1,30 @@
"""
-Definition of PyDough QDAG collection type for a CALC, which defines new
+Definition of PyDough QDAG collection type for a CALCULATE, which defines new
expressions of the current context (or overrides existing definitions) that are
all singular with regards to it.
"""
-__all__ = ["Calc"]
+__all__ = ["Calculate"]
from collections.abc import MutableMapping, MutableSequence
from functools import cache
from pydough.qdag.abstract_pydough_qdag import PyDoughQDAG
from pydough.qdag.errors import PyDoughQDAGException
-from pydough.qdag.expressions import PyDoughExpressionQDAG
+from pydough.qdag.expressions import (
+ BackReferenceExpression,
+ PyDoughExpressionQDAG,
+ Reference,
+)
from pydough.qdag.has_hasnot_rewrite import has_hasnot_rewrite
from .augmenting_child_operator import AugmentingChildOperator
from .collection_qdag import PyDoughCollectionQDAG
-class Calc(AugmentingChildOperator):
+class Calculate(AugmentingChildOperator):
"""
- The QDAG node implementation class representing a CALC expression.
+ The QDAG node implementation class representing a CALCULATE expression.
"""
def __init__(
@@ -33,43 +37,64 @@ def __init__(
self._calc_term_indices: dict[str, int] | None = None
self._calc_term_values: MutableMapping[str, PyDoughExpressionQDAG] | None = None
self._all_term_names: set[str] = set()
+ self._ancestral_mapping: dict[str, int] = dict(
+ predecessor.ancestral_mapping.items()
+ )
+ self._calc_terms: set[str] = set()
def with_terms(
self, terms: MutableSequence[tuple[str, PyDoughExpressionQDAG]]
- ) -> "Calc":
+ ) -> "Calculate":
"""
- Specifies the terms that are calculated inside of a CALC node,
- returning the mutated CALC node afterwards. This is called after the
- CALC node is created so that the terms can be expressions that
- reference child nodes of the CALC. However, this must be called
- on the CALC node before any properties are accessed by `calc_terms`,
- `all_terms`, `to_string`, etc.
+ Specifies the terms that are calculated inside of a CALCULATE node,
+ returning the mutated CALCULATE node afterwards. This is called after
+ the CALCULATE node is created so that the terms can be expressions that
+ reference child nodes of the CALCULATE. However, this must be called
+ on the CALCULATE node before any properties are accessed by
+ `calc_terms`, `all_terms`, `to_string`, etc.
Args:
- `terms`: the list of terms calculated in the CALC node as a list of
- tuples in the form `(name, expression)`. Each `expression` can
- contain `ChildReferenceExpression` instances that refer to an property of one
- of the children of the CALC node.
+ `terms`: the list of terms calculated in the CALCULATE node as a
+ list of tuples in the form `(name, expression)`. Each `expression`
+ can contain `ChildReferenceExpression` instances that refer to a
+ property of one of the children of the CALCULATE node.
Returns:
- The mutated CALC node (which has also been modified in-place).
+ The mutated CALCULATE node (which has also been modified in-place).
Raises:
`PyDoughQDAGException` if the terms have already been added to the
- CALC node.
+ CALCULATE node.
"""
if self._calc_term_indices is not None:
raise PyDoughQDAGException(
- "Cannot call `with_terms` on a CALC node more than once"
+ "Cannot call `with_terms` on a CALCULATE node more than once"
)
- # Include terms from the predecessor, with the terms from this CALC
- # added in.
+ # Include terms from the predecessor, with the terms from this
+ # CALCULATE added in.
self._calc_term_indices = {}
self._calc_term_values = {}
for idx, (name, value) in enumerate(terms):
+ ancestral_idx: int = self.ancestral_mapping.get(name, 0)
+ if ancestral_idx > 0:
+ # Ignore no-op back-references, e.g.:
+ # region(region_name=name).customers(region_name=region_name)
+ if not (
+ (
+ isinstance(value, BackReferenceExpression)
+ and value.back_levels == ancestral_idx
+ and value.term_name == name
+ )
+ or isinstance(value, Reference)
+ ):
+ raise PyDoughQDAGException(
+ f"Cannot redefine term {name!r} in CALCULATE that is already defined in an ancestor"
+ )
self._calc_term_indices[name] = idx
self._calc_term_values[name] = has_hasnot_rewrite(value, False)
self._all_term_names.add(name)
+ self._calc_terms.add(name)
+ self.ancestral_mapping[name] = 0
self.all_terms.update(self.preceding_context.all_terms)
self.verify_singular_terms(self._calc_term_values.values())
return self
@@ -79,12 +104,12 @@ def calc_term_indices(
self,
) -> dict[str, int]:
"""
- Mapping of each named expression of the CALC to the index of the
- ordinal position of the property when included in a CALC.
+ Mapping of each named expression of the CALCULATE to the index of the
+ ordinal position of the property when included in a CALCULATE.
"""
if self._calc_term_indices is None:
raise PyDoughQDAGException(
- "Cannot access `calc_term_indices` of a Calc node before adding calc terms with `with_terms`"
+ "Cannot access `calc_term_indices` of a CALCULATE node before adding calc terms with `with_terms`"
)
return self._calc_term_indices
@@ -93,30 +118,34 @@ def calc_term_values(
self,
) -> MutableMapping[str, PyDoughExpressionQDAG]:
"""
- Mapping of each named expression of the CALC to the QDAG node for
+ Mapping of each named expression of the CALCULATE to the QDAG node for
that expression.
"""
if self._calc_term_values is None:
raise PyDoughQDAGException(
- "Cannot access `_calc_term_values` of a Calc node before adding calc terms with `with_terms`"
+ "Cannot access `_calc_term_values` of a CALCULATE node before adding calc terms with `with_terms`"
)
return self._calc_term_values
@property
def key(self) -> str:
- return f"{self.preceding_context.key}.CALC"
+ return f"{self.preceding_context.key}.CALCULATE"
@property
def calc_terms(self) -> set[str]:
- return set(self.calc_term_indices)
+ return self._calc_terms
@property
def all_terms(self) -> set[str]:
return self._all_term_names
+ @property
+ def ancestral_mapping(self) -> dict[str, int]:
+ return self._ancestral_mapping
+
def get_expression_position(self, expr_name: str) -> int:
if expr_name not in self.calc_terms:
- raise PyDoughQDAGException(f"Unrecognized CALC term: {expr_name!r}")
+ raise PyDoughQDAGException(f"Unrecognized CALCULATE term: {expr_name!r}")
return self.calc_term_indices[expr_name]
@cache
@@ -128,7 +157,7 @@ def get_term(self, term_name: str) -> PyDoughQDAG:
def calc_kwarg_strings(self, tree_form: bool) -> str:
"""
- Converts the terms of a CALC into a string in the form
+ Converts the terms of a CALCULATE into a string in the form
`"x=1, y=phone_number, z=STARTSWITH(LOWER(name), 'a')"`
Args:
@@ -147,16 +176,13 @@ def calc_kwarg_strings(self, tree_form: bool) -> str:
return ", ".join(kwarg_strings)
@property
+ @cache
def standalone_string(self) -> str:
- return f"({self.calc_kwarg_strings(False)})"
-
- def to_string(self) -> str:
- assert self.preceding_context is not None
- return f"{self.preceding_context.to_string()}{self.standalone_string}"
+ return f"CALCULATE({self.calc_kwarg_strings(False)})"
@property
def tree_item_string(self) -> str:
- return f"Calc[{self.calc_kwarg_strings(True)}]"
+ return f"Calculate[{self.calc_kwarg_strings(True)}]"
def equals(self, other: object) -> bool:
if self._calc_term_indices is None:
@@ -165,7 +191,7 @@ def equals(self, other: object) -> bool:
)
return (
super().equals(other)
- and isinstance(other, Calc)
+ and isinstance(other, Calculate)
and self._calc_term_indices == other._calc_term_indices
and self._calc_term_values == other._calc_term_values
)
diff --git a/pydough/qdag/collections/child_operator.py b/pydough/qdag/collections/child_operator.py
index 77aecb00..3e39e3dd 100644
--- a/pydough/qdag/collections/child_operator.py
+++ b/pydough/qdag/collections/child_operator.py
@@ -14,7 +14,7 @@
class ChildOperator(PyDoughCollectionQDAG):
"""
Base class for PyDough collection QDAG nodes that have access to
- child collections, such as CALC or WHERE.
+ child collections, such as CALCULATE or WHERE.
"""
def __init__(
diff --git a/pydough/qdag/collections/child_operator_child_access.py b/pydough/qdag/collections/child_operator_child_access.py
index 4a972f13..e041abd5 100644
--- a/pydough/qdag/collections/child_operator_child_access.py
+++ b/pydough/qdag/collections/child_operator_child_access.py
@@ -10,7 +10,6 @@
from pydough.qdag.abstract_pydough_qdag import PyDoughQDAG
-from .back_reference_collection import BackReferenceCollection
from .child_access import ChildAccess
from .collection_qdag import PyDoughCollectionQDAG
from .collection_tree_form import CollectionTreeForm
@@ -53,6 +52,14 @@ def calc_terms(self) -> set[str]:
def all_terms(self) -> set[str]:
return self.child_access.all_terms
+ @property
+ def ancestral_mapping(self) -> dict[str, int]:
+ return self.child_access.ancestral_mapping
+
+ @property
+ def inherited_downstreamed_terms(self) -> set[str]:
+ return self.child_access.inherited_downstreamed_terms
+
@property
def unique_terms(self) -> list[str]:
return self.child_access.unique_terms
@@ -78,9 +85,7 @@ def is_singular(self, context: PyDoughCollectionQDAG) -> bool:
assert ancestor is not None
relative_context: PyDoughCollectionQDAG = ancestor.starting_predecessor
return self.child_access.is_singular(relative_context) and (
- isinstance(self.child_access, BackReferenceCollection)
- or (context == relative_context)
- or relative_context.is_singular(context)
+ (context == relative_context) or relative_context.is_singular(context)
)
@property
@@ -89,7 +94,7 @@ def standalone_string(self) -> str:
def to_string(self) -> str:
# Does not include the parent since this exists within the context
- # of a CALC node.
+ # of an operator such as a CALCULATE node.
return self.standalone_string
@property
diff --git a/pydough/qdag/collections/child_reference_collection.py b/pydough/qdag/collections/child_reference_collection.py
index bc23898d..fda74e82 100644
--- a/pydough/qdag/collections/child_reference_collection.py
+++ b/pydough/qdag/collections/child_reference_collection.py
@@ -9,7 +9,7 @@
from functools import cache
from pydough.qdag.abstract_pydough_qdag import PyDoughQDAG
-from pydough.qdag.expressions.collation_expression import CollationExpression
+from pydough.qdag.expressions import CollationExpression
from .child_access import ChildAccess
from .collection_qdag import PyDoughCollectionQDAG
@@ -19,7 +19,7 @@
class ChildReferenceCollection(ChildAccess):
"""
The QDAG node implementation class representing a reference to a collection
- term in a child collection of a CALC or other child operator.
+ term in a child collection of a CALCULATE or other child operator.
"""
def __init__(
@@ -45,7 +45,7 @@ def collection(self) -> PyDoughCollectionQDAG:
@property
def child_idx(self) -> int:
"""
- The integer index of the child from the CALC that the
+ The integer index of the child from the CALCULATE that the
ChildReferenceCollection refers to.
"""
return self._child_idx
@@ -62,6 +62,14 @@ def calc_terms(self) -> set[str]:
def all_terms(self) -> set[str]:
return self.collection.all_terms
+ @property
+ def ancestral_mapping(self) -> dict[str, int]:
+ return self.collection.ancestral_mapping
+
+ @property
+ def inherited_downstreamed_terms(self) -> set[str]:
+ return self.collection.inherited_downstreamed_terms
+
@property
def ordering(self) -> list[CollationExpression] | None:
return self.collection.ordering
diff --git a/pydough/qdag/collections/collection_access.py b/pydough/qdag/collections/collection_access.py
index e714ee04..c976bb7c 100644
--- a/pydough/qdag/collections/collection_access.py
+++ b/pydough/qdag/collections/collection_access.py
@@ -18,7 +18,11 @@
from pydough.metadata.properties import SubcollectionRelationshipMetadata
from pydough.qdag.abstract_pydough_qdag import PyDoughQDAG
from pydough.qdag.errors import PyDoughQDAGException
-from pydough.qdag.expressions import CollationExpression, ColumnProperty
+from pydough.qdag.expressions import (
+ BackReferenceExpression,
+ CollationExpression,
+ ColumnProperty,
+)
from .child_access import ChildAccess
from .collection_qdag import PyDoughCollectionQDAG
@@ -41,6 +45,10 @@ def __init__(
self._all_property_names: set[str] = set()
self._calc_property_names: set[str] = set()
self._calc_property_order: dict[str, int] = {}
+ self._ancestral_mapping: dict[str, int] = {
+ name: level + 1 for name, level in ancestor.ancestral_mapping.items()
+ }
+ self._all_property_names.update(self._ancestral_mapping)
for property_name in sorted(
collection.get_property_names(),
key=lambda name: collection.definition_order[name],
@@ -70,6 +78,14 @@ def calc_terms(self) -> set[str]:
def all_terms(self) -> set[str]:
return self._all_property_names
+ @property
+ def ancestral_mapping(self) -> dict[str, int]:
+ return self._ancestral_mapping
+
+ @property
+ def inherited_downstreamed_terms(self) -> set[str]:
+ return self.ancestor_context.inherited_downstreamed_terms
+
@property
def ordering(self) -> list[CollationExpression] | None:
return None
@@ -92,6 +108,20 @@ def get_term(self, term_name: str) -> PyDoughQDAG:
from .compound_sub_collection import CompoundSubCollection
from .sub_collection import SubCollection
+ # Special handling of terms down-streamed from an ancestor CALCULATE
+ # clause.
+ if term_name in self.ancestral_mapping:
+ # Verify that the ancestor name is not also a name in the current
+ # context.
+ if term_name in self.calc_terms:
+ raise PyDoughQDAGException(
+ f"Cannot have term name {term_name!r} used in an ancestor of collection {self!r}"
+ )
+ # Create a back-reference to the ancestor term.
+ return BackReferenceExpression(
+ self, term_name, self.ancestral_mapping[term_name]
+ )
+
if term_name not in self.all_terms:
raise PyDoughQDAGException(
f"Unrecognized term of {self.collection.error_name}: {term_name!r}"
diff --git a/pydough/qdag/collections/collection_qdag.py b/pydough/qdag/collections/collection_qdag.py
index 49daba8d..7caf265b 100644
--- a/pydough/qdag/collections/collection_qdag.py
+++ b/pydough/qdag/collections/collection_qdag.py
@@ -30,9 +30,9 @@ def __repr__(self):
@property
def ancestor_context(self) -> Union["PyDoughCollectionQDAG", None]:
"""
- The ancestor context from which this collection is derived, e.g. what
- is accessed by `BACK(1)`. Returns None if there is no ancestor context,
- e.g. because the collection is the top of the hierarchy.
+ The ancestor context from which this collection is derived. Returns
+ None if there is no ancestor context because the collection is the top
+ of the hierarchy.
"""
@property
@@ -40,7 +40,7 @@ def ancestor_context(self) -> Union["PyDoughCollectionQDAG", None]:
def preceding_context(self) -> Union["PyDoughCollectionQDAG", None]:
"""
The preceding context from which this collection is derived, e.g. an
- ORDER BY term before a CALC. Returns None if there is no preceding
+ ORDER BY term before a CALCULATE. Returns None if there is no preceding
context, e.g. because the collection is the start of a pipeline
within a larger ancestor context.
"""
@@ -51,7 +51,8 @@ def calc_terms(self) -> set[str]:
"""
The list of expressions that would be retrieved if the collection
were to have its results evaluated. This is the set of names in the
- most-recent CALC, potentially with extra expressions added since then.
+ most-recent CALCULATE, potentially with extra expressions added since
+ then.
"""
@property
@@ -61,6 +62,28 @@ def all_terms(self) -> set[str]:
The set of expression/subcollection names accessible by the context.
"""
+ @property
+ @abstractmethod
+ def ancestral_mapping(self) -> dict[str, int]:
+ """
+ A mapping of names created by the current context and its ancestors
+ describing terms defined inside a CALCULATE clause that are available
+ to the current context & descendants to back-reference via that name
+ to the number of ancestors up required to find the back-referenced
+ term.
+ """
+
+ @property
+ @abstractmethod
+ def inherited_downstreamed_terms(self) -> set[str]:
+ """
+ A set of names created by indirect ancestors of the current context
+ that can be used to back-reference. The specific index of the
+ back-reference is handled during the hybrid conversion process, when
+ implicit back-references are flushed to populate the base of the tree
+ input to a PARTITION node.
+ """
+
@abstractmethod
def is_singular(self, context: "PyDoughCollectionQDAG") -> bool:
"""
@@ -98,7 +121,7 @@ def starting_predecessor(self) -> "PyDoughCollectionQDAG":
def verify_singular_terms(self, exprs: Iterable[PyDoughExpressionQDAG]) -> None:
"""
Verifies that a list of expressions is singular with regards to the
- current collection, e.g. they can used as CALC terms.
+ current collection, e.g. they can used as CALCULATE terms.
Args:
`exprs`: the list of expression to be checked.
@@ -254,8 +277,14 @@ def to_tree_string(self) -> str:
structured. For example, consider the following PyDough snippet:
```
- Regions.WHERE(ENDSWITH(name, 's')).nations.WHERE(name != 'USA')(
- a=BACK(1).name,
+ Regions.CALCULATE(
+ region_name=name,
+ ).WHERE(
+ ENDSWITH(name, 's')
+ ).nations.WHERE(
+ name != 'USA'
+ ).CALCULATE(
+ a=region_name,
b=name,
c=MAX(YEAR(suppliers.WHERE(STARTSWITH(phone, '415')).supply_records.lines.ship_date)),
d=COUNT(customers.WHERE(acctbal > 0))
@@ -271,10 +300,11 @@ def to_tree_string(self) -> str:
```
──┬─ TPCH
├─── TableCollection[Regions]
+ ├─── Calculate[region_name=name]
└─┬─ Where[ENDSWITH(name, 's')]
├─── SubCollection[nations]
├─── Where[name != 'USA']
- ├─┬─ Calc[a=[BACK(1).name], b=[name], c=[MAX($2._expr1)], d=[COUNT($1)]]
+ ├─┬─ Calculate[a=[region_name], b=[name], c=[MAX($2._expr1)], d=[COUNT($1)]]
│ ├─┬─ AccessChild
│ │ ├─ SubCollection[customers]
│ │ └─── Where[acctbal > 0]
@@ -283,7 +313,7 @@ def to_tree_string(self) -> str:
│ ├─── Where[STARTSWITH(phone, '415')]
│ └─┬─ SubCollection[supply_records]
│ └─┬─ SubCollection[lines]
- │ └─── Calc[_expr1=YEAR(ship_date)]
+ │ └─── Calculate[_expr1=YEAR(ship_date)]
├─── Where[c > 1000]
└─── OrderBy[d.DESC()]
```
diff --git a/pydough/qdag/collections/compound_sub_collection.py b/pydough/qdag/collections/compound_sub_collection.py
index ea40f9cf..7163b65c 100644
--- a/pydough/qdag/collections/compound_sub_collection.py
+++ b/pydough/qdag/collections/compound_sub_collection.py
@@ -180,12 +180,8 @@ def get_term(self, term_name: str) -> PyDoughQDAG:
original_name: str = self._inheritance_source_name[term_name]
expr = ancestor.get_term(original_name)
if isinstance(expr, PyDoughCollectionQDAG):
- from .hidden_back_reference_collection import (
- HiddenBackReferenceCollection,
- )
-
- return HiddenBackReferenceCollection(
- self, term_name, original_name, back_levels
+ raise NotImplementedError(
+ f"Cannot access subcollection property {term_name} of compound subcollection {self.subcollection_property.name}"
)
else:
return HiddenBackReferenceExpression(
diff --git a/pydough/qdag/collections/global_context.py b/pydough/qdag/collections/global_context.py
index ca0ebf1f..bad09573 100644
--- a/pydough/qdag/collections/global_context.py
+++ b/pydough/qdag/collections/global_context.py
@@ -67,6 +67,16 @@ def calc_terms(self) -> set[str]:
# A global context does not have any calc terms
return set()
+ @property
+ def ancestral_mapping(self) -> dict[str, int]:
+ # A global context does not have any ancestral terms
+ return {}
+
+ @property
+ def inherited_downstreamed_terms(self) -> set[str]:
+ # A global context does not have any inherited downstreamed terms
+ return set()
+
@property
def all_terms(self) -> set[str]:
return set(self.collections)
diff --git a/pydough/qdag/collections/hidden_back_reference_collection.py b/pydough/qdag/collections/hidden_back_reference_collection.py
deleted file mode 100644
index d87ab9ea..00000000
--- a/pydough/qdag/collections/hidden_back_reference_collection.py
+++ /dev/null
@@ -1,94 +0,0 @@
-"""
-Definition of PyDough QDAG collection type for accesses to a subcollection of an
-ancestor of the current context in a manner that is hidden because the ancestor
-is from a compound subcollection access.
-"""
-
-__all__ = ["HiddenBackReferenceCollection"]
-
-
-from pydough.qdag.errors import PyDoughQDAGException
-
-from .back_reference_collection import BackReferenceCollection
-from .collection_access import CollectionAccess
-from .collection_qdag import PyDoughCollectionQDAG
-from .compound_sub_collection import CompoundSubCollection
-
-
-class HiddenBackReferenceCollection(BackReferenceCollection):
- """
- The QDAG node implementation class representing a subcollection of an
- ancestor collection.
- """
-
- def __init__(
- self,
- context: PyDoughCollectionQDAG,
- alias: str,
- term_name: str,
- back_levels: int,
- ):
- self._context: PyDoughCollectionQDAG = context
- self._term_name: str = term_name
- self._back_levels: int = back_levels
- self._alias: str = alias
-
- compound: PyDoughCollectionQDAG = context
- while compound.preceding_context is not None:
- compound = compound.preceding_context
- if not isinstance(compound, CompoundSubCollection):
- raise PyDoughQDAGException(
- f"Malformed hidden backreference expression: {self.to_string()}"
- )
- self._compound: CompoundSubCollection = compound
- hidden_ancestor: CollectionAccess = compound.subcollection_chain[-back_levels]
- collection_access = hidden_ancestor.get_collection(term_name)
- assert isinstance(collection_access, CollectionAccess)
- self._collection_access = collection_access
- super(BackReferenceCollection, self).__init__(
- collection_access.collection, context
- )
-
- def clone_with_parent(
- self, new_ancestor: PyDoughCollectionQDAG
- ) -> CollectionAccess:
- return HiddenBackReferenceCollection(
- new_ancestor, self.alias, self.term_name, self.back_levels
- )
-
- @property
- def context(self) -> PyDoughCollectionQDAG:
- """
- The collection context the hidden backreference operates within.
- """
- return self._context
-
- @property
- def compound(self) -> CompoundSubCollection:
- """
- The compound subcollection access that the hidden back reference
- traces to.
- """
- return self._compound
-
- @property
- def alias(self) -> str:
- """
- The alias that the back reference uses.
- """
- return self._alias
-
- @property
- def key(self) -> str:
- return f"{self.context.key}.{self.alias}"
-
- @property
- def standalone_string(self) -> str:
- return self.alias
-
- def to_string(self) -> str:
- return f"{self.context.to_string()}.{self.standalone_string}"
-
- @property
- def tree_item_string(self) -> str:
- return f"SubCollection[{self.standalone_string}]"
diff --git a/pydough/qdag/collections/order_by.py b/pydough/qdag/collections/order_by.py
index fa89f312..6a461182 100644
--- a/pydough/qdag/collections/order_by.py
+++ b/pydough/qdag/collections/order_by.py
@@ -7,6 +7,7 @@
from collections.abc import MutableSequence
+from functools import cache
from pydough.qdag.errors import PyDoughQDAGException
from pydough.qdag.expressions import CollationExpression
@@ -76,27 +77,16 @@ def collation(self) -> list[CollationExpression]:
def key(self) -> str:
return f"{self.preceding_context.key}.ORDERBY"
- @property
- def calc_terms(self) -> set[str]:
- return self.preceding_context.calc_terms
-
- @property
- def all_terms(self) -> set[str]:
- return self.preceding_context.all_terms
-
@property
def ordering(self) -> list[CollationExpression]:
return self.collation
@property
+ @cache
def standalone_string(self) -> str:
collation_str: str = ", ".join([expr.to_string() for expr in self.collation])
return f"ORDER_BY({collation_str})"
- def to_string(self) -> str:
- assert self.preceding_context is not None
- return f"{self.preceding_context.to_string()}.{self.standalone_string}"
-
@property
def tree_item_string(self) -> str:
collation_str: str = ", ".join(
diff --git a/pydough/qdag/collections/partition_by.py b/pydough/qdag/collections/partition_by.py
index dc2e7b94..0744382f 100644
--- a/pydough/qdag/collections/partition_by.py
+++ b/pydough/qdag/collections/partition_by.py
@@ -12,6 +12,7 @@
from pydough.qdag.abstract_pydough_qdag import PyDoughQDAG
from pydough.qdag.errors import PyDoughQDAGException
from pydough.qdag.expressions import (
+ BackReferenceExpression,
ChildReferenceExpression,
CollationExpression,
PartitionKey,
@@ -40,6 +41,11 @@ def __init__(
self._child_name: str = child_name
self._keys: list[PartitionKey] | None = None
self._key_name_indices: dict[str, int] = {}
+ self._ancestral_mapping: dict[str, int] = {
+ name: level + 1 for name, level in ancestor.ancestral_mapping.items()
+ }
+ self._calc_terms: set[str] = set()
+ self._all_terms: set[str] = set(self.ancestral_mapping) | {self.child_name}
def with_keys(self, keys: list[ChildReferenceExpression]) -> "PartitionBy":
"""
@@ -63,6 +69,8 @@ def with_keys(self, keys: list[ChildReferenceExpression]) -> "PartitionBy":
self._keys = [PartitionKey(self, key) for key in keys]
for idx, ref in enumerate(keys):
self._key_name_indices[ref.term_name] = idx
+ self._calc_terms.add(ref.term_name)
+ self.all_terms.update(self._calc_terms)
self.verify_singular_terms(self._keys)
return self
@@ -89,7 +97,7 @@ def keys(self) -> list[PartitionKey]:
def key_name_indices(self) -> dict[str, int]:
"""
The names of the partitioning keys for the PARTITION BY clause and the
- index they have in a CALC.
+ index they have in a CALCULATE.
"""
if self._keys is None:
raise PyDoughQDAGException(
@@ -118,11 +126,19 @@ def key(self) -> str:
@property
def calc_terms(self) -> set[str]:
- return set(self._key_name_indices)
+ return self._calc_terms
@property
def all_terms(self) -> set[str]:
- return self.calc_terms | {self.child_name}
+ return self._all_terms
+
+ @property
+ def ancestral_mapping(self) -> dict[str, int]:
+ return self._ancestral_mapping
+
+ @property
+ def inherited_downstreamed_terms(self) -> set[str]:
+ return self.ancestor_context.inherited_downstreamed_terms
@property
def ordering(self) -> list[CollationExpression] | None:
@@ -139,6 +155,7 @@ def is_singular(self, context: PyDoughCollectionQDAG) -> bool:
return False
@property
+ @cache
def standalone_string(self) -> str:
keys_str: str
if len(self.keys) == 1:
@@ -147,6 +164,7 @@ def standalone_string(self) -> str:
keys_str = str(tuple([expr.expr.term_name for expr in self.keys]))
return f"Partition({self.child.to_string()}, name={self.child_name!r}, by={keys_str})"
+ @cache
def to_string(self) -> str:
return f"{self.ancestor_context.to_string()}.{self.standalone_string}"
@@ -164,7 +182,11 @@ def get_expression_position(self, expr_name: str) -> int:
@cache
def get_term(self, term_name: str) -> PyDoughQDAG:
- if term_name in self._key_name_indices:
+ if term_name in self.ancestral_mapping:
+ return BackReferenceExpression(
+ self, term_name, self.ancestral_mapping[term_name]
+ )
+ elif term_name in self._key_name_indices:
term: PartitionKey = self.keys[self._key_name_indices[term_name]]
return term
elif term_name == self.child_name:
diff --git a/pydough/qdag/collections/partition_child.py b/pydough/qdag/collections/partition_child.py
index 1301b407..74f11086 100644
--- a/pydough/qdag/collections/partition_child.py
+++ b/pydough/qdag/collections/partition_child.py
@@ -6,7 +6,13 @@
__all__ = ["PartitionChild"]
-from pydough.qdag.expressions.collation_expression import CollationExpression
+from functools import cache
+
+from pydough.qdag.expressions import (
+ BackReferenceExpression,
+ CollationExpression,
+ Reference,
+)
from .child_access import ChildAccess
from .child_operator_child_access import ChildOperatorChildAccess
@@ -32,6 +38,22 @@ def __init__(
self._is_last = True
self._partition_child_name: str = partition_child_name
self._ancestor: PyDoughCollectionQDAG = ancestor
+ self._ancestral_mapping: dict[str, int] = {
+ name: level + 1 for name, level in ancestor.ancestral_mapping.items()
+ }
+ self._inherited_downstreamed_terms: set[str] = set(
+ self.ancestor_context.inherited_downstreamed_terms
+ )
+ for name in self._child_access.ancestral_mapping:
+ self._inherited_downstreamed_terms.add(name)
+ for name in self._child_access.inherited_downstreamed_terms:
+ self._inherited_downstreamed_terms.add(name)
+
+ self._all_terms: set[str] = (
+ self.child_access.all_terms
+ | set(self.ancestral_mapping)
+ | self._inherited_downstreamed_terms
+ )
def clone_with_parent(self, new_ancestor: PyDoughCollectionQDAG) -> ChildAccess:
return PartitionChild(
@@ -53,6 +75,35 @@ def key(self) -> str:
def ordering(self) -> list[CollationExpression] | None:
return self._child_access.ordering
+ @property
+ def ancestral_mapping(self) -> dict[str, int]:
+ return self._ancestral_mapping
+
+ @property
+ def all_terms(self) -> set[str]:
+ return self._all_terms
+
+ @property
+ def inherited_downstreamed_terms(self) -> set[str]:
+ return self._inherited_downstreamed_terms
+
+ @cache
+ def get_term(self, term_name: str):
+ if term_name in self.ancestral_mapping:
+ return BackReferenceExpression(
+ self, term_name, self.ancestral_mapping[term_name]
+ )
+ if term_name in self.inherited_downstreamed_terms:
+ context: PyDoughCollectionQDAG = self.child_access
+ while term_name not in context.all_terms:
+ if context is self.child_access:
+ context = self.ancestor_context
+ else:
+ assert context.ancestor_context is not None
+ context = context.ancestor_context
+ return Reference(context, term_name)
+ return super().get_term(term_name)
+
def is_singular(self, context: PyDoughCollectionQDAG) -> bool:
# The child of a PARTITION BY clause is always presumed to be plural
# since PyDough must assume that multiple records can be grouped
@@ -63,6 +114,7 @@ def is_singular(self, context: PyDoughCollectionQDAG) -> bool:
def standalone_string(self) -> str:
return self.partition_child_name
+ @cache
def to_string(self) -> str:
return f"{self.ancestor_context.to_string()}.{self.standalone_string}"
diff --git a/pydough/qdag/collections/top_k.py b/pydough/qdag/collections/top_k.py
index 6d909435..cc0ea5b2 100644
--- a/pydough/qdag/collections/top_k.py
+++ b/pydough/qdag/collections/top_k.py
@@ -8,6 +8,7 @@
from collections.abc import MutableSequence
+from functools import cache
from .collection_qdag import PyDoughCollectionQDAG
from .order_by import OrderBy
@@ -39,11 +40,8 @@ def key(self) -> str:
return f"{self.preceding_context.key}.TOPK"
@property
- def calc_terms(self) -> set[str]:
- return self.preceding_context.calc_terms
-
- @property
- def standalone_string(self) -> str:
+ @cache
+ def standalone_string(self):
collation_str: str = ", ".join([expr.to_string() for expr in self.collation])
return f"TOP_K({self.records_to_keep}, {collation_str})"
diff --git a/pydough/qdag/collections/where.py b/pydough/qdag/collections/where.py
index 9a3f2289..ef25f998 100644
--- a/pydough/qdag/collections/where.py
+++ b/pydough/qdag/collections/where.py
@@ -7,6 +7,7 @@
from collections.abc import MutableSequence
+from functools import cache
from pydough.qdag.errors import PyDoughQDAGException
from pydough.qdag.expressions import PyDoughExpressionQDAG
@@ -62,7 +63,7 @@ def condition(self) -> PyDoughExpressionQDAG:
"""
if self._condition is None:
raise PyDoughQDAGException(
- "Cannot access `condition` of a Calc node before adding calc terms with `with_condition`"
+ "Cannot access `condition` of a WHERE node before adding the predicate with `with_condition`"
)
return self._condition
@@ -71,21 +72,10 @@ def key(self) -> str:
return f"{self.preceding_context.key}.WHERE"
@property
- def calc_terms(self) -> set[str]:
- return self.preceding_context.calc_terms
-
- @property
- def all_terms(self) -> set[str]:
- return self.preceding_context.all_terms
-
- @property
+ @cache
def standalone_string(self) -> str:
return f"WHERE({self.condition.to_string()})"
- def to_string(self) -> str:
- assert self.preceding_context is not None
- return f"{self.preceding_context.to_string()}.{self.standalone_string}"
-
@property
def tree_item_string(self) -> str:
return f"Where[{self.condition.to_string(True)}]"
diff --git a/pydough/qdag/expressions/back_reference_expression.py b/pydough/qdag/expressions/back_reference_expression.py
index 7d178e81..70498908 100644
--- a/pydough/qdag/expressions/back_reference_expression.py
+++ b/pydough/qdag/expressions/back_reference_expression.py
@@ -65,7 +65,7 @@ def requires_enclosing_parens(self, parent: PyDoughExpressionQDAG) -> bool:
return False
def to_string(self, tree_form: bool = False) -> str:
- return f"BACK({self.back_levels}).{self.term_name}"
+ return self.term_name
def equals(self, other: object) -> bool:
return (
diff --git a/pydough/qdag/expressions/child_reference_expression.py b/pydough/qdag/expressions/child_reference_expression.py
index 5a122e6a..66b1bfa8 100644
--- a/pydough/qdag/expressions/child_reference_expression.py
+++ b/pydough/qdag/expressions/child_reference_expression.py
@@ -1,7 +1,7 @@
"""
Definition of PyDough QDAG nodes for referencing expressions from a child
collection of a child operator, e.g. `orders.order_date` in
-`customers(most_recent_order=MAX(orders.order_date))`.
+`customers.CALCULATE(most_recent_order=MAX(orders.order_date))`.
"""
__all__ = ["ChildReferenceExpression"]
@@ -19,7 +19,7 @@
class ChildReferenceExpression(Reference):
"""
The QDAG node implementation class representing a reference to a term in
- a child collection of a CALC.
+ a child collection of a CALCULATE or similar child operator node.
"""
def __init__(
@@ -37,8 +37,8 @@ def __init__(
@property
def child_idx(self) -> int:
"""
- The integer index of the child from the CALC that the ChildReferenceExpression
- refers to.
+ The integer index of the child from the child operator that the
+ ChildReferenceExpression refers to.
"""
return self._child_idx
diff --git a/pydough/qdag/expressions/expression_function_call.py b/pydough/qdag/expressions/expression_function_call.py
index 557fde26..6df9aa26 100644
--- a/pydough/qdag/expressions/expression_function_call.py
+++ b/pydough/qdag/expressions/expression_function_call.py
@@ -73,9 +73,6 @@ def requires_enclosing_parens(self, parent: PyDoughExpressionQDAG) -> bool:
return self.operator.requires_enclosing_parens(parent)
def to_string(self, tree_form: bool = False) -> str:
- from pydough.qdag.collections.back_reference_collection import (
- BackReferenceCollection,
- )
from pydough.qdag.collections.child_reference_collection import (
ChildReferenceCollection,
)
@@ -90,7 +87,7 @@ def to_string(self, tree_form: bool = False) -> str:
elif isinstance(arg, PyDoughCollectionQDAG):
if tree_form:
assert isinstance(
- arg, (ChildReferenceCollection, BackReferenceCollection)
+ arg, ChildReferenceCollection
), f"Unexpected argument to function call {arg}: expected an expression, or reference to a collection"
arg_string = arg.tree_item_string
else:
diff --git a/pydough/qdag/node_builder.py b/pydough/qdag/node_builder.py
index 05157a55..53d7abb6 100644
--- a/pydough/qdag/node_builder.py
+++ b/pydough/qdag/node_builder.py
@@ -23,8 +23,7 @@
from .abstract_pydough_qdag import PyDoughQDAG
from .collections import (
- BackReferenceCollection,
- Calc,
+ Calculate,
ChildAccess,
ChildReferenceCollection,
GlobalContext,
@@ -191,7 +190,7 @@ def build_child_reference_expression(
) -> Reference:
"""
Creates a new reference to an expression from a child collection of a
- CALC.
+ CALCULATE or similar operator.
Args:
`children`: the child collections that the reference accesses.
@@ -264,23 +263,23 @@ def build_child_access(
assert isinstance(term, ChildAccess)
return term
- def build_calc(
+ def build_calculate(
self,
preceding_context: PyDoughCollectionQDAG,
children: MutableSequence[PyDoughCollectionQDAG],
- ) -> Calc:
+ ) -> Calculate:
"""
- Creates a CALC instance, but `with_terms` still needs to be called on
+ Creates a CALCULATE instance, but `with_terms` still needs to be called on
the output.
Args:
`preceding_context`: the preceding collection.
- `children`: the child collections accessed by the CALC term.
+ `children`: the child collections accessed by the CALCULATE term.
Returns:
- The newly created PyDough CALC term.
+ The newly created PyDough CALCULATE term.
"""
- return Calc(preceding_context, children)
+ return Calculate(preceding_context, children)
def build_where(
self,
@@ -358,28 +357,6 @@ def build_partition(
"""
return PartitionBy(preceding_context, child, child_name)
- def build_back_reference_collection(
- self,
- collection: PyDoughCollectionQDAG,
- term_name: str,
- back_levels: int,
- ) -> BackReferenceCollection:
- """
- Creates a reference to a a subcollection of an ancestor.
-
- Args:
- `collection`: the preceding collection.
- `term_name`: the name of the subcollection being accessed.
- `back_levels`: the number of levels up in the ancestry tree to go.
-
- Returns:
- The newly created PyDough CALC term.
-
- Raises:
- `PyDoughQDAGException`: if the terms are invalid for the CALC term.
- """
- return BackReferenceCollection(collection, term_name, back_levels)
-
def build_child_reference_collection(
self,
preceding_context: PyDoughCollectionQDAG,
@@ -388,7 +365,7 @@ def build_child_reference_collection(
) -> ChildReferenceCollection:
"""
Creates a new reference to a collection from a child collection of a
- CALC or other child operator.
+ CALCULATE or other child operator.
Args:
`preceding_context`: the preceding collection.
diff --git a/pydough/relational/relational_nodes/empty_singleton.py b/pydough/relational/relational_nodes/empty_singleton.py
index 6dbe714f..be062750 100644
--- a/pydough/relational/relational_nodes/empty_singleton.py
+++ b/pydough/relational/relational_nodes/empty_singleton.py
@@ -1,9 +1,6 @@
"""
-This file contains the relational implementation for a "project". This is our
-relational representation for a "calc" that involves any compute steps and can include
-adding or removing columns (as well as technically reordering). In general, we seek to
-avoid introducing extra nodes just to reorder or prune columns, so ideally their use
-should be sparse.
+This file contains the relational implementation for an dummy relational node
+with 1 row and 0 columns.
"""
from collections.abc import MutableMapping, MutableSequence
diff --git a/pydough/relational/relational_nodes/project.py b/pydough/relational/relational_nodes/project.py
index c2413dff..00f42b92 100644
--- a/pydough/relational/relational_nodes/project.py
+++ b/pydough/relational/relational_nodes/project.py
@@ -1,6 +1,6 @@
"""
This file contains the relational implementation for a "project". This is our
-relational representation for a "calc" that involves any compute steps and can include
+relational representation for a "calculate" that involves any compute steps and can include
adding or removing columns (as well as technically reordering). In general, we seek to
avoid introducing extra nodes just to reorder or prune columns, so ideally their use
should be sparse.
@@ -20,9 +20,9 @@
class Project(SingleRelational):
"""
- The Project node in the relational tree. This node represents a "calc" in
- relational algebra, which should involve some "compute" functions and may
- involve adding, removing, or reordering columns.
+ The Project node in the relational tree. This node represents a "calculate"
+ in relational algebra, which should involve some "compute" functions and
+ may involve adding, removing, or reordering columns.
"""
def __init__(
diff --git a/pydough/sqlglot/transform_bindings.py b/pydough/sqlglot/transform_bindings.py
index d5ac1cb8..3461063b 100644
--- a/pydough/sqlglot/transform_bindings.py
+++ b/pydough/sqlglot/transform_bindings.py
@@ -32,6 +32,15 @@
sake of precedence.
"""
+trunc_pattern = re.compile(r"\s*start\s+of\s+(\w+)\s*", re.IGNORECASE)
+"""
+The REGEX pattern for truncation modifiers in DATETIME call.
+"""
+
+offset_pattern = re.compile(r"\s*([+-]?)\s*(\d+)\s+(\w+)\s*", re.IGNORECASE)
+"""
+The REGEX pattern for offset modifiers in DATETIME call.
+"""
year_units = ("years", "year", "y")
"""
@@ -305,10 +314,6 @@ def impl(
raw_args: Sequence[RelationalExpression] | None,
sql_glot_args: Sequence[SQLGlotExpression],
):
- # Regex pattern for truncation modifiers and offset strings.
- trunc_pattern = re.compile(r"\s*start\s+of\s+(\w+)\s*", re.IGNORECASE)
- offset_pattern = re.compile(r"\s*([+-]?)\s*(\d+)\s+(\w+)\s*", re.IGNORECASE)
-
# Handle the first argument
assert len(sql_glot_args) > 0
result: SQLGlotExpression = handle_datetime_base_arg(sql_glot_args[0], dialect)
diff --git a/pydough/unqualified/README.md b/pydough/unqualified/README.md
index 218126c7..311302d1 100644
--- a/pydough/unqualified/README.md
+++ b/pydough/unqualified/README.md
@@ -12,7 +12,7 @@ Unqualified nodes are the first intermediate representation (IR) created by PyDo
- `UnqualifiedRoot`: Represents the root of an unqualified node tree.
- `UnqualifiedLiteral`: Represents a literal value in an unqualified node tree.
- `UnqualifiedAccess`: Represents accessing a property from another unqualified node.
-- `UnqualifiedCalc`: Represents a CALC clause being done onto another unqualified node.
+- `UnqualifiedCalculate`: Represents a CALCULATE clause being done onto another unqualified node.
- `UnqualifiedWhere`: Represents a WHERE clause being done onto another unqualified node.
- `UnqualifiedOrderBy`: Represents an ORDER BY clause being done onto another unqualified node.
- `UnqualifiedTopK`: Represents a TOP K clause being done onto another unqualified node.
@@ -21,7 +21,6 @@ Unqualified nodes are the first intermediate representation (IR) created by PyDo
- `UnqualifiedBinaryOperation`: Represents a binary operation.
- `UnqualifiedCollation`: Represents a collation expression.
- `UnqualifiedOperator`: Represents a function that has yet to be called.
-- `UnqualifiedBack`: Represents a BACK node.
- `UnqualifiedWindow`: Represents a window operation.
## Code Transformation
@@ -48,7 +47,7 @@ graph = parse_json_metadata_from_file("path/to/metadata.json", "example_graph")
# Define a function with the init_pydough_context decorator
@init_pydough_context(graph)
def example_function():
- return Nations(
+ return Nations.CALCULATE(
nation_name=name,
region_name=region.name,
num_customers=COUNT(customers)
@@ -57,7 +56,7 @@ def example_function():
# Transform the source code of the function
source_code = """
def example_function():
- return Nations(
+ return Nations.CALCULATE(
nation_name=name,
region_name=region.name,
num_customers=COUNT(customers)
@@ -73,7 +72,7 @@ print(ast.unparse(transformed_ast))
# Transform a Jupyter cell
cell_code = """
-result = Nations(
+result = Nations.CALCULATE(
nation_name=name,
region_name=region.name,
num_customers=COUNT(customers)
@@ -92,7 +91,7 @@ from pydough.unqualified import UnqualifiedRoot
_ROOT = UnqualifiedRoot(example_graph)
def example_function():
- return _ROOT.Nations(
+ return _ROOT.Nations.CALCULATE(
nation_name=_ROOT.name,
region_name=_ROOT.region.name,
num_customers=_ROOT.COUNT(_ROOT.customers)
@@ -105,7 +104,7 @@ The transformed Python code for the Jupyter cell will look like this:
from pydough.unqualified import UnqualifiedRoot
_ROOT = UnqualifiedRoot(example_graph)
-result = _ROOT.Nations(
+result = _ROOT.Nations.CALCULATE(
nation_name=_ROOT.name,
region_name=_ROOT.region.name,
num_customers=_ROOT.COUNT(_ROOT.customers)
diff --git a/pydough/unqualified/__init__.py b/pydough/unqualified/__init__.py
index 1dbcdb36..39522204 100644
--- a/pydough/unqualified/__init__.py
+++ b/pydough/unqualified/__init__.py
@@ -7,9 +7,8 @@
__all__ = [
"PyDoughUnqualifiedException",
"UnqualifiedAccess",
- "UnqualifiedBack",
"UnqualifiedBinaryOperation",
- "UnqualifiedCalc",
+ "UnqualifiedCalculate",
"UnqualifiedLiteral",
"UnqualifiedNode",
"UnqualifiedOperation",
@@ -32,9 +31,8 @@
from .qualification import qualify_node, qualify_term
from .unqualified_node import (
UnqualifiedAccess,
- UnqualifiedBack,
UnqualifiedBinaryOperation,
- UnqualifiedCalc,
+ UnqualifiedCalculate,
UnqualifiedLiteral,
UnqualifiedNode,
UnqualifiedOperation,
diff --git a/pydough/unqualified/qualification.py b/pydough/unqualified/qualification.py
index 0a532d78..deab5b34 100644
--- a/pydough/unqualified/qualification.py
+++ b/pydough/unqualified/qualification.py
@@ -14,7 +14,7 @@
)
from pydough.qdag import (
AstNodeBuilder,
- Calc,
+ Calculate,
ChildOperatorChildAccess,
ChildReferenceExpression,
CollationExpression,
@@ -34,9 +34,8 @@
from .errors import PyDoughUnqualifiedException
from .unqualified_node import (
UnqualifiedAccess,
- UnqualifiedBack,
UnqualifiedBinaryOperation,
- UnqualifiedCalc,
+ UnqualifiedCalculate,
UnqualifiedCollation,
UnqualifiedLiteral,
UnqualifiedNode,
@@ -348,7 +347,7 @@ def qualify_access(
`children`: the list where collection nodes that must be derived
as children of `context` should be appended.
`is_child`: whether the collection is being qualified as a child
- of a child operator context, such as CALC or PARTITION.
+ of a child operator context, such as CALCULATE or PARTITION.
Returns:
The PyDough QDAG object for the qualified collection or expression
@@ -362,91 +361,68 @@ def qualify_access(
unqualified_parent: UnqualifiedNode = unqualified._parcel[0]
name: str = unqualified._parcel[1]
term: PyDoughQDAG
- if isinstance(unqualified_parent, UnqualifiedBack):
- # If the parent is an `UnqualifiedBack`, it means that this is an
- # access in the form "BACK(n).term_name". First, fetch the ancestor
- # context in question.
- levels: int = unqualified_parent._parcel[0]
- ancestor: PyDoughCollectionQDAG = context
- for _ in range(levels):
- if ancestor.ancestor_context is None:
- raise PyDoughUnqualifiedException(
- f"Cannot back reference {levels} above {context}"
- )
- ancestor = ancestor.ancestor_context
+ # First, qualify the parent collection.
+ qualified_parent: PyDoughCollectionQDAG = self.qualify_collection(
+ unqualified_parent, context, is_child
+ )
+ if (
+ isinstance(qualified_parent, GlobalContext)
+ and name == qualified_parent.graph.name
+ ):
+ # Special case: if the parent is the root context and the child
+ # is named after the graph name, return the parent since the
+ # child is just a de-sugared invocation of the global context.
+ return qualified_parent
+ else:
# Identify whether the access is an expression or a collection
- term = ancestor.get_term(name)
+ term = qualified_parent.get_term(name)
if isinstance(term, PyDoughCollectionQDAG):
- return self.builder.build_back_reference_collection(
- context, name, levels
- )
- else:
- return self.builder.build_back_reference_expression(
- context, name, levels
+ # If it is a collection that is not the special case,
+ # access the child collection from the qualified parent
+ # collection.
+ answer: PyDoughCollectionQDAG = self.builder.build_child_access(
+ name, qualified_parent
)
- else:
- # First, qualify the parent collection.
- qualified_parent: PyDoughCollectionQDAG = self.qualify_collection(
- unqualified_parent, context, is_child
- )
- if (
- isinstance(qualified_parent, GlobalContext)
- and name == qualified_parent.graph.name
- ):
- # Special case: if the parent is the root context and the child
- # is named after the graph name, return the parent since the
- # child is just a de-sugared invocation of the global context.
- return qualified_parent
+ if isinstance(unqualified_parent, UnqualifiedRoot) and is_child:
+ answer = ChildOperatorChildAccess(answer)
+ return answer
else:
- # Identify whether the access is an expression or a collection
- term = qualified_parent.get_term(name)
- if isinstance(term, PyDoughCollectionQDAG):
- # If it is a collection that is not the special case,
- # access the child collection from the qualified parent
- # collection.
- answer: PyDoughCollectionQDAG = self.builder.build_child_access(
- name, qualified_parent
- )
- if isinstance(unqualified_parent, UnqualifiedRoot) and is_child:
- answer = ChildOperatorChildAccess(answer)
- return answer
+ assert isinstance(term, PyDoughExpressionQDAG)
+ if isinstance(unqualified_parent, UnqualifiedRoot):
+ # If at the root, the access must be a reference to a scalar
+ # attribute accessible in the current context.
+ return self.builder.build_reference(context, name)
else:
- assert isinstance(term, PyDoughExpressionQDAG)
- if isinstance(unqualified_parent, UnqualifiedRoot):
- # If at the root, the access must be a reference to a scalar
- # attribute accessible in the current context.
- return self.builder.build_reference(context, name)
+ # Otherwise, the access is a reference to a scalar attribute of
+ # a child collection node of the current context. Add this new
+ # child to the list of children, unless already present, then
+ # return the answer as a reference to a field of the child.
+ ref_num: int
+ if qualified_parent in children:
+ ref_num = children.index(qualified_parent)
else:
- # Otherwise, the access is a reference to a scalar attribute of
- # a child collection node of the current context. Add this new
- # child to the list of children, unless already present, then
- # return the answer as a reference to a field of the child.
- ref_num: int
- if qualified_parent in children:
- ref_num = children.index(qualified_parent)
- else:
- ref_num = len(children)
- children.append(qualified_parent)
- return self.builder.build_child_reference_expression(
- children, ref_num, name
- )
-
- def qualify_calc(
+ ref_num = len(children)
+ children.append(qualified_parent)
+ return self.builder.build_child_reference_expression(
+ children, ref_num, name
+ )
+
+ def qualify_calculate(
self,
- unqualified: UnqualifiedCalc,
+ unqualified: UnqualifiedCalculate,
context: PyDoughCollectionQDAG,
is_child: bool,
) -> PyDoughCollectionQDAG:
"""
- Transforms an `UnqualifiedCalc` into a PyDoughCollectionQDAG node.
+ Transforms an `UnqualifiedCalculate` into a PyDoughCollectionQDAG node.
Args:
- `unqualified`: the UnqualifiedCalc instance to be transformed.
+ `unqualified`: the UnqualifiedCalculate instance to be transformed.
`builder`: a builder object used to create new qualified nodes.
`context`: the collection QDAG whose context the collection is being
evaluated within.
`is_child`: whether the collection is being qualified as a child
- of a child operator context, such as CALC or PARTITION.
+ of a child operator context, such as CALCULATE or PARTITION.
Returns:
The PyDough QDAG object for the qualified collection node.
@@ -463,16 +439,16 @@ def qualify_calc(
qualified_parent: PyDoughCollectionQDAG = self.qualify_collection(
unqualified_parent, context, is_child
)
- # Qualify all of the CALC terms, storing the children built along
+ # Qualify all of the CALCULATE terms, storing the children built along
# the way.
children: MutableSequence[PyDoughCollectionQDAG] = []
qualified_terms: MutableSequence[tuple[str, PyDoughExpressionQDAG]] = []
for name, term in unqualified_terms:
qualified_term = self.qualify_expression(term, qualified_parent, children)
qualified_terms.append((name, qualified_term))
- # Use the qualified children & terms to create a new CALC node.
- calc: Calc = self.builder.build_calc(qualified_parent, children)
- return calc.with_terms(qualified_terms)
+ # Use the qualified children & terms to create a new CALCULATE node.
+ calculate: Calculate = self.builder.build_calculate(qualified_parent, children)
+ return calculate.with_terms(qualified_terms)
def qualify_where(
self,
@@ -489,7 +465,7 @@ def qualify_where(
`context`: the collection QDAG whose context the collection is being
evaluated within.
`is_child`: whether the collection is being qualified as a child
- of a child operator context, such as CALC or PARTITION.
+ of a child operator context, such as CALCULATE or PARTITION.
Returns:
The PyDough QDAG object for the qualified collection node.
@@ -529,7 +505,7 @@ def qualify_order_by(
`context`: the collection QDAG whose context the collection is being
evaluated within.
`is_child`: whether the collection is being qualified as a child
- of a child operator context, such as CALC or PARTITION.
+ of a child operator context, such as CALCULATE or PARTITION.
Returns:
The PyDough QDAG object for the qualified collection node.
@@ -577,7 +553,7 @@ def qualify_top_k(
`context`: the collection QDAG whose context the collection is being
evaluated within.
`is_child`: whether the collection is being qualified as a child
- of a child operator context, such as CALC or PARTITION.
+ of a child operator context, such as CALCULATE or PARTITION.
Returns:
The PyDough QDAG object for the qualified collection node.
@@ -633,7 +609,7 @@ def qualify_partition(
`context`: the collection QDAG whose context the collection is being
evaluated within.
`is_child`: whether the collection is being qualified as a child
- of a child operator context, such as CALC or PARTITION.
+ of a child operator context, such as CALCULATE or PARTITION.
Returns:
The PyDough QDAG object for the qualified collection node.
@@ -698,7 +674,7 @@ def qualify_collection(
`context`: the collection QDAG whose context the collection is being
evaluated within.
`is_child`: whether the collection is being qualified as a child
- of a child operator context, such as CALC or PARTITION.
+ of a child operator context, such as CALCULATE or PARTITION.
Returns:
The PyDough QDAG object for the qualified collection node.
@@ -764,7 +740,7 @@ def qualify_node(
`children`: the list where collection nodes that must be derived
as children of `context` should be appended.
`is_child`: whether the collection is being qualified as a child
- of a child operator context, such as CALC or PARTITION.
+ of a child operator context, such as CALCULATE or PARTITION.
Returns:
The PyDough QDAG object for the qualified node. The result can be either
@@ -789,8 +765,8 @@ def qualify_node(
answer = context
case UnqualifiedAccess():
answer = self.qualify_access(unqualified, context, children, is_child)
- case UnqualifiedCalc():
- answer = self.qualify_calc(unqualified, context, is_child)
+ case UnqualifiedCalculate():
+ answer = self.qualify_calculate(unqualified, context, is_child)
case UnqualifiedWhere():
answer = self.qualify_where(unqualified, context, is_child)
case UnqualifiedOrderBy():
diff --git a/pydough/unqualified/unqualified_node.py b/pydough/unqualified/unqualified_node.py
index e22b8bfe..f22ec8c3 100644
--- a/pydough/unqualified/unqualified_node.py
+++ b/pydough/unqualified/unqualified_node.py
@@ -5,9 +5,8 @@
__all__ = [
"UnqualifiedAccess",
- "UnqualifiedBack",
"UnqualifiedBinaryOperation",
- "UnqualifiedCalc",
+ "UnqualifiedCalculate",
"UnqualifiedLiteral",
"UnqualifiedNode",
"UnqualifiedOperation",
@@ -135,6 +134,11 @@ def __getitem__(self, key):
f"Cannot index into PyDough object {self} with {key!r}"
)
+ def __call__(self, *args, **kwargs):
+ raise PyDoughUnqualifiedException(
+ f"PyDough nodes {self!r} is not callable. Did you mean to use a function?"
+ )
+
def __bool__(self):
raise PyDoughUnqualifiedException(
"PyDough code cannot be treated as a boolean. If you intend to do a logical operation, use `|`, `&` and `~` instead of `or`, `and` and `not`."
@@ -245,7 +249,7 @@ def __neg__(self):
def __invert__(self):
return UnqualifiedOperation("NOT", [self])
- def __call__(self, *args, **kwargs: dict[str, object]):
+ def CALCULATE(self, *args, **kwargs: dict[str, object]):
calc_args: list[tuple[str, UnqualifiedNode]] = []
counter = 0
for arg in args:
@@ -262,7 +266,7 @@ def __call__(self, *args, **kwargs: dict[str, object]):
calc_args.append((name, unqualified_arg))
for name, arg in kwargs.items():
calc_args.append((name, self.coerce_to_unqualified(arg)))
- return UnqualifiedCalc(self, calc_args)
+ return UnqualifiedCalculate(self, calc_args)
def __abs__(self):
return UnqualifiedOperation("ABS", [self])
@@ -376,12 +380,6 @@ def PARTITION(
else:
return UnqualifiedPartition(self, data, name, list(by))
- def BACK(self, levels: int) -> "UnqualifiedBack":
- """
- Method used to create a BACK node.
- """
- return UnqualifiedBack(levels)
-
class UnqualifiedRoot(UnqualifiedNode):
"""
@@ -408,17 +406,6 @@ def __getattribute__(self, name: str) -> Any:
return super().__getattribute__(name)
-class UnqualifiedBack(UnqualifiedNode):
- """
- Implementation of UnqualifiedNode used to refer to a BACK node, meaning that
- anything pointing to this node as an ancestor/predecessor must be derivable
- by looking at the ancestors of the context it is placed within.
- """
-
- def __init__(self, levels: int):
- self._parcel: tuple[int] = (levels,)
-
-
class UnqualifiedLiteral(UnqualifiedNode):
"""
Implementation of UnqualifiedNode used to refer to a literal whose value is
@@ -580,9 +567,9 @@ def __init__(self, predecessor: UnqualifiedNode, name: str):
self._parcel: tuple[UnqualifiedNode, str] = (predecessor, name)
-class UnqualifiedCalc(UnqualifiedNode):
+class UnqualifiedCalculate(UnqualifiedNode):
"""
- Implementation of UnqualifiedNode used to refer to a CALC clause being
+ Implementation of UnqualifiedNode used to refer to a CALCULATE clause being
done onto another UnqualifiedNode.
"""
@@ -678,9 +665,7 @@ def display_raw(unqualified: UnqualifiedNode) -> str:
operands_str: str
match unqualified:
case UnqualifiedRoot():
- return "?"
- case UnqualifiedBack():
- return f"BACK({unqualified._parcel[0]})"
+ return unqualified._parcel[0].name
case UnqualifiedLiteral():
literal_value: Any = unqualified._parcel[0]
match literal_value:
@@ -718,11 +703,13 @@ def display_raw(unqualified: UnqualifiedNode) -> str:
pos: str = "'last'" if unqualified._parcel[2] else "'first'"
return f"{display_raw(unqualified._parcel[0])}.{method}(na_pos={pos})"
case UnqualifiedAccess():
+ if isinstance(unqualified._parcel[0], UnqualifiedRoot):
+ return unqualified._parcel[1]
return f"{display_raw(unqualified._parcel[0])}.{unqualified._parcel[1]}"
- case UnqualifiedCalc():
+ case UnqualifiedCalculate():
for name, node in unqualified._parcel[1]:
term_strings.append(f"{name}={display_raw(node)}")
- return f"{display_raw(unqualified._parcel[0])}({', '.join(term_strings)})"
+ return f"{display_raw(unqualified._parcel[0])}.CALCULATE({', '.join(term_strings)})"
case UnqualifiedWhere():
return f"{display_raw(unqualified._parcel[0])}.WHERE({display_raw(unqualified._parcel[1])})"
case UnqualifiedTopK():
@@ -738,6 +725,8 @@ def display_raw(unqualified: UnqualifiedNode) -> str:
case UnqualifiedPartition():
for node in unqualified._parcel[3]:
term_strings.append(display_raw(node))
+ if isinstance(unqualified._parcel[0], UnqualifiedRoot):
+ return f"PARTITION({display_raw(unqualified._parcel[1])}, name={unqualified._parcel[2]!r}, by=({', '.join(term_strings)}))"
return f"{display_raw(unqualified._parcel[0])}.PARTITION({display_raw(unqualified._parcel[1])}, name={unqualified._parcel[2]!r}, by=({', '.join(term_strings)}))"
case _:
raise PyDoughUnqualifiedException(
diff --git a/tests/README.md b/tests/README.md
index 4078e893..022fcb4d 100644
--- a/tests/README.md
+++ b/tests/README.md
@@ -24,15 +24,14 @@ The `TestInfo` classes are used to specify information about a QDAG (Qualified D
- `ChildReferenceExpressionInfo`: Class for building a child reference expression.
- `TableCollectionInfo`: Class for building a table collection.
- `SubCollectionInfo`: Class for creating a subcollection access.
-- `ChildOperatorChildAccessInfo`: Class for wrapping around a subcollection info within a Calc context.
-- `BackReferenceCollectionInfo`: Class for building a reference to an ancestor collection.
+- `ChildOperatorChildAccessInfo`: Class for wrapping around a subcollection info within a child operator context, such as a `CALCULATE`.
- `ChildReferenceCollectionInfo`: Class for building a reference to a child collection.
- `ChildOperatorInfo`: Base class for types of CollectionTestInfo that have child nodes.
-- `CalcInfo`: Class for building a CALC node.
-- `WhereInfo`: Class for building a WHERE clause.
-- `OrderInfo`: Class for building an ORDER BY clause.
-- `TopKInfo`: Class for building a TOP K clause.
-- `PartitionInfo`: Class for building a PARTITION BY clause.
+- `CalcInfo`: Class for building a `CALCULATE` node.
+- `WhereInfo`: Class for building a `WHERE` clause.
+- `OrderInfo`: Class for building an `ORDER_BY` clause.
+- `TopKInfo`: Class for building a `TOP_K` clause.
+- `PartitionInfo`: Class for building a `PARTITION` clause.
### Using the `**` Operator
diff --git a/tests/bad_pydough_functions.py b/tests/bad_pydough_functions.py
index 47d4a244..21a13c3f 100644
--- a/tests/bad_pydough_functions.py
+++ b/tests/bad_pydough_functions.py
@@ -24,122 +24,122 @@ def bad_bool_3():
def bad_window_1():
# Missing `by`
- return Orders(RANKING())
+ return Orders.CALCULATE(RANKING())
def bad_window_2():
# Empty `by`
- return Orders(PERCENTILE(by=()))
+ return Orders.CALCULATE(PERCENTILE(by=()))
def bad_window_3():
# Non-collations in `by`
- return Orders(RANKING(by=order_key))
+ return Orders.CALCULATE(RANKING(by=order_key))
def bad_window_4():
# Non-positive levels
- return Orders(RANKING(by=order_key.ASC(), levels=0))
+ return Orders.CALCULATE(RANKING(by=order_key.ASC(), levels=0))
def bad_window_5():
# Non-integer levels
- return Orders(RANKING(by=order_key.ASC(), levels="hello"))
+ return Orders.CALCULATE(RANKING(by=order_key.ASC(), levels="hello"))
def bad_window_6():
# Non-positive n_buckets
- return Orders(PERCENTILE(by=order_key.ASC(), n_buckets=-3))
+ return Orders.CALCULATE(PERCENTILE(by=order_key.ASC(), n_buckets=-3))
def bad_window_7():
# Non-integer n_buckets
- return Orders(PERCENTILE(by=order_key.ASC(), n_buckets=[1, 2, 3]))
+ return Orders.CALCULATE(PERCENTILE(by=order_key.ASC(), n_buckets=[1, 2, 3]))
def bad_slice_1():
# Unsupported slicing: negative stop
- return Customers(name[:-1])
+ return Customers.CALCULATE(name[:-1])
def bad_slice_2():
# Unsupported slicing: negative start
- return Customers(name[-5:])
+ return Customers.CALCULATE(name[-5:])
def bad_slice_3():
# Unsupported slicing: skipping
- return Customers(name[1:10:2])
+ return Customers.CALCULATE(name[1:10:2])
def bad_slice_4():
# Unsupported slicing: reversed
- return Customers(name[::-1])
+ return Customers.CALCULATE(name[::-1])
def bad_floor():
# Using `math.floor` (calls __floor__)
- return Customer(age=math.floor(order.total_price))
+ return Customers.CALCULATE(age=math.floor(order.total_price))
def bad_ceil():
# Using `math.ceil` (calls __ceil__)
- return Customer(age=math.ceil(order.total_price))
+ return Customers.CALCULATE(age=math.ceil(order.total_price))
def bad_trunc():
# Using `math.trunc` (calls __trunc__)
- return Customer(age=math.trunc(order.total_price))
+ return Customers.CALCULATE(age=math.trunc(order.total_price))
def bad_reversed():
# Using `reversed` (calls __reversed__)
- return Regions(backwards_name=reversed(name))
+ return Regions.CALCULATE(backwards_name=reversed(name))
def bad_int():
# Casting to int (calls __int__)
- return Orders(limit=int(order.total_price))
+ return Orders.CALCULATE(limit=int(order.total_price))
def bad_float():
# Casting to float (calls __float__)
- return Orders(limit=float(order.quantity))
+ return Orders.CALCULATE(limit=float(order.quantity))
def bad_complex():
# Casting to complex (calls __complex__)
- return Orders(limit=complex(order.total_price))
+ return Orders.CALCULATE(limit=complex(order.total_price))
def bad_index():
# Using as an index (calls __index__)
- return Orders(s="ABCDE"[:order_priority])
+ return Orders.CALCULATE(s="ABCDE"[:order_priority])
def bad_nonzero():
# Using in a boolean context (calls __nonzero__)
- return Lineitems(is_taxed=bool(tax))
+ return Lineitems.CALCULATE(is_taxed=bool(tax))
def bad_len():
# Using `len` (calls __len__)
- return Customers(len(customer.name))
+ return Customers.CALCULATE(len(customer.name))
def bad_contains():
# Using `in` operator (calls __contains__)
- return Orders("discount" in order.details)
+ return Orders.CALCULATE("discount" in comment)
def bad_setitem():
# Assigning to an index (calls __setitem__)
- order.details["discount"] = True
- return order
+ Orders["discount"] = True
+ return Orders
def bad_iter():
# Iterating over an object (calls __iter__)
- for item in customer:
+ for item in Customers:
print(item)
- return customer
+ return Customers
diff --git a/tests/correlated_pydough_functions.py b/tests/correlated_pydough_functions.py
index 3d8c4eb7..2f3d9bd6 100644
--- a/tests/correlated_pydough_functions.py
+++ b/tests/correlated_pydough_functions.py
@@ -13,9 +13,14 @@ def correl_1():
# For each region, count how many of its its nations start with the same
# letter as the region. This is a true correlated join doing an aggregated
# access without requiring the RHS be present.
- return Regions(
- name, n_prefix_nations=COUNT(nations.WHERE(name[:1] == BACK(1).name[:1]))
- ).ORDER_BY(name.ASC())
+ return (
+ Regions.CALCULATE(region_name=name)
+ .CALCULATE(
+ region_name,
+ n_prefix_nations=COUNT(nations.WHERE(name[:1] == region_name[:1])),
+ )
+ .ORDER_BY(region_name.ASC())
+ )
def correl_2():
@@ -24,10 +29,11 @@ def correl_2():
# starting with the same letter as the region. Exclude regions that start
# with the letter a. This is a true correlated join doing an aggregated
# access without requiring the RHS be present.
- selected_custs = customers.WHERE(comment[:1] == LOWER(BACK(2).name[:1]))
+ selected_custs = customers.WHERE(comment[:1] == LOWER(region_name[:1]))
return (
- Regions.WHERE(~STARTSWITH(name, "A"))
- .nations(
+ Regions.CALCULATE(region_name=name)
+ .WHERE(~STARTSWITH(name, "A"))
+ .nations.CALCULATE(
name,
n_selected_custs=COUNT(selected_custs),
)
@@ -41,9 +47,11 @@ def correl_3():
# whose comment starts with the same 2 letter as the region. This is a true
# correlated join doing an aggregated access without requiring the RHS be
# present.
- selected_custs = customers.WHERE(comment[:2] == LOWER(BACK(2).name[:2]))
- return Regions(name, n_nations=COUNT(nations.WHERE(HAS(selected_custs)))).ORDER_BY(
- name.ASC()
+ selected_custs = customers.WHERE(comment[:2] == LOWER(region_name[:2]))
+ return (
+ Regions.CALCULATE(region_name=name)
+ .CALCULATE(region_name, n_nations=COUNT(nations.WHERE(HAS(selected_custs))))
+ .ORDER_BY(name.ASC())
)
@@ -52,12 +60,12 @@ def correl_4():
# Find every nation that does not have a customer whose account balance is
# within $5 of the smallest known account balance globally.
# (This is a correlated ANTI-join)
- selected_customers = customers.WHERE(acctbal <= (BACK(2).smallest_bal + 5.0))
+ selected_customers = customers.WHERE(acctbal <= (smallest_bal + 5.0))
return (
- TPCH(
+ TPCH.CALCULATE(
smallest_bal=MIN(Customers.acctbal),
)
- .Nations(name)
+ .Nations.CALCULATE(name)
.WHERE(HASNOT(selected_customers))
.ORDER_BY(name.ASC())
)
@@ -69,13 +77,13 @@ def correl_5():
# within $4 of the smallest known account balance globally.
# (This is a correlated SEMI-join)
selected_suppliers = nations.suppliers.WHERE(
- account_balance <= (BACK(3).smallest_bal + 4.0)
+ account_balance <= (smallest_bal + 4.0)
)
return (
- TPCH(
+ TPCH.CALCULATE(
smallest_bal=MIN(Suppliers.account_balance),
)
- .Regions(name)
+ .Regions.CALCULATE(name)
.WHERE(HAS(selected_suppliers))
.ORDER_BY(name.ASC())
)
@@ -87,9 +95,11 @@ def correl_6():
# letter as the region, but only keep regions with at least one such nation.
# This is a true correlated join doing an aggregated access that does NOT
# require that records without the RHS be kept.
- selected_nations = nations.WHERE(name[:1] == BACK(1).name[:1])
- return Regions.WHERE(HAS(selected_nations))(
- name, n_prefix_nations=COUNT(selected_nations)
+ selected_nations = nations.WHERE(name[:1] == region_name[:1])
+ return (
+ Regions.CALCULATE(region_name=name)
+ .WHERE(HAS(selected_nations))
+ .CALCULATE(name, n_prefix_nations=COUNT(selected_nations))
)
@@ -98,9 +108,11 @@ def correl_7():
# For each region, count how many of its its nations start with the same
# letter as the region, but only keep regions without at least one such
# nation. The true correlated join is trumped by the correlated ANTI-join.
- selected_nations = nations.WHERE(name[:1] == BACK(1).name[:1])
- return Regions.WHERE(HASNOT(selected_nations))(
- name, n_prefix_nations=COUNT(selected_nations)
+ selected_nations = nations.WHERE(name[:1] == region_name[:1])
+ return (
+ Regions.CALCULATE(region_name=name)
+ .WHERE(HASNOT(selected_nations))
+ .CALCULATE(name, n_prefix_nations=COUNT(selected_nations))
)
@@ -111,8 +123,12 @@ def correl_8():
# (otherwise, returns NULL). This is a true correlated join doing an
# access without aggregation without requiring the RHS be
# present.
- aug_region = region.WHERE(name[:1] == BACK(1).name[:1])
- return Nations(name, rname=aug_region.name).ORDER_BY(name.ASC())
+ aug_region = region.WHERE(name[:1] == nation_name[:1])
+ return (
+ Nations.CALCULATE(nation_name=name)
+ .CALCULATE(name, rname=aug_region.name)
+ .ORDER_BY(name.ASC())
+ )
def correl_9():
@@ -121,9 +137,12 @@ def correl_9():
# so it only keeps it if it starts with the same letter as the nation
# (otherwise, omit the nation). This is a true correlated join doing an
# access that also requires the RHS records be present.
- aug_region = region.WHERE(name[:1] == BACK(1).name[:1])
- return Nations.WHERE(HAS(aug_region))(name, rname=aug_region.name).ORDER_BY(
- name.ASC()
+ aug_region = region.WHERE(name[:1] == nation_name[:1])
+ return (
+ Nations.CALCULATE(nation_name=name)
+ .WHERE(HAS(aug_region))
+ .CALCULATE(name, rname=aug_region.name)
+ .ORDER_BY(name.ASC())
)
@@ -134,9 +153,12 @@ def correl_10():
# (otherwise, returns NULL), and also filter the nations to only keep
# records where the region is NULL. The true correlated join is trumped by
# the correlated ANTI-join.
- aug_region = region.WHERE(name[:1] == BACK(1).name[:1])
- return Nations.WHERE(HASNOT(aug_region))(name, rname=aug_region.name).ORDER_BY(
- name.ASC()
+ aug_region = region.WHERE(name[:1] == nation_name[:1])
+ return (
+ Nations.CALCULATE(nation_name=name)
+ .WHERE(HASNOT(aug_region))
+ .CALCULATE(name, rname=aug_region.name)
+ .ORDER_BY(name.ASC())
)
@@ -145,10 +167,12 @@ def correl_11():
# Which part brands have at least 1 part that more than 40% above the
# average retail price for all parts from that brand.
# (This is a correlated SEMI-join)
- brands = PARTITION(Parts, name="p", by=brand)(avg_price=AVG(p.retail_price))
- outlier_parts = p.WHERE(retail_price > 1.4 * BACK(1).avg_price)
+ brands = PARTITION(Parts, name="p", by=brand).CALCULATE(
+ avg_price=AVG(p.retail_price)
+ )
+ outlier_parts = p.WHERE(retail_price > 1.4 * avg_price)
selected_brands = brands.WHERE(HAS(outlier_parts))
- return selected_brands(brand).ORDER_BY(brand.ASC())
+ return selected_brands.CALCULATE(brand).ORDER_BY(brand.ASC())
def correl_12():
@@ -157,17 +181,17 @@ def correl_12():
# price for parts of that brand, below the average retail price for all
# parts, and has a size below 3.
# (This is a correlated SEMI-join)
- global_info = TPCH(avg_price=AVG(Parts.retail_price))
- brands = global_info.PARTITION(Parts, name="p", by=brand)(
- avg_price=AVG(p.retail_price)
+ global_info = TPCH.CALCULATE(global_avg_price=AVG(Parts.retail_price))
+ brands = global_info.PARTITION(Parts, name="p", by=brand).CALCULATE(
+ brand_avg_price=AVG(p.retail_price)
)
selected_parts = p.WHERE(
- (retail_price > BACK(1).avg_price)
- & (retail_price < BACK(2).avg_price)
+ (retail_price > brand_avg_price)
+ & (retail_price < global_avg_price)
& (size < 3)
)
selected_brands = brands.WHERE(HAS(selected_parts))
- return selected_brands(brand).ORDER_BY(brand.ASC())
+ return selected_brands.CALCULATE(brand).ORDER_BY(brand.ASC())
def correl_13():
@@ -177,14 +201,16 @@ def correl_13():
# from nations #1/#2/#3, and small parts.
# (This is a correlated SEMI-joins)
selected_part = part.WHERE(
- STARTSWITH(container, "SM") & (retail_price < (BACK(1).supplycost * 1.5))
+ STARTSWITH(container, "SM") & (retail_price < (supplycost * 1.5))
+ )
+ selected_supply_records = supply_records.CALCULATE(supplycost).WHERE(
+ HAS(selected_part)
)
- selected_supply_records = supply_records.WHERE(HAS(selected_part))
- supplier_info = Suppliers.WHERE(nation_key <= 3)(
+ supplier_info = Suppliers.WHERE(nation_key <= 3).CALCULATE(
avg_price=AVG(supply_records.part.retail_price)
)
selected_suppliers = supplier_info.WHERE(COUNT(selected_supply_records) > 0)
- return TPCH(n=COUNT(selected_suppliers))
+ return TPCH.CALCULATE(n=COUNT(selected_suppliers))
def correl_14():
@@ -196,15 +222,17 @@ def correl_14():
# (This is multiple correlated SEMI-joins)
selected_part = part.WHERE(
(container == "LG DRUM")
- & (retail_price < (BACK(1).supplycost * 1.5))
- & (retail_price < BACK(2).avg_price)
+ & (retail_price < (supplycost * 1.5))
+ & (retail_price < avg_price)
+ )
+ selected_supply_records = supply_records.CALCULATE(supplycost).WHERE(
+ HAS(selected_part)
)
- selected_supply_records = supply_records.WHERE(HAS(selected_part))
- supplier_info = Suppliers.WHERE(nation_key == 19)(
+ supplier_info = Suppliers.WHERE(nation_key == 19).CALCULATE(
avg_price=AVG(supply_records.part.retail_price)
)
selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records))
- return TPCH(n=COUNT(selected_suppliers))
+ return TPCH.CALCULATE(n=COUNT(selected_suppliers))
def correl_15():
@@ -217,17 +245,19 @@ def correl_15():
# (This is multiple correlated SEMI-joins & a correlated aggregate)
selected_part = part.WHERE(
(container == "LG DRUM")
- & (retail_price < (BACK(1).supplycost * 1.5))
- & (retail_price < BACK(2).avg_price)
- & (retail_price < BACK(3).avg_price * 0.85)
+ & (retail_price < (supplycost * 1.5))
+ & (retail_price < supplier_avg_price)
+ & (retail_price < global_avg_price * 0.85)
)
- selected_supply_records = supply_records.WHERE(HAS(selected_part))
- supplier_info = Suppliers.WHERE(nation_key == 19)(
- avg_price=AVG(supply_records.part.retail_price)
+ selected_supply_records = supply_records.CALCULATE(supplycost).WHERE(
+ HAS(selected_part)
+ )
+ supplier_info = Suppliers.WHERE(nation_key == 19).CALCULATE(
+ supplier_avg_price=AVG(supply_records.part.retail_price)
)
selected_suppliers = supplier_info.WHERE(HAS(selected_supply_records))
- global_info = TPCH(avg_price=AVG(Parts.retail_price))
- return global_info(n=COUNT(selected_suppliers))
+ global_info = TPCH.CALCULATE(global_avg_price=AVG(Parts.retail_price))
+ return global_info.CALCULATE(n=COUNT(selected_suppliers))
def correl_16():
@@ -237,15 +267,15 @@ def correl_16():
# customer's percentile value of account balance relative to all other
# customers. Percentile should be measured down to increments of 0.01%.
# (This is a correlated SEMI-joins)
- selected_customers = nation(rname=region.name).customers.WHERE(
- (PERCENTILE(by=(acctbal.ASC(), key.ASC()), n_buckets=10000) == BACK(2).tile)
- & (BACK(1).rname == "EUROPE")
+ selected_customers = nation.CALCULATE(rname=region.name).customers.WHERE(
+ (PERCENTILE(by=(acctbal.ASC(), key.ASC()), n_buckets=10000) == tile)
+ & (rname == "EUROPE")
)
- supplier_info = Suppliers(
+ supplier_info = Suppliers.CALCULATE(
tile=PERCENTILE(by=(account_balance.ASC(), key.ASC()), n_buckets=10000)
)
selected_suppliers = supplier_info.WHERE(HAS(selected_customers))
- return TPCH(n=COUNT(selected_suppliers))
+ return TPCH.CALCULATE(n=COUNT(selected_suppliers))
def correl_17():
@@ -253,9 +283,9 @@ def correl_17():
# An extremely roundabout way of getting each region_name-nation_name
# pair as a string.
# (This is a correlated singular/semi access)
- region_info = region(fname=JOIN_STRINGS("-", LOWER(name), BACK(1).lname))
- nation_info = Nations(lname=LOWER(name)).WHERE(HAS(region_info))
- return nation_info(fullname=region_info.fname).ORDER_BY(fullname.ASC())
+ region_info = region.CALCULATE(fname=JOIN_STRINGS("-", LOWER(name), lname))
+ nation_info = Nations.CALCULATE(lname=LOWER(name)).WHERE(HAS(region_info))
+ return nation_info.CALCULATE(fullname=region_info.fname).ORDER_BY(fullname.ASC())
def correl_18():
@@ -270,10 +300,14 @@ def correl_18():
name="o",
by=(customer_key, order_date),
)
- selected_groups = cust_date_groups.WHERE(COUNT(o) > 1)(
- total_price=SUM(o.total_price),
- )(n_above_avg=COUNT(o.WHERE(total_price >= 0.5 * BACK(1).total_price)))
- return TPCH(n=SUM(selected_groups.n_above_avg))
+ selected_groups = (
+ cust_date_groups.WHERE(COUNT(o) > 1)
+ .CALCULATE(
+ total_price_sum=SUM(o.total_price),
+ )
+ .CALCULATE(n_above_avg=COUNT(o.WHERE(total_price >= 0.5 * total_price_sum)))
+ )
+ return TPCH.CALCULATE(n=SUM(selected_groups.n_above_avg))
def correl_19():
@@ -282,9 +316,11 @@ def correl_19():
# higher account balance than that supplier. Pick the 5 suppliers with the
# largest such count.
# (This is a correlated aggregation access)
- super_cust = customers.WHERE(acctbal > BACK(2).account_balance)
- return Suppliers.nation(name=BACK(1).name, n_super_cust=COUNT(super_cust)).TOP_K(
- 5, n_super_cust.DESC()
+ super_cust = customers.WHERE(acctbal > account_balance)
+ return (
+ Suppliers.CALCULATE(account_balance, supplier_name=name)
+ .nation.CALCULATE(supplier_name, n_super_cust=COUNT(super_cust))
+ .TOP_K(5, n_super_cust.DESC())
)
@@ -294,12 +330,12 @@ def correl_20():
# customer in the same nation, only counting instances where the order was
# made in June of 1998.
# (This is a correlated singular/semi access)
- is_domestic = nation(domestic=name == BACK(5).name).domestic
- selected_orders = Nations.customers.orders.WHERE(
+ is_domestic = nation.CALCULATE(domestic=name == source_nation_name).domestic
+ selected_orders = Nations.CALCULATE(source_nation_name=name).customers.orders.WHERE(
(YEAR(order_date) == 1998) & (MONTH(order_date) == 6)
)
instances = selected_orders.lines.supplier.WHERE(is_domestic)
- return TPCH(n=COUNT(instances))
+ return TPCH.CALCULATE(n=COUNT(instances))
def correl_21():
@@ -307,9 +343,9 @@ def correl_21():
# Count how many part sizes have an above-average number of parts
# of that size.
# (This is a correlated aggregation access)
- sizes = PARTITION(Parts, name="p", by=size)(n_parts=COUNT(p))
- return TPCH(avg_n_parts=AVG(sizes.n_parts))(
- n_sizes=COUNT(sizes.WHERE(n_parts > BACK(1).avg_n_parts))
+ sizes = PARTITION(Parts, name="p", by=size).CALCULATE(n_parts=COUNT(p))
+ return TPCH.CALCULATE(avg_n_parts=AVG(sizes.n_parts)).CALCULATE(
+ n_sizes=COUNT(sizes.WHERE(n_parts > avg_n_parts))
)
@@ -319,16 +355,17 @@ def correl_22():
# where the average retail price of parts of that container type
# & part type is above the global average retail price.
# (This is a correlated aggregation access)
- ct_combos = PARTITION(Parts, name="p", by=(container, part_type))(
+ ct_combos = PARTITION(Parts, name="p", by=(container, part_type)).CALCULATE(
avg_price=AVG(p.retail_price)
)
return (
- TPCH(global_avg_price=AVG(Parts.retail_price))
+ TPCH.CALCULATE(global_avg_price=AVG(Parts.retail_price))
.PARTITION(
- ct_combos.WHERE(avg_price > BACK(1).global_avg_price),
+ ct_combos.WHERE(avg_price > global_avg_price),
name="ct",
by=container,
- )(container, n_types=COUNT(ct))
+ )
+ .CALCULATE(container, n_types=COUNT(ct))
.TOP_K(5, (n_types.DESC(), container.ASC()))
)
@@ -339,7 +376,7 @@ def correl_23():
# of part types/containers.
# (This is a correlated aggregation access)
combos = PARTITION(Parts, name="p", by=(size, part_type, container))
- sizes = PARTITION(combos, name="c", by=size)(n_combos=COUNT(c))
- return TPCH(avg_n_combo=AVG(sizes.n_combos))(
- n_sizes=COUNT(sizes.WHERE(n_combos > BACK(1).avg_n_combo)),
+ sizes = PARTITION(combos, name="c", by=size).CALCULATE(n_combos=COUNT(c))
+ return TPCH.CALCULATE(avg_n_combo=AVG(sizes.n_combos)).CALCULATE(
+ n_sizes=COUNT(sizes.WHERE(n_combos > avg_n_combo)),
)
diff --git a/tests/defog_test_functions.py b/tests/defog_test_functions.py
index 0a749f98..d104f3e0 100644
--- a/tests/defog_test_functions.py
+++ b/tests/defog_test_functions.py
@@ -26,7 +26,7 @@ def impl_defog_broker_adv1():
Who are the top 5 customers by total transaction amount? Return their name
and total amount.
"""
- return Customers(name, total_amount=SUM(transactions_made.amount)).TOP_K(
+ return Customers.CALCULATE(name, total_amount=SUM(transactions_made.amount)).TOP_K(
5, by=total_amount.DESC()
)
@@ -39,9 +39,11 @@ def impl_defog_broker_adv3():
"""
n_transactions = COUNT(transactions_made)
n_success = SUM(transactions_made.status == "success")
- return Customers.WHERE(n_transactions >= 5)(
- name, success_rate=100.0 * n_success / n_transactions
- ).ORDER_BY(success_rate.ASC(na_pos="first"))
+ return (
+ Customers.WHERE(n_transactions >= 5)
+ .CALCULATE(name, success_rate=100.0 * n_success / n_transactions)
+ .ORDER_BY(success_rate.ASC(na_pos="first"))
+ )
def impl_defog_broker_adv6():
@@ -53,7 +55,7 @@ def impl_defog_broker_adv6():
with rank 1 being the customer with the highest total transaction amount.
"""
total_amount = SUM(transactions_made.amount)
- return Customers.WHERE(HAS(transactions_made))(
+ return Customers.WHERE(HAS(transactions_made)).CALCULATE(
name,
num_tx=COUNT(transactions_made),
total_amount=total_amount,
@@ -69,7 +71,7 @@ def impl_defog_broker_adv11():
FAANG companies (Amazon, Apple, Google, Meta or Netflix)?
"""
faang = ("AMZN", "AAPL", "GOOGL", "META", "NFLX")
- return Broker(
+ return Broker.CALCULATE(
n_customers=COUNT(
Customers.WHERE(
ENDSWITH(email, ".com")
@@ -90,7 +92,7 @@ def impl_defog_broker_adv12():
(STARTSWITH(LOWER(name), "j") | ENDSWITH(LOWER(name), "ez"))
& ENDSWITH(LOWER(state), "a")
)
- return Broker(n_customers=COUNT(selected_customers))
+ return Broker.CALCULATE(n_customers=COUNT(selected_customers))
def impl_defog_broker_adv15():
@@ -105,7 +107,7 @@ def impl_defog_broker_adv15():
countries = PARTITION(selected_customers, name="custs", by=country)
n_active = SUM(custs.status == "active")
n_custs = COUNT(custs)
- return countries(
+ return countries.CALCULATE(
country,
ar=100 * DEFAULT_TO(n_active / n_custs, 0.0),
)
@@ -118,7 +120,7 @@ def impl_defog_broker_basic3():
What are the 2 most frequently bought stock ticker symbols in the past 10
days? Return the ticker symbol and number of buy transactions.
"""
- return Tickers(
+ return Tickers.CALCULATE(
symbol,
num_transactions=COUNT(transactions_of),
total_amount=SUM(transactions_of.amount),
@@ -132,12 +134,16 @@ def impl_defog_broker_basic4():
What are the top 5 combinations of customer state and ticker type by
number of transactions? Return the customer state, ticker type and number of transactions.
"""
- data = Customers.transactions_made.ticker(state=BACK(2).state)
- return PARTITION(data, name="combo", by=(state, ticker_type))(
- state,
- ticker_type,
- num_transactions=COUNT(combo),
- ).TOP_K(5, by=num_transactions.DESC())
+ data = Customers.CALCULATE(state=state).transactions_made.ticker
+ return (
+ PARTITION(data, name="combo", by=(state, ticker_type))
+ .CALCULATE(
+ state,
+ ticker_type,
+ num_transactions=COUNT(combo),
+ )
+ .TOP_K(5, by=num_transactions.DESC())
+ )
def impl_defog_broker_basic5():
@@ -146,7 +152,9 @@ def impl_defog_broker_basic5():
Return the distinct list of customer IDs who have made a 'buy' transaction.
"""
- return Customers.WHERE(HAS(transactions_made.WHERE(transaction_type == "buy")))(_id)
+ return Customers.WHERE(
+ HAS(transactions_made.WHERE(transaction_type == "buy"))
+ ).CALCULATE(_id)
def impl_defog_broker_basic7():
@@ -156,9 +164,11 @@ def impl_defog_broker_basic7():
What are the top 3 transaction statuses by number of transactions? Return
the status and number of transactions.
"""
- return PARTITION(Transactions, name="status_group", by=status)(
- status, num_transactions=COUNT(status_group)
- ).TOP_K(3, by=num_transactions.DESC())
+ return (
+ PARTITION(Transactions, name="status_group", by=status)
+ .CALCULATE(status, num_transactions=COUNT(status_group))
+ .TOP_K(3, by=num_transactions.DESC())
+ )
def impl_defog_broker_basic8():
@@ -168,9 +178,11 @@ def impl_defog_broker_basic8():
What are the top 5 countries by number of customers? Return the country
name and number of customers.
"""
- return PARTITION(Customers, name="custs", by=country)(
- country, num_customers=COUNT(custs)
- ).TOP_K(5, by=num_customers.DESC())
+ return (
+ PARTITION(Customers, name="custs", by=country)
+ .CALCULATE(country, num_customers=COUNT(custs))
+ .TOP_K(5, by=num_customers.DESC())
+ )
def impl_defog_broker_basic9():
@@ -180,7 +192,7 @@ def impl_defog_broker_basic9():
Return the customer ID and name of customers who have not made any
transactions.
"""
- return Customers.WHERE(HASNOT(transactions_made))(_id, name)
+ return Customers.WHERE(HASNOT(transactions_made)).CALCULATE(_id, name)
def impl_defog_broker_basic10():
@@ -190,4 +202,4 @@ def impl_defog_broker_basic10():
Return the ticker ID and symbol of tickers that do not have any daily
price records.
"""
- return Tickers.WHERE(HASNOT(historical_prices))(_id, symbol)
+ return Tickers.WHERE(HASNOT(historical_prices)).CALCULATE(_id, symbol)
diff --git a/tests/exploration_examples.py b/tests/exploration_examples.py
index a7dfaee2..205340d4 100644
--- a/tests/exploration_examples.py
+++ b/tests/exploration_examples.py
@@ -4,7 +4,6 @@
__all__ = [
"contextless_aggfunc_impl",
- "contextless_back_impl",
"contextless_collections_impl",
"contextless_expr_impl",
"contextless_func_impl",
@@ -14,10 +13,6 @@
"global_calc_impl",
"global_impl",
"lineitems_arithmetic_impl",
- "lps_back_lines_impl",
- "lps_back_lines_price_impl",
- "lps_back_supplier_impl",
- "lps_back_supplier_name_impl",
"nation_expr_impl",
"nation_impl",
"nation_name_impl",
@@ -59,25 +54,35 @@ def global_impl() -> UnqualifiedNode:
def global_calc_impl() -> UnqualifiedNode:
- return TPCH(x=42, y=13)
+ return TPCH.CALCULATE(x=42, y=13)
def global_agg_calc_impl() -> UnqualifiedNode:
- return TPCH(n_customers=COUNT(Customers), avg_part_price=AVG(Parts.retail_price))
+ return TPCH.CALCULATE(
+ n_customers=COUNT(Customers), avg_part_price=AVG(Parts.retail_price)
+ )
def table_calc_impl() -> UnqualifiedNode:
- return Nations(name, region_name=region.name, num_customers=COUNT(customers))
+ return Nations.CALCULATE(
+ name, region_name=region.name, num_customers=COUNT(customers)
+ )
def subcollection_calc_backref_impl() -> UnqualifiedNode:
- return Regions.nations.customers(
- name, nation_name=BACK(1).name, region_name=BACK(2).name
+ return (
+ Regions.CALCULATE(region_name=name)
+ .nations.CALCULATE(nation_name=name)
+ .customers.CALCULATE(name, nation_name, region_name)
)
+def calc_subcollection_impl() -> UnqualifiedNode:
+ return Nations.CALCULATE(nation_name=name).region
+
+
def filter_impl() -> UnqualifiedNode:
- return Nations(name).WHERE(
+ return Nations.CALCULATE(nation_name=name).WHERE(
(region.name == "ASIA")
& HAS(customers.orders.lines.WHERE(CONTAINS(part.name, "STEEL")))
& (COUNT(suppliers.WHERE(account_balance >= 0.0)) > 100)
@@ -85,11 +90,11 @@ def filter_impl() -> UnqualifiedNode:
def order_by_impl() -> UnqualifiedNode:
- return Nations(name).ORDER_BY(COUNT(suppliers).DESC(), name.ASC())
+ return Nations.CALCULATE(name).ORDER_BY(COUNT(suppliers).DESC(), name.ASC())
def top_k_impl() -> UnqualifiedNode:
- return Parts(name, n_suppliers=COUNT(suppliers_of_part)).TOP_K(
+ return Parts.CALCULATE(name, n_suppliers=COUNT(suppliers_of_part)).TOP_K(
100, by=(n_suppliers.DESC(), name.ASC())
)
@@ -100,7 +105,8 @@ def partition_impl() -> UnqualifiedNode:
def partition_child_impl() -> UnqualifiedNode:
return (
- PARTITION(Parts, name="p", by=part_type)(
+ PARTITION(Parts, name="p", by=part_type)
+ .CALCULATE(
part_type,
avg_price=AVG(p.retail_price),
)
@@ -118,11 +124,7 @@ def contextless_expr_impl() -> UnqualifiedNode:
def contextless_collections_impl() -> UnqualifiedNode:
- return lines(extended_price, name=part.name)
-
-
-def contextless_back_impl() -> UnqualifiedNode:
- return BACK(1).fizz
+ return lines.CALCULATE(extended_price, name=part.name)
def contextless_func_impl() -> UnqualifiedNode:
@@ -154,23 +156,7 @@ def region_nations_suppliers_name_impl() -> tuple[UnqualifiedNode, UnqualifiedNo
def region_nations_back_name() -> tuple[UnqualifiedNode, UnqualifiedNode]:
- return Regions.nations, BACK(1).name
-
-
-def lps_back_supplier_name_impl() -> tuple[UnqualifiedNode, UnqualifiedNode]:
- return Lineitems.part, BACK(1).supplier.name
-
-
-def lps_back_supplier_impl() -> tuple[UnqualifiedNode, UnqualifiedNode]:
- return Lineitems.part, BACK(1).supplier
-
-
-def lps_back_lines_price_impl() -> tuple[UnqualifiedNode, UnqualifiedNode]:
- return PartSupp.part, BACK(1).lines.extended_price
-
-
-def lps_back_lines_impl() -> tuple[UnqualifiedNode, UnqualifiedNode]:
- return PartSupp.part, BACK(1).lines
+ return Regions.CALCULATE(region_name=name).nations, region_name
def region_n_suppliers_in_red_impl() -> tuple[UnqualifiedNode, UnqualifiedNode]:
diff --git a/tests/simple_pydough_functions.py b/tests/simple_pydough_functions.py
index d4036f5c..4e7add25 100644
--- a/tests/simple_pydough_functions.py
+++ b/tests/simple_pydough_functions.py
@@ -10,62 +10,71 @@
def simple_scan():
- return Orders(key)
+ return Orders.CALCULATE(key)
def simple_filter():
- # Note: The SQL is non-deterministic once we add nested expressions.
- return Orders(o_orderkey=key, o_totalprice=total_price).WHERE(o_totalprice < 1000.0)
+ return Orders.CALCULATE(o_orderkey=key, o_totalprice=total_price).WHERE(
+ o_totalprice < 1000.0
+ )
def simple_scan_top_five():
- return Orders(key).TOP_K(5, by=key.ASC())
+ return Orders.CALCULATE(key).TOP_K(5, by=key.ASC())
def simple_filter_top_five():
- return Orders(key, total_price).WHERE(total_price < 1000.0).TOP_K(5, by=key.DESC())
+ return (
+ Orders.CALCULATE(key, total_price)
+ .WHERE(total_price < 1000.0)
+ .TOP_K(5, by=key.DESC())
+ )
def rank_a():
- return Customers(rank=RANKING(by=acctbal.DESC()))
+ return Customers.CALCULATE(rank=RANKING(by=acctbal.DESC()))
def rank_b():
- return Orders(rank=RANKING(by=(order_priority.ASC()), allow_ties=True))
+ return Orders.CALCULATE(rank=RANKING(by=(order_priority.ASC()), allow_ties=True))
def rank_c():
- return Orders(
+ return Orders.CALCULATE(
order_date, rank=RANKING(by=order_date.ASC(), allow_ties=True, dense=True)
)
def rank_nations_by_region():
- return Nations(name, rank=RANKING(by=region.name.ASC(), allow_ties=True))
+ return Nations.CALCULATE(name, rank=RANKING(by=region.name.ASC(), allow_ties=True))
def rank_nations_per_region_by_customers():
- return Regions.nations(
+ return Regions.nations.CALCULATE(
name, rank=RANKING(by=COUNT(customers).DESC(), levels=1)
).TOP_K(5, by=rank.ASC())
def rank_parts_per_supplier_region_by_size():
- return Regions.nations.suppliers.supply_records.part(
- key,
- region=BACK(4).name,
- rank=RANKING(
- by=(size.DESC(), container.DESC(), part_type.DESC()),
- levels=4,
- allow_ties=True,
- dense=True,
- ),
- ).TOP_K(15, by=key.ASC())
+ return (
+ Regions.CALCULATE(region_name=name)
+ .nations.suppliers.supply_records.part.CALCULATE(
+ key,
+ region=region_name,
+ rank=RANKING(
+ by=(size.DESC(), container.DESC(), part_type.DESC()),
+ levels=4,
+ allow_ties=True,
+ dense=True,
+ ),
+ )
+ .TOP_K(15, by=key.ASC())
+ )
def rank_with_filters_a():
return (
- Customers(n=name, r=RANKING(by=acctbal.DESC()))
+ Customers.CALCULATE(n=name, r=RANKING(by=acctbal.DESC()))
.WHERE(ENDSWITH(name, "0"))
.WHERE(r <= 30)
)
@@ -73,7 +82,7 @@ def rank_with_filters_a():
def rank_with_filters_b():
return (
- Customers(n=name, r=RANKING(by=acctbal.DESC()))
+ Customers.CALCULATE(n=name, r=RANKING(by=acctbal.DESC()))
.WHERE(r <= 30)
.WHERE(ENDSWITH(name, "0"))
)
@@ -83,7 +92,7 @@ def rank_with_filters_c():
return (
PARTITION(Parts, name="p", by=size)
.TOP_K(5, by=size.DESC())
- .p(size, name)
+ .p.CALCULATE(size, name)
.WHERE(RANKING(by=retail_price.DESC(), levels=1) == 1)
)
@@ -91,7 +100,7 @@ def rank_with_filters_c():
def percentile_nations():
# For every nation, give its name & its bucket from 1-5 ordered by name
# alphabetically
- return Nations(name, p=PERCENTILE(by=name.ASC(), n_buckets=5))
+ return Nations.CALCULATE(name, p=PERCENTILE(by=name.ASC(), n_buckets=5))
def percentile_customers_per_region():
@@ -100,7 +109,7 @@ def percentile_customers_per_region():
# means more money) and whose phone number ends in two zeros, sorted by the
# name of the customers
return (
- Regions.nations.customers(name)
+ Regions.nations.customers.CALCULATE(name)
.WHERE((PERCENTILE(by=(acctbal.ASC()), levels=2) == 95) & ENDSWITH(phone, "00"))
.ORDER_BY(name.ASC())
)
@@ -112,14 +121,18 @@ def regional_suppliers_percentile():
pct = PERCENTILE(
by=(COUNT(supply_records).ASC(), name.ASC()), levels=2, n_buckets=1000
)
- return Regions.nations.suppliers(name).WHERE(HAS(supply_records) & (pct == 1000))
+ return Regions.nations.suppliers.CALCULATE(name).WHERE(
+ HAS(supply_records) & (pct == 1000)
+ )
def function_sampler():
# Examples of using different functions
return (
- Regions.nations.customers(
- a=JOIN_STRINGS("-", BACK(2).name, BACK(1).name, name[16:]),
+ Regions.CALCULATE(region_name=name)
+ .nations.CALCULATE(nation_name=name)
+ .customers.CALCULATE(
+ a=JOIN_STRINGS("-", region_name, nation_name, name[16:]),
b=ROUND(acctbal, 1),
c=KEEP_IF(name, phone[:1] == "3"),
d=PRESENT(KEEP_IF(name, phone[1:2] == "1")),
@@ -131,7 +144,7 @@ def function_sampler():
def datetime_current():
- return TPCH(
+ return TPCH.CALCULATE(
d1=DATETIME("now", "start of year", "5 months", "-1 DAY"),
d2=DATETIME("current_date", "start of mm", "+24 hours"),
d3=DATETIME(
@@ -144,7 +157,7 @@ def datetime_relative():
selected_orders = Orders.TOP_K(
10, by=(customer_key.ASC(), order_date.ASC())
).ORDER_BY(order_date.ASC())
- return selected_orders(
+ return selected_orders.CALCULATE(
d1=DATETIME(order_date, "Start of Year"),
d2=DATETIME(order_date, "START OF MONTHS"),
d3=DATETIME(
@@ -166,7 +179,7 @@ def datetime_sampler():
# Near-exhaustive edge cases coverage testing for DATETIME strings. The
# terms were generated via random combination selection of various ways
# of augmenting the base/modifier terms.
- return Orders(
+ return Orders.CALCULATE(
DATETIME("2025-07-04 12:58:45"),
DATETIME("2024-12-31 11:59:00"),
DATETIME("2025-01-01"),
@@ -334,21 +347,21 @@ def datetime_sampler():
def loop_generated_terms():
- # Using a loop & dictionary to generate PyDough calc terms
+ # Using a loop & dictionary to generate PyDough calculate terms
terms = {"name": name}
for i in range(3):
terms[f"interval_{i}"] = COUNT(
customers.WHERE(MONOTONIC(i * 1000, acctbal, (i + 1) * 1000))
)
- return Nations(**terms)
+ return Nations.CALCULATE(**terms)
def function_defined_terms():
- # Using a regular function to generate PyDough calc terms
+ # Using a regular function to generate PyDough calculate terms
def interval_n(n):
return COUNT(customers.WHERE(MONOTONIC(n * 1000, acctbal, (n + 1) * 1000)))
- return Nations(
+ return Nations.CALCULATE(
name,
interval_7=interval_n(7),
interval_4=interval_n(4),
@@ -357,11 +370,12 @@ def interval_n(n):
def function_defined_terms_with_duplicate_names():
- # Using a regular function to generate PyDough calc terms with the function argument same as collection's fields.
+ # Using a regular function to generate PyDough calculate terms with the
+ # function argument same as collection's fields.
def interval_n(n, name="test"):
return COUNT(customers.WHERE(MONOTONIC(n * 1000, acctbal, (n + 1) * 1000)))
- return Nations(
+ return Nations.CALCULATE(
name,
redefined_name=name,
interval_7=interval_n(7),
@@ -371,12 +385,12 @@ def interval_n(n, name="test"):
def lambda_defined_terms():
- # Using a lambda function to generate PyDough calc terms
+ # Using a lambda function to generate PyDough calculate terms
interval_n = lambda n: COUNT(
customers.WHERE(MONOTONIC(n * 1000, acctbal, (n + 1) * 1000))
)
- return Nations(
+ return Nations.CALCULATE(
name,
interval_7=interval_n(7),
interval_4=interval_n(4),
@@ -385,7 +399,7 @@ def lambda_defined_terms():
def dict_comp_terms():
- # Using a dictionary comprehension to generate PyDough calc terms
+ # Using a dictionary comprehension to generate PyDough calculate terms
terms = {"name": name}
terms.update(
{
@@ -395,11 +409,11 @@ def dict_comp_terms():
for i in range(3)
}
)
- return Nations(**terms)
+ return Nations.CALCULATE(**terms)
def list_comp_terms():
- # Using a list comprehension to generate PyDough calc terms
+ # Using a list comprehension to generate PyDough calculate terms
terms = [name]
terms.extend(
[
@@ -407,11 +421,11 @@ def list_comp_terms():
for i in range(3)
]
)
- return Nations(*terms)
+ return Nations.CALCULATE(*terms)
def set_comp_terms():
- # Using a set comprehension to generate PyDough calc terms
+ # Using a set comprehension to generate PyDough calculate terms
terms = [name]
terms.extend(
set(
@@ -422,11 +436,11 @@ def set_comp_terms():
)
)
sorted_terms = sorted(terms, key=lambda x: repr(x))
- return Nations(*sorted_terms)
+ return Nations.CALCULATE(*sorted_terms)
def generator_comp_terms():
- # Using a generator comprehension to generate PyDough calc terms
+ # Using a generator comprehension to generate PyDough calculate terms
terms = {"name": name}
for term, value in (
(
@@ -436,21 +450,30 @@ def generator_comp_terms():
for i in range(3)
):
terms[term] = value
- return Nations(**terms)
+ return Nations.CALCULATE(**terms)
+
+
+def partition_as_child():
+ # Count how many part sizes have an above-average number of parts of that
+ # size.
+ sizes = PARTITION(Parts, name="p", by=size).CALCULATE(n_parts=COUNT(p))
+ return TPCH.CALCULATE(avg_n_parts=AVG(sizes.n_parts)).CALCULATE(
+ n_parts=COUNT(sizes.WHERE(n_parts > avg_n_parts))
+ )
def agg_partition():
# Doing a global aggregation on the output of a partition aggregation
- yearly_data = PARTITION(Orders(year=YEAR(order_date)), name="orders", by=year)(
- n_orders=COUNT(orders)
- )
- return TPCH(best_year=MAX(yearly_data.n_orders))
+ yearly_data = PARTITION(
+ Orders.CALCULATE(year=YEAR(order_date)), name="orders", by=year
+ ).CALCULATE(n_orders=COUNT(orders))
+ return TPCH.CALCULATE(best_year=MAX(yearly_data.n_orders))
def multi_partition_access_1():
# A use of multiple PARTITION and stepping into partition children that is
# a no-op.
- data = Tickers(symbol).TOP_K(5, by=symbol.ASC())
+ data = Tickers.CALCULATE(symbol).TOP_K(5, by=symbol.ASC())
grps_a = PARTITION(data, name="child_3", by=(currency, exchange, ticker_type))
grps_b = PARTITION(grps_a, name="child_2", by=(currency, exchange))
grps_c = PARTITION(grps_b, name="child_1", by=exchange)
@@ -461,42 +484,147 @@ def multi_partition_access_2():
# Identify transactions that are below the average number of shares for
# transactions of the same combinations of (customer, stock, type), or
# the same combination of (customer, stock), or the same customer.
- grps_a = PARTITION(
- Transactions, name="child_3", by=(customer_id, ticker_id, transaction_type)
- )(avg_shares_a=AVG(child_3.shares))
- grps_b = PARTITION(grps_a, name="child_2", by=(customer_id, ticker_id))(
- avg_shares_b=AVG(child_2.child_3.shares)
- )
- grps_c = PARTITION(grps_b, name="child_1", by=customer_id)(
- avg_shares_c=AVG(child_1.child_2.child_3.shares)
- )
- return grps_c.child_1.child_2.child_3.WHERE(
- (shares < BACK(1).avg_shares_a)
- & (shares < BACK(2).avg_shares_b)
- & (shares < BACK(3).avg_shares_c)
- )(
- transaction_id,
- customer.name,
- ticker.symbol,
- transaction_type,
- BACK(1).avg_shares_a,
- BACK(2).avg_shares_b,
- BACK(3).avg_shares_c,
- ).ORDER_BY(transaction_id.ASC())
+ cust_tick_typ_groups = PARTITION(
+ Transactions,
+ name="original_data",
+ by=(customer_id, ticker_id, transaction_type),
+ ).CALCULATE(cus_tick_typ_avg_shares=AVG(original_data.shares))
+ cust_tick_groups = PARTITION(
+ cust_tick_typ_groups, name="typs", by=(customer_id, ticker_id)
+ ).CALCULATE(cust_tick_avg_shares=AVG(typs.original_data.shares))
+ cus_groups = PARTITION(cust_tick_groups, name="ticks", by=customer_id).CALCULATE(
+ cust_avg_shares=AVG(ticks.typs.original_data.shares)
+ )
+ return (
+ cus_groups.ticks.typs.original_data.WHERE(
+ (shares < cus_tick_typ_avg_shares)
+ & (shares < cust_tick_avg_shares)
+ & (shares < cust_avg_shares)
+ )
+ .CALCULATE(
+ transaction_id,
+ customer.name,
+ ticker.symbol,
+ transaction_type,
+ cus_tick_typ_avg_shares,
+ cust_tick_avg_shares,
+ cust_avg_shares,
+ )
+ .ORDER_BY(transaction_id.ASC())
+ )
+
+
+def multi_partition_access_3():
+ # Find all daily price updates whose closing price was the high mark for
+ # that ticker, but not for tickers of that type.
+ data = Tickers.CALCULATE(symbol, ticker_type).historical_prices
+ ticker_groups = PARTITION(data, name="ticker_data", by=ticker_id).CALCULATE(
+ ticker_high_price=MAX(ticker_data.close)
+ )
+ type_groups = PARTITION(
+ ticker_groups.ticker_data, name="type_data", by=ticker_type
+ ).CALCULATE(type_high_price=MAX(type_data.close))
+ return (
+ type_groups.type_data.WHERE(
+ (close == ticker_high_price) & (close < type_high_price)
+ )
+ .CALCULATE(symbol, close)
+ .ORDER_BY(symbol.ASC())
+ )
+
+
+def multi_partition_access_4():
+ # Find all transacitons that were the largest for a customer of that ticker
+ # (by number of shares) but not the largest for that customer overall.
+ cust_ticker_groups = PARTITION(
+ Transactions, name="data", by=(customer_id, ticker_id)
+ ).CALCULATE(cust_ticker_max_shares=MAX(data.shares))
+ cust_groups = PARTITION(
+ cust_ticker_groups, name="ticker_groups", by=customer_id
+ ).CALCULATE(cust_max_shares=MAX(ticker_groups.cust_ticker_max_shares))
+ return (
+ cust_groups.ticker_groups.data.WHERE(
+ (shares >= cust_ticker_max_shares) & (shares < cust_max_shares)
+ )
+ .CALCULATE(transaction_id)
+ .ORDER_BY(transaction_id.ASC())
+ )
+
+
+def multi_partition_access_5():
+ # Find all transactions where more than 80% of all transactions of that
+ # that ticker were of that type, but less than 20% of all transactions of
+ # that type were from that ticker. List the transaction ID, the number of
+ # transactions of that ticker/type, ticker, and type. Sort by the number of
+ # transactions of that ticker/type, breaking ties by trnasaction ID.
+ ticker_type_groups = PARTITION(
+ Transactions, name="data", by=(ticker_id, transaction_type)
+ ).CALCULATE(n_ticker_type_trans=COUNT(data))
+ ticker_groups = PARTITION(
+ ticker_type_groups, name="sub_trans", by=ticker_id
+ ).CALCULATE(n_ticker_trans=SUM(sub_trans.n_ticker_type_trans))
+ type_groups = PARTITION(
+ ticker_groups.sub_trans, name="sub_trans", by=transaction_type
+ ).CALCULATE(n_type_trans=SUM(sub_trans.n_ticker_type_trans))
+ return (
+ type_groups.sub_trans.data.CALCULATE(
+ transaction_id,
+ n_ticker_type_trans,
+ n_ticker_trans,
+ n_type_trans,
+ )
+ .WHERE(
+ ((n_ticker_type_trans / n_ticker_trans) > 0.8)
+ & ((n_ticker_type_trans / n_type_trans) < 0.2)
+ )
+ .ORDER_BY(n_ticker_type_trans.ASC(), transaction_id.ASC())
+ )
+
+
+def multi_partition_access_6():
+ # Find all transactions that are the only transaction of that type for
+ # that ticker, or the only transaction of that type for that customer,
+ # but not the only transaction for that customer, type, or ticker. List
+ # the transaction IDs in ascending order.
+ ticker_type_groups = PARTITION(
+ Transactions, name="data", by=(ticker_id, transaction_type)
+ ).CALCULATE(n_ticker_type_trans=COUNT(data))
+ ticker_groups = PARTITION(
+ ticker_type_groups, name="sub_trans", by=ticker_id
+ ).CALCULATE(n_ticker_trans=SUM(sub_trans.n_ticker_type_trans))
+ type_groups = PARTITION(
+ ticker_groups.sub_trans, name="sub_trans", by=transaction_type
+ ).CALCULATE(n_type_trans=SUM(sub_trans.n_ticker_type_trans))
+ cust_type_groups = PARTITION(
+ type_groups.sub_trans.data, name="data", by=(customer_id, transaction_type)
+ ).CALCULATE(n_cust_type_trans=COUNT(data))
+ cust_groups = PARTITION(
+ cust_type_groups, name="sub_trans", by=customer_id
+ ).CALCULATE(n_cust_trans=SUM(sub_trans.n_cust_type_trans))
+ return (
+ cust_groups.sub_trans.data.CALCULATE(transaction_id)
+ .WHERE(
+ ((n_ticker_type_trans == 1) | (n_cust_type_trans == 1))
+ & (n_cust_trans > 1)
+ & (n_type_trans > 1)
+ & (n_ticker_trans > 1)
+ )
+ .ORDER_BY(transaction_id.ASC())
+ )
def double_partition():
# Doing a partition aggregation on the output of a partition aggregation
year_month_data = PARTITION(
- Orders(year=YEAR(order_date), month=MONTH(order_date)),
+ Orders.CALCULATE(year=YEAR(order_date), month=MONTH(order_date)),
name="orders",
by=(year, month),
- )(n_orders=COUNT(orders))
+ ).CALCULATE(n_orders=COUNT(orders))
return PARTITION(
year_month_data,
name="months",
by=year,
- )(year, best_month=MAX(months.n_orders))
+ ).CALCULATE(year, best_month=MAX(months.n_orders))
def triple_partition():
@@ -507,30 +635,29 @@ def triple_partition():
# customer region. Only considers lineitems from June of 1992 where the
# container is small.
line_info = (
- Parts.WHERE(
- STARTSWITH(container, "SM"),
- )
- .lines.WHERE((MONTH(ship_date) == 6) & (YEAR(ship_date) == 1992))(
- supp_region=supplier.nation.region.name,
- )
- .order.WHERE(YEAR(order_date) == 1992)(
- supp_region=BACK(1).supp_region,
- part_type=BACK(2).part_type,
- cust_region=customer.nation.region.name,
- )
+ Parts.CALCULATE(part_type)
+ .WHERE(STARTSWITH(container, "SM"))
+ .lines.WHERE((MONTH(ship_date) == 6) & (YEAR(ship_date) == 1992))
+ .CALCULATE(supp_region=supplier.nation.region.name)
+ .order.WHERE(YEAR(order_date) == 1992)
+ .CALCULATE(cust_region=customer.nation.region.name)
)
rrt_combos = PARTITION(
line_info, name="lines", by=(supp_region, cust_region, part_type)
- )(n_instances=COUNT(lines))
- rr_combos = PARTITION(rrt_combos, name="part_types", by=(supp_region, cust_region))(
+ ).CALCULATE(n_instances=COUNT(lines))
+ rr_combos = PARTITION(
+ rrt_combos, name="part_types", by=(supp_region, cust_region)
+ ).CALCULATE(
percentage=100.0 * MAX(part_types.n_instances) / SUM(part_types.n_instances)
)
- return PARTITION(
- rr_combos,
- name="cust_regions",
- by=supp_region,
- )(supp_region, avg_percentage=AVG(cust_regions.percentage)).ORDER_BY(
- supp_region.ASC()
+ return (
+ PARTITION(
+ rr_combos,
+ name="cust_regions",
+ by=supp_region,
+ )
+ .CALCULATE(supp_region, avg_percentage=AVG(cust_regions.percentage))
+ .ORDER_BY(supp_region.ASC())
)
@@ -541,7 +668,7 @@ def hour_minute_day():
ordered by transaction ID in ascending order.
"""
return (
- Transactions(
+ Transactions.CALCULATE(
transaction_id, HOUR(date_time), MINUTE(date_time), SECOND(date_time)
)
.WHERE(ISIN(ticker.symbol, ("AAPL", "GOOGL", "NFLX")))
@@ -550,7 +677,7 @@ def hour_minute_day():
def exponentiation():
- return DailyPrices(
+ return DailyPrices.CALCULATE(
low_square=low**2,
low_sqrt=SQRT(low),
low_cbrt=POWER(low, 1 / 3),
@@ -564,7 +691,7 @@ def impl(*args, **kwargs):
terms[f"n_{color}"] = COUNT(parts.WHERE(CONTAINS(part_name, color)))
for n, size in kwargs.items():
terms[n] = COUNT(parts.WHERE(size == size))
- return TPCH(**terms)
+ return TPCH.CALCULATE(**terms)
result = impl("tomato", "almond", small=10, large=40)
return result
@@ -586,7 +713,7 @@ def unpacking_in_iterable():
terms = {}
for i, j in zip(range(5), range(1992, 1997)):
terms[f"c{i}"] = COUNT(orders.WHERE(YEAR(order_date) == j))
- return Nations(**terms)
+ return Nations.CALCULATE(**terms)
def with_import_statement():
@@ -637,52 +764,66 @@ def annotated_assignment():
def abs_round_magic_method():
- return DailyPrices(abs_low=abs(low), round_low=round(low, 2), round_zero=round(low))
+ return DailyPrices.CALCULATE(
+ abs_low=abs(low), round_low=round(low, 2), round_zero=round(low)
+ )
def years_months_days_hours_datediff():
y1_datetime = datetime.datetime(2025, 5, 2, 11, 00, 0)
- return Transactions.WHERE((YEAR(date_time) < 2025))(
- x=date_time,
- y1=y1_datetime,
- years_diff=DATEDIFF("years", date_time, y1_datetime),
- c_years_diff=DATEDIFF("YEARS", date_time, y1_datetime),
- c_y_diff=DATEDIFF("Y", date_time, y1_datetime),
- y_diff=DATEDIFF("y", date_time, y1_datetime),
- months_diff=DATEDIFF("months", date_time, y1_datetime),
- c_months_diff=DATEDIFF("MONTHS", date_time, y1_datetime),
- mm_diff=DATEDIFF("mm", date_time, y1_datetime),
- days_diff=DATEDIFF("days", date_time, y1_datetime),
- c_days_diff=DATEDIFF("DAYS", date_time, y1_datetime),
- c_d_diff=DATEDIFF("D", date_time, y1_datetime),
- d_diff=DATEDIFF("d", date_time, y1_datetime),
- hours_diff=DATEDIFF("hours", date_time, y1_datetime),
- c_hours_diff=DATEDIFF("HOURS", date_time, y1_datetime),
- c_h_diff=DATEDIFF("H", date_time, y1_datetime),
- ).TOP_K(30, by=years_diff.ASC())
+ return (
+ Transactions.WHERE((YEAR(date_time) < 2025))
+ .CALCULATE(
+ x=date_time,
+ y1=y1_datetime,
+ years_diff=DATEDIFF("years", date_time, y1_datetime),
+ c_years_diff=DATEDIFF("YEARS", date_time, y1_datetime),
+ c_y_diff=DATEDIFF("Y", date_time, y1_datetime),
+ y_diff=DATEDIFF("y", date_time, y1_datetime),
+ months_diff=DATEDIFF("months", date_time, y1_datetime),
+ c_months_diff=DATEDIFF("MONTHS", date_time, y1_datetime),
+ mm_diff=DATEDIFF("mm", date_time, y1_datetime),
+ days_diff=DATEDIFF("days", date_time, y1_datetime),
+ c_days_diff=DATEDIFF("DAYS", date_time, y1_datetime),
+ c_d_diff=DATEDIFF("D", date_time, y1_datetime),
+ d_diff=DATEDIFF("d", date_time, y1_datetime),
+ hours_diff=DATEDIFF("hours", date_time, y1_datetime),
+ c_hours_diff=DATEDIFF("HOURS", date_time, y1_datetime),
+ c_h_diff=DATEDIFF("H", date_time, y1_datetime),
+ )
+ .TOP_K(30, by=years_diff.ASC())
+ )
def minutes_seconds_datediff():
y_datetime = datetime.datetime(2023, 4, 3, 13, 16, 30)
- return Transactions.WHERE(YEAR(date_time) <= 2024)(
- x=date_time,
- y=y_datetime,
- minutes_diff=DATEDIFF("m", date_time, y_datetime),
- seconds_diff=DATEDIFF("s", date_time, y_datetime),
- ).TOP_K(30, by=x.DESC())
+ return (
+ Transactions.WHERE(YEAR(date_time) <= 2024)
+ .CALCULATE(
+ x=date_time,
+ y=y_datetime,
+ minutes_diff=DATEDIFF("m", date_time, y_datetime),
+ seconds_diff=DATEDIFF("s", date_time, y_datetime),
+ )
+ .TOP_K(30, by=x.DESC())
+ )
def datediff():
y1_datetime = datetime.datetime(2025, 5, 2, 11, 00, 0)
y_datetime = datetime.datetime(2023, 4, 3, 13, 16, 30)
- return Transactions.WHERE((YEAR(date_time) < 2025))(
- x=date_time,
- y1=y1_datetime,
- y=y_datetime,
- years_diff=DATEDIFF("years", date_time, y1_datetime),
- months_diff=DATEDIFF("months", date_time, y1_datetime),
- days_diff=DATEDIFF("days", date_time, y1_datetime),
- hours_diff=DATEDIFF("hours", date_time, y1_datetime),
- minutes_diff=DATEDIFF("minutes", date_time, y_datetime),
- seconds_diff=DATEDIFF("seconds", date_time, y_datetime),
- ).TOP_K(30, by=years_diff.ASC())
+ return (
+ Transactions.WHERE((YEAR(date_time) < 2025))
+ .CALCULATE(
+ x=date_time,
+ y1=y1_datetime,
+ y=y_datetime,
+ years_diff=DATEDIFF("years", date_time, y1_datetime),
+ months_diff=DATEDIFF("months", date_time, y1_datetime),
+ days_diff=DATEDIFF("days", date_time, y1_datetime),
+ hours_diff=DATEDIFF("hours", date_time, y1_datetime),
+ minutes_diff=DATEDIFF("minutes", date_time, y_datetime),
+ seconds_diff=DATEDIFF("seconds", date_time, y_datetime),
+ )
+ .TOP_K(30, by=years_diff.ASC())
+ )
diff --git a/tests/test_exploration.py b/tests/test_exploration.py
index 02f7b703..cb369425 100644
--- a/tests/test_exploration.py
+++ b/tests/test_exploration.py
@@ -6,8 +6,8 @@
import pytest
from exploration_examples import (
+ calc_subcollection_impl,
contextless_aggfunc_impl,
- contextless_back_impl,
contextless_collections_impl,
contextless_expr_impl,
contextless_func_impl,
@@ -17,10 +17,6 @@
global_calc_impl,
global_impl,
lineitems_arithmetic_impl,
- lps_back_lines_impl,
- lps_back_lines_price_impl,
- lps_back_supplier_impl,
- lps_back_supplier_name_impl,
nation_expr_impl,
nation_impl,
nation_name_impl,
@@ -769,8 +765,6 @@ def test_graph_structure(
The following terms will be included in the result if this collection is executed:
comment, key, name, region_key
-It is possible to use BACK to go up to 1 level above this collection.
-
The collection has access to the following expressions:
comment, key, name, region_key
@@ -806,12 +800,10 @@ def test_graph_structure(
PyDough collection representing the following logic:
TPCH
-This node is a reference to the global context for the entire graph. An operation must be done onto this node (e.g. a CALC or accessing a collection) before it can be executed.
+This node is a reference to the global context for the entire graph. An operation must be done onto this node (e.g. a CALCULATE or accessing a collection) before it can be executed.
The collection does not have any terms that can be included in a result if it is executed.
-It is not possible to use BACK from this collection.
-
The collection has access to the following collections:
Customers, Lineitems, Nations, Orders, PartSupp, Parts, Regions, Suppliers
@@ -819,7 +811,7 @@ def test_graph_structure(
expressions or collections that the collection has access to.
""",
"""
-This node is a reference to the global context for the entire graph. An operation must be done onto this node (e.g. a CALC or accessing a collection) before it can be executed.
+This node is a reference to the global context for the entire graph. An operation must be done onto this node (e.g. a CALCULATE or accessing a collection) before it can be executed.
The collection has access to the following collections:
Customers, Lineitems, Nations, Orders, PartSupp, Parts, Regions, Suppliers
@@ -839,7 +831,7 @@ def test_graph_structure(
"""
PyDough collection representing the following logic:
┌─── TPCH
- └─── Calc[x=42, y=13]
+ └─── Calculate[x=42, y=13]
The main task of this node is to calculate the following additional expressions that are added to the terms of the collection:
x <- 42
@@ -848,8 +840,6 @@ def test_graph_structure(
The following terms will be included in the result if this collection is executed:
x, y
-It is not possible to use BACK from this collection.
-
The collection has access to the following expressions:
x, y
@@ -885,7 +875,7 @@ def test_graph_structure(
"""
PyDough collection representing the following logic:
┌─── TPCH
- └─┬─ Calc[n_customers=COUNT($1), avg_part_price=AVG($2.retail_price)]
+ └─┬─ Calculate[n_customers=COUNT($1), avg_part_price=AVG($2.retail_price)]
├─┬─ AccessChild
│ └─── TableCollection[Customers]
└─┬─ AccessChild
@@ -904,8 +894,6 @@ def test_graph_structure(
The following terms will be included in the result if this collection is executed:
avg_part_price, n_customers
-It is not possible to use BACK from this collection.
-
The collection has access to the following expressions:
avg_part_price, n_customers
@@ -946,7 +934,7 @@ def test_graph_structure(
PyDough collection representing the following logic:
──┬─ TPCH
├─── TableCollection[Nations]
- └─┬─ Calc[name=name, region_name=$1.name, num_customers=COUNT($2)]
+ └─┬─ Calculate[name=name, region_name=$1.name, num_customers=COUNT($2)]
├─┬─ AccessChild
│ └─── SubCollection[region]
└─┬─ AccessChild
@@ -966,8 +954,6 @@ def test_graph_structure(
The following terms will be included in the result if this collection is executed:
name, num_customers, region_name
-It is possible to use BACK to go up to 1 level above this collection.
-
The collection has access to the following expressions:
comment, key, name, num_customers, region_key, region_name
@@ -1008,21 +994,21 @@ def test_graph_structure(
"""
PyDough collection representing the following logic:
──┬─ TPCH
- └─┬─ TableCollection[Regions]
- └─┬─ SubCollection[nations]
+ ├─── TableCollection[Regions]
+ └─┬─ Calculate[region_name=name]
+ ├─── SubCollection[nations]
+ └─┬─ Calculate[nation_name=name]
├─── SubCollection[customers]
- └─── Calc[name=name, nation_name=BACK(1).name, region_name=BACK(2).name]
+ └─── Calculate[name=name, nation_name=nation_name, region_name=region_name]
The main task of this node is to calculate the following additional expressions that are added to the terms of the collection:
name <- name (propagated from previous collection)
- nation_name <- BACK(1).name
- region_name <- BACK(2).name
+ nation_name <- nation_name (propagated from previous collection)
+ region_name <- region_name (propagated from previous collection)
The following terms will be included in the result if this collection is executed:
name, nation_name, region_name
-It is possible to use BACK to go up to 3 levels above this collection.
-
The collection has access to the following expressions:
acctbal, address, comment, key, mktsegment, name, nation_key, nation_name, phone, region_name
@@ -1035,8 +1021,8 @@ def test_graph_structure(
"""
The main task of this node is to calculate the following additional expressions that are added to the terms of the collection:
name <- name (propagated from previous collection)
- nation_name <- BACK(1).name
- region_name <- BACK(2).name
+ nation_name <- nation_name (propagated from previous collection)
+ region_name <- region_name (propagated from previous collection)
The collection has access to the following expressions:
acctbal, address, comment, key, mktsegment, name, nation_key, nation_name, phone, region_name
@@ -1052,6 +1038,48 @@ def test_graph_structure(
),
id="subcollection_calc_backref",
),
+ pytest.param(
+ (
+ "TPCH",
+ calc_subcollection_impl,
+ """
+PyDough collection representing the following logic:
+ ──┬─ TPCH
+ ├─── TableCollection[Nations]
+ └─┬─ Calculate[nation_name=name]
+ └─── SubCollection[region]
+
+This node, specifically, accesses the subcollection Nations.region. Call pydough.explain(graph['Nations']['region']) to learn more about this subcollection property.
+
+The following terms will be included in the result if this collection is executed:
+ comment, key, name
+
+The collection has access to the following expressions:
+ comment, key, name, nation_name
+
+The collection has access to the following collections:
+ customers, lines_sourced_from, nations, orders_shipped_to, suppliers
+
+Call pydough.explain_term(collection, term) to learn more about any of these
+expressions or collections that the collection has access to.
+ """,
+ """
+This node, specifically, accesses the subcollection Nations.region. Call pydough.explain(graph['Nations']['region']) to learn more about this subcollection property.
+
+The collection has access to the following expressions:
+ comment, key, name, nation_name
+
+The collection has access to the following collections:
+ customers, lines_sourced_from, nations, orders_shipped_to, suppliers
+
+Call pydough.explain_term(collection, term) to learn more about any of these
+expressions or collections that the collection has access to.
+
+Call pydough.explain(collection, verbose=True) for more details.
+ """,
+ ),
+ id="calc_subcollection",
+ ),
pytest.param(
(
"TPCH",
@@ -1060,7 +1088,7 @@ def test_graph_structure(
PyDough collection representing the following logic:
──┬─ TPCH
├─── TableCollection[Nations]
- ├─── Calc[name=name]
+ ├─── Calculate[nation_name=name]
└─┬─ Where[($1.name == 'ASIA') & HAS($2) & (COUNT($3) > 100)]
├─┬─ AccessChild
│ └─── SubCollection[region]
@@ -1095,12 +1123,10 @@ def test_graph_structure(
COUNT($3) > 100, aka COUNT(suppliers.WHERE(account_balance >= 0.0)) > 100
The following terms will be included in the result if this collection is executed:
- name
-
-It is possible to use BACK to go up to 1 level above this collection.
+ nation_name
The collection has access to the following expressions:
- comment, key, name, region_key
+ comment, key, name, nation_name, region_key
The collection has access to the following collections:
customers, orders_shipped_to, region, suppliers
@@ -1120,7 +1146,7 @@ def test_graph_structure(
COUNT($3) > 100, aka COUNT(suppliers.WHERE(account_balance >= 0.0)) > 100
The collection has access to the following expressions:
- comment, key, name, region_key
+ comment, key, name, nation_name, region_key
The collection has access to the following collections:
customers, orders_shipped_to, region, suppliers
@@ -1141,7 +1167,7 @@ def test_graph_structure(
PyDough collection representing the following logic:
──┬─ TPCH
├─── TableCollection[Nations]
- ├─── Calc[name=name]
+ ├─── Calculate[name=name]
└─┬─ OrderBy[COUNT($1).DESC(na_pos='last'), name.ASC(na_pos='first')]
└─┬─ AccessChild
└─── SubCollection[suppliers]
@@ -1157,8 +1183,6 @@ def test_graph_structure(
The following terms will be included in the result if this collection is executed:
name
-It is possible to use BACK to go up to 1 level above this collection.
-
The collection has access to the following expressions:
comment, key, name, region_key
@@ -1198,7 +1222,7 @@ def test_graph_structure(
PyDough collection representing the following logic:
──┬─ TPCH
├─── TableCollection[Parts]
- ├─┬─ Calc[name=name, n_suppliers=COUNT($1)]
+ ├─┬─ Calculate[name=name, n_suppliers=COUNT($1)]
│ └─┬─ AccessChild
│ └─── SubCollection[suppliers_of_part]
└─── TopK[100, n_suppliers.DESC(na_pos='last'), name.ASC(na_pos='first')]
@@ -1210,8 +1234,6 @@ def test_graph_structure(
The following terms will be included in the result if this collection is executed:
n_suppliers, name
-It is possible to use BACK to go up to 1 level above this collection.
-
The collection has access to the following expressions:
brand, comment, container, key, manufacturer, n_suppliers, name, part_type, retail_price, size
@@ -1262,8 +1284,6 @@ def test_graph_structure(
The following terms will be included in the result if this collection is executed:
part_type
-It is possible to use BACK to go up to 1 level above this collection.
-
The collection has access to the following expressions:
part_type
@@ -1305,22 +1325,19 @@ def test_graph_structure(
├─┬─ Partition[name='p', by=part_type]
│ └─┬─ AccessChild
│ └─── TableCollection[Parts]
- ├─┬─ Calc[part_type=part_type, avg_price=AVG($1.retail_price)]
+ ├─┬─ Calculate[part_type=part_type, avg_price=AVG($1.retail_price)]
│ └─┬─ AccessChild
│ └─── PartitionChild[p]
└─┬─ Where[avg_price >= 27.5]
└─── PartitionChild[p]
This node, specifically, accesses the unpartitioned data of a partitioning (child name: p).
-Using BACK(1) will access the partitioned data.
The following terms will be included in the result if this collection is executed:
brand, comment, container, key, manufacturer, name, part_type, retail_price, size
-It is possible to use BACK to go up to 2 levels above this collection.
-
The collection has access to the following expressions:
- brand, comment, container, key, manufacturer, name, part_type, retail_price, size
+ avg_price, brand, comment, container, key, manufacturer, name, part_type, retail_price, size
The collection has access to the following collections:
lines, suppliers_of_part, supply_records
@@ -1330,10 +1347,9 @@ def test_graph_structure(
""",
"""
This node, specifically, accesses the unpartitioned data of a partitioning (child name: p).
-Using BACK(1) will access the partitioned data.
The collection has access to the following expressions:
- brand, comment, container, key, manufacturer, name, part_type, retail_price, size
+ avg_price, brand, comment, container, key, manufacturer, name, part_type, retail_price, size
The collection has access to the following collections:
lines, suppliers_of_part, supply_records
@@ -1403,50 +1419,35 @@ def test_graph_structure(
),
id="not_qualified_collection_c",
),
- pytest.param(
- (
- "TPCH",
- contextless_back_impl,
- """
-Cannot call pydough.explain on BACK(1).fizz.
-Did you mean to use pydough.explain_term?
-""",
- """
-Cannot call pydough.explain on BACK(1).fizz.
-Did you mean to use pydough.explain_term?
-""",
- ),
- id="not_qualified_collection_d",
- ),
pytest.param(
(
"TPCH",
contextless_aggfunc_impl,
"""
-Cannot call pydough.explain on COUNT(?.customers).
+Cannot call pydough.explain on COUNT(customers).
Did you mean to use pydough.explain_term?
""",
"""
-Cannot call pydough.explain on COUNT(?.customers).
+Cannot call pydough.explain on COUNT(customers).
Did you mean to use pydough.explain_term?
""",
),
- id="not_qualified_collection_e",
+ id="not_qualified_collection_d",
),
pytest.param(
(
"TPCH",
contextless_func_impl,
"""
-Cannot call pydough.explain on LOWER(((?.first_name + ' ') + ?.last_name)).
+Cannot call pydough.explain on LOWER(((first_name + ' ') + last_name)).
Did you mean to use pydough.explain_term?
""",
"""
-Cannot call pydough.explain on LOWER(((?.first_name + ' ') + ?.last_name)).
+Cannot call pydough.explain on LOWER(((first_name + ' ') + last_name)).
Did you mean to use pydough.explain_term?
""",
),
- id="not_qualified_collection_f",
+ id="not_qualified_collection_e",
),
]
)
@@ -1512,9 +1513,9 @@ def test_unqualified_node_exploration(
This is column 'name' of collection 'Nations'
-This term is singular with regards to the collection, meaning it can be placed in a CALC of a collection.
+This term is singular with regards to the collection, meaning it can be placed in a CALCULATE of a collection.
For example, the following is valid:
- TPCH.Nations(name)
+ TPCH.Nations.CALCULATE(name)
""",
"""
Collection: TPCH.Nations
@@ -1543,9 +1544,9 @@ def test_unqualified_node_exploration(
This is a reference to expression 'name' of child $1
-This term is singular with regards to the collection, meaning it can be placed in a CALC of a collection.
+This term is singular with regards to the collection, meaning it can be placed in a CALCULATE of a collection.
For example, the following is valid:
- TPCH.Nations(region.name)
+ TPCH.Nations.CALCULATE(region.name)
""",
"""
Collection: TPCH.Nations
@@ -1575,7 +1576,7 @@ def test_unqualified_node_exploration(
This child is singular with regards to the collection, meaning its scalar terms can be accessed by the collection as if they were scalar terms of the expression.
For example, the following is valid:
- TPCH.Nations(region.comment)
+ TPCH.Nations.CALCULATE(region.comment)
To learn more about this child, you can try calling pydough.explain on the following:
TPCH.Nations.region
@@ -1605,7 +1606,7 @@ def test_unqualified_node_exploration(
This child is plural with regards to the collection, meaning its scalar terms can only be accessed by the collection if they are aggregated.
For example, the following are valid:
- TPCH.Regions(COUNT(nations.suppliers.account_balance))
+ TPCH.Regions.CALCULATE(COUNT(nations.suppliers.account_balance))
TPCH.Regions.WHERE(HAS(nations.suppliers))
TPCH.Regions.ORDER_BY(COUNT(nations.suppliers).DESC())
@@ -1639,9 +1640,9 @@ def test_unqualified_node_exploration(
This is a reference to expression 'name' of child $1
-This expression is plural with regards to the collection, meaning it can be placed in a CALC of a collection if it is aggregated.
+This expression is plural with regards to the collection, meaning it can be placed in a CALCULATE of a collection if it is aggregated.
For example, the following is valid:
- TPCH.Regions(COUNT(nations.suppliers.name))
+ TPCH.Regions.CALCULATE(COUNT(nations.suppliers.name))
""",
"""
Collection: TPCH.Regions
@@ -1663,166 +1664,32 @@ def test_unqualified_node_exploration(
"""
Collection:
──┬─ TPCH
- └─┬─ TableCollection[Regions]
+ ├─── TableCollection[Regions]
+ └─┬─ Calculate[region_name=name]
└─── SubCollection[nations]
-The term is the following expression: BACK(1).name
+The term is the following expression: region_name
-This is a reference to expression 'name' of the 1st ancestor of the collection, which is the following:
+This is a reference to expression 'region_name' of the 1st ancestor of the collection, which is the following:
──┬─ TPCH
- └─── TableCollection[Regions]
+ ├─── TableCollection[Regions]
+ └─── Calculate[region_name=name]
-This term is singular with regards to the collection, meaning it can be placed in a CALC of a collection.
+This term is singular with regards to the collection, meaning it can be placed in a CALCULATE of a collection.
For example, the following is valid:
- TPCH.Regions.nations(BACK(1).name)
+ TPCH.Regions.CALCULATE(region_name=name).nations.CALCULATE(region_name)
""",
"""
-Collection: TPCH.Regions.nations
+Collection: TPCH.Regions.CALCULATE(region_name=name).nations
-The term is the following expression: BACK(1).name
+The term is the following expression: region_name
-This is a reference to expression 'name' of the 1st ancestor of the collection, which is the following:
- TPCH.Regions
+This is a reference to expression 'region_name' of the 1st ancestor of the collection, which is the following:
+ TPCH.Regions.CALCULATE(region_name=name)
""",
),
id="region_nations-back_name",
),
- pytest.param(
- (
- "TPCH",
- lps_back_supplier_name_impl,
- """
-Collection:
- ──┬─ TPCH
- └─┬─ TableCollection[Lineitems]
- └─── SubCollection[part]
-
-The evaluation of this term first derives the following additional children to the collection before doing its main task:
- child $1:
- └─┬─ TableCollection[Lineitems]
- └─── BackSubCollection[1, supplier]
-
-The term is the following expression: $1.name
-
-This is a reference to expression 'name' of child $1
-
-This term is singular with regards to the collection, meaning it can be placed in a CALC of a collection.
-For example, the following is valid:
- TPCH.Lineitems.part(BACK(1).supplier.name)
- """,
- """
-Collection: TPCH.Lineitems.part
-
-The evaluation of this term first derives the following additional children to the collection before doing its main task:
- child $1: BACK(1).supplier
-
-The term is the following expression: $1.name
-
-This is a reference to expression 'name' of child $1
- """,
- ),
- id="lineitem_part-back_supplier_name",
- ),
- pytest.param(
- (
- "TPCH",
- lps_back_supplier_impl,
- """
-Collection:
- ──┬─ TPCH
- └─┬─ TableCollection[Lineitems]
- └─── SubCollection[part]
-
-The term is the following child of the collection:
- ──┬─ TPCH
- └─┬─ TableCollection[Lineitems]
- └─── BackSubCollection[1, supplier]
-
-This child is singular with regards to the collection, meaning its scalar terms can be accessed by the collection as if they were scalar terms of the expression.
-For example, the following is valid:
- TPCH.Lineitems.part(BACK(1).supplier.account_balance)
-
-To learn more about this child, you can try calling pydough.explain on the following:
- TPCH.Lineitems.supplier
- """,
- """
-Collection: TPCH.Lineitems.part
-
-The term is the following child of the collection:
- BACK(1).supplier
- """,
- ),
- id="lineitem_part-back_supplier",
- ),
- pytest.param(
- (
- "TPCH",
- lps_back_lines_price_impl,
- """
-Collection:
- ──┬─ TPCH
- └─┬─ TableCollection[PartSupp]
- └─── SubCollection[part]
-
-The evaluation of this term first derives the following additional children to the collection before doing its main task:
- child $1:
- └─┬─ TableCollection[PartSupp]
- └─── BackSubCollection[1, lines]
-
-The term is the following expression: $1.extended_price
-
-This is a reference to expression 'extended_price' of child $1
-
-This expression is plural with regards to the collection, meaning it can be placed in a CALC of a collection if it is aggregated.
-For example, the following is valid:
- TPCH.PartSupp.part(COUNT(BACK(1).lines.extended_price))
- """,
- """
-Collection: TPCH.PartSupp.part
-
-The evaluation of this term first derives the following additional children to the collection before doing its main task:
- child $1: BACK(1).lines
-
-The term is the following expression: $1.extended_price
-
-This is a reference to expression 'extended_price' of child $1
- """,
- ),
- id="partsupp_part-back_lines_price",
- ),
- pytest.param(
- (
- "TPCH",
- lps_back_lines_impl,
- """
-Collection:
- ──┬─ TPCH
- └─┬─ TableCollection[PartSupp]
- └─── SubCollection[part]
-
-The term is the following child of the collection:
- ──┬─ TPCH
- └─┬─ TableCollection[PartSupp]
- └─── BackSubCollection[1, lines]
-
-This child is plural with regards to the collection, meaning its scalar terms can only be accessed by the collection if they are aggregated.
-For example, the following are valid:
- TPCH.PartSupp.part(COUNT(BACK(1).lines.comment))
- TPCH.PartSupp.part.WHERE(HAS(BACK(1).lines))
- TPCH.PartSupp.part.ORDER_BY(COUNT(BACK(1).lines).DESC())
-
-To learn more about this child, you can try calling pydough.explain on the following:
- TPCH.PartSupp.lines
- """,
- """
-Collection: TPCH.PartSupp.part
-
-The term is the following child of the collection:
- BACK(1).lines
- """,
- ),
- id="partsupp_part-back_lines",
- ),
pytest.param(
(
"TPCH",
@@ -1845,9 +1712,9 @@ def test_unqualified_node_exploration(
Call pydough.explain_term with this collection and any of the arguments to learn more about them.
-This term is singular with regards to the collection, meaning it can be placed in a CALC of a collection.
+This term is singular with regards to the collection, meaning it can be placed in a CALCULATE of a collection.
For example, the following is valid:
- TPCH.Regions(COUNT(nations.suppliers.WHERE(account_balance > 0)))
+ TPCH.Regions.CALCULATE(COUNT(nations.suppliers.WHERE(account_balance > 0)))
""",
"""
Collection: TPCH.Regions
@@ -1887,9 +1754,9 @@ def test_unqualified_node_exploration(
Call pydough.explain_term with this collection and any of the arguments to learn more about them.
-This term is singular with regards to the collection, meaning it can be placed in a CALC of a collection.
+This term is singular with regards to the collection, meaning it can be placed in a CALCULATE of a collection.
For example, the following is valid:
- TPCH.Partition(Parts, name='p', by=part_type)(AVG(p.retail_price))
+ TPCH.Partition(Parts, name='p', by=part_type).CALCULATE(AVG(p.retail_price))
""",
"""
Collection: TPCH.Partition(Parts, name='p', by=part_type)
@@ -1927,7 +1794,7 @@ def test_unqualified_node_exploration(
This child is plural with regards to the collection, meaning its scalar terms can only be accessed by the collection if they are aggregated.
For example, the following are valid:
- TPCH.Partition(Parts, name='p', by=part_type).WHERE(AVG(p.retail_price) >= 27.5)(COUNT(p.brand))
+ TPCH.Partition(Parts, name='p', by=part_type).WHERE(AVG(p.retail_price) >= 27.5).CALCULATE(COUNT(p.brand))
TPCH.Partition(Parts, name='p', by=part_type).WHERE(AVG(p.retail_price) >= 27.5).WHERE(HAS(p))
TPCH.Partition(Parts, name='p', by=part_type).WHERE(AVG(p.retail_price) >= 27.5).ORDER_BY(COUNT(p).DESC())
@@ -1959,9 +1826,9 @@ def test_unqualified_node_exploration(
Call pydough.explain_term with this collection and any of the arguments to learn more about them.
-This term is singular with regards to the collection, meaning it can be placed in a CALC of a collection.
+This term is singular with regards to the collection, meaning it can be placed in a CALCULATE of a collection.
For example, the following is valid:
- TPCH.Nations(LOWER(name))
+ TPCH.Nations.CALCULATE(LOWER(name))
""",
"""
Collection: TPCH.Nations
@@ -1993,9 +1860,9 @@ def test_unqualified_node_exploration(
Call pydough.explain_term with this collection and any of the arguments to learn more about them.
-This term is singular with regards to the collection, meaning it can be placed in a CALC of a collection.
+This term is singular with regards to the collection, meaning it can be placed in a CALCULATE of a collection.
For example, the following is valid:
- TPCH.Lineitems(extended_price * (1 - discount))
+ TPCH.Lineitems.CALCULATE(extended_price * (1 - discount))
""",
"""
Collection: TPCH.Lineitems
@@ -2029,9 +1896,9 @@ def test_unqualified_node_exploration(
Call pydough.explain_term with this collection and any of the arguments to learn more about them.
-This term is singular with regards to the collection, meaning it can be placed in a CALC of a collection.
+This term is singular with regards to the collection, meaning it can be placed in a CALCULATE of a collection.
For example, the following is valid:
- TPCH.Suppliers(IFF(account_balance < 0, 0, account_balance))
+ TPCH.Suppliers.CALCULATE(IFF(account_balance < 0, 0, account_balance))
""",
"""
Collection: TPCH.Suppliers
@@ -2068,9 +1935,9 @@ def test_unqualified_node_exploration(
Call pydough.explain_term with this collection and any of the arguments to learn more about them.
-This term is singular with regards to the collection, meaning it can be placed in a CALC of a collection.
+This term is singular with regards to the collection, meaning it can be placed in a CALCULATE of a collection.
For example, the following is valid:
- TPCH.Customers(HASNOT(orders))
+ TPCH.Customers.CALCULATE(HASNOT(orders))
""",
"""
Collection: TPCH.Customers
@@ -2112,9 +1979,9 @@ def test_unqualified_node_exploration(
Call pydough.explain_term with this collection and any of the arguments to learn more about them.
-This term is singular with regards to the collection, meaning it can be placed in a CALC of a collection.
+This term is singular with regards to the collection, meaning it can be placed in a CALCULATE of a collection.
For example, the following is valid:
- TPCH.Parts(HAS(supply_records.supplier.WHERE(nation.name == 'GERMANY')))
+ TPCH.Parts.CALCULATE(HAS(supply_records.supplier.WHERE(nation.name == 'GERMANY')))
""",
"""
Collection: TPCH.Parts
diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py
index 2cb612bd..adb5f5b8 100644
--- a/tests/test_pipeline.py
+++ b/tests/test_pipeline.py
@@ -48,6 +48,10 @@
minutes_seconds_datediff,
multi_partition_access_1,
multi_partition_access_2,
+ multi_partition_access_3,
+ multi_partition_access_4,
+ multi_partition_access_5,
+ multi_partition_access_6,
percentile_customers_per_region,
percentile_nations,
rank_nations_by_region,
@@ -58,6 +62,7 @@
rank_with_filters_c,
regional_suppliers_percentile,
simple_filter_top_five,
+ simple_scan,
simple_scan_top_five,
triple_partition,
years_months_days_hours_datediff,
@@ -118,6 +123,7 @@
from pydough.configs import PyDoughConfigs
from pydough.conversion.relational_converter import convert_ast_to_relational
from pydough.database_connectors import DatabaseContext
+from pydough.evaluation.evaluate_unqualified import _load_column_selection
from pydough.metadata import GraphMetadata
from pydough.qdag import PyDoughCollectionQDAG, PyDoughQDAG
from pydough.relational import RelationalRoot
@@ -133,6 +139,7 @@
pytest.param(
(
impl_tpch_q1,
+ None,
"tpch_q1",
tpch_q1_output,
),
@@ -141,6 +148,7 @@
pytest.param(
(
impl_tpch_q2,
+ None,
"tpch_q2",
tpch_q2_output,
),
@@ -149,6 +157,7 @@
pytest.param(
(
impl_tpch_q3,
+ None,
"tpch_q3",
tpch_q3_output,
),
@@ -157,6 +166,7 @@
pytest.param(
(
impl_tpch_q4,
+ None,
"tpch_q4",
tpch_q4_output,
),
@@ -165,6 +175,7 @@
pytest.param(
(
impl_tpch_q5,
+ None,
"tpch_q5",
tpch_q5_output,
),
@@ -173,6 +184,7 @@
pytest.param(
(
impl_tpch_q6,
+ None,
"tpch_q6",
tpch_q6_output,
),
@@ -181,6 +193,7 @@
pytest.param(
(
impl_tpch_q7,
+ None,
"tpch_q7",
tpch_q7_output,
),
@@ -189,6 +202,7 @@
pytest.param(
(
impl_tpch_q8,
+ None,
"tpch_q8",
tpch_q8_output,
),
@@ -197,6 +211,7 @@
pytest.param(
(
impl_tpch_q9,
+ None,
"tpch_q9",
tpch_q9_output,
),
@@ -205,6 +220,7 @@
pytest.param(
(
impl_tpch_q10,
+ None,
"tpch_q10",
tpch_q10_output,
),
@@ -213,6 +229,7 @@
pytest.param(
(
impl_tpch_q11,
+ None,
"tpch_q11",
tpch_q11_output,
),
@@ -221,6 +238,7 @@
pytest.param(
(
impl_tpch_q12,
+ None,
"tpch_q12",
tpch_q12_output,
),
@@ -229,6 +247,7 @@
pytest.param(
(
impl_tpch_q13,
+ None,
"tpch_q13",
tpch_q13_output,
),
@@ -237,6 +256,7 @@
pytest.param(
(
impl_tpch_q14,
+ None,
"tpch_q14",
tpch_q14_output,
),
@@ -245,6 +265,7 @@
pytest.param(
(
impl_tpch_q15,
+ None,
"tpch_q15",
tpch_q15_output,
),
@@ -253,6 +274,7 @@
pytest.param(
(
impl_tpch_q16,
+ None,
"tpch_q16",
tpch_q16_output,
),
@@ -261,6 +283,7 @@
pytest.param(
(
impl_tpch_q17,
+ None,
"tpch_q17",
tpch_q17_output,
),
@@ -269,6 +292,7 @@
pytest.param(
(
impl_tpch_q18,
+ None,
"tpch_q18",
tpch_q18_output,
),
@@ -277,6 +301,7 @@
pytest.param(
(
impl_tpch_q19,
+ None,
"tpch_q19",
tpch_q19_output,
),
@@ -285,6 +310,7 @@
pytest.param(
(
impl_tpch_q20,
+ None,
"tpch_q20",
tpch_q20_output,
),
@@ -293,6 +319,7 @@
pytest.param(
(
impl_tpch_q21,
+ None,
"tpch_q21",
tpch_q21_output,
),
@@ -301,6 +328,7 @@
pytest.param(
(
impl_tpch_q22,
+ None,
"tpch_q22",
tpch_q22_output,
),
@@ -309,6 +337,7 @@
pytest.param(
(
simple_scan_top_five,
+ None,
"simple_scan_top_five",
lambda: pd.DataFrame(
{
@@ -321,11 +350,11 @@
pytest.param(
(
simple_filter_top_five,
+ ["key"],
"simple_filter_top_five",
lambda: pd.DataFrame(
{
"key": [5989315, 5935174, 5881093, 5876066, 5866437],
- "total_price": [947.81, 974.01, 995.6, 967.55, 916.41],
}
),
),
@@ -334,6 +363,7 @@
pytest.param(
(
rank_nations_by_region,
+ None,
"rank_nations_by_region",
lambda: pd.DataFrame(
{
@@ -373,6 +403,7 @@
pytest.param(
(
rank_nations_per_region_by_customers,
+ None,
"rank_nations_per_region_by_customers",
lambda: pd.DataFrame(
{
@@ -386,6 +417,7 @@
pytest.param(
(
rank_parts_per_supplier_region_by_size,
+ None,
"rank_parts_per_supplier_region_by_size",
lambda: pd.DataFrame(
{
@@ -432,6 +464,7 @@
pytest.param(
(
rank_with_filters_a,
+ None,
"rank_with_filters_a",
lambda: pd.DataFrame(
{
@@ -449,6 +482,7 @@
pytest.param(
(
rank_with_filters_b,
+ None,
"rank_with_filters_b",
lambda: pd.DataFrame(
{
@@ -466,17 +500,18 @@
pytest.param(
(
rank_with_filters_c,
+ {"pname": "name", "psize": "size"},
"rank_with_filters_c",
lambda: pd.DataFrame(
{
- "size": [46, 47, 48, 49, 50],
- "name": [
+ "pname": [
"frosted powder drab burnished grey",
"lace khaki orange bisque beige",
"steel chartreuse navy ivory brown",
"forest azure almond antique violet",
"blanched floral red maroon papaya",
],
+ "psize": [46, 47, 48, 49, 50],
}
),
),
@@ -485,6 +520,7 @@
pytest.param(
(
percentile_nations,
+ {"name": "name", "p1": "p", "p2": "p"},
"percentile_nations",
lambda: pd.DataFrame(
{
@@ -515,7 +551,8 @@
"UNITED STATES",
"VIETNAM",
],
- "p": [1] * 5 + [2] * 5 + [3] * 5 + [4] * 5 + [5] * 5,
+ "p1": [1] * 5 + [2] * 5 + [3] * 5 + [4] * 5 + [5] * 5,
+ "p2": [1] * 5 + [2] * 5 + [3] * 5 + [4] * 5 + [5] * 5,
}
),
),
@@ -524,6 +561,7 @@
pytest.param(
(
percentile_customers_per_region,
+ None,
"percentile_customers_per_region",
lambda: pd.DataFrame(
{
@@ -547,6 +585,7 @@
pytest.param(
(
regional_suppliers_percentile,
+ ["name"],
"regional_suppliers_percentile",
lambda: pd.DataFrame(
{
@@ -568,6 +607,7 @@
pytest.param(
(
function_sampler,
+ None,
"function_sampler",
lambda: pd.DataFrame(
{
@@ -609,6 +649,7 @@
pytest.param(
(
datetime_current,
+ None,
"datetime_current",
lambda: pd.DataFrame(
{
@@ -630,6 +671,7 @@
pytest.param(
(
datetime_relative,
+ None,
"datetime_relative",
lambda: pd.DataFrame(
{
@@ -672,6 +714,7 @@
pytest.param(
(
agg_partition,
+ None,
"agg_partition",
lambda: pd.DataFrame(
{
@@ -684,6 +727,7 @@
pytest.param(
(
double_partition,
+ None,
"double_partition",
lambda: pd.DataFrame(
{
@@ -697,17 +741,18 @@
pytest.param(
(
triple_partition,
+ {"region": "supp_region", "avgpct": "avg_percentage"},
"triple_partition",
lambda: pd.DataFrame(
{
- "supp_region": [
+ "region": [
"AFRICA",
"AMERICA",
"ASIA",
"EUROPE",
"MIDDLE EAST",
],
- "avg_percentage": [
+ "avgpct": [
1.8038152,
1.9968418,
1.6850716,
@@ -722,10 +767,17 @@
pytest.param(
(
correl_1,
+ None,
"correl_1",
lambda: pd.DataFrame(
{
- "name": ["AFRICA", "AMERICA", "ASIA", "EUROPE", "MIDDLE EAST"],
+ "region_name": [
+ "AFRICA",
+ "AMERICA",
+ "ASIA",
+ "EUROPE",
+ "MIDDLE EAST",
+ ],
"n_prefix_nations": [1, 1, 0, 0, 0],
}
),
@@ -735,6 +787,7 @@
pytest.param(
(
correl_2,
+ None,
"correl_2",
lambda: pd.DataFrame(
{
@@ -770,10 +823,17 @@
pytest.param(
(
correl_3,
+ None,
"correl_3",
lambda: pd.DataFrame(
{
- "name": ["AFRICA", "AMERICA", "ASIA", "EUROPE", "MIDDLE EAST"],
+ "region_name": [
+ "AFRICA",
+ "AMERICA",
+ "ASIA",
+ "EUROPE",
+ "MIDDLE EAST",
+ ],
"n_nations": [5, 5, 5, 0, 2],
}
),
@@ -783,6 +843,7 @@
pytest.param(
(
correl_4,
+ None,
"correl_4",
lambda: pd.DataFrame(
{
@@ -795,6 +856,7 @@
pytest.param(
(
correl_5,
+ None,
"correl_5",
lambda: pd.DataFrame(
{
@@ -807,6 +869,7 @@
pytest.param(
(
correl_6,
+ None,
"correl_6",
lambda: pd.DataFrame(
{
@@ -820,6 +883,7 @@
pytest.param(
(
correl_7,
+ None,
"correl_7",
lambda: pd.DataFrame(
{
@@ -833,6 +897,7 @@
pytest.param(
(
correl_8,
+ None,
"correl_8",
lambda: pd.DataFrame(
{
@@ -872,6 +937,7 @@
pytest.param(
(
correl_9,
+ None,
"correl_9",
lambda: pd.DataFrame(
{
@@ -888,6 +954,7 @@
pytest.param(
(
correl_10,
+ None,
"correl_10",
lambda: pd.DataFrame(
{
@@ -925,6 +992,7 @@
pytest.param(
(
correl_11,
+ None,
"correl_11",
lambda: pd.DataFrame(
{"brand": ["Brand#33", "Brand#43", "Brand#45", "Brand#55"]}
@@ -935,6 +1003,7 @@
pytest.param(
(
correl_12,
+ None,
"correl_12",
lambda: pd.DataFrame(
{
@@ -953,6 +1022,7 @@
pytest.param(
(
correl_13,
+ None,
"correl_13",
lambda: pd.DataFrame({"n": [1129]}),
),
@@ -961,6 +1031,7 @@
pytest.param(
(
correl_14,
+ None,
"correl_14",
lambda: pd.DataFrame({"n": [66]}),
),
@@ -969,6 +1040,7 @@
pytest.param(
(
correl_15,
+ None,
"correl_15",
lambda: pd.DataFrame({"n": [61]}),
),
@@ -977,6 +1049,7 @@
pytest.param(
(
correl_16,
+ None,
"correl_16",
lambda: pd.DataFrame({"n": [929]}),
),
@@ -985,6 +1058,7 @@
pytest.param(
(
correl_17,
+ None,
"correl_17",
lambda: pd.DataFrame(
{
@@ -1023,6 +1097,7 @@
pytest.param(
(
correl_18,
+ None,
"correl_18",
lambda: pd.DataFrame({"n": [697]}),
),
@@ -1031,10 +1106,11 @@
pytest.param(
(
correl_19,
+ None,
"correl_19",
lambda: pd.DataFrame(
{
- "name": [
+ "supplier_name": [
"Supplier#000003934",
"Supplier#000003887",
"Supplier#000002628",
@@ -1050,6 +1126,7 @@
pytest.param(
(
correl_20,
+ None,
"correl_20",
lambda: pd.DataFrame({"n": [3002]}),
),
@@ -1058,6 +1135,7 @@
pytest.param(
(
correl_21,
+ None,
"correl_21",
lambda: pd.DataFrame({"n_sizes": [30]}),
),
@@ -1066,6 +1144,7 @@
pytest.param(
(
correl_22,
+ None,
"correl_22",
lambda: pd.DataFrame(
{
@@ -1085,6 +1164,7 @@
pytest.param(
(
correl_23,
+ None,
"correl_23",
lambda: pd.DataFrame({"n_sizes": [23]}),
),
@@ -1095,16 +1175,21 @@
def pydough_pipeline_test_data(
request,
) -> tuple[
- Callable[[UnqualifiedRoot], UnqualifiedNode], str, Callable[[], pd.DataFrame]
+ Callable[[], UnqualifiedNode],
+ dict[str, str] | list[str] | None,
+ str,
+ Callable[[], pd.DataFrame],
]:
"""
Test data for test_pydough_pipeline. Returns a tuple of the following
arguments:
1. `unqualified_impl`: a function that takes in an unqualified root and
creates the unqualified node for the TPCH query.
- 2. `file_name`: the name of the file containing the expected relational
+ 2. `columns`: a valid value for the `columns` argument of `to_sql` or
+ `to_df`.
+ 3. `file_name`: the name of the file containing the expected relational
plan.
- 3. `answer_impl`: a function that takes in nothing and returns the answer
+ 4. `answer_impl`: a function that takes in nothing and returns the answer
to a TPCH query as a Pandas DataFrame.
"""
return request.param
@@ -1112,7 +1197,10 @@ def pydough_pipeline_test_data(
def test_pipeline_until_relational(
pydough_pipeline_test_data: tuple[
- Callable[[UnqualifiedRoot], UnqualifiedNode], str, Callable[[], pd.DataFrame]
+ Callable[[], UnqualifiedNode],
+ dict[str, str] | list[str] | None,
+ str,
+ Callable[[], pd.DataFrame],
],
get_sample_graph: graph_fetcher,
default_config: PyDoughConfigs,
@@ -1126,7 +1214,7 @@ def test_pipeline_until_relational(
# Run the query through the stages from unqualified node to qualified node
# to relational tree, and confirm the tree string matches the expected
# structure.
- unqualified_impl, file_name, _ = pydough_pipeline_test_data
+ unqualified_impl, columns, file_name, _ = pydough_pipeline_test_data
file_path: str = get_plan_test_filename(file_name)
graph: GraphMetadata = get_sample_graph("TPCH")
UnqualifiedRoot(graph)
@@ -1135,7 +1223,9 @@ def test_pipeline_until_relational(
assert isinstance(
qualified, PyDoughCollectionQDAG
), "Expected qualified answer to be a collection, not an expression"
- relational: RelationalRoot = convert_ast_to_relational(qualified, default_config)
+ relational: RelationalRoot = convert_ast_to_relational(
+ qualified, _load_column_selection({"columns": columns}), default_config
+ )
if update_tests:
with open(file_path, "w") as f:
f.write(relational.to_tree_string() + "\n")
@@ -1150,7 +1240,10 @@ def test_pipeline_until_relational(
@pytest.mark.execute
def test_pipeline_e2e(
pydough_pipeline_test_data: tuple[
- Callable[[UnqualifiedRoot], UnqualifiedNode], str, Callable[[], pd.DataFrame]
+ Callable[[], UnqualifiedNode],
+ dict[str, str] | list[str] | None,
+ str,
+ Callable[[], pd.DataFrame],
],
get_sample_graph: graph_fetcher,
sqlite_tpch_db_context: DatabaseContext,
@@ -1158,41 +1251,78 @@ def test_pipeline_e2e(
"""
Test executing the TPC-H queries from the original code generation.
"""
- unqualified_impl, _, answer_impl = pydough_pipeline_test_data
+ unqualified_impl, columns, _, answer_impl = pydough_pipeline_test_data
graph: GraphMetadata = get_sample_graph("TPCH")
root: UnqualifiedNode = init_pydough_context(graph)(unqualified_impl)()
- result: pd.DataFrame = to_df(root, metadata=graph, database=sqlite_tpch_db_context)
+ result: pd.DataFrame = to_df(
+ root, columns=columns, metadata=graph, database=sqlite_tpch_db_context
+ )
pd.testing.assert_frame_equal(result, answer_impl())
@pytest.mark.execute
@pytest.mark.parametrize(
- "impl, error_msg",
+ "impl, columns, error_msg",
[
pytest.param(
bad_slice_1,
+ None,
"SLICE function currently only supports non-negative stop indices",
id="bad_slice_1",
),
pytest.param(
bad_slice_2,
+ None,
"SLICE function currently only supports non-negative start indices",
id="bad_slice_2",
),
pytest.param(
bad_slice_3,
+ None,
"SLICE function currently only supports a step of 1",
id="bad_slice_3",
),
pytest.param(
bad_slice_4,
+ None,
"SLICE function currently only supports a step of 1",
id="bad_slice_4",
),
+ pytest.param(
+ simple_scan,
+ [],
+ "Column selection must not be empty",
+ id="bad_columns_1",
+ ),
+ pytest.param(
+ simple_scan,
+ {},
+ "Column selection must not be empty",
+ id="bad_columns_2",
+ ),
+ pytest.param(
+ simple_scan,
+ ["A", "B", "C"],
+ "Unrecognized term of simple table collection 'Orders' in graph 'TPCH': 'A'",
+ id="bad_columns_3",
+ ),
+ pytest.param(
+ simple_scan,
+ {"X": "key", "W": "Y"},
+ "Unrecognized term of simple table collection 'Orders' in graph 'TPCH': 'Y'",
+ id="bad_columns_4",
+ ),
+ pytest.param(
+ simple_scan,
+ ["key", "key"],
+ "Duplicate column names found in root.",
+ id="bad_columns_5",
+ ),
],
)
def test_pipeline_e2e_errors(
- impl: Callable[[UnqualifiedRoot], UnqualifiedNode],
+ impl: Callable[[], UnqualifiedNode],
+ columns: dict[str, str] | list[str] | None,
error_msg: str,
get_sample_graph: graph_fetcher,
sqlite_tpch_db_context: DatabaseContext,
@@ -1204,7 +1334,7 @@ def test_pipeline_e2e_errors(
graph: GraphMetadata = get_sample_graph("TPCH")
with pytest.raises(Exception, match=error_msg):
root: UnqualifiedNode = init_pydough_context(graph)(impl)()
- to_df(root, metadata=graph, database=sqlite_tpch_db_context)
+ to_df(root, columns=columns, metadata=graph, database=sqlite_tpch_db_context)
@pytest.fixture(
@@ -1212,7 +1342,9 @@ def test_pipeline_e2e_errors(
pytest.param(
(
multi_partition_access_1,
+ None,
"Broker",
+ "multi_partition_access_1",
lambda: pd.DataFrame(
{"symbol": ["AAPL", "AMZN", "BRK.B", "FB", "GOOG"]}
),
@@ -1222,7 +1354,9 @@ def test_pipeline_e2e_errors(
pytest.param(
(
multi_partition_access_2,
+ None,
"Broker",
+ "multi_partition_access_2",
lambda: pd.DataFrame(
{
"transaction_id": [f"TX{i:03}" for i in (22, 24, 25, 27, 56)],
@@ -1235,18 +1369,154 @@ def test_pipeline_e2e_errors(
],
"symbol": ["MSFT", "TSLA", "GOOGL", "BRK.B", "FB"],
"transaction_type": ["sell", "sell", "buy", "buy", "sell"],
- "avg_shares_a": [56.66667, 55.0, 4.0, 55.5, 47.5],
- "avg_shares_b": [50.0, 41.66667, 3.33333, 37.33333, 47.5],
- "avg_shares_c": [50.625, 46.25, 40.0, 37.33333, 50.625],
+ "cus_tick_typ_avg_shares": [56.66667, 55.0, 4.0, 55.5, 47.5],
+ "cust_tick_avg_shares": [
+ 50.0,
+ 41.66667,
+ 3.33333,
+ 37.33333,
+ 47.5,
+ ],
+ "cust_avg_shares": [50.625, 46.25, 40.0, 37.33333, 50.625],
}
),
),
id="multi_partition_access_2",
),
+ pytest.param(
+ (
+ multi_partition_access_3,
+ None,
+ "Broker",
+ "multi_partition_access_3",
+ lambda: pd.DataFrame(
+ {
+ "symbol": [
+ "AAPL",
+ "AMZN",
+ "FB",
+ "GOOGL",
+ "JPM",
+ "MSFT",
+ "NFLX",
+ "PG",
+ "TSLA",
+ "V",
+ ],
+ "close": [
+ 153.5,
+ 3235,
+ 207,
+ 2535,
+ 133.75,
+ 284,
+ 320.5,
+ 143.25,
+ 187.75,
+ 223.5,
+ ],
+ }
+ ),
+ ),
+ id="multi_partition_access_3",
+ ),
+ pytest.param(
+ (
+ multi_partition_access_4,
+ None,
+ "Broker",
+ "multi_partition_access_4",
+ lambda: pd.DataFrame(
+ {
+ "transaction_id": [
+ f"TX{i:03}"
+ for i in (3, 4, 5, 6, 7, 8, 9, 40, 41, 42, 43, 47, 48, 49)
+ ],
+ }
+ ),
+ ),
+ id="multi_partition_access_4",
+ ),
+ pytest.param(
+ (
+ multi_partition_access_5,
+ None,
+ "Broker",
+ "multi_partition_access_5",
+ lambda: pd.DataFrame(
+ {
+ "transaction_id": [
+ f"TX{i:03}"
+ for i in (
+ 40,
+ 41,
+ 42,
+ 43,
+ 2,
+ 4,
+ 6,
+ 22,
+ 24,
+ 26,
+ 32,
+ 34,
+ 36,
+ 46,
+ 48,
+ 50,
+ 52,
+ 54,
+ 56,
+ )
+ ],
+ "n_ticker_type_trans": [1] * 4 + [5] * 15,
+ "n_ticker_trans": [1] * 4 + [6] * 15,
+ "n_type_trans": [29, 27] * 2 + [27] * 15,
+ }
+ ),
+ ),
+ id="multi_partition_access_5",
+ ),
+ pytest.param(
+ (
+ multi_partition_access_6,
+ None,
+ "Broker",
+ "multi_partition_access_6",
+ lambda: pd.DataFrame(
+ {
+ "transaction_id": [
+ f"TX{i:03}"
+ for i in (
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 30,
+ 46,
+ 47,
+ 48,
+ 49,
+ 50,
+ )
+ ],
+ }
+ ),
+ ),
+ id="multi_partition_access_6",
+ ),
pytest.param(
(
hour_minute_day,
+ None,
"Broker",
+ "hour_minute_day",
lambda: pd.DataFrame(
{
"transaction_id": [
@@ -1271,12 +1541,14 @@ def test_pipeline_e2e_errors(
}
),
),
- id="broker_basic1",
+ id="hour_minute_day",
),
pytest.param(
(
exponentiation,
+ None,
"Broker",
+ "exponentiation",
lambda: pd.DataFrame(
{
"low_square": [
@@ -1323,7 +1595,9 @@ def test_pipeline_e2e_errors(
pytest.param(
(
years_months_days_hours_datediff,
+ None,
"Broker",
+ "years_months_days_hours_datediff",
lambda: pd.DataFrame(
data={
"x": [
@@ -1474,7 +1748,9 @@ def test_pipeline_e2e_errors(
pytest.param(
(
minutes_seconds_datediff,
+ None,
"Broker",
+ "minutes_seconds_datediff",
lambda: pd.DataFrame(
{
"x": [
@@ -1583,21 +1859,76 @@ def test_pipeline_e2e_errors(
)
def custom_defog_test_data(
request,
-) -> tuple[Callable[[], UnqualifiedNode], str, pd.DataFrame]:
+) -> tuple[
+ Callable[[], UnqualifiedNode],
+ dict[str, str] | list[str] | None,
+ str,
+ str,
+ pd.DataFrame,
+]:
"""
Test data for test_defog_e2e. Returns a tuple of the following
arguments:
1. `unqualified_impl`: a PyDough implementation function.
- 2. `graph_name`: the name of the graph from the defog database to use.
- 3. `answer_impl`: a function that takes in nothing and returns the answer
+ 2. `columns`: the columns to select from the relational plan (optional).
+ 3. `graph_name`: the name of the graph from the defog database to use.
+ 4. `file_name`: the name of the file containing the expected relational
+ plan.
+ 5. `answer_impl`: a function that takes in nothing and returns the answer
to a defog query as a Pandas DataFrame.
"""
return request.param
+def test_defog_until_relational(
+ custom_defog_test_data: tuple[
+ Callable[[], UnqualifiedNode],
+ dict[str, str] | list[str] | None,
+ str,
+ str,
+ pd.DataFrame,
+ ],
+ defog_graphs: graph_fetcher,
+ default_config: PyDoughConfigs,
+ get_plan_test_filename: Callable[[str], str],
+ update_tests: bool,
+):
+ """
+ Same as `test_pipeline_until_relational`, but for defog data.
+ """
+ unqualified_impl, columns, graph_name, file_name, _ = custom_defog_test_data
+ graph: GraphMetadata = defog_graphs(graph_name)
+ init_pydough_context(graph)(unqualified_impl)()
+ file_path: str = get_plan_test_filename(file_name)
+ UnqualifiedRoot(graph)
+ unqualified: UnqualifiedNode = init_pydough_context(graph)(unqualified_impl)()
+ qualified: PyDoughQDAG = qualify_node(unqualified, graph)
+ assert isinstance(
+ qualified, PyDoughCollectionQDAG
+ ), "Expected qualified answer to be a collection, not an expression"
+ relational: RelationalRoot = convert_ast_to_relational(
+ qualified, _load_column_selection({"columns": columns}), default_config
+ )
+ if update_tests:
+ with open(file_path, "w") as f:
+ f.write(relational.to_tree_string() + "\n")
+ else:
+ with open(file_path) as f:
+ expected_relational_string: str = f.read()
+ assert (
+ relational.to_tree_string() == expected_relational_string.strip()
+ ), "Mismatch between tree string representation of relational node and expected Relational tree string"
+
+
@pytest.mark.execute
def test_defog_e2e_with_custom_data(
- custom_defog_test_data: tuple[Callable[[], UnqualifiedNode], str, pd.DataFrame],
+ custom_defog_test_data: tuple[
+ Callable[[], UnqualifiedNode],
+ dict[str, str] | list[str] | None,
+ str,
+ str,
+ pd.DataFrame,
+ ],
defog_graphs: graph_fetcher,
sqlite_defog_connection: DatabaseContext,
):
@@ -1606,8 +1937,10 @@ def test_defog_e2e_with_custom_data(
comparing against the result of running the reference SQL query text on the
same database connector.
"""
- unqualified_impl, graph_name, answer_impl = custom_defog_test_data
+ unqualified_impl, columns, graph_name, _, answer_impl = custom_defog_test_data
graph: GraphMetadata = defog_graphs(graph_name)
root: UnqualifiedNode = init_pydough_context(graph)(unqualified_impl)()
- result: pd.DataFrame = to_df(root, metadata=graph, database=sqlite_defog_connection)
+ result: pd.DataFrame = to_df(
+ root, columns=columns, metadata=graph, database=sqlite_defog_connection
+ )
pd.testing.assert_frame_equal(result, answer_impl())
diff --git a/tests/test_plan_refsols/correl_1.txt b/tests/test_plan_refsols/correl_1.txt
index bcc6d73d..04d8dd6a 100644
--- a/tests/test_plan_refsols/correl_1.txt
+++ b/tests/test_plan_refsols/correl_1.txt
@@ -1,10 +1,12 @@
-ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings=[(ordering_1):asc_first])
- PROJECT(columns={'n_prefix_nations': n_prefix_nations, 'name': name, 'ordering_1': name})
- PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name})
- JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name})
- SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
+ROOT(columns=[('region_name', region_name), ('n_prefix_nations', n_prefix_nations)], orderings=[(ordering_1):asc_first])
+ PROJECT(columns={'n_prefix_nations': n_prefix_nations, 'ordering_1': region_name, 'region_name': region_name})
+ PROJECT(columns={'n_prefix_nations': DEFAULT_TO(agg_0, 0:int64), 'region_name': region_name})
+ JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'region_name': t0.region_name})
+ PROJECT(columns={'key': key, 'region_name': name})
+ SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()})
- FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key})
- JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name})
- SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
+ FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(region_name, None:unknown, 1:int64, None:unknown), columns={'key': key})
+ JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'name_3': t1.name, 'region_name': t0.region_name})
+ PROJECT(columns={'key': key, 'region_name': name})
+ SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey})
diff --git a/tests/test_plan_refsols/correl_10.txt b/tests/test_plan_refsols/correl_10.txt
index e762421d..b1417167 100644
--- a/tests/test_plan_refsols/correl_10.txt
+++ b/tests/test_plan_refsols/correl_10.txt
@@ -3,6 +3,7 @@ ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_fir
PROJECT(columns={'name': name, 'rname': NULL_2})
FILTER(condition=True:bool, columns={'NULL_2': NULL_2, 'name': name})
JOIN(conditions=[t0.region_key == t1.key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}, correl_name='corr1')
- SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey})
- FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'key': key})
+ PROJECT(columns={'name': name, 'nation_name': name, 'region_key': region_key})
+ SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey})
+ FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.nation_name, None:unknown, 1:int64, None:unknown), columns={'key': key})
SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
diff --git a/tests/test_plan_refsols/correl_12.txt b/tests/test_plan_refsols/correl_12.txt
index 60626462..aef57cf1 100644
--- a/tests/test_plan_refsols/correl_12.txt
+++ b/tests/test_plan_refsols/correl_12.txt
@@ -1,13 +1,13 @@
ROOT(columns=[('brand', brand)], orderings=[(ordering_2):asc_first])
PROJECT(columns={'brand': brand, 'ordering_2': brand})
FILTER(condition=True:bool, columns={'brand': brand})
- JOIN(conditions=[t0.brand == t1.brand], types=['semi'], columns={'brand': t0.brand}, correl_name='corr2')
- PROJECT(columns={'avg_price': avg_price, 'avg_price_2': agg_1, 'brand': brand})
- JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1, 'avg_price': t0.avg_price, 'brand': t1.brand})
- PROJECT(columns={'avg_price': agg_0})
+ JOIN(conditions=[t0.brand == t1.brand], types=['semi'], columns={'brand': t0.brand}, correl_name='corr1')
+ PROJECT(columns={'brand': brand, 'brand_avg_price': agg_1, 'global_avg_price': global_avg_price})
+ JOIN(conditions=[True:bool], types=['left'], columns={'agg_1': t1.agg_1, 'brand': t1.brand, 'global_avg_price': t0.global_avg_price})
+ PROJECT(columns={'global_avg_price': agg_0})
AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)})
SCAN(table=tpch.PART, columns={'retail_price': p_retailprice})
AGGREGATE(keys={'brand': brand}, aggregations={'agg_1': AVG(retail_price)})
SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice})
- FILTER(condition=retail_price > corr2.avg_price_2 & retail_price < corr2.avg_price & size < 3:int64, columns={'brand': brand})
+ FILTER(condition=retail_price > corr1.brand_avg_price & retail_price < corr1.global_avg_price & size < 3:int64, columns={'brand': brand})
SCAN(table=tpch.PART, columns={'brand': p_brand, 'retail_price': p_retailprice, 'size': p_size})
diff --git a/tests/test_plan_refsols/correl_15.txt b/tests/test_plan_refsols/correl_15.txt
index b71be80d..c063805c 100644
--- a/tests/test_plan_refsols/correl_15.txt
+++ b/tests/test_plan_refsols/correl_15.txt
@@ -2,12 +2,12 @@ ROOT(columns=[('n', n)], orderings=[])
PROJECT(columns={'n': agg_1})
AGGREGATE(keys={}, aggregations={'agg_1': COUNT()})
FILTER(condition=True:bool, columns={'account_balance': account_balance})
- JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr4')
- PROJECT(columns={'account_balance': account_balance, 'avg_price': avg_price, 'avg_price_3': agg_0, 'key': key})
- JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'avg_price': t0.avg_price, 'key': t0.key})
- FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'avg_price': avg_price, 'key': key})
- JOIN(conditions=[True:bool], types=['inner'], columns={'account_balance': t1.account_balance, 'avg_price': t0.avg_price, 'key': t1.key, 'nation_key': t1.nation_key})
- PROJECT(columns={'avg_price': agg_0})
+ JOIN(conditions=[t0.key == t1.supplier_key], types=['semi'], columns={'account_balance': t0.account_balance}, correl_name='corr3')
+ PROJECT(columns={'account_balance': account_balance, 'global_avg_price': global_avg_price, 'key': key, 'supplier_avg_price': agg_0})
+ JOIN(conditions=[t0.key == t1.supplier_key], types=['left'], columns={'account_balance': t0.account_balance, 'agg_0': t1.agg_0, 'global_avg_price': t0.global_avg_price, 'key': t0.key})
+ FILTER(condition=nation_key == 19:int64, columns={'account_balance': account_balance, 'global_avg_price': global_avg_price, 'key': key})
+ JOIN(conditions=[True:bool], types=['inner'], columns={'account_balance': t1.account_balance, 'global_avg_price': t0.global_avg_price, 'key': t1.key, 'nation_key': t1.nation_key})
+ PROJECT(columns={'global_avg_price': agg_0})
AGGREGATE(keys={}, aggregations={'agg_0': AVG(retail_price)})
SCAN(table=tpch.PART, columns={'retail_price': p_retailprice})
SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey, 'nation_key': s_nationkey})
@@ -16,7 +16,7 @@ ROOT(columns=[('n', n)], orderings=[])
SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey})
SCAN(table=tpch.PART, columns={'key': p_partkey, 'retail_price': p_retailprice})
FILTER(condition=True:bool, columns={'supplier_key': supplier_key})
- JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr3')
+ JOIN(conditions=[t0.part_key == t1.key], types=['semi'], columns={'supplier_key': t0.supplier_key}, correl_name='corr2')
SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost})
- FILTER(condition=container == 'LG DRUM':string & retail_price < corr3.supplycost * 1.5:float64 & retail_price < corr4.avg_price_3 & retail_price < corr4.avg_price * 0.85:float64, columns={'key': key})
+ FILTER(condition=container == 'LG DRUM':string & retail_price < corr2.supplycost * 1.5:float64 & retail_price < corr3.supplier_avg_price & retail_price < corr3.global_avg_price * 0.85:float64, columns={'key': key})
SCAN(table=tpch.PART, columns={'container': p_container, 'key': p_partkey, 'retail_price': p_retailprice})
diff --git a/tests/test_plan_refsols/correl_18.txt b/tests/test_plan_refsols/correl_18.txt
index 74fab0da..b8762428 100644
--- a/tests/test_plan_refsols/correl_18.txt
+++ b/tests/test_plan_refsols/correl_18.txt
@@ -8,10 +8,10 @@ ROOT(columns=[('n', n)], orderings=[])
FILTER(condition=YEAR(order_date) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date})
SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'order_date': o_orderdate})
AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_2': COUNT()})
- FILTER(condition=total_price_3 >= 0.5:float64 * total_price, columns={'customer_key': customer_key, 'order_date': order_date})
- FILTER(condition=YEAR(order_date_2) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price, 'total_price_3': total_price_3})
- JOIN(conditions=[t0.customer_key == t1.customer_key & t0.order_date == t1.order_date], types=['inner'], columns={'customer_key': t0.customer_key, 'order_date': t0.order_date, 'order_date_2': t1.order_date, 'total_price': t0.total_price, 'total_price_3': t1.total_price})
- PROJECT(columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': DEFAULT_TO(agg_1, 0:int64)})
+ FILTER(condition=total_price >= 0.5:float64 * total_price_sum, columns={'customer_key': customer_key, 'order_date': order_date})
+ FILTER(condition=YEAR(order_date_2) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price, 'total_price_sum': total_price_sum})
+ JOIN(conditions=[t0.customer_key == t1.customer_key & t0.order_date == t1.order_date], types=['inner'], columns={'customer_key': t0.customer_key, 'order_date': t0.order_date, 'order_date_2': t1.order_date, 'total_price': t1.total_price, 'total_price_sum': t0.total_price_sum})
+ PROJECT(columns={'customer_key': customer_key, 'order_date': order_date, 'total_price_sum': DEFAULT_TO(agg_1, 0:int64)})
FILTER(condition=DEFAULT_TO(agg_0, 0:int64) > 1:int64, columns={'agg_1': agg_1, 'customer_key': customer_key, 'order_date': order_date})
AGGREGATE(keys={'customer_key': customer_key, 'order_date': order_date}, aggregations={'agg_0': COUNT(), 'agg_1': SUM(total_price)})
FILTER(condition=YEAR(order_date) == 1993:int64, columns={'customer_key': customer_key, 'order_date': order_date, 'total_price': total_price})
diff --git a/tests/test_plan_refsols/correl_19.txt b/tests/test_plan_refsols/correl_19.txt
index ddaa017a..7499f08d 100644
--- a/tests/test_plan_refsols/correl_19.txt
+++ b/tests/test_plan_refsols/correl_19.txt
@@ -1,10 +1,11 @@
-ROOT(columns=[('name', name_0), ('n_super_cust', n_super_cust)], orderings=[(ordering_1):desc_last])
- LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'n_super_cust': n_super_cust, 'name_0': name_0, 'ordering_1': ordering_1}, orderings=[(ordering_1):desc_last])
- PROJECT(columns={'n_super_cust': n_super_cust, 'name_0': name_0, 'ordering_1': n_super_cust})
- PROJECT(columns={'n_super_cust': DEFAULT_TO(agg_0, 0:int64), 'name_0': name})
- JOIN(conditions=[t0.key_2 == t1.key_5 & t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name})
- JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name': t0.name})
- SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey})
+ROOT(columns=[('supplier_name', supplier_name), ('n_super_cust', n_super_cust)], orderings=[(ordering_1):desc_last])
+ LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'n_super_cust': n_super_cust, 'ordering_1': ordering_1, 'supplier_name': supplier_name}, orderings=[(ordering_1):desc_last])
+ PROJECT(columns={'n_super_cust': n_super_cust, 'ordering_1': n_super_cust, 'supplier_name': supplier_name})
+ PROJECT(columns={'n_super_cust': DEFAULT_TO(agg_0, 0:int64), 'supplier_name': supplier_name})
+ JOIN(conditions=[t0.key_2 == t1.key_5 & t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'supplier_name': t0.supplier_name})
+ JOIN(conditions=[t0.nation_key == t1.key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'supplier_name': t0.supplier_name})
+ PROJECT(columns={'key': key, 'nation_key': nation_key, 'supplier_name': name})
+ SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey})
SCAN(table=tpch.NATION, columns={'key': n_nationkey})
AGGREGATE(keys={'key': key, 'key_5': key_5}, aggregations={'agg_0': COUNT()})
FILTER(condition=acctbal > account_balance, columns={'key': key, 'key_5': key_5})
diff --git a/tests/test_plan_refsols/correl_2.txt b/tests/test_plan_refsols/correl_2.txt
index b5d64b7c..3efa5d67 100644
--- a/tests/test_plan_refsols/correl_2.txt
+++ b/tests/test_plan_refsols/correl_2.txt
@@ -8,10 +8,11 @@ ROOT(columns=[('name', name_12), ('n_selected_custs', n_selected_custs)], orderi
SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey})
AGGREGATE(keys={'key': key, 'key_5': key_5}, aggregations={'agg_0': COUNT()})
- FILTER(condition=SLICE(comment_7, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(name, None:unknown, 1:int64, None:unknown)), columns={'key': key, 'key_5': key_5})
- JOIN(conditions=[t0.key_5 == t1.nation_key], types=['inner'], columns={'comment_7': t1.comment, 'key': t0.key, 'key_5': t0.key_5, 'name': t0.name})
- JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'name': t0.name})
- FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key, 'name': name})
- SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
+ FILTER(condition=SLICE(comment_7, None:unknown, 1:int64, None:unknown) == LOWER(SLICE(region_name, None:unknown, 1:int64, None:unknown)), columns={'key': key, 'key_5': key_5})
+ JOIN(conditions=[t0.key_5 == t1.nation_key], types=['inner'], columns={'comment_7': t1.comment, 'key': t0.key, 'key_5': t0.key_5, 'region_name': t0.region_name})
+ JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'region_name': t0.region_name})
+ FILTER(condition=NOT(STARTSWITH(name, 'A':string)), columns={'key': key, 'region_name': region_name})
+ PROJECT(columns={'key': key, 'name': name, 'region_name': name})
+ SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey})
SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey})
diff --git a/tests/test_plan_refsols/correl_20.txt b/tests/test_plan_refsols/correl_20.txt
index 66ed3e35..18e2ee76 100644
--- a/tests/test_plan_refsols/correl_20.txt
+++ b/tests/test_plan_refsols/correl_20.txt
@@ -13,14 +13,15 @@ ROOT(columns=[('n', n)], orderings=[])
SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate})
SCAN(table=tpch.LINEITEM, columns={'line_number': l_linenumber, 'order_key': l_orderkey, 'supplier_key': l_suppkey})
SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'key': s_suppkey})
- PROJECT(columns={'domestic': name_27 == name, 'key': key, 'key_14': key_14, 'key_17': key_17, 'key_21': key_21, 'line_number': line_number, 'order_key': order_key})
- JOIN(conditions=[t0.nation_key_23 == t1.key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t0.key_17, 'key_21': t0.key_21, 'line_number': t0.line_number, 'name': t0.name, 'name_27': t1.name, 'order_key': t0.order_key})
- JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t0.key_17, 'key_21': t1.key, 'line_number': t0.line_number, 'name': t0.name, 'nation_key_23': t1.nation_key, 'order_key': t0.order_key})
- JOIN(conditions=[t0.key_17 == t1.order_key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t0.key_17, 'line_number': t1.line_number, 'name': t0.name, 'order_key': t1.order_key, 'supplier_key': t1.supplier_key})
- FILTER(condition=YEAR(order_date) == 1998:int64 & MONTH(order_date) == 6:int64, columns={'key': key, 'key_14': key_14, 'key_17': key_17, 'name': name})
- JOIN(conditions=[t0.key_14 == t1.customer_key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t1.key, 'name': t0.name, 'order_date': t1.order_date})
- JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_14': t1.key, 'name': t0.name})
- SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name})
+ PROJECT(columns={'domestic': name_27 == source_nation_name, 'key': key, 'key_14': key_14, 'key_17': key_17, 'key_21': key_21, 'line_number': line_number, 'order_key': order_key})
+ JOIN(conditions=[t0.nation_key_23 == t1.key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t0.key_17, 'key_21': t0.key_21, 'line_number': t0.line_number, 'name_27': t1.name, 'order_key': t0.order_key, 'source_nation_name': t0.source_nation_name})
+ JOIN(conditions=[t0.supplier_key == t1.key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t0.key_17, 'key_21': t1.key, 'line_number': t0.line_number, 'nation_key_23': t1.nation_key, 'order_key': t0.order_key, 'source_nation_name': t0.source_nation_name})
+ JOIN(conditions=[t0.key_17 == t1.order_key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t0.key_17, 'line_number': t1.line_number, 'order_key': t1.order_key, 'source_nation_name': t0.source_nation_name, 'supplier_key': t1.supplier_key})
+ FILTER(condition=YEAR(order_date) == 1998:int64 & MONTH(order_date) == 6:int64, columns={'key': key, 'key_14': key_14, 'key_17': key_17, 'source_nation_name': source_nation_name})
+ JOIN(conditions=[t0.key_14 == t1.customer_key], types=['inner'], columns={'key': t0.key, 'key_14': t0.key_14, 'key_17': t1.key, 'order_date': t1.order_date, 'source_nation_name': t0.source_nation_name})
+ JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_14': t1.key, 'source_nation_name': t0.source_nation_name})
+ PROJECT(columns={'key': key, 'source_nation_name': name})
+ SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name})
SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey})
SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate})
SCAN(table=tpch.LINEITEM, columns={'line_number': l_linenumber, 'order_key': l_orderkey, 'supplier_key': l_suppkey})
diff --git a/tests/test_plan_refsols/correl_3.txt b/tests/test_plan_refsols/correl_3.txt
index 2bbb01bc..8ad33431 100644
--- a/tests/test_plan_refsols/correl_3.txt
+++ b/tests/test_plan_refsols/correl_3.txt
@@ -1,13 +1,15 @@
-ROOT(columns=[('name', name), ('n_nations', n_nations)], orderings=[(ordering_1):asc_first])
- PROJECT(columns={'n_nations': n_nations, 'name': name, 'ordering_1': name})
- PROJECT(columns={'n_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name})
- JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name})
- SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
+ROOT(columns=[('region_name', region_name), ('n_nations', n_nations)], orderings=[(ordering_1):asc_first])
+ PROJECT(columns={'n_nations': n_nations, 'ordering_1': name, 'region_name': region_name})
+ PROJECT(columns={'n_nations': DEFAULT_TO(agg_0, 0:int64), 'name': name, 'region_name': region_name})
+ JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'agg_0': t1.agg_0, 'name': t0.name, 'region_name': t0.region_name})
+ PROJECT(columns={'key': key, 'name': name, 'region_name': name})
+ SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()})
FILTER(condition=True:bool, columns={'key': key})
JOIN(conditions=[t0.key_2 == t1.nation_key], types=['semi'], columns={'key': t0.key}, correl_name='corr4')
- JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name': t0.name})
- SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
+ JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'region_name': t0.region_name})
+ PROJECT(columns={'key': key, 'region_name': name})
+ SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey})
- FILTER(condition=SLICE(comment, None:unknown, 2:int64, None:unknown) == LOWER(SLICE(corr4.name, None:unknown, 2:int64, None:unknown)), columns={'nation_key': nation_key})
+ FILTER(condition=SLICE(comment, None:unknown, 2:int64, None:unknown) == LOWER(SLICE(corr4.region_name, None:unknown, 2:int64, None:unknown)), columns={'nation_key': nation_key})
SCAN(table=tpch.CUSTOMER, columns={'comment': c_comment, 'nation_key': c_nationkey})
diff --git a/tests/test_plan_refsols/correl_5.txt b/tests/test_plan_refsols/correl_5.txt
index 73d793f0..3ea31b09 100644
--- a/tests/test_plan_refsols/correl_5.txt
+++ b/tests/test_plan_refsols/correl_5.txt
@@ -1,13 +1,13 @@
ROOT(columns=[('name', name)], orderings=[(ordering_1):asc_first])
PROJECT(columns={'name': name, 'ordering_1': name})
FILTER(condition=True:bool, columns={'name': name})
- JOIN(conditions=[t0.key == t1.region_key], types=['semi'], columns={'name': t0.name}, correl_name='corr4')
+ JOIN(conditions=[t0.key == t1.region_key], types=['semi'], columns={'name': t0.name}, correl_name='corr3')
JOIN(conditions=[True:bool], types=['inner'], columns={'key': t1.key, 'name': t1.name, 'smallest_bal': t0.smallest_bal})
PROJECT(columns={'smallest_bal': agg_0})
AGGREGATE(keys={}, aggregations={'agg_0': MIN(account_balance)})
SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal})
SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
- FILTER(condition=account_balance <= corr4.smallest_bal + 4.0:float64, columns={'region_key': region_key})
+ FILTER(condition=account_balance <= corr3.smallest_bal + 4.0:float64, columns={'region_key': region_key})
JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'account_balance': t1.account_balance, 'region_key': t0.region_key})
SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey})
SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'nation_key': s_nationkey})
diff --git a/tests/test_plan_refsols/correl_6.txt b/tests/test_plan_refsols/correl_6.txt
index a6829877..d073f37b 100644
--- a/tests/test_plan_refsols/correl_6.txt
+++ b/tests/test_plan_refsols/correl_6.txt
@@ -4,7 +4,8 @@ ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings
JOIN(conditions=[t0.key == t1.key], types=['inner'], columns={'agg_0': t1.agg_0, 'name': t0.name})
SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
AGGREGATE(keys={'key': key}, aggregations={'agg_0': COUNT()})
- FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key})
- JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name})
- SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
+ FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(region_name, None:unknown, 1:int64, None:unknown), columns={'key': key})
+ JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'name_3': t1.name, 'region_name': t0.region_name})
+ PROJECT(columns={'key': key, 'region_name': name})
+ SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey})
diff --git a/tests/test_plan_refsols/correl_7.txt b/tests/test_plan_refsols/correl_7.txt
index 94a129af..37269060 100644
--- a/tests/test_plan_refsols/correl_7.txt
+++ b/tests/test_plan_refsols/correl_7.txt
@@ -2,6 +2,7 @@ ROOT(columns=[('name', name), ('n_prefix_nations', n_prefix_nations)], orderings
PROJECT(columns={'n_prefix_nations': DEFAULT_TO(NULL_2, 0:int64), 'name': name})
FILTER(condition=True:bool, columns={'NULL_2': NULL_2, 'name': name})
JOIN(conditions=[t0.key == t1.region_key], types=['anti'], columns={'NULL_2': None:unknown, 'name': t0.name}, correl_name='corr1')
- SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
- FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key})
+ PROJECT(columns={'key': key, 'name': name, 'region_name': name})
+ SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
+ FILTER(condition=SLICE(name, None:unknown, 1:int64, None:unknown) == SLICE(corr1.region_name, None:unknown, 1:int64, None:unknown), columns={'region_key': region_key})
SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey})
diff --git a/tests/test_plan_refsols/correl_8.txt b/tests/test_plan_refsols/correl_8.txt
index 8da228f0..c26832eb 100644
--- a/tests/test_plan_refsols/correl_8.txt
+++ b/tests/test_plan_refsols/correl_8.txt
@@ -3,7 +3,8 @@ ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_fir
PROJECT(columns={'name': name, 'rname': name_3})
JOIN(conditions=[t0.key == t1.key], types=['left'], columns={'name': t0.name, 'name_3': t1.name_3})
SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name})
- FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name_3': name_3})
- JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name})
- SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey})
+ FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(nation_name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name_3': name_3})
+ JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_3': t1.name, 'nation_name': t0.nation_name})
+ PROJECT(columns={'key': key, 'nation_name': name, 'region_key': region_key})
+ SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey})
SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
diff --git a/tests/test_plan_refsols/correl_9.txt b/tests/test_plan_refsols/correl_9.txt
index 449cceea..2f81e6fa 100644
--- a/tests/test_plan_refsols/correl_9.txt
+++ b/tests/test_plan_refsols/correl_9.txt
@@ -4,7 +4,8 @@ ROOT(columns=[('name', name), ('rname', rname)], orderings=[(ordering_0):asc_fir
FILTER(condition=True:bool, columns={'name': name, 'name_3': name_3})
JOIN(conditions=[t0.key == t1.key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name_3})
SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name})
- FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name_3': name_3})
- JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'name_3': t1.name})
- SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey})
+ FILTER(condition=SLICE(name_3, None:unknown, 1:int64, None:unknown) == SLICE(nation_name, None:unknown, 1:int64, None:unknown), columns={'key': key, 'name_3': name_3})
+ JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'key': t0.key, 'name_3': t1.name, 'nation_name': t0.nation_name})
+ PROJECT(columns={'key': key, 'nation_name': name, 'region_key': region_key})
+ SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey})
SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
diff --git a/tests/test_plan_refsols/exponentiation.txt b/tests/test_plan_refsols/exponentiation.txt
new file mode 100644
index 00000000..709f9b85
--- /dev/null
+++ b/tests/test_plan_refsols/exponentiation.txt
@@ -0,0 +1,5 @@
+ROOT(columns=[('low_square', low_square), ('low_sqrt', low_sqrt), ('low_cbrt', low_cbrt)], orderings=[(ordering_0):asc_first])
+ LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'low_cbrt': low_cbrt, 'low_sqrt': low_sqrt, 'low_square': low_square, 'ordering_0': ordering_0}, orderings=[(ordering_0):asc_first])
+ PROJECT(columns={'low_cbrt': low_cbrt, 'low_sqrt': low_sqrt, 'low_square': low_square, 'ordering_0': low_square})
+ PROJECT(columns={'low_cbrt': POWER(low, 0.3333333333333333:float64), 'low_sqrt': SQRT(low), 'low_square': low ** 2:int64})
+ SCAN(table=main.sbDailyPrice, columns={'low': sbDpLow})
diff --git a/tests/test_plan_refsols/function_sampler.txt b/tests/test_plan_refsols/function_sampler.txt
index 86507720..a1b94cea 100644
--- a/tests/test_plan_refsols/function_sampler.txt
+++ b/tests/test_plan_refsols/function_sampler.txt
@@ -2,9 +2,11 @@ ROOT(columns=[('a', a), ('b', b), ('c', c), ('d', d), ('e', e)], orderings=[(ord
LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'a': a, 'b': b, 'c': c, 'd': d, 'e': e, 'ordering_0': ordering_0}, orderings=[(ordering_0):asc_first])
PROJECT(columns={'a': a, 'b': b, 'c': c, 'd': d, 'e': e, 'ordering_0': address})
FILTER(condition=MONOTONIC(0.0:float64, acctbal, 100.0:float64), columns={'a': a, 'address': address, 'b': b, 'c': c, 'd': d, 'e': e})
- PROJECT(columns={'a': JOIN_STRINGS('-':string, name, name_3, SLICE(name_6, 16:int64, None:unknown, None:unknown)), 'acctbal': acctbal, 'address': address, 'b': ROUND(acctbal, 1:int64), 'c': KEEP_IF(name_6, SLICE(phone, None:unknown, 1:int64, None:unknown) == '3':string), 'd': PRESENT(KEEP_IF(name_6, SLICE(phone, 1:int64, 2:int64, None:unknown) == '1':string)), 'e': ABSENT(KEEP_IF(name_6, SLICE(phone, 14:int64, None:unknown, None:unknown) == '7':string))})
- JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'address': t1.address, 'name': t0.name, 'name_3': t0.name_3, 'name_6': t1.name, 'phone': t1.phone})
- JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name, 'name_3': t1.name})
- SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
- SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey})
+ PROJECT(columns={'a': JOIN_STRINGS('-':string, region_name, nation_name, SLICE(name_6, 16:int64, None:unknown, None:unknown)), 'acctbal': acctbal, 'address': address, 'b': ROUND(acctbal, 1:int64), 'c': KEEP_IF(name_6, SLICE(phone, None:unknown, 1:int64, None:unknown) == '3':string), 'd': PRESENT(KEEP_IF(name_6, SLICE(phone, 1:int64, 2:int64, None:unknown) == '1':string)), 'e': ABSENT(KEEP_IF(name_6, SLICE(phone, 14:int64, None:unknown, None:unknown) == '7':string))})
+ JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'acctbal': t1.acctbal, 'address': t1.address, 'name_6': t1.name, 'nation_name': t0.nation_name, 'phone': t1.phone, 'region_name': t0.region_name})
+ PROJECT(columns={'key_2': key_2, 'nation_name': name_3, 'region_name': region_name})
+ JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key_2': t1.key, 'name_3': t1.name, 'region_name': t0.region_name})
+ PROJECT(columns={'key': key, 'region_name': name})
+ SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
+ SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey})
SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'address': c_address, 'name': c_name, 'nation_key': c_nationkey, 'phone': c_phone})
diff --git a/tests/test_plan_refsols/hour_minute_day.txt b/tests/test_plan_refsols/hour_minute_day.txt
new file mode 100644
index 00000000..f3c52728
--- /dev/null
+++ b/tests/test_plan_refsols/hour_minute_day.txt
@@ -0,0 +1,7 @@
+ROOT(columns=[('transaction_id', transaction_id), ('_expr0', _expr0), ('_expr1', _expr1), ('_expr2', _expr2)], orderings=[(ordering_0):asc_first])
+ PROJECT(columns={'_expr0': _expr0, '_expr1': _expr1, '_expr2': _expr2, 'ordering_0': transaction_id, 'transaction_id': transaction_id})
+ FILTER(condition=ISIN(symbol, ['AAPL', 'GOOGL', 'NFLX']:array[unknown]), columns={'_expr0': _expr0, '_expr1': _expr1, '_expr2': _expr2, 'transaction_id': transaction_id})
+ JOIN(conditions=[t0.ticker_id == t1._id], types=['left'], columns={'_expr0': t0._expr0, '_expr1': t0._expr1, '_expr2': t0._expr2, 'symbol': t1.symbol, 'transaction_id': t0.transaction_id})
+ PROJECT(columns={'_expr0': HOUR(date_time), '_expr1': MINUTE(date_time), '_expr2': SECOND(date_time), 'ticker_id': ticker_id, 'transaction_id': transaction_id})
+ SCAN(table=main.sbTransaction, columns={'date_time': sbTxDateTime, 'ticker_id': sbTxTickerId, 'transaction_id': sbTxId})
+ SCAN(table=main.sbTicker, columns={'_id': sbTickerId, 'symbol': sbTickerSymbol})
diff --git a/tests/test_plan_refsols/lines_shipping_vs_customer_region.txt b/tests/test_plan_refsols/lines_shipping_vs_customer_region.txt
index 643e3473..d1d9d589 100644
--- a/tests/test_plan_refsols/lines_shipping_vs_customer_region.txt
+++ b/tests/test_plan_refsols/lines_shipping_vs_customer_region.txt
@@ -1,5 +1,5 @@
-ROOT(columns=[('order_year', order_year), ('customer_region', customer_region), ('customer_nation', customer_nation), ('supplier_region', supplier_region), ('nation_name', nation_name)], orderings=[])
- PROJECT(columns={'customer_nation': name_3, 'customer_region': name, 'nation_name': nation_name, 'order_year': YEAR(order_date), 'supplier_region': name_16})
+ROOT(columns=[('order_year', order_year), ('customer_region_name', customer_region_name), ('customer_nation_name', customer_nation_name), ('supplier_region_name', supplier_region_name), ('nation_name', nation_name)], orderings=[])
+ PROJECT(columns={'customer_nation_name': name_3, 'customer_region_name': name, 'nation_name': nation_name, 'order_year': YEAR(order_date), 'supplier_region_name': name_16})
JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['left'], columns={'name': t0.name, 'name_16': t1.name_16, 'name_3': t0.name_3, 'nation_name': t1.nation_name, 'order_date': t0.order_date})
JOIN(conditions=[t0.key_8 == t1.order_key], types=['inner'], columns={'name': t0.name, 'name_3': t0.name_3, 'order_date': t0.order_date, 'part_key': t1.part_key, 'supplier_key': t1.supplier_key})
JOIN(conditions=[t0.key_5 == t1.customer_key], types=['inner'], columns={'key_8': t1.key, 'name': t0.name, 'name_3': t0.name_3, 'order_date': t1.order_date})
diff --git a/tests/test_plan_refsols/minutes_seconds_datediff.txt b/tests/test_plan_refsols/minutes_seconds_datediff.txt
new file mode 100644
index 00000000..a36bf553
--- /dev/null
+++ b/tests/test_plan_refsols/minutes_seconds_datediff.txt
@@ -0,0 +1,6 @@
+ROOT(columns=[('x', x), ('y', y), ('minutes_diff', minutes_diff), ('seconds_diff', seconds_diff)], orderings=[(ordering_0):desc_last])
+ LIMIT(limit=Literal(value=30, type=Int64Type()), columns={'minutes_diff': minutes_diff, 'ordering_0': ordering_0, 'seconds_diff': seconds_diff, 'x': x, 'y': y}, orderings=[(ordering_0):desc_last])
+ PROJECT(columns={'minutes_diff': minutes_diff, 'ordering_0': x, 'seconds_diff': seconds_diff, 'x': x, 'y': y})
+ PROJECT(columns={'minutes_diff': DATEDIFF('m':string, date_time, datetime.datetime(2023, 4, 3, 13, 16, 30):date), 'seconds_diff': DATEDIFF('s':string, date_time, datetime.datetime(2023, 4, 3, 13, 16, 30):date), 'x': date_time, 'y': datetime.datetime(2023, 4, 3, 13, 16, 30):date})
+ FILTER(condition=YEAR(date_time) <= 2024:int64, columns={'date_time': date_time})
+ SCAN(table=main.sbTransaction, columns={'date_time': sbTxDateTime})
diff --git a/tests/test_plan_refsols/multi_partition_access_1.txt b/tests/test_plan_refsols/multi_partition_access_1.txt
new file mode 100644
index 00000000..30b49da7
--- /dev/null
+++ b/tests/test_plan_refsols/multi_partition_access_1.txt
@@ -0,0 +1,22 @@
+ROOT(columns=[('symbol', symbol)], orderings=[(ordering_0):asc_first])
+ JOIN(conditions=[t0.currency_5 == t1.currency & t0.exchange_7 == t1.exchange & t0.ticker_type_8 == t1.ticker_type], types=['inner'], columns={'ordering_0': t1.ordering_0, 'symbol': t1.symbol})
+ JOIN(conditions=[t0.currency_1 == t1.currency & t0.exchange_3 == t1.exchange], types=['inner'], columns={'currency_5': t1.currency, 'exchange_7': t1.exchange, 'ticker_type_8': t1.ticker_type})
+ JOIN(conditions=[t0.exchange == t1.exchange], types=['inner'], columns={'currency_1': t1.currency, 'exchange_3': t1.exchange})
+ AGGREGATE(keys={'exchange': exchange}, aggregations={})
+ AGGREGATE(keys={'currency': currency, 'exchange': exchange}, aggregations={})
+ AGGREGATE(keys={'currency': currency, 'exchange': exchange, 'ticker_type': ticker_type}, aggregations={})
+ LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'currency': currency, 'exchange': exchange, 'ticker_type': ticker_type}, orderings=[(ordering_0):asc_first])
+ PROJECT(columns={'currency': currency, 'exchange': exchange, 'ordering_0': symbol, 'ticker_type': ticker_type})
+ SCAN(table=main.sbTicker, columns={'currency': sbTickerCurrency, 'exchange': sbTickerExchange, 'symbol': sbTickerSymbol, 'ticker_type': sbTickerType})
+ AGGREGATE(keys={'currency': currency, 'exchange': exchange}, aggregations={})
+ AGGREGATE(keys={'currency': currency, 'exchange': exchange, 'ticker_type': ticker_type}, aggregations={})
+ LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'currency': currency, 'exchange': exchange, 'ticker_type': ticker_type}, orderings=[(ordering_0):asc_first])
+ PROJECT(columns={'currency': currency, 'exchange': exchange, 'ordering_0': symbol, 'ticker_type': ticker_type})
+ SCAN(table=main.sbTicker, columns={'currency': sbTickerCurrency, 'exchange': sbTickerExchange, 'symbol': sbTickerSymbol, 'ticker_type': sbTickerType})
+ AGGREGATE(keys={'currency': currency, 'exchange': exchange, 'ticker_type': ticker_type}, aggregations={})
+ LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'currency': currency, 'exchange': exchange, 'ticker_type': ticker_type}, orderings=[(ordering_0):asc_first])
+ PROJECT(columns={'currency': currency, 'exchange': exchange, 'ordering_0': symbol, 'ticker_type': ticker_type})
+ SCAN(table=main.sbTicker, columns={'currency': sbTickerCurrency, 'exchange': sbTickerExchange, 'symbol': sbTickerSymbol, 'ticker_type': sbTickerType})
+ LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'currency': currency, 'exchange': exchange, 'ordering_0': ordering_0, 'symbol': symbol, 'ticker_type': ticker_type}, orderings=[(ordering_0):asc_first])
+ PROJECT(columns={'currency': currency, 'exchange': exchange, 'ordering_0': symbol, 'symbol': symbol, 'ticker_type': ticker_type})
+ SCAN(table=main.sbTicker, columns={'currency': sbTickerCurrency, 'exchange': sbTickerExchange, 'symbol': sbTickerSymbol, 'ticker_type': sbTickerType})
diff --git a/tests/test_plan_refsols/multi_partition_access_2.txt b/tests/test_plan_refsols/multi_partition_access_2.txt
new file mode 100644
index 00000000..aa5d5ca1
--- /dev/null
+++ b/tests/test_plan_refsols/multi_partition_access_2.txt
@@ -0,0 +1,52 @@
+ROOT(columns=[('transaction_id', transaction_id), ('name', name), ('symbol', symbol), ('transaction_type', transaction_type), ('cus_tick_typ_avg_shares', cus_tick_typ_avg_shares), ('cust_tick_avg_shares', cust_tick_avg_shares), ('cust_avg_shares', cust_avg_shares)], orderings=[(ordering_1):asc_first])
+ PROJECT(columns={'cus_tick_typ_avg_shares': cus_tick_typ_avg_shares, 'cust_avg_shares': cust_avg_shares, 'cust_tick_avg_shares': cust_tick_avg_shares, 'name': name, 'ordering_1': transaction_id, 'symbol': symbol, 'transaction_id': transaction_id, 'transaction_type': transaction_type})
+ PROJECT(columns={'cus_tick_typ_avg_shares': cus_tick_typ_avg_shares, 'cust_avg_shares': cust_avg_shares, 'cust_tick_avg_shares': cust_tick_avg_shares, 'name': name, 'symbol': symbol, 'transaction_id': transaction_id, 'transaction_type': transaction_type_29})
+ JOIN(conditions=[t0.ticker_id == t1._id], types=['left'], columns={'cus_tick_typ_avg_shares': t0.cus_tick_typ_avg_shares, 'cust_avg_shares': t0.cust_avg_shares, 'cust_tick_avg_shares': t0.cust_tick_avg_shares, 'name': t0.name, 'symbol': t1.symbol, 'transaction_id': t0.transaction_id, 'transaction_type_29': t0.transaction_type_29})
+ JOIN(conditions=[t0.customer_id_28 == t1._id], types=['left'], columns={'cus_tick_typ_avg_shares': t0.cus_tick_typ_avg_shares, 'cust_avg_shares': t0.cust_avg_shares, 'cust_tick_avg_shares': t0.cust_tick_avg_shares, 'name': t1.name, 'ticker_id': t0.ticker_id, 'transaction_id': t0.transaction_id, 'transaction_type_29': t0.transaction_type_29})
+ FILTER(condition=shares < cus_tick_typ_avg_shares & shares < cust_tick_avg_shares & shares < cust_avg_shares, columns={'cus_tick_typ_avg_shares': cus_tick_typ_avg_shares, 'cust_avg_shares': cust_avg_shares, 'cust_tick_avg_shares': cust_tick_avg_shares, 'customer_id_28': customer_id_28, 'ticker_id': ticker_id, 'transaction_id': transaction_id, 'transaction_type_29': transaction_type_29})
+ JOIN(conditions=[t0.customer_id_24 == t1.customer_id & t0.ticker_id_26 == t1.ticker_id & t0.transaction_type_27 == t1.transaction_type], types=['inner'], columns={'cus_tick_typ_avg_shares': t0.cus_tick_typ_avg_shares, 'cust_avg_shares': t0.cust_avg_shares, 'cust_tick_avg_shares': t0.cust_tick_avg_shares, 'customer_id_28': t1.customer_id, 'shares': t1.shares, 'ticker_id': t1.ticker_id, 'transaction_id': t1.transaction_id, 'transaction_type_29': t1.transaction_type})
+ JOIN(conditions=[t0.customer_id_21 == t1.customer_id & t0.ticker_id_22 == t1.ticker_id], types=['inner'], columns={'cus_tick_typ_avg_shares': t1.cus_tick_typ_avg_shares, 'cust_avg_shares': t0.cust_avg_shares, 'cust_tick_avg_shares': t0.cust_tick_avg_shares, 'customer_id_24': t1.customer_id, 'ticker_id_26': t1.ticker_id, 'transaction_type_27': t1.transaction_type})
+ JOIN(conditions=[t0.customer_id == t1.customer_id], types=['inner'], columns={'cust_avg_shares': t0.cust_avg_shares, 'cust_tick_avg_shares': t1.cust_tick_avg_shares, 'customer_id_21': t1.customer_id, 'ticker_id_22': t1.ticker_id})
+ PROJECT(columns={'cust_avg_shares': agg_0, 'customer_id': customer_id})
+ JOIN(conditions=[t0.customer_id == t1.customer_id], types=['left'], columns={'agg_0': t1.agg_0, 'customer_id': t0.customer_id})
+ AGGREGATE(keys={'customer_id': customer_id}, aggregations={})
+ JOIN(conditions=[t0.customer_id == t1.customer_id & t0.ticker_id == t1.ticker_id], types=['left'], columns={'customer_id': t0.customer_id})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id}, aggregations={})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id}, aggregations={})
+ JOIN(conditions=[t0.customer_id == t1.customer_id & t0.ticker_id == t1.ticker_id & t0.transaction_type == t1.transaction_type], types=['inner'], columns={'customer_id': t0.customer_id, 'ticker_id': t0.ticker_id})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ AGGREGATE(keys={'customer_id': customer_id}, aggregations={'agg_0': AVG(shares)})
+ JOIN(conditions=[t0.customer_id_9 == t1.customer_id & t0.ticker_id_11 == t1.ticker_id & t0.transaction_type_12 == t1.transaction_type], types=['inner'], columns={'customer_id': t0.customer_id, 'shares': t1.shares})
+ JOIN(conditions=[t0.customer_id == t1.customer_id & t0.ticker_id == t1.ticker_id], types=['inner'], columns={'customer_id': t0.customer_id, 'customer_id_9': t1.customer_id, 'ticker_id_11': t1.ticker_id, 'transaction_type_12': t1.transaction_type})
+ JOIN(conditions=[t0.customer_id == t1.customer_id & t0.ticker_id == t1.ticker_id], types=['left'], columns={'customer_id': t0.customer_id, 'ticker_id': t0.ticker_id})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id}, aggregations={})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id}, aggregations={})
+ JOIN(conditions=[t0.customer_id == t1.customer_id & t0.ticker_id == t1.ticker_id & t0.transaction_type == t1.transaction_type], types=['inner'], columns={'customer_id': t0.customer_id, 'ticker_id': t0.ticker_id})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'shares': sbTxShares, 'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ PROJECT(columns={'cust_tick_avg_shares': agg_0, 'customer_id': customer_id, 'ticker_id': ticker_id})
+ JOIN(conditions=[t0.customer_id == t1.customer_id & t0.ticker_id == t1.ticker_id], types=['left'], columns={'agg_0': t1.agg_0, 'customer_id': t0.customer_id, 'ticker_id': t0.ticker_id})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id}, aggregations={})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id}, aggregations={'agg_0': AVG(shares)})
+ JOIN(conditions=[t0.customer_id == t1.customer_id & t0.ticker_id == t1.ticker_id & t0.transaction_type == t1.transaction_type], types=['inner'], columns={'customer_id': t0.customer_id, 'shares': t1.shares, 'ticker_id': t0.ticker_id})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'shares': sbTxShares, 'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ PROJECT(columns={'cus_tick_typ_avg_shares': agg_0, 'customer_id': customer_id, 'ticker_id': ticker_id, 'transaction_type': transaction_type})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={'agg_0': AVG(shares)})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'shares': sbTxShares, 'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'shares': sbTxShares, 'ticker_id': sbTxTickerId, 'transaction_id': sbTxId, 'transaction_type': sbTxType})
+ SCAN(table=main.sbCustomer, columns={'_id': sbCustId, 'name': sbCustName})
+ SCAN(table=main.sbTicker, columns={'_id': sbTickerId, 'symbol': sbTickerSymbol})
diff --git a/tests/test_plan_refsols/multi_partition_access_3.txt b/tests/test_plan_refsols/multi_partition_access_3.txt
new file mode 100644
index 00000000..c9020854
--- /dev/null
+++ b/tests/test_plan_refsols/multi_partition_access_3.txt
@@ -0,0 +1,25 @@
+ROOT(columns=[('symbol', symbol_10), ('close', close)], orderings=[(ordering_1):asc_first])
+ PROJECT(columns={'close': close, 'ordering_1': ordering_1, 'symbol_10': symbol_9})
+ PROJECT(columns={'close': close, 'ordering_1': symbol_9, 'symbol_9': symbol_9})
+ PROJECT(columns={'close': close, 'symbol_9': symbol_4})
+ FILTER(condition=close == ticker_high_price_7 & close < type_high_price, columns={'close': close, 'symbol_4': symbol_4})
+ JOIN(conditions=[t0.ticker_type_3 == t1.ticker_type_6], types=['inner'], columns={'close': t1.close, 'symbol_4': t1.symbol_4, 'ticker_high_price_7': t1.ticker_high_price, 'type_high_price': t0.type_high_price})
+ PROJECT(columns={'ticker_type_3': ticker_type_3, 'type_high_price': agg_0})
+ AGGREGATE(keys={'ticker_type_3': ticker_type_3}, aggregations={'agg_0': MAX(close)})
+ JOIN(conditions=[t0.ticker_id == t1.ticker_id], types=['inner'], columns={'close': t1.close, 'ticker_type_3': t1.ticker_type})
+ AGGREGATE(keys={'ticker_id': ticker_id}, aggregations={})
+ JOIN(conditions=[t0._id == t1.ticker_id], types=['inner'], columns={'ticker_id': t1.ticker_id})
+ SCAN(table=main.sbTicker, columns={'_id': sbTickerId})
+ SCAN(table=main.sbDailyPrice, columns={'ticker_id': sbDpTickerId})
+ JOIN(conditions=[t0._id == t1.ticker_id], types=['inner'], columns={'close': t1.close, 'ticker_id': t1.ticker_id, 'ticker_type': t0.ticker_type})
+ SCAN(table=main.sbTicker, columns={'_id': sbTickerId, 'ticker_type': sbTickerType})
+ SCAN(table=main.sbDailyPrice, columns={'close': sbDpClose, 'ticker_id': sbDpTickerId})
+ JOIN(conditions=[t0.ticker_id == t1.ticker_id], types=['inner'], columns={'close': t1.close, 'symbol_4': t1.symbol, 'ticker_high_price': t0.ticker_high_price, 'ticker_type_6': t1.ticker_type})
+ PROJECT(columns={'ticker_high_price': agg_0, 'ticker_id': ticker_id})
+ AGGREGATE(keys={'ticker_id': ticker_id}, aggregations={'agg_0': MAX(close)})
+ JOIN(conditions=[t0._id == t1.ticker_id], types=['inner'], columns={'close': t1.close, 'ticker_id': t1.ticker_id})
+ SCAN(table=main.sbTicker, columns={'_id': sbTickerId})
+ SCAN(table=main.sbDailyPrice, columns={'close': sbDpClose, 'ticker_id': sbDpTickerId})
+ JOIN(conditions=[t0._id == t1.ticker_id], types=['inner'], columns={'close': t1.close, 'symbol': t0.symbol, 'ticker_id': t1.ticker_id, 'ticker_type': t0.ticker_type})
+ SCAN(table=main.sbTicker, columns={'_id': sbTickerId, 'symbol': sbTickerSymbol, 'ticker_type': sbTickerType})
+ SCAN(table=main.sbDailyPrice, columns={'close': sbDpClose, 'ticker_id': sbDpTickerId})
diff --git a/tests/test_plan_refsols/multi_partition_access_4.txt b/tests/test_plan_refsols/multi_partition_access_4.txt
new file mode 100644
index 00000000..6d3f10c0
--- /dev/null
+++ b/tests/test_plan_refsols/multi_partition_access_4.txt
@@ -0,0 +1,14 @@
+ROOT(columns=[('transaction_id', transaction_id)], orderings=[(ordering_1):asc_first])
+ PROJECT(columns={'ordering_1': transaction_id, 'transaction_id': transaction_id})
+ FILTER(condition=shares >= cust_ticker_max_shares & shares < cust_max_shares, columns={'transaction_id': transaction_id})
+ JOIN(conditions=[t0.customer_id_3 == t1.customer_id & t0.ticker_id_4 == t1.ticker_id], types=['inner'], columns={'cust_max_shares': t0.cust_max_shares, 'cust_ticker_max_shares': t0.cust_ticker_max_shares, 'shares': t1.shares, 'transaction_id': t1.transaction_id})
+ JOIN(conditions=[t0.customer_id == t1.customer_id], types=['inner'], columns={'cust_max_shares': t0.cust_max_shares, 'cust_ticker_max_shares': t1.cust_ticker_max_shares, 'customer_id_3': t1.customer_id, 'ticker_id_4': t1.ticker_id})
+ PROJECT(columns={'cust_max_shares': agg_0, 'customer_id': customer_id})
+ AGGREGATE(keys={'customer_id': customer_id}, aggregations={'agg_0': MAX(cust_ticker_max_shares)})
+ PROJECT(columns={'cust_ticker_max_shares': agg_0, 'customer_id': customer_id})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id}, aggregations={'agg_0': MAX(shares)})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'shares': sbTxShares, 'ticker_id': sbTxTickerId})
+ PROJECT(columns={'cust_ticker_max_shares': agg_0, 'customer_id': customer_id, 'ticker_id': ticker_id})
+ AGGREGATE(keys={'customer_id': customer_id, 'ticker_id': ticker_id}, aggregations={'agg_0': MAX(shares)})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'shares': sbTxShares, 'ticker_id': sbTxTickerId})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'shares': sbTxShares, 'ticker_id': sbTxTickerId, 'transaction_id': sbTxId})
diff --git a/tests/test_plan_refsols/multi_partition_access_5.txt b/tests/test_plan_refsols/multi_partition_access_5.txt
new file mode 100644
index 00000000..c3192bb9
--- /dev/null
+++ b/tests/test_plan_refsols/multi_partition_access_5.txt
@@ -0,0 +1,24 @@
+ROOT(columns=[('transaction_id', transaction_id), ('n_ticker_type_trans', n_ticker_type_trans), ('n_ticker_trans', n_ticker_trans), ('n_type_trans', n_type_trans)], orderings=[(ordering_1):asc_first, (ordering_2):asc_first])
+ PROJECT(columns={'n_ticker_trans': n_ticker_trans, 'n_ticker_type_trans': n_ticker_type_trans, 'n_type_trans': n_type_trans, 'ordering_1': n_ticker_type_trans, 'ordering_2': transaction_id, 'transaction_id': transaction_id})
+ FILTER(condition=n_ticker_type_trans / n_ticker_trans > 0.8:float64 & n_ticker_type_trans / n_type_trans < 0.2:float64, columns={'n_ticker_trans': n_ticker_trans, 'n_ticker_type_trans': n_ticker_type_trans, 'n_type_trans': n_type_trans, 'transaction_id': transaction_id})
+ JOIN(conditions=[t0.ticker_id_7 == t1.ticker_id & t0.transaction_type_8 == t1.transaction_type], types=['inner'], columns={'n_ticker_trans': t0.n_ticker_trans, 'n_ticker_type_trans': t0.n_ticker_type_trans, 'n_type_trans': t0.n_type_trans, 'transaction_id': t1.transaction_id})
+ JOIN(conditions=[t0.transaction_type_4 == t1.transaction_type_8], types=['inner'], columns={'n_ticker_trans': t1.n_ticker_trans, 'n_ticker_type_trans': t1.n_ticker_type_trans, 'n_type_trans': t0.n_type_trans, 'ticker_id_7': t1.ticker_id_7, 'transaction_type_8': t1.transaction_type_8})
+ PROJECT(columns={'n_type_trans': DEFAULT_TO(agg_0, 0:int64), 'transaction_type_4': transaction_type_4})
+ AGGREGATE(keys={'transaction_type_4': transaction_type_4}, aggregations={'agg_0': SUM(n_ticker_type_trans)})
+ JOIN(conditions=[t0.ticker_id == t1.ticker_id], types=['inner'], columns={'n_ticker_type_trans': t1.n_ticker_type_trans, 'transaction_type_4': t1.transaction_type})
+ AGGREGATE(keys={'ticker_id': ticker_id}, aggregations={})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ PROJECT(columns={'n_ticker_type_trans': DEFAULT_TO(agg_0, 0:int64), 'ticker_id': ticker_id, 'transaction_type': transaction_type})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={'agg_0': COUNT()})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ JOIN(conditions=[t0.ticker_id == t1.ticker_id], types=['inner'], columns={'n_ticker_trans': t0.n_ticker_trans, 'n_ticker_type_trans': t1.n_ticker_type_trans, 'ticker_id_7': t1.ticker_id, 'transaction_type_8': t1.transaction_type})
+ PROJECT(columns={'n_ticker_trans': DEFAULT_TO(agg_0, 0:int64), 'ticker_id': ticker_id})
+ AGGREGATE(keys={'ticker_id': ticker_id}, aggregations={'agg_0': SUM(n_ticker_type_trans)})
+ PROJECT(columns={'n_ticker_type_trans': DEFAULT_TO(agg_0, 0:int64), 'ticker_id': ticker_id})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={'agg_0': COUNT()})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ PROJECT(columns={'n_ticker_type_trans': DEFAULT_TO(agg_0, 0:int64), 'ticker_id': ticker_id, 'transaction_type': transaction_type})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={'agg_0': COUNT()})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_id': sbTxId, 'transaction_type': sbTxType})
diff --git a/tests/test_plan_refsols/multi_partition_access_6.txt b/tests/test_plan_refsols/multi_partition_access_6.txt
new file mode 100644
index 00000000..e5ab9cf4
--- /dev/null
+++ b/tests/test_plan_refsols/multi_partition_access_6.txt
@@ -0,0 +1,64 @@
+ROOT(columns=[('transaction_id', transaction_id)], orderings=[(ordering_1):asc_first])
+ PROJECT(columns={'ordering_1': transaction_id, 'transaction_id': transaction_id})
+ FILTER(condition=n_ticker_type_trans == 1:int64 | n_cust_type_trans == 1:int64 & n_cust_trans > 1:int64 & n_type_trans_39 > 1:int64 & n_ticker_trans > 1:int64, columns={'transaction_id': transaction_id})
+ JOIN(conditions=[t0.customer_id_25 == t1.customer_id & t0.transaction_type_22_26 == t1.transaction_type_37], types=['inner'], columns={'n_cust_trans': t0.n_cust_trans, 'n_cust_type_trans': t0.n_cust_type_trans, 'n_ticker_trans': t1.n_ticker_trans, 'n_ticker_type_trans': t1.n_ticker_type_trans, 'n_type_trans_39': t1.n_type_trans, 'transaction_id': t1.transaction_id})
+ JOIN(conditions=[t0.customer_id == t1.customer_id], types=['inner'], columns={'customer_id_25': t1.customer_id, 'n_cust_trans': t0.n_cust_trans, 'n_cust_type_trans': t1.n_cust_type_trans, 'transaction_type_22_26': t1.transaction_type_22})
+ PROJECT(columns={'customer_id': customer_id, 'n_cust_trans': DEFAULT_TO(agg_0, 0:int64)})
+ AGGREGATE(keys={'customer_id': customer_id}, aggregations={'agg_0': SUM(n_cust_type_trans)})
+ PROJECT(columns={'customer_id': customer_id, 'n_cust_type_trans': DEFAULT_TO(agg_0, 0:int64)})
+ AGGREGATE(keys={'customer_id': customer_id, 'transaction_type_11': transaction_type_11}, aggregations={'agg_0': COUNT()})
+ JOIN(conditions=[t0.ticker_id_7 == t1.ticker_id & t0.transaction_type_8 == t1.transaction_type], types=['inner'], columns={'customer_id': t1.customer_id, 'transaction_type_11': t1.transaction_type})
+ JOIN(conditions=[t0.transaction_type_4 == t1.transaction_type_8], types=['inner'], columns={'ticker_id_7': t1.ticker_id_7, 'transaction_type_8': t1.transaction_type_8})
+ AGGREGATE(keys={'transaction_type_4': transaction_type_4}, aggregations={})
+ JOIN(conditions=[t0.ticker_id == t1.ticker_id], types=['inner'], columns={'transaction_type_4': t1.transaction_type})
+ AGGREGATE(keys={'ticker_id': ticker_id}, aggregations={})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ JOIN(conditions=[t0.ticker_id == t1.ticker_id], types=['inner'], columns={'ticker_id_7': t1.ticker_id, 'transaction_type_8': t1.transaction_type})
+ AGGREGATE(keys={'ticker_id': ticker_id}, aggregations={})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ PROJECT(columns={'customer_id': customer_id, 'n_cust_type_trans': DEFAULT_TO(agg_0, 0:int64), 'transaction_type_22': transaction_type_22})
+ AGGREGATE(keys={'customer_id': customer_id, 'transaction_type_22': transaction_type_22}, aggregations={'agg_0': COUNT()})
+ JOIN(conditions=[t0.ticker_id_18 == t1.ticker_id & t0.transaction_type_19 == t1.transaction_type], types=['inner'], columns={'customer_id': t1.customer_id, 'transaction_type_22': t1.transaction_type})
+ JOIN(conditions=[t0.transaction_type_15 == t1.transaction_type_19], types=['inner'], columns={'ticker_id_18': t1.ticker_id_18, 'transaction_type_19': t1.transaction_type_19})
+ AGGREGATE(keys={'transaction_type_15': transaction_type_15}, aggregations={})
+ JOIN(conditions=[t0.ticker_id == t1.ticker_id], types=['inner'], columns={'transaction_type_15': t1.transaction_type})
+ AGGREGATE(keys={'ticker_id': ticker_id}, aggregations={})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ JOIN(conditions=[t0.ticker_id == t1.ticker_id], types=['inner'], columns={'ticker_id_18': t1.ticker_id, 'transaction_type_19': t1.transaction_type})
+ AGGREGATE(keys={'ticker_id': ticker_id}, aggregations={})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ JOIN(conditions=[t0.ticker_id_33 == t1.ticker_id & t0.transaction_type_34 == t1.transaction_type], types=['inner'], columns={'customer_id': t1.customer_id, 'n_ticker_trans': t0.n_ticker_trans, 'n_ticker_type_trans': t0.n_ticker_type_trans, 'n_type_trans': t0.n_type_trans, 'transaction_id': t1.transaction_id, 'transaction_type_37': t1.transaction_type})
+ JOIN(conditions=[t0.transaction_type_30 == t1.transaction_type_34], types=['inner'], columns={'n_ticker_trans': t1.n_ticker_trans, 'n_ticker_type_trans': t1.n_ticker_type_trans, 'n_type_trans': t0.n_type_trans, 'ticker_id_33': t1.ticker_id_33, 'transaction_type_34': t1.transaction_type_34})
+ PROJECT(columns={'n_type_trans': DEFAULT_TO(agg_0, 0:int64), 'transaction_type_30': transaction_type_30})
+ AGGREGATE(keys={'transaction_type_30': transaction_type_30}, aggregations={'agg_0': SUM(n_ticker_type_trans)})
+ JOIN(conditions=[t0.ticker_id == t1.ticker_id], types=['inner'], columns={'n_ticker_type_trans': t1.n_ticker_type_trans, 'transaction_type_30': t1.transaction_type})
+ AGGREGATE(keys={'ticker_id': ticker_id}, aggregations={})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ PROJECT(columns={'n_ticker_type_trans': DEFAULT_TO(agg_0, 0:int64), 'ticker_id': ticker_id, 'transaction_type': transaction_type})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={'agg_0': COUNT()})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ JOIN(conditions=[t0.ticker_id == t1.ticker_id], types=['inner'], columns={'n_ticker_trans': t0.n_ticker_trans, 'n_ticker_type_trans': t1.n_ticker_type_trans, 'ticker_id_33': t1.ticker_id, 'transaction_type_34': t1.transaction_type})
+ PROJECT(columns={'n_ticker_trans': DEFAULT_TO(agg_0, 0:int64), 'ticker_id': ticker_id})
+ AGGREGATE(keys={'ticker_id': ticker_id}, aggregations={'agg_0': SUM(n_ticker_type_trans)})
+ PROJECT(columns={'n_ticker_type_trans': DEFAULT_TO(agg_0, 0:int64), 'ticker_id': ticker_id})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={'agg_0': COUNT()})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ PROJECT(columns={'n_ticker_type_trans': DEFAULT_TO(agg_0, 0:int64), 'ticker_id': ticker_id, 'transaction_type': transaction_type})
+ AGGREGATE(keys={'ticker_id': ticker_id, 'transaction_type': transaction_type}, aggregations={'agg_0': COUNT()})
+ SCAN(table=main.sbTransaction, columns={'ticker_id': sbTxTickerId, 'transaction_type': sbTxType})
+ SCAN(table=main.sbTransaction, columns={'customer_id': sbTxCustId, 'ticker_id': sbTxTickerId, 'transaction_id': sbTxId, 'transaction_type': sbTxType})
diff --git a/tests/test_plan_refsols/multiple_simple_aggregations_multiple_calcs.txt b/tests/test_plan_refsols/multiple_simple_aggregations_multiple_calcs.txt
index 8431aba3..4b41ffc1 100644
--- a/tests/test_plan_refsols/multiple_simple_aggregations_multiple_calcs.txt
+++ b/tests/test_plan_refsols/multiple_simple_aggregations_multiple_calcs.txt
@@ -1,8 +1,8 @@
-ROOT(columns=[('nation_name', nation_name_0), ('total_consumer_value', total_consumer_value), ('total_supplier_value', total_supplier_value), ('avg_consumer_value', avg_consumer_value), ('avg_supplier_value', avg_supplier_value), ('best_consumer_value', best_consumer_value), ('best_supplier_value', best_supplier_value)], orderings=[])
- PROJECT(columns={'avg_consumer_value': avg_consumer_value, 'avg_supplier_value': avg_supplier_value, 'best_consumer_value': agg_4, 'best_supplier_value': agg_5, 'nation_name_0': key, 'total_consumer_value': total_consumer_value, 'total_supplier_value': total_supplier_value})
- PROJECT(columns={'agg_4': agg_4, 'agg_5': agg_5, 'avg_consumer_value': avg_consumer_value, 'avg_supplier_value': agg_2, 'key': key, 'total_consumer_value': total_consumer_value, 'total_supplier_value': DEFAULT_TO(agg_3, 0:int64)})
- JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_2': t1.agg_2, 'agg_3': t1.agg_3, 'agg_4': t0.agg_4, 'agg_5': t1.agg_5, 'avg_consumer_value': t0.avg_consumer_value, 'key': t0.key, 'total_consumer_value': t0.total_consumer_value})
- PROJECT(columns={'agg_4': agg_4, 'avg_consumer_value': agg_0, 'key': key, 'total_consumer_value': DEFAULT_TO(agg_1, 0:int64)})
+ROOT(columns=[('nation_name', nation_name), ('total_consumer_value', total_consumer_value), ('total_supplier_value', total_supplier_value), ('avg_consumer_value', avg_consumer_value), ('avg_supplier_value', avg_supplier_value), ('best_consumer_value', best_consumer_value), ('best_supplier_value', best_supplier_value)], orderings=[])
+ PROJECT(columns={'avg_consumer_value': avg_consumer_value_a, 'avg_supplier_value': avg_supplier_value_a, 'best_consumer_value': agg_4, 'best_supplier_value': agg_5, 'nation_name': key, 'total_consumer_value': total_consumer_value_a, 'total_supplier_value': total_supplier_value_a})
+ PROJECT(columns={'agg_4': agg_4, 'agg_5': agg_5, 'avg_consumer_value_a': avg_consumer_value_a, 'avg_supplier_value_a': agg_2, 'key': key, 'total_consumer_value_a': total_consumer_value_a, 'total_supplier_value_a': DEFAULT_TO(agg_3, 0:int64)})
+ JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_2': t1.agg_2, 'agg_3': t1.agg_3, 'agg_4': t0.agg_4, 'agg_5': t1.agg_5, 'avg_consumer_value_a': t0.avg_consumer_value_a, 'key': t0.key, 'total_consumer_value_a': t0.total_consumer_value_a})
+ PROJECT(columns={'agg_4': agg_4, 'avg_consumer_value_a': agg_0, 'key': key, 'total_consumer_value_a': DEFAULT_TO(agg_1, 0:int64)})
JOIN(conditions=[t0.key == t1.nation_key], types=['left'], columns={'agg_0': t1.agg_0, 'agg_1': t1.agg_1, 'agg_4': t1.agg_4, 'key': t0.key})
SCAN(table=tpch.NATION, columns={'key': n_nationkey})
AGGREGATE(keys={'nation_key': nation_key}, aggregations={'agg_0': AVG(acctbal), 'agg_1': SUM(acctbal), 'agg_4': MAX(acctbal)})
diff --git a/tests/test_plan_refsols/percentile_nations.txt b/tests/test_plan_refsols/percentile_nations.txt
index 19a91388..4eb536a2 100644
--- a/tests/test_plan_refsols/percentile_nations.txt
+++ b/tests/test_plan_refsols/percentile_nations.txt
@@ -1,3 +1,3 @@
-ROOT(columns=[('name', name), ('p', p)], orderings=[])
+ROOT(columns=[('name', name), ('p1', p), ('p2', p)], orderings=[])
PROJECT(columns={'name': name, 'p': PERCENTILE(args=[], partition=[], order=[(name):asc_last], n_buckets=5)})
SCAN(table=tpch.NATION, columns={'name': n_name})
diff --git a/tests/test_plan_refsols/rank_parts_per_supplier_region_by_size.txt b/tests/test_plan_refsols/rank_parts_per_supplier_region_by_size.txt
index f189b66e..e6e0780a 100644
--- a/tests/test_plan_refsols/rank_parts_per_supplier_region_by_size.txt
+++ b/tests/test_plan_refsols/rank_parts_per_supplier_region_by_size.txt
@@ -2,12 +2,13 @@ ROOT(columns=[('key', key_13), ('region', region), ('rank', rank)], orderings=[(
PROJECT(columns={'key_13': key_12, 'ordering_0': ordering_0, 'rank': rank, 'region': region})
LIMIT(limit=Literal(value=15, type=Int64Type()), columns={'key_12': key_12, 'ordering_0': ordering_0, 'rank': rank, 'region': region}, orderings=[(ordering_0):asc_first])
PROJECT(columns={'key_12': key_12, 'ordering_0': key_12, 'rank': rank, 'region': region})
- PROJECT(columns={'key_12': key_9, 'rank': RANKING(args=[], partition=[key], order=[(size):desc_first, (container):desc_first, (part_type):desc_first], allow_ties=True, dense=True), 'region': name})
- JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'container': t1.container, 'key': t0.key, 'key_9': t1.key, 'name': t0.name, 'part_type': t1.part_type, 'size': t1.size})
- JOIN(conditions=[t0.key_5 == t1.supplier_key], types=['inner'], columns={'key': t0.key, 'name': t0.name, 'part_key': t1.part_key})
- JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'name': t0.name})
- JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'name': t0.name})
- SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
+ PROJECT(columns={'key_12': key_9, 'rank': RANKING(args=[], partition=[key], order=[(size):desc_first, (container):desc_first, (part_type):desc_first], allow_ties=True, dense=True), 'region': region_name})
+ JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'container': t1.container, 'key': t0.key, 'key_9': t1.key, 'part_type': t1.part_type, 'region_name': t0.region_name, 'size': t1.size})
+ JOIN(conditions=[t0.key_5 == t1.supplier_key], types=['inner'], columns={'key': t0.key, 'part_key': t1.part_key, 'region_name': t0.region_name})
+ JOIN(conditions=[t0.key_2 == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_5': t1.key, 'region_name': t0.region_name})
+ JOIN(conditions=[t0.key == t1.region_key], types=['inner'], columns={'key': t0.key, 'key_2': t1.key, 'region_name': t0.region_name})
+ PROJECT(columns={'key': key, 'region_name': name})
+ SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'region_key': n_regionkey})
SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey})
SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey})
diff --git a/tests/test_plan_refsols/rank_with_filters_c.txt b/tests/test_plan_refsols/rank_with_filters_c.txt
index 0e7375af..2fc09d92 100644
--- a/tests/test_plan_refsols/rank_with_filters_c.txt
+++ b/tests/test_plan_refsols/rank_with_filters_c.txt
@@ -1,4 +1,4 @@
-ROOT(columns=[('size', size_4), ('name', name)], orderings=[])
+ROOT(columns=[('pname', name), ('psize', size_4)], orderings=[])
PROJECT(columns={'name': name, 'size_4': size_3})
FILTER(condition=RANKING(args=[], partition=[size], order=[(retail_price):desc_first]) == 1:int64, columns={'name': name, 'size_3': size_3})
PROJECT(columns={'name': name, 'retail_price': retail_price, 'size': size, 'size_3': size_1})
diff --git a/tests/test_plan_refsols/scan_calc_calc.txt b/tests/test_plan_refsols/scan_calc_calc.txt
index f6094a35..08ebe930 100644
--- a/tests/test_plan_refsols/scan_calc_calc.txt
+++ b/tests/test_plan_refsols/scan_calc_calc.txt
@@ -1,4 +1,3 @@
ROOT(columns=[('fizz', fizz), ('buzz', buzz)], orderings=[])
- PROJECT(columns={'buzz': key, 'fizz': name_0})
- PROJECT(columns={'key': key, 'name_0': 'foo':string})
- SCAN(table=tpch.REGION, columns={'key': r_regionkey})
+ PROJECT(columns={'buzz': key, 'fizz': name})
+ SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
diff --git a/tests/test_plan_refsols/scan_customer_call_functions.txt b/tests/test_plan_refsols/scan_customer_call_functions.txt
index 56c3fc4b..92205afe 100644
--- a/tests/test_plan_refsols/scan_customer_call_functions.txt
+++ b/tests/test_plan_refsols/scan_customer_call_functions.txt
@@ -1,3 +1,3 @@
-ROOT(columns=[('name', name_0), ('country_code', country_code), ('adjusted_account_balance', adjusted_account_balance), ('is_named_john', is_named_john)], orderings=[])
- PROJECT(columns={'adjusted_account_balance': IFF(acctbal < 0:int64, 0:int64, acctbal), 'country_code': SLICE(phone, 0:int64, 3:int64, 1:int64), 'is_named_john': LOWER(name) < 'john':string, 'name_0': LOWER(name)})
+ROOT(columns=[('lname', lname), ('country_code', country_code), ('adjusted_account_balance', adjusted_account_balance), ('is_named_john', is_named_john)], orderings=[])
+ PROJECT(columns={'adjusted_account_balance': IFF(acctbal < 0:int64, 0:int64, acctbal), 'country_code': SLICE(phone, 0:int64, 3:int64, 1:int64), 'is_named_john': LOWER(name) < 'john':string, 'lname': LOWER(name)})
SCAN(table=tpch.CUSTOMER, columns={'acctbal': c_acctbal, 'name': c_name, 'phone': c_phone})
diff --git a/tests/test_plan_refsols/simple_filter_top_five.txt b/tests/test_plan_refsols/simple_filter_top_five.txt
index 9049674a..99c94e47 100644
--- a/tests/test_plan_refsols/simple_filter_top_five.txt
+++ b/tests/test_plan_refsols/simple_filter_top_five.txt
@@ -1,5 +1,5 @@
-ROOT(columns=[('key', key), ('total_price', total_price)], orderings=[(ordering_0):desc_last])
- LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'key': key, 'ordering_0': ordering_0, 'total_price': total_price}, orderings=[(ordering_0):desc_last])
- PROJECT(columns={'key': key, 'ordering_0': key, 'total_price': total_price})
- FILTER(condition=total_price < 1000.0:float64, columns={'key': key, 'total_price': total_price})
+ROOT(columns=[('key', key)], orderings=[(ordering_0):desc_last])
+ LIMIT(limit=Literal(value=5, type=Int64Type()), columns={'key': key, 'ordering_0': ordering_0}, orderings=[(ordering_0):desc_last])
+ PROJECT(columns={'key': key, 'ordering_0': key})
+ FILTER(condition=total_price < 1000.0:float64, columns={'key': key})
SCAN(table=tpch.ORDERS, columns={'key': o_orderkey, 'total_price': o_totalprice})
diff --git a/tests/test_plan_refsols/singular_anti.txt b/tests/test_plan_refsols/singular_anti.txt
index 2beb3726..6a7e2c3b 100644
--- a/tests/test_plan_refsols/singular_anti.txt
+++ b/tests/test_plan_refsols/singular_anti.txt
@@ -1,6 +1,6 @@
-ROOT(columns=[('name', name), ('region_name', region_name)], orderings=[])
- FILTER(condition=True:bool, columns={'name': name, 'region_name': region_name})
- PROJECT(columns={'name': name, 'region_name': NULL_1})
+ROOT(columns=[('nation_name', nation_name), ('region_name', region_name)], orderings=[])
+ FILTER(condition=True:bool, columns={'nation_name': nation_name, 'region_name': region_name})
+ PROJECT(columns={'nation_name': name, 'region_name': NULL_1})
JOIN(conditions=[t0.region_key == t1.key], types=['anti'], columns={'NULL_1': None:unknown, 'name': t0.name})
SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey})
FILTER(condition=name != 'ASIA':string, columns={'key': key})
diff --git a/tests/test_plan_refsols/singular_semi.txt b/tests/test_plan_refsols/singular_semi.txt
index d09452bf..c7466e73 100644
--- a/tests/test_plan_refsols/singular_semi.txt
+++ b/tests/test_plan_refsols/singular_semi.txt
@@ -1,6 +1,6 @@
-ROOT(columns=[('name', name), ('region_name', region_name)], orderings=[])
- FILTER(condition=True:bool, columns={'name': name, 'region_name': region_name})
- PROJECT(columns={'name': name, 'region_name': name_3})
+ROOT(columns=[('nation_name', nation_name), ('region_name', region_name)], orderings=[])
+ FILTER(condition=True:bool, columns={'nation_name': nation_name, 'region_name': region_name})
+ PROJECT(columns={'nation_name': name, 'region_name': name_3})
JOIN(conditions=[t0.region_key == t1.key], types=['inner'], columns={'name': t0.name, 'name_3': t1.name})
SCAN(table=tpch.NATION, columns={'name': n_name, 'region_key': n_regionkey})
FILTER(condition=name != 'ASIA':string, columns={'key': key, 'name': name})
diff --git a/tests/test_plan_refsols/tpch_q16.txt b/tests/test_plan_refsols/tpch_q16.txt
index 094e075a..86dac1b0 100644
--- a/tests/test_plan_refsols/tpch_q16.txt
+++ b/tests/test_plan_refsols/tpch_q16.txt
@@ -5,9 +5,9 @@ ROOT(columns=[('P_BRAND', P_BRAND), ('P_TYPE', P_TYPE), ('P_SIZE', P_SIZE), ('SU
AGGREGATE(keys={'p_brand': p_brand, 'p_size': p_size, 'p_type': p_type}, aggregations={'agg_0': NDISTINCT(supplier_key)})
FILTER(condition=NOT(LIKE(comment_2, '%Customer%Complaints%':string)), columns={'p_brand': p_brand, 'p_size': p_size, 'p_type': p_type, 'supplier_key': supplier_key})
JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'comment_2': t1.comment, 'p_brand': t0.p_brand, 'p_size': t0.p_size, 'p_type': t0.p_type, 'supplier_key': t0.supplier_key})
- PROJECT(columns={'p_brand': brand, 'p_size': size, 'p_type': part_type, 'supplier_key': supplier_key})
- JOIN(conditions=[t0.key == t1.part_key], types=['inner'], columns={'brand': t0.brand, 'part_type': t0.part_type, 'size': t0.size, 'supplier_key': t1.supplier_key})
+ JOIN(conditions=[t0.key == t1.part_key], types=['inner'], columns={'p_brand': t0.p_brand, 'p_size': t0.p_size, 'p_type': t0.p_type, 'supplier_key': t1.supplier_key})
+ PROJECT(columns={'key': key, 'p_brand': brand, 'p_size': size, 'p_type': part_type})
FILTER(condition=brand != 'BRAND#45':string & NOT(STARTSWITH(part_type, 'MEDIUM POLISHED%':string)) & ISIN(size, [49, 14, 23, 45, 19, 3, 36, 9]:array[unknown]), columns={'brand': brand, 'key': key, 'part_type': part_type, 'size': size})
SCAN(table=tpch.PART, columns={'brand': p_brand, 'key': p_partkey, 'part_type': p_type, 'size': p_size})
- SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey})
+ SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey})
SCAN(table=tpch.SUPPLIER, columns={'comment': s_comment, 'key': s_suppkey})
diff --git a/tests/test_plan_refsols/tpch_q17.txt b/tests/test_plan_refsols/tpch_q17.txt
index 09e258eb..15d80270 100644
--- a/tests/test_plan_refsols/tpch_q17.txt
+++ b/tests/test_plan_refsols/tpch_q17.txt
@@ -1,9 +1,9 @@
ROOT(columns=[('AVG_YEARLY', AVG_YEARLY)], orderings=[])
PROJECT(columns={'AVG_YEARLY': DEFAULT_TO(agg_0, 0:int64) / 7.0:float64})
AGGREGATE(keys={}, aggregations={'agg_0': SUM(extended_price)})
- FILTER(condition=quantity < 0.2:float64 * avg_quantity, columns={'extended_price': extended_price})
- JOIN(conditions=[t0.key == t1.part_key], types=['inner'], columns={'avg_quantity': t0.avg_quantity, 'extended_price': t1.extended_price, 'quantity': t1.quantity})
- PROJECT(columns={'avg_quantity': agg_0, 'key': key})
+ FILTER(condition=quantity < 0.2:float64 * part_avg_quantity, columns={'extended_price': extended_price})
+ JOIN(conditions=[t0.key == t1.part_key], types=['inner'], columns={'extended_price': t1.extended_price, 'part_avg_quantity': t0.part_avg_quantity, 'quantity': t1.quantity})
+ PROJECT(columns={'key': key, 'part_avg_quantity': agg_0})
JOIN(conditions=[t0.key == t1.part_key], types=['left'], columns={'agg_0': t1.agg_0, 'key': t0.key})
FILTER(condition=brand == 'Brand#23':string & container == 'MED BOX':string, columns={'key': key})
SCAN(table=tpch.PART, columns={'brand': p_brand, 'container': p_container, 'key': p_partkey})
diff --git a/tests/test_plan_refsols/tpch_q2.txt b/tests/test_plan_refsols/tpch_q2.txt
index 461d2e30..55104390 100644
--- a/tests/test_plan_refsols/tpch_q2.txt
+++ b/tests/test_plan_refsols/tpch_q2.txt
@@ -1,9 +1,9 @@
ROOT(columns=[('S_ACCTBAL', S_ACCTBAL), ('S_NAME', S_NAME), ('N_NAME', N_NAME), ('P_PARTKEY', P_PARTKEY), ('P_MFGR', P_MFGR), ('S_ADDRESS', S_ADDRESS), ('S_PHONE', S_PHONE), ('S_COMMENT', S_COMMENT)], orderings=[(ordering_1):desc_last, (ordering_2):asc_first, (ordering_3):asc_first, (ordering_4):asc_first])
LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'N_NAME': N_NAME, 'P_MFGR': P_MFGR, 'P_PARTKEY': P_PARTKEY, 'S_ACCTBAL': S_ACCTBAL, 'S_ADDRESS': S_ADDRESS, 'S_COMMENT': S_COMMENT, 'S_NAME': S_NAME, 'S_PHONE': S_PHONE, 'ordering_1': ordering_1, 'ordering_2': ordering_2, 'ordering_3': ordering_3, 'ordering_4': ordering_4}, orderings=[(ordering_1):desc_last, (ordering_2):asc_first, (ordering_3):asc_first, (ordering_4):asc_first])
PROJECT(columns={'N_NAME': N_NAME, 'P_MFGR': P_MFGR, 'P_PARTKEY': P_PARTKEY, 'S_ACCTBAL': S_ACCTBAL, 'S_ADDRESS': S_ADDRESS, 'S_COMMENT': S_COMMENT, 'S_NAME': S_NAME, 'S_PHONE': S_PHONE, 'ordering_1': S_ACCTBAL, 'ordering_2': N_NAME, 'ordering_3': S_NAME, 'ordering_4': P_PARTKEY})
- PROJECT(columns={'N_NAME': n_name, 'P_MFGR': manufacturer, 'P_PARTKEY': key_19, 'S_ACCTBAL': s_acctbal, 'S_ADDRESS': s_address, 'S_COMMENT': s_comment, 'S_NAME': s_name, 'S_PHONE': s_phone})
- FILTER(condition=supplycost_21 == best_cost & ENDSWITH(part_type, 'BRASS':string) & size == 15:int64, columns={'key_19': key_19, 'manufacturer': manufacturer, 'n_name': n_name, 's_acctbal': s_acctbal, 's_address': s_address, 's_comment': s_comment, 's_name': s_name, 's_phone': s_phone})
- JOIN(conditions=[t0.key_9 == t1.key_19], types=['inner'], columns={'best_cost': t0.best_cost, 'key_19': t1.key_19, 'manufacturer': t1.manufacturer, 'n_name': t1.n_name, 'part_type': t1.part_type, 's_acctbal': t1.s_acctbal, 's_address': t1.s_address, 's_comment': t1.s_comment, 's_name': t1.s_name, 's_phone': t1.s_phone, 'size': t1.size, 'supplycost_21': t1.supplycost})
+ PROJECT(columns={'N_NAME': n_name_21, 'P_MFGR': manufacturer, 'P_PARTKEY': key_19, 'S_ACCTBAL': s_acctbal_22, 'S_ADDRESS': s_address_23, 'S_COMMENT': s_comment_24, 'S_NAME': s_name_25, 'S_PHONE': s_phone_26})
+ FILTER(condition=supplycost_27 == best_cost & ENDSWITH(part_type, 'BRASS':string) & size == 15:int64, columns={'key_19': key_19, 'manufacturer': manufacturer, 'n_name_21': n_name_21, 's_acctbal_22': s_acctbal_22, 's_address_23': s_address_23, 's_comment_24': s_comment_24, 's_name_25': s_name_25, 's_phone_26': s_phone_26})
+ JOIN(conditions=[t0.key_9 == t1.key_19], types=['inner'], columns={'best_cost': t0.best_cost, 'key_19': t1.key_19, 'manufacturer': t1.manufacturer, 'n_name_21': t1.n_name, 'part_type': t1.part_type, 's_acctbal_22': t1.s_acctbal, 's_address_23': t1.s_address, 's_comment_24': t1.s_comment, 's_name_25': t1.s_name, 's_phone_26': t1.s_phone, 'size': t1.size, 'supplycost_27': t1.supplycost})
PROJECT(columns={'best_cost': agg_0, 'key_9': key_9})
AGGREGATE(keys={'key_9': key_9}, aggregations={'agg_0': MIN(supplycost)})
FILTER(condition=ENDSWITH(part_type, 'BRASS':string) & size == 15:int64, columns={'key_9': key_9, 'supplycost': supplycost})
@@ -18,14 +18,15 @@ ROOT(columns=[('S_ACCTBAL', S_ACCTBAL), ('S_NAME', S_NAME), ('N_NAME', N_NAME),
SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost})
SCAN(table=tpch.PART, columns={'key': p_partkey, 'part_type': p_type, 'size': p_size})
FILTER(condition=ENDSWITH(part_type, 'BRASS':string) & size == 15:int64, columns={'key_19': key_19, 'manufacturer': manufacturer, 'n_name': n_name, 'part_type': part_type, 's_acctbal': s_acctbal, 's_address': s_address, 's_comment': s_comment, 's_name': s_name, 's_phone': s_phone, 'size': size, 'supplycost': supplycost})
- PROJECT(columns={'key_19': key_19, 'manufacturer': manufacturer, 'n_name': name, 'part_type': part_type, 's_acctbal': account_balance, 's_address': address, 's_comment': comment_14, 's_name': name_16, 's_phone': phone, 'size': size, 'supplycost': supplycost})
- JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'account_balance': t0.account_balance, 'address': t0.address, 'comment_14': t0.comment_14, 'key_19': t1.key, 'manufacturer': t1.manufacturer, 'name': t0.name, 'name_16': t0.name_16, 'part_type': t1.part_type, 'phone': t0.phone, 'size': t1.size, 'supplycost': t0.supplycost})
- JOIN(conditions=[t0.key_15 == t1.supplier_key], types=['inner'], columns={'account_balance': t0.account_balance, 'address': t0.address, 'comment_14': t0.comment_14, 'name': t0.name, 'name_16': t0.name_16, 'part_key': t1.part_key, 'phone': t0.phone, 'supplycost': t1.supplycost})
- JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'account_balance': t1.account_balance, 'address': t1.address, 'comment_14': t1.comment, 'key_15': t1.key, 'name': t0.name, 'name_16': t1.name, 'phone': t1.phone})
- FILTER(condition=name_13 == 'EUROPE':string, columns={'key': key, 'name': name})
- JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_13': t1.name})
- SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey})
+ JOIN(conditions=[t0.part_key == t1.key], types=['inner'], columns={'key_19': t1.key, 'manufacturer': t1.manufacturer, 'n_name': t0.n_name, 'part_type': t1.part_type, 's_acctbal': t0.s_acctbal, 's_address': t0.s_address, 's_comment': t0.s_comment, 's_name': t0.s_name, 's_phone': t0.s_phone, 'size': t1.size, 'supplycost': t0.supplycost})
+ JOIN(conditions=[t0.key_15 == t1.supplier_key], types=['inner'], columns={'n_name': t0.n_name, 'part_key': t1.part_key, 's_acctbal': t0.s_acctbal, 's_address': t0.s_address, 's_comment': t0.s_comment, 's_name': t0.s_name, 's_phone': t0.s_phone, 'supplycost': t1.supplycost})
+ PROJECT(columns={'key_15': key_15, 'n_name': n_name, 's_acctbal': account_balance, 's_address': address, 's_comment': comment_14, 's_name': name_16, 's_phone': phone})
+ JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'account_balance': t1.account_balance, 'address': t1.address, 'comment_14': t1.comment, 'key_15': t1.key, 'n_name': t0.n_name, 'name_16': t1.name, 'phone': t1.phone})
+ FILTER(condition=name_13 == 'EUROPE':string, columns={'key': key, 'n_name': n_name})
+ JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'n_name': t0.n_name, 'name_13': t1.name})
+ PROJECT(columns={'key': key, 'n_name': name, 'region_key': region_key})
+ SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey})
SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
SCAN(table=tpch.SUPPLIER, columns={'account_balance': s_acctbal, 'address': s_address, 'comment': s_comment, 'key': s_suppkey, 'name': s_name, 'nation_key': s_nationkey, 'phone': s_phone})
- SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost})
- SCAN(table=tpch.PART, columns={'key': p_partkey, 'manufacturer': p_mfgr, 'part_type': p_type, 'size': p_size})
+ SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost})
+ SCAN(table=tpch.PART, columns={'key': p_partkey, 'manufacturer': p_mfgr, 'part_type': p_type, 'size': p_size})
diff --git a/tests/test_plan_refsols/tpch_q21.txt b/tests/test_plan_refsols/tpch_q21.txt
index fdf99273..0a315d5d 100644
--- a/tests/test_plan_refsols/tpch_q21.txt
+++ b/tests/test_plan_refsols/tpch_q21.txt
@@ -10,12 +10,13 @@ ROOT(columns=[('S_NAME', S_NAME), ('NUMWAIT', NUMWAIT)], orderings=[(ordering_1)
AGGREGATE(keys={'supplier_key': supplier_key}, aggregations={'agg_0': COUNT()})
FILTER(condition=order_status == 'F':string & True:bool & True:bool, columns={'supplier_key': supplier_key})
JOIN(conditions=[t0.key == t1.order_key], types=['anti'], columns={'order_status': t0.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr6')
- JOIN(conditions=[t0.key == t1.order_key], types=['semi'], columns={'key': t0.key, 'order_status': t0.order_status, 'supplier_key': t0.supplier_key}, correl_name='corr5')
- JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'key': t1.key, 'order_status': t1.order_status, 'supplier_key': t0.supplier_key})
- FILTER(condition=receipt_date > commit_date, columns={'order_key': order_key, 'supplier_key': supplier_key})
- SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate, 'supplier_key': l_suppkey})
+ JOIN(conditions=[t0.key == t1.order_key], types=['semi'], columns={'key': t0.key, 'order_status': t0.order_status, 'original_key': t0.original_key, 'supplier_key': t0.supplier_key}, correl_name='corr5')
+ JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'key': t1.key, 'order_status': t1.order_status, 'original_key': t0.original_key, 'supplier_key': t0.supplier_key})
+ FILTER(condition=receipt_date > commit_date, columns={'order_key': order_key, 'original_key': original_key, 'supplier_key': supplier_key})
+ PROJECT(columns={'commit_date': commit_date, 'order_key': order_key, 'original_key': supplier_key, 'receipt_date': receipt_date, 'supplier_key': supplier_key})
+ SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate, 'supplier_key': l_suppkey})
SCAN(table=tpch.ORDERS, columns={'key': o_orderkey, 'order_status': o_orderstatus})
- FILTER(condition=supplier_key != corr5.supplier_key, columns={'order_key': order_key})
+ FILTER(condition=supplier_key != corr5.original_key, columns={'order_key': order_key})
SCAN(table=tpch.LINEITEM, columns={'order_key': l_orderkey, 'supplier_key': l_suppkey})
- FILTER(condition=supplier_key != corr6.supplier_key & receipt_date > commit_date, columns={'order_key': order_key})
+ FILTER(condition=supplier_key != corr6.original_key & receipt_date > commit_date, columns={'order_key': order_key})
SCAN(table=tpch.LINEITEM, columns={'commit_date': l_commitdate, 'order_key': l_orderkey, 'receipt_date': l_receiptdate, 'supplier_key': l_suppkey})
diff --git a/tests/test_plan_refsols/tpch_q22.txt b/tests/test_plan_refsols/tpch_q22.txt
index 01b9f860..ba731e1b 100644
--- a/tests/test_plan_refsols/tpch_q22.txt
+++ b/tests/test_plan_refsols/tpch_q22.txt
@@ -2,12 +2,12 @@ ROOT(columns=[('CNTRY_CODE', CNTRY_CODE), ('NUM_CUSTS', NUM_CUSTS), ('TOTACCTBAL
PROJECT(columns={'CNTRY_CODE': CNTRY_CODE, 'NUM_CUSTS': NUM_CUSTS, 'TOTACCTBAL': TOTACCTBAL, 'ordering_3': CNTRY_CODE})
PROJECT(columns={'CNTRY_CODE': cntry_code, 'NUM_CUSTS': DEFAULT_TO(agg_1, 0:int64), 'TOTACCTBAL': DEFAULT_TO(agg_2, 0:int64)})
AGGREGATE(keys={'cntry_code': cntry_code}, aggregations={'agg_1': COUNT(), 'agg_2': SUM(acctbal)})
- FILTER(condition=acctbal > avg_balance & DEFAULT_TO(agg_0, 0:int64) == 0:int64, columns={'acctbal': acctbal, 'cntry_code': cntry_code})
- JOIN(conditions=[t0.key == t1.customer_key], types=['left'], columns={'acctbal': t0.acctbal, 'agg_0': t1.agg_0, 'avg_balance': t0.avg_balance, 'cntry_code': t0.cntry_code})
- FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]), columns={'acctbal': acctbal, 'avg_balance': avg_balance, 'cntry_code': cntry_code, 'key': key})
- PROJECT(columns={'acctbal': acctbal, 'avg_balance': avg_balance, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'key': key})
- JOIN(conditions=[True:bool], types=['inner'], columns={'acctbal': t1.acctbal, 'avg_balance': t0.avg_balance, 'key': t1.key, 'phone': t1.phone})
- PROJECT(columns={'avg_balance': agg_0})
+ FILTER(condition=acctbal > global_avg_balance & DEFAULT_TO(agg_0, 0:int64) == 0:int64, columns={'acctbal': acctbal, 'cntry_code': cntry_code})
+ JOIN(conditions=[t0.key == t1.customer_key], types=['left'], columns={'acctbal': t0.acctbal, 'agg_0': t1.agg_0, 'cntry_code': t0.cntry_code, 'global_avg_balance': t0.global_avg_balance})
+ FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]), columns={'acctbal': acctbal, 'cntry_code': cntry_code, 'global_avg_balance': global_avg_balance, 'key': key})
+ PROJECT(columns={'acctbal': acctbal, 'cntry_code': SLICE(phone, None:unknown, 2:int64, None:unknown), 'global_avg_balance': global_avg_balance, 'key': key})
+ JOIN(conditions=[True:bool], types=['inner'], columns={'acctbal': t1.acctbal, 'global_avg_balance': t0.global_avg_balance, 'key': t1.key, 'phone': t1.phone})
+ PROJECT(columns={'global_avg_balance': agg_0})
AGGREGATE(keys={}, aggregations={'agg_0': AVG(acctbal)})
FILTER(condition=acctbal > 0.0:float64, columns={'acctbal': acctbal})
FILTER(condition=ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']:array[unknown]), columns={'acctbal': acctbal})
diff --git a/tests/test_plan_refsols/tpch_q5.txt b/tests/test_plan_refsols/tpch_q5.txt
index 7cb9925b..c1fc94de 100644
--- a/tests/test_plan_refsols/tpch_q5.txt
+++ b/tests/test_plan_refsols/tpch_q5.txt
@@ -8,15 +8,16 @@ ROOT(columns=[('N_NAME', N_NAME), ('REVENUE', REVENUE)], orderings=[(ordering_1)
SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
AGGREGATE(keys={'key': key}, aggregations={'agg_0': SUM(value)})
PROJECT(columns={'key': key, 'value': extended_price * 1:int64 - discount})
- FILTER(condition=name_15 == name, columns={'discount': discount, 'extended_price': extended_price, 'key': key})
- JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'key': t0.key, 'name': t0.name, 'name_15': t1.name_15})
- JOIN(conditions=[t0.key_11 == t1.order_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'key': t0.key, 'name': t0.name, 'supplier_key': t1.supplier_key})
- FILTER(condition=order_date >= datetime.date(1994, 1, 1):date & order_date < datetime.date(1995, 1, 1):date, columns={'key': key, 'key_11': key_11, 'name': name})
- JOIN(conditions=[t0.key_8 == t1.customer_key], types=['inner'], columns={'key': t0.key, 'key_11': t1.key, 'name': t0.name, 'order_date': t1.order_date})
- JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_8': t1.key, 'name': t0.name})
- FILTER(condition=name_6 == 'ASIA':string, columns={'key': key, 'name': name})
- JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name': t0.name, 'name_6': t1.name})
- SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey})
+ FILTER(condition=name_15 == nation_name, columns={'discount': discount, 'extended_price': extended_price, 'key': key})
+ JOIN(conditions=[t0.supplier_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'key': t0.key, 'name_15': t1.name_15, 'nation_name': t0.nation_name})
+ JOIN(conditions=[t0.key_11 == t1.order_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'key': t0.key, 'nation_name': t0.nation_name, 'supplier_key': t1.supplier_key})
+ FILTER(condition=order_date >= datetime.date(1994, 1, 1):date & order_date < datetime.date(1995, 1, 1):date, columns={'key': key, 'key_11': key_11, 'nation_name': nation_name})
+ JOIN(conditions=[t0.key_8 == t1.customer_key], types=['inner'], columns={'key': t0.key, 'key_11': t1.key, 'nation_name': t0.nation_name, 'order_date': t1.order_date})
+ JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key': t0.key, 'key_8': t1.key, 'nation_name': t0.nation_name})
+ FILTER(condition=name_6 == 'ASIA':string, columns={'key': key, 'nation_name': nation_name})
+ JOIN(conditions=[t0.region_key == t1.key], types=['left'], columns={'key': t0.key, 'name_6': t1.name, 'nation_name': t0.nation_name})
+ PROJECT(columns={'key': key, 'nation_name': name, 'region_key': region_key})
+ SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name, 'region_key': n_regionkey})
SCAN(table=tpch.REGION, columns={'key': r_regionkey, 'name': r_name})
SCAN(table=tpch.CUSTOMER, columns={'key': c_custkey, 'nation_key': c_nationkey})
SCAN(table=tpch.ORDERS, columns={'customer_key': o_custkey, 'key': o_orderkey, 'order_date': o_orderdate})
diff --git a/tests/test_plan_refsols/tpch_q8.txt b/tests/test_plan_refsols/tpch_q8.txt
index 7ae110a9..398b235f 100644
--- a/tests/test_plan_refsols/tpch_q8.txt
+++ b/tests/test_plan_refsols/tpch_q8.txt
@@ -3,15 +3,16 @@ ROOT(columns=[('O_YEAR', O_YEAR), ('MKT_SHARE', MKT_SHARE)], orderings=[])
AGGREGATE(keys={'o_year': o_year}, aggregations={'agg_0': SUM(brazil_volume), 'agg_1': SUM(volume)})
FILTER(condition=order_date >= datetime.date(1995, 1, 1):date & order_date <= datetime.date(1996, 12, 31):date & name_18 == 'AMERICA':string, columns={'brazil_volume': brazil_volume, 'o_year': o_year, 'volume': volume})
JOIN(conditions=[t0.customer_key == t1.key], types=['left'], columns={'brazil_volume': t0.brazil_volume, 'name_18': t1.name_18, 'o_year': t0.o_year, 'order_date': t0.order_date, 'volume': t0.volume})
- PROJECT(columns={'brazil_volume': IFF(name == 'BRAZIL':string, volume, 0:int64), 'customer_key': customer_key, 'o_year': YEAR(order_date), 'order_date': order_date, 'volume': volume})
- JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'customer_key': t1.customer_key, 'name': t0.name, 'order_date': t1.order_date, 'volume': t0.volume})
- PROJECT(columns={'name': name, 'order_key': order_key, 'volume': extended_price * 1:int64 - discount})
- JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'name': t0.name, 'order_key': t1.order_key})
- FILTER(condition=part_type == 'ECONOMY ANODIZED STEEL':string, columns={'name': name, 'part_key': part_key, 'supplier_key': supplier_key})
- JOIN(conditions=[t0.part_key == t1.key], types=['left'], columns={'name': t0.name, 'part_key': t0.part_key, 'part_type': t1.part_type, 'supplier_key': t0.supplier_key})
- JOIN(conditions=[t0.key_2 == t1.supplier_key], types=['inner'], columns={'name': t0.name, 'part_key': t1.part_key, 'supplier_key': t1.supplier_key})
- JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name})
- SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name})
+ PROJECT(columns={'brazil_volume': IFF(nation_name == 'BRAZIL':string, volume, 0:int64), 'customer_key': customer_key, 'o_year': YEAR(order_date), 'order_date': order_date, 'volume': volume})
+ JOIN(conditions=[t0.order_key == t1.key], types=['inner'], columns={'customer_key': t1.customer_key, 'nation_name': t0.nation_name, 'order_date': t1.order_date, 'volume': t0.volume})
+ PROJECT(columns={'nation_name': nation_name, 'order_key': order_key, 'volume': extended_price * 1:int64 - discount})
+ JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'nation_name': t0.nation_name, 'order_key': t1.order_key})
+ FILTER(condition=part_type == 'ECONOMY ANODIZED STEEL':string, columns={'nation_name': nation_name, 'part_key': part_key, 'supplier_key': supplier_key})
+ JOIN(conditions=[t0.part_key == t1.key], types=['left'], columns={'nation_name': t0.nation_name, 'part_key': t0.part_key, 'part_type': t1.part_type, 'supplier_key': t0.supplier_key})
+ JOIN(conditions=[t0.key_2 == t1.supplier_key], types=['inner'], columns={'nation_name': t0.nation_name, 'part_key': t1.part_key, 'supplier_key': t1.supplier_key})
+ JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_2': t1.key, 'nation_name': t0.nation_name})
+ PROJECT(columns={'key': key, 'nation_name': name})
+ SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name})
SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey})
SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey})
SCAN(table=tpch.PART, columns={'key': p_partkey, 'part_type': p_type})
diff --git a/tests/test_plan_refsols/tpch_q9.txt b/tests/test_plan_refsols/tpch_q9.txt
index 73b07439..4ef2e0df 100644
--- a/tests/test_plan_refsols/tpch_q9.txt
+++ b/tests/test_plan_refsols/tpch_q9.txt
@@ -1,16 +1,17 @@
ROOT(columns=[('NATION', NATION), ('O_YEAR', O_YEAR), ('AMOUNT', AMOUNT)], orderings=[(ordering_1):asc_first, (ordering_2):desc_last])
LIMIT(limit=Literal(value=10, type=Int64Type()), columns={'AMOUNT': AMOUNT, 'NATION': NATION, 'O_YEAR': O_YEAR, 'ordering_1': ordering_1, 'ordering_2': ordering_2}, orderings=[(ordering_1):asc_first, (ordering_2):desc_last])
PROJECT(columns={'AMOUNT': AMOUNT, 'NATION': NATION, 'O_YEAR': O_YEAR, 'ordering_1': NATION, 'ordering_2': O_YEAR})
- PROJECT(columns={'AMOUNT': DEFAULT_TO(agg_0, 0:int64), 'NATION': nation, 'O_YEAR': o_year})
- AGGREGATE(keys={'nation': nation, 'o_year': o_year}, aggregations={'agg_0': SUM(value)})
- PROJECT(columns={'nation': name, 'o_year': YEAR(order_date), 'value': extended_price * 1:int64 - discount - supplycost * quantity})
- JOIN(conditions=[t0.order_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'name': t0.name, 'order_date': t1.order_date, 'quantity': t0.quantity, 'supplycost': t0.supplycost})
- JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'name': t0.name, 'order_key': t1.order_key, 'quantity': t1.quantity, 'supplycost': t0.supplycost})
- FILTER(condition=CONTAINS(name_7, 'green':string), columns={'name': name, 'part_key': part_key, 'supplier_key': supplier_key, 'supplycost': supplycost})
- JOIN(conditions=[t0.part_key == t1.key], types=['left'], columns={'name': t0.name, 'name_7': t1.name, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key, 'supplycost': t0.supplycost})
- JOIN(conditions=[t0.key_2 == t1.supplier_key], types=['inner'], columns={'name': t0.name, 'part_key': t1.part_key, 'supplier_key': t1.supplier_key, 'supplycost': t1.supplycost})
- JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_2': t1.key, 'name': t0.name})
- SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name})
+ PROJECT(columns={'AMOUNT': DEFAULT_TO(agg_0, 0:int64), 'NATION': nation_name, 'O_YEAR': o_year})
+ AGGREGATE(keys={'nation_name': nation_name, 'o_year': o_year}, aggregations={'agg_0': SUM(value)})
+ PROJECT(columns={'nation_name': nation_name, 'o_year': YEAR(order_date), 'value': extended_price * 1:int64 - discount - supplycost * quantity})
+ JOIN(conditions=[t0.order_key == t1.key], types=['left'], columns={'discount': t0.discount, 'extended_price': t0.extended_price, 'nation_name': t0.nation_name, 'order_date': t1.order_date, 'quantity': t0.quantity, 'supplycost': t0.supplycost})
+ JOIN(conditions=[t0.part_key == t1.part_key & t0.supplier_key == t1.supplier_key], types=['inner'], columns={'discount': t1.discount, 'extended_price': t1.extended_price, 'nation_name': t0.nation_name, 'order_key': t1.order_key, 'quantity': t1.quantity, 'supplycost': t0.supplycost})
+ FILTER(condition=CONTAINS(name_7, 'green':string), columns={'nation_name': nation_name, 'part_key': part_key, 'supplier_key': supplier_key, 'supplycost': supplycost})
+ JOIN(conditions=[t0.part_key == t1.key], types=['left'], columns={'name_7': t1.name, 'nation_name': t0.nation_name, 'part_key': t0.part_key, 'supplier_key': t0.supplier_key, 'supplycost': t0.supplycost})
+ JOIN(conditions=[t0.key_2 == t1.supplier_key], types=['inner'], columns={'nation_name': t0.nation_name, 'part_key': t1.part_key, 'supplier_key': t1.supplier_key, 'supplycost': t1.supplycost})
+ JOIN(conditions=[t0.key == t1.nation_key], types=['inner'], columns={'key_2': t1.key, 'nation_name': t0.nation_name})
+ PROJECT(columns={'key': key, 'nation_name': name})
+ SCAN(table=tpch.NATION, columns={'key': n_nationkey, 'name': n_name})
SCAN(table=tpch.SUPPLIER, columns={'key': s_suppkey, 'nation_key': s_nationkey})
SCAN(table=tpch.PARTSUPP, columns={'part_key': ps_partkey, 'supplier_key': ps_suppkey, 'supplycost': ps_supplycost})
SCAN(table=tpch.PART, columns={'key': p_partkey, 'name': p_name})
diff --git a/tests/test_plan_refsols/triple_partition.txt b/tests/test_plan_refsols/triple_partition.txt
index d078683d..27b03c42 100644
--- a/tests/test_plan_refsols/triple_partition.txt
+++ b/tests/test_plan_refsols/triple_partition.txt
@@ -1,4 +1,4 @@
-ROOT(columns=[('supp_region', supp_region), ('avg_percentage', avg_percentage)], orderings=[(ordering_1):asc_first])
+ROOT(columns=[('region', supp_region), ('avgpct', avg_percentage)], orderings=[(ordering_1):asc_first])
PROJECT(columns={'avg_percentage': avg_percentage, 'ordering_1': supp_region, 'supp_region': supp_region})
PROJECT(columns={'avg_percentage': agg_0, 'supp_region': supp_region})
AGGREGATE(keys={'supp_region': supp_region}, aggregations={'agg_0': AVG(percentage)})
diff --git a/tests/test_plan_refsols/years_months_days_hours_datediff.txt b/tests/test_plan_refsols/years_months_days_hours_datediff.txt
new file mode 100644
index 00000000..f1bd8571
--- /dev/null
+++ b/tests/test_plan_refsols/years_months_days_hours_datediff.txt
@@ -0,0 +1,6 @@
+ROOT(columns=[('x', x), ('y1', y1), ('years_diff', years_diff), ('c_years_diff', c_years_diff), ('c_y_diff', c_y_diff), ('y_diff', y_diff), ('months_diff', months_diff), ('c_months_diff', c_months_diff), ('mm_diff', mm_diff), ('days_diff', days_diff), ('c_days_diff', c_days_diff), ('c_d_diff', c_d_diff), ('d_diff', d_diff), ('hours_diff', hours_diff), ('c_hours_diff', c_hours_diff), ('c_h_diff', c_h_diff)], orderings=[(ordering_0):asc_first])
+ LIMIT(limit=Literal(value=30, type=Int64Type()), columns={'c_d_diff': c_d_diff, 'c_days_diff': c_days_diff, 'c_h_diff': c_h_diff, 'c_hours_diff': c_hours_diff, 'c_months_diff': c_months_diff, 'c_y_diff': c_y_diff, 'c_years_diff': c_years_diff, 'd_diff': d_diff, 'days_diff': days_diff, 'hours_diff': hours_diff, 'mm_diff': mm_diff, 'months_diff': months_diff, 'ordering_0': ordering_0, 'x': x, 'y1': y1, 'y_diff': y_diff, 'years_diff': years_diff}, orderings=[(ordering_0):asc_first])
+ PROJECT(columns={'c_d_diff': c_d_diff, 'c_days_diff': c_days_diff, 'c_h_diff': c_h_diff, 'c_hours_diff': c_hours_diff, 'c_months_diff': c_months_diff, 'c_y_diff': c_y_diff, 'c_years_diff': c_years_diff, 'd_diff': d_diff, 'days_diff': days_diff, 'hours_diff': hours_diff, 'mm_diff': mm_diff, 'months_diff': months_diff, 'ordering_0': years_diff, 'x': x, 'y1': y1, 'y_diff': y_diff, 'years_diff': years_diff})
+ PROJECT(columns={'c_d_diff': DATEDIFF('D':string, date_time, datetime.datetime(2025, 5, 2, 11, 0):date), 'c_days_diff': DATEDIFF('DAYS':string, date_time, datetime.datetime(2025, 5, 2, 11, 0):date), 'c_h_diff': DATEDIFF('H':string, date_time, datetime.datetime(2025, 5, 2, 11, 0):date), 'c_hours_diff': DATEDIFF('HOURS':string, date_time, datetime.datetime(2025, 5, 2, 11, 0):date), 'c_months_diff': DATEDIFF('MONTHS':string, date_time, datetime.datetime(2025, 5, 2, 11, 0):date), 'c_y_diff': DATEDIFF('Y':string, date_time, datetime.datetime(2025, 5, 2, 11, 0):date), 'c_years_diff': DATEDIFF('YEARS':string, date_time, datetime.datetime(2025, 5, 2, 11, 0):date), 'd_diff': DATEDIFF('d':string, date_time, datetime.datetime(2025, 5, 2, 11, 0):date), 'days_diff': DATEDIFF('days':string, date_time, datetime.datetime(2025, 5, 2, 11, 0):date), 'hours_diff': DATEDIFF('hours':string, date_time, datetime.datetime(2025, 5, 2, 11, 0):date), 'mm_diff': DATEDIFF('mm':string, date_time, datetime.datetime(2025, 5, 2, 11, 0):date), 'months_diff': DATEDIFF('months':string, date_time, datetime.datetime(2025, 5, 2, 11, 0):date), 'x': date_time, 'y1': datetime.datetime(2025, 5, 2, 11, 0):date, 'y_diff': DATEDIFF('y':string, date_time, datetime.datetime(2025, 5, 2, 11, 0):date), 'years_diff': DATEDIFF('years':string, date_time, datetime.datetime(2025, 5, 2, 11, 0):date)})
+ FILTER(condition=YEAR(date_time) < 2025:int64, columns={'date_time': date_time})
+ SCAN(table=main.sbTransaction, columns={'date_time': sbTxDateTime})
diff --git a/tests/test_pydough_to_sql.py b/tests/test_pydough_to_sql.py
index 34723beb..302b941b 100644
--- a/tests/test_pydough_to_sql.py
+++ b/tests/test_pydough_to_sql.py
@@ -28,35 +28,41 @@
@pytest.mark.parametrize(
- "pydough_code, test_name",
+ "pydough_code, columns, test_name",
[
pytest.param(
simple_scan,
+ None,
"simple_scan",
id="simple_scan",
),
pytest.param(
simple_filter,
+ ["order_date", "o_orderkey", "o_totalprice"],
"simple_filter",
id="simple_filter",
),
pytest.param(
rank_a,
+ {"id": "key", "rk": "rank"},
"rank_a",
id="rank_a",
),
pytest.param(
rank_b,
+ {"order_key": "key", "rank": "rank"},
"rank_b",
id="rank_b",
),
pytest.param(
rank_c,
+ None,
"rank_c",
id="rank_c",
),
pytest.param(
datetime_sampler,
+ None,
"datetime_sampler",
id="datetime_sampler",
),
@@ -64,6 +70,7 @@
)
def test_pydough_to_sql_tpch(
pydough_code: Callable[[], UnqualifiedNode],
+ columns: dict[str, str] | list[str] | None,
test_name: str,
get_sample_graph: graph_fetcher,
get_sql_test_filename: Callable[[str, DatabaseDialect], str],
@@ -77,7 +84,7 @@ def test_pydough_to_sql_tpch(
graph: GraphMetadata = get_sample_graph("TPCH")
root: UnqualifiedNode = init_pydough_context(graph)(pydough_code)()
actual_sql: str = to_sql(
- root, metadata=graph, database=empty_context_database
+ root, columns=columns, metadata=graph, database=empty_context_database
).strip()
file_path: str = get_sql_test_filename(test_name, empty_context_database.dialect)
if update_tests:
diff --git a/tests/test_qdag_collection.py b/tests/test_qdag_collection.py
index 931ee9d8..86fb5346 100644
--- a/tests/test_qdag_collection.py
+++ b/tests/test_qdag_collection.py
@@ -4,9 +4,8 @@
import pytest
from test_utils import (
- BackReferenceCollectionInfo,
BackReferenceExpressionInfo,
- CalcInfo,
+ CalculateInfo,
ChildReferenceCollectionInfo,
ChildReferenceExpressionInfo,
CollectionTestInfo,
@@ -31,20 +30,20 @@
@pytest.fixture
-def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
+def region_intra_pct() -> tuple[CollectionTestInfo, str, str]:
"""
- The QDAG node info for a query that calculates the ratio for each region
- between the of all part sale values (retail price of the part times the
- quantity purchased for each time it was purchased) that were sold to a
- customer in the same region vs all part sale values from that region, only
- counting sales where the shipping mode was 'AIR'.
+ The QDAG node info for a query that calculates, for each region, the
+ percentage of all part sale values (retail price of the part times the
+ total quantity purchased for each time it was purchased) that were sold to
+ a customer in the same region vs all part sale values from that region,
+ only counting sales where the shipping mode was 'AIR'.
Equivalent SQL query:
```
SELECT
R1.name AS region_name
- SUM(P.p_retailprice * IFF(R1.region_name == R2.region_name, L.quantity, 0)) /
- SUM(P.p_retailprice * L.quantity) AS intra_ratio
+ 10.0 * SUM(P.p_retailprice * (R1.region_name == R2.region_name)) /
+ SUM(P.p_retailprice * L.quantity) AS intra_pct
FROM
REGION R1,
NATION N1,
@@ -70,100 +69,117 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
Equivalent PyDough code:
```
- is_intra_line = IFF(order.customer.region.name == BACK(3).name, 1, 0)
- part_sales = suppliers.parts_supplied.ps_lines.WHERE(
+ part_sales = nations.suppliers.supply_records.CALCULATE(
+ retail_price=part.retail_price
+ ).lines.WHERE(
shipmode == 'AIR
- )(
- intra_value = BACK(1).retail_price * quantity * is_intra_line,
- total_value = BACK(1).retail_price * quantity,
+ ).CALCULATE(
+ is_intra = order.customer.region.name == region_name.name,
+ value = retail_price * quantity,
)
- Regions(
- region_name = name
- intra_sales = SUM(part_sales.intra_value) / SUM(part_sales.total_value)
+ result = Regions.CALCULATE(region_name = name).CALCULATE(
+ region_name,
+ intra_pct = 100.0 * SUM(part_sales.value * part_sales.is_intra) / SUM(part_sales.value)
)
```
"""
- test_info: CollectionTestInfo = TableCollectionInfo("Regions") ** CalcInfo(
- [
- SubCollectionInfo("suppliers")
- ** SubCollectionInfo("parts_supplied")
- ** SubCollectionInfo("ps_lines")
- ** WhereInfo(
- [],
- FunctionInfo(
- "EQU",
- [ReferenceInfo("ship_mode"), LiteralInfo("AIR", StringType())],
- ),
- )
- ** CalcInfo(
- [
- SubCollectionInfo("order")
- ** SubCollectionInfo("customer")
- ** SubCollectionInfo("region")
- ],
- value=FunctionInfo(
- "MUL",
- [
- BackReferenceExpressionInfo("retail_price", 1),
- ReferenceInfo("quantity"),
- ],
- ),
- adj_value=FunctionInfo(
- "MUL",
+ test_info: CollectionTestInfo = (
+ TableCollectionInfo("Regions")
+ ** CalculateInfo([], region_name=ReferenceInfo("name"))
+ ** CalculateInfo(
+ [
+ SubCollectionInfo("nations")
+ ** SubCollectionInfo("suppliers")
+ ** SubCollectionInfo("supply_records")
+ ** CalculateInfo(
+ [SubCollectionInfo("part")],
+ retail_price=ChildReferenceExpressionInfo("retail_price", 0),
+ )
+ ** SubCollectionInfo("lines")
+ ** WhereInfo(
+ [],
+ FunctionInfo(
+ "EQU",
+ [ReferenceInfo("ship_mode"), LiteralInfo("AIR", StringType())],
+ ),
+ )
+ ** CalculateInfo(
[
- BackReferenceExpressionInfo("retail_price", 1),
- FunctionInfo(
- "MUL",
- [
- ReferenceInfo("quantity"),
- FunctionInfo(
- "IFF",
- [
- FunctionInfo(
- "EQU",
- [
- BackReferenceExpressionInfo("name", 3),
- ChildReferenceExpressionInfo("name", 0),
- ],
- ),
- LiteralInfo(1, Int64Type()),
- LiteralInfo(0, Int64Type()),
- ],
- ),
- ],
- ),
+ SubCollectionInfo("order")
+ ** SubCollectionInfo("customer")
+ ** SubCollectionInfo("region")
],
- ),
- )
- ],
- region_name=ReferenceInfo("name"),
- intra_ratio=FunctionInfo(
- "DIV",
- [
- FunctionInfo("SUM", [ChildReferenceExpressionInfo("adj_value", 0)]),
- FunctionInfo("SUM", [ChildReferenceExpressionInfo("value", 0)]),
+ is_intra=FunctionInfo(
+ "EQU",
+ [
+ ChildReferenceExpressionInfo("name", 0),
+ BackReferenceExpressionInfo("region_name", 4),
+ ],
+ ),
+ value=FunctionInfo(
+ "MUL",
+ [
+ BackReferenceExpressionInfo("retail_price", 1),
+ ReferenceInfo("quantity"),
+ ],
+ ),
+ )
],
- ),
+ region_name=ReferenceInfo("region_name"),
+ intra_pct=FunctionInfo(
+ "MUL",
+ [
+ LiteralInfo(100.0, Float64Type()),
+ FunctionInfo(
+ "DIV",
+ [
+ FunctionInfo(
+ "SUM",
+ [
+ FunctionInfo(
+ "MUL",
+ [
+ ChildReferenceExpressionInfo("value", 0),
+ ChildReferenceExpressionInfo("is_intra", 0),
+ ],
+ )
+ ],
+ ),
+ FunctionInfo(
+ "SUM", [ChildReferenceExpressionInfo("value", 0)]
+ ),
+ ],
+ ),
+ ],
+ ),
+ )
+ )
+ is_intra: str = "order.customer.region.name == region_name"
+ base_value: str = "retail_price * quantity"
+ path_to_lines = "nations.suppliers.supply_records.CALCULATE(retail_price=part.retail_price).lines.WHERE(ship_mode == 'AIR')"
+ part_values: str = (
+ f"{path_to_lines}.CALCULATE(is_intra={is_intra}, value={base_value})"
)
- adjustment: str = "IFF(BACK(3).name == order.customer.region.name, 1, 0)"
- base_value: str = "BACK(1).retail_price * quantity"
- adjusted_value: str = f"BACK(1).retail_price * (quantity * {adjustment})"
- part_values: str = f"suppliers.parts_supplied.ps_lines.WHERE(ship_mode == 'AIR')(value={base_value}, adj_value={adjusted_value})"
- string_representation: str = f"TPCH.Regions(region_name=name, intra_ratio=SUM({part_values}.adj_value) / SUM({part_values}.value))"
+ string_representation: str = f"TPCH.Regions.CALCULATE(region_name=name).CALCULATE(region_name=region_name, intra_pct=100.0 * (SUM({part_values}.value * {part_values}.is_intra) / SUM({part_values}.value)))"
tree_string_representation: str = """
──┬─ TPCH
├─── TableCollection[Regions]
- └─┬─ Calc[region_name=name, intra_ratio=SUM($1.adj_value) / SUM($1.value)]
+ ├─── Calculate[region_name=name]
+ └─┬─ Calculate[region_name=region_name, intra_pct=100.0 * (SUM($1.value * $1.is_intra) / SUM($1.value))]
└─┬─ AccessChild
- └─┬─ SubCollection[suppliers]
- └─┬─ SubCollection[parts_supplied]
- ├─── SubCollection[ps_lines]
- ├─── Where[ship_mode == 'AIR']
- └─┬─ Calc[value=BACK(1).retail_price * quantity, adj_value=BACK(1).retail_price * (quantity * IFF(BACK(3).name == $1.name, 1, 0))]
- └─┬─ AccessChild
- └─┬─ SubCollection[order]
- └─┬─ SubCollection[customer]
- └─── SubCollection[region]
+ └─┬─ SubCollection[nations]
+ └─┬─ SubCollection[suppliers]
+ ├─── SubCollection[supply_records]
+ └─┬─ Calculate[retail_price=$1.retail_price]
+ ├─┬─ AccessChild
+ │ └─── SubCollection[part]
+ ├─── SubCollection[lines]
+ ├─── Where[ship_mode == 'AIR']
+ └─┬─ Calculate[is_intra=$1.name == region_name, value=retail_price * quantity]
+ └─┬─ AccessChild
+ └─┬─ SubCollection[order]
+ └─┬─ SubCollection[customer]
+ └─── SubCollection[region]
"""
return test_info, string_representation, tree_string_representation
@@ -172,7 +188,9 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
"calc_pipeline, expected_calcs, expected_total_names",
[
pytest.param(
- CalcInfo([], x=LiteralInfo(1, Int64Type()), y=LiteralInfo(3, Int64Type())),
+ CalculateInfo(
+ [], x=LiteralInfo(1, Int64Type()), y=LiteralInfo(3, Int64Type())
+ ),
{"x": 0, "y": 1},
{
"x",
@@ -189,12 +207,12 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
id="global_calc",
),
pytest.param(
- CalcInfo(
+ CalculateInfo(
[
TableCollectionInfo("Suppliers"),
TableCollectionInfo("Parts")
** SubCollectionInfo("lines")
- ** CalcInfo(
+ ** CalculateInfo(
[],
value=FunctionInfo(
"MUL", [ReferenceInfo("quantity"), ReferenceInfo("tax")]
@@ -253,7 +271,7 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
),
pytest.param(
TableCollectionInfo("Regions")
- ** CalcInfo(
+ ** CalculateInfo(
[],
),
{},
@@ -272,7 +290,7 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
pytest.param(
(
TableCollectionInfo("Regions")
- ** CalcInfo(
+ ** CalculateInfo(
[], foo=LiteralInfo(42, Int64Type()), bar=ReferenceInfo("name")
)
),
@@ -294,13 +312,15 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
pytest.param(
(
TableCollectionInfo("Regions")
- ** CalcInfo(
+ ** CalculateInfo(
[], foo=LiteralInfo(42, Int64Type()), bar=ReferenceInfo("name")
)
** SubCollectionInfo("nations")
),
{"key": 0, "name": 1, "region_key": 2, "comment": 3},
{
+ "foo",
+ "bar",
"name",
"key",
"region_key",
@@ -316,7 +336,7 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
(
TableCollectionInfo("Regions")
** SubCollectionInfo("nations")
- ** CalcInfo(
+ ** CalculateInfo(
[], foo=LiteralInfo(42, Int64Type()), bar=ReferenceInfo("name")
)
),
@@ -338,10 +358,10 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
pytest.param(
(
TableCollectionInfo("Regions")
- ** CalcInfo(
+ ** CalculateInfo(
[], foo=LiteralInfo(42, Int64Type()), bar=ReferenceInfo("name")
)
- ** CalcInfo(
+ ** CalculateInfo(
[],
fizz=FunctionInfo(
"ADD", [ReferenceInfo("foo"), LiteralInfo(1, Int64Type())]
@@ -405,7 +425,7 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
(
TableCollectionInfo("Parts")
** SubCollectionInfo("suppliers_of_part")
- ** CalcInfo(
+ ** CalculateInfo(
[],
good_comment=FunctionInfo(
"EQU",
@@ -543,7 +563,7 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
"NEQ", [ReferenceInfo("name"), LiteralInfo("USA", StringType())]
),
)
- ** CalcInfo(
+ ** CalculateInfo(
[],
region_name=BackReferenceExpressionInfo("name", 1),
nation_name=ReferenceInfo("name"),
@@ -571,7 +591,7 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
** SubCollectionInfo("suppliers")
** SubCollectionInfo("supply_records")
** SubCollectionInfo("lines")
- ** CalcInfo(
+ ** CalculateInfo(
[],
region_name=BackReferenceExpressionInfo("name", 4),
nation_name=BackReferenceExpressionInfo("name", 3),
@@ -620,7 +640,7 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
(
TableCollectionInfo("Regions")
** SubCollectionInfo("lines_sourced_from")
- ** CalcInfo(
+ ** CalculateInfo(
[],
source_region_name=BackReferenceExpressionInfo("name", 1),
taxation=ReferenceInfo("tax"),
@@ -666,7 +686,7 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
"parts",
[ChildReferenceExpressionInfo("container", 0)],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("parts")],
container=ReferenceInfo("container"),
total_price=FunctionInfo(
@@ -684,7 +704,7 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
"parts",
[ChildReferenceExpressionInfo("container", 0)],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("parts")],
container=ReferenceInfo("container"),
total_price=FunctionInfo(
@@ -692,7 +712,7 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
),
)
** SubCollectionInfo("parts")
- ** CalcInfo(
+ ** CalculateInfo(
[],
part_name=ReferenceInfo("name"),
container=ReferenceInfo("container"),
@@ -720,6 +740,7 @@ def region_intra_ratio() -> tuple[CollectionTestInfo, str, str]:
"suppliers_of_part",
"supply_records",
"part_type",
+ "total_price",
},
id="partition_data_with_data_order",
),
@@ -733,18 +754,18 @@ def test_collections_calc_terms(
) -> None:
"""
Tests that a sequence of collection-producing QDAG nodes results in the
- correct calc terms & total set of available terms.
+ correct calculate terms & total set of available terms.
"""
collection: PyDoughCollectionQDAG = calc_pipeline.build(tpch_node_builder)
assert collection.calc_terms == set(
expected_calcs
- ), "Mismatch between set of calc terms and expected value"
+ ), "Mismatch between set of calculate terms and expected value"
actual_calcs: dict[str, int] = {
expr: collection.get_expression_position(expr) for expr in collection.calc_terms
}
assert (
actual_calcs == expected_calcs
- ), "Mismatch between positions of calc terms and expected value"
+ ), "Mismatch between positions of calculate terms and expected value"
assert (
collection.all_terms == expected_total_names
), "Mismatch between set of all terms and expected value"
@@ -754,21 +775,23 @@ def test_collections_calc_terms(
"calc_pipeline, expected_string, expected_tree_string",
[
pytest.param(
- CalcInfo([], x=LiteralInfo(1, Int64Type()), y=LiteralInfo(3, Int64Type())),
- "TPCH(x=1, y=3)",
+ CalculateInfo(
+ [], x=LiteralInfo(1, Int64Type()), y=LiteralInfo(3, Int64Type())
+ ),
+ "TPCH.CALCULATE(x=1, y=3)",
"""
┌─── TPCH
-└─── Calc[x=1, y=3]
+└─── Calculate[x=1, y=3]
""",
id="global_calc",
),
pytest.param(
- CalcInfo(
+ CalculateInfo(
[
TableCollectionInfo("Suppliers"),
TableCollectionInfo("Parts")
** SubCollectionInfo("lines")
- ** CalcInfo(
+ ** CalculateInfo(
[],
value=FunctionInfo(
"MUL", [ReferenceInfo("quantity"), ReferenceInfo("tax")]
@@ -780,16 +803,16 @@ def test_collections_calc_terms(
),
t_value=FunctionInfo("SUM", [ChildReferenceExpressionInfo("value", 1)]),
),
- "TPCH(n_balance=SUM(Suppliers.account_balance), t_value=SUM(Parts.lines(value=quantity * tax).value))",
+ "TPCH.CALCULATE(n_balance=SUM(Suppliers.account_balance), t_value=SUM(Parts.lines.CALCULATE(value=quantity * tax).value))",
"""
┌─── TPCH
-└─┬─ Calc[n_balance=SUM($1.account_balance), t_value=SUM($2.value)]
+└─┬─ Calculate[n_balance=SUM($1.account_balance), t_value=SUM($2.value)]
├─┬─ AccessChild
│ └─── TableCollection[Suppliers]
└─┬─ AccessChild
└─┬─ TableCollection[Parts]
├─── SubCollection[lines]
- └─── Calc[value=quantity * tax]
+ └─── Calculate[value=quantity * tax]
""",
id="global_nested_calc",
),
@@ -824,20 +847,20 @@ def test_collections_calc_terms(
),
pytest.param(
TableCollectionInfo("Regions")
- ** CalcInfo(
+ ** CalculateInfo(
[],
),
- "TPCH.Regions()",
+ "TPCH.Regions.CALCULATE()",
"""
──┬─ TPCH
├─── TableCollection[Regions]
- └─── Calc[]
+ └─── Calculate[]
""",
id="regions_empty_calc",
),
pytest.param(
TableCollectionInfo("Regions")
- ** CalcInfo(
+ ** CalculateInfo(
[],
region_name=ReferenceInfo("name"),
adjusted_key=FunctionInfo(
@@ -850,11 +873,11 @@ def test_collections_calc_terms(
],
),
),
- "TPCH.Regions(region_name=name, adjusted_key=(key - 1) * 2)",
+ "TPCH.Regions.CALCULATE(region_name=name, adjusted_key=(key - 1) * 2)",
"""
──┬─ TPCH
├─── TableCollection[Regions]
- └─── Calc[region_name=name, adjusted_key=(key - 1) * 2]
+ └─── Calculate[region_name=name, adjusted_key=(key - 1) * 2]
""",
id="regions_calc",
),
@@ -873,67 +896,54 @@ def test_collections_calc_terms(
"NEQ", [ReferenceInfo("name"), LiteralInfo("USA", StringType())]
),
)
- ** CalcInfo(
+ ** CalculateInfo(
[],
region_name=BackReferenceExpressionInfo("name", 1),
nation_name=ReferenceInfo("name"),
),
- "TPCH.Regions.WHERE(name != 'ASIA').nations.WHERE(name != 'USA')(region_name=BACK(1).name, nation_name=name)",
+ "TPCH.Regions.WHERE(name != 'ASIA').nations.WHERE(name != 'USA').CALCULATE(region_name=name, nation_name=name)",
"""
──┬─ TPCH
├─── TableCollection[Regions]
└─┬─ Where[name != 'ASIA']
├─── SubCollection[nations]
├─── Where[name != 'USA']
- └─── Calc[region_name=BACK(1).name, nation_name=name]
+ └─── Calculate[region_name=name, nation_name=name]
""",
id="regions_nations_calc",
),
pytest.param(
TableCollectionInfo("Regions")
** SubCollectionInfo("suppliers")
- ** CalcInfo(
+ ** CalculateInfo(
[],
region_name=BackReferenceExpressionInfo("name", 1),
nation_name=ReferenceInfo("nation_name"),
supplier_name=ReferenceInfo("name"),
),
- "TPCH.Regions.suppliers(region_name=BACK(1).name, nation_name=nation_name, supplier_name=name)",
+ "TPCH.Regions.suppliers.CALCULATE(region_name=name, nation_name=nation_name, supplier_name=name)",
"""
──┬─ TPCH
└─┬─ TableCollection[Regions]
├─── SubCollection[suppliers]
- └─── Calc[region_name=BACK(1).name, nation_name=nation_name, supplier_name=name]
+ └─── Calculate[region_name=name, nation_name=nation_name, supplier_name=name]
""",
id="regions_suppliers_calc",
),
- pytest.param(
- TableCollectionInfo("Parts")
- ** SubCollectionInfo("suppliers_of_part")
- ** SubCollectionInfo("ps_lines"),
- "TPCH.Parts.suppliers_of_part.ps_lines",
- """
-──┬─ TPCH
- └─┬─ TableCollection[Parts]
- └─┬─ SubCollection[suppliers_of_part]
- └─── SubCollection[ps_lines]
-""",
- id="parts_suppliers_lines",
- ),
pytest.param(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("suppliers")],
nation_name=ReferenceInfo("name"),
total_supplier_balances=FunctionInfo(
"SUM", [ChildReferenceExpressionInfo("account_balance", 0)]
),
),
- "TPCH.Nations(nation_name=name, total_supplier_balances=SUM(suppliers.account_balance))",
+ "TPCH.Nations.CALCULATE(nation_name=name, total_supplier_balances=SUM(suppliers.account_balance))",
"""
──┬─ TPCH
├─── TableCollection[Nations]
- └─┬─ Calc[nation_name=name, total_supplier_balances=SUM($1.account_balance)]
+ └─┬─ Calculate[nation_name=name, total_supplier_balances=SUM($1.account_balance)]
└─┬─ AccessChild
└─── SubCollection[suppliers]
""",
@@ -941,26 +951,28 @@ def test_collections_calc_terms(
),
pytest.param(
TableCollectionInfo("Regions")
- ** CalcInfo([], adj_name=FunctionInfo("LOWER", [ReferenceInfo("name")]))
+ ** CalculateInfo(
+ [], adj_name=FunctionInfo("LOWER", [ReferenceInfo("name")])
+ )
** SubCollectionInfo("nations")
- ** CalcInfo(
+ ** CalculateInfo(
[],
region_name=BackReferenceExpressionInfo("adj_name", 1),
nation_name=ReferenceInfo("name"),
),
- "TPCH.Regions(adj_name=LOWER(name)).nations(region_name=BACK(1).adj_name, nation_name=name)",
+ "TPCH.Regions.CALCULATE(adj_name=LOWER(name)).nations.CALCULATE(region_name=adj_name, nation_name=name)",
"""
──┬─ TPCH
├─── TableCollection[Regions]
- └─┬─ Calc[adj_name=LOWER(name)]
+ └─┬─ Calculate[adj_name=LOWER(name)]
├─── SubCollection[nations]
- └─── Calc[region_name=BACK(1).adj_name, nation_name=name]
+ └─── Calculate[region_name=adj_name, nation_name=name]
""",
id="regions_calc_nations_calc",
),
pytest.param(
TableCollectionInfo("Suppliers")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("parts_supplied")],
supplier_name=ReferenceInfo("name"),
total_retail_price=FunctionInfo(
@@ -976,11 +988,11 @@ def test_collections_calc_terms(
],
),
),
- "TPCH.Suppliers(supplier_name=name, total_retail_price=SUM(parts_supplied.retail_price - 1.0))",
+ "TPCH.Suppliers.CALCULATE(supplier_name=name, total_retail_price=SUM(parts_supplied.retail_price - 1.0))",
"""
──┬─ TPCH
├─── TableCollection[Suppliers]
- └─┬─ Calc[supplier_name=name, total_retail_price=SUM($1.retail_price - 1.0)]
+ └─┬─ Calculate[supplier_name=name, total_retail_price=SUM($1.retail_price - 1.0)]
└─┬─ AccessChild
└─── SubCollection[parts_supplied]
""",
@@ -988,10 +1000,10 @@ def test_collections_calc_terms(
),
pytest.param(
TableCollectionInfo("Suppliers")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("parts_supplied")
- ** CalcInfo(
+ ** CalculateInfo(
[],
adj_retail_price=FunctionInfo(
"SUB",
@@ -1007,115 +1019,77 @@ def test_collections_calc_terms(
"SUM", [ChildReferenceExpressionInfo("adj_retail_price", 0)]
),
),
- "TPCH.Suppliers(supplier_name=name, total_retail_price=SUM(parts_supplied(adj_retail_price=retail_price - 1.0).adj_retail_price))",
+ "TPCH.Suppliers.CALCULATE(supplier_name=name, total_retail_price=SUM(parts_supplied.CALCULATE(adj_retail_price=retail_price - 1.0).adj_retail_price))",
"""
──┬─ TPCH
├─── TableCollection[Suppliers]
- └─┬─ Calc[supplier_name=name, total_retail_price=SUM($1.adj_retail_price)]
+ └─┬─ Calculate[supplier_name=name, total_retail_price=SUM($1.adj_retail_price)]
└─┬─ AccessChild
├─── SubCollection[parts_supplied]
- └─── Calc[adj_retail_price=retail_price - 1.0]
+ └─── Calculate[adj_retail_price=retail_price - 1.0]
""",
id="suppliers_childcalc_parts_b",
),
pytest.param(
TableCollectionInfo("Suppliers")
- ** SubCollectionInfo("parts_supplied")
- ** CalcInfo(
- [
- SubCollectionInfo("ps_lines"),
- BackReferenceCollectionInfo("nation", 1)
- ** CalcInfo([], nation_name=ReferenceInfo("name")),
- ],
- nation_name=ChildReferenceExpressionInfo("nation_name", 1),
- supplier_name=BackReferenceExpressionInfo("name", 1),
- part_name=ReferenceInfo("name"),
+ ** SubCollectionInfo("supply_records")
+ ** CalculateInfo(
+ [SubCollectionInfo("lines")],
ratio=FunctionInfo(
"DIV",
[
FunctionInfo(
"SUM", [ChildReferenceExpressionInfo("quantity", 0)]
),
- ReferenceInfo("ps_availqty"),
+ ReferenceInfo("availqty"),
],
),
- ),
- "TPCH.Suppliers.parts_supplied(nation_name=BACK(1).nation(nation_name=name).nation_name, supplier_name=BACK(1).name, part_name=name, ratio=SUM(ps_lines.quantity) / ps_availqty)",
- """
-──┬─ TPCH
- └─┬─ TableCollection[Suppliers]
- ├─── SubCollection[parts_supplied]
- └─┬─ Calc[nation_name=$2.nation_name, supplier_name=BACK(1).name, part_name=name, ratio=SUM($1.quantity) / ps_availqty]
- ├─┬─ AccessChild
- │ └─── SubCollection[ps_lines]
- └─┬─ AccessChild
- ├─── BackSubCollection[1, nation]
- └─── Calc[nation_name=name]
-""",
- id="suppliers_parts_childcalc_a",
- ),
- pytest.param(
- TableCollectionInfo("Suppliers")
- ** SubCollectionInfo("parts_supplied")
- ** CalcInfo(
- [
- SubCollectionInfo("ps_lines")
- ** CalcInfo(
- [],
- ratio=FunctionInfo(
- "DIV",
- [
- ReferenceInfo("quantity"),
- BackReferenceExpressionInfo("ps_availqty", 1),
- ],
- ),
- ),
- BackReferenceCollectionInfo("nation", 1),
- ],
- nation_name=ChildReferenceExpressionInfo("name", 1),
- supplier_name=BackReferenceExpressionInfo("name", 1),
+ )
+ ** SubCollectionInfo("part")
+ ** CalculateInfo(
+ [],
+ supplier_name=BackReferenceExpressionInfo("name", 2),
part_name=ReferenceInfo("name"),
- ratio=FunctionInfo("MAX", [ChildReferenceExpressionInfo("ratio", 0)]),
+ ratio=BackReferenceExpressionInfo("ratio", 1),
),
- "TPCH.Suppliers.parts_supplied(nation_name=BACK(1).nation.name, supplier_name=BACK(1).name, part_name=name, ratio=MAX(ps_lines(ratio=quantity / BACK(1).ps_availqty).ratio))",
+ "TPCH.Suppliers.supply_records.CALCULATE(ratio=SUM(lines.quantity) / availqty).part.CALCULATE(supplier_name=name, part_name=name, ratio=ratio)",
"""
──┬─ TPCH
└─┬─ TableCollection[Suppliers]
- ├─── SubCollection[parts_supplied]
- └─┬─ Calc[nation_name=$2.name, supplier_name=BACK(1).name, part_name=name, ratio=MAX($1.ratio)]
+ ├─── SubCollection[supply_records]
+ └─┬─ Calculate[ratio=SUM($1.quantity) / availqty]
├─┬─ AccessChild
- │ ├─── SubCollection[ps_lines]
- │ └─── Calc[ratio=quantity / BACK(1).ps_availqty]
- └─┬─ AccessChild
- └─── BackSubCollection[1, nation]
+ │ └─── SubCollection[lines]
+ ├─── SubCollection[part]
+ └─── Calculate[supplier_name=name, part_name=name, ratio=ratio]
""",
- id="suppliers_parts_childcalc_b",
+ id="suppliers_parts_childcalc",
),
pytest.param(
(
- CalcInfo(
+ CalculateInfo(
[TableCollectionInfo("Customers")],
total_balance=FunctionInfo(
"SUM", [ChildReferenceExpressionInfo("acctbal", 0)]
),
)
),
- "TPCH(total_balance=SUM(Customers.acctbal))",
+ "TPCH.CALCULATE(total_balance=SUM(Customers.acctbal))",
"""
┌─── TPCH
-└─┬─ Calc[total_balance=SUM($1.acctbal)]
+└─┬─ Calculate[total_balance=SUM($1.acctbal)]
└─┬─ AccessChild
└─── TableCollection[Customers]
""",
id="globalcalc_a",
),
pytest.param(
- CalcInfo(
+ CalculateInfo(
[
TableCollectionInfo("Customers"),
TableCollectionInfo("Suppliers")
** SubCollectionInfo("parts_supplied")
- ** CalcInfo(
+ ** CalculateInfo(
[],
value=FunctionInfo(
"MUL",
@@ -1133,16 +1107,16 @@ def test_collections_calc_terms(
"SUM", [ChildReferenceExpressionInfo("value", 1)]
),
),
- "TPCH(total_demand=SUM(Customers.acctbal), total_supply=SUM(Suppliers.parts_supplied(value=ps_availqty * retail_price).value))",
+ "TPCH.CALCULATE(total_demand=SUM(Customers.acctbal), total_supply=SUM(Suppliers.parts_supplied.CALCULATE(value=ps_availqty * retail_price).value))",
"""
┌─── TPCH
-└─┬─ Calc[total_demand=SUM($1.acctbal), total_supply=SUM($2.value)]
+└─┬─ Calculate[total_demand=SUM($1.acctbal), total_supply=SUM($2.value)]
├─┬─ AccessChild
│ └─── TableCollection[Customers]
└─┬─ AccessChild
└─┬─ TableCollection[Suppliers]
├─── SubCollection[parts_supplied]
- └─── Calc[value=ps_availqty * retail_price]
+ └─── Calculate[value=ps_availqty * retail_price]
""",
id="globalcalc_b",
),
@@ -1159,16 +1133,16 @@ def test_collections_calc_terms(
),
pytest.param(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("customers")],
nation_name=ReferenceInfo("name"),
n_customers=FunctionInfo("COUNT", [ChildReferenceCollectionInfo(0)]),
),
- "TPCH.Nations(nation_name=name, n_customers=COUNT(customers))",
+ "TPCH.Nations.CALCULATE(nation_name=name, n_customers=COUNT(customers))",
"""
──┬─ TPCH
├─── TableCollection[Nations]
- └─┬─ Calc[nation_name=name, n_customers=COUNT($1)]
+ └─┬─ Calculate[nation_name=name, n_customers=COUNT($1)]
└─┬─ AccessChild
└─── SubCollection[customers]
""",
@@ -1176,7 +1150,7 @@ def test_collections_calc_terms(
),
pytest.param(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("customers"),
SubCollectionInfo("customers")
@@ -1213,11 +1187,11 @@ def test_collections_calc_terms(
"NDISTINCT", [ChildReferenceCollectionInfo(3)]
),
),
- "TPCH.Nations(nation_name=name, n_customers=COUNT(customers), n_customers_without_orders=COUNT(customers.WHERE(COUNT(orders) == 0)), n_lines_with_tax=COUNT(customers.orders.lines.tax), n_part_orders=COUNT(customers.orders.lines.part), n_unique_parts_ordered=NDISTINCT(customers.orders.lines.part))",
+ "TPCH.Nations.CALCULATE(nation_name=name, n_customers=COUNT(customers), n_customers_without_orders=COUNT(customers.WHERE(COUNT(orders) == 0)), n_lines_with_tax=COUNT(customers.orders.lines.tax), n_part_orders=COUNT(customers.orders.lines.part), n_unique_parts_ordered=NDISTINCT(customers.orders.lines.part))",
"""
──┬─ TPCH
├─── TableCollection[Nations]
- └─┬─ Calc[nation_name=name, n_customers=COUNT($1), n_customers_without_orders=COUNT($2), n_lines_with_tax=COUNT($3.tax), n_part_orders=COUNT($4), n_unique_parts_ordered=NDISTINCT($4)]
+ └─┬─ Calculate[nation_name=name, n_customers=COUNT($1), n_customers_without_orders=COUNT($2), n_lines_with_tax=COUNT($3.tax), n_part_orders=COUNT($4), n_unique_parts_ordered=NDISTINCT($4)]
├─┬─ AccessChild
│ └─── SubCollection[customers]
├─┬─ AccessChild
@@ -1290,7 +1264,7 @@ def test_collections_calc_terms(
"GRT", [ReferenceInfo("acctbal"), LiteralInfo(1000, Int64Type())]
),
)
- ** CalcInfo([], region_name=BackReferenceExpressionInfo("name", 1))
+ ** CalculateInfo([], region_name=BackReferenceExpressionInfo("name", 1))
** OrderInfo([], (ReferenceInfo("region_name"), True, True))
** WhereInfo(
[],
@@ -1299,7 +1273,7 @@ def test_collections_calc_terms(
[ReferenceInfo("region_name"), LiteralInfo("ASIA", StringType())],
),
),
- "TPCH.Regions.ORDER_BY(name.ASC(na_pos='last')).customers.ORDER_BY(key.ASC(na_pos='last')).WHERE(acctbal > 1000)(region_name=BACK(1).name).ORDER_BY(region_name.ASC(na_pos='last')).WHERE(region_name != 'ASIA')",
+ "TPCH.Regions.ORDER_BY(name.ASC(na_pos='last')).customers.ORDER_BY(key.ASC(na_pos='last')).WHERE(acctbal > 1000).CALCULATE(region_name=name).ORDER_BY(region_name.ASC(na_pos='last')).WHERE(region_name != 'ASIA')",
"""
──┬─ TPCH
├─── TableCollection[Regions]
@@ -1307,7 +1281,7 @@ def test_collections_calc_terms(
├─── SubCollection[customers]
├─── OrderBy[key.ASC(na_pos='last')]
├─── Where[acctbal > 1000]
- ├─── Calc[region_name=BACK(1).name]
+ ├─── Calculate[region_name=name]
├─── OrderBy[region_name.ASC(na_pos='last')]
└─── Where[region_name != 'ASIA']
""",
@@ -1319,20 +1293,20 @@ def test_collections_calc_terms(
"parts",
[ChildReferenceExpressionInfo("container", 0)],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("parts")],
container=ReferenceInfo("container"),
total_price=FunctionInfo(
"SUM", [ChildReferenceExpressionInfo("retail_price", 0)]
),
),
- "TPCH.Partition(Parts, name='parts', by=container)(container=container, total_price=SUM(parts.retail_price))",
+ "TPCH.Partition(Parts, name='parts', by=container).CALCULATE(container=container, total_price=SUM(parts.retail_price))",
"""
──┬─ TPCH
├─┬─ Partition[name='parts', by=container]
│ └─┬─ AccessChild
│ └─── TableCollection[Parts]
- └─┬─ Calc[container=container, total_price=SUM($1.retail_price)]
+ └─┬─ Calculate[container=container, total_price=SUM($1.retail_price)]
└─┬─ AccessChild
└─── PartitionChild[parts]
""",
@@ -1347,7 +1321,7 @@ def test_collections_calc_terms(
"EQU", [ReferenceInfo("tax"), LiteralInfo(0, Int64Type())]
),
)
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("order")
** SubCollectionInfo("shipping_region"),
@@ -1362,7 +1336,7 @@ def test_collections_calc_terms(
ChildReferenceExpressionInfo("part_type", 0),
],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("lines")],
region_name=ReferenceInfo("region_name"),
part_type=ReferenceInfo("part_type"),
@@ -1370,20 +1344,20 @@ def test_collections_calc_terms(
"SUM", [ChildReferenceExpressionInfo("extended_price", 0)]
),
),
- "TPCH.Partition(Lineitems.WHERE(tax == 0)(region_name=order.shipping_region.name, part_type=part.part_type), name='lines', by=('region_name', 'part_type'))(region_name=region_name, part_type=part_type, total_price=SUM(lines.extended_price))",
+ "TPCH.Partition(Lineitems.WHERE(tax == 0).CALCULATE(region_name=order.shipping_region.name, part_type=part.part_type), name='lines', by=('region_name', 'part_type')).CALCULATE(region_name=region_name, part_type=part_type, total_price=SUM(lines.extended_price))",
"""
──┬─ TPCH
├─┬─ Partition[name='lines', by=(region_name, part_type)]
│ └─┬─ AccessChild
│ ├─── TableCollection[Lineitems]
│ ├─── Where[tax == 0]
- │ └─┬─ Calc[region_name=$1.name, part_type=$2.part_type]
+ │ └─┬─ Calculate[region_name=$1.name, part_type=$2.part_type]
│ ├─┬─ AccessChild
│ │ └─┬─ SubCollection[order]
│ │ └─── SubCollection[shipping_region]
│ └─┬─ AccessChild
│ └─── SubCollection[part]
- └─┬─ Calc[region_name=region_name, part_type=part_type, total_price=SUM($1.extended_price)]
+ └─┬─ Calculate[region_name=region_name, part_type=part_type, total_price=SUM($1.extended_price)]
└─┬─ AccessChild
└─── PartitionChild[lines]
""",
@@ -1395,7 +1369,7 @@ def test_collections_calc_terms(
"parts",
[ChildReferenceExpressionInfo("container", 0)],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("parts")],
container=ReferenceInfo("container"),
total_price=FunctionInfo(
@@ -1403,13 +1377,13 @@ def test_collections_calc_terms(
),
)
** OrderInfo([], (ReferenceInfo("total_price"), False, True)),
- "TPCH.Partition(Parts, name='parts', by=container)(container=container, total_price=SUM(parts.retail_price)).ORDER_BY(total_price.DESC(na_pos='last'))",
+ "TPCH.Partition(Parts, name='parts', by=container).CALCULATE(container=container, total_price=SUM(parts.retail_price)).ORDER_BY(total_price.DESC(na_pos='last'))",
"""
──┬─ TPCH
├─┬─ Partition[name='parts', by=container]
│ └─┬─ AccessChild
│ └─── TableCollection[Parts]
- ├─┬─ Calc[container=container, total_price=SUM($1.retail_price)]
+ ├─┬─ Calculate[container=container, total_price=SUM($1.retail_price)]
│ └─┬─ AccessChild
│ └─── PartitionChild[parts]
└─── OrderBy[total_price.DESC(na_pos='last')]
@@ -1423,7 +1397,7 @@ def test_collections_calc_terms(
"parts",
[ChildReferenceExpressionInfo("container", 0)],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("parts")],
container=ReferenceInfo("container"),
total_price=FunctionInfo(
@@ -1431,7 +1405,7 @@ def test_collections_calc_terms(
),
)
** SubCollectionInfo("parts")
- ** CalcInfo(
+ ** CalculateInfo(
[],
part_name=ReferenceInfo("name"),
container=ReferenceInfo("container"),
@@ -1443,35 +1417,35 @@ def test_collections_calc_terms(
],
),
),
- "TPCH.Partition(Parts.ORDER_BY(retail_price.DESC(na_pos='last')), name='parts', by=container)(container=container, total_price=SUM(parts.retail_price)).parts(part_name=name, container=container, ratio=retail_price / BACK(1).total_price)",
+ "TPCH.Partition(Parts.ORDER_BY(retail_price.DESC(na_pos='last')), name='parts', by=container).CALCULATE(container=container, total_price=SUM(parts.retail_price)).parts.CALCULATE(part_name=name, container=container, ratio=retail_price / total_price)",
"""
──┬─ TPCH
├─┬─ Partition[name='parts', by=container]
│ └─┬─ AccessChild
│ ├─── TableCollection[Parts]
│ └─── OrderBy[retail_price.DESC(na_pos='last')]
- └─┬─ Calc[container=container, total_price=SUM($1.retail_price)]
+ └─┬─ Calculate[container=container, total_price=SUM($1.retail_price)]
├─┬─ AccessChild
│ └─── PartitionChild[parts]
├─── PartitionChild[parts]
- └─── Calc[part_name=name, container=container, ratio=retail_price / BACK(1).total_price]
+ └─── Calculate[part_name=name, container=container, ratio=retail_price / total_price]
""",
id="partition_data_with_data_order",
),
pytest.param(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("suppliers")],
total_sum=FunctionInfo(
"SUM", [ChildReferenceExpressionInfo("account_balance", 0)]
),
)
** TopKInfo([], 5, (ReferenceInfo("total_sum"), False, True)),
- "TPCH.Nations(total_sum=SUM(suppliers.account_balance)).TOP_K(5, total_sum.DESC(na_pos='last'))",
+ "TPCH.Nations.CALCULATE(total_sum=SUM(suppliers.account_balance)).TOP_K(5, total_sum.DESC(na_pos='last'))",
"""
──┬─ TPCH
├─── TableCollection[Nations]
- ├─┬─ Calc[total_sum=SUM($1.account_balance)]
+ ├─┬─ Calculate[total_sum=SUM($1.account_balance)]
│ └─┬─ AccessChild
│ └─── SubCollection[suppliers]
└─── TopK[5, total_sum.DESC(na_pos='last')]
@@ -1511,7 +1485,7 @@ def test_collections_calc_terms(
),
pytest.param(
TableCollectionInfo("Customers")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("nation")],
cust_nation_name=ChildReferenceExpressionInfo("name", 0),
)
@@ -1532,18 +1506,18 @@ def test_collections_calc_terms(
],
FunctionInfo("HASNOT", [ChildReferenceCollectionInfo(0)]),
),
- "TPCH.Customers(cust_nation_name=nation.name).WHERE(HASNOT(orders.lines.WHERE(supplier.nation.name != BACK(2).cust_nation_name)))",
+ "TPCH.Customers.CALCULATE(cust_nation_name=nation.name).WHERE(HASNOT(orders.lines.WHERE(supplier.nation.name != cust_nation_name)))",
"""
──┬─ TPCH
├─── TableCollection[Customers]
- ├─┬─ Calc[cust_nation_name=$1.name]
+ ├─┬─ Calculate[cust_nation_name=$1.name]
│ └─┬─ AccessChild
│ └─── SubCollection[nation]
└─┬─ Where[HASNOT($1)]
└─┬─ AccessChild
└─┬─ SubCollection[orders]
├─── SubCollection[lines]
- └─┬─ Where[$1.name != BACK(2).cust_nation_name]
+ └─┬─ Where[$1.name != cust_nation_name]
└─┬─ AccessChild
└─┬─ SubCollection[supplier]
└─── SubCollection[nation]
@@ -1552,7 +1526,7 @@ def test_collections_calc_terms(
),
pytest.param(
TableCollectionInfo("Suppliers")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("parts_supplied")
** WhereInfo(
@@ -1642,11 +1616,11 @@ def test_collections_calc_terms(
(FunctionInfo("HAS", [ChildReferenceCollectionInfo(1)]), True, True),
(ReferenceInfo("key"), True, True),
),
- "TPCH.Suppliers(has_part_24=COUNT(parts_supplied.WHERE(size == 24)) > 0, hasnot_part_25=COUNT(parts_supplied.WHERE(size == 25)) == 0).WHERE(HASNOT(parts_supplied.WHERE(size == 26)) & HAS(parts_supplied.WHERE(size == 27))).TOP_K(100, (COUNT(parts_supplied.WHERE(size == 28)) > 0).ASC(na_pos='last'), (COUNT(parts_supplied.WHERE(size == 29)) == 0).ASC(na_pos='last'), key.ASC(na_pos='last')).ORDER_BY((COUNT(parts_supplied.WHERE(size == 30)) == 0).ASC(na_pos='last'), (COUNT(parts_supplied.WHERE(size == 31)) > 0).ASC(na_pos='last'), key.ASC(na_pos='last'))",
+ "TPCH.Suppliers.CALCULATE(has_part_24=COUNT(parts_supplied.WHERE(size == 24)) > 0, hasnot_part_25=COUNT(parts_supplied.WHERE(size == 25)) == 0).WHERE(HASNOT(parts_supplied.WHERE(size == 26)) & HAS(parts_supplied.WHERE(size == 27))).TOP_K(100, (COUNT(parts_supplied.WHERE(size == 28)) > 0).ASC(na_pos='last'), (COUNT(parts_supplied.WHERE(size == 29)) == 0).ASC(na_pos='last'), key.ASC(na_pos='last')).ORDER_BY((COUNT(parts_supplied.WHERE(size == 30)) == 0).ASC(na_pos='last'), (COUNT(parts_supplied.WHERE(size == 31)) > 0).ASC(na_pos='last'), key.ASC(na_pos='last'))",
"""
──┬─ TPCH
├─── TableCollection[Suppliers]
- ├─┬─ Calc[has_part_24=COUNT($1) > 0, hasnot_part_25=COUNT($2) == 0]
+ ├─┬─ Calculate[has_part_24=COUNT($1) > 0, hasnot_part_25=COUNT($2) == 0]
│ ├─┬─ AccessChild
│ │ ├─── SubCollection[parts_supplied]
│ │ └─── Where[size == 24]
@@ -2026,41 +2000,41 @@ def test_collections_calc_terms(
),
pytest.param(
TableCollectionInfo("Customers")
- ** CalcInfo(
+ ** CalculateInfo(
[],
name=ReferenceInfo("name"),
cust_rank=WindowInfo(
"RANKING", (ReferenceInfo("acctbal"), False, True)
),
),
- "TPCH.Customers(name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last'))))",
+ "TPCH.Customers.CALCULATE(name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last'))))",
"""
──┬─ TPCH
├─── TableCollection[Customers]
- └─── Calc[name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')))]
+ └─── Calculate[name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')))]
""",
id="rank_customers_a",
),
pytest.param(
TableCollectionInfo("Customers")
- ** CalcInfo(
+ ** CalculateInfo(
[],
name=ReferenceInfo("name"),
cust_rank=WindowInfo(
"RANKING", (ReferenceInfo("acctbal"), False, True), allow_ties=True
),
),
- "TPCH.Customers(name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), allow_ties=True))",
+ "TPCH.Customers.CALCULATE(name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), allow_ties=True))",
"""
──┬─ TPCH
├─── TableCollection[Customers]
- └─── Calc[name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), allow_ties=True)]
+ └─── Calculate[name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), allow_ties=True)]
""",
id="rank_customers_b",
),
pytest.param(
TableCollectionInfo("Customers")
- ** CalcInfo(
+ ** CalculateInfo(
[],
name=ReferenceInfo("name"),
cust_rank=WindowInfo(
@@ -2070,18 +2044,18 @@ def test_collections_calc_terms(
dense=True,
),
),
- "TPCH.Customers(name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), allow_ties=True, dense=True))",
+ "TPCH.Customers.CALCULATE(name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), allow_ties=True, dense=True))",
"""
──┬─ TPCH
├─── TableCollection[Customers]
- └─── Calc[name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), allow_ties=True, dense=True)]
+ └─── Calculate[name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), allow_ties=True, dense=True)]
""",
id="rank_customers_c",
),
pytest.param(
TableCollectionInfo("Nations")
** SubCollectionInfo("customers")
- ** CalcInfo(
+ ** CalculateInfo(
[],
nation_name=BackReferenceExpressionInfo("name", 1),
name=ReferenceInfo("name"),
@@ -2089,12 +2063,12 @@ def test_collections_calc_terms(
"RANKING", (ReferenceInfo("acctbal"), False, True), levels=1
),
),
- "TPCH.Nations.customers(nation_name=BACK(1).name, name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), levels=1))",
+ "TPCH.Nations.customers.CALCULATE(nation_name=name, name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), levels=1))",
"""
──┬─ TPCH
└─┬─ TableCollection[Nations]
├─── SubCollection[customers]
- └─── Calc[nation_name=BACK(1).name, name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), levels=1)]
+ └─── Calculate[nation_name=name, name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), levels=1)]
""",
id="rank_customers_per_nation",
),
@@ -2102,7 +2076,7 @@ def test_collections_calc_terms(
TableCollectionInfo("Regions")
** SubCollectionInfo("nations")
** SubCollectionInfo("customers")
- ** CalcInfo(
+ ** CalculateInfo(
[],
region_name=BackReferenceExpressionInfo("name", 2),
nation_name=BackReferenceExpressionInfo("name", 1),
@@ -2115,13 +2089,13 @@ def test_collections_calc_terms(
dense=True,
),
),
- "TPCH.Regions.nations.customers(region_name=BACK(2).name, nation_name=BACK(1).name, name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), levels=2, allow_ties=True, dense=True))",
+ "TPCH.Regions.nations.customers.CALCULATE(region_name=name, nation_name=name, name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), levels=2, allow_ties=True, dense=True))",
"""
──┬─ TPCH
└─┬─ TableCollection[Regions]
└─┬─ SubCollection[nations]
├─── SubCollection[customers]
- └─── Calc[region_name=BACK(2).name, nation_name=BACK(1).name, name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), levels=2, allow_ties=True, dense=True)]
+ └─── Calculate[region_name=name, nation_name=name, name=name, cust_rank=RANKING(by=(acctbal.DESC(na_pos='last')), levels=2, allow_ties=True, dense=True)]
""",
id="rank_customers_per_region",
),
@@ -2150,14 +2124,16 @@ def test_collections_to_string(
"calc_pipeline, expected_collation_strings",
[
pytest.param(
- CalcInfo([], x=LiteralInfo(1, Int64Type()), y=LiteralInfo(3, Int64Type())),
+ CalculateInfo(
+ [], x=LiteralInfo(1, Int64Type()), y=LiteralInfo(3, Int64Type())
+ ),
None,
id="global_calc",
),
pytest.param(
TableCollectionInfo("Regions")
** SubCollectionInfo("suppliers")
- ** CalcInfo(
+ ** CalculateInfo(
[],
region_name=BackReferenceExpressionInfo("name", 1),
nation_name=ReferenceInfo("nation_name"),
@@ -2221,7 +2197,7 @@ def test_collections_to_string(
"GRT", [ReferenceInfo("acctbal"), LiteralInfo(1000, Int64Type())]
),
)
- ** CalcInfo([], region_name=BackReferenceExpressionInfo("name", 1))
+ ** CalculateInfo([], region_name=BackReferenceExpressionInfo("name", 1))
** OrderInfo([], (ReferenceInfo("region_name"), True, True))
** WhereInfo(
[],
@@ -2244,7 +2220,7 @@ def test_collections_to_string(
"GRT", [ReferenceInfo("acctbal"), LiteralInfo(1000, Int64Type())]
),
)
- ** CalcInfo([], region_name=BackReferenceExpressionInfo("name", 1))
+ ** CalculateInfo([], region_name=BackReferenceExpressionInfo("name", 1))
** WhereInfo(
[],
FunctionInfo(
@@ -2261,7 +2237,7 @@ def test_collections_to_string(
"parts",
[ChildReferenceExpressionInfo("container", 0)],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("parts")],
container=ReferenceInfo("container"),
total_price=FunctionInfo(
@@ -2277,7 +2253,7 @@ def test_collections_to_string(
"parts",
[ChildReferenceExpressionInfo("container", 0)],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("parts")],
container=ReferenceInfo("container"),
total_price=FunctionInfo(
@@ -2295,7 +2271,7 @@ def test_collections_to_string(
"parts",
[ChildReferenceExpressionInfo("container", 0)],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("parts")],
container=ReferenceInfo("container"),
total_price=FunctionInfo(
@@ -2312,7 +2288,7 @@ def test_collections_to_string(
"parts",
[ChildReferenceExpressionInfo("container", 0)],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("parts")],
container=ReferenceInfo("container"),
total_price=FunctionInfo(
@@ -2320,7 +2296,7 @@ def test_collections_to_string(
),
)
** SubCollectionInfo("parts")
- ** CalcInfo(
+ ** CalculateInfo(
[],
part_name=ReferenceInfo("name"),
container=ReferenceInfo("container"),
@@ -2337,7 +2313,7 @@ def test_collections_to_string(
),
pytest.param(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("suppliers")],
total_sum=FunctionInfo(
"SUM", [ChildReferenceExpressionInfo("account_balance", 0)]
@@ -2375,15 +2351,15 @@ def test_collections_ordering(
), "Mismatch between string representation of collation keys and expected value"
-def test_regions_intra_ratio_string_order(
- region_intra_ratio: tuple[CollectionTestInfo, str, str],
+def test_regions_intra_pct_string_order(
+ region_intra_pct: tuple[CollectionTestInfo, str, str],
tpch_node_builder: AstNodeBuilder,
) -> None:
"""
Same as `test_collections_to_string` and `test_collections_ordering`, but
- specifically on the structure from the `region_intra_ratio` fixture.
+ specifically on the structure from the `region_intra_pct` fixture.
"""
- calc_pipeline, expected_string, expected_tree_string = region_intra_ratio
+ calc_pipeline, expected_string, expected_tree_string = region_intra_pct
collection: PyDoughCollectionQDAG = calc_pipeline.build(tpch_node_builder)
assert collection.to_string() == expected_string
assert collection.to_tree_string() == expected_tree_string.strip()
diff --git a/tests/test_qdag_collection_errors.py b/tests/test_qdag_collection_errors.py
index 5dde8ae0..9d51cbf0 100644
--- a/tests/test_qdag_collection_errors.py
+++ b/tests/test_qdag_collection_errors.py
@@ -7,10 +7,8 @@
import pytest
from test_utils import (
AstNodeTestInfo,
- BackReferenceCollectionInfo,
BackReferenceExpressionInfo,
- CalcInfo,
- ChildReferenceCollectionInfo,
+ CalculateInfo,
ChildReferenceExpressionInfo,
FunctionInfo,
OrderInfo,
@@ -39,50 +37,51 @@
id="subcollection_dne",
),
pytest.param(
- TableCollectionInfo("Regions") ** CalcInfo([], foo=ReferenceInfo("bar")),
+ TableCollectionInfo("Regions")
+ ** CalculateInfo([], foo=ReferenceInfo("bar")),
"Unrecognized term of simple table collection 'Regions' in graph 'TPCH': 'bar'",
id="reference_dne",
),
pytest.param(
TableCollectionInfo("Nations")
** SubCollectionInfo("suppliers")
- ** CalcInfo([], foo=ReferenceInfo("region_key")),
+ ** CalculateInfo([], foo=ReferenceInfo("region_key")),
"Unrecognized term of simple table collection 'Suppliers' in graph 'TPCH': 'region_key'",
id="reference_bad_ancestry",
),
pytest.param(
TableCollectionInfo("Regions")
- ** CalcInfo([], foo=BackReferenceExpressionInfo("foo", 0)),
+ ** CalculateInfo([], foo=BackReferenceExpressionInfo("foo", 0)),
"Expected number of levels in BACK to be a positive integer, received 0",
id="back_zero",
),
pytest.param(
TableCollectionInfo("Regions")
- ** CalcInfo([], foo=BackReferenceExpressionInfo("foo", 1)),
+ ** CalculateInfo([], foo=BackReferenceExpressionInfo("foo", 1)),
"Unrecognized term of graph 'TPCH': 'foo'",
id="back_on_root",
),
pytest.param(
TableCollectionInfo("Regions")
** SubCollectionInfo("nations")
- ** CalcInfo([], foo=BackReferenceExpressionInfo("foo", 3)),
+ ** CalculateInfo([], foo=BackReferenceExpressionInfo("foo", 3)),
"Cannot reference back 3 levels above TPCH.Regions.nations",
id="back_too_far",
),
pytest.param(
TableCollectionInfo("Regions")
** SubCollectionInfo("nations")
- ** CalcInfo([], foo=BackReferenceExpressionInfo("foo", 1)),
+ ** CalculateInfo([], foo=BackReferenceExpressionInfo("foo", 1)),
"Unrecognized term of simple table collection 'Regions' in graph 'TPCH': 'foo'",
id="back_dne",
),
pytest.param(
- CalcInfo([], foo=ChildReferenceExpressionInfo("foo", 0)),
+ CalculateInfo([], foo=ChildReferenceExpressionInfo("foo", 0)),
"Invalid child reference index 0 with 0 children",
id="child_dne",
),
pytest.param(
- CalcInfo(
+ CalculateInfo(
[TableCollectionInfo("Regions")],
foo=ChildReferenceExpressionInfo("bar", 0),
),
@@ -91,39 +90,42 @@
),
pytest.param(
TableCollectionInfo("Regions")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("nations")],
nation_name=ChildReferenceExpressionInfo("name", 0),
),
- "Expected all terms in (nation_name=nations.name) to be singular, but encountered a plural expression: nations.name",
+ "Expected all terms in CALCULATE(nation_name=nations.name) to be singular, but encountered a plural expression: nations.name",
id="bad_plural_a",
),
pytest.param(
- TableCollectionInfo("Nations")
- ** SubCollectionInfo("customers")
- ** CalcInfo(
- [BackReferenceCollectionInfo("suppliers", 1)],
- customer_name=ChildReferenceExpressionInfo("name", 0),
+ TableCollectionInfo("Customers")
+ ** CalculateInfo(
+ [
+ SubCollectionInfo("nation")
+ ** SubCollectionInfo("region")
+ ** SubCollectionInfo("nations")
+ ],
+ nation_name=ChildReferenceExpressionInfo("name", 0),
),
- "Expected all terms in (customer_name=BACK(1).suppliers.name) to be singular, but encountered a plural expression: BACK(1).suppliers.name",
+ "Expected all terms in CALCULATE(nation_name=nation.region.nations.name) to be singular, but encountered a plural expression: nation.region.nations.name",
id="bad_plural_b",
),
pytest.param(
TableCollectionInfo("Parts")
- ** SubCollectionInfo("suppliers_of_part")
- ** CalcInfo(
- [SubCollectionInfo("ps_lines")],
+ ** SubCollectionInfo("supply_records")
+ ** CalculateInfo(
+ [SubCollectionInfo("lines")],
extended_price=ChildReferenceExpressionInfo("extended_price", 0),
),
- "Expected all terms in (extended_price=ps_lines.extended_price) to be singular, but encountered a plural expression: ps_lines.extended_price",
+ "Expected all terms in CALCULATE(extended_price=lines.extended_price) to be singular, but encountered a plural expression: lines.extended_price",
id="bad_plural_c",
),
pytest.param(
- CalcInfo(
+ CalculateInfo(
[TableCollectionInfo("Customers")],
cust_name=ChildReferenceExpressionInfo("name", 0),
),
- "Expected all terms in (cust_name=Customers.name) to be singular, but encountered a plural expression: Customers.name",
+ "Expected all terms in CALCULATE(cust_name=Customers.name) to be singular, but encountered a plural expression: Customers.name",
id="bad_plural_d",
),
pytest.param(
@@ -166,12 +168,12 @@
"parts",
[ChildReferenceExpressionInfo("container", 0)],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("parts")],
container=ReferenceInfo("container"),
price=ChildReferenceExpressionInfo("retail_price", 0),
),
- "Expected all terms in (container=container, price=parts.retail_price) to be singular, but encountered a plural expression: parts.retail_price",
+ "Expected all terms in CALCULATE(container=container, price=parts.retail_price) to be singular, but encountered a plural expression: parts.retail_price",
id="bad_plural_h",
),
pytest.param(
@@ -180,35 +182,14 @@
"parts",
[ChildReferenceExpressionInfo("container", 0)],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("parts") ** SubCollectionInfo("suppliers_of_part")],
container=ReferenceInfo("container"),
balance=ChildReferenceExpressionInfo("account_balance", 0),
),
- "Expected all terms in (container=container, balance=parts.suppliers_of_part.account_balance) to be singular, but encountered a plural expression: parts.suppliers_of_part.account_balance",
+ "Expected all terms in CALCULATE(container=container, balance=parts.suppliers_of_part.account_balance) to be singular, but encountered a plural expression: parts.suppliers_of_part.account_balance",
id="bad_plural_i",
),
- pytest.param(
- TableCollectionInfo("Nations")
- ** SubCollectionInfo("suppliers")
- ** PartitionInfo(
- SubCollectionInfo("parts_supplied"),
- "parts",
- [ChildReferenceExpressionInfo("part_type", 0)],
- )
- ** CalcInfo(
- [
- SubCollectionInfo("parts")
- ** SubCollectionInfo("suppliers_of_part"),
- BackReferenceCollectionInfo("customers", 2),
- ],
- part_type=ReferenceInfo("part_type"),
- num_parts=FunctionInfo("COUNT", [ChildReferenceCollectionInfo(0)]),
- cust_name=ChildReferenceExpressionInfo("name", 1),
- ),
- "Expected all terms in (part_type=part_type, num_parts=COUNT(parts.suppliers_of_part), cust_name=BACK(2).customers.name) to be singular, but encountered a plural expression: BACK(2).customers.name",
- id="bad_plural_j",
- ),
],
)
def test_malformed_collection_sequences(
diff --git a/tests/test_qdag_conversion.py b/tests/test_qdag_conversion.py
index 056ef168..02e08f3a 100644
--- a/tests/test_qdag_conversion.py
+++ b/tests/test_qdag_conversion.py
@@ -8,7 +8,7 @@
import pytest
from test_utils import (
BackReferenceExpressionInfo,
- CalcInfo,
+ CalculateInfo,
ChildReferenceCollectionInfo,
ChildReferenceExpressionInfo,
CollectionTestInfo,
@@ -54,7 +54,7 @@
pytest.param(
(
TableCollectionInfo("Regions")
- ** CalcInfo(
+ ** CalculateInfo(
[],
region_name=ReferenceInfo("name"),
magic_word=LiteralInfo("foo", StringType()),
@@ -66,8 +66,10 @@
pytest.param(
(
TableCollectionInfo("Regions")
- ** CalcInfo([], name=LiteralInfo("foo", StringType()))
- ** CalcInfo([], fizz=ReferenceInfo("name"), buzz=ReferenceInfo("key")),
+ ** CalculateInfo([], hello=LiteralInfo("foo", StringType()))
+ ** CalculateInfo(
+ [], fizz=ReferenceInfo("name"), buzz=ReferenceInfo("key")
+ ),
"scan_calc_calc",
),
id="scan_calc_calc",
@@ -91,9 +93,9 @@
pytest.param(
(
TableCollectionInfo("Customers")
- ** CalcInfo(
+ ** CalculateInfo(
[],
- name=FunctionInfo("LOWER", [ReferenceInfo("name")]),
+ lname=FunctionInfo("LOWER", [ReferenceInfo("name")]),
country_code=FunctionInfo(
"SLICE",
[
@@ -132,7 +134,7 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("region")],
nation_name=ReferenceInfo("name"),
region_name=ChildReferenceExpressionInfo("name", 0),
@@ -144,7 +146,7 @@
pytest.param(
(
TableCollectionInfo("Lineitems")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("part_and_supplier")
** SubCollectionInfo("supplier")
@@ -177,26 +179,8 @@
pytest.param(
(
TableCollectionInfo("Regions")
- ** CalcInfo([], key=LiteralInfo(-1, Int64Type()))
** SubCollectionInfo("nations")
- ** CalcInfo([], key=LiteralInfo(-2, Int64Type()))
- ** SubCollectionInfo("customers")
- ** CalcInfo(
- [],
- key=LiteralInfo(-3, Int64Type()),
- name=ReferenceInfo("name"),
- phone=ReferenceInfo("phone"),
- mktsegment=ReferenceInfo("mktsegment"),
- ),
- "join_regions_nations_calc_override",
- ),
- id="join_regions_nations_calc_override",
- ),
- pytest.param(
- (
- TableCollectionInfo("Regions")
- ** SubCollectionInfo("nations")
- ** CalcInfo(
+ ** CalculateInfo(
[],
region_name=BackReferenceExpressionInfo("name", 1),
nation_name=ReferenceInfo("name"),
@@ -212,22 +196,22 @@
** TableCollectionInfo("customers")
** TableCollectionInfo("orders")
** SubCollectionInfo("lines")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("part_and_supplier")
** SubCollectionInfo("supplier")
** SubCollectionInfo("nation")
** SubCollectionInfo("region")
- ** CalcInfo(
+ ** CalculateInfo(
[], nation_name=BackReferenceExpressionInfo("name", 1)
)
],
order_year=FunctionInfo(
"YEAR", [BackReferenceExpressionInfo("order_date", 1)]
),
- customer_region=BackReferenceExpressionInfo("name", 4),
- customer_nation=BackReferenceExpressionInfo("name", 3),
- supplier_region=ChildReferenceExpressionInfo("name", 0),
+ customer_region_name=BackReferenceExpressionInfo("name", 4),
+ customer_nation_name=BackReferenceExpressionInfo("name", 3),
+ supplier_region_name=ChildReferenceExpressionInfo("name", 0),
nation_name=ChildReferenceExpressionInfo("nation_name", 0),
),
"lines_shipping_vs_customer_region",
@@ -237,7 +221,7 @@
pytest.param(
(
TableCollectionInfo("Orders")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("lines")],
okey=ReferenceInfo("key"),
lsum=FunctionInfo(
@@ -251,7 +235,7 @@
pytest.param(
(
TableCollectionInfo("Customers")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("orders") ** SubCollectionInfo("lines")],
okey=ReferenceInfo("key"),
lsum=FunctionInfo(
@@ -265,7 +249,7 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("customers")
** SubCollectionInfo("orders")
@@ -283,7 +267,7 @@
pytest.param(
(
TableCollectionInfo("Regions")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("nations")
** SubCollectionInfo("customers")
@@ -302,7 +286,7 @@
pytest.param(
(
TableCollectionInfo("Orders")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("lines")],
okey=ReferenceInfo("key"),
lavg=FunctionInfo(
@@ -326,7 +310,7 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("customers"),
SubCollectionInfo("suppliers"),
@@ -346,32 +330,32 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("customers")],
- total_consumer_value=FunctionInfo(
+ total_consumer_value_a=FunctionInfo(
"SUM", [ChildReferenceExpressionInfo("acctbal", 0)]
),
- avg_consumer_value=FunctionInfo(
+ avg_consumer_value_a=FunctionInfo(
"AVG", [ChildReferenceExpressionInfo("acctbal", 0)]
),
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("suppliers")],
- nation_name=ReferenceInfo("key"),
- total_supplier_value=FunctionInfo(
+ nation_name_a=ReferenceInfo("key"),
+ total_supplier_value_a=FunctionInfo(
"SUM", [ChildReferenceExpressionInfo("account_balance", 0)]
),
- avg_supplier_value=FunctionInfo(
+ avg_supplier_value_a=FunctionInfo(
"AVG", [ChildReferenceExpressionInfo("account_balance", 0)]
),
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("suppliers"), SubCollectionInfo("customers")],
nation_name=ReferenceInfo("key"),
- total_consumer_value=ReferenceInfo("total_consumer_value"),
- total_supplier_value=ReferenceInfo("total_supplier_value"),
- avg_consumer_value=ReferenceInfo("avg_consumer_value"),
- avg_supplier_value=ReferenceInfo("avg_supplier_value"),
+ total_consumer_value=ReferenceInfo("total_consumer_value_a"),
+ total_supplier_value=ReferenceInfo("total_supplier_value_a"),
+ avg_consumer_value=ReferenceInfo("avg_consumer_value_a"),
+ avg_supplier_value=ReferenceInfo("avg_supplier_value_a"),
best_consumer_value=FunctionInfo(
"MAX", [ChildReferenceExpressionInfo("acctbal", 1)]
),
@@ -386,7 +370,7 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("customers")],
nation_name=ReferenceInfo("key"),
num_customers=FunctionInfo(
@@ -400,7 +384,7 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("customers"), SubCollectionInfo("suppliers")],
nation_name=ReferenceInfo("key"),
num_customers=FunctionInfo(
@@ -429,7 +413,7 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("customers"),
],
@@ -461,11 +445,11 @@
pytest.param(
(
TableCollectionInfo("Orders")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("lines")
** SubCollectionInfo("part_and_supplier")
- ** CalcInfo(
+ ** CalculateInfo(
[],
ratio=FunctionInfo(
"DIV",
@@ -488,14 +472,14 @@
pytest.param(
(
TableCollectionInfo("Orders")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("lines")],
total_quantity=FunctionInfo(
"SUM", [ChildReferenceExpressionInfo("quantity", 0)]
),
)
** SubCollectionInfo("lines")
- ** CalcInfo(
+ ** CalculateInfo(
[],
part_key=ReferenceInfo("part_key"),
supplier_key=ReferenceInfo("supplier_key"),
@@ -514,7 +498,7 @@
),
pytest.param(
(
- CalcInfo(
+ CalculateInfo(
[],
a=LiteralInfo(0, Int64Type()),
b=LiteralInfo("X", StringType()),
@@ -527,12 +511,12 @@
),
pytest.param(
(
- CalcInfo(
+ CalculateInfo(
[],
a=LiteralInfo(0, Int64Type()),
b=LiteralInfo("X", StringType()),
)
- ** CalcInfo(
+ ** CalculateInfo(
[],
a=ReferenceInfo("a"),
b=ReferenceInfo("b"),
@@ -545,7 +529,7 @@
),
pytest.param(
(
- CalcInfo(
+ CalculateInfo(
[TableCollectionInfo("Customers")],
total_bal=FunctionInfo(
"SUM", [ChildReferenceExpressionInfo("acctbal", 0)]
@@ -570,7 +554,7 @@
),
pytest.param(
(
- CalcInfo(
+ CalculateInfo(
[
TableCollectionInfo("Customers"),
TableCollectionInfo("Suppliers"),
@@ -586,13 +570,13 @@
),
pytest.param(
(
- CalcInfo(
+ CalculateInfo(
[],
a=LiteralInfo(28.15, Int64Type()),
b=LiteralInfo("NICKEL", StringType()),
)
** TableCollectionInfo("Parts")
- ** CalcInfo(
+ ** CalculateInfo(
[],
part_name=ReferenceInfo("name"),
is_above_cutoff=FunctionInfo(
@@ -616,14 +600,14 @@
),
pytest.param(
(
- CalcInfo(
+ CalculateInfo(
[TableCollectionInfo("Parts")],
avg_price=FunctionInfo(
"AVG", [ChildReferenceExpressionInfo("retail_price", 0)]
),
)
** TableCollectionInfo("Parts")
- ** CalcInfo(
+ ** CalculateInfo(
[],
part_name=ReferenceInfo("name"),
is_above_avg=FunctionInfo(
@@ -641,7 +625,7 @@
pytest.param(
(
TableCollectionInfo("Parts")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("supply_records")],
name=ReferenceInfo("name"),
total_delta=FunctionInfo(
@@ -669,7 +653,7 @@
"p",
[ChildReferenceExpressionInfo("part_type", 0)],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("p")],
part_type=ReferenceInfo("part_type"),
num_parts=FunctionInfo("COUNT", [ChildReferenceCollectionInfo(0)]),
@@ -685,7 +669,7 @@
(
PartitionInfo(
TableCollectionInfo("Orders")
- ** CalcInfo(
+ ** CalculateInfo(
[],
year=FunctionInfo("YEAR", [ReferenceInfo("order_date")]),
month=FunctionInfo("MONTH", [ReferenceInfo("order_date")]),
@@ -696,7 +680,7 @@
ChildReferenceExpressionInfo("month", 0),
],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("o")],
year=ReferenceInfo("year"),
month=ReferenceInfo("month"),
@@ -712,7 +696,7 @@
(
PartitionInfo(
TableCollectionInfo("Orders")
- ** CalcInfo(
+ ** CalculateInfo(
[],
year=FunctionInfo("YEAR", [ReferenceInfo("order_date")]),
month=FunctionInfo("MONTH", [ReferenceInfo("order_date")]),
@@ -723,7 +707,7 @@
ChildReferenceExpressionInfo("month", 0),
],
)
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("o"),
SubCollectionInfo("o")
@@ -759,7 +743,7 @@
(
PartitionInfo(
TableCollectionInfo("Orders")
- ** CalcInfo(
+ ** CalculateInfo(
[],
year=FunctionInfo("YEAR", [ReferenceInfo("order_date")]),
month=FunctionInfo("MONTH", [ReferenceInfo("order_date")]),
@@ -770,7 +754,7 @@
ChildReferenceExpressionInfo("month", 0),
],
)
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("o")
** WhereInfo(
@@ -808,7 +792,7 @@
** SubCollectionInfo("part_and_supplier")
** SubCollectionInfo("supplier")
** SubCollectionInfo("nation")
- ** CalcInfo(
+ ** CalculateInfo(
[],
year=FunctionInfo(
"YEAR", [BackReferenceExpressionInfo("order_date", 4)]
@@ -824,7 +808,7 @@
ChildReferenceExpressionInfo("supplier_nation", 0),
],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("combos")],
year=ReferenceInfo("year"),
customer_nation=ReferenceInfo("customer_nation"),
@@ -842,7 +826,7 @@
),
pytest.param(
(
- CalcInfo(
+ CalculateInfo(
[TableCollectionInfo("Parts")],
total_num_parts=FunctionInfo(
"COUNT", [ChildReferenceCollectionInfo(0)]
@@ -856,7 +840,7 @@
"p",
[ChildReferenceExpressionInfo("part_type", 0)],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("p")],
part_type=ReferenceInfo("part_type"),
percentage_of_parts=FunctionInfo(
@@ -904,7 +888,7 @@
),
)
** SubCollectionInfo("p")
- ** CalcInfo(
+ ** CalculateInfo(
[],
part_name=ReferenceInfo("name"),
part_type=ReferenceInfo("part_type"),
@@ -921,14 +905,14 @@
"p",
[ChildReferenceExpressionInfo("part_type", 0)],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("p")],
avg_price=FunctionInfo(
"AVG", [ChildReferenceExpressionInfo("retail_price", 0)]
),
)
** SubCollectionInfo("p")
- ** CalcInfo(
+ ** CalculateInfo(
[],
part_name=ReferenceInfo("name"),
part_type=ReferenceInfo("part_type"),
@@ -951,14 +935,14 @@
"p",
[ChildReferenceExpressionInfo("part_type", 0)],
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("p")],
avg_price=FunctionInfo(
"AVG", [ChildReferenceExpressionInfo("retail_price", 0)]
),
)
** SubCollectionInfo("p")
- ** CalcInfo(
+ ** CalculateInfo(
[],
part_name=ReferenceInfo("name"),
part_type=ReferenceInfo("part_type"),
@@ -1041,7 +1025,7 @@
],
),
)
- ** CalcInfo(
+ ** CalculateInfo(
[],
order_key=ReferenceInfo("order_key"),
ship_date=ReferenceInfo("ship_date"),
@@ -1088,7 +1072,7 @@
],
),
)
- ** CalcInfo(
+ ** CalculateInfo(
[],
rname=BackReferenceExpressionInfo("name", 4),
price=ReferenceInfo("extended_price"),
@@ -1119,7 +1103,7 @@
],
),
)
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("order")
** SubCollectionInfo("customer")
@@ -1161,7 +1145,7 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("suppliers")
** WhereInfo(
@@ -1191,7 +1175,7 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo([], name=ReferenceInfo("name"))
+ ** CalculateInfo([], name=ReferenceInfo("name"))
** WhereInfo(
[
SubCollectionInfo("suppliers")
@@ -1233,7 +1217,7 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("suppliers")
** WhereInfo(
@@ -1297,7 +1281,7 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("suppliers")
** WhereInfo(
@@ -1353,7 +1337,7 @@
TableCollectionInfo("Regions")
** SubCollectionInfo("nations")
** TopKInfo([], 10, (ReferenceInfo("name"), True, True))
- ** CalcInfo(
+ ** CalculateInfo(
[],
region_name=BackReferenceExpressionInfo("name", 1),
nation_name=ReferenceInfo("name"),
@@ -1375,7 +1359,7 @@
TableCollectionInfo("Regions")
** SubCollectionInfo("nations")
** OrderInfo([], (ReferenceInfo("name"), False, True))
- ** CalcInfo(
+ ** CalculateInfo(
[],
region_name=BackReferenceExpressionInfo("name", 1),
nation_name=ReferenceInfo("name"),
@@ -1389,7 +1373,7 @@
TableCollectionInfo("Regions")
** SubCollectionInfo("nations")
** OrderInfo([], (ReferenceInfo("name"), True, True))
- ** CalcInfo(
+ ** CalculateInfo(
[],
region_name=BackReferenceExpressionInfo("name", 1),
nation_name=ReferenceInfo("name"),
@@ -1413,7 +1397,7 @@
TableCollectionInfo("Regions")
** OrderInfo([], (ReferenceInfo("name"), True, True))
** TopKInfo([], 10, (ReferenceInfo("name"), True, True))
- ** CalcInfo(
+ ** CalculateInfo(
[],
region_name=ReferenceInfo("name"),
name_length=FunctionInfo("LENGTH", [ReferenceInfo("name")]),
@@ -1502,7 +1486,7 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("suppliers")
** TopKInfo(
@@ -1561,7 +1545,7 @@
True,
),
)
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("suppliers")],
name=ReferenceInfo("name"),
total_bal=FunctionInfo(
@@ -1576,7 +1560,7 @@
(
TableCollectionInfo("Regions")
** SubCollectionInfo("nations")
- ** CalcInfo(
+ ** CalculateInfo(
[],
region_name=BackReferenceExpressionInfo("name", 1),
nation_name=ReferenceInfo("name"),
@@ -1590,7 +1574,7 @@
(
TableCollectionInfo("Regions")
** SubCollectionInfo("nations")
- ** CalcInfo(
+ ** CalculateInfo(
[],
nation_name=ReferenceInfo("name"),
)
@@ -1602,7 +1586,7 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[],
ordering_0=ReferenceInfo("name"),
ordering_1=ReferenceInfo("key"),
@@ -1614,7 +1598,7 @@
(FunctionInfo("ABS", [ReferenceInfo("key")]), False, True),
(FunctionInfo("LENGTH", [ReferenceInfo("comment")]), True, False),
)
- ** CalcInfo(
+ ** CalculateInfo(
[],
ordering_0=ReferenceInfo("ordering_2"),
ordering_1=ReferenceInfo("ordering_0"),
@@ -1637,7 +1621,7 @@
[SubCollectionInfo("orders")],
FunctionInfo("HAS", [ChildReferenceCollectionInfo(0)]),
)
- ** CalcInfo(
+ ** CalculateInfo(
[],
name=ReferenceInfo("name"),
),
@@ -1662,7 +1646,7 @@
],
FunctionInfo("HAS", [ChildReferenceCollectionInfo(0)]),
)
- ** CalcInfo(
+ ** CalculateInfo(
[],
name=ReferenceInfo("name"),
),
@@ -1677,7 +1661,7 @@
[SubCollectionInfo("orders")],
FunctionInfo("HASNOT", [ChildReferenceCollectionInfo(0)]),
)
- ** CalcInfo(
+ ** CalculateInfo(
[],
name=ReferenceInfo("name"),
),
@@ -1702,7 +1686,7 @@
],
FunctionInfo("HASNOT", [ChildReferenceCollectionInfo(0)]),
)
- ** CalcInfo(
+ ** CalculateInfo(
[],
name=ReferenceInfo("name"),
),
@@ -1729,7 +1713,7 @@
],
FunctionInfo("HAS", [ChildReferenceCollectionInfo(0)]),
)
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("region")
** WhereInfo(
@@ -1753,7 +1737,7 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("region")
** WhereInfo(
@@ -1767,7 +1751,7 @@
),
)
],
- name=ReferenceInfo("name"),
+ nation_name=ReferenceInfo("name"),
region_name=ChildReferenceExpressionInfo("name", 0),
)
** WhereInfo(
@@ -1807,7 +1791,7 @@
],
FunctionInfo("HAS", [ChildReferenceCollectionInfo(0)]),
)
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("supply_records")
** SubCollectionInfo("part")
@@ -1837,7 +1821,7 @@
pytest.param(
(
TableCollectionInfo("Suppliers")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("supply_records")
** SubCollectionInfo("part")
@@ -1897,7 +1881,7 @@
],
FunctionInfo("HASNOT", [ChildReferenceCollectionInfo(0)]),
)
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("region")
** WhereInfo(
@@ -1921,7 +1905,7 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("region")
** WhereInfo(
@@ -1935,7 +1919,7 @@
),
)
],
- name=ReferenceInfo("name"),
+ nation_name=ReferenceInfo("name"),
region_name=ChildReferenceExpressionInfo("name", 0),
)
** WhereInfo(
@@ -1975,7 +1959,7 @@
],
FunctionInfo("HASNOT", [ChildReferenceCollectionInfo(0)]),
)
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("supply_records")
** SubCollectionInfo("part")
@@ -2005,7 +1989,7 @@
pytest.param(
(
TableCollectionInfo("Suppliers")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("supply_records")
** SubCollectionInfo("part")
@@ -2100,7 +2084,7 @@
],
),
)
- ** CalcInfo([], name=ReferenceInfo("name")),
+ ** CalculateInfo([], name=ReferenceInfo("name")),
"multiple_has_hasnot",
),
id="multiple_has_hasnot",
@@ -2108,7 +2092,7 @@
pytest.param(
(
TableCollectionInfo("Customers")
- ** CalcInfo(
+ ** CalculateInfo(
[],
name=ReferenceInfo("name"),
cust_rank=WindowInfo(
@@ -2124,7 +2108,7 @@
(
TableCollectionInfo("Nations")
** SubCollectionInfo("customers")
- ** CalcInfo(
+ ** CalculateInfo(
[],
nation_name=BackReferenceExpressionInfo("name", 1),
name=ReferenceInfo("name"),
@@ -2144,7 +2128,7 @@
TableCollectionInfo("Regions")
** SubCollectionInfo("nations")
** SubCollectionInfo("customers")
- ** CalcInfo(
+ ** CalculateInfo(
[],
nation_name=BackReferenceExpressionInfo("name", 1),
name=ReferenceInfo("name"),
@@ -2163,10 +2147,10 @@
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("customers")
- ** CalcInfo(
+ ** CalculateInfo(
[],
cust_rank=WindowInfo(
"RANKING",
@@ -2209,7 +2193,7 @@ def test_ast_to_relational(
calc_pipeline, file_name = relational_test_data
file_path: str = get_plan_test_filename(file_name)
collection: PyDoughCollectionQDAG = calc_pipeline.build(tpch_node_builder)
- relational = convert_ast_to_relational(collection, default_config)
+ relational = convert_ast_to_relational(collection, None, default_config)
if update_tests:
with open(file_path, "w") as f:
f.write(relational.to_tree_string() + "\n")
@@ -2226,7 +2210,7 @@ def test_ast_to_relational(
pytest.param(
(
TableCollectionInfo("Nations")
- ** CalcInfo(
+ ** CalculateInfo(
[SubCollectionInfo("customers")],
nation_name=ReferenceInfo("name"),
total_bal=FunctionInfo(
@@ -2252,7 +2236,7 @@ def test_ast_to_relational(
),
pytest.param(
(
- CalcInfo(
+ CalculateInfo(
[TableCollectionInfo("Customers")],
total_bal=FunctionInfo(
"SUM", [ChildReferenceExpressionInfo("acctbal", 0)]
@@ -2292,7 +2276,7 @@ def test_ast_to_relational(
],
FunctionInfo("HASNOT", [ChildReferenceCollectionInfo(0)]),
)
- ** CalcInfo(
+ ** CalculateInfo(
[
SubCollectionInfo("supply_records")
** SubCollectionInfo("part")
@@ -2349,7 +2333,7 @@ def test_ast_to_relational_alternative_aggregation_configs(
default_config.sum_default_zero = False
default_config.avg_default_zero = True
collection: PyDoughCollectionQDAG = calc_pipeline.build(tpch_node_builder)
- relational = convert_ast_to_relational(collection, default_config)
+ relational = convert_ast_to_relational(collection, None, default_config)
if update_tests:
with open(file_path, "w") as f:
f.write(relational.to_tree_string() + "\n")
diff --git a/tests/test_qualification.py b/tests/test_qualification.py
index eab7c54b..88e7c0d6 100644
--- a/tests/test_qualification.py
+++ b/tests/test_qualification.py
@@ -3,721 +3,138 @@
into qualified DAG nodes.
"""
-import datetime
from collections.abc import Callable
import pytest
+from simple_pydough_functions import partition_as_child
from test_utils import (
graph_fetcher,
)
+from tpch_test_functions import (
+ impl_tpch_q1,
+ impl_tpch_q2,
+ impl_tpch_q3,
+ impl_tpch_q4,
+ impl_tpch_q5,
+ impl_tpch_q6,
+ impl_tpch_q7,
+ impl_tpch_q8,
+ impl_tpch_q9,
+ impl_tpch_q10,
+ impl_tpch_q11,
+ impl_tpch_q12,
+ impl_tpch_q13,
+ impl_tpch_q14,
+ impl_tpch_q15,
+ impl_tpch_q16,
+ impl_tpch_q17,
+ impl_tpch_q18,
+ impl_tpch_q19,
+ impl_tpch_q20,
+ impl_tpch_q21,
+ impl_tpch_q22,
+)
+from pydough import init_pydough_context
from pydough.metadata import GraphMetadata
from pydough.qdag import PyDoughCollectionQDAG, PyDoughQDAG
from pydough.unqualified import (
UnqualifiedNode,
- UnqualifiedRoot,
qualify_node,
)
-def pydough_impl_misc_01(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for the following PyDough snippet:
- ```
- TPCH.Nations(nation_name=name, total_balance=SUM(customers.acctbal))
- ```
- """
- return root.Nations(
- nation_name=root.name, total_balance=root.SUM(root.customers.acctbal)
- )
-
-
-def pydough_impl_misc_02(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for the following PyDough snippet:
- ```
- lines_1994 = orders.WHERE(
- (datetime.date(1994, 1, 1) <= order_date) &
- (order_date < datetime.date(1995, 1, 1))
- ).lines
- lines_1995 = orders.WHERE(
- (datetime.date(1995, 1, 1) <= order_date) &
- (order_date < datetime.date(1996, 1, 1))
- ).lines
- TPCH.Nations.customers(
- name=LOWER(name),
- nation_name=BACK(1).name,
- total_1994=SUM(lines_1994.extended_price - lines_1994.tax / 2),
- total_1995=SUM(lines_1995.extended_price - lines_1995.tax / 2),
- )
- ```
- """
- lines_1994 = root.orders.WHERE(
- (datetime.date(1994, 1, 1) <= root.order_date)
- & (root.order_date < datetime.date(1995, 1, 1))
- ).lines
- lines_1995 = root.orders.WHERE(
- (datetime.date(1995, 1, 1) <= root.order_date)
- & (root.order_date < datetime.date(1996, 1, 1))
- ).lines
- return root.Nations.customers(
- name=root.LOWER(root.name),
- nation_name=root.BACK(1).name,
- total_1994=root.SUM(lines_1994.extended_price - lines_1994.tax / 2),
- total_1995=root.SUM(lines_1995.extended_price - lines_1995.tax / 2),
- )
-
-
-def pydough_impl_misc_03(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for the following PyDough snippet:
- ```
- sizes = PARTITION(Parts, name="p", by=size)(n_parts=COUNT(p))
- TPCH(
- avg_n_parts=AVG(sizes.n_parts)
- )(
- n_parts=COUNT(sizes.WHERE(n_parts > BACK(1).avg_n_parts))
- )
- ```
- """
- sizes = root.PARTITION(root.Parts, name="p", by=root.size)(
- n_parts=root.COUNT(root.p)
- )
- return root.TPCH(avg_n_parts=root.AVG(sizes.n_parts))(
- n_parts=root.COUNT(sizes.WHERE(root.n_parts > root.BACK(1).avg_n_parts))
- )
-
-
-def pydough_impl_tpch_q1(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 1.
- """
- selected_lines = root.Lineitems.WHERE(root.ship_date <= datetime.date(1998, 12, 1))
- return root.PARTITION(selected_lines, name="l", by=(root.return_flag, root.status))(
- l_returnflag=root.return_flag,
- l_linestatus=root.status,
- sum_qty=root.SUM(root.l.quantity),
- sum_base_price=root.SUM(root.l.extended_price),
- sum_disc_price=root.SUM(root.l.extended_price * (1 - root.l.discount)),
- sum_charge=root.SUM(
- root.l.extended_price * (1 - root.l.discount) * (1 + root.l.tax)
- ),
- avg_qty=root.AVG(root.l.quantity),
- avg_price=root.AVG(root.l.extended_price),
- avg_disc=root.AVG(root.l.discount),
- count_order=root.COUNT(root.l),
- ).ORDER_BY(
- root.return_flag.ASC(),
- root.status.ASC(),
- )
-
-
-def pydough_impl_tpch_q2(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 2.
- """
- selected_parts = (
- root.Nations.WHERE(root.region.name == "EUROPE")
- .suppliers.supply_records.part(
- s_acctbal=root.BACK(2).account_balance,
- s_name=root.BACK(2).name,
- n_name=root.BACK(3).name,
- s_address=root.BACK(2).address,
- s_phone=root.BACK(2).phone,
- s_comment=root.BACK(2).comment,
- supplycost=root.BACK(1).supplycost,
- )
- .WHERE(root.ENDSWITH(root.part_type, "BRASS") & (root.size == 15))
- )
-
- return (
- root.PARTITION(selected_parts, name="p", by=root.key)(
- best_cost=root.MIN(root.p.supplycost)
- )
- .p.WHERE(root.supplycost == root.BACK(1).best_cost)(
- s_acctbal=root.s_acctbal,
- s_name=root.s_name,
- n_name=root.n_name,
- p_partkey=root.key,
- p_mfgr=root.manufacturer,
- s_address=root.s_address,
- s_phone=root.s_phone,
- s_comment=root.s_comment,
- )
- .TOP_K(
- 10,
- by=(
- root.s_acctbal.DESC(),
- root.n_name.ASC(),
- root.s_name.ASC(),
- root.p_partkey.ASC(),
- ),
- )
- )
-
-
-def pydough_impl_tpch_q3(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 3.
- """
- selected_lines = root.Orders.WHERE(
- (root.customer.mktsegment == "BUILDING")
- & (root.order_date < datetime.date(1995, 3, 15))
- ).lines.WHERE(root.ship_date > datetime.date(1995, 3, 15))(
- root.BACK(1).order_date,
- root.BACK(1).ship_priority,
- )
-
- return root.PARTITION(
- selected_lines,
- name="l",
- by=(root.order_key, root.order_date, root.ship_priority),
- )(
- l_orderkey=root.order_key,
- revenue=root.SUM(root.l.extended_price * (1 - root.l.discount)),
- o_orderdate=root.order_date,
- o_shippriority=root.ship_priority,
- ).TOP_K(10, by=(root.revenue.DESC(), root.o_orderdate.ASC(), root.l_orderkey.ASC()))
-
-
-def pydough_impl_tpch_q4(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 4.
- """
- selected_lines = root.lines.WHERE(root.commit_date < root.receipt_date)
- selected_orders = root.Orders.WHERE(
- (root.order_date >= datetime.date(1993, 7, 1))
- & (root.order_date < datetime.date(1993, 10, 1))
- & root.HAS(selected_lines)
- )
- return root.PARTITION(selected_orders, name="o", by=root.order_priority)(
- o_orderpriority=root.order_priority,
- order_count=root.COUNT(root.o),
- ).ORDER_BY(root.order_priority.ASC())
-
-
-def pydough_impl_tpch_q5(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 5.
- """
- selected_lines = root.customers.orders.WHERE(
- (root.order_date >= datetime.date(1994, 1, 1))
- & (root.order_date < datetime.date(1995, 1, 1))
- ).lines.WHERE(root.supplier.nation.name == root.BACK(3).name)(
- value=root.extended_price * (1 - root.discount)
- )
- return root.Nations.WHERE(root.region.name == "ASIA")(
- n_name=root.name, revenue=root.SUM(selected_lines.value)
- ).ORDER_BY(root.revenue.DESC())
-
-
-def pydough_impl_tpch_q6(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 6.
- """
- selected_lines = root.Lineitems.WHERE(
- (root.ship_date >= datetime.date(1994, 1, 1))
- & (root.ship_date < datetime.date(1995, 1, 1))
- & (0.05 <= root.discount)
- & (root.discount <= 0.07)
- & (root.quantity < 24)
- )(amt=root.extended_price * root.discount)
- return root.TPCH(revenue=root.SUM(selected_lines.amt))
-
-
-def pydough_impl_tpch_q7(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 7.
- """
- line_info = root.Lineitems(
- supp_nation=root.supplier.nation.name,
- cust_nation=root.order.customer.nation.name,
- l_year=root.YEAR(root.ship_date),
- volume=root.extended_price * (1 - root.discount),
- ).WHERE(
- (root.ship_date >= datetime.date(1995, 1, 1))
- & (root.ship_date <= datetime.date(1996, 12, 31))
- & (
- ((root.supp_nation == "FRANCE") & (root.cust_nation == "GERMANY"))
- | ((root.supp_nation == "GERMANY") & (root.cust_nation == "FRANCE"))
- )
- )
-
- return root.PARTITION(
- line_info, name="l", by=(root.supp_nation, root.cust_nation, root.l_year)
- )(
- root.supp_nation,
- root.cust_nation,
- root.l_year,
- revenue=root.SUM(root.l.volume),
- ).ORDER_BY(
- root.supp_nation.ASC(),
- root.cust_nation.ASC(),
- root.l_year.ASC(),
- )
-
-
-def pydough_impl_tpch_q8(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 8.
- """
- volume_data = (
- root.Nations.suppliers.supply_records.WHERE(
- root.part.part_type == "ECONOMY ANODIZED STEEL"
- )
- .lines(volume=root.extended_price * (1 - root.discount))
- .order(
- o_year=root.YEAR(root.order_date),
- volume=root.BACK(1).volume,
- brazil_volume=root.IFF(
- root.BACK(4).name == "BRAZIL", root.BACK(1).volume, 0
- ),
- )
- .WHERE(
- (root.order_date >= datetime.date(1995, 1, 1))
- & (root.order_date <= datetime.date(1996, 12, 31))
- & (root.customer.nation.region.name == "AMERICA")
- )
- )
-
- return root.PARTITION(volume_data, name="v", by=root.o_year)(
- o_year=root.o_year,
- mkt_share=root.SUM(root.v.brazil_volume) / root.SUM(root.v.volume),
- )
-
-
-def pydough_impl_tpch_q9(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 9, truncated to 10 rows.
- """
- selected_lines = root.Nations.suppliers.supply_records.WHERE(
- root.CONTAINS(root.part.name, "green")
- ).lines(
- nation=root.BACK(3).name,
- o_year=root.YEAR(root.order.order_date),
- value=root.extended_price * (1 - root.discount)
- - root.BACK(1).supplycost * root.quantity,
- )
- return root.PARTITION(selected_lines, name="l", by=(root.nation, root.o_year))(
- nation=root.nation, o_year=root.o_year, amount=root.SUM(root.l.value)
- ).TOP_K(
- 10,
- by=(root.nation.ASC(), root.o_year.DESC()),
- )
-
-
-def pydough_impl_tpch_q10(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 10.
- """
- selected_lines = root.orders.WHERE(
- (root.order_date >= datetime.date(1993, 10, 1))
- & (root.order_date < datetime.date(1994, 1, 1))
- ).lines.WHERE(root.return_flag == "R")(
- amt=root.extended_price * (1 - root.discount)
- )
- return root.Customers(
- c_custkey=root.key,
- c_name=root.name,
- revenue=root.SUM(selected_lines.amt),
- c_acctbal=root.acctbal,
- n_name=root.nation.name,
- c_address=root.address,
- c_phone=root.phone,
- c_comment=root.comment,
- ).TOP_K(20, by=(root.revenue.DESC(), root.c_custkey.ASC()))
-
-
-def pydough_impl_tpch_q11(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 11, truncated to 10 rows.
- """
- is_german_supplier = root.supplier.nation.name == "GERMANY"
- selected_records = root.PartSupp.WHERE(is_german_supplier)(
- metric=root.supplycost * root.availqty
- )
- return (
- root.TPCH(min_market_share=root.SUM(selected_records.metric) * 0.0001)
- .PARTITION(selected_records, name="ps", by=root.part_key)(
- ps_partkey=root.part_key, value=root.SUM(root.ps.metric)
- )
- .WHERE(root.value > root.BACK(1).min_market_share)
- .TOP_K(10, by=root.value.DESC())
- )
-
-
-def pydough_impl_tpch_q12(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 12.
- """
- selected_lines = root.Lineitems.WHERE(
- ((root.ship_mode == "MAIL") | (root.ship_mode == "SHIP"))
- & (root.ship_date < root.commit_date)
- & (root.commit_date < root.receipt_date)
- & (root.receipt_date >= datetime.date(1994, 1, 1))
- & (root.receipt_date < datetime.date(1995, 1, 1))
- )(
- is_high_priority=(root.order.order_priority == "1-URGENT")
- | (root.order.order_priority == "2-HIGH"),
- )
- return root.PARTITION(selected_lines, "l", by=root.ship_mode)(
- l_shipmode=root.ship_mode,
- high_line_count=root.SUM(root.l.is_high_priority),
- low_line_count=root.SUM(~(root.l.is_high_priority)),
- ).ORDER_BY(root.ship_mode.ASC())
-
-
-def pydough_impl_tpch_q13(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 13, truncated to 10 rows.
- """
- customer_info = root.Customers(
- root.key,
- num_non_special_orders=root.COUNT(
- root.orders.WHERE(~(root.LIKE(root.comment, "%special%requests%")))
- ),
- )
- return root.PARTITION(customer_info, name="custs", by=root.num_non_special_orders)(
- c_count=root.num_non_special_orders, custdist=root.COUNT(root.custs)
- ).TOP_K(10, by=(root.custdist.DESC(), root.c_count.DESC()))
-
-
-def pydough_impl_tpch_q14(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 14.
- """
- value = root.extended_price * (1 - root.discount)
- selected_lines = root.Lineitems.WHERE(
- (root.ship_date >= datetime.date(1995, 9, 1))
- & (root.ship_date < datetime.date(1995, 10, 1))
- )(
- value=value,
- promo_value=root.IFF(root.STARTSWITH(root.part.part_type, "PROMO"), value, 0),
- )
- return root.TPCH(
- promo_revenue=100.0
- * root.SUM(selected_lines.promo_value)
- / root.SUM(selected_lines.value)
- )
-
-
-def pydough_impl_tpch_q15(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 15.
- """
- selected_lines = root.lines.WHERE(
- (root.ship_date >= datetime.date(1996, 1, 1))
- & (root.ship_date < datetime.date(1996, 4, 1))
- )
- total = root.SUM(selected_lines.extended_price * (1 - selected_lines.discount))
- return (
- root.TPCH(
- max_revenue=root.MAX(root.Suppliers(total_revenue=total).total_revenue)
- )
- .Suppliers(
- s_suppkey=root.key,
- s_name=root.name,
- s_address=root.address,
- s_phone=root.phone,
- total_revenue=total,
- )
- .WHERE(root.total_revenue == root.BACK(1).max_revenue)
- .ORDER_BY(root.s_suppkey.ASC())
- )
-
-
-def pydough_impl_tpch_q16(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 16, truncated to 10 rows.
- """
- selected_records = (
- root.Parts.WHERE(
- (root.brand != "BRAND#45")
- & ~root.STARTSWITH(root.part_type, "MEDIUM POLISHED%")
- & root.ISIN(root.size, [49, 14, 23, 45, 19, 3, 36, 9])
- )
- .supply_records(
- p_brand=root.BACK(1).brand,
- p_type=root.BACK(1).part_type,
- p_size=root.BACK(1).size,
- ps_suppkey=root.supplier_key,
- )
- .WHERE(~root.LIKE(root.supplier.comment, "%Customer%Complaints%"))
- )
- return root.PARTITION(
- selected_records, name="ps", by=(root.p_brand, root.p_type, root.p_size)
- )(
- root.p_brand,
- root.p_type,
- root.p_size,
- supplier_count=root.NDISTINCT(root.ps.supplier_key),
- ).TOP_K(10, by=(root.supplier_count.DESC(), root.p_brand.ASC(), root.p_type.ASC()))
-
-
-def pydough_impl_tpch_q17(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 17.
- """
- selected_lines = root.Parts.WHERE(
- (root.brand == "Brand#23") & (root.container == "MED BOX")
- )(avg_quantity=root.AVG(root.lines.quantity)).lines.WHERE(
- root.quantity < 0.2 * root.BACK(1).avg_quantity
- )
- return root.TPCH(avg_yearly=root.SUM(selected_lines.extended_price) / 7.0)
-
-
-def pydough_impl_tpch_q18(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 18, truncated to 10 rows.
- """
- return (
- root.Orders(
- c_name=root.customer.name,
- c_custkey=root.customer.key,
- o_orderkey=root.key,
- o_orderdate=root.order_date,
- o_totalprice=root.total_price,
- total_quantity=root.SUM(root.lines.quantity),
- )
- .WHERE(root.total_quantity > 300)
- .TOP_K(10, by=(root.o_totalprice.DESC(), root.o_orderdate.ASC()))
- )
-
-
-def pydough_impl_tpch_q19(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 19.
- """
- selected_lines = root.Lineitems.WHERE(
- (root.ISIN(root.ship_mode, ("AIR", "AIR REG")))
- & (root.ship_instruct == "DELIVER IN PERSON")
- & (root.part.size >= 1)
- & (
- (
- (root.part.size <= 5)
- & (root.quantity >= 1)
- & (root.quantity <= 11)
- & root.ISIN(
- root.part.container,
- ("SM CASE", "SM BOX", "SM PACK", "SM PKG"),
- )
- & (root.part.brand == "Brand#12")
- )
- | (
- (root.part.size <= 10)
- & (root.quantity >= 10)
- & (root.quantity <= 20)
- & root.ISIN(
- root.part.container,
- ("MED BAG", "MED BOX", "MED PACK", "MED PKG"),
- )
- & (root.part.brand == "Brand#23")
- )
- | (
- (root.part.size <= 15)
- & (root.quantity >= 20)
- & (root.quantity <= 30)
- & root.ISIN(
- root.part.container,
- ("LG CASE", "LG BOX", "LG PACK", "LG PKG"),
- )
- & (root.part.brand == "Brand#34")
- )
- )
- )
- return root.TPCH(
- revenue=root.SUM(selected_lines.extended_price * (1 - selected_lines.discount))
- )
-
-
-def pydough_impl_tpch_q20(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 20, truncated to 10 rows.
- """
- selected_lines = root.lines.WHERE(
- (root.ship_date >= datetime.date(1994, 1, 1))
- & (root.ship_date < datetime.date(1995, 1, 1))
- )
-
- selected_part_supplied = root.supply_records.part.WHERE(
- root.STARTSWITH(root.name, "forest")
- & root.HAS(selected_lines)
- & (root.BACK(1).availqty > (root.SUM(selected_lines.quantity) * 0.5))
- )
-
- return (
- root.Suppliers(
- s_name=root.name,
- s_address=root.address,
- )
- .WHERE(
- (root.nation.name == "CANADA") & (root.COUNT(selected_part_supplied) > 0)
- )
- .TOP_K(10, by=root.s_name.ASC())
- )
-
-
-def pydough_impl_tpch_q21(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 21, truncated to 10 rows.
- """
- date_check = root.receipt_date > root.commit_date
- different_supplier = root.supplier_key != root.BACK(2).supplier_key
- waiting_entries = root.lines.WHERE(
- root.receipt_date > root.commit_date
- ).order.WHERE(
- (root.order_status == "F")
- & root.HAS(root.lines.WHERE(different_supplier))
- & root.HASNOT(root.lines.WHERE(different_supplier & date_check))
- )
- return root.Suppliers.WHERE(root.nation.name == "SAUDI ARABIA")(
- s_name=root.name,
- numwait=root.COUNT(waiting_entries),
- ).TOP_K(
- 10,
- by=(root.numwait.DESC(), root.s_name.ASC()),
- )
-
-
-def pydough_impl_tpch_q22(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for TPC-H query 22.
- """
- selected_customers = root.Customers(cntry_code=root.phone[:2]).WHERE(
- root.ISIN(root.cntry_code, ("13", "31", "23", "29", "30", "18", "17"))
- & root.HASNOT(root.orders)
- )
- return root.TPCH(
- avg_balance=root.AVG(selected_customers.WHERE(root.acctbal > 0.0).acctbal)
- ).PARTITION(
- selected_customers.WHERE(root.acctbal > root.BACK(1).avg_balance),
- name="custs",
- by=root.cntry_code,
- )(
- root.cntry_code,
- num_custs=root.COUNT(root.custs),
- totacctbal=root.SUM(root.custs.acctbal),
- )
-
-
@pytest.mark.parametrize(
"impl, answer_tree_str",
[
pytest.param(
- pydough_impl_misc_01,
- """
-──┬─ TPCH
- ├─── TableCollection[Nations]
- └─┬─ Calc[nation_name=name, total_balance=SUM($1.acctbal)]
- └─┬─ AccessChild
- └─── SubCollection[customers]
-""",
- id="misc_01",
- ),
- pytest.param(
- pydough_impl_misc_02,
- """
-──┬─ TPCH
- └─┬─ TableCollection[Nations]
- ├─── SubCollection[customers]
- └─┬─ Calc[name=LOWER(name), nation_name=BACK(1).name, total_1994=SUM($1.extended_price - ($1.tax / 2)), total_1995=SUM($2.extended_price - ($2.tax / 2))]
- ├─┬─ AccessChild
- │ ├─── SubCollection[orders]
- │ └─┬─ Where[(order_date >= datetime.date(1994, 1, 1)) & (order_date < datetime.date(1995, 1, 1))]
- │ └─── SubCollection[lines]
- └─┬─ AccessChild
- ├─── SubCollection[orders]
- └─┬─ Where[(order_date >= datetime.date(1995, 1, 1)) & (order_date < datetime.date(1996, 1, 1))]
- └─── SubCollection[lines]
-""",
- id="misc_02",
- ),
- pytest.param(
- pydough_impl_misc_03,
+ partition_as_child,
"""
┌─── TPCH
-├─┬─ Calc[avg_n_parts=AVG($1.n_parts)]
+├─┬─ Calculate[avg_n_parts=AVG($1.n_parts)]
│ └─┬─ AccessChild
│ ├─┬─ Partition[name='p', by=size]
│ │ └─┬─ AccessChild
│ │ └─── TableCollection[Parts]
-│ └─┬─ Calc[n_parts=COUNT($1)]
+│ └─┬─ Calculate[n_parts=COUNT($1)]
│ └─┬─ AccessChild
│ └─── PartitionChild[p]
-└─┬─ Calc[n_parts=COUNT($1)]
+└─┬─ Calculate[n_parts=COUNT($1)]
└─┬─ AccessChild
├─┬─ Partition[name='p', by=size]
│ └─┬─ AccessChild
│ └─── TableCollection[Parts]
- ├─┬─ Calc[n_parts=COUNT($1)]
+ ├─┬─ Calculate[n_parts=COUNT($1)]
│ └─┬─ AccessChild
│ └─── PartitionChild[p]
- └─── Where[n_parts > BACK(1).avg_n_parts]
+ └─── Where[n_parts > avg_n_parts]
""",
- id="misc_03",
+ id="partition_as_child",
),
pytest.param(
- pydough_impl_tpch_q1,
+ impl_tpch_q1,
"""
──┬─ TPCH
├─┬─ Partition[name='l', by=(return_flag, status)]
│ └─┬─ AccessChild
│ ├─── TableCollection[Lineitems]
│ └─── Where[ship_date <= datetime.date(1998, 12, 1)]
- ├─┬─ Calc[l_returnflag=return_flag, l_linestatus=status, sum_qty=SUM($1.quantity), sum_base_price=SUM($1.extended_price), sum_disc_price=SUM($1.extended_price * (1 - $1.discount)), sum_charge=SUM(($1.extended_price * (1 - $1.discount)) * (1 + $1.tax)), avg_qty=AVG($1.quantity), avg_price=AVG($1.extended_price), avg_disc=AVG($1.discount), count_order=COUNT($1)]
+ ├─┬─ Calculate[L_RETURNFLAG=return_flag, L_LINESTATUS=status, SUM_QTY=SUM($1.quantity), SUM_BASE_PRICE=SUM($1.extended_price), SUM_DISC_PRICE=SUM($1.extended_price * (1 - $1.discount)), SUM_CHARGE=SUM(($1.extended_price * (1 - $1.discount)) * (1 + $1.tax)), AVG_QTY=AVG($1.quantity), AVG_PRICE=AVG($1.extended_price), AVG_DISC=AVG($1.discount), COUNT_ORDER=COUNT($1)]
│ └─┬─ AccessChild
│ └─── PartitionChild[l]
- └─── OrderBy[return_flag.ASC(na_pos='first'), status.ASC(na_pos='first')]
+ └─── OrderBy[L_RETURNFLAG.ASC(na_pos='first'), L_LINESTATUS.ASC(na_pos='first')]
""",
id="tpch_q1",
),
pytest.param(
- pydough_impl_tpch_q2,
+ impl_tpch_q2,
"""
──┬─ TPCH
├─┬─ Partition[name='p', by=key]
│ └─┬─ AccessChild
│ ├─── TableCollection[Nations]
+ │ ├─── Calculate[n_name=name]
│ └─┬─ Where[$1.name == 'EUROPE']
│ ├─┬─ AccessChild
│ │ └─── SubCollection[region]
- │ └─┬─ SubCollection[suppliers]
- │ └─┬─ SubCollection[supply_records]
+ │ ├─── SubCollection[suppliers]
+ │ └─┬─ Calculate[s_acctbal=account_balance, s_name=name, s_address=address, s_phone=phone, s_comment=comment]
+ │ ├─── SubCollection[supply_records]
+ │ └─┬─ Calculate[supplycost=supplycost]
│ ├─── SubCollection[part]
- │ ├─── Calc[s_acctbal=BACK(2).account_balance, s_name=BACK(2).name, n_name=BACK(3).name, s_address=BACK(2).address, s_phone=BACK(2).phone, s_comment=BACK(2).comment, supplycost=BACK(1).supplycost]
│ └─── Where[ENDSWITH(part_type, 'BRASS') & (size == 15)]
- └─┬─ Calc[best_cost=MIN($1.supplycost)]
+ └─┬─ Calculate[best_cost=MIN($1.supplycost)]
├─┬─ AccessChild
│ └─── PartitionChild[p]
├─── PartitionChild[p]
- ├─── Where[supplycost == BACK(1).best_cost]
- ├─── Calc[s_acctbal=s_acctbal, s_name=s_name, n_name=n_name, p_partkey=key, p_mfgr=manufacturer, s_address=s_address, s_phone=s_phone, s_comment=s_comment]
- └─── TopK[10, s_acctbal.DESC(na_pos='last'), n_name.ASC(na_pos='first'), s_name.ASC(na_pos='first'), p_partkey.ASC(na_pos='first')]
+ ├─── Where[(supplycost == best_cost) & ENDSWITH(part_type, 'BRASS') & (size == 15)]
+ ├─── Calculate[S_ACCTBAL=s_acctbal, S_NAME=s_name, N_NAME=n_name, P_PARTKEY=key, P_MFGR=manufacturer, S_ADDRESS=s_address, S_PHONE=s_phone, S_COMMENT=s_comment]
+ └─── TopK[10, S_ACCTBAL.DESC(na_pos='last'), N_NAME.ASC(na_pos='first'), S_NAME.ASC(na_pos='first'), P_PARTKEY.ASC(na_pos='first')]
""",
id="tpch_q2",
),
pytest.param(
- pydough_impl_tpch_q3,
+ impl_tpch_q3,
"""
──┬─ TPCH
├─┬─ Partition[name='l', by=(order_key, order_date, ship_priority)]
│ └─┬─ AccessChild
│ ├─── TableCollection[Orders]
+ │ ├─── Calculate[order_date=order_date, ship_priority=ship_priority]
│ └─┬─ Where[($1.mktsegment == 'BUILDING') & (order_date < datetime.date(1995, 3, 15))]
│ ├─┬─ AccessChild
│ │ └─── SubCollection[customer]
│ ├─── SubCollection[lines]
- │ ├─── Where[ship_date > datetime.date(1995, 3, 15)]
- │ └─── Calc[order_date=BACK(1).order_date, ship_priority=BACK(1).ship_priority]
- ├─┬─ Calc[l_orderkey=order_key, revenue=SUM($1.extended_price * (1 - $1.discount)), o_orderdate=order_date, o_shippriority=ship_priority]
+ │ └─── Where[ship_date > datetime.date(1995, 3, 15)]
+ ├─┬─ Calculate[L_ORDERKEY=order_key, REVENUE=SUM($1.extended_price * (1 - $1.discount)), O_ORDERDATE=order_date, O_SHIPPRIORITY=ship_priority]
│ └─┬─ AccessChild
│ └─── PartitionChild[l]
- └─── TopK[10, revenue.DESC(na_pos='last'), o_orderdate.ASC(na_pos='first'), l_orderkey.ASC(na_pos='first')]
+ └─── TopK[10, REVENUE.DESC(na_pos='last'), O_ORDERDATE.ASC(na_pos='first'), L_ORDERKEY.ASC(na_pos='first')]
""",
id="tpch_q3",
),
pytest.param(
- pydough_impl_tpch_q4,
+ impl_tpch_q4,
"""
──┬─ TPCH
├─┬─ Partition[name='o', by=order_priority]
@@ -727,56 +144,57 @@ def pydough_impl_tpch_q22(root: UnqualifiedNode) -> UnqualifiedNode:
│ └─┬─ AccessChild
│ ├─── SubCollection[lines]
│ └─── Where[commit_date < receipt_date]
- ├─┬─ Calc[o_orderpriority=order_priority, order_count=COUNT($1)]
+ ├─┬─ Calculate[O_ORDERPRIORITY=order_priority, ORDER_COUNT=COUNT($1)]
│ └─┬─ AccessChild
│ └─── PartitionChild[o]
- └─── OrderBy[order_priority.ASC(na_pos='first')]
+ └─── OrderBy[O_ORDERPRIORITY.ASC(na_pos='first')]
""",
id="tpch_q4",
),
pytest.param(
- pydough_impl_tpch_q5,
+ impl_tpch_q5,
"""
──┬─ TPCH
├─── TableCollection[Nations]
+ ├─── Calculate[nation_name=name]
├─┬─ Where[$1.name == 'ASIA']
│ └─┬─ AccessChild
│ └─── SubCollection[region]
- ├─┬─ Calc[n_name=name, revenue=SUM($1.value)]
+ ├─┬─ Calculate[N_NAME=name, REVENUE=SUM($1.value)]
│ └─┬─ AccessChild
│ └─┬─ SubCollection[customers]
│ ├─── SubCollection[orders]
│ └─┬─ Where[(order_date >= datetime.date(1994, 1, 1)) & (order_date < datetime.date(1995, 1, 1))]
│ ├─── SubCollection[lines]
- │ ├─┬─ Where[$1.name == BACK(3).name]
+ │ ├─┬─ Where[$1.name == nation_name]
│ │ └─┬─ AccessChild
│ │ └─┬─ SubCollection[supplier]
│ │ └─── SubCollection[nation]
- │ └─── Calc[value=extended_price * (1 - discount)]
- └─── OrderBy[revenue.DESC(na_pos='last')]
+ │ └─── Calculate[value=extended_price * (1 - discount)]
+ └─── OrderBy[REVENUE.DESC(na_pos='last')]
""",
id="tpch_q5",
),
pytest.param(
- pydough_impl_tpch_q6,
+ impl_tpch_q6,
"""
┌─── TPCH
-└─┬─ Calc[revenue=SUM($1.amt)]
+└─┬─ Calculate[REVENUE=SUM($1.amt)]
└─┬─ AccessChild
├─── TableCollection[Lineitems]
├─── Where[(ship_date >= datetime.date(1994, 1, 1)) & (ship_date < datetime.date(1995, 1, 1)) & (discount >= 0.05) & (discount <= 0.07) & (quantity < 24)]
- └─── Calc[amt=extended_price * discount]
+ └─── Calculate[amt=extended_price * discount]
""",
id="tpch_q6",
),
pytest.param(
- pydough_impl_tpch_q7,
+ impl_tpch_q7,
"""
──┬─ TPCH
├─┬─ Partition[name='l', by=(supp_nation, cust_nation, l_year)]
│ └─┬─ AccessChild
│ ├─── TableCollection[Lineitems]
- │ ├─┬─ Calc[supp_nation=$1.name, cust_nation=$2.name, l_year=YEAR(ship_date), volume=extended_price * (1 - discount)]
+ │ ├─┬─ Calculate[supp_nation=$1.name, cust_nation=$2.name, l_year=YEAR(ship_date), volume=extended_price * (1 - discount)]
│ │ ├─┬─ AccessChild
│ │ │ └─┬─ SubCollection[supplier]
│ │ │ └─── SubCollection[nation]
@@ -785,93 +203,96 @@ def pydough_impl_tpch_q22(root: UnqualifiedNode) -> UnqualifiedNode:
│ │ └─┬─ SubCollection[customer]
│ │ └─── SubCollection[nation]
│ └─── Where[(ship_date >= datetime.date(1995, 1, 1)) & (ship_date <= datetime.date(1996, 12, 31)) & (((supp_nation == 'FRANCE') & (cust_nation == 'GERMANY')) | ((supp_nation == 'GERMANY') & (cust_nation == 'FRANCE')))]
- ├─┬─ Calc[supp_nation=supp_nation, cust_nation=cust_nation, l_year=l_year, revenue=SUM($1.volume)]
+ ├─┬─ Calculate[SUPP_NATION=supp_nation, CUST_NATION=cust_nation, L_YEAR=l_year, REVENUE=SUM($1.volume)]
│ └─┬─ AccessChild
│ └─── PartitionChild[l]
- └─── OrderBy[supp_nation.ASC(na_pos='first'), cust_nation.ASC(na_pos='first'), l_year.ASC(na_pos='first')]
+ └─── OrderBy[SUPP_NATION.ASC(na_pos='first'), CUST_NATION.ASC(na_pos='first'), L_YEAR.ASC(na_pos='first')]
""",
id="tpch_q7",
),
pytest.param(
- pydough_impl_tpch_q8,
+ impl_tpch_q8,
"""
──┬─ TPCH
├─┬─ Partition[name='v', by=o_year]
│ └─┬─ AccessChild
- │ └─┬─ TableCollection[Nations]
+ │ ├─── TableCollection[Nations]
+ │ └─┬─ Calculate[nation_name=name]
│ └─┬─ SubCollection[suppliers]
│ ├─── SubCollection[supply_records]
│ └─┬─ Where[$1.part_type == 'ECONOMY ANODIZED STEEL']
│ ├─┬─ AccessChild
│ │ └─── SubCollection[part]
│ ├─── SubCollection[lines]
- │ └─┬─ Calc[volume=extended_price * (1 - discount)]
+ │ └─┬─ Calculate[volume=extended_price * (1 - discount)]
│ ├─── SubCollection[order]
- │ ├─── Calc[o_year=YEAR(order_date), volume=BACK(1).volume, brazil_volume=IFF(BACK(4).name == 'BRAZIL', BACK(1).volume, 0)]
+ │ ├─── Calculate[o_year=YEAR(order_date), brazil_volume=IFF(nation_name == 'BRAZIL', volume, 0)]
│ └─┬─ Where[(order_date >= datetime.date(1995, 1, 1)) & (order_date <= datetime.date(1996, 12, 31)) & ($1.name == 'AMERICA')]
│ └─┬─ AccessChild
│ └─┬─ SubCollection[customer]
│ └─┬─ SubCollection[nation]
│ └─── SubCollection[region]
- └─┬─ Calc[o_year=o_year, mkt_share=SUM($1.brazil_volume) / SUM($1.volume)]
+ └─┬─ Calculate[O_YEAR=o_year, MKT_SHARE=SUM($1.brazil_volume) / SUM($1.volume)]
└─┬─ AccessChild
└─── PartitionChild[v]
""",
id="tpch_q8",
),
pytest.param(
- pydough_impl_tpch_q9,
+ impl_tpch_q9,
"""
──┬─ TPCH
- ├─┬─ Partition[name='l', by=(nation, o_year)]
+ ├─┬─ Partition[name='l', by=(nation_name, o_year)]
│ └─┬─ AccessChild
- │ └─┬─ TableCollection[Nations]
+ │ ├─── TableCollection[Nations]
+ │ └─┬─ Calculate[nation_name=name]
│ └─┬─ SubCollection[suppliers]
│ ├─── SubCollection[supply_records]
+ │ ├─── Calculate[supplycost=supplycost]
│ └─┬─ Where[CONTAINS($1.name, 'green')]
│ ├─┬─ AccessChild
│ │ └─── SubCollection[part]
│ ├─── SubCollection[lines]
- │ └─┬─ Calc[nation=BACK(3).name, o_year=YEAR($1.order_date), value=(extended_price * (1 - discount)) - (BACK(1).supplycost * quantity)]
+ │ └─┬─ Calculate[o_year=YEAR($1.order_date), value=(extended_price * (1 - discount)) - (supplycost * quantity)]
│ └─┬─ AccessChild
│ └─── SubCollection[order]
- ├─┬─ Calc[nation=nation, o_year=o_year, amount=SUM($1.value)]
+ ├─┬─ Calculate[NATION=nation_name, O_YEAR=o_year, AMOUNT=SUM($1.value)]
│ └─┬─ AccessChild
│ └─── PartitionChild[l]
- └─── TopK[10, nation.ASC(na_pos='first'), o_year.DESC(na_pos='last')]
+ └─── TopK[10, NATION.ASC(na_pos='first'), O_YEAR.DESC(na_pos='last')]
""",
id="tpch_q9",
),
pytest.param(
- pydough_impl_tpch_q10,
+ impl_tpch_q10,
"""
──┬─ TPCH
├─── TableCollection[Customers]
- ├─┬─ Calc[c_custkey=key, c_name=name, revenue=SUM($1.amt), c_acctbal=acctbal, n_name=$2.name, c_address=address, c_phone=phone, c_comment=comment]
+ ├─┬─ Calculate[C_CUSTKEY=key, C_NAME=name, REVENUE=SUM($1.amt), C_ACCTBAL=acctbal, N_NAME=$2.name, C_ADDRESS=address, C_PHONE=phone, C_COMMENT=comment]
│ ├─┬─ AccessChild
│ │ ├─── SubCollection[orders]
│ │ └─┬─ Where[(order_date >= datetime.date(1993, 10, 1)) & (order_date < datetime.date(1994, 1, 1))]
│ │ ├─── SubCollection[lines]
│ │ ├─── Where[return_flag == 'R']
- │ │ └─── Calc[amt=extended_price * (1 - discount)]
+ │ │ └─── Calculate[amt=extended_price * (1 - discount)]
│ └─┬─ AccessChild
│ └─── SubCollection[nation]
- └─── TopK[20, revenue.DESC(na_pos='last'), c_custkey.ASC(na_pos='first')]
+ └─── TopK[20, REVENUE.DESC(na_pos='last'), C_CUSTKEY.ASC(na_pos='first')]
""",
id="tpch_q10",
),
pytest.param(
- pydough_impl_tpch_q11,
+ impl_tpch_q11,
"""
┌─── TPCH
-└─┬─ Calc[min_market_share=SUM($1.metric) * 0.0001]
+└─┬─ Calculate[min_market_share=SUM($1.metric) * 0.0001]
├─┬─ AccessChild
│ ├─── TableCollection[PartSupp]
│ ├─┬─ Where[$1.name == 'GERMANY']
│ │ └─┬─ AccessChild
│ │ └─┬─ SubCollection[supplier]
│ │ └─── SubCollection[nation]
- │ └─── Calc[metric=supplycost * availqty]
+ │ └─── Calculate[metric=supplycost * availqty]
├─┬─ Partition[name='ps', by=part_key]
│ └─┬─ AccessChild
│ ├─── TableCollection[PartSupp]
@@ -879,142 +300,142 @@ def pydough_impl_tpch_q22(root: UnqualifiedNode) -> UnqualifiedNode:
│ │ └─┬─ AccessChild
│ │ └─┬─ SubCollection[supplier]
│ │ └─── SubCollection[nation]
- │ └─── Calc[metric=supplycost * availqty]
- ├─┬─ Calc[ps_partkey=part_key, value=SUM($1.metric)]
+ │ └─── Calculate[metric=supplycost * availqty]
+ ├─┬─ Calculate[PS_PARTKEY=part_key, VALUE=SUM($1.metric)]
│ └─┬─ AccessChild
│ └─── PartitionChild[ps]
- ├─── Where[value > BACK(1).min_market_share]
- └─── TopK[10, value.DESC(na_pos='last')]
+ ├─── Where[VALUE > min_market_share]
+ └─── TopK[10, VALUE.DESC(na_pos='last')]
""",
id="tpch_q11",
),
pytest.param(
- pydough_impl_tpch_q12,
+ impl_tpch_q12,
"""
──┬─ TPCH
├─┬─ Partition[name='l', by=ship_mode]
│ └─┬─ AccessChild
│ ├─── TableCollection[Lineitems]
│ ├─── Where[((ship_mode == 'MAIL') | (ship_mode == 'SHIP')) & (ship_date < commit_date) & (commit_date < receipt_date) & (receipt_date >= datetime.date(1994, 1, 1)) & (receipt_date < datetime.date(1995, 1, 1))]
- │ └─┬─ Calc[is_high_priority=($1.order_priority == '1-URGENT') | ($1.order_priority == '2-HIGH')]
+ │ └─┬─ Calculate[is_high_priority=($1.order_priority == '1-URGENT') | ($1.order_priority == '2-HIGH')]
│ └─┬─ AccessChild
│ └─── SubCollection[order]
- ├─┬─ Calc[l_shipmode=ship_mode, high_line_count=SUM($1.is_high_priority), low_line_count=SUM(NOT($1.is_high_priority))]
+ ├─┬─ Calculate[L_SHIPMODE=ship_mode, HIGH_LINE_COUNT=SUM($1.is_high_priority), LOW_LINE_COUNT=SUM(NOT($1.is_high_priority))]
│ └─┬─ AccessChild
│ └─── PartitionChild[l]
- └─── OrderBy[ship_mode.ASC(na_pos='first')]
+ └─── OrderBy[L_SHIPMODE.ASC(na_pos='first')]
""",
id="tpch_q12",
),
pytest.param(
- pydough_impl_tpch_q13,
+ impl_tpch_q13,
"""
──┬─ TPCH
├─┬─ Partition[name='custs', by=num_non_special_orders]
│ └─┬─ AccessChild
│ ├─── TableCollection[Customers]
- │ └─┬─ Calc[key=key, num_non_special_orders=COUNT($1)]
+ │ └─┬─ Calculate[num_non_special_orders=COUNT($1)]
│ └─┬─ AccessChild
│ ├─── SubCollection[orders]
│ └─── Where[NOT(LIKE(comment, '%special%requests%'))]
- ├─┬─ Calc[c_count=num_non_special_orders, custdist=COUNT($1)]
+ ├─┬─ Calculate[C_COUNT=num_non_special_orders, CUSTDIST=COUNT($1)]
│ └─┬─ AccessChild
│ └─── PartitionChild[custs]
- └─── TopK[10, custdist.DESC(na_pos='last'), c_count.DESC(na_pos='last')]
+ └─── TopK[10, CUSTDIST.DESC(na_pos='last'), C_COUNT.DESC(na_pos='last')]
""",
id="tpch_q13",
),
pytest.param(
- pydough_impl_tpch_q14,
+ impl_tpch_q14,
"""
┌─── TPCH
-└─┬─ Calc[promo_revenue=(100.0 * SUM($1.promo_value)) / SUM($1.value)]
+└─┬─ Calculate[PROMO_REVENUE=(100.0 * SUM($1.promo_value)) / SUM($1.value)]
└─┬─ AccessChild
├─── TableCollection[Lineitems]
├─── Where[(ship_date >= datetime.date(1995, 9, 1)) & (ship_date < datetime.date(1995, 10, 1))]
- └─┬─ Calc[value=extended_price * (1 - discount), promo_value=IFF(STARTSWITH($1.part_type, 'PROMO'), extended_price * (1 - discount), 0)]
+ └─┬─ Calculate[value=extended_price * (1 - discount), promo_value=IFF(STARTSWITH($1.part_type, 'PROMO'), extended_price * (1 - discount), 0)]
└─┬─ AccessChild
└─── SubCollection[part]
""",
id="tpch_q14",
),
pytest.param(
- pydough_impl_tpch_q15,
+ impl_tpch_q15,
"""
┌─── TPCH
-└─┬─ Calc[max_revenue=MAX($1.total_revenue)]
+└─┬─ Calculate[max_revenue=MAX($1.total_revenue)]
├─┬─ AccessChild
│ ├─── TableCollection[Suppliers]
- │ └─┬─ Calc[total_revenue=SUM($1.extended_price * (1 - $1.discount))]
+ │ └─┬─ Calculate[total_revenue=SUM($1.extended_price * (1 - $1.discount))]
│ └─┬─ AccessChild
│ ├─── SubCollection[lines]
│ └─── Where[(ship_date >= datetime.date(1996, 1, 1)) & (ship_date < datetime.date(1996, 4, 1))]
├─── TableCollection[Suppliers]
- ├─┬─ Calc[s_suppkey=key, s_name=name, s_address=address, s_phone=phone, total_revenue=SUM($1.extended_price * (1 - $1.discount))]
+ ├─┬─ Calculate[S_SUPPKEY=key, S_NAME=name, S_ADDRESS=address, S_PHONE=phone, TOTAL_REVENUE=SUM($1.extended_price * (1 - $1.discount))]
│ └─┬─ AccessChild
│ ├─── SubCollection[lines]
│ └─── Where[(ship_date >= datetime.date(1996, 1, 1)) & (ship_date < datetime.date(1996, 4, 1))]
- ├─── Where[total_revenue == BACK(1).max_revenue]
- └─── OrderBy[s_suppkey.ASC(na_pos='first')]
+ ├─── Where[TOTAL_REVENUE == max_revenue]
+ └─── OrderBy[S_SUPPKEY.ASC(na_pos='first')]
""",
id="tpch_q15",
),
pytest.param(
- pydough_impl_tpch_q16,
+ impl_tpch_q16,
"""
──┬─ TPCH
├─┬─ Partition[name='ps', by=(p_brand, p_type, p_size)]
│ └─┬─ AccessChild
│ ├─── TableCollection[Parts]
- │ └─┬─ Where[(brand != 'BRAND#45') & NOT(STARTSWITH(part_type, 'MEDIUM POLISHED%')) & ISIN(size, [49, 14, 23, 45, 19, 3, 36, 9])]
+ │ ├─── Where[(brand != 'BRAND#45') & NOT(STARTSWITH(part_type, 'MEDIUM POLISHED%')) & ISIN(size, [49, 14, 23, 45, 19, 3, 36, 9])]
+ │ └─┬─ Calculate[p_brand=brand, p_type=part_type, p_size=size]
│ ├─── SubCollection[supply_records]
- │ ├─── Calc[p_brand=BACK(1).brand, p_type=BACK(1).part_type, p_size=BACK(1).size, ps_suppkey=supplier_key]
│ └─┬─ Where[NOT(LIKE($1.comment, '%Customer%Complaints%'))]
│ └─┬─ AccessChild
│ └─── SubCollection[supplier]
- ├─┬─ Calc[p_brand=p_brand, p_type=p_type, p_size=p_size, supplier_count=NDISTINCT($1.supplier_key)]
+ ├─┬─ Calculate[P_BRAND=p_brand, P_TYPE=p_type, P_SIZE=p_size, SUPPLIER_COUNT=NDISTINCT($1.supplier_key)]
│ └─┬─ AccessChild
│ └─── PartitionChild[ps]
- └─── TopK[10, supplier_count.DESC(na_pos='last'), p_brand.ASC(na_pos='first'), p_type.ASC(na_pos='first')]
+ └─── TopK[10, SUPPLIER_COUNT.DESC(na_pos='last'), P_BRAND.ASC(na_pos='first'), P_TYPE.ASC(na_pos='first'), P_SIZE.ASC(na_pos='first')]
""",
id="tpch_q16",
),
pytest.param(
- pydough_impl_tpch_q17,
+ impl_tpch_q17,
"""
┌─── TPCH
-└─┬─ Calc[avg_yearly=SUM($1.extended_price) / 7.0]
+└─┬─ Calculate[AVG_YEARLY=SUM($1.extended_price) / 7.0]
└─┬─ AccessChild
├─── TableCollection[Parts]
├─── Where[(brand == 'Brand#23') & (container == 'MED BOX')]
- └─┬─ Calc[avg_quantity=AVG($1.quantity)]
+ └─┬─ Calculate[part_avg_quantity=AVG($1.quantity)]
├─┬─ AccessChild
│ └─── SubCollection[lines]
├─── SubCollection[lines]
- └─── Where[quantity < (0.2 * BACK(1).avg_quantity)]
+ └─── Where[quantity < (0.2 * part_avg_quantity)]
""",
id="tpch_q17",
),
pytest.param(
- pydough_impl_tpch_q18,
+ impl_tpch_q18,
"""
──┬─ TPCH
├─── TableCollection[Orders]
- ├─┬─ Calc[c_name=$1.name, c_custkey=$1.key, o_orderkey=key, o_orderdate=order_date, o_totalprice=total_price, total_quantity=SUM($2.quantity)]
+ ├─┬─ Calculate[C_NAME=$1.name, C_CUSTKEY=$1.key, O_ORDERKEY=key, O_ORDERDATE=order_date, O_TOTALPRICE=total_price, TOTAL_QUANTITY=SUM($2.quantity)]
│ ├─┬─ AccessChild
│ │ └─── SubCollection[customer]
│ └─┬─ AccessChild
│ └─── SubCollection[lines]
- ├─── Where[total_quantity > 300]
- └─── TopK[10, o_totalprice.DESC(na_pos='last'), o_orderdate.ASC(na_pos='first')]
+ ├─── Where[TOTAL_QUANTITY > 300]
+ └─── TopK[10, O_TOTALPRICE.DESC(na_pos='last'), O_ORDERDATE.ASC(na_pos='first')]
""",
id="tpch_q18",
),
pytest.param(
- pydough_impl_tpch_q19,
+ impl_tpch_q19,
"""
┌─── TPCH
-└─┬─ Calc[revenue=SUM($1.extended_price * (1 - $1.discount))]
+└─┬─ Calculate[REVENUE=SUM($1.extended_price * (1 - $1.discount))]
└─┬─ AccessChild
├─── TableCollection[Lineitems]
└─┬─ Where[ISIN(ship_mode, ['AIR', 'AIR REG']) & (ship_instruct == 'DELIVER IN PERSON') & ($1.size >= 1) & (((($1.size <= 5) & (quantity >= 1) & (quantity <= 11) & ISIN($1.container, ['SM CASE', 'SM BOX', 'SM PACK', 'SM PKG']) & ($1.brand == 'Brand#12')) | (($1.size <= 10) & (quantity >= 10) & (quantity <= 20) & ISIN($1.container, ['MED BAG', 'MED BOX', 'MED PACK', 'MED PKG']) & ($1.brand == 'Brand#23'))) | (($1.size <= 15) & (quantity >= 20) & (quantity <= 30) & ISIN($1.container, ['LG CASE', 'LG BOX', 'LG PACK', 'LG PKG']) & ($1.brand == 'Brand#34')))]
@@ -1024,79 +445,80 @@ def pydough_impl_tpch_q22(root: UnqualifiedNode) -> UnqualifiedNode:
id="tpch_q19",
),
pytest.param(
- pydough_impl_tpch_q20,
+ impl_tpch_q20,
"""
──┬─ TPCH
├─── TableCollection[Suppliers]
- ├─── Calc[s_name=name, s_address=address]
- ├─┬─ Where[($1.name == 'CANADA') & (COUNT($2) > 0)]
+ ├─── Calculate[S_NAME=name, S_ADDRESS=address]
+ ├─┬─ Where[(($1.name == 'CANADA') & COUNT($2)) > 0]
│ ├─┬─ AccessChild
│ │ └─── SubCollection[nation]
│ └─┬─ AccessChild
- │ └─┬─ SubCollection[supply_records]
+ │ ├─── SubCollection[supply_records]
+ │ └─┬─ Calculate[availqty=availqty]
│ ├─── SubCollection[part]
- │ └─┬─ Where[STARTSWITH(name, 'forest') & HAS($1) & (BACK(1).availqty > (SUM($1.quantity) * 0.5))]
+ │ └─┬─ Where[STARTSWITH(name, 'forest') & (availqty > (SUM($1.quantity) * 0.5))]
│ └─┬─ AccessChild
│ ├─── SubCollection[lines]
│ └─── Where[(ship_date >= datetime.date(1994, 1, 1)) & (ship_date < datetime.date(1995, 1, 1))]
- └─── TopK[10, s_name.ASC(na_pos='first')]
+ └─── TopK[10, S_NAME.ASC(na_pos='first')]
""",
id="tpch_q20",
),
pytest.param(
- pydough_impl_tpch_q21,
+ impl_tpch_q21,
"""
──┬─ TPCH
├─── TableCollection[Suppliers]
├─┬─ Where[$1.name == 'SAUDI ARABIA']
│ └─┬─ AccessChild
│ └─── SubCollection[nation]
- ├─┬─ Calc[s_name=name, numwait=COUNT($1)]
+ ├─┬─ Calculate[S_NAME=name, NUMWAIT=COUNT($1)]
│ └─┬─ AccessChild
│ ├─── SubCollection[lines]
+ │ ├─── Calculate[original_key=supplier_key]
│ └─┬─ Where[receipt_date > commit_date]
│ ├─── SubCollection[order]
│ └─┬─ Where[(order_status == 'F') & HAS($1) & HASNOT($2)]
│ ├─┬─ AccessChild
│ │ ├─── SubCollection[lines]
- │ │ └─── Where[supplier_key != BACK(2).supplier_key]
+ │ │ └─── Where[supplier_key != original_key]
│ └─┬─ AccessChild
│ ├─── SubCollection[lines]
- │ └─── Where[(supplier_key != BACK(2).supplier_key) & (receipt_date > commit_date)]
- └─── TopK[10, numwait.DESC(na_pos='last'), s_name.ASC(na_pos='first')]
+ │ └─── Where[(supplier_key != original_key) & (receipt_date > commit_date)]
+ └─── TopK[10, NUMWAIT.DESC(na_pos='last'), S_NAME.ASC(na_pos='first')]
""",
id="tpch_q21",
),
pytest.param(
- pydough_impl_tpch_q22,
+ impl_tpch_q22,
"""
┌─── TPCH
-└─┬─ Calc[avg_balance=AVG($1.acctbal)]
+└─┬─ Calculate[global_avg_balance=AVG($1.acctbal)]
├─┬─ AccessChild
│ ├─── TableCollection[Customers]
- │ ├─── Calc[cntry_code=SLICE(phone, None, 2, None)]
- │ ├─┬─ Where[ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']) & HASNOT($1)]
- │ │ └─┬─ AccessChild
- │ │ └─── SubCollection[orders]
+ │ ├─── Calculate[cntry_code=SLICE(phone, None, 2, None)]
+ │ ├─── Where[ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17'])]
│ └─── Where[acctbal > 0.0]
├─┬─ Partition[name='custs', by=cntry_code]
│ └─┬─ AccessChild
│ ├─── TableCollection[Customers]
- │ ├─── Calc[cntry_code=SLICE(phone, None, 2, None)]
- │ ├─┬─ Where[ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17']) & HASNOT($1)]
- │ │ └─┬─ AccessChild
- │ │ └─── SubCollection[orders]
- │ └─── Where[acctbal > BACK(1).avg_balance]
- └─┬─ Calc[cntry_code=cntry_code, num_custs=COUNT($1), totacctbal=SUM($1.acctbal)]
- └─┬─ AccessChild
- └─── PartitionChild[custs]
+ │ ├─── Calculate[cntry_code=SLICE(phone, None, 2, None)]
+ │ ├─── Where[ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17'])]
+ │ └─┬─ Where[(acctbal > global_avg_balance) & (COUNT($1) == 0)]
+ │ └─┬─ AccessChild
+ │ └─── SubCollection[orders]
+ ├─┬─ Calculate[CNTRY_CODE=cntry_code, NUM_CUSTS=COUNT($1), TOTACCTBAL=SUM($1.acctbal)]
+ │ └─┬─ AccessChild
+ │ └─── PartitionChild[custs]
+ └─── OrderBy[CNTRY_CODE.ASC(na_pos='first')]
""",
id="tpch_q22",
),
],
)
def test_qualify_node_to_ast_string(
- impl: Callable[[UnqualifiedNode], UnqualifiedNode],
+ impl: Callable[[], UnqualifiedNode],
answer_tree_str: str,
get_sample_graph: graph_fetcher,
) -> None:
@@ -1105,8 +527,7 @@ def test_qualify_node_to_ast_string(
qualified DAG version, with the correct string representation.
"""
graph: GraphMetadata = get_sample_graph("TPCH")
- root: UnqualifiedNode = UnqualifiedRoot(graph)
- unqualified: UnqualifiedNode = impl(root)
+ unqualified: UnqualifiedNode = init_pydough_context(graph)(impl)()
qualified: PyDoughQDAG = qualify_node(unqualified, graph)
assert isinstance(
qualified, PyDoughCollectionQDAG
diff --git a/tests/test_qualification_errors.py b/tests/test_qualification_errors.py
index 03a70a12..ec2f871f 100644
--- a/tests/test_qualification_errors.py
+++ b/tests/test_qualification_errors.py
@@ -23,37 +23,41 @@ def bad_pydough_impl_01(root: UnqualifiedNode) -> UnqualifiedNode:
"""
Creates an UnqualifiedNode for the following invalid PyDough snippet:
```
- TPCH.Nations(nation_name=name, total_balance=SUM(acctbal))
+ TPCH.Nations.CALCULATE(nation_name=name, total_balance=SUM(acctbal))
```
The problem: there is no property `acctbal` to be accessed from Nations.
"""
- return root.Nations(nation_name=root.name, total_balance=root.SUM(root.acctbal))
+ return root.Nations.CALCULATE(
+ nation_name=root.name, total_balance=root.SUM(root.acctbal)
+ )
def bad_pydough_impl_02(root: UnqualifiedNode) -> UnqualifiedNode:
"""
Creates an UnqualifiedNode for the following invalid PyDough snippet:
```
- TPCH.Nations(nation_name=FIZZBUZZ(name))
+ TPCH.Nations.CALCULATE(nation_name=FIZZBUZZ(name))
```
The problem: there is no function named FIZZBUZZ, so this looks like a
- CALC term of a subcollection, which cannot be used as an expression inside
- a CALC.
+ CALCULATE being done onto a subcollection, which cannot be used as an
+ expression inside a CALCULATE.
"""
- return root.Nations(nation_name=root.FIZZBUZZ(root.name))
+ return root.Nations.CALCULATE(nation_name=root.FIZZBUZZ(root.name))
def bad_pydough_impl_03(root: UnqualifiedNode) -> UnqualifiedNode:
"""
Creates an UnqualifiedNode for the following invalid PyDough snippet:
```
- TPCH.Nations(y=suppliers(x=COUNT(parts_supplied)).x)
+ TPCH.Nations.CALCULATE(y=suppliers.CALCULATE(x=COUNT(parts_supplied)).x)
```
- The problem: `suppliers(x=COUNT(parts_supplied))` is plural with regards
+ The problem: `suppliers.CALCULATE(x=COUNT(parts_supplied))` is plural with regards
to Nations, so accessing its `x` property is still plural, therefore it
- cannot be used as a calc term relative to Nations.
+ cannot be used as a term inside a CALCULATE from the context of Nations.
"""
- return root.Nations(y=root.suppliers(x=root.COUNT(root.parts_supplied)).x)
+ return root.Nations.CALCULATE(
+ y=root.suppliers.CALCULATE(x=root.COUNT(root.parts_supplied)).x
+ )
def bad_pydough_impl_04(root: UnqualifiedNode) -> UnqualifiedNode:
@@ -75,61 +79,50 @@ def bad_pydough_impl_05(root: UnqualifiedNode) -> UnqualifiedNode:
TPCH.Customer(r=nation.region)
```
The problem: nation.region is a collection, therefore cannot be used as
- an expression in a CALC term.
+ an expression in a CALCULATE.
"""
- return root.Customers(r=root.nation.region)
+ return root.Customers.CALCULATE(r=root.nation.region)
def bad_pydough_impl_06(root: UnqualifiedNode) -> UnqualifiedNode:
"""
Creates an UnqualifiedNode for the following invalid PyDough snippet:
```
- TPCH.Suppliers.parts_supplied(o=ps_lines.order.order_date)
+ TPCH.Suppliers.supply_records.CALCULATE(o=lines.order.order_date)
```
- The problem: ps_lines is plural with regards to parts_supplied, therefore
- ps_lines.order.order_date is also plural and it cannot be used as a calc
- term relative to parts_supplied.
+ The problem: lines is plural with regards to supply_records, therefore
+ lines.order.order_date is also plural and it cannot be used in a CALCULATE
+ in the context of supply_records.
"""
- return root.Suppliers.parts_supplied(o=root.ps_lines.order.order_date)
+ return root.Suppliers.supply_records.CALCULATE(o=root.lines.order.order_date)
def bad_pydough_impl_07(root: UnqualifiedNode) -> UnqualifiedNode:
"""
Creates an UnqualifiedNode for the following invalid PyDough snippet:
```
- TPCH.Nations.suppliers.parts_supplied(cust_name=BACK(2).customers.name)
- ```
- The problem: customers is plural with regards to BACK(2), therefore
- BACK(2).customers.name is also plural and it cannot be used as a calc
- term relative to parts_supplied.
- """
- return root.Suppliers.parts_supplied(o=root.ps_lines.order.order_date)
-
-
-def bad_pydough_impl_08(root: UnqualifiedNode) -> UnqualifiedNode:
- """
- Creates an UnqualifiedNode for the following invalid PyDough snippet:
- ```
- TPCH.Lineitems(v=MUL(extended_price, SUB(1, discount)))
+ TPCH.Lineitems.CALCULATE(v=MUL(extended_price, SUB(1, discount)))
```
The problem: there is no function named MUL or SUB, so this looks like a
- CALC term of a subcollection, which cannot be used as an expression inside
- a CALC.
+ CALCULATE operation on a subcollection, which cannot be used as an
+ expression inside of a CALCULATE.
"""
- return root.Lineitems(v=root.MUL(root.extended_price, root.SUB(1, root.discount)))
+ return root.Lineitems.CALCULATE(
+ v=root.MUL(root.extended_price, root.SUB(1, root.discount))
+ )
-def bad_pydough_impl_09(root: UnqualifiedNode) -> UnqualifiedNode:
+def bad_pydough_impl_08(root: UnqualifiedNode) -> UnqualifiedNode:
"""
Creates an UnqualifiedNode for the following invalid PyDough snippet:
```
TPCH.Lineitems.tax = 0
- TPCH.Lineitems(value=extended_price * tax)
+ TPCH.Lineitems.CALCULATE(value=extended_price * tax)
```
The problem: writing to an unqualified node is not yet supported.
"""
root.Lineitems.tax = 0
- return root.Lineitems(value=root.extended_price * root.tax)
+ return root.Lineitems.CALCULATE(value=root.extended_price * root.tax)
@pytest.mark.parametrize(
@@ -142,12 +135,12 @@ def bad_pydough_impl_09(root: UnqualifiedNode) -> UnqualifiedNode:
),
pytest.param(
bad_pydough_impl_02,
- "Unrecognized term of simple table collection 'Nations' in graph 'TPCH': 'FIZZBUZZ'",
+ "PyDough nodes FIZZBUZZ is not callable. Did you mean to use a function?",
id="02",
),
pytest.param(
bad_pydough_impl_03,
- "Expected all terms in (y=suppliers(x=COUNT(parts_supplied)).x) to be singular, but encountered a plural expression: suppliers(x=COUNT(parts_supplied)).x",
+ "Expected all terms in CALCULATE(y=suppliers.CALCULATE(x=COUNT(parts_supplied)).x) to be singular, but encountered a plural expression: suppliers.CALCULATE(x=COUNT(parts_supplied)).x",
id="03",
),
pytest.param(
@@ -162,23 +155,18 @@ def bad_pydough_impl_09(root: UnqualifiedNode) -> UnqualifiedNode:
),
pytest.param(
bad_pydough_impl_06,
- "Expected all terms in (o=ps_lines.order.order_date) to be singular, but encountered a plural expression: ps_lines.order.order_date",
+ "Expected all terms in CALCULATE(o=lines.order.order_date) to be singular, but encountered a plural expression: lines.order.order_date",
id="06",
),
pytest.param(
bad_pydough_impl_07,
- "Expected all terms in (o=ps_lines.order.order_date) to be singular, but encountered a plural expression: ps_lines.order.order_date",
+ "PyDough nodes SUB is not callable. Did you mean to use a function?",
id="07",
),
pytest.param(
bad_pydough_impl_08,
- "Unrecognized term of simple table collection 'Lineitems' in graph 'TPCH': 'MUL'",
- id="08",
- ),
- pytest.param(
- bad_pydough_impl_09,
"PyDough objects do not yet support writing properties to them.",
- id="09",
+ id="08",
),
],
)
diff --git a/tests/test_sql_refsols/rank_a_ansi.sql b/tests/test_sql_refsols/rank_a_ansi.sql
index 8b1718e5..fd52284d 100644
--- a/tests/test_sql_refsols/rank_a_ansi.sql
+++ b/tests/test_sql_refsols/rank_a_ansi.sql
@@ -1,7 +1,9 @@
SELECT
- ROW_NUMBER() OVER (ORDER BY acctbal DESC NULLS FIRST) AS rank
+ key AS id,
+ ROW_NUMBER() OVER (ORDER BY acctbal DESC NULLS FIRST) AS rk
FROM (
SELECT
- c_acctbal AS acctbal
+ c_acctbal AS acctbal,
+ c_custkey AS key
FROM tpch.CUSTOMER
)
diff --git a/tests/test_sql_refsols/rank_a_sqlite.sql b/tests/test_sql_refsols/rank_a_sqlite.sql
index 9ae452fa..3682e60d 100644
--- a/tests/test_sql_refsols/rank_a_sqlite.sql
+++ b/tests/test_sql_refsols/rank_a_sqlite.sql
@@ -1,7 +1,9 @@
SELECT
- ROW_NUMBER() OVER (ORDER BY acctbal DESC) AS rank
+ key AS id,
+ ROW_NUMBER() OVER (ORDER BY acctbal DESC) AS rk
FROM (
SELECT
- c_acctbal AS acctbal
+ c_acctbal AS acctbal,
+ c_custkey AS key
FROM tpch.CUSTOMER
)
diff --git a/tests/test_sql_refsols/rank_b_ansi.sql b/tests/test_sql_refsols/rank_b_ansi.sql
index 2bb3c57c..a63bb06d 100644
--- a/tests/test_sql_refsols/rank_b_ansi.sql
+++ b/tests/test_sql_refsols/rank_b_ansi.sql
@@ -1,7 +1,9 @@
SELECT
+ key AS order_key,
RANK() OVER (ORDER BY order_priority NULLS LAST) AS rank
FROM (
SELECT
+ o_orderkey AS key,
o_orderpriority AS order_priority
FROM tpch.ORDERS
)
diff --git a/tests/test_sql_refsols/rank_b_sqlite.sql b/tests/test_sql_refsols/rank_b_sqlite.sql
index d7f98d3b..f03027ac 100644
--- a/tests/test_sql_refsols/rank_b_sqlite.sql
+++ b/tests/test_sql_refsols/rank_b_sqlite.sql
@@ -1,7 +1,9 @@
SELECT
+ key AS order_key,
RANK() OVER (ORDER BY order_priority) AS rank
FROM (
SELECT
+ o_orderkey AS key,
o_orderpriority AS order_priority
FROM tpch.ORDERS
)
diff --git a/tests/test_sql_refsols/simple_filter_ansi.sql b/tests/test_sql_refsols/simple_filter_ansi.sql
index bf21ef7a..78aaeaa8 100644
--- a/tests/test_sql_refsols/simple_filter_ansi.sql
+++ b/tests/test_sql_refsols/simple_filter_ansi.sql
@@ -1,8 +1,10 @@
SELECT
+ order_date,
o_orderkey,
o_totalprice
FROM (
SELECT
+ o_orderdate AS order_date,
o_orderkey AS o_orderkey,
o_totalprice AS o_totalprice
FROM tpch.ORDERS
diff --git a/tests/test_sql_refsols/simple_filter_sqlite.sql b/tests/test_sql_refsols/simple_filter_sqlite.sql
index bf21ef7a..78aaeaa8 100644
--- a/tests/test_sql_refsols/simple_filter_sqlite.sql
+++ b/tests/test_sql_refsols/simple_filter_sqlite.sql
@@ -1,8 +1,10 @@
SELECT
+ order_date,
o_orderkey,
o_totalprice
FROM (
SELECT
+ o_orderdate AS order_date,
o_orderkey AS o_orderkey,
o_totalprice AS o_totalprice
FROM tpch.ORDERS
diff --git a/tests/test_unqualified_node.py b/tests/test_unqualified_node.py
index b294f31e..9fb1af08 100644
--- a/tests/test_unqualified_node.py
+++ b/tests/test_unqualified_node.py
@@ -140,116 +140,116 @@ def verify_pydough_code_exec_match_unqualified(
[
pytest.param(
"answer = _ROOT.Parts",
- "?.Parts",
+ "Parts",
id="access_collection",
),
pytest.param(
"answer = _ROOT.Regions.nations",
- "?.Regions.nations",
+ "Regions.nations",
id="access_subcollection",
),
pytest.param(
"answer = _ROOT.Regions.name",
- "?.Regions.name",
+ "Regions.name",
id="access_property",
),
pytest.param(
- "answer = _ROOT.Regions(region_name=_ROOT.name, region_key=_ROOT.key)",
- "?.Regions(region_name=?.name, region_key=?.key)",
+ "answer = _ROOT.Regions.CALCULATE(region_name=_ROOT.name, region_key=_ROOT.key)",
+ "Regions.CALCULATE(region_name=name, region_key=key)",
id="simple_calc",
),
pytest.param(
- "answer = _ROOT.Nations(nation_name=_ROOT.UPPER(_ROOT.name), total_balance=_ROOT.SUM(_ROOT.customers.acct_bal))",
- "?.Nations(nation_name=UPPER(?.name), total_balance=SUM(?.customers.acct_bal))",
+ "answer = _ROOT.Nations.CALCULATE(nation_name=_ROOT.UPPER(_ROOT.name), total_balance=_ROOT.SUM(_ROOT.customers.acct_bal))",
+ "Nations.CALCULATE(nation_name=UPPER(name), total_balance=SUM(customers.acct_bal))",
id="calc_with_functions",
),
pytest.param(
"answer = _ROOT.x + 1",
- "(?.x + 1)",
+ "(x + 1)",
id="arithmetic_01",
),
pytest.param(
"answer = 2 + _ROOT.x",
- "(2 + ?.x)",
+ "(2 + x)",
id="arithmetic_02",
),
pytest.param(
"answer = ((1.5 * _ROOT.x) - 1)",
- "((1.5 * ?.x) - 1)",
+ "((1.5 * x) - 1)",
id="arithmetic_03",
),
pytest.param(
"answer = ((1.5 * _ROOT.x) - 1)",
- "((1.5 * ?.x) - 1)",
+ "((1.5 * x) - 1)",
id="arithmetic_03",
),
pytest.param(
"answer = (_ROOT.STARTSWITH(_ROOT.x, 'hello') | _ROOT.ENDSWITH(_ROOT.x, 'world')) & _ROOT.CONTAINS(_ROOT.x, ' ')",
- "((STARTSWITH(?.x, 'hello') | ENDSWITH(?.x, 'world')) & CONTAINS(?.x, ' '))",
+ "((STARTSWITH(x, 'hello') | ENDSWITH(x, 'world')) & CONTAINS(x, ' '))",
id="arithmetic_04",
),
pytest.param(
"answer = (1 / _ROOT.x) ** 2 - _ROOT.y",
- "(((1 / ?.x) ** 2) - ?.y)",
+ "(((1 / x) ** 2) - y)",
id="arithmetic_05",
),
pytest.param(
"answer = -(_ROOT.x % 10) / 3.1415",
- "((0 - (?.x % 10)) / 3.1415)",
+ "((0 - (x % 10)) / 3.1415)",
id="arithmetic_06",
),
pytest.param(
"answer = (+_ROOT.x < -_ROOT.y) ^ (_ROOT.y == _ROOT.z)",
- "((?.x < (0 - ?.y)) ^ (?.y == ?.z))",
+ "((x < (0 - y)) ^ (y == z))",
id="arithmetic_07",
),
pytest.param(
"answer = 'Hello' != _ROOT.word",
- "(?.word != 'Hello')",
+ "(word != 'Hello')",
id="arithmetic_08",
),
pytest.param(
"answer = _ROOT.order_date >= datetime.date(2020, 1, 1)",
- "(?.order_date >= datetime.date(2020, 1, 1))",
+ "(order_date >= datetime.date(2020, 1, 1))",
id="arithmetic_09",
),
pytest.param(
"answer = True & (0 >= _ROOT.x)",
- "(True & (?.x <= 0))",
+ "(True & (x <= 0))",
id="arithmetic_10",
),
pytest.param(
"answer = (_ROOT.x == 42) | (45 == _ROOT.x) | ((_ROOT.x < 16) & (_ROOT.x != 0)) | ((100 < _ROOT.x) ^ (0 == _ROOT.y))",
- "((((?.x == 42) | (?.x == 45)) | ((?.x < 16) & (?.x != 0))) | ((?.x > 100) ^ (?.y == 0)))",
+ "((((x == 42) | (x == 45)) | ((x < 16) & (x != 0))) | ((x > 100) ^ (y == 0)))",
id="arithmetic_11",
),
pytest.param(
"answer = False ^ 100 % 2.718281828 ** _ROOT.x",
- "(False ^ (100 % (2.718281828 ** ?.x)))",
+ "(False ^ (100 % (2.718281828 ** x)))",
id="arithmetic_12",
),
pytest.param(
- "answer = _ROOT.Parts(part_name=_ROOT.LOWER(_ROOT.name)).suppliers_of_part.region(part_name=_ROOT.BACK(2).part_name)",
- "?.Parts(part_name=LOWER(?.name)).suppliers_of_part.region(part_name=BACK(2).part_name)",
+ "answer = _ROOT.Parts.CALCULATE(part_name=_ROOT.LOWER(_ROOT.name)).suppliers_of_part.region.CALCULATE(part_name=_ROOT.part_name)",
+ "Parts.CALCULATE(part_name=LOWER(name)).suppliers_of_part.region.CALCULATE(part_name=part_name)",
id="multi_calc_with_back",
),
pytest.param(
"""\
-x = _ROOT.Parts(part_name=_ROOT.LOWER(_ROOT.name))
+x = _ROOT.Parts.CALCULATE(part_name=_ROOT.LOWER(_ROOT.name))
y = x.WHERE(_ROOT.STARTSWITH(_ROOT.part_name, 'a'))
answer = y.ORDER_BY(_ROOT.retail_price.DESC())\
""",
- "?.Parts(part_name=LOWER(?.name)).WHERE(STARTSWITH(?.part_name, 'a')).ORDER_BY(?.retail_price.DESC(na_pos='last'))",
+ "Parts.CALCULATE(part_name=LOWER(name)).WHERE(STARTSWITH(part_name, 'a')).ORDER_BY(retail_price.DESC(na_pos='last'))",
id="calc_with_where_order",
),
pytest.param(
"answer = _ROOT.Parts.TOP_K(10, by=(1 / (_ROOT.retail_price - 30.0)).ASC(na_pos='last'))",
- "?.Parts.TOP_K(10, by=((1 / (?.retail_price - 30.0)).ASC(na_pos='last')))",
+ "Parts.TOP_K(10, by=((1 / (retail_price - 30.0)).ASC(na_pos='last')))",
id="topk_single",
),
pytest.param(
"answer = _ROOT.Parts.TOP_K(10, by=(_ROOT.size.DESC(), _ROOT.part_type.DESC()))",
- "?.Parts.TOP_K(10, by=(?.size.DESC(na_pos='last'), ?.part_type.DESC(na_pos='last')))",
+ "Parts.TOP_K(10, by=(size.DESC(na_pos='last'), part_type.DESC(na_pos='last')))",
id="topk_multiple",
),
pytest.param(
@@ -257,27 +257,27 @@ def verify_pydough_code_exec_match_unqualified(
x = _ROOT.Parts.ORDER_BY(_ROOT.retail_price.ASC(na_pos='first'))
answer = x.TOP_K(100)\
""",
- "?.Parts.ORDER_BY(?.retail_price.ASC(na_pos='first')).TOP_K(100)",
+ "Parts.ORDER_BY(retail_price.ASC(na_pos='first')).TOP_K(100)",
id="order_topk_empty",
),
pytest.param(
- "answer = _ROOT.Parts(_ROOT.name, rank=_ROOT.RANKING(by=_ROOT.retail_price.DESC()))",
- "?.Parts(name=?.name, rank=RANKING(by=(?.retail_price.DESC(na_pos='last')))",
+ "answer = _ROOT.Parts.CALCULATE(_ROOT.name, rank=_ROOT.RANKING(by=_ROOT.retail_price.DESC()))",
+ "Parts.CALCULATE(name=name, rank=RANKING(by=(retail_price.DESC(na_pos='last')))",
id="ranking_1",
),
pytest.param(
- "answer = _ROOT.Parts(_ROOT.name, rank=_ROOT.RANKING(by=_ROOT.retail_price.DESC(), levels=1))",
- "?.Parts(name=?.name, rank=RANKING(by=(?.retail_price.DESC(na_pos='last'), levels=1))",
+ "answer = _ROOT.Parts.CALCULATE(_ROOT.name, rank=_ROOT.RANKING(by=_ROOT.retail_price.DESC(), levels=1))",
+ "Parts.CALCULATE(name=name, rank=RANKING(by=(retail_price.DESC(na_pos='last'), levels=1))",
id="ranking_2",
),
pytest.param(
- "answer = _ROOT.Parts(_ROOT.name, rank=_ROOT.RANKING(by=_ROOT.retail_price.DESC(), allow_ties=True))",
- "?.Parts(name=?.name, rank=RANKING(by=(?.retail_price.DESC(na_pos='last'), allow_ties=True))",
+ "answer = _ROOT.Parts.CALCULATE(_ROOT.name, rank=_ROOT.RANKING(by=_ROOT.retail_price.DESC(), allow_ties=True))",
+ "Parts.CALCULATE(name=name, rank=RANKING(by=(retail_price.DESC(na_pos='last'), allow_ties=True))",
id="ranking_3",
),
pytest.param(
- "answer = _ROOT.Parts(_ROOT.name, rank=_ROOT.RANKING(by=_ROOT.retail_price.DESC(), levels=2, allow_ties=True, dense=True))",
- "?.Parts(name=?.name, rank=RANKING(by=(?.retail_price.DESC(na_pos='last'), levels=2, allow_ties=True, dense=True))",
+ "answer = _ROOT.Parts.CALCULATE(_ROOT.name, rank=_ROOT.RANKING(by=_ROOT.retail_price.DESC(), levels=2, allow_ties=True, dense=True))",
+ "Parts.CALCULATE(name=name, rank=RANKING(by=(retail_price.DESC(na_pos='last'), levels=2, allow_ties=True, dense=True))",
id="ranking_4",
),
],
@@ -327,200 +327,197 @@ def test_unqualified_to_string(
[
pytest.param(
impl_tpch_q1,
- "?.PARTITION(?.Lineitems.WHERE((?.ship_date <= datetime.date(1998, 12, 1))), name='l', by=(?.return_flag, ?.status))(L_RETURNFLAG=?.return_flag, L_LINESTATUS=?.status, SUM_QTY=SUM(?.l.quantity), SUM_BASE_PRICE=SUM(?.l.extended_price), SUM_DISC_PRICE=SUM((?.l.extended_price * (1 - ?.l.discount))), SUM_CHARGE=SUM(((?.l.extended_price * (1 - ?.l.discount)) * (1 + ?.l.tax))), AVG_QTY=AVG(?.l.quantity), AVG_PRICE=AVG(?.l.extended_price), AVG_DISC=AVG(?.l.discount), COUNT_ORDER=COUNT(?.l)).ORDER_BY(?.L_RETURNFLAG.ASC(na_pos='first'), ?.L_LINESTATUS.ASC(na_pos='first'))",
+ "PARTITION(Lineitems.WHERE((ship_date <= datetime.date(1998, 12, 1))), name='l', by=(return_flag, status)).CALCULATE(L_RETURNFLAG=return_flag, L_LINESTATUS=status, SUM_QTY=SUM(l.quantity), SUM_BASE_PRICE=SUM(l.extended_price), SUM_DISC_PRICE=SUM((l.extended_price * (1 - l.discount))), SUM_CHARGE=SUM(((l.extended_price * (1 - l.discount)) * (1 + l.tax))), AVG_QTY=AVG(l.quantity), AVG_PRICE=AVG(l.extended_price), AVG_DISC=AVG(l.discount), COUNT_ORDER=COUNT(l)).ORDER_BY(L_RETURNFLAG.ASC(na_pos='first'), L_LINESTATUS.ASC(na_pos='first'))",
id="tpch_q1",
),
pytest.param(
impl_tpch_q2,
- "?.PARTITION(?.Nations.WHERE((?.region.name == 'EUROPE')).suppliers.supply_records.part(s_acctbal=BACK(2).account_balance, s_name=BACK(2).name, n_name=BACK(3).name, s_address=BACK(2).address, s_phone=BACK(2).phone, s_comment=BACK(2).comment, supplycost=BACK(1).supplycost).WHERE((ENDSWITH(?.part_type, 'BRASS') & (?.size == 15))), name='p', by=(?.key))(best_cost=MIN(?.p.supplycost)).p.WHERE((((?.supplycost == BACK(1).best_cost) & ENDSWITH(?.part_type, 'BRASS')) & (?.size == 15)))(S_ACCTBAL=?.s_acctbal, S_NAME=?.s_name, N_NAME=?.n_name, P_PARTKEY=?.key, P_MFGR=?.manufacturer, S_ADDRESS=?.s_address, S_PHONE=?.s_phone, S_COMMENT=?.s_comment).TOP_K(10, by=(?.S_ACCTBAL.DESC(na_pos='last'), ?.N_NAME.ASC(na_pos='first'), ?.S_NAME.ASC(na_pos='first'), ?.P_PARTKEY.ASC(na_pos='first')))",
+ "PARTITION(Nations.CALCULATE(n_name=name).WHERE((region.name == 'EUROPE')).suppliers.CALCULATE(s_acctbal=account_balance, s_name=name, s_address=address, s_phone=phone, s_comment=comment).supply_records.CALCULATE(supplycost=supplycost).part.WHERE((ENDSWITH(part_type, 'BRASS') & (size == 15))), name='p', by=(key)).CALCULATE(best_cost=MIN(p.supplycost)).p.WHERE((((supplycost == best_cost) & ENDSWITH(part_type, 'BRASS')) & (size == 15))).CALCULATE(S_ACCTBAL=s_acctbal, S_NAME=s_name, N_NAME=n_name, P_PARTKEY=key, P_MFGR=manufacturer, S_ADDRESS=s_address, S_PHONE=s_phone, S_COMMENT=s_comment).TOP_K(10, by=(S_ACCTBAL.DESC(na_pos='last'), N_NAME.ASC(na_pos='first'), S_NAME.ASC(na_pos='first'), P_PARTKEY.ASC(na_pos='first')))",
id="tpch_q2",
),
pytest.param(
impl_tpch_q3,
- "?.PARTITION(?.Orders.WHERE(((?.customer.mktsegment == 'BUILDING') & (?.order_date < datetime.date(1995, 3, 15)))).lines.WHERE((?.ship_date > datetime.date(1995, 3, 15)))(order_date=BACK(1).order_date, ship_priority=BACK(1).ship_priority), name='l', by=(?.order_key, ?.order_date, ?.ship_priority))(L_ORDERKEY=?.order_key, REVENUE=SUM((?.l.extended_price * (1 - ?.l.discount))), O_ORDERDATE=?.order_date, O_SHIPPRIORITY=?.ship_priority).TOP_K(10, by=(?.REVENUE.DESC(na_pos='last'), ?.O_ORDERDATE.ASC(na_pos='first'), ?.L_ORDERKEY.ASC(na_pos='first')))",
+ "PARTITION(Orders.CALCULATE(order_date=order_date, ship_priority=ship_priority).WHERE(((customer.mktsegment == 'BUILDING') & (order_date < datetime.date(1995, 3, 15)))).lines.WHERE((ship_date > datetime.date(1995, 3, 15))), name='l', by=(order_key, order_date, ship_priority)).CALCULATE(L_ORDERKEY=order_key, REVENUE=SUM((l.extended_price * (1 - l.discount))), O_ORDERDATE=order_date, O_SHIPPRIORITY=ship_priority).TOP_K(10, by=(REVENUE.DESC(na_pos='last'), O_ORDERDATE.ASC(na_pos='first'), L_ORDERKEY.ASC(na_pos='first')))",
id="tpch_q3",
),
pytest.param(
impl_tpch_q4,
- "?.PARTITION(?.Orders.WHERE((((?.order_date >= datetime.date(1993, 7, 1)) & (?.order_date < datetime.date(1993, 10, 1))) & HAS(?.lines.WHERE((?.commit_date < ?.receipt_date))))), name='o', by=(?.order_priority))(O_ORDERPRIORITY=?.order_priority, ORDER_COUNT=COUNT(?.o)).ORDER_BY(?.O_ORDERPRIORITY.ASC(na_pos='first'))",
+ "PARTITION(Orders.WHERE((((order_date >= datetime.date(1993, 7, 1)) & (order_date < datetime.date(1993, 10, 1))) & HAS(lines.WHERE((commit_date < receipt_date))))), name='o', by=(order_priority)).CALCULATE(O_ORDERPRIORITY=order_priority, ORDER_COUNT=COUNT(o)).ORDER_BY(O_ORDERPRIORITY.ASC(na_pos='first'))",
id="tpch_q4",
),
pytest.param(
impl_tpch_q5,
- "?.Nations.WHERE((?.region.name == 'ASIA'))(N_NAME=?.name, REVENUE=SUM(?.customers.orders.WHERE(((?.order_date >= datetime.date(1994, 1, 1)) & (?.order_date < datetime.date(1995, 1, 1)))).lines.WHERE((?.supplier.nation.name == BACK(3).name))(value=(?.extended_price * (1 - ?.discount))).value)).ORDER_BY(?.REVENUE.DESC(na_pos='last'))",
+ "Nations.CALCULATE(nation_name=name).WHERE((region.name == 'ASIA')).CALCULATE(N_NAME=name, REVENUE=SUM(customers.orders.WHERE(((order_date >= datetime.date(1994, 1, 1)) & (order_date < datetime.date(1995, 1, 1)))).lines.WHERE((supplier.nation.name == nation_name)).CALCULATE(value=(extended_price * (1 - discount))).value)).ORDER_BY(REVENUE.DESC(na_pos='last'))",
id="tpch_q5",
),
pytest.param(
impl_tpch_q6,
- "?.TPCH(REVENUE=SUM(?.Lineitems.WHERE((((((?.ship_date >= datetime.date(1994, 1, 1)) & (?.ship_date < datetime.date(1995, 1, 1))) & (?.discount >= 0.05)) & (?.discount <= 0.07)) & (?.quantity < 24)))(amt=(?.extended_price * ?.discount)).amt))",
+ "TPCH.CALCULATE(REVENUE=SUM(Lineitems.WHERE((((((ship_date >= datetime.date(1994, 1, 1)) & (ship_date < datetime.date(1995, 1, 1))) & (discount >= 0.05)) & (discount <= 0.07)) & (quantity < 24))).CALCULATE(amt=(extended_price * discount)).amt))",
id="tpch_q6",
),
pytest.param(
impl_tpch_q7,
- "?.PARTITION(?.Lineitems(supp_nation=?.supplier.nation.name, cust_nation=?.order.customer.nation.name, l_year=YEAR(?.ship_date), volume=(?.extended_price * (1 - ?.discount))).WHERE((((?.ship_date >= datetime.date(1995, 1, 1)) & (?.ship_date <= datetime.date(1996, 12, 31))) & (((?.supp_nation == 'FRANCE') & (?.cust_nation == 'GERMANY')) | ((?.supp_nation == 'GERMANY') & (?.cust_nation == 'FRANCE'))))), name='l', by=(?.supp_nation, ?.cust_nation, ?.l_year))(SUPP_NATION=?.supp_nation, CUST_NATION=?.cust_nation, L_YEAR=?.l_year, REVENUE=SUM(?.l.volume)).ORDER_BY(?.SUPP_NATION.ASC(na_pos='first'), ?.CUST_NATION.ASC(na_pos='first'), ?.L_YEAR.ASC(na_pos='first'))",
+ "PARTITION(Lineitems.CALCULATE(supp_nation=supplier.nation.name, cust_nation=order.customer.nation.name, l_year=YEAR(ship_date), volume=(extended_price * (1 - discount))).WHERE((((ship_date >= datetime.date(1995, 1, 1)) & (ship_date <= datetime.date(1996, 12, 31))) & (((supp_nation == 'FRANCE') & (cust_nation == 'GERMANY')) | ((supp_nation == 'GERMANY') & (cust_nation == 'FRANCE'))))), name='l', by=(supp_nation, cust_nation, l_year)).CALCULATE(SUPP_NATION=supp_nation, CUST_NATION=cust_nation, L_YEAR=l_year, REVENUE=SUM(l.volume)).ORDER_BY(SUPP_NATION.ASC(na_pos='first'), CUST_NATION.ASC(na_pos='first'), L_YEAR.ASC(na_pos='first'))",
id="tpch_q7",
),
pytest.param(
impl_tpch_q8,
- "?.PARTITION(?.Nations.suppliers.supply_records.WHERE((?.part.part_type == 'ECONOMY ANODIZED STEEL')).lines(volume=(?.extended_price * (1 - ?.discount))).order(o_year=YEAR(?.order_date), volume=BACK(1).volume, brazil_volume=IFF((BACK(4).name == 'BRAZIL'), BACK(1).volume, 0)).WHERE((((?.order_date >= datetime.date(1995, 1, 1)) & (?.order_date <= datetime.date(1996, 12, 31))) & (?.customer.nation.region.name == 'AMERICA'))), name='v', by=(?.o_year))(O_YEAR=?.o_year, MKT_SHARE=(SUM(?.v.brazil_volume) / SUM(?.v.volume)))",
+ "PARTITION(Nations.CALCULATE(nation_name=name).suppliers.supply_records.WHERE((part.part_type == 'ECONOMY ANODIZED STEEL')).lines.CALCULATE(volume=(extended_price * (1 - discount))).order.CALCULATE(o_year=YEAR(order_date), brazil_volume=IFF((nation_name == 'BRAZIL'), volume, 0)).WHERE((((order_date >= datetime.date(1995, 1, 1)) & (order_date <= datetime.date(1996, 12, 31))) & (customer.nation.region.name == 'AMERICA'))), name='v', by=(o_year)).CALCULATE(O_YEAR=o_year, MKT_SHARE=(SUM(v.brazil_volume) / SUM(v.volume)))",
id="tpch_q8",
),
pytest.param(
impl_tpch_q9,
- "?.PARTITION(?.Nations.suppliers.supply_records.WHERE(CONTAINS(?.part.name, 'green')).lines(nation=BACK(3).name, o_year=YEAR(?.order.order_date), value=((?.extended_price * (1 - ?.discount)) - (BACK(1).supplycost * ?.quantity))), name='l', by=(?.nation, ?.o_year))(NATION=?.nation, O_YEAR=?.o_year, AMOUNT=SUM(?.l.value)).TOP_K(10, by=(?.NATION.ASC(na_pos='first'), ?.O_YEAR.DESC(na_pos='last')))",
+ "PARTITION(Nations.CALCULATE(nation_name=name).suppliers.supply_records.CALCULATE(supplycost=supplycost).WHERE(CONTAINS(part.name, 'green')).lines.CALCULATE(o_year=YEAR(order.order_date), value=((extended_price * (1 - discount)) - (supplycost * quantity))), name='l', by=(nation_name, o_year)).CALCULATE(NATION=nation_name, O_YEAR=o_year, AMOUNT=SUM(l.value)).TOP_K(10, by=(NATION.ASC(na_pos='first'), O_YEAR.DESC(na_pos='last')))",
id="tpch_q9",
),
pytest.param(
impl_tpch_q10,
- "?.Customers(C_CUSTKEY=?.key, C_NAME=?.name, REVENUE=SUM(?.orders.WHERE(((?.order_date >= datetime.date(1993, 10, 1)) & (?.order_date < datetime.date(1994, 1, 1)))).lines.WHERE((?.return_flag == 'R'))(amt=(?.extended_price * (1 - ?.discount))).amt), C_ACCTBAL=?.acctbal, N_NAME=?.nation.name, C_ADDRESS=?.address, C_PHONE=?.phone, C_COMMENT=?.comment).TOP_K(20, by=(?.REVENUE.DESC(na_pos='last'), ?.C_CUSTKEY.ASC(na_pos='first')))",
+ "Customers.CALCULATE(C_CUSTKEY=key, C_NAME=name, REVENUE=SUM(orders.WHERE(((order_date >= datetime.date(1993, 10, 1)) & (order_date < datetime.date(1994, 1, 1)))).lines.WHERE((return_flag == 'R')).CALCULATE(amt=(extended_price * (1 - discount))).amt), C_ACCTBAL=acctbal, N_NAME=nation.name, C_ADDRESS=address, C_PHONE=phone, C_COMMENT=comment).TOP_K(20, by=(REVENUE.DESC(na_pos='last'), C_CUSTKEY.ASC(na_pos='first')))",
id="tpch_q10",
),
pytest.param(
impl_tpch_q11,
- "?.TPCH(min_market_share=(SUM(?.PartSupp.WHERE((?.supplier.nation.name == 'GERMANY'))(metric=(?.supplycost * ?.availqty)).metric) * 0.0001)).PARTITION(?.PartSupp.WHERE((?.supplier.nation.name == 'GERMANY'))(metric=(?.supplycost * ?.availqty)), name='ps', by=(?.part_key))(PS_PARTKEY=?.part_key, VALUE=SUM(?.ps.metric)).WHERE((?.VALUE > BACK(1).min_market_share)).TOP_K(10, by=(?.VALUE.DESC(na_pos='last')))",
+ "TPCH.CALCULATE(min_market_share=(SUM(PartSupp.WHERE((supplier.nation.name == 'GERMANY')).CALCULATE(metric=(supplycost * availqty)).metric) * 0.0001)).PARTITION(PartSupp.WHERE((supplier.nation.name == 'GERMANY')).CALCULATE(metric=(supplycost * availqty)), name='ps', by=(part_key)).CALCULATE(PS_PARTKEY=part_key, VALUE=SUM(ps.metric)).WHERE((VALUE > min_market_share)).TOP_K(10, by=(VALUE.DESC(na_pos='last')))",
id="tpch_q11",
),
pytest.param(
impl_tpch_q12,
- "?.PARTITION(?.Lineitems.WHERE(((((((?.ship_mode == 'MAIL') | (?.ship_mode == 'SHIP')) & (?.ship_date < ?.commit_date)) & (?.commit_date < ?.receipt_date)) & (?.receipt_date >= datetime.date(1994, 1, 1))) & (?.receipt_date < datetime.date(1995, 1, 1))))(is_high_priority=((?.order.order_priority == '1-URGENT') | (?.order.order_priority == '2-HIGH'))), name='l', by=(?.ship_mode))(L_SHIPMODE=?.ship_mode, HIGH_LINE_COUNT=SUM(?.l.is_high_priority), LOW_LINE_COUNT=SUM(NOT(?.l.is_high_priority))).ORDER_BY(?.L_SHIPMODE.ASC(na_pos='first'))",
+ "PARTITION(Lineitems.WHERE(((((((ship_mode == 'MAIL') | (ship_mode == 'SHIP')) & (ship_date < commit_date)) & (commit_date < receipt_date)) & (receipt_date >= datetime.date(1994, 1, 1))) & (receipt_date < datetime.date(1995, 1, 1)))).CALCULATE(is_high_priority=((order.order_priority == '1-URGENT') | (order.order_priority == '2-HIGH'))), name='l', by=(ship_mode)).CALCULATE(L_SHIPMODE=ship_mode, HIGH_LINE_COUNT=SUM(l.is_high_priority), LOW_LINE_COUNT=SUM(NOT(l.is_high_priority))).ORDER_BY(L_SHIPMODE.ASC(na_pos='first'))",
id="tpch_q12",
),
pytest.param(
impl_tpch_q13,
- "?.PARTITION(?.Customers(key=?.key, num_non_special_orders=COUNT(?.orders.WHERE(NOT(LIKE(?.comment, '%special%requests%'))))), name='custs', by=(?.num_non_special_orders))(C_COUNT=?.num_non_special_orders, CUSTDIST=COUNT(?.custs)).TOP_K(10, by=(?.CUSTDIST.DESC(na_pos='last'), ?.C_COUNT.DESC(na_pos='last')))",
+ "PARTITION(Customers.CALCULATE(num_non_special_orders=COUNT(orders.WHERE(NOT(LIKE(comment, '%special%requests%'))))), name='custs', by=(num_non_special_orders)).CALCULATE(C_COUNT=num_non_special_orders, CUSTDIST=COUNT(custs)).TOP_K(10, by=(CUSTDIST.DESC(na_pos='last'), C_COUNT.DESC(na_pos='last')))",
id="tpch_q13",
),
pytest.param(
impl_tpch_q14,
- "?.TPCH(PROMO_REVENUE=((100.0 * SUM(?.Lineitems.WHERE(((?.ship_date >= datetime.date(1995, 9, 1)) & (?.ship_date < datetime.date(1995, 10, 1))))(value=(?.extended_price * (1 - ?.discount)), promo_value=IFF(STARTSWITH(?.part.part_type, 'PROMO'), (?.extended_price * (1 - ?.discount)), 0)).promo_value)) / SUM(?.Lineitems.WHERE(((?.ship_date >= datetime.date(1995, 9, 1)) & (?.ship_date < datetime.date(1995, 10, 1))))(value=(?.extended_price * (1 - ?.discount)), promo_value=IFF(STARTSWITH(?.part.part_type, 'PROMO'), (?.extended_price * (1 - ?.discount)), 0)).value)))",
+ "TPCH.CALCULATE(PROMO_REVENUE=((100.0 * SUM(Lineitems.WHERE(((ship_date >= datetime.date(1995, 9, 1)) & (ship_date < datetime.date(1995, 10, 1)))).CALCULATE(value=(extended_price * (1 - discount)), promo_value=IFF(STARTSWITH(part.part_type, 'PROMO'), (extended_price * (1 - discount)), 0)).promo_value)) / SUM(Lineitems.WHERE(((ship_date >= datetime.date(1995, 9, 1)) & (ship_date < datetime.date(1995, 10, 1)))).CALCULATE(value=(extended_price * (1 - discount)), promo_value=IFF(STARTSWITH(part.part_type, 'PROMO'), (extended_price * (1 - discount)), 0)).value)))",
id="tpch_q14",
),
pytest.param(
impl_tpch_q15,
- "?.TPCH(max_revenue=MAX(?.Suppliers(total_revenue=SUM((?.lines.WHERE(((?.ship_date >= datetime.date(1996, 1, 1)) & (?.ship_date < datetime.date(1996, 4, 1)))).extended_price * (1 - ?.lines.WHERE(((?.ship_date >= datetime.date(1996, 1, 1)) & (?.ship_date < datetime.date(1996, 4, 1)))).discount)))).total_revenue)).Suppliers(S_SUPPKEY=?.key, S_NAME=?.name, S_ADDRESS=?.address, S_PHONE=?.phone, TOTAL_REVENUE=SUM((?.lines.WHERE(((?.ship_date >= datetime.date(1996, 1, 1)) & (?.ship_date < datetime.date(1996, 4, 1)))).extended_price * (1 - ?.lines.WHERE(((?.ship_date >= datetime.date(1996, 1, 1)) & (?.ship_date < datetime.date(1996, 4, 1)))).discount)))).WHERE((?.TOTAL_REVENUE == BACK(1).max_revenue)).ORDER_BY(?.S_SUPPKEY.ASC(na_pos='first'))",
+ "TPCH.CALCULATE(max_revenue=MAX(Suppliers.CALCULATE(total_revenue=SUM((lines.WHERE(((ship_date >= datetime.date(1996, 1, 1)) & (ship_date < datetime.date(1996, 4, 1)))).extended_price * (1 - lines.WHERE(((ship_date >= datetime.date(1996, 1, 1)) & (ship_date < datetime.date(1996, 4, 1)))).discount)))).total_revenue)).Suppliers.CALCULATE(S_SUPPKEY=key, S_NAME=name, S_ADDRESS=address, S_PHONE=phone, TOTAL_REVENUE=SUM((lines.WHERE(((ship_date >= datetime.date(1996, 1, 1)) & (ship_date < datetime.date(1996, 4, 1)))).extended_price * (1 - lines.WHERE(((ship_date >= datetime.date(1996, 1, 1)) & (ship_date < datetime.date(1996, 4, 1)))).discount)))).WHERE((TOTAL_REVENUE == max_revenue)).ORDER_BY(S_SUPPKEY.ASC(na_pos='first'))",
id="tpch_q15",
),
pytest.param(
impl_tpch_q16,
- "?.PARTITION(?.Parts.WHERE((((?.brand != 'BRAND#45') & NOT(STARTSWITH(?.part_type, 'MEDIUM POLISHED%'))) & ISIN(?.size, [49, 14, 23, 45, 19, 3, 36, 9]))).supply_records(p_brand=BACK(1).brand, p_type=BACK(1).part_type, p_size=BACK(1).size, ps_suppkey=?.supplier_key).WHERE(NOT(LIKE(?.supplier.comment, '%Customer%Complaints%'))), name='ps', by=(?.p_brand, ?.p_type, ?.p_size))(P_BRAND=?.p_brand, P_TYPE=?.p_type, P_SIZE=?.p_size, SUPPLIER_COUNT=NDISTINCT(?.ps.supplier_key)).TOP_K(10, by=(?.SUPPLIER_COUNT.DESC(na_pos='last'), ?.P_BRAND.ASC(na_pos='first'), ?.P_TYPE.ASC(na_pos='first'), ?.P_SIZE.ASC(na_pos='first')))",
+ "PARTITION(Parts.WHERE((((brand != 'BRAND#45') & NOT(STARTSWITH(part_type, 'MEDIUM POLISHED%'))) & ISIN(size, [49, 14, 23, 45, 19, 3, 36, 9]))).CALCULATE(p_brand=brand, p_type=part_type, p_size=size).supply_records.WHERE(NOT(LIKE(supplier.comment, '%Customer%Complaints%'))), name='ps', by=(p_brand, p_type, p_size)).CALCULATE(P_BRAND=p_brand, P_TYPE=p_type, P_SIZE=p_size, SUPPLIER_COUNT=NDISTINCT(ps.supplier_key)).TOP_K(10, by=(SUPPLIER_COUNT.DESC(na_pos='last'), P_BRAND.ASC(na_pos='first'), P_TYPE.ASC(na_pos='first'), P_SIZE.ASC(na_pos='first')))",
id="tpch_q16",
),
pytest.param(
impl_tpch_q17,
- "?.TPCH(AVG_YEARLY=(SUM(?.Parts.WHERE(((?.brand == 'Brand#23') & (?.container == 'MED BOX')))(avg_quantity=AVG(?.lines.quantity)).lines.WHERE((?.quantity < (0.2 * BACK(1).avg_quantity))).extended_price) / 7.0))",
+ "TPCH.CALCULATE(AVG_YEARLY=(SUM(Parts.WHERE(((brand == 'Brand#23') & (container == 'MED BOX'))).CALCULATE(part_avg_quantity=AVG(lines.quantity)).lines.WHERE((quantity < (0.2 * part_avg_quantity))).extended_price) / 7.0))",
id="tpch_q17",
),
pytest.param(
impl_tpch_q18,
- "?.Orders(C_NAME=?.customer.name, C_CUSTKEY=?.customer.key, O_ORDERKEY=?.key, O_ORDERDATE=?.order_date, O_TOTALPRICE=?.total_price, TOTAL_QUANTITY=SUM(?.lines.quantity)).WHERE((?.TOTAL_QUANTITY > 300)).TOP_K(10, by=(?.O_TOTALPRICE.DESC(na_pos='last'), ?.O_ORDERDATE.ASC(na_pos='first')))",
+ "Orders.CALCULATE(C_NAME=customer.name, C_CUSTKEY=customer.key, O_ORDERKEY=key, O_ORDERDATE=order_date, O_TOTALPRICE=total_price, TOTAL_QUANTITY=SUM(lines.quantity)).WHERE((TOTAL_QUANTITY > 300)).TOP_K(10, by=(O_TOTALPRICE.DESC(na_pos='last'), O_ORDERDATE.ASC(na_pos='first')))",
id="tpch_q18",
),
pytest.param(
impl_tpch_q19,
- "?.TPCH(REVENUE=SUM((?.Lineitems.WHERE((((ISIN(?.ship_mode, ['AIR', 'AIR REG']) & (?.ship_instruct == 'DELIVER IN PERSON')) & (?.part.size >= 1)) & (((((((?.part.size <= 5) & (?.quantity >= 1)) & (?.quantity <= 11)) & ISIN(?.part.container, ['SM CASE', 'SM BOX', 'SM PACK', 'SM PKG'])) & (?.part.brand == 'Brand#12')) | (((((?.part.size <= 10) & (?.quantity >= 10)) & (?.quantity <= 20)) & ISIN(?.part.container, ['MED BAG', 'MED BOX', 'MED PACK', 'MED PKG'])) & (?.part.brand == 'Brand#23'))) | (((((?.part.size <= 15) & (?.quantity >= 20)) & (?.quantity <= 30)) & ISIN(?.part.container, ['LG CASE', 'LG BOX', 'LG PACK', 'LG PKG'])) & (?.part.brand == 'Brand#34'))))).extended_price * (1 - ?.Lineitems.WHERE((((ISIN(?.ship_mode, ['AIR', 'AIR REG']) & (?.ship_instruct == 'DELIVER IN PERSON')) & (?.part.size >= 1)) & (((((((?.part.size <= 5) & (?.quantity >= 1)) & (?.quantity <= 11)) & ISIN(?.part.container, ['SM CASE', 'SM BOX', 'SM PACK', 'SM PKG'])) & (?.part.brand == 'Brand#12')) | (((((?.part.size <= 10) & (?.quantity >= 10)) & (?.quantity <= 20)) & ISIN(?.part.container, ['MED BAG', 'MED BOX', 'MED PACK', 'MED PKG'])) & (?.part.brand == 'Brand#23'))) | (((((?.part.size <= 15) & (?.quantity >= 20)) & (?.quantity <= 30)) & ISIN(?.part.container, ['LG CASE', 'LG BOX', 'LG PACK', 'LG PKG'])) & (?.part.brand == 'Brand#34'))))).discount))))",
+ "TPCH.CALCULATE(REVENUE=SUM((Lineitems.WHERE((((ISIN(ship_mode, ['AIR', 'AIR REG']) & (ship_instruct == 'DELIVER IN PERSON')) & (part.size >= 1)) & (((((((part.size <= 5) & (quantity >= 1)) & (quantity <= 11)) & ISIN(part.container, ['SM CASE', 'SM BOX', 'SM PACK', 'SM PKG'])) & (part.brand == 'Brand#12')) | (((((part.size <= 10) & (quantity >= 10)) & (quantity <= 20)) & ISIN(part.container, ['MED BAG', 'MED BOX', 'MED PACK', 'MED PKG'])) & (part.brand == 'Brand#23'))) | (((((part.size <= 15) & (quantity >= 20)) & (quantity <= 30)) & ISIN(part.container, ['LG CASE', 'LG BOX', 'LG PACK', 'LG PKG'])) & (part.brand == 'Brand#34'))))).extended_price * (1 - Lineitems.WHERE((((ISIN(ship_mode, ['AIR', 'AIR REG']) & (ship_instruct == 'DELIVER IN PERSON')) & (part.size >= 1)) & (((((((part.size <= 5) & (quantity >= 1)) & (quantity <= 11)) & ISIN(part.container, ['SM CASE', 'SM BOX', 'SM PACK', 'SM PKG'])) & (part.brand == 'Brand#12')) | (((((part.size <= 10) & (quantity >= 10)) & (quantity <= 20)) & ISIN(part.container, ['MED BAG', 'MED BOX', 'MED PACK', 'MED PKG'])) & (part.brand == 'Brand#23'))) | (((((part.size <= 15) & (quantity >= 20)) & (quantity <= 30)) & ISIN(part.container, ['LG CASE', 'LG BOX', 'LG PACK', 'LG PKG'])) & (part.brand == 'Brand#34'))))).discount))))",
id="tpch_q19",
),
pytest.param(
impl_tpch_q20,
- "?.Suppliers(S_NAME=?.name, S_ADDRESS=?.address).WHERE((((?.nation.name == 'CANADA') & COUNT(?.supply_records.part.WHERE((STARTSWITH(?.name, 'forest') & (BACK(1).availqty > (SUM(?.lines.WHERE(((?.ship_date >= datetime.date(1994, 1, 1)) & (?.ship_date < datetime.date(1995, 1, 1)))).quantity) * 0.5)))))) > 0)).TOP_K(10, by=(?.S_NAME.ASC(na_pos='first')))",
+ "Suppliers.CALCULATE(S_NAME=name, S_ADDRESS=address).WHERE((((nation.name == 'CANADA') & COUNT(supply_records.CALCULATE(availqty=availqty).part.WHERE((STARTSWITH(name, 'forest') & (availqty > (SUM(lines.WHERE(((ship_date >= datetime.date(1994, 1, 1)) & (ship_date < datetime.date(1995, 1, 1)))).quantity) * 0.5)))))) > 0)).TOP_K(10, by=(S_NAME.ASC(na_pos='first')))",
id="tpch_q20",
),
pytest.param(
impl_tpch_q21,
- "?.Suppliers.WHERE((?.nation.name == 'SAUDI ARABIA'))(S_NAME=?.name, NUMWAIT=COUNT(?.lines.WHERE((?.receipt_date > ?.commit_date)).order.WHERE((((?.order_status == 'F') & HAS(?.lines.WHERE((?.supplier_key != BACK(2).supplier_key)))) & HASNOT(?.lines.WHERE(((?.supplier_key != BACK(2).supplier_key) & (?.receipt_date > ?.commit_date)))))))).TOP_K(10, by=(?.NUMWAIT.DESC(na_pos='last'), ?.S_NAME.ASC(na_pos='first')))",
+ "Suppliers.WHERE((nation.name == 'SAUDI ARABIA')).CALCULATE(S_NAME=name, NUMWAIT=COUNT(lines.CALCULATE(original_key=supplier_key).WHERE((receipt_date > commit_date)).order.WHERE((((order_status == 'F') & HAS(lines.WHERE((supplier_key != original_key)))) & HASNOT(lines.WHERE(((supplier_key != original_key) & (receipt_date > commit_date)))))))).TOP_K(10, by=(NUMWAIT.DESC(na_pos='last'), S_NAME.ASC(na_pos='first')))",
id="tpch_q21",
),
pytest.param(
impl_tpch_q22,
- "?.TPCH(avg_balance=AVG(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE(ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17'])).WHERE((?.acctbal > 0.0)).acctbal)).PARTITION(?.Customers(cntry_code=SLICE(?.phone, None, 2, None)).WHERE(ISIN(?.cntry_code, ['13', '31', '23', '29', '30', '18', '17'])).WHERE(((?.acctbal > BACK(1).avg_balance) & (COUNT(?.orders) == 0))), name='custs', by=(?.cntry_code))(CNTRY_CODE=?.cntry_code, NUM_CUSTS=COUNT(?.custs), TOTACCTBAL=SUM(?.custs.acctbal)).ORDER_BY(?.CNTRY_CODE.ASC(na_pos='first'))",
+ "TPCH.CALCULATE(global_avg_balance=AVG(Customers.CALCULATE(cntry_code=SLICE(phone, None, 2, None)).WHERE(ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17'])).WHERE((acctbal > 0.0)).acctbal)).PARTITION(Customers.CALCULATE(cntry_code=SLICE(phone, None, 2, None)).WHERE(ISIN(cntry_code, ['13', '31', '23', '29', '30', '18', '17'])).WHERE(((acctbal > global_avg_balance) & (COUNT(orders) == 0))), name='custs', by=(cntry_code)).CALCULATE(CNTRY_CODE=cntry_code, NUM_CUSTS=COUNT(custs), TOTACCTBAL=SUM(custs.acctbal)).ORDER_BY(CNTRY_CODE.ASC(na_pos='first'))",
id="tpch_q22",
),
pytest.param(
loop_generated_terms,
- "?.Nations(name=?.name, interval_0=COUNT(?.customers.WHERE(MONOTONIC(0, ?.acctbal, 1000))), interval_1=COUNT(?.customers.WHERE(MONOTONIC(1000, ?.acctbal, 2000))), interval_2=COUNT(?.customers.WHERE(MONOTONIC(2000, ?.acctbal, 3000))))",
+ "Nations.CALCULATE(name=name, interval_0=COUNT(customers.WHERE(MONOTONIC(0, acctbal, 1000))), interval_1=COUNT(customers.WHERE(MONOTONIC(1000, acctbal, 2000))), interval_2=COUNT(customers.WHERE(MONOTONIC(2000, acctbal, 3000))))",
id="loop_generated_terms",
),
pytest.param(
function_defined_terms,
- "?.Nations(name=?.name, interval_7=COUNT(?.customers.WHERE(MONOTONIC(7000, ?.acctbal, 8000))), interval_4=COUNT(?.customers.WHERE(MONOTONIC(4000, ?.acctbal, 5000))), interval_13=COUNT(?.customers.WHERE(MONOTONIC(13000, ?.acctbal, 14000))))",
+ "Nations.CALCULATE(name=name, interval_7=COUNT(customers.WHERE(MONOTONIC(7000, acctbal, 8000))), interval_4=COUNT(customers.WHERE(MONOTONIC(4000, acctbal, 5000))), interval_13=COUNT(customers.WHERE(MONOTONIC(13000, acctbal, 14000))))",
id="function_defined_terms",
),
pytest.param(
function_defined_terms_with_duplicate_names,
- "?.Nations(name=?.name, redefined_name=?.name, interval_7=COUNT(?.customers.WHERE(MONOTONIC(7000, ?.acctbal, 8000))), interval_4=COUNT(?.customers.WHERE(MONOTONIC(4000, ?.acctbal, 5000))), interval_13=COUNT(?.customers.WHERE(MONOTONIC(13000, ?.acctbal, 14000))))",
+ "Nations.CALCULATE(name=name, redefined_name=name, interval_7=COUNT(customers.WHERE(MONOTONIC(7000, acctbal, 8000))), interval_4=COUNT(customers.WHERE(MONOTONIC(4000, acctbal, 5000))), interval_13=COUNT(customers.WHERE(MONOTONIC(13000, acctbal, 14000))))",
id="function_defined_terms_with_duplicate_names",
- # marks=pytest.mark.skip(
- # "TODO: (gh #222) ensure PyDough code is compatible with full Python syntax "
- # ),
),
pytest.param(
lambda_defined_terms,
- "?.Nations(name=?.name, interval_7=COUNT(?.customers.WHERE(MONOTONIC(7000, ?.acctbal, 8000))), interval_4=COUNT(?.customers.WHERE(MONOTONIC(4000, ?.acctbal, 5000))), interval_13=COUNT(?.customers.WHERE(MONOTONIC(13000, ?.acctbal, 14000))))",
+ "Nations.CALCULATE(name=name, interval_7=COUNT(customers.WHERE(MONOTONIC(7000, acctbal, 8000))), interval_4=COUNT(customers.WHERE(MONOTONIC(4000, acctbal, 5000))), interval_13=COUNT(customers.WHERE(MONOTONIC(13000, acctbal, 14000))))",
id="lambda_defined_terms",
),
pytest.param(
dict_comp_terms,
- "?.Nations(name=?.name, interval_0=COUNT(?.customers.WHERE(MONOTONIC(0, ?.acctbal, 1000))), interval_1=COUNT(?.customers.WHERE(MONOTONIC(1000, ?.acctbal, 2000))), interval_2=COUNT(?.customers.WHERE(MONOTONIC(2000, ?.acctbal, 3000))))",
+ "Nations.CALCULATE(name=name, interval_0=COUNT(customers.WHERE(MONOTONIC(0, acctbal, 1000))), interval_1=COUNT(customers.WHERE(MONOTONIC(1000, acctbal, 2000))), interval_2=COUNT(customers.WHERE(MONOTONIC(2000, acctbal, 3000))))",
id="dict_comp_terms",
),
pytest.param(
list_comp_terms,
- "?.Nations(name=?.name, _expr0=COUNT(?.customers.WHERE(MONOTONIC(0, ?.acctbal, 1000))), _expr1=COUNT(?.customers.WHERE(MONOTONIC(1000, ?.acctbal, 2000))), _expr2=COUNT(?.customers.WHERE(MONOTONIC(2000, ?.acctbal, 3000))))",
+ "Nations.CALCULATE(name=name, _expr0=COUNT(customers.WHERE(MONOTONIC(0, acctbal, 1000))), _expr1=COUNT(customers.WHERE(MONOTONIC(1000, acctbal, 2000))), _expr2=COUNT(customers.WHERE(MONOTONIC(2000, acctbal, 3000))))",
id="list_comp_terms",
),
pytest.param(
set_comp_terms,
- "?.Nations(name=?.name, _expr0=COUNT(?.customers.WHERE(MONOTONIC(0, ?.acctbal, 1000))), _expr1=COUNT(?.customers.WHERE(MONOTONIC(1000, ?.acctbal, 2000))), _expr2=COUNT(?.customers.WHERE(MONOTONIC(2000, ?.acctbal, 3000))))",
+ "Nations.CALCULATE(_expr0=COUNT(customers.WHERE(MONOTONIC(0, acctbal, 1000))), _expr1=COUNT(customers.WHERE(MONOTONIC(1000, acctbal, 2000))), _expr2=COUNT(customers.WHERE(MONOTONIC(2000, acctbal, 3000))), name=name)",
id="set_comp_terms",
),
pytest.param(
generator_comp_terms,
- "?.Nations(name=?.name, interval_0=COUNT(?.customers.WHERE(MONOTONIC(0, ?.acctbal, 1000))), interval_1=COUNT(?.customers.WHERE(MONOTONIC(1000, ?.acctbal, 2000))), interval_2=COUNT(?.customers.WHERE(MONOTONIC(2000, ?.acctbal, 3000))))",
+ "Nations.CALCULATE(name=name, interval_0=COUNT(customers.WHERE(MONOTONIC(0, acctbal, 1000))), interval_1=COUNT(customers.WHERE(MONOTONIC(1000, acctbal, 2000))), interval_2=COUNT(customers.WHERE(MONOTONIC(2000, acctbal, 3000))))",
id="generator_comp_terms",
),
pytest.param(
args_kwargs,
- "?.TPCH(n_tomato=COUNT(?.parts.WHERE(CONTAINS(?.part_name, 'tomato'))), n_almond=COUNT(?.parts.WHERE(CONTAINS(?.part_name, 'almond'))), small=COUNT(?.parts.WHERE(True)), large=COUNT(?.parts.WHERE(True)))",
+ "TPCH.CALCULATE(n_tomato=COUNT(parts.WHERE(CONTAINS(part_name, 'tomato'))), n_almond=COUNT(parts.WHERE(CONTAINS(part_name, 'almond'))), small=COUNT(parts.WHERE(True)), large=COUNT(parts.WHERE(True)))",
id="args_kwargs",
),
pytest.param(
unpacking,
- "?.orders.WHERE(MONOTONIC(1992, YEAR(?.order_date), 1994))",
+ "orders.WHERE(MONOTONIC(1992, YEAR(order_date), 1994))",
id="unpacking",
),
pytest.param(
nested_unpacking,
- "?.customers.WHERE(ISIN(?.nation.name, ['GERMANY', 'FRANCE', 'ARGENTINA']))",
+ "customers.WHERE(ISIN(nation.name, ['GERMANY', 'FRANCE', 'ARGENTINA']))",
id="nested_unpacking",
),
pytest.param(
unpacking_in_iterable,
- "?.Nations(c0=COUNT(?.orders.WHERE((YEAR(?.order_date) == 1992))), c1=COUNT(?.orders.WHERE((YEAR(?.order_date) == 1993))), c2=COUNT(?.orders.WHERE((YEAR(?.order_date) == 1994))), c3=COUNT(?.orders.WHERE((YEAR(?.order_date) == 1995))), c4=COUNT(?.orders.WHERE((YEAR(?.order_date) == 1996))))",
+ "Nations.CALCULATE(c0=COUNT(orders.WHERE((YEAR(order_date) == 1992))), c1=COUNT(orders.WHERE((YEAR(order_date) == 1993))), c2=COUNT(orders.WHERE((YEAR(order_date) == 1994))), c3=COUNT(orders.WHERE((YEAR(order_date) == 1995))), c4=COUNT(orders.WHERE((YEAR(order_date) == 1996))))",
id="unpacking_in_iterable",
),
pytest.param(
with_import_statement,
- "?.customers.WHERE(ISIN(?.nation.name, ['Canada', 'Mexico']))",
+ "customers.WHERE(ISIN(nation.name, ['Canada', 'Mexico']))",
id="with_import_statement",
),
pytest.param(
exception_handling,
- "?.customers.WHERE(ISIN(?.nation.name, ['Canada', 'Mexico']))",
+ "customers.WHERE(ISIN(nation.name, ['Canada', 'Mexico']))",
id="exception_handling",
),
pytest.param(
class_handling,
- "?.customers.WHERE(ISIN(?.nation.name, ['Canada', 'Mexico']))",
+ "customers.WHERE(ISIN(nation.name, ['Canada', 'Mexico']))",
id="class_handling",
),
pytest.param(
annotated_assignment,
- "?.Nations.WHERE((?.region.name == 'SOUTH WEST AMERICA'))",
+ "Nations.WHERE((region.name == 'SOUTH WEST AMERICA'))",
id="annotated_assignment",
),
pytest.param(
abs_round_magic_method,
- "?.DailyPrices(abs_low=ABS(?.low), round_low=ROUND(?.low, 2), round_zero=ROUND(?.low, 0))",
+ "DailyPrices.CALCULATE(abs_low=ABS(low), round_low=ROUND(low, 2), round_zero=ROUND(low, 0))",
id="abs_round_magic_method",
),
],
@@ -661,7 +658,7 @@ def test_init_pydough_context(
),
pytest.param(
bad_iter,
- "Cannot index into PyDough object \?.customer with 0",
+ "Cannot index into PyDough object Customers with 0",
id="bad_iter",
),
],
diff --git a/tests/test_utils.py b/tests/test_utils.py
index f60b8b00..108342cd 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -5,9 +5,8 @@
__all__ = [
"AstNodeTestInfo",
- "BackReferenceCollectionInfo",
"BackReferenceExpressionInfo",
- "CalcInfo",
+ "CalculateInfo",
"ChildReferenceExpressionInfo",
"ColumnInfo",
"FunctionInfo",
@@ -32,7 +31,7 @@
from pydough.metadata import GraphMetadata
from pydough.qdag import (
AstNodeBuilder,
- Calc,
+ Calculate,
ChildOperatorChildAccess,
ChildReferenceExpression,
CollationExpression,
@@ -102,7 +101,8 @@ def build(
`context`: an optional collection QDAG used as the context within
which the QDAG is created.
`children_contexts`: an optional list of collection QDAGs of
- child nodes of a CALC that are accessible for ChildReferenceExpression usage.
+ child nodes of a CALCULATE that are accessible for
+ ChildReferenceExpression usage.
Returns:
The new instance of the QDAG object.
@@ -259,7 +259,7 @@ class ReferenceInfo(AstNodeTestInfo):
"""
TestInfo implementation class to build a reference. Contains the
following fields:
- - `name`: the name of the calc term being referenced.
+ - `name`: the name of the term being referenced from the preceding context.
"""
def __init__(self, name: str):
@@ -284,7 +284,7 @@ class BackReferenceExpressionInfo(AstNodeTestInfo):
"""
TestInfo implementation class to build a back reference expression.
Contains the following fields:
- - `name`: the name of the calc term being referenced.
+ - `name`: the name of the term being referenced.
- `levels`: the number of levels upward to reference.
"""
@@ -311,7 +311,7 @@ class ChildReferenceExpressionInfo(AstNodeTestInfo):
"""
TestInfo implementation class to build a child reference expression.
Contains the following fields:
- - `name`: the name of the calc term being referenced.
+ - `name`: the name of the term being referenced.
- `child_idx`: the index of the child being referenced.
"""
@@ -394,7 +394,8 @@ def local_build(
`context`: an optional collection QDAG used as the context within
which the QDAG is created.
`children_contexts`: an optional list of collection QDAG of child
- nodes of a CALC that are accessible for ChildReferenceExpression usage.
+ nodes of a CALCULATE that are accessible for
+ ChildReferenceExpression usage.
Returns:
The new instance of the collection QDAG object.
@@ -467,7 +468,7 @@ def local_build(
class ChildOperatorChildAccessInfo(CollectionTestInfo):
"""
CollectionTestInfo implementation class that wraps around a subcollection
- info within a Calc context. Contains the following fields:
+ info within a CALCULATE context. Contains the following fields:
- `child_info`: the collection info for the child subcollection.
NOTE: must provide a `context` when building.
@@ -495,41 +496,11 @@ def local_build(
return ChildOperatorChildAccess(access)
-class BackReferenceCollectionInfo(CollectionTestInfo):
- """
- CollectionTestInfo implementation class to build a reference to an
- ancestor collection. Contains the following fields:
- - `name`: the name of the calc term being referenced.
- - `levels`: the number of levels upward to reference.
-
- NOTE: must provide a `context` when building.
- """
-
- def __init__(self, name: str, levels: int):
- super().__init__()
- self.name: str = name
- self.levels: int = levels
-
- def local_string(self) -> str:
- return f"BackReferenceCollection[{self.levels}:{self.name}]"
-
- def local_build(
- self,
- builder: AstNodeBuilder,
- context: PyDoughCollectionQDAG | None = None,
- children_contexts: MutableSequence[PyDoughCollectionQDAG] | None = None,
- ) -> PyDoughCollectionQDAG:
- assert (
- context is not None
- ), "Cannot call .build() on BackReferenceCollectionInfo without providing a context"
- return builder.build_back_reference_collection(context, self.name, self.levels)
-
-
class ChildReferenceCollectionInfo(CollectionTestInfo):
"""
CollectionTestInfo implementation class to build a reference to a
child collection. Contains the following fields:
- - `idx`: the index of the calc term being referenced.
+ - `idx`: the index of the child collection being referenced.
NOTE: must provide a `context` when building.
"""
@@ -561,7 +532,7 @@ def local_build(
class ChildOperatorInfo(CollectionTestInfo):
"""
Base class for types of CollectionTestInfo that have child nodes, such as
- CALC or WHERE. Contains the following fields:
+ CALCULATE or WHERE. Contains the following fields:
- `children_info`: a list of CollectionTestInfo objects that will be used
to build the child contexts.
"""
@@ -605,14 +576,14 @@ def build_children(
return children
-class CalcInfo(ChildOperatorInfo):
+class CalculateInfo(ChildOperatorInfo):
"""
- CollectionTestInfo implementation class to build a CALC node.
+ CollectionTestInfo implementation class to build a CALCULATE node.
Contains the following fields:
- `children_info`: a list of CollectionTestInfo objects that will be used
to build the child contexts.
- `args`: a list tuples containing a field name and a test info to derive
- an expression in the CALC. Passed in via keyword arguments to the
+ an expression in the CALCULATE. Passed in via keyword arguments to the
constructor, where the argument names are the field names and the
argument values are the expression infos.
"""
@@ -625,7 +596,7 @@ def local_string(self) -> str:
args_strings: MutableSequence[str] = [
f"{name}={arg.to_string()}" for name, arg in self.args
]
- return f"Calc[{self.child_strings()}{', '.join(args_strings)}]"
+ return f"Calculate[{self.child_strings()}{', '.join(args_strings)}]"
def local_build(
self,
@@ -639,8 +610,8 @@ def local_build(
builder,
context,
)
- raw_calc = builder.build_calc(context, children)
- assert isinstance(raw_calc, Calc)
+ raw_calc = builder.build_calculate(context, children)
+ assert isinstance(raw_calc, Calculate)
args: MutableSequence[tuple[str, PyDoughExpressionQDAG]] = []
for name, info in self.args:
expr = info.build(builder, context, children)
diff --git a/tests/tpch_test_functions.py b/tests/tpch_test_functions.py
index c1df2cf8..a5176ff8 100644
--- a/tests/tpch_test_functions.py
+++ b/tests/tpch_test_functions.py
@@ -31,18 +31,22 @@ def impl_tpch_q1():
PyDough implementation of TPCH Q1.
"""
selected_lines = Lineitems.WHERE((ship_date <= datetime.date(1998, 12, 1)))
- return PARTITION(selected_lines, name="l", by=(return_flag, status))(
- L_RETURNFLAG=return_flag,
- L_LINESTATUS=status,
- SUM_QTY=SUM(l.quantity),
- SUM_BASE_PRICE=SUM(l.extended_price),
- SUM_DISC_PRICE=SUM(l.extended_price * (1 - l.discount)),
- SUM_CHARGE=SUM(l.extended_price * (1 - l.discount) * (1 + l.tax)),
- AVG_QTY=AVG(l.quantity),
- AVG_PRICE=AVG(l.extended_price),
- AVG_DISC=AVG(l.discount),
- COUNT_ORDER=COUNT(l),
- ).ORDER_BY(L_RETURNFLAG.ASC(), L_LINESTATUS.ASC())
+ return (
+ PARTITION(selected_lines, name="l", by=(return_flag, status))
+ .CALCULATE(
+ L_RETURNFLAG=return_flag,
+ L_LINESTATUS=status,
+ SUM_QTY=SUM(l.quantity),
+ SUM_BASE_PRICE=SUM(l.extended_price),
+ SUM_DISC_PRICE=SUM(l.extended_price * (1 - l.discount)),
+ SUM_CHARGE=SUM(l.extended_price * (1 - l.discount) * (1 + l.tax)),
+ AVG_QTY=AVG(l.quantity),
+ AVG_PRICE=AVG(l.extended_price),
+ AVG_DISC=AVG(l.discount),
+ COUNT_ORDER=COUNT(l),
+ )
+ .ORDER_BY(L_RETURNFLAG.ASC(), L_LINESTATUS.ASC())
+ )
def impl_tpch_q2():
@@ -50,26 +54,28 @@ def impl_tpch_q2():
PyDough implementation of TPCH Q2, truncated to 10 rows.
"""
selected_parts = (
- Nations.WHERE(region.name == "EUROPE")
- .suppliers.supply_records.part(
- s_acctbal=BACK(2).account_balance,
- s_name=BACK(2).name,
- n_name=BACK(3).name,
- s_address=BACK(2).address,
- s_phone=BACK(2).phone,
- s_comment=BACK(2).comment,
- supplycost=BACK(1).supplycost,
+ Nations.CALCULATE(n_name=name)
+ .WHERE(region.name == "EUROPE")
+ .suppliers.CALCULATE(
+ s_acctbal=account_balance,
+ s_name=name,
+ s_address=address,
+ s_phone=phone,
+ s_comment=comment,
+ )
+ .supply_records.CALCULATE(
+ supplycost=supplycost,
)
- .WHERE(ENDSWITH(part_type, "BRASS") & (size == 15))
+ .part.WHERE(ENDSWITH(part_type, "BRASS") & (size == 15))
)
return (
- PARTITION(selected_parts, name="p", by=key)(best_cost=MIN(p.supplycost))
+ PARTITION(selected_parts, name="p", by=key)
+ .CALCULATE(best_cost=MIN(p.supplycost))
.p.WHERE(
- (supplycost == BACK(1).best_cost)
- & ENDSWITH(part_type, "BRASS")
- & (size == 15)
- )(
+ (supplycost == best_cost) & ENDSWITH(part_type, "BRASS") & (size == 15)
+ )
+ .CALCULATE(
S_ACCTBAL=s_acctbal,
S_NAME=s_name,
N_NAME=n_name,
@@ -90,21 +96,25 @@ def impl_tpch_q3():
"""
PyDough implementation of TPCH Q3.
"""
- selected_lines = Orders.WHERE(
- (customer.mktsegment == "BUILDING") & (order_date < datetime.date(1995, 3, 15))
- ).lines.WHERE(ship_date > datetime.date(1995, 3, 15))(
- BACK(1).order_date,
- BACK(1).ship_priority,
+ selected_lines = (
+ Orders.CALCULATE(order_date, ship_priority)
+ .WHERE(
+ (customer.mktsegment == "BUILDING")
+ & (order_date < datetime.date(1995, 3, 15))
+ )
+ .lines.WHERE(ship_date > datetime.date(1995, 3, 15))
)
- return PARTITION(
- selected_lines, name="l", by=(order_key, order_date, ship_priority)
- )(
- L_ORDERKEY=order_key,
- REVENUE=SUM(l.extended_price * (1 - l.discount)),
- O_ORDERDATE=order_date,
- O_SHIPPRIORITY=ship_priority,
- ).TOP_K(10, by=(REVENUE.DESC(), O_ORDERDATE.ASC(), L_ORDERKEY.ASC()))
+ return (
+ PARTITION(selected_lines, name="l", by=(order_key, order_date, ship_priority))
+ .CALCULATE(
+ L_ORDERKEY=order_key,
+ REVENUE=SUM(l.extended_price * (1 - l.discount)),
+ O_ORDERDATE=order_date,
+ O_SHIPPRIORITY=ship_priority,
+ )
+ .TOP_K(10, by=(REVENUE.DESC(), O_ORDERDATE.ASC(), L_ORDERKEY.ASC()))
+ )
def impl_tpch_q4():
@@ -117,25 +127,34 @@ def impl_tpch_q4():
& (order_date < datetime.date(1993, 10, 1))
& HAS(selected_lines)
)
- return PARTITION(selected_orders, name="o", by=order_priority)(
- O_ORDERPRIORITY=order_priority,
- ORDER_COUNT=COUNT(o),
- ).ORDER_BY(O_ORDERPRIORITY.ASC())
+ return (
+ PARTITION(selected_orders, name="o", by=order_priority)
+ .CALCULATE(
+ O_ORDERPRIORITY=order_priority,
+ ORDER_COUNT=COUNT(o),
+ )
+ .ORDER_BY(O_ORDERPRIORITY.ASC())
+ )
def impl_tpch_q5():
"""
PyDough implementation of TPCH Q5.
"""
- selected_lines = customers.orders.WHERE(
- (order_date >= datetime.date(1994, 1, 1))
- & (order_date < datetime.date(1995, 1, 1))
- ).lines.WHERE(supplier.nation.name == BACK(3).name)(
- value=extended_price * (1 - discount)
+ selected_lines = (
+ customers.orders.WHERE(
+ (order_date >= datetime.date(1994, 1, 1))
+ & (order_date < datetime.date(1995, 1, 1))
+ )
+ .lines.WHERE(supplier.nation.name == nation_name)
+ .CALCULATE(value=extended_price * (1 - discount))
+ )
+ return (
+ Nations.CALCULATE(nation_name=name)
+ .WHERE(region.name == "ASIA")
+ .CALCULATE(N_NAME=name, REVENUE=SUM(selected_lines.value))
+ .ORDER_BY(REVENUE.DESC())
)
- return Nations.WHERE(region.name == "ASIA")(
- N_NAME=name, REVENUE=SUM(selected_lines.value)
- ).ORDER_BY(REVENUE.DESC())
def impl_tpch_q6():
@@ -148,15 +167,15 @@ def impl_tpch_q6():
& (0.05 <= discount)
& (discount <= 0.07)
& (quantity < 24)
- )(amt=extended_price * discount)
- return TPCH(REVENUE=SUM(selected_lines.amt))
+ ).CALCULATE(amt=extended_price * discount)
+ return TPCH.CALCULATE(REVENUE=SUM(selected_lines.amt))
def impl_tpch_q7():
"""
PyDough implementation of TPCH Q7.
"""
- line_info = Lineitems(
+ line_info = Lineitems.CALCULATE(
supp_nation=supplier.nation.name,
cust_nation=order.customer.nation.name,
l_year=YEAR(ship_date),
@@ -170,15 +189,19 @@ def impl_tpch_q7():
)
)
- return PARTITION(line_info, name="l", by=(supp_nation, cust_nation, l_year))(
- SUPP_NATION=supp_nation,
- CUST_NATION=cust_nation,
- L_YEAR=l_year,
- REVENUE=SUM(l.volume),
- ).ORDER_BY(
- SUPP_NATION.ASC(),
- CUST_NATION.ASC(),
- L_YEAR.ASC(),
+ return (
+ PARTITION(line_info, name="l", by=(supp_nation, cust_nation, l_year))
+ .CALCULATE(
+ SUPP_NATION=supp_nation,
+ CUST_NATION=cust_nation,
+ L_YEAR=l_year,
+ REVENUE=SUM(l.volume),
+ )
+ .ORDER_BY(
+ SUPP_NATION.ASC(),
+ CUST_NATION.ASC(),
+ L_YEAR.ASC(),
+ )
)
@@ -187,14 +210,12 @@ def impl_tpch_q8():
PyDough implementation of TPCH Q8.
"""
volume_data = (
- Nations.suppliers.supply_records.WHERE(
- part.part_type == "ECONOMY ANODIZED STEEL"
- )
- .lines(volume=extended_price * (1 - discount))
- .order(
+ Nations.CALCULATE(nation_name=name)
+ .suppliers.supply_records.WHERE(part.part_type == "ECONOMY ANODIZED STEEL")
+ .lines.CALCULATE(volume=extended_price * (1 - discount))
+ .order.CALCULATE(
o_year=YEAR(order_date),
- volume=BACK(1).volume,
- brazil_volume=IFF(BACK(4).name == "BRAZIL", BACK(1).volume, 0),
+ brazil_volume=IFF(nation_name == "BRAZIL", volume, 0),
)
.WHERE(
(order_date >= datetime.date(1995, 1, 1))
@@ -203,7 +224,7 @@ def impl_tpch_q8():
)
)
- return PARTITION(volume_data, name="v", by=o_year)(
+ return PARTITION(volume_data, name="v", by=o_year).CALCULATE(
O_YEAR=o_year,
MKT_SHARE=SUM(v.brazil_volume) / SUM(v.volume),
)
@@ -213,18 +234,22 @@ def impl_tpch_q9():
"""
PyDough implementation of TPCH Q9, truncated to 10 rows.
"""
- selected_lines = Nations.suppliers.supply_records.WHERE(
- CONTAINS(part.name, "green")
- ).lines(
- nation=BACK(3).name,
- o_year=YEAR(order.order_date),
- value=extended_price * (1 - discount) - BACK(1).supplycost * quantity,
+ selected_lines = (
+ Nations.CALCULATE(nation_name=name)
+ .suppliers.supply_records.CALCULATE(supplycost)
+ .WHERE(CONTAINS(part.name, "green"))
+ .lines.CALCULATE(
+ o_year=YEAR(order.order_date),
+ value=extended_price * (1 - discount) - supplycost * quantity,
+ )
)
- return PARTITION(selected_lines, name="l", by=(nation, o_year))(
- NATION=nation, O_YEAR=o_year, AMOUNT=SUM(l.value)
- ).TOP_K(
- 10,
- by=(NATION.ASC(), O_YEAR.DESC()),
+ return (
+ PARTITION(selected_lines, name="l", by=(nation_name, o_year))
+ .CALCULATE(NATION=nation_name, O_YEAR=o_year, AMOUNT=SUM(l.value))
+ .TOP_K(
+ 10,
+ by=(NATION.ASC(), O_YEAR.DESC()),
+ )
)
@@ -232,11 +257,15 @@ def impl_tpch_q10():
"""
PyDough implementation of TPCH Q10.
"""
- selected_lines = orders.WHERE(
- (order_date >= datetime.date(1993, 10, 1))
- & (order_date < datetime.date(1994, 1, 1))
- ).lines.WHERE(return_flag == "R")(amt=extended_price * (1 - discount))
- return Customers(
+ selected_lines = (
+ orders.WHERE(
+ (order_date >= datetime.date(1993, 10, 1))
+ & (order_date < datetime.date(1994, 1, 1))
+ )
+ .lines.WHERE(return_flag == "R")
+ .CALCULATE(amt=extended_price * (1 - discount))
+ )
+ return Customers.CALCULATE(
C_CUSTKEY=key,
C_NAME=name,
REVENUE=SUM(selected_lines.amt),
@@ -253,13 +282,14 @@ def impl_tpch_q11():
PyDough implementation of TPCH Q11, truncated to 10 rows
"""
is_german_supplier = supplier.nation.name == "GERMANY"
- selected_records = PartSupp.WHERE(is_german_supplier)(metric=supplycost * availqty)
+ selected_records = PartSupp.WHERE(is_german_supplier).CALCULATE(
+ metric=supplycost * availqty
+ )
return (
- TPCH(min_market_share=SUM(selected_records.metric) * 0.0001)
- .PARTITION(selected_records, name="ps", by=part_key)(
- PS_PARTKEY=part_key, VALUE=SUM(ps.metric)
- )
- .WHERE(VALUE > BACK(1).min_market_share)
+ TPCH.CALCULATE(min_market_share=SUM(selected_records.metric) * 0.0001)
+ .PARTITION(selected_records, name="ps", by=part_key)
+ .CALCULATE(PS_PARTKEY=part_key, VALUE=SUM(ps.metric))
+ .WHERE(VALUE > min_market_share)
.TOP_K(10, by=VALUE.DESC())
)
@@ -274,30 +304,32 @@ def impl_tpch_q12():
& (commit_date < receipt_date)
& (receipt_date >= datetime.date(1994, 1, 1))
& (receipt_date < datetime.date(1995, 1, 1))
- )(
+ ).CALCULATE(
is_high_priority=(order.order_priority == "1-URGENT")
| (order.order_priority == "2-HIGH"),
)
- return PARTITION(selected_lines, "l", by=ship_mode)(
- L_SHIPMODE=ship_mode,
- HIGH_LINE_COUNT=SUM(l.is_high_priority),
- LOW_LINE_COUNT=SUM(~(l.is_high_priority)),
- ).ORDER_BY(L_SHIPMODE.ASC())
+ return (
+ PARTITION(selected_lines, "l", by=ship_mode)
+ .CALCULATE(
+ L_SHIPMODE=ship_mode,
+ HIGH_LINE_COUNT=SUM(l.is_high_priority),
+ LOW_LINE_COUNT=SUM(~(l.is_high_priority)),
+ )
+ .ORDER_BY(L_SHIPMODE.ASC())
+ )
def impl_tpch_q13():
"""
PyDough implementation of TPCH Q13, truncated to 10 rows.
"""
- customer_info = Customers(
- key,
- num_non_special_orders=COUNT(
- orders.WHERE(~(LIKE(comment, "%special%requests%")))
- ),
+ selected_orders = orders.WHERE(~(LIKE(comment, "%special%requests%")))
+ customer_info = Customers.CALCULATE(num_non_special_orders=COUNT(selected_orders))
+ return (
+ PARTITION(customer_info, name="custs", by=num_non_special_orders)
+ .CALCULATE(C_COUNT=num_non_special_orders, CUSTDIST=COUNT(custs))
+ .TOP_K(10, by=(CUSTDIST.DESC(), C_COUNT.DESC()))
)
- return PARTITION(customer_info, name="custs", by=num_non_special_orders)(
- C_COUNT=num_non_special_orders, CUSTDIST=COUNT(custs)
- ).TOP_K(10, by=(CUSTDIST.DESC(), C_COUNT.DESC()))
def impl_tpch_q14():
@@ -308,11 +340,11 @@ def impl_tpch_q14():
selected_lines = Lineitems.WHERE(
(ship_date >= datetime.date(1995, 9, 1))
& (ship_date < datetime.date(1995, 10, 1))
- )(
+ ).CALCULATE(
value=value,
promo_value=IFF(STARTSWITH(part.part_type, "PROMO"), value, 0),
)
- return TPCH(
+ return TPCH.CALCULATE(
PROMO_REVENUE=100.0
* SUM(selected_lines.promo_value)
/ SUM(selected_lines.value)
@@ -329,15 +361,17 @@ def impl_tpch_q15():
)
total = SUM(selected_lines.extended_price * (1 - selected_lines.discount))
return (
- TPCH(max_revenue=MAX(Suppliers(total_revenue=total).total_revenue))
- .Suppliers(
+ TPCH.CALCULATE(
+ max_revenue=MAX(Suppliers.CALCULATE(total_revenue=total).total_revenue)
+ )
+ .Suppliers.CALCULATE(
S_SUPPKEY=key,
S_NAME=name,
S_ADDRESS=address,
S_PHONE=phone,
TOTAL_REVENUE=total,
)
- .WHERE(TOTAL_REVENUE == BACK(1).max_revenue)
+ .WHERE(TOTAL_REVENUE == max_revenue)
.ORDER_BY(S_SUPPKEY.ASC())
)
@@ -352,30 +386,36 @@ def impl_tpch_q16():
& ~STARTSWITH(part_type, "MEDIUM POLISHED%")
& ISIN(size, [49, 14, 23, 45, 19, 3, 36, 9])
)
- .supply_records(
- p_brand=BACK(1).brand,
- p_type=BACK(1).part_type,
- p_size=BACK(1).size,
- ps_suppkey=supplier_key,
+ .CALCULATE(
+ p_brand=brand,
+ p_type=part_type,
+ p_size=size,
+ )
+ .supply_records.WHERE(~LIKE(supplier.comment, "%Customer%Complaints%"))
+ )
+ return (
+ PARTITION(selected_records, name="ps", by=(p_brand, p_type, p_size))
+ .CALCULATE(
+ P_BRAND=p_brand,
+ P_TYPE=p_type,
+ P_SIZE=p_size,
+ SUPPLIER_COUNT=NDISTINCT(ps.supplier_key),
+ )
+ .TOP_K(
+ 10, by=(SUPPLIER_COUNT.DESC(), P_BRAND.ASC(), P_TYPE.ASC(), P_SIZE.ASC())
)
- .WHERE(~LIKE(supplier.comment, "%Customer%Complaints%"))
)
- return PARTITION(selected_records, name="ps", by=(p_brand, p_type, p_size))(
- P_BRAND=p_brand,
- P_TYPE=p_type,
- P_SIZE=p_size,
- SUPPLIER_COUNT=NDISTINCT(ps.supplier_key),
- ).TOP_K(10, by=(SUPPLIER_COUNT.DESC(), P_BRAND.ASC(), P_TYPE.ASC(), P_SIZE.ASC()))
def impl_tpch_q17():
"""
PyDough implementation of TPCH Q17.
"""
- selected_lines = Parts.WHERE((brand == "Brand#23") & (container == "MED BOX"))(
- avg_quantity=AVG(lines.quantity)
- ).lines.WHERE(quantity < 0.2 * BACK(1).avg_quantity)
- return TPCH(AVG_YEARLY=SUM(selected_lines.extended_price) / 7.0)
+ part_info = Parts.WHERE((brand == "Brand#23") & (container == "MED BOX")).CALCULATE(
+ part_avg_quantity=AVG(lines.quantity)
+ )
+ selected_lines = part_info.lines.WHERE(quantity < 0.2 * part_avg_quantity)
+ return TPCH.CALCULATE(AVG_YEARLY=SUM(selected_lines.extended_price) / 7.0)
def impl_tpch_q18():
@@ -383,7 +423,7 @@ def impl_tpch_q18():
PyDough implementation of TPCH Q18, truncated to 10 rows
"""
return (
- Orders(
+ Orders.CALCULATE(
C_NAME=customer.name,
C_CUSTKEY=customer.key,
O_ORDERKEY=key,
@@ -440,7 +480,7 @@ def impl_tpch_q19():
)
)
)
- return TPCH(
+ return TPCH.CALCULATE(
REVENUE=SUM(selected_lines.extended_price * (1 - selected_lines.discount))
)
@@ -455,12 +495,12 @@ def impl_tpch_q20():
& (ship_date < datetime.date(1995, 1, 1))
).quantity
)
- selected_part_supplied = supply_records.part.WHERE(
- STARTSWITH(name, "forest") & (BACK(1).availqty > part_qty * 0.5)
+ selected_part_supplied = supply_records.CALCULATE(availqty).part.WHERE(
+ STARTSWITH(name, "forest") & (availqty > part_qty * 0.5)
)
return (
- Suppliers(
+ Suppliers.CALCULATE(
S_NAME=name,
S_ADDRESS=address,
)
@@ -474,18 +514,23 @@ def impl_tpch_q21():
PyDough implementation of TPCH Q21, truncated to 10 rows.
"""
date_check = receipt_date > commit_date
- different_supplier = supplier_key != BACK(2).supplier_key
- waiting_entries = lines.WHERE(date_check).order.WHERE(
+ selected_orders = lines.CALCULATE(original_key=supplier_key).WHERE(date_check).order
+ different_supplier = supplier_key != original_key
+ waiting_entries = selected_orders.WHERE(
(order_status == "F")
& HAS(lines.WHERE(different_supplier))
& HASNOT(lines.WHERE(different_supplier & date_check))
)
- return Suppliers.WHERE(nation.name == "SAUDI ARABIA")(
- S_NAME=name,
- NUMWAIT=COUNT(waiting_entries),
- ).TOP_K(
- 10,
- by=(NUMWAIT.DESC(), S_NAME.ASC()),
+ return (
+ Suppliers.WHERE(nation.name == "SAUDI ARABIA")
+ .CALCULATE(
+ S_NAME=name,
+ NUMWAIT=COUNT(waiting_entries),
+ )
+ .TOP_K(
+ 10,
+ by=(NUMWAIT.DESC(), S_NAME.ASC()),
+ )
)
@@ -493,18 +538,21 @@ def impl_tpch_q22():
"""
PyDough implementation of TPCH Q22.
"""
- selected_customers = Customers(cntry_code=phone[:2]).WHERE(
+ selected_customers = Customers.CALCULATE(cntry_code=phone[:2]).WHERE(
ISIN(cntry_code, ("13", "31", "23", "29", "30", "18", "17"))
)
return (
- TPCH(avg_balance=AVG(selected_customers.WHERE(acctbal > 0.0).acctbal))
+ TPCH.CALCULATE(
+ global_avg_balance=AVG(selected_customers.WHERE(acctbal > 0.0).acctbal)
+ )
.PARTITION(
selected_customers.WHERE(
- (acctbal > BACK(1).avg_balance) & (COUNT(orders) == 0)
+ (acctbal > global_avg_balance) & (COUNT(orders) == 0)
),
name="custs",
by=cntry_code,
- )(
+ )
+ .CALCULATE(
CNTRY_CODE=cntry_code,
NUM_CUSTS=COUNT(custs),
TOTACCTBAL=SUM(custs.acctbal),