Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 11 additions & 39 deletions src/host/proxy/proxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,51 +686,23 @@ int process_channel_amo(proxy_state_t *state, proxy_channel_t *ch, int *is_proce
}

void enforce_cst(proxy_state_t *proxy_state) {
#if defined(NVSHMEM_X86_64)
nvshmemi_state_t *state = proxy_state->nvshmemi_state;
#endif

int status = 0;

if (nvshmemi_options.BYPASS_FLUSH) return;

if (proxy_state->is_consistency_api_supported) {
if (CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER > proxy_state->gdr_device_native_ordering &&
CUPFN(nvshmemi_cuda_syms, cuFlushGPUDirectRDMAWrites)) {
status =
CUPFN(nvshmemi_cuda_syms,
cuFlushGPUDirectRDMAWrites(CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TARGET_CURRENT_CTX,
CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER));
/** We would want to use cudaFlushGPUDirectRDMAWritesToAllDevices when we enable
consistent access of data on any GPU (and not just self GPU) with
wait_until, quiet, barrier, etc. **/
if (status != CUDA_SUCCESS) {
NVSHMEMI_ERROR_EXIT("cuFlushGPUDirectRDMAWrites() failed in the proxy thread \n");
}
}
return;
}
#if defined(NVSHMEM_PPC64LE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine to remove since Power support is deprecated/removed.

status = cudaEventRecord(proxy_state->cuev, proxy_state->stream);
if (unlikely(status != CUDA_SUCCESS)) {
NVSHMEMI_ERROR_EXIT("cuEventRecord() failed in the proxy thread \n");
}
#elif defined(NVSHMEM_X86_64)
for (int i = 0; i < state->num_initialized_transports; i++) {
if (!((state->transport_bitmap) & (1 << i))) continue;
struct nvshmem_transport *tcurr = state->transports[i];
if (!tcurr->host_ops.enforce_cst) continue;

// assuming the transport is connected - IB RC
if (tcurr->attr & NVSHMEM_TRANSPORT_ATTR_CONNECTED) {
status = tcurr->host_ops.enforce_cst(tcurr);
if (status) {
NVSHMEMI_ERROR_PRINT("aborting due to error in progress_cst \n");
exit(-1);
}
if (CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER > proxy_state->gdr_device_native_ordering &&
CUPFN(nvshmemi_cuda_syms, cuFlushGPUDirectRDMAWrites)) {
status =
CUPFN(nvshmemi_cuda_syms,
cuFlushGPUDirectRDMAWrites(CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TARGET_CURRENT_CTX,
CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER));
/** We would want to use cudaFlushGPUDirectRDMAWritesToAllDevices when we enable
consistent access of data on any GPU (and not just self GPU) with
wait_until, quiet, barrier, etc. **/
if (status != CUDA_SUCCESS) {
NVSHMEMI_ERROR_EXIT("cuFlushGPUDirectRDMAWrites() failed in the proxy thread \n");
}
}
#endif
}

inline void quiet_ack_channels(proxy_state_t *proxy_state) {
Expand Down
1 change: 0 additions & 1 deletion src/include/internal/host_transport/transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ struct nvshmem_transport_host_ops {
fence_handle fence;
quiet_handle quiet;
put_signal_handle put_signal;
int (*enforce_cst)(struct nvshmem_transport *transport);
Copy link
Collaborator

@seth-howell seth-howell Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The more I look at this, the more I feel like enforce_cst_at_target should be removed from the function table and turned into a flag (somewhere in the nvshmem_transport struct starting on line 187). We are breaking the host-lib to transport API this release anyway, so it would be nice to just rip the band-aid off all at once.
That method has the added bonus of reducing confusion for anyone implementing a custom transport plugin in the future. It will also allow you to remove the enforce_cst code from ibrc.cpp to match the rest of the transports.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do agree that the flag makes more sense than the function table. It will be a little bit of time before I can go back and make this change.

int (*enforce_cst_at_target)(struct nvshmem_transport *transport);
int (*add_device_remote_mem_handles)(struct nvshmem_transport *transport, int transport_stride,
nvshmem_mem_handle_t *mem_handles, uint64_t heap_offset,
Expand Down
41 changes: 0 additions & 41 deletions src/modules/transport/ibdevx/ibdevx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1440,46 +1440,6 @@ int nvshmemt_ibdevx_amo(struct nvshmem_transport *tcurr, int pe, void *curetptr,
return status;
}

int nvshmemt_ibdevx_enforce_cst_at_target(struct nvshmem_transport *tcurr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one we need to keep. It never goes through the proxy and the other code path.

Copy link
Contributor Author

@a-szegel a-szegel Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to keep this? It was only called at: transport->host_ops.enforce_cst = nvshmemt_ibdevx_enforce_cst_at_target;... i.e.. the enforce_cst proxy API that we are getting rid of. I think there are 1 of 2 possible errors.

  1. The function name was named incorrectly
  2. Forgot to assign transport->host_ops.enforce_cst_at_target

nvshmemt_ib_common_state_t ibdevx_state = (nvshmemt_ib_common_state_t)tcurr->state;
struct ibdevx_ep *ep = (struct ibdevx_ep *)ibdevx_state->cst_ep;
struct ibdevx_rw_wqe *wqe;

int status = 0;

uintptr_t wqe_bb_idx_64 = ep->wqe_bb_idx;
uint32_t wqe_bb_idx_32 = ep->wqe_bb_idx;
size_t wqe_size;

wqe = (struct ibdevx_rw_wqe *)((char *)ep->wq_buf +
((wqe_bb_idx_64 % get_ibdevx_qp_depth(ibdevx_state))
<< NVSHMEMT_IBDEVX_WQE_BB_SHIFT));
wqe_size = sizeof(struct ibdevx_rw_wqe);
memset(wqe, 0, sizeof(struct ibdevx_rw_wqe));

wqe->ctrl.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
wqe->ctrl.qpn_ds =
htobe32((uint32_t)(wqe_size / NVSHMEMT_IBDEVX_MLX5_SEND_WQE_DS) | ep->qpid << 8);
wqe->ctrl.opmod_idx_opcode = htobe32(MLX5_OPCODE_RDMA_READ | (wqe_bb_idx_32 << 8));

wqe->raddr.raddr = htobe64((uintptr_t)local_dummy_mr.mr->addr);
wqe->raddr.rkey = htobe32(local_dummy_mr.rkey);

wqe->data.data_seg.byte_count = htobe32((uint32_t)4);
wqe->data.data_seg.lkey = htobe32(local_dummy_mr.lkey);
wqe->data.data_seg.addr = htobe64((uintptr_t)local_dummy_mr.mr->addr);

assert(wqe_size <= MLX5_SEND_WQE_BB);
ep->wqe_bb_idx++;
nvshmemt_ibdevx_post_send(ep, (void *)wqe, 1);

status = nvshmemt_ib_common_check_poll_avail(tcurr, ep, NVSHMEMT_IB_COMMON_WAIT_ALL);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "check_poll failed \n");

out:
return status;
}

// Using common fence and quiet functions from transport_ib_common

int nvshmemt_ibdevx_ep_create(struct ibdevx_ep **ep, int devid, nvshmem_transport_t t,
Expand Down Expand Up @@ -1932,7 +1892,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table,
transport->host_ops.finalize = nvshmemt_ibdevx_finalize;
transport->host_ops.show_info = nvshmemt_ibdevx_show_info;
transport->host_ops.progress = nvshmemt_ibdevx_progress;
transport->host_ops.enforce_cst = nvshmemt_ibdevx_enforce_cst_at_target;
transport->host_ops.put_signal = nvshmemt_put_signal;

transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED;
Expand Down
1 change: 0 additions & 1 deletion src/modules/transport/ibgda/ibgda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4915,7 +4915,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table,
transport->host_ops.amo = NULL;
transport->host_ops.fence = NULL;
transport->host_ops.quiet = NULL;
transport->host_ops.enforce_cst = NULL;
transport->host_ops.add_device_remote_mem_handles =
nvshmemt_ibgda_add_device_remote_mem_handles;
transport->host_ops.put_signal = NULL;
Expand Down
1 change: 0 additions & 1 deletion src/modules/transport/ibrc/ibrc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1800,7 +1800,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table,
transport->host_ops.progress = nvshmemt_ibrc_progress;
transport->host_ops.put_signal = nvshmemt_put_signal;

transport->host_ops.enforce_cst = nvshmemt_ibrc_enforce_cst_at_target;
#if !defined(NVSHMEM_PPC64LE) && !defined(NVSHMEM_AARCH64)
if (!use_gdrcopy)
#endif
Expand Down
79 changes: 0 additions & 79 deletions src/modules/transport/libfabric/libfabric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1071,71 +1071,6 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v
return status;
}

static int nvshmemt_libfabric_enforce_cst(struct nvshmem_transport *tcurr) {
nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state;
uint64_t num_retries = 0;
int status;
int target_ep;
int mype = tcurr->my_pe;

#ifdef NVSHMEM_USE_GDRCOPY
if (use_gdrcopy) {
if (libfabric_state->provider != NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) {
int temp;
nvshmemt_libfabric_memhandle_info_t *mem_handle_info;

mem_handle_info =
(nvshmemt_libfabric_memhandle_info_t *)nvshmemt_mem_handle_cache_get_by_idx(
libfabric_state->cache, 0);
if (!mem_handle_info) {
goto skip;
}
gdrcopy_ftable.copy_from_mapping(mem_handle_info->mh, &temp, mem_handle_info->cpu_ptr,
sizeof(int));
}
}

skip:
#endif

target_ep = mype * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + NVSHMEMT_LIBFABRIC_PROXY_EP_IDX;
do {
struct fi_msg_rma msg;
struct iovec l_iov;
struct fi_rma_iov r_iov;
void *desc = libfabric_state->local_mr_desc[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX];
uint64_t flags = 0;

memset(&msg, 0, sizeof(struct fi_msg_rma));
memset(&l_iov, 0, sizeof(struct iovec));
memset(&r_iov, 0, sizeof(struct fi_rma_iov));

l_iov.iov_base = libfabric_state->local_mem_ptr;
l_iov.iov_len = 8;

r_iov.addr = 0; // Zero offset
r_iov.len = 8;
r_iov.key = libfabric_state->local_mr_key[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX];

msg.msg_iov = &l_iov;
msg.desc = &desc;
msg.iov_count = 1;
msg.rma_iov = &r_iov;
msg.rma_iov_count = 1;
msg.context = NULL;
msg.data = 0;

if (libfabric_state->prov_info->caps & FI_FENCE) flags |= FI_FENCE;

status =
fi_readmsg(libfabric_state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX].endpoint, &msg, flags);
/* This try_again makes an assumption that enforce_cst is only for proxy threaded ops*/
} while (try_again(tcurr, &status, &num_retries, 1));

libfabric_state->eps[target_ep].submitted_ops++;
return status;
}

static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handle,
nvshmem_transport_t t) {
nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)t->state;
Expand Down Expand Up @@ -1177,9 +1112,6 @@ static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handl
max_reg = 1;

for (int i = 0; i < max_reg; i++) {
if (libfabric_state->local_mr[i] == fabric_handle->hdls[i].mr)
libfabric_state->local_mr[i] = NULL;

int status = fi_close(&fabric_handle->hdls[i].mr->fid);
if (status) {
NVSHMEMI_WARN_PRINT("Error releasing mem handle idx %d (%d): %s\n", i, status,
Expand Down Expand Up @@ -1359,15 +1291,6 @@ static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, v
} while (curr_ptr < (char *)buf + length);
}

if (libfabric_state->local_mr[0] == NULL && !local_only) {
for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) {
libfabric_state->local_mr[i] = fabric_handle->hdls[i].mr;
libfabric_state->local_mr_key[i] = fabric_handle->hdls[i].key;
libfabric_state->local_mr_desc[i] = fabric_handle->hdls[i].local_desc;
}
libfabric_state->local_mem_ptr = buf;
}

out:
if (status) {
if (handle_info) {
Expand Down Expand Up @@ -2100,8 +2023,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table,
transport->host_ops.finalize = nvshmemt_libfabric_finalize;
transport->host_ops.show_info = nvshmemt_libfabric_show_info;
transport->host_ops.progress = nvshmemt_libfabric_progress;
transport->host_ops.enforce_cst = nvshmemt_libfabric_enforce_cst;

transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED;
transport->is_successfully_initialized = true;

Expand Down
5 changes: 0 additions & 5 deletions src/modules/transport/libfabric/libfabric.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,6 @@ typedef struct {
struct fid_domain *domain;
struct fid_av *addresses[NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS];
nvshmemt_libfabric_endpoint_t *eps;
/* local_mr is used only for consistency ops. */
struct fid_mr *local_mr[2];
uint64_t local_mr_key[2];
void *local_mr_desc[2];
void *local_mem_ptr;
nvshmemt_libfabric_domain_name_t *domain_names;
int num_domains;
nvshmemt_libfabric_provider provider;
Expand Down
62 changes: 0 additions & 62 deletions src/modules/transport/ucx/ucx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1179,67 +1179,6 @@ int nvshmemt_ucx_finalize(nvshmem_transport_t transport) {
return 0;
}

int nvshmemt_ucx_enforce_cst_at_target(struct nvshmem_transport *tcurr) {
transport_ucx_state_t *ucx_state = (transport_ucx_state_t *)tcurr->state;
nvshmemt_ucx_mem_handle_info_t *mem_handle_info;

mem_handle_info =
(nvshmemt_ucx_mem_handle_info_t *)nvshmemt_mem_handle_cache_get_by_idx(ucx_state->cache, 0);

if (!mem_handle_info) return 0;
#ifdef NVSHMEM_USE_GDRCOPY
if (use_gdrcopy) {
int temp;
gdrcopy_ftable.copy_from_mapping(mem_handle_info->mh, &temp, mem_handle_info->cpu_ptr,
sizeof(int));
return 0;
}
#endif
int mype = tcurr->my_pe;
int ep_index = (ucx_state->ep_count * mype + ucx_state->proxy_ep_idx);
ucp_ep_h ep = ucx_state->endpoints[ep_index];
ucp_request_param_t param;
ucs_status_ptr_t ucs_ptr_rc = NULL;
ucs_status_t ucs_rc;
nvshmemt_ucx_mem_handle_t *mem_handle;
ucp_rkey_h rkey;
int local_int;

mem_handle = mem_handle_info->mem_handle;
if (unlikely(mem_handle->ep_rkey_host == NULL)) {
ucs_rc = ucp_ep_rkey_unpack(ep, mem_handle->rkey_packed_buf, &mem_handle->ep_rkey_host);
if (ucs_rc != UCS_OK) {
NVSHMEMI_ERROR_EXIT("Unable to unpack rkey in UCS transport! Exiting.\n");
}
}
rkey = mem_handle->ep_rkey_host;

param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK;
param.cb.send = nvshmemt_ucx_send_request_cb;

ucs_ptr_rc =
ucp_get_nbx(ep, &local_int, sizeof(int), (uint64_t)mem_handle_info->ptr, rkey, &param);

/* Wait for completion of get. */
if (ucs_ptr_rc != NULL) {
if (UCS_PTR_IS_ERR(ucs_ptr_rc)) {
NVSHMEMI_ERROR_PRINT("UCX CST request completed with error.\n");
return NVSHMEMX_ERROR_INTERNAL;
} else {
do {
ucs_rc = ucp_request_check_status(ucs_ptr_rc);
ucp_worker_progress(ucx_state->worker_context);
} while (ucs_rc == UCS_INPROGRESS);
if (ucs_rc != UCS_OK) {
NVSHMEMI_ERROR_PRINT("UCX CST request completed with error.\n");
return NVSHMEMX_ERROR_INTERNAL;
}
}
}

return 0;
}

int nvshmemt_ucx_show_info(struct nvshmem_transport *transport, int style) {
NVSHMEMI_ERROR_PRINT("UCX show info not implemented");
return 0;
Expand Down Expand Up @@ -1445,7 +1384,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table,
transport->host_ops.finalize = nvshmemt_ucx_finalize;
transport->host_ops.show_info = nvshmemt_ucx_show_info;
transport->host_ops.progress = nvshmemt_ucx_progress;
transport->host_ops.enforce_cst = nvshmemt_ucx_enforce_cst_at_target;
transport->host_ops.enforce_cst_at_target = NULL;
transport->host_ops.put_signal = nvshmemt_put_signal;
transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED;
Expand Down