Skip to content

Commit 8e241a0

Browse files
committed
RFC: Replace @generated_jit with @overload in random.draw
To avoid `NumbaDeprecationWarning`
1 parent 482f64a commit 8e241a0

File tree

2 files changed

+42
-16
lines changed

2 files changed

+42
-16
lines changed

quantecon/random/tests/test_utilities.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
from numpy.testing import (assert_array_equal, assert_allclose, assert_raises,
1313
assert_)
14+
from numba import njit
1415
from quantecon.random import probvec, sample_without_replacement, draw
1516

1617

@@ -67,32 +68,43 @@ def test_sample_without_replacement_value_error():
6768

6869
# draw #
6970

71+
@njit
72+
def draw_jitted(cdf, size=None):
73+
return draw(cdf, size)
74+
75+
7076
class TestDraw:
7177
def setup_method(self):
7278
self.pmf = np.array([0.4, 0.1, 0.5])
7379
self.cdf = np.cumsum(self.pmf)
7480
self.n = len(self.pmf)
81+
self.draw_funcs = [draw, draw_jitted]
7582

7683
def test_return_types(self):
77-
out = draw(self.cdf)
78-
assert_(isinstance(out, numbers.Integral))
84+
for func in self.draw_funcs:
85+
out = func(self.cdf)
86+
assert_(isinstance(out, numbers.Integral))
7987

8088
size = 10
81-
out = draw(self.cdf, size)
82-
assert_(out.shape == (size,))
89+
for func in self.draw_funcs:
90+
out = func(self.cdf, size)
91+
assert_(out.shape == (size,))
8392

8493
def test_return_values(self):
85-
out = draw(self.cdf)
86-
assert_(out in range(self.n))
94+
for func in self.draw_funcs:
95+
out = func(self.cdf)
96+
assert_(out in range(self.n))
8797

8898
size = 10
89-
out = draw(self.cdf, size)
90-
assert_(np.isin(out, range(self.n)).all())
99+
for func in self.draw_funcs:
100+
out = func(self.cdf, size)
101+
assert_(np.isin(out, range(self.n)).all())
91102

92103
def test_lln(self):
93104
size = 1000000
94-
out = draw(self.cdf, size)
95-
hist, bin_edges = np.histogram(out, bins=self.n, density=True)
96-
pmf_computed = hist * np.diff(bin_edges)
97-
atol = 1e-2
98-
assert_allclose(pmf_computed, self.pmf, atol=atol)
105+
for func in self.draw_funcs:
106+
out = func(self.cdf, size)
107+
hist, bin_edges = np.histogram(out, bins=self.n, density=True)
108+
pmf_computed = hist * np.diff(bin_edges)
109+
atol = 1e-2
110+
assert_allclose(pmf_computed, self.pmf, atol=atol)

quantecon/random/utilities.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
"""
55

66
import numpy as np
7-
from numba import guvectorize, generated_jit, types
8-
7+
from numba import guvectorize, types
8+
from numba.extending import overload
99
from ..util import check_random_state, searchsorted
1010

1111

@@ -169,7 +169,7 @@ def _sample_without_replacement(n, r, out):
169169
pool[idx] = pool[n-j-1]
170170

171171

172-
@generated_jit(nopython=True)
172+
# Pure python implementation that will run if the JIT compiler is disabled
173173
def draw(cdf, size=None):
174174
"""
175175
Generate a random sample according to the cumulative distribution
@@ -198,6 +198,20 @@ def draw(cdf, size=None):
198198
array([1, 0, 1, 0, 1, 0, 0, 0, 1, 0])
199199
200200
"""
201+
if isinstance(size, int):
202+
rs = np.random.random(size)
203+
out = np.empty(size, dtype=np.int_)
204+
for i in range(size):
205+
out[i] = searchsorted(cdf, rs[i])
206+
return out
207+
else:
208+
r = np.random.random()
209+
return searchsorted(cdf, r)
210+
211+
212+
# Overload for the `draw` function
213+
@overload(draw)
214+
def ol_draw(cdf, size):
201215
if isinstance(size, types.Integer):
202216
def draw_impl(cdf, size):
203217
rs = np.random.random(size)

0 commit comments

Comments
 (0)