Skip to content

Commit 4fee62a

Browse files
authored
Merge pull request statsmodels#8828 from bashtage/theta-method-bug
Theta method bug
2 parents 724d1c5 + 074d8ce commit 4fee62a

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

statsmodels/tsa/forecasting/tests/test_theta.py

+26
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,29 @@ def test_forecast_seasonal_alignment(data, period):
133133
index = np.arange(data.shape[0], data.shape[0] + comp.shape[0])
134134
expected = seasonal[index % period]
135135
np.testing.assert_allclose(comp.seasonal, expected)
136+
137+
138+
def test_auto(reset_randomstate):
139+
m = 250
140+
e = np.random.standard_normal(m)
141+
s = 10 * np.sin(np.linspace(0, np.pi, 12))
142+
s = np.tile(s, (m // 12 + 1))[:m]
143+
idx = pd.period_range("2000-01-01", freq="M", periods=m)
144+
x = e + s
145+
y = pd.DataFrame(10 + x - x.min(), index=idx)
146+
147+
tm = ThetaModel(y, method="auto")
148+
assert tm.method == "mul"
149+
res = tm.fit()
150+
151+
tm = ThetaModel(y, method="mul")
152+
assert tm.method == "mul"
153+
res2 = tm.fit()
154+
155+
np.testing.assert_allclose(res.params, res2.params)
156+
157+
tm = ThetaModel(y - y.mean(), method="auto")
158+
assert tm.method == "add"
159+
res3 = tm.fit()
160+
161+
assert not np.allclose(res.params, res3.params)

statsmodels/tsa/forecasting/theta.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ def __init__(
151151
"model",
152152
options=("auto", "additive", "multiplicative", "mul", "add"),
153153
)
154+
if self._method == "auto":
155+
self._method = "mul" if self._y.min() > 0 else "add"
154156
if self._period is None and self._deseasonalize:
155157
idx = getattr(endog, "index", None)
156158
pfreq = None
@@ -183,9 +185,6 @@ def _deseasonalize_data(self) -> Tuple[np.ndarray, np.ndarray]:
183185
y = self._y
184186
if not self._has_seasonality:
185187
return self._y, np.empty(0)
186-
self._method = (
187-
"mul" if self._method == "auto" and self._y.min() > 0 else "add"
188-
)
189188

190189
res = seasonal_decompose(y, model=self._method, period=self._period)
191190
if res.seasonal.min() <= 0:

0 commit comments

Comments
 (0)