Skip to content

Commit b85805e

Browse files
authored
Handle edge case with implicit input and multiple levels of subgraphs (microsoft#4031)
* Handle edge case where an implicit input for a subgraph may not get wired in correctly. Conditions required: - two or more levels of nested subgraph - an implicit input from above the bottom two levels is used in both levels of subgraph - this creates a NodeArg for the implicit input at both levels - something changes to the first level subgraph to no longer use the implicit input - could be constant folding, could be partitioning of nodes results in a copy of the implicit input being made to a different device When that occurs we lose the wiring through to the second level of nested subgraph as there's a NodeArg in the first level but the implicit input is no longer used there. Fix that by doing a final check for outer scope values once we know all the outputs produced by the current graph. Found by commenting out the CUDA implementations of the control flow nodes and running ssd_mobilenet_300 from the mlperf models. * Add test case.
1 parent c331d8c commit b85805e

File tree

4 files changed

+118
-3
lines changed

4 files changed

+118
-3
lines changed

onnxruntime/core/graph/graph.cc

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,8 @@ Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node
11241124
const std::unordered_set<std::string>& outer_scope_node_args = resolve_context_.outer_scope_node_args;
11251125
std::unordered_set<Node*> inner_nodes;
11261126

1127+
std::unordered_set<std::string> node_args_consumed_by_subgraphs;
1128+
11271129
// recurse into subgraphs first so we can update any nodes in this graph that are used by those subgraphs
11281130
if (!resolve_context_.nodes_with_subgraphs.empty()) {
11291131
for (auto* node : resolve_context_.nodes_with_subgraphs) {
@@ -1157,6 +1159,12 @@ Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node
11571159
"This is an invalid model. Failed to find NodeArg in all parent graphs. Name=", node_arg_name,
11581160
" Graph may not conform to the ONNX spec and contain initializers that are not graph inputs.");
11591161
}
1162+
} else {
1163+
// this value may be produced by this graph, or it could still be coming from a parent graph if it
1164+
// is also directly consumed at this level as we create a NodeArg for all Node inputs in this graph.
1165+
// due to that we need to check the outputs from this level to determine if it is an outer scope value.
1166+
// we don't have that info yet so store and check before returning from BuildConnections
1167+
ORT_IGNORE_RETURN_VALUE(node_args_consumed_by_subgraphs.insert(node_arg_name));
11601168
}
11611169

11621170
// add it to the Node's list of implicit inputs
@@ -1178,8 +1186,9 @@ Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node
11781186

11791187
inner_nodes.insert(&output_node);
11801188

1181-
// If this Graph was built manually, remove the implicit input from the graph outputs if it is present there
1182-
// and not explicitly listed in the ordered graph outputs (as that implies we should leave it as an output).
1189+
// If this Graph was built manually, remove the implicit input from the graph outputs
1190+
// if it is present there and not explicitly listed in the ordered graph outputs
1191+
// (as that implies we should leave it as an output).
11831192
// If the Graph was loaded from a GraphProto, honor the explicit graph outputs and leave as is.
11841193
if (!is_loaded_from_model_file_) {
11851194
graph_outputs_.erase(std::remove(graph_outputs_.begin(), graph_outputs_.end(), node_arg),
@@ -1252,8 +1261,17 @@ Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node
12521261
}
12531262
}
12541263

1264+
// finally check any node args consumed by subgraphs to see if they're available locally.
1265+
// if not we add them to the list of outer scope values consumed.
1266+
for (const auto& name : node_args_consumed_by_subgraphs) {
1267+
if (node_arg_to_producer_node_.count(name) == 0 &&
1268+
resolve_context_.inputs_and_initializers.find(name) == resolve_context_.inputs_and_initializers.cend()) {
1269+
ORT_IGNORE_RETURN_VALUE(outer_scope_node_args_consumed.insert(name));
1270+
}
1271+
}
1272+
12551273
return Status::OK();
1256-
} // namespace onnxruntime
1274+
}
12571275

12581276
void Graph::ReverseDFSFrom(const std::vector<NodeIndex>& from,
12591277
const std::function<void(const Node*)>& enter,

onnxruntime/test/providers/cpu/controlflow/loop_test.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,37 @@ TEST(Loop, PassThroughSubgraphInputNoTypeOrShape) {
929929
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
930930
}
931931

932+
TEST(Loop, BugFixIssue4031_implicit_input_handling) {
933+
SessionOptions so;
934+
so.graph_optimization_level = TransformerLevel::Level2; // we need constant folding to run
935+
InferenceSession session_object{so, GetEnvironment()};
936+
static constexpr const ORTCHAR_T* MODEL_URI = ORT_TSTR("testdata/ort_github_issue_4031.onnx");
937+
938+
ASSERT_STATUS_OK(session_object.Load(MODEL_URI));
939+
ASSERT_STATUS_OK(session_object.Initialize());
940+
941+
onnxruntime::RunOptions run_options;
942+
run_options.run_tag = "BugFixIssue4031_implicit_input_handling";
943+
944+
// prepare inputs
945+
OrtValue ml_value;
946+
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), {1}, {123.f},
947+
&ml_value);
948+
NameMLValMap feeds;
949+
feeds.insert(std::make_pair("state_var_in", ml_value));
950+
951+
// prepare outputs
952+
std::vector<std::string> output_names{"state_var_out"};
953+
std::vector<OrtValue> fetches;
954+
955+
// Now run
956+
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));
957+
958+
const auto& output = fetches[0].Get<Tensor>();
959+
ASSERT_TRUE(output.Shape().Size() == 1);
960+
ASSERT_TRUE(output.Data<float>()[0] == 125.f);
961+
}
962+
932963
#ifdef USE_CUDA
933964
// test that when part of the subgraph run on CUDA it executes successfully
934965
TEST(Loop, MixedExecutionProviders) {
Binary file not shown.
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import onnx
2+
from onnx import helper
3+
from onnx import TensorProto
4+
5+
if_body = helper.make_graph(
6+
[
7+
# need to use main_graph_initializer in a way that can't be constant folded
8+
helper.make_node("Add", ["state_var_in", "main_graph_initializer"], ["add_out"], "If_add"),
9+
helper.make_node("Cast", ["add_out"], ["output"], to=TensorProto.BOOL),
10+
],
11+
"if_branch_body",
12+
[
13+
# no explicit inputs
14+
],
15+
[
16+
helper.make_tensor_value_info('output', TensorProto.BOOL, [1]), # how is this getting a type of float?
17+
])
18+
19+
# Loop body graph with If node and usage of main_graph_initializer on this level
20+
body = helper.make_graph(
21+
[
22+
# Add node that can be constant folded. Creates NodeArg when created but that implicit usage of an outer scope
23+
# value main_graph_initializer goes away after constant folding
24+
helper.make_node("Add", ["sub_graph_initializer", "main_graph_initializer"], ["initializer_sum"], "Add1"),
25+
helper.make_node("Add", ["initializer_sum", "loop_state_in"], ["loop_state_out"], "Add2"),
26+
# If node to create usage of main_graph_initializer another level down
27+
helper.make_node("If", ["subgraph_keep_going_in"], ["subgraph_keep_going_out"], "If1",
28+
then_branch=if_body, else_branch=if_body),
29+
],
30+
"Loop_body",
31+
[
32+
helper.make_tensor_value_info('iteration_num', TensorProto.INT64, [1]),
33+
helper.make_tensor_value_info('subgraph_keep_going_in', TensorProto.BOOL, [1]),
34+
helper.make_tensor_value_info('loop_state_in', TensorProto.FLOAT, [1])
35+
],
36+
[
37+
helper.make_tensor_value_info('subgraph_keep_going_out', TensorProto.BOOL, [1]),
38+
helper.make_tensor_value_info('loop_state_out', TensorProto.FLOAT, [1]),
39+
],
40+
[
41+
helper.make_tensor('sub_graph_initializer', TensorProto.FLOAT, [1], [1.])
42+
]
43+
)
44+
45+
# Create the main graph
46+
graph_proto = helper.make_graph(
47+
[
48+
helper.make_node("Loop", ["max_trip_count", "keep_going", "state_var_in"],
49+
["state_var_out"], "Loop1", body=body)
50+
],
51+
"Main_graph",
52+
[
53+
helper.make_tensor_value_info('state_var_in', TensorProto.FLOAT, [1]),
54+
],
55+
[
56+
helper.make_tensor_value_info('state_var_out', TensorProto.FLOAT, [1]),
57+
],
58+
[
59+
helper.make_tensor('max_trip_count', TensorProto.INT64, [1], [1]),
60+
helper.make_tensor('main_graph_initializer', TensorProto.FLOAT, [1], [1.]),
61+
helper.make_tensor('keep_going', TensorProto.BOOL, [1], [True]),
62+
]
63+
)
64+
65+
model = helper.make_model(graph_proto)
66+
onnx.save(model, 'ort_github_issue_4031.onnx')

0 commit comments

Comments
 (0)