diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 99568861bf..265d535b48 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2192,32 +2192,23 @@ def make_node(self, rng, size, dtype, alpha, K): alpha = at.as_tensor_variable(alpha) K = at.as_tensor_variable(intX(K)) - if alpha.ndim > 0: - raise ValueError("The concentration parameter needs to be a scalar.") - if K.ndim > 0: raise ValueError("K must be a scalar.") return super().make_node(rng, size, dtype, alpha, K) - def _infer_shape(self, size, dist_params, param_shapes=None): - alpha, K = dist_params - - size = tuple(size) - - return size + (K + 1,) + def _supp_shape_from_params(self, dist_params, **kwargs): + K = dist_params[1] + return (K + 1,) @classmethod def rng_fn(cls, rng, alpha, K, size): if K < 0: raise ValueError("K needs to be positive.") - if size is None: - size = (K,) - elif isinstance(size, int): - size = (size,) + (K,) - else: - size = tuple(size) + (K,) + size = to_tuple(size) if size is not None else alpha.shape + size = size + (K,) + alpha = alpha[..., np.newaxis] betas = rng.beta(1, alpha, size=size) @@ -2286,9 +2277,10 @@ def dist(cls, alpha, K, *args, **kwargs): return super().dist([alpha, K], **kwargs) def moment(rv, size, alpha, K): + alpha = alpha[..., np.newaxis] moment = (alpha / (1 + alpha)) ** at.arange(K) moment *= 1 / (1 + alpha) - moment = at.concatenate([moment, [(alpha / (1 + alpha)) ** K]], axis=-1) + moment = at.concatenate([moment, (alpha / (1 + alpha)) ** K], axis=-1) if not rv_size_is_none(size): moment_size = at.concatenate( [ diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index 28705ba1b8..d0c1135bdd 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -23,6 +23,7 @@ from aeppl.logprob import ParameterValueError from aesara.tensor.random.utils import broadcast_params +from pymc.aesaraf import compile_pymc from pymc.distributions.continuous import get_tau_sigma from pymc.util import UNSET @@ -952,6 +953,17 @@ def test_hierarchical_obs_logp(): assert not any(isinstance(o, RandomVariable) for o in ops) +@pytest.fixture(scope="module") +def stickbreakingweights_logpdf(): + _value = at.vector() + _alpha = at.scalar() + _k = at.iscalar() + _logp = logp(StickBreakingWeights.dist(_alpha, _k), _value) + core_fn = compile_pymc([_value, _alpha, _k], _logp) + + return np.vectorize(core_fn, signature="(n),(),()->()") + + class TestMatchesScipy: def test_uniform(self): check_logp( @@ -2312,6 +2324,25 @@ def test_stickbreakingweights_invalid(self): assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf + @pytest.mark.parametrize( + "alpha,K", + [ + (np.array([0.5, 1.0, 2.0]), 3), + (np.arange(1, 7, dtype="float64").reshape(2, 3), 5), + ], + ) + def test_stickbreakingweights_vectorized(self, alpha, K, stickbreakingweights_logpdf): + value = pm.StickBreakingWeights.dist(alpha, K).eval() + with Model(): + sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None) + pt = {"sbw": value} + assert_almost_equal( + pm.logp(sbw, value).eval(), + stickbreakingweights_logpdf(value, alpha, K), + decimal=select_by_precision(float64=6, float32=2), + err_msg=str(pt), + ) + @aesara.config.change_flags(compute_test_value="raise") def test_categorical_bounds(self): with Model(): diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 75c29849bc..fac192315c 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -1166,6 +1166,32 @@ def test_rice_moment(nu, sigma, size, expected): fill_value=np.append((1 / 3) ** np.arange(5) * 2 / 3, (1 / 3) ** 5), ), ), + ( + np.array([1, 3]), + 11, + None, + np.array( + [ + np.append((1 / 2) ** np.arange(11) * 1 / 2, (1 / 2) ** 11), + np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11), + ] + ), + ), + ( + np.array([1, 3, 5]), + 9, + (5, 3), + np.full( + shape=(5, 3, 10), + fill_value=np.array( + [ + np.append((1 / 2) ** np.arange(9) * 1 / 2, (1 / 2) ** 9), + np.append((3 / 4) ** np.arange(9) * 1 / 4, (3 / 4) ** 9), + np.append((5 / 6) ** np.arange(9) * 1 / 6, (5 / 6) ** 9), + ] + ), + ), + ), ], ) def test_stickbreakingweights_moment(alpha, K, size, expected): diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 5818c5adf5..de7e6cd088 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -1321,6 +1321,18 @@ def check_basic_properties(self): assert np.all(draws <= 1) +class TestStickBreakingWeights_1D_alpha(BaseTestDistributionRandom): + pymc_dist = pm.StickBreakingWeights + pymc_dist_params = {"alpha": [1.0, 2.0, 3.0], "K": 19} + expected_rv_op_params = {"alpha": [1.0, 2.0, 3.0], "K": 19} + sizes_to_check = [None, (3,), (5, 3)] + sizes_expected = [(3, 20), (3, 20), (5, 3, 20)] + checks_to_run = [ + "check_pymc_params_match_rv_op", + "check_rv_size", + ] + + class TestCategorical(BaseTestDistributionRandom): pymc_dist = pm.Categorical pymc_dist_params = {"p": np.array([0.28, 0.62, 0.10])}