Skip to content

Commit fab878e

Browse files
committed
Adjust flattening scalars to numpy/cupy behavior, fix shape validation in reshape, fix to dense with sliced views
Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
1 parent c64b6d4 commit fab878e

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

cuda_core/cuda/core/experimental/_layout.pyx

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,10 @@ cdef class StridedLayout:
135135
"""
136136
cdef OrderFlag order_flag
137137
cdef axis_vec_t stride_order_vec
138+
cdef bint is_dense = other.get_is_dense()
138139

139140
if stride_order == "K":
140-
if other.get_is_dense():
141+
if is_dense:
141142
return other
142143
other.get_stride_order(stride_order_vec)
143144
order_flag = ORDER_PERM
@@ -149,10 +150,10 @@ cdef class StridedLayout:
149150
f"or a permutation tuple. Got: {stride_order}"
150151
)
151152
elif order_flag == ORDER_C:
152-
if other.get_is_contiguous_c():
153+
if is_dense and other.get_is_contiguous_c():
153154
return other
154155
elif order_flag == ORDER_F:
155-
if other.get_is_contiguous_f():
156+
if is_dense and other.get_is_contiguous_f():
156157
return other
157158

158159
cdef StridedLayout new_layout = StridedLayout.__new__(cls)
@@ -928,11 +929,12 @@ cdef inline int validate_reshaped_shape(BaseLayout& new_shape, int64_t old_volum
928929
else:
929930
raise ValueError("There can be at most one -1 extent in a shape")
930931
cdef int64_t new_volume = _c_abs(_volume(new_shape))
931-
if new_volume == 0 and axis != -1:
932-
raise ValueError("The -1 extent is ambiguous when the volume is 0")
933-
if new_volume != old_volume:
934-
if axis == -1:
932+
if axis == -1:
933+
if new_volume != old_volume:
935934
raise ValueError(f"The original volume {old_volume} and the new volume {new_volume} must be equal.")
935+
else:
936+
if new_volume == 0:
937+
raise ValueError("The -1 extent is ambiguous when the specified sub-volume is 0")
936938
extent = old_volume // new_volume
937939
if extent * new_volume != old_volume:
938940
raise ValueError(f"The original volume {old_volume} must be divisible by the specified sub-volume {new_volume}.")
@@ -957,6 +959,11 @@ cdef inline axes_mask_t axis_mask_from_range(int ndim, int start_axis, int end_a
957959

958960
cdef inline int flatten_strides_in_c_index_order(BaseLayout& out_layout, BaseLayout& in_layout, axes_mask_t axis_mask) except -1 nogil:
959961
cdef int ndim = in_layout.ndim
962+
if ndim == 0:
963+
init_base_layout(out_layout, 1)
964+
out_layout.shape[0] = 1
965+
out_layout.strides[0] = 1
966+
return 1
960967
init_base_layout(out_layout, ndim)
961968
cdef int group_start = 0
962969
cdef int group_end = 0
@@ -1021,16 +1028,19 @@ cdef inline bint split_strides_in_c_index_order(BaseLayout& out_layout, BaseLayo
10211028
_zero_strides(out_layout)
10221029
while i >= 0:
10231030
extent = in_shape[i]
1024-
group_vol = 1
10251031
group_stride = in_strides[i]
1026-
while new_i >= 0 and group_vol < extent:
1032+
group_vol = 1
1033+
while new_i >= 0:
10271034
new_extent = out_layout.shape[new_i]
10281035
if new_extent == 0:
10291036
return False
1030-
group_vol = _overflow_checked_mul(group_vol, new_extent)
1031-
out_layout.strides[new_i] = group_stride
1032-
group_stride = _overflow_checked_mul(group_stride, new_extent)
1033-
new_i -= 1
1037+
if new_extent == 1 or group_vol < extent:
1038+
out_layout.strides[new_i] = group_stride
1039+
group_stride = _overflow_checked_mul(group_stride, new_extent)
1040+
group_vol = _overflow_checked_mul(group_vol, new_extent)
1041+
new_i -= 1
1042+
else:
1043+
break
10341044
if group_vol != extent:
10351045
return False
10361046
i -= 1

0 commit comments

Comments
 (0)