Skip to content

Commit 6db674a

Browse files
committed
broken typing
1 parent a3b27bf commit 6db674a

File tree

7 files changed

+78
-396
lines changed

7 files changed

+78
-396
lines changed

narwhals/pandas_like/dataframe.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
from typing import Iterable
77
from typing import Literal
88

9+
from narwhals.pandas_like.utils import evaluate_into_exprs
10+
from narwhals.pandas_like.utils import flatten_str
11+
from narwhals.pandas_like.utils import horizontal_concat
12+
from narwhals.pandas_like.utils import validate_dataframe_comparand
913
from narwhals.spec import DataFrame as DataFrameT
1014
from narwhals.spec import IntoExpr
1115
from narwhals.spec import LazyFrame as LazyFrameProtocol
1216
from narwhals.spec import Namespace as NamespaceProtocol
13-
from narwhals.utils import evaluate_into_exprs
14-
from narwhals.utils import flatten_str
15-
from narwhals.utils import horizontal_concat
16-
from narwhals.utils import validate_dataframe_comparand
1717

1818
if TYPE_CHECKING:
1919
from collections.abc import Sequence

narwhals/pandas_like/expr.py

+46-42
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
34
from typing import Any
45
from typing import Callable
56

67
from narwhals.pandas_like.series import Series
8+
from narwhals.pandas_like.utils import register_expression_call
79
from narwhals.spec import DataFrame as DataFrameT
810
from narwhals.spec import Expr as ExprT
911
from narwhals.spec import ExprStringNamespace as ExprStringNamespaceT
1012
from narwhals.spec import LazyFrame as LazyFrameProtocol
1113
from narwhals.spec import Namespace as NamespaceProtocol
1214
from narwhals.spec import Series as SeriesProtocol
13-
from narwhals.utils import register_expression_call
15+
16+
if TYPE_CHECKING:
17+
from typing_extensions import Self
1418

1519

1620
class Expr(ExprT):
@@ -44,8 +48,8 @@ def __repr__(self) -> str:
4448

4549
@classmethod
4650
def from_column_names(
47-
cls: type[Expr], *column_names: str, implementation: str
48-
) -> ExprT:
51+
cls: type[Self], *column_names: str, implementation: str
52+
) -> Self:
4953
return cls(
5054
lambda df: [
5155
Series(
@@ -70,124 +74,124 @@ def __expr_namespace__(self) -> NamespaceProtocol:
7074
implementation=self._implementation, # type: ignore[attr-defined]
7175
)
7276

73-
def __eq__(self, other: Expr | Any) -> ExprT: # type: ignore[override]
77+
def __eq__(self, other: Expr | Any) -> Self: # type: ignore[override]
7478
return register_expression_call(self, "__eq__", other)
7579

76-
def __ne__(self, other: Expr | Any) -> ExprT: # type: ignore[override]
80+
def __ne__(self, other: Expr | Any) -> Self: # type: ignore[override]
7781
return register_expression_call(self, "__ne__", other)
7882

79-
def __ge__(self, other: Expr | Any) -> ExprT:
83+
def __ge__(self, other: Expr | Any) -> Self:
8084
return register_expression_call(self, "__ge__", other)
8185

82-
def __gt__(self, other: Expr | Any) -> ExprT:
86+
def __gt__(self, other: Expr | Any) -> Self:
8387
return register_expression_call(self, "__gt__", other)
8488

85-
def __le__(self, other: Expr | Any) -> ExprT:
89+
def __le__(self, other: Expr | Any) -> Self:
8690
return register_expression_call(self, "__le__", other)
8791

88-
def __lt__(self, other: Expr | Any) -> ExprT:
92+
def __lt__(self, other: Expr | Any) -> Self:
8993
return register_expression_call(self, "__lt__", other)
9094

91-
def __and__(self, other: Expr | bool | Any) -> ExprT:
95+
def __and__(self, other: Expr | bool | Any) -> Self:
9296
return register_expression_call(self, "__and__", other)
9397

94-
def __rand__(self, other: Any) -> ExprT:
98+
def __rand__(self, other: Any) -> Self:
9599
return register_expression_call(self, "__rand__", other)
96100

97-
def __or__(self, other: Expr | bool | Any) -> ExprT:
101+
def __or__(self, other: Expr | bool | Any) -> Self:
98102
return register_expression_call(self, "__or__", other)
99103

100-
def __ror__(self, other: Any) -> ExprT:
104+
def __ror__(self, other: Any) -> Self:
101105
return register_expression_call(self, "__ror__", other)
102106

103-
def __add__(self, other: Expr | Any) -> ExprT: # type: ignore[override]
107+
def __add__(self, other: Expr | Any) -> Self: # type: ignore[override]
104108
return register_expression_call(self, "__add__", other)
105109

106-
def __radd__(self, other: Any) -> ExprT:
110+
def __radd__(self, other: Any) -> Self:
107111
return register_expression_call(self, "__radd__", other)
108112

109-
def __sub__(self, other: Expr | Any) -> ExprT:
113+
def __sub__(self, other: Expr | Any) -> Self:
110114
return register_expression_call(self, "__sub__", other)
111115

112-
def __rsub__(self, other: Any) -> ExprT:
116+
def __rsub__(self, other: Any) -> Self:
113117
return register_expression_call(self, "__rsub__", other)
114118

115-
def __mul__(self, other: Expr | Any) -> ExprT:
119+
def __mul__(self, other: Expr | Any) -> Self:
116120
return register_expression_call(self, "__mul__", other)
117121

118-
def __rmul__(self, other: Any) -> ExprT:
122+
def __rmul__(self, other: Any) -> Self:
119123
return self.__mul__(other)
120124

121-
def __truediv__(self, other: Expr | Any) -> ExprT:
125+
def __truediv__(self, other: Expr | Any) -> Self:
122126
return register_expression_call(self, "__truediv__", other)
123127

124-
def __rtruediv__(self, other: Any) -> ExprT:
128+
def __rtruediv__(self, other: Any) -> Self:
125129
raise NotImplementedError
126130

127-
def __floordiv__(self, other: Expr | Any) -> ExprT:
131+
def __floordiv__(self, other: Expr | Any) -> Self:
128132
return register_expression_call(self, "__floordiv__", other)
129133

130-
def __rfloordiv__(self, other: Any) -> ExprT:
134+
def __rfloordiv__(self, other: Any) -> Self:
131135
raise NotImplementedError
132136

133-
def __pow__(self, other: Expr | Any) -> ExprT:
137+
def __pow__(self, other: Expr | Any) -> Self:
134138
return register_expression_call(self, "__pow__", other)
135139

136-
def __rpow__(self, other: Any) -> ExprT: # pragma: no cover
140+
def __rpow__(self, other: Any) -> Self: # pragma: no cover
137141
raise NotImplementedError
138142

139-
def __mod__(self, other: Expr | Any) -> ExprT:
143+
def __mod__(self, other: Expr | Any) -> Self:
140144
return register_expression_call(self, "__mod__", other)
141145

142-
def __rmod__(self, other: Any) -> ExprT: # pragma: no cover
146+
def __rmod__(self, other: Any) -> Self: # pragma: no cover
143147
raise NotImplementedError
144148

145149
# Unary
146150

147-
def __invert__(self) -> ExprT:
151+
def __invert__(self) -> Self:
148152
return register_expression_call(self, "__invert__")
149153

150154
# Reductions
151155

152-
def sum(self) -> ExprT:
156+
def sum(self) -> Self:
153157
return register_expression_call(self, "sum")
154158

155-
def mean(self) -> ExprT:
159+
def mean(self) -> Self:
156160
return register_expression_call(self, "mean")
157161

158-
def max(self) -> ExprT:
162+
def max(self) -> Self:
159163
return register_expression_call(self, "max")
160164

161-
def min(self) -> ExprT:
165+
def min(self) -> Self:
162166
return register_expression_call(self, "min")
163167

164168
# Other
165169
def is_between(
166170
self, lower_bound: Any, upper_bound: Any, closed: str = "both"
167-
) -> ExprT:
171+
) -> Self:
168172
return register_expression_call(
169173
self, "is_between", lower_bound, upper_bound, closed
170174
)
171175

172-
def is_null(self) -> ExprT:
176+
def is_null(self) -> Self:
173177
return register_expression_call(self, "is_null")
174178

175-
def is_in(self, other: Any) -> ExprT:
179+
def is_in(self, other: Any) -> Self:
176180
return register_expression_call(self, "is_in", other)
177181

178-
def drop_nulls(self) -> ExprT:
182+
def drop_nulls(self) -> Self:
179183
return register_expression_call(self, "drop_nulls")
180184

181-
def n_unique(self) -> ExprT:
185+
def n_unique(self) -> Self:
182186
return register_expression_call(self, "n_unique")
183187

184-
def unique(self) -> ExprT:
188+
def unique(self) -> Self:
185189
return register_expression_call(self, "unique")
186190

187-
def sample(self, n: int, fraction: float, *, with_replacement: bool) -> ExprT:
191+
def sample(self, n: int, fraction: float, *, with_replacement: bool) -> Self:
188192
return register_expression_call(self, "sample", n, fraction, with_replacement)
189193

190-
def alias(self, name: str) -> ExprT:
194+
def alias(self, name: str) -> Self:
191195
# Define this one manually, so that we can
192196
# override `output_names` and not increase depth
193197
if self._depth is None:
@@ -211,7 +215,7 @@ class ExprStringNamespace(ExprStringNamespaceT):
211215
def __init__(self, expr: ExprT) -> None:
212216
self._expr = expr
213217

214-
def ends_with(self, suffix: str) -> ExprT:
218+
def ends_with(self, suffix: str) -> Expr:
215219
# TODO make a register_expression_call for namespaces
216220
return Expr(
217221
lambda df: [
@@ -229,7 +233,7 @@ def ends_with(self, suffix: str) -> ExprT:
229233
implementation=self._expr._implementation, # type: ignore[attr-defined]
230234
)
231235

232-
def strip_chars(self, characters: str = " ") -> ExprT:
236+
def strip_chars(self, characters: str = " ") -> Expr:
233237
return Expr(
234238
lambda df: [
235239
Series(

narwhals/pandas_like/group_by_object.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
from typing import Iterable
66

77
from narwhals.pandas_like.dataframe import LazyFrame
8+
from narwhals.pandas_like.utils import dataframe_from_dict
9+
from narwhals.pandas_like.utils import evaluate_simple_aggregation
10+
from narwhals.pandas_like.utils import get_namespace
11+
from narwhals.pandas_like.utils import horizontal_concat
12+
from narwhals.pandas_like.utils import is_simple_aggregation
13+
from narwhals.pandas_like.utils import parse_into_exprs
814
from narwhals.spec import DataFrame as DataFrameT
915
from narwhals.spec import GroupBy as GroupByProtocol
1016
from narwhals.spec import IntoExpr
1117
from narwhals.spec import LazyFrame as LazyFrameProtocol
1218
from narwhals.spec import LazyGroupBy as LazyGroupByT
13-
from narwhals.utils import dataframe_from_dict
14-
from narwhals.utils import evaluate_simple_aggregation
15-
from narwhals.utils import get_namespace
16-
from narwhals.utils import horizontal_concat
17-
from narwhals.utils import is_simple_aggregation
18-
from narwhals.utils import parse_into_exprs
1919

2020

2121
class GroupBy(GroupByProtocol):

narwhals/pandas_like/namespace.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,16 @@
99
from narwhals.pandas_like.dataframe import LazyFrame
1010
from narwhals.pandas_like.expr import Expr
1111
from narwhals.pandas_like.series import Series
12+
from narwhals.pandas_like.utils import flatten_str
13+
from narwhals.pandas_like.utils import horizontal_concat
14+
from narwhals.pandas_like.utils import parse_into_exprs
15+
from narwhals.pandas_like.utils import series_from_iterable
1216
from narwhals.spec import AnyDataFrame
1317
from narwhals.spec import DataFrame as DataFrameT
14-
from narwhals.spec import Expr as ExprT
1518
from narwhals.spec import IntoExpr
1619
from narwhals.spec import LazyFrame as LazyFrameProtocol
1720
from narwhals.spec import Namespace as NamespaceProtocol
1821
from narwhals.spec import Series as SeriesProtocol
19-
from narwhals.utils import flatten_str
20-
from narwhals.utils import horizontal_concat
21-
from narwhals.utils import parse_into_exprs
22-
from narwhals.utils import series_from_iterable
2322

2423

2524
class Namespace(NamespaceProtocol):
@@ -29,13 +28,13 @@ def __init__(self, *, api_version: str, implementation: str) -> None:
2928
self._implementation = implementation
3029

3130
# --- horizontal reductions
32-
def sum_horizontal(self, *exprs: IntoExpr | Iterable[IntoExpr]) -> ExprT:
31+
def sum_horizontal(self, *exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
3332
return reduce(lambda x, y: x + y, parse_into_exprs(self, *exprs))
3433

35-
def all_horizontal(self, *exprs: IntoExpr | Iterable[IntoExpr]) -> ExprT:
34+
def all_horizontal(self, *exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
3635
return reduce(lambda x, y: x & y, parse_into_exprs(self, *exprs))
3736

38-
def any_horizontal(self, *exprs: IntoExpr | Iterable[IntoExpr]) -> ExprT:
37+
def any_horizontal(self, *exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
3938
return reduce(lambda x, y: x | y, parse_into_exprs(self, *exprs))
4039

4140
def concat(self, items: Iterable[AnyDataFrame], *, how: str) -> AnyDataFrame:
@@ -62,32 +61,32 @@ def concat(self, items: Iterable[AnyDataFrame], *, how: str) -> AnyDataFrame:
6261
implementation=self._implementation,
6362
)
6463

65-
def col(self, *column_names: str | Iterable[str]) -> ExprT:
64+
def col(self, *column_names: str | Iterable[str]) -> Expr:
6665
return Expr.from_column_names(
6766
*flatten_str(*column_names), implementation=self._implementation
6867
)
6968

70-
def sum(self, *column_names: str) -> ExprT:
69+
def sum(self, *column_names: str) -> Expr:
7170
return Expr.from_column_names(
7271
*column_names, implementation=self._implementation
7372
).sum()
7473

75-
def mean(self, *column_names: str) -> ExprT:
74+
def mean(self, *column_names: str) -> Expr:
7675
return Expr.from_column_names(
7776
*column_names, implementation=self._implementation
7877
).mean()
7978

80-
def max(self, *column_names: str) -> ExprT:
79+
def max(self, *column_names: str) -> Expr:
8180
return Expr.from_column_names(
8281
*column_names, implementation=self._implementation
8382
).max()
8483

85-
def min(self, *column_names: str) -> ExprT:
84+
def min(self, *column_names: str) -> Expr:
8685
return Expr.from_column_names(
8786
*column_names, implementation=self._implementation
8887
).min()
8988

90-
def len(self) -> ExprT:
89+
def len(self) -> Expr:
9190
return Expr(
9291
lambda df: [
9392
Series(
@@ -116,7 +115,7 @@ def _create_expr_from_callable( # noqa: PLR0913
116115
function_name: str | None,
117116
root_names: list[str] | None,
118117
output_names: list[str] | None,
119-
) -> ExprT:
118+
) -> Expr:
120119
return Expr(
121120
func,
122121
depth=depth,
@@ -140,7 +139,7 @@ def _create_series_from_scalar(
140139
implementation=self._implementation,
141140
)
142141

143-
def _create_expr_from_series(self, series: SeriesProtocol) -> ExprT:
142+
def _create_expr_from_series(self, series: SeriesProtocol) -> Expr:
144143
return Expr(
145144
lambda _df: [series],
146145
depth=0,
@@ -150,7 +149,7 @@ def _create_expr_from_series(self, series: SeriesProtocol) -> ExprT:
150149
implementation=self._implementation,
151150
)
152151

153-
def all(self) -> ExprT:
152+
def all(self) -> Expr:
154153
return Expr(
155154
lambda df: [
156155
Series(

narwhals/pandas_like/series.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55

66
from pandas.api.types import is_extension_array_dtype
77

8+
from narwhals.pandas_like.utils import item
9+
from narwhals.pandas_like.utils import validate_column_comparand
810
from narwhals.spec import Series as SeriesProtocol
9-
from narwhals.utils import item
10-
from narwhals.utils import validate_column_comparand
1111

1212
if TYPE_CHECKING:
13+
from typing_extensions import Self
14+
1315
from narwhals.pandas_like.namespace import Namespace
1416

1517

@@ -296,7 +298,7 @@ def sort(
296298
ser.sort_values(ascending=not descending).rename(self.name)
297299
)
298300

299-
def alias(self, name: str) -> Series:
301+
def alias(self, name: str) -> Self:
300302
ser = self.series
301303
return self._from_series(ser.rename(name, copy=False))
302304

narwhals/spec/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
class Expr(Protocol):
15-
def alias(self, name: str) -> Expr:
15+
def alias(self, name: str) -> Self:
1616
...
1717

1818
def __and__(self, other: Any) -> Expr:

0 commit comments

Comments
 (0)