Caffe2 - C++ API
A deep learning, cross platform ML framework
net.h
1 #ifndef CAFFE2_CORE_NET_H_
2 #define CAFFE2_CORE_NET_H_
3 
4 #include <atomic>
5 #include <climits>
6 #include <cstddef>
7 #include <thread> // NOLINT
8 #include <typeinfo>
9 #include <vector>
10 #include <unordered_map>
11 
12 #include "caffe2/core/blob.h"
13 #include "caffe2/core/common.h"
14 #include "caffe2/core/logging.h"
15 #include "caffe2/core/registry.h"
16 #include "caffe2/core/operator_schema.h"
17 
18 #include "caffe2/core/tensor.h"
19 #include "caffe2/core/workspace.h"
20 #include "caffe2/proto/caffe2.pb.h"
21 #include "caffe2/utils/simple_queue.h"
22 
23 namespace caffe2 {
24 
25 class OperatorBase;
26 class Workspace;
27 // Net is a thin struct that owns all the operators together with the operator
28 // contexts.
29 class NetBase {
30  public:
31  NetBase(const NetDef& net_def, Workspace* ws);
32  virtual ~NetBase() noexcept {}
33  virtual bool Run() = 0;
34 
35  // RunAsync runs the net on the current stream, but potentially does
36  // not synchronize with respect to the host, and thus may require
37  // external synchronization (with respect to the current stream)
38  // after execution.
39  virtual bool RunAsync() { return Run(); }
49  virtual vector<float> TEST_Benchmark(
50  const int warmup_runs,
51  const int main_runs,
52  const bool run_individual) {
53  LOG(ERROR) << "Benchmark not implemented for this net type.";
54  return vector<float>();
55  }
56 
57  inline const vector<string>& external_output() const {
58  return external_output_;
59  }
60 
61  inline const vector<string>& external_input() const {
62  return external_input_;
63  }
64 
65  protected:
66  vector<string> external_input_;
67  vector<string> external_output_;
68  string name_;
69 
70  DISABLE_COPY_AND_ASSIGN(NetBase);
71 };
72 
73 CAFFE_DECLARE_REGISTRY(NetRegistry, NetBase, const NetDef&, Workspace*);
74 #define REGISTER_NET_CREATOR(key, ...) \
75  CAFFE_REGISTER_CREATOR(NetRegistry, key, __VA_ARGS__)
76 #define REGISTER_NET(name, ...) \
77  CAFFE_REGISTER_CLASS(NetRegistry, name, __VA_ARGS__)
78 
86 unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws);
87 
88 // This is the very basic structure you need to run a network - all it
89 // does is simply to run everything in sequence. If you want more fancy control
90 // such as a DAG-like execution, check out other better net implementations.
91 class SimpleNet : public NetBase {
92  public:
93  SimpleNet(const NetDef& net_def, Workspace* ws);
94  bool Run() override;
95  bool RunAsync() override;
96  vector<float> TEST_Benchmark(
97  const int warmup_runs,
98  const int main_runs,
99  const bool run_individual) override;
100 
101  protected:
102  vector<unique_ptr<OperatorBase> > operators_;
103 
104  DISABLE_COPY_AND_ASSIGN(SimpleNet);
105 };
106 
107 namespace internal {
108 struct OperatorNode {
109  unique_ptr<OperatorBase> operator_;
110  vector<int> children_;
111  vector<int> parents_;
112  std::atomic<int> runtime_parent_count_;
113  bool is_chain_start_ = false;
114 };
115 
116 struct OpGraphNode {
117  vector<int> children_;
118  vector<int> parents_;
119  int visited_inputs = 0;
120  int num_orig_parents;
121 };
122 }
123 
124 class DAGNetBase : public NetBase {
125  public:
126  using ExecutionChains = std::unordered_map<int, std::vector<int>>;
127  DAGNetBase(const NetDef& net_def, Workspace* ws);
128  ~DAGNetBase();
129  bool Run() override;
130  // WorkerFunction() is a function wrapper to allow us to run worker threads.
131  // It checks out one ready-to-run operator from the job queue, runs it,
132  // notifies all its children, and for any children that is ready, enqueues
133  // it to the job queue.
134  void WorkerFunction();
135  vector<float> TEST_Benchmark(
136  const int warmup_runs,
137  const int main_runs,
138  const bool run_individual) override;
139 
140  const ExecutionChains& TEST_execution_chains() const {
141  return execution_chains_;
142  }
143 
144  protected:
145  virtual bool RunAt(const std::vector<int>& chain) = 0;
146 
147  vector<internal::OperatorNode> operator_nodes_;
148  ExecutionChains execution_chains_;
149  vector<int> initial_frontier_;
150  SimpleQueue<int> job_queue_;
151  std::vector<std::thread> workers_;
152  int num_workers_;
153  int remaining_ops_;
154 
155  bool success_;
156  std::mutex remaining_ops_mutex_;
157  std::condition_variable cv_;
158  std::mutex run_in_progress_;
159 
160  DISABLE_COPY_AND_ASSIGN(DAGNetBase);
161 };
162 
163 } // namespace caffe2
164 
165 #endif // CAFFE2_CORE_NET_H_
virtual vector< float > TEST_Benchmark(const int warmup_runs, const int main_runs, const bool run_individual)
Benchmarks a network.
Definition: net.h:49
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:53
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:66