Skip to content

Commit 94a665c

Browse files
committed
remove include_groups. use P.args and P.kwargs in apply defs
1 parent d13dfe2 commit 94a665c

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

pandas-stubs/core/groupby/generic.pyi

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
209209

210210
_TT = TypeVar("_TT", bound=Literal[True, False])
211211

212+
# ty ignore needed because of https://github.com/astral-sh/ty/issues/157#issuecomment-3017337945
212213
class DFCallable1(Protocol[P]): # ty: ignore[invalid-argument-type]
213214
def __call__(
214215
self, df: DataFrame, /, *args: P.args, **kwargs: P.kwargs
@@ -227,23 +228,26 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
227228
@overload # type: ignore[override]
228229
def apply(
229230
self,
230-
func: DFCallable1,
231-
*args,
232-
**kwargs,
231+
func: DFCallable1[P],
232+
/,
233+
*args: P.args,
234+
**kwargs: P.kwargs,
233235
) -> Series: ...
234236
@overload
235237
def apply(
236238
self,
237-
func: DFCallable2,
238-
*args,
239-
**kwargs,
239+
func: DFCallable2[P],
240+
/,
241+
*args: P.args,
242+
**kwargs: P.kwargs,
240243
) -> DataFrame: ...
241244
@overload
242-
def apply( # pyright: ignore[reportOverlappingOverload]
245+
def apply(
243246
self,
244-
func: DFCallable3,
245-
*args,
246-
**kwargs,
247+
func: DFCallable3[P],
248+
/,
249+
*args: P.args,
250+
**kwargs: P.kwargs,
247251
) -> DataFrame: ...
248252
# error: overload 1 overlaps overload 2 because of different return types
249253
@overload

tests/test_groupby.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def resample_interpolate(x: DataFrame) -> DataFrame:
273273

274274
check(
275275
assert_type(
276-
GB_DF.apply(resample_interpolate, include_groups=False),
276+
GB_DF.apply(resample_interpolate),
277277
DataFrame,
278278
),
279279
DataFrame,
@@ -286,7 +286,6 @@ def resample_interpolate_linear(x: DataFrame) -> DataFrame:
286286
assert_type(
287287
GB_DF.apply(
288288
resample_interpolate_linear,
289-
include_groups=False,
290289
),
291290
DataFrame,
292291
),

0 commit comments

Comments
 (0)