RTBKit
0.9
Open-source framework to create real-time ad bidding systems.
|
00001 /* zmq_endpoint.h -*- C++ -*- 00002 Jeremy Barnes, 25 September 2012 00003 Copyright (c) 2012 Datacratic Inc. All rights reserved. 00004 00005 Endpoints for zeromq. 00006 */ 00007 00008 #ifndef __service__zmq_endpoint_h__ 00009 #define __service__zmq_endpoint_h__ 00010 00011 #include "named_endpoint.h" 00012 #include "message_loop.h" 00013 #include <set> 00014 #include <type_traits> 00015 #include "jml/utils/smart_ptr_utils.h" 00016 #include "jml/utils/vector_utils.h" 00017 #include <boost/make_shared.hpp> 00018 #include "jml/arch/backtrace.h" 00019 #include "jml/arch/timers.h" 00020 #include "jml/arch/cmp_xchg.h" 00021 #include "zmq_utils.h" 00022 00023 namespace Datacratic { 00024 00025 00026 /*****************************************************************************/ 00027 /* ZMQ EVENT SOURCE */ 00028 /*****************************************************************************/ 00029 00032 struct ZmqEventSource : public AsyncEventSource { 00033 00034 typedef std::function<void (std::vector<std::string>)> 00035 AsyncMessageHandler; 00036 AsyncMessageHandler asyncMessageHandler; 00037 00038 typedef std::function<std::vector<std::string> (std::vector<std::string>)> 00039 SyncMessageHandler; 00040 SyncMessageHandler syncMessageHandler; 00041 00042 typedef std::mutex SocketLock; 00043 00044 ZmqEventSource(); 00045 00046 ZmqEventSource(zmq::socket_t & socket, SocketLock * lock = nullptr); 00047 00053 template<typename T> 00054 ZmqEventSource(zmq::socket_t & socket, 00055 const T & handler, 00056 SocketLock * lock = nullptr, 00057 typename std::enable_if<!std::is_convertible<decltype(std::declval<T>()(std::declval<std::vector<std::string> >())), 00058 std::vector<std::string> >::value, void>::type * = 0) 00059 : asyncMessageHandler(handler) 00060 { 00061 init(socket, lock); 00062 } 00063 00068 template<typename T> 00069 ZmqEventSource(zmq::socket_t & socket, 00070 const T & handler, 00071 SocketLock * lock = nullptr, 00072 typename std::enable_if<std::is_convertible<decltype(std::declval<T>()(std::declval<std::vector<std::string> >())), 00073 std::vector<std::string> >::value, void>::type * = 0) 00074 : syncMessageHandler(handler) 00075 { 00076 init(socket, lock); 00077 } 00078 00079 void init(zmq::socket_t & socket, SocketLock * lock = nullptr); 00080 00081 virtual int selectFd() const; 00082 00083 virtual bool poll() const; 00084 00085 virtual bool processOne(); 00086 00091 virtual void handleMessage(const std::vector<std::string> & message); 00092 00096 virtual std::vector<std::string> 00097 handleSyncMessage(const std::vector<std::string> & message); 00098 00099 zmq::socket_t & socket() const 00100 { 00101 ExcAssert(socket_); 00102 return *socket_; 00103 } 00104 00105 SocketLock * socketLock() const 00106 { 00107 return socketLock_; 00108 } 00109 00110 zmq::socket_t * socket_; 00111 00112 SocketLock * socketLock_; 00113 }; 00114 00115 00116 /*****************************************************************************/ 00117 /* ZMQ BINARY EVENT SOURCE */ 00118 /*****************************************************************************/ 00119 00122 struct ZmqBinaryEventSource : public AsyncEventSource { 00123 00124 typedef std::function<void (std::vector<zmq::message_t> &&)> 00125 MessageHandler; 00126 MessageHandler messageHandler; 00127 00128 ZmqBinaryEventSource() 00129 : socket_(0) 00130 { 00131 needsPoll = true; 00132 } 00133 00134 ZmqBinaryEventSource(zmq::socket_t & socket, 00135 MessageHandler messageHandler = MessageHandler()) 00136 : messageHandler(std::move(messageHandler)), 00137 socket_(&socket) 00138 { 00139 needsPoll = true; 00140 } 00141 00142 void init(zmq::socket_t & socket) 00143 { 00144 socket_ = &socket; 00145 needsPoll = true; 00146 } 00147 00148 virtual int selectFd() const 00149 { 00150 int res = -1; 00151 size_t resSize = sizeof(int); 00152 socket().getsockopt(ZMQ_FD, &res, &resSize); 00153 if (res == -1) 00154 throw ML::Exception("no fd for zeromq socket"); 00155 return res; 00156 } 00157 00158 virtual bool poll() const 00159 { 00160 return getEvents(socket()).first; 00161 } 00162 00163 virtual bool processOne() 00164 { 00165 ExcAssert(socket_); 00166 00167 std::vector<zmq::message_t> messages; 00168 00169 int64_t more = 1; 00170 size_t more_size = sizeof (more); 00171 00172 while (more) { 00173 zmq::message_t message; 00174 bool got = socket_->recv(&message, messages.empty() ? ZMQ_NOBLOCK: 0); 00175 if (!got) return false; // no first part available 00176 messages.emplace_back(std::move(message)); 00177 socket_->getsockopt(ZMQ_RCVMORE, &more, &more_size); 00178 } 00179 00180 handleMessage(std::move(messages)); 00181 00182 return poll(); 00183 } 00184 00189 virtual void handleMessage(std::vector<zmq::message_t> && message) 00190 { 00191 if (messageHandler) 00192 messageHandler(std::move(message)); 00193 else throw ML::Exception("need to override handleMessage"); 00194 } 00195 00196 zmq::socket_t & socket() const 00197 { 00198 ExcAssert(socket_); 00199 return *socket_; 00200 } 00201 00202 zmq::socket_t * socket_; 00203 00204 }; 00205 00206 00207 /*****************************************************************************/ 00208 /* ZMQ BINARY TYPED EVENT SOURCE */ 00209 /*****************************************************************************/ 00210 00215 template<typename Arg> 00216 struct ZmqBinaryTypedEventSource: public AsyncEventSource { 00217 00218 typedef std::function<void (Arg)> MessageHandler; 00219 MessageHandler messageHandler; 00220 00221 ZmqBinaryTypedEventSource() 00222 : socket_(0) 00223 { 00224 needsPoll = true; 00225 } 00226 00227 ZmqBinaryTypedEventSource(zmq::socket_t & socket, 00228 MessageHandler messageHandler = MessageHandler()) 00229 : messageHandler(std::move(messageHandler)), 00230 socket_(&socket) 00231 { 00232 needsPoll = true; 00233 } 00234 00235 void init(zmq::socket_t & socket) 00236 { 00237 socket_ = &socket; 00238 needsPoll = true; 00239 } 00240 00241 virtual int selectFd() const 00242 { 00243 int res = -1; 00244 size_t resSize = sizeof(int); 00245 socket().getsockopt(ZMQ_FD, &res, &resSize); 00246 if (res == -1) 00247 throw ML::Exception("no fd for zeromq socket"); 00248 return res; 00249 } 00250 00251 virtual bool poll() const 00252 { 00253 return getEvents(socket()).first; 00254 } 00255 00256 #if 0 00257 template<typename Arg, int Index> 00258 const Arg & 00259 getArg(const std::vector<zmq::message_t> & messages, 00260 const ML::InPosition<Arg, Index> * arg) 00261 { 00262 auto & m = messages.at(Index); 00263 ExcAssertEqual(m.size(), sizeof(Arg)); 00264 return * reinterpret_cast<const Arg *>(msg.data()); 00265 } 00266 #endif 00267 00268 virtual bool processOne() 00269 { 00270 zmq::message_t message; 00271 ExcAssert(socket_); 00272 bool got = socket_->recv(&message, ZMQ_NOBLOCK); 00273 if (!got) return false; 00274 00275 handleMessage(* reinterpret_cast<const Arg *>(message.data())); 00276 00277 return poll(); 00278 } 00279 00280 virtual void handleMessage(const Arg & arg) 00281 { 00282 if (messageHandler) 00283 messageHandler(arg); 00284 else throw ML::Exception("handleMessage not done"); 00285 } 00286 00287 zmq::socket_t & socket() const 00288 { 00289 ExcAssert(socket_); 00290 return *socket_; 00291 } 00292 00293 zmq::socket_t * socket_; 00294 }; 00295 00296 00297 /*****************************************************************************/ 00298 /* ZMQ TYPED EVENT SOURCE */ 00299 /*****************************************************************************/ 00300 00311 template<typename T> 00312 struct ZmqTypedEventSource: public ZmqEventSource { 00313 typedef std::function<void (T &&, std::string address)> 00314 OnMessage; 00315 00316 OnMessage onMessage; 00317 bool routable; 00318 std::string messageTopic; 00319 00320 ZmqTypedEventSource() 00321 { 00322 } 00323 00324 ZmqTypedEventSource(zmq::socket_t & socket, 00325 bool routable, 00326 const std::string & messageTopic) 00327 { 00328 init(routable, messageTopic); 00329 } 00330 00331 void init(zmq::socket_t & socket, 00332 bool routable, 00333 const std::string messageTopic) 00334 { 00335 this->routable = routable; 00336 this->messageTopic = messageTopic; 00337 } 00338 00339 virtual void handleMessage(const std::vector<std::string> & message) 00340 { 00341 int expectedSize = routable + 2; 00342 if (message.size() != expectedSize) 00343 throw ML::Exception("unexpected message size in ZmqTypedMessageSink"); 00344 00345 int i = routable; 00346 if (message[i + 1] != messageTopic) 00347 throw ML::Exception("unexpected messake kind in ZmqTypedMessageSink"); 00348 00349 std::istringstream stream(message[i + 2]); 00350 ML::DB::Store_Reader store(stream); 00351 T result; 00352 store >> result; 00353 00354 handleTypedMessage(std::move(result), routable ? message[0] : ""); 00355 } 00356 00357 virtual void handleTypedMessage(T && message, const std::string & address) 00358 { 00359 if (onMessage) 00360 onMessage(message, address); 00361 else 00362 throw ML::Exception("need to override handleTypedMessage or assign " 00363 "to onMessage"); 00364 } 00365 }; 00366 00367 00368 /*****************************************************************************/ 00369 /* ZMQ SOCKET MONITOR */ 00370 /*****************************************************************************/ 00371 00380 struct ZmqSocketMonitor : public ZmqBinaryTypedEventSource<zmq_event_t> { 00381 00382 ZmqSocketMonitor(zmq::context_t & context); 00383 00384 ~ZmqSocketMonitor() 00385 { 00386 shutdown(); 00387 } 00388 00389 void shutdown(); 00390 00397 void init(zmq::socket_t & socketToMonitor, int events = ZMQ_EVENT_ALL); 00398 00410 typedef std::function<void (std::string, int, zmq_event_t)> EventHandler; 00411 00412 // Success handlers 00413 EventHandler connectHandler, bindHandler, acceptHandler; 00414 00415 // Socket event handlers 00416 EventHandler closeHandler, disconnectHandler; 00417 00418 // Failure handlers 00419 EventHandler connectFailureHandler, acceptFailureHandler, 00420 bindFailureHandler, closeFailureHandler; 00421 00422 // Retry handlers 00423 EventHandler connectRetryHandler; 00424 00425 // Catch all handler, for when other handlers aren't registered 00426 EventHandler defaultHandler; 00427 00434 virtual int handleEvent(const zmq_event_t & event); 00435 00436 private: 00437 typedef std::mutex Lock; 00438 mutable Lock lock; 00439 00441 std::string connectedUri; 00442 00444 std::unique_ptr<zmq::socket_t> monitorEndpoint; 00445 00447 zmq::socket_t * monitoredSocket; 00448 }; 00449 00450 00451 /*****************************************************************************/ 00452 /* ZEROMQ NAMED ENDPOINT */ 00453 /*****************************************************************************/ 00454 00461 struct ZmqNamedEndpoint : public NamedEndpoint, public MessageLoop { 00462 00463 ZmqNamedEndpoint(std::shared_ptr<zmq::context_t> context); 00464 00465 ~ZmqNamedEndpoint() 00466 { 00467 shutdown(); 00468 } 00469 00470 void init(std::shared_ptr<ConfigurationService> config, 00471 int socketType, 00472 const std::string & endpointName); 00473 00474 void shutdown() 00475 { 00476 MessageLoop::shutdown(); 00477 00478 if (socket_) { 00479 unbindAll(); 00480 socket_.reset(); 00481 } 00482 00483 //ML::sleep(0.1); 00484 monitor.shutdown(); 00485 } 00486 00492 std::string bindTcp(PortRange const & portRange = PortRange(), std::string host = ""); 00493 00495 void bind(const std::string & address) 00496 { 00497 if (!socket_) 00498 throw ML::Exception("need to call ZmqNamedEndpoint::init() before " 00499 "bind"); 00500 00501 std::unique_lock<Lock> guard(lock); 00502 socket_->bind(address); 00503 boundAddresses[address]; 00504 } 00505 00507 void unbindAll() 00508 { 00509 std::unique_lock<Lock> guard(lock); 00510 ExcAssert(socket_); 00511 for (auto addr: boundAddresses) 00512 socket_->tryUnbind(addr.first); 00513 } 00514 00515 template<typename... Args> 00516 void sendMessage(Args&&... args) 00517 { 00518 using namespace std; 00519 std::unique_lock<Lock> guard(lock); 00520 ExcAssert(socket_); 00521 Datacratic::sendMessage(*socket_, std::forward<Args>(args)...); 00522 } 00523 00524 void sendMessage(const std::vector<std::string> & message) 00525 { 00526 using namespace std; 00527 std::unique_lock<Lock> guard(lock); 00528 ExcAssert(socket_); 00529 Datacratic::sendAll(*socket_, message); 00530 } 00531 00533 void sendMessage(std::vector<zmq::message_t> && message) 00534 { 00535 using namespace std; 00536 std::unique_lock<Lock> guard(lock); 00537 ExcAssert(socket_); 00538 for (unsigned i = 0; i < message.size(); ++i) { 00539 socket_->send(message[i], 00540 i == message.size() - 1 00541 ? 0 : ZMQ_SNDMORE); 00542 } 00543 } 00544 00546 zmq::socket_t & getSocketUnsafe() const 00547 { 00548 ExcAssert(socket_); 00549 return *socket_; 00550 } 00551 00552 typedef std::function<void (std::vector<zmq::message_t> &&)> 00553 RawMessageHandler; 00554 RawMessageHandler rawMessageHandler; 00555 00556 typedef std::function<void (std::vector<std::string> &&)> MessageHandler; 00557 MessageHandler messageHandler; 00558 00563 virtual void handleRawMessage(std::vector<zmq::message_t> && message) 00564 { 00565 if (rawMessageHandler) 00566 rawMessageHandler(std::move(message)); 00567 else { 00568 std::vector<std::string> msg2; 00569 for (unsigned i = 0; i < message.size(); ++i) { 00570 msg2.push_back(message[i].toString()); 00571 } 00572 handleMessage(std::move(msg2)); 00573 } 00574 } 00575 00576 virtual void handleMessage(std::vector<std::string> && message) 00577 { 00578 if (messageHandler) 00579 messageHandler(std::move(message)); 00580 else throw ML::Exception("need to override handleRawMessage or " 00581 "handleMessage"); 00582 } 00583 00584 typedef std::function<void (std::string bindAddress)> 00585 ConnectionEventHandler; 00586 00588 ConnectionEventHandler acceptEventHandler; 00589 00591 ConnectionEventHandler disconnectEventHandler; 00592 00594 ConnectionEventHandler closeEventHandler; 00595 00597 virtual void handleAcceptEvent(std::string boundAddress) 00598 { 00599 if (acceptEventHandler) 00600 acceptEventHandler(boundAddress); 00601 } 00602 00604 virtual void handleDisconnectEvent(std::string boundAddress) 00605 { 00606 if (disconnectEventHandler) 00607 disconnectEventHandler(boundAddress); 00608 } 00609 00611 virtual void handleCloseEvent(std::string boundAddress) 00612 { 00613 if (closeEventHandler) 00614 closeEventHandler(boundAddress); 00615 } 00616 00618 size_t numBoundAddresses() const 00619 { 00620 std::unique_lock<Lock> guard(lock); 00621 return boundAddresses.size(); 00622 } 00623 00627 size_t numActiveConnections(std::string addr = "") const 00628 { 00629 std::unique_lock<Lock> guard(lock); 00630 if (addr == "") { 00631 size_t result = 0; 00632 for (auto & addr: boundAddresses) 00633 result += addr.second.connectedFds.size(); 00634 return result; 00635 } 00636 else { 00637 auto it = boundAddresses.find(addr); 00638 if (it == boundAddresses.end()) 00639 return 0; 00640 return it->second.connectedFds.size(); 00641 } 00642 } 00643 00644 private: 00645 typedef std::mutex Lock; 00646 mutable Lock lock; 00647 00648 std::shared_ptr<zmq::context_t> context_; 00649 std::shared_ptr<zmq::socket_t> socket_; 00650 00653 ZmqSocketMonitor monitor; 00654 00655 struct AddressInfo { 00656 AddressInfo() 00657 : listeningFd(-1) 00658 { 00659 } 00660 00662 int listeningFd; 00663 00665 std::set<int> connectedFds; 00666 }; 00667 00669 std::map<std::string, AddressInfo> boundAddresses; 00670 00672 int socketType; 00673 }; 00674 00675 00676 /*****************************************************************************/ 00677 /* ZMQ NAMED CLIENT BUS */ 00678 /*****************************************************************************/ 00679 00683 struct ZmqNamedClientBus: public ZmqNamedEndpoint { 00684 00685 ZmqNamedClientBus(std::shared_ptr<zmq::context_t> context, 00686 double deadClientDelay = 5.0) 00687 : ZmqNamedEndpoint(context), deadClientDelay(deadClientDelay) 00688 { 00689 } 00690 00691 void init(std::shared_ptr<ConfigurationService> config, 00692 const std::string & endpointName) 00693 { 00694 ZmqNamedEndpoint::init(config, ZMQ_XREP, endpointName); 00695 addPeriodic("ZmqNamedClientBus::checkClient", 1.0, 00696 [=] (uint64_t v) { this->onCheckClient(v); }); 00697 } 00698 00699 virtual ~ZmqNamedClientBus() 00700 { 00701 shutdown(); 00702 } 00703 00704 void shutdown() 00705 { 00706 MessageLoop::shutdown(); 00707 ZmqNamedEndpoint::shutdown(); 00708 } 00709 00713 double deadClientDelay; 00714 00716 std::function<void (std::string)> onConnection; 00717 00721 std::function<void (std::string)> onDisconnection; 00722 00723 00724 00725 template<typename... Args> 00726 void sendMessage(const std::string & address, 00727 const std::string & topic, 00728 Args&&... args) 00729 { 00730 ZmqNamedEndpoint::sendMessage(address, topic, 00731 std::forward<Args>(args)...); 00732 } 00733 00734 virtual void handleMessage(std::vector<std::string> && message) 00735 { 00736 using namespace std; 00737 //cerr << "ZmqNamedClientBus got message " << message << endl; 00738 00739 const std::string & agent = message.at(0); 00740 const std::string & topic = message.at(1); 00741 00742 if (topic == "HEARTBEAT") { 00743 // Not the first message from the client 00744 auto it = clientInfo.find(agent); 00745 if (it == clientInfo.end()) { 00746 // Disconnection then reconnection 00747 if (onConnection) 00748 onConnection(agent); 00749 it = clientInfo.insert(make_pair(agent, ClientInfo())).first; 00750 } 00751 it->second.lastHeartbeat = Date::now(); 00752 sendMessage(agent, "HEARTBEAT"); 00753 } 00754 else if (topic == "HELLO") { 00755 // First message from client 00756 auto it = clientInfo.find(agent); 00757 if (it == clientInfo.end()) { 00758 // New connection 00759 if (onConnection) 00760 onConnection(agent); 00761 it = clientInfo.insert(make_pair(agent, ClientInfo())).first; 00762 } 00763 else { 00764 // Client must have disappeared then reappared without us 00765 // noticing. 00766 // Do this disconnection, then the reconnection 00767 if (onDisconnection) 00768 onDisconnection(agent); 00769 if (onConnection) 00770 onConnection(agent); 00771 } 00772 it->second.lastHeartbeat = Date::now(); 00773 sendMessage(agent, "HEARTBEAT"); 00774 } 00775 else { 00776 handleClientMessage(message); 00777 } 00778 00779 #if 0 00780 cerr << "poll() returned " << poll() << endl; 00781 00782 std::vector<std::string> msg 00783 = recvAllNonBlocking(socket()); 00784 00785 while (!msg.empty()) { 00786 cerr << "*** GOT FURTHER MESSAGE " << msg << endl; 00787 msg = recvAllNonBlocking(socket()); 00788 } 00789 00790 cerr << "poll() returned " << poll() << endl; 00791 #endif 00792 } 00793 00794 typedef std::function<void (std::vector<std::string>)> 00795 ClientMessageHandler; 00796 ClientMessageHandler clientMessageHandler; 00797 00798 virtual void handleClientMessage(const std::vector<std::string> & message) 00799 { 00800 if (clientMessageHandler) 00801 clientMessageHandler(message); 00802 else { 00803 throw ML::Exception("need to assign to onClientMessage " 00804 "or override handleClientMessage for message " 00805 + message.at(1)); 00806 } 00807 #if 0 00808 using namespace std; 00809 cerr << "ZmqNamedClientBus handleClientMessage " << message << endl; 00810 throw ML::Exception("handleClientMessage"); 00811 #endif 00812 } 00813 00814 private: 00815 void onCheckClient(uint64_t numEvents) 00816 { 00817 Date now = Date::now(); 00818 Date expiry = now.plusSeconds(-deadClientDelay); 00819 00820 std::vector<std::string> deadClients; 00821 00822 for (auto & c: clientInfo) 00823 if (c.second.lastHeartbeat < expiry) 00824 deadClients.push_back(c.first); 00825 00826 for (auto d: deadClients) { 00827 if (onDisconnection) 00828 onDisconnection(d); 00829 clientInfo.erase(d); 00830 } 00831 } 00832 00833 struct ClientInfo { 00834 ClientInfo() 00835 : lastHeartbeat(Date::now()) 00836 { 00837 } 00838 00839 Date lastHeartbeat; 00840 }; 00841 00842 std::map<std::string, ClientInfo> clientInfo; 00843 }; 00844 00845 00847 enum ConnectionStyle { 00848 CS_ASYNCHRONOUS, 00849 CS_SYNCHRONOUS, 00850 CS_MUST_SUCCEED 00851 }; 00852 00853 00854 00855 00856 /*****************************************************************************/ 00857 /* ZEROMQ NAMED PROXY */ 00858 /*****************************************************************************/ 00859 00862 // THIS SHOULD BE REPLACED BY ZmqNamedSocket 00863 00864 struct ZmqNamedProxy: public MessageLoop { 00865 00866 ZmqNamedProxy(); 00867 00868 ZmqNamedProxy(std::shared_ptr<zmq::context_t> context); 00869 00870 ~ZmqNamedProxy() 00871 { 00872 shutdown(); 00873 } 00874 00875 void shutdown() 00876 { 00877 MessageLoop::shutdown(); 00878 if(socket_) { 00879 std::lock_guard<ZmqEventSource::SocketLock> guard(socketLock_); 00880 socket_.reset(); 00881 } 00882 } 00883 00884 bool isConnected() const { return connectionState == CONNECTED; } 00885 00887 typedef std::function<void (std::string)> ConnectionHandler; 00888 00891 ConnectionHandler connectHandler; 00892 00896 virtual void onConnect(const std::string & source) 00897 { 00898 if (connectHandler) 00899 connectHandler(source); 00900 } 00901 00904 ConnectionHandler disconnectHandler; 00905 00909 virtual void onDisconnect(const std::string & source) 00910 { 00911 if (disconnectHandler) 00912 disconnectHandler(source); 00913 } 00914 00915 void init(std::shared_ptr<ConfigurationService> config, 00916 int socketType, 00917 const std::string & identity = ""); 00918 00928 bool connect(const std::string & endpointName, 00929 ConnectionStyle style = CS_ASYNCHRONOUS); 00930 00937 bool connectToServiceClass(const std::string & serviceClass, 00938 const std::string & endpointName, 00939 ConnectionStyle style = CS_ASYNCHRONOUS); 00940 00942 bool onConfigChange(ConfigurationService::ChangeType change, 00943 const std::string & key, 00944 const Json::Value & newValue); 00945 00947 zmq::socket_t & socket() const 00948 { 00949 ExcAssert(socket_); 00950 return *socket_; 00951 } 00952 00953 ZmqEventSource::SocketLock * socketLock() const 00954 { 00955 return &socketLock_; 00956 } 00957 00958 template<typename... Args> 00959 void sendMessage(Args&&... args) 00960 { 00961 std::lock_guard<ZmqEventSource::SocketLock> guard(socketLock_); 00962 00963 ExcCheckNotEqual(connectionState, NOT_CONNECTED, 00964 "sending on an unconnected socket: " + endpointName); 00965 00966 if (connectionState == CONNECTION_PENDING) { 00967 std::cerr << ("dropping message for " + endpointName + "\n"); 00968 return; 00969 } 00970 00971 Datacratic::sendMessage(socket(), std::forward<Args>(args)...); 00972 } 00973 00974 void disconnect() 00975 { 00976 if (connectionState == NOT_CONNECTED) return; 00977 00978 { 00979 std::lock_guard<ZmqEventSource::SocketLock> guard(socketLock_); 00980 00981 if (connectionState == CONNECTED) 00982 socket_->disconnect(connectedUri); 00983 00984 connectionState = NOT_CONNECTED; 00985 } 00986 00987 onDisconnect(connectedUri); 00988 } 00989 00990 00991 protected: 00992 ConfigurationService::Watch serviceWatch, endpointWatch; 00993 std::shared_ptr<ConfigurationService> config; 00994 std::shared_ptr<zmq::context_t> context_; 00995 std::shared_ptr<zmq::socket_t> socket_; 00996 00997 mutable ZmqEventSource::SocketLock socketLock_; 00998 00999 enum ConnectionType { 01000 NO_CONNECTION, 01001 CONNECT_DIRECT, 01002 CONNECT_TO_CLASS, 01003 } connectionType; 01004 01005 enum ConnectionState { 01006 NOT_CONNECTED, // connect() was not called 01007 CONNECTION_PENDING, // connect() was called but service is not available 01008 CONNECTED // connect() was called and the socket was connected 01009 } connectionState; 01010 01011 void onServiceNodeChange(const std::string & path, 01012 ConfigurationService::ChangeType change); 01013 void onEndpointNodeChange(const std::string & path, 01014 ConfigurationService::ChangeType change); 01015 01016 std::string serviceClass; 01017 std::string endpointName; 01018 std::string connectedService; 01019 std::string connectedUri; 01020 }; 01021 01022 01023 /*****************************************************************************/ 01024 /* ZEROMQ NAMED CLIENT BUS PROXY */ 01025 /*****************************************************************************/ 01026 01033 struct ZmqNamedClientBusProxy : public ZmqNamedProxy { 01034 01035 ZmqNamedClientBusProxy() 01036 : timeout(2.0) 01037 { 01038 } 01039 01040 ZmqNamedClientBusProxy(std::shared_ptr<zmq::context_t> context) 01041 : ZmqNamedProxy(context), timeout(2.0) 01042 { 01043 } 01044 01045 ~ZmqNamedClientBusProxy() 01046 { 01047 shutdown(); 01048 } 01049 01050 void init(std::shared_ptr<ConfigurationService> config, 01051 const std::string & identity = "") 01052 { 01053 ZmqNamedProxy::init(config, ZMQ_XREQ, identity); 01054 01055 auto doMessage = [=] (const std::vector<std::string> & message) 01056 { 01057 const std::string & topic = message.at(0); 01058 if (topic == "HEARTBEAT") 01059 this->lastHeartbeat = Date::now(); 01060 else handleMessage(message); 01061 }; 01062 01063 addSource("ZmqNamedClientBusProxy::doMessage", 01064 std::make_shared<ZmqEventSource>(socket(), doMessage, socketLock())); 01065 01066 auto doHeartbeat = [=] (int64_t skipped) 01067 { 01068 if (connectionState != CONNECTED) return; 01069 01070 sendMessage("HEARTBEAT"); 01071 01072 auto now = Date::now(); 01073 auto end = now.plusSeconds(-timeout); 01074 //if(lastHeartbeat < end) { 01075 //std::cerr << "no heartbeat for " << timeout << "s... should be disconnecting from " << connectedUri << std::endl; 01076 //disconnect(); 01077 //} 01078 }; 01079 01080 addPeriodic("ZmqNamedClientBusProxy::doHeartbeat", 1.0, doHeartbeat); 01081 } 01082 01083 void shutdown() 01084 { 01085 MessageLoop::shutdown(); 01086 ZmqNamedProxy::shutdown(); 01087 } 01088 01089 virtual void onConnect(const std::string & where) 01090 { 01091 lastHeartbeat = Date::now(); 01092 01093 sendMessage("HELLO"); 01094 01095 if (connectHandler) 01096 connectHandler(where); 01097 } 01098 01099 virtual void onDisconnect(const std::string & where) 01100 { 01101 if (disconnectHandler) 01102 disconnectHandler(where); 01103 } 01104 01105 ZmqEventSource::AsyncMessageHandler messageHandler; 01106 01107 virtual void handleMessage(const std::vector<std::string> & message) 01108 { 01109 if (messageHandler) 01110 messageHandler(message); 01111 else 01112 throw ML::Exception("need to override on messageHandler or handleMessage"); 01113 } 01114 01115 Date lastHeartbeat; 01116 double timeout; 01117 }; 01118 01119 01120 /*****************************************************************************/ 01121 /* ZEROMQ MULTIPLE NAMED CLIENT BUS PROXY */ 01122 /*****************************************************************************/ 01123 01130 struct ZmqMultipleNamedClientBusProxy: public MessageLoop { 01131 01132 ZmqMultipleNamedClientBusProxy() 01133 : zmqContext(new zmq::context_t(1)) 01134 { 01135 connected = false; 01136 } 01137 01138 ZmqMultipleNamedClientBusProxy(std::shared_ptr<zmq::context_t> context) 01139 : zmqContext(context) 01140 { 01141 connected = false; 01142 } 01143 01144 ~ZmqMultipleNamedClientBusProxy() 01145 { 01146 shutdown(); 01147 } 01148 01149 void init(std::shared_ptr<ConfigurationService> config, 01150 const std::string & identity = "") 01151 { 01152 this->config = config; 01153 this->identity = identity; 01154 } 01155 01156 void shutdown() 01157 { 01158 MessageLoop::shutdown(); 01159 for (auto & c: connections) 01160 if (c.second) 01161 c.second->shutdown(); 01162 } 01163 01164 template<typename... Args> 01165 void sendMessage(const std::string & recipient, 01166 const std::string & topic, 01167 Args&&... args) const 01168 { 01169 std::unique_lock<Lock> guard(connectionsLock); 01170 auto it = connections.find(recipient); 01171 if (it == connections.end()) { 01172 throw ML::Exception("attempt to deliver " + topic + " message to unknown client " 01173 + recipient); 01174 } 01175 it->second->sendMessage(topic, std::forward<Args>(args)...); 01176 } 01177 01181 void connectAllServiceProviders(const std::string & serviceClass, 01182 const std::string & endpointName) 01183 { 01184 if (connected) 01185 throw ML::Exception("alread connected to service providers"); 01186 01187 this->serviceClass = serviceClass; 01188 this->endpointName = endpointName; 01189 01190 serviceProvidersWatch.init([=] (const std::string & path, 01191 ConfigurationService::ChangeType change) 01192 { 01193 onServiceProvidersChanged("serviceClass/" + serviceClass); 01194 }); 01195 01196 onServiceProvidersChanged("serviceClass/" + serviceClass); 01197 // std::cerr << "++++after call to onServiceProvidersChanged " << std::endl; 01198 connected = true; 01199 } 01200 01201 01203 void connectSingleServiceProvider(const std::string & service); 01204 01206 void connectUri(const std::string & zmqUri); 01207 01208 size_t connectionCount() const 01209 { 01210 std::unique_lock<Lock> guard(connectionsLock); 01211 return connections.size(); 01212 } 01213 01215 typedef std::function<void (std::string)> ConnectionHandler; 01216 01219 ConnectionHandler connectHandler; 01220 01224 virtual void onConnect(const std::string & source) 01225 { 01226 if (connectHandler) 01227 connectHandler(source); 01228 } 01229 01232 ConnectionHandler disconnectHandler; 01233 01237 virtual void onDisconnect(const std::string & source) 01238 { 01239 if (disconnectHandler) 01240 disconnectHandler(source); 01241 } 01242 01244 typedef std::function<void (std::string, std::vector<std::string>)> 01245 MessageHandler; 01246 01250 MessageHandler messageHandler; 01251 01256 virtual void handleMessage(const std::string & source, 01257 const std::vector<std::string> & message) 01258 { 01259 if (messageHandler) 01260 messageHandler(source, message); 01261 else 01262 throw ML::Exception("need to override on messageHandler or handleMessage"); 01263 } 01264 01265 private: 01267 bool connected; 01268 01270 std::shared_ptr<zmq::context_t> zmqContext; 01271 01273 std::shared_ptr<ConfigurationService> config; 01274 01276 std::string serviceClass; 01277 01279 std::string endpointName; 01280 01282 std::string identity; 01283 01284 typedef ML::Spinlock Lock; 01285 01287 mutable Lock connectionsLock; 01288 01290 std::map<std::string, std::unique_ptr<ZmqNamedClientBusProxy> > connections; 01291 01293 ConfigurationService::Watch serviceProvidersWatch; 01294 01298 void onServiceProvidersChanged(const std::string & path) 01299 { 01300 using namespace std; 01301 //cerr << "onServiceProvidersChanged(" << path << ")" << endl; 01302 01303 // The list of service providers has changed 01304 01305 vector<string> children 01306 = config->getChildren(path, serviceProvidersWatch); 01307 for (auto c: children) { 01308 Json::Value value = config->getJson(path + "/" + c); 01309 std::string name = value["serviceName"].asString(); 01310 std::string path = value["servicePath"].asString(); 01311 01312 watchServiceProvider(name, path); 01313 } 01314 } 01315 01331 struct OnConnectCallback 01332 { 01333 OnConnectCallback(const ConnectionHandler& fn, std::string name) : 01334 fn(fn), name(name), state(DEFER) 01335 {} 01336 01338 void release() 01339 { 01340 State old = state; 01341 01342 ExcAssertNotEqual(old, CALL); 01343 01344 // If the callback wasn't triggered while we were holding the lock 01345 // then trigger it the next time we see it. 01346 if (old == DEFER && ML::cmp_xchg(state, old, CALL)) return; 01347 01348 ExcAssertEqual(old, DEFERRED); 01349 fn(name); 01350 } 01351 01352 void operator() (std::string) 01353 { 01354 State old = state; 01355 ExcAssertNotEqual(old, DEFERRED); 01356 01357 // If we're still in the locked section then trigger the callback 01358 // when release is called. 01359 if (old == DEFER && ML::cmp_xchg(state, old, DEFERRED)) return; 01360 01361 // We're out of the locked section so just trigger the callback. 01362 ExcAssertEqual(old, CALL); 01363 fn(name); 01364 } 01365 01366 private: 01367 01368 ConnectionHandler fn; 01369 std::string name; 01370 01371 enum State { 01372 DEFER, // We're holding the lock so defer an incoming callback. 01373 DEFERRED, // We were called while holding the lock. 01374 CALL // We were not called while holding the lock. 01375 } state; 01376 }; 01377 01379 void watchServiceProvider(const std::string & name, const std::string & path) 01380 { 01381 // Protects the connections map... I think. 01382 std::unique_lock<Lock> guard(connectionsLock); 01383 01384 auto & c = connections[name]; 01385 01386 // already connected 01387 if (c) return; 01388 01389 try { 01390 std::unique_ptr<ZmqNamedClientBusProxy> newClient 01391 (new ZmqNamedClientBusProxy(zmqContext)); 01392 newClient->init(config, identity); 01393 01394 // The connect call below could trigger this callback while we're 01395 // holding the connectionsLock which is a big no-no. This fancy 01396 // wrapper ensures that it's only called after we call its release 01397 // function. 01398 newClient->connectHandler = OnConnectCallback(connectHandler, name); 01399 01400 newClient->disconnectHandler = [=] (std::string s) 01401 { 01402 // TODO: chain in so that we know it's not around any more 01403 this->onDisconnect(s); 01404 }; 01405 01406 newClient->connect(path + "/" + endpointName); 01407 newClient->messageHandler = [=] (const std::vector<std::string> & msg) 01408 { 01409 this->handleMessage(name, msg); 01410 }; 01411 //newClient->debug(true); 01412 01413 c.reset(newClient.release()); 01414 01415 // Add it to our message loop so that it can process messages 01416 addSource("ZmqMultipleNamedClientBusProxy child " + name, *c); 01417 01418 guard.unlock(); 01419 c->connectHandler.target<OnConnectCallback>()->release(); 01420 01421 } catch (...) { 01422 connections.erase(name); 01423 throw; 01424 } 01425 } 01426 01427 }; 01428 01429 01430 01431 } // namespace Datacratic 01432 01433 #endif /* __service__zmq_endpoint_h__ */