Skip to content

[WIP] Porting kroneckernormal distribution to v4 #4774

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 74 additions & 127 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import warnings

from functools import reduce

import aesara
import aesara.tensor as at
import numpy as np
Expand Down Expand Up @@ -45,7 +47,7 @@
from pymc3.distributions.continuous import ChiSquared, Normal, assert_negative_support
from pymc3.distributions.dist_math import bound, factln, logpow, multigammaln
from pymc3.distributions.distribution import Continuous, Discrete
from pymc3.math import kron_diag, kron_dot, kron_solve_lower, kronecker
from pymc3.math import kron_diag, kron_dot

__all__ = [
"MvNormal",
Expand Down Expand Up @@ -1702,6 +1704,32 @@ def _distr_parameters_for_repr(self):
return ["mu", "row" + mapping[self._rowcov_type], "col" + mapping[self._colcov_type]]


class KroneckerNormalRV(RandomVariable):
name = "kroneckernormal"
ndim_supp = 2
ndims_params = [1, 0, 2]
dtype = "floatX"
_print_name = ("KroneckerNormal", "\\operatorname{KroneckerNormal}")

def _shape_from_params(self, dist_params, rep_param_idx=0, param_shapes=None):
return default_shape_from_params(1, dist_params, rep_param_idx, param_shapes)

def rng_fn(self, rng, mu, sigma, *covs, size=None):
size = size if size else covs[-1]
covs = covs[:-1] if covs[-1] == size else covs

cov = reduce(linalg.kron, covs)

if sigma:
cov = cov + sigma ** 2 * np.eye(cov.shape[0])

x = multivariate_normal.rng_fn(rng=rng, mean=mu, cov=cov, size=size)
return x


kroneckernormal = KroneckerNormalRV()


class KroneckerNormal(Continuous):
r"""
Multivariate normal log-likelihood with Kronecker-structured covariance.
Expand Down Expand Up @@ -1790,160 +1818,79 @@ class KroneckerNormal(Continuous):
----------
.. [1] Saatchi, Y. (2011). "Scalable inference for structured Gaussian process models"
"""
rv_op = kroneckernormal

def __init__(self, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs):
self._setup(covs, chols, evds, sigma)
super().__init__(*args, **kwargs)
self.mu = at.as_tensor_variable(mu)
self.mean = self.median = self.mode = self.mu
@classmethod
def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs):

def _setup(self, covs, chols, evds, sigma):
self.cholesky = Cholesky(lower=True, on_error="raise")
if len([i for i in [covs, chols, evds] if i is not None]) != 1:
raise ValueError(
"Incompatible parameterization. Specify exactly one of covs, chols, or evds."
)
self._isEVD = False
self.sigma = sigma
self.is_noisy = self.sigma is not None and self.sigma != 0
if covs is not None:
self._cov_type = "cov"
self.covs = covs
if self.is_noisy:
# Noise requires eigendecomposition
eigh_map = map(eigh, covs)
self._setup_evd(eigh_map)
else:
# Otherwise use cholesky as usual
self.chols = list(map(self.cholesky, self.covs))
self.chol_diags = list(map(at.diag, self.chols))
self.sizes = at.as_tensor_variable([chol.shape[0] for chol in self.chols])
self.N = at.prod(self.sizes)
elif chols is not None:
self._cov_type = "chol"
if self.is_noisy: # A strange case...
# Noise requires eigendecomposition
covs = [at.dot(chol, chol.T) for chol in chols]
eigh_map = map(eigh, covs)
self._setup_evd(eigh_map)
else:
self.chols = chols
self.chol_diags = list(map(at.diag, self.chols))
self.sizes = at.as_tensor_variable([chol.shape[0] for chol in self.chols])
self.N = at.prod(self.sizes)
else:
self._cov_type = "evd"
self._setup_evd(evds)

def _setup_evd(self, eigh_iterable):
self._isEVD = True
eigs_sep, Qs = zip(*eigh_iterable) # Unzip
self.Qs = list(map(at.as_tensor_variable, Qs))
self.QTs = list(map(at.transpose, self.Qs))

self.eigs_sep = list(map(at.as_tensor_variable, eigs_sep))
self.eigs = kron_diag(*self.eigs_sep) # Combine separate eigs
if self.is_noisy:
self.eigs += self.sigma ** 2
self.N = self.eigs.shape[0]

def _setup_random(self):
if not hasattr(self, "mv_params"):
self.mv_params = {"mu": self.mu}
if self._cov_type == "cov":
cov = kronecker(*self.covs)
if self.is_noisy:
cov = cov + self.sigma ** 2 * at.identity_like(cov)
self.mv_params["cov"] = cov
elif self._cov_type == "chol":
if self.is_noisy:
covs = []
for eig, Q in zip(self.eigs_sep, self.Qs):
cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T))
covs.append(cov_i)
cov = kronecker(*covs)
if self.is_noisy:
cov = cov + self.sigma ** 2 * at.identity_like(cov)
self.mv_params["chol"] = self.cholesky(cov)
else:
self.mv_params["chol"] = kronecker(*self.chols)
elif self._cov_type == "evd":
covs = []
for eig, Q in zip(self.eigs_sep, self.Qs):
cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T))
covs.append(cov_i)
cov = kronecker(*covs)
if self.is_noisy:
cov = cov + self.sigma ** 2 * at.identity_like(cov)
self.mv_params["cov"] = cov
sigma = sigma if sigma else 0

def random(self, point=None, size=None):
if chols is not None:
covs = [chol.dot(chol.T) for chol in chols]
elif evds is not None:
eigh_iterable = evds
covs = []
eigs_sep, Qs = zip(*eigh_iterable) # Unzip
for eig, Q in zip(eigs_sep, Qs):
cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T))
covs.append(cov_i)

mu = at.as_tensor_variable(mu)

# mean = median = mode = mu
return super().dist([mu, sigma, *covs], **kwargs)

def logp(value, mu, sigma, *covs):
"""
Draw random values from Multivariate Normal distribution
with Kronecker-structured covariance.
Calculate log-probability of Multivariate Normal distribution
with Kronecker-structured covariance at specified value.

Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not
specified).
value: numeric
Value for which log-probability is calculated.

Returns
-------
array
TensorVariable
"""
# Expand params into terms MvNormal can understand to force consistency
self._setup_random()
self.mv_params["shape"] = self.shape
dist = MvNormal.dist(**self.mv_params)
return dist.random(point, size)

def _quaddist(self, value):
"""Computes the quadratic (x-mu)^T @ K^-1 @ (x-mu) and log(det(K))"""
# Computes the quadratic (x-mu)^T @ K^-1 @ (x-mu) and log(det(K))
if value.ndim > 2 or value.ndim == 0:
raise ValueError("Invalid dimension for value: %s" % value.ndim)
raise ValueError(f"Invalid dimension for value: {value.ndim}")
if value.ndim == 1:
onedim = True
value = value[None, :]
else:
onedim = False

delta = value - self.mu
if self._isEVD:
sqrt_quad = kron_dot(self.QTs, delta.T)
sqrt_quad = sqrt_quad / at.sqrt(self.eigs[:, None])
logdet = at.sum(at.log(self.eigs))
else:
sqrt_quad = kron_solve_lower(self.chols, delta.T)
logdet = 0
for chol_size, chol_diag in zip(self.sizes, self.chol_diags):
logchol = at.log(chol_diag) * self.N / chol_size
logdet += at.sum(2 * logchol)
delta = value - mu

eigh_iterable = map(eigh, covs)
eigs_sep, Qs = zip(*eigh_iterable) # Unzip
Qs = list(map(at.as_tensor_variable, Qs))
QTs = list(map(at.transpose, Qs))

eigs_sep = list(map(at.as_tensor_variable, eigs_sep))
eigs = kron_diag(*eigs_sep) # Combine separate eigs
eigs += sigma ** 2
N = eigs.shape[0]

sqrt_quad = kron_dot(QTs, delta.T)
sqrt_quad = sqrt_quad / at.sqrt(eigs[:, None])
logdet = at.sum(at.log(eigs))

# Square each sample
quad = at.batched_dot(sqrt_quad.T, sqrt_quad.T)
if onedim:
quad = quad[0]
return quad, logdet

def logp(self, value):
"""
Calculate log-probability of Multivariate Normal distribution
with Kronecker-structured covariance at specified value.

Parameters
----------
value: numeric
Value for which log-probability is calculated.

Returns
-------
TensorVariable
"""
quad, logdet = self._quaddist(value)
return -(quad + logdet + self.N * at.log(2 * np.pi)) / 2.0
a = -(quad + logdet + N * at.log(2 * np.pi)) / 2.0
return a

def _distr_parameters_for_repr(self):
return ["mu"]
Expand Down
12 changes: 7 additions & 5 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,19 +388,19 @@ def matrix_normal_logpdf_chol(value, mu, rowchol, colchol):
)


def kron_normal_logpdf_cov(value, mu, covs, sigma):
def kron_normal_logpdf_cov(value, mu, covs, sigma, size=None):
cov = kronecker(*covs).eval()
if sigma is not None:
cov += sigma ** 2 * np.eye(*cov.shape)
return scipy.stats.multivariate_normal.logpdf(value, mu, cov).sum()


def kron_normal_logpdf_chol(value, mu, chols, sigma):
def kron_normal_logpdf_chol(value, mu, chols, sigma, size=None):
covs = [np.dot(chol, chol.T) for chol in chols]
return kron_normal_logpdf_cov(value, mu, covs, sigma=sigma)


def kron_normal_logpdf_evd(value, mu, evds, sigma):
def kron_normal_logpdf_evd(value, mu, evds, sigma, size=None):
covs = []
for eigs, Q in evds:
try:
Expand Down Expand Up @@ -1943,8 +1943,7 @@ def test_matrixnormal(self, n):

@pytest.mark.parametrize("n", [2, 3])
@pytest.mark.parametrize("m", [3])
@pytest.mark.parametrize("sigma", [None, 1.0])
@pytest.mark.xfail(reason="Distribution not refactored yet")
@pytest.mark.parametrize("sigma", [None, 1])
def test_kroneckernormal(self, n, m, sigma):
np.random.seed(5)
N = n * m
Expand Down Expand Up @@ -1990,6 +1989,9 @@ def test_kroneckernormal(self, n, m, sigma):
)

dom = Domain([np.random.randn(2, N) * 0.1], edges=(None, None), shape=(2, N))
cov_args["size"] = 2
chol_args["size"] = 2
evd_args["size"] = 2

self.check_logp(
KroneckerNormal,
Expand Down
Loading