Skip to content

Commit af458e0

Browse files
committed
Merge branch 'master' of github.com:MasonProtter/ArrayInterface.jl
2 parents d450944 + adeb1d9 commit af458e0

File tree

3 files changed

+71
-7
lines changed

3 files changed

+71
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "3.1.12"
3+
version = "3.1.13"
44

55
[deps]
66
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"

src/stridelayout.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,32 @@ stride_rank(x, i) = stride_rank(x)[i]
218218
function stride_rank(::Type{R}) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}}
219219
return nstatic(Val(N))
220220
end
221+
if VERSION v"1.6.0-DEV.1581"
222+
@inline function stride_rank(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
223+
_stride_rank_reinterpret(stride_rank(B), gt(StaticInt{NB}(), StaticInt{NA}()))
224+
end
225+
@inline _stride_rank_reinterpret(sr, ::False) = (One(), map(Base.Fix2(+,One()),sr)...)
226+
@inline _stride_rank_reinterpret(sr::Tuple{One,Vararg}, ::True) = map(Base.Fix2(-,One()), tail(sr))
227+
# if the leading dim's `stride_rank` is not one, then that means the individual elements are split across an axis, which ArrayInterface
228+
# doesn't currently have a means of representing.
229+
@inline function contiguous_axis(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
230+
_reinterpret_contiguous_axis(stride_rank(B), dense_dims(B), contiguous_axis(B), gt(StaticInt{NB}(), StaticInt{NA}()))
231+
end
232+
@inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::False) = One()
233+
@inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::True) = Zero()
234+
@generated function _reinterpret_contiguous_axis(t::Tuple{One,Vararg{StaticInt,N}}, d::Tuple{True,Vararg{StaticBool,N}}, ::One, ::True) where {N}
235+
for n in 1:N
236+
if t.parameters[n+1].parameters[1] === 2
237+
if d.parameters[n+1] === True
238+
return :(StaticInt{$n}())
239+
else
240+
return :(Zero())
241+
end
242+
end
243+
end
244+
:(Zero())
245+
end
246+
end
221247

222248
function stride_rank(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
223249
_reshaped_striderank(is_column_major(P), Val{N}(), Val{M}())

test/runtests.jl

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -321,17 +321,17 @@ using OffsetArrays
321321
x = zeros(100);
322322
# R = reshape(view(x, 1:100), (10,10));
323323
# A = zeros(3,4,5);
324-
A = Wrapper(reshape(view(x, 1:60), (3,4,5)))
324+
A = Wrapper(reshape(view(x, 1:60), (3,4,5)));
325325
B = A .== 0;
326-
D1 = view(A, 1:2:3, :, :) # first dimension is discontiguous
327-
D2 = view(A, :, 2:2:4, :) # first dimension is contiguous
326+
D1 = view(A, 1:2:3, :, :); # first dimension is discontiguous
327+
D2 = view(A, :, 2:2:4, :); # first dimension is contiguous
328328

329329
@test @inferred(ArrayInterface.defines_strides(x))
330330
@test @inferred(ArrayInterface.defines_strides(A))
331331
@test @inferred(ArrayInterface.defines_strides(D1))
332332
@test !@inferred(ArrayInterface.defines_strides(view(A, :, [1,2],1)))
333333
@test @inferred(ArrayInterface.defines_strides(DenseWrapper{Int,2,Matrix{Int}}))
334-
334+
335335
@test @inferred(device(A)) === ArrayInterface.CPUPointer()
336336
@test @inferred(device(B)) === ArrayInterface.CPUIndex()
337337
@test @inferred(device(-1:19)) === ArrayInterface.CPUIndex()
@@ -347,7 +347,6 @@ using OffsetArrays
347347
@test @inferred(device(OffsetArray(@MArray(zeros(2,2,2)),8,-2,-5))) === ArrayInterface.CPUPointer()
348348
@test isnothing(device("Hello, world!"))
349349
@test @inferred(device(DenseWrapper{Int,2,Matrix{Int}})) === ArrayInterface.CPUPointer()
350-
351350
#=
352351
@btime ArrayInterface.contiguous_axis($(reshape(view(zeros(100), 1:60), (3,4,5))))
353352
0.047 ns (0 allocations: 0 bytes)
@@ -373,7 +372,7 @@ using OffsetArrays
373372
@test @inferred(contiguous_axis(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing
374373
@test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :))) === nothing
375374
@test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :)')) === nothing
376-
375+
377376
@test @inferred(ArrayInterface.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) == (true,false,false)
378377
@test @inferred(ArrayInterface.contiguous_axis_indicator(A)) == (true,false,false)
379378
@test @inferred(ArrayInterface.contiguous_axis_indicator(B)) == (true,false,false)
@@ -418,6 +417,8 @@ using OffsetArrays
418417
@test @inferred(stride_rank(DummyZeros(3,4)')) === nothing
419418
@test @inferred(stride_rank(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing
420419
@test @inferred(stride_rank(view(DummyZeros(3,4), 1, :))) === nothing
420+
421+
421422
#=
422423
@btime ArrayInterface.is_column_major($(PermutedDimsArray(A,(3,1,2))))
423424
0.047 ns (0 allocations: 0 bytes)
@@ -479,6 +480,43 @@ using OffsetArrays
479480
Am = @MMatrix rand(2,10);
480481
@test @inferred(ArrayInterface.strides(view(Am,1,:))) === (StaticInt(2),)
481482

483+
if VERSION v"1.6.0-DEV.1581" # reinterpret(reshape,...) tests
484+
C1 = reinterpret(reshape, Float64, PermutedDimsArray(Array{Complex{Float64}}(undef, 3,4,5), (2,1,3)));
485+
C2 = reinterpret(reshape, Complex{Float64}, PermutedDimsArray(view(A,1:2,:,:), (1,3,2)));
486+
C3 = reinterpret(reshape, Complex{Float64}, PermutedDimsArray(Wrapper(reshape(view(x, 1:24), (2,3,4))), (1,3,2)));
487+
488+
@test @inferred(ArrayInterface.defines_strides(C1))
489+
@test @inferred(ArrayInterface.defines_strides(C2))
490+
@test @inferred(ArrayInterface.defines_strides(C3))
491+
492+
@test @inferred(device(C1)) === ArrayInterface.CPUPointer()
493+
@test @inferred(device(C2)) === ArrayInterface.CPUPointer()
494+
@test @inferred(device(C3)) === ArrayInterface.CPUPointer()
495+
496+
@test @inferred(contiguous_batch_size(C1)) === ArrayInterface.StaticInt(0)
497+
@test @inferred(contiguous_batch_size(C2)) === ArrayInterface.StaticInt(0)
498+
@test @inferred(contiguous_batch_size(C3)) === ArrayInterface.StaticInt(0)
499+
500+
@test @inferred(stride_rank(C1)) == (1,3,2,4)
501+
@test @inferred(stride_rank(C2)) == (2,1)
502+
@test @inferred(stride_rank(C3)) == (2,1)
503+
504+
@test @inferred(contiguous_axis(C1)) === StaticInt(1)
505+
@test @inferred(contiguous_axis(C2)) === StaticInt(0)
506+
@test @inferred(contiguous_axis(C3)) === StaticInt(2)
507+
508+
@test @inferred(ArrayInterface.contiguous_axis_indicator(C1)) == (true,false,false,false)
509+
@test @inferred(ArrayInterface.contiguous_axis_indicator(C2)) == (false,false)
510+
@test @inferred(ArrayInterface.contiguous_axis_indicator(C3)) == (false,true)
511+
512+
@test @inferred(ArrayInterface.is_column_major(C1)) === False()
513+
@test @inferred(ArrayInterface.is_column_major(C2)) === False()
514+
@test @inferred(ArrayInterface.is_column_major(C3)) === False()
515+
516+
@test @inferred(dense_dims(C1)) == (true,true,true,true)
517+
@test @inferred(dense_dims(C2)) == (false,false)
518+
@test @inferred(dense_dims(C3)) == (true,true)
519+
end
482520
end
483521

484522
@testset "Static-Dynamic Size, Strides, and Offsets" begin

0 commit comments

Comments
 (0)