Caffe2 - C++ API
A deep learning, cross platform ML framework
proto_utils.h
1 #ifndef CAFFE2_UTILS_PROTO_UTILS_H_
2 #define CAFFE2_UTILS_PROTO_UTILS_H_
3 
4 #include "google/protobuf/message_lite.h"
5 #ifndef CAFFE2_USE_LITE_PROTO
6 #include "google/protobuf/message.h"
7 #endif // !CAFFE2_USE_LITE_PROTO
8 
9 #include "caffe2/core/logging.h"
10 #include "caffe2/proto/caffe2.pb.h"
11 #include "caffe2/proto/predictor_consts.pb.h"
12 
13 namespace caffe2 {
14 
15 using std::string;
16 using ::google::protobuf::MessageLite;
17 
18 // A wrapper function to return device name string for use in blob serialization
19 // / deserialization. This should have one to one correspondence with
20 // caffe2/proto/caffe2.proto: enum DeviceType.
21 //
22 // Note that we can't use DeviceType_Name, because that is only available in
23 // protobuf-full, and some platforms (like mobile) may want to use
24 // protobuf-lite instead.
25 std::string DeviceTypeName(const int32_t& d);
26 
27 // Common interfaces that reads file contents into a string.
28 bool ReadStringFromFile(const char* filename, string* str);
29 bool WriteStringToFile(const string& str, const char* filename);
30 
31 // Common interfaces that are supported by both lite and full protobuf.
32 bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto);
33 inline bool ReadProtoFromBinaryFile(const string filename, MessageLite* proto) {
34  return ReadProtoFromBinaryFile(filename.c_str(), proto);
35 }
36 
37 void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename);
38 inline void WriteProtoToBinaryFile(const MessageLite& proto,
39  const string& filename) {
40  return WriteProtoToBinaryFile(proto, filename.c_str());
41 }
42 
43 #ifdef CAFFE2_USE_LITE_PROTO
44 
45 inline string ProtoDebugString(const MessageLite& proto) {
46  return proto.SerializeAsString();
47 }
48 
49 // Text format MessageLite wrappers: these functions do nothing but just
50 // allowing things to compile. It will produce a runtime error if you are using
51 // MessageLite but still want text support.
52 inline bool ReadProtoFromTextFile(const char* filename, MessageLite* proto) {
53  LOG(FATAL) << "If you are running lite version, you should not be "
54  << "calling any text-format protobuffers.";
55  return false; // Just to suppress compiler warning.
56 }
57 inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) {
58  return ReadProtoFromTextFile(filename.c_str(), proto);
59 }
60 
61 inline void WriteProtoToTextFile(const MessageLite& proto,
62  const char* filename) {
63  LOG(FATAL) << "If you are running lite version, you should not be "
64  << "calling any text-format protobuffers.";
65 }
66 inline void WriteProtoToTextFile(const MessageLite& proto,
67  const string& filename) {
68  return WriteProtoToTextFile(proto, filename.c_str());
69 }
70 
71 inline bool ReadProtoFromFile(const char* filename, MessageLite* proto) {
72  return (ReadProtoFromBinaryFile(filename, proto) ||
73  ReadProtoFromTextFile(filename, proto));
74 }
75 
76 inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) {
77  return ReadProtoFromFile(filename.c_str(), proto);
78 }
79 
80 #else // CAFFE2_USE_LITE_PROTO
81 
82 using ::google::protobuf::Message;
83 
84 inline string ProtoDebugString(const Message& proto) {
85  return proto.ShortDebugString();
86 }
87 
88 bool ReadProtoFromTextFile(const char* filename, Message* proto);
89 inline bool ReadProtoFromTextFile(const string filename, Message* proto) {
90  return ReadProtoFromTextFile(filename.c_str(), proto);
91 }
92 
93 void WriteProtoToTextFile(const Message& proto, const char* filename);
94 inline void WriteProtoToTextFile(const Message& proto, const string& filename) {
95  return WriteProtoToTextFile(proto, filename.c_str());
96 }
97 
98 // Read Proto from a file, letting the code figure out if it is text or binary.
99 inline bool ReadProtoFromFile(const char* filename, Message* proto) {
100  return (ReadProtoFromBinaryFile(filename, proto) ||
101  ReadProtoFromTextFile(filename, proto));
102 }
103 
104 inline bool ReadProtoFromFile(const string& filename, Message* proto) {
105  return ReadProtoFromFile(filename.c_str(), proto);
106 }
107 
108 #endif // CAFFE2_USE_LITE_PROTO
109 
110 
111 template <class IterableInputs, class IterableOutputs, class IterableArgs>
112 OperatorDef CreateOperatorDef(
113  const string& type, const string& name, const IterableInputs& inputs,
114  const IterableOutputs& outputs, const IterableArgs& args,
115  const DeviceOption& device_option, const string& engine) {
116  OperatorDef def;
117  def.set_type(type);
118  def.set_name(name);
119  for (const string& in : inputs) {
120  def.add_input(in);
121  }
122  for (const string& out : outputs) {
123  def.add_output(out);
124  }
125  for (const Argument& arg : args) {
126  def.add_arg()->CopyFrom(arg);
127  }
128  if (device_option.has_device_type()) {
129  def.mutable_device_option()->CopyFrom(device_option);
130  }
131  if (engine.size()) {
132  def.set_engine(engine);
133  }
134  return def;
135 }
136 
137 // A simplified version compared to the full CreateOperator, if you do not need
138 // to specify device option or engine.
139 template <class IterableInputs, class IterableOutputs, class IterableArgs>
140 inline OperatorDef CreateOperatorDef(
141  const string& type, const string& name, const IterableInputs& inputs,
142  const IterableOutputs& outputs, const IterableArgs& args) {
143  return CreateOperatorDef(
144  type, name, inputs, outputs, args, DeviceOption(), "");
145 }
146 
147 // A simplified version compared to the full CreateOperator, if you do not need
148 // to specify device option or engine or args.
149 template <class IterableInputs, class IterableOutputs>
150 inline OperatorDef CreateOperatorDef(
151  const string& type, const string& name, const IterableInputs& inputs,
152  const IterableOutputs& outputs) {
153  return CreateOperatorDef(type, name, inputs, outputs,
154  std::vector<Argument>(), DeviceOption(), "");
155 }
156 
157 inline bool HasArgument(const OperatorDef& def, const string& name) {
158  for (const Argument& arg : def.arg()) {
159  if (arg.name() == name) {
160  return true;
161  }
162  }
163  return false;
164 }
165 
175  public:
176  explicit ArgumentHelper(const OperatorDef& def);
177  explicit ArgumentHelper(const NetDef& netdef);
178  bool HasArgument(const string& name) const;
179 
180  template <typename T>
181  T GetSingleArgument(const string& name, const T& default_value) const;
182  template <typename T>
183  bool HasSingleArgumentOfType(const string& name) const;
184  template <typename T>
185  vector<T> GetRepeatedArgument(
186  const string& name,
187  const std::vector<T>& default_value = std::vector<T>()) const;
188 
189  template <typename MessageType>
190  MessageType GetMessageArgument(const string& name) const {
191  CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
192  MessageType message;
193  if (arg_map_.at(name)->has_s()) {
194  CAFFE_ENFORCE(
195  message.ParseFromString(arg_map_.at(name)->s()),
196  "Faild to parse content from the string");
197  } else {
198  VLOG(1) << "Return empty message for parameter " << name;
199  }
200  return message;
201  }
202 
203  template <typename MessageType>
204  vector<MessageType> GetRepeatedMessageArgument(const string& name) const {
205  CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
206  vector<MessageType> messages(arg_map_.at(name)->strings_size());
207  for (int i = 0; i < messages.size(); ++i) {
208  CAFFE_ENFORCE(
209  messages[i].ParseFromString(arg_map_.at(name)->strings(i)),
210  "Faild to parse content from the string");
211  }
212  return messages;
213  }
214 
215  private:
216  CaffeMap<string, const Argument*> arg_map_;
217 };
218 
219 const Argument& GetArgument(const OperatorDef& def, const string& name);
220 bool GetFlagArgument(
221  const OperatorDef& def,
222  const string& name,
223  bool def_value = false);
224 
225 Argument* GetMutableArgument(
226  const string& name,
227  const bool create_if_missing,
228  OperatorDef* def);
229 
230 template <typename T>
231 Argument MakeArgument(const string& name, const T& value);
232 
233 template <typename T>
234 inline void AddArgument(const string& name, const T& value, OperatorDef* def) {
235  GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value));
236 }
237 
238 } // namespace caffe2
239 
240 #endif // CAFFE2_UTILS_PROTO_UTILS_H_
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
A helper class to index into arguments.
Definition: proto_utils.h:174