@@ -122,6 +122,7 @@ Status ModelBuilder::Prepare() {
122
122
RETURN_STATUS_ON_ERROR (nnapi_->ANeuralNetworksModel_create (&nnapi_model_->model_ ));
123
123
ORT_RETURN_IF_ERROR (GetTargetDevices ());
124
124
PreprocessInitializers ();
125
+ PreprocessActivations ();
125
126
ORT_RETURN_IF_ERROR (RegisterInitializers ());
126
127
ORT_RETURN_IF_ERROR (RegisterModelInputs ());
127
128
ORT_RETURN_IF_ERROR (AddOperations ());
@@ -190,6 +191,28 @@ void ModelBuilder::PreprocessInitializers() {
190
191
}
191
192
}
192
193
194
+ void ModelBuilder::PreprocessActivations () {
195
+ const auto & node_indices = graph_viewer_.GetNodesInTopologicalOrder ();
196
+ for (size_t i = 0 ; i < node_indices.size (); i++) {
197
+ const auto * node (graph_viewer_.GetNode (node_indices[i]));
198
+ const auto & op_type (node->OpType ());
199
+
200
+ if (op_type == " Relu" ) {
201
+ activation_nodes_.emplace (node->Index (), ANEURALNETWORKS_FUSED_RELU);
202
+ } else if (op_type == " Clip" ) { // Relu1 or Relu6
203
+ float min, max;
204
+ if (!GetClipMinMax (*this , *node, min, max))
205
+ continue ;
206
+
207
+ if (min == -1 .0f && max == 1 .0f ) {
208
+ activation_nodes_.emplace (node->Index (), ANEURALNETWORKS_FUSED_RELU1);
209
+ } else if (min == 0 .0f && max == 6 .0f ) {
210
+ activation_nodes_.emplace (node->Index (), ANEURALNETWORKS_FUSED_RELU6);
211
+ }
212
+ }
213
+ }
214
+ }
215
+
193
216
// Help to get all quantized operators' input and the node(s) using the input
194
217
std::unordered_map<std::string, vector<const Node*>> GetAllQuantizedOpInputs (const GraphViewer& graph_viewer) {
195
218
std::unordered_map<std::string, vector<const Node*>> all_quantized_op_inputs;
@@ -554,9 +577,9 @@ int32_t ModelBuilder::FindActivation(const Node& node, const NodeArg& output) {
554
577
for (auto it = node.OutputEdgesBegin (), end = node.OutputEdgesEnd (); it != end; ++it) {
555
578
const auto & dst_node = it->GetNode ();
556
579
const auto * dst_input = dst_node.InputDefs ()[it->GetDstArgIndex ()];
557
- if (dst_node.OpType () == " Relu " ) {
580
+ if (Contains (activation_nodes_, dst_node.Index ()) ) {
558
581
if (&output == dst_input) {
559
- fuse_code = ANEURALNETWORKS_FUSED_RELU ;
582
+ fuse_code = activation_nodes_. at (dst_node. Index ()) ;
560
583
}
561
584
} else {
562
585
// if there is any other non-relu node using the output
0 commit comments