diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 5db5cdc52..006fd4fbb 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -75,7 +75,6 @@ def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: for i in range(dtype.num_fields) ] ) - if pa.types.is_list(dtype) or pa.types.is_large_list(dtype): return dtypes.List(native_to_narwhals_dtype(dtype.value_type, version)) if pa.types.is_fixed_size_list(dtype): @@ -145,8 +144,13 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pa ] ) if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover - msg = "Converting to Array dtype is not supported yet" - return NotImplementedError(msg) + inner = narwhals_to_native_dtype( + dtype.inner, # type: ignore[union-attr] + version=version, + ) + list_size = dtype.size # type: ignore[union-attr] + return pa.list_(inner, list_size=list_size) + msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) @@ -224,7 +228,7 @@ def broadcast_and_extract_dataframe_comparand( if isinstance(other, ArrowSeries): len_other = len(other) - if len_other == 1: + if len_other == 1 and length != 1: import numpy as np # ignore-banned-import value = other._native_series[0] diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 05c80c200..eb69d86b9 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -129,10 +129,13 @@ def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType: ) if match_ := re.match(r"(.*)\[\]$", duckdb_dtype): return dtypes.List(native_to_narwhals_dtype(match_.group(1), version)) - if match_ := re.match(r"(\w+)\[(\d+)\]", duckdb_dtype): + if match_ := re.match(r"(\w+)((?:\[\d+\])+)", duckdb_dtype): + duckdb_inner_type = match_.group(1) + duckdb_shape = match_.group(2) + shape = tuple(int(value) for value in re.findall(r"\[(\d+)\]", duckdb_shape)) return dtypes.Array( - native_to_narwhals_dtype(match_.group(1), version), - int(match_.group(2)), + inner=native_to_narwhals_dtype(duckdb_inner_type, version), + shape=shape, ) if duckdb_dtype.startswith("DECIMAL("): return dtypes.Decimal() @@ -189,8 +192,13 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st ) return f"STRUCT({inner})" if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover - msg = "todo" - raise NotImplementedError(msg) + shape: tuple[int] = dtype.shape # type: ignore[union-attr] + duckdb_shape_fmt = "".join(f"[{item}]" for item in shape) + inner_dtype = dtype + for _ in shape: + inner_dtype = inner_dtype.inner # type: ignore[union-attr] + duckdb_inner = narwhals_to_native_dtype(inner_dtype, version) + return f"{duckdb_inner}{duckdb_shape_fmt}" msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 28ca092a0..da9990095 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -167,7 +167,7 @@ def broadcast_and_extract_dataframe_comparand(index: Any, other: Any) -> Any: if isinstance(other, PandasLikeSeries): len_other = other.len() - if len_other == 1: + if len_other == 1 and len(index) != 1: # broadcast s = other._native_series return s.__class__(s.iloc[0], index=index, dtype=s.dtype, name=s.name) @@ -394,9 +394,7 @@ def rename( @lru_cache(maxsize=16) -def non_object_native_to_narwhals_dtype( - dtype: str, version: Version, _implementation: Implementation -) -> DType: +def non_object_native_to_narwhals_dtype(dtype: str, version: Version) -> DType: dtypes = import_dtypes_module(version) if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}: return dtypes.Int64() @@ -491,7 +489,7 @@ def native_to_narwhals_dtype( return arrow_native_to_narwhals_dtype(native_dtype.to_arrow(), version) return arrow_native_to_narwhals_dtype(native_dtype.pyarrow_dtype, version) if str_dtype != "object": - return non_object_native_to_narwhals_dtype(str_dtype, version, implementation) + return non_object_native_to_narwhals_dtype(str_dtype, version) elif implementation is Implementation.DASK: # Per conversations with their maintainers, they don't support arbitrary # objects, so we can just return String. @@ -643,38 +641,11 @@ def narwhals_to_native_dtype( # noqa: PLR0915 if isinstance_or_issubclass(dtype, dtypes.Enum): msg = "Converting to Enum is not (yet) supported" raise NotImplementedError(msg) - if isinstance_or_issubclass(dtype, dtypes.List): - from narwhals._arrow.utils import ( - narwhals_to_native_dtype as arrow_narwhals_to_native_dtype, - ) - - if implementation is Implementation.PANDAS and backend_version >= (2, 2): - try: - import pandas as pd - import pyarrow as pa # ignore-banned-import - except ImportError as exc: # pragma: no cover - msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}" - raise ImportError(msg) from exc - - return pd.ArrowDtype( - pa.list_( - value_type=arrow_narwhals_to_native_dtype( - dtype.inner, # type: ignore[union-attr] - version=version, - ) - ) - ) - else: # pragma: no cover - msg = ( - "Converting to List dtype is not supported for implementation " - f"{implementation} and version {version}." - ) - return NotImplementedError(msg) - if isinstance_or_issubclass(dtype, dtypes.Struct): + if isinstance_or_issubclass(dtype, (dtypes.Struct, dtypes.Array, dtypes.List)): if implementation is Implementation.PANDAS and backend_version >= (2, 2): try: import pandas as pd - import pyarrow as pa # ignore-banned-import + import pyarrow as pa # ignore-banned-import # noqa: F401 except ImportError as exc: # pragma: no cover msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}" raise ImportError(msg) from exc @@ -682,29 +653,13 @@ def narwhals_to_native_dtype( # noqa: PLR0915 narwhals_to_native_dtype as arrow_narwhals_to_native_dtype, ) - return pd.ArrowDtype( - pa.struct( - [ - ( - field.name, - arrow_narwhals_to_native_dtype( - field.dtype, - version=version, - ), - ) - for field in dtype.fields # type: ignore[union-attr] - ] - ) - ) + return pd.ArrowDtype(arrow_narwhals_to_native_dtype(dtype, version=version)) else: # pragma: no cover msg = ( - "Converting to Struct dtype is not supported for implementation " + f"Converting to {dtype} dtype is not supported for implementation " f"{implementation} and version {version}." ) - return NotImplementedError(msg) - if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover - msg = "Converting to Array dtype is not supported yet" - return NotImplementedError(msg) + raise NotImplementedError(msg) msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 64cf907ef..8a1160d76 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -104,8 +104,6 @@ def _from_native_object( def __getattr__(self: Self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: - import polars as pl - args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] try: return self._from_native_object( @@ -179,7 +177,6 @@ def __getitem__(self: Self, item: Any) -> Any: ) msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover raise TypeError(msg) # pragma: no cover - import polars as pl if ( isinstance(item, tuple) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 98a692a67..2b6969a3d 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -47,7 +47,7 @@ def func(*args: Any, **kwargs: Any) -> Any: def cast(self: Self, dtype: DType) -> Self: expr = self._native_expr - dtype_pl = narwhals_to_native_dtype(dtype, self._version) + dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version) return self._from_native_expr(expr.cast(dtype_pl)) def ewm_mean( @@ -79,8 +79,6 @@ def ewm_mean( **extra_kwargs, ) if self._backend_version < (1,): # pragma: no cover - import polars as pl - return self._from_native_expr( pl.when(~expr.is_null()).then(native_expr).otherwise(None) ) @@ -193,7 +191,9 @@ def map_batches( return_dtype: DType | None, ) -> Self: if return_dtype is not None: - return_dtype_pl = narwhals_to_native_dtype(return_dtype, self._version) + return_dtype_pl = narwhals_to_native_dtype( + return_dtype, self._version, self._backend_version + ) return self._from_native_expr( self._native_expr.map_batches(function, return_dtype_pl) ) @@ -205,7 +205,7 @@ def replace_strict( ) -> Self: expr = self._native_expr return_dtype_pl = ( - narwhals_to_native_dtype(return_dtype, self._version) + narwhals_to_native_dtype(return_dtype, self._version, self._backend_version) if return_dtype else None ) @@ -350,14 +350,10 @@ def len(self: Self) -> PolarsExpr: native_result = native_expr.list.len() if self._expr._backend_version < (1, 16): # pragma: no cover - import polars as pl - native_result: pl.Expr = ( # type: ignore[no-redef] pl.when(~native_expr.is_null()).then(native_result).cast(pl.UInt32()) ) elif self._expr._backend_version < (1, 17): # pragma: no cover - import polars as pl - native_result = native_result.cast(pl.UInt32()) return self._expr._from_native_expr(native_result) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index e1e00ba18..7d64a020d 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -116,7 +116,12 @@ def lit(self: Self, value: Any, dtype: DType | None) -> PolarsExpr: if dtype is not None: return PolarsExpr( - pl.lit(value, dtype=narwhals_to_native_dtype(dtype, self._version)), + pl.lit( + value, + dtype=narwhals_to_native_dtype( + dtype, self._version, self._backend_version + ), + ), version=self._version, backend_version=self._backend_version, ) @@ -210,9 +215,11 @@ def by_dtype(self: Self, dtypes: Iterable[DType]) -> PolarsExpr: from narwhals._polars.expr import PolarsExpr native_dtypes = [ - narwhals_to_native_dtype(dtype, self._version).__class__ + narwhals_to_native_dtype( + dtype, self._version, self._backend_version + ).__class__ if isinstance(dtype, type) and issubclass(dtype, DType) - else narwhals_to_native_dtype(dtype, self._version) + else narwhals_to_native_dtype(dtype, self._version, self._backend_version) for dtype in dtypes ] return PolarsExpr( @@ -233,8 +240,6 @@ def matches(self: Self, pattern: str) -> PolarsExpr: ) def numeric(self: Self) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr return PolarsExpr( @@ -244,8 +249,6 @@ def numeric(self: Self) -> PolarsExpr: ) def boolean(self: Self) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr return PolarsExpr( @@ -255,8 +258,6 @@ def boolean(self: Self) -> PolarsExpr: ) def string(self: Self) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr return PolarsExpr( @@ -266,8 +267,6 @@ def string(self: Self) -> PolarsExpr: ) def categorical(self: Self) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr return PolarsExpr( @@ -277,8 +276,6 @@ def categorical(self: Self) -> PolarsExpr: ) def all(self: Self) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr return PolarsExpr( diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 5bc2621dc..981812736 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -77,8 +77,6 @@ def _from_native_object(self: Self, series: T) -> T: ... def _from_native_object( self: Self, series: pl.Series | pl.DataFrame | T ) -> Self | PolarsDataFrame | T: - import polars as pl - if isinstance(series, pl.Series): return self._from_native_series(series) if isinstance(series, pl.DataFrame): @@ -130,7 +128,7 @@ def __getitem__(self: Self, item: int | slice | Sequence[int]) -> Any | Self: def cast(self: Self, dtype: DType) -> Self: ser = self._native_series - dtype_pl = narwhals_to_native_dtype(dtype, self._version) + dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version) return self._from_native_series(ser.cast(dtype_pl)) def replace_strict( @@ -138,7 +136,7 @@ def replace_strict( ) -> Self: ser = self._native_series dtype = ( - narwhals_to_native_dtype(return_dtype, self._version) + narwhals_to_native_dtype(return_dtype, self._version, self._backend_version) if return_dtype else None ) @@ -244,8 +242,6 @@ def median(self: Self) -> Any: return self._native_series.median() def to_dummies(self: Self, *, separator: str, drop_first: bool) -> PolarsDataFrame: - import polars as pl - from narwhals._polars.dataframe import PolarsDataFrame if self._backend_version < (0, 20, 15): @@ -294,8 +290,6 @@ def ewm_mean( **extra_kwargs, ) if self._backend_version < (1,): # pragma: no cover - import polars as pl - return self._from_native_series( pl.select( pl.when(~native_series.is_null()).then(native_result).otherwise(None) @@ -405,8 +399,6 @@ def sort(self: Self, *, descending: bool, nulls_last: bool) -> Self: result = self._native_series.sort(descending=descending) if nulls_last: - import polars as pl - is_null = result.is_null() result = pl.concat([result.filter(~is_null), result.filter(is_null)]) else: @@ -433,8 +425,6 @@ def value_counts( from narwhals._polars.dataframe import PolarsDataFrame if self._backend_version < (1, 0, 0): - import polars as pl - value_name_ = name or ("proportion" if normalize else "count") result = self._native_series.value_counts(sort=sort, parallel=parallel) @@ -547,15 +537,11 @@ def len(self: Self) -> PolarsSeries: native_result = native_series.list.len() if self._series._backend_version < (1, 16): # pragma: no cover - import polars as pl - native_result = pl.select( pl.when(~native_series.is_null()).then(native_result).otherwise(None) )[native_series.name].cast(pl.UInt32()) elif self._series._backend_version < (1, 17): # pragma: no cover - import polars as pl - native_result = native_series.cast(pl.UInt32()) return self._series._from_native_series(native_result) diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index 929830316..25719157d 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -132,24 +132,19 @@ def native_to_narwhals_dtype( native_to_narwhals_dtype(dtype.inner, version, backend_version) # type: ignore[attr-defined] ) if dtype == pl.Array: - if backend_version < (0, 20, 30): # pragma: no cover - return dtypes.Array( - native_to_narwhals_dtype(dtype.inner, version, backend_version), # type: ignore[attr-defined] - dtype.width, # type: ignore[attr-defined] - ) - else: - return dtypes.Array( - native_to_narwhals_dtype(dtype.inner, version, backend_version), # type: ignore[attr-defined] - dtype.size, # type: ignore[attr-defined] - ) + outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size # type: ignore[attr-defined] + return dtypes.Array( + inner=native_to_narwhals_dtype(dtype.inner, version, backend_version), # type: ignore[attr-defined] + shape=outer_shape, + ) if dtype == pl.Decimal: return dtypes.Decimal() return dtypes.Unknown() -def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pl.DataType: - import polars as pl - +def narwhals_to_native_dtype( + dtype: DType | type[DType], version: Version, backend_version: tuple[int, ...] +) -> pl.DataType: dtypes = import_dtypes_module(version) if dtype == dtypes.Float64: @@ -193,20 +188,24 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pl du_time_unit: TimeUnit = getattr(dtype, "time_unit", "us") return pl.Duration(time_unit=du_time_unit) # type: ignore[arg-type] if dtype == dtypes.List: - return pl.List(narwhals_to_native_dtype(dtype.inner, version)) # type: ignore[union-attr] + return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version)) # type: ignore[union-attr] if dtype == dtypes.Struct: return pl.Struct( fields=[ pl.Field( name=field.name, - dtype=narwhals_to_native_dtype(field.dtype, version), + dtype=narwhals_to_native_dtype(field.dtype, version, backend_version), ) for field in dtype.fields # type: ignore[union-attr] ] ) if dtype == dtypes.Array: # pragma: no cover - msg = "Converting to Array dtype is not supported yet" - raise NotImplementedError(msg) + size = dtype.size # type: ignore[union-attr] + kwargs = {"width": size} if backend_version < (0, 20, 30) else {"shape": size} + return pl.Array( + inner=narwhals_to_native_dtype(dtype.inner, version, backend_version), # type: ignore[union-attr] + **kwargs, + ) return pl.Unknown() # pragma: no cover diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index 7ef7f6b01..f4892cc28 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from types import ModuleType - import pyspark.types as pyspark_types + import pyspark.sql.types as pyspark_types from pyspark.sql import Column from narwhals._spark_like.dataframe import SparkLikeLazyFrame @@ -58,7 +58,8 @@ def native_to_narwhals_dtype( if isinstance(dtype, spark_types.ByteType): return dtypes.Int8() if isinstance( - dtype, (spark_types.StringType, spark_types.VarcharType, spark_types.CharType) + dtype, + (spark_types.StringType, spark_types.VarcharType, spark_types.CharType), ): return dtypes.String() if isinstance(dtype, spark_types.BooleanType): @@ -123,8 +124,13 @@ def narwhals_to_native_dtype( msg = "Converting to Struct dtype is not supported yet" raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover - msg = "Converting to Array dtype is not supported yet" - raise NotImplementedError(msg) + inner = narwhals_to_native_dtype( + dtype.inner, # type: ignore[union-attr] + version=version, + spark_types=spark_types, + ) + return spark_types.ArrayType(elementType=inner) + if isinstance_or_issubclass( dtype, (dtypes.UInt64, dtypes.UInt32, dtypes.UInt16, dtypes.UInt8) ): # pragma: no cover diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 4a9361073..7d9cf3136 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -765,7 +765,7 @@ class Array(NestedType): Arguments: inner: The datatype of the values within each array. - width: the length of each array. + shape: The shape of the arrays. Examples: >>> import pandas as pd @@ -778,21 +778,44 @@ class Array(NestedType): >>> ser_pa = pa.chunked_array([data], type=pa.list_(pa.int32(), 2)) >>> nw.from_native(ser_pd, series_only=True).dtype - Array(Int32, 2) + Array(Int32, shape=(2,)) >>> nw.from_native(ser_pl, series_only=True).dtype - Array(Int32, 2) + Array(Int32, shape=(2,)) >>> nw.from_native(ser_pa, series_only=True).dtype - Array(Int32, 2) + Array(Int32, shape=(2,)) """ + inner: DType | type[DType] + size: int + shape: tuple[int, ...] + def __init__( - self: Self, inner: DType | type[DType], width: int | None = None + self: Self, + inner: DType | type[DType], + shape: int | tuple[int, ...] | None = None, ) -> None: - self.inner = inner - if width is None: - error = "`width` must be specified when initializing an `Array`" - raise TypeError(error) - self.width = width + inner_shape: tuple[int, ...] = inner.shape if isinstance(inner, Array) else () + + if shape is None: # pragma: no cover + msg = "Array constructor is missing the required argument `shape`" + raise TypeError(msg) + + if isinstance(shape, int): + self.inner = inner + self.size = shape + self.shape = (shape, *inner_shape) + + elif isinstance(shape, tuple) and isinstance(shape[0], int): + if len(shape) > 1: + inner = Array(inner, shape[1:]) + + self.inner = inner + self.size = shape[0] + self.shape = shape + inner_shape + + else: + msg = f"invalid input for shape: {shape!r}" + raise TypeError(msg) def __eq__(self: Self, other: DType | type[DType]) -> bool: # type: ignore[override] # This equality check allows comparison of type classes and type instances. @@ -805,16 +828,24 @@ def __eq__(self: Self, other: DType | type[DType]) -> bool: # type: ignore[over if type(other) is type and issubclass(other, self.__class__): return True elif isinstance(other, self.__class__): - return self.inner == other.inner + if self.shape != other.shape: + return False + else: + return self.inner == other.inner else: return False def __hash__(self: Self) -> int: - return hash((self.__class__, self.inner, self.width)) + return hash((self.__class__, self.inner, self.shape)) + + def __repr__(self) -> str: + # Get leaf type + dtype_ = self + for _ in self.shape: + dtype_ = dtype_.inner # type: ignore[assignment] - def __repr__(self: Self) -> str: class_name = self.__class__.__name__ - return f"{class_name}({self.inner!r}, {self.width})" + return f"{class_name}({dtype_!r}, shape={self.shape})" class Date(TemporalType): diff --git a/narwhals/functions.py b/narwhals/functions.py index 052578bcf..51b80d11d 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -243,7 +243,10 @@ def _new_series_impl( narwhals_to_native_dtype as polars_narwhals_to_native_dtype, ) - dtype_pl = polars_narwhals_to_native_dtype(dtype, version=version) + backend_version = parse_version(native_namespace.__version__) + dtype_pl = polars_narwhals_to_native_dtype( + dtype, version=version, backend_version=backend_version + ) else: dtype_pl = None @@ -498,8 +501,13 @@ def _from_numpy_impl( narwhals_to_native_dtype as polars_narwhals_to_native_dtype, ) + backend_version = parse_version(native_namespace.__version__) schema = { - name: polars_narwhals_to_native_dtype(dtype, version=version) # type: ignore[misc] + name: polars_narwhals_to_native_dtype( # type: ignore[misc] + dtype, + version=version, + backend_version=backend_version, + ) for name, dtype in schema.items() } elif schema is None: diff --git a/narwhals/schema.py b/narwhals/schema.py index e6cb0dddb..b287e1aaf 100644 --- a/narwhals/schema.py +++ b/narwhals/schema.py @@ -202,10 +202,18 @@ def to_polars(self: Self) -> pl.Schema: from narwhals._polars.utils import narwhals_to_native_dtype + pl_version = parse_version(pl) schema = ( - (name, narwhals_to_native_dtype(dtype, self._version)) + ( + name, + narwhals_to_native_dtype( + dtype, self._version, backend_version=pl_version + ), + ) for name, dtype in self.items() ) - if parse_version(pl) >= (1, 0, 0): - return pl.Schema(schema) - return cast("pl.Schema", dict(schema)) # pragma: no cover + return ( + pl.Schema(schema) + if pl_version >= (1, 0, 0) + else cast("pl.Schema", dict(schema)) + ) diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index bd567478d..f523c3a63 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -3,6 +3,7 @@ from datetime import datetime from datetime import timedelta from datetime import timezone +from typing import TYPE_CHECKING from typing import Literal import numpy as np @@ -15,6 +16,9 @@ from tests.utils import PANDAS_VERSION from tests.utils import POLARS_VERSION +if TYPE_CHECKING: + from tests.utils import Constructor + @pytest.mark.parametrize("time_unit", ["us", "ns", "ms"]) @pytest.mark.parametrize("time_zone", ["Europe/Rome", timezone.utc, None]) @@ -73,9 +77,10 @@ def test_array_valid() -> None: dtype = nw.Array(nw.Int64, 2) assert dtype == nw.Array(nw.Int64, 2) assert dtype == nw.Array + assert dtype != nw.Array(nw.Int64, 3) assert dtype != nw.Array(nw.Float32, 2) assert dtype != nw.Duration - assert repr(dtype) == "Array(, 2)" + assert repr(dtype) == "Array(, shape=(2,))" dtype = nw.Array(nw.Array(nw.Int64, 2), 2) assert dtype == nw.Array(nw.Array(nw.Int64, 2), 2) assert dtype == nw.Array @@ -83,9 +88,12 @@ def test_array_valid() -> None: assert dtype in {nw.Array(nw.Array(nw.Int64, 2), 2)} with pytest.raises( - TypeError, match="`width` must be specified when initializing an `Array`" + TypeError, match="Array constructor is missing the required argument `shape`" ): - dtype = nw.Array(nw.Int64) + nw.Array(nw.Int64) + + with pytest.raises(TypeError, match="invalid input for shape"): + nw.Array(nw.Int64(), shape="invalid_type") # type: ignore[arg-type] def test_struct_valid() -> None: @@ -125,21 +133,16 @@ def test_struct_hashes() -> None: assert len({hash(tp) for tp in (dtypes)}) == 3 -@pytest.mark.skipif( - POLARS_VERSION < (1,) or PANDAS_VERSION < (2, 2), - reason="`shape` is only available after 1.0", -) -def test_polars_2d_array() -> None: - df = pl.DataFrame( - {"a": [[[1, 2], [3, 4], [5, 6]]]}, schema={"a": pl.Array(pl.Int64, (3, 2))} - ) - assert nw.from_native(df).collect_schema()["a"] == nw.Array(nw.Array(nw.Int64, 2), 3) - assert nw.from_native(df.to_arrow()).collect_schema()["a"] == nw.Array( - nw.Array(nw.Int64, 2), 3 +@pytest.mark.skipif(PANDAS_VERSION < (2, 2), reason="old pandas") +def test_2d_array(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if any(x in str(constructor) for x in ("dask", "modin", "cudf", "pyspark")): + request.applymarker(pytest.mark.xfail) + data = {"a": [[[1, 2], [3, 4], [5, 6]]]} + df = nw.from_native(constructor(data)).with_columns( + a=nw.col("a").cast(nw.Array(nw.Int64(), (3, 2))) ) - assert nw.from_native( - df.to_pandas(use_pyarrow_extension_array=True) - ).collect_schema()["a"] == nw.Array(nw.Array(nw.Int64, 2), 3) + assert df.collect_schema()["a"] == nw.Array(nw.Int64(), (3, 2)) + assert df.collect_schema()["a"] == nw.Array(nw.Array(nw.Int64(), 2), 3) def test_second_time_unit() -> None: