|
1 | 1 | module StaticArrayInterfaceStaticArraysExt |
2 | 2 |
|
3 | | -using Adapt |
4 | | -using ArrayInterface |
| 3 | +using StaticArrayInterface |
5 | 4 | using LinearAlgebra |
| 5 | +using Static |
| 6 | +using Static: StaticInt |
| 7 | + |
6 | 8 | if isdefined(Base, :get_extension) |
7 | 9 | using StaticArrays |
8 | | - using Static |
9 | | - using Static: StaticInt |
10 | 10 | else |
11 | 11 | using ..StaticArrays |
12 | | - using ..Static |
13 | | - using ..Static: StaticInt |
14 | 12 | end |
15 | 13 |
|
16 | 14 | const CanonicalInt = Union{Int,StaticInt} |
17 | 15 |
|
18 | 16 | function Static.OptionallyStaticUnitRange(::StaticArrays.SOneTo{N}) where {N} |
19 | 17 | Static.OptionallyStaticUnitRange(StaticInt(1), StaticInt(N)) |
20 | 18 | end |
21 | | -ArrayInterface.known_first(::Type{<:StaticArrays.SOneTo}) = 1 |
22 | | -ArrayInterface.known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N |
23 | | -ArrayInterface.known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N |
24 | | -ArrayInterface.known_length(::Type{StaticArrays.Length{L}}) where {L} = L |
25 | | -function ArrayInterface.known_length(::Type{A}) where {A<:StaticArrays.StaticArray} |
26 | | - ArrayInterface.known_length(StaticArrays.Length(A)) |
| 19 | +StaticArrayInterface.known_first(::Type{<:StaticArrays.SOneTo}) = 1 |
| 20 | +StaticArrayInterface.known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N |
| 21 | +StaticArrayInterface.known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N |
| 22 | +StaticArrayInterface.known_length(::Type{StaticArrays.Length{L}}) where {L} = L |
| 23 | +function StaticArrayInterface.known_length(::Type{A}) where {A<:StaticArrays.StaticArray} |
| 24 | + StaticArrayInterface.known_length(StaticArrays.Length(A)) |
27 | 25 | end |
28 | 26 |
|
29 | | -@inline ArrayInterface.static_length(x::StaticArrays.StaticArray) = Static.maybe_static(ArrayInterface.known_length, Base.length, x) |
30 | | -ArrayInterface.device(::Type{<:StaticArrays.MArray}) = ArrayInterface.CPUPointer() |
31 | | -ArrayInterface.device(::Type{<:StaticArrays.SArray}) = ArrayInterface.CPUTuple() |
32 | | -ArrayInterface.contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}() |
33 | | -ArrayInterface.contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}() |
34 | | -function ArrayInterface.stride_rank(::Type{T}) where {N,T<:StaticArray{<:Any,<:Any,N}} |
| 27 | +@inline StaticArrayInterface.static_length(x::StaticArrays.StaticArray) = Static.maybe_static(StaticArrayInterface.known_length, Base.length, x) |
| 28 | +StaticArrayInterface.device(::Type{<:StaticArrays.MArray}) = StaticArrayInterface.CPUPointer() |
| 29 | +StaticArrayInterface.device(::Type{<:StaticArrays.SArray}) = StaticArrayInterface.CPUTuple() |
| 30 | +StaticArrayInterface.contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}() |
| 31 | +StaticArrayInterface.contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}() |
| 32 | +function StaticArrayInterface.stride_rank(::Type{T}) where {N,T<:StaticArray{<:Any,<:Any,N}} |
35 | 33 | ntuple(static, StaticInt(N)) |
36 | 34 | end |
37 | | -function ArrayInterface.dense_dims(::Type{<:StaticArray{S,T,N}}) where {S,T,N} |
38 | | - ArrayInterface._all_dense(Val(N)) |
| 35 | +function StaticArrayInterface.dense_dims(::Type{<:StaticArray{S,T,N}}) where {S,T,N} |
| 36 | + StaticArrayInterface._all_dense(Val(N)) |
39 | 37 | end |
40 | | -ArrayInterface.defines_strides(::Type{<:StaticArrays.SArray}) = true |
41 | | -ArrayInterface.defines_strides(::Type{<:StaticArrays.MArray}) = true |
| 38 | +StaticArrayInterface.defines_strides(::Type{<:StaticArrays.SArray}) = true |
| 39 | +StaticArrayInterface.defines_strides(::Type{<:StaticArrays.MArray}) = true |
42 | 40 |
|
43 | | -@generated function ArrayInterface.axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S} |
| 41 | +@generated function StaticArrayInterface.axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S} |
44 | 42 | Tuple{[StaticArrays.SOneTo{s} for s in S.parameters]...} |
45 | 43 | end |
46 | | -@generated function ArrayInterface.static_size(A::StaticArrays.StaticArray{S}) where {S} |
| 44 | +@generated function StaticArrayInterface.static_size(A::StaticArrays.StaticArray{S}) where {S} |
47 | 45 | t = Expr(:tuple) |
48 | 46 | Sp = S.parameters |
49 | 47 | for n = 1:length(Sp) |
50 | 48 | push!(t.args, Expr(:call, Expr(:curly, :StaticInt, Sp[n]))) |
51 | 49 | end |
52 | 50 | return t |
53 | 51 | end |
54 | | -@generated function ArrayInterface.static_strides(A::StaticArrays.StaticArray{S}) where {S} |
| 52 | +@generated function StaticArrayInterface.static_strides(A::StaticArrays.StaticArray{S}) where {S} |
55 | 53 | t = Expr(:tuple, Expr(:call, Expr(:curly, :StaticInt, 1))) |
56 | 54 | Sp = S.parameters |
57 | 55 | x = 1 |
|
61 | 59 | return t |
62 | 60 | end |
63 | 61 | if StaticArrays.SizedArray{Tuple{8,8},Float64,2,2} isa UnionAll |
64 | | - @inline ArrayInterface.static_strides(B::StaticArrays.SizedArray{S,T,M,N,A}) where {S,T,M,N,A<:SubArray} = ArrayInterface.static_strides(B.data) |
65 | | - ArrayInterface.parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N,A}}) where {S,T,M,N,A} = A |
| 62 | + @inline StaticArrayInterface.static_strides(B::StaticArrays.SizedArray{S,T,M,N,A}) where {S,T,M,N,A<:SubArray} = StaticArrayInterface.static_strides(B.data) |
| 63 | + StaticArrayInterface.parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N,A}}) where {S,T,M,N,A} = A |
66 | 64 | else |
67 | | - ArrayInterface.parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N}}) where {S,T,M,N} = Array{T,N} |
| 65 | + StaticArrayInterface.parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N}}) where {S,T,M,N} = Array{T,N} |
68 | 66 | end |
69 | 67 |
|
70 | 68 | end # module |
0 commit comments