Skip to content

Commit 2bf5c2b

Browse files
committed
test1 working for polars!
1 parent dd1f70e commit 2bf5c2b

File tree

5 files changed

+62
-17
lines changed

5 files changed

+62
-17
lines changed

narwhals/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from narwhals.containers import is_series
66
from narwhals.dataframe import NarwhalsFrame
77
from narwhals.expression import col
8+
from narwhals.expression import len
89
from narwhals.translate import get_namespace
910
from narwhals.translate import to_native
1011
from narwhals.translate import translate_any
@@ -25,5 +26,6 @@
2526
"get_namespace",
2627
"to_native",
2728
"col",
29+
"len",
2830
"NarwhalsFrame",
2931
]

narwhals/dataframe.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,40 @@ def group_by(self, *keys: str | Iterable[str]) -> GroupBy:
101101

102102
return NarwhalsGroupBy(
103103
self,
104-
keys,
104+
*keys,
105105
is_eager=self._is_eager,
106106
is_lazy=self._is_lazy,
107107
)
108+
109+
def sort(
110+
self,
111+
by: str | Iterable[str],
112+
*more_by: str,
113+
descending: bool | Sequence[bool] = False,
114+
) -> Self:
115+
if self._implementation == "polars":
116+
return self._from_dataframe(
117+
self._dataframe.sort(by, *more_by, descending=descending)
118+
)
119+
120+
def collect(self) -> Self:
121+
if not self._is_lazy:
122+
raise RuntimeError(
123+
"DataFrame.collect can only be called if frame was instantiated with `is_lazy=True`"
124+
)
125+
if self._implementation == "polars":
126+
import polars as pl
127+
128+
assert isinstance(self._dataframe, pl.LazyFrame)
129+
return self.__class__(self._dataframe.collect(), is_eager=True, is_lazy=False)
130+
131+
def to_dict(self, *, as_series: bool = True) -> dict[str, Any]:
132+
if not self._is_eager:
133+
raise RuntimeError(
134+
"DataFrame.to_dict can only be called if frame was instantiated with `is_eager=True`"
135+
)
136+
if self._implementation == "polars":
137+
import polars as pl
138+
139+
assert isinstance(self._dataframe, pl.DataFrame)
140+
return self._dataframe.to_dict(as_series=as_series)

narwhals/expression.py

+10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import Any
44

5+
from narwhals.translate import get_polars
6+
57

68
def extract_native(expr, other: Any) -> Any:
79
if isinstance(other, NarwhalsExpr):
@@ -130,3 +132,11 @@ def sample(self, n: int, fraction: float, *, with_replacement: bool) -> Expr:
130132

131133
def col(col_name: str):
132134
return NarwhalsExpr(lambda expr: expr(col_name))
135+
136+
137+
def len():
138+
def func(expr):
139+
if (pl := get_polars()) is not None and issubclass(expr, pl.col):
140+
return pl.len()
141+
142+
return NarwhalsExpr(func)

tests/tpch_q1_test.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,21 @@ def test_q1(df_raw: Any) -> None:
2626
.agg(
2727
[
2828
nw.col("l_quantity").sum().alias("sum_qty"),
29-
# nw.col("l_extendedprice").sum().alias("sum_base_price"),
30-
# (nw.col("l_extendedprice") * (1 - nw.col("l_discount")))
31-
# .sum()
32-
# .alias("sum_disc_price"),
33-
# (
34-
# nw.col("l_extendedprice")
35-
# * (1.0 - nw.col("l_discount"))
36-
# * (1.0 + nw.col("l_tax"))
37-
# )
38-
# .sum()
39-
# .alias("sum_charge"),
40-
# nw.col("l_quantity").mean().alias("avg_qty"),
41-
# nw.col("l_extendedprice").mean().alias("avg_price"),
42-
# nw.col("l_discount").mean().alias("avg_disc"),
43-
# nw.len().alias("count_order"),
29+
nw.col("l_extendedprice").sum().alias("sum_base_price"),
30+
(nw.col("l_extendedprice") * (1 - nw.col("l_discount")))
31+
.sum()
32+
.alias("sum_disc_price"),
33+
(
34+
nw.col("l_extendedprice")
35+
* (1.0 - nw.col("l_discount"))
36+
* (1.0 + nw.col("l_tax"))
37+
)
38+
.sum()
39+
.alias("sum_charge"),
40+
nw.col("l_quantity").mean().alias("avg_qty"),
41+
nw.col("l_extendedprice").mean().alias("avg_price"),
42+
nw.col("l_discount").mean().alias("avg_disc"),
43+
nw.len().alias("count_order"),
4444
],
4545
)
4646
.sort(["l_returnflag", "l_linestatus"])

tests/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
def compare_dicts(result: dict[str, Any], expected: dict[str, Any]) -> None:
7-
for key in result:
7+
for key in expected:
88
for lhs, rhs in zip(result[key], expected[key]):
99
if isinstance(lhs, float):
1010
assert abs(lhs - rhs) < 1e-6

0 commit comments

Comments
 (0)