Skip to content

Commit f801310

Browse files
AlexAndorraricardoV94
authored andcommitted
Refactor DirichletMultinomial for V4
1 parent e5e83d0 commit f801310

File tree

3 files changed

+123
-191
lines changed

3 files changed

+123
-191
lines changed

pymc3/distributions/multivariate.py

Lines changed: 53 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,6 @@ class Dirichlet(Continuous):
393393
a: array
394394
Concentration parameters (a > 0).
395395
"""
396-
397396
rv_op = dirichlet
398397

399398
def __new__(cls, name, *args, **kwargs):
@@ -501,11 +500,6 @@ def dist(cls, n, p, *args, **kwargs):
501500
n = at.as_tensor_variable(n)
502501
p = at.as_tensor_variable(p)
503502

504-
# mean = n * p
505-
# mode = at.cast(at.round(mean), "int32")
506-
# diff = n - at.sum(mode, axis=-1, keepdims=True)
507-
# inc_bool_arr = at.abs_(diff) > 0
508-
# mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
509503
return super().dist([n, p], *args, **kwargs)
510504

511505
def logp(value, n, p):
@@ -515,7 +509,7 @@ def logp(value, n, p):
515509
516510
Parameters
517511
----------
518-
x: numeric
512+
value: numeric
519513
Value for which log-probability is calculated.
520514
521515
Returns
@@ -533,6 +527,46 @@ def logp(value, n, p):
533527
)
534528

535529

530+
class DirichletMultinomialRV(RandomVariable):
531+
name = "dirichlet_multinomial"
532+
ndim_supp = 1
533+
ndims_params = [0, 1]
534+
dtype = "int64"
535+
_print_name = ("DirichletMN", "\\operatorname{DirichletMN}")
536+
537+
def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
538+
return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes)
539+
540+
@classmethod
541+
def rng_fn(cls, rng, n, a, size):
542+
543+
if n.ndim > 0 or a.ndim > 1:
544+
n, a = broadcast_params([n, a], cls.ndims_params)
545+
size = tuple(size or ())
546+
547+
if size:
548+
n = np.broadcast_to(n, size + n.shape)
549+
a = np.broadcast_to(a, size + a.shape)
550+
551+
res = np.empty(a.shape)
552+
for idx in np.ndindex(a.shape[:-1]):
553+
p = rng.dirichlet(a[idx])
554+
res[idx] = rng.multinomial(n[idx], p)
555+
return res
556+
else:
557+
# n is a scalar, a is a 1d array
558+
p = rng.dirichlet(a, size=size) # (size, a.shape)
559+
560+
res = np.empty(p.shape)
561+
for idx in np.ndindex(p.shape[:-1]):
562+
res[idx] = rng.multinomial(n, p[idx])
563+
564+
return res
565+
566+
567+
dirichlet_multinomial = DirichletMultinomialRV()
568+
569+
536570
class DirichletMultinomial(Discrete):
537571
r"""Dirichlet Multinomial log-likelihood.
538572
@@ -566,92 +600,16 @@ class DirichletMultinomial(Discrete):
566600
Describes shape of distribution. For example if n=array([5, 10]), and
567601
a=array([1, 1, 1]), shape should be (2, 3).
568602
"""
603+
rv_op = dirichlet_multinomial
569604

570-
def __init__(self, n, a, shape, *args, **kwargs):
571-
572-
super().__init__(shape=shape, defaults=("_defaultval",), *args, **kwargs)
573-
605+
@classmethod
606+
def dist(cls, n, a, *args, **kwargs):
574607
n = intX(n)
575608
a = floatX(a)
576-
if len(self.shape) > 1:
577-
self.n = at.shape_padright(n)
578-
self.a = at.as_tensor_variable(a) if a.ndim > 1 else at.shape_padleft(a)
579-
else:
580-
# n is a scalar, p is a 1d array
581-
self.n = at.as_tensor_variable(n)
582-
self.a = at.as_tensor_variable(a)
583-
584-
p = self.a / self.a.sum(-1, keepdims=True)
585-
586-
self.mean = self.n * p
587-
# Mode is only an approximation. Exact computation requires a complex
588-
# iterative algorithm as described in https://doi.org/10.1016/j.spl.2009.09.013
589-
mode = at.cast(at.round(self.mean), "int32")
590-
diff = self.n - at.sum(mode, axis=-1, keepdims=True)
591-
inc_bool_arr = at.abs_(diff) > 0
592-
mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
593-
self._defaultval = mode
594-
595-
def _random(self, n, a, size=None):
596-
# numpy will cast dirichlet and multinomial samples to float64 by default
597-
original_dtype = a.dtype
598-
599-
# Thanks to the default shape handling done in generate_values, the last
600-
# axis of n is a dummy axis that allows it to broadcast well with `a`
601-
n = np.broadcast_to(n, size)
602-
a = np.broadcast_to(a, size)
603-
n = n[..., 0]
604-
605-
# np.random.multinomial needs `n` to be a scalar int and `a` a
606-
# sequence so we semi flatten them and iterate over them
607-
n_ = n.reshape([-1])
608-
a_ = a.reshape([-1, a.shape[-1]])
609-
p_ = np.array([np.random.dirichlet(aa) for aa in a_])
610-
samples = np.array([np.random.multinomial(nn, pp) for nn, pp in zip(n_, p_)])
611-
samples = samples.reshape(a.shape)
612-
613-
# We cast back to the original dtype
614-
return samples.astype(original_dtype)
615609

616-
def random(self, point=None, size=None):
617-
"""
618-
Draw random values from Dirichlet-Multinomial distribution.
619-
620-
Parameters
621-
----------
622-
point: dict, optional
623-
Dict of variable values on which random values are to be
624-
conditioned (uses default point if not specified).
625-
size: int, optional
626-
Desired size of random sample (returns one sample if not
627-
specified).
610+
return super().dist([n, a], **kwargs)
628611

629-
Returns
630-
-------
631-
array
632-
"""
633-
# n, a = draw_values([self.n, self.a], point=point, size=size)
634-
# samples = generate_samples(
635-
# self._random,
636-
# n,
637-
# a,
638-
# dist_shape=self.shape,
639-
# size=size,
640-
# )
641-
#
642-
# # If distribution is initialized with .dist(), valid init shape is not asserted.
643-
# # Under normal use in a model context valid init shape is asserted at start.
644-
# expected_shape = to_tuple(size) + to_tuple(self.shape)
645-
# sample_shape = tuple(samples.shape)
646-
# if sample_shape != expected_shape:
647-
# raise ShapeError(
648-
# f"Expected sample shape was {expected_shape} but got {sample_shape}. "
649-
# "This may reflect an invalid initialization shape."
650-
# )
651-
#
652-
# return samples
653-
654-
def logp(self, value):
612+
def logp(value, n, a):
655613
"""
656614
Calculate log-probability of DirichletMultinomial distribution
657615
at specified value.
@@ -665,13 +623,16 @@ def logp(self, value):
665623
-------
666624
TensorVariable
667625
"""
668-
a = self.a
669-
n = self.n
670-
sum_a = a.sum(axis=-1, keepdims=True)
626+
if value.ndim >= 1:
627+
n = at.shape_padright(n)
628+
if a.ndim > 1:
629+
a = at.shape_padleft(a)
671630

631+
sum_a = a.sum(axis=-1, keepdims=True)
672632
const = (gammaln(n + 1) + gammaln(sum_a)) - gammaln(n + sum_a)
673633
series = gammaln(value + a) - (gammaln(value + 1) + gammaln(a))
674634
result = const + series.sum(axis=-1, keepdims=True)
635+
675636
# Bounds checking to confirm parameters and data meet all constraints
676637
# and that each observation value_i sums to n_i.
677638
return bound(
@@ -968,7 +929,7 @@ def logp(X, nu, V):
968929

969930

970931
def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initval=None):
971-
R"""
932+
r"""
972933
Bartlett decomposition of the Wishart distribution. As the Wishart
973934
distribution requires the matrix to be symmetric positive semi-definite
974935
it is impossible for MCMC to ever propose acceptable matrices.

pymc3/tests/test_distributions.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2273,7 +2273,6 @@ def test_batch_multinomial(self):
22732273
assert_allclose(sample, np.stack([vals, vals], axis=0))
22742274

22752275
@pytest.mark.parametrize("n", [2, 3])
2276-
@pytest.mark.xfail(reason="Distribution not refactored yet")
22772276
def test_dirichlet_multinomial(self, n):
22782277
self.check_logp(
22792278
DirichletMultinomial,
@@ -2282,43 +2281,47 @@ def test_dirichlet_multinomial(self, n):
22822281
dirichlet_multinomial_logpmf,
22832282
)
22842283

2285-
@pytest.mark.xfail(reason="Distribution not refactored yet")
22862284
def test_dirichlet_multinomial_matches_beta_binomial(self):
22872285
a, b, n = 2, 1, 5
22882286
ns = np.arange(n + 1)
2289-
ns_dm = np.vstack((ns, n - ns)).T # covert ns=1 to ns_dm=[1, 4], for all ns...
2290-
bb_logp = logpt(pm.BetaBinomial.dist(n=n, alpha=a, beta=b), ns).tag.test_value
2291-
dm_logp = logpt(
2292-
pm.DirichletMultinomial.dist(n=n, a=[a, b], size=(1, 2)), ns_dm
2293-
).tag.test_value
2294-
dm_logp = dm_logp.ravel()
2287+
ns_dm = np.vstack((ns, n - ns)).T # convert ns=1 to ns_dm=[1, 4], for all ns...
2288+
2289+
bb = pm.BetaBinomial.dist(n=n, alpha=a, beta=b, size=2)
2290+
bb_value = bb.type()
2291+
bb.tag.value_var = bb_value
2292+
bb_logp = logpt(var=bb, rv_values={bb: bb_value}).eval({bb_value: ns})
2293+
2294+
dm = pm.DirichletMultinomial.dist(n=n, a=[a, b], size=2)
2295+
dm_value = dm.type()
2296+
dm.tag.value_var = dm_value
2297+
dm_logp = logpt(var=dm, rv_values={dm: dm_value}).eval({dm_value: ns_dm}).ravel()
2298+
22952299
assert_almost_equal(
22962300
dm_logp,
22972301
bb_logp,
22982302
decimal=select_by_precision(float64=6, float32=3),
22992303
)
23002304

2301-
@pytest.mark.xfail(reason="Distribution not refactored yet")
23022305
def test_dirichlet_multinomial_vec(self):
23032306
vals = np.array([[2, 4, 4], [3, 3, 4]])
23042307
a = np.array([0.2, 0.3, 0.5])
23052308
n = 10
23062309

23072310
with Model() as model_single:
2308-
DirichletMultinomial("m", n=n, a=a, size=len(a))
2311+
DirichletMultinomial("m", n=n, a=a)
23092312

23102313
with Model() as model_many:
2311-
DirichletMultinomial("m", n=n, a=a, size=vals.shape)
2314+
DirichletMultinomial("m", n=n, a=a, size=2)
23122315

23132316
assert_almost_equal(
2314-
np.asarray([dirichlet_multinomial_logpmf(v, n, a) for v in vals]),
2317+
np.asarray([dirichlet_multinomial_logpmf(val, n, a) for val in vals]),
23152318
np.asarray([model_single.fastlogp({"m": val}) for val in vals]),
23162319
decimal=4,
23172320
)
23182321

23192322
assert_almost_equal(
2320-
np.asarray([dirichlet_multinomial_logpmf(v, n, a) for v in vals]),
2321-
model_many.free_RVs[0].logp_elemwise({"m": vals}).squeeze(),
2323+
np.asarray([dirichlet_multinomial_logpmf(val, n, a) for val in vals]),
2324+
logpt(model_many.m, vals).eval().squeeze(),
23222325
decimal=4,
23232326
)
23242327

@@ -2328,56 +2331,52 @@ def test_dirichlet_multinomial_vec(self):
23282331
decimal=4,
23292332
)
23302333

2331-
@pytest.mark.xfail(reason="Distribution not refactored yet")
23322334
def test_dirichlet_multinomial_vec_1d_n(self):
23332335
vals = np.array([[2, 4, 4], [4, 3, 4]])
23342336
a = np.array([0.2, 0.3, 0.5])
23352337
ns = np.array([10, 11])
23362338

23372339
with Model() as model:
2338-
DirichletMultinomial("m", n=ns, a=a, size=vals.shape)
2340+
DirichletMultinomial("m", n=ns, a=a)
23392341

23402342
assert_almost_equal(
23412343
sum(dirichlet_multinomial_logpmf(val, n, a) for val, n in zip(vals, ns)),
23422344
model.fastlogp({"m": vals}),
23432345
decimal=4,
23442346
)
23452347

2346-
@pytest.mark.xfail(reason="Distribution not refactored yet")
23472348
def test_dirichlet_multinomial_vec_1d_n_2d_a(self):
23482349
vals = np.array([[2, 4, 4], [4, 3, 4]])
23492350
as_ = np.array([[0.2, 0.3, 0.5], [0.9, 0.09, 0.01]])
23502351
ns = np.array([10, 11])
23512352

23522353
with Model() as model:
2353-
DirichletMultinomial("m", n=ns, a=as_, size=vals.shape)
2354+
DirichletMultinomial("m", n=ns, a=as_)
23542355

23552356
assert_almost_equal(
23562357
sum(dirichlet_multinomial_logpmf(val, n, a) for val, n, a in zip(vals, ns, as_)),
23572358
model.fastlogp({"m": vals}),
23582359
decimal=4,
23592360
)
23602361

2361-
@pytest.mark.xfail(reason="Distribution not refactored yet")
23622362
def test_dirichlet_multinomial_vec_2d_a(self):
23632363
vals = np.array([[2, 4, 4], [3, 3, 4]])
23642364
as_ = np.array([[0.2, 0.3, 0.5], [0.3, 0.3, 0.4]])
23652365
n = 10
23662366

23672367
with Model() as model:
2368-
DirichletMultinomial("m", n=n, a=as_, size=vals.shape)
2368+
DirichletMultinomial("m", n=n, a=as_)
23692369

23702370
assert_almost_equal(
23712371
sum(dirichlet_multinomial_logpmf(val, n, a) for val, a in zip(vals, as_)),
23722372
model.fastlogp({"m": vals}),
23732373
decimal=4,
23742374
)
23752375

2376-
@pytest.mark.xfail(reason="Distribution not refactored yet")
23772376
def test_batch_dirichlet_multinomial(self):
23782377
# Test that DM can handle a 3d array for `a`
23792378

2380-
# Create an almost deterministic DM by setting a to 0.001, everywehere
2379+
# Create an almost deterministic DM by setting a to 0.001, everywhere
23812380
# except for one category / dimension which is given the value of 1000
23822381
n = 5
23832382
vals = np.zeros((4, 5, 3), dtype="int32")
@@ -2386,19 +2385,23 @@ def test_batch_dirichlet_multinomial(self):
23862385
np.put_along_axis(vals, inds, n, axis=-1)
23872386
np.put_along_axis(a, inds, 1000, axis=-1)
23882387

2389-
dist = DirichletMultinomial.dist(n=n, a=a, size=vals.shape)
2388+
dist = DirichletMultinomial.dist(n=n, a=a)
23902389

2391-
# Logp should be approx -9.924431e-06
2392-
dist_logp = logpt(dist, vals).tag.test_value
2393-
expected_logp = np.full(shape=vals.shape[:-1] + (1,), fill_value=-9.924431e-06)
2390+
# Logp should be approx -9.98004998e-06
2391+
value = at.tensor3(dtype="int32")
2392+
value.tag.test_value = np.zeros_like(vals, dtype="int32")
2393+
logp = logpt(dist, value)
2394+
f = aesara.function(inputs=[value], outputs=logp)
2395+
expected_logp = np.full(shape=f(vals).shape, fill_value=-9.98004998e-06)
23942396
assert_almost_equal(
2395-
dist_logp,
2397+
f(vals),
23962398
expected_logp,
23972399
decimal=select_by_precision(float64=6, float32=3),
23982400
)
23992401

24002402
# Samples should be equal given the almost deterministic DM
2401-
sample = dist.random(size=2)
2403+
dist = DirichletMultinomial.dist(n=n, a=a, size=2)
2404+
sample = dist.eval()
24022405
assert_allclose(sample, np.stack([vals, vals], axis=0))
24032406

24042407
@aesara.config.change_flags(compute_test_value="raise")

0 commit comments

Comments
 (0)