-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: improving array casting #1865
Changes from 1 commit
1b8fd7d
569c204
ab25f4d
65287c9
7c84531
31dd90f
044db09
35c2a2e
c79507d
981f87c
690fc99
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -133,10 +133,13 @@ def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType: | |
) | ||
if match_ := re.match(r"(.*)\[\]$", duckdb_dtype): | ||
return dtypes.List(native_to_narwhals_dtype(match_.group(1), version)) | ||
if match_ := re.match(r"(\w+)\[(\d+)\]", duckdb_dtype): | ||
if match_ := re.match(r"(\w+)((?:\[\d+\])+)", duckdb_dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Array type in duckdb can have multiple dimensions. The resulting type is: With this new regex we can parse multiple instances of the dimensions |
||
duckdb_inner_type = match_.group(1) | ||
duckdb_shape = match_.group(2) | ||
shape = tuple(int(value) for value in re.findall(r"\[(\d+)\]", duckdb_shape)) | ||
return dtypes.Array( | ||
native_to_narwhals_dtype(match_.group(1), version), | ||
int(match_.group(2)), | ||
inner=native_to_narwhals_dtype(duckdb_inner_type, version), | ||
shape=shape, | ||
) | ||
if duckdb_dtype.startswith("DECIMAL("): | ||
return dtypes.Decimal() | ||
|
@@ -193,8 +196,11 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st | |
) | ||
return f"STRUCT({inner})" | ||
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover | ||
msg = "todo" | ||
raise NotImplementedError(msg) | ||
duckdb_shape_fmt = "".join(f"[{item}]" for item in dtype.shape) # type: ignore[union-attr] | ||
while isinstance(dtype.inner, dtypes.Array): # type: ignore[union-attr] | ||
dtype = dtype.inner # type: ignore[union-attr] | ||
inner = narwhals_to_native_dtype(dtype.inner, version) # type: ignore[union-attr] | ||
return f"{inner}{duckdb_shape_fmt}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. First creates the shape |
||
msg = f"Unknown dtype: {dtype}" # pragma: no cover | ||
raise AssertionError(msg) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -160,7 +160,7 @@ def broadcast_and_extract_dataframe_comparand(index: Any, other: Any) -> Any: | |
if isinstance(other, PandasLikeSeries): | ||
len_other = other.len() | ||
|
||
if len_other == 1: | ||
if len_other == 1 and len(index) != 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly as for pyarrow, otherwise for list/array types we end up getting the first element |
||
# broadcast | ||
s = other._native_series | ||
return s.__class__(s.iloc[0], index=index, dtype=s.dtype, name=s.name) | ||
|
@@ -387,9 +387,7 @@ def rename( | |
|
||
|
||
@lru_cache(maxsize=16) | ||
def non_object_native_to_narwhals_dtype( | ||
dtype: str, version: Version, _implementation: Implementation | ||
) -> DType: | ||
def non_object_native_to_narwhals_dtype(dtype: str, version: Version) -> DType: | ||
dtypes = import_dtypes_module(version) | ||
if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}: | ||
return dtypes.Int64() | ||
|
@@ -465,7 +463,7 @@ def native_to_narwhals_dtype( | |
return arrow_native_to_narwhals_dtype(native_dtype.to_arrow(), version) | ||
return arrow_native_to_narwhals_dtype(native_dtype.pyarrow_dtype, version) | ||
if dtype != "object": | ||
return non_object_native_to_narwhals_dtype(dtype, version, implementation) | ||
return non_object_native_to_narwhals_dtype(dtype, version) | ||
if implementation is Implementation.DASK: | ||
# Dask columns are lazy, so we can't inspect values. | ||
# The most useful assumption is probably String | ||
|
@@ -649,68 +647,25 @@ def narwhals_to_native_dtype( # noqa: PLR0915 | |
if isinstance_or_issubclass(dtype, dtypes.Enum): | ||
msg = "Converting to Enum is not (yet) supported" | ||
raise NotImplementedError(msg) | ||
if isinstance_or_issubclass(dtype, dtypes.List): | ||
from narwhals._arrow.utils import ( | ||
narwhals_to_native_dtype as arrow_narwhals_to_native_dtype, | ||
) | ||
|
||
if implementation is Implementation.PANDAS and backend_version >= (2, 2): | ||
try: | ||
import pandas as pd | ||
import pyarrow as pa # ignore-banned-import | ||
except ImportError as exc: # pragma: no cover | ||
msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}" | ||
raise ImportError(msg) from exc | ||
|
||
return pd.ArrowDtype( | ||
pa.list_( | ||
value_type=arrow_narwhals_to_native_dtype( | ||
dtype.inner, # type: ignore[union-attr] | ||
version=version, | ||
) | ||
) | ||
) | ||
else: # pragma: no cover | ||
msg = ( | ||
"Converting to List dtype is not supported for implementation " | ||
f"{implementation} and version {version}." | ||
) | ||
return NotImplementedError(msg) | ||
if isinstance_or_issubclass(dtype, dtypes.Struct): | ||
if isinstance_or_issubclass(dtype, (dtypes.Struct, dtypes.Array, dtypes.List)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This diff is quite nice π |
||
if implementation is Implementation.PANDAS and backend_version >= (2, 2): | ||
try: | ||
import pandas as pd | ||
import pyarrow as pa # ignore-banned-import | ||
import pyarrow as pa # ignore-banned-import # noqa: F401 | ||
except ImportError as exc: # pragma: no cover | ||
msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}" | ||
raise ImportError(msg) from exc | ||
from narwhals._arrow.utils import ( | ||
narwhals_to_native_dtype as arrow_narwhals_to_native_dtype, | ||
) | ||
|
||
return pd.ArrowDtype( | ||
pa.struct( | ||
[ | ||
( | ||
field.name, | ||
arrow_narwhals_to_native_dtype( | ||
field.dtype, | ||
version=version, | ||
), | ||
) | ||
for field in dtype.fields # type: ignore[union-attr] | ||
] | ||
) | ||
) | ||
return pd.ArrowDtype(arrow_narwhals_to_native_dtype(dtype, version=version)) | ||
else: # pragma: no cover | ||
msg = ( | ||
"Converting to Struct dtype is not supported for implementation " | ||
f"Converting to {dtype} dtype is not supported for implementation " | ||
f"{implementation} and version {version}." | ||
) | ||
return NotImplementedError(msg) | ||
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover | ||
msg = "Converting to Array dtype is not supported yet" | ||
return NotImplementedError(msg) | ||
raise NotImplementedError(msg) | ||
msg = f"Unknown dtype: {dtype}" # pragma: no cover | ||
raise AssertionError(msg) | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -731,14 +731,37 @@ class Array(DType): | |||||||
Array(Int32, 2) | ||||||||
""" | ||||||||
|
||||||||
inner: DType | type[DType] | ||||||||
size: int | ||||||||
shape: tuple[int, ...] | ||||||||
|
||||||||
def __init__( | ||||||||
self: Self, inner: DType | type[DType], width: int | None = None | ||||||||
self: Self, | ||||||||
inner: DType | type[DType], | ||||||||
shape: int | tuple[int, ...] | None = None, | ||||||||
Comment on lines
-789
to
+795
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a breaking change. However by checking how narwhals is used on github, I could not find any instances of Array |
||||||||
) -> None: | ||||||||
self.inner = inner | ||||||||
if width is None: | ||||||||
error = "`width` must be specified when initializing an `Array`" | ||||||||
raise TypeError(error) | ||||||||
self.width = width | ||||||||
inner_shape: tuple[int, ...] = inner.shape if isinstance(inner, Array) else () | ||||||||
|
||||||||
if shape is None: # pragma: no cover | ||||||||
msg = "Array constructor is missing the required argument `shape`" | ||||||||
raise TypeError(msg) | ||||||||
|
||||||||
if isinstance(shape, int): | ||||||||
self.inner = inner | ||||||||
self.size = shape | ||||||||
self.shape = (shape, *inner_shape) | ||||||||
|
||||||||
elif isinstance(shape, tuple): | ||||||||
if len(shape) > 1: | ||||||||
inner = Array(inner, shape[1:]) | ||||||||
|
||||||||
self.inner = inner | ||||||||
self.size = shape[0] | ||||||||
self.shape = shape + inner_shape | ||||||||
|
||||||||
else: | ||||||||
msg = f"invalid input for shape: {shape!r}" | ||||||||
raise TypeError(msg) | ||||||||
|
||||||||
def __eq__(self: Self, other: DType | type[DType]) -> bool: # type: ignore[override] | ||||||||
# This equality check allows comparison of type classes and type instances. | ||||||||
|
@@ -751,16 +774,24 @@ def __eq__(self: Self, other: DType | type[DType]) -> bool: # type: ignore[over | |||||||
if type(other) is type and issubclass(other, self.__class__): | ||||||||
return True | ||||||||
elif isinstance(other, self.__class__): | ||||||||
return self.inner == other.inner | ||||||||
if self.shape != other.shape: | ||||||||
return False | ||||||||
else: | ||||||||
return self.inner == other.inner | ||||||||
else: | ||||||||
return False | ||||||||
|
||||||||
def __hash__(self: Self) -> int: | ||||||||
return hash((self.__class__, self.inner, self.width)) | ||||||||
return hash((self.__class__, self.inner, self.shape)) | ||||||||
|
||||||||
def __repr__(self) -> str: | ||||||||
# Get leaf type | ||||||||
dtype = self.inner | ||||||||
while isinstance(dtype, Array): | ||||||||
dtype = dtype.inner | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we put a limit on the depth of this (just to avoid some infinite loop in some unexpected scenario)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Disclaimer, from polars: https://github.com/pola-rs/polars/blob/faad12f7277751006e3faebf0fffb1f6bf9aa7e7/py-polars/polars/datatypes/classes.py#L905-L912 and the Lines 808 to 810 in 981f87c
In principle, max depth should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i was thinking something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I remember the trick, yet I don't feel super comfortable with introducing a fixed number of depth to check for. Since we know already that the max depth should be In practice, if someone generate a datatype to pass which ends up in an infinite recursion, maybe it's good to raise with sure error instead of an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, shall we use len shape? In theory this shouldn't happen, just trying to guard against implementation oversights |
||||||||
|
||||||||
def __repr__(self: Self) -> str: | ||||||||
class_name = self.__class__.__name__ | ||||||||
return f"{class_name}({self.inner!r}, {self.width})" | ||||||||
return f"{class_name}({dtype!r}, shape={self.shape})" | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This diff is that's causing marimo ci to fail - we could provide a fix there π There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It also wouldn't be too bad to override repr in v1 right? Not saying we necessarily have to do that here though There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we maintain |
||||||||
|
||||||||
|
||||||||
class Date(TemporalType): | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise for list/array types we end up getting the first element