diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 42a2a9f62c..ca1760c387 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -1524,6 +1524,13 @@ def _random_corr_matrix(cls, rng, n, eta, flat_size): lkjcorr = LKJCorrRV() +class MultivariateIntervalTransform(Interval): + name = "interval" + + def log_jac_det(self, *args): + return super().log_jac_det(*args).sum(-1) + + class LKJCorr(BoundedContinuous): r""" The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood. @@ -1592,6 +1599,9 @@ def logp(value, n, eta): TensorVariable """ + if value.ndim > 1: + raise NotImplementedError("LKJCorr logp is only implemented for vector values (ndim=1)") + # TODO: PyTensor does not have a `triu_indices`, so we can only work with constant # n (or else find a different expression) if not isinstance(n, Constant): @@ -1623,7 +1633,7 @@ def logp(value, n, eta): @_default_transform.register(LKJCorr) def lkjcorr_default_transform(op, rv): - return Interval(floatX(-1.0), floatX(1.0)) + return MultivariateIntervalTransform(floatX(-1.0), floatX(1.0)) class MatrixNormalRV(RandomVariable): diff --git a/pymc/logprob/transform_value.py b/pymc/logprob/transform_value.py index ee37fdca1d..4bcf130f76 100644 --- a/pymc/logprob/transform_value.py +++ b/pymc/logprob/transform_value.py @@ -125,13 +125,12 @@ def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs) raise NotImplementedError( f"Univariate transform {transform} cannot be applied to multivariate {rv_op}" ) - else: - # Check there is no broadcasting between logp and jacobian - if logp.type.broadcastable != log_jac_det.type.broadcastable: - raise ValueError( - f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. " - "There is a bug in the implementation of either one." - ) + # Check there is no broadcasting between logp and jacobian + if logp.type.broadcastable != log_jac_det.type.broadcastable: + raise ValueError( + f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. " + "There is a bug in the implementation of either one." + ) if use_jacobian: if value.name: diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 027d0b1915..2390804c0f 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -13,7 +13,6 @@ # limitations under the License. import functools as ft -import re import warnings import numpy as np @@ -33,6 +32,7 @@ import pymc as pm from pymc.distributions.multivariate import ( + MultivariateIntervalTransform, _LKJCholeskyCov, _OrderedMultinomial, posdef, @@ -1306,8 +1306,26 @@ def test_kronecker_normal_moment(self, mu, covs, size, expected): [ (3, 1, None, np.zeros(3)), (5, 1, None, np.zeros(10)), - (3, 1, 1, np.zeros((1, 3))), - (5, 1, (2, 3), np.zeros((2, 3, 10))), + pytest.param( + 3, + 1, + 1, + np.zeros((1, 3)), + marks=pytest.mark.xfail( + raises=NotImplementedError, + reason="LKJCorr logp is only implemented for vector values (ndim=1)", + ), + ), + pytest.param( + 5, + 1, + (2, 3), + np.zeros((2, 3, 10)), + marks=pytest.mark.xfail( + raises=NotImplementedError, + reason="LKJCorr logp is only implemented for vector values (ndim=1)", + ), + ), ], ) def test_lkjcorr_moment(self, n, eta, size, expected): @@ -2122,6 +2140,26 @@ def ref_rand(size, n, eta): ) +@pytest.mark.parametrize( + argnames="shape", + argvalues=[ + (2,), + pytest.param( + (3, 2), + marks=pytest.mark.xfail( + raises=NotImplementedError, + reason="LKJCorr logp is only implemented for vector values (ndim=1)", + ), + ), + ], +) +def test_LKJCorr_default_transform(shape): + with pm.Model() as m: + x = pm.LKJCorr("x", n=2, eta=1, shape=shape) + assert isinstance(m.rvs_to_transforms[x], MultivariateIntervalTransform) + assert m.logp(sum=False)[0].type.shape == shape[:-1] + + class TestLKJCholeskyCov(BaseTestDistributionRandom): pymc_dist = _LKJCholeskyCov pymc_dist_params = {"n": 3, "eta": 1.0, "sd_dist": pm.DiracDelta.dist([0.5, 1.0, 2.0])}