@@ -3,7 +3,7 @@ A type that stores an array of structures as a structure of arrays.
33# Fields:
44- `fieldarrays`: a (named) tuple of arrays. Also `fieldarrays(x)`
55"""
6- struct StructArray{T, N, C<: Tup } <: AbstractArray{T, N}
6+ struct StructArray{T, N, C<: Tup , I } <: AbstractArray{T, N}
77 fieldarrays:: C
88
99 function StructArray {T, N, C} (c) where {T, N, C<: Tup }
@@ -14,10 +14,18 @@ struct StructArray{T, N, C<:Tup} <: AbstractArray{T, N}
1414 axes (c[i]) == ax || error (" all field arrays must have same shape" )
1515 end
1616 end
17- new {T, N, C} (c)
17+ new {T, N, C, _best_index(c...) } (c)
1818 end
1919end
2020
21+ _best_index () = Int
22+ _best_index (col:: AbstractArray , cols:: AbstractArray... ) = _best_index (IndexStyle (col, cols... ), col)
23+ _best_index (:: IndexLinear , :: AbstractArray ) = Int
24+ _best_index (:: IndexCartesian , :: AbstractArray{T, N} ) where {T, N} = CartesianIndex{N}
25+ _best_index (:: Type{StructArray{T, N, C, I}} ) where {T, N, C, I} = I
26+ _indexstyle (:: Type{Int} ) = IndexLinear ()
27+ _indexstyle (:: Type{CartesianIndex{N}} ) where {N} = IndexCartesian ()
28+
2129_dims (c:: Tup ) = length (axes (c[1 ]))
2230_dims (c:: EmptyTup ) = 1
2331
@@ -37,17 +45,11 @@ _structarray(args::Tuple, names) = _structarray(args, Tuple(names))
3745_structarray (args:: Tuple , :: Tuple ) = _structarray (args, nothing )
3846_structarray (args:: NTuple{N, Any} , names:: NTuple{N, Symbol} ) where {N} = StructArray (NamedTuple {names} (args))
3947
40- const StructVector{T, C<: Tup } = StructArray{T, 1 , C}
48+ const StructVector{T, C<: Tup , I } = StructArray{T, 1 , C, I }
4149StructVector {T} (args... ; kwargs... ) where {T} = StructArray {T} (args... ; kwargs... )
4250StructVector (args... ; kwargs... ) = StructArray (args... ; kwargs... )
4351
44- _indexstyle (:: Type{Tuple{}} ) = IndexStyle (Union{})
45- _indexstyle (:: Type{T} ) where {T<: Tuple } = IndexStyle (IndexStyle (tuple_type_head (T)), _indexstyle (tuple_type_tail (T)))
46- _indexstyle (:: Type{NamedTuple{names, types}} ) where {names, types} = _indexstyle (types)
47-
48- function Base. IndexStyle (:: Type{StructArray{T, N, C}} ) where {T, N, C}
49- _indexstyle (C)
50- end
52+ Base. IndexStyle (:: Type{S} ) where {S<: StructArray } = _indexstyle (_best_index (S))
5153
5254_undef_array (:: Type{T} , sz; unwrap = t -> false ) where {T} = unwrap (T) ? StructArray {T} (undef, sz; unwrap = unwrap) : Array {T} (undef, sz)
5355
@@ -124,22 +126,34 @@ function get_ith(cols::NTuple{N, Any}, I...) where N
124126 end
125127end
126128
127- Base. @propagate_inbounds function Base. getindex (x:: StructArray{T, N, C} , I:: Int... ) where {T, N, C }
129+ Base. @propagate_inbounds function Base. getindex (x:: StructArray{T, <:Any, <:Any, CartesianIndex{N}} , I:: Vararg{ Int, N} ) where {T, N}
128130 cols = fieldarrays (x)
129131 @boundscheck checkbounds (x, I... )
130132 return createinstance (T, get_ith (cols, I... )... )
131133end
132134
135+ Base. @propagate_inbounds function Base. getindex (x:: StructArray{T, <:Any, <:Any, Int} , I:: Int ) where {T}
136+ cols = fieldarrays (x)
137+ @boundscheck checkbounds (x, I)
138+ return createinstance (T, get_ith (cols, I)... )
139+ end
140+
133141function Base. view (s:: StructArray{T, N, C} , I... ) where {T, N, C}
134142 StructArray {T} (map (v -> view (v, I... ), fieldarrays (s)))
135143end
136144
137- Base. @propagate_inbounds function Base. setindex! (s:: StructArray , vals, I:: Int... )
145+ Base. @propagate_inbounds function Base. setindex! (s:: StructArray{<:Any, <:Any, <:Any, CartesianIndex{N}} , vals, I:: Vararg{ Int, N} ) where {N}
138146 @boundscheck checkbounds (s, I... )
139147 foreachfield ((col, val) -> (@inbounds col[I... ] = val), s, vals)
140148 s
141149end
142150
151+ Base. @propagate_inbounds function Base. setindex! (s:: StructArray{<:Any, <:Any, <:Any, Int} , vals, I:: Int )
152+ @boundscheck checkbounds (s, I)
153+ foreachfield ((col, val) -> (@inbounds col[I] = val), s, vals)
154+ s
155+ end
156+
143157function Base. push! (s:: StructArray , vals)
144158 foreachfield (push!, s, vals)
145159 return s
0 commit comments