Caffe2 - C++ API
A deep learning, cross platform ML framework
db.h
1 #ifndef CAFFE2_CORE_DB_H_
2 #define CAFFE2_CORE_DB_H_
3 
4 #include <mutex>
5 
6 #include "caffe2/core/blob_serialization.h"
7 #include "caffe2/core/registry.h"
8 #include "caffe2/proto/caffe2.pb.h"
9 
10 namespace caffe2 {
11 namespace db {
12 
17 enum Mode { READ, WRITE, NEW };
18 
22 class Cursor {
23  public:
24  Cursor() { }
25  virtual ~Cursor() { }
31  virtual void Seek(const string& key) = 0;
32  virtual bool SupportsSeek() { return false; }
36  virtual void SeekToFirst() = 0;
40  virtual void Next() = 0;
44  virtual string key() = 0;
48  virtual string value() = 0;
53  virtual bool Valid() = 0;
54 
55  DISABLE_COPY_AND_ASSIGN(Cursor);
56 };
57 
61 class Transaction {
62  public:
63  Transaction() { }
64  virtual ~Transaction() { }
68  virtual void Put(const string& key, const string& value) = 0;
72  virtual void Commit() = 0;
73 
74  DISABLE_COPY_AND_ASSIGN(Transaction);
75 };
76 
80 class DB {
81  public:
82  DB(const string& source, Mode mode) : mode_(mode) {}
83  virtual ~DB() { }
87  virtual void Close() = 0;
92  virtual std::unique_ptr<Cursor> NewCursor() = 0;
97  virtual std::unique_ptr<Transaction> NewTransaction() = 0;
98 
99  protected:
100  Mode mode_;
101 
102  DISABLE_COPY_AND_ASSIGN(DB);
103 };
104 
105 // Database classes are registered by their names so we can do optional
106 // dependencies.
107 CAFFE_DECLARE_REGISTRY(Caffe2DBRegistry, DB, const string&, Mode);
108 #define REGISTER_CAFFE2_DB(name, ...) \
109  CAFFE_REGISTER_CLASS(Caffe2DBRegistry, name, __VA_ARGS__)
110 
117 inline unique_ptr<DB> CreateDB(
118  const string& db_type, const string& source, Mode mode) {
119  auto result = Caffe2DBRegistry()->Create(db_type, source, mode);
120  VLOG(1) << ((!result) ? "not found db " : "found db ") << db_type;
121  return result;
122 }
123 
127 class DBReader {
128  public:
129 
130  friend class DBReaderSerializer;
131  DBReader() {}
132 
133  DBReader(
134  const string& db_type,
135  const string& source,
136  const int32_t num_shards = 1,
137  const int32_t shard_id = 0) {
138  Open(db_type, source, num_shards, shard_id);
139  }
140 
141  explicit DBReader(const DBReaderProto& proto) {
142  Open(proto.db_type(), proto.source());
143  if (proto.has_key()) {
144  CAFFE_ENFORCE(cursor_->SupportsSeek(),
145  "Encountering a proto that needs seeking but the db type "
146  "does not support it.");
147  cursor_->Seek(proto.key());
148  }
149  num_shards_ = 1;
150  shard_id_ = 0;
151  }
152 
153  explicit DBReader(std::unique_ptr<DB> db)
154  : db_type_("<memory-type>"),
155  source_("<memory-source>"),
156  db_(std::move(db)) {
157  CAFFE_ENFORCE(db_.get(), "Passed null db");
158  cursor_ = db_->NewCursor();
159  }
160 
161  void Open(
162  const string& db_type,
163  const string& source,
164  const int32_t num_shards = 1,
165  const int32_t shard_id = 0) {
166  // Note(jiayq): resetting is needed when we re-open e.g. leveldb where no
167  // concurrent access is allowed.
168  cursor_.reset();
169  db_.reset();
170  db_type_ = db_type;
171  source_ = source;
172  db_ = CreateDB(db_type_, source_, READ);
173  CAFFE_ENFORCE(db_,
174  "Cannot open db: ", source_, " of type ", db_type_);
175  CAFFE_ENFORCE(num_shards >= 1);
176  CAFFE_ENFORCE(shard_id >= 0);
177  CAFFE_ENFORCE(shard_id < num_shards);
178  num_shards_ = num_shards;
179  shard_id_ = shard_id;
180  cursor_ = db_->NewCursor();
181  SeekToFirst();
182  }
183 
184  public:
201  void Read(string* key, string* value) const {
202  CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized.");
203  std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
204  *key = cursor_->key();
205  *value = cursor_->value();
206 
207  // In sharded mode, each read skips num_shards_ records
208  for (int s = 0; s < num_shards_; s++) {
209  cursor_->Next();
210  if (!cursor_->Valid()) {
211  MoveToBeginning();
212  break;
213  }
214  }
215  }
216 
220  void SeekToFirst() const {
221  CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized.");
222  std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
223  MoveToBeginning();
224  }
225 
233  inline Cursor* cursor() const {
234  LOG(ERROR) << "Usually for a DBReader you should use Read() to be "
235  "thread safe. Consider refactoring your code.";
236  return cursor_.get();
237  }
238 
239  private:
240  void MoveToBeginning() const {
241  cursor_->SeekToFirst();
242  for (auto s = 0; s < shard_id_; s++) {
243  cursor_->Next();
244  CAFFE_ENFORCE(
245  cursor_->Valid(), "Db has less rows than shard id: ", s, shard_id_);
246  }
247  }
248 
249  string db_type_;
250  string source_;
251  unique_ptr<DB> db_;
252  unique_ptr<Cursor> cursor_;
253  mutable std::mutex reader_mutex_;
254  uint32_t num_shards_;
255  uint32_t shard_id_;
256 
257  DISABLE_COPY_AND_ASSIGN(DBReader);
258 };
259 
261  public:
266  void Serialize(
267  const Blob& blob,
268  const string& name,
269  BlobSerializerBase::SerializationAcceptor acceptor) override;
270 };
271 
273  public:
274  void Deserialize(const BlobProto& proto, Blob* blob) override;
275 };
276 
277 } // namespace db
278 } // namespace caffe2
279 
280 #endif // CAFFE2_CORE_DB_H_
An abstract class for accessing a database of key-value pairs.
Definition: db.h:80
virtual string value()=0
Returns the current value.
void SeekToFirst() const
Seeks to the first key.
Definition: db.h:220
An abstract class for the cursor of the database while reading.
Definition: db.h:22
virtual bool Valid()=0
Returns whether the current location is valid - for example, if we have reached the end of the databa...
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
Cursor * cursor() const
Returns the underlying cursor of the db reader.
Definition: db.h:233
virtual string key()=0
Returns the current key.
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
BlobSerializerBase is an abstract class that serializes a blob to a string.
virtual void Seek(const string &key)=0
Seek to a specific key (or if the key does not exist, seek to the immediate next).
virtual void Next()=0
Go to the next location in the database.
virtual void SeekToFirst()=0
Seek to the first key in the database.
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:127
An abstract class for the current database transaction while writing.
Definition: db.h:61
void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
Definition: db.h:201