diff --git a/pytensor/link/jax/dispatch/elemwise.py b/pytensor/link/jax/dispatch/elemwise.py index 7d9532557b..d4c8e7b605 100644 --- a/pytensor/link/jax/dispatch/elemwise.py +++ b/pytensor/link/jax/dispatch/elemwise.py @@ -79,12 +79,7 @@ def dimshuffle(x): for augm in op.augment: shape.insert(augm, 1) - res = jnp.reshape(res, shape) - - if not op.inplace: - res = jnp.copy(res) - - return res + return jnp.reshape(res, shape) return dimshuffle diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 03c7084a8f..9fd81dadcf 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -4,6 +4,7 @@ import numba import numpy as np from numba.core.extending import overload +from numpy.lib.stride_tricks import as_strided from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic @@ -411,97 +412,38 @@ def numba_funcify_CAReduce(op, node, **kwargs): @numba_funcify.register(DimShuffle) def numba_funcify_DimShuffle(op, node, **kwargs): - shuffle = tuple(op.shuffle) - transposition = tuple(op.transposition) - augment = tuple(op.augment) - inplace = op.inplace + # We use `as_strided` to achieve the DimShuffle behavior of transposing and expanding/squezing dimensions in one call + # Numba doesn't currently support multiple expand/squeeze, and reshape is limited to contiguous arrays. + new_order = tuple(op._new_order) + shape_template = (1,) * node.outputs[0].ndim + strides_template = (0,) * node.outputs[0].ndim - ndim_new_shape = len(shuffle) + len(augment) - - no_transpose = all(i == j for i, j in enumerate(transposition)) - if no_transpose: - - @numba_basic.numba_njit - def transpose(x): - return x - - else: + if new_order == (): + # Special case needed because of https://github.com/numba/numba/issues/9933 @numba_basic.numba_njit - def transpose(x): - return np.transpose(x, transposition) - - shape_template = (1,) * ndim_new_shape - - # When `len(shuffle) == 0`, the `shuffle_shape[j]` expression below - # is typed as `getitem(Tuple(), int)`, which has no implementation - # (since getting an item from an empty sequence doesn't make sense). - # To avoid this compile-time error, we omit the expression altogether. - if len(shuffle) > 0: - # Use the statically known shape if available - if all(length is not None for length in node.outputs[0].type.shape): - shape = node.outputs[0].type.shape - - @numba_basic.numba_njit - def find_shape(array_shape): - return shape - - else: - - @numba_basic.numba_njit - def find_shape(array_shape): - shape = shape_template - j = 0 - for i in range(ndim_new_shape): - if i not in augment: - length = array_shape[j] - shape = numba_basic.tuple_setitem(shape, i, length) - j = j + 1 - return shape - - else: - - @numba_basic.numba_njit - def find_shape(array_shape): - return shape_template - - if ndim_new_shape > 0: - - @numba_basic.numba_njit - def dimshuffle_inner(x, shuffle): - x = transpose(x) - shuffle_shape = x.shape[: len(shuffle)] - new_shape = find_shape(shuffle_shape) - - # FIXME: Numba's `array.reshape` only accepts C arrays. - res_reshape = np.reshape(np.ascontiguousarray(x), new_shape) + def squeeze_to_0d(x): + return as_strided(x, shape=(), strides=()) - if not inplace: - return res_reshape.copy() - else: - return res_reshape + return squeeze_to_0d else: @numba_basic.numba_njit - def dimshuffle_inner(x, shuffle): - return np.reshape(np.ascontiguousarray(x), ()) - - # Without the following wrapper function we would see this error: - # E No implementation of function Function() found for signature: - # E - # E >>> getitem(UniTuple(int64 x 2), slice) - # E - # E There are 22 candidate implementations: - # E - Of which 22 did not match due to: - # E Overload of function 'getitem': File: : Line N/A. - # E With argument(s): '(UniTuple(int64 x 2), slice)': - # E No match. - # ...(on this line)... - # E shuffle_shape = res.shape[: len(shuffle)] - @numba_basic.numba_njit(inline="always") - def dimshuffle(x): - return dimshuffle_inner(np.asarray(x), shuffle) + def dimshuffle(x): + old_shape = x.shape + old_strides = x.strides + + new_shape = shape_template + new_strides = strides_template + for i, o in enumerate(new_order): + if o != -1: + new_shape = numba_basic.tuple_setitem(new_shape, i, old_shape[o]) + new_strides = numba_basic.tuple_setitem( + new_strides, i, old_strides[o] + ) + + return as_strided(x, shape=new_shape, strides=new_strides) return dimshuffle diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index c22945d914..a3b7683004 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -61,12 +61,7 @@ def dimshuffle(x): for augm in op.augment: shape.insert(augm, 1) - res = torch.reshape(res, shape) - - if not op.inplace: - res = res.clone() - - return res + return torch.reshape(res, shape) return dimshuffle diff --git a/pytensor/tensor/c_code/dimshuffle.c b/pytensor/tensor/c_code/dimshuffle.c index 6c67bd1bfb..b99a0ee419 100644 --- a/pytensor/tensor/c_code/dimshuffle.c +++ b/pytensor/tensor/c_code/dimshuffle.c @@ -1,82 +1,81 @@ #section support_code_apply -int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, - PARAMS_TYPE *params) { - - // This points to either the original input or a copy we create below. - // Either way, this is what we should be working on/with. - PyArrayObject *_input; - - if (*res) - Py_XDECREF(*res); - - if (params->inplace) { - _input = input; - Py_INCREF((PyObject *)_input); - } else { - _input = (PyArrayObject *)PyArray_FromAny( - (PyObject *)input, NULL, 0, 0, NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY, - NULL); - } - - PyArray_Dims permute; - - if (!PyArray_IntpConverter((PyObject *)params->transposition, &permute)) { - return 1; - } - - /* - res = res.transpose(self.transposition) - */ - PyArrayObject *transposed_input = - (PyArrayObject *)PyArray_Transpose(_input, &permute); - - Py_DECREF(_input); - - PyDimMem_FREE(permute.ptr); - - npy_intp *res_shape = PyArray_DIMS(transposed_input); - npy_intp N_shuffle = PyArray_SIZE(params->shuffle); - npy_intp N_augment = PyArray_SIZE(params->augment); - npy_intp N = N_augment + N_shuffle; - npy_intp *_reshape_shape = PyDimMem_NEW(N); - - if (_reshape_shape == NULL) { - PyErr_NoMemory(); - return 1; - } +int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PARAMS_TYPE *params) { + npy_int64* new_order; + npy_intp nd_in; + npy_intp nd_out; + npy_intp* dimensions; + npy_intp* strides; + + if (!PyArray_IS_C_CONTIGUOUS(params->_new_order)) { + PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous."); + return 1; + } + new_order = (npy_int64*) PyArray_DATA(params->_new_order); + nd_in = (npy_intp)(params->input_ndim); + nd_out = PyArray_SIZE(params->_new_order); - /* - shape = list(res.shape[: len(self.shuffle)]) - for augm in self.augment: - shape.insert(augm, 1) - */ - npy_intp aug_idx = 0; - int res_idx = 0; - for (npy_intp i = 0; i < N; i++) { - if (aug_idx < N_augment && - i == *((npy_intp *)PyArray_GetPtr(params->augment, &aug_idx))) { - _reshape_shape[i] = 1; - aug_idx++; - } else { - _reshape_shape[i] = res_shape[res_idx]; - res_idx++; + if (PyArray_NDIM(input) != nd_in) { + PyErr_SetString(PyExc_ValueError, "DimShuffle: Input has less dimensions than expected."); + return 1; } - } - PyArray_Dims reshape_shape = {.ptr = _reshape_shape, .len = (int)N}; + // Compute new dimensions and strides + dimensions = (npy_intp*) malloc(nd_out * sizeof(npy_intp)); + strides = (npy_intp*) malloc(nd_out * sizeof(npy_intp)); + if (dimensions == NULL || strides == NULL) { + PyErr_NoMemory(); + free(dimensions); + free(strides); + return 1; + }; + + npy_intp original_size = PyArray_SIZE(input); + npy_intp new_size = 1; + for (npy_intp i = 0; i < nd_out; ++i) { + if (new_order[i] != -1) { + dimensions[i] = PyArray_DIMS(input)[new_order[i]]; + strides[i] = PyArray_DIMS(input)[new_order[i]] == 1 ? 0 : PyArray_STRIDES(input)[new_order[i]]; + } else { + dimensions[i] = 1; + strides[i] = 0; + } + new_size *= dimensions[i]; + } - /* res = res.reshape(shape) */ - *res = (PyArrayObject *)PyArray_Newshape(transposed_input, &reshape_shape, - NPY_CORDER); + if (original_size != new_size) { + PyErr_SetString(PyExc_ValueError, "DimShuffle: Attempting to squeeze axes with size not equal to one."); + free(dimensions); + free(strides); + return 1; + } - Py_DECREF(transposed_input); + if (*res) + Py_XDECREF(*res); + + // Create the new array. + *res = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions, + PyArray_TYPE(input), strides, + PyArray_DATA(input), PyArray_ITEMSIZE(input), + // borrow only the writable flag from the base + // the NPY_OWNDATA flag will default to 0. + (NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(input)), + NULL); + + if (*res == NULL) { + free(dimensions); + free(strides); + return 1; + } - PyDimMem_FREE(reshape_shape.ptr); + // Declare it a view of the original input + Py_INCREF((PyObject*)input); + PyArray_SetBaseObject(*res, (PyObject*)input); - if (!*res) { - return 1; - } + // recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED + PyArray_UpdateFlags(*res, NPY_ARRAY_UPDATE_ALL); - return 0; -} + free(strides); + free(dimensions); + return 0; +} \ No newline at end of file diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 37acfc8e86..d653c49296 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -19,9 +19,8 @@ from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.printing import Printer, pprint from pytensor.scalar import get_scalar_type -from pytensor.scalar.basic import bool as scalar_bool from pytensor.scalar.basic import identity as scalar_identity -from pytensor.scalar.basic import transfer_type, upcast +from pytensor.scalar.basic import int64, transfer_type, upcast from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import get_vector_length from pytensor.tensor.basic import _get_vector_length, as_tensor_variable @@ -114,17 +113,16 @@ class DimShuffle(ExternalCOp): _f16_ok = True check_input = False - __props__ = ("input_ndim", "new_order", "inplace") + __props__ = ("input_ndim", "new_order") c_func_file = "c_code/dimshuffle.c" c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)" + view_map = {0: [0]} @property def params_type(self): return ParamsType( - shuffle=lvector, - augment=lvector, - transposition=lvector, - inplace=scalar_bool, + _new_order=lvector, + input_ndim=int64, ) def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): @@ -135,7 +133,7 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): self.input_ndim = input_ndim self.new_order = tuple(new_order) - self.inplace = True + self._new_order = [(-1 if x == "x" else x) for x in self.new_order] for i, j in enumerate(new_order): if j != "x": @@ -178,9 +176,6 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): :input_ndim ] == list(range(input_ndim)) - if self.inplace: - self.view_map = {0: [0]} - def __setstate__(self, state): self.__dict__.update(state) if not hasattr(self, "func_files"): @@ -231,10 +226,15 @@ def __str__(self): def perform(self, node, inp, out): (res,) = inp - (storage,) = out - if not isinstance(res, np.ndarray | np.memmap): - raise TypeError(res) + # This C-like impl is very slow in Python compared to transpose+reshape + # new_order = self._new_order + # old_shape = inp.shape + # old_strides = inp.strides + # res = as_strided( + # shape = [1 if i == -1 else old_shape[i] for i in new_order], + # strides=[0 if i == -1 else old_strides[i] for i in new_order], + # ) # Put dropped axis at end res = res.transpose(self.transposition) @@ -243,12 +243,7 @@ def perform(self, node, inp, out): new_shape = list(res.shape[: len(self.shuffle)]) for augm in self.augment: new_shape.insert(augm, 1) - res = res.reshape(new_shape) - - if not self.inplace: - res = np.copy(res) - - storage[0] = np.asarray(res) + out[0][0] = res.reshape(new_shape) def infer_shape(self, fgraph, node, shapes): (ishp,) = shapes diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index e839ac1f08..1fc4e6dd2b 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -16,7 +16,6 @@ from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.npy_2_compat import normalize_axis_tuple -from pytensor.scalar import int32 from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor import basic as ptb from pytensor.tensor.elemwise import get_normalized_batch_axes @@ -628,14 +627,11 @@ class Reshape(COp): check_input = False __props__ = ("ndim",) - params_type = ParamsType(ndim=int32) - # name does not participate because it doesn't affect computations - def __init__(self, ndim, name=None): + def __init__(self, ndim): self.ndim = int(ndim) if ndim < 0: raise ValueError("The output dimensions after reshape must be 0 or greater") - assert name is None, "name attribute for Reshape has been deprecated" def __str__(self): return f"{self.__class__.__name__}{{{self.ndim}}}" @@ -795,33 +791,32 @@ def infer_shape(self, fgraph, node, ishapes): ] def c_code_cache_version(self): - return (9,) + return (10,) def c_code(self, node, name, inputs, outputs, sub): x, shp = inputs + shp_dtype = node.inputs[1].type.dtype_specs()[1] (z,) = outputs fail = sub["fail"] - params = sub["params"] + ndim = self.ndim + return f""" assert (PyArray_NDIM({shp}) == 1); - PyArray_Dims newshape; - - if (!PyArray_IntpConverter((PyObject *){shp}, &newshape)) {{ - {fail}; + // Unpack shape into new_dims + npy_intp new_dims[{ndim}]; + for (int ii = 0; ii < {ndim}; ++ii) + {{ + new_dims[ii] = (({shp_dtype}*)(PyArray_BYTES({shp}) + ii * PyArray_STRIDES({shp})[0]))[0]; }} - if ({params}->ndim != newshape.len) {{ - PyErr_SetString(PyExc_ValueError, "Shape argument to Reshape has incorrect length"); - PyDimMem_FREE(newshape.ptr); - {fail}; - }} + PyArray_Dims newshape; + newshape.len = {ndim}; + newshape.ptr = new_dims; Py_XDECREF({z}); {z} = (PyArrayObject *) PyArray_Newshape({x}, &newshape, NPY_CORDER); - PyDimMem_FREE(newshape.ptr); - if (!{z}) {{ //The error message should have been set by PyArray_Newshape {fail}; diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index eaa0fa951d..7ef5705c07 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -23,6 +23,7 @@ from tests.tensor.test_elemwise import ( careduce_benchmark_tester, check_elemwise_runtime_broadcast, + dimshuffle_benchmark, ) @@ -201,7 +202,7 @@ def test_Dimshuffle_returns_array(): def test_Dimshuffle_non_contiguous(): """The numba impl of reshape doesn't work with - non-contiguous arrays, make sure we work around thpt.""" + non-contiguous arrays, make sure we work around that.""" x = pt.dvector() idx = pt.vector(dtype="int64") op = DimShuffle(input_ndim=1, new_order=[]) @@ -643,3 +644,7 @@ def test_numba_careduce_benchmark(self, axis, c_contiguous, benchmark): return careduce_benchmark_tester( axis, c_contiguous, mode="NUMBA", benchmark=benchmark ) + + @pytest.mark.parametrize("c_contiguous", (True, False)) + def test_dimshuffle(self, c_contiguous, benchmark): + dimshuffle_benchmark("NUMBA", c_contiguous, benchmark) diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 5ce533d3a3..d5aac0113b 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -1,3 +1,4 @@ +import itertools import math import re import tracemalloc @@ -10,6 +11,7 @@ import pytensor.scalar as ps import pytensor.tensor as pt import tests.unittest_tools as utt +from pytensor import In, Out from pytensor.compile.function import function from pytensor.compile.mode import Mode from pytensor.configdefaults import config @@ -35,6 +37,7 @@ matrix, scalar, tensor, + tensor3, vector, vectors, ) @@ -63,6 +66,30 @@ def custom_reduce(a): return np.apply_along_axis(custom_reduce, axis, x) +def dimshuffle_benchmark(mode, c_contiguous, benchmark): + x = tensor3("x") + if c_contiguous: + x_val = np.random.random((2, 3, 4)).astype(config.floatX) + else: + x_val = np.random.random((200, 300, 400)).transpose(1, 2, 0) + ys = [x.transpose(t) for t in itertools.permutations((0, 1, 2))] + ys += [ + x[None], + x[:, None], + x[:, :, None], + x[:, :, :, None], + ] + # Borrow to avoid deepcopy overhead + fn = pytensor.function( + [In(x, borrow=True)], + [Out(y, borrow=True) for y in ys], + mode=mode, + ) + fn.trust_input = True + fn(x_val) # JIT compile for JIT backends + benchmark(fn, x_val) + + class TestDimShuffle(unittest_tools.InferShapeTester): op = DimShuffle type = TensorType @@ -158,11 +185,14 @@ def test_c_views(self): # as the broadcasted value; that way, we'll be able to tell that we're getting # junk data from a poorly constructed array view. x_val = np.broadcast_to(2039, (5000,)) - for i in range(1000): + expected_x_val = x_val[None] + for i in range(1): inputs[0].storage[0] = x_val thunk() # Make sure it's a view of the original data assert np.shares_memory(x_val, outputs[0].storage[0]) + # Confirm the right strides + assert outputs[0].storage[0].strides == expected_x_val.strides # Confirm the broadcasted value in the output assert np.array_equiv(outputs[0].storage[0], 2039) @@ -212,6 +242,10 @@ def test_valid_input_ndim(self): with pytest.raises(TypeError, match="input_ndim must be an integer"): DimShuffle(input_ndim=(True, False), new_order=(1, 0)) + @pytest.mark.parametrize("c_contiguous", [True, False]) + def test_benchmark(self, c_contiguous, benchmark): + dimshuffle_benchmark("FAST_RUN", c_contiguous, benchmark) + class TestBroadcast: # this is to allow other types to reuse this class to test their ops diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index c387152757..b03d591705 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -480,10 +480,7 @@ def test_invalid_input(self): assert f([0]) == 0 # Test that we cannot squeeze dimensions whose length is greater than 1 - with pytest.raises( - ValueError, - match="cannot reshape array of size 3 into shape ()", - ): + with pytest.raises(ValueError): f([0, 1, 2]) diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 3f0b04d45d..090819f349 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -4,7 +4,7 @@ import pytest import pytensor -from pytensor import Mode, function, grad +from pytensor import In, Mode, Out, function, grad from pytensor.compile.ops import DeepCopyOp from pytensor.configdefaults import config from pytensor.graph.basic import Variable, equal_computations @@ -12,7 +12,7 @@ from pytensor.graph.type import Type from pytensor.scalar.basic import ScalarConstant from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row -from pytensor.tensor.basic import MakeVector, constant, stack +from pytensor.tensor.basic import MakeVector, arange, constant, stack from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.shape import ( Reshape, @@ -373,6 +373,43 @@ def test_static_shape(self): ): reshape(x2, (6, 3, 99)) + def test_shape_strides(self): + # Directly test the concern behind commit 223ee1548574b6bb8e73611ed605a97e29f13e7b + x = arange(8) + shape = vector("shape", dtype=int, shape=(3,)) + fn = function([shape], x.reshape(shape)) + + # Empty strides + test_shape = np.broadcast_to(np.array(2), (3,)) + assert test_shape.strides == (0,) + np.testing.assert_array_equal( + fn(test_shape), + np.arange(8).reshape(test_shape), + ) + + # Negative non-contiguous strides + test_shape = np.array([0, 4, 0, 2, 0, 1])[::-2] + assert np.all(test_shape == (1, 2, 4)) + assert test_shape.strides == (-16,) + np.testing.assert_array_equal( + fn(test_shape), + np.arange(8).reshape(test_shape), + ) + + def test_benchmark(self, benchmark): + x = tensor3("x") + x_val = np.random.random((2, 3, 4)).astype(config.floatX) + y1 = x.reshape((6, 4)) + y2 = x.reshape((2, 12)) + y3 = x.reshape((-1,)) + # Borrow to avoid deepcopy overhead + reshape_fn = pytensor.function( + [In(x, borrow=True)], + [Out(y1, borrow=True), Out(y2, borrow=True), Out(y3, borrow=True)], + ) + reshape_fn.trust_input = True + benchmark(reshape_fn, x_val) + def test_shape_i_hash(): assert isinstance(Shape_i(np.int64(1)).__hash__(), int)