Skip to content

Commit 9298953

Browse files
authored
Merge pull request statsmodels#5971 from bashtage/fix-future-holtwinters
BUG: Fix a future issue in ExpSmooth
2 parents 4b463a4 + a387cda commit 9298953

File tree

3 files changed

+32
-19
lines changed

3 files changed

+32
-19
lines changed

statsmodels/tsa/holtwinters.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
from scipy.stats import boxcox
2222

2323
from statsmodels.base.model import Results
24-
from statsmodels.base.wrapper import populate_wrapper, union_dicts, ResultsWrapper
25-
from statsmodels.tools.validation import array_like
24+
from statsmodels.base.wrapper import (populate_wrapper, union_dicts,
25+
ResultsWrapper)
26+
from statsmodels.tools.validation import (array_like, bool_like, float_like,
27+
string_like, int_like)
2628
from statsmodels.tsa.base.tsa_model import TimeSeriesModel
2729
from statsmodels.tsa.tsatools import freq_to_period
2830
import statsmodels.tsa._exponential_smoothers as smoothers
@@ -488,10 +490,14 @@ def __init__(self, endog, trend=None, damped=False, seasonal=None,
488490
self.endog = self.endog
489491
self._y = self._data = array_like(endog, 'endog', contiguous=True,
490492
order='C')
493+
options = ("add", "mul", "additive", "multiplicative")
494+
trend = string_like(trend, 'trend', options=options, optional=True)
491495
if trend in ['additive', 'multiplicative']:
492496
trend = {'additive': 'add', 'multiplicative': 'mul'}[trend]
493497
self.trend = trend
494-
self.damped = damped
498+
self.damped = bool_like(damped, 'damped')
499+
seasonal = string_like(seasonal, 'seasonal', options=options,
500+
optional=True)
495501
if seasonal in ['additive', 'multiplicative']:
496502
seasonal = {'additive': 'add', 'multiplicative': 'mul'}[seasonal]
497503
self.seasonal = seasonal
@@ -504,7 +510,8 @@ def __init__(self, endog, trend=None, damped=False, seasonal=None,
504510
if self.damped and not self.trending:
505511
raise ValueError('Can only dampen the trend component')
506512
if self.seasoning:
507-
self.seasonal_periods = seasonal_periods
513+
self.seasonal_periods = int_like(seasonal_periods,
514+
'seasonal_periods', optional=True)
508515
if seasonal_periods is None:
509516
self.seasonal_periods = freq_to_period(self._index_freq)
510517
if self.seasonal_periods <= 1:
@@ -608,13 +615,15 @@ def fit(self, smoothing_level=None, smoothing_slope=None, smoothing_seasonal=Non
608615
"""
609616
# Variable renames to alpha,beta, etc as this helps with following the
610617
# mathematical notation in general
611-
alpha = smoothing_level
612-
beta = smoothing_slope
613-
gamma = smoothing_seasonal
614-
phi = damping_slope
615-
l0 = self._l0 = initial_level
616-
b0 = self._b0 = initial_slope
617-
618+
alpha = float_like(smoothing_level, 'smoothing_level', True)
619+
beta = float_like(smoothing_slope, 'smoothing_slope', True)
620+
gamma = float_like(smoothing_seasonal, 'smoothing_seasonal', True)
621+
phi = float_like(damping_slope, 'damping_slope', True)
622+
l0 = self._l0 = float_like(initial_level, 'initial_level', True)
623+
b0 = self._b0 = float_like(initial_slope, 'initial_slope', True)
624+
if start_params is not None:
625+
start_params = array_like(start_params, 'start_params',
626+
contiguous=True)
618627
data = self._data
619628
damped = self.damped
620629
seasoning = self.seasoning
@@ -675,18 +684,22 @@ def fit(self, smoothing_level=None, smoothing_slope=None, smoothing_seasonal=Non
675684
txi = txi.astype(np.bool)
676685
bounds = np.array([(0.0, 1.0), (0.0, 1.0), (0.0, 1.0),
677686
(0.0, None), (0.0, None), (0.0, 1.0)] + [(None, None), ] * m)
678-
args = (txi.astype(np.uint8), p, y, lvls, b, s, m, self.nobs, max_seen)
687+
args = (txi.astype(np.uint8), p, y, lvls, b, s, m, self.nobs,
688+
max_seen)
679689
if start_params is None and np.any(txi) and use_brute:
680-
res = brute(func, bounds[txi], args, Ns=20, full_output=True, finish=None)
690+
res = brute(func, bounds[txi], args, Ns=20,
691+
full_output=True, finish=None)
681692
p[txi], max_seen, _, _ = res
682693
else:
683694
if start_params is not None:
684-
start_params = np.atleast_1d(np.squeeze(start_params))
685695
if len(start_params) != xi.sum():
686-
raise ValueError('start_params must have {0} values but '
687-
'has {1} instead'.format(len(xi), len(start_params)))
696+
msg = 'start_params must have {0} values but ' \
697+
'has {1} instead'
698+
nxi, nsp = len(xi), len(start_params)
699+
raise ValueError(msg.format(nxi, nsp))
688700
p[xi] = start_params
689-
args = (xi.astype(np.uint8), p, y, lvls, b, s, m, self.nobs, max_seen)
701+
args = (xi.astype(np.uint8), p, y, lvls, b, s, m,
702+
self.nobs, max_seen)
690703
max_seen = func(np.ascontiguousarray(p[xi]), *args)
691704
# alpha, beta, gamma, l0, b0, phi = p[:6]
692705
# s0 = p[6:]

statsmodels/tsa/tests/test_holtwinters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def test_negative_multipliative(trend_seasonal):
374374
@pytest.mark.parametrize('seasonal', SEASONALS)
375375
def test_dampen_no_trend(seasonal):
376376
y = -np.ones(100)
377-
with pytest.raises(ValueError):
377+
with pytest.raises(TypeError):
378378
ExponentialSmoothing(housing_data, trend=False, seasonal=seasonal, damped=True,
379379
seasonal_periods=10)
380380

statsmodels/tsa/tsatools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,4 +804,4 @@ def freq_to_period(freq):
804804

805805
__all__ = ['lagmat', 'lagmat2ds','add_trend', 'duplication_matrix',
806806
'elimination_matrix', 'commutation_matrix',
807-
'vec', 'vech', 'unvec', 'unvech']
807+
'vec', 'vech', 'unvec', 'unvech', 'freq_to_period']

0 commit comments

Comments
 (0)