Skip to content

Commit 9b697f6

Browse files
committed
use from_pymc3(..., log_likelihood=False) and update requirements
1 parent e7bc832 commit 9b697f6

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

pymc3/backends/report.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def n_draws(self) -> typing.Optional[int]:
8787
def t_sampling(self) -> typing.Optional[float]:
8888
"""
8989
Number of seconds that the sampling procedure took.
90-
90+
9191
(Includes parallelization overhead.)
9292
"""
9393
return self._t_sampling
@@ -108,6 +108,7 @@ def _run_convergence_checks(self, trace, model):
108108
return
109109

110110
from pymc3 import rhat, ess
111+
from arviz import from_pymc3
111112

112113
valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
113114
varnames = []
@@ -119,8 +120,9 @@ def _run_convergence_checks(self, trace, model):
119120
if rv_name in trace.varnames:
120121
varnames.append(rv_name)
121122

122-
self._ess = ess = ess(trace, var_names=varnames)
123-
self._rhat = rhat = rhat(trace, var_names=varnames)
123+
idata = from_pymc3(trace, log_likelihood=False)
124+
self._ess = ess = ess(idata, var_names=varnames)
125+
self._rhat = rhat = rhat(idata, var_names=varnames)
124126

125127
warnings = []
126128
rhat_max = max(val.max() for val in rhat.values())

requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
arviz>=0.4.1
1+
arviz>=0.7.0
22
theano>=1.0.4
33
numpy>=1.13.0
44
scipy>=0.18.1
@@ -7,4 +7,4 @@ patsy>=0.5.1
77
fastprogress>=0.2.0
88
h5py>=2.7.0
99
typing-extensions>=3.7.4
10-
contextvars; python_version < '3.7'
10+
contextvars; python_version < '3.7'

0 commit comments

Comments
 (0)