@@ -100,106 +100,50 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
100
100
continue ;
101
101
}
102
102
103
- if (node->GetExecutionProviderType () == onnxruntime::kCudaExecutionProvider ) {
104
- if (node->InputDefs ()[0 ]->TypeAsProto ()->tensor_type ().elem_type () !=
105
- ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
106
- continue ;
107
- }
108
- if (graph_utils::IsSupportedOptypeVersionAndDomain (next_node, " Relu" , {6 , 13 })) {
109
- Node& conv_node = *node;
110
- Node& act_node = *graph.GetNode (next_node.Index ());
111
- auto node_name = graph.GenerateNodeName (conv_node.Name () + " _" + act_node.Name ());
112
- Node& fused_conv = graph.AddNode (node_name,
113
- " FusedConv" ,
114
- node_name,
115
- conv_node.MutableInputDefs (),
116
- {},
117
- &conv_node.GetAttributes (),
118
- onnxruntime::kMSDomain );
119
- fused_conv.SetExecutionProviderType (conv_node.GetExecutionProviderType ());
120
- fused_conv.AddAttribute (" activation" , " Relu" );
121
- graph_utils::FinalizeNodeFusion (graph, {conv_node, act_node}, fused_conv);
122
- modified = true ;
123
- } else if (graph_utils::IsSupportedOptypeVersionAndDomain (next_node, " Add" , {6 , 7 , 13 })) {
124
- const auto & last_node = *(next_node.OutputNodesBegin ());
125
- if (last_node.GetExecutionProviderType () != node->GetExecutionProviderType ()) {
126
- continue ;
127
- }
128
- if (graph_utils::IsSupportedOptypeVersionAndDomain (last_node, " Relu" , {6 , 13 }) &&
129
- next_node.GetOutputEdgesCount () == 1 ) {
130
- Node& conv_node = *node;
131
- Node& add_node = *graph.GetNode (next_node.Index ());
132
- Node& act_node = *graph.GetNode (last_node.Index ());
133
- auto conv_inputs = conv_node.MutableInputDefs ();
134
- auto conv_outputs = conv_node.MutableOutputDefs ();
135
- auto add_inputs = add_node.MutableInputDefs ();
136
- for (auto add_input : add_inputs) {
137
- if (add_input->Name () != conv_outputs[0 ]->Name ()) {
138
- conv_inputs.push_back (add_input);
139
- break ;
140
- }
141
- }
142
- auto node_name = graph.GenerateNodeName (conv_node.Name () + " _" +
143
- add_node.Name () + " _" +
144
- act_node.Name ());
145
- Node& fused_conv = graph.AddNode (node_name,
146
- " FusedConv" ,
147
- node_name,
148
- conv_inputs,
149
- {}, &conv_node.GetAttributes (),
150
- onnxruntime::kMSDomain );
151
- fused_conv.SetExecutionProviderType (conv_node.GetExecutionProviderType ());
152
- fused_conv.AddAttribute (" activation" , " Relu" );
153
- graph_utils::FinalizeNodeFusion (graph, {conv_node, add_node, act_node}, fused_conv);
154
- modified = true ;
155
- }
156
- }
157
- } else {
158
- // Test if this is an activation that can be fused and also extract the
159
- // activation's parameters.
160
- std::vector<float > activation_params;
161
- if (!graph_utils::IsSupportedOptypeVersionAndDomain (next_node, " Relu" , {6 , 13 }) &&
162
- !graph_utils::IsSupportedOptypeVersionAndDomain (next_node, " Sigmoid" , {6 , 13 }) &&
163
- !graph_utils::IsSupportedOptypeVersionAndDomain (next_node, " Tanh" , {6 , 13 })) {
164
- if (graph_utils::IsSupportedOptypeVersionAndDomain (next_node, " LeakyRelu" , {6 })) {
165
- activation_params.push_back (graph_utils::GetNodeAttribute (next_node, " alpha" )->f ());
166
- } else if (graph_utils::IsSupportedOptypeVersionAndDomain (next_node, " Clip" , {6 , 11 , 12 , 13 })) {
167
- float min, max;
168
- if (GetClipConstantMinMax (graph, next_node, min, max)) {
169
- activation_params.push_back (min);
170
- activation_params.push_back (max);
171
- } else {
172
- continue ;
173
- }
103
+ // Test if this is an activation that can be fused and also extract the
104
+ // activation's parameters.
105
+ std::vector<float > activation_params;
106
+ if (!graph_utils::IsSupportedOptypeVersionAndDomain (next_node, " Relu" , {6 , 13 }) &&
107
+ !graph_utils::IsSupportedOptypeVersionAndDomain (next_node, " Sigmoid" , {6 , 13 }) &&
108
+ !graph_utils::IsSupportedOptypeVersionAndDomain (next_node, " Tanh" , {6 , 13 })) {
109
+ if (graph_utils::IsSupportedOptypeVersionAndDomain (next_node, " LeakyRelu" , {6 })) {
110
+ activation_params.push_back (graph_utils::GetNodeAttribute (next_node, " alpha" )->f ());
111
+ } else if (graph_utils::IsSupportedOptypeVersionAndDomain (next_node, " Clip" , {6 , 11 , 12 , 13 })) {
112
+ float min, max;
113
+ if (GetClipConstantMinMax (graph, next_node, min, max)) {
114
+ activation_params.push_back (min);
115
+ activation_params.push_back (max);
174
116
} else {
175
117
continue ;
176
118
}
119
+ } else {
120
+ continue ;
177
121
}
122
+ }
178
123
179
- Node& conv_node = *node;
180
- Node& act_node = *graph.GetNode (next_node.Index ());
124
+ Node& conv_node = *node;
125
+ Node& act_node = *graph.GetNode (next_node.Index ());
181
126
182
- Node& fused_conv = graph.AddNode (graph.GenerateNodeName (" fused " + conv_node.Name ()), " FusedConv" ,
183
- " fused Conv " + conv_node.Name () + " with activation " + act_node.OpType (),
184
- conv_node.MutableInputDefs (),
185
- {},
186
- &conv_node.GetAttributes (),
187
- " com.microsoft" );
127
+ Node& fused_conv = graph.AddNode (graph.GenerateNodeName (" fused " + conv_node.Name ()), " FusedConv" ,
128
+ " fused Conv " + conv_node.Name () + " with activation " + act_node.OpType (),
129
+ conv_node.MutableInputDefs (),
130
+ {},
131
+ &conv_node.GetAttributes (),
132
+ " com.microsoft" );
188
133
189
- // Assign provider to this new node. Provider should be same as the provider for old node.
190
- fused_conv.SetExecutionProviderType (conv_node.GetExecutionProviderType ());
134
+ // Assign provider to this new node. Provider should be same as the provider for old node.
135
+ fused_conv.SetExecutionProviderType (conv_node.GetExecutionProviderType ());
191
136
192
- // Add attributes to specify the activation type and parameters.
193
- fused_conv.AddAttribute (" activation" , next_node.OpType ());
194
- if (activation_params.size () > 0 ) {
195
- fused_conv.AddAttribute (" activation_params" , activation_params);
196
- }
137
+ // Add attributes to specify the activation type and parameters.
138
+ fused_conv.AddAttribute (" activation" , next_node.OpType ());
139
+ if (activation_params.size () > 0 ) {
140
+ fused_conv.AddAttribute (" activation_params" , activation_params);
141
+ }
197
142
198
- // move output definitions and edges from act_node to fused_conv. delete conv_node and act_node.
199
- graph_utils::FinalizeNodeFusion (graph, {conv_node, act_node}, fused_conv);
143
+ // move output definitions and edges from act_node to fused_conv. delete conv_node and act_node.
144
+ graph_utils::FinalizeNodeFusion (graph, {conv_node, act_node}, fused_conv);
200
145
201
- modified = true ;
202
- }
146
+ modified = true ;
203
147
}
204
148
205
149
return Status::OK ();
0 commit comments