3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.proto
import caffe2_pb2
9 from caffe2.proto
import metanet_pb2
10 from caffe2.python
import workspace, core
11 from caffe2.python.predictor_constants
import predictor_constants
12 import caffe2.python.predictor.serde
as serde
13 import caffe2.python.predictor.predictor_py_utils
as utils
19 'predict_net, parameters, inputs, outputs, shapes, name, extra_init_net')):
21 Metadata to be used for serializaing a net. 23 parameters, inputs, outputs could be either BlobReference or blob's names 25 predict_net can be either core.Net, NetDef, PlanDef or object 27 Override the named tuple to provide optional name parameter. 28 name will be used to identify multiple prediction nets. 40 inputs = map(str, inputs)
41 outputs = map(str, outputs)
42 parameters = map(str, parameters)
46 predict_net = predict_net.Proto()
48 assert isinstance(predict_net, (caffe2_pb2.NetDef, caffe2_pb2.PlanDef))
49 return super(PredictorExportMeta, cls).__new__(
50 cls, predict_net, parameters, inputs, outputs, shapes, name,
53 def inputs_name(self):
54 return utils.get_comp_name(predictor_constants.INPUTS_BLOB_TYPE,
57 def outputs_name(self):
58 return utils.get_comp_name(predictor_constants.OUTPUTS_BLOB_TYPE,
61 def parameters_name(self):
62 return utils.get_comp_name(predictor_constants.PARAMETERS_BLOB_TYPE,
65 def global_init_name(self):
66 return utils.get_comp_name(predictor_constants.GLOBAL_INIT_NET_TYPE,
69 def predict_init_name(self):
70 return utils.get_comp_name(predictor_constants.PREDICT_INIT_NET_TYPE,
73 def predict_net_name(self):
74 return utils.get_comp_name(predictor_constants.PREDICT_NET_TYPE,
77 def train_init_plan_name(self):
78 return utils.get_comp_name(predictor_constants.TRAIN_INIT_PLAN_TYPE,
81 def train_plan_name(self):
82 return utils.get_comp_name(predictor_constants.TRAIN_PLAN_TYPE,
88 Helper function which loads all required blobs from the db 89 and returns prediction net ready to be used 91 metanet_def = load_from_db(filename, db_type)
93 global_init_net = utils.GetNet(
94 metanet_def, predictor_constants.GLOBAL_INIT_NET_TYPE)
97 predict_init_net = utils.GetNet(
98 metanet_def, predictor_constants.PREDICT_INIT_NET_TYPE)
102 utils.GetNet(metanet_def, predictor_constants.PREDICT_NET_TYPE))
108 def _global_init_net(predictor_export_meta):
111 [predictor_constants.PREDICTOR_DBREADER],
112 predictor_export_meta.parameters)
113 net.Proto().external_input.extend([predictor_constants.PREDICTOR_DBREADER])
114 net.Proto().external_output.extend(predictor_export_meta.parameters)
122 ws = ws
or workspace.C.Workspace.current
125 meta_net_def = metanet_pb2.MetaNetDef()
126 utils.AddNet(meta_net_def, predictor_export_meta.predict_init_name(),
127 utils.create_predict_init_net(ws, predictor_export_meta))
128 utils.AddNet(meta_net_def, predictor_export_meta.global_init_name(),
129 _global_init_net(predictor_export_meta))
130 utils.AddNet(meta_net_def, predictor_export_meta.predict_net_name(),
131 utils.create_predict_net(predictor_export_meta))
132 utils.AddBlobs(meta_net_def, predictor_export_meta.parameters_name(),
133 predictor_export_meta.parameters)
134 utils.AddBlobs(meta_net_def, predictor_export_meta.inputs_name(),
135 predictor_export_meta.inputs)
136 utils.AddBlobs(meta_net_def, predictor_export_meta.outputs_name(),
137 predictor_export_meta.outputs)
141 def set_model_info(meta_net_def, project_str, model_class_str, version):
142 assert isinstance(meta_net_def, metanet_pb2.MetaNetDef)
143 meta_net_def.modelInfo.project = project_str
144 meta_net_def.modelInfo.modelClass = model_class_str
145 meta_net_def.modelInfo.version = version
148 def save_to_db(db_type, db_destination, predictor_export_meta):
153 blobs_to_save = [predictor_constants.META_NET_DEF] + \
154 predictor_export_meta.parameters
159 db=db_destination, db_type=db_type)
164 def load_from_db(filename, db_type):
170 db=filename, db_type=db_type)
172 'Failed to create db {}'.format(filename))
183 metanet_pb2.MetaNetDef)
185 def prepare_prediction_net(filename, db_type)
def deserialize_protobuf_struct(serialized_protobuf, struct_type)
def serialize_protobuf_struct(protobuf_struct)
def get_meta_net_def(predictor_export_meta, ws=None)
def FeedBlob(name, arr, device_option=None)
def CreateOperator(operator_type, inputs, outputs, name='', control_input=None, device_option=None, arg=None, engine=None, kwargs)
def CreateNet(net, overwrite=False, input_blobs=None)
def RunOperatorOnce(operator)