|
17 | 17 |
|
18 | 18 | import warnings
|
19 | 19 |
|
| 20 | +from functools import reduce |
| 21 | + |
20 | 22 | import aesara
|
21 | 23 | import aesara.tensor as at
|
22 | 24 | import numpy as np
|
|
45 | 47 | from pymc3.distributions.continuous import ChiSquared, Normal, assert_negative_support
|
46 | 48 | from pymc3.distributions.dist_math import bound, factln, logpow, multigammaln
|
47 | 49 | from pymc3.distributions.distribution import Continuous, Discrete
|
48 |
| -from pymc3.math import kron_diag, kron_dot, kron_solve_lower, kronecker |
| 50 | +from pymc3.math import kron_diag, kron_dot |
49 | 51 |
|
50 | 52 | __all__ = [
|
51 | 53 | "MvNormal",
|
@@ -1702,6 +1704,32 @@ def _distr_parameters_for_repr(self):
|
1702 | 1704 | return ["mu", "row" + mapping[self._rowcov_type], "col" + mapping[self._colcov_type]]
|
1703 | 1705 |
|
1704 | 1706 |
|
| 1707 | +class KroneckerNormalRV(RandomVariable): |
| 1708 | + name = "kroneckernormal" |
| 1709 | + ndim_supp = 2 |
| 1710 | + ndims_params = [1, 0, 2] |
| 1711 | + dtype = "floatX" |
| 1712 | + _print_name = ("KroneckerNormal", "\\operatorname{KroneckerNormal}") |
| 1713 | + |
| 1714 | + def _shape_from_params(self, dist_params, rep_param_idx=0, param_shapes=None): |
| 1715 | + return default_shape_from_params(1, dist_params, rep_param_idx, param_shapes) |
| 1716 | + |
| 1717 | + def rng_fn(self, rng, mu, sigma, *covs, size=None): |
| 1718 | + size = size if size else covs[-1] |
| 1719 | + covs = covs[:-1] if covs[-1] == size else covs |
| 1720 | + |
| 1721 | + cov = reduce(linalg.kron, covs) |
| 1722 | + |
| 1723 | + if sigma: |
| 1724 | + cov = cov + sigma ** 2 * np.eye(cov.shape[0]) |
| 1725 | + |
| 1726 | + x = multivariate_normal.rng_fn(rng=rng, mean=mu, cov=cov, size=size) |
| 1727 | + return x |
| 1728 | + |
| 1729 | + |
| 1730 | +kroneckernormal = KroneckerNormalRV() |
| 1731 | + |
| 1732 | + |
1705 | 1733 | class KroneckerNormal(Continuous):
|
1706 | 1734 | r"""
|
1707 | 1735 | Multivariate normal log-likelihood with Kronecker-structured covariance.
|
@@ -1790,160 +1818,79 @@ class KroneckerNormal(Continuous):
|
1790 | 1818 | ----------
|
1791 | 1819 | .. [1] Saatchi, Y. (2011). "Scalable inference for structured Gaussian process models"
|
1792 | 1820 | """
|
| 1821 | + rv_op = kroneckernormal |
1793 | 1822 |
|
1794 |
| - def __init__(self, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs): |
1795 |
| - self._setup(covs, chols, evds, sigma) |
1796 |
| - super().__init__(*args, **kwargs) |
1797 |
| - self.mu = at.as_tensor_variable(mu) |
1798 |
| - self.mean = self.median = self.mode = self.mu |
| 1823 | + @classmethod |
| 1824 | + def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs): |
1799 | 1825 |
|
1800 |
| - def _setup(self, covs, chols, evds, sigma): |
1801 |
| - self.cholesky = Cholesky(lower=True, on_error="raise") |
1802 | 1826 | if len([i for i in [covs, chols, evds] if i is not None]) != 1:
|
1803 | 1827 | raise ValueError(
|
1804 | 1828 | "Incompatible parameterization. Specify exactly one of covs, chols, or evds."
|
1805 | 1829 | )
|
1806 |
| - self._isEVD = False |
1807 |
| - self.sigma = sigma |
1808 |
| - self.is_noisy = self.sigma is not None and self.sigma != 0 |
1809 |
| - if covs is not None: |
1810 |
| - self._cov_type = "cov" |
1811 |
| - self.covs = covs |
1812 |
| - if self.is_noisy: |
1813 |
| - # Noise requires eigendecomposition |
1814 |
| - eigh_map = map(eigh, covs) |
1815 |
| - self._setup_evd(eigh_map) |
1816 |
| - else: |
1817 |
| - # Otherwise use cholesky as usual |
1818 |
| - self.chols = list(map(self.cholesky, self.covs)) |
1819 |
| - self.chol_diags = list(map(at.diag, self.chols)) |
1820 |
| - self.sizes = at.as_tensor_variable([chol.shape[0] for chol in self.chols]) |
1821 |
| - self.N = at.prod(self.sizes) |
1822 |
| - elif chols is not None: |
1823 |
| - self._cov_type = "chol" |
1824 |
| - if self.is_noisy: # A strange case... |
1825 |
| - # Noise requires eigendecomposition |
1826 |
| - covs = [at.dot(chol, chol.T) for chol in chols] |
1827 |
| - eigh_map = map(eigh, covs) |
1828 |
| - self._setup_evd(eigh_map) |
1829 |
| - else: |
1830 |
| - self.chols = chols |
1831 |
| - self.chol_diags = list(map(at.diag, self.chols)) |
1832 |
| - self.sizes = at.as_tensor_variable([chol.shape[0] for chol in self.chols]) |
1833 |
| - self.N = at.prod(self.sizes) |
1834 |
| - else: |
1835 |
| - self._cov_type = "evd" |
1836 |
| - self._setup_evd(evds) |
1837 | 1830 |
|
1838 |
| - def _setup_evd(self, eigh_iterable): |
1839 |
| - self._isEVD = True |
1840 |
| - eigs_sep, Qs = zip(*eigh_iterable) # Unzip |
1841 |
| - self.Qs = list(map(at.as_tensor_variable, Qs)) |
1842 |
| - self.QTs = list(map(at.transpose, self.Qs)) |
1843 |
| - |
1844 |
| - self.eigs_sep = list(map(at.as_tensor_variable, eigs_sep)) |
1845 |
| - self.eigs = kron_diag(*self.eigs_sep) # Combine separate eigs |
1846 |
| - if self.is_noisy: |
1847 |
| - self.eigs += self.sigma ** 2 |
1848 |
| - self.N = self.eigs.shape[0] |
1849 |
| - |
1850 |
| - def _setup_random(self): |
1851 |
| - if not hasattr(self, "mv_params"): |
1852 |
| - self.mv_params = {"mu": self.mu} |
1853 |
| - if self._cov_type == "cov": |
1854 |
| - cov = kronecker(*self.covs) |
1855 |
| - if self.is_noisy: |
1856 |
| - cov = cov + self.sigma ** 2 * at.identity_like(cov) |
1857 |
| - self.mv_params["cov"] = cov |
1858 |
| - elif self._cov_type == "chol": |
1859 |
| - if self.is_noisy: |
1860 |
| - covs = [] |
1861 |
| - for eig, Q in zip(self.eigs_sep, self.Qs): |
1862 |
| - cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T)) |
1863 |
| - covs.append(cov_i) |
1864 |
| - cov = kronecker(*covs) |
1865 |
| - if self.is_noisy: |
1866 |
| - cov = cov + self.sigma ** 2 * at.identity_like(cov) |
1867 |
| - self.mv_params["chol"] = self.cholesky(cov) |
1868 |
| - else: |
1869 |
| - self.mv_params["chol"] = kronecker(*self.chols) |
1870 |
| - elif self._cov_type == "evd": |
1871 |
| - covs = [] |
1872 |
| - for eig, Q in zip(self.eigs_sep, self.Qs): |
1873 |
| - cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T)) |
1874 |
| - covs.append(cov_i) |
1875 |
| - cov = kronecker(*covs) |
1876 |
| - if self.is_noisy: |
1877 |
| - cov = cov + self.sigma ** 2 * at.identity_like(cov) |
1878 |
| - self.mv_params["cov"] = cov |
| 1831 | + sigma = sigma if sigma else 0 |
1879 | 1832 |
|
1880 |
| - def random(self, point=None, size=None): |
| 1833 | + if chols is not None: |
| 1834 | + covs = [chol.dot(chol.T) for chol in chols] |
| 1835 | + elif evds is not None: |
| 1836 | + eigh_iterable = evds |
| 1837 | + covs = [] |
| 1838 | + eigs_sep, Qs = zip(*eigh_iterable) # Unzip |
| 1839 | + for eig, Q in zip(eigs_sep, Qs): |
| 1840 | + cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T)) |
| 1841 | + covs.append(cov_i) |
| 1842 | + |
| 1843 | + mu = at.as_tensor_variable(mu) |
| 1844 | + |
| 1845 | + # mean = median = mode = mu |
| 1846 | + return super().dist([mu, sigma, *covs], **kwargs) |
| 1847 | + |
| 1848 | + def logp(value, mu, sigma, *covs): |
1881 | 1849 | """
|
1882 |
| - Draw random values from Multivariate Normal distribution |
1883 |
| - with Kronecker-structured covariance. |
| 1850 | + Calculate log-probability of Multivariate Normal distribution |
| 1851 | + with Kronecker-structured covariance at specified value. |
1884 | 1852 |
|
1885 | 1853 | Parameters
|
1886 | 1854 | ----------
|
1887 |
| - point: dict, optional |
1888 |
| - Dict of variable values on which random values are to be |
1889 |
| - conditioned (uses default point if not specified). |
1890 |
| - size: int, optional |
1891 |
| - Desired size of random sample (returns one sample if not |
1892 |
| - specified). |
| 1855 | + value: numeric |
| 1856 | + Value for which log-probability is calculated. |
1893 | 1857 |
|
1894 | 1858 | Returns
|
1895 | 1859 | -------
|
1896 |
| - array |
| 1860 | + TensorVariable |
1897 | 1861 | """
|
1898 |
| - # Expand params into terms MvNormal can understand to force consistency |
1899 |
| - self._setup_random() |
1900 |
| - self.mv_params["shape"] = self.shape |
1901 |
| - dist = MvNormal.dist(**self.mv_params) |
1902 |
| - return dist.random(point, size) |
1903 |
| - |
1904 |
| - def _quaddist(self, value): |
1905 |
| - """Computes the quadratic (x-mu)^T @ K^-1 @ (x-mu) and log(det(K))""" |
| 1862 | + # Computes the quadratic (x-mu)^T @ K^-1 @ (x-mu) and log(det(K)) |
1906 | 1863 | if value.ndim > 2 or value.ndim == 0:
|
1907 |
| - raise ValueError("Invalid dimension for value: %s" % value.ndim) |
| 1864 | + raise ValueError(f"Invalid dimension for value: {value.ndim}") |
1908 | 1865 | if value.ndim == 1:
|
1909 | 1866 | onedim = True
|
1910 | 1867 | value = value[None, :]
|
1911 | 1868 | else:
|
1912 | 1869 | onedim = False
|
1913 | 1870 |
|
1914 |
| - delta = value - self.mu |
1915 |
| - if self._isEVD: |
1916 |
| - sqrt_quad = kron_dot(self.QTs, delta.T) |
1917 |
| - sqrt_quad = sqrt_quad / at.sqrt(self.eigs[:, None]) |
1918 |
| - logdet = at.sum(at.log(self.eigs)) |
1919 |
| - else: |
1920 |
| - sqrt_quad = kron_solve_lower(self.chols, delta.T) |
1921 |
| - logdet = 0 |
1922 |
| - for chol_size, chol_diag in zip(self.sizes, self.chol_diags): |
1923 |
| - logchol = at.log(chol_diag) * self.N / chol_size |
1924 |
| - logdet += at.sum(2 * logchol) |
| 1871 | + delta = value - mu |
| 1872 | + |
| 1873 | + eigh_iterable = map(eigh, covs) |
| 1874 | + eigs_sep, Qs = zip(*eigh_iterable) # Unzip |
| 1875 | + Qs = list(map(at.as_tensor_variable, Qs)) |
| 1876 | + QTs = list(map(at.transpose, Qs)) |
| 1877 | + |
| 1878 | + eigs_sep = list(map(at.as_tensor_variable, eigs_sep)) |
| 1879 | + eigs = kron_diag(*eigs_sep) # Combine separate eigs |
| 1880 | + eigs += sigma ** 2 |
| 1881 | + N = eigs.shape[0] |
| 1882 | + |
| 1883 | + sqrt_quad = kron_dot(QTs, delta.T) |
| 1884 | + sqrt_quad = sqrt_quad / at.sqrt(eigs[:, None]) |
| 1885 | + logdet = at.sum(at.log(eigs)) |
| 1886 | + |
1925 | 1887 | # Square each sample
|
1926 | 1888 | quad = at.batched_dot(sqrt_quad.T, sqrt_quad.T)
|
1927 | 1889 | if onedim:
|
1928 | 1890 | quad = quad[0]
|
1929 |
| - return quad, logdet |
1930 | 1891 |
|
1931 |
| - def logp(self, value): |
1932 |
| - """ |
1933 |
| - Calculate log-probability of Multivariate Normal distribution |
1934 |
| - with Kronecker-structured covariance at specified value. |
1935 |
| -
|
1936 |
| - Parameters |
1937 |
| - ---------- |
1938 |
| - value: numeric |
1939 |
| - Value for which log-probability is calculated. |
1940 |
| -
|
1941 |
| - Returns |
1942 |
| - ------- |
1943 |
| - TensorVariable |
1944 |
| - """ |
1945 |
| - quad, logdet = self._quaddist(value) |
1946 |
| - return -(quad + logdet + self.N * at.log(2 * np.pi)) / 2.0 |
| 1892 | + a = -(quad + logdet + N * at.log(2 * np.pi)) / 2.0 |
| 1893 | + return a |
1947 | 1894 |
|
1948 | 1895 | def _distr_parameters_for_repr(self):
|
1949 | 1896 | return ["mu"]
|
|
0 commit comments