Skip to content

Commit b543ee9

Browse files
authored
update _scipy_fft.py (#145)
* create a new util file * update _scipy_fft.py * update _scipy_fftpack.py * add new scipy tests * get rid of double leading underscore
1 parent ba40a93 commit b543ee9

16 files changed

+1002
-330
lines changed

Diff for: .github/workflows/conda-package-cf.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ jobs:
117117
- name: Install mkl_fft
118118
run: |
119119
CHANNELS="-c $GITHUB_WORKSPACE/channel ${{ env.CHANNELS }}"
120-
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python_ver }} ${{ matrix.numpy }} $PACKAGE_NAME pytest $CHANNELS
120+
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python_ver }} ${{ matrix.numpy }} $PACKAGE_NAME pytest scipy $CHANNELS
121121
# Test installed packages
122122
conda list -n ${{ env.TEST_ENV_NAME }}
123123
- name: Run tests
@@ -268,7 +268,7 @@ jobs:
268268
SET PACKAGE_VERSION=%%F
269269
)
270270
SET "TEST_DEPENDENCIES=pytest pytest-cov"
271-
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} ${{ matrix.numpy }} -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
271+
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} ${{ matrix.numpy }} scipy -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
272272
- name: Report content of test environment
273273
shell: cmd /C CALL {0}
274274
run: |

Diff for: .github/workflows/conda-package.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ jobs:
116116
- name: Install mkl_fft
117117
run: |
118118
CHANNELS="-c $GITHUB_WORKSPACE/channel ${{ env.CHANNELS }}"
119-
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python }} $PACKAGE_NAME pytest $CHANNELS
119+
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python }} $PACKAGE_NAME pytest scipy $CHANNELS
120120
# Test installed packages
121121
conda list -n ${{ env.TEST_ENV_NAME }}
122122
- name: Run tests
@@ -267,7 +267,7 @@ jobs:
267267
SET PACKAGE_VERSION=%%F
268268
)
269269
SET "TEST_DEPENDENCIES=pytest pytest-cov"
270-
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
270+
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} scipy -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
271271
- name: Report content of test environment
272272
shell: cmd /C CALL {0}
273273
run: |

Diff for: .gitignore

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
_vendored/__pycache__/
1+
# CMake build and local install directory
22
build/
33
mkl_fft.egg-info/
4-
mkl_fft/__pycache__/
4+
5+
# Byte-compiled / optimized / DLL files
6+
__pycache__/
7+
58
mkl_fft/_pydfti.c
6-
mkl_fft/_pydfti.cpython-310-x86_64-linux-gnu.so
7-
mkl_fft/interfaces/__pycache__/
9+
mkl_fft/_pydfti.cpython*.so
810
mkl_fft/src/mklfft.c
9-
mkl_fft/tests/__pycache__/

Diff for: conda-recipe-cf/meta.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ test:
3232
- pytest -v --pyargs mkl_fft
3333
requires:
3434
- pytest
35+
- scipy
3536
imports:
3637
- mkl_fft
3738
- mkl_fft.interfaces

Diff for: conda-recipe/meta.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ test:
3232
- pytest -v --pyargs mkl_fft
3333
requires:
3434
- pytest
35+
- scipy
3536
imports:
3637
- mkl_fft
3738
- mkl_fft.interfaces

Diff for: mkl_fft/_fft_utils.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2025, Intel Corporation
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of Intel Corporation nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
import numpy as np
28+
29+
__all__ = ["_check_norm", "_compute_fwd_scale"]
30+
31+
32+
def _check_norm(norm):
33+
if norm not in (None, "ortho", "forward", "backward"):
34+
raise ValueError(
35+
f"Invalid norm value {norm} should be None, 'ortho', 'forward', "
36+
"or 'backward'."
37+
)
38+
39+
40+
def _compute_fwd_scale(norm, n, shape):
41+
_check_norm(norm)
42+
if norm in (None, "backward"):
43+
return 1.0
44+
45+
ss = n if n is not None else shape
46+
nn = np.prod(ss)
47+
fsc = 1 / nn if nn != 0 else 1
48+
if norm == "forward":
49+
return fsc
50+
else: # norm == "ortho"
51+
return np.sqrt(fsc)

Diff for: mkl_fft/_float_utils.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727
import numpy as np
2828

2929
__all__ = [
30-
"__upcast_float16_array",
31-
"__downcast_float128_array",
32-
"__supported_array_or_not_implemented",
30+
"_upcast_float16_array",
31+
"_downcast_float128_array",
32+
"_supported_array_or_not_implemented",
3333
]
3434

3535

36-
def __upcast_float16_array(x):
36+
def _upcast_float16_array(x):
3737
"""
3838
Used in _scipy_fft to upcast float16 to float32,
3939
instead of float64, as mkl_fft would do
@@ -46,18 +46,18 @@ def __upcast_float16_array(x):
4646
if xdt == np.longdouble and not xdt == np.float64:
4747
raise ValueError("type %s is not supported" % xdt)
4848
if not isinstance(x, np.ndarray):
49-
__x = np.asarray(x)
50-
xdt = __x.dtype
49+
_x = np.asarray(x)
50+
xdt = _x.dtype
5151
if xdt == np.half:
5252
# no half-precision routines, so convert to single precision
53-
return np.asarray(__x, dtype=np.float32)
53+
return np.asarray(_x, dtype=np.float32)
5454
if xdt == np.longdouble and not xdt == np.float64:
5555
raise ValueError("type %s is not supported" % xdt)
56-
return __x
56+
return _x
5757
return x
5858

5959

60-
def __downcast_float128_array(x):
60+
def _downcast_float128_array(x):
6161
"""
6262
Used in _numpy_fft to unsafely downcast float128/complex256 to
6363
complex128, instead of raising an error
@@ -69,27 +69,27 @@ def __downcast_float128_array(x):
6969
elif xdt == np.clongdouble and not xdt == np.complex128:
7070
return np.asarray(x, dtype=np.complex128)
7171
if not isinstance(x, np.ndarray):
72-
__x = np.asarray(x)
73-
xdt = __x.dtype
72+
_x = np.asarray(x)
73+
xdt = _x.dtype
7474
if xdt == np.longdouble and not xdt == np.float64:
7575
return np.asarray(x, dtype=np.float64)
7676
elif xdt == np.clongdouble and not xdt == np.complex128:
7777
return np.asarray(x, dtype=np.complex128)
78-
return __x
78+
return _x
7979
return x
8080

8181

82-
def __supported_array_or_not_implemented(x):
82+
def _supported_array_or_not_implemented(x):
8383
"""
8484
Used in _scipy_fft to convert array to float32,
85-
float64, complex64, or complex128 type or return NotImplemented
85+
float64, complex64, or complex128 type or raise NotImplementedError
8686
"""
87-
__x = np.asarray(x)
87+
_x = np.asarray(x)
8888
black_list = [np.half]
8989
if hasattr(np, "float128"):
9090
black_list.append(np.float128)
9191
if hasattr(np, "complex256"):
9292
black_list.append(np.complex256)
93-
if __x.dtype in black_list:
94-
return NotImplemented
95-
return __x
93+
if _x.dtype in black_list:
94+
raise NotImplementedError
95+
return _x

0 commit comments

Comments
 (0)