Skip to content

Commit e67a317

Browse files
committed
change default transform LKJCOrr
1 parent 0fd7b9e commit e67a317

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

pymc/distributions/multivariate.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,13 @@ def _random_corr_matrix(cls, rng, n, eta, flat_size):
15241524
lkjcorr = LKJCorrRV()
15251525

15261526

1527+
class MultivariateIntervalTransform(Interval):
1528+
name = "interval"
1529+
1530+
def log_jac_det(self, *args):
1531+
return super().log_jac_det(*args).sum(-1)
1532+
1533+
15271534
class LKJCorr(BoundedContinuous):
15281535
r"""
15291536
The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood.
@@ -1623,7 +1630,7 @@ def logp(value, n, eta):
16231630

16241631
@_default_transform.register(LKJCorr)
16251632
def lkjcorr_default_transform(op, rv):
1626-
return Interval(floatX(-1.0), floatX(1.0))
1633+
return MultivariateIntervalTransform(floatX(-1.0), floatX(1.0))
16271634

16281635

16291636
class MatrixNormalRV(RandomVariable):

0 commit comments

Comments
 (0)