Skip to content

Commit 257fcec

Browse files
Make index check on statespace data less strict (#434)
* Make index check less strict * Validate generic index values * Remove redundant check
1 parent fc39c49 commit 257fcec

File tree

2 files changed

+89
-9
lines changed

2 files changed

+89
-9
lines changed

pymc_extras/statespace/utils/data_tools.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -87,23 +87,38 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals
8787
col_names = data.columns
8888
_validate_data_shape(data.shape, n_obs, obs_coords, check_column_names, col_names)
8989

90-
if isinstance(data.index, pd.RangeIndex):
91-
if obs_coords is not None:
92-
warnings.warn(NO_TIME_INDEX_WARNING)
93-
return preprocess_numpy_data(data.values, n_obs, obs_coords)
94-
95-
elif isinstance(data.index, pd.DatetimeIndex):
90+
if isinstance(data.index, pd.DatetimeIndex):
9691
if data.index.freq is None:
9792
warnings.warn(NO_FREQ_INFO_WARNING)
9893
data.index.freq = data.index.inferred_freq
9994

10095
index = data.index
10196
return data.values, index
10297

98+
elif isinstance(data.index, pd.RangeIndex):
99+
if obs_coords is not None:
100+
warnings.warn(NO_TIME_INDEX_WARNING)
101+
return preprocess_numpy_data(data.values, n_obs, obs_coords)
102+
103+
elif isinstance(data.index, pd.MultiIndex):
104+
if obs_coords is not None:
105+
warnings.warn(NO_TIME_INDEX_WARNING)
106+
107+
raise NotImplementedError("MultiIndex panel data is not currently supported.")
108+
103109
else:
104-
raise IndexError(
105-
f"Expected pd.DatetimeIndex or pd.RangeIndex on data, found {type(data.index)}"
106-
)
110+
if obs_coords is not None:
111+
warnings.warn(NO_TIME_INDEX_WARNING)
112+
113+
index = data.index
114+
if not np.issubdtype(index.dtype, np.integer):
115+
raise IndexError("Provided index is not an integer index.")
116+
117+
index_diff = index.to_series().diff().dropna().values
118+
if not (index_diff == 1).all():
119+
raise IndexError("Provided index is not monotonic increasing.")
120+
121+
return preprocess_numpy_data(data.values, n_obs, obs_coords)
107122

108123

109124
def add_data_to_active_model(values, index, data_dims=None):

tests/statespace/test_coord_assignment.py

+65
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
from pymc_extras.statespace.models import structural
11+
from pymc_extras.statespace.models.structural import LevelTrendComponent
1112
from pymc_extras.statespace.utils.constants import (
1213
FILTER_OUTPUT_DIMS,
1314
FILTER_OUTPUT_NAMES,
@@ -114,3 +115,67 @@ def test_data_index_is_coord(f, warning, create_model):
114115
with warning:
115116
pymc_model = create_model(f)
116117
assert TIME_DIM in pymc_model.coords
118+
119+
120+
def make_model(index):
121+
n = len(index)
122+
a = pd.DataFrame(index=index, columns=["A", "B", "C", "D"], data=np.arange(n * 4).reshape(n, 4))
123+
124+
mod = LevelTrendComponent(order=2, innovations_order=[0, 1])
125+
ss_mod = mod.build(name="a", verbose=False)
126+
127+
initial_trend_dims, sigma_trend_dims, P0_dims = ss_mod.param_dims.values()
128+
coords = ss_mod.coords
129+
130+
with pm.Model(coords=coords) as model:
131+
P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5)
132+
P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states) * P0_diag, dims=P0_dims)
133+
134+
initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims)
135+
sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=50, dims=sigma_trend_dims)
136+
137+
with pytest.warns(UserWarning, match="No time index found on the supplied data"):
138+
ss_mod.build_statespace_graph(
139+
a["A"],
140+
mode="JAX",
141+
)
142+
return model
143+
144+
145+
def test_integer_index():
146+
index = np.arange(8).astype(int)
147+
model = make_model(index)
148+
assert TIME_DIM in model.coords
149+
np.testing.assert_allclose(model.coords[TIME_DIM], index)
150+
151+
152+
def test_float_index_raises():
153+
index = np.linspace(0, 1, 8)
154+
155+
with pytest.raises(IndexError, match="Provided index is not an integer index"):
156+
make_model(index)
157+
158+
159+
def test_non_strictly_monotone_index_raises():
160+
# Decreases
161+
index = [0, 1, 2, 1, 2, 3]
162+
with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
163+
make_model(index)
164+
165+
# Has gaps
166+
index = [0, 1, 2, 3, 5, 6]
167+
with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
168+
make_model(index)
169+
170+
# Has duplicates
171+
index = [0, 1, 1, 2, 3, 4]
172+
with pytest.raises(IndexError, match="Provided index is not monotonic increasing"):
173+
make_model(index)
174+
175+
176+
def test_multiindex_raises():
177+
index = pd.MultiIndex.from_tuples([(0, 0), (1, 1), (2, 2), (3, 3)])
178+
with pytest.raises(
179+
NotImplementedError, match="MultiIndex panel data is not currently supported"
180+
):
181+
make_model(index)

0 commit comments

Comments
 (0)