Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BSE-4155] Add support for the Aggregate relational node #39

Merged
merged 89 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 76 commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
cce5e4c
Added the new base files
njriasan Nov 6, 2024
19460cc
Started added abstract classes
njriasan Nov 6, 2024
89f8e42
Added the file definitions
njriasan Nov 7, 2024
2786774
Added the sqlglot import
njriasan Nov 7, 2024
4ee70be
Added a basic test [run CI]
njriasan Nov 7, 2024
66107ec
Defined tests
njriasan Nov 8, 2024
256ecef
Defined the tests cases we want to write
njriasan Nov 8, 2024
d267275
Added the project file docstring
njriasan Nov 8, 2024
cb32b89
Added the class defintion
njriasan Nov 8, 2024
b11a50f
Started adding implementation parts
njriasan Nov 8, 2024
856e20e
Added the projection defintion
njriasan Nov 8, 2024
513abe3
Added the limit file
njriasan Nov 8, 2024
5ed09e0
Added a single relational base class
njriasan Nov 8, 2024
24dceb2
Merge branch 'nick/projection' into nick/limit
njriasan Nov 8, 2024
4cfd06e
Added limit support
njriasan Nov 8, 2024
7081a2c
Added the aggregate node
njriasan Nov 8, 2024
d318e64
Cleaned up the class structure
njriasan Nov 8, 2024
0764621
Merge branch 'nick/limit' into nick/aggregate
njriasan Nov 8, 2024
178ffd1
Updated scan
njriasan Nov 8, 2024
2e7d1a5
wrote tests for the basic scan ops
njriasan Nov 8, 2024
9b88810
Fixed the first basic unit test:
njriasan Nov 8, 2024
b844130
Added equality tests
njriasan Nov 8, 2024
24b620e
Added error test
njriasan Nov 8, 2024
63aa271
Added remaining tests [run CI]
njriasan Nov 8, 2024
d32f73e
Merged with prior PR
njriasan Nov 8, 2024
c2aef51
Updated orderings check
njriasan Nov 8, 2024
c832181
added the projection changes
njriasan Nov 8, 2024
135ab5f
added a test for project equality
njriasan Nov 9, 2024
ff3b25f
Added function definitions for the remaining unit tests
njriasan Nov 9, 2024
90cd1e5
Fixed the to_string() test
njriasan Nov 9, 2024
8bd7cfc
added a can merge test
njriasan Nov 9, 2024
0737744
added a merge test
njriasan Nov 9, 2024
95c281e
Finished adding project tests [run CI]
njriasan Nov 9, 2024
2fc3e83
Fixed scan node
njriasan Nov 9, 2024
47c1933
Simplified the equality checks
njriasan Nov 9, 2024
0e3166c
Added type checks
njriasan Nov 9, 2024
becb91a
Added the to_string test
njriasan Nov 9, 2024
7830f13
Added equals test
njriasan Nov 10, 2024
e8ac1e6
Added the can_merge test
njriasan Nov 10, 2024
4081ad8
Added the merge test
njriasan Nov 10, 2024
1c7e2de
Added remaining tests [run CI]
njriasan Nov 10, 2024
eb71c14
Merge branch 'nick/limit' into nick/aggregate
njriasan Nov 10, 2024
288ae56
Simplified the merge logic
njriasan Nov 10, 2024
631606d
Defined test signatures
njriasan Nov 10, 2024
411a623
Added to_string() test
njriasan Nov 10, 2024
106c708
Added test_aggregate_equals except aggregation function
njriasan Nov 10, 2024
4c76249
added can merge tests
njriasan Nov 10, 2024
4719a66
Added remaining aggregate tests [run CI]
njriasan Nov 10, 2024
ac2a6e4
Merged with prior [run CI]
njriasan Nov 11, 2024
44f709f
Merge branch 'nick/relational_scan' into nick/projection
njriasan Nov 11, 2024
33b5332
Merge branch 'nick/projection' into nick/limit
njriasan Nov 11, 2024
4d2a781
Merge branch 'nick/limit' into nick/aggregate
njriasan Nov 11, 2024
ecc1600
applied most of Kian's changes, need to test column expressions
njriasan Nov 13, 2024
25a7791
Removed unnecessary column class
njriasan Nov 13, 2024
ea9255c
Added remaining tests [run CI]
njriasan Nov 13, 2024
bf20591
Added remaining tests [run CI]
njriasan Nov 13, 2024
ab2f5db
added the literal
njriasan Nov 13, 2024
32cc1ea
added the remaining tests [run CI]
njriasan Nov 13, 2024
591b67d
Back-ported changes [run CI]
njriasan Nov 13, 2024
d96e9f6
Merged with prior PR
njriasan Nov 13, 2024
ce7b9fa
Fix a typo [run CI]
njriasan Nov 13, 2024
5376da1
Merge branch 'nick/relational_scan' into nick/projection [run CI]
njriasan Nov 13, 2024
7501f3c
Fixed the existing tests, need to update ordering
njriasan Nov 13, 2024
721c181
Added an assertion error check on the integer literal
njriasan Nov 13, 2024
bd84645
Added column ordering
njriasan Nov 13, 2024
876667f
Added to_string
njriasan Nov 13, 2024
316ea57
Added remaining limit tests
njriasan Nov 13, 2024
7fef483
Updated remaining tests [run CI]
njriasan Nov 13, 2024
9ce94c1
Fixed last comment [run CI]
njriasan Nov 13, 2024
2b37703
Merged with prior PR
njriasan Nov 13, 2024
55b03ed
Added expression tests
njriasan Nov 14, 2024
d3e2c89
Fixed code definitions, need to update tests
njriasan Nov 14, 2024
91bcc7c
added assertion tests
njriasan Nov 14, 2024
593247c
Added to_string tests
njriasan Nov 14, 2024
ab8c0c7
Finished adding tests [run CI]
njriasan Nov 14, 2024
f6d8eb2
Updated the base type [run CI]
njriasan Nov 14, 2024
90b4e7a
Renamed files
njriasan Nov 14, 2024
cc1a2d1
Applied actual refactoring, need to update test info
njriasan Nov 14, 2024
4a01824
Added test changes [run CI]
njriasan Nov 14, 2024
d05aa76
merged with prior PR
njriasan Nov 14, 2024
11b5d27
Added test docstrings
njriasan Nov 14, 2024
0f9329d
applied remaining feedback [run CI]
njriasan Nov 14, 2024
07f4e32
Merged with prior PR, need to clean up
njriasan Nov 14, 2024
d66dcee
Merged with prior PR
njriasan Nov 14, 2024
393057b
merged updates
njriasan Nov 14, 2024
8a78231
Updated testing [run CI]
njriasan Nov 14, 2024
e00ee40
Added aggregate
njriasan Nov 14, 2024
196a812
Propagated changes
njriasan Nov 14, 2024
16954e2
Merged with parent PR [run CI]
njriasan Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pydough/relational/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
"""

__all__ = [
"Column",
"Relational",
]
from .abstract import (
Column,
Relational,
)
101 changes: 39 additions & 62 deletions pydough/relational/abstract.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
"""
This module contains the abstract base classes for the relational
This file contains the abstract base classes for the relational
representation. This roughly maps to a Relational Algebra representation
but is not exact because it needs to maintain PyDough traits that define
ordering and other properties of the relational expression.
"""

from abc import ABC, abstractmethod
from collections.abc import MutableMapping, MutableSequence
from typing import Any, NamedTuple
from typing import Any

from sqlglot.expressions import Expression
from sqlglot.expressions import Expression as SQLGlotExpression

from pydough.pydough_ast.expressions import PyDoughExpressionAST
from .relational_expressions import RelationalExpression


class Relational(ABC):
Expand All @@ -20,6 +20,9 @@ class Relational(ABC):
structure of all relational nodes in the PyDough system.
"""

def __init__(self, columns: MutableMapping[str, RelationalExpression]) -> None:
self._columns: MutableMapping[str, RelationalExpression] = columns

@property
@abstractmethod
def inputs(self) -> MutableSequence["Relational"]:
Expand All @@ -32,55 +35,62 @@ def inputs(self) -> MutableSequence["Relational"]:
"""

@property
def traits(self) -> MutableMapping[str, Any]:
def columns(self) -> MutableMapping[str, RelationalExpression]:
"""
Return the traits of the relational expression.
The traits in general may have a variable schema,
but each entry should be strongly defined. Here are
traits that should always be available:
Returns the columns of the relational expression.

- orderings: MutableSequence[PyDoughExpressionAST]
TODO: Associate an ordering in the future to avoid unnecessary SQL with the
final ordering of the root nodes.

Returns:
MutableMapping[str, Any]: The traits of the relational expression.
MutableMapping[str, RelationalExpression]: The columns of the relational expression.
This does not have a defined ordering.
"""
return {"orderings": self.orderings}
return self._columns

@property
@abstractmethod
def orderings(self) -> MutableSequence["PyDoughExpressionAST"]:
def node_equals(self, other: "Relational") -> bool:
"""
Returns the PyDoughExpressionAST that the relational expression is ordered by.
Each PyDoughExpressionAST is a result computed relative to the given set of columns.
Determine if two relational nodes are exactly identical,
excluding column ordering. This should be extended to avoid
duplicating equality logic shared across relational nodes.

Args:
other (Relational): The other relational node to compare against.

Returns:
MutableSequence[PyDoughExpressionAST]: The PyDoughExpressionAST that the relational expression is ordered by,
possibly empty.
bool: Are the two relational nodes equal.
"""

@property
@abstractmethod
def columns(self) -> MutableSequence["Column"]:
def equals(self, other: "Relational") -> bool:
"""
Returns the columns of the relational expression.
Determine if two relational nodes are exactly identical,
including column ordering.

Args:
other (Relational): The other relational node to compare against.

Returns:
MutableSequence[Column]: The columns of the relational expression.
bool: Are the two relational nodes equal.
"""
return self.node_equals(other) and self.columns == other.columns

def __eq__(self, other: Any) -> bool:
return isinstance(other, Relational) and self.equals(other)

@abstractmethod
def to_sqlglot(self) -> "Expression":
"""Translate the given relational expression
def to_sqlglot(self) -> SQLGlotExpression:
"""Translate the given relational node
and its children to a SQLGlot expression.

Returns:
Expression: A SqlGlot expression representing the relational expression.
Expression: A SqlGlot expression representing the relational node.
"""

@abstractmethod
def to_string(self) -> str:
"""
Convert the relational expression to a string.
Convert the relational node to a string.

TODO: Refactor this API to include some form of string
builder so we can draw lines between children properly.
Expand All @@ -90,38 +100,5 @@ def to_string(self) -> str:
with this node at the root.
"""

@abstractmethod
def can_merge(self, other: "Relational") -> bool:
"""
Determine if two relational nodes can be merged together.

Args:
other (Relational): The other relational node to merge with.

Returns:
bool: Can the two relational nodes be merged together.
"""

@abstractmethod
def merge(self, other: "Relational") -> "Relational":
"""
Merge two relational nodes together to produce one output
relational node. This requires can_merge to return True.

Args:
other (Relational): The other relational node to merge with.

Returns:
Relational: A new relational node that is the result of merging
the two input relational nodes together and removing any redundant
components.
"""


class Column(NamedTuple):
"""
An column expression consisting of a name and an expression.
"""

name: str
expr: "PyDoughExpressionAST"
def __repr__(self) -> str:
return self.to_string()
64 changes: 64 additions & 0 deletions pydough/relational/aggregate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
This file contains the relational implementation for an aggregation. This is our
relational representation for any grouping operation that optionally involves
keys and aggregate functions.
"""

from collections.abc import MutableMapping

from sqlglot.expressions import Expression as SQLGlotExpression

from .abstract import Relational
from .relational_expressions.call_expression import CallExpression
from .relational_expressions.column_reference import ColumnReference
from .single_relational import SingleRelational


class Aggregate(SingleRelational):
"""
The Aggregate node in the relational tree. This node represents an aggregation
based on some keys, which should most commonly be column references, and some
aggregate functions.
"""

def __init__(
self,
input: Relational,
keys: MutableMapping[str, ColumnReference],
aggregations: MutableMapping[str, CallExpression],
) -> None:
total_cols = {**keys, **aggregations}
assert len(total_cols) == len(keys) + len(
aggregations
), "Keys and aggregations must have unique names"
super().__init__(input, total_cols)
self._keys: MutableMapping[str, ColumnReference] = keys
self._aggregations: MutableMapping[str, CallExpression] = aggregations
assert all(
agg.is_aggregation for agg in aggregations.values()
), "All functions used in aggregations must be aggregation functions"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can avoid this if we have a special version of CallExpression just for agg calls (its inputs are only allowed to be literals or column references, nothing else) -> use type annotations to verify that the aggregations argument maps to this special variant.

They can then check the operator passed in to make sure that its .is_aggregation field is true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's consider this as a followup. I think the extra class isn't worth the safety.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I'm not supporting literals yet. I still need to figure that out.


@property
def keys(self) -> MutableMapping[str, ColumnReference]:
return self._keys

@property
def aggregations(self) -> MutableMapping[str, CallExpression]:
return self._aggregations

def to_sqlglot(self) -> SQLGlotExpression:
raise NotImplementedError(
"Conversion to SQLGlot Expressions is not yet implemented."
)

def node_equals(self, other: Relational) -> bool:
return (
isinstance(other, Aggregate)
and self.keys == other.keys
and self.aggregations == other.aggregations
and super().node_equals(other)
)

def to_string(self) -> str:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For context because I didn't mention this in prior PRs, I've decided not to represent a tree for simplicity. I would rather eventually add a "TreeWriter" class that could unify nodes more easily.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, I suggest we have Relational object have a standalone_string (like AST nodes now do), then in the base class we can add a to_string method that recursively calls to_string on all children, indents them by 1, and prepends with self.standalone_string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to revisit this in a followup. It seems like a more reasonable pattern to have a separate writer object (equivalent to Calcite) so we can manage things like drawing the lines more elegantly.

# TODO: Should we visit the input?
return f"AGGREGATE(keys={self.keys}, aggregations={self.aggregations})"
67 changes: 67 additions & 0 deletions pydough/relational/limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
This file contains the relational implementation for a "limit" operation.
This is the relational representation of top-n selection and typically depends
on explicit ordering of the input relation.
"""

from collections.abc import MutableMapping, MutableSequence

from sqlglot.expressions import Expression as SQLGlotExpression

from pydough.types.integer_types import IntegerType

from .abstract import Relational
from .relational_expressions import ColumnOrdering, RelationalExpression
from .single_relational import SingleRelational


class Limit(SingleRelational):
"""
The Limit node in the relational tree. This node represents any TOP-N
operations in the relational algebra. This operation is dependent on the
orderings of the input relation.
"""

def __init__(
self,
input: Relational,
limit: RelationalExpression,
columns: MutableMapping[str, RelationalExpression],
orderings: MutableSequence[ColumnOrdering] | None = None,
) -> None:
super().__init__(input, columns)
# Note: The limit is a relational expression because it should be a constant
# now but in the future could be a more complex expression that may require
# multi-step SQL to successfully evaluate.
assert isinstance(
limit.data_type, IntegerType
), "Limit must be an integer type."
self._limit: RelationalExpression = limit
self._orderings: MutableSequence[ColumnOrdering] = (
[] if orderings is None else orderings
)

@property
def limit(self) -> RelationalExpression:
return self._limit

@property
def orderings(self) -> MutableSequence[ColumnOrdering]:
return self._orderings

def to_sqlglot(self) -> SQLGlotExpression:
raise NotImplementedError(
"Conversion to SQLGlot Expressions is not yet implemented."
)

def node_equals(self, other: Relational) -> bool:
return (
isinstance(other, Limit)
and self.limit == other.limit
and self.orderings == other.orderings
and super().node_equals(other)
)

def to_string(self) -> str:
# TODO: Should we visit the input?
return f"LIMIT(limit={self.limit}, columns={self.columns}, orderings={self.orderings})"
42 changes: 42 additions & 0 deletions pydough/relational/project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
This file contains the relational implementation for a "project". This is our
relational representation for a "calc" that involves any compute steps and can include
adding or removing columns (as well as technically reordering). In general, we seek to
avoid introducing extra nodes just to reorder or prune columns, so ideally their use
should be sparse.
"""

from collections.abc import MutableMapping

from sqlglot.expressions import Expression as SQLGlotExpression

from .abstract import Relational
from .relational_expressions import RelationalExpression
from .single_relational import SingleRelational


class Project(SingleRelational):
"""
The Project node in the relational tree. This node represents a "calc" in
relational algebra, which should involve some "compute" functions and may
involve adding, removing, or reordering columns.
"""

def __init__(
self,
input: Relational,
columns: MutableMapping[str, RelationalExpression],
) -> None:
super().__init__(input, columns)

def to_sqlglot(self) -> SQLGlotExpression:
raise NotImplementedError(
"Conversion to SQLGlot Expressions is not yet implemented."
)

def node_equals(self, other: Relational) -> bool:
return isinstance(other, Project) and super().node_equals(other)

def to_string(self) -> str:
# TODO: Should we visit the input?
return f"PROJECT(columns={self.columns})"
10 changes: 10 additions & 0 deletions pydough/relational/relational_expressions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
TODO: add module-level docstring
"""

__all__ = [
"ColumnOrdering",
"RelationalExpression",
]
from .abstract import RelationalExpression
from .column_ordering import ColumnOrdering
Loading