Skip to content

Commit 3bba101

Browse files
GH456 Attempt GroupBy.aggregate improved typing
1 parent 106a6f5 commit 3bba101

File tree

3 files changed

+56
-41
lines changed

3 files changed

+56
-41
lines changed

pandas-stubs/_typing.pyi

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -925,44 +925,6 @@ GroupByObjectNonScalar: TypeAlias = (
925925
| list[Grouper]
926926
)
927927
GroupByObject: TypeAlias = Scalar | Index | GroupByObjectNonScalar | Series
928-
GroupByFuncStrs: TypeAlias = Literal[
929-
# Reduction/aggregation functions
930-
"all",
931-
"any",
932-
"corrwith",
933-
"count",
934-
"first",
935-
"idxmax",
936-
"idxmin",
937-
"last",
938-
"max",
939-
"mean",
940-
"median",
941-
"min",
942-
"nunique",
943-
"prod",
944-
"quantile",
945-
"sem",
946-
"size",
947-
"skew",
948-
"std",
949-
"sum",
950-
"var",
951-
# Transformation functions
952-
"bfill",
953-
"cumcount",
954-
"cummax",
955-
"cummin",
956-
"cumprod",
957-
"cumsum",
958-
"diff",
959-
"ffill",
960-
"fillna",
961-
"ngroup",
962-
"pct_change",
963-
"rank",
964-
"shift",
965-
]
966928

967929
StataDateFormat: TypeAlias = Literal[
968930
"tc",

pandas-stubs/core/groupby/base.pyi

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,56 @@
11
from collections.abc import Hashable
22
import dataclasses
3+
from typing import (
4+
Literal,
5+
TypeAlias,
6+
)
37

48
@dataclasses.dataclass(order=True, frozen=True)
59
class OutputKey:
610
label: Hashable
711
position: int
12+
13+
reduction_kernels: TypeAlias = Literal[
14+
"all",
15+
"any",
16+
"corrwith",
17+
"count",
18+
"first",
19+
"idxmax",
20+
"idxmin",
21+
"last",
22+
"max",
23+
"mean",
24+
"median",
25+
"min",
26+
"nunique",
27+
"prod",
28+
# as long as `quantile`'s signature accepts only
29+
# a single quantile value, it's a reduction.
30+
# GH#27526 might change that.
31+
"quantile",
32+
"sem",
33+
"size",
34+
"skew",
35+
"std",
36+
"sum",
37+
"var",
38+
]
39+
40+
transformation_kernels: TypeAlias = Literal[
41+
"bfill",
42+
"cumcount",
43+
"cummax",
44+
"cummin",
45+
"cumprod",
46+
"cumsum",
47+
"diff",
48+
"ffill",
49+
"fillna",
50+
"ngroup",
51+
"pct_change",
52+
"rank",
53+
"shift",
54+
]
55+
56+
transform_kernel_allowlist: TypeAlias = reduction_kernels | transformation_kernels

pandas-stubs/core/groupby/generic.pyi

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ from typing import (
1919
from matplotlib.axes import Axes as PlotAxes
2020
import numpy as np
2121
from pandas.core.frame import DataFrame
22+
from pandas.core.groupby.base import transform_kernel_allowlist
2223
from pandas.core.groupby.groupby import (
2324
GroupBy,
2425
GroupByPlot,
@@ -41,7 +42,6 @@ from pandas._typing import (
4142
ByT,
4243
CorrelationMethod,
4344
Dtype,
44-
GroupByFuncStrs,
4545
IndexLabel,
4646
Level,
4747
ListLike,
@@ -109,7 +109,9 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]):
109109
**kwargs: Any,
110110
) -> UnknownSeries: ...
111111
@overload
112-
def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> UnknownSeries: ...
112+
def transform(
113+
self, func: transform_kernel_allowlist, *args, **kwargs
114+
) -> UnknownSeries: ...
113115
def filter(
114116
self, func: Callable | str, dropna: bool = ..., *args, **kwargs
115117
) -> Series: ...
@@ -253,7 +255,9 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
253255
**kwargs: Any,
254256
) -> DataFrame: ...
255257
@overload
256-
def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> DataFrame: ...
258+
def transform(
259+
self, func: transform_kernel_allowlist, *args, **kwargs
260+
) -> DataFrame: ...
257261
def filter(
258262
self, func: Callable, dropna: bool = ..., *args, **kwargs
259263
) -> DataFrame: ...

0 commit comments

Comments
 (0)