diff --git a/pymc_experimental/gp/latent_approx.py b/pymc_experimental/gp/latent_approx.py index abb9fc02..dbcc269e 100644 --- a/pymc_experimental/gp/latent_approx.py +++ b/pymc_experimental/gp/latent_approx.py @@ -16,7 +16,8 @@ import numpy as np import pymc as pm import pytensor.tensor as pt -from pymc.gp.util import JITTER_DEFAULT, cholesky, solve_lower, solve_upper, stabilize +from pymc.gp.util import JITTER_DEFAULT, stabilize +from pytensor.tensor.slinalg import solve_triangular class LatentApprox(pm.gp.Latent): @@ -35,14 +36,14 @@ def __init__( def _build_prior(self, name, X, Xu, jitter=JITTER_DEFAULT, **kwargs): mu = self.mean_func(X) Kuu = self.cov_func(Xu) - L = cholesky(stabilize(Kuu, jitter)) + L = pt.linalg.cholesky(stabilize(Kuu, jitter)) n_inducing_points = np.shape(Xu)[0] v = pm.Normal(name + "_u_rotated_", mu=0.0, sigma=1.0, size=n_inducing_points, **kwargs) u = pm.Deterministic(name + "_u", L @ v) Kfu = self.cov_func(X, Xu) - Kuuiu = solve_upper(pt.transpose(L), solve_lower(L, u)) + Kuuiu = solve_triangular(pt.transpose(L), solve_triangular(L, u), lower=False) return pm.Deterministic(name, mu + Kfu @ Kuuiu), Kuuiu, L @@ -62,10 +63,10 @@ def prior(self, name, X, Xu=None, jitter=JITTER_DEFAULT, **kwargs): def _build_conditional(self, name, Xnew, Xu, L, Kuuiu, jitter, **kwargs): Ksu = self.cov_func(Xnew, Xu) mu = self.mean_func(Xnew) + Ksu @ Kuuiu - tmp = solve_lower(L, pt.transpose(Ksu)) + tmp = solve_triangular(L, pt.transpose(Ksu)) Qss = pt.transpose(tmp) @ tmp # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T) Kss = self.cov_func(Xnew) - Lss = cholesky(stabilize(Kss - Qss, jitter)) + Lss = pt.linalg.cholesky(stabilize(Kss - Qss, jitter)) return mu, Lss def conditional(self, name, Xnew, jitter=1e-6, **kwargs): @@ -123,7 +124,7 @@ def _build_conditional(self, Xnew, X, f, U, s, jitter): Kxxpinv = U @ pt.diag(1.0 / s) @ U.T mus = Kxs.T @ Kxxpinv @ f K = Kss - Kxs.T @ Kxxpinv @ Kxs - L = pm.gp.util.cholesky(pm.gp.util.stabilize(K, jitter)) + L = pt.linalg.cholesky(stabilize(K, jitter)) return mus, L def conditional(self, name, Xnew, jitter=1e-6, **kwargs): diff --git a/pymc_experimental/statespace/filters/kalman_filter.py b/pymc_experimental/statespace/filters/kalman_filter.py index 1bdde3d2..62aaea14 100644 --- a/pymc_experimental/statespace/filters/kalman_filter.py +++ b/pymc_experimental/statespace/filters/kalman_filter.py @@ -9,7 +9,7 @@ from pytensor.raise_op import Assert from pytensor.tensor import TensorVariable from pytensor.tensor.nlinalg import matrix_dot -from pytensor.tensor.slinalg import SolveTriangular +from pytensor.tensor.slinalg import solve_triangular from pymc_experimental.statespace.filters.utilities import ( quad_form_sym, @@ -22,7 +22,6 @@ MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64")) PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"] -solve_lower_triangular = SolveTriangular(lower=True) assert_data_is_1d = Assert("UnivariateTimeSeries filter requires data be at most 1-dimensional") assert_time_varying_dim_correct = Assert( "The first dimension of a time varying matrix (the time dimension) must be " @@ -684,13 +683,13 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag): F_chol = pt.linalg.cholesky(F) # If everything is missing, K = 0, IKZ = I - K = solve_lower_triangular(F_chol.T, solve_lower_triangular(F_chol, PZT.T)).T + K = solve_triangular(F_chol.T, solve_triangular(F_chol, PZT.T)).T I_KZ = self.eye_states - K.dot(Z) a_filtered = a + K.dot(v) P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H) - inner_term = solve_lower_triangular(F_chol.T, solve_lower_triangular(F_chol, v)) + inner_term = solve_triangular(F_chol.T, solve_triangular(F_chol, v)) n = y.shape[0] ll = pt.switch(