Skip to content

Commit 905e5f4

Browse files
committed
align with latest design
1 parent 4a5457e commit 905e5f4

File tree

7 files changed

+47
-28
lines changed

7 files changed

+47
-28
lines changed

cuda_core/cuda/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5-
from cuda.core._compiler import Compiler
65
from cuda.core._device import Device
76
from cuda.core._event import EventOptions
87
from cuda.core._launcher import LaunchConfig, launch
8+
from cuda.core._program import Program
99
from cuda.core._stream import Stream, StreamOptions
1010
from cuda.core._version import __version__

cuda_core/cuda/core/_device.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from cuda import cuda, cudart
1010
from cuda.core._utils import handle_return, ComputeCapability, CUDAError, \
11-
precondition
11+
precondition
1212
from cuda.core._context import Context, ContextOptions
1313
from cuda.core._memory import _DefaultAsyncMempool, Buffer, MemoryResource
1414
from cuda.core._stream import default_stream, Stream, StreamOptions
@@ -50,7 +50,7 @@ def __new__(cls, device_id=None):
5050
def _check_context_initialized(self, *args, **kwargs):
5151
if not self._has_inited:
5252
raise CUDAError("the device is not yet initialized, "
53-
"perhaps you forgot to call .use() first?")
53+
"perhaps you forgot to call .set_current() first?")
5454

5555
@property
5656
def device_id(self) -> int:
@@ -120,14 +120,14 @@ def __int__(self):
120120
def __repr__(self):
121121
return f"<Device {self._id} ({self.name})>"
122122

123-
def use(self, ctx: Context=None) -> Union[Context, None]:
123+
def set_current(self, ctx: Context=None) -> Union[Context, None]:
124124
"""
125125
Entry point of this object. Users always start a code by
126126
calling this method, e.g.
127127
128128
>>> from cuda.core import Device
129129
>>> dev0 = Device(0)
130-
>>> dev0.use()
130+
>>> dev0.set_current()
131131
>>> # ... do work on device 0 ...
132132
133133
The optional ctx argument is for advanced users to bind a

cuda_core/cuda/core/_event.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
@dataclass
1515
class EventOptions:
16-
disable_timing: Optional[bool] = False
16+
enable_timing: Optional[bool] = False
1717
busy_waited_sync: Optional[bool] = False
1818
support_ipc: Optional[bool] = False
1919

@@ -37,8 +37,9 @@ def _init(options: Optional[EventOptions]=None):
3737

3838
options = check_or_create_options(EventOptions, options, "Event options")
3939
flags = 0x0
40-
self._timing_disabled = self._busy_waited = False
41-
if options.disable_timing:
40+
self._timing_disabled = False
41+
self._busy_waited = False
42+
if not options.enable_timing:
4243
flags |= cuda.CUevent_flags.CU_EVENT_DISABLE_TIMING
4344
self._timing_disabled = True
4445
if options.busy_waited_sync:
@@ -91,4 +92,4 @@ def is_done(self) -> bool:
9192

9293
@property
9394
def handle(self) -> int:
94-
return self._handle
95+
return int(self._handle)

cuda_core/cuda/core/_memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def device_id(self) -> int:
7575
return self._mr.device_id
7676
raise NotImplementedError
7777

78-
def copy_to(self, dst: Buffer=None, stream=None) -> Buffer:
78+
def copy_to(self, dst: Buffer=None, *, stream) -> Buffer:
7979
# Copy from this buffer to the dst buffer asynchronously on the
8080
# given stream. The dst buffer is returned. If the dst is not provided,
8181
# allocate one from self.memory_resource. Raise an exception if the
@@ -92,7 +92,7 @@ def copy_to(self, dst: Buffer=None, stream=None) -> Buffer:
9292
cuda.cuMemcpyAsync(dst._ptr, self._ptr, self._size, stream._handle))
9393
return dst
9494

95-
def copy_from(self, src: Buffer, stream=None):
95+
def copy_from(self, src: Buffer, *, stream):
9696
# Copy from the src buffer to this buffer asynchronously on the
9797
# given stream. Raise an exception if the stream is not provided.
9898
if stream is None:

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@ cdef class GPUMemoryView:
3131
readonly: bool = None
3232
obj: Any = None
3333

34+
def __init__(self, obj=None, stream_ptr=None):
35+
if obj is not None:
36+
# populate self's attributes
37+
if check_has_dlpack(obj):
38+
view_as_dlpack(obj, stream_ptr, self)
39+
else:
40+
view_as_cai(obj, stream_ptr, self)
41+
else:
42+
# default construct
43+
pass
44+
3445
def __repr__(self):
3546
return (f"GPUMemoryView(ptr={self.ptr},\n"
3647
+ f" shape={self.shape},\n"
@@ -57,22 +68,27 @@ cdef str get_simple_repr(obj):
5768
return obj_repr
5869

5970

71+
cdef bint check_has_dlpack(obj) except*:
72+
cdef bint has_dlpack
73+
if hasattr(obj, "__dlpack__") and hasattr(obj, "__dlpack_device__"):
74+
has_dlpack = True
75+
elif hasattr(obj, "__cuda_array_interface__"):
76+
has_dlpack = False
77+
else:
78+
raise RuntimeError(
79+
"the input object does not support any data exchange protocol")
80+
return has_dlpack
81+
82+
6083
cdef class _GPUMemoryViewProxy:
6184

6285
cdef:
6386
object obj
6487
bint has_dlpack
6588

6689
def __init__(self, obj):
67-
if hasattr(obj, "__dlpack__") and hasattr(obj, "__dlpack_device__"):
68-
has_dlpack = True
69-
elif hasattr(obj, "__cuda_array_interface__"):
70-
has_dlpack = False
71-
else:
72-
raise RuntimeError(
73-
"the input object does not support any data exchange protocol")
7490
self.obj = obj
75-
self.has_dlpack = has_dlpack
91+
self.has_dlpack = check_has_dlpack(obj)
7692

7793
cpdef GPUMemoryView view(self, stream_ptr=None):
7894
if self.has_dlpack:
@@ -81,7 +97,7 @@ cdef class _GPUMemoryViewProxy:
8197
return view_as_cai(self.obj, stream_ptr)
8298

8399

84-
cdef GPUMemoryView view_as_dlpack(obj, stream_ptr):
100+
cdef GPUMemoryView view_as_dlpack(obj, stream_ptr, view=None):
85101
cdef int dldevice, device_id, i
86102
cdef bint device_accessible, versioned, is_readonly
87103
dldevice, device_id = obj.__dlpack_device__()
@@ -144,7 +160,7 @@ cdef GPUMemoryView view_as_dlpack(obj, stream_ptr):
144160
dl_tensor = &dlm_tensor.dl_tensor
145161
is_readonly = False
146162

147-
cdef GPUMemoryView buf = GPUMemoryView()
163+
cdef GPUMemoryView buf = GPUMemoryView() if view is None else view
148164
buf.ptr = <intptr_t>(dl_tensor.data)
149165
buf.shape = tuple(int(dl_tensor.shape[i]) for i in range(dl_tensor.ndim))
150166
if dl_tensor.strides:
@@ -226,7 +242,7 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
226242
return numpy.dtype(np_dtype)
227243

228244

229-
cdef GPUMemoryView view_as_cai(obj, stream_ptr):
245+
cdef GPUMemoryView view_as_cai(obj, stream_ptr, view=None):
230246
cdef dict cai_data = obj.__cuda_array_interface__
231247
if cai_data["version"] < 3:
232248
raise BufferError("only CUDA Array Interface v3 or above is supported")
@@ -235,7 +251,7 @@ cdef GPUMemoryView view_as_cai(obj, stream_ptr):
235251
if stream_ptr is None:
236252
raise BufferError("stream=None is ambiguous with view()")
237253

238-
cdef GPUMemoryView buf = GPUMemoryView()
254+
cdef GPUMemoryView buf = GPUMemoryView() if view is None else view
239255
buf.obj = obj
240256
buf.ptr, buf.readonly = cai_data["data"]
241257
buf.shape = cai_data["shape"]

cuda_core/cuda/core/_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ def __init__(self):
3030
@staticmethod
3131
def _from_obj(obj, mod):
3232
assert isinstance(obj, (cuda.CUkernel, cuda.CUfunction))
33-
assert isinstance(mod, Module)
33+
assert isinstance(mod, ObjectCode)
3434
ker = Kernel.__new__(Kernel)
3535
ker._handle = obj
3636
ker._module = mod
3737
return ker
3838

3939

40-
class Module:
40+
class ObjectCode:
4141

4242
__slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map")
4343
_supported_code_type = ("cubin", "ptx", "fatbin")
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
from cuda import nvrtc
66
from cuda.core._utils import handle_return
7-
from cuda.core._module import Module
7+
from cuda.core._module import ObjectCode
88

99

10-
class Compiler:
10+
class Program:
1111

1212
__slots__ = ("_handle", "_backend", )
1313
_supported_code_type = ("c++", )
@@ -26,6 +26,8 @@ def __init__(self, code, code_type):
2626
self._handle = handle_return(
2727
nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
2828
self._backend = "nvrtc"
29+
else:
30+
raise NotImplementedError
2931

3032
def __del__(self):
3133
self.close()
@@ -72,7 +74,7 @@ def compile(self, target_type, options=(), name_expressions=(), logs=None):
7274

7375
# TODO: handle jit_options for ptx?
7476

75-
return Module(data, target_type, symbol_mapping=symbol_mapping)
77+
return ObjectCode(data, target_type, symbol_mapping=symbol_mapping)
7678

7779
@property
7880
def backend(self):

0 commit comments

Comments
 (0)