From 85d4a3a46fcc17b3ab7f4b4af7ae42dd40c5b4f5 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Fri, 21 Nov 2025 19:18:23 +0100 Subject: [PATCH 01/20] Add StridedLayout Signed-off-by: Kamil Tokarski --- cuda_core/cuda/core/experimental/_layout.pxd | 658 +++++++++++++++ cuda_core/cuda/core/experimental/_layout.pyx | 799 ++++++++++++++++++ .../cuda/core/experimental/_memoryview.pyx | 95 ++- .../core/experimental/_utils/cuda_utils.pxd | 7 +- .../cuda/core/experimental/include/layout.hpp | 50 ++ 5 files changed, 1573 insertions(+), 36 deletions(-) create mode 100644 cuda_core/cuda/core/experimental/_layout.pxd create mode 100644 cuda_core/cuda/core/experimental/_layout.pyx create mode 100644 cuda_core/cuda/core/experimental/include/layout.hpp diff --git a/cuda_core/cuda/core/experimental/_layout.pxd b/cuda_core/cuda/core/experimental/_layout.pxd new file mode 100644 index 0000000000..2576349cc3 --- /dev/null +++ b/cuda_core/cuda/core/experimental/_layout.pxd @@ -0,0 +1,658 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +cimport cython +from cython.operator cimport dereference as deref + +from libc.stdint cimport int64_t, uint32_t, intptr_t +from libcpp cimport vector + +ctypedef int64_t extent_t +ctypedef int64_t stride_t +ctypedef int axis_t + +ctypedef uint32_t axes_mask_t # MUST be exactly STRIDED_LAYOUT_MAX_NDIM bits wide +ctypedef uint32_t property_mask_t + +ctypedef vector.vector[stride_t] extents_strides_t +ctypedef vector.vector[axis_t] axis_vec_t + +from cuda.core.experimental._utils cimport cuda_utils + + +ctypedef fused integer_t: + int64_t + int + + +cdef extern from "include/layout.hpp": + + cdef int STRIDED_LAYOUT_MAX_NDIM + cdef int AXIS_MASK_ALL + int64_t _c_abs(int64_t x) nogil + void _order_from_strides(axis_vec_t& indices, extent_t* extent_t, stride_t* stride_t, int ndim) except + nogil + void _swap(extents_strides_t &a, extents_strides_t &b) noexcept nogil + void _swap(int64_t* a, int64_t* b) noexcept nogil + void _swap(int a, int b) noexcept nogil + void _swap(axis_vec_t &a, axis_vec_t &b) noexcept nogil + + +cdef enum OrderFlag: + ORDER_NONE = 0 + ORDER_C = 1 + ORDER_F = 2 + ORDER_PERM = 3 + + +cdef enum Property: + PROP_IS_UNIQUE = 1 << 0 + PROP_IS_CONTIGUOUS_C = 1 << 1 + PROP_IS_CONTIGUOUS_F = 1 << 2 + PROP_IS_CONTIGUOUS_ANY = 1 << 3 + PROP_REQUIRED_SIZE_IN_BYTES = 1 << 4 + PROP_SHAPE = 1 << 5 + PROP_STRIDES = 1 << 6 + PROP_STRIDES_IN_BYTES = 1 << 7 + PROP_STRIDE_ORDER = 1 << 8 + PROP_VOLUME = 1 << 9 + + +cdef struct BaseLayout: + # A struct holding the shape and strides for the layout. + # Use ``init_base_layout`` to initialize the layout, it will + # set the ``shape`` and ``strides`` pointers to point to + # ndim contigious integer arrays. + # The ``shape`` pointer must not be NULL, the ``strides`` can be + # set to NULL by the user to indicate C-contiguous layout. + # Uses single _mem allocation to reduce overhead + # (allocation and exceptions checks). + + extents_strides_t _mem + extent_t* shape + stride_t* strides + int ndim + + +@cython.final +cdef class StridedLayout: + + # Definition + cdef: + BaseLayout base + + readonly: + int itemsize + stride_t slice_offset + + # Lazy properties computed from the defining values. + cdef: + # Set to 0 to invalidate all properties, + # whenever a defining value is changed + property_mask_t _prop_mask + + # C and Python properties + property_mask_t _boolean_props + int64_t _required_size_in_bytes + int64_t _volume + + # Python properties + tuple _py_shape + tuple _py_strides + tuple _py_strides_in_bytes + tuple _py_stride_order + + # ============================== + # Initialization + # ============================== + + cdef inline int _init(StridedLayout self, BaseLayout& base, int itemsize, bint strides_in_bytes=False) except -1 nogil: + _validate_itemsize(itemsize) + + if base.strides != NULL and strides_in_bytes: + _divide_strides(base, itemsize) + + self.itemsize = itemsize + self.slice_offset = 0 + _swap_layout(self.base, base) + return 0 + + cdef inline stride_t _init_dense(StridedLayout self, BaseLayout& base, int itemsize, OrderFlag order_flag, axis_vec_t* stride_order=NULL) except -1 nogil: + _validate_itemsize(itemsize) + + cdef stride_t volume + if order_flag == ORDER_C: + volume = _dense_strides_c(base) + elif order_flag == ORDER_F: + volume = _dense_strides_f(base) + elif order_flag == ORDER_PERM: + if stride_order == NULL: + raise ValueError("stride_order is required for ORDER_PERM") + volume = _dense_strides_in_order(base, deref(stride_order)) + else: + raise ValueError("The stride_order must be 'C', 'F', or a permutation.") + + self.itemsize = itemsize + self.slice_offset = 0 + _swap_layout(self.base, base) + self._volume = volume + _mark_property_valid(self, PROP_VOLUME) + return 0 + + cdef inline int init_from_ptr(StridedLayout self, int ndim, extent_t* shape, stride_t* strides, int itemsize, bint strides_in_bytes=False) except -1 nogil: + cdef BaseLayout base + _init_base_layout_from_ptr(base, ndim, shape, strides) + return self._init(base, itemsize, strides_in_bytes) + + cdef inline int init_dense_from_ptr(StridedLayout self, int ndim, extent_t* shape, int itemsize, OrderFlag order_flag, axis_vec_t* stride_order=NULL) except -1 nogil: + cdef BaseLayout base + _init_base_layout_from_ptr(base, ndim, shape, NULL) + return self._init_dense(base, itemsize, order_flag, stride_order) + + cdef inline int init_from_tuple(StridedLayout self, tuple shape, tuple strides, int itemsize, bint strides_in_bytes=False) except -1: + cdef BaseLayout base + _init_base_layout_from_tuple(base, shape, strides) + return self._init(base, itemsize, strides_in_bytes) + + cdef inline int init_dense_from_tuple(StridedLayout self, tuple shape, int itemsize, object stride_order) except -1: + cdef axis_vec_t stride_order_vec + cdef OrderFlag order_flag = _stride_order2vec(stride_order_vec, stride_order) + + if order_flag == ORDER_NONE: + raise ValueError(f"The stride_order must be 'C', 'F', or a permutation tuple. Got: {stride_order}") + + cdef BaseLayout base + _init_base_layout_from_tuple(base, shape, None) + return self._init_dense(base, itemsize, order_flag, &stride_order_vec) + + # ============================== + # Properties + # ============================== + + cdef inline tuple get_shape_tuple(StridedLayout self): + if not _has_valid_property(self, PROP_SHAPE): + self._py_shape = cuda_utils.carray_integer_t_to_tuple(self.base.shape, self.base.ndim) + _mark_property_valid(self, PROP_SHAPE) + return self._py_shape + + cdef inline tuple get_strides_tuple(StridedLayout self): + if not _has_valid_property(self, PROP_STRIDES): + if self.base.strides == NULL: + self._py_strides = None + else: + self._py_strides = cuda_utils.carray_integer_t_to_tuple(self.base.strides, self.base.ndim) + _mark_property_valid(self, PROP_STRIDES) + return self._py_strides + + cdef inline int get_strides_in_bytes(StridedLayout self, extents_strides_t& strides) except -1 nogil: + if self.base.strides != NULL: + strides.resize(self.base.ndim) + for i in range(self.base.ndim): + strides[i] = _overflow_checked_mul(self.base.strides[i], self.itemsize) + return 0 + + cdef inline tuple get_strides_in_bytes_tuple(StridedLayout self): + if _has_valid_property(self, PROP_STRIDES_IN_BYTES): + return self._py_strides_in_bytes + cdef extents_strides_t strides + if self.base.strides == NULL: + self._py_strides_in_bytes = None + else: + self.get_strides_in_bytes(strides) + self._py_strides_in_bytes = cuda_utils.carray_integer_t_to_tuple(strides.data(), strides.size()) + _mark_property_valid(self, PROP_STRIDES_IN_BYTES) + return self._py_strides_in_bytes + + cdef inline int64_t get_volume(StridedLayout self) except -1 nogil: + if not _has_valid_property(self, PROP_VOLUME): + self._volume = _volume(self.base) + _mark_property_valid(self, PROP_VOLUME) + return self._volume + + cdef inline int get_stride_order(StridedLayout self, axis_vec_t& stride_order) except -1 nogil: + _order_from_strides(stride_order, self.base.shape, self.base.strides, self.base.ndim) + return 0 + + cdef inline tuple get_stride_order_tuple(StridedLayout self): + if _has_valid_property(self, PROP_STRIDE_ORDER): + return self._py_stride_order + cdef axis_vec_t stride_order + self.get_stride_order(stride_order) + self._py_stride_order = cuda_utils.carray_integer_t_to_tuple(stride_order.data(), stride_order.size()) + _mark_property_valid(self, PROP_STRIDE_ORDER) + return self._py_stride_order + + cdef inline bint get_is_unique(StridedLayout self) except -1 nogil: + if _has_valid_property(self, PROP_IS_UNIQUE): + return _boolean_property(self, PROP_IS_UNIQUE) + if self.base.strides == NULL or self.get_volume() == 0: + return _set_boolean_property(self, PROP_IS_UNIQUE, True) + cdef axis_vec_t stride_order + self.get_stride_order(stride_order) + return _set_boolean_property(self, PROP_IS_UNIQUE, _is_unique(self.base, stride_order)) + + cdef inline bint get_is_contiguous_c(StridedLayout self) except -1 nogil: + if _has_valid_property(self, PROP_IS_CONTIGUOUS_C): + return _boolean_property(self, PROP_IS_CONTIGUOUS_C) + return _set_boolean_property(self, PROP_IS_CONTIGUOUS_C, _is_contiguous_c(self.get_volume(), self.base)) + + cdef inline bint get_is_contiguous_f(StridedLayout self) except -1 nogil: + if _has_valid_property(self, PROP_IS_CONTIGUOUS_F): + return _boolean_property(self, PROP_IS_CONTIGUOUS_F) + return _set_boolean_property(self, PROP_IS_CONTIGUOUS_F, _is_contiguous_f(self.get_volume(), self.base)) + + cdef inline bint get_is_contiguous_any(StridedLayout self) except -1 nogil: + if _has_valid_property(self, PROP_IS_CONTIGUOUS_ANY): + return _boolean_property(self, PROP_IS_CONTIGUOUS_ANY) + cdef axis_vec_t stride_order + self.get_stride_order(stride_order) + return _set_boolean_property(self, PROP_IS_CONTIGUOUS_ANY, _is_contiguous_any(self.get_volume(), self.base, stride_order)) + + cdef inline int get_offset_bounds(StridedLayout self, stride_t& min_offset, stride_t& max_offset) except -1 nogil: + min_offset = 0 + max_offset = 0 + if self.base.strides == NULL: + max_offset = self.get_volume() - 1 + return 0 + cdef int ndim = self.base.ndim + cdef stride_t stride + cdef extent_t extent + for i in range(ndim): + stride = self.base.strides[i] # can be negative + extent = self.base.shape[i] # must be non-negative + if extent == 0: + min_offset = 0 + max_offset = -1 # so that max_offset - min_offset + 1 = 0 + return 0 + if stride <= 0: + min_offset = _overflow_checked_sum(min_offset, _overflow_checked_mul(stride, (extent - 1))) + else: + max_offset = _overflow_checked_sum(max_offset, _overflow_checked_mul(stride, (extent - 1))) + return 0 + + cdef inline int64_t get_required_size_in_bytes(StridedLayout self) except -1 nogil: + if _has_valid_property(self, PROP_REQUIRED_SIZE_IN_BYTES): + return self._required_size_in_bytes + cdef stride_t min_offset = 0 + cdef stride_t max_offset = 0 + self.get_offset_bounds(min_offset, max_offset) + if self.slice_offset > 0: + min_offset = min(min_offset, -self.slice_offset) + elif self.slice_offset < 0 and max_offset >= 0: + max_offset = max(max_offset, -self.slice_offset) + cdef int64_t required_size_in_bytes = _overflow_checked_diff(max_offset, min_offset) + required_size_in_bytes = _overflow_checked_sum(required_size_in_bytes, 1) + self._required_size_in_bytes = _overflow_checked_mul(required_size_in_bytes, self.itemsize) + _mark_property_valid(self, PROP_REQUIRED_SIZE_IN_BYTES) + return self._required_size_in_bytes + + cdef inline int64_t get_slice_offset_in_bytes(StridedLayout self) except -1 nogil: + return _overflow_checked_mul(self.slice_offset, self.itemsize) + + cdef axes_mask_t get_flattened_axis_mask(StridedLayout self) except? -1 nogil + cdef int get_max_compatible_itemsize(StridedLayout self, int max_itemsize, intptr_t data_ptr, int axis=*) except -1 nogil + + # ============================== + # Layout manipulation + # ============================== + + + cdef int reshape_into(StridedLayout self, StridedLayout out_layout, BaseLayout& new_shape) except -1 nogil + cdef int permute_into(StridedLayout self, StridedLayout out_layout, axis_vec_t& axis_order) except -1 nogil + + cdef int flatten_into(StridedLayout self, StridedLayout out_layout, axes_mask_t axis_mask=*) except -1 nogil + cdef int squeeze_into(StridedLayout self, StridedLayout out_layout) except -1 nogil + cdef int unsqueeze_into(StridedLayout self, StridedLayout out_layout, axis_vec_t& axis_vec) except -1 nogil + cdef int broadcast_into(StridedLayout self, StridedLayout out_layout, BaseLayout& broadcast) except -1 nogil + cdef int pack_into(StridedLayout self, StridedLayout out_layout, int itemsize, intptr_t data_ptr, bint keep_dim, int axis=*) except -1 nogil + cdef int unpack_into(StridedLayout self, StridedLayout out_layout, int itemsize, int axis=*) except -1 nogil + cdef int slice_into(StridedLayout self, StridedLayout out_layout, tuple slices) except -1 + +# ============================== +# Base layout helpers +# ============================== + + +cdef inline int init_base_layout(BaseLayout& layout, int ndim) except -1 nogil: + if ndim > STRIDED_LAYOUT_MAX_NDIM: + raise ValueError(f"Unsupported number of dimensions: {ndim}. Max supported ndim is {STRIDED_LAYOUT_MAX_NDIM}") + # resize(0) is no op, that results in _mem.data() being NULL, + # which would make it tricky to distinguish between strides == NULL + # and strides == tuple() + layout._mem.resize(2 * max(ndim, 1)) + layout.shape = layout._mem.data() + layout.strides = layout._mem.data() + ndim + layout.ndim = ndim + return 0 + + +cdef inline int trim_base_layout(BaseLayout& layout, int ndim) except -1 nogil: + if ndim > layout.ndim: + raise AssertionError(f"Cannot trim layout to {ndim} dimensions, it has {layout.ndim} dimensions") + layout.ndim = ndim + return 0 + + +cdef inline void _swap_layout(BaseLayout& a, BaseLayout& b) noexcept nogil: + _swap(a._mem, b._mem) + _swap(a.shape, b.shape) + _swap(a.strides, b.strides) + _swap(a.ndim, b.ndim) + + +cdef inline void _assure_strides_ptr(BaseLayout& base) noexcept nogil: + if base.strides == NULL: + base.strides = base._mem.data() + base._mem.size() // 2 + + +cdef inline stride_t *get_strides_ptr(BaseLayout& base) except? NULL nogil: + if base.strides != NULL: + return base.strides + cdef stride_t* tmp_strides = base._mem.data() + base._mem.size() // 2 + _dense_strides_c_ptrs(base.ndim, base.shape, tmp_strides) + return tmp_strides + + +cdef inline bint _base_layout_equal(BaseLayout& a, BaseLayout& b) noexcept nogil: + if a.ndim != b.ndim: + return False + for i in range(a.ndim): + if a.shape[i] != b.shape[i]: + return False + if a.strides != NULL or b.strides != NULL: + if a.strides == NULL or b.strides == NULL: + return False + for i in range(a.ndim): + if a.strides[i] != b.strides[i]: + return False + return True + + +@cython.overflowcheck(True) +cdef inline int64_t _volume(BaseLayout& base) except? -1 nogil: + cdef int64_t vol = 1 + for i in range(base.ndim): + vol *= base.shape[i] + return vol + + +cdef inline int _divide_strides(BaseLayout& base, int itemsize) except -1 nogil: + cdef stride_t stride + if base.strides == NULL: + raise ValueError("cannot divide strides, layout has no strides") + for i in range(base.ndim): + stride = base.strides[i] // itemsize + if stride * itemsize != base.strides[i]: + raise ValueError("strides must be divisible by itemsize") + base.strides[i] = stride + return 0 + + +cdef inline void _zero_strides_ptr(int ndim, stride_t* strides) noexcept nogil: + for i in range(ndim): + strides[i] = 0 + + +cdef inline void _zero_strides(BaseLayout& base) noexcept nogil: + _assure_strides_ptr(base) + _zero_strides_ptr(base.ndim, base.strides) + + +cdef inline stride_t _dense_strides_c_ptrs(int ndim, extent_t* shape, stride_t* strides) except? -1 nogil: + cdef stride_t stride = 1 + cdef int i = ndim - 1 + while i >= 0: + strides[i] = stride + stride = _overflow_checked_mul(stride, shape[i]) + i -= 1 + if stride == 0: + _zero_strides_ptr(ndim, strides) + return stride + + +cdef inline stride_t _dense_strides_c(BaseLayout& base) except? -1 nogil: + cdef int ndim = base.ndim + _assure_strides_ptr(base) + return _dense_strides_c_ptrs(ndim, base.shape, base.strides) + + +cdef inline stride_t _dense_strides_f(BaseLayout& base) except? -1 nogil: + cdef int ndim = base.ndim + _assure_strides_ptr(base) + cdef stride_t stride = 1 + cdef int i = 0 + while i < ndim: + base.strides[i] = stride + stride = _overflow_checked_mul(stride, base.shape[i]) + i += 1 + if stride == 0: + _zero_strides(base) + return stride + + +cdef inline stride_t _dense_strides_in_order(BaseLayout& base, axis_vec_t& stride_order) except? -1 nogil: + cdef int ndim = base.ndim + if ndim != stride_order.size(): + raise ValueError(f"stride_order must have the same length as shape. Shape has {ndim} dimensions, but stride_order has {stride_order.size()} elements.") + _assure_strides_ptr(base) + cdef stride_t stride = 1 + cdef int i = ndim - 1 + cdef axes_mask_t axis_order_mask = 0 + cdef axes_mask_t axis_mask + cdef axis_t axis + while i >= 0: + axis = stride_order[i] + if not _normalize_axis(axis, ndim): + raise ValueError(f"Invalid stride order: axis {axis} out of range for {ndim}D tensor") + axis_mask = 1 << axis + if axis_order_mask & axis_mask: + raise ValueError(f"The stride order must be a permutation. Axis {axis} appears multiple times.") + axis_order_mask |= axis_mask + base.strides[axis] = stride + stride = _overflow_checked_mul(stride, base.shape[axis]) + i -= 1 + if stride == 0: + _zero_strides(base) + return stride + + +cdef inline bint _is_contiguous_c(int64_t volume, BaseLayout& base) except -1 nogil: + if volume == 0 or base.strides == NULL: + return True + cdef int64_t stride = 1 + cdef int64_t j = base.ndim - 1 + cdef extent_t extent + while j >= 0: + extent = base.shape[j] + if extent != 1: + if base.strides[j] != stride: + return False + stride *= extent + j -= 1 + return True + + +cdef inline bint _is_contiguous_f(int64_t volume, BaseLayout& base) except -1 nogil: + if volume == 0: + return True + cdef int ndim = base.ndim + cdef int64_t j = 0 + if base.strides == NULL: + # find first non-singleton dimension + while j < ndim and base.shape[j] == 1: + j += 1 + # if any subsequent dimension is not a singleton, return False + for i in range(j + 1, ndim): + if base.shape[i] != 1: + return False + return True + cdef int64_t stride = 1 + cdef extent_t extent + while j < ndim: + extent = base.shape[j] + if extent != 1: + if base.strides[j] != stride: + return False + stride *= extent + j += 1 + return True + + +cdef inline bint _is_contiguous_any(int64_t volume, BaseLayout& base, axis_vec_t& axis_order) except -1 nogil: + if volume == 0 or base.strides == NULL: + return True + cdef int64_t stride = 1 + cdef int64_t j = base.ndim - 1 + cdef axis_t axis + cdef extent_t extent + while j >= 0: + axis = axis_order[j] + extent = base.shape[axis] + if extent != 1: + if base.strides[axis] != stride: + return False + stride *= extent + j -= 1 + return True + + +cdef inline int _validate_shape(BaseLayout& base) except -1 nogil: + for i in range(base.ndim): + if base.shape[i] < 0: + raise ValueError("Extents must be non-negative") + return 0 + + +cdef inline int _init_base_layout_from_tuple(BaseLayout& base, tuple shape, tuple strides) except -1: + cdef int ndim = len(shape) + init_base_layout(base, ndim) + for i in range(ndim): + base.shape[i] = shape[i] + _validate_shape(base) + + if strides is None: + base.strides = NULL + else: + if len(strides) != ndim: + raise ValueError(f"Strides, if provided, must have the same length as shape. Shape has {ndim} dimensions, but strides has {len(strides)} elements.") + for i in range(ndim): + base.strides[i] = strides[i] + return 0 + + +cdef inline int _init_base_layout_from_ptr(BaseLayout& base, int ndim, extent_t* shape, stride_t* strides) except -1 nogil: + init_base_layout(base, ndim) + for i in range(ndim): + base.shape[i] = shape[i] + _validate_shape(base) + + if strides == NULL: + base.strides = NULL + else: + for i in range(ndim): + base.strides[i] = strides[i] + return 0 + +# ============================== +# Strided layout helpers +# ============================== + + +cdef inline bint _has_valid_property(StridedLayout self, Property prop) noexcept nogil: + return self._prop_mask & prop + + +cdef inline void _mark_property_valid(StridedLayout self, Property prop) noexcept nogil: + self._prop_mask |= prop + + +cdef inline bint _boolean_property(StridedLayout self, Property prop) noexcept nogil: + return self._boolean_props & prop + + +cdef inline bint _set_boolean_property(StridedLayout self, Property prop, bint value) noexcept nogil: + if value: + self._boolean_props |= prop + else: + self._boolean_props &= ~prop + _mark_property_valid(self, prop) + return value + + +# ============================== +# Conversion, validation and normalization helpers +# ============================== + +cdef inline OrderFlag _stride_order2vec(axis_vec_t& stride_order_vec, object stride_order) except? ORDER_NONE: + if stride_order == 'C': + return ORDER_C + elif stride_order == 'F': + return ORDER_F + elif isinstance(stride_order, tuple | list): + _tuple2axis_vec(stride_order_vec, stride_order) + return ORDER_PERM + return ORDER_NONE + + +cdef inline int _tuple2axis_vec(axis_vec_t& vec, object t) except -1: + cdef int ndim = len(t) + vec.resize(ndim) + for i in range(ndim): + vec[i] = t[i] + return 0 + + +cdef inline bint _normalize_axis(integer_t& axis, integer_t extent) except -1 nogil: + if axis < -extent or axis >= extent: + return False + if axis < 0: + axis += extent + return True + + +cdef inline int _validate_itemsize(int itemsize) except -1 nogil: + if itemsize <= 0: + raise ValueError("itemsize must be positive") + if itemsize & (itemsize - 1): + raise ValueError("itemsize must be a power of two") + return 0 + + +cdef inline bint _is_unique(BaseLayout& base, axis_vec_t& stride_order) except -1 nogil: + if base.strides == NULL: + return True + cdef int64_t cur_max_offset = 0 + cdef int i = base.ndim - 1 + cdef int64_t stride + cdef axis_t axis + cdef extent_t extent + while i >= 0: + axis = stride_order[i] + extent = base.shape[axis] + if extent != 1: + stride = _c_abs(base.strides[axis]) + if cur_max_offset >= stride: + return False + cur_max_offset = _overflow_checked_sum(cur_max_offset, _overflow_checked_mul(stride, (extent - 1))) + i -= 1 + return True + + +@cython.overflowcheck(True) +cdef inline int64_t _overflow_checked_mul(int64_t a, int64_t b) except? -1 nogil: + return a * b + + +@cython.overflowcheck(True) +cdef inline int64_t _overflow_checked_diff(int64_t a, int64_t b) except? -1 nogil: + return a - b + + +@cython.overflowcheck(True) +cdef inline int64_t _overflow_checked_sum(int64_t a, int64_t b) except? -1 nogil: + return a + b + + +@cython.overflowcheck(True) +cdef inline int64_t _overflow_checked_div_ceil(int64_t a, int64_t b) except? -1 nogil: + return (a + b - 1) // b diff --git a/cuda_core/cuda/core/experimental/_layout.pyx b/cuda_core/cuda/core/experimental/_layout.pyx new file mode 100644 index 0000000000..1df819bb88 --- /dev/null +++ b/cuda_core/cuda/core/experimental/_layout.pyx @@ -0,0 +1,799 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +cimport cython + +from libc.stdint cimport int64_t, intptr_t +from libcpp cimport vector + +from cpython.object cimport PyObject + + +cdef extern from "Python.h": + int _PySlice_Unpack "PySlice_Unpack" (PyObject *slice, Py_ssize_t *start, Py_ssize_t *stop, Py_ssize_t *step) except -1 + Py_ssize_t _PySlice_AdjustIndices "PySlice_AdjustIndices" (Py_ssize_t length, Py_ssize_t *start, Py_ssize_t *stop, Py_ssize_t step) noexcept nogil + + +@cython.final +cdef class StridedLayout: + + def __init__(StridedLayout self, object shape, object strides, int itemsize, bint strides_in_bytes=False): + self.init_from_tuple(shape, strides, itemsize, strides_in_bytes) + + @classmethod + def dense(cls, object shape, int itemsize, object stride_order='C'): + cdef StridedLayout new_layout = StridedLayout.__new__(cls) + new_layout.init_dense_from_tuple(shape, itemsize, stride_order) + return new_layout + + @classmethod + def dense_like(cls, StridedLayout other, object stride_order="K"): + cdef StridedLayout new_layout = StridedLayout.__new__(cls) + cdef OrderFlag order_flag + cdef axis_vec_t stride_order_vec + + if stride_order == "K": + other.get_stride_order(stride_order_vec) + order_flag = ORDER_PERM + else: + order_flag = _stride_order2vec(stride_order_vec, stride_order) + if order_flag == ORDER_NONE: + raise ValueError(f"The stride_order must be 'K', 'C', 'F', or a permutation tuple. Got: {stride_order}") + + new_layout.init_dense_from_ptr( + other.base.ndim, + other.base.shape, + other.itemsize, + order_flag, + &stride_order_vec + ) + return new_layout + + def __repr__(StridedLayout self): + if self.slice_offset == 0: + return ( + f"StridedLayout(shape={self.shape}, strides={self.strides}, itemsize={self.itemsize})" + ) + else: + return ( + f"StridedLayout(shape={self.shape}, strides={self.strides}, itemsize={self.itemsize}, _slice_offset={self.slice_offset})" + ) + + def __eq__(StridedLayout self, StridedLayout other): + return self.itemsize == other.itemsize and self.slice_offset == other.slice_offset and _base_layout_equal(self.base, other.base) + + @property + def ndim(StridedLayout self) -> int: + return self.base.ndim + + @property + def shape(StridedLayout self) -> tuple: + return self.get_shape_tuple() + + @property + def strides(StridedLayout self) -> tuple | None: + return self.get_strides_tuple() + + @property + def strides_in_bytes(StridedLayout self) -> tuple | None: + return self.get_strides_in_bytes_tuple() + + @property + def stride_order(StridedLayout self) -> tuple: + return self.get_stride_order_tuple() + + @property + def volume(StridedLayout self) -> int: + return self.get_volume() + + @property + def is_unique(StridedLayout self) -> bool: + return self.get_is_unique() + + @property + def is_contiguous_c(StridedLayout self): + return self.get_is_contiguous_c() + + @property + def is_contiguous_f(StridedLayout self): + return self.get_is_contiguous_f() + + @property + def is_contiguous_any(StridedLayout self): + return self.get_is_contiguous_any() + + @property + def offset_bounds(StridedLayout self): + cdef stride_t min_offset = 0 + cdef stride_t max_offset = 0 + self.get_offset_bounds(min_offset, max_offset) + return min_offset, max_offset + + @property + def required_size_in_bytes(StridedLayout self): + return self.get_required_size_in_bytes() + + @property + def slice_offset_in_bytes(StridedLayout self): + return self.get_slice_offset_in_bytes() + + def flattened_axis_mask(StridedLayout self): + return self.get_flattened_axis_mask() + + def reshaped(StridedLayout self, object shape): + cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) + cdef BaseLayout new_shape + init_base_layout(new_shape, len(shape)) + for i in range(len(shape)): + new_shape.shape[i] = shape[i] + self.reshape_into(new_layout, new_shape) + return new_layout + + def permuted(StridedLayout self, object axis_order): + cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) + cdef axis_vec_t axis_order_vec + _tuple2axis_vec(axis_order_vec, axis_order) + self.permute_into(new_layout, axis_order_vec) + return new_layout + + def flattened(StridedLayout self, start_axis=0, end_axis=-1, mask=None): + cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) + cdef axes_mask_t axis_mask + if mask is None: + axis_mask = axis_mask_from_range(self.ndim, start_axis, end_axis) + else: + axis_mask = mask + self.flatten_into(new_layout, axis_mask) + return new_layout + + def flattened_axis_mask(StridedLayout self): + return self.get_flattened_axis_mask() + + def squeezed(StridedLayout self): + cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) + self.squeeze_into(new_layout) + return new_layout + + def unsqueezed(StridedLayout self, object axis): + cdef axis_vec_t axis_vec + if isinstance(axis, int): + axis_vec.push_back(axis) + else: + _tuple2axis_vec(axis_vec, axis) + if axis_vec.size() == 0: + return self + cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) + self.unsqueeze_into(new_layout, axis_vec) + return new_layout + + def broadcast_to(StridedLayout self, object shape): + cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) + cdef BaseLayout new_shape + cdef int new_ndim = len(shape) + init_base_layout(new_shape, new_ndim) + for i in range(new_ndim): + new_shape.shape[i] = shape[i] + self.broadcast_into(new_layout, new_shape) + return new_layout + + def packed(StridedLayout self, int itemsize, intptr_t data_ptr=0, int axis=-1, bint keep_dim=True): + if itemsize == self.itemsize: + return self + cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) + self.pack_into(new_layout, itemsize, data_ptr, keep_dim, axis) + return new_layout + + def unpacked(StridedLayout self, int itemsize, int axis=-1): + if itemsize == self.itemsize: + return self + cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) + self.unpack_into(new_layout, itemsize, axis) + return new_layout + + def max_compatible_itemsize(StridedLayout self, int max_itemsize=16, intptr_t data_ptr=0, int axis=-1): + return self.get_max_compatible_itemsize(max_itemsize, data_ptr, axis) + + def sliced(StridedLayout self, object slices): + if not isinstance(slices, tuple): + slices = (slices,) + cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) + self.slice_into(new_layout, slices) + return new_layout + + def __getitem__(StridedLayout self, object slices): + return self.sliced(slices) + + cdef axes_mask_t get_flattened_axis_mask(StridedLayout self) except? -1 nogil: + return flattened_strides_in_c_index_order_mask(self.base) + + cdef int get_max_compatible_itemsize(StridedLayout self, int max_itemsize, intptr_t data_ptr, int axis=-1) except -1 nogil: + return max_compatible_itemsize(self.base, self.slice_offset, self.itemsize, max_itemsize, data_ptr, axis) + + cdef int reshape_into(StridedLayout self, StridedLayout out_layout, BaseLayout& new_shape) except -1 nogil: + cdef int64_t old_volume = self.get_volume() + validate_reshaped_shape(new_shape, old_volume) + + cdef int ndim = new_shape.ndim + _zero_strides(new_shape) + + cdef BaseLayout flattened + if old_volume != 0: + flatten_strides_in_c_index_order(flattened, self.base, AXIS_MASK_ALL) + if not split_strides_in_c_index_order(new_shape, flattened): + raise ValueError("Layout strides are incompatible with the new shape") + + # Reset all memoized properties + out_layout._prop_mask = 0 + + # Copy preserved attributes + out_layout.slice_offset = self.slice_offset + out_layout.itemsize = self.itemsize + maybe_copy_volume(out_layout, self) + + # Set new attributes + _swap_layout(out_layout.base, new_shape) + return 0 + + cdef int permute_into(StridedLayout self, StridedLayout out_layout, axis_vec_t& axis_order) except -1 nogil: + if axis_order.size() != self.base.ndim: + raise ValueError(f"Permutation must have the same length as the number of dimensions, got {axis_order.size()} for {self.ndim}D tensor.") + + cdef BaseLayout permuted + permute_extents(permuted, self.base, axis_order) + + # Reset all memoized properties + out_layout._prop_mask = 0 + + # Preserved attributes + out_layout.itemsize = self.itemsize + out_layout.slice_offset = self.slice_offset + maybe_copy_volume(out_layout, self) + + # Set new attributes + _swap_layout(out_layout.base, permuted) + return 0 + + cdef int flatten_into(StridedLayout self, StridedLayout out_layout, axes_mask_t axis_mask=AXIS_MASK_ALL) except -1 nogil: + cdef BaseLayout flattened + cdef int ndim = flatten_strides_in_c_index_order(flattened, self.base, axis_mask) + + if out_layout is self and ndim == self.base.ndim: + return 0 + + # Reset all memoized properties + out_layout._prop_mask = 0 + + # Preserved attributes + out_layout.itemsize = self.itemsize + out_layout.slice_offset = self.slice_offset + maybe_copy_volume(out_layout, self) + + # Set new attributes + _swap_layout(out_layout.base, flattened) + return 0 + + cdef int squeeze_into(StridedLayout self, StridedLayout out_layout) except -1 nogil: + cdef BaseLayout squeezed + squeeze_extents(squeezed, self.base) + + if out_layout is self and squeezed.ndim == self.base.ndim: + return 0 + + # Reset all memoized properties + out_layout._prop_mask = 0 + + # Preserved attributes + out_layout.itemsize = self.itemsize + out_layout.slice_offset = self.slice_offset + maybe_copy_volume(out_layout, self) + + # Set new attributes + _swap_layout(out_layout.base, squeezed) + return 0 + + cdef int unsqueeze_into(StridedLayout self, StridedLayout out_layout, axis_vec_t& axis_vec) except -1 nogil: + if axis_vec.size() == 0 and self is out_layout: + return 0 + + cdef BaseLayout unsqueezed + unsqueeze_extents(unsqueezed, self.base, axis_vec) + + # Reset all memoized properties + out_layout._prop_mask = 0 + + # Preserved attributes + out_layout.itemsize = self.itemsize + out_layout.slice_offset = self.slice_offset + maybe_copy_volume(out_layout, self) + + # Set new attributes + _swap_layout(out_layout.base, unsqueezed) + return 0 + + cdef int broadcast_into(StridedLayout self, StridedLayout out_layout, BaseLayout& broadcast) except -1 nogil: + _validate_shape(broadcast) + broadcast_extents(broadcast, self.base) + + # Reset all memoized properties + out_layout._prop_mask = 0 + + # Preserved attributes + out_layout.itemsize = self.itemsize + out_layout.slice_offset = self.slice_offset + + # Set new attributes + _swap_layout(out_layout.base, broadcast) + return 0 + + cdef int pack_into(StridedLayout self, StridedLayout out_layout, int itemsize, intptr_t data_ptr, bint keep_dim, int axis=-1) except -1 nogil: + + cdef BaseLayout packed + cdef stride_t new_slice_offset = 0 + cdef int vec_size = pack_extents( + packed, + new_slice_offset, + self.base, + self.slice_offset, + self.itemsize, + itemsize, + data_ptr, + keep_dim, + axis + ) + + if vec_size == 1 and out_layout is self: + return 0 + + # Reset all memoized properties + out_layout._prop_mask = 0 + + # Set new attributes + out_layout.itemsize = itemsize + out_layout.slice_offset = new_slice_offset + _swap_layout(out_layout.base, packed) + return vec_size + + cdef int unpack_into(StridedLayout self, StridedLayout out_layout, int itemsize, int axis=-1) except -1 nogil: + cdef BaseLayout unpacked + cdef int vec_size = unpack_extents( + unpacked, + self.base, + self.itemsize, + itemsize, + axis + ) + if vec_size == 1 and out_layout is self: + return 0 + + cdef int64_t new_slice_offset = _overflow_checked_mul(self.slice_offset, vec_size) + + # Reset all memoized properties + out_layout._prop_mask = 0 + + # Set new attributes + out_layout.itemsize = itemsize + out_layout.slice_offset = new_slice_offset + _swap_layout(out_layout.base, unpacked) + return vec_size + + cdef int slice_into(StridedLayout self, StridedLayout out_layout, tuple slices) except -1: + cdef BaseLayout sliced + cdef stride_t slice_offset = slice_extents(sliced, self.base, slices) + cdef int64_t new_slice_offset = _overflow_checked_sum(self.slice_offset, slice_offset) + + # Reset all memoized properties + out_layout._prop_mask = 0 + + # Preserved attributes + out_layout.itemsize = self.itemsize + + # Set new attributes + _swap_layout(out_layout.base, sliced) + out_layout.slice_offset = new_slice_offset + return 0 + +cdef inline int maybe_copy_volume(StridedLayout out_layout, StridedLayout in_layout) except -1 nogil: + if _has_valid_property(in_layout, PROP_VOLUME): + out_layout._volume = in_layout.get_volume() + _mark_property_valid(out_layout, PROP_VOLUME) + return 0 + + +cdef inline int validate_reshaped_shape(BaseLayout& new_shape, int64_t old_volume) except -1 nogil: + cdef int ndim = new_shape.ndim + cdef int axis = -1 + cdef extent_t extent + for i in range(ndim): + extent = new_shape.shape[i] + if extent < -1: + raise ValueError("Extents must be non-negative") + elif extent == -1: + if axis == -1: + axis = i + else: + raise ValueError("There can be at most one -1 extent in a shape") + cdef int64_t new_volume = _c_abs(_volume(new_shape)) + if new_volume == 0 and axis != -1: + raise ValueError("The -1 extent is ambiguous when the volume is 0") + if new_volume != old_volume: + if axis == -1: + raise ValueError(f"The original volume {old_volume} and the new volume {new_volume} must be equal.") + extent = old_volume // new_volume + if extent * new_volume != old_volume: + raise ValueError(f"The original volume {old_volume} must be divisible by the specified sub-volume {new_volume}.") + new_shape.shape[axis] = extent + return 0 + + +cdef inline axes_mask_t axis_mask_from_range(int ndim, int start_axis, int end_axis) except? -1 nogil: + if ndim == 0 and start_axis == 0 and end_axis == -1: + return AXIS_MASK_ALL + cdef axes_mask_t axis_mask = AXIS_MASK_ALL + if not _normalize_axis(start_axis, ndim): + raise ValueError(f"Invalid start axis: {start_axis} out of range for {ndim}D tensor") + if not _normalize_axis(end_axis, ndim): + raise ValueError(f"Invalid end axis: {end_axis} out of range for {ndim}D tensor") + if start_axis > 0: + axis_mask &= (AXIS_MASK_ALL << start_axis + 1) + if end_axis < ndim: + axis_mask &= (AXIS_MASK_ALL >> (STRIDED_LAYOUT_MAX_NDIM - end_axis - 1)) + return axis_mask + + +cdef inline int flatten_strides_in_c_index_order(BaseLayout& out_layout, BaseLayout& in_layout, axes_mask_t axis_mask) except -1 nogil: + cdef int ndim = in_layout.ndim + init_base_layout(out_layout, ndim) + cdef int group_start = 0 + cdef int group_end = 0 + cdef int64_t group_vol + cdef int64_t group_stride + cdef int out_i = 0 + cdef extent_t* in_shape = in_layout.shape + cdef stride_t* in_strides = get_strides_ptr(in_layout) + while group_start < ndim: + group_vol = in_shape[group_start] + group_stride = in_strides[group_start] + group_end = group_start + 1 + while ( + group_end < ndim + and (axis_mask & (1 << group_end)) + and group_stride == _overflow_checked_mul(in_strides[group_end], in_shape[group_end]) + ): + group_vol = _overflow_checked_mul(group_vol, in_shape[group_end]) + group_stride = in_strides[group_end] + group_end += 1 + out_layout.shape[out_i] = group_vol + out_layout.strides[out_i] = group_stride + out_i += 1 + group_start = group_end + if out_i != ndim: + trim_base_layout(out_layout, out_i) + return out_i + + +cdef inline axes_mask_t flattened_strides_in_c_index_order_mask(BaseLayout& layout) except? -1 nogil: + if layout.strides == NULL: + return AXIS_MASK_ALL + cdef axes_mask_t axis_mask = 0 + cdef int ndim = layout.ndim + cdef int group_start = 0 + cdef int group_end = 0 + cdef int64_t group_vol + cdef int64_t group_stride + while group_start < ndim: + group_vol = layout.shape[group_start] + group_stride = layout.strides[group_start] + group_end = group_start + 1 + while group_end < ndim and group_stride == layout.strides[group_end] * layout.shape[group_end]: + group_vol = _overflow_checked_mul(group_vol, layout.shape[group_end]) + group_stride = layout.strides[group_end] + axis_mask |= (1 << group_end) + group_end += 1 + group_start = group_end + return axis_mask + + +cdef inline bint split_strides_in_c_index_order(BaseLayout& out_layout, BaseLayout& in_layout) except -1 nogil: + cdef int i = in_layout.ndim - 1 + cdef int new_i = out_layout.ndim - 1 + cdef extent_t extent + cdef extent_t new_extent + cdef extent_t group_vol + cdef stride_t group_stride + cdef extent_t* in_shape = in_layout.shape + cdef stride_t* in_strides = get_strides_ptr(in_layout) + if out_layout.strides == NULL: + _zero_strides(out_layout) + while i >= 0: + extent = in_shape[i] + group_vol = 1 + group_stride = in_strides[i] + while new_i >= 0 and group_vol < extent: + new_extent = out_layout.shape[new_i] + if new_extent == 0: + return False + group_vol = _overflow_checked_mul(group_vol, new_extent) + out_layout.strides[new_i] = group_stride + group_stride = _overflow_checked_mul(group_stride, new_extent) + new_i -= 1 + if group_vol != extent: + return False + i -= 1 + return True + + +cdef inline int permute_extents(BaseLayout& out_layout, BaseLayout& in_layout, axis_vec_t& axis_order) except -1 nogil: + cdef int ndim = in_layout.ndim + init_base_layout(out_layout, ndim) + cdef axis_t axis + cdef axes_mask_t axis_mask + cdef axes_mask_t axis_order_mask = 0 + cdef extent_t* in_shape = in_layout.shape + cdef stride_t* in_strides = get_strides_ptr(in_layout) + + for i in range(ndim): + axis = axis_order[i] + if not _normalize_axis(axis, ndim): + raise ValueError(f"Invalid permutation: axis {axis} out of range for {ndim}D tensor") + axis_mask = 1 << axis + if axis_order_mask & axis_mask: + raise ValueError(f"Invalid permutation: axis {axis_order[i]} appears multiple times.") + axis_order_mask |= axis_mask + out_layout.shape[i] = in_shape[axis] + out_layout.strides[i] = in_strides[axis] + return 0 + + +cdef inline stride_t slice_extents(BaseLayout& out_layout, BaseLayout& in_layout, tuple slices) except? -1: + cdef int ndim = in_layout.ndim + cdef int num_slices = len(slices) + if num_slices > ndim: + raise ValueError(f"The number of slices ({num_slices}) is greater than the number of dimensions ({ndim}).") + init_base_layout(out_layout, ndim) + cdef extent_t* in_shape = in_layout.shape + cdef stride_t* in_strides = get_strides_ptr(in_layout) + cdef stride_t slice_offset = 0 + cdef Py_ssize_t start + cdef Py_ssize_t stop + cdef Py_ssize_t step + cdef extent_t new_extent + cdef object py_slice + cdef bint zero_slice = False + cdef int out_i = 0 + for i in range(num_slices): + py_slice = slices[i] + if isinstance(py_slice, int): + start = py_slice + if not _normalize_axis(start, in_shape[i]): + raise ValueError(f"Invalid index: {start} out of range for axis {i} with extent {in_shape[i]}") + # single element index removes extent from the shape, + # just increase the offset and skip the shape and stride + slice_offset = _overflow_checked_sum(slice_offset, _overflow_checked_mul(start, in_strides[i])) + elif isinstance(py_slice, slice): + _PySlice_Unpack(py_slice, &start, &stop, &step) + new_extent = _PySlice_AdjustIndices(in_shape[i], &start, &stop, step) + if new_extent > 0: + # out_extent > 0 implies start is in [0, extent - 1] range + slice_offset = _overflow_checked_sum(slice_offset, _overflow_checked_mul(start, in_strides[i])) + else: + zero_slice = True + out_layout.shape[out_i] = new_extent + out_layout.strides[out_i] = _overflow_checked_mul(in_strides[i], step) + out_i += 1 + else: + raise TypeError(f"Invalid slice: {py_slice}. Expected slice instance or integer.") + for i in range(num_slices, ndim): + out_layout.shape[out_i] = in_shape[i] + out_layout.strides[out_i] = in_strides[i] + out_i += 1 + if out_i != ndim: + trim_base_layout(out_layout, out_i) + if zero_slice: + _zero_strides(out_layout) + return slice_offset + + +cdef inline int squeeze_extents(BaseLayout& out_layout, BaseLayout& in_layout) except -1 nogil: + cdef int ndim = in_layout.ndim + init_base_layout(out_layout, ndim) + cdef extent_t* in_shape = in_layout.shape + cdef stride_t* in_strides = get_strides_ptr(in_layout) + cdef int out_i = 0 + cdef extent_t extent + for i in range(ndim): + extent = in_shape[i] + if extent == 0: + trim_base_layout(out_layout, 1) + out_layout.shape[0] = 0 + out_layout.strides[0] = 0 + return 1 + elif extent != 1: + out_layout.shape[out_i] = extent + out_layout.strides[out_i] = in_strides[i] + out_i += 1 + if out_i != ndim: + trim_base_layout(out_layout, out_i) + return out_i + + +cdef inline int unsqueeze_extents(BaseLayout& out_layout, BaseLayout& in_layout, axis_vec_t& axis_vec) except -1 nogil: + cdef int ndim = in_layout.ndim + cdef int num_new_axes = axis_vec.size() + cdef int out_ndim = ndim + num_new_axes + # init_base_layout validates out_ndim + init_base_layout(out_layout, out_ndim) + cdef extent_t* in_shape = in_layout.shape + cdef stride_t* in_strides = get_strides_ptr(in_layout) + cdef axes_mask_t out_shape_mask = 0 + cdef axes_mask_t axis_mask = 0 + cdef axis_t axis + for i in range(num_new_axes): + axis = axis_vec[i] + if not _normalize_axis(axis, out_ndim): + raise ValueError(f"Invalid axis: {axis} out of range for {out_ndim}D tensor") + axis_mask = 1 << axis + if out_shape_mask & axis_mask: + raise ValueError(f"Axis {axis} appears multiple times.") + out_shape_mask |= axis_mask + cdef int in_i = 0 + for i in range(out_ndim): + # without the cast, cython has issues with + # recognizing 1 << i does not require Python interaction + axis_mask = 1 << i + if out_shape_mask & axis_mask: + out_layout.shape[i] = 1 + if in_i < ndim: + out_layout.strides[i] = _overflow_checked_mul(in_shape[in_i], in_strides[in_i]) + else: + if ndim > 0: + out_layout.strides[i] = in_strides[ndim - 1] + else: + out_layout.strides[i] = 1 + else: + out_layout.shape[i] = in_shape[in_i] + out_layout.strides[i] = in_strides[in_i] + in_i += 1 + assert in_i == ndim + return 0 + + +cdef inline int broadcast_extents(BaseLayout& broadcast, BaseLayout& in_layout) except -1 nogil: + if broadcast.ndim < in_layout.ndim: + raise ValueError( + f"The broadcast shape ndim ({broadcast.ndim}) must be " + f"greater than or equal to the input shape " + f"ndim ({in_layout.ndim})." + ) + cdef int ndim_diff = broadcast.ndim - in_layout.ndim + _zero_strides(broadcast) + cdef extent_t* in_shape = in_layout.shape + cdef stride_t* in_strides = get_strides_ptr(in_layout) + cdef extent_t* broadcast_shape = broadcast.shape + ndim_diff + cdef stride_t* broadcast_strides = broadcast.strides + ndim_diff + for i in range(in_layout.ndim): + if in_shape[i] == broadcast_shape[i]: + broadcast_strides[i] = in_strides[i] + elif in_shape[i] != 1: + raise ValueError( + f"Shapes cannot be broadcast together: " + f"the original extent must be 1 or be equal to broadcast extent, " + f"got {in_shape[i]} and {broadcast_shape[i]} for axis {i}." + ) + # else -> in_extent == 1, the broadcast extent and zero stride are already set + return 0 + + +cdef inline int64_t gcd(int64_t a, int64_t b) except? -1 nogil: + while b != 0: + a, b = b, a % b + return a + + +cdef inline int pack_extents(BaseLayout& out_layout, stride_t& out_slice_offset, BaseLayout& in_layout, stride_t slice_offset, int itemsize, int new_itemsize, intptr_t data_ptr, bint keep_dim, int axis) except -1 nogil: + cdef int ndim = in_layout.ndim + if new_itemsize <= 0 or new_itemsize & (new_itemsize - 1): + raise ValueError(f"new itemsize must be a power of two, got {new_itemsize}.") + if itemsize <= 0 or itemsize & (itemsize - 1): + raise ValueError(f"itemsize must be a power of two, got {itemsize}.") + if new_itemsize <= itemsize: + if new_itemsize == itemsize: + return 1 + raise ValueError(f"new itemsize ({new_itemsize}) must be greater than or equal to itemsize ({itemsize}).") + if not _normalize_axis(axis, ndim): + raise ValueError(f"Invalid axis: {axis} out of range for {ndim}D tensor") + if data_ptr % new_itemsize != 0: + raise ValueError(f"The data pointer ({data_ptr}) must be aligned to the packed itemsize ({new_itemsize}).") + + cdef extent_t* shape = in_layout.shape + cdef stride_t* strides = get_strides_ptr(in_layout) + if strides[axis] != 1: + raise ValueError(f"The axis {axis} stride must be 1, got {strides[axis]}.") + + cdef int vec_size = new_itemsize // itemsize + cdef extent_t packed_extent = shape[axis] + if packed_extent == 0: + raise ValueError(f"The axis {axis} extent must be non-zero, got {shape[axis]}.") + packed_extent //= vec_size + if packed_extent * vec_size != shape[axis]: + raise ValueError(f"The axis {axis} extent ({shape[axis]}) must be divisible by {vec_size}.") + + cdef stride_t new_slice_offset = slice_offset // vec_size + if new_slice_offset * vec_size != slice_offset: + raise ValueError(f"The slice offset ({slice_offset}) must be divisible by {vec_size}.") + out_slice_offset = new_slice_offset + + init_base_layout(out_layout, ndim) + cdef stride_t packed_stride + cdef int out_i = 0 + for i in range(ndim): + if i == axis: + if keep_dim or packed_extent != 1: # omit the packed axis if it is reduced to 1 + out_layout.shape[out_i] = packed_extent + out_layout.strides[out_i] = 1 + out_i += 1 + else: + packed_stride = strides[i] // vec_size + if packed_stride * vec_size != strides[i]: + raise ValueError(f"The {i} axis stride ({strides[i]}) must be divisible by {vec_size}.") + out_layout.shape[out_i] = shape[i] + out_layout.strides[out_i] = packed_stride + out_i += 1 + if out_i != ndim: + trim_base_layout(out_layout, out_i) + return vec_size + + +cdef inline int unpack_extents(BaseLayout &out_layout, BaseLayout &in_layout, int itemsize, int new_itemsize, int axis) except -1 nogil: + cdef int ndim = in_layout.ndim + if not _normalize_axis(axis, ndim): + raise ValueError(f"Invalid axis: {axis} out of range for {ndim}D tensor") + if new_itemsize <= 0 or new_itemsize & (new_itemsize - 1): + raise ValueError(f"new itemsize must be a power of two, got {new_itemsize}.") + if itemsize <= 0 or itemsize & (itemsize - 1): + raise ValueError(f"itemsize must be a power of two, got {itemsize}.") + if new_itemsize >= itemsize: + if new_itemsize == itemsize: + return 1 + raise ValueError(f"new itemsize ({new_itemsize}) must be less than or equal to itemsize ({itemsize}).") + + cdef extent_t* shape = in_layout.shape + cdef stride_t* strides = get_strides_ptr(in_layout) + if shape[axis] == 0: + raise ValueError(f"The axis {axis} extent must be non-zero, got {shape[axis]}.") + if strides[axis] != 1: + raise ValueError(f"The axis {axis} stride must be 1, got {strides[axis]}.") + + cdef int vec_size = itemsize // new_itemsize + init_base_layout(out_layout, ndim) + out_layout.shape[axis] = _overflow_checked_mul(shape[axis], vec_size) + out_layout.strides[axis] = 1 + + for i in range(ndim): + if i == axis: + continue + out_layout.shape[i] = shape[i] + out_layout.strides[i] = _overflow_checked_mul(strides[i], vec_size) + return vec_size + + +cdef inline int max_compatible_itemsize(BaseLayout& layout, stride_t slice_offset, int itemsize, int max_itemsize, intptr_t data_ptr, int axis) except? -1 nogil: + cdef int ndim = layout.ndim + if max_itemsize <= 0 or max_itemsize & (max_itemsize - 1): + raise ValueError(f"max_itemsize must be a power of two, got {max_itemsize}.") + if itemsize <= 0 or itemsize & (itemsize - 1): + raise ValueError(f"itemsize must be a power of two, got {itemsize}.") + if not _normalize_axis(axis, ndim): + raise ValueError(f"Invalid axis: {axis} out of range for {ndim}D tensor") + max_itemsize = gcd(max_itemsize, _c_abs(data_ptr)) + cdef extent_t* shape = layout.shape + cdef stride_t* strides = get_strides_ptr(layout) + if ndim < 1 or strides[axis] != 1 or shape[axis] == 0: + return min(max_itemsize, itemsize) + max_itemsize = gcd(max_itemsize, _overflow_checked_mul(slice_offset, itemsize)) + max_itemsize = gcd(max_itemsize, _overflow_checked_mul(shape[axis], itemsize)) + for i in range(ndim): + if i == axis: + continue + max_itemsize = gcd(max_itemsize, _overflow_checked_mul(_c_abs(strides[i]), itemsize)) + return max_itemsize diff --git a/cuda_core/cuda/core/experimental/_memoryview.pyx b/cuda_core/cuda/core/experimental/_memoryview.pyx index 40d70ad995..4992fd1b5b 100644 --- a/cuda_core/cuda/core/experimental/_memoryview.pyx +++ b/cuda_core/cuda/core/experimental/_memoryview.pyx @@ -4,14 +4,15 @@ from ._dlpack cimport * +cimport cython import functools from typing import Optional import numpy from cuda.core.experimental._utils.cuda_utils import handle_return, driver -from cuda.core.experimental._utils cimport cuda_utils +from cuda.core.experimental._layout cimport StridedLayout # TODO(leofang): support NumPy structured dtypes @@ -88,9 +89,7 @@ cdef class StridedMemoryView: cdef DLTensor *dl_tensor # Memoized properties - cdef tuple _shape - cdef tuple _strides - cdef bint _strides_init # Has the strides tuple been init'ed? + cdef StridedLayout _layout cdef object _dtype def __init__(self, obj=None, stream_ptr=None): @@ -120,41 +119,17 @@ cdef class StridedMemoryView: dlm_tensor = data dlm_tensor.deleter(dlm_tensor) + @property + def layout(self) -> StridedLayout: + return self.get_layout() + @property def shape(self) -> tuple[int]: - if self._shape is None: - if self.exporting_obj is not None: - if self.dl_tensor != NULL: - self._shape = cuda_utils.carray_int64_t_to_tuple( - self.dl_tensor.shape, - self.dl_tensor.ndim - ) - else: - self._shape = self.metadata["shape"] - else: - self._shape = () - return self._shape + return self.get_layout().get_shape_tuple() @property def strides(self) -> Optional[tuple[int]]: - cdef int itemsize - if self._strides_init is False: - if self.exporting_obj is not None: - if self.dl_tensor != NULL: - if self.dl_tensor.strides: - self._strides = cuda_utils.carray_int64_t_to_tuple( - self.dl_tensor.strides, - self.dl_tensor.ndim - ) - else: - # This is a Python interface anyway, so not much point - # to using the optimization in cuda_utils.carray_int64_t_to_tuple - strides = self.metadata.get("strides") - if strides is not None: - itemsize = self.dtype.itemsize - self._strides = tuple(x // itemsize for x in strides) - self._strides_init = True - return self._strides + return self.get_layout().get_strides_tuple() @property def dtype(self) -> Optional[numpy.dtype]: @@ -164,7 +139,7 @@ cdef class StridedMemoryView: self._dtype = dtype_dlpack_to_numpy(&self.dl_tensor.dtype) else: # TODO: this only works for built-in numeric types - self._dtype = numpy.dtype(self.metadata["typestr"]) + self._dtype = typestr2dtype(self.metadata["typestr"]) return self._dtype def __repr__(self): @@ -177,6 +152,16 @@ cdef class StridedMemoryView: + f" readonly={self.readonly},\n" + f" exporting_obj={get_simple_repr(self.exporting_obj)})") + cdef inline StridedLayout get_layout(self): + if self._layout is None: + if self.dl_tensor: + self._layout = layout_from_dlpack(self.dl_tensor) + elif self.metadata is not None: + self._layout = layout_from_cai(self.metadata) + else: + self._layout = StridedLayout.__new__(StridedLayout) + return self._layout + cdef str get_simple_repr(obj): # TODO: better handling in np.dtype objects @@ -433,3 +418,43 @@ def args_viewable_as_strided_memory(tuple arg_indices): return func(*args, **kwargs) return wrapped_func return wrapped_func_with_indices + + +cdef inline StridedLayout layout_from_dlpack(DLTensor* dl_tensor): + cdef StridedLayout layout = StridedLayout.__new__(StridedLayout) + cdef int nbits = dl_tensor.dtype.bits + cdef int itemsize = nbits >> 3 + if (itemsize << 3) != nbits: + raise ValueError("dl_tensor.dtype.bits must be a multiple of 8") + layout.init_from_ptr(dl_tensor.ndim, dl_tensor.shape, dl_tensor.strides, itemsize) + return layout + + +cdef StridedLayout layout_from_cai(object metadata): + cdef StridedLayout layout = StridedLayout.__new__(StridedLayout) + cdef object shape = metadata["shape"] + cdef object strides = metadata.get("strides") + cdef int itemsize = typestr2itemsize(metadata["typestr"]) + layout.init_from_tuple(shape, strides, itemsize, strides is not None) + return layout + + +_typestr2dtype_cache = {} +_typestr2itemsize_cache = {} + +cdef object typestr2dtype(str typestr): + global _typestr2dtype_cache + cdef object dtype = _typestr2dtype_cache.get(typestr) + if dtype is None: + dtype = numpy.dtype(typestr) + _typestr2dtype_cache[typestr] = dtype + return dtype + + +cdef inline int typestr2itemsize(str typestr): + global _typestr2itemsize_cache + cdef object itemsize = _typestr2itemsize_cache.get(typestr) + if itemsize is None: + itemsize = typestr2dtype(typestr).itemsize + _typestr2itemsize_cache[typestr] = itemsize + return itemsize diff --git a/cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd index 0e75202498..8af02fd92f 100644 --- a/cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd +++ b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd @@ -13,6 +13,11 @@ ctypedef fused supported_error_type: cydriver.CUresult +ctypedef fused integer_t: + int64_t + int + + # mimic CU_DEVICE_INVALID cdef const cydriver.CUcontext CU_CONTEXT_INVALID = (-2) @@ -41,7 +46,7 @@ cdef extern from "Python.h": void _PyTuple_SET_ITEM "PyTuple_SET_ITEM" (object p, Py_ssize_t pos, PyObject *o) -cdef inline tuple carray_int64_t_to_tuple(int64_t *ptr, int length): +cdef inline tuple carray_integer_t_to_tuple(integer_t *ptr, int length): # Construct shape and strides tuples using the Python/C API for speed cdef tuple result = cpython.PyTuple_New(length) for i in range(length): diff --git a/cuda_core/cuda/core/experimental/include/layout.hpp b/cuda_core/cuda/core/experimental/include/layout.hpp new file mode 100644 index 0000000000..b84f74d8a2 --- /dev/null +++ b/cuda_core/cuda/core/experimental/include/layout.hpp @@ -0,0 +1,50 @@ +#ifndef CUDA_CORE_LAYOUT_HPP +#define CUDA_CORE_LAYOUT_HPP + +#include +#include +#include +#include + + +#define STRIDED_LAYOUT_MAX_NDIM 32 +#define AXIS_MASK_ALL 0xFFFFFFFE + +inline int64_t _c_abs(int64_t x) +{ + return std::abs(x); +} + +template +void _swap(T &a, T &b) noexcept +{ + std::swap(a, b); +} + +inline void _order_from_strides(std::vector& indices, const int64_t* shape, const int64_t* strides, int ndim) +{ + indices.resize(ndim); + std::iota(indices.begin(), indices.end(), 0); + if (!strides) { + return; + } + std::sort(indices.begin(), indices.end(), + [&strides, &shape](int i, int j) + { + int64_t stride_i = _c_abs(strides[i]); + int64_t stride_j = _c_abs(strides[j]); + if (stride_i != stride_j) + { + return stride_i > stride_j; + } + int64_t shape_i = shape[i]; + int64_t shape_j = shape[j]; + if (shape_i != shape_j) + { + return shape_i > shape_j; + } + return i < j; + }); +} + +#endif // CUDA_CORE_LAYOUT_HPP \ No newline at end of file From 23b35fed95ce708f216de7e48dc2a1e24bc0483a Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Mon, 17 Nov 2025 20:14:00 +0100 Subject: [PATCH 02/20] Support wrapping ptr in Buffer, create SMV from buffer and layout, dlpack export Signed-off-by: Kamil Tokarski --- cuda_core/cuda/core/experimental/_layout.pxd | 150 ++++++++------ cuda_core/cuda/core/experimental/_layout.pyx | 151 ++++++++++++-- .../core/experimental/_memory/_buffer.pxd | 9 + .../core/experimental/_memory/_buffer.pyx | 89 +++++++- .../cuda/core/experimental/_memoryview.pyx | 190 +++++++++++++++--- 5 files changed, 470 insertions(+), 119 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_layout.pxd b/cuda_core/cuda/core/experimental/_layout.pxd index 2576349cc3..5fd5c783b6 100644 --- a/cuda_core/cuda/core/experimental/_layout.pxd +++ b/cuda_core/cuda/core/experimental/_layout.pxd @@ -49,13 +49,14 @@ cdef enum Property: PROP_IS_UNIQUE = 1 << 0 PROP_IS_CONTIGUOUS_C = 1 << 1 PROP_IS_CONTIGUOUS_F = 1 << 2 - PROP_IS_CONTIGUOUS_ANY = 1 << 3 - PROP_REQUIRED_SIZE_IN_BYTES = 1 << 4 - PROP_SHAPE = 1 << 5 - PROP_STRIDES = 1 << 6 - PROP_STRIDES_IN_BYTES = 1 << 7 - PROP_STRIDE_ORDER = 1 << 8 - PROP_VOLUME = 1 << 9 + PROP_IS_DENSE = 1 << 3 + PROP_OFFSET_BOUNDS = 1 << 4 + PROP_REQUIRED_SIZE_IN_BYTES = 1 << 5 + PROP_SHAPE = 1 << 6 + PROP_STRIDES = 1 << 7 + PROP_STRIDES_IN_BYTES = 1 << 8 + PROP_STRIDE_ORDER = 1 << 9 + PROP_VOLUME = 1 << 10 cdef struct BaseLayout: @@ -80,20 +81,22 @@ cdef class StridedLayout: # Definition cdef: BaseLayout base - + readonly: int itemsize stride_t slice_offset # Lazy properties computed from the defining values. cdef: - # Set to 0 to invalidate all properties, + # Set to 0 to invalidate all properties, # whenever a defining value is changed property_mask_t _prop_mask # C and Python properties property_mask_t _boolean_props int64_t _required_size_in_bytes + stride_t _min_offset + stride_t _max_offset int64_t _volume # Python properties @@ -101,7 +104,7 @@ cdef class StridedLayout: tuple _py_strides tuple _py_strides_in_bytes tuple _py_stride_order - + # ============================== # Initialization # ============================== @@ -111,12 +114,12 @@ cdef class StridedLayout: if base.strides != NULL and strides_in_bytes: _divide_strides(base, itemsize) - + self.itemsize = itemsize self.slice_offset = 0 _swap_layout(self.base, base) return 0 - + cdef inline stride_t _init_dense(StridedLayout self, BaseLayout& base, int itemsize, OrderFlag order_flag, axis_vec_t* stride_order=NULL) except -1 nogil: _validate_itemsize(itemsize) @@ -138,26 +141,26 @@ cdef class StridedLayout: self._volume = volume _mark_property_valid(self, PROP_VOLUME) return 0 - + cdef inline int init_from_ptr(StridedLayout self, int ndim, extent_t* shape, stride_t* strides, int itemsize, bint strides_in_bytes=False) except -1 nogil: cdef BaseLayout base _init_base_layout_from_ptr(base, ndim, shape, strides) return self._init(base, itemsize, strides_in_bytes) - + cdef inline int init_dense_from_ptr(StridedLayout self, int ndim, extent_t* shape, int itemsize, OrderFlag order_flag, axis_vec_t* stride_order=NULL) except -1 nogil: cdef BaseLayout base _init_base_layout_from_ptr(base, ndim, shape, NULL) return self._init_dense(base, itemsize, order_flag, stride_order) - + cdef inline int init_from_tuple(StridedLayout self, tuple shape, tuple strides, int itemsize, bint strides_in_bytes=False) except -1: cdef BaseLayout base _init_base_layout_from_tuple(base, shape, strides) return self._init(base, itemsize, strides_in_bytes) - + cdef inline int init_dense_from_tuple(StridedLayout self, tuple shape, int itemsize, object stride_order) except -1: cdef axis_vec_t stride_order_vec cdef OrderFlag order_flag = _stride_order2vec(stride_order_vec, stride_order) - + if order_flag == ORDER_NONE: raise ValueError(f"The stride_order must be 'C', 'F', or a permutation tuple. Got: {stride_order}") @@ -183,14 +186,14 @@ cdef class StridedLayout: self._py_strides = cuda_utils.carray_integer_t_to_tuple(self.base.strides, self.base.ndim) _mark_property_valid(self, PROP_STRIDES) return self._py_strides - + cdef inline int get_strides_in_bytes(StridedLayout self, extents_strides_t& strides) except -1 nogil: if self.base.strides != NULL: strides.resize(self.base.ndim) for i in range(self.base.ndim): strides[i] = _overflow_checked_mul(self.base.strides[i], self.itemsize) return 0 - + cdef inline tuple get_strides_in_bytes_tuple(StridedLayout self): if _has_valid_property(self, PROP_STRIDES_IN_BYTES): return self._py_strides_in_bytes @@ -202,7 +205,7 @@ cdef class StridedLayout: self._py_strides_in_bytes = cuda_utils.carray_integer_t_to_tuple(strides.data(), strides.size()) _mark_property_valid(self, PROP_STRIDES_IN_BYTES) return self._py_strides_in_bytes - + cdef inline int64_t get_volume(StridedLayout self) except -1 nogil: if not _has_valid_property(self, PROP_VOLUME): self._volume = _volume(self.base) @@ -212,7 +215,7 @@ cdef class StridedLayout: cdef inline int get_stride_order(StridedLayout self, axis_vec_t& stride_order) except -1 nogil: _order_from_strides(stride_order, self.base.shape, self.base.strides, self.base.ndim) return 0 - + cdef inline tuple get_stride_order_tuple(StridedLayout self): if _has_valid_property(self, PROP_STRIDE_ORDER): return self._py_stride_order @@ -221,7 +224,7 @@ cdef class StridedLayout: self._py_stride_order = cuda_utils.carray_integer_t_to_tuple(stride_order.data(), stride_order.size()) _mark_property_valid(self, PROP_STRIDE_ORDER) return self._py_stride_order - + cdef inline bint get_is_unique(StridedLayout self) except -1 nogil: if _has_valid_property(self, PROP_IS_UNIQUE): return _boolean_property(self, PROP_IS_UNIQUE) @@ -234,61 +237,82 @@ cdef class StridedLayout: cdef inline bint get_is_contiguous_c(StridedLayout self) except -1 nogil: if _has_valid_property(self, PROP_IS_CONTIGUOUS_C): return _boolean_property(self, PROP_IS_CONTIGUOUS_C) - return _set_boolean_property(self, PROP_IS_CONTIGUOUS_C, _is_contiguous_c(self.get_volume(), self.base)) + cdef bint is_contiguous_c = ( + self.slice_offset == 0 and _is_contiguous_c(self.get_volume(), self.base) + ) + return _set_boolean_property(self, PROP_IS_CONTIGUOUS_C, is_contiguous_c) cdef inline bint get_is_contiguous_f(StridedLayout self) except -1 nogil: if _has_valid_property(self, PROP_IS_CONTIGUOUS_F): return _boolean_property(self, PROP_IS_CONTIGUOUS_F) - return _set_boolean_property(self, PROP_IS_CONTIGUOUS_F, _is_contiguous_f(self.get_volume(), self.base)) - - cdef inline bint get_is_contiguous_any(StridedLayout self) except -1 nogil: - if _has_valid_property(self, PROP_IS_CONTIGUOUS_ANY): - return _boolean_property(self, PROP_IS_CONTIGUOUS_ANY) + cdef bint is_contiguous_f = ( + self.slice_offset == 0 and _is_contiguous_f(self.get_volume(), self.base) + ) + return _set_boolean_property(self, PROP_IS_CONTIGUOUS_F, is_contiguous_f) + + cdef inline bint get_is_dense(StridedLayout self) except -1 nogil: + if _has_valid_property(self, PROP_IS_DENSE): + return _boolean_property(self, PROP_IS_DENSE) cdef axis_vec_t stride_order self.get_stride_order(stride_order) - return _set_boolean_property(self, PROP_IS_CONTIGUOUS_ANY, _is_contiguous_any(self.get_volume(), self.base, stride_order)) - + cdef bint is_dense = ( + self.slice_offset == 0 and _is_dense(self.get_volume(), self.base, stride_order) + ) + return _set_boolean_property(self, PROP_IS_DENSE, is_dense) + cdef inline int get_offset_bounds(StridedLayout self, stride_t& min_offset, stride_t& max_offset) except -1 nogil: - min_offset = 0 - max_offset = 0 - if self.base.strides == NULL: - max_offset = self.get_volume() - 1 + if _has_valid_property(self, PROP_OFFSET_BOUNDS): + min_offset = self._min_offset + max_offset = self._max_offset return 0 cdef int ndim = self.base.ndim cdef stride_t stride cdef extent_t extent - for i in range(ndim): - stride = self.base.strides[i] # can be negative - extent = self.base.shape[i] # must be non-negative - if extent == 0: - min_offset = 0 - max_offset = -1 # so that max_offset - min_offset + 1 = 0 - return 0 - if stride <= 0: - min_offset = _overflow_checked_sum(min_offset, _overflow_checked_mul(stride, (extent - 1))) - else: - max_offset = _overflow_checked_sum(max_offset, _overflow_checked_mul(stride, (extent - 1))) + min_offset = self.slice_offset + max_offset = self.slice_offset + if self.base.strides == NULL: + max_offset = _overflow_checked_sum(max_offset, self.get_volume() - 1) + else: + for i in range(ndim): + stride = self.base.strides[i] # can be negative + extent = self.base.shape[i] # must be non-negative + if extent == 0: + min_offset = 0 + max_offset = -1 # empty range + return 0 + if stride <= 0: + min_offset = _overflow_checked_sum(min_offset, _overflow_checked_mul(stride, (extent - 1))) + else: + max_offset = _overflow_checked_sum(max_offset, _overflow_checked_mul(stride, (extent - 1))) + self._min_offset = min_offset + self._max_offset = max_offset + _mark_property_valid(self, PROP_OFFSET_BOUNDS) return 0 - - cdef inline int64_t get_required_size_in_bytes(StridedLayout self) except -1 nogil: + + cdef inline int64_t get_required_size_in_bytes(StridedLayout self) except? -1 nogil: if _has_valid_property(self, PROP_REQUIRED_SIZE_IN_BYTES): return self._required_size_in_bytes cdef stride_t min_offset = 0 cdef stride_t max_offset = 0 self.get_offset_bounds(min_offset, max_offset) - if self.slice_offset > 0: - min_offset = min(min_offset, -self.slice_offset) - elif self.slice_offset < 0 and max_offset >= 0: - max_offset = max(max_offset, -self.slice_offset) - cdef int64_t required_size_in_bytes = _overflow_checked_diff(max_offset, min_offset) - required_size_in_bytes = _overflow_checked_sum(required_size_in_bytes, 1) + if min_offset < 0: + raise ValueError( + f"Allocation size for a layout that maps elements " + f"to negative memory offsets is ambiguous. " + f"The layout's min_offset is {min_offset}. " + f"To create a supported layout with the same shape " + f"please use StridedLayout.to_dense()." + ) + if max_offset < min_offset: + return 0 + cdef int64_t required_size_in_bytes = _overflow_checked_sum(max_offset, 1) self._required_size_in_bytes = _overflow_checked_mul(required_size_in_bytes, self.itemsize) _mark_property_valid(self, PROP_REQUIRED_SIZE_IN_BYTES) - return self._required_size_in_bytes - - cdef inline int64_t get_slice_offset_in_bytes(StridedLayout self) except -1 nogil: + return self._required_size_in_bytes + + cdef inline int64_t get_slice_offset_in_bytes(StridedLayout self) except? -1 nogil: return _overflow_checked_mul(self.slice_offset, self.itemsize) - + cdef axes_mask_t get_flattened_axis_mask(StridedLayout self) except? -1 nogil cdef int get_max_compatible_itemsize(StridedLayout self, int max_itemsize, intptr_t data_ptr, int axis=*) except -1 nogil @@ -296,10 +320,10 @@ cdef class StridedLayout: # Layout manipulation # ============================== - + cdef int reshape_into(StridedLayout self, StridedLayout out_layout, BaseLayout& new_shape) except -1 nogil cdef int permute_into(StridedLayout self, StridedLayout out_layout, axis_vec_t& axis_order) except -1 nogil - + cdef int flatten_into(StridedLayout self, StridedLayout out_layout, axes_mask_t axis_mask=*) except -1 nogil cdef int squeeze_into(StridedLayout self, StridedLayout out_layout) except -1 nogil cdef int unsqueeze_into(StridedLayout self, StridedLayout out_layout, axis_vec_t& axis_vec) except -1 nogil @@ -317,8 +341,8 @@ cdef inline int init_base_layout(BaseLayout& layout, int ndim) except -1 nogil: if ndim > STRIDED_LAYOUT_MAX_NDIM: raise ValueError(f"Unsupported number of dimensions: {ndim}. Max supported ndim is {STRIDED_LAYOUT_MAX_NDIM}") # resize(0) is no op, that results in _mem.data() being NULL, - # which would make it tricky to distinguish between strides == NULL - # and strides == tuple() + # which would make it tricky to distinguish between strides == NULL + # and strides == tuple() layout._mem.resize(2 * max(ndim, 1)) layout.shape = layout._mem.data() layout.strides = layout._mem.data() + ndim @@ -498,7 +522,7 @@ cdef inline bint _is_contiguous_f(int64_t volume, BaseLayout& base) except -1 no return True -cdef inline bint _is_contiguous_any(int64_t volume, BaseLayout& base, axis_vec_t& axis_order) except -1 nogil: +cdef inline bint _is_dense(int64_t volume, BaseLayout& base, axis_vec_t& axis_order) except -1 nogil: if volume == 0 or base.strides == NULL: return True cdef int64_t stride = 1 @@ -588,7 +612,7 @@ cdef inline OrderFlag _stride_order2vec(axis_vec_t& stride_order_vec, object str return ORDER_C elif stride_order == 'F': return ORDER_F - elif isinstance(stride_order, tuple | list): + elif isinstance(stride_order, tuple | list): _tuple2axis_vec(stride_order_vec, stride_order) return ORDER_PERM return ORDER_NONE diff --git a/cuda_core/cuda/core/experimental/_layout.pyx b/cuda_core/cuda/core/experimental/_layout.pyx index 1df819bb88..d116cd7018 100644 --- a/cuda_core/cuda/core/experimental/_layout.pyx +++ b/cuda_core/cuda/core/experimental/_layout.pyx @@ -29,18 +29,29 @@ cdef class StridedLayout: @classmethod def dense_like(cls, StridedLayout other, object stride_order="K"): - cdef StridedLayout new_layout = StridedLayout.__new__(cls) cdef OrderFlag order_flag cdef axis_vec_t stride_order_vec if stride_order == "K": + if other.get_is_dense(): + return other other.get_stride_order(stride_order_vec) order_flag = ORDER_PERM else: order_flag = _stride_order2vec(stride_order_vec, stride_order) if order_flag == ORDER_NONE: - raise ValueError(f"The stride_order must be 'K', 'C', 'F', or a permutation tuple. Got: {stride_order}") + raise ValueError( + f"The stride_order must be 'K', 'C', 'F', " + f"or a permutation tuple. Got: {stride_order}" + ) + elif order_flag == ORDER_C: + if other.get_is_contiguous_c(): + return other + elif order_flag == ORDER_F: + if other.get_is_contiguous_f(): + return other + cdef StridedLayout new_layout = StridedLayout.__new__(cls) new_layout.init_dense_from_ptr( other.base.ndim, other.base.shape, @@ -93,34 +104,121 @@ cdef class StridedLayout: @property def is_contiguous_c(StridedLayout self): + """ + True iff the layout is contiguous in C-order, i.e. + the rightmost stride is 1 and each subsequent + stride to the left is the product of the + next extent and the stride. + In C-contigious layout, the strides are non-negative, + increase from the right to the left and the mapping + from indices to memory offsets is 1 to 1. + """ return self.get_is_contiguous_c() @property def is_contiguous_f(StridedLayout self): + """ + True iff the layout is contiguous in F-order, i.e. + the leftmost stride is 1 and each subsequent + stride to the right is the product of the + next stride and extent. + In F-contigious layout, the strides are non-negative, + increase from the left to the right and the mapping + from indices to memory offsets is 1 to 1. + """ return self.get_is_contiguous_f() @property - def is_contiguous_any(StridedLayout self): - return self.get_is_contiguous_any() + def is_dense(StridedLayout self): + """ + True iff the layout is contiguous in some axis order, i.e. + there exists a permutation of axes such that the layout + is C-contiguous. + In dense layout, the strides are non-negative and the mapping + from indices to memory offsets is 1 to 1. + """ + return self.get_is_dense() @property def offset_bounds(StridedLayout self): + """ + A tuple of ``(min_offset, max_offset)`` representing the + minimum and maximum offsets (as a number of elements, not bytes) + that the layout can map to. + I.e. there exist two ndim-tuples ``idx_min`` and ``idx_max``, + where ``0 <= idx[i] < shape[i]`` for ``0 <= i < ndim``, + such that: + ``min_offset = sum(idx_min[i] * strides[i] for i in range(ndim))`` + ``max_offset = sum(idx_max[i] * strides[i] for i in range(ndim))``, + and all other valid ndim-indices are mapped to offsets + in the range ``[min_offset, max_offset]``. + """ cdef stride_t min_offset = 0 cdef stride_t max_offset = 0 self.get_offset_bounds(min_offset, max_offset) return min_offset, max_offset @property - def required_size_in_bytes(StridedLayout self): - return self.get_required_size_in_bytes() + def min_offset(StridedLayout self): + """ + See ``offset_bounds`` for details. + """ + cdef stride_t min_offset = 0 + cdef stride_t max_offset = 0 + self.get_offset_bounds(min_offset, max_offset) + return min_offset + + @property + def max_offset(StridedLayout self): + """ + See ``offset_bounds`` for details. + """ + cdef stride_t min_offset = 0 + cdef stride_t max_offset = 0 + self.get_offset_bounds(min_offset, max_offset) + return max_offset @property def slice_offset_in_bytes(StridedLayout self): + """ + The memory offset (as a number of bytes) + of the element at index ``(0,) * ndim``. + The only way for the index 0 to be mapped to + non-zero offset in memory is if the layout + was sliced. + """ return self.get_slice_offset_in_bytes() - + + def required_size_in_bytes(StridedLayout self): + """ + The memory allocation size in bytes needed for all + elements of the ndim-tensor to be mapped to + offsets within the allocated memory range. + I.e. for any ndim-tuple ``idx``, such that + ``0 <= idx[i] < shape[i]`` for ``0 <= i < ndim``, + the ``sum(idx[i] * strides[i] for i in range(ndim))`` + is in the range ``[0, required_size_in_bytes - 1]``. + The function raises an error if the layout maps any element + to a negative memory offset (i.e. layout.offset_bounds[0] < 0). + """ + return self.get_required_size_in_bytes() + def flattened_axis_mask(StridedLayout self): + """ + A mask describing which axes can be merged + together preserving the index to memory offset mapping + (see more details in ``flattened`` method documentation). + The only supported operation is the logical ``&`` + between masks coming from the layouts with equal ndim. + If such a mask is passed to the + ``flattened`` method, only the axes that are mergable + for all the layouts will be flattened. + """ return self.get_flattened_axis_mask() - + + def to_dense(StridedLayout self, object stride_order="K"): + return StridedLayout.dense_like(self, stride_order) + def reshaped(StridedLayout self, object shape): cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) cdef BaseLayout new_shape @@ -138,6 +236,21 @@ cdef class StridedLayout: return new_layout def flattened(StridedLayout self, start_axis=0, end_axis=-1, mask=None): + """ + Merges consecutive axes into a single axis (where the new extent + is the product of merged extents) if the mapping of indices to + memory offsets is preserved (assuming the indices are iterated + in C-order, i.e. the rightmost axis is incremented first). + E.g. for ``StridedLayout((2, 2), (4, 2), 1)`` + and the C-ordered indices ``[(0, 0), (0, 1), (1, 0), (1, 1)]`` would + be mapped to offsets ``[0, 2, 4, 6]``, same as for the + flattened layout ``StridedLayout((4,), (2,), 1)`` + and the indices ``[0, 1, 2, 3]``. + If ``start_axis`` and ``end_axis`` are provided, only the axes in the + inclusive range ``[start_axis, end_axis]`` are considered for flattening. + Alternatively, a mask specifying which axes to consider can be provided + (see ``flattened_axis_mask`` method documentation for details). + """ cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) cdef axes_mask_t axis_mask if mask is None: @@ -206,14 +319,14 @@ cdef class StridedLayout: cdef axes_mask_t get_flattened_axis_mask(StridedLayout self) except? -1 nogil: return flattened_strides_in_c_index_order_mask(self.base) - + cdef int get_max_compatible_itemsize(StridedLayout self, int max_itemsize, intptr_t data_ptr, int axis=-1) except -1 nogil: return max_compatible_itemsize(self.base, self.slice_offset, self.itemsize, max_itemsize, data_ptr, axis) cdef int reshape_into(StridedLayout self, StridedLayout out_layout, BaseLayout& new_shape) except -1 nogil: cdef int64_t old_volume = self.get_volume() validate_reshaped_shape(new_shape, old_volume) - + cdef int ndim = new_shape.ndim _zero_strides(new_shape) @@ -222,7 +335,7 @@ cdef class StridedLayout: flatten_strides_in_c_index_order(flattened, self.base, AXIS_MASK_ALL) if not split_strides_in_c_index_order(new_shape, flattened): raise ValueError("Layout strides are incompatible with the new shape") - + # Reset all memoized properties out_layout._prop_mask = 0 @@ -241,7 +354,7 @@ cdef class StridedLayout: cdef BaseLayout permuted permute_extents(permuted, self.base, axis_order) - + # Reset all memoized properties out_layout._prop_mask = 0 @@ -327,7 +440,7 @@ cdef class StridedLayout: return 0 cdef int pack_into(StridedLayout self, StridedLayout out_layout, int itemsize, intptr_t data_ptr, bint keep_dim, int axis=-1) except -1 nogil: - + cdef BaseLayout packed cdef stride_t new_slice_offset = 0 cdef int vec_size = pack_extents( @@ -365,9 +478,9 @@ cdef class StridedLayout: ) if vec_size == 1 and out_layout is self: return 0 - + cdef int64_t new_slice_offset = _overflow_checked_mul(self.slice_offset, vec_size) - + # Reset all memoized properties out_layout._prop_mask = 0 @@ -387,7 +500,7 @@ cdef class StridedLayout: # Preserved attributes out_layout.itemsize = self.itemsize - + # Set new attributes _swap_layout(out_layout.base, sliced) out_layout.slice_offset = new_slice_offset @@ -452,7 +565,7 @@ cdef inline int flatten_strides_in_c_index_order(BaseLayout& out_layout, BaseLay cdef extent_t* in_shape = in_layout.shape cdef stride_t* in_strides = get_strides_ptr(in_layout) while group_start < ndim: - group_vol = in_shape[group_start] + group_vol = in_shape[group_start] group_stride = in_strides[group_start] group_end = group_start + 1 while ( @@ -482,7 +595,7 @@ cdef inline axes_mask_t flattened_strides_in_c_index_order_mask(BaseLayout& layo cdef int64_t group_vol cdef int64_t group_stride while group_start < ndim: - group_vol = layout.shape[group_start] + group_vol = layout.shape[group_start] group_stride = layout.strides[group_start] group_end = group_start + 1 while group_end < ndim and group_stride == layout.strides[group_end] * layout.shape[group_end]: @@ -763,7 +876,7 @@ cdef inline int unpack_extents(BaseLayout &out_layout, BaseLayout &in_layout, in raise ValueError(f"The axis {axis} extent must be non-zero, got {shape[axis]}.") if strides[axis] != 1: raise ValueError(f"The axis {axis} stride must be 1, got {strides[axis]}.") - + cdef int vec_size = itemsize // new_itemsize init_base_layout(out_layout, ndim) out_layout.shape[axis] = _overflow_checked_mul(shape[axis], vec_size) diff --git a/cuda_core/cuda/core/experimental/_memory/_buffer.pxd b/cuda_core/cuda/core/experimental/_memory/_buffer.pxd index 12da84b2bd..262e7f5dd8 100644 --- a/cuda_core/cuda/core/experimental/_memory/_buffer.pxd +++ b/cuda_core/cuda/core/experimental/_memory/_buffer.pxd @@ -7,13 +7,22 @@ from libc.stdint cimport uintptr_t from cuda.core.experimental._stream cimport Stream +cdef struct _MemAttrs: + int device_id + bint is_device_accessible + bint is_host_accessible + + cdef class Buffer: cdef: uintptr_t _ptr size_t _size MemoryResource _memory_resource + object _owner object _ptr_obj Stream _alloc_stream + _MemAttrs _mem_attrs + bint _mem_attrs_inited cdef class MemoryResource: diff --git a/cuda_core/cuda/core/experimental/_memory/_buffer.pyx b/cuda_core/cuda/core/experimental/_memory/_buffer.pyx index 1ad79538ac..ac611d5cdd 100644 --- a/cuda_core/cuda/core/experimental/_memory/_buffer.pyx +++ b/cuda_core/cuda/core/experimental/_memory/_buffer.pyx @@ -47,6 +47,8 @@ cdef class Buffer: self._memory_resource = None self._ptr_obj = None self._alloc_stream = None + self._owner = None + self._mem_attrs_inited = False def __init__(self, *args, **kwargs): raise RuntimeError("Buffer objects cannot be instantiated directly. " @@ -55,14 +57,17 @@ cdef class Buffer: @classmethod def _init( cls, ptr: DevicePointerT, size_t size, mr: MemoryResource | None = None, - stream: Stream | None = None + stream: Stream | None = None, owner : object | None = None, ): cdef Buffer self = Buffer.__new__(cls) self._ptr = (int(ptr)) self._ptr_obj = ptr self._size = size + if mr is not None and owner is not None: + raise ValueError("owner and memory resource cannot be both specified together") self._memory_resource = mr self._alloc_stream = (stream) if stream is not None else None + self._owner = owner return self def __dealloc__(self): @@ -74,7 +79,8 @@ cdef class Buffer: @staticmethod def from_handle( - ptr: DevicePointerT, size_t size, mr: MemoryResource | None = None + ptr: DevicePointerT, size_t size, mr: MemoryResource | None = None, + owner: object | None = None, ) -> Buffer: """Create a new :class:`Buffer` object from a pointer. @@ -86,9 +92,13 @@ cdef class Buffer: Memory size of the buffer mr : :obj:`~_memory.MemoryResource`, optional Memory resource associated with the buffer + owner : object, optional + An object holding external allocation that the ``ptr`` points to. + The reference is kept as long as the buffer is alive. + The ``owner`` and ``mr`` cannot be specified together. """ # TODO: It is better to take a stream for latter deallocation - return Buffer._init(ptr, size, mr=mr) + return Buffer._init(ptr, size, mr=mr, owner=owner) @classmethod def from_ipc_descriptor( @@ -224,7 +234,9 @@ cdef class Buffer: """Return the device ordinal of this buffer.""" if self._memory_resource is not None: return self._memory_resource.device_id - raise NotImplementedError("WIP: Currently this property only supports buffers with associated MemoryResource") + else: + Buffer_init_mem_attrs(self) + return self._mem_attrs.device_id @property def handle(self) -> DevicePointerT: @@ -248,14 +260,18 @@ cdef class Buffer: """Return True if this buffer can be accessed by the GPU, otherwise False.""" if self._memory_resource is not None: return self._memory_resource.is_device_accessible - raise NotImplementedError("WIP: Currently this property only supports buffers with associated MemoryResource") + else: + Buffer_init_mem_attrs(self) + return self._mem_attrs.is_device_accessible @property def is_host_accessible(self) -> bool: """Return True if this buffer can be accessed by the CPU, otherwise False.""" if self._memory_resource is not None: return self._memory_resource.is_host_accessible - raise NotImplementedError("WIP: Currently this property only supports buffers with associated MemoryResource") + else: + Buffer_init_mem_attrs(self) + return self._mem_attrs.is_host_accessible @property def memory_resource(self) -> MemoryResource: @@ -267,20 +283,75 @@ cdef class Buffer: """Return the memory size of this buffer.""" return self._size + @property + def owner(self) -> object: + """Return the object holding external allocation.""" + return self._owner + # Buffer Implementation # --------------------- cdef inline void Buffer_close(Buffer self, stream): cdef Stream s - if self._ptr and self._memory_resource is not None: - s = Stream_accept(stream) if stream is not None else self._alloc_stream - self._memory_resource.deallocate(self._ptr, self._size, s) + if self._ptr: + if self._memory_resource is not None: + s = Stream_accept(stream) if stream is not None else self._alloc_stream + self._memory_resource.deallocate(self._ptr, self._size, s) self._ptr = 0 self._memory_resource = None + self._owner = None self._ptr_obj = None self._alloc_stream = None +cdef Buffer_init_mem_attrs(Buffer self): + if not self._mem_attrs_inited: + query_memory_attrs(self._mem_attrs, self._ptr) + self._mem_attrs_inited = True + + +cdef int query_memory_attrs(_MemAttrs &out, uintptr_t ptr) except -1: + cdef int memory_type + ret, attrs = _query_memory_attrs(ptr) + if ret == driver.CUresult.CUDA_ERROR_NOT_INITIALIZED: + # Device class handles the cuInit call internally + from cuda.core.experimental import Device as _Device + _Device() + ret, attrs = _query_memory_attrs(ptr) + raise_if_driver_error(ret) + memory_type = attrs[0] + + if memory_type == 0: + # unregistered host pointer + out.is_host_accessible = True + out.is_device_accessible = False + out.device_id = -1 + elif ( + memory_type == driver.CUmemorytype.CU_MEMORYTYPE_HOST + or memory_type == driver.CUmemorytype.CU_MEMORYTYPE_UNIFIED + ): + # TODO(ktokarski): should we compare host/device ptrs using cuPointerGetAttribute + # for exceptional cases when the same data can end up with different ptrs + # for host and device? + out.is_host_accessible = True + out.is_device_accessible = True + out.device_id = attrs[1] + else: + # device/texture + out.is_host_accessible = False + out.is_device_accessible = True + out.device_id = attrs[1] + return 0 + + +cdef inline _query_memory_attrs(uintptr_t ptr): + cdef tuple attrs = ( + driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_MEMORY_TYPE, + driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + ) + return driver.cuPointerGetAttributes(len(attrs), attrs, ptr) + + cdef class MemoryResource: """Abstract base class for memory resources that manage allocation and deallocation of buffers. diff --git a/cuda_core/cuda/core/experimental/_memoryview.pyx b/cuda_core/cuda/core/experimental/_memoryview.pyx index 4992fd1b5b..17ecddfb8e 100644 --- a/cuda_core/cuda/core/experimental/_memoryview.pyx +++ b/cuda_core/cuda/core/experimental/_memoryview.pyx @@ -13,6 +13,8 @@ import numpy from cuda.core.experimental._utils.cuda_utils import handle_return, driver from cuda.core.experimental._layout cimport StridedLayout +from cuda.core.experimental._dlpack import make_py_capsule +from cuda.core.experimental._memory import Buffer # TODO(leofang): support NumPy structured dtypes @@ -89,9 +91,18 @@ cdef class StridedMemoryView: cdef DLTensor *dl_tensor # Memoized properties + # Either lazily inferred from dl_tensor/metadata, + # or explicitly provided if created with from_buffer(). cdef StridedLayout _layout + # Either exporting_obj if it is a Buffer, otherwise a Buffer instance + # with owner set to the exporting object. + cdef object _buffer + # Either lazily inferred from dl_tensor/metadata, + # or explicitly provided if created with from_buffer(). + # In the latter case, it can be None. cdef object _dtype + def __init__(self, obj=None, stream_ptr=None): if obj is not None: # populate self's attributes @@ -119,6 +130,25 @@ cdef class StridedMemoryView: dlm_tensor = data dlm_tensor.deleter(dlm_tensor) + @classmethod + def from_buffer(cls, object buffer, StridedLayout layout, object dtype=None, bint is_readonly=False): + cdef StridedMemoryView view = StridedMemoryView.__new__(cls) + view_buffer_strided(view, buffer, layout, dtype, is_readonly) + return view + + def view(self, StridedLayout layout=None, object dtype=None, object is_readonly=None): + cdef StridedMemoryView view = StridedMemoryView.__new__(self.__class__) + if layout is None and dtype is None: + return self + if layout is None: + layout = self.get_layout() + if dtype is None: + dtype = self.get_dtype() + if is_readonly is None: + is_readonly = self.readonly + view_buffer_strided(view, self.get_buffer(), layout, dtype, is_readonly) + return view + @property def layout(self) -> StridedLayout: return self.get_layout() @@ -133,25 +163,57 @@ cdef class StridedMemoryView: @property def dtype(self) -> Optional[numpy.dtype]: - if self._dtype is None: - if self.exporting_obj is not None: - if self.dl_tensor != NULL: - self._dtype = dtype_dlpack_to_numpy(&self.dl_tensor.dtype) - else: - # TODO: this only works for built-in numeric types - self._dtype = typestr2dtype(self.metadata["typestr"]) - return self._dtype + return self.get_dtype() def __repr__(self): return (f"StridedMemoryView(ptr={self.ptr},\n" + f" shape={self.shape},\n" + f" strides={self.strides},\n" + + f" itemsize={self.layout.itemsize},\n" + f" dtype={get_simple_repr(self.dtype)},\n" + f" device_id={self.device_id},\n" + f" is_device_accessible={self.is_device_accessible},\n" + f" readonly={self.readonly},\n" + f" exporting_obj={get_simple_repr(self.exporting_obj)})") + def __dlpack__( + self, + *, + stream: int | None = None, + max_version: tuple[int, int] | None = None, + dl_device: tuple[int, int] | None = None, + copy: bool | None = None, + ) -> PyCapsule: + # Note: we ignore the stream argument entirely (as if it is -1). + # It is the user's responsibility to maintain stream order. + if dl_device is not None: + raise BufferError("Sorry, not supported: dl_device other than None") + if copy is True: + raise BufferError("Sorry, not supported: copy=True") + if max_version is None: + versioned = False + else: + if not isinstance(max_version, tuple) or len(max_version) != 2: + raise BufferError(f"Expected max_version tuple[int, int], got {max_version}") + versioned = max_version >= (1, 0) + cdef object dtype = self.get_dtype() + if dtype is None: + raise ValueError( + f"Cannot export the StridedMemoryView without a dtype. " + f"You can create a dtyped view calling view(dtype=...) method." + ) + capsule = make_py_capsule( + self.get_buffer(), + self.ptr, + versioned, + self.get_layout(), + _numpy2dlpack_dtype[dtype], + ) + return capsule + + def __dlpack_device__(self) -> tuple[int, int]: + return self.get_buffer().__dlpack_device__() + cdef inline StridedLayout get_layout(self): if self._layout is None: if self.dl_tensor: @@ -162,6 +224,27 @@ cdef class StridedMemoryView: self._layout = StridedLayout.__new__(StridedLayout) return self._layout + cdef inline object get_buffer(self): + """ + Returns Buffer instance with the underlying data. + If the SMV was created from a Buffer, it will return the same Buffer instance. + Otherwise, it will create a new instance with owner set to the exporting object. + """ + if self._buffer is None: + if isinstance(self.exporting_obj, Buffer): + self._buffer = self.exporting_obj + else: + self._buffer = Buffer.from_handle(self.ptr, 0, owner=self.exporting_obj) + return self._buffer + + cdef inline object get_dtype(self): + if self._dtype is None: + if self.dl_tensor != NULL: + self._dtype = dtype_dlpack_to_numpy(&self.dl_tensor.dtype) + elif self.metadata is not None: + # TODO: this only works for built-in numeric types + self._dtype = _typestr2dtype[self.metadata["typestr"]] + return self._dtype cdef str get_simple_repr(obj): # TODO: better handling in np.dtype objects @@ -279,6 +362,26 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None): return buf +_numpy2dlpack_dtype = { + numpy.dtype("uint8"): (kDLUInt, 8, 1), + numpy.dtype("uint16"): (kDLUInt, 16, 1), + numpy.dtype("uint32"): (kDLUInt, 32, 1), + numpy.dtype("uint64"): (kDLUInt, 64, 1), + numpy.dtype("int8"): (kDLInt, 8, 1), + numpy.dtype("int16"): (kDLInt, 16, 1), + numpy.dtype("int32"): (kDLInt, 32, 1), + numpy.dtype("int64"): (kDLInt, 64, 1), + numpy.dtype("float16"): (kDLFloat, 16, 1), + numpy.dtype("float32"): (kDLFloat, 32, 1), + numpy.dtype("float64"): (kDLFloat, 64, 1), + numpy.dtype("complex64"): (kDLComplex, 64, 1), + numpy.dtype("complex128"): (kDLComplex, 128, 1), + numpy.dtype("bool"): (kDLBool, 8, 1), +} +_typestr2dtype = {dtype.str: dtype for dtype in _numpy2dlpack_dtype.keys()} +_typestr2itemsize = {dtype.str: dtype.itemsize for dtype in _numpy2dlpack_dtype.keys()} + + cdef object dtype_dlpack_to_numpy(DLDataType* dtype): cdef int bits = dtype.bits if dtype.lanes != 1: @@ -434,27 +537,58 @@ cdef StridedLayout layout_from_cai(object metadata): cdef StridedLayout layout = StridedLayout.__new__(StridedLayout) cdef object shape = metadata["shape"] cdef object strides = metadata.get("strides") - cdef int itemsize = typestr2itemsize(metadata["typestr"]) + cdef int itemsize = _typestr2itemsize[metadata["typestr"]] layout.init_from_tuple(shape, strides, itemsize, strides is not None) return layout -_typestr2dtype_cache = {} -_typestr2itemsize_cache = {} - -cdef object typestr2dtype(str typestr): - global _typestr2dtype_cache - cdef object dtype = _typestr2dtype_cache.get(typestr) - if dtype is None: - dtype = numpy.dtype(typestr) - _typestr2dtype_cache[typestr] = dtype - return dtype - - -cdef inline int typestr2itemsize(str typestr): - global _typestr2itemsize_cache - cdef object itemsize = _typestr2itemsize_cache.get(typestr) - if itemsize is None: - itemsize = typestr2dtype(typestr).itemsize - _typestr2itemsize_cache[typestr] = itemsize - return itemsize +cdef inline intptr_t _get_data_ptr(object buffer, StridedLayout layout) except? 0: + cdef bint is_allocated = buffer.owner is None + if is_allocated: + if buffer.memory_resource is None: + raise ValueError( + "Ambiguous buffer instance. The buffer must either hold an allocation " + "(coming from MemoryResource, e.g. Device().memory_resource.allocate()) " + "or wrap external data and specify the owner " + "(`Buffer.from_handle(ptr, size, owner=...)`)." + ) + # For external buffers, we may not know the size. Even if we did, the size + # alone is not enough if the layout can map to negative offsets, i.e.: + # the valid range is not the [ptr, ptr + size - 1], but + # [ptr - offset, ptr + size - offset - 1]. The offset is not reported + # by the packages. + if is_allocated and buffer.size < layout.get_required_size_in_bytes(): + raise ValueError( + f"Buffer size is too small for the layout. " + f"Expected at least {layout.get_required_size_in_bytes()} bytes, " + f"got {buffer.size} bytes." + ) + return (buffer.handle) + layout.get_slice_offset_in_bytes() + + +cdef inline int view_buffer_strided( + StridedMemoryView view, + object buffer, + StridedLayout layout, + object dtype, + bint is_readonly, +) except -1: + if dtype is not None: + dtype = numpy.dtype(dtype) + if dtype.itemsize > layout.itemsize: + layout = layout.packed(dtype.itemsize, int(buffer.handle)) + elif dtype.itemsize < layout.itemsize: + layout = layout.unpacked(dtype.itemsize) + # set the public attributes + view.ptr = _get_data_ptr(buffer, layout) + view.device_id = buffer.device_id + view.is_device_accessible = buffer.is_device_accessible + view.readonly = is_readonly + view.exporting_obj = view._buffer = buffer + # no dlpack/cai metadata + view.dl_tensor = NULL + view.metadata = None + # we get the layout from the caller + view._layout = layout + view._dtype = dtype + return 0 From 247061741785174911acd250f57a219c88bc00da Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Tue, 18 Nov 2025 11:51:24 +0100 Subject: [PATCH 03/20] Documentation, linting, minor fixes Signed-off-by: Kamil Tokarski --- cuda_core/cuda/core/experimental/_layout.pxd | 62 +- cuda_core/cuda/core/experimental/_layout.pyx | 623 ++++++++++++++---- .../core/experimental/_memory/_buffer.pyx | 2 +- .../cuda/core/experimental/_memoryview.pyx | 149 +++-- .../cuda/core/experimental/include/layout.hpp | 6 +- cuda_core/cuda/core/experimental/utils.py | 1 + .../source/_templates/autosummary/cyclass.rst | 27 + cuda_core/docs/source/api.rst | 3 +- cuda_core/docs/source/conf.py | 18 + 9 files changed, 703 insertions(+), 188 deletions(-) create mode 100644 cuda_core/docs/source/_templates/autosummary/cyclass.rst diff --git a/cuda_core/cuda/core/experimental/_layout.pxd b/cuda_core/cuda/core/experimental/_layout.pxd index 5fd5c783b6..37ce85cc72 100644 --- a/cuda_core/cuda/core/experimental/_layout.pxd +++ b/cuda_core/cuda/core/experimental/_layout.pxd @@ -5,7 +5,7 @@ cimport cython from cython.operator cimport dereference as deref -from libc.stdint cimport int64_t, uint32_t, intptr_t +from libc.stdint cimport int64_t, uint32_t, uintptr_t from libcpp cimport vector ctypedef int64_t extent_t @@ -49,14 +49,15 @@ cdef enum Property: PROP_IS_UNIQUE = 1 << 0 PROP_IS_CONTIGUOUS_C = 1 << 1 PROP_IS_CONTIGUOUS_F = 1 << 2 - PROP_IS_DENSE = 1 << 3 - PROP_OFFSET_BOUNDS = 1 << 4 - PROP_REQUIRED_SIZE_IN_BYTES = 1 << 5 - PROP_SHAPE = 1 << 6 - PROP_STRIDES = 1 << 7 - PROP_STRIDES_IN_BYTES = 1 << 8 - PROP_STRIDE_ORDER = 1 << 9 - PROP_VOLUME = 1 << 10 + PROP_IS_CONTIGUOUS_ANY = 1 << 3 + PROP_IS_DENSE = 1 << 4 + PROP_OFFSET_BOUNDS = 1 << 5 + PROP_REQUIRED_SIZE_IN_BYTES = 1 << 6 + PROP_SHAPE = 1 << 7 + PROP_STRIDES = 1 << 8 + PROP_STRIDES_IN_BYTES = 1 << 9 + PROP_STRIDE_ORDER = 1 << 10 + PROP_VOLUME = 1 << 11 cdef struct BaseLayout: @@ -109,14 +110,15 @@ cdef class StridedLayout: # Initialization # ============================== - cdef inline int _init(StridedLayout self, BaseLayout& base, int itemsize, bint strides_in_bytes=False) except -1 nogil: + cdef inline int _init(StridedLayout self, BaseLayout& base, int itemsize, bint divide_strides=False) except -1 nogil: _validate_itemsize(itemsize) - if base.strides != NULL and strides_in_bytes: + if base.strides != NULL and divide_strides: _divide_strides(base, itemsize) self.itemsize = itemsize self.slice_offset = 0 + _swap_layout(self.base, base) return 0 @@ -142,20 +144,20 @@ cdef class StridedLayout: _mark_property_valid(self, PROP_VOLUME) return 0 - cdef inline int init_from_ptr(StridedLayout self, int ndim, extent_t* shape, stride_t* strides, int itemsize, bint strides_in_bytes=False) except -1 nogil: + cdef inline int init_from_ptr(StridedLayout self, int ndim, extent_t* shape, stride_t* strides, int itemsize, bint divide_strides=False) except -1 nogil: cdef BaseLayout base _init_base_layout_from_ptr(base, ndim, shape, strides) - return self._init(base, itemsize, strides_in_bytes) + return self._init(base, itemsize, divide_strides) cdef inline int init_dense_from_ptr(StridedLayout self, int ndim, extent_t* shape, int itemsize, OrderFlag order_flag, axis_vec_t* stride_order=NULL) except -1 nogil: cdef BaseLayout base _init_base_layout_from_ptr(base, ndim, shape, NULL) return self._init_dense(base, itemsize, order_flag, stride_order) - cdef inline int init_from_tuple(StridedLayout self, tuple shape, tuple strides, int itemsize, bint strides_in_bytes=False) except -1: + cdef inline int init_from_tuple(StridedLayout self, tuple shape, tuple strides, int itemsize, bint divide_strides=False) except -1: cdef BaseLayout base _init_base_layout_from_tuple(base, shape, strides) - return self._init(base, itemsize, strides_in_bytes) + return self._init(base, itemsize, divide_strides) cdef inline int init_dense_from_tuple(StridedLayout self, tuple shape, int itemsize, object stride_order) except -1: cdef axis_vec_t stride_order_vec @@ -237,28 +239,24 @@ cdef class StridedLayout: cdef inline bint get_is_contiguous_c(StridedLayout self) except -1 nogil: if _has_valid_property(self, PROP_IS_CONTIGUOUS_C): return _boolean_property(self, PROP_IS_CONTIGUOUS_C) - cdef bint is_contiguous_c = ( - self.slice_offset == 0 and _is_contiguous_c(self.get_volume(), self.base) - ) - return _set_boolean_property(self, PROP_IS_CONTIGUOUS_C, is_contiguous_c) + return _set_boolean_property(self, PROP_IS_CONTIGUOUS_C, _is_contiguous_c(self.get_volume(), self.base)) cdef inline bint get_is_contiguous_f(StridedLayout self) except -1 nogil: if _has_valid_property(self, PROP_IS_CONTIGUOUS_F): return _boolean_property(self, PROP_IS_CONTIGUOUS_F) - cdef bint is_contiguous_f = ( - self.slice_offset == 0 and _is_contiguous_f(self.get_volume(), self.base) - ) - return _set_boolean_property(self, PROP_IS_CONTIGUOUS_F, is_contiguous_f) + return _set_boolean_property(self, PROP_IS_CONTIGUOUS_F, _is_contiguous_f(self.get_volume(), self.base)) + + cdef inline bint get_is_contiguous_any(StridedLayout self) except -1 nogil: + if _has_valid_property(self, PROP_IS_CONTIGUOUS_ANY): + return _boolean_property(self, PROP_IS_CONTIGUOUS_ANY) + cdef axis_vec_t stride_order + self.get_stride_order(stride_order) + return _set_boolean_property(self, PROP_IS_CONTIGUOUS_ANY, _is_contiguous_any(self.get_volume(), self.base, stride_order)) cdef inline bint get_is_dense(StridedLayout self) except -1 nogil: if _has_valid_property(self, PROP_IS_DENSE): return _boolean_property(self, PROP_IS_DENSE) - cdef axis_vec_t stride_order - self.get_stride_order(stride_order) - cdef bint is_dense = ( - self.slice_offset == 0 and _is_dense(self.get_volume(), self.base, stride_order) - ) - return _set_boolean_property(self, PROP_IS_DENSE, is_dense) + return _set_boolean_property(self, PROP_IS_DENSE, self.slice_offset == 0 and self.get_is_contiguous_any()) cdef inline int get_offset_bounds(StridedLayout self, stride_t& min_offset, stride_t& max_offset) except -1 nogil: if _has_valid_property(self, PROP_OFFSET_BOUNDS): @@ -314,7 +312,7 @@ cdef class StridedLayout: return _overflow_checked_mul(self.slice_offset, self.itemsize) cdef axes_mask_t get_flattened_axis_mask(StridedLayout self) except? -1 nogil - cdef int get_max_compatible_itemsize(StridedLayout self, int max_itemsize, intptr_t data_ptr, int axis=*) except -1 nogil + cdef int get_max_compatible_itemsize(StridedLayout self, int max_itemsize, uintptr_t data_ptr, int axis=*) except -1 nogil # ============================== # Layout manipulation @@ -328,7 +326,7 @@ cdef class StridedLayout: cdef int squeeze_into(StridedLayout self, StridedLayout out_layout) except -1 nogil cdef int unsqueeze_into(StridedLayout self, StridedLayout out_layout, axis_vec_t& axis_vec) except -1 nogil cdef int broadcast_into(StridedLayout self, StridedLayout out_layout, BaseLayout& broadcast) except -1 nogil - cdef int pack_into(StridedLayout self, StridedLayout out_layout, int itemsize, intptr_t data_ptr, bint keep_dim, int axis=*) except -1 nogil + cdef int pack_into(StridedLayout self, StridedLayout out_layout, int itemsize, uintptr_t data_ptr, bint keep_dim, int axis=*) except -1 nogil cdef int unpack_into(StridedLayout self, StridedLayout out_layout, int itemsize, int axis=*) except -1 nogil cdef int slice_into(StridedLayout self, StridedLayout out_layout, tuple slices) except -1 @@ -522,7 +520,7 @@ cdef inline bint _is_contiguous_f(int64_t volume, BaseLayout& base) except -1 no return True -cdef inline bint _is_dense(int64_t volume, BaseLayout& base, axis_vec_t& axis_order) except -1 nogil: +cdef inline bint _is_contiguous_any(int64_t volume, BaseLayout& base, axis_vec_t& axis_order) except -1 nogil: if volume == 0 or base.strides == NULL: return True cdef int64_t stride = 1 diff --git a/cuda_core/cuda/core/experimental/_layout.pyx b/cuda_core/cuda/core/experimental/_layout.pyx index d116cd7018..2cac3d9c08 100644 --- a/cuda_core/cuda/core/experimental/_layout.pyx +++ b/cuda_core/cuda/core/experimental/_layout.pyx @@ -4,8 +4,7 @@ cimport cython -from libc.stdint cimport int64_t, intptr_t -from libcpp cimport vector +from libc.stdint cimport int64_t, uintptr_t from cpython.object cimport PyObject @@ -17,18 +16,123 @@ cdef extern from "Python.h": @cython.final cdef class StridedLayout: - - def __init__(StridedLayout self, object shape, object strides, int itemsize, bint strides_in_bytes=False): - self.init_from_tuple(shape, strides, itemsize, strides_in_bytes) + """ + A class describing the layout of a multi-dimensional tensor + with a shape, strides and itemsize. + + Parameters + ---------- + shape : tuple + A tuple of non-negative integers. + strides : tuple, optional + If provided, must be a tuple of integers of the same length as ``shape``. + Otherwise, the strides are assumed to be implicitly C-contiguous and the resulting + layout's :attr:`strides` will be None. + itemsize : int + The number of bytes per single element (dtype size). Must be a power of two. + divide_strides : bool, optional + If True, the provided :attr:`strides` will be divided by the :attr:`itemsize`. + + + See also :meth:`dense`. + + + Attributes + ---------- + itemsize : int + The number of bytes per single element (dtype size). Must be a power of two. + slice_offset : int + The offset (as a number of elements, not bytes) of the element at + index ``(0,) * ndim``. See also :attr:`slice_offset_in_bytes`. + """ + + def __init__( + self : StridedLayout, + shape : tuple[int], + strides : tuple[int] | None, + itemsize : int, + divide_strides : bool = False + ) -> None: + self.init_from_tuple(shape, strides, itemsize, divide_strides) @classmethod - def dense(cls, object shape, int itemsize, object stride_order='C'): + def dense( + cls, + shape : tuple[int], + itemsize : int, + stride_order : str | tuple[int] = 'C' + ) -> StridedLayout: + """ + Creates a new StridedLayout instance with dense strides. + + Parameters + ---------- + shape : tuple + A tuple of non-negative integers. + itemsize : int + The number of bytes per single element of the tensor. + stride_order : str or tuple, optional + The order of the strides: + * 'C' (default) - the strides are computed in C-order (increasing from the right to the left) + * 'F' - the strides are computed in F-order (increasing from the left to the right) + * A tuple - it must be a permutation of ``tuple(range(len(shape)))``. + The last element of the tuple is the axis with stride 1. + + See also :attr:`stride_order`. + + + .. highlight:: python + .. code-block:: python + + assert StridedLayout.dense((5, 3, 7), 1, "C") == StridedLayout((5, 3, 7), (21, 7, 1), 1) + assert StridedLayout.dense((5, 3, 7), 1, "F") == StridedLayout((5, 3, 7), (1, 5, 15), 1) + assert StridedLayout.dense((5, 3, 7), 1, (2, 0, 1)) == StridedLayout((5, 3, 7), (3, 1, 15), 1) + + """ cdef StridedLayout new_layout = StridedLayout.__new__(cls) new_layout.init_dense_from_tuple(shape, itemsize, stride_order) return new_layout @classmethod - def dense_like(cls, StridedLayout other, object stride_order="K"): + def dense_like( + cls, + other : StridedLayout, + stride_order : str | tuple[int] = "K" + ) -> StridedLayout: + """ + Creates a StridedLayout with the same :attr:`shape` and :attr:`itemsize` as the other layout, + but with contiguous strides in the specified order and no slice offset. + + See also :attr:`is_dense`. + + Parameters + ---------- + other : StridedLayout + The StridedLayout to copy the :attr:`shape` and :attr:`itemsize` from. + stride_order : str or tuple, optional + The order of the strides: + * 'K' (default) - keeps the order of the strides as in the ``other`` layout. + * 'C' - the strides are computed in C-order (increasing from the right to the left) + * 'F' - the strides are computed in F-order (increasing from the left to the right) + * A tuple - it must be a permutation of ``tuple(range(len(shape)))``. + The last element of the tuple is the axis with stride 1. + + See also :attr:`stride_order`. + + + .. highlight:: python + .. code-block:: python + + layout = StridedLayout.dense((5, 3, 7), 1).permuted((2, 0, 1)) + assert layout == StridedLayout((7, 5, 3), (1, 21, 7), 1) + + # dense_like with the default "K" stride_order + # keeps the same order of strides as in the original layout + assert StridedLayout.dense_like(layout) == layout + # "C", "F" recompute the strides accordingly + assert StridedLayout.dense_like(layout, "C") == StridedLayout((7, 5, 3), (15, 3, 1), 1) + assert StridedLayout.dense_like(layout, "F") == StridedLayout((7, 5, 3), (1, 7, 35), 1) + """ cdef OrderFlag order_flag cdef axis_vec_t stride_order_vec @@ -61,7 +165,7 @@ cdef class StridedLayout: ) return new_layout - def __repr__(StridedLayout self): + def __repr__(self : StridedLayout) -> str: if self.slice_offset == 0: return ( f"StridedLayout(shape={self.shape}, strides={self.strides}, itemsize={self.itemsize})" @@ -71,87 +175,219 @@ cdef class StridedLayout: f"StridedLayout(shape={self.shape}, strides={self.strides}, itemsize={self.itemsize}, _slice_offset={self.slice_offset})" ) - def __eq__(StridedLayout self, StridedLayout other): + def __eq__(self : StridedLayout, other : StridedLayout) -> bool: return self.itemsize == other.itemsize and self.slice_offset == other.slice_offset and _base_layout_equal(self.base, other.base) @property - def ndim(StridedLayout self) -> int: + def ndim(self : StridedLayout): + """ + The number of dimensions (length of the shape tuple). + + :type: int + """ return self.base.ndim @property - def shape(StridedLayout self) -> tuple: + def shape(self : StridedLayout): + """ + Shape of the tensor. + + :type: tuple[int] + """ return self.get_shape_tuple() @property - def strides(StridedLayout self) -> tuple | None: + def strides(self : StridedLayout): + """ + Strides of the tensor (in **counts**, not bytes). + If StridedLayout was created with strides=None, the + returned value is None and layout is implicitly C-contiguous. + + :type: tuple[int] | None + """ return self.get_strides_tuple() @property - def strides_in_bytes(StridedLayout self) -> tuple | None: + def strides_in_bytes(self : StridedLayout): + """ + Strides of the tensor (in bytes). + + :type: tuple[int] | None + """ return self.get_strides_in_bytes_tuple() @property - def stride_order(StridedLayout self) -> tuple: + def stride_order(self : StridedLayout): + """ + A permutation of ``tuple(range(ndim))`` describing the + relative order of the strides. + + .. highlight:: python + .. code-block:: python + + # C-contiguous layout + assert StridedLayout.dense((5, 3, 7), 1).stride_order == (0, 1, 2) + # F-contiguous layout + assert StridedLayout.dense((5, 3, 7), 1, stride_order="F").stride_order == (2, 1, 0) + # Permuted layout + assert StridedLayout.dense((5, 3, 7), 1, stride_order=(2, 0, 1)).stride_order == (2, 0, 1) + + :type: tuple[int] + """ return self.get_stride_order_tuple() @property - def volume(StridedLayout self) -> int: + def volume(self : StridedLayout): + """ + The number of elements in the tensor, i.e. the product of the shape tuple. + + :type: int + """ return self.get_volume() @property - def is_unique(StridedLayout self) -> bool: + def is_unique(self : StridedLayout): + """ + If True, each element of a tensor with this layout is mapped to + a unique memory offset. + + All contiguous layouts are unique and so are layouts that can be created + by permuting, slicing, flattening, squeezing, repacking, or reshaping + a contiguous layout. + Conversely, broadcast layouts (layouts with a 0 stride + for some extent greater than 1) are not unique. + + For layouts resulting from manual stride manipulations + (such as with ``numpy.lib.stride_tricks``), the check + may inaccurately report False, as the exact uniqueness + check may be expensive. + + :type: bool + """ return self.get_is_unique() @property - def is_contiguous_c(StridedLayout self): + def is_contiguous_c(self : StridedLayout): """ True iff the layout is contiguous in C-order, i.e. the rightmost stride is 1 and each subsequent stride to the left is the product of the - next extent and the stride. - In C-contigious layout, the strides are non-negative, - increase from the right to the left and the mapping - from indices to memory offsets is 1 to 1. + extent and the stride to the right. + + .. highlight:: python + .. code-block:: python + + layout = StridedLayout.dense((2, 5, 3), 1, "C") + assert layout == StridedLayout((2, 5, 3), (15, 3, 1), 1) + assert layout.is_contiguous_c + + See also :attr:`is_contiguous_any`. + + :type: bool """ return self.get_is_contiguous_c() @property - def is_contiguous_f(StridedLayout self): + def is_contiguous_f(self : StridedLayout): """ True iff the layout is contiguous in F-order, i.e. the leftmost stride is 1 and each subsequent stride to the right is the product of the - next stride and extent. - In F-contigious layout, the strides are non-negative, - increase from the left to the right and the mapping - from indices to memory offsets is 1 to 1. + stride and extent to the left. + + .. highlight:: python + .. code-block:: python + + layout = StridedLayout.dense((2, 5, 3), 1, "F") + assert layout == StridedLayout((2, 5, 3), (1, 2, 10), 1) + assert layout.is_contiguous_f + + See also :attr:`is_contiguous_any`. + + :type: bool """ return self.get_is_contiguous_f() @property - def is_dense(StridedLayout self): + def is_contiguous_any(self : StridedLayout): """ True iff the layout is contiguous in some axis order, i.e. there exists a permutation of axes such that the layout is C-contiguous. - In dense layout, the strides are non-negative and the mapping - from indices to memory offsets is 1 to 1. + + In a contiguous layout, the strides are non-negative and + the mapping of elements to the memory offset range + ``[min_offset, max_offset]`` is 1-to-1. + + .. highlight:: python + .. code-block:: python + + # dense defaults to C-contiguous + layout = StridedLayout.dense((5, 3, 7), 1) + assert layout.is_contiguous_c and not layout.is_contiguous_f + assert layout.is_contiguous_any + + # reversing the order of axes gives F-contiguous layout + permuted = layout.permuted((2, 1, 0)) + assert not permuted.is_contiguous_c and permuted.is_contiguous_f + assert permuted.is_contiguous_any + + # neither C- nor F-order but still contiguous + permuted = layout.permuted((2, 0, 1)) + assert not permuted.is_contiguous_c and not permuted.is_contiguous_f + assert permuted.is_contiguous_any + + # slicing the right-most extent creates a gap in the + # offset_bounds range that is not reachable with any + # element in the sliced layout + sliced = layout[:, :, :-1] + assert not sliced.is_contiguous_c and not sliced.is_contiguous_f + assert not sliced.is_contiguous_any + + :type: bool + """ + return self.get_is_contiguous_any() + + @property + def is_dense(self : StridedLayout): + """ + A dense layout is contiguous (:attr:`is_contiguous_any` is True) + and has no slice offset (:attr:`slice_offset_in_bytes` is 0). + + In a dense layout, elements are mapped 1-to-1 to the ``[0, volume - 1]`` + memory offset range. + + :type: bool """ return self.get_is_dense() @property - def offset_bounds(StridedLayout self): - """ - A tuple of ``(min_offset, max_offset)`` representing the - minimum and maximum offsets (as a number of elements, not bytes) - that the layout can map to. - I.e. there exist two ndim-tuples ``idx_min`` and ``idx_max``, - where ``0 <= idx[i] < shape[i]`` for ``0 <= i < ndim``, - such that: - ``min_offset = sum(idx_min[i] * strides[i] for i in range(ndim))`` - ``max_offset = sum(idx_max[i] * strides[i] for i in range(ndim))``, - and all other valid ndim-indices are mapped to offsets - in the range ``[min_offset, max_offset]``. + def offset_bounds(self : StridedLayout): + """ + The memory offset range ``[min_offset, max_offset]`` (in element counts, not bytes) + that elements of a tensor with this layout are mapped to. + + If the layout is empty (i.e. ``volume == 0``), the returned tuple is ``(0, -1)``. + Otherwise, ``min_offset <= max_offset`` and all elements of the tensor with + this layout are mapped within the ``[min_offset, max_offset]`` range. + + .. highlight:: python + .. code-block:: python + + # Possible implementation of the offset_bounds + def offset_bounds(layout : StridedLayout): + if layout.volume == 0: + return 0, -1 + ndim = layout.ndim + shape = layout.shape + strides = layout.strides + idx_min = [shape[i] - 1 if strides[i] < 0 else 0 for i in range(ndim)] + idx_max = [shape[i] - 1 if strides[i] > 0 else 0 for i in range(ndim)] + min_offset = sum(strides[i] * idx_min[i] for i in range(ndim)) + layout.slice_offset + max_offset = sum(strides[i] * idx_max[i] for i in range(ndim)) + layout.slice_offset + return min_offset, max_offset + + :type: tuple[int, int] """ cdef stride_t min_offset = 0 cdef stride_t max_offset = 0 @@ -159,9 +395,11 @@ cdef class StridedLayout: return min_offset, max_offset @property - def min_offset(StridedLayout self): + def min_offset(self : StridedLayout): """ - See ``offset_bounds`` for details. + See :attr:`offset_bounds` for details. + + :type: int """ cdef stride_t min_offset = 0 cdef stride_t max_offset = 0 @@ -169,9 +407,11 @@ cdef class StridedLayout: return min_offset @property - def max_offset(StridedLayout self): + def max_offset(self : StridedLayout): """ - See ``offset_bounds`` for details. + See :attr:`offset_bounds` for details. + + :type: int """ cdef stride_t min_offset = 0 cdef stride_t max_offset = 0 @@ -179,47 +419,91 @@ cdef class StridedLayout: return max_offset @property - def slice_offset_in_bytes(StridedLayout self): + def slice_offset_in_bytes(self : StridedLayout): """ - The memory offset (as a number of bytes) - of the element at index ``(0,) * ndim``. - The only way for the index 0 to be mapped to - non-zero offset in memory is if the layout - was sliced. + The memory offset (as a number of bytes) of the element at index ``(0,) * ndim``. + Equal to :attr:`itemsize` ``*`` :attr:`slice_offset`. + + .. note:: + The only way for the index ``(0,) * ndim`` to be mapped to a non-zero offset + is slicing with :meth:`sliced` method (or ``[]`` operator). + + :type: int """ return self.get_slice_offset_in_bytes() - def required_size_in_bytes(StridedLayout self): + def required_size_in_bytes(self : StridedLayout) -> int: """ - The memory allocation size in bytes needed for all - elements of the ndim-tensor to be mapped to - offsets within the allocated memory range. - I.e. for any ndim-tuple ``idx``, such that - ``0 <= idx[i] < shape[i]`` for ``0 <= i < ndim``, - the ``sum(idx[i] * strides[i] for i in range(ndim))`` - is in the range ``[0, required_size_in_bytes - 1]``. - The function raises an error if the layout maps any element - to a negative memory offset (i.e. layout.offset_bounds[0] < 0). + The memory allocation size (in bytes) needed so that + all elements of a tensor with this layout can be mapped + within the allocated memory range. + + The function raises an error if ``min_offset < 0``. + Otherwise, the returned value is equal to + ``(max_offset + 1) * itemsize``. + + .. hint:: + For dense layouts, the function always succeeds and the + ``(max_offset + 1) * itemsize`` is equal to the ``volume * itemsize``. + + .. highlight:: python + .. code-block:: python + + # Allocating memory on a device to copy a host tensor + def device_tensor_like(a : numpy.ndarray, device : ccx.Device) -> StridedMemoryView: + a_view = StridedMemoryView(a, -1) + # get the original layout of ``a`` and convert it to a dense layout + # to avoid overallocating memory (e.g. if the ``a`` was sliced) + layout = a_view.layout.to_dense() + # get the required size in bytes to fit the tensor + required_size = layout.required_size_in_bytes() + # allocate the memory on the device + device.set_current() + mem = device.allocate(required_size) + # create a view on the newly allocated device memory + b_view = StridedMemoryView.from_buffer(mem, layout, a_view.dtype) + return b_view """ return self.get_required_size_in_bytes() - def flattened_axis_mask(StridedLayout self): + def flattened_axis_mask(self : StridedLayout) -> axes_mask_t: """ - A mask describing which axes can be merged - together preserving the index to memory offset mapping - (see more details in ``flattened`` method documentation). - The only supported operation is the logical ``&`` - between masks coming from the layouts with equal ndim. - If such a mask is passed to the - ``flattened`` method, only the axes that are mergable - for all the layouts will be flattened. + A mask describing which axes of this layout are mergeable + using the :meth:`flattened` method. """ return self.get_flattened_axis_mask() - def to_dense(StridedLayout self, object stride_order="K"): + def to_dense(self : StridedLayout, object stride_order="K") -> StridedLayout: + """ + Returns a dense layout with the same shape and itemsize, + but with dense strides in the specified order. + + See :meth:`dense_like` method documentation for details. + """ return StridedLayout.dense_like(self, stride_order) - def reshaped(StridedLayout self, object shape): + def reshaped(self : StridedLayout, shape : tuple[int]) -> StridedLayout: + """ + Returns a layout with the new shape, if the new shape is compatible + with the current layout. + + The new shape is compatible if: + * the new and old shapes have the same volume + * the old strides can be split or flattened to match the new shape, + assuming indices are iterated in C-order + + A single extent in the ``shape`` tuple can be set to -1 to indicate + it should be inferred from the old volume and the other extents. + + .. highlight:: python + .. code-block:: python + + layout = StridedLayout.dense((5, 3, 4), 1) + assert layout.reshaped((20, 3)) == StridedLayout.dense((20, 3), 1) + assert layout.reshaped((4, -1)) == StridedLayout.dense((4, 15), 1) + assert layout.permuted((2, 0, 1)).reshaped((4, 15,)) == StridedLayout((4, 15), (1, 4), 1) + # layout.permuted((2, 0, 1)).reshaped((20, 3)) -> error + """ cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) cdef BaseLayout new_shape init_base_layout(new_shape, len(shape)) @@ -228,28 +512,59 @@ cdef class StridedLayout: self.reshape_into(new_layout, new_shape) return new_layout - def permuted(StridedLayout self, object axis_order): + def permuted(self : StridedLayout, axis_order : tuple[int]) -> StridedLayout: + """ + Returns a new layout where the shape and strides tuples are permuted + according to the specified permutation of axes. + """ cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) cdef axis_vec_t axis_order_vec _tuple2axis_vec(axis_order_vec, axis_order) self.permute_into(new_layout, axis_order_vec) return new_layout - def flattened(StridedLayout self, start_axis=0, end_axis=-1, mask=None): - """ - Merges consecutive axes into a single axis (where the new extent - is the product of merged extents) if the mapping of indices to - memory offsets is preserved (assuming the indices are iterated - in C-order, i.e. the rightmost axis is incremented first). - E.g. for ``StridedLayout((2, 2), (4, 2), 1)`` - and the C-ordered indices ``[(0, 0), (0, 1), (1, 0), (1, 1)]`` would - be mapped to offsets ``[0, 2, 4, 6]``, same as for the - flattened layout ``StridedLayout((4,), (2,), 1)`` - and the indices ``[0, 1, 2, 3]``. + def flattened(self : StridedLayout, start_axis : int = 0, end_axis : int = -1, mask : int | None = None) -> StridedLayout: + """ + Merges consecutive extents into a single extent (equal to the product of merged extents) + if the corresponding strides can be replaced with a single stride + (assuming indices are iterated in C-order, i.e. the rightmost + axis is incremented first). + + .. highlight:: python + .. code-block:: python + + # the two extents can be merged into a single extent + # because layout.strides[0] == layout.strides[1] * layout.shape[1] + layout = StridedLayout((3, 2), (2, 1), 1) + assert layout.flattened() == StridedLayout((6,), (1,), 1) + + # the two extents cannot be merged into a single extent + # because layout.strides[0] != layout.strides[1] * layout.shape[1] + layout = StridedLayout((3, 2), (1, 3), 1) + assert layout.flattened() == layout + If ``start_axis`` and ``end_axis`` are provided, only the axes in the inclusive range ``[start_axis, end_axis]`` are considered for flattening. - Alternatively, a mask specifying which axes to consider can be provided - (see ``flattened_axis_mask`` method documentation for details). + + Alternatively, a mask specifying which axes to consider can be provided. + A mask of mergeable extents can be obtained using the :meth:`flattened_axis_mask` method. + Masks for layouts with the same number of dimensions can be combined + using the logical ``&`` (bitwise AND) operator. + + .. highlight:: python + .. code-block:: python + + layout = StridedLayout.dense((4, 5, 3), 4) + layout2 = StridedLayout((4, 5, 3), (1, 12, 4), 4) + # Even though the two layouts have the same shape initially, + # their shapes differ after flattening. + assert layout.flattened() == StridedLayout((60,), (1,), 4) + assert layout2.flattened() == StridedLayout((4, 15), (1, 4), 4) + # With the mask, only extents that are mergeable in both layouts are flattened + # and the resulting shape is the same for both layouts. + mask = layout.flattened_axis_mask() & layout2.flattened_axis_mask() + assert layout.flattened(mask=mask) == StridedLayout((4, 15), (15, 1), 4) + assert layout2.flattened(mask=mask) == StridedLayout((4, 15), (1, 4), 4) """ cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) cdef axes_mask_t axis_mask @@ -260,15 +575,23 @@ cdef class StridedLayout: self.flatten_into(new_layout, axis_mask) return new_layout - def flattened_axis_mask(StridedLayout self): - return self.get_flattened_axis_mask() - - def squeezed(StridedLayout self): + def squeezed(self : StridedLayout) -> StridedLayout: + """ + Returns a new layout where all the singleton dimensions (extents equal to 1) + are removed. Additionally, if the layout volume is 0, + the returned layout will be reduced to a 1-dim layout + with shape (0,) and strides (0,). + """ cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) self.squeeze_into(new_layout) return new_layout - def unsqueezed(StridedLayout self, object axis): + def unsqueezed(self : StridedLayout, axis : int | tuple[int]) -> StridedLayout: + """ + Returns a new layout where the specified axis or axes are added as singleton extents. + The ``axis`` can be either a single integer in range ``[0, ndim]`` + or a tuple of unique integers in range ``[0, ndim + len(axis) - 1]``. + """ cdef axis_vec_t axis_vec if isinstance(axis, int): axis_vec.push_back(axis) @@ -280,7 +603,19 @@ cdef class StridedLayout: self.unsqueeze_into(new_layout, axis_vec) return new_layout - def broadcast_to(StridedLayout self, object shape): + def broadcast_to(self : StridedLayout, shape : tuple[int]) -> StridedLayout: + """ + Returns a layout with the new shape, if the old shape can be + broadcast to the new one. + + The shapes are compatible if: + * the new shape has the same or greater number of dimensions + * starting from the right, each extent in the old shape must be 1 or + equal to the corresponding extent in the new shape. + + Strides of the added or modified extents are set to 0, the remaining ones are unchanged. + If the shapes are not compatible, a ValueError is raised. + """ cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) cdef BaseLayout new_shape cdef int new_ndim = len(shape) @@ -290,44 +625,110 @@ cdef class StridedLayout: self.broadcast_into(new_layout, new_shape) return new_layout - def packed(StridedLayout self, int itemsize, intptr_t data_ptr=0, int axis=-1, bint keep_dim=True): - if itemsize == self.itemsize: - return self - cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) - self.pack_into(new_layout, itemsize, data_ptr, keep_dim, axis) - return new_layout + def repacked(self : StridedLayout, itemsize : int, data_ptr : uintptr_t = 0, axis : int = -1, keep_dim : bool = True) -> StridedLayout: + """ + Converts the layout to match the specified itemsize. + If ``new_itemsize < itemsize``, each element of the tensor is **unpacked** into multiple elements, + i.e. the extent at ``axis`` increases by the factor ``itemsize // new_itemsize``. + If ``new_itemsize > itemsize``, the consecutive elements in the tensor are **packed** into a single element, + i.e. the extent at ``axis`` decreases by the factor ``new_itemsize // itemsize``. + In either case, the ``volume * itemsize`` of the layout remains the same. + + The conversion is subject to the following constraints: + * The old and new itemsizes must be powers of two. + * The extent at ``axis`` must be a positive integer. + * The stride at ``axis`` must be 1. + + Moreover, if the ``new_itemsize > itemsize``: + * The extent at ``axis`` must be divisible by ``new_itemsize // itemsize``. + * All other strides must be divisible by ``new_itemsize // itemsize``. + * The ``slice_offset`` must be divisible by ``new_itemsize // itemsize``. + * If ``data_ptr`` is provided, it must be aligned to the new itemsize. + + The maximum itemsize that satisfies all the constraints + can be obtained using the :meth:`max_compatible_itemsize` method. + + If the ``keep_dim`` is False and the extent at ``axis`` would be reduced to 1, + it is omitted from the returned layout. + + .. highlight:: python + .. code-block:: python + + # Repacking the layout with itemsize = 4 bytes as 2, 8, and 16 sized layouts. + layout = StridedLayout.dense((5, 4), 4) + assert layout.repacked(2) == StridedLayout.dense((5, 8), 2) + assert layout.repacked(8) == StridedLayout.dense((5, 2), 8) + assert layout.repacked(16) == StridedLayout.dense((5, 1), 16) + assert layout.repacked(16, keep_dim=False) == StridedLayout.dense((5,), 16) + + + .. highlight:: python + .. code-block:: python + + # Viewing (5, 6) float array as (5, 3) complex64 array. + a = numpy.ones((5, 6), dtype=numpy.float32) + float_view = StridedMemoryView(a, -1) + layout = float_view.layout + assert layout.shape == (5, 6) + assert layout.itemsize == 4 + complex_view = float_view.view(layout.repacked(8), numpy.complex64) + assert complex_view.layout.shape == (5, 3) + assert complex_view.layout.itemsize == 8 + b = numpy.from_dlpack(complex_view) + assert b.shape == (5, 3) + """ - def unpacked(StridedLayout self, int itemsize, int axis=-1): if itemsize == self.itemsize: return self cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) - self.unpack_into(new_layout, itemsize, axis) + if itemsize > self.itemsize: + self.pack_into(new_layout, itemsize, data_ptr, keep_dim, axis) + else: + self.unpack_into(new_layout, itemsize, axis) return new_layout - def max_compatible_itemsize(StridedLayout self, int max_itemsize=16, intptr_t data_ptr=0, int axis=-1): + def max_compatible_itemsize(self : StridedLayout, max_itemsize : int = 16, data_ptr : uintptr_t = 0, axis : int = -1) -> int: + """ + Returns the maximum itemsize (but no greater than ``max_itemsize``) that can be used + with the :meth:`repacked` method for the current layout. + """ return self.get_max_compatible_itemsize(max_itemsize, data_ptr, axis) - def sliced(StridedLayout self, object slices): + def sliced(self : StridedLayout, slices : int | slice | tuple[int | slice]) -> StridedLayout: + """ + Returns a sliced layout. + The ``slices`` parameter can be a single integer, a single :py:class:`slice` object + or a tuple of integers/slices. + + .. hint:: + For convenience, instead of calling this method directly, please rely + on the :py:meth:`~object.__getitem__` operator (i.e. bracket syntax), e.g.: + ``layout[:, start:end:step]``. + + .. note:: + Slicing is purely a layout transformation and does not involve + any data access. + + """ if not isinstance(slices, tuple): slices = (slices,) cdef StridedLayout new_layout = StridedLayout.__new__(StridedLayout) self.slice_into(new_layout, slices) return new_layout - def __getitem__(StridedLayout self, object slices): + def __getitem__(self : StridedLayout, slices : int | slice | tuple[int | slice]) -> StridedLayout: return self.sliced(slices) cdef axes_mask_t get_flattened_axis_mask(StridedLayout self) except? -1 nogil: return flattened_strides_in_c_index_order_mask(self.base) - cdef int get_max_compatible_itemsize(StridedLayout self, int max_itemsize, intptr_t data_ptr, int axis=-1) except -1 nogil: + cdef int get_max_compatible_itemsize(StridedLayout self, int max_itemsize, uintptr_t data_ptr, int axis=-1) except -1 nogil: return max_compatible_itemsize(self.base, self.slice_offset, self.itemsize, max_itemsize, data_ptr, axis) cdef int reshape_into(StridedLayout self, StridedLayout out_layout, BaseLayout& new_shape) except -1 nogil: cdef int64_t old_volume = self.get_volume() - validate_reshaped_shape(new_shape, old_volume) - cdef int ndim = new_shape.ndim + validate_reshaped_shape(new_shape, old_volume) _zero_strides(new_shape) cdef BaseLayout flattened @@ -439,7 +840,7 @@ cdef class StridedLayout: _swap_layout(out_layout.base, broadcast) return 0 - cdef int pack_into(StridedLayout self, StridedLayout out_layout, int itemsize, intptr_t data_ptr, bint keep_dim, int axis=-1) except -1 nogil: + cdef int pack_into(StridedLayout self, StridedLayout out_layout, int itemsize, uintptr_t data_ptr, bint keep_dim, int axis=-1) except -1 nogil: cdef BaseLayout packed cdef stride_t new_slice_offset = 0 @@ -803,7 +1204,7 @@ cdef inline int64_t gcd(int64_t a, int64_t b) except? -1 nogil: return a -cdef inline int pack_extents(BaseLayout& out_layout, stride_t& out_slice_offset, BaseLayout& in_layout, stride_t slice_offset, int itemsize, int new_itemsize, intptr_t data_ptr, bint keep_dim, int axis) except -1 nogil: +cdef inline int pack_extents(BaseLayout& out_layout, stride_t& out_slice_offset, BaseLayout& in_layout, stride_t slice_offset, int itemsize, int new_itemsize, uintptr_t data_ptr, bint keep_dim, int axis) except -1 nogil: cdef int ndim = in_layout.ndim if new_itemsize <= 0 or new_itemsize & (new_itemsize - 1): raise ValueError(f"new itemsize must be a power of two, got {new_itemsize}.") @@ -890,7 +1291,7 @@ cdef inline int unpack_extents(BaseLayout &out_layout, BaseLayout &in_layout, in return vec_size -cdef inline int max_compatible_itemsize(BaseLayout& layout, stride_t slice_offset, int itemsize, int max_itemsize, intptr_t data_ptr, int axis) except? -1 nogil: +cdef inline int max_compatible_itemsize(BaseLayout& layout, stride_t slice_offset, int itemsize, int max_itemsize, uintptr_t data_ptr, int axis) except? -1 nogil: cdef int ndim = layout.ndim if max_itemsize <= 0 or max_itemsize & (max_itemsize - 1): raise ValueError(f"max_itemsize must be a power of two, got {max_itemsize}.") @@ -898,11 +1299,13 @@ cdef inline int max_compatible_itemsize(BaseLayout& layout, stride_t slice_offse raise ValueError(f"itemsize must be a power of two, got {itemsize}.") if not _normalize_axis(axis, ndim): raise ValueError(f"Invalid axis: {axis} out of range for {ndim}D tensor") + if max_itemsize < itemsize: + raise ValueError(f"max_itemsize ({max_itemsize}) cannot be less than itemsize ({itemsize}).") max_itemsize = gcd(max_itemsize, _c_abs(data_ptr)) cdef extent_t* shape = layout.shape cdef stride_t* strides = get_strides_ptr(layout) if ndim < 1 or strides[axis] != 1 or shape[axis] == 0: - return min(max_itemsize, itemsize) + return itemsize max_itemsize = gcd(max_itemsize, _overflow_checked_mul(slice_offset, itemsize)) max_itemsize = gcd(max_itemsize, _overflow_checked_mul(shape[axis], itemsize)) for i in range(ndim): diff --git a/cuda_core/cuda/core/experimental/_memory/_buffer.pyx b/cuda_core/cuda/core/experimental/_memory/_buffer.pyx index ac611d5cdd..c596dbcc59 100644 --- a/cuda_core/cuda/core/experimental/_memory/_buffer.pyx +++ b/cuda_core/cuda/core/experimental/_memory/_buffer.pyx @@ -203,7 +203,7 @@ cdef class Buffer: if not isinstance(max_version, tuple) or len(max_version) != 2: raise BufferError(f"Expected max_version tuple[int, int], got {max_version}") versioned = max_version >= (1, 0) - capsule = make_py_capsule(self, versioned) + capsule = make_py_capsule(self, versioned, int(self.handle)) return capsule def __dlpack_device__(self) -> tuple[int, int]: diff --git a/cuda_core/cuda/core/experimental/_memoryview.pyx b/cuda_core/cuda/core/experimental/_memoryview.pyx index 17ecddfb8e..4cc24a32bb 100644 --- a/cuda_core/cuda/core/experimental/_memoryview.pyx +++ b/cuda_core/cuda/core/experimental/_memoryview.pyx @@ -3,8 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 from ._dlpack cimport * +from libc.stdint cimport uintptr_t +from cuda.core.experimental._layout cimport StridedLayout -cimport cython import functools from typing import Optional @@ -12,7 +13,7 @@ import numpy from cuda.core.experimental._utils.cuda_utils import handle_return, driver -from cuda.core.experimental._layout cimport StridedLayout + from cuda.core.experimental._dlpack import make_py_capsule from cuda.core.experimental._memory import Buffer @@ -20,14 +21,16 @@ from cuda.core.experimental._memory import Buffer cdef class StridedMemoryView: - """A dataclass holding metadata of a strided dense array/tensor. + """A class holding metadata of a strided dense array/tensor. - A :obj:`StridedMemoryView` instance can be created in two ways: + A :obj:`StridedMemoryView` instance can be created in three ways: 1. Using the :obj:`args_viewable_as_strided_memory` decorator (recommended) - 2. Explicit construction, see below + 2. Explicit construction relying on DLPack or CUDA Array Interface, see below. + 3. From :obj:`~_memory.Buffer` and a :obj:`StridedLayout` (see :meth:`from_buffer` classmethod) - This object supports both DLPack (up to v1.0) and CUDA Array Interface + ``StridedMemoryView(obj, stream_ptr)`` can be used to create a view from + objects supporting either DLPack (up to v1.0) or CUDA Array Interface (CAI) v3. When wrapping an arbitrary object it will try the DLPack protocol first, then the CAI protocol. A :obj:`BufferError` is raised if neither is supported. @@ -45,16 +48,26 @@ cdef class StridedMemoryView: consistent with the CAI's semantics. For DLPack, ``stream=-1`` will be internally passed to ``obj.__dlpack__()`` instead. - Attributes + Parameters ---------- + obj : Any + Any objects that supports either DLPack (up to v1.0) or CUDA Array + Interface (v3). + stream_ptr: int + The pointer address (as Python `int`) to the **consumer** stream. + Stream ordering will be properly established unless ``-1`` is passed. + + + .. note:: + The StridedMemoryView can be exported with DLPack, in which case + no synchronization is performed. It is the user's responsibility to ensure + the data in ``exporting_obj`` is properly synchronized when consuming the view. + + + Attributes + ----------- ptr : int Pointer to the tensor buffer (as a Python `int`). - shape : tuple - Shape of the tensor. - strides : Optional[tuple] - Strides of the tensor (in **counts**, not bytes). - dtype: numpy.dtype - Data type of the tensor. device_id : int The device ID for where the tensor is located. It is -1 for CPU tensors (meaning those only accessible from the host). @@ -64,18 +77,12 @@ cdef class StridedMemoryView: Whether the tensor data can be modified in place. exporting_obj : Any A reference to the original tensor object that is being viewed. + If the view is created with :meth:`from_buffer`, + it will be the Buffer instance passed to the method. - Parameters - ---------- - obj : Any - Any objects that supports either DLPack (up to v1.0) or CUDA Array - Interface (v3). - stream_ptr: int - The pointer address (as Python `int`) to the **consumer** stream. - Stream ordering will be properly established unless ``-1`` is passed. """ cdef readonly: - intptr_t ptr + uintptr_t ptr int device_id bint is_device_accessible bint readonly @@ -131,12 +138,54 @@ cdef class StridedMemoryView: dlm_tensor.deleter(dlm_tensor) @classmethod - def from_buffer(cls, object buffer, StridedLayout layout, object dtype=None, bint is_readonly=False): + def from_buffer( + cls, buffer : Buffer, layout : StridedLayout, + dtype : numpy.dtype | None = None, + is_readonly : bool = False + ) -> StridedMemoryView: + """ + Creates a :obj:`StridedMemoryView` instance from a :obj:`~_memory.Buffer` and a :obj:`StridedLayout`. + The Buffer can be either allocation coming from a :obj:`MemoryResource` or an external allocation + wrapped in a :obj:`~_memory.Buffer` object with ``Buffer.from_handle(ptr, size, owner=...)``. + + .. hint:: + When allocating the memory for a given layout, the required allocation size + can be obtained with the :meth:`StridedLayout.required_size_in_bytes` method. + It is best to use the :meth:`StridedLayout.to_dense` method + first to make sure the layout is contiguous, to avoid overallocating memory + for layouts with gaps. + + .. caution:: + When creating a :obj:`StridedMemoryView` from a :obj:`~_memory.Buffer`, + no synchronization is performed. It is the user's responsibility to ensure + the data in ``buffer`` is properly synchronized when consuming the view. + + Parameters + ---------- + buffer : :obj:`~_memory.Buffer` + The buffer to create the view from. + layout : :obj:`StridedLayout` + The layout describing the shape, strides and itemsize of the elements in + the buffer. + dtype : :obj:`numpy.dtype`, optional + Optional dtype. The dtype is required for exporting the view with DLPack. + If specified, the dtype's itemsize must match the layout's itemsize. + To view the buffer with a different itemsize, please use :meth:`StridedLayout.repacked` + first to transform the layout to the desired itemsize. + is_readonly : bool, optional + Whether the mark the view as readonly. The flag will be forwarded to the DLPack capsule. + """ cdef StridedMemoryView view = StridedMemoryView.__new__(cls) view_buffer_strided(view, buffer, layout, dtype, is_readonly) return view - def view(self, StridedLayout layout=None, object dtype=None, object is_readonly=None): + def view( + self, layout : StridedLayout | None = None, dtype : numpy.dtype | None = None + ) -> StridedMemoryView: + """ + Creates a new view with adjusted layout and dtype. + Same as calling :meth:`from_buffer` with the current buffer, layout and dtype. + """ cdef StridedMemoryView view = StridedMemoryView.__new__(self.__class__) if layout is None and dtype is None: return self @@ -144,25 +193,36 @@ cdef class StridedMemoryView: layout = self.get_layout() if dtype is None: dtype = self.get_dtype() - if is_readonly is None: - is_readonly = self.readonly - view_buffer_strided(view, self.get_buffer(), layout, dtype, is_readonly) + view_buffer_strided(view, self.get_buffer(), layout, dtype, self.readonly) return view @property def layout(self) -> StridedLayout: + """ + The layout of the tensor. For StridedMemoryView created from DLPack or CAI, + the layout is inferred from the tensor object's metadata. + """ return self.get_layout() @property def shape(self) -> tuple[int]: + """ + Shape of the tensor. + """ return self.get_layout().get_shape_tuple() @property def strides(self) -> Optional[tuple[int]]: + """ + Strides of the tensor (in **counts**, not bytes). + """ return self.get_layout().get_strides_tuple() @property def dtype(self) -> Optional[numpy.dtype]: + """ + Data type of the tensor. + """ return self.get_dtype() def __repr__(self): @@ -199,13 +259,13 @@ cdef class StridedMemoryView: cdef object dtype = self.get_dtype() if dtype is None: raise ValueError( - f"Cannot export the StridedMemoryView without a dtype. " - f"You can create a dtyped view calling view(dtype=...) method." + "Cannot export the StridedMemoryView without a dtype. " + "You can create a dtyped view calling view(dtype=...) method." ) capsule = make_py_capsule( self.get_buffer(), - self.ptr, versioned, + self.ptr, self.get_layout(), _numpy2dlpack_dtype[dtype], ) @@ -221,7 +281,7 @@ cdef class StridedMemoryView: elif self.metadata is not None: self._layout = layout_from_cai(self.metadata) else: - self._layout = StridedLayout.__new__(StridedLayout) + raise ValueError("Cannot infer layout from the exporting object") return self._layout cdef inline object get_buffer(self): @@ -234,7 +294,7 @@ cdef class StridedMemoryView: if isinstance(self.exporting_obj, Buffer): self._buffer = self.exporting_obj else: - self._buffer = Buffer.from_handle(self.ptr, 0, owner=self.exporting_obj) + self._buffer = Buffer.from_handle(self.ptr, 0, owner=self.exporting_obj) return self._buffer cdef inline object get_dtype(self): @@ -246,6 +306,7 @@ cdef class StridedMemoryView: self._dtype = _typestr2dtype[self.metadata["typestr"]] return self._dtype + cdef str get_simple_repr(obj): # TODO: better handling in np.dtype objects cdef object obj_class @@ -353,7 +414,7 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None): cdef StridedMemoryView buf = StridedMemoryView() if view is None else view buf.dl_tensor = dl_tensor buf.metadata = capsule - buf.ptr = (dl_tensor.data) + buf.ptr = (dl_tensor.data) buf.device_id = device_id buf.is_device_accessible = is_device_accessible buf.readonly = is_readonly @@ -463,13 +524,13 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None): driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, buf.ptr)) - cdef intptr_t producer_s, consumer_s + cdef uintptr_t producer_s, consumer_s stream_ptr = int(stream_ptr) if stream_ptr != -1: stream = cai_data.get("stream") if stream is not None: - producer_s = (stream) - consumer_s = (stream_ptr) + producer_s = (stream) + consumer_s = (stream_ptr) assert producer_s > 0 # establish stream order if producer_s != consumer_s: @@ -538,11 +599,11 @@ cdef StridedLayout layout_from_cai(object metadata): cdef object shape = metadata["shape"] cdef object strides = metadata.get("strides") cdef int itemsize = _typestr2itemsize[metadata["typestr"]] - layout.init_from_tuple(shape, strides, itemsize, strides is not None) + layout.init_from_tuple(shape, strides, itemsize, True) return layout -cdef inline intptr_t _get_data_ptr(object buffer, StridedLayout layout) except? 0: +cdef inline uintptr_t _get_data_ptr(object buffer, StridedLayout layout) except? 0: cdef bint is_allocated = buffer.owner is None if is_allocated: if buffer.memory_resource is None: @@ -563,7 +624,7 @@ cdef inline intptr_t _get_data_ptr(object buffer, StridedLayout layout) except? f"Expected at least {layout.get_required_size_in_bytes()} bytes, " f"got {buffer.size} bytes." ) - return (buffer.handle) + layout.get_slice_offset_in_bytes() + return (buffer.handle) + layout.get_slice_offset_in_bytes() cdef inline int view_buffer_strided( @@ -575,10 +636,12 @@ cdef inline int view_buffer_strided( ) except -1: if dtype is not None: dtype = numpy.dtype(dtype) - if dtype.itemsize > layout.itemsize: - layout = layout.packed(dtype.itemsize, int(buffer.handle)) - elif dtype.itemsize < layout.itemsize: - layout = layout.unpacked(dtype.itemsize) + if dtype.itemsize != layout.itemsize: + raise ValueError( + f"The dtype's itemsize ({dtype.itemsize}) does not match the layout's " + f"itemsize ({layout.itemsize}). Please use :meth:`StridedLayout.repacked` " + f"to transform the layout to the desired itemsize." + ) # set the public attributes view.ptr = _get_data_ptr(buffer, layout) view.device_id = buffer.device_id diff --git a/cuda_core/cuda/core/experimental/include/layout.hpp b/cuda_core/cuda/core/experimental/include/layout.hpp index b84f74d8a2..bed485ea01 100644 --- a/cuda_core/cuda/core/experimental/include/layout.hpp +++ b/cuda_core/cuda/core/experimental/include/layout.hpp @@ -1,3 +1,7 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + #ifndef CUDA_CORE_LAYOUT_HPP #define CUDA_CORE_LAYOUT_HPP @@ -47,4 +51,4 @@ inline void _order_from_strides(std::vector& indices, const int64_t* shape, }); } -#endif // CUDA_CORE_LAYOUT_HPP \ No newline at end of file +#endif // CUDA_CORE_LAYOUT_HPP diff --git a/cuda_core/cuda/core/experimental/utils.py b/cuda_core/cuda/core/experimental/utils.py index 32f62918f6..3227f1eae1 100644 --- a/cuda_core/cuda/core/experimental/utils.py +++ b/cuda_core/cuda/core/experimental/utils.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from cuda.core.experimental._layout import StridedLayout # noqa: F401 from cuda.core.experimental._memoryview import ( StridedMemoryView, # noqa: F401 args_viewable_as_strided_memory, # noqa: F401 diff --git a/cuda_core/docs/source/_templates/autosummary/cyclass.rst b/cuda_core/docs/source/_templates/autosummary/cyclass.rst new file mode 100644 index 0000000000..8728ab53ef --- /dev/null +++ b/cuda_core/docs/source/_templates/autosummary/cyclass.rst @@ -0,0 +1,27 @@ +.. SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +.. SPDX-License-Identifier: Apache-2.0 + +{{ fullname | escape | underline}} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + + {% block attributes %} + {% if attributes %} + {% for item in attributes %} + .. autoattribute:: {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block methods %} + {% if methods %} + .. rubric:: {{ _('Methods') }} + + {% for item in methods %} + .. automethod:: {{ item }} + {%- endfor %} + + {% endif %} + {% endblock %} diff --git a/cuda_core/docs/source/api.rst b/cuda_core/docs/source/api.rst index d7f4d3642d..45be638eb6 100644 --- a/cuda_core/docs/source/api.rst +++ b/cuda_core/docs/source/api.rst @@ -75,6 +75,7 @@ Utility functions args_viewable_as_strided_memory - :template: dataclass.rst + :template: autosummary/cyclass.rst StridedMemoryView + StridedLayout diff --git a/cuda_core/docs/source/conf.py b/cuda_core/docs/source/conf.py index e735918c41..bab2a2b942 100644 --- a/cuda_core/docs/source/conf.py +++ b/cuda_core/docs/source/conf.py @@ -123,5 +123,23 @@ def autodoc_process_docstring(app, what, name, obj, options, lines): lines.extend(new_lines) +def skip_member(app, what, name, obj, skip, options): + # skip undocumented attributes for modules documented + # with cyclass.rst template where attributes + # are assumed to be properties (because cythonized + # properties are not recognized as such by autodoc) + excluded_dirs = [ + "cuda.core.experimental._layout", + "cuda.core.experimental._memoryview", + ] + if what == "attribute" and getattr(obj, "__doc__", None) is None: + obj_module = getattr(getattr(obj, "__objclass__", None), "__module__", None) + if obj_module in excluded_dirs: + print(f"Skipping undocumented attribute {name} in {obj_module}") + return True + return None + + def setup(app): app.connect("autodoc-process-docstring", autodoc_process_docstring) + app.connect("autodoc-skip-member", skip_member) From cf7eff5739d8609c7b10a76d3f40314a87f8270c Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Fri, 21 Nov 2025 19:15:00 +0100 Subject: [PATCH 04/20] Add NotImplemented copy_from/copy_to Signed-off-by: Kamil Tokarski --- .../cuda/core/experimental/_memoryview.pyx | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/cuda_core/cuda/core/experimental/_memoryview.pyx b/cuda_core/cuda/core/experimental/_memoryview.pyx index 4cc24a32bb..d5221d4fcf 100644 --- a/cuda_core/cuda/core/experimental/_memoryview.pyx +++ b/cuda_core/cuda/core/experimental/_memoryview.pyx @@ -5,6 +5,7 @@ from ._dlpack cimport * from libc.stdint cimport uintptr_t from cuda.core.experimental._layout cimport StridedLayout +from cuda.core.experimental._stream import Stream import functools from typing import Optional @@ -196,6 +197,60 @@ cdef class StridedMemoryView: view_buffer_strided(view, self.get_buffer(), layout, dtype, self.readonly) return view + def copy_from( + self, other : StridedMemoryView, stream : Stream, + allocator : MemoryResource | None = None, + blocking : bool | None = None, + ): + """ + Copies the data from the other view into this view. + + The copy can be performed between following memory spaces: + host-to-device, device-to-host, device-to-device (on the same device). + + The following conditions must be met: + * Both views must have compatible shapes, i.e. the shapes must be equal + or the source view's shape must be broadcastable to the target view's shape + (see :meth:`StridedLayout.broadcast_to`). + * Both views must have the same :attr:`dtype` (or :attr:`StridedLayout.itemsize` + if :attr:`dtype` is not specified). + * The destination's layout must be unique (see :meth:`StridedLayout.is_unique`). + + Parameters + ---------- + other : StridedMemoryView + The view to copy data from. + stream : Stream | None, optional + The stream to schedule the copy on. + allocator : MemoryResource | None, optional + If temporary buffers are needed, the specifed memory resources + will be used to allocate the memory. If not specified, default + resources will be used. + blocking : bool | None, optional + Whether the call should block until the copy is complete. + * ``True``: the ``stream`` is synchronized with the host at the end of the call, + blocking until the copy is complete. + * ``False``: if possible, the call returns immediately once the copy is scheduled. + However, in some cases of host-to-device or device-to-host copies, the call may + still synchronize with the host if necessary. + * ``None`` (default): + * for device-to-device, it defaults to ``False`` (non-blocking), + * for host-to-device or device-to-host, it defaults to ``True`` (blocking). + """ + raise NotImplementedError("Sorry, not supported: copy_from") + + def copy_to( + self, other : StridedMemoryView, stream : Stream | None = None, + allocator : MemoryResource | None = None, + blocking : bool | None = None, + ): + """ + Copies the data from this view into the other view. + + For details, see :meth:`copy_from`. + """ + raise NotImplementedError("Sorry, not supported: copy_to") + @property def layout(self) -> StridedLayout: """ From 79010b8e25e95c57f9aaae9b7b3c71d5cad512f8 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Wed, 26 Nov 2025 16:46:16 +0100 Subject: [PATCH 05/20] Adjust flattening scalars to numpy/cupy behavior, fix shape validation in reshape, fix to dense with sliced views Signed-off-by: Kamil Tokarski --- cuda_core/cuda/core/experimental/_layout.pyx | 36 +++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_layout.pyx b/cuda_core/cuda/core/experimental/_layout.pyx index 2cac3d9c08..93570687af 100644 --- a/cuda_core/cuda/core/experimental/_layout.pyx +++ b/cuda_core/cuda/core/experimental/_layout.pyx @@ -135,9 +135,10 @@ cdef class StridedLayout: """ cdef OrderFlag order_flag cdef axis_vec_t stride_order_vec + cdef bint is_dense = other.get_is_dense() if stride_order == "K": - if other.get_is_dense(): + if is_dense: return other other.get_stride_order(stride_order_vec) order_flag = ORDER_PERM @@ -149,10 +150,10 @@ cdef class StridedLayout: f"or a permutation tuple. Got: {stride_order}" ) elif order_flag == ORDER_C: - if other.get_is_contiguous_c(): + if is_dense and other.get_is_contiguous_c(): return other elif order_flag == ORDER_F: - if other.get_is_contiguous_f(): + if is_dense and other.get_is_contiguous_f(): return other cdef StridedLayout new_layout = StridedLayout.__new__(cls) @@ -928,11 +929,12 @@ cdef inline int validate_reshaped_shape(BaseLayout& new_shape, int64_t old_volum else: raise ValueError("There can be at most one -1 extent in a shape") cdef int64_t new_volume = _c_abs(_volume(new_shape)) - if new_volume == 0 and axis != -1: - raise ValueError("The -1 extent is ambiguous when the volume is 0") - if new_volume != old_volume: - if axis == -1: + if axis == -1: + if new_volume != old_volume: raise ValueError(f"The original volume {old_volume} and the new volume {new_volume} must be equal.") + else: + if new_volume == 0: + raise ValueError("The -1 extent is ambiguous when the specified sub-volume is 0") extent = old_volume // new_volume if extent * new_volume != old_volume: raise ValueError(f"The original volume {old_volume} must be divisible by the specified sub-volume {new_volume}.") @@ -957,6 +959,11 @@ cdef inline axes_mask_t axis_mask_from_range(int ndim, int start_axis, int end_a cdef inline int flatten_strides_in_c_index_order(BaseLayout& out_layout, BaseLayout& in_layout, axes_mask_t axis_mask) except -1 nogil: cdef int ndim = in_layout.ndim + if ndim == 0: + init_base_layout(out_layout, 1) + out_layout.shape[0] = 1 + out_layout.strides[0] = 1 + return 1 init_base_layout(out_layout, ndim) cdef int group_start = 0 cdef int group_end = 0 @@ -1021,16 +1028,19 @@ cdef inline bint split_strides_in_c_index_order(BaseLayout& out_layout, BaseLayo _zero_strides(out_layout) while i >= 0: extent = in_shape[i] - group_vol = 1 group_stride = in_strides[i] - while new_i >= 0 and group_vol < extent: + group_vol = 1 + while new_i >= 0: new_extent = out_layout.shape[new_i] if new_extent == 0: return False - group_vol = _overflow_checked_mul(group_vol, new_extent) - out_layout.strides[new_i] = group_stride - group_stride = _overflow_checked_mul(group_stride, new_extent) - new_i -= 1 + if new_extent == 1 or group_vol < extent: + out_layout.strides[new_i] = group_stride + group_stride = _overflow_checked_mul(group_stride, new_extent) + group_vol = _overflow_checked_mul(group_vol, new_extent) + new_i -= 1 + else: + break if group_vol != extent: return False i -= 1 From 4ca35678bcaa86203bb97b8ae03dc719cf14ea41 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Wed, 26 Nov 2025 17:31:16 +0100 Subject: [PATCH 06/20] Add StridedLayout tests Signed-off-by: Kamil Tokarski --- cuda_core/tests/helpers/layout.py | 151 +++++ cuda_core/tests/test_strided_layout.py | 823 +++++++++++++++++++++++++ 2 files changed, 974 insertions(+) create mode 100644 cuda_core/tests/helpers/layout.py create mode 100644 cuda_core/tests/test_strided_layout.py diff --git a/cuda_core/tests/helpers/layout.py b/cuda_core/tests/helpers/layout.py new file mode 100644 index 0000000000..387a849e39 --- /dev/null +++ b/cuda_core/tests/helpers/layout.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from enum import Enum + +import numpy as np + + +class NamedParam: + def __init__(self, name, value): + self.name = name + self.value = value + + def __bool__(self): + return bool(self.value) + + def pretty_name(self): + if isinstance(self.value, Enum): + value_str = self.value.name + else: + value_str = str(self.value) + return f"{self.name}.{value_str}" + + +class DenseOrder(Enum): + """ + Whether to initialize the dense layout in C or F order. + For C, the strides can be explicit or implicit (None). + """ + + C = "C" + IMPLICIT_C = "implicit_c" + F = "F" + + +class _S: + """ + SliceSpec + """ + + def __init__(self): + self.slices = [] + + def __getitem__(self, value): + self.slices.append(value) + return self + + +class LayoutSpec: + """ + Pretty printable specification of a layout in a test case. + """ + + def __init__( + self, + shape, + itemsize, + stride_order=DenseOrder.C, + perm=None, + slices=None, + np_ref=None, + ): + self.shape = shape + self.itemsize = itemsize + self.stride_order = stride_order + self.perm = perm + if slices is not None: + assert isinstance(slices, _S) + slices = slices.slices + self.slices = slices + self.np_ref = np_ref + + def pretty_name(self): + desc = [ + f"ndim.{len(self.shape)}", + f"shape.{self.shape}", + f"itemsize.{self.itemsize}", + ] + if self.stride_order is not None: + if isinstance(self.stride_order, DenseOrder): + desc.append(f"stride_order.{self.stride_order.value}") + else: + assert isinstance(self.stride_order, tuple) + assert len(self.stride_order) == len(self.shape) + desc.append(f"stride_order.{self.stride_order}") + if self.perm is not None: + desc.append(f"perm.{self.perm}") + if self.slices is not None: + desc.append(f"slices.{self.slices}") + return "-".join(desc) + + def dtype_from_itemsize(self): + return dtype_from_itemsize(self.itemsize) + + def np_order(self): + return "F" if self.stride_order == DenseOrder.F else "C" + + def has_no_strides(self): + return self.stride_order == DenseOrder.IMPLICIT_C + + def has_no_strides_transformed(self): + return self.stride_order == DenseOrder.IMPLICIT_C and self.perm is None and self.slices is None + + +def dtype_from_itemsize(itemsize): + if itemsize <= 8: + return np.dtype(f"int{itemsize * 8}") + elif itemsize == 16: + return np.dtype("complex128") + else: + raise ValueError(f"Unsupported itemsize: {itemsize}") + + +def pretty_name(val): + """ + Pytest does not pretty print (repr/str) parameters of custom types. + Use this function as the `ids` argument of `pytest.mark.parametrize`, e.g.: + ``@pytest.mark.parametrize(..., ids=pretty_name)`` + """ + if hasattr(val, "pretty_name"): + return val.pretty_name() + # use default pytest pretty printing + return None + + +def flatten_mask2str(mask, ndim): + return "".join("1" if mask & (1 << i) else "0" for i in range(ndim)) + + +def random_permutations(rng, perm_len, cutoff_len=3, sample_size=6): + if perm_len <= cutoff_len: + return [perm for perm in itertools.permutations(range(perm_len))] + perms = [] + for _ in range(sample_size): + perm = list(range(perm_len)) + rng.shuffle(perm) + perms.append(tuple(perm)) + return perms + + +def inv_permutation(perm): + inv = [None] * len(perm) + for i, p in enumerate(perm): + inv[p] = i + return tuple(inv) + + +def permuted(t, perm): + return tuple(t[i] for i in perm) diff --git a/cuda_core/tests/test_strided_layout.py b/cuda_core/tests/test_strided_layout.py new file mode 100644 index 0000000000..27c6c856cb --- /dev/null +++ b/cuda_core/tests/test_strided_layout.py @@ -0,0 +1,823 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +import itertools +import math +import random +from enum import Enum + +import numpy as np +import pytest +from cuda.core.experimental._layout import StridedLayout +from helpers.layout import ( + _S, + DenseOrder, + LayoutSpec, + NamedParam, + dtype_from_itemsize, + flatten_mask2str, + inv_permutation, + permuted, + pretty_name, + random_permutations, +) + +_ITEMSIZES = [1, 2, 4, 8, 16] + +py_rng = random.Random(42) + + +def _setup_layout_and_np_ref(spec: LayoutSpec): + np_ref = np.arange(math.prod(spec.shape), dtype=spec.dtype_from_itemsize()) + + if isinstance(spec.stride_order, DenseOrder): + np_ref = np_ref.reshape(spec.shape, order=spec.np_order()) + if spec.stride_order == DenseOrder.IMPLICIT_C: + layout = StridedLayout(spec.shape, None, spec.itemsize) + else: + layout = StridedLayout.dense(spec.shape, spec.itemsize, spec.stride_order.value) + else: + assert isinstance(spec.stride_order, tuple) + assert len(spec.stride_order) == len(spec.shape) + # numpy does not allow specyfing the tuple order (only C/F) + np_ref = np_ref.reshape(permuted(spec.shape, spec.stride_order)) + np_ref = np_ref.transpose(inv_permutation(spec.stride_order)) + layout = StridedLayout.dense(spec.shape, spec.itemsize, spec.stride_order) + return layout, np_ref + + +def _transform(layout: StridedLayout, np_ref: np.ndarray, spec: LayoutSpec): + if spec.perm is not None: + np_ref = np_ref.transpose(spec.perm) + layout = layout.permuted(spec.perm) + if spec.slices is not None: + for sl in spec.slices: + np_ref = np_ref[sl] + layout = layout.sliced(sl) + return layout, np_ref + + +def _cmp_layout_and_array(layout: StridedLayout, arr: np.ndarray, expect_strides_none: bool): + """ + Compare StridedLayout and numpy.ndarray. + Compares shape, strides, itemsize and contiguity flags. + """ + ndim = len(arr.shape) + assert layout.ndim == ndim + assert layout.shape == arr.shape + volume = math.prod(arr.shape) + assert layout.volume == volume + assert layout.itemsize == arr.itemsize + assert layout.slice_offset * layout.itemsize == layout.slice_offset_in_bytes + + ref_c_contig = arr.flags["C_CONTIGUOUS"] + ref_f_contig = arr.flags["F_CONTIGUOUS"] + assert layout.is_contiguous_c == ref_c_contig + assert layout.is_contiguous_f == ref_f_contig + ref_any_contig = ref_c_contig or ref_f_contig or arr.transpose(layout.stride_order).flags["C_CONTIGUOUS"] + assert layout.is_contiguous_any == ref_any_contig + assert layout.is_dense == (ref_any_contig and layout.slice_offset == 0) + + if expect_strides_none: + assert layout.strides is None + assert layout.strides_in_bytes is None + assert arr.flags["C_CONTIGUOUS"] + elif math.prod(arr.shape) == 0: + assert layout.strides_in_bytes == tuple(0 for _ in range(ndim)) + else: + assert layout.strides_in_bytes == arr.strides + + +def _cmp_layout_from_dense_vs_from_np(layout: StridedLayout, np_ref: np.ndarray, has_no_strides: bool): + """ + Compare the layout created through series of transformations vs + the layout created from numpy.ndarray transformed accordingly. + """ + + layout_from_np = StridedLayout(np_ref.shape, np_ref.strides, np_ref.itemsize, divide_strides=True) + assert layout_from_np.shape == layout.shape + assert layout_from_np.itemsize == layout.itemsize + assert layout_from_np.is_contiguous_c == layout.is_contiguous_c + assert layout_from_np.is_contiguous_f == layout.is_contiguous_f + assert layout_from_np.is_contiguous_any == layout.is_contiguous_any + assert layout_from_np.is_unique == layout.is_unique + volume = math.prod(np_ref.shape) + assert layout_from_np.volume == layout.volume == volume + + if volume > 0: + assert layout_from_np.stride_order == layout.stride_order + + if has_no_strides: + assert layout_from_np.is_contiguous_c + assert layout_from_np.is_contiguous_any + dense_layout = StridedLayout.dense(np_ref.shape, np_ref.itemsize) + assert layout_from_np.strides == dense_layout.strides + assert layout_from_np.strides_in_bytes == dense_layout.strides_in_bytes + else: + assert layout_from_np.strides == layout.strides + assert layout_from_np.strides_in_bytes == layout.strides_in_bytes + + +def _cmp_slice_offset( + layout_0: StridedLayout, + layout_1: StridedLayout, + np_ref_0: np.ndarray, + np_ref_1: np.ndarray, +): + # cannot access numpy's scalar data pointer + if layout_1.ndim > 0: + ref_offset = np_ref_1.ctypes.data - np_ref_0.ctypes.data + layout_offset = layout_1.slice_offset_in_bytes - layout_0.slice_offset_in_bytes + assert layout_offset == ref_offset + + +@pytest.mark.parametrize( + "layout_spec", + [ + LayoutSpec(shape, py_rng.choice(_ITEMSIZES), stride_order) + for shape in [tuple(), (5,), (7, 9), (2, 3, 4)] + for stride_order in random_permutations(py_rng, len(shape)) + ], + ids=pretty_name, +) +def test_dense_with_permutation_as_stride_order(layout_spec): + """ + Test creating StridedLayout with stride_order=tuple(...). + """ + layout, np_ref = _setup_layout_and_np_ref(layout_spec) + _cmp_layout_and_array(layout, np_ref, False) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, False) + assert layout.stride_order == layout_spec.stride_order + + +@pytest.mark.parametrize( + "layout_spec", + [ + LayoutSpec(shape, py_rng.choice(_ITEMSIZES), stride_order, perm=permutation) + for shape in [tuple(), (1,), (2, 3), (5, 6, 7), (5, 1, 7), (5, 2, 3, 4)] + for permutation in random_permutations(py_rng, len(shape)) + for stride_order in list(DenseOrder) + ], + ids=pretty_name, +) +def test_permuted(layout_spec): + """ + Test creating StridedLayout with dense(C/F) order or implict C order + StridedLayout(strides=None) and calling permuted(perm) on it. + Tests against numpy transpose and checks stride_order attribute. + """ + layout, np_ref = _setup_layout_and_np_ref(layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides()) + layout, np_ref = _transform(layout, np_ref, layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides_transformed()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides_transformed()) + expected_order = inv_permutation(layout_spec.perm) + if layout_spec.np_order() == "F": + expected_order = tuple(reversed(expected_order)) + assert layout.stride_order == expected_order + + +class SliceErr(Enum): + ZERO_STEP = "slice step cannot be zer" + TOO_MANY_SLICES = "is greater than the number of dimensions" + OUT_OF_RANGE = "out of range for axis" + TYPE_ERROR = "Expected slice instance or integer." + + +@pytest.mark.parametrize( + ("layout_spec", "error_msg"), + [ + ( + LayoutSpec(shape, py_rng.choice(_ITEMSIZES), stride_order, slices=slices), + error_msg, + ) + for shape, slices, error_msg in [ + (tuple(), _S(), None), + ((12,), _S()[:], None), + ((13,), _S()[::-1], None), + ((13,), _S()[::-1][::-1], None), + ((13,), _S()[::-1][1:-1][::-1], None), + ((13,), _S()[2:-3], None), + ((13,), _S()[2:-3:2], None), + ((13,), _S()[-3:2:-2], None), + ((13,), _S()[-3:2:-2][1:3], None), + ((3, 5), _S()[:2][:, 3:], None), + ((3, 5), _S()[5:4], None), + ((3, 5), _S()[:, ::0], SliceErr.ZERO_STEP), + ((3, 5), _S()[:, :-1, :2], SliceErr.TOO_MANY_SLICES), + ((11, 12, 3), _S()[:, 0, :-1], None), + ((11, 12, 3), _S()[0, 1, :-1], None), + ((11, 12, 3, 5), _S()[0][1], None), + ((11, 12, 3, 5), _S()[:, 1, :-1], None), + ((11, 12, 3), _S()[0, 1, 2], None), + ((11, 12, 3), _S()[0, 1, 5], SliceErr.OUT_OF_RANGE), + ((11, 12, 3), _S()[-2], None), + ((11, 12, 3), _S()[-42], SliceErr.OUT_OF_RANGE), + ((11, 12, 3), _S()["abc"], SliceErr.TYPE_ERROR), + ] + for stride_order in list(DenseOrder) + ], + ids=pretty_name, +) +def test_slice(layout_spec, error_msg): + layout, np_ref = _setup_layout_and_np_ref(layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides()) + + if error_msg is None: + for sl in layout_spec.slices: + sliced = layout[sl] + sliced_ref = np_ref[sl] + _cmp_layout_and_array(sliced, sliced_ref, False) + _cmp_layout_from_dense_vs_from_np(sliced, sliced_ref, False) + _cmp_slice_offset(layout, sliced, np_ref, sliced_ref) + layout = sliced + np_ref = sliced_ref + else: + error_cls = TypeError if error_msg == SliceErr.TYPE_ERROR else ValueError + with pytest.raises(error_cls, match=error_msg.value): + for sl in layout_spec.slices: + layout[sl] + + +class ReshapeErr(Enum): + VOLUME_MISMATCH = "The original volume \\d+ and the new volume \\d+ must be equal." + NEG_EXTENT = "Extents must be non-negative" + MULTI_NEG_EXTENTS = "There can be at most one -1 extent in a shape" + AMBIGUOUS_NEG_EXTENT = "The -1 extent is ambiguous when the specified sub-volume is 0" + DIVISIBILITY_VIOLATION = "The original volume \\d+ must be divisible by the specified sub-volume \\d+" + STRIDE = "Layout strides are incompatible with the new shape" + TYPE_ERROR = None + + +@pytest.mark.parametrize( + ("layout_spec", "new_shape", "error_msg"), + [ + ( + LayoutSpec( + shape, + py_rng.choice(_ITEMSIZES), + stride_order, + perm=permutation, + slices=slices, + ), + NamedParam("new_shape", new_shape), + error_msg, + ) + for shape, permutation, slices, new_shape, error_msg in [ + (tuple(), None, None, tuple(), None), + (tuple(), None, None, (1,), None), + (tuple(), None, None, (-1,), None), + (tuple(), None, None, (1, -1, 1), None), + ((1,), None, None, (-1,), None), + ((1,), None, None, tuple(), None), + ((12,), None, _S()[:], (12,), None), + ((12,), None, None, (11,), ReshapeErr.VOLUME_MISMATCH), + ((12,), None, _S()[1:], (11,), None), + ((0,), None, None, (0,), None), + ((0,), None, None, (1, 3), ReshapeErr.VOLUME_MISMATCH), + ((3,), None, _S()[3:], (3,), ReshapeErr.VOLUME_MISMATCH), + ((18,), None, None, (0,), ReshapeErr.VOLUME_MISMATCH), + ((3,), None, _S()[2:-1], (0,), None), + ((3,), None, _S()[3:], (-1,), None), + ((0,), None, None, (1, -1), None), + ((0,), None, None, (0, -1), ReshapeErr.AMBIGUOUS_NEG_EXTENT), + ((3, 0, 3), None, None, (2, 3, 4, 5, 6, 7, 0, 12), None), + ((3, 0, 3), None, None, (0,), None), + ((12,), None, None, (2, 3, 2), None), + ((12,), None, None, (2, 6), None), + ((12,), None, None, (4, 3), None), + ((12,), None, None, (3, 4), None), + ((7, 12), None, None, (7, 12), None), + ((7, 12), None, None, (12, 7), None), + ((12, 11), None, None, (2, 3, 2, 11), None), + ((12, 11), None, None, (2, 3, 11, 2), None), + ((12, 11), None, None, (2, 11, 3, 2), None), + ((12, 11), None, None, (11, 2, 3, 2), None), + ((12, 11), None, None, (2, 3, 2, -1), None), + ((12, 11), None, None, (2, 3, -1, 2), None), + ((12, 11), None, None, (2, -1, 3, 2), None), + ((12, 11), None, None, (-1, 2, 3, 2), None), + ((12, 11), None, None, (2, 3, -1, 11), None), + ((12, 11), None, None, (2, 3, 11, -1), None), + ((12, 11), None, None, (-1, 11, 3, 2), None), + ((12, 11), None, None, (11, 2, -1, 2), None), + ((5, 12), None, None, (2, 5, 6), None), + ((2, 3, 2), None, None, (12,), None), + ((2, 3, 2), None, None, (6, 2), None), + ((2, 3, 2), None, None, (2, 3, 2), None), + ((2, 3, 2), (1, 2, 0), None, (6, 2), None), + ((2, 3, 2), (1, 2, 0), None, (2, 6), ReshapeErr.STRIDE), + ((2, 3, 2), (1, 2, 0), None, (12,), ReshapeErr.STRIDE), + ((2, 3, 2), (1, 0, 2), None, (3, 2, 2), None), + ((2, 3, 2), (1, 0, 2), None, (3, 4), ReshapeErr.STRIDE), + ((2, 3, 2), (1, 0, 2), None, (6, 2), ReshapeErr.STRIDE), + ((2, 3, 2), (1, 0, 2), None, (12,), ReshapeErr.STRIDE), + ((10, 10, 10), None, _S()[::-1, ::-1, :], (10, 10, 10), None), + ((10, 10, 10), None, _S()[::-1, ::-1, ::-1], (1000,), None), + ((10, 10, 10), None, _S()[::-1, ::-1, :], (100, 10), None), + ((10, 10, 10), None, _S()[::-1, ::-1, :], (10, 100), ReshapeErr.STRIDE), + ((10, 10, 10), None, _S()[:, :, ::-1], (100, 10), None), + ((10, 10, 10), None, _S()[:, :, ::-1], (10, 100), ReshapeErr.STRIDE), + ((10, 10, 10), None, _S()[::-1, :, ::-1], (1000,), ReshapeErr.STRIDE), + ((10, 10, 10), (1, 0, 2), _S()[::-1, ::-1], (100, 10), ReshapeErr.STRIDE), + ((5, 3), None, _S()[:-1, :], (12,), None), + ((13, 3), None, _S()[1:, :], (6, 6), None), + ((12, 4), None, _S()[:, :-1], (6, 6), ReshapeErr.STRIDE), + ((12, 4), None, _S()[:, :-1], (6, 2, 3), None), + ((7, 6, 5), None, None, (70, -1), None), + ((7, 6, 5), None, None, (-1, 70), None), + ((7, 6, 5), None, None, (71, -1), ReshapeErr.DIVISIBILITY_VIOLATION), + ((7, 6, 5), None, None, (-1, 71), ReshapeErr.DIVISIBILITY_VIOLATION), + ((7, 6, 5), None, None, (71, -2), ReshapeErr.NEG_EXTENT), + ((7, 6, 5), None, None, (-2, 71), ReshapeErr.NEG_EXTENT), + ((7, 6, 5), None, None, (-1, 6, -1), ReshapeErr.MULTI_NEG_EXTENTS), + ((7, 6, 5), None, None, (-2, -1, -1), ReshapeErr.NEG_EXTENT), + ((7, 6, 5), None, None, (-2, -1, -2), ReshapeErr.NEG_EXTENT), + ((7, 6, 5), None, None, (-7, 6, -5), ReshapeErr.NEG_EXTENT), + ((7, 6, 5), None, None, (5, 0, -1), ReshapeErr.AMBIGUOUS_NEG_EXTENT), + ((7, 0, 5), None, None, (5, 0, -1), ReshapeErr.AMBIGUOUS_NEG_EXTENT), + ((7, 6, 5), None, None, map, ReshapeErr.TYPE_ERROR), + ] + for stride_order in [DenseOrder.C, DenseOrder.IMPLICIT_C] + ], + ids=pretty_name, +) +def test_reshape(layout_spec, new_shape, error_msg): + new_shape = new_shape.value + layout, np_ref = _setup_layout_and_np_ref(layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides()) + + layout, np_ref = _transform(layout, np_ref, layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides_transformed()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides_transformed()) + + if error_msg is None: + reshaped = layout.reshaped(new_shape) + reshaped_ref = np_ref.reshape(new_shape, copy=False) + _cmp_layout_and_array(reshaped, reshaped_ref, False) + _cmp_layout_from_dense_vs_from_np(reshaped, reshaped_ref, False) + else: + # sanity check that numpy is not able to reshape without + # a copy as well + if error_msg == ReshapeErr.STRIDE: + with pytest.raises(ValueError): + np_ref.reshape(new_shape, copy=False) + + error_cls = TypeError if error_msg == ReshapeErr.TYPE_ERROR else ValueError + msg = None if error_msg == ReshapeErr.TYPE_ERROR else error_msg.value + with pytest.raises(error_cls, match=msg): + layout.reshaped(new_shape) + + +@pytest.mark.parametrize( + ( + "layout_spec", + "expected_shape", + "expected_strides", + "expected_axis_mask", + ), + [ + ( + LayoutSpec( + shape, + py_rng.choice(_ITEMSIZES), + stride_order, + perm=permutation, + slices=slices, + ), + NamedParam("expected_shape", expected_shape), + NamedParam("expected_strides", expected_strides), + NamedParam("expected_axis_mask", expected_axis_mask), + ) + for shape, permutation, slices, expected_shape, expected_strides, expected_axis_mask in [ + (tuple(), None, None, (1,), (1,), ""), + ((12,), None, _S()[:], (12,), (1,), "0"), + ((1, 2, 3, 4, 5), None, None, (120,), (1,), "01111"), + ((1, 2, 3, 0, 5), None, None, (0,), (0,), "01111"), + ((5, 1, 2, 4, 3), None, _S()[:, :, :, :, ::-2], (40, 2), (3, -2), "01110"), + ((5, 2, 4, 3), None, _S()[:, ::-1, :, :], (5, 2, 12), (24, -12, 1), "0001"), + ((5, 7, 4, 3), None, _S()[:, ::-1, ::-1], (5, 28, 3), (84, -3, 1), "0010"), + ((5, 4, 3, 7), (2, 3, 0, 1), _S()[:], (21, 20), (1, 21), "0101"), + ((5, 4, 3, 7), (3, 2, 0, 1), None, (7, 3, 20), (1, 7, 21), "0001"), + ] + for stride_order in [DenseOrder.C, DenseOrder.IMPLICIT_C] + ], + ids=pretty_name, +) +def test_flatten( + layout_spec, + expected_shape, + expected_strides, + expected_axis_mask, +): + expected_shape = expected_shape.value + expected_strides = expected_strides.value + expected_axis_mask = expected_axis_mask.value + + layout, np_ref = _setup_layout_and_np_ref(layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides()) + + layout, np_ref = _transform(layout, np_ref, layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides_transformed()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides_transformed()) + + mask = flatten_mask2str(layout.flattened_axis_mask(), layout.ndim) + assert mask == expected_axis_mask + + flattened = layout.flattened() + assert flattened.shape == expected_shape + assert flattened.strides == expected_strides + assert flattened.itemsize == layout_spec.itemsize + assert flattened.slice_offset == layout.slice_offset + + # cannot be flattened any further + assert flattened.flattened_axis_mask() == 0 + + +@pytest.mark.parametrize( + ( + "layout_spec_0", + "layout_spec_1", + "expected_layout_spec_0", + "expected_layout_spec_1", + ), + [ + ( + layout_spec_0, + layout_spec_1, + expected_layout_spec_0, + expected_layout_spec_1, + ) + for layout_spec_0, layout_spec_1, expected_layout_spec_0, expected_layout_spec_1 in [ + ( + LayoutSpec(tuple(), 2, DenseOrder.C), + LayoutSpec(tuple(), 4, DenseOrder.C), + LayoutSpec((1,), 2, DenseOrder.C), + LayoutSpec((1,), 4, DenseOrder.C), + ), + ( + LayoutSpec(tuple(), 2, DenseOrder.IMPLICIT_C), + LayoutSpec(tuple(), 4, DenseOrder.IMPLICIT_C), + LayoutSpec((1,), 2, DenseOrder.C), + LayoutSpec((1,), 4, DenseOrder.C), + ), + ( + LayoutSpec((2, 7, 13, 5), 8, DenseOrder.C), + LayoutSpec((3, 5, 11, 1), 4, DenseOrder.C), + LayoutSpec((910,), 8, DenseOrder.C), + LayoutSpec((165,), 4, DenseOrder.C), + ), + ( + LayoutSpec((2, 7, 13, 5), 8, DenseOrder.IMPLICIT_C), + LayoutSpec((3, 5, 11, 1), 4, DenseOrder.IMPLICIT_C), + LayoutSpec((910,), 8, DenseOrder.C), + LayoutSpec((165,), 4, DenseOrder.C), + ), + ( + LayoutSpec((5, 7, 13, 2), 4, (3, 1, 2, 0)), + LayoutSpec((3, 5, 11, 1), 2, DenseOrder.IMPLICIT_C), + LayoutSpec((5, 91, 2), 4, (2, 1, 0)), + LayoutSpec((3, 55, 1), 2, DenseOrder.C), + ), + ( + LayoutSpec((2, 7, 13, 5), 16, DenseOrder.C), + LayoutSpec((11, 1, 3, 5), 1, (2, 3, 0, 1)), + LayoutSpec((14, 65), 16, DenseOrder.C), + LayoutSpec((11, 15), 1, (1, 0)), + ), + ( + LayoutSpec( + (4, 5, 11, 2, 3, 7), + 4, + (5, 3, 4, 0, 1, 2), + ), + LayoutSpec( + (3, 8, 5, 6, 7, 9), + 4, + (0, 1, 3, 4, 5, 2), + ), + LayoutSpec((20, 11, 6, 7), 4, (3, 2, 0, 1)), + LayoutSpec((24, 5, 42, 9), 4, (0, 2, 3, 1)), + ), + ] + ], + ids=pretty_name, +) +def test_flatten_together( + layout_spec_0, + layout_spec_1, + expected_layout_spec_0, + expected_layout_spec_1, +): + layouts = [] + for layout_spec in [ + layout_spec_0, + layout_spec_1, + expected_layout_spec_0, + expected_layout_spec_1, + ]: + layout, np_ref = _setup_layout_and_np_ref(layout_spec) + layout, np_ref = _transform(layout, np_ref, layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides_transformed()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides_transformed()) + layouts.append(layout) + + layout_0, layout_1, expected_layout_0, expected_layout_1 = layouts + + mask_0 = layout_0.flattened_axis_mask() + mask_1 = layout_1.flattened_axis_mask() + mask = mask_0 & mask_1 + + flattened_0 = layout_0.flattened(mask=mask) + flattened_1 = layout_1.flattened(mask=mask) + + for flattened, expected_layout in zip([flattened_0, flattened_1], [expected_layout_0, expected_layout_1]): + assert flattened == expected_layout + assert flattened.shape == expected_layout.shape + assert flattened.strides == expected_layout.strides + assert flattened.itemsize == expected_layout.itemsize + assert flattened.slice_offset == expected_layout.slice_offset + assert flattened.is_contiguous_c == expected_layout.is_contiguous_c + assert flattened.is_contiguous_f == expected_layout.is_contiguous_f + assert flattened.is_contiguous_any == expected_layout.is_contiguous_any + assert flattened.is_unique == expected_layout.is_unique + + +@pytest.mark.parametrize( + ("layout_spec",), + [ + ( + LayoutSpec( + shape, + py_rng.choice(_ITEMSIZES), + stride_order, + perm=permutation, + slices=slices, + ), + ) + for shape, permutation, slices in [ + (tuple(), None, None), + ((12,), None, None), + ((1, 5, 4, 3), None, None), + ((1, 5, 1, 4, 3), None, _S()[:, -1:, :]), + ((1, 5, 4, 3), None, _S()[:, -1:, :1, 1:2]), + ((7, 5, 3), (2, 0, 1), _S()[::-1, 3:2:-1, :]), + ((7, 5, 3), (2, 0, 1), _S()[:, 3:2, :]), + ] + for stride_order in list(DenseOrder) + ], + ids=pretty_name, +) +def test_squeezed(layout_spec): + layout, np_ref = _setup_layout_and_np_ref(layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides()) + layout, np_ref = _transform(layout, np_ref, layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides_transformed()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides_transformed()) + + squeezed = layout.squeezed() + squeezed_ref = np_ref.squeeze() + if math.prod(np_ref.shape) != 0: + _cmp_layout_and_array(squeezed, squeezed_ref, False) + _cmp_layout_from_dense_vs_from_np(squeezed, squeezed_ref, False) + else: + assert squeezed.shape == (0,) + assert squeezed.strides == (0,) + assert squeezed.slice_offset == layout.slice_offset + + +@pytest.mark.parametrize( + ( + "layout_spec", + "axes", + ), + [ + ( + LayoutSpec(shape, py_rng.choice(_ITEMSIZES), stride_order, slices=slices), + NamedParam("axes", axes), + ) + for shape, slices in [ + (tuple(), None), + ((7,), None), + ((4, 5, 7, 11), _S()[1:-1, ::-1, 2:-1, ::3]), + ] + for stride_order in list(DenseOrder) + for num_axes in range(3) + for axes in itertools.combinations(list(range(len(shape) + num_axes)), num_axes) + ], + ids=pretty_name, +) +def test_unsqueezed_layout(layout_spec, axes): + axes = tuple(axes.value) + + layout, np_ref = _setup_layout_and_np_ref(layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides()) + layout, np_ref = _transform(layout, np_ref, layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides_transformed()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides_transformed()) + + unsqueezed = layout.unsqueezed(axes) + unsqueezed_ref = np.expand_dims(np_ref, axis=axes) + # the implicit C layout is kept if the original layout has such strides + # and there are no actual transformations along the way: no slices + # and unsqueezing with empty axes tuple + has_no_strides = layout_spec.has_no_strides_transformed() and len(axes) == 0 + _cmp_layout_and_array(unsqueezed, unsqueezed_ref, has_no_strides) + _cmp_layout_from_dense_vs_from_np(unsqueezed, unsqueezed_ref, has_no_strides) + + +@pytest.mark.parametrize( + ( + "layout_spec", + "axis", + "expected_max_itemsize", + "new_itemsize", + ), + [ + ( + LayoutSpec(shape, itemsize, stride_order, perm=permutation, slices=slices), + NamedParam("axis", axis), + NamedParam("expected_max_itemsize", expected_max_itemsize), + NamedParam("new_itemsize", new_itemsize), + ) + for shape, permutation, slices, stride_order, itemsize, axis, expected_max_itemsize, new_itemsize in [ + ((12,), None, None, DenseOrder.C, 1, -1, 4, 1), + ((12,), None, None, DenseOrder.IMPLICIT_C, 1, -1, 4, 1), + ((12,), None, None, DenseOrder.F, 1, 0, 4, 1), + ((12,), None, None, DenseOrder.C, 4, -1, 16, 8), + ((12,), None, None, DenseOrder.IMPLICIT_C, 4, -1, 16, 8), + ((12,), None, None, DenseOrder.F, 4, 0, 16, 8), + ((16, 5, 4, 6), None, None, DenseOrder.C, 2, -1, 4, 4), + ((16, 5, 4, 6), None, None, DenseOrder.IMPLICIT_C, 2, -1, 4, 4), + ((16, 5, 4, 6), None, None, DenseOrder.F, 2, 0, 16, 4), + ((11, 5, 9), None, _S()[:, :, -1:], DenseOrder.C, 2, 2, 2, 2), + ((11, 5, 9), None, _S()[:, :, -1:], DenseOrder.IMPLICIT_C, 2, 2, 2, 2), + ((11, 5, 9), None, _S()[:, :, -1:], DenseOrder.F, 2, 0, 2, 2), + ((12, 3, 24), (1, 2, 0), _S()[::-1, 20:, 1:], DenseOrder.C, 2, 1, 8, 8), + ((12, 3, 24), (1, 2, 0), _S()[1:, ::-1, 10:], DenseOrder.F, 2, 2, 4, 4), + ] + ], + ids=pretty_name, +) +def test_packed_unpacked( + layout_spec, + axis, + expected_max_itemsize, + new_itemsize, +): + axis = axis.value + expected_max_itemsize = expected_max_itemsize.value + new_itemsize = new_itemsize.value + + layout, np_ref = _setup_layout_and_np_ref(layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides()) + layout, np_ref = _transform(layout, np_ref, layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides_transformed()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides_transformed()) + + assert layout.max_compatible_itemsize(axis=axis) == expected_max_itemsize + packed = layout.repacked(new_itemsize, axis=axis) + # numpy does not allow specifying the axis to repack, + # so we need to transpose the array + packed_ref = ( + np_ref.transpose(layout.stride_order) + .view(dtype=dtype_from_itemsize(new_itemsize)) + .transpose(inv_permutation(layout.stride_order)) + ) + has_no_strides = layout_spec.has_no_strides_transformed() and layout.itemsize == new_itemsize + _cmp_layout_and_array(packed, packed_ref, has_no_strides) + _cmp_layout_from_dense_vs_from_np(packed, packed_ref, has_no_strides) + vec_size = new_itemsize // layout.itemsize + assert packed.slice_offset * vec_size == layout.slice_offset + unpacked = packed.repacked(layout.itemsize, axis=axis) + _cmp_layout_and_array(unpacked, np_ref, has_no_strides) + _cmp_layout_from_dense_vs_from_np(unpacked, np_ref, has_no_strides) + + +@pytest.mark.parametrize( + ( + "layout_spec", + "new_shape", + ), + [ + ( + LayoutSpec(shape, py_rng.choice(_ITEMSIZES), stride_order, slices=slices), + NamedParam("new_shape", new_shape), + ) + for shape, slices, new_shape in [ + (tuple(), None, tuple()), + (tuple(), None, (1,)), + (tuple(), None, (17, 1, 5)), + ((1,), None, (5,)), + ((1,), None, (3, 5, 2)), + ((7,), None, (7,)), + ((7,), None, (2, 7)), + ((5, 11), _S()[1:-1, ::-1], (3, 11)), + ((5, 11), _S()[1:-1, ::-1], (7, 3, 11)), + ((5, 11), _S()[::-1, 3:4], (5, 7)), + ((5, 11), _S()[::-1, 3:4], (5, 30)), + ((5, 11), _S()[::-1, 3:4], (4, 5, 12)), + ((5, 11), _S()[-1:,], (4, 13, 11)), + ] + for stride_order in list(DenseOrder) + ], + ids=pretty_name, +) +def test_broadcast_layout( + layout_spec, + new_shape, +): + new_shape = new_shape.value + layout, np_ref = _setup_layout_and_np_ref(layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides()) + layout, np_ref = _transform(layout, np_ref, layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides_transformed()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides_transformed()) + + broadcasted = layout.broadcast_to(new_shape) + broadcasted_ref = np.broadcast_to(np_ref, new_shape) + _cmp_layout_and_array(broadcasted, broadcasted_ref, False) + _cmp_layout_from_dense_vs_from_np(broadcasted, broadcasted_ref, False) + assert layout.is_unique + ndim_diff = len(broadcasted_ref.shape) - len(np_ref.shape) + expect_unique = all(broadcasted_ref.shape[i] == 1 for i in range(ndim_diff)) + expect_unique = expect_unique and all( + broadcasted_ref.shape[i + ndim_diff] == np_ref.shape[i] for i in range(len(np_ref.shape)) + ) + assert broadcasted.is_unique is expect_unique + + +@pytest.mark.parametrize( + ( + "layout_spec", + "new_stride_order", + ), + [ + ( + LayoutSpec( + shape, + py_rng.choice(_ITEMSIZES), + stride_order, + perm=permutation, + slices=slices, + ), + NamedParam("new_stride_order", new_stride_order), + ) + for shape, permutation, slices in [ + (tuple(), None, None), + ((1,), None, None), + ((7,), None, None), + ((7,), None, _S()[3:6]), + ((7,), None, _S()[::-1]), + ((5, 11), None, None), + ((5, 11), None, _S()[1:-1]), + ((5, 11), None, _S()[::-1, 3:10]), + ((5, 11), None, _S()[1:4, ::-1]), + ((5, 11), None, _S()[-1:,]), + ((3, 5, 7), (1, 0, 2), None), + ] + for stride_order in list(DenseOrder) + for new_stride_order in ["C", "F", "K"] + random_permutations(py_rng, len(shape)) + ], + ids=pretty_name, +) +def test_to_dense(layout_spec, new_stride_order): + new_stride_order = new_stride_order.value + + layout, np_ref = _setup_layout_and_np_ref(layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides()) + layout, np_ref = _transform(layout, np_ref, layout_spec) + _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides_transformed()) + _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides_transformed()) + + if isinstance(new_stride_order, str): + if new_stride_order == "K": + is_noop = layout.slice_offset == 0 and layout.is_contiguous_any + elif new_stride_order == "C": + is_noop = layout.slice_offset == 0 and layout.is_contiguous_c + elif new_stride_order == "F": + is_noop = layout.slice_offset == 0 and layout.is_contiguous_f + else: + raise AssertionError(f"Invalid new_stride_order: {new_stride_order}") + has_no_strides = layout_spec.has_no_strides_transformed() and is_noop + dense = layout.to_dense(new_stride_order) + dense_ref = np_ref.copy(order=new_stride_order) + _cmp_layout_and_array(dense, dense_ref, has_no_strides) + _cmp_layout_from_dense_vs_from_np(dense, dense_ref, has_no_strides) + else: + assert isinstance(new_stride_order, tuple) + assert len(new_stride_order) == len(layout.shape) + dense = layout.to_dense(new_stride_order) + dense_ref = np_ref.transpose(new_stride_order).copy(order="C").transpose(inv_permutation(new_stride_order)) + _cmp_layout_and_array(dense, dense_ref, False) + _cmp_layout_from_dense_vs_from_np(dense, dense_ref, False) From acdd6f85d60488204f36516a117bfe863787ff28 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Wed, 26 Nov 2025 18:24:00 +0100 Subject: [PATCH 07/20] Use explicit int32_t instead of int in integer fused type Signed-off-by: Kamil Tokarski --- cuda_core/cuda/core/experimental/_layout.pxd | 6 +++--- cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_layout.pxd b/cuda_core/cuda/core/experimental/_layout.pxd index 37ce85cc72..4235587d3b 100644 --- a/cuda_core/cuda/core/experimental/_layout.pxd +++ b/cuda_core/cuda/core/experimental/_layout.pxd @@ -5,12 +5,12 @@ cimport cython from cython.operator cimport dereference as deref -from libc.stdint cimport int64_t, uint32_t, uintptr_t +from libc.stdint cimport int64_t, int32_t, uint32_t, uintptr_t from libcpp cimport vector ctypedef int64_t extent_t ctypedef int64_t stride_t -ctypedef int axis_t +ctypedef int32_t axis_t ctypedef uint32_t axes_mask_t # MUST be exactly STRIDED_LAYOUT_MAX_NDIM bits wide ctypedef uint32_t property_mask_t @@ -23,7 +23,7 @@ from cuda.core.experimental._utils cimport cuda_utils ctypedef fused integer_t: int64_t - int + int32_t cdef extern from "include/layout.hpp": diff --git a/cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd index 8af02fd92f..ce30285aa5 100644 --- a/cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd +++ b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd @@ -4,7 +4,7 @@ cimport cpython from cpython.object cimport PyObject -from libc.stdint cimport int64_t +from libc.stdint cimport int64_t, int32_t from cuda.bindings cimport cydriver @@ -15,7 +15,7 @@ ctypedef fused supported_error_type: ctypedef fused integer_t: int64_t - int + int32_t # mimic CU_DEVICE_INVALID From 60a0d668758dd43d3b763ad6b6c9f584976d532e Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Wed, 26 Nov 2025 18:42:57 +0100 Subject: [PATCH 08/20] Disable (for now) exporting the SMV via dlpack Signed-off-by: Kamil Tokarski --- .../cuda/core/experimental/_memoryview.pyx | 76 +++++-------------- 1 file changed, 19 insertions(+), 57 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memoryview.pyx b/cuda_core/cuda/core/experimental/_memoryview.pyx index d5221d4fcf..b6a88d07fb 100644 --- a/cuda_core/cuda/core/experimental/_memoryview.pyx +++ b/cuda_core/cuda/core/experimental/_memoryview.pyx @@ -15,7 +15,6 @@ import numpy from cuda.core.experimental._utils.cuda_utils import handle_return, driver -from cuda.core.experimental._dlpack import make_py_capsule from cuda.core.experimental._memory import Buffer # TODO(leofang): support NumPy structured dtypes @@ -291,44 +290,6 @@ cdef class StridedMemoryView: + f" readonly={self.readonly},\n" + f" exporting_obj={get_simple_repr(self.exporting_obj)})") - def __dlpack__( - self, - *, - stream: int | None = None, - max_version: tuple[int, int] | None = None, - dl_device: tuple[int, int] | None = None, - copy: bool | None = None, - ) -> PyCapsule: - # Note: we ignore the stream argument entirely (as if it is -1). - # It is the user's responsibility to maintain stream order. - if dl_device is not None: - raise BufferError("Sorry, not supported: dl_device other than None") - if copy is True: - raise BufferError("Sorry, not supported: copy=True") - if max_version is None: - versioned = False - else: - if not isinstance(max_version, tuple) or len(max_version) != 2: - raise BufferError(f"Expected max_version tuple[int, int], got {max_version}") - versioned = max_version >= (1, 0) - cdef object dtype = self.get_dtype() - if dtype is None: - raise ValueError( - "Cannot export the StridedMemoryView without a dtype. " - "You can create a dtyped view calling view(dtype=...) method." - ) - capsule = make_py_capsule( - self.get_buffer(), - versioned, - self.ptr, - self.get_layout(), - _numpy2dlpack_dtype[dtype], - ) - return capsule - - def __dlpack_device__(self) -> tuple[int, int]: - return self.get_buffer().__dlpack_device__() - cdef inline StridedLayout get_layout(self): if self._layout is None: if self.dl_tensor: @@ -478,24 +439,25 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None): return buf -_numpy2dlpack_dtype = { - numpy.dtype("uint8"): (kDLUInt, 8, 1), - numpy.dtype("uint16"): (kDLUInt, 16, 1), - numpy.dtype("uint32"): (kDLUInt, 32, 1), - numpy.dtype("uint64"): (kDLUInt, 64, 1), - numpy.dtype("int8"): (kDLInt, 8, 1), - numpy.dtype("int16"): (kDLInt, 16, 1), - numpy.dtype("int32"): (kDLInt, 32, 1), - numpy.dtype("int64"): (kDLInt, 64, 1), - numpy.dtype("float16"): (kDLFloat, 16, 1), - numpy.dtype("float32"): (kDLFloat, 32, 1), - numpy.dtype("float64"): (kDLFloat, 64, 1), - numpy.dtype("complex64"): (kDLComplex, 64, 1), - numpy.dtype("complex128"): (kDLComplex, 128, 1), - numpy.dtype("bool"): (kDLBool, 8, 1), -} -_typestr2dtype = {dtype.str: dtype for dtype in _numpy2dlpack_dtype.keys()} -_typestr2itemsize = {dtype.str: dtype.itemsize for dtype in _numpy2dlpack_dtype.keys()} +_builtin_numeric_dtypes = [ + numpy.dtype("uint8"), + numpy.dtype("uint16"), + numpy.dtype("uint32"), + numpy.dtype("uint64"), + numpy.dtype("int8"), + numpy.dtype("int16"), + numpy.dtype("int32"), + numpy.dtype("int64"), + numpy.dtype("float16"), + numpy.dtype("float32"), + numpy.dtype("float64"), + numpy.dtype("complex64"), + numpy.dtype("complex128"), + numpy.dtype("bool"), +] +# Doing it once to avoid repeated overhead +_typestr2dtype = {dtype.str: dtype for dtype in _builtin_numeric_dtypes} +_typestr2itemsize = {dtype.str: dtype.itemsize for dtype in _builtin_numeric_dtypes} cdef object dtype_dlpack_to_numpy(DLDataType* dtype): From 1fa43d440c4fbb6bf978e2c98d6c0109dd697374 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Wed, 26 Nov 2025 19:05:10 +0100 Subject: [PATCH 09/20] Revert dlpack changes Signed-off-by: Kamil Tokarski --- cuda_core/cuda/core/experimental/_memory/_buffer.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda_core/cuda/core/experimental/_memory/_buffer.pyx b/cuda_core/cuda/core/experimental/_memory/_buffer.pyx index c596dbcc59..ac611d5cdd 100644 --- a/cuda_core/cuda/core/experimental/_memory/_buffer.pyx +++ b/cuda_core/cuda/core/experimental/_memory/_buffer.pyx @@ -203,7 +203,7 @@ cdef class Buffer: if not isinstance(max_version, tuple) or len(max_version) != 2: raise BufferError(f"Expected max_version tuple[int, int], got {max_version}") versioned = max_version >= (1, 0) - capsule = make_py_capsule(self, versioned, int(self.handle)) + capsule = make_py_capsule(self, versioned) return capsule def __dlpack_device__(self) -> tuple[int, int]: From 67c6c5e847cf1b1af89744beacb08c2a8c4b050d Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Thu, 27 Nov 2025 16:21:01 +0100 Subject: [PATCH 10/20] Support layouts up to 64 dims Signed-off-by: Kamil Tokarski --- cuda_core/cuda/core/experimental/_layout.pxd | 13 ++-- cuda_core/cuda/core/experimental/_layout.pyx | 14 ++-- .../cuda/core/experimental/include/layout.hpp | 4 +- cuda_core/tests/helpers/layout.py | 15 ++++- cuda_core/tests/test_strided_layout.py | 66 +++++++++++++++++-- 5 files changed, 90 insertions(+), 22 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_layout.pxd b/cuda_core/cuda/core/experimental/_layout.pxd index 4235587d3b..fe0e8d2e58 100644 --- a/cuda_core/cuda/core/experimental/_layout.pxd +++ b/cuda_core/cuda/core/experimental/_layout.pxd @@ -5,14 +5,14 @@ cimport cython from cython.operator cimport dereference as deref -from libc.stdint cimport int64_t, int32_t, uint32_t, uintptr_t +from libc.stdint cimport int64_t, int32_t, uint32_t, uint64_t, uintptr_t from libcpp cimport vector ctypedef int64_t extent_t ctypedef int64_t stride_t ctypedef int32_t axis_t -ctypedef uint32_t axes_mask_t # MUST be exactly STRIDED_LAYOUT_MAX_NDIM bits wide +ctypedef uint64_t axes_mask_t # MUST be exactly STRIDED_LAYOUT_MAX_NDIM bits wide ctypedef uint32_t property_mask_t ctypedef vector.vector[stride_t] extents_strides_t @@ -29,7 +29,7 @@ ctypedef fused integer_t: cdef extern from "include/layout.hpp": cdef int STRIDED_LAYOUT_MAX_NDIM - cdef int AXIS_MASK_ALL + cdef axes_mask_t AXIS_MASK_ALL int64_t _c_abs(int64_t x) nogil void _order_from_strides(axis_vec_t& indices, extent_t* extent_t, stride_t* stride_t, int ndim) except + nogil void _swap(extents_strides_t &a, extents_strides_t &b) noexcept nogil @@ -466,7 +466,7 @@ cdef inline stride_t _dense_strides_in_order(BaseLayout& base, axis_vec_t& strid axis = stride_order[i] if not _normalize_axis(axis, ndim): raise ValueError(f"Invalid stride order: axis {axis} out of range for {ndim}D tensor") - axis_mask = 1 << axis + axis_mask = _axis2mask(axis) if axis_order_mask & axis_mask: raise ValueError(f"The stride order must be a permutation. Axis {axis} appears multiple times.") axis_order_mask |= axis_mask @@ -605,6 +605,11 @@ cdef inline bint _set_boolean_property(StridedLayout self, Property prop, bint v # Conversion, validation and normalization helpers # ============================== + +cdef inline axes_mask_t _axis2mask(axis_t axis) noexcept nogil: + return 1ULL << axis + + cdef inline OrderFlag _stride_order2vec(axis_vec_t& stride_order_vec, object stride_order) except? ORDER_NONE: if stride_order == 'C': return ORDER_C diff --git a/cuda_core/cuda/core/experimental/_layout.pyx b/cuda_core/cuda/core/experimental/_layout.pyx index 93570687af..f1aa1ec78b 100644 --- a/cuda_core/cuda/core/experimental/_layout.pyx +++ b/cuda_core/cuda/core/experimental/_layout.pyx @@ -951,7 +951,7 @@ cdef inline axes_mask_t axis_mask_from_range(int ndim, int start_axis, int end_a if not _normalize_axis(end_axis, ndim): raise ValueError(f"Invalid end axis: {end_axis} out of range for {ndim}D tensor") if start_axis > 0: - axis_mask &= (AXIS_MASK_ALL << start_axis + 1) + axis_mask &= (AXIS_MASK_ALL << (start_axis + 1)) if end_axis < ndim: axis_mask &= (AXIS_MASK_ALL >> (STRIDED_LAYOUT_MAX_NDIM - end_axis - 1)) return axis_mask @@ -978,7 +978,7 @@ cdef inline int flatten_strides_in_c_index_order(BaseLayout& out_layout, BaseLay group_end = group_start + 1 while ( group_end < ndim - and (axis_mask & (1 << group_end)) + and (axis_mask & _axis2mask(group_end)) and group_stride == _overflow_checked_mul(in_strides[group_end], in_shape[group_end]) ): group_vol = _overflow_checked_mul(group_vol, in_shape[group_end]) @@ -1009,7 +1009,7 @@ cdef inline axes_mask_t flattened_strides_in_c_index_order_mask(BaseLayout& layo while group_end < ndim and group_stride == layout.strides[group_end] * layout.shape[group_end]: group_vol = _overflow_checked_mul(group_vol, layout.shape[group_end]) group_stride = layout.strides[group_end] - axis_mask |= (1 << group_end) + axis_mask |= _axis2mask(group_end) group_end += 1 group_start = group_end return axis_mask @@ -1060,7 +1060,7 @@ cdef inline int permute_extents(BaseLayout& out_layout, BaseLayout& in_layout, a axis = axis_order[i] if not _normalize_axis(axis, ndim): raise ValueError(f"Invalid permutation: axis {axis} out of range for {ndim}D tensor") - axis_mask = 1 << axis + axis_mask = _axis2mask(axis) if axis_order_mask & axis_mask: raise ValueError(f"Invalid permutation: axis {axis_order[i]} appears multiple times.") axis_order_mask |= axis_mask @@ -1156,15 +1156,13 @@ cdef inline int unsqueeze_extents(BaseLayout& out_layout, BaseLayout& in_layout, axis = axis_vec[i] if not _normalize_axis(axis, out_ndim): raise ValueError(f"Invalid axis: {axis} out of range for {out_ndim}D tensor") - axis_mask = 1 << axis + axis_mask = _axis2mask(axis) if out_shape_mask & axis_mask: raise ValueError(f"Axis {axis} appears multiple times.") out_shape_mask |= axis_mask cdef int in_i = 0 for i in range(out_ndim): - # without the cast, cython has issues with - # recognizing 1 << i does not require Python interaction - axis_mask = 1 << i + axis_mask = _axis2mask(i) if out_shape_mask & axis_mask: out_layout.shape[i] = 1 if in_i < ndim: diff --git a/cuda_core/cuda/core/experimental/include/layout.hpp b/cuda_core/cuda/core/experimental/include/layout.hpp index bed485ea01..58c408889e 100644 --- a/cuda_core/cuda/core/experimental/include/layout.hpp +++ b/cuda_core/cuda/core/experimental/include/layout.hpp @@ -11,8 +11,8 @@ #include -#define STRIDED_LAYOUT_MAX_NDIM 32 -#define AXIS_MASK_ALL 0xFFFFFFFE +#define STRIDED_LAYOUT_MAX_NDIM 64 +#define AXIS_MASK_ALL 0xFFFFFFFFFFFFFFFEULL inline int64_t _c_abs(int64_t x) { diff --git a/cuda_core/tests/helpers/layout.py b/cuda_core/tests/helpers/layout.py index 387a849e39..02fadac867 100644 --- a/cuda_core/tests/helpers/layout.py +++ b/cuda_core/tests/helpers/layout.py @@ -40,8 +40,12 @@ class _S: SliceSpec """ - def __init__(self): - self.slices = [] + def __init__(self, slices=None): + if slices is None: + slices = [] + else: + assert isinstance(slices, list) + self.slices = slices def __getitem__(self, value): self.slices.append(value) @@ -149,3 +153,10 @@ def inv_permutation(perm): def permuted(t, perm): return tuple(t[i] for i in perm) + + +def long_shape(rng, ndim, num_non_unit_dims=5, max_dim_size=6): + dims = [min(i + 2, max_dim_size) for i in range(num_non_unit_dims)] + dims.extend(1 for i in range(ndim - num_non_unit_dims)) + rng.shuffle(dims) + return tuple(dims) diff --git a/cuda_core/tests/test_strided_layout.py b/cuda_core/tests/test_strided_layout.py index 27c6c856cb..b783257054 100644 --- a/cuda_core/tests/test_strided_layout.py +++ b/cuda_core/tests/test_strided_layout.py @@ -18,6 +18,7 @@ dtype_from_itemsize, flatten_mask2str, inv_permutation, + long_shape, permuted, pretty_name, random_permutations, @@ -155,8 +156,16 @@ def test_dense_with_permutation_as_stride_order(layout_spec): "layout_spec", [ LayoutSpec(shape, py_rng.choice(_ITEMSIZES), stride_order, perm=permutation) - for shape in [tuple(), (1,), (2, 3), (5, 6, 7), (5, 1, 7), (5, 2, 3, 4)] - for permutation in random_permutations(py_rng, len(shape)) + for shape in [ + tuple(), + (1,), + (2, 3), + (5, 6, 7), + (5, 1, 7), + (5, 2, 3, 4), + long_shape(py_rng, 64), + ] + for permutation in random_permutations(py_rng, len(shape), sample_size=3) for stride_order in list(DenseOrder) ], ids=pretty_name, @@ -173,10 +182,44 @@ def test_permuted(layout_spec): layout, np_ref = _transform(layout, np_ref, layout_spec) _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides_transformed()) _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides_transformed()) - expected_order = inv_permutation(layout_spec.perm) - if layout_spec.np_order() == "F": - expected_order = tuple(reversed(expected_order)) - assert layout.stride_order == expected_order + unit_dims_count = sum(dim == 1 for dim in np_ref.shape) + if unit_dims_count <= 1: + # stride order with multiple unit dimensions is not unique + # a simple equality check won't do + expected_order = inv_permutation(layout_spec.perm) + if layout_spec.np_order() == "F": + expected_order = tuple(reversed(expected_order)) + assert layout.stride_order == expected_order + + +class PermutedErr(Enum): + REPEATED_AXIS = "axis -?\\d+ appears multiple times" + OUT_OF_RANGE = "axis -?\\d+ out of range for" + WRONG_LEN = "the same length as the number of dimensions" + + +@pytest.mark.parametrize( + ("layout_spec", "error_msg"), + [ + ( + LayoutSpec(shape, py_rng.choice(_ITEMSIZES), stride_order, perm=permutation), + error_msg, + ) + for shape, permutation, error_msg in [ + (tuple(), (5,), PermutedErr.WRONG_LEN), + ((1,), (0, 0), PermutedErr.WRONG_LEN), + ((2, 5, 3), (1, 0, 1), PermutedErr.REPEATED_AXIS), + ((5, 6, 7), (1, 3, 0), PermutedErr.OUT_OF_RANGE), + ((5, 6, 7), (1, -2000, 0), PermutedErr.OUT_OF_RANGE), + ] + for stride_order in list(DenseOrder) + ], + ids=pretty_name, +) +def test_permuted_validation(layout_spec, error_msg): + layout, _ = _setup_layout_and_np_ref(layout_spec) + with pytest.raises(ValueError, match=error_msg.value): + layout.permuted(layout_spec.perm) class SliceErr(Enum): @@ -216,6 +259,7 @@ class SliceErr(Enum): ((11, 12, 3), _S()[-2], None), ((11, 12, 3), _S()[-42], SliceErr.OUT_OF_RANGE), ((11, 12, 3), _S()["abc"], SliceErr.TYPE_ERROR), + (long_shape(py_rng, 64), _S([slice(None, None, -1)] * 64), None), ] for stride_order in list(DenseOrder) ], @@ -340,6 +384,8 @@ class ReshapeErr(Enum): ((7, 6, 5), None, None, (5, 0, -1), ReshapeErr.AMBIGUOUS_NEG_EXTENT), ((7, 0, 5), None, None, (5, 0, -1), ReshapeErr.AMBIGUOUS_NEG_EXTENT), ((7, 6, 5), None, None, map, ReshapeErr.TYPE_ERROR), + # random 64-dim shape with 5 non-unit extents 2, 3, 4, 5, 6 + (long_shape(py_rng, 64, 5, 6), None, None, (60, 12), None), ] for stride_order in [DenseOrder.C, DenseOrder.IMPLICIT_C] ], @@ -403,6 +449,8 @@ def test_reshape(layout_spec, new_shape, error_msg): ((5, 7, 4, 3), None, _S()[:, ::-1, ::-1], (5, 28, 3), (84, -3, 1), "0010"), ((5, 4, 3, 7), (2, 3, 0, 1), _S()[:], (21, 20), (1, 21), "0101"), ((5, 4, 3, 7), (3, 2, 0, 1), None, (7, 3, 20), (1, 7, 21), "0001"), + # random 64-dim shape with 4 non-unit extents 2, 3, 4, 5 + (long_shape(py_rng, 64, 4, 5), None, None, (120,), (1,), "0" + "1" * 63), ] for stride_order in [DenseOrder.C, DenseOrder.IMPLICIT_C] ], @@ -568,6 +616,8 @@ def test_flatten_together( ((1, 5, 4, 3), None, _S()[:, -1:, :1, 1:2]), ((7, 5, 3), (2, 0, 1), _S()[::-1, 3:2:-1, :]), ((7, 5, 3), (2, 0, 1), _S()[:, 3:2, :]), + (long_shape(py_rng, 64, 1), None, None), + (long_shape(py_rng, 33, 3), None, None), ] for stride_order in list(DenseOrder) ], @@ -662,6 +712,9 @@ def test_unsqueezed_layout(layout_spec, axes): ((11, 5, 9), None, _S()[:, :, -1:], DenseOrder.F, 2, 0, 2, 2), ((12, 3, 24), (1, 2, 0), _S()[::-1, 20:, 1:], DenseOrder.C, 2, 1, 8, 8), ((12, 3, 24), (1, 2, 0), _S()[1:, ::-1, 10:], DenseOrder.F, 2, 2, 4, 4), + ((1, 3) + (1,) * 61 + (4,), None, None, DenseOrder.C, 2, -1, 8, 8), + ((1, 3) + (1,) * 61 + (4,), None, None, DenseOrder.IMPLICIT_C, 2, -1, 8, 4), + ((4, 3) + (1,) * 61 + (3,), None, None, DenseOrder.F, 2, 0, 8, 4), ] ], ids=pretty_name, @@ -726,6 +779,7 @@ def test_packed_unpacked( ((5, 11), _S()[::-1, 3:4], (5, 30)), ((5, 11), _S()[::-1, 3:4], (4, 5, 12)), ((5, 11), _S()[-1:,], (4, 13, 11)), + ((2, 3, 3), _S()[:, 1:2], (401, 3) + (1,) * 59 + (2, 4, 3)), ] for stride_order in list(DenseOrder) ], From a96bec5090ff28abfe38e189ab8a2e2d8b19a7ee Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Thu, 27 Nov 2025 18:53:32 +0100 Subject: [PATCH 11/20] Use cydriver to query memory attributes, fix managed memory handling, add tests for the attributes Signed-off-by: Kamil Tokarski --- .../core/experimental/_memory/_buffer.pyx | 71 +++++++---- cuda_core/tests/test_memory.py | 118 ++++++++++++++++++ 2 files changed, 164 insertions(+), 25 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memory/_buffer.pyx b/cuda_core/cuda/core/experimental/_memory/_buffer.pyx index ac611d5cdd..c31c5a8fe9 100644 --- a/cuda_core/cuda/core/experimental/_memory/_buffer.pyx +++ b/cuda_core/cuda/core/experimental/_memory/_buffer.pyx @@ -4,6 +4,7 @@ from __future__ import annotations +cimport cython from libc.stdint cimport uintptr_t from cuda.core.experimental._memory._device_memory_resource cimport DeviceMemoryResource @@ -12,7 +13,9 @@ from cuda.core.experimental._memory cimport _ipc from cuda.core.experimental._stream cimport Stream_accept, Stream from cuda.core.experimental._utils.cuda_utils cimport ( _check_driver_error as raise_if_driver_error, + HANDLE_RETURN, ) +from cuda.bindings cimport cydriver import abc from typing import TypeVar, Union @@ -310,46 +313,64 @@ cdef Buffer_init_mem_attrs(Buffer self): self._mem_attrs_inited = True -cdef int query_memory_attrs(_MemAttrs &out, uintptr_t ptr) except -1: - cdef int memory_type - ret, attrs = _query_memory_attrs(ptr) - if ret == driver.CUresult.CUDA_ERROR_NOT_INITIALIZED: - # Device class handles the cuInit call internally - from cuda.core.experimental import Device as _Device - _Device() - ret, attrs = _query_memory_attrs(ptr) - raise_if_driver_error(ret) - memory_type = attrs[0] +cdef int query_memory_attrs(_MemAttrs &out, uintptr_t ptr) except -1 nogil: + cdef unsigned int memory_type = 0 + cdef int is_managed = 0 + cdef int device_id = 0 + _query_memory_attrs(memory_type, is_managed, device_id, ptr) if memory_type == 0: # unregistered host pointer out.is_host_accessible = True out.is_device_accessible = False out.device_id = -1 + # for managed memory, the memory type can be CU_MEMORYTYPE_DEVICE, + # so we need to check it first not to falsely claim it is not + # host accessible. elif ( - memory_type == driver.CUmemorytype.CU_MEMORYTYPE_HOST - or memory_type == driver.CUmemorytype.CU_MEMORYTYPE_UNIFIED + is_managed + or memory_type == cydriver.CUmemorytype.CU_MEMORYTYPE_HOST ): - # TODO(ktokarski): should we compare host/device ptrs using cuPointerGetAttribute - # for exceptional cases when the same data can end up with different ptrs - # for host and device? + # For pinned memory allocated with cudaMallocHost or paged-locked + # with cudaHostRegister, the memory_type is + # cydriver.CUmemorytype.CU_MEMORYTYPE_HOST. + # TODO(ktokarski): In some cases, the registered memory requires + # using different ptr for device and host, we could check + # cuMemHostGetDevicePointer and + # CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM + # to double check the device accessibility. out.is_host_accessible = True out.is_device_accessible = True - out.device_id = attrs[1] - else: - # device/texture + out.device_id = device_id + elif memory_type == cydriver.CUmemorytype.CU_MEMORYTYPE_DEVICE: out.is_host_accessible = False out.is_device_accessible = True - out.device_id = attrs[1] + out.device_id = device_id + else: + raise ValueError(f"Unsupported memory type: {memory_type}") return 0 -cdef inline _query_memory_attrs(uintptr_t ptr): - cdef tuple attrs = ( - driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_MEMORY_TYPE, - driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, - ) - return driver.cuPointerGetAttributes(len(attrs), attrs, ptr) +cdef inline int _query_memory_attrs(unsigned int& memory_type, int & is_managed, int& device_id, cydriver.CUdeviceptr ptr) except -1 nogil: + cdef cydriver.CUpointer_attribute attrs[3] + cdef uintptr_t vals[3] + attrs[0] = cydriver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_MEMORY_TYPE + attrs[1] = cydriver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_IS_MANAGED + attrs[2] = cydriver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL + vals[0] = &memory_type + vals[1] = &is_managed + vals[2] = &device_id + + cdef cydriver.CUresult ret + ret = cydriver.cuPointerGetAttributes(3, attrs, vals, ptr) + if ret == cydriver.CUresult.CUDA_ERROR_NOT_INITIALIZED: + with cython.gil: + # Device class handles the cuInit call internally + from cuda.core.experimental import Device + Device() + ret = cydriver.cuPointerGetAttributes(2, attrs, vals, ptr) + HANDLE_RETURN(ret) + return 0 cdef class MemoryResource: diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index 796c12ea7d..268727d748 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -9,6 +9,7 @@ from cuda.bindings import driver except ImportError: from cuda import cuda as driver + try: import numpy as np except ImportError: @@ -27,6 +28,9 @@ VirtualMemoryResource, VirtualMemoryResourceOptions, ) +from cuda.core.experimental import ( + system as ccx_system, +) from cuda.core.experimental._dlpack import DLDeviceType from cuda.core.experimental._memory import IPCBufferDescriptor from cuda.core.experimental._utils.cuda_utils import handle_return @@ -235,6 +239,120 @@ def test_buffer_close(): buffer_close(DummyPinnedMemoryResource(device)) +def test_buffer_external_host(): + a = (ctypes.c_byte * 20)() + ptr = ctypes.addressof(a) + buffer = Buffer.from_handle(ptr, 20, owner=ptr) + assert not buffer.is_device_accessible + assert buffer.is_host_accessible + assert buffer.device_id == -1 + buffer.close() + + +@pytest.mark.parametrize("change_device", [True, False]) +def test_buffer_external_device(change_device): + n = ccx_system.num_devices + if n < 1: + pytest.skip("No devices found") + dev_id = n - 1 + d = Device(dev_id) + d.set_current() + buffer_ = d.allocate(size=32) + + if change_device: + # let's switch to a different device if possibe + # to make sure we get the original device id + d = Device(0) + d.set_current() + + buffer = Buffer.from_handle(int(buffer_.handle), 32) + assert buffer.is_device_accessible + assert not buffer.is_host_accessible + assert buffer.device_id == dev_id + buffer.close() + buffer_.close() + + +@pytest.mark.parametrize("change_device", [True, False]) +def test_buffer_external_pinned_alloc(change_device): + n = ccx_system.num_devices + if n < 1: + pytest.skip("No devices found") + dev_id = n - 1 + d = Device(dev_id) + d.set_current() + mr = DummyPinnedMemoryResource(d) + buffer_ = mr.allocate(size=32) + + if change_device: + # let's switch to a different device if possibe + # to make sure we get the original device id + d = Device(0) + d.set_current() + + buffer = Buffer.from_handle(int(buffer_.handle), 32) + assert buffer.is_device_accessible + assert buffer.is_host_accessible + assert buffer.device_id == dev_id + buffer.close() + buffer_.close() + + +@pytest.mark.parametrize("change_device", [True, False]) +def test_buffer_external_pinned_registered(change_device): + n = ccx_system.num_devices + if n < 1: + pytest.skip("No devices found") + dev_id = n - 1 + d = Device(dev_id) + d.set_current() + a = (ctypes.c_byte * 20)() + ptr = ctypes.addressof(a) + + buffer = Buffer.from_handle(ptr, 20, owner=ptr) + assert not buffer.is_device_accessible + assert buffer.is_host_accessible + assert buffer.device_id == -1 + + handle_return(driver.cuMemHostRegister(ptr, 20, 0)) + if change_device: + # let's switch to a different device if possibe + # to make sure we get the original device id + d = Device(0) + d.set_current() + + buffer = Buffer.from_handle(ptr, 20, owner=ptr) + assert buffer.is_device_accessible + assert buffer.is_host_accessible + assert buffer.device_id == dev_id + buffer.close() + + +@pytest.mark.parametrize("change_device", [True, False]) +def test_buffer_external_managed(change_device): + n = ccx_system.num_devices + if n < 1: + pytest.skip("No devices found") + dev_id = n - 1 + d = Device(dev_id) + d.set_current() + ptr = None + try: + ptr = handle_return(driver.cuMemAllocManaged(32, driver.CUmemAttach_flags.CU_MEM_ATTACH_GLOBAL.value)) + if change_device: + # let's switch to a different device if possibe + # to make sure we get the original device id + d = Device(0) + d.set_current() + buffer = Buffer.from_handle(ptr, 32) + assert buffer.is_device_accessible + assert buffer.is_host_accessible + assert buffer.device_id == dev_id + finally: + if ptr is not None: + handle_return(driver.cuMemFree(ptr)) + + def test_buffer_dunder_dlpack(): device = Device() device.set_current() From 91387b063ea64a300f1757ca2d34ea8e96e00484 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Thu, 27 Nov 2025 18:55:33 +0100 Subject: [PATCH 12/20] Test owner and mr cannot be specified together Signed-off-by: Kamil Tokarski --- cuda_core/tests/test_memory.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index 268727d748..9a6b18d968 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -353,6 +353,13 @@ def test_buffer_external_managed(change_device): handle_return(driver.cuMemFree(ptr)) +def test_memory_resource_and_owner_disallowed(): + with pytest.raises(ValueError, match="cannot be both specified together"): + a = (ctypes.c_byte * 20)() + ptr = ctypes.addressof(a) + Buffer.from_handle(ptr, 20, mr=DummyDeviceMemoryResource(Device()), owner=a) + + def test_buffer_dunder_dlpack(): device = Device() device.set_current() From 91c0af9938a142f4ebfd5419486268dd9552beed Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Thu, 27 Nov 2025 18:57:23 +0100 Subject: [PATCH 13/20] Test Buffer.close with owner Signed-off-by: Kamil Tokarski --- cuda_core/tests/test_memory.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index 9a6b18d968..ce61d8ad80 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -360,6 +360,17 @@ def test_memory_resource_and_owner_disallowed(): Buffer.from_handle(ptr, 20, mr=DummyDeviceMemoryResource(Device()), owner=a) +def test_owner_close(): + a = (ctypes.c_byte * 20)() + ptr = ctypes.addressof(a) + before = sys.getrefcount(a) + buffer = Buffer.from_handle(ptr, 20, owner=a) + assert sys.getrefcount(a) != before + buffer.close() + after = sys.getrefcount(a) + assert after == before + + def test_buffer_dunder_dlpack(): device = Device() device.set_current() From b74ef2c8a7c72b7f946a90d6aaf06d9fd36f4d92 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Thu, 27 Nov 2025 19:45:48 +0100 Subject: [PATCH 14/20] Add envelope checks (rquires_size_in_bytes, offset_bounds) Signed-off-by: Kamil Tokarski --- cuda_core/tests/test_strided_layout.py | 40 ++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/cuda_core/tests/test_strided_layout.py b/cuda_core/tests/test_strided_layout.py index b783257054..a88ac67999 100644 --- a/cuda_core/tests/test_strided_layout.py +++ b/cuda_core/tests/test_strided_layout.py @@ -120,6 +120,30 @@ def _cmp_layout_from_dense_vs_from_np(layout: StridedLayout, np_ref: np.ndarray, assert layout_from_np.strides_in_bytes == layout.strides_in_bytes +def _check_envelope(layout: StridedLayout, layout_spec: LayoutSpec): + orignal_vol = math.prod(layout_spec.shape) + min_offset, max_offset = layout.offset_bounds + if layout.volume == 0: + assert min_offset == 0 + assert max_offset == -1 + else: + assert min_offset >= 0 + assert min_offset <= max_offset + assert max_offset <= orignal_vol - 1 + if layout.is_dense: + assert min_offset == 0 + assert max_offset == math.prod(layout.shape) - 1 + else: + shape, strides = layout.shape, layout.strides + ref_min_offset = ref_max_offset = layout.slice_offset + ref_min_offset += sum(strides[i] * (shape[i] - 1) for i in range(layout.ndim) if strides[i] < 0) + ref_max_offset += sum(strides[i] * (shape[i] - 1) for i in range(layout.ndim) if strides[i] > 0) + assert min_offset == ref_min_offset + assert max_offset == ref_max_offset + assert 0 <= layout.required_size_in_bytes() <= orignal_vol * layout_spec.itemsize + assert layout.required_size_in_bytes() == (max_offset + 1) * layout.itemsize + + def _cmp_slice_offset( layout_0: StridedLayout, layout_1: StridedLayout, @@ -149,6 +173,7 @@ def test_dense_with_permutation_as_stride_order(layout_spec): layout, np_ref = _setup_layout_and_np_ref(layout_spec) _cmp_layout_and_array(layout, np_ref, False) _cmp_layout_from_dense_vs_from_np(layout, np_ref, False) + _check_envelope(layout, layout_spec) assert layout.stride_order == layout_spec.stride_order @@ -182,6 +207,7 @@ def test_permuted(layout_spec): layout, np_ref = _transform(layout, np_ref, layout_spec) _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides_transformed()) _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides_transformed()) + _check_envelope(layout, layout_spec) unit_dims_count = sum(dim == 1 for dim in np_ref.shape) if unit_dims_count <= 1: # stride order with multiple unit dimensions is not unique @@ -277,6 +303,7 @@ def test_slice(layout_spec, error_msg): _cmp_layout_and_array(sliced, sliced_ref, False) _cmp_layout_from_dense_vs_from_np(sliced, sliced_ref, False) _cmp_slice_offset(layout, sliced, np_ref, sliced_ref) + _check_envelope(sliced, layout_spec) layout = sliced np_ref = sliced_ref else: @@ -406,6 +433,7 @@ def test_reshape(layout_spec, new_shape, error_msg): reshaped_ref = np_ref.reshape(new_shape, copy=False) _cmp_layout_and_array(reshaped, reshaped_ref, False) _cmp_layout_from_dense_vs_from_np(reshaped, reshaped_ref, False) + _check_envelope(reshaped, layout_spec) else: # sanity check that numpy is not able to reshape without # a copy as well @@ -473,6 +501,7 @@ def test_flatten( layout, np_ref = _transform(layout, np_ref, layout_spec) _cmp_layout_and_array(layout, np_ref, layout_spec.has_no_strides_transformed()) _cmp_layout_from_dense_vs_from_np(layout, np_ref, layout_spec.has_no_strides_transformed()) + _check_envelope(layout, layout_spec) mask = flatten_mask2str(layout.flattened_axis_mask(), layout.ndim) assert mask == expected_axis_mask @@ -583,6 +612,8 @@ def test_flatten_together( flattened_0 = layout_0.flattened(mask=mask) flattened_1 = layout_1.flattened(mask=mask) + _check_envelope(flattened_0, layout_spec_0) + _check_envelope(flattened_1, layout_spec_1) for flattened, expected_layout in zip([flattened_0, flattened_1], [expected_layout_0, expected_layout_1]): assert flattened == expected_layout @@ -640,6 +671,7 @@ def test_squeezed(layout_spec): assert squeezed.shape == (0,) assert squeezed.strides == (0,) assert squeezed.slice_offset == layout.slice_offset + _check_envelope(squeezed, layout_spec) @pytest.mark.parametrize( @@ -681,6 +713,7 @@ def test_unsqueezed_layout(layout_spec, axes): has_no_strides = layout_spec.has_no_strides_transformed() and len(axes) == 0 _cmp_layout_and_array(unsqueezed, unsqueezed_ref, has_no_strides) _cmp_layout_from_dense_vs_from_np(unsqueezed, unsqueezed_ref, has_no_strides) + _check_envelope(unsqueezed, layout_spec) @pytest.mark.parametrize( @@ -748,11 +781,13 @@ def test_packed_unpacked( has_no_strides = layout_spec.has_no_strides_transformed() and layout.itemsize == new_itemsize _cmp_layout_and_array(packed, packed_ref, has_no_strides) _cmp_layout_from_dense_vs_from_np(packed, packed_ref, has_no_strides) + _check_envelope(packed, layout_spec) vec_size = new_itemsize // layout.itemsize assert packed.slice_offset * vec_size == layout.slice_offset unpacked = packed.repacked(layout.itemsize, axis=axis) _cmp_layout_and_array(unpacked, np_ref, has_no_strides) _cmp_layout_from_dense_vs_from_np(unpacked, np_ref, has_no_strides) + _check_envelope(unpacked, layout_spec) @pytest.mark.parametrize( @@ -801,6 +836,7 @@ def test_broadcast_layout( broadcasted_ref = np.broadcast_to(np_ref, new_shape) _cmp_layout_and_array(broadcasted, broadcasted_ref, False) _cmp_layout_from_dense_vs_from_np(broadcasted, broadcasted_ref, False) + _check_envelope(broadcasted, layout_spec) assert layout.is_unique ndim_diff = len(broadcasted_ref.shape) - len(np_ref.shape) expect_unique = all(broadcasted_ref.shape[i] == 1 for i in range(ndim_diff)) @@ -875,3 +911,7 @@ def test_to_dense(layout_spec, new_stride_order): dense_ref = np_ref.transpose(new_stride_order).copy(order="C").transpose(inv_permutation(new_stride_order)) _cmp_layout_and_array(dense, dense_ref, False) _cmp_layout_from_dense_vs_from_np(dense, dense_ref, False) + + assert dense.is_dense + assert dense.required_size_in_bytes() == np_ref.size * layout.itemsize + assert dense.offset_bounds == (0, np_ref.size - 1) From 2c0343ffc6687b233d19ce8c41031506f310a349 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Thu, 27 Nov 2025 20:36:44 +0100 Subject: [PATCH 15/20] Docs, annotation fixes, remove dlpack export mentions Signed-off-by: Kamil Tokarski --- cuda_core/cuda/core/experimental/_memoryview.pyx | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memoryview.pyx b/cuda_core/cuda/core/experimental/_memoryview.pyx index b6a88d07fb..dba77f00ad 100644 --- a/cuda_core/cuda/core/experimental/_memoryview.pyx +++ b/cuda_core/cuda/core/experimental/_memoryview.pyx @@ -58,12 +58,6 @@ cdef class StridedMemoryView: Stream ordering will be properly established unless ``-1`` is passed. - .. note:: - The StridedMemoryView can be exported with DLPack, in which case - no synchronization is performed. It is the user's responsibility to ensure - the data in ``exporting_obj`` is properly synchronized when consuming the view. - - Attributes ----------- ptr : int @@ -168,12 +162,12 @@ cdef class StridedMemoryView: The layout describing the shape, strides and itemsize of the elements in the buffer. dtype : :obj:`numpy.dtype`, optional - Optional dtype. The dtype is required for exporting the view with DLPack. + Optional dtype. If specified, the dtype's itemsize must match the layout's itemsize. To view the buffer with a different itemsize, please use :meth:`StridedLayout.repacked` first to transform the layout to the desired itemsize. is_readonly : bool, optional - Whether the mark the view as readonly. The flag will be forwarded to the DLPack capsule. + Whether the mark the view as readonly. """ cdef StridedMemoryView view = StridedMemoryView.__new__(cls) view_buffer_strided(view, buffer, layout, dtype, is_readonly) @@ -184,7 +178,7 @@ cdef class StridedMemoryView: ) -> StridedMemoryView: """ Creates a new view with adjusted layout and dtype. - Same as calling :meth:`from_buffer` with the current buffer, layout and dtype. + Same as calling :meth:`from_buffer` with the current buffer. """ cdef StridedMemoryView view = StridedMemoryView.__new__(self.__class__) if layout is None and dtype is None: @@ -198,7 +192,7 @@ cdef class StridedMemoryView: def copy_from( self, other : StridedMemoryView, stream : Stream, - allocator : MemoryResource | None = None, + allocator = None, blocking : bool | None = None, ): """ @@ -240,7 +234,7 @@ cdef class StridedMemoryView: def copy_to( self, other : StridedMemoryView, stream : Stream | None = None, - allocator : MemoryResource | None = None, + allocator = None, blocking : bool | None = None, ): """ From 598a2f140893f8962c7426e3d387f59789f41a27 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Thu, 27 Nov 2025 21:13:16 +0100 Subject: [PATCH 16/20] Add SMV.from_buffer/view tests Signed-off-by: Kamil Tokarski --- cuda_core/tests/test_utils.py | 114 +++++++++++++++++++++++++++++++++- 1 file changed, 113 insertions(+), 1 deletion(-) diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py index 3580507250..5b14b31018 100644 --- a/cuda_core/tests/test_utils.py +++ b/cuda_core/tests/test_utils.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE +import math + try: import cupy as cp except ImportError: @@ -15,7 +17,7 @@ import pytest from cuda.core.experimental import Device from cuda.core.experimental._memoryview import view_as_cai -from cuda.core.experimental.utils import StridedMemoryView, args_viewable_as_strided_memory +from cuda.core.experimental.utils import StridedLayout, StridedMemoryView, args_viewable_as_strided_memory def test_cast_to_3_tuple_success(): @@ -195,3 +197,113 @@ def _check_view(self, view, in_arr, dev): assert view.device_id == dev.device_id assert view.is_device_accessible is True assert view.exporting_obj is in_arr + + +def _dense_strides(shape, stride_order): + ndim = len(shape) + strides = [None] * ndim + if ndim > 0: + if stride_order == "C": + strides[-1] = 1 + for i in range(ndim - 2, -1, -1): + strides[i] = strides[i + 1] * shape[i + 1] + else: + assert stride_order == "F" + strides[0] = 1 + for i in range(1, ndim): + strides[i] = strides[i - 1] * shape[i - 1] + return tuple(strides) + + +@pytest.mark.parametrize("shape", [tuple(), (2, 3), (10, 10), (10, 13, 11)]) +@pytest.mark.parametrize("itemsize", [1, 4]) +@pytest.mark.parametrize("stride_order", ["C", "F"]) +@pytest.mark.parametrize("readonly", [True, False]) +def test_from_buffer(shape, itemsize, stride_order, readonly): + dev = Device() + dev.set_current() + layout = StridedLayout.dense(shape=shape, itemsize=itemsize, stride_order=stride_order) + required_size = layout.required_size_in_bytes() + assert required_size == math.prod(shape) * itemsize + buffer = dev.memory_resource.allocate(required_size) + view = StridedMemoryView.from_buffer(buffer, layout, is_readonly=readonly) + assert view.exporting_obj is buffer + assert view.layout is layout + assert view.ptr == int(buffer.handle) + assert view.shape == shape + assert view.strides == _dense_strides(shape, stride_order) + assert view.dtype is None + assert view.device_id == dev.device_id + assert view.is_device_accessible + assert view.readonly == readonly + + +@pytest.mark.parametrize("stride_order", ["C", "F"]) +def test_from_buffer_sliced(stride_order): + layout = StridedLayout.dense((5, 7), 2, stride_order=stride_order) + device = Device() + device.set_current() + buffer = device.memory_resource.allocate(layout.required_size_in_bytes()) + view = StridedMemoryView.from_buffer(buffer, layout) + assert view.shape == (5, 7) + + sliced_view = view.view(layout[:-2, 3:]) + assert sliced_view.shape == (3, 4) + expected_offset = 3 if stride_order == "C" else 3 * 5 + assert sliced_view.layout.slice_offset == expected_offset + assert sliced_view.layout.slice_offset_in_bytes == expected_offset * 2 + assert sliced_view.ptr == view.ptr + expected_offset * 2 + + +def test_from_buffer_too_small(): + layout = StridedLayout.dense((5, 4), 2) + d = Device() + d.set_current() + buffer = d.memory_resource.allocate(20) + with pytest.raises(ValueError, match="Expected at least 40 bytes, got 20 bytes."): + StridedMemoryView.from_buffer(buffer, layout) + + +def test_from_buffer_disallowed_negative_offset(): + layout = StridedLayout((5, 4), (-4, 1), 1) + d = Device() + d.set_current() + buffer = d.memory_resource.allocate(20) + with pytest.raises(ValueError, match="please use StridedLayout.to_dense()."): + StridedMemoryView.from_buffer(buffer, layout) + + +@pytest.mark.parametrize( + ("shape", "slices", "stride_order"), + [ + (shape, slices, stride_order) + for shape, slices in [ + ((5, 6), (2, slice(1, -1))), + ((10, 13, 11), (slice(None, None, 2), slice(None, None, -1), slice(2, -3))), + ] + for stride_order in ["C", "F"] + ], +) +def test_from_buffer_sliced_external(shape, slices, stride_order): + if np is None: + pytest.skip("NumPy is not installed") + a = np.arange(math.prod(shape), dtype=np.int32).reshape(shape, order=stride_order) + view = StridedMemoryView(a, -1) + layout = view.layout + assert layout.is_dense + assert layout.required_size_in_bytes() == a.nbytes + assert view.ptr == a.ctypes.data + + sliced_layout = layout[slices] + sliced_view = view.view(sliced_layout) + a_sliced = a[slices] + assert sliced_view.ptr == a_sliced.ctypes.data + assert sliced_view.ptr != view.ptr + + assert 0 <= sliced_layout.required_size_in_bytes() <= a.nbytes + assert not sliced_layout.is_dense + assert sliced_view.layout is sliced_layout + assert view.dtype == sliced_view.dtype + assert sliced_view.layout.itemsize == a_sliced.itemsize == layout.itemsize + assert sliced_view.shape == a_sliced.shape + assert sliced_view.layout.strides_in_bytes == a_sliced.strides From bbb227b793ad0d8416bdb08861b51f50e68c43bf Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Fri, 28 Nov 2025 12:28:01 +0100 Subject: [PATCH 17/20] Layout tests for SMV created from CAI Signed-off-by: Kamil Tokarski --- cuda_core/tests/test_utils.py | 39 +++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py index 5b14b31018..9d72b98af0 100644 --- a/cuda_core/tests/test_utils.py +++ b/cuda_core/tests/test_utils.py @@ -273,31 +273,52 @@ def test_from_buffer_disallowed_negative_offset(): StridedMemoryView.from_buffer(buffer, layout) +class _EnforceCAIView: + def __init__(self, array): + self.array = array + self.__cuda_array_interface__ = array.__cuda_array_interface__ + + +def _get_ptr(array): + if isinstance(array, np.ndarray): + return array.ctypes.data + else: + assert isinstance(array, cp.ndarray) + return array.data.ptr + + @pytest.mark.parametrize( - ("shape", "slices", "stride_order"), + ("shape", "slices", "stride_order", "view_as"), [ - (shape, slices, stride_order) + (shape, slices, stride_order, view_as) for shape, slices in [ ((5, 6), (2, slice(1, -1))), ((10, 13, 11), (slice(None, None, 2), slice(None, None, -1), slice(2, -3))), ] for stride_order in ["C", "F"] + for view_as in ["dlpack", "cai"] ], ) -def test_from_buffer_sliced_external(shape, slices, stride_order): - if np is None: - pytest.skip("NumPy is not installed") - a = np.arange(math.prod(shape), dtype=np.int32).reshape(shape, order=stride_order) - view = StridedMemoryView(a, -1) +def test_from_buffer_sliced_external(shape, slices, stride_order, view_as): + if view_as == "dlpack": + if np is None: + pytest.skip("NumPy is not installed") + a = np.arange(math.prod(shape), dtype=np.int32).reshape(shape, order=stride_order) + view = StridedMemoryView(a, -1) + else: + if cp is None: + pytest.skip("CuPy is not installed") + a = cp.arange(math.prod(shape), dtype=cp.int32).reshape(shape, order=stride_order) + view = StridedMemoryView(_EnforceCAIView(a), -1) layout = view.layout assert layout.is_dense assert layout.required_size_in_bytes() == a.nbytes - assert view.ptr == a.ctypes.data + assert view.ptr == _get_ptr(a) sliced_layout = layout[slices] sliced_view = view.view(sliced_layout) a_sliced = a[slices] - assert sliced_view.ptr == a_sliced.ctypes.data + assert sliced_view.ptr == _get_ptr(a_sliced) assert sliced_view.ptr != view.ptr assert 0 <= sliced_layout.required_size_in_bytes() <= a.nbytes From 26dfe3bfe01844a2bcfabd5619407e0eb69ef08a Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Mon, 1 Dec 2025 14:33:04 +0100 Subject: [PATCH 18/20] Fix missing host unregister call in buffer test Signed-off-by: Kamil Tokarski --- cuda_core/tests/test_memory.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index ce61d8ad80..77f34ab72b 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -242,7 +242,7 @@ def test_buffer_close(): def test_buffer_external_host(): a = (ctypes.c_byte * 20)() ptr = ctypes.addressof(a) - buffer = Buffer.from_handle(ptr, 20, owner=ptr) + buffer = Buffer.from_handle(ptr, 20, owner=a) assert not buffer.is_device_accessible assert buffer.is_host_accessible assert buffer.device_id == -1 @@ -315,17 +315,20 @@ def test_buffer_external_pinned_registered(change_device): assert buffer.device_id == -1 handle_return(driver.cuMemHostRegister(ptr, 20, 0)) - if change_device: - # let's switch to a different device if possibe - # to make sure we get the original device id - d = Device(0) - d.set_current() + try: + if change_device: + # let's switch to a different device if possibe + # to make sure we get the original device id + d = Device(0) + d.set_current() - buffer = Buffer.from_handle(ptr, 20, owner=ptr) - assert buffer.is_device_accessible - assert buffer.is_host_accessible - assert buffer.device_id == dev_id - buffer.close() + buffer = Buffer.from_handle(ptr, 20, owner=ptr) + assert buffer.is_device_accessible + assert buffer.is_host_accessible + assert buffer.device_id == dev_id + buffer.close() + finally: + handle_return(driver.cuMemHostUnregister(ptr)) @pytest.mark.parametrize("change_device", [True, False]) From 3adae5c5665f61b3e15b0670f41feb9b481bff0b Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Mon, 1 Dec 2025 15:53:09 +0100 Subject: [PATCH 19/20] Fix num attrib on re-try Signed-off-by: Kamil Tokarski --- cuda_core/cuda/core/experimental/_memory/_buffer.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda_core/cuda/core/experimental/_memory/_buffer.pyx b/cuda_core/cuda/core/experimental/_memory/_buffer.pyx index c31c5a8fe9..c8a3b49c05 100644 --- a/cuda_core/cuda/core/experimental/_memory/_buffer.pyx +++ b/cuda_core/cuda/core/experimental/_memory/_buffer.pyx @@ -368,7 +368,7 @@ cdef inline int _query_memory_attrs(unsigned int& memory_type, int & is_managed, # Device class handles the cuInit call internally from cuda.core.experimental import Device Device() - ret = cydriver.cuPointerGetAttributes(2, attrs, vals, ptr) + ret = cydriver.cuPointerGetAttributes(3, attrs, vals, ptr) HANDLE_RETURN(ret) return 0 From 7554164409fb731fed747fe6f1012c374e438bbd Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Mon, 1 Dec 2025 15:56:40 +0100 Subject: [PATCH 20/20] Call int on the buffer.handle Signed-off-by: Kamil Tokarski --- cuda_core/cuda/core/experimental/_memoryview.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda_core/cuda/core/experimental/_memoryview.pyx b/cuda_core/cuda/core/experimental/_memoryview.pyx index dba77f00ad..10d59c6763 100644 --- a/cuda_core/cuda/core/experimental/_memoryview.pyx +++ b/cuda_core/cuda/core/experimental/_memoryview.pyx @@ -635,7 +635,7 @@ cdef inline uintptr_t _get_data_ptr(object buffer, StridedLayout layout) except? f"Expected at least {layout.get_required_size_in_bytes()} bytes, " f"got {buffer.size} bytes." ) - return (buffer.handle) + layout.get_slice_offset_in_bytes() + return (int(buffer.handle)) + layout.get_slice_offset_in_bytes() cdef inline int view_buffer_strided(