Skip to content

Commit 7d07dc1

Browse files
authored
perf: make to_py_scalar 3x faster when argument is str (#1276)
1 parent fe2abe1 commit 7d07dc1

File tree

4 files changed

+42
-47
lines changed

4 files changed

+42
-47
lines changed

narwhals/stable/v1/__init__.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from narwhals.translate import _from_native_impl
5252
from narwhals.translate import get_native_namespace as nw_get_native_namespace
5353
from narwhals.translate import to_native
54-
from narwhals.translate import to_py_scalar as nw_to_py_scalar
54+
from narwhals.translate import to_py_scalar
5555
from narwhals.typing import IntoDataFrameT
5656
from narwhals.typing import IntoFrameT
5757
from narwhals.typing import IntoSeriesT
@@ -956,28 +956,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
956956
return decorator(func)
957957

958958

959-
def to_py_scalar(scalar: Any) -> Any:
960-
"""If a scalar is not Python native, converts it to Python native.
961-
962-
Raises:
963-
ValueError: If the object is not convertible to a scalar.
964-
965-
Examples:
966-
>>> import narwhals.stable.v1 as nw
967-
>>> import pandas as pd
968-
>>> df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]}))
969-
>>> nw.to_py_scalar(df["a"].item(0))
970-
1
971-
>>> import pyarrow as pa
972-
>>> df = nw.from_native(pa.table({"a": [1, 2, 3]}))
973-
>>> nw.to_py_scalar(df["a"].item(0))
974-
1
975-
>>> nw.to_py_scalar(1)
976-
1
977-
"""
978-
return _stableify(nw_to_py_scalar(scalar))
979-
980-
981959
def all() -> Expr:
982960
"""
983961
Instantiate an expression representing all columns.

narwhals/translate.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import numbers
43
from datetime import datetime
54
from datetime import timedelta
65
from functools import wraps
@@ -46,6 +45,15 @@
4645

4746
T = TypeVar("T")
4847

48+
NON_TEMPORAL_SCALAR_TYPES = (
49+
bool,
50+
bytes,
51+
str,
52+
int,
53+
float,
54+
complex,
55+
)
56+
4957

5058
@overload
5159
def to_native(
@@ -843,16 +851,10 @@ def to_py_scalar(scalar_like: Any) -> Any:
843851
>>> nw.to_py_scalar(1)
844852
1
845853
"""
846-
847-
pa = get_pyarrow()
848-
if pa and isinstance(scalar_like, pa.Scalar):
849-
return scalar_like.as_py()
850-
851-
cupy = get_cupy()
852-
if ( # pragma: no cover
853-
cupy and isinstance(scalar_like, cupy.ndarray) and scalar_like.size == 1
854-
):
855-
return scalar_like.item()
854+
if scalar_like is None:
855+
return None
856+
if isinstance(scalar_like, NON_TEMPORAL_SCALAR_TYPES):
857+
return scalar_like
856858

857859
np = get_numpy()
858860
if np and np.isscalar(scalar_like) and hasattr(scalar_like, "item"):
@@ -863,21 +865,30 @@ def to_py_scalar(scalar_like: Any) -> Any:
863865
return scalar_like.to_pydatetime()
864866
if pd and isinstance(scalar_like, pd.Timedelta):
865867
return scalar_like.to_pytimedelta()
866-
867-
all_scalar_types = (
868-
int,
869-
float,
870-
complex,
871-
bool,
872-
bytes,
873-
str,
874-
datetime,
875-
timedelta,
876-
numbers.Number,
877-
)
878-
if isinstance(scalar_like, all_scalar_types):
868+
if pd and pd.api.types.is_scalar(scalar_like):
869+
try:
870+
is_na = pd.isna(scalar_like)
871+
except Exception: # pragma: no cover # noqa: BLE001, S110
872+
pass
873+
else:
874+
if is_na:
875+
return None
876+
877+
# pd.Timestamp and pd.Timedelta subclass datetime and timedelta,
878+
# so we need to check this separately
879+
if isinstance(scalar_like, (datetime, timedelta)):
879880
return scalar_like
880881

882+
pa = get_pyarrow()
883+
if pa and isinstance(scalar_like, pa.Scalar):
884+
return scalar_like.as_py()
885+
886+
cupy = get_cupy()
887+
if ( # pragma: no cover
888+
cupy and isinstance(scalar_like, cupy.ndarray) and scalar_like.size == 1
889+
):
890+
return scalar_like.item()
891+
881892
msg = (
882893
f"Expected object convertible to a scalar, found {type(scalar_like)}. "
883894
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"

tests/stable_api_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ def test_stable_api_docstrings() -> None:
8787
nw_doc = getattr(nw, item).__doc__
8888
if item == "from_native":
8989
v1_doc = v1_doc.replace("native_dataframe", "native_object")
90+
if item in {"to_py_scalar"}:
91+
# We don't overwrite the docstring for these
92+
assert v1_doc == nw_doc
93+
continue
9094
assert (
9195
v1_doc.replace("import narwhals.stable.v1 as nw", "import narwhals as nw")
9296
== nw_doc

tests/translate/to_py_scalar_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
("input_value", "expected"),
2121
[
2222
(1, 1),
23+
(np.int64(1), 1),
2324
(1.0, 1.0),
25+
(None, None),
2426
("a", "a"),
2527
(True, True),
2628
(b"a", b"a"),

0 commit comments

Comments
 (0)