Caffe2 - C++ API
A deep learning, cross platform ML framework
proto_utils.cc
1 #include "caffe2/utils/proto_utils.h"
2 
3 #include <fcntl.h>
4 #include <cerrno>
5 #include <fstream>
6 
7 #include "google/protobuf/io/coded_stream.h"
8 #include "google/protobuf/io/zero_copy_stream_impl.h"
9 
10 #ifndef CAFFE2_USE_LITE_PROTO
11 #include "google/protobuf/text_format.h"
12 #endif // !CAFFE2_USE_LITE_PROTO
13 
14 #include "caffe2/core/logging.h"
15 
16 using ::google::protobuf::Message;
17 using ::google::protobuf::MessageLite;
18 
19 namespace caffe2 {
20 
21 std::string DeviceTypeName(const int32_t& d) {
22  switch (d) {
23  case CPU:
24  return "CPU";
25  case CUDA:
26  return "CUDA";
27  case MKLDNN:
28  return "MKLDNN";
29  default:
30  CAFFE_THROW(
31  "Unknown device: ",
32  d,
33  ". If you have recently updated the caffe2.proto file to add a new "
34  "device type, did you forget to update the TensorDeviceTypeName() "
35  "function to reflect such recent changes?");
36  // The below code won't run but is needed to suppress some compiler
37  // warnings.
38  return "";
39  }
40 };
41 
42 bool ReadStringFromFile(const char* filename, string* str) {
43  std::ifstream ifs(filename, std::ios::in);
44  if (!ifs) {
45  VLOG(1) << "File cannot be opened: " << filename
46  << " error: " << ifs.rdstate();
47  return false;
48  }
49  ifs.seekg(0, std::ios::end);
50  size_t n = ifs.tellg();
51  str->resize(n);
52  ifs.seekg(0);
53  ifs.read(&(*str)[0], n);
54  return true;
55 }
56 
57 bool WriteStringToFile(const string& str, const char* filename) {
58  std::ofstream ofs(filename, std::ios::out | std::ios::trunc);
59  if (!ofs.is_open()) {
60  VLOG(1) << "File cannot be created: " << filename
61  << " error: " << ofs.rdstate();
62  return false;
63  }
64  ofs << str;
65  return true;
66 }
67 
68 // IO-specific proto functions: we will deal with the protocol buffer lite and
69 // full versions differently.
70 
71 #ifdef CAFFE2_USE_LITE_PROTO
72 
73 // Lite runtime.
74 
75 namespace {
76 class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
77  public:
78  explicit IfstreamInputStream(const string& filename)
79  : ifs_(filename.c_str(), std::ios::in | std::ios::binary) {}
80  ~IfstreamInputStream() { ifs_.close(); }
81 
82  int Read(void* buffer, int size) {
83  if (!ifs_) {
84  return -1;
85  }
86  ifs_.read(static_cast<char*>(buffer), size);
87  return ifs_.gcount();
88  }
89 
90  private:
91  std::ifstream ifs_;
92 };
93 } // namespace
94 
95 bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
96  ::google::protobuf::io::CopyingInputStreamAdaptor stream(
97  new IfstreamInputStream(filename));
98  stream.SetOwnsCopyingStream(true);
99  // Total bytes hard limit / warning limit are set to 1GB and 512MB
100  // respectively.
101  ::google::protobuf::io::CodedInputStream coded_stream(&stream);
102  coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
103  return proto->ParseFromCodedStream(&coded_stream);
104 }
105 
106 void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) {
107  LOG(FATAL) << "Not implemented yet.";
108 }
109 
110 #else // CAFFE2_USE_LITE_PROTO
111 
112 // Full protocol buffer.
113 
114 using ::google::protobuf::io::FileInputStream;
115 using ::google::protobuf::io::FileOutputStream;
116 using ::google::protobuf::io::ZeroCopyInputStream;
117 using ::google::protobuf::io::CodedInputStream;
118 using ::google::protobuf::io::ZeroCopyOutputStream;
119 using ::google::protobuf::io::CodedOutputStream;
120 
121 bool ReadProtoFromTextFile(const char* filename, Message* proto) {
122  int fd = open(filename, O_RDONLY);
123  CAFFE_ENFORCE_NE(fd, -1, "File not found: ", filename);
124  FileInputStream* input = new FileInputStream(fd);
125  bool success = google::protobuf::TextFormat::Parse(input, proto);
126  delete input;
127  close(fd);
128  return success;
129 }
130 
131 void WriteProtoToTextFile(const Message& proto, const char* filename) {
132  int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
133  FileOutputStream* output = new FileOutputStream(fd);
134  CAFFE_ENFORCE(google::protobuf::TextFormat::Print(proto, output));
135  delete output;
136  close(fd);
137 }
138 
139 bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
140 #if defined (_MSC_VER) // for MSC compiler binary flag needs to be specified
141  int fd = open(filename, O_RDONLY | O_BINARY);
142 #else
143  int fd = open(filename, O_RDONLY);
144 #endif
145  CAFFE_ENFORCE_NE(fd, -1, "File not found: ", filename);
146  std::unique_ptr<ZeroCopyInputStream> raw_input(new FileInputStream(fd));
147  std::unique_ptr<CodedInputStream> coded_input(
148  new CodedInputStream(raw_input.get()));
149  // A hack to manually allow using very large protocol buffers.
150  coded_input->SetTotalBytesLimit(1073741824, 536870912);
151  bool success = proto->ParseFromCodedStream(coded_input.get());
152  coded_input.reset();
153  raw_input.reset();
154  close(fd);
155  return success;
156 }
157 
158 void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) {
159  int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
160  CAFFE_ENFORCE_NE(
161  fd, -1, "File cannot be created: ", filename, " error number: ", errno);
162  std::unique_ptr<ZeroCopyOutputStream> raw_output(new FileOutputStream(fd));
163  std::unique_ptr<CodedOutputStream> coded_output(
164  new CodedOutputStream(raw_output.get()));
165  CAFFE_ENFORCE(proto.SerializeToCodedStream(coded_output.get()));
166  coded_output.reset();
167  raw_output.reset();
168  close(fd);
169 }
170 
171 #endif // CAFFE2_USE_LITE_PROTO
172 
173 ArgumentHelper::ArgumentHelper(const OperatorDef& def) {
174  for (auto& arg : def.arg()) {
175  if (arg_map_.count(arg.name())) {
176  if (arg.SerializeAsString() !=
177  arg_map_[arg.name()]->SerializeAsString()) {
178  // If there are two arguments of the same name but different contents,
179  // we will throw an error.
180  CAFFE_THROW(
181  "Found argument of the same name ",
182  arg.name(),
183  "but with different contents.",
184  ProtoDebugString(def));
185  } else {
186  LOG(WARNING) << "Duplicated argument name found in operator def: "
187  << ProtoDebugString(def);
188  }
189  }
190  arg_map_[arg.name()] = &arg;
191  }
192 }
193 
194 ArgumentHelper::ArgumentHelper(const NetDef& netdef) {
195  for (auto& arg : netdef.arg()) {
196  CAFFE_ENFORCE(
197  arg_map_.count(arg.name()) == 0,
198  "Duplicated argument name found in net def: ",
199  ProtoDebugString(netdef));
200  arg_map_[arg.name()] = &arg;
201  }
202 }
203 
204 bool ArgumentHelper::HasArgument(const string& name) const {
205  return arg_map_.count(name);
206 }
207 
208 namespace {
209 // Helper function to verify that conversion between types won't loose any
210 // significant bit.
211 template <typename InputType, typename TargetType>
212 bool SupportsLosslessConversion(const InputType& value) {
213  return static_cast<InputType>(static_cast<TargetType>(value)) == value;
214 }
215 }
216 
217 #define INSTANTIATE_GET_SINGLE_ARGUMENT( \
218  T, fieldname, enforce_lossless_conversion) \
219  template <> \
220  T ArgumentHelper::GetSingleArgument<T>( \
221  const string& name, const T& default_value) const { \
222  if (arg_map_.count(name) == 0) { \
223  VLOG(1) << "Using default parameter value " << default_value \
224  << " for parameter " << name; \
225  return default_value; \
226  } \
227  CAFFE_ENFORCE( \
228  arg_map_.at(name)->has_##fieldname(), \
229  "Argument ", \
230  name, \
231  " does not have the right field: expected field " #fieldname); \
232  auto value = arg_map_.at(name)->fieldname(); \
233  if (enforce_lossless_conversion) { \
234  auto supportsConversion = \
235  SupportsLosslessConversion<decltype(value), T>(value); \
236  CAFFE_ENFORCE( \
237  supportsConversion, \
238  "Value", \
239  value, \
240  " of argument ", \
241  name, \
242  "cannot be represented correctly in a target type"); \
243  } \
244  return value; \
245  } \
246  template <> \
247  bool ArgumentHelper::HasSingleArgumentOfType<T>(const string& name) const { \
248  if (arg_map_.count(name) == 0) { \
249  return false; \
250  } \
251  return arg_map_.at(name)->has_##fieldname(); \
252  }
253 
254 INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false)
255 INSTANTIATE_GET_SINGLE_ARGUMENT(double, f, false)
256 INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i, false)
257 INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true)
258 INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true)
259 INSTANTIATE_GET_SINGLE_ARGUMENT(int, i, true)
260 INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true)
261 INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true)
262 INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true)
263 INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true)
264 INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false)
265 #undef INSTANTIATE_GET_SINGLE_ARGUMENT
266 
267 #define INSTANTIATE_GET_REPEATED_ARGUMENT( \
268  T, fieldname, enforce_lossless_conversion) \
269  template <> \
270  vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
271  const string& name, const std::vector<T>& default_value) const { \
272  if (arg_map_.count(name) == 0) { \
273  return default_value; \
274  } \
275  vector<T> values; \
276  for (const auto& v : arg_map_.at(name)->fieldname()) { \
277  if (enforce_lossless_conversion) { \
278  auto supportsConversion = \
279  SupportsLosslessConversion<decltype(v), T>(v); \
280  CAFFE_ENFORCE( \
281  supportsConversion, \
282  "Value", \
283  v, \
284  " of argument ", \
285  name, \
286  "cannot be represented correctly in a target type"); \
287  } \
288  values.push_back(v); \
289  } \
290  return values; \
291  }
292 
293 INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats, false)
294 INSTANTIATE_GET_REPEATED_ARGUMENT(double, floats, false)
295 INSTANTIATE_GET_REPEATED_ARGUMENT(bool, ints, false)
296 INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true)
297 INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true)
298 INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints, true)
299 INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true)
300 INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true)
301 INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true)
302 INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true)
303 INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false)
304 #undef INSTANTIATE_GET_REPEATED_ARGUMENT
305 
306 #define CAFFE2_MAKE_SINGULAR_ARGUMENT(T, fieldname) \
307 template <> \
308 Argument MakeArgument(const string& name, const T& value) { \
309  Argument arg; \
310  arg.set_name(name); \
311  arg.set_##fieldname(value); \
312  return arg; \
313 }
314 
315 CAFFE2_MAKE_SINGULAR_ARGUMENT(bool, i)
316 CAFFE2_MAKE_SINGULAR_ARGUMENT(float, f)
317 CAFFE2_MAKE_SINGULAR_ARGUMENT(int, i)
318 CAFFE2_MAKE_SINGULAR_ARGUMENT(int64_t, i)
319 CAFFE2_MAKE_SINGULAR_ARGUMENT(string, s)
320 #undef CAFFE2_MAKE_SINGULAR_ARGUMENT
321 
322 template <>
323 Argument MakeArgument(const string& name, const MessageLite& value) {
324  Argument arg;
325  arg.set_name(name);
326  arg.set_s(value.SerializeAsString());
327  return arg;
328 }
329 
330 #define CAFFE2_MAKE_REPEATED_ARGUMENT(T, fieldname) \
331 template <> \
332 Argument MakeArgument(const string& name, const vector<T>& value) { \
333  Argument arg; \
334  arg.set_name(name); \
335  for (const auto& v : value) { \
336  arg.add_##fieldname(v); \
337  } \
338  return arg; \
339 }
340 
341 CAFFE2_MAKE_REPEATED_ARGUMENT(float, floats)
342 CAFFE2_MAKE_REPEATED_ARGUMENT(int, ints)
343 CAFFE2_MAKE_REPEATED_ARGUMENT(int64_t, ints)
344 CAFFE2_MAKE_REPEATED_ARGUMENT(string, strings)
345 #undef CAFFE2_MAKE_REPEATED_ARGUMENT
346 
347 const Argument& GetArgument(const OperatorDef& def, const string& name) {
348  for (const Argument& arg : def.arg()) {
349  if (arg.name() == name) {
350  return arg;
351  }
352  }
353  CAFFE_THROW(
354  "Argument named ",
355  name,
356  " does not exist in operator ",
357  ProtoDebugString(def));
358 }
359 
360 bool GetFlagArgument(
361  const OperatorDef& def,
362  const string& name,
363  bool def_value) {
364  for (const Argument& arg : def.arg()) {
365  if (arg.name() == name) {
366  CAFFE_ENFORCE(
367  arg.has_i(), "Can't parse argument as bool: ", ProtoDebugString(arg));
368  return arg.i();
369  }
370  }
371  return def_value;
372 }
373 
374 Argument* GetMutableArgument(
375  const string& name,
376  const bool create_if_missing,
377  OperatorDef* def) {
378  for (int i = 0; i < def->arg_size(); ++i) {
379  if (def->arg(i).name() == name) {
380  return def->mutable_arg(i);
381  }
382  }
383  // If no argument of the right name is found...
384  if (create_if_missing) {
385  Argument* arg = def->add_arg();
386  arg->set_name(name);
387  return arg;
388  } else {
389  return nullptr;
390  }
391 }
392 
393 } // namespace caffe2
Definition: types.h:57
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...