Skip to content

Commit 8d4f77b

Browse files
committed
restore test
1 parent f5ec176 commit 8d4f77b

File tree

3 files changed

+85
-60
lines changed

3 files changed

+85
-60
lines changed

narwhals/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
from narwhals.dataframe import DataFrame
77
from narwhals.expression import col
88
from narwhals.expression import len
9+
from narwhals.expression import max
10+
from narwhals.expression import mean
11+
from narwhals.expression import min
12+
from narwhals.expression import sum
913
from narwhals.translate import get_namespace
1014
from narwhals.translate import to_native
1115
from narwhals.translate import translate_any
@@ -27,5 +31,9 @@
2731
"to_native",
2832
"col",
2933
"len",
34+
"min",
35+
"max",
36+
"mean",
37+
"sum",
3038
"DataFrame",
3139
]

narwhals/expression.py

+16
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,19 @@ def col(col_name: str):
138138

139139
def len():
140140
return NarwhalsExpr(lambda plx: plx.len())
141+
142+
143+
def sum(col_name):
144+
return NarwhalsExpr(lambda plx: plx.sum(col_name))
145+
146+
147+
def mean(col_name):
148+
return NarwhalsExpr(lambda plx: plx.mean(col_name))
149+
150+
151+
def min(col_name):
152+
return NarwhalsExpr(lambda plx: plx.min(col_name))
153+
154+
155+
def max(col_name):
156+
return NarwhalsExpr(lambda plx: plx.max(col_name))

tests/tpch_q1_test.py

+61-60
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
import os
34
from datetime import datetime
45
from typing import Any
6+
from unittest import mock
57

68
import polars
79
import pytest
@@ -71,63 +73,62 @@ def test_q1(df_raw: Any) -> None:
7173
compare_dicts(result, expected)
7274

7375

74-
# @pytest.mark.parametrize(
75-
# "df_raw",
76-
# [
77-
# (polars.read_parquet("tests/data/lineitem.parquet").to_pandas()),
78-
# ],
79-
# )
80-
# @mock.patch.dict(os.environ, {"NARWHALS_FORCE_GENERIC": "1"})
81-
# def test_q1_w_pandas_agg_generic_path(df_raw: Any) -> None:
82-
# var_1 = datetime(1998, 9, 2)
83-
# df = translate_frame(df_raw, is_lazy=True)
84-
# pl = get_namespace(df)
85-
# query_result = (
86-
# df.filter(pl.col("l_shipdate") <= var_1)
87-
# .group_by(["l_returnflag", "l_linestatus"])
88-
# .agg(
89-
# [
90-
# pl.sum("l_quantity").alias("sum_qty"),
91-
# pl.sum("l_extendedprice").alias("sum_base_price"),
92-
# (pl.col("l_extendedprice") * (1 - pl.col("l_discount")))
93-
# .sum()
94-
# .alias("sum_disc_price"),
95-
# (
96-
# pl.col("l_extendedprice")
97-
# * (1.0 - pl.col("l_discount"))
98-
# * (1.0 + pl.col("l_tax"))
99-
# )
100-
# .sum()
101-
# .alias("sum_charge"),
102-
# pl.mean("l_quantity").alias("avg_qty"),
103-
# pl.mean("l_extendedprice").alias("avg_price"),
104-
# pl.mean("l_discount").alias("avg_disc"),
105-
# pl.len().alias("count_order"),
106-
# ],
107-
# )
108-
# .sort(["l_returnflag", "l_linestatus"])
109-
# )
110-
# result = query_result.collect().to_dict(as_series=False)
111-
# expected = {
112-
# "l_returnflag": ["A", "N", "N", "R"],
113-
# "l_linestatus": ["F", "F", "O", "F"],
114-
# "sum_qty": [2109.0, 29.0, 3682.0, 1876.0],
115-
# "sum_base_price": [3114026.44, 39824.83, 5517101.99, 2947892.16],
116-
# "sum_disc_price": [2954950.8082, 39028.3334, 5205468.4852, 2816542.4816999994],
117-
# "sum_charge": [
118-
# 3092840.4194289995,
119-
# 39808.900068,
120-
# 5406966.873419,
121-
# 2935797.8313019997,
122-
# ],
123-
# "avg_qty": [27.75, 29.0, 25.047619047619047, 26.422535211267604],
124-
# "avg_price": [
125-
# 40974.032105263155,
126-
# 39824.83,
127-
# 37531.30605442177,
128-
# 41519.607887323946,
129-
# ],
130-
# "avg_disc": [0.05039473684210526, 0.02, 0.05537414965986395, 0.04507042253521127],
131-
# "count_order": [76, 1, 147, 71],
132-
# }
133-
# compare_dicts(result, expected)
76+
@pytest.mark.parametrize(
77+
"df_raw",
78+
[
79+
(polars.read_parquet("tests/data/lineitem.parquet").to_pandas()),
80+
],
81+
)
82+
@mock.patch.dict(os.environ, {"NARWHALS_FORCE_GENERIC": "1"})
83+
def test_q1_w_pandas_agg_generic_path(df_raw: Any) -> None:
84+
var_1 = datetime(1998, 9, 2)
85+
df = nw.DataFrame(df_raw, is_lazy=True)
86+
query_result = (
87+
df.filter(nw.col("l_shipdate") <= var_1)
88+
.group_by(["l_returnflag", "l_linestatus"])
89+
.agg(
90+
[
91+
nw.sum("l_quantity").alias("sum_qty"),
92+
nw.sum("l_extendedprice").alias("sum_base_price"),
93+
(nw.col("l_extendedprice") * (1 - nw.col("l_discount")))
94+
.sum()
95+
.alias("sum_disc_price"),
96+
(
97+
nw.col("l_extendedprice")
98+
* (1.0 - nw.col("l_discount"))
99+
* (1.0 + nw.col("l_tax"))
100+
)
101+
.sum()
102+
.alias("sum_charge"),
103+
nw.mean("l_quantity").alias("avg_qty"),
104+
nw.mean("l_extendedprice").alias("avg_price"),
105+
nw.mean("l_discount").alias("avg_disc"),
106+
nw.len().alias("count_order"),
107+
],
108+
)
109+
.sort(["l_returnflag", "l_linestatus"])
110+
)
111+
result = query_result.collect().to_dict(as_series=False)
112+
expected = {
113+
"l_returnflag": ["A", "N", "N", "R"],
114+
"l_linestatus": ["F", "F", "O", "F"],
115+
"sum_qty": [2109.0, 29.0, 3682.0, 1876.0],
116+
"sum_base_price": [3114026.44, 39824.83, 5517101.99, 2947892.16],
117+
"sum_disc_price": [2954950.8082, 39028.3334, 5205468.4852, 2816542.4816999994],
118+
"sum_charge": [
119+
3092840.4194289995,
120+
39808.900068,
121+
5406966.873419,
122+
2935797.8313019997,
123+
],
124+
"avg_qty": [27.75, 29.0, 25.047619047619047, 26.422535211267604],
125+
"avg_price": [
126+
40974.032105263155,
127+
39824.83,
128+
37531.30605442177,
129+
41519.607887323946,
130+
],
131+
"avg_disc": [0.05039473684210526, 0.02, 0.05537414965986395, 0.04507042253521127],
132+
"count_order": [76, 1, 147, 71],
133+
}
134+
compare_dicts(result, expected)

0 commit comments

Comments
 (0)