diff --git a/examples/slcp_snl_masked_autoregressive.py b/examples/slcp_snl_masked_autoregressive.py index 6e75fba..1024185 100644 --- a/examples/slcp_snl_masked_autoregressive.py +++ b/examples/slcp_snl_masked_autoregressive.py @@ -15,7 +15,6 @@ from jax import numpy as jnp from jax import random from jax import scipy as jsp -from jax import vmap from surjectors import Chain, TransformedDistribution from surjectors.bijectors.masked_autoregressive import MaskedAutoregressive from surjectors.bijectors.permutation import Permutation @@ -155,7 +154,16 @@ def run(use_surjectors): snl = SNL(fns, model) optimizer = optax.adam(1e-3) params, info = snl.fit( - random.PRNGKey(23), y_observed, optimizer, n_rounds=3, sampler="slice" + random.PRNGKey(23), + y_observed, + optimizer, + n_rounds=5, + sampler="slice", + n_thin=0, + max_n_iter=1, + n_samples=100, + n_warmup=50, + n_simulations_per_round=10, ) snl_samples, _ = snl.sample_posterior( @@ -170,8 +178,9 @@ def log_density_fn(theta, y): lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp) return lp + # log_density_partial = partial(log_density_fn, y=y_observed) - log_density = lambda x: vmap(log_density_partial)(x) + log_density = lambda x: log_density_partial(**x) rng_seq = hk.PRNGSequence(12) slice_samples = sample_with_slice( @@ -204,6 +213,6 @@ def log_density_fn(theta, y): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--use-surjectors", action="store_true", default=True) + parser.add_argument("--use-surjectors", action="store_true", default=False) args = parser.parse_args() run(args.use_surjectors) diff --git a/sbijax/mcmc/slice.py b/sbijax/mcmc/slice.py index 1bb1f8f..d2af50b 100644 --- a/sbijax/mcmc/slice.py +++ b/sbijax/mcmc/slice.py @@ -1,6 +1,9 @@ -import distrax +import jax +import numpy as np import tensorflow_probability.substrates.jax as tfp +from sbijax.mcmc.slice_sampler import slice_sampler + # pylint: disable=too-many-arguments def sample_with_slice( @@ -10,9 +13,10 @@ def sample_with_slice( n_samples, n_warmup, prior, - n_thin=2, + n_thin=0, n_doubling=5, - step_size=1, + implementation="custom", + **kwargs, ): """ Sample from a distribution using the No-U-Turn sampler. @@ -36,7 +40,36 @@ def sample_with_slice( a JAX array of dimension n_samples \times n_chains \times len_theta """ - initial_states = _slice_init(rng_seq, n_chains, prior) + if implementation == "custom": + fn = _custom_slice + else: + fn = _tfp_slice + return fn( + rng_seq, + lp, + n_chains, + n_samples, + n_warmup, + prior, + n_thin, + n_doubling, + **kwargs, + ) + + +# pylint: disable=too-many-arguments +def _tfp_slice( + rng_seq, + lp, + n_chains, + n_samples, + n_warmup, + prior, + n_thin=0, + n_doubling=5, + step_size=1.0, +): + initial_states = prior(seed=next(rng_seq), sample_shape=(n_chains,)) samples = tfp.mcmc.sample_chain( num_results=n_samples, current_state=initial_states, @@ -51,8 +84,62 @@ def sample_with_slice( return samples +# pylint: disable=too-many-arguments +def _custom_slice( + rng_seq, + lp, + n_chains, + n_samples, + n_warmup, + prior, + n_thin=0, + n_doubling=5, + **kwargs, +): + def _inference_loop(rng_key, kernel, initial_state, n_samples): + @jax.jit + def _step(states, rng_key): + keys = jax.random.split(rng_key, n_chains) + states = jax.vmap(kernel)(keys, states) + return states, states + + sampling_keys = jax.random.split(rng_key, n_samples) + _, states = jax.lax.scan(_step, initial_state, sampling_keys) + return states + + initial_states, kernel = _slice_init( + rng_seq, prior, n_chains, lp, n_doubling + ) + + n_total_samples = _minimal_sample_size_with_thinning( + n_samples, n_warmup, n_thin + ) + states = _inference_loop( + next(rng_seq), kernel, initial_states, n_total_samples + ) + _ = states.position["theta"].block_until_ready() + thetas = states.position["theta"][n_warmup:, :, :] + # thinning: take the n_thin-th sample as first point and then step to + # the next sample by skipping n_thin indexes, i.e. + thetas = thetas[n_thin :: (n_thin + 1), ...] + return thetas + + # pylint: disable=missing-function-docstring -def _slice_init(rng_seq, n_chains, prior: distrax.Distribution): +def _slice_init(rng_seq, prior, n_chains, lp, n_doublings): initial_positions = prior(seed=next(rng_seq), sample_shape=(n_chains,)) + initial_positions = {"theta": initial_positions} + init, kernel = slice_sampler(lp, n_doublings) + initial_states = jax.vmap(init)(initial_positions) + return initial_states, kernel + - return initial_positions +def _minimal_sample_size_with_thinning(n_samples, n_warmup, n_thin): + n_effective_samples = n_samples - n_warmup + n_to_draw = n_effective_samples + while True: + n_returned = len(np.arange(n_to_draw)[n_thin :: (n_thin + 1)]) + if n_returned >= n_effective_samples: + break + n_to_draw += 1 + return n_warmup + n_to_draw diff --git a/sbijax/mcmc/slice_sampler.py b/sbijax/mcmc/slice_sampler.py new file mode 100644 index 0000000..c51d155 --- /dev/null +++ b/sbijax/mcmc/slice_sampler.py @@ -0,0 +1,261 @@ +from typing import Callable, NamedTuple + +import jax +from blackjax.base import MCMCSamplingAlgorithm +from blackjax.types import PRNGKey, PyTree +from jax import numpy as jnp +from jax import random + + +class SliceState(NamedTuple): + position: PyTree + logdensity: PyTree + widths: PyTree + n: jnp.ndarray + + +def init(position: PyTree, logdensity_fn: Callable): + logdensity = logdensity_fn(position) + widths = jax.tree_map(lambda x: jnp.full(x.shape, 1.0), position) + return SliceState( + position, jnp.atleast_1d(logdensity), widths, jnp.atleast_1d(0.0) + ) + + +def kernel(n_doublings: int, doubling_procedure) -> Callable: + if doubling_procedure == "tfp": + db_fn = _tfp_doubling_fn + else: + db_fn = _doubling_fn + + def one_step(rng_key: PRNGKey, state: SliceState, logdensity_fn: Callable): + proposal_generator = slice_proposal(logdensity_fn, n_doublings, db_fn) + return proposal_generator(rng_key, state) + + return one_step + + +class slice_sampler: + init = staticmethod(init) + kernel = staticmethod(kernel) + + def __new__( # type: ignore[misc] + cls, + logdensity_fn: Callable, + n_doublings: int = 5, + doubling_procedure="tfp", + ) -> MCMCSamplingAlgorithm: + step = cls.kernel(n_doublings, doubling_procedure) + + def init_fn(position: PyTree): + return cls.init(position, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return step(rng_key, state, logdensity_fn) + + return MCMCSamplingAlgorithm(init_fn, step_fn) + + +def slice_proposal(logdensity_fn: Callable, n_doublings: int, db_fn): + def generate(rng_key, state): + order_key, rng_key = random.split(rng_key) + n = state.n[0] + positions, unravel_fn = jax.flatten_util.ravel_pytree(state.position) + widths, _ = jax.flatten_util.ravel_pytree(state.widths) + + def conditional_proposal(rng_key, idx): + return _sample_conditionally( + rng_key, + logdensity_fn, + idx, + positions, + widths, + n_doublings, + db_fn, + ) + + def body_fn(carry, rn): + seed, idx = rn + positions, widths = carry + xi, wi = conditional_proposal(seed, idx) + positions = positions.at[idx].set(xi) + nw = widths[idx] + (wi - widths[idx]) / (n + 1) + widths = widths.at[idx].set(nw) + return (positions, widths), (positions, widths) + + order = random.choice( + order_key, + jnp.arange(len(positions)), + shape=(len(positions),), + replace=False, + ) + + keys = random.split(rng_key, len(positions)) + (new_positions, new_widths), _ = jax.lax.scan( + body_fn, (positions, widths), (keys, order) + ) + + new_positions = unravel_fn(new_positions) + new_widths = unravel_fn(new_widths) + new_state = SliceState( + new_positions, + jnp.atleast_1d(logdensity_fn(new_positions)), + new_widths, + jnp.atleast_1d(n + 1.0), + ) + return new_state + + return generate + + +def _sample_conditionally( + seed, logdensity_fn, idx, positions, widths, n_doublings, db_fn +): + def cond_lp_fn(xi_to_set): + return logdensity_fn({"theta": positions.at[idx].set(xi_to_set)}) + + key, seed1, seed2 = random.split(seed, 3) + x0, w0 = positions[idx], widths[idx] + y = cond_lp_fn(x0) - random.exponential(key) + left, right, _ = db_fn(seed1, y, x0, cond_lp_fn, w0, n_doublings) + x1 = _shrinkage_fn(seed2, y, x0, cond_lp_fn, left, right, w0) + return x1, right - left + + +def _doubling_fn(rng, y, x0, cond_lp_fn, w, n_doublings): + """ + Implementation according to Fig 4 in [1] + References + ------- + [1] Radford Neil, Slice Sampling 2003 + """ + + key1, key2, key3, key4 = random.split(rng, 4) + left = x0 - w * random.uniform(key2) + right = left + w + K = n_doublings + + def cond_fn(state): + left, right, K, _ = state + return jnp.logical_and( + K > 0, + jnp.logical_or(y < cond_lp_fn(left), y < cond_lp_fn(right)), + ) + + def body_fn(state): + left, right, K, seed = state + rng, seed = random.split(seed) + v = random.uniform(rng) + left = jnp.where(v < 0.5, 2 * left - right, left) + right = jnp.where(v < 0.5, right, 2 * right - left) + return left, right, K - 1, seed + + left, right, _, _ = jax.lax.while_loop( + cond_fn, body_fn, (left, right, K, rng) + ) + return left, right, None + + +def _tfp_doubling_fn(rng, y, x0, cond_lp_fn, w, n_doublings): + """ + Implementation according to Fig 4 in [1] + References + ------- + [1] Radford Neil, Slice Sampling 2003 + """ + + key1, key2 = random.split(rng) + left = x0 - w * random.uniform(key1) + + K = n_doublings + 1 + + left_expands = random.bernoulli(key2, 0.5, (K,)) + width_multipliers = 2 ** jnp.arange(0, K, dtype=jnp.int32) + widths = width_multipliers * w + left_increments = jnp.cumsum(widths * left_expands) + + lefts = left - left_increments + rights = left + widths + + left_lps = jax.vmap(cond_lp_fn)(lefts) + right_lps = jax.vmap(cond_lp_fn)(rights) + left_oks = left_lps < y + right_oks = right_lps < y + both_ok = jnp.logical_and(left_oks, right_oks) + # best_interval_idx = _best_interval(both_ok.astype(jnp.int32)) + best_interval_idx = jnp.argmin(both_ok.astype(jnp.int32)) + + return ( + lefts[best_interval_idx], + rights[best_interval_idx], + both_ok[best_interval_idx], + ) + + +def _best_interval(x): + k = x.shape[0] + mults = jnp.arange(2 * k, k, -1, dtype=x.dtype) + shifts = jnp.arange(k, dtype=x.dtype) + indices = jnp.argmax(mults * x + shifts).astype(x.dtype) + return indices + + +def _shrinkage_fn(seed, y, x0, cond_lp_fn, left, right, w): + def cond_fn(state): + *_, found = state + return jnp.logical_not(found) + + def body_fn(state): + x1, left, right, seed, _ = state + key, seed = random.split(seed) + v = random.uniform(key) + x1 = left + v * (right - left) + + found = jnp.logical_and( + y < cond_lp_fn(x1), + _accept_fn(y, x1, x0, cond_lp_fn, left, right, w), + ) + + left = jnp.where(x1 < x0, x1, left) + right = jnp.where(x1 >= x0, x1, right) + + return x1, left, right, seed, found + + x1, left, right, seed, _ = jax.lax.while_loop( + cond_fn, body_fn, (x0, left, right, seed, False) + ) + return x1 + + +def _accept_fn(y, x1, x0, cond_lp_fn, left, right, w): + def cond_fn(state): + _, _, left, right, w, _, is_acceptable = state + return jnp.logical_and(right - left > 1.1 * w, is_acceptable) + + def body_fn(state): + x1, x0, left, right, w, D, _ = state + mid = (left + right) / 2 + D = jnp.logical_or( + jnp.logical_or( + jnp.logical_and(x0 < mid, x1 >= mid), + jnp.logical_and(x0 >= mid, x1 < mid), + ), + D, + ) + right = jnp.where(x1 < mid, mid, right) + left = jnp.where(x1 >= mid, mid, left) + + left_is_not_acceptable = y >= cond_lp_fn(left) + right_is_not_acceptable = y >= cond_lp_fn(right) + interval_is_not_acceptable = jnp.logical_and( + left_is_not_acceptable, right_is_not_acceptable + ) + is_still_acceptable = jnp.logical_not( + jnp.logical_and(D, interval_is_not_acceptable) + ) + return x1, x0, left, right, w, D, is_still_acceptable + + *_, is_acceptable = jax.lax.while_loop( + cond_fn, body_fn, (x1, x0, left, right, w, False, True) + ) + return is_acceptable diff --git a/sbijax/mcmc/test_slice.py b/sbijax/mcmc/test_slice.py new file mode 100644 index 0000000..3fff992 --- /dev/null +++ b/sbijax/mcmc/test_slice.py @@ -0,0 +1,139 @@ +# pylint: skip-file + +from functools import partial +from timeit import default_timer as timer + +import distrax +import jax +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +import tensorflow_probability.substrates.jax as tfp +from jax import numpy as jnp +from jax import random +from jax import scipy as jsp + +from sbijax.mcmc.slice import sample_with_slice +from sbijax.mcmc.slice_sampler import slice_sampler + +p = distrax.Independent(distrax.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)), 1) + + +def likelihood_fn(theta, y): + mu = jnp.tile(theta[:2], 4) + s1, s2 = theta[2] ** 2, theta[3] ** 2 + corr = s1 * s2 * jnp.tanh(theta[4]) + cov = jnp.array([[s1**2, corr], [corr, s2**2]]) + cov = jsp.linalg.block_diag(*[cov for _ in range(4)]) + p = distrax.MultivariateNormalFullCovariance(mu, cov) + return p.log_prob(y) + + +def log_density_fn(theta, y): + prior_lp = p.log_prob(theta) + likelihood_lp = likelihood_fn(theta, y) + + lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp) + return lp + + +len_theta = 5 +y_observed = jnp.array( + [ + [ + -0.9707123, + -2.9461224, + -0.4494722, + -3.4231849, + -0.13285634, + -3.364017, + -0.85367596, + -2.4271638, + ] + ] +) + + +def inference_loop(rng_key, kernel, initial_state, n_samples, n_chains): + @jax.jit + def _step(states, rng_key): + keys = jax.random.split(rng_key, n_chains) + states = jax.vmap(kernel)(keys, states) + return states, states + + sampling_keys = jax.random.split(rng_key, n_samples) + _, states = jax.lax.scan(_step, initial_state, sampling_keys) + return states + + +# pylint: disable=too-many-locals,invalid-name,redefined-outer-name +def run_slice(n_samples=30000, n_warmup=10000, n_chains=8, dbl_str="tfp"): + log_density_partial = partial(log_density_fn, y=y_observed) + log_density = lambda x: log_density_partial(**x) + + init, kernel = slice_sampler(log_density, 2, dbl_str) + initial_positions = p.sample( + seed=random.PRNGKey(1), sample_shape=(n_chains,) + ) + initial_positions = {"theta": initial_positions} + + initial_states = jax.vmap(init)(initial_positions) + states = inference_loop( + random.PRNGKey(23), kernel, initial_states, n_samples, n_chains + ) + samples = jax.block_until_ready(states.position["theta"]) + samples = samples[n_warmup:, ...].reshape(-1, len_theta) + + return samples + + +def run_tfp_slice(n_samples=30000, n_warmup=10000, n_chains=8): + log_density_partial = partial(log_density_fn, y=y_observed) + log_density = lambda x: jax.vmap(log_density_partial)(x) + + initial_states = p.sample(seed=random.PRNGKey(1), sample_shape=(n_chains,)) + samples = tfp.mcmc.sample_chain( + num_results=n_samples, + current_state=initial_states, + num_steps_between_results=0, + kernel=tfp.mcmc.SliceSampler( + log_density, step_size=0.1, max_doublings=2 + ), + num_burnin_steps=n_warmup, + trace_fn=None, + seed=random.PRNGKey(23), + ) + + samples = samples[n_warmup:, ...].reshape(-1, len_theta) + return samples + + +start = timer() +s1 = run_slice(dbl_str="custom") +end = timer() +print(f"Custom sampler time: {end - start}") + +start = timer() +s2 = run_slice(dbl_str="tfp") +end = timer() +print(f"TFP sampler time: {end - start}") + +start = timer() +s3 = run_tfp_slice() +end = timer() +print(f"real TFP sampler time: {end - start}") + + +fig, axes = plt.subplots(len_theta, 3) +for i in range(len_theta): + sns.histplot(s1[:, i], color="darkgrey", ax=axes[i, 0]) + sns.histplot(s2[:, i], color="darkgrey", ax=axes[i, 1]) + sns.histplot(s3[:, i], color="darkgrey", ax=axes[i, 2]) + axes[i, 0].set_title(rf"Custom slice") + axes[i, 1].set_title(rf"Custm slice TFP doubling") + axes[i, 2].set_title(rf"TFP slice") + for j in range(3): + axes[i, j].set_xlim(-5, 5) +sns.despine() +plt.tight_layout() +plt.show() diff --git a/sbijax/snl.py b/sbijax/snl.py index 91f6372..668a2c3 100644 --- a/sbijax/snl.py +++ b/sbijax/snl.py @@ -279,7 +279,7 @@ def _joint_logdensity_fn(theta): if "sampler" in kwargs and kwargs["sampler"] == "slice": def lp__(theta): - return jax.vmap(_joint_logdensity_fn)(theta) + return _joint_logdensity_fn(**theta) kwargs.pop("sampler", None) samples = sample_with_slice(