Skip to content

Commit a0cff37

Browse files
markvrmaricardoV94
authored andcommitted
Generalize Multinomial moment to arbitrary dimensions
1 parent 5a44793 commit a0cff37

File tree

2 files changed

+20
-31
lines changed

2 files changed

+20
-31
lines changed

Diff for: pymc/distributions/multivariate.py

+4-21
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
logpow,
5757
multigammaln,
5858
)
59-
from pymc.distributions.distribution import Continuous, Discrete
59+
from pymc.distributions.distribution import Continuous, Discrete, get_moment
6060
from pymc.distributions.shape_utils import (
6161
broadcast_dist_samples_to,
6262
rv_size_is_none,
@@ -558,11 +558,7 @@ def dist(cls, n, p, *args, **kwargs):
558558
return super().dist([n, p], *args, **kwargs)
559559

560560
def get_moment(rv, size, n, p):
561-
if p.ndim > 1:
562-
n = at.shape_padright(n)
563-
if (p.ndim == 1) & (n.ndim > 0):
564-
n = at.shape_padright(n)
565-
p = at.shape_padleft(p)
561+
n = at.shape_padright(n)
566562
mode = at.round(n * p)
567563
diff = n - at.sum(mode, axis=-1, keepdims=True)
568564
inc_bool_arr = at.abs_(diff) > 0
@@ -682,21 +678,8 @@ def dist(cls, n, a, *args, **kwargs):
682678
return super().dist([n, a], **kwargs)
683679

684680
def get_moment(rv, size, n, a):
685-
p = a / at.sum(a, axis=-1)
686-
mode = at.round(n * p)
687-
diff = n - at.sum(mode, axis=-1, keepdims=True)
688-
inc_bool_arr = at.abs_(diff) > 0
689-
mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
690-
691-
# Reshape mode according to dimensions implied by the parameters
692-
# This can include axes of length 1
693-
_, p_bcast = broadcast_params([n, p], ndims_params=[0, 1])
694-
mode = at.reshape(mode, p_bcast.shape)
695-
696-
if not rv_size_is_none(size):
697-
output_size = at.concatenate([size, [p.shape[-1]]])
698-
mode = at.full(output_size, mode)
699-
return mode
681+
p = a / at.sum(a, axis=-1, keepdims=True)
682+
return get_moment(Multinomial.dist(n=n, p=p, size=size))
700683

701684
def logp(value, n, a):
702685
"""

Diff for: pymc/tests/test_distributions_moments.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -1308,22 +1308,22 @@ def test_polyagamma_moment(h, z, size, expected):
13081308
np.array([[4, 6, 0, 0], [4, 2, 2, 2]]),
13091309
),
13101310
(
1311-
np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]),
1312-
np.array([1, 10]),
1313-
None,
1314-
np.array([[1, 0, 0, 0], [2, 3, 3, 2]]),
1311+
np.array([0.3, 0.6, 0.05, 0.05]),
1312+
np.array([2, 10]),
1313+
(1, 2),
1314+
np.array([[[1, 1, 0, 0], [4, 6, 0, 0]]]),
13151315
),
13161316
(
1317-
np.array([0.26, 0.26, 0.26, 0.22]),
1317+
np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]),
13181318
np.array([1, 10]),
13191319
None,
13201320
np.array([[1, 0, 0, 0], [2, 3, 3, 2]]),
13211321
),
13221322
(
13231323
np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]),
13241324
np.array([1, 10]),
1325-
(2, 2),
1326-
np.full((2, 2, 4), [[1, 0, 0, 0], [2, 3, 3, 2]]),
1325+
(3, 2),
1326+
np.full((3, 2, 4), [[1, 0, 0, 0], [2, 3, 3, 2]]),
13271327
),
13281328
],
13291329
)
@@ -1470,10 +1470,16 @@ def test_lkjcholeskycov_moment(n, eta, size, expected):
14701470
(np.array([3, 6, 0.5, 0.5]), 2, None, np.array([1, 1, 0, 0])),
14711471
(np.array([30, 60, 5, 5]), 10, None, np.array([4, 6, 0, 0])),
14721472
(
1473-
np.array([[26, 26, 26, 22]]), # Dim: 1 x 4
1474-
np.array([[1], [10]]), # Dim: 2 x 1
1473+
np.array([[30, 60, 5, 5], [26, 26, 26, 22]]),
1474+
10,
1475+
(1, 2),
1476+
np.array([[[4, 6, 0, 0], [2, 3, 3, 2]]]),
1477+
),
1478+
(
1479+
np.array([26, 26, 26, 22]),
1480+
np.array([1, 10]),
14751481
None,
1476-
np.array([[[1, 0, 0, 0]], [[2, 3, 3, 2]]]), # Dim: 2 x 1 x 4
1482+
np.array([[1, 0, 0, 0], [2, 3, 3, 2]]),
14771483
),
14781484
(
14791485
np.array([[26, 26, 26, 22]]), # Dim: 1 x 4

0 commit comments

Comments
 (0)