Skip to content

Commit cc45303

Browse files
Always define the diverging variable in sampling generator (#3901)
1 parent eec70e8 commit cc45303

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

Diff for: pymc3/sampling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,7 @@ def _iter_sample(
912912
step.reset_tuning()
913913
for i in range(draws):
914914
stats = None
915+
diverging = False
915916

916917
if i == 0 and hasattr(step, "iter_count"):
917918
step.iter_count = 0
@@ -927,7 +928,6 @@ def _iter_sample(
927928
else:
928929
point = step.step(point)
929930
strace.record(point)
930-
diverging = False
931931
if callback is not None:
932932
warns = getattr(step, "warnings", None)
933933
callback(trace=strace, draw=Draw(chain, i == draws, i, i < tune, stats, point, warns))

Diff for: pymc3/tests/test_text_backend.py

+5
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ def test_supports_sampler_stats(self):
2828
db = text.Text(self.name)
2929
pm.sample(20, tune=10, init=None, trace=db, cores=2)
3030

31+
def test_supports_sampler_stats_diverging(self):
32+
with pm.Model():
33+
pm.Normal("mu", mu=0, sigma=1, shape=2)
34+
pm.sample(20, tune=10, init=None, trace='text', cores=1)
35+
3136
def teardown_method(self):
3237
bf.remove_file_or_directory(self.name)
3338

0 commit comments

Comments
 (0)