@@ -1756,6 +1756,7 @@ class TestMatrixNormal(BaseTestDistributionRandom):
1756
1756
"check_pymc_params_match_rv_op" ,
1757
1757
"check_draws" ,
1758
1758
"check_errors" ,
1759
+ "check_random_variable_prior" ,
1759
1760
]
1760
1761
1761
1762
def check_draws (self ):
@@ -1824,6 +1825,28 @@ def check_errors(self):
1824
1825
shape = 15 ,
1825
1826
)
1826
1827
1828
+ def check_random_variable_prior (self ):
1829
+ """
1830
+ This test checks for shape correctness when using MatrixNormal distribution
1831
+ with parameters as random variables.
1832
+ Originally reported - https://github.com/pymc-devs/pymc/issues/3585
1833
+ """
1834
+ K = 3
1835
+ D = 15
1836
+ mu_0 = np .zeros ((D , K ))
1837
+ lambd = 1.0
1838
+ with pm .Model () as model :
1839
+ sd_dist = pm .HalfCauchy .dist (beta = 2.5 , size = D )
1840
+ packedL = pm .LKJCholeskyCov ("packedL" , eta = 2 , n = D , sd_dist = sd_dist , compute_corr = False )
1841
+ L = pm .expand_packed_triangular (D , packedL , lower = True )
1842
+ Sigma = pm .Deterministic ("Sigma" , L .dot (L .T )) # D x D covariance
1843
+ mu = pm .MatrixNormal (
1844
+ "mu" , mu = mu_0 , rowcov = (1 / lambd ) * Sigma , colcov = np .eye (K ), shape = (D , K )
1845
+ )
1846
+ prior = pm .sample_prior_predictive (2 , return_inferencedata = False )
1847
+
1848
+ assert prior ["mu" ].shape == (2 , D , K )
1849
+
1827
1850
1828
1851
class TestInterpolated (BaseTestDistributionRandom ):
1829
1852
def interpolated_rng_fn (self , size , mu , sigma , rng ):
@@ -2435,30 +2458,6 @@ def generate_shapes(include_params=False):
2435
2458
return data
2436
2459
2437
2460
2438
- @pytest .mark .xfail (reason = "This distribution has not been refactored for v4" )
2439
- def test_matrix_normal_random_with_random_variables ():
2440
- """
2441
- This test checks for shape correctness when using MatrixNormal distribution
2442
- with parameters as random variables.
2443
- Originally reported - https://github.com/pymc-devs/pymc/issues/3585
2444
- """
2445
- K = 3
2446
- D = 15
2447
- mu_0 = np .zeros ((D , K ))
2448
- lambd = 1.0
2449
- with pm .Model () as model :
2450
- sd_dist = pm .HalfCauchy .dist (beta = 2.5 )
2451
- packedL = pm .LKJCholeskyCov ("packedL" , eta = 2 , n = D , sd_dist = sd_dist )
2452
- L = pm .expand_packed_triangular (D , packedL , lower = True )
2453
- Sigma = pm .Deterministic ("Sigma" , L .dot (L .T )) # D x D covariance
2454
- mu = pm .MatrixNormal (
2455
- "mu" , mu = mu_0 , rowcov = (1 / lambd ) * Sigma , colcov = np .eye (K ), shape = (D , K )
2456
- )
2457
- prior = pm .sample_prior_predictive (2 )
2458
-
2459
- assert prior ["mu" ].shape == (2 , D , K )
2460
-
2461
-
2462
2461
@pytest .mark .xfail (reason = "This distribution has not been refactored for v4" )
2463
2462
class TestMvGaussianRandomWalk (SeededTest ):
2464
2463
@pytest .mark .parametrize (
0 commit comments