Skip to content

Make MvStudentT distribution v4 compatible #4731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 14, 2021
107 changes: 58 additions & 49 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from aesara.tensor import gammaln
from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
from aesara.tensor.random.basic import MultinomialRV, dirichlet, multivariate_normal
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
from aesara.tensor.random.utils import broadcast_params
from aesara.tensor.slinalg import (
Cholesky,
Expand All @@ -41,7 +42,7 @@

from pymc3.aesaraf import floatX, intX
from pymc3.distributions import transforms
from pymc3.distributions.continuous import ChiSquared, Normal
from pymc3.distributions.continuous import ChiSquared, Normal, assert_negative_support
from pymc3.distributions.dist_math import bound, factln, logpow, multigammaln
from pymc3.distributions.distribution import Continuous, Discrete
from pymc3.math import kron_diag, kron_dot, kron_solve_lower, kronecker
Expand Down Expand Up @@ -248,6 +249,48 @@ def _distr_parameters_for_repr(self):
return ["mu", "cov"]


class MvStudentTRV(RandomVariable):
name = "multivariate_studentt"
ndim_supp = 1
ndims_params = [0, 1, 2]
dtype = "floatX"
_print_name = ("MvStudentT", "\\operatorname{MvStudentT}")

def __call__(self, nu, mu=None, cov=None, size=None, **kwargs):

dtype = aesara.config.floatX if self.dtype == "floatX" else self.dtype

if mu is None:
mu = np.array([0.0], dtype=dtype)
if cov is None:
cov = np.array([[1.0]], dtype=dtype)
return super().__call__(nu, mu, cov, size=size, **kwargs)

def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes)

@classmethod
def rng_fn(cls, rng, nu, mu, cov, size):

# Don't reassign broadcasted cov, since MvNormal expects two dimensional cov only.
mu, _ = broadcast_params([mu, cov], cls.ndims_params[1:])

chi2_samples = np.sqrt(rng.chisquare(nu, size=size) / nu)
# Add distribution shape to chi2 samples
chi2_samples = chi2_samples.reshape(chi2_samples.shape + (1,) * len(mu.shape))

mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros_like(mu), cov=cov, size=size)

size = tuple(size or ())
if size:
mu = np.broadcast_to(mu, size + mu.shape)

return (mv_samples / chi2_samples) + mu


mv_studentt = MvStudentTRV()


class MvStudentT(Continuous):
r"""
Multivariate Student-T log-likelihood.
Expand All @@ -273,8 +316,8 @@ class MvStudentT(Continuous):

Parameters
----------
nu: int
Degrees of freedom.
nu: float
Degrees of freedom, should be a positive scalar.
Sigma: matrix
Covariance matrix. Use `cov` in new code.
mu: array
Expand All @@ -288,55 +331,21 @@ class MvStudentT(Continuous):
lower: bool, default=True
Whether the cholesky fatcor is given as a lower triangular matrix.
"""
rv_op = mv_studentt

def __init__(
self, nu, Sigma=None, mu=None, cov=None, tau=None, chol=None, lower=True, *args, **kwargs
):
@classmethod
def dist(cls, nu, Sigma=None, mu=None, cov=None, tau=None, chol=None, lower=True, **kwargs):
if Sigma is not None:
if cov is not None:
raise ValueError("Specify only one of cov and Sigma")
cov = Sigma
super().__init__(mu=mu, cov=cov, tau=tau, chol=chol, lower=lower, *args, **kwargs)
self.nu = nu = at.as_tensor_variable(nu)
self.mean = self.median = self.mode = self.mu = self.mu

def random(self, point=None, size=None):
"""
Draw random values from Multivariate Student's T distribution.

Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not
specified).

Returns
-------
array
"""
# with _DrawValuesContext():
# nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
# if self._cov_type == "cov":
# (cov,) = draw_values([self.cov], point=point, size=size)
# dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov, shape=self.shape)
# elif self._cov_type == "tau":
# (tau,) = draw_values([self.tau], point=point, size=size)
# dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau, shape=self.shape)
# else:
# (chol,) = draw_values([self.chol_cov], point=point, size=size)
# dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol, shape=self.shape)
#
# samples = dist.random(point, size)
#
# chi2_samples = np.random.chisquare(nu, size)
# # Add distribution shape to chi2 samples
# chi2_samples = chi2_samples.reshape(chi2_samples.shape + (1,) * len(self.shape))
# return (samples / np.sqrt(chi2_samples / nu)) + mu
nu = at.as_tensor_variable(floatX(nu))
mu = at.as_tensor_variable(floatX(mu))
cov = quaddist_matrix(cov, chol, tau, lower)
assert_negative_support(nu, "nu", "MvStudentT")
return super().dist([nu, mu, cov], **kwargs)

def logp(value, nu, cov):
def logp(value, nu, mu, cov):
"""
Calculate log-probability of Multivariate Student's T distribution
at specified value.
Expand All @@ -350,15 +359,15 @@ def logp(value, nu, cov):
-------
TensorVariable
"""
quaddist, logdet, ok = quaddist_parse(value, nu, cov)
quaddist, logdet, ok = quaddist_parse(value, mu, cov)
k = floatX(value.shape[-1])

norm = gammaln((nu + k) / 2.0) - gammaln(nu / 2.0) - 0.5 * k * floatX(np.log(nu * np.pi))
norm = gammaln((nu + k) / 2.0) - gammaln(nu / 2.0) - 0.5 * k * at.log(nu * np.pi)
inner = -(nu + k) / 2.0 * at.log1p(quaddist / nu)
return bound(norm + inner - logdet, ok)

def _distr_parameters_for_repr(self):
return ["mu", "nu", "cov"]
return ["nu", "mu", "cov"]


class Dirichlet(Continuous):
Expand Down
10 changes: 4 additions & 6 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2019,7 +2019,6 @@ def test_kroneckernormal(self, n, m, sigma):
)

@pytest.mark.parametrize("n", [1, 2])
@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_mvt(self, n):
self.check_logp(
MvStudentT,
Expand All @@ -2032,6 +2031,7 @@ def test_mvt(self, n):
RealMatrix(2, n),
{"nu": Rplus, "Sigma": PdMatrix(n), "mu": Vector(R, n)},
mvt_logpdf,
extra_args={"size": 2},
)

@pytest.mark.parametrize("n", [2, 3, 4])
Expand Down Expand Up @@ -2937,13 +2937,11 @@ def test_car_logp(size):


class TestBugfixes:
@pytest.mark.parametrize(
"dist_cls,kwargs", [(MvNormal, dict(mu=0)), (MvStudentT, dict(mu=0, nu=2))]
)
@pytest.mark.parametrize("dist_cls,kwargs", [(MvNormal, dict()), (MvStudentT, dict(nu=2))])
@pytest.mark.parametrize("dims", [1, 2, 4])
@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_issue_3051(self, dims, dist_cls, kwargs):
d = dist_cls.dist(**kwargs, cov=np.eye(dims), size=(dims,))
mu = np.repeat(0, dims)
d = dist_cls.dist(mu=mu, cov=np.eye(dims), **kwargs, size=(20))

X = np.random.normal(size=(20, dims))
actual_t = logpt(d, X)
Expand Down
82 changes: 65 additions & 17 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ class TestPoisson(BaseTestDistribution):
tests_to_run = ["check_pymc_params_match_rv_op"]


class TestMvNormal(BaseTestDistribution):
class TestMvNormalCov(BaseTestDistribution):
pymc_dist = pm.MvNormal
pymc_dist_params = {
"mu": np.array([1.0, 2.0]),
Expand Down Expand Up @@ -893,6 +893,70 @@ class TestMvNormalTau(BaseTestDistribution):
tests_to_run = ["check_pymc_params_match_rv_op"]


class TestMvStudentTCov(BaseTestDistribution):
def mvstudentt_rng_fn(self, size, nu, mu, cov, rng):
chi2_samples = rng.chisquare(nu, size=size)
mv_samples = rng.multivariate_normal(np.zeros_like(mu), cov, size=size)
return (mv_samples / np.sqrt(chi2_samples[:, None] / nu)) + mu

pymc_dist = pm.MvStudentT
pymc_dist_params = {
"nu": 5,
"mu": np.array([1.0, 2.0]),
"cov": np.array([[2.0, 0.0], [0.0, 3.5]]),
}
expected_rv_op_params = {
"nu": 5,
"mu": np.array([1.0, 2.0]),
"cov": np.array([[2.0, 0.0], [0.0, 3.5]]),
}
sizes_to_check = [None, (1), (2, 3)]
sizes_expected = [(2,), (1, 2), (2, 3, 2)]
reference_dist_params = {
"nu": 5,
"mu": np.array([1.0, 2.0]),
"cov": np.array([[2.0, 0.0], [0.0, 3.5]]),
}
reference_dist = lambda self: functools.partial(
self.mvstudentt_rng_fn, rng=self.get_random_state()
)
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]


class TestMvStudentTChol(BaseTestDistribution):
pymc_dist = pm.MvStudentT
pymc_dist_params = {
"nu": 5,
"mu": np.array([1.0, 2.0]),
"chol": np.array([[2.0, 0.0], [0.0, 3.5]]),
}
expected_rv_op_params = {
"nu": 5,
"mu": np.array([1.0, 2.0]),
"cov": quaddist_matrix(chol=pymc_dist_params["chol"]).eval(),
}
tests_to_run = ["check_pymc_params_match_rv_op"]


class TestMvStudentTTau(BaseTestDistribution):
pymc_dist = pm.MvStudentT
pymc_dist_params = {
"nu": 5,
"mu": np.array([1.0, 2.0]),
"tau": np.array([[2.0, 0.0], [0.0, 3.5]]),
}
expected_rv_op_params = {
"nu": 5,
"mu": np.array([1.0, 2.0]),
"cov": quaddist_matrix(tau=pymc_dist_params["tau"]).eval(),
}
tests_to_run = ["check_pymc_params_match_rv_op"]


class TestDirichlet(BaseTestDistribution):
pymc_dist = pm.Dirichlet
pymc_dist_params = {"a": np.array([1.0, 2.0])}
Expand Down Expand Up @@ -1402,22 +1466,6 @@ def ref_rand_evd(size, mu, evds, sigma):
model_args=evd_args,
)

@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
def test_mv_t(self):
def ref_rand(size, nu, Sigma, mu):
normal = st.multivariate_normal.rvs(cov=Sigma, size=size)
chi2 = st.chi2.rvs(df=nu, size=size)[..., None]
return mu + (normal / np.sqrt(chi2 / nu))

for n in [2, 3]:
pymc3_random(
pm.MvStudentT,
{"nu": Domain([5, 10, 25, 50]), "Sigma": PdMatrix(n), "mu": Vector(R, n)},
size=100,
valuedomain=Vector(R, n),
ref_rand=ref_rand,
)

@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
def test_dirichlet_multinomial(self):
def ref_rand(size, a, n):
Expand Down