Skip to content

Commit 8adacfa

Browse files
committed
Fix setting initvals on ZeroSumTransform rv
1 parent 920f043 commit 8adacfa

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pymc/distributions/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def __init__(self, zerosum_axes):
300300

301301
@staticmethod
302302
def extend_axis(array, axis):
303-
n = (array.shape[axis] + 1).astype("floatX")
303+
n = pt.cast(array.shape[axis] + 1, "floatX")
304304
sum_vals = array.sum(axis, keepdims=True)
305305
norm = sum_vals / (pt.sqrt(n) + n)
306306
fill_val = norm - sum_vals / pt.sqrt(n)

0 commit comments

Comments
 (0)