diff --git a/conda-envs/environment-dev-py37.yml b/conda-envs/environment-dev-py37.yml index 90e22cf6b7..46b3d2d826 100644 --- a/conda-envs/environment-dev-py37.yml +++ b/conda-envs/environment-dev-py37.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aeppl=0.0.17 +- aeppl=0.0.18 - aesara>=2.2.6 - arviz>=0.11.4 - cachetools>=4.2.1 diff --git a/conda-envs/environment-dev-py38.yml b/conda-envs/environment-dev-py38.yml index 69fd28473b..6b98167eb3 100644 --- a/conda-envs/environment-dev-py38.yml +++ b/conda-envs/environment-dev-py38.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aeppl=0.0.17 +- aeppl=0.0.18 - aesara>=2.2.6 - arviz>=0.11.4 - cachetools>=4.2.1 diff --git a/conda-envs/environment-dev-py39.yml b/conda-envs/environment-dev-py39.yml index 6a0bfc917d..dd700dba6a 100644 --- a/conda-envs/environment-dev-py39.yml +++ b/conda-envs/environment-dev-py39.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aeppl=0.0.17 +- aeppl=0.0.18 - aesara>=2.2.6 - arviz>=0.11.4 - cachetools>=4.2.1 diff --git a/conda-envs/environment-test-py37.yml b/conda-envs/environment-test-py37.yml index 964939f493..9aa684412d 100644 --- a/conda-envs/environment-test-py37.yml +++ b/conda-envs/environment-test-py37.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aeppl=0.0.17 +- aeppl=0.0.18 - aesara>=2.2.6 - arviz>=0.11.4 - cachetools>=4.2.1 diff --git a/conda-envs/environment-test-py38.yml b/conda-envs/environment-test-py38.yml index cec0b1ffa2..8df3a4759e 100644 --- a/conda-envs/environment-test-py38.yml +++ b/conda-envs/environment-test-py38.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aeppl=0.0.17 +- aeppl=0.0.18 - aesara>=2.2.6 - arviz>=0.11.4 - cachetools>=4.2.1 diff --git a/conda-envs/environment-test-py39.yml b/conda-envs/environment-test-py39.yml index 50d87383c6..e8d50dc2c5 100644 --- a/conda-envs/environment-test-py39.yml +++ b/conda-envs/environment-test-py39.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aeppl=0.0.17 +- aeppl=0.0.18 - aesara>=2.2.6 - arviz>=0.11.4 - cachetools diff --git a/conda-envs/windows-environment-dev-py38.yml b/conda-envs/windows-environment-dev-py38.yml index ffac106acf..855187c643 100644 --- a/conda-envs/windows-environment-dev-py38.yml +++ b/conda-envs/windows-environment-dev-py38.yml @@ -4,7 +4,7 @@ channels: - defaults dependencies: # base dependencies (see install guide for Windows) -- aeppl=0.0.17 +- aeppl=0.0.18 - aesara>=2.2.6 - arviz>=0.11.4 - cachetools>=4.2.1 diff --git a/conda-envs/windows-environment-test-py38.yml b/conda-envs/windows-environment-test-py38.yml index 6ca7b966dd..e23664e815 100644 --- a/conda-envs/windows-environment-test-py38.yml +++ b/conda-envs/windows-environment-test-py38.yml @@ -4,7 +4,7 @@ channels: - defaults dependencies: # base dependencies (see install guide for Windows) -- aeppl=0.0.17 +- aeppl=0.0.18 - aesara>=2.2.6 - arviz>=0.11.2 - cachetools diff --git a/docs/source/contributing/developer_guide_implementing_distribution.md b/docs/source/contributing/developer_guide_implementing_distribution.md index 85ceee6f90..b1f036b71c 100644 --- a/docs/source/contributing/developer_guide_implementing_distribution.md +++ b/docs/source/contributing/developer_guide_implementing_distribution.md @@ -129,11 +129,11 @@ Here is how the example continues: from pymc.aesaraf import floatX, intX from pymc.distributions.continuous import PositiveContinuous -from pymc.distributions.dist_math import bound +from pymc.distributions.dist_math import check_parameters + # Subclassing `PositiveContinuous` will dispatch a default `log` transformation class Blah(PositiveContinuous): - # This will be used by the metaclass `DistributionMeta` to dispatch the # class `logp` and `logcdf` methods to the `blah` `op` rv_op = blah @@ -158,24 +158,25 @@ class Blah(PositiveContinuous): def logp(value, param1, param2): logp_expression = value * (param1 + at.log(param2)) - # We use `bound` for parameter validation. After the default expression, - # multiple comma-separated symbolic conditions can be added. Whenever - # a bound is invalidated, the returned expression evaluates to `-np.inf` - return bound( + # A switch is often used to enforce the distribution support domain + bounded_logp_expression = at.switch( + at.gt(value >= 0), logp_expression, - value >= 0, - param2 >= 0, - # There is one sneaky optional keyowrd argument, that converts an - # otherwise elemwise `bound` to a reduced scalar bound. This is usually - # needed for multivariate distributions where the dimensionality - # of the bound conditions does not match that of the "value" / "logp" - # By default it is set to `True`. - broadcast_conditions=True, + -np.inf, ) - # logcdf works the same way as logp. For bounded variables, it is expected - # to return `-inf` for values below the domain start and `0` for values - # above the domain end, but only when the parameters are valid. + # We use `check_parameters` for parameter validation. After the default expression, + # multiple comma-separated symbolic conditions can be added. Whenever + # a bound is invalidated, the returned expression raises an error with the message + # defined in the optional `msg` keyword argument. + return check_parameters( + logp_expression, + param2 >= 0, + msg="param2 >= 0", + ) + + # logcdf works the same way as logp. For bounded variables, it is expected to return + # `-inf` for values below the domain start and `0` for values above the domain end. def logcdf(value, param1, param2): ... @@ -357,7 +358,7 @@ New distributions should have a rich docstring, following the same format as tha It generally looks something like this: ```python - r"""Univariate blah distribution. +r"""Univariate blah distribution. The pdf of this distribution is diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 3a3203aedb..14fdad70f6 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -30,9 +30,11 @@ import numpy as np import scipy.sparse as sps +from aeppl.logprob import CheckParameterValue from aesara import config, scalar from aesara.compile.mode import Mode, get_mode from aesara.gradient import grad +from aesara.graph import local_optimizer from aesara.graph.basic import ( Apply, Constant, @@ -899,11 +901,65 @@ def take_along_axis(arr, indices, axis=0): return arr[_make_along_axis_idx(arr_shape, indices, _axis)] -def compile_rv_inplace(inputs, outputs, mode=None, **kwargs): - """Use ``aesara.function`` with the random_make_inplace optimization always enabled. +@local_optimizer(tracks=[CheckParameterValue]) +def local_remove_check_parameter(fgraph, node): + """Rewrite that removes Aeppl's CheckParameterValue - Using this function ensures that compiled functions containing random - variables will produce new samples on each call. + This is used when compile_rv_inplace + """ + if isinstance(node.op, CheckParameterValue): + return [node.inputs[0]] + + +@local_optimizer(tracks=[CheckParameterValue]) +def local_check_parameter_to_ninf_switch(fgraph, node): + if isinstance(node.op, CheckParameterValue): + logp_expr, *logp_conds = node.inputs + if len(logp_conds) > 1: + logp_cond = at.all(logp_conds) + else: + (logp_cond,) = logp_conds + out = at.switch(logp_cond, logp_expr, -np.inf) + out.name = node.op.msg + + if out.dtype != node.outputs[0].dtype: + out = at.cast(out, node.outputs[0].dtype) + + return [out] + + +aesara.compile.optdb["canonicalize"].register( + "local_remove_check_parameter", + local_remove_check_parameter, + use_db_name_as_tag=False, +) + +aesara.compile.optdb["canonicalize"].register( + "local_check_parameter_to_ninf_switch", + local_check_parameter_to_ninf_switch, + use_db_name_as_tag=False, +) + + +def compile_pymc(inputs, outputs, mode=None, **kwargs): + """Use ``aesara.function`` with specialized pymc rewrites always enabled. + + Included rewrites + ----------------- + random_make_inplace + Ensures that compiled functions containing random variables will produce new + samples on each call. + local_check_parameter_to_ninf_switch + Replaces Aeppl's CheckParameterValue assertions is logp expressions with Switches + that return -inf in place of the assert. + + Optional rewrites + ----------------- + local_remove_check_parameter + Replaces Aeppl's CheckParameterValue assertions is logp expressions. This is used + as an alteranative to the default local_check_parameter_to_ninf_switch whenenver + this function is called within a model context and the model `check_bounds` flag + is set to False. """ # Avoid circular dependency @@ -921,8 +977,20 @@ def compile_rv_inplace(inputs, outputs, mode=None, **kwargs): if not hasattr(rng, "default_update"): rng.default_update = rv.owner.outputs[0] + # If called inside a model context, see if check_bounds flag is set to False + try: + from pymc.model import modelcontext + + model = modelcontext(None) + check_bounds = model.check_bounds + except TypeError: + check_bounds = True + check_parameter_opt = ( + "local_check_parameter_to_ninf_switch" if check_bounds else "local_remove_check_parameter" + ) + mode = get_mode(mode) - opt_qry = mode.provided_optimizer.including("random_make_inplace") + opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt) mode = Mode(linker=mode.linker, optimizer=opt_qry) aesara_function = aesara.function(inputs, outputs, mode=mode, **kwargs) return aesara_function diff --git a/pymc/distributions/bound.py b/pymc/distributions/bound.py index be773ec60e..076efebd60 100644 --- a/pymc/distributions/bound.py +++ b/pymc/distributions/bound.py @@ -12,17 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import aesara.tensor as at import numpy as np -from aeppl.logprob import logprob from aesara.tensor import as_tensor_variable from aesara.tensor.random.op import RandomVariable from aesara.tensor.var import TensorVariable from pymc.aesaraf import floatX, intX from pymc.distributions.continuous import BoundedContinuous -from pymc.distributions.dist_math import bound +from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import Continuous, Discrete +from pymc.distributions.logprob import logp from pymc.distributions.shape_utils import to_tuple from pymc.model import modelcontext @@ -67,8 +68,17 @@ def logp(value, distribution, lower, upper): ------- TensorVariable """ - logp = logprob(distribution, value) - return bound(logp, (value >= lower), (value <= upper)) + res = at.switch( + at.or_(at.lt(value, lower), at.gt(value, upper)), + -np.inf, + logp(distribution, value), + ) + + return check_parameters( + res, + lower <= upper, + msg="lower <= upper", + ) class DiscreteBoundRV(BoundRV): @@ -107,8 +117,17 @@ def logp(value, distribution, lower, upper): ------- TensorVariable """ - logp = logprob(distribution, value) - return bound(logp, (value >= lower), (value <= upper)) + res = at.switch( + at.or_(at.lt(value, lower), at.gt(value, upper)), + -np.inf, + logp(distribution, value), + ) + + return check_parameters( + res, + lower <= upper, + msg="lower <= upper", + ) class Bound: diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 50f3f1f368..bd7a1e942d 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -24,9 +24,9 @@ import aesara.tensor as at import numpy as np -from aeppl.logprob import _logprob +from aeppl.logprob import _logprob, logcdf from aesara.assert_op import Assert -from aesara.graph.basic import Apply +from aesara.graph.basic import Apply, Variable from aesara.graph.op import Op from aesara.tensor import gammaln from aesara.tensor.extra_ops import broadcast_shape @@ -76,7 +76,7 @@ def polyagamma_cdf(*args, **kwargs): from pymc.distributions import logp_transform, transforms from pymc.distributions.dist_math import ( SplineWrapper, - bound, + check_parameters, clipped_beta_rvs, i0e, log_normal, @@ -230,13 +230,24 @@ def get_tau_sigma(tau=None, sigma=None): sigma = 1.0 tau = 1.0 else: - tau = sigma ** -2.0 + if isinstance(sigma, Variable): + sigma_ = check_parameters(sigma, sigma > 0, msg="sigma > 0") + else: + assert np.all(np.asarray(sigma) > 0) + sigma_ = sigma + tau = sigma_ ** -2.0 else: if sigma is not None: raise ValueError("Can't pass both tau and sigma") else: - sigma = tau ** -0.5 + if isinstance(tau, Variable): + tau_ = check_parameters(tau, tau > 0, msg="tau > 0") + else: + assert np.all(np.asarray(tau) > 0) + tau_ = tau + + sigma = tau_ ** -0.5 return floatX(tau), floatX(sigma) @@ -442,7 +453,7 @@ def logp(value): ------- TensorVariable """ - return bound(at.zeros_like(value), value > 0) + return at.switch(at.lt(value, 0), -np.inf, at.zeros_like(value)) def logcdf(value): """ @@ -565,9 +576,10 @@ def logcdf(value, mu, sigma): ------- TensorVariable """ - return bound( + return check_parameters( normal_lcdf(mu, sigma, value), 0 < sigma, + msg="sigma > 0", ) @@ -771,7 +783,7 @@ def logp( bounds.append(value <= upper) if not unbounded_lower and not unbounded_upper: bounds.append(lower <= upper) - return bound(logp, *bounds) + return check_parameters(logp, *bounds) class HalfNormal(PositiveContinuous): @@ -873,10 +885,16 @@ def logcdf(value, loc, sigma): TensorVariable """ z = zvalue(value, mu=loc, sigma=sigma) - return bound( + logcdf = at.switch( + at.lt(value, loc), + -np.inf, at.log1p(-at.erfc(z / at.sqrt(2.0))), - loc <= value, + ) + + return check_parameters( + logcdf, 0 < sigma, + msg="sigma > 0", ) @@ -1050,15 +1068,22 @@ def logp( TensorVariable """ centered_value = value - alpha - # value *must* be iid. Otherwise this is wrong. - return bound( - logpow(lam / (2.0 * np.pi), 0.5) - - logpow(centered_value, 1.5) - - (0.5 * lam / centered_value * ((centered_value - mu) / mu) ** 2), - centered_value > 0, + logp = at.switch( + at.le(centered_value, 0), + -np.inf, + ( + logpow(lam / (2.0 * np.pi), 0.5) + - logpow(centered_value, 1.5) + - (0.5 * lam / centered_value * ((centered_value - mu) / mu) ** 2) + ), + ) + + return check_parameters( + logp, mu > 0, lam > 0, alpha >= 0, + msg="mu > 0, lam > 0, alpha >= 0", ) def logcdf( @@ -1095,16 +1120,18 @@ def logcdf( a = normal_lcdf(0, 1, (q - 1.0) / r) b = 2.0 / l + normal_lcdf(0, 1, -(q + 1.0) / r) - return bound( + logcdf = at.switch( + at.le(value, 0), + -np.inf, at.switch( at.lt(value, np.inf), a + at.log1pexp(b - a), 0, ), - 0 < value, - 0 < mu, - 0 < lam, - 0 <= alpha, + ) + + return check_parameters( + logcdf, 0 < mu, 0 < lam, 0 <= alpha, msg="mu > 0, lam > 0, alpha >= 0" ) @@ -1219,9 +1246,6 @@ def get_alpha_beta(self, alpha=None, beta=None, mu=None, sigma=None): return alpha, beta - def _distr_parameters_for_repr(self): - return ["alpha", "beta"] - def logcdf(value, alpha, beta): """ Compute the log of the cumulative distribution function for Beta distribution @@ -1238,15 +1262,21 @@ def logcdf(value, alpha, beta): TensorVariable """ - return bound( + logcdf = at.switch( + at.lt(value, 0), + -np.inf, at.switch( at.lt(value, 1), at.log(at.betainc(alpha, beta, value)), 0, ), - 0 <= value, + ) + + return check_parameters( + logcdf, 0 < alpha, 0 < beta, + msg="alpha > 0, beta > 0", ) @@ -1340,9 +1370,18 @@ def logp(value, a, b): ------- TensorVariable """ - logp = at.log(a) + at.log(b) + (a - 1) * at.log(value) + (b - 1) * at.log(1 - value ** a) - - return bound(logp, value >= 0, value <= 1, a > 0, b > 0) + res = at.log(a) + at.log(b) + (a - 1) * at.log(value) + (b - 1) * at.log(1 - value ** a) + res = at.switch( + at.or_(at.lt(value, 0), at.gt(value, 1)), + -np.inf, + res, + ) + return check_parameters( + res, + a > 0, + b > 0, + msg="a > 0, b > 0", + ) def logcdf(value, a, b): r""" @@ -1360,8 +1399,22 @@ def logcdf(value, a, b): ------- TensorVariable """ - logcdf = at.log1mexp(b * at.log1p(-(value ** a))) - return bound(at.switch(value < 1, logcdf, 0), value >= 0, a > 0, b > 0) + res = at.switch( + at.lt(value, 0), + -np.inf, + at.switch( + at.lt(value, 1), + at.log1mexp(b * at.log1p(-(value ** a))), + 0, + ), + ) + + return check_parameters( + res, + a > 0, + b > 0, + msg="a > 0, b > 0", + ) class Exponential(PositiveContinuous): @@ -1434,12 +1487,14 @@ def logcdf(value, mu): TensorVariable """ lam = at.inv(mu) - return bound( + res = at.switch( + at.lt(value, 0), + -np.inf, at.log1mexp(-lam * value), - 0 <= value, - 0 <= lam, ) + return check_parameters(res, 0 <= lam, msg="lam >= 0") + class Laplace(Continuous): r""" @@ -1515,17 +1570,21 @@ def logcdf(value, mu, b): TensorVariable """ y = (value - mu) / b - return bound( + + res = at.switch( + at.le(value, mu), + at.log(0.5) + y, at.switch( - at.le(value, mu), - at.log(0.5) + y, - at.switch( - at.gt(y, 1), - at.log1p(-0.5 * at.exp(-y)), - at.log(1 - 0.5 * at.exp(-y)), - ), + at.gt(y, 1), + at.log1p(-0.5 * at.exp(-y)), + at.log(1 - 0.5 * at.exp(-y)), ), + ) + + return check_parameters( + res, 0 < b, + msg="b > 0", ) @@ -1622,13 +1681,12 @@ def logp(value, b, kappa, mu): TensorVariable """ value = value - mu - return bound( - at.log(b / (kappa + (kappa ** -1))) - + (-value * b * at.sgn(value) * (kappa ** at.sgn(value))), - 0 < b, - 0 < kappa, + res = at.log(b / (kappa + (kappa ** -1))) + ( + -value * b * at.sgn(value) * (kappa ** at.sgn(value)) ) + return check_parameters(res, 0 < b, 0 < kappa, msg="b > 0, kappa > 0") + class LogNormal(PositiveContinuous): r""" @@ -1733,13 +1791,14 @@ def logcdf(value, mu, sigma): ------- TensorVariable """ - - return bound( + res = at.switch( + at.le(value, 0), + -np.inf, normal_lcdf(mu, sigma, at.log(value)), - 0 < value, - 0 < sigma, ) + return check_parameters(res, 0 < sigma, msg="sigma > 0") + Lognormal = LogNormal @@ -1857,17 +1916,17 @@ def logp(value, nu, mu, sigma): ------- TensorVariable """ - lam, sigma = get_tau_sigma(sigma=sigma) - return bound( + lam, _ = get_tau_sigma(sigma=sigma) + + res = ( gammaln((nu + 1.0) / 2.0) + 0.5 * at.log(lam / (nu * np.pi)) - gammaln(nu / 2.0) - - (nu + 1.0) / 2.0 * at.log1p(lam * (value - mu) ** 2 / nu), - lam > 0, - nu > 0, - sigma > 0, + - (nu + 1.0) / 2.0 * at.log1p(lam * (value - mu) ** 2 / nu) ) + return check_parameters(res, lam > 0, nu > 0, msg="lam > 0, nu > 0") + def logcdf(value, nu, mu, sigma): """ Compute the log of the cumulative distribution function for Student's T distribution @@ -1883,18 +1942,15 @@ def logcdf(value, nu, mu, sigma): ------- TensorVariable """ - lam, sigma = get_tau_sigma(sigma=sigma) + _, sigma = get_tau_sigma(sigma=sigma) t = (value - mu) / sigma sqrt_t2_nu = at.sqrt(t ** 2 + nu) x = (t + sqrt_t2_nu) / (2.0 * sqrt_t2_nu) - return bound( - at.log(at.betainc(nu / 2.0, nu / 2.0, x)), - 0 < nu, - 0 < sigma, - 0 < lam, - ) + res = at.log(at.betainc(nu / 2.0, nu / 2.0, x)) + + return check_parameters(res, 0 < nu, 0 < sigma, msg="nu > 0, sigma > 0") class Pareto(BoundedContinuous): @@ -1983,17 +2039,19 @@ def logcdf( TensorVariable """ arg = (m / value) ** alpha - return bound( + + res = at.switch( + at.lt(value, m), + -np.inf, at.switch( at.le(arg, 1e-5), at.log1p(-arg), at.log(1 - arg), ), - m <= value, - 0 < alpha, - 0 < m, ) + return check_parameters(res, 0 < alpha, 0 < m, msg="alpha > 0, m > 0") + class Cauchy(Continuous): r""" @@ -2071,9 +2129,11 @@ def logcdf(value, alpha, beta): ------- TensorVariable """ - return bound( - at.log(0.5 + at.arctan((value - alpha) / beta) / np.pi), + res = at.log(0.5 + at.arctan((value - alpha) / beta) / np.pi) + return check_parameters( + res, 0 < beta, + msg="beta > 0", ) @@ -2143,12 +2203,14 @@ def logcdf(value, loc, beta): ------- TensorVariable """ - return bound( + res = at.switch( + at.lt(value, loc), + -np.inf, at.log(2 * at.arctan((value - loc) / beta) / np.pi), - loc <= value, - 0 < beta, ) + return check_parameters(res, 0 < beta, msg="beta > 0") + class Gamma(PositiveContinuous): r""" @@ -2231,6 +2293,10 @@ def get_alpha_beta(cls, alpha=None, beta=None, mu=None, sigma=None): if (alpha is not None) and (beta is not None): pass elif (mu is not None) and (sigma is not None): + if isinstance(sigma, Variable): + sigma = check_parameters(sigma, sigma > 0, msg="sigma > 0") + else: + assert np.all(np.asarray(sigma) > 0) alpha = mu ** 2 / sigma ** 2 beta = mu / sigma ** 2 else: @@ -2266,14 +2332,14 @@ def logcdf(value, alpha, inv_beta): TensorVariable """ beta = at.inv(inv_beta) - - return bound( + res = at.switch( + at.lt(value, 0), + -np.inf, at.log(at.gammainc(alpha, beta * value)), - 0 <= value, - 0 < alpha, - 0 < beta, ) + return check_parameters(res, 0 < alpha, 0 < beta, msg="alpha > 0, beta > 0") + class InverseGamma(PositiveContinuous): r""" @@ -2355,6 +2421,10 @@ def _get_alpha_beta(cls, alpha, beta, mu, sigma): else: beta = 1 elif (mu is not None) and (sigma is not None): + if isinstance(sigma, Variable): + sigma = check_parameters(sigma, sigma > 0, msg="sigma > 0") + else: + assert np.all(np.asarray(sigma) > 0) alpha = (2 * sigma ** 2 + mu ** 2) / sigma ** 2 beta = mu * (mu ** 2 + sigma ** 2) / sigma ** 2 else: @@ -2385,14 +2455,14 @@ def logcdf(value, alpha, beta): ------- TensorVariable """ - - return bound( + res = at.switch( + at.lt(value, 0), + -np.inf, at.log(at.gammaincc(alpha, beta / value)), - 0 <= value, - 0 < alpha, - 0 < beta, ) + return check_parameters(res, 0 < alpha, 0 < beta, msg="alpha > 0, beta > 0") + class ChiSquared(PositiveContinuous): r""" @@ -2460,7 +2530,7 @@ def logcdf(value, nu): ------- TensorVariable """ - return Gamma.logcdf(value, nu / 2, 2) + return logcdf(Gamma.dist(alpha=nu / 2, beta=0.5), value) # TODO: Remove this once logpt for multiplication is working! @@ -2554,13 +2624,15 @@ def logcdf(value, alpha, beta): TensorVariable """ a = (value / beta) ** alpha - return bound( + + res = at.switch( + at.lt(value, 0), + -np.inf, at.log1mexp(-a), - 0 <= value, - 0 < alpha, - 0 < beta, ) + return check_parameters(res, 0 < alpha, 0 < beta, msg="alpha > 0, beta > 0") + class HalfStudentTRV(RandomVariable): name = "halfstudentt" @@ -2673,22 +2745,21 @@ def logp(value, nu, sigma): TensorVariable """ - lam, sigma = get_tau_sigma(None, sigma) - - return bound( + res = ( at.log(2) + gammaln((nu + 1.0) / 2.0) - gammaln(nu / 2.0) - 0.5 * at.log(nu * np.pi * sigma ** 2) - - (nu + 1.0) / 2.0 * at.log1p(value ** 2 / (nu * sigma ** 2)), - sigma > 0, - lam > 0, - nu > 0, - value >= 0, + - (nu + 1.0) / 2.0 * at.log1p(value ** 2 / (nu * sigma ** 2)) ) - def _distr_parameters_for_repr(self): - return ["nu", "lam"] + res = at.switch( + at.lt(value, 0), + -np.inf, + res, + ) + + return check_parameters(res, sigma > 0, nu > 0, msg="sigma > 0, nu > 0") class ExGaussianRV(RandomVariable): @@ -2811,19 +2882,21 @@ def logp(value, mu, sigma, nu): """ # Alogithm is adapted from dexGAUS.R from gamlss - return bound( - at.switch( - at.gt(nu, 0.05 * sigma), - ( - -at.log(nu) - + (mu - value) / nu - + 0.5 * (sigma / nu) ** 2 - + normal_lcdf(mu + (sigma ** 2) / nu, sigma, value) - ), - log_normal(value, mean=mu, sigma=sigma), + res = at.switch( + at.gt(nu, 0.05 * sigma), + ( + -at.log(nu) + + (mu - value) / nu + + 0.5 * (sigma / nu) ** 2 + + normal_lcdf(mu + (sigma ** 2) / nu, sigma, value) ), + log_normal(value, mean=mu, sigma=sigma), + ) + return check_parameters( + res, 0 < sigma, 0 < nu, + msg="nu > 0, sigma > 0", ) def logcdf(value, mu, sigma, nu): @@ -2849,25 +2922,20 @@ def logcdf(value, mu, sigma, nu): """ # Alogithm is adapted from pexGAUS.R from gamlss - return bound( - at.switch( - at.gt(nu, 0.05 * sigma), - logdiffexp( - normal_lcdf(mu, sigma, value), - ( - (mu - value) / nu - + 0.5 * (sigma / nu) ** 2 - + normal_lcdf(mu + (sigma ** 2) / nu, sigma, value) - ), - ), + res = at.switch( + at.gt(nu, 0.05 * sigma), + logdiffexp( normal_lcdf(mu, sigma, value), + ( + (mu - value) / nu + + 0.5 * (sigma / nu) ** 2 + + normal_lcdf(mu + (sigma ** 2) / nu, sigma, value) + ), ), - 0 < sigma, - 0 < nu, + normal_lcdf(mu, sigma, value), ) - def _distr_parameters_for_repr(self): - return ["mu", "sigma", "nu"] + return check_parameters(res, 0 < sigma, 0 < nu, msg="sigma > 0, nu > 0") class VonMises(CircularContinuous): @@ -3042,14 +3110,15 @@ def logp(value, mu, sigma, alpha): ------- TensorVariable """ - tau, sigma = get_tau_sigma(sigma=sigma) - return bound( + tau, _ = get_tau_sigma(sigma=sigma) + + res = ( at.log(1 + at.erf(((value - mu) * at.sqrt(tau) * alpha) / at.sqrt(2))) - + (-tau * (value - mu) ** 2 + at.log(tau / np.pi / 2.0)) / 2.0, - tau > 0, - sigma > 0, + + (-tau * (value - mu) ** 2 + at.log(tau / np.pi / 2.0)) / 2.0 ) + return check_parameters(res, tau > 0, msg="tau > 0") + class Triangular(BoundedContinuous): r""" @@ -3138,22 +3207,25 @@ def logcdf(value, lower, c, upper): ------- TensorVariable """ - return bound( + res = at.switch( + at.le(value, lower), + -np.inf, at.switch( - at.le(value, lower), - -np.inf, + at.le(value, c), + at.log(((value - lower) ** 2) / ((upper - lower) * (c - lower))), at.switch( - at.le(value, c), - at.log(((value - lower) ** 2) / ((upper - lower) * (c - lower))), - at.switch( - at.lt(value, upper), - at.log1p(-((upper - value) ** 2) / ((upper - lower) * (upper - c))), - 0, - ), + at.lt(value, upper), + at.log1p(-((upper - value) ** 2) / ((upper - lower) * (upper - c))), + 0, ), ), + ) + + return check_parameters( + res, lower <= c, c <= upper, + msg="lower <= c <= upper", ) @@ -3248,10 +3320,9 @@ def logcdf( ------- TensorVariable """ - return bound( - -at.exp(-(value - mu) / beta), - 0 < beta, - ) + res = -at.exp(-(value - mu) / beta) + + return check_parameters(res, 0 < beta, msg="beta > 0") class RiceRV(RandomVariable): @@ -3384,10 +3455,18 @@ def logp(value, b, sigma): TensorVariable """ x = value / sigma - return bound( + + res = at.switch( + at.le(value, 0), + -np.inf, at.log(x * at.exp((-(x - b) * (x - b)) / 2) * i0e(x * b) / sigma), + ) + + return check_parameters( + res, sigma >= 0, - value > 0, + b >= 0, + msg="sigma >= 0, b >= 0", ) @@ -3464,10 +3543,12 @@ def logcdf(value, mu, s): ------- TensorVariable """ + res = -at.log1pexp(-(value - mu) / s) - return bound( - -at.log1pexp(-(value - mu) / s), + return check_parameters( + res, 0 < s, + msg="s > 0", ) @@ -3565,14 +3646,22 @@ def logp(value, mu, sigma): ------- TensorVariable """ - tau, sigma = get_tau_sigma(sigma=sigma) - return bound( - -0.5 * tau * (logit(value) - mu) ** 2 - + 0.5 * at.log(tau / (2.0 * np.pi)) - - at.log(value * (1 - value)), - value > 0, - value < 1, + tau, _ = get_tau_sigma(sigma=sigma) + + res = at.switch( + at.or_(at.le(value, 0), at.ge(value, 1)), + -np.inf, + ( + -0.5 * tau * (logit(value) - mu) ** 2 + + 0.5 * at.log(tau / (2.0 * np.pi)) + - at.log(value * (1 - value)) + ), + ) + + return check_parameters( + res, tau > 0, + msg="tau > 0", ) @@ -3810,10 +3899,8 @@ def logp(value, mu, sigma): TensorVariable """ scaled = (value - mu) / sigma - return bound( - (-(1 / 2) * (scaled + at.exp(-scaled)) - at.log(sigma) - (1 / 2) * at.log(2 * np.pi)), - 0 < sigma, - ) + res = -(1 / 2) * (scaled + at.exp(-scaled)) - at.log(sigma) - (1 / 2) * at.log(2 * np.pi) + return check_parameters(res, 0 < sigma, msg="sigma > 0") def logcdf(value, mu, sigma): """ @@ -3831,9 +3918,11 @@ def logcdf(value, mu, sigma): TensorVariable """ scaled = (value - mu) / sigma - return bound( - at.log(at.erfc(at.exp(-scaled / 2) * (2 ** -0.5))), + res = at.log(at.erfc(at.exp(-scaled / 2) * (2 ** -0.5))) + return check_parameters( + res, 0 < sigma, + msg="sigma > 0", ) @@ -4021,7 +4110,16 @@ def logp(value, h, z): TensorVariable """ - return bound(_PolyaGammaLogDistFunc(True)(value, h, z), h > 0, value > 0) + res = at.switch( + at.le(value, 0), + -np.inf, + _PolyaGammaLogDistFunc(get_pdf=True)(value, h, z), + ) + return check_parameters( + res, + h > 0, + msg="h > 0", + ) def logcdf(value, h, z): """ @@ -4038,4 +4136,14 @@ def logcdf(value, h, z): ------- TensorVariable """ - return bound(_PolyaGammaLogDistFunc(False)(value, h, z), h > 0, value > 0) + res = at.switch( + at.le(value, 0), + -np.inf, + _PolyaGammaLogDistFunc(get_pdf=False)(value, h, z), + ) + + return check_parameters( + res, + h > 0, + msg="h > 0", + ) diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 68a386208c..1a4900cc8a 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -33,7 +33,7 @@ from pymc.distributions.dist_math import ( betaln, binomln, - bound, + check_parameters, factln, log_diff_normal_cdf, logpow, @@ -135,14 +135,15 @@ def logp(value, n, p): ------- TensorVariable """ - return bound( + + res = at.switch( + at.or_(at.lt(value, 0), at.gt(value, n)), + -np.inf, binomln(n, value) + logpow(p, value) + logpow(1 - p, n - value), - 0 <= value, - value <= n, - 0 <= p, - p <= 1, ) + return check_parameters(res, 0 < n, 0 <= p, p <= 1, msg="n > 0, 0 <= p <= 1") + def logcdf(value, n, p): """ Compute the log of the cumulative distribution function for Binomial distribution @@ -160,16 +161,22 @@ def logcdf(value, n, p): """ value = at.floor(value) - return bound( + res = at.switch( + at.lt(value, 0), + -np.inf, at.switch( at.lt(value, n), at.log(at.betainc(n - value, value + 1, 1 - p)), 0, ), - 0 <= value, + ) + + return check_parameters( + res, 0 < n, 0 <= p, p <= 1, + msg="n > 0, 0 <= p <= 1", ) @@ -258,13 +265,12 @@ def logp(value, n, alpha, beta): ------- TensorVariable """ - return bound( + res = at.switch( + at.or_(at.lt(value, 0), at.gt(value, n)), + -np.inf, binomln(n, value) + betaln(value + alpha, n - value + beta) - betaln(alpha, beta), - value >= 0, - value <= n, - alpha > 0, - beta > 0, ) + return check_parameters(res, n >= 0, alpha > 0, beta > 0, msg="n >= 0, alpha > 0, beta > 0") def logcdf(value, n, alpha, beta): """ @@ -287,21 +293,22 @@ def logcdf(value, n, alpha, beta): ) safe_lower = at.switch(at.lt(value, 0), value, 0) - - return bound( + res = at.switch( + at.lt(value, 0), + -np.inf, at.switch( at.lt(value, n), at.logsumexp( - BetaBinomial.logp(at.arange(safe_lower, value + 1), n, alpha, beta), + logp( + BetaBinomial.dist(alpha=alpha, beta=beta, n=n), + at.arange(safe_lower, value + 1), + ), keepdims=False, ), 0, ), - 0 <= value, - 0 <= n, - 0 < alpha, - 0 < beta, ) + return check_parameters(res, 0 <= n, 0 < alpha, 0 < beta, msg="n >= 0, alpha > 0, beta > 0") class Bernoulli(Discrete): @@ -383,14 +390,14 @@ def logp(value, p): TensorVariable """ - return bound( + res = at.switch( + at.or_(at.lt(value, 0), at.gt(value, 1)), + -np.inf, at.switch(value, at.log(p), at.log1p(-p)), - value >= 0, - value <= 1, - p >= 0, - p <= 1, ) + return check_parameters(res, p >= 0, p <= 1, msg="0 <= p <= 1") + def logcdf(value, p): """ Compute the log of the cumulative distribution function for Bernoulli distribution @@ -406,17 +413,16 @@ def logcdf(value, p): ------- TensorVariable """ - - return bound( + res = at.switch( + at.lt(value, 0), + -np.inf, at.switch( at.lt(value, 1), at.log1p(-p), 0, ), - 0 <= value, - 0 <= p, - p <= 1, ) + return check_parameters(res, 0 <= p, p <= 1, msg="0 <= p <= 1") class DiscreteWeibullRV(RandomVariable): @@ -496,14 +502,15 @@ def logp(value, q, beta): ------- TensorVariable """ - return bound( + + res = at.switch( + at.lt(value, 0), + -np.inf, at.log(at.power(q, at.power(value, beta)) - at.power(q, at.power(value + 1, beta))), - 0 <= value, - 0 < q, - q < 1, - 0 < beta, ) + return check_parameters(res, 0 < q, q < 1, 0 < beta, msg="0 < q < 1, beta > 0") + def logcdf(value, q, beta): """ Compute the log of the cumulative distribution function for Discrete Weibull distribution @@ -519,13 +526,13 @@ def logcdf(value, q, beta): ------- TensorVariable """ - return bound( + + res = at.switch( + at.lt(value, 0), + -np.inf, at.log1p(-at.power(q, at.power(value + 1, beta))), - 0 <= value, - 0 < q, - q < 1, - 0 < beta, ) + return check_parameters(res, 0 < q, q < 1, 0 < beta, msg="0 < q < 1, beta > 0") class Poisson(Discrete): @@ -599,7 +606,12 @@ def logp(value, mu): ------- TensorVariable """ - log_prob = bound(logpow(mu, value) - factln(value) - mu, mu >= 0, value >= 0) + res = at.switch( + at.lt(value, 0), + -np.inf, + logpow(mu, value) - factln(value) - mu, + ) + log_prob = check_parameters(res, mu >= 0, msg="mu >= 0") # Return zero when mu and value are both zero return at.switch(at.eq(mu, 0) * at.eq(value, 0), 0, log_prob) @@ -623,12 +635,16 @@ def logcdf(value, mu): safe_mu = at.switch(at.lt(mu, 0), 0, mu) safe_value = at.switch(at.lt(value, 0), 0, value) - return bound( - at.log(at.gammaincc(safe_value + 1, safe_mu)), - 0 <= value, - 0 <= mu, + res = ( + at.switch( + at.lt(value, 0), + -np.inf, + at.log(at.gammaincc(safe_value + 1, safe_mu)), + ), ) + return check_parameters(res, 0 <= mu, msg="mu >= 0") + class NegativeBinomial(Discrete): R""" @@ -743,13 +759,22 @@ def logp(value, n, p): """ alpha = n mu = alpha * (1 - p) / p - negbinom = bound( - binomln(value + alpha - 1, value) - + logpow(mu / (mu + alpha), value) - + logpow(alpha / (mu + alpha), alpha), - value >= 0, + + res = at.switch( + at.lt(value, 0), + -np.inf, + ( + binomln(value + alpha - 1, value) + + logpow(mu / (mu + alpha), value) + + logpow(alpha / (mu + alpha), alpha) + ), + ) + + negbinom = check_parameters( + res, mu > 0, alpha > 0, + msg="mu > 0, alpha > 0", ) # Return Poisson when alpha gets very large. @@ -770,12 +795,17 @@ def logcdf(value, n, p): ------- TensorVariable """ - return bound( + res = at.switch( + at.lt(value, 0), + -np.inf, at.log(at.betainc(n, at.floor(value) + 1, p)), - 0 <= value, + ) + return check_parameters( + res, 0 < n, 0 <= p, p <= 1, + msg="0 < n, 0 <= p <= 1", ) @@ -844,11 +874,18 @@ def logp(value, p): ------- TensorVariable """ - return bound( + + res = at.switch( + at.lt(value, 1), + -np.inf, at.log(p) + logpow(1 - p, value - 1), + ) + + return check_parameters( + res, 0 <= p, p <= 1, - value >= 1, + msg="0 <= p <= 1", ) def logcdf(value, p): @@ -867,11 +904,16 @@ def logcdf(value, p): TensorVariable """ - return bound( + res = at.switch( + at.lt(value, 0), + -np.inf, at.log1mexp(at.log1p(-p) * value), - 0 <= value, + ) + return check_parameters( + res, 0 <= p, p <= 1, + msg="0 <= p <= 1", ) @@ -965,7 +1007,18 @@ def logp(value, good, bad, n): # value in [max(0, n - N + k), min(k, n)] lower = at.switch(at.gt(n - tot + good, 0), n - tot + good, 0) upper = at.switch(at.lt(good, n), good, n) - return bound(result, lower <= value, value <= upper) + + res = at.switch( + at.lt(value, lower), + -np.inf, + at.switch( + at.le(value, upper), + result, + -np.inf, + ), + ) + + return check_parameters(res, lower <= upper, msg="lower <= upper") def logcdf(value, good, bad, n): """ @@ -991,7 +1044,9 @@ def logcdf(value, good, bad, n): # TODO: Use lower upper in locgdf for smarter logsumexp? safe_lower = at.switch(at.lt(value, 0), value, 0) - return bound( + res = at.switch( + at.lt(value, 0), + -np.inf, at.switch( at.lt(value, n), at.logsumexp( @@ -1000,12 +1055,16 @@ def logcdf(value, good, bad, n): ), 0, ), - 0 <= value, + ) + + return check_parameters( + res, 0 < N, 0 <= good, 0 <= n, good <= N, n <= N, + msg="N > 0, 0 <= good <= N, 0 <= n <= N", ) @@ -1092,11 +1151,12 @@ def logp(value, lower, upper): ------- TensorVariable """ - return bound( + res = at.switch( + at.or_(at.lt(value, lower), at.gt(value, upper)), + -np.inf, at.fill(value, -at.log(upper - lower + 1)), - lower <= value, - value <= upper, ) + return check_parameters(res, lower <= upper, msg="lower <= upper") def logcdf(value, lower, upper): """ @@ -1114,16 +1174,18 @@ def logcdf(value, lower, upper): TensorVariable """ - return bound( + res = at.switch( + at.le(value, lower), + -np.inf, at.switch( at.lt(value, upper), at.log(at.minimum(at.floor(value), upper) - lower + 1) - at.log(upper - lower + 1), 0, ), - lower <= value, - lower <= upper, ) + return check_parameters(res, lower <= upper, msg="lower <= upper") + class Categorical(Discrete): R""" @@ -1205,12 +1267,14 @@ def logp(value, p): else: a = at.log(p[value_clip]) - return bound( + res = at.switch( + at.or_(at.lt(value, 0), at.gt(value, k - 1)), + -np.inf, a, - value >= 0, - value <= (k - 1), - at.all(p_ >= 0, axis=-1), - at.all(p <= 1, axis=-1), + ) + + return check_parameters( + res, at.all(p_ >= 0, axis=-1), at.all(p <= 1, axis=-1), msg="0 <= p <=1" ) @@ -1267,9 +1331,10 @@ def logp(value, c): ------- TensorVariable """ - return bound( - at.zeros_like(value), + return at.switch( at.eq(value, c), + at.zeros_like(value), + -np.inf, ) @@ -1368,18 +1433,20 @@ def logp(value, psi, theta): TensorVariable """ - logp_val = at.switch( + res = at.switch( at.gt(value, 0), at.log(psi) + logp(Poisson.dist(mu=theta), value), at.logaddexp(at.log1p(-psi), at.log(psi) - theta), ) - return bound( - logp_val, - 0 <= value, + res = at.switch(at.lt(value, 0), -np.inf, res) + + return check_parameters( + res, 0 <= psi, psi <= 1, 0 <= theta, + msg="0 <= psi <= 1, theta >= 0", ) def logcdf(value, psi, theta): @@ -1398,15 +1465,17 @@ def logcdf(value, psi, theta): TensorVariable """ - return bound( + res = at.switch( + at.lt(value, 0), + -np.inf, at.logaddexp( at.log1p(-psi), at.log(psi) + logcdf(Poisson.dist(mu=theta), value), ), - 0 <= value, - 0 <= psi, - psi <= 1, - 0 <= theta, + ) + + return check_parameters( + res, 0 <= psi, psi <= 1, 0 <= theta, msg="0 <= psi <= 1, theta >= 0" ) @@ -1507,20 +1576,25 @@ def logp(value, psi, n, p): TensorVariable """ - logp_val = at.switch( + res = at.switch( at.gt(value, 0), at.log(psi) + logp(Binomial.dist(n=n, p=p), value), at.logaddexp(at.log1p(-psi), at.log(psi) + n * at.log1p(-p)), ) - return bound( - logp_val, - 0 <= value, - value <= n, + res = at.switch( + at.lt(value, 0), + -np.inf, + res, + ) + + return check_parameters( + res, 0 <= psi, psi <= 1, 0 <= p, p <= 1, + msg="0 <= psi <= 1, 0 <= p <= 1", ) def logcdf(value, psi, n, p): @@ -1538,18 +1612,22 @@ def logcdf(value, psi, n, p): ------- TensorVariable """ - - return bound( + res = at.switch( + at.or_(at.lt(value, 0), at.gt(value, n)), + -np.inf, at.logaddexp( at.log1p(-psi), at.log(psi) + logcdf(Binomial.dist(n=n, p=p), value), ), - 0 <= value, - value <= n, + ) + + return check_parameters( + res, 0 <= psi, psi <= 1, 0 <= p, p <= 1, + msg="0 <= psi <= 1, 0 <= p <= 1", ) @@ -1683,18 +1761,26 @@ def logp(value, psi, n, p): TensorVariable """ - return bound( - at.switch( - at.gt(value, 0), - at.log(psi) + logp(NegativeBinomial.dist(n=n, p=p), value), - at.logaddexp(at.log1p(-psi), at.log(psi) + n * at.log(p)), - ), - 0 <= value, + res = at.switch( + at.gt(value, 0), + at.log(psi) + logp(NegativeBinomial.dist(n=n, p=p), value), + at.logaddexp(at.log1p(-psi), at.log(psi) + n * at.log(p)), + ) + + res = at.switch( + at.lt(value, 0), + -np.inf, + res, + ) + + return check_parameters( + res, 0 <= psi, psi <= 1, 0 < n, 0 <= p, p <= 1, + msg="0 <= psi <= 1, n > 0, 0 <= p <= 1", ) def logcdf(value, psi, n, p): @@ -1712,16 +1798,23 @@ def logcdf(value, psi, n, p): ------- TensorVariable """ - return bound( + res = at.switch( + at.lt(value, 0), + -np.inf, at.logaddexp( - at.log1p(-psi), at.log(psi) + logcdf(NegativeBinomial.dist(n=n, p=p), value) + at.log1p(-psi), + at.log(psi) + logcdf(NegativeBinomial.dist(n=n, p=p), value), ), - 0 <= value, + ) + + return check_parameters( + res, 0 <= psi, psi <= 1, 0 < n, 0 < p, p <= 1, + msg="0 <= psi <= 1, n > 0, 0 < p <= 1", ) diff --git a/pymc/distributions/dist_math.py b/pymc/distributions/dist_math.py index 9b4a2bb57b..4547a9aa30 100644 --- a/pymc/distributions/dist_math.py +++ b/pymc/distributions/dist_math.py @@ -19,14 +19,17 @@ """ import warnings +from typing import Iterable + import aesara import aesara.tensor as at import numpy as np import scipy.linalg import scipy.stats +from aeppl.logprob import CheckParameterValue from aesara.compile.builders import OpFromGraph -from aesara.graph.basic import Apply +from aesara.graph.basic import Apply, Variable from aesara.graph.op import Op from aesara.scalar import UnaryScalarOp, upgrade_to_float_no_complex from aesara.tensor import gammaln @@ -46,59 +49,22 @@ } -def bound(logp, *conditions, broadcast_conditions=True): +def check_parameters(logp: Variable, *conditions: Iterable[Variable], msg: str = ""): """ - Bounds a log probability density with several conditions. - When conditions are not met, the logp values are replaced by -inf. - - Note that bound should not be used to enforce the logic of the logp under the normal - support as it can be disabled by the user via check_bounds = False in pm.Model() - - Parameters - ---------- - logp: float - *conditions: booleans - broadcast_conditions: bool (optional, default=True) - If True, conditions are broadcasted and applied element-wise to each value in logp. - If False, conditions are collapsed via at.all(). As a consequence the entire logp - array is either replaced by -inf or unchanged. + Wrap a log probability graph in a CheckParameterValue that asserts several + conditions are True. When conditions are not met a ParameterValueError assertion is + raised, with an optional custom message defined by `msg` - Setting broadcasts_conditions to False is necessary for most (all?) multivariate - distributions where the dimensions of the conditions do not unambigously match - that of the logp. - - Returns - ------- - logp with elements set to -inf where any condition is False + Note that check_parameter should not be used to enforce the logic of the logp + expression under the normal parameter support as it can be disabled by the user via + check_bounds = False in pm.Model() """ - - # If called inside a model context, see if bounds check is disabled - try: - from pymc.model import modelcontext - - model = modelcontext(None) - if not model.check_bounds: - return logp - except TypeError: - pass # no model found - - if broadcast_conditions: - alltrue = alltrue_elemwise - else: - alltrue = alltrue_scalar - - return at.switch(alltrue(conditions), logp, -np.inf) - - -def alltrue_elemwise(vals): - ret = 1 - for c in vals: - ret = ret * (1 * c) - return ret - - -def alltrue_scalar(vals): - return at.all([at.all(1 * val) for val in vals]) + # at.all does not accept True/False, but accepts np.array(True)/np.array(False) + conditions = [ + cond if (cond is not True and cond is not False) else np.array(cond) for cond in conditions + ] + all_true_scalar = at.all([at.all(cond) for cond in conditions]) + return CheckParameterValue(msg)(logp, all_true_scalar) def logpow(x, m): diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index d7ca78a8c4..6105bb726c 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -20,7 +20,7 @@ from pymc.aesaraf import _conversion_map, take_along_axis from pymc.distributions.continuous import Normal, get_tau_sigma -from pymc.distributions.dist_math import bound +from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import Discrete, Distribution from pymc.distributions.shape_utils import to_tuple from pymc.math import logsumexp @@ -416,7 +416,7 @@ def logp(self, value): """ w = self.w - return bound( + return check_parameters( logsumexp(at.log(w) + self._comp_logp(value), axis=-1, keepdims=False), w >= 0, w <= 1, @@ -744,7 +744,7 @@ def logp(self, value): value = at.shape_padaxis(value, axis=mixture_axis - comp_dists_ndim) comp_logp = comp_dists.logp(value) - return bound( + return check_parameters( logsumexp(at.log(w) + comp_logp, axis=mixture_axis, keepdims=False), w >= 0, w <= 1, diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 43a66fe1e5..746400c9a2 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -44,7 +44,7 @@ from pymc.aesaraf import floatX, intX from pymc.distributions import transforms from pymc.distributions.continuous import ChiSquared, Normal, assert_negative_support -from pymc.distributions.dist_math import bound, factln, logpow, multigammaln +from pymc.distributions.dist_math import check_parameters, factln, logpow, multigammaln from pymc.distributions.distribution import Continuous, Discrete from pymc.distributions.shape_utils import ( broadcast_dist_samples_to, @@ -252,10 +252,7 @@ def logp(value, mu, cov): quaddist, logdet, ok = quaddist_parse(value, mu, cov) k = floatX(value.shape[-1]) norm = -0.5 * k * pm.floatX(np.log(2 * np.pi)) - return bound(norm - 0.5 * quaddist - logdet, ok) - - def _distr_parameters_for_repr(self): - return ["mu", "cov"] + return check_parameters(norm - 0.5 * quaddist - logdet, ok) class MvStudentTRV(RandomVariable): @@ -380,10 +377,13 @@ def logp(value, nu, mu, cov): norm = gammaln((nu + k) / 2.0) - gammaln(nu / 2.0) - 0.5 * k * at.log(nu * np.pi) inner = -(nu + k) / 2.0 * at.log1p(quaddist / nu) - return bound(norm + inner - logdet, ok) + res = norm + inner - logdet - def _distr_parameters_for_repr(self): - return ["nu", "mu", "cov"] + return check_parameters( + res, + ok, + nu > 0, + ) class Dirichlet(Continuous): @@ -448,16 +448,20 @@ def logp(value, a): TensorVariable """ # only defined for sum(value) == 1 - return bound( - at.sum(logpow(value, a - 1) - gammaln(a), axis=-1) + gammaln(at.sum(a, axis=-1)), - at.all(value >= 0), - at.all(value <= 1), - at.all(a > 0), - broadcast_conditions=False, + res = at.sum(logpow(value, a - 1) - gammaln(a), axis=-1) + gammaln(at.sum(a, axis=-1)) + res = at.switch( + at.or_( + at.any(at.lt(value, 0), axis=-1), + at.any(at.gt(value, 1), axis=-1), + ), + -np.inf, + res, + ) + return check_parameters( + res, + a > 0, + msg="a > 0", ) - - def _distr_parameters_for_repr(self): - return ["a"] class MultinomialRV(MultinomialRV): @@ -554,14 +558,19 @@ def logp(value, n, p): ------- TensorVariable """ - return bound( - factln(n) + at.sum(-factln(value) + logpow(p, value), axis=-1), - at.all(value >= 0), - at.all(at.eq(at.sum(value, axis=-1), n)), - at.all(p <= 1), - at.all(at.eq(at.sum(p, axis=-1), 1)), - at.all(at.ge(n, 0)), - broadcast_conditions=False, + + res = factln(n) + at.sum(-factln(value) + logpow(p, value), axis=-1) + res = at.switch( + at.or_(at.any(at.lt(value, 0), axis=-1), at.neq(at.sum(value, axis=-1), n)), + -np.inf, + res, + ) + return check_parameters( + res, + p <= 1, + at.eq(at.sum(p, axis=-1), 1), + at.ge(n, 0), + msg="p <= 1, sum(p) = 1, n >= 0", ) @@ -659,17 +668,22 @@ def logp(value, n, a): sum_a = a.sum(axis=-1) const = (gammaln(n + 1) + gammaln(sum_a)) - gammaln(n + sum_a) series = gammaln(value + a) - (gammaln(value + 1) + gammaln(a)) - result = const + series.sum(axis=-1) + res = const + series.sum(axis=-1) + + res = at.switch( + at.or_( + at.any(at.lt(value, 0), axis=-1), + at.neq(at.sum(value, axis=-1), n), + ), + -np.inf, + res, + ) - # Bounds checking to confirm parameters and data meet all constraints - # and that each observation value_i sums to n_i. - return bound( - result, - value >= 0, + return check_parameters( + res, a > 0, n >= 0, - at.eq(value.sum(axis=-1), n), - broadcast_conditions=False, + msg="a > 0, n >= 0", ) @@ -937,7 +951,7 @@ def logp(X, nu, V): IVI = det(V) IXI = det(X) - return bound( + return check_parameters( ( (nu - p - 1) * at.log(IXI) - trace(matrix_inverse(V).dot(X)) @@ -949,7 +963,6 @@ def logp(X, nu, V): matrix_pos_def(X), at.eq(X, X.T), nu > (p - 1), - broadcast_conditions=False, ) @@ -1234,9 +1247,6 @@ def random(self, point=None, size=None): # samples = np.reshape(samples, size + sample_shape) # return samples - def _distr_parameters_for_repr(self): - return ["eta", "n"] - def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=True, *args, **kwargs): r"""Wrapper function for covariance matrix with LKJ distributed correlations. @@ -1534,18 +1544,14 @@ def logp(self, x): result = _lkj_normalizing_constant(eta, n) result += (eta - 1.0) * at.log(det(X)) - return bound( + return check_parameters( result, X >= -1, X <= 1, matrix_pos_def(X), eta > 0, - broadcast_conditions=False, ) - def _distr_parameters_for_repr(self): - return ["eta", "n"] - class MatrixNormalRV(RandomVariable): name = "matrixnormal" @@ -1783,9 +1789,6 @@ def logp(value, mu, rowchol, colchol): norm = -0.5 * m * n * pm.floatX(np.log(2 * np.pi)) return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet - def _distr_parameters_for_repr(self): - return ["mu"] - class KroneckerNormalRV(RandomVariable): name = "kroneckernormal" @@ -1975,9 +1978,6 @@ def logp(value, mu, sigma, *covs): a = -(quad + logdet + N * at.log(2 * np.pi)) / 2.0 return a - def _distr_parameters_for_repr(self): - return ["mu"] - class CARRV(RandomVariable): name = "car" @@ -2146,9 +2146,10 @@ def logp(value, mu, W, alpha, tau): tau_dot_delta = D[None, :] * delta - alpha * Wdelta logquad = (tau * delta * tau_dot_delta).sum(axis=-1) - return bound( + return check_parameters( 0.5 * (logtau + logdet - logquad), - at.all(alpha <= 1), - at.all(alpha >= -1), + alpha <= 1, + alpha >= -1, tau > 0, + msg="-1 <= alpha <= 1, tau > 0", ) diff --git a/pymc/gp/util.py b/pymc/gp/util.py index 730e664eec..bd419d6b4f 100644 --- a/pymc/gp/util.py +++ b/pymc/gp/util.py @@ -31,7 +31,7 @@ from aesara.tensor.var import TensorConstant from scipy.cluster.vq import kmeans -from pymc.aesaraf import compile_rv_inplace, walk_model +from pymc.aesaraf import compile_pymc, walk_model # Avoid circular dependency when importing modelcontext from pymc.distributions.distribution import NoDistribution @@ -68,7 +68,7 @@ def replace_with_values(vars_needed, replacements=None, model=None): if len(inputs) == 0: return tuple(v.eval() for v in vars_needed) - fn = compile_rv_inplace( + fn = compile_pymc( inputs, vars_needed, allow_input_downcast=True, diff --git a/pymc/initial_point.py b/pymc/initial_point.py index b5f29a7e98..a53f9af33e 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -23,7 +23,7 @@ from aesara.graph.fg import FunctionGraph from aesara.tensor.var import TensorVariable -from pymc.aesaraf import compile_rv_inplace +from pymc.aesaraf import compile_pymc from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name StartDict = Dict[Union[Variable, str], Union[np.ndarray, Variable, str]] @@ -185,9 +185,7 @@ def find_rng_nodes(variables): new_rng = np.random.Generator(np.random.PCG64()) new_rng_nodes.append(aesara.shared(new_rng)) graph.replace_all(zip(rng_nodes, new_rng_nodes), import_missing=True) - func = compile_rv_inplace( - inputs=[], outputs=graph.outputs, mode=aesara.compile.mode.FAST_COMPILE - ) + func = compile_pymc(inputs=[], outputs=graph.outputs, mode=aesara.compile.mode.FAST_COMPILE) varnames = [] for var in model.free_RVs: diff --git a/pymc/model.py b/pymc/model.py index 4dcb20f16c..625b1f3bfa 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -48,7 +48,7 @@ from pandas import Series from pymc.aesaraf import ( - compile_rv_inplace, + compile_pymc, gradient, hessian, inputvars, @@ -462,7 +462,7 @@ def __init__( inputs = grad_vars - self._aesara_function = compile_rv_inplace(inputs, outputs, givens=givens, **kwargs) + self._aesara_function = compile_pymc(inputs, outputs, givens=givens, **kwargs) def set_weights(self, values): if values.shape != (self._n_costs - 1,): @@ -1400,7 +1400,7 @@ def makefn(self, outs, mode=None, *args, **kwargs): Compiled Aesara function """ with self: - return compile_rv_inplace( + return compile_pymc( self.value_vars, outs, allow_input_downcast=True, diff --git a/pymc/sampling.py b/pymc/sampling.py index ff1261e4ce..30be104412 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -37,7 +37,7 @@ import pymc as pm -from pymc.aesaraf import change_rv_size, compile_rv_inplace, inputvars, walk_model +from pymc.aesaraf import change_rv_size, compile_pymc, inputvars, walk_model from pymc.backends.arviz import _DefaultTrace from pymc.backends.base import BaseTrace, MultiTrace from pymc.backends.ndarray import NDArray @@ -1697,7 +1697,7 @@ def sample_posterior_predictive( if size is not None: vars_to_sample = [change_rv_size(v, size, expand=True) for v in vars_to_sample] - sampler_fn = compile_rv_inplace( + sampler_fn = compile_pymc( inputs, vars_to_sample, allow_input_downcast=True, @@ -2009,7 +2009,7 @@ def sample_prior_predictive( inputs = [i for i in inputvars(vars_to_sample) if not isinstance(i, SharedVariable)] - sampler_fn = compile_rv_inplace( + sampler_fn = compile_pymc( inputs, vars_to_sample, allow_input_downcast=True, accept_inplace=True, mode=mode ) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 051e9818fd..278552ac1e 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -19,14 +19,14 @@ import numpy as np import pandas as pd -from aesara.assert_op import Assert +from aeppl.logprob import CheckParameterValue from aesara.compile import SharedVariable from aesara.graph.basic import clone_replace, graph_inputs from aesara.graph.fg import FunctionGraph from aesara.link.jax.dispatch import jax_funcify +from aesara.raise_op import Assert from pymc import Model, modelcontext -from pymc.aesaraf import compile_rv_inplace from pymc.backends.arviz import find_observations from pymc.distributions import logpt from pymc.util import get_default_varnames @@ -35,6 +35,7 @@ @jax_funcify.register(Assert) +@jax_funcify.register(CheckParameterValue) def jax_funcify_Assert(op, **kwargs): # Jax does not allow assert whose values aren't known during JIT compilation # within it's JIT-ed code. Hence we need to make a simple pass through diff --git a/pymc/smc/smc.py b/pymc/smc/smc.py index 88d15ced80..91fab8eed0 100644 --- a/pymc/smc/smc.py +++ b/pymc/smc/smc.py @@ -24,7 +24,7 @@ from scipy.stats import multivariate_normal from pymc.aesaraf import ( - compile_rv_inplace, + compile_pymc, floatX, inputvars, join_nonshared_inputs, @@ -570,6 +570,6 @@ def _logp_forw(point, out_vars, in_vars, shared): in_vars = new_in_vars out_list, inarray0 = join_nonshared_inputs(point, out_vars, in_vars, shared) - f = compile_rv_inplace([inarray0], out_list[0]) + f = compile_pymc([inarray0], out_list[0]) f.trust_input = True return f diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index b5b64225bb..fb601c58a0 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -22,7 +22,7 @@ import pymc as pm -from pymc.aesaraf import compile_rv_inplace, floatX, rvs_to_value_vars +from pymc.aesaraf import compile_pymc, floatX, rvs_to_value_vars from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.step_methods.arraystep import ( ArrayStep, @@ -1006,6 +1006,6 @@ def delta_logp(point, logp, vars, shared): logp1 = pm.CallableTensor(logp0)(inarray1) - f = compile_rv_inplace([inarray1, inarray0], logp1 - logp0) + f = compile_pymc([inarray1, inarray0], logp1 - logp0) f.trust_input = True return f diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index a41625ac66..9c7d0e690d 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -23,6 +23,7 @@ import pytest import scipy.sparse as sps +from aeppl.logprob import ParameterValueError from aesara.graph.basic import Constant, Variable, ancestors, equal_computations from aesara.tensor.random.basic import normal, uniform from aesara.tensor.random.op import RandomVariable @@ -35,12 +36,14 @@ from pymc.aesaraf import ( _conversion_map, change_rv_size, + compile_pymc, extract_obs_data, pandas_to_array, rvs_to_value_vars, take_along_axis, walk_model, ) +from pymc.distributions.dist_math import check_parameters from pymc.exceptions import ShapeError from pymc.vartypes import int_types @@ -550,3 +553,24 @@ def test_rvs_to_value_vars_nested(): after = aesara.clone_replace(m.free_RVs) assert equal_computations(before, after) + + +def test_check_bounds_flag(): + """Test that CheckParameterValue Ops are replaced or removed when using compile_pymc""" + logp = at.ones(3) + cond = np.array([1, 0, 1]) + bound = check_parameters(logp, cond) + + with pm.Model() as m: + pass + + with pytest.raises(ParameterValueError): + aesara.function([], bound)() + + m.check_bounds = False + with m: + assert np.all(compile_pymc([], bound)() == 1) + + m.check_bounds = True + with m: + assert np.all(compile_pymc([], bound)() == -np.inf) diff --git a/pymc/tests/test_dist_math.py b/pymc/tests/test_dist_math.py index 0adcc3a239..c2f4147fc1 100644 --- a/pymc/tests/test_dist_math.py +++ b/pymc/tests/test_dist_math.py @@ -18,6 +18,7 @@ import pytest import scipy.special +from aeppl.logprob import ParameterValueError from aesara import config, function from aesara.tensor.random.basic import multinomial from scipy import interpolate, stats @@ -29,8 +30,7 @@ from pymc.distributions.dist_math import ( MvNormalLogp, SplineWrapper, - alltrue_scalar, - bound, + check_parameters, clipped_beta_rvs, factln, i0e, @@ -41,57 +41,35 @@ from pymc.tests.helpers import verify_grad -def test_bound(): - logp = at.ones((10, 10)) - cond = at.ones((10, 10)) - assert np.all(bound(logp, cond).eval() == logp.eval()) - - logp = at.ones((10, 10)) - cond = at.zeros((10, 10)) - assert np.all(bound(logp, cond).eval() == (-np.inf * logp).eval()) - - logp = at.ones((10, 10)) - cond = True - assert np.all(bound(logp, cond).eval() == logp.eval()) - - logp = at.ones(3) - cond = np.array([1, 0, 1]) - assert not np.all(bound(logp, cond).eval() == 1) - assert np.prod(bound(logp, cond).eval()) == -np.inf - - logp = at.ones((2, 3)) - cond = np.array([[1, 1, 1], [1, 0, 1]]) - assert not np.all(bound(logp, cond).eval() == 1) - assert np.prod(bound(logp, cond).eval()) == -np.inf - - -def test_check_bounds_false(): - with pm.Model(check_bounds=False): - logp = at.ones(3) - cond = np.array([1, 0, 1]) - assert np.all(bound(logp, cond).eval() == logp.eval()) - - -def test_alltrue_scalar(): - assert alltrue_scalar([]).eval() - assert alltrue_scalar([True]).eval() - assert alltrue_scalar([at.ones(10)]).eval() - assert alltrue_scalar([at.ones(10), 5 * at.ones(101)]).eval() - assert alltrue_scalar([np.ones(10), 5 * at.ones(101)]).eval() - assert alltrue_scalar([np.ones(10), True, 5 * at.ones(101)]).eval() - assert alltrue_scalar([np.array([1, 2, 3]), True, 5 * at.ones(101)]).eval() - - assert not alltrue_scalar([False]).eval() - assert not alltrue_scalar([at.zeros(10)]).eval() - assert not alltrue_scalar([True, False]).eval() - assert not alltrue_scalar([np.array([0, -1]), at.ones(60)]).eval() - assert not alltrue_scalar([np.ones(10), False, 5 * at.ones(101)]).eval() - +@pytest.mark.parametrize( + "conditions, succeeds", + [ + ([], True), + ([True], True), + ([at.ones(10)], True), + ([at.ones(10), 5 * at.ones(101)], True), + ([np.ones(10), 5 * at.ones(101)], True), + ([np.ones(10), True, 5 * at.ones(101)], True), + ([np.array([1, 2, 3]), True, 5 * at.ones(101)], True), + ([False], False), + ([at.zeros(10)], False), + ([True, False], False), + ([np.array([0, -1]), at.ones(60)], False), + ([np.ones(10), False, 5 * at.ones(101)], False), + ], +) +def test_check_parameters(conditions, succeeds): + ret = check_parameters(1, *conditions, msg="parameter check msg") + if succeeds: + assert ret.eval() + else: + with pytest.raises(ParameterValueError, match="^parameter check msg$"): + ret.eval() -def test_alltrue_shape(): - vals = [True, at.ones(10), at.zeros(5)] - assert alltrue_scalar(vals).eval().shape == () +def test_check_parameters_shape(): + conditions = [True, at.ones(10), at.ones(5)] + assert check_parameters(1, *conditions).eval().shape == () class MultinomialA(Discrete): @@ -102,13 +80,12 @@ def dist(cls, n, p, *args, **kwargs): return super().dist([n, p], **kwargs) def logp(value, n, p): - return bound( + return check_parameters( factln(n) - factln(value).sum() + (value * at.log(p)).sum(), value >= 0, 0 <= p, p <= 1, at.isclose(p.sum(), 1), - broadcast_conditions=False, ) @@ -120,17 +97,16 @@ def dist(cls, n, p, *args, **kwargs): return super().dist([n, p], **kwargs) def logp(value, n, p): - return bound( + return check_parameters( factln(n) - factln(value).sum() + (value * at.log(p)).sum(), at.all(value >= 0), at.all(0 <= p), at.all(p <= 1), at.isclose(p.sum(), 1), - broadcast_conditions=False, ) -def test_multinomial_bound(): +def test_multinomial_check_parameters(): x = np.array([1, 5]) n = x.sum() diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index a4c1ee472d..ef7d7d9577 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -20,6 +20,9 @@ import numpy as np import numpy.random as nr +from aeppl.logprob import ParameterValueError + +from pymc.distributions.continuous import get_tau_sigma from pymc.util import UNSET try: @@ -118,7 +121,6 @@ def polyagamma_cdf(*args, **kwargs): ZeroInflatedBinomial, ZeroInflatedNegativeBinomial, ZeroInflatedPoisson, - continuous, logcdf, logp, logpt, @@ -150,15 +152,20 @@ def get_lkj_cases(): class Domain: - def __init__(self, vals, dtype=None, edges=None, shape=None): - avals = array(vals, dtype=dtype) - if dtype is None and not str(avals.dtype).startswith("int"): - avals = avals.astype(aesara.config.floatX) - vals = [array(v, dtype=avals.dtype) for v in vals] + def __init__(self, vals, dtype=aesara.config.floatX, edges=None, shape=None): + # Infinity values must be kept as floats + vals = [array(v, dtype=dtype) if np.all(np.isfinite(v)) else floatX(v) for v in vals] if edges is None: edges = array(vals[0]), array(vals[-1]) vals = vals[1:-1] + else: + edges = list(edges) + if edges[0] is None: + edges[0] = np.full_like(vals[0], -np.inf) + if edges[1] is None: + edges[1] = np.full_like(vals[0], np.inf) + edges = tuple(edges) if not vals: raise ValueError( @@ -168,13 +175,12 @@ def __init__(self, vals, dtype=None, edges=None, shape=None): ) if shape is None: - shape = avals[0].shape + shape = vals[0].shape self.vals = vals self.shape = shape - self.lower, self.upper = edges - self.dtype = avals.dtype + self.dtype = dtype def __add__(self, other): return Domain( @@ -249,17 +255,17 @@ def product(domains, n_samples=-1): Circ = Domain([-np.pi, -2.1, -1, -0.01, 0.0, 0.01, 1, 2.1, np.pi]) -Runif = Domain([-1, -0.4, 0, 0.4, 1]) -Rdunif = Domain([-10, 0, 10.0]) +Runif = Domain([-np.inf, -0.4, 0, 0.4, np.inf]) +Rdunif = Domain([-np.inf, -1, 0, 1, np.inf], "int64") Rplusunif = Domain([0, 0.5, inf]) -Rplusdunif = Domain([2, 10, 100], "int64") +Rplusdunif = Domain([0, 10, np.inf], "int64") -I = Domain([-1000, -3, -2, -1, 0, 1, 2, 3, 1000], "int64") +I = Domain([-np.inf, -3, -2, -1, 0, 1, 2, 3, np.inf], "int64") -NatSmall = Domain([0, 3, 4, 5, 1000], "int64") -Nat = Domain([0, 1, 2, 3, 2000], "int64") -NatBig = Domain([0, 1, 2, 3, 5000, 50000], "int64") -PosNat = Domain([1, 2, 3, 2000], "int64") +NatSmall = Domain([0, 3, 4, 5, np.inf], "int64") +Nat = Domain([0, 1, 2, 3, np.inf], "int64") +NatBig = Domain([0, 1, 2, 3, 5000, np.inf], "int64") +PosNat = Domain([1, 2, 3, np.inf], "int64") Bool = Domain([0, 0, 1, 1], "int64") @@ -521,20 +527,16 @@ def orderedprobit_logpdf(value, eta, cutpoints): return np.where(np.all(ps >= 0), np.log(p), -np.inf) -class Simplex: - def __init__(self, n): - self.vals = list(simplex_values(n)) - self.shape = (n,) - self.dtype = Unit.dtype +def Simplex(n): + return Domain(simplex_values(n), shape=(n,), dtype=Unit.dtype, edges=(None, None)) + +def MultiSimplex(n_dependent, n_independent): + vals = [] + for simplex_value in itertools.product(simplex_values(n_dependent), repeat=n_independent): + vals.append(np.vstack(simplex_value)) -class MultiSimplex: - def __init__(self, n_dependent, n_independent): - self.vals = [] - for simplex_value in itertools.product(simplex_values(n_dependent), repeat=n_independent): - self.vals.append(np.vstack(simplex_value)) - self.shape = (n_independent, n_dependent) - self.dtype = Unit.dtype + return Domain(vals, dtype=Unit.dtype, shape=(n_independent, n_dependent)) def PdMatrix(n): @@ -632,6 +634,7 @@ def check_logp( n_samples=100, extra_args=None, scipy_args=None, + skip_paramdomain_outside_edge_test=False, ): """ Generic test for PyMC logp methods @@ -677,14 +680,39 @@ def logp_reference(args): args.update(scipy_args) return scipy_logp(**args) + def _model_input_dict(model, param_vars, pt): + """Create a dict with only the necessary, transformed logp inputs.""" + pt_d = {} + for k, v in pt.items(): + rv_var = model.named_vars.get(k) + nv = param_vars.get(k, rv_var) + nv = getattr(nv.tag, "value_var", nv) + + transform = getattr(nv.tag, "transform", None) + if transform: + # todo: the compiled graph behind this should be cached and + # reused (if it isn't already). + v = transform.forward(rv_var, v).eval() + + if nv.name in param_vars: + # update the shared parameter variables in `param_vars` + param_vars[nv.name].set_value(v) + else: + # create an argument entry for the (potentially + # transformed) "value" variable + pt_d[nv.name] = v + + return pt_d + model, param_vars = build_model(pymc_dist, domain, paramdomains, extra_args) logp_pymc = model.fastlogp_nojac + # Test supported value and parameters domain matches scipy domains = paramdomains.copy() domains["value"] = domain for pt in product(domains, n_samples=n_samples): pt = dict(pt) - pt_d = self._model_input_dict(model, param_vars, pt) + pt_d = _model_input_dict(model, param_vars, pt) pt_logp = Point(pt_d, model=model) pt_ref = Point(pt, filter_model_vars=False, model=model) assert_almost_equal( @@ -694,29 +722,58 @@ def logp_reference(args): err_msg=str(pt), ) - def _model_input_dict(self, model, param_vars, pt): - """Create a dict with only the necessary, transformed logp inputs.""" - pt_d = {} - for k, v in pt.items(): - rv_var = model.named_vars.get(k) - nv = param_vars.get(k, rv_var) - nv = getattr(nv.tag, "value_var", nv) - - transform = getattr(nv.tag, "transform", None) - if transform: - # todo: the compiled graph behind this should be cached and - # reused (if it isn't already). - v = transform.forward(rv_var, v).eval() - - if nv.name in param_vars: - # update the shared parameter variables in `param_vars` - param_vars[nv.name].set_value(v) - else: - # create an argument entry for the (potentially - # transformed) "value" variable - pt_d[nv.name] = v + valid_value = domain.vals[0] + valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()} + valid_dist = pymc_dist.dist(**valid_params, **extra_args) + + # Test pymc distribution raises ParameterValueError for scalar parameters outside + # the supported domain edges (excluding edges) + if not skip_paramdomain_outside_edge_test: + # Step1: collect potential invalid parameters + invalid_params = {param: [None, None] for param in paramdomains} + for param, paramdomain in paramdomains.items(): + if np.ndim(paramdomain.lower) != 0: + continue + if np.isfinite(paramdomain.lower): + invalid_params[param][0] = paramdomain.lower - 1 + if np.isfinite(paramdomain.upper): + invalid_params[param][1] = paramdomain.upper + 1 + + # Step2: test invalid parameters, one a time + for invalid_param, invalid_edges in invalid_params.items(): + for invalid_edge in invalid_edges: + if invalid_edge is None: + continue + test_params = valid_params.copy() # Shallow copy should be okay + test_params[invalid_param] = at.as_tensor_variable(invalid_edge) + # We need to remove `Assert`s introduced by checks like + # `assert_negative_support` and disable test values; + # otherwise, we won't be able to create the `RandomVariable` + with aesara.config.change_flags(compute_test_value="off"): + invalid_dist = pymc_dist.dist(**test_params, **extra_args) + with aesara.config.change_flags(mode=Mode("py")): + with pytest.raises(ParameterValueError): + logp(invalid_dist, valid_value).eval() + pytest.fail(f"test_params={test_params}, valid_value={valid_value}") + + # Test that values outside of scalar domain support evaluate to -np.inf + if np.ndim(domain.lower) != 0: + return + invalid_values = [None, None] + if np.isfinite(domain.lower): + invalid_values[0] = domain.lower - 1 + if np.isfinite(domain.upper): + invalid_values[1] = domain.upper + 1 - return pt_d + for invalid_value in invalid_values: + if invalid_value is None: + continue + with aesara.config.change_flags(mode=Mode("py")): + assert_equal( + logp(valid_dist, invalid_value).eval(), + -np.inf, + err_msg=str(invalid_value), + ) def check_logcdf( self, @@ -809,25 +866,22 @@ def check_logcdf( valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()} valid_dist = pymc_dist.dist(**valid_params) - # Natural domains do not have inf as the upper edge, but should also be ignored - nat_domains = (NatSmall, Nat, NatBig, PosNat) - - # Test pymc distribution gives -inf for parameters outside the - # supported domain edges (excluding edgse) + # Test pymc distribution raises ParameterValueError for parameters outside the + # supported domain edges (excluding edges) if not skip_paramdomain_outside_edge_test: # Step1: collect potential invalid parameters invalid_params = {param: [None, None] for param in paramdomains} for param, paramdomain in paramdomains.items(): if np.isfinite(paramdomain.lower): invalid_params[param][0] = paramdomain.lower - 1 - if np.isfinite(paramdomain.upper) and paramdomain not in nat_domains: + if np.isfinite(paramdomain.upper): invalid_params[param][1] = paramdomain.upper + 1 # Step2: test invalid parameters, one a time for invalid_param, invalid_edges in invalid_params.items(): for invalid_edge in invalid_edges: if invalid_edge is not None: test_params = valid_params.copy() # Shallow copy should be okay - test_params[invalid_param] = invalid_edge + test_params[invalid_param] = at.as_tensor_variable(invalid_edge) # We need to remove `Assert`s introduced by checks like # `assert_negative_support` and disable test values; # otherwise, we won't be able to create the @@ -835,11 +889,8 @@ def check_logcdf( with aesara.config.change_flags(compute_test_value="off"): invalid_dist = pymc_dist.dist(**test_params) with aesara.config.change_flags(mode=Mode("py")): - assert_equal( - logcdf(invalid_dist, valid_value).eval(), - -np.inf, - err_msg=str(test_params), - ) + with pytest.raises(ParameterValueError): + logcdf(invalid_dist, valid_value).eval() # Test that values below domain edge evaluate to -np.inf if np.isfinite(domain.lower): @@ -852,7 +903,7 @@ def check_logcdf( ) # Test that values above domain edge evaluate to 0 - if domain not in nat_domains and np.isfinite(domain.upper): + if np.isfinite(domain.upper): above_domain = domain.upper + 1 with aesara.config.change_flags(mode=Mode("py")): assert_equal( @@ -938,6 +989,7 @@ def test_uniform(self): Runif, {"lower": -Rplusunif, "upper": Rplusunif}, lambda value, lower, upper: sp.uniform.logpdf(value, lower, upper - lower), + skip_paramdomain_outside_edge_test=True, ) self.check_logcdf( Uniform, @@ -958,6 +1010,7 @@ def test_triangular(self): Runif, {"lower": -Rplusunif, "c": Runif, "upper": Rplusunif}, lambda value, c, lower, upper: sp.triang.logpdf(value, c - lower, lower, upper - lower), + skip_paramdomain_outside_edge_test=True, ) self.check_logcdf( Triangular, @@ -977,11 +1030,13 @@ def test_triangular(self): # Invalid logp checks for triangular are being done in aeppl invalid_dist = Triangular.dist(lower=1, upper=0, c=0.1) with aesara.config.change_flags(mode=Mode("py")): - assert logcdf(invalid_dist, 2).eval() == -np.inf + with pytest.raises(ParameterValueError): + logcdf(invalid_dist, 2).eval() invalid_dist = Triangular.dist(lower=0, upper=1, c=2.0) with aesara.config.change_flags(mode=Mode("py")): - assert logcdf(invalid_dist, 2).eval() == -np.inf + with pytest.raises(ParameterValueError): + logcdf(invalid_dist, 2).eval() @pytest.mark.skipif( condition=_polyagamma_not_installed, @@ -1009,6 +1064,7 @@ def test_discrete_unif(self): Rdunif, {"lower": -Rplusdunif, "upper": Rplusdunif}, lambda value, lower, upper: sp.randint.logpmf(value, lower, upper + 1), + skip_paramdomain_outside_edge_test=True, ) self.check_logcdf( DiscreteUniform, @@ -1019,17 +1075,19 @@ def test_discrete_unif(self): ) self.check_selfconsistency_discrete_logcdf( DiscreteUniform, - Rdunif, + Domain([-10, 0, 10], "int64"), {"lower": -Rplusdunif, "upper": Rplusdunif}, ) # Custom logp / logcdf check for invalid parameters invalid_dist = DiscreteUniform.dist(lower=1, upper=0) with aesara.config.change_flags(mode=Mode("py")): - assert logp(invalid_dist, 0.5).eval() == -np.inf - assert logcdf(invalid_dist, 2).eval() == -np.inf + with pytest.raises(ParameterValueError): + logp(invalid_dist, 0.5).eval() + with pytest.raises(ParameterValueError): + logcdf(invalid_dist, 2).eval() def test_flat(self): - self.check_logp(Flat, Runif, {}, lambda value: 0) + self.check_logp(Flat, R, {}, lambda value: 0) with Model(): x = Flat("a") self.check_logcdf(Flat, R, {}, lambda value: np.log(0.5)) @@ -1074,6 +1132,7 @@ def scipy_logp(value, mu, sigma, lower, upper): {"mu": R, "sigma": Rplusbig, "lower": -Rplusbig, "upper": Rplusbig}, scipy_logp, decimal=select_by_precision(float64=6, float32=1), + skip_paramdomain_outside_edge_test=True, ) self.check_logp( @@ -1082,6 +1141,7 @@ def scipy_logp(value, mu, sigma, lower, upper): {"mu": R, "sigma": Rplusbig, "upper": Rplusbig}, functools.partial(scipy_logp, lower=-np.inf), decimal=select_by_precision(float64=6, float32=1), + skip_paramdomain_outside_edge_test=True, ) self.check_logp( @@ -1090,6 +1150,7 @@ def scipy_logp(value, mu, sigma, lower, upper): {"mu": R, "sigma": Rplusbig, "lower": -Rplusbig}, functools.partial(scipy_logp, upper=np.inf), decimal=select_by_precision(float64=6, float32=1), + skip_paramdomain_outside_edge_test=True, ) def test_half_normal(self): @@ -1173,9 +1234,6 @@ def test_wald_logp_custom_points(self, value, mu, lam, phi, alpha, logp): decimals = select_by_precision(float64=6, float32=1) assert_almost_equal(model.fastlogp(pt), logp, decimal=decimals, err_msg=str(pt)) - @pytest.mark.xfail( - reason="Fails because mu and sigma values are being picked randomly from domains" - ) def test_beta_logp(self): self.check_logp( Beta, @@ -1355,6 +1413,7 @@ def test_negative_binomial_init_fail(self, mu, p, alpha, n, expected): with pytest.raises(ValueError, match=f"Incompatible parametrization. {expected}"): NegativeBinomial("x", mu=mu, p=p, alpha=alpha, n=n) + @pytest.mark.xfail(reason="Aeppl Laplace does not have a CheckParameterValue for b") def test_laplace(self): self.check_logp( Laplace, @@ -1682,13 +1741,13 @@ def test_discrete_weibull(self): self.check_logp( DiscreteWeibull, Nat, - {"q": Unit, "beta": Rplusdunif}, + {"q": Unit, "beta": NatSmall}, discrete_weibull_logpmf, ) self.check_selfconsistency_discrete_logcdf( DiscreteWeibull, Nat, - {"q": Unit, "beta": Rplusdunif}, + {"q": Unit, "beta": NatSmall}, ) def test_poisson(self): @@ -1918,14 +1977,16 @@ def test_mvnormal_indef(self): x.tag.test_value = np.zeros(2) mvn_logp = logp(MvNormal.dist(mu=mu, cov=cov), x) f_logp = aesara.function([cov, x], mvn_logp) - assert f_logp(cov_val, np.ones(2)) == -np.inf + with pytest.raises(ParameterValueError): + f_logp(cov_val, np.ones(2)) dlogp = at.grad(mvn_logp, cov) f_dlogp = aesara.function([cov, x], dlogp) assert not np.all(np.isfinite(f_dlogp(cov_val, np.ones(2)))) mvn_logp = logp(MvNormal.dist(mu=mu, tau=cov), x) f_logp = aesara.function([cov, x], mvn_logp) - assert f_logp(cov_val, np.ones(2)) == -np.inf + with pytest.raises(ParameterValueError): + f_logp(cov_val, np.ones(2)) dlogp = at.grad(mvn_logp, cov) f_dlogp = aesara.function([cov, x], dlogp) assert not np.all(np.isfinite(f_dlogp(cov_val, np.ones(2)))) @@ -2089,7 +2150,7 @@ def test_wishart(self, n): self.check_logp( Wishart, PdMatrix(n), - {"nu": Domain([3, 4, 2000]), "V": PdMatrix(n)}, + {"nu": Domain([0, 3, 4, np.inf], "int64"), "V": PdMatrix(n)}, lambda value, nu, V: scipy.stats.wishart.logpdf(value, np.int(nu), V), ) @@ -2112,6 +2173,18 @@ def test_dirichlet(self, n): dirichlet_logpdf, ) + def test_dirichlet_invalid(self): + # Test non-scalar invalid parameters/values + value = np.array([[0.1, 0.2, 0.7], [0.3, 0.3, 0.4]]) + + invalid_dist = Dirichlet.dist(a=[-1, 1, 2], size=2) + with pytest.raises(ParameterValueError): + pm.logp(invalid_dist, value).eval() + + value[1] -= 1 + valid_dist = Dirichlet.dist(a=[1, 1, 1]) + assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False])) + @pytest.mark.parametrize( "a", [ @@ -2143,6 +2216,20 @@ def test_multinomial(self, n): lambda value, n, p: scipy.stats.multinomial.logpmf(value, n, p), ) + def test_multinomial_invalid(self): + # Test non-scalar invalid parameters/values + value = np.array([[1, 2, 2], [4, 0, 1]]) + + invalid_dist = Multinomial.dist(n=5, p=[-1, 1, 1], size=2) + # TODO: Multinomial normalizes p, so it is impossible to trigger p checks + # with pytest.raises(ParameterValueError): + with does_not_raise(): + pm.logp(invalid_dist, value).eval() + + value[1] -= 1 + valid_dist = Multinomial.dist(n=5, p=np.ones(3) / 3) + assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False])) + @pytest.mark.parametrize("n", [(10), ([10, 11]), ([[5, 6], [10, 11]])]) @pytest.mark.parametrize( "p", @@ -2183,6 +2270,18 @@ def test_dirichlet_multinomial(self, n): dirichlet_multinomial_logpmf, ) + def test_dirichlet_multinomial_invalid(self): + # Test non-scalar invalid parameters/values + value = np.array([[1, 2, 2], [4, 0, 1]]) + + invalid_dist = DirichletMultinomial.dist(n=5, a=[-1, 1, 1], size=2) + with pytest.raises(ParameterValueError): + pm.logp(invalid_dist, value).eval() + + value[1] -= 1 + valid_dist = DirichletMultinomial.dist(n=5, a=[1, 1, 1]) + assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False])) + def test_dirichlet_multinomial_matches_beta_binomial(self): a, b, n = 2, 1, 5 ns = np.arange(n + 1) @@ -2232,34 +2331,31 @@ def test_categorical_bounds(self): assert np.isinf(logp(x, 3).eval()) @aesara.config.change_flags(compute_test_value="raise") - def test_categorical_valid_p(self): - with Model(): - x = Categorical("x", p=np.array([-0.2, 0.3, 0.5])) - assert np.isinf(logp(x, 0).eval()) - assert np.isinf(logp(x, 1).eval()) - assert np.isinf(logp(x, 2).eval()) - with Model(): + @pytest.mark.parametrize( + "p", + [ + np.array([-0.2, 0.3, 0.5]), # A model where p sums to 1 but contains negative values - x = Categorical("x", p=np.array([-0.2, 0.7, 0.5])) - assert np.isinf(logp(x, 0).eval()) - assert np.isinf(logp(x, 1).eval()) - assert np.isinf(logp(x, 2).eval()) - with Model(): + np.array([-0.2, 0.7, 0.5]), # Hard edge case from #2082 # Early automatic normalization of p's sum would hide the negative # entries if there is a single or pair number of negative values # and the rest are zero - x = Categorical("x", p=np.array([-1, -1, 0, 0])) - assert np.isinf(logp(x, 0).eval()) - assert np.isinf(logp(x, 1).eval()) - assert np.isinf(logp(x, 2).eval()) - assert np.isinf(logp(x, 3).eval()) + np.array([-1, -1, 0, 0]), + ], + ) + def test_categorical_valid_p(self, p): + with Model(): + x = Categorical("x", p=p) + + with pytest.raises(ParameterValueError): + logp(x, 2).eval() @pytest.mark.parametrize("n", [2, 3, 4]) def test_categorical(self, n): self.check_logp( Categorical, - Domain(range(n), dtype="int64", edges=(None, None)), + Domain(range(n), dtype="int64", edges=(0, n)), {"p": Simplex(n)}, lambda value, p: categorical_logpdf(value, p), ) @@ -2290,8 +2386,19 @@ def logp(x): self.checkd(DensityDist, R, {}, extra_args={"logp": logp}) def test_get_tau_sigma(self): - sigma = np.array([2]) - assert_almost_equal(continuous.get_tau_sigma(sigma=sigma), [1.0 / sigma ** 2, sigma]) + sigma = np.array(2) + assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma ** 2, sigma]) + + tau = np.array(2) + assert_almost_equal(get_tau_sigma(tau=tau), [tau, tau ** -0.5]) + + tau, _ = get_tau_sigma(sigma=at.constant(-2)) + with pytest.raises(ParameterValueError): + tau.eval() + + _, sigma = get_tau_sigma(tau=at.constant(-2)) + with pytest.raises(ParameterValueError): + sigma.eval() @pytest.mark.parametrize( "value,mu,sigma,nu,logp", @@ -2361,8 +2468,8 @@ def test_ex_gaussian_cdf_outside_edges(self): def test_vonmises(self): self.check_logp( VonMises, - R, - {"mu": Circ, "kappa": Rplus}, + Circ, + {"mu": R, "kappa": Rplus}, lambda value, mu, kappa: floatX(sp.vonmises.logpdf(value, kappa, loc=mu)), ) diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index b5d1b0cfd9..e505d2d107 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -163,8 +163,8 @@ def test_normal_moment(mu, sigma, size, expected): [ (1, None, 1), (1, 5, np.ones(5)), - (np.arange(5), None, np.arange(5)), - (np.arange(5), (2, 5), np.full((2, 5), np.arange(5))), + (np.arange(1, 6), None, np.arange(1, 6)), + (np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(1, 6))), ], ) def test_halfnormal_moment(sigma, size, expected): @@ -178,7 +178,7 @@ def test_halfnormal_moment(sigma, size, expected): [ (1, 1, None, 1), (1, 1, 5, np.ones(5)), - (1, np.arange(5), (2, 5), np.full((2, 5), np.arange(5))), + (1, np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(1, 6))), (np.arange(1, 6), 1, None, np.full(5, 1)), ], ) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 6978128872..6f3a37a719 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -31,7 +31,7 @@ import pymc as pm -from pymc.aesaraf import compile_rv_inplace +from pymc.aesaraf import compile_pymc from pymc.backends.base import MultiTrace from pymc.backends.ndarray import NDArray from pymc.exceptions import IncorrectArgumentsError, SamplingError @@ -1016,7 +1016,7 @@ def test_layers(self): a = pm.Uniform("a", lower=0, upper=1, size=10) b = pm.Binomial("b", n=1, p=a, size=10) - b_sampler = compile_rv_inplace([], b, mode="FAST_RUN") + b_sampler = compile_pymc([], b, mode="FAST_RUN") avg = np.stack([b_sampler() for i in range(10000)]).mean(0) npt.assert_array_almost_equal(avg, 0.5 * np.ones((10,)), decimal=2) diff --git a/requirements-dev.txt b/requirements-dev.txt index c5fbff991b..19d903eb68 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ # This file is auto-generated by scripts/generate_pip_deps_from_conda.py, do not modify. # See that file for comments about the need/usage of each dependency. -aeppl==0.0.17 +aeppl==0.0.18 aesara>=2.2.6 arviz>=0.11.4 cachetools>=4.2.1 diff --git a/requirements.txt b/requirements.txt index 75da478008..5dbc6e1f57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -aeppl==0.0.17 +aeppl==0.0.18 aesara>=2.2.6 arviz>=0.11.4 cachetools>=4.2.1