Skip to content

RFC: Replace @generated_jit with @overload #701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ exclude_lines =
@jit
@jit\(.*nopython=True
@njit
@generated_jit\(.*nopython=True
@overload
@guvectorize\(.*nopython=True
2 changes: 1 addition & 1 deletion docs/rtd-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
sphinx
ipython
numpydoc
numba>=0.38
numba>=0.49
numpy>=1.17
sympy
scipy>=1.5
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ keywords = [
dynamic = ["description", "version"]
requires-python = ">=3.7"
dependencies = [
'numba',
'numba>=0.49.0',
'numpy>=1.17.0',
'requests',
'scipy>=1.5.0',
Expand Down
11 changes: 8 additions & 3 deletions quantecon/_compute_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import time
import warnings
import numpy as np
from numba import jit, generated_jit, types
from numba import jit, types
from numba.extending import overload
from .game_theory.lemke_howson import _lemke_howson_tbl, _get_mixed_actions


Expand Down Expand Up @@ -352,8 +353,12 @@ def _initialize_tableaux_ig(X, Y, tableaux, bases):
return tableaux, bases


@generated_jit(nopython=True, cache=True)
def _square_sum(a):
def _square_sum(a): # pragma: no cover
pass


@overload(_square_sum, jit_options={'cache':True})
def _square_sum_ol(a):
if isinstance(a, types.Number):
return lambda a: a**2
elif isinstance(a, types.Array):
Expand Down
38 changes: 25 additions & 13 deletions quantecon/random/tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
from numpy.testing import (assert_array_equal, assert_allclose, assert_raises,
assert_)
from numba import njit
from quantecon.random import probvec, sample_without_replacement, draw


Expand Down Expand Up @@ -67,32 +68,43 @@ def test_sample_without_replacement_value_error():

# draw #

@njit
def draw_jitted(cdf, size=None):
return draw(cdf, size)


class TestDraw:
def setup_method(self):
self.pmf = np.array([0.4, 0.1, 0.5])
self.cdf = np.cumsum(self.pmf)
self.n = len(self.pmf)
self.draw_funcs = [draw, draw_jitted]

def test_return_types(self):
out = draw(self.cdf)
assert_(isinstance(out, numbers.Integral))
for func in self.draw_funcs:
out = func(self.cdf)
assert_(isinstance(out, numbers.Integral))

size = 10
out = draw(self.cdf, size)
assert_(out.shape == (size,))
for func in self.draw_funcs:
out = func(self.cdf, size)
assert_(out.shape == (size,))

def test_return_values(self):
out = draw(self.cdf)
assert_(out in range(self.n))
for func in self.draw_funcs:
out = func(self.cdf)
assert_(out in range(self.n))

size = 10
out = draw(self.cdf, size)
assert_(np.isin(out, range(self.n)).all())
for func in self.draw_funcs:
out = func(self.cdf, size)
assert_(np.isin(out, range(self.n)).all())

def test_lln(self):
size = 1000000
out = draw(self.cdf, size)
hist, bin_edges = np.histogram(out, bins=self.n, density=True)
pmf_computed = hist * np.diff(bin_edges)
atol = 1e-2
assert_allclose(pmf_computed, self.pmf, atol=atol)
for func in self.draw_funcs:
out = func(self.cdf, size)
hist, bin_edges = np.histogram(out, bins=self.n, density=True)
pmf_computed = hist * np.diff(bin_edges)
atol = 1e-2
assert_allclose(pmf_computed, self.pmf, atol=atol)
22 changes: 18 additions & 4 deletions quantecon/random/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""

import numpy as np
from numba import guvectorize, generated_jit, types

from numba import guvectorize, types
from numba.extending import overload
from ..util import check_random_state, searchsorted


Expand Down Expand Up @@ -63,7 +63,7 @@ def probvec(m, k, random_state=None, parallel=True):
return x


def _probvec(r, out):
def _probvec(r, out): # pragma: no cover
"""
Fill `out` with randomly sampled probability vectors as rows.

Expand Down Expand Up @@ -169,7 +169,7 @@ def _sample_without_replacement(n, r, out):
pool[idx] = pool[n-j-1]


@generated_jit(nopython=True)
# Pure python implementation that will run if the JIT compiler is disabled
def draw(cdf, size=None):
"""
Generate a random sample according to the cumulative distribution
Expand Down Expand Up @@ -198,6 +198,20 @@ def draw(cdf, size=None):
array([1, 0, 1, 0, 1, 0, 0, 0, 1, 0])

"""
if isinstance(size, int):
rs = np.random.random(size)
out = np.empty(size, dtype=np.int_)
for i in range(size):
out[i] = searchsorted(cdf, rs[i])
return out
else:
r = np.random.random()
return searchsorted(cdf, r)


# Overload for the `draw` function
@overload(draw)
def ol_draw(cdf, size):
if isinstance(size, types.Integer):
def draw_impl(cdf, size):
rs = np.random.random(size)
Expand Down
17 changes: 10 additions & 7 deletions quantecon/util/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@

"""
import numpy as np
from numba import jit, generated_jit, types
try:
from numba.np.linalg import _LAPACK # for Numba >= 0.49.0
except ModuleNotFoundError:
from numba.targets.linalg import _LAPACK # for Numba < 0.49.0
from numba import jit, types
from numba.extending import overload
from numba.np.linalg import _LAPACK


# BLAS kinds as letters
Expand All @@ -19,14 +17,19 @@
}


@generated_jit(nopython=True, cache=True)
def _numba_linalg_solve(a, b):
def _numba_linalg_solve(a, b): # pragma: no cover
pass


@overload(_numba_linalg_solve, jit_options={'cache':True})
def _numba_linalg_solve_ol(a, b):
"""
Solve the linear equation ax = b directly calling a Numba internal
function. The data in `a` and `b` are interpreted in Fortran order,
and dtype of `a` and `b` must be the same, one of {float32, float64,
complex64, complex128}. `a` and `b` are modified in place, and the
solution is stored in `b`. *No error check is made for the inputs.*
Only work in a Numba-jitted function.

Parameters
----------
Expand Down
11 changes: 8 additions & 3 deletions quantecon/util/tests/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from quantecon.util.numba import _numba_linalg_solve, comb_jit


@jit(nopython=True)
def _numba_linalg_solve_jitted(a, b):
return _numba_linalg_solve(a, b)


@jit(nopython=True)
def numba_linalg_solve_orig(a, b):
return np.linalg.solve(a, b)
Expand All @@ -26,7 +31,7 @@ def test_b_1dim(self):
a = np.asfortranarray(self.a, dtype=dtype)
b = np.asfortranarray(self.b_1dim, dtype=dtype)
sol_orig = numba_linalg_solve_orig(a, b)
r = _numba_linalg_solve(a, b)
r = _numba_linalg_solve_jitted(a, b)
assert_(r == 0)
assert_array_equal(b, sol_orig)

Expand All @@ -35,7 +40,7 @@ def test_b_2dim(self):
a = np.asfortranarray(self.a, dtype=dtype)
b = np.asfortranarray(self.b_2dim, dtype=dtype)
sol_orig = numba_linalg_solve_orig(a, b)
r = _numba_linalg_solve(a, b)
r = _numba_linalg_solve_jitted(a, b)
assert_(r == 0)
assert_array_equal(b, sol_orig)

Expand All @@ -44,7 +49,7 @@ def test_singular_a(self):
for dtype in self.dtypes:
a = np.asfortranarray(self.a_singular, dtype=dtype)
b = np.asfortranarray(b, dtype=dtype)
r = _numba_linalg_solve(a, b)
r = _numba_linalg_solve_jitted(a, b)
assert_(r != 0)


Expand Down