1 #ifndef CAFFE2_UTILS_PROTO_UTILS_H_ 2 #define CAFFE2_UTILS_PROTO_UTILS_H_ 4 #include "google/protobuf/message_lite.h" 5 #ifndef CAFFE2_USE_LITE_PROTO 6 #include "google/protobuf/message.h" 7 #endif // !CAFFE2_USE_LITE_PROTO 9 #include "caffe2/core/logging.h" 10 #include "caffe2/proto/caffe2.pb.h" 11 #include "caffe2/proto/predictor_consts.pb.h" 16 using ::google::protobuf::MessageLite;
25 std::string DeviceTypeName(
const int32_t& d);
28 bool ReadStringFromFile(
const char* filename,
string* str);
29 bool WriteStringToFile(
const string& str,
const char* filename);
32 bool ReadProtoFromBinaryFile(
const char* filename, MessageLite* proto);
33 inline bool ReadProtoFromBinaryFile(
const string filename, MessageLite* proto) {
34 return ReadProtoFromBinaryFile(filename.c_str(), proto);
37 void WriteProtoToBinaryFile(
const MessageLite& proto,
const char* filename);
38 inline void WriteProtoToBinaryFile(
const MessageLite& proto,
39 const string& filename) {
40 return WriteProtoToBinaryFile(proto, filename.c_str());
43 #ifdef CAFFE2_USE_LITE_PROTO 45 inline string ProtoDebugString(
const MessageLite& proto) {
46 return proto.SerializeAsString();
52 inline bool ReadProtoFromTextFile(
const char* filename, MessageLite* proto) {
53 LOG(FATAL) <<
"If you are running lite version, you should not be " 54 <<
"calling any text-format protobuffers.";
57 inline bool ReadProtoFromTextFile(
const string filename, MessageLite* proto) {
58 return ReadProtoFromTextFile(filename.c_str(), proto);
61 inline void WriteProtoToTextFile(
const MessageLite& proto,
62 const char* filename) {
63 LOG(FATAL) <<
"If you are running lite version, you should not be " 64 <<
"calling any text-format protobuffers.";
66 inline void WriteProtoToTextFile(
const MessageLite& proto,
67 const string& filename) {
68 return WriteProtoToTextFile(proto, filename.c_str());
71 inline bool ReadProtoFromFile(
const char* filename, MessageLite* proto) {
72 return (ReadProtoFromBinaryFile(filename, proto) ||
73 ReadProtoFromTextFile(filename, proto));
76 inline bool ReadProtoFromFile(
const string& filename, MessageLite* proto) {
77 return ReadProtoFromFile(filename.c_str(), proto);
80 #else // CAFFE2_USE_LITE_PROTO 82 using ::google::protobuf::Message;
84 inline string ProtoDebugString(
const Message& proto) {
85 return proto.ShortDebugString();
88 bool ReadProtoFromTextFile(
const char* filename, Message* proto);
89 inline bool ReadProtoFromTextFile(
const string filename, Message* proto) {
90 return ReadProtoFromTextFile(filename.c_str(), proto);
93 void WriteProtoToTextFile(
const Message& proto,
const char* filename);
94 inline void WriteProtoToTextFile(
const Message& proto,
const string& filename) {
95 return WriteProtoToTextFile(proto, filename.c_str());
99 inline bool ReadProtoFromFile(
const char* filename, Message* proto) {
100 return (ReadProtoFromBinaryFile(filename, proto) ||
101 ReadProtoFromTextFile(filename, proto));
104 inline bool ReadProtoFromFile(
const string& filename, Message* proto) {
105 return ReadProtoFromFile(filename.c_str(), proto);
108 #endif // CAFFE2_USE_LITE_PROTO 111 template <
class IterableInputs,
class IterableOutputs,
class IterableArgs>
112 OperatorDef CreateOperatorDef(
113 const string& type,
const string& name,
const IterableInputs& inputs,
114 const IterableOutputs& outputs,
const IterableArgs& args,
115 const DeviceOption& device_option,
const string& engine) {
119 for (
const string& in : inputs) {
122 for (
const string& out : outputs) {
125 for (
const Argument& arg : args) {
126 def.add_arg()->CopyFrom(arg);
128 if (device_option.has_device_type()) {
129 def.mutable_device_option()->CopyFrom(device_option);
132 def.set_engine(engine);
139 template <
class IterableInputs,
class IterableOutputs,
class IterableArgs>
140 inline OperatorDef CreateOperatorDef(
141 const string& type,
const string& name,
const IterableInputs& inputs,
142 const IterableOutputs& outputs,
const IterableArgs& args) {
143 return CreateOperatorDef(
144 type, name, inputs, outputs, args, DeviceOption(),
"");
149 template <
class IterableInputs,
class IterableOutputs>
150 inline OperatorDef CreateOperatorDef(
151 const string& type,
const string& name,
const IterableInputs& inputs,
152 const IterableOutputs& outputs) {
153 return CreateOperatorDef(type, name, inputs, outputs,
154 std::vector<Argument>(), DeviceOption(),
"");
157 inline bool HasArgument(
const OperatorDef& def,
const string& name) {
158 for (
const Argument& arg : def.arg()) {
159 if (arg.name() == name) {
178 bool HasArgument(
const string& name)
const;
180 template <
typename T>
181 T GetSingleArgument(
const string& name,
const T& default_value)
const;
182 template <
typename T>
183 bool HasSingleArgumentOfType(
const string& name)
const;
184 template <
typename T>
185 vector<T> GetRepeatedArgument(
187 const std::vector<T>& default_value = std::vector<T>())
const;
189 template <
typename MessageType>
190 MessageType GetMessageArgument(
const string& name)
const {
191 CAFFE_ENFORCE(arg_map_.count(name),
"Cannot find parameter named ", name);
193 if (arg_map_.at(name)->has_s()) {
195 message.ParseFromString(arg_map_.at(name)->s()),
196 "Faild to parse content from the string");
198 VLOG(1) <<
"Return empty message for parameter " << name;
203 template <
typename MessageType>
204 vector<MessageType> GetRepeatedMessageArgument(
const string& name)
const {
205 CAFFE_ENFORCE(arg_map_.count(name),
"Cannot find parameter named ", name);
206 vector<MessageType> messages(arg_map_.at(name)->strings_size());
207 for (
int i = 0; i < messages.size(); ++i) {
209 messages[i].ParseFromString(arg_map_.at(name)->strings(i)),
210 "Faild to parse content from the string");
216 CaffeMap<string, const Argument*> arg_map_;
219 const Argument& GetArgument(
const OperatorDef& def,
const string& name);
220 bool GetFlagArgument(
221 const OperatorDef& def,
223 bool def_value =
false);
225 Argument* GetMutableArgument(
227 const bool create_if_missing,
230 template <
typename T>
231 Argument MakeArgument(
const string& name,
const T& value);
233 template <
typename T>
234 inline void AddArgument(
const string& name,
const T& value, OperatorDef* def) {
235 GetMutableArgument(name,
true, def)->CopyFrom(MakeArgument(name, value));
240 #endif // CAFFE2_UTILS_PROTO_UTILS_H_ Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
A helper class to index into arguments.