Skip to content

Commit 40b0929

Browse files
Revert fuse conv fix err (microsoft#6859)
* merge fuse cuda conv revert * resolve merge conflict revert exclude unsupported type * add Stream for slicing * remove file * add Stream Co-authored-by: RandySheriffH <[email protected]>
1 parent 29b30bb commit 40b0929

File tree

12 files changed

+299
-720
lines changed

12 files changed

+299
-720
lines changed

onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
7979
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear);
8080
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention);
8181
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
82-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedConv);
8382

8483
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
8584
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
@@ -175,7 +174,6 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
175174
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
176175
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float, LayerNormalization)>,
177176
#endif
178-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedConv)>,
179177
};
180178

181179
for (auto& function_table_entry : function_table) {

onnxruntime/contrib_ops/cuda/fused_conv.cc

Lines changed: 0 additions & 126 deletions
This file was deleted.

onnxruntime/core/graph/contrib_ops/contrib_defs.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,12 +1273,6 @@ activation.)DOC")
12731273
"",
12741274
"T",
12751275
OpSchema::Optional)
1276-
.Input(
1277-
3,
1278-
"Z",
1279-
"",
1280-
"T",
1281-
OpSchema::Optional)
12821276
.Output(
12831277
0,
12841278
"Y",

onnxruntime/core/optimizer/conv_activation_fusion.cc

Lines changed: 34 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -100,106 +100,50 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
100100
continue;
101101
}
102102

103-
if (node->GetExecutionProviderType() == onnxruntime::kCudaExecutionProvider) {
104-
if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() !=
105-
ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
106-
continue;
107-
}
108-
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13})) {
109-
Node& conv_node = *node;
110-
Node& act_node = *graph.GetNode(next_node.Index());
111-
auto node_name = graph.GenerateNodeName(conv_node.Name() + "_" + act_node.Name());
112-
Node& fused_conv = graph.AddNode(node_name,
113-
"FusedConv",
114-
node_name,
115-
conv_node.MutableInputDefs(),
116-
{},
117-
&conv_node.GetAttributes(),
118-
onnxruntime::kMSDomain);
119-
fused_conv.SetExecutionProviderType(conv_node.GetExecutionProviderType());
120-
fused_conv.AddAttribute("activation", "Relu");
121-
graph_utils::FinalizeNodeFusion(graph, {conv_node, act_node}, fused_conv);
122-
modified = true;
123-
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {6, 7, 13})) {
124-
const auto& last_node = *(next_node.OutputNodesBegin());
125-
if (last_node.GetExecutionProviderType() != node->GetExecutionProviderType()) {
126-
continue;
127-
}
128-
if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Relu", {6, 13}) &&
129-
next_node.GetOutputEdgesCount() == 1) {
130-
Node& conv_node = *node;
131-
Node& add_node = *graph.GetNode(next_node.Index());
132-
Node& act_node = *graph.GetNode(last_node.Index());
133-
auto conv_inputs = conv_node.MutableInputDefs();
134-
auto conv_outputs = conv_node.MutableOutputDefs();
135-
auto add_inputs = add_node.MutableInputDefs();
136-
for (auto add_input : add_inputs) {
137-
if (add_input->Name() != conv_outputs[0]->Name()) {
138-
conv_inputs.push_back(add_input);
139-
break;
140-
}
141-
}
142-
auto node_name = graph.GenerateNodeName(conv_node.Name() + "_" +
143-
add_node.Name() + "_" +
144-
act_node.Name());
145-
Node& fused_conv = graph.AddNode(node_name,
146-
"FusedConv",
147-
node_name,
148-
conv_inputs,
149-
{}, &conv_node.GetAttributes(),
150-
onnxruntime::kMSDomain);
151-
fused_conv.SetExecutionProviderType(conv_node.GetExecutionProviderType());
152-
fused_conv.AddAttribute("activation", "Relu");
153-
graph_utils::FinalizeNodeFusion(graph, {conv_node, add_node, act_node}, fused_conv);
154-
modified = true;
155-
}
156-
}
157-
} else {
158-
// Test if this is an activation that can be fused and also extract the
159-
// activation's parameters.
160-
std::vector<float> activation_params;
161-
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13}) &&
162-
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Sigmoid", {6, 13}) &&
163-
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6, 13})) {
164-
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LeakyRelu", {6})) {
165-
activation_params.push_back(graph_utils::GetNodeAttribute(next_node, "alpha")->f());
166-
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12, 13})) {
167-
float min, max;
168-
if (GetClipConstantMinMax(graph, next_node, min, max)) {
169-
activation_params.push_back(min);
170-
activation_params.push_back(max);
171-
} else {
172-
continue;
173-
}
103+
// Test if this is an activation that can be fused and also extract the
104+
// activation's parameters.
105+
std::vector<float> activation_params;
106+
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13}) &&
107+
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Sigmoid", {6, 13}) &&
108+
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6, 13})) {
109+
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LeakyRelu", {6})) {
110+
activation_params.push_back(graph_utils::GetNodeAttribute(next_node, "alpha")->f());
111+
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12, 13})) {
112+
float min, max;
113+
if (GetClipConstantMinMax(graph, next_node, min, max)) {
114+
activation_params.push_back(min);
115+
activation_params.push_back(max);
174116
} else {
175117
continue;
176118
}
119+
} else {
120+
continue;
177121
}
122+
}
178123

179-
Node& conv_node = *node;
180-
Node& act_node = *graph.GetNode(next_node.Index());
124+
Node& conv_node = *node;
125+
Node& act_node = *graph.GetNode(next_node.Index());
181126

182-
Node& fused_conv = graph.AddNode(graph.GenerateNodeName("fused " + conv_node.Name()), "FusedConv",
183-
"fused Conv " + conv_node.Name() + "with activation " + act_node.OpType(),
184-
conv_node.MutableInputDefs(),
185-
{},
186-
&conv_node.GetAttributes(),
187-
"com.microsoft");
127+
Node& fused_conv = graph.AddNode(graph.GenerateNodeName("fused " + conv_node.Name()), "FusedConv",
128+
"fused Conv " + conv_node.Name() + "with activation " + act_node.OpType(),
129+
conv_node.MutableInputDefs(),
130+
{},
131+
&conv_node.GetAttributes(),
132+
"com.microsoft");
188133

189-
// Assign provider to this new node. Provider should be same as the provider for old node.
190-
fused_conv.SetExecutionProviderType(conv_node.GetExecutionProviderType());
134+
// Assign provider to this new node. Provider should be same as the provider for old node.
135+
fused_conv.SetExecutionProviderType(conv_node.GetExecutionProviderType());
191136

192-
// Add attributes to specify the activation type and parameters.
193-
fused_conv.AddAttribute("activation", next_node.OpType());
194-
if (activation_params.size() > 0) {
195-
fused_conv.AddAttribute("activation_params", activation_params);
196-
}
137+
// Add attributes to specify the activation type and parameters.
138+
fused_conv.AddAttribute("activation", next_node.OpType());
139+
if (activation_params.size() > 0) {
140+
fused_conv.AddAttribute("activation_params", activation_params);
141+
}
197142

198-
// move output definitions and edges from act_node to fused_conv. delete conv_node and act_node.
199-
graph_utils::FinalizeNodeFusion(graph, {conv_node, act_node}, fused_conv);
143+
// move output definitions and edges from act_node to fused_conv. delete conv_node and act_node.
144+
graph_utils::FinalizeNodeFusion(graph, {conv_node, act_node}, fused_conv);
200145

201-
modified = true;
202-
}
146+
modified = true;
203147
}
204148

205149
return Status::OK();

onnxruntime/core/optimizer/graph_transformer_utils.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
142142
transformers.emplace_back(onnxruntime::make_unique<DynamicQuantizeMatMulFusion>(cpu_execution_providers));
143143

144144
std::unordered_set<std::string> cpu_acl_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kAclExecutionProvider};
145-
std::unordered_set<std::string> cpu_cuda_acl_armnn_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider, onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider};
145+
std::unordered_set<std::string> cpu_acl_armnn_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider};
146146

147-
transformers.emplace_back(onnxruntime::make_unique<ConvActivationFusion>(cpu_cuda_acl_armnn_execution_providers));
147+
transformers.emplace_back(onnxruntime::make_unique<ConvActivationFusion>(cpu_acl_armnn_execution_providers));
148148

149149
std::unordered_set<std::string> cpu_cuda_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider};
150150
transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(cpu_cuda_execution_providers));

0 commit comments

Comments
 (0)