Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/modules/transport/common/env_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ NVSHMEMI_ENV_DEF(DISABLE_DATA_DIRECT, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT,
NVSHMEMI_ENV_DEF(IB_DISABLE_DMABUF, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT,
"Disable use of DMABUF in IBRC/IBDEVX/IBGDA Transports")
NVSHMEMI_ENV_DEF(IB_GID_INDEX, int, -1, NVSHMEMI_ENV_CAT_TRANSPORT, "Source GID Index for ROCE")
NVSHMEMI_ENV_DEF(IB_TRAFFIC_CLASS, int, 0, NVSHMEMI_ENV_CAT_TRANSPORT, "Traffic calss for ROCE")
NVSHMEMI_ENV_DEF(IB_TRAFFIC_CLASS, int, 0, NVSHMEMI_ENV_CAT_TRANSPORT, "Traffic class for ROCE")
NVSHMEMI_ENV_DEF(IB_ADDR_FAMILY, string, "AF_INET", NVSHMEMI_ENV_CAT_TRANSPORT,
"IP address family associated to IB GID "
"dynamically selected by NVSHMEM when NVSHMEM_IB_GID_INDEX is left unset")
Expand All @@ -69,6 +69,9 @@ NVSHMEMI_ENV_DEF(IB_NUM_RC_PER_DEVICE, int, 1, NVSHMEMI_ENV_CAT_TRANSPORT,
"Number of RC qpairs to create per device in the IB proxy-based transports."
"A device is each enumerated IB device, either a full HCA or a single port of a "
"multi-port HCA.")
NVSHMEMI_ENV_DEF(IBGDA_ENABLE_SYSTEM_TRAFFIC_CLASS, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT,
"When true, read and use global traffic class from sysfs (/sys/class/infiniband/.../traffic_class) "
"if available. System-level traffic class takes precedence over NVSHMEM_IB_TRAFFIC_CLASS.")

NVSHMEMI_ENV_DEF(HCA_PREFIX, string, "mlx5", NVSHMEMI_ENV_CAT_TRANSPORT,
"Prefix of HCA interface names. Example, mlx5, ibp.")
Expand Down
70 changes: 70 additions & 0 deletions src/modules/transport/common/transport_ib_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,76 @@
#include "non_abi/nvshmem_build_options.h" // for NVSHMEM_USE_MLX5DV
#include "transport_common.h" // for LOAD_SYM, INFO, MAXPAT...

static void nvshmemt_ib_read_traffic_class_from_sysfs(const char *sysfs_path,
struct nvshmemt_ib_traffic_class_info *tclass_info) {
FILE *fp = fopen(sysfs_path, "r");
if (!fp) return;

char line[MAXPATHSIZE];
while (fgets(line, sizeof(line), fp)) {
int traffic_class_value;
if (sscanf(line, "Global tclass=%d", &traffic_class_value) == 1) {
tclass_info->global_tclass = traffic_class_value;
}
}
fclose(fp);
}

static int nvshmemt_ib_query_device_traffic_class(const char *ib_device_name,
int port_number,
struct nvshmemt_ib_traffic_class_info *tclass_info,
int log_level) {
int status;
char tclass_sysfs_path[MAXPATHSIZE];

status = snprintf(tclass_sysfs_path, MAXPATHSIZE,
"/sys/class/infiniband/%s/tc/%d/traffic_class",
ib_device_name, port_number);
if (status < 0 || status >= MAXPATHSIZE) {
NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
"Unable to construct traffic class sysfs path for device %s port %d.\n",
ib_device_name, port_number);
}

if (access(tclass_sysfs_path, F_OK) == 0) {
nvshmemt_ib_read_traffic_class_from_sysfs(tclass_sysfs_path, tclass_info);
} else {
NVSHMEMI_WARN_PRINT("Traffic class sysfs file not found: %s", tclass_sysfs_path);
}

status = NVSHMEMX_SUCCESS;
out:
return status;
}

int nvshmemt_ib_get_tclass(const char *ib_device_name, int port_number, int log_level,
struct nvshmemi_options_s *options) {
int user_traffic_class = options ? options->IB_TRAFFIC_CLASS : 0;

if (!options || !options->IB_ENABLE_SYSTEM_TRAFFIC_CLASS) {
return user_traffic_class;
}

struct nvshmemt_ib_traffic_class_info tclass_info;

memset(&tclass_info, -1, sizeof(struct nvshmemt_ib_traffic_class_info));

int status = nvshmemt_ib_query_device_traffic_class(ib_device_name, port_number,
&tclass_info, log_level);

if (status != NVSHMEMX_SUCCESS) {
NVSHMEMI_WARN_PRINT("Failed to query traffic class for device %s port %d\n", ib_device_name, port_number);
return user_traffic_class;
}

// If system traffic class is set (>0), use it; otherwise use user-specified value
if (tclass_info.global_tclass > 0) {
return tclass_info.global_tclass;
}

return user_traffic_class;
}

int nvshmemt_ib_common_nv_peer_mem_available() {
if (access("/sys/kernel/mm/memory_peers/nv_mem/version", F_OK) == 0) {
return NVSHMEMX_SUCCESS;
Expand Down
7 changes: 7 additions & 0 deletions src/modules/transport/common/transport_ib_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ struct nvshmemt_ib_common_mem_handle {
bool local_only;
};

struct nvshmemt_ib_traffic_class_info {
int global_tclass;
};

struct nvshmemt_ibv_function_table {
int (*fork_init)(void);
struct ibv_ah *(*create_ah)(struct ibv_pd *pd, struct ibv_ah_attr *ah_attr);
Expand Down Expand Up @@ -219,6 +223,9 @@ int nvshmemt_ib_common_connect_endpoints(nvshmem_transport_t t, int *selected_de
nvshmemt_ib_common_ep_ptr_t nvshmemt_ib_common_get_ep_from_qp_index(nvshmem_transport_t t,
int qp_index, int pe_index);

int nvshmemt_ib_get_tclass(const char *ib_device_name, int port_number, int log_level,
struct nvshmemi_options_s *options);

/* The following code is for dynamic GID detection for RoCE platforms.
It has been adapted from NCCL: https://gitlab-master.nvidia.com/nccl/nccl/-/merge_requests/359 */
static sa_family_t env_ib_addr_family(int log_level, nvshmemi_options_s *options) {
Expand Down
27 changes: 16 additions & 11 deletions src/modules/transport/ibgda/ibgda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ struct ibgda_device {
bool may_skip_cst;
ibgda_nic_handler_t nic_handler;
bool data_direct;
int tclass_val;
};

struct nvshmemt_ibgda_device_state_cache {
Expand Down Expand Up @@ -1650,7 +1651,7 @@ static int ibgda_qp_rst2init(struct ibgda_ep *ep, const struct ibgda_device *dev
* ============================================================================= */

static int ibgda_dci_init2rtr(nvshmemt_ibgda_state_t *ibgda_state, struct ibgda_ep *ep,
const struct ibgda_device *device, int portid) {
const struct ibgda_device *device, int portid, int traffic_class) {
int status = 0;

uint8_t cmd_in[DEVX_ST_SZ_BYTES(init2rtr_qp_in)] = {
Expand All @@ -1676,9 +1677,9 @@ static int ibgda_dci_init2rtr(nvshmemt_ibgda_state_t *ibgda_state, struct ibgda_
if (port_attr->link_layer == IBV_LINK_LAYER_INFINIBAND) {
DEVX_SET(qpc, qpc, primary_address_path.sl, ibgda_state->options->IB_SL);
} else if (port_attr->link_layer == IBV_LINK_LAYER_ETHERNET) {
DEVX_SET(qpc, qpc, primary_address_path.tclass, ibgda_state->options->IB_TRAFFIC_CLASS);
DEVX_SET(qpc, qpc, primary_address_path.tclass, traffic_class);
DEVX_SET(qpc, qpc, primary_address_path.eth_prio, ibgda_state->options->IB_SL);
DEVX_SET(qpc, qpc, primary_address_path.dscp, ibgda_state->options->IB_TRAFFIC_CLASS >> 2);
DEVX_SET(qpc, qpc, primary_address_path.dscp, traffic_class >> 2);
}

status = mlx5dv_devx_obj_modify(ep->devx_qp, cmd_in, sizeof(cmd_in), cmd_out, sizeof(cmd_out));
Expand All @@ -1692,7 +1693,7 @@ static int ibgda_dci_init2rtr(nvshmemt_ibgda_state_t *ibgda_state, struct ibgda_

static int ibgda_rc_init2rtr(nvshmemt_ibgda_state_t *ibgda_state, struct ibgda_ep *ep,
const struct ibgda_device *device, int portid,
struct ibgda_rc_handle *peer_ep_handle) {
struct ibgda_rc_handle *peer_ep_handle, int traffic_class) {
int status = 0;

uint8_t cmd_in[DEVX_ST_SZ_BYTES(init2rtr_qp_in)] = {
Expand Down Expand Up @@ -1753,7 +1754,7 @@ static int ibgda_rc_init2rtr(nvshmemt_ibgda_state_t *ibgda_state, struct ibgda_e
ah_attr.grh.dgid.global.subnet_prefix = peer_ep_handle->spn;
ah_attr.grh.dgid.global.interface_id = peer_ep_handle->iid;
ah_attr.grh.sgid_index = device->common_device.gid_info[portid - 1].local_gid_index;
ah_attr.grh.traffic_class = ibgda_state->options->IB_TRAFFIC_CLASS;
ah_attr.grh.traffic_class = traffic_class;
ah_attr.sl = ibgda_state->options->IB_SL;
ah_attr.src_path_bits = 0;

Expand All @@ -1775,7 +1776,7 @@ static int ibgda_rc_init2rtr(nvshmemt_ibgda_state_t *ibgda_state, struct ibgda_e
device->common_device.gid_info[portid - 1].local_gid_index);
DEVX_SET(qpc, qpc, primary_address_path.eth_prio, ibgda_state->options->IB_SL);
DEVX_SET(qpc, qpc, primary_address_path.udp_sport, ah_attr.dlid);
DEVX_SET(qpc, qpc, primary_address_path.dscp, ibgda_state->options->IB_TRAFFIC_CLASS >> 2);
DEVX_SET(qpc, qpc, primary_address_path.dscp, traffic_class >> 2);

memcpy(DEVX_ADDR_OF(qpc, qpc, primary_address_path.rgid_rip), &dah.av->rgid,
sizeof(dah.av->rgid));
Expand Down Expand Up @@ -2320,7 +2321,7 @@ static int ibgda_destroy_dct_shared_objects(nvshmemt_ibgda_state_t *ibgda_state,
}

static int ibgda_create_dct_shared_objects(nvshmemt_ibgda_state_t *ibgda_state,
struct ibgda_device *device, int portid) {
struct ibgda_device *device, int portid, int traffic_class) {
int status = 0;

const struct ibv_port_attr *port_attr = device->common_device.port_attr + (portid - 1);
Expand Down Expand Up @@ -2400,7 +2401,7 @@ static int ibgda_create_dct_shared_objects(nvshmemt_ibgda_state_t *ibgda_state,
device->common_device.gid_info[portid - 1].local_gid.global.interface_id;
ah_attr.grh.flow_label = 0;
ah_attr.grh.sgid_index = device->common_device.gid_info[portid - 1].local_gid_index;
ah_attr.grh.traffic_class = ibgda_state->options->IB_TRAFFIC_CLASS;
ah_attr.grh.traffic_class = traffic_class;
ah_attr.grh.hop_limit = IBGDA_GRH_HOP_LIMIT;
support_half_av_seg = false;
} else {
Expand Down Expand Up @@ -2833,7 +2834,7 @@ static int ibgda_setup_dci_endpoints(nvshmemt_ibgda_state_t *ibgda_state,
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
"ibgda_qp_rst2init failed on DCI #%d.", i);

status = ibgda_dci_init2rtr(ibgda_state, device->dci.eps[i], device, portid);
status = ibgda_dci_init2rtr(ibgda_state, device->dci.eps[i], device, portid, device->tclass_val);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
"ibgda_dci_init2rtr failed on DCI #%d.", i);

Expand Down Expand Up @@ -3143,7 +3144,7 @@ static int ibgda_setup_rc_endpoints(nvshmemt_ibgda_state_t *ibgda_state,
"ibgda_qp_rst2init failed on RC #%d.", ep_index);

status = ibgda_rc_init2rtr(ibgda_state, device->rc.eps[ep_index], device, portid,
&peer_ep_handles[peer_handle_index]);
&peer_ep_handles[peer_handle_index], device->tclass_val);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
"ibgda_rc_init2rtr failed on RC #%d.", ep_index);

Expand Down Expand Up @@ -3834,7 +3835,7 @@ static int ibgda_connect_device_resources(nvshmemt_ibgda_state_t *ibgda_state,
status = ibgda_create_qp_shared_objects(ibgda_state, device);
if (status) return status;

status = ibgda_create_dct_shared_objects(ibgda_state, device, portid);
status = ibgda_create_dct_shared_objects(ibgda_state, device, portid, device->tclass_val);
if (status) return status;

status = ibgda_create_dci_shared_objects(ibgda_state, device);
Expand Down Expand Up @@ -4792,6 +4793,10 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table,
continue;
}

device->tclass_val = nvshmemt_ib_get_tclass(name, ibgda_state->port_ids[i],
ibgda_state->log_level, ibgda_state->options);
INFO(ibgda_state->log_level, "traffic class value for %s is %d.", name, device->tclass_val);

/* Report whether we need to do atomic endianness conversions on 8 byte operands. */
status = nvshmemt_ib_common_query_endianness_conversion_size(&atomic_host_endian_size,
device->common_device.context);
Expand Down