Skip to content

Commit 9eb60eb

Browse files
committed
Make drop_warning_stat work with flat stat names for compound steps
1 parent 70cb73c commit 9eb60eb

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

pymc/util.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools
16+
import re
1617
import warnings
1718

1819
from collections import namedtuple
@@ -276,7 +277,12 @@ def drop_warning_stat(idata: arviz.InferenceData) -> arviz.InferenceData:
276277
nidata = arviz.InferenceData(attrs=idata.attrs)
277278
for gname, group in idata.items():
278279
if "sample_stat" in gname:
279-
group = group.drop_vars(names=["warning", "warning_dim_0"], errors="ignore")
280+
warning_vars = [
281+
name
282+
for name in group.data_vars
283+
if name == "warning" or re.match(r"sampler_\d+__warning", str(name))
284+
]
285+
group = group.drop_vars(names=[*warning_vars, "warning_dim_0"], errors="ignore")
280286
nidata.add_groups({gname: group}, coords=group.coords, dims=group.dims)
281287
return nidata
282288

0 commit comments

Comments
 (0)