diff --git a/mkl_fft/_numpy_fft.py b/mkl_fft/_numpy_fft.py
index 2c832d5..0a0a07d 100644
--- a/mkl_fft/_numpy_fft.py
+++ b/mkl_fft/_numpy_fft.py
@@ -71,16 +71,13 @@ def _check_norm(norm):
 
 
 def frwd_sc_1d(n, s):
-    nn = n if n else s
+    nn = n if n is not None else s
     return 1/nn if nn != 0 else 1
 
 
-def frwd_sc_nd(s, axes, x_shape):
+def frwd_sc_nd(s, x_shape):
     ss = s if s is not None else x_shape
-    if axes is not None:
-        nn = prod([ss[ai] for ai in axes])
-    else:
-        nn = prod(ss)
+    nn = prod(ss)
     return 1/nn if nn != 0 else 1
 
 
@@ -815,14 +812,14 @@ def fftn(a, s=None, axes=None, norm=None):
     if norm in (None, "backward"):
         fsc = 1.0
     elif norm == "forward":
-        fsc = frwd_sc_nd(s, axes, x.shape)
+        fsc = frwd_sc_nd(s, x.shape)
     else:
-        fsc = sqrt(frwd_sc_nd(s, axes, x.shape))
+        fsc = sqrt(frwd_sc_nd(s, x.shape))
 
     return trycall(
         mkl_fft.fftn,
         (x,),
-        {'shape': s, 'axes': axes,
+        {'s': s, 'axes': axes,
          'fwd_scale': fsc})
 
 
@@ -931,14 +928,14 @@ def ifftn(a, s=None, axes=None, norm=None):
     if norm in (None, "backward"):
         fsc = 1.0
     elif norm == "forward":
-        fsc = frwd_sc_nd(s, axes, x.shape)
+        fsc = frwd_sc_nd(s, x.shape)
     else:
-        fsc = sqrt(frwd_sc_nd(s, axes, x.shape))
+        fsc = sqrt(frwd_sc_nd(s, x.shape))
 
     return trycall(
         mkl_fft.ifftn,
         (x,),
-        {'shape': s, 'axes': axes,
+        {'s': s, 'axes': axes,
          'fwd_scale': fsc})
 
 
@@ -1230,11 +1227,11 @@ def rfftn(a, s=None, axes=None, norm=None):
     elif norm == "forward":
         x = asanyarray(x)
         s, axes = _cook_nd_args(x, s, axes)
-        fsc = frwd_sc_nd(s, axes, x.shape)
+        fsc = frwd_sc_nd(s, x.shape)
     else:
         x = asanyarray(x)
         s, axes = _cook_nd_args(x, s, axes)
-        fsc = sqrt(frwd_sc_nd(s, axes, x.shape))
+        fsc = sqrt(frwd_sc_nd(s, x.shape))
 
     return trycall(
         mkl_fft.rfftn,
@@ -1387,11 +1384,11 @@ def irfftn(a, s=None, axes=None, norm=None):
     elif norm == "forward":
         x = asanyarray(x)
         s, axes = _cook_nd_args(x, s, axes, invreal=1)
-        fsc = frwd_sc_nd(s, axes, x.shape)
+        fsc = frwd_sc_nd(s, x.shape)
     else:
         x = asanyarray(x)
         s, axes = _cook_nd_args(x, s, axes, invreal=1)
-        fsc = sqrt(frwd_sc_nd(s, axes, x.shape))
+        fsc = sqrt(frwd_sc_nd(s, x.shape))
 
     return trycall(
         mkl_fft.irfftn,
diff --git a/mkl_fft/_pydfti.pyx b/mkl_fft/_pydfti.pyx
index 47e0bfd..ded8ce1 100644
--- a/mkl_fft/_pydfti.pyx
+++ b/mkl_fft/_pydfti.pyx
@@ -157,11 +157,11 @@ cdef int _datacopied(cnp.ndarray arr, object orig):
 
 
 def fft(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0):
-    return _fft1d_impl(x, n=n, axis=axis, overwrite_arg=overwrite_x, direction=+1, fsc=fwd_scale)
+    return _fft1d_impl(x, n=n, axis=axis, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale)
 
 
 def ifft(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0):
-    return _fft1d_impl(x, n=n, axis=axis, overwrite_arg=overwrite_x, direction=-1, fsc=fwd_scale)
+    return _fft1d_impl(x, n=n, axis=axis, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale)
 
 
 cdef cnp.ndarray pad_array(cnp.ndarray x_arr, cnp.npy_intp n, int axis, int realQ):
@@ -200,7 +200,7 @@ cdef cnp.ndarray pad_array(cnp.ndarray x_arr, cnp.npy_intp n, int axis, int real
 
 
 cdef cnp.ndarray  __process_arguments(object x, object n, object axis,
-                                      object overwrite_arg, object direction,
+                                      object overwrite_x, object direction,
                                       long *axis_, long *n_, int *in_place,
                                       int *xnd, int *dir_, int realQ):
     "Internal utility to validate and process input arguments of 1D FFT functions"
@@ -213,7 +213,7 @@ cdef cnp.ndarray  __process_arguments(object x, object n, object axis,
     else:
         dir_[0] = -1 if direction is -1 else +1
 
-    in_place[0] = 1 if overwrite_arg is True else 0
+    in_place[0] = 1 if overwrite_x else 0
 
     # convert x to ndarray, ensure that strides are multiples of itemsize
     x_arr = PyArray_CheckFromAny(
@@ -294,7 +294,7 @@ cdef cnp.ndarray __allocate_result(cnp.ndarray x_arr, long n_, long axis_, int f
 # Float/double inputs are not cast to complex, but are effectively
 # treated as complexes with zero imaginary parts.
 # All other types are cast to complex double.
-def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fsc=1.0):
+def _fft1d_impl(x, n=None, axis=-1, overwrite_x=False, direction=+1, double fsc=1.0):
     """
     Uses MKL to perform 1D FFT on the input array x along the given axis.
     """
@@ -308,7 +308,7 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fs
     cdef bytes py_error_msg
     cdef DftiCache *_cache
 
-    x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
+    x_arr = __process_arguments(x, n, axis, overwrite_x, direction,
                                 &axis_, &n_, &in_place, &xnd, &dir_, 0)
 
     x_type = cnp.PyArray_TYPE(x_arr)
@@ -410,12 +410,12 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fs
 
 def rfftpack(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0):
     """Packed real-valued harmonics of FFT of a real sequence x"""
-    return _rr_fft1d_impl2(x, n=n, axis=axis, overwrite_arg=overwrite_x, fsc=fwd_scale)
+    return _rr_fft1d_impl2(x, n=n, axis=axis, overwrite_x=overwrite_x, fsc=fwd_scale)
 
 
 def irfftpack(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0):
     """Inverse FFT of a real sequence, takes packed real-valued harmonics of FFT"""
-    return _rr_ifft1d_impl2(x, n=n, axis=axis, overwrite_arg=overwrite_x, fsc=fwd_scale)
+    return _rr_ifft1d_impl2(x, n=n, axis=axis, overwrite_x=overwrite_x, fsc=fwd_scale)
 
 
 cdef object _rc_to_rr(cnp.ndarray rc_arr, int n, int axis, int xnd, int x_type):
@@ -520,12 +520,12 @@ def _repack_rc_to_rr(x, n, axis):
     return _rc_to_rr(x, n_, axis_, cnp.PyArray_NDIM(x_arr), x_type)
 
 
-def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
+def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0):
     """
     Uses MKL to perform real packed 1D FFT on the input array x along the given axis.
 
     This done by using rfft and post-processing the result.
-    Thus overwrite_arg is effectively discarded.
+    Thus overwrite_x is effectively discarded.
 
     Functionally equivalent to scipy.fftpack.rfft
     """
@@ -539,7 +539,7 @@ def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
     cdef bytes py_error_msg
     cdef DftiCache *_cache
 
-    x_arr = __process_arguments(x, n, axis, overwrite_arg, <object>(+1),
+    x_arr = __process_arguments(x, n, axis, overwrite_x, <object>(+1),
                                 &axis_, &n_, &in_place, &xnd, &dir_, 1)
 
     x_type = cnp.PyArray_TYPE(x_arr)
@@ -576,12 +576,12 @@ def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
     return _rc_to_rr(f_arr, n_, axis_, xnd, x_type)
 
 
-def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
+def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0):
     """
     Uses MKL to perform real packed 1D FFT on the input array x along the given axis.
 
     This done by using rfft and post-processing the result.
-    Thus overwrite_arg is effectively discarded.
+    Thus overwrite_x is effectively discarded.
 
     Functionally equivalent to scipy.fftpack.irfft
     """
@@ -595,7 +595,7 @@ def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
     cdef bytes py_error_msg
     cdef DftiCache *_cache
 
-    x_arr = __process_arguments(x, n, axis, overwrite_arg, <object>(-1),
+    x_arr = __process_arguments(x, n, axis, overwrite_x, <object>(-1),
                                 &axis_, &n_, &in_place, &xnd, &dir_, 1)
 
     x_type = cnp.PyArray_TYPE(x_arr)
@@ -645,7 +645,7 @@ def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
 
 
 # this routine is functionally equivalent to numpy.fft.rfft
-def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
+def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0):
     """
     Uses MKL to perform 1D FFT on the real input array x along the given axis,
     producing complex output, but giving only half of the harmonics.
@@ -663,13 +663,13 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
     cdef bytes py_error_msg
     cdef DftiCache *_cache
 
-    x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
+    x_arr = __process_arguments(x, n, axis, overwrite_x, direction,
                                 &axis_, &n_, &in_place, &xnd, &dir_, 1)
 
     x_type = cnp.PyArray_TYPE(x_arr)
 
     if x_type is cnp.NPY_CFLOAT or x_type is cnp.NPY_CDOUBLE or x_type is cnp.NPY_CLONGDOUBLE:
-        raise TypeError("1st argument must be a real sequence 1")
+        raise TypeError("1st argument must be a real sequence.")
     elif x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_DOUBLE:
         pass
     else:
@@ -723,7 +723,7 @@ cdef int _is_integral(object num):
 
 
 # this routine is functionally equivalent to numpy.fft.irfft
-def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
+def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0):
     """
     Uses MKL to perform 1D FFT on the real input array x along the given axis,
     producing complex output, but giving only half of the harmonics.
@@ -743,7 +743,7 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
     int_n = _is_integral(n)
     # nn gives the number elements along axis of the input that we use
     nn = (n // 2 + 1) if int_n and n > 0 else n
-    x_arr = __process_arguments(x, nn, axis, overwrite_arg, direction,
+    x_arr = __process_arguments(x, nn, axis, overwrite_x, direction,
                                 &axis_, &n_, &in_place, &xnd, &dir_, 0)
     n_ = 2*(n_ - 1)
     if int_n and (n % 2 == 1):
@@ -907,10 +907,10 @@ def _cook_nd_args(a, s=None, axes=None, invreal=0):
     return s, axes
 
 
-def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_arg=False, scale_function=lambda n, ind: 1.0):
+def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_x=False, scale_function=lambda n, ind: 1.0):
     a = np.asarray(a)
     s, axes = _init_nd_shape_and_axes(a, s, axes)
-    ovwr = overwrite_arg
+    ovwr = overwrite_x
     for ii in reversed(range(len(axes))):
         a = function(a, n = s[ii], axis = axes[ii], overwrite_x=ovwr, fwd_scale=scale_function(s[ii], ii))
         ovwr = True
@@ -959,7 +959,7 @@ def iter_complementary(x, axes, func, kwargs, result):
     return result
 
 
-def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):
+def _direct_fftnd(x, overwrite_x=False, direction=+1, double fsc=1.0):
     """Perform n-dimensional FFT over all axes"""
     cdef int err
     cdef long n_max = 0
@@ -972,7 +972,7 @@ def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):
     else:
         dir_ = -1 if direction is -1 else +1
 
-    in_place = 1 if overwrite_arg is True else 0
+    in_place = 1 if overwrite_x else 0
 
     # convert x to ndarray, ensure that strides are multiples of itemsize
     x_arr = PyArray_CheckFromAny(
@@ -1069,17 +1069,17 @@ def _output_dtype(dt):
     return dt
 
 
-def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1, double fsc=1.0):
+def _fftnd_impl(x, s=None, axes=None, overwrite_x=False, direction=+1, double fsc=1.0):
     if direction not in [-1, +1]:
         raise ValueError("Direction of FFT should +1 or -1")
 
     # _direct_fftnd requires complex type, and full-dimensional transform
     if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1:
-        _direct = shape is None and axes is None
+        _direct = s is None and axes is None
         if _direct:
             _direct = x.ndim <= 7 # Intel MKL only supports FFT up to 7D
         if not _direct:
-            xs, xa = _cook_nd_args(x, shape, axes)
+            xs, xa = _cook_nd_args(x, s, axes)
             if _check_shapes_for_direct(xs, x.shape, xa):
                 _direct = True
         _direct = _direct and x.dtype in [np.complex64, np.complex128, np.float32, np.float64]
@@ -1087,38 +1087,38 @@ def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1, doubl
         _direct = False
 
     if _direct:
-        return _direct_fftnd(x, overwrite_arg=overwrite_x, direction=direction, fsc=fsc)
+        return _direct_fftnd(x, overwrite_x=overwrite_x, direction=direction, fsc=fsc)
     else:
-        if (shape is None and x.dtype in [np.csingle, np.cdouble, np.single, np.double]):
+        if (s is None and x.dtype in [np.csingle, np.cdouble, np.single, np.double]):
             x = np.asarray(x)
             res = np.empty(x.shape, dtype=_output_dtype(x.dtype))
             return iter_complementary(
                 x, axes,
                 _direct_fftnd,
-                {'overwrite_arg': overwrite_x, 'direction': direction, 'fsc': fsc},
+                {'overwrite_x': overwrite_x, 'direction': direction, 'fsc': fsc},
                 res
                 )
         else:
             sc = <object> fsc
-            return _iter_fftnd(x, s=shape, axes=axes,
-                               overwrite_arg=overwrite_x, scale_function=lambda n, i: sc if i == 0 else 1.,
+            return _iter_fftnd(x, s=s, axes=axes,
+                               overwrite_x=overwrite_x, scale_function=lambda n, i: sc if i == 0 else 1.,
                                function=fft if direction == 1 else ifft)
 
 
-def fft2(x, shape=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0):
-    return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale)
+def fft2(x, s=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0):
+    return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale)
 
 
-def ifft2(x, shape=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0):
-    return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale)
+def ifft2(x, s=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0):
+    return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale)
 
 
-def fftn(x, shape=None, axes=None, overwrite_x=False, fwd_scale=1.0):
-    return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale)
+def fftn(x, s=None, axes=None, overwrite_x=False, fwd_scale=1.0):
+    return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale)
 
 
-def ifftn(x, shape=None, axes=None, overwrite_x=False, fwd_scale=1.0):
-    return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale)
+def ifftn(x, s=None, axes=None, overwrite_x=False, fwd_scale=1.0):
+    return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale)
 
 
 def rfft2(x, s=None, axes=(-2,-1), fwd_scale=1.0):
@@ -1154,7 +1154,7 @@ cdef cnp.ndarray _trim_array(cnp.ndarray arr, object s, object axes):
             raise ValueError("Invalid axis (%d) specified" % ai)
         if si < shp_i:
             if no_trim:
-                ind = [slice(None,None,None),] * len(s)
+                ind = [slice(None,None,None),] * len(arr_shape)
             no_trim = False
             ind[ai] = slice(None, si, None)
     if no_trim:
@@ -1203,12 +1203,12 @@ def rfftn(x, s=None, axes=None, fwd_scale=1.0):
                 tind = tuple(ind)
                 a_inp = a[tind]
                 a_res = _fftnd_impl(
-                    a_inp, shape=ss, axes=aa,
+                    a_inp, s=ss, axes=aa,
                     overwrite_x=True, direction=1)
                 if a_res is not a_inp:
                     a[tind] = a_res # copy in place
         else:
-            for ii in range(len(axes)-1):
+            for ii in range(len(axes) - 2, -1, -1):
                 a = fft(a, s[ii], axes[ii], overwrite_x=True)
     return a
 
@@ -1218,6 +1218,8 @@ def irfftn(x, s=None, axes=None, fwd_scale=1.0):
     no_trim = (s is None) and (axes is None)
     s, axes = _cook_nd_args(a, s, axes, invreal=True)
     la = axes[-1]
+    if not no_trim:
+        a = _trim_array(a, s, axes)
     if len(s) > 1:
         if not no_trim:
             a = _fix_dimensions(a, s, axes)
@@ -1227,14 +1229,18 @@ def irfftn(x, s=None, axes=None, fwd_scale=1.0):
             if not ovr_x:
                 a = a.copy()
                 ovr_x = True
+            if not np.issubdtype(a.dtype, np.complexfloating):
+                # copy is needed, because output of complex type will be copied to input
+                a = a.astype(np.complex64) if a.dtype == np.float32 else a.astype(np.complex128)
+                ovr_x = True
             ss, aa = _remove_axis(s, axes, -1)
-            ind = [slice(None,None,1),] * len(s)
+            ind = [slice(None, None, 1),] * len(s)
             for ii in range(a.shape[la]):
                 ind[la] = ii
                 tind = tuple(ind)
                 a_inp = a[tind]
                 a_res = _fftnd_impl(
-                    a_inp, shape=ss, axes=aa,
+                    a_inp, s=ss, axes=aa,
                     overwrite_x=True, direction=-1)
                 if a_res is not a_inp:
                     a[tind] = a_res # copy in place
diff --git a/mkl_fft/_scipy_fft.py b/mkl_fft/_scipy_fft.py
index 4d3c9ac..a9dd98b 100644
--- a/mkl_fft/_scipy_fft.py
+++ b/mkl_fft/_scipy_fft.py
@@ -200,16 +200,13 @@ def _check_plan(plan):
 
 
 def _frwd_sc_1d(n, s):
-    nn = n if n else s
+    nn = n if n is not None else s
     return 1/nn if nn != 0 else 1
 
 
-def _frwd_sc_nd(s, axes, x_shape):
+def _frwd_sc_nd(s, x_shape):
     ss = s if s is not None else x_shape
-    if axes is not None:
-        nn = prod([ss[ai] for ai in axes])
-    else:
-        nn = prod(ss)
+    nn = prod(ss)
     return 1/nn if nn != 0 else 1
 
 
@@ -233,9 +230,9 @@ def _compute_nd_fwd_scale(norm, s, axes, x_shape):
     if norm in (None, "backward"):
         fsc = 1.0
     elif norm == "forward":
-        fsc = _frwd_sc_nd(s, axes, x_shape)
+        fsc = _frwd_sc_nd(s, x_shape)
     elif norm == "ortho":
-        fsc = sqrt(_frwd_sc_nd(s, axes, x_shape))
+        fsc = sqrt(_frwd_sc_nd(s, x_shape))
     else:
         _check_norm(norm)
     return fsc
@@ -279,7 +276,7 @@ def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None, pl
     fsc = _compute_nd_fwd_scale(norm, s, axes, x.shape)
     _check_plan(plan)
     with Workers(workers):
-        output = mkl_fft.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
+        output = mkl_fft.fftn(x, s=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
     return output
 
 
@@ -293,7 +290,7 @@ def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None, p
     fsc = _compute_nd_fwd_scale(norm, s, axes, x.shape)
     _check_plan(plan)
     with Workers(workers):
-        output = mkl_fft.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
+        output = mkl_fft.ifftn(x, s=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
     return output
 
 
@@ -307,7 +304,7 @@ def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None, plan=
     fsc = _compute_nd_fwd_scale(norm, s, axes, x.shape)
     _check_plan(plan)
     with Workers(workers):
-        output = mkl_fft.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
+        output = mkl_fft.fftn(x, s=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
     return output
 
 
@@ -321,7 +318,7 @@ def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None, plan
     fsc = _compute_nd_fwd_scale(norm, s, axes, x.shape)
     _check_plan(plan)
     with Workers(workers):
-        output = mkl_fft.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
+        output = mkl_fft.ifftn(x, s=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
     return output
 
 
@@ -359,10 +356,10 @@ def _compute_nd_fwd_scale_for_rfft(norm, s, axes, x, invreal=False):
         fsc = 1.0
     elif norm == "forward":
         s, axes = _cook_nd_args(x, s, axes, invreal=invreal)
-        fsc = _frwd_sc_nd(s, axes, x.shape)
+        fsc = _frwd_sc_nd(s, x.shape)
     elif norm == "ortho":
         s, axes = _cook_nd_args(x, s, axes, invreal=invreal)
-        fsc = sqrt(_frwd_sc_nd(s, axes, x.shape))
+        fsc = sqrt(_frwd_sc_nd(s, x.shape))
     else:
         _check_norm(norm)
     return s, axes, fsc
diff --git a/mkl_fft/tests/test_fftnd.py b/mkl_fft/tests/test_fftnd.py
index 20c73e4..da02abd 100644
--- a/mkl_fft/tests/test_fftnd.py
+++ b/mkl_fft/tests/test_fftnd.py
@@ -31,7 +31,7 @@
 from numpy import random as rnd
 import sys
 import warnings
-
+import pytest
 import mkl_fft
 
 reps_64 = (2**11)*np.finfo(np.float64).eps
@@ -162,7 +162,7 @@ def test_gh64(self):
         a = np.arange(12).reshape((3,4))
         x = a.astype(np.cdouble)
         # should executed successfully
-        r1 = mkl_fft.fftn(a, shape=None, axes=(-2,-1))
+        r1 = mkl_fft.fftn(a, s=None, axes=(-2,-1))
         r2 = mkl_fft.fftn(x)
         r_tol, a_tol = _get_rtol_atol(x)
         assert_allclose(r1, r2, rtol=r_tol, atol=a_tol)
@@ -223,8 +223,43 @@ def test_gh109():
     b_int = np.array([[5, 7, 6, 5], [4, 6, 4, 8], [9, 3, 7, 5]], dtype=np.int64)
     b = np.asarray(b_int, dtype=np.float32)
 
-    r1 = mkl_fft.fftn(b, shape=None, axes=(0,), overwrite_x=False, fwd_scale=1/3)
-    r2 = mkl_fft.fftn(b_int, shape=None, axes=(0,), overwrite_x=False, fwd_scale=1/3)
+    r1 = mkl_fft.fftn(b, s=None, axes=(0,), overwrite_x=False, fwd_scale=1/3)
+    r2 = mkl_fft.fftn(b_int, s=None, axes=(0,), overwrite_x=False, fwd_scale=1/3)
 
     rtol, atol = _get_rtol_atol(b)
     assert_allclose(r1, r2, rtol=rtol, atol=atol)
+
+
+@pytest.mark.parametrize("dtype", [complex, float])
+@pytest.mark.parametrize("s", [(15, 24, 10), [35, 25, 15], [25, 15, 5]])
+@pytest.mark.parametrize("axes", [(0, 1, 2), (-1, -2, -3), [1, 0, 2]])
+@pytest.mark.parametrize("func", ["fftn", "ifftn", "rfftn", "irfftn"])
+def test_s_axes(dtype, s, axes, func):
+    shape = (30, 20, 10)
+    if dtype is complex and func != "rfftn":
+        x = np.random.random(shape) + 1j * np.random.random(shape)
+    else:
+        x = np.random.random(shape)
+
+    r1 = getattr(mkl_fft, func)(x, s=s, axes=axes)
+    r2 = getattr(np.fft, func)(x, s=s, axes=axes)
+
+    rtol, atol = _get_rtol_atol(x)
+    assert_allclose(r1, r2, rtol=rtol, atol=atol)
+
+
+@pytest.mark.parametrize("dtype", [complex, float])
+@pytest.mark.parametrize("axes", [(2, 0, 2, 0), (0, 1, 1), (2, 0, 1, 3, 2, 1)])
+@pytest.mark.parametrize("func", ["rfftn", "irfftn"])
+def test_repeated_axes(dtype, axes, func):
+    shape = (2, 3, 4, 5)
+    if dtype is complex and func != "rfftn":
+        x = np.random.random(shape) + 1j * np.random.random(shape)
+    else:
+        x = np.random.random(shape)
+
+    r1 = getattr(mkl_fft, func)(x, axes=axes)
+    r2 = getattr(np.fft, func)(x, axes=axes)
+
+    rtol, atol = _get_rtol_atol(x)
+    assert_allclose(r1, r2, rtol=rtol, atol=atol)
diff --git a/mkl_fft/tests/test_interfaces.py b/mkl_fft/tests/test_interfaces.py
index 954ed56..221e6f5 100644
--- a/mkl_fft/tests/test_interfaces.py
+++ b/mkl_fft/tests/test_interfaces.py
@@ -29,14 +29,6 @@
 import numpy as np
 
 
-def test_interfaces_has_numpy():
-    assert hasattr(mfi, 'numpy_fft')
-
-
-def test_interfaces_has_scipy():
-    assert hasattr(mfi, 'scipy_fft')
-
-
 @pytest.mark.parametrize('norm', [None, "forward", "backward", "ortho"])
 @pytest.mark.parametrize('dtype', [np.float32, np.float64, np.complex64, np.complex128])
 def test_scipy_fft(norm, dtype):
@@ -151,3 +143,15 @@ def test_scipy_fft_arg_validate():
     with pytest.raises(NotImplementedError):
         mfi.scipy_fft.fft([1,2,3,4], plan="magic")
 
+
+@pytest.mark.parametrize(
+    "func", 
+    [mfi.scipy_fft.rfft2, mfi.numpy_fft.rfft2],
+    ids=["scipy", "numpy"],
+)
+def test_axes(func):
+    x = np.arange(24.).reshape(2, 3, 4)
+    res = func(x, axes=(1, 2))
+    exp = np.fft.rfft2(x, axes=(1, 2))
+    tol = 64 * np.finfo(np.float64).eps
+    assert np.allclose(res, exp, atol=tol, rtol=tol)
diff --git a/mkl_fft/tests/test_pocketfft.py b/mkl_fft/tests/test_pocketfft.py
index aef1cdd..7f006a0 100644
--- a/mkl_fft/tests/test_pocketfft.py
+++ b/mkl_fft/tests/test_pocketfft.py
@@ -37,8 +37,7 @@ def test_identity(self):
             assert_allclose(mkl_fft.irfft(mkl_fft.rfft(xr[0:i]), i),
                             xr[0:i], atol=1e-12)
 
-    @pytest.mark.skip()
-    @pytest.mark.parametrize("dtype", [np.single, np.double, np.longdouble])
+    @pytest.mark.parametrize("dtype", [np.single, np.double]) #, np.longdouble])
     def test_identity_long_short(self, dtype):
         # Test with explicitly given number of points, both for n
         # smaller and for n larger than the input size.
@@ -56,8 +55,7 @@ def test_identity_long_short(self, dtype):
             assert check_r.dtype == dtype
             assert_allclose(check_r, xxr[0:i], atol=atol, rtol=0)
 
-    @pytest.mark.skip()
-    @pytest.mark.parametrize("dtype", [np.single, np.double, np.longdouble])
+    @pytest.mark.parametrize("dtype", [np.single, np.double]) #, np.longdouble])
     def test_identity_long_short_reversed(self, dtype):
         # Also test explicitly given number of points in reversed order.
         maxlen = 16
@@ -307,7 +305,6 @@ def test_irfft2(self):
         assert_allclose(x, mkl_fft.irfft2(mkl_fft.rfft2(x, norm="forward"),
                         norm="forward"), atol=1e-6)
 
-    @pytest.mark.skip("repeated axes")
     def test_rfftn(self):
         x = random((30, 20, 10))
         assert_allclose(mkl_fft.fftn(x)[:, :, :6], mkl_fft.rfftn(x), atol=1e-6)
@@ -360,7 +357,6 @@ def test_ihfft(self):
         assert_allclose(x_herm, mkl_fft.ihfft(mkl_fft.hfft(x_herm,
                         norm="forward"), norm="forward"), atol=1e-6)
 
-    @pytest.mark.skip("Casting complex values to real")
     @pytest.mark.parametrize("op", [mkl_fft.fftn, mkl_fft.ifftn,
                                     mkl_fft.rfftn, mkl_fft.irfftn])
     def test_axes(self, op):
@@ -483,7 +479,6 @@ def test_irfftn_out_and_s_interaction(self, s):
         assert_array_equal(result, expected)
 
 
-@pytest.mark.skip()
 @pytest.mark.parametrize(
         "dtype",
         [np.float32, np.float64, np.complex64, np.complex128])
@@ -518,7 +513,7 @@ def test_fft_with_order(dtype, order, fft):
         for ax in axes:
             X_res = fft(X, axes=ax)
             Y_res = fft(Y, axes=ax)
-            assert_allclose(X_res, Y_res, atol=_tol, rtol=_tol)
+            assert_allclose(X_res, Y_res, atol=_tol, rtol=10 * _tol)
     else:
         raise ValueError
 
@@ -591,7 +586,6 @@ def test_irfft_with_n_large_regression():
     assert_allclose(result, expected)
 
 
-@pytest.mark.skip()
 @pytest.mark.parametrize("fft", [
     mkl_fft.fft, mkl_fft.ifft, mkl_fft.rfft, mkl_fft.irfft
 ])
@@ -605,4 +599,4 @@ def test_fft_with_integer_or_bool_input(data, fft):
     result = fft(data)
     float_data = data.astype(np.result_type(data, 1.))
     expected = fft(float_data)
-    assert_array_equal(result, expected)
+    assert_allclose(result, expected, rtol=1e-15)