Skip to content

Commit db1d61a

Browse files
Yinghai Lufacebook-github-bot
Yinghai Lu
authored andcommitted
Add rule based filtering for ONNXIFI transformation (pytorch#17198)
Summary: Pull Request resolved: pytorch#17198 We come to the point that we need to apply some rules to bind certain ops together to avoid un-inferrable intermediate shapes. We either lower them together to backend or neither. This diff adds a pass for us to add rules like this. The first one is to bind `Gather` with `SparseLengthsWeighted*`. Reviewed By: ipiszy Differential Revision: D14118326 fbshipit-source-id: 14bc62e1feddae02a3dd8eae93b8f553d52ac951
1 parent 63214b5 commit db1d61a

6 files changed

+331
-212
lines changed

caffe2/onnx/onnxifi_graph_info.h

+9
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,26 @@ struct BackendGraphInfo {
1616
onnxBackend backend;
1717
onnxGraph graph;
1818
onnxifi_library* lib{nullptr};
19+
1920
BackendGraphInfo(
2021
onnxBackendID backend_id,
2122
onnxBackend backend,
2223
onnxGraph graph,
2324
onnxifi_library* lib)
2425
: backend_id(backend_id), backend(backend), graph(graph), lib(lib) {}
26+
2527
BackendGraphInfo(const BackendGraphInfo& other) = delete;
28+
2629
BackendGraphInfo& operator=(const BackendGraphInfo& other) = delete;
30+
2731
BackendGraphInfo(BackendGraphInfo&& other) noexcept {
2832
backend_id = other.backend_id;
2933
backend = other.backend;
3034
graph = other.graph;
3135
lib = other.lib;
3236
other.backend_id = other.backend = other.graph = other.lib = nullptr;
3337
}
38+
3439
BackendGraphInfo& operator=(BackendGraphInfo&& other) {
3540
backend_id = other.backend_id;
3641
backend = other.backend;
@@ -39,6 +44,7 @@ struct BackendGraphInfo {
3944
other.backend_id = other.backend = other.graph = other.lib = nullptr;
4045
return *this;
4146
}
47+
4248
~BackendGraphInfo() {
4349
if (lib) {
4450
onnxStatus err;
@@ -74,14 +80,17 @@ class OnnxBackendGraphMap {
7480
OnnxBackendGraphMap(OnnxBackendGraphMap&&) = delete;
7581
OnnxBackendGraphMap operator=(const OnnxBackendGraphMap&) = delete;
7682
OnnxBackendGraphMap operator=(OnnxBackendGraphMap&&) = delete;
83+
7784
SharedPtrBackendGraphInfo lookup(const std::string& key);
85+
7886
// If acquisition of graph_ptr fails then graph already exists. And the
7987
// corresponding SharedPtrBackendGraphInfo is returned. Otherwise graph_ptr is
8088
// inserted in the map and the wrapper SharedPtrBackendGraphInfo is
8189
// returned.teebfhlbnheuk
8290
SharedPtrBackendGraphInfo insert(
8391
const std::string& key,
8492
BackendGraphInfo graph);
93+
8594
void remove(const std::string& key);
8695

8796
private:

caffe2/operators/onnxifi_op.h

+22-13
Original file line numberDiff line numberDiff line change
@@ -171,19 +171,28 @@ class OnnxifiOp final : public Operator<Context> {
171171
}
172172
lib_->onnxReleaseBackendID(backend_ids[i]);
173173
}
174-
CAFFE_ENFORCE_EQ(
175-
lib_->onnxInitGraph(
176-
backend_,
177-
nullptr,
178-
onnx_model_str.size(),
179-
(const void*)(onnx_model_str.c_str()),
180-
weight_descs.size(),
181-
weight_descs.data(),
182-
&graph_),
183-
ONNXIFI_STATUS_SUCCESS);
184-
backend_graph_shared_ptr_ = backend_graph_map_ptr_->insert(
185-
op_id_string_,
186-
onnx::BackendGraphInfo(backend_id_, backend_, graph_, lib_));
174+
175+
// Lookup the backend first, if it's not there, create our own and try
176+
// submitting it to the backend_graph_map
177+
backend_graph_shared_ptr_ = backend_graph_map_ptr_->lookup(op_id_string_);
178+
if (!backend_graph_shared_ptr_) {
179+
LOG(INFO) << "Creating backend for " << op_id_string_;
180+
CAFFE_ENFORCE_EQ(
181+
lib_->onnxInitGraph(
182+
backend_,
183+
nullptr,
184+
onnx_model_str.size(),
185+
(const void*)(onnx_model_str.c_str()),
186+
weight_descs.size(),
187+
weight_descs.data(),
188+
&graph_),
189+
ONNXIFI_STATUS_SUCCESS);
190+
backend_graph_shared_ptr_ = backend_graph_map_ptr_->insert(
191+
op_id_string_,
192+
onnx::BackendGraphInfo(backend_id_, backend_, graph_, lib_));
193+
} else {
194+
LOG(INFO) << "Got cached backend for " << op_id_string_;
195+
}
187196
// This checks if our insertion was successful or some other thread did
188197
// the insert in the meantime.
189198
if (backend_graph_shared_ptr_->backend_id != backend_id_ ||

caffe2/opt/backend_transformer_base.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ std::string BackendTransformerBase::getModelId(const NetDef& net) {
2525

2626
TensorProto BackendTransformerBase::wrapShapeInfoIntoTensorProto(
2727
const std::string& name,
28-
const ShapeInfo& shape_info) {
28+
const ShapeInfo& shape_info) const {
2929
TensorProto t;
3030
t.set_name(name);
3131
t.set_data_type(shape_info.shape.data_type());

caffe2/opt/backend_transformer_base.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class BackendTransformerBase {
5252
// Wrap TensorShape into TensorProto
5353
TensorProto wrapShapeInfoIntoTensorProto(
5454
const std::string& name,
55-
const ShapeInfo& shape_info);
55+
const ShapeInfo& shape_info) const;
5656

5757
// Do bound shape inference and collect shape infos
5858
ShapeInfoMap inferShapes(

0 commit comments

Comments
 (0)