diff --git a/src/modules/transport/common/env_defs.h b/src/modules/transport/common/env_defs.h index aafc312..be844b7 100644 --- a/src/modules/transport/common/env_defs.h +++ b/src/modules/transport/common/env_defs.h @@ -98,6 +98,9 @@ NVSHMEMI_ENV_DEF(DISABLE_LOCAL_ONLY_PROXY, bool, false, NVSHMEMI_ENV_CAT_TRANSPO NVSHMEMI_ENV_DEF(LIBFABRIC_PROVIDER, string, "cxi", NVSHMEMI_ENV_CAT_TRANSPORT, "Set the feature set provider for the libfabric transport: cxi, efa, verbs") +NVSHMEMI_ENV_DEF(DISABLE_LIBFABRIC_EFA_DIRECT, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT, + "Disable EFA direct fabric and use efa instead") + #if defined(NVSHMEM_IBGDA_SUPPORT) || defined(NVSHMEM_ENV_ALL) /** GPU-initiated communication **/ NVSHMEMI_ENV_DEF(IBGDA_ENABLE_MULTI_PORT, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT, diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 927c079..f897dfc 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -1907,7 +1907,7 @@ static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { return 0; } -static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabric_state_t *state) { +static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabric_state_t *state, struct nvshmemi_options_s *options) { struct fi_info info; struct fi_tx_attr tx_attr; struct fi_rx_attr rx_attr; @@ -1952,6 +1952,9 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr FI_MR_LOCAL | FI_MR_VIRT_ADDR | FI_MR_ALLOCATED | FI_MR_PROV_KEY | FI_MR_HMEM; info.caps |= FI_MSG; info.caps |= FI_SOURCE; + if (options->DISABLE_LIBFABRIC_EFA_DIRECT) { + info.fabric_attr->name = strdup("efa"); + } } if (use_staged_atomics) { @@ -2209,7 +2212,7 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, #undef NVSHMEMI_SET_ENV_VAR /* Prepare fabric state information. */ - status = nvshmemi_libfabric_init_state(transport, libfabric_state); + status = nvshmemi_libfabric_init_state(transport, libfabric_state, &options); if (status) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out_clean, "Failed to initialize the libfabric state.\n");