@@ -7,10 +7,6 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
7
7
npy_intp * dimensions ;
8
8
npy_intp * strides ;
9
9
10
- // This points to either the original input or a copy we create below.
11
- // Either way, this is what we should be working on/with.
12
- PyArrayObject * _input ;
13
-
14
10
if (!PyArray_IS_C_CONTIGUOUS (params -> _new_order )) {
15
11
PyErr_SetString (PyExc_RuntimeError , "DimShuffle: param _new_order must be C-contiguous." );
16
12
return 1 ;
@@ -20,7 +16,7 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
20
16
nd_out = PyArray_SIZE (params -> _new_order );
21
17
22
18
if (PyArray_NDIM (input ) != nd_in ) {
23
- PyErr_SetString (PyExc_NotImplementedError , "DimShuffle: Input has less dimensions than expected." );
19
+ PyErr_SetString (PyExc_ValueError , "DimShuffle: Input has less dimensions than expected." );
24
20
return 1 ;
25
21
}
26
22
@@ -34,12 +30,12 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
34
30
return 1 ;
35
31
};
36
32
37
- npy_intp original_size = PyArray_SIZE (_input );
33
+ npy_intp original_size = PyArray_SIZE (input );
38
34
npy_intp new_size = 1 ;
39
35
for (npy_intp i = 0 ; i < nd_out ; ++ i ) {
40
36
if (new_order [i ] != -1 ) {
41
- dimensions [i ] = PyArray_DIMS (_input )[new_order [i ]];
42
- strides [i ] = PyArray_DIMS (_input )[new_order [i ]] == 1 ? 0 : PyArray_STRIDES (_input )[new_order [i ]];
37
+ dimensions [i ] = PyArray_DIMS (input )[new_order [i ]];
38
+ strides [i ] = PyArray_DIMS (input )[new_order [i ]] == 1 ? 0 : PyArray_STRIDES (input )[new_order [i ]];
43
39
} else {
44
40
dimensions [i ] = 1 ;
45
41
strides [i ] = 0 ;
@@ -57,22 +53,13 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
57
53
if (* res )
58
54
Py_XDECREF (* res );
59
55
60
- if (params -> inplace ) {
61
- _input = input ;
62
- Py_INCREF ((PyObject * )_input );
63
- } else {
64
- _input = (PyArrayObject * )PyArray_FromAny (
65
- (PyObject * )input , NULL , 0 , 0 , NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY ,
66
- NULL );
67
- }
68
-
69
56
// Create the new array.
70
57
* res = (PyArrayObject * )PyArray_New (& PyArray_Type , nd_out , dimensions ,
71
- PyArray_TYPE (_input ), strides ,
72
- PyArray_DATA (_input ), PyArray_ITEMSIZE (_input ),
58
+ PyArray_TYPE (input ), strides ,
59
+ PyArray_DATA (input ), PyArray_ITEMSIZE (input ),
73
60
// borrow only the writable flag from the base
74
61
// the NPY_OWNDATA flag will default to 0.
75
- (NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE (_input )),
62
+ (NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE (input )),
76
63
NULL );
77
64
78
65
if (* res == NULL ) {
@@ -81,12 +68,13 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
81
68
return 1 ;
82
69
}
83
70
71
+ // Declare it a view of the original input
72
+ Py_INCREF ((PyObject * )input );
73
+ PyArray_SetBaseObject (* res , (PyObject * )input );
74
+
84
75
// recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
85
76
PyArray_UpdateFlags (* res , NPY_ARRAY_UPDATE_ALL );
86
77
87
- // we are making a view in both inplace and non-inplace cases
88
- PyArray_SetBaseObject (* res , (PyObject * )_input );
89
-
90
78
free (strides );
91
79
free (dimensions );
92
80
return 0 ;
0 commit comments