Skip to content

Commit bc0d8f5

Browse files
Update tests/distributions/test_multivariate.py
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 301b159 commit bc0d8f5

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/distributions/test_multivariate.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2121,11 +2121,11 @@ def ref_rand(size, n, eta):
21212121
size=1000,
21222122
)
21232123

2124-
@pytest.mark.parametrize(argnames="n", argvalues=[2, 3], ids=["n=2", "n=3"])
2125-
def test_default_transform(self, n):
2124+
def test_default_transform(self):
21262125
with pm.Model() as m:
2127-
pm.LKJCorr("x", n=n, eta=1)
2128-
m.logp()
2126+
x = pm.LKJCorr("x", n=2, eta=1, shape=(3, 2))
2127+
assert isinstance(m.rvs_to_transforms[x], MultivariateIntervalTransform)
2128+
assert m.logp(sum=False)[0].shape == (3,)
21292129

21302130

21312131
class TestLKJCholeskyCov(BaseTestDistributionRandom):

0 commit comments

Comments
 (0)