Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improving array casting #1865

Merged
merged 11 commits into from
Feb 10, 2025
Merged
11 changes: 8 additions & 3 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,13 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pa
]
)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
return NotImplementedError(msg)
inner = narwhals_to_native_dtype(
dtype.inner, # type: ignore[union-attr]
version=version,
)
list_size = dtype.size # type: ignore[union-attr]
return pa.list_(inner, list_size=list_size)

msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)

Expand Down Expand Up @@ -220,7 +225,7 @@ def broadcast_and_extract_dataframe_comparand(

if isinstance(other, ArrowSeries):
len_other = len(other)
if len_other == 1:
if len_other == 1 and length != 1:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise for list/array types we end up getting the first element

import numpy as np # ignore-banned-import

value = other._native_series[0]
Expand Down
16 changes: 11 additions & 5 deletions narwhals/_duckdb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,13 @@ def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType:
)
if match_ := re.match(r"(.*)\[\]$", duckdb_dtype):
return dtypes.List(native_to_narwhals_dtype(match_.group(1), version))
if match_ := re.match(r"(\w+)\[(\d+)\]", duckdb_dtype):
if match_ := re.match(r"(\w+)((?:\[\d+\])+)", duckdb_dtype):
Copy link
Member Author

@FBruzzesi FBruzzesi Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array type in duckdb can have multiple dimensions. The resulting type is: INNER[d1][d2][...]

With this new regex we can parse multiple instances of the dimensions

duckdb_inner_type = match_.group(1)
duckdb_shape = match_.group(2)
shape = tuple(int(value) for value in re.findall(r"\[(\d+)\]", duckdb_shape))
return dtypes.Array(
native_to_narwhals_dtype(match_.group(1), version),
int(match_.group(2)),
inner=native_to_narwhals_dtype(duckdb_inner_type, version),
shape=shape,
)
if duckdb_dtype.startswith("DECIMAL("):
return dtypes.Decimal()
Expand Down Expand Up @@ -193,8 +196,11 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st
)
return f"STRUCT({inner})"
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "todo"
raise NotImplementedError(msg)
duckdb_shape_fmt = "".join(f"[{item}]" for item in dtype.shape) # type: ignore[union-attr]
while isinstance(dtype.inner, dtypes.Array): # type: ignore[union-attr]
dtype = dtype.inner # type: ignore[union-attr]
inner = narwhals_to_native_dtype(dtype.inner, version) # type: ignore[union-attr]
return f"{inner}{duckdb_shape_fmt}"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First creates the shape [d1][d2]... then find the inner type recursively (first being non array)

msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)

Expand Down
61 changes: 8 additions & 53 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def broadcast_and_extract_dataframe_comparand(index: Any, other: Any) -> Any:
if isinstance(other, PandasLikeSeries):
len_other = other.len()

if len_other == 1:
if len_other == 1 and len(index) != 1:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly as for pyarrow, otherwise for list/array types we end up getting the first element

# broadcast
s = other._native_series
return s.__class__(s.iloc[0], index=index, dtype=s.dtype, name=s.name)
Expand Down Expand Up @@ -387,9 +387,7 @@ def rename(


@lru_cache(maxsize=16)
def non_object_native_to_narwhals_dtype(
dtype: str, version: Version, _implementation: Implementation
) -> DType:
def non_object_native_to_narwhals_dtype(dtype: str, version: Version) -> DType:
dtypes = import_dtypes_module(version)
if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}:
return dtypes.Int64()
Expand Down Expand Up @@ -465,7 +463,7 @@ def native_to_narwhals_dtype(
return arrow_native_to_narwhals_dtype(native_dtype.to_arrow(), version)
return arrow_native_to_narwhals_dtype(native_dtype.pyarrow_dtype, version)
if dtype != "object":
return non_object_native_to_narwhals_dtype(dtype, version, implementation)
return non_object_native_to_narwhals_dtype(dtype, version)
if implementation is Implementation.DASK:
# Dask columns are lazy, so we can't inspect values.
# The most useful assumption is probably String
Expand Down Expand Up @@ -649,68 +647,25 @@ def narwhals_to_native_dtype( # noqa: PLR0915
if isinstance_or_issubclass(dtype, dtypes.Enum):
msg = "Converting to Enum is not (yet) supported"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.List):
from narwhals._arrow.utils import (
narwhals_to_native_dtype as arrow_narwhals_to_native_dtype,
)

if implementation is Implementation.PANDAS and backend_version >= (2, 2):
try:
import pandas as pd
import pyarrow as pa # ignore-banned-import
except ImportError as exc: # pragma: no cover
msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}"
raise ImportError(msg) from exc

return pd.ArrowDtype(
pa.list_(
value_type=arrow_narwhals_to_native_dtype(
dtype.inner, # type: ignore[union-attr]
version=version,
)
)
)
else: # pragma: no cover
msg = (
"Converting to List dtype is not supported for implementation "
f"{implementation} and version {version}."
)
return NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Struct):
if isinstance_or_issubclass(dtype, (dtypes.Struct, dtypes.Array, dtypes.List)):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff is quite nice πŸ˜‰

if implementation is Implementation.PANDAS and backend_version >= (2, 2):
try:
import pandas as pd
import pyarrow as pa # ignore-banned-import
import pyarrow as pa # ignore-banned-import # noqa: F401
except ImportError as exc: # pragma: no cover
msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}"
raise ImportError(msg) from exc
from narwhals._arrow.utils import (
narwhals_to_native_dtype as arrow_narwhals_to_native_dtype,
)

return pd.ArrowDtype(
pa.struct(
[
(
field.name,
arrow_narwhals_to_native_dtype(
field.dtype,
version=version,
),
)
for field in dtype.fields # type: ignore[union-attr]
]
)
)
return pd.ArrowDtype(arrow_narwhals_to_native_dtype(dtype, version=version))
else: # pragma: no cover
msg = (
"Converting to Struct dtype is not supported for implementation "
f"Converting to {dtype} dtype is not supported for implementation "
f"{implementation} and version {version}."
)
return NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
return NotImplementedError(msg)
raise NotImplementedError(msg)
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)

Expand Down
21 changes: 9 additions & 12 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,11 @@ def native_to_narwhals_dtype(
native_to_narwhals_dtype(dtype.inner, version, backend_version) # type: ignore[attr-defined]
)
if dtype == pl.Array:
if backend_version < (0, 20, 30): # pragma: no cover
return dtypes.Array(
native_to_narwhals_dtype(dtype.inner, version, backend_version), # type: ignore[attr-defined]
dtype.width, # type: ignore[attr-defined]
)
else:
return dtypes.Array(
native_to_narwhals_dtype(dtype.inner, version, backend_version), # type: ignore[attr-defined]
dtype.size, # type: ignore[attr-defined]
)
outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size # type: ignore[attr-defined]
return dtypes.Array(
inner=native_to_narwhals_dtype(dtype.inner, version, backend_version), # type: ignore[attr-defined]
shape=outer_shape,
)
if dtype == pl.Decimal:
return dtypes.Decimal()
return dtypes.Unknown()
Expand Down Expand Up @@ -205,8 +200,10 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pl
]
)
if dtype == dtypes.Array: # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
raise NotImplementedError(msg)
return pl.Array(
inner=narwhals_to_native_dtype(dtype.inner, version), # type: ignore[union-attr]
shape=dtype.size, # type: ignore[union-attr]
)
return pl.Unknown() # pragma: no cover


Expand Down
28 changes: 15 additions & 13 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,24 @@ def native_to_narwhals_dtype(
return dtypes.Int16()
if isinstance(dtype, pyspark_types.ByteType):
return dtypes.Int8()
string_types = [
pyspark_types.StringType,
pyspark_types.VarcharType,
pyspark_types.CharType,
]
if any(isinstance(dtype, t) for t in string_types):
if isinstance(
dtype,
(pyspark_types.StringType, pyspark_types.VarcharType, pyspark_types.CharType),
):
return dtypes.String()
if isinstance(dtype, pyspark_types.BooleanType):
return dtypes.Boolean()
if isinstance(dtype, pyspark_types.DateType):
return dtypes.Date()
datetime_types = [
pyspark_types.TimestampType,
pyspark_types.TimestampNTZType,
]
if any(isinstance(dtype, t) for t in datetime_types):
if isinstance(dtype, (pyspark_types.TimestampType, pyspark_types.TimestampNTZType)):
return dtypes.Datetime()
if isinstance(dtype, pyspark_types.DecimalType): # pragma: no cover
# TODO(unassigned): cover this in dtypes_test.py
return dtypes.Decimal()
if isinstance(dtype, pyspark_types.ArrayType): # pragma: no cover
return dtypes.List(
inner=native_to_narwhals_dtype(dtype.elementType, version=version)
)
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
return dtypes.Unknown()


Expand Down Expand Up @@ -97,8 +95,12 @@ def narwhals_to_native_dtype(
msg = "Converting to Struct dtype is not supported yet"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
raise NotImplementedError(msg)
inner = narwhals_to_native_dtype(
dtype.inner, # type: ignore[union-attr]
version=version,
)
return pyspark_types.ArrayType(elementType=inner)

if isinstance_or_issubclass(
dtype, (dtypes.UInt64, dtypes.UInt32, dtypes.UInt16, dtypes.UInt8)
): # pragma: no cover
Expand Down
51 changes: 41 additions & 10 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,14 +731,37 @@ class Array(DType):
Array(Int32, 2)
"""

inner: DType | type[DType]
size: int
shape: tuple[int, ...]

def __init__(
self: Self, inner: DType | type[DType], width: int | None = None
self: Self,
inner: DType | type[DType],
shape: int | tuple[int, ...] | None = None,
Comment on lines -789 to +795
Copy link
Member Author

@FBruzzesi FBruzzesi Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change. However by checking how narwhals is used on github, I could not find any instances of Array

) -> None:
self.inner = inner
if width is None:
error = "`width` must be specified when initializing an `Array`"
raise TypeError(error)
self.width = width
inner_shape: tuple[int, ...] = inner.shape if isinstance(inner, Array) else ()

if shape is None: # pragma: no cover
msg = "Array constructor is missing the required argument `shape`"
raise TypeError(msg)

if isinstance(shape, int):
self.inner = inner
self.size = shape
self.shape = (shape, *inner_shape)

elif isinstance(shape, tuple):
if len(shape) > 1:
inner = Array(inner, shape[1:])

self.inner = inner
self.size = shape[0]
self.shape = shape + inner_shape

else:
msg = f"invalid input for shape: {shape!r}"
raise TypeError(msg)

def __eq__(self: Self, other: DType | type[DType]) -> bool: # type: ignore[override]
# This equality check allows comparison of type classes and type instances.
Expand All @@ -751,16 +774,24 @@ def __eq__(self: Self, other: DType | type[DType]) -> bool: # type: ignore[over
if type(other) is type and issubclass(other, self.__class__):
return True
elif isinstance(other, self.__class__):
return self.inner == other.inner
if self.shape != other.shape:
return False
else:
return self.inner == other.inner
else:
return False

def __hash__(self: Self) -> int:
return hash((self.__class__, self.inner, self.width))
return hash((self.__class__, self.inner, self.shape))

def __repr__(self) -> str:
# Get leaf type
dtype = self.inner
while isinstance(dtype, Array):
dtype = dtype.inner
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we put a limit on the depth of this (just to avoid some infinite loop in some unexpected scenario)?
same with the other place where there's the while loop

Copy link
Member Author

@FBruzzesi FBruzzesi Feb 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disclaimer, from polars: https://github.com/pola-rs/polars/blob/faad12f7277751006e3faebf0fffb1f6bf9aa7e7/py-polars/polars/datatypes/classes.py#L905-L912

and the __init__ is also recursive:

narwhals/narwhals/dtypes.py

Lines 808 to 810 in 981f87c

elif isinstance(shape, tuple) and isinstance(shape[0], int):
if len(shape) > 1:
inner = Array(inner, shape[1:])

In principle, max depth should be len(shape). Is that a good depth to check? WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember the trick, yet I don't feel super comfortable with introducing a fixed number of depth to check for.

Since we know already that the max depth should be len(shape), I think that would be the best limit to introduce, if any.

In practice, if someone generate a datatype to pass which ends up in an infinite recursion, maybe it's good to raise with sure error instead of an AssertionError. I am thinking out loud here though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, shall we use len shape?

In theory this shouldn't happen, just trying to guard against implementation oversights


def __repr__(self: Self) -> str:
class_name = self.__class__.__name__
return f"{class_name}({self.inner!r}, {self.width})"
return f"{class_name}({dtype!r}, shape={self.shape})"
Copy link
Member Author

@FBruzzesi FBruzzesi Feb 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff is that's causing marimo ci to fail - we could provide a fix there πŸ™ˆ

Copy link
Member

@MarcoGorelli MarcoGorelli Feb 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also wouldn't be too bad to override repr in v1 right?

Not saying we necessarily have to do that here though

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we maintain width in v1 then? My understanding is that we are limited to 1d arrays with the v1 implementation.
Or should we just keep the __repr__ and show f"{class_name}({self.inner!r}, {width})" with width=self.shape[0] if shape has only one dimension else the tuple self.shape?



class Date(TemporalType):
Expand Down
24 changes: 19 additions & 5 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import TYPE_CHECKING
from typing import Literal

import numpy as np
Expand All @@ -15,6 +16,9 @@
from tests.utils import PANDAS_VERSION
from tests.utils import POLARS_VERSION

if TYPE_CHECKING:
from tests.utils import Constructor


@pytest.mark.parametrize("time_unit", ["us", "ns", "ms"])
@pytest.mark.parametrize("time_zone", ["Europe/Rome", timezone.utc, None])
Expand Down Expand Up @@ -75,15 +79,15 @@ def test_array_valid() -> None:
assert dtype == nw.Array
assert dtype != nw.Array(nw.Float32, 2)
assert dtype != nw.Duration
assert repr(dtype) == "Array(<class 'narwhals.dtypes.Int64'>, 2)"
assert repr(dtype) == "Array(<class 'narwhals.dtypes.Int64'>, shape=(2,))"
dtype = nw.Array(nw.Array(nw.Int64, 2), 2)
assert dtype == nw.Array(nw.Array(nw.Int64, 2), 2)
assert dtype == nw.Array
assert dtype != nw.Array(nw.Array(nw.Float32, 2), 2)
assert dtype in {nw.Array(nw.Array(nw.Int64, 2), 2)}

with pytest.raises(
TypeError, match="`width` must be specified when initializing an `Array`"
TypeError, match="Array constructor is missing the required argument `shape`"
):
dtype = nw.Array(nw.Int64)

Expand Down Expand Up @@ -133,13 +137,23 @@ def test_polars_2d_array() -> None:
df = pl.DataFrame(
{"a": [[[1, 2], [3, 4], [5, 6]]]}, schema={"a": pl.Array(pl.Int64, (3, 2))}
)
assert nw.from_native(df).collect_schema()["a"] == nw.Array(nw.Array(nw.Int64, 2), 3)
assert nw.from_native(df).collect_schema()["a"] == nw.Array(nw.Int64(), (3, 2))
assert nw.from_native(df.to_arrow()).collect_schema()["a"] == nw.Array(
nw.Array(nw.Int64, 2), 3
nw.Array(nw.Int64(), 2), 3
)
assert nw.from_native(
df.to_pandas(use_pyarrow_extension_array=True)
).collect_schema()["a"] == nw.Array(nw.Array(nw.Int64, 2), 3)
).collect_schema()["a"] == nw.Array(nw.Array(nw.Int64(), 2), 3)


def test_2d_array(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if any(x in str(constructor) for x in ("dask", "modin", "cudf", "pyspark")):
request.applymarker(pytest.mark.xfail)
data = {"a": [[[1, 2], [3, 4], [5, 6]]]}
df = nw.from_native(constructor(data)).with_columns(
a=nw.col("a").cast(nw.Array(nw.Int64(), (3, 2)))
)
assert df.collect_schema()["a"] == nw.Array(nw.Int64(), (3, 2))


def test_second_time_unit() -> None:
Expand Down
Loading