Skip to content

Commit 27cdb5c

Browse files
[EP ABI] Add ability to drop constant initializers for fused nodes specified in GetCapability (#25137)
### Description - Add ability to drop constant initializers for fused nodes specified in GetCapability. - Rework how an EP specifies nodes that should be fused into one node within GetCapability. - Instead of passing the set of nodes as arguments to `GraphSupportInfo_AddNodesToFuse()`, the EP creates an `OrtNodeFusionOptions` object to specify the nodes and other relevant options. This makes it easier to extend the API in the future since we can't add more parameters to an existing function, but we can add more functions that modify an options object. ### Motivation and Context Add more functionality missing from GetCapability() in the EP ABI.
1 parent 18282b1 commit 27cdb5c

File tree

9 files changed

+249
-77
lines changed

9 files changed

+249
-77
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -796,9 +796,6 @@ typedef struct OrtCompileApi OrtCompileApi;
796796
struct OrtEpApi;
797797
typedef struct OrtEpApi OrtEpApi;
798798

799-
struct OrtNodeComputeInfo;
800-
typedef struct OrtNodeComputeInfo OrtNodeComputeInfo;
801-
802799
/** \brief The helper interface to get the right version of OrtApi
803800
*
804801
* Get a pointer to this structure through ::OrtGetApiBase

include/onnxruntime/core/session/onnxruntime_ep_c_api.h

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,44 @@ ORT_RUNTIME_CLASS(EpFactory);
1212
ORT_RUNTIME_CLASS(EpGraphSupportInfo);
1313
ORT_RUNTIME_CLASS(NodeComputeContext);
1414

15+
struct OrtNodeFusionOptions;
16+
typedef struct OrtNodeFusionOptions OrtNodeFusionOptions;
17+
18+
struct OrtNodeComputeInfo;
19+
typedef struct OrtNodeComputeInfo OrtNodeComputeInfo;
20+
21+
/**
22+
* \brief The OrtNodeFusionOptions struct specifies options for fusing nodes supported by an execution provider.
23+
*
24+
* Refer to OrtEpApi::EpGraphSupportInfo_AddNodesToFuse.
25+
*
26+
* \since Version 1.23.
27+
*/
28+
struct OrtNodeFusionOptions {
29+
/** \brief The ONNX Runtime version the OrtNodeFusionOptions was compiled with.
30+
*
31+
* Implementation should set to ORT_API_VERSION.
32+
* ORT will use this to ensure it does not use members that were not available when the EP library was compiled.
33+
*
34+
* \since Version 1.23.
35+
*/
36+
uint32_t ort_version_supported;
37+
38+
/** \brief If set to true, specify that the execution provider does not require ONNX Runtime to provide constant
39+
* initializers as inputs to the fused node during model inference. This is used when the execution
40+
* provider saves a copy of constant initializers, and allows ONNX Runtime to release constant initializers that
41+
* are not used by any execution provider.
42+
*
43+
* If not specified, defaults to false. That is, ONNX Runtime provides constant initializers as inputs to
44+
* the fused node by default.
45+
*
46+
* \since Version 1.23.
47+
*/
48+
bool drop_constant_initializers;
49+
50+
// const OrtNode* fused_node_schema;
51+
};
52+
1553
/**
1654
* \brief The OrtNodeComputeInfo struct provides functions that an OrtEp implements to specify the compute
1755
* function for a compiled OrtGraph instance.
@@ -21,7 +59,7 @@ struct OrtNodeComputeInfo {
2159
/** \brief The ONNX Runtime version the OrtNodeComputeInfo was compiled with.
2260
*
2361
* Implementation should set to ORT_API_VERSION.
24-
* ORT will use this to ensure it does not call functions that were not available when the library was compiled.
62+
* ORT will use this to ensure it does not call functions that were not available when the EP library was compiled.
2563
*
2664
* \since Version 1.23.
2765
*/
@@ -87,9 +125,6 @@ struct OrtEpApi {
87125
ORT_CLASS_RELEASE(EpDevice);
88126

89127
/** \brief Specify nodes that are supported by an OrtEp and should be fused into one node.
90-
*
91-
* IMPORTANT: This is not the final version of this API function. This is currently experimental but will
92-
* be stabilized by the ONNX Runtime 1.23 release.
93128
*
94129
* Because the nodes will be fused into one "fused node", there must not exist an unsupported node in
95130
* a path between two of the provided nodes. Otherwise, the graph will become invalid.
@@ -100,14 +135,15 @@ struct OrtEpApi {
100135
* \param[in] graph_support_info OrtEpGraphSupportInfo instance to which to add the supported nodes.
101136
* \param[in] nodes Array of nodes supported by the EP that should be fused/compiled.
102137
* \param[in] num_nodes The number of supported nodes.
138+
* \param[in] node_fusion_options Optional node fusion options. Ignored if set to NULL.
103139
*
104140
* \snippet{doc} snippets.dox OrtStatus Return Value
105141
*
106142
* \since Version 1.23.
107143
*/
108144
ORT_API2_STATUS(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* graph_support_info,
109-
_In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes
110-
/*, OrtFusedNodeSchema* optional_fused_node_schema, OrtNodesToOptimizeInfo* nodes_to_opt*/);
145+
_In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes,
146+
_In_opt_ const OrtNodeFusionOptions* node_fusion_options);
111147

112148
/** \brief Specify a node that is supported by an OrtEp and should be run with a registered EP kernel.
113149
*

onnxruntime/core/session/abi_ep_types.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
#include "core/graph/ep_api_types.h"
1111
#include "core/session/abi_devices.h"
1212

13-
onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span<const OrtNode* const> nodes) {
13+
onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span<const OrtNode* const> nodes,
14+
const OrtNodeFusionOptions* optional_fusion_options) {
1415
std::vector<const onnxruntime::EpNode*> ep_nodes;
1516
ep_nodes.reserve(nodes.size());
1617

@@ -20,14 +21,14 @@ onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span<const OrtNod
2021
ep_nodes.push_back(ep_node);
2122
}
2223

23-
node_groupings.emplace_back(NodeGroupingKind::kFusedNode, std::move(ep_nodes));
24+
node_groupings.emplace_back(NodeGroupingKind::kFusedNode, std::move(ep_nodes),
25+
optional_fusion_options != nullptr ? *optional_fusion_options : OrtNodeFusionOptions{});
2426
return onnxruntime::Status::OK();
2527
}
2628

2729
onnxruntime::Status OrtEpGraphSupportInfo::AddSingleNode(const OrtNode* node) {
2830
std::vector<const onnxruntime::EpNode*> ep_nodes;
2931
ep_nodes.push_back(onnxruntime::EpNode::ToInternal(node));
3032
node_groupings.emplace_back(NodeGroupingKind::kSingleAssignedNode, std::move(ep_nodes));
31-
3233
return onnxruntime::Status::OK();
3334
}

onnxruntime/core/session/abi_ep_types.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,19 @@ struct OrtEpGraphSupportInfo {
3030

3131
// A grouping of supported nodes that should be handled in a single ComputeCapability.
3232
struct NodeGrouping {
33-
NodeGrouping(NodeGroupingKind kind, std::vector<const onnxruntime::EpNode*>&& nodes)
34-
: kind(kind), nodes(std::move(nodes)) {}
33+
NodeGrouping(NodeGroupingKind kind, std::vector<const onnxruntime::EpNode*>&& nodes,
34+
const OrtNodeFusionOptions& fusion_options = {})
35+
: kind(kind), nodes(std::move(nodes)), fusion_options(fusion_options) {}
3536

3637
NodeGroupingKind kind = NodeGroupingKind::kInvalidGrouping;
3738
std::vector<const onnxruntime::EpNode*> nodes;
39+
OrtNodeFusionOptions fusion_options = {};
3840
};
3941

4042
explicit OrtEpGraphSupportInfo(const onnxruntime::EpGraph& graph) : ort_graph(graph) {}
4143

42-
onnxruntime::Status AddNodesToFuse(gsl::span<const OrtNode* const> nodes);
44+
onnxruntime::Status AddNodesToFuse(gsl::span<const OrtNode* const> nodes,
45+
const OrtNodeFusionOptions* node_fusion_options = nullptr);
4346
onnxruntime::Status AddSingleNode(const OrtNode* node);
4447

4548
const onnxruntime::EpGraph& ort_graph;

onnxruntime/core/session/ep_api.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ ORT_API(void, ReleaseEpDevice, _Frees_ptr_opt_ OrtEpDevice* device) {
4444
}
4545

4646
ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* ort_graph_support_info,
47-
_In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes) {
47+
_In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes,
48+
_In_opt_ const OrtNodeFusionOptions* node_fusion_options) {
4849
API_IMPL_BEGIN
4950
if (ort_graph_support_info == nullptr) {
5051
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid OrtGraph instance");
@@ -55,7 +56,7 @@ ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInf
5556
}
5657

5758
gsl::span<const OrtNode* const> nodes_span(nodes, nodes + num_nodes);
58-
ORT_API_RETURN_IF_STATUS_NOT_OK(ort_graph_support_info->AddNodesToFuse(nodes_span));
59+
ORT_API_RETURN_IF_STATUS_NOT_OK(ort_graph_support_info->AddNodesToFuse(nodes_span, node_fusion_options));
5960
return nullptr;
6061
API_IMPL_END
6162
}

onnxruntime/core/session/ep_api.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ ORT_API_STATUS_IMPL(CreateEpDevice, _In_ OrtEpFactory* ep_factory,
1818
ORT_API(void, ReleaseEpDevice, _Frees_ptr_opt_ OrtEpDevice* device);
1919

2020
ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* graph_support_info,
21-
_In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes);
21+
_In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes,
22+
_In_opt_ const OrtNodeFusionOptions* node_fusion_options);
2223
ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddSingleNode, _In_ OrtEpGraphSupportInfo* graph_support_info,
2324
_In_ const OrtNode* node);
2425
ORT_API(const char*, NodeComputeContext_NodeName, _In_ const OrtNodeComputeContext* context);

onnxruntime/core/session/ep_plugin_provider_interfaces.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
138138
} else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) {
139139
std::unordered_set<const Node*> node_set;
140140
node_set.reserve(node_grouping.nodes.size());
141+
141142
for (const EpNode* ep_node : node_grouping.nodes) {
142143
node_set.insert(&ep_node->GetInternalNode());
143144
}
@@ -151,7 +152,8 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
151152
// unsupported nodes in any path between supported nodes.
152153
std::vector<std::unique_ptr<ComputeCapability>> capabilities = utils::CreateSupportedPartitions(
153154
graph_viewer, node_set, /*stop_ops*/ {}, PluginEpMetaDefNameFunctor(generator, graph_viewer, this->Type()),
154-
this->Type(), this->Type(), /*node_unit_map*/ nullptr);
155+
this->Type(), this->Type(), /*node_unit_map*/ nullptr,
156+
node_grouping.fusion_options.drop_constant_initializers);
155157

156158
if (capabilities.size() > 1) {
157159
LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() set nodes that cannot be fused together. "

0 commit comments

Comments
 (0)