| 
4 | 4 | import pytest  | 
5 | 5 | 
 
  | 
6 | 6 | import pytensor  | 
7 |  | -from pytensor import Mode, function, grad  | 
 | 7 | +from pytensor import In, Mode, Out, function, grad  | 
8 | 8 | from pytensor.compile.ops import DeepCopyOp  | 
9 | 9 | from pytensor.configdefaults import config  | 
10 | 10 | from pytensor.graph.basic import Variable, equal_computations  | 
11 | 11 | from pytensor.graph.replace import clone_replace, vectorize_node  | 
12 | 12 | from pytensor.graph.type import Type  | 
13 | 13 | from pytensor.scalar.basic import ScalarConstant  | 
14 | 14 | from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row  | 
15 |  | -from pytensor.tensor.basic import MakeVector, constant, stack  | 
 | 15 | +from pytensor.tensor.basic import MakeVector, arange, constant, stack  | 
16 | 16 | from pytensor.tensor.elemwise import DimShuffle, Elemwise  | 
17 | 17 | from pytensor.tensor.shape import (  | 
18 | 18 |     Reshape,  | 
@@ -373,6 +373,43 @@ def test_static_shape(self):  | 
373 | 373 |         ):  | 
374 | 374 |             reshape(x2, (6, 3, 99))  | 
375 | 375 | 
 
  | 
 | 376 | +    def test_shape_strides(self):  | 
 | 377 | +        # Directly test the concern behind commit 223ee1548574b6bb8e73611ed605a97e29f13e7b  | 
 | 378 | +        x = arange(8)  | 
 | 379 | +        shape = vector("shape", dtype=int, shape=(3,))  | 
 | 380 | +        fn = function([shape], x.reshape(shape))  | 
 | 381 | + | 
 | 382 | +        # Empty strides  | 
 | 383 | +        test_shape = np.broadcast_to(np.array(2), (3,))  | 
 | 384 | +        assert test_shape.strides == (0,)  | 
 | 385 | +        np.testing.assert_array_equal(  | 
 | 386 | +            fn(test_shape),  | 
 | 387 | +            np.arange(8).reshape(test_shape),  | 
 | 388 | +        )  | 
 | 389 | + | 
 | 390 | +        # Negative non-contiguous strides  | 
 | 391 | +        test_shape = np.array([0, 4, 0, 2, 0, 1])[::-2]  | 
 | 392 | +        assert np.all(test_shape == (1, 2, 4))  | 
 | 393 | +        assert test_shape.strides == (-16,)  | 
 | 394 | +        np.testing.assert_array_equal(  | 
 | 395 | +            fn(test_shape),  | 
 | 396 | +            np.arange(8).reshape(test_shape),  | 
 | 397 | +        )  | 
 | 398 | + | 
 | 399 | +    def test_benchmark(self, benchmark):  | 
 | 400 | +        x = tensor3("x")  | 
 | 401 | +        x_val = np.random.random((2, 3, 4)).astype(config.floatX)  | 
 | 402 | +        y1 = x.reshape((6, 4))  | 
 | 403 | +        y2 = x.reshape((2, 12))  | 
 | 404 | +        y3 = x.reshape((-1,))  | 
 | 405 | +        # Borrow to avoid deepcopy overhead  | 
 | 406 | +        reshape_fn = pytensor.function(  | 
 | 407 | +            [In(x, borrow=True)],  | 
 | 408 | +            [Out(y1, borrow=True), Out(y2, borrow=True), Out(y3, borrow=True)],  | 
 | 409 | +        )  | 
 | 410 | +        reshape_fn.trust_input = True  | 
 | 411 | +        benchmark(reshape_fn, x_val)  | 
 | 412 | + | 
376 | 413 | 
 
  | 
377 | 414 | def test_shape_i_hash():  | 
378 | 415 |     assert isinstance(Shape_i(np.int64(1)).__hash__(), int)  | 
 | 
0 commit comments