From b2cdffd362c35fe1bde5affbc17c690d9bf50a31 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 16 Sep 2023 16:50:33 +0200 Subject: [PATCH 1/2] Do not predefine custom Cholesky and SolveTriangular Ops --- pymc_experimental/gp/latent_approx.py | 14 ++++++++------ .../statespace/filters/kalman_filter.py | 6 ++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pymc_experimental/gp/latent_approx.py b/pymc_experimental/gp/latent_approx.py index abb9fc02..8fa72f30 100644 --- a/pymc_experimental/gp/latent_approx.py +++ b/pymc_experimental/gp/latent_approx.py @@ -16,7 +16,7 @@ import numpy as np import pymc as pm import pytensor.tensor as pt -from pymc.gp.util import JITTER_DEFAULT, cholesky, solve_lower, solve_upper, stabilize +from pymc.gp.util import JITTER_DEFAULT, stabilize class LatentApprox(pm.gp.Latent): @@ -35,14 +35,16 @@ def __init__( def _build_prior(self, name, X, Xu, jitter=JITTER_DEFAULT, **kwargs): mu = self.mean_func(X) Kuu = self.cov_func(Xu) - L = cholesky(stabilize(Kuu, jitter)) + L = pt.linalg.cholesky(stabilize(Kuu, jitter)) n_inducing_points = np.shape(Xu)[0] v = pm.Normal(name + "_u_rotated_", mu=0.0, sigma=1.0, size=n_inducing_points, **kwargs) u = pm.Deterministic(name + "_u", L @ v) Kfu = self.cov_func(X, Xu) - Kuuiu = solve_upper(pt.transpose(L), solve_lower(L, u)) + Kuuiu = pt.linalg.solve_triangular( + pt.transpose(L), pt.linalg.solve_triangular(L, u), lower=False + ) return pm.Deterministic(name, mu + Kfu @ Kuuiu), Kuuiu, L @@ -62,10 +64,10 @@ def prior(self, name, X, Xu=None, jitter=JITTER_DEFAULT, **kwargs): def _build_conditional(self, name, Xnew, Xu, L, Kuuiu, jitter, **kwargs): Ksu = self.cov_func(Xnew, Xu) mu = self.mean_func(Xnew) + Ksu @ Kuuiu - tmp = solve_lower(L, pt.transpose(Ksu)) + tmp = pt.linalg.solve_triangular(L, pt.transpose(Ksu)) Qss = pt.transpose(tmp) @ tmp # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T) Kss = self.cov_func(Xnew) - Lss = cholesky(stabilize(Kss - Qss, jitter)) + Lss = pt.linalg.cholesky(stabilize(Kss - Qss, jitter)) return mu, Lss def conditional(self, name, Xnew, jitter=1e-6, **kwargs): @@ -123,7 +125,7 @@ def _build_conditional(self, Xnew, X, f, U, s, jitter): Kxxpinv = U @ pt.diag(1.0 / s) @ U.T mus = Kxs.T @ Kxxpinv @ f K = Kss - Kxs.T @ Kxxpinv @ Kxs - L = pm.gp.util.cholesky(pm.gp.util.stabilize(K, jitter)) + L = pt.linalg.cholesky(stabilize(K, jitter)) return mus, L def conditional(self, name, Xnew, jitter=1e-6, **kwargs): diff --git a/pymc_experimental/statespace/filters/kalman_filter.py b/pymc_experimental/statespace/filters/kalman_filter.py index 1bdde3d2..f3bcbd1e 100644 --- a/pymc_experimental/statespace/filters/kalman_filter.py +++ b/pymc_experimental/statespace/filters/kalman_filter.py @@ -9,7 +9,6 @@ from pytensor.raise_op import Assert from pytensor.tensor import TensorVariable from pytensor.tensor.nlinalg import matrix_dot -from pytensor.tensor.slinalg import SolveTriangular from pymc_experimental.statespace.filters.utilities import ( quad_form_sym, @@ -22,7 +21,6 @@ MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64")) PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"] -solve_lower_triangular = SolveTriangular(lower=True) assert_data_is_1d = Assert("UnivariateTimeSeries filter requires data be at most 1-dimensional") assert_time_varying_dim_correct = Assert( "The first dimension of a time varying matrix (the time dimension) must be " @@ -684,13 +682,13 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag): F_chol = pt.linalg.cholesky(F) # If everything is missing, K = 0, IKZ = I - K = solve_lower_triangular(F_chol.T, solve_lower_triangular(F_chol, PZT.T)).T + K = pt.linalg.solve_triangular(F_chol.T, pt.linalg.solve_triangular(F_chol, PZT.T)).T I_KZ = self.eye_states - K.dot(Z) a_filtered = a + K.dot(v) P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H) - inner_term = solve_lower_triangular(F_chol.T, solve_lower_triangular(F_chol, v)) + inner_term = pt.linalg.solve_triangular(F_chol.T, pt.linalg.solve_triangular(F_chol, v)) n = y.shape[0] ll = pt.switch( From 2a76fec2894c616ed0cc173dc83550e46d8f4c34 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 16 Sep 2023 17:08:39 +0200 Subject: [PATCH 2/2] Import `solve_triangular` from `pytensor.linalg.slinalg` for better compatibility with previous pytensor versions --- pymc_experimental/gp/latent_approx.py | 7 +++---- pymc_experimental/statespace/filters/kalman_filter.py | 5 +++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pymc_experimental/gp/latent_approx.py b/pymc_experimental/gp/latent_approx.py index 8fa72f30..dbcc269e 100644 --- a/pymc_experimental/gp/latent_approx.py +++ b/pymc_experimental/gp/latent_approx.py @@ -17,6 +17,7 @@ import pymc as pm import pytensor.tensor as pt from pymc.gp.util import JITTER_DEFAULT, stabilize +from pytensor.tensor.slinalg import solve_triangular class LatentApprox(pm.gp.Latent): @@ -42,9 +43,7 @@ def _build_prior(self, name, X, Xu, jitter=JITTER_DEFAULT, **kwargs): u = pm.Deterministic(name + "_u", L @ v) Kfu = self.cov_func(X, Xu) - Kuuiu = pt.linalg.solve_triangular( - pt.transpose(L), pt.linalg.solve_triangular(L, u), lower=False - ) + Kuuiu = solve_triangular(pt.transpose(L), solve_triangular(L, u), lower=False) return pm.Deterministic(name, mu + Kfu @ Kuuiu), Kuuiu, L @@ -64,7 +63,7 @@ def prior(self, name, X, Xu=None, jitter=JITTER_DEFAULT, **kwargs): def _build_conditional(self, name, Xnew, Xu, L, Kuuiu, jitter, **kwargs): Ksu = self.cov_func(Xnew, Xu) mu = self.mean_func(Xnew) + Ksu @ Kuuiu - tmp = pt.linalg.solve_triangular(L, pt.transpose(Ksu)) + tmp = solve_triangular(L, pt.transpose(Ksu)) Qss = pt.transpose(tmp) @ tmp # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T) Kss = self.cov_func(Xnew) Lss = pt.linalg.cholesky(stabilize(Kss - Qss, jitter)) diff --git a/pymc_experimental/statespace/filters/kalman_filter.py b/pymc_experimental/statespace/filters/kalman_filter.py index f3bcbd1e..62aaea14 100644 --- a/pymc_experimental/statespace/filters/kalman_filter.py +++ b/pymc_experimental/statespace/filters/kalman_filter.py @@ -9,6 +9,7 @@ from pytensor.raise_op import Assert from pytensor.tensor import TensorVariable from pytensor.tensor.nlinalg import matrix_dot +from pytensor.tensor.slinalg import solve_triangular from pymc_experimental.statespace.filters.utilities import ( quad_form_sym, @@ -682,13 +683,13 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag): F_chol = pt.linalg.cholesky(F) # If everything is missing, K = 0, IKZ = I - K = pt.linalg.solve_triangular(F_chol.T, pt.linalg.solve_triangular(F_chol, PZT.T)).T + K = solve_triangular(F_chol.T, solve_triangular(F_chol, PZT.T)).T I_KZ = self.eye_states - K.dot(Z) a_filtered = a + K.dot(v) P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H) - inner_term = pt.linalg.solve_triangular(F_chol.T, pt.linalg.solve_triangular(F_chol, v)) + inner_term = solve_triangular(F_chol.T, solve_triangular(F_chol, v)) n = y.shape[0] ll = pt.switch(