Skip to content

Commit 48a305c

Browse files
committed
fix dtype repr and stream pass-through
1 parent ab83c5b commit 48a305c

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

cuda_py/cuda/py/_memoryview.pyx

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ cdef class GPUMemoryView:
2929
return (f"GPUMemoryView(ptr={self.ptr},\n"
3030
+ f" shape={self.shape},\n"
3131
+ f" strides={self.strides},\n"
32-
+ f" dtype={get_simple_repr(numpy.dtype(self.dtype))},\n"
32+
+ f" dtype={self.dtype.__name__},\n"
3333
+ f" device_id={self.device_id},\n"
3434
+ f" device_accessible={self.device_accessible},\n"
3535
+ f" readonly={self.readonly},\n"
@@ -39,7 +39,7 @@ cdef class GPUMemoryView:
3939
cdef str get_simple_repr(obj):
4040
cdef object obj_class = obj.__class__
4141
cdef str obj_repr
42-
if obj_class.__module__ in (None, "__builtin__"):
42+
if obj_class.__module__ in (None, "builtins"):
4343
obj_repr = obj_class.__name__
4444
else:
4545
obj_repr = f"{obj_class.__module__}.{obj_class.__name__}"
@@ -78,17 +78,24 @@ cdef GPUMemoryView view_as_dlpack(obj, stream_ptr):
7878
if dldevice == _kDLCPU:
7979
device_accessible = False
8080
assert device_id == 0
81-
stream_ptr = None
81+
if stream_ptr is None:
82+
raise BufferError("stream=None is ambiguous with view()")
83+
elif stream_ptr == -1:
84+
stream_ptr = None
8285
elif dldevice == _kDLCUDA:
8386
device_accessible = True
84-
stream_ptr = -1
87+
# no need to check other stream values, it's a pass-through
88+
if stream_ptr is None:
89+
raise BufferError("stream=None is ambiguous with view()")
8590
elif dldevice == _kDLCUDAHost:
8691
device_accessible = True
8792
assert device_id == 0
88-
stream_ptr = None
93+
# just do a pass-through without any checks, as pinned memory can be
94+
# accessed on both host and device
8995
elif dldevice == _kDLCUDAManaged:
9096
device_accessible = True
91-
stream_ptr = -1
97+
# just do a pass-through without any checks, as managed memory can be
98+
# accessed on both host and device
9299
else:
93100
raise BufferError("device not supported")
94101

0 commit comments

Comments
 (0)