Skip to content

Commit f725f63

Browse files
pssrawatfacebook-github-bot
authored andcommitted
Implement _fft_c2r core ATen op (#10208)
Summary: Pull Request resolved: #10208 Add ff2_c2r Differential Revision: D73006888
1 parent 605bfa6 commit f725f63

File tree

9 files changed

+405
-91
lines changed

9 files changed

+405
-91
lines changed

kernels/aten/functions.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
- op: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out
88

9+
- op: _fft_c2r.out
10+
911
- op: _fft_r2c.out
1012

1113
- op: _linalg_det.result

kernels/optimized/cpu/fft_utils.h

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <pocketfft_hdronly.h>
13+
#include <optional>
14+
15+
namespace torch::executor::native {
16+
17+
// TODO: contents of this anonymous namespace are copy/pasted from
18+
// PyTorch core (aten/src/ATen/native/mkl/SpectralOps.cpp). Small
19+
// portions (the parts that don't depend on Tensor) could be reused;
20+
// refactor to enable that once we can share headers from PyTorch
21+
// core.
22+
namespace {
23+
pocketfft::stride_t stride_from_tensor(const Tensor& t) {
24+
pocketfft::stride_t stride(t.strides().begin(), t.strides().end());
25+
for (auto& s : stride) {
26+
s *= t.element_size();
27+
}
28+
return stride;
29+
}
30+
31+
pocketfft::shape_t shape_from_tensor(const Tensor& t) {
32+
return pocketfft::shape_t(t.sizes().begin(), t.sizes().end());
33+
}
34+
35+
// NOTE: The reinterpret_cast in tensor_cdata is UB, but it's what
36+
// PyTorch core does and I'm not aware of a portable way to do this
37+
// that doesn't rely on UB.
38+
template <typename T>
39+
inline std::complex<T>* tensor_cdata(Tensor& t) {
40+
return reinterpret_cast<std::complex<T>*>(
41+
t.data_ptr<executorch::runtime::etensor::complex<T>>());
42+
}
43+
44+
template <typename T>
45+
inline const std::complex<T>* tensor_cdata(const Tensor& t) {
46+
return reinterpret_cast<const std::complex<T>*>(
47+
t.const_data_ptr<executorch::runtime::etensor::complex<T>>());
48+
}
49+
50+
// NOTE: in particular this is in ATen/native/SpectralOpsUtils.h and
51+
// could be shared immediately.
52+
enum class fft_norm_mode {
53+
none, // No normalization
54+
by_root_n, // Divide by sqrt(signal_size)
55+
by_n, // Divide by signal_size
56+
};
57+
58+
// NOTE: slight fork from upstream PyTorch to use ET_KERNEL_CHECK;
59+
// upstream with TORCH_CHECK will be fine to use once we have code
60+
// sharing.
61+
template <typename T>
62+
std::optional<T>
63+
compute_fct(KernelRuntimeContext& ctx, int64_t size, int64_t normalization) {
64+
constexpr auto one = static_cast<T>(1);
65+
switch (static_cast<fft_norm_mode>(normalization)) {
66+
case fft_norm_mode::none:
67+
return one;
68+
case fft_norm_mode::by_n:
69+
return one / static_cast<T>(size);
70+
case fft_norm_mode::by_root_n:
71+
return one / std::sqrt(static_cast<T>(size));
72+
}
73+
ET_KERNEL_CHECK_MSG(
74+
ctx,
75+
false,
76+
InvalidArgument,
77+
std::nullopt,
78+
"Unsupported normalization type: %" PRId64,
79+
normalization);
80+
}
81+
82+
template <typename T>
83+
std::optional<T> compute_fct(
84+
KernelRuntimeContext& ctx,
85+
const Tensor& t,
86+
IntArrayRef dim,
87+
int64_t normalization) {
88+
if (static_cast<fft_norm_mode>(normalization) == fft_norm_mode::none) {
89+
return static_cast<T>(1);
90+
}
91+
const auto& sizes = t.sizes();
92+
int64_t n = 1;
93+
for (auto idx : dim) {
94+
n *= sizes[idx];
95+
}
96+
return compute_fct<T>(ctx, n, normalization);
97+
}
98+
} // namespace
99+
100+
} // namespace torch::executor::native

kernels/optimized/cpu/op_fft_c2r.cpp

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/optimized/cpu/fft_utils.h>
10+
#include <executorch/runtime/core/span.h>
11+
12+
namespace torch::executor::native {
13+
Tensor& opt_fft_c2r_out(
14+
KernelRuntimeContext& ctx,
15+
const Tensor& in,
16+
IntArrayRef dim,
17+
int64_t normalization,
18+
int64_t last_dim_size,
19+
Tensor& out) {
20+
auto in_sizes = in.sizes();
21+
ET_KERNEL_CHECK(ctx, in.dim() <= kTensorDimensionLimit, InvalidArgument, out);
22+
23+
ET_KERNEL_CHECK(ctx, !dim.empty(), InvalidArgument, out);
24+
ET_KERNEL_CHECK(ctx, last_dim_size >= 1, InvalidArgument, out);
25+
26+
// Determine the output size
27+
std::array<Tensor::SizesType, kTensorDimensionLimit> out_sizes_storage{};
28+
executorch::runtime::Span<Tensor::SizesType> out_sizes(
29+
out_sizes_storage.data(), in_sizes.size());
30+
std::copy(in_sizes.begin(), in_sizes.end(), out_sizes.begin());
31+
out_sizes[dim.back()] = last_dim_size;
32+
33+
ET_KERNEL_CHECK(
34+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
35+
36+
ET_KERNEL_CHECK_MSG(
37+
ctx,
38+
in.scalar_type() == executorch::runtime::toComplexType(out.scalar_type()),
39+
InvalidArgument,
40+
out,
41+
"the input type for _fft_c2r must be the Complex type corresponding to the output type");
42+
43+
for (auto d : dim) {
44+
ET_KERNEL_CHECK_MSG(
45+
ctx,
46+
d >= 0 && d < in.dim(),
47+
InvalidArgument,
48+
out,
49+
"dims must be in bounds (got %" PRId64 ")",
50+
d);
51+
}
52+
53+
ET_KERNEL_CHECK_MSG(
54+
ctx,
55+
resize_tensor(
56+
out,
57+
executorch::runtime::ArrayRef<Tensor::SizesType>(
58+
out_sizes.data(), out_sizes.size())) == Error::Ok,
59+
InvalidArgument,
60+
out,
61+
"Failed to resize output tensor (last dim %d).",
62+
out_sizes[dim.back()]);
63+
64+
pocketfft::shape_t axes(dim.begin(), dim.end());
65+
auto out_shape = shape_from_tensor(out);
66+
// TODO: if arbitrary strides are a possibility, we need to validate
67+
// these, because pocketfft README says "Strides that lead to
68+
// multiple accesses of the same memory address are not allowed."
69+
auto in_stride = stride_from_tensor(in);
70+
auto out_stride = stride_from_tensor(out);
71+
// NOTE: as of this writing, upstream PyTorch only supports
72+
// float/double, so we follow suit.
73+
ET_SWITCH_FLOAT_TYPES(out.scalar_type(), ctx, "_fft_c2r.out", CTYPE_OUT, [&] {
74+
auto fct = compute_fct<CTYPE_OUT>(ctx, out, dim, normalization);
75+
if (!fct) {
76+
// Check failed, just bail out of the lambda.
77+
return;
78+
}
79+
pocketfft::c2r<CTYPE_OUT>(
80+
out_shape,
81+
in_stride,
82+
out_stride,
83+
axes,
84+
false /* forward */,
85+
tensor_cdata<CTYPE_OUT>(in),
86+
out.mutable_data_ptr<CTYPE_OUT>(),
87+
*fct);
88+
});
89+
return out;
90+
}
91+
} // namespace torch::executor::native

kernels/optimized/cpu/op_fft_r2c.cpp

+1-90
Original file line numberDiff line numberDiff line change
@@ -6,99 +6,10 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/kernels/optimized/cpu/fft_utils.h>
910
#include <executorch/runtime/core/span.h>
10-
#include <executorch/runtime/kernel/kernel_includes.h>
11-
12-
#include <pocketfft_hdronly.h>
13-
14-
#include <optional>
1511

1612
namespace torch::executor::native {
17-
18-
// TODO: contents of this anonymous namespace are copy/pasted from
19-
// PyTorch core (aten/src/ATen/native/mkl/SpectralOps.cpp). Small
20-
// portions (the parts that don't depend on Tensor) could be reused;
21-
// refactor to enable that once we can share headers from PyTorch
22-
// core.
23-
namespace {
24-
pocketfft::stride_t stride_from_tensor(const Tensor& t) {
25-
pocketfft::stride_t stride(t.strides().begin(), t.strides().end());
26-
for (auto& s : stride) {
27-
s *= t.element_size();
28-
}
29-
return stride;
30-
}
31-
32-
pocketfft::shape_t shape_from_tensor(const Tensor& t) {
33-
return pocketfft::shape_t(t.sizes().begin(), t.sizes().end());
34-
}
35-
36-
// NOTE: The reinterpret_cast in tensor_cdata is UB, but it's what
37-
// PyTorch core does and I'm not aware of a portable way to do this
38-
// that doesn't rely on UB.
39-
template <typename T>
40-
inline std::complex<T>* tensor_cdata(Tensor& t) {
41-
return reinterpret_cast<std::complex<T>*>(
42-
t.data_ptr<executorch::runtime::etensor::complex<T>>());
43-
}
44-
45-
template <typename T>
46-
inline const std::complex<T>* tensor_cdata(const Tensor& t) {
47-
return reinterpret_cast<const std::complex<T>*>(
48-
t.const_data_ptr<executorch::runtime::etensor::complex<T>>());
49-
}
50-
51-
// NOTE: in particular this is in ATen/native/SpectralOpsUtils.h and
52-
// could be shared immediately.
53-
enum class fft_norm_mode {
54-
none, // No normalization
55-
by_root_n, // Divide by sqrt(signal_size)
56-
by_n, // Divide by signal_size
57-
};
58-
59-
// NOTE: slight fork from upstream PyTorch to use ET_KERNEL_CHECK;
60-
// upstream with TORCH_CHECK will be fine to use once we have code
61-
// sharing.
62-
template <typename T>
63-
std::optional<T>
64-
compute_fct(KernelRuntimeContext& ctx, int64_t size, int64_t normalization) {
65-
constexpr auto one = static_cast<T>(1);
66-
switch (static_cast<fft_norm_mode>(normalization)) {
67-
case fft_norm_mode::none:
68-
return one;
69-
case fft_norm_mode::by_n:
70-
return one / static_cast<T>(size);
71-
case fft_norm_mode::by_root_n:
72-
return one / std::sqrt(static_cast<T>(size));
73-
}
74-
ET_KERNEL_CHECK_MSG(
75-
ctx,
76-
false,
77-
InvalidArgument,
78-
std::nullopt,
79-
"Unsupported normalization type: %" PRId64,
80-
normalization);
81-
}
82-
83-
template <typename T>
84-
std::optional<T> compute_fct(
85-
KernelRuntimeContext& ctx,
86-
const Tensor& t,
87-
IntArrayRef dim,
88-
int64_t normalization) {
89-
if (static_cast<fft_norm_mode>(normalization) == fft_norm_mode::none) {
90-
return static_cast<T>(1);
91-
}
92-
const auto& sizes = t.sizes();
93-
int64_t n = 1;
94-
for (auto idx : dim) {
95-
n *= sizes[idx];
96-
}
97-
return compute_fct<T>(ctx, n, normalization);
98-
}
99-
100-
} // namespace
101-
10213
Tensor& opt_fft_r2c_out(
10314
KernelRuntimeContext& ctx,
10415
const Tensor& in,

kernels/optimized/cpu/targets.bzl

+17-1
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,21 @@ _OPTIMIZED_ATEN_OPS = (
3535
],
3636
),
3737
op_target(name = "op_exp"),
38+
op_target(
39+
name = "op_fft_c2r",
40+
compiler_flags = [] if runtime.is_oss else [
41+
"-Wno-global-constructors",
42+
"-Wno-shadow",
43+
],
44+
deps = [":fft_utils"],
45+
),
3846
op_target(
3947
name = "op_fft_r2c",
4048
compiler_flags = [] if runtime.is_oss else [
4149
"-Wno-global-constructors",
4250
"-Wno-shadow",
4351
],
44-
deps = [] if runtime.is_oss else ["fbsource//third-party/pocket_fft:pocketfft"],
52+
deps = [":fft_utils"],
4553
),
4654
op_target(name = "op_sigmoid"),
4755
op_target(
@@ -143,6 +151,14 @@ def define_common_targets():
143151
exported_deps = ["//executorch/runtime/core:core"],
144152
)
145153

154+
runtime.cxx_library(
155+
name = "fft_utils",
156+
srcs = [],
157+
exported_headers = ["fft_utils.h"],
158+
visibility = ["//executorch/kernels/optimized/cpu/..."],
159+
exported_deps = [] if runtime.is_oss else ["fbsource//third-party/pocket_fft:pocketfft"],
160+
)
161+
146162
runtime.cxx_library(
147163
name = "binary_ops",
148164
exported_headers = ["binary_ops.h"],

kernels/optimized/optimized.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
#
33
# This yaml file contains operators that have optimized kernels available.
44

5+
- op: _fft_c2r.out
6+
kernels:
7+
- arg_meta: null
8+
kernel_name: torch::executor::opt_fft_c2r_out
9+
510
- op: _fft_r2c.out
611
kernels:
712
- arg_meta: null

kernels/test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ set(_optimized_kernels_test_sources
276276
"op_div_test.cpp"
277277
"op_elu_test.cpp"
278278
"op_exp_test.cpp"
279+
"op_fft_c2r_test.cpp"
279280
"op_fft_r2c_test.cpp"
280281
"op_gelu_test.cpp"
281282
"op_le_test.cpp"

0 commit comments

Comments
 (0)