Skip to content

Commit 20b2f6c

Browse files
committed
REF/ENH: robust norms dtype, plus pep-8, refactor for consistency
1 parent b50de1b commit 20b2f6c

File tree

3 files changed

+99
-29
lines changed

3 files changed

+99
-29
lines changed

statsmodels/robust/norms.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33
# TODO: add plots to weighting functions for online docs.
44

55

6+
def _cabs(x):
7+
"""absolute value function that changes complex sign based on real sign
8+
9+
This could be useful for complex step derivatives of functions that
10+
need abs. Not yet used.
11+
"""
12+
sign = (x.real >= 0) * 2 - 1
13+
return sign * x
14+
15+
616
class RobustNorm:
717
"""
818
The parent class for the norms used for robust regression.
@@ -627,23 +637,25 @@ def rho(self, z):
627637
628638
rho(z) = a*(b + c - a) for \|z\| > c
629639
"""
640+
a, b, c = self.a, self.b, self.c
630641

631-
z = np.abs(z)
632642
z_isscalar = np.isscalar(z)
633643
z = np.atleast_1d(z)
634-
a = self.a; b = self.b; c = self.c
644+
635645
t1, t2, t3 = self._subset(z)
636-
t34 = ~(t1 | t2 )
637-
v = np.zeros(z.shape, float)
646+
t34 = ~(t1 | t2)
647+
dt = np.promote_types(z.dtype, "float")
648+
v = np.zeros(z.shape, dtype=dt)
649+
z = np.abs(z)
638650
v[t1] = z[t1]**2 * 0.5
639651
v[t2] = (a * (z[t2] - a) + a**2 * 0.5)
640-
v[t3] = a * (c - z[t3])**2 / (c - b) * (-0.5)
652+
v[t3] = a * (c - z[t3])**2 / (c - b) * (-0.5)
641653
v[t34] += a * (b + c - a) * 0.5
642654

643655
if z_isscalar:
644-
return v[0]
645-
else:
646-
return v
656+
v = v[0]
657+
658+
return v
647659

648660
def psi(self, z):
649661
r"""
@@ -667,14 +679,26 @@ def psi(self, z):
667679
668680
psi(z) = 0 for \|z\| > c
669681
"""
670-
z = np.asarray(z)
671-
a = self.a; b = self.b; c = self.c
682+
a, b, c = self.a, self.b, self.c
683+
684+
z_isscalar = np.isscalar(z)
685+
z = np.atleast_1d(z)
686+
672687
t1, t2, t3 = self._subset(z)
688+
dt = np.promote_types(z.dtype, "float")
689+
v = np.zeros(z.shape, dtype=dt)
673690
s = np.sign(z)
674691
z = np.abs(z)
675-
v = (t1 * z*s +
676-
t2 * a*s +
677-
t3 * a*s * (c - z) / (c - b))
692+
693+
v[t1] = z[t1] * s[t1]
694+
v[t2] = a * s[t2]
695+
v[t3] = a * s[t3] * (c - z[t3]) / (c - b)
696+
# v = (t1 * z*s +
697+
# t2 * a*s +
698+
# t3 * a*s * (c - z) / (c - b))
699+
700+
if z_isscalar:
701+
v = v[0]
678702
return v
679703

680704
def weights(self, z):
@@ -699,32 +723,43 @@ def weights(self, z):
699723
700724
weights(z) = 0 for \|z\| > c
701725
"""
702-
z = np.asarray(z)
703-
a = self.a
704-
b = self.b
705-
c = self.c
726+
a, b, c = self.a, self.b, self.c
727+
728+
z_isscalar = np.isscalar(z)
729+
z = np.atleast_1d(z)
730+
706731
t1, t2, t3 = self._subset(z)
707732

708-
v = np.zeros_like(z)
733+
dt = np.promote_types(z.dtype, "float")
734+
v = np.zeros(z.shape, dtype=dt)
709735
v[t1] = 1.0
710736
abs_z = np.abs(z)
711737
v[t2] = a / abs_z[t2]
712738
abs_zt3 = abs_z[t3]
713739
v[t3] = a * (c - abs_zt3) / (abs_zt3 * (c - b))
714-
v[np.where(np.isnan(v))] = 1. # TODO: for some reason 0 returns a nan?
740+
741+
if z_isscalar:
742+
v = v[0]
715743
return v
716744

717745
def psi_deriv(self, z):
718746
"""Derivative of psi function, second derivative of rho function.
719747
"""
720-
t1, _, t3 = self._subset(z)
721748
a, b, c = self.a, self.b, self.c
749+
750+
z_isscalar = np.isscalar(z)
722751
z = np.atleast_1d(z)
723-
# default is t1
724-
d = np.zeros_like(z)
752+
753+
t1, _, t3 = self._subset(z)
754+
755+
dt = np.promote_types(z.dtype, "float")
756+
d = np.zeros(z.shape, dtype=dt)
725757
d[t1] = 1.0
726758
zt3 = z[t3]
727759
d[t3] = -(a * np.sign(zt3) * zt3) / (np.abs(zt3) * (c - b))
760+
761+
if z_isscalar:
762+
d = d[0]
728763
return d
729764

730765

statsmodels/robust/tests/results/results_norms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
-0.3333333333333333, 0.0],
1212
weights=[0.0, 0.1111111111111111, 0.75, 1.0, 1.0, 1.0, 0.75,
1313
0.1111111111111111, 0.0],
14-
)
14+
)

statsmodels/robust/tests/test_norms.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,54 @@
44
from numpy.testing import assert_allclose
55

66
from statsmodels.robust import norms
7+
from statsmodels.tools.numdiff import (
8+
_approx_fprime_scalar,
9+
# _approx_fprime_cs_scalar, # not yet
10+
)
711
from .results import results_norms as res_r
812

913
cases = [
1014
(norms.Hampel, (1.5, 3.5, 8.), res_r.res_hampel)
1115
]
1216

17+
dtypes = ["int", np.float64, np.complex128]
18+
19+
20+
@pytest.mark.parametrize("dtype", dtypes)
1321
@pytest.mark.parametrize("case", cases)
14-
def test_norm(case):
22+
def test_norm(case, dtype):
1523
ncls, args, res = case
1624
norm = ncls(*args)
17-
x = np.array([-9., -6, -2, -1, 0, 1, 2, 6, 9])
25+
x = np.array([-9, -6, -2, -1, 0, 1, 2, 6, 9], dtype=dtype)
26+
27+
weights = norm.weights(x)
28+
rho = norm.rho(x)
29+
psi = norm.psi(x)
30+
psi_deriv = norm.psi_deriv(x)
31+
assert_allclose(weights, res.weights, rtol=1e-12, atol=1e-20)
32+
assert_allclose(rho, res.rho, rtol=1e-12, atol=1e-20)
33+
assert_allclose(psi, res.psi, rtol=1e-12, atol=1e-20)
34+
assert_allclose(psi_deriv, res.psi_deriv, rtol=1e-12, atol=1e-20)
35+
36+
dtype2 = np.promote_types(dtype, "float")
37+
assert weights.dtype == dtype2
38+
assert rho.dtype == dtype2
39+
assert psi.dtype == dtype2
40+
assert psi_deriv.dtype == dtype2
41+
42+
psid = _approx_fprime_scalar(x, norm.rho)
43+
assert_allclose(psid, res.psi, rtol=1e-6, atol=1e-8)
44+
psidd = _approx_fprime_scalar(x, norm.psi)
45+
assert_allclose(psidd, res.psi_deriv, rtol=1e-6, atol=1e-8)
46+
47+
# complex step derivatives are not yet supported if method uses np.abs
48+
# psid = _approx_fprime_cs_scalar(x, norm.rho)
49+
# assert_allclose(psid, res.psi, rtol=1e-12, atol=1e-20)
50+
# psidd = _approx_fprime_cs_scalar(x, norm.psi)
51+
# assert_allclose(psidd, res.psi_deriv, rtol=1e-12, atol=1e-20)
1852

19-
assert_allclose(norm.weights(x), res.weights, rtol=1e-12, atol=1e-20)
20-
assert_allclose(norm.rho(x), res.rho, rtol=1e-12, atol=1e-20)
21-
assert_allclose(norm.psi(x), res.psi, rtol=1e-12, atol=1e-20)
22-
assert_allclose(norm.psi_deriv(x), res.psi_deriv, rtol=1e-12, atol=1e-20)
53+
# check scalar value
54+
methods = ["weights", "rho", "psi", "psi_deriv"]
55+
for meth in methods:
56+
resm = [getattr(norm, meth)(xi) for xi in x]
57+
assert_allclose(resm, getattr(res, meth))

0 commit comments

Comments
 (0)