Skip to content

Commit efed875

Browse files
Yinghai Lufacebook-github-bot
Yinghai Lu
authored andcommitted
Catch exceptions in bound_shape_inference (pytorch#17775)
Summary: Pull Request resolved: pytorch#17775 Handles use input shape hint properly. Reviewed By: zrphercule Differential Revision: D14368735 fbshipit-source-id: 504cd96589e47aa432617e56362aa6b01a25ba9b
1 parent 4a7c549 commit efed875

File tree

4 files changed

+41
-6
lines changed

4 files changed

+41
-6
lines changed

caffe2/opt/backend_transformer_base.cc

+5-1
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,15 @@ ShapeInfoMap BackendTransformerBase::inferShapes(
7878
shape_map[s] = shape_info;
7979
}
8080
}
81+
// We treat hinted shapes as BATCH. If there are shape hints on blobs in the
82+
// workspace, since they are already inserted as CONSTANT, it will take effect
83+
// here. For SEQ typed tensors, there are only a few of them and they will be
84+
// handled by BoundShapeInferencer.
8185
for (const auto& kv : shape_hints_mapped) {
8286
shape_map.emplace(
8387
std::piecewise_construct,
8488
std::forward_as_tuple(kv.first),
85-
std::forward_as_tuple(ShapeInfo::DimType::CONSTANT, kv.second));
89+
std::forward_as_tuple(ShapeInfo::DimType::BATCH, kv.second));
8690
}
8791
BoundShapeInferencer eng(spec);
8892
eng.InferBoundShapeAndType(*pred_net, shape_map);

caffe2/opt/bound_shape_inferencer.cc

+22-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "caffe2/core/operator_schema.h"
33
#include "caffe2/core/tensor_impl.h"
44
#include "caffe2/utils/proto_utils.h"
5+
#include "caffe2/utils/string_utils.h"
56

67
namespace caffe2 {
78

@@ -60,6 +61,10 @@ void BoundShapeInferencer::InferBoundShapeAndType(
6061
InferReshape(op);
6162
} else if (op.type() == "LengthsRangeFill") {
6263
InferLengthsRangeFill(op);
64+
} else if (
65+
caffe2::StartsWith(op.type(), "GivenTensor") &&
66+
caffe2::EndsWith(op.type(), "Fill")) {
67+
InferGivenTensorFill(op);
6368
} else {
6469
InferCommonOp(op);
6570
}
@@ -122,6 +127,15 @@ std::vector<TensorShape> InferOutput(
122127
return schema->InferTensor(op, input_shapes);
123128
}
124129

130+
void BoundShapeInferencer::InferGivenTensorFill(const OperatorDef& op) {
131+
CAFFE_ENFORCE_EQ(op.output_size(), 1, op.type(), " must have 1 output");
132+
InferCommonOp(op);
133+
auto it = shape_info_.find(op.output(0));
134+
if (it != shape_info_.end()) {
135+
it->second.dim_type = ShapeInfo::DimType::CONSTANT;
136+
}
137+
}
138+
125139
void BoundShapeInferencer::InferLengthsRangeFill(const OperatorDef& op) {
126140
CAFFE_ENFORCE_EQ(op.input_size(), 1, "LengthsRangeFill must have 1 input");
127141
CAFFE_ENFORCE_EQ(op.output_size(), 1, "LengthsRangeFill must have 1 output");
@@ -342,6 +356,7 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) {
342356
void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
343357
// First, we need to check that all the input shape/types are already
344358
// presented
359+
try {
345360
std::vector<TensorShape> input_shapes;
346361
for (const auto& input : op.input()) {
347362
const auto it = shape_info_.find(input);
@@ -356,11 +371,7 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
356371
const OpSchema* schema = OpSchemaRegistry::Schema(op.type());
357372
CAFFE_ENFORCE(schema);
358373
std::vector<TensorShape> output_shapes;
359-
try {
360374
output_shapes = schema->InferTensor(op, input_shapes);
361-
} catch (const std::exception& e) {
362-
LOG(WARNING) << "Caught exception while inferring shapes for " << op.type();
363-
}
364375
int i = 0;
365376
for (const auto& shape : output_shapes) {
366377
if (shape.unknown_shape()) {
@@ -373,6 +384,13 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
373384
ConvertToVec(shape.dims()),
374385
shape.data_type());
375386
}
387+
} catch (const caffe2::EnforceNotMet& e) {
388+
LOG(ERROR) << "Enforce not met while inferring shapes for " << op.type()
389+
<< ": " << e.msg();
390+
} catch (const std::exception& e) {
391+
LOG(WARNING) << "Caught exception while inferring shapes for " << op.type()
392+
<< ": " << e.what();
393+
}
376394
}
377395

378396
} // namespace caffe2

caffe2/opt/bound_shape_inferencer.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class CAFFE2_API BoundShapeInferencer {
6363
std::vector<int64_t> bound_dims,
6464
TensorProto::DataType type);
6565

66+
void InferGivenTensorFill(const OperatorDef& op);
6667
void InferSparseLengthsSum(const OperatorDef& op);
6768
void InferFC(const OperatorDef& op);
6869
void InferConcat(const OperatorDef& op);
@@ -74,7 +75,7 @@ class CAFFE2_API BoundShapeInferencer {
7475
void InferCommonOp(const OperatorDef& op);
7576

7677
const BoundShapeSpec spec_;
77-
ShapeInfo::DimType current_dim_type_{ShapeInfo::DimType::UNKNOWN};
78+
ShapeInfo::DimType current_dim_type_{ShapeInfo::DimType::BATCH};
7879
int64_t current_max_batch_size_{0};
7980
std::unordered_map<std::string, ShapeInfo> shape_info_;
8081
};

caffe2/utils/string_utils.h

+12
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ CAFFE2_API inline bool StartsWith(const std::string& str, const std::string& pre
2121
prefix.end();
2222
}
2323

24+
CAFFE2_API inline bool EndsWith(
25+
const std::string& full,
26+
const std::string& ending) {
27+
if (full.length() >= ending.length()) {
28+
return (
29+
0 ==
30+
full.compare(full.length() - ending.length(), ending.length(), ending));
31+
} else {
32+
return false;
33+
}
34+
}
35+
2436
CAFFE2_API int32_t editDistanceHelper(const char* s1,
2537
size_t s1_len,
2638
const char* s2,

0 commit comments

Comments
 (0)