diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 30b545270..540baf690 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -84,6 +84,7 @@ from pandas._typing import ( FilePath, FillnaOptions, FormattersType, + Frequency, GroupByObjectNonScalar, HashableT, HashableT1, @@ -855,7 +856,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): def shift( self, periods: int = ..., - freq=..., + freq: Frequency | dt.timedelta | None = ..., axis: Axis = ..., fill_value: Hashable | None = ..., ) -> Self: ... diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index f6afe8e53..393ad4ccc 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -119,6 +119,7 @@ from pandas._typing import ( FilePath, FillnaOptions, FloatDtypeArg, + Frequency, GroupByObjectNonScalar, HashableT1, IgnoreRaise, @@ -1219,10 +1220,10 @@ class Series(IndexOpsMixin[S1], NDFrame): def shift( self, periods: int = ..., - freq=..., + freq: Frequency | timedelta | None = ..., axis: AxisIndex = ..., fill_value: object | None = ..., - ) -> Series[S1]: ... + ) -> Series: ... def memory_usage(self, index: _bool = ..., deep: _bool = ...) -> int: ... def isin(self, values: Iterable | Series[S1] | dict) -> Series[_bool]: ... def between( diff --git a/tests/test_frame.py b/tests/test_frame.py index eb2fcdc97..53b581e6d 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -516,10 +516,13 @@ def test_types_sort_values_with_key() -> None: def test_types_shift() -> None: - df = pd.DataFrame(data={"col1": [1, 1], "col2": [3, 4]}) - df.shift() - df.shift(1) - df.shift(-1) + df = pd.DataFrame( + data={"col1": [1, 1], "col2": [3, 4]}, index=pd.date_range("2020", periods=2) + ) + check(assert_type(df.shift(), pd.DataFrame), pd.DataFrame) + check(assert_type(df.shift(1), pd.DataFrame), pd.DataFrame) + check(assert_type(df.shift(-1), pd.DataFrame), pd.DataFrame) + check(assert_type(df.shift(freq="1D"), pd.DataFrame), pd.DataFrame) def test_types_rank() -> None: diff --git a/tests/test_series.py b/tests/test_series.py index 2857b9f60..2f9b0d63b 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -416,10 +416,15 @@ def test_types_sort_values_with_key() -> None: def test_types_shift() -> None: - s = pd.Series([1, 2, 3]) - s.shift() - s.shift(axis=0, periods=1) - s.shift(-1, fill_value=0) + s = pd.Series([1, 2, 3], index=pd.date_range("2020", periods=3)) + check(assert_type(s.shift(), pd.Series), pd.Series, np.floating) + check( + assert_type(s.shift(axis=0, periods=1), pd.Series), + pd.Series, + np.floating, + ) + check(assert_type(s.shift(-1, fill_value=0), pd.Series), pd.Series, np.integer) + check(assert_type(s.shift(freq="1D"), pd.Series), pd.Series, np.integer) def test_types_rank() -> None: