Skip to content

Commit 5b5279d

Browse files
committed
OptionallyStaticRange -> OptionallyStaticUnitRange
Also indices now produces a `Base.Slice` whenever possible.
1 parent cd47c71 commit 5b5279d

File tree

2 files changed

+143
-52
lines changed

2 files changed

+143
-52
lines changed

src/ArrayInterface.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using Requires
44
using LinearAlgebra
55
using SparseArrays
66

7+
using Base: OneTo
8+
79
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
810
parameterless_type(x) = parameterless_type(typeof(x))
911
parameterless_type(x::Type) = __parameterless_type(x)
@@ -20,8 +22,21 @@ parent_type(::Type{Adjoint{T,S}}) where {T,S} = S
2022
parent_type(::Type{Transpose{T,S}}) where {T,S} = S
2123
parent_type(::Type{Symmetric{T,S}}) where {T,S} = S
2224
parent_type(::Type{<:LinearAlgebra.AbstractTriangular{T,S}}) where {T,S} = S
25+
parent_type(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A} = A
26+
parent_type(::Type{Base.Slice{T}}) where {T} = T
2327
parent_type(::Type{T}) where {T} = T
2428

29+
"""
30+
known_length(::Type{T})
31+
32+
If `length` of an instance of type `T` is known at compile time, return it.
33+
Otherwise, return `nothing`.
34+
"""
35+
known_length(x) = known_length(typeof(x))
36+
known_length(::Type{<:NTuple{N,<:Any}}) where {N} = N
37+
known_length(::Type{<:NamedTuple{L}}) where {L} = length(L)
38+
known_length(::Type{T}) where {T<:Base.Slice} = known_length(parent_type(T))
39+
2540
"""
2641
can_change_size(::Type{T}) -> Bool
2742
@@ -536,6 +551,7 @@ function __init__()
536551

537552
known_first(::Type{<:StaticArrays.SOneTo}) = 1
538553
known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N
554+
known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N
539555

540556
@require Adapt="79e6a3ab-5dfb-504d-930d-738a2a938a0e" begin
541557
function Adapt.adapt_storage(::Type{<:StaticArrays.SArray{S}},xs::Array) where S

src/ranges.jl

Lines changed: 127 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
"""
3-
known_first(::Type{T})
3+
known_first(::Type{T})
44
55
If `first` of an instance of type `T` is known at compile time, return it.
66
Otherwise, return `nothing`.
@@ -11,9 +11,11 @@ Otherwise, return `nothing`.
1111
known_first(x) = known_first(typeof(x))
1212
known_first(::Type{T}) where {T} = nothing
1313
known_first(::Type{Base.OneTo{T}}) where {T} = one(T)
14+
known_first(::Type{T}) where {T<:Base.Slice} = known_first(parent_type(T))
15+
1416

1517
"""
16-
known_last(::Type{T})
18+
known_last(::Type{T})
1719
1820
If `last` of an instance of type `T` is known at compile time, return it.
1921
Otherwise, return `nothing`.
@@ -24,9 +26,10 @@ using StaticArrays
2426
"""
2527
known_last(x) = known_last(typeof(x))
2628
known_last(::Type{T}) where {T} = nothing
29+
known_last(::Type{T}) where {T<:Base.Slice} = known_last(parent_type(T))
2730

2831
"""
29-
known_step(::Type{T})
32+
known_step(::Type{T})
3033
3134
If `step` of an instance of type `T` is known at compile time, return it.
3235
Otherwise, return `nothing`.
@@ -38,77 +41,134 @@ known_step(x) = known_step(typeof(x))
3841
known_step(::Type{T}) where {T} = nothing
3942
known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)
4043

41-
_eltype(::Type{T}) where {T} = T
42-
_eltype(::Type{Val{V}}) where {V} = typeof(V)
44+
# add methods to support ArrayInterface
4345

44-
struct OptionallyStaticRange{T<:Integer,F,S,L} <: OrdinalRange{T,T}
45-
start::F
46-
step::S
47-
stop::L
46+
_get(x) = x
47+
_get(::Val{V}) where {V} = V
48+
_convert(::Type{T}, x) where {T} = convert(T, x)
49+
_convert(::Type{T}, ::Val{V}) where {T,V} = Val(convert(T, V))
4850

49-
function OptionallyStaticRange(start::F, step::S, stop::L) where {F,S,L}
50-
T = promote_type(_eltype(F), _eltype(S), eltype(L))
51-
return new{T,F,S,L}(start, step, stop)
52-
end
51+
"""
52+
OptionallyStaticUnitRange{T<:Integer}(start, stop) <: OrdinalRange{T,T}
53+
54+
This range permits diverse representations of arrays to comunicate common information
55+
about their indices. Each field may be an integer or `Val(<:Integer)` if it is known
56+
at compile time. An `OptionallyStaticUnitRange` is intended to be constructed internally
57+
from other valid indices. Therefore, users should not expect the same checks are used
58+
to ensure construction of a valid `OptionallyStaticUnitRange`.
59+
"""
60+
struct OptionallyStaticUnitRange{T,F,L} <: AbstractUnitRange{T}
61+
start::F
62+
stop::L
63+
64+
function OptionallyStaticUnitRange{T}(start, stop) where {T<:Real}
65+
if _get(start) isa T
66+
if _get(stop) isa T
67+
return new{T,typeof(start),typeof(stop)}(start, stop)
68+
else
69+
return OptionallyStaticUnitRange{T}(start, _convert(T, stop))
70+
end
71+
else
72+
return OptionallyStaticUnitRange{T}(_convert(T, start), stop)
73+
end
74+
end
75+
76+
function OptionallyStaticUnitRange(start, stop)
77+
T = promote_type(typeof(_get(start)), typeof(_get(stop)))
78+
return OptionallyStaticUnitRange{T}(start, stop)
79+
end
5380
end
5481

55-
Base.first(r::OptionallyStaticRange{<:Any,Val{F}}) where {F} = F
56-
Base.first(r::OptionallyStaticRange{<:Any,<:Any}) = getfield(r, :start)
82+
Base.first(r::OptionallyStaticUnitRange{<:Any,Val{F}}) where {F} = F
83+
Base.first(r::OptionallyStaticUnitRange{<:Any,<:Any}) = r.start
5784

58-
Base.step(r::OptionallyStaticRange{<:Any,<:Any,Val{S}}) where {S} = S
59-
Base.step(r::OptionallyStaticRange{<:Any,<:Any,<:Any}) = getfield(r, :step)
85+
Base.step(r::OptionallyStaticUnitRange{T}) where {T} = oneunit(T)
6086

61-
Base.last(r::OptionallyStaticRange{<:Any,<:Any,<:Any,Val{L}}) where {L} = L
62-
Base.last(r::OptionallyStaticRange{<:Any,<:Any,<:Any,<:Any}) = getfield(r, :stop)
87+
Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}) where {L} = L
88+
Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,<:Any}) = r.stop
6389

64-
ArrayInterface.known_first(::OptionallyStaticRange{<:Any,Val{F}}) where {F} = F
65-
ArrayInterface.known_step(::OptionallyStaticRange{<:Any,<:Any,Val{S}}) where {S} =S
66-
ArrayInterface.known_last(::OptionallyStaticRange{<:Any,<:Any,<:Any,Val{L}}) where {L} = L
90+
known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Val{F}}}) where {F} = F
91+
known_step(::Type{<:OptionallyStaticUnitRange{T}}) where {T} = one(T)
92+
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}}) where {L} = L
6793

68-
function Base.isempty(r::OptionallyStaticRange)
69-
return (first(r) != last(r)) & ((step(r) > zero(step(r))) != (last(r) > first(r)))
94+
function Base.isempty(r::OptionallyStaticUnitRange)
95+
if known_first(r) === oneunit(eltype(r))
96+
return unsafe_isempty_one_to(last(r))
97+
else
98+
return unsafe_isempty_unit_range(first(r), last(r))
99+
end
70100
end
71101

72-
@inline function Base.length(r::OptionallyStaticRange{T}) where {T}
73-
if isempty(r)
74-
return zero(T)
102+
unsafe_isempty_one_to(lst) = lst <= zero(lst)
103+
unsafe_isempty_unit_range(fst, lst) = fst > lst
104+
105+
unsafe_isempty_unit_range(fst::T, lst::T) where {T} = Integer(lst - fst + one(T))
106+
107+
unsafe_length_one_to(lst::T) where {T<:Int} = T(lst)
108+
unsafe_length_one_to(lst::T) where {T} = Integer(lst - zero(lst))
109+
110+
Base.@propagate_inbounds function Base.getindex(r::OptionallyStaticUnitRange, i::Integer)
111+
if known_first(r) === oneunit(r)
112+
return get_index_one_to(r, i)
75113
else
76-
if known_step(r) === oneunit(T)
77-
if known_first(r) === oneunit(T)
78-
return last(r)
79-
else
80-
return last(r) - first(r) + step(r)
81-
end
82-
else
83-
return Integer(div((last(r) - first(r)) + step(r), step(r)))
84-
end
114+
return get_index_unit_range(r, i)
85115
end
86116
end
87117

118+
@inline function get_index_one_to(r, i)
119+
@boundscheck if ((i > 0) & (i <= last(r)))
120+
throw(BoundsError(r, i))
121+
end
122+
return convert(eltype(r), i)
123+
end
88124

89-
isempty(r::StepRange) =
90-
(r.start != r.stop) & ((r.step > zero(r.step)) != (r.stop > r.start))
91-
isempty(r::AbstractUnitRange) = first(r) > last(r)
92-
isempty(r::StepRangeLen) = length(r) == 0
93-
isempty(r::LinRange) = length(r) == 0
94-
95-
# add methods to support ArrayInterface
125+
@inline function get_index_unit_range(r, i)
126+
val = first(r) + (i - 1)
127+
@boundscheck if i > 0 && val <= last(r) && val >= first(r)
128+
throw(BoundsError(r, i))
129+
end
130+
return convert(eltype(r), val)
131+
end
96132

97133
_try_static(x, y) = Val(x)
98134
_try_static(::Nothing, y) = Val(y)
99135
_try_static(x, ::Nothing) = Val(x)
100136
_try_static(::Nothing, ::Nothing) = nothing
101137

102-
@inline function _pick_range(x, y)
103-
fst = _try_static(known_first(x), known_first(y))
104-
fst = fst === nothing ? first(x) : fst
138+
###
139+
### length
140+
###
141+
@inline function known_length(::Type{T}) where {T<:AbstractUnitRange}
142+
fst = known_first(T)
143+
lst = known_last(T)
144+
if stp === nothing || fst === nothing || lst === nothing
145+
return nothing
146+
else
147+
if fst === oneunit(eltype(T))
148+
return unsafe_length_one_to(lst)
149+
else
150+
return unsafe_length_unit_range(fst, lst)
151+
end
152+
end
153+
end
105154

106-
st = _try_static(known_step(x), known_step(y))
107-
st = st === nothing ? step(x) : st
155+
function Base.length(r::OptionallyStaticUnitRange{T}) where {T}
156+
if isempty(r)
157+
return zero(T)
158+
else
159+
if known_one(r) === one(T)
160+
return unsafe_length_one_to(last(r))
161+
else
162+
return unsafe_length_unit_range(first(r), last(r))
163+
end
164+
end
165+
end
108166

109-
lst = _try_static(known_last(x), known_last(y))
110-
lst = lst === nothing ? last(x) : lst
111-
return OptionallyStaticRange(fst, st, lst)
167+
function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{Int,Int64,Int128}}
168+
return Base.checked_add(Base.checked_sub(lst, fst), one(T))
169+
end
170+
function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{UInt,UInt64,UInt128}}
171+
return Base.checked_add(lst - fst, one(T))
112172
end
113173

114174
"""
@@ -120,7 +180,14 @@ returned. If any indices are not equal along dimension `d` an error is thrown. A
120180
tuple may be used to specify a different dimension for each array. If `d` is not
121181
specified then indices for visiting each index of `x` is returned.
122182
"""
123-
@inline indices(x) = eachindex(x)
183+
@inline function indices(x)
184+
inds = eachindex(x)
185+
if inds isa AbstractUnitRange{<:Integer}
186+
return Base.Slice(inds)
187+
else
188+
return inds
189+
end
190+
end
124191

125192
indices(x, d) = indices(axes(x, d))
126193

@@ -136,4 +203,12 @@ end
136203
return reduce(_pick_range, inds)
137204
end
138205

206+
@inline function _pick_range(x, y)
207+
fst = _try_static(known_first(x), known_first(y))
208+
fst = fst === nothing ? first(x) : fst
209+
210+
lst = _try_static(known_last(x), known_last(y))
211+
lst = lst === nothing ? last(x) : lst
212+
return Base.Slice(OptionallyStaticUnitRange(fst, lst))
213+
end
139214

0 commit comments

Comments
 (0)