Skip to content

Commit 1d9ac1c

Browse files
Nathan-Bransby-NMTdevdanzinpre-commit-ci[bot]MarcoGorelli
authored
feat: add Series and Expr arg_min & arg_max (#1529)
--------- Co-authored-by: devdanzin <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Marco Gorelli <[email protected]>
1 parent b567aff commit 1d9ac1c

File tree

11 files changed

+281
-4
lines changed

11 files changed

+281
-4
lines changed

docs/api-reference/expr.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
- alias
99
- all
1010
- any
11+
- arg_max
12+
- arg_min
1113
- arg_true
1214
- cast
1315
- count

docs/api-reference/series.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
- alias
1212
- all
1313
- any
14+
- arg_max
15+
- arg_min
1416
- arg_true
1517
- cast
1618
- clip

narwhals/_arrow/expr.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,12 @@ def min(self: Self) -> Self:
272272
def max(self: Self) -> Self:
273273
return reuse_series_implementation(self, "max", returns_scalar=True)
274274

275+
def arg_min(self: Self) -> Self:
276+
return reuse_series_implementation(self, "arg_min", returns_scalar=True)
277+
278+
def arg_max(self: Self) -> Self:
279+
return reuse_series_implementation(self, "arg_max", returns_scalar=True)
280+
275281
def all(self: Self) -> Self:
276282
return reuse_series_implementation(self, "all", returns_scalar=True)
277283

narwhals/_arrow/series.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,18 @@ def max(self: Self, *, _return_py_scalar: bool = True) -> int:
288288

289289
return maybe_extract_py_scalar(pc.max(self._native_series), _return_py_scalar) # type: ignore[no-any-return]
290290

291+
def arg_min(self: Self, *, _return_py_scalar: bool = True) -> int:
292+
import pyarrow.compute as pc
293+
294+
index_min = pc.index(self._native_series, pc.min(self._native_series))
295+
return maybe_extract_py_scalar(index_min, _return_py_scalar) # type: ignore[no-any-return]
296+
297+
def arg_max(self: Self, *, _return_py_scalar: bool = True) -> int:
298+
import pyarrow.compute as pc
299+
300+
index_max = pc.index(self._native_series, pc.max(self._native_series))
301+
return maybe_extract_py_scalar(index_max, _return_py_scalar) # type: ignore[no-any-return]
302+
291303
def sum(self: Self, *, _return_py_scalar: bool = True) -> int:
292304
import pyarrow.compute as pc
293305

narwhals/_pandas_like/expr.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,12 @@ def max(self) -> Self:
264264
def min(self) -> Self:
265265
return reuse_series_implementation(self, "min", returns_scalar=True)
266266

267+
def arg_min(self) -> Self:
268+
return reuse_series_implementation(self, "arg_min", returns_scalar=True)
269+
270+
def arg_max(self) -> Self:
271+
return reuse_series_implementation(self, "arg_max", returns_scalar=True)
272+
267273
# Other
268274

269275
def clip(self, lower_bound: Any, upper_bound: Any) -> Self:

narwhals/_pandas_like/series.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,18 @@ def arg_true(self) -> PandasLikeSeries:
284284
result = ser.__class__(range(len(ser)), name=ser.name, index=ser.index).loc[ser]
285285
return self._from_native_series(result)
286286

287+
def arg_min(self) -> int:
288+
ser = self._native_series
289+
if self._implementation is Implementation.PANDAS and self._backend_version < (1,):
290+
return ser.values.argmin() # type: ignore[no-any-return] # noqa: PD011
291+
return ser.argmin() # type: ignore[no-any-return]
292+
293+
def arg_max(self) -> int:
294+
ser = self._native_series
295+
if self._implementation is Implementation.PANDAS and self._backend_version < (1,):
296+
return ser.values.argmax() # type: ignore[no-any-return] # noqa: PD011
297+
return ser.argmax() # type: ignore[no-any-return]
298+
287299
# Binary comparisons
288300

289301
def filter(self, other: Any) -> PandasLikeSeries:

narwhals/expr.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,102 @@ def max(self) -> Self:
962962
"""
963963
return self.__class__(lambda plx: self._to_compliant_expr(plx).max())
964964

965+
def arg_min(self) -> Self:
966+
"""Returns the index of the minimum value.
967+
968+
Returns:
969+
A new expression.
970+
971+
Examples:
972+
>>> import polars as pl
973+
>>> import pandas as pd
974+
>>> import pyarrow as pa
975+
>>> import narwhals as nw
976+
>>> from narwhals.typing import IntoFrameT
977+
>>> df_pd = pd.DataFrame({"a": [10, 20], "b": [150, 100]})
978+
>>> df_pl = pl.DataFrame({"a": [10, 20], "b": [150, 100]})
979+
>>> df_pa = pa.table({"a": [10, 20], "b": [150, 100]})
980+
981+
Let's define a dataframe-agnostic function:
982+
983+
>>> def agnostic_arg_min(df_native: IntoFrameT) -> IntoFrameT:
984+
... df = nw.from_native(df_native)
985+
... return df.select(
986+
... nw.col("a", "b").arg_min().name.suffix("_arg_min")
987+
... ).to_native()
988+
989+
We can then pass any supported library such as Pandas, Polars, or PyArrow:
990+
991+
>>> agnostic_arg_min(df_pd)
992+
a_arg_min b_arg_min
993+
0 0 1
994+
>>> agnostic_arg_min(df_pl)
995+
shape: (1, 2)
996+
┌───────────┬───────────┐
997+
│ a_arg_min ┆ b_arg_min │
998+
│ --- ┆ --- │
999+
│ u32 ┆ u32 │
1000+
╞═══════════╪═══════════╡
1001+
│ 0 ┆ 1 │
1002+
└───────────┴───────────┘
1003+
>>> agnostic_arg_min(df_pa)
1004+
pyarrow.Table
1005+
a_arg_min: int64
1006+
b_arg_min: int64
1007+
----
1008+
a_arg_min: [[0]]
1009+
b_arg_min: [[1]]
1010+
"""
1011+
return self.__class__(lambda plx: self._to_compliant_expr(plx).arg_min())
1012+
1013+
def arg_max(self) -> Self:
1014+
"""Returns the index of the maximum value.
1015+
1016+
Returns:
1017+
A new expression.
1018+
1019+
Examples:
1020+
>>> import polars as pl
1021+
>>> import pandas as pd
1022+
>>> import pyarrow as pa
1023+
>>> import narwhals as nw
1024+
>>> from narwhals.typing import IntoFrameT
1025+
>>> df_pd = pd.DataFrame({"a": [10, 20], "b": [150, 100]})
1026+
>>> df_pl = pl.DataFrame({"a": [10, 20], "b": [150, 100]})
1027+
>>> df_pa = pa.table({"a": [10, 20], "b": [150, 100]})
1028+
1029+
Let's define a dataframe-agnostic function:
1030+
1031+
>>> def agnostic_arg_max(df_native: IntoFrameT) -> IntoFrameT:
1032+
... df = nw.from_native(df_native)
1033+
... return df.select(
1034+
... nw.col("a", "b").arg_max().name.suffix("_arg_max")
1035+
... ).to_native()
1036+
1037+
We can then pass any supported library such as Pandas, Polars, or PyArrow:
1038+
1039+
>>> agnostic_arg_max(df_pd)
1040+
a_arg_max b_arg_max
1041+
0 1 0
1042+
>>> agnostic_arg_max(df_pl)
1043+
shape: (1, 2)
1044+
┌───────────┬───────────┐
1045+
│ a_arg_max ┆ b_arg_max │
1046+
│ --- ┆ --- │
1047+
│ u32 ┆ u32 │
1048+
╞═══════════╪═══════════╡
1049+
│ 1 ┆ 0 │
1050+
└───────────┴───────────┘
1051+
>>> agnostic_arg_max(df_pa)
1052+
pyarrow.Table
1053+
a_arg_max: int64
1054+
b_arg_max: int64
1055+
----
1056+
a_arg_max: [[1]]
1057+
b_arg_max: [[0]]
1058+
"""
1059+
return self.__class__(lambda plx: self._to_compliant_expr(plx).arg_max())
1060+
9651061
def count(self) -> Self:
9661062
"""Returns the number of non-null elements in the column.
9671063

narwhals/series.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,70 @@ def max(self) -> Any:
931931
"""
932932
return self._compliant_series.max()
933933

934+
def arg_min(self) -> int:
935+
"""Returns the index of the minimum value.
936+
937+
Examples:
938+
>>> import pandas as pd
939+
>>> import polars as pl
940+
>>> import pyarrow as pa
941+
>>> import narwhals as nw
942+
>>> from narwhals.typing import IntoSeries
943+
>>> s = [1, 2, 3]
944+
>>> s_pd = pd.Series(s)
945+
>>> s_pl = pl.Series(s)
946+
>>> s_pa = pa.chunked_array([s])
947+
948+
We define a library agnostic function:
949+
950+
>>> def agnostic_arg_min(s_native: IntoSeries):
951+
... s = nw.from_native(s_native, series_only=True)
952+
... return s.arg_min()
953+
954+
We can then pass either any supported library such as pandas, Polars,
955+
or PyArrow:
956+
957+
>>> agnostic_arg_min(s_pd)
958+
np.int64(0)
959+
>>> agnostic_arg_min(s_pl)
960+
0
961+
>>> agnostic_arg_min(s_pa)
962+
0
963+
"""
964+
return self._compliant_series.arg_min() # type: ignore[no-any-return]
965+
966+
def arg_max(self) -> int:
967+
"""Returns the index of the maximum value.
968+
969+
Examples:
970+
>>> import pandas as pd
971+
>>> import polars as pl
972+
>>> import pyarrow as pa
973+
>>> import narwhals as nw
974+
>>> from narwhals.typing import IntoSeries
975+
>>> s = [1, 2, 3]
976+
>>> s_pd = pd.Series(s)
977+
>>> s_pl = pl.Series(s)
978+
>>> s_pa = pa.chunked_array([s])
979+
980+
We define a library agnostic function:
981+
982+
>>> def agnostic_arg_max(s_native: IntoSeries):
983+
... s = nw.from_native(s_native, series_only=True)
984+
... return s.arg_max()
985+
986+
We can then pass either any supported library such as pandas, Polars,
987+
or PyArrow:
988+
989+
>>> agnostic_arg_max(s_pd)
990+
np.int64(2)
991+
>>> agnostic_arg_max(s_pl)
992+
2
993+
>>> agnostic_arg_max(s_pa)
994+
2
995+
"""
996+
return self._compliant_series.arg_max() # type: ignore[no-any-return]
997+
934998
def sum(self) -> Any:
935999
"""Reduce this Series to the sum value.
9361000

pyproject.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,20 +142,19 @@ omit = [
142142
'narwhals/typing.py',
143143
'narwhals/stable/v1/typing.py',
144144
'narwhals/this.py',
145-
# we can run this in every environment that we measure coverage on due to upper-bound constraits
145+
# we can't run this in every environment that we measure coverage on due to upper-bound constraits
146146
'narwhals/_ibis/*',
147147
# the latest pyspark (3.5) doesn't officially support Python 3.12 and 3.13
148148
'narwhals/_spark_like/*',
149149
'tests/spark_like_test.py',
150150
]
151151
exclude_also = [
152-
"> POLARS_VERSION",
153152
"if sys.version_info() <",
154153
"if (:?self._)?implementation is Implementation.MODIN",
155154
"if (:?self._)?implementation is Implementation.CUDF",
156155
'request.applymarker\(pytest.mark.xfail\)',
157-
'if \w+._backend_version < ',
158-
'if backend_version <',
156+
'\w+._backend_version < ',
157+
'backend_version <',
159158
'if "cudf" in str\(constructor'
160159
]
161160

tests/expr_and_series/arg_max_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
import narwhals.stable.v1 as nw
6+
from tests.utils import Constructor
7+
from tests.utils import ConstructorEager
8+
from tests.utils import assert_equal_data
9+
10+
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9], "i": [3, 1, 5]}
11+
12+
13+
def test_expr_arg_max_expr(
14+
constructor: Constructor, request: pytest.FixtureRequest
15+
) -> None:
16+
if "dask" in str(constructor):
17+
# This operation is row-order dependent so we don't support it for Dask
18+
request.applymarker(pytest.mark.xfail)
19+
if "modin" in str(constructor):
20+
# TODO(unassigned): bug in modin?
21+
return
22+
df = nw.from_native(constructor(data))
23+
df = nw.maybe_set_index(df, "i")
24+
result = df.select(nw.col("a", "b", "z").arg_max())
25+
expected = {"a": [1], "b": [2], "z": [2]}
26+
assert_equal_data(result, expected)
27+
28+
29+
@pytest.mark.parametrize(("col", "expected"), [("a", 1), ("b", 2), ("z", 2)])
30+
def test_expr_arg_max_series(
31+
constructor_eager: ConstructorEager,
32+
col: str,
33+
expected: float,
34+
) -> None:
35+
if "modin" in str(constructor_eager):
36+
# TODO(unassigned): bug in modin?
37+
return
38+
series = nw.from_native(constructor_eager(data), eager_only=True)[col]
39+
series = nw.maybe_set_index(series, index=[1, 0, 9]) # type: ignore[arg-type]
40+
result = series.arg_max()
41+
assert_equal_data({col: [result]}, {col: [expected]})

tests/expr_and_series/arg_min_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
import narwhals.stable.v1 as nw
6+
from tests.utils import Constructor
7+
from tests.utils import ConstructorEager
8+
from tests.utils import assert_equal_data
9+
10+
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
11+
12+
13+
def test_expr_arg_min_expr(
14+
constructor: Constructor, request: pytest.FixtureRequest
15+
) -> None:
16+
if "dask" in str(constructor):
17+
# This operation is row-order dependent so we don't support it for Dask
18+
request.applymarker(pytest.mark.xfail)
19+
df = nw.from_native(constructor(data))
20+
result = df.select(nw.col("a", "b", "z").arg_min())
21+
expected = {"a": [0], "b": [0], "z": [0]}
22+
assert_equal_data(result, expected)
23+
24+
25+
@pytest.mark.parametrize(("col", "expected"), [("a", 0), ("b", 0), ("z", 0)])
26+
def test_expr_arg_min_series(
27+
constructor_eager: ConstructorEager,
28+
col: str,
29+
expected: float,
30+
) -> None:
31+
if "modin" in str(constructor_eager):
32+
# TODO(unassigned): bug in modin?
33+
return
34+
series = nw.from_native(constructor_eager(data), eager_only=True)[col]
35+
series = nw.maybe_set_index(series, index=[1, 0, 9]) # type: ignore[arg-type]
36+
result = series.arg_min()
37+
assert_equal_data({col: [result]}, {col: [expected]})

0 commit comments

Comments
 (0)