Skip to content

Commit f9b2dce

Browse files
authored
Add pre-commit CI to PyDough (#252)
Resolves #250. See issue for more details.
1 parent d491718 commit f9b2dce

9 files changed

+102
-51
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
ci:
2+
autoupdate_schedule: monthly
3+
14
repos:
25
- repo: https://github.com/astral-sh/ruff-pre-commit
36
rev: v0.6.7

pydough/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
"explain",
99
"explain_structure",
1010
"explain_term",
11+
"get_logger",
1112
"init_pydough_context",
1213
"parse_json_metadata_from_file",
1314
"to_df",
1415
"to_sql",
15-
"get_logger"
1616
]
1717

1818
from .configs import PyDoughSession

pydough/logger/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
Module of PyDough dealing with logging across the library
33
"""
44

5-
__all__ = [
6-
"get_logger"
7-
]
5+
__all__ = ["get_logger"]
86

97
from .logger import get_logger

pydough/sqlglot/execute_relational.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def convert_relation_to_sql(
2626
relational: RelationalRoot,
2727
dialect: SQLGlotDialect,
2828
bindings: SqlGlotTransformBindings,
29-
pretty_print_sql: bool = False
29+
pretty_print_sql: bool = False,
3030
) -> str:
3131
"""
3232
Convert the given relational tree to a SQL string using the given dialect.
@@ -44,7 +44,7 @@ def convert_relation_to_sql(
4444
glot_expr: SQLGlotExpression = SQLGlotRelationalVisitor(
4545
dialect, bindings
4646
).relational_to_sqlglot(relational)
47-
return glot_expr.sql(dialect,pretty=pretty_print_sql)
47+
return glot_expr.sql(dialect, pretty=pretty_print_sql)
4848

4949

5050
def convert_dialect_to_sqlglot(dialect: DatabaseDialect) -> SQLGlotDialect:
@@ -91,7 +91,9 @@ def execute_df(
9191
pretty_print_sql: bool = False
9292
if display_sql:
9393
pretty_print_sql = True
94-
sql: str = convert_relation_to_sql(relational, sqlglot_dialect, bindings,pretty_print_sql)
94+
sql: str = convert_relation_to_sql(
95+
relational, sqlglot_dialect, bindings, pretty_print_sql
96+
)
9597
if display_sql:
9698
pyd_logger = get_logger(__name__)
9799
pyd_logger.info(f"SQL query:\n {sql}")

pydough/sqlglot/sqlglot_relational_visitor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _merge_selects(
171171
new_columns: list[SQLGlotExpression],
172172
orig_select: Select,
173173
deps: set[Identifier],
174-
sort: bool = True
174+
sort: bool = True,
175175
) -> Select:
176176
"""
177177
Attempt to merge a new select statement with an existing one.
@@ -195,7 +195,7 @@ def _merge_selects(
195195
new_columns, orig_select.expressions, deps
196196
)
197197
if sort:
198-
old_exprs = sorted(old_exprs,key=repr)
198+
old_exprs = sorted(old_exprs, key=repr)
199199
orig_select.set("expressions", old_exprs)
200200
if new_exprs is None:
201201
return orig_select
@@ -289,7 +289,7 @@ def _build_subquery(
289289
Select: A select statement representing the subquery.
290290
"""
291291
if sort:
292-
column_exprs = sorted(column_exprs,key=repr)
292+
column_exprs = sorted(column_exprs, key=repr)
293293
return (
294294
Select().select(*column_exprs).from_(Subquery(this=input_expr, alias=alias))
295295
)

pydough/sqlglot/transform_bindings.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@ def convert_ndistinct(
505505
this=sqlglot_expressions.Distinct(expressions=[column])
506506
)
507507

508+
508509
def create_convert_time_unit_function(unit: str):
509510
"""
510511
Creates a function that extracts a specific time unit
@@ -517,6 +518,7 @@ def create_convert_time_unit_function(unit: str):
517518
A function that can convert operands into a SQLGlot expression matching
518519
the functionality of `EXTRACT(unit FROM expression)`.
519520
"""
521+
520522
def convert_time_unit(
521523
raw_args: Sequence[RelationalExpression] | None,
522524
sql_glot_args: Sequence[SQLGlotExpression],
@@ -537,35 +539,35 @@ def convert_time_unit(
537539
from the first operand.
538540
"""
539541
return sqlglot_expressions.Extract(
540-
this=sqlglot_expressions.Var(this=unit),
541-
expression=sql_glot_args[0]
542+
this=sqlglot_expressions.Var(this=unit), expression=sql_glot_args[0]
542543
)
543544

544545
return convert_time_unit
545546

547+
546548
def convert_sqrt(
547-
raw_args: Sequence[RelationalExpression] | None,
548-
sql_glot_args: Sequence[SQLGlotExpression],
549-
) -> SQLGlotExpression:
549+
raw_args: Sequence[RelationalExpression] | None,
550+
sql_glot_args: Sequence[SQLGlotExpression],
551+
) -> SQLGlotExpression:
550552
"""
551-
Support for getting the square root of the operand.
553+
Support for getting the square root of the operand.
552554
553-
Args:
554-
`raw_args`: The operands passed to the function before they were converted to
555-
SQLGlot expressions. (Not actively used in this implementation.)
556-
`sql_glot_args`: The operands passed to the function after they were converted
557-
to SQLGlot expressions.
555+
Args:
556+
`raw_args`: The operands passed to the function before they were converted to
557+
SQLGlot expressions. (Not actively used in this implementation.)
558+
`sql_glot_args`: The operands passed to the function after they were converted
559+
to SQLGlot expressions.
558560
559-
Returns:
560-
The SQLGlot expression matching the functionality of
561-
`POWER(x,0.5)`,i.e the square root.
561+
Returns:
562+
The SQLGlot expression matching the functionality of
563+
`POWER(x,0.5)`,i.e the square root.
562564
"""
563565

564566
return sqlglot_expressions.Pow(
565-
this=sql_glot_args[0],
566-
expression=sqlglot_expressions.Literal.number(0.5)
567+
this=sql_glot_args[0], expression=sqlglot_expressions.Literal.number(0.5)
567568
)
568569

570+
569571
class SqlGlotTransformBindings:
570572
"""
571573
Binding infrastructure used to associate PyDough operators with a procedure
@@ -754,8 +756,8 @@ def add_builtin_bindings(self) -> None:
754756
self.bind_binop(pydop.NEQ, sqlglot_expressions.NEQ)
755757
self.bind_binop(pydop.BAN, sqlglot_expressions.And)
756758
self.bind_binop(pydop.BOR, sqlglot_expressions.Or)
757-
self.bind_binop(pydop.POW,sqlglot_expressions.Pow)
758-
self.bind_binop(pydop.POWER,sqlglot_expressions.Pow)
759+
self.bind_binop(pydop.POW, sqlglot_expressions.Pow)
760+
self.bind_binop(pydop.POWER, sqlglot_expressions.Pow)
759761
self.bindings[pydop.SQRT] = convert_sqrt
760762

761763
# Unary operators

tests/simple_pydough_functions.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -278,14 +278,18 @@ def hour_minute_day():
278278
transaction timestamps for specific ticker symbols ("AAPL","GOOGL","NFLX"),
279279
ordered by transaction ID in ascending order.
280280
"""
281-
return Transactions(
282-
transaction_id, HOUR(date_time), MINUTE(date_time), SECOND(date_time)
283-
).WHERE(
284-
ISIN(ticker.symbol,("AAPL", "GOOGL", "NFLX"))
285-
).ORDER_BY(
286-
transaction_id.ASC()
281+
return (
282+
Transactions(
283+
transaction_id, HOUR(date_time), MINUTE(date_time), SECOND(date_time)
284+
)
285+
.WHERE(ISIN(ticker.symbol, ("AAPL", "GOOGL", "NFLX")))
286+
.ORDER_BY(transaction_id.ASC())
287287
)
288288

289+
289290
def exponentiation():
290-
return DailyPrices(low_square = low ** 2, low_sqrt = SQRT(low),
291-
low_cbrt = POWER(low, 1/3), ).TOP_K(10, by=low_square.ASC())
291+
return DailyPrices(
292+
low_square=low**2,
293+
low_sqrt=SQRT(low),
294+
low_cbrt=POWER(low, 1 / 3),
295+
).TOP_K(10, by=low_square.ASC())

tests/test_pipeline.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -754,15 +754,26 @@ def test_pipeline_e2e_errors(
754754
lambda: pd.DataFrame(
755755
{
756756
"transaction_id": [
757-
"TX001", "TX005", "TX011", "TX015", "TX021", "TX025",
758-
"TX031", "TX033", "TX035", "TX044", "TX045", "TX049",
759-
"TX051", "TX055"
757+
"TX001",
758+
"TX005",
759+
"TX011",
760+
"TX015",
761+
"TX021",
762+
"TX025",
763+
"TX031",
764+
"TX033",
765+
"TX035",
766+
"TX044",
767+
"TX045",
768+
"TX049",
769+
"TX051",
770+
"TX055",
760771
],
761772
"_expr0": [9, 12, 9, 12, 9, 12, 0, 0, 0, 10, 10, 16, 0, 0],
762773
"_expr1": [30, 30, 30, 30, 30, 30, 0, 0, 0, 0, 30, 0, 0, 0],
763774
"_expr2": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
764775
}
765-
)
776+
),
766777
),
767778
id="broker_basic1",
768779
),
@@ -772,22 +783,52 @@ def test_pipeline_e2e_errors(
772783
"Broker",
773784
lambda: pd.DataFrame(
774785
{
775-
"low_square" : [6642.2500, 6740.4100, 6839.2900, 6938.8900, 7039.2100,
776-
7140.2500, 7242.0100, 16576.5625, 16900.0000, 17292.2500],
777-
"low_sqrt" : [9.027735, 9.060905, 9.093954, 9.126883, 9.159694,
778-
9.192388, 9.224966, 11.346806, 11.401754, 11.467345],
779-
"low_cbrt" : [4.335633, 4.346247, 4.356809, 4.367320, 4.377781,
780-
4.388191, 4.398553, 5.049508, 5.065797, 5.085206]
786+
"low_square": [
787+
6642.2500,
788+
6740.4100,
789+
6839.2900,
790+
6938.8900,
791+
7039.2100,
792+
7140.2500,
793+
7242.0100,
794+
16576.5625,
795+
16900.0000,
796+
17292.2500,
797+
],
798+
"low_sqrt": [
799+
9.027735,
800+
9.060905,
801+
9.093954,
802+
9.126883,
803+
9.159694,
804+
9.192388,
805+
9.224966,
806+
11.346806,
807+
11.401754,
808+
11.467345,
809+
],
810+
"low_cbrt": [
811+
4.335633,
812+
4.346247,
813+
4.356809,
814+
4.367320,
815+
4.377781,
816+
4.388191,
817+
4.398553,
818+
5.049508,
819+
5.065797,
820+
5.085206,
821+
],
781822
}
782-
)
823+
),
783824
),
784825
id="exponentiation",
785826
),
786-
],
827+
],
787828
)
788829
def custom_defog_test_data(
789830
request,
790-
) -> tuple[Callable[[], UnqualifiedNode],str,pd.DataFrame]:
831+
) -> tuple[Callable[[], UnqualifiedNode], str, pd.DataFrame]:
791832
"""
792833
Test data for test_defog_e2e. Returns a tuple of the following
793834
arguments:
@@ -801,7 +842,7 @@ def custom_defog_test_data(
801842

802843
@pytest.mark.execute
803844
def test_defog_e2e_with_custom_data(
804-
custom_defog_test_data: tuple[Callable[[], UnqualifiedNode],str,pd.DataFrame],
845+
custom_defog_test_data: tuple[Callable[[], UnqualifiedNode], str, pd.DataFrame],
805846
defog_graphs: graph_fetcher,
806847
sqlite_defog_connection: DatabaseContext,
807848
):
@@ -810,7 +851,7 @@ def test_defog_e2e_with_custom_data(
810851
comparing against the result of running the reference SQL query text on the
811852
same database connector.
812853
"""
813-
unqualified_impl, graph_name ,answer_impl = custom_defog_test_data
854+
unqualified_impl, graph_name, answer_impl = custom_defog_test_data
814855
graph: GraphMetadata = defog_graphs(graph_name)
815856
root: UnqualifiedNode = init_pydough_context(graph)(unqualified_impl)()
816857
result: pd.DataFrame = to_df(root, metadata=graph, database=sqlite_defog_connection)

tests/test_pydough_to_sql.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_pydough_to_sql(
7272
expected_sql = expected_sql.strip()
7373
assert actual_sql == expected_sql
7474

75+
7576
@pytest.mark.parametrize(
7677
"pydough_code,expected_sql,graph_name",
7778
[

0 commit comments

Comments
 (0)