Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 131 additions & 24 deletions src/brpc/rdma/rdma_endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ DEFINE_int32(rdma_poller_num, 1, "Poller number in RDMA polling mode.");
DEFINE_bool(rdma_poller_yield, false, "Yield thread in RDMA polling mode.");
DEFINE_bool(rdma_disable_bthread, false, "Disable bthread in RDMA");
DEFINE_bool(rdma_ece, false, "Open ece in RDMA, should use this feature when rdma nics are from the same merchant.");
DEFINE_bool(rdma_extend, false, "Use the extend fields to negotiate the advance feature of rdma, such as mtu.");

static const size_t IOBUF_BLOCK_HEADER_LEN = 32; // implementation-dependent

Expand All @@ -81,14 +82,16 @@ static const size_t RESERVED_WR_NUM = 3;
// block size (4B)
// sq size (2B)
// rq size (2B)
// lid size (2B)
// GID (16B)
// QP number (4B)
// mtu type (2B)
static const char* MAGIC_STR = "RDMA";
static const size_t MAGIC_STR_LEN = 4;
static const size_t HELLO_MSG_LEN_MIN = 40;
// static const size_t HELLO_MSG_LEN_MAX = 4096;
static const size_t ACK_MSG_LEN = 4;
static uint16_t g_rdma_hello_msg_len = 40; // In Byte
static uint16_t g_rdma_hello_msg_len = 42; // In Byte
static uint16_t g_rdma_hello_version = 2;
Comment on lines +94 to 95
static uint16_t g_rdma_impl_version = 1;
static uint32_t g_rdma_recv_block_size = 0;
Expand All @@ -105,10 +108,16 @@ static const uint32_t ACK_MSG_RDMA_OK = 0x1;
static butil::Mutex* g_rdma_resource_mutex = NULL;
static RdmaResource* g_rdma_resource_list = NULL;

// The HelloMessage should have all base fields, and the new versions of HelloMessage
// maybe have some extern fields.

struct HelloMessage {
void Serialize(void* data) const;
void Deserialize(void* data);
void BaseSerialize(void* data) const;
void ExtSerialize(void* data) const;
void BaseDeserialize(void* data);
uint16_t ExtDeserialize(void* data, uint16_t ext_len);

// base fields
uint16_t msg_len;
uint16_t hello_ver;
uint16_t impl_ver;
Expand All @@ -118,9 +127,12 @@ struct HelloMessage {
uint16_t lid;
ibv_gid gid;
uint32_t qp_num;

// extern fields
uint16_t mtu_type;
};

void HelloMessage::Serialize(void* data) const {
void HelloMessage::BaseSerialize(void* data) const {
uint16_t* current_pos = (uint16_t*)data;
*(current_pos++) = butil::HostToNet16(msg_len);
*(current_pos++) = butil::HostToNet16(hello_ver);
Expand All @@ -132,11 +144,17 @@ void HelloMessage::Serialize(void* data) const {
*(current_pos++) = butil::HostToNet16(rq_size);
*(current_pos++) = butil::HostToNet16(lid);
memcpy(current_pos, gid.raw, 16);
uint32_t* qp_num_pos = (uint32_t*)((char*)current_pos + 16);
current_pos += 8;
uint32_t* qp_num_pos = (uint32_t*)(current_pos);
*qp_num_pos = butil::HostToNet32(qp_num);
}

void HelloMessage::Deserialize(void* data) {
void HelloMessage::ExtSerialize(void* data) const {
uint16_t* current_pos = (uint16_t*)data;
*(current_pos) = butil::HostToNet16(mtu_type);
}

void HelloMessage::BaseDeserialize(void* data) {
uint16_t* current_pos = (uint16_t*)data;
msg_len = butil::NetToHost16(*current_pos++);
hello_ver = butil::NetToHost16(*current_pos++);
Expand All @@ -147,7 +165,26 @@ void HelloMessage::Deserialize(void* data) {
rq_size = butil::NetToHost16(*current_pos++);
lid = butil::NetToHost16(*current_pos++);
memcpy(gid.raw, current_pos, 16);
qp_num = butil::NetToHost32(*(uint32_t*)((char*)current_pos + 16));
current_pos += 8;
qp_num = butil::NetToHost32(*(uint32_t*)(current_pos));
}

uint16_t HelloMessage::ExtDeserialize(void* data, uint16_t ext_len) {
if (ext_len == 0) {
return 0;
}

uint16_t remain_ext_len = ext_len;

// try to deserialize mtu_type
if (remain_ext_len < 2) {
LOG(FATAL) << "illegal HelloMessage, remain ext len is " << remain_ext_len << ", should not be less than 2!!!";
}
uint16_t* current_pos = (uint16_t*)data;
mtu_type = butil::NetToHost16(*current_pos++);
remain_ext_len -= 2;

return remain_ext_len;
}

RdmaResource::~RdmaResource() {
Expand Down Expand Up @@ -435,6 +472,7 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
<< "Start handshake on " << s->_local_side;

uint8_t data[g_rdma_hello_msg_len];
uint16_t local_mtu_type = GetLocalMtuType();

// First initialize CQ and QP resources
ep->_state = C_ALLOC_QPCQ;
Expand Down Expand Up @@ -463,9 +501,15 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
// Only happens in UT
local_msg.qp_num = 0;
}
local_msg.mtu_type = local_mtu_type;
memcpy(data, MAGIC_STR, 4);
local_msg.Serialize((char*)data + 4);
if (ep->WriteToFd(data, g_rdma_hello_msg_len) < 0) {
local_msg.BaseSerialize((char*)data + 4);
// If FLAGS_rdma_extend is not open, only send base fields of HelloMessage
if (FLAGS_rdma_extend) {
local_msg.ExtSerialize((char*)data + HELLO_MSG_LEN_MIN);
}
size_t msg_len = FLAGS_rdma_extend ? g_rdma_hello_msg_len : HELLO_MSG_LEN_MIN;
if (ep->WriteToFd(data, msg_len) < 0) {
const int saved_errno = errno;
PLOG(WARNING) << "Fail to send hello message to server:" << s->description();
s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s",
Expand Down Expand Up @@ -502,7 +546,7 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
return NULL;
}
HelloMessage remote_msg;
remote_msg.Deserialize(data);
remote_msg.BaseDeserialize(data);
if (remote_msg.msg_len < HELLO_MSG_LEN_MIN) {
LOG(WARNING) << "Fail to parse Hello Message length from server:"
<< s->description();
Expand All @@ -512,9 +556,27 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
return NULL;
}

if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) {
// TODO: Read Hello Message customized data
// Just for future use, should not happen now
// In older versions of brpc, IBV_MTU_1024 is the default mtu type,
// So we set remote_mtu IBV_MTU_1024 at default to be ompatible with older versions.
uint16_t remote_mtu_type = IBV_MTU_1024;
if (FLAGS_rdma_extend && remote_msg.msg_len > HELLO_MSG_LEN_MIN) {
// Read Hello Message customized data
uint16_t remote_msg_ext_len = remote_msg.msg_len - HELLO_MSG_LEN_MIN;
uint8_t ext_data[remote_msg_ext_len];
if (ep->ReadFromFd(ext_data, remote_msg_ext_len) < 0) {
const int saved_errno = errno;
PLOG(WARNING) << "Fail to get Hello Message ext fields from server:" << s->description();
s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s",
s->description().c_str(), berror(saved_errno));
ep->_state = FAILED;
return NULL;
}
remote_msg.ExtDeserialize(ext_data, remote_msg_ext_len);
if (remote_msg_ext_len >= 2) {
// mtu_type field is valid
remote_mtu_type = remote_msg.mtu_type;
}
Comment on lines +575 to +578
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By setting the default value of HelloMessage::mtu_type to IBV_MTU_1024, it's no longer necessary to check remote_msg_ext_len.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this check, because the brpc version of peer can be older or newer, we should make sure the mtu_type field is valid in remote_msg.

// TODO: other extern fields
}

if (!HelloNegotiationValid(remote_msg)) {
Expand All @@ -534,7 +596,9 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
ep->_local_window_capacity, butil::memory_order_relaxed);

ep->_state = C_BRINGUP_QP;
if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) {
// use the minimum of local mtu type and remote mtu type
uint16_t min_mtu_type = std::min(local_mtu_type, remote_mtu_type);
if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num, min_mtu_type) < 0) {
LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" << s->description();
Comment on lines +599 to 602
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
} else {
Expand Down Expand Up @@ -582,6 +646,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
<< "Start handshake on " << s->description();

uint8_t data[g_rdma_hello_msg_len];
uint16_t local_mtu_type = GetLocalMtuType();

ep->_state = S_HELLO_WAIT;
if (ep->ReadFromFd(data, MAGIC_STR_LEN) < 0) {
Expand All @@ -605,7 +670,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
return NULL;
}

if (ep->ReadFromFd(data, g_rdma_hello_msg_len - MAGIC_STR_LEN) < 0) {
if (ep->ReadFromFd(data, HELLO_MSG_LEN_MIN - MAGIC_STR_LEN) < 0) {
const int saved_errno = errno;
PLOG(WARNING) << "Fail to read Hello Message from client:" << s->description();
s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s",
Expand All @@ -615,7 +680,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
}

HelloMessage remote_msg;
remote_msg.Deserialize(data);
remote_msg.BaseDeserialize(data);
if (remote_msg.msg_len < HELLO_MSG_LEN_MIN) {
LOG(WARNING) << "Fail to parse Hello Message length from client:"
<< s->description();
Expand All @@ -624,9 +689,28 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
ep->_state = FAILED;
return NULL;
}
if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) {
// TODO: Read Hello Message customized header
// Just for future use, should not happen now

// In older versions of brpc, IBV_MTU_1024 is the default mtu type,
// So we set remote_mtu IBV_MTU_1024 at default to be ompatible with older versions.
uint16_t remote_mtu_type = IBV_MTU_1024;
if (FLAGS_rdma_extend && remote_msg.msg_len > HELLO_MSG_LEN_MIN) {
// Read Hello Message customized data
uint16_t remote_msg_ext_len = remote_msg.msg_len - HELLO_MSG_LEN_MIN;
uint8_t ext_data[remote_msg_ext_len];
if (ep->ReadFromFd(ext_data, remote_msg_ext_len) < 0) {
const int saved_errno = errno;
PLOG(WARNING) << "Fail to get Hello Message ext fields from client:" << s->description();
s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s",
s->description().c_str(), berror(saved_errno));
ep->_state = FAILED;
return NULL;
}
remote_msg.ExtDeserialize(ext_data, remote_msg_ext_len);
if (remote_msg_ext_len >= 2) {
// mtu_type field is valid
remote_mtu_type = remote_msg.mtu_type;
}
// TODO: other extern fields
}

if (!HelloNegotiationValid(remote_msg)) {
Expand All @@ -652,7 +736,9 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
} else {
ep->_state = S_BRINGUP_QP;
if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) {
// use the minimum of local mtu type and remote mtu type
uint16_t min_mtu_type = std::min(local_mtu_type, remote_mtu_type);
if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num, min_mtu_type) < 0) {
LOG(WARNING) << "Fail to bringup QP, fallback to tcp:"
<< s->description();
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
Expand Down Expand Up @@ -681,10 +767,16 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
// Only happens in UT
local_msg.qp_num = 0;
}
local_msg.mtu_type = local_mtu_type;
}
memcpy(data, MAGIC_STR, 4);
local_msg.Serialize((char*)data + 4);
if (ep->WriteToFd(data, g_rdma_hello_msg_len) < 0) {
local_msg.BaseSerialize((char*)data + 4);
// If FLAGS_rdma_extend is not open, only send base fields of HelloMessage
if (FLAGS_rdma_extend) {
local_msg.ExtSerialize((char*)data + HELLO_MSG_LEN_MIN);
}
size_t msg_len = FLAGS_rdma_extend ? g_rdma_hello_msg_len : HELLO_MSG_LEN_MIN;
if (ep->WriteToFd(data, msg_len) < 0) {
const int saved_errno = errno;
PLOG(WARNING) << "Fail to send Hello Message to client:" << s->description();
s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s",
Expand Down Expand Up @@ -1232,12 +1324,27 @@ int RdmaEndpoint::AllocateResources() {
return 0;
}

int RdmaEndpoint::BringUpQp(uint16_t lid, ibv_gid gid, uint32_t qp_num) {
int RdmaEndpoint::BringUpQp(uint16_t lid, ibv_gid gid, uint32_t qp_num, uint16_t mtu_type) {
if (BAIDU_UNLIKELY(g_skip_rdma_init)) {
// For UT
return 0;
}

if (mtu_type == IBV_MTU_256) {
LOG(INFO) << "negotiated mtu is 256";
} else if (mtu_type == IBV_MTU_512) {
LOG(INFO) << "negotiated mtu is 512";
} else if (mtu_type == IBV_MTU_1024) {
LOG(INFO) << "negotiated mtu is 1024";
} else if (mtu_type == IBV_MTU_2048) {
LOG(INFO) << "negotiated mtu is 2048";
} else if (mtu_type == IBV_MTU_4096) {
LOG(INFO) << "negotiated mtu is 4096";
Comment on lines +1334 to +1342
} else {
LOG(ERROR) << "unknown mtu " << mtu_type;
return -1;
}

ibv_qp_attr attr;

attr.qp_state = IBV_QPS_INIT;
Expand Down Expand Up @@ -1275,7 +1382,7 @@ int RdmaEndpoint::BringUpQp(uint16_t lid, ibv_gid gid, uint32_t qp_num) {
}

attr.qp_state = IBV_QPS_RTR;
attr.path_mtu = IBV_MTU_1024; // TODO: support more mtu in future
attr.path_mtu = ibv_mtu(mtu_type);
attr.ah_attr.grh.dgid = gid;
attr.ah_attr.grh.flow_label = 0;
attr.ah_attr.grh.sgid_index = GetRdmaGidIndex();
Expand Down
3 changes: 2 additions & 1 deletion src/brpc/rdma/rdma_endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,11 @@ friend class Socket;
// lid: remote LID
// gid: remote GID
// qp_num: remote QP number
// mtu_type: the minimum of local mtu_type and remote mtu_type
// Return:
// 0: success
// -1: failed, errno set
int BringUpQp(uint16_t lid, ibv_gid gid, uint32_t qp_num);
int BringUpQp(uint16_t lid, ibv_gid gid, uint32_t qp_num, uint16_t mtu_type);

// Get event from comp channel and ack the events
int GetAndAckEvents(SocketUniquePtr& s);
Expand Down
41 changes: 41 additions & 0 deletions src/brpc/rdma/rdma_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ static int g_comp_vector_index = 0;

butil::atomic<bool> g_rdma_available(false);

static uint16_t local_mtu_type = IBV_MTU_4096;

DEFINE_int32(rdma_max_sge, 0, "Max SGE num in a WR");
DEFINE_string(rdma_device, "", "The name of the HCA device used "
"(Empty means using the first active device)");
Expand Down Expand Up @@ -455,6 +457,36 @@ static ibv_context* OpenDevice(int num_total, int* num_available_devices) {
return ret_context;
}

static uint16_t detect_mtu(struct ibv_context* ctx, int port_num) {
struct ibv_port_attr port_attr;

if (IbvQueryPort(ctx, port_num, &port_attr)) {
LOG(ERROR) << "IbvQueryPort failed";
return 0;
}

LOG(INFO) << "local active mtu type:" << port_attr.active_mtu
<< ", max mtu type:" << port_attr.max_mtu;

uint16_t mtu_type = port_attr.active_mtu;
if (mtu_type == IBV_MTU_256) {
LOG(INFO) << "local mtu is 256";
} else if (mtu_type == IBV_MTU_512) {
LOG(INFO) << "local mtu is 512";
} else if (mtu_type == IBV_MTU_1024) {
LOG(INFO) << "local mtu is 1024";
} else if (mtu_type == IBV_MTU_2048) {
LOG(INFO) << "local mtu is 2048";
} else if (mtu_type == IBV_MTU_4096) {
LOG(INFO) << "local mtu is 4096";
} else {
LOG(ERROR) << "unknown mtu type " << mtu_type;
return 0;
}

return mtu_type;
}

static void GlobalRdmaInitializeOrDieImpl() {
if (BAIDU_UNLIKELY(g_skip_rdma_init)) {
// Just for UT
Expand Down Expand Up @@ -549,6 +581,11 @@ static void GlobalRdmaInitializeOrDieImpl() {
g_max_sge = attr.max_sge;
}

local_mtu_type = detect_mtu(g_context, g_port_num);
if (!local_mtu_type) {
PLOG(ERROR) << "Fail to get local mtu type";
ExitWithError();
}
// Initialize RDMA memory pool (block_pool)
if (!InitBlockPool(RdmaRegisterMemory)) {
PLOG(ERROR) << "Fail to initialize RDMA memory pool";
Expand Down Expand Up @@ -701,6 +738,10 @@ bool SupportedByRdma(std::string protocol) {
return false;
}

uint16_t GetLocalMtuType() {
return local_mtu_type;
}

bool InitPollingModeWithTag(bthread_tag_t tag,
std::function<void(void)> callback,
std::function<void(void)> init_fn,
Expand Down
1 change: 1 addition & 0 deletions src/brpc/rdma/rdma_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ void GlobalDisableRdma();
// If the given protocol supported by RDMA
bool SupportedByRdma(std::string protocol);

uint16_t GetLocalMtuType();
} // namespace rdma
} // namespace brpc
#else
Expand Down
Loading