Skip to content

Commit 4a6b78f

Browse files
pssrawatfacebook-github-bot
authored andcommitted
Implement _fft_c2r core ATen op
Summary: Add ff2_c2r Differential Revision: D73006888
1 parent 092da57 commit 4a6b78f

File tree

7 files changed

+367
-0
lines changed

7 files changed

+367
-0
lines changed

kernels/aten/functions.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
- op: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out
88

9+
- op: _fft_c2r.out
10+
911
- op: _fft_r2c.out
1012

1113
- op: _linalg_det.result

kernels/optimized/cpu/op_fft_c2r.cpp

+181
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/span.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
#include <pocketfft_hdronly.h>
13+
14+
#include <optional>
15+
16+
namespace torch::executor::native {
17+
18+
// TODO: contents of this anonymous namespace are copy/pasted from
19+
// PyTorch core (aten/src/ATen/native/mkl/SpectralOps.cpp). Small
20+
// portions (the parts that don't depend on Tensor) could be reused;
21+
// refactor to enable that once we can share headers from PyTorch
22+
// core.
23+
namespace {
24+
pocketfft::stride_t stride_from_tensor(const Tensor& t) {
25+
pocketfft::stride_t stride(t.strides().begin(), t.strides().end());
26+
for (auto& s : stride) {
27+
s *= t.element_size();
28+
}
29+
return stride;
30+
}
31+
32+
pocketfft::shape_t shape_from_tensor(const Tensor& t) {
33+
return pocketfft::shape_t(t.sizes().begin(), t.sizes().end());
34+
}
35+
36+
// NOTE: The reinterpret_cast in tensor_cdata is UB, but it's what
37+
// PyTorch core does and I'm not aware of a portable way to do this
38+
// that doesn't rely on UB.
39+
template <typename T>
40+
inline std::complex<T>* tensor_cdata(Tensor& t) {
41+
return reinterpret_cast<std::complex<T>*>(
42+
t.data_ptr<executorch::runtime::etensor::complex<T>>());
43+
}
44+
45+
template <typename T>
46+
inline const std::complex<T>* tensor_cdata(const Tensor& t) {
47+
return reinterpret_cast<const std::complex<T>*>(
48+
t.const_data_ptr<executorch::runtime::etensor::complex<T>>());
49+
}
50+
51+
// NOTE: in particular this is in ATen/native/SpectralOpsUtils.h and
52+
// could be shared immediately.
53+
enum class fft_norm_mode {
54+
none, // No normalization
55+
by_root_n, // Divide by sqrt(signal_size)
56+
by_n, // Divide by signal_size
57+
};
58+
59+
// NOTE: slight fork from upstream PyTorch to use ET_KERNEL_CHECK;
60+
// upstream with TORCH_CHECK will be fine to use once we have code
61+
// sharing.
62+
template <typename T>
63+
std::optional<T>
64+
compute_fct(KernelRuntimeContext& ctx, int64_t size, int64_t normalization) {
65+
constexpr auto one = static_cast<T>(1);
66+
switch (static_cast<fft_norm_mode>(normalization)) {
67+
case fft_norm_mode::none:
68+
return one;
69+
case fft_norm_mode::by_n:
70+
return one / static_cast<T>(size);
71+
case fft_norm_mode::by_root_n:
72+
return one / std::sqrt(static_cast<T>(size));
73+
}
74+
ET_KERNEL_CHECK_MSG(
75+
ctx,
76+
false,
77+
InvalidArgument,
78+
std::nullopt,
79+
"Unsupported normalization type: %" PRId64,
80+
normalization);
81+
}
82+
83+
template <typename T>
84+
std::optional<T> compute_fct(
85+
KernelRuntimeContext& ctx,
86+
const Tensor& t,
87+
IntArrayRef dim,
88+
int64_t normalization) {
89+
if (static_cast<fft_norm_mode>(normalization) == fft_norm_mode::none) {
90+
return static_cast<T>(1);
91+
}
92+
const auto& sizes = t.sizes();
93+
int64_t n = 1;
94+
for (auto idx : dim) {
95+
n *= sizes[idx];
96+
}
97+
return compute_fct<T>(ctx, n, normalization);
98+
}
99+
100+
} // namespace
101+
102+
Tensor& opt_fft_c2r_out(
103+
KernelRuntimeContext& ctx,
104+
const Tensor& in,
105+
IntArrayRef dim,
106+
int64_t normalization,
107+
int64_t last_dim_size,
108+
Tensor& out) {
109+
auto in_sizes = in.sizes();
110+
ET_KERNEL_CHECK(ctx, in.dim() <= kTensorDimensionLimit, InvalidArgument, out);
111+
112+
ET_KERNEL_CHECK(ctx, !dim.empty(), InvalidArgument, out);
113+
ET_KERNEL_CHECK(ctx, last_dim_size >= 1, InvalidArgument, out);
114+
115+
// Determine the output size
116+
std::array<Tensor::SizesType, kTensorDimensionLimit> out_sizes_storage{};
117+
executorch::runtime::Span<Tensor::SizesType> out_sizes(
118+
out_sizes_storage.data(), in_sizes.size());
119+
std::copy(in_sizes.begin(), in_sizes.end(), out_sizes.begin());
120+
out_sizes[dim.back()] = last_dim_size;
121+
122+
ET_KERNEL_CHECK(
123+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
124+
125+
ET_KERNEL_CHECK_MSG(
126+
ctx,
127+
in.scalar_type() == executorch::runtime::toComplexType(out.scalar_type()),
128+
InvalidArgument,
129+
out,
130+
"the input type for _fft_c2r must be the Complex type corresponding to the output type");
131+
132+
for (auto d : dim) {
133+
ET_KERNEL_CHECK_MSG(
134+
ctx,
135+
d >= 0 && d < in.dim(),
136+
InvalidArgument,
137+
out,
138+
"dims must be in bounds (got %" PRId64 ")",
139+
d);
140+
}
141+
142+
ET_KERNEL_CHECK_MSG(
143+
ctx,
144+
resize_tensor(
145+
out,
146+
executorch::runtime::ArrayRef<Tensor::SizesType>(
147+
out_sizes.data(), out_sizes.size())) == Error::Ok,
148+
InvalidArgument,
149+
out,
150+
"Failed to resize output tensor (last dim %d).",
151+
out_sizes[dim.back()]);
152+
153+
pocketfft::shape_t axes(dim.begin(), dim.end());
154+
auto out_shape = shape_from_tensor(out);
155+
// TODO: if arbitrary strides are a possibility, we need to validate
156+
// these, because pocketfft README says "Strides that lead to
157+
// multiple accesses of the same memory address are not allowed."
158+
auto in_stride = stride_from_tensor(in);
159+
auto out_stride = stride_from_tensor(out);
160+
// NOTE: as of this writing, upstream PyTorch only supports
161+
// float/double, so we follow suit.
162+
ET_SWITCH_FLOAT_TYPES(out.scalar_type(), ctx, "_fft_c2r.out", CTYPE_OUT, [&] {
163+
auto fct = compute_fct<CTYPE_OUT>(ctx, out, dim, normalization);
164+
if (!fct) {
165+
// Check failed, just bail out of the lambda.
166+
return;
167+
}
168+
pocketfft::c2r<CTYPE_OUT>(
169+
out_shape,
170+
in_stride,
171+
out_stride,
172+
axes,
173+
false /* forward */,
174+
tensor_cdata<CTYPE_OUT>(in),
175+
out.mutable_data_ptr<CTYPE_OUT>(),
176+
*fct);
177+
});
178+
return out;
179+
}
180+
181+
} // namespace torch::executor::native

kernels/optimized/cpu/targets.bzl

+8
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ _OPTIMIZED_ATEN_OPS = (
3434
],
3535
),
3636
op_target(name = "op_exp"),
37+
op_target(
38+
name = "op_fft_c2r",
39+
compiler_flags = [] if runtime.is_oss else [
40+
"-Wno-global-constructors",
41+
"-Wno-shadow",
42+
],
43+
deps = [] if runtime.is_oss else ["fbsource//third-party/pocket_fft:pocketfft"],
44+
),
3745
op_target(
3846
name = "op_fft_r2c",
3947
compiler_flags = [] if runtime.is_oss else [

kernels/optimized/optimized.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
#
33
# This yaml file contains operators that have optimized kernels available.
44

5+
- op: _fft_c2r.out
6+
kernels:
7+
- arg_meta: null
8+
kernel_name: torch::executor::opt_fft_c2r_out
9+
510
- op: _fft_r2c.out
611
kernels:
712
- arg_meta: null

kernels/test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ set(_optimized_kernels_test_sources
276276
"op_div_test.cpp"
277277
"op_elu_test.cpp"
278278
"op_exp_test.cpp"
279+
"op_fft_c2r_test.cpp"
279280
"op_fft_r2c_test.cpp"
280281
"op_gelu_test.cpp"
281282
"op_le_test.cpp"

kernels/test/op_fft_c2r_test.cpp

+169
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/TestUtil.h>
11+
#include <executorch/kernels/test/supported_features.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15+
16+
#include <gtest/gtest.h>
17+
18+
using executorch::aten::IntArrayRef;
19+
using executorch::aten::ScalarType;
20+
using executorch::aten::Tensor;
21+
using executorch::runtime::testing::TensorFactory;
22+
23+
class OpFftC2rOutTest : public OperatorTest {
24+
protected:
25+
Tensor& op_fft_c2r_out(
26+
const Tensor& in,
27+
IntArrayRef dim,
28+
int64_t normalization,
29+
int64_t last_dim_size,
30+
Tensor& out) {
31+
return torch::executor::aten::_fft_c2r_outf(
32+
context_, in, dim, normalization, last_dim_size, out);
33+
}
34+
35+
template <
36+
class CTYPE_OUT,
37+
executorch::aten::ScalarType DTYPE_OUT,
38+
bool expect_failure = false>
39+
void test_dtype(int64_t norm, int64_t dim = 0) {
40+
TensorFactory<DTYPE_OUT> tf_out;
41+
constexpr auto DTYPE_IN = executorch::runtime::toComplexType(DTYPE_OUT);
42+
TensorFactory<DTYPE_IN> tf_in;
43+
44+
using CTYPE_IN =
45+
typename executorch::runtime::ScalarTypeToCppType<DTYPE_IN>::type;
46+
47+
std::vector<CTYPE_IN> input_data = {
48+
CTYPE_IN{24, 4},
49+
CTYPE_IN{4, -8},
50+
CTYPE_IN{0, 4},
51+
52+
CTYPE_IN{8, -16},
53+
CTYPE_IN{-4, 0},
54+
CTYPE_IN{0, 32},
55+
56+
CTYPE_IN{12, 0},
57+
CTYPE_IN{0, 4},
58+
CTYPE_IN{-8, 4},
59+
60+
CTYPE_IN{0, 8},
61+
CTYPE_IN{-4, 8},
62+
CTYPE_IN{8, 0},
63+
};
64+
65+
Tensor in = tf_in.make({4, 3}, input_data);
66+
Tensor out = tf_out.full({4, 3}, 0);
67+
68+
int64_t last_dim_size = (dim >= 0 && dim < out.dim()) ? out.sizes()[dim] : 0;
69+
op_fft_c2r_out(in, {dim}, norm, last_dim_size, out);
70+
71+
double norm_factor = 1;
72+
if (norm == 1) {
73+
norm_factor = 2;
74+
} else if (norm == 2) {
75+
norm_factor = 4;
76+
}
77+
std::vector<CTYPE_OUT> expected_data = {52., -4., -8., 44., 4., -56., 20., 12., -8., -20., 4., 72.};
78+
for (auto& elem : expected_data) {
79+
elem /= norm_factor;
80+
}
81+
Tensor expected = tf_out.make({4, 3}, expected_data);
82+
83+
if (!expect_failure) {
84+
EXPECT_TENSOR_CLOSE(out, expected);
85+
}
86+
}
87+
88+
template <class CTYPE_OUT, executorch::aten::ScalarType DTYPE_OUT>
89+
void test_dtype_multiple_axes() {
90+
TensorFactory<DTYPE_OUT> tf_out;
91+
constexpr auto DTYPE_IN = executorch::runtime::toComplexType(DTYPE_OUT);
92+
TensorFactory<DTYPE_IN> tf_in;
93+
94+
using CTYPE_IN =
95+
typename executorch::runtime::ScalarTypeToCppType<DTYPE_IN>::type;
96+
97+
std::vector<CTYPE_IN> input_data = {
98+
CTYPE_IN{16, 4},
99+
CTYPE_IN{4, -8},
100+
CTYPE_IN{0, 4},
101+
102+
CTYPE_IN{8, -16},
103+
CTYPE_IN{-4, 0},
104+
CTYPE_IN{0, 36},
105+
106+
CTYPE_IN{32, 0},
107+
CTYPE_IN{0, 4},
108+
CTYPE_IN{-8, 4},
109+
110+
CTYPE_IN{0, 8},
111+
CTYPE_IN{-4, 8},
112+
CTYPE_IN{8, 0},
113+
};
114+
115+
Tensor in = tf_in.make({4, 3}, input_data);
116+
Tensor out = tf_out.full({4, 4}, 0);
117+
118+
int64_t last_dim_size = out.sizes()[0];
119+
std::array<int64_t, 2> dim = {0, 1};
120+
op_fft_c2r_out(in, dim, 1, last_dim_size, out);
121+
122+
std::vector<CTYPE_OUT> expected_data = {12., 12., 16., 16., 1., 15., -11., 3., 12., 20., 0., 8., -1., -15., 3., -27.};
123+
Tensor expected = tf_out.make({4, 4}, expected_data);
124+
EXPECT_TENSOR_CLOSE(out, expected);
125+
}
126+
};
127+
128+
TEST_F(OpFftC2rOutTest, AllDtypesSupported) {
129+
#define TEST_ENTRY(ctype, dtype) \
130+
test_dtype<ctype, ScalarType::dtype>(0); \
131+
test_dtype<ctype, ScalarType::dtype>(1); \
132+
test_dtype<ctype, ScalarType::dtype>(2);
133+
ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
134+
#undef TEST_ENTRY
135+
}
136+
137+
TEST_F(OpFftC2rOutTest, MultipleDims) {
138+
#define TEST_ENTRY(ctype, dtype) \
139+
test_dtype_multiple_axes<ctype, ScalarType::dtype>();
140+
ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
141+
#undef TEST_ENTRY
142+
}
143+
144+
TEST_F(OpFftC2rOutTest, InvalidNorm) {
145+
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
146+
GTEST_SKIP() << "ATen MKL path does not validate norm";
147+
return;
148+
}
149+
auto invalid_norm = [this](int64_t norm) {
150+
test_dtype<float, ScalarType::Float, /* expect_failure = */ true>(norm);
151+
};
152+
ET_EXPECT_KERNEL_FAILURE(context_, invalid_norm(3));
153+
ET_EXPECT_KERNEL_FAILURE(context_, invalid_norm(4));
154+
ET_EXPECT_KERNEL_FAILURE(context_, invalid_norm(-1));
155+
ET_EXPECT_KERNEL_FAILURE(context_, invalid_norm(9999999));
156+
}
157+
158+
TEST_F(OpFftC2rOutTest, InvalidDim) {
159+
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
160+
GTEST_SKIP() << "ATen fails UBSAN";
161+
return;
162+
}
163+
auto negative_dim = [this]() {
164+
test_dtype<float, ScalarType::Float, /* expect_failure = */ true>(0, -1);
165+
test_dtype<float, ScalarType::Float, /* expect_failure = */ true>(0, 3);
166+
test_dtype<float, ScalarType::Float, /* expect_failure = */ true>(0, 9001);
167+
};
168+
ET_EXPECT_KERNEL_FAILURE(context_, negative_dim());
169+
}

0 commit comments

Comments
 (0)