From 8db99399a131cc7985c787cdf48b61fe3cb43088 Mon Sep 17 00:00:00 2001 From: preich Date: Tue, 24 Oct 2017 19:01:51 +0100 Subject: [PATCH 1/2] example from https://github.com/pymc-devs/pymc3/issues/2560 --- 2560.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 2560.py diff --git a/2560.py b/2560.py new file mode 100644 index 0000000000..d934327db8 --- /dev/null +++ b/2560.py @@ -0,0 +1,12 @@ +import pymc3 as pm +with pm.Model() as model: + a = pm.Gamma('a', mu=10.0, sd=2.0) + b = pm.Gamma('b', mu=10.0, sd=2.0) + + trace = pm.sample(trace=[model.a, model.a_log__]) + assert len(trace.varnames) == 2 + + pm.backends.text.dump('trace.text', trace) + + loaded = pm.backends.text.load('trace.text') + x = loaded[0] #!!! Will throw a KeyError looking for 'b_log__' From 2d3499e3486a0b5e57b6b2387d6e66eeb7b4cd93 Mon Sep 17 00:00:00 2001 From: preich Date: Tue, 24 Oct 2017 20:59:38 +0100 Subject: [PATCH 2/2] https://github.com/pymc-devs/pymc3/issues/2560 fail-fast if model vars mismatches trace vars --- 2560.py | 3 +++ pymc3/backends/text.py | 9 ++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/2560.py b/2560.py index d934327db8..2f085baf0d 100644 --- a/2560.py +++ b/2560.py @@ -5,8 +5,11 @@ trace = pm.sample(trace=[model.a, model.a_log__]) assert len(trace.varnames) == 2 +# you have to provide 4 vars: +# trace = pm.sample(trace=[model.a, model.a_log__, model.b, model.b_log__]) pm.backends.text.dump('trace.text', trace) loaded = pm.backends.text.load('trace.text') + print("loaded: ", loaded) x = loaded[0] #!!! Will throw a KeyError looking for 'b_log__' diff --git a/pymc3/backends/text.py b/pymc3/backends/text.py index d5bec95639..3b84a72017 100644 --- a/pymc3/backends/text.py +++ b/pymc3/backends/text.py @@ -19,6 +19,7 @@ import os import pandas as pd +from ..model import modelcontext from ..backends import base, ndarray from . import tracetab as ttab from ..theanof import floatX @@ -184,7 +185,7 @@ def load(name, model=None): return base.MultiTrace(straces) -def dump(name, trace, chains=None): +def dump(name, trace, model=None, chains=None): """Store values from NDArray trace as CSV files. Parameters @@ -200,6 +201,12 @@ def dump(name, trace, chains=None): os.mkdir(name) if chains is None: chains = trace.chains + model = modelcontext(model) + model_vars = sorted([x.name for x in model.unobserved_RVs]) + trace_vars = sorted(trace.varnames) + + if model_vars != trace_vars: + raise ValueError('Variables mismatch: model_vars=' + str(model_vars) + ", trace_vars=" + str(trace_vars)) for chain in chains: filename = os.path.join(name, 'chain-{}.csv'.format(chain))