Skip to content

Commit

Permalink
ENH: add dtype kwarg to fft.{fftfreq, rfftfreq}
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br committed Feb 1, 2025
1 parent 1a288de commit fe9760e
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions array_api_strict/_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

if TYPE_CHECKING:
from typing import Union, Optional, Literal
from ._typing import Device
from ._typing import Device, Dtype as DType
from collections.abc import Sequence

from ._dtypes import (
Expand Down Expand Up @@ -251,26 +251,52 @@ def ihfft(
return res

@requires_extension('fft')
def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array:
def fftfreq(
n: int,
/,
*,
d: float = 1.0,
dtype: Optional[DType] = None,
device: Optional[Device] = None
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.fft.fftfreq <numpy.fft.fftfreq>`.
See its docstring for more information.
"""
if device is not None and device not in ALL_DEVICES:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.fft.fftfreq(n, d=d), device=device)
if dtype and not dtype in _real_floating_dtypes:
raise ValueError(f"`dtype` must be a real floating-point type. Got {dtype=}.")

np_result = np.fft.fftfreq(n, d=d)
if dtype:
np_result = np_result.astype(dtype._np_dtype)
return Array._new(np_result, device=device)

@requires_extension('fft')
def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array:
def rfftfreq(
n: int,
/,
*,
d: float = 1.0,
dtype: Optional[DType] = None,
device: Optional[Device] = None
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.fft.rfftfreq <numpy.fft.rfftfreq>`.
See its docstring for more information.
"""
if device is not None and device not in ALL_DEVICES:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.fft.rfftfreq(n, d=d), device=device)
if dtype and not dtype in _real_floating_dtypes:
raise ValueError(f"`dtype` must be a real floating-point type. Got {dtype=}.")

np_result = np.fft.rfftfreq(n, d=d)
if dtype:
np_result = np_result.astype(dtype._np_dtype)
return Array._new(np_result, device=device)

@requires_extension('fft')
def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
Expand Down

0 comments on commit fe9760e

Please sign in to comment.