diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index e30887cfe3..edd37b4af5 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3224,13 +3224,29 @@ def __init__(self, dtype): self.dtype = dtype def make_node(self, start, stop, step): + from math import ceil + start, stop, step = map(as_tensor_variable, (start, stop, step)) + assert start.ndim == 0 assert stop.ndim == 0 assert step.ndim == 0 + # if it is possible to directly determine the shape i.e static shape is present, we find it. + if ( + isinstance(start, TensorConstant) + and isinstance(stop, TensorConstant) + and isinstance(step, TensorConstant) + ): + length = max( + ceil((float(stop.data) - float(start.data)) / float(step.data)), 0 + ) + shape = (length,) + else: + shape = (None,) + inputs = [start, stop, step] - outputs = [tensor(dtype=self.dtype, shape=(None,))] + outputs = [tensor(dtype=self.dtype, shape=shape)] return Apply(self, inputs, outputs) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 60643e2984..19e08d49cd 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -2861,6 +2861,13 @@ def test_infer_shape(self, cast_policy): assert np.all(f(2) == len(np.arange(0, 2))) assert np.all(f(0) == len(np.arange(0, 0))) + def test_static_shape(self): + assert np.arange(1, 10).shape == arange(1, 10).type.shape + assert np.arange(10, 1, -1).shape == arange(10, 1, -1).type.shape + assert np.arange(1, -9, 2).shape == arange(1, -9, 2).type.shape + assert np.arange(1.3, 17.48, 2.67).shape == arange(1.3, 17.48, 2.67).type.shape + assert np.arange(-64, 64).shape == arange(-64, 64).type.shape + class TestNdGrid: def setup_method(self):