Skip to content

Commit 2ec2dce

Browse files
authored
Merge pull request #78 from asmeurer/fft
Add fft support for numpy and cupy
2 parents 8240d19 + 0864c73 commit 2ec2dce

16 files changed

+348
-65
lines changed

Diff for: array_api_compat/common/_fft.py

+183
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Union, Optional, Literal
4+
5+
if TYPE_CHECKING:
6+
from ._typing import Device, ndarray
7+
from collections.abc import Sequence
8+
9+
# Note: NumPy fft functions improperly upcast float32 and complex64 to
10+
# complex128, which is why we require wrapping them all here.
11+
12+
def fft(
13+
x: ndarray,
14+
/,
15+
xp,
16+
*,
17+
n: Optional[int] = None,
18+
axis: int = -1,
19+
norm: Literal["backward", "ortho", "forward"] = "backward",
20+
) -> ndarray:
21+
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
22+
if x.dtype in [xp.float32, xp.complex64]:
23+
return res.astype(xp.complex64)
24+
return res
25+
26+
def ifft(
27+
x: ndarray,
28+
/,
29+
xp,
30+
*,
31+
n: Optional[int] = None,
32+
axis: int = -1,
33+
norm: Literal["backward", "ortho", "forward"] = "backward",
34+
) -> ndarray:
35+
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
36+
if x.dtype in [xp.float32, xp.complex64]:
37+
return res.astype(xp.complex64)
38+
return res
39+
40+
def fftn(
41+
x: ndarray,
42+
/,
43+
xp,
44+
*,
45+
s: Sequence[int] = None,
46+
axes: Sequence[int] = None,
47+
norm: Literal["backward", "ortho", "forward"] = "backward",
48+
) -> ndarray:
49+
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
50+
if x.dtype in [xp.float32, xp.complex64]:
51+
return res.astype(xp.complex64)
52+
return res
53+
54+
def ifftn(
55+
x: ndarray,
56+
/,
57+
xp,
58+
*,
59+
s: Sequence[int] = None,
60+
axes: Sequence[int] = None,
61+
norm: Literal["backward", "ortho", "forward"] = "backward",
62+
) -> ndarray:
63+
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
64+
if x.dtype in [xp.float32, xp.complex64]:
65+
return res.astype(xp.complex64)
66+
return res
67+
68+
def rfft(
69+
x: ndarray,
70+
/,
71+
xp,
72+
*,
73+
n: Optional[int] = None,
74+
axis: int = -1,
75+
norm: Literal["backward", "ortho", "forward"] = "backward",
76+
) -> ndarray:
77+
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
78+
if x.dtype == xp.float32:
79+
return res.astype(xp.complex64)
80+
return res
81+
82+
def irfft(
83+
x: ndarray,
84+
/,
85+
xp,
86+
*,
87+
n: Optional[int] = None,
88+
axis: int = -1,
89+
norm: Literal["backward", "ortho", "forward"] = "backward",
90+
) -> ndarray:
91+
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
92+
if x.dtype == xp.complex64:
93+
return res.astype(xp.float32)
94+
return res
95+
96+
def rfftn(
97+
x: ndarray,
98+
/,
99+
xp,
100+
*,
101+
s: Sequence[int] = None,
102+
axes: Sequence[int] = None,
103+
norm: Literal["backward", "ortho", "forward"] = "backward",
104+
) -> ndarray:
105+
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
106+
if x.dtype == xp.float32:
107+
return res.astype(xp.complex64)
108+
return res
109+
110+
def irfftn(
111+
x: ndarray,
112+
/,
113+
xp,
114+
*,
115+
s: Sequence[int] = None,
116+
axes: Sequence[int] = None,
117+
norm: Literal["backward", "ortho", "forward"] = "backward",
118+
) -> ndarray:
119+
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
120+
if x.dtype == xp.complex64:
121+
return res.astype(xp.float32)
122+
return res
123+
124+
def hfft(
125+
x: ndarray,
126+
/,
127+
xp,
128+
*,
129+
n: Optional[int] = None,
130+
axis: int = -1,
131+
norm: Literal["backward", "ortho", "forward"] = "backward",
132+
) -> ndarray:
133+
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
134+
if x.dtype in [xp.float32, xp.complex64]:
135+
return res.astype(xp.float32)
136+
return res
137+
138+
def ihfft(
139+
x: ndarray,
140+
/,
141+
xp,
142+
*,
143+
n: Optional[int] = None,
144+
axis: int = -1,
145+
norm: Literal["backward", "ortho", "forward"] = "backward",
146+
) -> ndarray:
147+
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
148+
if x.dtype in [xp.float32, xp.complex64]:
149+
return res.astype(xp.complex64)
150+
return res
151+
152+
def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
153+
if device not in ["cpu", None]:
154+
raise ValueError(f"Unsupported device {device!r}")
155+
return xp.fft.fftfreq(n, d=d)
156+
157+
def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
158+
if device not in ["cpu", None]:
159+
raise ValueError(f"Unsupported device {device!r}")
160+
return xp.fft.rfftfreq(n, d=d)
161+
162+
def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
163+
return xp.fft.fftshift(x, axes=axes)
164+
165+
def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
166+
return xp.fft.ifftshift(x, axes=axes)
167+
168+
__all__ = [
169+
"fft",
170+
"ifft",
171+
"fftn",
172+
"ifftn",
173+
"rfft",
174+
"irfft",
175+
"rfftn",
176+
"irfftn",
177+
"hfft",
178+
"ihfft",
179+
"fftfreq",
180+
"rfftfreq",
181+
"fftshift",
182+
"ifftshift",
183+
]

Diff for: array_api_compat/cupy/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# See the comment in the numpy __init__.py
1010
__import__(__package__ + '.linalg')
1111

12+
__import__(__package__ + '.fft')
13+
1214
from ..common._helpers import * # noqa: F401,F403
1315

1416
__array_api_version__ = '2022.12'

Diff for: array_api_compat/cupy/_aliases.py

+2
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,5 @@
7777
'acosh', 'asin', 'asinh', 'atan', 'atan2',
7878
'atanh', 'bitwise_left_shift', 'bitwise_invert',
7979
'bitwise_right_shift', 'concat', 'pow']
80+
81+
_all_ignore = ['cp', 'get_xp']

Diff for: array_api_compat/cupy/fft.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from cupy.fft import * # noqa: F403
2+
# cupy.fft doesn't have __all__. If it is added, replace this with
3+
#
4+
# from cupy.fft import __all__ as linalg_all
5+
_n = {}
6+
exec('from cupy.fft import *', _n)
7+
del _n['__builtins__']
8+
fft_all = list(_n)
9+
del _n
10+
11+
from ..common import _fft
12+
from .._internal import get_xp
13+
14+
import cupy as cp
15+
16+
fft = get_xp(cp)(_fft.fft)
17+
ifft = get_xp(cp)(_fft.ifft)
18+
fftn = get_xp(cp)(_fft.fftn)
19+
ifftn = get_xp(cp)(_fft.ifftn)
20+
rfft = get_xp(cp)(_fft.rfft)
21+
irfft = get_xp(cp)(_fft.irfft)
22+
rfftn = get_xp(cp)(_fft.rfftn)
23+
irfftn = get_xp(cp)(_fft.irfftn)
24+
hfft = get_xp(cp)(_fft.hfft)
25+
ihfft = get_xp(cp)(_fft.ihfft)
26+
fftfreq = get_xp(cp)(_fft.fftfreq)
27+
rfftfreq = get_xp(cp)(_fft.rfftfreq)
28+
fftshift = get_xp(cp)(_fft.fftshift)
29+
ifftshift = get_xp(cp)(_fft.ifftshift)
30+
31+
__all__ = fft_all + _fft.__all__
32+
33+
del get_xp
34+
del cp
35+
del fft_all
36+
del _fft

Diff for: array_api_compat/numpy/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# dynamically so that the library can be vendored.
1616
__import__(__package__ + '.linalg')
1717

18+
__import__(__package__ + '.fft')
19+
1820
from .linalg import matrix_transpose, vecdot # noqa: F401
1921

2022
from ..common._helpers import * # noqa: F403

Diff for: array_api_compat/numpy/fft.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from numpy.fft import * # noqa: F403
2+
from numpy.fft import __all__ as fft_all
3+
4+
from ..common import _fft
5+
from .._internal import get_xp
6+
7+
import numpy as np
8+
9+
fft = get_xp(np)(_fft.fft)
10+
ifft = get_xp(np)(_fft.ifft)
11+
fftn = get_xp(np)(_fft.fftn)
12+
ifftn = get_xp(np)(_fft.ifftn)
13+
rfft = get_xp(np)(_fft.rfft)
14+
irfft = get_xp(np)(_fft.irfft)
15+
rfftn = get_xp(np)(_fft.rfftn)
16+
irfftn = get_xp(np)(_fft.irfftn)
17+
hfft = get_xp(np)(_fft.hfft)
18+
ihfft = get_xp(np)(_fft.ihfft)
19+
fftfreq = get_xp(np)(_fft.fftfreq)
20+
rfftfreq = get_xp(np)(_fft.rfftfreq)
21+
fftshift = get_xp(np)(_fft.fftshift)
22+
ifftshift = get_xp(np)(_fft.ifftshift)
23+
24+
__all__ = fft_all + _fft.__all__
25+
26+
del get_xp
27+
del np
28+
del fft_all
29+
del _fft

Diff for: array_api_compat/torch/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
# See the comment in the numpy __init__.py
1818
__import__(__package__ + '.linalg')
1919

20+
__import__(__package__ + '.fft')
21+
2022
from ..common._helpers import * # noqa: F403
2123

2224
__array_api_version__ = '2022.12'

Diff for: array_api_compat/torch/fft.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
if TYPE_CHECKING:
5+
import torch
6+
array = torch.Tensor
7+
from typing import Union, Sequence, Literal
8+
9+
from torch.fft import * # noqa: F403
10+
import torch.fft
11+
12+
# Several torch fft functions do not map axes to dim
13+
14+
def fftn(
15+
x: array,
16+
/,
17+
*,
18+
s: Sequence[int] = None,
19+
axes: Sequence[int] = None,
20+
norm: Literal["backward", "ortho", "forward"] = "backward",
21+
**kwargs,
22+
) -> array:
23+
return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)
24+
25+
def ifftn(
26+
x: array,
27+
/,
28+
*,
29+
s: Sequence[int] = None,
30+
axes: Sequence[int] = None,
31+
norm: Literal["backward", "ortho", "forward"] = "backward",
32+
**kwargs,
33+
) -> array:
34+
return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)
35+
36+
def rfftn(
37+
x: array,
38+
/,
39+
*,
40+
s: Sequence[int] = None,
41+
axes: Sequence[int] = None,
42+
norm: Literal["backward", "ortho", "forward"] = "backward",
43+
**kwargs,
44+
) -> array:
45+
return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)
46+
47+
def irfftn(
48+
x: array,
49+
/,
50+
*,
51+
s: Sequence[int] = None,
52+
axes: Sequence[int] = None,
53+
norm: Literal["backward", "ortho", "forward"] = "backward",
54+
**kwargs,
55+
) -> array:
56+
return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)
57+
58+
def fftshift(
59+
x: array,
60+
/,
61+
*,
62+
axes: Union[int, Sequence[int]] = None,
63+
**kwargs,
64+
) -> array:
65+
return torch.fft.fftshift(x, dim=axes, **kwargs)
66+
67+
def ifftshift(
68+
x: array,
69+
/,
70+
*,
71+
axes: Union[int, Sequence[int]] = None,
72+
**kwargs,
73+
) -> array:
74+
return torch.fft.ifftshift(x, dim=axes, **kwargs)
75+
76+
77+
__all__ = torch.fft.__all__ + [
78+
"fftn",
79+
"ifftn",
80+
"rfftn",
81+
"irfftn",
82+
"fftshift",
83+
"ifftshift",
84+
]
85+
86+
_all_ignore = ['torch']

Diff for: cupy-xfails.txt

+2-9
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,8 @@ array_api_tests/test_special_cases.py::test_unary[tan(x_i is -0) -> -0]
165165
array_api_tests/test_special_cases.py::test_unary[tanh(x_i is -0) -> -0]
166166
array_api_tests/test_special_cases.py::test_unary[trunc(x_i is -0) -> -0]
167167

168-
# fft functions are not yet supported
169-
# (https://github.com/data-apis/array-api-compat/issues/67)
170-
array_api_tests/test_fft.py::test_fft
171-
array_api_tests/test_fft.py::test_ifft
168+
# CuPy gives the wrong shape for n-dim fft funcs. See
169+
# https://github.com/data-apis/array-api-compat/pull/78#issuecomment-1984527870
172170
array_api_tests/test_fft.py::test_fftn
173171
array_api_tests/test_fft.py::test_ifftn
174-
array_api_tests/test_fft.py::test_rfft
175-
array_api_tests/test_fft.py::test_irfft
176172
array_api_tests/test_fft.py::test_rfftn
177-
array_api_tests/test_fft.py::test_irfftn
178-
array_api_tests/test_fft.py::test_hfft
179-
array_api_tests/test_fft.py::test_ihfft

0 commit comments

Comments
 (0)