Skip to content

Commit bef5225

Browse files
tests pass
1 parent 7635d02 commit bef5225

18 files changed

+1078
-933
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name = "StaticArrayInterface"
2-
uuid = "8e56abe8-95dc-4044-b214-a7de3332c128"
2+
uuid = "0d7ed370-da01-4f52-bd93-41d350b8b718"
33
version = "1.0.0"
44

55
[deps]
@@ -22,6 +22,10 @@ Static = "0.8"
2222
Requires = "1"
2323
julia = "1.6"
2424

25+
[extensions]
26+
StaticArrayInterfaceOffsetArraysExt = "OffsetArrays"
27+
StaticArrayInterfaceStaticArraysExt = "StaticArrays"
28+
2529
[extras]
2630
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2731
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
@@ -35,4 +39,3 @@ test = ["SafeTestsets", "Pkg", "Test", "OffsetArrays", "StaticArrays"]
3539
[weakdeps]
3640
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
3741
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
38-
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
module StaticArrayInterfaceOffsetArraysExt
22

3-
using ArrayInterface
3+
using StaticArrayInterface
4+
using Static
45
if isdefined(Base, :get_extension)
56
using OffsetArrays
6-
using Static
77
else
88
using ..OffsetArrays
9-
using ..Static
109
end
1110

1211
relative_offsets(r::OffsetArrays.IdOffsetRange) = (getfield(r, :offset),)
@@ -25,43 +24,43 @@ function relative_offsets(A::OffsetArrays.OffsetArray, dim::Int)
2524
return getfield(relative_offsets(A), dim)
2625
end
2726
end
28-
ArrayInterface.parent_type(::Type{<:OffsetArrays.OffsetArray{T,N,A}}) where {T,N,A} = A
27+
StaticArrayInterface.parent_type(::Type{<:OffsetArrays.OffsetArray{T,N,A}}) where {T,N,A} = A
2928
function _offset_axis_type(::Type{T}, dim::StaticInt{D}) where {T,D}
30-
OffsetArrays.IdOffsetRange{Int,ArrayInterface.axes_types(T, dim)}
29+
OffsetArrays.IdOffsetRange{Int,StaticArrayInterface.axes_types(T, dim)}
3130
end
32-
function ArrayInterface.axes_types(::Type{T}) where {T<:OffsetArrays.OffsetArray}
31+
function StaticArrayInterface.axes_types(::Type{T}) where {T<:OffsetArrays.OffsetArray}
3332
Static.eachop_tuple(
3433
_offset_axis_type,
3534
ntuple(static, StaticInt(ndims(T))),
36-
ArrayInterface.parent_type(T)
35+
StaticArrayInterface.parent_type(T)
3736
)
3837
end
39-
ArrayInterface.static_strides(A::OffsetArray) = ArrayInterface.static_strides(parent(A))
40-
function ArrayInterface.known_offsets(::Type{A}) where {A<:OffsetArrays.OffsetArray}
38+
StaticArrayInterface.static_strides(A::OffsetArray) = StaticArrayInterface.static_strides(parent(A))
39+
function StaticArrayInterface.known_offsets(::Type{A}) where {A<:OffsetArrays.OffsetArray}
4140
ntuple(identity -> nothing, Val(ndims(A)))
4241
end
43-
function ArrayInterface.offsets(A::OffsetArrays.OffsetArray)
44-
map(+, ArrayInterface.offsets(parent(A)), relative_offsets(A))
42+
function StaticArrayInterface.offsets(A::OffsetArrays.OffsetArray)
43+
map(+, StaticArrayInterface.offsets(parent(A)), relative_offsets(A))
4544
end
46-
@inline function ArrayInterface.offsets(A::OffsetArrays.OffsetArray, dim)
47-
d = ArrayInterface.to_dims(A, dim)
48-
ArrayInterface.offsets(parent(A), d) + relative_offsets(A, d)
45+
@inline function StaticArrayInterface.offsets(A::OffsetArrays.OffsetArray, dim)
46+
d = StaticArrayInterface.to_dims(A, dim)
47+
StaticArrayInterface.offsets(parent(A), d) + relative_offsets(A, d)
4948
end
50-
@inline function ArrayInterface.static_axes(A::OffsetArrays.OffsetArray)
51-
map(OffsetArrays.IdOffsetRange, ArrayInterface.static_axes(parent(A)), relative_offsets(A))
49+
@inline function StaticArrayInterface.static_axes(A::OffsetArrays.OffsetArray)
50+
map(OffsetArrays.IdOffsetRange, StaticArrayInterface.static_axes(parent(A)), relative_offsets(A))
5251
end
53-
@inline function ArrayInterface.static_axes(A::OffsetArrays.OffsetArray, dim)
54-
d = ArrayInterface.to_dims(A, dim)
55-
OffsetArrays.IdOffsetRange(ArrayInterface.static_axes(parent(A), d), relative_offsets(A, d))
52+
@inline function StaticArrayInterface.static_axes(A::OffsetArrays.OffsetArray, dim)
53+
d = StaticArrayInterface.to_dims(A, dim)
54+
OffsetArrays.IdOffsetRange(StaticArrayInterface.static_axes(parent(A), d), relative_offsets(A, d))
5655
end
57-
function ArrayInterface.stride_rank(T::Type{<:OffsetArray})
58-
ArrayInterface.stride_rank(ArrayInterface.parent_type(T))
56+
function StaticArrayInterface.stride_rank(T::Type{<:OffsetArray})
57+
StaticArrayInterface.stride_rank(StaticArrayInterface.parent_type(T))
5958
end
60-
function ArrayInterface.dense_dims(T::Type{<:OffsetArray})
61-
ArrayInterface.dense_dims(ArrayInterface.parent_type(T))
59+
function StaticArrayInterface.dense_dims(T::Type{<:OffsetArray})
60+
StaticArrayInterface.dense_dims(StaticArrayInterface.parent_type(T))
6261
end
63-
function ArrayInterface.contiguous_axis(T::Type{<:OffsetArray})
64-
ArrayInterface.contiguous_axis(ArrayInterface.parent_type(T))
62+
function StaticArrayInterface.contiguous_axis(T::Type{<:OffsetArray})
63+
StaticArrayInterface.contiguous_axis(StaticArrayInterface.parent_type(T))
6564
end
6665

6766
end # module
Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,55 @@
11
module StaticArrayInterfaceStaticArraysExt
22

3-
using Adapt
4-
using ArrayInterface
3+
using StaticArrayInterface
54
using LinearAlgebra
5+
using Static
6+
using Static: StaticInt
7+
68
if isdefined(Base, :get_extension)
79
using StaticArrays
8-
using Static
9-
using Static: StaticInt
1010
else
1111
using ..StaticArrays
12-
using ..Static
13-
using ..Static: StaticInt
1412
end
1513

1614
const CanonicalInt = Union{Int,StaticInt}
1715

1816
function Static.OptionallyStaticUnitRange(::StaticArrays.SOneTo{N}) where {N}
1917
Static.OptionallyStaticUnitRange(StaticInt(1), StaticInt(N))
2018
end
21-
ArrayInterface.known_first(::Type{<:StaticArrays.SOneTo}) = 1
22-
ArrayInterface.known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N
23-
ArrayInterface.known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N
24-
ArrayInterface.known_length(::Type{StaticArrays.Length{L}}) where {L} = L
25-
function ArrayInterface.known_length(::Type{A}) where {A<:StaticArrays.StaticArray}
26-
ArrayInterface.known_length(StaticArrays.Length(A))
19+
StaticArrayInterface.known_first(::Type{<:StaticArrays.SOneTo}) = 1
20+
StaticArrayInterface.known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N
21+
StaticArrayInterface.known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N
22+
StaticArrayInterface.known_length(::Type{StaticArrays.Length{L}}) where {L} = L
23+
function StaticArrayInterface.known_length(::Type{A}) where {A<:StaticArrays.StaticArray}
24+
StaticArrayInterface.known_length(StaticArrays.Length(A))
2725
end
2826

29-
@inline ArrayInterface.static_length(x::StaticArrays.StaticArray) = Static.maybe_static(ArrayInterface.known_length, Base.length, x)
30-
ArrayInterface.device(::Type{<:StaticArrays.MArray}) = ArrayInterface.CPUPointer()
31-
ArrayInterface.device(::Type{<:StaticArrays.SArray}) = ArrayInterface.CPUTuple()
32-
ArrayInterface.contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}()
33-
ArrayInterface.contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}()
34-
function ArrayInterface.stride_rank(::Type{T}) where {N,T<:StaticArray{<:Any,<:Any,N}}
27+
@inline StaticArrayInterface.static_length(x::StaticArrays.StaticArray) = Static.maybe_static(StaticArrayInterface.known_length, Base.length, x)
28+
StaticArrayInterface.device(::Type{<:StaticArrays.MArray}) = StaticArrayInterface.CPUPointer()
29+
StaticArrayInterface.device(::Type{<:StaticArrays.SArray}) = StaticArrayInterface.CPUTuple()
30+
StaticArrayInterface.contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}()
31+
StaticArrayInterface.contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}()
32+
function StaticArrayInterface.stride_rank(::Type{T}) where {N,T<:StaticArray{<:Any,<:Any,N}}
3533
ntuple(static, StaticInt(N))
3634
end
37-
function ArrayInterface.dense_dims(::Type{<:StaticArray{S,T,N}}) where {S,T,N}
38-
ArrayInterface._all_dense(Val(N))
35+
function StaticArrayInterface.dense_dims(::Type{<:StaticArray{S,T,N}}) where {S,T,N}
36+
StaticArrayInterface._all_dense(Val(N))
3937
end
40-
ArrayInterface.defines_strides(::Type{<:StaticArrays.SArray}) = true
41-
ArrayInterface.defines_strides(::Type{<:StaticArrays.MArray}) = true
38+
StaticArrayInterface.defines_strides(::Type{<:StaticArrays.SArray}) = true
39+
StaticArrayInterface.defines_strides(::Type{<:StaticArrays.MArray}) = true
4240

43-
@generated function ArrayInterface.axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S}
41+
@generated function StaticArrayInterface.axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S}
4442
Tuple{[StaticArrays.SOneTo{s} for s in S.parameters]...}
4543
end
46-
@generated function ArrayInterface.static_size(A::StaticArrays.StaticArray{S}) where {S}
44+
@generated function StaticArrayInterface.static_size(A::StaticArrays.StaticArray{S}) where {S}
4745
t = Expr(:tuple)
4846
Sp = S.parameters
4947
for n = 1:length(Sp)
5048
push!(t.args, Expr(:call, Expr(:curly, :StaticInt, Sp[n])))
5149
end
5250
return t
5351
end
54-
@generated function ArrayInterface.static_strides(A::StaticArrays.StaticArray{S}) where {S}
52+
@generated function StaticArrayInterface.static_strides(A::StaticArrays.StaticArray{S}) where {S}
5553
t = Expr(:tuple, Expr(:call, Expr(:curly, :StaticInt, 1)))
5654
Sp = S.parameters
5755
x = 1
@@ -61,10 +59,10 @@ end
6159
return t
6260
end
6361
if StaticArrays.SizedArray{Tuple{8,8},Float64,2,2} isa UnionAll
64-
@inline ArrayInterface.static_strides(B::StaticArrays.SizedArray{S,T,M,N,A}) where {S,T,M,N,A<:SubArray} = ArrayInterface.static_strides(B.data)
65-
ArrayInterface.parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N,A}}) where {S,T,M,N,A} = A
62+
@inline StaticArrayInterface.static_strides(B::StaticArrays.SizedArray{S,T,M,N,A}) where {S,T,M,N,A<:SubArray} = StaticArrayInterface.static_strides(B.data)
63+
StaticArrayInterface.parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N,A}}) where {S,T,M,N,A} = A
6664
else
67-
ArrayInterface.parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N}}) where {S,T,M,N} = Array{T,N}
65+
StaticArrayInterface.parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N}}) where {S,T,M,N} = Array{T,N}
6866
end
6967

7068
end # module

0 commit comments

Comments
 (0)