Skip to content

Commit e723876

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Runtime API to retrieve attributes (#10144)
Summary: Scan implementation today which is kind of slow. I dont think people wil regularly be scanning for tons of tensors though so perf is not super sensitive. Can always add a cache later if necessary. Reviewed By: larryliu0820 Differential Revision: D72802654
1 parent 4b1b4ba commit e723876

11 files changed

+225
-17
lines changed

runtime/executor/method.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -1593,6 +1593,38 @@ EValue& Method::mutable_input(size_t i) {
15931593
return mutable_value(get_input_index(i));
15941594
}
15951595

1596+
Result<executorch::aten::Tensor> Method::get_attribute(
1597+
executorch::aten::string_view name) {
1598+
auto flatbuffer_values = serialization_plan_->values();
1599+
size_t counter = 0;
1600+
1601+
for (size_t i = 0; i < flatbuffer_values->size(); ++i) {
1602+
auto serialization_value = flatbuffer_values->Get(i);
1603+
if (serialization_value->val_type() ==
1604+
executorch_flatbuffer::KernelTypes::Tensor) {
1605+
const auto s_tensor = static_cast<const executorch_flatbuffer::Tensor*>(
1606+
serialization_value->val());
1607+
if (s_tensor->extra_tensor_info() != nullptr &&
1608+
s_tensor->extra_tensor_info()->fully_qualified_name() != nullptr &&
1609+
executorch::aten::string_view{
1610+
s_tensor->extra_tensor_info()->fully_qualified_name()->c_str(),
1611+
s_tensor->extra_tensor_info()->fully_qualified_name()->size()} ==
1612+
name) {
1613+
if (!this->values_[counter].isTensor()) {
1614+
ET_LOG(
1615+
Error,
1616+
"Attribute tensor not at the expected location. The .pte is likely malformed. Please file a bug report on https://github.com/pytorch/executorch/issues");
1617+
return Error::Internal;
1618+
}
1619+
return this->values_[counter].toTensor();
1620+
}
1621+
}
1622+
++counter;
1623+
}
1624+
1625+
return Error::NotFound;
1626+
}
1627+
15961628
size_t Method::outputs_size() const {
15971629
const auto* outputs = serialization_plan_->outputs();
15981630
return outputs == nullptr ? 0 : outputs->size();

runtime/executor/method.h

+12
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,18 @@ class Method final {
192192
*/
193193
ET_NODISCARD Error get_inputs(EValue* input_evalues, size_t length);
194194

195+
/**
196+
*
197+
* Retrieves the attribute tensor associated with the given name.
198+
*
199+
* @param[in] name The name of the attribute tensor to retrieve.
200+
*
201+
* @returns Result containing the attribute tensor on success, non-Ok on
202+
* failure.
203+
*/
204+
ET_NODISCARD Result<executorch::aten::Tensor> get_attribute(
205+
executorch::aten::string_view name);
206+
195207
/**
196208
* Execute the method.
197209
*

runtime/executor/method_meta.cpp

+62-3
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@ TensorInfo::TensorInfo(
6969
Span<const int32_t> sizes,
7070
Span<const uint8_t> dim_order,
7171
executorch::aten::ScalarType scalar_type,
72-
const bool is_memory_planned)
72+
const bool is_memory_planned,
73+
std::optional<executorch::aten::string_view> name)
7374
: sizes_(sizes),
7475
dim_order_(dim_order),
76+
name_(name),
7577
scalar_type_(scalar_type),
7678
is_memory_planned_(is_memory_planned),
7779
nbytes_(calculate_nbytes(sizes_, scalar_type_)) {}
@@ -96,6 +98,10 @@ size_t TensorInfo::nbytes() const {
9698
return nbytes_;
9799
}
98100

101+
std::optional<executorch::aten::string_view> TensorInfo::name() const {
102+
return name_;
103+
}
104+
99105
MethodMeta::MethodMeta(const executorch_flatbuffer::ExecutionPlan* s_plan)
100106
: s_plan_(s_plan) {}
101107

@@ -150,7 +156,8 @@ Result<TensorInfo> MethodMeta::input_tensor_meta(size_t index) const {
150156
static_cast<executorch::aten::ScalarType>(tensor_value->scalar_type()),
151157
tensor_value->allocation_info() != nullptr ||
152158
tensor_value->data_buffer_idx() !=
153-
0); // Count constant returns as memory planned.
159+
0 /* is_memory_planned */); // Count constant returns as memory
160+
// planned.
154161
}
155162

156163
size_t MethodMeta::num_outputs() const {
@@ -201,7 +208,59 @@ Result<TensorInfo> MethodMeta::output_tensor_meta(size_t index) const {
201208
static_cast<executorch::aten::ScalarType>(tensor_value->scalar_type()),
202209
tensor_value->allocation_info() != nullptr ||
203210
tensor_value->data_buffer_idx() !=
204-
0); // Count constant returns as memory planned.
211+
0 /* is_memory_planned */); // Count constant returns as memory
212+
// planned.
213+
}
214+
215+
size_t MethodMeta::num_attributes() const {
216+
size_t counter = 0;
217+
auto values = s_plan_->values();
218+
for (size_t i = 0; i < values->size(); ++i) {
219+
auto value = values->Get(i);
220+
if (value->val_type() == executorch_flatbuffer::KernelTypes::Tensor) {
221+
auto tensor_value = value->val_as_Tensor();
222+
if (tensor_value->extra_tensor_info() != nullptr &&
223+
tensor_value->extra_tensor_info()->fully_qualified_name()->c_str() !=
224+
nullptr) {
225+
++counter;
226+
}
227+
}
228+
}
229+
return counter;
230+
}
231+
232+
Result<TensorInfo> MethodMeta::attribute_tensor_meta(size_t index) const {
233+
size_t counter = 0;
234+
auto values = s_plan_->values();
235+
for (size_t i = 0; i < values->size(); ++i) {
236+
auto value = values->Get(i);
237+
if (value->val_type() == executorch_flatbuffer::KernelTypes::Tensor) {
238+
auto tensor_value = value->val_as_Tensor();
239+
if (tensor_value->extra_tensor_info() != nullptr &&
240+
tensor_value->extra_tensor_info()->fully_qualified_name()->c_str() !=
241+
nullptr) {
242+
if (counter == index) {
243+
auto t_name =
244+
tensor_value->extra_tensor_info()->fully_qualified_name();
245+
// Count constant returns as memory planned
246+
return TensorInfo(
247+
Span<const int32_t>(
248+
tensor_value->sizes()->data(), tensor_value->sizes()->size()),
249+
Span<const uint8_t>(
250+
tensor_value->dim_order()->data(),
251+
tensor_value->dim_order()->size()),
252+
static_cast<executorch::aten::ScalarType>(
253+
tensor_value->scalar_type()),
254+
tensor_value->allocation_info() != nullptr ||
255+
tensor_value->data_buffer_idx() != 0 /* is_memory_planned */,
256+
executorch::aten::string_view{t_name->c_str(), t_name->size()});
257+
}
258+
++counter;
259+
}
260+
}
261+
}
262+
ET_LOG(Error, "No attribute tensor found at index %zu", index);
263+
return Error::InvalidArgument;
205264
}
206265

207266
size_t MethodMeta::num_memory_planned_buffers() const {

runtime/executor/method_meta.h

+25-1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ class TensorInfo final {
6262
*/
6363
size_t nbytes() const;
6464

65+
/**
66+
* Returns the fully qualified name of the Tensor if it has one.
67+
*/
68+
std::optional<executorch::aten::string_view> name() const;
69+
6570
private:
6671
// Let MethodMeta create TensorInfo.
6772
friend class MethodMeta;
@@ -70,7 +75,8 @@ class TensorInfo final {
7075
Span<const int32_t> sizes,
7176
Span<const uint8_t> dim_order,
7277
executorch::aten::ScalarType scalar_type,
73-
const bool is_memory_planned);
78+
const bool is_memory_planned,
79+
std::optional<executorch::aten::string_view> name = std::nullopt);
7480

7581
/**
7682
* The sizes of the tensor.
@@ -88,6 +94,9 @@ class TensorInfo final {
8894
*/
8995
Span<const uint8_t> dim_order_;
9096

97+
/// The fully qualified name of the Tensor.
98+
std::optional<executorch::aten::string_view> name_;
99+
91100
/// The scalar type of the tensor.
92101
executorch::aten::ScalarType scalar_type_;
93102

@@ -170,6 +179,21 @@ class MethodMeta final {
170179
*/
171180
Result<TensorInfo> output_tensor_meta(size_t index) const;
172181

182+
/**
183+
* Get the number of attribute tensors in this method.
184+
*
185+
* @returns The number of attribute tensors.
186+
*/
187+
size_t num_attributes() const;
188+
189+
/**
190+
* Get metadata about the specified attribute tensor.
191+
*
192+
* @param[in] index The index of the attribute tensor to look up.
193+
* @returns The metadata on success, or an error on failure.
194+
*/
195+
Result<TensorInfo> attribute_tensor_meta(size_t index) const;
196+
173197
/**
174198
* Get the number of memory-planned buffers this method requires.
175199
*

runtime/executor/test/CMakeLists.txt

+4-1
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ add_custom_command(
2727
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
2828
"${CMAKE_CURRENT_BINARY_DIR}/ModuleMultipleEntry.pte"
2929
"${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrain.pte"
30+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleStateful.pte"
3031
COMMAND
3132
python3 -m test.models.export_program --modules
32-
"ModuleAdd,ModuleAddHalf,ModuleDynamicCatUnallocatedIO,ModuleIndex,ModuleLinear,ModuleMultipleEntry,ModuleSimpleTrain"
33+
"ModuleAdd,ModuleAddHalf,ModuleDynamicCatUnallocatedIO,ModuleIndex,ModuleLinear,ModuleMultipleEntry,ModuleSimpleTrain,ModuleStateful"
3334
--outdir "${CMAKE_CURRENT_BINARY_DIR}" 2> /dev/null
3435
COMMAND
3536
python3 -m test.models.export_program --modules "ModuleLinear"
@@ -51,6 +52,7 @@ add_custom_target(
5152
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
5253
"${CMAKE_CURRENT_BINARY_DIR}/ModuleMultipleEntry.pte"
5354
"${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrain.pte"
55+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleStateful.pte"
5456
)
5557

5658
set(test_env
@@ -64,6 +66,7 @@ set(test_env
6466
"ET_MODULE_LINEAR_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
6567
"ET_MODULE_MULTI_ENTRY_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleMultipleEntry.pte"
6668
"ET_MODULE_SIMPLE_TRAIN_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrain.pte"
69+
"ET_MODULE_STATEFUL_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleStateful.pte"
6770
)
6871

6972
et_cxx_test(

runtime/executor/test/method_meta_test.cpp

+36-11
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,34 @@ using torch::executor::util::FileDataLoader;
2626

2727
class MethodMetaTest : public ::testing::Test {
2828
protected:
29-
void SetUp() override {
30-
// Create a loader for the serialized ModuleAdd program.
31-
const char* path = std::getenv("ET_MODULE_ADD_PATH");
29+
void load_program(const char* path, const char* module_name) {
30+
// Create a loader for the serialized program.
3231
Result<FileDataLoader> loader = FileDataLoader::from(path);
3332
ASSERT_EQ(loader.error(), Error::Ok);
34-
loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
33+
loaders_.insert(
34+
{module_name,
35+
std::make_unique<FileDataLoader>(std::move(loader.get()))});
3536

3637
// Use it to load the program.
3738
Result<Program> program = Program::load(
38-
loader_.get(), Program::Verification::InternalConsistency);
39+
loaders_[module_name].get(),
40+
Program::Verification::InternalConsistency);
3941
ASSERT_EQ(program.error(), Error::Ok);
40-
program_ = std::make_unique<Program>(std::move(program.get()));
42+
programs_.insert(
43+
{module_name, std::make_unique<Program>(std::move(program.get()))});
44+
}
45+
46+
void SetUp() override {
47+
load_program(std::getenv("ET_MODULE_ADD_PATH"), "add");
48+
load_program(std::getenv("ET_MODULE_STATEFUL_PATH"), "stateful");
4149
}
4250

4351
private:
4452
// Must outlive program_, but tests shouldn't need to touch it.
45-
std::unique_ptr<FileDataLoader> loader_;
53+
std::unordered_map<std::string, std::unique_ptr<FileDataLoader>> loaders_;
4654

4755
protected:
48-
std::unique_ptr<Program> program_;
56+
std::unordered_map<std::string, std::unique_ptr<Program>> programs_;
4957
};
5058

5159
namespace {
@@ -67,7 +75,7 @@ void check_tensor(const TensorInfo& tensor_info) {
6775
} // namespace
6876

6977
TEST_F(MethodMetaTest, MethodMetaApi) {
70-
Result<MethodMeta> method_meta = program_->method_meta("forward");
78+
Result<MethodMeta> method_meta = programs_["add"]->method_meta("forward");
7179
ASSERT_EQ(method_meta.error(), Error::Ok);
7280

7381
// Appropriate amount of inputs
@@ -97,11 +105,12 @@ TEST_F(MethodMetaTest, MethodMetaApi) {
97105

98106
// Missing method fails
99107
EXPECT_EQ(
100-
program_->method_meta("not_a_method").error(), Error::InvalidArgument);
108+
programs_["add"]->method_meta("not_a_method").error(),
109+
Error::InvalidArgument);
101110
}
102111

103112
TEST_F(MethodMetaTest, TensorInfoApi) {
104-
Result<MethodMeta> method_meta = program_->method_meta("forward");
113+
Result<MethodMeta> method_meta = programs_["add"]->method_meta("forward");
105114
ASSERT_EQ(method_meta.error(), Error::Ok);
106115

107116
// Input 1
@@ -138,3 +147,19 @@ TEST_F(MethodMetaTest, TensorInfoApi) {
138147
EXPECT_EQ(
139148
method_meta->output_tensor_meta(-1).error(), Error::InvalidArgument);
140149
}
150+
151+
TEST_F(MethodMetaTest, MethodMetaAttribute) {
152+
Result<MethodMeta> method_meta =
153+
programs_["stateful"]->method_meta("forward");
154+
ASSERT_EQ(method_meta.error(), Error::Ok);
155+
156+
ASSERT_EQ(method_meta->num_attributes(), 1);
157+
auto state = method_meta->attribute_tensor_meta(0);
158+
ASSERT_TRUE(state.ok());
159+
160+
ASSERT_EQ(state->name(), "state");
161+
ASSERT_FALSE(state->is_memory_planned());
162+
163+
auto bad_access = method_meta->attribute_tensor_meta(1);
164+
ASSERT_EQ(bad_access.error(), Error::InvalidArgument);
165+
}

runtime/executor/test/method_test.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class MethodTest : public ::testing::Test {
7979
load_program(
8080
std::getenv("ET_MODULE_DYNAMIC_CAT_UNALLOCATED_IO_PATH"), "cat");
8181
load_program(std::getenv("ET_MODULE_LINEAR_PATH"), "linear");
82+
load_program(std::getenv("ET_MODULE_STATEFUL_PATH"), "stateful");
8283
load_program(
8384
std::getenv("DEPRECATED_ET_MODULE_LINEAR_CONSTANT_BUFFER_PATH"),
8485
"linear_constant_buffer");
@@ -339,6 +340,31 @@ TEST_F(MethodTest, ProgramDataSeparationTest) {
339340
ASSERT_EQ(err, Error::Ok);
340341
}
341342

343+
TEST_F(MethodTest, MethodGetAttributeTest) {
344+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
345+
Result<Method> method =
346+
programs_["stateful"]->load_method("forward", &mmm.get());
347+
ASSERT_EQ(method.error(), Error::Ok);
348+
349+
auto res = method->get_attribute("state");
350+
ASSERT_TRUE(res.ok());
351+
// expect data to be empty
352+
EXPECT_EQ(res->const_data_ptr(), nullptr);
353+
354+
int32_t data = 0;
355+
res->set_data(&data);
356+
357+
// expect data to be set
358+
EXPECT_EQ(res->const_data_ptr(), &data);
359+
360+
// Can execute the method.
361+
Error err = method->execute();
362+
ASSERT_EQ(err, Error::Ok);
363+
364+
// Expect the state to be incremented
365+
EXPECT_EQ(res->const_data_ptr<int32_t>()[0], 1);
366+
}
367+
342368
/*
343369
* TODO(T161163608): Test is disabled due to a resize bug in tensor_index_out of
344370
* the portable op lib

runtime/executor/test/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def define_common_targets(is_fbcode = False):
122122
"ET_MODULE_LINEAR_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleLinear.pte])",
123123
"ET_MODULE_MULTI_ENTRY_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleMultipleEntry.pte])",
124124
"ET_MODULE_SIMPLE_TRAIN_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleSimpleTrain.pte])",
125+
"ET_MODULE_STATEFUL_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleStateful.pte])",
125126
"ET_MODULE_LINEAR_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.pte])",
126127
"ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])",
127128
}

test/end2end/exported_module.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def export(
7070
skip_type_promotion: bool = False,
7171
export_joint_graph: bool = False,
7272
external_constants: bool = False,
73+
export_state_names: bool = False,
7374
) -> "ExportedModule":
7475
"""
7576
Creates a new ExportedModule for the specified module class.
@@ -148,7 +149,9 @@ def return_wrapper():
148149
for method in methods:
149150
method_name_to_dynamic_shapes[method] = trace_dynamic_shapes
150151

151-
memory_planning_pass = MemoryPlanningPass()
152+
memory_planning_pass = MemoryPlanningPass(
153+
alloc_mutable_buffers=not export_state_names
154+
)
152155
if hasattr(eager_module, "get_memory_planning_pass"):
153156
memory_planning_pass = eager_module.get_memory_planning_pass() # type: ignore[operator]
154157

@@ -208,6 +211,7 @@ def __init__(self, method):
208211
memory_planning_pass=memory_planning_pass,
209212
to_out_var_pass=ToOutVarPass(ignore_to_out_var_failure),
210213
external_constants=external_constants,
214+
emit_mutable_buffer_names=export_state_names,
211215
)
212216
)
213217

0 commit comments

Comments
 (0)