Skip to content

Commit 384e522

Browse files
authored
Merge pull request statsmodels#6130 from ChadFulton/fix-predict-dates
BUG: Incorrect TSA index if loc resolves to slice
2 parents 7610de9 + 51ee4bf commit 384e522

File tree

2 files changed

+137
-4
lines changed

2 files changed

+137
-4
lines changed

statsmodels/tsa/base/tests/test_tsa_indexes.py

Lines changed: 136 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ def test_prediction_increment_pandas_noindex():
659659
assert_equal(prediction_index.equals(pd.Index(np.arange(1, 6))), True)
660660

661661

662-
def test_prediction_increment_pandas_dates():
662+
def test_prediction_increment_pandas_dates_daily():
663663
# Date-based index
664664
endog = dta[2].copy()
665665
endog.index = date_indexes[0][0] # Daily, 1950-01-01, 1950-01-02, ...
@@ -680,6 +680,18 @@ def test_prediction_increment_pandas_dates():
680680
assert type(prediction_index) is type(endog.index) # noqa: E721
681681
assert_equal(prediction_index.equals(mod._index), True)
682682

683+
# In-sample prediction: [0, 3]; the index is a subset of the date index
684+
start_key = 0
685+
end_key = 3
686+
start, end, out_of_sample, prediction_index = (
687+
mod._get_prediction_index(start_key, end_key))
688+
689+
assert_equal(start, 0)
690+
assert_equal(end, 3)
691+
assert_equal(out_of_sample, 0)
692+
assert type(prediction_index) is type(endog.index) # noqa: E721
693+
assert_equal(prediction_index.equals(mod._index[:4]), True)
694+
683695
# Negative index: [-2, end]
684696
start_key = -2
685697
end_key = -1
@@ -705,6 +717,20 @@ def test_prediction_increment_pandas_dates():
705717
assert_equal(prediction_index.equals(desired_index), True)
706718

707719
# Date-based keys
720+
721+
# In-sample prediction (equivalent to [1, 3])
722+
start_key = '1950-01-02'
723+
end_key = '1950-01-04'
724+
start, end, out_of_sample, prediction_index = (
725+
mod._get_prediction_index(start_key, end_key))
726+
727+
assert_equal(start, 1)
728+
assert_equal(end, 3)
729+
assert_equal(out_of_sample, 0)
730+
assert type(prediction_index) is type(endog.index) # noqa: E721
731+
assert_equal(prediction_index.equals(mod._index[1:4]), True)
732+
733+
# Out-of-sample forecasting (equivalent to [0, 5])
708734
start_key = '1950-01-01'
709735
end_key = '1950-01-08'
710736
start, end, out_of_sample, prediction_index = (
@@ -716,7 +742,6 @@ def test_prediction_increment_pandas_dates():
716742
desired_index = pd.date_range(start='1950-01-01', periods=8, freq='D')
717743
assert_equal(prediction_index.equals(desired_index), True)
718744

719-
720745
# Test getting a location that exists in the (internal) index
721746
loc, index, index_was_expanded = mod._get_index_loc(2)
722747
assert_equal(loc, 2)
@@ -741,10 +766,118 @@ def test_prediction_increment_pandas_dates():
741766
assert_equal(index_was_expanded, False)
742767

743768

769+
def test_prediction_increment_pandas_dates_monthly():
770+
# Date-based index
771+
endog = dta[2].copy()
772+
endog.index = date_indexes[2][0] # Monthly, 1950-01, 1950-02, ...
773+
mod = tsa_model.TimeSeriesModel(endog)
774+
775+
# Tests three common use cases: basic prediction, negative indexes, and
776+
# out-of-sample indexes.
777+
778+
# Basic prediction: [0, end]; the index is the date index
779+
start_key = 0
780+
end_key = None
781+
start, end, out_of_sample, prediction_index = (
782+
mod._get_prediction_index(start_key, end_key))
783+
784+
assert_equal(start, 0)
785+
assert_equal(end, nobs-1)
786+
assert_equal(out_of_sample, 0)
787+
assert type(prediction_index) is type(endog.index) # noqa: E721
788+
assert_equal(prediction_index.equals(mod._index), True)
789+
790+
# In-sample prediction: [0, 3]; the index is a subset of the date index
791+
start_key = 0
792+
end_key = 3
793+
start, end, out_of_sample, prediction_index = (
794+
mod._get_prediction_index(start_key, end_key))
795+
796+
assert_equal(start, 0)
797+
assert_equal(end, 3)
798+
assert_equal(out_of_sample, 0)
799+
assert type(prediction_index) is type(endog.index) # noqa: E721
800+
assert_equal(prediction_index.equals(mod._index[:4]), True)
801+
802+
# Negative index: [-2, end]
803+
start_key = -2
804+
end_key = -1
805+
start, end, out_of_sample, prediction_index = (
806+
mod._get_prediction_index(start_key, end_key))
807+
808+
assert_equal(start, 3)
809+
assert_equal(end, 4)
810+
assert_equal(out_of_sample, 0)
811+
assert type(prediction_index) is type(endog.index) # noqa: E721
812+
assert_equal(prediction_index.equals(mod._index[3:]), True)
813+
814+
# Forecasting: [1, 5]; the index is an extended version of the date index
815+
start_key = 1
816+
end_key = nobs
817+
start, end, out_of_sample, prediction_index = (
818+
mod._get_prediction_index(start_key, end_key))
819+
820+
assert_equal(start, 1)
821+
assert_equal(end, 4)
822+
assert_equal(out_of_sample, 1)
823+
desired_index = pd.date_range(start='1950-02', periods=5, freq='M')
824+
assert_equal(prediction_index.equals(desired_index), True)
825+
826+
# Date-based keys
827+
828+
# In-sample prediction (equivalent to [1, 3])
829+
start_key = '1950-02'
830+
end_key = '1950-04'
831+
start, end, out_of_sample, prediction_index = (
832+
mod._get_prediction_index(start_key, end_key))
833+
834+
assert_equal(start, 1)
835+
assert_equal(end, 3)
836+
assert_equal(out_of_sample, 0)
837+
assert type(prediction_index) is type(endog.index) # noqa: E721
838+
assert_equal(prediction_index.equals(mod._index[1:4]), True)
839+
840+
# Out-of-sample forecasting (equivalent to [0, 5])
841+
start_key = '1950-01'
842+
end_key = '1950-08'
843+
start, end, out_of_sample, prediction_index = (
844+
mod._get_prediction_index(start_key, end_key))
845+
846+
assert_equal(start, 0)
847+
assert_equal(end, 4)
848+
assert_equal(out_of_sample, 3)
849+
desired_index = pd.date_range(start='1950-01', periods=8, freq='M')
850+
assert_equal(prediction_index.equals(desired_index), True)
851+
852+
# Test getting a location that exists in the (internal) index
853+
loc, index, index_was_expanded = mod._get_index_loc(2)
854+
assert_equal(loc, 2)
855+
desired_index = pd.date_range(start='1950-01', periods=3, freq='M')
856+
assert_equal(index.equals(desired_index), True)
857+
assert_equal(index_was_expanded, False)
858+
859+
# Test getting a location that exists in the (internal) index
860+
# when using the function that alternatively falls back to the row labels
861+
loc, index, index_was_expanded = mod._get_index_label_loc(2)
862+
assert_equal(loc, 2)
863+
desired_index = pd.date_range(start='1950-01', periods=3, freq='M')
864+
assert_equal(index.equals(desired_index), True)
865+
assert_equal(index_was_expanded, False)
866+
867+
# Test getting a location that exists in the given (unsupported) index
868+
# Note that the returned index is now like the row labels
869+
loc, index, index_was_expanded = mod._get_index_label_loc('1950-03')
870+
assert_equal(loc, slice(2, 3, None))
871+
desired_index = mod.data.row_labels[:3]
872+
assert_equal(index.equals(desired_index), True)
873+
assert_equal(index_was_expanded, False)
874+
875+
744876
def test_prediction_increment_pandas_dates_nanosecond():
745877
# Date-based index
746878
endog = dta[2].copy()
747-
endog.index = pd.date_range(start='1970-01-01', periods=len(endog), freq='N')
879+
endog.index = pd.date_range(start='1970-01-01', periods=len(endog),
880+
freq='N')
748881
mod = tsa_model.TimeSeriesModel(endog)
749882

750883
# Tests three common use cases: basic prediction, negative indexes, and

statsmodels/tsa/base/tsa_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def _get_index_loc(self, key, base_index=None):
377377

378378
# Return the index through the end of the loc / slice
379379
if isinstance(loc, slice):
380-
end = loc.stop
380+
end = loc.stop - 1
381381
else:
382382
end = loc
383383

0 commit comments

Comments
 (0)