Caffe2 - C++ API
A deep learning, cross platform ML framework
predictor.cc
1 #include "caffe2/core/predictor.h"
2 
3 namespace caffe2 {
4 
5 namespace {
6 
7 void enforceIsTensor(Workspace* ws, const std::string& name) {
8  auto blob = ws->GetBlob(name);
9  CAFFE_ENFORCE(blob, "Blob does not exist: ", name);
10  CAFFE_ENFORCE(
11  blob->template IsType<TensorCPU>(), "Blob is not a CPU Tensor: ", name);
12 }
13 
14 void shareInputTensor(
15  Workspace* ws,
16  const std::string& name,
17  TensorCPU* input) {
18  enforceIsTensor(ws, name);
19  auto* blob = ws->GetBlob(name);
20  CAFFE_ENFORCE(blob, "Blob: ", name, " does not exist");
21  auto* tensor = blob->template GetMutable<TensorCPU>();
22  tensor->ResizeLike(*input);
23  tensor->ShareData(*input);
24 }
25 
26 TensorCPU* extractOutputTensor(Workspace* ws, const std::string& name) {
27  enforceIsTensor(ws, name);
28  auto* blob = ws->GetBlob(name);
29  CAFFE_ENFORCE(blob, "Blob: ", name, " does not exist");
30  return blob->template GetMutable<TensorCPU>();
31 }
32 }
33 
34 Predictor::Predictor(
35  const NetDef& init_net,
36  const NetDef& run_net,
37  Workspace* parent)
38  : run_net_(run_net), ws_(parent) {
39  CAFFE_ENFORCE(ws_.RunNetOnce(init_net));
40  CAFFE_ENFORCE(ws_.CreateNet(run_net));
41 }
42 
43 void Predictor::run(const TensorVector& inputs, TensorVector* outputs) {
44  CAFFE_ENFORCE(inputs.size() <= run_net_.external_input_size());
45  for (auto i = 0; i < inputs.size(); ++i) {
46  shareInputTensor(&ws_, run_net_.external_input(i), inputs[i]);
47  }
48 
49  CAFFE_ENFORCE(ws_.RunNet(run_net_.name()));
50 
51  outputs->resize(run_net_.external_output_size());
52  for (auto i = 0; i < outputs->size(); ++i) {
53  (*outputs)[i] = extractOutputTensor(&ws_, run_net_.external_output(i));
54  }
55 }
56 }
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...