Skip to content

Commit 2b58ee7

Browse files
fix(typing): Resolve all mypy & pyright errors for _arrow (#2007)
Co-authored-by: Marco Gorelli <[email protected]>
1 parent ddf78c3 commit 2b58ee7

24 files changed

+606
-404
lines changed

narwhals/_arrow/dataframe.py

+22-23
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from functools import partial
34
from typing import TYPE_CHECKING
45
from typing import Any
56
from typing import Iterator
@@ -36,17 +37,17 @@
3637

3738
import pandas as pd
3839
import polars as pl
39-
from pyarrow._stubs_typing import ( # pyright: ignore[reportMissingModuleSource]
40-
Indices,
41-
)
42-
from pyarrow._stubs_typing import Order # pyright: ignore[reportMissingModuleSource]
4340
from typing_extensions import Self
4441
from typing_extensions import TypeAlias
4542

4643
from narwhals._arrow.expr import ArrowExpr
4744
from narwhals._arrow.group_by import ArrowGroupBy
4845
from narwhals._arrow.namespace import ArrowNamespace
4946
from narwhals._arrow.series import ArrowSeries
47+
from narwhals._arrow.typing import ArrowChunkedArray
48+
from narwhals._arrow.typing import Indices
49+
from narwhals._arrow.typing import Mask
50+
from narwhals._arrow.typing import Order
5051
from narwhals.dtypes import DType
5152
from narwhals.typing import SizeUnit
5253
from narwhals.typing import _1DArray
@@ -133,7 +134,7 @@ def __len__(self: Self) -> int:
133134
return len(self._native_frame)
134135

135136
def row(self: Self, index: int) -> tuple[Any, ...]:
136-
return tuple(col[index] for col in self._native_frame)
137+
return tuple(col[index] for col in self._native_frame.itercolumns())
137138

138139
@overload
139140
def rows(self: Self, *, named: Literal[True]) -> list[dict[str, Any]]: ...
@@ -371,7 +372,9 @@ def with_columns(self: Self, *exprs: ArrowExpr) -> Self:
371372

372373
native_frame = (
373374
native_frame.set_column(
374-
columns.index(col_name), field_=col_name, column=column
375+
columns.index(col_name),
376+
field_=col_name,
377+
column=column, # type: ignore[arg-type]
375378
)
376379
if col_name in columns
377380
else native_frame.append_column(field_=col_name, column=column)
@@ -532,17 +535,18 @@ def with_row_index(self: Self, name: str) -> Self:
532535
df.append_column(name, row_indices).select([name, *cols])
533536
)
534537

535-
def filter(self: Self, predicate: ArrowExpr | list[bool]) -> Self:
538+
def filter(self: Self, predicate: ArrowExpr | list[bool | None]) -> Self:
536539
if isinstance(predicate, list):
537-
mask_native = predicate
540+
mask_native: Mask | ArrowChunkedArray = predicate
538541
else:
539542
# `[0]` is safe as the predicate's expression only returns a single column
540543
mask = evaluate_into_exprs(self, predicate)[0]
541544
mask_native = broadcast_and_extract_dataframe_comparand(
542545
length=len(self), other=mask, backend_version=self._backend_version
543546
)
544547
return self._from_native_frame(
545-
self._native_frame.filter(mask_native), validate_column_names=False
548+
self._native_frame.filter(mask_native), # pyright: ignore[reportArgumentType]
549+
validate_column_names=False,
546550
)
547551

548552
def head(self: Self, n: int) -> Self:
@@ -745,17 +749,14 @@ def unique(
745749

746750
agg_func = agg_func_map[keep]
747751
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
748-
keep_idx = (
752+
keep_idx_native = (
749753
df.append_column(col_token, pa.array(np.arange(len(self))))
750754
.group_by(subset)
751755
.aggregate([(col_token, agg_func)])
752756
.column(f"{col_token}_{agg_func}")
753757
)
754-
755-
return self._from_native_frame(
756-
pc.take(df, keep_idx), # type: ignore[call-overload, unused-ignore]
757-
validate_column_names=False,
758-
)
758+
indices = cast("Indices", keep_idx_native)
759+
return self._from_native_frame(df.take(indices), validate_column_names=False)
759760

760761
keep_idx = self.simple_select(*subset).is_unique()
761762
plx = self.__narwhals_namespace__()
@@ -804,30 +805,28 @@ def unpivot(
804805
on_: list[str] = (
805806
[c for c in self.columns if c not in index_] if on is None else on
806807
)
807-
808-
promote_kwargs: dict[Literal["promote_options"], PromoteOptions] = (
809-
{"promote_options": "permissive"}
808+
concat = (
809+
partial(pa.concat_tables, promote_options="permissive")
810810
if self._backend_version >= (14, 0, 0)
811-
else {}
811+
else pa.concat_tables
812812
)
813813
names = [*index_, variable_name, value_name]
814814
return self._from_native_frame(
815-
pa.concat_tables(
815+
concat(
816816
[
817817
pa.Table.from_arrays(
818818
[
819819
*(native_frame.column(idx_col) for idx_col in index_),
820820
cast(
821-
"pa.ChunkedArray",
821+
"ArrowChunkedArray",
822822
pa.array([on_col] * n_rows, pa.string()),
823823
),
824824
native_frame.column(on_col),
825825
],
826826
names=names,
827827
)
828828
for on_col in on_
829-
],
830-
**promote_kwargs,
829+
]
831830
)
832831
)
833832
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not

narwhals/_arrow/expr.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
self._depth = depth
4949
self._function_name = function_name
5050
self._depth = depth
51-
self._evaluate_output_names = evaluate_output_names
51+
self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue]
5252
self._alias_output_names = alias_output_names
5353
self._backend_version = backend_version
5454
self._version = version

narwhals/_arrow/group_by.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import TYPE_CHECKING
66
from typing import Any
77
from typing import Iterator
8+
from typing import cast
89

910
import pyarrow as pa
1011
import pyarrow.compute as pc
@@ -18,6 +19,7 @@
1819

1920
from narwhals._arrow.dataframe import ArrowDataFrame
2021
from narwhals._arrow.expr import ArrowExpr
22+
from narwhals._arrow.typing import Incomplete
2123

2224
POLARS_TO_ARROW_AGGREGATIONS = {
2325
"sum": "sum",
@@ -68,7 +70,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:
6870
)
6971
raise ValueError(msg)
7072

71-
aggs: list[tuple[str, str, pc.FunctionOptions | None]] = []
73+
aggs: list[tuple[str, str, Any]] = []
7274
expected_pyarrow_column_names: list[str] = self._keys.copy()
7375
new_column_names: list[str] = self._keys.copy()
7476

@@ -91,7 +93,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:
9193

9294
function_name = re.sub(r"(\w+->)", "", expr._function_name)
9395
if function_name in {"std", "var"}:
94-
option = pc.VarianceOptions(ddof=expr._kwargs["ddof"])
96+
option: Any = pc.VarianceOptions(ddof=expr._kwargs["ddof"])
9597
elif function_name in {"len", "n_unique"}:
9698
option = pc.CountOptions(mode="all")
9799
elif function_name == "count":
@@ -139,14 +141,19 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:
139141

140142
def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]:
141143
col_token = generate_temporary_column_name(n_bytes=8, columns=self._df.columns)
142-
null_token = "__null_token_value__" # noqa: S105
144+
null_token: str = "__null_token_value__" # noqa: S105
143145

144146
table = self._df._native_frame
145-
key_values = pc.binary_join_element_wise(
146-
*[pc.cast(table[key], pa.string()) for key in self._keys],
147-
"",
148-
null_handling="replace",
149-
null_replacement=null_token,
147+
# NOTE: stubs fail in multiple places for `ChunkedArray`
148+
it = cast(
149+
"Iterator[pa.StringArray]",
150+
(table[key].cast(pa.string()) for key in self._keys),
151+
)
152+
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
153+
# Reality: `str` is fine
154+
concat_str: Incomplete = pc.binary_join_element_wise
155+
key_values = concat_str(
156+
*it, "", null_handling="replace", null_replacement=null_token
150157
)
151158
table = table.add_column(i=0, field_=col_token, column=key_values)
152159

narwhals/_arrow/namespace.py

+30-23
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from narwhals._arrow.utils import broadcast_series
2020
from narwhals._arrow.utils import diagonal_concat
2121
from narwhals._arrow.utils import horizontal_concat
22+
from narwhals._arrow.utils import nulls_like
2223
from narwhals._arrow.utils import vertical_concat
2324
from narwhals._expression_parsing import combine_alias_output_names
2425
from narwhals._expression_parsing import combine_evaluate_output_names
@@ -31,6 +32,7 @@
3132

3233
from typing_extensions import Self
3334

35+
from narwhals._arrow.typing import Incomplete
3436
from narwhals._arrow.typing import IntoArrowExpr
3537
from narwhals.dtypes import DType
3638
from narwhals.utils import Version
@@ -254,13 +256,16 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
254256
def min_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
255257
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
256258
init_series, *series = [s for _expr in exprs for s in _expr(df)]
259+
# NOTE: Stubs copy the wrong signature https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L963
260+
min_element_wise: Incomplete = pc.min_element_wise
261+
native_series = reduce(
262+
min_element_wise,
263+
[s._native_series for s in series],
264+
init_series._native_series,
265+
)
257266
return [
258267
ArrowSeries(
259-
native_series=reduce(
260-
pc.min_element_wise,
261-
[s._native_series for s in series],
262-
init_series._native_series,
263-
),
268+
native_series,
264269
name=init_series.name,
265270
backend_version=self._backend_version,
266271
version=self._version,
@@ -279,13 +284,17 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
279284
def max_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
280285
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
281286
init_series, *series = [s for _expr in exprs for s in _expr(df)]
287+
# NOTE: stubs are missing `ChunkedArray` support
288+
# https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L948-L954
289+
max_element_wise: Incomplete = pc.max_element_wise
290+
native_series = reduce(
291+
max_element_wise,
292+
[s._native_series for s in series],
293+
init_series._native_series,
294+
)
282295
return [
283296
ArrowSeries(
284-
native_series=reduce(
285-
pc.max_element_wise,
286-
[s._native_series for s in series],
287-
init_series._native_series,
288-
),
297+
native_series,
289298
name=init_series.name,
290299
backend_version=self._backend_version,
291300
version=self._version,
@@ -347,18 +356,19 @@ def concat_str(
347356
dtypes = import_dtypes_module(self._version)
348357

349358
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
350-
compliant_series_list = [
359+
compliant_series_list: list[ArrowSeries] = [
351360
s for _expr in exprs for s in _expr.cast(dtypes.String())(df)
352361
]
353-
null_handling = "skip" if ignore_nulls else "emit_null"
354-
result_series = pc.binary_join_element_wise(
355-
*(s._native_series for s in compliant_series_list),
356-
separator,
357-
null_handling=null_handling,
362+
null_handling: Literal["skip", "emit_null"] = (
363+
"skip" if ignore_nulls else "emit_null"
358364
)
365+
it = (s._native_series for s in compliant_series_list)
366+
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
367+
# Reality: `str` is fine
368+
concat_str: Incomplete = pc.binary_join_element_wise
359369
return [
360370
ArrowSeries(
361-
native_series=result_series,
371+
native_series=concat_str(*it, separator, null_handling=null_handling),
362372
name=compliant_series_list[0].name,
363373
backend_version=self._backend_version,
364374
version=self._version,
@@ -410,14 +420,11 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
410420
condition_native, value_series_native = broadcast_series(
411421
[condition, value_series]
412422
)
413-
414423
if self._otherwise_value is None:
415-
otherwise_native = pa.repeat(
416-
pa.scalar(None, type=value_series_native.type), len(condition_native)
417-
)
424+
otherwise_null = nulls_like(len(condition_native), value_series)
418425
return [
419426
value_series._from_native_series(
420-
pc.if_else(condition_native, value_series_native, otherwise_native)
427+
pc.if_else(condition_native, value_series_native, otherwise_null)
421428
)
422429
]
423430
if isinstance(self._otherwise_value, ArrowExpr):
@@ -474,7 +481,7 @@ def __init__(
474481
self._call = call
475482
self._depth = depth
476483
self._function_name = function_name
477-
self._evaluate_output_names = evaluate_output_names
484+
self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue]
478485
self._alias_output_names = alias_output_names
479486
self._kwargs = kwargs
480487

0 commit comments

Comments
 (0)