@@ -223,8 +223,12 @@ def logcumsumexp(self, dim):
223
223
return torch .empty_like (self ).contiguous ()
224
224
225
225
226
- # Stride-related code from _exec_fft in aten/src/ATen/native/cuda/SpectralOps.cpp
227
- def _exec_fft (out , self , out_sizes , dim , forward ):
226
+ # Stride-related code from _exec_fft in aten/src/ATen/native/mkl/SpectralOps.cpp
227
+ # and aten/src/ATen/cuda/SpectralOps.cpp
228
+ #
229
+ # Although the actual FFT launch is different, all the permuting code appears
230
+ # to be the same
231
+ def _exec_fft (out , self , out_sizes , dim , * , forward ):
228
232
ndim = self .ndim
229
233
signal_ndim = len (dim )
230
234
batch_dims = ndim - signal_ndim
@@ -258,12 +262,12 @@ def _exec_fft(out, self, out_sizes, dim, forward):
258
262
259
263
batch_size = input .size (0 )
260
264
batched_sizes [0 ] = batch_size
261
- batched_out_sizes = batched_sizes
265
+ batched_out_sizes = list ( batched_sizes )
262
266
for i in range (len (dim )):
263
267
batched_out_sizes [i + 1 ] = out_sizes [dim [i ]]
264
- out = out . reshape (batched_out_sizes )
268
+ out . resize_ (batched_out_sizes , memory_format = torch . contiguous_format )
265
269
266
- # Reshaping to original batch shape and inverting the dimension permutation
270
+ # Inplace reshaping to original batch shape and inverting the dimension permutation
267
271
out_strides = [0 for _ in range (ndim )]
268
272
batch_numel = 1
269
273
i = batch_dims - 1
@@ -273,44 +277,102 @@ def _exec_fft(out, self, out_sizes, dim, forward):
273
277
i -= 1
274
278
for i in range (batch_dims , ndim ):
275
279
out_strides [dim_permute [i ]] = out .stride (1 + (i - batch_dims ))
276
- return out .as_strided (out_sizes , out_strides , out .storage_offset ())
280
+ out .as_strided_ (out_sizes , out_strides , out .storage_offset ())
281
+
282
+ return out
283
+
284
+
285
+ def _sort_dims (self : Tensor , dim : list [int ], exclude_last : bool = False ):
286
+ sorted_dims = list (dim )
287
+ self_strides = self .stride ()
288
+ sorted_dims [: len (sorted_dims ) - int (exclude_last )].sort (
289
+ key = lambda i : self_strides [i ]
290
+ )
291
+ return sorted_dims
277
292
278
293
279
294
# See _fft_c2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
280
295
# and _fft_c2c_mkl in aten/src/ATen/native/mkl/SpectralOps.cpp
281
296
@register_meta ([aten ._fft_c2c .default , aten ._fft_c2c .out ])
282
297
@out_wrapper ()
283
298
def meta_fft_c2c (self , dim , normalization , forward ):
284
- assert self .dtype .is_complex
299
+ torch ._check (self .dtype .is_complex )
300
+ if not dim :
301
+ return self .clone ()
285
302
286
- out_sizes = self .shape
287
- output = self .new_empty (out_sizes )
303
+ sorted_dims = _sort_dims (self , dim )
304
+ out = self .new_empty (self .size ())
305
+ return _exec_fft (out , self , self .size (), sorted_dims , forward = forward )
288
306
289
- if not dim :
290
- return output
291
307
292
- sorted_dims = dim [:]
293
- self_strides = self .stride ()
294
- sorted_dims .sort (key = lambda x : self_strides [x ], reverse = True )
295
- output = _exec_fft (output , self , out_sizes , sorted_dims , forward )
308
+ cufft_max_ndim = 3
296
309
297
- return output
310
+
311
+ def use_optimized_cufft_path (dim : list [int ]):
312
+ if len (dim ) > cufft_max_ndim or (len (dim ) >= 2 and dim [0 ] == 0 and dim [1 ] == 1 ):
313
+ return False
314
+ else :
315
+ return True
298
316
299
317
300
318
@register_meta ([aten ._fft_r2c .default , aten ._fft_r2c .out ])
301
319
@out_wrapper ()
302
320
def meta_fft_r2c (self , dim , normalization , onesided ):
303
- assert self .dtype .is_floating_point
304
- output_sizes = list (self .size ())
321
+ torch ._check (self .dtype .is_floating_point )
322
+ input_sizes = list (self .size ())
323
+ out_sizes = list (input_sizes )
324
+ last_dim = dim [- 1 ]
325
+ last_dim_halfsize = input_sizes [last_dim ] // 2 + 1
326
+ onesided_sizes = list (input_sizes )
327
+ onesided_sizes [last_dim ] = last_dim_halfsize
305
328
306
329
if onesided :
307
- last_dim = dim [- 1 ]
308
- last_dim_halfsize = (output_sizes [last_dim ] // 2 ) + 1
309
- output_sizes [last_dim ] = last_dim_halfsize
330
+ out_sizes [last_dim ] = last_dim_halfsize
310
331
311
- return self .new_empty (
312
- output_sizes , dtype = utils .corresponding_complex_dtype (self .dtype )
313
- )
332
+ if device_hint (self ) == "cuda" :
333
+ # _fft_r2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
334
+ output = self .new_empty (
335
+ out_sizes , dtype = utils .corresponding_complex_dtype (self .dtype )
336
+ )
337
+
338
+ working_tensor = self
339
+ if use_optimized_cufft_path (dim ):
340
+ _exec_fft (output , working_tensor , out_sizes , dim , forward = True )
341
+ else :
342
+ # First do the R2C transform on the last dimension
343
+ target_sizes = out_sizes if len (dim ) == 1 else onesided_sizes
344
+ _exec_fft (output , working_tensor , target_sizes , [last_dim ], forward = True )
345
+ if len (dim ) > 1 :
346
+ working_tensor = self .new_empty (
347
+ out_sizes , dtype = utils .corresponding_complex_dtype (self .dtype )
348
+ )
349
+
350
+ # Then any remaining C2C transforms
351
+ sorted_dims = dim [:- 1 ]
352
+ while sorted_dims :
353
+ output , working_tensor = working_tensor , output
354
+ strides = working_tensor .stride ()
355
+ sorted_dims .sort (
356
+ key = lambda i : strides [i ], reverse = True
357
+ ) # NB reverse! Not sure if this is og bug
358
+ max_dims = min (cufft_max_ndim , len (sorted_dims ))
359
+ last_dims = sorted_dims [len (sorted_dims ) - max_dims :]
360
+ _exec_fft (
361
+ output , working_tensor , onesided_sizes , last_dims , forward = True
362
+ )
363
+ sorted_dims = sorted_dims [: len (sorted_dims ) - max_dims ]
364
+
365
+ if not onesided :
366
+ if output .size (last_dim ) != out_sizes [last_dim ]:
367
+ working_tensor .resize_ (out_sizes , memory_format = torch .contiguous_format )
368
+ output = working_tensor
369
+
370
+ return output
371
+
372
+ else :
373
+ return self .new_empty (
374
+ out_sizes , dtype = utils .corresponding_complex_dtype (self .dtype )
375
+ )
314
376
315
377
316
378
@register_meta (aten .randperm .generator_out )
@@ -375,11 +437,43 @@ def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=
375
437
376
438
@register_meta ([aten ._fft_c2r .default , aten ._fft_c2r .out ])
377
439
@out_wrapper ()
378
- def meta_fft_c2r (self , dim , normalization , lastdim ):
379
- assert self .dtype .is_complex
380
- output_sizes = list (self .size ())
381
- output_sizes [dim [- 1 ]] = lastdim
382
- return self .new_empty (output_sizes , dtype = toRealValueType (self .dtype ))
440
+ def meta_fft_c2r (self : Tensor , dim : list [int ], normalization : int , lastdim : int ):
441
+ # _fft_c2r_mkl
442
+ torch ._check (self .dtype .is_complex )
443
+
444
+ if device_hint (self ) == "cuda" :
445
+ out_sizes = list (self .size ())
446
+ out_sizes [dim [- 1 ]] = lastdim
447
+
448
+ output = self .new_empty (out_sizes , dtype = toRealValueType (self .dtype ))
449
+
450
+ if use_optimized_cufft_path (dim ):
451
+ return _exec_fft (
452
+ output ,
453
+ self .clone (memory_format = torch .contiguous_format ),
454
+ out_sizes ,
455
+ dim ,
456
+ forward = False ,
457
+ )
458
+ else :
459
+ # First complete any C2C transforms
460
+ if len (dim ) > 1 :
461
+ temp = meta_fft_c2c (self , dim [:- 1 ], 0 , lastdim ) # fft_norm_mode::none
462
+ else :
463
+ temp = self .clone (memory_format = torch .contiguous_format )
464
+ return _exec_fft (output , temp , out_sizes , [dim [- 1 ]], forward = False )
465
+
466
+ else :
467
+ input = self
468
+ if len (dim ) > 1 :
469
+ c2c_dims = dim [:- 1 ]
470
+ input = meta_fft_c2c (self , c2c_dims , normalization , forward = False )
471
+ dim = dim [- 1 :]
472
+
473
+ out_sizes = list (input .size ())
474
+ out_sizes [dim [- 1 ]] = lastdim
475
+ out = self .new_empty (out_sizes , dtype = toRealValueType (self .dtype ))
476
+ return _exec_fft (out , input , out_sizes , dim , forward = False )
383
477
384
478
385
479
@register_meta (aten .copy_ .default )
0 commit comments