22#include " caffe2/core/operator_schema.h"
33#include " caffe2/core/tensor_impl.h"
44#include " caffe2/utils/proto_utils.h"
5+ #include " caffe2/utils/string_utils.h"
56
67namespace caffe2 {
78
@@ -60,6 +61,10 @@ void BoundShapeInferencer::InferBoundShapeAndType(
6061 InferReshape (op);
6162 } else if (op.type () == " LengthsRangeFill" ) {
6263 InferLengthsRangeFill (op);
64+ } else if (
65+ caffe2::StartsWith (op.type (), " GivenTensor" ) &&
66+ caffe2::EndsWith (op.type (), " Fill" )) {
67+ InferGivenTensorFill (op);
6368 } else {
6469 InferCommonOp (op);
6570 }
@@ -122,6 +127,15 @@ std::vector<TensorShape> InferOutput(
122127 return schema->InferTensor (op, input_shapes);
123128}
124129
130+ void BoundShapeInferencer::InferGivenTensorFill (const OperatorDef& op) {
131+ CAFFE_ENFORCE_EQ (op.output_size (), 1 , op.type (), " must have 1 output" );
132+ InferCommonOp (op);
133+ auto it = shape_info_.find (op.output (0 ));
134+ if (it != shape_info_.end ()) {
135+ it->second .dim_type = ShapeInfo::DimType::CONSTANT;
136+ }
137+ }
138+
125139void BoundShapeInferencer::InferLengthsRangeFill (const OperatorDef& op) {
126140 CAFFE_ENFORCE_EQ (op.input_size (), 1 , " LengthsRangeFill must have 1 input" );
127141 CAFFE_ENFORCE_EQ (op.output_size (), 1 , " LengthsRangeFill must have 1 output" );
@@ -342,6 +356,7 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) {
342356void BoundShapeInferencer::InferCommonOp (const OperatorDef& op) {
343357 // First, we need to check that all the input shape/types are already
344358 // presented
359+ try {
345360 std::vector<TensorShape> input_shapes;
346361 for (const auto & input : op.input ()) {
347362 const auto it = shape_info_.find (input);
@@ -356,11 +371,7 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
356371 const OpSchema* schema = OpSchemaRegistry::Schema (op.type ());
357372 CAFFE_ENFORCE (schema);
358373 std::vector<TensorShape> output_shapes;
359- try {
360374 output_shapes = schema->InferTensor (op, input_shapes);
361- } catch (const std::exception& e) {
362- LOG (WARNING) << " Caught exception while inferring shapes for " << op.type ();
363- }
364375 int i = 0 ;
365376 for (const auto & shape : output_shapes) {
366377 if (shape.unknown_shape ()) {
@@ -373,6 +384,13 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
373384 ConvertToVec (shape.dims ()),
374385 shape.data_type ());
375386 }
387+ } catch (const caffe2::EnforceNotMet& e) {
388+ LOG (ERROR) << " Enforce not met while inferring shapes for " << op.type ()
389+ << " : " << e.msg ();
390+ } catch (const std::exception& e) {
391+ LOG (WARNING) << " Caught exception while inferring shapes for " << op.type ()
392+ << " : " << e.what ();
393+ }
376394}
377395
378396} // namespace caffe2
0 commit comments