Skip to content

Commit 189bc9c

Browse files
author
Pietro Vertechi
authored
Fix collection of stateful iterator (#101)
* avoid repetition * readd tests
1 parent 1c447d6 commit 189bc9c

File tree

2 files changed

+29
-36
lines changed

2 files changed

+29
-36
lines changed

src/collect.jl

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ reshapestructarray(v::AbstractArray, d) = reshape(v, d)
3030
reshapestructarray(v::StructArray{T}, d) where {T} =
3131
StructArray{T}(map(x -> reshapestructarray(x, d), fieldarrays(v)))
3232

33+
function collect_empty_structarray(itr::T; initializer = default_initializer) where {T}
34+
S = Core.Compiler.return_type(first, Tuple{T})
35+
res = initializer(S, (0,))
36+
_reshape(res, itr)
37+
end
38+
3339
"""
3440
`collect_structarray(itr, fr=iterate(itr); initializer = default_initializer)`
3541
@@ -39,28 +45,29 @@ and size `d`. By default `initializer` returns a `StructArray` of `Array` but cu
3945
may be used. `fr` represents the moment in the iteration of `itr` from which to start collecting.
4046
"""
4147
collect_structarray(itr; initializer = default_initializer) =
42-
collect_structarray(itr, iterate(itr); initializer = initializer)
43-
44-
collect_structarray(itr, fr; initializer = default_initializer) =
45-
collect_structarray(itr, fr, Base.IteratorSize(itr); initializer = initializer)
48+
_collect_structarray(itr, Base.IteratorSize(itr); initializer = initializer)
4649

47-
collect_structarray(itr, ::Nothing; initializer = default_initializer) =
48-
collect_empty_structarray(itr; initializer = initializer)
49-
50-
function collect_empty_structarray(itr::T; initializer = default_initializer) where {T}
51-
S = Core.Compiler.return_type(first, Tuple{T})
52-
res = initializer(S, (0,))
53-
_reshape(res, itr)
50+
function _collect_structarray(itr, sz::Union{Base.HasShape, Base.HasLength};
51+
initializer = default_initializer)
52+
len = length(itr)
53+
elem = iterate(itr)
54+
elem === nothing && return collect_empty_structarray(itr, initializer = initializer)
55+
el, st = elem
56+
S = typeof(el)
57+
dest = initializer(S, (len,))
58+
dest[1] = el
59+
v = collect_to_structarray!(dest, itr, 2, st)
60+
_reshape(v, itr, sz)
5461
end
5562

56-
function collect_structarray(itr, elem, sz::Union{Base.HasShape, Base.HasLength};
57-
initializer = default_initializer)
58-
el, i = elem
63+
function _collect_structarray(itr, ::Base.SizeUnknown; initializer = default_initializer)
64+
elem = iterate(itr)
65+
elem === nothing && return collect_empty_structarray(itr, initializer = initializer)
66+
el, st = elem
5967
S = typeof(el)
60-
dest = initializer(S, (length(itr),))
68+
dest = initializer(S, (1,))
6169
dest[1] = el
62-
v = collect_to_structarray!(dest, itr, 2, i)
63-
_reshape(v, itr, sz)
70+
grow_to_structarray!(dest, itr, iterate(itr, st))
6471
end
6572

6673
function collect_to_structarray!(dest::AbstractArray, itr, offs, st)
@@ -83,13 +90,6 @@ function collect_to_structarray!(dest::AbstractArray, itr, offs, st)
8390
return dest
8491
end
8592

86-
function collect_structarray(itr, elem, ::Base.SizeUnknown; initializer = default_initializer)
87-
el, st = elem
88-
dest = initializer(typeof(el), (1,))
89-
dest[1] = el
90-
grow_to_structarray!(dest, itr, iterate(itr, st))
91-
end
92-
9393
function grow_to_structarray!(dest::AbstractArray, itr, elem = iterate(itr))
9494
# collect to dest array, checking the type of each result. if a result does not
9595
# match, widen the result type and re-dispatch.
@@ -152,16 +152,10 @@ function _append!!(dest::AbstractVector, itr, ::Union{Base.HasShape, Base.HasLen
152152
fr === nothing && return dest
153153
el, st = fr
154154
i = lastindex(dest) + 1
155-
if iscompatible(el, dest)
156-
resize!(dest, length(dest) + n)
157-
@inbounds dest[i] = el
158-
return collect_to_structarray!(dest, itr, i + 1, st)
159-
else
160-
new = widenstructarray(dest, i, el)
161-
resize!(new, length(dest) + n)
162-
@inbounds new[i] = el
163-
return collect_to_structarray!(new, itr, i + 1, st)
164-
end
155+
new = iscompatible(el, dest) ? dest : widenstructarray(dest, i, el)
156+
resize!(new, length(dest) + n)
157+
@inbounds new[i] = el
158+
return collect_to_structarray!(new, itr, i + 1, st)
165159
end
166160

167161
_append!!(dest::AbstractVector, itr, ::Base.SizeUnknown) =

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,8 +658,7 @@ end
658658
("SizeUnknown", () -> (x for x in itr if isodd(x.a))),
659659
# Broken due to https://github.com/JuliaArrays/StructArrays.jl/issues/100:
660660
# ("empty", (x for x in itr if false)),
661-
# Broken due to https://github.com/JuliaArrays/StructArrays.jl/issues/99:
662-
# ("stateful", () -> Iterators.Stateful(itr)),
661+
("stateful", () -> Iterators.Stateful(itr)),
663662
]
664663
@testset "$destlabel $itrlabel" for (destlabel, dest) in dest_examples,
665664
(itrlabel, makeitr) in itr_examples

0 commit comments

Comments
 (0)