From 79be6134d93a38f748aca864d6f200caa5aec355 Mon Sep 17 00:00:00 2001 From: Max Horn Date: Sat, 27 Feb 2021 15:09:30 +0100 Subject: [PATCH] Keep broadcasting information in make_shared_replacements It seems like broadcasting information gets lost when applying `pm.make_shared_replacements`, leading to problems with the metropolis sampler. Potentially related issues below: - https://github.com/pymc-devs/pymc3/issues/1083 - https://github.com/pymc-devs/pymc3/issues/1304 - https://github.com/pymc-devs/pymc3/issues/1983 This fix was previously suggested in the following issue: - https://github.com/pymc-devs/pymc3/issues/3337 It could be that further adaptations are necessary as indicated in the issue. Strangely, this does not seem to lead to problems when using NUTS. --- RELEASE-NOTES.md | 4 ++-- pymc3/tests/test_theanof.py | 25 +++++++++++++++++++++++++ pymc3/theanof.py | 7 ++++++- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 71d0b35085..388ca345e4 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -9,8 +9,8 @@ + ... ### Maintenance -- ⚠ Our memoization mechanism wasn't robust against hash collisions (#4506), sometimes resulting in incorrect values in, for example, posterior predictives. The `pymc3.memoize` module was removed and replaced with `cachetools`. The `hashable` function and `WithMemoization` class were moved to `pymc3.util`. -- ... +- ⚠ Our memoization mechanism wasn't robust against hash collisions (#4506), sometimes resulting in incorrect values in, for example, posterior predictives. The `pymc3.memoize` module was removed and replaced with `cachetools`. The `hashable` function and `WithMemoization` class were moved to `pymc3.util` (see #4525). +- `pm.make_shared_replacements` now retains broadcasting information which fixes issues with Metropolis samplers (see [#4492](https://github.com/pymc-devs/pymc3/pull/4492)). ## PyMC3 3.11.1 (12 February 2021) diff --git a/pymc3/tests/test_theanof.py b/pymc3/tests/test_theanof.py index d54aed680d..bf2ff61fce 100644 --- a/pymc3/tests/test_theanof.py +++ b/pymc3/tests/test_theanof.py @@ -19,6 +19,8 @@ import theano import theano.tensor as tt +import pymc3 as pm + from pymc3.theanof import _conversion_map, take_along_axis from pymc3.vartypes import int_types @@ -26,6 +28,29 @@ INTX = str(_conversion_map[FLOATX]) +class TestBroadcasting: + def test_make_shared_replacements(self): + """Check if pm.make_shared_replacements preserves broadcasting.""" + + with pm.Model() as test_model: + test1 = pm.Normal("test1", mu=0.0, sigma=1.0, shape=(1, 10)) + test2 = pm.Normal("test2", mu=0.0, sigma=1.0, shape=(10, 1)) + + # Replace test1 with a shared variable, keep test 2 the same + replacement = pm.make_shared_replacements([test_model.test2], test_model) + assert test_model.test1.broadcastable == replacement[test_model.test1].broadcastable + + def test_metropolis_sampling(self): + """Check if the Metropolis sampler can handle broadcasting.""" + with pm.Model() as test_model: + test1 = pm.Normal("test1", mu=0.0, sigma=1.0, shape=(1, 10)) + test2 = pm.Normal("test2", mu=test1, sigma=1.0, shape=(10, 10)) + + step = pm.Metropolis() + # This should fail immediately if broadcasting does not work. + pm.sample(tune=5, draws=7, cores=1, step=step, compute_convergence_checks=False) + + def _make_along_axis_idx(arr_shape, indices, axis): # compute dimensions to iterate over if str(indices.dtype) not in int_types: diff --git a/pymc3/theanof.py b/pymc3/theanof.py index c40311da6e..50ab8afdaf 100644 --- a/pymc3/theanof.py +++ b/pymc3/theanof.py @@ -235,7 +235,12 @@ def make_shared_replacements(vars, model): Dict of variable -> new shared variable """ othervars = set(model.vars) - set(vars) - return {var: theano.shared(var.tag.test_value, var.name + "_shared") for var in othervars} + return { + var: theano.shared( + var.tag.test_value, var.name + "_shared", broadcastable=var.broadcastable + ) + for var in othervars + } def join_nonshared_inputs(xs, vars, shared, make_shared=False):