Skip to content

Fix failing default transform for LKJCorr #7065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 13, 2023
9 changes: 8 additions & 1 deletion pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,13 @@ def _random_corr_matrix(cls, rng, n, eta, flat_size):
lkjcorr = LKJCorrRV()


class MultivariateIntervalTransform(Interval):
name = "interval"

def log_jac_det(self, *args):
return super().log_jac_det(*args).sum(-1)


class LKJCorr(BoundedContinuous):
r"""
The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood.
Expand Down Expand Up @@ -1623,7 +1630,7 @@ def logp(value, n, eta):

@_default_transform.register(LKJCorr)
def lkjcorr_default_transform(op, rv):
return Interval(floatX(-1.0), floatX(1.0))
return MultivariateIntervalTransform(floatX(-1.0), floatX(1.0))


class MatrixNormalRV(RandomVariable):
Expand Down
7 changes: 7 additions & 0 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import pymc as pm

from pymc.distributions.multivariate import (
MultivariateIntervalTransform,
_LKJCholeskyCov,
_OrderedMultinomial,
posdef,
Expand Down Expand Up @@ -2121,6 +2122,12 @@ def ref_rand(size, n, eta):
size=1000,
)

def test_default_transform(self):
with pm.Model() as m:
x = pm.LKJCorr("x", n=2, eta=1, shape=(3, 2))
assert isinstance(m.rvs_to_transforms[x], MultivariateIntervalTransform)
assert m.logp(sum=False)[0].shape == (3,)


class TestLKJCholeskyCov(BaseTestDistributionRandom):
pymc_dist = _LKJCholeskyCov
Expand Down