Skip to content

Commit 0c1650c

Browse files
feat: add to_py_scalar (#1194)
* add to_py_scalar * fix tests * more fixes pragma and doctsting * fix test_to_py_scalar_cudf_series * convert numpy scalars * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove parse version * simplify test_to_py_scalar_arrays_series * add conversion for datetime and timedelta * stricter to_py_scalar --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 59aa483 commit 0c1650c

File tree

6 files changed

+162
-0
lines changed

6 files changed

+162
-0
lines changed

docs/api-reference/narwhals.md

+1
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,5 @@ Here are the top-level functions available in Narwhals.
3939
- when
4040
- show_versions
4141
- to_native
42+
- to_py_scalar
4243
show_source: false

narwhals/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from narwhals.translate import get_native_namespace
5959
from narwhals.translate import narwhalify
6060
from narwhals.translate import to_native
61+
from narwhals.translate import to_py_scalar
6162
from narwhals.utils import is_ordered_categorical
6263
from narwhals.utils import maybe_align_index
6364
from narwhals.utils import maybe_convert_dtypes
@@ -84,6 +85,7 @@
8485
"maybe_reset_index",
8586
"maybe_set_index",
8687
"get_native_namespace",
88+
"to_py_scalar",
8789
"all",
8890
"all_horizontal",
8991
"any_horizontal",

narwhals/dependencies.py

+5
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ def get_cudf() -> Any:
4646
return sys.modules.get("cudf", None)
4747

4848

49+
def get_cupy() -> Any:
50+
"""Get cupy module (if already imported - else return None)."""
51+
return sys.modules.get("cupy", None)
52+
53+
4954
def get_pyarrow() -> Any: # pragma: no cover
5055
"""Get pyarrow module (if already imported - else return None)."""
5156
return sys.modules.get("pyarrow", None)

narwhals/stable/v1/__init__.py

+24
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from narwhals.translate import _from_native_impl
5252
from narwhals.translate import get_native_namespace as nw_get_native_namespace
5353
from narwhals.translate import to_native
54+
from narwhals.translate import to_py_scalar as nw_to_py_scalar
5455
from narwhals.typing import IntoDataFrameT
5556
from narwhals.typing import IntoFrameT
5657
from narwhals.typing import IntoSeriesT
@@ -952,6 +953,28 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
952953
return decorator(func)
953954

954955

956+
def to_py_scalar(scalar: Any) -> Any:
957+
"""If a scalar is not Python native, converts it to Python native.
958+
959+
Raises:
960+
ValueError: If the object is not convertible to a scalar.
961+
962+
Examples:
963+
>>> import narwhals.stable.v1 as nw
964+
>>> import pandas as pd
965+
>>> df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]}))
966+
>>> nw.to_py_scalar(df["a"].item(0))
967+
1
968+
>>> import pyarrow as pa
969+
>>> df = nw.from_native(pa.table({"a": [1, 2, 3]}))
970+
>>> nw.to_py_scalar(df["a"].item(0))
971+
1
972+
>>> nw.to_py_scalar(1)
973+
1
974+
"""
975+
return _stableify(nw_to_py_scalar(scalar))
976+
977+
955978
def all() -> Expr:
956979
"""
957980
Instantiate an expression representing all columns.
@@ -2306,6 +2329,7 @@ def from_dict(
23062329
"dependencies",
23072330
"to_native",
23082331
"from_native",
2332+
"to_py_scalar",
23092333
"is_ordered_categorical",
23102334
"maybe_align_index",
23112335
"maybe_convert_dtypes",

narwhals/translate.py

+67
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
import numbers
4+
from datetime import datetime
5+
from datetime import timedelta
36
from functools import wraps
47
from typing import TYPE_CHECKING
58
from typing import Any
@@ -9,9 +12,11 @@
912
from typing import overload
1013

1114
from narwhals.dependencies import get_cudf
15+
from narwhals.dependencies import get_cupy
1216
from narwhals.dependencies import get_dask
1317
from narwhals.dependencies import get_dask_expr
1418
from narwhals.dependencies import get_modin
19+
from narwhals.dependencies import get_numpy
1520
from narwhals.dependencies import get_pandas
1621
from narwhals.dependencies import get_polars
1722
from narwhals.dependencies import get_pyarrow
@@ -776,8 +781,70 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
776781
return decorator(func)
777782

778783

784+
def to_py_scalar(scalar_like: Any) -> Any:
785+
"""If a scalar is not Python native, converts it to Python native.
786+
787+
Raises:
788+
ValueError: If the object is not convertible to a scalar.
789+
790+
Examples:
791+
>>> import narwhals as nw
792+
>>> import pandas as pd
793+
>>> df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]}))
794+
>>> nw.to_py_scalar(df["a"].item(0))
795+
1
796+
>>> import pyarrow as pa
797+
>>> df = nw.from_native(pa.table({"a": [1, 2, 3]}))
798+
>>> nw.to_py_scalar(df["a"].item(0))
799+
1
800+
>>> nw.to_py_scalar(1)
801+
1
802+
"""
803+
804+
pa = get_pyarrow()
805+
if pa and isinstance(scalar_like, pa.Scalar):
806+
return scalar_like.as_py()
807+
808+
cupy = get_cupy()
809+
if ( # pragma: no cover
810+
cupy and isinstance(scalar_like, cupy.ndarray) and scalar_like.size == 1
811+
):
812+
return scalar_like.item()
813+
814+
np = get_numpy()
815+
if np and np.isscalar(scalar_like) and hasattr(scalar_like, "item"):
816+
return scalar_like.item()
817+
818+
pd = get_pandas()
819+
if pd and isinstance(scalar_like, pd.Timestamp):
820+
return scalar_like.to_pydatetime()
821+
if pd and isinstance(scalar_like, pd.Timedelta):
822+
return scalar_like.to_pytimedelta()
823+
824+
all_scalar_types = (
825+
int,
826+
float,
827+
complex,
828+
bool,
829+
bytes,
830+
str,
831+
datetime,
832+
timedelta,
833+
numbers.Number,
834+
)
835+
if isinstance(scalar_like, all_scalar_types):
836+
return scalar_like
837+
838+
msg = (
839+
f"Expected object convertible to a scalar, found {type(scalar_like)}. "
840+
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
841+
)
842+
raise ValueError(msg)
843+
844+
779845
__all__ = [
780846
"get_native_namespace",
781847
"to_native",
782848
"narwhalify",
849+
"to_py_scalar",
783850
]

tests/translate/to_py_scalar_test.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import annotations
2+
3+
from datetime import datetime
4+
from datetime import timedelta
5+
from typing import TYPE_CHECKING
6+
from typing import Any
7+
8+
import numpy as np
9+
import pandas as pd
10+
import pytest
11+
12+
import narwhals.stable.v1 as nw
13+
from narwhals.dependencies import get_cudf
14+
15+
if TYPE_CHECKING:
16+
from tests.utils import ConstructorEager
17+
18+
19+
@pytest.mark.parametrize(
20+
("input_value", "expected"),
21+
[
22+
(1, 1),
23+
(1.0, 1.0),
24+
("a", "a"),
25+
(True, True),
26+
(b"a", b"a"),
27+
(datetime(2021, 1, 1), datetime(2021, 1, 1)),
28+
(timedelta(days=1), timedelta(days=1)),
29+
],
30+
)
31+
def test_to_py_scalar(
32+
constructor_eager: ConstructorEager, input_value: Any, expected: Any
33+
) -> None:
34+
df = nw.from_native(constructor_eager({"a": [input_value]}))
35+
output = nw.to_py_scalar(df["a"].item(0))
36+
if expected == 1 and constructor_eager.__name__.startswith("pandas"):
37+
assert not isinstance(output, np.int64)
38+
elif isinstance(expected, datetime) and constructor_eager.__name__.startswith(
39+
"pandas"
40+
):
41+
assert not isinstance(output, pd.Timestamp)
42+
elif isinstance(expected, timedelta) and constructor_eager.__name__.startswith(
43+
"pandas"
44+
):
45+
assert not isinstance(output, pd.Timedelta)
46+
assert output == expected
47+
48+
49+
@pytest.mark.parametrize(
50+
"input_value",
51+
[np.array([1, 2]), [1, 2, 3], {"a": [1, 2, 3]}],
52+
)
53+
def test_to_py_scalar_value_error(input_value: Any) -> None:
54+
with pytest.raises(ValueError, match="Expected object convertible to a scalar"):
55+
nw.to_py_scalar(input_value)
56+
57+
58+
def test_to_py_scalar_value_error_cudf() -> None:
59+
if cudf := get_cudf(): # pragma: no cover
60+
df = nw.from_native(cudf.DataFrame({"a": [1, 2, 3]}))
61+
62+
with pytest.raises(ValueError, match="Expected object convertible to a scalar"):
63+
nw.to_py_scalar(df["a"])

0 commit comments

Comments
 (0)