@@ -2172,6 +2172,18 @@ def test_dirichlet(self, n):
2172
2172
dirichlet_logpdf ,
2173
2173
)
2174
2174
2175
+ def test_dirichlet_invalid (self ):
2176
+ # Test non-scalar invalid parameters/values
2177
+ value = np .array ([[0.1 , 0.2 , 0.7 ], [0.3 , 0.3 , 0.4 ]])
2178
+
2179
+ invalid_dist = Dirichlet .dist (a = [- 1 , 1 , 2 ], size = 2 )
2180
+ with pytest .raises (ParameterValueError ):
2181
+ pm .logp (invalid_dist , value ).eval ()
2182
+
2183
+ value [1 ] -= 1
2184
+ valid_dist = Dirichlet .dist (a = [1 , 1 , 1 ])
2185
+ assert np .all (np .isfinite (pm .logp (valid_dist , value ).eval ()) == np .array ([True , False ]))
2186
+
2175
2187
@pytest .mark .parametrize (
2176
2188
"a" ,
2177
2189
[
@@ -2203,6 +2215,20 @@ def test_multinomial(self, n):
2203
2215
lambda value , n , p : scipy .stats .multinomial .logpmf (value , n , p ),
2204
2216
)
2205
2217
2218
+ def test_multinomial_invalid (self ):
2219
+ # Test non-scalar invalid parameters/values
2220
+ value = np .array ([[1 , 2 , 2 ], [4 , 0 , 1 ]])
2221
+
2222
+ invalid_dist = Multinomial .dist (n = 5 , p = [- 1 , 1 , 1 ], size = 2 )
2223
+ # TODO: Multinomial normalizes p, so it is impossible to trigger p checks
2224
+ # with pytest.raises(ParameterValueError):
2225
+ with does_not_raise ():
2226
+ pm .logp (invalid_dist , value ).eval ()
2227
+
2228
+ value [1 ] -= 1
2229
+ valid_dist = Multinomial .dist (n = 5 , p = np .ones (3 ) / 3 )
2230
+ assert np .all (np .isfinite (pm .logp (valid_dist , value ).eval ()) == np .array ([True , False ]))
2231
+
2206
2232
@pytest .mark .parametrize ("n" , [(10 ), ([10 , 11 ]), ([[5 , 6 ], [10 , 11 ]])])
2207
2233
@pytest .mark .parametrize (
2208
2234
"p" ,
@@ -2243,6 +2269,18 @@ def test_dirichlet_multinomial(self, n):
2243
2269
dirichlet_multinomial_logpmf ,
2244
2270
)
2245
2271
2272
+ def test_dirichlet_multinomial_invalid (self ):
2273
+ # Test non-scalar invalid parameters/values
2274
+ value = np .array ([[1 , 2 , 2 ], [4 , 0 , 1 ]])
2275
+
2276
+ invalid_dist = DirichletMultinomial .dist (n = 5 , a = [- 1 , 1 , 1 ], size = 2 )
2277
+ with pytest .raises (ParameterValueError ):
2278
+ pm .logp (invalid_dist , value ).eval ()
2279
+
2280
+ value [1 ] -= 1
2281
+ valid_dist = DirichletMultinomial .dist (n = 5 , a = [1 , 1 , 1 ])
2282
+ assert np .all (np .isfinite (pm .logp (valid_dist , value ).eval ()) == np .array ([True , False ]))
2283
+
2246
2284
def test_dirichlet_multinomial_matches_beta_binomial (self ):
2247
2285
a , b , n = 2 , 1 , 5
2248
2286
ns = np .arange (n + 1 )
0 commit comments