Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/modules/transport/common/env_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ NVSHMEMI_ENV_DEF(DISABLE_DATA_DIRECT, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT,
"Disable use of directNIC in IB Transport")
NVSHMEMI_ENV_DEF(IB_DISABLE_DMABUF, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT,
"Disable use of DMABUF in IBRC/IBDEVX/IBGDA Transports")
NVSHMEMI_ENV_DEF(IB_DISABLE_SRQ, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT,
"Disable use of srq in IB RC Transport")
NVSHMEMI_ENV_DEF(IB_GID_INDEX, int, -1, NVSHMEMI_ENV_CAT_TRANSPORT, "Source GID Index for ROCE")
NVSHMEMI_ENV_DEF(IB_TRAFFIC_CLASS, int, 0, NVSHMEMI_ENV_CAT_TRANSPORT, "Traffic calss for ROCE")
NVSHMEMI_ENV_DEF(IB_ADDR_FAMILY, string, "AF_INET", NVSHMEMI_ENV_CAT_TRANSPORT,
Expand Down
1 change: 1 addition & 0 deletions src/modules/transport/common/transport_ib_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ struct nvshmemt_ib_common_state {
int selected_dev_id;
int log_level;
bool dmabuf_support;
bool srq_support;
nvshmemt_ib_common_ep_ptr_t cst_ep;
nvshmemt_ib_common_ep_ptr_t *ep;
struct nvshmemi_options_s *options;
Expand Down
202 changes: 152 additions & 50 deletions src/modules/transport/ibrc/ibrc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@
#endif
// IWYU pragma: no_include <xmmintrin.h>

#define IBRC_MAX_INLINE_SIZE 128

// Helper functions to access qp_depth and srq_depth from state
static inline int get_ibrc_qp_depth(nvshmemt_ib_common_state_t state) { return state->qp_depth; }

Expand All @@ -58,6 +56,8 @@ static inline int get_ibrc_srq_depth(nvshmemt_ib_common_state_t state) { return
#define IBRC_REQUEST_QUEUE_MASK(state) (get_ibrc_qp_depth(state) - 1)
#define IBRC_BUF_SIZE 64

#define IBRC_MAX_INLINE_SIZE IBRC_BUF_SIZE

#if defined(NVSHMEM_X86_64)
#define IBRC_CACHELINE 64
#elif defined(NVSHMEM_PPC64LE)
Expand Down Expand Up @@ -130,6 +130,8 @@ struct ibrc_ep {
struct ibv_cq *send_cq;
struct ibv_cq *recv_cq;
struct ibrc_request *req;
struct ibv_mr *recv_bufs_mr;
ibrc_buf_t *recv_bufs;
};

typedef struct ibrc_mem_handle_info {
Expand Down Expand Up @@ -196,21 +198,26 @@ ibrc_mem_handle_info_t *get_mem_handle_info(nvshmem_transport_t t, void *gpu_ptr
return (ibrc_mem_handle_info_t *)nvshmemt_mem_handle_cache_get(t, ibrc_state->cache, gpu_ptr);
}

inline int refill_srq(struct ibrc_device *device, nvshmemt_ib_common_state_t ibrc_state) {
inline void init_recv_buf(ibrc_buf_t *buf) {
buf->rwr.next = NULL;
buf->rwr.wr_id = (uint64_t)buf;
buf->rwr.sg_list = &(buf->sge);
buf->rwr.num_sge = 1;

buf->sge.addr = (uint64_t)buf->buf;
buf->sge.length = IBRC_BUF_SIZE;
}

inline int refill_srq(struct ibrc_device *device, nvshmemt_ib_common_state_t ibrc_state, ibrc_buf_t *buf) {
int status = 0;

if (buf)
bpool_free.push_back((void *)buf);

while ((device->srq_posted < get_ibrc_srq_depth(ibrc_state)) && !bpool_free.empty()) {
ibrc_buf_t *buf = (ibrc_buf_t *)bpool_free.back();

buf->rwr.next = NULL;
buf->rwr.wr_id = (uint64_t)buf;
buf->rwr.sg_list = &(buf->sge);
buf->rwr.num_sge = 1;

buf->sge.addr = (uint64_t)buf->buf;
buf->sge.length = IBRC_BUF_SIZE;
buf->sge.lkey = device->bpool_mr->lkey;

status = ibv_post_srq_recv(device->srq, &buf->rwr, &buf->bad_rwr);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibv_post_srq_recv failed \n");

Expand All @@ -222,6 +229,21 @@ inline int refill_srq(struct ibrc_device *device, nvshmemt_ib_common_state_t ibr
return status;
}

inline int refill_rq(struct ibrc_ep *ep, ibrc_buf_t *buf) {
int status = 0;

if (buf == NULL)
return 0;

buf->sge.lkey = ep->recv_bufs_mr->lkey;

status = ibv_post_recv(ep->qp, &buf->rwr, &buf->bad_rwr);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibv_post_recv failed \n");

out:
return status;
}

int nvshmemt_ibrc_show_info(struct nvshmem_transport *transport, int style) {
NVSHMEMI_ERROR_PRINT("ibrc show info not implemented");
return 0;
Expand Down Expand Up @@ -285,33 +307,48 @@ static int ep_create(void **ep_ptr, int devid, nvshmem_transport_t t) {
assert(device->send_cq != NULL);
ep->send_cq = device->send_cq;

if (!device->srq) {
struct ibv_srq_init_attr srq_init_attr;
memset(&srq_init_attr, 0, sizeof(srq_init_attr));
if (ibrc_state->srq_support) {
if (!device->srq) {
struct ibv_srq_init_attr srq_init_attr;
memset(&srq_init_attr, 0, sizeof(srq_init_attr));

srq_init_attr.attr.max_wr = get_ibrc_srq_depth(ibrc_state);
srq_init_attr.attr.max_sge = 1;
srq_init_attr.attr.max_wr = get_ibrc_srq_depth(ibrc_state);
srq_init_attr.attr.max_sge = 1;

device->srq = ftable.create_srq(pd, &srq_init_attr);
NVSHMEMI_NULL_ERROR_JMP(device->srq, status, NVSHMEMX_ERROR_INTERNAL, out,
"srq creation failed \n");
device->srq = ftable.create_srq(pd, &srq_init_attr);
NVSHMEMI_NULL_ERROR_JMP(device->srq, status, NVSHMEMX_ERROR_INTERNAL, out,
"srq creation failed \n");

device->recv_cq = ftable.create_cq(context, get_ibrc_srq_depth(ibrc_state), NULL, NULL, 0);
NVSHMEMI_NULL_ERROR_JMP(device->recv_cq, status, NVSHMEMX_ERROR_INTERNAL, out,
"cq creation failed \n");
device->recv_cq = ftable.create_cq(context, get_ibrc_srq_depth(ibrc_state), NULL, NULL, 0);
NVSHMEMI_NULL_ERROR_JMP(device->recv_cq, status, NVSHMEMX_ERROR_INTERNAL, out,
"cq creation failed \n");
}
} else {
if (!device->recv_cq) {
device->recv_cq = ftable.create_cq(context, device->common_device.device_attr.max_cqe, NULL, NULL, 0);
NVSHMEMI_NULL_ERROR_JMP(device->recv_cq, status, NVSHMEMX_ERROR_INTERNAL, out,
"recv cq creation failed \n");
}
}
assert(device->recv_cq != NULL);
ep->recv_cq = device->recv_cq;

memset(&init_attr, 0, sizeof(struct ibv_qp_init_attr));
init_attr.srq = device->srq;
if (ibrc_state->srq_support) {
init_attr.srq = device->srq;
init_attr.cap.max_recv_wr = 0;
init_attr.cap.max_recv_sge = 0;
} else {
init_attr.srq = NULL;
init_attr.cap.max_recv_wr = get_ibrc_qp_depth(ibrc_state);
init_attr.cap.max_recv_sge = 1;
}

init_attr.send_cq = ep->send_cq;
init_attr.recv_cq = ep->recv_cq;
init_attr.qp_type = IBV_QPT_RC;
init_attr.cap.max_send_wr = get_ibrc_qp_depth(ibrc_state);
init_attr.cap.max_recv_wr = 0;
init_attr.cap.max_send_sge = 1;
init_attr.cap.max_recv_sge = 0;
init_attr.cap.max_inline_data = IBRC_MAX_INLINE_SIZE;

ep->qp = ftable.create_qp(pd, &init_attr);
Expand All @@ -328,6 +365,16 @@ static int ep_create(void **ep_ptr, int devid, nvshmem_transport_t t) {
status = ftable.modify_qp(ep->qp, &attr, flags);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibv_modify_qp failed \n");

if (!ibrc_state->srq_support) {
ep->recv_bufs = (ibrc_buf_t *)malloc(sizeof(ibrc_buf_t) * get_ibrc_qp_depth(ibrc_state));
ep->recv_bufs_mr = ftable.reg_mr(
pd, ep->recv_bufs, get_ibrc_qp_depth(ibrc_state) * sizeof(ibrc_buf_t),
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ);
} else {
ep->recv_bufs = NULL;
ep->recv_bufs_mr = NULL;
}

ep->req =
(struct ibrc_request *)malloc(sizeof(struct ibrc_request) * get_ibrc_qp_depth(ibrc_state));
NVSHMEMI_NULL_ERROR_JMP(ep->req, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out,
Expand Down Expand Up @@ -414,17 +461,25 @@ static int ep_connect(struct ibrc_ep *ep, struct nvshmemt_ib_common_ep_handle *e
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibv_modify_qp failed \n");

// register and post receive buffer pool
if (!device->bpool_mr) {
device->bpool_mr = ftable.reg_mr(
device->common_device.pd, bpool, bpool_size * sizeof(ibrc_buf_t),
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ);
NVSHMEMI_NULL_ERROR_JMP(device->bpool_mr, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out,
"mem registration failed \n");

assert(device->srq != NULL);

status = refill_srq(device, ibrc_state);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_srq failed \n");
if (ibrc_state->srq_support) {
if (!device->bpool_mr) {
device->bpool_mr = ftable.reg_mr(
device->common_device.pd, bpool, bpool_size * sizeof(ibrc_buf_t),
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ);
NVSHMEMI_NULL_ERROR_JMP(device->bpool_mr, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out,
"mem registration failed \n");

assert(device->srq != NULL);

status = refill_srq(device, ibrc_state, NULL);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_srq failed \n");
}
} else {
for (int i = 0; i < get_ibrc_qp_depth(ibrc_state); i++) {
init_recv_buf(ep->recv_bufs + i);
status = refill_rq(ep, ep->recv_bufs + i);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_rq failed \n");
}
}

connected_qp_count++;
Expand Down Expand Up @@ -652,15 +707,35 @@ int nvshmemt_ibrc_finalize(nvshmem_transport_t transport) {
}
if (state->ep) {
for (int i = 0; i < state->ep_count; i++) {
status = ftable.destroy_qp(((struct ibrc_ep *)state->ep[i])->qp);
struct ibrc_ep *ep = (struct ibrc_ep *)state->ep[i];
status = ftable.destroy_qp(ep->qp);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibv_destroy_qp failed \n");
if (ep->recv_bufs_mr) {
status = ftable.dereg_mr(ep->recv_bufs_mr);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
"ibv_dereg_mr failed \n");
}
if (ep->recv_bufs)
free(ep->recv_bufs);
if (ep->req)
free(ep->req);
}
free(state->ep);
}

if (state->cst_ep) {
status = ftable.destroy_qp(((struct ibrc_ep *)state->cst_ep)->qp);
struct ibrc_ep *ep = (struct ibrc_ep *)state->cst_ep;
status = ftable.destroy_qp(ep->qp);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibv_destroy_qp failed \n");
if (ep->recv_bufs_mr) {
status = ftable.dereg_mr(ep->recv_bufs_mr);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
"ibv_dereg_mr failed \n");
}
if (ep->recv_bufs)
free(ep->recv_bufs);
if (ep->req)
free(ep->req);
free(state->cst_ep);
state->cst_ep = NULL;
}
Expand Down Expand Up @@ -954,27 +1029,32 @@ int poll_recv(nvshmemt_ib_common_state_t ibrc_state) {
if (wc.wc_flags & IBV_WC_WITH_IMM) {
atomics_acked++;
TRACE(ibrc_state->log_level, "[%d] atomic acked : %lu \n", getpid(), atomics_acked);
bpool_free.push_back((void *)buf);
} else {
struct ibrc_atomic_op *op = (struct ibrc_atomic_op *)buf->buf;
if (op->op == NVSHMEMI_AMO_ACK) {
atomics_acked++;
TRACE(ibrc_state->log_level, "[%d] atomic acked : %lu \n", getpid(),
atomics_acked);
bpool_free.push_back((void *)buf);
} else {
buf->qp_num = wc.qp_num;
atomics_received++;
TRACE(ibrc_state->log_level, "[%d] atomic received, enqueued : %lu \n",
getpid(), atomics_received);
bqueue_toprocess.push_back((void *)buf);
buf = NULL;
}
}
device->srq_posted--;
}

status = refill_srq(device, ibrc_state);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_sqr failed \n");
if (ibrc_state->srq_support) {
device->srq_posted--;
status = refill_srq(device, ibrc_state, buf);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_sqr failed \n");
} else {
struct ibrc_ep *ep = (struct ibrc_ep *)qp_map.find((unsigned int)wc.qp_num)->second;
status = refill_rq(ep, buf);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_rq failed \n");
}
}
}

out:
Expand All @@ -987,6 +1067,7 @@ int process_recv(nvshmem_transport_t t, nvshmemt_ib_common_state_t ibrc_state) {
if (!bqueue_toprocess.empty()) {
ibrc_buf_t *buf = (ibrc_buf_t *)bqueue_toprocess.front();
struct ibrc_ep *ep = (struct ibrc_ep *)qp_map.find((unsigned int)buf->qp_num)->second;
struct ibrc_device *device = ((struct ibrc_device *)ibrc_state->devices + ep->devid);
struct ibrc_atomic_op *op = (struct ibrc_atomic_op *)buf->buf;
ibrc_mem_handle_info_t *mem_handle_info = get_mem_handle_info(t, (void *)op->addr);
void *ptr = (void *)((uintptr_t)mem_handle_info->cpu_ptr +
Expand All @@ -1011,6 +1092,13 @@ int process_recv(nvshmem_transport_t t, nvshmemt_ib_common_state_t ibrc_state) {
atomics_processed);

bqueue_toprocess.pop_front();
if (ibrc_state->srq_support) {
status = refill_srq(device, ibrc_state, buf);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_sqr failed \n");
} else {
status = refill_rq(ep, buf);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_rq failed \n");
}
bpool_free.push_back((void *)buf);
}

Expand Down Expand Up @@ -1584,6 +1672,11 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table,
if (ibrc_state->options->DISABLE_IB_NATIVE_ATOMICS) {
use_ib_native_atomics = 0;
}

ibrc_state->srq_support = true;
if (ibrc_state->options->IB_DISABLE_SRQ)
ibrc_state->srq_support = false;

ibrc_state->qp_depth = ibrc_state->options->QP_DEPTH;
ibrc_state->srq_depth = ibrc_state->options->SRQ_DEPTH;
// qp_depth and srq_depth are now accessed directly from ibrc_state
Expand Down Expand Up @@ -1629,6 +1722,9 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table,
ftable.query_device(device->common_device.context, &device->common_device.device_attr);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibv_query_device failed \n");

if ((device->common_device.device_attr).max_srq <= 0)
ibrc_state->srq_support = false;

NVSHMEMT_IBRC_MAX_RD_ATOMIC = (device->common_device.device_attr).max_qp_rd_atom;
INFO(ibrc_state->log_level,
"Enumerated IB devices in the system - device id=%d (of %d), name=%s, num_ports=%d", i,
Expand Down Expand Up @@ -1781,12 +1877,18 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table,
}

// allocate buffer pool
bpool_size = ibrc_state->srq_depth;
nvshmemi_ib_malloc((void **)&bpool, bpool_size * sizeof(ibrc_buf_t), ibrc_state->log_level);
NVSHMEMI_NULL_ERROR_JMP(bpool, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out,
"buf poll allocation failed \n");
for (int i = 0; i < bpool_size; i++) {
bpool_free.push_back((void *)(bpool + i));
if (ibrc_state->srq_support) {
bpool_size = ibrc_state->srq_depth;
nvshmemi_ib_malloc((void **)&bpool, bpool_size * sizeof(ibrc_buf_t), ibrc_state->log_level);
NVSHMEMI_NULL_ERROR_JMP(bpool, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out,
"buf poll allocation failed \n");
for (int i = 0; i < bpool_size; i++) {
init_recv_buf(bpool + i);
bpool_free.push_back((void *)(bpool + i));
}
} else {
bpool_size = 0;
bpool = NULL;
}

transport->host_ops.can_reach_peer = nvshmemt_ibrc_can_reach_peer;
Expand Down