@@ -223,8 +223,12 @@ def logcumsumexp(self, dim):
223223 return torch .empty_like (self ).contiguous ()
224224
225225
226- # Stride-related code from _exec_fft in aten/src/ATen/native/cuda/SpectralOps.cpp
227- def _exec_fft (out , self , out_sizes , dim , forward ):
226+ # Stride-related code from _exec_fft in aten/src/ATen/native/mkl/SpectralOps.cpp
227+ # and aten/src/ATen/cuda/SpectralOps.cpp
228+ #
229+ # Although the actual FFT launch is different, all the permuting code appears
230+ # to be the same
231+ def _exec_fft (out , self , out_sizes , dim , * , forward ):
228232 ndim = self .ndim
229233 signal_ndim = len (dim )
230234 batch_dims = ndim - signal_ndim
@@ -258,12 +262,12 @@ def _exec_fft(out, self, out_sizes, dim, forward):
258262
259263 batch_size = input .size (0 )
260264 batched_sizes [0 ] = batch_size
261- batched_out_sizes = batched_sizes
265+ batched_out_sizes = list ( batched_sizes )
262266 for i in range (len (dim )):
263267 batched_out_sizes [i + 1 ] = out_sizes [dim [i ]]
264- out = out . reshape (batched_out_sizes )
268+ out . resize_ (batched_out_sizes , memory_format = torch . contiguous_format )
265269
266- # Reshaping to original batch shape and inverting the dimension permutation
270+ # Inplace reshaping to original batch shape and inverting the dimension permutation
267271 out_strides = [0 for _ in range (ndim )]
268272 batch_numel = 1
269273 i = batch_dims - 1
@@ -273,44 +277,102 @@ def _exec_fft(out, self, out_sizes, dim, forward):
273277 i -= 1
274278 for i in range (batch_dims , ndim ):
275279 out_strides [dim_permute [i ]] = out .stride (1 + (i - batch_dims ))
276- return out .as_strided (out_sizes , out_strides , out .storage_offset ())
280+ out .as_strided_ (out_sizes , out_strides , out .storage_offset ())
281+
282+ return out
283+
284+
285+ def _sort_dims (self : Tensor , dim : list [int ], exclude_last : bool = False ):
286+ sorted_dims = list (dim )
287+ self_strides = self .stride ()
288+ sorted_dims [: len (sorted_dims ) - int (exclude_last )].sort (
289+ key = lambda i : self_strides [i ]
290+ )
291+ return sorted_dims
277292
278293
279294# See _fft_c2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
280295# and _fft_c2c_mkl in aten/src/ATen/native/mkl/SpectralOps.cpp
281296@register_meta ([aten ._fft_c2c .default , aten ._fft_c2c .out ])
282297@out_wrapper ()
283298def meta_fft_c2c (self , dim , normalization , forward ):
284- assert self .dtype .is_complex
299+ torch ._check (self .dtype .is_complex )
300+ if not dim :
301+ return self .clone ()
285302
286- out_sizes = self .shape
287- output = self .new_empty (out_sizes )
303+ sorted_dims = _sort_dims (self , dim )
304+ out = self .new_empty (self .size ())
305+ return _exec_fft (out , self , self .size (), sorted_dims , forward = forward )
288306
289- if not dim :
290- return output
291307
292- sorted_dims = dim [:]
293- self_strides = self .stride ()
294- sorted_dims .sort (key = lambda x : self_strides [x ], reverse = True )
295- output = _exec_fft (output , self , out_sizes , sorted_dims , forward )
308+ cufft_max_ndim = 3
296309
297- return output
310+
311+ def use_optimized_cufft_path (dim : list [int ]):
312+ if len (dim ) > cufft_max_ndim or (len (dim ) >= 2 and dim [0 ] == 0 and dim [1 ] == 1 ):
313+ return False
314+ else :
315+ return True
298316
299317
300318@register_meta ([aten ._fft_r2c .default , aten ._fft_r2c .out ])
301319@out_wrapper ()
302320def meta_fft_r2c (self , dim , normalization , onesided ):
303- assert self .dtype .is_floating_point
304- output_sizes = list (self .size ())
321+ torch ._check (self .dtype .is_floating_point )
322+ input_sizes = list (self .size ())
323+ out_sizes = list (input_sizes )
324+ last_dim = dim [- 1 ]
325+ last_dim_halfsize = input_sizes [last_dim ] // 2 + 1
326+ onesided_sizes = list (input_sizes )
327+ onesided_sizes [last_dim ] = last_dim_halfsize
305328
306329 if onesided :
307- last_dim = dim [- 1 ]
308- last_dim_halfsize = (output_sizes [last_dim ] // 2 ) + 1
309- output_sizes [last_dim ] = last_dim_halfsize
330+ out_sizes [last_dim ] = last_dim_halfsize
310331
311- return self .new_empty (
312- output_sizes , dtype = utils .corresponding_complex_dtype (self .dtype )
313- )
332+ if device_hint (self ) == "cuda" :
333+ # _fft_r2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
334+ output = self .new_empty (
335+ out_sizes , dtype = utils .corresponding_complex_dtype (self .dtype )
336+ )
337+
338+ working_tensor = self
339+ if use_optimized_cufft_path (dim ):
340+ _exec_fft (output , working_tensor , out_sizes , dim , forward = True )
341+ else :
342+ # First do the R2C transform on the last dimension
343+ target_sizes = out_sizes if len (dim ) == 1 else onesided_sizes
344+ _exec_fft (output , working_tensor , target_sizes , [last_dim ], forward = True )
345+ if len (dim ) > 1 :
346+ working_tensor = self .new_empty (
347+ out_sizes , dtype = utils .corresponding_complex_dtype (self .dtype )
348+ )
349+
350+ # Then any remaining C2C transforms
351+ sorted_dims = dim [:- 1 ]
352+ while sorted_dims :
353+ output , working_tensor = working_tensor , output
354+ strides = working_tensor .stride ()
355+ sorted_dims .sort (
356+ key = lambda i : strides [i ], reverse = True
357+ ) # NB reverse! Not sure if this is og bug
358+ max_dims = min (cufft_max_ndim , len (sorted_dims ))
359+ last_dims = sorted_dims [len (sorted_dims ) - max_dims :]
360+ _exec_fft (
361+ output , working_tensor , onesided_sizes , last_dims , forward = True
362+ )
363+ sorted_dims = sorted_dims [: len (sorted_dims ) - max_dims ]
364+
365+ if not onesided :
366+ if output .size (last_dim ) != out_sizes [last_dim ]:
367+ working_tensor .resize_ (out_sizes , memory_format = torch .contiguous_format )
368+ output = working_tensor
369+
370+ return output
371+
372+ else :
373+ return self .new_empty (
374+ out_sizes , dtype = utils .corresponding_complex_dtype (self .dtype )
375+ )
314376
315377
316378@register_meta (aten .randperm .generator_out )
@@ -375,11 +437,43 @@ def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=
375437
376438@register_meta ([aten ._fft_c2r .default , aten ._fft_c2r .out ])
377439@out_wrapper ()
378- def meta_fft_c2r (self , dim , normalization , lastdim ):
379- assert self .dtype .is_complex
380- output_sizes = list (self .size ())
381- output_sizes [dim [- 1 ]] = lastdim
382- return self .new_empty (output_sizes , dtype = toRealValueType (self .dtype ))
440+ def meta_fft_c2r (self : Tensor , dim : list [int ], normalization : int , lastdim : int ):
441+ # _fft_c2r_mkl
442+ torch ._check (self .dtype .is_complex )
443+
444+ if device_hint (self ) == "cuda" :
445+ out_sizes = list (self .size ())
446+ out_sizes [dim [- 1 ]] = lastdim
447+
448+ output = self .new_empty (out_sizes , dtype = toRealValueType (self .dtype ))
449+
450+ if use_optimized_cufft_path (dim ):
451+ return _exec_fft (
452+ output ,
453+ self .clone (memory_format = torch .contiguous_format ),
454+ out_sizes ,
455+ dim ,
456+ forward = False ,
457+ )
458+ else :
459+ # First complete any C2C transforms
460+ if len (dim ) > 1 :
461+ temp = meta_fft_c2c (self , dim [:- 1 ], 0 , lastdim ) # fft_norm_mode::none
462+ else :
463+ temp = self .clone (memory_format = torch .contiguous_format )
464+ return _exec_fft (output , temp , out_sizes , [dim [- 1 ]], forward = False )
465+
466+ else :
467+ input = self
468+ if len (dim ) > 1 :
469+ c2c_dims = dim [:- 1 ]
470+ input = meta_fft_c2c (self , c2c_dims , normalization , forward = False )
471+ dim = dim [- 1 :]
472+
473+ out_sizes = list (input .size ())
474+ out_sizes [dim [- 1 ]] = lastdim
475+ out = self .new_empty (out_sizes , dtype = toRealValueType (self .dtype ))
476+ return _exec_fft (out , input , out_sizes , dim , forward = False )
383477
384478
385479@register_meta (aten .copy_ .default )
0 commit comments