Skip to content

Commit 6040d32

Browse files
committed
check for polars specifically
1 parent 487bfa8 commit 6040d32

File tree

7 files changed

+20
-32
lines changed

7 files changed

+20
-32
lines changed

narwhals/_arrow/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any
55
from typing import Sequence
66

7+
from narwhals.dependencies import get_polars
78
from narwhals.utils import isinstance_or_issubclass
89

910
if TYPE_CHECKING:
@@ -76,7 +77,9 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:
7677

7778

7879
def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
79-
if "polars" in str(type(dtype)):
80+
if (pl := get_polars()) is not None and isinstance(
81+
dtype, (pl.DataType, pl.DataType.__class__)
82+
):
8083
msg = (
8184
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
8285
"Perhaps you:\n"

narwhals/_dask/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any
55

66
from narwhals.dependencies import get_pandas
7+
from narwhals.dependencies import get_polars
78
from narwhals.dependencies import get_pyarrow
89
from narwhals.utils import isinstance_or_issubclass
910
from narwhals.utils import parse_version
@@ -85,7 +86,9 @@ def validate_comparand(lhs: dask_expr.Series, rhs: dask_expr.Series) -> None:
8586

8687

8788
def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
88-
if "polars" in str(type(dtype)):
89+
if (pl := get_polars()) is not None and isinstance(
90+
dtype, (pl.DataType, pl.DataType.__class__)
91+
):
8992
msg = (
9093
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
9194
"Perhaps you:\n"

narwhals/_pandas_like/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from narwhals._arrow.utils import (
1111
native_to_narwhals_dtype as arrow_native_to_narwhals_dtype,
1212
)
13+
from narwhals.dependencies import get_polars
1314
from narwhals.utils import Implementation
1415
from narwhals.utils import isinstance_or_issubclass
1516

@@ -339,7 +340,9 @@ def narwhals_to_native_dtype( # noqa: PLR0915
339340
backend_version: tuple[int, ...],
340341
dtypes: DTypes,
341342
) -> Any:
342-
if "polars" in str(type(dtype)):
343+
if (pl := get_polars()) is not None and isinstance(
344+
dtype, (pl.DataType, pl.DataType.__class__)
345+
):
343346
msg = (
344347
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
345348
"Perhaps you:\n"

narwhals/_polars/utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from narwhals.dtypes import DType
99
from narwhals.typing import DTypes
1010

11+
from narwhals.dependencies import get_polars
1112
from narwhals.utils import parse_version
1213

1314

@@ -94,7 +95,9 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:
9495

9596

9697
def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
97-
if "polars" in str(type(dtype)):
98+
if (pl := get_polars()) is not None and isinstance(
99+
dtype, (pl.DataType, pl.DataType.__class__)
100+
):
98101
msg = (
99102
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
100103
"Perhaps you:\n"
@@ -141,7 +144,7 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
141144
if dtype == dtypes.Datetime or isinstance(dtype, dtypes.Datetime):
142145
dt_time_unit = getattr(dtype, "time_unit", "us")
143146
dt_time_zone = getattr(dtype, "time_zone", None)
144-
return pl.Datetime(dt_time_unit, dt_time_zone) # type: ignore[arg-type]
147+
return pl.Datetime(dt_time_unit, dt_time_zone)
145148
if dtype == dtypes.Duration or isinstance(dtype, dtypes.Duration):
146149
du_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us")
147150
return pl.Duration(time_unit=du_time_unit)

narwhals/expr.py

-13
Original file line numberDiff line numberDiff line change
@@ -186,19 +186,6 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self:
186186
foo: [[1,2,3]]
187187
bar: [[6,7,8]]
188188
"""
189-
# from narwhals.dtypes import DType
190-
191-
# if not (
192-
# isinstance(dtype, DType)
193-
# or (isinstance(dtype, type) and issubclass(dtype, DType))
194-
# ):
195-
# msg = (
196-
# f"Expected Narwhals DType, got: {type(dtype)}.\n\n"
197-
# "Hint: Perhaps you used Polars DataType instance `pl.dtype` instead of "
198-
# "Narwhals DType `nw.dtype`?"
199-
# )
200-
# raise TypeError(msg)
201-
202189
return self.__class__(
203190
lambda plx: self._call(plx).cast(dtype),
204191
)

narwhals/series.py

-10
Original file line numberDiff line numberDiff line change
@@ -420,16 +420,6 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self:
420420
1
421421
]
422422
"""
423-
# from narwhals.dtypes import DType
424-
425-
# if not (isinstance(dtype, DType) or dtype == DType()):
426-
# msg = (
427-
# f"Expected Narwhals DType, got: {type(dtype)}.\n\n"
428-
# "Hint: Perhaps you used Polars DataType instance `pl.dtype` instead of "
429-
# "Narwhals DType `nw.dtype`?"
430-
# )
431-
# raise TypeError(msg)
432-
433423
return self._from_compliant_series(self._compliant_series.cast(dtype))
434424

435425
def to_frame(self) -> DataFrame[Any]:

tests/expr_and_series/cast_test.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import pandas as pd
99
import polars as pl
10-
import pyarrow as pa
1110
import pytest
1211

1312
import narwhals.stable.v1 as nw
@@ -194,7 +193,7 @@ class Banana:
194193
pass
195194

196195
with pytest.raises(AssertionError, match=r"Unknown dtype"):
197-
df.select(nw.col("a").cast(Banana))
196+
df.select(nw.col("a").cast(Banana)) # type: ignore[arg-type]
198197

199198

200199
def test_cast_datetime_tz_aware(
@@ -227,8 +226,8 @@ def test_cast_datetime_tz_aware(
227226
assert_equal_data(result, expected)
228227

229228

230-
@pytest.mark.parametrize("dtype", [pl.String, pl.String(), pa.float64(), str])
231-
def test_raise_if_not_narwhals_dtype(constructor: Constructor, dtype: Any) -> None:
229+
@pytest.mark.parametrize("dtype", [pl.String, pl.String()])
230+
def test_raise_if_polars_dtype(constructor: Constructor, dtype: Any) -> None:
232231
df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]}))
233232
with pytest.raises(TypeError, match="Expected Narwhals object, got:"):
234233
df.select(nw.col("a").cast(dtype))

0 commit comments

Comments
 (0)