Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Static"
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
authors = ["chriselrod", "ChrisRackauckas", "Tokazama"]
version = "1.3.0"
version = "1.3.1"

[deps]
CommonWorldInvalidations = "f70d9fcc-98c5-4d4a-abd7-e4cdeebd8ca8"
Expand Down
29 changes: 15 additions & 14 deletions src/Static.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ import IfElse: ifelse
using SciMLPublic: @public

export StaticInt, StaticFloat64, StaticSymbol, True, False, StaticBool, NDIndex
export dynamic, is_static, known, static, static_promote, static_first, static_step,
static_last
export dynamic, is_static, known, static, static_promote

@public OptionallyStaticRange,
OptionallyStaticUnitRange, OptionallyStaticStepRange, SUnitRange, SOneTo
Expand Down Expand Up @@ -353,18 +352,20 @@ julia> static_promote(1:2:9, static(1):static(2):static(9))
static(1):static(2):static(9)
```
"""
Base.@propagate_inbounds @inline function static_promote(x::AbstractUnitRange{<:Integer},
y::AbstractUnitRange{<:Integer})
fst = static_promote(static_first(x), static_first(y))
lst = static_promote(static_last(x), static_last(y))
return OptionallyStaticUnitRange(fst, lst)
end
Base.@propagate_inbounds @inline function static_promote(x::AbstractRange{<:Integer},
y::AbstractRange{<:Integer})
fst = static_promote(static_first(x), static_first(y))
stp = static_promote(static_step(x), static_step(y))
lst = static_promote(static_last(x), static_last(y))
return _OptionallyStaticStepRange(fst, stp, lst)
@inline function static_promote(
x0::AbstractRange{<:Integer},
y0::AbstractRange{<:Integer},
)
x = OptionallyStaticStepRange(x0)
y = OptionallyStaticStepRange(y0)
fst = static_promote(getfield(x, :start), getfield(y, :start))
stp = static_promote(getfield(x, :step), getfield(y, :step))
lst = static_promote(getfield(x, :stop), getfield(y, :stop))
if isa(stp, One)
return _OptionallyStaticUnitRange(fst, lst)
else
return _OptionallyStaticStepRange(fst, stp, lst)
end
end
function static_promote(x::Base.Slice, y::Base.Slice)
Base.Slice(static_promote(x.indices, y.indices))
Expand Down
184 changes: 69 additions & 115 deletions src/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,9 @@ struct OptionallyStaticUnitRange{F <: IntType, L <: IntType} <:
start::F
stop::L

function OptionallyStaticUnitRange(start::IntType,
stop::IntType)
global function _OptionallyStaticUnitRange(start::IntType, stop::IntType)
new{typeof(start), typeof(stop)}(start, stop)
end
function OptionallyStaticUnitRange(start, stop)
OptionallyStaticUnitRange(IntType(start), IntType(stop))
end
OptionallyStaticUnitRange(@nospecialize x::OptionallyStaticUnitRange) = x
function OptionallyStaticUnitRange(x::AbstractRange)
step(x) == 1 && return OptionallyStaticUnitRange(static_first(x), static_last(x))

errmsg(x) = throw(ArgumentError("step must be 1, got $(step(x))")) # avoid GC frame
errmsg(x)
end
function OptionallyStaticUnitRange{F, L}(x::AbstractRange) where {F, L}
OptionallyStaticUnitRange(x)
end
function OptionallyStaticUnitRange{StaticInt{F}, StaticInt{L}}() where {F, L}
new{StaticInt{F}, StaticInt{L}}()
end
end

"""
Expand Down Expand Up @@ -97,9 +80,20 @@ function OptionallyStaticStepRange(@nospecialize(start::IntType),
end
end
OptionallyStaticStepRange(@nospecialize x::OptionallyStaticStepRange) = x
function OptionallyStaticStepRange(x::Union{Base.Slice, Base.IdentityUnitRange})
OptionallyStaticStepRange(x.indices)
end
function OptionallyStaticStepRange(x::Base.OneTo)
_OptionallyStaticStepRange(static(1), static(1), Int(last(x)))
end
function OptionallyStaticStepRange(x::OptionallyStaticUnitRange)
_OptionallyStaticStepRange(getfield(x, :start), static(1), getfield(x, :stop))
end
function OptionallyStaticStepRange(x::AbstractUnitRange)
_OptionallyStaticStepRange(Int(first(x)), static(1), Int(last(x)))
end
function OptionallyStaticStepRange(x::AbstractRange)
_OptionallyStaticStepRange(IntType(static_first(x)), IntType(static_step(x)),
IntType(static_last(x)))
_OptionallyStaticStepRange(Int(first(x)), Int(step(x)), Int(last(x)))
end

# to make StepRange constructor inlineable, so optimizer can see `step` value
Expand Down Expand Up @@ -129,6 +123,37 @@ end
end
end

OptionallyStaticUnitRange(@nospecialize x::OptionallyStaticUnitRange) = x
OptionallyStaticUnitRange(x::Base.OneTo) = OptionallyStaticUnitRange(static(1), Int(last(x)))
function OptionallyStaticUnitRange(x::Union{Base.Slice, Base.IdentityUnitRange})
OptionallyStaticUnitRange(x.indices)
end
function OptionallyStaticUnitRange(x::OptionallyStaticStepRange)
assert_unit_step(step(x))
_OptionallyStaticUnitRange(getfield(x, :start), getfield(x, :stop))
end
function OptionallyStaticUnitRange(x::AbstractRange)
assert_unit_step(step(x))
_OptionallyStaticUnitRange(first(x), last(x))
end
function OptionallyStaticUnitRange{F, L}(x::AbstractRange) where {F, L}
OptionallyStaticUnitRange(x)
end
function OptionallyStaticUnitRange(start::IntType, stop::IntType)
_OptionallyStaticUnitRange(start, stop)
end
function OptionallyStaticUnitRange(start, stop)
OptionallyStaticUnitRange(IntType(start), IntType(stop))
end
function OptionallyStaticUnitRange{StaticInt{F}, StaticInt{L}}() where {F, L}
_OptionallyStaticUnitRange(StaticInt{F}(), StaticInt{L}())
end
function assert_unit_step(s::Int)
s == 1 && return nothing
errmsg(x) = throw(ArgumentError(LazyString("step must be 1, got ", s))) # avoid GC frame
errmsg(s)
end

"""
SUnitRange(start::Int, stop::Int)

Expand All @@ -150,82 +175,6 @@ const OptionallyStaticRange{
F, L} = Union{OptionallyStaticUnitRange{F, L},
OptionallyStaticStepRange{F, <:Any, L}}

"""
static_first(x::AbstractRange)

Attempt to return `static(first(x))`, if known at compile time. Otherwise, return
`first(x)`.

See also: [`static_step`](@ref), [`static_last`](@ref)

# Examples

```julia
julia> static_first(static(2):10)
static(2)

julia> static_first(1:10)
1

julia> static_first(Base.OneTo(10))
static(1)

```
"""
static_first(x::Base.OneTo) = StaticInt(1)
static_first(x::Union{Base.Slice, Base.IdentityUnitRange}) = static_first(x.indices)
static_first(x::OptionallyStaticRange) = getfield(x, :start)
static_first(x) = first(x)

"""
static_step(x::AbstractRange)

Attempt to return `static(step(x))`, if known at compile time. Otherwise, return
`step(x)`.

See also: [`static_first`](@ref), [`static_last`](@ref)

# Examples

```julia
julia> static_step(static(1):static(3):9)
static(3)

julia> static_step(1:3:9)
3

julia> static_step(1:9)
static(1)

```
"""
static_step(@nospecialize x::AbstractUnitRange) = StaticInt(1)
static_step(x::OptionallyStaticStepRange) = getfield(x, :step)
static_step(x) = step(x)

"""
static_last(x::AbstractRange)

Attempt to return `static(last(x))`, if known at compile time. Otherwise, return
`last(x)`.

See also: [`static_first`](@ref), [`static_step`](@ref)

# Examples

```julia
julia> static_last(static(1):static(10))
static(10)

julia> static_last(static(1):10)
10

```
"""
static_last(x::OptionallyStaticRange) = getfield(x, :stop)
static_last(x) = last(x)
static_last(x::Union{Base.Slice, Base.IdentityUnitRange}) = static_last(x.indices)

Base.first(x::OptionallyStaticRange{Int}) = getfield(x, :start)
Base.first(::OptionallyStaticRange{StaticInt{F}}) where {F} = F
Base.step(x::OptionallyStaticStepRange{<:Any, Int}) = getfield(x, :step)
Expand Down Expand Up @@ -285,12 +234,15 @@ function Base.checkindex(::Type{Bool},
(F1::Int <= F2::Int) && (L1::Int >= L2::Int)
end

function Base.getindex(r::OptionallyStaticUnitRange,
s::AbstractUnitRange{<:Integer})
function Base.getindex(
r::OptionallyStaticUnitRange,
i::AbstractUnitRange{<:Integer},
)
s = OptionallyStaticUnitRange(i)
@boundscheck checkbounds(r, s)
f = static_first(r)
f = getfield(r, :start)
fnew = f - one(f)
return (fnew + static_first(s)):(fnew + static_last(s))
return (fnew + getfield(s, :start)):(fnew + getfield(s, :stop))
end

function Base.getindex(x::OptionallyStaticUnitRange{StaticInt{1}}, i::Int)
Expand Down Expand Up @@ -320,11 +272,10 @@ end

Base.AbstractUnitRange{Int}(r::OptionallyStaticUnitRange) = r
function Base.AbstractUnitRange{T}(r::OptionallyStaticUnitRange) where {T}
start = static_first(r)
if isa(start, StaticInt{1}) && T <: Integer
return Base.OneTo{T}(T(static_last(r)))
if isa(getfield(r, :start), StaticInt{1}) && T <: Integer
return Base.OneTo{T}(T(last(r)))
else
return UnitRange{T}(T(static_first(r)), T(static_last(r)))
return UnitRange{T}(T(first(r)), T(last(r)))
end
end

Expand Down Expand Up @@ -373,7 +324,10 @@ end
Base.axes1(x::Base.Slice{<:OptionallyStaticUnitRange{One}}) = x.indices
Base.axes1(x::Base.Slice{<:OptionallyStaticRange}) = Base.IdentityUnitRange(x.indices)

Base.:(-)(r::OptionallyStaticRange) = (-static_first(r)):(-static_step(r)):(-static_last(r))
function Base.:(-)(r::OptionallyStaticRange)
s = isa(r, OptionallyStaticStepRange) ? -getfield(r, :step) : -One()
(-getfield(r, :start)):s:(-getfield(r, :stop))
end

function Base.reverse(x::OptionallyStaticUnitRange)
_OptionallyStaticStepRange(getfield(x, :stop), StaticInt(-1), getfield(x, :start))
Expand Down Expand Up @@ -435,25 +389,25 @@ end

function Base.first(x::OptionallyStaticUnitRange, n::IntType)
n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
start = static_first(x)
OptionallyStaticUnitRange(start, min(start - one(start) + n, static_last(x)))
start = getfield(x, :start)
OptionallyStaticUnitRange(start, min(start - one(start) + n, getfield(x, :stop)))
end
function Base.first(x::OptionallyStaticStepRange, n::IntType)
n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
start = static_first(x)
s = static_step(x)
stop = min(((n - one(n)) * s) + static_first(x), static_last(x))
start = getfield(x, :start)
s = getfield(x, :step)
stop = min(((n - one(n)) * s) + start, getfield(x, :stop))
OptionallyStaticStepRange(start, s, stop)
end
function Base.last(x::OptionallyStaticUnitRange, n::IntType)
n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
stop = static_last(x)
OptionallyStaticUnitRange(max(stop + one(stop) - n, static_first(x)), stop)
stop = getfield(x, :stop)
OptionallyStaticUnitRange(max(stop + one(stop) - n, getfield(x, :start)), stop)
end
function Base.last(x::OptionallyStaticStepRange, n::IntType)
n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
start = static_first(x)
s = static_step(x)
stop = static_last(x)
start = getfield(x, :start)
s = getfield(x, :step)
stop = getfield(x, :stop)
OptionallyStaticStepRange(max(stop + one(stop) - (n * s), start), s, stop)
end
3 changes: 0 additions & 3 deletions test/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,6 @@ end
@test_throws BoundsError getindex(Static.OptionallyStaticUnitRange(static(1), 10), 11)
@test_throws BoundsError getindex(Static.OptionallyStaticStepRange(static(1), 2, 10), 11)

@test Static.static_first(Base.OneTo(one(UInt))) === static(1)
@test Static.static_step(Base.OneTo(one(UInt))) === static(1)

@test @inferred(eachindex(static(-7):static(7))) === static(1):static(15)
@test @inferred((static(-7):static(7))[first(eachindex(static(-7):static(7)))]) == -7

Expand Down
Loading