Skip to content

Commit a6a08bb

Browse files
committed
Removed the shape validation check to even work for last dimensional shape as 1.
Modified the `test_dirichlet` function to check for the same.
1 parent 18fd1f5 commit a6a08bb

File tree

2 files changed

+1
-2
lines changed

2 files changed

+1
-2
lines changed

pymc3/distributions/multivariate.py

-1
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,6 @@ def logp(self, value):
522522
tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1) + gammaln(tt.sum(a, axis=-1)),
523523
tt.all(value >= 0),
524524
tt.all(value <= 1),
525-
np.logical_not(a.broadcastable[-1]),
526525
tt.all(a > 0),
527526
broadcast_conditions=False,
528527
)

pymc3/tests/test_distributions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1692,7 +1692,7 @@ def test_lkj(self, x, eta, n, lp):
16921692
decimals = select_by_precision(float64=6, float32=4)
16931693
assert_almost_equal(model.fastlogp(pt), lp, decimal=decimals, err_msg=str(pt))
16941694

1695-
@pytest.mark.parametrize("n", [2, 3])
1695+
@pytest.mark.parametrize("n", [1, 2, 3])
16961696
def test_dirichlet(self, n):
16971697
self.pymc3_matches_scipy(Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf)
16981698

0 commit comments

Comments
 (0)