From b4ff6a0040ddb88b0c49cc615fd911dc9306a5b5 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 28 Jan 2025 07:47:57 +0000 Subject: [PATCH 01/13] feat: support more scalar operations for duckdb, Increase width for ipython (#1877) --- narwhals/_duckdb/dataframe.py | 27 +++- narwhals/_duckdb/expr.py | 123 +++++++++++-------- narwhals/_duckdb/expr_dt.py | 34 ++--- narwhals/_duckdb/expr_list.py | 2 +- narwhals/_duckdb/expr_name.py | 2 +- narwhals/_duckdb/expr_str.py | 20 +-- narwhals/_duckdb/namespace.py | 37 +++--- narwhals/_duckdb/selectors.py | 13 +- narwhals/_duckdb/utils.py | 47 +++++-- narwhals/utils.py | 4 +- tests/expr_and_series/lit_test.py | 10 -- tests/expr_and_series/sum_horizontal_test.py | 14 +++ tpch/execute.py | 2 +- 13 files changed, 208 insertions(+), 127 deletions(-) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index c34028e841..c947e088fb 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -9,6 +9,7 @@ import duckdb from duckdb import ColumnExpression +from narwhals._duckdb.utils import ExprKind from narwhals._duckdb.utils import native_to_narwhals_dtype from narwhals._duckdb.utils import parse_exprs_and_named_exprs from narwhals.dependencies import get_duckdb @@ -109,14 +110,24 @@ def select( # TODO(marco): return empty relation with 0 columns? return self._from_native_frame(self._native_frame.limit(0)) - if all(getattr(x, "_returns_scalar", False) for x in exprs) and all( - getattr(x, "_returns_scalar", False) for x in named_exprs.values() + if not any(expr._expr_kind is ExprKind.TRANSFORM for expr in exprs) and not any( + expr._expr_kind is ExprKind.TRANSFORM for expr in named_exprs.values() ): return self._from_native_frame( self._native_frame.aggregate( [val.alias(col) for col, val in new_columns_map.items()] ) ) + if any(expr._expr_kind is ExprKind.AGGREGATION for expr in exprs) or any( + expr._expr_kind is ExprKind.AGGREGATION for expr in named_exprs.values() + ): + msg = ( + "Mixing expressions which aggregate and expressions which don't\n" + "is not yet supported by the DuckDB backend. Once they introduce\n" + "duckdb.WindowExpression to their Python API, we'll be able to\n" + "support this." + ) + raise NotImplementedError(msg) return self._from_native_frame( self._native_frame.select( @@ -140,6 +151,18 @@ def with_columns( **named_exprs: DuckDBExpr, ) -> Self: new_columns_map = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs) + + if any(expr._expr_kind is ExprKind.AGGREGATION for expr in exprs) or any( + expr._expr_kind is ExprKind.AGGREGATION for expr in named_exprs.values() + ): + msg = ( + "Mixing expressions which aggregate and expressions which don't\n" + "is not yet supported by the DuckDB backend. Once they introduce\n" + "duckdb.WindowExpression to their Python API, we'll be able to\n" + "support this." + ) + raise NotImplementedError(msg) + result = [] for col in self._native_frame.columns: if col in new_columns_map: diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index d4be0b45a5..590fd66f2d 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -16,8 +16,9 @@ from narwhals._duckdb.expr_list import DuckDBExprListNamespace from narwhals._duckdb.expr_name import DuckDBExprNameNamespace from narwhals._duckdb.expr_str import DuckDBExprStringNamespace -from narwhals._duckdb.utils import binary_operation_returns_scalar +from narwhals._duckdb.utils import ExprKind from narwhals._duckdb.utils import maybe_evaluate +from narwhals._duckdb.utils import n_ary_operation_expr_kind from narwhals._duckdb.utils import narwhals_to_native_dtype from narwhals.typing import CompliantExpr from narwhals.utils import Implementation @@ -43,9 +44,7 @@ def __init__( function_name: str, evaluate_output_names: Callable[[DuckDBLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, - # Whether the expression is a length-1 Column resulting from - # a reduction, such as `nw.col('a').sum()` - returns_scalar: bool, + expr_kind: ExprKind, backend_version: tuple[int, ...], version: Version, ) -> None: @@ -54,7 +53,7 @@ def __init__( self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names - self._returns_scalar = returns_scalar + self._expr_kind = expr_kind self._backend_version = backend_version self._version = version @@ -87,7 +86,7 @@ def func(_: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name="col", evaluate_output_names=lambda _df: list(column_names), alias_output_names=None, - returns_scalar=False, + expr_kind=ExprKind.TRANSFORM, backend_version=backend_version, version=version, ) @@ -110,7 +109,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name="nth", evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], alias_output_names=None, - returns_scalar=False, + expr_kind=ExprKind.TRANSFORM, backend_version=backend_version, version=version, ) @@ -120,13 +119,23 @@ def _from_call( call: Callable[..., duckdb.Expression], expr_name: str, *, - returns_scalar: bool, + expr_kind: ExprKind, **expressifiable_args: Self | Any, ) -> Self: + """Create expression from callable. + + Arguments: + call: Callable from compliant DataFrame to native Expression + expr_name: Expression name + expr_kind: kind of output expression + expressifiable_args: arguments pass to expression which should be parsed + as expressions (e.g. in `nw.col('a').is_between('b', 'c')`) + """ + def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: native_series_list = self._call(df) other_native_series = { - key: maybe_evaluate(df, value) + key: maybe_evaluate(df, value, expr_kind=expr_kind) for key, value in expressifiable_args.items() } return [ @@ -140,7 +149,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name=f"{self._function_name}->{expr_name}", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, - returns_scalar=returns_scalar, + expr_kind=expr_kind, backend_version=self._backend_version, version=self._version, ) @@ -150,7 +159,7 @@ def __and__(self: Self, other: DuckDBExpr) -> Self: lambda _input, other: _input & other, "__and__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __or__(self: Self, other: DuckDBExpr) -> Self: @@ -158,7 +167,7 @@ def __or__(self: Self, other: DuckDBExpr) -> Self: lambda _input, other: _input | other, "__or__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __add__(self: Self, other: DuckDBExpr) -> Self: @@ -166,7 +175,7 @@ def __add__(self: Self, other: DuckDBExpr) -> Self: lambda _input, other: _input + other, "__add__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __truediv__(self: Self, other: DuckDBExpr) -> Self: @@ -174,7 +183,7 @@ def __truediv__(self: Self, other: DuckDBExpr) -> Self: lambda _input, other: _input / other, "__truediv__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __floordiv__(self: Self, other: DuckDBExpr) -> Self: @@ -182,7 +191,7 @@ def __floordiv__(self: Self, other: DuckDBExpr) -> Self: lambda _input, other: _input.__floordiv__(other), "__floordiv__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __mod__(self: Self, other: DuckDBExpr) -> Self: @@ -190,7 +199,7 @@ def __mod__(self: Self, other: DuckDBExpr) -> Self: lambda _input, other: _input.__mod__(other), "__mod__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __sub__(self: Self, other: DuckDBExpr) -> Self: @@ -198,7 +207,7 @@ def __sub__(self: Self, other: DuckDBExpr) -> Self: lambda _input, other: _input - other, "__sub__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __mul__(self: Self, other: DuckDBExpr) -> Self: @@ -206,7 +215,7 @@ def __mul__(self: Self, other: DuckDBExpr) -> Self: lambda _input, other: _input * other, "__mul__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __pow__(self: Self, other: DuckDBExpr) -> Self: @@ -214,7 +223,7 @@ def __pow__(self: Self, other: DuckDBExpr) -> Self: lambda _input, other: _input**other, "__pow__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __lt__(self: Self, other: DuckDBExpr) -> Self: @@ -222,7 +231,7 @@ def __lt__(self: Self, other: DuckDBExpr) -> Self: lambda _input, other: _input < other, "__lt__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __gt__(self: Self, other: DuckDBExpr) -> Self: @@ -230,7 +239,7 @@ def __gt__(self: Self, other: DuckDBExpr) -> Self: lambda _input, other: _input > other, "__gt__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __le__(self: Self, other: DuckDBExpr) -> Self: @@ -238,7 +247,7 @@ def __le__(self: Self, other: DuckDBExpr) -> Self: lambda _input, other: _input <= other, "__le__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __ge__(self: Self, other: DuckDBExpr) -> Self: @@ -246,7 +255,7 @@ def __ge__(self: Self, other: DuckDBExpr) -> Self: lambda _input, other: _input >= other, "__ge__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __eq__(self: Self, other: DuckDBExpr) -> Self: # type: ignore[override] @@ -254,7 +263,7 @@ def __eq__(self: Self, other: DuckDBExpr) -> Self: # type: ignore[override] lambda _input, other: _input == other, "__eq__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __ne__(self: Self, other: DuckDBExpr) -> Self: # type: ignore[override] @@ -262,14 +271,14 @@ def __ne__(self: Self, other: DuckDBExpr) -> Self: # type: ignore[override] lambda _input, other: _input != other, "__ne__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __invert__(self: Self) -> Self: return self._from_call( lambda _input: ~_input, "__invert__", - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, ) def alias(self: Self, name: str) -> Self: @@ -287,7 +296,7 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=alias_output_names, - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, backend_version=self._backend_version, version=self._version, ) @@ -296,14 +305,14 @@ def abs(self: Self) -> Self: return self._from_call( lambda _input: FunctionExpression("abs", _input), "abs", - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, ) def mean(self: Self) -> Self: return self._from_call( lambda _input: FunctionExpression("mean", _input), "mean", - returns_scalar=True, + expr_kind=ExprKind.AGGREGATION, ) def skew(self: Self) -> Self: @@ -321,27 +330,27 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: ) ) - return self._from_call(func, "skew", returns_scalar=True) + return self._from_call(func, "skew", expr_kind=ExprKind.AGGREGATION) def median(self: Self) -> Self: return self._from_call( lambda _input: FunctionExpression("median", _input), "median", - returns_scalar=True, + expr_kind=ExprKind.AGGREGATION, ) def all(self: Self) -> Self: return self._from_call( lambda _input: FunctionExpression("bool_and", _input), "all", - returns_scalar=True, + expr_kind=ExprKind.AGGREGATION, ) def any(self: Self) -> Self: return self._from_call( lambda _input: FunctionExpression("bool_or", _input), "any", - returns_scalar=True, + expr_kind=ExprKind.AGGREGATION, ) def quantile( @@ -360,7 +369,7 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: return self._from_call( func, "quantile", - returns_scalar=True, + expr_kind=ExprKind.AGGREGATION, ) def clip(self: Self, lower_bound: Any, upper_bound: Any) -> Self: @@ -376,7 +385,7 @@ def func( "clip", lower_bound=lower_bound, upper_bound=upper_bound, - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, ) def is_between( @@ -401,12 +410,14 @@ def func( "is_between", lower_bound=lower_bound, upper_bound=upper_bound, - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, ) def sum(self: Self) -> Self: return self._from_call( - lambda _input: FunctionExpression("sum", _input), "sum", returns_scalar=True + lambda _input: FunctionExpression("sum", _input), + "sum", + expr_kind=ExprKind.AGGREGATION, ) def n_unique(self: Self) -> Self: @@ -424,19 +435,21 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: return self._from_call( func, "n_unique", - returns_scalar=True, + expr_kind=ExprKind.AGGREGATION, ) def count(self: Self) -> Self: return self._from_call( lambda _input: FunctionExpression("count", _input), "count", - returns_scalar=True, + expr_kind=ExprKind.AGGREGATION, ) def len(self: Self) -> Self: return self._from_call( - lambda _input: FunctionExpression("count"), "len", returns_scalar=True + lambda _input: FunctionExpression("count"), + "len", + expr_kind=ExprKind.AGGREGATION, ) def std(self: Self, ddof: int) -> Self: @@ -453,7 +466,7 @@ def _std(_input: duckdb.Expression, ddof: int) -> duckdb.Expression: _std, "std", ddof=ddof, - returns_scalar=True, + expr_kind=ExprKind.AGGREGATION, ) def var(self: Self, ddof: int) -> Self: @@ -465,50 +478,56 @@ def _var(_input: duckdb.Expression, ddof: int) -> duckdb.Expression: _var, "var", ddof=ddof, - returns_scalar=True, + expr_kind=ExprKind.AGGREGATION, ) def max(self: Self) -> Self: return self._from_call( - lambda _input: FunctionExpression("max", _input), "max", returns_scalar=True + lambda _input: FunctionExpression("max", _input), + "max", + expr_kind=ExprKind.AGGREGATION, ) def min(self: Self) -> Self: return self._from_call( - lambda _input: FunctionExpression("min", _input), "min", returns_scalar=True + lambda _input: FunctionExpression("min", _input), + "min", + expr_kind=ExprKind.AGGREGATION, ) def null_count(self: Self) -> Self: return self._from_call( lambda _input: FunctionExpression("sum", _input.isnull().cast("int")), "null_count", - returns_scalar=True, + expr_kind=ExprKind.AGGREGATION, ) def is_null(self: Self) -> Self: return self._from_call( - lambda _input: _input.isnull(), "is_null", returns_scalar=self._returns_scalar + lambda _input: _input.isnull(), + "is_null", + expr_kind=self._expr_kind, ) def is_nan(self: Self) -> Self: return self._from_call( lambda _input: FunctionExpression("isnan", _input), "is_nan", - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, ) def is_finite(self: Self) -> Self: return self._from_call( lambda _input: FunctionExpression("isfinite", _input), "is_finite", - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, ) def is_in(self: Self, other: Sequence[Any]) -> Self: return self._from_call( lambda _input: _input.isin(*[ConstantExpression(x) for x in other]), "is_in", - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, ) def round(self: Self, decimals: int) -> Self: @@ -517,7 +536,7 @@ def round(self: Self, decimals: int) -> Self: "round", _input, ConstantExpression(decimals) ), "round", - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, ) def fill_null(self: Self, value: Any, strategy: Any, limit: int | None) -> Self: @@ -528,7 +547,7 @@ def fill_null(self: Self, value: Any, strategy: Any, limit: int | None) -> Self: return self._from_call( lambda _input: CoalesceOperator(_input, ConstantExpression(value)), "fill_null", - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, ) def cast(self: Self, dtype: DType | type[DType]) -> Self: @@ -539,7 +558,7 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: return self._from_call( func, "cast", - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, ) @property diff --git a/narwhals/_duckdb/expr_dt.py b/narwhals/_duckdb/expr_dt.py index 01556e9be8..b7a750d2ac 100644 --- a/narwhals/_duckdb/expr_dt.py +++ b/narwhals/_duckdb/expr_dt.py @@ -19,42 +19,42 @@ def year(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("year", _input), "year", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def month(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("month", _input), "month", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def day(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("day", _input), "day", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def hour(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("hour", _input), "hour", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def minute(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("minute", _input), "minute", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def second(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("second", _input), "second", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def millisecond(self: Self) -> DuckDBExpr: @@ -62,7 +62,7 @@ def millisecond(self: Self) -> DuckDBExpr: lambda _input: FunctionExpression("millisecond", _input) - FunctionExpression("second", _input) * 1_000, "millisecond", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def microsecond(self: Self) -> DuckDBExpr: @@ -70,7 +70,7 @@ def microsecond(self: Self) -> DuckDBExpr: lambda _input: FunctionExpression("microsecond", _input) - FunctionExpression("second", _input) * 1_000_000, "microsecond", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def nanosecond(self: Self) -> DuckDBExpr: @@ -78,7 +78,7 @@ def nanosecond(self: Self) -> DuckDBExpr: lambda _input: FunctionExpression("nanosecond", _input) - FunctionExpression("second", _input) * 1_000_000_000, "nanosecond", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def to_string(self: Self, format: str) -> DuckDBExpr: # noqa: A002 @@ -87,28 +87,28 @@ def to_string(self: Self, format: str) -> DuckDBExpr: # noqa: A002 "strftime", _input, ConstantExpression(format) ), "to_string", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def weekday(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("isodow", _input), "weekday", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def ordinal_day(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("dayofyear", _input), "ordinal_day", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def date(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: _input.cast("date"), "date", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def total_minutes(self: Self) -> DuckDBExpr: @@ -117,7 +117,7 @@ def total_minutes(self: Self) -> DuckDBExpr: "datepart", ConstantExpression("minute"), _input ), "total_minutes", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def total_seconds(self: Self) -> DuckDBExpr: @@ -126,7 +126,7 @@ def total_seconds(self: Self) -> DuckDBExpr: * FunctionExpression("datepart", ConstantExpression("minute"), _input) + FunctionExpression("datepart", ConstantExpression("second"), _input), "total_seconds", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def total_milliseconds(self: Self) -> DuckDBExpr: @@ -135,7 +135,7 @@ def total_milliseconds(self: Self) -> DuckDBExpr: * FunctionExpression("datepart", ConstantExpression("minute"), _input) + FunctionExpression("datepart", ConstantExpression("millisecond"), _input), "total_milliseconds", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def total_microseconds(self: Self) -> DuckDBExpr: @@ -144,7 +144,7 @@ def total_microseconds(self: Self) -> DuckDBExpr: * FunctionExpression("datepart", ConstantExpression("minute"), _input) + FunctionExpression("datepart", ConstantExpression("microsecond"), _input), "total_microseconds", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def total_nanoseconds(self: Self) -> DuckDBExpr: diff --git a/narwhals/_duckdb/expr_list.py b/narwhals/_duckdb/expr_list.py index 2a9349517a..134df90b6d 100644 --- a/narwhals/_duckdb/expr_list.py +++ b/narwhals/_duckdb/expr_list.py @@ -16,5 +16,5 @@ def len(self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("len", _input), "len", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) diff --git a/narwhals/_duckdb/expr_name.py b/narwhals/_duckdb/expr_name.py index e8f6a31420..5dd1e8b808 100644 --- a/narwhals/_duckdb/expr_name.py +++ b/narwhals/_duckdb/expr_name.py @@ -62,7 +62,7 @@ def _from_alias_output_names( function_name=self._compliant_expr._function_name, evaluate_output_names=self._compliant_expr._evaluate_output_names, alias_output_names=alias_output_names, - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, ) diff --git a/narwhals/_duckdb/expr_str.py b/narwhals/_duckdb/expr_str.py index a4ce8a579d..74ef5c272d 100644 --- a/narwhals/_duckdb/expr_str.py +++ b/narwhals/_duckdb/expr_str.py @@ -23,7 +23,7 @@ def starts_with(self: Self, prefix: str) -> DuckDBExpr: "starts_with", _input, ConstantExpression(prefix) ), "starts_with", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def ends_with(self: Self, suffix: str) -> DuckDBExpr: @@ -32,7 +32,7 @@ def ends_with(self: Self, suffix: str) -> DuckDBExpr: "ends_with", _input, ConstantExpression(suffix) ), "ends_with", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def contains(self: Self, pattern: str, *, literal: bool) -> DuckDBExpr: @@ -44,7 +44,7 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: ) return self._compliant_expr._from_call( - func, "contains", returns_scalar=self._compliant_expr._returns_scalar + func, "contains", expr_kind=self._compliant_expr._expr_kind ) def slice(self: Self, offset: int, length: int) -> DuckDBExpr: @@ -61,28 +61,28 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: ) return self._compliant_expr._from_call( - func, "slice", returns_scalar=self._compliant_expr._returns_scalar + func, "slice", expr_kind=self._compliant_expr._expr_kind ) def len_chars(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("length", _input), "len_chars", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def to_lowercase(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("lower", _input), "to_lowercase", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def to_uppercase(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("upper", _input), "to_uppercase", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def strip_chars(self: Self, characters: str | None) -> DuckDBExpr: @@ -97,7 +97,7 @@ def strip_chars(self: Self, characters: str | None) -> DuckDBExpr: ), ), "strip_chars", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def replace_all(self: Self, pattern: str, value: str, *, literal: bool) -> DuckDBExpr: @@ -111,14 +111,14 @@ def replace_all(self: Self, pattern: str, value: str, *, literal: bool) -> DuckD ConstantExpression("g"), ), "replace_all", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) return self._compliant_expr._from_call( lambda _input: FunctionExpression( "replace", _input, ConstantExpression(pattern), ConstantExpression(value) ), "replace_all", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def replace( diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index c6cd6708e9..d4abcd3ec6 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -18,6 +18,8 @@ from narwhals._duckdb.expr import DuckDBExpr from narwhals._duckdb.selectors import DuckDBSelectorNamespace +from narwhals._duckdb.utils import ExprKind +from narwhals._duckdb.utils import n_ary_operation_expr_kind from narwhals._duckdb.utils import narwhals_to_native_dtype from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names @@ -55,7 +57,7 @@ def _all(df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name="all", evaluate_output_names=lambda df: df.columns, alias_output_names=None, - returns_scalar=False, + expr_kind=ExprKind.TRANSFORM, backend_version=self._backend_version, version=self._version, ) @@ -138,7 +140,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - returns_scalar=False, + expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, ) @@ -154,7 +156,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - returns_scalar=False, + expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, ) @@ -170,7 +172,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name="or_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - returns_scalar=False, + expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, ) @@ -186,7 +188,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - returns_scalar=False, + expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, ) @@ -202,7 +204,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - returns_scalar=False, + expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, ) @@ -223,7 +225,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - returns_scalar=False, + expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, ) @@ -247,7 +249,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name="mean_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - returns_scalar=False, + expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, ) @@ -259,7 +261,10 @@ def when( plx = self.__class__(backend_version=self._backend_version, version=self._version) condition = plx.all_horizontal(*predicates) return DuckDBWhen( - condition, self._backend_version, returns_scalar=False, version=self._version + condition, + self._backend_version, + expr_kind=ExprKind.TRANSFORM, + version=self._version, ) def col(self: Self, *column_names: str) -> DuckDBExpr: @@ -288,7 +293,7 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name="lit", evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, - returns_scalar=True, + expr_kind=ExprKind.LITERAL, backend_version=self._backend_version, version=self._version, ) @@ -303,7 +308,7 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name="len", evaluate_output_names=lambda _df: ["len"], alias_output_names=None, - returns_scalar=True, + expr_kind=ExprKind.AGGREGATION, backend_version=self._backend_version, version=self._version, ) @@ -317,14 +322,14 @@ def __init__( then_value: Any = None, otherwise_value: Any = None, *, - returns_scalar: bool, + expr_kind: ExprKind, version: Version, ) -> None: self._backend_version = backend_version self._condition = condition self._then_value = then_value self._otherwise_value = otherwise_value - self._returns_scalar = returns_scalar + self._expr_kind = expr_kind self._version = version def __call__(self: Self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: @@ -361,7 +366,7 @@ def then(self: Self, value: DuckDBExpr | Any) -> DuckDBThen: value, "_evaluate_output_names", lambda _df: ["literal"] ), alias_output_names=getattr(value, "_alias_output_names", None), - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, backend_version=self._backend_version, version=self._version, ) @@ -376,7 +381,7 @@ def __init__( function_name: str, evaluate_output_names: Callable[[DuckDBLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, - returns_scalar: bool, + expr_kind: ExprKind, backend_version: tuple[int, ...], version: Version, ) -> None: @@ -387,7 +392,7 @@ def __init__( self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names - self._returns_scalar = returns_scalar + self._expr_kind = expr_kind def otherwise(self: Self, value: DuckDBExpr | Any) -> DuckDBExpr: # type ignore because we are setting the `_call` attribute to a diff --git a/narwhals/_duckdb/selectors.py b/narwhals/_duckdb/selectors.py index 11961e5237..5ef0b02fdd 100644 --- a/narwhals/_duckdb/selectors.py +++ b/narwhals/_duckdb/selectors.py @@ -7,6 +7,7 @@ from duckdb import ColumnExpression from narwhals._duckdb.expr import DuckDBExpr +from narwhals._duckdb.utils import ExprKind from narwhals.utils import import_dtypes_module if TYPE_CHECKING: @@ -41,7 +42,7 @@ def evalute_output_names(df: DuckDBLazyFrame) -> Sequence[str]: evaluate_output_names=evalute_output_names, alias_output_names=None, backend_version=self._backend_version, - returns_scalar=False, + expr_kind=ExprKind.TRANSFORM, version=self._version, ) @@ -87,7 +88,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: evaluate_output_names=lambda df: df.columns, alias_output_names=None, backend_version=self._backend_version, - returns_scalar=False, + expr_kind=ExprKind.TRANSFORM, version=self._version, ) @@ -108,7 +109,7 @@ def _to_expr(self: Self) -> DuckDBExpr: evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, version=self._version, ) @@ -133,7 +134,7 @@ def evaluate_output_names(df: DuckDBLazyFrame) -> list[str]: evaluate_output_names=evaluate_output_names, alias_output_names=None, backend_version=self._backend_version, - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, version=self._version, ) else: @@ -164,7 +165,7 @@ def evaluate_output_names(df: DuckDBLazyFrame) -> list[str]: evaluate_output_names=evaluate_output_names, alias_output_names=None, backend_version=self._backend_version, - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, version=self._version, ) else: @@ -191,7 +192,7 @@ def evaluate_output_names(df: DuckDBLazyFrame) -> list[str]: evaluate_output_names=evaluate_output_names, alias_output_names=None, backend_version=self._backend_version, - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, version=self._version, ) else: diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 1d2d28bc78..ae0b465311 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -1,6 +1,8 @@ from __future__ import annotations import re +from enum import Enum +from enum import auto from functools import lru_cache from typing import TYPE_CHECKING from typing import Any @@ -18,7 +20,22 @@ from narwhals.utils import Version -def maybe_evaluate(df: DuckDBLazyFrame, obj: Any) -> Any: +class ExprKind(Enum): + """Describe which kind of expression we are dealing with. + + Composition rule is: + - LITERAL vs LITERAL -> LITERAL + - TRANSFORM vs anything -> TRANSFORM + - anything vs TRANSFORM -> TRANSFORM + - all remaining cases -> AGGREGATION + """ + + LITERAL = auto() # e.g. nw.lit(1) + AGGREGATION = auto() # e.g. nw.col('a').mean() + TRANSFORM = auto() # e.g. nw.col('a').round() + + +def maybe_evaluate(df: DuckDBLazyFrame, obj: Any, *, expr_kind: ExprKind) -> Any: from narwhals._duckdb.expr import DuckDBExpr if isinstance(obj, DuckDBExpr): @@ -27,8 +44,15 @@ def maybe_evaluate(df: DuckDBLazyFrame, obj: Any) -> Any: msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) not supported in this context" raise NotImplementedError(msg) column_result = column_results[0] - if obj._returns_scalar: - msg = "Reductions are not yet supported for DuckDB, at least until they implement duckdb.WindowExpression" + if obj._expr_kind is ExprKind.AGGREGATION and expr_kind is ExprKind.TRANSFORM: + # Returns scalar, but overall expression doesn't. + # Not yet supported. + msg = ( + "Mixing expressions which aggregate and expressions which don't\n" + "is not yet supported by the DuckDB backend. Once they introduce\n" + "duckdb.WindowExpression to their Python API, we'll be able to\n" + "support this." + ) raise NotImplementedError(msg) return column_result return duckdb.ConstantExpression(obj) @@ -40,7 +64,7 @@ def parse_exprs_and_named_exprs( def func( *exprs: DuckDBExpr, **named_exprs: DuckDBExpr ) -> dict[str, duckdb.Expression]: - native_results: dict[str, list[duckdb.Expression]] = {} + native_results: dict[str, duckdb.Expression] = {} for expr in exprs: native_series_list = expr._call(df) output_names = expr._evaluate_output_names(df) @@ -177,8 +201,13 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st raise AssertionError(msg) -def binary_operation_returns_scalar(lhs: DuckDBExpr, rhs: DuckDBExpr | Any) -> bool: - # If `rhs` is a DuckDBExpr, we look at `_returns_scalar`. If it isn't, - # it means that it was a scalar (e.g. nw.col('a') + 1), and so we default - # to `True`. - return lhs._returns_scalar and getattr(rhs, "_returns_scalar", True) +def n_ary_operation_expr_kind(*args: DuckDBExpr | Any) -> ExprKind: + if all( + getattr(arg, "_expr_kind", ExprKind.LITERAL) is ExprKind.LITERAL for arg in args + ): + return ExprKind.LITERAL + if any( + getattr(arg, "_expr_kind", ExprKind.LITERAL) is ExprKind.TRANSFORM for arg in args + ): + return ExprKind.TRANSFORM + return ExprKind.AGGREGATION diff --git a/narwhals/utils.py b/narwhals/utils.py index 4192b820d4..dfde5c6fc6 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -1035,11 +1035,11 @@ def generate_repr(header: str, native_repr: str) -> str: try: terminal_width = os.get_terminal_size().columns except OSError: - terminal_width = 80 + terminal_width = int(os.getenv("COLUMNS", 80)) # noqa: PLW1508 native_lines = native_repr.splitlines() max_native_width = max(len(line) for line in native_lines) - if max_native_width + 2 < terminal_width: + if max_native_width + 2 <= terminal_width: length = max(max_native_width, len(header)) output = f"┌{'─'*length}┐\n" header_extra = length - len(header) diff --git a/tests/expr_and_series/lit_test.py b/tests/expr_and_series/lit_test.py index 525584f154..9005927666 100644 --- a/tests/expr_and_series/lit_test.py +++ b/tests/expr_and_series/lit_test.py @@ -92,13 +92,6 @@ def test_lit_operation_in_select( expected_result: list[int], request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor) and col_name in ( - "left_scalar_with_agg", - "left_lit_with_agg", - "right_lit", - "right_lit_with_agg", - ): - request.applymarker(pytest.mark.xfail) if ( "dask" in str(constructor) and col_name in ("left_lit", "left_scalar") @@ -126,10 +119,7 @@ def test_lit_operation_in_with_columns( col_name: str, expr: nw.Expr, expected_result: list[int], - request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor) and col_name == "scalar_and_lit": - request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2]} df_raw = constructor(data) df = nw.from_native(df_raw).lazy() diff --git a/tests/expr_and_series/sum_horizontal_test.py b/tests/expr_and_series/sum_horizontal_test.py index 25ef6af81e..b26178d558 100644 --- a/tests/expr_and_series/sum_horizontal_test.py +++ b/tests/expr_and_series/sum_horizontal_test.py @@ -45,3 +45,17 @@ def test_sumh_all(constructor: Constructor) -> None: "c": [11, 22, 33], } assert_equal_data(result, expected) + + +def test_sumh_aggregations( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "dask" in str(constructor): + request.applymarker(pytest.mark.xfail) + data = {"a": [1, 2, 3], "b": [10, 20, 30]} + df = nw.from_native(constructor(data)) + result = df.select(nw.sum_horizontal(nw.all().mean().name.suffix("_foo"))) + expected = { + "a_foo": [22], + } + assert_equal_data(result, expected) diff --git a/tpch/execute.py b/tpch/execute.py index 5209ad48e3..25632754b0 100644 --- a/tpch/execute.py +++ b/tpch/execute.py @@ -40,7 +40,7 @@ "dask": lambda x: x.compute(), } -DUCKDB_SKIPS = ["q14", "q15"] +DUCKDB_SKIPS = ["q15"] QUERY_DATA_PATH_MAP = { "q1": (LINEITEM_PATH,), From daaaf974a4bee39b65d13f6500472a07dc41745a Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 28 Jan 2025 08:08:05 +0000 Subject: [PATCH 02/13] chore: remove unused `_depth` from `DuckDBExpr` and `SparkLikeExpr` (#1878) --- narwhals/_duckdb/expr.py | 9 +-------- narwhals/_duckdb/expr_name.py | 1 - narwhals/_duckdb/namespace.py | 13 ------------- narwhals/_duckdb/selectors.py | 12 +----------- narwhals/_spark_like/expr.py | 10 +--------- narwhals/_spark_like/expr_name.py | 1 - narwhals/_spark_like/namespace.py | 13 ------------- narwhals/_spark_like/selectors.py | 12 +----------- 8 files changed, 4 insertions(+), 67 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 590fd66f2d..ccc741be45 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -35,12 +35,12 @@ class DuckDBExpr(CompliantExpr["duckdb.Expression"]): _implementation = Implementation.DUCKDB + _depth = 0 # Unused, just for compatibility with CompliantExpr def __init__( self: Self, call: Callable[[DuckDBLazyFrame], Sequence[duckdb.Expression]], *, - depth: int, function_name: str, evaluate_output_names: Callable[[DuckDBLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, @@ -49,7 +49,6 @@ def __init__( version: Version, ) -> None: self._call = call - self._depth = depth self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names @@ -82,7 +81,6 @@ def func(_: DuckDBLazyFrame) -> list[duckdb.Expression]: return cls( func, - depth=0, function_name="col", evaluate_output_names=lambda _df: list(column_names), alias_output_names=None, @@ -105,7 +103,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return cls( func, - depth=0, function_name="nth", evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], alias_output_names=None, @@ -145,7 +142,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return self.__class__( func, - depth=self._depth + 1, function_name=f"{self._function_name}->{expr_name}", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, @@ -288,11 +284,8 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: raise ValueError(msg) return [name] - # Define this one manually, so that we can - # override `output_names` and not increase depth return self.__class__( self._call, - depth=self._depth, function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=alias_output_names, diff --git a/narwhals/_duckdb/expr_name.py b/narwhals/_duckdb/expr_name.py index 5dd1e8b808..e2d6f2fb43 100644 --- a/narwhals/_duckdb/expr_name.py +++ b/narwhals/_duckdb/expr_name.py @@ -58,7 +58,6 @@ def _from_alias_output_names( ) -> DuckDBExpr: return self._compliant_expr.__class__( call=self._compliant_expr._call, - depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, evaluate_output_names=self._compliant_expr._evaluate_output_names, alias_output_names=alias_output_names, diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index d4abcd3ec6..fd761eb137 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -53,7 +53,6 @@ def _all(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return DuckDBExpr( call=_all, - depth=0, function_name="all", evaluate_output_names=lambda df: df.columns, alias_output_names=None, @@ -136,7 +135,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return DuckDBExpr( call=func, - depth=max(x._depth for x in exprs) + 1, function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), @@ -152,7 +150,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return DuckDBExpr( call=func, - depth=max(x._depth for x in exprs) + 1, function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), @@ -168,7 +165,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return DuckDBExpr( call=func, - depth=max(x._depth for x in exprs) + 1, function_name="or_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), @@ -184,7 +180,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return DuckDBExpr( call=func, - depth=max(x._depth for x in exprs) + 1, function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), @@ -200,7 +195,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return DuckDBExpr( call=func, - depth=max(x._depth for x in exprs) + 1, function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), @@ -221,7 +215,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return DuckDBExpr( call=func, - depth=max(x._depth for x in exprs) + 1, function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), @@ -245,7 +238,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return DuckDBExpr( call=func, - depth=max(x._depth for x in exprs) + 1, function_name="mean_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), @@ -289,7 +281,6 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: return DuckDBExpr( func, - depth=0, function_name="lit", evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, @@ -304,7 +295,6 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: return DuckDBExpr( call=func, - depth=0, function_name="len", evaluate_output_names=lambda _df: ["len"], alias_output_names=None, @@ -360,7 +350,6 @@ def then(self: Self, value: DuckDBExpr | Any) -> DuckDBThen: return DuckDBThen( self, - depth=0, function_name="whenthen", evaluate_output_names=getattr( value, "_evaluate_output_names", lambda _df: ["literal"] @@ -377,7 +366,6 @@ def __init__( self: Self, call: DuckDBWhen, *, - depth: int, function_name: str, evaluate_output_names: Callable[[DuckDBLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, @@ -388,7 +376,6 @@ def __init__( self._backend_version = backend_version self._version = version self._call = call - self._depth = depth self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names diff --git a/narwhals/_duckdb/selectors.py b/narwhals/_duckdb/selectors.py index 5ef0b02fdd..59848a8020 100644 --- a/narwhals/_duckdb/selectors.py +++ b/narwhals/_duckdb/selectors.py @@ -37,7 +37,6 @@ def evalute_output_names(df: DuckDBLazyFrame) -> Sequence[str]: return DuckDBSelector( func, - depth=0, function_name="selector", evaluate_output_names=evalute_output_names, alias_output_names=None, @@ -83,7 +82,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return DuckDBSelector( func, - depth=0, function_name="selector", evaluate_output_names=lambda df: df.columns, alias_output_names=None, @@ -95,16 +93,11 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: class DuckDBSelector(DuckDBExpr): def __repr__(self: Self) -> str: # pragma: no cover - return ( - f"DuckDBSelector(" - f"depth={self._depth}, " - f"function_name={self._function_name})" - ) + return f"DuckDBSelector(" f"function_name={self._function_name})" def _to_expr(self: Self) -> DuckDBExpr: return DuckDBExpr( self._call, - depth=self._depth, function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, @@ -129,7 +122,6 @@ def evaluate_output_names(df: DuckDBLazyFrame) -> list[str]: return DuckDBSelector( call, - depth=0, function_name="selector", evaluate_output_names=evaluate_output_names, alias_output_names=None, @@ -160,7 +152,6 @@ def evaluate_output_names(df: DuckDBLazyFrame) -> list[str]: return DuckDBSelector( call, - depth=0, function_name="selector", evaluate_output_names=evaluate_output_names, alias_output_names=None, @@ -187,7 +178,6 @@ def evaluate_output_names(df: DuckDBLazyFrame) -> list[str]: return DuckDBSelector( call, - depth=0, function_name="selector", evaluate_output_names=evaluate_output_names, alias_output_names=None, diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 6ad08f6b6a..5483f1e776 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -31,12 +31,12 @@ class SparkLikeExpr(CompliantExpr["Column"]): _implementation = Implementation.PYSPARK + _depth = 0 # Unused, just for compatibility with CompliantExpr def __init__( self: Self, call: Callable[[SparkLikeLazyFrame], list[Column]], *, - depth: int, function_name: str, evaluate_output_names: Callable[[SparkLikeLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, @@ -47,7 +47,6 @@ def __init__( version: Version, ) -> None: self._call = call - self._depth = depth self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names @@ -80,7 +79,6 @@ def func(_: SparkLikeLazyFrame) -> list[Column]: return cls( func, - depth=0, function_name="col", evaluate_output_names=lambda _df: list(column_names), alias_output_names=None, @@ -102,7 +100,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return cls( func, - depth=0, function_name="nth", evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], alias_output_names=None, @@ -132,7 +129,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return self.__class__( func, - depth=self._depth + 1, function_name=f"{self._function_name}->{expr_name}", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, @@ -281,11 +277,8 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: raise ValueError(msg) return [name] - # Define this one manually, so that we can - # override `output_names` and not increase depth return self.__class__( self._call, - depth=self._depth, function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=alias_output_names, @@ -487,7 +480,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return self.__class__( func, - depth=self._depth + 1, function_name=self._function_name + "->over", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, diff --git a/narwhals/_spark_like/expr_name.py b/narwhals/_spark_like/expr_name.py index 3e2ef7dd15..c1febdd7ad 100644 --- a/narwhals/_spark_like/expr_name.py +++ b/narwhals/_spark_like/expr_name.py @@ -58,7 +58,6 @@ def _from_alias_output_names( ) -> SparkLikeExpr: return self._compliant_expr.__class__( self._compliant_expr._call, - depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, evaluate_output_names=self._compliant_expr._evaluate_output_names, alias_output_names=alias_output_names, diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index d57ddb133d..5b0ce71325 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -47,7 +47,6 @@ def _all(df: SparkLikeLazyFrame) -> list[Column]: return SparkLikeExpr( call=_all, - depth=0, function_name="all", evaluate_output_names=lambda df: df.columns, alias_output_names=None, @@ -78,7 +77,6 @@ def _lit(_: SparkLikeLazyFrame) -> list[Column]: return SparkLikeExpr( call=_lit, - depth=0, function_name="lit", evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, @@ -93,7 +91,6 @@ def func(_: SparkLikeLazyFrame) -> list[Column]: return SparkLikeExpr( func, - depth=0, function_name="len", evaluate_output_names=lambda _df: ["len"], alias_output_names=None, @@ -109,7 +106,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return SparkLikeExpr( call=func, - depth=max(x._depth for x in exprs) + 1, function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), @@ -125,7 +121,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return SparkLikeExpr( call=func, - depth=max(x._depth for x in exprs) + 1, function_name="any_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), @@ -146,7 +141,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return SparkLikeExpr( call=func, - depth=max(x._depth for x in exprs) + 1, function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), @@ -170,7 +164,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return SparkLikeExpr( call=func, - depth=max(x._depth for x in exprs) + 1, function_name="mean_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), @@ -186,7 +179,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return SparkLikeExpr( call=func, - depth=max(x._depth for x in exprs) + 1, function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), @@ -202,7 +194,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return SparkLikeExpr( call=func, - depth=max(x._depth for x in exprs) + 1, function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), @@ -295,7 +286,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return SparkLikeExpr( call=func, - depth=max(x._depth for x in exprs) + 1, function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), @@ -352,7 +342,6 @@ def then(self: Self, value: SparkLikeExpr | Any) -> SparkLikeThen: return SparkLikeThen( self, - depth=0, function_name="whenthen", evaluate_output_names=getattr( value, "_evaluate_output_names", lambda _df: ["literal"] @@ -369,7 +358,6 @@ def __init__( self: Self, call: SparkLikeWhen, *, - depth: int, function_name: str, evaluate_output_names: Callable[[SparkLikeLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, @@ -380,7 +368,6 @@ def __init__( self._backend_version = backend_version self._version = version self._call = call - self._depth = depth self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names diff --git a/narwhals/_spark_like/selectors.py b/narwhals/_spark_like/selectors.py index 3004a34fa5..b9771bcd7f 100644 --- a/narwhals/_spark_like/selectors.py +++ b/narwhals/_spark_like/selectors.py @@ -34,7 +34,6 @@ def evalute_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: return SparkLikeSelector( func, - depth=0, function_name="selector", evaluate_output_names=evalute_output_names, alias_output_names=None, @@ -80,7 +79,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return SparkLikeSelector( func, - depth=0, function_name="selector", evaluate_output_names=lambda df: df.columns, alias_output_names=None, @@ -92,16 +90,11 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: class SparkLikeSelector(SparkLikeExpr): def __repr__(self: Self) -> str: # pragma: no cover - return ( - f"SparkLikeSelector(" - f"depth={self._depth}, " - f"function_name={self._function_name})" - ) + return f"SparkLikeSelector(" f"function_name={self._function_name})" def _to_expr(self: Self) -> SparkLikeExpr: return SparkLikeExpr( self._call, - depth=self._depth, function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, @@ -126,7 +119,6 @@ def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: return SparkLikeSelector( call, - depth=0, function_name="selector", evaluate_output_names=evaluate_output_names, alias_output_names=None, @@ -157,7 +149,6 @@ def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: return SparkLikeSelector( call, - depth=0, function_name="selector", evaluate_output_names=evaluate_output_names, alias_output_names=None, @@ -184,7 +175,6 @@ def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: return SparkLikeSelector( call, - depth=0, function_name="selector", evaluate_output_names=evaluate_output_names, alias_output_names=None, From e997ff9a3468c99bbe6fe3f9c8e4426ebc3c427d Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 28 Jan 2025 08:35:40 +0000 Subject: [PATCH 03/13] chore: use positional-only arg in evaluate_into_exprs (#1879) --- narwhals/_arrow/dataframe.py | 4 +- narwhals/_dask/dataframe.py | 4 +- narwhals/_dask/utils.py | 50 ++++++++++---------- narwhals/_duckdb/dataframe.py | 4 +- narwhals/_duckdb/utils.py | 44 ++++++++---------- narwhals/_expression_parsing.py | 38 +++++++-------- narwhals/_pandas_like/dataframe.py | 8 ++-- narwhals/_spark_like/dataframe.py | 8 ++-- narwhals/_spark_like/utils.py | 74 ++++++++++++++---------------- 9 files changed, 107 insertions(+), 127 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 8953bd6f9b..0d4206fa4e 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -294,7 +294,7 @@ def simple_select(self, *column_names: str) -> Self: return self._from_native_frame(self._native_frame.select(list(column_names))) def select(self: Self, *exprs: IntoArrowExpr, **named_exprs: IntoArrowExpr) -> Self: - new_series: list[ArrowSeries] = evaluate_into_exprs(self)(*exprs, **named_exprs) + new_series: list[ArrowSeries] = evaluate_into_exprs(self, *exprs, **named_exprs) if not new_series: # return empty dataframe, like Polars does return self._from_native_frame(self._native_frame.__class__.from_arrays([])) @@ -306,7 +306,7 @@ def with_columns( self: Self, *exprs: IntoArrowExpr, **named_exprs: IntoArrowExpr ) -> Self: native_frame = self._native_frame - new_columns: list[ArrowSeries] = evaluate_into_exprs(self)(*exprs, **named_exprs) + new_columns: list[ArrowSeries] = evaluate_into_exprs(self, *exprs, **named_exprs) length = len(self) columns = self.columns diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index cb3c136be4..3213a0eaac 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -75,7 +75,7 @@ def _from_native_frame(self: Self, df: Any) -> Self: def with_columns(self: Self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self: df = self._native_frame - new_series = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs) + new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) df = df.assign(**new_series) return self._from_native_frame(df) @@ -115,7 +115,7 @@ def simple_select(self: Self, *column_names: str) -> Self: ) def select(self: Self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self: - new_series = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs) + new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) if not new_series: # return empty dataframe, like Polars does diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 6a2222f95f..da247b78b5 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING from typing import Any -from typing import Callable from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._pandas_like.utils import select_columns_by_name @@ -45,32 +44,29 @@ def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any: return obj -def parse_exprs_and_named_exprs(df: DaskLazyFrame) -> Callable[..., dict[str, dx.Series]]: - def func(*exprs: DaskExpr, **named_exprs: DaskExpr) -> dict[str, dx.Series]: - native_results: dict[str, dx.Series] = {} - for expr in exprs: - native_series_list = expr._call(df) - return_scalar = getattr(expr, "_returns_scalar", False) - _, aliases = evaluate_output_names_and_aliases(expr, df, []) - if len(aliases) != len(native_series_list): # pragma: no cover - msg = f"Internal error: got aliases {aliases}, but only got {len(native_series_list)} results" - raise AssertionError(msg) - for native_series, alias in zip(native_series_list, aliases): - native_results[alias] = ( - native_series[0] if return_scalar else native_series - ) - for name, value in named_exprs.items(): - native_series_list = value._call(df) - if len(native_series_list) != 1: # pragma: no cover - msg = "Named expressions must return a single column" - raise AssertionError(msg) - return_scalar = getattr(value, "_returns_scalar", False) - native_results[name] = ( - native_series_list[0][0] if return_scalar else native_series_list[0] - ) - return native_results - - return func +def parse_exprs_and_named_exprs( + df: DaskLazyFrame, /, *exprs: DaskExpr, **named_exprs: DaskExpr +) -> dict[str, dx.Series]: + native_results: dict[str, dx.Series] = {} + for expr in exprs: + native_series_list = expr._call(df) + return_scalar = getattr(expr, "_returns_scalar", False) + _, aliases = evaluate_output_names_and_aliases(expr, df, []) + if len(aliases) != len(native_series_list): # pragma: no cover + msg = f"Internal error: got aliases {aliases}, but only got {len(native_series_list)} results" + raise AssertionError(msg) + for native_series, alias in zip(native_series_list, aliases): + native_results[alias] = native_series[0] if return_scalar else native_series + for name, value in named_exprs.items(): + native_series_list = value._call(df) + if len(native_series_list) != 1: # pragma: no cover + msg = "Named expressions must return a single column" + raise AssertionError(msg) + return_scalar = getattr(value, "_returns_scalar", False) + native_results[name] = ( + native_series_list[0][0] if return_scalar else native_series_list[0] + ) + return native_results def add_row_index( diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index c947e088fb..f17c8073a5 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -105,7 +105,7 @@ def select( *exprs: DuckDBExpr, **named_exprs: DuckDBExpr, ) -> Self: - new_columns_map = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs) + new_columns_map = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) if not new_columns_map: # TODO(marco): return empty relation with 0 columns? return self._from_native_frame(self._native_frame.limit(0)) @@ -150,7 +150,7 @@ def with_columns( *exprs: DuckDBExpr, **named_exprs: DuckDBExpr, ) -> Self: - new_columns_map = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs) + new_columns_map = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) if any(expr._expr_kind is ExprKind.AGGREGATION for expr in exprs) or any( expr._expr_kind is ExprKind.AGGREGATION for expr in named_exprs.values() diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index ae0b465311..05c80c2005 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -6,7 +6,6 @@ from functools import lru_cache from typing import TYPE_CHECKING from typing import Any -from typing import Callable import duckdb @@ -59,30 +58,25 @@ def maybe_evaluate(df: DuckDBLazyFrame, obj: Any, *, expr_kind: ExprKind) -> Any def parse_exprs_and_named_exprs( - df: DuckDBLazyFrame, -) -> Callable[..., dict[str, duckdb.Expression]]: - def func( - *exprs: DuckDBExpr, **named_exprs: DuckDBExpr - ) -> dict[str, duckdb.Expression]: - native_results: dict[str, duckdb.Expression] = {} - for expr in exprs: - native_series_list = expr._call(df) - output_names = expr._evaluate_output_names(df) - if expr._alias_output_names is not None: - output_names = expr._alias_output_names(output_names) - if len(output_names) != len(native_series_list): # pragma: no cover - msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" - raise AssertionError(msg) - native_results.update(zip(output_names, native_series_list)) - for col_alias, expr in named_exprs.items(): - native_series_list = expr._call(df) - if len(native_series_list) != 1: # pragma: no cover - msg = "Named expressions must return a single column" - raise ValueError(msg) - native_results[col_alias] = native_series_list[0] - return native_results - - return func + df: DuckDBLazyFrame, /, *exprs: DuckDBExpr, **named_exprs: DuckDBExpr +) -> dict[str, duckdb.Expression]: + native_results: dict[str, duckdb.Expression] = {} + for expr in exprs: + native_series_list = expr._call(df) + output_names = expr._evaluate_output_names(df) + if expr._alias_output_names is not None: + output_names = expr._alias_output_names(output_names) + if len(output_names) != len(native_series_list): # pragma: no cover + msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" + raise AssertionError(msg) + native_results.update(zip(output_names, native_series_list)) + for col_alias, expr in named_exprs.items(): + native_series_list = expr._call(df) + if len(native_series_list) != 1: # pragma: no cover + msg = "Named expressions must return a single column" + raise ValueError(msg) + native_results[col_alias] = native_series_list[0] + return native_results @lru_cache(maxsize=16) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 5fa186dbbc..422614f891 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -62,28 +62,24 @@ def evaluate_into_expr( def evaluate_into_exprs( df: CompliantDataFrame, -) -> Callable[..., list[CompliantSeriesT_co]]: + /, + *exprs: IntoCompliantExpr[CompliantSeriesT_co], + **named_exprs: IntoCompliantExpr[CompliantSeriesT_co], +) -> list[CompliantSeriesT_co]: """Evaluate each expr into Series.""" - - def func( - *exprs: IntoCompliantExpr[CompliantSeriesT_co], - **named_exprs: IntoCompliantExpr[CompliantSeriesT_co], - ) -> list[CompliantSeriesT_co]: - series = [ - item - for sublist in (evaluate_into_expr(df, into_expr) for into_expr in exprs) - for item in sublist - ] - for name, expr in named_exprs.items(): - evaluated_expr = evaluate_into_expr(df, expr) - if len(evaluated_expr) > 1: - msg = "Named expressions must return a single column" # pragma: no cover - raise AssertionError(msg) - to_append = evaluated_expr[0].alias(name) - series.append(to_append) - return series - - return func + series = [ + item + for sublist in (evaluate_into_expr(df, into_expr) for into_expr in exprs) + for item in sublist + ] + for name, expr in named_exprs.items(): + evaluated_expr = evaluate_into_expr(df, expr) + if len(evaluated_expr) > 1: + msg = "Named expressions must return a single column" # pragma: no cover + raise AssertionError(msg) + to_append = evaluated_expr[0].alias(name) + series.append(to_append) + return series def maybe_evaluate_expr( diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 1b4c29b8e5..f07247371c 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -363,8 +363,8 @@ def select( *exprs: IntoPandasLikeExpr, **named_exprs: IntoPandasLikeExpr, ) -> Self: - new_series: list[PandasLikeSeries] = evaluate_into_exprs(self)( - *exprs, **named_exprs + new_series: list[PandasLikeSeries] = evaluate_into_exprs( + self, *exprs, **named_exprs ) if not new_series: # return empty dataframe, like Polars does @@ -435,8 +435,8 @@ def with_columns( **named_exprs: IntoPandasLikeExpr, ) -> Self: index = self._native_frame.index - new_columns: list[PandasLikeSeries] = evaluate_into_exprs(self)( - *exprs, **named_exprs + new_columns: list[PandasLikeSeries] = evaluate_into_exprs( + self, *exprs, **named_exprs ) if not new_columns and len(self) == 0: return self diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 62865159fd..12e175494e 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -97,8 +97,8 @@ def select( *exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr, ) -> Self: - new_columns, returns_scalar = parse_exprs_and_named_exprs(self)( - *exprs, **named_exprs + new_columns, returns_scalar = parse_exprs_and_named_exprs( + self, *exprs, **named_exprs ) if not new_columns: @@ -131,8 +131,8 @@ def with_columns( *exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr, ) -> Self: - new_columns, returns_scalar = parse_exprs_and_named_exprs(self)( - *exprs, **named_exprs + new_columns, returns_scalar = parse_exprs_and_named_exprs( + self, *exprs, **named_exprs ) new_columns_map = { diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index ab7298e529..5eb9b304d5 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -3,7 +3,6 @@ from functools import lru_cache from typing import TYPE_CHECKING from typing import Any -from typing import Callable from pyspark.sql import Column from pyspark.sql import Window @@ -109,47 +108,42 @@ def narwhals_to_native_dtype( def parse_exprs_and_named_exprs( - df: SparkLikeLazyFrame, -) -> Callable[..., tuple[dict[str, Column], list[bool]]]: - def func( - *exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr - ) -> tuple[dict[str, Column], list[bool]]: - native_results: dict[str, list[Column]] = {} - - # `returns_scalar` keeps track if an expression returns a scalar and is not lit. - # Notice that lit is quite special case, since it gets broadcasted by pyspark - # without the need of adding `.over(Window.partitionBy(F.lit(1)))` - returns_scalar: list[bool] = [] - for expr in exprs: - native_series_list = expr._call(df) - output_names = expr._evaluate_output_names(df) - if expr._alias_output_names is not None: - output_names = expr._alias_output_names(output_names) - if len(output_names) != len(native_series_list): # pragma: no cover - msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" - raise AssertionError(msg) - native_results.update(zip(output_names, native_series_list)) - returns_scalar.extend( - [ - expr._returns_scalar - and expr._function_name.split("->", maxsplit=1)[0] != "lit" - ] - * len(output_names) - ) - for col_alias, expr in named_exprs.items(): - native_series_list = expr._call(df) - if len(native_series_list) != 1: # pragma: no cover - msg = "Named expressions must return a single column" - raise ValueError(msg) - native_results[col_alias] = native_series_list[0] - returns_scalar.append( + df: SparkLikeLazyFrame, /, *exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr +) -> tuple[dict[str, Column], list[bool]]: + native_results: dict[str, list[Column]] = {} + + # `returns_scalar` keeps track if an expression returns a scalar and is not lit. + # Notice that lit is quite special case, since it gets broadcasted by pyspark + # without the need of adding `.over(Window.partitionBy(F.lit(1)))` + returns_scalar: list[bool] = [] + for expr in exprs: + native_series_list = expr._call(df) + output_names = expr._evaluate_output_names(df) + if expr._alias_output_names is not None: + output_names = expr._alias_output_names(output_names) + if len(output_names) != len(native_series_list): # pragma: no cover + msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" + raise AssertionError(msg) + native_results.update(zip(output_names, native_series_list)) + returns_scalar.extend( + [ expr._returns_scalar and expr._function_name.split("->", maxsplit=1)[0] != "lit" - ) - - return native_results, returns_scalar - - return func + ] + * len(output_names) + ) + for col_alias, expr in named_exprs.items(): + native_series_list = expr._call(df) + if len(native_series_list) != 1: # pragma: no cover + msg = "Named expressions must return a single column" + raise ValueError(msg) + native_results[col_alias] = native_series_list[0] + returns_scalar.append( + expr._returns_scalar + and expr._function_name.split("->", maxsplit=1)[0] != "lit" + ) + + return native_results, returns_scalar def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) -> Column: From 8f8849c33ba12384730c816d9a14180b279e3d08 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 28 Jan 2025 09:26:46 +0000 Subject: [PATCH 04/13] chore: add "typing :: typed" classifier to pyproject.toml (#1881) --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index ffca10965c..fd2860b76f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Programming Language :: Python", "Operating System :: OS Independent", + "Typing :: Typed", ] [project.optional-dependencies] From f1c04ac46422d9c3ede987c333ab50808c72ccaf Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Tue, 28 Jan 2025 14:24:23 +0000 Subject: [PATCH 05/13] feat: add `selectors.matches` (#1882) * add selectors matches * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/api-reference/selectors.md | 1 + narwhals/_arrow/selectors.py | 19 ++++++++++ narwhals/_dask/selectors.py | 22 ++++++++++++ narwhals/_duckdb/selectors.py | 20 +++++++++++ narwhals/_pandas_like/selectors.py | 20 +++++++++++ narwhals/_polars/namespace.py | 11 ++++++ narwhals/_spark_like/selectors.py | 18 ++++++++++ narwhals/selectors.py | 56 ++++++++++++++++++++++++++++++ narwhals/stable/v1/selectors.py | 2 ++ tests/selectors_test.py | 18 ++++++++++ 10 files changed, 187 insertions(+) diff --git a/docs/api-reference/selectors.md b/docs/api-reference/selectors.md index 6480a869e4..aa990ab18e 100644 --- a/docs/api-reference/selectors.md +++ b/docs/api-reference/selectors.md @@ -15,6 +15,7 @@ set operations are supported: - boolean - by_dtype - categorical + - matches - numeric - string show_root_heading: false diff --git a/narwhals/_arrow/selectors.py b/narwhals/_arrow/selectors.py index c83f4a6acd..e82a99c0bb 100644 --- a/narwhals/_arrow/selectors.py +++ b/narwhals/_arrow/selectors.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING from typing import Any from typing import Sequence @@ -43,6 +44,24 @@ def evalute_output_names(df: ArrowDataFrame) -> Sequence[str]: kwargs={"dtypes": dtypes}, ) + def matches(self: Self, pattern: str) -> ArrowSelector: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + return [df[col] for col in df.columns if re.search(pattern, col)] + + def evalute_output_names(df: ArrowDataFrame) -> Sequence[str]: + return [col for col in df.columns if re.search(pattern, col)] + + return ArrowSelector( + func, + depth=0, + function_name="selector", + evaluate_output_names=evalute_output_names, + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + kwargs={"pattern": pattern}, + ) + def numeric(self: Self) -> ArrowSelector: dtypes = import_dtypes_module(self._version) return self.by_dtype( diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index 96f3feb089..8645d77a96 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING from typing import Any from typing import Sequence @@ -48,6 +49,27 @@ def evalute_output_names(df: DaskLazyFrame) -> Sequence[str]: kwargs={}, ) + def matches(self: Self, pattern: str) -> DaskSelector: + def func(df: DaskLazyFrame) -> list[dx.Series]: + return [ + df._native_frame[col] for col in df.columns if re.search(pattern, col) + ] + + def evalute_output_names(df: DaskLazyFrame) -> Sequence[str]: + return [col for col in df.columns if re.search(pattern, col)] + + return DaskSelector( + func, + depth=0, + function_name="selector", + evaluate_output_names=evalute_output_names, + alias_output_names=None, + backend_version=self._backend_version, + returns_scalar=False, + version=self._version, + kwargs={}, + ) + def numeric(self: Self) -> DaskSelector: dtypes = import_dtypes_module(self._version) return self.by_dtype( diff --git a/narwhals/_duckdb/selectors.py b/narwhals/_duckdb/selectors.py index 59848a8020..30fd32e338 100644 --- a/narwhals/_duckdb/selectors.py +++ b/narwhals/_duckdb/selectors.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING from typing import Any from typing import Sequence @@ -45,6 +46,25 @@ def evalute_output_names(df: DuckDBLazyFrame) -> Sequence[str]: version=self._version, ) + def matches(self: Self, pattern: str) -> DuckDBSelector: + def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: + return [ + ColumnExpression(col) for col in df.columns if re.search(pattern, col) + ] + + def evalute_output_names(df: DuckDBLazyFrame) -> Sequence[str]: + return [col for col in df.columns if re.search(pattern, col)] + + return DuckDBSelector( + func, + function_name="selector", + evaluate_output_names=evalute_output_names, + alias_output_names=None, + backend_version=self._backend_version, + expr_kind=ExprKind.TRANSFORM, + version=self._version, + ) + def numeric(self: Self) -> DuckDBSelector: dtypes = import_dtypes_module(self._version) return self.by_dtype( diff --git a/narwhals/_pandas_like/selectors.py b/narwhals/_pandas_like/selectors.py index d7e18a2e82..2d90adc934 100644 --- a/narwhals/_pandas_like/selectors.py +++ b/narwhals/_pandas_like/selectors.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING from typing import Any from typing import Sequence @@ -48,6 +49,25 @@ def evalute_output_names(df: PandasLikeDataFrame) -> Sequence[str]: kwargs={"dtypes": dtypes}, ) + def matches(self: Self, pattern: str) -> PandasSelector: + def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: + return [df[col] for col in df.columns if re.search(pattern, col)] + + def evalute_output_names(df: PandasLikeDataFrame) -> Sequence[str]: + return [col for col in df.columns if re.search(pattern, col)] + + return PandasSelector( + func, + depth=0, + function_name="selector", + evaluate_output_names=evalute_output_names, + alias_output_names=None, + implementation=self._implementation, + backend_version=self._backend_version, + version=self._version, + kwargs={"pattern": pattern}, + ) + def numeric(self: Self) -> PandasSelector: dtypes = import_dtypes_module(self._version) return self.by_dtype( diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 005af3604a..4a73c0edb6 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -217,6 +217,17 @@ def by_dtype(self: Self, dtypes: Iterable[DType]) -> PolarsExpr: backend_version=self._backend_version, ) + def matches(self: Self, pattern: str) -> PolarsExpr: + import polars as pl + + from narwhals._polars.expr import PolarsExpr + + return PolarsExpr( + pl.selectors.matches(pattern=pattern), + version=self._version, + backend_version=self._backend_version, + ) + def numeric(self: Self) -> PolarsExpr: import polars as pl diff --git a/narwhals/_spark_like/selectors.py b/narwhals/_spark_like/selectors.py index b9771bcd7f..897702da08 100644 --- a/narwhals/_spark_like/selectors.py +++ b/narwhals/_spark_like/selectors.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING from typing import Any from typing import Sequence @@ -42,6 +43,23 @@ def evalute_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: version=self._version, ) + def matches(self: Self, pattern: str) -> SparkLikeSelector: + def func(df: SparkLikeLazyFrame) -> list[Column]: + return [F.col(col) for col in df.columns if re.search(pattern, col)] + + def evalute_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: + return [col for col in df.columns if re.search(pattern, col)] + + return SparkLikeSelector( + func, + function_name="selector", + evaluate_output_names=evalute_output_names, + alias_output_names=None, + backend_version=self._backend_version, + returns_scalar=False, + version=self._version, + ) + def numeric(self: Self) -> SparkLikeSelector: dtypes = import_dtypes_module(self._version) return self.by_dtype( diff --git a/narwhals/selectors.py b/narwhals/selectors.py index dabf6f83f2..6228030b4b 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -87,6 +87,62 @@ def by_dtype(*dtypes: Any) -> Expr: ) +def matches(pattern: str) -> Expr: + """Select all columns that match the given regex pattern. + + Arguments: + pattern: A valid regular expression pattern. + + Returns: + A new expression. + + Examples: + >>> import narwhals as nw + >>> import narwhals.selectors as ncs + >>> import pandas as pd + >>> import polars as pl + >>> + >>> data = { + ... "foo": ["x", "y"], + ... "bar": [123, 456], + ... "baz": [2.0, 5.5], + ... "zap": [0, 1], + ... } + >>> df_pd = pd.DataFrame(data) + >>> df_pl = pl.DataFrame(data) + + Let's define a dataframe-agnostic function to select column names + containing an 'a', preceded by a character that is not 'z': + + >>> @nw.narwhalify + ... def func(df): + ... return df.select(ncs.matches("[^z]a")) + + We can then pass either pandas or Polars dataframes: + + >>> func(df_pd) + bar baz + 0 123 2.0 + 1 456 5.5 + >>> func(df_pl) + shape: (2, 2) + ┌─────┬─────┐ + │ bar ┆ baz │ + │ --- ┆ --- │ + │ i64 ┆ f64 │ + ╞═════╪═════╡ + │ 123 ┆ 2.0 │ + │ 456 ┆ 5.5 │ + └─────┴─────┘ + """ + return Selector( + lambda plx: plx.selectors.matches(pattern), + is_order_dependent=False, + changes_length=False, + aggregates=False, + ) + + def numeric() -> Expr: """Select numeric columns. diff --git a/narwhals/stable/v1/selectors.py b/narwhals/stable/v1/selectors.py index 0d82484e9b..f938dfa353 100644 --- a/narwhals/stable/v1/selectors.py +++ b/narwhals/stable/v1/selectors.py @@ -4,6 +4,7 @@ from narwhals.selectors import boolean from narwhals.selectors import by_dtype from narwhals.selectors import categorical +from narwhals.selectors import matches from narwhals.selectors import numeric from narwhals.selectors import string @@ -12,6 +13,7 @@ "boolean", "by_dtype", "categorical", + "matches", "numeric", "string", ] diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 9e16da5a19..20a75d22bb 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -10,6 +10,7 @@ from narwhals.stable.v1.selectors import boolean from narwhals.stable.v1.selectors import by_dtype from narwhals.stable.v1.selectors import categorical +from narwhals.stable.v1.selectors import matches from narwhals.stable.v1.selectors import numeric from narwhals.stable.v1.selectors import string from tests.utils import POLARS_VERSION @@ -25,6 +26,13 @@ "d": [True, False, True], } +data_regex = { + "foo": ["x", "y"], + "bar": [123, 456], + "baz": [2.0, 5.5], + "zap": [0, 1], +} + def test_selectors(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) @@ -33,6 +41,16 @@ def test_selectors(constructor: Constructor) -> None: assert_equal_data(result, expected) +def test_matches(constructor: Constructor) -> None: + df = nw.from_native(constructor(data_regex)) + result = df.select(matches("[^z]a") + 1) + expected = { + "bar": [124, 457], + "baz": [3.0, 6.5], + } + assert_equal_data(result, expected) + + def test_numeric(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(numeric() + 1) From a11147bb2d061d168d8ab545f21820cf668da42e Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 28 Jan 2025 15:14:23 +0000 Subject: [PATCH 06/13] feat: move robust literal handling for PySpark (#1880) --- narwhals/_spark_like/dataframe.py | 25 +++---- narwhals/_spark_like/expr.py | 111 ++++++++++++++---------------- narwhals/_spark_like/expr_dt.py | 24 +++---- narwhals/_spark_like/expr_name.py | 2 +- narwhals/_spark_like/expr_str.py | 20 +++--- narwhals/_spark_like/namespace.py | 37 +++++----- narwhals/_spark_like/selectors.py | 13 ++-- narwhals/_spark_like/utils.py | 62 +++++++++-------- 8 files changed, 147 insertions(+), 147 deletions(-) diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 12e175494e..367b66da34 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -9,6 +9,7 @@ from pyspark.sql import Window from pyspark.sql import functions as F # noqa: N812 +from narwhals._spark_like.utils import ExprKind from narwhals._spark_like.utils import native_to_narwhals_dtype from narwhals._spark_like.utils import parse_exprs_and_named_exprs from narwhals.typing import CompliantLazyFrame @@ -97,9 +98,7 @@ def select( *exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr, ) -> Self: - new_columns, returns_scalar = parse_exprs_and_named_exprs( - self, *exprs, **named_exprs - ) + new_columns, expr_kinds = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) if not new_columns: # return empty dataframe, like Polars does @@ -110,7 +109,7 @@ def select( return self._from_native_frame(spark_df) - if all(returns_scalar): + if not any(expr_kind is ExprKind.TRANSFORM for expr_kind in expr_kinds): new_columns_list = [ col.alias(col_name) for col_name, col in new_columns.items() ] @@ -118,11 +117,9 @@ def select( else: new_columns_list = [ col.over(Window.partitionBy(F.lit(1))).alias(col_name) - if _returns_scalar + if expr_kind is ExprKind.AGGREGATION else col.alias(col_name) - for (col_name, col), _returns_scalar in zip( - new_columns.items(), returns_scalar - ) + for (col_name, col), expr_kind in zip(new_columns.items(), expr_kinds) ] return self._from_native_frame(self._native_frame.select(*new_columns_list)) @@ -131,15 +128,13 @@ def with_columns( *exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr, ) -> Self: - new_columns, returns_scalar = parse_exprs_and_named_exprs( - self, *exprs, **named_exprs - ) + new_columns, expr_kinds = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) new_columns_map = { - col_name: col.over(Window.partitionBy(F.lit(1))) if _returns_scalar else col - for (col_name, col), _returns_scalar in zip( - new_columns.items(), returns_scalar - ) + col_name: col.over(Window.partitionBy(F.lit(1))) + if expr_kind is ExprKind.AGGREGATION + else col + for (col_name, col), expr_kind in zip(new_columns.items(), expr_kinds) } return self._from_native_frame(self._native_frame.withColumns(new_columns_map)) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 5483f1e776..36c40f2d16 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -12,8 +12,9 @@ from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace from narwhals._spark_like.expr_name import SparkLikeExprNameNamespace from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace -from narwhals._spark_like.utils import binary_operation_returns_scalar +from narwhals._spark_like.utils import ExprKind from narwhals._spark_like.utils import maybe_evaluate +from narwhals._spark_like.utils import n_ary_operation_expr_kind from narwhals._spark_like.utils import narwhals_to_native_dtype from narwhals.typing import CompliantExpr from narwhals.utils import Implementation @@ -40,9 +41,7 @@ def __init__( function_name: str, evaluate_output_names: Callable[[SparkLikeLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, - # Whether the expression is a length-1 Column resulting from - # a reduction, such as `nw.col('a').sum()` - returns_scalar: bool, + expr_kind: ExprKind, backend_version: tuple[int, ...], version: Version, ) -> None: @@ -50,7 +49,7 @@ def __init__( self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names - self._returns_scalar = returns_scalar + self._expr_kind = expr_kind self._backend_version = backend_version self._version = version @@ -82,7 +81,7 @@ def func(_: SparkLikeLazyFrame) -> list[Column]: function_name="col", evaluate_output_names=lambda _df: list(column_names), alias_output_names=None, - returns_scalar=False, + expr_kind=ExprKind.TRANSFORM, backend_version=backend_version, version=version, ) @@ -103,7 +102,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: function_name="nth", evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], alias_output_names=None, - returns_scalar=False, + expr_kind=ExprKind.TRANSFORM, backend_version=backend_version, version=version, ) @@ -113,13 +112,13 @@ def _from_call( call: Callable[..., Column], expr_name: str, *, - returns_scalar: bool, + expr_kind: ExprKind, **expressifiable_args: Self | Any, ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: native_series_list = self._call(df) other_native_series = { - key: maybe_evaluate(df, value, returns_scalar=returns_scalar) + key: maybe_evaluate(df, value, expr_kind=expr_kind) for key, value in expressifiable_args.items() } return [ @@ -132,7 +131,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: function_name=f"{self._function_name}->{expr_name}", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, - returns_scalar=returns_scalar, + expr_kind=expr_kind, backend_version=self._backend_version, version=self._version, ) @@ -142,7 +141,7 @@ def __eq__(self: Self, other: SparkLikeExpr) -> Self: # type: ignore[override] lambda _input, other: _input.__eq__(other), "__eq__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __ne__(self: Self, other: SparkLikeExpr) -> Self: # type: ignore[override] @@ -150,7 +149,7 @@ def __ne__(self: Self, other: SparkLikeExpr) -> Self: # type: ignore[override] lambda _input, other: _input.__ne__(other), "__ne__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __add__(self: Self, other: SparkLikeExpr) -> Self: @@ -158,7 +157,7 @@ def __add__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__add__(other), "__add__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __sub__(self: Self, other: SparkLikeExpr) -> Self: @@ -166,7 +165,7 @@ def __sub__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__sub__(other), "__sub__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __mul__(self: Self, other: SparkLikeExpr) -> Self: @@ -174,7 +173,7 @@ def __mul__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__mul__(other), "__mul__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __truediv__(self: Self, other: SparkLikeExpr) -> Self: @@ -182,7 +181,7 @@ def __truediv__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__truediv__(other), "__truediv__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __floordiv__(self: Self, other: SparkLikeExpr) -> Self: @@ -193,7 +192,7 @@ def _floordiv(_input: Column, other: Column) -> Column: _floordiv, "__floordiv__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __pow__(self: Self, other: SparkLikeExpr) -> Self: @@ -201,7 +200,7 @@ def __pow__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__pow__(other), "__pow__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __mod__(self: Self, other: SparkLikeExpr) -> Self: @@ -209,7 +208,7 @@ def __mod__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__mod__(other), "__mod__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __ge__(self: Self, other: SparkLikeExpr) -> Self: @@ -217,7 +216,7 @@ def __ge__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__ge__(other), "__ge__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __gt__(self: Self, other: SparkLikeExpr) -> Self: @@ -225,7 +224,7 @@ def __gt__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input > other, "__gt__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __le__(self: Self, other: SparkLikeExpr) -> Self: @@ -233,7 +232,7 @@ def __le__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__le__(other), "__le__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __lt__(self: Self, other: SparkLikeExpr) -> Self: @@ -241,7 +240,7 @@ def __lt__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__lt__(other), "__lt__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __and__(self: Self, other: SparkLikeExpr) -> Self: @@ -249,7 +248,7 @@ def __and__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__and__(other), "__and__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __or__(self: Self, other: SparkLikeExpr) -> Self: @@ -257,18 +256,18 @@ def __or__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__or__(other), "__or__", other=other, - returns_scalar=binary_operation_returns_scalar(self, other), + expr_kind=n_ary_operation_expr_kind(self, other), ) def __invert__(self: Self) -> Self: return self._from_call( lambda _input: _input.__invert__(), "__invert__", - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, ) def abs(self: Self) -> Self: - return self._from_call(F.abs, "abs", returns_scalar=self._returns_scalar) + return self._from_call(F.abs, "abs", expr_kind=self._expr_kind) def alias(self: Self, name: str) -> Self: def alias_output_names(names: Sequence[str]) -> Sequence[str]: @@ -282,32 +281,32 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=alias_output_names, - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, backend_version=self._backend_version, version=self._version, ) def all(self: Self) -> Self: - return self._from_call(F.bool_and, "all", returns_scalar=True) + return self._from_call(F.bool_and, "all", expr_kind=ExprKind.AGGREGATION) def any(self: Self) -> Self: - return self._from_call(F.bool_or, "any", returns_scalar=True) + return self._from_call(F.bool_or, "any", expr_kind=ExprKind.AGGREGATION) def cast(self: Self, dtype: DType | type[DType]) -> Self: def _cast(_input: Column) -> Column: spark_dtype = narwhals_to_native_dtype(dtype, self._version) return _input.cast(spark_dtype) - return self._from_call(_cast, "cast", returns_scalar=self._returns_scalar) + return self._from_call(_cast, "cast", expr_kind=self._expr_kind) def count(self: Self) -> Self: - return self._from_call(F.count, "count", returns_scalar=True) + return self._from_call(F.count, "count", expr_kind=ExprKind.AGGREGATION) def max(self: Self) -> Self: - return self._from_call(F.max, "max", returns_scalar=True) + return self._from_call(F.max, "max", expr_kind=ExprKind.AGGREGATION) def mean(self: Self) -> Self: - return self._from_call(F.mean, "mean", returns_scalar=True) + return self._from_call(F.mean, "mean", expr_kind=ExprKind.AGGREGATION) def median(self: Self) -> Self: def _median(_input: Column) -> Column: @@ -319,19 +318,19 @@ def _median(_input: Column) -> Column: return F.median(_input) - return self._from_call(_median, "median", returns_scalar=True) + return self._from_call(_median, "median", expr_kind=ExprKind.AGGREGATION) def min(self: Self) -> Self: - return self._from_call(F.min, "min", returns_scalar=True) + return self._from_call(F.min, "min", expr_kind=ExprKind.AGGREGATION) def null_count(self: Self) -> Self: def _null_count(_input: Column) -> Column: return F.count_if(F.isnull(_input)) - return self._from_call(_null_count, "null_count", returns_scalar=True) + return self._from_call(_null_count, "null_count", expr_kind=ExprKind.AGGREGATION) def sum(self: Self) -> Self: - return self._from_call(F.sum, "sum", returns_scalar=True) + return self._from_call(F.sum, "sum", expr_kind=ExprKind.AGGREGATION) def std(self: Self, ddof: int) -> Self: from functools import partial @@ -342,7 +341,7 @@ def std(self: Self, ddof: int) -> Self: func = partial(_std, ddof=ddof, np_version=parse_version(np.__version__)) - return self._from_call(func, "std", returns_scalar=True) + return self._from_call(func, "std", expr_kind=ExprKind.AGGREGATION) def var(self: Self, ddof: int) -> Self: from functools import partial @@ -353,7 +352,7 @@ def var(self: Self, ddof: int) -> Self: func = partial(_var, ddof=ddof, np_version=parse_version(np.__version__)) - return self._from_call(func, "var", returns_scalar=True) + return self._from_call(func, "var", expr_kind=ExprKind.AGGREGATION) def clip( self: Self, @@ -379,7 +378,7 @@ def _clip(_input: Column, lower_bound: Any, upper_bound: Any) -> Column: "clip", lower_bound=lower_bound, upper_bound=upper_bound, - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, ) def is_between( @@ -402,7 +401,7 @@ def _is_between(_input: Column, lower_bound: Any, upper_bound: Any) -> Column: "is_between", lower_bound=lower_bound, upper_bound=upper_bound, - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, ) def is_duplicated(self: Self) -> Self: @@ -410,9 +409,7 @@ def _is_duplicated(_input: Column) -> Column: # Create a window spec that treats each value separately. return F.count("*").over(Window.partitionBy(_input)) > 1 - return self._from_call( - _is_duplicated, "is_duplicated", returns_scalar=self._returns_scalar - ) + return self._from_call(_is_duplicated, "is_duplicated", expr_kind=self._expr_kind) def is_finite(self: Self) -> Self: def _is_finite(_input: Column) -> Column: @@ -423,9 +420,7 @@ def _is_finite(_input: Column) -> Column: ) return F.when(~F.isnull(_input), is_finite_condition).otherwise(None) - return self._from_call( - _is_finite, "is_finite", returns_scalar=self._returns_scalar - ) + return self._from_call(_is_finite, "is_finite", expr_kind=self._expr_kind) def is_in(self: Self, values: Sequence[Any]) -> Self: def _is_in(_input: Column) -> Column: @@ -434,7 +429,7 @@ def _is_in(_input: Column) -> Column: return self._from_call( _is_in, "is_in", - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, ) def is_unique(self: Self) -> Self: @@ -442,16 +437,14 @@ def _is_unique(_input: Column) -> Column: # Create a window spec that treats each value separately return F.count("*").over(Window.partitionBy(_input)) == 1 - return self._from_call( - _is_unique, "is_unique", returns_scalar=self._returns_scalar - ) + return self._from_call(_is_unique, "is_unique", expr_kind=self._expr_kind) def len(self: Self) -> Self: def _len(_input: Column) -> Column: # Use count(*) to count all rows including nulls return F.count("*") - return self._from_call(_len, "len", returns_scalar=True) + return self._from_call(_len, "len", expr_kind=ExprKind.AGGREGATION) def round(self: Self, decimals: int) -> Self: def _round(_input: Column) -> Column: @@ -460,11 +453,11 @@ def _round(_input: Column) -> Column: return self._from_call( _round, "round", - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, ) def skew(self: Self) -> Self: - return self._from_call(F.skewness, "skew", returns_scalar=True) + return self._from_call(F.skewness, "skew", expr_kind=ExprKind.AGGREGATION) def n_unique(self: Self) -> Self: from pyspark.sql.types import IntegerType @@ -472,7 +465,7 @@ def n_unique(self: Self) -> Self: def _n_unique(_input: Column) -> Column: return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType())) - return self._from_call(_n_unique, "n_unique", returns_scalar=True) + return self._from_call(_n_unique, "n_unique", expr_kind=ExprKind.AGGREGATION) def over(self: Self, keys: list[str]) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: @@ -485,17 +478,17 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, - returns_scalar=False, + expr_kind=ExprKind.TRANSFORM, ) def is_null(self: Self) -> Self: - return self._from_call(F.isnull, "is_null", returns_scalar=self._returns_scalar) + return self._from_call(F.isnull, "is_null", expr_kind=self._expr_kind) def is_nan(self: Self) -> Self: def _is_nan(_input: Column) -> Column: return F.when(F.isnull(_input), None).otherwise(F.isnan(_input)) - return self._from_call(_is_nan, "is_nan", returns_scalar=self._returns_scalar) + return self._from_call(_is_nan, "is_nan", expr_kind=self._expr_kind) @property def str(self: Self) -> SparkLikeExprStringNamespace: diff --git a/narwhals/_spark_like/expr_dt.py b/narwhals/_spark_like/expr_dt.py index ad5d58b829..d0229b381f 100644 --- a/narwhals/_spark_like/expr_dt.py +++ b/narwhals/_spark_like/expr_dt.py @@ -19,49 +19,49 @@ def date(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( F.to_date, "date", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def year(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( F.year, "year", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def month(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( F.month, "month", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def day(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( F.day, "day", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def hour(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( F.hour, "hour", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def minute(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( F.minute, "minute", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def second(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( F.second, "second", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def millisecond(self: Self) -> SparkLikeExpr: @@ -71,7 +71,7 @@ def _millisecond(_input: Column) -> Column: return self._compliant_expr._from_call( _millisecond, "millisecond", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def microsecond(self: Self) -> SparkLikeExpr: @@ -81,7 +81,7 @@ def _microsecond(_input: Column) -> Column: return self._compliant_expr._from_call( _microsecond, "microsecond", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def nanosecond(self: Self) -> SparkLikeExpr: @@ -91,14 +91,14 @@ def _nanosecond(_input: Column) -> Column: return self._compliant_expr._from_call( _nanosecond, "nanosecond", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def ordinal_day(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( F.dayofyear, "ordinal_day", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def weekday(self: Self) -> SparkLikeExpr: @@ -109,5 +109,5 @@ def _weekday(_input: Column) -> Column: return self._compliant_expr._from_call( _weekday, "weekday", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) diff --git a/narwhals/_spark_like/expr_name.py b/narwhals/_spark_like/expr_name.py index c1febdd7ad..bd3a6021d4 100644 --- a/narwhals/_spark_like/expr_name.py +++ b/narwhals/_spark_like/expr_name.py @@ -61,7 +61,7 @@ def _from_alias_output_names( function_name=self._compliant_expr._function_name, evaluate_output_names=self._compliant_expr._evaluate_output_names, alias_output_names=alias_output_names, - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, ) diff --git a/narwhals/_spark_like/expr_str.py b/narwhals/_spark_like/expr_str.py index ff2223bd91..e67fb14aaa 100644 --- a/narwhals/_spark_like/expr_str.py +++ b/narwhals/_spark_like/expr_str.py @@ -20,7 +20,7 @@ def len_chars(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( F.char_length, "len", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def replace_all( @@ -33,7 +33,7 @@ def func(_input: Column) -> Column: return self._compliant_expr._from_call( func, "replace", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def strip_chars(self: Self, characters: str | None) -> SparkLikeExpr: @@ -46,21 +46,21 @@ def func(_input: Column) -> Column: return self._compliant_expr._from_call( func, "strip", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def starts_with(self: Self, prefix: str) -> SparkLikeExpr: return self._compliant_expr._from_call( lambda _input: F.startswith(_input, F.lit(prefix)), "starts_with", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def ends_with(self: Self, suffix: str) -> SparkLikeExpr: return self._compliant_expr._from_call( lambda _input: F.endswith(_input, F.lit(suffix)), "ends_with", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def contains(self: Self, pattern: str, *, literal: bool) -> SparkLikeExpr: @@ -71,7 +71,7 @@ def func(_input: Column) -> Column: return self._compliant_expr._from_call( func, "contains", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def slice(self: Self, offset: int, length: int | None) -> SparkLikeExpr: @@ -87,21 +87,21 @@ def func(_input: Column) -> Column: return self._compliant_expr._from_call( func, "slice", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def to_uppercase(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( F.upper, "to_uppercase", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def to_lowercase(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( F.lower, "to_lowercase", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) def to_datetime(self: Self, format: str | None) -> SparkLikeExpr: # noqa: A002 @@ -111,7 +111,7 @@ def to_datetime(self: Self, format: str | None) -> SparkLikeExpr: # noqa: A002 format=strptime_to_pyspark_format(format), ), "to_datetime", - returns_scalar=self._compliant_expr._returns_scalar, + expr_kind=self._compliant_expr._expr_kind, ) diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 5b0ce71325..f39cd29da0 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -17,6 +17,8 @@ from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.expr import SparkLikeExpr from narwhals._spark_like.selectors import SparkLikeSelectorNamespace +from narwhals._spark_like.utils import ExprKind +from narwhals._spark_like.utils import n_ary_operation_expr_kind from narwhals.typing import CompliantNamespace if TYPE_CHECKING: @@ -50,7 +52,7 @@ def _all(df: SparkLikeLazyFrame) -> list[Column]: function_name="all", evaluate_output_names=lambda df: df.columns, alias_output_names=None, - returns_scalar=False, + expr_kind=ExprKind.TRANSFORM, backend_version=self._backend_version, version=self._version, ) @@ -80,7 +82,7 @@ def _lit(_: SparkLikeLazyFrame) -> list[Column]: function_name="lit", evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, - returns_scalar=True, + expr_kind=ExprKind.LITERAL, backend_version=self._backend_version, version=self._version, ) @@ -94,7 +96,7 @@ def func(_: SparkLikeLazyFrame) -> list[Column]: function_name="len", evaluate_output_names=lambda _df: ["len"], alias_output_names=None, - returns_scalar=True, + expr_kind=ExprKind.AGGREGATION, backend_version=self._backend_version, version=self._version, ) @@ -109,7 +111,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - returns_scalar=False, + expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, ) @@ -124,7 +126,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: function_name="any_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - returns_scalar=False, + expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, ) @@ -144,7 +146,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - returns_scalar=False, + expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, ) @@ -167,7 +169,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: function_name="mean_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - returns_scalar=False, + expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, ) @@ -182,7 +184,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - returns_scalar=False, + expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, ) @@ -197,7 +199,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - returns_scalar=False, + expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, ) @@ -289,7 +291,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - returns_scalar=False, + expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, ) @@ -298,7 +300,10 @@ def when(self: Self, *predicates: SparkLikeExpr) -> SparkLikeWhen: plx = self.__class__(backend_version=self._backend_version, version=self._version) condition = plx.all_horizontal(*predicates) return SparkLikeWhen( - condition, self._backend_version, returns_scalar=False, version=self._version + condition, + self._backend_version, + expr_kind=ExprKind.TRANSFORM, + version=self._version, ) @@ -310,14 +315,14 @@ def __init__( then_value: Any | None = None, otherwise_value: Any | None = None, *, - returns_scalar: bool, + expr_kind: ExprKind, version: Version, ) -> None: self._backend_version = backend_version self._condition = condition self._then_value = then_value self._otherwise_value = otherwise_value - self._returns_scalar = returns_scalar + self._expr_kind = expr_kind self._version = version def __call__(self: Self, df: SparkLikeLazyFrame) -> list[Column]: @@ -347,7 +352,7 @@ def then(self: Self, value: SparkLikeExpr | Any) -> SparkLikeThen: value, "_evaluate_output_names", lambda _df: ["literal"] ), alias_output_names=getattr(value, "_alias_output_names", None), - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, backend_version=self._backend_version, version=self._version, ) @@ -361,7 +366,7 @@ def __init__( function_name: str, evaluate_output_names: Callable[[SparkLikeLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, - returns_scalar: bool, + expr_kind: ExprKind, backend_version: tuple[int, ...], version: Version, ) -> None: @@ -371,7 +376,7 @@ def __init__( self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names - self._returns_scalar = returns_scalar + self._expr_kind = expr_kind def otherwise(self: Self, value: SparkLikeExpr | Any) -> SparkLikeExpr: # type ignore because we are setting the `_call` attribute to a diff --git a/narwhals/_spark_like/selectors.py b/narwhals/_spark_like/selectors.py index 897702da08..2e33b7bc63 100644 --- a/narwhals/_spark_like/selectors.py +++ b/narwhals/_spark_like/selectors.py @@ -8,6 +8,7 @@ from pyspark.sql import functions as F # noqa: N812 from narwhals._spark_like.expr import SparkLikeExpr +from narwhals._spark_like.utils import ExprKind from narwhals.utils import import_dtypes_module if TYPE_CHECKING: @@ -39,7 +40,7 @@ def evalute_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: evaluate_output_names=evalute_output_names, alias_output_names=None, backend_version=self._backend_version, - returns_scalar=False, + expr_kind=ExprKind.TRANSFORM, version=self._version, ) @@ -101,7 +102,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: evaluate_output_names=lambda df: df.columns, alias_output_names=None, backend_version=self._backend_version, - returns_scalar=False, + expr_kind=ExprKind.TRANSFORM, version=self._version, ) @@ -117,7 +118,7 @@ def _to_expr(self: Self) -> SparkLikeExpr: evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, version=self._version, ) @@ -141,7 +142,7 @@ def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: evaluate_output_names=evaluate_output_names, alias_output_names=None, backend_version=self._backend_version, - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, version=self._version, ) else: @@ -171,7 +172,7 @@ def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: evaluate_output_names=evaluate_output_names, alias_output_names=None, backend_version=self._backend_version, - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, version=self._version, ) else: @@ -197,7 +198,7 @@ def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: evaluate_output_names=evaluate_output_names, alias_output_names=None, backend_version=self._backend_version, - returns_scalar=self._returns_scalar, + expr_kind=self._expr_kind, version=self._version, ) else: diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index 5eb9b304d5..a29d96add3 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -1,5 +1,7 @@ from __future__ import annotations +from enum import Enum +from enum import auto from functools import lru_cache from typing import TYPE_CHECKING from typing import Any @@ -22,6 +24,21 @@ from narwhals.utils import Version +class ExprKind(Enum): + """Describe which kind of expression we are dealing with. + + Composition rule is: + - LITERAL vs LITERAL -> LITERAL + - TRANSFORM vs anything -> TRANSFORM + - anything vs TRANSFORM -> TRANSFORM + - all remaining cases -> AGGREGATION + """ + + LITERAL = auto() # e.g. nw.lit(1) + AGGREGATION = auto() # e.g. nw.col('a').mean() + TRANSFORM = auto() # e.g. nw.col('a').round() + + @lru_cache(maxsize=16) def native_to_narwhals_dtype( dtype: pyspark_types.DataType, @@ -109,13 +126,10 @@ def narwhals_to_native_dtype( def parse_exprs_and_named_exprs( df: SparkLikeLazyFrame, /, *exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr -) -> tuple[dict[str, Column], list[bool]]: +) -> tuple[dict[str, Column], list[ExprKind]]: native_results: dict[str, list[Column]] = {} - # `returns_scalar` keeps track if an expression returns a scalar and is not lit. - # Notice that lit is quite special case, since it gets broadcasted by pyspark - # without the need of adding `.over(Window.partitionBy(F.lit(1)))` - returns_scalar: list[bool] = [] + expr_kinds: list[ExprKind] = [] for expr in exprs: native_series_list = expr._call(df) output_names = expr._evaluate_output_names(df) @@ -125,28 +139,19 @@ def parse_exprs_and_named_exprs( msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" raise AssertionError(msg) native_results.update(zip(output_names, native_series_list)) - returns_scalar.extend( - [ - expr._returns_scalar - and expr._function_name.split("->", maxsplit=1)[0] != "lit" - ] - * len(output_names) - ) + expr_kinds.extend([expr._expr_kind] * len(output_names)) for col_alias, expr in named_exprs.items(): native_series_list = expr._call(df) if len(native_series_list) != 1: # pragma: no cover msg = "Named expressions must return a single column" raise ValueError(msg) native_results[col_alias] = native_series_list[0] - returns_scalar.append( - expr._returns_scalar - and expr._function_name.split("->", maxsplit=1)[0] != "lit" - ) + expr_kinds.append(expr._expr_kind) - return native_results, returns_scalar + return native_results, expr_kinds -def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) -> Column: +def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, expr_kind: ExprKind) -> Column: from narwhals._spark_like.expr import SparkLikeExpr if isinstance(obj, SparkLikeExpr): @@ -155,11 +160,7 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) -> msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) not supported in this context" raise NotImplementedError(msg) column_result = column_results[0] - if ( - obj._returns_scalar - and obj._function_name.split("->", maxsplit=1)[0] != "lit" - and not returns_scalar - ): + if obj._expr_kind is ExprKind.AGGREGATION and expr_kind is ExprKind.TRANSFORM: # Returns scalar, but overall expression doesn't. # Let PySpark do its broadcasting return column_result.over(Window.partitionBy(F.lit(1))) @@ -195,8 +196,13 @@ def _var(_input: Column | str, ddof: int, np_version: tuple[int, ...]) -> Column return var(input_col, ddof=ddof) -def binary_operation_returns_scalar(lhs: SparkLikeExpr, rhs: SparkLikeExpr | Any) -> bool: - # If `rhs` is a SparkLikeExpr, we look at `_returns_scalar`. If it isn't, - # it means that it was a scalar (e.g. nw.col('a') + 1), and so we default - # to `True`. - return lhs._returns_scalar and getattr(rhs, "_returns_scalar", True) +def n_ary_operation_expr_kind(*args: SparkLikeExpr | Any) -> ExprKind: + if all( + getattr(arg, "_expr_kind", ExprKind.LITERAL) is ExprKind.LITERAL for arg in args + ): + return ExprKind.LITERAL + if any( + getattr(arg, "_expr_kind", ExprKind.LITERAL) is ExprKind.TRANSFORM for arg in args + ): + return ExprKind.TRANSFORM + return ExprKind.AGGREGATION From 97483094f20a2bb6740f5e8792ae8aec4f96a34b Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 28 Jan 2025 16:58:42 +0000 Subject: [PATCH 07/13] feat: start (silently) adding support for SQLFrame (#1883) --- .github/workflows/mkdocs.yml | 2 +- narwhals/_spark_like/dataframe.py | 88 ++++++++++++------- narwhals/_spark_like/expr.py | 139 +++++++++++++++++++++--------- narwhals/_spark_like/expr_dt.py | 28 +++--- narwhals/_spark_like/expr_name.py | 1 + narwhals/_spark_like/expr_str.py | 62 +++++++++---- narwhals/_spark_like/group_by.py | 3 +- narwhals/_spark_like/namespace.py | 106 +++++++++++++++-------- narwhals/_spark_like/selectors.py | 29 +++++-- narwhals/_spark_like/utils.py | 85 +++++++++--------- narwhals/dependencies.py | 14 +++ narwhals/translate.py | 28 +++++- narwhals/utils.py | 3 + pyproject.toml | 4 +- tests/conftest.py | 14 +++ 15 files changed, 415 insertions(+), 191 deletions(-) diff --git a/.github/workflows/mkdocs.yml b/.github/workflows/mkdocs.yml index 2e87e892ab..80930ea9c2 100644 --- a/.github/workflows/mkdocs.yml +++ b/.github/workflows/mkdocs.yml @@ -29,6 +29,6 @@ jobs: - name: griffe # hopefully temporary until https://github.com/mkdocstrings/mkdocstrings/issues/716 run: pip install git+https://github.com/MarcoGorelli/griffe.git@no-overloads - - run: pip install -e .[docs,pyspark,dask,duckdb] + - run: pip install -e .[docs,dask,duckdb] - run: mkdocs gh-deploy --force diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 367b66da34..13945b0606 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -6,9 +6,6 @@ from typing import Literal from typing import Sequence -from pyspark.sql import Window -from pyspark.sql import functions as F # noqa: N812 - from narwhals._spark_like.utils import ExprKind from narwhals._spark_like.utils import native_to_narwhals_dtype from narwhals._spark_like.utils import parse_exprs_and_named_exprs @@ -40,25 +37,54 @@ def __init__( *, backend_version: tuple[int, ...], version: Version, + implementation: Implementation, ) -> None: self._native_frame = native_dataframe self._backend_version = backend_version - self._implementation = Implementation.PYSPARK + self._implementation = implementation self._version = version validate_backend_version(self._implementation, self._backend_version) - def __native_namespace__(self: Self) -> ModuleType: # pragma: no cover - if self._implementation is Implementation.PYSPARK: - return self._implementation.to_native_namespace() + @property + def _F(self) -> Any: # noqa: N802 + if self._implementation is Implementation.SQLFRAME: + from sqlframe.duckdb import functions + + return functions + from pyspark.sql import functions + + return functions + + @property + def _native_dtypes(self) -> Any: + if self._implementation is Implementation.SQLFRAME: + from sqlframe.duckdb import types + + return types + from pyspark.sql import types + + return types + + @property + def _Window(self) -> Any: # noqa: N802 + if self._implementation is Implementation.SQLFRAME: + from sqlframe.duckdb import Window - msg = f"Expected pyspark, got: {type(self._implementation)}" # pragma: no cover - raise AssertionError(msg) + return Window + from pyspark.sql import Window + + return Window + + def __native_namespace__(self: Self) -> ModuleType: # pragma: no cover + return self._implementation.to_native_namespace() def __narwhals_namespace__(self: Self) -> SparkLikeNamespace: from narwhals._spark_like.namespace import SparkLikeNamespace return SparkLikeNamespace( - backend_version=self._backend_version, version=self._version + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, ) def __narwhals_lazyframe__(self: Self) -> Self: @@ -66,12 +92,18 @@ def __narwhals_lazyframe__(self: Self) -> Self: def _change_version(self: Self, version: Version) -> Self: return self.__class__( - self._native_frame, backend_version=self._backend_version, version=version + self._native_frame, + backend_version=self._backend_version, + version=version, + implementation=self._implementation, ) def _from_native_frame(self: Self, df: DataFrame) -> Self: return self.__class__( - df, backend_version=self._backend_version, version=self._version + df, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, ) @property @@ -102,10 +134,10 @@ def select( if not new_columns: # return empty dataframe, like Polars does - from pyspark.sql.types import StructType - spark_session = self._native_frame.sparkSession - spark_df = spark_session.createDataFrame([], StructType([])) + spark_df = spark_session.createDataFrame( + [], self._native_dtypes.StructType([]) + ) return self._from_native_frame(spark_df) @@ -116,7 +148,7 @@ def select( return self._from_native_frame(self._native_frame.agg(*new_columns_list)) else: new_columns_list = [ - col.over(Window.partitionBy(F.lit(1))).alias(col_name) + col.over(self._Window().partitionBy(self._F.lit(1))).alias(col_name) if expr_kind is ExprKind.AGGREGATION else col.alias(col_name) for (col_name, col), expr_kind in zip(new_columns.items(), expr_kinds) @@ -131,7 +163,7 @@ def with_columns( new_columns, expr_kinds = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) new_columns_map = { - col_name: col.over(Window.partitionBy(F.lit(1))) + col_name: col.over(self._Window().partitionBy(self._F.lit(1))) if expr_kind is ExprKind.AGGREGATION else col for (col_name, col), expr_kind in zip(new_columns.items(), expr_kinds) @@ -152,7 +184,9 @@ def filter(self: Self, *predicates: SparkLikeExpr, **constraints: Any) -> Self: def schema(self: Self) -> dict[str, DType]: return { field.name: native_to_narwhals_dtype( - dtype=field.dataType, version=self._version + dtype=field.dataType, + version=self._version, + spark_types=self._native_dtypes, ) for field in self._native_frame.schema } @@ -186,18 +220,18 @@ def sort( descending: bool | Sequence[bool], nulls_last: bool, ) -> Self: - import pyspark.sql.functions as F # noqa: N812 - if isinstance(descending, bool): descending = [descending] * len(by) if nulls_last: sort_funcs = ( - F.desc_nulls_last if d else F.asc_nulls_last for d in descending + self._F.desc_nulls_last if d else self._F.asc_nulls_last + for d in descending ) else: sort_funcs = ( - F.desc_nulls_first if d else F.asc_nulls_first for d in descending + self._F.desc_nulls_first if d else self._F.asc_nulls_first + for d in descending ) sort_cols = [sort_f(col) for col, sort_f in zip(by, sort_funcs)] @@ -207,14 +241,12 @@ def drop_nulls(self: Self, subset: list[str] | None) -> Self: return self._from_native_frame(self._native_frame.dropna(subset=subset)) def rename(self: Self, mapping: dict[str, str]) -> Self: - import pyspark.sql.functions as F # noqa: N812 - rename_mapping = { colname: mapping.get(colname, colname) for colname in self.columns } return self._from_native_frame( self._native_frame.select( - [F.col(old).alias(new) for old, new in rename_mapping.items()] + [self._F.col(old).alias(new) for old, new in rename_mapping.items()] ) ) @@ -238,8 +270,6 @@ def join( right_on: str | list[str] | None, suffix: str, ) -> Self: - import pyspark.sql.functions as F # noqa: N812 - self_native = self._native_frame other_native = other._native_frame @@ -262,7 +292,7 @@ def join( }, } other = other_native.select( - [F.col(old).alias(new) for old, new in rename_mapping.items()] + [self._F.col(old).alias(new) for old, new in rename_mapping.items()] ) # If how in {"semi", "anti"}, then resulting columns are same as left columns @@ -280,5 +310,5 @@ def join( ) return self._from_native_frame( - self_native.join(other=other, on=left_on, how=how).select(col_order) + self_native.join(other, on=left_on, how=how).select(col_order) ) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 36c40f2d16..f3a9de7f02 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -6,9 +6,6 @@ from typing import Literal from typing import Sequence -from pyspark.sql import Window -from pyspark.sql import functions as F # noqa: N812 - from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace from narwhals._spark_like.expr_name import SparkLikeExprNameNamespace from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace @@ -31,7 +28,6 @@ class SparkLikeExpr(CompliantExpr["Column"]): - _implementation = Implementation.PYSPARK _depth = 0 # Unused, just for compatibility with CompliantExpr def __init__( @@ -44,6 +40,7 @@ def __init__( expr_kind: ExprKind, backend_version: tuple[int, ...], version: Version, + implementation: Implementation, ) -> None: self._call = call self._function_name = function_name @@ -52,10 +49,41 @@ def __init__( self._expr_kind = expr_kind self._backend_version = backend_version self._version = version + self._implementation = implementation def __call__(self: Self, df: SparkLikeLazyFrame) -> Sequence[Column]: return self._call(df) + @property + def _F(self) -> Any: # noqa: N802 + if self._implementation is Implementation.SQLFRAME: + from sqlframe.duckdb import functions + + return functions + from pyspark.sql import functions + + return functions + + @property + def _native_types(self) -> Any: + if self._implementation is Implementation.SQLFRAME: + from sqlframe.duckdb import types + + return types + from pyspark.sql import types + + return types + + @property + def _Window(self) -> Any: # noqa: N802 + if self._implementation is Implementation.SQLFRAME: + from sqlframe.duckdb import Window + + return Window + from pyspark.sql import Window + + return Window + def __narwhals_expr__(self: Self) -> None: ... def __narwhals_namespace__(self: Self) -> SparkLikeNamespace: # pragma: no cover @@ -63,7 +91,9 @@ def __narwhals_namespace__(self: Self) -> SparkLikeNamespace: # pragma: no cove from narwhals._spark_like.namespace import SparkLikeNamespace return SparkLikeNamespace( - backend_version=self._backend_version, version=self._version + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, ) @classmethod @@ -72,9 +102,10 @@ def from_column_names( *column_names: str, backend_version: tuple[int, ...], version: Version, + implementation: Implementation, ) -> Self: - def func(_: SparkLikeLazyFrame) -> list[Column]: - return [F.col(col_name) for col_name in column_names] + def func(df: SparkLikeLazyFrame) -> list[Column]: + return [df._F.col(col_name) for col_name in column_names] return cls( func, @@ -84,6 +115,7 @@ def func(_: SparkLikeLazyFrame) -> list[Column]: expr_kind=ExprKind.TRANSFORM, backend_version=backend_version, version=version, + implementation=implementation, ) @classmethod @@ -92,10 +124,11 @@ def from_column_indices( *column_indices: int, backend_version: tuple[int, ...], version: Version, + implementation: Implementation, ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: columns = df.columns - return [F.col(columns[i]) for i in column_indices] + return [df._F.col(columns[i]) for i in column_indices] return cls( func, @@ -105,6 +138,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: expr_kind=ExprKind.TRANSFORM, backend_version=backend_version, version=version, + implementation=implementation, ) def _from_call( @@ -134,6 +168,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: expr_kind=expr_kind, backend_version=self._backend_version, version=self._version, + implementation=self._implementation, ) def __eq__(self: Self, other: SparkLikeExpr) -> Self: # type: ignore[override] @@ -186,7 +221,7 @@ def __truediv__(self: Self, other: SparkLikeExpr) -> Self: def __floordiv__(self: Self, other: SparkLikeExpr) -> Self: def _floordiv(_input: Column, other: Column) -> Column: - return F.floor(_input / other) + return self._F.floor(_input / other) return self._from_call( _floordiv, @@ -267,7 +302,7 @@ def __invert__(self: Self) -> Self: ) def abs(self: Self) -> Self: - return self._from_call(F.abs, "abs", expr_kind=self._expr_kind) + return self._from_call(self._F.abs, "abs", expr_kind=self._expr_kind) def alias(self: Self, name: str) -> Self: def alias_output_names(names: Sequence[str]) -> Sequence[str]: @@ -284,29 +319,32 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: expr_kind=self._expr_kind, backend_version=self._backend_version, version=self._version, + implementation=self._implementation, ) def all(self: Self) -> Self: - return self._from_call(F.bool_and, "all", expr_kind=ExprKind.AGGREGATION) + return self._from_call(self._F.bool_and, "all", expr_kind=ExprKind.AGGREGATION) def any(self: Self) -> Self: - return self._from_call(F.bool_or, "any", expr_kind=ExprKind.AGGREGATION) + return self._from_call(self._F.bool_or, "any", expr_kind=ExprKind.AGGREGATION) def cast(self: Self, dtype: DType | type[DType]) -> Self: def _cast(_input: Column) -> Column: - spark_dtype = narwhals_to_native_dtype(dtype, self._version) + spark_dtype = narwhals_to_native_dtype( + dtype, self._version, self._native_types + ) return _input.cast(spark_dtype) return self._from_call(_cast, "cast", expr_kind=self._expr_kind) def count(self: Self) -> Self: - return self._from_call(F.count, "count", expr_kind=ExprKind.AGGREGATION) + return self._from_call(self._F.count, "count", expr_kind=ExprKind.AGGREGATION) def max(self: Self) -> Self: - return self._from_call(F.max, "max", expr_kind=ExprKind.AGGREGATION) + return self._from_call(self._F.max, "max", expr_kind=ExprKind.AGGREGATION) def mean(self: Self) -> Self: - return self._from_call(F.mean, "mean", expr_kind=ExprKind.AGGREGATION) + return self._from_call(self._F.mean, "mean", expr_kind=ExprKind.AGGREGATION) def median(self: Self) -> Self: def _median(_input: Column) -> Column: @@ -314,23 +352,23 @@ def _median(_input: Column) -> Column: if parse_version(pyspark.__version__) < (3, 4): # Use percentile_approx with default accuracy parameter (10000) - return F.percentile_approx(_input.cast("double"), 0.5) + return self._F.percentile_approx(_input.cast("double"), 0.5) - return F.median(_input) + return self._F.median(_input) return self._from_call(_median, "median", expr_kind=ExprKind.AGGREGATION) def min(self: Self) -> Self: - return self._from_call(F.min, "min", expr_kind=ExprKind.AGGREGATION) + return self._from_call(self._F.min, "min", expr_kind=ExprKind.AGGREGATION) def null_count(self: Self) -> Self: def _null_count(_input: Column) -> Column: - return F.count_if(F.isnull(_input)) + return self._F.count_if(self._F.isnull(_input)) return self._from_call(_null_count, "null_count", expr_kind=ExprKind.AGGREGATION) def sum(self: Self) -> Self: - return self._from_call(F.sum, "sum", expr_kind=ExprKind.AGGREGATION) + return self._from_call(self._F.sum, "sum", expr_kind=ExprKind.AGGREGATION) def std(self: Self, ddof: int) -> Self: from functools import partial @@ -339,7 +377,12 @@ def std(self: Self, ddof: int) -> Self: from narwhals._spark_like.utils import _std - func = partial(_std, ddof=ddof, np_version=parse_version(np.__version__)) + func = partial( + _std, + ddof=ddof, + np_version=parse_version(np.__version__), + functions=self._F, + ) return self._from_call(func, "std", expr_kind=ExprKind.AGGREGATION) @@ -350,7 +393,12 @@ def var(self: Self, ddof: int) -> Self: from narwhals._spark_like.utils import _var - func = partial(_var, ddof=ddof, np_version=parse_version(np.__version__)) + func = partial( + _var, + ddof=ddof, + np_version=parse_version(np.__version__), + functions=self._F, + ) return self._from_call(func, "var", expr_kind=ExprKind.AGGREGATION) @@ -363,14 +411,14 @@ def _clip(_input: Column, lower_bound: Any, upper_bound: Any) -> Column: result = _input if lower_bound is not None: # Convert lower_bound to a literal Column - result = F.when(result < lower_bound, F.lit(lower_bound)).otherwise( - result - ) + result = self._F.when( + result < lower_bound, self._F.lit(lower_bound) + ).otherwise(result) if upper_bound is not None: # Convert upper_bound to a literal Column - result = F.when(result > upper_bound, F.lit(upper_bound)).otherwise( - result - ) + result = self._F.when( + result > upper_bound, self._F.lit(upper_bound) + ).otherwise(result) return result return self._from_call( @@ -407,7 +455,7 @@ def _is_between(_input: Column, lower_bound: Any, upper_bound: Any) -> Column: def is_duplicated(self: Self) -> Self: def _is_duplicated(_input: Column) -> Column: # Create a window spec that treats each value separately. - return F.count("*").over(Window.partitionBy(_input)) > 1 + return self._F.count("*").over(self._Window.partitionBy(_input)) > 1 return self._from_call(_is_duplicated, "is_duplicated", expr_kind=self._expr_kind) @@ -416,9 +464,13 @@ def _is_finite(_input: Column) -> Column: # A value is finite if it's not NaN, and not infinite, while NULLs should be # preserved is_finite_condition = ( - ~F.isnan(_input) & (_input != float("inf")) & (_input != float("-inf")) + ~self._F.isnan(_input) + & (_input != self._F.lit(float("inf"))) + & (_input != self._F.lit(float("-inf"))) + ) + return self._F.when(~self._F.isnull(_input), is_finite_condition).otherwise( + None ) - return F.when(~F.isnull(_input), is_finite_condition).otherwise(None) return self._from_call(_is_finite, "is_finite", expr_kind=self._expr_kind) @@ -435,20 +487,20 @@ def _is_in(_input: Column) -> Column: def is_unique(self: Self) -> Self: def _is_unique(_input: Column) -> Column: # Create a window spec that treats each value separately - return F.count("*").over(Window.partitionBy(_input)) == 1 + return self._F.count("*").over(self._Window.partitionBy(_input)) == 1 return self._from_call(_is_unique, "is_unique", expr_kind=self._expr_kind) def len(self: Self) -> Self: def _len(_input: Column) -> Column: # Use count(*) to count all rows including nulls - return F.count("*") + return self._F.count("*") return self._from_call(_len, "len", expr_kind=ExprKind.AGGREGATION) def round(self: Self, decimals: int) -> Self: def _round(_input: Column) -> Column: - return F.round(_input, decimals) + return self._F.round(_input, decimals) return self._from_call( _round, @@ -457,19 +509,19 @@ def _round(_input: Column) -> Column: ) def skew(self: Self) -> Self: - return self._from_call(F.skewness, "skew", expr_kind=ExprKind.AGGREGATION) + return self._from_call(self._F.skewness, "skew", expr_kind=ExprKind.AGGREGATION) def n_unique(self: Self) -> Self: - from pyspark.sql.types import IntegerType - def _n_unique(_input: Column) -> Column: - return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType())) + return self._F.count_distinct(_input) + self._F.max( + self._F.isnull(_input).cast(self._native_types.IntegerType()) + ) return self._from_call(_n_unique, "n_unique", expr_kind=ExprKind.AGGREGATION) def over(self: Self, keys: list[str]) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: - return [expr.over(Window.partitionBy(*keys)) for expr in self._call(df)] + return [expr.over(self._Window.partitionBy(*keys)) for expr in self._call(df)] return self.__class__( func, @@ -478,15 +530,18 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, + implementation=self._implementation, expr_kind=ExprKind.TRANSFORM, ) def is_null(self: Self) -> Self: - return self._from_call(F.isnull, "is_null", expr_kind=self._expr_kind) + return self._from_call(self._F.isnull, "is_null", expr_kind=self._expr_kind) def is_nan(self: Self) -> Self: def _is_nan(_input: Column) -> Column: - return F.when(F.isnull(_input), None).otherwise(F.isnan(_input)) + return self._F.when(self._F.isnull(_input), None).otherwise( + self._F.isnan(_input) + ) return self._from_call(_is_nan, "is_nan", expr_kind=self._expr_kind) diff --git a/narwhals/_spark_like/expr_dt.py b/narwhals/_spark_like/expr_dt.py index d0229b381f..20743ea35b 100644 --- a/narwhals/_spark_like/expr_dt.py +++ b/narwhals/_spark_like/expr_dt.py @@ -2,8 +2,6 @@ from typing import TYPE_CHECKING -from pyspark.sql import functions as F # noqa: N812 - if TYPE_CHECKING: from pyspark.sql import Column from typing_extensions import Self @@ -17,56 +15,58 @@ def __init__(self: Self, expr: SparkLikeExpr) -> None: def date(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( - F.to_date, + self._compliant_expr._F.to_date, "date", expr_kind=self._compliant_expr._expr_kind, ) def year(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( - F.year, + self._compliant_expr._F.year, "year", expr_kind=self._compliant_expr._expr_kind, ) def month(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( - F.month, + self._compliant_expr._F.month, "month", expr_kind=self._compliant_expr._expr_kind, ) def day(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( - F.day, + self._compliant_expr._F.day, "day", expr_kind=self._compliant_expr._expr_kind, ) def hour(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( - F.hour, + self._compliant_expr._F.hour, "hour", expr_kind=self._compliant_expr._expr_kind, ) def minute(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( - F.minute, + self._compliant_expr._F.minute, "minute", expr_kind=self._compliant_expr._expr_kind, ) def second(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( - F.second, + self._compliant_expr._F.second, "second", expr_kind=self._compliant_expr._expr_kind, ) def millisecond(self: Self) -> SparkLikeExpr: def _millisecond(_input: Column) -> Column: - return F.floor((F.unix_micros(_input) % 1_000_000) / 1000) + return self._compliant_expr._F.floor( + (self._compliant_expr._F.unix_micros(_input) % 1_000_000) / 1000 + ) return self._compliant_expr._from_call( _millisecond, @@ -76,7 +76,7 @@ def _millisecond(_input: Column) -> Column: def microsecond(self: Self) -> SparkLikeExpr: def _microsecond(_input: Column) -> Column: - return F.unix_micros(_input) % 1_000_000 + return self._compliant_expr._F.unix_micros(_input) % 1_000_000 return self._compliant_expr._from_call( _microsecond, @@ -86,7 +86,7 @@ def _microsecond(_input: Column) -> Column: def nanosecond(self: Self) -> SparkLikeExpr: def _nanosecond(_input: Column) -> Column: - return (F.unix_micros(_input) % 1_000_000) * 1000 + return (self._compliant_expr._F.unix_micros(_input) % 1_000_000) * 1000 return self._compliant_expr._from_call( _nanosecond, @@ -96,7 +96,7 @@ def _nanosecond(_input: Column) -> Column: def ordinal_day(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( - F.dayofyear, + self._compliant_expr._F.dayofyear, "ordinal_day", expr_kind=self._compliant_expr._expr_kind, ) @@ -104,7 +104,7 @@ def ordinal_day(self: Self) -> SparkLikeExpr: def weekday(self: Self) -> SparkLikeExpr: def _weekday(_input: Column) -> Column: # PySpark's dayofweek returns 1-7 for Sunday-Saturday - return (F.dayofweek(_input) + 6) % 7 + return (self._compliant_expr._F.dayofweek(_input) + 6) % 7 return self._compliant_expr._from_call( _weekday, diff --git a/narwhals/_spark_like/expr_name.py b/narwhals/_spark_like/expr_name.py index bd3a6021d4..61cd8b5cb2 100644 --- a/narwhals/_spark_like/expr_name.py +++ b/narwhals/_spark_like/expr_name.py @@ -64,4 +64,5 @@ def _from_alias_output_names( expr_kind=self._compliant_expr._expr_kind, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + implementation=self._compliant_expr._implementation, ) diff --git a/narwhals/_spark_like/expr_str.py b/narwhals/_spark_like/expr_str.py index e67fb14aaa..8bae6a0307 100644 --- a/narwhals/_spark_like/expr_str.py +++ b/narwhals/_spark_like/expr_str.py @@ -3,8 +3,6 @@ from typing import TYPE_CHECKING from typing import overload -from pyspark.sql import functions as F # noqa: N812 - if TYPE_CHECKING: from pyspark.sql import Column from typing_extensions import Self @@ -18,7 +16,7 @@ def __init__(self: Self, expr: SparkLikeExpr) -> None: def len_chars(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( - F.char_length, + self._compliant_expr._F.char_length, "len", expr_kind=self._compliant_expr._expr_kind, ) @@ -27,8 +25,16 @@ def replace_all( self: Self, pattern: str, value: str, *, literal: bool ) -> SparkLikeExpr: def func(_input: Column) -> Column: - replace_all_func = F.replace if literal else F.regexp_replace - return replace_all_func(_input, F.lit(pattern), F.lit(value)) + replace_all_func = ( + self._compliant_expr._F.replace + if literal + else self._compliant_expr._F.regexp_replace + ) + return replace_all_func( + _input, + self._compliant_expr._F.lit(pattern), + self._compliant_expr._F.lit(value), + ) return self._compliant_expr._from_call( func, @@ -41,7 +47,9 @@ def strip_chars(self: Self, characters: str | None) -> SparkLikeExpr: def func(_input: Column) -> Column: to_remove = characters if characters is not None else string.whitespace - return F.btrim(_input, F.lit(to_remove)) + return self._compliant_expr._F.btrim( + _input, self._compliant_expr._F.lit(to_remove) + ) return self._compliant_expr._from_call( func, @@ -51,22 +59,30 @@ def func(_input: Column) -> Column: def starts_with(self: Self, prefix: str) -> SparkLikeExpr: return self._compliant_expr._from_call( - lambda _input: F.startswith(_input, F.lit(prefix)), + lambda _input: self._compliant_expr._F.startswith( + _input, self._compliant_expr._F.lit(prefix) + ), "starts_with", expr_kind=self._compliant_expr._expr_kind, ) def ends_with(self: Self, suffix: str) -> SparkLikeExpr: return self._compliant_expr._from_call( - lambda _input: F.endswith(_input, F.lit(suffix)), + lambda _input: self._compliant_expr._F.endswith( + _input, self._compliant_expr._F.lit(suffix) + ), "ends_with", expr_kind=self._compliant_expr._expr_kind, ) def contains(self: Self, pattern: str, *, literal: bool) -> SparkLikeExpr: def func(_input: Column) -> Column: - contains_func = F.contains if literal else F.regexp - return contains_func(_input, F.lit(pattern)) + contains_func = ( + self._compliant_expr._F.contains + if literal + else self._compliant_expr._F.regexp + ) + return contains_func(_input, self._compliant_expr._F.lit(pattern)) return self._compliant_expr._from_call( func, @@ -78,10 +94,16 @@ def slice(self: Self, offset: int, length: int | None) -> SparkLikeExpr: # From the docs: https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.substring.html # The position is not zero based, but 1 based index. def func(_input: Column) -> Column: - col_length = F.char_length(_input) - - _offset = col_length + F.lit(offset + 1) if offset < 0 else F.lit(offset + 1) - _length = F.lit(length) if length is not None else col_length + col_length = self._compliant_expr._F.char_length(_input) + + _offset = ( + col_length + self._compliant_expr._F.lit(offset + 1) + if offset < 0 + else self._compliant_expr._F.lit(offset + 1) + ) + _length = ( + self._compliant_expr._F.lit(length) if length is not None else col_length + ) return _input.substr(_offset, _length) return self._compliant_expr._from_call( @@ -92,22 +114,26 @@ def func(_input: Column) -> Column: def to_uppercase(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( - F.upper, + self._compliant_expr._F.upper, "to_uppercase", expr_kind=self._compliant_expr._expr_kind, ) def to_lowercase(self: Self) -> SparkLikeExpr: return self._compliant_expr._from_call( - F.lower, + self._compliant_expr._F.lower, "to_lowercase", expr_kind=self._compliant_expr._expr_kind, ) def to_datetime(self: Self, format: str | None) -> SparkLikeExpr: # noqa: A002 return self._compliant_expr._from_call( - lambda _input: F.to_timestamp( - F.replace(_input, F.lit("T"), F.lit(" ")), + lambda _input: self._compliant_expr._F.to_timestamp( + self._compliant_expr._F.replace( + _input, + self._compliant_expr._F.lit("T"), + self._compliant_expr._F.lit(" "), + ), format=strptime_to_pyspark_format(format), ), "to_datetime", diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index 7d87b11dcb..f18cb5e7c5 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -52,7 +52,6 @@ def agg(self: Self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame: return self._compliant_frame._from_native_frame( self._compliant_frame._native_frame.select(*self._keys).dropDuplicates() ) - return self._compliant_frame._from_native_frame( - self._compliant_frame._native_frame.groupBy(self._keys).agg(*agg_columns) + self._compliant_frame._native_frame.groupBy(*self._keys).agg(*agg_columns) ) diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index f39cd29da0..b7e3bfb786 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -9,9 +9,6 @@ from typing import Literal from typing import Sequence -from pyspark.sql import functions as F # noqa: N812 -from pyspark.sql.types import IntegerType - from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals._spark_like.dataframe import SparkLikeLazyFrame @@ -27,25 +24,33 @@ from typing_extensions import Self from narwhals.dtypes import DType + from narwhals.utils import Implementation from narwhals.utils import Version class SparkLikeNamespace(CompliantNamespace["Column"]): def __init__( - self: Self, *, backend_version: tuple[int, ...], version: Version + self: Self, + *, + backend_version: tuple[int, ...], + version: Version, + implementation: Implementation, ) -> None: self._backend_version = backend_version self._version = version + self._implementation = implementation @property def selectors(self: Self) -> SparkLikeSelectorNamespace: return SparkLikeSelectorNamespace( - backend_version=self._backend_version, version=self._version + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, ) def all(self: Self) -> SparkLikeExpr: def _all(df: SparkLikeLazyFrame) -> list[Column]: - return [F.col(col_name) for col_name in df.columns] + return [df._F.col(col_name) for col_name in df.columns] return SparkLikeExpr( call=_all, @@ -55,16 +60,23 @@ def _all(df: SparkLikeLazyFrame) -> list[Column]: expr_kind=ExprKind.TRANSFORM, backend_version=self._backend_version, version=self._version, + implementation=self._implementation, ) def col(self: Self, *column_names: str) -> SparkLikeExpr: return SparkLikeExpr.from_column_names( - *column_names, backend_version=self._backend_version, version=self._version + *column_names, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, ) def nth(self: Self, *column_indices: int) -> SparkLikeExpr: return SparkLikeExpr.from_column_indices( - *column_indices, backend_version=self._backend_version, version=self._version + *column_indices, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, ) def lit(self: Self, value: object, dtype: DType | None) -> SparkLikeExpr: @@ -72,10 +84,8 @@ def lit(self: Self, value: object, dtype: DType | None) -> SparkLikeExpr: msg = "todo" raise NotImplementedError(msg) - def _lit(_: SparkLikeLazyFrame) -> list[Column]: - import pyspark.sql.functions as F # noqa: N812 - - return [F.lit(value)] + def _lit(df: SparkLikeLazyFrame) -> list[Column]: + return [df._F.lit(value)] return SparkLikeExpr( call=_lit, @@ -85,11 +95,12 @@ def _lit(_: SparkLikeLazyFrame) -> list[Column]: expr_kind=ExprKind.LITERAL, backend_version=self._backend_version, version=self._version, + implementation=self._implementation, ) def len(self: Self) -> SparkLikeExpr: - def func(_: SparkLikeLazyFrame) -> list[Column]: - return [F.count("*")] + def func(df: SparkLikeLazyFrame) -> list[Column]: + return [df._F.count("*")] return SparkLikeExpr( func, @@ -99,6 +110,7 @@ def func(_: SparkLikeLazyFrame) -> list[Column]: expr_kind=ExprKind.AGGREGATION, backend_version=self._backend_version, version=self._version, + implementation=self._implementation, ) def all_horizontal(self: Self, *exprs: SparkLikeExpr) -> SparkLikeExpr: @@ -114,6 +126,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, + implementation=self._implementation, ) def any_horizontal(self: Self, *exprs: SparkLikeExpr) -> SparkLikeExpr: @@ -129,6 +142,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, + implementation=self._implementation, ) def sum_horizontal(self: Self, *exprs: SparkLikeExpr) -> SparkLikeExpr: @@ -137,7 +151,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return [ reduce( operator.add, - (F.coalesce(col, F.lit(0)) for col in cols), + (df._F.coalesce(col, df._F.lit(0)) for col in cols), ) ] @@ -149,6 +163,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, + implementation=self._implementation, ) def mean_horizontal(self: Self, *exprs: SparkLikeExpr) -> SparkLikeExpr: @@ -156,10 +171,16 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: cols = [c for _expr in exprs for c in _expr(df)] return [ ( - reduce(operator.add, (F.coalesce(col, F.lit(0)) for col in cols)) + reduce( + operator.add, + (df._F.coalesce(col, df._F.lit(0)) for col in cols), + ) / reduce( operator.add, - (col.isNotNull().cast(IntegerType()) for col in cols), + ( + col.isNotNull().cast(df._native_dtypes.IntegerType()) + for col in cols + ), ) ) ] @@ -172,12 +193,13 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, + implementation=self._implementation, ) def max_horizontal(self: Self, *exprs: SparkLikeExpr) -> SparkLikeExpr: def func(df: SparkLikeLazyFrame) -> list[Column]: cols = [c for _expr in exprs for c in _expr(df)] - return [F.greatest(*cols)] + return [df._F.greatest(*cols)] return SparkLikeExpr( call=func, @@ -187,12 +209,13 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, + implementation=self._implementation, ) def min_horizontal(self: Self, *exprs: SparkLikeExpr) -> SparkLikeExpr: def func(df: SparkLikeLazyFrame) -> list[Column]: cols = [c for _expr in exprs for c in _expr(df)] - return [F.least(*cols)] + return [df._F.least(*cols)] return SparkLikeExpr( call=func, @@ -202,6 +225,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, + implementation=self._implementation, ) def concat( @@ -234,6 +258,7 @@ def concat( native_dataframe=reduce(lambda x, y: x.union(y), dfs), backend_version=self._backend_version, version=self._version, + implementation=self._implementation, ) if how == "diagonal": @@ -243,6 +268,7 @@ def concat( ), backend_version=self._backend_version, version=self._version, + implementation=self._implementation, ) raise NotImplementedError @@ -252,35 +278,36 @@ def concat_str( separator: str, ignore_nulls: bool, ) -> SparkLikeExpr: - from pyspark.sql.types import StringType - def func(df: SparkLikeLazyFrame) -> list[Column]: cols = [s for _expr in exprs for s in _expr(df)] - cols_casted = [s.cast(StringType()) for s in cols] - null_mask = [F.isnull(s) for _expr in exprs for s in _expr(df)] + cols_casted = [s.cast(df._native_dtypes.StringType()) for s in cols] + null_mask = [df._F.isnull(s) for _expr in exprs for s in _expr(df)] if not ignore_nulls: null_mask_result = reduce(lambda x, y: x | y, null_mask) - result = F.when( + result = df._F.when( ~null_mask_result, reduce( - lambda x, y: F.format_string(f"%s{separator}%s", x, y), + lambda x, y: df._F.format_string(f"%s{separator}%s", x, y), cols_casted, ), - ).otherwise(F.lit(None)) + ).otherwise(df._F.lit(None)) else: init_value, *values = [ - F.when(~nm, col).otherwise(F.lit("")) + df._F.when(~nm, col).otherwise(df._F.lit("")) for col, nm in zip(cols_casted, null_mask) ] separators = ( - F.when(nm, F.lit("")).otherwise(F.lit(separator)) + df._F.when(nm, df._F.lit("")).otherwise(df._F.lit(separator)) for nm in null_mask[:-1] ) result = reduce( - lambda x, y: F.format_string("%s%s", x, y), - (F.format_string("%s%s", s, v) for s, v in zip(separators, values)), + lambda x, y: df._F.format_string("%s%s", x, y), + ( + df._F.format_string("%s%s", s, v) + for s, v in zip(separators, values) + ), init_value, ) @@ -294,16 +321,22 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: expr_kind=n_ary_operation_expr_kind(*exprs), backend_version=self._backend_version, version=self._version, + implementation=self._implementation, ) def when(self: Self, *predicates: SparkLikeExpr) -> SparkLikeWhen: - plx = self.__class__(backend_version=self._backend_version, version=self._version) + plx = self.__class__( + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) condition = plx.all_horizontal(*predicates) return SparkLikeWhen( condition, self._backend_version, expr_kind=ExprKind.TRANSFORM, version=self._version, + implementation=self._implementation, ) @@ -317,6 +350,7 @@ def __init__( *, expr_kind: ExprKind, version: Version, + implementation: Implementation, ) -> None: self._backend_version = backend_version self._condition = condition @@ -324,6 +358,7 @@ def __init__( self._otherwise_value = otherwise_value self._expr_kind = expr_kind self._version = version + self._implementation = implementation def __call__(self: Self, df: SparkLikeLazyFrame) -> list[Column]: condition = self._condition(df)[0] @@ -332,15 +367,15 @@ def __call__(self: Self, df: SparkLikeLazyFrame) -> list[Column]: value_ = self._then_value(df)[0] else: # `self._then_value` is a scalar - value_ = F.lit(self._then_value) + value_ = df._F.lit(self._then_value) if isinstance(self._otherwise_value, SparkLikeExpr): other_ = self._otherwise_value(df)[0] else: # `self._otherwise_value` is a scalar - other_ = F.lit(self._otherwise_value) + other_ = df._F.lit(self._otherwise_value) - return [F.when(condition=condition, value=value_).otherwise(value=other_)] + return [df._F.when(condition=condition, value=value_).otherwise(value=other_)] def then(self: Self, value: SparkLikeExpr | Any) -> SparkLikeThen: self._then_value = value @@ -355,6 +390,7 @@ def then(self: Self, value: SparkLikeExpr | Any) -> SparkLikeThen: expr_kind=self._expr_kind, backend_version=self._backend_version, version=self._version, + implementation=self._implementation, ) @@ -369,6 +405,7 @@ def __init__( expr_kind: ExprKind, backend_version: tuple[int, ...], version: Version, + implementation: Implementation, ) -> None: self._backend_version = backend_version self._version = version @@ -377,6 +414,7 @@ def __init__( self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._expr_kind = expr_kind + self._implementation = implementation def otherwise(self: Self, value: SparkLikeExpr | Any) -> SparkLikeExpr: # type ignore because we are setting the `_call` attribute to a diff --git a/narwhals/_spark_like/selectors.py b/narwhals/_spark_like/selectors.py index 2e33b7bc63..b0ad7c594a 100644 --- a/narwhals/_spark_like/selectors.py +++ b/narwhals/_spark_like/selectors.py @@ -5,10 +5,9 @@ from typing import Any from typing import Sequence -from pyspark.sql import functions as F # noqa: N812 - from narwhals._spark_like.expr import SparkLikeExpr from narwhals._spark_like.utils import ExprKind +from narwhals.utils import Implementation from narwhals.utils import import_dtypes_module if TYPE_CHECKING: @@ -22,14 +21,19 @@ class SparkLikeSelectorNamespace: def __init__( - self: Self, *, backend_version: tuple[int, ...], version: Version + self: Self, + *, + backend_version: tuple[int, ...], + version: Version, + implementation: Implementation, ) -> None: self._backend_version = backend_version self._version = version + self._implementation = implementation def by_dtype(self: Self, dtypes: list[DType | type[DType]]) -> SparkLikeSelector: def func(df: SparkLikeLazyFrame) -> list[Column]: - return [F.col(col) for col in df.columns if df.schema[col] in dtypes] + return [df._F.col(col) for col in df.columns if df.schema[col] in dtypes] def evalute_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: return [col for col in df.columns if df.schema[col] in dtypes] @@ -42,11 +46,12 @@ def evalute_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: backend_version=self._backend_version, expr_kind=ExprKind.TRANSFORM, version=self._version, + implementation=self._implementation, ) def matches(self: Self, pattern: str) -> SparkLikeSelector: def func(df: SparkLikeLazyFrame) -> list[Column]: - return [F.col(col) for col in df.columns if re.search(pattern, col)] + return [df._F.col(col) for col in df.columns if re.search(pattern, col)] def evalute_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: return [col for col in df.columns if re.search(pattern, col)] @@ -57,8 +62,9 @@ def evalute_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: evaluate_output_names=evalute_output_names, alias_output_names=None, backend_version=self._backend_version, - returns_scalar=False, + expr_kind=ExprKind.TRANSFORM, version=self._version, + implementation=self._implementation, ) def numeric(self: Self) -> SparkLikeSelector: @@ -94,7 +100,7 @@ def boolean(self: Self) -> SparkLikeSelector: def all(self: Self) -> SparkLikeSelector: def func(df: SparkLikeLazyFrame) -> list[Column]: - return [F.col(col) for col in df.columns] + return [df._F.col(col) for col in df.columns] return SparkLikeSelector( func, @@ -104,6 +110,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: backend_version=self._backend_version, expr_kind=ExprKind.TRANSFORM, version=self._version, + implementation=self._implementation, ) @@ -120,6 +127,7 @@ def _to_expr(self: Self) -> SparkLikeExpr: backend_version=self._backend_version, expr_kind=self._expr_kind, version=self._version, + implementation=self._implementation, ) def __sub__(self: Self, other: SparkLikeSelector | Any) -> SparkLikeSelector | Any: @@ -144,6 +152,7 @@ def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: backend_version=self._backend_version, expr_kind=self._expr_kind, version=self._version, + implementation=self._implementation, ) else: return self._to_expr() - other @@ -174,6 +183,7 @@ def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: backend_version=self._backend_version, expr_kind=self._expr_kind, version=self._version, + implementation=self._implementation, ) else: return self._to_expr() | other @@ -200,6 +210,7 @@ def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: backend_version=self._backend_version, expr_kind=self._expr_kind, version=self._version, + implementation=self._implementation, ) else: return self._to_expr() & other @@ -207,7 +218,9 @@ def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: def __invert__(self: Self) -> SparkLikeSelector: return ( SparkLikeSelectorNamespace( - backend_version=self._backend_version, version=self._version + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, ).all() - self ) diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index a29d96add3..82d7feb28b 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -6,16 +6,14 @@ from typing import TYPE_CHECKING from typing import Any -from pyspark.sql import Column -from pyspark.sql import Window -from pyspark.sql import functions as F # noqa: N812 -from pyspark.sql import types as pyspark_types - from narwhals.exceptions import UnsupportedDTypeError from narwhals.utils import import_dtypes_module from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: + from types import ModuleType + + import pyspark.types as pyspark_types from pyspark.sql import Column from narwhals._spark_like.dataframe import SparkLikeLazyFrame @@ -43,65 +41,66 @@ class ExprKind(Enum): def native_to_narwhals_dtype( dtype: pyspark_types.DataType, version: Version, + spark_types: ModuleType, ) -> DType: # pragma: no cover dtypes = import_dtypes_module(version=version) - if isinstance(dtype, pyspark_types.DoubleType): + if isinstance(dtype, spark_types.DoubleType): return dtypes.Float64() - if isinstance(dtype, pyspark_types.FloatType): + if isinstance(dtype, spark_types.FloatType): return dtypes.Float32() - if isinstance(dtype, pyspark_types.LongType): + if isinstance(dtype, spark_types.LongType): return dtypes.Int64() - if isinstance(dtype, pyspark_types.IntegerType): + if isinstance(dtype, spark_types.IntegerType): return dtypes.Int32() - if isinstance(dtype, pyspark_types.ShortType): + if isinstance(dtype, spark_types.ShortType): return dtypes.Int16() - if isinstance(dtype, pyspark_types.ByteType): + if isinstance(dtype, spark_types.ByteType): return dtypes.Int8() string_types = [ - pyspark_types.StringType, - pyspark_types.VarcharType, - pyspark_types.CharType, + spark_types.StringType, + spark_types.VarcharType, + spark_types.CharType, ] if any(isinstance(dtype, t) for t in string_types): return dtypes.String() - if isinstance(dtype, pyspark_types.BooleanType): + if isinstance(dtype, spark_types.BooleanType): return dtypes.Boolean() - if isinstance(dtype, pyspark_types.DateType): + if isinstance(dtype, spark_types.DateType): return dtypes.Date() datetime_types = [ - pyspark_types.TimestampType, - pyspark_types.TimestampNTZType, + spark_types.TimestampType, + spark_types.TimestampNTZType, ] if any(isinstance(dtype, t) for t in datetime_types): return dtypes.Datetime() - if isinstance(dtype, pyspark_types.DecimalType): # pragma: no cover + if isinstance(dtype, spark_types.DecimalType): # pragma: no cover # TODO(unassigned): cover this in dtypes_test.py return dtypes.Decimal() return dtypes.Unknown() def narwhals_to_native_dtype( - dtype: DType | type[DType], version: Version + dtype: DType | type[DType], version: Version, spark_types: ModuleType ) -> pyspark_types.DataType: dtypes = import_dtypes_module(version) if isinstance_or_issubclass(dtype, dtypes.Float64): - return pyspark_types.DoubleType() + return spark_types.DoubleType() if isinstance_or_issubclass(dtype, dtypes.Float32): - return pyspark_types.FloatType() + return spark_types.FloatType() if isinstance_or_issubclass(dtype, dtypes.Int64): - return pyspark_types.LongType() + return spark_types.LongType() if isinstance_or_issubclass(dtype, dtypes.Int32): - return pyspark_types.IntegerType() + return spark_types.IntegerType() if isinstance_or_issubclass(dtype, dtypes.Int16): - return pyspark_types.ShortType() + return spark_types.ShortType() if isinstance_or_issubclass(dtype, dtypes.Int8): - return pyspark_types.ByteType() + return spark_types.ByteType() if isinstance_or_issubclass(dtype, dtypes.String): - return pyspark_types.StringType() + return spark_types.StringType() if isinstance_or_issubclass(dtype, dtypes.Boolean): - return pyspark_types.BooleanType() + return spark_types.BooleanType() if isinstance_or_issubclass(dtype, (dtypes.Date, dtypes.Datetime)): msg = "Converting to Date or Datetime dtype is not supported yet" raise NotImplementedError(msg) @@ -163,36 +162,42 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, expr_kind: ExprKind) -> if obj._expr_kind is ExprKind.AGGREGATION and expr_kind is ExprKind.TRANSFORM: # Returns scalar, but overall expression doesn't. # Let PySpark do its broadcasting - return column_result.over(Window.partitionBy(F.lit(1))) + return column_result.over(df._Window().partitionBy(df._F.lit(1))) return column_result - return F.lit(obj) + return df._F.lit(obj) -def _std(_input: Column | str, ddof: int, np_version: tuple[int, ...]) -> Column: +def _std( + _input: Column | str, ddof: int, np_version: tuple[int, ...], functions: Any +) -> Column: if np_version > (2, 0): if ddof == 1: - return F.stddev_samp(_input) + return functions.stddev_samp(_input) - n_rows = F.count(_input) - return F.stddev_samp(_input) * F.sqrt((n_rows - 1) / (n_rows - ddof)) + n_rows = functions.count(_input) + return functions.stddev_samp(_input) * functions.sqrt( + (n_rows - 1) / (n_rows - ddof) + ) from pyspark.pandas.spark.functions import stddev - input_col = F.col(_input) if isinstance(_input, str) else _input + input_col = functions.col(_input) if isinstance(_input, str) else _input return stddev(input_col, ddof=ddof) -def _var(_input: Column | str, ddof: int, np_version: tuple[int, ...]) -> Column: +def _var( + _input: Column | str, ddof: int, np_version: tuple[int, ...], functions: Any +) -> Column: if np_version > (2, 0): if ddof == 1: - return F.var_samp(_input) + return functions.var_samp(_input) - n_rows = F.count(_input) - return F.var_samp(_input) * (n_rows - 1) / (n_rows - ddof) + n_rows = functions.count(_input) + return functions.var_samp(_input) * (n_rows - 1) / (n_rows - ddof) from pyspark.pandas.spark.functions import var - input_col = F.col(_input) if isinstance(_input, str) else _input + input_col = functions.col(_input) if isinstance(_input, str) else _input return var(input_col, ddof=ddof) diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index df671be3ad..a9c457c08f 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: import numpy as np + import sqlframe if sys.version_info >= (3, 10): from typing import TypeGuard @@ -107,6 +108,11 @@ def get_pyspark_sql() -> Any: return sys.modules.get("pyspark.sql", None) +def get_sqlframe() -> Any: + """Get sqlframe module (if already imported - else return None).""" + return sys.modules.get("sqlframe", None) + + def is_pandas_dataframe(df: Any) -> TypeGuard[pd.DataFrame]: """Check whether `df` is a pandas DataFrame without importing pandas.""" return ((pd := get_pandas()) is not None and isinstance(df, pd.DataFrame)) or any( @@ -218,6 +224,14 @@ def is_pyspark_dataframe(df: Any) -> TypeGuard[pyspark_sql.DataFrame]: ) +def is_sqlframe_dataframe(df: Any) -> TypeGuard[sqlframe.base.dataframe.BaseDataFrame]: + """Check whether `df` is a SQLFrame DataFrame without importing SQLFrame.""" + return bool( + (sqlframe := get_sqlframe()) is not None + and isinstance(df, sqlframe.base.dataframe.BaseDataFrame) + ) + + def is_numpy_array(arr: Any) -> TypeGuard[np.ndarray]: """Check whether `arr` is a NumPy Array without importing NumPy.""" return (np := get_numpy()) is not None and isinstance(arr, np.ndarray) diff --git a/narwhals/translate.py b/narwhals/translate.py index 9c455055a0..124945cf29 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -36,6 +36,7 @@ from narwhals.dependencies import is_pyarrow_chunked_array from narwhals.dependencies import is_pyarrow_table from narwhals.dependencies import is_pyspark_dataframe +from narwhals.dependencies import is_sqlframe_dataframe from narwhals.utils import Version if TYPE_CHECKING: @@ -378,8 +379,32 @@ def _from_native_impl( # noqa: PLR0915 msg = "Invalid parameter combination: `eager_only=True` and `eager_or_interchange_only=True`" raise ValueError(msg) + # SQLFrame + # This one needs checking before extensions as `hasattr` always returns `True`. + if is_sqlframe_dataframe(native_object): # pragma: no cover + from narwhals._spark_like.dataframe import SparkLikeLazyFrame + + if series_only: + msg = "Cannot only use `series_only` with pyspark DataFrame" + raise TypeError(msg) + if eager_only or eager_or_interchange_only: + msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with pyspark DataFrame" + raise TypeError(msg) + import sqlframe._version + + backend_version = parse_version(sqlframe._version.__version__) + return LazyFrame( + SparkLikeLazyFrame( + native_object, + backend_version=backend_version, + version=version, + implementation=Implementation.SQLFRAME, + ), + level="lazy", + ) + # Extensions - if hasattr(native_object, "__narwhals_dataframe__"): + elif hasattr(native_object, "__narwhals_dataframe__"): if series_only: if not pass_through: msg = "Cannot only use `series_only` with dataframe" @@ -727,6 +752,7 @@ def _from_native_impl( # noqa: PLR0915 native_object, backend_version=parse_version(get_pyspark().__version__), version=version, + implementation=Implementation.PYSPARK, ), level="lazy", ) diff --git a/narwhals/utils.py b/narwhals/utils.py index dfde5c6fc6..bbfe3eeaf4 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -79,6 +79,8 @@ class Implementation(Enum): """DuckDB implementation.""" IBIS = auto() """Ibis implementation.""" + SQLFRAME = auto() + """SQLFrame implementation.""" UNKNOWN = auto() """Unknown implementation.""" @@ -300,6 +302,7 @@ def is_ibis(self: Self) -> bool: Implementation.DASK: (2024, 8), Implementation.DUCKDB: (1,), Implementation.IBIS: (6,), + Implementation.SQLFRAME: (3, 14, 2), } diff --git a/pyproject.toml b/pyproject.toml index fd2860b76f..80eee43741 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,8 +170,8 @@ filterwarnings = [ 'ignore:.*distutils Version classes are deprecated. Use packaging.version instead.*:DeprecationWarning:pyspark', 'ignore:.*is_datetime64tz_dtype is deprecated and will be removed in a future version.*:DeprecationWarning:pyspark', # Warning raised by PyArrow nightly just by importing pandas - 'ignore:.*Python binding for RankQuantileOptions not exposed:RuntimeWarning:pyarrow' - + 'ignore:.*Python binding for RankQuantileOptions not exposed:RuntimeWarning:pyarrow', + 'ignore:.*pandas only supports SQLAlchemy:UserWarning' ] xfail_strict = true markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] diff --git a/tests/conftest.py b/tests/conftest.py index 960c80387f..a95f6a5f84 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -182,6 +182,17 @@ def _constructor(obj: dict[str, list[Any]]) -> IntoFrame: return _constructor +def sqlframe_pyspark_lazy_constructor( + obj: dict[str, Any], +) -> Callable[[Any], IntoFrame]: # pragma: no cover + from sqlframe.duckdb import DuckDBSession + + session = DuckDBSession() + return ( # type: ignore[no-any-return] + session.createDataFrame([*zip(*obj.values())], schema=[*obj.keys()]) + ) + + EAGER_CONSTRUCTORS: dict[str, Callable[[Any], IntoDataFrame]] = { "pandas": pandas_constructor, "pandas[nullable]": pandas_nullable_constructor, @@ -197,6 +208,9 @@ def _constructor(obj: dict[str, list[Any]]) -> IntoFrame: "polars[lazy]": polars_lazy_constructor, "duckdb": duckdb_lazy_constructor, "pyspark": pyspark_lazy_constructor, # type: ignore[dict-item] + # We've reported several bugs to sqlframe - once they address + # them, we can start testing them as part of our CI. + # "sqlframe": pyspark_lazy_constructor, # noqa: ERA001 } GPU_CONSTRUCTORS: dict[str, Callable[[Any], IntoFrame]] = {"cudf": cudf_constructor} From 23373a3abd0c044a5bb32fa47d9b53597026457d Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 28 Jan 2025 17:02:26 +0000 Subject: [PATCH 08/13] chore: fix pyspark/sqlframe typo (#1884) --- narwhals/translate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/translate.py b/narwhals/translate.py index 124945cf29..f370650869 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -385,10 +385,10 @@ def _from_native_impl( # noqa: PLR0915 from narwhals._spark_like.dataframe import SparkLikeLazyFrame if series_only: - msg = "Cannot only use `series_only` with pyspark DataFrame" + msg = "Cannot only use `series_only` with SQLFrame DataFrame" raise TypeError(msg) if eager_only or eager_or_interchange_only: - msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with pyspark DataFrame" + msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with SQLFrame DataFrame" raise TypeError(msg) import sqlframe._version From f644931076307b150c1f1333c48835e4ffe9342c Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 28 Jan 2025 19:16:49 +0000 Subject: [PATCH 09/13] release: Bump version to 1.24.1 (#1885) --- docs/installation.md | 2 +- narwhals/__init__.py | 2 +- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/installation.md b/docs/installation.md index 54d4fb0c5b..f8e2d897f0 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -30,7 +30,7 @@ To verify the installation, start the Python REPL and execute: ```python >>> import narwhals >>> narwhals.__version__ -'1.24.0' +'1.24.1' ``` If you see the version number, then the installation was successful! diff --git a/narwhals/__init__.py b/narwhals/__init__.py index b5033514de..f7d0e0427e 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -79,7 +79,7 @@ from narwhals.utils import maybe_reset_index from narwhals.utils import maybe_set_index -__version__ = "1.24.0" +__version__ = "1.24.1" __all__ = [ "Array", diff --git a/pyproject.toml b/pyproject.toml index 80eee43741..094b5e78fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" [project] name = "narwhals" -version = "1.24.0" +version = "1.24.1" dependencies = [] requires-python = ">=3.8" authors = [ From 83c1b2aa35b41c294daa55f095118999dec130c6 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Thu, 30 Jan 2025 19:50:59 +0100 Subject: [PATCH 10/13] feat: spark like `.list` namespace and lazyframe `.explode` (#1887) * feat: spark like list namespace and explode * some cleanup and use F.explode_outer --- narwhals/_duckdb/expr_list.py | 6 ++-- narwhals/_spark_like/dataframe.py | 48 ++++++++++++++++++++++++-- narwhals/_spark_like/expr.py | 5 +++ narwhals/_spark_like/expr_list.py | 20 +++++++++++ narwhals/_spark_like/utils.py | 29 +++++++++------- tests/expr_and_series/list/len_test.py | 4 +-- tests/frame/explode_test.py | 10 +++--- tests/utils.py | 7 +++- 8 files changed, 102 insertions(+), 27 deletions(-) create mode 100644 narwhals/_spark_like/expr_list.py diff --git a/narwhals/_duckdb/expr_list.py b/narwhals/_duckdb/expr_list.py index 134df90b6d..a2277c4b8d 100644 --- a/narwhals/_duckdb/expr_list.py +++ b/narwhals/_duckdb/expr_list.py @@ -5,14 +5,16 @@ from duckdb import FunctionExpression if TYPE_CHECKING: + from typing_extensions import Self + from narwhals._duckdb.expr import DuckDBExpr class DuckDBExprListNamespace: - def __init__(self, expr: DuckDBExpr) -> None: + def __init__(self: Self, expr: DuckDBExpr) -> None: self._compliant_expr = expr - def len(self) -> DuckDBExpr: + def len(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("len", _input), "len", diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 13945b0606..0661f3cfb6 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -9,9 +9,11 @@ from narwhals._spark_like.utils import ExprKind from narwhals._spark_like.utils import native_to_narwhals_dtype from narwhals._spark_like.utils import parse_exprs_and_named_exprs +from narwhals.exceptions import InvalidOperationError from narwhals.typing import CompliantLazyFrame from narwhals.utils import Implementation from narwhals.utils import check_column_exists +from narwhals.utils import import_dtypes_module from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version from narwhals.utils import validate_backend_version @@ -46,7 +48,7 @@ def __init__( validate_backend_version(self._implementation, self._backend_version) @property - def _F(self) -> Any: # noqa: N802 + def _F(self: Self) -> Any: # noqa: N802 if self._implementation is Implementation.SQLFRAME: from sqlframe.duckdb import functions @@ -56,7 +58,7 @@ def _F(self) -> Any: # noqa: N802 return functions @property - def _native_dtypes(self) -> Any: + def _native_dtypes(self: Self) -> Any: if self._implementation is Implementation.SQLFRAME: from sqlframe.duckdb import types @@ -66,7 +68,7 @@ def _native_dtypes(self) -> Any: return types @property - def _Window(self) -> Any: # noqa: N802 + def _Window(self: Self) -> Any: # noqa: N802 if self._implementation is Implementation.SQLFRAME: from sqlframe.duckdb import Window @@ -312,3 +314,43 @@ def join( return self._from_native_frame( self_native.join(other, on=left_on, how=how).select(col_order) ) + + def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + dtypes = import_dtypes_module(self._version) + + to_explode = ( + [columns, *more_columns] + if isinstance(columns, str) + else [*columns, *more_columns] + ) + schema = self.collect_schema() + for col_to_explode in to_explode: + dtype = schema[col_to_explode] + + if dtype != dtypes.List: + msg = ( + f"`explode` operation not supported for dtype `{dtype}`, " + "expected List type" + ) + raise InvalidOperationError(msg) + + native_frame = self._native_frame + column_names = self.columns + + if len(to_explode) != 1: + msg = ( + "Exploding on multiple columns is not supported with SparkLike backend since " + "we cannot guarantee that the exploded columns have matching element counts." + ) + raise NotImplementedError(msg) + + return self._from_native_frame( + native_frame.select( + *[ + self._F.col(col_name).alias(col_name) + if col_name != to_explode[0] + else self._F.explode_outer(col_name).alias(col_name) + for col_name in column_names + ] + ) + ) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index f3a9de7f02..d3f590ad2d 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -7,6 +7,7 @@ from typing import Sequence from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace +from narwhals._spark_like.expr_list import SparkLikeExprListNamespace from narwhals._spark_like.expr_name import SparkLikeExprNameNamespace from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace from narwhals._spark_like.utils import ExprKind @@ -556,3 +557,7 @@ def name(self: Self) -> SparkLikeExprNameNamespace: @property def dt(self: Self) -> SparkLikeExprDateTimeNamespace: return SparkLikeExprDateTimeNamespace(self) + + @property + def list(self: Self) -> SparkLikeExprListNamespace: + return SparkLikeExprListNamespace(self) diff --git a/narwhals/_spark_like/expr_list.py b/narwhals/_spark_like/expr_list.py new file mode 100644 index 0000000000..ba0dc31897 --- /dev/null +++ b/narwhals/_spark_like/expr_list.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._spark_like.expr import SparkLikeExpr + + +class SparkLikeExprListNamespace: + def __init__(self: Self, expr: SparkLikeExpr) -> None: + self._compliant_expr = expr + + def len(self: Self) -> SparkLikeExpr: + return self._compliant_expr._from_call( + self._compliant_expr._F.array_size, + "len", + expr_kind=self._compliant_expr._expr_kind, + ) diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index 82d7feb28b..3a728d5c59 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -57,26 +57,25 @@ def native_to_narwhals_dtype( return dtypes.Int16() if isinstance(dtype, spark_types.ByteType): return dtypes.Int8() - string_types = [ - spark_types.StringType, - spark_types.VarcharType, - spark_types.CharType, - ] - if any(isinstance(dtype, t) for t in string_types): + if isinstance( + dtype, (spark_types.StringType, spark_types.VarcharType, spark_types.CharType) + ): return dtypes.String() if isinstance(dtype, spark_types.BooleanType): return dtypes.Boolean() if isinstance(dtype, spark_types.DateType): return dtypes.Date() - datetime_types = [ - spark_types.TimestampType, - spark_types.TimestampNTZType, - ] - if any(isinstance(dtype, t) for t in datetime_types): + if isinstance(dtype, (spark_types.TimestampType, spark_types.TimestampNTZType)): return dtypes.Datetime() if isinstance(dtype, spark_types.DecimalType): # pragma: no cover # TODO(unassigned): cover this in dtypes_test.py return dtypes.Decimal() + if isinstance(dtype, spark_types.ArrayType): # pragma: no cover + return dtypes.List( + inner=native_to_narwhals_dtype( + dtype.elementType, version=version, spark_types=spark_types + ) + ) return dtypes.Unknown() @@ -105,8 +104,12 @@ def narwhals_to_native_dtype( msg = "Converting to Date or Datetime dtype is not supported yet" raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover - msg = "Converting to List dtype is not supported yet" - raise NotImplementedError(msg) + inner = narwhals_to_native_dtype( + dtype.inner, # type: ignore[union-attr] + version=version, + spark_types=spark_types, + ) + return spark_types.ArrayType(elementType=inner) if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover msg = "Converting to Struct dtype is not supported yet" raise NotImplementedError(msg) diff --git a/tests/expr_and_series/list/len_test.py b/tests/expr_and_series/list/len_test.py index 375cfc7d88..7066fc6cf3 100644 --- a/tests/expr_and_series/list/len_test.py +++ b/tests/expr_and_series/list/len_test.py @@ -17,9 +17,7 @@ def test_len_expr( request: pytest.FixtureRequest, constructor: Constructor, ) -> None: - if any( - backend in str(constructor) for backend in ("dask", "modin", "cudf", "pyspark") - ): + if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): diff --git a/tests/frame/explode_test.py b/tests/frame/explode_test.py index f3b0961944..db5a4fc5a3 100644 --- a/tests/frame/explode_test.py +++ b/tests/frame/explode_test.py @@ -40,7 +40,7 @@ def test_explode_single_col( ) -> None: if any( backend in str(constructor) - for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb", "pyspark") + for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb") ): request.applymarker(pytest.mark.xfail) @@ -110,7 +110,7 @@ def test_explode_shape_error( ) -> None: if any( backend in str(constructor) - for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb", "pyspark") + for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb") ): request.applymarker(pytest.mark.xfail) @@ -118,8 +118,8 @@ def test_explode_shape_error( request.applymarker(pytest.mark.xfail) with pytest.raises( - (ShapeError, PlShapeError), - match="exploded columns must have matching element counts", + (ShapeError, PlShapeError, NotImplementedError), + match=r".*exploded columns (must )?have matching element counts", ): _ = ( nw.from_native(constructor(data)) @@ -133,7 +133,7 @@ def test_explode_shape_error( def test_explode_invalid_operation_error( request: pytest.FixtureRequest, constructor: Constructor ) -> None: - if any(x in str(constructor) for x in ("pyarrow_table", "dask", "duckdb", "pyspark")): + if any(x in str(constructor) for x in ("pyarrow_table", "dask", "duckdb")): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6): diff --git a/tests/utils.py b/tests/utils.py index f4f6126195..7174fbb9ee 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -66,7 +66,12 @@ def _sort_dict_by_key( ) -> dict[str, list[Any]]: # pragma: no cover sort_list = data_dict[key] sorted_indices = sorted( - range(len(sort_list)), key=lambda i: (sort_list[i] is None, sort_list[i]) + range(len(sort_list)), + key=lambda i: ( + (sort_list[i] is None) + or (isinstance(sort_list[i], float) and math.isnan(sort_list[i])), + sort_list[i], + ), ) return {key: [value[i] for i in sorted_indices] for key, value in data_dict.items()} From b53121316461d89f9dbad307fb403da1658104a5 Mon Sep 17 00:00:00 2001 From: Richard Iannone Date: Fri, 31 Jan 2025 15:23:37 -0500 Subject: [PATCH 11/13] docs(README.md): add mention of pointblank project (#1903) Add pointblank project repo to README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index eee90ebd96..533bd12de9 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,7 @@ Join the party! - [marimo](https://github.com/marimo-team/marimo) - [panel-graphic-walker](https://github.com/panel-extensions/panel-graphic-walker) - [plotly](https://plotly.com) +- [pointblank](https://github.com/posit-dev/pointblank) - [pymarginaleffects](https://github.com/vincentarelbundock/pymarginaleffects) - [py-shiny](https://github.com/posit-dev/py-shiny) - [rio](https://github.com/rio-labs/rio) From 5f7fdabfdfcb9675a0d8f2e299d10b23338de80e Mon Sep 17 00:00:00 2001 From: Richard Iannone Date: Fri, 31 Jan 2025 16:49:55 -0500 Subject: [PATCH 12/13] docs(ecosystem.md): add `pointblank` to Used By section (#1904) Add `pointblank` to Used By section --- docs/ecosystem.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/ecosystem.md b/docs/ecosystem.md index 90de8f657b..7c9cbb6cc9 100644 --- a/docs/ecosystem.md +++ b/docs/ecosystem.md @@ -10,6 +10,7 @@ for their dataframe interoperability needs: * [marimo](https://github.com/marimo-team/marimo) * [panel-graphic-walker](https://github.com/panel-extensions/panel-graphic-walker) * [plotly](https://github.com/plotly/plotly.py) +* [pointblank](https://github.com/posit-dev/pointblank) * [pymarginaleffects](https://github.com/vincentarelbundock/pymarginaleffects) * [py-shiny](https://github.com/posit-dev/py-shiny) * [rio](https://github.com/rio-labs/rio) From 562b213c8e571b690db57b0145f3a0f023d7382e Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Fri, 31 Jan 2025 23:07:30 +0100 Subject: [PATCH 13/13] docs: update selectors docstrings (#1905) * docs: update selector docstrings * mypy --- narwhals/_dask/selectors.py | 13 +- narwhals/_duckdb/selectors.py | 13 +- narwhals/_pandas_like/selectors.py | 13 +- narwhals/_spark_like/selectors.py | 13 +- narwhals/selectors.py | 244 ++++++++++++++++++++--------- 5 files changed, 196 insertions(+), 100 deletions(-) diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index 8645d77a96..1e49ceee9c 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -3,6 +3,7 @@ import re from typing import TYPE_CHECKING from typing import Any +from typing import Iterable from typing import Sequence from narwhals._dask.expr import DaskExpr @@ -28,7 +29,7 @@ def __init__( self._backend_version = backend_version self._version = version - def by_dtype(self: Self, dtypes: list[DType | type[DType]]) -> DaskSelector: + def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> DaskSelector: def func(df: DaskLazyFrame) -> list[dx.Series]: return [ df._native_frame[col] for col in df.columns if df.schema[col] in dtypes @@ -73,7 +74,7 @@ def evalute_output_names(df: DaskLazyFrame) -> Sequence[str]: def numeric(self: Self) -> DaskSelector: dtypes = import_dtypes_module(self._version) return self.by_dtype( - [ + { dtypes.Int128, dtypes.Int64, dtypes.Int32, @@ -86,20 +87,20 @@ def numeric(self: Self) -> DaskSelector: dtypes.UInt8, dtypes.Float64, dtypes.Float32, - ], + }, ) def categorical(self: Self) -> DaskSelector: dtypes = import_dtypes_module(self._version) - return self.by_dtype([dtypes.Categorical]) + return self.by_dtype({dtypes.Categorical}) def string(self: Self) -> DaskSelector: dtypes = import_dtypes_module(self._version) - return self.by_dtype([dtypes.String]) + return self.by_dtype({dtypes.String}) def boolean(self: Self) -> DaskSelector: dtypes = import_dtypes_module(self._version) - return self.by_dtype([dtypes.Boolean]) + return self.by_dtype({dtypes.Boolean}) def all(self: Self) -> DaskSelector: def func(df: DaskLazyFrame) -> list[dx.Series]: diff --git a/narwhals/_duckdb/selectors.py b/narwhals/_duckdb/selectors.py index 30fd32e338..f11927ad42 100644 --- a/narwhals/_duckdb/selectors.py +++ b/narwhals/_duckdb/selectors.py @@ -3,6 +3,7 @@ import re from typing import TYPE_CHECKING from typing import Any +from typing import Iterable from typing import Sequence from duckdb import ColumnExpression @@ -27,7 +28,7 @@ def __init__( self._backend_version = backend_version self._version = version - def by_dtype(self: Self, dtypes: list[DType | type[DType]]) -> DuckDBSelector: + def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> DuckDBSelector: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [ ColumnExpression(col) for col in df.columns if df.schema[col] in dtypes @@ -68,7 +69,7 @@ def evalute_output_names(df: DuckDBLazyFrame) -> Sequence[str]: def numeric(self: Self) -> DuckDBSelector: dtypes = import_dtypes_module(self._version) return self.by_dtype( - [ + { dtypes.Int128, dtypes.Int64, dtypes.Int32, @@ -81,20 +82,20 @@ def numeric(self: Self) -> DuckDBSelector: dtypes.UInt8, dtypes.Float64, dtypes.Float32, - ], + }, ) def categorical(self: Self) -> DuckDBSelector: # pragma: no cover dtypes = import_dtypes_module(self._version) - return self.by_dtype([dtypes.Categorical]) + return self.by_dtype({dtypes.Categorical}) def string(self: Self) -> DuckDBSelector: dtypes = import_dtypes_module(self._version) - return self.by_dtype([dtypes.String]) + return self.by_dtype({dtypes.String}) def boolean(self: Self) -> DuckDBSelector: dtypes = import_dtypes_module(self._version) - return self.by_dtype([dtypes.Boolean]) + return self.by_dtype({dtypes.Boolean}) def all(self: Self) -> DuckDBSelector: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: diff --git a/narwhals/_pandas_like/selectors.py b/narwhals/_pandas_like/selectors.py index 2d90adc934..48bf63a04a 100644 --- a/narwhals/_pandas_like/selectors.py +++ b/narwhals/_pandas_like/selectors.py @@ -3,6 +3,7 @@ import re from typing import TYPE_CHECKING from typing import Any +from typing import Iterable from typing import Sequence from narwhals._pandas_like.expr import PandasLikeExpr @@ -30,7 +31,7 @@ def __init__( self._backend_version = backend_version self._version = version - def by_dtype(self: Self, dtypes: list[DType | type[DType]]) -> PandasSelector: + def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> PandasSelector: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return [df[col] for col in df.columns if df.schema[col] in dtypes] @@ -71,7 +72,7 @@ def evalute_output_names(df: PandasLikeDataFrame) -> Sequence[str]: def numeric(self: Self) -> PandasSelector: dtypes = import_dtypes_module(self._version) return self.by_dtype( - [ + { dtypes.Int128, dtypes.Int64, dtypes.Int32, @@ -84,20 +85,20 @@ def numeric(self: Self) -> PandasSelector: dtypes.UInt8, dtypes.Float64, dtypes.Float32, - ], + } ) def categorical(self: Self) -> PandasSelector: dtypes = import_dtypes_module(self._version) - return self.by_dtype([dtypes.Categorical]) + return self.by_dtype({dtypes.Categorical}) def string(self: Self) -> PandasSelector: dtypes = import_dtypes_module(self._version) - return self.by_dtype([dtypes.String]) + return self.by_dtype({dtypes.String}) def boolean(self: Self) -> PandasSelector: dtypes = import_dtypes_module(self._version) - return self.by_dtype([dtypes.Boolean]) + return self.by_dtype({dtypes.Boolean}) def all(self: Self) -> PandasSelector: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: diff --git a/narwhals/_spark_like/selectors.py b/narwhals/_spark_like/selectors.py index b0ad7c594a..d2534cb149 100644 --- a/narwhals/_spark_like/selectors.py +++ b/narwhals/_spark_like/selectors.py @@ -3,6 +3,7 @@ import re from typing import TYPE_CHECKING from typing import Any +from typing import Iterable from typing import Sequence from narwhals._spark_like.expr import SparkLikeExpr @@ -31,7 +32,7 @@ def __init__( self._version = version self._implementation = implementation - def by_dtype(self: Self, dtypes: list[DType | type[DType]]) -> SparkLikeSelector: + def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> SparkLikeSelector: def func(df: SparkLikeLazyFrame) -> list[Column]: return [df._F.col(col) for col in df.columns if df.schema[col] in dtypes] @@ -70,7 +71,7 @@ def evalute_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: def numeric(self: Self) -> SparkLikeSelector: dtypes = import_dtypes_module(self._version) return self.by_dtype( - [ + { dtypes.Int128, dtypes.Int64, dtypes.Int32, @@ -83,20 +84,20 @@ def numeric(self: Self) -> SparkLikeSelector: dtypes.UInt8, dtypes.Float64, dtypes.Float32, - ], + }, ) def categorical(self: Self) -> SparkLikeSelector: dtypes = import_dtypes_module(self._version) - return self.by_dtype([dtypes.Categorical]) + return self.by_dtype({dtypes.Categorical}) def string(self: Self) -> SparkLikeSelector: dtypes = import_dtypes_module(self._version) - return self.by_dtype([dtypes.String]) + return self.by_dtype({dtypes.String}) def boolean(self: Self) -> SparkLikeSelector: dtypes = import_dtypes_module(self._version) - return self.by_dtype([dtypes.Boolean]) + return self.by_dtype({dtypes.Boolean}) def all(self: Self) -> SparkLikeSelector: def func(df: SparkLikeLazyFrame) -> list[Column]: diff --git a/narwhals/selectors.py b/narwhals/selectors.py index 6228030b4b..5c08103602 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Iterable from typing import NoReturn from narwhals.expr import Expr @@ -10,6 +11,8 @@ if TYPE_CHECKING: from typing_extensions import Self + from narwhals.dtypes import DType + class Selector(Expr): def _to_expr(self: Self) -> Expr: @@ -36,7 +39,7 @@ def __ror__(self: Self, other: Any) -> NoReturn: raise NotImplementedError -def by_dtype(*dtypes: Any) -> Expr: +def by_dtype(*dtypes: DType | type[DType] | Iterable[DType | type[DType]]) -> Selector: """Select columns based on their dtype. Arguments: @@ -46,29 +49,34 @@ def by_dtype(*dtypes: Any) -> Expr: A new expression. Examples: - >>> import narwhals as nw - >>> import narwhals.selectors as ncs >>> import pandas as pd >>> import polars as pl + >>> import pyarrow as pa + >>> import narwhals as nw + >>> import narwhals.selectors as ncs + >>> from narwhals.typing import IntoFrameT >>> >>> data = {"a": [1, 2], "b": ["x", "y"], "c": [4.1, 2.3]} >>> df_pd = pd.DataFrame(data) >>> df_pl = pl.DataFrame(data) + >>> df_pa = pa.table(data) Let's define a dataframe-agnostic function to select int64 and float64 dtypes and multiplies each value by 2: - >>> @nw.narwhalify - ... def func(df): - ... return df.select(ncs.by_dtype(nw.Int64, nw.Float64) * 2) + >>> def agnostic_select_by_dtype(df_native: IntoFrameT) -> IntoFrameT: + ... df_nw = nw.from_native(df_native) + ... return df_nw.select(ncs.by_dtype(nw.Int64, nw.Float64) * 2).to_native() - We can then pass either pandas or Polars dataframes: + We can then pass any supported library such as pandas, Polars, or + PyArrow to `agnostic_select_by_dtype`: - >>> func(df_pd) + >>> agnostic_select_by_dtype(df_pd) a c 0 2 8.2 1 4 4.6 - >>> func(df_pl) + + >>> agnostic_select_by_dtype(df_pl) shape: (2, 2) ┌─────┬─────┐ │ a ┆ c │ @@ -78,6 +86,14 @@ def by_dtype(*dtypes: Any) -> Expr: │ 2 ┆ 8.2 │ │ 4 ┆ 4.6 │ └─────┴─────┘ + + >>> agnostic_select_by_dtype(df_pa) + pyarrow.Table + a: int64 + c: double + ---- + a: [[2,4]] + c: [[8.2,4.6]] """ return Selector( lambda plx: plx.selectors.by_dtype(flatten(dtypes)), @@ -87,7 +103,7 @@ def by_dtype(*dtypes: Any) -> Expr: ) -def matches(pattern: str) -> Expr: +def matches(pattern: str) -> Selector: """Select all columns that match the given regex pattern. Arguments: @@ -97,10 +113,12 @@ def matches(pattern: str) -> Expr: A new expression. Examples: - >>> import narwhals as nw - >>> import narwhals.selectors as ncs >>> import pandas as pd >>> import polars as pl + >>> import pyarrow as pa + >>> import narwhals as nw + >>> import narwhals.selectors as ncs + >>> from narwhals.typing import IntoFrameT >>> >>> data = { ... "foo": ["x", "y"], @@ -110,21 +128,24 @@ def matches(pattern: str) -> Expr: ... } >>> df_pd = pd.DataFrame(data) >>> df_pl = pl.DataFrame(data) + >>> df_pa = pa.table(data) Let's define a dataframe-agnostic function to select column names containing an 'a', preceded by a character that is not 'z': - >>> @nw.narwhalify - ... def func(df): - ... return df.select(ncs.matches("[^z]a")) + >>> def agnostic_select_match(df_native: IntoFrameT) -> IntoFrameT: + ... df_nw = nw.from_native(df_native) + ... return df_nw.select(ncs.matches("[^z]a")).to_native() - We can then pass either pandas or Polars dataframes: + We can then pass any supported library such as pandas, Polars, or + PyArrow to `agnostic_select_match`: - >>> func(df_pd) + >>> agnostic_select_match(df_pd) bar baz 0 123 2.0 1 456 5.5 - >>> func(df_pl) + + >>> agnostic_select_match(df_pl) shape: (2, 2) ┌─────┬─────┐ │ bar ┆ baz │ @@ -134,6 +155,14 @@ def matches(pattern: str) -> Expr: │ 123 ┆ 2.0 │ │ 456 ┆ 5.5 │ └─────┴─────┘ + + >>> agnostic_select_match(df_pa) + pyarrow.Table + bar: int64 + baz: double + ---- + bar: [[123,456]] + baz: [[2,5.5]] """ return Selector( lambda plx: plx.selectors.matches(pattern), @@ -143,36 +172,41 @@ def matches(pattern: str) -> Expr: ) -def numeric() -> Expr: +def numeric() -> Selector: """Select numeric columns. Returns: A new expression. Examples: - >>> import narwhals as nw - >>> import narwhals.selectors as ncs >>> import pandas as pd >>> import polars as pl + >>> import pyarrow as pa + >>> import narwhals as nw + >>> import narwhals.selectors as ncs + >>> from narwhals.typing import IntoFrameT >>> >>> data = {"a": [1, 2], "b": ["x", "y"], "c": [4.1, 2.3]} >>> df_pd = pd.DataFrame(data) >>> df_pl = pl.DataFrame(data) + >>> df_pa = pa.table(data) Let's define a dataframe-agnostic function to select numeric dtypes and multiplies each value by 2: - >>> @nw.narwhalify - ... def func(df): - ... return df.select(ncs.numeric() * 2) + >>> def agnostic_select_numeric(df_native: IntoFrameT) -> IntoFrameT: + ... df_nw = nw.from_native(df_native) + ... return df_nw.select(ncs.numeric() * 2).to_native() - We can then pass either pandas or Polars dataframes: + We can then pass any supported library such as pandas, Polars, or + PyArrow to `agnostic_select_numeric`: - >>> func(df_pd) + >>> agnostic_select_numeric(df_pd) a c 0 2 8.2 1 4 4.6 - >>> func(df_pl) + + >>> agnostic_select_numeric(df_pl) shape: (2, 2) ┌─────┬─────┐ │ a ┆ c │ @@ -182,6 +216,14 @@ def numeric() -> Expr: │ 2 ┆ 8.2 │ │ 4 ┆ 4.6 │ └─────┴─────┘ + + >>> agnostic_select_numeric(df_pa) + pyarrow.Table + a: int64 + c: double + ---- + a: [[2,4]] + c: [[8.2,4.6]] """ return Selector( lambda plx: plx.selectors.numeric(), @@ -191,36 +233,40 @@ def numeric() -> Expr: ) -def boolean() -> Expr: +def boolean() -> Selector: """Select boolean columns. Returns: A new expression. Examples: - >>> import narwhals as nw - >>> import narwhals.selectors as ncs >>> import pandas as pd >>> import polars as pl + >>> import pyarrow as pa + >>> import narwhals as nw + >>> import narwhals.selectors as ncs + >>> from narwhals.typing import IntoFrameT >>> >>> data = {"a": [1, 2], "b": ["x", "y"], "c": [False, True]} >>> df_pd = pd.DataFrame(data) >>> df_pl = pl.DataFrame(data) + >>> df_pa = pa.table(data) - Let's define a dataframe-agnostic function to select boolean - dtypes: + Let's define a dataframe-agnostic function to select boolean dtypes: - >>> @nw.narwhalify - ... def func(df): - ... return df.select(ncs.boolean()) + >>> def agnostic_select_boolean(df_native: IntoFrameT) -> IntoFrameT: + ... df_nw = nw.from_native(df_native) + ... return df_nw.select(ncs.boolean()).to_native() - We can then pass either pandas or Polars dataframes: + We can then pass any supported library such as pandas, Polars, or + PyArrow to `agnostic_select_boolean`: - >>> func(df_pd) + >>> agnostic_select_boolean(df_pd) c 0 False 1 True - >>> func(df_pl) + + >>> agnostic_select_boolean(df_pl) shape: (2, 1) ┌───────┐ │ c │ @@ -230,6 +276,12 @@ def boolean() -> Expr: │ false │ │ true │ └───────┘ + + >>> agnostic_select_boolean(df_pa) + pyarrow.Table + c: bool + ---- + c: [[false,true]] """ return Selector( lambda plx: plx.selectors.boolean(), @@ -239,36 +291,40 @@ def boolean() -> Expr: ) -def string() -> Expr: +def string() -> Selector: """Select string columns. Returns: A new expression. Examples: - >>> import narwhals as nw - >>> import narwhals.selectors as ncs >>> import pandas as pd >>> import polars as pl + >>> import pyarrow as pa + >>> import narwhals as nw + >>> import narwhals.selectors as ncs + >>> from narwhals.typing import IntoFrameT >>> >>> data = {"a": [1, 2], "b": ["x", "y"], "c": [False, True]} >>> df_pd = pd.DataFrame(data) >>> df_pl = pl.DataFrame(data) + >>> df_pa = pa.table(data) - Let's define a dataframe-agnostic function to select string - dtypes: + Let's define a dataframe-agnostic function to select string dtypes: - >>> @nw.narwhalify - ... def func(df): - ... return df.select(ncs.string()) + >>> def agnostic_select_string(df_native: IntoFrameT) -> IntoFrameT: + ... df_nw = nw.from_native(df_native) + ... return df_nw.select(ncs.string()).to_native() - We can then pass either pandas or Polars dataframes: + We can then pass any supported library such as pandas, Polars, or + PyArrow to `agnostic_select_string`: - >>> func(df_pd) + >>> agnostic_select_string(df_pd) b 0 x 1 y - >>> func(df_pl) + + >>> agnostic_select_string(df_pl) shape: (2, 1) ┌─────┐ │ b │ @@ -278,6 +334,12 @@ def string() -> Expr: │ x │ │ y │ └─────┘ + + >>> agnostic_select_string(df_pa) + pyarrow.Table + b: string + ---- + b: [["x","y"]] """ return Selector( lambda plx: plx.selectors.string(), @@ -287,36 +349,43 @@ def string() -> Expr: ) -def categorical() -> Expr: +def categorical() -> Selector: """Select categorical columns. Returns: A new expression. Examples: - >>> import narwhals as nw - >>> import narwhals.selectors as ncs >>> import pandas as pd >>> import polars as pl + >>> import pyarrow as pa + >>> import narwhals as nw + >>> import narwhals.selectors as ncs + >>> from narwhals.typing import IntoFrameT >>> >>> data = {"a": [1, 2], "b": ["x", "y"], "c": [False, True]} - >>> df_pd = pd.DataFrame(data).astype({"b": "category"}) - >>> df_pl = pl.DataFrame(data, schema_overrides={"b": pl.Categorical}) + >>> df_pd = pd.DataFrame(data) + >>> df_pl = pl.DataFrame(data) + >>> df_pa = pa.table(data) - Let's define a dataframe-agnostic function to select string - dtypes: + Let's define a dataframe-agnostic function that first converts column "b" to + categorical, and then selects categorical dtypes: - >>> @nw.narwhalify - ... def func(df): - ... return df.select(ncs.categorical()) + >>> def agnostic_select_categorical(df_native: IntoFrameT) -> IntoFrameT: + ... df_nw = nw.from_native(df_native).with_columns( + ... b=nw.col("b").cast(nw.Categorical()) + ... ) + ... return df_nw.select(ncs.categorical()).to_native() - We can then pass either pandas or Polars dataframes: + We can then pass any supported library such as pandas, Polars, or + PyArrow to `agnostic_select_categorical`: - >>> func(df_pd) + >>> agnostic_select_categorical(df_pd) b 0 x 1 y - >>> func(df_pl) + + >>> agnostic_select_categorical(df_pl) shape: (2, 1) ┌─────┐ │ b │ @@ -326,6 +395,14 @@ def categorical() -> Expr: │ x │ │ y │ └─────┘ + + >>> agnostic_select_categorical(df_pa) + pyarrow.Table + b: dictionary + ---- + b: [ -- dictionary: + ["x","y"] -- indices: + [0,1]] """ return Selector( lambda plx: plx.selectors.categorical(), @@ -335,45 +412,59 @@ def categorical() -> Expr: ) -def all() -> Expr: +def all() -> Selector: """Select all columns. Returns: A new expression. Examples: - >>> import narwhals as nw - >>> import narwhals.selectors as ncs >>> import pandas as pd >>> import polars as pl + >>> import pyarrow as pa + >>> import narwhals as nw + >>> import narwhals.selectors as ncs + >>> from narwhals.typing import IntoFrameT >>> >>> data = {"a": [1, 2], "b": ["x", "y"], "c": [False, True]} - >>> df_pd = pd.DataFrame(data).astype({"b": "category"}) - >>> df_pl = pl.DataFrame(data, schema_overrides={"b": pl.Categorical}) + >>> df_pd = pd.DataFrame(data) + >>> df_pl = pl.DataFrame(data) + >>> df_pa = pa.table(data) - Let's define a dataframe-agnostic function to select string - dtypes: + Let's define a dataframe-agnostic function to select all dtypes: - >>> @nw.narwhalify - ... def func(df): - ... return df.select(ncs.all()) + >>> def agnostic_select_all(df_native: IntoFrameT) -> IntoFrameT: + ... df_nw = nw.from_native(df_native) + ... return df_nw.select(ncs.all()).to_native() - We can then pass either pandas or Polars dataframes: + We can then pass any supported library such as pandas, Polars, or + PyArrow to `agnostic_select_all`: - >>> func(df_pd) + >>> agnostic_select_all(df_pd) a b c 0 1 x False 1 2 y True - >>> func(df_pl) + + >>> agnostic_select_all(df_pl) shape: (2, 3) ┌─────┬─────┬───────┐ │ a ┆ b ┆ c │ │ --- ┆ --- ┆ --- │ - │ i64 ┆ cat ┆ bool │ + │ i64 ┆ str ┆ bool │ ╞═════╪═════╪═══════╡ │ 1 ┆ x ┆ false │ │ 2 ┆ y ┆ true │ └─────┴─────┴───────┘ + + >>> agnostic_select_all(df_pa) + pyarrow.Table + a: int64 + b: string + c: bool + ---- + a: [[1,2]] + b: [["x","y"]] + c: [[false,true]] """ return Selector( lambda plx: plx.selectors.all(), @@ -388,6 +479,7 @@ def all() -> Expr: "boolean", "by_dtype", "categorical", + "matches", "numeric", "string", ]