diff --git a/runtime/backend/backend_options.h b/runtime/backend/backend_options.h new file mode 100644 index 00000000000..6c106cc1561 --- /dev/null +++ b/runtime/backend/backend_options.h @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include + +namespace executorch { +namespace runtime { + +// Strongly-typed option key template +template +struct OptionKey { + const char* key; + constexpr explicit OptionKey(const char* k) : key(k) {} +}; + +// Union replaced with std::variant +using OptionValue = std::variant; + +struct BackendOption { + const char* key; // key is the name of the backend option, like num_threads, + // enable_profiling, etc + OptionValue + value; // value is the value of the backend option, like 4, true, etc +}; + +template +class BackendOptions { + public: + // Initialize with zero options + BackendOptions() : size_(0) {} + + // Type-safe setters + template + void set_option(OptionKey key, T value) { + const char* k = key.key; + // Update existing if found + for (size_t i = 0; i < size_; ++i) { + if (strcmp(options_[i].key, k) == 0) { + options_[i].value = value; + return; + } + } + // Add new option if space available + if (size_ < MaxCapacity) { + options_[size_++] = BackendOption{k, value}; + } + } + + // Type-safe getters + template + Error get_option(OptionKey key, T& out) const { + const char* k = key.key; + for (size_t i = 0; i < size_; ++i) { + if (strcmp(options_[i].key, k) == 0) { + if (auto* val = std::get_if(&options_[i].value)) { + out = *val; + return Error::Ok; + } + return Error::InvalidArgument; + } + } + return Error::NotFound; + } + + private: + BackendOption options_[MaxCapacity]{}; // Storage for backend options + size_t size_; // Current number of options +}; + +// Helper functions for creating typed option keys (unchanged) +constexpr OptionKey BoolKey(const char* k) { + return OptionKey(k); +} + +constexpr OptionKey IntKey(const char* k) { + return OptionKey(k); +} + +constexpr OptionKey StrKey(const char* k) { + return OptionKey(k); +} + +} // namespace runtime +} // namespace executorch diff --git a/runtime/backend/targets.bzl b/runtime/backend/targets.bzl index d2187afb5fc..49a14d4d0d6 100644 --- a/runtime/backend/targets.bzl +++ b/runtime/backend/targets.bzl @@ -17,6 +17,7 @@ def define_common_targets(): exported_headers = [ "backend_execution_context.h", "backend_init_context.h", + "backend_options.h", "interface.h", ], preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [], diff --git a/runtime/backend/test/backend_options_test.cpp b/runtime/backend/test/backend_options_test.cpp new file mode 100644 index 00000000000..32d5c6008f5 --- /dev/null +++ b/runtime/backend/test/backend_options_test.cpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +using namespace ::testing; +using executorch::runtime::BackendOptions; +using executorch::runtime::BoolKey; +using executorch::runtime::Error; +using executorch::runtime::IntKey; +using executorch::runtime::OptionKey; +using executorch::runtime::StrKey; + +class BackendOptionsTest : public ::testing::Test { + protected: + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + } + BackendOptions<5> options; // Capacity of 5 for testing limits +}; + +// Test basic string functionality +TEST_F(BackendOptionsTest, HandlesStringOptions) { + // Set and retrieve valid string + options.set_option(StrKey("backend_type"), "GPU"); + const char* result = nullptr; + EXPECT_EQ(options.get_option(StrKey("backend_type"), result), Error::Ok); + EXPECT_STREQ(result, "GPU"); + + // Update existing key + options.set_option(StrKey("backend_type"), "CPU"); + EXPECT_EQ(options.get_option(StrKey("backend_type"), result), Error::Ok); + EXPECT_STREQ(result, "CPU"); +} + +// Test boolean options +TEST_F(BackendOptionsTest, HandlesBoolOptions) { + options.set_option(BoolKey("debug"), true); + bool debug = false; + EXPECT_EQ(options.get_option(BoolKey("debug"), debug), Error::Ok); + EXPECT_TRUE(debug); + + // Test false value + options.set_option(BoolKey("verbose"), false); + EXPECT_EQ(options.get_option(BoolKey("verbose"), debug), Error::Ok); + EXPECT_FALSE(debug); +} + +// Test integer options +TEST_F(BackendOptionsTest, HandlesIntOptions) { + options.set_option(IntKey("num_threads"), 256); + int num_threads = 0; + EXPECT_EQ(options.get_option(IntKey("num_threads"), num_threads), Error::Ok); + EXPECT_EQ(num_threads, 256); +} + +// Test error conditions +TEST_F(BackendOptionsTest, HandlesErrors) { + // Non-existent key + bool dummy_bool; + EXPECT_EQ( + options.get_option(BoolKey("missing"), dummy_bool), Error::NotFound); + + // Type mismatch + options.set_option(IntKey("threshold"), 100); + const char* dummy_str = nullptr; + EXPECT_EQ( + options.get_option(StrKey("threshold"), dummy_str), + Error::InvalidArgument); + + // Null value handling + options.set_option(StrKey("nullable"), static_cast(nullptr)); + EXPECT_EQ(options.get_option(StrKey("nullable"), dummy_str), Error::Ok); + EXPECT_EQ(dummy_str, nullptr); +} + +// Test capacity limits +TEST_F(BackendOptionsTest, HandlesCapacity) { + // Use persistent storage for keys + std::vector keys = {"key0", "key1", "key2", "key3", "key4"}; + + // Fill to capacity with persistent keys + for (int i = 0; i < 5; i++) { + options.set_option(IntKey(keys[i].c_str()), i); + } + + // Verify all exist + int value; + for (int i = 0; i < 5; i++) { + EXPECT_EQ(options.get_option(IntKey(keys[i].c_str()), value), Error::Ok); + EXPECT_EQ(value, i); + } + + // Add beyond capacity - should fail + const char* overflow_key = "overflow"; + options.set_option(IntKey(overflow_key), 99); + EXPECT_EQ(options.get_option(IntKey(overflow_key), value), Error::NotFound); + + // Update existing within capacity + options.set_option(IntKey(keys[2].c_str()), 222); + EXPECT_EQ(options.get_option(IntKey(keys[2].c_str()), value), Error::Ok); + EXPECT_EQ(value, 222); +} + +// Test type-specific keys +TEST_F(BackendOptionsTest, EnforcesKeyTypes) { + // Same key name - later set operations overwrite earlier ones + options.set_option(BoolKey("flag"), true); + options.set_option(IntKey("flag"), 123); // Overwrites the boolean entry + + bool bval; + int ival; + + // Boolean get should fail - type was overwritten to INT + EXPECT_EQ(options.get_option(BoolKey("flag"), bval), Error::InvalidArgument); + + // Integer get should succeed with correct value + EXPECT_EQ(options.get_option(IntKey("flag"), ival), Error::Ok); + EXPECT_EQ(ival, 123); +} diff --git a/runtime/backend/test/targets.bzl b/runtime/backend/test/targets.bzl index 9ea585f650c..97299bbcb35 100644 --- a/runtime/backend/test/targets.bzl +++ b/runtime/backend/test/targets.bzl @@ -1,7 +1,16 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. The directory containing this targets.bzl file should also contain both TARGETS and BUCK files that call this function. """ - pass + runtime.cxx_test( + name = "backend_options_test", + srcs = ["backend_options_test.cpp"], + deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/backend:interface", + ], + )