Skip to content

Commit 901b72f

Browse files
committed
Add bound tests for some non-scalar values and parameters
1 parent a3ab0f1 commit 901b72f

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

Diff for: pymc/tests/test_distributions.py

+38
Original file line numberDiff line numberDiff line change
@@ -2172,6 +2172,18 @@ def test_dirichlet(self, n):
21722172
dirichlet_logpdf,
21732173
)
21742174

2175+
def test_dirichlet_invalid(self):
2176+
# Test non-scalar invalid parameters/values
2177+
value = np.array([[0.1, 0.2, 0.7], [0.3, 0.3, 0.4]])
2178+
2179+
invalid_dist = Dirichlet.dist(a=[-1, 1, 2], size=2)
2180+
with pytest.raises(ParameterValueError):
2181+
pm.logp(invalid_dist, value).eval()
2182+
2183+
value[1] -= 1
2184+
valid_dist = Dirichlet.dist(a=[1, 1, 1])
2185+
assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False]))
2186+
21752187
@pytest.mark.parametrize(
21762188
"a",
21772189
[
@@ -2203,6 +2215,20 @@ def test_multinomial(self, n):
22032215
lambda value, n, p: scipy.stats.multinomial.logpmf(value, n, p),
22042216
)
22052217

2218+
def test_multinomial_invalid(self):
2219+
# Test non-scalar invalid parameters/values
2220+
value = np.array([[1, 2, 2], [4, 0, 1]])
2221+
2222+
invalid_dist = Multinomial.dist(n=5, p=[-1, 1, 1], size=2)
2223+
# TODO: Multinomial normalizes p, so it is impossible to trigger p checks
2224+
# with pytest.raises(ParameterValueError):
2225+
with does_not_raise():
2226+
pm.logp(invalid_dist, value).eval()
2227+
2228+
value[1] -= 1
2229+
valid_dist = Multinomial.dist(n=5, p=np.ones(3) / 3)
2230+
assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False]))
2231+
22062232
@pytest.mark.parametrize("n", [(10), ([10, 11]), ([[5, 6], [10, 11]])])
22072233
@pytest.mark.parametrize(
22082234
"p",
@@ -2243,6 +2269,18 @@ def test_dirichlet_multinomial(self, n):
22432269
dirichlet_multinomial_logpmf,
22442270
)
22452271

2272+
def test_dirichlet_multinomial_invalid(self):
2273+
# Test non-scalar invalid parameters/values
2274+
value = np.array([[1, 2, 2], [4, 0, 1]])
2275+
2276+
invalid_dist = DirichletMultinomial.dist(n=5, a=[-1, 1, 1], size=2)
2277+
with pytest.raises(ParameterValueError):
2278+
pm.logp(invalid_dist, value).eval()
2279+
2280+
value[1] -= 1
2281+
valid_dist = DirichletMultinomial.dist(n=5, a=[1, 1, 1])
2282+
assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False]))
2283+
22462284
def test_dirichlet_multinomial_matches_beta_binomial(self):
22472285
a, b, n = 2, 1, 5
22482286
ns = np.arange(n + 1)

0 commit comments

Comments
 (0)