Skip to content

Commit bcfbb46

Browse files
committed
fix an issue for real inputs of irfft
1 parent 2eb3cfc commit bcfbb46

File tree

7 files changed

+41
-31
lines changed

7 files changed

+41
-31
lines changed

.github/workflows/conda-package-cf.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ jobs:
132132
- name: Install mkl_fft
133133
run: |
134134
CHANNELS="-c $GITHUB_WORKSPACE/channel ${{ env.CHANNELS }}"
135-
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python_ver }} ${{ matrix.numpy }} $PACKAGE_NAME pytest scipy $CHANNELS
135+
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python_ver }} ${{ matrix.numpy }} $PACKAGE_NAME pytest scipy>=1.10 $CHANNELS
136136
# Test installed packages
137137
conda list -n ${{ env.TEST_ENV_NAME }}
138138
@@ -295,8 +295,8 @@ jobs:
295295
FOR /F "tokens=* USEBACKQ" %%F IN (`python -c "%SCRIPT%"`) DO (
296296
SET PACKAGE_VERSION=%%F
297297
)
298-
SET "TEST_DEPENDENCIES=pytest pytest-cov"
299-
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} ${{ matrix.numpy }} scipy -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
298+
SET "TEST_DEPENDENCIES=pytest scipy>=1.10"
299+
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} ${{ matrix.numpy }} -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
300300
301301
- name: Report content of test environment
302302
shell: cmd /C CALL {0}

.github/workflows/conda-package.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ jobs:
131131
- name: Install mkl_fft
132132
run: |
133133
CHANNELS="-c $GITHUB_WORKSPACE/channel ${{ env.CHANNELS }}"
134-
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python }} $PACKAGE_NAME pytest scipy $CHANNELS
134+
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python }} $PACKAGE_NAME pytest scipy>=1.10 $CHANNELS
135135
# Test installed packages
136136
conda list -n ${{ env.TEST_ENV_NAME }}
137137
@@ -295,8 +295,8 @@ jobs:
295295
FOR /F "tokens=* USEBACKQ" %%F IN (`python -c "%SCRIPT%"`) DO (
296296
SET PACKAGE_VERSION=%%F
297297
)
298-
SET "TEST_DEPENDENCIES=pytest pytest-cov"
299-
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} scipy -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
298+
SET "TEST_DEPENDENCIES=pytest scipy>=1.10.0"
299+
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
300300
301301
- name: Report content of test environment
302302
shell: cmd /C CALL {0}

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
### Changed
1414
* NumPy interface `mkl_fft.interfaces.numpy_fft` is aligned with numpy-2.* [gh-139](https://github.com/IntelPython/mkl_fft/pull/139), [gh-157](https://github.com/IntelPython/mkl_fft/pull/157)
1515

16+
### Fixed
17+
* Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with an empty axes [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
18+
* Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with a zero-size array [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
19+
* Fixed inconsistency of input and output arrays dtype for `irfft` function [gh-180](https://github.com/IntelPython/mkl_fft/pull/180)
20+
1621
## [1.3.14] (04/10/2025)
1722

1823
resolves gh-152 by adding an explicit `mkl-service` dependency to `mkl-fft` when building the wheel

mkl_fft/_pydfti.pyx

+12-13
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _tls_dfti_cache_capsule():
8989
cdef extern from "Python.h":
9090
ctypedef int size_t
9191

92-
long PyInt_AsLong(object ob)
92+
long PyLong_AsLong(object ob)
9393
int PyObject_HasAttrString(object, char*)
9494

9595

@@ -262,7 +262,7 @@ cdef cnp.ndarray _process_arguments(
262262
xnd[0] = cnp.PyArray_NDIM(x_arr) # tensor-rank of the array
263263

264264
err = 0
265-
axis_[0] = PyInt_AsLong(axis)
265+
axis_[0] = PyLong_AsLong(axis)
266266
if (axis_[0] == -1 and PyErr_Occurred()):
267267
PyErr_Clear()
268268
err = 1
@@ -278,7 +278,7 @@ cdef cnp.ndarray _process_arguments(
278278
n_[0] = x_arr.shape[axis_[0]]
279279
else:
280280
try:
281-
n_[0] = PyInt_AsLong(n)
281+
n_[0] = PyLong_AsLong(n)
282282
except:
283283
err = 1
284284

@@ -334,7 +334,7 @@ cdef int _is_integral(object num):
334334
if num is None:
335335
return 0
336336
try:
337-
n = PyInt_AsLong(num)
337+
n = PyLong_AsLong(num)
338338
_integral = 1 if n > 0 else 0
339339
except:
340340
_integral = 0
@@ -665,13 +665,12 @@ def _r2c_fft1d_impl(
665665
return f_arr
666666

667667

668-
# this routine is functionally equivalent to numpy.fft.irfft
669668
def _c2r_fft1d_impl(
670669
x, n=None, axis=-1, overwrite_x=False, double fsc=1.0, out=None
671670
):
672671
"""
673-
Uses MKL to perform 1D FFT on the real input array x along the given axis,
674-
producing complex output, but giving only half of the harmonics.
672+
Uses MKL to perform 1D FFT on the real/complex input array x along the
673+
given axis, producing real output.
675674
676675
cf. numpy.fft.irfft
677676
"""
@@ -704,13 +703,13 @@ def _c2r_fft1d_impl(
704703
else:
705704
# we must cast the input and allocate the output,
706705
# so we cast to complex double and operate in place
707-
try:
706+
if x_type is cnp.NPY_FLOAT:
708707
x_arr = <cnp.ndarray> cnp.PyArray_FROM_OTF(
709-
x_arr, cnp.NPY_CDOUBLE, cnp.NPY_BEHAVED)
710-
except:
711-
raise ValueError(
712-
"First argument should be a real or "
713-
"a complex sequence of single or double precision"
708+
x_arr, cnp.NPY_CFLOAT, cnp.NPY_BEHAVED
709+
)
710+
else:
711+
x_arr = <cnp.ndarray> cnp.PyArray_FROM_OTF(
712+
x_arr, cnp.NPY_CDOUBLE, cnp.NPY_BEHAVED
714713
)
715714
x_type = cnp.PyArray_TYPE(x_arr)
716715
in_place = 1

mkl_fft/interfaces/_numpy_fft.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,11 @@ def hfft(a, n=None, axis=-1, norm=None, out=None):
295295
"""
296296
norm = _swap_direction(norm)
297297
x = _downcast_float128_array(a)
298-
x = np.array(x, copy=True)
299-
np.conjugate(x, out=x)
300298
fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1))
301299

302300
return _trycall(
303301
mkl_fft.irfft,
304-
(x,),
302+
(np.conjugate(x),),
305303
{"n": n, "axis": axis, "fwd_scale": fsc, "out": out},
306304
)
307305

@@ -317,9 +315,9 @@ def ihfft(a, n=None, axis=-1, norm=None, out=None):
317315
x = _downcast_float128_array(a)
318316
fsc = _compute_fwd_scale(norm, n, x.shape[axis])
319317

320-
output = _trycall(
318+
result = _trycall(
321319
mkl_fft.rfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc, "out": out}
322320
)
323321

324-
np.conjugate(output, out=output)
325-
return output
322+
np.conjugate(result, out=result)
323+
return result

mkl_fft/tests/test_fft1d.py

+10
Original file line numberDiff line numberDiff line change
@@ -457,3 +457,13 @@ def test_irfft_out_strided(axis):
457457
expected = np.fft.irfft(x, axis=axis, out=out)
458458

459459
assert_allclose(result, expected)
460+
461+
462+
@requires_numpy_2
463+
@pytest.mark.parametrize("dt", ["i4", "f4", "f8", "c8", "c16"])
464+
def test_irfft_dtype(dt):
465+
x = np.array(rnd.random((20, 20)), dtype=dt)
466+
result = mkl_fft.irfft(x)
467+
expected = np.fft.irfft(x)
468+
assert result.dtype == expected.dtype
469+
assert_allclose(result, expected, rtol=1e-7, atol=1e-7)

mkl_fft/tests/third_party/scipy/test_basic.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@
1414

1515
# pylint: disable=possibly-used-before-assignment
1616
if scipy.__version__ < "1.12":
17-
# scipy from Intel channel is 1.10
18-
pytest.skip(
19-
"This test file needs scipy>=1.12",
20-
allow_module_level=True,
21-
)
17+
# scipy from Intel channel is 1.10 with python 3.9 and 3.10
18+
pytest.skip("This test file needs scipy>=1.12", allow_module_level=True)
2219
elif scipy.__version__ < "1.14":
23-
# For python<=3.9, scipy<1.14 is installed
20+
# For pytho-3.11 and 3.12, scipy<1.14 is installed from Intel channel
21+
# For python<=3.9, scipy<1.14 is installed from conda channel
2422
# pylint: disable=no-name-in-module
2523
from scipy._lib._array_api import size as xp_size
2624
else:

0 commit comments

Comments
 (0)