Skip to content

Commit a6949ab

Browse files
Michael Antonovfacebook-github-bot
Michael Antonov
authored andcommitted
Guard all Caffe2 protobuf string serializations with CAFFE_ENFORCE (fixed reverted bug) (pytorch#12848)
Summary: Pull Request resolved: pytorch#12848 Updated all non-test uses of protobuf::MessageLite::SerializeAsString to call SerializeAsString_EnforceCheck so that the return value is checked and can throw an exception if failing. Most of the affected code was called from classes derived from BlobSerializeBase. Didn't touch most tests and ENFORCE calls because they usually do checks anyway. Original commit changeset: c0760e73ecc7 Reviewed By: dzhulgakov Differential Revision: D10453456 fbshipit-source-id: d2f2b7b4578e721924354149f08f627c7e3bf070
1 parent dd00c29 commit a6949ab

14 files changed

+58
-18
lines changed

binaries/convert_caffe_image_db.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ int main(int argc, char** argv) {
7979
data->add_dims(datum.channels());
8080
data->set_byte_data(buffer, datum.data().size());
8181
}
82-
transaction->Put(cursor->key(), protos.SerializeAsString());
82+
transaction->Put(cursor->key(), SerializeAsString_EnforceCheck(protos));
8383
if (++count % FLAGS_batch_size == 0) {
8484
transaction->Commit();
8585
LOG(INFO) << "Converted " << count << " items so far.";
@@ -88,4 +88,3 @@ int main(int argc, char** argv) {
8888
LOG(INFO) << "A total of " << count << " items processed.";
8989
return 0;
9090
}
91-

caffe2/core/blob_serialization.cc

+21-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class StringSerializer : public BlobSerializerBase {
4747
blob_proto.set_name(name);
4848
blob_proto.set_type("std::string");
4949
blob_proto.set_content(*static_cast<const std::string*>(pointer));
50-
acceptor(name, blob_proto.SerializeAsString());
50+
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
5151
}
5252
};
5353

@@ -134,7 +134,7 @@ void TensorSerializer::SerializeWithChunkSize(
134134
tensor, name, blob_proto.mutable_tensor(), chunkStart, chunk_size);
135135
acceptor(
136136
c10::str(name, kChunkIdSeparator, chunkStart / chunk_size),
137-
blob_proto.SerializeAsString());
137+
SerializeBlobProtoAsString_EnforceCheck(blob_proto));
138138
};
139139

140140
#ifndef __ANDROID__
@@ -543,6 +543,25 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
543543
context->FinishDeviceComputation();
544544
}
545545

546+
////////////////////////////////////////////////////////////////////////////////
547+
// Serialization Helpers
548+
////////////////////////////////////////////////////////////////////////////////
549+
550+
std::string SerializeAsString_EnforceCheck(
551+
const google::protobuf::MessageLite& msg,
552+
const char* error_location) {
553+
std::string serialize_output;
554+
bool result = msg.SerializeToString(&serialize_output);
555+
if (!error_location) {
556+
CAFFE_ENFORCE(result, "protobuf::SerializeToString failed");
557+
} else {
558+
CAFFE_ENFORCE(result,
559+
"protobuf::SerializeToString failed for ", error_location);
560+
}
561+
return serialize_output;
562+
}
563+
564+
546565
namespace {
547566
// Serialize Tensor
548567
REGISTER_BLOB_SERIALIZER((TypeMeta::Id<Tensor>()), TensorSerializer);

caffe2/core/blob_serialization.h

+18
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,24 @@ inline void CopyFromProtoWithCast(
184184
}
185185

186186
} // namespace detail
187+
188+
////////////////////////////////////////////////////////////////////////////////
189+
// Serialization Helpers
190+
////////////////////////////////////////////////////////////////////////////////
191+
192+
// Converts MessageLite to string while also checking that SerializeAsString
193+
// succeeds. Pass description of class/function of the call if you'd
194+
// like it appended to the error message.
195+
CAFFE2_API std::string SerializeAsString_EnforceCheck(
196+
const google::protobuf::MessageLite&,
197+
const char* error_location = nullptr);
198+
199+
// Convert BlobProto to string with success checks.
200+
inline std::string SerializeBlobProtoAsString_EnforceCheck(
201+
const BlobProto& blob) {
202+
return SerializeAsString_EnforceCheck(blob, blob.name().c_str());
203+
}
204+
187205
} // namespace caffe2
188206

189207
#endif // CAFFE2_CORE_BLOB_SERIALIZATION_H_

caffe2/core/blob_test.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class BlobTestFooSerializer : public BlobSerializerBase {
6565
reinterpret_cast<const char*>(
6666
&static_cast<const BlobTestFoo*>(pointer)->val),
6767
sizeof(int32_t)));
68-
acceptor(name, blob_proto.SerializeAsString());
68+
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
6969
}
7070
};
7171

caffe2/core/db.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,8 @@ void DBReaderSerializer::Serialize(
186186
BlobProto blob_proto;
187187
blob_proto.set_name(name);
188188
blob_proto.set_type("DBReader");
189-
blob_proto.set_content(proto.SerializeAsString());
190-
acceptor(name, blob_proto.SerializeAsString());
189+
blob_proto.set_content(SerializeAsString_EnforceCheck(proto));
190+
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
191191
}
192192

193193
void DBReaderDeserializer::Deserialize(const BlobProto& proto, Blob* blob) {

caffe2/core/int8_serialization.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class Int8TensorCPUSerializer : public BlobSerializerBase {
5151
CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU");
5252
}
5353

54-
acceptor(name, blob_proto.SerializeAsString());
54+
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
5555
}
5656

5757
private:

caffe2/core/qtensor_serialization.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ void QTensorSerializer<Context>::Serialize(
5555
proto.set_is_signed(qtensor.is_signed());
5656
detail::CopyToProtoWithCast(
5757
qtensor.nbytes(), qtensor.data(), proto.mutable_data(), &this->context_);
58-
acceptor(name, blob_proto.SerializeAsString());
58+
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
5959
}
6060

6161
template <class Context>

caffe2/db/protodb.cc

+4-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ class ProtoDBCursor : public Cursor {
2020
void SeekToFirst() override { iter_ = 0; }
2121
void Next() override { ++iter_; }
2222
string key() override { return proto_->protos(iter_).name(); }
23-
string value() override { return proto_->protos(iter_).SerializeAsString(); }
23+
string value() override {
24+
return
25+
SerializeAsString_EnforceCheck(proto_->protos(iter_), "ProtoDBCursor");
26+
}
2427
bool Valid() override { return iter_ < proto_->protos_size(); }
2528

2629
private:

caffe2/operators/counter_ops.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class CounterSerializer : public BlobSerializerBase {
155155
proto.add_int64_data(
156156
(*static_cast<const std::unique_ptr<Counter<int64_t>>*>(pointer))
157157
->retrieve());
158-
acceptor(name, blob_proto.SerializeAsString());
158+
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
159159
}
160160
};
161161

caffe2/operators/dataset_ops.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -1451,7 +1451,7 @@ class TreeCursorSerializer : public BlobSerializerBase {
14511451
}
14521452
blob_proto.set_content(os.str());
14531453

1454-
acceptor(name, blob_proto.SerializeAsString());
1454+
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
14551455
}
14561456
};
14571457

@@ -1513,7 +1513,7 @@ void SharedTensorVectorPtrSerializer::Serialize(
15131513
blob_proto.set_name(name);
15141514
blob_proto.set_type("std::shared_ptr<std::vector<TensorCPU>>");
15151515
blob_proto.set_content("");
1516-
acceptor(name, blob_proto.SerializeAsString());
1516+
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
15171517
};
15181518

15191519
void SharedTensorVectorPtrDeserializer::Deserialize(

caffe2/operators/index_ops.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ class IndexSerializer : public BlobSerializerBase {
381381
os << base->maxElements() << " " << base->isFrozen();
382382
blob_proto.set_content(os.str());
383383

384-
acceptor(name, blob_proto.SerializeAsString());
384+
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
385385
}
386386

387387
private:

caffe2/operators/map_ops.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ class MapSerializer : public BlobSerializerBase {
225225
BlobProto blob_proto;
226226
blob_proto.set_name(name);
227227
blob_proto.set_type(MapTypeTraits<KEY_T, VALUE_T>::MapTypeName());
228-
blob_proto.set_content(tensor_protos.SerializeAsString());
229-
acceptor(name, blob_proto.SerializeAsString());
228+
blob_proto.set_content(SerializeAsString_EnforceCheck(tensor_protos));
229+
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
230230
}
231231
};
232232

caffe2/python/pybind_state.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,8 @@ void addObjectMethods(py::module& m) {
586586
const auto& meta = GetGradientForOp(def, output_gradients);
587587
std::vector<py::bytes> grad_ops;
588588
for (const auto& op : meta.ops_) {
589-
grad_ops.push_back(op.SerializeAsString());
589+
grad_ops.push_back(
590+
SerializeAsString_EnforceCheck(op, "addObjectMethods"));
590591
}
591592
return std::pair<std::vector<py::bytes>, std::vector<GradientWrapper>>{
592593
grad_ops, meta.g_input_};

caffe2/sgd/iter_op.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ void MutexSerializer::Serialize(
1717
blob_proto.set_name(name);
1818
blob_proto.set_type("std::unique_ptr<std::mutex>");
1919
blob_proto.set_content("");
20-
acceptor(name, blob_proto.SerializeAsString());
20+
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
2121
}
2222

2323
void MutexDeserializer::Deserialize(const BlobProto& /* unused */, Blob* blob) {

0 commit comments

Comments
 (0)