2
2
#include " caffe2/core/operator_schema.h"
3
3
#include " caffe2/core/tensor_impl.h"
4
4
#include " caffe2/utils/proto_utils.h"
5
+ #include " caffe2/utils/string_utils.h"
5
6
6
7
namespace caffe2 {
7
8
@@ -60,6 +61,10 @@ void BoundShapeInferencer::InferBoundShapeAndType(
60
61
InferReshape (op);
61
62
} else if (op.type () == " LengthsRangeFill" ) {
62
63
InferLengthsRangeFill (op);
64
+ } else if (
65
+ caffe2::StartsWith (op.type (), " GivenTensor" ) &&
66
+ caffe2::EndsWith (op.type (), " Fill" )) {
67
+ InferGivenTensorFill (op);
63
68
} else {
64
69
InferCommonOp (op);
65
70
}
@@ -122,6 +127,15 @@ std::vector<TensorShape> InferOutput(
122
127
return schema->InferTensor (op, input_shapes);
123
128
}
124
129
130
+ void BoundShapeInferencer::InferGivenTensorFill (const OperatorDef& op) {
131
+ CAFFE_ENFORCE_EQ (op.output_size (), 1 , op.type (), " must have 1 output" );
132
+ InferCommonOp (op);
133
+ auto it = shape_info_.find (op.output (0 ));
134
+ if (it != shape_info_.end ()) {
135
+ it->second .dim_type = ShapeInfo::DimType::CONSTANT;
136
+ }
137
+ }
138
+
125
139
void BoundShapeInferencer::InferLengthsRangeFill (const OperatorDef& op) {
126
140
CAFFE_ENFORCE_EQ (op.input_size (), 1 , " LengthsRangeFill must have 1 input" );
127
141
CAFFE_ENFORCE_EQ (op.output_size (), 1 , " LengthsRangeFill must have 1 output" );
@@ -342,6 +356,7 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) {
342
356
void BoundShapeInferencer::InferCommonOp (const OperatorDef& op) {
343
357
// First, we need to check that all the input shape/types are already
344
358
// presented
359
+ try {
345
360
std::vector<TensorShape> input_shapes;
346
361
for (const auto & input : op.input ()) {
347
362
const auto it = shape_info_.find (input);
@@ -356,11 +371,7 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
356
371
const OpSchema* schema = OpSchemaRegistry::Schema (op.type ());
357
372
CAFFE_ENFORCE (schema);
358
373
std::vector<TensorShape> output_shapes;
359
- try {
360
374
output_shapes = schema->InferTensor (op, input_shapes);
361
- } catch (const std::exception & e) {
362
- LOG (WARNING) << " Caught exception while inferring shapes for " << op.type ();
363
- }
364
375
int i = 0 ;
365
376
for (const auto & shape : output_shapes) {
366
377
if (shape.unknown_shape ()) {
@@ -373,6 +384,13 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
373
384
ConvertToVec (shape.dims ()),
374
385
shape.data_type ());
375
386
}
387
+ } catch (const caffe2::EnforceNotMet& e) {
388
+ LOG (ERROR) << " Enforce not met while inferring shapes for " << op.type ()
389
+ << " : " << e.msg ();
390
+ } catch (const std::exception & e) {
391
+ LOG (WARNING) << " Caught exception while inferring shapes for " << op.type ()
392
+ << " : " << e.what ();
393
+ }
376
394
}
377
395
378
396
} // namespace caffe2
0 commit comments