3 from caffe2.proto
import caffe2_pb2
4 from google.protobuf.message
import DecodeError, Message
5 from google.protobuf
import text_format
12 if sys.version_info > (3,):
18 def CaffeBlobToNumpyArray(blob):
21 return (np.asarray(blob.data, dtype=np.float32)
22 .reshape(blob.num, blob.channels, blob.height, blob.width))
25 return (np.asarray(blob.data, dtype=np.float32)
26 .reshape(blob.shape.dim))
29 def Caffe2TensorToNumpyArray(tensor):
30 if tensor.data_type == caffe2_pb2.TensorProto.FLOAT:
32 tensor.float_data, dtype=np.float32).reshape(tensor.dims)
33 elif tensor.data_type == caffe2_pb2.TensorProto.DOUBLE:
35 tensor.double_data, dtype=np.float64).reshape(tensor.dims)
36 elif tensor.data_type == caffe2_pb2.TensorProto.INT32:
38 tensor.double_data, dtype=np.int).reshape(tensor.dims)
42 "Tensor data type not supported yet: " + str(tensor.data_type))
45 def NumpyArrayToCaffe2Tensor(arr, name=None):
46 tensor = caffe2_pb2.TensorProto()
47 tensor.dims.extend(arr.shape)
50 if arr.dtype == np.float32:
51 tensor.data_type = caffe2_pb2.TensorProto.FLOAT
52 tensor.float_data.extend(list(arr.flatten().astype(float)))
53 elif arr.dtype == np.float64:
54 tensor.data_type = caffe2_pb2.TensorProto.DOUBLE
55 tensor.double_data.extend(list(arr.flatten().astype(np.float64)))
56 elif arr.dtype == np.int:
57 tensor.data_type = caffe2_pb2.TensorProto.INT32
58 tensor.int32_data.extend(list(arr.flatten().astype(np.int)))
62 "Numpy data type not supported yet: " + str(arr.dtype))
67 """Makes an argument based on the value type.""" 68 argument = caffe2_pb2.Argument()
70 iterable = isinstance(value, collections.Iterable)
72 if isinstance(value, np.ndarray):
73 value = value.flatten().tolist()
74 elif isinstance(value, np.generic):
76 value = np.asscalar(value)
78 if type(value)
is float:
80 elif type(value)
is int
or type(value)
is bool
or type(value)
is long:
84 elif isinstance(value, basestring):
85 argument.s = (value
if type(value)
is bytes
86 else value.encode(
'utf-8'))
87 elif isinstance(value, Message):
88 argument.s = value.SerializeToString()
89 elif iterable
and all(type(v)
in [float, np.float_]
for v
in value):
90 argument.floats.extend(value)
91 elif iterable
and all(type(v)
in [int, bool, long, np.int_]
for v
in value):
92 argument.ints.extend(value)
93 elif iterable
and all(isinstance(v, basestring)
for v
in value):
94 argument.strings.extend([
95 (v
if type(v)
is bytes
else v.encode(
'utf-8'))
for v
in value])
96 elif iterable
and all(isinstance(v, Message)
for v
in value):
97 argument.strings.extend([v.SerializeToString()
for v
in value])
100 "Unknown argument type: key=%s value=%s, value type=%s" %
101 (key, str(value), str(type(value)))
107 """Reads a protobuffer with the given proto class. 110 cls: a protobuffer class. 111 s: a string of either binary or text protobuffer content. 114 proto: the protobuffer of cls 117 google.protobuf.message.DecodeError: if we cannot decode the message. 121 text_format.Parse(s, obj)
123 except text_format.ParseError:
124 obj.ParseFromString(s)
129 """Gets a specific field from a protocol buffer that matches the given class 131 for cls, func
in function_map.items():
136 def GetContentFromProtoString(s, function_map):
137 for cls, func
in function_map.items():
144 raise DecodeError(
"Cannot find a fit protobuffer class.")
148 """Convert a text file of the given protobuf class to binary.""" 150 with open(out_filename,
'w')
as fid:
151 fid.write(proto.SerializeToString())
156 This class allows to drop you into an interactive debugger 157 if there is an unhandled exception in your python script 165 if __name__ == '__main__': 166 from caffe2.python.utils import DebugMode 174 except KeyboardInterrupt:
180 'Entering interactive debugger. Type "bt" to print ' 181 'the full stacktrace. Type "help" to see command listing.')
182 print(sys.exc_info()[1])
192 Use this method to decorate your function with DebugMode's functionality 198 raise Exception("Bar") 203 def wrapper(*args, **kwargs):
205 return f(*args, **kwargs)
def TryReadProtoWithClass(cls, s)
def ConvertProtoToBinary(proto_class, filename, out_filename)
def MakeArgument(key, value)
def GetContentFromProto(obj, function_map)