Skip to content

Commit 4e37832

Browse files
committed
Revert regression in Reshape c-impl speed
This was caused by 223ee15, which used the generic `PyArray_IntpConverter` to convert the shape numpy vector into a simple C-array for the Reshape operation. There seems to be no justification for this change.
1 parent e25e8a2 commit 4e37832

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

pytensor/tensor/shape.py

+13-18
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from pytensor.link.c.op import COp
1717
from pytensor.link.c.params_type import ParamsType
1818
from pytensor.npy_2_compat import normalize_axis_tuple
19-
from pytensor.scalar import int32
2019
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
2120
from pytensor.tensor import basic as ptb
2221
from pytensor.tensor.elemwise import get_normalized_batch_axes
@@ -628,14 +627,11 @@ class Reshape(COp):
628627

629628
check_input = False
630629
__props__ = ("ndim",)
631-
params_type = ParamsType(ndim=int32)
632-
# name does not participate because it doesn't affect computations
633630

634-
def __init__(self, ndim, name=None):
631+
def __init__(self, ndim):
635632
self.ndim = int(ndim)
636633
if ndim < 0:
637634
raise ValueError("The output dimensions after reshape must be 0 or greater")
638-
assert name is None, "name attribute for Reshape has been deprecated"
639635

640636
def __str__(self):
641637
return f"{self.__class__.__name__}{{{self.ndim}}}"
@@ -795,33 +791,32 @@ def infer_shape(self, fgraph, node, ishapes):
795791
]
796792

797793
def c_code_cache_version(self):
798-
return (9,)
794+
return (10,)
799795

800796
def c_code(self, node, name, inputs, outputs, sub):
801797
x, shp = inputs
798+
shp_dtype = node.inputs[1].type.dtype_specs()[1]
802799
(z,) = outputs
803800
fail = sub["fail"]
804-
params = sub["params"]
801+
ndim = self.ndim
802+
805803
return f"""
806804
assert (PyArray_NDIM({shp}) == 1);
807805
808-
PyArray_Dims newshape;
809-
810-
if (!PyArray_IntpConverter((PyObject *){shp}, &newshape)) {{
811-
{fail};
806+
// Unpack shape into new_dims
807+
npy_intp new_dims[{ndim}];
808+
for (int ii = 0; ii < {ndim}; ++ii)
809+
{{
810+
new_dims[ii] = (({shp_dtype}*)(PyArray_DATA({shp}) + ii * PyArray_STRIDES({shp})[0]))[0];
812811
}}
813812
814-
if ({params}->ndim != newshape.len) {{
815-
PyErr_SetString(PyExc_ValueError, "Shape argument to Reshape has incorrect length");
816-
PyDimMem_FREE(newshape.ptr);
817-
{fail};
818-
}}
813+
PyArray_Dims newshape;
814+
newshape.len = {ndim};
815+
newshape.ptr = new_dims;
819816
820817
Py_XDECREF({z});
821818
{z} = (PyArrayObject *) PyArray_Newshape({x}, &newshape, NPY_CORDER);
822819
823-
PyDimMem_FREE(newshape.ptr);
824-
825820
if (!{z}) {{
826821
//The error message should have been set by PyArray_Newshape
827822
{fail};

tests/tensor/test_shape.py

+11
Original file line numberDiff line numberDiff line change
@@ -858,3 +858,14 @@ def test_unbroadcast(self):
858858
node = unbroadcast(mat, 0).owner
859859
vect_node = vectorize_node(node, tns)
860860
assert equal_computations(vect_node.outputs, [unbroadcast(tns, 2)])
861+
862+
863+
def test_reshape_benchmark(benchmark):
864+
x = tensor3("x")
865+
x_val = np.random.random((2, 3, 4)).astype(config.floatX)
866+
867+
y = x.reshape((6, 4))
868+
reshape_fn = pytensor.function([x], y)
869+
870+
reshape_fn.trust_input = True
871+
benchmark(reshape_fn, x_val)

0 commit comments

Comments
 (0)