Skip to content

Commit 25b06f2

Browse files
authored
Fix layout transformer for FusedConv (#24169)
### Description Fix layout transformer for FusedConv. The current layout transformer will transform `FusedConv` (kMSDomain) into `FusedConv` (kMSInternalNHWCDomain) if the EP wants channels_last. However, kMSInternalNHWCDomain uses OpType `Conv` for both Conv and FusedConv, so `FusedConv` (kMSInternalNHWCDomain) is invalid (unregistered op). This PR fixes this and allows layout transformer change `FusedConv` (kMSDomain) into `Conv` (kMSInternalNHWCDomain). ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 1ef3044 commit 25b06f2

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,14 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid
145145
}
146146

147147
if (ConvertNodeLayout(*node)) {
148+
// domain kMSInternalNHWCDomain uses OpType "Conv" for both Conv and FusedConv.
149+
// So, change the OpType to "Conv" for FusedConv.
150+
std::string_view op_type = node->OpType() == "FusedConv" ? "Conv" : node->OpType();
151+
148152
// if already transformed then change the domain to kMSInternalNHWCDomain this way the EP
149153
// knows this op is in the expected format.
150154
if (node->GetAttributeIntDefault("channels_last", 0) == 1) {
151-
SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain);
155+
SwapNodeOpTypeAndDomain(*api_graph, *node, op_type, kMSInternalNHWCDomain);
152156
// Changing the domain for the node requires creating a new node and replacing the old one
153157
// therefore set the modified flag.
154158
modified = true;
@@ -175,7 +179,7 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid
175179
// Except for resize and convolution ops, all the other layout sensitive ops only require layout transformation
176180
// for 0th input and output. For resize, add the other relevant inputs which need conversion. For Conv - layout
177181
// transformer only converts layout for 0th input, weights should be handled by every EP.
178-
if (node->OpType() == "Resize") {
182+
if (op_type == "Resize") {
179183
// Older versions of resize have a bug where ROI and Scales cannot be made empty inputs. To handle this case,
180184
// we need to jump a few extra hoops to make sure its inputs are correctly handled.
181185
//
@@ -205,7 +209,7 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid
205209
WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm});
206210
}
207211

208-
SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain);
212+
SwapNodeOpTypeAndDomain(*api_graph, *node, op_type, kMSInternalNHWCDomain);
209213
modified = true;
210214
}
211215
}

0 commit comments

Comments
 (0)