Skip to content

Commit e8b327d

Browse files
authored
Fix constant folding of node assigned to CUDA (microsoft#2510)
* Constant folding bug fix/improvements - Handle constant folding for node that is assigned to a non cpu EP - Check for errors in optimizer execution frame setup - Improve CUDA partitioning to look for initializers in parent graphs - Add unit test Fixes microsoft#2474
1 parent 4354023 commit e8b327d

File tree

5 files changed

+79
-36
lines changed

5 files changed

+79
-36
lines changed

onnxruntime/core/optimizer/constant_folding.cc

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,17 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
2525

2626
InitializedTensorSet constant_inputs;
2727

28+
// we currently constant fold using the CPU EP only.
29+
// if the node is assigned to a different EP we can run it if it's an ONNX op as we have CPU based implementations
30+
// for all ONNX ops. if it's from a different domain we can't.
31+
// NOTE: This is in addition to the IsSupportedProvider check below which will optionally do further filtering
32+
// on the EPs we constant fold for.
33+
auto ep_type = node->GetExecutionProviderType();
34+
bool cpu_ep = ep_type == kCpuExecutionProvider;
35+
if (!cpu_ep && node->Domain() != kOnnxDomain) {
36+
continue;
37+
}
38+
2839
// Check if constant folding can be applied on this node.
2940
if (!graph_utils::IsSupportedProvider(*node, GetCompatibleExecutionProviders()) ||
3041
excluded_op_types_.find(node->OpType()) != excluded_op_types_.end() ||
@@ -36,9 +47,19 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
3647
continue;
3748
}
3849

50+
// override the EP while setting up OptimizerExecutionFrame::Info so that it will use the CPU kernel for Compute.
51+
if (!cpu_ep) {
52+
node->SetExecutionProviderType(kCpuExecutionProvider);
53+
}
54+
3955
// Create execution frame for executing constant nodes.
4056
OptimizerExecutionFrame::Info info({node}, constant_inputs);
4157

58+
// undo the EP change in case something fails prior to node removal
59+
if (!cpu_ep) {
60+
node->SetExecutionProviderType(ep_type);
61+
}
62+
4263
std::vector<int> fetch_mlvalue_idxs;
4364
for (const auto* node_out : node->OutputDefs()) {
4465
fetch_mlvalue_idxs.push_back(info.GetMLValueIndex(node_out->Name()));
@@ -62,8 +83,8 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
6283
OrtValue& ort_value = fetches[fetch_idx];
6384

6485
if (!ort_value.IsTensor()) {
65-
LOGS(logger, WARNING) << "Unsupported output type of " << ort_value.Type()
66-
<< ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'";
86+
LOGS(logger, WARNING) << "Unsupported output type of " << ort_value.Type()
87+
<< ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'";
6788
unsupported_output_type = true;
6889
break;
6990
}

onnxruntime/core/optimizer/constant_folding.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ it statically computes parts of the graph that rely only on constant initializer
1616
*/
1717
class ConstantFolding : public GraphTransformer {
1818
public:
19-
ConstantFolding(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept :
20-
GraphTransformer("ConstantFolding", compatible_execution_providers) {}
19+
ConstantFolding(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
20+
: GraphTransformer("ConstantFolding", compatible_execution_providers) {}
2121

2222
private:
2323
/** Constant folding will not be applied to nodes whose op_type is included in this set.
@@ -26,11 +26,6 @@ class ConstantFolding : public GraphTransformer {
2626
{"RandomUniform", "RandomNormal", "RandomUniformLike", "RandomNormalLike", "Multinomial"};
2727

2828
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
29-
30-
/** Create a TensorProto that has the same value as the given OrtValue
31-
and the same type and dimensions as the given NodeArg. */
32-
void BuildTensorProtoForInitializer(const OrtValue& ort_value, const NodeArg& constant_node_arg,
33-
ONNX_NAMESPACE::TensorProto& tensorproto) const;
3429
};
3530

3631
} // namespace onnxruntime

onnxruntime/core/optimizer/optimizer_execution_frame.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ OptimizerExecutionFrame::Info::Info(const std::vector<const Node*>& nodes,
5757

5858
// TODO: node->ImplicitInputDefs() need to be added here for control flow nodes.
5959
for (auto* node : nodes) {
60-
onnxruntime::Node::ForEachWithIndex(node->InputDefs(), initialize_maps);
61-
onnxruntime::Node::ForEachWithIndex(node->OutputDefs(), initialize_maps);
60+
ORT_THROW_IF_ERROR(onnxruntime::Node::ForEachWithIndex(node->InputDefs(), initialize_maps));
61+
ORT_THROW_IF_ERROR(onnxruntime::Node::ForEachWithIndex(node->OutputDefs(), initialize_maps));
6262
}
6363

6464
node_index_info_ = onnxruntime::make_unique<NodeIndexInfo>(nodes, ort_value_name_idx_map_);
@@ -67,8 +67,9 @@ OptimizerExecutionFrame::Info::Info(const std::vector<const Node*>& nodes,
6767
for (auto* node : nodes) {
6868
std::unique_ptr<OpKernel> op_kernel;
6969
std::shared_ptr<KernelRegistry> kernel_registry = cpu_execution_provider_->GetKernelRegistry();
70-
auto status = kernel_registry->TryCreateKernel(*node, *cpu_execution_provider_, initializers_,
71-
ort_value_name_idx_map_, FuncManager(), data_transfer_mgr_, op_kernel);
70+
ORT_THROW_IF_ERROR(kernel_registry->TryCreateKernel(*node, *cpu_execution_provider_, initializers_,
71+
ort_value_name_idx_map_, FuncManager(), data_transfer_mgr_,
72+
op_kernel));
7273
kernels_[node->Index()] = std::move(op_kernel);
7374
}
7475
}
@@ -118,8 +119,8 @@ Status OptimizerExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value,
118119
auto element_type = static_cast<const TensorTypeBase*>(ml_type)->GetElementType();
119120
AllocatorPtr allocator_ptr = info_.GetAllocator();
120121
std::unique_ptr<Tensor> p_tensor = onnxruntime::make_unique<Tensor>(element_type,
121-
*shape,
122-
allocator_ptr);
122+
*shape,
123+
allocator_ptr);
123124

124125
auto ml_tensor = DataTypeImpl::GetType<Tensor>();
125126
ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc());

onnxruntime/core/providers/cuda/cuda_execution_provider.cc

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
#include "cuda_common.h"
55
#include "cuda_execution_provider.h"
6-
#include "core/framework/memcpy.h"
76
#include "cuda_fence.h"
87
#include "cuda_allocator.h"
98
#include "core/framework/kernel_registry.h"
109
#include "core/framework/compute_capability.h"
10+
#include "core/framework/memcpy.h"
11+
#include "core/graph/graph_utils.h"
1112
#include "core/providers/cuda/gpu_data_transfer.h"
1213

1314
#ifndef DISABLE_CONTRIB_OPS
@@ -1303,28 +1304,27 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
13031304
// Note that nodes with only inputs from initializer would not be place on CUDA
13041305
// Ideally, those nodes should be eliminated in constant folding
13051306
bool should_force_outside = true;
1306-
bool all_input_are_initializer = true;
1307-
node.ForEachWithIndex(
1308-
node.InputDefs(),
1309-
[&](const NodeArg& def, size_t index) {
1310-
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
1311-
// The input is not a initializer and the input is from CPU
1312-
// or the input declared as CPU memory and is from CPU
1313-
// in that case we should still keep the node on CUDA
1314-
bool initializer_input = graph.GetInitializedTensor(def.Name(), initializer);
1315-
bool input_is_on_cpu = defs_outside_cuda.count(&def) > 0;
1316-
if ((!initializer_input && !input_is_on_cpu) ||
1317-
(input_is_on_cpu && cuda_kernel_def->kernel_def->IsInputOnCpu(index)))
1318-
should_force_outside = false;
1307+
bool all_inputs_are_initializers = true;
1308+
node.ForEachWithIndex(node.InputDefs(),
1309+
[&](const NodeArg& def, size_t index) {
1310+
// The input is not a initializer and the input is from CPU
1311+
// or the input declared as CPU memory and is from CPU
1312+
// in that case we should still keep the node on CUDA
1313+
bool initializer_input = graph.IsConstantInitializer(def.Name(), /*check_outer_scope*/ true);
1314+
bool input_is_on_cpu = defs_outside_cuda.count(&def) > 0;
1315+
if ((!initializer_input && !input_is_on_cpu) ||
1316+
(input_is_on_cpu && cuda_kernel_def->kernel_def->IsInputOnCpu(index))) {
1317+
should_force_outside = false;
1318+
}
13191319

1320-
if (!initializer_input) {
1321-
all_input_are_initializer = false;
1322-
}
1323-
return Status::OK();
1324-
});
1320+
if (!initializer_input) {
1321+
all_inputs_are_initializers = false;
1322+
}
1323+
return Status::OK();
1324+
});
13251325

13261326
// If all the inputs are initializers, we shouldn't force it to CPU
1327-
if (should_force_outside && !all_input_are_initializer) {
1327+
if (should_force_outside && !all_inputs_are_initializers) {
13281328
force_outside = true;
13291329
}
13301330
}

onnxruntime/test/optimizer/graph_transform_test.cc

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,33 @@ TEST(GraphTransformationTests, ConstantFolding) {
131131
ASSERT_TRUE(op_to_count["Unsqueeze"] == 0);
132132
}
133133

134+
TEST(GraphTransformationTests, ConstantFoldingNodesOnDifferentEP) {
135+
auto model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx";
136+
std::shared_ptr<Model> model;
137+
ASSERT_TRUE(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
138+
Graph& graph = model->MainGraph();
139+
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
140+
ASSERT_TRUE(op_to_count["Unsqueeze"] == 2);
141+
142+
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
143+
graph_transformation_mgr.Register(onnxruntime::make_unique<ConstantFolding>(), TransformerLevel::Level1);
144+
145+
// assign all nodes to CUDA. the constant folding should override this to perform the constant folding on cpu
146+
for (auto& node : graph.Nodes()) {
147+
node.SetExecutionProviderType(kCudaExecutionProvider);
148+
}
149+
150+
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK());
151+
152+
op_to_count = CountOpsInGraph(graph);
153+
ASSERT_TRUE(op_to_count["Unsqueeze"] == 0);
154+
155+
// all remaining nodes should still be on CUDA
156+
for (auto& node : graph.Nodes()) {
157+
EXPECT_STREQ(node.GetExecutionProviderType().c_str(), kCudaExecutionProvider);
158+
}
159+
}
160+
134161
TEST(GraphTransformationTests, ConstantFoldingSubgraph) {
135162
TensorProto value_tensor;
136163
value_tensor.add_dims(1);
@@ -1010,7 +1037,6 @@ static void ValidateAttention(Graph& graph) {
10101037
for (size_t i = 0; i < expected_value2.size(); i++) {
10111038
EXPECT_EQ(data2[i], static_cast<float>(expected_value2[i]));
10121039
}
1013-
10141040
}
10151041
}
10161042
}

0 commit comments

Comments
 (0)