diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 635a418f..2655f09f 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -837,4 +837,9 @@ ExcelWriteEngine: TypeAlias = Literal["openpyxl", "odf", "xlsxwriter"] # https://github.com/pandas-dev/pandas-stubs/pull/1151#issuecomment-2715130190 TimeZones: TypeAlias = str | tzinfo | None | int +# Evaluates to a DataFrame column in DataFrame.assign context. +IntoColumn: TypeAlias = ( + AnyArrayLike | Scalar | Callable[[DataFrame], AnyArrayLike | Scalar] +) + __all__ = ["npt", "type_t"] diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 11449493..f88bd0c6 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -100,6 +100,7 @@ from pandas._typing import ( InterpolateOptions, IntervalClosedType, IntervalT, + IntoColumn, JoinHow, JsonFrameOrient, Label, @@ -728,7 +729,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): value: Scalar | ListLikeU | None, allow_duplicates: _bool = ..., ) -> None: ... - def assign(self, **kwargs) -> Self: ... + def assign(self, **kwargs: IntoColumn) -> Self: ... def align( self, other: NDFrameT, diff --git a/tests/test_frame.py b/tests/test_frame.py index eca3cec8..99f11b3c 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -308,6 +308,38 @@ def test_types_assign() -> None: df["col3"] = df.sum(axis=1) +def test_assign() -> None: + df = pd.DataFrame({"a": [1, 2, 3], 1: [4, 5, 6]}) + + my_unnamed_func = lambda df: df["a"] * 2 + + def my_named_func_1(df: pd.DataFrame) -> pd.Series[str]: + return df["a"] + + def my_named_func_2(df: pd.DataFrame) -> pd.Series[Any]: + return df["a"] + + check(assert_type(df.assign(c=lambda df: df["a"] * 2), pd.DataFrame), pd.DataFrame) + check( + assert_type(df.assign(c=lambda df: df["a"].index), pd.DataFrame), pd.DataFrame + ) + check( + assert_type(df.assign(c=lambda df: df["a"].to_numpy()), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type(df.assign(c=lambda df: df["a"].max()), pd.DataFrame), + pd.DataFrame, + ) + check(assert_type(df.assign(c=df["a"] * 2), pd.DataFrame), pd.DataFrame) + check(assert_type(df.assign(c=df["a"].index), pd.DataFrame), pd.DataFrame) + check(assert_type(df.assign(c=df["a"].to_numpy()), pd.DataFrame), pd.DataFrame) + check(assert_type(df.assign(c=2), pd.DataFrame), pd.DataFrame) + check(assert_type(df.assign(c=my_unnamed_func), pd.DataFrame), pd.DataFrame) + check(assert_type(df.assign(c=my_named_func_1), pd.DataFrame), pd.DataFrame) + check(assert_type(df.assign(c=my_named_func_2), pd.DataFrame), pd.DataFrame) + + def test_types_sample() -> None: df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) # GH 67