Skip to content

Commit a038c8e

Browse files
committed
Implement numba tridiagonal solve
1 parent 1902354 commit a038c8e

File tree

3 files changed

+320
-3
lines changed

3 files changed

+320
-3
lines changed
+299
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
from collections.abc import Callable
2+
3+
import numpy as np
4+
from numba.core.extending import overload
5+
from numba.np.linalg import ensure_lapack
6+
from numpy import ndarray
7+
from scipy import linalg
8+
9+
from pytensor.link.numba.dispatch.basic import numba_njit
10+
from pytensor.link.numba.dispatch.linalg._LAPACK import (
11+
_LAPACK,
12+
_get_underlying_float,
13+
int_ptr_to_val,
14+
val_to_int_ptr,
15+
)
16+
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
17+
from pytensor.link.numba.dispatch.linalg.utils import (
18+
_check_scipy_linalg_matrix,
19+
_copy_to_fortran_order_even_if_1d,
20+
_solve_check,
21+
_trans_char_to_int,
22+
)
23+
24+
25+
@numba_njit
26+
def tridiagonal_norm(du, d, dl):
27+
# Adapted from scipy _matrix_norm_tridiagonal:
28+
# https://github.com/scipy/scipy/blob/0f1fd4a7268b813fa2b844ca6038e4dfdf90084a/scipy/linalg/_basic.py#L356-L367
29+
anorm = np.abs(d)
30+
anorm[1:] += np.abs(du)
31+
anorm[:-1] += np.abs(dl)
32+
anorm = anorm.max()
33+
return anorm
34+
35+
36+
def _gttrf(
37+
dl: ndarray, d: ndarray, du: ndarray
38+
) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]:
39+
"""Placeholder for LU factorization of tridiagonal matrix."""
40+
return # type: ignore
41+
42+
43+
@overload(_gttrf)
44+
def gttrf_impl(
45+
dl: ndarray,
46+
d: ndarray,
47+
du: ndarray,
48+
) -> Callable[
49+
[ndarray, ndarray, ndarray], tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]
50+
]:
51+
ensure_lapack()
52+
_check_scipy_linalg_matrix(dl, "gttrf")
53+
_check_scipy_linalg_matrix(d, "gttrf")
54+
_check_scipy_linalg_matrix(du, "gttrf")
55+
dtype = d.dtype
56+
w_type = _get_underlying_float(dtype)
57+
numba_gttrf = _LAPACK().numba_xgttrf(dtype)
58+
59+
def impl(
60+
dl: ndarray,
61+
d: ndarray,
62+
du: ndarray,
63+
) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]:
64+
n = np.int32(d.shape[-1])
65+
ipiv = np.empty(n, dtype=np.int32)
66+
du2 = np.empty(n - 2, dtype=dtype)
67+
info = val_to_int_ptr(0)
68+
69+
numba_gttrf(
70+
val_to_int_ptr(n),
71+
dl.view(w_type).ctypes,
72+
d.view(w_type).ctypes,
73+
du.view(w_type).ctypes,
74+
du2.view(w_type).ctypes,
75+
ipiv.ctypes,
76+
info,
77+
)
78+
79+
return dl, d, du, du2, ipiv, int_ptr_to_val(info)
80+
81+
return impl
82+
83+
84+
def _gttrs(
85+
dl: ndarray,
86+
d: ndarray,
87+
du: ndarray,
88+
du2: ndarray,
89+
ipiv: ndarray,
90+
b: ndarray,
91+
overwrite_b: bool,
92+
trans: bool,
93+
) -> tuple[ndarray, int]:
94+
"""Placeholder for solving an LU-decomposed tridiagonal system."""
95+
return # type: ignore
96+
97+
98+
@overload(_gttrs)
99+
def gttrs_impl(
100+
dl: ndarray,
101+
d: ndarray,
102+
du: ndarray,
103+
du2: ndarray,
104+
ipiv: ndarray,
105+
b: ndarray,
106+
overwrite_b: bool,
107+
trans: bool,
108+
) -> Callable[
109+
[ndarray, ndarray, ndarray, ndarray, ndarray, ndarray, bool, bool],
110+
tuple[ndarray, int],
111+
]:
112+
ensure_lapack()
113+
_check_scipy_linalg_matrix(dl, "gttrs")
114+
_check_scipy_linalg_matrix(d, "gttrs")
115+
_check_scipy_linalg_matrix(du, "gttrs")
116+
_check_scipy_linalg_matrix(du2, "gttrs")
117+
_check_scipy_linalg_matrix(b, "gttrs")
118+
dtype = d.dtype
119+
w_type = _get_underlying_float(dtype)
120+
numba_gttrs = _LAPACK().numba_xgttrs(dtype)
121+
122+
def impl(
123+
dl: ndarray,
124+
d: ndarray,
125+
du: ndarray,
126+
du2: ndarray,
127+
ipiv: ndarray,
128+
b: ndarray,
129+
overwrite_b: bool,
130+
trans: bool,
131+
) -> tuple[ndarray, int]:
132+
n = np.int32(d.shape[-1])
133+
nrhs = 1 if b.ndim == 1 else int(b.shape[-1])
134+
info = val_to_int_ptr(0)
135+
136+
if overwrite_b and b.flags.f_contiguous:
137+
b_copy = b
138+
else:
139+
b_copy = _copy_to_fortran_order_even_if_1d(b)
140+
141+
numba_gttrs(
142+
val_to_int_ptr(_trans_char_to_int(trans)),
143+
val_to_int_ptr(n),
144+
val_to_int_ptr(nrhs),
145+
dl.view(w_type).ctypes,
146+
d.view(w_type).ctypes,
147+
du.view(w_type).ctypes,
148+
du2.view(w_type).ctypes,
149+
ipiv.ctypes,
150+
b_copy.view(w_type).ctypes,
151+
val_to_int_ptr(n),
152+
info,
153+
)
154+
155+
return b_copy, int_ptr_to_val(info)
156+
157+
return impl
158+
159+
160+
def _gtcon(
161+
dl: ndarray,
162+
d: ndarray,
163+
du: ndarray,
164+
du2: ndarray,
165+
ipiv: ndarray,
166+
anorm: float,
167+
norm: str,
168+
) -> tuple[ndarray, int]:
169+
"""Placeholder for computing the condition number of a tridiagonal system."""
170+
return # type: ignore
171+
172+
173+
@overload(_gtcon)
174+
def gtcon_impl(
175+
dl: ndarray,
176+
d: ndarray,
177+
du: ndarray,
178+
du2: ndarray,
179+
ipiv: ndarray,
180+
anorm: float,
181+
norm: str,
182+
) -> Callable[
183+
[ndarray, ndarray, ndarray, ndarray, ndarray, float, str], tuple[ndarray, int]
184+
]:
185+
ensure_lapack()
186+
_check_scipy_linalg_matrix(dl, "gtcon")
187+
_check_scipy_linalg_matrix(d, "gtcon")
188+
_check_scipy_linalg_matrix(du, "gtcon")
189+
_check_scipy_linalg_matrix(du2, "gtcon")
190+
dtype = d.dtype
191+
w_type = _get_underlying_float(dtype)
192+
numba_gtcon = _LAPACK().numba_xgtcon(dtype)
193+
194+
def impl(
195+
dl: ndarray,
196+
d: ndarray,
197+
du: ndarray,
198+
du2: ndarray,
199+
ipiv: ndarray,
200+
anorm: float,
201+
norm: str,
202+
) -> tuple[ndarray, int]:
203+
n = np.int32(d.shape[-1])
204+
rcond = np.empty(1, dtype=dtype)
205+
work = np.empty(2 * n, dtype=dtype)
206+
iwork = np.empty(n, dtype=np.int32)
207+
info = val_to_int_ptr(0)
208+
209+
numba_gtcon(
210+
val_to_int_ptr(ord(norm)),
211+
val_to_int_ptr(n),
212+
dl.view(w_type).ctypes,
213+
d.view(w_type).ctypes,
214+
du.view(w_type).ctypes,
215+
du2.view(w_type).ctypes,
216+
ipiv.ctypes,
217+
np.array(anorm, dtype=dtype).view(w_type).ctypes,
218+
rcond.view(w_type).ctypes,
219+
work.view(w_type).ctypes,
220+
iwork.ctypes,
221+
info,
222+
)
223+
224+
return rcond, int_ptr_to_val(info)
225+
226+
return impl
227+
228+
229+
def _solve_tridiagonal(
230+
a: ndarray,
231+
b: ndarray,
232+
lower: bool,
233+
overwrite_a: bool,
234+
overwrite_b: bool,
235+
check_finite: bool,
236+
transposed: bool,
237+
):
238+
"""
239+
Solve a positive-definite linear system using the Cholesky decomposition.
240+
"""
241+
return linalg.solve(
242+
a=a,
243+
b=b,
244+
lower=lower,
245+
overwrite_a=overwrite_a,
246+
overwrite_b=overwrite_b,
247+
check_finite=check_finite,
248+
transposed=transposed,
249+
assume_a="tridiagonal",
250+
)
251+
252+
253+
@overload(_solve_tridiagonal)
254+
def _tridiagonal_solve_impl(
255+
A: ndarray,
256+
B: ndarray,
257+
lower: bool,
258+
overwrite_a: bool,
259+
overwrite_b: bool,
260+
check_finite: bool,
261+
transposed: bool,
262+
) -> Callable[[ndarray, ndarray, bool, bool, bool, bool, bool], ndarray]:
263+
ensure_lapack()
264+
_check_scipy_linalg_matrix(A, "solve")
265+
_check_scipy_linalg_matrix(B, "solve")
266+
267+
def impl(
268+
A: ndarray,
269+
B: ndarray,
270+
lower: bool,
271+
overwrite_a: bool,
272+
overwrite_b: bool,
273+
check_finite: bool,
274+
transposed: bool,
275+
) -> ndarray:
276+
n = np.int32(A.shape[-1])
277+
_solve_check_input_shapes(A, B)
278+
norm = "1"
279+
280+
if transposed:
281+
A = A.T
282+
dl, d, du = np.diag(A, -1), np.diag(A, 0), np.diag(A, 1)
283+
284+
anorm = tridiagonal_norm(du, d, dl)
285+
286+
dl, d, du, du2, IPIV, INFO = _gttrf(dl, d, du)
287+
_solve_check(n, INFO)
288+
289+
X, INFO = _gttrs(
290+
dl, d, du, du2, IPIV, B, trans=transposed, overwrite_b=overwrite_b
291+
)
292+
_solve_check(n, INFO)
293+
294+
RCOND, INFO = _gtcon(dl, d, du, du2, IPIV, anorm, norm)
295+
_solve_check(n, INFO, True, RCOND)
296+
297+
return X
298+
299+
return impl

Diff for: pytensor/link/numba/dispatch/slinalg.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd
1010
from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric
1111
from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular
12+
from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal
1213
from pytensor.tensor.slinalg import (
1314
BlockDiagonal,
1415
Cholesky,
@@ -114,10 +115,12 @@ def numba_funcify_Solve(op, node, **kwargs):
114115
solve_fn = _solve_symmetric
115116
elif assume_a == "pos":
116117
solve_fn = _solve_psd
118+
elif assume_a == "tridiagonal":
119+
solve_fn = _solve_tridiagonal
117120
else:
118121
warnings.warn(
119122
f"Numba assume_a={assume_a} not implemented. Falling back to general solve.\n"
120-
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her', or 'triangular' to improve performance.",
123+
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her', 'triangular' or 'tridiagonal' to improve performance.",
121124
UserWarning,
122125
)
123126
solve_fn = _solve_gen

Diff for: tests/link/numba/test_slinalg.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class TestSolves:
9797
[(5, 1), (5, 5), (5,)],
9898
ids=["b_col_vec", "b_matrix", "b_vec"],
9999
)
100-
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
100+
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos", "tridiagonal"], ids=str)
101101
def test_solve(
102102
self,
103103
b_shape: tuple[int],
@@ -106,7 +106,7 @@ def test_solve(
106106
overwrite_a: bool,
107107
overwrite_b: bool,
108108
):
109-
if assume_a not in ("sym", "her", "pos") and not lower:
109+
if assume_a not in ("sym", "her", "pos", "tridiagonal") and not lower:
110110
# Avoid redundant tests with lower=True and lower=False for non symmetric matrices
111111
pytest.skip("Skipping redundant test already covered by lower=True")
112112

@@ -120,6 +120,14 @@ def A_func(x):
120120
# We have to set the unused triangle to something other than zero
121121
# to see lapack destroying it.
122122
x[np.triu_indices(n, 1) if lower else np.tril_indices(n, 1)] = np.pi
123+
elif assume_a == "tridiagonal":
124+
_x = x
125+
x = np.zeros_like(x)
126+
n = x.shape[-1]
127+
arange_n = np.arange(n)
128+
x[arange_n[1:], arange_n[:-1]] = np.diag(_x, k=-1)
129+
x[arange_n, arange_n] = np.diag(_x, k=0)
130+
x[arange_n[:-1], arange_n[1:]] = np.diag(_x, k=1)
123131
return x
124132

125133
A = pt.matrix("A", dtype=floatX)
@@ -146,7 +154,14 @@ def A_func(x):
146154

147155
op = f.maker.fgraph.outputs[0].owner.op
148156
assert isinstance(op, Solve)
157+
assert op.assume_a == assume_a
149158
destroy_map = op.destroy_map
159+
160+
if overwrite_a and assume_a == "tridiagonal":
161+
# Tridiagonal solve never destroys the A matrix
162+
# Treat test from here as if overwrite_a is False
163+
overwrite_a = False
164+
150165
if overwrite_a and overwrite_b:
151166
raise NotImplementedError(
152167
"Test not implemented for simultaneous overwrite_a and overwrite_b, as that's not currently supported by PyTensor"

0 commit comments

Comments
 (0)