Skip to content

Commit 01deb8a

Browse files
Bump minimum PyMC dependency
1 parent 13e88e8 commit 01deb8a

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

pymc_experimental/statespace/filters/kalman_filter.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytensor.raise_op import Assert
1010
from pytensor.tensor import TensorVariable
1111
from pytensor.tensor.nlinalg import matrix_dot
12-
from pytensor.tensor.slinalg import SolveTriangular
12+
from pytensor.tensor.slinalg import solve_triangular
1313

1414
from pymc_experimental.statespace.filters.utilities import (
1515
quad_form_sym,
@@ -22,7 +22,6 @@
2222
MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
2323
PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]
2424

25-
solve_lower_triangular = SolveTriangular(lower=True)
2625
assert_data_is_1d = Assert("UnivariateTimeSeries filter requires data be at most 1-dimensional")
2726
assert_time_varying_dim_correct = Assert(
2827
"The first dimension of a time varying matrix (the time dimension) must be "
@@ -684,13 +683,13 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
684683
F_chol = pt.linalg.cholesky(F)
685684

686685
# If everything is missing, K = 0, IKZ = I
687-
K = solve_lower_triangular(F_chol.T, solve_lower_triangular(F_chol, PZT.T)).T
686+
K = solve_triangular(F_chol.T, solve_triangular(F_chol, PZT.T)).T
688687
I_KZ = self.eye_states - K.dot(Z)
689688

690689
a_filtered = a + K.dot(v)
691690
P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H)
692691

693-
inner_term = solve_lower_triangular(F_chol.T, solve_lower_triangular(F_chol, v))
692+
inner_term = solve_triangular(F_chol.T, solve_triangular(F_chol, v))
694693
n = y.shape[0]
695694

696695
ll = pt.switch(

0 commit comments

Comments
 (0)