|
4 | 4 | import pytest
|
5 | 5 |
|
6 | 6 | import pytensor
|
7 |
| -from pytensor import Mode, function, grad |
| 7 | +from pytensor import In, Mode, Out, function, grad |
8 | 8 | from pytensor.compile.ops import DeepCopyOp
|
9 | 9 | from pytensor.configdefaults import config
|
10 | 10 | from pytensor.graph.basic import Variable, equal_computations
|
11 | 11 | from pytensor.graph.replace import clone_replace, vectorize_node
|
12 | 12 | from pytensor.graph.type import Type
|
13 | 13 | from pytensor.scalar.basic import ScalarConstant
|
14 | 14 | from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
|
15 |
| -from pytensor.tensor.basic import MakeVector, constant, stack |
| 15 | +from pytensor.tensor.basic import MakeVector, arange, constant, stack |
16 | 16 | from pytensor.tensor.elemwise import DimShuffle, Elemwise
|
17 | 17 | from pytensor.tensor.shape import (
|
18 | 18 | Reshape,
|
@@ -373,6 +373,43 @@ def test_static_shape(self):
|
373 | 373 | ):
|
374 | 374 | reshape(x2, (6, 3, 99))
|
375 | 375 |
|
| 376 | + def test_shape_strides(self): |
| 377 | + # Directly test the concern behind commit 223ee1548574b6bb8e73611ed605a97e29f13e7b |
| 378 | + x = arange(8) |
| 379 | + shape = vector("shape", dtype=int, shape=(3,)) |
| 380 | + fn = function([shape], x.reshape(shape)) |
| 381 | + |
| 382 | + # Empty strides |
| 383 | + test_shape = np.broadcast_to(np.array(2), (3,)) |
| 384 | + assert test_shape.strides == (0,) |
| 385 | + np.testing.assert_array_equal( |
| 386 | + fn(test_shape), |
| 387 | + np.arange(8).reshape(test_shape), |
| 388 | + ) |
| 389 | + |
| 390 | + # Negative non-contiguous strides |
| 391 | + test_shape = np.array([0, 4, 0, 2, 0, 1])[::-2] |
| 392 | + assert np.all(test_shape == (1, 2, 4)) |
| 393 | + assert test_shape.strides == (-16,) |
| 394 | + np.testing.assert_array_equal( |
| 395 | + fn(test_shape), |
| 396 | + np.arange(8).reshape(test_shape), |
| 397 | + ) |
| 398 | + |
| 399 | + def test_benchmark(self, benchmark): |
| 400 | + x = tensor3("x") |
| 401 | + x_val = np.random.random((2, 3, 4)).astype(config.floatX) |
| 402 | + y1 = x.reshape((6, 4)) |
| 403 | + y2 = x.reshape((2, 12)) |
| 404 | + y3 = x.reshape((-1,)) |
| 405 | + # Borrow to avoid deepcopy overhead |
| 406 | + reshape_fn = pytensor.function( |
| 407 | + [In(x, borrow=True)], |
| 408 | + [Out(y1, borrow=True), Out(y2, borrow=True), Out(y3, borrow=True)], |
| 409 | + ) |
| 410 | + reshape_fn.trust_input = True |
| 411 | + benchmark(reshape_fn, x_val) |
| 412 | + |
376 | 413 |
|
377 | 414 | def test_shape_i_hash():
|
378 | 415 | assert isinstance(Shape_i(np.int64(1)).__hash__(), int)
|
|
0 commit comments