diff --git a/src/modules/transport/common/env_defs.h b/src/modules/transport/common/env_defs.h index aafc312..db7b0ed 100644 --- a/src/modules/transport/common/env_defs.h +++ b/src/modules/transport/common/env_defs.h @@ -49,6 +49,8 @@ NVSHMEMI_ENV_DEF(DISABLE_DATA_DIRECT, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT, "Disable use of directNIC in IB Transport") NVSHMEMI_ENV_DEF(IB_DISABLE_DMABUF, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT, "Disable use of DMABUF in IBRC/IBDEVX/IBGDA Transports") +NVSHMEMI_ENV_DEF(IB_DISABLE_SRQ, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT, + "Disable use of srq in IB RC Transport") NVSHMEMI_ENV_DEF(IB_GID_INDEX, int, -1, NVSHMEMI_ENV_CAT_TRANSPORT, "Source GID Index for ROCE") NVSHMEMI_ENV_DEF(IB_TRAFFIC_CLASS, int, 0, NVSHMEMI_ENV_CAT_TRANSPORT, "Traffic calss for ROCE") NVSHMEMI_ENV_DEF(IB_ADDR_FAMILY, string, "AF_INET", NVSHMEMI_ENV_CAT_TRANSPORT, diff --git a/src/modules/transport/common/transport_ib_common.h b/src/modules/transport/common/transport_ib_common.h index 4ea66d4..71617cc 100644 --- a/src/modules/transport/common/transport_ib_common.h +++ b/src/modules/transport/common/transport_ib_common.h @@ -109,6 +109,7 @@ struct nvshmemt_ib_common_state { int selected_dev_id; int log_level; bool dmabuf_support; + bool srq_support; nvshmemt_ib_common_ep_ptr_t cst_ep; nvshmemt_ib_common_ep_ptr_t *ep; struct nvshmemi_options_s *options; diff --git a/src/modules/transport/ibrc/ibrc.cpp b/src/modules/transport/ibrc/ibrc.cpp index bcea10f..8a85706 100644 --- a/src/modules/transport/ibrc/ibrc.cpp +++ b/src/modules/transport/ibrc/ibrc.cpp @@ -47,8 +47,6 @@ #endif // IWYU pragma: no_include -#define IBRC_MAX_INLINE_SIZE 128 - // Helper functions to access qp_depth and srq_depth from state static inline int get_ibrc_qp_depth(nvshmemt_ib_common_state_t state) { return state->qp_depth; } @@ -58,6 +56,8 @@ static inline int get_ibrc_srq_depth(nvshmemt_ib_common_state_t state) { return #define IBRC_REQUEST_QUEUE_MASK(state) (get_ibrc_qp_depth(state) - 1) #define IBRC_BUF_SIZE 64 +#define IBRC_MAX_INLINE_SIZE IBRC_BUF_SIZE + #if defined(NVSHMEM_X86_64) #define IBRC_CACHELINE 64 #elif defined(NVSHMEM_PPC64LE) @@ -130,6 +130,8 @@ struct ibrc_ep { struct ibv_cq *send_cq; struct ibv_cq *recv_cq; struct ibrc_request *req; + struct ibv_mr *recv_bufs_mr; + ibrc_buf_t *recv_bufs; }; typedef struct ibrc_mem_handle_info { @@ -196,21 +198,26 @@ ibrc_mem_handle_info_t *get_mem_handle_info(nvshmem_transport_t t, void *gpu_ptr return (ibrc_mem_handle_info_t *)nvshmemt_mem_handle_cache_get(t, ibrc_state->cache, gpu_ptr); } -inline int refill_srq(struct ibrc_device *device, nvshmemt_ib_common_state_t ibrc_state) { +inline void init_recv_buf(ibrc_buf_t *buf) { + buf->rwr.next = NULL; + buf->rwr.wr_id = (uint64_t)buf; + buf->rwr.sg_list = &(buf->sge); + buf->rwr.num_sge = 1; + + buf->sge.addr = (uint64_t)buf->buf; + buf->sge.length = IBRC_BUF_SIZE; +} + +inline int refill_srq(struct ibrc_device *device, nvshmemt_ib_common_state_t ibrc_state, ibrc_buf_t *buf) { int status = 0; + if (buf) + bpool_free.push_back((void *)buf); + while ((device->srq_posted < get_ibrc_srq_depth(ibrc_state)) && !bpool_free.empty()) { ibrc_buf_t *buf = (ibrc_buf_t *)bpool_free.back(); - buf->rwr.next = NULL; - buf->rwr.wr_id = (uint64_t)buf; - buf->rwr.sg_list = &(buf->sge); - buf->rwr.num_sge = 1; - - buf->sge.addr = (uint64_t)buf->buf; - buf->sge.length = IBRC_BUF_SIZE; buf->sge.lkey = device->bpool_mr->lkey; - status = ibv_post_srq_recv(device->srq, &buf->rwr, &buf->bad_rwr); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibv_post_srq_recv failed \n"); @@ -222,6 +229,21 @@ inline int refill_srq(struct ibrc_device *device, nvshmemt_ib_common_state_t ibr return status; } +inline int refill_rq(struct ibrc_ep *ep, ibrc_buf_t *buf) { + int status = 0; + + if (buf == NULL) + return 0; + + buf->sge.lkey = ep->recv_bufs_mr->lkey; + + status = ibv_post_recv(ep->qp, &buf->rwr, &buf->bad_rwr); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibv_post_recv failed \n"); + +out: + return status; +} + int nvshmemt_ibrc_show_info(struct nvshmem_transport *transport, int style) { NVSHMEMI_ERROR_PRINT("ibrc show info not implemented"); return 0; @@ -285,33 +307,48 @@ static int ep_create(void **ep_ptr, int devid, nvshmem_transport_t t) { assert(device->send_cq != NULL); ep->send_cq = device->send_cq; - if (!device->srq) { - struct ibv_srq_init_attr srq_init_attr; - memset(&srq_init_attr, 0, sizeof(srq_init_attr)); + if (ibrc_state->srq_support) { + if (!device->srq) { + struct ibv_srq_init_attr srq_init_attr; + memset(&srq_init_attr, 0, sizeof(srq_init_attr)); - srq_init_attr.attr.max_wr = get_ibrc_srq_depth(ibrc_state); - srq_init_attr.attr.max_sge = 1; + srq_init_attr.attr.max_wr = get_ibrc_srq_depth(ibrc_state); + srq_init_attr.attr.max_sge = 1; - device->srq = ftable.create_srq(pd, &srq_init_attr); - NVSHMEMI_NULL_ERROR_JMP(device->srq, status, NVSHMEMX_ERROR_INTERNAL, out, - "srq creation failed \n"); + device->srq = ftable.create_srq(pd, &srq_init_attr); + NVSHMEMI_NULL_ERROR_JMP(device->srq, status, NVSHMEMX_ERROR_INTERNAL, out, + "srq creation failed \n"); - device->recv_cq = ftable.create_cq(context, get_ibrc_srq_depth(ibrc_state), NULL, NULL, 0); - NVSHMEMI_NULL_ERROR_JMP(device->recv_cq, status, NVSHMEMX_ERROR_INTERNAL, out, - "cq creation failed \n"); + device->recv_cq = ftable.create_cq(context, get_ibrc_srq_depth(ibrc_state), NULL, NULL, 0); + NVSHMEMI_NULL_ERROR_JMP(device->recv_cq, status, NVSHMEMX_ERROR_INTERNAL, out, + "cq creation failed \n"); + } + } else { + if (!device->recv_cq) { + device->recv_cq = ftable.create_cq(context, device->common_device.device_attr.max_cqe, NULL, NULL, 0); + NVSHMEMI_NULL_ERROR_JMP(device->recv_cq, status, NVSHMEMX_ERROR_INTERNAL, out, + "recv cq creation failed \n"); + } } assert(device->recv_cq != NULL); ep->recv_cq = device->recv_cq; memset(&init_attr, 0, sizeof(struct ibv_qp_init_attr)); - init_attr.srq = device->srq; + if (ibrc_state->srq_support) { + init_attr.srq = device->srq; + init_attr.cap.max_recv_wr = 0; + init_attr.cap.max_recv_sge = 0; + } else { + init_attr.srq = NULL; + init_attr.cap.max_recv_wr = get_ibrc_qp_depth(ibrc_state); + init_attr.cap.max_recv_sge = 1; + } + init_attr.send_cq = ep->send_cq; init_attr.recv_cq = ep->recv_cq; init_attr.qp_type = IBV_QPT_RC; init_attr.cap.max_send_wr = get_ibrc_qp_depth(ibrc_state); - init_attr.cap.max_recv_wr = 0; init_attr.cap.max_send_sge = 1; - init_attr.cap.max_recv_sge = 0; init_attr.cap.max_inline_data = IBRC_MAX_INLINE_SIZE; ep->qp = ftable.create_qp(pd, &init_attr); @@ -328,6 +365,16 @@ static int ep_create(void **ep_ptr, int devid, nvshmem_transport_t t) { status = ftable.modify_qp(ep->qp, &attr, flags); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibv_modify_qp failed \n"); + if (!ibrc_state->srq_support) { + ep->recv_bufs = (ibrc_buf_t *)malloc(sizeof(ibrc_buf_t) * get_ibrc_qp_depth(ibrc_state)); + ep->recv_bufs_mr = ftable.reg_mr( + pd, ep->recv_bufs, get_ibrc_qp_depth(ibrc_state) * sizeof(ibrc_buf_t), + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ); + } else { + ep->recv_bufs = NULL; + ep->recv_bufs_mr = NULL; + } + ep->req = (struct ibrc_request *)malloc(sizeof(struct ibrc_request) * get_ibrc_qp_depth(ibrc_state)); NVSHMEMI_NULL_ERROR_JMP(ep->req, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, @@ -414,17 +461,25 @@ static int ep_connect(struct ibrc_ep *ep, struct nvshmemt_ib_common_ep_handle *e NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibv_modify_qp failed \n"); // register and post receive buffer pool - if (!device->bpool_mr) { - device->bpool_mr = ftable.reg_mr( - device->common_device.pd, bpool, bpool_size * sizeof(ibrc_buf_t), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ); - NVSHMEMI_NULL_ERROR_JMP(device->bpool_mr, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "mem registration failed \n"); - - assert(device->srq != NULL); - - status = refill_srq(device, ibrc_state); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_srq failed \n"); + if (ibrc_state->srq_support) { + if (!device->bpool_mr) { + device->bpool_mr = ftable.reg_mr( + device->common_device.pd, bpool, bpool_size * sizeof(ibrc_buf_t), + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ); + NVSHMEMI_NULL_ERROR_JMP(device->bpool_mr, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "mem registration failed \n"); + + assert(device->srq != NULL); + + status = refill_srq(device, ibrc_state, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_srq failed \n"); + } + } else { + for (int i = 0; i < get_ibrc_qp_depth(ibrc_state); i++) { + init_recv_buf(ep->recv_bufs + i); + status = refill_rq(ep, ep->recv_bufs + i); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_rq failed \n"); + } } connected_qp_count++; @@ -652,15 +707,35 @@ int nvshmemt_ibrc_finalize(nvshmem_transport_t transport) { } if (state->ep) { for (int i = 0; i < state->ep_count; i++) { - status = ftable.destroy_qp(((struct ibrc_ep *)state->ep[i])->qp); + struct ibrc_ep *ep = (struct ibrc_ep *)state->ep[i]; + status = ftable.destroy_qp(ep->qp); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibv_destroy_qp failed \n"); + if (ep->recv_bufs_mr) { + status = ftable.dereg_mr(ep->recv_bufs_mr); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "ibv_dereg_mr failed \n"); + } + if (ep->recv_bufs) + free(ep->recv_bufs); + if (ep->req) + free(ep->req); } free(state->ep); } if (state->cst_ep) { - status = ftable.destroy_qp(((struct ibrc_ep *)state->cst_ep)->qp); + struct ibrc_ep *ep = (struct ibrc_ep *)state->cst_ep; + status = ftable.destroy_qp(ep->qp); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibv_destroy_qp failed \n"); + if (ep->recv_bufs_mr) { + status = ftable.dereg_mr(ep->recv_bufs_mr); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "ibv_dereg_mr failed \n"); + } + if (ep->recv_bufs) + free(ep->recv_bufs); + if (ep->req) + free(ep->req); free(state->cst_ep); state->cst_ep = NULL; } @@ -954,27 +1029,32 @@ int poll_recv(nvshmemt_ib_common_state_t ibrc_state) { if (wc.wc_flags & IBV_WC_WITH_IMM) { atomics_acked++; TRACE(ibrc_state->log_level, "[%d] atomic acked : %lu \n", getpid(), atomics_acked); - bpool_free.push_back((void *)buf); } else { struct ibrc_atomic_op *op = (struct ibrc_atomic_op *)buf->buf; if (op->op == NVSHMEMI_AMO_ACK) { atomics_acked++; TRACE(ibrc_state->log_level, "[%d] atomic acked : %lu \n", getpid(), atomics_acked); - bpool_free.push_back((void *)buf); } else { buf->qp_num = wc.qp_num; atomics_received++; TRACE(ibrc_state->log_level, "[%d] atomic received, enqueued : %lu \n", getpid(), atomics_received); bqueue_toprocess.push_back((void *)buf); + buf = NULL; } } - device->srq_posted--; - } - status = refill_srq(device, ibrc_state); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_sqr failed \n"); + if (ibrc_state->srq_support) { + device->srq_posted--; + status = refill_srq(device, ibrc_state, buf); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_sqr failed \n"); + } else { + struct ibrc_ep *ep = (struct ibrc_ep *)qp_map.find((unsigned int)wc.qp_num)->second; + status = refill_rq(ep, buf); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_rq failed \n"); + } + } } out: @@ -987,6 +1067,7 @@ int process_recv(nvshmem_transport_t t, nvshmemt_ib_common_state_t ibrc_state) { if (!bqueue_toprocess.empty()) { ibrc_buf_t *buf = (ibrc_buf_t *)bqueue_toprocess.front(); struct ibrc_ep *ep = (struct ibrc_ep *)qp_map.find((unsigned int)buf->qp_num)->second; + struct ibrc_device *device = ((struct ibrc_device *)ibrc_state->devices + ep->devid); struct ibrc_atomic_op *op = (struct ibrc_atomic_op *)buf->buf; ibrc_mem_handle_info_t *mem_handle_info = get_mem_handle_info(t, (void *)op->addr); void *ptr = (void *)((uintptr_t)mem_handle_info->cpu_ptr + @@ -1011,6 +1092,13 @@ int process_recv(nvshmem_transport_t t, nvshmemt_ib_common_state_t ibrc_state) { atomics_processed); bqueue_toprocess.pop_front(); + if (ibrc_state->srq_support) { + status = refill_srq(device, ibrc_state, buf); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_sqr failed \n"); + } else { + status = refill_rq(ep, buf); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "refill_rq failed \n"); + } bpool_free.push_back((void *)buf); } @@ -1584,6 +1672,11 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, if (ibrc_state->options->DISABLE_IB_NATIVE_ATOMICS) { use_ib_native_atomics = 0; } + + ibrc_state->srq_support = true; + if (ibrc_state->options->IB_DISABLE_SRQ) + ibrc_state->srq_support = false; + ibrc_state->qp_depth = ibrc_state->options->QP_DEPTH; ibrc_state->srq_depth = ibrc_state->options->SRQ_DEPTH; // qp_depth and srq_depth are now accessed directly from ibrc_state @@ -1629,6 +1722,9 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, ftable.query_device(device->common_device.context, &device->common_device.device_attr); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibv_query_device failed \n"); + if ((device->common_device.device_attr).max_srq <= 0) + ibrc_state->srq_support = false; + NVSHMEMT_IBRC_MAX_RD_ATOMIC = (device->common_device.device_attr).max_qp_rd_atom; INFO(ibrc_state->log_level, "Enumerated IB devices in the system - device id=%d (of %d), name=%s, num_ports=%d", i, @@ -1781,12 +1877,18 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, } // allocate buffer pool - bpool_size = ibrc_state->srq_depth; - nvshmemi_ib_malloc((void **)&bpool, bpool_size * sizeof(ibrc_buf_t), ibrc_state->log_level); - NVSHMEMI_NULL_ERROR_JMP(bpool, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "buf poll allocation failed \n"); - for (int i = 0; i < bpool_size; i++) { - bpool_free.push_back((void *)(bpool + i)); + if (ibrc_state->srq_support) { + bpool_size = ibrc_state->srq_depth; + nvshmemi_ib_malloc((void **)&bpool, bpool_size * sizeof(ibrc_buf_t), ibrc_state->log_level); + NVSHMEMI_NULL_ERROR_JMP(bpool, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "buf poll allocation failed \n"); + for (int i = 0; i < bpool_size; i++) { + init_recv_buf(bpool + i); + bpool_free.push_back((void *)(bpool + i)); + } + } else { + bpool_size = 0; + bpool = NULL; } transport->host_ops.can_reach_peer = nvshmemt_ibrc_can_reach_peer;