Skip to content

Commit 54c49b0

Browse files
authored
chore(typing): Upgrade TypeGuard -> TypeIs (#2027)
* chore(typing): Upgrade `TypeGuard` -> `TypeIs` https://typing.readthedocs.io/en/latest/spec/narrowing.html#typeis * fix(typing): revert, add note on `is_numpy_scalar` * feat(typing): Preserve `TypeVar`(s) for `nw` guards
1 parent 825e027 commit 54c49b0

File tree

1 file changed

+36
-32
lines changed

1 file changed

+36
-32
lines changed

narwhals/dependencies.py

+36-32
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,27 @@
88
from typing import Any
99

1010
if TYPE_CHECKING:
11-
import numpy as np
12-
import sqlframe
13-
14-
if sys.version_info >= (3, 10):
15-
from typing import TypeGuard
16-
else:
17-
from typing_extensions import TypeGuard
1811
import cudf
1912
import dask.dataframe as dd
2013
import duckdb
2114
import ibis
2215
import modin.pandas as mpd
16+
import numpy as np
2317
import pandas as pd
2418
import polars as pl
2519
import pyarrow as pa
2620
import pyspark.sql as pyspark_sql
21+
import sqlframe
22+
from typing_extensions import TypeGuard
2723
from typing_extensions import TypeIs
2824

2925
from narwhals.dataframe import DataFrame
3026
from narwhals.dataframe import LazyFrame
3127
from narwhals.series import Series
32-
from narwhals.typing import IntoSeries
28+
from narwhals.typing import DataFrameT
29+
from narwhals.typing import FrameT
30+
from narwhals.typing import IntoDataFrameT
31+
from narwhals.typing import IntoSeriesT
3332
from narwhals.typing import _1DArray
3433
from narwhals.typing import _2DArray
3534
from narwhals.typing import _NDArray
@@ -118,7 +117,7 @@ def get_sqlframe() -> Any:
118117
return sys.modules.get("sqlframe", None)
119118

120119

121-
def is_pandas_dataframe(df: Any) -> TypeGuard[pd.DataFrame]:
120+
def is_pandas_dataframe(df: Any) -> TypeIs[pd.DataFrame]:
122121
"""Check whether `df` is a pandas DataFrame without importing pandas."""
123122
return ((pd := get_pandas()) is not None and isinstance(df, pd.DataFrame)) or any(
124123
(mod := sys.modules.get(module_name, None)) is not None
@@ -127,7 +126,7 @@ def is_pandas_dataframe(df: Any) -> TypeGuard[pd.DataFrame]:
127126
)
128127

129128

130-
def is_pandas_series(ser: Any) -> TypeGuard[pd.Series[Any]]:
129+
def is_pandas_series(ser: Any) -> TypeIs[pd.Series[Any]]:
131130
"""Check whether `ser` is a pandas Series without importing pandas."""
132131
return ((pd := get_pandas()) is not None and isinstance(ser, pd.Series)) or any(
133132
(mod := sys.modules.get(module_name, None)) is not None
@@ -136,7 +135,7 @@ def is_pandas_series(ser: Any) -> TypeGuard[pd.Series[Any]]:
136135
)
137136

138137

139-
def is_pandas_index(index: Any) -> TypeGuard[pd.Index]:
138+
def is_pandas_index(index: Any) -> TypeIs[pd.Index]:
140139
"""Check whether `index` is a pandas Index without importing pandas."""
141140
return ((pd := get_pandas()) is not None and isinstance(index, pd.Index)) or any(
142141
(mod := sys.modules.get(module_name, None)) is not None
@@ -145,91 +144,91 @@ def is_pandas_index(index: Any) -> TypeGuard[pd.Index]:
145144
)
146145

147146

148-
def is_modin_dataframe(df: Any) -> TypeGuard[mpd.DataFrame]:
147+
def is_modin_dataframe(df: Any) -> TypeIs[mpd.DataFrame]:
149148
"""Check whether `df` is a modin DataFrame without importing modin."""
150149
return (mpd := get_modin()) is not None and isinstance(df, mpd.DataFrame)
151150

152151

153-
def is_modin_series(ser: Any) -> TypeGuard[mpd.Series]:
152+
def is_modin_series(ser: Any) -> TypeIs[mpd.Series]:
154153
"""Check whether `ser` is a modin Series without importing modin."""
155154
return (mpd := get_modin()) is not None and isinstance(ser, mpd.Series)
156155

157156

158-
def is_modin_index(index: Any) -> TypeGuard[mpd.Index]:
157+
def is_modin_index(index: Any) -> TypeIs[mpd.Index]:
159158
"""Check whether `index` is a modin Index without importing modin."""
160159
return (mpd := get_modin()) is not None and isinstance(
161160
index, mpd.Index
162161
) # pragma: no cover
163162

164163

165-
def is_cudf_dataframe(df: Any) -> TypeGuard[cudf.DataFrame]:
164+
def is_cudf_dataframe(df: Any) -> TypeIs[cudf.DataFrame]:
166165
"""Check whether `df` is a cudf DataFrame without importing cudf."""
167166
return (cudf := get_cudf()) is not None and isinstance(df, cudf.DataFrame)
168167

169168

170-
def is_cudf_series(ser: Any) -> TypeGuard[cudf.Series[Any]]:
169+
def is_cudf_series(ser: Any) -> TypeIs[cudf.Series[Any]]:
171170
"""Check whether `ser` is a cudf Series without importing cudf."""
172171
return (cudf := get_cudf()) is not None and isinstance(ser, cudf.Series)
173172

174173

175-
def is_cudf_index(index: Any) -> TypeGuard[cudf.Index]:
174+
def is_cudf_index(index: Any) -> TypeIs[cudf.Index]:
176175
"""Check whether `index` is a cudf Index without importing cudf."""
177176
return (cudf := get_cudf()) is not None and isinstance(
178177
index, cudf.Index
179178
) # pragma: no cover
180179

181180

182-
def is_dask_dataframe(df: Any) -> TypeGuard[dd.DataFrame]:
181+
def is_dask_dataframe(df: Any) -> TypeIs[dd.DataFrame]:
183182
"""Check whether `df` is a Dask DataFrame without importing Dask."""
184183
return (dd := get_dask_dataframe()) is not None and isinstance(df, dd.DataFrame)
185184

186185

187-
def is_duckdb_relation(df: Any) -> TypeGuard[duckdb.DuckDBPyRelation]:
186+
def is_duckdb_relation(df: Any) -> TypeIs[duckdb.DuckDBPyRelation]:
188187
"""Check whether `df` is a DuckDB Relation without importing DuckDB."""
189188
return (duckdb := get_duckdb()) is not None and isinstance(
190189
df, duckdb.DuckDBPyRelation
191190
)
192191

193192

194-
def is_ibis_table(df: Any) -> TypeGuard[ibis.Table]:
193+
def is_ibis_table(df: Any) -> TypeIs[ibis.Table]:
195194
"""Check whether `df` is a Ibis Table without importing Ibis."""
196195
return (ibis := get_ibis()) is not None and isinstance(df, ibis.expr.types.Table)
197196

198197

199-
def is_polars_dataframe(df: Any) -> TypeGuard[pl.DataFrame]:
198+
def is_polars_dataframe(df: Any) -> TypeIs[pl.DataFrame]:
200199
"""Check whether `df` is a Polars DataFrame without importing Polars."""
201200
return (pl := get_polars()) is not None and isinstance(df, pl.DataFrame)
202201

203202

204-
def is_polars_lazyframe(df: Any) -> TypeGuard[pl.LazyFrame]:
203+
def is_polars_lazyframe(df: Any) -> TypeIs[pl.LazyFrame]:
205204
"""Check whether `df` is a Polars LazyFrame without importing Polars."""
206205
return (pl := get_polars()) is not None and isinstance(df, pl.LazyFrame)
207206

208207

209-
def is_polars_series(ser: Any) -> TypeGuard[pl.Series]:
208+
def is_polars_series(ser: Any) -> TypeIs[pl.Series]:
210209
"""Check whether `ser` is a Polars Series without importing Polars."""
211210
return (pl := get_polars()) is not None and isinstance(ser, pl.Series)
212211

213212

214-
def is_pyarrow_chunked_array(ser: Any) -> TypeGuard[pa.ChunkedArray]:
213+
def is_pyarrow_chunked_array(ser: Any) -> TypeIs[pa.ChunkedArray]:
215214
"""Check whether `ser` is a PyArrow ChunkedArray without importing PyArrow."""
216215
return (pa := get_pyarrow()) is not None and isinstance(ser, pa.ChunkedArray)
217216

218217

219-
def is_pyarrow_table(df: Any) -> TypeGuard[pa.Table]:
218+
def is_pyarrow_table(df: Any) -> TypeIs[pa.Table]:
220219
"""Check whether `df` is a PyArrow Table without importing PyArrow."""
221220
return (pa := get_pyarrow()) is not None and isinstance(df, pa.Table)
222221

223222

224-
def is_pyspark_dataframe(df: Any) -> TypeGuard[pyspark_sql.DataFrame]:
223+
def is_pyspark_dataframe(df: Any) -> TypeIs[pyspark_sql.DataFrame]:
225224
"""Check whether `df` is a PySpark DataFrame without importing PySpark."""
226225
return bool(
227226
(pyspark_sql := get_pyspark_sql()) is not None
228227
and isinstance(df, pyspark_sql.DataFrame)
229228
)
230229

231230

232-
def is_sqlframe_dataframe(df: Any) -> TypeGuard[sqlframe.base.dataframe.BaseDataFrame]:
231+
def is_sqlframe_dataframe(df: Any) -> TypeIs[sqlframe.base.dataframe.BaseDataFrame]:
233232
"""Check whether `df` is a SQLFrame DataFrame without importing SQLFrame."""
234233
return bool(
235234
(sqlframe := get_sqlframe()) is not None
@@ -254,6 +253,9 @@ def is_numpy_array_2d(arr: Any) -> TypeIs[_2DArray]:
254253

255254
def is_numpy_scalar(scalar: Any) -> TypeGuard[np.generic]:
256255
"""Check whether `scalar` is a NumPy Scalar without importing NumPy."""
256+
# NOTE: Needs to stay as `TypeGuard`
257+
# - Used in `Series.__getitem__`, but not annotated
258+
# - `TypeGuard` is *hiding* that the check introduces an intersection
257259
return (np := get_numpy()) is not None and np.isscalar(scalar)
258260

259261

@@ -283,7 +285,7 @@ def is_pandas_like_index(index: Any) -> bool:
283285
) # pragma: no cover
284286

285287

286-
def is_into_series(native_series: IntoSeries) -> bool:
288+
def is_into_series(native_series: Any | IntoSeriesT) -> TypeIs[IntoSeriesT]:
287289
"""Check whether `native_series` can be converted to a Narwhals Series.
288290
289291
Arguments:
@@ -320,7 +322,7 @@ def is_into_series(native_series: IntoSeries) -> bool:
320322
)
321323

322324

323-
def is_into_dataframe(native_dataframe: Any) -> bool:
325+
def is_into_dataframe(native_dataframe: Any | IntoDataFrameT) -> TypeIs[IntoDataFrameT]:
324326
"""Check whether `native_dataframe` can be converted to a Narwhals DataFrame.
325327
326328
Arguments:
@@ -357,7 +359,9 @@ def is_into_dataframe(native_dataframe: Any) -> bool:
357359
)
358360

359361

360-
def is_narwhals_dataframe(df: Any) -> TypeGuard[DataFrame[Any]]:
362+
def is_narwhals_dataframe(
363+
df: Any | DataFrame[DataFrameT],
364+
) -> TypeIs[DataFrame[DataFrameT]]:
361365
"""Check whether `df` is a Narwhals DataFrame.
362366
363367
This is useful if you expect a user to pass in a Narwhals
@@ -369,7 +373,7 @@ def is_narwhals_dataframe(df: Any) -> TypeGuard[DataFrame[Any]]:
369373
return isinstance(df, DataFrame)
370374

371375

372-
def is_narwhals_lazyframe(lf: Any) -> TypeGuard[LazyFrame[Any]]:
376+
def is_narwhals_lazyframe(lf: Any | LazyFrame[FrameT]) -> TypeIs[LazyFrame[FrameT]]:
373377
"""Check whether `lf` is a Narwhals LazyFrame.
374378
375379
This is useful if you expect a user to pass in a Narwhals
@@ -381,7 +385,7 @@ def is_narwhals_lazyframe(lf: Any) -> TypeGuard[LazyFrame[Any]]:
381385
return isinstance(lf, LazyFrame)
382386

383387

384-
def is_narwhals_series(ser: Any) -> TypeGuard[Series[Any]]:
388+
def is_narwhals_series(ser: Any | Series[IntoSeriesT]) -> TypeIs[Series[IntoSeriesT]]:
385389
"""Check whether `ser` is a Narwhals Series.
386390
387391
This is useful if you expect a user to pass in a Narwhals

0 commit comments

Comments
 (0)