Skip to content

Commit 1dc78e9

Browse files
author
Pietro Vertechi
authored
Merge pull request #78 from piever/pv/cart
Add type parameter for index style in StructArray
2 parents 5953d9d + 97e78b3 commit 1dc78e9

File tree

2 files changed

+32
-20
lines changed

2 files changed

+32
-20
lines changed

src/lazy.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
struct LazyRow{T, N, C, I}
2-
columns::StructArray{T, N, C} # a `Columns` object
2+
columns::StructArray{T, N, C, I} # a `Columns` object
33
index::I
44
end
55

@@ -30,19 +30,17 @@ iscompatible(::Type{<:LazyRow{S}}, ::Type{StructArray{T, N, C}}) where {S, T, N,
3030
(s::ArrayInitializer)(::Type{<:LazyRow{T}}, d) where {T} = buildfromschema(typ -> s(typ, d), T)
3131

3232
struct LazyRows{T, N, C, I} <: AbstractArray{LazyRow{T, N, C, I}, N}
33-
columns::StructArray{T, N, C}
33+
columns::StructArray{T, N, C, I}
3434
end
35-
LazyRows(s::S) where {S<:StructArray} = LazyRows(IndexStyle(S), s)
36-
LazyRows(::IndexLinear, s::StructArray{T, N, C}) where {T, N, C} = LazyRows{T, N, C, Int}(s)
37-
LazyRows(::IndexCartesian, s::StructArray{T, N, C}) where {T, N, C} = LazyRows{T, N, C, CartesianIndex{N}}(s)
3835
Base.parent(v::LazyRows) = getfield(v, 1)
3936
fieldarrays(v::LazyRows) = fieldarrays(parent(v))
4037

4138
Base.size(v::LazyRows) = size(parent(v))
42-
Base.getindex(v::LazyRows{<:Any, <:Any, <:Any, <:Integer}, i::Integer) = LazyRow(parent(v), i)
43-
Base.getindex(v::LazyRows{<:Any, <:Any, <:Any, <:CartesianIndex}, i::Integer...) = LazyRow(parent(v), CartesianIndex(i))
39+
Base.getindex(v::LazyRows{<:Any, <:Any, <:Any, Int}, i::Int) = LazyRow(parent(v), i)
40+
Base.getindex(v::LazyRows{<:Any, <:Any, <:Any, CartesianIndex{N}}, i::Vararg{Int, N}) where {N} = LazyRow(parent(v), CartesianIndex(i))
4441

45-
Base.IndexStyle(::Type{<:LazyRows{<:Any, <:Any, <:Any, <:Integer}}) = IndexLinear()
42+
_best_index(::Type{LazyRows{T, N, C, I}}) where {T, N, C, I} = I
43+
Base.IndexStyle(::Type{L}) where {L<:LazyRows} = _indexstyle(_best_index(L))
4644

4745
function Base.showarg(io::IO, s::LazyRows{T}, toplevel) where T
4846
print(io, "LazyRows")

src/structarray.jl

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ A type that stores an array of structures as a structure of arrays.
33
# Fields:
44
- `fieldarrays`: a (named) tuple of arrays. Also `fieldarrays(x)`
55
"""
6-
struct StructArray{T, N, C<:Tup} <: AbstractArray{T, N}
6+
struct StructArray{T, N, C<:Tup, I} <: AbstractArray{T, N}
77
fieldarrays::C
88

99
function StructArray{T, N, C}(c) where {T, N, C<:Tup}
@@ -14,10 +14,18 @@ struct StructArray{T, N, C<:Tup} <: AbstractArray{T, N}
1414
axes(c[i]) == ax || error("all field arrays must have same shape")
1515
end
1616
end
17-
new{T, N, C}(c)
17+
new{T, N, C, _best_index(c...)}(c)
1818
end
1919
end
2020

21+
_best_index() = Int
22+
_best_index(col::AbstractArray, cols::AbstractArray...) = _best_index(IndexStyle(col, cols...), col)
23+
_best_index(::IndexLinear, ::AbstractArray) = Int
24+
_best_index(::IndexCartesian, ::AbstractArray{T, N}) where {T, N} = CartesianIndex{N}
25+
_best_index(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = I
26+
_indexstyle(::Type{Int}) = IndexLinear()
27+
_indexstyle(::Type{CartesianIndex{N}}) where {N} = IndexCartesian()
28+
2129
_dims(c::Tup) = length(axes(c[1]))
2230
_dims(c::EmptyTup) = 1
2331

@@ -37,17 +45,11 @@ _structarray(args::Tuple, names) = _structarray(args, Tuple(names))
3745
_structarray(args::Tuple, ::Tuple) = _structarray(args, nothing)
3846
_structarray(args::NTuple{N, Any}, names::NTuple{N, Symbol}) where {N} = StructArray(NamedTuple{names}(args))
3947

40-
const StructVector{T, C<:Tup} = StructArray{T, 1, C}
48+
const StructVector{T, C<:Tup, I} = StructArray{T, 1, C, I}
4149
StructVector{T}(args...; kwargs...) where {T} = StructArray{T}(args...; kwargs...)
4250
StructVector(args...; kwargs...) = StructArray(args...; kwargs...)
4351

44-
_indexstyle(::Type{Tuple{}}) = IndexStyle(Union{})
45-
_indexstyle(::Type{T}) where {T<:Tuple} = IndexStyle(IndexStyle(tuple_type_head(T)), _indexstyle(tuple_type_tail(T)))
46-
_indexstyle(::Type{NamedTuple{names, types}}) where {names, types} = _indexstyle(types)
47-
48-
function Base.IndexStyle(::Type{StructArray{T, N, C}}) where {T, N, C}
49-
_indexstyle(C)
50-
end
52+
Base.IndexStyle(::Type{S}) where {S<:StructArray} = _indexstyle(_best_index(S))
5153

5254
_undef_array(::Type{T}, sz; unwrap = t -> false) where {T} = unwrap(T) ? StructArray{T}(undef, sz; unwrap = unwrap) : Array{T}(undef, sz)
5355

@@ -124,22 +126,34 @@ function get_ith(cols::NTuple{N, Any}, I...) where N
124126
end
125127
end
126128

127-
Base.@propagate_inbounds function Base.getindex(x::StructArray{T, N, C}, I::Int...) where {T, N, C}
129+
Base.@propagate_inbounds function Base.getindex(x::StructArray{T, <:Any, <:Any, CartesianIndex{N}}, I::Vararg{Int, N}) where {T, N}
128130
cols = fieldarrays(x)
129131
@boundscheck checkbounds(x, I...)
130132
return createinstance(T, get_ith(cols, I...)...)
131133
end
132134

135+
Base.@propagate_inbounds function Base.getindex(x::StructArray{T, <:Any, <:Any, Int}, I::Int) where {T}
136+
cols = fieldarrays(x)
137+
@boundscheck checkbounds(x, I)
138+
return createinstance(T, get_ith(cols, I)...)
139+
end
140+
133141
function Base.view(s::StructArray{T, N, C}, I...) where {T, N, C}
134142
StructArray{T}(map(v -> view(v, I...), fieldarrays(s)))
135143
end
136144

137-
Base.@propagate_inbounds function Base.setindex!(s::StructArray, vals, I::Int...)
145+
Base.@propagate_inbounds function Base.setindex!(s::StructArray{<:Any, <:Any, <:Any, CartesianIndex{N}}, vals, I::Vararg{Int, N}) where {N}
138146
@boundscheck checkbounds(s, I...)
139147
foreachfield((col, val) -> (@inbounds col[I...] = val), s, vals)
140148
s
141149
end
142150

151+
Base.@propagate_inbounds function Base.setindex!(s::StructArray{<:Any, <:Any, <:Any, Int}, vals, I::Int)
152+
@boundscheck checkbounds(s, I)
153+
foreachfield((col, val) -> (@inbounds col[I] = val), s, vals)
154+
s
155+
end
156+
143157
function Base.push!(s::StructArray, vals)
144158
foreachfield(push!, s, vals)
145159
return s

0 commit comments

Comments
 (0)