Skip to content

Commit fb5790c

Browse files
Yinghai Lufacebook-github-bot
Yinghai Lu
authored andcommitted
Remove second output of Reshape during ONNXIFI transform (pytorch#17027)
Summary: Pull Request resolved: pytorch#17027 Glow doesn't support second output of Reshape right now and it's useless. For correctness, we do make sure that the second output of Reshape is of Constant type during bound shape inference. Reviewed By: ipiszy Differential Revision: D14056555 fbshipit-source-id: f39cca7ba941bf5a5cc3adc96e2b1f943cc0be93
1 parent 9d01be1 commit fb5790c

4 files changed

+64
-7
lines changed

caffe2/opt/bound_shape_inference_test.cc

+40
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,46 @@ TEST(BoundShapeInference, SparseLengthsSumFused8BitRowwise) {
110110
out_shape, "Out", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 50});
111111
}
112112

113+
TEST(BoundShapeInference, Reshape) {
114+
NetDef net;
115+
std::vector<int> new_shape{-1, 8};
116+
std::vector<int> new_shape2{2, 8};
117+
net.add_op()->CopyFrom(
118+
CreateOperatorDef("FC", "", {"X0", "W0", "B0"}, {"X1"}, {}));
119+
net.add_op()->CopyFrom(CreateOperatorDef(
120+
"Reshape",
121+
"",
122+
{"X1"},
123+
{"Y1", "old_shape"},
124+
{MakeArgument<std::vector<int>>("shape", new_shape)}));
125+
126+
// Cannot infer shape for this one because input/output shape doesn't match
127+
net.add_op()->CopyFrom(CreateOperatorDef(
128+
"Reshape",
129+
"",
130+
{"X1"},
131+
{"Y2", "old_shape2"},
132+
{MakeArgument<std::vector<int>>("shape", new_shape2)}));
133+
ShapeInfoMap shape_map;
134+
shape_map.emplace(
135+
"W0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1024}));
136+
shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
137+
BoundShapeSpec spec(20, 1000);
138+
BoundShapeInferencer eng(spec);
139+
eng.InferBoundShapeAndType(net, shape_map);
140+
const auto& out_shape = eng.shape_info();
141+
verifyShapeInfo(
142+
out_shape, "X0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 1024});
143+
verifyShapeInfo(
144+
out_shape, "X1", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
145+
verifyShapeInfo(
146+
out_shape,
147+
"Y1",
148+
ShapeInfo::DimType::BATCH,
149+
{spec.max_batch_size * 16 / 8, 8});
150+
EXPECT_TRUE(out_shape.find("Y2") == out_shape.end());
151+
}
152+
113153
TEST(BoundShapeInference, ConcatMissingInput) {
114154
NetDef net;
115155
net.add_op()->CopyFrom(CreateOperatorDef(

caffe2/opt/bound_shape_inferencer.cc

+16-2
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ void BoundShapeInferencer::InferBoundShapeAndType(
5656
InferFC(op);
5757
} else if (op.type() == "Concat") {
5858
InferConcat(op);
59+
} else if (op.type() == "Reshape") {
60+
InferReshape(op);
5961
} else if (op.type() == "LengthsRangeFill") {
6062
InferLengthsRangeFill(op);
6163
} else {
@@ -198,6 +200,13 @@ void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) {
198200
TensorProto_DataType_FLOAT);
199201
}
200202

203+
void BoundShapeInferencer::InferReshape(const OperatorDef& op) {
204+
InferCommonOp(op);
205+
// old_shape should be a constant
206+
if (op.output_size() > 1 && shape_info_.count(op.output(1))) {
207+
shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT;
208+
}
209+
}
201210
// For concat net, if some inputs are missing and we have add_axis argument, it
202211
// means that all the inputs should be of the same dimension. In this case, we
203212
// can infer the shape of the missing inputs
@@ -253,7 +262,7 @@ void BoundShapeInferencer::InferConcat(const OperatorDef& op) {
253262
}
254263
InferCommonOp(op);
255264
// split_info should be a constant
256-
if (op.output_size() > 1) {
265+
if (op.output_size() > 1 && shape_info_.count(op.output(1))) {
257266
shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT;
258267
}
259268
}
@@ -345,7 +354,12 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
345354

346355
const OpSchema* schema = OpSchemaRegistry::Schema(op.type());
347356
CAFFE_ENFORCE(schema);
348-
auto output_shapes = schema->InferTensor(op, input_shapes);
357+
std::vector<TensorShape> output_shapes;
358+
try {
359+
output_shapes = schema->InferTensor(op, input_shapes);
360+
} catch (const std::exception& e) {
361+
LOG(WARNING) << "Caught exception while inferring shapes for " << op.type();
362+
}
349363
int i = 0;
350364
for (const auto& shape : output_shapes) {
351365
if (shape.unknown_shape()) {

caffe2/opt/bound_shape_inferencer.h

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class CAFFE2_API BoundShapeInferencer {
6666
void InferSparseLengthsSum(const OperatorDef& op);
6767
void InferFC(const OperatorDef& op);
6868
void InferConcat(const OperatorDef& op);
69+
void InferReshape(const OperatorDef& op);
6970
void InferLengthsRangeFill(const OperatorDef& op);
7071

7172
// Standard shape/type inference using op schema registered shape inference

caffe2/opt/onnxifi_transformer.cc

+7-5
Original file line numberDiff line numberDiff line change
@@ -489,15 +489,16 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2(
489489
// We already have all the ops and external inputs and outputs!
490490
NetDef onnxifi_net(net);
491491

492-
// Remove the second output of Concat from external_output. In addition, we
493-
// remove those outputs from the Onnxifi op too.
492+
// Remove the second output of Concat/Reshape from external_output. In
493+
// addition, we remove those outputs from the Onnxifi op too.
494494
// TODO: This approach is a bit hacky as we assume that the second output is
495495
// never used. A more appropriate approach can be learned from the ONNX path,
496496
// where we statically computes the split_info given input shape and insert a
497497
// GivenTensorIntFill op
498498
std::unordered_set<std::string> split_infos;
499499
for (auto& op : *onnxifi_net.mutable_op()) {
500-
if (op.type() == "Concat" && op.output_size() == 2) {
500+
if ((op.type() == "Concat" || op.type() == "Reshape") &&
501+
op.output_size() == 2) {
501502
split_infos.emplace(op.output(1));
502503
}
503504
}
@@ -802,8 +803,9 @@ NetDef OnnxifiTransformer::TransformViaC2(
802803
for (const auto& o : op.output()) {
803804
net.add_external_output(o);
804805
}
805-
// Remove the second output of Concat from the external_output
806-
if (op.type() == "Concat" && op.output_size() == 2) {
806+
// Remove the second output of Concat/Reshape from the external_output
807+
if ((op.type() == "Concat" || op.type() == "Reshape") &&
808+
op.output_size() == 2) {
807809
net.mutable_external_output()->RemoveLast();
808810
}
809811

0 commit comments

Comments
 (0)