@@ -273,31 +273,52 @@ def test_from_buffer_disallowed_negative_offset():
273273 StridedMemoryView .from_buffer (buffer , layout )
274274
275275
276+ class _EnforceCAIView :
277+ def __init__ (self , array ):
278+ self .array = array
279+ self .__cuda_array_interface__ = array .__cuda_array_interface__
280+
281+
282+ def _get_ptr (array ):
283+ if isinstance (array , np .ndarray ):
284+ return array .ctypes .data
285+ else :
286+ assert isinstance (array , cp .ndarray )
287+ return array .data .ptr
288+
289+
276290@pytest .mark .parametrize (
277- ("shape" , "slices" , "stride_order" ),
291+ ("shape" , "slices" , "stride_order" , "view_as" ),
278292 [
279- (shape , slices , stride_order )
293+ (shape , slices , stride_order , view_as )
280294 for shape , slices in [
281295 ((5 , 6 ), (2 , slice (1 , - 1 ))),
282296 ((10 , 13 , 11 ), (slice (None , None , 2 ), slice (None , None , - 1 ), slice (2 , - 3 ))),
283297 ]
284298 for stride_order in ["C" , "F" ]
299+ for view_as in ["dlpack" , "cai" ]
285300 ],
286301)
287- def test_from_buffer_sliced_external (shape , slices , stride_order ):
288- if np is None :
289- pytest .skip ("NumPy is not installed" )
290- a = np .arange (math .prod (shape ), dtype = np .int32 ).reshape (shape , order = stride_order )
291- view = StridedMemoryView (a , - 1 )
302+ def test_from_buffer_sliced_external (shape , slices , stride_order , view_as ):
303+ if view_as == "dlpack" :
304+ if np is None :
305+ pytest .skip ("NumPy is not installed" )
306+ a = np .arange (math .prod (shape ), dtype = np .int32 ).reshape (shape , order = stride_order )
307+ view = StridedMemoryView (a , - 1 )
308+ else :
309+ if cp is None :
310+ pytest .skip ("CuPy is not installed" )
311+ a = cp .arange (math .prod (shape ), dtype = cp .int32 ).reshape (shape , order = stride_order )
312+ view = StridedMemoryView (_EnforceCAIView (a ), - 1 )
292313 layout = view .layout
293314 assert layout .is_dense
294315 assert layout .required_size_in_bytes () == a .nbytes
295- assert view .ptr == a . ctypes . data
316+ assert view .ptr == _get_ptr ( a )
296317
297318 sliced_layout = layout [slices ]
298319 sliced_view = view .view (sliced_layout )
299320 a_sliced = a [slices ]
300- assert sliced_view .ptr == a_sliced . ctypes . data
321+ assert sliced_view .ptr == _get_ptr ( a_sliced )
301322 assert sliced_view .ptr != view .ptr
302323
303324 assert 0 <= sliced_layout .required_size_in_bytes () <= a .nbytes
0 commit comments