|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import numbers |
| 4 | +from datetime import datetime |
| 5 | +from datetime import timedelta |
3 | 6 | from functools import wraps
|
4 | 7 | from typing import TYPE_CHECKING
|
5 | 8 | from typing import Any
|
|
9 | 12 | from typing import overload
|
10 | 13 |
|
11 | 14 | from narwhals.dependencies import get_cudf
|
| 15 | +from narwhals.dependencies import get_cupy |
12 | 16 | from narwhals.dependencies import get_dask
|
13 | 17 | from narwhals.dependencies import get_dask_expr
|
14 | 18 | from narwhals.dependencies import get_modin
|
| 19 | +from narwhals.dependencies import get_numpy |
15 | 20 | from narwhals.dependencies import get_pandas
|
16 | 21 | from narwhals.dependencies import get_polars
|
17 | 22 | from narwhals.dependencies import get_pyarrow
|
@@ -776,8 +781,70 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
|
776 | 781 | return decorator(func)
|
777 | 782 |
|
778 | 783 |
|
| 784 | +def to_py_scalar(scalar_like: Any) -> Any: |
| 785 | + """If a scalar is not Python native, converts it to Python native. |
| 786 | +
|
| 787 | + Raises: |
| 788 | + ValueError: If the object is not convertible to a scalar. |
| 789 | +
|
| 790 | + Examples: |
| 791 | + >>> import narwhals as nw |
| 792 | + >>> import pandas as pd |
| 793 | + >>> df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]})) |
| 794 | + >>> nw.to_py_scalar(df["a"].item(0)) |
| 795 | + 1 |
| 796 | + >>> import pyarrow as pa |
| 797 | + >>> df = nw.from_native(pa.table({"a": [1, 2, 3]})) |
| 798 | + >>> nw.to_py_scalar(df["a"].item(0)) |
| 799 | + 1 |
| 800 | + >>> nw.to_py_scalar(1) |
| 801 | + 1 |
| 802 | + """ |
| 803 | + |
| 804 | + pa = get_pyarrow() |
| 805 | + if pa and isinstance(scalar_like, pa.Scalar): |
| 806 | + return scalar_like.as_py() |
| 807 | + |
| 808 | + cupy = get_cupy() |
| 809 | + if ( # pragma: no cover |
| 810 | + cupy and isinstance(scalar_like, cupy.ndarray) and scalar_like.size == 1 |
| 811 | + ): |
| 812 | + return scalar_like.item() |
| 813 | + |
| 814 | + np = get_numpy() |
| 815 | + if np and np.isscalar(scalar_like) and hasattr(scalar_like, "item"): |
| 816 | + return scalar_like.item() |
| 817 | + |
| 818 | + pd = get_pandas() |
| 819 | + if pd and isinstance(scalar_like, pd.Timestamp): |
| 820 | + return scalar_like.to_pydatetime() |
| 821 | + if pd and isinstance(scalar_like, pd.Timedelta): |
| 822 | + return scalar_like.to_pytimedelta() |
| 823 | + |
| 824 | + all_scalar_types = ( |
| 825 | + int, |
| 826 | + float, |
| 827 | + complex, |
| 828 | + bool, |
| 829 | + bytes, |
| 830 | + str, |
| 831 | + datetime, |
| 832 | + timedelta, |
| 833 | + numbers.Number, |
| 834 | + ) |
| 835 | + if isinstance(scalar_like, all_scalar_types): |
| 836 | + return scalar_like |
| 837 | + |
| 838 | + msg = ( |
| 839 | + f"Expected object convertible to a scalar, found {type(scalar_like)}. " |
| 840 | + "Please report a bug to https://github.com/narwhals-dev/narwhals/issues" |
| 841 | + ) |
| 842 | + raise ValueError(msg) |
| 843 | + |
| 844 | + |
779 | 845 | __all__ = [
|
780 | 846 | "get_native_namespace",
|
781 | 847 | "to_native",
|
782 | 848 | "narwhalify",
|
| 849 | + "to_py_scalar", |
783 | 850 | ]
|
0 commit comments