Skip to content

Commit 54a763c

Browse files
Fix signature of Series.map() (#942)
* Fix signature of Series.map() * Add tests for Series.map() hints.
1 parent 56ddb6f commit 54a763c

File tree

3 files changed

+76
-1
lines changed

3 files changed

+76
-1
lines changed

pandas-stubs/_typing.pyi

+19
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,25 @@ S1 = TypeVar(
549549
| BaseOffset,
550550
)
551551

552+
S2 = TypeVar(
553+
"S2",
554+
bound=str
555+
| bytes
556+
| datetime.date
557+
| datetime.time
558+
| bool
559+
| int
560+
| float
561+
| complex
562+
| Dtype
563+
| datetime.datetime # includes pd.Timestamp
564+
| datetime.timedelta # includes pd.Timedelta
565+
| Period
566+
| Interval
567+
| CategoricalDtype
568+
| BaseOffset,
569+
)
570+
552571
IndexingInt: TypeAlias = (
553572
int | np.int_ | np.integer | np.unsignedinteger | np.signedinteger | np.int8
554573
)

pandas-stubs/core/series.pyi

+13-1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ from pandas._libs.tslibs import BaseOffset
9393
from pandas._libs.tslibs.nattype import NaTType
9494
from pandas._typing import (
9595
S1,
96+
S2,
9697
AggFuncTypeBase,
9798
AggFuncTypeDictFrame,
9899
AggFuncTypeSeriesToFrame,
@@ -913,7 +914,18 @@ class Series(IndexOpsMixin[S1], NDFrame):
913914
level: Level = ...,
914915
fill_value: int | _str | dict | None = ...,
915916
) -> DataFrame: ...
916-
def map(self, arg, na_action: Literal["ignore"] | None = ...) -> Series[S1]: ...
917+
@overload
918+
def map(
919+
self,
920+
arg: Callable[[S1], S2 | NAType] | Mapping[S1, S2] | Series[S2],
921+
na_action: Literal["ignore"] = ...,
922+
) -> Series[S2]: ...
923+
@overload
924+
def map(
925+
self,
926+
arg: Callable[[S1 | NAType], S2 | NAType] | Mapping[S1, S2] | Series[S2],
927+
na_action: None = ...,
928+
) -> Series[S2]: ...
917929
@overload
918930
def aggregate( # type: ignore[overload-overlap]
919931
self: Series[int],

tests/test_series.py

+44
Original file line numberDiff line numberDiff line change
@@ -3230,3 +3230,47 @@ def test_operator_constistency() -> None:
32303230
pd.Series,
32313231
pd.Timedelta,
32323232
)
3233+
3234+
3235+
def test_map() -> None:
3236+
s = pd.Series([1, 2, 3])
3237+
3238+
mapping = {1: "a", 2: "b", 3: "c"}
3239+
check(
3240+
assert_type(s.map(mapping, na_action="ignore"), "pd.Series[str]"),
3241+
pd.Series,
3242+
str,
3243+
)
3244+
3245+
def callable(x: int) -> str:
3246+
return str(x)
3247+
3248+
check(
3249+
assert_type(s.map(callable, na_action="ignore"), "pd.Series[str]"),
3250+
pd.Series,
3251+
str,
3252+
)
3253+
3254+
series = pd.Series(["a", "b", "c"])
3255+
check(
3256+
assert_type(s.map(series, na_action="ignore"), "pd.Series[str]"), pd.Series, str
3257+
)
3258+
3259+
3260+
def test_map_na() -> None:
3261+
s: pd.Series[int] = pd.Series([1, pd.NA, 3])
3262+
3263+
mapping = {1: "a", 2: "b", 3: "c"}
3264+
check(assert_type(s.map(mapping, na_action=None), "pd.Series[str]"), pd.Series, str)
3265+
3266+
def callable(x: int | NAType) -> str | NAType:
3267+
if isinstance(x, int):
3268+
return str(x)
3269+
return x
3270+
3271+
check(
3272+
assert_type(s.map(callable, na_action=None), "pd.Series[str]"), pd.Series, str
3273+
)
3274+
3275+
series = pd.Series(["a", "b", "c"])
3276+
check(assert_type(s.map(series, na_action=None), "pd.Series[str]"), pd.Series, str)

0 commit comments

Comments
 (0)