21
21
if TYPE_CHECKING :
22
22
from collections .abc import Sequence
23
23
24
+ from typing_extensions import Self
25
+
24
26
25
27
class DataFrame (DataFrameT ):
26
28
"""dataframe object"""
@@ -136,7 +138,7 @@ def join(
136
138
.collect ()
137
139
)
138
140
139
- def lazy (self ) -> LazyFrameT :
141
+ def lazy (self ) -> LazyFrame :
140
142
return LazyFrame (
141
143
self .dataframe ,
142
144
api_version = self .api_version ,
@@ -217,8 +219,8 @@ def _validate_booleanness(self) -> None:
217
219
msg ,
218
220
)
219
221
220
- def _from_dataframe (self , df : Any ) -> LazyFrameT :
221
- return LazyFrame (
222
+ def _from_dataframe (self , df : Any ) -> Self :
223
+ return self . __class__ (
222
224
df ,
223
225
api_version = self .api_version ,
224
226
implementation = self ._implementation ,
@@ -245,7 +247,7 @@ def select(
245
247
self ,
246
248
* exprs : IntoExpr | Iterable [IntoExpr ],
247
249
** named_exprs : IntoExpr ,
248
- ) -> LazyFrameT :
250
+ ) -> Self :
249
251
new_series = evaluate_into_exprs (self , * exprs , ** named_exprs )
250
252
df = horizontal_concat (
251
253
[series .series for series in new_series ], # type: ignore[attr-defined]
@@ -256,7 +258,7 @@ def select(
256
258
def filter (
257
259
self ,
258
260
* predicates : IntoExpr | Iterable [IntoExpr ],
259
- ) -> LazyFrameT :
261
+ ) -> Self :
260
262
plx = self .__lazyframe_namespace__ ()
261
263
expr = plx .all_horizontal (* predicates )
262
264
# Safety: all_horizontal's expression only returns a single column.
@@ -268,7 +270,7 @@ def with_columns(
268
270
self ,
269
271
* exprs : IntoExpr | Iterable [IntoExpr ],
270
272
** named_exprs : IntoExpr ,
271
- ) -> LazyFrameT :
273
+ ) -> Self :
272
274
new_series = evaluate_into_exprs (self , * exprs , ** named_exprs )
273
275
df = self .dataframe .assign (
274
276
** {
@@ -283,7 +285,7 @@ def sort(
283
285
by : str | Iterable [str ],
284
286
* more_by : str ,
285
287
descending : bool | Iterable [bool ] = False ,
286
- ) -> LazyFrameT :
288
+ ) -> Self :
287
289
flat_keys = flatten_str ([* flatten_str (by ), * more_by ])
288
290
if not flat_keys :
289
291
flat_keys = self .dataframe .columns .tolist ()
@@ -304,7 +306,7 @@ def join(
304
306
how : Literal ["left" , "inner" , "outer" ] = "inner" ,
305
307
left_on : str | list [str ],
306
308
right_on : str | list [str ],
307
- ) -> LazyFrameT :
309
+ ) -> Self :
308
310
if how not in ["inner" ]:
309
311
msg = "Only inner join supported for now, others coming soon"
310
312
raise ValueError (msg )
@@ -337,14 +339,14 @@ def collect(self) -> DataFrameT:
337
339
implementation = self ._implementation ,
338
340
)
339
341
340
- def cache (self ) -> LazyFrameT :
342
+ def cache (self ) -> Self :
341
343
return self
342
344
343
- def head (self , n : int ) -> LazyFrameT :
345
+ def head (self , n : int ) -> Self :
344
346
return self ._from_dataframe (self .dataframe .head (n ))
345
347
346
- def unique (self , subset : list [str ]) -> LazyFrameT :
348
+ def unique (self , subset : list [str ]) -> Self :
347
349
return self ._from_dataframe (self .dataframe .drop_duplicates (subset = subset ))
348
350
349
- def rename (self , mapping : dict [str , str ]) -> LazyFrameT :
351
+ def rename (self , mapping : dict [str , str ]) -> Self :
350
352
return self ._from_dataframe (self .dataframe .rename (columns = mapping ))
0 commit comments