1 #ifndef CAFFE2_CORE_NET_H_ 2 #define CAFFE2_CORE_NET_H_ 10 #include <unordered_map> 12 #include "caffe2/core/blob.h" 13 #include "caffe2/core/common.h" 14 #include "caffe2/core/logging.h" 15 #include "caffe2/core/registry.h" 16 #include "caffe2/core/operator_schema.h" 18 #include "caffe2/core/tensor.h" 19 #include "caffe2/core/workspace.h" 20 #include "caffe2/proto/caffe2.pb.h" 21 #include "caffe2/utils/simple_queue.h" 33 virtual bool Run() = 0;
39 virtual bool RunAsync() {
return Run(); }
50 const int warmup_runs,
52 const bool run_individual) {
53 LOG(ERROR) <<
"Benchmark not implemented for this net type.";
54 return vector<float>();
57 inline const vector<string>& external_output()
const {
58 return external_output_;
61 inline const vector<string>& external_input()
const {
62 return external_input_;
66 vector<string> external_input_;
67 vector<string> external_output_;
70 DISABLE_COPY_AND_ASSIGN(
NetBase);
74 #define REGISTER_NET_CREATOR(key, ...) \ 75 CAFFE_REGISTER_CREATOR(NetRegistry, key, __VA_ARGS__) 76 #define REGISTER_NET(name, ...) \ 77 CAFFE_REGISTER_CLASS(NetRegistry, name, __VA_ARGS__) 95 bool RunAsync()
override;
97 const int warmup_runs,
99 const bool run_individual)
override;
102 vector<unique_ptr<OperatorBase> > operators_;
109 unique_ptr<OperatorBase> operator_;
110 vector<int> children_;
111 vector<int> parents_;
112 std::atomic<int> runtime_parent_count_;
113 bool is_chain_start_ =
false;
117 vector<int> children_;
118 vector<int> parents_;
119 int visited_inputs = 0;
120 int num_orig_parents;
126 using ExecutionChains = std::unordered_map<int, std::vector<int>>;
134 void WorkerFunction();
136 const int warmup_runs,
138 const bool run_individual)
override;
140 const ExecutionChains& TEST_execution_chains()
const {
141 return execution_chains_;
145 virtual bool RunAt(
const std::vector<int>& chain) = 0;
147 vector<internal::OperatorNode> operator_nodes_;
148 ExecutionChains execution_chains_;
149 vector<int> initial_frontier_;
151 std::vector<std::thread> workers_;
156 std::mutex remaining_ops_mutex_;
157 std::condition_variable cv_;
158 std::mutex run_in_progress_;
165 #endif // CAFFE2_CORE_NET_H_
virtual vector< float > TEST_Benchmark(const int warmup_runs, const int main_runs, const bool run_individual)
Benchmarks a network.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.