-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Make Metropolis cope better with multiple dimensions #5823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -168,10 +168,12 @@ def __init__( | |||||||||
| vars = model.value_vars | ||||||||||
| else: | ||||||||||
| vars = [model.rvs_to_values.get(var, var) for var in vars] | ||||||||||
|
|
||||||||||
| vars = pm.inputvars(vars) | ||||||||||
|
|
||||||||||
| initial_values_shape = [initial_values[v.name].shape for v in vars] | ||||||||||
| if S is None: | ||||||||||
| S = np.ones(sum(initial_values[v.name].size for v in vars)) | ||||||||||
| S = np.ones(int(sum(np.prod(ivs) for ivs in initial_values_shape))) | ||||||||||
|
|
||||||||||
| if proposal_dist is not None: | ||||||||||
| self.proposal_dist = proposal_dist(S) | ||||||||||
|
|
@@ -186,7 +188,6 @@ def __init__( | |||||||||
| self.tune = tune | ||||||||||
| self.tune_interval = tune_interval | ||||||||||
| self.steps_until_tune = tune_interval | ||||||||||
| self.accepted = 0 | ||||||||||
|
|
||||||||||
| # Determine type of variables | ||||||||||
| self.discrete = np.concatenate( | ||||||||||
|
|
@@ -195,11 +196,33 @@ def __init__( | |||||||||
| self.any_discrete = self.discrete.any() | ||||||||||
| self.all_discrete = self.discrete.all() | ||||||||||
|
|
||||||||||
| # remember initial settings before tuning so they can be reset | ||||||||||
| self._untuned_settings = dict( | ||||||||||
| scaling=self.scaling, steps_until_tune=tune_interval, accepted=self.accepted | ||||||||||
| # Metropolis will try to handle one batched dimension at a time This, however, | ||||||||||
| # is not safe for discrete multivariate distributions (looking at you Multinomial), | ||||||||||
| # due to high dependency among the support dimensions. For continuous multivariate | ||||||||||
| # distributions we assume they are being transformed in a way that makes each | ||||||||||
| # dimension semi-independent. | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| is_scalar = len(initial_values_shape) == 1 and initial_values_shape[0] == () | ||||||||||
| self.elemwise_update = not ( | ||||||||||
| is_scalar | ||||||||||
| or ( | ||||||||||
| self.any_discrete | ||||||||||
| and max(getattr(model.values_to_rvs[var].owner.op, "ndim_supp", 1) for var in vars) | ||||||||||
| > 0 | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| if self.elemwise_update: | ||||||||||
| dims = int(sum(np.prod(ivs) for ivs in initial_values_shape)) | ||||||||||
|
Comment on lines
+213
to
+214
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Right now (reading the diff top to bottom) I'm confused because this smells like
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my comment below. |
||||||||||
| else: | ||||||||||
| dims = 1 | ||||||||||
| self.enum_dims = np.arange(dims, dtype=int) | ||||||||||
| self.accept_rate_iter = np.zeros(dims, dtype=float) | ||||||||||
| self.accepted_iter = np.zeros(dims, dtype=bool) | ||||||||||
| self.accepted_sum = np.zeros(dims, dtype=int) | ||||||||||
|
|
||||||||||
| # remember initial settings before tuning so they can be reset | ||||||||||
| self._untuned_settings = dict(scaling=self.scaling, steps_until_tune=tune_interval) | ||||||||||
|
|
||||||||||
| # TODO: This is not being used when compiling the logp function! | ||||||||||
| self.mode = mode | ||||||||||
|
|
||||||||||
| shared = pm.make_shared_replacements(initial_values, vars, model) | ||||||||||
|
|
@@ -210,6 +233,7 @@ def reset_tuning(self): | |||||||||
| """Resets the tuned sampler parameters to their initial values.""" | ||||||||||
| for attr, initial_value in self._untuned_settings.items(): | ||||||||||
| setattr(self, attr, initial_value) | ||||||||||
| self.accepted_sum[:] = 0 | ||||||||||
| return | ||||||||||
|
|
||||||||||
| def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: | ||||||||||
|
|
@@ -219,10 +243,10 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: | |||||||||
|
|
||||||||||
| if not self.steps_until_tune and self.tune: | ||||||||||
| # Tune scaling parameter | ||||||||||
| self.scaling = tune(self.scaling, self.accepted / float(self.tune_interval)) | ||||||||||
| self.scaling = tune(self.scaling, self.accepted_sum / float(self.tune_interval)) | ||||||||||
| # Reset counter | ||||||||||
| self.steps_until_tune = self.tune_interval | ||||||||||
| self.accepted = 0 | ||||||||||
| self.accepted_sum[:] = 0 | ||||||||||
|
|
||||||||||
| delta = self.proposal_dist() * self.scaling | ||||||||||
|
|
||||||||||
|
|
@@ -237,23 +261,36 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: | |||||||||
| else: | ||||||||||
| q = floatX(q0 + delta) | ||||||||||
|
|
||||||||||
| accept = self.delta_logp(q, q0) | ||||||||||
| q_new, accepted = metrop_select(accept, q, q0) | ||||||||||
|
|
||||||||||
| self.accepted += accepted | ||||||||||
| if self.elemwise_update: | ||||||||||
| q_temp = q0.copy() | ||||||||||
| # Shuffle order of updates (probably we don't need to do this in every step) | ||||||||||
| np.random.shuffle(self.enum_dims) | ||||||||||
| for i in self.enum_dims: | ||||||||||
| q_temp[i] = q[i] | ||||||||||
| accept_rate_i = self.delta_logp(q_temp, q0) | ||||||||||
| q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0) | ||||||||||
| q_temp[i] = q_temp_[i] | ||||||||||
| self.accept_rate_iter[i] = accept_rate_i | ||||||||||
| self.accepted_iter[i] = accepted_i | ||||||||||
| self.accepted_sum[i] += accepted_i | ||||||||||
| q = q_temp | ||||||||||
|
Comment on lines
+268
to
+276
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this doing what we usually do with If not, maybe explain what the if/else blocks do in a code comment
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CompoundStep assigns one variable per step, here we are sampling one dimension within a variable (or within multiple variables) semi-independently |
||||||||||
| else: | ||||||||||
| accept_rate = self.delta_logp(q, q0) | ||||||||||
| q, accepted = metrop_select(accept_rate, q, q0) | ||||||||||
| self.accept_rate_iter = accept_rate | ||||||||||
| self.accepted_iter = accepted | ||||||||||
| self.accepted_sum += accepted | ||||||||||
|
|
||||||||||
| self.steps_until_tune -= 1 | ||||||||||
|
|
||||||||||
| stats = { | ||||||||||
| "tune": self.tune, | ||||||||||
| "scaling": self.scaling, | ||||||||||
| "accept": np.exp(accept), | ||||||||||
| "accepted": accepted, | ||||||||||
| "scaling": np.mean(self.scaling), | ||||||||||
| "accept": np.mean(np.exp(self.accept_rate_iter)), | ||||||||||
| "accepted": np.mean(self.accepted_iter), | ||||||||||
| } | ||||||||||
|
|
||||||||||
| q_new = RaveledVars(q_new, point_map_info) | ||||||||||
|
|
||||||||||
| return q_new, [stats] | ||||||||||
| return RaveledVars(q, point_map_info), [stats] | ||||||||||
|
|
||||||||||
| @staticmethod | ||||||||||
| def competence(var, has_grad): | ||||||||||
|
|
@@ -275,26 +312,38 @@ def tune(scale, acc_rate): | |||||||||
| >0.95 x 10 | ||||||||||
|
|
||||||||||
| """ | ||||||||||
| if acc_rate < 0.001: | ||||||||||
| return scale * np.where( | ||||||||||
ricardoV94 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| acc_rate < 0.001, | ||||||||||
| # reduce by 90 percent | ||||||||||
| return scale * 0.1 | ||||||||||
| elif acc_rate < 0.05: | ||||||||||
| # reduce by 50 percent | ||||||||||
| return scale * 0.5 | ||||||||||
| elif acc_rate < 0.2: | ||||||||||
| # reduce by ten percent | ||||||||||
| return scale * 0.9 | ||||||||||
| elif acc_rate > 0.95: | ||||||||||
| # increase by factor of ten | ||||||||||
| return scale * 10.0 | ||||||||||
| elif acc_rate > 0.75: | ||||||||||
| # increase by double | ||||||||||
| return scale * 2.0 | ||||||||||
| elif acc_rate > 0.5: | ||||||||||
| # increase by ten percent | ||||||||||
| return scale * 1.1 | ||||||||||
|
|
||||||||||
| return scale | ||||||||||
| 0.1, | ||||||||||
| np.where( | ||||||||||
| acc_rate < 0.05, | ||||||||||
| # reduce by 50 percent | ||||||||||
| 0.5, | ||||||||||
| np.where( | ||||||||||
| acc_rate < 0.2, | ||||||||||
| # reduce by ten percent | ||||||||||
| 0.9, | ||||||||||
| np.where( | ||||||||||
| acc_rate > 0.95, | ||||||||||
| # increase by factor of ten | ||||||||||
| 10.0, | ||||||||||
| np.where( | ||||||||||
| acc_rate > 0.75, | ||||||||||
| # increase by double | ||||||||||
| 2.0, | ||||||||||
| np.where( | ||||||||||
| acc_rate > 0.5, | ||||||||||
| # increase by ten percent | ||||||||||
| 1.1, | ||||||||||
| # Do not change | ||||||||||
| 1.0, | ||||||||||
| ), | ||||||||||
| ), | ||||||||||
| ), | ||||||||||
| ), | ||||||||||
| ), | ||||||||||
| ) | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ugh, can't we do this in a
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can probably index, that should be faster for large arrays because all branches of |
||||||||||
|
|
||||||||||
|
|
||||||||||
| class BinaryMetropolis(ArrayStep): | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,7 +35,9 @@ | |
| Beta, | ||
| Binomial, | ||
| Categorical, | ||
| Dirichlet, | ||
| HalfNormal, | ||
| Multinomial, | ||
| MvNormal, | ||
| Normal, | ||
| ) | ||
|
|
@@ -174,33 +176,6 @@ def test_step_categorical(self, proposal): | |
| self.check_stat(check, idata, step.__class__.__name__) | ||
|
|
||
|
|
||
| class TestMetropolisProposal: | ||
| def test_proposal_choice(self): | ||
| with aesara.config.change_flags(mode=fast_unstable_sampling_mode): | ||
| _, model, _ = mv_simple() | ||
| with model: | ||
| initial_point = model.initial_point() | ||
| initial_point_size = sum(initial_point[n.name].size for n in model.value_vars) | ||
|
|
||
| s = np.ones(initial_point_size) | ||
| sampler = Metropolis(S=s) | ||
| assert isinstance(sampler.proposal_dist, NormalProposal) | ||
| s = np.diag(s) | ||
| sampler = Metropolis(S=s) | ||
| assert isinstance(sampler.proposal_dist, MultivariateNormalProposal) | ||
| s[0, 0] = -s[0, 0] | ||
| with pytest.raises(np.linalg.LinAlgError): | ||
| sampler = Metropolis(S=s) | ||
|
|
||
| def test_mv_proposal(self): | ||
| np.random.seed(42) | ||
| cov = np.random.randn(5, 5) | ||
| cov = cov.dot(cov.T) | ||
| prop = MultivariateNormalProposal(cov) | ||
| samples = np.array([prop() for _ in range(10000)]) | ||
| npt.assert_allclose(np.cov(samples.T), cov, rtol=0.2) | ||
|
|
||
|
|
||
| class TestCompoundStep: | ||
| samplers = (Metropolis, Slice, HamiltonianMC, NUTS, DEMetropolis) | ||
|
|
||
|
|
@@ -383,6 +358,31 @@ def test_parallelized_chains_are_random(self): | |
|
|
||
|
|
||
| class TestMetropolis: | ||
| def test_proposal_choice(self): | ||
| with aesara.config.change_flags(mode=fast_unstable_sampling_mode): | ||
| _, model, _ = mv_simple() | ||
| with model: | ||
| initial_point = model.initial_point() | ||
| initial_point_size = sum(initial_point[n.name].size for n in model.value_vars) | ||
|
|
||
| s = np.ones(initial_point_size) | ||
| sampler = Metropolis(S=s) | ||
| assert isinstance(sampler.proposal_dist, NormalProposal) | ||
| s = np.diag(s) | ||
| sampler = Metropolis(S=s) | ||
| assert isinstance(sampler.proposal_dist, MultivariateNormalProposal) | ||
| s[0, 0] = -s[0, 0] | ||
| with pytest.raises(np.linalg.LinAlgError): | ||
| sampler = Metropolis(S=s) | ||
|
|
||
| def test_mv_proposal(self): | ||
| np.random.seed(42) | ||
| cov = np.random.randn(5, 5) | ||
| cov = cov.dot(cov.T) | ||
| prop = MultivariateNormalProposal(cov) | ||
| samples = np.array([prop() for _ in range(10000)]) | ||
| npt.assert_allclose(np.cov(samples.T), cov, rtol=0.2) | ||
|
|
||
| def test_tuning_reset(self): | ||
| """Re-use of the step method instance with cores=1 must not leak tuning information between chains.""" | ||
| with Model() as pmodel: | ||
|
|
@@ -403,6 +403,40 @@ def test_tuning_reset(self): | |
| assert tuned != 0.1 | ||
| np.testing.assert_array_equal(idata.sample_stats["scaling"].sel(chain=c).values, tuned) | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "batched_dist", | ||
| ( | ||
| Binomial.dist(n=5, p=0.9), # scalar case | ||
| Binomial.dist(n=np.arange(40) + 1, p=np.linspace(0.1, 0.9, 40), shape=(40,)), | ||
| Binomial.dist( | ||
| n=(np.arange(20) + 1)[::-1], | ||
| p=np.linspace(0.1, 0.9, 20), | ||
| shape=( | ||
| 2, | ||
| 20, | ||
| ), | ||
| ), | ||
| Dirichlet.dist(a=np.ones(3) * (np.arange(40) + 1)[:, None], shape=(40, 3)), | ||
| Dirichlet.dist(a=np.ones(3) * (np.arange(20) + 1)[:, None], shape=(2, 20, 3)), | ||
| ), | ||
| ) | ||
| def test_elemwise_update(self, batched_dist): | ||
|
||
| with Model() as m: | ||
| m.register_rv(batched_dist, name="batched_dist") | ||
| step = pm.Metropolis([batched_dist]) | ||
| assert step.elemwise_update == (batched_dist.ndim > 0) | ||
| trace = pm.sample(draws=1000, chains=2, step=step, random_seed=428) | ||
|
|
||
| assert az.rhat(trace).max()["batched_dist"].values < 1.1 | ||
| assert az.ess(trace).min()["batched_dist"].values > 50 | ||
|
|
||
| def test_multinomial_no_elemwise_update(self): | ||
| with Model() as m: | ||
| batched_dist = Multinomial("batched_dist", n=5, p=np.ones(4) / 4, shape=(10, 4)) | ||
| with aesara.config.change_flags(mode=fast_unstable_sampling_mode): | ||
| step = pm.Metropolis([batched_dist]) | ||
| assert not step.elemwise_update | ||
|
|
||
|
|
||
| class TestDEMetropolisZ: | ||
| def test_tuning_lambda_sequential(self): | ||
|
|
@@ -1217,8 +1251,6 @@ def perform(self, node, inputs, outputs): | |
| mout = [] | ||
| coarse_models = [] | ||
|
|
||
| rng = np.random.RandomState(seed) | ||
|
|
||
| with Model() as coarse_model_0: | ||
| if aesara.config.floatX == "float32": | ||
| Q = Data("Q", np.float32(0.0)) | ||
|
|
@@ -1236,8 +1268,6 @@ def perform(self, node, inputs, outputs): | |
|
|
||
| coarse_models.append(coarse_model_0) | ||
|
|
||
| rng = np.random.RandomState(seed) | ||
|
|
||
| with Model() as coarse_model_1: | ||
| if aesara.config.floatX == "float32": | ||
| Q = Data("Q", np.float32(0.0)) | ||
|
|
@@ -1255,8 +1285,6 @@ def perform(self, node, inputs, outputs): | |
|
|
||
| coarse_models.append(coarse_model_1) | ||
|
|
||
| rng = np.random.RandomState(seed) | ||
|
|
||
| with Model() as model: | ||
| if aesara.config.floatX == "float32": | ||
| Q = Data("Q", np.float32(0.0)) | ||
|
|
@@ -1314,8 +1342,9 @@ def perform(self, node, inputs, outputs): | |
| (nchains, ndraws * nsub) | ||
| ) | ||
| Q_2_1 = np.concatenate(trace.get_sampler_stats("Q_2_1")).reshape((nchains, ndraws)) | ||
| assert Q_1_0.mean(axis=1) == 0.0 | ||
| assert Q_2_1.mean(axis=1) == 0.0 | ||
| # This used to be a scrict zero equality! | ||
| assert np.isclose(Q_1_0.mean(axis=1), 0.0, atol=1e-4) | ||
| assert np.isclose(Q_2_1.mean(axis=1), 0.0, atol=1e-4) | ||
|
||
|
|
||
|
|
||
| class TestRVsAssignmentSteps: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.