Skip to content

Commit 22ce463

Browse files
authored
refactor: Generic CompliantSelector (#2064)
1 parent 217bd7d commit 22ce463

17 files changed

+494
-961
lines changed

narwhals/_arrow/dataframe.py

+2
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ def iter_columns(self) -> Iterator[ArrowSeries]:
164164
version=self._version,
165165
)
166166

167+
_iter_columns = iter_columns
168+
167169
def iter_rows(
168170
self: Self, *, named: bool, buffer_size: int
169171
) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]:

narwhals/_arrow/selectors.py

+24-188
Original file line numberDiff line numberDiff line change
@@ -1,130 +1,45 @@
11
from __future__ import annotations
22

3-
import re
43
from typing import TYPE_CHECKING
5-
from typing import Any
6-
from typing import Callable
7-
from typing import Iterable
8-
from typing import Sequence
94

105
from narwhals._arrow.expr import ArrowExpr
11-
from narwhals.utils import _parse_time_unit_and_time_zone
12-
from narwhals.utils import dtype_matches_time_unit_and_time_zone
13-
from narwhals.utils import import_dtypes_module
6+
from narwhals._selectors import CompliantSelector
7+
from narwhals._selectors import EagerSelectorNamespace
148

159
if TYPE_CHECKING:
16-
from datetime import timezone
17-
1810
from typing_extensions import Self
1911

2012
from narwhals._arrow.dataframe import ArrowDataFrame
2113
from narwhals._arrow.series import ArrowSeries
22-
from narwhals.dtypes import DType
23-
from narwhals.typing import TimeUnit
24-
from narwhals.utils import _LimitedContext
25-
26-
27-
class ArrowSelectorNamespace:
28-
def __init__(self: Self, context: _LimitedContext, /) -> None:
29-
self._backend_version = context._backend_version
30-
self._version = context._version
31-
32-
def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> ArrowSelector:
33-
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
34-
return [df[col] for col in df.columns if df.schema[col] in dtypes]
35-
36-
def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
37-
return [col for col in df.columns if df.schema[col] in dtypes]
38-
39-
return selector(self, func, evaluate_output_names)
40-
41-
def matches(self: Self, pattern: str) -> ArrowSelector:
42-
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
43-
return [df[col] for col in df.columns if re.search(pattern, col)]
44-
45-
def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
46-
return [col for col in df.columns if re.search(pattern, col)]
14+
from narwhals._selectors import EvalNames
15+
from narwhals._selectors import EvalSeries
16+
from narwhals.utils import _FullContext
4717

48-
return selector(self, func, evaluate_output_names)
4918

50-
def numeric(self: Self) -> ArrowSelector:
51-
dtypes = import_dtypes_module(self._version)
52-
return self.by_dtype(
53-
[
54-
dtypes.Int128,
55-
dtypes.Int64,
56-
dtypes.Int32,
57-
dtypes.Int16,
58-
dtypes.Int8,
59-
dtypes.UInt128,
60-
dtypes.UInt64,
61-
dtypes.UInt32,
62-
dtypes.UInt16,
63-
dtypes.UInt8,
64-
dtypes.Float64,
65-
dtypes.Float32,
66-
],
67-
)
68-
69-
def categorical(self: Self) -> ArrowSelector:
70-
dtypes = import_dtypes_module(self._version)
71-
return self.by_dtype([dtypes.Categorical])
72-
73-
def string(self: Self) -> ArrowSelector:
74-
dtypes = import_dtypes_module(self._version)
75-
return self.by_dtype([dtypes.String])
76-
77-
def boolean(self: Self) -> ArrowSelector:
78-
dtypes = import_dtypes_module(self._version)
79-
return self.by_dtype([dtypes.Boolean])
80-
81-
def all(self: Self) -> ArrowSelector:
82-
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
83-
return [df[col] for col in df.columns]
84-
85-
return selector(self, func, lambda df: df.columns)
86-
87-
def datetime(
88-
self: Self,
89-
time_unit: TimeUnit | Iterable[TimeUnit] | None,
90-
time_zone: str | timezone | Iterable[str | timezone | None] | None,
19+
class ArrowSelectorNamespace(EagerSelectorNamespace["ArrowDataFrame", "ArrowSeries"]):
20+
def _selector(
21+
self,
22+
call: EvalSeries[ArrowDataFrame, ArrowSeries],
23+
evaluate_output_names: EvalNames[ArrowDataFrame],
24+
/,
9125
) -> ArrowSelector:
92-
dtypes = import_dtypes_module(version=self._version)
93-
time_units, time_zones = _parse_time_unit_and_time_zone(
94-
time_unit=time_unit, time_zone=time_zone
26+
return ArrowSelector(
27+
call,
28+
depth=0,
29+
function_name="selector",
30+
evaluate_output_names=evaluate_output_names,
31+
alias_output_names=None,
32+
backend_version=self._backend_version,
33+
version=self._version,
9534
)
9635

97-
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
98-
return [
99-
df[col]
100-
for col in df.columns
101-
if dtype_matches_time_unit_and_time_zone(
102-
dtype=df.schema[col],
103-
dtypes=dtypes,
104-
time_units=time_units,
105-
time_zones=time_zones,
106-
)
107-
]
108-
109-
def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
110-
return [
111-
col
112-
for col in df.columns
113-
if dtype_matches_time_unit_and_time_zone(
114-
dtype=df.schema[col],
115-
dtypes=dtypes,
116-
time_units=time_units,
117-
time_zones=time_zones,
118-
)
119-
]
120-
121-
return selector(self, func, evaluate_output_names)
122-
36+
def __init__(self: Self, context: _FullContext, /) -> None:
37+
self._implementation = context._implementation
38+
self._backend_version = context._backend_version
39+
self._version = context._version
12340

124-
class ArrowSelector(ArrowExpr):
125-
def __repr__(self: Self) -> str: # pragma: no cover
126-
return f"ArrowSelector(depth={self._depth}, function_name={self._function_name})"
12741

42+
class ArrowSelector(CompliantSelector["ArrowDataFrame", "ArrowSeries"], ArrowExpr): # type: ignore[misc]
12843
def _to_expr(self: Self) -> ArrowExpr:
12944
return ArrowExpr(
13045
self._call,
@@ -135,82 +50,3 @@ def _to_expr(self: Self) -> ArrowExpr:
13550
backend_version=self._backend_version,
13651
version=self._version,
13752
)
138-
139-
def __sub__(self: Self, other: Self | Any) -> ArrowSelector | Any:
140-
if isinstance(other, ArrowSelector):
141-
142-
def call(df: ArrowDataFrame) -> list[ArrowSeries]:
143-
lhs_names = self._evaluate_output_names(df)
144-
rhs_names = other._evaluate_output_names(df)
145-
lhs = self._call(df)
146-
return [x for x, name in zip(lhs, lhs_names) if name not in rhs_names]
147-
148-
def evaluate_output_names(df: ArrowDataFrame) -> list[str]:
149-
lhs_names = self._evaluate_output_names(df)
150-
rhs_names = other._evaluate_output_names(df)
151-
return [x for x in lhs_names if x not in rhs_names]
152-
153-
return selector(self, call, evaluate_output_names)
154-
else:
155-
return self._to_expr() - other
156-
157-
def __or__(self: Self, other: Self | Any) -> ArrowSelector | Any:
158-
if isinstance(other, ArrowSelector):
159-
160-
def call(df: ArrowDataFrame) -> list[ArrowSeries]:
161-
lhs_names = self._evaluate_output_names(df)
162-
rhs_names = other._evaluate_output_names(df)
163-
lhs = self._call(df)
164-
rhs = other._call(df)
165-
return [
166-
*(x for x, name in zip(lhs, lhs_names) if name not in rhs_names),
167-
*rhs,
168-
]
169-
170-
def evaluate_output_names(df: ArrowDataFrame) -> list[str]:
171-
lhs_names = self._evaluate_output_names(df)
172-
rhs_names = other._evaluate_output_names(df)
173-
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]
174-
175-
return selector(self, call, evaluate_output_names)
176-
else:
177-
return self._to_expr() | other
178-
179-
def __and__(self: Self, other: Self | Any) -> ArrowSelector | Any:
180-
if isinstance(other, ArrowSelector):
181-
182-
def call(df: ArrowDataFrame) -> list[ArrowSeries]:
183-
lhs_names = self._evaluate_output_names(df)
184-
rhs_names = other._evaluate_output_names(df)
185-
lhs = self._call(df)
186-
return [x for x, name in zip(lhs, lhs_names) if name in rhs_names]
187-
188-
def evaluate_output_names(df: ArrowDataFrame) -> list[str]:
189-
lhs_names = self._evaluate_output_names(df)
190-
rhs_names = other._evaluate_output_names(df)
191-
return [x for x in lhs_names if x in rhs_names]
192-
193-
return selector(self, call, evaluate_output_names)
194-
195-
else:
196-
return self._to_expr() & other
197-
198-
def __invert__(self: Self) -> ArrowSelector:
199-
return ArrowSelectorNamespace(self).all() - self
200-
201-
202-
def selector(
203-
context: _LimitedContext,
204-
call: Callable[[ArrowDataFrame], Sequence[ArrowSeries]],
205-
evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]],
206-
/,
207-
) -> ArrowSelector:
208-
return ArrowSelector(
209-
call,
210-
depth=0,
211-
function_name="selector",
212-
evaluate_output_names=evaluate_output_names,
213-
alias_output_names=None,
214-
backend_version=context._backend_version,
215-
version=context._version,
216-
)

narwhals/_dask/dataframe.py

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import TYPE_CHECKING
44
from typing import Any
5+
from typing import Iterator
56
from typing import Literal
67
from typing import Sequence
78

@@ -24,6 +25,7 @@
2425
if TYPE_CHECKING:
2526
from types import ModuleType
2627

28+
import dask.dataframe.dask_expr as dx
2729
from typing_extensions import Self
2830

2931
from narwhals._dask.expr import DaskExpr
@@ -79,6 +81,10 @@ def _from_native_frame(self: Self, df: Any) -> Self:
7981
version=self._version,
8082
)
8183

84+
def _iter_columns(self) -> Iterator[dx.Series]:
85+
for _col, ser in self._native_frame.items(): # noqa: PERF102
86+
yield ser
87+
8288
def with_columns(self: Self, *exprs: DaskExpr) -> Self:
8389
df = self._native_frame
8490
new_series = evaluate_exprs(self, *exprs)

narwhals/_dask/namespace.py

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from narwhals._expression_parsing import combine_alias_output_names
2525
from narwhals._expression_parsing import combine_evaluate_output_names
2626
from narwhals.typing import CompliantNamespace
27+
from narwhals.utils import Implementation
2728
from narwhals.utils import get_column_names
2829

2930
if TYPE_CHECKING:
@@ -39,6 +40,8 @@
3940

4041

4142
class DaskNamespace(CompliantNamespace[DaskLazyFrame, "dx.Series"]): # pyright: ignore[reportInvalidTypeArguments] (#2044)
43+
_implementation: Implementation = Implementation.DASK
44+
4245
@property
4346
def selectors(self: Self) -> DaskSelectorNamespace:
4447
return DaskSelectorNamespace(self)

0 commit comments

Comments
 (0)