@@ -502,8 +502,10 @@ while still producing correct behavior when using valid cartesian indices, such
502502strides (A:: StrideIndex ) = getfield (A, :strides )
503503@inline strides (A:: Vector{<:Any} ) = (StaticInt (1 ),)
504504@inline strides (A:: Array{<:Any,N} ) where {N} = (StaticInt (1 ), Base. tail (Base. strides (A))... )
505- function strides (x)
506- if defines_strides (x)
505+ @inline function strides (x:: X ) where {X}
506+ if ! (parent_type (X) <: X )
507+ return strides (parent (x))
508+ elseif defines_strides (X)
507509 return size_to_strides (size (x), One ())
508510 else
509511 return Base. strides (x)
@@ -519,11 +521,19 @@ function strides(A::ReshapedArray{T,N,P}) where {T, N, P<:AbstractVector}
519521 return Base. strides (A)
520522 end
521523end
524+ function strides (A:: ReshapedArray{T,N,P} ) where {T, N, P}
525+ if defines_strides (A)
526+ return size_to_strides (size (A), static (1 ))
527+ else
528+ return Base. strides (A)
529+ end
530+ end
531+
522532
523533@inline bmap (f:: F , t:: Tuple{} , x:: Number ) where {F} = ()
524534@inline bmap (f:: F , t:: Tuple{T} , x:: Number ) where {F, T} = (f (first (t),x), )
525535@inline bmap (f:: F , t:: Tuple , x:: Number ) where {F} = (f (first (t),x), bmap (f, Base. tail (t), x)... )
526- if VERSION ≥ v " 1.6.0-DEV.1581"
536+ @static if VERSION ≥ v " 1.6.0-DEV.1581"
527537 # from `reinterpret(reshape, ...)`
528538 @inline function strides (A:: Base.ReinterpretArray{R, N, T, B, true} ) where {R,N,T,B}
529539 P = strides (parent (A))
@@ -541,7 +551,6 @@ if VERSION ≥ v"1.6.0-DEV.1581"
541551 (One (), bmap (* , P, StaticInt (sizeof (T)) ÷ StaticInt (sizeof (R)))... )
542552 end
543553 end
544-
545554 # plain `reinterpret(...)`
546555 @inline function strides (A:: Base.ReinterpretArray{R, N, T, B, false} ) where {R,N,T,B}
547556 P = strides (parent (A))
@@ -553,6 +562,18 @@ if VERSION ≥ v"1.6.0-DEV.1581"
553562 (first (P), bmap (* , Base. tail (P), StaticInt (sizeof (T)) ÷ StaticInt (sizeof (R)))... )
554563 end
555564 end
565+ else
566+ # plain `reinterpret(...)`
567+ @inline function strides (A:: Base.ReinterpretArray{R, N, T} ) where {R,N,T}
568+ P = strides (parent (A))
569+ if sizeof (R) == sizeof (T)
570+ P
571+ elseif sizeof (R) > sizeof (T)
572+ (first (P), bmap (÷ , Base. tail (P), StaticInt (sizeof (R)) ÷ StaticInt (sizeof (T)))... )
573+ else # sizeof(R) < sizeof(T)
574+ (first (P), bmap (* , Base. tail (P), StaticInt (sizeof (T)) ÷ StaticInt (sizeof (R)))... )
575+ end
576+ end
556577end
557578# @inline strides(A) = _strides(A, Base.strides(A), contiguous_axis(A))
558579
0 commit comments