-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
Copy pathtest_numba.py
77 lines (60 loc) · 2.33 KB
/
test_numba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Tests for Numba support utilities
"""
import numpy as np
from numpy.testing import assert_array_equal, assert_
from numba import jit
from quantecon.util.numba import _numba_linalg_solve, comb_jit
@jit(nopython=True)
def _numba_linalg_solve_jitted(a, b):
return _numba_linalg_solve(a, b)
@jit(nopython=True)
def numba_linalg_solve_orig(a, b):
return np.linalg.solve(a, b)
class TestNumbaLinalgSolve:
def setup_method(self):
self.dtypes = [np.float32, np.float64]
self.a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]])
self.b_1dim = np.array([2, 4, -1])
self.b_2dim = np.array([[2, 3], [4, 1], [-1, 0]])
self.a_singular = np.array([[0, 1, 2], [3, 4, 5], [3, 5, 7]])
def test_b_1dim(self):
for dtype in self.dtypes:
a = np.asfortranarray(self.a, dtype=dtype)
b = np.asfortranarray(self.b_1dim, dtype=dtype)
sol_orig = numba_linalg_solve_orig(a, b)
r = _numba_linalg_solve_jitted(a, b)
assert_(r == 0)
assert_array_equal(b, sol_orig)
def test_b_2dim(self):
for dtype in self.dtypes:
a = np.asfortranarray(self.a, dtype=dtype)
b = np.asfortranarray(self.b_2dim, dtype=dtype)
sol_orig = numba_linalg_solve_orig(a, b)
r = _numba_linalg_solve_jitted(a, b)
assert_(r == 0)
assert_array_equal(b, sol_orig)
def test_singular_a(self):
for b in [self.b_1dim, self.b_2dim]:
for dtype in self.dtypes:
a = np.asfortranarray(self.a_singular, dtype=dtype)
b = np.asfortranarray(b, dtype=dtype)
r = _numba_linalg_solve_jitted(a, b)
assert_(r != 0)
class TestCombJit:
def setup_method(self):
self.MAX_INTP = np.iinfo(np.intp).max
def test_comb(self):
N, k = 10, 3
N_choose_k = 120
assert_(comb_jit(N, k) == N_choose_k)
def test_comb_zeros(self):
assert_(comb_jit(2, 3) == 0)
assert_(comb_jit(-1, 3) == 0)
assert_(comb_jit(2, -1) == 0)
assert_(comb_jit(self.MAX_INTP, 2) == 0)
N = np.intp(self.MAX_INTP**0.5 * 2**0.5) + 1
assert_(comb_jit(N, 2) == 0)
def test_max_intp(self):
assert_(comb_jit(self.MAX_INTP, 0) == 1)
assert_(comb_jit(self.MAX_INTP, 1) == self.MAX_INTP)