Caffe2 - Python API
A deep learning, cross platform ML framework
predictor_exporter.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.proto import caffe2_pb2
9 from caffe2.proto import metanet_pb2
10 from caffe2.python import workspace, core
11 from caffe2.python.predictor_constants import predictor_constants
12 import caffe2.python.predictor.serde as serde
13 import caffe2.python.predictor.predictor_py_utils as utils
14 
15 import collections
16 
17 
18 class PredictorExportMeta(collections.namedtuple( 'PredictorExportMeta',
19  'predict_net, parameters, inputs, outputs, shapes, name, extra_init_net')):
20  """
21  Metadata to be used for serializaing a net.
22 
23  parameters, inputs, outputs could be either BlobReference or blob's names
24 
25  predict_net can be either core.Net, NetDef, PlanDef or object
26 
27  Override the named tuple to provide optional name parameter.
28  name will be used to identify multiple prediction nets.
29  """
30  def __new__(
31  cls,
32  predict_net,
33  parameters,
34  inputs,
35  outputs,
36  shapes=None,
37  name="",
38  extra_init_net=None
39  ):
40  inputs = map(str, inputs)
41  outputs = map(str, outputs)
42  parameters = map(str, parameters)
43  shapes = shapes or {}
44 
45  if isinstance(predict_net, (core.Net, core.Plan)):
46  predict_net = predict_net.Proto()
47 
48  assert isinstance(predict_net, (caffe2_pb2.NetDef, caffe2_pb2.PlanDef))
49  return super(PredictorExportMeta, cls).__new__(
50  cls, predict_net, parameters, inputs, outputs, shapes, name,
51  extra_init_net)
52 
53  def inputs_name(self):
54  return utils.get_comp_name(predictor_constants.INPUTS_BLOB_TYPE,
55  self.name)
56 
57  def outputs_name(self):
58  return utils.get_comp_name(predictor_constants.OUTPUTS_BLOB_TYPE,
59  self.name)
60 
61  def parameters_name(self):
62  return utils.get_comp_name(predictor_constants.PARAMETERS_BLOB_TYPE,
63  self.name)
64 
65  def global_init_name(self):
66  return utils.get_comp_name(predictor_constants.GLOBAL_INIT_NET_TYPE,
67  self.name)
68 
69  def predict_init_name(self):
70  return utils.get_comp_name(predictor_constants.PREDICT_INIT_NET_TYPE,
71  self.name)
72 
73  def predict_net_name(self):
74  return utils.get_comp_name(predictor_constants.PREDICT_NET_TYPE,
75  self.name)
76 
77  def train_init_plan_name(self):
78  return utils.get_comp_name(predictor_constants.TRAIN_INIT_PLAN_TYPE,
79  self.name)
80 
81  def train_plan_name(self):
82  return utils.get_comp_name(predictor_constants.TRAIN_PLAN_TYPE,
83  self.name)
84 
85 
86 def prepare_prediction_net(filename, db_type):
87  '''
88  Helper function which loads all required blobs from the db
89  and returns prediction net ready to be used
90  '''
91  metanet_def = load_from_db(filename, db_type)
92 
93  global_init_net = utils.GetNet(
94  metanet_def, predictor_constants.GLOBAL_INIT_NET_TYPE)
95  workspace.RunNetOnce(global_init_net)
96 
97  predict_init_net = utils.GetNet(
98  metanet_def, predictor_constants.PREDICT_INIT_NET_TYPE)
99  workspace.RunNetOnce(predict_init_net)
100 
101  predict_net = core.Net(
102  utils.GetNet(metanet_def, predictor_constants.PREDICT_NET_TYPE))
103  workspace.CreateNet(predict_net)
104 
105  return predict_net
106 
107 
108 def _global_init_net(predictor_export_meta):
109  net = core.Net("global-init")
110  net.Load(
111  [predictor_constants.PREDICTOR_DBREADER],
112  predictor_export_meta.parameters)
113  net.Proto().external_input.extend([predictor_constants.PREDICTOR_DBREADER])
114  net.Proto().external_output.extend(predictor_export_meta.parameters)
115  return net.Proto()
116 
117 
118 def get_meta_net_def(predictor_export_meta, ws=None):
119  """
120  """
121 
122  ws = ws or workspace.C.Workspace.current
123 
124  # Predict net is the core network that we use.
125  meta_net_def = metanet_pb2.MetaNetDef()
126  utils.AddNet(meta_net_def, predictor_export_meta.predict_init_name(),
127  utils.create_predict_init_net(ws, predictor_export_meta))
128  utils.AddNet(meta_net_def, predictor_export_meta.global_init_name(),
129  _global_init_net(predictor_export_meta))
130  utils.AddNet(meta_net_def, predictor_export_meta.predict_net_name(),
131  utils.create_predict_net(predictor_export_meta))
132  utils.AddBlobs(meta_net_def, predictor_export_meta.parameters_name(),
133  predictor_export_meta.parameters)
134  utils.AddBlobs(meta_net_def, predictor_export_meta.inputs_name(),
135  predictor_export_meta.inputs)
136  utils.AddBlobs(meta_net_def, predictor_export_meta.outputs_name(),
137  predictor_export_meta.outputs)
138  return meta_net_def
139 
140 
141 def set_model_info(meta_net_def, project_str, model_class_str, version):
142  assert isinstance(meta_net_def, metanet_pb2.MetaNetDef)
143  meta_net_def.modelInfo.project = project_str
144  meta_net_def.modelInfo.modelClass = model_class_str
145  meta_net_def.modelInfo.version = version
146 
147 
148 def save_to_db(db_type, db_destination, predictor_export_meta):
149  meta_net_def = get_meta_net_def(predictor_export_meta)
150  workspace.FeedBlob(predictor_constants.META_NET_DEF,
151  serde.serialize_protobuf_struct(meta_net_def))
152 
153  blobs_to_save = [predictor_constants.META_NET_DEF] + \
154  predictor_export_meta.parameters
155  op = core.CreateOperator(
156  "Save",
157  blobs_to_save, [],
158  absolute_path=True,
159  db=db_destination, db_type=db_type)
160 
162 
163 
164 def load_from_db(filename, db_type):
165  # global_init_net in meta_net_def will load parameters from
166  # predictor_constants.PREDICTOR_DBREADER
167  create_db = core.CreateOperator(
168  'CreateDB', [],
169  [core.BlobReference(predictor_constants.PREDICTOR_DBREADER)],
170  db=filename, db_type=db_type)
171  assert workspace.RunOperatorOnce(create_db), (
172  'Failed to create db {}'.format(filename))
173 
174  # predictor_constants.META_NET_DEF is always stored before the parameters
175  load_meta_net_def = core.CreateOperator(
176  'Load',
177  [core.BlobReference(predictor_constants.PREDICTOR_DBREADER)],
178  [core.BlobReference(predictor_constants.META_NET_DEF)])
179  assert workspace.RunOperatorOnce(load_meta_net_def)
180 
181  meta_net_def = serde.deserialize_protobuf_struct(
182  str(workspace.FetchBlob(predictor_constants.META_NET_DEF)),
183  metanet_pb2.MetaNetDef)
184  return meta_net_def
185 
def prepare_prediction_net(filename, db_type)
def deserialize_protobuf_struct(serialized_protobuf, struct_type)
Definition: serde.py:13
def serialize_protobuf_struct(protobuf_struct)
Definition: serde.py:9
def RunNetOnce(net)
Definition: workspace.py:160
def get_meta_net_def(predictor_export_meta, ws=None)
def FeedBlob(name, arr, device_option=None)
Definition: workspace.py:229
def CreateOperator(operator_type, inputs, outputs, name='', control_input=None, device_option=None, arg=None, engine=None, kwargs)
Definition: core.py:259
def CreateNet(net, overwrite=False, input_blobs=None)
Definition: workspace.py:140
def RunOperatorOnce(operator)
Definition: workspace.py:148
def FetchBlob(name)
Definition: workspace.py:276