Skip to content

Commit 60e8236

Browse files
committed
Fix ReinterpretArray's definition of strides to be inline with ArrayInterface's meaning
1 parent cd6e5d2 commit 60e8236

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

src/stridelayout.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,27 @@ function strides(x)
443443
return Base.strides(x)
444444
end
445445
end
446+
@inline bmap(f::F, t::Tuple{}, x::Number) where {F} = ()
447+
@inline bmap(f::F, t::Tuple{T}, x::Number) where {F, T} = (f(first(t),x), )
448+
@inline bmap(f::F, t::Tuple, x::Number) where {F} = (f(first(t),x), bmap(f, Base.tail(t), x)...)
449+
if VERSION v"1.6.0-DEV.1581"
450+
@inline @inline function strides(A::Base.ReinterpretArray{R, N, T, B, true}) where {R,N,T,B}
451+
P = strides(parent(A))
452+
if sizeof(R) == sizeof(T)
453+
P
454+
elseif sizeof(R) > sizeof(T)
455+
x = Base.tail(P)
456+
fx = first(x)
457+
if fx isa Int
458+
(One(), bmap(Base.sdiv_int, Base.tail(x), fx)...)
459+
else
460+
(One(), bmap(÷, Base.tail(x), fx)...)
461+
end
462+
else
463+
(One(), bmap(*, P, StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...)
464+
end
465+
end
466+
end
446467
#@inline strides(A) = _strides(A, Base.strides(A), contiguous_axis(A))
447468

448469
strides(::AbstractRange) = (One(),)

test/runtests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,16 @@ end
647647
@test @inferred(ArrayInterface.dense_dims(view(Sr2,:,2))) === (True(),)
648648
@test @inferred(ArrayInterface.dense_dims(view(Sr2,:,2:3))) === (True(),True())
649649
@test @inferred(ArrayInterface.dense_dims(view(Sr2,2:3,:))) === (True(),False())
650+
651+
Ar2c = reinterpret(reshape, Complex{Float64}, view(rand(2, 5, 7), :, 2:4, 3:5));
652+
@test @inferred(ArrayInterface.strides(Ar2c)) === (StaticInt(1), 5)
653+
654+
Ac2r = reinterpret(reshape, Float64, view(rand(ComplexF64, 5, 7), 2:4, 3:6));
655+
@test @inferred(ArrayInterface.strides(Ac2r)) === (StaticInt(1), StaticInt(2), 10)
656+
657+
Ac2t = reinterpret(reshape, Tuple{Float64,Float64}, view(rand(ComplexF64, 5, 7), 2:4, 3:6));
658+
@test @inferred(ArrayInterface.strides(Ac2t)) === (StaticInt(1), 5)
659+
650660
end
651661
end
652662

0 commit comments

Comments
 (0)