Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for global CALC nodes #23

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pydough/pydough_ast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"SubCollection",
"Calc",
"CalcSubCollection",
"GlobalCalc",
]

from .abstract_pydough_ast import PyDoughAST
Expand All @@ -31,5 +32,6 @@
SubCollection,
Calc,
CalcSubCollection,
GlobalCalc,
)
from .node_builder import AstNodeBuilder
2 changes: 2 additions & 0 deletions pydough/pydough_ast/collections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"CalcSubCollection",
"BackReferenceCollection",
"HiddenBackReferenceCollection",
"GlobalCalc",
]

from .collection_ast import PyDoughCollectionAST
Expand All @@ -19,3 +20,4 @@
from .calc_sub_collection import CalcSubCollection
from .back_reference_collection import BackReferenceCollection
from .hidden_back_reference_collection import HiddenBackReferenceCollection
from .global_calc import GlobalCalc
2 changes: 1 addition & 1 deletion pydough/pydough_ast/collections/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
self._calc_term_indices: Dict[str, Tuple[int, PyDoughExpressionAST]] | None = (
None
)
self._all_terms: Dict[str, PyDoughExpressionAST] = None
self._all_terms: Dict[str, PyDoughAST] = None

def with_terms(self, terms: List[Tuple[str, PyDoughExpressionAST]]) -> "Calc":
"""
Expand Down
127 changes: 127 additions & 0 deletions pydough/pydough_ast/collections/global_calc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
TODO: add file-level docstring
"""

__all__ = ["GlobalCalc"]


from typing import Dict, List, Tuple, Set

from pydough.metadata import GraphMetadata
from pydough.pydough_ast.abstract_pydough_ast import PyDoughAST
from pydough.pydough_ast.errors import PyDoughASTException
from pydough.pydough_ast.expressions import PyDoughExpressionAST
from .collection_ast import PyDoughCollectionAST
from .table_collection import TableCollection


class GlobalCalc(PyDoughCollectionAST):
"""
The AST node implementation class representing a top-level CALC expression
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not saying there is any need to change this, but could there ever be a situation where the "global context" is really graph-level context in the future, meaning tied to a particular graph but not any node in the graph, for when multiple graphs are supported?

If there is then maybe GraphLevelCalc which could have all the same properties is a more accurate description.

Copy link
Contributor Author

@knassre-bodo knassre-bodo Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I just overhauled this by making a "global context" which table collections & calcs can be children of (the necessity of this became quite apparent when I started doing the tree strings), eliminating the need for this class. I'm closing this PR since it is now redundant.

without a parent context.
"""

def __init__(
self,
graph: GraphMetadata,
children: List[PyDoughCollectionAST],
):
self._graph: GraphMetadata = graph
self._children: List[PyDoughCollectionAST] = children
# Not defined until with_terms is called
self._calc_term_indices: Dict[str, Tuple[int, PyDoughExpressionAST]] | None = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for having an optional type as opposed to an empty dictionary? Is this because the dictionary is effectively "frozen" once initialized so this is avoid any invalid behavior?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explained in the next section. This seems fine.

None
)
self._all_terms: Dict[str, PyDoughAST] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type annotation needs optional


def with_terms(self, terms: List[Tuple[str, PyDoughExpressionAST]]) -> "GlobalCalc":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming this is necessary because terms can't be assigned until a later time. This is a less desirable design pattern because it makes it harder to track when the state is valid, but it's an understandable limitation.

Can you write the docstring now though so we can be very clear about the explained delay between with_terms and the constructor?

"""
TODO: add function docstring
"""
if self._calc_term_indices is not None:
raise PyDoughCollectionAST(
"Cannot call `with_terms` more than once per GlobalCalc node"
)
self._calc_term_indices = {name: idx for idx, (name, _) in enumerate(terms)}
# Include terms from the graph itself, with the terms from this CALC
# added in (overwriting any preceding properties with the same name)
self._all_terms = {}
for name in self.graph.get_collection_names():
self._all_terms[name] = TableCollection(self.graph.get_collection(name))
for name, property in terms:
self._all_terms[name] = property
return self

@property
def graph(self) -> GraphMetadata:
"""
The graph that the global calc node is being done within.
"""
return self._graph

@property
def children(self) -> List[PyDoughCollectionAST]:
"""
The child collections accessible from the global CALC used to derive
expressions in terms of a subcollection.
"""
return self._children

@property
def calc_term_indices(self) -> Dict[str, Tuple[int, PyDoughExpressionAST]]:
"""
Mapping of each named expression of the CALC to a tuple (idx, expr)
where idx is the ordinal position of the property when included
in a CALC and property is the AST node representing the property.
"""
if self._calc_term_indices is None:
raise PyDoughCollectionAST(
"Cannot invoke `calc_term_indices` before calling `with_terms`"
)
return self._calc_term_indices

@property
def ancestor_context(self) -> PyDoughCollectionAST | None:
return None

@property
def preceding_context(self) -> PyDoughCollectionAST | None:
return None

@property
def calc_terms(self) -> Set[str]:
return set(self.calc_term_indices)

@property
def all_terms(self) -> Set[str]:
return set(self._all_terms)

def get_expression_position(self, expr_name: str) -> int:
if expr_name not in self.calc_term_indices:
raise PyDoughASTException(f"Unrecognized CALC term: {expr_name!r}")
return self.calc_term_indices[expr_name]

def get_term(self, term_name: str) -> PyDoughAST:
if term_name not in self.all_terms:
raise PyDoughASTException(f"Unrecognized term: {term_name!r}")
return self._all_terms[term_name]

def to_string(self) -> str:
kwarg_strings: List[str] = []
for name in self._calc_term_indices:
expr: PyDoughExpressionAST = self.get_term(name)
kwarg_strings.append(f"{name}={expr.to_string()}")
return f"{self.graph.name}({', '.join(kwarg_strings)})"

def to_tree_string(self) -> str:
raise NotImplementedError

def equals(self, other: "GlobalCalc") -> bool:
if self._all_terms is None:
raise PyDoughCollectionAST(
"Cannot invoke `equals` before calling `with_terms`"
)
return (
super().equals(other)
and self._calc_term_indices == other._calc_term_indices
)
25 changes: 19 additions & 6 deletions pydough/pydough_ast/node_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Calc,
CalcSubCollection,
BackReferenceCollection,
GlobalCalc,
)


Expand Down Expand Up @@ -188,6 +189,21 @@ def build_back_reference_expression(
"""
return BackReferenceExpression(collection, name, levels)

def build_global_calc(
self, graph: GraphMetadata, children: List[CalcSubCollection]
) -> GlobalCalc:
"""
Creates a new global CALC instance, but `with_terms` still needs to be
called on the output.

Args:
`children`: the child subcollections accessed by the CALC term.

Returns:
The newly created GlobalCalc.
"""
return GlobalCalc(graph, children)

def build_table_collection(self, name: str) -> TableCollection:
"""
Creates a new table collection invocation.
Expand Down Expand Up @@ -234,8 +250,8 @@ def build_calc(
children: List[CalcSubCollection],
) -> Calc:
"""
Creates a CALC term, but `with_terms` still needs to be called on the
output.
Creates a CALC instance, but `with_terms` still needs to be called on
the output.

Args:
`collection`: the preceding collection.
Expand All @@ -244,11 +260,8 @@ def build_calc(

Returns:
The newly created PyDough CALC term.

Raises:
`PyDoughASTException`: if the terms are invalid for the CALC term.
"""
return Calc(collection, CalcSubCollection)
return Calc(collection, children)

def build_back_reference_collection(
self,
Expand Down
40 changes: 40 additions & 0 deletions tests/test_ast_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
CalcInfo,
ChildReferenceInfo,
BackReferenceCollectionInfo,
GlobalCalcInfo,
)
import pytest

Expand Down Expand Up @@ -768,6 +769,45 @@ def test_collections_calc_terms(
"Suppliers.parts_supplied(nation_name=nation(name=name).name, supplier_name=BACK(1).name, part_name=name, ratio=ps_lines(ratio=quantity / BACK(1).ps_availqty).ratio)",
id="suppliers_parts_childcalc_b",
),
pytest.param(
(
GlobalCalcInfo(
[TableCollectionInfo("Customers")],
total_balance=FunctionInfo(
"SUM", [ChildReferenceInfo("acctbal", 0)]
),
)
),
"TPCH(total_balance=SUM(Customers.acctbal))",
id="globalcalc_a",
),
pytest.param(
(
GlobalCalcInfo(
[
TableCollectionInfo("Customers"),
TableCollectionInfo("Suppliers")
** SubCollectionInfo("parts_supplied")
** CalcInfo(
[],
value=FunctionInfo(
"MUL",
[
ReferenceInfo("ps_availqty"),
ReferenceInfo("retail_price"),
],
),
),
],
total_demand=FunctionInfo(
"SUM", [ChildReferenceInfo("acctbal", 0)]
),
total_supply=FunctionInfo("SUM", [ChildReferenceInfo("value", 1)]),
)
),
"TPCH(total_demand=SUM(Customers.acctbal), total_supply=SUM(Suppliers.parts_supplied(value=ps_availqty * retail_price).value))",
id="globalcalc_b",
),
],
)
def test_collections_to_string(
Expand Down
30 changes: 28 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"BackReferenceExpressionInfo",
"ChildReferenceInfo",
"BackReferenceCollectionInfo",
"GlobalCalcInfo",
]

from pydough.metadata import GraphMetadata
Expand All @@ -27,6 +28,7 @@
PyDoughExpressionAST,
Calc,
CalcSubCollection,
GlobalCalc,
)
from typing import Dict, Set, Callable, Any, List, Tuple
from pydough.types import PyDoughType
Expand Down Expand Up @@ -156,7 +158,7 @@ def __init__(self, function_name: str, args_info: List[AstNodeTestInfo]):
self.args_info: List[AstNodeTestInfo] = args_info

def to_string(self) -> str:
arg_strings: List[str] = [arg.to_string for arg in self.args_info]
arg_strings: List[str] = [arg.to_string() for arg in self.args_info]
return f"Call[{self.function_name} on ({', '.join(arg_strings)})]"

def build(
Expand Down Expand Up @@ -245,7 +247,7 @@ def build(
children_contexts: List[PyDoughCollectionAST] | None = None,
) -> PyDoughAST:
assert (
context is not None
children_contexts is not None
), "Cannot call .build() on ChildReferenceInfo without providing a list of child contexts"
return builder.build_child_reference(
children_contexts, self.child_idx, self.name
Expand Down Expand Up @@ -433,6 +435,30 @@ def local_build(
return raw_calc.with_terms(args)


class GlobalCalcInfo(CalcInfo):
"""
CollectionTestInfo implementation class to build a global CALC.
"""

def to_string(self) -> str:
return f"Global{super().to_string()}"

def local_build(
self,
builder: AstNodeBuilder,
context: PyDoughCollectionAST | None = None,
children_contexts: List[PyDoughCollectionAST] | None = None,
) -> PyDoughCollectionAST:
children: List[PyDoughCollectionAST] = [
child.build(builder, context) for child in self.children_info
]
raw_calc: GlobalCalc = builder.build_global_calc(builder.graph, children)
args: List[Tuple[str, PyDoughExpressionAST]] = [
(name, info.build(builder, context, children)) for name, info in self.args
]
return raw_calc.with_terms(args)


class BackReferenceCollectionInfo(CollectionTestInfo):
"""
CollectionTestInfo implementation class to build a reference to an
Expand Down