RTBKit  0.9
Open-source framework to create real-time ad bidding systems.
soa/service/zmq_endpoint.cc
00001 /* zmq_endpoint.cc
00002    Jeremy Barnes, 9 November 2012
00003    Copyright (c) 2012 Datacratic Inc.  All rights reserved.
00004 
00005 */
00006 
00007 #include "zmq_endpoint.h"
00008 #include "jml/utils/smart_ptr_utils.h"
00009 #include <sys/utsname.h>
00010 #include <thread>
00011 #include "jml/arch/timers.h"
00012 
00013 using namespace std;
00014 
00015 
00016 namespace Datacratic {
00017 
00018 
00019 /*****************************************************************************/
00020 /* ZMQ EVENT SOURCE                                                          */
00021 /*****************************************************************************/
00022 
00023 ZmqEventSource::
00024 ZmqEventSource()
00025     : socket_(0), socketLock_(nullptr)
00026 {
00027     needsPoll = true;
00028 }
00029 
00030 ZmqEventSource::
00031 ZmqEventSource(zmq::socket_t & socket, SocketLock * socketLock)
00032     : socket_(&socket), socketLock_(socketLock)
00033 {
00034     needsPoll = true;
00035 }
00036 
00037 void
00038 ZmqEventSource::
00039 init(zmq::socket_t & socket, SocketLock * socketLock)
00040 {
00041     socket_ = &socket;
00042     socketLock_ = socketLock;
00043     needsPoll = true;
00044 }
00045 
00046 int
00047 ZmqEventSource::
00048 selectFd() const
00049 {
00050     int res = -1;
00051     size_t resSize = sizeof(int);
00052     socket().getsockopt(ZMQ_FD, &res, &resSize);
00053     if (res == -1)
00054         throw ML::Exception("no fd for zeromq socket");
00055     using namespace std;
00056     //cerr << "select FD is " << res << endl;
00057     return res;
00058 }
00059 
00060 bool
00061 ZmqEventSource::
00062 poll() const
00063 {
00064     return getEvents(socket()).first;
00065 
00066 #if 0
00067     using namespace std;
00068 
00069     
00070 
00071     zmq_pollitem_t toPoll = { socket(), 0, ZMQ_POLLIN };
00072     int res = zmq_poll(&toPoll, 1, 0);
00073     //cerr << "poll returned " << res << endl;
00074     if (res == -1)
00075         throw ML::Exception(errno, "zmq_poll");
00076     return res;
00077 #endif
00078 }
00079 
00080 bool
00081 ZmqEventSource::
00082 processOne()
00083 {
00084     using namespace std;
00085     if (debug_)
00086         cerr << "called processOne on " << this << ", poll = " << poll() << endl;
00087 
00088     std::vector<std::string> msg;
00089 
00093     {
00094         std::unique_lock<SocketLock> guard;
00095         if (socketLock_)
00096             guard = std::unique_lock<SocketLock>(*socketLock_);
00097 
00098         msg = recvAllNonBlocking(socket());
00099     }
00100 
00101     if (!msg.empty()) {
00102         if (debug_)
00103             cerr << "got message of length " << msg.size() << endl;
00104         handleMessage(msg);
00105     }
00106 
00107     return poll();
00108 }
00109 
00110 void
00111 ZmqEventSource::
00112 handleMessage(const std::vector<std::string> & message)
00113 {
00114     if (asyncMessageHandler) {
00115         asyncMessageHandler(message);
00116         return;
00117     }
00118 
00119     auto reply = handleSyncMessage(message);
00120     if (!reply.empty()) {
00121         sendAll(socket(), reply);
00122     }
00123 }
00124 
00125 std::vector<std::string>
00126 ZmqEventSource::
00127 handleSyncMessage(const std::vector<std::string> & message)
00128 {
00129     if (syncMessageHandler)
00130         return syncMessageHandler(message);
00131     throw ML::Exception("need to assign to or override one of the "
00132                         "message handlers");
00133 }
00134     
00135 
00136 /*****************************************************************************/
00137 /* ZMQ SOCKET MONITOR                                                        */
00138 /*****************************************************************************/
00139 
00140 //static int numMonitors = 0;
00141 
00142 ZmqSocketMonitor::
00143 ZmqSocketMonitor(zmq::context_t & context)
00144     : monitorEndpoint(new zmq::socket_t(context, ZMQ_PAIR)),
00145       monitoredSocket(0)
00146 {
00147     //cerr << "creating socket monitor at " << this << endl;
00148     //__sync_fetch_and_add(&numMonitors, 1);
00149 }
00150 
00151 void
00152 ZmqSocketMonitor::
00153 shutdown()
00154 {
00155     if (!monitorEndpoint)
00156         return;
00157 
00158     //cerr << "shutting down socket monitor at " << this << endl;
00159     
00160     connectedUri.clear();
00161     std::unique_lock<Lock> guard(lock);
00162     monitorEndpoint.reset();
00163 
00164     //cerr << __sync_add_and_fetch(&numMonitors, -1) << " monitors still active"
00165     //     << endl;
00166 }
00167 
00168 void
00169 ZmqSocketMonitor::
00170 init(zmq::socket_t & socketToMonitor, int events)
00171 {
00172     static int serial = 0;
00173 
00174     // Initialize the monitor connection
00175     connectedUri
00176         = ML::format("inproc://monitor-%p-%d",
00177                      this, __sync_fetch_and_add(&serial, 1));
00178     monitoredSocket = &socketToMonitor;
00179         
00180     //using namespace std;
00181     //cerr << "connecting monitor to " << connectedUri << endl;
00182 
00183     int res = zmq_socket_monitor(socketToMonitor, connectedUri.c_str(), events);
00184     if (res == -1)
00185         throw zmq::error_t();
00186 
00187     // Connect it in
00188     monitorEndpoint->connect(connectedUri.c_str());
00189 
00190     // Make sure we receive events from it
00191     ZmqBinaryTypedEventSource<zmq_event_t>::init(*monitorEndpoint);
00192 
00193     messageHandler = [=] (const zmq_event_t & event)
00194         {
00195             this->handleEvent(event);
00196         };
00197 }
00198 
00199 bool debugZmqMonitorEvents = false;
00200 
00201 int
00202 ZmqSocketMonitor::
00203 handleEvent(const zmq_event_t & event)
00204 {
00205     if (debugZmqMonitorEvents) {
00206         cerr << "got socket event " << printZmqEvent(event.event)
00207              << " at " << this
00208              << " " << connectedUri
00209              << " for socket " << monitoredSocket << endl;
00210     }
00211 
00212     auto doEvent = [&] (const EventHandler & handler,
00213                         const char * addr,
00214                         int param)
00215         {
00216             if (handler)
00217                 handler(addr, param, event);
00218             else if (defaultHandler)
00219                 defaultHandler(addr, param, event);
00220             else return 0;
00221             return 1;
00222         };
00223 
00224     switch (event.event) {
00225 
00226         // Bind
00227     case ZMQ_EVENT_LISTENING:
00228         return doEvent(bindHandler,
00229                        event.data.listening.addr,
00230                        event.data.listening.fd);
00231 
00232     case ZMQ_EVENT_BIND_FAILED:
00233         return doEvent(bindFailureHandler,
00234                        event.data.bind_failed.addr,
00235                        event.data.bind_failed.err);
00236 
00237         // Accept
00238     case ZMQ_EVENT_ACCEPTED:
00239         return doEvent(acceptHandler,
00240                        event.data.accepted.addr,
00241                        event.data.accepted.fd);
00242     case ZMQ_EVENT_ACCEPT_FAILED:
00243         return doEvent(acceptFailureHandler,
00244                        event.data.accept_failed.addr,
00245                        event.data.accept_failed.err);
00246         break;
00247 
00248         // Connect
00249     case ZMQ_EVENT_CONNECTED:
00250         return doEvent(connectHandler,
00251                        event.data.connected.addr,
00252                        event.data.connected.fd);
00253     case ZMQ_EVENT_CONNECT_DELAYED:
00254         return doEvent(connectFailureHandler,
00255                        event.data.connect_delayed.addr,
00256                        event.data.connect_delayed.err);
00257     case ZMQ_EVENT_CONNECT_RETRIED:
00258         return doEvent(connectRetryHandler,
00259                        event.data.connect_retried.addr,
00260                        event.data.connect_retried.interval);
00261             
00262         // Close and disconnection
00263     case ZMQ_EVENT_CLOSE_FAILED:
00264         return doEvent(closeFailureHandler,
00265                        event.data.close_failed.addr,
00266                        event.data.close_failed.err);
00267     case ZMQ_EVENT_CLOSED:
00268         return doEvent(closeHandler,
00269                        event.data.closed.addr,
00270                        event.data.closed.fd);
00271 
00272     case ZMQ_EVENT_DISCONNECTED:
00273         return doEvent(disconnectHandler,
00274                        event.data.disconnected.addr,
00275                        event.data.disconnected.fd);
00276             
00277     default:
00278         using namespace std;
00279         cerr << "got unknown event type " << event.event << endl;
00280         return doEvent(defaultHandler, "", -1);
00281     }
00282 }
00283 
00284 
00285 /*****************************************************************************/
00286 /* NAMED ZEROMQ ENDPOINT                                                     */
00287 /*****************************************************************************/
00288 
00289 ZmqNamedEndpoint::
00290 ZmqNamedEndpoint(std::shared_ptr<zmq::context_t> context)
00291     : context_(context), monitor(*context)
00292 {
00293 }
00294 
00295 void
00296 ZmqNamedEndpoint::
00297 init(std::shared_ptr<ConfigurationService> config,
00298      int socketType,
00299      const std::string & endpointName)
00300 {
00301     NamedEndpoint::init(config, endpointName);
00302     this->socketType = socketType;
00303     this->socket_.reset(new zmq::socket_t(*context_, socketType));
00304     setHwm(*socket_, 65536);
00305     
00306     bool monitorSocket = false;
00307 
00308     if (monitorSocket) {
00309         monitor.init(*socket_);
00310 
00311         monitor.bindHandler = [=] (std::string addr, int fd, const zmq_event_t &)
00312             {
00313                 std::unique_lock<Lock> guard(lock);
00314                 ExcAssert(!boundAddresses.count(addr));
00315                 boundAddresses[addr].listeningFd = fd;
00316             };
00317 
00318         monitor.acceptHandler = [=] (std::string addr, int fd, const zmq_event_t &)
00319             {
00320                 {
00321                     std::unique_lock<Lock> guard(lock);
00322                     ExcAssert(boundAddresses.count(addr));
00323                     bool added = boundAddresses[addr].connectedFds.insert(fd).second;
00324                     ExcAssert(added);
00325                 }
00326 
00327                 handleAcceptEvent(addr);
00328             };
00329 
00330         monitor.disconnectHandler = [=] (std::string addr, int fd, const zmq_event_t &)
00331             {
00332                 {
00333                     std::unique_lock<Lock> guard(lock);
00334                     ExcAssert(boundAddresses.count(addr));
00335                     bool erased = boundAddresses[addr].connectedFds.erase(fd);
00336                     ExcAssert(erased);
00337                 }
00338             
00339                 handleDisconnectEvent(addr);
00340             };
00341 
00342         monitor.closeHandler = [=] (std::string addr, int fd, const zmq_event_t &)
00343             {
00344                 {
00345                     std::unique_lock<Lock> guard(lock);
00346                     if (boundAddresses[addr].listeningFd != -1)
00347                         ExcAssertEqual(boundAddresses[addr].listeningFd, fd);
00348                     boundAddresses.erase(addr);
00349                 }
00350             
00351                 handleDisconnectEvent(addr);
00352             };
00353     
00354         // zmq_bind() returns this information for us
00355         monitor.bindFailureHandler = [=] (std::string addr, int fd, const zmq_event_t &)
00356             {
00357             };
00358     
00359 
00360         monitor.defaultHandler = [=] (string addr, int param,
00361                                       const zmq_event_t & event)
00362             {
00363                 cerr << "ZmqNamedEndpoint got socket event "
00364                 << printZmqEvent(event.event)
00365                 << " on " << addr << " with " << param;
00366                 if (zmqEventIsError(event.event))
00367                     cerr << " " << strerror(param);
00368                 cerr<< endl;
00369             };
00370 
00371         addSource("ZmqNamedEndpoint::monitor", monitor);
00372     }
00373 
00374     addSource("ZmqNamedEndpoint::socket",
00375               std::make_shared<ZmqBinaryEventSource>
00376               (*socket_, [=] (std::vector<zmq::message_t> && message)
00377                {
00378                    handleRawMessage(std::move(message));
00379                }));
00380 }
00381 
00382 std::string
00383 ZmqNamedEndpoint::
00384 bindTcp(PortRange const & portRange, std::string host)
00385 {
00386     std::unique_lock<Lock> guard(lock);
00387 
00388     if (!socket_)
00389         throw ML::Exception("need to call ZmqNamedEndpoint::init() before "
00390                             "bind");
00391 
00392     using namespace std;
00393 
00394     if (host == "")
00395         host = "*";
00396 
00397     int port = bindAndReturnOpenTcpPort(*socket_, portRange, host);
00398 
00399     auto getUri = [&] (const std::string & host)
00400         {
00401             return "tcp://" + host + ":" + to_string(port);
00402         };
00403 
00404     Json::Value config;
00405 
00406     auto addEntry = [&] (const std::string & addr,
00407                          const std::string & hostScope,
00408                          const std::string & uri)
00409         {
00410             Json::Value & entry = config[config.size()];
00411             entry["zmqConnectUri"] = uri;
00412 
00413             Json::Value & transports = entry["transports"];
00414             transports[0]["name"] = "tcp";
00415             transports[0]["addr"] = addr;
00416             transports[0]["hostScope"] = hostScope;
00417             transports[0]["port"] = port;
00418             transports[1]["name"] = "zeromq";
00419             transports[1]["socketType"] = socketType;
00420             transports[1]["uri"] = uri;
00421         };
00422 
00423     if (host == "*") {
00424         auto interfaces = getInterfaces({AF_INET});
00425         for (unsigned i = 0;  i < interfaces.size();  ++i) {
00426             addEntry(interfaces[i].addr, interfaces[i].hostScope,
00427                      getUri(interfaces[i].addr));
00428         }
00429         publishAddress("tcp", config);
00430         return getUri(host);
00431     }
00432     else {
00433         string host2 = addrToIp(host);
00434         string uri = getUri(host2);
00435         // TODO: compute the host scope; don't just assume "*"
00436         addEntry(host2, "*", uri);
00437         publishAddress("tcp", config);
00438         return uri;
00439     }
00440 }
00441 
00442  
00443 
00444 /*****************************************************************************/
00445 /* NAMED ZEROMQ PROXY                                                        */
00446 /*****************************************************************************/
00447 
00448 ZmqNamedProxy::
00449 ZmqNamedProxy()
00450     : context_(new zmq::context_t(1))
00451 {
00452 }
00453 
00454 ZmqNamedProxy::
00455 ZmqNamedProxy(std::shared_ptr<zmq::context_t> context)
00456     : context_(context)
00457 {
00458 }
00459 
00460 void
00461 ZmqNamedProxy::
00462 init(std::shared_ptr<ConfigurationService> config,
00463      int socketType,
00464      const std::string & identity)
00465 {
00466     this->connectionType = NO_CONNECTION;
00467     this->connectionState = NOT_CONNECTED;
00468 
00469     this->config = config;
00470     socket_.reset(new zmq::socket_t(*context_, socketType));
00471     if (identity != "")
00472         setIdentity(*socket_, identity);
00473     setHwm(*socket_, 65536);
00474 
00475     serviceWatch.init(std::bind(&ZmqNamedProxy::onServiceNodeChange,
00476                                 this,
00477                                 std::placeholders::_1,
00478                                 std::placeholders::_2));
00479 
00480     endpointWatch.init(std::bind(&ZmqNamedProxy::onEndpointNodeChange,
00481                                  this,
00482                                  std::placeholders::_1,
00483                                  std::placeholders::_2));
00484 }
00485 
00486 bool
00487 ZmqNamedProxy::
00488 connect(const std::string & endpointName,
00489         ConnectionStyle style)
00490 {
00491     if (!config)
00492         throw ML::Exception("attempt to connect to named service "
00493                             + endpointName + " without calling init()");
00494 
00495     if (connectionState == CONNECTED)
00496         throw ML::Exception("already connected");
00497 
00498     this->connectedService = endpointName;
00499     if (connectionType == NO_CONNECTION)
00500         connectionType = CONNECT_DIRECT;
00501 
00502     cerr << "connecting to " << endpointName << endl;
00503 
00504     vector<string> children
00505         = config->getChildren(endpointName, endpointWatch);
00506 
00507     auto setPending = [&]
00508         {
00509             std::lock_guard<ZmqEventSource::SocketLock> guard(socketLock_);
00510 
00511             if (connectionState == NOT_CONNECTED)
00512                 connectionState = CONNECTION_PENDING;
00513         };
00514 
00515     for (auto c: children) {
00516         ExcAssertNotEqual(connectionState, CONNECTED);
00517         string key = endpointName + "/" + c;
00518         //cerr << "got key " << key << endl;
00519         Json::Value epConfig = config->getJson(key);
00520 
00521         //cerr << "epConfig for " << key << " is " << epConfig
00522         //     << endl;
00523                 
00524         for (auto & entry: epConfig) {
00525 
00526             //cerr << "entry is " << entry << endl;
00527 
00528             if (!entry.isMember("zmqConnectUri"))
00529                 return true;
00530 
00531             string uri = entry["zmqConnectUri"].asString();
00532 
00533             auto hs = entry["transports"][0]["hostScope"];
00534             if (!hs)
00535                 continue;
00536 
00537             string hostScope = hs.asString();
00538             if (hs != "*") {
00539                 utsname name;
00540                 if (uname(&name))
00541                     throw ML::Exception(errno, "uname");
00542                 if (hostScope != name.nodename)
00543                     continue;  // wrong host scope
00544             }
00545 
00546             {
00547                 std::lock_guard<ZmqEventSource::SocketLock> guard(socketLock_);
00548                 socket().connect(uri.c_str());
00549                 connectedUri = uri;
00550                 connectionState = CONNECTED;
00551             }
00552 
00553             cerr << "connected to " << uri << endl;
00554             onConnect(uri);
00555             return true;
00556         }
00557 
00558         setPending();
00559         return false;
00560     }
00561 
00562     if (style == CS_MUST_SUCCEED && connectionState != CONNECTED)
00563         throw ML::Exception("couldn't connect to any services of class "
00564                             + serviceClass);
00565 
00566     setPending();
00567     return connectionState == CONNECTED;
00568 }
00569 
00570 bool
00571 ZmqNamedProxy::
00572 connectToServiceClass(const std::string & serviceClass,
00573                       const std::string & endpointName,
00574                       ConnectionStyle style)
00575 {
00576     // TODO: exception safety... if we bail don't screw around the auction
00577     ExcAssertNotEqual(connectionType, CONNECT_DIRECT);
00578     ExcAssertNotEqual(serviceClass, "");
00579     ExcAssertNotEqual(endpointName, "");
00580 
00581     //cerr << "serviceClass = " << serviceClass << endl;
00582 
00583     this->serviceClass = serviceClass;
00584     this->endpointName = endpointName;
00585 
00586     if (connectionType == NO_CONNECTION)
00587         connectionType = CONNECT_TO_CLASS;
00588 
00589     if (!config)
00590         throw ML::Exception("attempt to connect to named service "
00591                             + endpointName + " without calling init()");
00592 
00593     if (connectionState == CONNECTED)
00594         throw ML::Exception("attempt to double connect connection");
00595 
00596     vector<string> children
00597         = config->getChildren("serviceClass/" + serviceClass, serviceWatch);
00598 
00599     for (auto c: children) {
00600         string key = "serviceClass/" + serviceClass + "/" + c;
00601         //cerr << "getting " << key << endl;
00602         Json::Value value = config->getJson(key);
00603         std::string name = value["serviceName"].asString();
00604         std::string path = value["servicePath"].asString();
00605 
00606         //cerr << "name = " << name << " path = " << path << endl;
00607         if (connect(path + "/" + endpointName,
00608                     style == CS_ASYNCHRONOUS ? CS_ASYNCHRONOUS : CS_SYNCHRONOUS))
00609             return true;
00610     }
00611 
00612     if (style == CS_MUST_SUCCEED && connectionState != CONNECTED)
00613         throw ML::Exception("couldn't connect to any services of class "
00614                             + serviceClass);
00615 
00616     {
00617         std::lock_guard<ZmqEventSource::SocketLock> guard(socketLock_);
00618 
00619         if (connectionState == NOT_CONNECTED)
00620             connectionState = CONNECTION_PENDING;
00621     }
00622 
00623     return connectionState == CONNECTED;
00624 }
00625 
00626 void
00627 ZmqNamedProxy::
00628 onServiceNodeChange(const std::string & path,
00629                     ConfigurationService::ChangeType change)
00630 {
00631     //cerr << "******* CHANGE TO SERVICE NODE " << path << endl;
00632 
00633     if (connectionState != CONNECTION_PENDING)
00634         return;  // no need to watch anymore
00635 
00636     connectToServiceClass(serviceClass, endpointName, CS_ASYNCHRONOUS);
00637 }
00638 
00639 void
00640 ZmqNamedProxy::
00641 onEndpointNodeChange(const std::string & path,
00642                      ConfigurationService::ChangeType change)
00643 {
00644     //cerr << "******* CHANGE TO ENDPOINT NODE " << path << endl;
00645 
00646     if (connectionState != CONNECTION_PENDING)
00647         return;  // no need to watch anymore
00648 
00649     connect(connectedService, CS_ASYNCHRONOUS);
00650 }
00651 
00652 
00653 } // namespace Datacratic
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator