|
26 | 26 | from aesara.link.jax.dispatch import jax_funcify |
27 | 27 |
|
28 | 28 | from pymc import Model, modelcontext |
29 | | -from pymc.aesaraf import compile_rv_inplace |
| 29 | +from pymc.aesaraf import compile_rv_inplace, inputvars |
| 30 | +from pymc.util import get_default_varnames |
30 | 31 |
|
31 | 32 | warnings.warn("This module is experimental.") |
32 | 33 |
|
@@ -101,13 +102,19 @@ def sample_numpyro_nuts( |
101 | 102 | target_accept=0.8, |
102 | 103 | random_seed=10, |
103 | 104 | model=None, |
| 105 | + var_names=None, |
104 | 106 | progress_bar=True, |
105 | 107 | keep_untransformed=False, |
106 | 108 | ): |
107 | 109 | from numpyro.infer import MCMC, NUTS |
108 | 110 |
|
109 | 111 | model = modelcontext(model) |
110 | 112 |
|
| 113 | + if var_names is None: |
| 114 | + var_names = model.unobserved_value_vars |
| 115 | + |
| 116 | + vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) |
| 117 | + |
111 | 118 | tic1 = pd.Timestamp.now() |
112 | 119 | print("Compiling...", file=sys.stdout) |
113 | 120 |
|
@@ -143,45 +150,28 @@ def sample_numpyro_nuts( |
143 | 150 | seed = jax.random.PRNGKey(random_seed) |
144 | 151 | map_seed = jax.random.split(seed, chains) |
145 | 152 |
|
146 | | - pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",)) |
| 153 | + if chains == 1: |
| 154 | + pmap_numpyro.run(seed, init_params=init_state, extra_fields=("num_steps",)) |
| 155 | + else: |
| 156 | + pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",)) |
| 157 | + |
147 | 158 | raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) |
148 | 159 |
|
149 | 160 | tic3 = pd.Timestamp.now() |
150 | 161 | print("Sampling time = ", tic3 - tic2, file=sys.stdout) |
151 | 162 |
|
152 | 163 | print("Transforming variables...", file=sys.stdout) |
153 | | - mcmc_samples = [] |
154 | | - for i, (value_var, raw_samples) in enumerate(zip(model.value_vars, raw_mcmc_samples)): |
155 | | - raw_samples = at.constant(np.asarray(raw_samples)) |
156 | | - |
157 | | - rv = model.values_to_rvs[value_var] |
158 | | - transform = getattr(value_var.tag, "transform", None) |
159 | | - |
160 | | - if transform is not None: |
161 | | - # TODO: This will fail when the transformation depends on another variable |
162 | | - # such as in interval transform with RVs as edges |
163 | | - trans_samples = transform.backward(raw_samples, *rv.owner.inputs) |
164 | | - trans_samples.name = rv.name |
165 | | - mcmc_samples.append(trans_samples) |
166 | | - |
167 | | - if keep_untransformed: |
168 | | - raw_samples.name = value_var.name |
169 | | - mcmc_samples.append(raw_samples) |
170 | | - else: |
171 | | - raw_samples.name = rv.name |
172 | | - mcmc_samples.append(raw_samples) |
173 | | - |
174 | | - mcmc_varnames = [var.name for var in mcmc_samples] |
175 | | - mcmc_samples = compile_rv_inplace( |
176 | | - [], |
177 | | - mcmc_samples, |
178 | | - mode="JAX", |
179 | | - )() |
| 164 | + mcmc_samples = {} |
| 165 | + for v in vars_to_sample: |
| 166 | + fgraph = FunctionGraph(model.value_vars, [v], clone=False) |
| 167 | + jax_fn = jax_funcify(fgraph) |
| 168 | + result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] |
| 169 | + mcmc_samples[v.name] = result |
180 | 170 |
|
181 | 171 | tic4 = pd.Timestamp.now() |
182 | 172 | print("Transformation time = ", tic4 - tic3, file=sys.stdout) |
183 | 173 |
|
184 | | - posterior = {k: v for k, v in zip(mcmc_varnames, mcmc_samples)} |
| 174 | + posterior = mcmc_samples |
185 | 175 | az_trace = az.from_dict(posterior=posterior) |
186 | 176 |
|
187 | 177 | return az_trace |
0 commit comments