@@ -47,7 +47,11 @@ def random_polyagamma(*args, **kwargs):
47
47
from pymc .distributions .discrete import _OrderedLogistic , _OrderedProbit
48
48
from pymc .distributions .dist_math import clipped_beta_rvs
49
49
from pymc .distributions .logprob import logp
50
- from pymc .distributions .multivariate import _OrderedMultinomial , quaddist_matrix
50
+ from pymc .distributions .multivariate import (
51
+ _LKJCholeskyCov ,
52
+ _OrderedMultinomial ,
53
+ quaddist_matrix ,
54
+ )
51
55
from pymc .distributions .shape_utils import to_tuple
52
56
from pymc .tests .helpers import SeededTest , select_by_precision
53
57
from pymc .tests .test_distributions import (
@@ -1867,6 +1871,43 @@ def ref_rand(size, n, eta):
1867
1871
)
1868
1872
1869
1873
1874
+ class TestLKJCholeskyCov (BaseTestDistributionRandom ):
1875
+ pymc_dist = _LKJCholeskyCov
1876
+ pymc_dist_params = {"eta" : 1.0 , "n" : 3 , "sd_dist" : pm .Constant .dist ([0.5 , 1.0 , 2.0 ])}
1877
+ expected_rv_op_params = {"n" : 3 , "eta" : 1.0 , "sd_dist" : pm .Constant .dist ([0.5 , 1.0 , 2.0 ])}
1878
+ size = None
1879
+
1880
+ sizes_to_check = [None , (), 1 , (1 ,), 5 , (4 , 5 ), (2 , 4 , 2 )]
1881
+ sizes_expected = [
1882
+ (6 ,),
1883
+ (6 ,),
1884
+ (1 , 6 ),
1885
+ (1 , 6 ),
1886
+ (5 , 6 ),
1887
+ (4 , 5 , 6 ),
1888
+ (2 , 4 , 2 , 6 ),
1889
+ ]
1890
+
1891
+ tests_to_run = [
1892
+ "check_rv_size" ,
1893
+ "check_draws_match_expected" ,
1894
+ ]
1895
+
1896
+ def check_rv_size (self ):
1897
+ for size , expected in zip (self .sizes_to_check , self .sizes_expected ):
1898
+ sd_dist = pm .Exponential .dist (1 , size = (* to_tuple (size ), 3 ))
1899
+ pymc_rv = self .pymc_dist .dist (n = 3 , eta = 1 , sd_dist = sd_dist , size = size )
1900
+ expected_symbolic = tuple (pymc_rv .shape .eval ())
1901
+ actual = pymc_rv .eval ().shape
1902
+ assert actual == expected_symbolic == expected
1903
+
1904
+ def check_draws_match_expected (self ):
1905
+ # TODO: Find better comparison:
1906
+ rng = aesara .shared (self .get_random_state (reset = True ))
1907
+ x = _LKJCholeskyCov .dist (n = 2 , eta = 10_000 , sd_dist = pm .Constant .dist ([0.5 , 2.0 ]), rng = rng )
1908
+ assert np .all (np .abs (x .eval () - np .array ([0.5 , 0 , 2.0 ])) < 0.01 )
1909
+
1910
+
1870
1911
class TestScalarParameterSamples (SeededTest ):
1871
1912
@pytest .mark .xfail (reason = "This distribution has not been refactored for v4" )
1872
1913
def test_normalmixture (self ):
@@ -2346,9 +2387,11 @@ def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape):
2346
2387
with pm .Model () as model :
2347
2388
mu = pm .Normal ("mu" , 0.0 , 1.0 , shape = mu_shape )
2348
2389
sd_dist = pm .Exponential .dist (1.0 , shape = 3 )
2390
+ # pylint: disable=unpacking-non-sequence
2349
2391
chol , corr , stds = pm .LKJCholeskyCov (
2350
2392
"chol_cov" , n = 3 , eta = 2 , sd_dist = sd_dist , compute_corr = True
2351
2393
)
2394
+ # pylint: enable=unpacking-non-sequence
2352
2395
mv = pm .MvNormal ("mv" , mu , chol = chol , shape = dist_shape )
2353
2396
prior = pm .sample_prior_predictive (samples = sample_shape )
2354
2397
@@ -2363,9 +2406,11 @@ def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
2363
2406
with pm .Model () as model :
2364
2407
mu = pm .Normal ("mu" , 0.0 , 1.0 , shape = mu_shape )
2365
2408
sd_dist = pm .Exponential .dist (1.0 , shape = 3 )
2409
+ # pylint: disable=unpacking-non-sequence
2366
2410
chol , corr , stds = pm .LKJCholeskyCov (
2367
2411
"chol_cov" , n = 3 , eta = 2 , sd_dist = sd_dist , compute_corr = True
2368
2412
)
2413
+ # pylint: enable=unpacking-non-sequence
2369
2414
mv = pm .MvNormal ("mv" , mu , cov = pm .math .dot (chol , chol .T ), shape = dist_shape )
2370
2415
prior = pm .sample_prior_predictive (samples = sample_shape )
2371
2416
@@ -2457,9 +2502,11 @@ def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape):
2457
2502
with pm .Model () as model :
2458
2503
mu = pm .Normal ("mu" , 0.0 , 1.0 , shape = mu_shape )
2459
2504
sd_dist = pm .Exponential .dist (1.0 , shape = 3 )
2505
+ # pylint: disable=unpacking-non-sequence
2460
2506
chol , corr , stds = pm .LKJCholeskyCov (
2461
2507
"chol_cov" , n = 3 , eta = 2 , sd_dist = sd_dist , compute_corr = True
2462
2508
)
2509
+ # pylint: enable=unpacking-non-sequence
2463
2510
mv = pm .MvGaussianRandomWalk ("mv" , mu , chol = chol , shape = dist_shape )
2464
2511
prior = pm .sample_prior_predictive (samples = sample_shape )
2465
2512
@@ -2475,9 +2522,11 @@ def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
2475
2522
with pm .Model () as model :
2476
2523
mu = pm .Normal ("mu" , 0.0 , 1.0 , shape = mu_shape )
2477
2524
sd_dist = pm .Exponential .dist (1.0 , shape = 3 )
2525
+ # pylint: disable=unpacking-non-sequence
2478
2526
chol , corr , stds = pm .LKJCholeskyCov (
2479
2527
"chol_cov" , n = 3 , eta = 2 , sd_dist = sd_dist , compute_corr = True
2480
2528
)
2529
+ # pylint: enable=unpacking-non-sequence
2481
2530
mv = pm .MvGaussianRandomWalk ("mv" , mu , cov = pm .math .dot (chol , chol .T ), shape = dist_shape )
2482
2531
prior = pm .sample_prior_predictive (samples = sample_shape )
2483
2532
0 commit comments