1 #include "caffe2/core/workspace.h" 7 #include "caffe2/core/logging.h" 8 #include "caffe2/core/net.h" 9 #include "caffe2/core/operator.h" 10 #include "caffe2/core/tensor.h" 11 #include "caffe2/core/timer.h" 12 #include "caffe2/proto/caffe2.pb.h" 15 caffe2_handle_executor_threads_exceptions,
17 "If used we will handle exceptions in executor threads. " 18 "This avoids SIGABRT but may cause process to deadlock");
21 caffe2_print_blob_sizes_at_exit,
23 "If true, workspace destructor will print all blob shapes");
29 CAFFE2_DEFINE_int(caffe2_threadpool_android_cap,
true,
"");
32 CAFFE2_DEFINE_int(caffe2_threadpool_ios_cap,
false,
"");
34 #endif // CAFFE2_MOBILE 41 inline const bool getShouldStop(
const Blob* b) {
42 if (!b || !b->meta().id()) {
46 const auto& t = b->Get<TensorCPU>();
47 CAFFE_ENFORCE(t.IsType<
bool>() && t.size() == 1,
"expects a scalar boolean");
48 return *(t.template data<bool>());
53 std::function<bool(int64_t)> getContinuationTest(
55 const ExecutionStep& step) {
56 if (step.has_should_stop_blob()) {
59 "Must not specify num_iter if should_stop_blob is set");
62 if (!step.has_should_stop_blob()) {
63 CAFFE_ENFORCE(!step.has_only_once(),
"not supported");
64 int64_t iterations = step.has_num_iter() ? step.num_iter() : 1;
65 VLOG(1) <<
"Will execute step " << step.name() <<
" for " << iterations
67 return [=](int64_t i) {
return i < iterations; };
69 bool onlyOnce = step.has_only_once() && step.only_once();
70 VLOG(1) <<
"Will execute step" << step.name() << (onlyOnce ?
" once " :
"")
71 <<
" until stopped by blob " << step.should_stop_blob();
73 return [](int64_t i) {
return i == 0; };
75 return [](int64_t i) {
return true; };
82 typedef std::function<bool(int)> ShouldContinue;
85 const ExecutionStep* mainStep,
87 ShouldContinue externalShouldContinue)
90 (step->substep_size() == 0 || step->network_size() == 0),
91 "An ExecutionStep should either have substep or networks" 94 if (step->has_should_stop_blob()) {
95 shouldStop = ws->
GetBlob(step->should_stop_blob());
97 shouldStop,
"blob ", step->should_stop_blob(),
" does not exist");
100 if (step->substep_size()) {
101 ShouldContinue substepShouldContinue;
102 if (!step->concurrent_substeps() || step->substep().size() <= 1) {
103 substepShouldContinue = externalShouldContinue;
105 substepShouldContinue = [
this, externalShouldContinue](int64_t it) {
106 return !gotFailure && externalShouldContinue(it);
110 for (
const auto& ss : step->substep()) {
111 auto compiledSubstep = std::make_shared<CompiledExecutionStep>(
112 &ss, ws, substepShouldContinue);
113 if (ss.has_run_every_ms()) {
114 reportSubsteps.push_back(compiledSubstep);
116 recurringSubsteps.push_back(compiledSubstep);
120 for (
const string& network_name : step->network()) {
121 auto* net = ws->
GetNet(network_name);
122 CAFFE_ENFORCE(net !=
nullptr,
"Network ", network_name,
" not found.");
123 networks.push_back(net);
127 netShouldContinue = getContinuationTest(ws, *step);
128 shouldContinue = [
this, externalShouldContinue](int64_t iter) {
129 return externalShouldContinue(iter) && this->netShouldContinue(iter);
133 const ExecutionStep* step;
134 vector<std::shared_ptr<CompiledExecutionStep>> reportSubsteps;
135 vector<std::shared_ptr<CompiledExecutionStep>> recurringSubsteps;
137 vector<NetBase*> networks;
138 Blob* shouldStop{
nullptr};
139 ShouldContinue netShouldContinue;
140 ShouldContinue shouldContinue;
141 std::atomic<bool> gotFailure{
false};
144 void Workspace::PrintBlobSizes() {
145 vector<string> blobs = LocalBlobs();
149 vector<std::pair<size_t, std::string>> blob_sizes;
150 for (
const auto& s : blobs) {
151 Blob* b = this->GetBlob(s);
152 ShapeCall shape_fun = GetShapeCallFunction(b->
meta().
id());
154 bool shares_data =
false;
156 auto shape = shape_fun(b->GetRaw(), shares_data, capacity);
161 cumtotal += capacity;
162 blob_sizes.push_back(make_pair(capacity, s));
168 [](
const std::pair<size_t, std::string>& a,
169 const std::pair<size_t, std::string>& b) {
170 return b.first < a.first;
174 LOG(INFO) <<
"---- Workspace blobs: ---- ";
175 LOG(INFO) <<
"name;current shape;capacity bytes;percentage";
176 for (
const auto& sb : blob_sizes) {
177 Blob* b = this->GetBlob(sb.second);
178 ShapeCall shape_fun = GetShapeCallFunction(b->
meta().
id());
179 CHECK(shape_fun !=
nullptr);
180 bool _shares_data =
false;
182 auto shape = shape_fun(b->GetRaw(), _shares_data, capacity);
183 std::stringstream ss;
184 ss << sb.second <<
";";
185 for (
const auto d : shape) {
188 LOG(INFO) << ss.str() <<
";" << sb.first <<
";" << std::setprecision(3)
189 << (cumtotal > 0 ? 100.0 * double(sb.first) / cumtotal : 0.0)
192 LOG(INFO) <<
"Total;;" << cumtotal <<
";100%";
196 vector<string> names;
197 for (
auto& entry : blob_map_) {
198 names.push_back(entry.first);
204 vector<string> names;
205 for (
auto& entry : blob_map_) {
206 names.push_back(entry.first);
209 vector<string> shared_blobs = shared_->Blobs();
210 names.insert(names.end(), shared_blobs.begin(), shared_blobs.end());
217 VLOG(1) <<
"Blob " << name <<
" already exists. Skipping.";
219 VLOG(1) <<
"Creating blob " << name;
220 blob_map_[name] = unique_ptr<Blob>(
new Blob());
222 return GetBlob(name);
226 auto it = blob_map_.find(name);
227 if (it != blob_map_.end()) {
228 VLOG(1) <<
"Removing blob " << name <<
" from this workspace.";
234 VLOG(1) <<
"Blob " << name <<
" not exists. Skipping.";
239 if (blob_map_.count(name)) {
240 return blob_map_.at(name).get();
241 }
else if (shared_ && shared_->HasBlob(name)) {
242 return shared_->GetBlob(name);
244 LOG(WARNING) <<
"Blob " << name <<
" not in the workspace.";
255 return const_cast<Blob*
>(
256 static_cast<const Workspace*
>(
this)->GetBlob(name));
260 CAFFE_ENFORCE(net_def.has_name(),
"Net definition should have a name.");
261 if (net_map_.count(net_def.name()) > 0) {
264 "I respectfully refuse to overwrite an existing net of the same " 267 "\", unless you explicitly specify overwrite=true.");
269 VLOG(1) <<
"Deleting existing network of the same name.";
274 net_map_.erase(net_def.name());
277 VLOG(1) <<
"Initializing network " << net_def.name();
278 net_map_[net_def.name()] =
280 if (net_map_[net_def.name()].get() ==
nullptr) {
281 LOG(ERROR) <<
"Error when creating the network.";
282 net_map_.erase(net_def.name());
285 return net_map_[net_def.name()].get();
289 if (!net_map_.count(name)) {
292 return net_map_[name].get();
297 if (net_map_.count(name)) {
298 net_map_.erase(name);
303 if (!net_map_.count(name)) {
304 LOG(ERROR) <<
"Network " << name <<
" does not exist yet.";
307 return net_map_[name]->Run();
310 bool Workspace::RunOperatorOnce(
const OperatorDef& op_def) {
311 std::unique_ptr<OperatorBase> op(CreateOperator(op_def,
this));
312 if (op.get() ==
nullptr) {
313 LOG(ERROR) <<
"Cannot create operator of type " << op_def.type();
317 LOG(ERROR) <<
"Error when running operator " << op_def.type();
322 bool Workspace::RunNetOnce(
const NetDef& net_def) {
325 LOG(ERROR) <<
"Error when running network " << net_def.name();
332 ShouldContinue shouldContinue) {
333 LOG(INFO) <<
"Started executing plan.";
334 if (plan.execution_step_size() == 0) {
335 LOG(WARNING) <<
"Nothing to run - did you define a correct plan?";
339 LOG(INFO) <<
"Initializing networks.";
341 std::set<string> seen_net_names_in_plan;
342 for (
const NetDef& net_def : plan.network()) {
344 seen_net_names_in_plan.count(net_def.name()) == 0,
345 "Your plan contains networks of the same name \"",
347 "\", which should not happen. Check your plan to see " 348 "if you made a programming error in creating the plan.");
349 seen_net_names_in_plan.insert(net_def.name());
356 LOG(ERROR) <<
"Failed initializing the networks.";
361 for (
const ExecutionStep& step : plan.execution_step()) {
364 if (!ExecuteStepRecursive(compiledStep)) {
365 LOG(ERROR) <<
"Failed initializing step " << step.name();
368 LOG(INFO) <<
"Step " << step.name() <<
" took " << step_timer.
Seconds()
371 LOG(INFO) <<
"Total plan took " << plan_timer.
Seconds() <<
" seconds.";
372 LOG(INFO) <<
"Plan executed successfully.";
377 ThreadPool* Workspace::GetThreadPool() {
378 std::lock_guard<std::mutex> guard(thread_pool_creation_mutex_);
381 int numThreads = std::thread::hardware_concurrency();
383 bool applyCap =
false;
385 applyCap = caffe2::FLAGS_caffe2_threadpool_android_cap;
387 applyCap = caffe2::FLAGS_caffe2_threadpool_ios_cap;
389 #error Undefined architecture 399 if (numThreads <= 3) {
401 }
else if (numThreads <= 5) {
406 numThreads = numThreads / 2;
410 LOG(INFO) <<
"Constructing thread pool with " << numThreads <<
" threads";
411 thread_pool_.reset(
new ThreadPool(numThreads));
414 return thread_pool_.get();
416 #endif // CAFFE2_MOBILE 421 struct ReporterInstance {
422 std::mutex report_mutex;
423 std::condition_variable report_cv;
424 std::thread report_thread;
425 ReporterInstance(
int intervalMillis,
bool* done, std::function<
void()> f) {
426 auto interval = std::chrono::milliseconds(intervalMillis);
427 auto reportWorker = [=]() {
428 std::unique_lock<std::mutex> lk(report_mutex);
430 report_cv.wait_for(lk, interval, [&]() {
return *done; });
434 report_thread = std::thread(reportWorker);
438 void start(int64_t intervalMillis, std::function<
void()> f) {
439 instances_.emplace_back(
new ReporterInstance(intervalMillis, &done, f));
444 for (
auto& instance : instances_) {
445 if (!instance->report_thread.joinable()) {
448 instance->report_cv.notify_all();
449 instance->report_thread.join();
454 std::vector<std::unique_ptr<ReporterInstance>> instances_;
460 #define CHECK_SHOULD_STOP(step, shouldStop) \ 461 if (getShouldStop(shouldStop)) { \ 462 VLOG(1) << "Execution step " << step.name() << " stopped by " \ 463 << step.should_stop_blob(); \ 468 const auto& step = *(compiledStep.step);
469 VLOG(1) <<
"Running execution step " << step.name();
471 std::unique_ptr<Reporter> reporter;
472 if (step.has_report_net() || compiledStep.reportSubsteps.size() > 0) {
473 reporter = caffe2::make_unique<Reporter>();
474 if (step.has_report_net()) {
476 step.has_report_interval(),
477 "A report_interval must be provided if report_net is set.");
478 if (net_map_.count(step.report_net()) == 0) {
479 LOG(ERROR) <<
"Report net " << step.report_net() <<
" not found.";
481 VLOG(1) <<
"Starting reporter net";
482 auto* net = net_map_[step.report_net()].get();
483 reporter->start(step.report_interval() * 1000, [=]() {
485 LOG(WARNING) <<
"Error running report_net.";
489 for (
auto& compiledSubstep : compiledStep.reportSubsteps) {
490 reporter->start(compiledSubstep->step->run_every_ms(), [=]() {
491 if (!ExecuteStepRecursive(*compiledSubstep)) {
492 LOG(WARNING) <<
"Error running report step.";
498 const Blob* shouldStop = compiledStep.shouldStop;
500 if (step.substep_size()) {
501 bool sequential = !step.concurrent_substeps() || step.substep().size() <= 1;
502 for (int64_t iter = 0; compiledStep.shouldContinue(iter); ++iter) {
504 VLOG(1) <<
"Executing step " << step.name() <<
" iteration " << iter;
505 for (
auto& compiledSubstep : compiledStep.recurringSubsteps) {
506 if (!ExecuteStepRecursive(*compiledSubstep)) {
509 CHECK_SHOULD_STOP(step, shouldStop);
512 VLOG(1) <<
"Executing step " << step.name() <<
" iteration " << iter
513 <<
" with " << step.substep().size() <<
" concurrent substeps";
515 std::atomic<int> next_substep{0};
516 std::mutex exception_mutex;
517 string first_exception;
518 auto worker = [&]() {
520 int substep_id = next_substep++;
521 if (compiledStep.gotFailure ||
522 (substep_id >= compiledStep.recurringSubsteps.size())) {
526 if (!ExecuteStepRecursive(
527 *compiledStep.recurringSubsteps.at(substep_id))) {
528 compiledStep.gotFailure =
true;
530 }
catch (
const std::exception& ex) {
531 std::lock_guard<std::mutex> guard(exception_mutex);
532 if (!first_exception.size()) {
533 first_exception = GetExceptionString(ex);
534 LOG(ERROR) <<
"Parallel worker exception:\n" << first_exception;
536 compiledStep.gotFailure =
true;
537 if (!FLAGS_caffe2_handle_executor_threads_exceptions) {
548 std::vector<std::thread> threads;
549 for (int64_t i = 0; i < step.substep().size(); ++i) {
550 if (!step.substep().Get(i).has_run_every_ms()) {
551 threads.emplace_back(worker);
554 for (
auto& thread: threads) {
557 if (compiledStep.gotFailure) {
558 LOG(ERROR) <<
"One of the workers failed.";
559 if (first_exception.size()) {
561 "One of the workers died with an unhandled exception ",
567 CHECK_SHOULD_STOP(step, shouldStop);
573 for (int64_t iter = 0; compiledStep.shouldContinue(iter); ++iter) {
574 VLOG(1) <<
"Executing networks " << step.name() <<
" iteration " << iter;
575 for (
NetBase* network : compiledStep.networks) {
576 if (!network->Run()) {
579 CHECK_SHOULD_STOP(step, shouldStop);
586 #undef CHECK_SHOULD_STOP
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
bool RunPlan(const PlanDef &plan_def, ShouldContinue should_continue=StopOnSignal{})
Runs a plan that has multiple nets and execution steps.
NetBase * GetNet(const string &net_name)
Gets the pointer to a created net.
float Seconds()
Returns the elapsed time in seconds.
NetBase * CreateNet(const NetDef &net_def, bool overwrite=false)
Creates a network with the given NetDef, and returns the pointer to the network.
bool RunNet(const string &net_name)
Finds and runs the instantiated network with the given name.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
vector< string > Blobs() const
Return a list of blob names.
A simple timer object for measuring time.
vector< string > LocalBlobs() const
Return list of blobs owned by this Workspace, not including blobs shared from parent workspace...
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
Blob is a general container that hosts a typed pointer.
bool RemoveBlob(const string &name)
Remove the blob of the given name.
void DeleteNet(const string &net_name)
Deletes the instantiated network with the given name.
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
const TypeMeta & meta() const
Returns the meta info of the blob.
Blob * CreateBlob(const string &name)
Creates a blob of the given name.