@@ -251,26 +251,52 @@ def ihfft(
251
251
return res
252
252
253
253
@requires_extension ('fft' )
254
- def fftfreq (n : int , / , * , d : float = 1.0 , device : Optional [Device ] = None ) -> Array :
254
+ def fftfreq (
255
+ n : int ,
256
+ / ,
257
+ * ,
258
+ d : float = 1.0 ,
259
+ dtype : Optional [dtype ] = None ,
260
+ device : Optional [Device ] = None
261
+ ) -> Array :
255
262
"""
256
263
Array API compatible wrapper for :py:func:`np.fft.fftfreq <numpy.fft.fftfreq>`.
257
264
258
265
See its docstring for more information.
259
266
"""
260
267
if device is not None and device not in ALL_DEVICES :
261
268
raise ValueError (f"Unsupported device { device !r} " )
262
- return Array ._new (np .fft .fftfreq (n , d = d ), device = device )
269
+ if dtype and not dtype in _real_floating_dtypes :
270
+ raise ValueError (f"`dtype` must be a real floating-point type. Got { dtype = } ." )
271
+
272
+ np_result = np .fft .fftfreq (n , d = d )
273
+ if dtype :
274
+ np_result = np_result .astype (dtype ._np_dtype )
275
+ return Array ._new (np_result , device = device )
263
276
264
277
@requires_extension ('fft' )
265
- def rfftfreq (n : int , / , * , d : float = 1.0 , device : Optional [Device ] = None ) -> Array :
278
+ def rfftfreq (
279
+ n : int ,
280
+ / ,
281
+ * ,
282
+ d : float = 1.0 ,
283
+ dtype : Optional [dtype ] = None ,
284
+ device : Optional [Device ] = None
285
+ ) -> Array :
266
286
"""
267
287
Array API compatible wrapper for :py:func:`np.fft.rfftfreq <numpy.fft.rfftfreq>`.
268
288
269
289
See its docstring for more information.
270
290
"""
271
291
if device is not None and device not in ALL_DEVICES :
272
292
raise ValueError (f"Unsupported device { device !r} " )
273
- return Array ._new (np .fft .rfftfreq (n , d = d ), device = device )
293
+ if dtype and not dtype in _real_floating_dtypes :
294
+ raise ValueError (f"`dtype` must be a real floating-point type. Got { dtype = } ." )
295
+
296
+ np_result = np .fft .rfftfreq (n , d = d )
297
+ if dtype :
298
+ np_result = np_result .astype (dtype ._np_dtype )
299
+ return Array ._new (np_result , device = device )
274
300
275
301
@requires_extension ('fft' )
276
302
def fftshift (x : Array , / , * , axes : Union [int , Sequence [int ]] = None ) -> Array :
0 commit comments