Skip to content

Commit d12e561

Browse files
authored
Merge pull request #249 from ev-br/fftfreq_dtypes
ENH: add dtype argument to fft.{fftfreq,rfftfreq}
2 parents cb6a3ec + 6cf2edf commit d12e561

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

array_api_compat/common/_fft.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import TYPE_CHECKING, Union, Optional, Literal
44

55
if TYPE_CHECKING:
6-
from ._typing import Device, ndarray
6+
from ._typing import Device, ndarray, DType
77
from collections.abc import Sequence
88

99
# Note: NumPy fft functions improperly upcast float32 and complex64 to
@@ -149,15 +149,37 @@ def ihfft(
149149
return res.astype(xp.complex64)
150150
return res
151151

152-
def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
152+
def fftfreq(
153+
n: int,
154+
/,
155+
xp,
156+
*,
157+
d: float = 1.0,
158+
dtype: Optional[DType] = None,
159+
device: Optional[Device] = None
160+
) -> ndarray:
153161
if device not in ["cpu", None]:
154162
raise ValueError(f"Unsupported device {device!r}")
155-
return xp.fft.fftfreq(n, d=d)
163+
res = xp.fft.fftfreq(n, d=d)
164+
if dtype is not None:
165+
return res.astype(dtype)
166+
return res
156167

157-
def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
168+
def rfftfreq(
169+
n: int,
170+
/,
171+
xp,
172+
*,
173+
d: float = 1.0,
174+
dtype: Optional[DType] = None,
175+
device: Optional[Device] = None
176+
) -> ndarray:
158177
if device not in ["cpu", None]:
159178
raise ValueError(f"Unsupported device {device!r}")
160-
return xp.fft.rfftfreq(n, d=d)
179+
res = xp.fft.rfftfreq(n, d=d)
180+
if dtype is not None:
181+
return res.astype(dtype)
182+
return res
161183

162184
def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
163185
return xp.fft.fftshift(x, axes=axes)

array_api_compat/common/_typing.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ def __len__(self, /) -> int: ...
2121

2222
Array = Any
2323
Device = Any
24+
DType = Any

0 commit comments

Comments
 (0)