|
34 | 34 | HyperGeometric,
|
35 | 35 | Interpolated,
|
36 | 36 | InverseGamma,
|
| 37 | + KroneckerNormal, |
37 | 38 | Kumaraswamy,
|
38 | 39 | Laplace,
|
39 | 40 | Logistic,
|
@@ -110,7 +111,6 @@ def test_all_distributions_have_moments():
|
110 | 111 | dist_module.discrete.DiscreteWeibull,
|
111 | 112 | dist_module.multivariate.CAR,
|
112 | 113 | dist_module.multivariate.DirichletMultinomial,
|
113 |
| - dist_module.multivariate.KroneckerNormal, |
114 | 114 | dist_module.multivariate.Wishart,
|
115 | 115 | }
|
116 | 116 |
|
@@ -1316,3 +1316,32 @@ def normal_sim(rng, mu, sigma, size):
|
1316 | 1316 | cutoff = st.norm().ppf(1 - (alpha / 2))
|
1317 | 1317 |
|
1318 | 1318 | assert np.all(np.abs((result - expected_sample_mean) / expected_sample_mean_std) < cutoff)
|
| 1319 | + |
| 1320 | + |
| 1321 | +@pytest.mark.parametrize( |
| 1322 | + "mu, covs, size, expected", |
| 1323 | + [ |
| 1324 | + (np.ones(1), [np.identity(1), np.identity(1)], None, np.ones(1)), |
| 1325 | + (np.ones(6), [np.identity(2), np.identity(3)], 5, np.ones((5, 6))), |
| 1326 | + (np.zeros(6), [np.identity(2), np.identity(3)], 6, np.zeros((6, 6))), |
| 1327 | + (np.zeros(3), [np.identity(3), np.identity(1)], 6, np.zeros((6, 3))), |
| 1328 | + ( |
| 1329 | + np.array([1, 2, 3, 4]), |
| 1330 | + [ |
| 1331 | + np.array([[1.0, 0.5], [0.5, 2]]), |
| 1332 | + np.array([[1.0, 0.4], [0.4, 2]]), |
| 1333 | + ], |
| 1334 | + 2, |
| 1335 | + np.array( |
| 1336 | + [ |
| 1337 | + [1, 2, 3, 4], |
| 1338 | + [1, 2, 3, 4], |
| 1339 | + ] |
| 1340 | + ), |
| 1341 | + ), |
| 1342 | + ], |
| 1343 | +) |
| 1344 | +def test_kronecker_normal_moments(mu, covs, size, expected): |
| 1345 | + with Model() as model: |
| 1346 | + KroneckerNormal("x", mu=mu, covs=covs, size=size) |
| 1347 | + assert_moment_is_expected(model, expected) |
0 commit comments