Skip to content

Commit 28caadf

Browse files
committed
changes to raise warning only when numpy > 2.0 is used
1 parent 65727ea commit 28caadf

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

Diff for: mkl_fft/_numpy_fft.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
import re
7474
import warnings
7575

76+
import numpy as np
7677
from numpy import array, asanyarray, conjugate, prod, sqrt, take
7778

7879
from . import _float_utils
@@ -701,7 +702,7 @@ def _cook_nd_args(a, s=None, axes=None, invreal=False):
701702
shapeless = False
702703
s = list(s)
703704
if axes is None:
704-
if not shapeless:
705+
if not shapeless and np.__version__ >= "2.0":
705706
msg = (
706707
"`axes` should not be `None` if `s` is not `None` "
707708
"(Deprecated in NumPy 2.0). In a future version of NumPy, "
@@ -716,7 +717,7 @@ def _cook_nd_args(a, s=None, axes=None, invreal=False):
716717
raise ValueError("Shape and axes have different lengths.")
717718
if invreal and shapeless:
718719
s[-1] = (a.shape[axes[-1]] - 1) * 2
719-
if None in s:
720+
if None in s and np.__version__ >= "2.0":
720721
msg = (
721722
"Passing an array containing `None` values to `s` is "
722723
"deprecated in NumPy 2.0 and will raise an error in "

Diff for: mkl_fft/tests/test_pocketfft.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def test_s_negative_1(self, op):
510510
# should use the whole input array along the first axis
511511
assert op(x, s=(-1, 5), axes=(0, 1)).shape == (10, 5)
512512

513-
@pytest.mark.skip("no warning is raised in mkl_ftt")
513+
@pytest.mark.skipif(np.__version__ < "2.0", reason="Requires numpy >= 2.0")
514514
@pytest.mark.parametrize(
515515
"op", [mkl_fft.fftn, mkl_fft.ifftn, mkl_fft.rfftn, mkl_fft.irfftn]
516516
)
@@ -519,13 +519,14 @@ def test_s_axes_none(self, op):
519519
with pytest.warns(match="`axes` should not be `None` if `s`"):
520520
op(x, s=(-1, 5))
521521

522+
@pytest.mark.skipif(np.__version__ < "2.0", reason="Requires numpy >= 2.0")
522523
@pytest.mark.parametrize("op", [mkl_fft.fft2, mkl_fft.ifft2])
523524
def test_s_axes_none_2D(self, op):
524525
x = np.arange(100).reshape(10, 10)
525526
with pytest.warns(match="`axes` should not be `None` if `s`"):
526527
op(x, s=(-1, 5), axes=None)
527528

528-
@pytest.mark.skip("no warning is raised in mkl_ftt")
529+
@pytest.mark.skipif(np.__version__ < "2.0", reason="Requires numpy >= 2.0")
529530
@pytest.mark.parametrize(
530531
"op",
531532
[

0 commit comments

Comments
 (0)