diff --git a/2560.py b/2560.py new file mode 100644 index 0000000000..2f085baf0d --- /dev/null +++ b/2560.py @@ -0,0 +1,15 @@ +import pymc3 as pm +with pm.Model() as model: + a = pm.Gamma('a', mu=10.0, sd=2.0) + b = pm.Gamma('b', mu=10.0, sd=2.0) + + trace = pm.sample(trace=[model.a, model.a_log__]) + assert len(trace.varnames) == 2 +# you have to provide 4 vars: +# trace = pm.sample(trace=[model.a, model.a_log__, model.b, model.b_log__]) + + pm.backends.text.dump('trace.text', trace) + + loaded = pm.backends.text.load('trace.text') + print("loaded: ", loaded) + x = loaded[0] #!!! Will throw a KeyError looking for 'b_log__' diff --git a/pymc3/backends/text.py b/pymc3/backends/text.py index d5bec95639..3b84a72017 100644 --- a/pymc3/backends/text.py +++ b/pymc3/backends/text.py @@ -19,6 +19,7 @@ import os import pandas as pd +from ..model import modelcontext from ..backends import base, ndarray from . import tracetab as ttab from ..theanof import floatX @@ -184,7 +185,7 @@ def load(name, model=None): return base.MultiTrace(straces) -def dump(name, trace, chains=None): +def dump(name, trace, model=None, chains=None): """Store values from NDArray trace as CSV files. Parameters @@ -200,6 +201,12 @@ def dump(name, trace, chains=None): os.mkdir(name) if chains is None: chains = trace.chains + model = modelcontext(model) + model_vars = sorted([x.name for x in model.unobserved_RVs]) + trace_vars = sorted(trace.varnames) + + if model_vars != trace_vars: + raise ValueError('Variables mismatch: model_vars=' + str(model_vars) + ", trace_vars=" + str(trace_vars)) for chain in chains: filename = os.path.join(name, 'chain-{}.csv'.format(chain))