File tree Expand file tree Collapse file tree 1 file changed +1
-2
lines changed Expand file tree Collapse file tree 1 file changed +1
-2
lines changed Original file line number Diff line number Diff line change @@ -114,7 +114,6 @@ def sample_numpyro_nuts(
114
114
var_names = model .unobserved_value_vars
115
115
116
116
vars_to_sample = list (get_default_varnames (var_names , include_transformed = keep_untransformed ))
117
- inputs = [model .rvs_to_values [i ] for i in model .free_RVs ]
118
117
119
118
tic1 = pd .Timestamp .now ()
120
119
print ("Compiling..." , file = sys .stdout )
@@ -164,7 +163,7 @@ def sample_numpyro_nuts(
164
163
print ("Transforming variables..." , file = sys .stdout )
165
164
mcmc_samples = {}
166
165
for v in vars_to_sample :
167
- fgraph = FunctionGraph (inputs , [v ], clone = False )
166
+ fgraph = FunctionGraph (model . value_vars , [v ], clone = False )
168
167
jax_fn = jax_funcify (fgraph )
169
168
result = jax .vmap (jax .vmap (jax_fn ))(* raw_mcmc_samples )[0 ]
170
169
mcmc_samples [v .name ] = result
You can’t perform that action at this time.
0 commit comments