Skip to content

Commit 87fdadd

Browse files
ezyangpytorchmergebot
authored andcommitted
Remove FFT from stride incorrect ops (pytorch#145080)
I gotta say, the FFT implementation is completely insane, there's gotta be a better way to do this than repeatedly inplace restriding the output tensor. Anyway, this is a faithful translation of both the MKL and cuFFT paths to Python. Fixes pytorch#135087 Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#145080 Approved by: https://github.com/Skylion007, https://github.com/albanD ghstack dependencies: pytorch#145530
1 parent b75afa2 commit 87fdadd

File tree

4 files changed

+123
-74
lines changed

4 files changed

+123
-74
lines changed

test/functorch/test_aotdispatch.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6509,21 +6509,6 @@ def _test_fn(fn, check_backward=True):
65096509
"linalg.householder_product",
65106510
decorator=unittest.skipIf(IS_MACOS and IS_X86, "flaky"),
65116511
),
6512-
# many complex operators incorrect striding, metadata
6513-
xfail("fft.fft", ""),
6514-
xfail("fft.hfft2", ""),
6515-
xfail("fft.hfft", ""),
6516-
xfail("fft.hfftn", ""),
6517-
xfail("fft.ifft", ""),
6518-
xfail("fft.ihfft2", ""),
6519-
xfail("fft.ihfft", ""),
6520-
xfail("fft.ihfftn", ""),
6521-
xfail("fft.irfft2", ""),
6522-
xfail("fft.irfft", ""),
6523-
xfail("fft.irfftn", ""),
6524-
xfail("fft.rfft2", ""),
6525-
xfail("fft.rfft", ""),
6526-
xfail("fft.rfftn", ""),
65276512
xfail("stft", ""), # Cannot call sizes() on tensor with symbolic sizes/strides
65286513
}
65296514

test/test_proxy_tensor.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2014,24 +2014,6 @@ def f(t):
20142014
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
20152015

20162016
xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but...
2017-
2018-
# many complex operators incorrect striding, metadata
2019-
xfail('fft.fft', ''),
2020-
xfail('fft.hfft2', ''),
2021-
xfail('fft.hfft', ''),
2022-
xfail('fft.hfftn', ''),
2023-
xfail('fft.ifft', ''),
2024-
xfail('fft.ihfft2', ''),
2025-
xfail('fft.ihfft', ''),
2026-
xfail('fft.ihfftn', ''),
2027-
xfail('fft.ihfft2', ''),
2028-
xfail('fft.irfft2', ''),
2029-
xfail('fft.irfft', ''),
2030-
xfail('fft.irfftn', ''),
2031-
xfail('fft.rfft2', ''),
2032-
xfail('fft.rfft', ''),
2033-
xfail('fft.rfftn', ''),
2034-
xfail('stft', '')
20352017
}
20362018
symbolic_tensor_segfaults = {
20372019
skip('nn.functional.batch_norm') # Segfault??
@@ -2058,10 +2040,6 @@ def f(t):
20582040
xfail('angle', ''),
20592041
xfail('argmax', ''),
20602042
xfail('argmin', ''),
2061-
xfail('fft.fft2', ''),
2062-
xfail('fft.fftn', ''),
2063-
xfail('fft.ifft2', ''),
2064-
xfail('fft.ifftn', ''),
20652043
xfail('gather', ''),
20662044
xfail('linalg.pinv', ''),
20672045
xfail('linalg.pinv', 'hermitian'),

torch/_meta_registrations.py

Lines changed: 123 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,12 @@ def logcumsumexp(self, dim):
223223
return torch.empty_like(self).contiguous()
224224

225225

226-
# Stride-related code from _exec_fft in aten/src/ATen/native/cuda/SpectralOps.cpp
227-
def _exec_fft(out, self, out_sizes, dim, forward):
226+
# Stride-related code from _exec_fft in aten/src/ATen/native/mkl/SpectralOps.cpp
227+
# and aten/src/ATen/cuda/SpectralOps.cpp
228+
#
229+
# Although the actual FFT launch is different, all the permuting code appears
230+
# to be the same
231+
def _exec_fft(out, self, out_sizes, dim, *, forward):
228232
ndim = self.ndim
229233
signal_ndim = len(dim)
230234
batch_dims = ndim - signal_ndim
@@ -258,12 +262,12 @@ def _exec_fft(out, self, out_sizes, dim, forward):
258262

259263
batch_size = input.size(0)
260264
batched_sizes[0] = batch_size
261-
batched_out_sizes = batched_sizes
265+
batched_out_sizes = list(batched_sizes)
262266
for i in range(len(dim)):
263267
batched_out_sizes[i + 1] = out_sizes[dim[i]]
264-
out = out.reshape(batched_out_sizes)
268+
out.resize_(batched_out_sizes, memory_format=torch.contiguous_format)
265269

266-
# Reshaping to original batch shape and inverting the dimension permutation
270+
# Inplace reshaping to original batch shape and inverting the dimension permutation
267271
out_strides = [0 for _ in range(ndim)]
268272
batch_numel = 1
269273
i = batch_dims - 1
@@ -273,44 +277,102 @@ def _exec_fft(out, self, out_sizes, dim, forward):
273277
i -= 1
274278
for i in range(batch_dims, ndim):
275279
out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims))
276-
return out.as_strided(out_sizes, out_strides, out.storage_offset())
280+
out.as_strided_(out_sizes, out_strides, out.storage_offset())
281+
282+
return out
283+
284+
285+
def _sort_dims(self: Tensor, dim: list[int], exclude_last: bool = False):
286+
sorted_dims = list(dim)
287+
self_strides = self.stride()
288+
sorted_dims[: len(sorted_dims) - int(exclude_last)].sort(
289+
key=lambda i: self_strides[i]
290+
)
291+
return sorted_dims
277292

278293

279294
# See _fft_c2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
280295
# and _fft_c2c_mkl in aten/src/ATen/native/mkl/SpectralOps.cpp
281296
@register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
282297
@out_wrapper()
283298
def meta_fft_c2c(self, dim, normalization, forward):
284-
assert self.dtype.is_complex
299+
torch._check(self.dtype.is_complex)
300+
if not dim:
301+
return self.clone()
285302

286-
out_sizes = self.shape
287-
output = self.new_empty(out_sizes)
303+
sorted_dims = _sort_dims(self, dim)
304+
out = self.new_empty(self.size())
305+
return _exec_fft(out, self, self.size(), sorted_dims, forward=forward)
288306

289-
if not dim:
290-
return output
291307

292-
sorted_dims = dim[:]
293-
self_strides = self.stride()
294-
sorted_dims.sort(key=lambda x: self_strides[x], reverse=True)
295-
output = _exec_fft(output, self, out_sizes, sorted_dims, forward)
308+
cufft_max_ndim = 3
296309

297-
return output
310+
311+
def use_optimized_cufft_path(dim: list[int]):
312+
if len(dim) > cufft_max_ndim or (len(dim) >= 2 and dim[0] == 0 and dim[1] == 1):
313+
return False
314+
else:
315+
return True
298316

299317

300318
@register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
301319
@out_wrapper()
302320
def meta_fft_r2c(self, dim, normalization, onesided):
303-
assert self.dtype.is_floating_point
304-
output_sizes = list(self.size())
321+
torch._check(self.dtype.is_floating_point)
322+
input_sizes = list(self.size())
323+
out_sizes = list(input_sizes)
324+
last_dim = dim[-1]
325+
last_dim_halfsize = input_sizes[last_dim] // 2 + 1
326+
onesided_sizes = list(input_sizes)
327+
onesided_sizes[last_dim] = last_dim_halfsize
305328

306329
if onesided:
307-
last_dim = dim[-1]
308-
last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
309-
output_sizes[last_dim] = last_dim_halfsize
330+
out_sizes[last_dim] = last_dim_halfsize
310331

311-
return self.new_empty(
312-
output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
313-
)
332+
if device_hint(self) == "cuda":
333+
# _fft_r2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
334+
output = self.new_empty(
335+
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
336+
)
337+
338+
working_tensor = self
339+
if use_optimized_cufft_path(dim):
340+
_exec_fft(output, working_tensor, out_sizes, dim, forward=True)
341+
else:
342+
# First do the R2C transform on the last dimension
343+
target_sizes = out_sizes if len(dim) == 1 else onesided_sizes
344+
_exec_fft(output, working_tensor, target_sizes, [last_dim], forward=True)
345+
if len(dim) > 1:
346+
working_tensor = self.new_empty(
347+
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
348+
)
349+
350+
# Then any remaining C2C transforms
351+
sorted_dims = dim[:-1]
352+
while sorted_dims:
353+
output, working_tensor = working_tensor, output
354+
strides = working_tensor.stride()
355+
sorted_dims.sort(
356+
key=lambda i: strides[i], reverse=True
357+
) # NB reverse! Not sure if this is og bug
358+
max_dims = min(cufft_max_ndim, len(sorted_dims))
359+
last_dims = sorted_dims[len(sorted_dims) - max_dims :]
360+
_exec_fft(
361+
output, working_tensor, onesided_sizes, last_dims, forward=True
362+
)
363+
sorted_dims = sorted_dims[: len(sorted_dims) - max_dims]
364+
365+
if not onesided:
366+
if output.size(last_dim) != out_sizes[last_dim]:
367+
working_tensor.resize_(out_sizes, memory_format=torch.contiguous_format)
368+
output = working_tensor
369+
370+
return output
371+
372+
else:
373+
return self.new_empty(
374+
out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
375+
)
314376

315377

316378
@register_meta(aten.randperm.generator_out)
@@ -375,11 +437,43 @@ def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=
375437

376438
@register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
377439
@out_wrapper()
378-
def meta_fft_c2r(self, dim, normalization, lastdim):
379-
assert self.dtype.is_complex
380-
output_sizes = list(self.size())
381-
output_sizes[dim[-1]] = lastdim
382-
return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
440+
def meta_fft_c2r(self: Tensor, dim: list[int], normalization: int, lastdim: int):
441+
# _fft_c2r_mkl
442+
torch._check(self.dtype.is_complex)
443+
444+
if device_hint(self) == "cuda":
445+
out_sizes = list(self.size())
446+
out_sizes[dim[-1]] = lastdim
447+
448+
output = self.new_empty(out_sizes, dtype=toRealValueType(self.dtype))
449+
450+
if use_optimized_cufft_path(dim):
451+
return _exec_fft(
452+
output,
453+
self.clone(memory_format=torch.contiguous_format),
454+
out_sizes,
455+
dim,
456+
forward=False,
457+
)
458+
else:
459+
# First complete any C2C transforms
460+
if len(dim) > 1:
461+
temp = meta_fft_c2c(self, dim[:-1], 0, lastdim) # fft_norm_mode::none
462+
else:
463+
temp = self.clone(memory_format=torch.contiguous_format)
464+
return _exec_fft(output, temp, out_sizes, [dim[-1]], forward=False)
465+
466+
else:
467+
input = self
468+
if len(dim) > 1:
469+
c2c_dims = dim[:-1]
470+
input = meta_fft_c2c(self, c2c_dims, normalization, forward=False)
471+
dim = dim[-1:]
472+
473+
out_sizes = list(input.size())
474+
out_sizes[dim[-1]] = lastdim
475+
out = self.new_empty(out_sizes, dtype=toRealValueType(self.dtype))
476+
return _exec_fft(out, input, out_sizes, dim, forward=False)
383477

384478

385479
@register_meta(aten.copy_.default)

torch/_subclasses/fake_impls.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,6 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs):
222222

223223

224224
def stride_incorrect_op(op):
225-
if op.namespace not in ("aten", "prims"):
226-
return False
227-
if op is aten._fft_c2c.default:
228-
return False
229-
230-
op_name = op.name()
231-
if "fft" in op_name:
232-
return True
233225
return False
234226

235227

0 commit comments

Comments
 (0)