Skip to content

Commit 7d62c53

Browse files
Fix bug when model has exactly one variable (#437)
1 parent 0a00a31 commit 7d62c53

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

pymc_extras/inference/laplace.py

+3
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,10 @@ def sample_laplace_posterior(
377377
posterior_dist = stats.multivariate_normal(
378378
mean=mu.data, cov=H_inv, allow_singular=True, seed=rng
379379
)
380+
380381
posterior_draws = posterior_dist.rvs(size=(chains, draws))
382+
if mu.data.shape == (1,):
383+
posterior_draws = np.expand_dims(posterior_draws, -1)
381384

382385
if transform_samples:
383386
constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model)

tests/test_laplace.py

+16
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,19 @@ def test_fit_laplace(fit_in_unconstrained_space, mode, gradient_backend: Gradien
263263
else:
264264
assert idata.fit.rows.values.tolist() == ["mu", "sigma"]
265265
np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 1.5]), atol=0.1)
266+
267+
268+
def test_laplace_scalar():
269+
# Example model from Statistical Rethinking
270+
data = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1])
271+
272+
with pm.Model():
273+
p = pm.Uniform("p", 0, 1)
274+
w = pm.Binomial("w", n=len(data), p=p, observed=data.sum())
275+
276+
idata_laplace = pmx.fit_laplace(progressbar=False)
277+
278+
assert idata_laplace.fit.mean_vector.shape == (1,)
279+
assert idata_laplace.fit.covariance_matrix.shape == (1, 1)
280+
281+
np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1)

0 commit comments

Comments
 (0)