-
Notifications
You must be signed in to change notification settings - Fork 229
Add a "Getting Started" page to the documentation #720
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
1d1171f
Add a CUDA graphs example
shwina f161652
Add a brief description for each of the examples
shwina dc63296
Add getting started page
shwina 08919d9
Load nvrtc using path_finder
shwina aa0e36f
Tell CuPy to use our stream
shwina 5f12bb8
Fix dependency specification
shwina 0c9f25f
Too exciting
shwina 4af0bb7
Don't load nvrtc in cuda.core
shwina 310f391
Revert
shwina 806b9a7
update CI handling of installing cuda.core dependencies when testing …
leofang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| # Overview | ||
|
|
||
| ## What is `cuda core`? | ||
|
|
||
| `cuda.core` provides a Pythonic interface to the CUDA runtime and other functionality, | ||
| including: | ||
|
|
||
| - Compiling and launching CUDA kernels | ||
| - Asynchronous concurrent execution with CUDA graphs, streams and events | ||
| - Coordinating work across multiple CUDA devices | ||
| - Allocating, transfering, and managing device memory | ||
| - Runtime linking of device code with Link-Time Optimization (LTO) | ||
| - and much more! | ||
|
|
||
| Rather than providing 1:1 equivalents of the CUDA driver and runtime APIs | ||
| (for that, see [`cuda.bindings`][bindings]), `cuda.core` provides high-level constructs such as: | ||
|
|
||
| - {class}`Device <cuda.core.experimental.Device>` class for GPU device operations and context management. | ||
| - {class}`Buffer <cuda.core.experimental.Buffer>` and {class}`MemoryResource <cuda.core.experimental.MemoryResource>` classes for memory allocation and management. | ||
| - {class}`Program <cuda.core.experimental.Program>` for JIT compilation of CUDA kernels. | ||
| - {class}`GraphBuilder <cuda.core.experimental.GraphBuilder>` for building and executing CUDA graphs. | ||
| - {class}`Stream <cuda.core.experimental.Stream>` and {class}`Event <cuda.core.experimental.Event>` for asynchronous execution and timing. | ||
|
|
||
| ## Example: Compiling and Launching a CUDA kernel | ||
|
|
||
| To get a taste for `cuda.core`, let's walk through a simple example that compiles and launches a vector addition kernel. | ||
| You can find the complete example in [`vector_add.py`][vector_add_example]. | ||
|
|
||
| First, we define a string containing the CUDA C++ kernel. Note that this is a templated kernel: | ||
|
|
||
| ```python | ||
| # compute c = a + b | ||
| code = """ | ||
| template<typename T> | ||
| __global__ void vector_add(const T* A, | ||
| const T* B, | ||
| T* C, | ||
| size_t N) { | ||
| const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; | ||
| for (size_t i=tid; i<N; i+=gridDim.x*blockDim.x) { | ||
| C[tid] = A[tid] + B[tid]; | ||
| } | ||
| } | ||
| """ | ||
| ``` | ||
|
|
||
| Next, we create a {class}`Device <cuda.core.experimental.Device>` object | ||
| and a corresponding {class}`Stream <cuda.core.experimental.Stream>`. | ||
| Don't forget to use {meth}`Device.set_current() <cuda.core.experimental.Device.set_current>`! | ||
|
|
||
| ```python | ||
| from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch | ||
|
|
||
| dev = Device() | ||
| dev.set_current() | ||
| s = dev.create_stream() | ||
| ``` | ||
|
|
||
| Next, we compile the CUDA C++ kernel from earlier using the {class}`Program <cuda.core.experimental.Program>` class. | ||
| The result of the compilation is saved as a CUBIN. | ||
| Note the use of the `name_expressions` parameter to the {meth}`Program.compile() <cuda.core.experimental.Program.compile>` method to specify which kernel template instantiations to compile: | ||
|
|
||
| ```python | ||
| arch = "".join(f"{i}" for i in dev.compute_capability) | ||
leofang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}") | ||
| prog = Program(code, code_type="c++", options=program_options) | ||
| mod = prog.compile("cubin", name_expressions=("vector_add<float>",)) | ||
| ``` | ||
|
|
||
| Next, we retrieve the compiled kernel from the CUBIN and prepare the arguments and kernel configuration. | ||
| We're using [CuPy][cupy] arrays as inputs for this example, but you can use PyTorch tensors too | ||
| (we show how to do this in one of our [examples][examples]). | ||
|
|
||
| ```python | ||
| ker = mod.get_kernel("vector_add<float>") | ||
|
|
||
| # Prepare input/output arrays (using CuPy) | ||
| size = 50000 | ||
| a = rng.random(size, dtype=cp.float32) | ||
| b = rng.random(size, dtype=cp.float32) | ||
| c = cp.empty_like(a) | ||
|
|
||
| # Configure launch parameters | ||
| block = 256 | ||
| grid = (size + block - 1) // block | ||
| config = LaunchConfig(grid=grid, block=block) | ||
| ``` | ||
|
|
||
| Finally, we use the {func}`launch <cuda.core.experimental.launch>` function to execute our kernel on the specified stream with the given configuration and arguments. Note the use of `.data.ptr` to get the pointer to the array data. | ||
|
|
||
| ```python | ||
| launch(s, config, ker, a.data.ptr, b.data.ptr, c.data.ptr, cp.uint64(size)) | ||
| s.sync() | ||
| ``` | ||
|
|
||
| This example demonstrates one of the core workflows enabled by `cuda.core`: compiling and launching CUDA code. | ||
| Note the clean, Pythonic interface, and absense of any direct calls to the CUDA runtime/driver APIs. | ||
|
|
||
| ## Examples and Recipes | ||
|
|
||
| As we mentioned before, `cuda.core` can do much more than just compile and launch kernels. | ||
|
|
||
| The best way to explore and learn the different features `cuda.core` is through | ||
| our [`examples`][examples]. Find one that matches your use-case, and modify it to fit your needs! | ||
|
|
||
|
|
||
| [bindings]: https://nvidia.github.io/cuda-python/cuda-bindings/latest/ | ||
| [cai]: https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html | ||
| [cupy]: https://cupy.dev/ | ||
| [dlpack]: https://dmlc.github.io/dlpack/latest/ | ||
| [examples]: https://github.com/NVIDIA/cuda-python/tree/main/cuda_core/examples | ||
| [vector_add_example]: https://github.com/NVIDIA/cuda-python/tree/main/cuda_core/examples/vector_add.py | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,172 @@ | ||
| # Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # ################################################################################ | ||
| # | ||
| # This demo illustrates how to use CUDA graphs to capture and execute | ||
| # multiple kernel launches with minimal overhead. The graph performs a | ||
| # sequence of vector operations: add, multiply, and subtract. | ||
| # | ||
| # ################################################################################ | ||
|
|
||
| import time | ||
|
|
||
| import cupy as cp | ||
|
|
||
| from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch | ||
|
|
||
|
|
||
| def main(): | ||
| # CUDA kernels for vector operations | ||
| code = """ | ||
| template<typename T> | ||
| __global__ void vector_add(const T* A, const T* B, T* C, size_t N) { | ||
| const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; | ||
| for (size_t i = tid; i < N; i += gridDim.x * blockDim.x) { | ||
| C[i] = A[i] + B[i]; | ||
| } | ||
| } | ||
|
|
||
| template<typename T> | ||
| __global__ void vector_multiply(const T* A, const T* B, T* C, size_t N) { | ||
| const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; | ||
| for (size_t i = tid; i < N; i += gridDim.x * blockDim.x) { | ||
| C[i] = A[i] * B[i]; | ||
| } | ||
| } | ||
|
|
||
| template<typename T> | ||
| __global__ void vector_subtract(const T* A, const T* B, T* C, size_t N) { | ||
| const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; | ||
| for (size_t i = tid; i < N; i += gridDim.x * blockDim.x) { | ||
| C[i] = A[i] - B[i]; | ||
| } | ||
| } | ||
| """ | ||
|
|
||
| # Initialize device and stream | ||
| dev = Device() | ||
| dev.set_current() | ||
| stream = dev.create_stream() | ||
| # tell CuPy to use our stream as the current stream: | ||
| cp.cuda.ExternalStream(int(stream.handle)).use() | ||
|
|
||
| # Compile the program | ||
| arch = "".join(f"{i}" for i in dev.compute_capability) | ||
| program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}") | ||
| prog = Program(code, code_type="c++", options=program_options) | ||
| mod = prog.compile( | ||
| "cubin", name_expressions=("vector_add<float>", "vector_multiply<float>", "vector_subtract<float>") | ||
| ) | ||
|
|
||
| # Get kernel functions | ||
| add_kernel = mod.get_kernel("vector_add<float>") | ||
| multiply_kernel = mod.get_kernel("vector_multiply<float>") | ||
| subtract_kernel = mod.get_kernel("vector_subtract<float>") | ||
|
|
||
| # Prepare data | ||
| size = 1000000 | ||
| dtype = cp.float32 | ||
|
|
||
| # Create input arrays | ||
| rng = cp.random.default_rng(42) # Fixed seed for reproducibility | ||
| a = rng.random(size, dtype=dtype) | ||
| b = rng.random(size, dtype=dtype) | ||
| c = rng.random(size, dtype=dtype) | ||
|
|
||
| # Create output arrays | ||
| result1 = cp.empty_like(a) | ||
| result2 = cp.empty_like(a) | ||
| result3 = cp.empty_like(a) | ||
|
|
||
| # Prepare launch configuration | ||
| block_size = 256 | ||
| grid_size = (size + block_size - 1) // block_size | ||
| config = LaunchConfig(grid=grid_size, block=block_size) | ||
|
|
||
| # Sync before graph capture | ||
| dev.sync() | ||
|
|
||
| print("Building CUDA graph...") | ||
|
|
||
| # Build the graph | ||
| graph_builder = stream.create_graph_builder() | ||
| graph_builder.begin_building() | ||
|
|
||
| # Add multiple kernel launches to the graph | ||
| # Kernel 1: result1 = a + b | ||
| launch(graph_builder, config, add_kernel, a.data.ptr, b.data.ptr, result1.data.ptr, cp.uint64(size)) | ||
|
|
||
| # Kernel 2: result2 = result1 * c | ||
| launch(graph_builder, config, multiply_kernel, result1.data.ptr, c.data.ptr, result2.data.ptr, cp.uint64(size)) | ||
|
|
||
| # Kernel 3: result3 = result2 - a | ||
| launch(graph_builder, config, subtract_kernel, result2.data.ptr, a.data.ptr, result3.data.ptr, cp.uint64(size)) | ||
|
|
||
| # Complete the graph | ||
| graph = graph_builder.end_building().complete() | ||
|
|
||
| print("Graph built successfully!") | ||
|
|
||
| # Upload the graph to the stream | ||
| graph.upload(stream) | ||
|
|
||
| # Execute the entire graph with a single launch | ||
| print("Executing graph...") | ||
| start_time = time.time() | ||
| graph.launch(stream) | ||
| stream.sync() | ||
| end_time = time.time() | ||
|
|
||
| graph_execution_time = end_time - start_time | ||
| print(f"Graph execution time: {graph_execution_time:.6f} seconds") | ||
|
|
||
| # Verify results | ||
| expected_result1 = a + b | ||
| expected_result2 = expected_result1 * c | ||
| expected_result3 = expected_result2 - a | ||
|
|
||
| print("Verifying results...") | ||
| assert cp.allclose(result1, expected_result1, rtol=1e-5, atol=1e-5), "Result 1 mismatch" | ||
| assert cp.allclose(result2, expected_result2, rtol=1e-5, atol=1e-5), "Result 2 mismatch" | ||
| assert cp.allclose(result3, expected_result3, rtol=1e-5, atol=1e-5), "Result 3 mismatch" | ||
| print("All results verified successfully!") | ||
|
|
||
| # Demonstrate performance benefit by running the same operations without graph | ||
| print("\nRunning same operations without graph for comparison...") | ||
|
|
||
| # Reset results | ||
| result1.fill(0) | ||
| result2.fill(0) | ||
| result3.fill(0) | ||
|
|
||
| start_time = time.time() | ||
|
|
||
| # Individual kernel launches | ||
| launch(stream, config, add_kernel, a.data.ptr, b.data.ptr, result1.data.ptr, cp.uint64(size)) | ||
| launch(stream, config, multiply_kernel, result1.data.ptr, c.data.ptr, result2.data.ptr, cp.uint64(size)) | ||
| launch(stream, config, subtract_kernel, result2.data.ptr, a.data.ptr, result3.data.ptr, cp.uint64(size)) | ||
|
|
||
| stream.sync() | ||
| end_time = time.time() | ||
|
|
||
| individual_execution_time = end_time - start_time | ||
| print(f"Individual kernel execution time: {individual_execution_time:.6f} seconds") | ||
|
|
||
| # Calculate speedup | ||
| speedup = individual_execution_time / graph_execution_time | ||
| print(f"Graph provides {speedup:.2f}x speedup") | ||
|
|
||
| # Verify results again | ||
| assert cp.allclose(result1, expected_result1, rtol=1e-5, atol=1e-5), "Result 1 mismatch" | ||
| assert cp.allclose(result2, expected_result2, rtol=1e-5, atol=1e-5), "Result 2 mismatch" | ||
| assert cp.allclose(result3, expected_result3, rtol=1e-5, atol=1e-5), "Result 3 mismatch" | ||
|
|
||
| cp.cuda.Stream.null.use() # reset CuPy's current stream to the null stream | ||
|
|
||
| print("\nExample completed successfully!") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.