|
| 1 | +from collections.abc import Callable |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +from numba.core.extending import overload |
| 5 | +from numba.np.linalg import ensure_lapack |
| 6 | +from numpy import ndarray |
| 7 | +from scipy import linalg |
| 8 | + |
| 9 | +from pytensor.link.numba.dispatch.linalg._LAPACK import ( |
| 10 | + _LAPACK, |
| 11 | + _get_underlying_float, |
| 12 | + int_ptr_to_val, |
| 13 | + val_to_int_ptr, |
| 14 | +) |
| 15 | +from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes |
| 16 | +from pytensor.link.numba.dispatch.linalg.utils import ( |
| 17 | + _check_scipy_linalg_matrix, |
| 18 | + _copy_to_fortran_order_even_if_1d, |
| 19 | + _solve_check, |
| 20 | + _trans_char_to_int, |
| 21 | +) |
| 22 | + |
| 23 | + |
| 24 | +def _gttrf( |
| 25 | + dl: ndarray, d: ndarray, du: ndarray |
| 26 | +) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]: |
| 27 | + """Placeholder for LU factorization of tridiagonal matrix.""" |
| 28 | + return # type: ignore |
| 29 | + |
| 30 | + |
| 31 | +@overload(_gttrf) |
| 32 | +def gttrf_impl( |
| 33 | + dl: ndarray, |
| 34 | + d: ndarray, |
| 35 | + du: ndarray, |
| 36 | +) -> Callable[ |
| 37 | + [ndarray, ndarray, ndarray], tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int] |
| 38 | +]: |
| 39 | + ensure_lapack() |
| 40 | + _check_scipy_linalg_matrix(dl, "gttrf") |
| 41 | + _check_scipy_linalg_matrix(d, "gttrf") |
| 42 | + _check_scipy_linalg_matrix(du, "gttrf") |
| 43 | + dtype = d.dtype |
| 44 | + w_type = _get_underlying_float(dtype) |
| 45 | + numba_gttrf = _LAPACK().numba_xgttrf(dtype) |
| 46 | + |
| 47 | + def impl( |
| 48 | + dl: ndarray, |
| 49 | + d: ndarray, |
| 50 | + du: ndarray, |
| 51 | + ) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]: |
| 52 | + n = np.int32(d.shape[-1]) |
| 53 | + ipiv = np.empty(n, dtype=np.int32) |
| 54 | + du2 = np.empty(n - 2, dtype=dtype) |
| 55 | + info = val_to_int_ptr(0) |
| 56 | + |
| 57 | + numba_gttrf( |
| 58 | + val_to_int_ptr(n), |
| 59 | + dl.view(w_type).ctypes, |
| 60 | + d.view(w_type).ctypes, |
| 61 | + du.view(w_type).ctypes, |
| 62 | + du2.view(w_type).ctypes, |
| 63 | + ipiv.ctypes, |
| 64 | + info, |
| 65 | + ) |
| 66 | + |
| 67 | + return dl, d, du, du2, ipiv, int_ptr_to_val(info) |
| 68 | + |
| 69 | + return impl |
| 70 | + |
| 71 | + |
| 72 | +def _gttrs( |
| 73 | + dl: ndarray, |
| 74 | + d: ndarray, |
| 75 | + du: ndarray, |
| 76 | + du2: ndarray, |
| 77 | + ipiv: ndarray, |
| 78 | + b: ndarray, |
| 79 | + overwrite_b: bool, |
| 80 | + trans: bool, |
| 81 | +) -> tuple[ndarray, int]: |
| 82 | + """Placeholder for solving an LU-decomposed tridiagonal system.""" |
| 83 | + return # type: ignore |
| 84 | + |
| 85 | + |
| 86 | +@overload(_gttrs) |
| 87 | +def gttrs_impl( |
| 88 | + dl: ndarray, |
| 89 | + d: ndarray, |
| 90 | + du: ndarray, |
| 91 | + du2: ndarray, |
| 92 | + ipiv: ndarray, |
| 93 | + b: ndarray, |
| 94 | + overwrite_b: bool, |
| 95 | + trans: bool, |
| 96 | +) -> Callable[ |
| 97 | + [ndarray, ndarray, ndarray, ndarray, ndarray, ndarray, bool, bool], |
| 98 | + tuple[ndarray, int], |
| 99 | +]: |
| 100 | + ensure_lapack() |
| 101 | + _check_scipy_linalg_matrix(dl, "gttrs") |
| 102 | + _check_scipy_linalg_matrix(d, "gttrs") |
| 103 | + _check_scipy_linalg_matrix(du, "gttrs") |
| 104 | + _check_scipy_linalg_matrix(du2, "gttrs") |
| 105 | + _check_scipy_linalg_matrix(b, "gttrs") |
| 106 | + dtype = d.dtype |
| 107 | + w_type = _get_underlying_float(dtype) |
| 108 | + numba_gttrs = _LAPACK().numba_xgttrs(dtype) |
| 109 | + |
| 110 | + def impl( |
| 111 | + dl: ndarray, |
| 112 | + d: ndarray, |
| 113 | + du: ndarray, |
| 114 | + du2: ndarray, |
| 115 | + ipiv: ndarray, |
| 116 | + b: ndarray, |
| 117 | + overwrite_b: bool, |
| 118 | + trans: bool, |
| 119 | + ) -> tuple[ndarray, int]: |
| 120 | + n = np.int32(d.shape[-1]) |
| 121 | + nrhs = 1 if b.ndim == 1 else int(b.shape[-1]) |
| 122 | + info = val_to_int_ptr(0) |
| 123 | + |
| 124 | + if overwrite_b and b.flags.f_contiguous: |
| 125 | + b_copy = b |
| 126 | + else: |
| 127 | + b_copy = _copy_to_fortran_order_even_if_1d(b) |
| 128 | + |
| 129 | + numba_gttrs( |
| 130 | + val_to_int_ptr(_trans_char_to_int(trans)), |
| 131 | + val_to_int_ptr(n), |
| 132 | + val_to_int_ptr(nrhs), |
| 133 | + dl.view(w_type).ctypes, |
| 134 | + d.view(w_type).ctypes, |
| 135 | + du.view(w_type).ctypes, |
| 136 | + du2.view(w_type).ctypes, |
| 137 | + ipiv.ctypes, |
| 138 | + b_copy.view(w_type).ctypes, |
| 139 | + val_to_int_ptr(n), |
| 140 | + info, |
| 141 | + ) |
| 142 | + |
| 143 | + return b_copy, int_ptr_to_val(info) |
| 144 | + |
| 145 | + return impl |
| 146 | + |
| 147 | + |
| 148 | +def _gtcon( |
| 149 | + dl: ndarray, |
| 150 | + d: ndarray, |
| 151 | + du: ndarray, |
| 152 | + du2: ndarray, |
| 153 | + ipiv: ndarray, |
| 154 | + anorm: float, |
| 155 | + norm: str, |
| 156 | +) -> tuple[ndarray, int]: |
| 157 | + """Placeholder for computing the condition number of a tridiagonal system.""" |
| 158 | + return # type: ignore |
| 159 | + |
| 160 | + |
| 161 | +@overload(_gtcon) |
| 162 | +def gtcon_impl( |
| 163 | + dl: ndarray, |
| 164 | + d: ndarray, |
| 165 | + du: ndarray, |
| 166 | + du2: ndarray, |
| 167 | + ipiv: ndarray, |
| 168 | + anorm: float, |
| 169 | + norm: str, |
| 170 | +) -> Callable[ |
| 171 | + [ndarray, ndarray, ndarray, ndarray, ndarray, float, str], tuple[ndarray, int] |
| 172 | +]: |
| 173 | + ensure_lapack() |
| 174 | + _check_scipy_linalg_matrix(dl, "gtcon") |
| 175 | + _check_scipy_linalg_matrix(d, "gtcon") |
| 176 | + _check_scipy_linalg_matrix(du, "gtcon") |
| 177 | + _check_scipy_linalg_matrix(du2, "gtcon") |
| 178 | + dtype = d.dtype |
| 179 | + w_type = _get_underlying_float(dtype) |
| 180 | + numba_gtcon = _LAPACK().numba_xgtcon(dtype) |
| 181 | + |
| 182 | + def impl( |
| 183 | + dl: ndarray, |
| 184 | + d: ndarray, |
| 185 | + du: ndarray, |
| 186 | + du2: ndarray, |
| 187 | + ipiv: ndarray, |
| 188 | + anorm: float, |
| 189 | + norm: str, |
| 190 | + ) -> tuple[ndarray, int]: |
| 191 | + n = np.int32(d.shape[-1]) |
| 192 | + rcond = np.empty(1, dtype=dtype) |
| 193 | + work = np.empty(2 * n, dtype=dtype) |
| 194 | + iwork = np.empty(n, dtype=np.int32) |
| 195 | + info = val_to_int_ptr(0) |
| 196 | + |
| 197 | + numba_gtcon( |
| 198 | + val_to_int_ptr(ord(norm)), |
| 199 | + val_to_int_ptr(n), |
| 200 | + dl.view(w_type).ctypes, |
| 201 | + d.view(w_type).ctypes, |
| 202 | + du.view(w_type).ctypes, |
| 203 | + du2.view(w_type).ctypes, |
| 204 | + ipiv.ctypes, |
| 205 | + np.array(anorm, dtype=dtype).view(w_type).ctypes, |
| 206 | + rcond.view(w_type).ctypes, |
| 207 | + work.view(w_type).ctypes, |
| 208 | + iwork.ctypes, |
| 209 | + info, |
| 210 | + ) |
| 211 | + |
| 212 | + return rcond, int_ptr_to_val(info) |
| 213 | + |
| 214 | + return impl |
| 215 | + |
| 216 | + |
| 217 | +def _solve_tridiagonal( |
| 218 | + a: ndarray, |
| 219 | + b: ndarray, |
| 220 | + lower: bool, |
| 221 | + overwrite_a: bool, |
| 222 | + overwrite_b: bool, |
| 223 | + check_finite: bool, |
| 224 | + transposed: bool, |
| 225 | +): |
| 226 | + """ |
| 227 | + Solve a positive-definite linear system using the Cholesky decomposition. |
| 228 | + """ |
| 229 | + return linalg.solve( |
| 230 | + a=a, |
| 231 | + b=b, |
| 232 | + lower=lower, |
| 233 | + overwrite_a=overwrite_a, |
| 234 | + overwrite_b=overwrite_b, |
| 235 | + check_finite=check_finite, |
| 236 | + transposed=transposed, |
| 237 | + assume_a="tridiagonal", |
| 238 | + ) |
| 239 | + |
| 240 | + |
| 241 | +@overload(_solve_tridiagonal) |
| 242 | +def _tridiagonal_solve_impl( |
| 243 | + A: ndarray, |
| 244 | + B: ndarray, |
| 245 | + lower: bool, |
| 246 | + overwrite_a: bool, |
| 247 | + overwrite_b: bool, |
| 248 | + check_finite: bool, |
| 249 | + transposed: bool, |
| 250 | +) -> Callable[[ndarray, ndarray, bool, bool, bool, bool, bool], ndarray]: |
| 251 | + ensure_lapack() |
| 252 | + _check_scipy_linalg_matrix(A, "solve") |
| 253 | + _check_scipy_linalg_matrix(B, "solve") |
| 254 | + |
| 255 | + def impl( |
| 256 | + A: ndarray, |
| 257 | + B: ndarray, |
| 258 | + lower: bool, |
| 259 | + overwrite_a: bool, |
| 260 | + overwrite_b: bool, |
| 261 | + check_finite: bool, |
| 262 | + transposed: bool, |
| 263 | + ) -> ndarray: |
| 264 | + n = np.int32(A.shape[-1]) |
| 265 | + _solve_check_input_shapes(A, B) |
| 266 | + norm = "1" |
| 267 | + |
| 268 | + if transposed: |
| 269 | + A = A.T |
| 270 | + dl, d, du = np.diag(A, -1), np.diag(A, 0), np.diag(A, 1) |
| 271 | + |
| 272 | + # Adapted from scipy _matrix_norm_tridiagonal: |
| 273 | + # https://github.com/scipy/scipy/blob/0f1fd4a7268b813fa2b844ca6038e4dfdf90084a/scipy/linalg/_basic.py#L356-L367 |
| 274 | + anorm = np.abs(d) |
| 275 | + anorm[1:] += np.abs(du) |
| 276 | + anorm[:-1] += np.abs(dl) |
| 277 | + anorm = anorm.max() |
| 278 | + |
| 279 | + dl, d, du, du2, IPIV, INFO = _gttrf(dl, d, du) |
| 280 | + _solve_check(n, INFO) |
| 281 | + |
| 282 | + X, INFO = _gttrs( |
| 283 | + dl, d, du, du2, IPIV, B, trans=transposed, overwrite_b=overwrite_b |
| 284 | + ) |
| 285 | + _solve_check(n, INFO) |
| 286 | + |
| 287 | + RCOND, INFO = _gtcon(dl, d, du, du2, IPIV, anorm, norm) |
| 288 | + _solve_check(n, INFO, True, RCOND) |
| 289 | + |
| 290 | + return X |
| 291 | + |
| 292 | + return impl |
0 commit comments