Skip to content

Commit 4acb254

Browse files
committed
Refactor numba lapack codegen
1 parent 39704d1 commit 4acb254

File tree

16 files changed

+1271
-1070
lines changed

16 files changed

+1271
-1070
lines changed

pytensor/link/numba/dispatch/_LAPACK.py renamed to pytensor/link/numba/dispatch/linalg/_LAPACK.py

+67
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,70 @@ def numba_xposv(cls, dtype):
390390
_ptr_int, # INFO
391391
)
392392
return functype(lapack_ptr)
393+
394+
@classmethod
395+
def numba_xgttrf(cls, dtype):
396+
"""
397+
Compute the LU factorization of a tridiagonal matrix A using row interchanges.
398+
399+
Called by scipy.linalg.lu_factor
400+
"""
401+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gttrf")
402+
functype = ctypes.CFUNCTYPE(
403+
None,
404+
_ptr_int, # N
405+
float_pointer, # DL
406+
float_pointer, # D
407+
float_pointer, # DU
408+
float_pointer, # DU2
409+
_ptr_int, # IPIV
410+
_ptr_int, # INFO
411+
)
412+
return functype(lapack_ptr)
413+
414+
@classmethod
415+
def numba_xgttrs(cls, dtype):
416+
"""
417+
Solve a system of linear equations A @ X = B with a tridiagonal matrix A using the LU factorization computed by numba_gttrf.
418+
419+
Called by scipy.linalg.lu_solve
420+
"""
421+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gttrs")
422+
functype = ctypes.CFUNCTYPE(
423+
None,
424+
_ptr_int, # TRANS
425+
_ptr_int, # N
426+
_ptr_int, # NRHS
427+
float_pointer, # DL
428+
float_pointer, # D
429+
float_pointer, # DU
430+
float_pointer, # DU2
431+
_ptr_int, # IPIV
432+
float_pointer, # B
433+
_ptr_int, # LDB
434+
_ptr_int, # INFO
435+
)
436+
return functype(lapack_ptr)
437+
438+
@classmethod
439+
def numba_xgtcon(cls, dtype):
440+
"""
441+
Estimate the reciprocal of the condition number of a tridiagonal matrix A using the LU factorization computed by numba_gttrf.
442+
"""
443+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gtcon")
444+
functype = ctypes.CFUNCTYPE(
445+
None,
446+
_ptr_int, # NORM
447+
_ptr_int, # N
448+
float_pointer, # DL
449+
float_pointer, # D
450+
float_pointer, # DU
451+
float_pointer, # DU2
452+
_ptr_int, # IPIV
453+
float_pointer, # ANORM
454+
float_pointer, # RCOND
455+
float_pointer, # WORK
456+
_ptr_int, # IWORK
457+
_ptr_int, # INFO
458+
)
459+
return functype(lapack_ptr)

pytensor/link/numba/dispatch/linalg/__init__.py

Whitespace-only changes.

pytensor/link/numba/dispatch/linalg/decomposition/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import numpy as np
2+
from numba.core.extending import overload
3+
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
4+
from scipy import linalg
5+
6+
from pytensor.link.numba.dispatch.linalg._LAPACK import (
7+
_LAPACK,
8+
_get_underlying_float,
9+
int_ptr_to_val,
10+
val_to_int_ptr,
11+
)
12+
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
13+
14+
15+
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
16+
return (
17+
linalg.cholesky(
18+
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
19+
),
20+
0,
21+
)
22+
23+
24+
@overload(_cholesky)
25+
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
26+
ensure_lapack()
27+
_check_scipy_linalg_matrix(A, "cholesky")
28+
dtype = A.dtype
29+
w_type = _get_underlying_float(dtype)
30+
numba_potrf = _LAPACK().numba_xpotrf(dtype)
31+
32+
def impl(A, lower=0, overwrite_a=False, check_finite=True):
33+
_N = np.int32(A.shape[-1])
34+
if A.shape[-2] != _N:
35+
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
36+
37+
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
38+
N = val_to_int_ptr(_N)
39+
LDA = val_to_int_ptr(_N)
40+
INFO = val_to_int_ptr(0)
41+
42+
if overwrite_a and A.flags.f_contiguous:
43+
A_copy = A
44+
else:
45+
A_copy = _copy_to_fortran_order(A)
46+
47+
numba_potrf(
48+
UPLO,
49+
N,
50+
A_copy.view(w_type).ctypes,
51+
LDA,
52+
INFO,
53+
)
54+
55+
if lower:
56+
for j in range(1, _N):
57+
for i in range(j):
58+
A_copy[i, j] = 0.0
59+
else:
60+
for j in range(_N):
61+
for i in range(j + 1, _N):
62+
A_copy[i, j] = 0.0
63+
64+
return A_copy, int_ptr_to_val(INFO)
65+
66+
return impl

pytensor/link/numba/dispatch/linalg/solve/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import numpy as np
2+
from numba.core.extending import overload
3+
from numba.np.linalg import ensure_lapack
4+
from scipy import linalg
5+
6+
from pytensor.link.numba.dispatch.linalg._LAPACK import (
7+
_LAPACK,
8+
_get_underlying_float,
9+
int_ptr_to_val,
10+
val_to_int_ptr,
11+
)
12+
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
13+
from pytensor.link.numba.dispatch.linalg.utils import (
14+
_check_scipy_linalg_matrix,
15+
_copy_to_fortran_order_even_if_1d,
16+
_solve_check,
17+
)
18+
19+
20+
def _cho_solve(
21+
C: np.ndarray, B: np.ndarray, lower: bool, overwrite_b: bool, check_finite: bool
22+
):
23+
"""
24+
Solve a positive-definite linear system using the Cholesky decomposition.
25+
"""
26+
return linalg.cho_solve(
27+
(C, lower), b=B, overwrite_b=overwrite_b, check_finite=check_finite
28+
)
29+
30+
31+
@overload(_cho_solve)
32+
def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
33+
ensure_lapack()
34+
_check_scipy_linalg_matrix(C, "cho_solve")
35+
_check_scipy_linalg_matrix(B, "cho_solve")
36+
dtype = C.dtype
37+
w_type = _get_underlying_float(dtype)
38+
numba_potrs = _LAPACK().numba_xpotrs(dtype)
39+
40+
def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
41+
_solve_check_input_shapes(C, B)
42+
43+
_N = np.int32(C.shape[-1])
44+
if C.flags.f_contiguous or C.flags.c_contiguous:
45+
C_f = C
46+
if C.flags.c_contiguous:
47+
# An upper/lower triangular c_contiguous is the same as a lower/upper triangular f_contiguous
48+
lower = not lower
49+
else:
50+
C_f = np.asfortranarray(C)
51+
52+
if overwrite_b and B.flags.f_contiguous:
53+
B_copy = B
54+
else:
55+
B_copy = _copy_to_fortran_order_even_if_1d(B)
56+
57+
B_is_1d = B.ndim == 1
58+
if B_is_1d:
59+
B_copy = np.expand_dims(B_copy, -1)
60+
61+
NRHS = 1 if B_is_1d else int(B.shape[-1])
62+
63+
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
64+
N = val_to_int_ptr(_N)
65+
NRHS = val_to_int_ptr(NRHS)
66+
LDA = val_to_int_ptr(_N)
67+
LDB = val_to_int_ptr(_N)
68+
INFO = val_to_int_ptr(0)
69+
70+
numba_potrs(
71+
UPLO,
72+
N,
73+
NRHS,
74+
C_f.view(w_type).ctypes,
75+
LDA,
76+
B_copy.view(w_type).ctypes,
77+
LDB,
78+
INFO,
79+
)
80+
81+
_solve_check(_N, int_ptr_to_val(INFO))
82+
83+
if B_is_1d:
84+
return B_copy[..., 0]
85+
return B_copy
86+
87+
return impl

0 commit comments

Comments
 (0)