Skip to content

Commit 6cbdf1e

Browse files
committed
fixup flattening
1 parent d26d8a0 commit 6cbdf1e

File tree

6 files changed

+39
-17
lines changed

6 files changed

+39
-17
lines changed

narwhals/dataframe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _from_dataframe(self, df: Any) -> Self:
7272
def _flatten_and_extract(self, *args: Any, **kwargs: Any) -> Any:
7373
from narwhals.utils import flatten
7474

75-
args = [self._extract_native(v) for v in flatten(*args)] # type: ignore[assignment]
75+
args = [self._extract_native(v) for v in flatten(args)] # type: ignore[assignment]
7676
kwargs = {k: self._extract_native(v) for k, v in kwargs.items()}
7777
return args, kwargs
7878

narwhals/expression.py

+10
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,16 @@ def __rsub__(self, other: Any) -> Expr:
7070
lambda plx: self._call(plx).__rsub__(extract_native(plx, other))
7171
)
7272

73+
def __truediv__(self, other: Any) -> Expr:
74+
return self.__class__(
75+
lambda plx: self._call(plx).__truediv__(extract_native(plx, other))
76+
)
77+
78+
def __rtruediv__(self, other: Any) -> Expr:
79+
return self.__class__(
80+
lambda plx: self._call(plx).__rtruediv__(extract_native(plx, other))
81+
)
82+
7383
def __mul__(self, other: Any) -> Expr:
7484
return self.__class__(
7585
lambda plx: self._call(plx).__mul__(extract_native(plx, other))

narwhals/pandas_like/utils.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def parse_into_exprs(
8989
*exprs: IntoPandasExpr | Iterable[IntoPandasExpr],
9090
**named_exprs: IntoPandasExpr,
9191
) -> list[PandasExpr]:
92-
out = [parse_into_expr(implementation, into_expr) for into_expr in flatten(*exprs)] # type: ignore[arg-type]
92+
out = [parse_into_expr(implementation, into_expr) for into_expr in flatten(exprs)]
9393
for name, expr in named_exprs.items():
9494
out.append(parse_into_expr(implementation, expr).alias(name))
9595
return out
@@ -128,8 +128,7 @@ def evaluate_into_exprs(
128128
series: list[PandasSeries] = [
129129
item
130130
for sublist in [
131-
evaluate_into_expr(df, into_expr) # type: ignore[arg-type]
132-
for into_expr in flatten(*exprs)
131+
evaluate_into_expr(df, into_expr) for into_expr in flatten(*exprs)
133132
]
134133
for item in sublist
135134
]

narwhals/series.py

+6
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,9 @@ def to_numpy(self) -> Any:
130130

131131
def to_pandas(self) -> Any:
132132
return self._series.to_pandas()
133+
134+
def mean(self) -> Any:
135+
return self._series.mean()
136+
137+
def std(self) -> Any:
138+
return self._series.std()

narwhals/utils.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import Any
44
from typing import Iterable
55

6-
if TYPE_CHECKING:
7-
from narwhals.typing import T
8-
96

107
def remove_prefix(text: str, prefix: str) -> str:
118
if text.startswith(prefix):
129
return text[len(prefix) :]
1310
return text
1411

1512

16-
def flatten(*args: T | Iterable[T]) -> list[T]:
17-
out: list[T] = []
18-
for arg in args:
19-
if isinstance(arg, (list, tuple)):
20-
out.extend(arg)
21-
else:
22-
out.append(arg) # type: ignore[arg-type]
23-
return out
13+
def flatten(args: Any) -> list[Any]:
14+
if not args:
15+
return []
16+
if len(args) == 1 and _is_iterable(args[0]):
17+
return args[0] # type: ignore[no-any-return]
18+
return args # type: ignore[no-any-return]
19+
20+
21+
def _is_iterable(arg: Any | Iterable[Any]) -> bool:
22+
from narwhals.series import Series
23+
24+
return isinstance(arg, Iterable) and not isinstance(arg, (str, bytes, Series))
2425

2526

2627
def flatten_str(*args: str | Iterable[str]) -> list[str]:

t.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
# type: ignore
33
import polars
44
import pandas as pd
5+
import polars as pl
56

67
import narwhals as nw
78

89
df_raw = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
9-
df = nw.DataFrame(df_raw)
10+
df = nw.DataFrame(df_raw, features=["lazy"])
1011
df_raw_2 = pd.DataFrame({"a": [1, 3], "c": [7, 9]})
1112
df2 = nw.DataFrame(df_raw_2, features=["lazy"])
1213

@@ -123,3 +124,8 @@
123124

124125
result = df.select("a", "b").select(nw.all() + nw.col("a"))
125126
print(nw.to_native(result))
127+
128+
df = nw.DataFrame(df_raw, features=["eager"])
129+
print(df["a"].mean())
130+
df = nw.DataFrame(pl.from_pandas(df_raw), features=["eager"])
131+
print(df["a"].mean())

0 commit comments

Comments
 (0)