Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 115 additions & 34 deletions cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -128,67 +128,123 @@ cdef inline int prepare_ctypes_arg(
vector.vector[void*]& data_addresses,
arg,
const size_t idx) except -1:
if isinstance(arg, ctypes_bool):
cdef object arg_type = type(arg)
if arg_type is ctypes_bool:
return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int8):
elif arg_type is ctypes_int8:
return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int16):
elif arg_type is ctypes_int16:
return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int32):
elif arg_type is ctypes_int32:
return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int64):
elif arg_type is ctypes_int64:
return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint8):
elif arg_type is ctypes_uint8:
return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint16):
elif arg_type is ctypes_uint16:
return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint32):
elif arg_type is ctypes_uint32:
return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint64):
elif arg_type is ctypes_uint64:
return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_float):
elif arg_type is ctypes_float:
return prepare_arg[float](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_double):
elif arg_type is ctypes_double:
return prepare_arg[double](data, data_addresses, arg.value, idx)
else:
return 1
# If no exact types are found, fallback to slower `isinstance` check
if isinstance(arg, ctypes_bool):
return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int8):
return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int16):
return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int32):
return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int64):
return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint8):
return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint16):
return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint32):
return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint64):
return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_float):
return prepare_arg[float](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_double):
return prepare_arg[double](data, data_addresses, arg.value, idx)
else:
return 1


cdef inline int prepare_numpy_arg(
vector.vector[void*]& data,
vector.vector[void*]& data_addresses,
arg,
const size_t idx) except -1:
if isinstance(arg, numpy_bool):
cdef object arg_type = type(arg)
if arg_type is numpy_bool:
return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int8):
elif arg_type is numpy_int8:
return prepare_arg[int8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int16):
elif arg_type is numpy_int16:
return prepare_arg[int16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int32):
elif arg_type is numpy_int32:
return prepare_arg[int32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int64):
elif arg_type is numpy_int64:
return prepare_arg[int64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint8):
elif arg_type is numpy_uint8:
return prepare_arg[uint8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint16):
elif arg_type is numpy_uint16:
return prepare_arg[uint16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint32):
elif arg_type is numpy_uint32:
return prepare_arg[uint32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint64):
elif arg_type is numpy_uint64:
return prepare_arg[uint64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float16):
elif arg_type is numpy_float16:
return prepare_arg[__half_raw](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float32):
elif arg_type is numpy_float32:
return prepare_arg[float](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float64):
elif arg_type is numpy_float64:
return prepare_arg[double](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex64):
elif arg_type is numpy_complex64:
return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex128):
elif arg_type is numpy_complex128:
return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
else:
return 1
# If no exact types are found, fallback to slower `isinstance` check
if isinstance(arg, numpy_bool):
return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int8):
return prepare_arg[int8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int16):
return prepare_arg[int16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int32):
return prepare_arg[int32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int64):
return prepare_arg[int64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint8):
return prepare_arg[uint8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint16):
return prepare_arg[uint16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint32):
return prepare_arg[uint32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint64):
return prepare_arg[uint64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float16):
return prepare_arg[__half_raw](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float32):
return prepare_arg[float](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float64):
return prepare_arg[double](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex64):
return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex128):
return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
else:
return 1


cdef class ParamHolder:
Expand All @@ -207,34 +263,36 @@ cdef class ParamHolder:
cdef size_t n_args = len(kernel_args)
cdef size_t i
cdef int not_prepared
cdef object arg_type
self.data = vector.vector[voidptr](n_args, nullptr)
self.data_addresses = vector.vector[voidptr](n_args)
for i, arg in enumerate(kernel_args):
if isinstance(arg, Buffer):
arg_type = type(arg)
if arg_type is Buffer:
# we need the address of where the actual buffer address is stored
if isinstance(arg.handle, int):
if type(arg.handle) is int:
# see note below on handling int arguments
prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
continue
else:
# it's a CUdeviceptr:
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
continue
elif isinstance(arg, int):
elif arg_type is int:
# Here's the dilemma: We want to have a fast path to pass in Python
# integers as pointer addresses, but one could also (mistakenly) pass
# it with the intention of passing a scalar integer. It's a mistake
# bacause a Python int is ambiguous (arbitrary width). Our judgement
# call here is to treat it as a pointer address, without any warning!
prepare_arg[intptr_t](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, float):
elif arg_type is float:
prepare_arg[double](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, complex):
elif arg_type is complex:
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, bool):
elif arg_type is bool:
prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
continue

Expand All @@ -243,7 +301,30 @@ cdef class ParamHolder:
not_prepared = prepare_ctypes_arg(self.data, self.data_addresses, arg, i)
if not_prepared:
# TODO: revisit this treatment if we decide to cythonize cuda.core
if isinstance(arg, driver.CUgraphConditionalHandle):
if arg_type is driver.CUgraphConditionalHandle:
prepare_arg[intptr_t](self.data, self.data_addresses, <intptr_t>int(arg), i)
continue
# If no exact types are found, fallback to slower `isinstance` check
elif isinstance(arg, Buffer):
if isinstance(arg.handle, int):
prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
continue
else:
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
continue
elif isinstance(arg, int):
prepare_arg[intptr_t](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, float):
prepare_arg[double](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, complex):
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, bool):
prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, driver.CUgraphConditionalHandle):
prepare_arg[intptr_t](self.data, self.data_addresses, <intptr_t>int(arg), i)
continue
# TODO: support ctypes/numpy struct
Expand Down