Skip to content

Commit 5f43bb4

Browse files
committed
Reenable old MatrixNormal test
1 parent 45bbc9e commit 5f43bb4

File tree

2 files changed

+25
-26
lines changed

2 files changed

+25
-26
lines changed

Diff for: pymc/tests/test_distributions_random.py

+23-24
Original file line numberDiff line numberDiff line change
@@ -1756,6 +1756,7 @@ class TestMatrixNormal(BaseTestDistributionRandom):
17561756
"check_pymc_params_match_rv_op",
17571757
"check_draws",
17581758
"check_errors",
1759+
"check_random_variable_prior",
17591760
]
17601761

17611762
def check_draws(self):
@@ -1824,6 +1825,28 @@ def check_errors(self):
18241825
shape=15,
18251826
)
18261827

1828+
def check_random_variable_prior(self):
1829+
"""
1830+
This test checks for shape correctness when using MatrixNormal distribution
1831+
with parameters as random variables.
1832+
Originally reported - https://github.com/pymc-devs/pymc/issues/3585
1833+
"""
1834+
K = 3
1835+
D = 15
1836+
mu_0 = np.zeros((D, K))
1837+
lambd = 1.0
1838+
with pm.Model() as model:
1839+
sd_dist = pm.HalfCauchy.dist(beta=2.5, size=D)
1840+
packedL = pm.LKJCholeskyCov("packedL", eta=2, n=D, sd_dist=sd_dist, compute_corr=False)
1841+
L = pm.expand_packed_triangular(D, packedL, lower=True)
1842+
Sigma = pm.Deterministic("Sigma", L.dot(L.T)) # D x D covariance
1843+
mu = pm.MatrixNormal(
1844+
"mu", mu=mu_0, rowcov=(1 / lambd) * Sigma, colcov=np.eye(K), shape=(D, K)
1845+
)
1846+
prior = pm.sample_prior_predictive(2, return_inferencedata=False)
1847+
1848+
assert prior["mu"].shape == (2, D, K)
1849+
18271850

18281851
class TestInterpolated(BaseTestDistributionRandom):
18291852
def interpolated_rng_fn(self, size, mu, sigma, rng):
@@ -2435,30 +2458,6 @@ def generate_shapes(include_params=False):
24352458
return data
24362459

24372460

2438-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
2439-
def test_matrix_normal_random_with_random_variables():
2440-
"""
2441-
This test checks for shape correctness when using MatrixNormal distribution
2442-
with parameters as random variables.
2443-
Originally reported - https://github.com/pymc-devs/pymc/issues/3585
2444-
"""
2445-
K = 3
2446-
D = 15
2447-
mu_0 = np.zeros((D, K))
2448-
lambd = 1.0
2449-
with pm.Model() as model:
2450-
sd_dist = pm.HalfCauchy.dist(beta=2.5)
2451-
packedL = pm.LKJCholeskyCov("packedL", eta=2, n=D, sd_dist=sd_dist)
2452-
L = pm.expand_packed_triangular(D, packedL, lower=True)
2453-
Sigma = pm.Deterministic("Sigma", L.dot(L.T)) # D x D covariance
2454-
mu = pm.MatrixNormal(
2455-
"mu", mu=mu_0, rowcov=(1 / lambd) * Sigma, colcov=np.eye(K), shape=(D, K)
2456-
)
2457-
prior = pm.sample_prior_predictive(2)
2458-
2459-
assert prior["mu"].shape == (2, D, K)
2460-
2461-
24622461
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
24632462
class TestMvGaussianRandomWalk(SeededTest):
24642463
@pytest.mark.parametrize(

Diff for: pymc/tests/test_idata_conversion.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -331,14 +331,14 @@ def test_missing_data_model(self):
331331
# See https://github.com/pymc-devs/pymc/issues/5255
332332
assert inference_data.log_likelihood["y_observed"].shape == (2, 100, 3)
333333

334-
@pytest.mark.xfal(reason="Multivariate partial observed RVs not implemented for V4")
334+
@pytest.mark.xfail(reason="Multivariate partial observed RVs not implemented for V4")
335335
def test_mv_missing_data_model(self):
336336
data = ma.masked_values([[1, 2], [2, 2], [-1, 4], [2, -1], [-1, -1]], value=-1)
337337

338338
model = pm.Model()
339339
with model:
340340
mu = pm.Normal("mu", 0, 1, size=2)
341-
sd_dist = pm.HalfNormal.dist(1.0)
341+
sd_dist = pm.HalfNormal.dist(1.0, size=2)
342342
# pylint: disable=unpacking-non-sequence
343343
chol, *_ = pm.LKJCholeskyCov("chol_cov", n=2, eta=1, sd_dist=sd_dist, compute_corr=True)
344344
# pylint: enable=unpacking-non-sequence

0 commit comments

Comments
 (0)