Skip to content

Commit 77ce200

Browse files
kc611ricardoV94
andauthored
Porting kroneckernormal distribution to v4 (#4774)
Co-authored-by: Ricardo <[email protected]> Co-authored-by: Ricardo <[email protected]>
1 parent c9a2b40 commit 77ce200

File tree

4 files changed

+112
-200
lines changed

4 files changed

+112
-200
lines changed

pymc3/distributions/multivariate.py

+74-127
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import warnings
1919

20+
from functools import reduce
21+
2022
import aesara
2123
import aesara.tensor as at
2224
import numpy as np
@@ -45,7 +47,7 @@
4547
from pymc3.distributions.continuous import ChiSquared, Normal, assert_negative_support
4648
from pymc3.distributions.dist_math import bound, factln, logpow, multigammaln
4749
from pymc3.distributions.distribution import Continuous, Discrete
48-
from pymc3.math import kron_diag, kron_dot, kron_solve_lower, kronecker
50+
from pymc3.math import kron_diag, kron_dot
4951

5052
__all__ = [
5153
"MvNormal",
@@ -1702,6 +1704,32 @@ def _distr_parameters_for_repr(self):
17021704
return ["mu", "row" + mapping[self._rowcov_type], "col" + mapping[self._colcov_type]]
17031705

17041706

1707+
class KroneckerNormalRV(RandomVariable):
1708+
name = "kroneckernormal"
1709+
ndim_supp = 2
1710+
ndims_params = [1, 0, 2]
1711+
dtype = "floatX"
1712+
_print_name = ("KroneckerNormal", "\\operatorname{KroneckerNormal}")
1713+
1714+
def _shape_from_params(self, dist_params, rep_param_idx=0, param_shapes=None):
1715+
return default_shape_from_params(1, dist_params, rep_param_idx, param_shapes)
1716+
1717+
def rng_fn(self, rng, mu, sigma, *covs, size=None):
1718+
size = size if size else covs[-1]
1719+
covs = covs[:-1] if covs[-1] == size else covs
1720+
1721+
cov = reduce(linalg.kron, covs)
1722+
1723+
if sigma:
1724+
cov = cov + sigma ** 2 * np.eye(cov.shape[0])
1725+
1726+
x = multivariate_normal.rng_fn(rng=rng, mean=mu, cov=cov, size=size)
1727+
return x
1728+
1729+
1730+
kroneckernormal = KroneckerNormalRV()
1731+
1732+
17051733
class KroneckerNormal(Continuous):
17061734
r"""
17071735
Multivariate normal log-likelihood with Kronecker-structured covariance.
@@ -1790,160 +1818,79 @@ class KroneckerNormal(Continuous):
17901818
----------
17911819
.. [1] Saatchi, Y. (2011). "Scalable inference for structured Gaussian process models"
17921820
"""
1821+
rv_op = kroneckernormal
17931822

1794-
def __init__(self, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs):
1795-
self._setup(covs, chols, evds, sigma)
1796-
super().__init__(*args, **kwargs)
1797-
self.mu = at.as_tensor_variable(mu)
1798-
self.mean = self.median = self.mode = self.mu
1823+
@classmethod
1824+
def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs):
17991825

1800-
def _setup(self, covs, chols, evds, sigma):
1801-
self.cholesky = Cholesky(lower=True, on_error="raise")
18021826
if len([i for i in [covs, chols, evds] if i is not None]) != 1:
18031827
raise ValueError(
18041828
"Incompatible parameterization. Specify exactly one of covs, chols, or evds."
18051829
)
1806-
self._isEVD = False
1807-
self.sigma = sigma
1808-
self.is_noisy = self.sigma is not None and self.sigma != 0
1809-
if covs is not None:
1810-
self._cov_type = "cov"
1811-
self.covs = covs
1812-
if self.is_noisy:
1813-
# Noise requires eigendecomposition
1814-
eigh_map = map(eigh, covs)
1815-
self._setup_evd(eigh_map)
1816-
else:
1817-
# Otherwise use cholesky as usual
1818-
self.chols = list(map(self.cholesky, self.covs))
1819-
self.chol_diags = list(map(at.diag, self.chols))
1820-
self.sizes = at.as_tensor_variable([chol.shape[0] for chol in self.chols])
1821-
self.N = at.prod(self.sizes)
1822-
elif chols is not None:
1823-
self._cov_type = "chol"
1824-
if self.is_noisy: # A strange case...
1825-
# Noise requires eigendecomposition
1826-
covs = [at.dot(chol, chol.T) for chol in chols]
1827-
eigh_map = map(eigh, covs)
1828-
self._setup_evd(eigh_map)
1829-
else:
1830-
self.chols = chols
1831-
self.chol_diags = list(map(at.diag, self.chols))
1832-
self.sizes = at.as_tensor_variable([chol.shape[0] for chol in self.chols])
1833-
self.N = at.prod(self.sizes)
1834-
else:
1835-
self._cov_type = "evd"
1836-
self._setup_evd(evds)
18371830

1838-
def _setup_evd(self, eigh_iterable):
1839-
self._isEVD = True
1840-
eigs_sep, Qs = zip(*eigh_iterable) # Unzip
1841-
self.Qs = list(map(at.as_tensor_variable, Qs))
1842-
self.QTs = list(map(at.transpose, self.Qs))
1843-
1844-
self.eigs_sep = list(map(at.as_tensor_variable, eigs_sep))
1845-
self.eigs = kron_diag(*self.eigs_sep) # Combine separate eigs
1846-
if self.is_noisy:
1847-
self.eigs += self.sigma ** 2
1848-
self.N = self.eigs.shape[0]
1849-
1850-
def _setup_random(self):
1851-
if not hasattr(self, "mv_params"):
1852-
self.mv_params = {"mu": self.mu}
1853-
if self._cov_type == "cov":
1854-
cov = kronecker(*self.covs)
1855-
if self.is_noisy:
1856-
cov = cov + self.sigma ** 2 * at.identity_like(cov)
1857-
self.mv_params["cov"] = cov
1858-
elif self._cov_type == "chol":
1859-
if self.is_noisy:
1860-
covs = []
1861-
for eig, Q in zip(self.eigs_sep, self.Qs):
1862-
cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T))
1863-
covs.append(cov_i)
1864-
cov = kronecker(*covs)
1865-
if self.is_noisy:
1866-
cov = cov + self.sigma ** 2 * at.identity_like(cov)
1867-
self.mv_params["chol"] = self.cholesky(cov)
1868-
else:
1869-
self.mv_params["chol"] = kronecker(*self.chols)
1870-
elif self._cov_type == "evd":
1871-
covs = []
1872-
for eig, Q in zip(self.eigs_sep, self.Qs):
1873-
cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T))
1874-
covs.append(cov_i)
1875-
cov = kronecker(*covs)
1876-
if self.is_noisy:
1877-
cov = cov + self.sigma ** 2 * at.identity_like(cov)
1878-
self.mv_params["cov"] = cov
1831+
sigma = sigma if sigma else 0
18791832

1880-
def random(self, point=None, size=None):
1833+
if chols is not None:
1834+
covs = [chol.dot(chol.T) for chol in chols]
1835+
elif evds is not None:
1836+
eigh_iterable = evds
1837+
covs = []
1838+
eigs_sep, Qs = zip(*eigh_iterable) # Unzip
1839+
for eig, Q in zip(eigs_sep, Qs):
1840+
cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T))
1841+
covs.append(cov_i)
1842+
1843+
mu = at.as_tensor_variable(mu)
1844+
1845+
# mean = median = mode = mu
1846+
return super().dist([mu, sigma, *covs], **kwargs)
1847+
1848+
def logp(value, mu, sigma, *covs):
18811849
"""
1882-
Draw random values from Multivariate Normal distribution
1883-
with Kronecker-structured covariance.
1850+
Calculate log-probability of Multivariate Normal distribution
1851+
with Kronecker-structured covariance at specified value.
18841852
18851853
Parameters
18861854
----------
1887-
point: dict, optional
1888-
Dict of variable values on which random values are to be
1889-
conditioned (uses default point if not specified).
1890-
size: int, optional
1891-
Desired size of random sample (returns one sample if not
1892-
specified).
1855+
value: numeric
1856+
Value for which log-probability is calculated.
18931857
18941858
Returns
18951859
-------
1896-
array
1860+
TensorVariable
18971861
"""
1898-
# Expand params into terms MvNormal can understand to force consistency
1899-
self._setup_random()
1900-
self.mv_params["shape"] = self.shape
1901-
dist = MvNormal.dist(**self.mv_params)
1902-
return dist.random(point, size)
1903-
1904-
def _quaddist(self, value):
1905-
"""Computes the quadratic (x-mu)^T @ K^-1 @ (x-mu) and log(det(K))"""
1862+
# Computes the quadratic (x-mu)^T @ K^-1 @ (x-mu) and log(det(K))
19061863
if value.ndim > 2 or value.ndim == 0:
1907-
raise ValueError("Invalid dimension for value: %s" % value.ndim)
1864+
raise ValueError(f"Invalid dimension for value: {value.ndim}")
19081865
if value.ndim == 1:
19091866
onedim = True
19101867
value = value[None, :]
19111868
else:
19121869
onedim = False
19131870

1914-
delta = value - self.mu
1915-
if self._isEVD:
1916-
sqrt_quad = kron_dot(self.QTs, delta.T)
1917-
sqrt_quad = sqrt_quad / at.sqrt(self.eigs[:, None])
1918-
logdet = at.sum(at.log(self.eigs))
1919-
else:
1920-
sqrt_quad = kron_solve_lower(self.chols, delta.T)
1921-
logdet = 0
1922-
for chol_size, chol_diag in zip(self.sizes, self.chol_diags):
1923-
logchol = at.log(chol_diag) * self.N / chol_size
1924-
logdet += at.sum(2 * logchol)
1871+
delta = value - mu
1872+
1873+
eigh_iterable = map(eigh, covs)
1874+
eigs_sep, Qs = zip(*eigh_iterable) # Unzip
1875+
Qs = list(map(at.as_tensor_variable, Qs))
1876+
QTs = list(map(at.transpose, Qs))
1877+
1878+
eigs_sep = list(map(at.as_tensor_variable, eigs_sep))
1879+
eigs = kron_diag(*eigs_sep) # Combine separate eigs
1880+
eigs += sigma ** 2
1881+
N = eigs.shape[0]
1882+
1883+
sqrt_quad = kron_dot(QTs, delta.T)
1884+
sqrt_quad = sqrt_quad / at.sqrt(eigs[:, None])
1885+
logdet = at.sum(at.log(eigs))
1886+
19251887
# Square each sample
19261888
quad = at.batched_dot(sqrt_quad.T, sqrt_quad.T)
19271889
if onedim:
19281890
quad = quad[0]
1929-
return quad, logdet
19301891

1931-
def logp(self, value):
1932-
"""
1933-
Calculate log-probability of Multivariate Normal distribution
1934-
with Kronecker-structured covariance at specified value.
1935-
1936-
Parameters
1937-
----------
1938-
value: numeric
1939-
Value for which log-probability is calculated.
1940-
1941-
Returns
1942-
-------
1943-
TensorVariable
1944-
"""
1945-
quad, logdet = self._quaddist(value)
1946-
return -(quad + logdet + self.N * at.log(2 * np.pi)) / 2.0
1892+
a = -(quad + logdet + N * at.log(2 * np.pi)) / 2.0
1893+
return a
19471894

19481895
def _distr_parameters_for_repr(self):
19491896
return ["mu"]

pymc3/tests/test_distributions.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -388,19 +388,19 @@ def matrix_normal_logpdf_chol(value, mu, rowchol, colchol):
388388
)
389389

390390

391-
def kron_normal_logpdf_cov(value, mu, covs, sigma):
391+
def kron_normal_logpdf_cov(value, mu, covs, sigma, size=None):
392392
cov = kronecker(*covs).eval()
393393
if sigma is not None:
394394
cov += sigma ** 2 * np.eye(*cov.shape)
395395
return scipy.stats.multivariate_normal.logpdf(value, mu, cov).sum()
396396

397397

398-
def kron_normal_logpdf_chol(value, mu, chols, sigma):
398+
def kron_normal_logpdf_chol(value, mu, chols, sigma, size=None):
399399
covs = [np.dot(chol, chol.T) for chol in chols]
400400
return kron_normal_logpdf_cov(value, mu, covs, sigma=sigma)
401401

402402

403-
def kron_normal_logpdf_evd(value, mu, evds, sigma):
403+
def kron_normal_logpdf_evd(value, mu, evds, sigma, size=None):
404404
covs = []
405405
for eigs, Q in evds:
406406
try:
@@ -1943,8 +1943,7 @@ def test_matrixnormal(self, n):
19431943

19441944
@pytest.mark.parametrize("n", [2, 3])
19451945
@pytest.mark.parametrize("m", [3])
1946-
@pytest.mark.parametrize("sigma", [None, 1.0])
1947-
@pytest.mark.xfail(reason="Distribution not refactored yet")
1946+
@pytest.mark.parametrize("sigma", [None, 1])
19481947
def test_kroneckernormal(self, n, m, sigma):
19491948
np.random.seed(5)
19501949
N = n * m
@@ -1990,6 +1989,9 @@ def test_kroneckernormal(self, n, m, sigma):
19901989
)
19911990

19921991
dom = Domain([np.random.randn(2, N) * 0.1], edges=(None, None), shape=(2, N))
1992+
cov_args["size"] = 2
1993+
chol_args["size"] = 2
1994+
evd_args["size"] = 2
19931995

19941996
self.check_logp(
19951997
KroneckerNormal,

0 commit comments

Comments
 (0)