Skip to content

Commit 6834740

Browse files
ricardoV94michaelosthege
authored andcommitted
Fix static type shape bug
1 parent 081a0b4 commit 6834740

File tree

5 files changed

+49
-6
lines changed

5 files changed

+49
-6
lines changed

pytensor/tensor/random/op.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def extract_batch_shape(p, ps, n):
191191
return shape
192192

193193
batch_shape = [
194-
s if b is False else constant(1, "int64")
194+
s if not b else constant(1, "int64")
195195
for s, b in zip(shape[:-n], p.type.broadcastable[:-n])
196196
]
197197
return batch_shape

pytensor/tensor/type.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,13 @@ def __init__(
109109
def parse_bcast_and_shape(s):
110110
if isinstance(s, (bool, np.bool_)):
111111
return 1 if s else None
112-
else:
112+
elif isinstance(s, (int, np.int_)):
113+
return int(s)
114+
elif s is None:
113115
return s
116+
raise ValueError(
117+
f"TensorType broadcastable/shape must be a boolean, integer or None, got {type(s)} {s}"
118+
)
114119

115120
self.shape = tuple(parse_bcast_and_shape(s) for s in shape)
116121
self.dtype_specs() # error checking is done there

tests/tensor/random/test_basic.py

+10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytensor.graph.op import get_test_value
1717
from pytensor.graph.replace import clone_replace
1818
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
19+
from pytensor.tensor import ones, stack
1920
from pytensor.tensor.random.basic import (
2021
_gamma,
2122
bernoulli,
@@ -1465,3 +1466,12 @@ def test_rebuild():
14651466
assert y_new.type.shape == (100,)
14661467
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
14671468
assert y_new.eval({x_new: x_new_test}).shape == (100,)
1469+
1470+
1471+
def test_categorical_join_p_static_shape():
1472+
"""Regression test against a bug caused by misreading a numpy.bool_"""
1473+
p = ones(3) / 3
1474+
prob = stack([p, 1 - p], axis=-1)
1475+
assert prob.type.shape == (3, 2)
1476+
x = categorical(p=prob)
1477+
assert x.type.shape == (3,)

tests/tensor/test_basic.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -2046,17 +2046,24 @@ def test_mixed_ndim_error(self):
20462046
def test_static_shape_inference(self):
20472047
a = at.tensor(dtype="int8", shape=(2, 3))
20482048
b = at.tensor(dtype="int8", shape=(2, 5))
2049-
assert at.join(1, a, b).type.shape == (2, 8)
2050-
assert at.join(-1, a, b).type.shape == (2, 8)
2049+
2050+
res = at.join(1, a, b).type.shape
2051+
assert res == (2, 8)
2052+
assert all(isinstance(s, int) for s in res)
2053+
2054+
res = at.join(-1, a, b).type.shape
2055+
assert res == (2, 8)
2056+
assert all(isinstance(s, int) for s in res)
20512057

20522058
# Check early informative errors from static shape info
20532059
with pytest.raises(ValueError, match="must match exactly"):
20542060
at.join(0, at.ones((2, 3)), at.ones((2, 5)))
20552061

20562062
# Check partial inference
20572063
d = at.tensor(dtype="int8", shape=(2, None))
2058-
assert at.join(1, a, b, d).type.shape == (2, None)
2059-
return
2064+
res = at.join(1, a, b, d).type.shape
2065+
assert res == (2, None)
2066+
assert isinstance(res[0], int)
20602067

20612068
def test_split_0elem(self):
20622069
rng = np.random.default_rng(seed=utt.fetch_seed())

tests/tensor/test_type.py

+21
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,27 @@ def test_fixed_shape_basic():
267267
assert t2.shape == (2, 4)
268268

269269

270+
def test_shape_type_conversion():
271+
t1 = TensorType("float64", shape=np.array([3], dtype=int))
272+
assert t1.shape == (3,)
273+
assert isinstance(t1.shape[0], int)
274+
assert t1.broadcastable == (False,)
275+
assert isinstance(t1.broadcastable[0], bool)
276+
277+
t2 = TensorType("float64", broadcastable=np.array([True, False], dtype="bool"))
278+
assert t2.shape == (1, None)
279+
assert isinstance(t2.shape[0], int)
280+
assert t2.broadcastable == (True, False)
281+
assert isinstance(t2.broadcastable[0], bool)
282+
assert isinstance(t2.broadcastable[1], bool)
283+
284+
with pytest.raises(
285+
ValueError,
286+
match="TensorType broadcastable/shape must be a boolean, integer or None",
287+
):
288+
TensorType("float64", shape=("1", "2"))
289+
290+
270291
def test_fixed_shape_clone():
271292
t1 = TensorType("float64", (1,))
272293

0 commit comments

Comments
 (0)