Skip to content

Commit 1a35a3d

Browse files
committed
Refactor LKJCholeskyCov for V4
Changes: * compute_corr now defaults to True * LKJCholeskyCov now also provides a `.dist` interface
1 parent eed60c3 commit 1a35a3d

9 files changed

+248
-172
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ All of the above apply to:
8787
- ArviZ `plots` and `stats` *wrappers* were removed. The functions are now just available by their original names (see [#4549](https://github.com/pymc-devs/pymc/pull/4471) and `3.11.2` release notes).
8888
- `pm.sample_posterior_predictive(vars=...)` kwarg was removed in favor of `var_names` (see [#4343](https://github.com/pymc-devs/pymc/pull/4343)).
8989
- `ElemwiseCategorical` step method was removed (see [#4701](https://github.com/pymc-devs/pymc/pull/4701))
90+
- `LKJCholeskyCov` `compute_corr` keyword argument is now set to `True` by default (see[#5382](https://github.com/pymc-devs/pymc/pull/5382))
9091

9192
### Ongoing deprecations
9293
- Old API still works in `v4` and has a deprecation warning.

pymc/distributions/multivariate.py

+137-158
Large diffs are not rendered by default.

pymc/distributions/transforms.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
RVTransform,
2323
Simplex,
2424
)
25-
from aesara.tensor.subtensor import advanced_set_subtensor1
2625

2726
__all__ = [
2827
"RVTransform",
@@ -97,22 +96,31 @@ def log_jac_det(self, value, *inputs):
9796

9897

9998
class CholeskyCovPacked(RVTransform):
99+
"""
100+
Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the
101+
log scale
102+
"""
103+
100104
name = "cholesky-cov-packed"
101105

102-
def __init__(self, param_extract_fn):
103-
self.param_extract_fn = param_extract_fn
106+
def __init__(self, n):
107+
"""
108+
109+
Parameters
110+
----------
111+
n: int
112+
Number of diagonal entries in the LKJCholeskyCov distribution
113+
"""
114+
self.diag_idxs = at.arange(1, n + 1).cumsum() - 1
104115

105116
def backward(self, value, *inputs):
106-
diag_idxs = self.param_extract_fn(inputs)
107-
return advanced_set_subtensor1(value, at.exp(value[diag_idxs]), diag_idxs)
117+
return at.set_subtensor(value[..., self.diag_idxs], at.exp(value[..., self.diag_idxs]))
108118

109119
def forward(self, value, *inputs):
110-
diag_idxs = self.param_extract_fn(inputs)
111-
return advanced_set_subtensor1(value, at.log(value[diag_idxs]), diag_idxs)
120+
return at.set_subtensor(value[..., self.diag_idxs], at.log(value[..., self.diag_idxs]))
112121

113122
def log_jac_det(self, value, *inputs):
114-
diag_idxs = self.param_extract_fn(inputs)
115-
return at.sum(value[diag_idxs])
123+
return at.sum(value[..., self.diag_idxs], axis=-1)
116124

117125

118126
class Chain(RVTransform):

pymc/tests/sampler_fixtures.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ def make_model(cls):
122122
with pm.Model() as model:
123123
sd_mu = np.array([1, 2, 3, 4, 5])
124124
sd_dist = pm.LogNormal.dist(mu=sd_mu, sigma=sd_mu / 10.0, size=5)
125-
chol_packed = pm.LKJCholeskyCov("chol_packed", eta=3, n=5, sd_dist=sd_dist)
125+
chol_packed = pm.LKJCholeskyCov(
126+
"chol_packed", eta=3, n=5, sd_dist=sd_dist, compute_corr=False
127+
)
126128
chol = pm.expand_packed_triangular(5, chol_packed, lower=True)
127129
cov = at.dot(chol, chol.T)
128130
stds = at.sqrt(at.diag(cov))

pymc/tests/test_distributions.py

+37
Original file line numberDiff line numberDiff line change
@@ -3352,3 +3352,40 @@ def test_censored_invalid_dist(self):
33523352
match="The dist dist was already registered in the current model",
33533353
):
33543354
x = pm.Censored("x", registered_dist, lower=None, upper=None)
3355+
3356+
3357+
class TestLKJCholeskCov:
3358+
def test_dist(self):
3359+
sd_dist = pm.Exponential.dist(1, size=(10, 3))
3360+
x = pm.LKJCholeskyCov.dist(n=3, eta=1, sd_dist=sd_dist, size=10, compute_corr=False)
3361+
assert x.eval().shape == (10, 6)
3362+
3363+
sd_dist = pm.Exponential.dist(1, size=3)
3364+
chol, corr, stds = pm.LKJCholeskyCov.dist(n=3, eta=1, sd_dist=sd_dist)
3365+
assert chol.eval().shape == (3, 3)
3366+
assert corr.eval().shape == (3, 3)
3367+
assert stds.eval().shape == (3,)
3368+
3369+
def test_sd_dist_distribution(self):
3370+
with pm.Model() as m:
3371+
sd_dist = at.constant([1, 2, 3])
3372+
with pytest.raises(TypeError, match="sd_dist must be a Distribution variable"):
3373+
x = pm.LKJCholeskyCov("x", n=3, eta=1, sd_dist=sd_dist)
3374+
3375+
def test_sd_dist_registered(self):
3376+
with pm.Model() as m:
3377+
sd_dist = pm.Exponential("sd_dist", 1, size=3)
3378+
with pytest.raises(
3379+
ValueError, match="The dist sd_dist was already registered in the current model"
3380+
):
3381+
x = pm.LKJCholeskyCov("x", n=3, eta=1, sd_dist=sd_dist)
3382+
3383+
def test_no_warning_logp(self):
3384+
# Check that calling logp of a model with LKJCholeskyCov does not issue any warnings
3385+
# due to the RandomVariable in the graph
3386+
with pm.Model() as m:
3387+
sd_dist = pm.Exponential.dist(1, size=3)
3388+
x = pm.LKJCholeskyCov("x", n=3, eta=1, sd_dist=sd_dist)
3389+
with pytest.warns(None) as record:
3390+
m.logpt()
3391+
assert not record

pymc/tests/test_distributions_random.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,11 @@ def random_polyagamma(*args, **kwargs):
4747
from pymc.distributions.discrete import _OrderedLogistic, _OrderedProbit
4848
from pymc.distributions.dist_math import clipped_beta_rvs
4949
from pymc.distributions.logprob import logp
50-
from pymc.distributions.multivariate import _OrderedMultinomial, quaddist_matrix
50+
from pymc.distributions.multivariate import (
51+
_LKJCholeskyCov,
52+
_OrderedMultinomial,
53+
quaddist_matrix,
54+
)
5155
from pymc.distributions.shape_utils import to_tuple
5256
from pymc.tests.helpers import SeededTest, select_by_precision
5357
from pymc.tests.test_distributions import (
@@ -1867,6 +1871,43 @@ def ref_rand(size, n, eta):
18671871
)
18681872

18691873

1874+
class TestLKJCholeskyCov(BaseTestDistributionRandom):
1875+
pymc_dist = _LKJCholeskyCov
1876+
pymc_dist_params = {"eta": 1.0, "n": 3, "sd_dist": pm.Constant.dist([0.5, 1.0, 2.0])}
1877+
expected_rv_op_params = {"n": 3, "eta": 1.0, "sd_dist": pm.Constant.dist([0.5, 1.0, 2.0])}
1878+
size = None
1879+
1880+
sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
1881+
sizes_expected = [
1882+
(6,),
1883+
(6,),
1884+
(1, 6),
1885+
(1, 6),
1886+
(5, 6),
1887+
(4, 5, 6),
1888+
(2, 4, 2, 6),
1889+
]
1890+
1891+
tests_to_run = [
1892+
"check_rv_size",
1893+
"check_draws_match_expected",
1894+
]
1895+
1896+
def check_rv_size(self):
1897+
for size, expected in zip(self.sizes_to_check, self.sizes_expected):
1898+
sd_dist = pm.Exponential.dist(1, size=(*to_tuple(size), 3))
1899+
pymc_rv = self.pymc_dist.dist(n=3, eta=1, sd_dist=sd_dist, size=size)
1900+
expected_symbolic = tuple(pymc_rv.shape.eval())
1901+
actual = pymc_rv.eval().shape
1902+
assert actual == expected_symbolic == expected
1903+
1904+
def check_draws_match_expected(self):
1905+
# TODO: Find better comparison:
1906+
rng = aesara.shared(self.get_random_state(reset=True))
1907+
x = _LKJCholeskyCov.dist(n=2, eta=10_000, sd_dist=pm.Constant.dist([0.5, 2.0]), rng=rng)
1908+
assert np.all(np.abs(x.eval() - np.array([0.5, 0, 2.0])) < 0.01)
1909+
1910+
18701911
class TestScalarParameterSamples(SeededTest):
18711912
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
18721913
def test_normalmixture(self):
@@ -2346,9 +2387,11 @@ def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape):
23462387
with pm.Model() as model:
23472388
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
23482389
sd_dist = pm.Exponential.dist(1.0, shape=3)
2390+
# pylint: disable=unpacking-non-sequence
23492391
chol, corr, stds = pm.LKJCholeskyCov(
23502392
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
23512393
)
2394+
# pylint: enable=unpacking-non-sequence
23522395
mv = pm.MvNormal("mv", mu, chol=chol, shape=dist_shape)
23532396
prior = pm.sample_prior_predictive(samples=sample_shape)
23542397

@@ -2363,9 +2406,11 @@ def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
23632406
with pm.Model() as model:
23642407
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
23652408
sd_dist = pm.Exponential.dist(1.0, shape=3)
2409+
# pylint: disable=unpacking-non-sequence
23662410
chol, corr, stds = pm.LKJCholeskyCov(
23672411
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
23682412
)
2413+
# pylint: enable=unpacking-non-sequence
23692414
mv = pm.MvNormal("mv", mu, cov=pm.math.dot(chol, chol.T), shape=dist_shape)
23702415
prior = pm.sample_prior_predictive(samples=sample_shape)
23712416

@@ -2457,9 +2502,11 @@ def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape):
24572502
with pm.Model() as model:
24582503
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
24592504
sd_dist = pm.Exponential.dist(1.0, shape=3)
2505+
# pylint: disable=unpacking-non-sequence
24602506
chol, corr, stds = pm.LKJCholeskyCov(
24612507
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
24622508
)
2509+
# pylint: enable=unpacking-non-sequence
24632510
mv = pm.MvGaussianRandomWalk("mv", mu, chol=chol, shape=dist_shape)
24642511
prior = pm.sample_prior_predictive(samples=sample_shape)
24652512

@@ -2475,9 +2522,11 @@ def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
24752522
with pm.Model() as model:
24762523
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
24772524
sd_dist = pm.Exponential.dist(1.0, shape=3)
2525+
# pylint: disable=unpacking-non-sequence
24782526
chol, corr, stds = pm.LKJCholeskyCov(
24792527
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
24802528
)
2529+
# pylint: enable=unpacking-non-sequence
24812530
mv = pm.MvGaussianRandomWalk("mv", mu, cov=pm.math.dot(chol, chol.T), shape=dist_shape)
24822531
prior = pm.sample_prior_predictive(samples=sample_shape)
24832532

pymc/tests/test_idata_conversion.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -332,15 +332,16 @@ def test_missing_data_model(self):
332332
assert inference_data.log_likelihood["y_observed"].shape == (2, 100, 3)
333333

334334
@pytest.mark.xfal(reason="Multivariate partial observed RVs not implemented for V4")
335-
@pytest.mark.xfail(reason="LKJCholeskyCov not refactored for v4")
336335
def test_mv_missing_data_model(self):
337336
data = ma.masked_values([[1, 2], [2, 2], [-1, 4], [2, -1], [-1, -1]], value=-1)
338337

339338
model = pm.Model()
340339
with model:
341340
mu = pm.Normal("mu", 0, 1, size=2)
342341
sd_dist = pm.HalfNormal.dist(1.0)
342+
# pylint: disable=unpacking-non-sequence
343343
chol, *_ = pm.LKJCholeskyCov("chol_cov", n=2, eta=1, sd_dist=sd_dist, compute_corr=True)
344+
# pylint: enable=unpacking-non-sequence
344345
y = pm.MvNormal("y", mu=mu, chol=chol, observed=data)
345346
inference_data = pm.sample(100, chains=2, return_inferencedata=True)
346347

pymc/tests/test_mixture.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def build_toy_dataset(N, K):
368368
mu.append(pm.Normal("mu%i" % i, 0, 10, shape=D))
369369
packed_chol.append(
370370
pm.LKJCholeskyCov(
371-
"chol_cov_%i" % i, eta=2, n=D, sd_dist=pm.HalfNormal.dist(2.5)
371+
"chol_cov_%i" % i, eta=2, n=D, sd_dist=pm.HalfNormal.dist(2.5, size=D)
372372
)
373373
)
374374
chol.append(pm.expand_packed_triangular(D, packed_chol[i], lower=True))

pymc/tests/test_posteriors.py

-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ class TestNUTSNormalLong(sf.NutsFixture, sf.NormalFixture):
9595
atol = 0.001
9696

9797

98-
@pytest.mark.xfail(reason="LKJCholeskyCov not refactored for v4")
9998
class TestNUTSLKJCholeskyCov(sf.NutsFixture, sf.LKJCholeskyCovFixture):
10099
n_samples = 2000
101100
tune = 1000

0 commit comments

Comments
 (0)