From 08a81c0879d81138dfdb37fccee34012c174c930 Mon Sep 17 00:00:00 2001 From: kc611 Date: Sat, 26 Jun 2021 23:46:37 +0530 Subject: [PATCH] Porting kroneckernormal distribution to v4 Co-authored-by: Ricardo --- pymc3/distributions/multivariate.py | 201 +++++++++-------------- pymc3/tests/test_distributions.py | 12 +- pymc3/tests/test_distributions_random.py | 97 ++++------- pymc3/tests/test_starting.py | 2 +- 4 files changed, 112 insertions(+), 200 deletions(-) diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index b7566d59ac..3366d95593 100644 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -17,6 +17,8 @@ import warnings +from functools import reduce + import aesara import aesara.tensor as at import numpy as np @@ -45,7 +47,7 @@ from pymc3.distributions.continuous import ChiSquared, Normal, assert_negative_support from pymc3.distributions.dist_math import bound, factln, logpow, multigammaln from pymc3.distributions.distribution import Continuous, Discrete -from pymc3.math import kron_diag, kron_dot, kron_solve_lower, kronecker +from pymc3.math import kron_diag, kron_dot __all__ = [ "MvNormal", @@ -1702,6 +1704,32 @@ def _distr_parameters_for_repr(self): return ["mu", "row" + mapping[self._rowcov_type], "col" + mapping[self._colcov_type]] +class KroneckerNormalRV(RandomVariable): + name = "kroneckernormal" + ndim_supp = 2 + ndims_params = [1, 0, 2] + dtype = "floatX" + _print_name = ("KroneckerNormal", "\\operatorname{KroneckerNormal}") + + def _shape_from_params(self, dist_params, rep_param_idx=0, param_shapes=None): + return default_shape_from_params(1, dist_params, rep_param_idx, param_shapes) + + def rng_fn(self, rng, mu, sigma, *covs, size=None): + size = size if size else covs[-1] + covs = covs[:-1] if covs[-1] == size else covs + + cov = reduce(linalg.kron, covs) + + if sigma: + cov = cov + sigma ** 2 * np.eye(cov.shape[0]) + + x = multivariate_normal.rng_fn(rng=rng, mean=mu, cov=cov, size=size) + return x + + +kroneckernormal = KroneckerNormalRV() + + class KroneckerNormal(Continuous): r""" Multivariate normal log-likelihood with Kronecker-structured covariance. @@ -1790,160 +1818,79 @@ class KroneckerNormal(Continuous): ---------- .. [1] Saatchi, Y. (2011). "Scalable inference for structured Gaussian process models" """ + rv_op = kroneckernormal - def __init__(self, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs): - self._setup(covs, chols, evds, sigma) - super().__init__(*args, **kwargs) - self.mu = at.as_tensor_variable(mu) - self.mean = self.median = self.mode = self.mu + @classmethod + def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs): - def _setup(self, covs, chols, evds, sigma): - self.cholesky = Cholesky(lower=True, on_error="raise") if len([i for i in [covs, chols, evds] if i is not None]) != 1: raise ValueError( "Incompatible parameterization. Specify exactly one of covs, chols, or evds." ) - self._isEVD = False - self.sigma = sigma - self.is_noisy = self.sigma is not None and self.sigma != 0 - if covs is not None: - self._cov_type = "cov" - self.covs = covs - if self.is_noisy: - # Noise requires eigendecomposition - eigh_map = map(eigh, covs) - self._setup_evd(eigh_map) - else: - # Otherwise use cholesky as usual - self.chols = list(map(self.cholesky, self.covs)) - self.chol_diags = list(map(at.diag, self.chols)) - self.sizes = at.as_tensor_variable([chol.shape[0] for chol in self.chols]) - self.N = at.prod(self.sizes) - elif chols is not None: - self._cov_type = "chol" - if self.is_noisy: # A strange case... - # Noise requires eigendecomposition - covs = [at.dot(chol, chol.T) for chol in chols] - eigh_map = map(eigh, covs) - self._setup_evd(eigh_map) - else: - self.chols = chols - self.chol_diags = list(map(at.diag, self.chols)) - self.sizes = at.as_tensor_variable([chol.shape[0] for chol in self.chols]) - self.N = at.prod(self.sizes) - else: - self._cov_type = "evd" - self._setup_evd(evds) - def _setup_evd(self, eigh_iterable): - self._isEVD = True - eigs_sep, Qs = zip(*eigh_iterable) # Unzip - self.Qs = list(map(at.as_tensor_variable, Qs)) - self.QTs = list(map(at.transpose, self.Qs)) - - self.eigs_sep = list(map(at.as_tensor_variable, eigs_sep)) - self.eigs = kron_diag(*self.eigs_sep) # Combine separate eigs - if self.is_noisy: - self.eigs += self.sigma ** 2 - self.N = self.eigs.shape[0] - - def _setup_random(self): - if not hasattr(self, "mv_params"): - self.mv_params = {"mu": self.mu} - if self._cov_type == "cov": - cov = kronecker(*self.covs) - if self.is_noisy: - cov = cov + self.sigma ** 2 * at.identity_like(cov) - self.mv_params["cov"] = cov - elif self._cov_type == "chol": - if self.is_noisy: - covs = [] - for eig, Q in zip(self.eigs_sep, self.Qs): - cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T)) - covs.append(cov_i) - cov = kronecker(*covs) - if self.is_noisy: - cov = cov + self.sigma ** 2 * at.identity_like(cov) - self.mv_params["chol"] = self.cholesky(cov) - else: - self.mv_params["chol"] = kronecker(*self.chols) - elif self._cov_type == "evd": - covs = [] - for eig, Q in zip(self.eigs_sep, self.Qs): - cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T)) - covs.append(cov_i) - cov = kronecker(*covs) - if self.is_noisy: - cov = cov + self.sigma ** 2 * at.identity_like(cov) - self.mv_params["cov"] = cov + sigma = sigma if sigma else 0 - def random(self, point=None, size=None): + if chols is not None: + covs = [chol.dot(chol.T) for chol in chols] + elif evds is not None: + eigh_iterable = evds + covs = [] + eigs_sep, Qs = zip(*eigh_iterable) # Unzip + for eig, Q in zip(eigs_sep, Qs): + cov_i = at.dot(Q, at.dot(at.diag(eig), Q.T)) + covs.append(cov_i) + + mu = at.as_tensor_variable(mu) + + # mean = median = mode = mu + return super().dist([mu, sigma, *covs], **kwargs) + + def logp(value, mu, sigma, *covs): """ - Draw random values from Multivariate Normal distribution - with Kronecker-structured covariance. + Calculate log-probability of Multivariate Normal distribution + with Kronecker-structured covariance at specified value. Parameters ---------- - point: dict, optional - Dict of variable values on which random values are to be - conditioned (uses default point if not specified). - size: int, optional - Desired size of random sample (returns one sample if not - specified). + value: numeric + Value for which log-probability is calculated. Returns ------- - array + TensorVariable """ - # Expand params into terms MvNormal can understand to force consistency - self._setup_random() - self.mv_params["shape"] = self.shape - dist = MvNormal.dist(**self.mv_params) - return dist.random(point, size) - - def _quaddist(self, value): - """Computes the quadratic (x-mu)^T @ K^-1 @ (x-mu) and log(det(K))""" + # Computes the quadratic (x-mu)^T @ K^-1 @ (x-mu) and log(det(K)) if value.ndim > 2 or value.ndim == 0: - raise ValueError("Invalid dimension for value: %s" % value.ndim) + raise ValueError(f"Invalid dimension for value: {value.ndim}") if value.ndim == 1: onedim = True value = value[None, :] else: onedim = False - delta = value - self.mu - if self._isEVD: - sqrt_quad = kron_dot(self.QTs, delta.T) - sqrt_quad = sqrt_quad / at.sqrt(self.eigs[:, None]) - logdet = at.sum(at.log(self.eigs)) - else: - sqrt_quad = kron_solve_lower(self.chols, delta.T) - logdet = 0 - for chol_size, chol_diag in zip(self.sizes, self.chol_diags): - logchol = at.log(chol_diag) * self.N / chol_size - logdet += at.sum(2 * logchol) + delta = value - mu + + eigh_iterable = map(eigh, covs) + eigs_sep, Qs = zip(*eigh_iterable) # Unzip + Qs = list(map(at.as_tensor_variable, Qs)) + QTs = list(map(at.transpose, Qs)) + + eigs_sep = list(map(at.as_tensor_variable, eigs_sep)) + eigs = kron_diag(*eigs_sep) # Combine separate eigs + eigs += sigma ** 2 + N = eigs.shape[0] + + sqrt_quad = kron_dot(QTs, delta.T) + sqrt_quad = sqrt_quad / at.sqrt(eigs[:, None]) + logdet = at.sum(at.log(eigs)) + # Square each sample quad = at.batched_dot(sqrt_quad.T, sqrt_quad.T) if onedim: quad = quad[0] - return quad, logdet - def logp(self, value): - """ - Calculate log-probability of Multivariate Normal distribution - with Kronecker-structured covariance at specified value. - - Parameters - ---------- - value: numeric - Value for which log-probability is calculated. - - Returns - ------- - TensorVariable - """ - quad, logdet = self._quaddist(value) - return -(quad + logdet + self.N * at.log(2 * np.pi)) / 2.0 + a = -(quad + logdet + N * at.log(2 * np.pi)) / 2.0 + return a def _distr_parameters_for_repr(self): return ["mu"] diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 1e57071070..dbf2d42515 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -388,19 +388,19 @@ def matrix_normal_logpdf_chol(value, mu, rowchol, colchol): ) -def kron_normal_logpdf_cov(value, mu, covs, sigma): +def kron_normal_logpdf_cov(value, mu, covs, sigma, size=None): cov = kronecker(*covs).eval() if sigma is not None: cov += sigma ** 2 * np.eye(*cov.shape) return scipy.stats.multivariate_normal.logpdf(value, mu, cov).sum() -def kron_normal_logpdf_chol(value, mu, chols, sigma): +def kron_normal_logpdf_chol(value, mu, chols, sigma, size=None): covs = [np.dot(chol, chol.T) for chol in chols] return kron_normal_logpdf_cov(value, mu, covs, sigma=sigma) -def kron_normal_logpdf_evd(value, mu, evds, sigma): +def kron_normal_logpdf_evd(value, mu, evds, sigma, size=None): covs = [] for eigs, Q in evds: try: @@ -1943,8 +1943,7 @@ def test_matrixnormal(self, n): @pytest.mark.parametrize("n", [2, 3]) @pytest.mark.parametrize("m", [3]) - @pytest.mark.parametrize("sigma", [None, 1.0]) - @pytest.mark.xfail(reason="Distribution not refactored yet") + @pytest.mark.parametrize("sigma", [None, 1]) def test_kroneckernormal(self, n, m, sigma): np.random.seed(5) N = n * m @@ -1990,6 +1989,9 @@ def test_kroneckernormal(self, n, m, sigma): ) dom = Domain([np.random.randn(2, N) * 0.1], edges=(None, None), shape=(2, N)) + cov_args["size"] = 2 + chol_args["size"] = 2 + evd_args["size"] = 2 self.check_logp( KroneckerNormal, diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index ad1a731ba7..0d4fb6248c 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -1322,12 +1322,8 @@ def interpolated_rng_fn(self, size, mu, sigma, rng): reference_dist = lambda self: functools.partial( self.interpolated_rng_fn, rng=self.get_random_state() ) - tests_to_run = [ - "check_rv_size", - ] - + tests_to_run = ["check_rv_size", "test_interpolated"] -class TestInterpolatedSeeded(SeededTest): @pytest.mark.xfail(condition=(aesara.config.floatX == "float32"), reason="Fails on float32") def test_interpolated(self): for mu in R.vals: @@ -1348,6 +1344,35 @@ def dist(cls, **kwargs): pymc3_random(TestedInterpolated, {}, ref_rand=ref_rand) +class TestKroneckerNormal(BaseTestDistribution): + def kronecker_rng_fn(self, size, mu, covs=None, sigma=None, rng=None): + cov = pm.math.kronecker(covs[0], covs[1]).eval() + cov += sigma ** 2 * np.identity(cov.shape[0]) + return st.multivariate_normal.rvs(mean=mu, cov=cov, size=size) + + pymc_dist = pm.KroneckerNormal + + n = 3 + N = n ** 2 + covs = [RandomPdMatrix(n), RandomPdMatrix(n)] + mu = np.random.random(N) * 0.1 + sigma = 1 + + pymc_dist_params = {"mu": mu, "covs": covs, "sigma": sigma} + expected_rv_op_params = {"mu": mu, "covs": covs, "sigma": sigma} + reference_dist_params = {"mu": mu, "covs": covs, "sigma": sigma} + sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)] + sizes_expected = [(N,), (N,), (1, N), (1, N), (5, N), (4, 5, N), (2, 4, 2, N)] + + reference_dist = lambda self: functools.partial( + self.kronecker_rng_fn, rng=self.get_random_state() + ) + tests_to_run = [ + "check_pymc_draws_match_reference", + "check_rv_size", + ] + + class TestScalarParameterSamples(SeededTest): @pytest.mark.xfail(reason="This distribution has not been refactored for v4") def test_bounded(self): @@ -1473,68 +1498,6 @@ def ref_rand_uchol(size, mu, rowchol, colchol): ref_rand=ref_rand_chol_transpose, ) - @pytest.mark.xfail(reason="This distribution has not been refactored for v4") - def test_kronecker_normal(self): - def ref_rand(size, mu, covs, sigma): - cov = pm.math.kronecker(covs[0], covs[1]).eval() - cov += sigma ** 2 * np.identity(cov.shape[0]) - return st.multivariate_normal.rvs(mean=mu, cov=cov, size=size) - - def ref_rand_chol(size, mu, chols, sigma): - covs = [np.dot(chol, chol.T) for chol in chols] - return ref_rand(size, mu, covs, sigma) - - def ref_rand_evd(size, mu, evds, sigma): - covs = [] - for eigs, Q in evds: - covs.append(np.dot(Q, np.dot(np.diag(eigs), Q.T))) - return ref_rand(size, mu, covs, sigma) - - sizes = [2, 3] - sigmas = [0, 1] - for n, sigma in zip(sizes, sigmas): - N = n ** 2 - covs = [RandomPdMatrix(n), RandomPdMatrix(n)] - chols = list(map(np.linalg.cholesky, covs)) - evds = list(map(np.linalg.eigh, covs)) - dom = Domain([np.random.randn(N) * 0.1], edges=(None, None), shape=N) - mu = Domain([np.random.randn(N) * 0.1], edges=(None, None), shape=N) - - std_args = {"mu": mu} - cov_args = {"covs": covs} - chol_args = {"chols": chols} - evd_args = {"evds": evds} - if sigma is not None and sigma != 0: - std_args["sigma"] = Domain([sigma], edges=(None, None)) - else: - for args in [cov_args, chol_args, evd_args]: - args["sigma"] = sigma - - pymc3_random( - pm.KroneckerNormal, - std_args, - valuedomain=dom, - ref_rand=ref_rand, - extra_args=cov_args, - model_args=cov_args, - ) - pymc3_random( - pm.KroneckerNormal, - std_args, - valuedomain=dom, - ref_rand=ref_rand_chol, - extra_args=chol_args, - model_args=chol_args, - ) - pymc3_random( - pm.KroneckerNormal, - std_args, - valuedomain=dom, - ref_rand=ref_rand_evd, - extra_args=evd_args, - model_args=evd_args, - ) - @pytest.mark.xfail(reason="This distribution has not been refactored for v4") def test_dirichlet_multinomial(self): def ref_rand(size, a, n): diff --git a/pymc3/tests/test_starting.py b/pymc3/tests/test_starting.py index 42837571e8..0806f1f796 100644 --- a/pymc3/tests/test_starting.py +++ b/pymc3/tests/test_starting.py @@ -109,7 +109,7 @@ def test_find_MAP_issue_4488(): map_estimate = find_MAP() assert not set.difference({"x_missing", "x_missing_log__", "y"}, set(map_estimate.keys())) - np.testing.assert_allclose(map_estimate["x_missing"], 0.2, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(map_estimate["x_missing"], 0.2, rtol=1e-4, atol=1e-4) np.testing.assert_allclose(map_estimate["y"], [2.0, map_estimate["x_missing"][0] + 1])