1 #include "caffe2/core/predictor.h" 7 void enforceIsTensor(Workspace* ws,
const std::string& name) {
8 auto blob = ws->GetBlob(name);
9 CAFFE_ENFORCE(blob,
"Blob does not exist: ", name);
11 blob->template IsType<TensorCPU>(),
"Blob is not a CPU Tensor: ", name);
14 void shareInputTensor(
16 const std::string& name,
18 enforceIsTensor(ws, name);
19 auto* blob = ws->GetBlob(name);
20 CAFFE_ENFORCE(blob,
"Blob: ", name,
" does not exist");
21 auto* tensor = blob->template GetMutable<TensorCPU>();
22 tensor->ResizeLike(*input);
23 tensor->ShareData(*input);
26 TensorCPU* extractOutputTensor(Workspace* ws,
const std::string& name) {
27 enforceIsTensor(ws, name);
28 auto* blob = ws->GetBlob(name);
29 CAFFE_ENFORCE(blob,
"Blob: ", name,
" does not exist");
30 return blob->template GetMutable<TensorCPU>();
35 const NetDef& init_net,
36 const NetDef& run_net,
38 : run_net_(run_net), ws_(parent) {
39 CAFFE_ENFORCE(ws_.RunNetOnce(init_net));
40 CAFFE_ENFORCE(ws_.CreateNet(run_net));
43 void Predictor::run(
const TensorVector& inputs, TensorVector* outputs) {
44 CAFFE_ENFORCE(inputs.size() <= run_net_.external_input_size());
45 for (
auto i = 0; i < inputs.size(); ++i) {
46 shareInputTensor(&ws_, run_net_.external_input(i), inputs[i]);
49 CAFFE_ENFORCE(ws_.RunNet(run_net_.name()));
51 outputs->resize(run_net_.external_output_size());
52 for (
auto i = 0; i < outputs->size(); ++i) {
53 (*outputs)[i] = extractOutputTensor(&ws_, run_net_.external_output(i));
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...