3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.python
import core
13 Return the input prediction net. 16 net =
core.Net(predictor_export_meta.predict_net.name
or "predict")
17 net.Proto().op.extend(predictor_export_meta.predict_net.op)
18 net.Proto().external_input.extend(
19 predictor_export_meta.inputs + predictor_export_meta.parameters)
20 net.Proto().external_output.extend(predictor_export_meta.outputs)
26 Return an initialization net that zero-fill all the input and 27 output blobs, using the shapes from the provided workspace. This is 28 necessary as there is no shape inference functionality in Caffe2. 33 shape = predictor_export_meta.shapes.get(blob)
35 if blob
not in ws.blobs:
37 "{} not in workspace but needed for shape: {}".format(
40 shape = ws.blobs[blob].fetch().shape
41 net.ConstantFill([], blob, shape=shape, value=0.0)
43 external_blobs = predictor_export_meta.inputs + \
44 predictor_export_meta.outputs
45 for blob
in external_blobs:
48 net.Proto().external_input.extend(external_blobs)
49 if predictor_export_meta.extra_init_net:
50 net.AppendNet(predictor_export_meta.extra_init_net)
54 def get_comp_name(string, name):
56 return string +
'_' + name
60 def _ProtoMapGet(field, key):
62 Given the key, get the value of the repeated field. 63 Helper function used by protobuf since it doesn't have map construct 71 def GetPlan(meta_net_def, key):
72 return _ProtoMapGet(meta_net_def.plans, key)
75 def GetPlanOriginal(meta_net_def, key):
76 return _ProtoMapGet(meta_net_def.plans, key)
79 def GetBlobs(meta_net_def, key):
80 blobs = _ProtoMapGet(meta_net_def.blobs, key)
86 def GetNet(meta_net_def, key):
87 return _ProtoMapGet(meta_net_def.nets, key)
90 def GetNetOriginal(meta_net_def, key):
91 return _ProtoMapGet(meta_net_def.nets, key)
94 def GetApplicationSpecificInfo(meta_net_def, key):
95 return _ProtoMapGet(meta_net_def.applicationSpecificInfo, key)
98 def AddBlobs(meta_net_def, blob_name, blob_def):
99 blobs = _ProtoMapGet(meta_net_def.blobs, blob_name)
101 blobs = meta_net_def.blobs.add()
102 blobs.key = blob_name
104 for blob
in blob_def:
108 def AddPlan(meta_net_def, plan_name, plan_def):
109 meta_net_def.plans.add(key=plan_name, value=plan_def)
112 def AddNet(meta_net_def, net_name, net_def):
113 meta_net_def.nets.add(key=net_name, value=net_def)
def create_predict_net(predictor_export_meta)
def create_predict_init_net(ws, predictor_export_meta)