Caffe2 - Python API
A deep learning, cross platform ML framework
utils.py
1 
3 from caffe2.proto import caffe2_pb2
4 from google.protobuf.message import DecodeError, Message
5 from google.protobuf import text_format
6 import collections
7 import functools
8 import numpy as np
9 import sys
10 
11 
12 if sys.version_info > (3,):
13  # This is python 3. We will define a few stuff that we used.
14  basestring = str
15  long = int
16 
17 
18 def CaffeBlobToNumpyArray(blob):
19  if (blob.num != 0):
20  # old style caffe blob.
21  return (np.asarray(blob.data, dtype=np.float32)
22  .reshape(blob.num, blob.channels, blob.height, blob.width))
23  else:
24  # new style caffe blob.
25  return (np.asarray(blob.data, dtype=np.float32)
26  .reshape(blob.shape.dim))
27 
28 
29 def Caffe2TensorToNumpyArray(tensor):
30  if tensor.data_type == caffe2_pb2.TensorProto.FLOAT:
31  return np.asarray(
32  tensor.float_data, dtype=np.float32).reshape(tensor.dims)
33  elif tensor.data_type == caffe2_pb2.TensorProto.DOUBLE:
34  return np.asarray(
35  tensor.double_data, dtype=np.float64).reshape(tensor.dims)
36  elif tensor.data_type == caffe2_pb2.TensorProto.INT32:
37  return np.asarray(
38  tensor.double_data, dtype=np.int).reshape(tensor.dims)
39  else:
40  # TODO: complete the data type.
41  raise RuntimeError(
42  "Tensor data type not supported yet: " + str(tensor.data_type))
43 
44 
45 def NumpyArrayToCaffe2Tensor(arr, name=None):
46  tensor = caffe2_pb2.TensorProto()
47  tensor.dims.extend(arr.shape)
48  if name:
49  tensor.name = name
50  if arr.dtype == np.float32:
51  tensor.data_type = caffe2_pb2.TensorProto.FLOAT
52  tensor.float_data.extend(list(arr.flatten().astype(float)))
53  elif arr.dtype == np.float64:
54  tensor.data_type = caffe2_pb2.TensorProto.DOUBLE
55  tensor.double_data.extend(list(arr.flatten().astype(np.float64)))
56  elif arr.dtype == np.int:
57  tensor.data_type = caffe2_pb2.TensorProto.INT32
58  tensor.int32_data.extend(list(arr.flatten().astype(np.int)))
59  else:
60  # TODO: complete the data type.
61  raise RuntimeError(
62  "Numpy data type not supported yet: " + str(arr.dtype))
63  return tensor
64 
65 
66 def MakeArgument(key, value):
67  """Makes an argument based on the value type."""
68  argument = caffe2_pb2.Argument()
69  argument.name = key
70  iterable = isinstance(value, collections.Iterable)
71 
72  if isinstance(value, np.ndarray):
73  value = value.flatten().tolist()
74  elif isinstance(value, np.generic):
75  # convert numpy scalar to native python type
76  value = np.asscalar(value)
77 
78  if type(value) is float:
79  argument.f = value
80  elif type(value) is int or type(value) is bool or type(value) is long:
81  # We make a relaxation that a boolean variable will also be stored as
82  # int.
83  argument.i = value
84  elif isinstance(value, basestring):
85  argument.s = (value if type(value) is bytes
86  else value.encode('utf-8'))
87  elif isinstance(value, Message):
88  argument.s = value.SerializeToString()
89  elif iterable and all(type(v) in [float, np.float_] for v in value):
90  argument.floats.extend(value)
91  elif iterable and all(type(v) in [int, bool, long, np.int_] for v in value):
92  argument.ints.extend(value)
93  elif iterable and all(isinstance(v, basestring) for v in value):
94  argument.strings.extend([
95  (v if type(v) is bytes else v.encode('utf-8')) for v in value])
96  elif iterable and all(isinstance(v, Message) for v in value):
97  argument.strings.extend([v.SerializeToString() for v in value])
98  else:
99  raise ValueError(
100  "Unknown argument type: key=%s value=%s, value type=%s" %
101  (key, str(value), str(type(value)))
102  )
103  return argument
104 
105 
107  """Reads a protobuffer with the given proto class.
108 
109  Inputs:
110  cls: a protobuffer class.
111  s: a string of either binary or text protobuffer content.
112 
113  Outputs:
114  proto: the protobuffer of cls
115 
116  Throws:
117  google.protobuf.message.DecodeError: if we cannot decode the message.
118  """
119  obj = cls()
120  try:
121  text_format.Parse(s, obj)
122  return obj
123  except text_format.ParseError:
124  obj.ParseFromString(s)
125  return obj
126 
127 
128 def GetContentFromProto(obj, function_map):
129  """Gets a specific field from a protocol buffer that matches the given class
130  """
131  for cls, func in function_map.items():
132  if type(obj) is cls:
133  return func(obj)
134 
135 
136 def GetContentFromProtoString(s, function_map):
137  for cls, func in function_map.items():
138  try:
139  obj = TryReadProtoWithClass(cls, s)
140  return func(obj)
141  except DecodeError:
142  continue
143  else:
144  raise DecodeError("Cannot find a fit protobuffer class.")
145 
146 
147 def ConvertProtoToBinary(proto_class, filename, out_filename):
148  """Convert a text file of the given protobuf class to binary."""
149  proto = TryReadProtoWithClass(proto_class, open(filename).read())
150  with open(out_filename, 'w') as fid:
151  fid.write(proto.SerializeToString())
152 
153 
154 class DebugMode(object):
155  '''
156  This class allows to drop you into an interactive debugger
157  if there is an unhandled exception in your python script
158 
159  Example of usage:
160 
161  def main():
162  # your code here
163  pass
164 
165  if __name__ == '__main__':
166  from caffe2.python.utils import DebugMode
167  DebugMode.run(main)
168  '''
169 
170  @classmethod
171  def run(cls, func):
172  try:
173  return func()
174  except KeyboardInterrupt:
175  raise
176  except Exception:
177  import pdb
178 
179  print(
180  'Entering interactive debugger. Type "bt" to print '
181  'the full stacktrace. Type "help" to see command listing.')
182  print(sys.exc_info()[1])
183  print
184 
185  pdb.post_mortem()
186  sys.exit(1)
187  raise
188 
189 
190 def debug(f):
191  '''
192  Use this method to decorate your function with DebugMode's functionality
193 
194  Example:
195 
196  @debug
197  def test_foo(self):
198  raise Exception("Bar")
199 
200  '''
201 
202  @functools.wraps(f)
203  def wrapper(*args, **kwargs):
204  def func():
205  return f(*args, **kwargs)
206  DebugMode.run(func)
207 
208  return wrapper
def TryReadProtoWithClass(cls, s)
Definition: utils.py:106
def ConvertProtoToBinary(proto_class, filename, out_filename)
Definition: utils.py:147
def MakeArgument(key, value)
Definition: utils.py:66
def debug(f)
Definition: utils.py:190
def GetContentFromProto(obj, function_map)
Definition: utils.py:128