From f228b584e87dfa1280561e15373f04d530b10762 Mon Sep 17 00:00:00 2001
From: Aaron Meurer <asmeurer@gmail.com>
Date: Mon, 8 Jan 2024 18:40:18 -0700
Subject: [PATCH 01/13] Add fft support for numpy and cupy

This is based off of https://github.com/numpy/numpy/pull/25317
---
 array_api_compat/common/_fft.py    | 183 +++++++++++++++++++++++++++++
 array_api_compat/cupy/__init__.py  |   2 +
 array_api_compat/cupy/fft.py       |  29 +++++
 array_api_compat/numpy/__init__.py |   2 +
 array_api_compat/numpy/fft.py      |  29 +++++
 5 files changed, 245 insertions(+)
 create mode 100644 array_api_compat/common/_fft.py
 create mode 100644 array_api_compat/cupy/fft.py
 create mode 100644 array_api_compat/numpy/fft.py

diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py
new file mode 100644
index 00000000..4d2bb3fe
--- /dev/null
+++ b/array_api_compat/common/_fft.py
@@ -0,0 +1,183 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Union, Optional, Literal
+
+if TYPE_CHECKING:
+    from ._typing import Device, ndarray
+    from collections.abc import Sequence
+
+# Note: NumPy fft functions improperly upcast float32 and complex64 to
+# complex128, which is why we require wrapping them all here.
+
+def fft(
+    x: ndarray,
+    /,
+    xp,
+    *,
+    n: Optional[int] = None,
+    axis: int = -1,
+    norm: Literal["backward", "ortho", "forward"] = "backward",
+) -> ndarray:
+    res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
+    if x.dtype in [xp.float32, xp.complex64]:
+        return res.astype(xp.complex64)
+    return res
+
+def ifft(
+    x: ndarray,
+    /,
+    xp,
+    *,
+    n: Optional[int] = None,
+    axis: int = -1,
+    norm: Literal["backward", "ortho", "forward"] = "backward",
+) -> ndarray:
+    res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
+    if x.dtype in [xp.float32, xp.complex64]:
+        return res.astype(xp.complex64)
+    return res
+
+def fftn(
+    x: ndarray,
+    /,
+    xp,
+    *,
+    s: Sequence[int] = None,
+    axes: Sequence[int] = None,
+    norm: Literal["backward", "ortho", "forward"] = "backward",
+) -> ndarray:
+    res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
+    if x.dtype in [xp.float32, xp.complex64]:
+        return res.astype(xp.complex64)
+    return res
+
+def ifftn(
+    x: ndarray,
+    /,
+    xp,
+    *,
+    s: Sequence[int] = None,
+    axes: Sequence[int] = None,
+    norm: Literal["backward", "ortho", "forward"] = "backward",
+) -> ndarray:
+    res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
+    if x.dtype in [xp.float32, xp.complex64]:
+        return res.astype(xp.complex64)
+    return res
+
+def rfft(
+    x: ndarray,
+    /,
+    xp,
+    *,
+    n: Optional[int] = None,
+    axis: int = -1,
+    norm: Literal["backward", "ortho", "forward"] = "backward",
+) -> ndarray:
+    res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
+    if x.dtype == xp.float32:
+        return res.astype(xp.complex64)
+    return res
+
+def irfft(
+    x: ndarray,
+    /,
+    xp,
+    *,
+    n: Optional[int] = None,
+    axis: int = -1,
+    norm: Literal["backward", "ortho", "forward"] = "backward",
+) -> ndarray:
+    res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
+    if x.dtype == xp.complex64:
+        return res.astype(xp.float32)
+    return res
+
+def rfftn(
+    x: ndarray,
+    /,
+    xp,
+    *,
+    s: Sequence[int] = None,
+    axes: Sequence[int] = None,
+    norm: Literal["backward", "ortho", "forward"] = "backward",
+) -> ndarray:
+    res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
+    if x.dtype == xp.float32:
+        return res.astype(xp.complex64)
+    return res
+
+def irfftn(
+    x: ndarray,
+    /,
+    xp,
+    *,
+    s: Sequence[int] = None,
+    axes: Sequence[int] = None,
+    norm: Literal["backward", "ortho", "forward"] = "backward",
+) -> ndarray:
+    res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
+    if x.dtype == xp.complex64:
+        return res.astype(xp.float32)
+    return res
+
+def hfft(
+    x: ndarray,
+    /,
+    xp,
+    *,
+    n: Optional[int] = None,
+    axis: int = -1,
+    norm: Literal["backward", "ortho", "forward"] = "backward",
+) -> ndarray:
+    res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
+    if x.dtype in [xp.float32, xp.complex64]:
+        return res.astype(xp.complex64)
+    return res
+
+def ihfft(
+    x: ndarray,
+    /,
+    xp,
+    *,
+    n: Optional[int] = None,
+    axis: int = -1,
+    norm: Literal["backward", "ortho", "forward"] = "backward",
+) -> ndarray:
+    res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
+    if x.dtype in [xp.float32, xp.complex64]:
+        return res.astype(xp.complex64)
+    return res
+
+def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
+    if device not in ["cpu", None]:
+        raise ValueError(f"Unsupported device {device!r}")
+    return xp.fft.fftfreq(n, d=d)
+
+def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
+    if device not in ["cpu", None]:
+        raise ValueError(f"Unsupported device {device!r}")
+    return xp.fft.rfftfreq(n, d=d)
+
+def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
+    return xp.fft.fftshift(x, axes=axes)
+
+def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
+    return xp.fft.ifftshift(x, axes=axes)
+
+__all__ = [
+    "fft",
+    "ifft",
+    "fftn",
+    "ifftn",
+    "rfft",
+    "irfft",
+    "rfftn",
+    "irfftn",
+    "hfft",
+    "ihfft",
+    "fftfreq",
+    "rfftfreq",
+    "fftshift",
+    "ifftshift",
+]
diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py
index ec113f9d..d820e44b 100644
--- a/array_api_compat/cupy/__init__.py
+++ b/array_api_compat/cupy/__init__.py
@@ -9,6 +9,8 @@
 # See the comment in the numpy __init__.py
 __import__(__package__ + '.linalg')
 
+__import__(__package__ + '.fft')
+
 from .linalg import matrix_transpose, vecdot
 
 from ..common._helpers import *
diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py
new file mode 100644
index 00000000..8e83abb8
--- /dev/null
+++ b/array_api_compat/cupy/fft.py
@@ -0,0 +1,29 @@
+from cupy.fft import *
+from cupy.fft import __all__ as fft_all
+
+from ..common import _fft
+from .._internal import get_xp
+
+import cupy as cp
+
+fft = get_xp(cp)(_fft.fft),
+ifft = get_xp(cp)(_fft.ifft),
+fftn = get_xp(cp)(_fft.fftn),
+ifftn = get_xp(cp)(_fft.ifftn),
+rfft = get_xp(cp)(_fft.rfft),
+irfft = get_xp(cp)(_fft.irfft),
+rfftn = get_xp(cp)(_fft.rfftn),
+irfftn = get_xp(cp)(_fft.irfftn),
+hfft = get_xp(cp)(_fft.hfft),
+ihfft = get_xp(cp)(_fft.ihfft),
+fftfreq = get_xp(cp)(_fft.fftfreq),
+rfftfreq = get_xp(cp)(_fft.rfftfreq),
+fftshift = get_xp(cp)(_fft.fftshift),
+ifftshift = get_xp(cp)(_fft.ifftshift),
+
+__all__ = fft_all + _fft.__all__
+
+del get_xp
+del cp
+del fft_all
+del _fft
diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py
index 4a49f2f1..ff5efdfd 100644
--- a/array_api_compat/numpy/__init__.py
+++ b/array_api_compat/numpy/__init__.py
@@ -15,6 +15,8 @@
 # dynamically so that the library can be vendored.
 __import__(__package__ + '.linalg')
 
+__import__(__package__ + '.fft')
+
 from .linalg import matrix_transpose, vecdot
 
 from ..common._helpers import *
diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py
new file mode 100644
index 00000000..6093b19d
--- /dev/null
+++ b/array_api_compat/numpy/fft.py
@@ -0,0 +1,29 @@
+from numpy.fft import *
+from numpy.fft import __all__ as fft_all
+
+from ..common import _fft
+from .._internal import get_xp
+
+import numpy as np
+
+fft = get_xp(np)(_fft.fft)
+ifft = get_xp(np)(_fft.ifft)
+fftn = get_xp(np)(_fft.fftn)
+ifftn = get_xp(np)(_fft.ifftn)
+rfft = get_xp(np)(_fft.rfft)
+irfft = get_xp(np)(_fft.irfft)
+rfftn = get_xp(np)(_fft.rfftn)
+irfftn = get_xp(np)(_fft.irfftn)
+hfft = get_xp(np)(_fft.hfft)
+ihfft = get_xp(np)(_fft.ihfft)
+fftfreq = get_xp(np)(_fft.fftfreq)
+rfftfreq = get_xp(np)(_fft.rfftfreq)
+fftshift = get_xp(np)(_fft.fftshift)
+ifftshift = get_xp(np)(_fft.ifftshift)
+
+__all__ = fft_all + _fft.__all__
+
+del get_xp
+del np
+del fft_all
+del _fft

From d7a9ecbad6e522ce43d2d4941e10e0792e252941 Mon Sep 17 00:00:00 2001
From: Aaron Meurer <asmeurer@gmail.com>
Date: Tue, 5 Mar 2024 16:37:57 -0700
Subject: [PATCH 02/13] Fix hfft downcasting logic

---
 array_api_compat/common/_fft.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py
index 4d2bb3fe..666b0b1f 100644
--- a/array_api_compat/common/_fft.py
+++ b/array_api_compat/common/_fft.py
@@ -132,7 +132,7 @@ def hfft(
 ) -> ndarray:
     res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
-        return res.astype(xp.complex64)
+        return res.astype(xp.float32)
     return res
 
 def ihfft(

From f97b59ec343a7290d0ae24554797f00195cb6e45 Mon Sep 17 00:00:00 2001
From: Aaron Meurer <asmeurer@gmail.com>
Date: Tue, 5 Mar 2024 16:39:44 -0700
Subject: [PATCH 03/13] Remove fft xfails

---
 cupy-xfails.txt       | 13 -------------
 numpy-1-21-xfails.txt | 13 -------------
 numpy-dev-xfails.txt  | 13 -------------
 numpy-xfails.txt      | 13 -------------
 torch-xfails.txt      | 13 -------------
 5 files changed, 65 deletions(-)

diff --git a/cupy-xfails.txt b/cupy-xfails.txt
index cfacbe33..e76c4c32 100644
--- a/cupy-xfails.txt
+++ b/cupy-xfails.txt
@@ -164,16 +164,3 @@ array_api_tests/test_special_cases.py::test_unary[sqrt(x_i is -0) -> -0]
 array_api_tests/test_special_cases.py::test_unary[tan(x_i is -0) -> -0]
 array_api_tests/test_special_cases.py::test_unary[tanh(x_i is -0) -> -0]
 array_api_tests/test_special_cases.py::test_unary[trunc(x_i is -0) -> -0]
-
-# fft functions are not yet supported
-# (https://github.com/data-apis/array-api-compat/issues/67)
-array_api_tests/test_fft.py::test_fft
-array_api_tests/test_fft.py::test_ifft
-array_api_tests/test_fft.py::test_fftn
-array_api_tests/test_fft.py::test_ifftn
-array_api_tests/test_fft.py::test_rfft
-array_api_tests/test_fft.py::test_irfft
-array_api_tests/test_fft.py::test_rfftn
-array_api_tests/test_fft.py::test_irfftn
-array_api_tests/test_fft.py::test_hfft
-array_api_tests/test_fft.py::test_ihfft
diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt
index 9a0d2827..dce83859 100644
--- a/numpy-1-21-xfails.txt
+++ b/numpy-1-21-xfails.txt
@@ -50,19 +50,6 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x
 array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
 array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
 
-# fft functions are not yet supported
-# (https://github.com/data-apis/array-api-compat/issues/67)
-array_api_tests/test_fft.py::test_fft
-array_api_tests/test_fft.py::test_ifft
-array_api_tests/test_fft.py::test_fftn
-array_api_tests/test_fft.py::test_ifftn
-array_api_tests/test_fft.py::test_rfft
-array_api_tests/test_fft.py::test_irfft
-array_api_tests/test_fft.py::test_rfftn
-array_api_tests/test_fft.py::test_irfftn
-array_api_tests/test_fft.py::test_hfft
-array_api_tests/test_fft.py::test_ihfft
-
 # NumPy 1.21 specific XFAILS
 ############################
 
diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt
index 5e270a95..8d291d01 100644
--- a/numpy-dev-xfails.txt
+++ b/numpy-dev-xfails.txt
@@ -42,16 +42,3 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
 # The test suite is incorrectly checking sums that have loss of significance
 # (https://github.com/data-apis/array-api-tests/issues/168)
 array_api_tests/test_statistical_functions.py::test_sum
-
-# fft functions are not yet supported
-# (https://github.com/data-apis/array-api-compat/issues/67)
-array_api_tests/test_fft.py::test_fft
-array_api_tests/test_fft.py::test_ifft
-array_api_tests/test_fft.py::test_fftn
-array_api_tests/test_fft.py::test_ifftn
-array_api_tests/test_fft.py::test_rfft
-array_api_tests/test_fft.py::test_irfft
-array_api_tests/test_fft.py::test_rfftn
-array_api_tests/test_fft.py::test_irfftn
-array_api_tests/test_fft.py::test_hfft
-array_api_tests/test_fft.py::test_ihfft
diff --git a/numpy-xfails.txt b/numpy-xfails.txt
index d0be245b..e44d7035 100644
--- a/numpy-xfails.txt
+++ b/numpy-xfails.txt
@@ -44,16 +44,3 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
 # The test suite is incorrectly checking sums that have loss of significance
 # (https://github.com/data-apis/array-api-tests/issues/168)
 array_api_tests/test_statistical_functions.py::test_sum
-
-# fft functions are not yet supported
-# (https://github.com/data-apis/array-api-compat/issues/67)
-array_api_tests/test_fft.py::test_fft
-array_api_tests/test_fft.py::test_ifft
-array_api_tests/test_fft.py::test_fftn
-array_api_tests/test_fft.py::test_ifftn
-array_api_tests/test_fft.py::test_rfft
-array_api_tests/test_fft.py::test_irfft
-array_api_tests/test_fft.py::test_rfftn
-array_api_tests/test_fft.py::test_irfftn
-array_api_tests/test_fft.py::test_hfft
-array_api_tests/test_fft.py::test_ihfft
diff --git a/torch-xfails.txt b/torch-xfails.txt
index caf1aa65..a9106fae 100644
--- a/torch-xfails.txt
+++ b/torch-xfails.txt
@@ -190,16 +190,3 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_expm1
 array_api_tests/test_operators_and_elementwise_functions.py::test_round
 array_api_tests/test_set_functions.py::test_unique_counts
 array_api_tests/test_set_functions.py::test_unique_values
-
-# fft functions are not yet supported
-# (https://github.com/data-apis/array-api-compat/issues/67)
-array_api_tests/test_fft.py::test_fftn
-array_api_tests/test_fft.py::test_ifftn
-array_api_tests/test_fft.py::test_rfft
-array_api_tests/test_fft.py::test_irfft
-array_api_tests/test_fft.py::test_rfftn
-array_api_tests/test_fft.py::test_irfftn
-array_api_tests/test_fft.py::test_hfft
-array_api_tests/test_fft.py::test_ihfft
-array_api_tests/test_fft.py::test_shift_func[fftshift]
-array_api_tests/test_fft.py::test_shift_func[ifftshift]

From 1ea7ecd97f98e025c6c28ce2a18b7ff36e128354 Mon Sep 17 00:00:00 2001
From: Aaron Meurer <asmeurer@gmail.com>
Date: Tue, 5 Mar 2024 16:53:14 -0700
Subject: [PATCH 04/13] Add wrappers for torch.fft

The only thing that needs to be wrapped is a few functions which do not
properly map axes to dim.
---
 array_api_compat/torch/__init__.py |  2 +
 array_api_compat/torch/fft.py      | 84 ++++++++++++++++++++++++++++++
 2 files changed, 86 insertions(+)
 create mode 100644 array_api_compat/torch/fft.py

diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py
index 59898aab..172f5279 100644
--- a/array_api_compat/torch/__init__.py
+++ b/array_api_compat/torch/__init__.py
@@ -17,6 +17,8 @@
 # See the comment in the numpy __init__.py
 __import__(__package__ + '.linalg')
 
+__import__(__package__ + '.fft')
+
 from ..common._helpers import * # noqa: F403
 
 __array_api_version__ = '2022.12'
diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py
new file mode 100644
index 00000000..dbf74cb0
--- /dev/null
+++ b/array_api_compat/torch/fft.py
@@ -0,0 +1,84 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+    import torch
+    array = torch.Tensor
+    from typing import Union, Sequence, Literal
+
+from torch.fft import * # noqa: F403
+import torch.fft
+
+# Several torch fft functions do not map axes to dim
+
+def fftn(
+    x: array,
+    /,
+    *,
+    s: Sequence[int] = None,
+    axes: Sequence[int] = None,
+    norm: Literal["backward", "ortho", "forward"] = "backward",
+    **kwargs,
+) -> array:
+    return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)
+
+def ifftn(
+    x: array,
+    /,
+    *,
+    s: Sequence[int] = None,
+    axes: Sequence[int] = None,
+    norm: Literal["backward", "ortho", "forward"] = "backward",
+    **kwargs,
+) -> array:
+    return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)
+
+def rfftn(
+    x: array,
+    /,
+    *,
+    s: Sequence[int] = None,
+    axes: Sequence[int] = None,
+    norm: Literal["backward", "ortho", "forward"] = "backward",
+    **kwargs,
+) -> array:
+    return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)
+
+def irfftn(
+    x: array,
+    /,
+    *,
+    s: Sequence[int] = None,
+    axes: Sequence[int] = None,
+    norm: Literal["backward", "ortho", "forward"] = "backward",
+    **kwargs,
+) -> array:
+    return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)
+
+def fftshift(
+    x: array,
+    /,
+    *,
+    axes: Union[int, Sequence[int]] = None,
+    **kwargs,
+) -> array:
+    return torch.fft.fftshift(x, dim=axes, **kwargs)
+
+def ifftshift(
+    x: array,
+    /,
+    *,
+    axes: Union[int, Sequence[int]] = None,
+    **kwargs,
+) -> array:
+    return torch.fft.ifftshift(x, dim=axes, **kwargs)
+
+
+__all__ = torch.fft.__all__ + [
+    "fftn",
+    "ifftn",
+    "rfftn",
+    "irfftn",
+    "fftshift",
+    "ifftshift",
+]

From 18960aafc9bbdad120a251e49e4e98a0bf89649f Mon Sep 17 00:00:00 2001
From: Aaron Meurer <asmeurer@gmail.com>
Date: Tue, 5 Mar 2024 18:05:59 -0700
Subject: [PATCH 05/13] Fix ruff and tests

---
 array_api_compat/cupy/__init__.py | 2 +-
 array_api_compat/cupy/fft.py      | 2 +-
 array_api_compat/numpy/fft.py     | 2 +-
 array_api_compat/torch/fft.py     | 2 ++
 4 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py
index 697c20f3..7968d68d 100644
--- a/array_api_compat/cupy/__init__.py
+++ b/array_api_compat/cupy/__init__.py
@@ -1,4 +1,4 @@
-from cupy import *
+from cupy import * # noqa: F403
 
 # from cupy import * doesn't overwrite these builtin names
 from cupy import abs, max, min, round # noqa: F401
diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py
index 8e83abb8..297a52b6 100644
--- a/array_api_compat/cupy/fft.py
+++ b/array_api_compat/cupy/fft.py
@@ -1,4 +1,4 @@
-from cupy.fft import *
+from cupy.fft import * # noqa: F403
 from cupy.fft import __all__ as fft_all
 
 from ..common import _fft
diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py
index 6093b19d..28667594 100644
--- a/array_api_compat/numpy/fft.py
+++ b/array_api_compat/numpy/fft.py
@@ -1,4 +1,4 @@
-from numpy.fft import *
+from numpy.fft import * # noqa: F403
 from numpy.fft import __all__ as fft_all
 
 from ..common import _fft
diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py
index dbf74cb0..3c9117ee 100644
--- a/array_api_compat/torch/fft.py
+++ b/array_api_compat/torch/fft.py
@@ -82,3 +82,5 @@ def ifftshift(
     "fftshift",
     "ifftshift",
 ]
+
+_all_ignore = ['torch']

From da6d4e44572339a6a40b5ce0ef6e46d417481e9c Mon Sep 17 00:00:00 2001
From: Aaron Meurer <asmeurer@gmail.com>
Date: Tue, 5 Mar 2024 18:12:02 -0700
Subject: [PATCH 06/13] Fix cupy fft __all__

---
 array_api_compat/cupy/fft.py | 9 ++++++++-
 1 file changed, 8 insertions(+), 1 deletion(-)

diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py
index 297a52b6..db1f8047 100644
--- a/array_api_compat/cupy/fft.py
+++ b/array_api_compat/cupy/fft.py
@@ -1,5 +1,12 @@
 from cupy.fft import * # noqa: F403
-from cupy.fft import __all__ as fft_all
+# cupy.fft doesn't have __all__. If it is added, replace this with
+#
+# from cupy.fft import __all__ as linalg_all
+_n = {}
+exec('from cupy.fft import *', _n)
+del _n['__builtins__']
+fft_all = list(_n)
+del _n
 
 from ..common import _fft
 from .._internal import get_xp

From 912e80c146c41a25e034e5cd04b83cff9ee28c3e Mon Sep 17 00:00:00 2001
From: Aaron Meurer <asmeurer@gmail.com>
Date: Tue, 5 Mar 2024 18:13:43 -0700
Subject: [PATCH 07/13] Avoid testing against vendored array_api_compat in
 test_all

---
 tests/test_all.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tests/test_all.py b/tests/test_all.py
index 5b49fa14..7a6f74f0 100644
--- a/tests/test_all.py
+++ b/tests/test_all.py
@@ -21,7 +21,7 @@ def test_all(library):
     import_(library, wrapper=True)
 
     for mod_name in sys.modules:
-        if 'array_api_compat.' + library not in mod_name:
+        if not mod_name.startswith('array_api_compat.' + library):
             continue
 
         module = sys.modules[mod_name]

From 4018fe43d0ead5ed83fbcf07580197801bc35f1e Mon Sep 17 00:00:00 2001
From: Aaron Meurer <asmeurer@gmail.com>
Date: Tue, 5 Mar 2024 18:15:52 -0700
Subject: [PATCH 08/13] Fix import_('cupy', wrapper=True) tests helper

---
 tests/_helpers.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tests/_helpers.py b/tests/_helpers.py
index 23cb5db9..c41bc881 100644
--- a/tests/_helpers.py
+++ b/tests/_helpers.py
@@ -6,8 +6,6 @@
 
 
 def import_(library, wrapper=False):
-    if library == 'cupy':
-        return pytest.importorskip(library)
     if 'jax' in library and sys.version_info < (3, 9):
         pytest.skip('JAX array API support does not support Python 3.8')
 
@@ -16,5 +14,7 @@ def import_(library, wrapper=False):
             library = 'jax.experimental.array_api'
         else:
             library = 'array_api_compat.' + library
+    elif library == 'cupy':
+        return pytest.importorskip(library)
 
     return import_module(library)

From d7f95a32e61e168042d1baae4c3919da5a714662 Mon Sep 17 00:00:00 2001
From: Aaron Meurer <asmeurer@gmail.com>
Date: Tue, 5 Mar 2024 18:16:58 -0700
Subject: [PATCH 09/13] Fix test_all for cupy

---
 array_api_compat/cupy/_aliases.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py
index 968b974b..b9364ac6 100644
--- a/array_api_compat/cupy/_aliases.py
+++ b/array_api_compat/cupy/_aliases.py
@@ -77,3 +77,5 @@
                               'acosh', 'asin', 'asinh', 'atan', 'atan2',
                               'atanh', 'bitwise_left_shift', 'bitwise_invert',
                               'bitwise_right_shift', 'concat', 'pow']
+
+_all_ignore = ['cp', 'get_xp']

From 29ec4d6d7c79da4162f4ba3b3d1be267f91b2b91 Mon Sep 17 00:00:00 2001
From: Aaron Meurer <asmeurer@gmail.com>
Date: Tue, 5 Mar 2024 18:18:31 -0700
Subject: [PATCH 10/13] Fix array api tests pytest call in test_cupy.sh

---
 test_cupy.sh | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test_cupy.sh b/test_cupy.sh
index 6b4d6b56..3d8c0711 100755
--- a/test_cupy.sh
+++ b/test_cupy.sh
@@ -26,4 +26,4 @@ mkdir -p $SCRIPT_DIR/.hypothesis
 ln -s $SCRIPT_DIR/.hypothesis .hypothesis
 
 export ARRAY_API_TESTS_MODULE=array_api_compat.cupy
-pytest ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt "$@"
+pytest array_api_tests/ ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt "$@"

From de64f5ecd7d1a1e5abb87b98ea7ad5e3759a5825 Mon Sep 17 00:00:00 2001
From: Aaron Meurer <asmeurer@gmail.com>
Date: Tue, 5 Mar 2024 18:59:31 -0700
Subject: [PATCH 11/13] Remove a bunch of incorrect trailing commas

---
 array_api_compat/cupy/fft.py | 28 ++++++++++++++--------------
 1 file changed, 14 insertions(+), 14 deletions(-)

diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py
index db1f8047..307e0f72 100644
--- a/array_api_compat/cupy/fft.py
+++ b/array_api_compat/cupy/fft.py
@@ -13,20 +13,20 @@
 
 import cupy as cp
 
-fft = get_xp(cp)(_fft.fft),
-ifft = get_xp(cp)(_fft.ifft),
-fftn = get_xp(cp)(_fft.fftn),
-ifftn = get_xp(cp)(_fft.ifftn),
-rfft = get_xp(cp)(_fft.rfft),
-irfft = get_xp(cp)(_fft.irfft),
-rfftn = get_xp(cp)(_fft.rfftn),
-irfftn = get_xp(cp)(_fft.irfftn),
-hfft = get_xp(cp)(_fft.hfft),
-ihfft = get_xp(cp)(_fft.ihfft),
-fftfreq = get_xp(cp)(_fft.fftfreq),
-rfftfreq = get_xp(cp)(_fft.rfftfreq),
-fftshift = get_xp(cp)(_fft.fftshift),
-ifftshift = get_xp(cp)(_fft.ifftshift),
+fft = get_xp(cp)(_fft.fft)
+ifft = get_xp(cp)(_fft.ifft)
+fftn = get_xp(cp)(_fft.fftn)
+ifftn = get_xp(cp)(_fft.ifftn)
+rfft = get_xp(cp)(_fft.rfft)
+irfft = get_xp(cp)(_fft.irfft)
+rfftn = get_xp(cp)(_fft.rfftn)
+irfftn = get_xp(cp)(_fft.irfftn)
+hfft = get_xp(cp)(_fft.hfft)
+ihfft = get_xp(cp)(_fft.ihfft)
+fftfreq = get_xp(cp)(_fft.fftfreq)
+rfftfreq = get_xp(cp)(_fft.rfftfreq)
+fftshift = get_xp(cp)(_fft.fftshift)
+ifftshift = get_xp(cp)(_fft.ifftshift)
 
 __all__ = fft_all + _fft.__all__
 

From 9e613cce3197f53f5f0bd7a03d676c4218521923 Mon Sep 17 00:00:00 2001
From: Aaron Meurer <asmeurer@gmail.com>
Date: Thu, 7 Mar 2024 14:28:09 -0700
Subject: [PATCH 12/13] Fix cupy skipping in the tests

---
 tests/_helpers.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tests/_helpers.py b/tests/_helpers.py
index c41bc881..e8421b52 100644
--- a/tests/_helpers.py
+++ b/tests/_helpers.py
@@ -9,12 +9,12 @@ def import_(library, wrapper=False):
     if 'jax' in library and sys.version_info < (3, 9):
         pytest.skip('JAX array API support does not support Python 3.8')
 
+    if library == 'cupy':
+        pytest.importorskip(library)
     if wrapper:
         if 'jax' in library:
             library = 'jax.experimental.array_api'
         else:
             library = 'array_api_compat.' + library
-    elif library == 'cupy':
-        return pytest.importorskip(library)
 
     return import_module(library)

From cb46aad155d87d7a75aa2d04e8d4c80c7196d691 Mon Sep 17 00:00:00 2001
From: Aaron Meurer <asmeurer@gmail.com>
Date: Thu, 7 Mar 2024 14:36:57 -0700
Subject: [PATCH 13/13] Add xfails for cupy n-dim fft funcs

---
 cupy-xfails.txt | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/cupy-xfails.txt b/cupy-xfails.txt
index e76c4c32..85ca5aa4 100644
--- a/cupy-xfails.txt
+++ b/cupy-xfails.txt
@@ -164,3 +164,9 @@ array_api_tests/test_special_cases.py::test_unary[sqrt(x_i is -0) -> -0]
 array_api_tests/test_special_cases.py::test_unary[tan(x_i is -0) -> -0]
 array_api_tests/test_special_cases.py::test_unary[tanh(x_i is -0) -> -0]
 array_api_tests/test_special_cases.py::test_unary[trunc(x_i is -0) -> -0]
+
+# CuPy gives the wrong shape for n-dim fft funcs. See
+# https://github.com/data-apis/array-api-compat/pull/78#issuecomment-1984527870
+array_api_tests/test_fft.py::test_fftn
+array_api_tests/test_fft.py::test_ifftn
+array_api_tests/test_fft.py::test_rfftn