Skip to content

Commit ddfcfcd

Browse files
pssrawatfacebook-github-bot
authored andcommitted
Implement _fft_c2r core ATen op (#10208)
Summary: Pull Request resolved: #10208 Add ff2_c2r Differential Revision: D73006888
1 parent 900b42c commit ddfcfcd

File tree

7 files changed

+385
-0
lines changed

7 files changed

+385
-0
lines changed

kernels/aten/functions.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
- op: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out
88

9+
- op: _fft_c2r.out
10+
911
- op: _fft_r2c.out
1012

1113
- op: _linalg_det.result

kernels/optimized/cpu/op_fft_c2r.cpp

+181
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/span.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
#include <pocketfft_hdronly.h>
13+
14+
#include <optional>
15+
16+
namespace torch::executor::native {
17+
18+
// TODO: contents of this anonymous namespace are copy/pasted from
19+
// PyTorch core (aten/src/ATen/native/mkl/SpectralOps.cpp). Small
20+
// portions (the parts that don't depend on Tensor) could be reused;
21+
// refactor to enable that once we can share headers from PyTorch
22+
// core.
23+
namespace {
24+
pocketfft::stride_t stride_from_tensor(const Tensor& t) {
25+
pocketfft::stride_t stride(t.strides().begin(), t.strides().end());
26+
for (auto& s : stride) {
27+
s *= t.element_size();
28+
}
29+
return stride;
30+
}
31+
32+
pocketfft::shape_t shape_from_tensor(const Tensor& t) {
33+
return pocketfft::shape_t(t.sizes().begin(), t.sizes().end());
34+
}
35+
36+
// NOTE: The reinterpret_cast in tensor_cdata is UB, but it's what
37+
// PyTorch core does and I'm not aware of a portable way to do this
38+
// that doesn't rely on UB.
39+
template <typename T>
40+
inline std::complex<T>* tensor_cdata(Tensor& t) {
41+
return reinterpret_cast<std::complex<T>*>(
42+
t.data_ptr<executorch::runtime::etensor::complex<T>>());
43+
}
44+
45+
template <typename T>
46+
inline const std::complex<T>* tensor_cdata(const Tensor& t) {
47+
return reinterpret_cast<const std::complex<T>*>(
48+
t.const_data_ptr<executorch::runtime::etensor::complex<T>>());
49+
}
50+
51+
// NOTE: in particular this is in ATen/native/SpectralOpsUtils.h and
52+
// could be shared immediately.
53+
enum class fft_norm_mode {
54+
none, // No normalization
55+
by_root_n, // Divide by sqrt(signal_size)
56+
by_n, // Divide by signal_size
57+
};
58+
59+
// NOTE: slight fork from upstream PyTorch to use ET_KERNEL_CHECK;
60+
// upstream with TORCH_CHECK will be fine to use once we have code
61+
// sharing.
62+
template <typename T>
63+
std::optional<T>
64+
compute_fct(KernelRuntimeContext& ctx, int64_t size, int64_t normalization) {
65+
constexpr auto one = static_cast<T>(1);
66+
switch (static_cast<fft_norm_mode>(normalization)) {
67+
case fft_norm_mode::none:
68+
return one;
69+
case fft_norm_mode::by_n:
70+
return one / static_cast<T>(size);
71+
case fft_norm_mode::by_root_n:
72+
return one / std::sqrt(static_cast<T>(size));
73+
}
74+
ET_KERNEL_CHECK_MSG(
75+
ctx,
76+
false,
77+
InvalidArgument,
78+
std::nullopt,
79+
"Unsupported normalization type: %" PRId64,
80+
normalization);
81+
}
82+
83+
template <typename T>
84+
std::optional<T> compute_fct(
85+
KernelRuntimeContext& ctx,
86+
const Tensor& t,
87+
IntArrayRef dim,
88+
int64_t normalization) {
89+
if (static_cast<fft_norm_mode>(normalization) == fft_norm_mode::none) {
90+
return static_cast<T>(1);
91+
}
92+
const auto& sizes = t.sizes();
93+
int64_t n = 1;
94+
for (auto idx : dim) {
95+
n *= sizes[idx];
96+
}
97+
return compute_fct<T>(ctx, n, normalization);
98+
}
99+
100+
} // namespace
101+
102+
Tensor& opt_fft_c2r_out(
103+
KernelRuntimeContext& ctx,
104+
const Tensor& in,
105+
IntArrayRef dim,
106+
int64_t normalization,
107+
int64_t last_dim_size,
108+
Tensor& out) {
109+
auto in_sizes = in.sizes();
110+
ET_KERNEL_CHECK(ctx, in.dim() <= kTensorDimensionLimit, InvalidArgument, out);
111+
112+
ET_KERNEL_CHECK(ctx, !dim.empty(), InvalidArgument, out);
113+
ET_KERNEL_CHECK(ctx, last_dim_size >= 1, InvalidArgument, out);
114+
115+
// Determine the output size
116+
std::array<Tensor::SizesType, kTensorDimensionLimit> out_sizes_storage{};
117+
executorch::runtime::Span<Tensor::SizesType> out_sizes(
118+
out_sizes_storage.data(), in_sizes.size());
119+
std::copy(in_sizes.begin(), in_sizes.end(), out_sizes.begin());
120+
out_sizes[dim.back()] = last_dim_size;
121+
122+
ET_KERNEL_CHECK(
123+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
124+
125+
ET_KERNEL_CHECK_MSG(
126+
ctx,
127+
in.scalar_type() == executorch::runtime::toComplexType(out.scalar_type()),
128+
InvalidArgument,
129+
out,
130+
"the input type for _fft_c2r must be the Complex type corresponding to the output type");
131+
132+
for (auto d : dim) {
133+
ET_KERNEL_CHECK_MSG(
134+
ctx,
135+
d >= 0 && d < in.dim(),
136+
InvalidArgument,
137+
out,
138+
"dims must be in bounds (got %" PRId64 ")",
139+
d);
140+
}
141+
142+
ET_KERNEL_CHECK_MSG(
143+
ctx,
144+
resize_tensor(
145+
out,
146+
executorch::runtime::ArrayRef<Tensor::SizesType>(
147+
out_sizes.data(), out_sizes.size())) == Error::Ok,
148+
InvalidArgument,
149+
out,
150+
"Failed to resize output tensor (last dim %d).",
151+
out_sizes[dim.back()]);
152+
153+
pocketfft::shape_t axes(dim.begin(), dim.end());
154+
auto out_shape = shape_from_tensor(out);
155+
// TODO: if arbitrary strides are a possibility, we need to validate
156+
// these, because pocketfft README says "Strides that lead to
157+
// multiple accesses of the same memory address are not allowed."
158+
auto in_stride = stride_from_tensor(in);
159+
auto out_stride = stride_from_tensor(out);
160+
// NOTE: as of this writing, upstream PyTorch only supports
161+
// float/double, so we follow suit.
162+
ET_SWITCH_FLOAT_TYPES(out.scalar_type(), ctx, "_fft_c2r.out", CTYPE_OUT, [&] {
163+
auto fct = compute_fct<CTYPE_OUT>(ctx, out, dim, normalization);
164+
if (!fct) {
165+
// Check failed, just bail out of the lambda.
166+
return;
167+
}
168+
pocketfft::c2r<CTYPE_OUT>(
169+
out_shape,
170+
in_stride,
171+
out_stride,
172+
axes,
173+
false /* forward */,
174+
tensor_cdata<CTYPE_OUT>(in),
175+
out.mutable_data_ptr<CTYPE_OUT>(),
176+
*fct);
177+
});
178+
return out;
179+
}
180+
181+
} // namespace torch::executor::native

kernels/optimized/cpu/targets.bzl

+8
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ _OPTIMIZED_ATEN_OPS = (
3434
],
3535
),
3636
op_target(name = "op_exp"),
37+
op_target(
38+
name = "op_fft_c2r",
39+
compiler_flags = [] if runtime.is_oss else [
40+
"-Wno-global-constructors",
41+
"-Wno-shadow",
42+
],
43+
deps = [] if runtime.is_oss else ["fbsource//third-party/pocket_fft:pocketfft"],
44+
),
3745
op_target(
3846
name = "op_fft_r2c",
3947
compiler_flags = [] if runtime.is_oss else [

kernels/optimized/optimized.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
#
33
# This yaml file contains operators that have optimized kernels available.
44

5+
- op: _fft_c2r.out
6+
kernels:
7+
- arg_meta: null
8+
kernel_name: torch::executor::opt_fft_c2r_out
9+
510
- op: _fft_r2c.out
611
kernels:
712
- arg_meta: null

kernels/test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ set(_optimized_kernels_test_sources
276276
"op_div_test.cpp"
277277
"op_elu_test.cpp"
278278
"op_exp_test.cpp"
279+
"op_fft_c2r_test.cpp"
279280
"op_fft_r2c_test.cpp"
280281
"op_gelu_test.cpp"
281282
"op_le_test.cpp"

0 commit comments

Comments
 (0)