Skip to content

Commit 7e1c8f5

Browse files
committed
implement kernel arg handling + check in saxpy sample
1 parent 905e5f4 commit 7e1c8f5

File tree

4 files changed

+309
-14
lines changed

4 files changed

+309
-14
lines changed
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
2+
#
3+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4+
5+
from cpython.mem cimport PyMem_Malloc, PyMem_Free
6+
from libc.stdint cimport (intptr_t,
7+
int8_t, int16_t, int32_t, int64_t,
8+
uint8_t, uint16_t, uint32_t, uint64_t,)
9+
from libcpp cimport bool as cpp_bool
10+
from libcpp.complex cimport complex as cpp_complex
11+
from libcpp cimport nullptr
12+
from libcpp cimport vector
13+
14+
import ctypes
15+
16+
# this might be an unnecessary assumption that NumPy does not exist...
17+
try:
18+
import numpy
19+
except ImportError:
20+
numpy = None
21+
22+
from cuda.core._memory import Buffer
23+
24+
25+
ctypedef cpp_complex.complex[float] cpp_single_complex
26+
ctypedef cpp_complex.complex[double] cpp_double_complex
27+
28+
29+
ctypedef fused supported_type:
30+
cpp_bool
31+
int8_t
32+
int16_t
33+
int32_t
34+
int64_t
35+
uint8_t
36+
uint16_t
37+
uint32_t
38+
uint64_t
39+
float
40+
double
41+
intptr_t
42+
cpp_single_complex
43+
cpp_double_complex
44+
45+
46+
# TODO: cache ctypes/numpy type objects to avoid attribute access
47+
48+
49+
# limitation due to cython/cython#534
50+
ctypedef void* voidptr
51+
52+
53+
# Cython can't infer the overload without at least one input argument with fused type
54+
cdef inline int prepare_arg(
55+
vector.vector[void*]& data,
56+
vector.vector[void*]& data_addresses,
57+
arg, # important: keep it a Python object and don't cast
58+
const size_t idx,
59+
const supported_type* __unused=NULL) except -1:
60+
cdef void* ptr = PyMem_Malloc(sizeof(supported_type))
61+
# note: this should also work once ctypes has complex support:
62+
# python/cpython#121248
63+
if supported_type is cpp_single_complex:
64+
(<supported_type*>ptr)[0] = cpp_complex.complex[float](arg.real, arg.imag)
65+
elif supported_type is cpp_double_complex:
66+
(<supported_type*>ptr)[0] = cpp_complex.complex[double](arg.real, arg.imag)
67+
else:
68+
(<supported_type*>ptr)[0] = <supported_type>(arg)
69+
data_addresses[idx] = ptr # take the address to the scalar
70+
data[idx] = ptr # for later dealloc
71+
return 0
72+
73+
74+
cdef inline int prepare_ctypes_arg(
75+
vector.vector[void*]& data,
76+
vector.vector[void*]& data_addresses,
77+
arg,
78+
const size_t idx) except -1:
79+
if isinstance(arg, ctypes.c_bool):
80+
return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
81+
elif isinstance(arg, ctypes.c_int8):
82+
return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
83+
elif isinstance(arg, ctypes.c_int16):
84+
return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
85+
elif isinstance(arg, ctypes.c_int32):
86+
return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
87+
elif isinstance(arg, ctypes.c_int64):
88+
return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
89+
elif isinstance(arg, ctypes.c_uint8):
90+
return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
91+
elif isinstance(arg, ctypes.c_uint16):
92+
return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
93+
elif isinstance(arg, ctypes.c_uint32):
94+
return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
95+
elif isinstance(arg, ctypes.c_uint64):
96+
return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
97+
elif isinstance(arg, ctypes.c_float):
98+
return prepare_arg[float](data, data_addresses, arg.value, idx)
99+
elif isinstance(arg, ctypes.c_double):
100+
return prepare_arg[double](data, data_addresses, arg.value, idx)
101+
else:
102+
return 1
103+
104+
105+
cdef inline int prepare_numpy_arg(
106+
vector.vector[void*]& data,
107+
vector.vector[void*]& data_addresses,
108+
arg,
109+
const size_t idx) except -1:
110+
if not numpy:
111+
return 1
112+
113+
if isinstance(arg, numpy.bool_):
114+
return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
115+
elif isinstance(arg, numpy.int8):
116+
return prepare_arg[int8_t](data, data_addresses, arg, idx)
117+
elif isinstance(arg, numpy.int16):
118+
return prepare_arg[int16_t](data, data_addresses, arg, idx)
119+
elif isinstance(arg, numpy.int32):
120+
return prepare_arg[int32_t](data, data_addresses, arg, idx)
121+
elif isinstance(arg, numpy.int64):
122+
return prepare_arg[int64_t](data, data_addresses, arg, idx)
123+
elif isinstance(arg, numpy.uint8):
124+
return prepare_arg[uint8_t](data, data_addresses, arg, idx)
125+
elif isinstance(arg, numpy.uint16):
126+
return prepare_arg[uint16_t](data, data_addresses, arg, idx)
127+
elif isinstance(arg, numpy.uint32):
128+
return prepare_arg[uint32_t](data, data_addresses, arg, idx)
129+
elif isinstance(arg, numpy.uint64):
130+
return prepare_arg[uint64_t](data, data_addresses, arg, idx)
131+
elif isinstance(arg, numpy.float16):
132+
# use int16 as a proxy
133+
return prepare_arg[int16_t](data, data_addresses, arg, idx)
134+
elif isinstance(arg, numpy.float32):
135+
return prepare_arg[float](data, data_addresses, arg, idx)
136+
elif isinstance(arg, numpy.float64):
137+
return prepare_arg[double](data, data_addresses, arg, idx)
138+
elif isinstance(arg, numpy.complex64):
139+
return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
140+
elif isinstance(arg, numpy.complex128):
141+
return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
142+
else:
143+
return 1
144+
145+
146+
cdef class ParamHolder:
147+
148+
cdef:
149+
vector.vector[void*] data
150+
vector.vector[void*] data_addresses
151+
object kernel_args
152+
readonly intptr_t ptr
153+
154+
def __init__(self, kernel_args):
155+
if len(kernel_args) == 0:
156+
self.ptr = 0
157+
return
158+
159+
cdef size_t n_args = len(kernel_args)
160+
cdef size_t i
161+
cdef int not_prepared
162+
self.data = vector.vector[voidptr](n_args, nullptr)
163+
self.data_addresses = vector.vector[voidptr](n_args)
164+
for i, arg in enumerate(kernel_args):
165+
if isinstance(arg, Buffer):
166+
# we need the address of where the actual buffer address is stored
167+
self.data_addresses[i] = <void*><intptr_t>(arg._ptr.getPtr())
168+
continue
169+
elif isinstance(arg, int):
170+
# Here's the dilemma: We want to have a fast path to pass in Python
171+
# integers as pointer addresses, but one could also (mistakenly) pass
172+
# it with the intention of passing a scalar integer. It's a mistake
173+
# bacause a Python int is ambiguous (arbitrary width). Our judgement
174+
# call here is to treat it as a pointer address, without any warning!
175+
prepare_arg[intptr_t](self.data, self.data_addresses, arg, i)
176+
continue
177+
elif isinstance(arg, float):
178+
prepare_arg[double](self.data, self.data_addresses, arg, i)
179+
continue
180+
elif isinstance(arg, complex):
181+
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
182+
continue
183+
elif isinstance(arg, bool):
184+
prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
185+
continue
186+
187+
not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i)
188+
if not_prepared != 0:
189+
not_prepared = prepare_ctypes_arg(self.data, self.data_addresses, arg, i)
190+
if not_prepared != 0:
191+
# TODO: support ctypes/numpy struct
192+
raise TypeError
193+
194+
self.kernel_args = kernel_args
195+
self.ptr = <intptr_t>self.data_addresses.data()
196+
197+
def __dealloc__(self):
198+
for data in self.data:
199+
if data:
200+
PyMem_Free(data)

cuda_core/cuda/core/_launcher.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
import numpy as np
99

1010
from cuda import cuda, cudart
11-
from cuda.core._utils import CUDAError, check_or_create_options, handle_return
11+
from cuda.core._kernel_arg_handler import ParamHolder
1212
from cuda.core._memory import Buffer
1313
from cuda.core._module import Kernel
1414
from cuda.core._stream import Stream
15+
from cuda.core._utils import CUDAError, check_or_create_options, handle_return
1516

1617

1718
@dataclass
@@ -80,19 +81,8 @@ def launch(kernel, config, *kernel_args):
8081
drv_cfg.numAttrs = 0 # FIXME
8182

8283
# TODO: merge with HelperKernelParams?
83-
num_args = len(kernel_args)
84-
args_ptr = 0
85-
if num_args:
86-
# FIXME: support args passed by value
87-
args = np.empty(num_args, dtype=np.intp)
88-
for i, arg in enumerate(kernel_args):
89-
if isinstance(arg, Buffer):
90-
# this is super weird... we need the address of where the actual
91-
# buffer address is stored...
92-
args[i] = arg._ptr.getPtr()
93-
else:
94-
raise NotImplementedError
95-
args_ptr = args.ctypes.data
84+
kernel_args = ParamHolder(kernel_args)
85+
args_ptr = kernel_args.ptr
9686

9787
handle_return(cuda.cuLaunchKernelEx(
9888
drv_cfg, int(kernel._handle), args_ptr, 0))

cuda_core/examples/saxpy.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import sys
2+
3+
from cuda.core import Device
4+
from cuda.core import LaunchConfig, launch
5+
from cuda.core import Program
6+
7+
import cupy as cp
8+
9+
10+
# compute out = a * x + y
11+
code = """
12+
template<typename T>
13+
__global__ void saxpy(const T a,
14+
const T* x,
15+
const T* y,
16+
T* out,
17+
size_t N) {
18+
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
19+
for (size_t i=tid; i<N; i+=gridDim.x*blockDim.x) {
20+
out[tid] = a * x[tid] + y[tid];
21+
}
22+
}
23+
"""
24+
25+
26+
dev = Device()
27+
dev.set_current()
28+
s = dev.create_stream()
29+
30+
# prepare program
31+
prog = Program(code, code_type="c++")
32+
mod = prog.compile(
33+
"cubin",
34+
options=("-std=c++11", "-arch=sm_" + "".join(f"{i}" for i in dev.compute_capability),),
35+
logs=sys.stdout,
36+
name_expressions=("saxpy<float>", "saxpy<double>"))
37+
38+
# run in single precision
39+
ker = mod.get_kernel("saxpy<float>")
40+
dtype = cp.float32
41+
42+
# prepare input/output
43+
size = cp.uint64(64)
44+
a = dtype(10)
45+
x = cp.random.random(size, dtype=dtype)
46+
y = cp.random.random(size, dtype=dtype)
47+
out = cp.empty_like(x)
48+
dev.sync() # cupy runs on a different stream from s, so sync before accessing
49+
50+
# prepare launch
51+
block = 32
52+
grid = int((size + block - 1) // block)
53+
config = LaunchConfig(grid=grid, block=block, stream=s)
54+
ker_args = (a, x.data.ptr, y.data.ptr, out.data.ptr, size)
55+
56+
# launch kernel on stream s
57+
launch(ker, config, *ker_args)
58+
s.sync()
59+
60+
# check result
61+
assert cp.allclose(out, a*x+y)
62+
63+
# let's repeat again, this time allocates our own out buffer instead of cupy's
64+
# run in double precision
65+
ker = mod.get_kernel("saxpy<double>")
66+
dtype = cp.float64
67+
68+
# prepare input
69+
size = cp.uint64(128)
70+
a = dtype(42)
71+
x = cp.random.random(size, dtype=dtype)
72+
y = cp.random.random(size, dtype=dtype)
73+
dev.sync()
74+
75+
# prepare output
76+
buf = dev.allocate(size * 8, # = dtype.itemsize
77+
stream=s)
78+
79+
# prepare launch
80+
block = 64
81+
grid = int((size + block - 1) // block)
82+
config = LaunchConfig(grid=grid, block=block, stream=s)
83+
ker_args = (a, x.data.ptr, y.data.ptr, buf, size)
84+
85+
# launch kernel on stream s
86+
launch(ker, config, *ker_args)
87+
s.sync()
88+
89+
# check result
90+
# we wrap output buffer as a cupy array for simplicity
91+
out = cp.ndarray(size, dtype=dtype,
92+
memptr=cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(int(buf.handle), buf.size, buf), 0))
93+
assert cp.allclose(out, a*x+y)
94+
95+
# clean up resources that we allocate
96+
# cupy cleans up automatically the rest
97+
buf.close(s)
98+
s.close()
99+
100+
print("done!")

cuda_core/setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
sources=["cuda/core/_memoryview.pyx"],
1818
language="c++",
1919
),
20+
Extension(
21+
"cuda.core._kernel_arg_handler",
22+
sources=["cuda/core/_kernel_arg_handler.pyx"],
23+
language="c++",
24+
),
2025
)
2126

2227

0 commit comments

Comments
 (0)