Skip to content

Commit 2235bae

Browse files
authored
fix: narwhals_to_native_dtype raise if polars dtype is passed (#1307)
* fix: raise for non-narwhals dtypes * check for polars specifically * mypy ignore
1 parent abd9d4a commit 2235bae

File tree

8 files changed

+54
-12
lines changed

8 files changed

+54
-12
lines changed

narwhals/_arrow/utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any
55
from typing import Sequence
66

7+
from narwhals.dependencies import get_polars
78
from narwhals.utils import isinstance_or_issubclass
89

910
if TYPE_CHECKING:
@@ -76,6 +77,17 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:
7677

7778

7879
def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
80+
if (pl := get_polars()) is not None and isinstance(
81+
dtype, (pl.DataType, pl.DataType.__class__)
82+
):
83+
msg = (
84+
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
85+
"Perhaps you:\n"
86+
"- Forgot a `nw.from_native` somewhere?\n"
87+
"- Used `pl.Int64` instead of `nw.Int64`?"
88+
)
89+
raise TypeError(msg)
90+
7991
import pyarrow as pa # ignore-banned-import
8092

8193
if isinstance_or_issubclass(dtype, dtypes.Float64):

narwhals/_dask/utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any
55

66
from narwhals.dependencies import get_pandas
7+
from narwhals.dependencies import get_polars
78
from narwhals.dependencies import get_pyarrow
89
from narwhals.utils import isinstance_or_issubclass
910
from narwhals.utils import parse_version
@@ -85,6 +86,17 @@ def validate_comparand(lhs: dask_expr.Series, rhs: dask_expr.Series) -> None:
8586

8687

8788
def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
89+
if (pl := get_polars()) is not None and isinstance(
90+
dtype, (pl.DataType, pl.DataType.__class__)
91+
):
92+
msg = (
93+
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
94+
"Perhaps you:\n"
95+
"- Forgot a `nw.from_native` somewhere?\n"
96+
"- Used `pl.Int64` instead of `nw.Int64`?"
97+
)
98+
raise TypeError(msg)
99+
88100
if isinstance_or_issubclass(dtype, dtypes.Float64):
89101
return "float64"
90102
if isinstance_or_issubclass(dtype, dtypes.Float32):

narwhals/_pandas_like/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from narwhals._arrow.utils import (
1111
native_to_narwhals_dtype as arrow_native_to_narwhals_dtype,
1212
)
13+
from narwhals.dependencies import get_polars
1314
from narwhals.utils import Implementation
1415
from narwhals.utils import isinstance_or_issubclass
1516

@@ -376,7 +377,9 @@ def narwhals_to_native_dtype( # noqa: PLR0915
376377
backend_version: tuple[int, ...],
377378
dtypes: DTypes,
378379
) -> Any:
379-
if "polars" in str(type(dtype)):
380+
if (pl := get_polars()) is not None and isinstance(
381+
dtype, (pl.DataType, pl.DataType.__class__)
382+
):
380383
msg = (
381384
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
382385
"Perhaps you:\n"

narwhals/_polars/utils.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from narwhals.dtypes import DType
99
from narwhals.typing import DTypes
1010

11+
from narwhals.dependencies import get_polars
1112
from narwhals.utils import parse_version
1213

1314

@@ -94,6 +95,17 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:
9495

9596

9697
def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
98+
if (pl := get_polars()) is not None and isinstance(
99+
dtype, (pl.DataType, pl.DataType.__class__)
100+
):
101+
msg = (
102+
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
103+
"Perhaps you:\n"
104+
"- Forgot a `nw.from_native` somewhere?\n"
105+
"- Used `pl.Int64` instead of `nw.Int64`?"
106+
)
107+
raise TypeError(msg)
108+
97109
import polars as pl # ignore-banned-import()
98110

99111
if dtype == dtypes.Float64:
@@ -132,7 +144,7 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
132144
if dtype == dtypes.Datetime or isinstance(dtype, dtypes.Datetime):
133145
dt_time_unit = getattr(dtype, "time_unit", "us")
134146
dt_time_zone = getattr(dtype, "time_zone", None)
135-
return pl.Datetime(dt_time_unit, dt_time_zone) # type: ignore[arg-type]
147+
return pl.Datetime(dt_time_unit, dt_time_zone)
136148
if dtype == dtypes.Duration or isinstance(dtype, dtypes.Duration):
137149
du_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us")
138150
return pl.Duration(time_unit=du_time_unit)

narwhals/expr.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,7 @@ def pipe(self, function: Callable[[Any], Self], *args: Any, **kwargs: Any) -> Se
135135
"""
136136
return function(self, *args, **kwargs)
137137

138-
def cast(
139-
self,
140-
dtype: Any,
141-
) -> Self:
138+
def cast(self: Self, dtype: DType | type[DType]) -> Self:
142139
"""
143140
Redefine an object's data type.
144141

narwhals/series.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -383,10 +383,7 @@ def name(self) -> str:
383383
"""
384384
return self._compliant_series.name # type: ignore[no-any-return]
385385

386-
def cast(
387-
self,
388-
dtype: Any,
389-
) -> Self:
386+
def cast(self: Self, dtype: DType | type[DType]) -> Self:
390387
"""
391388
Cast between data types.
392389

tests/expr_and_series/cast_test.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from datetime import datetime
44
from datetime import timedelta
55
from datetime import timezone
6+
from typing import Any
67

78
import pandas as pd
9+
import polars as pl
810
import pytest
911

1012
import narwhals.stable.v1 as nw
@@ -191,7 +193,7 @@ class Banana:
191193
pass
192194

193195
with pytest.raises(AssertionError, match=r"Unknown dtype"):
194-
df.select(nw.col("a").cast(Banana))
196+
df.select(nw.col("a").cast(Banana)) # type: ignore[arg-type]
195197

196198

197199
def test_cast_datetime_tz_aware(
@@ -222,3 +224,10 @@ def test_cast_datetime_tz_aware(
222224
.str.slice(offset=0, length=19)
223225
)
224226
assert_equal_data(result, expected)
227+
228+
229+
@pytest.mark.parametrize("dtype", [pl.String, pl.String()])
230+
def test_raise_if_polars_dtype(constructor: Constructor, dtype: Any) -> None:
231+
df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]}))
232+
with pytest.raises(TypeError, match="Expected Narwhals object, got:"):
233+
df.select(nw.col("a").cast(dtype))

tests/frame/invalid_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_invalid() -> None:
2020
with pytest.raises(TypeError, match="Perhaps you:"):
2121
df.select([pl.col("a")]) # type: ignore[list-item]
2222
with pytest.raises(TypeError, match="Perhaps you:"):
23-
df.select([nw.col("a").cast(pl.Int64)])
23+
df.select([nw.col("a").cast(pl.Int64)]) # type: ignore[arg-type]
2424

2525

2626
def test_native_vs_non_native() -> None:

0 commit comments

Comments
 (0)