Skip to content

Commit b903e57

Browse files
committed
stack array stats
1 parent c3c76c5 commit b903e57

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

pymc/backends/base.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,18 @@ def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):
204204
vals = np.stack(
205205
[self._get_sampler_stats(stat_name, i, burn, thin) for i in sampler_idxs], axis=-1
206206
)
207+
207208
if vals.shape[-1] == 1:
208-
return vals[..., 0]
209-
else:
210-
return vals
209+
vals = vals[..., 0]
210+
211+
if vals.dtype == np.dtype(object):
212+
try:
213+
vals = np.vstack(vals)
214+
except ValueError:
215+
# Most likely due to non-identical shapes. Just stick with the object-array.
216+
pass
217+
218+
return vals
211219

212220
def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
213221
"""Get sampler statistics."""

0 commit comments

Comments
 (0)