Skip to content

Commit 1d1171f

Browse files
committed
Add a CUDA graphs example
1 parent fd8e07b commit 1d1171f

File tree

1 file changed

+168
-0
lines changed

1 file changed

+168
-0
lines changed

cuda_core/examples/cuda_graphs.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
# ################################################################################
6+
#
7+
# This demo illustrates how to use CUDA graphs to capture and execute
8+
# multiple kernel launches with minimal overhead. The graph performs a
9+
# sequence of vector operations: add, multiply, and subtract.
10+
#
11+
# ################################################################################
12+
13+
import time
14+
15+
import cupy as cp
16+
17+
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch
18+
19+
20+
def main():
21+
# CUDA kernels for vector operations
22+
code = """
23+
template<typename T>
24+
__global__ void vector_add(const T* A, const T* B, T* C, size_t N) {
25+
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
26+
for (size_t i = tid; i < N; i += gridDim.x * blockDim.x) {
27+
C[i] = A[i] + B[i];
28+
}
29+
}
30+
31+
template<typename T>
32+
__global__ void vector_multiply(const T* A, const T* B, T* C, size_t N) {
33+
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
34+
for (size_t i = tid; i < N; i += gridDim.x * blockDim.x) {
35+
C[i] = A[i] * B[i];
36+
}
37+
}
38+
39+
template<typename T>
40+
__global__ void vector_subtract(const T* A, const T* B, T* C, size_t N) {
41+
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
42+
for (size_t i = tid; i < N; i += gridDim.x * blockDim.x) {
43+
C[i] = A[i] - B[i];
44+
}
45+
}
46+
"""
47+
48+
# Initialize device and stream
49+
dev = Device()
50+
dev.set_current()
51+
stream = dev.create_stream()
52+
53+
# Compile the program
54+
arch = "".join(f"{i}" for i in dev.compute_capability)
55+
program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}")
56+
prog = Program(code, code_type="c++", options=program_options)
57+
mod = prog.compile(
58+
"cubin", name_expressions=("vector_add<float>", "vector_multiply<float>", "vector_subtract<float>")
59+
)
60+
61+
# Get kernel functions
62+
add_kernel = mod.get_kernel("vector_add<float>")
63+
multiply_kernel = mod.get_kernel("vector_multiply<float>")
64+
subtract_kernel = mod.get_kernel("vector_subtract<float>")
65+
66+
# Prepare data
67+
size = 1000000
68+
dtype = cp.float32
69+
70+
# Create input arrays
71+
rng = cp.random.default_rng(42) # Fixed seed for reproducibility
72+
a = rng.random(size, dtype=dtype)
73+
b = rng.random(size, dtype=dtype)
74+
c = rng.random(size, dtype=dtype)
75+
76+
# Create output arrays
77+
result1 = cp.empty_like(a)
78+
result2 = cp.empty_like(a)
79+
result3 = cp.empty_like(a)
80+
81+
# Prepare launch configuration
82+
block_size = 256
83+
grid_size = (size + block_size - 1) // block_size
84+
config = LaunchConfig(grid=grid_size, block=block_size)
85+
86+
# Sync before graph capture
87+
dev.sync()
88+
89+
print("Building CUDA graph...")
90+
91+
# Build the graph
92+
graph_builder = stream.create_graph_builder()
93+
graph_builder.begin_building()
94+
95+
# Add multiple kernel launches to the graph
96+
# Kernel 1: result1 = a + b
97+
launch(graph_builder, config, add_kernel, a.data.ptr, b.data.ptr, result1.data.ptr, cp.uint64(size))
98+
99+
# Kernel 2: result2 = result1 * c
100+
launch(graph_builder, config, multiply_kernel, result1.data.ptr, c.data.ptr, result2.data.ptr, cp.uint64(size))
101+
102+
# Kernel 3: result3 = result2 - a
103+
launch(graph_builder, config, subtract_kernel, result2.data.ptr, a.data.ptr, result3.data.ptr, cp.uint64(size))
104+
105+
# Complete the graph
106+
graph = graph_builder.end_building().complete()
107+
108+
print("Graph built successfully!")
109+
110+
# Upload the graph to the stream
111+
graph.upload(stream)
112+
113+
# Execute the entire graph with a single launch
114+
print("Executing graph...")
115+
start_time = time.time()
116+
graph.launch(stream)
117+
stream.sync()
118+
end_time = time.time()
119+
120+
graph_execution_time = end_time - start_time
121+
print(f"Graph execution time: {graph_execution_time:.6f} seconds")
122+
123+
# Verify results
124+
expected_result1 = a + b
125+
expected_result2 = expected_result1 * c
126+
expected_result3 = expected_result2 - a
127+
128+
print("Verifying results...")
129+
assert cp.allclose(result1, expected_result1, rtol=1e-5, atol=1e-5), "Result 1 mismatch"
130+
assert cp.allclose(result2, expected_result2, rtol=1e-5, atol=1e-5), "Result 2 mismatch"
131+
assert cp.allclose(result3, expected_result3, rtol=1e-5, atol=1e-5), "Result 3 mismatch"
132+
print("All results verified successfully!")
133+
134+
# Demonstrate performance benefit by running the same operations without graph
135+
print("\nRunning same operations without graph for comparison...")
136+
137+
# Reset results
138+
result1.fill(0)
139+
result2.fill(0)
140+
result3.fill(0)
141+
142+
start_time = time.time()
143+
144+
# Individual kernel launches
145+
launch(stream, config, add_kernel, a.data.ptr, b.data.ptr, result1.data.ptr, cp.uint64(size))
146+
launch(stream, config, multiply_kernel, result1.data.ptr, c.data.ptr, result2.data.ptr, cp.uint64(size))
147+
launch(stream, config, subtract_kernel, result2.data.ptr, a.data.ptr, result3.data.ptr, cp.uint64(size))
148+
149+
stream.sync()
150+
end_time = time.time()
151+
152+
individual_execution_time = end_time - start_time
153+
print(f"Individual kernel execution time: {individual_execution_time:.6f} seconds")
154+
155+
# Calculate speedup
156+
speedup = individual_execution_time / graph_execution_time
157+
print(f"Graph provides {speedup:.2f}x speedup")
158+
159+
# Verify results again
160+
assert cp.allclose(result1, expected_result1, rtol=1e-5, atol=1e-5), "Result 1 mismatch"
161+
assert cp.allclose(result2, expected_result2, rtol=1e-5, atol=1e-5), "Result 2 mismatch"
162+
assert cp.allclose(result3, expected_result3, rtol=1e-5, atol=1e-5), "Result 3 mismatch"
163+
164+
print("\nExample completed successfully!")
165+
166+
167+
if __name__ == "__main__":
168+
main()

0 commit comments

Comments
 (0)