Caffe2 - C++ API
A deep learning, cross platform ML framework
zmqdb.cc
1 #include <atomic>
2 #include <condition_variable>
3 #include <mutex>
4 #include <thread> // NOLINT
5 
6 #include "caffe2/core/db.h"
7 #include "caffe2/utils/zmq_helper.h"
8 #include "caffe2/core/logging.h"
9 
10 namespace caffe2 {
11 namespace db {
12 
13 class ZmqDBCursor : public Cursor {
14  public:
15  explicit ZmqDBCursor(const string& source)
16  : source_(source), socket_(ZMQ_PULL),
17  prefetched_(false), finalize_(false) {
18  socket_.Connect(source_);
19  // Start prefetching thread.
20  prefetch_thread_.reset(
21  new std::thread([this] { this->Prefetch(); }));
22  // obtain the first value.
23  Next();
24  }
25 
26  ~ZmqDBCursor() {
27  finalize_ = true;
28  prefetched_ = false;
29  producer_.notify_one();
30  // Wait for the prefetch thread to finish elegantly.
31  prefetch_thread_->join();
32  socket_.Disconnect(source_);
33  }
34 
35  void Seek(const string& key) override { /* do nothing */ }
36 
37  void SeekToFirst() override { /* do nothing */ }
38 
39  void Next() override {
40  std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
41  while (!prefetched_) consumer_.wait(lock);
42  key_ = prefetch_key_;
43  value_ = prefetch_value_;
44  prefetched_ = false;
45  producer_.notify_one();
46  }
47 
48  string key() override { return key_; }
49  string value() override { return value_; }
50  bool Valid() override { return true; }
51 
52  private:
53 
54  void Prefetch() {
55  while (!finalize_) {
56  std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
57  while (prefetched_) producer_.wait(lock);
58  if (finalize_) {
59  return;
60  }
61  ZmqMessage msg;
62  socket_.RecvTillSuccess(&msg);
63  prefetch_key_.assign(static_cast<char*>(msg.data()), msg.size());
64  socket_.RecvTillSuccess(&msg);
65  prefetch_value_.assign(static_cast<char*>(msg.data()), msg.size());
66  prefetched_ = true;
67  consumer_.notify_one();
68  }
69  }
70 
71  string source_;
72  ZmqSocket socket_;
73  string key_;
74  string value_;
75  string prefetch_key_;
76  string prefetch_value_;
77 
78  unique_ptr<std::thread> prefetch_thread_;
79  std::mutex prefetch_access_mutex_;
80  std::condition_variable producer_, consumer_;
81  std::atomic<bool> prefetched_;
82  // finalize_ is used to tell the prefetcher to quit.
83  std::atomic<bool> finalize_;
84 };
85 
86 class ZmqDB : public DB {
87  public:
88  ZmqDB(const string& source, Mode mode)
89  : DB(source, mode), source_(source) {
90  CAFFE_ENFORCE(mode == READ, "ZeroMQ DB only supports read mode.");
91  }
92 
93  ~ZmqDB() {}
94 
95  void Close() override {}
96 
97  unique_ptr<Cursor> NewCursor() override {
98  return make_unique<ZmqDBCursor>(source_);
99  }
100 
101  unique_ptr<Transaction> NewTransaction() override {
102  CAFFE_THROW("ZeroMQ DB does not support writing with a transaction.");
103  return nullptr; // dummy placeholder to suppress old compiler warnings.
104  }
105 
106  private:
107  string source_;
108 };
109 
110 REGISTER_CAFFE2_DB(ZmqDB, ZmqDB);
111 // For lazy-minded, one can also call with lower-case name.
112 REGISTER_CAFFE2_DB(zmqdb, ZmqDB);
113 
114 } // namespace db
115 } // namespace caffe2
An abstract class for accessing a database of key-value pairs.
Definition: db.h:80
void Next() override
Go to the next location in the database.
Definition: zmqdb.cc:39
An abstract class for the cursor of the database while reading.
Definition: db.h:22
void SeekToFirst() override
Seek to the first key in the database.
Definition: zmqdb.cc:37
void Close() override
Closes the database.
Definition: zmqdb.cc:95
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
void Seek(const string &key) override
Seek to a specific key (or if the key does not exist, seek to the immediate next).
Definition: zmqdb.cc:35
unique_ptr< Cursor > NewCursor() override
Returns a cursor to read the database.
Definition: zmqdb.cc:97
string value() override
Returns the current value.
Definition: zmqdb.cc:49
unique_ptr< Transaction > NewTransaction() override
Returns a transaction to write data to the database.
Definition: zmqdb.cc:101
bool Valid() override
Returns whether the current location is valid - for example, if we have reached the end of the databa...
Definition: zmqdb.cc:50
string key() override
Returns the current key.
Definition: zmqdb.cc:48