From efa948fa570175c1cdc586f8bccad0d856c537cc Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 1 Nov 2024 14:24:11 -0400 Subject: [PATCH] Adding global calc definitions and simple tests [RUN CI] --- pydough/pydough_ast/__init__.py | 2 + pydough/pydough_ast/collections/__init__.py | 2 + pydough/pydough_ast/collections/calc.py | 2 +- .../pydough_ast/collections/global_calc.py | 127 ++++++++++++++++++ pydough/pydough_ast/node_builder.py | 25 +++- tests/test_ast_collection.py | 40 ++++++ tests/test_utils.py | 30 ++++- 7 files changed, 219 insertions(+), 9 deletions(-) create mode 100644 pydough/pydough_ast/collections/global_calc.py diff --git a/pydough/pydough_ast/__init__.py b/pydough/pydough_ast/__init__.py index 2921186f..d92cfc57 100644 --- a/pydough/pydough_ast/__init__.py +++ b/pydough/pydough_ast/__init__.py @@ -15,6 +15,7 @@ "SubCollection", "Calc", "CalcSubCollection", + "GlobalCalc", ] from .abstract_pydough_ast import PyDoughAST @@ -31,5 +32,6 @@ SubCollection, Calc, CalcSubCollection, + GlobalCalc, ) from .node_builder import AstNodeBuilder diff --git a/pydough/pydough_ast/collections/__init__.py b/pydough/pydough_ast/collections/__init__.py index 65b59774..1a58e186 100644 --- a/pydough/pydough_ast/collections/__init__.py +++ b/pydough/pydough_ast/collections/__init__.py @@ -10,6 +10,7 @@ "CalcSubCollection", "BackReferenceCollection", "HiddenBackReferenceCollection", + "GlobalCalc", ] from .collection_ast import PyDoughCollectionAST @@ -19,3 +20,4 @@ from .calc_sub_collection import CalcSubCollection from .back_reference_collection import BackReferenceCollection from .hidden_back_reference_collection import HiddenBackReferenceCollection +from .global_calc import GlobalCalc diff --git a/pydough/pydough_ast/collections/calc.py b/pydough/pydough_ast/collections/calc.py index b1021672..0e60cfab 100644 --- a/pydough/pydough_ast/collections/calc.py +++ b/pydough/pydough_ast/collections/calc.py @@ -31,7 +31,7 @@ def __init__( self._calc_term_indices: Dict[str, Tuple[int, PyDoughExpressionAST]] | None = ( None ) - self._all_terms: Dict[str, PyDoughExpressionAST] = None + self._all_terms: Dict[str, PyDoughAST] = None def with_terms(self, terms: List[Tuple[str, PyDoughExpressionAST]]) -> "Calc": """ diff --git a/pydough/pydough_ast/collections/global_calc.py b/pydough/pydough_ast/collections/global_calc.py new file mode 100644 index 00000000..84c3d4df --- /dev/null +++ b/pydough/pydough_ast/collections/global_calc.py @@ -0,0 +1,127 @@ +""" +TODO: add file-level docstring +""" + +__all__ = ["GlobalCalc"] + + +from typing import Dict, List, Tuple, Set + +from pydough.metadata import GraphMetadata +from pydough.pydough_ast.abstract_pydough_ast import PyDoughAST +from pydough.pydough_ast.errors import PyDoughASTException +from pydough.pydough_ast.expressions import PyDoughExpressionAST +from .collection_ast import PyDoughCollectionAST +from .table_collection import TableCollection + + +class GlobalCalc(PyDoughCollectionAST): + """ + The AST node implementation class representing a top-level CALC expression + without a parent context. + """ + + def __init__( + self, + graph: GraphMetadata, + children: List[PyDoughCollectionAST], + ): + self._graph: GraphMetadata = graph + self._children: List[PyDoughCollectionAST] = children + # Not defined until with_terms is called + self._calc_term_indices: Dict[str, Tuple[int, PyDoughExpressionAST]] | None = ( + None + ) + self._all_terms: Dict[str, PyDoughAST] = None + + def with_terms(self, terms: List[Tuple[str, PyDoughExpressionAST]]) -> "GlobalCalc": + """ + TODO: add function docstring + """ + if self._calc_term_indices is not None: + raise PyDoughCollectionAST( + "Cannot call `with_terms` more than once per GlobalCalc node" + ) + self._calc_term_indices = {name: idx for idx, (name, _) in enumerate(terms)} + # Include terms from the graph itself, with the terms from this CALC + # added in (overwriting any preceding properties with the same name) + self._all_terms = {} + for name in self.graph.get_collection_names(): + self._all_terms[name] = TableCollection(self.graph.get_collection(name)) + for name, property in terms: + self._all_terms[name] = property + return self + + @property + def graph(self) -> GraphMetadata: + """ + The graph that the global calc node is being done within. + """ + return self._graph + + @property + def children(self) -> List[PyDoughCollectionAST]: + """ + The child collections accessible from the global CALC used to derive + expressions in terms of a subcollection. + """ + return self._children + + @property + def calc_term_indices(self) -> Dict[str, Tuple[int, PyDoughExpressionAST]]: + """ + Mapping of each named expression of the CALC to a tuple (idx, expr) + where idx is the ordinal position of the property when included + in a CALC and property is the AST node representing the property. + """ + if self._calc_term_indices is None: + raise PyDoughCollectionAST( + "Cannot invoke `calc_term_indices` before calling `with_terms`" + ) + return self._calc_term_indices + + @property + def ancestor_context(self) -> PyDoughCollectionAST | None: + return None + + @property + def preceding_context(self) -> PyDoughCollectionAST | None: + return None + + @property + def calc_terms(self) -> Set[str]: + return set(self.calc_term_indices) + + @property + def all_terms(self) -> Set[str]: + return set(self._all_terms) + + def get_expression_position(self, expr_name: str) -> int: + if expr_name not in self.calc_term_indices: + raise PyDoughASTException(f"Unrecognized CALC term: {expr_name!r}") + return self.calc_term_indices[expr_name] + + def get_term(self, term_name: str) -> PyDoughAST: + if term_name not in self.all_terms: + raise PyDoughASTException(f"Unrecognized term: {term_name!r}") + return self._all_terms[term_name] + + def to_string(self) -> str: + kwarg_strings: List[str] = [] + for name in self._calc_term_indices: + expr: PyDoughExpressionAST = self.get_term(name) + kwarg_strings.append(f"{name}={expr.to_string()}") + return f"{self.graph.name}({', '.join(kwarg_strings)})" + + def to_tree_string(self) -> str: + raise NotImplementedError + + def equals(self, other: "GlobalCalc") -> bool: + if self._all_terms is None: + raise PyDoughCollectionAST( + "Cannot invoke `equals` before calling `with_terms`" + ) + return ( + super().equals(other) + and self._calc_term_indices == other._calc_term_indices + ) diff --git a/pydough/pydough_ast/node_builder.py b/pydough/pydough_ast/node_builder.py index a0f90ecb..45cb3b91 100644 --- a/pydough/pydough_ast/node_builder.py +++ b/pydough/pydough_ast/node_builder.py @@ -31,6 +31,7 @@ Calc, CalcSubCollection, BackReferenceCollection, + GlobalCalc, ) @@ -188,6 +189,21 @@ def build_back_reference_expression( """ return BackReferenceExpression(collection, name, levels) + def build_global_calc( + self, graph: GraphMetadata, children: List[CalcSubCollection] + ) -> GlobalCalc: + """ + Creates a new global CALC instance, but `with_terms` still needs to be + called on the output. + + Args: + `children`: the child subcollections accessed by the CALC term. + + Returns: + The newly created GlobalCalc. + """ + return GlobalCalc(graph, children) + def build_table_collection(self, name: str) -> TableCollection: """ Creates a new table collection invocation. @@ -234,8 +250,8 @@ def build_calc( children: List[CalcSubCollection], ) -> Calc: """ - Creates a CALC term, but `with_terms` still needs to be called on the - output. + Creates a CALC instance, but `with_terms` still needs to be called on + the output. Args: `collection`: the preceding collection. @@ -244,11 +260,8 @@ def build_calc( Returns: The newly created PyDough CALC term. - - Raises: - `PyDoughASTException`: if the terms are invalid for the CALC term. """ - return Calc(collection, CalcSubCollection) + return Calc(collection, children) def build_back_reference_collection( self, diff --git a/tests/test_ast_collection.py b/tests/test_ast_collection.py index cce8beab..8952f7c6 100644 --- a/tests/test_ast_collection.py +++ b/tests/test_ast_collection.py @@ -20,6 +20,7 @@ CalcInfo, ChildReferenceInfo, BackReferenceCollectionInfo, + GlobalCalcInfo, ) import pytest @@ -768,6 +769,45 @@ def test_collections_calc_terms( "Suppliers.parts_supplied(nation_name=nation(name=name).name, supplier_name=BACK(1).name, part_name=name, ratio=ps_lines(ratio=quantity / BACK(1).ps_availqty).ratio)", id="suppliers_parts_childcalc_b", ), + pytest.param( + ( + GlobalCalcInfo( + [TableCollectionInfo("Customers")], + total_balance=FunctionInfo( + "SUM", [ChildReferenceInfo("acctbal", 0)] + ), + ) + ), + "TPCH(total_balance=SUM(Customers.acctbal))", + id="globalcalc_a", + ), + pytest.param( + ( + GlobalCalcInfo( + [ + TableCollectionInfo("Customers"), + TableCollectionInfo("Suppliers") + ** SubCollectionInfo("parts_supplied") + ** CalcInfo( + [], + value=FunctionInfo( + "MUL", + [ + ReferenceInfo("ps_availqty"), + ReferenceInfo("retail_price"), + ], + ), + ), + ], + total_demand=FunctionInfo( + "SUM", [ChildReferenceInfo("acctbal", 0)] + ), + total_supply=FunctionInfo("SUM", [ChildReferenceInfo("value", 1)]), + ) + ), + "TPCH(total_demand=SUM(Customers.acctbal), total_supply=SUM(Suppliers.parts_supplied(value=ps_availqty * retail_price).value))", + id="globalcalc_b", + ), ], ) def test_collections_to_string( diff --git a/tests/test_utils.py b/tests/test_utils.py index 901c6b52..9782c090 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,6 +17,7 @@ "BackReferenceExpressionInfo", "ChildReferenceInfo", "BackReferenceCollectionInfo", + "GlobalCalcInfo", ] from pydough.metadata import GraphMetadata @@ -27,6 +28,7 @@ PyDoughExpressionAST, Calc, CalcSubCollection, + GlobalCalc, ) from typing import Dict, Set, Callable, Any, List, Tuple from pydough.types import PyDoughType @@ -156,7 +158,7 @@ def __init__(self, function_name: str, args_info: List[AstNodeTestInfo]): self.args_info: List[AstNodeTestInfo] = args_info def to_string(self) -> str: - arg_strings: List[str] = [arg.to_string for arg in self.args_info] + arg_strings: List[str] = [arg.to_string() for arg in self.args_info] return f"Call[{self.function_name} on ({', '.join(arg_strings)})]" def build( @@ -245,7 +247,7 @@ def build( children_contexts: List[PyDoughCollectionAST] | None = None, ) -> PyDoughAST: assert ( - context is not None + children_contexts is not None ), "Cannot call .build() on ChildReferenceInfo without providing a list of child contexts" return builder.build_child_reference( children_contexts, self.child_idx, self.name @@ -433,6 +435,30 @@ def local_build( return raw_calc.with_terms(args) +class GlobalCalcInfo(CalcInfo): + """ + CollectionTestInfo implementation class to build a global CALC. + """ + + def to_string(self) -> str: + return f"Global{super().to_string()}" + + def local_build( + self, + builder: AstNodeBuilder, + context: PyDoughCollectionAST | None = None, + children_contexts: List[PyDoughCollectionAST] | None = None, + ) -> PyDoughCollectionAST: + children: List[PyDoughCollectionAST] = [ + child.build(builder, context) for child in self.children_info + ] + raw_calc: GlobalCalc = builder.build_global_calc(builder.graph, children) + args: List[Tuple[str, PyDoughExpressionAST]] = [ + (name, info.build(builder, context, children)) for name, info in self.args + ] + return raw_calc.with_terms(args) + + class BackReferenceCollectionInfo(CollectionTestInfo): """ CollectionTestInfo implementation class to build a reference to an