Skip to content

Speedup DimShuffle and Reshape in C and Numba backends #1226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions pytensor/link/jax/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,7 @@ def dimshuffle(x):
for augm in op.augment:
shape.insert(augm, 1)

res = jnp.reshape(res, shape)

if not op.inplace:
res = jnp.copy(res)

return res
return jnp.reshape(res, shape)

return dimshuffle

Expand Down
108 changes: 25 additions & 83 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numba
import numpy as np
from numba.core.extending import overload
from numpy.lib.stride_tricks import as_strided

from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic
Expand Down Expand Up @@ -411,97 +412,38 @@ def numba_funcify_CAReduce(op, node, **kwargs):

@numba_funcify.register(DimShuffle)
def numba_funcify_DimShuffle(op, node, **kwargs):
shuffle = tuple(op.shuffle)
transposition = tuple(op.transposition)
augment = tuple(op.augment)
inplace = op.inplace
# We use `as_strided` to achieve the DimShuffle behavior of transposing and expanding/squezing dimensions in one call
# Numba doesn't currently support multiple expand/squeeze, and reshape is limited to contiguous arrays.
new_order = tuple(op._new_order)
shape_template = (1,) * node.outputs[0].ndim
strides_template = (0,) * node.outputs[0].ndim

ndim_new_shape = len(shuffle) + len(augment)

no_transpose = all(i == j for i, j in enumerate(transposition))
if no_transpose:

@numba_basic.numba_njit
def transpose(x):
return x

else:
if new_order == ():
# Special case needed because of https://github.com/numba/numba/issues/9933

@numba_basic.numba_njit
def transpose(x):
return np.transpose(x, transposition)

shape_template = (1,) * ndim_new_shape

# When `len(shuffle) == 0`, the `shuffle_shape[j]` expression below
# is typed as `getitem(Tuple(), int)`, which has no implementation
# (since getting an item from an empty sequence doesn't make sense).
# To avoid this compile-time error, we omit the expression altogether.
if len(shuffle) > 0:
# Use the statically known shape if available
if all(length is not None for length in node.outputs[0].type.shape):
shape = node.outputs[0].type.shape

@numba_basic.numba_njit
def find_shape(array_shape):
return shape

else:

@numba_basic.numba_njit
def find_shape(array_shape):
shape = shape_template
j = 0
for i in range(ndim_new_shape):
if i not in augment:
length = array_shape[j]
shape = numba_basic.tuple_setitem(shape, i, length)
j = j + 1
return shape

else:

@numba_basic.numba_njit
def find_shape(array_shape):
return shape_template

if ndim_new_shape > 0:

@numba_basic.numba_njit
def dimshuffle_inner(x, shuffle):
x = transpose(x)
shuffle_shape = x.shape[: len(shuffle)]
new_shape = find_shape(shuffle_shape)

# FIXME: Numba's `array.reshape` only accepts C arrays.
res_reshape = np.reshape(np.ascontiguousarray(x), new_shape)
def squeeze_to_0d(x):
return as_strided(x, shape=(), strides=())

if not inplace:
return res_reshape.copy()
else:
return res_reshape
return squeeze_to_0d

else:

@numba_basic.numba_njit
def dimshuffle_inner(x, shuffle):
return np.reshape(np.ascontiguousarray(x), ())

# Without the following wrapper function we would see this error:
# E No implementation of function Function(<built-in function getitem>) found for signature:
# E
# E >>> getitem(UniTuple(int64 x 2), slice<a:b>)
# E
# E There are 22 candidate implementations:
# E - Of which 22 did not match due to:
# E Overload of function 'getitem': File: <numerous>: Line N/A.
# E With argument(s): '(UniTuple(int64 x 2), slice<a:b>)':
# E No match.
# ...(on this line)...
# E shuffle_shape = res.shape[: len(shuffle)]
@numba_basic.numba_njit(inline="always")
def dimshuffle(x):
return dimshuffle_inner(np.asarray(x), shuffle)
def dimshuffle(x):
old_shape = x.shape
old_strides = x.strides

new_shape = shape_template
new_strides = strides_template
for i, o in enumerate(new_order):
if o != -1:
new_shape = numba_basic.tuple_setitem(new_shape, i, old_shape[o])
new_strides = numba_basic.tuple_setitem(
new_strides, i, old_strides[o]
)

return as_strided(x, shape=new_shape, strides=new_strides)

return dimshuffle

Expand Down
7 changes: 1 addition & 6 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,7 @@
for augm in op.augment:
shape.insert(augm, 1)

res = torch.reshape(res, shape)

if not op.inplace:
res = res.clone()

return res
return torch.reshape(res, shape)

Check warning on line 64 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L64

Added line #L64 was not covered by tests

return dimshuffle

Expand Down
143 changes: 71 additions & 72 deletions pytensor/tensor/c_code/dimshuffle.c
Original file line number Diff line number Diff line change
@@ -1,82 +1,81 @@
#section support_code_apply

int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res,
PARAMS_TYPE *params) {

// This points to either the original input or a copy we create below.
// Either way, this is what we should be working on/with.
PyArrayObject *_input;

if (*res)
Py_XDECREF(*res);

if (params->inplace) {
_input = input;
Py_INCREF((PyObject *)_input);
} else {
_input = (PyArrayObject *)PyArray_FromAny(
(PyObject *)input, NULL, 0, 0, NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY,
NULL);
}

PyArray_Dims permute;

if (!PyArray_IntpConverter((PyObject *)params->transposition, &permute)) {
return 1;
}

/*
res = res.transpose(self.transposition)
*/
PyArrayObject *transposed_input =
(PyArrayObject *)PyArray_Transpose(_input, &permute);

Py_DECREF(_input);

PyDimMem_FREE(permute.ptr);

npy_intp *res_shape = PyArray_DIMS(transposed_input);
npy_intp N_shuffle = PyArray_SIZE(params->shuffle);
npy_intp N_augment = PyArray_SIZE(params->augment);
npy_intp N = N_augment + N_shuffle;
npy_intp *_reshape_shape = PyDimMem_NEW(N);

if (_reshape_shape == NULL) {
PyErr_NoMemory();
return 1;
}
int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PARAMS_TYPE *params) {
npy_int64* new_order;
npy_intp nd_in;
npy_intp nd_out;
npy_intp* dimensions;
npy_intp* strides;

if (!PyArray_IS_C_CONTIGUOUS(params->_new_order)) {
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous.");
return 1;
}
new_order = (npy_int64*) PyArray_DATA(params->_new_order);
nd_in = (npy_intp)(params->input_ndim);
nd_out = PyArray_SIZE(params->_new_order);

/*
shape = list(res.shape[: len(self.shuffle)])
for augm in self.augment:
shape.insert(augm, 1)
*/
npy_intp aug_idx = 0;
int res_idx = 0;
for (npy_intp i = 0; i < N; i++) {
if (aug_idx < N_augment &&
i == *((npy_intp *)PyArray_GetPtr(params->augment, &aug_idx))) {
_reshape_shape[i] = 1;
aug_idx++;
} else {
_reshape_shape[i] = res_shape[res_idx];
res_idx++;
if (PyArray_NDIM(input) != nd_in) {
PyErr_SetString(PyExc_ValueError, "DimShuffle: Input has less dimensions than expected.");
return 1;
}
}

PyArray_Dims reshape_shape = {.ptr = _reshape_shape, .len = (int)N};
// Compute new dimensions and strides
dimensions = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
strides = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
Comment on lines +24 to +25
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I don't know why we're using the externalCOp approach for this. It used to be a normal COp before, and then you didn't need this sort of alloc/params stuff...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (dimensions == NULL || strides == NULL) {
PyErr_NoMemory();
free(dimensions);
free(strides);
return 1;
};

npy_intp original_size = PyArray_SIZE(input);
npy_intp new_size = 1;
for (npy_intp i = 0; i < nd_out; ++i) {
if (new_order[i] != -1) {
dimensions[i] = PyArray_DIMS(input)[new_order[i]];
strides[i] = PyArray_DIMS(input)[new_order[i]] == 1 ? 0 : PyArray_STRIDES(input)[new_order[i]];
} else {
dimensions[i] = 1;
strides[i] = 0;
}
new_size *= dimensions[i];
}

/* res = res.reshape(shape) */
*res = (PyArrayObject *)PyArray_Newshape(transposed_input, &reshape_shape,
NPY_CORDER);
if (original_size != new_size) {
PyErr_SetString(PyExc_ValueError, "DimShuffle: Attempting to squeeze axes with size not equal to one.");
free(dimensions);
free(strides);
return 1;
}

Py_DECREF(transposed_input);
if (*res)
Py_XDECREF(*res);

// Create the new array.
*res = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions,
PyArray_TYPE(input), strides,
PyArray_DATA(input), PyArray_ITEMSIZE(input),
// borrow only the writable flag from the base
// the NPY_OWNDATA flag will default to 0.
(NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(input)),
NULL);

if (*res == NULL) {
free(dimensions);
free(strides);
return 1;
}

PyDimMem_FREE(reshape_shape.ptr);
// Declare it a view of the original input
Py_INCREF((PyObject*)input);
PyArray_SetBaseObject(*res, (PyObject*)input);

if (!*res) {
return 1;
}
// recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
PyArray_UpdateFlags(*res, NPY_ARRAY_UPDATE_ALL);

return 0;
}
free(strides);
free(dimensions);
return 0;
}
35 changes: 15 additions & 20 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import bool as scalar_bool
from pytensor.scalar.basic import identity as scalar_identity
from pytensor.scalar.basic import transfer_type, upcast
from pytensor.scalar.basic import int64, transfer_type, upcast
from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
Expand Down Expand Up @@ -114,17 +113,16 @@ class DimShuffle(ExternalCOp):

_f16_ok = True
check_input = False
__props__ = ("input_ndim", "new_order", "inplace")
__props__ = ("input_ndim", "new_order")
c_func_file = "c_code/dimshuffle.c"
c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"
view_map = {0: [0]}

@property
def params_type(self):
return ParamsType(
shuffle=lvector,
augment=lvector,
transposition=lvector,
inplace=scalar_bool,
_new_order=lvector,
input_ndim=int64,
)

def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
Expand All @@ -135,7 +133,7 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):

self.input_ndim = input_ndim
self.new_order = tuple(new_order)
self.inplace = True
self._new_order = [(-1 if x == "x" else x) for x in self.new_order]

for i, j in enumerate(new_order):
if j != "x":
Expand Down Expand Up @@ -178,9 +176,6 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
:input_ndim
] == list(range(input_ndim))

if self.inplace:
self.view_map = {0: [0]}

def __setstate__(self, state):
self.__dict__.update(state)
if not hasattr(self, "func_files"):
Expand Down Expand Up @@ -231,10 +226,15 @@ def __str__(self):

def perform(self, node, inp, out):
(res,) = inp
(storage,) = out

if not isinstance(res, np.ndarray | np.memmap):
raise TypeError(res)
# This C-like impl is very slow in Python compared to transpose+reshape
# new_order = self._new_order
# old_shape = inp.shape
# old_strides = inp.strides
# res = as_strided(
# shape = [1 if i == -1 else old_shape[i] for i in new_order],
# strides=[0 if i == -1 else old_strides[i] for i in new_order],
# )

# Put dropped axis at end
res = res.transpose(self.transposition)
Expand All @@ -243,12 +243,7 @@ def perform(self, node, inp, out):
new_shape = list(res.shape[: len(self.shuffle)])
for augm in self.augment:
new_shape.insert(augm, 1)
res = res.reshape(new_shape)

if not self.inplace:
res = np.copy(res)

storage[0] = np.asarray(res)
out[0][0] = res.reshape(new_shape)

def infer_shape(self, fgraph, node, shapes):
(ishp,) = shapes
Expand Down
Loading