Skip to content

Commit cda2371

Browse files
kc611michaelosthege
authored andcommitted
Refactored Wishart and MatrixNormal distribution
1 parent de83381 commit cda2371

File tree

3 files changed

+202
-243
lines changed

3 files changed

+202
-243
lines changed

pymc3/distributions/multivariate.py

+105-146
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from pymc3.distributions.continuous import ChiSquared, Normal, assert_negative_support
4848
from pymc3.distributions.dist_math import bound, factln, logpow, multigammaln
4949
from pymc3.distributions.distribution import Continuous, Discrete
50+
from pymc3.distributions.shape_utils import broadcast_dist_samples_to, to_tuple
5051
from pymc3.math import kron_diag, kron_dot
5152

5253
__all__ = [
@@ -739,6 +740,26 @@ def __str__(self):
739740
matrix_pos_def = PosDefMatrix()
740741

741742

743+
class WishartRV(RandomVariable):
744+
name = "wishart"
745+
ndim_supp = 2
746+
ndims_params = [0, 2]
747+
dtype = "floatX"
748+
_print_name = ("Wishart", "\\operatorname{Wishart}")
749+
750+
def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
751+
# The shape of second parameter `V` defines the shape of the output.
752+
return dist_params[1].shape
753+
754+
@classmethod
755+
def rng_fn(cls, rng, nu, V, size=None):
756+
size = size if size else 1 # Default size for Scipy's wishart.rvs is 1
757+
return stats.wishart.rvs(np.int(nu), V, size=size, random_state=rng)
758+
759+
760+
wishart = WishartRV()
761+
762+
742763
class Wishart(Continuous):
743764
r"""
744765
Wishart log-likelihood.
@@ -775,9 +796,13 @@ class Wishart(Continuous):
775796
This distribution is unusable in a PyMC3 model. You should instead
776797
use LKJCholeskyCov or LKJCorr.
777798
"""
799+
rv_op = wishart
800+
801+
@classmethod
802+
def dist(cls, nu, V, *args, **kwargs):
803+
nu = at.as_tensor_variable(intX(nu))
804+
V = at.as_tensor_variable(floatX(V))
778805

779-
def __init__(self, nu, V, *args, **kwargs):
780-
super().__init__(*args, **kwargs)
781806
warnings.warn(
782807
"The Wishart distribution can currently not be used "
783808
"for MCMC sampling. The probability of sampling a "
@@ -787,34 +812,13 @@ def __init__(self, nu, V, *args, **kwargs):
787812
"https://github.com/pymc-devs/pymc3/issues/538.",
788813
UserWarning,
789814
)
790-
self.nu = nu = at.as_tensor_variable(nu)
791-
self.p = p = at.as_tensor_variable(V.shape[0])
792-
self.V = V = at.as_tensor_variable(V)
793-
self.mean = nu * V
794-
self.mode = at.switch(at.ge(nu, p + 1), (nu - p - 1) * V, np.nan)
795815

796-
def random(self, point=None, size=None):
797-
"""
798-
Draw random values from Wishart distribution.
816+
# mean = nu * V
817+
# p = V.shape[0]
818+
# mode = at.switch(at.ge(nu, p + 1), (nu - p - 1) * V, np.nan)
819+
return super().dist([nu, V], *args, **kwargs)
799820

800-
Parameters
801-
----------
802-
point: dict, optional
803-
Dict of variable values on which random values are to be
804-
conditioned (uses default point if not specified).
805-
size: int, optional
806-
Desired size of random sample (returns one sample if not
807-
specified).
808-
809-
Returns
810-
-------
811-
array
812-
"""
813-
# nu, V = draw_values([self.nu, self.V], point=point, size=size)
814-
# size = 1 if size is None else size
815-
# return generate_samples(stats.wishart.rvs, nu.item(), V, broadcast_shape=(size,))
816-
817-
def logp(self, X):
821+
def logp(X, nu, V):
818822
"""
819823
Calculate log-probability of Wishart distribution
820824
at specified value.
@@ -828,9 +832,8 @@ def logp(self, X):
828832
-------
829833
TensorVariable
830834
"""
831-
nu = self.nu
832-
p = self.p
833-
V = self.V
835+
836+
p = V.shape[0]
834837

835838
IVI = det(V)
836839
IXI = det(X)
@@ -1445,6 +1448,36 @@ def _distr_parameters_for_repr(self):
14451448
return ["eta", "n"]
14461449

14471450

1451+
class MatrixNormalRV(RandomVariable):
1452+
name = "matrixnormal"
1453+
ndim_supp = 2
1454+
ndims_params = [2, 2, 2]
1455+
dtype = "floatX"
1456+
_print_name = ("MatrixNormal", "\\operatorname{MatrixNormal}")
1457+
1458+
@classmethod
1459+
def rng_fn(cls, rng, mu, rowchol, colchol, size=None):
1460+
1461+
size = to_tuple(size)
1462+
dist_shape = to_tuple([rowchol.shape[0], colchol.shape[0]])
1463+
output_shape = size + dist_shape
1464+
1465+
# Broadcasting all parameters
1466+
(mu,) = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size)
1467+
rowchol = np.broadcast_to(rowchol, shape=size + rowchol.shape[-2:])
1468+
1469+
colchol = np.broadcast_to(colchol, shape=size + colchol.shape[-2:])
1470+
colchol = np.swapaxes(colchol, -1, -2) # Take transpose
1471+
1472+
standard_normal = rng.standard_normal(output_shape)
1473+
samples = mu + np.matmul(rowchol, np.matmul(standard_normal, colchol))
1474+
1475+
return samples
1476+
1477+
1478+
matrixnormal = MatrixNormalRV()
1479+
1480+
14481481
class MatrixNormal(Continuous):
14491482
r"""
14501483
Matrix-valued normal log-likelihood.
@@ -1533,175 +1566,101 @@ class MatrixNormal(Continuous):
15331566
vals = pm.MatrixNormal('vals', mu=mu, colchol=colchol, rowcov=rowcov,
15341567
observed=data, shape=(m, n))
15351568
"""
1569+
rv_op = matrixnormal
15361570

1537-
def __init__(
1538-
self,
1539-
mu=0,
1571+
@classmethod
1572+
def dist(
1573+
cls,
1574+
mu,
15401575
rowcov=None,
15411576
rowchol=None,
1542-
rowtau=None,
15431577
colcov=None,
15441578
colchol=None,
1545-
coltau=None,
15461579
shape=None,
15471580
*args,
15481581
**kwargs,
15491582
):
1550-
self._setup_matrices(colcov, colchol, coltau, rowcov, rowchol, rowtau)
1551-
if shape is None:
1552-
raise TypeError("shape is a required argument")
1553-
assert len(shape) == 2, "shape must have length 2: mxn"
1554-
self.shape = shape
1555-
super().__init__(shape=shape, *args, **kwargs)
1556-
self.mu = at.as_tensor_variable(mu)
1557-
self.mean = self.median = self.mode = self.mu
1558-
self.solve_lower = solve_lower_triangular
1559-
self.solve_upper = solve_upper_triangular
1560-
1561-
def _setup_matrices(self, colcov, colchol, coltau, rowcov, rowchol, rowtau):
1583+
15621584
cholesky = Cholesky(lower=True, on_error="raise")
15631585

1586+
if mu.ndim == 1:
1587+
raise ValueError(
1588+
"1x1 Matrix was provided. Please use Normal distribution for such cases."
1589+
)
1590+
15641591
# Among-row matrices
1565-
if len([i for i in [rowtau, rowcov, rowchol] if i is not None]) != 1:
1592+
if len([i for i in [rowcov, rowchol] if i is not None]) != 1:
15661593
raise ValueError(
1567-
"Incompatible parameterization. "
1568-
"Specify exactly one of rowtau, rowcov, "
1569-
"or rowchol."
1594+
"Incompatible parameterization. Specify exactly one of rowcov, or rowchol."
15701595
)
15711596
if rowcov is not None:
1572-
self.m = rowcov.shape[0]
1573-
self._rowcov_type = "cov"
1574-
rowcov = at.as_tensor_variable(rowcov)
15751597
if rowcov.ndim != 2:
15761598
raise ValueError("rowcov must be two dimensional.")
1577-
self.rowchol_cov = cholesky(rowcov)
1578-
self.rowcov = rowcov
1579-
elif rowtau is not None:
1580-
raise ValueError("rowtau not supported at this time")
1581-
self.m = rowtau.shape[0]
1582-
self._rowcov_type = "tau"
1583-
rowtau = at.as_tensor_variable(rowtau)
1584-
if rowtau.ndim != 2:
1585-
raise ValueError("rowtau must be two dimensional.")
1586-
self.rowchol_tau = cholesky(rowtau)
1587-
self.rowtau = rowtau
1599+
rowchol_cov = cholesky(rowcov)
15881600
else:
1589-
self.m = rowchol.shape[0]
1590-
self._rowcov_type = "chol"
15911601
if rowchol.ndim != 2:
15921602
raise ValueError("rowchol must be two dimensional.")
1593-
self.rowchol_cov = at.as_tensor_variable(rowchol)
1603+
rowchol_cov = at.as_tensor_variable(rowchol)
15941604

15951605
# Among-column matrices
1596-
if len([i for i in [coltau, colcov, colchol] if i is not None]) != 1:
1606+
if len([i for i in [colcov, colchol] if i is not None]) != 1:
15971607
raise ValueError(
1598-
"Incompatible parameterization. "
1599-
"Specify exactly one of coltau, colcov, "
1600-
"or colchol."
1608+
"Incompatible parameterization. Specify exactly one of colcov, or colchol."
16011609
)
16021610
if colcov is not None:
1603-
self.n = colcov.shape[0]
1604-
self._colcov_type = "cov"
16051611
colcov = at.as_tensor_variable(colcov)
16061612
if colcov.ndim != 2:
16071613
raise ValueError("colcov must be two dimensional.")
1608-
self.colchol_cov = cholesky(colcov)
1609-
self.colcov = colcov
1610-
elif coltau is not None:
1611-
raise ValueError("coltau not supported at this time")
1612-
self.n = coltau.shape[0]
1613-
self._colcov_type = "tau"
1614-
coltau = at.as_tensor_variable(coltau)
1615-
if coltau.ndim != 2:
1616-
raise ValueError("coltau must be two dimensional.")
1617-
self.colchol_tau = cholesky(coltau)
1618-
self.coltau = coltau
1614+
colchol_cov = cholesky(colcov)
16191615
else:
1620-
self.n = colchol.shape[0]
1621-
self._colcov_type = "chol"
16221616
if colchol.ndim != 2:
16231617
raise ValueError("colchol must be two dimensional.")
1624-
self.colchol_cov = at.as_tensor_variable(colchol)
1618+
colchol_cov = at.as_tensor_variable(colchol)
16251619

1626-
def random(self, point=None, size=None):
1620+
mu = at.as_tensor_variable(floatX(mu))
1621+
# mean = median = mode = mu
1622+
1623+
return super().dist([mu, rowchol_cov, colchol_cov], **kwargs)
1624+
1625+
def logp(value, mu, rowchol, colchol):
16271626
"""
1628-
Draw random values from Matrix-valued Normal distribution.
1627+
Calculate log-probability of Matrix-valued Normal distribution
1628+
at specified value.
16291629
16301630
Parameters
16311631
----------
1632-
point: dict, optional
1633-
Dict of variable values on which random values are to be
1634-
conditioned (uses default point if not specified).
1635-
size: int, optional
1636-
Desired size of random sample (returns one sample if not
1637-
specified).
1632+
value: numeric
1633+
Value for which log-probability is calculated.
16381634
16391635
Returns
16401636
-------
1641-
array
1637+
TensorVariable
16421638
"""
1643-
# mu, colchol, rowchol = draw_values(
1644-
# [self.mu, self.colchol_cov, self.rowchol_cov], point=point, size=size
1645-
# )
1646-
# size = to_tuple(size)
1647-
# dist_shape = to_tuple(self.shape)
1648-
# output_shape = size + dist_shape
1649-
#
1650-
# # Broadcasting all parameters
1651-
# (mu,) = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size)
1652-
# rowchol = np.broadcast_to(rowchol, shape=size + rowchol.shape[-2:])
1653-
#
1654-
# colchol = np.broadcast_to(colchol, shape=size + colchol.shape[-2:])
1655-
# colchol = np.swapaxes(colchol, -1, -2) # Take transpose
1656-
#
1657-
# standard_normal = np.random.standard_normal(output_shape)
1658-
# samples = mu + np.matmul(rowchol, np.matmul(standard_normal, colchol))
1659-
# return samples
1660-
1661-
def _trquaddist(self, value):
1662-
"""Compute Tr[colcov^-1 @ (x - mu).T @ rowcov^-1 @ (x - mu)] and
1663-
the logdet of colcov and rowcov."""
16641639

1665-
delta = value - self.mu
1666-
rowchol_cov = self.rowchol_cov
1667-
colchol_cov = self.colchol_cov
1640+
# Compute Tr[colcov^-1 @ (x - mu).T @ rowcov^-1 @ (x - mu)] and
1641+
# the logdet of colcov and rowcov.
1642+
delta = value - mu
16681643

16691644
# Find exponent piece by piece
1670-
right_quaddist = self.solve_lower(rowchol_cov, delta)
1645+
right_quaddist = solve_lower_triangular(rowchol, delta)
16711646
quaddist = at.nlinalg.matrix_dot(right_quaddist.T, right_quaddist)
1672-
quaddist = self.solve_lower(colchol_cov, quaddist)
1673-
quaddist = self.solve_upper(colchol_cov.T, quaddist)
1647+
quaddist = solve_lower_triangular(colchol, quaddist)
1648+
quaddist = solve_upper_triangular(colchol.T, quaddist)
16741649
trquaddist = at.nlinalg.trace(quaddist)
16751650

1676-
coldiag = at.diag(colchol_cov)
1677-
rowdiag = at.diag(rowchol_cov)
1651+
coldiag = at.diag(colchol)
1652+
rowdiag = at.diag(rowchol)
16781653
half_collogdet = at.sum(at.log(coldiag)) # logdet(M) = 2*Tr(log(L))
16791654
half_rowlogdet = at.sum(at.log(rowdiag)) # Using Cholesky: M = L L^T
1680-
return trquaddist, half_collogdet, half_rowlogdet
1681-
1682-
def logp(self, value):
1683-
"""
1684-
Calculate log-probability of Matrix-valued Normal distribution
1685-
at specified value.
16861655

1687-
Parameters
1688-
----------
1689-
value: numeric
1690-
Value for which log-probability is calculated.
1656+
m = rowchol.shape[0]
1657+
n = colchol.shape[0]
16911658

1692-
Returns
1693-
-------
1694-
TensorVariable
1695-
"""
1696-
trquaddist, half_collogdet, half_rowlogdet = self._trquaddist(value)
1697-
m = self.m
1698-
n = self.n
16991659
norm = -0.5 * m * n * pm.floatX(np.log(2 * np.pi))
17001660
return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet
17011661

17021662
def _distr_parameters_for_repr(self):
1703-
mapping = {"tau": "tau", "cov": "cov", "chol": "chol_cov"}
1704-
return ["mu", "row" + mapping[self._rowcov_type], "col" + mapping[self._colcov_type]]
1663+
return ["mu"]
17051664

17061665

17071666
class KroneckerNormalRV(RandomVariable):

0 commit comments

Comments
 (0)