diff --git a/singleflight/singleflight.go b/singleflight/singleflight.go index 97a1aa4..cdaad2d 100644 --- a/singleflight/singleflight.go +++ b/singleflight/singleflight.go @@ -6,7 +6,10 @@ // mechanism. package singleflight // import "golang.org/x/sync/singleflight" -import "sync" +import ( + "sync" + "sync/atomic" +) // call is an in-flight or completed singleflight.Do call type call struct { @@ -24,8 +27,8 @@ type call struct { // These fields are read and written with the singleflight // mutex held before the WaitGroup is done, and are read but // not written after the WaitGroup is done. - dups int - chans []chan<- Result + refCount int64 + chans []chan<- Result } // Group represents a class of work and forms a namespace in @@ -40,32 +43,35 @@ type Group struct { type Result struct { Val interface{} Err error - Shared bool + Shared RefShared +} + +// RefShared struct encapsulates both "shared boolean" as well as actual reference counter +// callers can call RefShared.Decrement to determine when last caller is done using result, so cleanup if needed can be performed +type RefShared struct { + shared bool + refCount *int64 +} + +// Decrement will atomically decrement refcounter and will return new value +func (rs *RefShared) Decrement() int64 { + return atomic.AddInt64(rs.refCount, -1) +} + +// returns boolean indicator of whether original "ref counter" had more than 1 reference +// it will return same value regardless of whether Decrement() method was called +func (rs *RefShared) Shared() bool { + return rs.shared } // Do executes and returns the results of the given function, making // sure that only one execution is in-flight for a given key at a // time. If a duplicate comes in, the duplicate caller waits for the // original to complete and receives the same results. -// The return value shared indicates whether v was given to multiple callers. -func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool) { - g.mu.Lock() - if g.m == nil { - g.m = make(map[string]*call) - } - if c, ok := g.m[key]; ok { - c.dups++ - g.mu.Unlock() - c.wg.Wait() - return c.val, c.err, true - } - c := new(call) - c.wg.Add(1) - g.m[key] = c - g.mu.Unlock() - - g.doCall(c, key, fn) - return c.val, c.err, c.dups > 0 +// The return value shared indicates whether v was given to multiple callers (and a reference counter for callers too). +func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared RefShared) { + r := <-g.DoChan(key, fn) + return r.Val, r.Err, r.Shared } // DoChan is like Do but returns a channel that will receive the @@ -77,12 +83,12 @@ func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result g.m = make(map[string]*call) } if c, ok := g.m[key]; ok { - c.dups++ + c.refCount++ c.chans = append(c.chans, ch) g.mu.Unlock() return ch } - c := &call{chans: []chan<- Result{ch}} + c := &call{refCount: 1, chans: []chan<- Result{ch}} c.wg.Add(1) g.m[key] = c g.mu.Unlock() @@ -101,8 +107,9 @@ func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) { if !c.forgotten { delete(g.m, key) } + shared := RefShared{shared: c.refCount > 1, refCount: &c.refCount} for _, ch := range c.chans { - ch <- Result{c.val, c.err, c.dups > 0} + ch <- Result{c.val, c.err, shared} } g.mu.Unlock() } diff --git a/singleflight/singleflight_test.go b/singleflight/singleflight_test.go index ad04037..a4d6ea3 100644 --- a/singleflight/singleflight_test.go +++ b/singleflight/singleflight_test.go @@ -15,7 +15,7 @@ import ( func TestDo(t *testing.T) { var g Group - v, err, _ := g.Do("key", func() (interface{}, error) { + v, err, shared := g.Do("key", func() (interface{}, error) { return "bar", nil }) if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { @@ -24,6 +24,12 @@ func TestDo(t *testing.T) { if err != nil { t.Errorf("Do error = %v", err) } + if shared.Decrement() != 0 { + t.Errorf("ref counter is expected to be 0") + } + if shared.Shared() { + t.Errorf("Do returned shared") + } } func TestDoErr(t *testing.T) {