Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slice sampler #11

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions examples/slcp_snl_masked_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from jax import numpy as jnp
from jax import random
from jax import scipy as jsp
from jax import vmap
from surjectors import Chain, TransformedDistribution
from surjectors.bijectors.masked_autoregressive import MaskedAutoregressive
from surjectors.bijectors.permutation import Permutation
Expand Down Expand Up @@ -155,7 +154,16 @@ def run(use_surjectors):
snl = SNL(fns, model)
optimizer = optax.adam(1e-3)
params, info = snl.fit(
random.PRNGKey(23), y_observed, optimizer, n_rounds=3, sampler="slice"
random.PRNGKey(23),
y_observed,
optimizer,
n_rounds=5,
sampler="slice",
n_thin=0,
max_n_iter=1,
n_samples=100,
n_warmup=50,
n_simulations_per_round=10,
)

snl_samples, _ = snl.sample_posterior(
Expand All @@ -170,8 +178,9 @@ def log_density_fn(theta, y):
lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp)
return lp

#
log_density_partial = partial(log_density_fn, y=y_observed)
log_density = lambda x: vmap(log_density_partial)(x)
log_density = lambda x: log_density_partial(**x)

rng_seq = hk.PRNGSequence(12)
slice_samples = sample_with_slice(
Expand Down Expand Up @@ -204,6 +213,6 @@ def log_density_fn(theta, y):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--use-surjectors", action="store_true", default=True)
parser.add_argument("--use-surjectors", action="store_true", default=False)
args = parser.parse_args()
run(args.use_surjectors)
99 changes: 93 additions & 6 deletions sbijax/mcmc/slice.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import distrax
import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp

from sbijax.mcmc.slice_sampler import slice_sampler


# pylint: disable=too-many-arguments
def sample_with_slice(
Expand All @@ -10,9 +13,10 @@ def sample_with_slice(
n_samples,
n_warmup,
prior,
n_thin=2,
n_thin=0,
n_doubling=5,
step_size=1,
implementation="custom",
**kwargs,
):
"""
Sample from a distribution using the No-U-Turn sampler.
Expand All @@ -36,7 +40,36 @@ def sample_with_slice(
a JAX array of dimension n_samples \times n_chains \times len_theta
"""

initial_states = _slice_init(rng_seq, n_chains, prior)
if implementation == "custom":
fn = _custom_slice
else:
fn = _tfp_slice
return fn(
rng_seq,
lp,
n_chains,
n_samples,
n_warmup,
prior,
n_thin,
n_doubling,
**kwargs,
)


# pylint: disable=too-many-arguments
def _tfp_slice(
rng_seq,
lp,
n_chains,
n_samples,
n_warmup,
prior,
n_thin=0,
n_doubling=5,
step_size=1.0,
):
initial_states = prior(seed=next(rng_seq), sample_shape=(n_chains,))
samples = tfp.mcmc.sample_chain(
num_results=n_samples,
current_state=initial_states,
Expand All @@ -51,8 +84,62 @@ def sample_with_slice(
return samples


# pylint: disable=too-many-arguments
def _custom_slice(
rng_seq,
lp,
n_chains,
n_samples,
n_warmup,
prior,
n_thin=0,
n_doubling=5,
**kwargs,
):
def _inference_loop(rng_key, kernel, initial_state, n_samples):
@jax.jit
def _step(states, rng_key):
keys = jax.random.split(rng_key, n_chains)
states = jax.vmap(kernel)(keys, states)
return states, states

sampling_keys = jax.random.split(rng_key, n_samples)
_, states = jax.lax.scan(_step, initial_state, sampling_keys)
return states

initial_states, kernel = _slice_init(
rng_seq, prior, n_chains, lp, n_doubling
)

n_total_samples = _minimal_sample_size_with_thinning(
n_samples, n_warmup, n_thin
)
states = _inference_loop(
next(rng_seq), kernel, initial_states, n_total_samples
)
_ = states.position["theta"].block_until_ready()
thetas = states.position["theta"][n_warmup:, :, :]
# thinning: take the n_thin-th sample as first point and then step to
# the next sample by skipping n_thin indexes, i.e.
thetas = thetas[n_thin :: (n_thin + 1), ...]
return thetas


# pylint: disable=missing-function-docstring
def _slice_init(rng_seq, n_chains, prior: distrax.Distribution):
def _slice_init(rng_seq, prior, n_chains, lp, n_doublings):
initial_positions = prior(seed=next(rng_seq), sample_shape=(n_chains,))
initial_positions = {"theta": initial_positions}
init, kernel = slice_sampler(lp, n_doublings)
initial_states = jax.vmap(init)(initial_positions)
return initial_states, kernel


return initial_positions
def _minimal_sample_size_with_thinning(n_samples, n_warmup, n_thin):
n_effective_samples = n_samples - n_warmup
n_to_draw = n_effective_samples
while True:
n_returned = len(np.arange(n_to_draw)[n_thin :: (n_thin + 1)])
if n_returned >= n_effective_samples:
break
n_to_draw += 1
return n_warmup + n_to_draw
Loading