Skip to content

Commit f49c692

Browse files
authored
Merge pull request statsmodels#4995 from jbrockmendel/main_tools4
TST make tools.linalg __main__ section into tests
2 parents d2e4a83 + 1775ac1 commit f49c692

File tree

3 files changed

+30
-22
lines changed

3 files changed

+30
-22
lines changed

lint.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ if [ "$LINT" == true ]; then
2222
statsmodels/regression/mixed_linear_model.py \
2323
statsmodels/duration/__init__.py \
2424
statsmodels/regression/recursive_ls.py \
25+
statsmodels/tools/linalg.py \
26+
statsmodels/tools/tests/test_linalg.py \
2527
conftest.py
2628
if [ $? -ne "0" ]; then
2729
RET=1

statsmodels/tools/linalg.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from numpy.linalg import LinAlgError
1717

1818

19-
### Linear Least Squares
19+
# Linear Least Squares
2020

2121
def lstsq(a, b, cond=None, overwrite_a=0, overwrite_b=0):
2222
"""Compute least-squares solution to equation :m:`a x = b`
@@ -246,11 +246,11 @@ def stationary_solve(r, b):
246246
dim = b.ndim
247247
if b.ndim == 1:
248248
b = b[:, None]
249-
x = b[0:1,:]
249+
x = b[0:1, :]
250250

251251
for j in range(1, len(b)):
252252
rf = r[0:j][::-1]
253-
a = (b[j,:] - np.dot(rf, x)) / (1 - np.dot(rf, db[::-1]))
253+
a = (b[j, :] - np.dot(rf, x)) / (1 - np.dot(rf, db[::-1]))
254254
z = x - np.outer(db[::-1], a)
255255
x = np.concatenate((z, a[None, :]), axis=0)
256256

@@ -266,22 +266,3 @@ def stationary_solve(r, b):
266266
x = x[:, 0]
267267

268268
return x
269-
270-
271-
272-
if __name__ == '__main__':
273-
#for checking only,
274-
#Note on Windows32:
275-
# linalg doesn't always produce the same results in each call
276-
import scipy.linalg
277-
a0 = np.random.randn(100,10)
278-
b0 = a0.sum(1)[:, None] + np.random.randn(100,3)
279-
lstsq(a0,b0)
280-
pinv(a0)
281-
pinv2(a0)
282-
x = pinv(a0)
283-
x2=scipy.linalg.pinv(a0)
284-
print(np.max(np.abs(x-x2)))
285-
x = pinv2(a0)
286-
x2 = scipy.linalg.pinv2(a0)
287-
print(np.max(np.abs(x-x2)))

statsmodels/tools/tests/test_linalg.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from numpy.testing import assert_allclose
44
from scipy.linalg import toeplitz
55

6+
67
def test_stationary_solve_1d():
78
b = np.random.uniform(size=10)
89
r = np.random.uniform(size=9)
@@ -21,3 +22,27 @@ def test_stationary_solve_2d():
2122
soln = np.linalg.solve(tmat, b)
2223
soln1 = linalg.stationary_solve(r, b)
2324
assert_allclose(soln, soln1, rtol=1e-5, atol=1e-5)
25+
26+
27+
def test_scipy_equivalence():
28+
# This test was moved from the __main__ section of tools.linalg
29+
# Note on Windows32:
30+
# linalg doesn't always produce the same results in each call
31+
import scipy.linalg
32+
a0 = np.random.randn(100, 10)
33+
b0 = a0.sum(1)[:, None] + np.random.randn(100, 3)
34+
35+
result = linalg.pinv(a0)
36+
expected = scipy.linalg.pinv(a0)
37+
assert_allclose(result, expected)
38+
39+
result = linalg.pinv2(a0)
40+
expected = scipy.linalg.pinv2(a0)
41+
assert_allclose(result, expected)
42+
43+
result = linalg.lstsq(a0, b0)
44+
expected = scipy.linalg.lstsq(a0, b0)
45+
assert_allclose(result[0], expected[0])
46+
assert_allclose(result[1], expected[1])
47+
assert_allclose(result[2], expected[2])
48+
assert_allclose(result[3], expected[3])

0 commit comments

Comments
 (0)