@@ -135,9 +135,10 @@ cdef class StridedLayout:
135135 """
136136 cdef OrderFlag order_flag
137137 cdef axis_vec_t stride_order_vec
138+ cdef bint is_dense = other.get_is_dense()
138139
139140 if stride_order == "K":
140- if other.get_is_dense() :
141+ if is_dense :
141142 return other
142143 other.get_stride_order(stride_order_vec )
143144 order_flag = ORDER_PERM
@@ -149,10 +150,10 @@ cdef class StridedLayout:
149150 f"or a permutation tuple. Got: {stride_order}"
150151 )
151152 elif order_flag == ORDER_C:
152- if other.get_is_contiguous_c():
153+ if is_dense and other.get_is_contiguous_c():
153154 return other
154155 elif order_flag == ORDER_F:
155- if other.get_is_contiguous_f():
156+ if is_dense and other.get_is_contiguous_f():
156157 return other
157158
158159 cdef StridedLayout new_layout = StridedLayout.__new__ (cls )
@@ -928,11 +929,12 @@ cdef inline int validate_reshaped_shape(BaseLayout& new_shape, int64_t old_volum
928929 else :
929930 raise ValueError (" There can be at most one -1 extent in a shape" )
930931 cdef int64_t new_volume = _c_abs(_volume(new_shape))
931- if new_volume == 0 and axis != - 1 :
932- raise ValueError (" The -1 extent is ambiguous when the volume is 0" )
933- if new_volume != old_volume:
934- if axis == - 1 :
932+ if axis == - 1 :
933+ if new_volume != old_volume:
935934 raise ValueError (f" The original volume {old_volume} and the new volume {new_volume} must be equal." )
935+ else :
936+ if new_volume == 0 :
937+ raise ValueError (" The -1 extent is ambiguous when the specified sub-volume is 0" )
936938 extent = old_volume // new_volume
937939 if extent * new_volume != old_volume:
938940 raise ValueError (f" The original volume {old_volume} must be divisible by the specified sub-volume {new_volume}." )
@@ -957,6 +959,11 @@ cdef inline axes_mask_t axis_mask_from_range(int ndim, int start_axis, int end_a
957959
958960cdef inline int flatten_strides_in_c_index_order(BaseLayout& out_layout, BaseLayout& in_layout, axes_mask_t axis_mask) except - 1 nogil:
959961 cdef int ndim = in_layout.ndim
962+ if ndim == 0 :
963+ init_base_layout(out_layout, 1 )
964+ out_layout.shape[0 ] = 1
965+ out_layout.strides[0 ] = 1
966+ return 1
960967 init_base_layout(out_layout, ndim)
961968 cdef int group_start = 0
962969 cdef int group_end = 0
@@ -1021,16 +1028,19 @@ cdef inline bint split_strides_in_c_index_order(BaseLayout& out_layout, BaseLayo
10211028 _zero_strides(out_layout)
10221029 while i >= 0 :
10231030 extent = in_shape[i]
1024- group_vol = 1
10251031 group_stride = in_strides[i]
1026- while new_i >= 0 and group_vol < extent:
1032+ group_vol = 1
1033+ while new_i >= 0 :
10271034 new_extent = out_layout.shape[new_i]
10281035 if new_extent == 0 :
10291036 return False
1030- group_vol = _overflow_checked_mul(group_vol, new_extent)
1031- out_layout.strides[new_i] = group_stride
1032- group_stride = _overflow_checked_mul(group_stride, new_extent)
1033- new_i -= 1
1037+ if new_extent == 1 or group_vol < extent:
1038+ out_layout.strides[new_i] = group_stride
1039+ group_stride = _overflow_checked_mul(group_stride, new_extent)
1040+ group_vol = _overflow_checked_mul(group_vol, new_extent)
1041+ new_i -= 1
1042+ else :
1043+ break
10341044 if group_vol != extent:
10351045 return False
10361046 i -= 1
0 commit comments