Caffe2 - C++ API
A deep learning, cross platform ML framework
blobs_queue.h
1 #pragma once
2 
3 #include <atomic>
4 #include <condition_variable>
5 #include <memory>
6 #include <mutex>
7 #include <queue>
8 
9 #include "caffe2/core/blob_stats.h"
10 #include "caffe2/core/logging.h"
11 #include "caffe2/core/stats.h"
12 #include "caffe2/core/tensor.h"
13 #include "caffe2/core/workspace.h"
14 
15 namespace caffe2 {
16 
17 // A thread-safe, bounded, blocking queue.
18 // Modelled as a circular buffer.
19 
20 // Containing blobs are owned by the workspace.
21 // On read, we swap out the underlying data for the blob passed in for blobs
22 
23 class BlobsQueue : public std::enable_shared_from_this<BlobsQueue> {
24  public:
25  BlobsQueue(
26  Workspace* ws,
27  const std::string& queueName,
28  size_t capacity,
29  size_t numBlobs,
30  bool enforceUniqueName,
31  const std::vector<std::string>& fieldNames = {})
32  : numBlobs_(numBlobs), stats_(queueName) {
33  if (!fieldNames.empty()) {
34  CAFFE_ENFORCE_EQ(
35  fieldNames.size(), numBlobs, "Wrong number of fieldNames provided.");
36  stats_.queue_dequeued_bytes.setDetails(fieldNames);
37  }
38  queue_.reserve(capacity);
39  for (auto i = 0; i < capacity; ++i) {
40  std::vector<Blob*> blobs;
41  blobs.reserve(numBlobs);
42  for (auto j = 0; j < numBlobs; ++j) {
43  const auto blobName =
44  queueName + "_" + to_string(i) + "_" + to_string(j);
45  if (enforceUniqueName) {
46  CAFFE_ENFORCE(
47  !ws->GetBlob(blobName),
48  "Queue internal blob already exists: ",
49  blobName);
50  }
51  blobs.push_back(ws->CreateBlob(blobName));
52  }
53  queue_.push_back(blobs);
54  }
55  DCHECK_EQ(queue_.size(), capacity);
56  }
57 
58  ~BlobsQueue() {
59  close();
60  }
61 
62  bool blockingRead(const std::vector<Blob*>& inputs) {
63  auto keeper = this->shared_from_this();
64  std::unique_lock<std::mutex> g(mutex_);
65  auto canRead = [this]() {
66  CAFFE_ENFORCE_LE(reader_, writer_);
67  return reader_ != writer_;
68  };
69  CAFFE_EVENT(stats_, queue_balance, -1);
70  cv_.wait(g, [this, canRead]() { return closing_ || canRead(); });
71  if (!canRead()) {
72  return false;
73  }
74  DCHECK(canRead());
75  auto& result = queue_[reader_ % queue_.size()];
76  CAFFE_ENFORCE(inputs.size() >= result.size());
77  for (auto i = 0; i < result.size(); ++i) {
78  auto bytes = BlobStat::sizeBytes(*result[i]);
79  CAFFE_EVENT(stats_, queue_dequeued_bytes, bytes, i);
80  using std::swap;
81  swap(*(inputs[i]), *(result[i]));
82  }
83  CAFFE_EVENT(stats_, queue_dequeued_records);
84  ++reader_;
85  cv_.notify_all();
86  return true;
87  }
88 
89  bool tryWrite(const std::vector<Blob*>& inputs) {
90  auto keeper = this->shared_from_this();
91  std::unique_lock<std::mutex> g(mutex_);
92  if (!canWrite()) {
93  return false;
94  }
95  CAFFE_EVENT(stats_, queue_balance, 1);
96  DCHECK(canWrite());
97  doWrite(inputs);
98  return true;
99  }
100 
101  bool blockingWrite(const std::vector<Blob*>& inputs) {
102  auto keeper = this->shared_from_this();
103  std::unique_lock<std::mutex> g(mutex_);
104  CAFFE_EVENT(stats_, queue_balance, 1);
105  cv_.wait(g, [this]() { return closing_ || canWrite(); });
106  if (!canWrite()) {
107  return false;
108  }
109  DCHECK(canWrite());
110  doWrite(inputs);
111  return true;
112  }
113 
114  void close() {
115  closing_ = true;
116 
117  std::lock_guard<std::mutex> g(mutex_);
118  cv_.notify_all();
119  }
120 
121  size_t getNumBlobs() const {
122  return numBlobs_;
123  }
124 
125  private:
126  bool canWrite() {
127  // writer is always within [reader, reader + size)
128  // we can write if reader is within [reader, reader + size)
129  CAFFE_ENFORCE_LE(reader_, writer_);
130  CAFFE_ENFORCE_LE(writer_, reader_ + queue_.size());
131  return writer_ != reader_ + queue_.size();
132  }
133 
134  void doWrite(const std::vector<Blob*>& inputs) {
135  auto& result = queue_[writer_ % queue_.size()];
136  CAFFE_ENFORCE(inputs.size() >= result.size());
137  for (auto i = 0; i < result.size(); ++i) {
138  using std::swap;
139  swap(*(inputs[i]), *(result[i]));
140  }
141  ++writer_;
142  cv_.notify_all();
143  }
144 
145  std::atomic<bool> closing_{false};
146 
147  size_t numBlobs_;
148  std::mutex mutex_; // protects all variables in the class.
149  std::condition_variable cv_;
150  int64_t reader_{0};
151  int64_t writer_{0};
152  std::vector<std::vector<Blob*>> queue_;
153 
154  struct QueueStats {
155  CAFFE_STAT_CTOR(QueueStats);
156  CAFFE_EXPORTED_STAT(queue_balance);
157  CAFFE_EXPORTED_STAT(queue_dequeued_records);
158  CAFFE_DETAILED_EXPORTED_STAT(queue_dequeued_bytes);
159  } stats_;
160 };
161 }
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
Definition: workspace.cc:238
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:53
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:215