Skip to content

Commit b44d749

Browse files
authored
Merge pull request #9 from raisadz/main
Add tests for common dataframe operations
2 parents 9ce6f21 + 23ec222 commit b44d749

File tree

5 files changed

+134
-8
lines changed

5 files changed

+134
-8
lines changed

narwhals/expression.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Iterable
77

88
from narwhals.dtypes import translate_dtype
9+
from narwhals.utils import flatten
910

1011
if TYPE_CHECKING:
1112
from narwhals.typing import IntoExpr
@@ -184,7 +185,9 @@ def max(*columns: str) -> Expr:
184185

185186

186187
def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
187-
return Expr(lambda plx: plx.sum_horizontal(*exprs))
188+
return Expr(
189+
lambda plx: plx.sum_horizontal([extract_native(plx, v) for v in flatten(exprs)])
190+
)
188191

189192

190193
__all__ = [

narwhals/pandas_like/dataframe.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ def filter(
9999
self,
100100
*predicates: IntoPandasExpr | Iterable[IntoPandasExpr],
101101
) -> Self:
102-
from narwhals.pandas_like.namespace import Namespace
102+
from narwhals.pandas_like.namespace import PandasNamespace
103103

104-
plx = Namespace(self._implementation)
104+
plx = PandasNamespace(self._implementation)
105105
expr = plx.all_horizontal(*predicates)
106106
# Safety: all_horizontal's expression only returns a single column.
107107
mask = expr._call(self)[0]

narwhals/pandas_like/namespace.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from narwhals.pandas_like.typing import IntoPandasExpr
2020

2121

22-
class Namespace:
22+
class PandasNamespace:
2323
Int64 = dtypes.Int64
2424
Int32 = dtypes.Int32
2525
Int16 = dtypes.Int16

narwhals/pandas_like/utils.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,13 @@ def parse_into_exprs(
9797

9898
def parse_into_expr(implementation: str, into_expr: IntoPandasExpr) -> PandasExpr:
9999
from narwhals.expression import Expr
100-
from narwhals.pandas_like.namespace import Namespace
100+
from narwhals.pandas_like.expr import PandasExpr
101+
from narwhals.pandas_like.namespace import PandasNamespace
101102

102-
plx = Namespace(implementation=implementation)
103+
plx = PandasNamespace(implementation=implementation)
103104

105+
if isinstance(into_expr, PandasExpr):
106+
return into_expr
104107
if isinstance(into_expr, Expr):
105108
return into_expr._call(plx)
106109
if isinstance(into_expr, str):
@@ -141,10 +144,10 @@ def evaluate_into_exprs(
141144

142145
def register_expression_call(expr: ExprT, attr: str, *args: Any, **kwargs: Any) -> ExprT:
143146
from narwhals.pandas_like.expr import PandasExpr
144-
from narwhals.pandas_like.namespace import Namespace
147+
from narwhals.pandas_like.namespace import PandasNamespace
145148
from narwhals.pandas_like.series import PandasSeries
146149

147-
plx = Namespace(implementation=expr._implementation)
150+
plx = PandasNamespace(implementation=expr._implementation)
148151

149152
def func(df: PandasDataFrame) -> list[PandasSeries]:
150153
out: list[PandasSeries] = []

tests/test_common.py

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import pandas as pd
6+
import polars as pl
7+
import pytest
8+
9+
import narwhals as nw
10+
from tests.utils import compare_dicts
11+
12+
df_pandas = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
13+
df_polars = pl.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
14+
15+
16+
@pytest.mark.parametrize(
17+
"df_raw",
18+
[df_pandas, df_polars],
19+
)
20+
def test_sort(df_raw: Any) -> None:
21+
df = nw.DataFrame(df_raw)
22+
result = df.sort("a", "b")
23+
result_native = nw.to_native(result)
24+
expected = {
25+
"a": [1, 2, 3],
26+
"b": [4, 6, 4],
27+
"z": [7.0, 9.0, 8.0],
28+
}
29+
compare_dicts(result_native, expected)
30+
31+
32+
@pytest.mark.parametrize(
33+
"df_raw",
34+
[df_pandas, df_polars],
35+
)
36+
def test_filter(df_raw: Any) -> None:
37+
df = nw.DataFrame(df_raw)
38+
result = df.filter(nw.col("a") > 1)
39+
result_native = nw.to_native(result)
40+
expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]}
41+
compare_dicts(result_native, expected)
42+
43+
44+
@pytest.mark.parametrize(
45+
"df_raw",
46+
[df_pandas, df_polars],
47+
)
48+
def test_add(df_raw: Any) -> None:
49+
df = nw.DataFrame(df_raw)
50+
result = df.with_columns(
51+
c=nw.col("a") + nw.col("b"),
52+
d=nw.col("a") - nw.col("a").mean(),
53+
)
54+
result_native = nw.to_native(result)
55+
expected = {
56+
"a": [1, 3, 2],
57+
"b": [4, 4, 6],
58+
"z": [7.0, 8.0, 9.0],
59+
"c": [5, 7, 8],
60+
"d": [-1.0, 1.0, 0.0],
61+
}
62+
compare_dicts(result_native, expected)
63+
64+
65+
@pytest.mark.parametrize(
66+
"df_raw",
67+
[df_pandas, df_polars],
68+
)
69+
def test_double(df_raw: Any) -> None:
70+
df = nw.DataFrame(df_raw)
71+
result = df.with_columns(nw.all() * 2)
72+
result_native = nw.to_native(result)
73+
expected = {"a": [2, 6, 4], "b": [8, 8, 12], "z": [14.0, 16.0, 18.0]}
74+
compare_dicts(result_native, expected)
75+
76+
77+
@pytest.mark.parametrize("df_raw", [df_pandas, df_polars])
78+
def test_sumh(df_raw: Any) -> None:
79+
df = nw.DataFrame(df_raw)
80+
result = df.with_columns(horizonal_sum=nw.sum_horizontal(nw.col("a"), nw.col("b")))
81+
result_native = nw.to_native(result)
82+
expected = {
83+
"a": [1, 3, 2],
84+
"b": [4, 4, 6],
85+
"z": [7.0, 8.0, 9.0],
86+
"horizonal_sum": [5, 7, 8],
87+
}
88+
compare_dicts(result_native, expected)
89+
90+
91+
@pytest.mark.parametrize("df_raw", [df_pandas, df_polars])
92+
def test_sumh_literal(df_raw: Any) -> None:
93+
df = nw.DataFrame(df_raw)
94+
result = df.with_columns(horizonal_sum=nw.sum_horizontal("a", nw.col("b")))
95+
result_native = nw.to_native(result)
96+
expected = {
97+
"a": [1, 3, 2],
98+
"b": [4, 4, 6],
99+
"z": [7.0, 8.0, 9.0],
100+
"horizonal_sum": [5, 7, 8],
101+
}
102+
compare_dicts(result_native, expected)
103+
104+
105+
@pytest.mark.parametrize("df_raw", [df_pandas, df_polars])
106+
def test_sum_all(df_raw: Any) -> None:
107+
df = nw.DataFrame(df_raw)
108+
result = df.select(nw.all().sum())
109+
result_native = nw.to_native(result)
110+
expected = {"a": [6], "b": [14], "z": [24.0]}
111+
compare_dicts(result_native, expected)
112+
113+
114+
@pytest.mark.parametrize("df_raw", [df_pandas, df_polars])
115+
def test_double_selected(df_raw: Any) -> None:
116+
df = nw.DataFrame(df_raw)
117+
result = df.select(nw.col("a", "b") * 2)
118+
result_native = nw.to_native(result)
119+
expected = {"a": [2, 6, 4], "b": [8, 8, 12]}
120+
compare_dicts(result_native, expected)

0 commit comments

Comments
 (0)