Skip to content

Commit a7ac221

Browse files
committed
BUG/REF/DOC: more internal consistency fixes, update docstrings
1 parent e09a830 commit a7ac221

File tree

2 files changed

+27
-24
lines changed

2 files changed

+27
-24
lines changed

statsmodels/robust/norms.py

+21-23
Original file line numberDiff line numberDiff line change
@@ -406,16 +406,16 @@ def rho(self, z):
406406
Returns
407407
-------
408408
rho : ndarray
409-
rho(z) = a*(1-cos(z/a)) for \|z\| <= a*pi
409+
rho(z) = a**2 *(1-cos(z/a)) for \|z\| <= a*pi
410410
411-
rho(z) = 2*a for \|z\| > a*pi
411+
rho(z) = 2*a for \|z\| > a*pi
412412
"""
413413

414414
a = self.a
415415
z = np.asarray(z)
416416
test = self._subset(z)
417-
return (test * a**2 * (0 - np.cos(z / a)) +
418-
(1 - test) * 2 * a)
417+
return (test * a**2 * (1 - np.cos(z / a)) +
418+
(1 - test) * a**2 * 2)
419419

420420
def psi(self, z):
421421
r"""
@@ -431,7 +431,7 @@ def psi(self, z):
431431
Returns
432432
-------
433433
psi : ndarray
434-
psi(z) = sin(z/a) for \|z\| <= a*pi
434+
psi(z) = a * sin(z/a) for \|z\| <= a*pi
435435
436436
psi(z) = 0 for \|z\| > a*pi
437437
"""
@@ -455,9 +455,9 @@ def weights(self, z):
455455
Returns
456456
-------
457457
weights : ndarray
458-
weights(z) = sin(z/a)/(z/a) for \|z\| <= a*pi
458+
weights(z) = sin(z/a) / (z/a) for \|z\| <= a*pi
459459
460-
weights(z) = 0 for \|z\| > a*pi
460+
weights(z) = 0 for \|z\| > a*pi
461461
"""
462462
a = self.a
463463
z = np.asarray(z)
@@ -527,12 +527,12 @@ def rho(self, z):
527527
rho : ndarray
528528
rho(z) = (1/2.)*z**2 for \|z\| <= c
529529
530-
rho(z) = 0 for \|z\| > c
530+
rho(z) = (1/2.)*c**2 for \|z\| > c
531531
"""
532532

533533
z = np.asarray(z)
534534
test = self._subset(z)
535-
return test * z**2 * 0.5
535+
return test * z**2 * 0.5 + (1 - test) * self.c**2 * 0.5
536536

537537
def psi(self, z):
538538
r"""
@@ -635,13 +635,13 @@ def rho(self, z):
635635
Returns
636636
-------
637637
rho : ndarray
638-
rho(z) = (1/2.)*z**2 for \|z\| <= a
638+
rho(z) = z**2 / 2 for \|z\| <= a
639639
640-
rho(z) = a*\|z\| - 1/2.*a**2 for a < \|z\| <= b
640+
rho(z) = a*\|z\| - 1/2.*a**2 for a < \|z\| <= b
641641
642-
rho(z) = a*(c*\|z\|-(1/2.)*z**2)/(c-b) for b < \|z\| <= c
642+
rho(z) = a*(c - \|z\|)**2 / (c - b) / 2 for b < \|z\| <= c
643643
644-
rho(z) = a*(b + c - a) for \|z\| > c
644+
rho(z) = a*(b + c - a) / 2 for \|z\| > c
645645
"""
646646
a, b, c = self.a, self.b, self.c
647647

@@ -654,7 +654,8 @@ def rho(self, z):
654654
v = np.zeros(z.shape, dtype=dt)
655655
z = np.abs(z)
656656
v[t1] = z[t1]**2 * 0.5
657-
v[t2] = (a * (z[t2] - a) + a**2 * 0.5)
657+
# v[t2] = (a * (z[t2] - a) + a**2 * 0.5)
658+
v[t2] = (a * z[t2] - a**2 * 0.5)
658659
v[t3] = a * (c - z[t3])**2 / (c - b) * (-0.5)
659660
v[t34] += a * (b + c - a) * 0.5
660661

@@ -694,14 +695,11 @@ def psi(self, z):
694695
dt = np.promote_types(z.dtype, "float")
695696
v = np.zeros(z.shape, dtype=dt)
696697
s = np.sign(z)
697-
z = np.abs(z)
698+
za = np.abs(z)
698699

699-
v[t1] = z[t1] * s[t1]
700+
v[t1] = z[t1]
700701
v[t2] = a * s[t2]
701-
v[t3] = a * s[t3] * (c - z[t3]) / (c - b)
702-
# v = (t1 * z*s +
703-
# t2 * a*s +
704-
# t3 * a*s * (c - z) / (c - b))
702+
v[t3] = a * s[t3] * (c - za[t3]) / (c - b)
705703

706704
if z_isscalar:
707705
v = v[0]
@@ -721,13 +719,13 @@ def weights(self, z):
721719
Returns
722720
-------
723721
weights : ndarray
724-
weights(z) = 1 for \|z\| <= a
722+
weights(z) = 1 for \|z\| <= a
725723
726-
weights(z) = a/\|z\| for a < \|z\| <= b
724+
weights(z) = a/\|z\| for a < \|z\| <= b
727725
728726
weights(z) = a*(c - \|z\|)/(\|z\|*(c-b)) for b < \|z\| <= c
729727
730-
weights(z) = 0 for \|z\| > c
728+
weights(z) = 0 for \|z\| > c
731729
"""
732730
a, b, c = self.a, self.b, self.c
733731

statsmodels/robust/tests/test_norms.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,15 @@ def test_norms_consistent(case):
7979
x = np.array([-9, -6, -2, -1, 0, 1, 2, 6, 9], dtype=float)
8080

8181
weights = norm.weights(x)
82-
# rho = norm.rho(x) # not used
82+
rho = norm.rho(x) # not used
8383
psi = norm.psi(x)
8484
psi_deriv = norm.psi_deriv(x)
8585

86+
# check location and u-shape of rho
87+
assert rho[4] == 0
88+
assert np.all(np.diff(rho[4:]) >= 0)
89+
assert np.all(np.diff(rho[:4]) <= 0)
90+
8691
# avoid zero division nan:
8792
assert_allclose(weights, (psi + 1e-50) / (x + 1e-50), rtol=1e-6, atol=1e-8)
8893
psid = _approx_fprime_scalar(x, norm.rho)

0 commit comments

Comments
 (0)