@@ -2172,6 +2172,18 @@ def test_dirichlet(self, n):
21722172 dirichlet_logpdf ,
21732173 )
21742174
2175+ def test_dirichlet_invalid (self ):
2176+ # Test non-scalar invalid parameters/values
2177+ value = np .array ([[0.1 , 0.2 , 0.7 ], [0.3 , 0.3 , 0.4 ]])
2178+
2179+ invalid_dist = Dirichlet .dist (a = [- 1 , 1 , 2 ], size = 2 )
2180+ with pytest .raises (ParameterValueError ):
2181+ pm .logp (invalid_dist , value ).eval ()
2182+
2183+ value [1 ] -= 1
2184+ valid_dist = Dirichlet .dist (a = [1 , 1 , 1 ])
2185+ assert np .all (np .isfinite (pm .logp (valid_dist , value ).eval ()) == np .array ([True , False ]))
2186+
21752187 @pytest .mark .parametrize (
21762188 "a" ,
21772189 [
@@ -2203,6 +2215,20 @@ def test_multinomial(self, n):
22032215 lambda value , n , p : scipy .stats .multinomial .logpmf (value , n , p ),
22042216 )
22052217
2218+ def test_multinomial_invalid (self ):
2219+ # Test non-scalar invalid parameters/values
2220+ value = np .array ([[1 , 2 , 2 ], [4 , 0 , 1 ]])
2221+
2222+ invalid_dist = Multinomial .dist (n = 5 , p = [- 1 , 1 , 1 ], size = 2 )
2223+ # TODO: Multinomial normalizes p, so it is impossible to trigger p checks
2224+ # with pytest.raises(ParameterValueError):
2225+ with does_not_raise ():
2226+ pm .logp (invalid_dist , value ).eval ()
2227+
2228+ value [1 ] -= 1
2229+ valid_dist = Multinomial .dist (n = 5 , p = np .ones (3 ) / 3 )
2230+ assert np .all (np .isfinite (pm .logp (valid_dist , value ).eval ()) == np .array ([True , False ]))
2231+
22062232 @pytest .mark .parametrize ("n" , [(10 ), ([10 , 11 ]), ([[5 , 6 ], [10 , 11 ]])])
22072233 @pytest .mark .parametrize (
22082234 "p" ,
@@ -2243,6 +2269,18 @@ def test_dirichlet_multinomial(self, n):
22432269 dirichlet_multinomial_logpmf ,
22442270 )
22452271
2272+ def test_dirichlet_multinomial_invalid (self ):
2273+ # Test non-scalar invalid parameters/values
2274+ value = np .array ([[1 , 2 , 2 ], [4 , 0 , 1 ]])
2275+
2276+ invalid_dist = DirichletMultinomial .dist (n = 5 , a = [- 1 , 1 , 1 ], size = 2 )
2277+ with pytest .raises (ParameterValueError ):
2278+ pm .logp (invalid_dist , value ).eval ()
2279+
2280+ value [1 ] -= 1
2281+ valid_dist = DirichletMultinomial .dist (n = 5 , a = [1 , 1 , 1 ])
2282+ assert np .all (np .isfinite (pm .logp (valid_dist , value ).eval ()) == np .array ([True , False ]))
2283+
22462284 def test_dirichlet_multinomial_matches_beta_binomial (self ):
22472285 a , b , n = 2 , 1 , 5
22482286 ns = np .arange (n + 1 )
0 commit comments