diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index ebdaf3c3e..4e59049d3 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -300,7 +300,7 @@ def __init__(self, zerosum_axes): @staticmethod def extend_axis(array, axis): - n = (array.shape[axis] + 1).astype("floatX") + n = pt.cast(array.shape[axis] + 1, "floatX") sum_vals = array.sum(axis, keepdims=True) norm = sum_vals / (pt.sqrt(n) + n) fill_val = norm - sum_vals / pt.sqrt(n) @@ -312,7 +312,7 @@ def extend_axis(array, axis): def extend_axis_rev(array, axis): normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] - n = array.shape[normalized_axis].astype("floatX") + n = pt.cast(array.shape[normalized_axis], "floatX") last = pt.take(array, [-1], axis=normalized_axis) sum_vals = -last * pt.sqrt(n) diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index e28052bab..176bc24ce 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -170,6 +170,17 @@ def test_sum_to_1(): ) +def test_zerosumtransform(): + zst = tr.ZeroSumTransform([0]) + + # Check numpy input works, as it is not always converted to pytensor before + # Case where it failed was when setting initvals in model + val = np.array([1, 2, 3, 4]) + zval = zst.backward(val) + assert np.allclose(zval.eval().sum(), 0.0) + assert np.allclose(zst.forward(zval).eval(), val) + + def test_log(): check_transform(tr.log, Rplusbig)