Skip to content

Commit 7e464e5

Browse files
author
Brandon T. Willard
committed
Updates for Aesara Solve changes
1 parent a061106 commit 7e464e5

File tree

3 files changed

+19
-18
lines changed

3 files changed

+19
-18
lines changed

pymc3/distributions/dist_math.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
from aesara.scan import until
3232
from aesara.tensor import gammaln
3333
from aesara.tensor.elemwise import Elemwise
34-
from aesara.tensor.slinalg import Cholesky, Solve
34+
from aesara.tensor.slinalg import Cholesky
35+
from aesara.tensor.slinalg import solve_lower_triangular as solve_lower
36+
from aesara.tensor.slinalg import solve_upper_triangular as solve_upper
3537

3638
from pymc3.aesaraf import floatX
3739
from pymc3.distributions.shape_utils import to_tuple
@@ -267,8 +269,6 @@ def MvNormalLogp():
267269
delta = at.matrix("delta")
268270
delta.tag.test_value = floatX(np.zeros((2, 3)))
269271

270-
solve_lower = Solve(A_structure="lower_triangular")
271-
solve_upper = Solve(A_structure="upper_triangular")
272272
cholesky = Cholesky(lower=True, on_error="nan")
273273

274274
n, k = delta.shape

pymc3/distributions/multivariate.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,9 @@
3131
from aesara.tensor.random.basic import MultinomialRV, dirichlet, multivariate_normal
3232
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
3333
from aesara.tensor.random.utils import broadcast_params
34-
from aesara.tensor.slinalg import (
35-
Cholesky,
36-
Solve,
37-
solve_lower_triangular,
38-
solve_upper_triangular,
39-
)
34+
from aesara.tensor.slinalg import Cholesky
35+
from aesara.tensor.slinalg import solve_lower_triangular as solve_lower
36+
from aesara.tensor.slinalg import solve_upper_triangular as solve_upper
4037
from aesara.tensor.type import TensorType
4138
from scipy import linalg, stats
4239

@@ -66,7 +63,6 @@
6663
"CAR",
6764
]
6865

69-
solve_lower = Solve(A_structure="lower_triangular")
7066
# Step methods and advi do not catch LinAlgErrors at the
7167
# moment. We work around that by using a cholesky op
7268
# that returns a nan as first entry instead of raising
@@ -1716,10 +1712,10 @@ def logp(value, mu, rowchol, colchol):
17161712
delta = value - mu
17171713

17181714
# Find exponent piece by piece
1719-
right_quaddist = solve_lower_triangular(rowchol, delta)
1715+
right_quaddist = solve_lower(rowchol, delta)
17201716
quaddist = at.nlinalg.matrix_dot(right_quaddist.T, right_quaddist)
1721-
quaddist = solve_lower_triangular(colchol, quaddist)
1722-
quaddist = solve_upper_triangular(colchol.T, quaddist)
1717+
quaddist = solve_lower(colchol, quaddist)
1718+
quaddist = solve_upper(colchol.T, quaddist)
17231719
trquaddist = at.nlinalg.trace(quaddist)
17241720

17251721
coldiag = at.diag(colchol)

pymc3/gp/util.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,19 @@
1717
import aesara.tensor as at
1818
import numpy as np
1919

20-
from aesara.tensor.slinalg import Solve, cholesky # pylint: disable=unused-import
20+
from aesara.tensor.slinalg import ( # noqa: W0611; pylint: disable=unused-import
21+
cholesky,
22+
solve,
23+
)
24+
from aesara.tensor.slinalg import ( # noqa: W0611; pylint: disable=unused-import
25+
solve_lower_triangular as solve_lower,
26+
)
27+
from aesara.tensor.slinalg import ( # noqa: W0611; pylint: disable=unused-import
28+
solve_upper_triangular as solve_upper,
29+
)
2130
from aesara.tensor.var import TensorConstant
2231
from scipy.cluster.vq import kmeans
2332

24-
solve_lower = Solve(A_structure="lower_triangular")
25-
solve_upper = Solve(A_structure="upper_triangular")
26-
solve = Solve(A_structure="general")
27-
2833

2934
def infer_shape(X, n_points=None):
3035
R"""

0 commit comments

Comments
 (0)