Skip to content

Commit 3ddb70d

Browse files
committed
RFC: Replace @generated_jit with @overload in util.numba
1 parent 8cebd03 commit 3ddb70d

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

quantecon/util/numba.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
44
"""
55
import numpy as np
6-
from numba import jit, generated_jit, types
6+
from numba import jit, types
7+
from numba.extending import overload
78
try:
89
from numba.np.linalg import _LAPACK # for Numba >= 0.49.0
910
except ModuleNotFoundError:
@@ -19,14 +20,19 @@
1920
}
2021

2122

22-
@generated_jit(nopython=True, cache=True)
23-
def _numba_linalg_solve(a, b):
23+
def _numba_linalg_solve(a, b): # pragma: no cover
24+
pass
25+
26+
27+
@overload(_numba_linalg_solve, jit_options={'cache':True})
28+
def _numba_linalg_solve_ol(a, b):
2429
"""
2530
Solve the linear equation ax = b directly calling a Numba internal
2631
function. The data in `a` and `b` are interpreted in Fortran order,
2732
and dtype of `a` and `b` must be the same, one of {float32, float64,
2833
complex64, complex128}. `a` and `b` are modified in place, and the
2934
solution is stored in `b`. *No error check is made for the inputs.*
35+
Only work in a Numba-jitted function.
3036
3137
Parameters
3238
----------

quantecon/util/tests/test_numba.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
from quantecon.util.numba import _numba_linalg_solve, comb_jit
99

1010

11+
@jit(nopython=True)
12+
def _numba_linalg_solve_jitted(a, b):
13+
return _numba_linalg_solve(a, b)
14+
15+
1116
@jit(nopython=True)
1217
def numba_linalg_solve_orig(a, b):
1318
return np.linalg.solve(a, b)
@@ -26,7 +31,7 @@ def test_b_1dim(self):
2631
a = np.asfortranarray(self.a, dtype=dtype)
2732
b = np.asfortranarray(self.b_1dim, dtype=dtype)
2833
sol_orig = numba_linalg_solve_orig(a, b)
29-
r = _numba_linalg_solve(a, b)
34+
r = _numba_linalg_solve_jitted(a, b)
3035
assert_(r == 0)
3136
assert_array_equal(b, sol_orig)
3237

@@ -35,7 +40,7 @@ def test_b_2dim(self):
3540
a = np.asfortranarray(self.a, dtype=dtype)
3641
b = np.asfortranarray(self.b_2dim, dtype=dtype)
3742
sol_orig = numba_linalg_solve_orig(a, b)
38-
r = _numba_linalg_solve(a, b)
43+
r = _numba_linalg_solve_jitted(a, b)
3944
assert_(r == 0)
4045
assert_array_equal(b, sol_orig)
4146

@@ -44,7 +49,7 @@ def test_singular_a(self):
4449
for dtype in self.dtypes:
4550
a = np.asfortranarray(self.a_singular, dtype=dtype)
4651
b = np.asfortranarray(b, dtype=dtype)
47-
r = _numba_linalg_solve(a, b)
52+
r = _numba_linalg_solve_jitted(a, b)
4853
assert_(r != 0)
4954

5055

0 commit comments

Comments
 (0)