1
1
from __future__ import annotations
2
2
3
- import re
4
3
from typing import TYPE_CHECKING
5
- from typing import Any
6
- from typing import Callable
7
- from typing import Iterable
8
- from typing import Sequence
9
4
10
5
from narwhals ._arrow .expr import ArrowExpr
11
- from narwhals .utils import _parse_time_unit_and_time_zone
12
- from narwhals .utils import dtype_matches_time_unit_and_time_zone
13
- from narwhals .utils import import_dtypes_module
6
+ from narwhals ._selectors import CompliantSelector
7
+ from narwhals ._selectors import EagerSelectorNamespace
14
8
15
9
if TYPE_CHECKING :
16
- from datetime import timezone
17
-
18
10
from typing_extensions import Self
19
11
20
12
from narwhals ._arrow .dataframe import ArrowDataFrame
21
13
from narwhals ._arrow .series import ArrowSeries
22
- from narwhals .dtypes import DType
23
- from narwhals .typing import TimeUnit
24
- from narwhals .utils import _LimitedContext
25
-
26
-
27
- class ArrowSelectorNamespace :
28
- def __init__ (self : Self , context : _LimitedContext , / ) -> None :
29
- self ._backend_version = context ._backend_version
30
- self ._version = context ._version
31
-
32
- def by_dtype (self : Self , dtypes : Iterable [DType | type [DType ]]) -> ArrowSelector :
33
- def func (df : ArrowDataFrame ) -> list [ArrowSeries ]:
34
- return [df [col ] for col in df .columns if df .schema [col ] in dtypes ]
35
-
36
- def evaluate_output_names (df : ArrowDataFrame ) -> Sequence [str ]:
37
- return [col for col in df .columns if df .schema [col ] in dtypes ]
38
-
39
- return selector (self , func , evaluate_output_names )
40
-
41
- def matches (self : Self , pattern : str ) -> ArrowSelector :
42
- def func (df : ArrowDataFrame ) -> list [ArrowSeries ]:
43
- return [df [col ] for col in df .columns if re .search (pattern , col )]
44
-
45
- def evaluate_output_names (df : ArrowDataFrame ) -> Sequence [str ]:
46
- return [col for col in df .columns if re .search (pattern , col )]
14
+ from narwhals ._selectors import EvalNames
15
+ from narwhals ._selectors import EvalSeries
16
+ from narwhals .utils import _FullContext
47
17
48
- return selector (self , func , evaluate_output_names )
49
18
50
- def numeric (self : Self ) -> ArrowSelector :
51
- dtypes = import_dtypes_module (self ._version )
52
- return self .by_dtype (
53
- [
54
- dtypes .Int128 ,
55
- dtypes .Int64 ,
56
- dtypes .Int32 ,
57
- dtypes .Int16 ,
58
- dtypes .Int8 ,
59
- dtypes .UInt128 ,
60
- dtypes .UInt64 ,
61
- dtypes .UInt32 ,
62
- dtypes .UInt16 ,
63
- dtypes .UInt8 ,
64
- dtypes .Float64 ,
65
- dtypes .Float32 ,
66
- ],
67
- )
68
-
69
- def categorical (self : Self ) -> ArrowSelector :
70
- dtypes = import_dtypes_module (self ._version )
71
- return self .by_dtype ([dtypes .Categorical ])
72
-
73
- def string (self : Self ) -> ArrowSelector :
74
- dtypes = import_dtypes_module (self ._version )
75
- return self .by_dtype ([dtypes .String ])
76
-
77
- def boolean (self : Self ) -> ArrowSelector :
78
- dtypes = import_dtypes_module (self ._version )
79
- return self .by_dtype ([dtypes .Boolean ])
80
-
81
- def all (self : Self ) -> ArrowSelector :
82
- def func (df : ArrowDataFrame ) -> list [ArrowSeries ]:
83
- return [df [col ] for col in df .columns ]
84
-
85
- return selector (self , func , lambda df : df .columns )
86
-
87
- def datetime (
88
- self : Self ,
89
- time_unit : TimeUnit | Iterable [TimeUnit ] | None ,
90
- time_zone : str | timezone | Iterable [str | timezone | None ] | None ,
19
+ class ArrowSelectorNamespace (EagerSelectorNamespace ["ArrowDataFrame" , "ArrowSeries" ]):
20
+ def _selector (
21
+ self ,
22
+ call : EvalSeries [ArrowDataFrame , ArrowSeries ],
23
+ evaluate_output_names : EvalNames [ArrowDataFrame ],
24
+ / ,
91
25
) -> ArrowSelector :
92
- dtypes = import_dtypes_module (version = self ._version )
93
- time_units , time_zones = _parse_time_unit_and_time_zone (
94
- time_unit = time_unit , time_zone = time_zone
26
+ return ArrowSelector (
27
+ call ,
28
+ depth = 0 ,
29
+ function_name = "selector" ,
30
+ evaluate_output_names = evaluate_output_names ,
31
+ alias_output_names = None ,
32
+ backend_version = self ._backend_version ,
33
+ version = self ._version ,
95
34
)
96
35
97
- def func (df : ArrowDataFrame ) -> list [ArrowSeries ]:
98
- return [
99
- df [col ]
100
- for col in df .columns
101
- if dtype_matches_time_unit_and_time_zone (
102
- dtype = df .schema [col ],
103
- dtypes = dtypes ,
104
- time_units = time_units ,
105
- time_zones = time_zones ,
106
- )
107
- ]
108
-
109
- def evaluate_output_names (df : ArrowDataFrame ) -> Sequence [str ]:
110
- return [
111
- col
112
- for col in df .columns
113
- if dtype_matches_time_unit_and_time_zone (
114
- dtype = df .schema [col ],
115
- dtypes = dtypes ,
116
- time_units = time_units ,
117
- time_zones = time_zones ,
118
- )
119
- ]
120
-
121
- return selector (self , func , evaluate_output_names )
122
-
36
+ def __init__ (self : Self , context : _FullContext , / ) -> None :
37
+ self ._implementation = context ._implementation
38
+ self ._backend_version = context ._backend_version
39
+ self ._version = context ._version
123
40
124
- class ArrowSelector (ArrowExpr ):
125
- def __repr__ (self : Self ) -> str : # pragma: no cover
126
- return f"ArrowSelector(depth={ self ._depth } , function_name={ self ._function_name } )"
127
41
42
+ class ArrowSelector (CompliantSelector ["ArrowDataFrame" , "ArrowSeries" ], ArrowExpr ): # type: ignore[misc]
128
43
def _to_expr (self : Self ) -> ArrowExpr :
129
44
return ArrowExpr (
130
45
self ._call ,
@@ -135,82 +50,3 @@ def _to_expr(self: Self) -> ArrowExpr:
135
50
backend_version = self ._backend_version ,
136
51
version = self ._version ,
137
52
)
138
-
139
- def __sub__ (self : Self , other : Self | Any ) -> ArrowSelector | Any :
140
- if isinstance (other , ArrowSelector ):
141
-
142
- def call (df : ArrowDataFrame ) -> list [ArrowSeries ]:
143
- lhs_names = self ._evaluate_output_names (df )
144
- rhs_names = other ._evaluate_output_names (df )
145
- lhs = self ._call (df )
146
- return [x for x , name in zip (lhs , lhs_names ) if name not in rhs_names ]
147
-
148
- def evaluate_output_names (df : ArrowDataFrame ) -> list [str ]:
149
- lhs_names = self ._evaluate_output_names (df )
150
- rhs_names = other ._evaluate_output_names (df )
151
- return [x for x in lhs_names if x not in rhs_names ]
152
-
153
- return selector (self , call , evaluate_output_names )
154
- else :
155
- return self ._to_expr () - other
156
-
157
- def __or__ (self : Self , other : Self | Any ) -> ArrowSelector | Any :
158
- if isinstance (other , ArrowSelector ):
159
-
160
- def call (df : ArrowDataFrame ) -> list [ArrowSeries ]:
161
- lhs_names = self ._evaluate_output_names (df )
162
- rhs_names = other ._evaluate_output_names (df )
163
- lhs = self ._call (df )
164
- rhs = other ._call (df )
165
- return [
166
- * (x for x , name in zip (lhs , lhs_names ) if name not in rhs_names ),
167
- * rhs ,
168
- ]
169
-
170
- def evaluate_output_names (df : ArrowDataFrame ) -> list [str ]:
171
- lhs_names = self ._evaluate_output_names (df )
172
- rhs_names = other ._evaluate_output_names (df )
173
- return [* (x for x in lhs_names if x not in rhs_names ), * rhs_names ]
174
-
175
- return selector (self , call , evaluate_output_names )
176
- else :
177
- return self ._to_expr () | other
178
-
179
- def __and__ (self : Self , other : Self | Any ) -> ArrowSelector | Any :
180
- if isinstance (other , ArrowSelector ):
181
-
182
- def call (df : ArrowDataFrame ) -> list [ArrowSeries ]:
183
- lhs_names = self ._evaluate_output_names (df )
184
- rhs_names = other ._evaluate_output_names (df )
185
- lhs = self ._call (df )
186
- return [x for x , name in zip (lhs , lhs_names ) if name in rhs_names ]
187
-
188
- def evaluate_output_names (df : ArrowDataFrame ) -> list [str ]:
189
- lhs_names = self ._evaluate_output_names (df )
190
- rhs_names = other ._evaluate_output_names (df )
191
- return [x for x in lhs_names if x in rhs_names ]
192
-
193
- return selector (self , call , evaluate_output_names )
194
-
195
- else :
196
- return self ._to_expr () & other
197
-
198
- def __invert__ (self : Self ) -> ArrowSelector :
199
- return ArrowSelectorNamespace (self ).all () - self
200
-
201
-
202
- def selector (
203
- context : _LimitedContext ,
204
- call : Callable [[ArrowDataFrame ], Sequence [ArrowSeries ]],
205
- evaluate_output_names : Callable [[ArrowDataFrame ], Sequence [str ]],
206
- / ,
207
- ) -> ArrowSelector :
208
- return ArrowSelector (
209
- call ,
210
- depth = 0 ,
211
- function_name = "selector" ,
212
- evaluate_output_names = evaluate_output_names ,
213
- alias_output_names = None ,
214
- backend_version = context ._backend_version ,
215
- version = context ._version ,
216
- )
0 commit comments