From efe1a002b5d462fe332c98b9b7b7782c5cb5a9ff Mon Sep 17 00:00:00 2001 From: Harrison Metzger Date: Sun, 12 Oct 2025 05:40:03 -0500 Subject: [PATCH 1/2] return known error if batch function omits providing a result --- dataloader.go | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/dataloader.go b/dataloader.go index 45e229f..a20e3db 100644 --- a/dataloader.go +++ b/dataloader.go @@ -29,8 +29,12 @@ type Interface[K comparable, V any] interface { Flush() } +var ErrNoResultProvided = errors.New("no result provided") + // BatchFunc is a function, which when given a slice of keys (string), returns a slice of `results`. // It's important that the length of the input keys matches the length of the output results. +// Should the batch function return nil for a result, it will be treated as return an error +// of `ErrNoResultProvided` for that key. // // The keys passed to this function are guaranteed to be unique type BatchFunc[K comparable, V any] func(context.Context, []K) []*Result[V] @@ -269,7 +273,7 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] { defer result.mu.RUnlock() var ev *PanicErrorWrapper var es *SkipCacheError - if result.value.Error != nil && (errors.As(result.value.Error, &ev) || errors.As(result.value.Error, &es)){ + if result.value.Error != nil && (errors.As(result.value.Error, &ev) || errors.As(result.value.Error, &es)) { l.Clear(ctx, key) } return result.value.Data, result.value.Error @@ -524,8 +528,16 @@ func (b *batcher[K, V]) batch(originalContext context.Context) { return } + var notSetResult *Result[V] // don't allocate unless we need it for i, req := range reqs { - req.channel <- items[i] + if items[i] == nil { + if notSetResult == nil { + notSetResult = &Result[V]{Error: ErrNoResultProvided} + } + req.channel <- notSetResult + } else { + req.channel <- items[i] + } close(req.channel) } } From a59f279cf8c22d46b6dd606119ee61660399f51d Mon Sep 17 00:00:00 2001 From: Harrison Metzger Date: Sun, 28 Sep 2025 13:39:07 +0200 Subject: [PATCH 2/2] Remove thunk mutex and use a signal channel and atomic pointer --- dataloader.go | 80 ++++++++++++++++----------------------------------- go.mod | 2 +- 2 files changed, 26 insertions(+), 56 deletions(-) diff --git a/dataloader.go b/dataloader.go index a20e3db..65de9ed 100644 --- a/dataloader.go +++ b/dataloader.go @@ -9,6 +9,7 @@ import ( "log" "runtime" "sync" + "sync/atomic" "time" ) @@ -135,8 +136,9 @@ type ThunkMany[V any] func() ([]V, []error) // type used to on input channel type batchRequest[K comparable, V any] struct { - key K - channel chan *Result[V] + key K + result atomic.Pointer[Result[V]] + done chan struct{} } // Option allows for configuration of Loader fields. @@ -225,11 +227,9 @@ func NewBatchedLoader[K comparable, V any](batchFn BatchFunc[K, V], opts ...Opti // the registered BatchFunc. func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] { ctx, finish := l.tracer.TraceLoad(originalContext, key) - - c := make(chan *Result[V], 1) - var result struct { - mu sync.RWMutex - value *Result[V] + req := &batchRequest[K, V]{ + key: key, + done: make(chan struct{}), } // We need to lock both the batchLock and cacheLock because the batcher can @@ -258,34 +258,19 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] { defer l.cacheLock.Unlock() thunk := func() (V, error) { - result.mu.RLock() - resultNotSet := result.value == nil - result.mu.RUnlock() - - if resultNotSet { - result.mu.Lock() - if v, ok := <-c; ok { - result.value = v - } - result.mu.Unlock() - } - result.mu.RLock() - defer result.mu.RUnlock() + <-req.done + result := req.result.Load() var ev *PanicErrorWrapper var es *SkipCacheError - if result.value.Error != nil && (errors.As(result.value.Error, &ev) || errors.As(result.value.Error, &es)) { + if result.Error != nil && (errors.As(result.Error, &ev) || errors.As(result.Error, &es)) { l.Clear(ctx, key) } - return result.value.Data, result.value.Error + return result.Data, result.Error } defer finish(thunk) l.cache.Set(ctx, key, thunk) - // this is sent to batch fn. It contains the key and the channel to return - // the result on - req := &batchRequest[K, V]{key, c} - // start the batch window if it hasn't already started. if l.curBatcher == nil { l.curBatcher = l.newBatcher(l.silent, l.tracer) @@ -342,8 +327,9 @@ func (l *Loader[K, V]) LoadMany(originalContext context.Context, keys []K) Thunk length = len(keys) data = make([]V, length) errors = make([]error, length) - c = make(chan *ResultMany[V], 1) + result atomic.Pointer[ResultMany[V]] wg sync.WaitGroup + done = make(chan struct{}) ) resolve := func(ctx context.Context, i int) { @@ -360,6 +346,7 @@ func (l *Loader[K, V]) LoadMany(originalContext context.Context, keys []K) Thunk } go func() { + defer close(done) wg.Wait() // errs is nil unless there exists a non-nil error. @@ -372,30 +359,13 @@ func (l *Loader[K, V]) LoadMany(originalContext context.Context, keys []K) Thunk } } - c <- &ResultMany[V]{Data: data, Error: errs} - close(c) + result.Store(&ResultMany[V]{Data: data, Error: errs}) }() - var result struct { - mu sync.RWMutex - value *ResultMany[V] - } - thunkMany := func() ([]V, []error) { - result.mu.RLock() - resultNotSet := result.value == nil - result.mu.RUnlock() - - if resultNotSet { - result.mu.Lock() - if v, ok := <-c; ok { - result.value = v - } - result.mu.Unlock() - } - result.mu.RLock() - defer result.mu.RUnlock() - return result.value.Data, result.value.Error + <-done + r := result.Load() + return r.Data, r.Error } defer finish(thunkMany) @@ -502,8 +472,8 @@ func (b *batcher[K, V]) batch(originalContext context.Context) { if panicErr != nil { for _, req := range reqs { - req.channel <- &Result[V]{Error: &PanicErrorWrapper{panicError: fmt.Errorf("Panic received in batch function: %v", panicErr)}} - close(req.channel) + req.result.Store(&Result[V]{Error: &PanicErrorWrapper{panicError: fmt.Errorf("Panic received in batch function: %v", panicErr)}}) + close(req.done) } return } @@ -521,8 +491,8 @@ func (b *batcher[K, V]) batch(originalContext context.Context) { `, keys, items)} for _, req := range reqs { - req.channel <- err - close(req.channel) + req.result.Store(err) + close(req.done) } return @@ -534,11 +504,11 @@ func (b *batcher[K, V]) batch(originalContext context.Context) { if notSetResult == nil { notSetResult = &Result[V]{Error: ErrNoResultProvided} } - req.channel <- notSetResult + req.result.Store(notSetResult) } else { - req.channel <- items[i] + req.result.Store(items[i]) } - close(req.channel) + close(req.done) } } diff --git a/go.mod b/go.mod index d5fe9d0..863539d 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/graph-gophers/dataloader/v7 -go 1.18 +go 1.19 require ( github.com/hashicorp/golang-lru v0.5.4