|
11 | 11 | import numpy as np
|
12 | 12 | from numpy.testing import (assert_array_equal, assert_allclose, assert_raises,
|
13 | 13 | assert_)
|
| 14 | +from numba import njit |
14 | 15 | from quantecon.random import probvec, sample_without_replacement, draw
|
15 | 16 |
|
16 | 17 |
|
@@ -67,32 +68,43 @@ def test_sample_without_replacement_value_error():
|
67 | 68 |
|
68 | 69 | # draw #
|
69 | 70 |
|
| 71 | +@njit |
| 72 | +def draw_jitted(cdf, size=None): |
| 73 | + return draw(cdf, size) |
| 74 | + |
| 75 | + |
70 | 76 | class TestDraw:
|
71 | 77 | def setup_method(self):
|
72 | 78 | self.pmf = np.array([0.4, 0.1, 0.5])
|
73 | 79 | self.cdf = np.cumsum(self.pmf)
|
74 | 80 | self.n = len(self.pmf)
|
| 81 | + self.draw_funcs = [draw, draw_jitted] |
75 | 82 |
|
76 | 83 | def test_return_types(self):
|
77 |
| - out = draw(self.cdf) |
78 |
| - assert_(isinstance(out, numbers.Integral)) |
| 84 | + for func in self.draw_funcs: |
| 85 | + out = func(self.cdf) |
| 86 | + assert_(isinstance(out, numbers.Integral)) |
79 | 87 |
|
80 | 88 | size = 10
|
81 |
| - out = draw(self.cdf, size) |
82 |
| - assert_(out.shape == (size,)) |
| 89 | + for func in self.draw_funcs: |
| 90 | + out = func(self.cdf, size) |
| 91 | + assert_(out.shape == (size,)) |
83 | 92 |
|
84 | 93 | def test_return_values(self):
|
85 |
| - out = draw(self.cdf) |
86 |
| - assert_(out in range(self.n)) |
| 94 | + for func in self.draw_funcs: |
| 95 | + out = func(self.cdf) |
| 96 | + assert_(out in range(self.n)) |
87 | 97 |
|
88 | 98 | size = 10
|
89 |
| - out = draw(self.cdf, size) |
90 |
| - assert_(np.isin(out, range(self.n)).all()) |
| 99 | + for func in self.draw_funcs: |
| 100 | + out = func(self.cdf, size) |
| 101 | + assert_(np.isin(out, range(self.n)).all()) |
91 | 102 |
|
92 | 103 | def test_lln(self):
|
93 | 104 | size = 1000000
|
94 |
| - out = draw(self.cdf, size) |
95 |
| - hist, bin_edges = np.histogram(out, bins=self.n, density=True) |
96 |
| - pmf_computed = hist * np.diff(bin_edges) |
97 |
| - atol = 1e-2 |
98 |
| - assert_allclose(pmf_computed, self.pmf, atol=atol) |
| 105 | + for func in self.draw_funcs: |
| 106 | + out = func(self.cdf, size) |
| 107 | + hist, bin_edges = np.histogram(out, bins=self.n, density=True) |
| 108 | + pmf_computed = hist * np.diff(bin_edges) |
| 109 | + atol = 1e-2 |
| 110 | + assert_allclose(pmf_computed, self.pmf, atol=atol) |
0 commit comments