Skip to content

Commit b6f76e5

Browse files
Add moments for KroneckerNormalDistribution (#5235)
* moments * code format
1 parent 36c7553 commit b6f76e5

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

pymc/distributions/multivariate.py

+7
Original file line numberDiff line numberDiff line change
@@ -1938,6 +1938,13 @@ def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs)
19381938
# mean = median = mode = mu
19391939
return super().dist([mu, sigma, *covs], **kwargs)
19401940

1941+
def get_moment(rv, size, mu, covs, chols, evds):
1942+
mean = mu
1943+
if not rv_size_is_none(size):
1944+
moment_size = at.concatenate([size, mu.shape])
1945+
mean = at.full(moment_size, mu)
1946+
return mean
1947+
19411948
def logp(value, mu, sigma, *covs):
19421949
"""
19431950
Calculate log-probability of Multivariate Normal distribution

pymc/tests/test_distributions_moments.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
HyperGeometric,
3535
Interpolated,
3636
InverseGamma,
37+
KroneckerNormal,
3738
Kumaraswamy,
3839
Laplace,
3940
Logistic,
@@ -110,7 +111,6 @@ def test_all_distributions_have_moments():
110111
dist_module.discrete.DiscreteWeibull,
111112
dist_module.multivariate.CAR,
112113
dist_module.multivariate.DirichletMultinomial,
113-
dist_module.multivariate.KroneckerNormal,
114114
dist_module.multivariate.Wishart,
115115
}
116116

@@ -1316,3 +1316,32 @@ def normal_sim(rng, mu, sigma, size):
13161316
cutoff = st.norm().ppf(1 - (alpha / 2))
13171317

13181318
assert np.all(np.abs((result - expected_sample_mean) / expected_sample_mean_std) < cutoff)
1319+
1320+
1321+
@pytest.mark.parametrize(
1322+
"mu, covs, size, expected",
1323+
[
1324+
(np.ones(1), [np.identity(1), np.identity(1)], None, np.ones(1)),
1325+
(np.ones(6), [np.identity(2), np.identity(3)], 5, np.ones((5, 6))),
1326+
(np.zeros(6), [np.identity(2), np.identity(3)], 6, np.zeros((6, 6))),
1327+
(np.zeros(3), [np.identity(3), np.identity(1)], 6, np.zeros((6, 3))),
1328+
(
1329+
np.array([1, 2, 3, 4]),
1330+
[
1331+
np.array([[1.0, 0.5], [0.5, 2]]),
1332+
np.array([[1.0, 0.4], [0.4, 2]]),
1333+
],
1334+
2,
1335+
np.array(
1336+
[
1337+
[1, 2, 3, 4],
1338+
[1, 2, 3, 4],
1339+
]
1340+
),
1341+
),
1342+
],
1343+
)
1344+
def test_kronecker_normal_moments(mu, covs, size, expected):
1345+
with Model() as model:
1346+
KroneckerNormal("x", mu=mu, covs=covs, size=size)
1347+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)