Skip to content

fix an issue for real inputs of irfft #180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/conda-package-cf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ jobs:
FOR /F "tokens=* USEBACKQ" %%F IN (`python -c "%SCRIPT%"`) DO (
SET PACKAGE_VERSION=%%F
)
SET "TEST_DEPENDENCIES=pytest pytest-cov"
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} ${{ matrix.numpy }} scipy -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
SET "TEST_DEPENDENCIES=pytest scipy"
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} ${{ matrix.numpy }} -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
- name: Report content of test environment
shell: cmd /C CALL {0}
Expand Down
8 changes: 5 additions & 3 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ jobs:
- name: Install mkl_fft
run: |
CHANNELS="-c $GITHUB_WORKSPACE/channel ${{ env.CHANNELS }}"
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python }} $PACKAGE_NAME pytest scipy $CHANNELS
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python }} "scipy>=1.10" $CHANNELS
conda install -n ${{ env.TEST_ENV_NAME }} $PACKAGE_NAME pytest $CHANNELS
# Test installed packages
conda list -n ${{ env.TEST_ENV_NAME }}

Expand Down Expand Up @@ -295,15 +296,16 @@ jobs:
FOR /F "tokens=* USEBACKQ" %%F IN (`python -c "%SCRIPT%"`) DO (
SET PACKAGE_VERSION=%%F
)
SET "TEST_DEPENDENCIES=pytest pytest-cov"
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} scipy -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
SET "TEST_DEPENDENCIES=pytest scipy"
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}

- name: Report content of test environment
shell: cmd /C CALL {0}
run: |
echo "Value of CONDA environment variable was: " %CONDA%
echo "Value of CONDA_PREFIX environment variable was: " %CONDA_PREFIX%
conda info && conda list -n ${{ env.TEST_ENV_NAME }}

- name: Run tests
shell: cmd /C CALL {0}
run: >-
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
* NumPy interface `mkl_fft.interfaces.numpy_fft` is aligned with numpy-2.* [gh-139](https://github.com/IntelPython/mkl_fft/pull/139), [gh-157](https://github.com/IntelPython/mkl_fft/pull/157)

### Fixed
* Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with an empty axes [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
* Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with a zero-size array [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
* Fixed inconsistency of input and output arrays dtype for `irfft` function [gh-180](https://github.com/IntelPython/mkl_fft/pull/180)

## [1.3.14] (04/10/2025)

resolves gh-152 by adding an explicit `mkl-service` dependency to `mkl-fft` when building the wheel
Expand Down
25 changes: 12 additions & 13 deletions mkl_fft/_pydfti.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _tls_dfti_cache_capsule():
cdef extern from "Python.h":
ctypedef int size_t

long PyInt_AsLong(object ob)
long PyLong_AsLong(object ob)
int PyObject_HasAttrString(object, char*)


Expand Down Expand Up @@ -262,7 +262,7 @@ cdef cnp.ndarray _process_arguments(
xnd[0] = cnp.PyArray_NDIM(x_arr) # tensor-rank of the array

err = 0
axis_[0] = PyInt_AsLong(axis)
axis_[0] = PyLong_AsLong(axis)
if (axis_[0] == -1 and PyErr_Occurred()):
PyErr_Clear()
err = 1
Expand All @@ -278,7 +278,7 @@ cdef cnp.ndarray _process_arguments(
n_[0] = x_arr.shape[axis_[0]]
else:
try:
n_[0] = PyInt_AsLong(n)
n_[0] = PyLong_AsLong(n)
except:
err = 1

Expand Down Expand Up @@ -334,7 +334,7 @@ cdef int _is_integral(object num):
if num is None:
return 0
try:
n = PyInt_AsLong(num)
n = PyLong_AsLong(num)
_integral = 1 if n > 0 else 0
except:
_integral = 0
Expand Down Expand Up @@ -665,13 +665,12 @@ def _r2c_fft1d_impl(
return f_arr


# this routine is functionally equivalent to numpy.fft.irfft
def _c2r_fft1d_impl(
x, n=None, axis=-1, overwrite_x=False, double fsc=1.0, out=None
):
"""
Uses MKL to perform 1D FFT on the real input array x along the given axis,
producing complex output, but giving only half of the harmonics.
Uses MKL to perform 1D FFT on the real/complex input array x along the
given axis, producing real output.

cf. numpy.fft.irfft
"""
Expand Down Expand Up @@ -704,13 +703,13 @@ def _c2r_fft1d_impl(
else:
# we must cast the input and allocate the output,
# so we cast to complex double and operate in place
try:
if x_type is cnp.NPY_FLOAT:
x_arr = <cnp.ndarray> cnp.PyArray_FROM_OTF(
x_arr, cnp.NPY_CDOUBLE, cnp.NPY_BEHAVED)
except:
raise ValueError(
"First argument should be a real or "
"a complex sequence of single or double precision"
x_arr, cnp.NPY_CFLOAT, cnp.NPY_BEHAVED
)
else:
x_arr = <cnp.ndarray> cnp.PyArray_FROM_OTF(
x_arr, cnp.NPY_CDOUBLE, cnp.NPY_BEHAVED
)
x_type = cnp.PyArray_TYPE(x_arr)
in_place = 1
Expand Down
10 changes: 4 additions & 6 deletions mkl_fft/interfaces/_numpy_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,11 @@ def hfft(a, n=None, axis=-1, norm=None, out=None):
"""
norm = _swap_direction(norm)
x = _downcast_float128_array(a)
x = np.array(x, copy=True)
np.conjugate(x, out=x)
fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1))

return _trycall(
mkl_fft.irfft,
(x,),
(np.conjugate(x),),
{"n": n, "axis": axis, "fwd_scale": fsc, "out": out},
)

Expand All @@ -317,9 +315,9 @@ def ihfft(a, n=None, axis=-1, norm=None, out=None):
x = _downcast_float128_array(a)
fsc = _compute_fwd_scale(norm, n, x.shape[axis])

output = _trycall(
result = _trycall(
mkl_fft.rfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc, "out": out}
)

np.conjugate(output, out=output)
return output
np.conjugate(result, out=result)
return result
10 changes: 10 additions & 0 deletions mkl_fft/tests/test_fft1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,13 @@ def test_irfft_out_strided(axis):
expected = np.fft.irfft(x, axis=axis, out=out)

assert_allclose(result, expected)


@requires_numpy_2
@pytest.mark.parametrize("dt", ["i4", "f4", "f8", "c8", "c16"])
def test_irfft_dtype(dt):
x = np.array(rnd.random((20, 20)), dtype=dt)
result = mkl_fft.irfft(x)
expected = np.fft.irfft(x)
assert result.dtype == expected.dtype
assert_allclose(result, expected, rtol=1e-7, atol=1e-7)
10 changes: 4 additions & 6 deletions mkl_fft/tests/third_party/scipy/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@

# pylint: disable=possibly-used-before-assignment
if scipy.__version__ < "1.12":
# scipy from Intel channel is 1.10
pytest.skip(
"This test file needs scipy>=1.12",
allow_module_level=True,
)
# scipy from Intel channel is 1.10 with python 3.9 and 3.10
pytest.skip("This test file needs scipy>=1.12", allow_module_level=True)
elif scipy.__version__ < "1.14":
# For python<=3.9, scipy<1.14 is installed
# For pytho-3.11 and 3.12, scipy<1.14 is installed from Intel channel
# For python<=3.9, scipy<1.14 is installed from conda channel
# pylint: disable=no-name-in-module
from scipy._lib._array_api import size as xp_size
else:
Expand Down
Loading