|
9 | 9 | from pytensor.raise_op import Assert
|
10 | 10 | from pytensor.tensor import TensorVariable
|
11 | 11 | from pytensor.tensor.nlinalg import matrix_dot
|
12 |
| -from pytensor.tensor.slinalg import SolveTriangular |
| 12 | +from pytensor.tensor.slinalg import solve_triangular |
13 | 13 |
|
14 | 14 | from pymc_experimental.statespace.filters.utilities import (
|
15 | 15 | quad_form_sym,
|
|
22 | 22 | MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
|
23 | 23 | PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]
|
24 | 24 |
|
25 |
| -solve_lower_triangular = SolveTriangular(lower=True) |
26 | 25 | assert_data_is_1d = Assert("UnivariateTimeSeries filter requires data be at most 1-dimensional")
|
27 | 26 | assert_time_varying_dim_correct = Assert(
|
28 | 27 | "The first dimension of a time varying matrix (the time dimension) must be "
|
@@ -684,13 +683,13 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
|
684 | 683 | F_chol = pt.linalg.cholesky(F)
|
685 | 684 |
|
686 | 685 | # If everything is missing, K = 0, IKZ = I
|
687 |
| - K = solve_lower_triangular(F_chol.T, solve_lower_triangular(F_chol, PZT.T)).T |
| 686 | + K = solve_triangular(F_chol.T, solve_triangular(F_chol, PZT.T)).T |
688 | 687 | I_KZ = self.eye_states - K.dot(Z)
|
689 | 688 |
|
690 | 689 | a_filtered = a + K.dot(v)
|
691 | 690 | P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H)
|
692 | 691 |
|
693 |
| - inner_term = solve_lower_triangular(F_chol.T, solve_lower_triangular(F_chol, v)) |
| 692 | + inner_term = solve_triangular(F_chol.T, solve_triangular(F_chol, v)) |
694 | 693 | n = y.shape[0]
|
695 | 694 |
|
696 | 695 | ll = pt.switch(
|
|
0 commit comments