1 #ifndef CAFFE2_CORE_BLOB_SERIALIZATION_H_ 2 #define CAFFE2_CORE_BLOB_SERIALIZATION_H_ 7 #include <google/protobuf/repeated_field.h> 9 #include "caffe2/core/blob.h" 10 #include "caffe2/core/blob_serializer_base.h" 11 #include "caffe2/core/tensor.h" 12 #include "caffe2/core/typeid.h" 13 #include "caffe2/core/types.h" 15 CAFFE2_DECLARE_int(caffe2_tensor_chunk_size);
19 constexpr
auto kTensorBlobType =
"Tensor";
21 constexpr
auto kChunkIdSeparator =
"#%";
24 CAFFE_DECLARE_TYPED_REGISTRY(
25 BlobSerializerRegistry,
28 #define REGISTER_BLOB_SERIALIZER(id, ...) \ 29 CAFFE_REGISTER_TYPED_CLASS(BlobSerializerRegistry, id, __VA_ARGS__) 31 inline unique_ptr<BlobSerializerBase> CreateSerializer(CaffeTypeId
id) {
32 return BlobSerializerRegistry()->Create(
id);
41 template <
class Context>
53 SerializationAcceptor acceptor)
override;
54 void SerializeWithChunkSize(
57 SerializationAcceptor acceptor,
58 int chunk_size)
override;
61 TensorProto* proto,
size_t chunkBegin, int32_t chunkSize);
65 void StoreDeviceDetail(
const Tensor<Context>& input, TensorProto* proto);
78 virtual void Deserialize(
const BlobProto& proto,
Blob* blob) = 0;
82 #define REGISTER_BLOB_DESERIALIZER(name, ...) \ 83 CAFFE_REGISTER_CLASS(BlobDeserializerRegistry, name, __VA_ARGS__) 85 inline unique_ptr<BlobDeserializerBase> CreateDeserializer(
const string& type) {
86 return BlobDeserializerRegistry()->Create(type);
97 template <
class Context>
100 void Deserialize(
const BlobProto& proto,
Blob* blob)
override;
109 template <
typename SrcType,
typename DstType,
class Context>
110 inline void CopyToProtoAsIs(
113 google::protobuf::RepeatedField<DstType>* field,
116 sizeof(SrcType) ==
sizeof(DstType),
117 "The source type and dest type cannot be copied as-is. Did " 118 "you mean CopyToProtoWithCast?");
119 field->Reserve(size);
120 for (
int i = 0; i < size; ++i) {
123 context->template Copy<SrcType, Context, CPUContext>(
124 size, src,
reinterpret_cast<SrcType*
>(field->mutable_data()));
126 context->FinishDeviceComputation();
129 template <
typename SrcType,
typename DstType,
class Context>
130 inline void CopyToProtoWithCast(
133 google::protobuf::RepeatedField<DstType>* field,
137 unique_ptr<SrcType[]> buffer(
new SrcType[size]);
138 context->template Copy<SrcType, Context, CPUContext>(
139 size, src, buffer.get());
140 context->FinishDeviceComputation();
141 field->Reserve(size);
142 for (
int i = 0; i < size; ++i) {
143 field->Add(static_cast<DstType>(buffer[i]));
147 template <
typename SrcType,
typename DstType,
class Context>
148 inline void CopyFromProtoAsIs(
150 const google::protobuf::RepeatedField<SrcType>& field,
154 sizeof(SrcType) ==
sizeof(DstType),
155 "The source type and dest type cannot be copied as-is. Did " 156 "you mean CopyFromProtoWithCast?");
157 CAFFE_ENFORCE_EQ(size, field.size(),
"Incorrect proto field size.");
158 context->template Copy<DstType, CPUContext, Context>(
159 size,
reinterpret_cast<const DstType*
>(field.data()), dst);
162 template <
typename SrcType,
typename DstType,
class Context>
163 inline void CopyFromProtoWithCast(
165 const google::protobuf::RepeatedField<SrcType>& field,
168 CAFFE_ENFORCE_EQ(size, field.size(),
"Incorrect proto field size.");
171 unique_ptr<DstType[]> buffer(
new DstType[size]);
172 const SrcType* src = field.data();
173 for (
int i = 0; i < size; ++i) {
174 buffer[i] =
static_cast<DstType
>(src[i]);
176 context->template Copy<DstType, CPUContext, Context>(size, buffer.get(), dst);
181 template <
class Context>
185 BlobSerializerBase::SerializationAcceptor acceptor) {
186 this->SerializeWithChunkSize(blob, name, acceptor, kDefaultChunkSize);
189 template <
class Context>
193 BlobSerializerBase::SerializationAcceptor acceptor,
196 const auto& tensor = blob.template Get<Tensor<Context>>();
197 if (chunk_size == kNoChunking) {
198 chunk_size = tensor.size() + 1;
199 }
else if (chunk_size == kDefaultChunkSize) {
200 chunk_size = FLAGS_caffe2_tensor_chunk_size;
204 std::vector<std::future<void>> futures;
207 VLOG(1) <<
"Serializing blob " << name;
210 for (
size_t chunkBegin = 0;
211 chunkBegin < std::max(tensor.size(),
static_cast<TIndex
>(1));
212 chunkBegin += chunk_size) {
213 VLOG(2) <<
"Starting a chunk at " << chunkBegin;
214 auto task = [&](
size_t chunkStart) {
215 BlobProto blob_proto;
216 blob_proto.set_name(name);
217 blob_proto.set_type(kTensorBlobType);
218 TensorProto& proto = *blob_proto.mutable_tensor();
219 proto.set_name(name);
221 tensor, name, blob_proto.mutable_tensor(), chunkStart, chunk_size);
224 name, kChunkIdSeparator, chunkStart / chunk_size),
225 blob_proto.SerializeAsString());
228 if (tensor.size() > chunk_size) {
229 futures.emplace_back(std::async(std::launch::async, task, chunkBegin));
241 for (
auto& fut : futures) {
247 template <
class Context>
250 TensorProto* proto_ptr,
size_t chunkBegin, int32_t chunkSize) {
252 chunkBegin <= input.
size(),
253 "Chunk begin is out of tensor: ",
257 if (chunkBegin + chunkSize > input.
size()) {
258 chunkSize = input.
size() - chunkBegin;
263 "The input does not have data input yet. This is probably because you " 264 "created a tensor of non-zero shape but never filled its data via " 265 "mutable_data() calls. This means that it makes no sense to serialize " 266 "the tensor content.");
268 TensorProto& proto = *proto_ptr;
269 proto.mutable_segment()->set_begin(chunkBegin);
270 proto.mutable_segment()->set_end(chunkBegin + chunkSize);
272 for (
int i = 0; i < input.
ndim(); ++i) {
273 proto.add_dims(input.
dim(i));
275 const TensorProto::DataType data_type = TypeMetaToDataType(input.
meta());
276 proto.set_data_type(data_type);
277 StoreDeviceDetail(input, &proto);
281 case TensorProto_DataType_FLOAT:
282 detail::CopyToProtoAsIs(
284 input.template data<float>() + chunkBegin,
285 proto.mutable_float_data(),
288 case TensorProto_DataType_INT32:
289 detail::CopyToProtoAsIs(
291 input.template data<int>() + chunkBegin,
292 proto.mutable_int32_data(),
295 case TensorProto_DataType_BYTE:
296 LOG(FATAL) <<
"This should not happen. When serializing, " 297 "BYTE is deprecated and moved to UINT8.";
299 case TensorProto_DataType_STRING:
301 proto.mutable_string_data()->Reserve(chunkSize);
302 const string* content = input.template data<string>();
303 for (
int i = chunkBegin; i < chunkBegin + chunkSize; ++i) {
304 proto.add_string_data(content[i]);
308 case TensorProto_DataType_BOOL:
309 detail::CopyToProtoWithCast(
311 input.template data<bool>() + chunkBegin,
312 proto.mutable_int32_data(),
315 case TensorProto_DataType_UINT8:
316 detail::CopyToProtoWithCast(
318 input.template data<uint8_t>() + chunkBegin,
319 proto.mutable_int32_data(),
322 case TensorProto_DataType_INT8:
323 detail::CopyToProtoWithCast(
325 input.template data<int8_t>() + chunkBegin,
326 proto.mutable_int32_data(),
329 case TensorProto_DataType_UINT16:
330 detail::CopyToProtoWithCast(
332 input.template data<uint16_t>() + chunkBegin,
333 proto.mutable_int32_data(),
336 case TensorProto_DataType_INT16:
337 detail::CopyToProtoWithCast(
339 input.template data<int16_t>() + chunkBegin,
340 proto.mutable_int32_data(),
343 case TensorProto_DataType_INT64:
344 detail::CopyToProtoAsIs(
346 input.template data<int64_t>() + chunkBegin,
347 proto.mutable_int64_data(),
350 case TensorProto_DataType_FLOAT16:
351 detail::CopyToProtoWithCast(
353 reinterpret_cast<const uint16_t*>(input.template data<float16>()) +
355 proto.mutable_int32_data(),
358 case TensorProto_DataType_DOUBLE:
359 detail::CopyToProtoAsIs(
361 input.template data<double>() + chunkBegin,
362 proto.mutable_double_data(),
365 case TensorProto_DataType_UNDEFINED:
366 LOG(FATAL) <<
"TensorSerializer does not have a serialization " 367 "implementation for " << input.
meta().name();
374 template <
class Context>
376 const BlobProto& blob_proto,
381 template <
class Context>
383 const TensorProto& proto,
387 Context context(proto.device_detail());
388 context.SwitchToDevice(0);
390 for (
const TIndex d : proto.dims()) {
395 int64_t chunkBegin = 0;
396 auto chunkEnd = tensor->
size();
397 if (proto.has_segment()) {
398 chunkBegin = proto.segment().begin();
399 chunkEnd = proto.segment().end();
402 0 <= chunkBegin && chunkBegin <= chunkEnd && chunkEnd <= tensor->size(),
407 " with total tensor size ",
409 auto chunkSize = chunkEnd - chunkBegin;
411 switch (proto.data_type()) {
412 case TensorProto_DataType_FLOAT:
413 detail::CopyFromProtoAsIs(
416 tensor->template mutable_data<float>() + chunkBegin,
419 case TensorProto_DataType_INT32:
420 detail::CopyFromProtoAsIs(
423 tensor->template mutable_data<int>() + chunkBegin,
426 case TensorProto_DataType_BYTE:
430 chunkSize, proto.byte_data().size(),
"Incorrect proto field size.");
431 context.template Copy<uint8_t, Context, CPUContext>(
433 reinterpret_cast<const uint8_t*
>(proto.byte_data().data()),
434 tensor->template mutable_data<uint8_t>() + chunkBegin);
436 case TensorProto_DataType_STRING:
439 string* content = tensor->template mutable_data<string>();
440 for (
int i = 0; i < chunkSize; ++i) {
441 content[i + chunkBegin] = proto.string_data(i);
445 case TensorProto_DataType_BOOL:
446 detail::CopyFromProtoWithCast(
449 tensor->template mutable_data<bool>() + chunkBegin,
452 case TensorProto_DataType_UINT8:
453 detail::CopyFromProtoWithCast(
456 tensor->template mutable_data<uint8_t>() + chunkBegin,
459 case TensorProto_DataType_INT8:
460 detail::CopyFromProtoWithCast(
463 tensor->template mutable_data<int8_t>() + chunkBegin,
466 case TensorProto_DataType_UINT16:
467 detail::CopyFromProtoWithCast(
470 tensor->template mutable_data<uint16_t>() + chunkBegin,
473 case TensorProto_DataType_INT16:
474 detail::CopyFromProtoWithCast(
477 tensor->template mutable_data<int16_t>() + chunkBegin,
480 case TensorProto_DataType_INT64:
481 detail::CopyFromProtoAsIs(
484 tensor->template mutable_data<int64_t>() + chunkBegin,
487 case TensorProto_DataType_FLOAT16:
488 detail::CopyFromProtoWithCast(
491 reinterpret_cast<uint16_t*
>(
492 tensor->template mutable_data<float16>()) +
496 case TensorProto_DataType_DOUBLE:
497 detail::CopyFromProtoAsIs(
500 tensor->template mutable_data<double>() + chunkBegin,
503 case TensorProto_DataType_UNDEFINED:
504 CAFFE_THROW(
"Cannot deserialize from a TensorProto UNDEFINED data type.");
506 context.FinishDeviceComputation();
511 #endif // CAFFE2_CORE_BLOB_SERIALIZATION_H_ int ndim() const
Returns the number of dimensions of the data.
TIndex dim(const int i) const
Returns the i-th dimension of the tensor.
const void * raw_data() const
Returns a const raw void* pointer of the underlying storage.
TIndex size() const
Returns the size (i.e.
const TypeMeta & meta() const
Returns the TypeMeta object associated with the current data type.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
void Serialize(const Blob &blob, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
TensorSerializer is the serializer for Tensors.
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
Blob is a general container that hosts a typed pointer.
BlobSerializerBase is an abstract class that serializes a blob to a string.
T * GetMutable(bool *is_new_object=nullptr)
Gets a mutable pointer to the stored object.
void Resize(Ts... dim_source)
Resizes a tensor.
bool IsType() const
Checks if the content stored in the blob is of type T.
TensorDeserializer is the deserializer for Tensors.