diff --git a/CMakeLists.txt b/CMakeLists.txt index b10991f5f5..77703a4661 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -146,7 +146,7 @@ if(WITH_MESALINK) endif() set(CMAKE_CPP_FLAGS "${CMAKE_CPP_FLAGS} -DBTHREAD_USE_FAST_PTHREAD_MUTEX -D__const__=__unused__ -D_GNU_SOURCE -DUSE_SYMBOLIZE -DNO_TCMALLOC -D__STDC_FORMAT_MACROS -D__STDC_LIMIT_MACROS -D__STDC_CONSTANT_MACROS -DBRPC_REVISION=\\\"${BRPC_REVISION}\\\" -D__STRICT_ANSI__") set(CMAKE_CPP_FLAGS "${CMAKE_CPP_FLAGS} ${DEBUG_SYMBOL} ${THRIFT_CPP_FLAG}") -set(CMAKE_CXX_FLAGS "${CMAKE_CPP_FLAGS} -O2 -pipe -Wall -W -fPIC -fstrict-aliasing -Wno-invalid-offsetof -Wno-unused-parameter -fno-omit-frame-pointer") +set(CMAKE_CXX_FLAGS "${CMAKE_CPP_FLAGS} ${CMAKE_CXX_FLAGS} -O2 -pipe -Wall -W -fPIC -fstrict-aliasing -Wno-invalid-offsetof -Wno-unused-parameter -fno-omit-frame-pointer") set(CMAKE_C_FLAGS "${CMAKE_CPP_FLAGS} -O2 -pipe -Wall -W -fPIC -fstrict-aliasing -Wno-unused-parameter -fno-omit-frame-pointer") macro(use_cxx11) diff --git a/example/rdma_performance/client.cpp b/example/rdma_performance/client.cpp index 57d0c06c93..a7ed2c99c6 100644 --- a/example/rdma_performance/client.cpp +++ b/example/rdma_performance/client.cpp @@ -102,7 +102,7 @@ class PerformanceTest { int Init() { brpc::ChannelOptions options; - options.use_rdma = FLAGS_use_rdma; + options.socket_mode = FLAGS_use_rdma? brpc::SOCKET_MODE_RDMA : brpc::SOCKET_MODE_TCP; options.protocol = FLAGS_protocol; options.connection_type = FLAGS_connection_type; options.timeout_ms = FLAGS_rpc_timeout_ms; diff --git a/example/rdma_performance/server.cpp b/example/rdma_performance/server.cpp index d3d00057f4..2e93e1eec7 100644 --- a/example/rdma_performance/server.cpp +++ b/example/rdma_performance/server.cpp @@ -76,7 +76,7 @@ int main(int argc, char* argv[]) { g_last_time.store(0, butil::memory_order_relaxed); brpc::ServerOptions options; - options.use_rdma = FLAGS_use_rdma; + options.socket_mode = FLAGS_use_rdma? brpc::SOCKET_MODE_RDMA : brpc::SOCKET_MODE_TCP; if (server.Start(FLAGS_port, &options) != 0) { LOG(ERROR) << "Fail to start EchoServer"; return -1; diff --git a/src/brpc/acceptor.cpp b/src/brpc/acceptor.cpp index fd6564c987..f9c22a6848 100644 --- a/src/brpc/acceptor.cpp +++ b/src/brpc/acceptor.cpp @@ -21,8 +21,8 @@ #include "butil/fd_guard.h" // fd_guard #include "butil/fd_utility.h" // make_close_on_exec #include "butil/time.h" // gettimeofday_us -#include "brpc/rdma/rdma_endpoint.h" #include "brpc/acceptor.h" +#include "brpc/transport_factory.h" namespace brpc { @@ -40,7 +40,7 @@ Acceptor::Acceptor(bthread_keytable_pool_t* pool) , _empty_cond(&_map_mutex) , _force_ssl(false) , _ssl_ctx(NULL) - , _use_rdma(false) + , _socket_mode(SOCKET_MODE_TCP) , _bthread_tag(BTHREAD_TAG_DEFAULT) { } @@ -282,18 +282,10 @@ void Acceptor::OnNewConnectionsUntilEAGAIN(Socket* acception) { options.fd = in_fd; butil::sockaddr2endpoint(&in_addr, in_len, &options.remote_side); options.user = acception->user(); + options.need_on_edge_trigger = true; options.force_ssl = am->_force_ssl; options.initial_ssl_ctx = am->_ssl_ctx; -#if BRPC_WITH_RDMA - if (am->_use_rdma) { - options.on_edge_triggered_events = rdma::RdmaEndpoint::OnNewDataFromTcp; - } else { -#else - { -#endif - options.on_edge_triggered_events = InputMessenger::OnNewMessages; - } - options.use_rdma = am->_use_rdma; + options.socket_mode = am->_socket_mode; options.bthread_tag = am->_bthread_tag; if (Socket::Create(options, &socket_id) != 0) { LOG(ERROR) << "Fail to create Socket"; diff --git a/src/brpc/acceptor.h b/src/brpc/acceptor.h index 69f632aaca..77942beca2 100644 --- a/src/brpc/acceptor.h +++ b/src/brpc/acceptor.h @@ -22,6 +22,7 @@ #include "butil/synchronization/condition_variable.h" #include "butil/containers/flat_map.h" #include "brpc/input_messenger.h" +#include "brpc/socket_mode.h" namespace brpc { @@ -110,8 +111,8 @@ friend class Server; bool _force_ssl; std::shared_ptr _ssl_ctx; - // Whether to use rdma or not - bool _use_rdma; + // Choose to use a certain socket: 0 TCP, 1 RDMA + SocketMode _socket_mode; // Acceptor belongs to this tag bthread_tag_t _bthread_tag; diff --git a/src/brpc/channel.cpp b/src/brpc/channel.cpp index 0fd43d7c9c..86124c2552 100644 --- a/src/brpc/channel.cpp +++ b/src/brpc/channel.cpp @@ -37,6 +37,7 @@ #include "brpc/details/usercode_backup_pool.h" // TooManyUserCode #include "brpc/rdma/rdma_helper.h" #include "brpc/policy/esp_authenticator.h" +#include "brpc/transport_factory.h" namespace brpc { @@ -60,7 +61,7 @@ ChannelOptions::ChannelOptions() , connection_type(CONNECTION_TYPE_UNKNOWN) , succeed_without_server(true) , log_succeed_without_server(true) - , use_rdma(false) + , socket_mode(SOCKET_MODE_TCP) , auth(NULL) , backup_request_policy(NULL) , retry_policy(NULL) @@ -77,6 +78,8 @@ ChannelSSLOptions* ChannelOptions::mutable_ssl_options() { static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) { if (opt.auth == NULL && !opt.has_ssl_options() && + opt.client_host.empty() && + opt.device_name.empty() && opt.connection_group.empty() && opt.hc_option.health_check_path.empty()) { // Returning zeroized result by default is more intuitive for users. @@ -94,6 +97,14 @@ static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) { buf.append("|conng="); buf.append(opt.connection_group); } + if (!opt.client_host.empty()) { + buf.append("|clih="); + buf.append(opt.client_host); + } + if (!opt.device_name.empty()) { + buf.append("|devn="); + buf.append(opt.device_name); + } if (opt.auth) { buf.append("|auth="); buf.append((char*)&opt.auth, sizeof(opt.auth)); @@ -120,7 +131,7 @@ static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) { } else { // All disabled ChannelSSLOptions are the same } - if (opt.use_rdma) { + if (opt.socket_mode == SOCKET_MODE_RDMA) { buf.append("|rdma"); } butil::MurmurHash3_x64_128_Update(&mm_ctx, buf.data(), buf.size()); @@ -163,20 +174,6 @@ Channel::~Channel() { } } -#if BRPC_WITH_RDMA -static bool OptionsAvailableForRdma(const ChannelOptions* opt) { - if (opt->has_ssl_options()) { - LOG(WARNING) << "Cannot use SSL and RDMA at the same time"; - return false; - } - if (!rdma::SupportedByRdma(opt->protocol.name())) { - LOG(WARNING) << "Cannot use " << opt->protocol.name() - << " over RDMA"; - return false; - } - return true; -} -#endif int Channel::InitChannelOptions(const ChannelOptions* options) { if (options) { // Override default options if user provided one. @@ -191,19 +188,10 @@ int Channel::InitChannelOptions(const ChannelOptions* options) { _options.hc_option.health_check_path = FLAGS_health_check_path; _options.hc_option.health_check_timeout_ms = FLAGS_health_check_timeout_ms; } - if (_options.use_rdma) { -#if BRPC_WITH_RDMA - if (!OptionsAvailableForRdma(&_options)) { - return -1; - } - rdma::GlobalRdmaInitializeOrDie(); - if (!rdma::InitPollingModeWithTag(bthread_self_tag())) { - return -1; - } -#else - LOG(WARNING) << "Cannot use rdma since brpc does not compile with rdma"; + auto ret = TransportFactory::ContextInitOrDie(_options.socket_mode, false, &_options); + if (ret != 0) { + LOG(ERROR) << "Fail to initialize transport context for channel, ret=" << ret; return -1; -#endif } _serialize_request = protocol->serialize_request; @@ -362,14 +350,27 @@ int Channel::InitSingle(const butil::EndPoint& server_addr_and_port, LOG(ERROR) << "Invalid port=" << port; return -1; } + butil::EndPoint client_endpoint; + if (!_options.client_host.empty() && + butil::str2ip(_options.client_host.c_str(), &client_endpoint.ip) != 0 && + butil::hostname2ip(_options.client_host.c_str(), &client_endpoint.ip) != 0) { + LOG(ERROR) << "Invalid client host=`" << _options.client_host << '\''; + return -1; + } _server_address = server_addr_and_port; const ChannelSignature sig = ComputeChannelSignature(_options); std::shared_ptr ssl_ctx; if (CreateSocketSSLContext(_options, &ssl_ctx) != 0) { return -1; } + SocketOptions opt; + opt.local_side = client_endpoint; + opt.initial_ssl_ctx = ssl_ctx; + opt.socket_mode = _options.socket_mode; + opt.hc_option = _options.hc_option; + opt.device_name = _options.device_name; if (SocketMapInsert(SocketMapKey(server_addr_and_port, sig), - &_server_id, ssl_ctx, _options.use_rdma, _options.hc_option) != 0) { + &_server_id, opt) != 0) { LOG(ERROR) << "Fail to insert into SocketMap"; return -1; } @@ -397,6 +398,13 @@ int Channel::Init(const char* ns_url, _options.mutable_ssl_options()->sni_name = _service_name; } } + butil::EndPoint client_endpoint; + if (!_options.client_host.empty() && + butil::str2ip(_options.client_host.c_str(), &client_endpoint.ip) != 0 && + butil::hostname2ip(_options.client_host.c_str(), &client_endpoint.ip) != 0) { + LOG(ERROR) << "Invalid client host=`" << _options.client_host << '\''; + return -1; + } std::unique_ptr lb(new (std::nothrow) LoadBalancerWithNaming); if (NULL == lb) { @@ -406,10 +414,13 @@ int Channel::Init(const char* ns_url, GetNamingServiceThreadOptions ns_opt; ns_opt.succeed_without_server = _options.succeed_without_server; ns_opt.log_succeed_without_server = _options.log_succeed_without_server; - ns_opt.use_rdma = _options.use_rdma; + ns_opt.socket_option.socket_mode = _options.socket_mode; ns_opt.channel_signature = ComputeChannelSignature(_options); - ns_opt.hc_option = _options.hc_option; - if (CreateSocketSSLContext(_options, &ns_opt.ssl_ctx) != 0) { + ns_opt.socket_option.hc_option = _options.hc_option; + ns_opt.socket_option.local_side = client_endpoint; + ns_opt.socket_option.device_name = _options.device_name; + if (CreateSocketSSLContext(_options, + &ns_opt.socket_option.initial_ssl_ctx) != 0) { return -1; } if (lb->Init(ns_url, lb_name, _options.ns_filter, &ns_opt) != 0) { diff --git a/src/brpc/channel.h b/src/brpc/channel.h index c970209b3a..7c257c05d3 100644 --- a/src/brpc/channel.h +++ b/src/brpc/channel.h @@ -37,6 +37,7 @@ #include "brpc/backup_request_policy.h" #include "brpc/naming_service_filter.h" #include "brpc/health_check_option.h" +#include "brpc/socket_mode.h" namespace brpc { @@ -105,9 +106,9 @@ struct ChannelOptions { const ChannelSSLOptions& ssl_options() const { return *_ssl_options; } ChannelSSLOptions* mutable_ssl_options(); - // Let this channel use rdma rather than tcp. - // Default: false - bool use_rdma; + // Let this channel Choose to use a certain socket: 0 SOCKET_MODE_TCP, 1 SOCKET_MODE_RDMA. + // Default: SOCKET_MODE_TCP + SocketMode socket_mode; // Turn on authentication for this channel if `auth' is not NULL. // Note `auth' will not be deleted by channel and must remain valid when @@ -148,6 +149,16 @@ struct ChannelOptions { // Its priority is higher than FLAGS_health_check_path and FLAGS_health_check_timeout_ms. // When it is not set, FLAGS_health_check_path and FLAGS_health_check_timeout_ms will take effect. HealthCheckOption hc_option; + + // IP address or host name of the client. + // if the client_host is "", the client IP address is determined by the OS. + // Default: "" + std::string client_host; + + // The device name of the client's network adapter. + // if the device_name is "", the flow control is determined by the OS. + // Default: "" + std::string device_name; private: // SSLOptions is large and not often used, allocate it on heap to // prevent ChannelOptions from being bloated in most cases. diff --git a/src/brpc/details/naming_service_thread.cpp b/src/brpc/details/naming_service_thread.cpp index 341ca35b09..7eb005e8f0 100644 --- a/src/brpc/details/naming_service_thread.cpp +++ b/src/brpc/details/naming_service_thread.cpp @@ -125,8 +125,8 @@ void NamingServiceThread::Actions::ResetServers( // Socket. SocketMapKey may be passed through AddWatcher. Make sure // to pick those Sockets with the right settings during OnAddedServers const SocketMapKey key(_added[i], _owner->_options.channel_signature); - CHECK_EQ(0, SocketMapInsert(key, &tagged_id.id, _owner->_options.ssl_ctx, - _owner->_options.use_rdma, _owner->_options.hc_option)); + CHECK_EQ(0, SocketMapInsert(key, &tagged_id.id, + _owner->_options.socket_option)); _added_sockets.push_back(tagged_id); } diff --git a/src/brpc/details/naming_service_thread.h b/src/brpc/details/naming_service_thread.h index 1745e5f267..f01fbea6a4 100644 --- a/src/brpc/details/naming_service_thread.h +++ b/src/brpc/details/naming_service_thread.h @@ -27,6 +27,7 @@ #include "brpc/naming_service.h" // NamingService #include "brpc/naming_service_filter.h" // NamingServiceFilter #include "brpc/socket_map.h" +#include "brpc/socket_mode.h" namespace brpc { @@ -44,15 +45,14 @@ class NamingServiceWatcher { struct GetNamingServiceThreadOptions { GetNamingServiceThreadOptions() : succeed_without_server(false) - , log_succeed_without_server(true) - , use_rdma(false) {} + , log_succeed_without_server(true) { + socket_option.socket_mode = SOCKET_MODE_TCP; +} bool succeed_without_server; bool log_succeed_without_server; - bool use_rdma; - HealthCheckOption hc_option; ChannelSignature channel_signature; - std::shared_ptr ssl_ctx; + SocketOptions socket_option; }; // A dedicated thread to map a name to ServerIds diff --git a/src/brpc/input_message_base.h b/src/brpc/input_message_base.h index 86b25785cc..b117eb99c3 100644 --- a/src/brpc/input_message_base.h +++ b/src/brpc/input_message_base.h @@ -55,6 +55,7 @@ class InputMessageBase : public Destroyable { friend class InputMessenger; friend void* ProcessInputMessage(void*); friend class Stream; +friend class Transport; int64_t _received_us; int64_t _base_real_us; SocketUniquePtr _socket; diff --git a/src/brpc/input_messenger.cpp b/src/brpc/input_messenger.cpp index 1b8a86f2c6..c249cca22c 100644 --- a/src/brpc/input_messenger.cpp +++ b/src/brpc/input_messenger.cpp @@ -29,7 +29,7 @@ #include "brpc/protocol.h" // ListProtocols #include "brpc/rdma/rdma_endpoint.h" #include "brpc/input_messenger.h" - +#include "brpc/transport_factory.h" namespace brpc { @@ -112,8 +112,7 @@ ParseResult InputMessenger::CutInputMessage( // The length of `data' must be PROTO_DUMMY_LEN + 1 to store extra ending char '\0' char data[PROTO_DUMMY_LEN + 1]; m->_read_buf.copy_to_cstr(data, PROTO_DUMMY_LEN); - if (strncmp(data, "RDMA", PROTO_DUMMY_LEN) == 0 && - m->_rdma_state == Socket::RDMA_OFF) { + if (strncmp(data, "RDMA", PROTO_DUMMY_LEN) == 0) { // To avoid timeout when client uses RDMA but server uses TCP return MakeParseError(PARSE_ERROR_TRY_OTHERS); } @@ -191,46 +190,13 @@ struct RunLastMessage { } }; -static void QueueMessage(InputMessageBase* to_run_msg, - int* num_bthread_created, - bthread_keytable_pool_t* keytable_pool) { - if (!to_run_msg) { - return; - } - -#if BRPC_WITH_RDMA - if (rdma::FLAGS_rdma_disable_bthread) { - ProcessInputMessage(to_run_msg); - return; - } -#endif - // Create bthread for last_msg. The bthread is not scheduled - // until bthread_flush() is called (in the worse case). - - // TODO(gejun): Join threads. - bthread_t th; - bthread_attr_t tmp = (FLAGS_usercode_in_pthread ? - BTHREAD_ATTR_PTHREAD : - BTHREAD_ATTR_NORMAL) | BTHREAD_NOSIGNAL; - tmp.keytable_pool = keytable_pool; - tmp.tag = bthread_self_tag(); - bthread_attr_set_name(&tmp, "ProcessInputMessage"); - - if (!FLAGS_usercode_in_coroutine && bthread_start_background( - &th, &tmp, ProcessInputMessage, to_run_msg) == 0) { - ++*num_bthread_created; - } else { - ProcessInputMessage(to_run_msg); - } -} - -InputMessenger::InputMessageClosure::~InputMessageClosure() noexcept(false) { +InputMessageClosure::~InputMessageClosure() noexcept(false) { if (_msg) { ProcessInputMessage(_msg); } } -void InputMessenger::InputMessageClosure::reset(InputMessageBase* m) { +void InputMessageClosure::reset(InputMessageBase* m) { if (_msg) { ProcessInputMessage(_msg); } @@ -303,7 +269,7 @@ int InputMessenger::ProcessNewMessage( // This unique_ptr prevents msg to be lost before transfering // ownership to last_msg DestroyingPtr msg(pr.message()); - QueueMessage(last_msg.release(), &num_bthread_created, m->_keytable_pool); + m->_transport->QueueMessage(last_msg, &num_bthread_created, false); if (_handlers[index].process == NULL) { LOG(ERROR) << "process of index=" << index << " is NULL"; continue; @@ -336,22 +302,19 @@ int InputMessenger::ProcessNewMessage( // Transfer ownership to last_msg last_msg.reset(msg.release()); } else { - QueueMessage(msg.release(), &num_bthread_created, - m->_keytable_pool); + last_msg.reset(msg.release()); + m->_transport->QueueMessage(last_msg, &num_bthread_created, false); bthread_flush(); num_bthread_created = 0; } } -#if BRPC_WITH_RDMA // In RDMA polling mode, all messages must be executed in a new bthread and // not in the bthread where the polling bthread is located, because the // method for processing messages may call synchronization primitives, // causing the polling bthread to be scheduled out. - if (rdma::FLAGS_rdma_use_polling) { - QueueMessage(last_msg.release(), &num_bthread_created, - m->_keytable_pool); + if (m->_socket_mode == SOCKET_MODE_RDMA) { + m->_transport->QueueMessage(last_msg, &num_bthread_created, true); } -#endif if (num_bthread_created) { bthread_flush(); } @@ -414,8 +377,8 @@ void InputMessenger::OnNewMessages(Socket* m) { } } - if (m->_rdma_state == Socket::RDMA_OFF && messenger->ProcessNewMessage( - m, nr, read_eof, received_us, base_realtime, last_msg) < 0) { + if (messenger->ProcessNewMessage(m, nr, read_eof, received_us, + base_realtime, last_msg) < 0) { return; } } @@ -533,16 +496,7 @@ int InputMessenger::Create(const butil::EndPoint& remote_side, int InputMessenger::Create(SocketOptions options, SocketId* id) { options.user = this; -#if BRPC_WITH_RDMA - if (options.use_rdma) { - options.on_edge_triggered_events = rdma::RdmaEndpoint::OnNewDataFromTcp; - options.app_connect = std::make_shared(); - } else { -#else - { -#endif - options.on_edge_triggered_events = OnNewMessages; - } + options.need_on_edge_trigger = true; // Enable keepalive by options or Gflag. // Priority: options > Gflag. if (options.keepalive_options || FLAGS_socket_keepalive) { diff --git a/src/brpc/input_messenger.h b/src/brpc/input_messenger.h index 1c191a87c2..8482c3f3fc 100644 --- a/src/brpc/input_messenger.h +++ b/src/brpc/input_messenger.h @@ -29,7 +29,7 @@ namespace brpc { namespace rdma { class RdmaEndpoint; } - +class TcpTransport; struct InputMessageHandler { // The callback to cut a message from `source'. // Returned message will be passed to process_request or process_response @@ -70,9 +70,28 @@ struct InputMessageHandler { const char* name; }; +class InputMessageClosure { +public: + InputMessageClosure() : _msg(NULL) { } + ~InputMessageClosure() noexcept(false); + + InputMessageBase* release() { + InputMessageBase* m = _msg; + _msg = NULL; + return m; + } + + void reset(InputMessageBase* m); + +private: + InputMessageBase* _msg; +}; + // Process messages from connections. // `Message' corresponds to a client's request or a server's response. class InputMessenger : public SocketUser { +friend class Socket; +friend class TcpTransport; friend class rdma::RdmaEndpoint; public: explicit InputMessenger(size_t capacity = 128); @@ -111,22 +130,6 @@ friend class rdma::RdmaEndpoint; static void OnNewMessages(Socket* m); private: - class InputMessageClosure { - public: - InputMessageClosure() : _msg(NULL) { } - ~InputMessageClosure() noexcept(false); - - InputMessageBase* release() { - InputMessageBase* m = _msg; - _msg = NULL; - return m; - } - - void reset(InputMessageBase* m); - - private: - InputMessageBase* _msg; - }; // Find a valid scissor from `handlers' to cut off `header' and `payload' // from m->read_buf, save index of the scissor into `index'. diff --git a/src/brpc/rdma/rdma_endpoint.cpp b/src/brpc/rdma/rdma_endpoint.cpp index 616ef33252..3cc2107f23 100644 --- a/src/brpc/rdma/rdma_endpoint.cpp +++ b/src/brpc/rdma/rdma_endpoint.cpp @@ -30,6 +30,7 @@ #include "brpc/rdma/block_pool.h" #include "brpc/rdma/rdma_helper.h" #include "brpc/rdma/rdma_endpoint.h" +#include "brpc/rdma_transport.h" DECLARE_int32(task_group_ntags); @@ -239,14 +240,15 @@ void RdmaEndpoint::Reset() { void RdmaConnect::StartConnect(const Socket* socket, void (*done)(int err, void* data), void* data) { - CHECK(socket->_rdma_ep != NULL); + auto* rdma_transport = static_cast(socket->_transport.get()); + CHECK(rdma_transport->_rdma_ep != NULL); SocketUniquePtr s; if (Socket::Address(socket->id(), &s) != 0) { return; } if (!IsRdmaAvailable()) { - socket->_rdma_ep->_state = RdmaEndpoint::FALLBACK_TCP; - s->_rdma_state = Socket::RDMA_OFF; + rdma_transport->_rdma_ep->_state = RdmaEndpoint::FALLBACK_TCP; + rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; done(0, data); return; } @@ -256,7 +258,7 @@ void RdmaConnect::StartConnect(const Socket* socket, bthread_attr_t attr = BTHREAD_ATTR_NORMAL; bthread_attr_set_name(&attr, "RdmaProcessHandshakeAtClient"); if (bthread_start_background(&tid, &attr, - RdmaEndpoint::ProcessHandshakeAtClient, socket->_rdma_ep) < 0) { + RdmaEndpoint::ProcessHandshakeAtClient, rdma_transport->_rdma_ep) < 0) { LOG(FATAL) << "Fail to start handshake bthread"; Run(); } else { @@ -299,7 +301,8 @@ static void TryReadOnTcpDuringRdmaEst(Socket* s) { } void RdmaEndpoint::OnNewDataFromTcp(Socket* m) { - RdmaEndpoint* ep = m->_rdma_ep; + auto* rdma_transport = static_cast(m->_transport.get()); + RdmaEndpoint* ep = rdma_transport->GetRdmaEp(); CHECK(ep != NULL); int progress = Socket::PROGRESS_INIT; @@ -308,7 +311,7 @@ void RdmaEndpoint::OnNewDataFromTcp(Socket* m) { if (!m->CreatedByConnect()) { if (!IsRdmaAvailable()) { ep->_state = FALLBACK_TCP; - m->_rdma_state = Socket::RDMA_OFF; + rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; continue; } bthread_t tid; @@ -433,9 +436,10 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { // First initialize CQ and QP resources ep->_state = C_ALLOC_QPCQ; + auto* rdma_transport = static_cast(s->_transport.get()); if (ep->AllocateResources() < 0) { LOG(WARNING) << "Fallback to tcp:" << s->description(); - s->_rdma_state = Socket::RDMA_OFF; + rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; ep->_state = FALLBACK_TCP; return NULL; } @@ -514,7 +518,7 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { if (!HelloNegotiationValid(remote_msg)) { LOG(WARNING) << "Fail to negotiate with server, fallback to tcp:" << s->description(); - s->_rdma_state = Socket::RDMA_OFF; + rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { ep->_remote_recv_block_size = remote_msg.block_size; ep->_local_window_capacity = @@ -530,16 +534,16 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { ep->_state = C_BRINGUP_QP; if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) { LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" << s->description(); - s->_rdma_state = Socket::RDMA_OFF; + rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { - s->_rdma_state = Socket::RDMA_ON; + rdma_transport->_rdma_state = RdmaTransport::RDMA_ON; } } // Send ACK message to server ep->_state = C_ACK_SEND; uint32_t flags = 0; - if (s->_rdma_state != Socket::RDMA_OFF) { + if (rdma_transport->_rdma_state != RdmaTransport::RDMA_OFF) { flags |= ACK_MSG_RDMA_OK; } uint32_t* tmp = (uint32_t*)data; // avoid GCC warning on strict-aliasing @@ -553,7 +557,7 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { return NULL; } - if (s->_rdma_state == Socket::RDMA_ON) { + if (rdma_transport->_rdma_state == RdmaTransport::RDMA_ON) { ep->_state = ESTABLISHED; LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Client handshake ends (use rdma) on " << s->description(); @@ -586,7 +590,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { ep->_state = FAILED; return NULL; } - + auto* rdma_transport = static_cast(s->_transport.get()); if (memcmp(data, MAGIC_STR, MAGIC_STR_LEN) != 0) { LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "It seems that the " << "client does not use RDMA, fallback to TCP:" @@ -594,7 +598,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { // we need to copy data read back to _socket->_read_buf s->_read_buf.append(data, MAGIC_STR_LEN); ep->_state = FALLBACK_TCP; - s->_rdma_state = Socket::RDMA_OFF; + rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; ep->TryReadOnTcp(); return NULL; } @@ -626,7 +630,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { if (!HelloNegotiationValid(remote_msg)) { LOG(WARNING) << "Fail to negotiate with client, fallback to tcp:" << s->description(); - s->_rdma_state = Socket::RDMA_OFF; + rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { ep->_remote_recv_block_size = remote_msg.block_size; ep->_local_window_capacity = @@ -643,13 +647,13 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { if (ep->AllocateResources() < 0) { LOG(WARNING) << "Fail to allocate rdma resources, fallback to tcp:" << s->description(); - s->_rdma_state = Socket::RDMA_OFF; + rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { ep->_state = S_BRINGUP_QP; if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) { LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" << s->description(); - s->_rdma_state = Socket::RDMA_OFF; + rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } } } @@ -658,7 +662,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { ep->_state = S_HELLO_SEND; HelloMessage local_msg; local_msg.msg_len = g_rdma_hello_msg_len; - if (s->_rdma_state == Socket::RDMA_OFF) { + if (rdma_transport->_rdma_state == RdmaTransport::RDMA_OFF) { local_msg.impl_ver = 0; local_msg.hello_ver = 0; } else { @@ -702,7 +706,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { uint32_t* tmp = (uint32_t*)data; // avoid GCC warning on strict-aliasing uint32_t flags = butil::NetToHost32(*tmp); if (flags & ACK_MSG_RDMA_OK) { - if (s->_rdma_state == Socket::RDMA_OFF) { + if (rdma_transport->_rdma_state == RdmaTransport::RDMA_OFF) { LOG(WARNING) << "Fail to parse Hello Message length from client:" << s->description(); s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", @@ -710,13 +714,13 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { ep->_state = FAILED; return NULL; } else { - s->_rdma_state = Socket::RDMA_ON; + rdma_transport->_rdma_state = RdmaTransport::RDMA_ON; ep->_state = ESTABLISHED; LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Server handshake ends (use rdma) on " << s->description(); } } else { - s->_rdma_state = Socket::RDMA_OFF; + rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; ep->_state = FALLBACK_TCP; LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Server handshake ends (use tcp) on " << s->description(); @@ -1455,7 +1459,8 @@ void RdmaEndpoint::PollCq(Socket* m) { if (Socket::Address(ep->_socket->id(), &s) < 0) { return; } - CHECK(ep == s->_rdma_ep); + auto* rdma_transport = static_cast(s->_transport.get()); + CHECK(ep == rdma_transport->_rdma_ep); bool send = false; ibv_cq* cq = ep->_resource->recv_cq; @@ -1472,7 +1477,7 @@ void RdmaEndpoint::PollCq(Socket* m) { int progress = Socket::PROGRESS_INIT; bool notified = false; - InputMessenger::InputMessageClosure last_msg; + InputMessageClosure last_msg; ibv_wc wc[FLAGS_rdma_cqe_poll_once]; while (true) { int cnt = ibv_poll_cq(cq, FLAGS_rdma_cqe_poll_once, wc); diff --git a/src/brpc/rdma_transport.cpp b/src/brpc/rdma_transport.cpp new file mode 100644 index 0000000000..8fe88c6b4b --- /dev/null +++ b/src/brpc/rdma_transport.cpp @@ -0,0 +1,238 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#if BRPC_WITH_RDMA + +#include "brpc/rdma_transport.h" +#include "brpc/tcp_transport.h" +#include "brpc/rdma/rdma_endpoint.h" +#include "brpc/rdma/rdma_helper.h" + +namespace brpc { +DECLARE_bool(usercode_in_coroutine); +DECLARE_bool(usercode_in_pthread); + +extern SocketVarsCollector *g_vars; + +void RdmaTransport::Init(Socket *socket, const SocketOptions &options) { + CHECK(_rdma_ep == NULL); + if (options.socket_mode == SOCKET_MODE_RDMA) { + _rdma_ep = new(std::nothrow)rdma::RdmaEndpoint(socket); + if (!_rdma_ep) { + const int saved_errno = errno; + PLOG(ERROR) << "Fail to create RdmaEndpoint"; + socket->SetFailed( + saved_errno, "Fail to create RdmaEndpoint: %s", berror(saved_errno)); + } + _rdma_state = RDMA_UNKNOWN; + } else { + _rdma_state = RDMA_OFF; + socket->_socket_mode = SOCKET_MODE_TCP; + } + _socket = socket; + _default_connect = options.app_connect; + _on_edge_trigger = options.on_edge_triggered_events; + if (options.need_on_edge_trigger && _on_edge_trigger == NULL) { + _on_edge_trigger = rdma::RdmaEndpoint::OnNewDataFromTcp; + } + _tcp_transport = std::make_shared(); + _tcp_transport->Init(socket, options); +} + +void RdmaTransport::Release() { + if (_rdma_ep) { + delete _rdma_ep; + _rdma_ep = NULL; + _rdma_state = RDMA_UNKNOWN; + } +} + +int RdmaTransport::Reset(int32_t expected_nref) { + if (_rdma_ep) { + _rdma_ep->Reset(); + _rdma_state = RDMA_UNKNOWN; + } + return 0; +} + +std::shared_ptr RdmaTransport::Connect() { + if (_default_connect == nullptr) { + return std::make_shared(); + } + return _default_connect; +} + +int RdmaTransport::CutFromIOBuf(butil::IOBuf *buf) { + if (_rdma_ep && _rdma_state != RDMA_OFF) { + butil::IOBuf *data_arr[1] = {buf}; + return _rdma_ep->CutFromIOBufList(data_arr, 1); + } else { + return _tcp_transport->CutFromIOBuf(buf); + } +} + +ssize_t RdmaTransport::CutFromIOBufList(butil::IOBuf **buf, size_t ndata) { + if (_rdma_ep && _rdma_state != RDMA_OFF) { + return _rdma_ep->CutFromIOBufList(buf, ndata); + } + return _tcp_transport->CutFromIOBufList(buf, ndata); +} + +int RdmaTransport::WaitEpollOut(butil::atomic *_epollout_butex, + bool pollin, const timespec duetime) { + if (_rdma_state == RDMA_ON) { + const int expected_val = _epollout_butex->load(butil::memory_order_acquire); + CHECK(_rdma_ep != NULL); + if (!_rdma_ep->IsWritable()) { + g_vars->nwaitepollout << 1; + if (bthread::butex_wait(_epollout_butex, expected_val, &duetime) < 0) { + if (errno != EAGAIN && errno != ETIMEDOUT) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to wait rdma window of " << _socket; + _socket->SetFailed(saved_errno, + "Fail to wait rdma window of %s: %s", + _socket->description().c_str(), + berror(saved_errno)); + } + if (_socket->Failed()) { + // NOTE: + // Different from TCP, we cannot find the RDMA channel + // failed by writing to it. Thus we must check if it + // is already failed here. + return 1; + } + } + } + } else { + return _tcp_transport->WaitEpollOut(_epollout_butex, pollin, duetime); + } + return 0; +} + +void RdmaTransport::ProcessEvent(bthread_attr_t attr) { + bthread_t tid; + if (FLAGS_usercode_in_coroutine) { + OnEdge(_socket); + } else if (rdma::FLAGS_rdma_edisp_unsched == false) { + auto rc = bthread_start_background(&tid, &attr, OnEdge, _socket); + if (rc != 0) { + LOG(FATAL) << "Fail to start ProcessEvent"; + OnEdge(_socket); + } + } else if (bthread_start_urgent(&tid, &attr, OnEdge, _socket) != 0) { + LOG(FATAL) << "Fail to start ProcessEvent"; + OnEdge(_socket); + } +} + +void RdmaTransport::QueueMessage(InputMessageClosure& input_msg, + int* num_bthread_created, bool last_msg) { + if (last_msg && !rdma::FLAGS_rdma_use_polling) { + return; + } + InputMessageBase* to_run_msg = input_msg.release(); + if (!to_run_msg) { + return; + } + + if (rdma::FLAGS_rdma_disable_bthread) { + ProcessInputMessage(to_run_msg); + return; + } + // Create bthread for last_msg. The bthread is not scheduled + // until bthread_flush() is called (in the worse case). + + // TODO(gejun): Join threads. + bthread_t th; + bthread_attr_t tmp = (FLAGS_usercode_in_pthread ? + BTHREAD_ATTR_PTHREAD : + BTHREAD_ATTR_NORMAL) | BTHREAD_NOSIGNAL; + tmp.keytable_pool = _socket->keytable_pool(); + tmp.tag = bthread_self_tag(); + bthread_attr_set_name(&tmp, "ProcessInputMessage"); + + if (!FLAGS_usercode_in_coroutine && bthread_start_background( + &th, &tmp, ProcessInputMessage, to_run_msg) == 0) { + ++*num_bthread_created; + } else { + ProcessInputMessage(to_run_msg); + } +} + +void RdmaTransport::Debug(std::ostream &os) { + if (_rdma_state == RDMA_ON && _rdma_ep) { + _rdma_ep->DebugInfo(os); + } +} + +int RdmaTransport::ContextInitOrDie(bool serverOrNot, const void* _options) { + if (serverOrNot) { + if (!OptionsAvailableOverRdma(static_cast(_options))) { + return -1; + } + rdma::GlobalRdmaInitializeOrDie(); + if (!rdma::InitPollingModeWithTag(static_cast(_options)->bthread_tag)) { + return -1; + } + } else { + if (!OptionsAvailableForRdma(static_cast(_options))) { + return -1; + } + rdma::GlobalRdmaInitializeOrDie(); + if (!rdma::InitPollingModeWithTag(bthread_self_tag())) { + return -1; + } + return 0; + } + + return 0; +} + +bool RdmaTransport::OptionsAvailableForRdma(const ChannelOptions* opt) { + if (opt->has_ssl_options()) { + LOG(WARNING) << "Cannot use SSL and RDMA at the same time"; + return false; + } + if (!rdma::SupportedByRdma(opt->protocol.name())) { + LOG(WARNING) << "Cannot use " << opt->protocol.name() + << " over RDMA"; + return false; + } + return true; +} + +bool RdmaTransport::OptionsAvailableOverRdma(const ServerOptions* opt) { + if (opt->rtmp_service) { + LOG(WARNING) << "RTMP is not supported by RDMA"; + return false; + } + if (opt->has_ssl_options()) { + LOG(WARNING) << "SSL is not supported by RDMA"; + return false; + } + if (opt->nshead_service) { + LOG(WARNING) << "NSHEAD is not supported by RDMA"; + return false; + } + if (opt->mongo_service_adaptor) { + LOG(WARNING) << "MONGO is not supported by RDMA"; + return false; + } + return true; +} +} // namespace brpc +#endif \ No newline at end of file diff --git a/src/brpc/rdma_transport.h b/src/brpc/rdma_transport.h new file mode 100644 index 0000000000..65ae88f7a6 --- /dev/null +++ b/src/brpc/rdma_transport.h @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_RDMA_TRANSPORT_H +#define BRPC_RDMA_TRANSPORT_H + +#if BRPC_WITH_RDMA +#include "brpc/socket.h" +#include "brpc/channel.h" +#include "brpc/transport.h" + +namespace brpc { +class RdmaTransport : public Transport { + friend class TransportFactory; + friend class rdma::RdmaEndpoint; + friend class rdma::RdmaConnect; +public: + void Init(Socket* socket, const SocketOptions& options) override; + void Release() override; + int Reset(int32_t expected_nref) override; + std::shared_ptr Connect() override; + int CutFromIOBuf(butil::IOBuf* buf) override; + ssize_t CutFromIOBufList(butil::IOBuf** buf, size_t ndata) override; + int WaitEpollOut(butil::atomic* _epollout_butex, bool pollin, const timespec duetime) override; + void ProcessEvent(bthread_attr_t attr) override; + void QueueMessage(InputMessageClosure& inputMsg, int* num_bthread_created, bool last_msg) override; + void Debug(std::ostream &os) override; + rdma::RdmaEndpoint* GetRdmaEp() { + CHECK(_rdma_ep != NULL); + return _rdma_ep; + } + static int ContextInitOrDie(bool serverOrNot, const void* _options); +private: + static bool OptionsAvailableForRdma(const ChannelOptions* opt); + static bool OptionsAvailableOverRdma(const ServerOptions* opt); +private: + // The on/off state of RDMA + enum RdmaState { + RDMA_ON, + RDMA_OFF, + RDMA_UNKNOWN + }; + // The RdmaEndpoint + rdma::RdmaEndpoint* _rdma_ep = NULL; + // Should use RDMA or not + RdmaState _rdma_state; + std::shared_ptr _tcp_transport; +}; +} // namespace brpc +#endif // BRPC_WITH_RDMA +#endif //BRPC_RDMA_TRANSPORT_H \ No newline at end of file diff --git a/src/brpc/selective_channel.cpp b/src/brpc/selective_channel.cpp index dd155a3044..ec93354121 100644 --- a/src/brpc/selective_channel.cpp +++ b/src/brpc/selective_channel.cpp @@ -41,10 +41,13 @@ typedef std::map ChannelToIdMap; class SubChannel : public SocketUser { public: ChannelBase* chan; + ChannelOwnership ownership; // internal channel is deleted after the fake Socket is SetFailed void BeforeRecycle(Socket*) { - delete chan; + if (ownership == OWNS_CHANNEL) { + delete chan; + } delete this; } @@ -83,7 +86,8 @@ class ChannelBalancer : public SharedLoadBalancer { ChannelBalancer() {} ~ChannelBalancer(); int Init(const char* lb_name); - int AddChannel(ChannelBase* sub_channel, const std::string& tag, + int AddChannel(ChannelBase* sub_channel, + const SelectiveChannel::SubChannelOptions& subopt, SelectiveChannel::ChannelHandle* handle); void RemoveAndDestroyChannel(const SelectiveChannel::ChannelHandle& handle); int SelectChannel(const LoadBalancer::SelectIn& in, SelectOut* out); @@ -168,7 +172,8 @@ int ChannelBalancer::Init(const char* lb_name) { return SharedLoadBalancer::Init(lb_name); } -int ChannelBalancer::AddChannel(ChannelBase* sub_channel, const std::string& tag, +int ChannelBalancer::AddChannel(ChannelBase* sub_channel, + const SelectiveChannel::SubChannelOptions& subopt, SelectiveChannel::ChannelHandle* handle) { if (NULL == sub_channel) { LOG(ERROR) << "Parameter[sub_channel] is NULL"; @@ -185,6 +190,7 @@ int ChannelBalancer::AddChannel(ChannelBase* sub_channel, const std::string& tag return -1; } sub_chan->chan = sub_channel; + sub_chan->ownership = subopt.ownership; SocketId sock_id; SocketOptions options; options.user = sub_chan; @@ -206,7 +212,7 @@ int ChannelBalancer::AddChannel(ChannelBase* sub_channel, const std::string& tag << sock_id << " is disabled"; return -1; } - if (!AddServer(ServerId(sock_id, tag))) { + if (!AddServer(ServerId(sock_id, subopt.tag))) { LOG(ERROR) << "Duplicated sub_channel=" << sub_channel; // sub_chan will be deleted when the socket is recycled. ptr->SetFailed(); @@ -215,10 +221,10 @@ int ChannelBalancer::AddChannel(ChannelBase* sub_channel, const std::string& tag return -1; } // The health-check-related reference has been held on created. - _chan_map[sub_channel]= ptr.get(); + _chan_map[sub_channel] = ptr.get(); if (handle) { handle->id = sock_id; - handle->tag = tag; + handle->tag = subopt.tag; } return 0; } @@ -532,12 +538,7 @@ bool SelectiveChannel::initialized() const { } int SelectiveChannel::AddChannel(ChannelBase* sub_channel, - ChannelHandle* handle) { - return AddChannel(sub_channel, "", handle); -} - -int SelectiveChannel::AddChannel(ChannelBase* sub_channel, - const std::string& tag, + const SubChannelOptions& option, ChannelHandle* handle) { schan::ChannelBalancer* lb = static_cast(_chan._lb.get()); @@ -545,7 +546,7 @@ int SelectiveChannel::AddChannel(ChannelBase* sub_channel, LOG(ERROR) << "You must call Init() to initialize a SelectiveChannel"; return -1; } - return lb->AddChannel(sub_channel, tag, handle); + return lb->AddChannel(sub_channel, option, handle); } void SelectiveChannel::RemoveAndDestroyChannel(const ChannelHandle& handle) { diff --git a/src/brpc/selective_channel.h b/src/brpc/selective_channel.h index 6c0af1da9c..fd8fb9cf9d 100644 --- a/src/brpc/selective_channel.h +++ b/src/brpc/selective_channel.h @@ -56,6 +56,11 @@ class SelectiveChannel : public ChannelBase/*non-copyable*/ { std::string tag; }; + struct SubChannelOptions { + std::string tag; + ChannelOwnership ownership = OWNS_CHANNEL; + }; + SelectiveChannel(); ~SelectiveChannel(); @@ -69,8 +74,16 @@ class SelectiveChannel : public ChannelBase/*non-copyable*/ { // On success, handle is set with the key for removal. // NOTE: Different from pchan, schan can add channels at any time. // Returns 0 on success, -1 otherwise. - int AddChannel(ChannelBase* sub_channel, ChannelHandle* handle); - int AddChannel(ChannelBase* sub_channel, const std::string& tag, ChannelHandle* handle); + int AddChannel(ChannelBase* sub_channel, ChannelHandle* handle) { + return AddChannel(sub_channel, SubChannelOptions(), handle); + } + int AddChannel(ChannelBase* sub_channel, const std::string& tag, ChannelHandle* handle) { + SubChannelOptions option; + option.tag = tag; + return AddChannel(sub_channel, option, handle); + } + int AddChannel(ChannelBase* sub_channel, const SubChannelOptions& option, + ChannelHandle* handle); // Remove and destroy the sub_channel associated with `handle'. void RemoveAndDestroyChannel(const ChannelHandle& handle); diff --git a/src/brpc/server.cpp b/src/brpc/server.cpp index 8e2368bcb2..9470220d09 100644 --- a/src/brpc/server.cpp +++ b/src/brpc/server.cpp @@ -81,6 +81,7 @@ #include "brpc/details/tcmalloc_extension.h" #include "brpc/rdma/rdma_helper.h" #include "brpc/baidu_master_service.h" +#include "brpc/transport_factory.h" inline std::ostream& operator<<(std::ostream& os, const timeval& tm) { const char old_fill = os.fill(); @@ -146,7 +147,7 @@ ServerOptions::ServerOptions() , internal_port(-1) , has_builtin_services(true) , force_ssl(false) - , use_rdma(false) + , socket_mode(SOCKET_MODE_TCP) , baidu_master_service(NULL) , http_master_service(NULL) , health_reporter(NULL) @@ -772,27 +773,6 @@ bool Server::CreateConcurrencyLimiter(const AdaptiveMaxConcurrency& amc, return true; } -#if BRPC_WITH_RDMA -static bool OptionsAvailableOverRdma(const ServerOptions* opt) { - if (opt->rtmp_service) { - LOG(WARNING) << "RTMP is not supported by RDMA"; - return false; - } - if (opt->has_ssl_options()) { - LOG(WARNING) << "SSL is not supported by RDMA"; - return false; - } - if (opt->nshead_service) { - LOG(WARNING) << "NSHEAD is not supported by RDMA"; - return false; - } - if (opt->mongo_service_adaptor) { - LOG(WARNING) << "MONGO is not supported by RDMA"; - return false; - } - return true; -} -#endif static AdaptiveMaxConcurrency g_default_max_concurrency_of_method(0); static bool g_default_ignore_eovercrowded(false); @@ -889,20 +869,10 @@ int Server::StartInternal(const butil::EndPoint& endpoint, << FLAGS_task_group_ntags << ")"; return -1; } - - if (_options.use_rdma) { -#if BRPC_WITH_RDMA - if (!OptionsAvailableOverRdma(&_options)) { - return -1; - } - rdma::GlobalRdmaInitializeOrDie(); - if (!rdma::InitPollingModeWithTag(_options.bthread_tag)) { - return -1; - } -#else - LOG(WARNING) << "Cannot use rdma since brpc does not compile with rdma"; + int ret = TransportFactory::ContextInitOrDie(_options.socket_mode, true, &_options); + if (ret != 0) { + LOG(ERROR) << "Fail to initialize transport context for server, ret=" << ret; return -1; -#endif } if (_options.http_master_service) { @@ -1170,7 +1140,7 @@ int Server::StartInternal(const butil::EndPoint& endpoint, LOG(ERROR) << "Fail to build acceptor"; return -1; } - _am->_use_rdma = _options.use_rdma; + _am->_socket_mode = _options.socket_mode; _am->_bthread_tag = _options.bthread_tag; } // Set `_status' to RUNNING before accepting connections diff --git a/src/brpc/server.h b/src/brpc/server.h index c262375c67..9f69a83458 100644 --- a/src/brpc/server.h +++ b/src/brpc/server.h @@ -45,6 +45,7 @@ #include "brpc/concurrency_limiter.h" #include "brpc/baidu_master_service.h" #include "brpc/rpc_pb_message_factory.h" +#include "brpc/socket_mode.h" namespace brpc { @@ -223,9 +224,9 @@ struct ServerOptions { // Force ssl for all connections of the port to Start(). bool force_ssl; - // Whether the server uses rdma or not - // Default: false - bool use_rdma; + // the server socket mode uses tcp or rdma or other + // Default: SOCKET_MODE_TCP + SocketMode socket_mode; // [CAUTION] This option is for implementing specialized baidu-std proxies, // most users don't need it. Don't change this option unless you fully diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp index 9490650b78..b132f2acea 100644 --- a/src/brpc/socket.cpp +++ b/src/brpc/socket.cpp @@ -50,8 +50,7 @@ #include "brpc/policy/rtmp_protocol.h" // FIXME #include "brpc/periodic_task.h" #include "brpc/details/health_check.h" -#include "brpc/rdma/rdma_endpoint.h" -#include "brpc/rdma/rdma_helper.h" +#include "brpc/transport_factory.h" #if defined(OS_MACOSX) #include #endif @@ -456,6 +455,7 @@ Socket::Socket(Forbidden f) , _tos(0) , _reset_fd_real_us(-1) , _on_edge_triggered_events(NULL) + , _need_on_edge_trigger(false) , _user(NULL) , _conn(NULL) , _preferred_index(-1) @@ -473,8 +473,7 @@ Socket::Socket(Forbidden f) , _auth_context(NULL) , _ssl_state(SSL_UNKNOWN) , _ssl_session(NULL) - , _rdma_ep(NULL) - , _rdma_state(RDMA_OFF) + , _socket_mode(SOCKET_MODE_TCP) , _connection_type_for_progressive_read(CONNECTION_TYPE_UNKNOWN) , _controller_released_socket(false) , _overcrowded(false) @@ -601,7 +600,7 @@ int Socket::ResetFileDescriptor(int fd) { SetSocketOptions(fd); - if (_on_edge_triggered_events) { + if (_transport->HasOnEdgeTrigger()) { if (_io_event.AddConsumer(fd) != 0) { PLOG(ERROR) << "Fail to add SocketId=" << id() << " into EventDispatcher"; @@ -721,6 +720,11 @@ int Socket::OnCreated(const SocketOptions& options) { auto guard = butil::MakeScopeGuard([this] { _io_event.Reset(); }); + // start build the transport + _socket_mode = options.socket_mode; + _transport = TransportFactory::CreateTransport(options.socket_mode); + CHECK(NULL != _transport); + _transport->Init(this, options); g_vars->nsocket << 1; CHECK(NULL == _shared_part.load(butil::memory_order_relaxed)); @@ -728,11 +732,13 @@ int Socket::OnCreated(const SocketOptions& options) { _keytable_pool = options.keytable_pool; _tos = 0; _remote_side = options.remote_side; - _local_side = butil::EndPoint(); + _local_side = options.local_side; + _device_name = options.device_name; _on_edge_triggered_events = options.on_edge_triggered_events; + _need_on_edge_trigger = options.need_on_edge_trigger; _user = options.user; _conn = options.conn; - _app_connect = options.app_connect; + _app_connect = _transport->Connect(); _preferred_index = -1; _hc_count = 0; CHECK(_read_buf.empty()); @@ -756,22 +762,6 @@ int Socket::OnCreated(const SocketOptions& options) { _ssl_state = (options.initial_ssl_ctx == NULL ? SSL_OFF : SSL_UNKNOWN); _ssl_session = NULL; _ssl_ctx = options.initial_ssl_ctx; -#if BRPC_WITH_RDMA - CHECK(_rdma_ep == NULL); - if (options.use_rdma) { - _rdma_ep = new (std::nothrow)rdma::RdmaEndpoint(this); - if (!_rdma_ep) { - const int saved_errno = errno; - PLOG(ERROR) << "Fail to create RdmaEndpoint"; - SetFailed(saved_errno, "Fail to create RdmaEndpoint: %s", - berror(saved_errno)); - return -1; - } - _rdma_state = RDMA_UNKNOWN; - } else { - _rdma_state = RDMA_OFF; - } -#endif _connection_type_for_progressive_read = CONNECTION_TYPE_UNKNOWN; _controller_released_socket.store(false, butil::memory_order_relaxed); _overcrowded = false; @@ -851,7 +841,7 @@ void Socket::BeforeRecycled() { }; const int prev_fd = _fd.exchange(-1, butil::memory_order_relaxed); if (ValidFileDescriptor(prev_fd)) { - if (_on_edge_triggered_events != NULL) { + if (_transport->HasOnEdgeTrigger()) { _io_event.RemoveConsumer(prev_fd); } close(prev_fd); @@ -859,15 +849,7 @@ void Socket::BeforeRecycled() { g_vars->channel_conn << -1; } } - -#if BRPC_WITH_RDMA - if (_rdma_ep) { - delete _rdma_ep; - _rdma_ep = NULL; - _rdma_state = RDMA_UNKNOWN; - } -#endif - + _transport->Release(); reset_parsing_context(NULL); _read_buf.clear(); @@ -1012,7 +994,7 @@ int Socket::WaitAndReset(int32_t expected_nref) { // It's safe to close previous fd (provided expected_nref is correct). const int prev_fd = _fd.exchange(-1, butil::memory_order_relaxed); if (ValidFileDescriptor(prev_fd)) { - if (_on_edge_triggered_events != NULL) { + if (_transport->HasOnEdgeTrigger()) { _io_event.RemoveConsumer(prev_fd); } close(prev_fd); @@ -1020,13 +1002,7 @@ int Socket::WaitAndReset(int32_t expected_nref) { g_vars->channel_conn << -1; } } - -#if BRPC_WITH_RDMA - if (_rdma_ep) { - _rdma_ep->Reset(); - _rdma_state = RDMA_UNKNOWN; - } -#endif + _transport->Reset(expected_nref); _local_side = butil::EndPoint(); if (_ssl_session) { @@ -1180,13 +1156,6 @@ int Socket::Status(SocketId id, int32_t* nref) { return -1; } -void* Socket::ProcessEvent(void* arg) { - // the enclosed Socket is valid and free to access inside this function. - SocketUniquePtr s(static_cast(arg)); - s->_on_edge_triggered_events(s.get()); - return NULL; -} - // Check if there're new requests appended. // If yes, point old_head to reversed new requests and return false; // If no: @@ -1296,7 +1265,25 @@ int Socket::Connect(const timespec* abstime, CHECK_EQ(0, butil::make_close_on_exec(sockfd)); // We need to do async connect (to manage the timeout by ourselves). CHECK_EQ(0, butil::make_non_blocking(sockfd)); - + if (!_device_name.empty()) { + if (setsockopt(sockfd, SOL_SOCKET, SO_BINDTODEVICE, + _device_name.c_str(), _device_name.size()) < 0) { + PLOG(ERROR) << "Fail to set SO_BINDTODEVICE of fd=" << sockfd + << " to device_name=" << _device_name; + return -1; + } + } + if (local_side().ip != butil::IP_ANY) { + struct sockaddr_storage cli_addr; + if (butil::endpoint2sockaddr(local_side(), &cli_addr, &addr_size) != 0) { + PLOG(ERROR) << "Fail to get client sockaddr"; + return -1; + } + if (::bind(sockfd, (struct sockaddr*)&cli_addr, addr_size) != 0) { + PLOG(ERROR) << "Fail to bind client socket, errno=" << strerror(errno); + return -1; + } + } const int rc = ::connect( sockfd, (struct sockaddr*)&serv_addr, addr_size); if (rc != 0 && errno != EINPROGRESS) { @@ -1752,16 +1739,7 @@ int Socket::StartWrite(WriteRequest* req, const WriteOptions& opt) { butil::IOBuf* data_arr[1] = { &req->data }; nw = _conn->CutMessageIntoFileDescriptor(fd(), data_arr, 1); } else { -#if BRPC_WITH_RDMA - if (_rdma_ep && _rdma_state != RDMA_OFF) { - butil::IOBuf* data_arr[1] = { &req->data }; - nw = _rdma_ep->CutFromIOBufList(data_arr, 1); - } else { -#else - { -#endif - nw = req->data.cut_into_file_descriptor(fd()); - } + nw = _transport->CutFromIOBuf(&req->data); } if (nw < 0) { // RTMP may return EOVERCROWDED @@ -1863,45 +1841,11 @@ void* Socket::KeepWrite(void* void_arg) { // which may turn on _overcrowded to stop pending requests from // growing infinitely. const timespec duetime = - butil::milliseconds_from_now(WAIT_EPOLLOUT_TIMEOUT_MS); -#if BRPC_WITH_RDMA - if (s->_rdma_state == RDMA_ON) { - const int expected_val = s->_epollout_butex - ->load(butil::memory_order_acquire); - CHECK(s->_rdma_ep != NULL); - if (!s->_rdma_ep->IsWritable()) { - g_vars->nwaitepollout << 1; - if (bthread::butex_wait(s->_epollout_butex, - expected_val, &duetime) < 0) { - if (errno != EAGAIN && errno != ETIMEDOUT) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to wait rdma window of " << *s; - s->SetFailed(saved_errno, "Fail to wait rdma window of %s: %s", - s->description().c_str(), berror(saved_errno)); - } - if (s->Failed()) { - // NOTE: - // Different from TCP, we cannot find the RDMA channel - // failed by writing to it. Thus we must check if it - // is already failed here. - break; - } - } - } - } else { -#else - { -#endif - g_vars->nwaitepollout << 1; - bool pollin = (s->_on_edge_triggered_events != NULL); - const int rc = s->WaitEpollOut(s->fd(), pollin, &duetime); - if (rc < 0 && errno != ETIMEDOUT) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to wait epollout of " << *s; - s->SetFailed(saved_errno, "Fail to wait epollout of %s: %s", - s->description().c_str(), berror(saved_errno)); - break; - } + butil::milliseconds_from_now(WAIT_EPOLLOUT_TIMEOUT_MS); + bool pollin = s->_transport->HasOnEdgeTrigger(); + int ret = s->_transport->WaitEpollOut(s->_epollout_butex, pollin, duetime); + if (ret == 1) { + break; } } if (NULL == cur_tail) { @@ -1941,13 +1885,7 @@ ssize_t Socket::DoWrite(WriteRequest* req) { if (_conn) { return _conn->CutMessageIntoFileDescriptor(fd(), data_list, ndata); } else { -#if BRPC_WITH_RDMA - if (_rdma_ep && _rdma_state != RDMA_OFF) { - return _rdma_ep->CutFromIOBufList(data_list, ndata); - } -#endif - return butil::IOBuf::cut_multiple_into_file_descriptor( - fd(), data_list, ndata); + return _transport->CutFromIOBufList(data_list, ndata); } } @@ -2136,7 +2074,6 @@ ssize_t Socket::DoRead(size_t size_hint) { errno = ESSL; return -1; } - CHECK(_rdma_state == RDMA_OFF); return _read_buf.append_from_file_descriptor(fd(), size_hint); } @@ -2238,7 +2175,7 @@ int Socket::OnInputEvent(void* user_data, uint32_t events, if (Address(id, &s) < 0) { return -1; } - if (NULL == s->_on_edge_triggered_events) { + if (!s->_transport->HasOnEdgeTrigger()) { // Callback can be NULL when receiving error epoll events // (Added into epoll by `WaitConnected') return 0; @@ -2264,28 +2201,15 @@ int Socket::OnInputEvent(void* user_data, uint32_t events, // is just 1500~1700/s g_vars->neventthread << 1; - bthread_t tid; // transfer ownership as well, don't use s anymore! Socket* const p = s.release(); bthread_attr_t attr = thread_attr; attr.keytable_pool = p->_keytable_pool; attr.tag = bthread_self_tag(); - bthread_attr_set_name(&attr, "ProcessEvent"); - if (FLAGS_usercode_in_coroutine) { - ProcessEvent(p); -#if BRPC_WITH_RDMA - } else if (rdma::FLAGS_rdma_edisp_unsched) { - auto rc = bthread_start_background(&tid, &attr, ProcessEvent, p); - if (rc != 0) { - LOG(FATAL) << "Fail to start ProcessEvent"; - ProcessEvent(p); - } -#endif - } else if (bthread_start_urgent(&tid, &attr, ProcessEvent, p) != 0) { - LOG(FATAL) << "Fail to start ProcessEvent"; - ProcessEvent(p); - } + // Only event dispatcher thread has flag BTHREAD_GLOBAL_PRIORITY + attr.flags = attr.flags & (~BTHREAD_GLOBAL_PRIORITY); + p->_transport->ProcessEvent(attr); } return 0; } @@ -2587,11 +2511,7 @@ void Socket::DebugSocket(std::ostream& os, SocketId id) { << "\n}"; } #endif -#if BRPC_WITH_RDMA - if (ptr->_rdma_state == RDMA_ON && ptr->_rdma_ep) { - ptr->_rdma_ep->DebugInfo(os); - } -#endif + ptr->_transport->Debug(os); { os << "\nbthread_tag=" << ptr->_io_event.bthread_tag(); } } @@ -2811,12 +2731,14 @@ int Socket::GetPooledSocket(SocketUniquePtr* pooled_socket) { if (socket_pool == NULL) { SocketOptions opt; opt.remote_side = remote_side(); + opt.local_side = butil::EndPoint(local_side().ip, 0); opt.user = user(); opt.on_edge_triggered_events = _on_edge_triggered_events; + opt.need_on_edge_trigger = _need_on_edge_trigger; opt.initial_ssl_ctx = _ssl_ctx; opt.keytable_pool = _keytable_pool; opt.app_connect = _app_connect; - opt.use_rdma = (_rdma_ep) ? true : false; + opt.socket_mode = _socket_mode; socket_pool = new SocketPool(opt); SocketPool* expected = NULL; if (!main_sp->socket_pool.compare_exchange_strong( @@ -2912,12 +2834,14 @@ int Socket::GetShortSocket(SocketUniquePtr* short_socket) { SocketId id; SocketOptions opt; opt.remote_side = remote_side(); + opt.local_side = butil::EndPoint(local_side().ip, 0); opt.user = user(); opt.on_edge_triggered_events = _on_edge_triggered_events; + opt.need_on_edge_trigger = _need_on_edge_trigger; opt.initial_ssl_ctx = _ssl_ctx; opt.keytable_pool = _keytable_pool; opt.app_connect = _app_connect; - opt.use_rdma = (_rdma_ep) ? true : false; + opt.socket_mode = _socket_mode; if (get_client_side_messenger()->Create(opt, &id) != 0 || Address(id, short_socket) != 0) { return -1; diff --git a/src/brpc/socket.h b/src/brpc/socket.h index 03ad43f867..816fccdf27 100644 --- a/src/brpc/socket.h +++ b/src/brpc/socket.h @@ -42,6 +42,7 @@ #include "brpc/event_dispatcher.h" #include "brpc/versioned_ref_with_id.h" #include "brpc/health_check_option.h" +#include "brpc/socket_mode.h" namespace brpc { namespace policy { @@ -61,6 +62,7 @@ class Socket; class AuthContext; class EventDispatcher; class Stream; +class Transport; // A special closure for processing the about-to-recycle socket. Socket does // not delete SocketUser, if you want, `delete this' at the end of @@ -250,6 +252,8 @@ struct SocketOptions { // user->BeforeRecycle() before recycling. int fd{-1}; butil::EndPoint remote_side; + butil::EndPoint local_side; + std::string device_name; // If `connect_on_create' is true and `fd' is less than 0, // a client connection will be established to remote_side() // regarding deadline `connect_abstime' when Socket is being created. @@ -266,11 +270,20 @@ struct SocketOptions { // until new data arrives. The callback will not be called from more than // one thread at any time. void (*on_edge_triggered_events)(Socket*){NULL}; + // Indicates that this socket requires an edge-triggered event handler even + // if `on_edge_triggered_events` is left as NULL by the caller. When this + // flag is true and `on_edge_triggered_events` is NULL, the underlying + // transport-specific implementation (e.g. a transport subclass) is allowed + // to install a suitable default `on_edge_triggered_events` callback on + // behalf of the user. Typical usage is by transports/protocols that rely + // on edge-triggered I/O semantics but want the framework to provide the + // actual event handler. + bool need_on_edge_trigger{false}; int health_check_interval_s{-1}; // Only accept ssl connection. bool force_ssl{false}; std::shared_ptr initial_ssl_ctx; - bool use_rdma{false}; + SocketMode socket_mode{SOCKET_MODE_TCP}; bthread_keytable_pool_t* keytable_pool{NULL}; SocketConnection* conn{NULL}; std::shared_ptr app_connect; @@ -311,6 +324,10 @@ friend class policy::H2GlobalStreamCreator; friend class VersionedRefWithId; friend class IOEvent; friend void DereferenceSocket(Socket*); +friend class Transport; +friend class TcpTransport; +friend class RdmaTransport; +friend class TransportFactory; class SharedPart; struct WriteRequest; @@ -648,13 +665,6 @@ friend void DereferenceSocket(Socket*); private: DISALLOW_COPY_AND_ASSIGN(Socket); - // The on/off state of RDMA - enum RdmaState { - RDMA_ON, - RDMA_OFF, - RDMA_UNKNOWN - }; - int ConductError(bthread_id_t); int StartWrite(WriteRequest*, const WriteOptions&); @@ -730,7 +740,6 @@ friend void DereferenceSocket(Socket*); // Wait until nref hits `expected_nref' and reset some internal resources. int WaitAndReset(int32_t expected_nref); - static void* ProcessEvent(void*); static void* KeepWrite(void*); @@ -830,11 +839,14 @@ friend void DereferenceSocket(Socket*); // Address of self. Initialized in ResetFileDescriptor(). butil::EndPoint _local_side; + // The device name of the client's network adapter. + std::string _device_name; + // Called when edge-triggered events happened on `_fd'. Read comments // of EventDispatcher::AddConsumer (event_dispatcher.h) // carefully before implementing the callback. void (*_on_edge_triggered_events)(Socket*); - + bool _need_on_edge_trigger; // A set of callbacks to monitor important events of this socket. // Initialized by SocketOptions.user SocketUser* _user; @@ -913,10 +925,9 @@ friend void DereferenceSocket(Socket*); SSL* _ssl_session; // owner std::shared_ptr _ssl_ctx; - // The RdmaEndpoint - rdma::RdmaEndpoint* _rdma_ep; - // Should use RDMA or not - RdmaState _rdma_state; + // Should use SOCKET_MODE_RDMA or SOCKET_MODE_TCP or Other, default is SOCKET_MODE_TCP Transport + SocketMode _socket_mode; + std::unique_ptr _transport; // Pass from controller, for progressive reading. ConnectionType _connection_type_for_progressive_read; diff --git a/src/brpc/socket_map.cpp b/src/brpc/socket_map.cpp index 14bea71db5..3984f6b866 100644 --- a/src/brpc/socket_map.cpp +++ b/src/brpc/socket_map.cpp @@ -90,11 +90,9 @@ SocketMap* get_or_new_client_side_socket_map() { } int SocketMapInsert(const SocketMapKey& key, SocketId* id, - const std::shared_ptr& ssl_ctx, - bool use_rdma, - const HealthCheckOption& hc_option) { - return get_or_new_client_side_socket_map()->Insert(key, id, ssl_ctx, use_rdma, hc_option); -} + SocketOptions& opt) { + return get_or_new_client_side_socket_map()->Insert(key, id, opt); +} int SocketMapFind(const SocketMapKey& key, SocketId* id) { SocketMap* m = get_client_side_socket_map(); @@ -227,9 +225,7 @@ void SocketMap::ShowSocketMapInBvarIfNeed() { } int SocketMap::Insert(const SocketMapKey& key, SocketId* id, - const std::shared_ptr& ssl_ctx, - bool use_rdma, - const HealthCheckOption& hc_option) { + SocketOptions& opt) { ShowSocketMapInBvarIfNeed(); std::unique_lock mu(_mutex); @@ -249,11 +245,7 @@ int SocketMap::Insert(const SocketMapKey& key, SocketId* id, sc = NULL; } SocketId tmp_id; - SocketOptions opt; opt.remote_side = key.peer.addr; - opt.initial_ssl_ctx = ssl_ctx; - opt.use_rdma = use_rdma; - opt.hc_option = hc_option; if (_options.socket_creator->CreateSocket(opt, &tmp_id) != 0) { PLOG(FATAL) << "Fail to create socket to " << key.peer; return -1; diff --git a/src/brpc/socket_map.h b/src/brpc/socket_map.h index b0d542e78e..b1922bf86e 100644 --- a/src/brpc/socket_map.h +++ b/src/brpc/socket_map.h @@ -80,20 +80,30 @@ struct SocketMapKeyHasher { // successfully, SocketMapRemove() MUST be called when the Socket is not needed. // Return 0 on success, -1 otherwise. int SocketMapInsert(const SocketMapKey& key, SocketId* id, + SocketOptions& opt); + +inline int SocketMapInsert(const SocketMapKey& key, SocketId* id, const std::shared_ptr& ssl_ctx, - bool use_rdma, - const HealthCheckOption& hc_option); + SocketMode socket_mode, + const HealthCheckOption& hc_option) { + SocketOptions opt; + opt.remote_side = key.peer.addr; + opt.initial_ssl_ctx = ssl_ctx; + opt.socket_mode = socket_mode; + opt.hc_option = hc_option; + return SocketMapInsert(key, id, opt); +} inline int SocketMapInsert(const SocketMapKey& key, SocketId* id, const std::shared_ptr& ssl_ctx) { HealthCheckOption hc_option; - return SocketMapInsert(key, id, ssl_ctx, false, hc_option); + return SocketMapInsert(key, id, ssl_ctx, SOCKET_MODE_TCP, hc_option); } inline int SocketMapInsert(const SocketMapKey& key, SocketId* id) { std::shared_ptr empty_ptr; HealthCheckOption hc_option; - return SocketMapInsert(key, id, empty_ptr, false, hc_option); + return SocketMapInsert(key, id, empty_ptr, SOCKET_MODE_TCP, hc_option); } // Find the SocketId associated with `key'. @@ -154,19 +164,27 @@ class SocketMap { int Init(const SocketMapOptions&); int Insert(const SocketMapKey& key, SocketId* id, const std::shared_ptr& ssl_ctx, - bool use_rdma, - const HealthCheckOption& hc_option); + SocketMode socket_mode, + const HealthCheckOption& hc_option) { + SocketOptions opt; + opt.remote_side = key.peer.addr; + opt.initial_ssl_ctx = ssl_ctx; + opt.socket_mode = socket_mode; + opt.hc_option = hc_option; + return Insert(key, id, opt); +} int Insert(const SocketMapKey& key, SocketId* id, const std::shared_ptr& ssl_ctx) { HealthCheckOption hc_option; - return Insert(key, id, ssl_ctx, false, hc_option); + return Insert(key, id, ssl_ctx, SOCKET_MODE_TCP, hc_option); } int Insert(const SocketMapKey& key, SocketId* id) { std::shared_ptr empty_ptr; HealthCheckOption hc_option; - return Insert(key, id, empty_ptr, false, hc_option); + return Insert(key, id, empty_ptr, SOCKET_MODE_TCP, hc_option); } + int Insert(const SocketMapKey& key, SocketId* id, SocketOptions& opt); void Remove(const SocketMapKey& key, SocketId expected_id); int Find(const SocketMapKey& key, SocketId* id); diff --git a/src/brpc/socket_mode.h b/src/brpc/socket_mode.h new file mode 100644 index 0000000000..b5d42be4aa --- /dev/null +++ b/src/brpc/socket_mode.h @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_SOCKET_MODE_H +#define BRPC_SOCKET_MODE_H +namespace brpc { +enum SocketMode { + SOCKET_MODE_TCP = 0, + SOCKET_MODE_RDMA = 1 +}; +} // namespace brpc +#endif //BRPC_SOCKET_MODE_H \ No newline at end of file diff --git a/src/brpc/tcp_transport.cpp b/src/brpc/tcp_transport.cpp new file mode 100644 index 0000000000..37db7a8966 --- /dev/null +++ b/src/brpc/tcp_transport.cpp @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "brpc/tcp_transport.h" + +namespace brpc { +DECLARE_bool(usercode_in_coroutine); +DECLARE_bool(usercode_in_pthread); + +extern SocketVarsCollector* g_vars; + +void TcpTransport::Init(Socket* socket, const SocketOptions& options) { + _socket = socket; + _default_connect = options.app_connect; + _on_edge_trigger = options.on_edge_triggered_events; + if (options.need_on_edge_trigger && _on_edge_trigger == NULL) { + _on_edge_trigger = InputMessenger::OnNewMessages; + } +} + +void TcpTransport::Release(){} + +int TcpTransport::Reset(int32_t expected_nref) { + return 0; +} + +int TcpTransport::CutFromIOBuf(butil::IOBuf* buf) { + return buf->cut_into_file_descriptor(_socket->fd()); +} + +std::shared_ptr TcpTransport::Connect() { + return _default_connect; +} + +ssize_t TcpTransport::CutFromIOBufList(butil::IOBuf** buf, size_t ndata) { + return butil::IOBuf::cut_multiple_into_file_descriptor(_socket->fd(), buf, ndata); +} + +int TcpTransport::WaitEpollOut(butil::atomic* _epollout_butex, + bool pollin, timespec duetime) { + g_vars->nwaitepollout << 1; + const int rc = _socket->WaitEpollOut(_socket->fd(), pollin, &duetime); + if (rc < 0 && errno != ETIMEDOUT) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to wait epollout of " << _socket; + _socket->SetFailed(saved_errno, "Fail to wait epollout of %s: %s", + _socket->description().c_str(), berror(saved_errno)); + return 1; + } + return 0; +} + +void TcpTransport::ProcessEvent(bthread_attr_t attr) { + bthread_t tid; + if (FLAGS_usercode_in_coroutine) { + OnEdge(_socket); + } else if (bthread_start_urgent(&tid, &attr, OnEdge, _socket) != 0) { + LOG(FATAL) << "Fail to start ProcessEvent"; + OnEdge(_socket); + } +} +void TcpTransport::QueueMessage(InputMessageClosure& input_msg, + int* num_bthread_created, bool) { + InputMessageBase* to_run_msg = input_msg.release(); + if (!to_run_msg) { + return; + } + // Create bthread for last_msg. The bthread is not scheduled + // until bthread_flush() is called (in the worse case). + bthread_t th; + bthread_attr_t tmp = + (FLAGS_usercode_in_pthread ? BTHREAD_ATTR_PTHREAD : BTHREAD_ATTR_NORMAL) | + BTHREAD_NOSIGNAL; + tmp.keytable_pool = _socket->keytable_pool(); + tmp.tag = bthread_self_tag(); + bthread_attr_set_name(&tmp, "ProcessInputMessage"); + if (!FLAGS_usercode_in_coroutine && bthread_start_background( + &th, &tmp, ProcessInputMessage, to_run_msg) == 0) { + ++*num_bthread_created; + } else { + ProcessInputMessage(to_run_msg); + } +} + +} // namespace brpc \ No newline at end of file diff --git a/src/brpc/tcp_transport.h b/src/brpc/tcp_transport.h new file mode 100644 index 0000000000..8a06a85d37 --- /dev/null +++ b/src/brpc/tcp_transport.h @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_TCP_TRANSPORT_H +#define BRPC_TCP_TRANSPORT_H + +#include "brpc/transport.h" +#include "brpc/socket.h" + +namespace brpc { +class TcpTransport : public Transport { + friend class TransportFactory; +public: + void Init(Socket* socket, const SocketOptions& options) override; + void Release() override; + int Reset(int32_t expected_nref) override; + std::shared_ptr Connect() override; + int CutFromIOBuf(butil::IOBuf* buf) override; + ssize_t CutFromIOBufList(butil::IOBuf** buf, size_t ndata) override; + int WaitEpollOut(butil::atomic* _epollout_butex, bool pollin, timespec duetime) override; + void ProcessEvent(bthread_attr_t attr) override; + void QueueMessage(InputMessageClosure& input_msg, int* num_bthread_created, bool last_msg) override; + void Debug(std::ostream &os) override {} +}; +} // namespace brpc + +#endif //BRPC_TCP_TRANSPORT_H \ No newline at end of file diff --git a/src/brpc/transport.h b/src/brpc/transport.h new file mode 100644 index 0000000000..a2cb868b89 --- /dev/null +++ b/src/brpc/transport.h @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_TRANSPORT_H +#define BRPC_TRANSPORT_H +#include "brpc/input_messenger.h" +#include "brpc/socket.h" +#include "server.h" + +namespace brpc { +using OnEdgeTrigger = std::function; +class Transport { + friend class TransportFactory; +public: + static void* OnEdge(void* arg) { + // the enclosed Socket is valid and free to access inside this function. + SocketUniquePtr s(static_cast(arg)); + const OnEdgeTrigger on_edge_trigger = s->_transport->GetOnEdgeTrigger(); + on_edge_trigger(s.get()); + return NULL; + } + + static void* ProcessInputMessage(void* void_arg) { + InputMessageBase* msg = static_cast(void_arg); + msg->_process(msg); + return NULL; + } + virtual ~Transport() = default; + virtual void Init(Socket* socket, const SocketOptions& options) = 0; + virtual void Release() = 0; + virtual int Reset(int32_t expected_nref) = 0; + virtual std::shared_ptr Connect() = 0; + virtual int CutFromIOBuf(butil::IOBuf* buf) = 0; + virtual ssize_t CutFromIOBufList(butil::IOBuf** buf, size_t ndata) = 0; + virtual int WaitEpollOut(butil::atomic* _epollout_butex, bool pollin, timespec duetime) = 0; + virtual void ProcessEvent(bthread_attr_t attr) = 0; + virtual void QueueMessage(InputMessageClosure& input_msg, int* num_bthread_created, bool last_msg) = 0; + virtual void Debug(std::ostream &os) = 0; + + bool HasOnEdgeTrigger() { + return _on_edge_trigger != NULL; + } + OnEdgeTrigger GetOnEdgeTrigger() { + return _on_edge_trigger; + } +protected: + Socket* _socket; + std::shared_ptr _default_connect; + OnEdgeTrigger _on_edge_trigger; +}; +} +#endif //BRPC_TRANSPORT_H \ No newline at end of file diff --git a/src/brpc/transport_factory.cpp b/src/brpc/transport_factory.cpp new file mode 100644 index 0000000000..b689e2edd2 --- /dev/null +++ b/src/brpc/transport_factory.cpp @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "brpc/transport_factory.h" +#include "brpc/tcp_transport.h" +#include "brpc/rdma_transport.h" + +namespace brpc { +int TransportFactory::ContextInitOrDie(SocketMode mode, bool serverOrNot, const void* _options) { + if (mode == SOCKET_MODE_TCP) { + return 0; + } +#if BRPC_WITH_RDMA + else if (mode == SOCKET_MODE_RDMA) { + return RdmaTransport::ContextInitOrDie(serverOrNot, _options); + } +#endif + else { + LOG(ERROR) << "unknown transport type " << mode; + return 1; + } +} + +std::unique_ptr TransportFactory::CreateTransport(SocketMode mode) { + if (mode == SOCKET_MODE_TCP) { + return std::unique_ptr(new TcpTransport()); + } +#if BRPC_WITH_RDMA + else if (mode == SOCKET_MODE_RDMA) { + return std::unique_ptr(new RdmaTransport()); + } +#endif + else { + LOG(ERROR) << "socket_mode set error"; + return nullptr; + } +} +} // namespace brpc \ No newline at end of file diff --git a/src/brpc/transport_factory.h b/src/brpc/transport_factory.h new file mode 100644 index 0000000000..d933a130e1 --- /dev/null +++ b/src/brpc/transport_factory.h @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_TRANSPORT_FACTORY_H +#define BRPC_TRANSPORT_FACTORY_H + +#include "brpc/socket_mode.h" +#include "brpc/transport.h" + +namespace brpc { +// TransportFactory to create transport instance with socket_mode {TCP, RDMA} +class TransportFactory { +public: + static int ContextInitOrDie(SocketMode mode, bool serverOrNot, const void* _options); + // Create transport instance with socket mode. + static std::unique_ptr CreateTransport(SocketMode mode); +}; +} // namespace brpc + +#endif //BRPC_TRANSPORT_FACTORY_H \ No newline at end of file diff --git a/src/bthread/bthread.h b/src/bthread/bthread.h index 7e42c96c9f..603cf04d0e 100644 --- a/src/bthread/bthread.h +++ b/src/bthread/bthread.h @@ -30,6 +30,7 @@ #if defined(__cplusplus) #include #include "bthread/mutex.h" // use bthread_mutex_t in the RAII way +#include "bthread/condition_variable.h" // use bthread_cond_t in the RAII way #endif // __cplusplus #include "bthread/id.h" diff --git a/src/bthread/condition_variable.h b/src/bthread/condition_variable.h index c684cf6cbd..fb6bb4bcb5 100644 --- a/src/bthread/condition_variable.h +++ b/src/bthread/condition_variable.h @@ -63,6 +63,20 @@ class ConditionVariable { bthread_cond_wait(&_cond, lock.mutex()); } + template + void wait(std::unique_lock& lock, Predicate p) { + while (!p()) { + bthread_cond_wait(&_cond, lock.mutex()->native_handler()); + } + } + + template + void wait(std::unique_lock& lock, Predicate p) { + while (!p()) { + bthread_cond_wait(&_cond, lock.mutex()); + } + } + // Unlike std::condition_variable, we return ETIMEDOUT when time expires // rather than std::timeout int wait_for(std::unique_lock& lock, diff --git a/test/brpc_server_unittest.cpp b/test/brpc_server_unittest.cpp index 4a774fab2a..8508a7986c 100644 --- a/test/brpc_server_unittest.cpp +++ b/test/brpc_server_unittest.cpp @@ -2070,4 +2070,49 @@ TEST_F(ServerTest, auth) { ASSERT_EQ(0, server.Join()); } +void TestClientHost(const butil::EndPoint& ep, + brpc::Controller& cntl, + int error_code, bool failed, + brpc::ChannelOptions& copt) { + brpc::Channel chan; + copt.max_retry = 0; + ASSERT_EQ(0, chan.Init(ep, &copt)); + + test::EchoRequest req; + test::EchoResponse res; + req.set_message(EXP_REQUEST); + test::EchoService_Stub stub(&chan); + stub.Echo(&cntl, &req, &res, NULL); + ASSERT_EQ(cntl.Failed(), failed) << cntl.ErrorText(); + ASSERT_EQ(cntl.ErrorCode(), error_code); +} + +TEST_F(ServerTest, bind_client_host_and_network_device) { + butil::EndPoint ep; + ASSERT_EQ(0, str2endpoint("127.0.0.1:8613", &ep)); + brpc::Server server; + EchoServiceImpl service; + ASSERT_EQ(0, server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE)); + brpc::ServerOptions opt; + ASSERT_EQ(0, server.Start(ep, &opt)); + + brpc::Controller cntl; + brpc::ChannelOptions copt; + copt.client_host = "localhost"; + copt.device_name = "lo"; + std::vector connection_types = { + brpc::CONNECTION_TYPE_SINGLE, + brpc::CONNECTION_TYPE_POOLED, + brpc::CONNECTION_TYPE_SHORT + }; + for (auto connect_type : connection_types) { + copt.connection_type = connect_type; + TestClientHost(ep, cntl, 0, false, copt); + cntl.Reset(); + } + + ASSERT_EQ(0, server.Stop(0)); + ASSERT_EQ(0, server.Join()); +} + } //namespace diff --git a/test/bthread_cond_unittest.cpp b/test/bthread_cond_unittest.cpp index d01ef69c26..f2dcddfe8c 100644 --- a/test/bthread_cond_unittest.cpp +++ b/test/bthread_cond_unittest.cpp @@ -138,7 +138,10 @@ TEST(CondTest, sanity) { struct WrapperArg { bthread::Mutex mutex; bthread::ConditionVariable cond; + bool ready = false; + static std::atomic wake_time; }; +std::atomic WrapperArg::wake_time{0}; void* cv_signaler(void* void_arg) { WrapperArg* a = (WrapperArg*)void_arg; @@ -168,6 +171,23 @@ void* cv_mutex_waiter(void* void_arg) { return NULL; } + +void* cv_bmutex_waiter_with_pred(void* void_arg) { + WrapperArg* a = (WrapperArg*)void_arg; + std::unique_lock lck(*a->mutex.native_handler()); + a->cond.wait(lck, [&] { return a->ready; }); + WrapperArg::wake_time.fetch_add(1); + return NULL; +} + +void* cv_mutex_waiter_with_pred(void* void_arg) { + WrapperArg* a = (WrapperArg*)void_arg; + std::unique_lock lck(a->mutex); + a->cond.wait(lck, [&] { return a->ready; }); + WrapperArg::wake_time.fetch_add(1); + return NULL; +} + #define COND_IN_PTHREAD #ifndef COND_IN_PTHREAD @@ -202,6 +222,37 @@ TEST(CondTest, cpp_wrapper) { } } +TEST(CondTest, cpp_wrapper2) { + stop = false; + bthread::ConditionVariable cond; + pthread_t bmutex_waiter_threads[8]; + pthread_t mutex_waiter_threads[8]; + pthread_t signal_thread; + WrapperArg a; + for (size_t i = 0; i < ARRAY_SIZE(bmutex_waiter_threads); ++i) { + ASSERT_EQ(0, pthread_create(&bmutex_waiter_threads[i], NULL, + cv_bmutex_waiter_with_pred, &a)); + ASSERT_EQ(0, pthread_create(&mutex_waiter_threads[i], NULL, + cv_mutex_waiter_with_pred, &a)); + } + ASSERT_EQ(0, pthread_create(&signal_thread, NULL, cv_signaler, &a)); + bthread_usleep(100L * 1000); + ASSERT_EQ(WrapperArg::wake_time, 0); + { + BAIDU_SCOPED_LOCK(a.mutex); + stop = true; + a.ready = true; + + } + pthread_join(signal_thread, NULL); + a.cond.notify_all(); + for (size_t i = 0; i < ARRAY_SIZE(bmutex_waiter_threads); ++i) { + pthread_join(bmutex_waiter_threads[i], NULL); + pthread_join(mutex_waiter_threads[i], NULL); + } + ASSERT_EQ(WrapperArg::wake_time, 16); +} + #ifndef COND_IN_PTHREAD #undef pthread_join #undef pthread_create