Skip to content

Commit bbb227b

Browse files
committed
Layout tests for SMV created from CAI
Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
1 parent 598a2f1 commit bbb227b

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

cuda_core/tests/test_utils.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -273,31 +273,52 @@ def test_from_buffer_disallowed_negative_offset():
273273
StridedMemoryView.from_buffer(buffer, layout)
274274

275275

276+
class _EnforceCAIView:
277+
def __init__(self, array):
278+
self.array = array
279+
self.__cuda_array_interface__ = array.__cuda_array_interface__
280+
281+
282+
def _get_ptr(array):
283+
if isinstance(array, np.ndarray):
284+
return array.ctypes.data
285+
else:
286+
assert isinstance(array, cp.ndarray)
287+
return array.data.ptr
288+
289+
276290
@pytest.mark.parametrize(
277-
("shape", "slices", "stride_order"),
291+
("shape", "slices", "stride_order", "view_as"),
278292
[
279-
(shape, slices, stride_order)
293+
(shape, slices, stride_order, view_as)
280294
for shape, slices in [
281295
((5, 6), (2, slice(1, -1))),
282296
((10, 13, 11), (slice(None, None, 2), slice(None, None, -1), slice(2, -3))),
283297
]
284298
for stride_order in ["C", "F"]
299+
for view_as in ["dlpack", "cai"]
285300
],
286301
)
287-
def test_from_buffer_sliced_external(shape, slices, stride_order):
288-
if np is None:
289-
pytest.skip("NumPy is not installed")
290-
a = np.arange(math.prod(shape), dtype=np.int32).reshape(shape, order=stride_order)
291-
view = StridedMemoryView(a, -1)
302+
def test_from_buffer_sliced_external(shape, slices, stride_order, view_as):
303+
if view_as == "dlpack":
304+
if np is None:
305+
pytest.skip("NumPy is not installed")
306+
a = np.arange(math.prod(shape), dtype=np.int32).reshape(shape, order=stride_order)
307+
view = StridedMemoryView(a, -1)
308+
else:
309+
if cp is None:
310+
pytest.skip("CuPy is not installed")
311+
a = cp.arange(math.prod(shape), dtype=cp.int32).reshape(shape, order=stride_order)
312+
view = StridedMemoryView(_EnforceCAIView(a), -1)
292313
layout = view.layout
293314
assert layout.is_dense
294315
assert layout.required_size_in_bytes() == a.nbytes
295-
assert view.ptr == a.ctypes.data
316+
assert view.ptr == _get_ptr(a)
296317

297318
sliced_layout = layout[slices]
298319
sliced_view = view.view(sliced_layout)
299320
a_sliced = a[slices]
300-
assert sliced_view.ptr == a_sliced.ctypes.data
321+
assert sliced_view.ptr == _get_ptr(a_sliced)
301322
assert sliced_view.ptr != view.ptr
302323

303324
assert 0 <= sliced_layout.required_size_in_bytes() <= a.nbytes

0 commit comments

Comments
 (0)