|
| 1 | +module StructArraysStaticArraysExt |
| 2 | + |
| 3 | +using StructArrays |
| 4 | +using StaticArrays: StaticArray, FieldArray, tuple_prod |
| 5 | + |
| 6 | +""" |
| 7 | + StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} |
| 8 | +
|
| 9 | +The `staticschema` of a `StaticArray` element type is the `staticschema` of the underlying `Tuple`. |
| 10 | +```julia |
| 11 | +julia> StructArrays.staticschema(SVector{2, Float64}) |
| 12 | +Tuple{Float64, Float64} |
| 13 | +``` |
| 14 | +The one exception to this rule is `<:StaticArrays.FieldArray`, since `FieldArray` is based on a |
| 15 | +struct. In this case, `staticschema(<:FieldArray)` returns the `staticschema` for the struct |
| 16 | +which subtypes `FieldArray`. |
| 17 | +""" |
| 18 | +@generated function StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} |
| 19 | + return quote |
| 20 | + Base.@_inline_meta |
| 21 | + return NTuple{$(tuple_prod(S)), T} |
| 22 | + end |
| 23 | +end |
| 24 | +StructArrays.createinstance(::Type{T}, args...) where {T<:StaticArray} = T(args) |
| 25 | +StructArrays.component(s::StaticArray, i) = getindex(s, i) |
| 26 | + |
| 27 | +# invoke general fallbacks for a `FieldArray` type. |
| 28 | +@inline function StructArrays.staticschema(T::Type{<:FieldArray}) |
| 29 | + invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T) |
| 30 | +end |
| 31 | +StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i) |
| 32 | +StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArrays.createinstance, Tuple{Type{<:Any}, Vararg}, T, args...) |
| 33 | + |
| 34 | +# Broadcast overload |
| 35 | +using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo |
| 36 | +using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype |
| 37 | +using StructArrays: isnonemptystructtype |
| 38 | +using Base.Broadcast: Broadcasted, _broadcast_getindex |
| 39 | + |
| 40 | +# StaticArrayStyle has no similar defined. |
| 41 | +# Overload `try_struct_copy` instead. |
| 42 | +@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M} |
| 43 | + flat = broadcast_flatten(bc); as = flat.args; f = flat.f |
| 44 | + argsizes = broadcast_sizes(as...) |
| 45 | + ax = axes(bc) |
| 46 | + ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug at `StaticArrays.jl`.") |
| 47 | + return _broadcast(f, Size(map(length, ax)), argsizes, as...) |
| 48 | +end |
| 49 | + |
| 50 | +# A functor generates the ith component of StructStaticBroadcast. |
| 51 | +struct Similar_ith{SA, E<:Tuple} |
| 52 | + elements::E |
| 53 | + Similar_ith{SA}(elements::Tuple) where {SA} = new{SA, typeof(elements)}(elements) |
| 54 | +end |
| 55 | +function (s::Similar_ith{SA})(i::Int) where {SA} |
| 56 | + ith_elements = ntuple(Val(length(s.elements))) do j |
| 57 | + getfield(s.elements[j], i) |
| 58 | + end |
| 59 | + ith_SA = similar_type(SA, fieldtype(eltype(SA), i)) |
| 60 | + return @inbounds ith_SA(ith_elements) |
| 61 | +end |
| 62 | + |
| 63 | +@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where {newsize} |
| 64 | + first_staticarray = first_statictype(a...) |
| 65 | + elements, ET = if prod(newsize) == 0 |
| 66 | + # Use inference to get eltype in empty case (following StaticBroadcast defined in StaticArrays.jl) |
| 67 | + eltys = Tuple{map(eltype, a)...} |
| 68 | + (), Core.Compiler.return_type(f, eltys) |
| 69 | + else |
| 70 | + temp = __broadcast(f, sz, s, a...) |
| 71 | + temp, eltype(temp) |
| 72 | + end |
| 73 | + if isnonemptystructtype(ET) |
| 74 | + SA = similar_type(first_staticarray, ET, sz) |
| 75 | + arrs = ntuple(Similar_ith{SA}(elements), Val(fieldcount(ET))) |
| 76 | + return StructArray{ET}(arrs) |
| 77 | + else |
| 78 | + @inbounds return similar_type(first_staticarray, ET, sz)(elements) |
| 79 | + end |
| 80 | +end |
| 81 | + |
| 82 | +# The `__broadcast` kernal is copied from `StaticArrays.jl`. |
| 83 | +# see https://github.com/JuliaArrays/StaticArrays.jl/blob/master/src/broadcast.jl |
| 84 | +@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize |
| 85 | + sizes = [sz.parameters[1] for sz ∈ s.parameters] |
| 86 | + |
| 87 | + indices = CartesianIndices(newsize) |
| 88 | + exprs = similar(indices, Expr) |
| 89 | + for (j, current_ind) ∈ enumerate(indices) |
| 90 | + exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes)) |
| 91 | + exprs[j] = :(f($(exprs_vals...))) |
| 92 | + end |
| 93 | + |
| 94 | + return quote |
| 95 | + Base.@_inline_meta |
| 96 | + return tuple($(exprs...)) |
| 97 | + end |
| 98 | +end |
| 99 | + |
| 100 | +broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I)) |
| 101 | +function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex) |
| 102 | + li = LinearIndices(oldsize) |
| 103 | + ind = _broadcast_getindex(li, newindex) |
| 104 | + return :(a[$i][$ind]) |
| 105 | +end |
| 106 | + |
| 107 | +end |
0 commit comments