1 #include "caffe2/core/net.h" 5 #include <unordered_map> 6 #include <unordered_set> 8 #include "caffe2/core/operator.h" 9 #include "caffe2/core/static_tracepoint.h" 10 #include "caffe2/core/timer.h" 11 #include "caffe2/proto/caffe2.pb.h" 12 #include "caffe2/utils/proto_utils.h" 15 caffe2_disable_chaining,
17 "Disable chaining logic (some latent multi-device issues).");
23 bool sameDevice(
const OperatorDef& lhs,
const OperatorDef& rhs) {
24 return lhs.device_option().device_type() ==
25 rhs.device_option().device_type() &&
26 lhs.device_option().cuda_gpu_id() == rhs.device_option().cuda_gpu_id();
30 DAGNetBase::ExecutionChains singleChains(
31 const std::vector<internal::OperatorNode>& nodes) {
32 DAGNetBase::ExecutionChains chains;
33 for (
auto i = 0; i < nodes.size(); ++i) {
39 static void prune(
int node_idx, std::vector<internal::OpGraphNode>& nodes) {
41 std::vector<bool> ancestors(nodes.size(),
false);
43 std::stack<std::pair<int, int>> nodes_stack;
45 nodes_stack.push(std::make_pair(node_idx, -1));
47 while (!nodes_stack.empty()) {
48 const auto& node_pair = nodes_stack.top();
49 int curr = node_pair.first;
50 int prev = node_pair.second;
54 CAFFE_ENFORCE(curr < ancestors.size(),
"Out of bound access");
55 if (ancestors[curr]) {
56 ancestors[curr] =
false;
66 std::vector<int> new_parents;
67 for (
auto parent : nodes[curr].parents_) {
68 if (parent != prev && ancestors[parent]) {
70 nodes[parent].children_.erase(
72 nodes[parent].children_.begin(),
73 nodes[parent].children_.end(),
75 nodes[parent].children_.end());
77 new_parents.push_back(parent);
80 nodes[curr].parents_ = new_parents;
83 ancestors[curr] =
true;
86 if (nodes[curr].visited_inputs == nodes[curr].num_orig_parents) {
87 const auto& children = nodes[curr].children_;
88 for (
auto child : children) {
89 nodes[child].visited_inputs++;
90 nodes_stack.push(std::make_pair(child, curr));
100 std::vector<internal::OpGraphNode> pruneOpNodeGraph(
101 const std::vector<internal::OperatorNode>& orig_nodes) {
103 std::vector<internal::OpGraphNode> pruned;
109 for (
auto& node : orig_nodes) {
110 internal::OpGraphNode nd;
111 nd.children_ = node.children_;
112 nd.parents_ = node.parents_;
113 nd.num_orig_parents = nd.parents_.size();
114 pruned.push_back(nd);
117 for (
int i = 0; i < pruned.size(); ++i) {
118 if (pruned[i].parents_.size() == 0) {
123 LOG(INFO) <<
"Operator graph pruning prior to chain compute took: " 124 << t.Seconds() <<
" secs";
128 DAGNetBase::ExecutionChains computeChains(
129 const std::vector<internal::OperatorNode>& orig_nodes) {
130 const std::vector<internal::OpGraphNode> nodes = pruneOpNodeGraph(orig_nodes);
131 vector<int> initial_frontier;
132 for (
int idx = 0; idx < nodes.size(); ++idx) {
133 if (nodes[idx].parents_.size() == 0) {
134 initial_frontier.push_back(idx);
140 std::unordered_map<OpIndex, int> node_seen_count;
142 for (
int root_index : initial_frontier) {
143 const auto& root = nodes[root_index];
144 std::stack<std::pair<OpIndex, std::vector<int>::const_iterator>>
146 depth_stack.push(make_pair(root_index, root.children_.begin()));
147 node_seen_count[root_index]++;
149 node_seen_count[root_index] == 1,
152 " visit count must be == 1");
154 while (depth_stack.size() > 0) {
155 auto cur = depth_stack.top();
157 if (cur.second != nodes[cur.first].children_.end()) {
158 OpIndex node_index = *cur.second;
159 node_seen_count[node_index]++;
161 depth_stack.push(cur);
162 if (node_seen_count[node_index] == 1) {
165 make_pair(node_index, nodes[node_index].children_.begin()));
173 DAGNetBase::ExecutionChains chains;
174 std::unordered_set<OpIndex> seen_nodes;
175 std::vector<OpIndex> chain;
176 std::pair<OpIndex, std::vector<int>::const_iterator> cur;
177 std::stack<std::pair<OpIndex, std::vector<int>::const_iterator>> depth_stack;
178 auto check_current_for_chaining = [&]() ->
bool {
180 node_seen_count[cur.first] == 1 &&
181 (chain.size() == 0 || sameDevice(
182 orig_nodes[cur.first].operator_->def(),
183 orig_nodes[chain.back()].operator_->def())));
185 auto commit_chain = [&]() {
186 if (chain.size() > 0) {
188 chains.insert({chain.front(), chain}).second,
191 " was already added.");
192 VLOG(2) <<
"Added chain: " << chain.front() <<
"with elements";
193 for (
auto ch : chain) {
194 VLOG(2) << ch <<
", ";
199 auto depth_traverse = [&]() {
200 while (cur.second != nodes[cur.first].children_.end() &&
201 seen_nodes.find(*cur.second) != seen_nodes.end()) {
205 if (cur.second != nodes[cur.first].children_.end()) {
206 auto next = make_pair(*cur.second, nodes[*cur.second].children_.begin());
207 depth_stack.push(cur);
208 depth_stack.push(next);
211 for (
int root_index : initial_frontier) {
213 make_pair(root_index, nodes[root_index].children_.begin()));
214 while (depth_stack.size() > 0) {
215 cur = depth_stack.top();
217 if (seen_nodes.find(cur.first) == seen_nodes.end()) {
218 seen_nodes.insert(cur.first);
221 if (nodes[cur.first].children_.size() == 1) {
222 if (check_current_for_chaining()) {
224 VLOG(1) <<
"Adding to existing chain" << cur.first;
225 chain.push_back(cur.first);
226 int index = *nodes[cur.first].children_.begin();
227 depth_stack.push(make_pair(index, nodes[index].children_.begin()));
232 chain.push_back(cur.first);
233 int index = *nodes[cur.first].children_.begin();
234 depth_stack.push(make_pair(index, nodes[index].children_.begin()));
237 nodes[cur.first].children_.size() == 0 &&
238 check_current_for_chaining()) {
240 chain.push_back(cur.first);
247 chain.push_back(cur.first);
263 seen_nodes.size() == nodes.size(),
264 "Haven't seen all the nodes, expected number of nodes ",
273 DAGNetBase::DAGNetBase(
const NetDef& net_def, Workspace* ws)
274 : NetBase(net_def, ws), operator_nodes_(net_def.op_size()) {
276 VLOG(1) <<
"Constructing DAGNet " << net_def.name();
277 std::map<string, int> blob_creator;
278 std::map<string, std::set<int>> blob_readers;
279 bool net_def_has_device_option = net_def.has_device_option();
281 for (
int idx = 0; idx < net_def.op_size(); ++idx) {
282 const OperatorDef& op_def = net_def.op(idx);
283 VLOG(1) <<
"Creating operator #" << idx <<
": " << op_def.name() <<
":" 285 if (!op_def.has_device_option() && net_def_has_device_option) {
286 OperatorDef temp_def(op_def);
287 temp_def.mutable_device_option()->CopyFrom(net_def.device_option());
288 operator_nodes_[idx].operator_ = CreateOperator(temp_def, ws);
290 operator_nodes_[idx].operator_ = CreateOperator(op_def, ws);
295 [&](
const google::protobuf::RepeatedPtrField<std::string>& inputs) {
296 for (
const string& input : inputs) {
297 if (blob_creator.count(input) == 0) {
298 VLOG(1) <<
"Input " << input <<
" not produced by this net. " 299 <<
"Assuming it is pre-existing.";
301 int parent = blob_creator[input];
302 VLOG(1) <<
"op dependency (RaW " << input <<
"): " << parent
304 operator_nodes_[idx].parents_.push_back(parent);
305 operator_nodes_[parent].children_.push_back(idx);
308 blob_readers[input].insert(idx);
311 checkInputs(op_def.input());
312 checkInputs(op_def.control_input());
315 for (
const string& output : op_def.output()) {
316 if (blob_creator.count(output) != 0) {
319 int waw_parent = blob_creator[output];
320 VLOG(1) <<
"op dependency (WaW " << output <<
"): " << waw_parent
322 operator_nodes_[idx].parents_.push_back(waw_parent);
323 operator_nodes_[waw_parent].children_.push_back(idx);
327 for (
const int war_parent : blob_readers[output]) {
328 VLOG(1) <<
"op dependency (WaR " << output <<
"): " << war_parent
330 operator_nodes_[idx].parents_.push_back(war_parent);
331 operator_nodes_[war_parent].children_.push_back(idx);
334 blob_creator[output] = idx;
339 blob_readers[output].clear();
345 for (
int i = 0; i < operator_nodes_.size(); ++i) {
346 auto& node = operator_nodes_[i];
348 auto& p = node.parents_;
349 std::sort(p.begin(), p.end());
350 p.erase(std::unique(p.begin(), p.end()), p.end());
351 p.erase(std::remove(p.begin(), p.end(), i), p.end());
353 auto& c = node.children_;
354 std::sort(c.begin(), c.end());
355 c.erase(std::unique(c.begin(), c.end()), c.end());
356 c.erase(std::remove(c.begin(), c.end(), i), c.end());
360 (FLAGS_caffe2_disable_chaining ? singleChains(operator_nodes_)
361 : computeChains(operator_nodes_));
364 for (
int i = 0; i < operator_nodes_.size(); ++i) {
365 auto& node = operator_nodes_[i];
366 if (execution_chains_.find(i) != execution_chains_.end()) {
367 node.is_chain_start_ =
true;
369 node.is_chain_start_ =
false;
371 node.runtime_parent_count_ = 0;
374 LOG(INFO) <<
"Number of parallel execution chains " 375 << execution_chains_.size()
376 <<
" Number of operators = " << net_def.op_size();
382 for (
int idx = 0; idx < operator_nodes_.size(); ++idx) {
383 if (operator_nodes_[idx].parents_.size() == 0) {
384 initial_frontier_.push_back(idx);
388 int num_workers = net_def.has_num_workers() ? net_def.num_workers() : 1;
389 CAFFE_ENFORCE(num_workers > 0,
"Must have a positive number of workers.");
390 if (num_workers == 1) {
391 LOG(WARNING) <<
"Number of workers is 1: this means that all operators " 392 <<
"will be executed sequentially. Did you forget to set " 393 <<
"num_workers in the NetDef?";
395 num_workers_ = num_workers;
397 int num_workers_to_start = num_workers_;
402 ArgumentHelper arg_helper(net_def);
403 if (arg_helper.HasArgument(
"first_iter_only_one_worker")) {
404 if (arg_helper.GetSingleArgument<int64_t>(
405 "first_iter_only_one_worker", 0)) {
406 num_workers_to_start = 1;
410 for (
int i = 0; i < num_workers_to_start; ++i) {
411 VLOG(1) <<
"Start worker #" << i;
412 workers_.push_back(std::thread(&DAGNetBase::WorkerFunction,
this));
416 DAGNetBase::~DAGNetBase() {
418 job_queue_.NoMoreJobs();
419 VLOG(1) <<
"Joining workers.";
420 for (
auto& worker : workers_) {
425 bool DAGNetBase::Run() {
428 std::unique_lock<std::mutex> run_lock(run_in_progress_);
429 VLOG(1) <<
"Running parallel net.";
431 remaining_ops_ = operator_nodes_.size();
435 for (
auto& node : operator_nodes_) {
436 node.runtime_parent_count_ = node.parents_.size();
439 for (
auto& value : initial_frontier_) {
440 job_queue_.Push(value);
442 std::unique_lock<std::mutex> mutex_lock(remaining_ops_mutex_);
443 while (remaining_ops_ > 0) {
444 VLOG(2) <<
"Remaining ops to run: " << remaining_ops_;
445 cv_.wait(mutex_lock);
447 VLOG(2) <<
"All ops finished running.";
448 for (
const auto& op : operator_nodes_) {
450 op.runtime_parent_count_ == 0,
452 op.operator_->def().name(),
454 op.operator_->def().type(),
455 ") has some runtime parents left.");
459 for (
auto i = workers_.size(); i < num_workers_; ++i) {
460 VLOG(1) <<
"Start worker #" << i;
461 workers_.push_back(std::thread(&DAGNetBase::WorkerFunction,
this));
468 void DAGNetBase::WorkerFunction() {
474 if (!job_queue_.Pop(&idx)) {
477 VLOG(1) <<
"Running operator #" << idx <<
" " 478 << operator_nodes_[idx].operator_->def().name() <<
"(" 479 << operator_nodes_[idx].operator_->def().type() <<
").";
481 execution_chains_.find(idx) != execution_chains_.end(),
485 const auto& chain = execution_chains_[idx];
486 bool this_success = RunAt(execution_chains_[idx]);
488 LOG(ERROR) <<
"Operator chain failed: " 489 << ProtoDebugString(operator_nodes_[idx].operator_->def());
493 for (
const auto idx : chain) {
494 for (
const auto child : operator_nodes_[idx].children_) {
495 const int count = --operator_nodes_[child].runtime_parent_count_;
498 "Found runtime parent count smaller than zero for ",
500 operator_nodes_[child].operator_->def().name(),
502 operator_nodes_[child].operator_->def().type(),
509 if (operator_nodes_[child].is_chain_start_) {
510 VLOG(2) <<
"Pushing chain #" << child <<
" to queue.";
511 job_queue_.Push(child);
518 std::unique_lock<std::mutex> mutex_lock(remaining_ops_mutex_);
519 remaining_ops_ -= chain.size();
520 success_ &= this_success;
523 "All the operations should be finished by now, still have ",
528 VLOG(2) <<
"Finished executing operator #" << idx;
533 const int warmup_runs,
535 const bool run_individual) {
536 LOG(INFO) <<
"Starting benchmark.";
537 LOG(INFO) <<
"Running warmup runs.";
540 "Number of warm up runs should be non negative, provided ",
543 for (
int i = 0; i < warmup_runs; ++i) {
544 CAFFE_ENFORCE(Run(),
"Warmup run ", i,
" has failed.");
547 LOG(INFO) <<
"Main runs.";
550 "Number of main runs should be non negative, provided ",
554 for (
int i = 0; i < main_runs; ++i) {
555 CAFFE_ENFORCE(Run(),
"Main run ", i,
" has failed.");
558 LOG(INFO) <<
"Main run finished. Milliseconds per iter: " 559 << millis / main_runs
560 <<
". Iters per second: " << 1000.0 * main_runs / millis;
562 if (run_individual) {
563 LOG(INFO) <<
"DAGNet does not do per-op benchmark. To do so, " 564 "switch to a simple net type.";
566 return vector<float>{millis / main_runs};
571 using DAGNetBase::DAGNetBase;
574 bool RunAt(
const std::vector<int>& chain)
override {
576 const auto& net_name = name_.c_str();
577 for (
const auto i : chain) {
578 const auto& op = operator_nodes_[i].operator_.get();
579 const auto& op_name = op->def().name().c_str();
580 const auto& op_type = op->def().type().c_str();
581 CAFFE_SDT(operator_start, net_name, op_name, op_type, op);
582 success &= operator_nodes_[i].operator_->Run();
583 CAFFE_SDT(operator_done, net_name, op_name, op_type, op);
591 REGISTER_NET(dag,
DAGNet);
float MilliSeconds()
Returns the elapsed time in milliseconds.
A simple timer object for measuring time.
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
vector< float > TEST_Benchmark(const int warmup_runs, const int main_runs, const bool run_individual) override
Benchmarks a network.