Skip to content

Commit e09a830

Browse files
committed
BUG/REF/TST: add consistency checks, fix smaller bugs
1 parent 9b1c6d1 commit e09a830

File tree

3 files changed

+54
-6
lines changed

3 files changed

+54
-6
lines changed

statsmodels/robust/norms.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,17 @@ def weights(self, z):
255255
256256
weights(z) = t/\|z\| for \|z\| > t
257257
"""
258-
z = np.asarray(z)
258+
z_isscalar = np.isscalar(z)
259+
z = np.atleast_1d(z)
260+
259261
test = self._subset(z)
260262
absz = np.abs(z)
261263
absz[test] = 1.0
262-
return test + (1 - test) * self.t / absz
264+
v = test + (1 - test) * self.t / absz
265+
266+
if z_isscalar:
267+
v = v[0]
268+
return v
263269

264270
def psi_deriv(self, z):
265271
"""
@@ -269,7 +275,7 @@ def psi_deriv(self, z):
269275
-----
270276
Used to estimate the robust covariance matrix.
271277
"""
272-
return np.less_equal(np.abs(z), self.t)
278+
return np.less_equal(np.abs(z), self.t).astype(float)
273279

274280

275281
# TODO: untested, but looks right. RamsayE not available in R or SAS?
@@ -408,7 +414,7 @@ def rho(self, z):
408414
a = self.a
409415
z = np.asarray(z)
410416
test = self._subset(z)
411-
return (test * a * (1 - np.cos(z / a)) +
417+
return (test * a**2 * (0 - np.cos(z / a)) +
412418
(1 - test) * 2 * a)
413419

414420
def psi(self, z):
@@ -433,7 +439,7 @@ def psi(self, z):
433439
a = self.a
434440
z = np.asarray(z)
435441
test = self._subset(z)
436-
return test * np.sin(z / a)
442+
return test * a * np.sin(z / a)
437443

438444
def weights(self, z):
439445
r"""
@@ -477,7 +483,7 @@ def psi_deriv(self, z):
477483
"""
478484

479485
test = self._subset(z)
480-
return test*np.cos(z / self.a)/self.a
486+
return test * np.cos(z / self.a)
481487

482488

483489
# TODO: this is untested

statsmodels/robust/tests/results/results_norms.py

+9
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,12 @@
2222
0.0],
2323
weights=[0.0, 0.0, 0.5625, 0.87890625, 1.0, 0.87890625, 0.5625, 0.0, 0.0],
2424
)
25+
26+
res_huber = Holder(
27+
rho=[11.200487500000001, 7.165487499999999, 1.7854875000000001, 0.5, 0.0,
28+
0.5, 1.7854875000000001, 7.165487499999999, 11.200487500000001],
29+
psi=[-1.345, -1.345, -1.345, -1.0, 0.0, 1.0, 1.345, 1.345, 1.345],
30+
psi_deriv=[0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
31+
weights=[0.14944444444444444, 0.22416666666666665, 0.6725, 1.0, 1.0, 1.0,
32+
0.6725, 0.22416666666666665, 0.14944444444444444],
33+
)

statsmodels/robust/tests/test_norms.py

+33
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@
1313
cases = [
1414
(norms.Hampel, (1.5, 3.5, 8.), res_r.res_hampel),
1515
(norms.TukeyBiweight, (4,), res_r.res_biweight),
16+
(norms.HuberT, (1.345,), res_r.res_huber),
17+
]
18+
19+
norms_other = [
20+
(norms.LeastSquares, ()),
21+
(norms.TrimmedMean, (1.9,)), # avoid arg at integer used in example
22+
(norms.AndrewWave, ()),
23+
(norms.RamsayE, ()),
24+
# norms.MQuantileNorm, # requires keywords in init
1625
]
1726

1827
dtypes = ["int", np.float64, np.complex128]
@@ -22,6 +31,10 @@
2231
@pytest.mark.parametrize("case", cases)
2332
def test_norm(case, dtype):
2433
ncls, args, res = case
34+
if ncls in [norms.HuberT] and dtype == np.complex128:
35+
# skip for now
36+
return
37+
2538
norm = ncls(*args)
2639
x = np.array([-9, -6, -2, -1, 0, 1, 2, 6, 9], dtype=dtype)
2740

@@ -56,3 +69,23 @@ def test_norm(case, dtype):
5669
for meth in methods:
5770
resm = [getattr(norm, meth)(xi) for xi in x]
5871
assert_allclose(resm, getattr(res, meth))
72+
73+
74+
@pytest.mark.parametrize("case", norms_other)
75+
def test_norms_consistent(case):
76+
# test that norm methods are consistent with each other
77+
ncls, args = case
78+
norm = ncls(*args)
79+
x = np.array([-9, -6, -2, -1, 0, 1, 2, 6, 9], dtype=float)
80+
81+
weights = norm.weights(x)
82+
# rho = norm.rho(x) # not used
83+
psi = norm.psi(x)
84+
psi_deriv = norm.psi_deriv(x)
85+
86+
# avoid zero division nan:
87+
assert_allclose(weights, (psi + 1e-50) / (x + 1e-50), rtol=1e-6, atol=1e-8)
88+
psid = _approx_fprime_scalar(x, norm.rho)
89+
assert_allclose(psi, psid, rtol=1e-6, atol=1e-8)
90+
psidd = _approx_fprime_scalar(x, norm.psi)
91+
assert_allclose(psi_deriv, psidd, rtol=1e-6, atol=1e-8)

0 commit comments

Comments
 (0)