Skip to content

Commit b991ee4

Browse files
authored
Cleanup NNAPI code (microsoft#5505)
* Cleanup NNAPI code * Check return of GetNCHWInput
1 parent 6f65e2a commit b991ee4

File tree

4 files changed

+118
-141
lines changed

4 files changed

+118
-141
lines changed

onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ Status ModelBuilder::Prepare() {
122122
RETURN_STATUS_ON_ERROR(nnapi_->ANeuralNetworksModel_create(&nnapi_model_->model_));
123123
ORT_RETURN_IF_ERROR(GetTargetDevices());
124124
PreprocessInitializers();
125+
PreprocessActivations();
125126
ORT_RETURN_IF_ERROR(RegisterInitializers());
126127
ORT_RETURN_IF_ERROR(RegisterModelInputs());
127128
ORT_RETURN_IF_ERROR(AddOperations());
@@ -190,6 +191,28 @@ void ModelBuilder::PreprocessInitializers() {
190191
}
191192
}
192193

194+
void ModelBuilder::PreprocessActivations() {
195+
const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder();
196+
for (size_t i = 0; i < node_indices.size(); i++) {
197+
const auto* node(graph_viewer_.GetNode(node_indices[i]));
198+
const auto& op_type(node->OpType());
199+
200+
if (op_type == "Relu") {
201+
activation_nodes_.emplace(node->Index(), ANEURALNETWORKS_FUSED_RELU);
202+
} else if (op_type == "Clip") { // Relu1 or Relu6
203+
float min, max;
204+
if (!GetClipMinMax(*this, *node, min, max))
205+
continue;
206+
207+
if (min == -1.0f && max == 1.0f) {
208+
activation_nodes_.emplace(node->Index(), ANEURALNETWORKS_FUSED_RELU1);
209+
} else if (min == 0.0f && max == 6.0f) {
210+
activation_nodes_.emplace(node->Index(), ANEURALNETWORKS_FUSED_RELU6);
211+
}
212+
}
213+
}
214+
}
215+
193216
// Help to get all quantized operators' input and the node(s) using the input
194217
std::unordered_map<std::string, vector<const Node*>> GetAllQuantizedOpInputs(const GraphViewer& graph_viewer) {
195218
std::unordered_map<std::string, vector<const Node*>> all_quantized_op_inputs;
@@ -554,9 +577,9 @@ int32_t ModelBuilder::FindActivation(const Node& node, const NodeArg& output) {
554577
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
555578
const auto& dst_node = it->GetNode();
556579
const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()];
557-
if (dst_node.OpType() == "Relu") {
580+
if (Contains(activation_nodes_, dst_node.Index())) {
558581
if (&output == dst_input) {
559-
fuse_code = ANEURALNETWORKS_FUSED_RELU;
582+
fuse_code = activation_nodes_.at(dst_node.Index());
560583
}
561584
} else {
562585
// if there is any other non-relu node using the output

onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ class ModelBuilder {
132132
std::unordered_map<std::string, const ONNX_NAMESPACE::TensorProto&> initializers_;
133133
std::unordered_set<std::string> skipped_initializers_;
134134

135+
// All activation nodes (Relu, Relu1, Relu6) as a map <NodeIndex, activeation_code>
136+
std::unordered_map<NodeIndex, int32_t> activation_nodes_;
137+
135138
std::unordered_map<std::string, std::shared_ptr<IOpBuilder>> op_builders_;
136139

137140
// Operands in nhwc
@@ -157,12 +160,18 @@ class ModelBuilder {
157160
Status Prepare() ORT_MUST_USE_RESULT;
158161

159162
Status GetTargetDevices() ORT_MUST_USE_RESULT;
163+
// Get names of all the initializers
160164
void GetAllInitializers();
165+
// If a NNAPI operation will use initializers directly, we will add the initializers to the skip list
161166
void PreprocessInitializers();
167+
// Preprocess all the activation nodes (Relu/Relu1/Relu6) for easy query later
168+
void PreprocessActivations();
169+
// Copy and process all the initializers to NNAPI model
162170
Status RegisterInitializers() ORT_MUST_USE_RESULT;
163171
Status RegisterModelInputs() ORT_MUST_USE_RESULT;
164172
Status AddOperations() ORT_MUST_USE_RESULT;
165173
Status RegisterModelOutputs() ORT_MUST_USE_RESULT;
174+
// After constructing the NNAPI model, will set the shape inferencing record to the Model
166175
void RegisterModelShaper();
167176

168177
Status SetOperandValue(uint32_t index, Model::NNMemory* memory,

0 commit comments

Comments
 (0)