1
1
#section support_code_apply
2
2
3
- int APPLY_SPECIFIC (cpu_dimshuffle )(PyArrayObject * input , PyArrayObject * * res ,
4
- PARAMS_TYPE * params ) {
5
-
6
- // This points to either the original input or a copy we create below.
7
- // Either way, this is what we should be working on/with.
8
- PyArrayObject * _input ;
9
-
10
- if (* res )
11
- Py_XDECREF (* res );
12
-
13
- if (params -> inplace ) {
14
- _input = input ;
15
- Py_INCREF ((PyObject * )_input );
16
- } else {
17
- _input = (PyArrayObject * )PyArray_FromAny (
18
- (PyObject * )input , NULL , 0 , 0 , NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY ,
19
- NULL );
20
- }
21
-
22
- PyArray_Dims permute ;
23
-
24
- if (!PyArray_IntpConverter ((PyObject * )params -> transposition , & permute )) {
25
- return 1 ;
26
- }
27
-
28
- /*
29
- res = res.transpose(self.transposition)
30
- */
31
- PyArrayObject * transposed_input =
32
- (PyArrayObject * )PyArray_Transpose (_input , & permute );
33
-
34
- Py_DECREF (_input );
35
-
36
- PyDimMem_FREE (permute .ptr );
3
+ int APPLY_SPECIFIC (cpu_dimshuffle )(PyArrayObject * input , PyArrayObject * * res , PARAMS_TYPE * params ) {
4
+ npy_int64 * new_order ;
5
+ npy_intp nd_in ;
6
+ npy_intp nd_out ;
7
+ npy_intp * dimensions ;
8
+ npy_intp * strides ;
9
+
10
+ // This points to either the original input or a copy we create below.
11
+ // Either way, this is what we should be working on/with.
12
+ PyArrayObject * _input ;
13
+
14
+ if (!PyArray_IS_C_CONTIGUOUS (params -> _new_order )) {
15
+ PyErr_SetString (PyExc_RuntimeError , "DimShuffle: param _new_order must be C-contiguous." );
16
+ return 1 ;
17
+ }
18
+ new_order = (npy_int64 * ) PyArray_DATA (params -> _new_order );
19
+ nd_in = (npy_intp )(params -> input_ndim );
20
+ nd_out = PyArray_SIZE (params -> _new_order );
37
21
38
- npy_intp * res_shape = PyArray_DIMS (transposed_input );
39
- npy_intp N_shuffle = PyArray_SIZE (params -> shuffle );
40
- npy_intp N_augment = PyArray_SIZE (params -> augment );
41
- npy_intp N = N_augment + N_shuffle ;
42
- npy_intp * _reshape_shape = PyDimMem_NEW (N );
22
+ if (PyArray_NDIM (input ) != nd_in ) {
23
+ PyErr_SetString (PyExc_NotImplementedError , "DimShuffle: Input has less dimensions than expected." );
24
+ return 1 ;
25
+ }
43
26
44
- if (_reshape_shape == NULL ) {
45
- PyErr_NoMemory ();
46
- return 1 ;
47
- }
27
+ if (* res )
28
+ Py_XDECREF (* res );
48
29
49
- /*
50
- shape = list(res.shape[: len(self.shuffle)])
51
- for augm in self.augment:
52
- shape.insert(augm, 1)
53
- */
54
- npy_intp aug_idx = 0 ;
55
- int res_idx = 0 ;
56
- for (npy_intp i = 0 ; i < N ; i ++ ) {
57
- if (aug_idx < N_augment &&
58
- i == * ((npy_intp * )PyArray_GetPtr (params -> augment , & aug_idx ))) {
59
- _reshape_shape [i ] = 1 ;
60
- aug_idx ++ ;
30
+ if (params -> inplace ) {
31
+ _input = input ;
32
+ Py_INCREF ((PyObject * )_input );
61
33
} else {
62
- _reshape_shape [i ] = res_shape [res_idx ];
63
- res_idx ++ ;
34
+ _input = (PyArrayObject * )PyArray_FromAny (
35
+ (PyObject * )input , NULL , 0 , 0 , NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY ,
36
+ NULL );
64
37
}
65
- }
66
38
67
- PyArray_Dims reshape_shape = {.ptr = _reshape_shape , .len = (int )N };
39
+ // Compute new dimensions and strides
40
+ dimensions = (npy_intp * ) malloc (nd_out * sizeof (npy_intp ));
41
+ strides = (npy_intp * ) malloc (nd_out * sizeof (npy_intp ));
42
+ if (dimensions == NULL || strides == NULL ) {
43
+ PyErr_NoMemory ();
44
+ free (dimensions );
45
+ free (strides );
46
+ return 1 ;
47
+ };
48
+
49
+ npy_intp original_size = PyArray_SIZE (_input );
50
+ npy_intp new_size = 1 ;
51
+ for (npy_intp i = 0 ; i < nd_out ; ++ i ) {
52
+ if (new_order [i ] != -1 ) {
53
+ dimensions [i ] = PyArray_DIMS (_input )[new_order [i ]];
54
+ strides [i ] = PyArray_DIMS (_input )[new_order [i ]] == 1 ? 0 : PyArray_STRIDES (_input )[new_order [i ]];
55
+ } else {
56
+ dimensions [i ] = 1 ;
57
+ strides [i ] = 0 ;
58
+ }
59
+ new_size *= dimensions [i ];
60
+ }
68
61
69
- /* res = res.reshape(shape) */
70
- * res = (PyArrayObject * )PyArray_Newshape (transposed_input , & reshape_shape ,
71
- NPY_CORDER );
62
+ if (original_size != new_size ) {
63
+ PyErr_SetString (PyExc_ValueError , "DimShuffle: Attempting to squeeze axes with size not equal to one." );
64
+ free (dimensions );
65
+ free (strides );
66
+ return 1 ;
67
+ }
72
68
73
- Py_DECREF (transposed_input );
69
+ // Create the new array.
70
+ * res = (PyArrayObject * )PyArray_New (& PyArray_Type , nd_out , dimensions ,
71
+ PyArray_TYPE (_input ), strides ,
72
+ PyArray_DATA (_input ), PyArray_ITEMSIZE (_input ),
73
+ // borrow only the writable flag from the base
74
+ // the NPY_OWNDATA flag will default to 0.
75
+ (NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE (_input )),
76
+ NULL );
77
+
78
+ if (* res == NULL ) {
79
+ free (dimensions );
80
+ free (strides );
81
+ return 1 ;
82
+ }
74
83
75
- PyDimMem_FREE (reshape_shape .ptr );
84
+ // recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
85
+ PyArray_UpdateFlags (* res , NPY_ARRAY_UPDATE_ALL );
76
86
77
- if (!* res ) {
78
- return 1 ;
79
- }
87
+ // we are making a view in both inplace and non-inplace cases
88
+ PyArray_SetBaseObject (* res , (PyObject * )_input );
80
89
81
- return 0 ;
82
- }
90
+ free (strides );
91
+ free (dimensions );
92
+ return 0 ;
93
+ }
0 commit comments