Skip to content

Commit 8fb41d6

Browse files
committed
Revert regression in Reshape c-impl speed
This was caused by 223ee15, which used the generic `PyArray_IntpConverter` to convert the shape numpy vector into a simple C-array for the Reshape operation. There seems to be no justification for this change. Profiling suggests the previous changes caused a 7.5x slowdown. The benchmark detects only a 2.3x slowdown due to the PyTensor call overhead.
1 parent e25e8a2 commit 8fb41d6

File tree

2 files changed

+52
-20
lines changed

2 files changed

+52
-20
lines changed

pytensor/tensor/shape.py

+13-18
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from pytensor.link.c.op import COp
1717
from pytensor.link.c.params_type import ParamsType
1818
from pytensor.npy_2_compat import normalize_axis_tuple
19-
from pytensor.scalar import int32
2019
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
2120
from pytensor.tensor import basic as ptb
2221
from pytensor.tensor.elemwise import get_normalized_batch_axes
@@ -628,14 +627,11 @@ class Reshape(COp):
628627

629628
check_input = False
630629
__props__ = ("ndim",)
631-
params_type = ParamsType(ndim=int32)
632-
# name does not participate because it doesn't affect computations
633630

634-
def __init__(self, ndim, name=None):
631+
def __init__(self, ndim):
635632
self.ndim = int(ndim)
636633
if ndim < 0:
637634
raise ValueError("The output dimensions after reshape must be 0 or greater")
638-
assert name is None, "name attribute for Reshape has been deprecated"
639635

640636
def __str__(self):
641637
return f"{self.__class__.__name__}{{{self.ndim}}}"
@@ -795,33 +791,32 @@ def infer_shape(self, fgraph, node, ishapes):
795791
]
796792

797793
def c_code_cache_version(self):
798-
return (9,)
794+
return (10,)
799795

800796
def c_code(self, node, name, inputs, outputs, sub):
801797
x, shp = inputs
798+
shp_dtype = node.inputs[1].type.dtype_specs()[1]
802799
(z,) = outputs
803800
fail = sub["fail"]
804-
params = sub["params"]
801+
ndim = self.ndim
802+
805803
return f"""
806804
assert (PyArray_NDIM({shp}) == 1);
807805
808-
PyArray_Dims newshape;
809-
810-
if (!PyArray_IntpConverter((PyObject *){shp}, &newshape)) {{
811-
{fail};
806+
// Unpack shape into new_dims
807+
npy_intp new_dims[{ndim}];
808+
for (int ii = 0; ii < {ndim}; ++ii)
809+
{{
810+
new_dims[ii] = (({shp_dtype}*)(PyArray_BYTES({shp}) + ii * PyArray_STRIDES({shp})[0]))[0];
812811
}}
813812
814-
if ({params}->ndim != newshape.len) {{
815-
PyErr_SetString(PyExc_ValueError, "Shape argument to Reshape has incorrect length");
816-
PyDimMem_FREE(newshape.ptr);
817-
{fail};
818-
}}
813+
PyArray_Dims newshape;
814+
newshape.len = {ndim};
815+
newshape.ptr = new_dims;
819816
820817
Py_XDECREF({z});
821818
{z} = (PyArrayObject *) PyArray_Newshape({x}, &newshape, NPY_CORDER);
822819
823-
PyDimMem_FREE(newshape.ptr);
824-
825820
if (!{z}) {{
826821
//The error message should have been set by PyArray_Newshape
827822
{fail};

tests/tensor/test_shape.py

+39-2
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
import pytest
55

66
import pytensor
7-
from pytensor import Mode, function, grad
7+
from pytensor import In, Mode, Out, function, grad
88
from pytensor.compile.ops import DeepCopyOp
99
from pytensor.configdefaults import config
1010
from pytensor.graph.basic import Variable, equal_computations
1111
from pytensor.graph.replace import clone_replace, vectorize_node
1212
from pytensor.graph.type import Type
1313
from pytensor.scalar.basic import ScalarConstant
1414
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
15-
from pytensor.tensor.basic import MakeVector, constant, stack
15+
from pytensor.tensor.basic import MakeVector, arange, constant, stack
1616
from pytensor.tensor.elemwise import DimShuffle, Elemwise
1717
from pytensor.tensor.shape import (
1818
Reshape,
@@ -373,6 +373,43 @@ def test_static_shape(self):
373373
):
374374
reshape(x2, (6, 3, 99))
375375

376+
def test_shape_strides(self):
377+
# Directly test the concern behind commit 223ee1548574b6bb8e73611ed605a97e29f13e7b
378+
x = arange(8)
379+
shape = vector("shape", dtype=int, shape=(3,))
380+
fn = function([shape], x.reshape(shape))
381+
382+
# Empty strides
383+
test_shape = np.broadcast_to(np.array(2), (3,))
384+
assert test_shape.strides == (0,)
385+
np.testing.assert_array_equal(
386+
fn(test_shape),
387+
np.arange(8).reshape(test_shape),
388+
)
389+
390+
# Negative non-contiguous strides
391+
test_shape = np.array([0, 4, 0, 2, 0, 1])[::-2]
392+
assert np.all(test_shape == (1, 2, 4))
393+
assert test_shape.strides == (-16,)
394+
np.testing.assert_array_equal(
395+
fn(test_shape),
396+
np.arange(8).reshape(test_shape),
397+
)
398+
399+
def test_benchmark(self, benchmark):
400+
x = tensor3("x")
401+
x_val = np.random.random((2, 3, 4)).astype(config.floatX)
402+
y1 = x.reshape((6, 4))
403+
y2 = x.reshape((2, 12))
404+
y3 = x.reshape((-1,))
405+
# Borrow to avoid deepcopy overhead
406+
reshape_fn = pytensor.function(
407+
[In(x, borrow=True)],
408+
[Out(y1, borrow=True), Out(y2, borrow=True), Out(y3, borrow=True)],
409+
)
410+
reshape_fn.trust_input = True
411+
benchmark(reshape_fn, x_val)
412+
376413

377414
def test_shape_i_hash():
378415
assert isinstance(Shape_i(np.int64(1)).__hash__(), int)

0 commit comments

Comments
 (0)