Skip to content

Commit 46990c2

Browse files
ggoossenfacebook-github-bot
authored andcommitted
Verify def before infer fensor (pytorch#18129)
Summary: Pull Request resolved: pytorch#18129 A lot of tensor interference function assume the operator passes the schema. So call Verity to make sure this is actually the case. Created diff before to add checking in Concat (pytorch#17110), but I encountered lot more places where this is assumed (for example ElementwiseOpShapeInference) Reviewed By: mdschatz Differential Revision: D14503933 fbshipit-source-id: cf0097b8c3e4beb1cded6b61e092a6adee4b8fcb
1 parent 77a7285 commit 46990c2

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

caffe2/core/operator_schema.h

+5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "caffe2/core/logging.h"
1515
#include "caffe2/proto/caffe2_pb.h"
1616
#include "caffe2/utils/filler.h"
17+
#include "caffe2/utils/proto_utils.h"
1718

1819
namespace caffe2 {
1920

@@ -186,6 +187,10 @@ class CAFFE2_API OpSchema {
186187
inline vector<TensorShape> InferTensor(
187188
const OperatorDef& def,
188189
const vector<TensorShape>& input_type_shape) const {
190+
CAFFE_ENFORCE(
191+
Verify(def),
192+
"(InferTensor) Operator def did not pass schema checking: ",
193+
ProtoDebugString(def));
189194
return tensor_inference_function_(def, input_type_shape);
190195
}
191196

caffe2/python/operator_test/shape_inference_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,8 @@ def testConcat(self):
415415
net = core.Net("concat")
416416

417417
net.Concat(["A", "B"], ["C", "splits"], axis=1)
418-
net.Concat(["C", "D"], ["E"], order="NCHW")
419-
net.Concat(["E", "F"], ["G"], add_axis=1, order="NHWC")
418+
net.Concat(["C", "D"], ["E", "splitsE"], order="NCHW")
419+
net.Concat(["E", "F"], ["G", "splitsG"], add_axis=1, order="NHWC")
420420
(shapes, types) = workspace.InferShapesAndTypes(
421421
[net],
422422
{
@@ -435,8 +435,8 @@ def testConcatInt32(self):
435435
net = core.Net("concat")
436436

437437
net.Concat(["A", "B"], ["C", "splits"], axis=1)
438-
net.Concat(["C", "D"], ["E"], order="NCHW")
439-
net.Concat(["E", "F"], ["G"], add_axis=1, order="NHWC")
438+
net.Concat(["C", "D"], ["E", "splitsE"], order="NCHW")
439+
net.Concat(["E", "F"], ["G", "splitsG"], add_axis=1, order="NHWC")
440440
(shapes, types) = workspace.InferShapesAndTypes(
441441
[net],
442442
blob_dimensions={

0 commit comments

Comments
 (0)