Skip to content

Commit 1e3380c

Browse files
committed
type DataFrame.assign
1 parent 7328e89 commit 1e3380c

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

pandas-stubs/_typing.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -837,4 +837,7 @@ ExcelWriteEngine: TypeAlias = Literal["openpyxl", "odf", "xlsxwriter"]
837837
# https://github.com/pandas-dev/pandas-stubs/pull/1151#issuecomment-2715130190
838838
TimeZones: TypeAlias = str | tzinfo | None | int
839839

840+
# Evaluates to ArrayLike in DataFrame.assign context.
841+
IntoArrayLike: TypeAlias = AnyArrayLike | Callable[[DataFrame], AnyArrayLike]
842+
840843
__all__ = ["npt", "type_t"]

pandas-stubs/core/frame.pyi

+2-1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ from pandas._typing import (
100100
InterpolateOptions,
101101
IntervalClosedType,
102102
IntervalT,
103+
IntoArrayLike,
103104
JoinHow,
104105
JsonFrameOrient,
105106
Label,
@@ -728,7 +729,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
728729
value: Scalar | ListLikeU | None,
729730
allow_duplicates: _bool = ...,
730731
) -> None: ...
731-
def assign(self, **kwargs) -> Self: ...
732+
def assign(self, **kwargs: IntoArrayLike) -> Self: ...
732733
def align(
733734
self,
734735
other: NDFrameT,

tests/test_frame.py

+27
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,33 @@ def test_types_assign() -> None:
308308
df["col3"] = df.sum(axis=1)
309309

310310

311+
def test_assign() -> None:
312+
df = pd.DataFrame({"a": [1, 2, 3], 1: [4, 5, 6]})
313+
314+
my_unnamed_func = lambda df: df["a"] * 2
315+
316+
def my_named_func_1(df: pd.DataFrame) -> pd.Series[str]:
317+
return df["a"]
318+
319+
def my_named_func_2(df: pd.DataFrame) -> pd.Series[Any]:
320+
return df["a"]
321+
322+
check(assert_type(df.assign(c=lambda df: df["a"] * 2), pd.DataFrame), pd.DataFrame)
323+
check(
324+
assert_type(df.assign(c=lambda df: df["a"].index), pd.DataFrame), pd.DataFrame
325+
)
326+
check(
327+
assert_type(df.assign(c=lambda df: df["a"].to_numpy()), pd.DataFrame),
328+
pd.DataFrame,
329+
)
330+
check(assert_type(df.assign(c=df["a"] * 2), pd.DataFrame), pd.DataFrame)
331+
check(assert_type(df.assign(c=df["a"].index), pd.DataFrame), pd.DataFrame)
332+
check(assert_type(df.assign(c=df["a"].to_numpy()), pd.DataFrame), pd.DataFrame)
333+
check(assert_type(df.assign(c=my_unnamed_func), pd.DataFrame), pd.DataFrame)
334+
check(assert_type(df.assign(c=my_named_func_1), pd.DataFrame), pd.DataFrame)
335+
check(assert_type(df.assign(c=my_named_func_2), pd.DataFrame), pd.DataFrame)
336+
337+
311338
def test_types_sample() -> None:
312339
df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
313340
# GH 67

0 commit comments

Comments
 (0)