Skip to content

Commit 724d1c5

Browse files
authored
Merge pull request statsmodels#8831 from bashtage/fix-expon-smoothing-initial-values
BUG: Correct initial level, treand and seasonal
2 parents afac1aa + 0eb5ac7 commit 724d1c5

File tree

4 files changed

+109
-71
lines changed

4 files changed

+109
-71
lines changed

statsmodels/tsa/holtwinters/_smoothers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def holt_mul_dam(x, hw_args: HoltWintersArgs):
192192
for i in range(1, hw_args.n):
193193
lvl[i] = (y_alpha[i - 1]) + (alphac * (lvl[i - 1] * b[i - 1] ** phi))
194194
b[i] = (beta * (lvl[i] / lvl[i - 1])) + (betac * b[i - 1] ** phi)
195-
return hw_args.y - lvl * b ** phi
195+
return hw_args.y - lvl * b**phi
196196

197197

198198
def holt_add_dam(x, hw_args: HoltWintersArgs):
@@ -337,7 +337,7 @@ def holt_win_mul_mul_dam(x, hw_args: HoltWintersArgs):
337337
s[i + m - 1] = (y_gamma[i - 1] / (lvl[i - 1] * b[i - 1] ** phi)) + (
338338
gammac * s[i - 1]
339339
)
340-
return hw_args.y - (lvl * b ** phi) * s[: -(m - 1)]
340+
return hw_args.y - (lvl * b**phi) * s[: -(m - 1)]
341341

342342

343343
def holt_win_add_add_dam(x, hw_args: HoltWintersArgs):

statsmodels/tsa/holtwinters/model.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,8 @@ def __init__(
287287
)
288288
estimated = self._initialization_method == "estimated"
289289
self._estimate_level = estimated
290-
self._estimate_trend = estimated and self.trend
291-
self._estimate_seasonal = estimated and self.seasonal
290+
self._estimate_trend = estimated and self.trend is not None
291+
self._estimate_seasonal = estimated and self.seasonal is not None
292292
self._bounds = self._check_bounds(bounds)
293293
self._use_boxcox = use_boxcox
294294
self._lambda = np.nan
@@ -765,8 +765,6 @@ def _optimize_parameters(
765765
beta = data.beta
766766
phi = data.phi
767767
gamma = data.gamma
768-
initial_level = data.level
769-
initial_trend = data.trend
770768
y = data.y
771769
start_params = data.params
772770

@@ -796,11 +794,11 @@ def _optimize_parameters(
796794
alpha is None,
797795
has_trend and beta is None,
798796
has_seasonal and gamma is None,
799-
initial_level is None,
800-
has_trend and initial_trend is None,
797+
self._estimate_level,
798+
self._estimate_trend,
801799
damped_trend and phi is None,
802800
]
803-
+ [has_seasonal] * m,
801+
+ [has_seasonal and self._estimate_seasonal] * m,
804802
)
805803
(
806804
sel,

statsmodels/tsa/holtwinters/results.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def simulate(
668668
resid = self.model._y - fitted
669669
else:
670670
resid = (self.model._y - fitted) / fitted
671-
sigma = np.sqrt(np.sum(resid ** 2) / (len(resid) - n_params))
671+
sigma = np.sqrt(np.sum(resid**2) / (len(resid) - n_params))
672672

673673
# get random error eps
674674
if isinstance(random_errors, np.ndarray):

statsmodels/tsa/holtwinters/tests/test_holtwinters.py

+101-61
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,62 @@
4343
SEASONALS = ("add", "mul", None)
4444
TRENDS = ("add", "mul", None)
4545

46+
# aust = pd.read_json(aust_json, typ='Series').sort_index()
47+
data = [
48+
41.727457999999999,
49+
24.04185,
50+
32.328102999999999,
51+
37.328707999999999,
52+
46.213152999999998,
53+
29.346326000000001,
54+
36.482909999999997,
55+
42.977719,
56+
48.901524999999999,
57+
31.180221,
58+
37.717880999999998,
59+
40.420211000000002,
60+
51.206862999999998,
61+
31.887228,
62+
40.978262999999998,
63+
43.772491000000002,
64+
55.558566999999996,
65+
33.850915000000001,
66+
42.076383,
67+
45.642291999999998,
68+
59.766779999999997,
69+
35.191876999999998,
70+
44.319737000000003,
71+
47.913736,
72+
]
73+
index = [
74+
"2005-03-01 00:00:00",
75+
"2005-06-01 00:00:00",
76+
"2005-09-01 00:00:00",
77+
"2005-12-01 00:00:00",
78+
"2006-03-01 00:00:00",
79+
"2006-06-01 00:00:00",
80+
"2006-09-01 00:00:00",
81+
"2006-12-01 00:00:00",
82+
"2007-03-01 00:00:00",
83+
"2007-06-01 00:00:00",
84+
"2007-09-01 00:00:00",
85+
"2007-12-01 00:00:00",
86+
"2008-03-01 00:00:00",
87+
"2008-06-01 00:00:00",
88+
"2008-09-01 00:00:00",
89+
"2008-12-01 00:00:00",
90+
"2009-03-01 00:00:00",
91+
"2009-06-01 00:00:00",
92+
"2009-09-01 00:00:00",
93+
"2009-12-01 00:00:00",
94+
"2010-03-01 00:00:00",
95+
"2010-06-01 00:00:00",
96+
"2010-09-01 00:00:00",
97+
"2010-12-01 00:00:00",
98+
]
99+
idx = pd.to_datetime(index)
100+
aust = pd.Series(data, index=pd.DatetimeIndex(idx, freq=pd.infer_freq(idx)))
101+
46102

47103
@pytest.fixture(scope="module")
48104
def ses():
@@ -240,63 +296,6 @@ def setup_class(cls):
240296
)
241297
cls.livestock2_livestock = livestock2_livestock
242298

243-
# aust = pd.read_json(aust_json, typ='Series').sort_index()
244-
data = [
245-
41.727457999999999,
246-
24.04185,
247-
32.328102999999999,
248-
37.328707999999999,
249-
46.213152999999998,
250-
29.346326000000001,
251-
36.482909999999997,
252-
42.977719,
253-
48.901524999999999,
254-
31.180221,
255-
37.717880999999998,
256-
40.420211000000002,
257-
51.206862999999998,
258-
31.887228,
259-
40.978262999999998,
260-
43.772491000000002,
261-
55.558566999999996,
262-
33.850915000000001,
263-
42.076383,
264-
45.642291999999998,
265-
59.766779999999997,
266-
35.191876999999998,
267-
44.319737000000003,
268-
47.913736,
269-
]
270-
index = [
271-
"2005-03-01 00:00:00",
272-
"2005-06-01 00:00:00",
273-
"2005-09-01 00:00:00",
274-
"2005-12-01 00:00:00",
275-
"2006-03-01 00:00:00",
276-
"2006-06-01 00:00:00",
277-
"2006-09-01 00:00:00",
278-
"2006-12-01 00:00:00",
279-
"2007-03-01 00:00:00",
280-
"2007-06-01 00:00:00",
281-
"2007-09-01 00:00:00",
282-
"2007-12-01 00:00:00",
283-
"2008-03-01 00:00:00",
284-
"2008-06-01 00:00:00",
285-
"2008-09-01 00:00:00",
286-
"2008-12-01 00:00:00",
287-
"2009-03-01 00:00:00",
288-
"2009-06-01 00:00:00",
289-
"2009-09-01 00:00:00",
290-
"2009-12-01 00:00:00",
291-
"2010-03-01 00:00:00",
292-
"2010-06-01 00:00:00",
293-
"2010-09-01 00:00:00",
294-
"2010-12-01 00:00:00",
295-
]
296-
aust = pd.Series(data, index)
297-
aust.index = pd.DatetimeIndex(
298-
aust.index, freq=pd.infer_freq(aust.index)
299-
)
300299
cls.aust = aust
301300
cls.start_params = [
302301
1.5520372162082909e-09,
@@ -519,7 +518,11 @@ def test_holt_damp_r(self):
519518
# livestock2_livestock <- c(...)
520519
# res <- ets(livestock2_livestock, model='AAN', damped_trend=TRUE,
521520
# phi=0.98)
522-
mod = Holt(self.livestock2_livestock, damped_trend=True)
521+
mod = Holt(
522+
self.livestock2_livestock,
523+
damped_trend=True,
524+
initialization_method="estimated",
525+
)
523526
params = {
524527
"smoothing_level": 0.97402626,
525528
"smoothing_trend": 0.00010006,
@@ -1646,11 +1649,11 @@ def test_error_boxcox():
16461649
with pytest.raises(TypeError, match="use_boxcox must be True"):
16471650
ExponentialSmoothing(y, use_boxcox="a", initialization_method="known")
16481651

1649-
mod = ExponentialSmoothing(y ** 2, use_boxcox=True)
1652+
mod = ExponentialSmoothing(y**2, use_boxcox=True)
16501653
assert isinstance(mod, ExponentialSmoothing)
16511654

16521655
mod = ExponentialSmoothing(
1653-
y ** 2, use_boxcox=True, initialization_method="legacy-heuristic"
1656+
y**2, use_boxcox=True, initialization_method="legacy-heuristic"
16541657
)
16551658
with pytest.raises(ValueError, match="use_boxcox was set"):
16561659
mod.fit(use_boxcox=False)
@@ -1950,7 +1953,7 @@ def test_attributes(ses):
19501953

19511954
def test_summary_boxcox(ses):
19521955
mod = ExponentialSmoothing(
1953-
ses ** 2, use_boxcox=True, initialization_method="heuristic"
1956+
ses**2, use_boxcox=True, initialization_method="heuristic"
19541957
)
19551958
with pytest.raises(ValueError, match="use_boxcox was set at model"):
19561959
mod.fit(use_boxcox=True)
@@ -2111,3 +2114,40 @@ def test_invalid_index(reset_randomstate):
21112114
fitted = model.fit(optimized=True, use_brute=True)
21122115
with pytest.warns(ValueWarning, match="No supported"):
21132116
fitted.forecast(steps=157200)
2117+
2118+
2119+
def test_initial_level():
2120+
# GH 8634
2121+
series = [0.0, 0.0, 0.0, 100.0, 0.0, 0.0, 0.0]
2122+
es = ExponentialSmoothing(
2123+
series, initialization_method="known", initial_level=20.0
2124+
)
2125+
es_fit = es.fit()
2126+
es_fit.params
2127+
assert_allclose(es_fit.params["initial_level"], 20.0)
2128+
2129+
2130+
def test_all_initial_values():
2131+
fit1 = ExponentialSmoothing(
2132+
aust,
2133+
seasonal_periods=4,
2134+
trend="add",
2135+
seasonal="mul",
2136+
initialization_method="estimated",
2137+
).fit()
2138+
lvl = np.round(fit1.params["initial_level"])
2139+
trend = np.round(fit1.params["initial_trend"], 1)
2140+
seas = np.round(fit1.params["initial_seasons"], 1)
2141+
fit2 = ExponentialSmoothing(
2142+
aust,
2143+
seasonal_periods=4,
2144+
trend="add",
2145+
seasonal="mul",
2146+
initialization_method="known",
2147+
initial_level=lvl,
2148+
initial_trend=trend,
2149+
initial_seasonal=seas,
2150+
).fit()
2151+
assert_allclose(fit2.params["initial_level"], lvl)
2152+
assert_allclose(fit2.params["initial_trend"], trend)
2153+
assert_allclose(fit2.params["initial_seasons"], seas)

0 commit comments

Comments
 (0)