Caffe2 - C++ API
A deep learning, cross platform ML framework
blob_serialization.h
1 #ifndef CAFFE2_CORE_BLOB_SERIALIZATION_H_
2 #define CAFFE2_CORE_BLOB_SERIALIZATION_H_
3 
4 #include <limits>
5 #include <future>
6 
7 #include <google/protobuf/repeated_field.h>
8 
9 #include "caffe2/core/blob.h"
10 #include "caffe2/core/blob_serializer_base.h"
11 #include "caffe2/core/tensor.h"
12 #include "caffe2/core/typeid.h"
13 #include "caffe2/core/types.h"
14 
15 CAFFE2_DECLARE_int(caffe2_tensor_chunk_size);
16 
17 namespace caffe2 {
18 
19 constexpr auto kTensorBlobType = "Tensor";
20 // String used to separate chunk id from the blob name when storing in DB
21 constexpr auto kChunkIdSeparator = "#%";
22 
23 // The Blob serialization registry and serializer creator functions.
24 CAFFE_DECLARE_TYPED_REGISTRY(
25  BlobSerializerRegistry,
26  CaffeTypeId,
27  BlobSerializerBase);
28 #define REGISTER_BLOB_SERIALIZER(id, ...) \
29  CAFFE_REGISTER_TYPED_CLASS(BlobSerializerRegistry, id, __VA_ARGS__)
30 // Creates an operator with the given operator definition.
31 inline unique_ptr<BlobSerializerBase> CreateSerializer(CaffeTypeId id) {
32  return BlobSerializerRegistry()->Create(id);
33 }
34 
41 template <class Context>
43  public:
44  TensorSerializer() : context_() {}
45  ~TensorSerializer() {}
50  void Serialize(
51  const Blob& blob,
52  const string& name,
53  SerializationAcceptor acceptor) override;
54  void SerializeWithChunkSize(
55  const Blob& blob,
56  const string& name,
57  SerializationAcceptor acceptor,
58  int chunk_size) override;
59 
60  void Serialize(const Tensor<Context>& tensor, const string& name,
61  TensorProto* proto, size_t chunkBegin, int32_t chunkSize);
62 
63  private:
64  // A utility function to store the device context detauls.
65  void StoreDeviceDetail(const Tensor<Context>& input, TensorProto* proto);
66  Context context_;
67 };
68 
74  public:
75  virtual ~BlobDeserializerBase() {}
76 
77  // Deserializes from a BlobProto object.
78  virtual void Deserialize(const BlobProto& proto, Blob* blob) = 0;
79 };
80 
81 CAFFE_DECLARE_REGISTRY(BlobDeserializerRegistry, BlobDeserializerBase);
82 #define REGISTER_BLOB_DESERIALIZER(name, ...) \
83  CAFFE_REGISTER_CLASS(BlobDeserializerRegistry, name, __VA_ARGS__)
84 // Creates an operator with the given operator definition.
85 inline unique_ptr<BlobDeserializerBase> CreateDeserializer(const string& type) {
86  return BlobDeserializerRegistry()->Create(type);
87 }
88 
97 template <class Context>
99  public:
100  void Deserialize(const BlobProto& proto, Blob* blob) override;
101  void Deserialize(const TensorProto& proto, Tensor<Context>* tensor);
102 };
103 
105 // Implementations
107 
108 namespace detail {
109 template <typename SrcType, typename DstType, class Context>
110 inline void CopyToProtoAsIs(
111  const size_t size,
112  const SrcType* src,
113  google::protobuf::RepeatedField<DstType>* field,
114  Context* context) {
115  static_assert(
116  sizeof(SrcType) == sizeof(DstType),
117  "The source type and dest type cannot be copied as-is. Did "
118  "you mean CopyToProtoWithCast?");
119  field->Reserve(size);
120  for (int i = 0; i < size; ++i) {
121  field->Add(0);
122  }
123  context->template Copy<SrcType, Context, CPUContext>(
124  size, src, reinterpret_cast<SrcType*>(field->mutable_data()));
125  // Make sure that we finish the copy into the protobuf.
126  context->FinishDeviceComputation();
127 }
128 
129 template <typename SrcType, typename DstType, class Context>
130 inline void CopyToProtoWithCast(
131  const size_t size,
132  const SrcType* src,
133  google::protobuf::RepeatedField<DstType>* field,
134  Context* context) {
135  // TODO: we are having one unnecessary copy here if the context is already
136  // CPUContext. Remove it if it is performance critical.
137  unique_ptr<SrcType[]> buffer(new SrcType[size]);
138  context->template Copy<SrcType, Context, CPUContext>(
139  size, src, buffer.get());
140  context->FinishDeviceComputation();
141  field->Reserve(size);
142  for (int i = 0; i < size; ++i) {
143  field->Add(static_cast<DstType>(buffer[i]));
144  }
145 }
146 
147 template <typename SrcType, typename DstType, class Context>
148 inline void CopyFromProtoAsIs(
149  const size_t size,
150  const google::protobuf::RepeatedField<SrcType>& field,
151  DstType* dst,
152  Context* context) {
153  static_assert(
154  sizeof(SrcType) == sizeof(DstType),
155  "The source type and dest type cannot be copied as-is. Did "
156  "you mean CopyFromProtoWithCast?");
157  CAFFE_ENFORCE_EQ(size, field.size(), "Incorrect proto field size.");
158  context->template Copy<DstType, CPUContext, Context>(
159  size, reinterpret_cast<const DstType*>(field.data()), dst);
160 }
161 
162 template <typename SrcType, typename DstType, class Context>
163 inline void CopyFromProtoWithCast(
164  const size_t size,
165  const google::protobuf::RepeatedField<SrcType>& field,
166  DstType* dst,
167  Context* context) {
168  CAFFE_ENFORCE_EQ(size, field.size(), "Incorrect proto field size.");
169  // TODO: we are having one unnecessary copy here if the context is already
170  // CPUContext. Remove it if it is performance critical.
171  unique_ptr<DstType[]> buffer(new DstType[size]);
172  const SrcType* src = field.data();
173  for (int i = 0; i < size; ++i) {
174  buffer[i] = static_cast<DstType>(src[i]);
175  }
176  context->template Copy<DstType, CPUContext, Context>(size, buffer.get(), dst);
177 }
178 
179 } // namespace detail
180 
181 template <class Context>
183  const Blob& blob,
184  const string& name,
185  BlobSerializerBase::SerializationAcceptor acceptor) {
186  this->SerializeWithChunkSize(blob, name, acceptor, kDefaultChunkSize);
187 }
188 
189 template <class Context>
191  const Blob& blob,
192  const string& name,
193  BlobSerializerBase::SerializationAcceptor acceptor,
194  int chunk_size) {
195  CAFFE_ENFORCE(blob.IsType<Tensor<Context>>());
196  const auto& tensor = blob.template Get<Tensor<Context>>();
197  if (chunk_size == kNoChunking) {
198  chunk_size = tensor.size() + 1; // to account for empty tensors
199  } else if (chunk_size == kDefaultChunkSize) {
200  chunk_size = FLAGS_caffe2_tensor_chunk_size;
201  }
202 
203 #ifndef __ANDROID__
204  std::vector<std::future<void>> futures;
205 #endif
206 
207  VLOG(1) << "Serializing blob " << name;
208  // Serialize whole vector. If vector is empty, it's shape still needs to be
209  // serialized in empty proto
210  for (size_t chunkBegin = 0;
211  chunkBegin < std::max(tensor.size(), static_cast<TIndex>(1));
212  chunkBegin += chunk_size) {
213  VLOG(2) << "Starting a chunk at " << chunkBegin;
214  auto task = [&](size_t chunkStart) {
215  BlobProto blob_proto;
216  blob_proto.set_name(name);
217  blob_proto.set_type(kTensorBlobType);
218  TensorProto& proto = *blob_proto.mutable_tensor();
219  proto.set_name(name);
220  this->Serialize(
221  tensor, name, blob_proto.mutable_tensor(), chunkStart, chunk_size);
222  acceptor(
223  MakeString(
224  name, kChunkIdSeparator, chunkStart / chunk_size),
225  blob_proto.SerializeAsString());
226  };
227 #ifndef __ANDROID__
228  if (tensor.size() > chunk_size) {
229  futures.emplace_back(std::async(std::launch::async, task, chunkBegin));
230  } else {
231  // Sync mode for small tensors
232  task(chunkBegin);
233  }
234 #else
235  // Since Android does not have std::future, we will always do sync mode
236  task(chunkBegin);
237 #endif
238  }
239 
240 #ifndef __ANDROID__
241  for (auto& fut : futures) {
242  fut.get();
243  }
244 #endif
245 }
246 
247 template <class Context>
249  const Tensor<Context>& input, const string& name,
250  TensorProto* proto_ptr, size_t chunkBegin, int32_t chunkSize) {
251  CAFFE_ENFORCE(
252  chunkBegin <= input.size(),
253  "Chunk begin is out of tensor: ",
254  chunkBegin,
255  ' ',
256  input.size());
257  if (chunkBegin + chunkSize > input.size()) {
258  chunkSize = input.size() - chunkBegin;
259  }
260 
261  CAFFE_ENFORCE(
262  input.raw_data() || chunkSize == 0,
263  "The input does not have data input yet. This is probably because you "
264  "created a tensor of non-zero shape but never filled its data via "
265  "mutable_data() calls. This means that it makes no sense to serialize "
266  "the tensor content.");
267 
268  TensorProto& proto = *proto_ptr;
269  proto.mutable_segment()->set_begin(chunkBegin);
270  proto.mutable_segment()->set_end(chunkBegin + chunkSize);
271 
272  for (int i = 0; i < input.ndim(); ++i) {
273  proto.add_dims(input.dim(i));
274  }
275  const TensorProto::DataType data_type = TypeMetaToDataType(input.meta());
276  proto.set_data_type(data_type);
277  StoreDeviceDetail(input, &proto);
278 
279  // A lot of copypaste is error prone. Should we create a macro for this?
280  switch (data_type) {
281  case TensorProto_DataType_FLOAT:
282  detail::CopyToProtoAsIs(
283  chunkSize,
284  input.template data<float>() + chunkBegin,
285  proto.mutable_float_data(),
286  &this->context_);
287  break;
288  case TensorProto_DataType_INT32:
289  detail::CopyToProtoAsIs(
290  chunkSize,
291  input.template data<int>() + chunkBegin,
292  proto.mutable_int32_data(),
293  &this->context_);
294  break;
295  case TensorProto_DataType_BYTE:
296  LOG(FATAL) << "This should not happen. When serializing, "
297  "BYTE is deprecated and moved to UINT8.";
298  break;
299  case TensorProto_DataType_STRING:
300  {
301  proto.mutable_string_data()->Reserve(chunkSize);
302  const string* content = input.template data<string>();
303  for (int i = chunkBegin; i < chunkBegin + chunkSize; ++i) {
304  proto.add_string_data(content[i]);
305  }
306  break;
307  }
308  case TensorProto_DataType_BOOL:
309  detail::CopyToProtoWithCast(
310  chunkSize,
311  input.template data<bool>() + chunkBegin,
312  proto.mutable_int32_data(),
313  &this->context_);
314  break;
315  case TensorProto_DataType_UINT8:
316  detail::CopyToProtoWithCast(
317  chunkSize,
318  input.template data<uint8_t>() + chunkBegin,
319  proto.mutable_int32_data(),
320  &this->context_);
321  break;
322  case TensorProto_DataType_INT8:
323  detail::CopyToProtoWithCast(
324  chunkSize,
325  input.template data<int8_t>() + chunkBegin,
326  proto.mutable_int32_data(),
327  &this->context_);
328  break;
329  case TensorProto_DataType_UINT16:
330  detail::CopyToProtoWithCast(
331  chunkSize,
332  input.template data<uint16_t>() + chunkBegin,
333  proto.mutable_int32_data(),
334  &this->context_);
335  break;
336  case TensorProto_DataType_INT16:
337  detail::CopyToProtoWithCast(
338  chunkSize,
339  input.template data<int16_t>() + chunkBegin,
340  proto.mutable_int32_data(),
341  &this->context_);
342  break;
343  case TensorProto_DataType_INT64:
344  detail::CopyToProtoAsIs(
345  chunkSize,
346  input.template data<int64_t>() + chunkBegin,
347  proto.mutable_int64_data(),
348  &this->context_);
349  break;
350  case TensorProto_DataType_FLOAT16:
351  detail::CopyToProtoWithCast(
352  chunkSize,
353  reinterpret_cast<const uint16_t*>(input.template data<float16>()) +
354  chunkBegin,
355  proto.mutable_int32_data(),
356  &this->context_);
357  break;
358  case TensorProto_DataType_DOUBLE:
359  detail::CopyToProtoAsIs(
360  chunkSize,
361  input.template data<double>() + chunkBegin,
362  proto.mutable_double_data(),
363  &this->context_);
364  break;
365  case TensorProto_DataType_UNDEFINED:
366  LOG(FATAL) << "TensorSerializer does not have a serialization "
367  "implementation for " << input.meta().name();
368  break;
369  // Note: we intentially do not provide "default:" so if any new data types
370  // are added, the compiler should warn the user to add the case here.
371  }
372 }
373 
374 template <class Context>
376  const BlobProto& blob_proto,
377  Blob* blob) {
378  Deserialize(blob_proto.tensor(), blob->GetMutable<Tensor<Context>>());
379 }
380 
381 template <class Context>
383  const TensorProto& proto,
384  Tensor<Context>* tensor) {
385  // We create a local context for deserializing. Since Caffe2 contexts are
386  // usually lightweighted, this should not involve too much overhead.
387  Context context(proto.device_detail());
388  context.SwitchToDevice(0);
389  vector<TIndex> dims;
390  for (const TIndex d : proto.dims()) {
391  dims.push_back(d);
392  }
393  tensor->Resize(dims);
394 
395  int64_t chunkBegin = 0;
396  auto chunkEnd = tensor->size();
397  if (proto.has_segment()) {
398  chunkBegin = proto.segment().begin();
399  chunkEnd = proto.segment().end();
400  }
401  CAFFE_ENFORCE(
402  0 <= chunkBegin && chunkBegin <= chunkEnd && chunkEnd <= tensor->size(),
403  "Invalid chunk ",
404  chunkBegin,
405  ' ',
406  chunkEnd,
407  " with total tensor size ",
408  tensor->size());
409  auto chunkSize = chunkEnd - chunkBegin;
410 
411  switch (proto.data_type()) {
412  case TensorProto_DataType_FLOAT:
413  detail::CopyFromProtoAsIs(
414  chunkSize,
415  proto.float_data(),
416  tensor->template mutable_data<float>() + chunkBegin,
417  &context);
418  break;
419  case TensorProto_DataType_INT32:
420  detail::CopyFromProtoAsIs(
421  chunkSize,
422  proto.int32_data(),
423  tensor->template mutable_data<int>() + chunkBegin,
424  &context);
425  break;
426  case TensorProto_DataType_BYTE:
427  // Since BYTE stores the data in a string field instead of a repreated
428  // field we will have it special cased.
429  CAFFE_ENFORCE_EQ(
430  chunkSize, proto.byte_data().size(), "Incorrect proto field size.");
431  context.template Copy<uint8_t, Context, CPUContext>(
432  chunkSize,
433  reinterpret_cast<const uint8_t*>(proto.byte_data().data()),
434  tensor->template mutable_data<uint8_t>() + chunkBegin);
435  break;
436  case TensorProto_DataType_STRING:
437  // Special handing of string because it is a non-fundamental type.
438  {
439  string* content = tensor->template mutable_data<string>();
440  for (int i = 0; i < chunkSize; ++i) {
441  content[i + chunkBegin] = proto.string_data(i);
442  }
443  }
444  break;
445  case TensorProto_DataType_BOOL:
446  detail::CopyFromProtoWithCast(
447  chunkSize,
448  proto.int32_data(),
449  tensor->template mutable_data<bool>() + chunkBegin,
450  &context);
451  break;
452  case TensorProto_DataType_UINT8:
453  detail::CopyFromProtoWithCast(
454  chunkSize,
455  proto.int32_data(),
456  tensor->template mutable_data<uint8_t>() + chunkBegin,
457  &context);
458  break;
459  case TensorProto_DataType_INT8:
460  detail::CopyFromProtoWithCast(
461  chunkSize,
462  proto.int32_data(),
463  tensor->template mutable_data<int8_t>() + chunkBegin,
464  &context);
465  break;
466  case TensorProto_DataType_UINT16:
467  detail::CopyFromProtoWithCast(
468  chunkSize,
469  proto.int32_data(),
470  tensor->template mutable_data<uint16_t>() + chunkBegin,
471  &context);
472  break;
473  case TensorProto_DataType_INT16:
474  detail::CopyFromProtoWithCast(
475  chunkSize,
476  proto.int32_data(),
477  tensor->template mutable_data<int16_t>() + chunkBegin,
478  &context);
479  break;
480  case TensorProto_DataType_INT64:
481  detail::CopyFromProtoAsIs(
482  chunkSize,
483  proto.int64_data(),
484  tensor->template mutable_data<int64_t>() + chunkBegin,
485  &context);
486  break;
487  case TensorProto_DataType_FLOAT16:
488  detail::CopyFromProtoWithCast(
489  chunkSize,
490  proto.int32_data(),
491  reinterpret_cast<uint16_t*>(
492  tensor->template mutable_data<float16>()) +
493  chunkBegin,
494  &context);
495  break;
496  case TensorProto_DataType_DOUBLE:
497  detail::CopyFromProtoAsIs(
498  chunkSize,
499  proto.double_data(),
500  tensor->template mutable_data<double>() + chunkBegin,
501  &context);
502  break;
503  case TensorProto_DataType_UNDEFINED:
504  CAFFE_THROW("Cannot deserialize from a TensorProto UNDEFINED data type.");
505  }
506  context.FinishDeviceComputation();
507 }
508 
509 } // namespace caffe2
510 
511 #endif // CAFFE2_CORE_BLOB_SERIALIZATION_H_
int ndim() const
Returns the number of dimensions of the data.
Definition: tensor.h:530
TIndex dim(const int i) const
Returns the i-th dimension of the tensor.
Definition: tensor.h:606
const void * raw_data() const
Returns a const raw void* pointer of the underlying storage.
Definition: tensor.h:421
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:534
const TypeMeta & meta() const
Returns the TypeMeta object associated with the current data type.
Definition: tensor.h:585
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
void Serialize(const Blob &blob, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
TensorSerializer is the serializer for Tensors.
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:73
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
BlobSerializerBase is an abstract class that serializes a blob to a string.
T * GetMutable(bool *is_new_object=nullptr)
Gets a mutable pointer to the stored object.
Definition: blob.h:96
void Resize(Ts... dim_source)
Resizes a tensor.
Definition: tensor.h:263
bool IsType() const
Checks if the content stored in the blob is of type T.
Definition: blob.h:56
TensorDeserializer is the deserializer for Tensors.