Skip to content

Commit 89d5366

Browse files
committed
Do not introduce 0 strides for broadcastable dimensions in DimShuffle
Some poorly implemented BLAS operations don't handle them correctly
1 parent bf628c9 commit 89d5366

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

Diff for: pytensor/tensor/c_code/dimshuffle.c

+7-2
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,17 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
3333
npy_intp original_size = PyArray_SIZE(input);
3434
npy_intp new_size = 1;
3535
for (npy_intp i = 0; i < nd_out; ++i) {
36+
// We set the strides of length 1 dimensions to PyArray_ITEMSIZE(input).
37+
// The value is arbitrary, because there is never a next element.
38+
// np.expand_dims(x, 0) and x[None] do different things here.
39+
// I would prefer zero, but there are some poorly implemented BLAS operations
40+
// That don't handle zero strides correctly. At least they won't fail because of DimShuffle.
3641
if (new_order[i] != -1) {
3742
dimensions[i] = PyArray_DIMS(input)[new_order[i]];
38-
strides[i] = PyArray_DIMS(input)[new_order[i]] == 1 ? 0 : PyArray_STRIDES(input)[new_order[i]];
43+
strides[i] = PyArray_DIMS(input)[new_order[i]] == 1 ? PyArray_ITEMSIZE(input) : PyArray_STRIDES(input)[new_order[i]];
3944
} else {
4045
dimensions[i] = 1;
41-
strides[i] = 0;
46+
strides[i] = PyArray_ITEMSIZE(input);
4247
}
4348
new_size *= dimensions[i];
4449
}

Diff for: tests/tensor/test_elemwise.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -185,14 +185,13 @@ def test_c_views(self):
185185
# as the broadcasted value; that way, we'll be able to tell that we're getting
186186
# junk data from a poorly constructed array view.
187187
x_val = np.broadcast_to(2039, (5000,))
188-
expected_x_val = x_val[None]
189188
for i in range(1):
190189
inputs[0].storage[0] = x_val
191190
thunk()
192191
# Make sure it's a view of the original data
193192
assert np.shares_memory(x_val, outputs[0].storage[0])
194193
# Confirm the right strides
195-
assert outputs[0].storage[0].strides == expected_x_val.strides
194+
assert outputs[0].storage[0].strides[-1] == 0
196195
# Confirm the broadcasted value in the output
197196
assert np.array_equiv(outputs[0].storage[0], 2039)
198197

0 commit comments

Comments
 (0)