|
16 | 16 | from pytensor.link.c.op import COp
|
17 | 17 | from pytensor.link.c.params_type import ParamsType
|
18 | 18 | from pytensor.npy_2_compat import normalize_axis_tuple
|
19 |
| -from pytensor.scalar import int32 |
20 | 19 | from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
|
21 | 20 | from pytensor.tensor import basic as ptb
|
22 | 21 | from pytensor.tensor.elemwise import get_normalized_batch_axes
|
@@ -628,14 +627,11 @@ class Reshape(COp):
|
628 | 627 |
|
629 | 628 | check_input = False
|
630 | 629 | __props__ = ("ndim",)
|
631 |
| - params_type = ParamsType(ndim=int32) |
632 |
| - # name does not participate because it doesn't affect computations |
633 | 630 |
|
634 |
| - def __init__(self, ndim, name=None): |
| 631 | + def __init__(self, ndim): |
635 | 632 | self.ndim = int(ndim)
|
636 | 633 | if ndim < 0:
|
637 | 634 | raise ValueError("The output dimensions after reshape must be 0 or greater")
|
638 |
| - assert name is None, "name attribute for Reshape has been deprecated" |
639 | 635 |
|
640 | 636 | def __str__(self):
|
641 | 637 | return f"{self.__class__.__name__}{{{self.ndim}}}"
|
@@ -795,33 +791,32 @@ def infer_shape(self, fgraph, node, ishapes):
|
795 | 791 | ]
|
796 | 792 |
|
797 | 793 | def c_code_cache_version(self):
|
798 |
| - return (9,) |
| 794 | + return (10,) |
799 | 795 |
|
800 | 796 | def c_code(self, node, name, inputs, outputs, sub):
|
801 | 797 | x, shp = inputs
|
| 798 | + shp_dtype = node.inputs[1].type.dtype_specs()[1] |
802 | 799 | (z,) = outputs
|
803 | 800 | fail = sub["fail"]
|
804 |
| - params = sub["params"] |
| 801 | + ndim = self.ndim |
| 802 | + |
805 | 803 | return f"""
|
806 | 804 | assert (PyArray_NDIM({shp}) == 1);
|
807 | 805 |
|
808 |
| - PyArray_Dims newshape; |
809 |
| -
|
810 |
| - if (!PyArray_IntpConverter((PyObject *){shp}, &newshape)) {{ |
811 |
| - {fail}; |
| 806 | + // Unpack shape into new_dims |
| 807 | + npy_intp new_dims[{ndim}]; |
| 808 | + for (int ii = 0; ii < {ndim}; ++ii) |
| 809 | + {{ |
| 810 | + new_dims[ii] = (({shp_dtype}*)(PyArray_DATA({shp}) + ii * PyArray_STRIDES({shp})[0]))[0]; |
812 | 811 | }}
|
813 | 812 |
|
814 |
| - if ({params}->ndim != newshape.len) {{ |
815 |
| - PyErr_SetString(PyExc_ValueError, "Shape argument to Reshape has incorrect length"); |
816 |
| - PyDimMem_FREE(newshape.ptr); |
817 |
| - {fail}; |
818 |
| - }} |
| 813 | + PyArray_Dims newshape; |
| 814 | + newshape.len = {ndim}; |
| 815 | + newshape.ptr = new_dims; |
819 | 816 |
|
820 | 817 | Py_XDECREF({z});
|
821 | 818 | {z} = (PyArrayObject *) PyArray_Newshape({x}, &newshape, NPY_CORDER);
|
822 | 819 |
|
823 |
| - PyDimMem_FREE(newshape.ptr); |
824 |
| -
|
825 | 820 | if (!{z}) {{
|
826 | 821 | //The error message should have been set by PyArray_Newshape
|
827 | 822 | {fail};
|
|
0 commit comments