@@ -308,6 +308,38 @@ def test_types_assign() -> None:
308308 df ["col3" ] = df .sum (axis = 1 )
309309
310310
311+ def test_assign () -> None :
312+ df = pd .DataFrame ({"a" : [1 , 2 , 3 ], 1 : [4 , 5 , 6 ]})
313+
314+ my_unnamed_func = lambda df : df ["a" ] * 2
315+
316+ def my_named_func_1 (df : pd .DataFrame ) -> pd .Series [str ]:
317+ return df ["a" ]
318+
319+ def my_named_func_2 (df : pd .DataFrame ) -> pd .Series [Any ]:
320+ return df ["a" ]
321+
322+ check (assert_type (df .assign (c = lambda df : df ["a" ] * 2 ), pd .DataFrame ), pd .DataFrame )
323+ check (
324+ assert_type (df .assign (c = lambda df : df ["a" ].index ), pd .DataFrame ), pd .DataFrame
325+ )
326+ check (
327+ assert_type (df .assign (c = lambda df : df ["a" ].to_numpy ()), pd .DataFrame ),
328+ pd .DataFrame ,
329+ )
330+ check (
331+ assert_type (df .assign (c = lambda df : df ["a" ].max ()), pd .DataFrame ),
332+ pd .DataFrame ,
333+ )
334+ check (assert_type (df .assign (c = df ["a" ] * 2 ), pd .DataFrame ), pd .DataFrame )
335+ check (assert_type (df .assign (c = df ["a" ].index ), pd .DataFrame ), pd .DataFrame )
336+ check (assert_type (df .assign (c = df ["a" ].to_numpy ()), pd .DataFrame ), pd .DataFrame )
337+ check (assert_type (df .assign (c = 2 ), pd .DataFrame ), pd .DataFrame )
338+ check (assert_type (df .assign (c = my_unnamed_func ), pd .DataFrame ), pd .DataFrame )
339+ check (assert_type (df .assign (c = my_named_func_1 ), pd .DataFrame ), pd .DataFrame )
340+ check (assert_type (df .assign (c = my_named_func_2 ), pd .DataFrame ), pd .DataFrame )
341+
342+
311343def test_types_sample () -> None :
312344 df = pd .DataFrame (data = {"col1" : [1 , 2 ], "col2" : [3 , 4 ]})
313345 # GH 67
0 commit comments