From 0784b39e47a85e7c2890256cae5108b8672472ea Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Wed, 18 Jun 2025 09:47:52 -0400 Subject: [PATCH 01/11] Add memory ops example --- cuda_core/examples/memory_ops.py | 163 +++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 cuda_core/examples/memory_ops.py diff --git a/cuda_core/examples/memory_ops.py b/cuda_core/examples/memory_ops.py new file mode 100644 index 0000000000..ef8f736252 --- /dev/null +++ b/cuda_core/examples/memory_ops.py @@ -0,0 +1,163 @@ +import cupy as cp +import numpy as np +from cuda.core.experimental import ( + Device, LaunchConfig, Program, ProgramOptions, launch, + DeviceMemoryResource, LegacyPinnedMemoryResource, Buffer +) +from cuda.core.experimental._memory import MemoryResource +from cuda.core.experimental._utils.cuda_utils import handle_return +from cuda.bindings import driver + +# Kernel for memory operations +code = """ +extern "C" +__global__ void memory_ops(float* device_data, + float* pinned_data, + size_t N) { + const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < N) { + // Access device memory + device_data[tid] = device_data[tid] + 1.0f; + + // Access pinned memory (zero-copy from GPU) + pinned_data[tid] = pinned_data[tid] * 3.0f; + } +} +""" + +dev = Device() +dev.set_current() +stream = dev.create_stream() + +# Compile kernel +arch = "".join(f"{i}" for i in dev.compute_capability) +program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}") +prog = Program(code, code_type="c++", options=program_options) +mod = prog.compile("cubin") +kernel = mod.get_kernel("memory_ops") + +# Create different memory resources +device_mr = DeviceMemoryResource(dev.device_id) +pinned_mr = LegacyPinnedMemoryResource() + +# Allocate different types of memory +size = 1024 +dtype = cp.float32 +element_size = dtype().itemsize +total_size = size * element_size + +# 1. Device Memory (GPU-only) +device_buffer = device_mr.allocate(total_size, stream=stream) +device_array = cp.ndarray( + size, dtype=dtype, + memptr=cp.cuda.MemoryPointer( + cp.cuda.UnownedMemory(int(device_buffer.handle), device_buffer.size, device_buffer), 0 + ) +) + +# 2. Pinned Memory (CPU memory, GPU accessible) +pinned_buffer = pinned_mr.allocate(total_size, stream=stream) +pinned_array = cp.ndarray( + size, dtype=dtype, + memptr=cp.cuda.MemoryPointer( + cp.cuda.UnownedMemory(int(pinned_buffer.handle), pinned_buffer.size, pinned_buffer), 0 + ) +) + +# Initialize data +rng = cp.random.default_rng() +device_array[:] = rng.random(size, dtype=dtype) +pinned_array[:] = rng.random(size, dtype=dtype) + +# Store original values for verification +device_original = device_array.copy() +pinned_original = pinned_array.copy() + +# Sync before kernel launch +dev.sync() + +# Launch kernel +block = 256 +grid = (size + block - 1) // block +config = LaunchConfig(grid=grid, block=block) + +launch(stream, config, kernel, + device_buffer, pinned_buffer, cp.uint64(size)) +stream.sync() + +# Verify kernel operations +assert cp.allclose(device_array, device_original + 1.0), "Device memory operation failed" +assert cp.allclose(pinned_array, pinned_original * 3.0), "Pinned memory operation failed" + +# Demonstrate buffer copying operations +print("Memory buffer properties:") +print(f"Device buffer - Device accessible: {device_buffer.is_device_accessible}") +print(f"Pinned buffer - Device accessible: {pinned_buffer.is_device_accessible}") + +# Assert memory properties +assert device_buffer.is_device_accessible, "Device buffer should be device accessible" +assert not device_buffer.is_host_accessible, "Device buffer should not be host accessible" +assert pinned_buffer.is_device_accessible, "Pinned buffer should be device accessible" +assert pinned_buffer.is_host_accessible, "Pinned buffer should be host accessible" + +# Copy data between different memory types +print("\nCopying data between memory types...") + +# Copy from device to pinned memory +device_buffer.copy_to(pinned_buffer, stream=stream) +stream.sync() + +# Verify the copy operation +assert cp.allclose(pinned_array, device_array), "Device to pinned copy failed" + +# Create a new device buffer and copy from pinned +new_device_buffer = device_mr.allocate(total_size, stream=stream) +new_device_array = cp.ndarray( + size, dtype=dtype, + memptr=cp.cuda.MemoryPointer( + cp.cuda.UnownedMemory(int(new_device_buffer.handle), new_device_buffer.size, new_device_buffer), 0 + ) +) + +pinned_buffer.copy_to(new_device_buffer, stream=stream) +stream.sync() + +# Verify the copy operation +assert cp.allclose(new_device_array, pinned_array), "Pinned to device copy failed" + +# Demonstrate DLPack integration +print("\nDLPack device information:") +print(f"Device buffer DLPack device: {device_buffer.__dlpack_device__()}") +print(f"Pinned buffer DLPack device: {pinned_buffer.__dlpack_device__()}") + +# Assert DLPack device types +from cuda.core.experimental._memory import DLDeviceType + +device_dlpack = device_buffer.__dlpack_device__() +pinned_dlpack = pinned_buffer.__dlpack_device__() + +assert device_dlpack[0] == DLDeviceType.kDLCUDA, "Device buffer should have CUDA device type" +assert pinned_dlpack[0] == DLDeviceType.kDLCUDAHost, "Pinned buffer should have CUDA host device type" + +# Test buffer size properties +assert device_buffer.size == total_size, f"Device buffer size mismatch: expected {total_size}, got {device_buffer.size}" +assert pinned_buffer.size == total_size, f"Pinned buffer size mismatch: expected {total_size}, got {pinned_buffer.size}" +assert new_device_buffer.size == total_size, f"New device buffer size mismatch: expected {total_size}, got {new_device_buffer.size}" + +# Test memory resource properties +assert device_buffer.memory_resource == device_mr, "Device buffer should use device memory resource" +assert pinned_buffer.memory_resource == pinned_mr, "Pinned buffer should use pinned memory resource" +assert new_device_buffer.memory_resource == device_mr, "New device buffer should use device memory resource" + +# Clean up +device_buffer.close(stream) +pinned_buffer.close(stream) +new_device_buffer.close(stream) +stream.close() + +# Verify buffers are properly closed +assert device_buffer.handle == 0, "Device buffer should be closed" +assert pinned_buffer.handle == 0, "Pinned buffer should be closed" +assert new_device_buffer.handle == 0, "New device buffer should be closed" + +print("Memory management example completed!") \ No newline at end of file From e1d1f9ecec20979cdbfeea353cc1d88ac683f813 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Wed, 18 Jun 2025 10:08:14 -0400 Subject: [PATCH 02/11] Fix handling of buffer with int handle --- cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx index edd2ab2c56..074724b7e0 100644 --- a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx +++ b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx @@ -212,7 +212,13 @@ cdef class ParamHolder: for i, arg in enumerate(kernel_args): if isinstance(arg, Buffer): # we need the address of where the actual buffer address is stored - self.data_addresses[i] = (arg.handle.getPtr()) + if isinstance(arg.handle, int): + # see note below on handling int arguments + prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i) + continue + else: + # it's a CUdeviceptr: + self.data_addresses[i] = (arg.handle.getPtr()) continue elif isinstance(arg, int): # Here's the dilemma: We want to have a fast path to pass in Python From fb65ab89692e3daff8d87164bed6370db42dbf86 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Wed, 18 Jun 2025 10:27:40 -0400 Subject: [PATCH 03/11] pre-commit fixes --- cuda_core/examples/memory_ops.py | 49 +++++++++++++++++++------------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/cuda_core/examples/memory_ops.py b/cuda_core/examples/memory_ops.py index ef8f736252..56d4db64f3 100644 --- a/cuda_core/examples/memory_ops.py +++ b/cuda_core/examples/memory_ops.py @@ -1,24 +1,31 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 + import cupy as cp -import numpy as np + from cuda.core.experimental import ( - Device, LaunchConfig, Program, ProgramOptions, launch, - DeviceMemoryResource, LegacyPinnedMemoryResource, Buffer + Device, + DeviceMemoryResource, + LaunchConfig, + LegacyPinnedMemoryResource, + Program, + ProgramOptions, + launch, ) -from cuda.core.experimental._memory import MemoryResource -from cuda.core.experimental._utils.cuda_utils import handle_return -from cuda.bindings import driver +from cuda.core.experimental._dlpack import DLDeviceType # Kernel for memory operations code = """ extern "C" -__global__ void memory_ops(float* device_data, +__global__ void memory_ops(float* device_data, float* pinned_data, size_t N) { const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid < N) { // Access device memory device_data[tid] = device_data[tid] + 1.0f; - + // Access pinned memory (zero-copy from GPU) pinned_data[tid] = pinned_data[tid] * 3.0f; } @@ -49,19 +56,21 @@ # 1. Device Memory (GPU-only) device_buffer = device_mr.allocate(total_size, stream=stream) device_array = cp.ndarray( - size, dtype=dtype, + size, + dtype=dtype, memptr=cp.cuda.MemoryPointer( cp.cuda.UnownedMemory(int(device_buffer.handle), device_buffer.size, device_buffer), 0 - ) + ), ) # 2. Pinned Memory (CPU memory, GPU accessible) pinned_buffer = pinned_mr.allocate(total_size, stream=stream) pinned_array = cp.ndarray( - size, dtype=dtype, + size, + dtype=dtype, memptr=cp.cuda.MemoryPointer( cp.cuda.UnownedMemory(int(pinned_buffer.handle), pinned_buffer.size, pinned_buffer), 0 - ) + ), ) # Initialize data @@ -81,8 +90,7 @@ grid = (size + block - 1) // block config = LaunchConfig(grid=grid, block=block) -launch(stream, config, kernel, - device_buffer, pinned_buffer, cp.uint64(size)) +launch(stream, config, kernel, device_buffer, pinned_buffer, cp.uint64(size)) stream.sync() # Verify kernel operations @@ -113,10 +121,11 @@ # Create a new device buffer and copy from pinned new_device_buffer = device_mr.allocate(total_size, stream=stream) new_device_array = cp.ndarray( - size, dtype=dtype, + size, + dtype=dtype, memptr=cp.cuda.MemoryPointer( cp.cuda.UnownedMemory(int(new_device_buffer.handle), new_device_buffer.size, new_device_buffer), 0 - ) + ), ) pinned_buffer.copy_to(new_device_buffer, stream=stream) @@ -131,8 +140,6 @@ print(f"Pinned buffer DLPack device: {pinned_buffer.__dlpack_device__()}") # Assert DLPack device types -from cuda.core.experimental._memory import DLDeviceType - device_dlpack = device_buffer.__dlpack_device__() pinned_dlpack = pinned_buffer.__dlpack_device__() @@ -142,7 +149,9 @@ # Test buffer size properties assert device_buffer.size == total_size, f"Device buffer size mismatch: expected {total_size}, got {device_buffer.size}" assert pinned_buffer.size == total_size, f"Pinned buffer size mismatch: expected {total_size}, got {pinned_buffer.size}" -assert new_device_buffer.size == total_size, f"New device buffer size mismatch: expected {total_size}, got {new_device_buffer.size}" +assert new_device_buffer.size == total_size, ( + f"New device buffer size mismatch: expected {total_size}, got {new_device_buffer.size}" +) # Test memory resource properties assert device_buffer.memory_resource == device_mr, "Device buffer should use device memory resource" @@ -160,4 +169,4 @@ assert pinned_buffer.handle == 0, "Pinned buffer should be closed" assert new_device_buffer.handle == 0, "New device buffer should be closed" -print("Memory management example completed!") \ No newline at end of file +print("Memory management example completed!") From 55bd3b135ebf3068aa5cb60704c838fd13bbe1d8 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Wed, 18 Jun 2025 15:10:45 -0400 Subject: [PATCH 04/11] Simplify pinned memory example --- cuda_core/examples/memory_ops.py | 50 ++++---------------------------- 1 file changed, 5 insertions(+), 45 deletions(-) diff --git a/cuda_core/examples/memory_ops.py b/cuda_core/examples/memory_ops.py index 56d4db64f3..92411ae8db 100644 --- a/cuda_core/examples/memory_ops.py +++ b/cuda_core/examples/memory_ops.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import cupy as cp +import numpy as np from cuda.core.experimental import ( Device, @@ -55,28 +56,16 @@ # 1. Device Memory (GPU-only) device_buffer = device_mr.allocate(total_size, stream=stream) -device_array = cp.ndarray( - size, - dtype=dtype, - memptr=cp.cuda.MemoryPointer( - cp.cuda.UnownedMemory(int(device_buffer.handle), device_buffer.size, device_buffer), 0 - ), -) +device_array = cp.from_dlpack(device_buffer).view(dtype=dtype) # 2. Pinned Memory (CPU memory, GPU accessible) pinned_buffer = pinned_mr.allocate(total_size, stream=stream) -pinned_array = cp.ndarray( - size, - dtype=dtype, - memptr=cp.cuda.MemoryPointer( - cp.cuda.UnownedMemory(int(pinned_buffer.handle), pinned_buffer.size, pinned_buffer), 0 - ), -) +pinned_array = np.from_dlpack(pinned_buffer).view(dtype=dtype) # Initialize data rng = cp.random.default_rng() device_array[:] = rng.random(size, dtype=dtype) -pinned_array[:] = rng.random(size, dtype=dtype) +pinned_array[:] = rng.random(size, dtype=dtype).get() # Store original values for verification device_original = device_array.copy() @@ -97,17 +86,6 @@ assert cp.allclose(device_array, device_original + 1.0), "Device memory operation failed" assert cp.allclose(pinned_array, pinned_original * 3.0), "Pinned memory operation failed" -# Demonstrate buffer copying operations -print("Memory buffer properties:") -print(f"Device buffer - Device accessible: {device_buffer.is_device_accessible}") -print(f"Pinned buffer - Device accessible: {pinned_buffer.is_device_accessible}") - -# Assert memory properties -assert device_buffer.is_device_accessible, "Device buffer should be device accessible" -assert not device_buffer.is_host_accessible, "Device buffer should not be host accessible" -assert pinned_buffer.is_device_accessible, "Pinned buffer should be device accessible" -assert pinned_buffer.is_host_accessible, "Pinned buffer should be host accessible" - # Copy data between different memory types print("\nCopying data between memory types...") @@ -120,13 +98,7 @@ # Create a new device buffer and copy from pinned new_device_buffer = device_mr.allocate(total_size, stream=stream) -new_device_array = cp.ndarray( - size, - dtype=dtype, - memptr=cp.cuda.MemoryPointer( - cp.cuda.UnownedMemory(int(new_device_buffer.handle), new_device_buffer.size, new_device_buffer), 0 - ), -) +new_device_array = cp.from_dlpack(new_device_buffer).view(dtype=dtype) pinned_buffer.copy_to(new_device_buffer, stream=stream) stream.sync() @@ -146,18 +118,6 @@ assert device_dlpack[0] == DLDeviceType.kDLCUDA, "Device buffer should have CUDA device type" assert pinned_dlpack[0] == DLDeviceType.kDLCUDAHost, "Pinned buffer should have CUDA host device type" -# Test buffer size properties -assert device_buffer.size == total_size, f"Device buffer size mismatch: expected {total_size}, got {device_buffer.size}" -assert pinned_buffer.size == total_size, f"Pinned buffer size mismatch: expected {total_size}, got {pinned_buffer.size}" -assert new_device_buffer.size == total_size, ( - f"New device buffer size mismatch: expected {total_size}, got {new_device_buffer.size}" -) - -# Test memory resource properties -assert device_buffer.memory_resource == device_mr, "Device buffer should use device memory resource" -assert pinned_buffer.memory_resource == pinned_mr, "Pinned buffer should use pinned memory resource" -assert new_device_buffer.memory_resource == device_mr, "New device buffer should use device memory resource" - # Clean up device_buffer.close(stream) pinned_buffer.close(stream) From 70f0da6d6e78a776dae6a418b054c3c09cac7621 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Wed, 18 Jun 2025 15:45:41 -0400 Subject: [PATCH 05/11] Copy pinned memory tests to test_launcher.py --- cuda_core/examples/memory_ops.py | 2 +- cuda_core/tests/test_launcher.py | 106 ++++++++++++++++++++++++++++++- 2 files changed, 106 insertions(+), 2 deletions(-) diff --git a/cuda_core/examples/memory_ops.py b/cuda_core/examples/memory_ops.py index 92411ae8db..6d54f49eb6 100644 --- a/cuda_core/examples/memory_ops.py +++ b/cuda_core/examples/memory_ops.py @@ -72,7 +72,7 @@ pinned_original = pinned_array.copy() # Sync before kernel launch -dev.sync() +stream.sync() # Launch kernel block = 256 diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index 3a02065de8..75466a12e0 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -9,7 +9,15 @@ import pytest from conftest import skipif_need_cuda_headers -from cuda.core.experimental import Device, LaunchConfig, LegacyPinnedMemoryResource, Program, ProgramOptions, launch +from cuda.core.experimental import ( + Device, + DeviceMemoryResource, + LaunchConfig, + LegacyPinnedMemoryResource, + Program, + ProgramOptions, + launch, +) def test_launch_config_init(init_cuda): @@ -197,3 +205,99 @@ def test_cooperative_launch(): config = LaunchConfig(grid=1, block=1, cooperative_launch=True) launch(s, config, ker) s.sync() + + +@pytest.mark.parametrize( + "memory_resource_class", + [ + DeviceMemoryResource, + LegacyPinnedMemoryResource, + ], +) +def test_launch_with_buffers_allocated_by_memory_resource(init_cuda, memory_resource_class): + """Test that kernels can access memory allocated by memory resources.""" + dev = Device() + dev.set_current() + stream = dev.create_stream() + + # Kernel that operates on memory + code = """ + extern "C" + __global__ void memory_ops(float* data, size_t N) { + const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < N) { + // Access memory (device or pinned) + data[tid] = data[tid] * 3.0f; + } + } + """ + + # Compile kernel + arch = "".join(f"{i}" for i in dev.compute_capability) + program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}") + prog = Program(code, code_type="c++", options=program_options) + mod = prog.compile("cubin") + kernel = mod.get_kernel("memory_ops") + + # Create memory resource + if memory_resource_class == DeviceMemoryResource: + mr = memory_resource_class(dev.device_id) + else: # LegacyPinnedMemoryResource + mr = memory_resource_class() + + # Allocate memory + size = 1024 + dtype = np.float32 + element_size = dtype().itemsize + total_size = size * element_size + + buffer = mr.allocate(total_size, stream=stream) + + # Create array view based on memory type + if mr.is_host_accessible: + # For pinned memory, use numpy + array = np.from_dlpack(buffer).view(dtype=dtype) + else: + # For device memory, use cupy + import cupy as cp + + array = cp.from_dlpack(buffer).view(dtype=dtype) + + # Initialize data with random values + if mr.is_host_accessible: + rng = np.random.default_rng() + array[:] = rng.random(size, dtype=dtype) + else: + import cupy as cp + + rng = cp.random.default_rng() + array[:] = rng.random(size, dtype=dtype) + + # Store original values for verification + original = array.copy() + + # Sync before kernel launch + stream.sync() + + # Launch kernel + block = 256 + grid = (size + block - 1) // block + config = LaunchConfig(grid=grid, block=block) + + launch(stream, config, kernel, buffer, np.uint64(size)) + stream.sync() + + # Verify kernel operations + if mr.is_host_accessible: + assert np.allclose(array, original * 3.0), f"{memory_resource_class.__name__} operation failed" + else: + import cupy as cp + + assert cp.allclose(array, original * 3.0), f"{memory_resource_class.__name__} operation failed" + + # Clean up + buffer.close(stream) + stream.close() + + # Verify buffer is properly closed + assert buffer.handle == 0, f"{memory_resource_class.__name__} buffer should be closed" From 40df59e89068af79353f44b1d5d6c3a2f8113ec8 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Tue, 24 Jun 2025 09:10:59 -0400 Subject: [PATCH 06/11] Remove dlpack assertions and address other review comments --- cuda_core/examples/memory_ops.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/cuda_core/examples/memory_ops.py b/cuda_core/examples/memory_ops.py index 6d54f49eb6..eae086db53 100644 --- a/cuda_core/examples/memory_ops.py +++ b/cuda_core/examples/memory_ops.py @@ -14,7 +14,6 @@ ProgramOptions, launch, ) -from cuda.core.experimental._dlpack import DLDeviceType # Kernel for memory operations code = """ @@ -36,6 +35,8 @@ dev = Device() dev.set_current() stream = dev.create_stream() +# tell CuPy to use our stream as the current stream: +cp.cuda.ExternalStream(int(stream.handle)).use() # Compile kernel arch = "".join(f"{i}" for i in dev.compute_capability) @@ -106,23 +107,12 @@ # Verify the copy operation assert cp.allclose(new_device_array, pinned_array), "Pinned to device copy failed" -# Demonstrate DLPack integration -print("\nDLPack device information:") -print(f"Device buffer DLPack device: {device_buffer.__dlpack_device__()}") -print(f"Pinned buffer DLPack device: {pinned_buffer.__dlpack_device__()}") - -# Assert DLPack device types -device_dlpack = device_buffer.__dlpack_device__() -pinned_dlpack = pinned_buffer.__dlpack_device__() - -assert device_dlpack[0] == DLDeviceType.kDLCUDA, "Device buffer should have CUDA device type" -assert pinned_dlpack[0] == DLDeviceType.kDLCUDAHost, "Pinned buffer should have CUDA host device type" - # Clean up device_buffer.close(stream) pinned_buffer.close(stream) new_device_buffer.close(stream) stream.close() +cp.cuda.Stream.null.use() # reset CuPy's current stream to the null stream # Verify buffers are properly closed assert device_buffer.handle == 0, "Device buffer should be closed" From 0b2e2077aa3b73bf940f9d7b381544de5346b34b Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Tue, 24 Jun 2025 11:30:48 -0400 Subject: [PATCH 07/11] Try addressing issues that may be causing CI failures --- cuda_core/examples/memory_ops.py | 16 ++++++++++++++++ cuda_core/tests/test_launcher.py | 17 ++++++----------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/cuda_core/examples/memory_ops.py b/cuda_core/examples/memory_ops.py index eae086db53..bde4243af4 100644 --- a/cuda_core/examples/memory_ops.py +++ b/cuda_core/examples/memory_ops.py @@ -2,6 +2,18 @@ # # SPDX-License-Identifier: Apache-2.0 +# ################################################################################ +# +# This demo illustrates: +# +# 1. How to use different memory resources to allocate and manage memory +# 2. How to copy data between different memory types +# 3. How to use DLPack to interoperate with other libraries +# +# ################################################################################ + +import sys + import cupy as cp import numpy as np @@ -15,6 +27,10 @@ launch, ) +if np.__version__ < "2.1.0": + print("This example requires NumPy 2.1.0 or later", file=sys.stderr) + sys.exit(0) + # Kernel for memory operations code = """ extern "C" diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index 75466a12e0..9b632c628a 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -5,6 +5,7 @@ import os import pathlib +import cupy as cp import numpy as np import pytest from conftest import skipif_need_cuda_headers @@ -219,6 +220,8 @@ def test_launch_with_buffers_allocated_by_memory_resource(init_cuda, memory_reso dev = Device() dev.set_current() stream = dev.create_stream() + # tell CuPy to use our stream as the current stream: + cp.cuda.ExternalStream(int(stream.handle)).use() # Kernel that operates on memory code = """ @@ -258,9 +261,6 @@ def test_launch_with_buffers_allocated_by_memory_resource(init_cuda, memory_reso # For pinned memory, use numpy array = np.from_dlpack(buffer).view(dtype=dtype) else: - # For device memory, use cupy - import cupy as cp - array = cp.from_dlpack(buffer).view(dtype=dtype) # Initialize data with random values @@ -268,8 +268,6 @@ def test_launch_with_buffers_allocated_by_memory_resource(init_cuda, memory_reso rng = np.random.default_rng() array[:] = rng.random(size, dtype=dtype) else: - import cupy as cp - rng = cp.random.default_rng() array[:] = rng.random(size, dtype=dtype) @@ -288,16 +286,13 @@ def test_launch_with_buffers_allocated_by_memory_resource(init_cuda, memory_reso stream.sync() # Verify kernel operations - if mr.is_host_accessible: - assert np.allclose(array, original * 3.0), f"{memory_resource_class.__name__} operation failed" - else: - import cupy as cp - - assert cp.allclose(array, original * 3.0), f"{memory_resource_class.__name__} operation failed" + assert cp.allclose(array, original * 3.0), f"{memory_resource_class.__name__} operation failed" # Clean up buffer.close(stream) stream.close() + cp.cuda.Stream.null.use() # reset CuPy's current stream to the null stream + # Verify buffer is properly closed assert buffer.handle == 0, f"{memory_resource_class.__name__} buffer should be closed" From 2699ff1858fee3972f5a522f1419ab28507453cc Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Wed, 25 Jun 2025 06:23:42 -0400 Subject: [PATCH 08/11] Use per device device MR, add numpy requirement to test --- cuda_core/examples/memory_ops.py | 3 +-- cuda_core/tests/test_launcher.py | 8 +++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/cuda_core/examples/memory_ops.py b/cuda_core/examples/memory_ops.py index bde4243af4..ceff29dd34 100644 --- a/cuda_core/examples/memory_ops.py +++ b/cuda_core/examples/memory_ops.py @@ -19,7 +19,6 @@ from cuda.core.experimental import ( Device, - DeviceMemoryResource, LaunchConfig, LegacyPinnedMemoryResource, Program, @@ -62,7 +61,7 @@ kernel = mod.get_kernel("memory_ops") # Create different memory resources -device_mr = DeviceMemoryResource(dev.device_id) +device_mr = dev.memory_resource pinned_mr = LegacyPinnedMemoryResource() # Allocate different types of memory diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index 9b632c628a..3095752757 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -212,7 +212,13 @@ def test_cooperative_launch(): "memory_resource_class", [ DeviceMemoryResource, - LegacyPinnedMemoryResource, + pytest.param( + LegacyPinnedMemoryResource, + marks=pytest.mark.skipif( + tuple(int(i) for i in np.__version__.split(".")[:3]) < (2, 2, 5), + reason="need numpy 2.2.5+, numpy GH #28632", + ), + ), ], ) def test_launch_with_buffers_allocated_by_memory_resource(init_cuda, memory_resource_class): From 246e8a1b09f9a8c03ffd91d81c417e88d5d4ea69 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Wed, 25 Jun 2025 16:54:36 -0400 Subject: [PATCH 09/11] Use SynchronousMemoryResource if memory pools are not supported --- cuda_core/tests/test_launcher.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index 3095752757..eeab1957f9 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -10,6 +10,7 @@ import pytest from conftest import skipif_need_cuda_headers +from cuda.bindings import driver from cuda.core.experimental import ( Device, DeviceMemoryResource, @@ -19,6 +20,8 @@ ProgramOptions, launch, ) +from cuda.core.experimental._memory import _SynchronousMemoryResource +from cuda.core.experimental._utils.cuda_utils import handle_return def test_launch_config_init(init_cuda): @@ -211,7 +214,7 @@ def test_cooperative_launch(): @pytest.mark.parametrize( "memory_resource_class", [ - DeviceMemoryResource, + "device_memory_resource", # kludgy, but can go away after #726 is resolved pytest.param( LegacyPinnedMemoryResource, marks=pytest.mark.skipif( @@ -249,9 +252,18 @@ def test_launch_with_buffers_allocated_by_memory_resource(init_cuda, memory_reso kernel = mod.get_kernel("memory_ops") # Create memory resource - if memory_resource_class == DeviceMemoryResource: - mr = memory_resource_class(dev.device_id) - else: # LegacyPinnedMemoryResource + if memory_resource_class == "device_memory_resource": + if ( + handle_return( + driver.cuDeviceGetAttribute( + driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev.device_id + ) + ) + ) == 1: + mr = DeviceMemoryResource(dev.device_id) + else: + mr = _SynchronousMemoryResource(dev.device_id) + else: mr = memory_resource_class() # Allocate memory From 7e3c468dfc98f5b73aa8abd37bc616aa74fe3b88 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Wed, 25 Jun 2025 20:01:47 -0400 Subject: [PATCH 10/11] apply nit --- cuda_core/tests/test_launcher.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index eeab1957f9..d1151a3e28 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -253,13 +253,7 @@ def test_launch_with_buffers_allocated_by_memory_resource(init_cuda, memory_reso # Create memory resource if memory_resource_class == "device_memory_resource": - if ( - handle_return( - driver.cuDeviceGetAttribute( - driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev.device_id - ) - ) - ) == 1: + if dev.properties.memory_pools_supported: mr = DeviceMemoryResource(dev.device_id) else: mr = _SynchronousMemoryResource(dev.device_id) From c2ff8cc313387998afb2155f5df8b7835ffd9b5a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Jun 2025 00:02:54 +0000 Subject: [PATCH 11/11] [pre-commit.ci] auto code formatting --- cuda_core/tests/test_launcher.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index d1151a3e28..a6648d8a4b 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -10,7 +10,6 @@ import pytest from conftest import skipif_need_cuda_headers -from cuda.bindings import driver from cuda.core.experimental import ( Device, DeviceMemoryResource, @@ -21,7 +20,6 @@ launch, ) from cuda.core.experimental._memory import _SynchronousMemoryResource -from cuda.core.experimental._utils.cuda_utils import handle_return def test_launch_config_init(init_cuda):