Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ nvshmem_library_set_base_config(nvshmem_host)
## End generic variable configuration

## Start transports
set(TRANSPORT_VERSION_MAJOR 4)
set(TRANSPORT_VERSION_MAJOR 5)
set(TRANSPORT_VERSION_MINOR 0)
set(TRANSPORT_VERSION_PATCH 0)

Expand Down
37 changes: 30 additions & 7 deletions src/host/proxy/proxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
#include "internal/host/nvshmem_nvtx.hpp" // for nvshmem_nvtx_...
// IWYU pragma: no_include "nvtx3.hpp"

#define NVSHMEM_PUT_SIGNAL_MAX_WRITES 256

uint64_t proxy_channel_g_buf_size; /* Total size of g_buf in bytes */
uint64_t proxy_channel_g_buf_log_size; /* Total size of g_buf in bytes */

Expand Down Expand Up @@ -268,6 +270,19 @@ int nvshmemi_proxy_init(nvshmemi_state_t *state, int proxy_level) {

proxy_state->nvshmemi_state = state;

/* Allocate put_signal arrays */
proxy_state->put_signal_local_desc_arr =
(rma_memdesc_t *)malloc(NVSHMEM_PUT_SIGNAL_MAX_WRITES * sizeof(rma_memdesc_t));
NVSHMEMI_NULL_ERROR_JMP(proxy_state->put_signal_local_desc_arr, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out,
"Cannot allocate proxy put_signal_local_desc_arr.\n");
proxy_state->put_signal_remote_desc_arr =
(rma_memdesc_t *)malloc(NVSHMEM_PUT_SIGNAL_MAX_WRITES * sizeof(rma_memdesc_t));
NVSHMEMI_NULL_ERROR_JMP(proxy_state->put_signal_remote_desc_arr, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out,
"Cannot allocate proxy put_signal_remote_desc_arr.\n");
proxy_state->put_signal_bytes_arr =
(rma_bytesdesc_t *)malloc(NVSHMEM_PUT_SIGNAL_MAX_WRITES * sizeof(rma_bytesdesc_t));
NVSHMEMI_NULL_ERROR_JMP(proxy_state->put_signal_bytes_arr, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out,
"Cannot allocate proxy put_signal_bytes_arr.\n");
CUDA_RUNTIME_CHECK(
cudaMallocHost((void **)&proxy_state->global_exit_request_state, sizeof(int), 0));
CUDA_RUNTIME_CHECK(cudaMallocHost((void **)&proxy_state->global_exit_code, sizeof(int), 0));
Expand Down Expand Up @@ -1083,8 +1098,7 @@ inline int process_channel_put_signal(proxy_state_t *state, proxy_channel_t *ch,
rma_verb_t write_verb;
rma_bytesdesc_t write_bytes_desc;
rma_memdesc_t write_remote_desc, write_local_desc;
std::vector<rma_memdesc_t> local_write_desc_vec, remote_write_desc_vec;
std::vector<rma_bytesdesc_t> write_bytes_vec;
int num_writes = 0;
size_t chunk_size, local_chunk_size, remote_chunk_size, size_remaining;
amo_verb_t sig_verb;
amo_memdesc_t sig_target_desc;
Expand Down Expand Up @@ -1188,15 +1202,18 @@ inline int process_channel_put_signal(proxy_state_t *state, proxy_channel_t *ch,
chunk_size = std::min(local_chunk_size, std::min(remote_chunk_size, size_remaining));
write_bytes_desc.nelems = chunk_size;

local_write_desc_vec.push_back(write_local_desc);
remote_write_desc_vec.push_back(write_remote_desc);
write_bytes_vec.push_back(write_bytes_desc);
state->put_signal_local_desc_arr[num_writes] = write_local_desc;
state->put_signal_remote_desc_arr[num_writes] = write_remote_desc;
state->put_signal_bytes_arr[num_writes] = write_bytes_desc;
num_writes++;

size_remaining -= chunk_size;
lwrite_ptr = (char *)lwrite_ptr + chunk_size;
rwrite_ptr = (char *)rwrite_ptr + chunk_size;
}

assert(num_writes <= NVSHMEM_PUT_SIGNAL_MAX_WRITES);

/* build signal parameters */
memset(&sig_target_desc, 0, sizeof(amo_memdesc_t));
sig_verb.desc = (nvshmemi_amo_t)ps_req_3->sig_op;
Expand All @@ -1215,8 +1232,9 @@ inline int process_channel_put_signal(proxy_state_t *state, proxy_channel_t *ch,
TRACE(NVSHMEM_PROXY, "process_channel_put_signal laddr %p pe %d", lwrite_ptr, pe);

tcurr = state->transport[pe];
status = tcurr->host_ops.put_signal(tcurr, pe, write_verb, remote_write_desc_vec,
local_write_desc_vec, write_bytes_vec, sig_verb,
status = tcurr->host_ops.put_signal(tcurr, pe, write_verb, state->put_signal_remote_desc_arr,
state->put_signal_local_desc_arr,
state->put_signal_bytes_arr, num_writes, sig_verb,
&sig_target_desc, sig_bytes_desc, qp_index);
if (unlikely(status)) {
NVSHMEMI_ERROR_PRINT("aborting due to error in process_channel_put_signal\n");
Expand Down Expand Up @@ -1557,5 +1575,10 @@ int nvshmemi_proxy_finalize(nvshmemi_state_t *state) {
if (proxy_state->nvshmemi_timeout)
CUDA_RUNTIME_CHECK(cudaFreeHost(proxy_state->nvshmemi_timeout));

/* Free put_signal arrays */
free(proxy_state->put_signal_local_desc_arr);
free(proxy_state->put_signal_remote_desc_arr);
free(proxy_state->put_signal_bytes_arr);

return 0;
}
4 changes: 4 additions & 0 deletions src/host/proxy/proxy_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ typedef struct proxy_state {
int gdr_device_native_ordering;
int *global_exit_request_state;
int *global_exit_code;
/* put_signal arrays */
rma_memdesc_t *put_signal_local_desc_arr;
rma_memdesc_t *put_signal_remote_desc_arr;
rma_bytesdesc_t *put_signal_bytes_arr;
} proxy_state_t;

#endif
9 changes: 4 additions & 5 deletions src/include/internal/host_transport/transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,10 @@ typedef int (*amo_handle)(struct nvshmem_transport *tcurr, int pe, void *curetpt
typedef int (*fence_handle)(struct nvshmem_transport *tcurr, int pe, int qp_index, int is_multi);
typedef int (*quiet_handle)(struct nvshmem_transport *tcurr, int pe, int qp_index);
typedef int (*put_signal_handle)(struct nvshmem_transport *tcurr, int pe, rma_verb_t write_verb,
std::vector<rma_memdesc_t> &write_remote,
std::vector<rma_memdesc_t> &write_local,
std::vector<rma_bytesdesc_t> &write_bytesdesc, amo_verb_t sig_verb,
amo_memdesc_t *sig_target, amo_bytesdesc_t sig_bytesdesc,
int qp_index);
rma_memdesc_t *write_remote, rma_memdesc_t *write_local,
rma_bytesdesc_t *write_bytesdesc, int num_writes,
amo_verb_t sig_verb, amo_memdesc_t *sig_target,
amo_bytesdesc_t sig_bytesdesc, int qp_index);

struct nvshmem_transport_host_ops {
int (*can_reach_peer)(int *access, nvshmem_transport_pe_info_t *peer_info,
Expand Down
10 changes: 4 additions & 6 deletions src/modules/transport/common/transport_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,16 +224,14 @@ int nvshmemt_mem_handle_cache_fini(struct transport_mem_handle_info_cache *cache
}

int nvshmemt_put_signal(struct nvshmem_transport *tcurr, int pe, rma_verb_t write_verb,
std::vector<rma_memdesc_t> &write_remote,
std::vector<rma_memdesc_t> &write_local,
std::vector<rma_bytesdesc_t> &write_bytesdesc, amo_verb_t sig_verb,
rma_memdesc_t *write_remote, rma_memdesc_t *write_local,
rma_bytesdesc_t *write_bytesdesc, int num_writes, amo_verb_t sig_verb,
amo_memdesc_t *sig_target, amo_bytesdesc_t sig_bytesdesc, int is_proxy) {
int status = 0;
assert(tcurr->host_ops.rma);
assert(tcurr->host_ops.amo);
assert(write_remote.size() == write_local.size() &&
write_local.size() == write_bytesdesc.size());
for (size_t i = 0; i < write_remote.size(); i++) {

for (int i = 0; i < num_writes; i++) {
status = tcurr->host_ops.rma(tcurr, pe, write_verb, &write_remote[i], &write_local[i],
write_bytesdesc[i], is_proxy);
}
Expand Down
5 changes: 2 additions & 3 deletions src/modules/transport/common/transport_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,8 @@ static inline int nvshmemt_common_get_log_level(struct nvshmemi_options_s *optio
struct transport_mem_handle_info_cache; // IWYU pragma: keep

int nvshmemt_put_signal(struct nvshmem_transport *tcurr, int pe, rma_verb_t write_verb,
std::vector<rma_memdesc_t> &write_remote,
std::vector<rma_memdesc_t> &write_local,
std::vector<rma_bytesdesc_t> &write_bytesdesc, amo_verb_t sig_verb,
rma_memdesc_t *write_remote, rma_memdesc_t *write_local,
rma_bytesdesc_t *write_bytesdesc, int num_writes, amo_verb_t sig_verb,
amo_memdesc_t *sig_target, amo_bytesdesc_t sig_bytesdesc, int is_proxy);

struct nvshmemt_hca_info {
Expand Down
13 changes: 5 additions & 8 deletions src/modules/transport/libfabric/libfabric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1002,9 +1002,8 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in
}

int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_verb_t write_verb,
std::vector<rma_memdesc_t> &write_remote,
std::vector<rma_memdesc_t> &write_local,
std::vector<rma_bytesdesc_t> &write_bytes_desc,
rma_memdesc_t *write_remote, rma_memdesc_t *write_local,
rma_bytesdesc_t *write_bytes_desc, int num_writes,
amo_verb_t sig_verb, amo_memdesc_t *sig_target,
amo_bytesdesc_t sig_bytes_desc, int is_proxy) {
nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state;
Expand Down Expand Up @@ -1037,22 +1036,20 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v
goto out;
}

assert(write_remote.size() == write_local.size() &&
write_local.size() == write_bytes_desc.size());
for (size_t i = 0; i < write_remote.size(); i++) {
for (int i = 0; i < num_writes; i++) {
status =
nvshmemt_libfabric_rma_impl(tcurr, pe, write_verb, &write_remote[i], &write_local[i],
write_bytes_desc[i], is_proxy, &sequence_count);
if (unlikely(status)) {
NVSHMEMI_ERROR_PRINT(
"Error in nvshmemt_put_signal_unordered, could not submit write #%lu\n", i);
"Error in nvshmemt_put_signal_unordered, could not submit write #%d\n", i);
goto out;
}
}

assert(use_staged_atomics == true);
status = nvshmemt_libfabric_gdr_signal(tcurr, pe, NULL, sig_verb, sig_target, sig_bytes_desc,
is_proxy, sequence_count, (uint16_t)write_remote.size());
is_proxy, sequence_count, (uint16_t)num_writes);
out:
if (status) {
NVSHMEMI_ERROR_PRINT(
Expand Down