Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- alias
- all
- any
- any_value
- cast
- ceil
- clip
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- alias
- all
- any
- any_value
- arg_max
- arg_min
- arg_true
Expand Down
17 changes: 13 additions & 4 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]):
"any": "any",
"first": "first",
"last": "last",
"any_value": "first",
}
_REMAP_UNIQUE: ClassVar[Mapping[UniqueKeepStrategy, Aggregation]] = {
"any": "min",
Expand All @@ -52,7 +53,7 @@ class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]):
)
_OPTION_COUNT_VALID: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(("count",))
_OPTION_ORDERED: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(
("first", "last")
("first", "last", "any_value")
)
_OPTION_VARIANCE: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(("std", "var"))
_OPTION_SCALAR: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(
Expand Down Expand Up @@ -89,11 +90,19 @@ def _configure_agg(
elif function_name in self._OPTION_SCALAR:
option = pc.ScalarAggregateOptions(min_count=0)
elif function_name in self._OPTION_ORDERED:
grouped, option = self._ordered_agg(grouped, function_name)
ignore_nulls = kwargs.get("ignore_nulls", False)
grouped, option = self._ordered_agg(
grouped, function_name, ignore_nulls=ignore_nulls
)
return grouped, self._remap_expr_name(function_name), option

def _ordered_agg(
self, grouped: pa.TableGroupBy, name: NarwhalsAggregation, /
self,
grouped: pa.TableGroupBy,
name: NarwhalsAggregation,
/,
*,
ignore_nulls: bool,
) -> tuple[pa.TableGroupBy, AggregateOptions]:
"""The default behavior of `pyarrow` raises when `first` or `last` are used.

Expand All @@ -117,7 +126,7 @@ def _ordered_agg(
f"See https://github.com/apache/arrow/issues/36709"
)
raise NotImplementedError(msg)
return grouped, pc.ScalarAggregateOptions(skip_nulls=False)
return grouped, pc.ScalarAggregateOptions(skip_nulls=ignore_nulls)

def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
self._ensure_all_simple(exprs)
Expand Down
5 changes: 5 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,11 @@ def exp(self) -> Self:
def sqrt(self) -> Self:
return self._with_native(pc.sqrt(self.native))

def any_value(
self, *, ignore_nulls: bool, _return_py_scalar: bool = True
) -> PythonLiteral:
return self.drop_nulls().first() if ignore_nulls else self.first()

@property
def dt(self) -> ArrowSeriesDateTimeNamespace:
return ArrowSeriesDateTimeNamespace(self)
Expand Down
6 changes: 6 additions & 0 deletions narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def broadcast(self) -> Self: ...
def alias(self, name: str) -> Self: ...
def all(self) -> Self: ...
def any(self) -> Self: ...
def any_value(self, *, ignore_nulls: bool) -> Self: ...
def count(self) -> Self: ...
def min(self) -> Self: ...
def max(self) -> Self: ...
Expand Down Expand Up @@ -824,6 +825,11 @@ def first(self) -> Self:
def last(self) -> Self:
return self._reuse_series("last", returns_scalar=True)

def any_value(self, *, ignore_nulls: bool) -> Self:
return self._reuse_series(
"any_value", returns_scalar=True, ignore_nulls=ignore_nulls
)

@property
def cat(self) -> EagerExprCatNamespace[Self]:
return EagerExprCatNamespace(self)
Expand Down
1 change: 1 addition & 0 deletions narwhals/_compliant/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __rmul__(self, other: Any) -> Self: ...
def __ror__(self, other: Any) -> Self: ...
def all(self) -> bool: ...
def any(self) -> bool: ...
def any_value(self, *, ignore_nulls: bool) -> PythonLiteral: ...
def arg_max(self) -> int: ...
def arg_min(self) -> int: ...
def arg_true(self) -> Self: ...
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_compliant/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ class ScalarKwargs(TypedDict, total=False):
"any",
"first",
"last",
"any_value",
]
"""`Expr` methods we aim to support in `DepthTrackingGroupBy`.

Be sure to update me if you're working on one of these:
- https://github.com/narwhals-dev/narwhals/issues/981
- https://github.com/narwhals-dev/narwhals/issues/2385
- https://github.com/narwhals-dev/narwhals/issues/2484
- https://github.com/narwhals-dev/narwhals/issues/2526
Expand Down
1 change: 1 addition & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ def str(self) -> DaskExprStringNamespace:
def dt(self) -> DaskExprDateTimeNamespace:
return DaskExprDateTimeNamespace(self)

any_value = not_implemented()
filter = not_implemented()
first = not_implemented()
rank = not_implemented()
Expand Down
9 changes: 9 additions & 0 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ def _first(self, expr: Expression, *order_by: str) -> Expression:
def _last(self, expr: Expression, *order_by: str) -> Expression:
return self._first_last("last", expr, order_by)

def _any_value(self, expr: Expression, *, ignore_nulls: bool) -> Expression:
# !NOTE: DuckDB `any_value` returns the first non-null value
# See: https://duckdb.org/docs/stable/sql/functions/aggregates#any_valuearg
return (
self._function("any_value", expr)
if ignore_nulls
else self._function("first", expr)
)

def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover
from narwhals._duckdb.namespace import DuckDBNamespace

Expand Down
6 changes: 6 additions & 0 deletions narwhals/_ibis/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ def _last(self, expr: ir.Value, *order_by: str) -> ir.Value:
order_by=self._sort(*order_by), include_null=True
)

def _any_value(self, expr: ir.Value, *, ignore_nulls: bool) -> ir.Value:
# !NOTE: ibis arbitrary returns a random non-null value
# See: https://ibis-project.org/reference/expression-generic.html#ibis.expr.types.generic.Column.arbitrary
expr = cast("ir.Column", expr)
return expr.arbitrary() if ignore_nulls else expr.first(include_null=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expr.first(include_null=not ignore_nulls) is not behaving as expected 🫠


def __narwhals_namespace__(self) -> IbisNamespace: # pragma: no cover
from narwhals._ibis.namespace import IbisNamespace

Expand Down
19 changes: 12 additions & 7 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,13 @@ def window_kwargs_to_pandas_equivalent( # noqa: C901
"min_periods": kwargs["min_samples"],
"ignore_na": kwargs["ignore_nulls"],
}
elif function_name in {"first", "last"}:
elif function_name in {"first", "last", "any_value"}:
if kwargs.get("ignore_nulls"):
msg = (
"`Expr.any_value(ignore_nulls=True)` is not supported in a `over` "
"context for pandas-like backend."
)
raise NotImplementedError(msg)
Comment on lines +113 to +118
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an issue if a group has all null values

pandas_kwargs = {
"n": _REMAP_ORDERED_INDEX[cast("NarwhalsAggregation", function_name)]
}
Expand Down Expand Up @@ -357,17 +363,16 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901,
msg = "Safety check failed, please report a bug."
raise AssertionError(msg)
res_native = grouped.transform("size").to_frame(aliases[0])
elif function_name in {"first", "last"}:
elif function_name in {"first", "last", "any_value"}:
with warnings.catch_warnings():
# Ignore settingwithcopy warnings/errors, they're false-positives here.
warnings.filterwarnings("ignore", message="\n.*copy of a slice")
_nth = getattr(
_agg = getattr(
grouped[[*partition_by, *aliases]], pandas_function_name
)(**pandas_kwargs)
_nth.reset_index(drop=True, inplace=True)
res_native = df.native[list(partition_by)].merge(
_nth, on=list(partition_by)
)[list(aliases)]
_agg.reset_index(drop=True, inplace=True)
keys = list(partition_by)
res_native = df.native[keys].merge(_agg, on=keys)[list(aliases)]
else:
res_native = grouped[list(aliases)].transform(
pandas_function_name, **pandas_kwargs
Expand Down
13 changes: 12 additions & 1 deletion narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
_REMAP_ORDERED_INDEX: Mapping[NarwhalsAggregation, Literal[0, -1]] = {
"first": 0,
"last": -1,
"any_value": 0,
}


Expand Down Expand Up @@ -151,7 +152,7 @@ def _getitem_aggs(
for col in cols
]
)
elif self.is_last() or self.is_first():
elif self.is_last() or self.is_first() or self.is_any_value():
result = self.native_agg()(group_by._grouped[[*group_by._keys, *names]])
result.set_index(group_by._keys, inplace=True) # noqa: PD002
else:
Expand All @@ -175,6 +176,9 @@ def is_first(self) -> bool:
def is_mode(self) -> bool:
return self.leaf_name == "mode"

def is_any_value(self) -> bool:
return self.leaf_name == "any_value"

def is_top_level_function(self) -> bool:
# e.g. `nw.len()`.
return len(list(self.expr._metadata.op_nodes_reversed())) == 1
Expand All @@ -191,6 +195,12 @@ def native_agg(self) -> _NativeAgg:
native_name = PandasLikeGroupBy._remap_expr_name(self.leaf_name)
last_node = next(self.expr._metadata.op_nodes_reversed())
if self.leaf_name in _REMAP_ORDERED_INDEX:
if last_node.kwargs.get("ignore_nulls"):
msg = (
"`Expr.any_value(ignore_nulls=True)` is not supported in a `group_by` "
"context for pandas-like backend"
)
raise NotImplementedError(msg)
Comment on lines +198 to +203
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an issue if a group has all null values

return methodcaller("nth", n=_REMAP_ORDERED_INDEX[self.leaf_name])
return _native_agg(native_name, **last_node.kwargs)

Expand All @@ -215,6 +225,7 @@ class PandasLikeGroupBy(
"any": "any",
"first": "nth",
"last": "nth",
"any_value": "nth",
}
_original_columns: tuple[str, ...]
"""Column names *prior* to any aliasing in `ParseKeysGroupBy`."""
Expand Down
3 changes: 3 additions & 0 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,9 @@ def exp(self) -> Self:
def sqrt(self) -> Self:
return self._with_native(self.native.pow(0.5))

def any_value(self, *, ignore_nulls: bool) -> PythonLiteral:
return self.drop_nulls().first() if ignore_nulls else self.first()

@property
def str(self) -> PandasLikeSeriesStringNamespace:
return PandasLikeSeriesStringNamespace(self)
Expand Down
3 changes: 3 additions & 0 deletions narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ def mode(self, *, keep: ModeKeepStrategy) -> Self:
result = self.native.mode()
return self._with_native(result.first() if keep == "any" else result)

def any_value(self, *, ignore_nulls: bool) -> Self:
return self.drop_nulls().first() if ignore_nulls else self.first()

@property
def dt(self) -> PolarsExprDateTimeNamespace:
return PolarsExprDateTimeNamespace(self)
Expand Down
3 changes: 3 additions & 0 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,9 @@ def last(self) -> PythonLiteral:
return self.native.item(-1) if len(self) else None
return self.native.last() # type: ignore[return-value]

def any_value(self, *, ignore_nulls: bool) -> PythonLiteral:
return self.drop_nulls().first() if ignore_nulls else self.first()

@property
def dt(self) -> PolarsSeriesDateTimeNamespace:
return PolarsSeriesDateTimeNamespace(self)
Expand Down
3 changes: 3 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def _last(self, expr: Column, *order_by: str) -> Column: # pragma: no cover
msg = "`last` is not supported for PySpark."
raise NotImplementedError(msg)

def _any_value(self, expr: Column, *, ignore_nulls: bool) -> Column:
return self._F.any_value(expr, ignoreNulls=ignore_nulls)

def broadcast(self) -> Self:
return self.over([self._F.lit(1)], [])

Expand Down
6 changes: 6 additions & 0 deletions narwhals/_sql/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def _coalesce(self, *expr: NativeExprT) -> NativeExprT:
def _count_star(self) -> NativeExprT: ...
def _first(self, expr: NativeExprT, *order_by: str) -> NativeExprT: ...
def _last(self, expr: NativeExprT, *order_by: str) -> NativeExprT: ...
def _any_value(self, expr: NativeExprT, *, ignore_nulls: bool) -> NativeExprT: ...

def _when(
self,
Expand Down Expand Up @@ -759,6 +760,11 @@ def func(

return self._with_window_function(func)

def any_value(self, *, ignore_nulls: bool) -> Self:
return self._with_callable(
lambda expr: self._any_value(expr, ignore_nulls=ignore_nulls)
)

def rank(self, method: RankMethod, *, descending: bool) -> Self:
if method in {"min", "max", "average"}:
func = self._function("rank")
Expand Down
42 changes: 42 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2368,6 +2368,48 @@ def is_close( # noqa: PLR0914

return result

def any_value(self, *, ignore_nulls: bool = False) -> Self:
"""Get a random value from the column.

Arguments:
ignore_nulls: Whether to ignore null values or not.
If `True` and there are no not-null elements, then `None` is returned.

Examples:
>>> import pyarrow as pa
>>> import narwhals as nw
>>> data = {"a": [1, 1, 2, 2], "b": [None, "foo", "baz", None]}
>>> df_native = pa.table(data)
>>> df = nw.from_native(df_native)
>>> df.select(nw.all().any_value(ignore_nulls=False))
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
|Narwhals DataFrame|
|------------------|
| pyarrow.Table |
| a: int64 |
| b: null |
| ---- |
| a: [[1]] |
| b: [1 nulls] |
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

>>> df.group_by("a").agg(nw.col("b").any_value(ignore_nulls=True))
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
|Narwhals DataFrame|
|------------------|
|pyarrow.Table |
|a: int64 |
|b: string |
|---- |
|a: [[1,2]] |
|b: [["foo","baz"]]|
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

"""
return self._append_node(
ExprNode(ExprKind.AGGREGATION, "any_value", ignore_nulls=ignore_nulls)
)

@property
def str(self) -> ExprStringNamespace[Self]:
return ExprStringNamespace(self)
Expand Down
17 changes: 17 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2844,6 +2844,23 @@ def is_close(
result = result.rename(orig_name) if name_is_none else result
return cast("Self", result)

def any_value(self, *, ignore_nulls: bool = False) -> PythonLiteral:
"""Get a random value from the column.

Arguments:
ignore_nulls: Whether to ignore null values or not.
If `True` and there are no not-null elements, then `None` is returned.

Examples:
>>> import pyarrow as pa
>>> import narwhals as nw
>>> s_native = pa.chunked_array([[1, 2, None]])
>>> s = nw.from_native(s_native, series_only=True)
>>> s.any_value(ignore_nulls=True)
1
"""
return self._compliant_series.any_value(ignore_nulls=ignore_nulls)

@property
def str(self) -> SeriesStringNamespace[Self]:
return SeriesStringNamespace(self)
Expand Down
Loading
Loading