Skip to content

Commit d2d55da

Browse files
MarcoGorellidangotbannedpre-commit-ci[bot]
authored
chore: fix pandas-stubs issues (#2008)
* chore: fix pandas-stubs issues * fix * fix * fx * remove from mypy overrides * fix typing tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(typing): avoid `TypeVar` in tests * remove unused ignores * fix(DRAFT): maybe detect `"numpy_nullable"` w/ public api Need to see if we can drop the `suppress` on older versions #2008 (comment) * test: resolve most of the new warnings #2008 (comment) * test: resolve remaining warnings follows (a4dc352) * fix(typing): resolve `polars` warning https://github.com/narwhals-dev/narwhals/actions/runs/13372996489/job/37345672357?pr=2008 --------- Co-authored-by: dangotbanned <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6c66272 commit d2d55da

32 files changed

+160
-96
lines changed

narwhals/_dask/group_by.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626

2727

2828
def n_unique() -> dd.Aggregation:
29-
def chunk(s: pd.core.groupby.generic.SeriesGroupBy) -> int:
29+
def chunk(s: pd.core.groupby.generic.SeriesGroupBy) -> pd.Series[Any]:
3030
return s.nunique(dropna=False) # type: ignore[no-any-return]
3131

32-
def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> int:
32+
def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> pd.Series[Any]:
3333
return s0.sum() # type: ignore[no-any-return]
3434

3535
return dd.Aggregation(

narwhals/_dask/namespace.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,12 @@ def nth(self: Self, *column_indices: int) -> DaskExpr:
7777

7878
def lit(self: Self, value: Any, dtype: DType | None) -> DaskExpr:
7979
def func(df: DaskLazyFrame) -> list[dx.Series]:
80-
return [
81-
dd.from_pandas(
82-
pd.Series(
83-
[value],
84-
dtype=narwhals_to_native_dtype(dtype, self._version)
85-
if dtype is not None
86-
else None,
87-
name="literal",
88-
),
89-
npartitions=df._native_frame.npartitions,
90-
)
91-
]
80+
if dtype is not None:
81+
native_dtype = narwhals_to_native_dtype(dtype, self._version)
82+
s = pd.Series([value], dtype=native_dtype)
83+
else:
84+
s = pd.Series([value])
85+
return [dd.from_pandas(s, npartitions=df._native_frame.npartitions)]
9286

9387
return DaskExpr(
9488
func,

narwhals/_pandas_like/utils.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Iterable
1010
from typing import Sequence
1111
from typing import TypeVar
12+
from typing import cast
1213

1314
import pandas as pd
1415

@@ -511,7 +512,11 @@ def get_dtype_backend(dtype: Any, implementation: Implementation) -> DTypeBacken
511512
if hasattr(pd, "ArrowDtype") and isinstance(dtype, pd.ArrowDtype):
512513
return "pyarrow"
513514
with suppress(AttributeError):
514-
if isinstance(dtype, pd.core.dtypes.dtypes.BaseMaskedDtype):
515+
sentinel = object()
516+
if (
517+
isinstance(dtype, pd.api.extensions.ExtensionDtype)
518+
and getattr(dtype, "base", sentinel) is None
519+
):
515520
return "numpy_nullable"
516521
return None
517522

@@ -722,8 +727,17 @@ def int_dtype_mapper(dtype: Any) -> str:
722727
def convert_str_slice_to_int_slice(
723728
str_slice: slice, columns: pd.Index
724729
) -> tuple[int | None, int | None, int | None]:
725-
start = columns.get_loc(str_slice.start) if str_slice.start is not None else None
726-
stop = columns.get_loc(str_slice.stop) + 1 if str_slice.stop is not None else None
730+
# We can safely cast to int because we know that `columns` doesn't contain duplicates.
731+
start = (
732+
cast(int, columns.get_loc(str_slice.start))
733+
if str_slice.start is not None
734+
else None
735+
)
736+
stop = (
737+
cast(int, columns.get_loc(str_slice.stop)) + 1
738+
if str_slice.stop is not None
739+
else None
740+
)
727741
step = str_slice.step
728742
return (start, stop, step)
729743

narwhals/_polars/series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def hist(
483483
version=self._version,
484484
)
485485
elif (self._backend_version < (1, 15)) and self._native_series.count() < 1:
486-
data_dict: dict[str, list[int | float] | pl.Series | pl.Expr]
486+
data_dict: dict[str, Sequence[Any] | pl.Series]
487487
if bins is not None:
488488
data_dict = {
489489
"breakpoint": bins[1:],

narwhals/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,7 @@ def is_ordered_categorical(series: Series[Any]) -> bool:
977977
if is_polars_series(native_series):
978978
return native_series.dtype.ordering == "physical" # type: ignore[attr-defined, no-any-return]
979979
if is_pandas_series(native_series):
980-
return native_series.cat.ordered # type: ignore[no-any-return]
980+
return bool(native_series.cat.ordered)
981981
if is_modin_series(native_series): # pragma: no cover
982982
return native_series.cat.ordered # type: ignore[no-any-return]
983983
if is_cudf_series(native_series): # pragma: no cover

pyproject.toml

+1-3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ tests = [
4747
"typing_extensions",
4848
]
4949
typing = [
50+
"pandas-stubs",
5051
"typing_extensions",
5152
"mypy~=1.15.0",
5253
]
@@ -226,16 +227,13 @@ pretty = true
226227

227228
[[tool.mypy.overrides]]
228229
module = [
229-
# TODO: enable step by step when it makes sense
230-
# e.g. the pandas API is just too inconsistent for type hinting to be useful.
231230
"cudf.*",
232231
"dask.*",
233232
"dask_expr.*",
234233
"duckdb.*",
235234
"ibis.*",
236235
"modin.*",
237236
"numpy.*",
238-
"pandas.*",
239237
"pyspark.*",
240238
"sklearn.*",
241239
"sqlframe.*",

tests/conftest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def dask_lazy_p2_constructor(obj: dict[str, list[Any]]) -> IntoFrame: # pragma:
132132
return dd.from_dict(obj, npartitions=2) # type: ignore[no-any-return]
133133

134134

135-
def pyarrow_table_constructor(obj: dict[str, list[Any]]) -> IntoDataFrame:
135+
def pyarrow_table_constructor(obj: dict[str, Any]) -> IntoDataFrame:
136136
return pa.table(obj) # type: ignore[no-any-return]
137137

138138

@@ -227,7 +227,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
227227
if x not in GPU_CONSTRUCTORS and x != "modin" # too slow
228228
]
229229
else: # pragma: no cover
230-
selected_constructors = metafunc.config.getoption("constructors").split(",")
230+
selected_constructors = metafunc.config.getoption("constructors").split(",") # pyright: ignore[reportAttributeAccessIssue]
231231

232232
eager_constructors: list[Callable[[Any], IntoDataFrame]] = []
233233
eager_constructors_ids: list[str] = []

tests/dependencies/is_into_dataframe_test.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import TYPE_CHECKING
44
from typing import Any
5+
from typing import Mapping
56

67
import numpy as np
78
import pandas as pd
@@ -16,7 +17,7 @@
1617

1718

1819
class DictDataFrame:
19-
def __init__(self: Self, data: dict[str, list[Any]]) -> None:
20+
def __init__(self: Self, data: Mapping[str, Any]) -> None:
2021
self._data = data
2122

2223
def __len__(self) -> int: # pragma: no cover
@@ -27,7 +28,7 @@ def __narwhals_dataframe__(self) -> Self: # pragma: no cover
2728

2829

2930
def test_is_into_dataframe() -> None:
30-
data = {"a": [1, 2, 3], "b": [4, 5, 6]}
31+
data: dict[str, Any] = {"a": [1, 2, 3], "b": [4, 5, 6]}
3132
assert is_into_dataframe(pa.table(data))
3233
assert is_into_dataframe(pl.DataFrame(data))
3334
assert is_into_dataframe(pd.DataFrame(data))

tests/dtypes_test.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
import narwhals.stable.v1 as nw
1616
from tests.utils import PANDAS_VERSION
1717
from tests.utils import POLARS_VERSION
18+
from tests.utils import PYARROW_VERSION
1819

1920
if TYPE_CHECKING:
21+
from narwhals.typing import IntoSeries
2022
from tests.utils import Constructor
2123

2224

@@ -135,6 +137,8 @@ def test_struct_hashes() -> None:
135137
def test_2d_array(constructor: Constructor, request: pytest.FixtureRequest) -> None:
136138
if any(x in str(constructor) for x in ("dask", "modin", "cudf", "pyspark")):
137139
request.applymarker(pytest.mark.xfail)
140+
if "pyarrow_table" in str(constructor) and PYARROW_VERSION < (14,):
141+
request.applymarker(pytest.mark.xfail)
138142
data = {"a": [[[1, 2], [3, 4], [5, 6]]]}
139143
df = nw.from_native(constructor(data)).with_columns(
140144
a=nw.col("a").cast(nw.Array(nw.Int64(), (3, 2)))
@@ -144,13 +148,15 @@ def test_2d_array(constructor: Constructor, request: pytest.FixtureRequest) -> N
144148

145149

146150
def test_second_time_unit() -> None:
147-
s = pd.Series(np.array([np.datetime64("2020-01-01", "s")]))
151+
s: IntoSeries = pd.Series(np.array([np.datetime64("2020-01-01", "s")]))
148152
result = nw.from_native(s, series_only=True)
149153
if PANDAS_VERSION < (2,): # pragma: no cover
150154
assert result.dtype == nw.Datetime("ns")
151155
else:
152156
assert result.dtype == nw.Datetime("s")
153-
s = pa.chunked_array([pa.array([datetime(2020, 1, 1)], type=pa.timestamp("s"))])
157+
ts_sec = pa.timestamp("s")
158+
dur_sec = pa.duration("s")
159+
s = pa.chunked_array([pa.array([datetime(2020, 1, 1)], type=ts_sec)], type=ts_sec)
154160
result = nw.from_native(s, series_only=True)
155161
assert result.dtype == nw.Datetime("s")
156162
s = pd.Series(np.array([np.timedelta64(1, "s")]))
@@ -159,7 +165,7 @@ def test_second_time_unit() -> None:
159165
assert result.dtype == nw.Duration("ns")
160166
else:
161167
assert result.dtype == nw.Duration("s")
162-
s = pa.chunked_array([pa.array([timedelta(1)], type=pa.duration("s"))])
168+
s = pa.chunked_array([pa.array([timedelta(1)], type=dur_sec)], type=dur_sec)
163169
result = nw.from_native(s, series_only=True)
164170
assert result.dtype == nw.Duration("s")
165171

tests/expr_and_series/arithmetic_test.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,9 @@ def test_floordiv(left: int, right: int) -> None:
166166
# test is a bit manual unfortunately
167167
assume(right != 0)
168168
expected = {"a": [left // right]}
169-
result = nw.from_native(pd.DataFrame({"a": [left]}), eager_only=True).select(
170-
nw.col("a") // right
171-
)
169+
result: nw.DataFrame[Any] = nw.from_native(
170+
pd.DataFrame({"a": [left]}), eager_only=True
171+
).select(nw.col("a") // right)
172172
assert_equal_data(result, expected)
173173
if PANDAS_VERSION < (2, 2): # pragma: no cover
174174
# Bug in old version of pandas
@@ -201,9 +201,9 @@ def test_mod(left: int, right: int) -> None:
201201
# test is a bit manual unfortunately
202202
assume(right != 0)
203203
expected = {"a": [left % right]}
204-
result = nw.from_native(pd.DataFrame({"a": [left]}), eager_only=True).select(
205-
nw.col("a") % right
206-
)
204+
result: nw.DataFrame[Any] = nw.from_native(
205+
pd.DataFrame({"a": [left]}), eager_only=True
206+
).select(nw.col("a") % right)
207207
assert_equal_data(result, expected)
208208
result = nw.from_native(
209209
pd.DataFrame({"a": [left]}).convert_dtypes(), eager_only=True

tests/expr_and_series/nth_test.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
from typing import Mapping
5+
36
import polars as pl
47
import pytest
58

@@ -8,7 +11,7 @@
811
from tests.utils import Constructor
912
from tests.utils import assert_equal_data
1013

11-
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.1, 8.0, 9.0]}
14+
data: Mapping[str, Any] = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.1, 8.0, 9.0]}
1215

1316

1417
@pytest.mark.parametrize(

tests/frame/getitem_test.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
34
from typing import Any
5+
from typing import cast
46

57
import numpy as np
68
import pandas as pd
@@ -12,7 +14,10 @@
1214
from tests.utils import ConstructorEager
1315
from tests.utils import assert_equal_data
1416

15-
data = {
17+
if TYPE_CHECKING:
18+
from narwhals.typing import _1DArray
19+
20+
data: dict[str, Any] = {
1621
"a": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1722
"b": [11, 12, 13, 14, 15, 16],
1823
}
@@ -74,7 +79,8 @@ def test_gather(constructor_eager: ConstructorEager) -> None:
7479
"b": [11, 14, 12],
7580
}
7681
assert_equal_data(result, expected)
77-
result = df[np.array([0, 3, 1])]
82+
arr = cast("_1DArray", np.array([0, 3, 1]))
83+
result = df[arr]
7884
assert_equal_data(result, expected)
7985

8086

@@ -96,10 +102,10 @@ def test_gather_rows_cols(constructor_eager: ConstructorEager) -> None:
96102

97103
expected = {"b": [11, 14, 12]}
98104

99-
result = {"b": df[[0, 3, 1], 1]}
105+
result: Any = {"b": df[[0, 3, 1], 1]}
100106
assert_equal_data(result, expected)
101-
102-
result = {"b": df[np.array([0, 3, 1]), "b"]}
107+
arr = cast("_1DArray", np.array([0, 3, 1]))
108+
result = {"b": df[arr, "b"]}
103109
assert_equal_data(result, expected)
104110

105111

tests/frame/interchange_native_namespace_test.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
from typing import Mapping
5+
36
import polars as pl
47
import pytest
58

69
import narwhals.stable.v1 as nw
710

8-
data = {"a": [1, 2, 3], "b": [4.5, 6.7, 8.9], "z": ["x", "y", "w"]}
11+
data: Mapping[str, Any] = {"a": [1, 2, 3], "b": [4.5, 6.7, 8.9], "z": ["x", "y", "w"]}
912

1013

1114
def test_interchange() -> None:

tests/frame/interchange_select_test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import TYPE_CHECKING
44
from typing import Any
5+
from typing import Mapping
56

67
import polars as pl
78
import pytest
@@ -11,7 +12,7 @@
1112
if TYPE_CHECKING:
1213
from typing_extensions import Self
1314

14-
data = {"a": [1, 2, 3], "b": [4.0, 5.0, 6.1], "z": ["x", "y", "z"]}
15+
data: Mapping[str, Any] = {"a": [1, 2, 3], "b": [4.0, 5.0, 6.1], "z": ["x", "y", "z"]}
1516

1617

1718
class InterchangeDataFrame:

tests/frame/interchange_to_arrow_test.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
from typing import Mapping
5+
36
import polars as pl
47
import pyarrow as pa
58
import pytest
69

710
import narwhals.stable.v1 as nw
811

9-
data = {"a": [1, 2, 3], "b": [4.0, 5.0, 6.1], "z": ["x", "y", "z"]}
12+
data: Mapping[str, Any] = {"a": [1, 2, 3], "b": [4.0, 5.0, 6.1], "z": ["x", "y", "z"]}
1013

1114

1215
def test_interchange_to_arrow() -> None:

tests/frame/invalid_test.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
import pandas as pd
46
import polars as pl
57
import pyarrow as pa
@@ -8,10 +10,13 @@
810
import narwhals.stable.v1 as nw
911
from tests.utils import NUMPY_VERSION
1012

13+
if TYPE_CHECKING:
14+
from narwhals.typing import Frame
15+
1116

1217
def test_invalid() -> None:
1318
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
14-
df = nw.from_native(pa.table({"a": [1, 2], "b": [3, 4]}))
19+
df: Frame = nw.from_native(pa.table({"a": [1, 2], "b": [3, 4]}))
1520
with pytest.raises(ValueError, match="Multi-output"):
1621
df.select(nw.all() + nw.all())
1722
df = nw.from_native(pd.DataFrame(data))
@@ -24,14 +29,14 @@ def test_invalid() -> None:
2429

2530

2631
def test_native_vs_non_native() -> None:
27-
s = pd.Series([1, 2, 3])
28-
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
32+
s_pd = pd.Series([1, 2, 3])
33+
df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
2934
with pytest.raises(TypeError, match="Perhaps you forgot"):
30-
nw.from_native(df).filter(s > 1)
31-
s = pl.Series([1, 2, 3])
32-
df = pl.DataFrame({"a": [2, 2, 3], "b": [4, 5, 6]})
35+
nw.from_native(df_pd).filter(s_pd > 1) # type: ignore[arg-type]
36+
s_pl = pl.Series([1, 2, 3])
37+
df_pl = pl.DataFrame({"a": [2, 2, 3], "b": [4, 5, 6]})
3338
with pytest.raises(TypeError, match="Perhaps you\n- forgot"):
34-
nw.from_native(df).filter(s > 1)
39+
nw.from_native(df_pl).filter(s_pl > 1)
3540

3641

3742
def test_validate_laziness() -> None:

0 commit comments

Comments
 (0)