diff --git a/src/modules/transport/common/env_defs.h b/src/modules/transport/common/env_defs.h index aafc312..2cd690f 100644 --- a/src/modules/transport/common/env_defs.h +++ b/src/modules/transport/common/env_defs.h @@ -50,7 +50,7 @@ NVSHMEMI_ENV_DEF(DISABLE_DATA_DIRECT, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT, NVSHMEMI_ENV_DEF(IB_DISABLE_DMABUF, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT, "Disable use of DMABUF in IBRC/IBDEVX/IBGDA Transports") NVSHMEMI_ENV_DEF(IB_GID_INDEX, int, -1, NVSHMEMI_ENV_CAT_TRANSPORT, "Source GID Index for ROCE") -NVSHMEMI_ENV_DEF(IB_TRAFFIC_CLASS, int, 0, NVSHMEMI_ENV_CAT_TRANSPORT, "Traffic calss for ROCE") +NVSHMEMI_ENV_DEF(IB_TRAFFIC_CLASS, int, 0, NVSHMEMI_ENV_CAT_TRANSPORT, "Traffic class for ROCE") NVSHMEMI_ENV_DEF(IB_ADDR_FAMILY, string, "AF_INET", NVSHMEMI_ENV_CAT_TRANSPORT, "IP address family associated to IB GID " "dynamically selected by NVSHMEM when NVSHMEM_IB_GID_INDEX is left unset") @@ -69,6 +69,9 @@ NVSHMEMI_ENV_DEF(IB_NUM_RC_PER_DEVICE, int, 1, NVSHMEMI_ENV_CAT_TRANSPORT, "Number of RC qpairs to create per device in the IB proxy-based transports." "A device is each enumerated IB device, either a full HCA or a single port of a " "multi-port HCA.") +NVSHMEMI_ENV_DEF(IBGDA_ENABLE_SYSTEM_TRAFFIC_CLASS, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT, + "When true, read and use global traffic class from sysfs (/sys/class/infiniband/.../traffic_class) " + "if available. System-level traffic class takes precedence over NVSHMEM_IB_TRAFFIC_CLASS.") NVSHMEMI_ENV_DEF(HCA_PREFIX, string, "mlx5", NVSHMEMI_ENV_CAT_TRANSPORT, "Prefix of HCA interface names. Example, mlx5, ibp.") diff --git a/src/modules/transport/common/transport_ib_common.cpp b/src/modules/transport/common/transport_ib_common.cpp index 195b516..80c91ee 100644 --- a/src/modules/transport/common/transport_ib_common.cpp +++ b/src/modules/transport/common/transport_ib_common.cpp @@ -21,6 +21,76 @@ #include "non_abi/nvshmem_build_options.h" // for NVSHMEM_USE_MLX5DV #include "transport_common.h" // for LOAD_SYM, INFO, MAXPAT... +static void nvshmemt_ib_read_traffic_class_from_sysfs(const char *sysfs_path, + struct nvshmemt_ib_traffic_class_info *tclass_info) { + FILE *fp = fopen(sysfs_path, "r"); + if (!fp) return; + + char line[MAXPATHSIZE]; + while (fgets(line, sizeof(line), fp)) { + int traffic_class_value; + if (sscanf(line, "Global tclass=%d", &traffic_class_value) == 1) { + tclass_info->global_tclass = traffic_class_value; + } + } + fclose(fp); +} + +static int nvshmemt_ib_query_device_traffic_class(const char *ib_device_name, + int port_number, + struct nvshmemt_ib_traffic_class_info *tclass_info, + int log_level) { + int status; + char tclass_sysfs_path[MAXPATHSIZE]; + + status = snprintf(tclass_sysfs_path, MAXPATHSIZE, + "/sys/class/infiniband/%s/tc/%d/traffic_class", + ib_device_name, port_number); + if (status < 0 || status >= MAXPATHSIZE) { + NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to construct traffic class sysfs path for device %s port %d.\n", + ib_device_name, port_number); + } + + if (access(tclass_sysfs_path, F_OK) == 0) { + nvshmemt_ib_read_traffic_class_from_sysfs(tclass_sysfs_path, tclass_info); + } else { + NVSHMEMI_WARN_PRINT("Traffic class sysfs file not found: %s", tclass_sysfs_path); + } + + status = NVSHMEMX_SUCCESS; +out: + return status; +} + +int nvshmemt_ib_get_tclass(const char *ib_device_name, int port_number, int log_level, + struct nvshmemi_options_s *options) { + int user_traffic_class = options ? options->IB_TRAFFIC_CLASS : 0; + + if (!options || !options->IB_ENABLE_SYSTEM_TRAFFIC_CLASS) { + return user_traffic_class; + } + + struct nvshmemt_ib_traffic_class_info tclass_info; + + memset(&tclass_info, -1, sizeof(struct nvshmemt_ib_traffic_class_info)); + + int status = nvshmemt_ib_query_device_traffic_class(ib_device_name, port_number, + &tclass_info, log_level); + + if (status != NVSHMEMX_SUCCESS) { + NVSHMEMI_WARN_PRINT("Failed to query traffic class for device %s port %d\n", ib_device_name, port_number); + return user_traffic_class; + } + + // If system traffic class is set (>0), use it; otherwise use user-specified value + if (tclass_info.global_tclass > 0) { + return tclass_info.global_tclass; + } + + return user_traffic_class; +} + int nvshmemt_ib_common_nv_peer_mem_available() { if (access("/sys/kernel/mm/memory_peers/nv_mem/version", F_OK) == 0) { return NVSHMEMX_SUCCESS; diff --git a/src/modules/transport/common/transport_ib_common.h b/src/modules/transport/common/transport_ib_common.h index 4ea66d4..a8fd0e0 100644 --- a/src/modules/transport/common/transport_ib_common.h +++ b/src/modules/transport/common/transport_ib_common.h @@ -133,6 +133,10 @@ struct nvshmemt_ib_common_mem_handle { bool local_only; }; +struct nvshmemt_ib_traffic_class_info { + int global_tclass; +}; + struct nvshmemt_ibv_function_table { int (*fork_init)(void); struct ibv_ah *(*create_ah)(struct ibv_pd *pd, struct ibv_ah_attr *ah_attr); @@ -219,6 +223,9 @@ int nvshmemt_ib_common_connect_endpoints(nvshmem_transport_t t, int *selected_de nvshmemt_ib_common_ep_ptr_t nvshmemt_ib_common_get_ep_from_qp_index(nvshmem_transport_t t, int qp_index, int pe_index); +int nvshmemt_ib_get_tclass(const char *ib_device_name, int port_number, int log_level, + struct nvshmemi_options_s *options); + /* The following code is for dynamic GID detection for RoCE platforms. It has been adapted from NCCL: https://gitlab-master.nvidia.com/nccl/nccl/-/merge_requests/359 */ static sa_family_t env_ib_addr_family(int log_level, nvshmemi_options_s *options) { diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp index 7989e38..a6c62fa 100644 --- a/src/modules/transport/ibgda/ibgda.cpp +++ b/src/modules/transport/ibgda/ibgda.cpp @@ -275,6 +275,7 @@ struct ibgda_device { bool may_skip_cst; ibgda_nic_handler_t nic_handler; bool data_direct; + int tclass_val; }; struct nvshmemt_ibgda_device_state_cache { @@ -1650,7 +1651,7 @@ static int ibgda_qp_rst2init(struct ibgda_ep *ep, const struct ibgda_device *dev * ============================================================================= */ static int ibgda_dci_init2rtr(nvshmemt_ibgda_state_t *ibgda_state, struct ibgda_ep *ep, - const struct ibgda_device *device, int portid) { + const struct ibgda_device *device, int portid, int traffic_class) { int status = 0; uint8_t cmd_in[DEVX_ST_SZ_BYTES(init2rtr_qp_in)] = { @@ -1676,9 +1677,9 @@ static int ibgda_dci_init2rtr(nvshmemt_ibgda_state_t *ibgda_state, struct ibgda_ if (port_attr->link_layer == IBV_LINK_LAYER_INFINIBAND) { DEVX_SET(qpc, qpc, primary_address_path.sl, ibgda_state->options->IB_SL); } else if (port_attr->link_layer == IBV_LINK_LAYER_ETHERNET) { - DEVX_SET(qpc, qpc, primary_address_path.tclass, ibgda_state->options->IB_TRAFFIC_CLASS); + DEVX_SET(qpc, qpc, primary_address_path.tclass, traffic_class); DEVX_SET(qpc, qpc, primary_address_path.eth_prio, ibgda_state->options->IB_SL); - DEVX_SET(qpc, qpc, primary_address_path.dscp, ibgda_state->options->IB_TRAFFIC_CLASS >> 2); + DEVX_SET(qpc, qpc, primary_address_path.dscp, traffic_class >> 2); } status = mlx5dv_devx_obj_modify(ep->devx_qp, cmd_in, sizeof(cmd_in), cmd_out, sizeof(cmd_out)); @@ -1692,7 +1693,7 @@ static int ibgda_dci_init2rtr(nvshmemt_ibgda_state_t *ibgda_state, struct ibgda_ static int ibgda_rc_init2rtr(nvshmemt_ibgda_state_t *ibgda_state, struct ibgda_ep *ep, const struct ibgda_device *device, int portid, - struct ibgda_rc_handle *peer_ep_handle) { + struct ibgda_rc_handle *peer_ep_handle, int traffic_class) { int status = 0; uint8_t cmd_in[DEVX_ST_SZ_BYTES(init2rtr_qp_in)] = { @@ -1753,7 +1754,7 @@ static int ibgda_rc_init2rtr(nvshmemt_ibgda_state_t *ibgda_state, struct ibgda_e ah_attr.grh.dgid.global.subnet_prefix = peer_ep_handle->spn; ah_attr.grh.dgid.global.interface_id = peer_ep_handle->iid; ah_attr.grh.sgid_index = device->common_device.gid_info[portid - 1].local_gid_index; - ah_attr.grh.traffic_class = ibgda_state->options->IB_TRAFFIC_CLASS; + ah_attr.grh.traffic_class = traffic_class; ah_attr.sl = ibgda_state->options->IB_SL; ah_attr.src_path_bits = 0; @@ -1775,7 +1776,7 @@ static int ibgda_rc_init2rtr(nvshmemt_ibgda_state_t *ibgda_state, struct ibgda_e device->common_device.gid_info[portid - 1].local_gid_index); DEVX_SET(qpc, qpc, primary_address_path.eth_prio, ibgda_state->options->IB_SL); DEVX_SET(qpc, qpc, primary_address_path.udp_sport, ah_attr.dlid); - DEVX_SET(qpc, qpc, primary_address_path.dscp, ibgda_state->options->IB_TRAFFIC_CLASS >> 2); + DEVX_SET(qpc, qpc, primary_address_path.dscp, traffic_class >> 2); memcpy(DEVX_ADDR_OF(qpc, qpc, primary_address_path.rgid_rip), &dah.av->rgid, sizeof(dah.av->rgid)); @@ -2320,7 +2321,7 @@ static int ibgda_destroy_dct_shared_objects(nvshmemt_ibgda_state_t *ibgda_state, } static int ibgda_create_dct_shared_objects(nvshmemt_ibgda_state_t *ibgda_state, - struct ibgda_device *device, int portid) { + struct ibgda_device *device, int portid, int traffic_class) { int status = 0; const struct ibv_port_attr *port_attr = device->common_device.port_attr + (portid - 1); @@ -2400,7 +2401,7 @@ static int ibgda_create_dct_shared_objects(nvshmemt_ibgda_state_t *ibgda_state, device->common_device.gid_info[portid - 1].local_gid.global.interface_id; ah_attr.grh.flow_label = 0; ah_attr.grh.sgid_index = device->common_device.gid_info[portid - 1].local_gid_index; - ah_attr.grh.traffic_class = ibgda_state->options->IB_TRAFFIC_CLASS; + ah_attr.grh.traffic_class = traffic_class; ah_attr.grh.hop_limit = IBGDA_GRH_HOP_LIMIT; support_half_av_seg = false; } else { @@ -2833,7 +2834,7 @@ static int ibgda_setup_dci_endpoints(nvshmemt_ibgda_state_t *ibgda_state, NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_qp_rst2init failed on DCI #%d.", i); - status = ibgda_dci_init2rtr(ibgda_state, device->dci.eps[i], device, portid); + status = ibgda_dci_init2rtr(ibgda_state, device->dci.eps[i], device, portid, device->tclass_val); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_dci_init2rtr failed on DCI #%d.", i); @@ -3143,7 +3144,7 @@ static int ibgda_setup_rc_endpoints(nvshmemt_ibgda_state_t *ibgda_state, "ibgda_qp_rst2init failed on RC #%d.", ep_index); status = ibgda_rc_init2rtr(ibgda_state, device->rc.eps[ep_index], device, portid, - &peer_ep_handles[peer_handle_index]); + &peer_ep_handles[peer_handle_index], device->tclass_val); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_rc_init2rtr failed on RC #%d.", ep_index); @@ -3834,7 +3835,7 @@ static int ibgda_connect_device_resources(nvshmemt_ibgda_state_t *ibgda_state, status = ibgda_create_qp_shared_objects(ibgda_state, device); if (status) return status; - status = ibgda_create_dct_shared_objects(ibgda_state, device, portid); + status = ibgda_create_dct_shared_objects(ibgda_state, device, portid, device->tclass_val); if (status) return status; status = ibgda_create_dci_shared_objects(ibgda_state, device); @@ -4792,6 +4793,10 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, continue; } + device->tclass_val = nvshmemt_ib_get_tclass(name, ibgda_state->port_ids[i], + ibgda_state->log_level, ibgda_state->options); + INFO(ibgda_state->log_level, "traffic class value for %s is %d.", name, device->tclass_val); + /* Report whether we need to do atomic endianness conversions on 8 byte operands. */ status = nvshmemt_ib_common_query_endianness_conversion_size(&atomic_host_endian_size, device->common_device.context);