Skip to content

Commit 28387c1

Browse files
committed
Remove unused inplace option in DimShuffle
1 parent 355f4ee commit 28387c1

File tree

5 files changed

+14
-49
lines changed

5 files changed

+14
-49
lines changed

pytensor/link/jax/dispatch/elemwise.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,7 @@ def dimshuffle(x):
7979
for augm in op.augment:
8080
shape.insert(augm, 1)
8181

82-
res = jnp.reshape(res, shape)
83-
84-
if not op.inplace:
85-
res = jnp.copy(res)
86-
87-
return res
82+
return jnp.reshape(res, shape)
8883

8984
return dimshuffle
9085

pytensor/link/numba/dispatch/elemwise.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,6 @@ def numba_funcify_DimShuffle(op, node, **kwargs):
414414
shuffle = tuple(op.shuffle)
415415
transposition = tuple(op.transposition)
416416
augment = tuple(op.augment)
417-
inplace = op.inplace
418417

419418
ndim_new_shape = len(shuffle) + len(augment)
420419

@@ -474,12 +473,7 @@ def dimshuffle_inner(x, shuffle):
474473
new_shape = find_shape(shuffle_shape)
475474

476475
# FIXME: Numba's `array.reshape` only accepts C arrays.
477-
res_reshape = np.reshape(np.ascontiguousarray(x), new_shape)
478-
479-
if not inplace:
480-
return res_reshape.copy()
481-
else:
482-
return res_reshape
476+
return np.reshape(np.ascontiguousarray(x), new_shape)
483477

484478
else:
485479

pytensor/link/pytorch/dispatch/elemwise.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,7 @@ def dimshuffle(x):
6161
for augm in op.augment:
6262
shape.insert(augm, 1)
6363

64-
res = torch.reshape(res, shape)
65-
66-
if not op.inplace:
67-
res = res.clone()
68-
69-
return res
64+
return torch.reshape(res, shape)
7065

7166
return dimshuffle
7267

pytensor/tensor/c_code/dimshuffle.c

+8-18
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
77
npy_intp* dimensions;
88
npy_intp* strides;
99

10-
// This points to either the original input or a copy we create below.
11-
// Either way, this is what we should be working on/with.
12-
PyArrayObject *_input;
1310

1411
if (!PyArray_IS_C_CONTIGUOUS(params->_new_order)) {
1512
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous.");
@@ -34,12 +31,12 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
3431
return 1;
3532
};
3633

37-
npy_intp original_size = PyArray_SIZE(_input);
34+
npy_intp original_size = PyArray_SIZE(input);
3835
npy_intp new_size = 1;
3936
for (npy_intp i = 0; i < nd_out; ++i) {
4037
if (new_order[i] != -1) {
41-
dimensions[i] = PyArray_DIMS(_input)[new_order[i]];
42-
strides[i] = PyArray_DIMS(_input)[new_order[i]] == 1 ? 0 : PyArray_STRIDES(_input)[new_order[i]];
38+
dimensions[i] = PyArray_DIMS(input)[new_order[i]];
39+
strides[i] = PyArray_DIMS(input)[new_order[i]] == 1 ? 0 : PyArray_STRIDES(input)[new_order[i]];
4340
} else {
4441
dimensions[i] = 1;
4542
strides[i] = 0;
@@ -57,22 +54,15 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
5754
if (*res)
5855
Py_XDECREF(*res);
5956

60-
if (params->inplace) {
61-
_input = input;
62-
Py_INCREF((PyObject*)_input);
63-
} else {
64-
_input = (PyArrayObject *)PyArray_FromAny(
65-
(PyObject *)input, NULL, 0, 0, NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY,
66-
NULL);
67-
}
57+
Py_INCREF((PyObject*)input);
6858

6959
// Create the new array.
7060
*res = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions,
71-
PyArray_TYPE(_input), strides,
72-
PyArray_DATA(_input), PyArray_ITEMSIZE(_input),
61+
PyArray_TYPE(input), strides,
62+
PyArray_DATA(input), PyArray_ITEMSIZE(input),
7363
// borrow only the writable flag from the base
7464
// the NPY_OWNDATA flag will default to 0.
75-
(NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(_input)),
65+
(NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(input)),
7666
NULL);
7767

7868
if (*res == NULL) {
@@ -85,7 +75,7 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
8575
PyArray_UpdateFlags(*res, NPY_ARRAY_UPDATE_ALL);
8676

8777
// we are making a view in both inplace and non-inplace cases
88-
PyArray_SetBaseObject(*res, (PyObject*)_input);
78+
PyArray_SetBaseObject(*res, (PyObject*)input);
8979

9080
free(strides);
9181
free(dimensions);

pytensor/tensor/elemwise.py

+3-12
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,15 @@ class DimShuffle(ExternalCOp):
114114

115115
_f16_ok = True
116116
check_input = False
117-
__props__ = ("input_ndim", "new_order", "inplace")
117+
__props__ = ("input_ndim", "new_order")
118118
c_func_file = "c_code/dimshuffle.c"
119119
c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"
120+
view_map = {0: [0]}
120121

121122
@property
122123
def params_type(self):
123124
return ParamsType(
124125
_new_order=lvector,
125-
inplace=scalar_bool,
126126
input_ndim=int64,
127127
)
128128

@@ -135,7 +135,6 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
135135
self.input_ndim = input_ndim
136136
self.new_order = tuple(new_order)
137137
self._new_order = [(-1 if x == "x" else x) for x in self.new_order]
138-
self.inplace = True
139138

140139
for i, j in enumerate(new_order):
141140
if j != "x":
@@ -178,9 +177,6 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
178177
:input_ndim
179178
] == list(range(input_ndim))
180179

181-
if self.inplace:
182-
self.view_map = {0: [0]}
183-
184180
def __setstate__(self, state):
185181
self.__dict__.update(state)
186182
if not hasattr(self, "func_files"):
@@ -248,12 +244,7 @@ def perform(self, node, inp, out):
248244
new_shape = list(res.shape[: len(self.shuffle)])
249245
for augm in self.augment:
250246
new_shape.insert(augm, 1)
251-
res = res.reshape(new_shape)
252-
253-
if not self.inplace:
254-
res = np.copy(res)
255-
256-
out[0][0] = res
247+
out[0][0] = res.reshape(new_shape)
257248

258249
def infer_shape(self, fgraph, node, shapes):
259250
(ishp,) = shapes

0 commit comments

Comments
 (0)