Skip to content

Commit 9109c65

Browse files
committed
Automatically resize sd_dist in _LKJCholeskyCov
1 parent a92a414 commit 9109c65

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

pymc/distributions/multivariate.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
import pymc as pm
4343

44-
from pymc.aesaraf import floatX, intX
44+
from pymc.aesaraf import change_rv_size, floatX, intX
4545
from pymc.distributions import transforms
4646
from pymc.distributions.continuous import (
4747
BoundedContinuous,
@@ -1199,8 +1199,21 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs):
11991199
isinstance(sd_dist, Variable)
12001200
and sd_dist.owner is not None
12011201
and isinstance(sd_dist.owner.op, RandomVariable)
1202+
and sd_dist.owner.op.ndim_supp < 2
12021203
):
1203-
raise TypeError("sd_dist must be a Distribution variable")
1204+
raise TypeError("sd_dist must be a scalar or vector distribution variable")
1205+
1206+
# We resize the sd_dist automatically so that it has (size x n) independent draws
1207+
# which is what the `_LKJCholeskyCovRV.rng_fn` expects. This makes the random
1208+
# and logp methods equivalent, as the latter also assumes a unique value for each
1209+
# diagonal element.
1210+
# Since `eta` and `n` are forced to be scalars we don't need to worry about
1211+
# implied batched dimensions for the time being.
1212+
if sd_dist.owner.op.ndim_supp == 0:
1213+
sd_dist = change_rv_size(sd_dist, to_tuple(size) + (n,))
1214+
else:
1215+
# The support shape must be `n` but we have no way of controlling it
1216+
sd_dist = change_rv_size(sd_dist, to_tuple(size))
12041217

12051218
# sd_dist is part of the generative graph, but should be completely ignored
12061219
# by the logp graph, since the LKJ logp explicitly includes these terms.
@@ -1288,7 +1301,9 @@ class LKJCholeskyCov:
12881301
n: int
12891302
Dimension of the covariance matrix (n > 1).
12901303
sd_dist: pm.Distribution
1291-
A distribution for the standard deviations, should have `size=n`.
1304+
A positive scalar or vector distribution for the standard deviations, created
1305+
with the `.dist()` API. Should have `shape[-1]=n`. Scalar distributions will be
1306+
automatically resized to ensure this.
12921307
compute_corr: bool, default=True
12931308
If `True`, returns three values: the Cholesky decomposition, the correlations
12941309
and the standard deviations of the covariance matrix. Otherwise, only returns

pymc/tests/test_distributions.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -3382,7 +3382,7 @@ def test_dist(self):
33823382
def test_sd_dist_distribution(self):
33833383
with pm.Model() as m:
33843384
sd_dist = at.constant([1, 2, 3])
3385-
with pytest.raises(TypeError, match="sd_dist must be a Distribution variable"):
3385+
with pytest.raises(TypeError, match="^sd_dist must be a scalar or vector distribution"):
33863386
x = pm.LKJCholeskyCov("x", n=3, eta=1, sd_dist=sd_dist)
33873387

33883388
def test_sd_dist_registered(self):
@@ -3402,3 +3402,17 @@ def test_no_warning_logp(self):
34023402
with pytest.warns(None) as record:
34033403
m.logpt()
34043404
assert not record
3405+
3406+
@pytest.mark.parametrize(
3407+
"sd_dist",
3408+
[
3409+
pm.Exponential.dist(1),
3410+
pm.MvNormal.dist(np.ones(3), np.eye(3)),
3411+
],
3412+
)
3413+
def test_sd_dist_automatically_resized(self, sd_dist):
3414+
x = pm.LKJCholeskyCov.dist(n=3, eta=1, sd_dist=sd_dist, size=10, compute_corr=False)
3415+
resized_sd_dist = x.owner.inputs[-1]
3416+
assert resized_sd_dist.eval().shape == (10, 3)
3417+
# LKJCov has support shape `(n * (n+1)) // 2`
3418+
assert x.eval().shape == (10, 6)

0 commit comments

Comments
 (0)