Skip to content

Commit e04cf78

Browse files
AlexAndorrabrandonwillard
authored andcommitted
Refactor return statement in rng_fn
1 parent d082d2f commit e04cf78

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

pymc3/distributions/multivariate.py

+23
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,29 @@ def _distr_parameters_for_repr(self):
654654
return ["n", "a"]
655655

656656

657+
class OrderedMultinomial(Multinomial):
658+
rv_op = multinomial
659+
660+
@classmethod
661+
def dist(cls, eta, cutpoints, n, *args, **kwargs):
662+
eta = at.as_tensor_variable(floatX(eta))
663+
cutpoints = at.as_tensor_variable(cutpoints)
664+
n = at.as_tensor_variable(n)
665+
666+
pa = sigmoid(cutpoints - at.shape_padright(eta))
667+
p_cum = at.concatenate(
668+
[
669+
at.zeros_like(at.shape_padright(pa[..., 0])),
670+
pa,
671+
at.ones_like(at.shape_padright(pa[..., 0])),
672+
],
673+
axis=-1,
674+
)
675+
p = p_cum[..., 1:] - p_cum[..., :-1]
676+
677+
return super().dist([n, p], *args, **kwargs)
678+
679+
657680
def posdef(AA):
658681
try:
659682
linalg.cholesky(AA)

0 commit comments

Comments
 (0)