Caffe2 - Python API
A deep learning, cross platform ML framework
predictor_py_utils.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core
9 
10 
11 def create_predict_net(predictor_export_meta):
12  """
13  Return the input prediction net.
14  """
15  # Construct a new net to clear the existing settings.
16  net = core.Net(predictor_export_meta.predict_net.name or "predict")
17  net.Proto().op.extend(predictor_export_meta.predict_net.op)
18  net.Proto().external_input.extend(
19  predictor_export_meta.inputs + predictor_export_meta.parameters)
20  net.Proto().external_output.extend(predictor_export_meta.outputs)
21  return net.Proto()
22 
23 
24 def create_predict_init_net(ws, predictor_export_meta):
25  """
26  Return an initialization net that zero-fill all the input and
27  output blobs, using the shapes from the provided workspace. This is
28  necessary as there is no shape inference functionality in Caffe2.
29  """
30  net = core.Net("predict-init")
31 
32  def zero_fill(blob):
33  shape = predictor_export_meta.shapes.get(blob)
34  if shape is None:
35  if blob not in ws.blobs:
36  raise Exception(
37  "{} not in workspace but needed for shape: {}".format(
38  blob, ws.blobs))
39 
40  shape = ws.blobs[blob].fetch().shape
41  net.ConstantFill([], blob, shape=shape, value=0.0)
42 
43  external_blobs = predictor_export_meta.inputs + \
44  predictor_export_meta.outputs
45  for blob in external_blobs:
46  zero_fill(blob)
47 
48  net.Proto().external_input.extend(external_blobs)
49  if predictor_export_meta.extra_init_net:
50  net.AppendNet(predictor_export_meta.extra_init_net)
51  return net.Proto()
52 
53 
54 def get_comp_name(string, name):
55  if name:
56  return string + '_' + name
57  return string
58 
59 
60 def _ProtoMapGet(field, key):
61  '''
62  Given the key, get the value of the repeated field.
63  Helper function used by protobuf since it doesn't have map construct
64  '''
65  for v in field:
66  if (v.key == key):
67  return v.value
68  return None
69 
70 
71 def GetPlan(meta_net_def, key):
72  return _ProtoMapGet(meta_net_def.plans, key)
73 
74 
75 def GetPlanOriginal(meta_net_def, key):
76  return _ProtoMapGet(meta_net_def.plans, key)
77 
78 
79 def GetBlobs(meta_net_def, key):
80  blobs = _ProtoMapGet(meta_net_def.blobs, key)
81  if blobs is None:
82  return []
83  return blobs
84 
85 
86 def GetNet(meta_net_def, key):
87  return _ProtoMapGet(meta_net_def.nets, key)
88 
89 
90 def GetNetOriginal(meta_net_def, key):
91  return _ProtoMapGet(meta_net_def.nets, key)
92 
93 
94 def GetApplicationSpecificInfo(meta_net_def, key):
95  return _ProtoMapGet(meta_net_def.applicationSpecificInfo, key)
96 
97 
98 def AddBlobs(meta_net_def, blob_name, blob_def):
99  blobs = _ProtoMapGet(meta_net_def.blobs, blob_name)
100  if blobs is None:
101  blobs = meta_net_def.blobs.add()
102  blobs.key = blob_name
103  blobs = blobs.value
104  for blob in blob_def:
105  blobs.append(blob)
106 
107 
108 def AddPlan(meta_net_def, plan_name, plan_def):
109  meta_net_def.plans.add(key=plan_name, value=plan_def)
110 
111 
112 def AddNet(meta_net_def, net_name, net_def):
113  meta_net_def.nets.add(key=net_name, value=net_def)
def create_predict_net(predictor_export_meta)
def create_predict_init_net(ws, predictor_export_meta)