Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 53 additions & 31 deletions src/PeerUtils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,35 @@

#include "CudaUtils.cuh"

// CUDA 13 added a cudaGraphEdgeData* parameter to the graph dependency
// and stream capture APIs. These shims present the CUDA 12 signature
// regardless of version, passing nullptr for the new parameter on 13+.
#if CUDART_VERSION >= 13000
#define cudaGraphAddDependencies_compat(g, from, to, n) \
cudaGraphAddDependencies(g, from, to, nullptr, n)
#define cudaGraphNodeGetDependencies_compat(node, deps, n) \
cudaGraphNodeGetDependencies(node, deps, nullptr, n)
#define cudaGraphNodeGetDependentNodes_compat(node, deps, n) \
cudaGraphNodeGetDependentNodes(node, deps, nullptr, n)
#define cudaStreamGetCaptureInfo_compat(s, status, id, graph, deps, ndeps) \
cudaStreamGetCaptureInfo(s, status, id, graph, deps, nullptr, ndeps)
#define cudaStreamUpdateCaptureDependencies_compat(s, nodes, n, flags) \
cudaStreamUpdateCaptureDependencies(s, nodes, nullptr, n, flags)
#else
#define cudaGraphAddDependencies_compat cudaGraphAddDependencies
#define cudaGraphNodeGetDependencies_compat cudaGraphNodeGetDependencies
#define cudaGraphNodeGetDependentNodes_compat cudaGraphNodeGetDependentNodes
#define cudaStreamGetCaptureInfo_compat cudaStreamGetCaptureInfo
#define cudaStreamUpdateCaptureDependencies_compat cudaStreamUpdateCaptureDependencies
#endif

std::vector<cudaGraphNode_t> get_current_capture_dependencies(cudaStream_t s) {
cudaStreamCaptureStatus status;
cudaGraphNode_t* depNodes = nullptr;
size_t numDeps = 0;

cudaError_t err =
cudaStreamGetCaptureInfo(s, &status, nullptr, nullptr, ((const cudaGraphNode_t**)&depNodes), &numDeps);
cudaStreamGetCaptureInfo_compat(s, &status, nullptr, nullptr, ((const cudaGraphNode_t**)&depNodes), &numDeps);

std::vector<cudaGraphNode_t> deps;
if (err == cudaSuccess && numDeps > 0 && depNodes) {
Expand Down Expand Up @@ -99,7 +121,7 @@ static void launch(cudaFunction_t kernel, dim3 grid, dim3 block, void** args, ui
// === Step 7: Retrieve snapshot of graph AFTER launch ===
// NCCL explicitly updates capture dependencies after cuLaunchKernelEx
cudaError_t err =
cudaStreamUpdateCaptureDependencies(s, // The capturing stream
cudaStreamUpdateCaptureDependencies_compat(s, // The capturing stream
nullptr, // No specific nodes to depend on
0, // No additional nodes
cudaStreamAddCaptureDependencies // Add implicit dependencies
Expand All @@ -121,7 +143,7 @@ static void launch(cudaFunction_t kernel, dim3 grid, dim3 block, void** args, ui
cudaGraphNode_t* deps_after_array = nullptr;
size_t numDeps_after = 0;

err = cudaStreamGetCaptureInfo(s, &status,
err = cudaStreamGetCaptureInfo_compat(s, &status,
nullptr, // captureID (optional)
&captured_graph, (const cudaGraphNode_t**)&deps_after_array, &numDeps_after);

Expand Down Expand Up @@ -177,7 +199,7 @@ static void launch(cudaFunction_t kernel, dim3 grid, dim3 block, void** args, ui
std::cout << "[P2P] Adding explicit dependencies from " << deps_before.size()
<< " predecessor nodes\n";
CudaCheckErrorModNoSync;
err = cudaGraphAddDependencies(captured_graph,
err = cudaGraphAddDependencies_compat(captured_graph,
deps_before.data(), // From these nodes
&latest_node, // To this kernel node
deps_before.size() // This many dependencies
Expand Down Expand Up @@ -229,7 +251,7 @@ void pollingKernel(TimelineSemaphore* gpu1_complete_flag, uint64_t value, cudaSt
const cudaGraphNode_t* deps;
size_t dep_count;

cudaStreamGetCaptureInfo(s, &capture_status, nullptr, &capturing_graph, &deps, &dep_count);
cudaStreamGetCaptureInfo_compat(s, &capture_status, nullptr, &capturing_graph, &deps, &dep_count);

// ========================================================================
// Create peer access kernel node with V2 parameters (supports attributes)
Expand Down Expand Up @@ -259,7 +281,7 @@ void pollingKernel(TimelineSemaphore* gpu1_complete_flag, uint64_t value, cudaSt
cudaGraphKernelNodeSetAttribute(peer_kernel_node, cudaLaunchAttributeMemSyncDomain, &attr_value);

// Update stream dependencies so subsequent work depends on peer kernel
cudaStreamUpdateCaptureDependencies(s, &peer_kernel_node, 1, 1);
cudaStreamUpdateCaptureDependencies_compat(s, &peer_kernel_node, 1, 1);
}
}

Expand Down Expand Up @@ -295,7 +317,7 @@ void notifyKernel(TimelineSemaphore* gpu1_complete_flag, uint64_t value, cudaStr
const cudaGraphNode_t* deps;
size_t dep_count;

cudaStreamGetCaptureInfo(s, &capture_status, nullptr, &capturing_graph, &deps, &dep_count);
cudaStreamGetCaptureInfo_compat(s, &capture_status, nullptr, &capturing_graph, &deps, &dep_count);

// ========================================================================
// Create peer access kernel node with V2 parameters (supports attributes)
Expand Down Expand Up @@ -325,7 +347,7 @@ void notifyKernel(TimelineSemaphore* gpu1_complete_flag, uint64_t value, cudaStr
cudaGraphKernelNodeSetAttribute(peer_kernel_node, cudaLaunchAttributeMemSyncDomain, &attr_value);

// Update stream dependencies so subsequent work depends on peer kernel
cudaStreamUpdateCaptureDependencies(s, &peer_kernel_node, 1, 1);
cudaStreamUpdateCaptureDependencies_compat(s, &peer_kernel_node, 1, 1);
}
}
__global__ void p2p_polling_kernel(volatile uint32_t* completion_flag, uint32_t value) {
Expand All @@ -346,7 +368,7 @@ __global__ void hostpin_polling_kernel(TimelineSemaphore* completion_flag, uint6
}
bool is_stream_being_captured(cudaStream_t stream) {
cudaStreamCaptureStatus capture_status;
cudaStreamGetCaptureInfo(stream, &capture_status, nullptr, nullptr, nullptr, nullptr);
cudaStreamGetCaptureInfo_compat(stream, &capture_status, nullptr, nullptr, nullptr, nullptr);

return capture_status == cudaStreamCaptureStatusActive;
}
Expand All @@ -360,7 +382,7 @@ void verify_all_streams_joined(cudaStream_t main_stream) {

// Get all subsidiary streams involved
cudaGraph_t capturing_graph;
cudaStreamGetCaptureInfo(main_stream, &status, NULL, &capturing_graph, NULL, NULL);
cudaStreamGetCaptureInfo_compat(main_stream, &status, NULL, &capturing_graph, NULL, NULL);

// Get all nodes to check for any outstanding work
size_t num_nodes;
Expand Down Expand Up @@ -424,19 +446,19 @@ void printGraphDependencies(cudaGraph_t graph, const char* name) {

// Incoming edges (dependencies)
size_t num_incoming;
cudaGraphNodeGetDependencies(nodes[i], NULL, &num_incoming);
cudaGraphNodeGetDependencies_compat(nodes[i], NULL, &num_incoming);

// Outgoing edges (dependents)
size_t num_outgoing;

cudaGraphNodeGetDependentNodes(nodes[i], NULL, &num_outgoing);
cudaGraphNodeGetDependentNodes_compat(nodes[i], NULL, &num_outgoing);

printf("Node[%zu]: %s | Incoming: %zu | Outgoing: %zu\n", i, getNodeTypeName(type), num_incoming, num_outgoing);

// Print incoming edges
if (num_incoming > 0) {
std::vector<cudaGraphNode_t> deps(num_incoming);
cudaGraphNodeGetDependencies(nodes[i], deps.data(), &num_incoming);
cudaGraphNodeGetDependencies_compat(nodes[i], deps.data(), &num_incoming);

printf(" ← Depends on: ");
for (size_t j = 0; j < num_incoming; j++) {
Expand All @@ -450,7 +472,7 @@ void printGraphDependencies(cudaGraph_t graph, const char* name) {
// Print outgoing edges
if (num_outgoing > 0) {
std::vector<cudaGraphNode_t> dependents(num_outgoing);
cudaGraphNodeGetDependentNodes(nodes[i], dependents.data(), &num_outgoing);
cudaGraphNodeGetDependentNodes_compat(nodes[i], dependents.data(), &num_outgoing);

printf(" → Used by: ");
for (size_t j = 0; j < num_outgoing; j++) {
Expand All @@ -477,7 +499,7 @@ void printGraphDependencies(cudaGraph_t graph, const char* name) {
int source_count = 0;
for (size_t i = 0; i < num_nodes; i++) {
size_t num_deps;
cudaGraphNodeGetDependencies(nodes[i], NULL, &num_deps);
cudaGraphNodeGetDependencies_compat(nodes[i], NULL, &num_deps);
if (num_deps == 0) {
printf("[%zu] ", i);
source_count++;
Expand All @@ -490,7 +512,7 @@ void printGraphDependencies(cudaGraph_t graph, const char* name) {
int sink_count = 0;
for (size_t i = 0; i < num_nodes; i++) {
size_t num_dependents;
cudaGraphNodeGetDependentNodes(nodes[i], NULL, &num_dependents);
cudaGraphNodeGetDependentNodes_compat(nodes[i], NULL, &num_dependents);
if (num_dependents == 0) {
printf("[%zu] ", i);
sink_count++;
Expand All @@ -503,7 +525,7 @@ void printGraphDependencies(cudaGraph_t graph, const char* name) {
bool all_connected = true;
for (size_t i = 0; i < num_nodes; i++) {
size_t num_deps, num_dependents;
cudaGraphNodeGetDependentNodes(nodes[i], NULL, &num_deps);
cudaGraphNodeGetDependentNodes_compat(nodes[i], NULL, &num_deps);

// Orphaned if no incoming AND no outgoing (except sources/sinks)
if (num_deps == 0 && num_dependents == 0 && num_nodes > 1) {
Expand All @@ -527,11 +549,11 @@ bool pathExists(cudaGraphNode_t source, cudaGraphNode_t sink, const std::vector<

// Get dependents of source
size_t num_dependents;
cudaGraphNodeGetDependentNodes(source, NULL, &num_dependents);
cudaGraphNodeGetDependentNodes_compat(source, NULL, &num_dependents);

if (num_dependents > 0) {
std::vector<cudaGraphNode_t> dependents(num_dependents);
cudaGraphNodeGetDependentNodes(source, dependents.data(), &num_dependents);
cudaGraphNodeGetDependentNodes_compat(source, dependents.data(), &num_dependents);

for (auto& dep : dependents) {
if (visited.find((uintptr_t)dep) == visited.end()) {
Expand Down Expand Up @@ -575,17 +597,17 @@ void printGraphDependencies2(cudaGraph_t graph, const char* name) {
cudaGraphNodeGetType(nodes[i], &type);
CudaCheckErrorModNoSync;
size_t num_incoming;
cudaGraphNodeGetDependencies(nodes[i], NULL, &num_incoming);
cudaGraphNodeGetDependencies_compat(nodes[i], NULL, &num_incoming);
CudaCheckErrorModNoSync;
size_t num_outgoing;
cudaGraphNodeGetDependentNodes(nodes[i], NULL, &num_outgoing);
cudaGraphNodeGetDependentNodes_compat(nodes[i], NULL, &num_outgoing);
CudaCheckErrorModNoSync;
printf("Node[%zu]: %s | Incoming: %zu | Outgoing: %zu\n", i, getNodeTypeName(type), num_incoming, num_outgoing);

// Print incoming edges
if (num_incoming > 0) {
std::vector<cudaGraphNode_t> deps(num_incoming);
cudaGraphNodeGetDependencies(nodes[i], deps.data(), &num_incoming);
cudaGraphNodeGetDependencies_compat(nodes[i], deps.data(), &num_incoming);
CudaCheckErrorModNoSync;
printf(" ← Depends on: ");
for (size_t j = 0; j < num_incoming; j++) {
Expand All @@ -599,7 +621,7 @@ void printGraphDependencies2(cudaGraph_t graph, const char* name) {
// Print outgoing edges
if (num_outgoing > 0) {
std::vector<cudaGraphNode_t> dependents(num_outgoing);
cudaGraphNodeGetDependentNodes(nodes[i], dependents.data(), &num_outgoing);
cudaGraphNodeGetDependentNodes_compat(nodes[i], dependents.data(), &num_outgoing);
CudaCheckErrorModNoSync;
printf(" → Used by: ");
for (size_t j = 0; j < num_outgoing; j++) {
Expand Down Expand Up @@ -628,7 +650,7 @@ void printGraphDependencies2(cudaGraph_t graph, const char* name) {
printf("Source nodes: ");
for (size_t i = 0; i < num_nodes; i++) {
size_t num_deps;
cudaGraphNodeGetDependencies(nodes[i], NULL, &num_deps);
cudaGraphNodeGetDependencies_compat(nodes[i], NULL, &num_deps);
if (num_deps == 0) {
printf("[%zu] ", i);
source_nodes.push_back(i);
Expand All @@ -641,7 +663,7 @@ void printGraphDependencies2(cudaGraph_t graph, const char* name) {
printf("Sink nodes: ");
for (size_t i = 0; i < num_nodes; i++) {
size_t num_dependents;
cudaGraphNodeGetDependentNodes(nodes[i], NULL, &num_dependents);
cudaGraphNodeGetDependentNodes_compat(nodes[i], NULL, &num_dependents);
if (num_dependents == 0) {
printf("[%zu] ", i);
sink_nodes.push_back(i);
Expand All @@ -658,13 +680,13 @@ void printGraphDependencies2(cudaGraph_t graph, const char* name) {

for (size_t sink_idx : sink_nodes) {
size_t num_incoming;
cudaGraphNodeGetDependencies(nodes[sink_idx], NULL, &num_incoming);
cudaGraphNodeGetDependencies_compat(nodes[sink_idx], NULL, &num_incoming);
CudaCheckErrorModNoSync;
printf("Sink[%zu]: Has %zu incoming dependencies\n", sink_idx, num_incoming);

// Get dependencies for this sink
std::vector<cudaGraphNode_t> deps(num_incoming);
cudaGraphNodeGetDependencies(nodes[sink_idx], deps.data(), &num_incoming);
cudaGraphNodeGetDependencies_compat(nodes[sink_idx], deps.data(), &num_incoming);
CudaCheckErrorModNoSync;
printf(" Depends on: ");
for (size_t i = 0; i < num_incoming; i++) {
Expand Down Expand Up @@ -747,7 +769,7 @@ void transferKernel(float* src, float* dst, size_t elems, cudaStream_t s, size_t
cudaStreamCaptureStatus _capture_status;
const cudaGraphNode_t* _deps;
size_t _dep_count;
cudaStreamGetCaptureInfo(s, &_capture_status, nullptr, &_capturing_graph, &_deps, &_dep_count);
cudaStreamGetCaptureInfo_compat(s, &_capture_status, nullptr, &_capturing_graph, &_deps, &_dep_count);

cudaGraphNode_t copy_0to1;
cudaMemcpy3DParms memcpyParams = {0};
Expand All @@ -765,7 +787,7 @@ void transferKernel(float* src, float* dst, size_t elems, cudaStream_t s, size_t

cudaGraphAddMemcpyNode(&copy_0to1, _capturing_graph, _deps, _dep_count, &memcpyParams);

cudaStreamUpdateCaptureDependencies(s, &copy_0to1, 1, 1);
cudaStreamUpdateCaptureDependencies_compat(s, &copy_0to1, 1, 1);

} else if (!is_stream_being_captured(s)) {
// p2p_transfer_1d<<<involved_sm, 128, 0, s>>>(src, dst, elems);
Expand Down Expand Up @@ -793,7 +815,7 @@ void transferKernel(float* src, float* dst, size_t elems, cudaStream_t s, size_t
const cudaGraphNode_t* deps;
size_t dep_count;

cudaStreamGetCaptureInfo(s, &capture_status, nullptr, &capturing_graph, &deps, &dep_count);
cudaStreamGetCaptureInfo_compat(s, &capture_status, nullptr, &capturing_graph, &deps, &dep_count);

// ========================================================================
// Create peer access kernel node with V2 parameters (supports attributes)
Expand Down Expand Up @@ -823,7 +845,7 @@ void transferKernel(float* src, float* dst, size_t elems, cudaStream_t s, size_t
cudaGraphKernelNodeSetAttribute(peer_kernel_node, cudaLaunchAttributeMemSyncDomain, &attr_value);

// Update stream dependencies so subsequent work depends on peer kernel
cudaStreamUpdateCaptureDependencies(s, &peer_kernel_node, 1, 1);
cudaStreamUpdateCaptureDependencies_compat(s, &peer_kernel_node, 1, 1);
}
}
__global__ void notify_kernel(volatile uint32_t* gpu_complete_flag, uint32_t value) {
Expand Down