@@ -157,11 +157,11 @@ cdef int _datacopied(cnp.ndarray arr, object orig):
157
157
158
158
159
159
def fft (x , n = None , axis = - 1 , overwrite_x = False , fwd_scale = 1.0 ):
160
- return _fft1d_impl (x , n = n , axis = axis , overwrite_arg = overwrite_x , direction = + 1 , fsc = fwd_scale )
160
+ return _fft1d_impl (x , n = n , axis = axis , overwrite_x = overwrite_x , direction = + 1 , fsc = fwd_scale )
161
161
162
162
163
163
def ifft (x , n = None , axis = - 1 , overwrite_x = False , fwd_scale = 1.0 ):
164
- return _fft1d_impl (x , n = n , axis = axis , overwrite_arg = overwrite_x , direction = - 1 , fsc = fwd_scale )
164
+ return _fft1d_impl (x , n = n , axis = axis , overwrite_x = overwrite_x , direction = - 1 , fsc = fwd_scale )
165
165
166
166
167
167
cdef cnp .ndarray pad_array (cnp .ndarray x_arr , cnp .npy_intp n , int axis , int realQ ):
@@ -200,7 +200,7 @@ cdef cnp.ndarray pad_array(cnp.ndarray x_arr, cnp.npy_intp n, int axis, int real
200
200
201
201
202
202
cdef cnp .ndarray __process_arguments (object x , object n , object axis ,
203
- object overwrite_arg , object direction ,
203
+ object overwrite_x , object direction ,
204
204
long * axis_ , long * n_ , int * in_place ,
205
205
int * xnd , int * dir_ , int realQ ):
206
206
"Internal utility to validate and process input arguments of 1D FFT functions"
@@ -213,7 +213,7 @@ cdef cnp.ndarray __process_arguments(object x, object n, object axis,
213
213
else :
214
214
dir_ [0 ] = - 1 if direction is - 1 else + 1
215
215
216
- in_place [0 ] = 1 if overwrite_arg is True else 0
216
+ in_place [0 ] = 1 if overwrite_x else 0
217
217
218
218
# convert x to ndarray, ensure that strides are multiples of itemsize
219
219
x_arr = PyArray_CheckFromAny (
@@ -294,7 +294,7 @@ cdef cnp.ndarray __allocate_result(cnp.ndarray x_arr, long n_, long axis_, int f
294
294
# Float/double inputs are not cast to complex, but are effectively
295
295
# treated as complexes with zero imaginary parts.
296
296
# All other types are cast to complex double.
297
- def _fft1d_impl (x , n = None , axis = - 1 , overwrite_arg = False , direction = + 1 , double fsc = 1.0 ):
297
+ def _fft1d_impl (x , n = None , axis = - 1 , overwrite_x = False , direction = + 1 , double fsc = 1.0 ):
298
298
"""
299
299
Uses MKL to perform 1D FFT on the input array x along the given axis.
300
300
"""
@@ -308,7 +308,7 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fs
308
308
cdef bytes py_error_msg
309
309
cdef DftiCache * _cache
310
310
311
- x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
311
+ x_arr = __process_arguments (x , n , axis , overwrite_x , direction ,
312
312
& axis_ , & n_ , & in_place , & xnd , & dir_ , 0 )
313
313
314
314
x_type = cnp .PyArray_TYPE (x_arr )
@@ -410,12 +410,12 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fs
410
410
411
411
def rfftpack (x , n = None , axis = - 1 , overwrite_x = False , fwd_scale = 1.0 ):
412
412
"""Packed real-valued harmonics of FFT of a real sequence x"""
413
- return _rr_fft1d_impl2 (x , n = n , axis = axis , overwrite_arg = overwrite_x , fsc = fwd_scale )
413
+ return _rr_fft1d_impl2 (x , n = n , axis = axis , overwrite_x = overwrite_x , fsc = fwd_scale )
414
414
415
415
416
416
def irfftpack (x , n = None , axis = - 1 , overwrite_x = False , fwd_scale = 1.0 ):
417
417
"""Inverse FFT of a real sequence, takes packed real-valued harmonics of FFT"""
418
- return _rr_ifft1d_impl2 (x , n = n , axis = axis , overwrite_arg = overwrite_x , fsc = fwd_scale )
418
+ return _rr_ifft1d_impl2 (x , n = n , axis = axis , overwrite_x = overwrite_x , fsc = fwd_scale )
419
419
420
420
421
421
cdef object _rc_to_rr (cnp .ndarray rc_arr , int n , int axis , int xnd , int x_type ):
@@ -520,12 +520,12 @@ def _repack_rc_to_rr(x, n, axis):
520
520
return _rc_to_rr (x , n_ , axis_ , cnp .PyArray_NDIM (x_arr ), x_type )
521
521
522
522
523
- def _rr_fft1d_impl2 (x , n = None , axis = - 1 , overwrite_arg = False , double fsc = 1.0 ):
523
+ def _rr_fft1d_impl2 (x , n = None , axis = - 1 , overwrite_x = False , double fsc = 1.0 ):
524
524
"""
525
525
Uses MKL to perform real packed 1D FFT on the input array x along the given axis.
526
526
527
527
This done by using rfft and post-processing the result.
528
- Thus overwrite_arg is effectively discarded.
528
+ Thus overwrite_x is effectively discarded.
529
529
530
530
Functionally equivalent to scipy.fftpack.rfft
531
531
"""
@@ -539,7 +539,7 @@ def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
539
539
cdef bytes py_error_msg
540
540
cdef DftiCache * _cache
541
541
542
- x_arr = __process_arguments (x , n , axis , overwrite_arg , < object > (+ 1 ),
542
+ x_arr = __process_arguments (x , n , axis , overwrite_x , < object > (+ 1 ),
543
543
& axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
544
544
545
545
x_type = cnp .PyArray_TYPE (x_arr )
@@ -576,12 +576,12 @@ def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
576
576
return _rc_to_rr (f_arr , n_ , axis_ , xnd , x_type )
577
577
578
578
579
- def _rr_ifft1d_impl2 (x , n = None , axis = - 1 , overwrite_arg = False , double fsc = 1.0 ):
579
+ def _rr_ifft1d_impl2 (x , n = None , axis = - 1 , overwrite_x = False , double fsc = 1.0 ):
580
580
"""
581
581
Uses MKL to perform real packed 1D FFT on the input array x along the given axis.
582
582
583
583
This done by using rfft and post-processing the result.
584
- Thus overwrite_arg is effectively discarded.
584
+ Thus overwrite_x is effectively discarded.
585
585
586
586
Functionally equivalent to scipy.fftpack.irfft
587
587
"""
@@ -595,7 +595,7 @@ def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
595
595
cdef bytes py_error_msg
596
596
cdef DftiCache * _cache
597
597
598
- x_arr = __process_arguments (x , n , axis , overwrite_arg , < object > (- 1 ),
598
+ x_arr = __process_arguments (x , n , axis , overwrite_x , < object > (- 1 ),
599
599
& axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
600
600
601
601
x_type = cnp .PyArray_TYPE (x_arr )
@@ -645,7 +645,7 @@ def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
645
645
646
646
647
647
# this routine is functionally equivalent to numpy.fft.rfft
648
- def _rc_fft1d_impl (x , n = None , axis = - 1 , overwrite_arg = False , double fsc = 1.0 ):
648
+ def _rc_fft1d_impl (x , n = None , axis = - 1 , overwrite_x = False , double fsc = 1.0 ):
649
649
"""
650
650
Uses MKL to perform 1D FFT on the real input array x along the given axis,
651
651
producing complex output, but giving only half of the harmonics.
@@ -663,13 +663,13 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
663
663
cdef bytes py_error_msg
664
664
cdef DftiCache * _cache
665
665
666
- x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
666
+ x_arr = __process_arguments (x , n , axis , overwrite_x , direction ,
667
667
& axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
668
668
669
669
x_type = cnp .PyArray_TYPE (x_arr )
670
670
671
671
if x_type is cnp .NPY_CFLOAT or x_type is cnp .NPY_CDOUBLE or x_type is cnp .NPY_CLONGDOUBLE :
672
- raise TypeError ("1st argument must be a real sequence 1 " )
672
+ raise TypeError ("1st argument must be a real sequence. " )
673
673
elif x_type is cnp .NPY_FLOAT or x_type is cnp .NPY_DOUBLE :
674
674
pass
675
675
else :
@@ -723,7 +723,7 @@ cdef int _is_integral(object num):
723
723
724
724
725
725
# this routine is functionally equivalent to numpy.fft.irfft
726
- def _rc_ifft1d_impl (x , n = None , axis = - 1 , overwrite_arg = False , double fsc = 1.0 ):
726
+ def _rc_ifft1d_impl (x , n = None , axis = - 1 , overwrite_x = False , double fsc = 1.0 ):
727
727
"""
728
728
Uses MKL to perform 1D FFT on the real input array x along the given axis,
729
729
producing complex output, but giving only half of the harmonics.
@@ -743,7 +743,7 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
743
743
int_n = _is_integral (n )
744
744
# nn gives the number elements along axis of the input that we use
745
745
nn = (n // 2 + 1 ) if int_n and n > 0 else n
746
- x_arr = __process_arguments (x , nn , axis , overwrite_arg , direction ,
746
+ x_arr = __process_arguments (x , nn , axis , overwrite_x , direction ,
747
747
& axis_ , & n_ , & in_place , & xnd , & dir_ , 0 )
748
748
n_ = 2 * (n_ - 1 )
749
749
if int_n and (n % 2 == 1 ):
@@ -907,10 +907,10 @@ def _cook_nd_args(a, s=None, axes=None, invreal=0):
907
907
return s , axes
908
908
909
909
910
- def _iter_fftnd (a , s = None , axes = None , function = fft , overwrite_arg = False , scale_function = lambda n , ind : 1.0 ):
910
+ def _iter_fftnd (a , s = None , axes = None , function = fft , overwrite_x = False , scale_function = lambda n , ind : 1.0 ):
911
911
a = np .asarray (a )
912
912
s , axes = _init_nd_shape_and_axes (a , s , axes )
913
- ovwr = overwrite_arg
913
+ ovwr = overwrite_x
914
914
for ii in reversed (range (len (axes ))):
915
915
a = function (a , n = s [ii ], axis = axes [ii ], overwrite_x = ovwr , fwd_scale = scale_function (s [ii ], ii ))
916
916
ovwr = True
@@ -959,7 +959,7 @@ def iter_complementary(x, axes, func, kwargs, result):
959
959
return result
960
960
961
961
962
- def _direct_fftnd (x , overwrite_arg = False , direction = + 1 , double fsc = 1.0 ):
962
+ def _direct_fftnd (x , overwrite_x = False , direction = + 1 , double fsc = 1.0 ):
963
963
"""Perform n-dimensional FFT over all axes"""
964
964
cdef int err
965
965
cdef long n_max = 0
@@ -972,7 +972,7 @@ def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):
972
972
else :
973
973
dir_ = - 1 if direction is - 1 else + 1
974
974
975
- in_place = 1 if overwrite_arg is True else 0
975
+ in_place = 1 if overwrite_x else 0
976
976
977
977
# convert x to ndarray, ensure that strides are multiples of itemsize
978
978
x_arr = PyArray_CheckFromAny (
@@ -1069,56 +1069,56 @@ def _output_dtype(dt):
1069
1069
return dt
1070
1070
1071
1071
1072
- def _fftnd_impl (x , shape = None , axes = None , overwrite_x = False , direction = + 1 , double fsc = 1.0 ):
1072
+ def _fftnd_impl (x , s = None , axes = None , overwrite_x = False , direction = + 1 , double fsc = 1.0 ):
1073
1073
if direction not in [- 1 , + 1 ]:
1074
1074
raise ValueError ("Direction of FFT should +1 or -1" )
1075
1075
1076
1076
# _direct_fftnd requires complex type, and full-dimensional transform
1077
1077
if isinstance (x , np .ndarray ) and x .size != 0 and x .ndim > 1 :
1078
- _direct = shape is None and axes is None
1078
+ _direct = s is None and axes is None
1079
1079
if _direct :
1080
1080
_direct = x .ndim <= 7 # Intel MKL only supports FFT up to 7D
1081
1081
if not _direct :
1082
- xs , xa = _cook_nd_args (x , shape , axes )
1082
+ xs , xa = _cook_nd_args (x , s , axes )
1083
1083
if _check_shapes_for_direct (xs , x .shape , xa ):
1084
1084
_direct = True
1085
1085
_direct = _direct and x .dtype in [np .complex64 , np .complex128 , np .float32 , np .float64 ]
1086
1086
else :
1087
1087
_direct = False
1088
1088
1089
1089
if _direct :
1090
- return _direct_fftnd (x , overwrite_arg = overwrite_x , direction = direction , fsc = fsc )
1090
+ return _direct_fftnd (x , overwrite_x = overwrite_x , direction = direction , fsc = fsc )
1091
1091
else :
1092
- if (shape is None and x .dtype in [np .csingle , np .cdouble , np .single , np .double ]):
1092
+ if (s is None and x .dtype in [np .csingle , np .cdouble , np .single , np .double ]):
1093
1093
x = np .asarray (x )
1094
1094
res = np .empty (x .shape , dtype = _output_dtype (x .dtype ))
1095
1095
return iter_complementary (
1096
1096
x , axes ,
1097
1097
_direct_fftnd ,
1098
- {'overwrite_arg ' : overwrite_x , 'direction' : direction , 'fsc' : fsc },
1098
+ {'overwrite_x ' : overwrite_x , 'direction' : direction , 'fsc' : fsc },
1099
1099
res
1100
1100
)
1101
1101
else :
1102
1102
sc = < object > fsc
1103
- return _iter_fftnd (x , s = shape , axes = axes ,
1104
- overwrite_arg = overwrite_x , scale_function = lambda n , i : sc if i == 0 else 1. ,
1103
+ return _iter_fftnd (x , s = s , axes = axes ,
1104
+ overwrite_x = overwrite_x , scale_function = lambda n , i : sc if i == 0 else 1. ,
1105
1105
function = fft if direction == 1 else ifft )
1106
1106
1107
1107
1108
- def fft2 (x , shape = None , axes = (- 2 ,- 1 ), overwrite_x = False , fwd_scale = 1.0 ):
1109
- return _fftnd_impl (x , shape = shape , axes = axes , overwrite_x = overwrite_x , direction = + 1 , fsc = fwd_scale )
1108
+ def fft2 (x , s = None , axes = (- 2 ,- 1 ), overwrite_x = False , fwd_scale = 1.0 ):
1109
+ return _fftnd_impl (x , s = s , axes = axes , overwrite_x = overwrite_x , direction = + 1 , fsc = fwd_scale )
1110
1110
1111
1111
1112
- def ifft2 (x , shape = None , axes = (- 2 ,- 1 ), overwrite_x = False , fwd_scale = 1.0 ):
1113
- return _fftnd_impl (x , shape = shape , axes = axes , overwrite_x = overwrite_x , direction = - 1 , fsc = fwd_scale )
1112
+ def ifft2 (x , s = None , axes = (- 2 ,- 1 ), overwrite_x = False , fwd_scale = 1.0 ):
1113
+ return _fftnd_impl (x , s = s , axes = axes , overwrite_x = overwrite_x , direction = - 1 , fsc = fwd_scale )
1114
1114
1115
1115
1116
- def fftn (x , shape = None , axes = None , overwrite_x = False , fwd_scale = 1.0 ):
1117
- return _fftnd_impl (x , shape = shape , axes = axes , overwrite_x = overwrite_x , direction = + 1 , fsc = fwd_scale )
1116
+ def fftn (x , s = None , axes = None , overwrite_x = False , fwd_scale = 1.0 ):
1117
+ return _fftnd_impl (x , s = s , axes = axes , overwrite_x = overwrite_x , direction = + 1 , fsc = fwd_scale )
1118
1118
1119
1119
1120
- def ifftn (x , shape = None , axes = None , overwrite_x = False , fwd_scale = 1.0 ):
1121
- return _fftnd_impl (x , shape = shape , axes = axes , overwrite_x = overwrite_x , direction = - 1 , fsc = fwd_scale )
1120
+ def ifftn (x , s = None , axes = None , overwrite_x = False , fwd_scale = 1.0 ):
1121
+ return _fftnd_impl (x , s = s , axes = axes , overwrite_x = overwrite_x , direction = - 1 , fsc = fwd_scale )
1122
1122
1123
1123
1124
1124
def rfft2 (x , s = None , axes = (- 2 ,- 1 ), fwd_scale = 1.0 ):
@@ -1154,7 +1154,7 @@ cdef cnp.ndarray _trim_array(cnp.ndarray arr, object s, object axes):
1154
1154
raise ValueError ("Invalid axis (%d) specified" % ai )
1155
1155
if si < shp_i :
1156
1156
if no_trim :
1157
- ind = [slice (None ,None ,None ),] * len (s )
1157
+ ind = [slice (None ,None ,None ),] * len (arr_shape )
1158
1158
no_trim = False
1159
1159
ind [ai ] = slice (None , si , None )
1160
1160
if no_trim :
@@ -1203,12 +1203,12 @@ def rfftn(x, s=None, axes=None, fwd_scale=1.0):
1203
1203
tind = tuple (ind )
1204
1204
a_inp = a [tind ]
1205
1205
a_res = _fftnd_impl (
1206
- a_inp , shape = ss , axes = aa ,
1206
+ a_inp , s = ss , axes = aa ,
1207
1207
overwrite_x = True , direction = 1 )
1208
1208
if a_res is not a_inp :
1209
1209
a [tind ] = a_res # copy in place
1210
1210
else :
1211
- for ii in range (len (axes )- 1 ):
1211
+ for ii in range (len (axes ) - 2 , - 1 , - 1 ):
1212
1212
a = fft (a , s [ii ], axes [ii ], overwrite_x = True )
1213
1213
return a
1214
1214
@@ -1218,6 +1218,8 @@ def irfftn(x, s=None, axes=None, fwd_scale=1.0):
1218
1218
no_trim = (s is None ) and (axes is None )
1219
1219
s , axes = _cook_nd_args (a , s , axes , invreal = True )
1220
1220
la = axes [- 1 ]
1221
+ if not no_trim :
1222
+ a = _trim_array (a , s , axes )
1221
1223
if len (s ) > 1 :
1222
1224
if not no_trim :
1223
1225
a = _fix_dimensions (a , s , axes )
@@ -1227,14 +1229,18 @@ def irfftn(x, s=None, axes=None, fwd_scale=1.0):
1227
1229
if not ovr_x :
1228
1230
a = a .copy ()
1229
1231
ovr_x = True
1232
+ if not np .issubdtype (a .dtype , np .complexfloating ):
1233
+ # copy is needed, because output of complex type will be copied to input
1234
+ a = a .astype (np .complex64 ) if a .dtype == np .float32 else a .astype (np .complex128 )
1235
+ ovr_x = True
1230
1236
ss , aa = _remove_axis (s , axes , - 1 )
1231
- ind = [slice (None ,None ,1 ),] * len (s )
1237
+ ind = [slice (None , None , 1 ),] * len (s )
1232
1238
for ii in range (a .shape [la ]):
1233
1239
ind [la ] = ii
1234
1240
tind = tuple (ind )
1235
1241
a_inp = a [tind ]
1236
1242
a_res = _fftnd_impl (
1237
- a_inp , shape = ss , axes = aa ,
1243
+ a_inp , s = ss , axes = aa ,
1238
1244
overwrite_x = True , direction = - 1 )
1239
1245
if a_res is not a_inp :
1240
1246
a [tind ] = a_res # copy in place
0 commit comments