Skip to content

Commit 0f0497f

Browse files
committed
Allow way to cancel Ctrie iterator
This allows a cancel channel to be passed in to the Ctrie Iterator. When the channel is closed, the iterator channel will close, freeing up the goroutine.
1 parent 1d6a105 commit 0f0497f

File tree

2 files changed

+38
-9
lines changed

2 files changed

+38
-9
lines changed

trie/ctrie/ctrie.go

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ package ctrie
2525

2626
import (
2727
"bytes"
28+
"errors"
2829
"hash"
2930
"hash/fnv"
3031
"sync/atomic"
@@ -354,12 +355,15 @@ func (c *Ctrie) Clear() {
354355
}
355356
}
356357

357-
// Iterator returns a channel which yields the Entries of the Ctrie.
358-
func (c *Ctrie) Iterator() <-chan *Entry {
358+
// Iterator returns a channel which yields the Entries of the Ctrie. If a
359+
// cancel channel is provided, closing it will terminate and close the iterator
360+
// channel. Note that if a cancel channel is not used and not every entry is
361+
// read from the iterator, a goroutine will leak.
362+
func (c *Ctrie) Iterator(cancel <-chan struct{}) <-chan *Entry {
359363
ch := make(chan *Entry)
360364
snapshot := c.ReadOnlySnapshot()
361365
go func() {
362-
traverse(snapshot.root, ch)
366+
traverse(snapshot.root, ch, cancel)
363367
close(ch)
364368
}()
365369
return ch
@@ -373,30 +377,43 @@ func (c *Ctrie) Size() uint {
373377
// computation is amortized across the update operations that occurred
374378
// since the last snapshot.
375379
size := uint(0)
376-
for _ = range c.Iterator() {
380+
for _ = range c.Iterator(nil) {
377381
size++
378382
}
379383
return size
380384
}
381385

382-
func traverse(i *iNode, ch chan<- *Entry) {
386+
var errCanceled = errors.New("canceled")
387+
388+
func traverse(i *iNode, ch chan<- *Entry, cancel <-chan struct{}) error {
383389
switch {
384390
case i.main.cNode != nil:
385391
for _, br := range i.main.cNode.array {
386392
switch b := br.(type) {
387393
case *iNode:
388-
traverse(b, ch)
394+
if err := traverse(b, ch, cancel); err != nil {
395+
return err
396+
}
389397
case *sNode:
390-
ch <- b.Entry
398+
select {
399+
case ch <- b.Entry:
400+
case <-cancel:
401+
return errCanceled
402+
}
391403
}
392404
}
393405
case i.main.lNode != nil:
394406
for _, e := range i.main.lNode.Map(func(sn interface{}) interface{} {
395407
return sn.(*sNode).Entry
396408
}) {
397-
ch <- e.(*Entry)
409+
select {
410+
case ch <- e.(*Entry):
411+
case <-cancel:
412+
return errCanceled
413+
}
398414
}
399415
}
416+
return nil
400417
}
401418

402419
func (c *Ctrie) assertReadWrite() {

trie/ctrie/ctrie_test.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,14 +273,26 @@ func TestIterator(t *testing.T) {
273273
}
274274

275275
count := 0
276-
for entry := range ctrie.Iterator() {
276+
for entry := range ctrie.Iterator(nil) {
277277
exp, ok := expected[string(entry.Key)]
278278
if assert.True(ok) {
279279
assert.Equal(exp, entry.Value)
280280
}
281281
count++
282282
}
283283
assert.Equal(len(expected), count)
284+
285+
// Closing cancel channel should close iterator channel.
286+
cancel := make(chan struct{})
287+
iter := ctrie.Iterator(cancel)
288+
entry := <-iter
289+
exp, ok := expected[string(entry.Key)]
290+
if assert.True(ok) {
291+
assert.Equal(exp, entry.Value)
292+
}
293+
close(cancel)
294+
_, ok = <-iter
295+
assert.False(ok)
284296
}
285297

286298
func TestSize(t *testing.T) {

0 commit comments

Comments
 (0)