Skip to content

Commit 25f9204

Browse files
committed
Fix setting initvals on ZeroSumTransform rv
1 parent 920f043 commit 25f9204

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

pymc/distributions/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def __init__(self, zerosum_axes):
300300

301301
@staticmethod
302302
def extend_axis(array, axis):
303-
n = (array.shape[axis] + 1).astype("floatX")
303+
n = pt.cast(array.shape[axis] + 1, "floatX")
304304
sum_vals = array.sum(axis, keepdims=True)
305305
norm = sum_vals / (pt.sqrt(n) + n)
306306
fill_val = norm - sum_vals / pt.sqrt(n)
@@ -312,7 +312,7 @@ def extend_axis(array, axis):
312312
def extend_axis_rev(array, axis):
313313
normalized_axis = normalize_axis_tuple(axis, array.ndim)[0]
314314

315-
n = array.shape[normalized_axis].astype("floatX")
315+
n = pt.cast(array.shape[normalized_axis], "floatX")
316316
last = pt.take(array, [-1], axis=normalized_axis)
317317

318318
sum_vals = -last * pt.sqrt(n)

tests/distributions/test_transform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,17 @@ def test_sum_to_1():
170170
)
171171

172172

173+
def test_zerosumtransform():
174+
zst = tr.ZeroSumTransform([0])
175+
176+
# Check numpy input works, as it is not always converted to pytensor before
177+
# Case where it failed was when setting initvals in model
178+
val = np.array([1, 2, 3, 4])
179+
zval = zst.backward(val)
180+
assert np.allclose(zval.eval().sum(), 0.0)
181+
assert np.allclose(zst.forward(zval).eval(), val)
182+
183+
173184
def test_log():
174185
check_transform(tr.log, Rplusbig)
175186

0 commit comments

Comments
 (0)