Skip to content

Commit ebc4263

Browse files
committed
Revert regression in DimShuffle C-impl speed
Introduced in e593b0a due to a bug when inputs had zero-strides. The bug can be fixed just by removing a block that assumed some `full`/`broadcasting` behavior by the operation, but this is not happening with DimShuffle.
1 parent 003e268 commit ebc4263

File tree

3 files changed

+116
-74
lines changed

3 files changed

+116
-74
lines changed

pytensor/tensor/c_code/dimshuffle.c

+80-69
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,93 @@
11
#section support_code_apply
22

3-
int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res,
4-
PARAMS_TYPE *params) {
5-
6-
// This points to either the original input or a copy we create below.
7-
// Either way, this is what we should be working on/with.
8-
PyArrayObject *_input;
9-
10-
if (*res)
11-
Py_XDECREF(*res);
12-
13-
if (params->inplace) {
14-
_input = input;
15-
Py_INCREF((PyObject *)_input);
16-
} else {
17-
_input = (PyArrayObject *)PyArray_FromAny(
18-
(PyObject *)input, NULL, 0, 0, NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY,
19-
NULL);
20-
}
21-
22-
PyArray_Dims permute;
23-
24-
if (!PyArray_IntpConverter((PyObject *)params->transposition, &permute)) {
25-
return 1;
26-
}
27-
28-
/*
29-
res = res.transpose(self.transposition)
30-
*/
31-
PyArrayObject *transposed_input =
32-
(PyArrayObject *)PyArray_Transpose(_input, &permute);
33-
34-
Py_DECREF(_input);
35-
36-
PyDimMem_FREE(permute.ptr);
3+
int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PARAMS_TYPE *params) {
4+
npy_int64* new_order;
5+
npy_intp nd_in;
6+
npy_intp nd_out;
7+
npy_intp* dimensions;
8+
npy_intp* strides;
9+
10+
// This points to either the original input or a copy we create below.
11+
// Either way, this is what we should be working on/with.
12+
PyArrayObject *_input;
13+
14+
if (!PyArray_IS_C_CONTIGUOUS(params->_new_order)) {
15+
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous.");
16+
return 1;
17+
}
18+
new_order = (npy_int64*) PyArray_DATA(params->_new_order);
19+
nd_in = (npy_intp)(params->input_ndim);
20+
nd_out = PyArray_SIZE(params->_new_order);
3721

38-
npy_intp *res_shape = PyArray_DIMS(transposed_input);
39-
npy_intp N_shuffle = PyArray_SIZE(params->shuffle);
40-
npy_intp N_augment = PyArray_SIZE(params->augment);
41-
npy_intp N = N_augment + N_shuffle;
42-
npy_intp *_reshape_shape = PyDimMem_NEW(N);
22+
if (PyArray_NDIM(input) != nd_in) {
23+
PyErr_SetString(PyExc_NotImplementedError, "DimShuffle: Input has less dimensions than expected.");
24+
return 1;
25+
}
4326

44-
if (_reshape_shape == NULL) {
45-
PyErr_NoMemory();
46-
return 1;
47-
}
27+
if (*res)
28+
Py_XDECREF(*res);
4829

49-
/*
50-
shape = list(res.shape[: len(self.shuffle)])
51-
for augm in self.augment:
52-
shape.insert(augm, 1)
53-
*/
54-
npy_intp aug_idx = 0;
55-
int res_idx = 0;
56-
for (npy_intp i = 0; i < N; i++) {
57-
if (aug_idx < N_augment &&
58-
i == *((npy_intp *)PyArray_GetPtr(params->augment, &aug_idx))) {
59-
_reshape_shape[i] = 1;
60-
aug_idx++;
30+
if (params->inplace) {
31+
_input = input;
32+
Py_INCREF((PyObject*)_input);
6133
} else {
62-
_reshape_shape[i] = res_shape[res_idx];
63-
res_idx++;
34+
_input = (PyArrayObject *)PyArray_FromAny(
35+
(PyObject *)input, NULL, 0, 0, NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY,
36+
NULL);
6437
}
65-
}
6638

67-
PyArray_Dims reshape_shape = {.ptr = _reshape_shape, .len = (int)N};
39+
// Compute new dimensions and strides
40+
dimensions = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
41+
strides = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
42+
if (dimensions == NULL || strides == NULL) {
43+
PyErr_NoMemory();
44+
free(dimensions);
45+
free(strides);
46+
return 1;
47+
};
48+
49+
npy_intp original_size = PyArray_SIZE(_input);
50+
npy_intp new_size = 1;
51+
for (npy_intp i = 0; i < nd_out; ++i) {
52+
if (new_order[i] != -1) {
53+
dimensions[i] = PyArray_DIMS(_input)[new_order[i]];
54+
strides[i] = PyArray_DIMS(_input)[new_order[i]] == 1 ? 0 : PyArray_STRIDES(_input)[new_order[i]];
55+
} else {
56+
dimensions[i] = 1;
57+
strides[i] = 0;
58+
}
59+
new_size *= dimensions[i];
60+
}
6861

69-
/* res = res.reshape(shape) */
70-
*res = (PyArrayObject *)PyArray_Newshape(transposed_input, &reshape_shape,
71-
NPY_CORDER);
62+
if (original_size != new_size) {
63+
PyErr_SetString(PyExc_ValueError, "DimShuffle: Attempting to squeeze axes with size not equal to one.");
64+
free(dimensions);
65+
free(strides);
66+
return 1;
67+
}
7268

73-
Py_DECREF(transposed_input);
69+
// Create the new array.
70+
*res = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions,
71+
PyArray_TYPE(_input), strides,
72+
PyArray_DATA(_input), PyArray_ITEMSIZE(_input),
73+
// borrow only the writable flag from the base
74+
// the NPY_OWNDATA flag will default to 0.
75+
(NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(_input)),
76+
NULL);
77+
78+
if (*res == NULL) {
79+
free(dimensions);
80+
free(strides);
81+
return 1;
82+
}
7483

75-
PyDimMem_FREE(reshape_shape.ptr);
84+
// recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
85+
PyArray_UpdateFlags(*res, NPY_ARRAY_UPDATE_ALL);
7686

77-
if (!*res) {
78-
return 1;
79-
}
87+
// we are making a view in both inplace and non-inplace cases
88+
PyArray_SetBaseObject(*res, (PyObject*)_input);
8089

81-
return 0;
82-
}
90+
free(strides);
91+
free(dimensions);
92+
return 0;
93+
}

pytensor/tensor/elemwise.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytensor.scalar import get_scalar_type
2222
from pytensor.scalar.basic import bool as scalar_bool
2323
from pytensor.scalar.basic import identity as scalar_identity
24-
from pytensor.scalar.basic import transfer_type, upcast
24+
from pytensor.scalar.basic import int64, transfer_type, upcast
2525
from pytensor.tensor import elemwise_cgen as cgen
2626
from pytensor.tensor import get_vector_length
2727
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
@@ -121,12 +121,18 @@ class DimShuffle(ExternalCOp):
121121
@property
122122
def params_type(self):
123123
return ParamsType(
124-
shuffle=lvector,
125-
augment=lvector,
126-
transposition=lvector,
124+
_new_order=lvector,
127125
inplace=scalar_bool,
126+
input_ndim=int64,
128127
)
129128

129+
@property
130+
def _new_order(self):
131+
# Param for C code.
132+
# self.new_order may contain 'x', which is not a valid integer value.
133+
# We replace it with -1.
134+
return [(-1 if x == "x" else x) for x in self.new_order]
135+
130136
def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
131137
super().__init__([self.c_func_file], self.c_func_name)
132138

tests/tensor/test_elemwise.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import math
23
import re
34
import tracemalloc
@@ -10,6 +11,7 @@
1011
import pytensor.scalar as ps
1112
import pytensor.tensor as pt
1213
import tests.unittest_tools as utt
14+
from pytensor import In, Out
1315
from pytensor.compile.function import function
1416
from pytensor.compile.mode import Mode
1517
from pytensor.configdefaults import config
@@ -35,6 +37,7 @@
3537
matrix,
3638
scalar,
3739
tensor,
40+
tensor3,
3841
vector,
3942
vectors,
4043
)
@@ -158,11 +161,14 @@ def test_c_views(self):
158161
# as the broadcasted value; that way, we'll be able to tell that we're getting
159162
# junk data from a poorly constructed array view.
160163
x_val = np.broadcast_to(2039, (5000,))
161-
for i in range(1000):
164+
expected_x_val = x_val[None]
165+
for i in range(1):
162166
inputs[0].storage[0] = x_val
163167
thunk()
164168
# Make sure it's a view of the original data
165169
assert np.shares_memory(x_val, outputs[0].storage[0])
170+
# Confirm the right strides
171+
assert outputs[0].storage[0].strides == expected_x_val.strides
166172
# Confirm the broadcasted value in the output
167173
assert np.array_equiv(outputs[0].storage[0], 2039)
168174

@@ -212,6 +218,25 @@ def test_valid_input_ndim(self):
212218
with pytest.raises(TypeError, match="input_ndim must be an integer"):
213219
DimShuffle(input_ndim=(True, False), new_order=(1, 0))
214220

221+
def test_benchmark(self, benchmark):
222+
x = tensor3("x")
223+
x_val = np.random.random((2, 3, 4)).astype(config.floatX)
224+
ys = [x.transpose(t) for t in itertools.permutations((0, 1, 2))]
225+
ys += [
226+
x[None],
227+
x[:, None],
228+
x[:, :, None],
229+
x[:, :, :, None],
230+
]
231+
# Borrow to avoid deepcopy overhead
232+
fn = pytensor.function(
233+
[In(x, borrow=True)],
234+
[Out(y, borrow=True) for y in ys],
235+
)
236+
fn.dprint()
237+
fn.trust_input = True
238+
benchmark(fn, x_val)
239+
215240

216241
class TestBroadcast:
217242
# this is to allow other types to reuse this class to test their ops

0 commit comments

Comments
 (0)