From f67385f52984ca89098aa7e43862670caabc78df Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 7 Nov 2022 07:00:00 -0500 Subject: [PATCH 1/5] Remove optionally static range types This is complimentary to https://github.com/SciML/Static.jl/pull/88 and would be a big move in disentangling static types from ArrayInterface --- Project.toml | 2 +- src/ArrayInterface.jl | 3 +- src/ranges.jl | 351 +----------------------------------------- test/ranges.jl | 94 ----------- 4 files changed, 6 insertions(+), 444 deletions(-) diff --git a/Project.toml b/Project.toml index ca79133ce..df4499373 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" ArrayInterfaceCore = "0.1.3" Compat = "3, 4" IfElse = "0.1" -Static = "0.7" +Static = "0.8" julia = "1.6" [extras] diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index 043b310bf..e226aa587 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -23,7 +23,8 @@ import ArrayInterfaceCore: known_first, known_step, known_last using Static using Static: Zero, One, nstatic, eq, ne, gt, ge, lt, le, eachop, eachop_tuple, - permute, invariant_permutation, field_type, reduce_tup, find_first_eq + permute, invariant_permutation, field_type, reduce_tup, find_first_eq, + OptionallyStaticUnitRange, OptionallyStaticStepRange, OptionallyStaticRange using IfElse diff --git a/src/ranges.jl b/src/ranges.jl index ca58ae620..c7767db3f 100644 --- a/src/ranges.jl +++ b/src/ranges.jl @@ -1,129 +1,4 @@ -""" - OptionallyStaticUnitRange(start, stop) <: AbstractUnitRange{Int} - -Similar to `UnitRange` except each field may be an `Int` or `StaticInt`. An -`OptionallyStaticUnitRange` is intended to be constructed internally from other valid -indices. Therefore, users should not expect the same checks are used to ensure construction -of a valid `OptionallyStaticUnitRange` as a `UnitRange`. -""" -struct OptionallyStaticUnitRange{F<:CanonicalInt,L<:CanonicalInt} <: AbstractUnitRange{Int} - start::F - stop::L - - function OptionallyStaticUnitRange(start::CanonicalInt, stop::CanonicalInt) - new{typeof(start),typeof(stop)}(start, stop) - end - function OptionallyStaticUnitRange(start, stop) - OptionallyStaticUnitRange(canonicalize(start), canonicalize(stop)) - end - function OptionallyStaticUnitRange(x::AbstractRange) - step(x) == 1 && return OptionallyStaticUnitRange(static_first(x), static_last(x)) - - errmsg(x) = throw(ArgumentError("step must be 1, got $(step(x))")) # avoid GC frame - errmsg(x) - end - OptionallyStaticUnitRange{F,L}(x::AbstractRange) where {F,L} = OptionallyStaticUnitRange(x) - function OptionallyStaticUnitRange{StaticInt{F},StaticInt{L}}() where {F,L} - new{StaticInt{F},StaticInt{L}}() - end -end - -""" - OptionallyStaticStepRange(start, step, stop) <: OrdinalRange{Int,Int} - -Similarly to [`OptionallyStaticUnitRange`](@ref), `OptionallyStaticStepRange` permits -a combination of static and standard primitive `Int`s to construct a range. It -specifically enables the use of ranges without a step size of 1. It may be constructed -through the use of `OptionallyStaticStepRange` directly or using static integers with -the range operator (i.e., `:`). - -```julia -julia> using ArrayInterface - -julia> x = ArrayInterface.static(2); - -julia> x:x:10 -static(2):static(2):10 - -julia> ArrayInterface.OptionallyStaticStepRange(x, x, 10) -static(2):static(2):10 - -``` -""" -struct OptionallyStaticStepRange{F<:CanonicalInt,S<:CanonicalInt,L<:CanonicalInt} <: OrdinalRange{Int,Int} - start::F - step::S - stop::L - - function OptionallyStaticStepRange(start::CanonicalInt, step::CanonicalInt, stop::CanonicalInt) - lst = _steprange_last(start, step, stop) - new{typeof(start),typeof(step),typeof(lst)}(start, step, lst) - end - function OptionallyStaticStepRange(start, step, stop) - OptionallyStaticStepRange(canonicalize(start), canonicalize(step), canonicalize(stop)) - end - function OptionallyStaticStepRange(x::AbstractRange) - return OptionallyStaticStepRange(static_first(x), static_step(x), static_last(x)) - end -end - -# to make StepRange constructor inlineable, so optimizer can see `step` value -@inline function _steprange_last(start::StaticInt, step::StaticInt, stop::StaticInt) - return StaticInt(_steprange_last(Int(start), Int(step), Int(stop))) -end -@inline function _steprange_last(start, step::StaticInt, stop::StaticInt) - if step === one(step) - # we don't need to check the `stop` if we know it acts like a unit range - return stop - else - return _steprange_last(start, Int(step), Int(stop)) - end -end -@inline function _steprange_last(start, step, stop) - z = zero(step) - if step === z - throw(ArgumentError("step cannot be zero")) - else - if stop == start - return Int(stop) - else - if step > z - if stop > start - return stop - Int(unsigned(stop - start) % step) - else - return Int(start - one(start)) - end - else - if stop > start - return Int(start + one(start)) - else - return stop + Int(unsigned(start - stop) % -step) - end - end - end - end -end - -""" - SUnitRange(start::Int, stop::Int) - -An alias for `OptionallyStaticUnitRange` where both the start and stop are known statically. -""" -const SUnitRange{F,L} = OptionallyStaticUnitRange{StaticInt{F},StaticInt{L}} -SUnitRange(start::Int, stop::Int) = SUnitRange{start,stop}() - -""" - SOneTo(n::Int) - -An alias for `OptionallyStaticUnitRange` usfeul for statically sized axes. -""" -const SOneTo{L} = SUnitRange{1,L} -SOneTo(n::Int) = SOneTo{n}() - -const OptionallyStaticRange = Union{<:OptionallyStaticUnitRange,<:OptionallyStaticStepRange} - - ArrayInterfaceCore.known_first(::Type{<:OptionallyStaticUnitRange{StaticInt{F}}}) where {F} = F::Int ArrayInterfaceCore.known_first(::Type{<:OptionallyStaticStepRange{StaticInt{F}}}) where {F} = F::Int @@ -132,210 +7,6 @@ ArrayInterfaceCore.known_step(::Type{<:OptionallyStaticStepRange{<:Any,StaticInt ArrayInterfaceCore.known_last(::Type{<:OptionallyStaticUnitRange{<:Any,StaticInt{L}}}) where {L} = L::Int ArrayInterfaceCore.known_last(::Type{<:OptionallyStaticStepRange{<:Any,<:Any,StaticInt{L}}}) where {L} = L::Int -@inline function Base.first(r::OptionallyStaticRange)::Int - if known_first(r) === nothing - return getfield(r, :start) - else - return known_first(r) - end -end -function Base.step(r::OptionallyStaticStepRange)::Int - if known_step(r) === nothing - return getfield(r, :step) - else - return known_step(r) - end -end -@inline function Base.last(r::OptionallyStaticRange)::Int - if known_last(r) === nothing - return getfield(r, :stop) - else - return known_last(r) - end -end - -Base.:(:)(L::Integer, ::StaticInt{U}) where {U} = OptionallyStaticUnitRange(L, StaticInt(U)) -Base.:(:)(::StaticInt{L}, U::Integer) where {L} = OptionallyStaticUnitRange(StaticInt(L), U) -function Base.:(:)(::StaticInt{L}, ::StaticInt{U}) where {L,U} - return OptionallyStaticUnitRange(StaticInt(L), StaticInt(U)) -end -function Base.:(:)(::StaticInt{F}, ::StaticInt{S}, ::StaticInt{L}) where {F,S,L} - return OptionallyStaticStepRange(StaticInt(F), StaticInt(S), StaticInt(L)) -end -function Base.:(:)(start::Integer, ::StaticInt{S}, ::StaticInt{L}) where {S,L} - return OptionallyStaticStepRange(start, StaticInt(S), StaticInt(L)) -end -function Base.:(:)(::StaticInt{F}, ::StaticInt{S}, stop::Integer) where {F,S} - return OptionallyStaticStepRange(StaticInt(F), StaticInt(S), stop) -end -function Base.:(:)(::StaticInt{F}, step::Integer, ::StaticInt{L}) where {F,L} - return OptionallyStaticStepRange(StaticInt(F), step, StaticInt(L)) -end -function Base.:(:)(start::Integer, step::Integer, ::StaticInt{L}) where {L} - return OptionallyStaticStepRange(start, step, StaticInt(L)) -end -function Base.:(:)(start::Integer, ::StaticInt{S}, stop::Integer) where {S} - return OptionallyStaticStepRange(start, StaticInt(S), stop) -end -function Base.:(:)(::StaticInt{F}, step::Integer, stop::Integer) where {F} - return OptionallyStaticStepRange(StaticInt(F), step, stop) -end -Base.:(:)(start::StaticInt{F}, ::StaticInt{1}, stop::StaticInt{L}) where {F,L} = start:stop -Base.:(:)(start::Integer, ::StaticInt{1}, stop::StaticInt{L}) where {L} = start:stop -Base.:(:)(start::StaticInt{F}, ::StaticInt{1}, stop::Integer) where {F} = start:stop -function Base.:(:)(start::Integer, ::StaticInt{1}, stop::Integer) - OptionallyStaticUnitRange(start, stop) -end - -Base.isempty(r::OptionallyStaticUnitRange{One}) = last(r) <= 0 -Base.isempty(r::OptionallyStaticUnitRange) = first(r) > last(r) -function Base.isempty(r::OptionallyStaticStepRange) - (r.start != r.stop) & ((r.step > 0) != (r.stop > r.start)) -end - -function Base.checkindex( - ::Type{Bool}, - ::SUnitRange{F1,L1}, - ::SUnitRange{F2,L2} -) where {F1,L1,F2,L2} - - (F1::Int <= F2::Int) && (L1::Int >= L2::Int) -end - -@propagate_inbounds function Base.getindex( - r::OptionallyStaticUnitRange, - s::AbstractUnitRange{<:Integer}, -) - @boundscheck checkbounds(r, s) - f = static_first(r) - fnew = f - one(f) - return (fnew+static_first(s)):(fnew+static_last(s)) -end - -@propagate_inbounds function Base.getindex(x::OptionallyStaticUnitRange{StaticInt{1}}, i::Int) - @boundscheck checkbounds(x, i) - i -end -@propagate_inbounds function Base.getindex(x::OptionallyStaticUnitRange, i::Int) - val = first(x) + (i - 1) - @boundscheck ((i < 1) || val > last(x)) && throw(BoundsError(x, i)) - val::Int -end - -@noinline unequal_error(x,y) = @assert false "Unequal Indices: x == $x != $y == y" -@inline check_equal(x, y) = x == y || unequal_error(x,y) -_try_static(::Nothing, ::Nothing) = nothing -_try_static(x::Int, ::Nothing) = x -_try_static(::Nothing, y::Int) = y -@inline _try_static(::StaticInt{N}, ::StaticInt{N}) where {N} = StaticInt{N}() -@inline function _try_static(::StaticInt{M}, ::StaticInt{N}) where {M,N} - @assert false "Unequal Indices: StaticInt{$M}() != StaticInt{$N}()" -end -@propagate_inbounds _try_static(::StaticInt{N}, x) where {N} = static(_try_static(N, x)) -@propagate_inbounds _try_static(x, ::StaticInt{N}) where {N} = static(_try_static(N, x)) -@propagate_inbounds function _try_static(x, y) - @boundscheck check_equal(x, y) - return x -end - -## length -@inline function Base.length(r::OptionallyStaticUnitRange) - if isempty(r) - return 0 - else - return last(r) - first(r) + 1 - end -end -Base.length(r::OptionallyStaticStepRange) = _range_length(first(r), step(r), last(r)) -_range_length(start, s, stop) = nothing -@inline function _range_length(start::Int, s::Int, stop::Int) - if s > 0 - if stop < start # isempty - return 0 - else - return Int(div(stop - start, s)) + 1 - end - else - if stop > start # isempty - return 0 - else - return Int(div(start - stop, -s)) + 1 - end - end -end - -Base.AbstractUnitRange{Int}(r::OptionallyStaticUnitRange) = r -function Base.AbstractUnitRange{T}(r::OptionallyStaticUnitRange) where {T} - if known_first(r) === 1 && T <: Integer - return OneTo{T}(last(r)) - else - return UnitRange{T}(first(r), last(r)) - end -end - -@inline function Base.iterate(r::OptionallyStaticRange) - isempty(r) && return nothing - fi = Int(first(r)); - fi, fi -end -function Base.iterate(::SUnitRange{F,L}) where {F,L} - if L::Int < F::Int - return nothing - else - return (F::Int, F::Int) - end -end -function Base.iterate(::SOneTo{n}, s::Int) where {n} - if s < n::Int - s2 = s + 1 - return (s2, s2) - else - return nothing - end -end - -Base.to_shape(x::OptionallyStaticRange) = Base.length(x) -Base.to_shape(x::Slice{T}) where {T<:OptionallyStaticRange} = Base.length(x) -Base.axes(S::Slice{<:OptionallyStaticUnitRange{One}}) = (S.indices,) -Base.axes(S::Slice{<:OptionallyStaticRange}) = (Base.IdentityUnitRange(S.indices),) - -Base.axes(x::OptionallyStaticRange) = (Base.axes1(x),) -Base.axes1(x::OptionallyStaticRange) = static(1):length(x) -Base.axes1(x::Slice{<:OptionallyStaticUnitRange{One}}) = x.indices -Base.axes1(x::Slice{<:OptionallyStaticRange}) = Base.IdentityUnitRange(x.indices) - -Base.:(-)(r::OptionallyStaticRange) = -static_first(r):-static_step(r):-static_last(r) - -Base.reverse(r::OptionallyStaticUnitRange) = static_last(r):static(-1):static_first(r) -function Base.reverse(r::OptionallyStaticStepRange) - OptionallyStaticStepRange(static_last(r), -static_step(r), static_first(r)) -end - -function Base.show(io::IO, ::MIME"text/plain", @nospecialize(r::OptionallyStaticUnitRange)) - print(io, "$(getfield(r, :start)):$(getfield(r, :stop))") -end -function Base.show(io::IO, ::MIME"text/plain", @nospecialize(r::OptionallyStaticStepRange)) - print(io, "$(getfield(r, :start)):$(getfield(r, :step)):$(getfield(r, :stop))") -end - -@inline function Base.getproperty(x::OptionallyStaticRange, s::Symbol) - if s === :start - return first(x) - elseif s === :step - return step(x) - elseif s === :stop - return last(x) - else - error("$x has no property $s") - end -end - -@propagate_inbounds function _pick_range(x, y) - fst = _try_static(static_first(x), static_first(y)) - lst = _try_static(static_last(x), static_last(y)) - return Base.Slice(OptionallyStaticUnitRange(fst, lst)) -end - """ indices(x, dim) -> AbstractUnitRange{Int} @@ -365,7 +36,7 @@ Returns valid indices for the entire length of each array in `x`. """ @propagate_inbounds function indices(x::Tuple) inds = map(eachindex, x) - return reduce_tup(_pick_range, inds) + return reduce_tup(static_promote, inds) end """ @@ -375,7 +46,7 @@ Returns valid indices for each array in `x` along dimension `dim` """ @propagate_inbounds function indices(x::Tuple, dim) inds = map(Base.Fix2(indices, dim), x) - return reduce_tup(_pick_range, inds) + return reduce_tup(static_promote, inds) end """ @@ -386,7 +57,7 @@ respective array (`dim`). """ @propagate_inbounds function indices(x::Tuple, dim::Tuple) inds = map(indices, x, dim) - return reduce_tup(_pick_range, inds) + return reduce_tup(static_promote, inds) end """ @@ -397,19 +68,3 @@ Returns valid indices for array `x` along each dimension specified in `dim`. @inline indices(x, dims::Tuple) = _indices(x, dims) _indices(x, dims::Tuple) = (indices(x, first(dims)), _indices(x, tail(dims))...) _indices(x, ::Tuple{}) = () - -function Base.Broadcast.axistype(r::OptionallyStaticUnitRange{StaticInt{1}}, _) - Base.OneTo(last(r)) -end -function Base.Broadcast.axistype(_, r::OptionallyStaticUnitRange{StaticInt{1}}) - Base.OneTo(last(r)) -end -function Base.Broadcast.axistype(r::OptionallyStaticUnitRange{StaticInt{1}}, ::OptionallyStaticUnitRange{StaticInt{1}}) - Base.OneTo(last(r)) -end -function Base.similar(::Type{<:Array{T}}, axes::Tuple{OptionallyStaticUnitRange{StaticInt{1}},Vararg{Union{Base.OneTo,OptionallyStaticUnitRange{StaticInt{1}}}}}) where {T} - Array{T}(undef, map(last, axes)) -end -function Base.similar(::Type{<:Array{T}}, axes::Tuple{Base.OneTo,OptionallyStaticUnitRange{StaticInt{1}},Vararg{Union{Base.OneTo,OptionallyStaticUnitRange{StaticInt{1}}}}}) where {T} - Array{T}(undef, map(last, axes)) -end diff --git a/test/ranges.jl b/test/ranges.jl index 098bfc435..846db0931 100644 --- a/test/ranges.jl +++ b/test/ranges.jl @@ -1,59 +1,4 @@ -@testset "Range Constructors" begin - @test @inferred(static(1):static(10)) == 1:10 - @test @inferred(ArrayInterface.SUnitRange{1,10}()) == 1:10 - @test @inferred(static(1):static(2):static(10)) == 1:2:10 - @test @inferred(1:static(2):static(10)) == 1:2:10 - @test @inferred(static(1):static(2):10) == 1:2:10 - @test @inferred(static(1):2:static(10)) == 1:2:10 - @test @inferred(1:2:static(10)) == 1:2:10 - @test @inferred(1:static(2):10) == 1:2:10 - @test @inferred(static(1):2:10) == 1:2:10 - @test @inferred(static(1):UInt(10)) === static(1):10 - @test @inferred(UInt(1):static(1):static(10)) === 1:static(10) - @test ArrayInterface.SUnitRange(1, 10) == 1:10 - @test @inferred(ArrayInterface.OptionallyStaticUnitRange{Int,Int}(1:10)) == 1:10 - @test @inferred(ArrayInterface.OptionallyStaticUnitRange(1:10)) == 1:10 - - @inferred(ArrayInterface.OptionallyStaticUnitRange(1:10)) - - @test @inferred(ArrayInterface.OptionallyStaticStepRange(static(1), static(1), static(1))) == 1:1:1 - @test @inferred(ArrayInterface.OptionallyStaticStepRange(static(1), 1, UInt(10))) == static(1):1:10 - @test @inferred(ArrayInterface.OptionallyStaticStepRange(UInt(1), 1, static(10))) == static(1):1:10 - @test @inferred(ArrayInterface.OptionallyStaticStepRange(1:10)) == 1:1:10 - - @test_throws ArgumentError ArrayInterface.OptionallyStaticUnitRange(1:2:10) - @test_throws ArgumentError ArrayInterface.OptionallyStaticUnitRange{Int,Int}(1:2:10) - @test_throws ArgumentError ArrayInterface.OptionallyStaticStepRange(1, 0, 10) - - @test @inferred(static(1):static(1):static(10)) === ArrayInterface.OptionallyStaticUnitRange(static(1), static(10)) - @test @inferred(static(1):static(1):10) === ArrayInterface.OptionallyStaticUnitRange(static(1), 10) - @test @inferred(1:static(1):10) === ArrayInterface.OptionallyStaticUnitRange(1, 10) - @test length(static(-1):static(-1):static(-10)) == 10 == lastindex(static(-1):static(-1):static(-10)) - - @test UnitRange(ArrayInterface.OptionallyStaticUnitRange(static(1), static(10))) === UnitRange(1, 10) - @test UnitRange{Int}(ArrayInterface.OptionallyStaticUnitRange(static(1), static(10))) === UnitRange(1, 10) - - @test AbstractUnitRange{Int}(ArrayInterface.OptionallyStaticUnitRange(static(1), static(10))) isa ArrayInterface.OptionallyStaticUnitRange - @test AbstractUnitRange{UInt}(ArrayInterface.OptionallyStaticUnitRange(static(1), static(10))) isa Base.OneTo - @test AbstractUnitRange{UInt}(ArrayInterface.OptionallyStaticUnitRange(static(2), static(10))) isa UnitRange - - @test @inferred((static(1):static(10))[static(2):static(3)]) === static(2):static(3) - @test @inferred((static(1):static(10))[static(2):3]) === static(2):3 - @test @inferred((static(1):static(10))[2:3]) === 2:3 - @test @inferred((1:static(10))[static(2):static(3)]) === 2:3 - - @test Base.checkindex(Bool, static(1):static(10), static(1):static(5)) - @test -(static(1):static(10)) === static(-1):static(-1):static(-10) - - @test reverse(static(1):static(10)) === static(10):static(-1):static(1) - @test reverse(static(1):static(2):static(9)) === static(9):static(-2):static(1) -end - -# iteration -@test iterate(static(1):static(5), 5) === nothing -@test iterate(static(2):static(5), 5) === nothing - @test isone(@inferred(ArrayInterface.known_first(typeof(static(1):2:10)))) @test isone(@inferred(ArrayInterface.known_last(typeof(static(-1):static(2):static(1))))) @@ -62,52 +7,14 @@ CI = CartesianIndices((static(1):static(2), static(1):static(2))) @test @inferred(ArrayInterface.known_last(typeof(CI))) == CartesianIndex(2, 2) @testset "length" begin - @test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(1, 0))) == 0 - @test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(1, 10))) == 10 - @test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(static(1), 10))) == 10 - @test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(static(0), 10))) == 11 - @test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(static(1), static(10)))) == 10 - @test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(static(0), static(10)))) == 11 - - @test @inferred(length(static(1):static(2):static(0))) == 0 - @test @inferred(length(static(0):static(-2):static(1))) == 0 - @test @inferred(ArrayInterface.known_length(typeof(ArrayInterface.OptionallyStaticStepRange(static(1), 2, 10)))) === nothing @test @inferred(ArrayInterface.known_length(typeof(ArrayInterface.SOneTo{-10}()))) === 0 @test @inferred(ArrayInterface.known_length(typeof(ArrayInterface.OptionallyStaticStepRange(static(1), static(1), static(10))))) === 10 @test @inferred(ArrayInterface.known_length(typeof(ArrayInterface.OptionallyStaticStepRange(static(2), static(1), static(10))))) === 9 @test @inferred(ArrayInterface.known_length(typeof(ArrayInterface.OptionallyStaticStepRange(static(2), static(2), static(10))))) === 5 @test @inferred(ArrayInterface.known_length(Int)) === 1 - - @test @inferred(length(ArrayInterface.OptionallyStaticStepRange(static(1), 2, 10))) == 5 - @test @inferred(length(ArrayInterface.OptionallyStaticStepRange(static(1), static(1), static(10)))) == 10 - @test @inferred(length(ArrayInterface.OptionallyStaticStepRange(static(2), static(1), static(10)))) == 9 - @test @inferred(length(ArrayInterface.OptionallyStaticStepRange(static(2), static(2), static(10)))) == 5 end -@test @inferred(getindex(ArrayInterface.OptionallyStaticUnitRange(static(1), 10), 1)) == 1 -@test @inferred(getindex(ArrayInterface.OptionallyStaticUnitRange(static(0), 10), 1)) == 0 -@test_throws BoundsError getindex(ArrayInterface.OptionallyStaticUnitRange(static(1), 10), 0) -@test_throws BoundsError getindex(ArrayInterface.OptionallyStaticStepRange(static(1), 2, 10), 0) -@test_throws BoundsError getindex(ArrayInterface.OptionallyStaticUnitRange(static(1), 10), 11) -@test_throws BoundsError getindex(ArrayInterface.OptionallyStaticStepRange(static(1), 2, 10), 11) - -@test ArrayInterface.static_first(Base.OneTo(one(UInt))) === static(1) -@test ArrayInterface.static_step(Base.OneTo(one(UInt))) === static(1) - -@test Base.setindex(1:5, [6,2], 1:2) == [6,2,3,4,5] - -@test @inferred(eachindex(static(-7):static(7))) === static(1):static(15) -@test @inferred((static(-7):static(7))[first(eachindex(static(-7):static(7)))]) == -7 - -@test @inferred(firstindex(128:static(-1):1)) == 1 - -@test identity.(static(1):5) isa Vector{Int} -@test (static(1):5) .+ (1:3)' isa Matrix{Int} -@test similar(Array{Int}, (static(1):(4),)) isa Vector{Int} -@test similar(Array{Int}, (static(1):(4), Base.OneTo(4))) isa Matrix{Int} -@test similar(Array{Int}, (Base.OneTo(4), static(1):(4))) isa Matrix{Int} - @testset "indices" begin A23 = ones(2,3); SA23 = MArray(A23); @@ -149,4 +56,3 @@ end @test ArrayInterface.indices((x',y'),StaticInt(1)) === Base.Slice(StaticInt(1):StaticInt(1)) @test ArrayInterface.indices((x,y), StaticInt(2)) === Base.Slice(StaticInt(1):StaticInt(1)) end - From ab6cb46c689a13c91de52a52adc98e3410f6ce00 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 7 Nov 2022 07:49:25 -0500 Subject: [PATCH 2/5] Replace CanonicalInt with Static.IntType --- src/ArrayInterface.jl | 10 +++------- src/array_index.jl | 2 +- src/axes.jl | 2 +- src/dimensions.jl | 10 +++++----- src/indexing.jl | 46 +++++++++++++++++++++---------------------- src/size.jl | 8 ++++---- src/stridelayout.jl | 8 ++++---- 7 files changed, 41 insertions(+), 45 deletions(-) diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index e226aa587..749dc61ea 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -24,7 +24,7 @@ import ArrayInterfaceCore: known_first, known_step, known_last using Static using Static: Zero, One, nstatic, eq, ne, gt, ge, lt, le, eachop, eachop_tuple, permute, invariant_permutation, field_type, reduce_tup, find_first_eq, - OptionallyStaticUnitRange, OptionallyStaticStepRange, OptionallyStaticRange + OptionallyStaticUnitRange, OptionallyStaticStepRange, OptionallyStaticRange, IntType using IfElse @@ -44,10 +44,6 @@ _sub1(@nospecialize x) = x - oneunit(x) Tuple{X.parameters...,Y.parameters...} end -const CanonicalInt = Union{Int,StaticInt} -canonicalize(x::Integer) = Int(x) -canonicalize(@nospecialize(x::StaticInt)) = x - abstract type AbstractArray2{T,N} <: AbstractArray{T,N} end Base.size(A::AbstractArray2) = map(Int, ArrayInterface.size(A)) @@ -94,10 +90,10 @@ end @inline static_last(x) = Static.maybe_static(known_last, last, x) @inline static_step(x) = Static.maybe_static(known_step, step, x) -@inline function _to_cartesian(a, i::CanonicalInt) +@inline function _to_cartesian(a, i::IntType) @inbounds(CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a))))[i]) end -@inline function _to_linear(a, i::Tuple{CanonicalInt,Vararg{CanonicalInt}}) +@inline function _to_linear(a, i::Tuple{IntType,Vararg{IntType}}) _strides2int(offsets(a), size_to_strides(size(a), static(1)), i) + static(1) end diff --git a/src/array_index.jl b/src/array_index.jl index 3a1812743..ee2ae2734 100644 --- a/src/array_index.jl +++ b/src/array_index.jl @@ -20,7 +20,7 @@ struct StrideIndex{N,R,C,S,O} <: ArrayIndex{N} end ## getindex -@propagate_inbounds Base.getindex(x::ArrayIndex, i::CanonicalInt, ii::CanonicalInt...) = x[NDIndex(i, ii...)] +@propagate_inbounds Base.getindex(x::ArrayIndex, i::IntType, ii::IntType...) = x[NDIndex(i, ii...)] @inline function Base.getindex(x::StrideIndex{N}, i::AbstractCartesianIndex) where {N} return _strides2int(offsets(x), strides(x), Tuple(i)) + static(1) diff --git a/src/axes.jl b/src/axes.jl index d992c512b..6492a63ed 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -219,7 +219,7 @@ Base.axes1(x::Slice{LazyAxis{N,A}}) where {N,A} = indices(getfield(x.indices, :p Base.axes1(x::Slice{LazyAxis{:,A}}) where {A} = indices(getfield(x.indices, :parent)) Base.to_shape(x::LazyAxis) = Base.length(x) -@propagate_inbounds function Base.getindex(x::LazyAxis, i::CanonicalInt) +@propagate_inbounds function Base.getindex(x::LazyAxis, i::IntType) @boundscheck checkindex(Bool, x, i) || throw(BoundsError(x, i)) return Int(i) end diff --git a/src/dimensions.jl b/src/dimensions.jl index 194658852..8f7466104 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -111,7 +111,7 @@ to `:_`, then `false` is returned. Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not have a name. """ -@inline known_dimnames(x, dim) = _known_dimname(known_dimnames(x), canonicalize(dim)) +@inline known_dimnames(x, dim) = _known_dimname(known_dimnames(x), IntType(dim)) known_dimnames(x) = known_dimnames(typeof(x)) function known_dimnames(@nospecialize T::Type{<:VecAdjTrans}) (:_, getfield(known_dimnames(parent_type(T)), 1)) @@ -159,7 +159,7 @@ end _unknown_dimnames(::Base.HasShape{N}) where {N} = ntuple(Compat.Returns(:_), StaticInt(N)) _unknown_dimnames(::Any) = (:_,) -@inline function _known_dimname(x::Tuple{Vararg{Any,N}}, dim::CanonicalInt) where {N} +@inline function _known_dimname(x::Tuple{Vararg{Any,N}}, dim::IntType) where {N} # we cannot have `@boundscheck`, else this will depend on bounds checking being enabled (dim > N || dim < 1) && return :_ return @inbounds(x[dim]) @@ -173,7 +173,7 @@ end Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not have a name. """ -@inline dimnames(x, dim) = _dimname(dimnames(x), canonicalize(dim)) +@inline dimnames(x, dim) = _dimname(dimnames(x), IntType(dim)) @inline function dimnames(x::Union{PermutedDimsArray,MatAdjTrans}) map(GetIndex{false}(dimnames(parent(x))), to_parent_dims(x)) end @@ -214,7 +214,7 @@ end return ntuple(Compat.Returns(static(:_)), StaticInt(ndims(x))) end end -@inline function _dimname(x::Tuple{Vararg{Any,N}}, dim::CanonicalInt) where {N} +@inline function _dimname(x::Tuple{Vararg{Any,N}}, dim::IntType) where {N} # we cannot have `@boundscheck`, else this will depend on bounds checking being enabled # for calls such as `dimnames(view(x, :, 1, :))` (dim > N || dim < 1) && return static(:_) @@ -228,7 +228,7 @@ end This returns the dimension(s) of `x` corresponding to `dim`. """ to_dims(x, dim::Colon) = dim -to_dims(x, @nospecialize(dim::CanonicalInt)) = dim +to_dims(x, @nospecialize(dim::IntType)) = dim to_dims(x, dim::Integer) = Int(dim) to_dims(x, dim::Union{StaticSymbol,Symbol}) = _to_dim(dimnames(x), dim) function to_dims(x, dims::Tuple{Vararg{Any,N}}) where {N} diff --git a/src/indexing.jl b/src/indexing.jl index a4b4bbca0..6cc65b1a5 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -162,16 +162,16 @@ to_index(::LinearIndices, i::AbstractArray{Bool}) = LogicalIndex{Int}(i) @inline to_index(x, i::NDIndex) = getfield(i, 1) @inline to_index(x, i::AbstractArray{<:AbstractCartesianIndex}) = i @inline function to_index(x, i::Base.Fix2{<:Union{typeof(<),typeof(isless)},<:Union{Base.BitInteger,StaticInt}}) - static_first(x):min(_sub1(canonicalize(i.x)), static_last(x)) + static_first(x):min(_sub1(IntType(i.x)), static_last(x)) end @inline function to_index(x, i::Base.Fix2{typeof(<=),<:Union{Base.BitInteger,StaticInt}}) - static_first(x):min(canonicalize(i.x), static_last(x)) + static_first(x):min(IntType(i.x), static_last(x)) end @inline function to_index(x, i::Base.Fix2{typeof(>=),<:Union{Base.BitInteger,StaticInt}}) - max(canonicalize(i.x), static_first(x)):static_last(x) + max(IntType(i.x), static_first(x)):static_last(x) end @inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}}) - max(_add1(canonicalize(i.x)), static_first(x)):static_last(x) + max(_add1(IntType(i.x)), static_first(x)):static_last(x) end # integer indexing to_index(x, i::AbstractArray{<:Integer}) = i @@ -232,7 +232,7 @@ indices calling [`to_axis`](@ref). end end # drop this dimension -to_axes(A, a::Tuple, i::Tuple{<:CanonicalInt,Vararg{Any}}) = to_axes(A, _maybe_tail(a), tail(i)) +to_axes(A, a::Tuple, i::Tuple{<:IntType,Vararg{Any}}) = to_axes(A, _maybe_tail(a), tail(i)) to_axes(A, a::Tuple, i::Tuple{I,Vararg{Any}}) where {I} = _to_axes(StaticInt(ndims_index(I)), A, a, i) function _to_axes(::StaticInt{1}, A, axs::Tuple, inds::Tuple) return (to_axis(_maybe_first(axs), first(inds)), to_axes(A, _maybe_tail(axs), tail(inds))...) @@ -309,7 +309,7 @@ function unsafe_getindex(a::A) where {A} end # TODO Need to manage index transformations between nested layers of arrays -function unsafe_getindex(a::A, i::CanonicalInt) where {A} +function unsafe_getindex(a::A, i::IntType) where {A} if IndexStyle(A) === IndexLinear() is_forwarding_wrapper(A) || throw(MethodError(unsafe_getindex, (A, i))) return unsafe_getindex(parent(a), i) @@ -317,7 +317,7 @@ function unsafe_getindex(a::A, i::CanonicalInt) where {A} return unsafe_getindex(a, _to_cartesian(a, i)...) end end -function unsafe_getindex(a::A, i::CanonicalInt, ii::Vararg{CanonicalInt}) where {A} +function unsafe_getindex(a::A, i::IntType, ii::Vararg{IntType}) where {A} if IndexStyle(A) === IndexLinear() return unsafe_getindex(a, _to_linear(a, (i, ii...))) else @@ -329,24 +329,24 @@ end unsafe_getindex(a, i::Vararg{Any}) = unsafe_get_collection(a, i) unsafe_getindex(A::Array) = Base.arrayref(false, A, 1) -unsafe_getindex(A::Array, i::CanonicalInt) = Base.arrayref(false, A, Int(i)) -@inline function unsafe_getindex(A::Array, i::CanonicalInt, ii::Vararg{CanonicalInt}) +unsafe_getindex(A::Array, i::IntType) = Base.arrayref(false, A, Int(i)) +@inline function unsafe_getindex(A::Array, i::IntType, ii::Vararg{IntType}) unsafe_getindex(A, _to_linear(A, (i, ii...))) end -unsafe_getindex(A::LinearIndices, i::CanonicalInt) = Int(i) -unsafe_getindex(A::CartesianIndices{N}, ii::Vararg{CanonicalInt,N}) where {N} = CartesianIndex(ii...) -unsafe_getindex(A::CartesianIndices, ii::Vararg{CanonicalInt}) = +unsafe_getindex(A::LinearIndices, i::IntType) = Int(i) +unsafe_getindex(A::CartesianIndices{N}, ii::Vararg{IntType,N}) where {N} = CartesianIndex(ii...) +unsafe_getindex(A::CartesianIndices, ii::Vararg{IntType}) = unsafe_getindex(A, Base.front(ii)...) -unsafe_getindex(A::CartesianIndices, i::CanonicalInt) = @inbounds(A[i]) +unsafe_getindex(A::CartesianIndices, i::IntType) = @inbounds(A[i]) -unsafe_getindex(A::ReshapedArray, i::CanonicalInt) = @inbounds(parent(A)[i]) -function unsafe_getindex(A::ReshapedArray, i::CanonicalInt, ii::Vararg{CanonicalInt}) +unsafe_getindex(A::ReshapedArray, i::IntType) = @inbounds(parent(A)[i]) +function unsafe_getindex(A::ReshapedArray, i::IntType, ii::Vararg{IntType}) @inbounds(parent(A)[_to_linear(A, (i, ii...))]) end -unsafe_getindex(A::SubArray, i::CanonicalInt) = @inbounds(A[i]) -unsafe_getindex(A::SubArray, i::CanonicalInt, ii::Vararg{CanonicalInt}) = @inbounds(A[i, ii...]) +unsafe_getindex(A::SubArray, i::IntType) = @inbounds(A[i]) +unsafe_getindex(A::SubArray, i::IntType, ii::Vararg{IntType}) = @inbounds(A[i, ii...]) # This is based on Base._unsafe_getindex from https://github.com/JuliaLang/julia/blob/c5ede45829bf8eb09f2145bfd6f089459d77b2b1/base/multidimensional.jl#L755. #= @@ -364,7 +364,7 @@ function unsafe_get_collection(A, inds) end return dest end -_ints2range(x::CanonicalInt) = x:x +_ints2range(x::IntType) = x:x _ints2range(x::AbstractRange) = x # apply _ints2range to front N elements _ints2range_front(::Val{N}, ind, inds...) where {N} = @@ -372,9 +372,9 @@ _ints2range_front(::Val{N}, ind, inds...) where {N} = _ints2range_front(::Val{0}, ind, inds...) = () _ints2range_front(::Val{0}) = () # get output shape with given indices -_output_shape(::CanonicalInt, inds...) = _output_shape(inds...) +_output_shape(::IntType, inds...) = _output_shape(inds...) _output_shape(ind::AbstractRange, inds...) = (Base.length(ind), _output_shape(inds...)...) -_output_shape(::CanonicalInt) = () +_output_shape(::IntType) = () _output_shape(x::AbstractRange) = (Base.length(x),) @inline function unsafe_get_collection(A::CartesianIndices{N}, inds) where {N} if (Base.length(inds) === 1 && N > 1) || stride_preserving_index(typeof(inds)) === False() @@ -426,7 +426,7 @@ function unsafe_setindex!(a::A, v) where {A} return unsafe_setindex!(parent(a), v) end # TODO Need to manage index transformations between nested layers of arrays -function unsafe_setindex!(a::A, v, i::CanonicalInt) where {A} +function unsafe_setindex!(a::A, v, i::IntType) where {A} if IndexStyle(A) === IndexLinear() is_forwarding_wrapper(A) || throw(MethodError(unsafe_setindex!, (A, v, i))) return unsafe_setindex!(parent(a), v, i) @@ -434,7 +434,7 @@ function unsafe_setindex!(a::A, v, i::CanonicalInt) where {A} return unsafe_setindex!(a, v, _to_cartesian(a, i)...) end end -function unsafe_setindex!(a::A, v, i::CanonicalInt, ii::Vararg{CanonicalInt}) where {A} +function unsafe_setindex!(a::A, v, i::IntType, ii::Vararg{IntType}) where {A} if IndexStyle(A) === IndexLinear() return unsafe_setindex!(a, v, _to_linear(a, (i, ii...))) else @@ -446,7 +446,7 @@ end function unsafe_setindex!(A::Array{T}, v) where {T} Base.arrayset(false, A, convert(T, v)::T, 1) end -function unsafe_setindex!(A::Array{T}, v, i::CanonicalInt) where {T} +function unsafe_setindex!(A::Array{T}, v, i::IntType) where {T} return Base.arrayset(false, A, convert(T, v)::T, Int(i)) end diff --git a/src/size.jl b/src/size.jl index 7cd9904d5..9c9e8d27e 100644 --- a/src/size.jl +++ b/src/size.jl @@ -64,8 +64,8 @@ end _sub_size(x::Tuple, ::StaticInt{dim}) where {dim} = length(getfield(x, dim)) size(a, dim) = size(a, to_dims(a, dim)) -size(a::Array, dim::CanonicalInt) = Base.arraysize(a, convert(Int, dim)) -function size(a::A, dim::CanonicalInt) where {A} +size(a::Array, dim::IntType) = Base.arraysize(a, convert(Int, dim)) +function size(a::A, dim::IntType) where {A} if is_forwarding_wrapper(A) return size(parent(a), dim) else @@ -161,7 +161,7 @@ end # 1. `Zip` doesn't check that its collections are compatible (same size) at construction, # but we assume as much b/c otherwise it will error while iterating. So we promote to the # known size if matching a `Nothing` and `Int` size. -# 2. `promote_shape(::Tuple{Vararg{CanonicalInt}}, ::Tuple{Vararg{CanonicalInt}})` promotes +# 2. `promote_shape(::Tuple{Vararg{IntType}}, ::Tuple{Vararg{IntType}})` promotes # trailing dimensions (which must be of size 1), to `static(1)`. We want to stick to # `Nothing` and `Int` types, so we do one last pass to ensure everything is dynamic @inline function known_size(::Type{<:Iterators.Zip{T}}) where {T} @@ -171,7 +171,7 @@ _unzip_size(::Type{T}, n::StaticInt{N}) where {T,N} = known_size(field_type(T, n _known_size(::Type{T}, dim::StaticInt) where {T} = known_length(field_type(T, dim)) @inline known_size(x, dim) = known_size(typeof(x), dim) @inline known_size(::Type{T}, dim) where {T} = known_size(T, to_dims(T, dim)) -known_size(T::Type, dim::CanonicalInt) = ndims(T) < dim ? 1 : known_size(T)[dim] +known_size(T::Type, dim::IntType) = ndims(T) < dim ? 1 : known_size(T)[dim] """ length(A) -> Union{Int,StaticInt} diff --git a/src/stridelayout.jl b/src/stridelayout.jl index 8605aade2..ec8ae3536 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -12,7 +12,7 @@ known_offsets(@nospecialize T::Type{<:Number}) = () # Int has no dimensions @inline function known_offsets(@nospecialize T::Type{<:SubArray}) flatten_tuples(map_tuple_type(known_offsets, fieldtype(T, :indices))) end -function known_offsets(::Type{T}, dim::CanonicalInt) where {T} +function known_offsets(::Type{T}, dim::IntType) where {T} if ndims(T) < dim return 1 else @@ -155,7 +155,7 @@ end I = field_type(fieldtype(T, :indices), c) if I <: AbstractUnitRange return from_parent_dims(T)[c] # FIXME get rid of from_parent_dims - elseif I <: AbstractArray || I <: CanonicalInt + elseif I <: AbstractArray || I <: IntType return StaticInt(-1) else return nothing @@ -446,7 +446,7 @@ compile time are represented by `nothing`. """ known_strides(x, dim) = known_strides(typeof(x), dim) known_strides(::Type{T}, dim) where {T} = known_strides(T, to_dims(T, dim)) -function known_strides(::Type{T}, dim::CanonicalInt) where {T} +function known_strides(::Type{T}, dim::IntType) where {T} # see https://github.com/JuliaLang/julia/blob/6468dcb04ea2947f43a11f556da9a5588de512a0/base/reinterpretarray.jl#L148 if ndims(T) < dim return known_length(T) @@ -678,7 +678,7 @@ maybe_static_step(_) = nothing end strides(a, dim) = strides(a, to_dims(a, dim)) -function strides(a::A, dim::CanonicalInt) where {A} +function strides(a::A, dim::IntType) where {A} if is_forwarding_wrapper(A) return strides(parent(a), dim) else From 3163aba9a051ec4dd3475f921a1ca402993cd53c Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Tue, 15 Nov 2022 07:58:55 -0500 Subject: [PATCH 3/5] Bump versions and add new imports from Static --- lib/ArrayInterfaceOffsetArrays/Project.toml | 2 +- lib/ArrayInterfaceStaticArrays/Project.toml | 2 +- .../src/ArrayInterfaceStaticArrays.jl | 3 +++ src/ArrayInterface.jl | 3 ++- src/axes.jl | 3 +++ src/indexing.jl | 23 ------------------- src/ranges.jl | 6 ++--- src/size.jl | 13 ++++++++++- test/indexing.jl | 7 ------ test/ranges.jl | 12 +++++----- 10 files changed, 31 insertions(+), 43 deletions(-) diff --git a/lib/ArrayInterfaceOffsetArrays/Project.toml b/lib/ArrayInterfaceOffsetArrays/Project.toml index 1200aacf1..3de5556ef 100644 --- a/lib/ArrayInterfaceOffsetArrays/Project.toml +++ b/lib/ArrayInterfaceOffsetArrays/Project.toml @@ -10,7 +10,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" [compat] ArrayInterface = "5, 6" OffsetArrays = "1.11" -Static = "0.7" +Static = "0.7, 0.8" julia = "1.6" [extras] diff --git a/lib/ArrayInterfaceStaticArrays/Project.toml b/lib/ArrayInterfaceStaticArrays/Project.toml index 61f119631..65da71050 100644 --- a/lib/ArrayInterfaceStaticArrays/Project.toml +++ b/lib/ArrayInterfaceStaticArrays/Project.toml @@ -16,7 +16,7 @@ Adapt = "3" ArrayInterface = "6" ArrayInterfaceCore = "0.1.21" ArrayInterfaceStaticArraysCore = "0.1" -Static = "0.7" +Static = "0.8" StaticArrays = "1.2.5, 1.3, 1.4" julia = "1.6" diff --git a/lib/ArrayInterfaceStaticArrays/src/ArrayInterfaceStaticArrays.jl b/lib/ArrayInterfaceStaticArrays/src/ArrayInterfaceStaticArrays.jl index f5c78b668..cb676189a 100644 --- a/lib/ArrayInterfaceStaticArrays/src/ArrayInterfaceStaticArrays.jl +++ b/lib/ArrayInterfaceStaticArrays/src/ArrayInterfaceStaticArrays.jl @@ -9,6 +9,9 @@ import ArrayInterfaceStaticArraysCore const CanonicalInt = Union{Int,StaticInt} +function Static.OptionallyStaticUnitRange(::StaticArrays.SOneTo{N}) where {N} + Static.OptionallyStaticUnitRange(StaticInt(1), StaticInt(N)) +end ArrayInterface.known_first(::Type{<:StaticArrays.SOneTo}) = 1 ArrayInterface.known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N ArrayInterface.known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index 749dc61ea..32e28f0de 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -24,7 +24,8 @@ import ArrayInterfaceCore: known_first, known_step, known_last using Static using Static: Zero, One, nstatic, eq, ne, gt, ge, lt, le, eachop, eachop_tuple, permute, invariant_permutation, field_type, reduce_tup, find_first_eq, - OptionallyStaticUnitRange, OptionallyStaticStepRange, OptionallyStaticRange, IntType + OptionallyStaticUnitRange, OptionallyStaticStepRange, OptionallyStaticRange, IntType, + SOneTo, SUnitRange using IfElse diff --git a/src/axes.jl b/src/axes.jl index 6492a63ed..7676f38c1 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -185,6 +185,9 @@ Base.keys(x::LazyAxis) = keys(parent(x)) Base.IndexStyle(T::Type{<:LazyAxis}) = IndexStyle(parent_type(T)) +function Static.OptionallyStaticUnitRange(x::LazyAxis) + OptionallyStaticUnitRange(static_first(x), static_last(x)) +end ArrayInterfaceCore.can_change_size(@nospecialize T::Type{<:LazyAxis}) = can_change_size(fieldtype(T, :parent)) ArrayInterfaceCore.known_first(::Type{<:LazyAxis{N,P}}) where {N,P} = known_offsets(P, static(N)) diff --git a/src/indexing.jl b/src/indexing.jl index 6cc65b1a5..b98e3c286 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -1,27 +1,4 @@ -function known_lastindex(::Type{T}) where {T} - if known_offset1(T) === nothing || known_length(T) === nothing - return nothing - else - return known_length(T) - known_offset1(T) + 1 - end -end -known_lastindex(@nospecialize x) = known_lastindex(typeof(x)) - -@inline static_lastindex(x) = Static.maybe_static(known_lastindex, lastindex, x) - -function Base.first(x::AbstractVector, n::StaticInt) - @boundscheck n < 0 && throw(ArgumentError("Number of elements must be nonnegative")) - start = offset1(x) - @inbounds x[start:min((start - one(start)) + n, static_lastindex(x))] -end - -function Base.last(x::AbstractVector, n::StaticInt) - @boundscheck n < 0 && throw(ArgumentError("Number of elements must be nonnegative")) - stop = static_lastindex(x) - @inbounds x[max(offset1(x), (stop + one(stop)) - n):stop] -end - """ ArrayInterface.to_indices(A, I::Tuple) -> Tuple diff --git a/src/ranges.jl b/src/ranges.jl index c7767db3f..0fda69b17 100644 --- a/src/ranges.jl +++ b/src/ranges.jl @@ -36,7 +36,7 @@ Returns valid indices for the entire length of each array in `x`. """ @propagate_inbounds function indices(x::Tuple) inds = map(eachindex, x) - return reduce_tup(static_promote, inds) + return Base.Slice(reduce_tup(static_promote, inds)) end """ @@ -46,7 +46,7 @@ Returns valid indices for each array in `x` along dimension `dim` """ @propagate_inbounds function indices(x::Tuple, dim) inds = map(Base.Fix2(indices, dim), x) - return reduce_tup(static_promote, inds) + return Base.Slice(reduce_tup(static_promote, inds)) end """ @@ -57,7 +57,7 @@ respective array (`dim`). """ @propagate_inbounds function indices(x::Tuple, dim::Tuple) inds = map(indices, x, dim) - return reduce_tup(static_promote, inds) + return Base.Slice(reduce_tup(static_promote, inds)) end """ diff --git a/src/size.jl b/src/size.jl index 9c9e8d27e..f60f06757 100644 --- a/src/size.jl +++ b/src/size.jl @@ -141,7 +141,18 @@ end if is_forwarding_wrapper(T) return known_size(parent_type(T)) else - return (_range_length(known_first(T), known_step(T), known_last(T)),) + start = known_first(T) + s = known_step(T) + stop = known_last(T) + if stop !== nothing && s !== nothing && start !== nothing + if s > 0 + return (stop < start ? 0 : div(stop - start, s) + 1,) + else + return (stop > start ? 0 : div(start - stop, -s) + 1,) + end + else + return (nothing,) + end end end diff --git a/test/indexing.jl b/test/indexing.jl index c0eec4907..edef14b3e 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -240,13 +240,6 @@ end end end -@testset "n-first/last" begin - x = MArray([1, 2, 3, 4]) - n = static(2) - @test @inferred(first(x, n)) == [1, 2] - @test @inferred(last(x, n)) == [3, 4] -end - A = zeros(3, 4, 5); A[:] = 1:60 Ap = @view(PermutedDimsArray(A, (3, 1, 2))[:, 1:2, 1])'; diff --git a/test/ranges.jl b/test/ranges.jl index 846db0931..050362738 100644 --- a/test/ranges.jl +++ b/test/ranges.jl @@ -19,7 +19,7 @@ end A23 = ones(2,3); SA23 = MArray(A23); A32 = ones(3,2); - SA32 = MArray(A32) + SA32 = MArray(A32); @test @inferred(ArrayInterface.indices(A23, (static(1),static(2)))) === (Base.Slice(StaticInt(1):2), Base.Slice(StaticInt(1):3)) @test @inferred(ArrayInterface.indices((A23, A32))) == 1:6 @@ -40,11 +40,11 @@ end @test @inferred(ArrayInterface.indices((A23, SA23), StaticInt(1))) === Base.Slice(StaticInt(1):StaticInt(2)) @test @inferred(ArrayInterface.indices((SA23, SA23), StaticInt(1))) === Base.Slice(StaticInt(1):StaticInt(2)) - @test_throws AssertionError ArrayInterface.indices((A23, ones(3, 3)), 1) - @test_throws AssertionError ArrayInterface.indices((A23, ones(3, 3)), (1, 2)) - @test_throws AssertionError ArrayInterface.indices((SA23, ones(3, 3)), StaticInt(1)) - @test_throws AssertionError ArrayInterface.indices((SA23, ones(3, 3)), (StaticInt(1), 2)) - @test_throws AssertionError ArrayInterface.indices((SA23, SA23), (StaticInt(1), StaticInt(2))) + @test_throws ErrorException ArrayInterface.indices((A23, ones(3, 3)), 1) + @test_throws ErrorException ArrayInterface.indices((A23, ones(3, 3)), (1, 2)) + @test_throws ErrorException ArrayInterface.indices((SA23, ones(3, 3)), StaticInt(1)) + @test_throws ErrorException ArrayInterface.indices((SA23, ones(3, 3)), (StaticInt(1), 2)) + @test_throws ErrorException ArrayInterface.indices((SA23, SA23), (StaticInt(1), StaticInt(2))) @test size(similar(ones(2, 4), ArrayInterface.indices(ones(2, 4), 1), ArrayInterface.indices(ones(2, 4), 2))) == (2, 4) @test axes(ArrayInterface.indices(ones(2,2))) === (StaticInt(1):4,) From 0c2d8706bc8a81dae8db285bda7a1bbc924e8e91 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Tue, 15 Nov 2022 22:08:34 -0500 Subject: [PATCH 4/5] Add deprecation for _pick_range --- src/ranges.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/ranges.jl b/src/ranges.jl index 0fda69b17..26dbac86f 100644 --- a/src/ranges.jl +++ b/src/ranges.jl @@ -36,7 +36,7 @@ Returns valid indices for the entire length of each array in `x`. """ @propagate_inbounds function indices(x::Tuple) inds = map(eachindex, x) - return Base.Slice(reduce_tup(static_promote, inds)) + return reduce_tup(static_promote, inds) end """ @@ -46,7 +46,7 @@ Returns valid indices for each array in `x` along dimension `dim` """ @propagate_inbounds function indices(x::Tuple, dim) inds = map(Base.Fix2(indices, dim), x) - return Base.Slice(reduce_tup(static_promote, inds)) + return reduce_tup(static_promote, inds) end """ @@ -57,7 +57,7 @@ respective array (`dim`). """ @propagate_inbounds function indices(x::Tuple, dim::Tuple) inds = map(indices, x, dim) - return Base.Slice(reduce_tup(static_promote, inds)) + return reduce_tup(static_promote, inds) end """ @@ -68,3 +68,6 @@ Returns valid indices for array `x` along each dimension specified in `dim`. @inline indices(x, dims::Tuple) = _indices(x, dims) _indices(x, dims::Tuple) = (indices(x, first(dims)), _indices(x, tail(dims))...) _indices(x, ::Tuple{}) = () + +@deprecate _pick_range(x, y) Base.Slice(Static.static_promote(x, y)) false + From 5d41ce5efc8dc8a70ab9db5149938997b1be15db Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Wed, 16 Nov 2022 19:39:41 -0500 Subject: [PATCH 5/5] Remove `_pick_range` --- src/ranges.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/ranges.jl b/src/ranges.jl index 26dbac86f..93ad803ac 100644 --- a/src/ranges.jl +++ b/src/ranges.jl @@ -69,5 +69,3 @@ Returns valid indices for array `x` along each dimension specified in `dim`. _indices(x, dims::Tuple) = (indices(x, first(dims)), _indices(x, tail(dims))...) _indices(x, ::Tuple{}) = () -@deprecate _pick_range(x, y) Base.Slice(Static.static_promote(x, y)) false -