diff --git a/dataloader.go b/dataloader.go index 45e229f..65de9ed 100644 --- a/dataloader.go +++ b/dataloader.go @@ -9,6 +9,7 @@ import ( "log" "runtime" "sync" + "sync/atomic" "time" ) @@ -29,8 +30,12 @@ type Interface[K comparable, V any] interface { Flush() } +var ErrNoResultProvided = errors.New("no result provided") + // BatchFunc is a function, which when given a slice of keys (string), returns a slice of `results`. // It's important that the length of the input keys matches the length of the output results. +// Should the batch function return nil for a result, it will be treated as return an error +// of `ErrNoResultProvided` for that key. // // The keys passed to this function are guaranteed to be unique type BatchFunc[K comparable, V any] func(context.Context, []K) []*Result[V] @@ -131,8 +136,9 @@ type ThunkMany[V any] func() ([]V, []error) // type used to on input channel type batchRequest[K comparable, V any] struct { - key K - channel chan *Result[V] + key K + result atomic.Pointer[Result[V]] + done chan struct{} } // Option allows for configuration of Loader fields. @@ -221,11 +227,9 @@ func NewBatchedLoader[K comparable, V any](batchFn BatchFunc[K, V], opts ...Opti // the registered BatchFunc. func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] { ctx, finish := l.tracer.TraceLoad(originalContext, key) - - c := make(chan *Result[V], 1) - var result struct { - mu sync.RWMutex - value *Result[V] + req := &batchRequest[K, V]{ + key: key, + done: make(chan struct{}), } // We need to lock both the batchLock and cacheLock because the batcher can @@ -254,34 +258,19 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] { defer l.cacheLock.Unlock() thunk := func() (V, error) { - result.mu.RLock() - resultNotSet := result.value == nil - result.mu.RUnlock() - - if resultNotSet { - result.mu.Lock() - if v, ok := <-c; ok { - result.value = v - } - result.mu.Unlock() - } - result.mu.RLock() - defer result.mu.RUnlock() + <-req.done + result := req.result.Load() var ev *PanicErrorWrapper var es *SkipCacheError - if result.value.Error != nil && (errors.As(result.value.Error, &ev) || errors.As(result.value.Error, &es)){ + if result.Error != nil && (errors.As(result.Error, &ev) || errors.As(result.Error, &es)) { l.Clear(ctx, key) } - return result.value.Data, result.value.Error + return result.Data, result.Error } defer finish(thunk) l.cache.Set(ctx, key, thunk) - // this is sent to batch fn. It contains the key and the channel to return - // the result on - req := &batchRequest[K, V]{key, c} - // start the batch window if it hasn't already started. if l.curBatcher == nil { l.curBatcher = l.newBatcher(l.silent, l.tracer) @@ -338,8 +327,9 @@ func (l *Loader[K, V]) LoadMany(originalContext context.Context, keys []K) Thunk length = len(keys) data = make([]V, length) errors = make([]error, length) - c = make(chan *ResultMany[V], 1) + result atomic.Pointer[ResultMany[V]] wg sync.WaitGroup + done = make(chan struct{}) ) resolve := func(ctx context.Context, i int) { @@ -356,6 +346,7 @@ func (l *Loader[K, V]) LoadMany(originalContext context.Context, keys []K) Thunk } go func() { + defer close(done) wg.Wait() // errs is nil unless there exists a non-nil error. @@ -368,30 +359,13 @@ func (l *Loader[K, V]) LoadMany(originalContext context.Context, keys []K) Thunk } } - c <- &ResultMany[V]{Data: data, Error: errs} - close(c) + result.Store(&ResultMany[V]{Data: data, Error: errs}) }() - var result struct { - mu sync.RWMutex - value *ResultMany[V] - } - thunkMany := func() ([]V, []error) { - result.mu.RLock() - resultNotSet := result.value == nil - result.mu.RUnlock() - - if resultNotSet { - result.mu.Lock() - if v, ok := <-c; ok { - result.value = v - } - result.mu.Unlock() - } - result.mu.RLock() - defer result.mu.RUnlock() - return result.value.Data, result.value.Error + <-done + r := result.Load() + return r.Data, r.Error } defer finish(thunkMany) @@ -498,8 +472,8 @@ func (b *batcher[K, V]) batch(originalContext context.Context) { if panicErr != nil { for _, req := range reqs { - req.channel <- &Result[V]{Error: &PanicErrorWrapper{panicError: fmt.Errorf("Panic received in batch function: %v", panicErr)}} - close(req.channel) + req.result.Store(&Result[V]{Error: &PanicErrorWrapper{panicError: fmt.Errorf("Panic received in batch function: %v", panicErr)}}) + close(req.done) } return } @@ -517,16 +491,24 @@ func (b *batcher[K, V]) batch(originalContext context.Context) { `, keys, items)} for _, req := range reqs { - req.channel <- err - close(req.channel) + req.result.Store(err) + close(req.done) } return } + var notSetResult *Result[V] // don't allocate unless we need it for i, req := range reqs { - req.channel <- items[i] - close(req.channel) + if items[i] == nil { + if notSetResult == nil { + notSetResult = &Result[V]{Error: ErrNoResultProvided} + } + req.result.Store(notSetResult) + } else { + req.result.Store(items[i]) + } + close(req.done) } } diff --git a/go.mod b/go.mod index d5fe9d0..863539d 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/graph-gophers/dataloader/v7 -go 1.18 +go 1.19 require ( github.com/hashicorp/golang-lru v0.5.4