Skip to content

Commit 879d3cf

Browse files
authored
feat: add from_arrow (which uses the PyCapsule Interface) (#1181)
1 parent 9b628ee commit 879d3cf

File tree

5 files changed

+198
-0
lines changed

5 files changed

+198
-0
lines changed

docs/api-reference/narwhals.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Here are the top-level functions available in Narwhals.
1414
- concat_str
1515
- from_dict
1616
- from_native
17+
- from_arrow
1718
- get_level
1819
- get_native_namespace
1920
- is_ordered_categorical

narwhals/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from narwhals.expr import sum_horizontal
4646
from narwhals.expr import when
4747
from narwhals.functions import concat
48+
from narwhals.functions import from_arrow
4849
from narwhals.functions import from_dict
4950
from narwhals.functions import get_level
5051
from narwhals.functions import new_series
@@ -69,6 +70,7 @@
6970
"selectors",
7071
"concat",
7172
"from_dict",
73+
"from_arrow",
7274
"get_level",
7375
"new_series",
7476
"to_native",

narwhals/functions.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any
77
from typing import Iterable
88
from typing import Literal
9+
from typing import Protocol
910
from typing import TypeVar
1011
from typing import Union
1112

@@ -21,6 +22,7 @@
2122
# The rest of the annotations seem to work fine with this anyway
2223
FrameT = TypeVar("FrameT", bound=Union[DataFrame, LazyFrame]) # type: ignore[type-arg]
2324

25+
2426
if TYPE_CHECKING:
2527
from types import ModuleType
2628

@@ -29,6 +31,11 @@
2931
from narwhals.series import Series
3032
from narwhals.typing import DTypes
3133

34+
class ArrowStreamExportable(Protocol):
35+
def __arrow_c_stream__(
36+
self, requested_schema: object | None = None
37+
) -> object: ...
38+
3239

3340
def concat(
3441
items: Iterable[FrameT],
@@ -406,6 +413,100 @@ def _from_dict_impl(
406413
return from_native(native_frame, eager_only=True)
407414

408415

416+
def from_arrow(
417+
native_frame: ArrowStreamExportable, *, native_namespace: ModuleType
418+
) -> DataFrame[Any]:
419+
"""
420+
Construct a DataFrame from an object which supports the PyCapsule Interface.
421+
422+
Arguments:
423+
native_frame: Object which implements `__arrow_c_stream__`.
424+
native_namespace: The native library to use for DataFrame creation.
425+
426+
Examples:
427+
>>> import pandas as pd
428+
>>> import polars as pl
429+
>>> import pyarrow as pa
430+
>>> import narwhals as nw
431+
>>> data = {"a": [1, 2, 3], "b": [4, 5, 6]}
432+
433+
Let's define a dataframe-agnostic function which creates a PyArrow
434+
Table.
435+
436+
>>> @nw.narwhalify
437+
... def func(df):
438+
... return nw.from_arrow(df, native_namespace=pa)
439+
440+
Let's see what happens when passing pandas / Polars input:
441+
442+
>>> func(pd.DataFrame(data)) # doctest: +SKIP
443+
pyarrow.Table
444+
a: int64
445+
b: int64
446+
----
447+
a: [[1,2,3]]
448+
b: [[4,5,6]]
449+
>>> func(pl.DataFrame(data)) # doctest: +SKIP
450+
pyarrow.Table
451+
a: int64
452+
b: int64
453+
----
454+
a: [[1,2,3]]
455+
b: [[4,5,6]]
456+
"""
457+
if not hasattr(native_frame, "__arrow_c_stream__"):
458+
msg = f"Given object of type {type(native_frame)} does not support PyCapsule interface"
459+
raise TypeError(msg)
460+
implementation = Implementation.from_native_namespace(native_namespace)
461+
462+
if implementation is Implementation.POLARS and parse_version(
463+
native_namespace.__version__
464+
) >= (1, 3):
465+
native_frame = native_namespace.DataFrame(native_frame)
466+
elif implementation in {
467+
Implementation.PANDAS,
468+
Implementation.MODIN,
469+
Implementation.CUDF,
470+
Implementation.POLARS,
471+
}:
472+
# These don't (yet?) support the PyCapsule Interface for import
473+
# so we go via PyArrow
474+
try:
475+
import pyarrow as pa # ignore-banned-import
476+
except ModuleNotFoundError as exc: # pragma: no cover
477+
msg = f"PyArrow>=14.0.0 is required for `from_arrow` for object of type {native_namespace}"
478+
raise ModuleNotFoundError(msg) from exc
479+
if parse_version(pa.__version__) < (14, 0): # pragma: no cover
480+
msg = f"PyArrow>=14.0.0 is required for `from_arrow` for object of type {native_namespace}"
481+
raise ModuleNotFoundError(msg) from None
482+
483+
tbl = pa.table(native_frame)
484+
if implementation is Implementation.PANDAS:
485+
native_frame = tbl.to_pandas()
486+
elif implementation is Implementation.MODIN: # pragma: no cover
487+
from modin.pandas.utils import from_arrow
488+
489+
native_frame = from_arrow(tbl)
490+
elif implementation is Implementation.CUDF: # pragma: no cover
491+
native_frame = native_namespace.DataFrame.from_arrow(tbl)
492+
elif implementation is Implementation.POLARS: # pragma: no cover
493+
native_frame = native_namespace.from_arrow(tbl)
494+
else: # pragma: no cover
495+
msg = "congratulations, you entered unrecheable code - please report a bug"
496+
raise AssertionError(msg)
497+
elif implementation is Implementation.PYARROW:
498+
native_frame = native_namespace.table(native_frame)
499+
else: # pragma: no cover
500+
try:
501+
# implementation is UNKNOWN, Narwhals extension using this feature should
502+
# implement PyCapsule support
503+
native_frame = native_namespace.DataFrame(native_frame)
504+
except AttributeError as e:
505+
msg = "Unknown namespace is expected to implement `DataFrame` class which accepts object which supports PyCapsule Interface."
506+
raise AttributeError(msg) from e
507+
return from_native(native_frame, eager_only=True)
508+
509+
409510
def _get_sys_info() -> dict[str, str]:
410511
"""System information
411512

narwhals/stable/v1/__init__.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from narwhals.expr import when as nw_when
2222
from narwhals.functions import _from_dict_impl
2323
from narwhals.functions import _new_series_impl
24+
from narwhals.functions import from_arrow as nw_from_arrow
2425
from narwhals.functions import show_versions
2526
from narwhals.schema import Schema as NwSchema
2627
from narwhals.series import Series as NwSeries
@@ -66,6 +67,7 @@
6667
from typing_extensions import Self
6768

6869
from narwhals.dtypes import DType
70+
from narwhals.functions import ArrowStreamExportable
6971
from narwhals.typing import IntoExpr
7072

7173
T = TypeVar("T")
@@ -2183,6 +2185,52 @@ def new_series(
21832185
)
21842186

21852187

2188+
def from_arrow(
2189+
native_frame: ArrowStreamExportable, *, native_namespace: ModuleType
2190+
) -> DataFrame[Any]:
2191+
"""
2192+
Construct a DataFrame from an object which supports the PyCapsule Interface.
2193+
2194+
Arguments:
2195+
native_frame: Object which implements `__arrow_c_stream__`.
2196+
native_namespace: The native library to use for DataFrame creation.
2197+
2198+
Examples:
2199+
>>> import pandas as pd
2200+
>>> import polars as pl
2201+
>>> import pyarrow as pa
2202+
>>> import narwhals.stable.v1 as nw
2203+
>>> data = {"a": [1, 2, 3], "b": [4, 5, 6]}
2204+
2205+
Let's define a dataframe-agnostic function which creates a PyArrow
2206+
Table.
2207+
2208+
>>> @nw.narwhalify
2209+
... def func(df):
2210+
... return nw.from_arrow(df, native_namespace=pa)
2211+
2212+
Let's see what happens when passing pandas / Polars input:
2213+
2214+
>>> func(pd.DataFrame(data)) # doctest: +SKIP
2215+
pyarrow.Table
2216+
a: int64
2217+
b: int64
2218+
----
2219+
a: [[1,2,3]]
2220+
b: [[4,5,6]]
2221+
>>> func(pl.DataFrame(data)) # doctest: +SKIP
2222+
pyarrow.Table
2223+
a: int64
2224+
b: int64
2225+
----
2226+
a: [[1,2,3]]
2227+
b: [[4,5,6]]
2228+
"""
2229+
return _stableify( # type: ignore[no-any-return]
2230+
nw_from_arrow(native_frame, native_namespace=native_namespace)
2231+
)
2232+
2233+
21862234
def from_dict(
21872235
data: dict[str, Any],
21882236
schema: dict[str, DType] | Schema | None = None,
@@ -2307,5 +2355,6 @@ def from_dict(
23072355
"show_versions",
23082356
"Schema",
23092357
"from_dict",
2358+
"from_arrow",
23102359
"new_series",
23112360
]

tests/from_pycapsule_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import sys
2+
3+
import pandas as pd
4+
import polars as pl
5+
import pyarrow as pa
6+
import pytest
7+
8+
import narwhals.stable.v1 as nw
9+
from narwhals.utils import parse_version
10+
from tests.utils import compare_dicts
11+
12+
13+
@pytest.mark.xfail(parse_version(pa.__version__) < (14,), reason="too old")
14+
def test_from_arrow_to_arrow() -> None:
15+
df = nw.from_native(pl.DataFrame({"ab": [1, 2, 3], "ba": [4, 5, 6]}), eager_only=True)
16+
result = nw.from_arrow(df, native_namespace=pa)
17+
assert isinstance(result.to_native(), pa.Table)
18+
expected = {"ab": [1, 2, 3], "ba": [4, 5, 6]}
19+
compare_dicts(result, expected)
20+
21+
22+
@pytest.mark.xfail(parse_version(pa.__version__) < (14,), reason="too old")
23+
def test_from_arrow_to_polars(monkeypatch: pytest.MonkeyPatch) -> None:
24+
tbl = pa.table({"ab": [1, 2, 3], "ba": [4, 5, 6]})
25+
monkeypatch.delitem(sys.modules, "pandas")
26+
df = nw.from_native(tbl, eager_only=True)
27+
result = nw.from_arrow(df, native_namespace=pl)
28+
assert isinstance(result.to_native(), pl.DataFrame)
29+
expected = {"ab": [1, 2, 3], "ba": [4, 5, 6]}
30+
compare_dicts(result, expected)
31+
assert "pandas" not in sys.modules
32+
33+
34+
@pytest.mark.xfail(parse_version(pa.__version__) < (14,), reason="too old")
35+
def test_from_arrow_to_pandas() -> None:
36+
df = nw.from_native(pa.table({"ab": [1, 2, 3], "ba": [4, 5, 6]}), eager_only=True)
37+
result = nw.from_arrow(df, native_namespace=pd)
38+
assert isinstance(result.to_native(), pd.DataFrame)
39+
expected = {"ab": [1, 2, 3], "ba": [4, 5, 6]}
40+
compare_dicts(result, expected)
41+
42+
43+
def test_from_arrow_invalid() -> None:
44+
with pytest.raises(TypeError, match="PyCapsule"):
45+
nw.from_arrow({"a": [1]}, native_namespace=pa) # type: ignore[arg-type]

0 commit comments

Comments
 (0)