Skip to content

Commit eed60c3

Browse files
committed
Add LKJCorr moment
1 parent 5fd94ca commit eed60c3

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

pymc/distributions/multivariate.py

+3
Original file line numberDiff line numberDiff line change
@@ -1566,6 +1566,9 @@ def dist(cls, n, eta, **kwargs):
15661566
eta = at.as_tensor_variable(floatX(eta))
15671567
return super().dist([n, eta], **kwargs)
15681568

1569+
def get_moment(rv, *args):
1570+
return at.zeros_like(rv)
1571+
15691572
def logp(value, n, eta):
15701573
"""
15711574
Calculate log-probability of LKJ distribution at specified

pymc/tests/test_distributions_moments.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@
3939
KroneckerNormal,
4040
Kumaraswamy,
4141
Laplace,
42+
LKJCorr,
4243
Logistic,
4344
LogitNormal,
4445
LogNormal,
4546
MatrixNormal,
4647
Moyal,
4748
Multinomial,
49+
MvNormal,
4850
MvStudentT,
4951
NegativeBinomial,
5052
Normal,
@@ -68,7 +70,6 @@
6870
)
6971
from pymc.distributions.distribution import _get_moment, get_moment
7072
from pymc.distributions.logprob import joint_logpt
71-
from pymc.distributions.multivariate import MvNormal
7273
from pymc.distributions.shape_utils import rv_size_is_none, to_tuple
7374
from pymc.initial_point import make_initial_point_fn
7475
from pymc.model import Model
@@ -97,7 +98,6 @@ def test_all_distributions_have_moments():
9798

9899
# Distributions that have not been refactored for V4 yet
99100
not_implemented = {
100-
dist_module.multivariate.LKJCorr,
101101
dist_module.mixture.Mixture,
102102
dist_module.mixture.MixtureSameFamily,
103103
dist_module.mixture.NormalMixture,
@@ -1424,3 +1424,18 @@ def test_kronecker_normal_moments(mu, covs, size, expected):
14241424
with Model() as model:
14251425
KroneckerNormal("x", mu=mu, covs=covs, size=size)
14261426
assert_moment_is_expected(model, expected)
1427+
1428+
1429+
@pytest.mark.parametrize(
1430+
"n, eta, size, expected",
1431+
[
1432+
(3, 1, None, np.zeros(3)),
1433+
(5, 1, None, np.zeros(10)),
1434+
(3, 1, 1, np.zeros((1, 3))),
1435+
(5, 1, (2, 3), np.zeros((2, 3, 10))),
1436+
],
1437+
)
1438+
def test_lkjcorr_moment(n, eta, size, expected):
1439+
with Model() as model:
1440+
LKJCorr("x", n=n, eta=eta, size=size)
1441+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)