Caffe2 - C++ API
A deep learning, cross platform ML framework
thread_pool.h
1 #ifndef CAFFE2_UTILS_THREAD_POOL_H_
2 #define CAFFE2_UTILS_THREAD_POOL_H_
3 
4 #include <condition_variable>
5 #include <functional>
6 #include <mutex>
7 #include <queue>
8 #include <thread>
9 
11  private:
12  std::queue< std::function< void() > > tasks_;
13  std::vector<std::thread> threads_;
14  std::mutex mutex_;
15  std::condition_variable condition_;
16  std::condition_variable completed_;
17  bool running_;
18  bool complete_;
19  std::size_t available_;
20  std::size_t total_;
21 
22  public:
24  explicit TaskThreadPool(std::size_t pool_size)
25  : threads_(pool_size), running_(true), complete_(true),
26  available_(pool_size), total_(pool_size) {
27  for ( std::size_t i = 0; i < pool_size; ++i ) {
28  threads_[i] = std::thread(
29  std::bind(&TaskThreadPool::main_loop, this));
30  }
31  }
32 
35  // Set running flag to false then notify all threads.
36  {
37  std::unique_lock< std::mutex > lock(mutex_);
38  running_ = false;
39  condition_.notify_all();
40  }
41 
42  try {
43  for (auto& t : threads_) {
44  t.join();
45  }
46  }
47  // Suppress all exceptions.
48  catch (const std::exception&) {}
49  }
50 
52  template <typename Task>
53  void runTask(Task task) {
54  std::unique_lock<std::mutex> lock(mutex_);
55 
56  // Set task and signal condition variable so that a worker thread will
57  // wake up and use the task.
58  tasks_.push(std::function<void()>(task));
59  complete_ = false;
60  condition_.notify_one();
61  }
62 
65  std::unique_lock<std::mutex> lock(mutex_);
66  if (!complete_)
67  completed_.wait(lock);
68  }
69 
70  private:
72  void main_loop() {
73  while (running_) {
74  // Wait on condition variable while the task is empty and
75  // the pool is still running.
76  std::unique_lock<std::mutex> lock(mutex_);
77  while (tasks_.empty() && running_) {
78  condition_.wait(lock);
79  }
80  // If pool is no longer running, break out of loop.
81  if (!running_) break;
82 
83  // Copy task locally and remove from the queue. This is
84  // done within its own scope so that the task object is
85  // destructed immediately after running the task. This is
86  // useful in the event that the function contains
87  // shared_ptr arguments bound via bind.
88  {
89  std::function< void() > task = tasks_.front();
90  tasks_.pop();
91  // Decrement count, indicating thread is no longer available.
92  --available_;
93 
94  lock.unlock();
95 
96  // Run the task.
97  try {
98  task();
99  }
100  // Suppress all exceptions.
101  catch ( const std::exception& ) {}
102 
103  // Update status of empty, maybe
104  // Need to recover the lock first
105  lock.lock();
106 
107  // Increment count, indicating thread is available.
108  ++available_;
109  if (tasks_.empty() && available_ == total_) {
110  complete_ = true;
111  completed_.notify_one();
112  }
113  }
114  } // while running_
115  }
116 };
117 
118 #endif
void runTask(Task task)
Add task to the thread pool if a thread is currently available.
Definition: thread_pool.h:53
void waitWorkComplete()
Wait for queue to be empty.
Definition: thread_pool.h:64
~TaskThreadPool()
Destructor.
Definition: thread_pool.h:34
TaskThreadPool(std::size_t pool_size)
Constructor.
Definition: thread_pool.h:24