|
41 | 41 |
|
42 | 42 | import pymc as pm
|
43 | 43 |
|
44 |
| -from pymc.aesaraf import floatX, intX |
| 44 | +from pymc.aesaraf import change_rv_size, floatX, intX |
45 | 45 | from pymc.distributions import transforms
|
46 | 46 | from pymc.distributions.continuous import (
|
47 | 47 | BoundedContinuous,
|
@@ -1199,8 +1199,21 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs):
|
1199 | 1199 | isinstance(sd_dist, Variable)
|
1200 | 1200 | and sd_dist.owner is not None
|
1201 | 1201 | and isinstance(sd_dist.owner.op, RandomVariable)
|
| 1202 | + and sd_dist.owner.op.ndim_supp < 2 |
1202 | 1203 | ):
|
1203 |
| - raise TypeError("sd_dist must be a Distribution variable") |
| 1204 | + raise TypeError("sd_dist must be a scalar or vector distribution variable") |
| 1205 | + |
| 1206 | + # We resize the sd_dist automatically so that it has (size x n) independent draws |
| 1207 | + # which is what the `_LKJCholeskyCovRV.rng_fn` expects. This makes the random |
| 1208 | + # and logp methods equivalent, as the latter also assumes a unique value for each |
| 1209 | + # diagonal element. |
| 1210 | + # Since `eta` and `n` are forced to be scalars we don't need to worry about |
| 1211 | + # implied batched dimensions for the time being. |
| 1212 | + if sd_dist.owner.op.ndim_supp == 0: |
| 1213 | + sd_dist = change_rv_size(sd_dist, to_tuple(size) + (n,)) |
| 1214 | + else: |
| 1215 | + # The support shape must be `n` but we have no way of controlling it |
| 1216 | + sd_dist = change_rv_size(sd_dist, to_tuple(size)) |
1204 | 1217 |
|
1205 | 1218 | # sd_dist is part of the generative graph, but should be completely ignored
|
1206 | 1219 | # by the logp graph, since the LKJ logp explicitly includes these terms.
|
@@ -1288,7 +1301,9 @@ class LKJCholeskyCov:
|
1288 | 1301 | n: int
|
1289 | 1302 | Dimension of the covariance matrix (n > 1).
|
1290 | 1303 | sd_dist: pm.Distribution
|
1291 |
| - A distribution for the standard deviations, should have `size=n`. |
| 1304 | + A positive scalar or vector distribution for the standard deviations, created |
| 1305 | + with the `.dist()` API. Should have `shape[-1]=n`. Scalar distributions will be |
| 1306 | + automatically resized to ensure this. |
1292 | 1307 | compute_corr: bool, default=True
|
1293 | 1308 | If `True`, returns three values: the Cholesky decomposition, the correlations
|
1294 | 1309 | and the standard deviations of the covariance matrix. Otherwise, only returns
|
|
0 commit comments