Skip to content

Commit 4cc3474

Browse files
committed
Implement numba tridiagonal solve
1 parent 4acb254 commit 4cc3474

File tree

3 files changed

+313
-3
lines changed

3 files changed

+313
-3
lines changed
+292
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
from collections.abc import Callable
2+
3+
import numpy as np
4+
from numba.core.extending import overload
5+
from numba.np.linalg import ensure_lapack
6+
from numpy import ndarray
7+
from scipy import linalg
8+
9+
from pytensor.link.numba.dispatch.linalg._LAPACK import (
10+
_LAPACK,
11+
_get_underlying_float,
12+
int_ptr_to_val,
13+
val_to_int_ptr,
14+
)
15+
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
16+
from pytensor.link.numba.dispatch.linalg.utils import (
17+
_check_scipy_linalg_matrix,
18+
_copy_to_fortran_order_even_if_1d,
19+
_solve_check,
20+
_trans_char_to_int,
21+
)
22+
23+
24+
def _gttrf(
25+
dl: ndarray, d: ndarray, du: ndarray
26+
) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]:
27+
"""Placeholder for LU factorization of tridiagonal matrix."""
28+
return # type: ignore
29+
30+
31+
@overload(_gttrf)
32+
def gttrf_impl(
33+
dl: ndarray,
34+
d: ndarray,
35+
du: ndarray,
36+
) -> Callable[
37+
[ndarray, ndarray, ndarray], tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]
38+
]:
39+
ensure_lapack()
40+
_check_scipy_linalg_matrix(dl, "gttrf")
41+
_check_scipy_linalg_matrix(d, "gttrf")
42+
_check_scipy_linalg_matrix(du, "gttrf")
43+
dtype = d.dtype
44+
w_type = _get_underlying_float(dtype)
45+
numba_gttrf = _LAPACK().numba_xgttrf(dtype)
46+
47+
def impl(
48+
dl: ndarray,
49+
d: ndarray,
50+
du: ndarray,
51+
) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]:
52+
n = np.int32(d.shape[-1])
53+
ipiv = np.empty(n, dtype=np.int32)
54+
du2 = np.empty(n - 2, dtype=dtype)
55+
info = val_to_int_ptr(0)
56+
57+
numba_gttrf(
58+
val_to_int_ptr(n),
59+
dl.view(w_type).ctypes,
60+
d.view(w_type).ctypes,
61+
du.view(w_type).ctypes,
62+
du2.view(w_type).ctypes,
63+
ipiv.ctypes,
64+
info,
65+
)
66+
67+
return dl, d, du, du2, ipiv, int_ptr_to_val(info)
68+
69+
return impl
70+
71+
72+
def _gttrs(
73+
dl: ndarray,
74+
d: ndarray,
75+
du: ndarray,
76+
du2: ndarray,
77+
ipiv: ndarray,
78+
b: ndarray,
79+
overwrite_b: bool,
80+
trans: bool,
81+
) -> tuple[ndarray, int]:
82+
"""Placeholder for solving an LU-decomposed tridiagonal system."""
83+
return # type: ignore
84+
85+
86+
@overload(_gttrs)
87+
def gttrs_impl(
88+
dl: ndarray,
89+
d: ndarray,
90+
du: ndarray,
91+
du2: ndarray,
92+
ipiv: ndarray,
93+
b: ndarray,
94+
overwrite_b: bool,
95+
trans: bool,
96+
) -> Callable[
97+
[ndarray, ndarray, ndarray, ndarray, ndarray, ndarray, bool, bool],
98+
tuple[ndarray, int],
99+
]:
100+
ensure_lapack()
101+
_check_scipy_linalg_matrix(dl, "gttrs")
102+
_check_scipy_linalg_matrix(d, "gttrs")
103+
_check_scipy_linalg_matrix(du, "gttrs")
104+
_check_scipy_linalg_matrix(du2, "gttrs")
105+
_check_scipy_linalg_matrix(b, "gttrs")
106+
dtype = d.dtype
107+
w_type = _get_underlying_float(dtype)
108+
numba_gttrs = _LAPACK().numba_xgttrs(dtype)
109+
110+
def impl(
111+
dl: ndarray,
112+
d: ndarray,
113+
du: ndarray,
114+
du2: ndarray,
115+
ipiv: ndarray,
116+
b: ndarray,
117+
overwrite_b: bool,
118+
trans: bool,
119+
) -> tuple[ndarray, int]:
120+
n = np.int32(d.shape[-1])
121+
nrhs = 1 if b.ndim == 1 else int(b.shape[-1])
122+
info = val_to_int_ptr(0)
123+
124+
if overwrite_b and b.flags.f_contiguous:
125+
b_copy = b
126+
else:
127+
b_copy = _copy_to_fortran_order_even_if_1d(b)
128+
129+
numba_gttrs(
130+
val_to_int_ptr(_trans_char_to_int(trans)),
131+
val_to_int_ptr(n),
132+
val_to_int_ptr(nrhs),
133+
dl.view(w_type).ctypes,
134+
d.view(w_type).ctypes,
135+
du.view(w_type).ctypes,
136+
du2.view(w_type).ctypes,
137+
ipiv.ctypes,
138+
b_copy.view(w_type).ctypes,
139+
val_to_int_ptr(n),
140+
info,
141+
)
142+
143+
return b_copy, int_ptr_to_val(info)
144+
145+
return impl
146+
147+
148+
def _gtcon(
149+
dl: ndarray,
150+
d: ndarray,
151+
du: ndarray,
152+
du2: ndarray,
153+
ipiv: ndarray,
154+
anorm: float,
155+
norm: str,
156+
) -> tuple[ndarray, int]:
157+
"""Placeholder for computing the condition number of a tridiagonal system."""
158+
return # type: ignore
159+
160+
161+
@overload(_gtcon)
162+
def gtcon_impl(
163+
dl: ndarray,
164+
d: ndarray,
165+
du: ndarray,
166+
du2: ndarray,
167+
ipiv: ndarray,
168+
anorm: float,
169+
norm: str,
170+
) -> Callable[
171+
[ndarray, ndarray, ndarray, ndarray, ndarray, float, str], tuple[ndarray, int]
172+
]:
173+
ensure_lapack()
174+
_check_scipy_linalg_matrix(dl, "gtcon")
175+
_check_scipy_linalg_matrix(d, "gtcon")
176+
_check_scipy_linalg_matrix(du, "gtcon")
177+
_check_scipy_linalg_matrix(du2, "gtcon")
178+
dtype = d.dtype
179+
w_type = _get_underlying_float(dtype)
180+
numba_gtcon = _LAPACK().numba_xgtcon(dtype)
181+
182+
def impl(
183+
dl: ndarray,
184+
d: ndarray,
185+
du: ndarray,
186+
du2: ndarray,
187+
ipiv: ndarray,
188+
anorm: float,
189+
norm: str,
190+
) -> tuple[ndarray, int]:
191+
n = np.int32(d.shape[-1])
192+
rcond = np.empty(1, dtype=dtype)
193+
work = np.empty(2 * n, dtype=dtype)
194+
iwork = np.empty(n, dtype=np.int32)
195+
info = val_to_int_ptr(0)
196+
197+
numba_gtcon(
198+
val_to_int_ptr(ord(norm)),
199+
val_to_int_ptr(n),
200+
dl.view(w_type).ctypes,
201+
d.view(w_type).ctypes,
202+
du.view(w_type).ctypes,
203+
du2.view(w_type).ctypes,
204+
ipiv.ctypes,
205+
np.array(anorm, dtype=dtype).view(w_type).ctypes,
206+
rcond.view(w_type).ctypes,
207+
work.view(w_type).ctypes,
208+
iwork.ctypes,
209+
info,
210+
)
211+
212+
return rcond, int_ptr_to_val(info)
213+
214+
return impl
215+
216+
217+
def _solve_tridiagonal(
218+
a: ndarray,
219+
b: ndarray,
220+
lower: bool,
221+
overwrite_a: bool,
222+
overwrite_b: bool,
223+
check_finite: bool,
224+
transposed: bool,
225+
):
226+
"""
227+
Solve a positive-definite linear system using the Cholesky decomposition.
228+
"""
229+
return linalg.solve(
230+
a=a,
231+
b=b,
232+
lower=lower,
233+
overwrite_a=overwrite_a,
234+
overwrite_b=overwrite_b,
235+
check_finite=check_finite,
236+
transposed=transposed,
237+
assume_a="tridiagonal",
238+
)
239+
240+
241+
@overload(_solve_tridiagonal)
242+
def _tridiagonal_solve_impl(
243+
A: ndarray,
244+
B: ndarray,
245+
lower: bool,
246+
overwrite_a: bool,
247+
overwrite_b: bool,
248+
check_finite: bool,
249+
transposed: bool,
250+
) -> Callable[[ndarray, ndarray, bool, bool, bool, bool, bool], ndarray]:
251+
ensure_lapack()
252+
_check_scipy_linalg_matrix(A, "solve")
253+
_check_scipy_linalg_matrix(B, "solve")
254+
255+
def impl(
256+
A: ndarray,
257+
B: ndarray,
258+
lower: bool,
259+
overwrite_a: bool,
260+
overwrite_b: bool,
261+
check_finite: bool,
262+
transposed: bool,
263+
) -> ndarray:
264+
n = np.int32(A.shape[-1])
265+
_solve_check_input_shapes(A, B)
266+
norm = "1"
267+
268+
if transposed:
269+
A = A.T
270+
dl, d, du = np.diag(A, -1), np.diag(A, 0), np.diag(A, 1)
271+
272+
# Adapted from scipy _matrix_norm_tridiagonal:
273+
# https://github.com/scipy/scipy/blob/0f1fd4a7268b813fa2b844ca6038e4dfdf90084a/scipy/linalg/_basic.py#L356-L367
274+
anorm = np.abs(d)
275+
anorm[1:] += np.abs(du)
276+
anorm[:-1] += np.abs(dl)
277+
anorm = anorm.max()
278+
279+
dl, d, du, du2, IPIV, INFO = _gttrf(dl, d, du)
280+
_solve_check(n, INFO)
281+
282+
X, INFO = _gttrs(
283+
dl, d, du, du2, IPIV, B, trans=transposed, overwrite_b=overwrite_b
284+
)
285+
_solve_check(n, INFO)
286+
287+
RCOND, INFO = _gtcon(dl, d, du, du2, IPIV, anorm, norm)
288+
_solve_check(n, INFO, True, RCOND)
289+
290+
return X
291+
292+
return impl

Diff for: pytensor/link/numba/dispatch/slinalg.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd
1010
from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric
1111
from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular
12+
from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal
1213
from pytensor.tensor.slinalg import (
1314
BlockDiagonal,
1415
Cholesky,
@@ -114,10 +115,12 @@ def numba_funcify_Solve(op, node, **kwargs):
114115
solve_fn = _solve_symmetric
115116
elif assume_a == "pos":
116117
solve_fn = _solve_psd
118+
elif assume_a == "tridiagonal":
119+
solve_fn = _solve_tridiagonal
117120
else:
118121
warnings.warn(
119122
f"Numba assume_a={assume_a} not implemented. Falling back to general solve.\n"
120-
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her', or 'triangular' to improve performance.",
123+
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her', 'triangular' or 'tridiagonal' to improve performance.",
121124
UserWarning,
122125
)
123126
solve_fn = _solve_gen

Diff for: tests/link/numba/test_slinalg.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class TestSolves:
9595
[(5, 1), (5, 5), (5,)],
9696
ids=["b_col_vec", "b_matrix", "b_vec"],
9797
)
98-
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
98+
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos", "tridiagonal"], ids=str)
9999
def test_solve(
100100
self,
101101
b_shape: tuple[int],
@@ -104,7 +104,7 @@ def test_solve(
104104
overwrite_a: bool,
105105
overwrite_b: bool,
106106
):
107-
if assume_a not in ("sym", "her", "pos") and not lower:
107+
if assume_a not in ("sym", "her", "pos", "tridiagonal") and not lower:
108108
# Avoid redundant tests with lower=True and lower=False for non symmetric matrices
109109
pytest.skip("Skipping redundant test already covered by lower=True")
110110

@@ -118,6 +118,14 @@ def A_func(x):
118118
# We have to set the unused triangle to something other than zero
119119
# to see lapack destroying it.
120120
x[np.triu_indices(n, 1) if lower else np.tril_indices(n, 1)] = np.pi
121+
elif assume_a == "tridiagonal":
122+
_x = x
123+
x = np.zeros_like(x)
124+
n = x.shape[-1]
125+
arange_n = np.arange(n)
126+
x[arange_n[1:], arange_n[:-1]] = np.diag(_x, k=-1)
127+
x[arange_n, arange_n] = np.diag(_x, k=0)
128+
x[arange_n[:-1], arange_n[1:]] = np.diag(_x, k=1)
121129
return x
122130

123131
A = pt.matrix("A", dtype=floatX)
@@ -144,7 +152,14 @@ def A_func(x):
144152

145153
op = f.maker.fgraph.outputs[0].owner.op
146154
assert isinstance(op, Solve)
155+
assert op.assume_a == assume_a
147156
destroy_map = op.destroy_map
157+
158+
if overwrite_a and assume_a == "tridiagonal":
159+
# Tridiagonal solve never destroys the A matrix
160+
# Treat test from here as if overwrite_a is False
161+
overwrite_a = False
162+
148163
if overwrite_a and overwrite_b:
149164
raise NotImplementedError(
150165
"Test not implemented for simultaneous overwrite_a and overwrite_b, as that's not currently supported by PyTensor"

0 commit comments

Comments
 (0)