|
11 | 11 | import numpy as np |
12 | 12 | from numpy.testing import (assert_array_equal, assert_allclose, assert_raises, |
13 | 13 | assert_) |
| 14 | +from numba import njit |
14 | 15 | from quantecon.random import probvec, sample_without_replacement, draw |
15 | 16 |
|
16 | 17 |
|
@@ -67,32 +68,43 @@ def test_sample_without_replacement_value_error(): |
67 | 68 |
|
68 | 69 | # draw # |
69 | 70 |
|
| 71 | +@njit |
| 72 | +def draw_jitted(cdf, size=None): |
| 73 | + return draw(cdf, size) |
| 74 | + |
| 75 | + |
70 | 76 | class TestDraw: |
71 | 77 | def setup_method(self): |
72 | 78 | self.pmf = np.array([0.4, 0.1, 0.5]) |
73 | 79 | self.cdf = np.cumsum(self.pmf) |
74 | 80 | self.n = len(self.pmf) |
| 81 | + self.draw_funcs = [draw, draw_jitted] |
75 | 82 |
|
76 | 83 | def test_return_types(self): |
77 | | - out = draw(self.cdf) |
78 | | - assert_(isinstance(out, numbers.Integral)) |
| 84 | + for func in self.draw_funcs: |
| 85 | + out = func(self.cdf) |
| 86 | + assert_(isinstance(out, numbers.Integral)) |
79 | 87 |
|
80 | 88 | size = 10 |
81 | | - out = draw(self.cdf, size) |
82 | | - assert_(out.shape == (size,)) |
| 89 | + for func in self.draw_funcs: |
| 90 | + out = func(self.cdf, size) |
| 91 | + assert_(out.shape == (size,)) |
83 | 92 |
|
84 | 93 | def test_return_values(self): |
85 | | - out = draw(self.cdf) |
86 | | - assert_(out in range(self.n)) |
| 94 | + for func in self.draw_funcs: |
| 95 | + out = func(self.cdf) |
| 96 | + assert_(out in range(self.n)) |
87 | 97 |
|
88 | 98 | size = 10 |
89 | | - out = draw(self.cdf, size) |
90 | | - assert_(np.isin(out, range(self.n)).all()) |
| 99 | + for func in self.draw_funcs: |
| 100 | + out = func(self.cdf, size) |
| 101 | + assert_(np.isin(out, range(self.n)).all()) |
91 | 102 |
|
92 | 103 | def test_lln(self): |
93 | 104 | size = 1000000 |
94 | | - out = draw(self.cdf, size) |
95 | | - hist, bin_edges = np.histogram(out, bins=self.n, density=True) |
96 | | - pmf_computed = hist * np.diff(bin_edges) |
97 | | - atol = 1e-2 |
98 | | - assert_allclose(pmf_computed, self.pmf, atol=atol) |
| 105 | + for func in self.draw_funcs: |
| 106 | + out = func(self.cdf, size) |
| 107 | + hist, bin_edges = np.histogram(out, bins=self.n, density=True) |
| 108 | + pmf_computed = hist * np.diff(bin_edges) |
| 109 | + atol = 1e-2 |
| 110 | + assert_allclose(pmf_computed, self.pmf, atol=atol) |
0 commit comments