Skip to content

Commit 183bc5f

Browse files
Bruschkovtwiecki
authored andcommitted
Text backend does not support tallying a subset of variables (raises KeyError) #2560 (#3492)
1 parent e7d28f4 commit 183bc5f

File tree

3 files changed

+33
-5
lines changed

3 files changed

+33
-5
lines changed

pymc3/backends/text.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
"""
1818
from glob import glob
1919
import os
20+
import re
2021
import pandas as pd
2122

2223
from ..backends import base, ndarray
2324
from . import tracetab as ttab
2425
from ..theanof import floatX
26+
from ..model import modelcontext
2527

2628

2729
class Text(base.BaseTrace):
@@ -112,7 +114,6 @@ def _load_df(self):
112114
if "float" in str(dtype):
113115
self.df[key] = floatX(self.df[key])
114116

115-
116117
def __len__(self):
117118
if self.filename is None:
118119
return 0
@@ -178,13 +179,25 @@ def load(name, model=None):
178179
straces = []
179180
for f in files:
180181
chain = int(os.path.splitext(f)[0].rsplit('-', 1)[1])
181-
strace = Text(name, model=model)
182+
model_vars_in_chain = _parse_chain_vars(f, model)
183+
strace = Text(name, model=model, vars=model_vars_in_chain)
182184
strace.chain = chain
183185
strace.filename = f
184186
straces.append(strace)
185187
return base.MultiTrace(straces)
186188

187189

190+
def _parse_chain_vars(filepath, model):
191+
with open(filepath) as f:
192+
header = f.readline().split("\n", 1)[0]
193+
shape_pattern = re.compile(r"__\d+_\d+")
194+
chain_vars = [shape_pattern.split(v)[0] for v in header.split(",")]
195+
chain_vars = list(set(chain_vars))
196+
m = modelcontext(model)
197+
model_vars_in_chain = [v for v in m.unobserved_RVs if v.name in chain_vars]
198+
return model_vars_in_chain
199+
200+
188201
def dump(name, trace, chains=None):
189202
"""Store values from NDArray trace as CSV files.
190203

pymc3/tests/backend_fixtures.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,20 @@ class ModelBackendSampledTestCase:
124124
125125
Children may define
126126
- sampler_vars
127+
- write_partial_chain
127128
"""
128129
@classmethod
129130
def setup_class(cls):
130131
cls.test_point, cls.model, _ = models.beta_bernoulli(cls.shape)
132+
133+
if hasattr(cls, 'write_partial_chain') and cls.write_partial_chain is True:
134+
cls.chain_vars = cls.model.unobserved_RVs[1:]
135+
else:
136+
cls.chain_vars = cls.model.unobserved_RVs
137+
131138
with cls.model:
132-
strace0 = cls.backend(cls.name)
133-
strace1 = cls.backend(cls.name)
139+
strace0 = cls.backend(cls.name, vars=cls.chain_vars)
140+
strace1 = cls.backend(cls.name, vars=cls.chain_vars)
134141

135142
if not hasattr(cls, 'sampler_vars'):
136143
cls.sampler_vars = None
@@ -459,7 +466,7 @@ def test_values(self):
459466
trace = self.mtrace
460467
dumped = self.dumped
461468
for chain in trace.chains:
462-
for varname in self.test_point.keys():
469+
for varname in self.chain_vars:
463470
data = trace.get_values(varname, chains=[chain])
464471
dumped_data = dumped.get_values(varname, chains=[chain])
465472
npt.assert_equal(data, dumped_data)

pymc3/tests/test_text_backend.py

+8
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ class TestTextDumpLoad(bf.DumpLoadTestCase):
6262
shape = (2, 3)
6363

6464

65+
class TestTextDumpLoadWithPartialChain(bf.DumpLoadTestCase):
66+
backend = text.Text
67+
load_func = staticmethod(text.load)
68+
name = 'text-db'
69+
shape = (2, 3)
70+
write_partial_chain = True
71+
72+
6573
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
6674
class TestTextDumpFunction(bf.BackendEqualityTestCase):
6775
backend0 = backend1 = ndarray.NDArray

0 commit comments

Comments
 (0)