diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index 2156e85ab3c..28f1a215562 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -6,6 +6,8 @@ - op: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out +- op: _fft_c2r.out + - op: _fft_r2c.out - op: _linalg_det.result diff --git a/kernels/optimized/cpu/fft_utils.h b/kernels/optimized/cpu/fft_utils.h new file mode 100644 index 00000000000..2225e8ddfa7 --- /dev/null +++ b/kernels/optimized/cpu/fft_utils.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace torch::executor::native { + +// TODO: contents of this anonymous namespace are copy/pasted from +// PyTorch core (aten/src/ATen/native/mkl/SpectralOps.cpp). Small +// portions (the parts that don't depend on Tensor) could be reused; +// refactor to enable that once we can share headers from PyTorch +// core. +namespace { +inline pocketfft::stride_t stride_from_tensor(const Tensor& t) { + pocketfft::stride_t stride(t.strides().begin(), t.strides().end()); + for (auto& s : stride) { + s *= t.element_size(); + } + return stride; +} + +inline pocketfft::shape_t shape_from_tensor(const Tensor& t) { + return pocketfft::shape_t(t.sizes().begin(), t.sizes().end()); +} + +// NOTE: The reinterpret_cast in tensor_cdata is UB, but it's what +// PyTorch core does and I'm not aware of a portable way to do this +// that doesn't rely on UB. +template +inline std::complex* tensor_cdata(Tensor& t) { + return reinterpret_cast*>( + t.data_ptr>()); +} + +template +inline const std::complex* tensor_cdata(const Tensor& t) { + return reinterpret_cast*>( + t.const_data_ptr>()); +} + +// NOTE: in particular this is in ATen/native/SpectralOpsUtils.h and +// could be shared immediately. +enum class fft_norm_mode { + none, // No normalization + by_root_n, // Divide by sqrt(signal_size) + by_n, // Divide by signal_size +}; + +// NOTE: slight fork from upstream PyTorch to use ET_KERNEL_CHECK; +// upstream with TORCH_CHECK will be fine to use once we have code +// sharing. +template +std::optional +compute_fct(KernelRuntimeContext& ctx, int64_t size, int64_t normalization) { + constexpr auto one = static_cast(1); + switch (static_cast(normalization)) { + case fft_norm_mode::none: + return one; + case fft_norm_mode::by_n: + return one / static_cast(size); + case fft_norm_mode::by_root_n: + return one / std::sqrt(static_cast(size)); + } + ET_KERNEL_CHECK_MSG( + ctx, + false, + InvalidArgument, + std::nullopt, + "Unsupported normalization type: %" PRId64, + normalization); +} + +template +std::optional compute_fct( + KernelRuntimeContext& ctx, + const Tensor& t, + IntArrayRef dim, + int64_t normalization) { + if (static_cast(normalization) == fft_norm_mode::none) { + return static_cast(1); + } + const auto& sizes = t.sizes(); + int64_t n = 1; + for (auto idx : dim) { + n *= sizes[idx]; + } + return compute_fct(ctx, n, normalization); +} +} // namespace + +} // namespace torch::executor::native diff --git a/kernels/optimized/cpu/op_fft_c2r.cpp b/kernels/optimized/cpu/op_fft_c2r.cpp new file mode 100644 index 00000000000..f595b5f7299 --- /dev/null +++ b/kernels/optimized/cpu/op_fft_c2r.cpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace torch::executor::native { +Tensor& opt_fft_c2r_out( + KernelRuntimeContext& ctx, + const Tensor& in, + IntArrayRef dim, + int64_t normalization, + int64_t last_dim_size, + Tensor& out) { + auto in_sizes = in.sizes(); + ET_KERNEL_CHECK(ctx, in.dim() <= kTensorDimensionLimit, InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, !dim.empty(), InvalidArgument, out); + ET_KERNEL_CHECK(ctx, last_dim_size >= 1, InvalidArgument, out); + + // Determine the output size + std::array out_sizes_storage{}; + executorch::runtime::Span out_sizes( + out_sizes_storage.data(), in_sizes.size()); + std::copy(in_sizes.begin(), in_sizes.end(), out_sizes.begin()); + out_sizes[dim.back()] = last_dim_size; + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + + ET_KERNEL_CHECK_MSG( + ctx, + in.scalar_type() == executorch::runtime::toComplexType(out.scalar_type()), + InvalidArgument, + out, + "the input type for _fft_c2r must be the Complex type corresponding to the output type"); + + for (auto d : dim) { + ET_KERNEL_CHECK_MSG( + ctx, + d >= 0 && d < in.dim(), + InvalidArgument, + out, + "dims must be in bounds (got %" PRId64 ")", + d); + } + + ET_KERNEL_CHECK_MSG( + ctx, + resize_tensor( + out, + executorch::runtime::ArrayRef( + out_sizes.data(), out_sizes.size())) == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor (last dim %d).", + out_sizes[dim.back()]); + + pocketfft::shape_t axes(dim.begin(), dim.end()); + auto out_shape = shape_from_tensor(out); + // TODO: if arbitrary strides are a possibility, we need to validate + // these, because pocketfft README says "Strides that lead to + // multiple accesses of the same memory address are not allowed." + auto in_stride = stride_from_tensor(in); + auto out_stride = stride_from_tensor(out); + // NOTE: as of this writing, upstream PyTorch only supports + // float/double, so we follow suit. + ET_SWITCH_FLOAT_TYPES(out.scalar_type(), ctx, "_fft_c2r.out", CTYPE_OUT, [&] { + auto fct = compute_fct(ctx, out, dim, normalization); + if (!fct) { + // Check failed, just bail out of the lambda. + return; + } + pocketfft::c2r( + out_shape, + in_stride, + out_stride, + axes, + false /* forward */, + tensor_cdata(in), + out.mutable_data_ptr(), + *fct); + }); + return out; +} +} // namespace torch::executor::native diff --git a/kernels/optimized/cpu/op_fft_r2c.cpp b/kernels/optimized/cpu/op_fft_r2c.cpp index 45d3d9acb42..750a7e8f0a2 100644 --- a/kernels/optimized/cpu/op_fft_r2c.cpp +++ b/kernels/optimized/cpu/op_fft_r2c.cpp @@ -6,99 +6,10 @@ * LICENSE file in the root directory of this source tree. */ +#include #include -#include - -#include - -#include namespace torch::executor::native { - -// TODO: contents of this anonymous namespace are copy/pasted from -// PyTorch core (aten/src/ATen/native/mkl/SpectralOps.cpp). Small -// portions (the parts that don't depend on Tensor) could be reused; -// refactor to enable that once we can share headers from PyTorch -// core. -namespace { -pocketfft::stride_t stride_from_tensor(const Tensor& t) { - pocketfft::stride_t stride(t.strides().begin(), t.strides().end()); - for (auto& s : stride) { - s *= t.element_size(); - } - return stride; -} - -pocketfft::shape_t shape_from_tensor(const Tensor& t) { - return pocketfft::shape_t(t.sizes().begin(), t.sizes().end()); -} - -// NOTE: The reinterpret_cast in tensor_cdata is UB, but it's what -// PyTorch core does and I'm not aware of a portable way to do this -// that doesn't rely on UB. -template -inline std::complex* tensor_cdata(Tensor& t) { - return reinterpret_cast*>( - t.data_ptr>()); -} - -template -inline const std::complex* tensor_cdata(const Tensor& t) { - return reinterpret_cast*>( - t.const_data_ptr>()); -} - -// NOTE: in particular this is in ATen/native/SpectralOpsUtils.h and -// could be shared immediately. -enum class fft_norm_mode { - none, // No normalization - by_root_n, // Divide by sqrt(signal_size) - by_n, // Divide by signal_size -}; - -// NOTE: slight fork from upstream PyTorch to use ET_KERNEL_CHECK; -// upstream with TORCH_CHECK will be fine to use once we have code -// sharing. -template -std::optional -compute_fct(KernelRuntimeContext& ctx, int64_t size, int64_t normalization) { - constexpr auto one = static_cast(1); - switch (static_cast(normalization)) { - case fft_norm_mode::none: - return one; - case fft_norm_mode::by_n: - return one / static_cast(size); - case fft_norm_mode::by_root_n: - return one / std::sqrt(static_cast(size)); - } - ET_KERNEL_CHECK_MSG( - ctx, - false, - InvalidArgument, - std::nullopt, - "Unsupported normalization type: %" PRId64, - normalization); -} - -template -std::optional compute_fct( - KernelRuntimeContext& ctx, - const Tensor& t, - IntArrayRef dim, - int64_t normalization) { - if (static_cast(normalization) == fft_norm_mode::none) { - return static_cast(1); - } - const auto& sizes = t.sizes(); - int64_t n = 1; - for (auto idx : dim) { - n *= sizes[idx]; - } - return compute_fct(ctx, n, normalization); -} - -} // namespace - Tensor& opt_fft_r2c_out( KernelRuntimeContext& ctx, const Tensor& in, diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index bf24e4de49c..7406cc21b53 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -35,13 +35,21 @@ _OPTIMIZED_ATEN_OPS = ( ], ), op_target(name = "op_exp"), + op_target( + name = "op_fft_c2r", + compiler_flags = [] if runtime.is_oss else [ + "-Wno-global-constructors", + "-Wno-shadow", + ], + deps = [":fft_utils"], + ), op_target( name = "op_fft_r2c", compiler_flags = [] if runtime.is_oss else [ "-Wno-global-constructors", "-Wno-shadow", ], - deps = [] if runtime.is_oss else ["fbsource//third-party/pocket_fft:pocketfft"], + deps = [":fft_utils"], ), op_target(name = "op_sigmoid"), op_target( @@ -143,6 +151,14 @@ def define_common_targets(): exported_deps = ["//executorch/runtime/core:core"], ) + runtime.cxx_library( + name = "fft_utils", + srcs = [], + exported_headers = ["fft_utils.h"], + visibility = ["//executorch/kernels/optimized/cpu/..."], + exported_deps = [] if runtime.is_oss else ["fbsource//third-party/pocket_fft:pocketfft"], + ) + runtime.cxx_library( name = "binary_ops", exported_headers = ["binary_ops.h"], diff --git a/kernels/optimized/optimized.yaml b/kernels/optimized/optimized.yaml index 864c3ed5780..42a065f63ed 100644 --- a/kernels/optimized/optimized.yaml +++ b/kernels/optimized/optimized.yaml @@ -2,6 +2,11 @@ # # This yaml file contains operators that have optimized kernels available. +- op: _fft_c2r.out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_fft_c2r_out + - op: _fft_r2c.out kernels: - arg_meta: null diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index dcefa8c2e68..deb61410b10 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -277,6 +277,7 @@ set(_optimized_kernels_test_sources "op_div_test.cpp" "op_elu_test.cpp" "op_exp_test.cpp" + "op_fft_c2r_test.cpp" "op_fft_r2c_test.cpp" "op_gelu_test.cpp" "op_le_test.cpp" diff --git a/kernels/test/op_fft_c2r_test.cpp b/kernels/test/op_fft_c2r_test.cpp new file mode 100644 index 00000000000..58c8a216e42 --- /dev/null +++ b/kernels/test/op_fft_c2r_test.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include + +#include + +using executorch::aten::IntArrayRef; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::testing::TensorFactory; + +class OpFftC2rOutTest : public OperatorTest { + protected: + Tensor& op_fft_c2r_out( + const Tensor& in, + IntArrayRef dim, + int64_t normalization, + int64_t last_dim_size, + Tensor& out) { + return torch::executor::aten::_fft_c2r_outf( + context_, in, dim, normalization, last_dim_size, out); + } + + template < + class CTYPE_OUT, + executorch::aten::ScalarType DTYPE_OUT, + bool expect_failure = false> + void test_dtype(int64_t norm, int64_t dim = 0) { + TensorFactory tf_out; + constexpr auto DTYPE_IN = executorch::runtime::toComplexType(DTYPE_OUT); + TensorFactory tf_in; + + using CTYPE_IN = + typename executorch::runtime::ScalarTypeToCppType::type; + + std::vector input_data = { + CTYPE_IN{24, 4}, + CTYPE_IN{4, -8}, + CTYPE_IN{0, 4}, + + CTYPE_IN{8, -16}, + CTYPE_IN{-4, 0}, + CTYPE_IN{0, 32}, + + CTYPE_IN{12, 0}, + CTYPE_IN{0, 4}, + CTYPE_IN{-8, 4}, + + CTYPE_IN{0, 8}, + CTYPE_IN{-4, 8}, + CTYPE_IN{8, 0}, + }; + + Tensor in = tf_in.make({4, 3}, input_data); + Tensor out = tf_out.full({4, 3}, 0); + + int64_t last_dim_size = + (dim >= 0 && dim < out.dim()) ? out.sizes()[dim] : 0; + op_fft_c2r_out(in, {dim}, norm, last_dim_size, out); + + double norm_factor = 1; + if (norm == 1) { + norm_factor = 2; + } else if (norm == 2) { + norm_factor = 4; + } + std::vector expected_data = { + 52., -4., -8., 44., 4., -56., 20., 12., -8., -20., 4., 72.}; + for (auto& elem : expected_data) { + elem /= norm_factor; + } + Tensor expected = tf_out.make({4, 3}, expected_data); + + if (!expect_failure) { + EXPECT_TENSOR_CLOSE(out, expected); + } + } + + template + void test_dtype_multiple_axes() { + TensorFactory tf_out; + constexpr auto DTYPE_IN = executorch::runtime::toComplexType(DTYPE_OUT); + TensorFactory tf_in; + + using CTYPE_IN = + typename executorch::runtime::ScalarTypeToCppType::type; + + std::vector input_data = { + CTYPE_IN{16, 4}, + CTYPE_IN{4, -8}, + CTYPE_IN{0, 4}, + + CTYPE_IN{8, -16}, + CTYPE_IN{-4, 0}, + CTYPE_IN{0, 36}, + + CTYPE_IN{32, 0}, + CTYPE_IN{0, 4}, + CTYPE_IN{-8, 4}, + + CTYPE_IN{0, 8}, + CTYPE_IN{-4, 8}, + CTYPE_IN{8, 0}, + }; + + Tensor in = tf_in.make({4, 3}, input_data); + Tensor out = tf_out.full({4, 4}, 0); + + int64_t last_dim_size = out.sizes()[0]; + std::array dim = {0, 1}; + op_fft_c2r_out(in, dim, 1, last_dim_size, out); + + std::vector expected_data = { + 12., + 12., + 16., + 16., + 1., + 15., + -11., + 3., + 12., + 20., + 0., + 8., + -1., + -15., + 3., + -27.}; + Tensor expected = tf_out.make({4, 4}, expected_data); + EXPECT_TENSOR_CLOSE(out, expected); + } +}; + +TEST_F(OpFftC2rOutTest, AllDtypesSupported) { +#define TEST_ENTRY(ctype, dtype) \ + test_dtype(0); \ + test_dtype(1); \ + test_dtype(2); + ET_FORALL_FLOAT_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpFftC2rOutTest, MultipleDims) { +#define TEST_ENTRY(ctype, dtype) \ + test_dtype_multiple_axes(); + ET_FORALL_FLOAT_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpFftC2rOutTest, InvalidNorm) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen MKL path does not validate norm"; + return; + } + auto invalid_norm = [this](int64_t norm) { + test_dtype(norm); + }; + ET_EXPECT_KERNEL_FAILURE(context_, invalid_norm(3)); + ET_EXPECT_KERNEL_FAILURE(context_, invalid_norm(4)); + ET_EXPECT_KERNEL_FAILURE(context_, invalid_norm(-1)); + ET_EXPECT_KERNEL_FAILURE(context_, invalid_norm(9999999)); +} + +TEST_F(OpFftC2rOutTest, InvalidDim) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen fails UBSAN"; + return; + } + auto negative_dim = [this]() { + test_dtype(0, -1); + test_dtype(0, 3); + test_dtype(0, 9001); + }; + ET_EXPECT_KERNEL_FAILURE(context_, negative_dim()); +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 8c5fad1f588..b9e1d3d6dac 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -223,6 +223,7 @@ def define_common_targets(): _common_op_test("op_exp_test", ["aten", "portable", "optimized"]) _common_op_test("op_expand_copy_test", ["aten", "portable"]) _common_op_test("op_expm1_test", ["aten", "portable"]) + _common_op_test("op_fft_c2r_test", ["aten", "optimized"]) _common_op_test("op_fft_r2c_test", ["aten", "optimized"]) _common_op_test("op_fill_test", ["aten", "portable"]) _common_op_test("op_flip_test", ["aten", "portable"])