Skip to content

Commit 07f3d0e

Browse files
thevroTheo Everley RöhrkastenMarcoGorelli
authored
fix: allow np.scalar to be used in Series.__getitem__ (#1515)
--------- Co-authored-by: Theo Everley Röhrkasten <[email protected]> Co-authored-by: Marco Gorelli <[email protected]>
1 parent 1d9ac1c commit 07f3d0e

File tree

4 files changed

+17
-2
lines changed

4 files changed

+17
-2
lines changed

narwhals/_pandas_like/series.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from narwhals._pandas_like.utils import select_columns_by_name
2020
from narwhals._pandas_like.utils import set_axis
2121
from narwhals._pandas_like.utils import to_datetime
22+
from narwhals.dependencies import is_numpy_scalar
2223
from narwhals.typing import CompliantSeries
2324
from narwhals.utils import Implementation
2425
from narwhals.utils import import_dtypes_module
@@ -114,7 +115,7 @@ def __getitem__(self, idx: int) -> Any: ...
114115
def __getitem__(self, idx: slice | Sequence[int]) -> Self: ...
115116

116117
def __getitem__(self, idx: int | slice | Sequence[int]) -> Any | Self:
117-
if isinstance(idx, int):
118+
if isinstance(idx, int) or is_numpy_scalar(idx):
118119
return self._native_series.iloc[idx]
119120
return self._from_native_series(self._native_series.iloc[idx])
120121

narwhals/dependencies.py

+5
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,11 @@ def is_numpy_array(arr: Any) -> TypeGuard[np.ndarray]:
223223
return (np := get_numpy()) is not None and isinstance(arr, np.ndarray)
224224

225225

226+
def is_numpy_scalar(scalar: Any) -> TypeGuard[np.generic]:
227+
"""Check whether `scalar` is a NumPy Scalar without importing NumPy."""
228+
return (np := get_numpy()) is not None and np.isscalar(scalar)
229+
230+
226231
def is_pandas_like_dataframe(df: Any) -> bool:
227232
"""Check whether `df` is a pandas-like DataFrame without doing any imports.
228233

narwhals/series.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import TypeVar
1212
from typing import overload
1313

14+
from narwhals.dependencies import is_numpy_scalar
1415
from narwhals.dtypes import _validate_dtype
1516
from narwhals.typing import IntoSeriesT
1617
from narwhals.utils import _validate_rolling_arguments
@@ -160,7 +161,9 @@ def __getitem__(self: Self, idx: int | slice | Sequence[int]) -> Any | Self:
160161
]
161162
]
162163
"""
163-
if isinstance(idx, int):
164+
if isinstance(idx, int) or (
165+
is_numpy_scalar(idx) and idx.dtype.kind in ("i", "u")
166+
):
164167
return self._compliant_series[idx]
165168
return self._from_compliant_series(self._compliant_series[idx])
166169

tests/series_only/slice_test.py

+6
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,9 @@ def test_slice(constructor_eager: ConstructorEager) -> None:
3939
def test_getitem_arrow_scalar() -> None:
4040
result = nw.from_native(pa.chunked_array([[1]]), series_only=True)[0]
4141
assert isinstance(result, int)
42+
43+
44+
def test_index(constructor_eager: ConstructorEager) -> None:
45+
df = constructor_eager({"a": [0, 1, 2]})
46+
snw = nw.from_native(df, eager_only=True)["a"]
47+
assert snw[snw[0]] == 0

0 commit comments

Comments
 (0)