Caffe2 - Python API
A deep learning, cross platform ML framework
workspace.py
1 
3 import contextlib
4 from google.protobuf.message import Message
5 from multiprocessing import Process
6 import os
7 import shutil
8 import socket
9 import tempfile
10 import logging
11 
12 from six import string_types
13 
14 import numpy as np
15 from caffe2.proto import caffe2_pb2
16 from caffe2.python import scope, utils
17 
18 import caffe2.python._import_c_extension as C
19 
20 logger = logging.getLogger(__name__)
21 
22 Blobs = C.blobs
23 CreateBlob = C.create_blob
24 CurrentWorkspace = C.current_workspace
25 DeserializeBlob = C.deserialize_blob
26 GlobalInit = C.global_init
27 HasBlob = C.has_blob
28 RegisteredOperators = C.registered_operators
29 SerializeBlob = C.serialize_blob
30 SwitchWorkspace = C.switch_workspace
31 RootFolder = C.root_folder
32 Workspaces = C.workspaces
33 BenchmarkNet = C.benchmark_net
34 Predictor = C.Predictor
35 
36 is_asan = C.is_asan
37 has_gpu_support = C.has_gpu_support
38 if has_gpu_support:
39  NumCudaDevices = C.num_cuda_devices
40  SetDefaultGPUID = C.set_default_gpu_id
41  GetDefaultGPUID = C.get_default_gpu_id
42  GetCuDNNVersion = C.get_cudnn_version
43 
44  def GetCudaPeerAccessPattern():
45  return np.asarray(C.get_cuda_peer_access_pattern())
46 else:
47  NumCudaDevices = lambda: 0 # noqa
48  SetDefaultGPUID = lambda x: None # noqa
49  GetDefaultGPUID = lambda: 0 # noqa
50  GetCuDNNVersion = lambda: 0 # noqa
51  GetCudaPeerAccessPattern = lambda: np.array([]) # noqa
52 
53 
54 # Python 2 and 3 compatibility: test if basestring exists
55 try:
56  basestring # NOQA
57 except NameError:
58  # This is python3 so we define basestring.
59  basestring = str
60 
61 
62 def _GetFreeFlaskPort():
63  """Get a free flask port."""
64  # We will prefer to use 5000. If not, we will then pick a random port.
65  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
66  result = sock.connect_ex(('127.0.0.1', 5000))
67  if result == 0:
68  return 5000
69  else:
70  s = socket.socket()
71  s.bind(('', 0))
72  port = s.getsockname()[1]
73  s.close()
74  # Race condition: between the interval we close the socket and actually
75  # start a mint process, another process might have occupied the port. We
76  # don't do much here as this is mostly for convenience in research
77  # rather than 24x7 service.
78  return port
79 
80 
81 def StartMint(root_folder=None, port=None):
82  """Start a mint instance.
83 
84  TODO(Yangqing): this does not work well under ipython yet. According to
85  https://github.com/ipython/ipython/issues/5862
86  writing up some fix is a todo item.
87  """
88  from caffe2.python.mint import app
89  if root_folder is None:
90  # Get the root folder from the current workspace
91  root_folder = C.root_folder()
92  if port is None:
93  port = _GetFreeFlaskPort()
94  process = Process(
95  target=app.main,
96  args=(
97  ['-p', str(port), '-r', root_folder],
98  )
99  )
100  process.start()
101  print('Mint running at http://{}:{}'.format(socket.getfqdn(), port))
102  return process
103 
104 
105 def StringifyProto(obj):
106  """Stringify a protocol buffer object.
107 
108  Inputs:
109  obj: a protocol buffer object, or a Pycaffe2 object that has a Proto()
110  function.
111  Outputs:
112  string: the output protobuf string.
113  Raises:
114  AttributeError: if the passed in object does not have the right attribute.
115  """
116  if isinstance(obj, string_types):
117  return obj
118  else:
119  if isinstance(obj, Message):
120  # First, see if this object is a protocol buffer, which we can
121  # simply serialize with the SerializeToString() call.
122  return obj.SerializeToString()
123  elif hasattr(obj, 'Proto'):
124  return obj.Proto().SerializeToString()
125  else:
126  raise ValueError("Unexpected argument to StringifyProto of type " +
127  type(obj).__name__)
128 
129 
130 def ResetWorkspace(root_folder=None):
131  if root_folder is None:
132  # Reset the workspace, but keep the current root folder setting.
133  return C.reset_workspace(C.root_folder())
134  else:
135  if not os.path.exists(root_folder):
136  os.makedirs(root_folder)
137  return C.reset_workspace(root_folder)
138 
139 
140 def CreateNet(net, overwrite=False, input_blobs=None):
141  if input_blobs is None:
142  input_blobs = []
143  for input_blob in input_blobs:
144  C.create_blob(input_blob)
145  return C.create_net(StringifyProto(net), overwrite)
146 
147 
148 def RunOperatorOnce(operator):
149  return C.run_operator_once(StringifyProto(operator))
150 
151 
152 def RunOperatorsOnce(operators):
153  for op in operators:
154  success = RunOperatorOnce(op)
155  if not success:
156  return False
157  return True
158 
159 
160 def RunNetOnce(net):
161  return C.run_net_once(StringifyProto(net))
162 
163 
164 def RunNet(name, num_iter=1):
165  """Runs a given net.
166 
167  Inputs:
168  name: the name of the net, or a reference to the net.
169  num_iter: number of iterations to run
170  Returns:
171  True or an exception.
172  """
173  return C.run_net(StringifyNetName(name), num_iter)
174 
175 
176 def RunPlan(plan_or_step):
177  # TODO(jiayq): refactor core.py/workspace.py to avoid circular deps
178  import caffe2.python.core as core
179  if isinstance(plan_or_step, core.ExecutionStep):
180  plan_or_step = core.Plan(plan_or_step)
181  return C.run_plan(StringifyProto(plan_or_step))
182 
183 
184 def InferShapesAndTypes(nets, blob_dimensions=None):
185  """Infers the shapes and types for the specified nets.
186 
187  Inputs:
188  nets: the list of nets
189  blob_dimensions (optional): a dictionary of blobs and their dimensions.
190  If not specified, the workspace blobs are used.
191  Returns:
192  A tuple of (shapes, types) dictionaries keyed by blob name.
193  """
194  net_protos = [StringifyProto(n.Proto()) for n in nets]
195  if blob_dimensions is None:
196  blobdesc_prototxt = C.infer_shapes_and_types_from_workspace(net_protos)
197  else:
198  blobdesc_prototxt = C.infer_shapes_and_types_from_map(
199  net_protos, blob_dimensions
200  )
201  blobdesc_proto = caffe2_pb2.TensorShapes()
202  blobdesc_proto.ParseFromString(blobdesc_prototxt)
203  shapes = {}
204  types = {}
205  for ts in blobdesc_proto.shapes:
206  if not ts.unknown_shape:
207  shapes[ts.name] = list(ts.dims)
208  types[ts.name] = ts.data_type
209 
210  return (shapes, types)
211 
212 
213 def _StringifyName(name, expected_type):
214  if isinstance(name, basestring):
215  return name
216  assert type(name).__name__ == expected_type, \
217  "Expected a string or %s" % expected_type
218  return str(name)
219 
220 
221 def StringifyBlobName(name):
222  return _StringifyName(name, "BlobReference")
223 
224 
225 def StringifyNetName(name):
226  return _StringifyName(name, "Net")
227 
228 
229 def FeedBlob(name, arr, device_option=None):
230  """Feeds a blob into the workspace.
231 
232  Inputs:
233  name: the name of the blob.
234  arr: either a TensorProto object or a numpy array object to be fed into
235  the workspace.
236  device_option (optional): the device option to feed the data with.
237  Returns:
238  True or False, stating whether the feed is successful.
239  """
240  if type(arr) is caffe2_pb2.TensorProto:
242  if type(arr) is np.ndarray and arr.dtype.kind == 'S':
243  # Plain NumPy strings are weird, let's use objects instead
244  arr = arr.astype(np.object)
245 
246  if device_option is None:
247  device_option = scope.CurrentDeviceScope()
248 
249  if device_option and device_option.device_type == caffe2_pb2.CUDA:
250  if arr.dtype == np.dtype('float64'):
251  logger.warning(
252  "CUDA operators do not support 64-bit doubles, " +
253  "please use arr.astype(np.float32) or np.int32 for ints." +
254  " Blob: {}".format(name) +
255  " type: {}".format(str(arr.dtype))
256  )
257 
258  name = StringifyBlobName(name)
259  if device_option is not None:
260  return C.feed_blob(name, arr, StringifyProto(device_option))
261  else:
262  return C.feed_blob(name, arr)
263 
264 
265 def FetchBlobs(names):
266  """Fetches a list of blobs from the workspace.
267 
268  Inputs:
269  names: list of names of blobs - strings or BlobReferences
270  Returns:
271  list of fetched blobs
272  """
273  return [FetchBlob(name) for name in names]
274 
275 
276 def FetchBlob(name):
277  """Fetches a blob from the workspace.
278 
279  Inputs:
280  name: the name of the blob - a string or a BlobReference
281  Returns:
282  Fetched blob (numpy array or string) if successful
283  """
284  return C.fetch_blob(StringifyBlobName(name))
285 
286 
288  """Return the current namescope string. To be used to fetch blobs"""
289  return scope.CurrentNameScope()
290 
291 
292 class _BlobDict(object):
293  """Provides python dict compatible way to do fetching and feeding"""
294 
295  def __getitem__(self, key):
296  return FetchBlob(key)
297 
298  def __setitem__(self, key, value):
299  return FeedBlob(key, value)
300 
301  def __len__(self):
302  return len(C.blobs())
303 
304  def __iter__(self):
305  return C.blobs().__iter__()
306 
307  def __contains__(self, item):
308  return C.has_blob(item)
309 
310 
311 blobs = _BlobDict()
312 
313 
314 
333 
334 _immediate_mode = False
335 _immediate_workspace_name = "_CAFFE2_IMMEDIATE"
336 _immediate_root_folder = ''
337 
338 
339 def IsImmediate():
340  return _immediate_mode
341 
342 
343 @contextlib.contextmanager
344 def WorkspaceGuard(workspace_name):
345  current = CurrentWorkspace()
346  SwitchWorkspace(workspace_name, True)
347  yield
348  SwitchWorkspace(current)
349 
350 
351 def StartImmediate(i_know=False):
352  global _immediate_mode
353  global _immediate_root_folder
354  if IsImmediate():
355  # already in immediate mode. We will kill the previous one
356  # and start from fresh.
357  StopImmediate()
358  _immediate_mode = True
359  with WorkspaceGuard(_immediate_workspace_name):
360  _immediate_root_folder = tempfile.mkdtemp()
361  ResetWorkspace(_immediate_root_folder)
362  if i_know:
363  # if the user doesn't want to see the warning message, sure...
364  return
365  print("""
366  Enabling immediate mode in caffe2 python is an EXTREMELY EXPERIMENTAL
367  feature and may very easily go wrong. This is because Caffe2 uses a
368  declarative way of defining operators and models, which is essentially
369  not meant to run things in an interactive way. Read the following carefully
370  to make sure that you understand the caveats.
371 
372  (1) You need to make sure that the sequences of operators you create are
373  actually runnable sequentially. For example, if you create an op that takes
374  an input X, somewhere earlier you should have already created X.
375 
376  (2) Caffe2 immediate uses one single workspace, so if the set of operators
377  you run are intended to be under different workspaces, they will not run.
378  To create boundaries between such use cases, you can call FinishImmediate()
379  and StartImmediate() manually to flush out everything no longer needed.
380 
381  (3) Underlying objects held by the immediate mode may interfere with your
382  normal run. For example, if there is a leveldb that you opened in immediate
383  mode and did not close, your main run will fail because leveldb does not
384  support double opening. Immediate mode may also occupy a lot of memory esp.
385  on GPUs. Call FinishImmediate() as soon as possible when you no longer
386  need it.
387 
388  (4) Immediate is designed to be slow. Every immediate call implicitly
389  creates a temp operator object, runs it, and destroys the operator. This
390  slow-speed run is by design to discourage abuse. For most use cases other
391  than debugging, do NOT turn on immediate mode.
392 
393  (5) If there is anything FATAL happening in the underlying C++ code, the
394  immediate mode will immediately (pun intended) cause the runtime to crash.
395 
396  Thus you should use immediate mode with extra care. If you still would
397  like to, have fun [https://xkcd.com/149/].
398  """)
399 
400 
402  """Stops an immediate mode run."""
403  # Phew, that was a dangerous ride.
404  global _immediate_mode
405  global _immediate_root_folder
406  if not IsImmediate():
407  return
408  with WorkspaceGuard(_immediate_workspace_name):
409  ResetWorkspace()
410  shutil.rmtree(_immediate_root_folder)
411  _immediate_root_folder = ''
412  _immediate_mode = False
413 
414 
415 def ImmediateBlobs():
416  with WorkspaceGuard(_immediate_workspace_name):
417  return Blobs()
418 
419 
420 def RunOperatorImmediate(op):
421  with WorkspaceGuard(_immediate_workspace_name):
422  RunOperatorOnce(op)
423 
424 
425 def FetchImmediate(*args, **kwargs):
426  with WorkspaceGuard(_immediate_workspace_name):
427  return FetchBlob(*args, **kwargs)
428 
429 
430 def FeedImmediate(*args, **kwargs):
431  with WorkspaceGuard(_immediate_workspace_name):
432  return FeedBlob(*args, **kwargs)
433 
434 
435 # CWorkspace utilities
436 
437 def _Workspace_create_net(ws, net, overwrite=False):
438  return ws._create_net(StringifyProto(net), overwrite)
439 
440 
441 C.Workspace.create_net = _Workspace_create_net
442 
443 
444 def _Workspace_run(ws, obj):
445  if hasattr(obj, 'Proto'):
446  obj = obj.Proto()
447  if isinstance(obj, caffe2_pb2.PlanDef):
448  return ws._run_plan(obj.SerializeToString())
449  if isinstance(obj, caffe2_pb2.NetDef):
450  return ws._run_net(obj.SerializeToString())
451  if isinstance(obj, caffe2_pb2.OperatorDef):
452  return ws._run_operator(obj.SerializeToString())
453  raise ValueError(
454  "Don't know how to do Workspace.run() on {}".format(type(obj)))
455 
456 
457 C.Workspace.run = _Workspace_run
458 
459 
460 def _Blob_feed(blob, arg, device_option=None):
461  if device_option is not None:
462  device_option = StringifyProto(device_option)
463  return blob._feed(arg, device_option)
464 
465 
466 C.Blob.feed = _Blob_feed
def StopImmediate()
Definition: workspace.py:401
def CurrentDeviceScope()
Definition: scope.py:33
def CurrentNameScope()
Definition: scope.py:26
def Caffe2TensorToNumpyArray(tensor)
Definition: utils.py:29
def RunNet(name, num_iter=1)
Definition: workspace.py:164
def InferShapesAndTypes(nets, blob_dimensions=None)
Definition: workspace.py:184
def GetNameScope()
Definition: workspace.py:287
def FeedBlob(name, arr, device_option=None)
Definition: workspace.py:229
def FetchBlobs(names)
Definition: workspace.py:265
def StringifyProto(obj)
Definition: workspace.py:105
def FetchBlob(name)
Definition: workspace.py:276
def StartMint(root_folder=None, port=None)
Definition: workspace.py:81