1 #include "caffe2/utils/proto_utils.h" 7 #include "google/protobuf/io/coded_stream.h" 8 #include "google/protobuf/io/zero_copy_stream_impl.h" 10 #ifndef CAFFE2_USE_LITE_PROTO 11 #include "google/protobuf/text_format.h" 12 #endif // !CAFFE2_USE_LITE_PROTO 14 #include "caffe2/core/logging.h" 16 using ::google::protobuf::Message;
17 using ::google::protobuf::MessageLite;
21 std::string DeviceTypeName(
const int32_t& d) {
33 ". If you have recently updated the caffe2.proto file to add a new " 34 "device type, did you forget to update the TensorDeviceTypeName() " 35 "function to reflect such recent changes?");
42 bool ReadStringFromFile(
const char* filename,
string* str) {
43 std::ifstream ifs(filename, std::ios::in);
45 VLOG(1) <<
"File cannot be opened: " << filename
46 <<
" error: " << ifs.rdstate();
49 ifs.seekg(0, std::ios::end);
50 size_t n = ifs.tellg();
53 ifs.read(&(*str)[0], n);
57 bool WriteStringToFile(
const string& str,
const char* filename) {
58 std::ofstream ofs(filename, std::ios::out | std::ios::trunc);
60 VLOG(1) <<
"File cannot be created: " << filename
61 <<
" error: " << ofs.rdstate();
71 #ifdef CAFFE2_USE_LITE_PROTO 76 class IfstreamInputStream :
public ::google::protobuf::io::CopyingInputStream {
78 explicit IfstreamInputStream(
const string& filename)
79 : ifs_(filename.c_str(),
std::ios::in |
std::ios::binary) {}
80 ~IfstreamInputStream() { ifs_.close(); }
82 int Read(
void* buffer,
int size) {
86 ifs_.read(static_cast<char*>(buffer), size);
95 bool ReadProtoFromBinaryFile(
const char* filename, MessageLite* proto) {
96 ::google::protobuf::io::CopyingInputStreamAdaptor stream(
97 new IfstreamInputStream(filename));
98 stream.SetOwnsCopyingStream(
true);
101 ::google::protobuf::io::CodedInputStream coded_stream(&stream);
102 coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
103 return proto->ParseFromCodedStream(&coded_stream);
106 void WriteProtoToBinaryFile(
const MessageLite& proto,
const char* filename) {
107 LOG(FATAL) <<
"Not implemented yet.";
110 #else // CAFFE2_USE_LITE_PROTO 114 using ::google::protobuf::io::FileInputStream;
115 using ::google::protobuf::io::FileOutputStream;
116 using ::google::protobuf::io::ZeroCopyInputStream;
117 using ::google::protobuf::io::CodedInputStream;
118 using ::google::protobuf::io::ZeroCopyOutputStream;
119 using ::google::protobuf::io::CodedOutputStream;
121 bool ReadProtoFromTextFile(
const char* filename, Message* proto) {
122 int fd = open(filename, O_RDONLY);
123 CAFFE_ENFORCE_NE(fd, -1,
"File not found: ", filename);
124 FileInputStream* input =
new FileInputStream(fd);
125 bool success = google::protobuf::TextFormat::Parse(input, proto);
131 void WriteProtoToTextFile(
const Message& proto,
const char* filename) {
132 int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
133 FileOutputStream* output =
new FileOutputStream(fd);
134 CAFFE_ENFORCE(google::protobuf::TextFormat::Print(proto, output));
139 bool ReadProtoFromBinaryFile(
const char* filename, MessageLite* proto) {
140 #if defined (_MSC_VER) // for MSC compiler binary flag needs to be specified 141 int fd = open(filename, O_RDONLY | O_BINARY);
143 int fd = open(filename, O_RDONLY);
145 CAFFE_ENFORCE_NE(fd, -1,
"File not found: ", filename);
146 std::unique_ptr<ZeroCopyInputStream> raw_input(
new FileInputStream(fd));
147 std::unique_ptr<CodedInputStream> coded_input(
148 new CodedInputStream(raw_input.get()));
150 coded_input->SetTotalBytesLimit(1073741824, 536870912);
151 bool success = proto->ParseFromCodedStream(coded_input.get());
158 void WriteProtoToBinaryFile(
const MessageLite& proto,
const char* filename) {
159 int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
161 fd, -1,
"File cannot be created: ", filename,
" error number: ", errno);
162 std::unique_ptr<ZeroCopyOutputStream> raw_output(
new FileOutputStream(fd));
163 std::unique_ptr<CodedOutputStream> coded_output(
164 new CodedOutputStream(raw_output.get()));
165 CAFFE_ENFORCE(proto.SerializeToCodedStream(coded_output.get()));
166 coded_output.reset();
171 #endif // CAFFE2_USE_LITE_PROTO 173 ArgumentHelper::ArgumentHelper(
const OperatorDef& def) {
174 for (
auto& arg : def.arg()) {
175 if (arg_map_.count(arg.name())) {
176 if (arg.SerializeAsString() !=
177 arg_map_[arg.name()]->SerializeAsString()) {
181 "Found argument of the same name ",
183 "but with different contents.",
184 ProtoDebugString(def));
186 LOG(WARNING) <<
"Duplicated argument name found in operator def: " 187 << ProtoDebugString(def);
190 arg_map_[arg.name()] = &arg;
194 ArgumentHelper::ArgumentHelper(
const NetDef& netdef) {
195 for (
auto& arg : netdef.arg()) {
197 arg_map_.count(arg.name()) == 0,
198 "Duplicated argument name found in net def: ",
199 ProtoDebugString(netdef));
200 arg_map_[arg.name()] = &arg;
204 bool ArgumentHelper::HasArgument(
const string& name)
const {
205 return arg_map_.count(name);
211 template <
typename InputType,
typename TargetType>
212 bool SupportsLosslessConversion(
const InputType& value) {
213 return static_cast<InputType
>(
static_cast<TargetType
>(value)) == value;
217 #define INSTANTIATE_GET_SINGLE_ARGUMENT( \ 218 T, fieldname, enforce_lossless_conversion) \ 220 T ArgumentHelper::GetSingleArgument<T>( \ 221 const string& name, const T& default_value) const { \ 222 if (arg_map_.count(name) == 0) { \ 223 VLOG(1) << "Using default parameter value " << default_value \ 224 << " for parameter " << name; \ 225 return default_value; \ 228 arg_map_.at(name)->has_##fieldname(), \ 231 " does not have the right field: expected field " #fieldname); \ 232 auto value = arg_map_.at(name)->fieldname(); \ 233 if (enforce_lossless_conversion) { \ 234 auto supportsConversion = \ 235 SupportsLosslessConversion<decltype(value), T>(value); \ 237 supportsConversion, \ 242 "cannot be represented correctly in a target type"); \ 247 bool ArgumentHelper::HasSingleArgumentOfType<T>(const string& name) const { \ 248 if (arg_map_.count(name) == 0) { \ 251 return arg_map_.at(name)->has_##fieldname(); \ 254 INSTANTIATE_GET_SINGLE_ARGUMENT(
float, f,
false)
255 INSTANTIATE_GET_SINGLE_ARGUMENT(
double, f, false)
256 INSTANTIATE_GET_SINGLE_ARGUMENT(
bool, i, false)
257 INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true)
258 INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true)
259 INSTANTIATE_GET_SINGLE_ARGUMENT(
int, i, true)
260 INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true)
261 INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true)
262 INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true)
263 INSTANTIATE_GET_SINGLE_ARGUMENT(
size_t, i, true)
264 INSTANTIATE_GET_SINGLE_ARGUMENT(
string, s, false)
265 #undef INSTANTIATE_GET_SINGLE_ARGUMENT 267 #define INSTANTIATE_GET_REPEATED_ARGUMENT( \ 268 T, fieldname, enforce_lossless_conversion) \ 270 vector<T> ArgumentHelper::GetRepeatedArgument<T>( \ 271 const string& name, const std::vector<T>& default_value) const { \ 272 if (arg_map_.count(name) == 0) { \ 273 return default_value; \ 276 for (const auto& v : arg_map_.at(name)->fieldname()) { \ 277 if (enforce_lossless_conversion) { \ 278 auto supportsConversion = \ 279 SupportsLosslessConversion<decltype(v), T>(v); \ 281 supportsConversion, \ 286 "cannot be represented correctly in a target type"); \ 288 values.push_back(v); \ 293 INSTANTIATE_GET_REPEATED_ARGUMENT(
float, floats,
false)
294 INSTANTIATE_GET_REPEATED_ARGUMENT(
double, floats, false)
295 INSTANTIATE_GET_REPEATED_ARGUMENT(
bool, ints, false)
296 INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true)
297 INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true)
298 INSTANTIATE_GET_REPEATED_ARGUMENT(
int, ints, true)
299 INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true)
300 INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true)
301 INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true)
302 INSTANTIATE_GET_REPEATED_ARGUMENT(
size_t, ints, true)
303 INSTANTIATE_GET_REPEATED_ARGUMENT(
string, strings, false)
304 #undef INSTANTIATE_GET_REPEATED_ARGUMENT 306 #define CAFFE2_MAKE_SINGULAR_ARGUMENT(T, fieldname) \ 308 Argument MakeArgument(const string& name, const T& value) { \ 310 arg.set_name(name); \ 311 arg.set_##fieldname(value); \ 315 CAFFE2_MAKE_SINGULAR_ARGUMENT(
bool, i)
316 CAFFE2_MAKE_SINGULAR_ARGUMENT(
float, f)
317 CAFFE2_MAKE_SINGULAR_ARGUMENT(
int, i)
318 CAFFE2_MAKE_SINGULAR_ARGUMENT(int64_t, i)
319 CAFFE2_MAKE_SINGULAR_ARGUMENT(
string, s)
320 #undef CAFFE2_MAKE_SINGULAR_ARGUMENT 323 Argument MakeArgument(
const string& name,
const MessageLite& value) {
326 arg.set_s(value.SerializeAsString());
330 #define CAFFE2_MAKE_REPEATED_ARGUMENT(T, fieldname) \ 332 Argument MakeArgument(const string& name, const vector<T>& value) { \ 334 arg.set_name(name); \ 335 for (const auto& v : value) { \ 336 arg.add_##fieldname(v); \ 341 CAFFE2_MAKE_REPEATED_ARGUMENT(
float, floats)
342 CAFFE2_MAKE_REPEATED_ARGUMENT(
int, ints)
343 CAFFE2_MAKE_REPEATED_ARGUMENT(int64_t, ints)
344 CAFFE2_MAKE_REPEATED_ARGUMENT(
string, strings)
345 #undef CAFFE2_MAKE_REPEATED_ARGUMENT 347 const Argument& GetArgument(
const OperatorDef& def,
const string& name) {
348 for (
const Argument& arg : def.arg()) {
349 if (arg.name() == name) {
356 " does not exist in operator ",
357 ProtoDebugString(def));
360 bool GetFlagArgument(
361 const OperatorDef& def,
364 for (
const Argument& arg : def.arg()) {
365 if (arg.name() == name) {
367 arg.has_i(),
"Can't parse argument as bool: ", ProtoDebugString(arg));
374 Argument* GetMutableArgument(
376 const bool create_if_missing,
378 for (
int i = 0; i < def->arg_size(); ++i) {
379 if (def->arg(i).name() == name) {
380 return def->mutable_arg(i);
384 if (create_if_missing) {
385 Argument* arg = def->add_arg();
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...