Skip to content

Commit 6cbf199

Browse files
smessmerfacebook-github-bot
authored andcommitted
Serialization takes pointers instead of Blob (pytorch#11925)
Summary: Pull Request resolved: pytorch#11925 This is step 1 in the refactoring to remove Blob::ShareExternal(), i.e. Blob would then always own its contents. ShareExternal() is for example used to pass non-owning blobs to serialization. This diff prepares removing that. Reviewed By: ezyang Differential Revision: D9884177 fbshipit-source-id: d01df9a613a4fc62e5679fe45bfc47e2c899b818
1 parent 25db86c commit 6cbf199

15 files changed

+79
-46
lines changed

caffe2/core/blob_serialization.cc

+14-9
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,16 @@ class StringSerializer : public BlobSerializerBase {
3737
* otherwise this function produces a fatal error.
3838
*/
3939
void Serialize(
40-
const Blob& blob,
40+
const void* pointer,
41+
TypeMeta typeMeta,
4142
const string& name,
4243
SerializationAcceptor acceptor) override {
43-
CAFFE_ENFORCE(blob.IsType<std::string>());
44+
CAFFE_ENFORCE(typeMeta.Match<std::string>());
4445

4546
BlobProto blob_proto;
4647
blob_proto.set_name(name);
4748
blob_proto.set_type("std::string");
48-
blob_proto.set_content(blob.template Get<std::string>());
49+
blob_proto.set_content(*static_cast<const std::string*>(pointer));
4950
acceptor(name, blob_proto.SerializeAsString());
5051
}
5152
};
@@ -70,7 +71,8 @@ void SerializeBlob(
7071
std::unique_ptr<BlobSerializerBase> serializer(
7172
CreateSerializer(blob.meta().id()));
7273
CAFFE_ENFORCE(serializer, "No known serializer for ", blob.meta().name());
73-
serializer->SerializeWithChunkSize(blob, name, acceptor, chunk_size);
74+
serializer->SerializeWithChunkSize(
75+
blob.GetRaw(), blob.meta(), name, acceptor, chunk_size);
7476
}
7577

7678
// The blob serialization member function implementation.
@@ -86,19 +88,22 @@ std::string SerializeBlob(const Blob& blob, const string& name) {
8688
}
8789

8890
void TensorSerializer::Serialize(
89-
const Blob& blob,
91+
const void* pointer,
92+
TypeMeta typeMeta,
9093
const string& name,
9194
BlobSerializerBase::SerializationAcceptor acceptor) {
92-
this->SerializeWithChunkSize(blob, name, acceptor, kDefaultChunkSize);
95+
this->SerializeWithChunkSize(
96+
pointer, typeMeta, name, acceptor, kDefaultChunkSize);
9397
}
9498

9599
void TensorSerializer::SerializeWithChunkSize(
96-
const Blob& blob,
100+
const void* pointer,
101+
TypeMeta typeMeta,
97102
const string& name,
98103
BlobSerializerBase::SerializationAcceptor acceptor,
99104
int chunk_size) {
100-
CAFFE_ENFORCE(blob.IsType<Tensor>());
101-
const auto& tensor = blob.template Get<Tensor>();
105+
CAFFE_ENFORCE(typeMeta.Match<Tensor>());
106+
const auto& tensor = *static_cast<const Tensor*>(pointer);
102107
if (chunk_size == kNoChunking) {
103108
chunk_size = tensor.size() + 1; // to account for empty tensors
104109
} else if (chunk_size == kDefaultChunkSize) {

caffe2/core/blob_serialization.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,13 @@ class CAFFE2_API TensorSerializer : public BlobSerializerBase {
7070
* otherwise this function produces a fatal error.
7171
*/
7272
void Serialize(
73-
const Blob& blob,
73+
const void* pointer,
74+
TypeMeta typeMeta,
7475
const string& name,
7576
SerializationAcceptor acceptor) override;
7677
void SerializeWithChunkSize(
77-
const Blob& blob,
78+
const void* pointer,
79+
TypeMeta typeMeta,
7880
const string& name,
7981
SerializationAcceptor acceptor,
8082
int chunk_size) override;

caffe2/core/blob_serializer_base.h

+8-4
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,20 @@ class BlobSerializerBase {
4343
* serailizer can use it to save blob in several chunks
4444
* acceptor should be thread-safe
4545
*/
46-
virtual void Serialize(const Blob& blob, const std::string& name,
47-
SerializationAcceptor acceptor) = 0;
46+
virtual void Serialize(
47+
const void* pointer,
48+
TypeMeta typeMeta,
49+
const std::string& name,
50+
SerializationAcceptor acceptor) = 0;
4851

4952
virtual void SerializeWithChunkSize(
50-
const Blob& blob,
53+
const void* pointer,
54+
TypeMeta typeMeta,
5155
const std::string& name,
5256
SerializationAcceptor acceptor,
5357
int /*chunk_size*/) {
5458
// Base implementation.
55-
Serialize(blob, name, acceptor);
59+
Serialize(pointer, typeMeta, name, acceptor);
5660
}
5761
};
5862

caffe2/core/blob_test.cc

+9-6
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,19 @@ class BlobTestFooSerializer : public BlobSerializerBase {
5151
* otherwise this function produces a fatal error.
5252
*/
5353
void Serialize(
54-
const Blob& blob,
54+
const void* pointer,
55+
TypeMeta typeMeta,
5556
const string& name,
5657
SerializationAcceptor acceptor) override {
57-
CAFFE_ENFORCE(blob.IsType<BlobTestFoo>());
58+
CAFFE_ENFORCE(typeMeta.Match<BlobTestFoo>());
5859

5960
BlobProto blob_proto;
6061
blob_proto.set_name(name);
6162
blob_proto.set_type("BlobTestFoo");
6263
// For simplicity we will just serialize the 4-byte content as a string.
6364
blob_proto.set_content(std::string(
64-
reinterpret_cast<const char*>(&(blob.Get<BlobTestFoo>().val)),
65+
reinterpret_cast<const char*>(
66+
&static_cast<const BlobTestFoo*>(pointer)->val),
6567
sizeof(int32_t)));
6668
acceptor(name, blob_proto.SerializeAsString());
6769
}
@@ -942,11 +944,12 @@ class DummyTypeSerializer : public BlobSerializerBase {
942944
DummyTypeSerializer() {}
943945
~DummyTypeSerializer() {}
944946
void Serialize(
945-
const Blob& blob,
947+
const void* pointer,
948+
TypeMeta typeMeta,
946949
const string& name,
947950
SerializationAcceptor acceptor) override {
948-
CAFFE_ENFORCE(blob.IsType<DummyType>());
949-
const auto& container = blob.template Get<DummyType>();
951+
CAFFE_ENFORCE(typeMeta.Match<DummyType>());
952+
const auto& container = *static_cast<const DummyType*>(pointer);
950953
for (int k = 0; k < container.n_chunks; ++k) {
951954
std::string serialized_chunk = container.serialize(name, k);
952955
acceptor(c10::str(name, kChunkIdSeparator, k), serialized_chunk);

caffe2/core/db.cc

+4-3
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,12 @@ REGISTER_CAFFE2_DB(MiniDB, MiniDB);
170170
REGISTER_CAFFE2_DB(minidb, MiniDB);
171171

172172
void DBReaderSerializer::Serialize(
173-
const Blob& blob,
173+
const void* pointer,
174+
TypeMeta typeMeta,
174175
const string& name,
175176
BlobSerializerBase::SerializationAcceptor acceptor) {
176-
CAFFE_ENFORCE(blob.IsType<DBReader>());
177-
auto& reader = blob.Get<DBReader>();
177+
CAFFE_ENFORCE(typeMeta.Match<DBReader>());
178+
const auto& reader = *static_cast<const DBReader*>(pointer);
178179
DBReaderProto proto;
179180
proto.set_name(name);
180181
proto.set_source(reader.source_);

caffe2/core/db.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,8 @@ class CAFFE2_API DBReaderSerializer : public BlobSerializerBase {
295295
* otherwise this function produces a fatal error.
296296
*/
297297
void Serialize(
298-
const Blob& blob,
298+
const void* pointer,
299+
TypeMeta typeMeta,
299300
const string& name,
300301
BlobSerializerBase::SerializationAcceptor acceptor) override;
301302
};

caffe2/core/int8_serialization.cc

+4-2
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@ namespace int8 {
1111
class Int8TensorCPUSerializer : public BlobSerializerBase {
1212
public:
1313
void Serialize(
14-
const Blob& blob,
14+
const void* pointer,
15+
TypeMeta typeMeta,
1516
const string& name,
1617
SerializationAcceptor acceptor) override {
17-
const auto& tensor = blob.template Get<Int8TensorCPU>();
18+
CAFFE_ENFORCE(typeMeta.Match<Int8TensorCPU>());
19+
const auto& tensor = *static_cast<const Int8TensorCPU*>(pointer);
1820
BlobProto blob_proto;
1921
blob_proto.set_name(name);
2022
blob_proto.set_type("Int8TensorCPU");

caffe2/core/qtensor_serialization.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ class QTensorSerializer : public BlobSerializerBase {
1717
* Serializes a Blob. Note that this blob has to contain QTensor<Context>.
1818
*/
1919
void Serialize(
20-
const Blob& blob,
20+
const void* pointer,
21+
TypeMeta typeMeta,
2122
const string& name,
2223
SerializationAcceptor acceptor) override;
2324

@@ -34,10 +35,12 @@ class QTensorDeserializer : public BlobDeserializerBase {
3435

3536
template <class Context>
3637
void QTensorSerializer<Context>::Serialize(
37-
const Blob& blob,
38+
const void* pointer,
39+
TypeMeta typeMeta,
3840
const string& name,
3941
BlobSerializerBase::SerializationAcceptor acceptor) {
40-
const auto& qtensor = blob.template Get<QTensor<Context>>();
42+
CAFFE_ENFORCE(typeMeta.Match<QTensor<Context>>());
43+
const auto& qtensor = *static_cast<const QTensor<Context>*>(pointer);
4144
BlobProto blob_proto;
4245
blob_proto.set_name(name);
4346
blob_proto.set_type(kQTensorBlobQType);

caffe2/operators/counter_ops.cc

+5-3
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,11 @@ class CounterSerializer : public BlobSerializerBase {
139139
~CounterSerializer() {}
140140

141141
void Serialize(
142-
const Blob& blob,
142+
const void* pointer,
143+
TypeMeta typeMeta,
143144
const string& name,
144145
SerializationAcceptor acceptor) override {
145-
CAFFE_ENFORCE(blob.IsType<std::unique_ptr<Counter<int64_t>>>());
146+
CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<Counter<int64_t>>>());
146147

147148
BlobProto blob_proto;
148149
blob_proto.set_name(name);
@@ -152,7 +153,8 @@ class CounterSerializer : public BlobSerializerBase {
152153
proto.set_data_type(TensorProto_DataType_INT64);
153154
proto.add_dims(1);
154155
proto.add_int64_data(
155-
blob.template Get<std::unique_ptr<Counter<int64_t>>>()->retrieve());
156+
(*static_cast<const std::unique_ptr<Counter<int64_t>>*>(pointer))
157+
->retrieve());
156158
acceptor(name, blob_proto.SerializeAsString());
157159
}
158160
};

caffe2/operators/dataset_ops.cc

+8-4
Original file line numberDiff line numberDiff line change
@@ -1419,10 +1419,13 @@ class TreeCursorSerializer : public BlobSerializerBase {
14191419
~TreeCursorSerializer() {}
14201420

14211421
void Serialize(
1422-
const Blob& blob,
1422+
const void* pointer,
1423+
TypeMeta typeMeta,
14231424
const string& name,
14241425
SerializationAcceptor acceptor) override {
1425-
auto& cursor = blob.template Get<std::unique_ptr<TreeCursor>>();
1426+
CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<TreeCursor>>());
1427+
const auto& cursor =
1428+
*static_cast<const std::unique_ptr<TreeCursor>*>(pointer);
14261429
BlobProto blob_proto;
14271430

14281431
// serialize offsets as a tensor
@@ -1495,7 +1498,8 @@ REGISTER_BLOB_DESERIALIZER(std::unique_ptr<TreeCursor>, TreeCursorDeserializer);
14951498
} // namespace
14961499

14971500
void SharedTensorVectorPtrSerializer::Serialize(
1498-
const Blob& blob,
1501+
const void* pointer,
1502+
TypeMeta typeMeta,
14991503
const string& name,
15001504
BlobSerializerBase::SerializationAcceptor acceptor) {
15011505
/* This is dummy serialize that doesn't save anything. If saving the content
@@ -1504,7 +1508,7 @@ void SharedTensorVectorPtrSerializer::Serialize(
15041508
LastNWindowCollectorOp and ReservoirSamplingOp if this serializer actually
15051509
saves the content.
15061510
*/
1507-
CAFFE_ENFORCE(blob.IsType<std::shared_ptr<std::vector<TensorCPU>>>());
1511+
CAFFE_ENFORCE(typeMeta.Match<std::shared_ptr<std::vector<TensorCPU>>>());
15081512
BlobProto blob_proto;
15091513
blob_proto.set_name(name);
15101514
blob_proto.set_type("std::shared_ptr<std::vector<TensorCPU>>");

caffe2/operators/dataset_ops.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ using TensorVectorPtr = std::unique_ptr<std::vector<Tensor>>;
196196
class SharedTensorVectorPtrSerializer : public BlobSerializerBase {
197197
public:
198198
void Serialize(
199-
const Blob& blob,
199+
const void* pointer,
200+
TypeMeta typeMeta,
200201
const string& name,
201202
BlobSerializerBase::SerializationAcceptor acceptor) override;
202203
};

caffe2/operators/index_ops.cc

+4-2
Original file line numberDiff line numberDiff line change
@@ -348,10 +348,12 @@ class IndexSerializer : public BlobSerializerBase {
348348
~IndexSerializer() {}
349349

350350
void Serialize(
351-
const Blob& blob,
351+
const void* pointer,
352+
TypeMeta typeMeta,
352353
const string& name,
353354
SerializationAcceptor acceptor) override {
354-
auto& base = blob.template Get<std::unique_ptr<IndexBase>>();
355+
CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<IndexBase>>());
356+
const auto& base = *static_cast<const std::unique_ptr<IndexBase>*>(pointer);
355357
Blob tensor_blob;
356358
auto* tensor_out = BlobGetMutableTensor(&tensor_blob, CPU);
357359

caffe2/operators/map_ops.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,12 @@ class MapSerializer : public BlobSerializerBase {
195195
using MapType = typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
196196

197197
void Serialize(
198-
const Blob& blob,
198+
const void* pointer,
199+
TypeMeta typeMeta,
199200
const string& name,
200201
BlobSerializerBase::SerializationAcceptor acceptor) override {
201-
CAFFE_ENFORCE(blob.IsType<MapType>());
202-
const MapType& map_data = blob.template Get<MapType>();
202+
CAFFE_ENFORCE(typeMeta.Match<MapType>());
203+
const MapType& map_data = *static_cast<const MapType*>(pointer);
203204
int64_t sz = map_data.size();
204205
Tensor key_tensor(CPU);
205206
key_tensor.Resize(sz);

caffe2/sgd/iter_op.cc

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
namespace caffe2 {
99

1010
void MutexSerializer::Serialize(
11-
const Blob& blob,
11+
const void* pointer,
12+
TypeMeta typeMeta,
1213
const string& name,
1314
BlobSerializerBase::SerializationAcceptor acceptor) {
14-
CAFFE_ENFORCE(blob.IsType<std::unique_ptr<std::mutex>>());
15+
CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<std::mutex>>());
1516
BlobProto blob_proto;
1617
blob_proto.set_name(name);
1718
blob_proto.set_type("std::unique_ptr<std::mutex>");

caffe2/sgd/iter_op.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ class MutexSerializer : public BlobSerializerBase {
8888
* fatal error.
8989
*/
9090
void Serialize(
91-
const Blob& blob,
91+
const void* pointer,
92+
TypeMeta typeMeta,
9293
const string& name,
9394
BlobSerializerBase::SerializationAcceptor acceptor) override;
9495
};

0 commit comments

Comments
 (0)