Skip to content

Fix failing default transform for LKJCorr #7065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 13, 2023
12 changes: 11 additions & 1 deletion pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,13 @@ def _random_corr_matrix(cls, rng, n, eta, flat_size):
lkjcorr = LKJCorrRV()


class MultivariateIntervalTransform(Interval):
name = "interval"

def log_jac_det(self, *args):
return super().log_jac_det(*args).sum(-1)


class LKJCorr(BoundedContinuous):
r"""
The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood.
Expand Down Expand Up @@ -1592,6 +1599,9 @@ def logp(value, n, eta):
TensorVariable
"""

if value.ndim > 1:
raise NotImplementedError("LKJCorr logp is only implemented for vector values (ndim=1)")

# TODO: PyTensor does not have a `triu_indices`, so we can only work with constant
# n (or else find a different expression)
if not isinstance(n, Constant):
Expand Down Expand Up @@ -1623,7 +1633,7 @@ def logp(value, n, eta):

@_default_transform.register(LKJCorr)
def lkjcorr_default_transform(op, rv):
return Interval(floatX(-1.0), floatX(1.0))
return MultivariateIntervalTransform(floatX(-1.0), floatX(1.0))


class MatrixNormalRV(RandomVariable):
Expand Down
13 changes: 6 additions & 7 deletions pymc/logprob/transform_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,12 @@ def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs)
raise NotImplementedError(
f"Univariate transform {transform} cannot be applied to multivariate {rv_op}"
)
else:
# Check there is no broadcasting between logp and jacobian
if logp.type.broadcastable != log_jac_det.type.broadcastable:
raise ValueError(
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
"There is a bug in the implementation of either one."
)
# Check there is no broadcasting between logp and jacobian
if logp.type.broadcastable != log_jac_det.type.broadcastable:
raise ValueError(
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
"There is a bug in the implementation of either one."
)

if use_jacobian:
if value.name:
Expand Down
44 changes: 41 additions & 3 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import functools as ft
import re
import warnings

import numpy as np
Expand All @@ -33,6 +32,7 @@
import pymc as pm

from pymc.distributions.multivariate import (
MultivariateIntervalTransform,
_LKJCholeskyCov,
_OrderedMultinomial,
posdef,
Expand Down Expand Up @@ -1306,8 +1306,26 @@ def test_kronecker_normal_moment(self, mu, covs, size, expected):
[
(3, 1, None, np.zeros(3)),
(5, 1, None, np.zeros(10)),
(3, 1, 1, np.zeros((1, 3))),
(5, 1, (2, 3), np.zeros((2, 3, 10))),
pytest.param(
3,
1,
1,
np.zeros((1, 3)),
marks=pytest.mark.xfail(
raises=NotImplementedError,
reason="LKJCorr logp is only implemented for vector values (ndim=1)",
),
),
pytest.param(
5,
1,
(2, 3),
np.zeros((2, 3, 10)),
marks=pytest.mark.xfail(
raises=NotImplementedError,
reason="LKJCorr logp is only implemented for vector values (ndim=1)",
),
),
],
)
def test_lkjcorr_moment(self, n, eta, size, expected):
Expand Down Expand Up @@ -2122,6 +2140,26 @@ def ref_rand(size, n, eta):
)


@pytest.mark.parametrize(
argnames="shape",
argvalues=[
(2,),
pytest.param(
(3, 2),
marks=pytest.mark.xfail(
raises=NotImplementedError,
reason="LKJCorr logp is only implemented for vector values (ndim=1)",
),
),
],
)
def test_LKJCorr_default_transform(shape):
with pm.Model() as m:
x = pm.LKJCorr("x", n=2, eta=1, shape=shape)
assert isinstance(m.rvs_to_transforms[x], MultivariateIntervalTransform)
assert m.logp(sum=False)[0].type.shape == shape[:-1]


class TestLKJCholeskyCov(BaseTestDistributionRandom):
pymc_dist = _LKJCholeskyCov
pymc_dist_params = {"n": 3, "eta": 1.0, "sd_dist": pm.DiracDelta.dist([0.5, 1.0, 2.0])}
Expand Down