From 874b7f4e4414786eae2ab7d3edb99460c3027ea8 Mon Sep 17 00:00:00 2001 From: Bharat Raghunathan Date: Sun, 30 Nov 2025 23:47:20 -0600 Subject: [PATCH 1/9] chore: Replace isinstance(obj, T) with type(obj) is T comparisons --- .../core/experimental/_kernel_arg_handler.pyx | 68 ++++++++++--------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx index 0bb40bf404..5ba50b42a7 100644 --- a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx +++ b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx @@ -128,27 +128,28 @@ cdef inline int prepare_ctypes_arg( vector.vector[void*]& data_addresses, arg, const size_t idx) except -1: - if isinstance(arg, ctypes_bool): + cdef object arg_type = type(arg) + if arg_type is ctypes_bool: return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_int8): + elif arg_type is ctypes_int8: return prepare_arg[int8_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_int16): + elif arg_type is ctypes_int16: return prepare_arg[int16_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_int32): + elif arg_type is ctypes_int32: return prepare_arg[int32_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_int64): + elif arg_type is ctypes_int64: return prepare_arg[int64_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_uint8): + elif arg_type is ctypes_uint8: return prepare_arg[uint8_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_uint16): + elif arg_type is ctypes_uint16: return prepare_arg[uint16_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_uint32): + elif arg_type is ctypes_uint32: return prepare_arg[uint32_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_uint64): + elif arg_type is ctypes_uint64: return prepare_arg[uint64_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_float): + elif arg_type is ctypes_float: return prepare_arg[float](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_double): + elif arg_type is ctypes_double: return prepare_arg[double](data, data_addresses, arg.value, idx) else: return 1 @@ -159,33 +160,34 @@ cdef inline int prepare_numpy_arg( vector.vector[void*]& data_addresses, arg, const size_t idx) except -1: - if isinstance(arg, numpy_bool): + cdef object arg_type = type(arg) + if arg_type is numpy_bool: return prepare_arg[cpp_bool](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_int8): + elif arg_type is numpy_int8: return prepare_arg[int8_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_int16): + elif arg_type is numpy_int16: return prepare_arg[int16_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_int32): + elif arg_type is numpy_int32: return prepare_arg[int32_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_int64): + elif arg_type is numpy_int64: return prepare_arg[int64_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_uint8): + elif arg_type is numpy_uint8: return prepare_arg[uint8_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_uint16): + elif arg_type is numpy_uint16: return prepare_arg[uint16_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_uint32): + elif arg_type is numpy_uint32: return prepare_arg[uint32_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_uint64): + elif arg_type is numpy_uint64: return prepare_arg[uint64_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_float16): + elif arg_type is numpy_float16: return prepare_arg[__half_raw](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_float32): + elif arg_type is numpy_float32: return prepare_arg[float](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_float64): + elif arg_type is numpy_float64: return prepare_arg[double](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_complex64): + elif arg_type is numpy_complex64: return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_complex128): + elif arg_type is numpy_complex128: return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx) else: return 1 @@ -207,12 +209,14 @@ cdef class ParamHolder: cdef size_t n_args = len(kernel_args) cdef size_t i cdef int not_prepared + cdef object arg_type self.data = vector.vector[voidptr](n_args, nullptr) self.data_addresses = vector.vector[voidptr](n_args) for i, arg in enumerate(kernel_args): - if isinstance(arg, Buffer): + arg_type = type(arg) + if arg_type is Buffer: # we need the address of where the actual buffer address is stored - if isinstance(arg.handle, int): + if type(arg.handle) is int: # see note below on handling int arguments prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i) continue @@ -220,7 +224,7 @@ cdef class ParamHolder: # it's a CUdeviceptr: self.data_addresses[i] = (arg.handle.getPtr()) continue - elif isinstance(arg, int): + elif arg_type is int: # Here's the dilemma: We want to have a fast path to pass in Python # integers as pointer addresses, but one could also (mistakenly) pass # it with the intention of passing a scalar integer. It's a mistake @@ -228,13 +232,13 @@ cdef class ParamHolder: # call here is to treat it as a pointer address, without any warning! prepare_arg[intptr_t](self.data, self.data_addresses, arg, i) continue - elif isinstance(arg, float): + elif arg_type is float: prepare_arg[double](self.data, self.data_addresses, arg, i) continue - elif isinstance(arg, complex): + elif arg_type is complex: prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i) continue - elif isinstance(arg, bool): + elif arg_type is bool: prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) continue @@ -243,7 +247,7 @@ cdef class ParamHolder: not_prepared = prepare_ctypes_arg(self.data, self.data_addresses, arg, i) if not_prepared: # TODO: revisit this treatment if we decide to cythonize cuda.core - if isinstance(arg, driver.CUgraphConditionalHandle): + if arg_type is driver.CUgraphConditionalHandle: prepare_arg[intptr_t](self.data, self.data_addresses, int(arg), i) continue # TODO: support ctypes/numpy struct From eb9ca80ef5fa7f8039ea74207fb1ebc06fc66200 Mon Sep 17 00:00:00 2001 From: Bharat Raghunathan Date: Tue, 2 Dec 2025 14:57:03 -0600 Subject: [PATCH 2/9] Add isinstance fallback to maintain backward compat for subclasses Signed-off-by: Bharat Raghunathan --- .../core/experimental/_kernel_arg_handler.pyx | 81 ++++++++++++++++++- 1 file changed, 79 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx index 5ba50b42a7..d249db4f25 100644 --- a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx +++ b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx @@ -152,7 +152,31 @@ cdef inline int prepare_ctypes_arg( elif arg_type is ctypes_double: return prepare_arg[double](data, data_addresses, arg.value, idx) else: - return 1 + # If no exact types are found, fallback to slower `isinstance` check + if isinstance(arg, ctypes_bool): + return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_int8): + return prepare_arg[int8_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_int16): + return prepare_arg[int16_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_int32): + return prepare_arg[int32_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_int64): + return prepare_arg[int64_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_uint8): + return prepare_arg[uint8_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_uint16): + return prepare_arg[uint16_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_uint32): + return prepare_arg[uint32_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_uint64): + return prepare_arg[uint64_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_float): + return prepare_arg[float](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_double): + return prepare_arg[double](data, data_addresses, arg.value, idx) + else: + return 1 cdef inline int prepare_numpy_arg( @@ -190,7 +214,37 @@ cdef inline int prepare_numpy_arg( elif arg_type is numpy_complex128: return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx) else: - return 1 + # If no exact types are found, fallback to slower `isinstance` check + if isinstance(arg, numpy_bool): + return prepare_arg[cpp_bool](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_int8): + return prepare_arg[int8_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_int16): + return prepare_arg[int16_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_int32): + return prepare_arg[int32_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_int64): + return prepare_arg[int64_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_uint8): + return prepare_arg[uint8_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_uint16): + return prepare_arg[uint16_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_uint32): + return prepare_arg[uint32_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_uint64): + return prepare_arg[uint64_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_float16): + return prepare_arg[__half_raw](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_float32): + return prepare_arg[float](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_float64): + return prepare_arg[double](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_complex64): + return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_complex128): + return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx) + else: + return 1 cdef class ParamHolder: @@ -250,6 +304,29 @@ cdef class ParamHolder: if arg_type is driver.CUgraphConditionalHandle: prepare_arg[intptr_t](self.data, self.data_addresses, int(arg), i) continue + # If no exact types are found, fallback to slower `isinstance` check + elif isinstance(arg, Buffer): + if isinstance(arg.handle, int): + prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i) + continue + else: + self.data_addresses[i] = (arg.handle.getPtr()) + continue + elif isinstance(arg, int): + prepare_arg[intptr_t](self.data, self.data_addresses, arg, i) + continue + elif isinstance(arg, float): + prepare_arg[double](self.data, self.data_addresses, arg, i) + continue + elif isinstance(arg, complex): + prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i) + continue + elif isinstance(arg, bool): + prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) + continue + elif isinstance(arg, driver.CUgraphConditionalHandle): + prepare_arg[intptr_t](self.data, self.data_addresses, int(arg), i) + continue # TODO: support ctypes/numpy struct raise TypeError("the argument is of unsupported type: " + str(type(arg))) From 85bc86cc1bc5b3958d5583265dbee9e61c1e9cf3 Mon Sep 17 00:00:00 2001 From: Bharat Raghunathan Date: Wed, 3 Dec 2025 12:26:47 -0600 Subject: [PATCH 3/9] Fix test failures (attempt 1/n) Signed-off-by: Bharat Raghunathan --- .../cuda/core/experimental/_kernel_arg_handler.pyx | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx index d249db4f25..2e791f08d5 100644 --- a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx +++ b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx @@ -278,6 +278,9 @@ cdef class ParamHolder: # it's a CUdeviceptr: self.data_addresses[i] = (arg.handle.getPtr()) continue + elif arg_type is bool: + prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) + continue elif arg_type is int: # Here's the dilemma: We want to have a fast path to pass in Python # integers as pointer addresses, but one could also (mistakenly) pass @@ -292,9 +295,6 @@ cdef class ParamHolder: elif arg_type is complex: prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i) continue - elif arg_type is bool: - prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) - continue not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i) if not_prepared: @@ -312,6 +312,9 @@ cdef class ParamHolder: else: self.data_addresses[i] = (arg.handle.getPtr()) continue + elif isinstance(arg, bool): + prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) + continue elif isinstance(arg, int): prepare_arg[intptr_t](self.data, self.data_addresses, arg, i) continue @@ -321,9 +324,6 @@ cdef class ParamHolder: elif isinstance(arg, complex): prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i) continue - elif isinstance(arg, bool): - prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) - continue elif isinstance(arg, driver.CUgraphConditionalHandle): prepare_arg[intptr_t](self.data, self.data_addresses, int(arg), i) continue From ed81c11fb5e5f990c1bd5fe353476b9713cee1be Mon Sep 17 00:00:00 2001 From: Bharat Raghunathan Date: Wed, 3 Dec 2025 15:31:16 -0600 Subject: [PATCH 4/9] Fix test failures (attempt 2/n) Signed-off-by: Bharat Raghunathan --- cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx index 2e791f08d5..90c5856a08 100644 --- a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx +++ b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx @@ -279,7 +279,7 @@ cdef class ParamHolder: self.data_addresses[i] = (arg.handle.getPtr()) continue elif arg_type is bool: - prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) + prepare_arg[int32_t](self.data, self.data_addresses, int(arg), i) continue elif arg_type is int: # Here's the dilemma: We want to have a fast path to pass in Python @@ -313,7 +313,7 @@ cdef class ParamHolder: self.data_addresses[i] = (arg.handle.getPtr()) continue elif isinstance(arg, bool): - prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) + prepare_arg[int32_t](self.data, self.data_addresses, int(arg), i) continue elif isinstance(arg, int): prepare_arg[intptr_t](self.data, self.data_addresses, arg, i) From 3f9b5f12e2c366928a202bf42239da3d6cc4ac22 Mon Sep 17 00:00:00 2001 From: Bharat Raghunathan Date: Thu, 4 Dec 2025 10:31:02 -0600 Subject: [PATCH 5/9] Apply suggestions from code review by @mdboom and @leofang Explicit cast not needed since `prepare_arg` does it automatically Co-authored-by: Leo Fang --- cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx index 90c5856a08..2e791f08d5 100644 --- a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx +++ b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx @@ -279,7 +279,7 @@ cdef class ParamHolder: self.data_addresses[i] = (arg.handle.getPtr()) continue elif arg_type is bool: - prepare_arg[int32_t](self.data, self.data_addresses, int(arg), i) + prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) continue elif arg_type is int: # Here's the dilemma: We want to have a fast path to pass in Python @@ -313,7 +313,7 @@ cdef class ParamHolder: self.data_addresses[i] = (arg.handle.getPtr()) continue elif isinstance(arg, bool): - prepare_arg[int32_t](self.data, self.data_addresses, int(arg), i) + prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) continue elif isinstance(arg, int): prepare_arg[intptr_t](self.data, self.data_addresses, arg, i) From 0fc1ed8d12161f8dbefca68f9bf9cf55f85fe637 Mon Sep 17 00:00:00 2001 From: Michael Droettboom Date: Thu, 4 Dec 2025 12:44:51 -0500 Subject: [PATCH 6/9] Fix test_graph.py to use bool correctly in cudaGraphSetConditional call. --- cuda_core/docs/source/release/0.5.x-notes.rst | 39 +++++++++++++++++++ cuda_core/tests/test_graph.py | 14 +++++-- 2 files changed, 50 insertions(+), 3 deletions(-) create mode 100644 cuda_core/docs/source/release/0.5.x-notes.rst diff --git a/cuda_core/docs/source/release/0.5.x-notes.rst b/cuda_core/docs/source/release/0.5.x-notes.rst new file mode 100644 index 0000000000..5f0298e502 --- /dev/null +++ b/cuda_core/docs/source/release/0.5.x-notes.rst @@ -0,0 +1,39 @@ +.. SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +.. SPDX-License-Identifier: Apache-2.0 + +.. currentmodule:: cuda.core.experimental + +``cuda.core`` 0.5.x Release Notes +================================= + + +Highlights +---------- + +None. + + +Breaking Changes +---------------- + +- Python ``bool`` objects are now converted to C++ ``bool`` type when passed as kernel + arguments. Previously, they were converted to ``int``. This brings them inline + with ``ctypes.c_bool`` and ``numpy.bool_``. + + +New features +------------ + +None. + + +New examples +------------ + +None. + + +Fixes and enhancements +---------------------- + +None. diff --git a/cuda_core/tests/test_graph.py b/cuda_core/tests/test_graph.py index 615f7242c4..2a31ad42f7 100644 --- a/cuda_core/tests/test_graph.py +++ b/cuda_core/tests/test_graph.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE +import ctypes + import numpy as np import pytest @@ -41,7 +43,9 @@ def _common_kernels_conditional(): unsigned int value); __global__ void empty_kernel() {} __global__ void add_one(int *a) { *a += 1; } - __global__ void set_handle(cudaGraphConditionalHandle handle, int value) { cudaGraphSetConditional(handle, value); } + __global__ void set_handle(cudaGraphConditionalHandle handle, bool value) { + cudaGraphSetConditional(handle, value); + } __global__ void loop_kernel(cudaGraphConditionalHandle handle) { static int count = 10; @@ -216,7 +220,9 @@ def test_graph_capture_errors(init_cuda): gb.end_building().complete() -@pytest.mark.parametrize("condition_value", [True, False]) +@pytest.mark.parametrize( + "condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False)] +) @pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+") def test_graph_conditional_if(init_cuda, condition_value): mod = _common_kernels_conditional() @@ -278,7 +284,9 @@ def test_graph_conditional_if(init_cuda, condition_value): b.close() -@pytest.mark.parametrize("condition_value", [True, False]) +@pytest.mark.parametrize( + "condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False)] +) @pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+") def test_graph_conditional_if_else(init_cuda, condition_value): mod = _common_kernels_conditional() From 3025e049165c4805f8d14e03f9d569adafb81844 Mon Sep 17 00:00:00 2001 From: Michael Droettboom Date: Fri, 5 Dec 2025 07:55:09 -0500 Subject: [PATCH 7/9] More explicit typing of CUgraphConditionalHandle --- cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx index 2e791f08d5..3537fb0665 100644 --- a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx +++ b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx @@ -17,6 +17,7 @@ import numpy from cuda.core.experimental._memory import Buffer from cuda.core.experimental._utils.cuda_utils import driver +from cuda.bindings cimport cydriver ctypedef cpp_complex.complex[float] cpp_single_complex @@ -272,7 +273,7 @@ cdef class ParamHolder: # we need the address of where the actual buffer address is stored if type(arg.handle) is int: # see note below on handling int arguments - prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i) + prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, arg.handle, i) continue else: # it's a CUdeviceptr: @@ -302,7 +303,7 @@ cdef class ParamHolder: if not_prepared: # TODO: revisit this treatment if we decide to cythonize cuda.core if arg_type is driver.CUgraphConditionalHandle: - prepare_arg[intptr_t](self.data, self.data_addresses, int(arg), i) + prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, int(arg), i) continue # If no exact types are found, fallback to slower `isinstance` check elif isinstance(arg, Buffer): From 48a08565c75335f9c87dc3518897642c5e6be7f7 Mon Sep 17 00:00:00 2001 From: Michael Droettboom Date: Fri, 5 Dec 2025 07:57:22 -0500 Subject: [PATCH 8/9] Improve conditional kernel tests by testing both bools and ints --- cuda_core/tests/test_graph.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/cuda_core/tests/test_graph.py b/cuda_core/tests/test_graph.py index 2a31ad42f7..e988eeebf6 100644 --- a/cuda_core/tests/test_graph.py +++ b/cuda_core/tests/test_graph.py @@ -37,13 +37,20 @@ def _common_kernels(): return mod -def _common_kernels_conditional(): +def _common_kernels_conditional(cond_type): + if cond_type in (bool, np.bool_, ctypes.c_bool): + cond_type_str = "bool" + elif cond_type is int: + cond_type_str = "unsigned int" + else: + raise ValueError("Unsupported cond_type") + code = """ extern "C" __device__ __cudart_builtin__ void CUDARTAPI cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value); __global__ void empty_kernel() {} __global__ void add_one(int *a) { *a += 1; } - __global__ void set_handle(cudaGraphConditionalHandle handle, bool value) { + __global__ void set_handle(cudaGraphConditionalHandle handle, $cond_type_str value) { cudaGraphSetConditional(handle, value); } __global__ void loop_kernel(cudaGraphConditionalHandle handle) @@ -51,7 +58,7 @@ def _common_kernels_conditional(): static int count = 10; cudaGraphSetConditional(handle, --count ? 1 : 0); } - """ + """.replace("$cond_type_str", cond_type_str) arch = "".join(f"{i}" for i in Device().compute_capability) program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}") prog = Program(code, code_type="c++", options=program_options) @@ -221,11 +228,11 @@ def test_graph_capture_errors(init_cuda): @pytest.mark.parametrize( - "condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False)] + "condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False), 1, 0] ) @pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+") def test_graph_conditional_if(init_cuda, condition_value): - mod = _common_kernels_conditional() + mod = _common_kernels_conditional(type(condition_value)) add_one = mod.get_kernel("add_one") set_handle = mod.get_kernel("set_handle") @@ -285,11 +292,11 @@ def test_graph_conditional_if(init_cuda, condition_value): @pytest.mark.parametrize( - "condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False)] + "condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False), 1, 0] ) @pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+") def test_graph_conditional_if_else(init_cuda, condition_value): - mod = _common_kernels_conditional() + mod = _common_kernels_conditional(type(condition_value)) add_one = mod.get_kernel("add_one") set_handle = mod.get_kernel("set_handle") @@ -361,7 +368,7 @@ def test_graph_conditional_if_else(init_cuda, condition_value): @pytest.mark.parametrize("condition_value", [0, 1, 2, 3]) @pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+") def test_graph_conditional_switch(init_cuda, condition_value): - mod = _common_kernels_conditional() + mod = _common_kernels_conditional(type(condition_value)) add_one = mod.get_kernel("add_one") set_handle = mod.get_kernel("set_handle") @@ -449,10 +456,10 @@ def test_graph_conditional_switch(init_cuda, condition_value): b.close() -@pytest.mark.parametrize("condition_value", [True, False]) +@pytest.mark.parametrize("condition_value", [True, False, 1, 0]) @pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+") def test_graph_conditional_while(init_cuda, condition_value): - mod = _common_kernels_conditional() + mod = _common_kernels_conditional(type(condition_value)) add_one = mod.get_kernel("add_one") loop_kernel = mod.get_kernel("loop_kernel") empty_kernel = mod.get_kernel("empty_kernel") @@ -553,7 +560,7 @@ def test_graph_child_graph(init_cuda): @pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+") def test_graph_update(init_cuda): - mod = _common_kernels_conditional() + mod = _common_kernels_conditional(int) add_one = mod.get_kernel("add_one") # Allocate memory @@ -676,7 +683,7 @@ def test_graph_stream_lifetime(init_cuda): def test_graph_dot_print_options(init_cuda, tmp_path): - mod = _common_kernels_conditional() + mod = _common_kernels_conditional(bool) set_handle = mod.get_kernel("set_handle") empty_kernel = mod.get_kernel("empty_kernel") From d9a7c14e6c97f414be68c69186e3fb01301ce890 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Fri, 5 Dec 2025 23:37:19 -0500 Subject: [PATCH 9/9] fix --- cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx index 3537fb0665..4cac74a25f 100644 --- a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx +++ b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx @@ -273,7 +273,7 @@ cdef class ParamHolder: # we need the address of where the actual buffer address is stored if type(arg.handle) is int: # see note below on handling int arguments - prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, arg.handle, i) + prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i) continue else: # it's a CUdeviceptr: @@ -326,7 +326,7 @@ cdef class ParamHolder: prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i) continue elif isinstance(arg, driver.CUgraphConditionalHandle): - prepare_arg[intptr_t](self.data, self.data_addresses, int(arg), i) + prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, arg, i) continue # TODO: support ctypes/numpy struct raise TypeError("the argument is of unsupported type: " + str(type(arg)))