Caffe2 - C++ API
A deep learning, cross platform ML framework
zmq_helper.h
1 #ifndef CAFFE2_UTILS_ZMQ_HELPER_H_
2 #define CAFFE2_UTILS_ZMQ_HELPER_H_
3 
4 #include <zmq.h>
5 
6 #include "caffe2/core/logging.h"
7 
8 namespace caffe2 {
9 
10 class ZmqContext {
11  public:
12  explicit ZmqContext(int io_threads) : ptr_(zmq_ctx_new()) {
13  CAFFE_ENFORCE(ptr_ != nullptr, "Failed to create zmq context.");
14  int rc = zmq_ctx_set(ptr_, ZMQ_IO_THREADS, io_threads);
15  CAFFE_ENFORCE_EQ(rc, 0);
16  rc = zmq_ctx_set(ptr_, ZMQ_MAX_SOCKETS, ZMQ_MAX_SOCKETS_DFLT);
17  CAFFE_ENFORCE_EQ(rc, 0);
18  }
19  ~ZmqContext() {
20  int rc = zmq_ctx_destroy(ptr_);
21  CAFFE_ENFORCE_EQ(rc, 0);
22  }
23 
24  void* ptr() { return ptr_; }
25 
26  private:
27  void* ptr_;
28 
29  DISABLE_COPY_AND_ASSIGN(ZmqContext);
30 };
31 
32 class ZmqMessage {
33  public:
34  ZmqMessage() {
35  int rc = zmq_msg_init(&msg_);
36  CAFFE_ENFORCE_EQ(rc, 0);
37  }
38 
39  ~ZmqMessage() {
40  int rc = zmq_msg_close(&msg_);
41  CAFFE_ENFORCE_EQ(rc, 0);
42  }
43 
44  zmq_msg_t* msg() { return &msg_; }
45 
46  void* data() { return zmq_msg_data(&msg_); }
47  size_t size() { return zmq_msg_size(&msg_); }
48 
49  private:
50  zmq_msg_t msg_;
51  DISABLE_COPY_AND_ASSIGN(ZmqMessage);
52 };
53 
54 class ZmqSocket {
55  public:
56  explicit ZmqSocket(int type)
57  : context_(1), ptr_(zmq_socket(context_.ptr(), type)) {
58  CAFFE_ENFORCE(ptr_ != nullptr, "Faild to create zmq socket.");
59  }
60 
61  ~ZmqSocket() {
62  int rc = zmq_close(ptr_);
63  CAFFE_ENFORCE_EQ(rc, 0);
64  }
65 
66  void Bind(const string& addr) {
67  int rc = zmq_bind(ptr_, addr.c_str());
68  CAFFE_ENFORCE_EQ(rc, 0);
69  }
70 
71  void Unbind(const string& addr) {
72  int rc = zmq_unbind(ptr_, addr.c_str());
73  CAFFE_ENFORCE_EQ(rc, 0);
74  }
75 
76  void Connect(const string& addr) {
77  int rc = zmq_connect(ptr_, addr.c_str());
78  CAFFE_ENFORCE_EQ(rc, 0);
79  }
80 
81  void Disconnect(const string& addr) {
82  int rc = zmq_disconnect(ptr_, addr.c_str());
83  CAFFE_ENFORCE_EQ(rc, 0);
84  }
85 
86  int Send(const string& msg, int flags) {
87  int nbytes = zmq_send(ptr_, msg.c_str(), msg.size(), flags);
88  if (nbytes) {
89  return nbytes;
90  } else if (zmq_errno() == EAGAIN) {
91  return 0;
92  } else {
93  LOG(FATAL) << "Cannot send zmq message. Error number: "
94  << zmq_errno();
95  return 0;
96  }
97  }
98 
99  int SendTillSuccess(const string& msg, int flags) {
100  CAFFE_ENFORCE(msg.size(), "You cannot send an empty message.");
101  int nbytes = 0;
102  do {
103  nbytes = Send(msg, flags);
104  } while (nbytes == 0);
105  return nbytes;
106  }
107 
108  int Recv(ZmqMessage* msg) {
109  int nbytes = zmq_msg_recv(msg->msg(), ptr_, 0);
110  if (nbytes >= 0) {
111  return nbytes;
112  } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) {
113  return 0;
114  } else {
115  LOG(FATAL) << "Cannot receive zmq message. Error number: "
116  << zmq_errno();
117  return 0;
118  }
119  }
120 
121  int RecvTillSuccess(ZmqMessage* msg) {
122  int nbytes = 0;
123  do {
124  nbytes = Recv(msg);
125  } while (nbytes == 0);
126  return nbytes;
127  }
128 
129  private:
130  ZmqContext context_;
131  void* ptr_;
132 };
133 
134 } // namespace caffe2
135 
136 
137 #endif // CAFFE2_UTILS_ZMQ_HELPER_H_
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...