Skip to content

Commit 87ef5fb

Browse files
authored
resolve a few issues (#138)
* resolve a few issues * address comemnts * add a new test * Apply suggestions from code review
1 parent 9a87930 commit 87ef5fb

File tree

6 files changed

+129
-96
lines changed

6 files changed

+129
-96
lines changed

mkl_fft/_numpy_fft.py

+13-16
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,13 @@ def _check_norm(norm):
7171

7272

7373
def frwd_sc_1d(n, s):
74-
nn = n if n else s
74+
nn = n if n is not None else s
7575
return 1/nn if nn != 0 else 1
7676

7777

78-
def frwd_sc_nd(s, axes, x_shape):
78+
def frwd_sc_nd(s, x_shape):
7979
ss = s if s is not None else x_shape
80-
if axes is not None:
81-
nn = prod([ss[ai] for ai in axes])
82-
else:
83-
nn = prod(ss)
80+
nn = prod(ss)
8481
return 1/nn if nn != 0 else 1
8582

8683

@@ -815,14 +812,14 @@ def fftn(a, s=None, axes=None, norm=None):
815812
if norm in (None, "backward"):
816813
fsc = 1.0
817814
elif norm == "forward":
818-
fsc = frwd_sc_nd(s, axes, x.shape)
815+
fsc = frwd_sc_nd(s, x.shape)
819816
else:
820-
fsc = sqrt(frwd_sc_nd(s, axes, x.shape))
817+
fsc = sqrt(frwd_sc_nd(s, x.shape))
821818

822819
return trycall(
823820
mkl_fft.fftn,
824821
(x,),
825-
{'shape': s, 'axes': axes,
822+
{'s': s, 'axes': axes,
826823
'fwd_scale': fsc})
827824

828825

@@ -931,14 +928,14 @@ def ifftn(a, s=None, axes=None, norm=None):
931928
if norm in (None, "backward"):
932929
fsc = 1.0
933930
elif norm == "forward":
934-
fsc = frwd_sc_nd(s, axes, x.shape)
931+
fsc = frwd_sc_nd(s, x.shape)
935932
else:
936-
fsc = sqrt(frwd_sc_nd(s, axes, x.shape))
933+
fsc = sqrt(frwd_sc_nd(s, x.shape))
937934

938935
return trycall(
939936
mkl_fft.ifftn,
940937
(x,),
941-
{'shape': s, 'axes': axes,
938+
{'s': s, 'axes': axes,
942939
'fwd_scale': fsc})
943940

944941

@@ -1230,11 +1227,11 @@ def rfftn(a, s=None, axes=None, norm=None):
12301227
elif norm == "forward":
12311228
x = asanyarray(x)
12321229
s, axes = _cook_nd_args(x, s, axes)
1233-
fsc = frwd_sc_nd(s, axes, x.shape)
1230+
fsc = frwd_sc_nd(s, x.shape)
12341231
else:
12351232
x = asanyarray(x)
12361233
s, axes = _cook_nd_args(x, s, axes)
1237-
fsc = sqrt(frwd_sc_nd(s, axes, x.shape))
1234+
fsc = sqrt(frwd_sc_nd(s, x.shape))
12381235

12391236
return trycall(
12401237
mkl_fft.rfftn,
@@ -1387,11 +1384,11 @@ def irfftn(a, s=None, axes=None, norm=None):
13871384
elif norm == "forward":
13881385
x = asanyarray(x)
13891386
s, axes = _cook_nd_args(x, s, axes, invreal=1)
1390-
fsc = frwd_sc_nd(s, axes, x.shape)
1387+
fsc = frwd_sc_nd(s, x.shape)
13911388
else:
13921389
x = asanyarray(x)
13931390
s, axes = _cook_nd_args(x, s, axes, invreal=1)
1394-
fsc = sqrt(frwd_sc_nd(s, axes, x.shape))
1391+
fsc = sqrt(frwd_sc_nd(s, x.shape))
13951392

13961393
return trycall(
13971394
mkl_fft.irfftn,

mkl_fft/_pydfti.pyx

+50-44
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,11 @@ cdef int _datacopied(cnp.ndarray arr, object orig):
157157

158158

159159
def fft(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0):
160-
return _fft1d_impl(x, n=n, axis=axis, overwrite_arg=overwrite_x, direction=+1, fsc=fwd_scale)
160+
return _fft1d_impl(x, n=n, axis=axis, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale)
161161

162162

163163
def ifft(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0):
164-
return _fft1d_impl(x, n=n, axis=axis, overwrite_arg=overwrite_x, direction=-1, fsc=fwd_scale)
164+
return _fft1d_impl(x, n=n, axis=axis, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale)
165165

166166

167167
cdef cnp.ndarray pad_array(cnp.ndarray x_arr, cnp.npy_intp n, int axis, int realQ):
@@ -200,7 +200,7 @@ cdef cnp.ndarray pad_array(cnp.ndarray x_arr, cnp.npy_intp n, int axis, int real
200200

201201

202202
cdef cnp.ndarray __process_arguments(object x, object n, object axis,
203-
object overwrite_arg, object direction,
203+
object overwrite_x, object direction,
204204
long *axis_, long *n_, int *in_place,
205205
int *xnd, int *dir_, int realQ):
206206
"Internal utility to validate and process input arguments of 1D FFT functions"
@@ -213,7 +213,7 @@ cdef cnp.ndarray __process_arguments(object x, object n, object axis,
213213
else:
214214
dir_[0] = -1 if direction is -1 else +1
215215

216-
in_place[0] = 1 if overwrite_arg is True else 0
216+
in_place[0] = 1 if overwrite_x else 0
217217

218218
# convert x to ndarray, ensure that strides are multiples of itemsize
219219
x_arr = PyArray_CheckFromAny(
@@ -294,7 +294,7 @@ cdef cnp.ndarray __allocate_result(cnp.ndarray x_arr, long n_, long axis_, int f
294294
# Float/double inputs are not cast to complex, but are effectively
295295
# treated as complexes with zero imaginary parts.
296296
# All other types are cast to complex double.
297-
def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fsc=1.0):
297+
def _fft1d_impl(x, n=None, axis=-1, overwrite_x=False, direction=+1, double fsc=1.0):
298298
"""
299299
Uses MKL to perform 1D FFT on the input array x along the given axis.
300300
"""
@@ -308,7 +308,7 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fs
308308
cdef bytes py_error_msg
309309
cdef DftiCache *_cache
310310

311-
x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
311+
x_arr = __process_arguments(x, n, axis, overwrite_x, direction,
312312
&axis_, &n_, &in_place, &xnd, &dir_, 0)
313313

314314
x_type = cnp.PyArray_TYPE(x_arr)
@@ -410,12 +410,12 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fs
410410

411411
def rfftpack(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0):
412412
"""Packed real-valued harmonics of FFT of a real sequence x"""
413-
return _rr_fft1d_impl2(x, n=n, axis=axis, overwrite_arg=overwrite_x, fsc=fwd_scale)
413+
return _rr_fft1d_impl2(x, n=n, axis=axis, overwrite_x=overwrite_x, fsc=fwd_scale)
414414

415415

416416
def irfftpack(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0):
417417
"""Inverse FFT of a real sequence, takes packed real-valued harmonics of FFT"""
418-
return _rr_ifft1d_impl2(x, n=n, axis=axis, overwrite_arg=overwrite_x, fsc=fwd_scale)
418+
return _rr_ifft1d_impl2(x, n=n, axis=axis, overwrite_x=overwrite_x, fsc=fwd_scale)
419419

420420

421421
cdef object _rc_to_rr(cnp.ndarray rc_arr, int n, int axis, int xnd, int x_type):
@@ -520,12 +520,12 @@ def _repack_rc_to_rr(x, n, axis):
520520
return _rc_to_rr(x, n_, axis_, cnp.PyArray_NDIM(x_arr), x_type)
521521

522522

523-
def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
523+
def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0):
524524
"""
525525
Uses MKL to perform real packed 1D FFT on the input array x along the given axis.
526526
527527
This done by using rfft and post-processing the result.
528-
Thus overwrite_arg is effectively discarded.
528+
Thus overwrite_x is effectively discarded.
529529
530530
Functionally equivalent to scipy.fftpack.rfft
531531
"""
@@ -539,7 +539,7 @@ def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
539539
cdef bytes py_error_msg
540540
cdef DftiCache *_cache
541541

542-
x_arr = __process_arguments(x, n, axis, overwrite_arg, <object>(+1),
542+
x_arr = __process_arguments(x, n, axis, overwrite_x, <object>(+1),
543543
&axis_, &n_, &in_place, &xnd, &dir_, 1)
544544

545545
x_type = cnp.PyArray_TYPE(x_arr)
@@ -576,12 +576,12 @@ def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
576576
return _rc_to_rr(f_arr, n_, axis_, xnd, x_type)
577577

578578

579-
def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
579+
def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0):
580580
"""
581581
Uses MKL to perform real packed 1D FFT on the input array x along the given axis.
582582
583583
This done by using rfft and post-processing the result.
584-
Thus overwrite_arg is effectively discarded.
584+
Thus overwrite_x is effectively discarded.
585585
586586
Functionally equivalent to scipy.fftpack.irfft
587587
"""
@@ -595,7 +595,7 @@ def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
595595
cdef bytes py_error_msg
596596
cdef DftiCache *_cache
597597

598-
x_arr = __process_arguments(x, n, axis, overwrite_arg, <object>(-1),
598+
x_arr = __process_arguments(x, n, axis, overwrite_x, <object>(-1),
599599
&axis_, &n_, &in_place, &xnd, &dir_, 1)
600600

601601
x_type = cnp.PyArray_TYPE(x_arr)
@@ -645,7 +645,7 @@ def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
645645

646646

647647
# this routine is functionally equivalent to numpy.fft.rfft
648-
def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
648+
def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0):
649649
"""
650650
Uses MKL to perform 1D FFT on the real input array x along the given axis,
651651
producing complex output, but giving only half of the harmonics.
@@ -663,13 +663,13 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
663663
cdef bytes py_error_msg
664664
cdef DftiCache *_cache
665665

666-
x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
666+
x_arr = __process_arguments(x, n, axis, overwrite_x, direction,
667667
&axis_, &n_, &in_place, &xnd, &dir_, 1)
668668

669669
x_type = cnp.PyArray_TYPE(x_arr)
670670

671671
if x_type is cnp.NPY_CFLOAT or x_type is cnp.NPY_CDOUBLE or x_type is cnp.NPY_CLONGDOUBLE:
672-
raise TypeError("1st argument must be a real sequence 1")
672+
raise TypeError("1st argument must be a real sequence.")
673673
elif x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_DOUBLE:
674674
pass
675675
else:
@@ -723,7 +723,7 @@ cdef int _is_integral(object num):
723723

724724

725725
# this routine is functionally equivalent to numpy.fft.irfft
726-
def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
726+
def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0):
727727
"""
728728
Uses MKL to perform 1D FFT on the real input array x along the given axis,
729729
producing complex output, but giving only half of the harmonics.
@@ -743,7 +743,7 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
743743
int_n = _is_integral(n)
744744
# nn gives the number elements along axis of the input that we use
745745
nn = (n // 2 + 1) if int_n and n > 0 else n
746-
x_arr = __process_arguments(x, nn, axis, overwrite_arg, direction,
746+
x_arr = __process_arguments(x, nn, axis, overwrite_x, direction,
747747
&axis_, &n_, &in_place, &xnd, &dir_, 0)
748748
n_ = 2*(n_ - 1)
749749
if int_n and (n % 2 == 1):
@@ -907,10 +907,10 @@ def _cook_nd_args(a, s=None, axes=None, invreal=0):
907907
return s, axes
908908

909909

910-
def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_arg=False, scale_function=lambda n, ind: 1.0):
910+
def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_x=False, scale_function=lambda n, ind: 1.0):
911911
a = np.asarray(a)
912912
s, axes = _init_nd_shape_and_axes(a, s, axes)
913-
ovwr = overwrite_arg
913+
ovwr = overwrite_x
914914
for ii in reversed(range(len(axes))):
915915
a = function(a, n = s[ii], axis = axes[ii], overwrite_x=ovwr, fwd_scale=scale_function(s[ii], ii))
916916
ovwr = True
@@ -959,7 +959,7 @@ def iter_complementary(x, axes, func, kwargs, result):
959959
return result
960960

961961

962-
def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):
962+
def _direct_fftnd(x, overwrite_x=False, direction=+1, double fsc=1.0):
963963
"""Perform n-dimensional FFT over all axes"""
964964
cdef int err
965965
cdef long n_max = 0
@@ -972,7 +972,7 @@ def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):
972972
else:
973973
dir_ = -1 if direction is -1 else +1
974974

975-
in_place = 1 if overwrite_arg is True else 0
975+
in_place = 1 if overwrite_x else 0
976976

977977
# convert x to ndarray, ensure that strides are multiples of itemsize
978978
x_arr = PyArray_CheckFromAny(
@@ -1069,56 +1069,56 @@ def _output_dtype(dt):
10691069
return dt
10701070

10711071

1072-
def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1, double fsc=1.0):
1072+
def _fftnd_impl(x, s=None, axes=None, overwrite_x=False, direction=+1, double fsc=1.0):
10731073
if direction not in [-1, +1]:
10741074
raise ValueError("Direction of FFT should +1 or -1")
10751075

10761076
# _direct_fftnd requires complex type, and full-dimensional transform
10771077
if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1:
1078-
_direct = shape is None and axes is None
1078+
_direct = s is None and axes is None
10791079
if _direct:
10801080
_direct = x.ndim <= 7 # Intel MKL only supports FFT up to 7D
10811081
if not _direct:
1082-
xs, xa = _cook_nd_args(x, shape, axes)
1082+
xs, xa = _cook_nd_args(x, s, axes)
10831083
if _check_shapes_for_direct(xs, x.shape, xa):
10841084
_direct = True
10851085
_direct = _direct and x.dtype in [np.complex64, np.complex128, np.float32, np.float64]
10861086
else:
10871087
_direct = False
10881088

10891089
if _direct:
1090-
return _direct_fftnd(x, overwrite_arg=overwrite_x, direction=direction, fsc=fsc)
1090+
return _direct_fftnd(x, overwrite_x=overwrite_x, direction=direction, fsc=fsc)
10911091
else:
1092-
if (shape is None and x.dtype in [np.csingle, np.cdouble, np.single, np.double]):
1092+
if (s is None and x.dtype in [np.csingle, np.cdouble, np.single, np.double]):
10931093
x = np.asarray(x)
10941094
res = np.empty(x.shape, dtype=_output_dtype(x.dtype))
10951095
return iter_complementary(
10961096
x, axes,
10971097
_direct_fftnd,
1098-
{'overwrite_arg': overwrite_x, 'direction': direction, 'fsc': fsc},
1098+
{'overwrite_x': overwrite_x, 'direction': direction, 'fsc': fsc},
10991099
res
11001100
)
11011101
else:
11021102
sc = <object> fsc
1103-
return _iter_fftnd(x, s=shape, axes=axes,
1104-
overwrite_arg=overwrite_x, scale_function=lambda n, i: sc if i == 0 else 1.,
1103+
return _iter_fftnd(x, s=s, axes=axes,
1104+
overwrite_x=overwrite_x, scale_function=lambda n, i: sc if i == 0 else 1.,
11051105
function=fft if direction == 1 else ifft)
11061106

11071107

1108-
def fft2(x, shape=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0):
1109-
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale)
1108+
def fft2(x, s=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0):
1109+
return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale)
11101110

11111111

1112-
def ifft2(x, shape=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0):
1113-
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale)
1112+
def ifft2(x, s=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0):
1113+
return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale)
11141114

11151115

1116-
def fftn(x, shape=None, axes=None, overwrite_x=False, fwd_scale=1.0):
1117-
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale)
1116+
def fftn(x, s=None, axes=None, overwrite_x=False, fwd_scale=1.0):
1117+
return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale)
11181118

11191119

1120-
def ifftn(x, shape=None, axes=None, overwrite_x=False, fwd_scale=1.0):
1121-
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale)
1120+
def ifftn(x, s=None, axes=None, overwrite_x=False, fwd_scale=1.0):
1121+
return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale)
11221122

11231123

11241124
def rfft2(x, s=None, axes=(-2,-1), fwd_scale=1.0):
@@ -1154,7 +1154,7 @@ cdef cnp.ndarray _trim_array(cnp.ndarray arr, object s, object axes):
11541154
raise ValueError("Invalid axis (%d) specified" % ai)
11551155
if si < shp_i:
11561156
if no_trim:
1157-
ind = [slice(None,None,None),] * len(s)
1157+
ind = [slice(None,None,None),] * len(arr_shape)
11581158
no_trim = False
11591159
ind[ai] = slice(None, si, None)
11601160
if no_trim:
@@ -1203,12 +1203,12 @@ def rfftn(x, s=None, axes=None, fwd_scale=1.0):
12031203
tind = tuple(ind)
12041204
a_inp = a[tind]
12051205
a_res = _fftnd_impl(
1206-
a_inp, shape=ss, axes=aa,
1206+
a_inp, s=ss, axes=aa,
12071207
overwrite_x=True, direction=1)
12081208
if a_res is not a_inp:
12091209
a[tind] = a_res # copy in place
12101210
else:
1211-
for ii in range(len(axes)-1):
1211+
for ii in range(len(axes) - 2, -1, -1):
12121212
a = fft(a, s[ii], axes[ii], overwrite_x=True)
12131213
return a
12141214

@@ -1218,6 +1218,8 @@ def irfftn(x, s=None, axes=None, fwd_scale=1.0):
12181218
no_trim = (s is None) and (axes is None)
12191219
s, axes = _cook_nd_args(a, s, axes, invreal=True)
12201220
la = axes[-1]
1221+
if not no_trim:
1222+
a = _trim_array(a, s, axes)
12211223
if len(s) > 1:
12221224
if not no_trim:
12231225
a = _fix_dimensions(a, s, axes)
@@ -1227,14 +1229,18 @@ def irfftn(x, s=None, axes=None, fwd_scale=1.0):
12271229
if not ovr_x:
12281230
a = a.copy()
12291231
ovr_x = True
1232+
if not np.issubdtype(a.dtype, np.complexfloating):
1233+
# copy is needed, because output of complex type will be copied to input
1234+
a = a.astype(np.complex64) if a.dtype == np.float32 else a.astype(np.complex128)
1235+
ovr_x = True
12301236
ss, aa = _remove_axis(s, axes, -1)
1231-
ind = [slice(None,None,1),] * len(s)
1237+
ind = [slice(None, None, 1),] * len(s)
12321238
for ii in range(a.shape[la]):
12331239
ind[la] = ii
12341240
tind = tuple(ind)
12351241
a_inp = a[tind]
12361242
a_res = _fftnd_impl(
1237-
a_inp, shape=ss, axes=aa,
1243+
a_inp, s=ss, axes=aa,
12381244
overwrite_x=True, direction=-1)
12391245
if a_res is not a_inp:
12401246
a[tind] = a_res # copy in place

0 commit comments

Comments
 (0)