RTBKit  0.9
Open-source framework to create real-time ad bidding systems.
soa/service/zmq_endpoint.h
00001 /* zmq_endpoint.h                                                  -*- C++ -*-
00002    Jeremy Barnes, 25 September 2012
00003    Copyright (c) 2012 Datacratic Inc.  All rights reserved.
00004 
00005    Endpoints for zeromq.
00006 */
00007 
00008 #ifndef __service__zmq_endpoint_h__
00009 #define __service__zmq_endpoint_h__
00010 
00011 #include "named_endpoint.h"
00012 #include "message_loop.h"
00013 #include <set>
00014 #include <type_traits>
00015 #include "jml/utils/smart_ptr_utils.h"
00016 #include "jml/utils/vector_utils.h"
00017 #include <boost/make_shared.hpp>
00018 #include "jml/arch/backtrace.h"
00019 #include "jml/arch/timers.h"
00020 #include "jml/arch/cmp_xchg.h"
00021 #include "zmq_utils.h"
00022 
00023 namespace Datacratic {
00024 
00025 
00026 /*****************************************************************************/
00027 /* ZMQ EVENT SOURCE                                                          */
00028 /*****************************************************************************/
00029 
00032 struct ZmqEventSource : public AsyncEventSource {
00033 
00034     typedef std::function<void (std::vector<std::string>)>
00035         AsyncMessageHandler;
00036     AsyncMessageHandler asyncMessageHandler;
00037 
00038     typedef std::function<std::vector<std::string> (std::vector<std::string>)>
00039         SyncMessageHandler;
00040     SyncMessageHandler syncMessageHandler;
00041 
00042     typedef std::mutex SocketLock;
00043 
00044     ZmqEventSource();
00045 
00046     ZmqEventSource(zmq::socket_t & socket, SocketLock * lock = nullptr);
00047 
00053     template<typename T>
00054     ZmqEventSource(zmq::socket_t & socket,
00055                    const T & handler,
00056                    SocketLock * lock = nullptr,
00057                    typename std::enable_if<!std::is_convertible<decltype(std::declval<T>()(std::declval<std::vector<std::string> >())),
00058                                                                 std::vector<std::string> >::value, void>::type * = 0)
00059         : asyncMessageHandler(handler)
00060     {
00061         init(socket, lock);
00062     }
00063 
00068     template<typename T>
00069     ZmqEventSource(zmq::socket_t & socket,
00070                    const T & handler,
00071                    SocketLock * lock = nullptr,
00072                    typename std::enable_if<std::is_convertible<decltype(std::declval<T>()(std::declval<std::vector<std::string> >())),
00073                                                                 std::vector<std::string> >::value, void>::type * = 0)
00074         : syncMessageHandler(handler)
00075     {
00076         init(socket, lock);
00077     }
00078 
00079     void init(zmq::socket_t & socket, SocketLock * lock = nullptr);
00080 
00081     virtual int selectFd() const;
00082 
00083     virtual bool poll() const;
00084 
00085     virtual bool processOne();
00086 
00091     virtual void handleMessage(const std::vector<std::string> & message);
00092 
00096     virtual std::vector<std::string>
00097     handleSyncMessage(const std::vector<std::string> & message);
00098 
00099     zmq::socket_t & socket() const
00100     {
00101         ExcAssert(socket_);
00102         return *socket_;
00103     }
00104 
00105     SocketLock * socketLock() const
00106     {
00107         return socketLock_;
00108     }
00109 
00110     zmq::socket_t * socket_;
00111 
00112     SocketLock * socketLock_;
00113 };
00114 
00115 
00116 /*****************************************************************************/
00117 /* ZMQ BINARY EVENT SOURCE                                                   */
00118 /*****************************************************************************/
00119 
00122 struct ZmqBinaryEventSource : public AsyncEventSource {
00123 
00124     typedef std::function<void (std::vector<zmq::message_t> &&)>
00125         MessageHandler;
00126     MessageHandler messageHandler;
00127 
00128     ZmqBinaryEventSource()
00129         : socket_(0)
00130     {
00131         needsPoll = true;
00132     }
00133 
00134     ZmqBinaryEventSource(zmq::socket_t & socket,
00135                          MessageHandler messageHandler = MessageHandler())
00136         : messageHandler(std::move(messageHandler)),
00137           socket_(&socket)
00138     {
00139         needsPoll = true;
00140     }
00141 
00142     void init(zmq::socket_t & socket)
00143     {
00144         socket_ = &socket;
00145         needsPoll = true;
00146     }
00147 
00148     virtual int selectFd() const
00149     {
00150         int res = -1;
00151         size_t resSize = sizeof(int);
00152         socket().getsockopt(ZMQ_FD, &res, &resSize);
00153         if (res == -1)
00154             throw ML::Exception("no fd for zeromq socket");
00155         return res;
00156     }
00157 
00158     virtual bool poll() const
00159     {
00160         return getEvents(socket()).first;
00161     }
00162 
00163     virtual bool processOne()
00164     {
00165         ExcAssert(socket_);
00166 
00167         std::vector<zmq::message_t> messages;
00168 
00169         int64_t more = 1;
00170         size_t more_size = sizeof (more);
00171 
00172         while (more) {
00173             zmq::message_t message;
00174             bool got = socket_->recv(&message, messages.empty() ? ZMQ_NOBLOCK: 0);
00175             if (!got) return false;  // no first part available
00176             messages.emplace_back(std::move(message));
00177             socket_->getsockopt(ZMQ_RCVMORE, &more, &more_size);
00178         }
00179 
00180         handleMessage(std::move(messages));
00181 
00182         return poll();
00183     }
00184 
00189     virtual void handleMessage(std::vector<zmq::message_t> && message)
00190     {
00191         if (messageHandler)
00192             messageHandler(std::move(message));
00193         else throw ML::Exception("need to override handleMessage");
00194     }
00195 
00196     zmq::socket_t & socket() const
00197     {
00198         ExcAssert(socket_);
00199         return *socket_;
00200     }
00201 
00202     zmq::socket_t * socket_;
00203 
00204 };
00205 
00206 
00207 /*****************************************************************************/
00208 /* ZMQ BINARY TYPED EVENT SOURCE                                             */
00209 /*****************************************************************************/
00210 
00215 template<typename Arg>
00216 struct ZmqBinaryTypedEventSource: public AsyncEventSource {
00217 
00218     typedef std::function<void (Arg)> MessageHandler;
00219     MessageHandler messageHandler;
00220 
00221     ZmqBinaryTypedEventSource()
00222         : socket_(0)
00223     {
00224         needsPoll = true;
00225     }
00226 
00227     ZmqBinaryTypedEventSource(zmq::socket_t & socket,
00228                               MessageHandler messageHandler = MessageHandler())
00229         : messageHandler(std::move(messageHandler)),
00230           socket_(&socket)
00231     {
00232         needsPoll = true;
00233     }
00234 
00235     void init(zmq::socket_t & socket)
00236     {
00237         socket_ = &socket;
00238         needsPoll = true;
00239     }
00240 
00241     virtual int selectFd() const
00242     {
00243         int res = -1;
00244         size_t resSize = sizeof(int);
00245         socket().getsockopt(ZMQ_FD, &res, &resSize);
00246         if (res == -1)
00247             throw ML::Exception("no fd for zeromq socket");
00248         return res;
00249     }
00250 
00251     virtual bool poll() const
00252     {
00253         return getEvents(socket()).first;
00254     }
00255 
00256 #if 0
00257     template<typename Arg, int Index>
00258     const Arg &
00259     getArg(const std::vector<zmq::message_t> & messages,
00260            const ML::InPosition<Arg, Index> * arg)
00261     {
00262         auto & m = messages.at(Index);
00263         ExcAssertEqual(m.size(), sizeof(Arg));
00264         return * reinterpret_cast<const Arg *>(msg.data());
00265     }
00266 #endif
00267 
00268     virtual bool processOne()
00269     {
00270         zmq::message_t message;
00271         ExcAssert(socket_);
00272         bool got = socket_->recv(&message, ZMQ_NOBLOCK);
00273         if (!got) return false;
00274 
00275         handleMessage(* reinterpret_cast<const Arg *>(message.data()));
00276 
00277         return poll();
00278     }
00279 
00280     virtual void handleMessage(const Arg & arg)
00281     {
00282         if (messageHandler)
00283             messageHandler(arg);
00284         else throw ML::Exception("handleMessage not done");
00285     }
00286 
00287     zmq::socket_t & socket() const
00288     {
00289         ExcAssert(socket_);
00290         return *socket_;
00291     }
00292 
00293     zmq::socket_t * socket_;
00294 };
00295 
00296 
00297 /*****************************************************************************/
00298 /* ZMQ TYPED EVENT SOURCE                                                    */
00299 /*****************************************************************************/
00300 
00311 template<typename T>
00312 struct ZmqTypedEventSource: public ZmqEventSource {
00313     typedef std::function<void (T &&, std::string address)>
00314       OnMessage;
00315 
00316     OnMessage onMessage;
00317     bool routable;
00318     std::string messageTopic;
00319 
00320     ZmqTypedEventSource()
00321     {
00322     }
00323 
00324     ZmqTypedEventSource(zmq::socket_t & socket,
00325                         bool routable,
00326                         const std::string & messageTopic)
00327     {
00328         init(routable, messageTopic);
00329     }
00330 
00331     void init(zmq::socket_t & socket,
00332               bool routable,
00333               const std::string messageTopic)
00334     {
00335         this->routable = routable;
00336         this->messageTopic = messageTopic;
00337     }
00338 
00339     virtual void handleMessage(const std::vector<std::string> & message)
00340     {
00341         int expectedSize = routable + 2;
00342         if (message.size() != expectedSize)
00343             throw ML::Exception("unexpected message size in ZmqTypedMessageSink");
00344 
00345         int i = routable;
00346         if (message[i + 1] != messageTopic)
00347             throw ML::Exception("unexpected messake kind in ZmqTypedMessageSink");
00348 
00349         std::istringstream stream(message[i + 2]);
00350         ML::DB::Store_Reader store(stream);
00351         T result;
00352         store >> result;
00353 
00354         handleTypedMessage(std::move(result), routable ? message[0] : "");
00355     }
00356 
00357     virtual void handleTypedMessage(T && message, const std::string & address)
00358     {
00359         if (onMessage)
00360             onMessage(message, address);
00361         else
00362             throw ML::Exception("need to override handleTypedMessage or assign "
00363                                 "to onMessage");
00364     }
00365 };
00366 
00367 
00368 /*****************************************************************************/
00369 /* ZMQ SOCKET MONITOR                                                        */
00370 /*****************************************************************************/
00371 
00380 struct ZmqSocketMonitor : public ZmqBinaryTypedEventSource<zmq_event_t> {
00381 
00382     ZmqSocketMonitor(zmq::context_t & context);
00383 
00384     ~ZmqSocketMonitor()
00385     {
00386         shutdown();
00387     }
00388 
00389     void shutdown();
00390 
00397     void init(zmq::socket_t & socketToMonitor, int events = ZMQ_EVENT_ALL);
00398 
00410     typedef std::function<void (std::string, int, zmq_event_t)> EventHandler;
00411 
00412     // Success handlers
00413     EventHandler connectHandler, bindHandler, acceptHandler;
00414 
00415     // Socket event handlers
00416     EventHandler closeHandler, disconnectHandler;
00417 
00418     // Failure handlers
00419     EventHandler connectFailureHandler, acceptFailureHandler,
00420         bindFailureHandler, closeFailureHandler;
00421 
00422     // Retry handlers
00423     EventHandler connectRetryHandler;
00424 
00425     // Catch all handler, for when other handlers aren't registered
00426     EventHandler defaultHandler;
00427 
00434     virtual int handleEvent(const zmq_event_t & event);
00435 
00436 private:
00437     typedef std::mutex Lock;
00438     mutable Lock lock;
00439 
00441     std::string connectedUri;
00442 
00444     std::unique_ptr<zmq::socket_t> monitorEndpoint;
00445 
00447     zmq::socket_t * monitoredSocket;
00448 };
00449 
00450 
00451 /*****************************************************************************/
00452 /* ZEROMQ NAMED ENDPOINT                                                     */
00453 /*****************************************************************************/
00454 
00461 struct ZmqNamedEndpoint : public NamedEndpoint, public MessageLoop {
00462 
00463     ZmqNamedEndpoint(std::shared_ptr<zmq::context_t> context);
00464 
00465     ~ZmqNamedEndpoint()
00466     {
00467         shutdown();
00468     }
00469 
00470     void init(std::shared_ptr<ConfigurationService> config,
00471               int socketType,
00472               const std::string & endpointName);
00473 
00474     void shutdown()
00475     {
00476         MessageLoop::shutdown();
00477 
00478         if (socket_) {
00479             unbindAll();
00480             socket_.reset();
00481         }
00482 
00483         //ML::sleep(0.1);
00484         monitor.shutdown();
00485     }
00486 
00492     std::string bindTcp(PortRange const & portRange = PortRange(), std::string host = "");
00493 
00495     void bind(const std::string & address)
00496     {
00497         if (!socket_)
00498             throw ML::Exception("need to call ZmqNamedEndpoint::init() before "
00499                                 "bind");
00500 
00501         std::unique_lock<Lock> guard(lock);
00502         socket_->bind(address);
00503         boundAddresses[address];
00504     }
00505 
00507     void unbindAll()
00508     {
00509         std::unique_lock<Lock> guard(lock);
00510         ExcAssert(socket_);
00511         for (auto addr: boundAddresses)
00512             socket_->tryUnbind(addr.first);
00513     }
00514 
00515     template<typename... Args>
00516     void sendMessage(Args&&... args)
00517     {
00518         using namespace std;
00519         std::unique_lock<Lock> guard(lock);
00520         ExcAssert(socket_);
00521         Datacratic::sendMessage(*socket_, std::forward<Args>(args)...);
00522     }
00523 
00524     void sendMessage(const std::vector<std::string> & message)
00525     {
00526         using namespace std;
00527         std::unique_lock<Lock> guard(lock);
00528         ExcAssert(socket_);
00529         Datacratic::sendAll(*socket_, message);
00530     }
00531 
00533     void sendMessage(std::vector<zmq::message_t> && message)
00534     {
00535         using namespace std;
00536         std::unique_lock<Lock> guard(lock);
00537         ExcAssert(socket_);
00538         for (unsigned i = 0;  i < message.size();  ++i) {
00539             socket_->send(message[i],
00540                           i == message.size() - 1
00541                           ? 0 : ZMQ_SNDMORE);
00542         }
00543     }
00544 
00546     zmq::socket_t & getSocketUnsafe() const
00547     {
00548         ExcAssert(socket_);
00549         return *socket_;
00550     }
00551 
00552     typedef std::function<void (std::vector<zmq::message_t> &&)>
00553         RawMessageHandler;
00554     RawMessageHandler rawMessageHandler;
00555 
00556     typedef std::function<void (std::vector<std::string> &&)> MessageHandler;
00557     MessageHandler messageHandler;
00558 
00563     virtual void handleRawMessage(std::vector<zmq::message_t> && message)
00564     {
00565         if (rawMessageHandler)
00566             rawMessageHandler(std::move(message));
00567         else {
00568             std::vector<std::string> msg2;
00569             for (unsigned i = 0;  i < message.size();  ++i) {
00570                 msg2.push_back(message[i].toString());
00571             }
00572             handleMessage(std::move(msg2));
00573         }
00574     }
00575 
00576     virtual void handleMessage(std::vector<std::string> && message)
00577     {
00578         if (messageHandler)
00579             messageHandler(std::move(message));
00580         else throw ML::Exception("need to override handleRawMessage or "
00581                                  "handleMessage");
00582     }
00583 
00584     typedef std::function<void (std::string bindAddress)>
00585         ConnectionEventHandler;
00586 
00588     ConnectionEventHandler acceptEventHandler;
00589 
00591     ConnectionEventHandler disconnectEventHandler;
00592 
00594     ConnectionEventHandler closeEventHandler;
00595 
00597     virtual void handleAcceptEvent(std::string boundAddress)
00598     {
00599         if (acceptEventHandler)
00600             acceptEventHandler(boundAddress);
00601     }
00602 
00604     virtual void handleDisconnectEvent(std::string boundAddress)
00605     {
00606         if (disconnectEventHandler)
00607             disconnectEventHandler(boundAddress);
00608     }
00609 
00611     virtual void handleCloseEvent(std::string boundAddress)
00612     {
00613         if (closeEventHandler)
00614             closeEventHandler(boundAddress);
00615     }
00616 
00618     size_t numBoundAddresses() const
00619     {
00620         std::unique_lock<Lock> guard(lock);
00621         return boundAddresses.size();
00622     }
00623 
00627     size_t numActiveConnections(std::string addr = "") const
00628     {
00629         std::unique_lock<Lock> guard(lock);
00630         if (addr == "") {
00631             size_t result = 0;
00632             for (auto & addr: boundAddresses)
00633                 result += addr.second.connectedFds.size();
00634             return result;
00635         }
00636         else {
00637             auto it = boundAddresses.find(addr);
00638             if (it == boundAddresses.end())
00639                 return 0;
00640             return it->second.connectedFds.size();
00641         }
00642     }
00643 
00644 private:
00645     typedef std::mutex Lock;
00646     mutable Lock lock;
00647 
00648     std::shared_ptr<zmq::context_t> context_;
00649     std::shared_ptr<zmq::socket_t> socket_;
00650 
00653     ZmqSocketMonitor monitor;
00654 
00655     struct AddressInfo {
00656         AddressInfo()
00657             : listeningFd(-1)
00658         {
00659         }
00660 
00662         int listeningFd;
00663 
00665         std::set<int> connectedFds;
00666     };
00667 
00669     std::map<std::string, AddressInfo> boundAddresses;
00670 
00672     int socketType;
00673 };
00674 
00675 
00676 /*****************************************************************************/
00677 /* ZMQ NAMED CLIENT BUS                                                      */
00678 /*****************************************************************************/
00679 
00683 struct ZmqNamedClientBus: public ZmqNamedEndpoint {
00684 
00685     ZmqNamedClientBus(std::shared_ptr<zmq::context_t> context,
00686                       double deadClientDelay = 5.0)
00687         : ZmqNamedEndpoint(context), deadClientDelay(deadClientDelay)
00688     {
00689     }
00690 
00691     void init(std::shared_ptr<ConfigurationService> config,
00692               const std::string & endpointName)
00693     {
00694         ZmqNamedEndpoint::init(config, ZMQ_XREP, endpointName);
00695         addPeriodic("ZmqNamedClientBus::checkClient", 1.0,
00696                     [=] (uint64_t v) { this->onCheckClient(v); });
00697     }
00698 
00699     virtual ~ZmqNamedClientBus()
00700     {
00701         shutdown();
00702     }
00703 
00704     void shutdown()
00705     {
00706         MessageLoop::shutdown();
00707         ZmqNamedEndpoint::shutdown();
00708     }
00709 
00713     double deadClientDelay;
00714 
00716     std::function<void (std::string)> onConnection;
00717 
00721     std::function<void (std::string)> onDisconnection;
00722 
00723 
00724 
00725     template<typename... Args>
00726     void sendMessage(const std::string & address,
00727                      const std::string & topic,
00728                      Args&&... args)
00729     {
00730         ZmqNamedEndpoint::sendMessage(address, topic,
00731                                       std::forward<Args>(args)...);
00732     }
00733 
00734     virtual void handleMessage(std::vector<std::string> && message)
00735     {
00736         using namespace std;
00737         //cerr << "ZmqNamedClientBus got message " << message << endl;
00738 
00739         const std::string & agent = message.at(0);
00740         const std::string & topic = message.at(1);
00741 
00742         if (topic == "HEARTBEAT") {
00743             // Not the first message from the client
00744             auto it = clientInfo.find(agent);
00745             if (it == clientInfo.end()) {
00746                 // Disconnection then reconnection
00747                 if (onConnection)
00748                     onConnection(agent);
00749                 it = clientInfo.insert(make_pair(agent, ClientInfo())).first;
00750             }
00751             it->second.lastHeartbeat = Date::now();
00752             sendMessage(agent, "HEARTBEAT");
00753         }
00754         else if (topic == "HELLO") {
00755             // First message from client
00756             auto it = clientInfo.find(agent);
00757             if (it == clientInfo.end()) {
00758                 // New connection
00759                 if (onConnection)
00760                     onConnection(agent);
00761                 it = clientInfo.insert(make_pair(agent, ClientInfo())).first;
00762             }
00763             else {
00764                 // Client must have disappeared then reappared without us
00765                 // noticing.
00766                 // Do this disconnection, then the reconnection
00767                 if (onDisconnection)
00768                     onDisconnection(agent);
00769                 if (onConnection)
00770                     onConnection(agent);
00771             }
00772             it->second.lastHeartbeat = Date::now();
00773             sendMessage(agent, "HEARTBEAT");
00774         }
00775         else {
00776             handleClientMessage(message);
00777         }
00778 
00779 #if 0
00780         cerr << "poll() returned " << poll() << endl;
00781 
00782         std::vector<std::string> msg
00783             = recvAllNonBlocking(socket());
00784 
00785         while (!msg.empty()) {
00786             cerr << "*** GOT FURTHER MESSAGE " << msg << endl;
00787             msg = recvAllNonBlocking(socket());
00788         }
00789 
00790         cerr << "poll() returned " << poll() << endl;
00791 #endif
00792     }
00793 
00794     typedef std::function<void (std::vector<std::string>)>
00795     ClientMessageHandler;
00796     ClientMessageHandler clientMessageHandler;
00797 
00798     virtual void handleClientMessage(const std::vector<std::string> & message)
00799     {
00800         if (clientMessageHandler)
00801             clientMessageHandler(message);
00802         else {
00803             throw ML::Exception("need to assign to onClientMessage "
00804                                 "or override handleClientMessage for message "
00805                                 + message.at(1));
00806         }
00807 #if 0
00808         using namespace std;
00809         cerr << "ZmqNamedClientBus handleClientMessage " << message << endl;
00810         throw ML::Exception("handleClientMessage");
00811 #endif
00812     }
00813 
00814 private:
00815     void onCheckClient(uint64_t numEvents)
00816     {
00817         Date now = Date::now();
00818         Date expiry = now.plusSeconds(-deadClientDelay);
00819 
00820         std::vector<std::string> deadClients;
00821 
00822         for (auto & c: clientInfo)
00823             if (c.second.lastHeartbeat < expiry)
00824                 deadClients.push_back(c.first);
00825 
00826         for (auto d: deadClients) {
00827             if (onDisconnection)
00828                 onDisconnection(d);
00829             clientInfo.erase(d);
00830         }
00831     }
00832 
00833     struct ClientInfo {
00834         ClientInfo()
00835             : lastHeartbeat(Date::now())
00836         {
00837         }
00838 
00839         Date lastHeartbeat;
00840     };
00841 
00842     std::map<std::string, ClientInfo> clientInfo;
00843 };
00844 
00845 
00847 enum ConnectionStyle {
00848     CS_ASYNCHRONOUS,  
00849     CS_SYNCHRONOUS,   
00850     CS_MUST_SUCCEED   
00851 };
00852 
00853 
00854 
00855 
00856 /*****************************************************************************/
00857 /* ZEROMQ NAMED PROXY                                                        */
00858 /*****************************************************************************/
00859 
00862 // THIS SHOULD BE REPLACED BY ZmqNamedSocket
00863 
00864 struct ZmqNamedProxy: public MessageLoop {
00865 
00866     ZmqNamedProxy();
00867 
00868     ZmqNamedProxy(std::shared_ptr<zmq::context_t> context);
00869 
00870     ~ZmqNamedProxy()
00871     {
00872         shutdown();
00873     }
00874 
00875     void shutdown()
00876     {
00877         MessageLoop::shutdown();
00878         if(socket_) {
00879             std::lock_guard<ZmqEventSource::SocketLock> guard(socketLock_);
00880             socket_.reset();
00881         }
00882     }
00883 
00884     bool isConnected() const { return connectionState == CONNECTED; }
00885 
00887     typedef std::function<void (std::string)> ConnectionHandler;
00888 
00891     ConnectionHandler connectHandler;
00892 
00896     virtual void onConnect(const std::string & source)
00897     {
00898         if (connectHandler)
00899             connectHandler(source);
00900     }
00901 
00904     ConnectionHandler disconnectHandler;
00905 
00909     virtual void onDisconnect(const std::string & source)
00910     {
00911         if (disconnectHandler)
00912             disconnectHandler(source);
00913     }
00914 
00915     void init(std::shared_ptr<ConfigurationService> config,
00916               int socketType,
00917               const std::string & identity = "");
00918 
00928     bool connect(const std::string & endpointName,
00929                  ConnectionStyle style = CS_ASYNCHRONOUS);
00930 
00937     bool connectToServiceClass(const std::string & serviceClass,
00938                                const std::string & endpointName,
00939                                ConnectionStyle style = CS_ASYNCHRONOUS);
00940 
00942     bool onConfigChange(ConfigurationService::ChangeType change,
00943                         const std::string & key,
00944                         const Json::Value & newValue);
00945 
00947     zmq::socket_t & socket() const
00948     {
00949         ExcAssert(socket_);
00950         return *socket_;
00951     }
00952 
00953     ZmqEventSource::SocketLock * socketLock() const
00954     {
00955         return &socketLock_;
00956     }
00957 
00958     template<typename... Args>
00959     void sendMessage(Args&&... args)
00960     {
00961         std::lock_guard<ZmqEventSource::SocketLock> guard(socketLock_);
00962 
00963         ExcCheckNotEqual(connectionState, NOT_CONNECTED,
00964                 "sending on an unconnected socket: " + endpointName);
00965 
00966         if (connectionState == CONNECTION_PENDING) {
00967             std::cerr << ("dropping message for " + endpointName + "\n");
00968             return;
00969         }
00970 
00971         Datacratic::sendMessage(socket(), std::forward<Args>(args)...);
00972     }
00973 
00974     void disconnect()
00975     {
00976         if (connectionState == NOT_CONNECTED) return;
00977 
00978         {
00979             std::lock_guard<ZmqEventSource::SocketLock> guard(socketLock_);
00980 
00981             if (connectionState == CONNECTED)
00982                 socket_->disconnect(connectedUri);
00983 
00984             connectionState = NOT_CONNECTED;
00985         }
00986 
00987         onDisconnect(connectedUri);
00988     }
00989 
00990 
00991 protected:
00992     ConfigurationService::Watch serviceWatch, endpointWatch;
00993     std::shared_ptr<ConfigurationService> config;
00994     std::shared_ptr<zmq::context_t> context_;
00995     std::shared_ptr<zmq::socket_t> socket_;
00996 
00997     mutable ZmqEventSource::SocketLock socketLock_;
00998 
00999     enum ConnectionType {
01000         NO_CONNECTION,        
01001         CONNECT_DIRECT,       
01002         CONNECT_TO_CLASS,     
01003     } connectionType;
01004 
01005     enum ConnectionState {
01006         NOT_CONNECTED,      // connect() was not called
01007         CONNECTION_PENDING, // connect() was called but service is not available
01008         CONNECTED           // connect() was called and the socket was connected
01009     } connectionState;
01010 
01011     void onServiceNodeChange(const std::string & path,
01012                              ConfigurationService::ChangeType change);
01013     void onEndpointNodeChange(const std::string & path,
01014                               ConfigurationService::ChangeType change);
01015 
01016     std::string serviceClass;      
01017     std::string endpointName;      
01018     std::string connectedService;  
01019     std::string connectedUri;      
01020 };
01021 
01022 
01023 /*****************************************************************************/
01024 /* ZEROMQ NAMED CLIENT BUS PROXY                                             */
01025 /*****************************************************************************/
01026 
01033 struct ZmqNamedClientBusProxy : public ZmqNamedProxy {
01034 
01035     ZmqNamedClientBusProxy()
01036         : timeout(2.0)
01037     {
01038     }
01039 
01040     ZmqNamedClientBusProxy(std::shared_ptr<zmq::context_t> context)
01041         : ZmqNamedProxy(context), timeout(2.0)
01042     {
01043     }
01044 
01045     ~ZmqNamedClientBusProxy()
01046     {
01047         shutdown();
01048     }
01049 
01050     void init(std::shared_ptr<ConfigurationService> config,
01051               const std::string & identity = "")
01052     {
01053         ZmqNamedProxy::init(config, ZMQ_XREQ, identity);
01054 
01055         auto doMessage = [=] (const std::vector<std::string> & message)
01056             {
01057                 const std::string & topic = message.at(0);
01058                 if (topic == "HEARTBEAT")
01059                     this->lastHeartbeat = Date::now();
01060                 else handleMessage(message);
01061             };
01062 
01063         addSource("ZmqNamedClientBusProxy::doMessage",
01064                   std::make_shared<ZmqEventSource>(socket(), doMessage, socketLock()));
01065  
01066         auto doHeartbeat = [=] (int64_t skipped)
01067             {
01068                 if (connectionState != CONNECTED) return;
01069 
01070                 sendMessage("HEARTBEAT");
01071 
01072                 auto now = Date::now();
01073                 auto end = now.plusSeconds(-timeout);
01074                 //if(lastHeartbeat < end) {
01075                     //std::cerr << "no heartbeat for " << timeout << "s... should be disconnecting from " << connectedUri << std::endl;
01076                     //disconnect();
01077                 //}
01078             };
01079 
01080         addPeriodic("ZmqNamedClientBusProxy::doHeartbeat", 1.0, doHeartbeat);
01081    }
01082 
01083     void shutdown()
01084     {
01085         MessageLoop::shutdown();
01086         ZmqNamedProxy::shutdown();
01087     }
01088 
01089     virtual void onConnect(const std::string & where)
01090     {
01091         lastHeartbeat = Date::now();
01092 
01093         sendMessage("HELLO");
01094 
01095         if (connectHandler)
01096             connectHandler(where);
01097     }
01098 
01099     virtual void onDisconnect(const std::string & where)
01100     {
01101         if (disconnectHandler)
01102             disconnectHandler(where);
01103     }
01104 
01105     ZmqEventSource::AsyncMessageHandler messageHandler;
01106 
01107     virtual void handleMessage(const std::vector<std::string> & message)
01108     {
01109         if (messageHandler)
01110             messageHandler(message);
01111         else
01112             throw ML::Exception("need to override on messageHandler or handleMessage");
01113     }
01114 
01115     Date lastHeartbeat;
01116     double timeout;
01117 };
01118 
01119 
01120 /*****************************************************************************/
01121 /* ZEROMQ MULTIPLE NAMED CLIENT BUS PROXY                                    */
01122 /*****************************************************************************/
01123 
01130 struct ZmqMultipleNamedClientBusProxy: public MessageLoop {
01131 
01132     ZmqMultipleNamedClientBusProxy()
01133         : zmqContext(new zmq::context_t(1))
01134     {
01135         connected = false;
01136     }
01137 
01138     ZmqMultipleNamedClientBusProxy(std::shared_ptr<zmq::context_t> context)
01139         : zmqContext(context)
01140     {
01141         connected = false;
01142     }
01143 
01144     ~ZmqMultipleNamedClientBusProxy()
01145     {
01146         shutdown();
01147     }
01148 
01149     void init(std::shared_ptr<ConfigurationService> config,
01150               const std::string & identity = "")
01151     {
01152         this->config = config;
01153         this->identity = identity;
01154     }
01155 
01156     void shutdown()
01157     {
01158         MessageLoop::shutdown();
01159         for (auto & c: connections)
01160             if (c.second)
01161                 c.second->shutdown();
01162     }
01163 
01164     template<typename... Args>
01165     void sendMessage(const std::string & recipient,
01166                      const std::string & topic,
01167                      Args&&... args) const
01168     {
01169         std::unique_lock<Lock> guard(connectionsLock);
01170         auto it = connections.find(recipient);
01171         if (it == connections.end()) {
01172             throw ML::Exception("attempt to deliver " + topic + " message to unknown client "
01173                                 + recipient);
01174         }
01175         it->second->sendMessage(topic, std::forward<Args>(args)...);
01176     }
01177 
01181     void connectAllServiceProviders(const std::string & serviceClass,
01182                                     const std::string & endpointName)
01183     {
01184         if (connected)
01185             throw ML::Exception("alread connected to service providers");
01186 
01187         this->serviceClass = serviceClass;
01188         this->endpointName = endpointName;
01189 
01190         serviceProvidersWatch.init([=] (const std::string & path,
01191                                         ConfigurationService::ChangeType change)
01192                                    {
01193                                        onServiceProvidersChanged("serviceClass/" + serviceClass);
01194                                    });
01195 
01196         onServiceProvidersChanged("serviceClass/" + serviceClass);
01197         // std::cerr << "++++after call to onServiceProvidersChanged " << std::endl;
01198         connected = true;
01199     }
01200 
01201 
01203     void connectSingleServiceProvider(const std::string & service);
01204 
01206     void connectUri(const std::string & zmqUri);
01207 
01208     size_t connectionCount() const
01209     {
01210         std::unique_lock<Lock> guard(connectionsLock);
01211         return connections.size();
01212     }
01213 
01215     typedef std::function<void (std::string)> ConnectionHandler;
01216 
01219     ConnectionHandler connectHandler;
01220 
01224     virtual void onConnect(const std::string & source)
01225     {
01226         if (connectHandler)
01227             connectHandler(source);
01228     }
01229 
01232     ConnectionHandler disconnectHandler;
01233 
01237     virtual void onDisconnect(const std::string & source)
01238     {
01239         if (disconnectHandler)
01240             disconnectHandler(source);
01241     }
01242 
01244     typedef std::function<void (std::string, std::vector<std::string>)>
01245     MessageHandler;
01246 
01250     MessageHandler messageHandler;
01251 
01256     virtual void handleMessage(const std::string & source,
01257                                const std::vector<std::string> & message)
01258     {
01259         if (messageHandler)
01260             messageHandler(source, message);
01261         else
01262             throw ML::Exception("need to override on messageHandler or handleMessage");
01263     }
01264 
01265 private:
01267     bool connected;
01268 
01270     std::shared_ptr<zmq::context_t> zmqContext;
01271 
01273     std::shared_ptr<ConfigurationService> config;
01274 
01276     std::string serviceClass;
01277 
01279     std::string endpointName;
01280 
01282     std::string identity;
01283 
01284     typedef ML::Spinlock Lock;
01285 
01287     mutable Lock connectionsLock;
01288 
01290     std::map<std::string, std::unique_ptr<ZmqNamedClientBusProxy> > connections;
01291 
01293     ConfigurationService::Watch serviceProvidersWatch;
01294 
01298     void onServiceProvidersChanged(const std::string & path)
01299     {
01300         using namespace std;
01301         //cerr << "onServiceProvidersChanged(" << path << ")" << endl;
01302 
01303         // The list of service providers has changed
01304 
01305         vector<string> children
01306             = config->getChildren(path, serviceProvidersWatch);
01307         for (auto c: children) {
01308             Json::Value value = config->getJson(path + "/" + c);
01309             std::string name = value["serviceName"].asString();
01310             std::string path = value["servicePath"].asString();
01311 
01312             watchServiceProvider(name, path);
01313         }
01314     }
01315 
01331     struct OnConnectCallback
01332     {
01333         OnConnectCallback(const ConnectionHandler& fn, std::string name) :
01334             fn(fn), name(name), state(DEFER)
01335         {}
01336 
01338         void release()
01339         {
01340             State old = state;
01341 
01342             ExcAssertNotEqual(old, CALL);
01343 
01344             // If the callback wasn't triggered while we were holding the lock
01345             // then trigger it the next time we see it.
01346             if (old == DEFER && ML::cmp_xchg(state, old, CALL)) return;
01347 
01348             ExcAssertEqual(old, DEFERRED);
01349             fn(name);
01350         }
01351 
01352         void operator() (std::string)
01353         {
01354             State old = state;
01355             ExcAssertNotEqual(old, DEFERRED);
01356 
01357             // If we're still in the locked section then trigger the callback
01358             // when release is called.
01359             if (old == DEFER && ML::cmp_xchg(state, old, DEFERRED)) return;
01360 
01361             // We're out of the locked section so just trigger the callback.
01362             ExcAssertEqual(old, CALL);
01363             fn(name);
01364         }
01365 
01366     private:
01367 
01368         ConnectionHandler fn;
01369         std::string name;
01370 
01371         enum State {
01372             DEFER,    // We're holding the lock so defer an incoming callback.
01373             DEFERRED, // We were called while holding the lock.
01374             CALL      // We were not called while holding the lock.
01375         } state;
01376     };
01377 
01379     void watchServiceProvider(const std::string & name, const std::string & path)
01380     {
01381         // Protects the connections map... I think.
01382         std::unique_lock<Lock> guard(connectionsLock);
01383 
01384         auto & c = connections[name];
01385 
01386         // already connected
01387         if (c) return;
01388 
01389         try {
01390             std::unique_ptr<ZmqNamedClientBusProxy> newClient
01391                 (new ZmqNamedClientBusProxy(zmqContext));
01392             newClient->init(config, identity);
01393 
01394             // The connect call below could trigger this callback while we're
01395             // holding the connectionsLock which is a big no-no. This fancy
01396             // wrapper ensures that it's only called after we call its release
01397             // function.
01398             newClient->connectHandler = OnConnectCallback(connectHandler, name);
01399 
01400             newClient->disconnectHandler = [=] (std::string s)
01401                 {
01402                     // TODO: chain in so that we know it's not around any more
01403                     this->onDisconnect(s);
01404                 };
01405 
01406             newClient->connect(path + "/" + endpointName);
01407             newClient->messageHandler = [=] (const std::vector<std::string> & msg)
01408                 {
01409                     this->handleMessage(name, msg);
01410                 };
01411             //newClient->debug(true);
01412 
01413             c.reset(newClient.release());
01414 
01415             // Add it to our message loop so that it can process messages
01416             addSource("ZmqMultipleNamedClientBusProxy child " + name, *c);
01417 
01418             guard.unlock();
01419             c->connectHandler.target<OnConnectCallback>()->release();
01420 
01421         } catch (...) {
01422             connections.erase(name);
01423             throw;
01424         }
01425     }
01426 
01427 };
01428 
01429 
01430 
01431 } // namespace Datacratic
01432 
01433 #endif /* __service__zmq_endpoint_h__ */
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator