|
3 | 3 | from typing import TYPE_CHECKING, Union, Optional, Literal
|
4 | 4 |
|
5 | 5 | if TYPE_CHECKING:
|
6 |
| - from ._typing import Device, ndarray |
| 6 | + from ._typing import Device, ndarray, DType |
7 | 7 | from collections.abc import Sequence
|
8 | 8 |
|
9 | 9 | # Note: NumPy fft functions improperly upcast float32 and complex64 to
|
@@ -149,15 +149,37 @@ def ihfft(
|
149 | 149 | return res.astype(xp.complex64)
|
150 | 150 | return res
|
151 | 151 |
|
152 |
| -def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray: |
| 152 | +def fftfreq( |
| 153 | + n: int, |
| 154 | + /, |
| 155 | + xp, |
| 156 | + *, |
| 157 | + d: float = 1.0, |
| 158 | + dtype: Optional[DType] = None, |
| 159 | + device: Optional[Device] = None |
| 160 | +) -> ndarray: |
153 | 161 | if device not in ["cpu", None]:
|
154 | 162 | raise ValueError(f"Unsupported device {device!r}")
|
155 |
| - return xp.fft.fftfreq(n, d=d) |
| 163 | + res = xp.fft.fftfreq(n, d=d) |
| 164 | + if dtype is not None: |
| 165 | + return res.astype(dtype) |
| 166 | + return res |
156 | 167 |
|
157 |
| -def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray: |
| 168 | +def rfftfreq( |
| 169 | + n: int, |
| 170 | + /, |
| 171 | + xp, |
| 172 | + *, |
| 173 | + d: float = 1.0, |
| 174 | + dtype: Optional[DType] = None, |
| 175 | + device: Optional[Device] = None |
| 176 | +) -> ndarray: |
158 | 177 | if device not in ["cpu", None]:
|
159 | 178 | raise ValueError(f"Unsupported device {device!r}")
|
160 |
| - return xp.fft.rfftfreq(n, d=d) |
| 179 | + res = xp.fft.rfftfreq(n, d=d) |
| 180 | + if dtype is not None: |
| 181 | + return res.astype(dtype) |
| 182 | + return res |
161 | 183 |
|
162 | 184 | def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
|
163 | 185 | return xp.fft.fftshift(x, axes=axes)
|
|
0 commit comments