From 34b5ab5629067c58ec2712885907560f90239e66 Mon Sep 17 00:00:00 2001 From: Victor Conner Date: Wed, 12 Mar 2025 10:56:44 +0100 Subject: [PATCH] Reworked and improved the error handling --- distribution.go | 71 +++++++--- distribution_test.go | 314 ++++++++++++++++++++++++++++++++++++++++++- errors.go | 40 ++++-- fetch.go | 8 +- fetch_test.go | 278 ++++++++++++++++++++++++++++++++++++++ inflight.go | 97 +++++++------ passthrough.go | 2 +- refresh.go | 2 +- safe.go | 23 ++-- 9 files changed, 732 insertions(+), 103 deletions(-) diff --git a/distribution.go b/distribution.go index 73546b5..d4a002d 100644 --- a/distribution.go +++ b/distribution.go @@ -96,9 +96,10 @@ func distributedFetch[V, T any](c *Client[T], key string, fetchFn FetchFn[V]) Fe return func(ctx context.Context) (V, error) { stale, hasStale := *new(V), false - bytes, ok := c.distributedStorage.Get(ctx, key) - if ok { - c.reportDistributedCacheHit(true) + bytes, existsInDistributedStorage := c.distributedStorage.Get(ctx, key) + c.reportDistributedCacheHit(existsInDistributedStorage) + + if existsInDistributedStorage { record, unmarshalErr := unmarshalRecord[V](bytes, key, c.log) if unmarshalErr != nil { return record.Value, unmarshalErr @@ -116,8 +117,16 @@ func distributedFetch[V, T any](c *Client[T], key string, fetchFn FetchFn[V]) Fe stale, hasStale = record.Value, true } - if !ok { - c.reportDistributedCacheHit(false) + // Before we call the fetchFn, we'll do an unblocking read to see if the + // context has been cancelled. If it has, we'll return a stale value if we + // have one available. + select { + case <-ctx.Done(): + if hasStale { + return stale, errors.Join(errOnlyDistributedRecords, ctx.Err()) + } + return *(new(V)), ctx.Err() + default: } // If it's not fresh enough, we'll retrieve it from the source. @@ -146,7 +155,7 @@ func distributedFetch[V, T any](c *Client[T], key string, fetchFn FetchFn[V]) Fe if hasStale { c.reportDistributedStaleFallback() - return stale, nil + return stale, errors.Join(errOnlyDistributedRecords, fetchErr) } return response, fetchErr @@ -177,14 +186,14 @@ func distributedBatchFetch[V, T any](c *Client[T], keyFn KeyFn, fetchFn BatchFet idsToRefresh := make([]string, 0, len(ids)) for _, id := range ids { key := keyFn(id) - bytes, ok := distributedRecords[key] - if !ok { - c.reportDistributedCacheHit(false) + bytes, existsInDistributedStorage := distributedRecords[key] + c.reportDistributedCacheHit(existsInDistributedStorage) + + if !existsInDistributedStorage { idsToRefresh = append(idsToRefresh, id) continue } - c.reportDistributedCacheHit(true) record, unmarshalErr := unmarshalRecord[V](bytes, key, c.log) if unmarshalErr != nil { idsToRefresh = append(idsToRefresh, id) @@ -194,11 +203,12 @@ func distributedBatchFetch[V, T any](c *Client[T], keyFn KeyFn, fetchFn BatchFet // If early refreshes isn't enabled it means all records are fresh, otherwise we'll check the CreatedAt time. if !c.distributedEarlyRefreshes || c.clock.Since(record.CreatedAt) < c.distributedRefreshAfterDuration { // We never want to return missing records. - if !record.IsMissingRecord { - fresh[id] = record.Value - } else { + if record.IsMissingRecord { c.reportDistributedMissingRecord() + continue } + + fresh[id] = record.Value continue } @@ -206,17 +216,33 @@ func distributedBatchFetch[V, T any](c *Client[T], keyFn KeyFn, fetchFn BatchFet c.reportDistributedRefresh() // We never want to return missing records. - if !record.IsMissingRecord { - stale[id] = record.Value - } else { + if record.IsMissingRecord { c.reportDistributedMissingRecord() + continue } + stale[id] = record.Value } if len(idsToRefresh) == 0 { return fresh, nil } + // Before we call the fetchFn, we'll do an unblocking read to see if the + // context has been cancelled. If it has, we'll return any potential + // records we got from the distributed storage. + select { + case <-ctx.Done(): + maps.Copy(stale, fresh) + + // If we didn't get any records from the distributed storage, + // we'll return the error from the fetch function as-is. + if len(stale) < 1 { + return stale, ctx.Err() + } + return stale, errors.Join(errOnlyDistributedRecords, ctx.Err()) + default: + } + dataSourceResponses, err := fetchFn(ctx, idsToRefresh) // In case of an error, we'll proceed with the ones we got from the distributed storage. // NOTE: It's important that we return a specific error here, otherwise we'll potentially @@ -227,7 +253,14 @@ func distributedBatchFetch[V, T any](c *Client[T], keyFn KeyFn, fetchFn BatchFet c.reportDistributedStaleFallback() } maps.Copy(stale, fresh) - return stale, errOnlyDistributedRecords + + // If we didn't get any records from the distributed storage, + // we'll return the error from the fetch function as-is. + if len(stale) < 1 { + return dataSourceResponses, err + } + + return stale, errors.Join(errOnlyDistributedRecords, err) } // Next, we'll want to check if we should change any of the records to be missing or perform deletions. @@ -235,9 +268,7 @@ func distributedBatchFetch[V, T any](c *Client[T], keyFn KeyFn, fetchFn BatchFet keysToDelete := make([]string, 0, max(len(idsToRefresh)-len(dataSourceResponses), 0)) for _, id := range idsToRefresh { key := keyFn(id) - response, ok := dataSourceResponses[id] - - if ok { + if response, ok := dataSourceResponses[id]; ok { if recordBytes, marshalErr := marshalRecord[V](response, c); marshalErr == nil { recordsToWrite[key] = recordBytes } diff --git a/distribution_test.go b/distribution_test.go index c0c69b5..b8f0768 100644 --- a/distribution_test.go +++ b/distribution_test.go @@ -17,11 +17,17 @@ type mockStorage struct { setCount int deleteCount int records map[string][]byte + cancelFunc *context.CancelFunc } func (m *mockStorage) Get(_ context.Context, key string) ([]byte, bool) { m.Lock() - defer m.Unlock() + defer func() { + if m.cancelFunc != nil { + (*m.cancelFunc)() + } + m.Unlock() + }() m.getCount++ bytes, ok := m.records[key] @@ -48,7 +54,12 @@ func (m *mockStorage) Delete(_ context.Context, key string) { func (m *mockStorage) GetBatch(_ context.Context, _ []string) map[string][]byte { m.Lock() - defer m.Unlock() + defer func() { + if m.cancelFunc != nil { + (*m.cancelFunc)() + } + m.Unlock() + }() m.getCount++ return m.records } @@ -220,10 +231,14 @@ func TestDistributedStaleStorage(t *testing.T) { clock.Add(time.Minute * 2) // Now we can request the same key again, but we'll make the fetchFn error. + // This is going to result in an ErrOnlyCachedRecords error. fetchObserver.Err(errors.New("error")) res, err := sturdyc.GetOrFetch(ctx, c, key, fetchObserver.Fetch) - if err != nil { - t.Fatalf("expected no error, got %v", err) + if !errors.Is(err, sturdyc.ErrOnlyCachedRecords) { + t.Fatalf("expected ErrOnlyCachedRecords, got %v", err) + } + if !errors.Is(err, fetchObserver.err) { + t.Fatal("expected the original error to have been joined with ErrOnlyCachedRecords") } if res != "valuekey1" { t.Errorf("expected valuekey1, got %s", res) @@ -235,6 +250,17 @@ func TestDistributedStaleStorage(t *testing.T) { distributedStorage.assertGetCount(t, 2) distributedStorage.assertSetCount(t, 1) distributedStorage.assertDeleteCount(t, 0) + + // Getting the key now should not result in another error since we took the + // stale value from the distributed storage, and wrote it to the in-memory + // cache. + res, err = sturdyc.GetOrFetch(ctx, c, key, fetchObserver.Fetch) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if res != "valuekey1" { + t.Errorf("expected valuekey1, got %s", res) + } } func TestDistributedStaleStorageDeletes(t *testing.T) { @@ -803,7 +829,7 @@ func TestPartialResponseForRefreshesDoesNotResultInMissingRecords(t *testing.T) ids = append(ids, strconv.Itoa(i)) } - fetchObserver := NewFetchObserver(11) + fetchObserver := NewFetchObserver(1) fetchObserver.BatchResponse(ids) res, err := sturdyc.GetOrFetchBatch(ctx, c, ids, keyFn, fetchObserver.FetchBatch) if err != nil { @@ -871,3 +897,281 @@ func TestPartialResponseForRefreshesDoesNotResultInMissingRecords(t *testing.T) t.Fatalf("expected cache size to be 100, got %d", c.Size()) } } + +func TestReturnsDataSourceErrorsWhenThereAreNoDistributedRecords(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ttl := time.Minute + distributedStorage := &mockStorage{} + c := sturdyc.New[string](1000, 10, ttl, 30, + sturdyc.WithNoContinuousEvictions(), + sturdyc.WithDistributedStorage(distributedStorage), + ) + + keyFn := c.BatchKeyFn("item") + ids := make([]string, 0, 100) + for i := 1; i <= 100; i++ { + ids = append(ids, strconv.Itoa(i)) + } + + fetchObserver := NewFetchObserver(1) + fetchObserver.err = context.Canceled + res, err := sturdyc.GetOrFetchBatch(ctx, c, ids, keyFn, fetchObserver.FetchBatch) + <-fetchObserver.FetchCompleted + + fetchObserver.AssertFetchCount(t, 1) + fetchObserver.AssertRequestedRecords(t, ids) + distributedStorage.assertGetCount(t, 1) + + if len(res) != 0 { + t.Fatalf("expected 0 records, got %d", len(res)) + } + + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } + + if errors.Is(err, sturdyc.ErrOnlyCachedRecords) { + t.Error("expected no ErrOnlyCachedRecords since the response was empty") + } +} + +func TestReturnsPartialResultsAndJoinedErrorsOnDataSourceFailure(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ttl := time.Minute + distributedStorage := &mockStorage{} + c := sturdyc.New[string](1000, 10, ttl, 30, + sturdyc.WithNoContinuousEvictions(), + sturdyc.WithDistributedStorage(distributedStorage), + ) + fetchObserver := NewFetchObserver(1) + + keyFn := c.BatchKeyFn("item") + firstBatchOfIDs := []string{"1", "2", "3"} + fetchObserver.BatchResponse(firstBatchOfIDs) + _, err := sturdyc.GetOrFetchBatch(ctx, c, firstBatchOfIDs, keyFn, fetchObserver.FetchBatch) + <-fetchObserver.FetchCompleted + + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + fetchObserver.AssertRequestedRecords(t, firstBatchOfIDs) + fetchObserver.AssertFetchCount(t, 1) + fetchObserver.Clear() + + // The keys are written asynchronously to the distributed storage. + time.Sleep(100 * time.Millisecond) + distributedStorage.assertRecords(t, firstBatchOfIDs, keyFn) + distributedStorage.assertGetCount(t, 1) + distributedStorage.assertSetCount(t, 1) + + // Next, we'll delete the records from the in-memory cache to ensure that we + // get them from the distributed storage on the next GetOrFetchBatchCall. + for _, id := range firstBatchOfIDs { + c.Delete(keyFn(id)) + } + if c.Size() != 0 { + t.Fatalf("expected cache size to be 0, got %d", c.Size()) + } + + // Now we can request a second batch of IDs. This time though, we'll make the + // underlying data source error out. This means that we're only going to get + // ids 1-3 from the distributed storage. + secondBatchOfIDs := []string{"1", "2", "3", "4", "5", "6"} + fetchObserver.err = context.Canceled + res, err := sturdyc.GetOrFetchBatch(ctx, c, secondBatchOfIDs, keyFn, fetchObserver.FetchBatch) + <-fetchObserver.FetchCompleted + + // Assert that we got the 3 records which were stored in the distributed storage after the first request. + if len(res) != 3 { + t.Fatalf("expected 3 records, got %d", len(res)) + } + + // Assert that we're returning an ErrOnlyCachedRecords error to + // inform that the call to the underlying data source failed. + if !errors.Is(err, sturdyc.ErrOnlyCachedRecords) { + t.Fatalf("expected err to be ErrOnlyCachedRecords, got %v", err) + } + + // Assert that the error from the request to fetch the additional IDs is returned as well. + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected err to be context.Canceled, got %v", err) + } + + // Let's also make sure that we don't lose any data source errors when + // there is multiple joined errors returned from the fetch function. + fetchObserver.Clear() + fetchObserver.err = errors.Join(context.Canceled, context.DeadlineExceeded) + + // Clear the in-memory cache. + for _, id := range firstBatchOfIDs { + c.Delete(keyFn(id)) + } + if c.Size() != 0 { + t.Fatalf("expected cache size to be 0, got %d", c.Size()) + } + + res, err = sturdyc.GetOrFetchBatch(ctx, c, secondBatchOfIDs, keyFn, fetchObserver.FetchBatch) + <-fetchObserver.FetchCompleted + + if len(res) != 3 { + t.Fatalf("expected 3 records, got %d", len(res)) + } + + if !errors.Is(err, sturdyc.ErrOnlyCachedRecords) { + t.Fatalf("expected err to be ErrOnlyCachedRecords, got %v", err) + } + + // Lastly, let's assert that we got both the context.Canceled and context.DeadlineExceeded errors. + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected err to be context.Canceled, got %v", err) + } + + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected err to be context.DeadlineExceeded, got %v", err) + } +} + +func TestGetOrFetchDoesNotProceedToCallTheFetchFunctionIfTheContextIsCancelled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + ttl := time.Minute + distributedStorage := &mockStorage{} + clock := sturdyc.NewTestClock(time.Now()) + refreshAfter := time.Second * 10 + c := sturdyc.New[string](1000, 10, ttl, 30, + sturdyc.WithNoContinuousEvictions(), + sturdyc.WithDistributedStorageEarlyRefreshes(distributedStorage, refreshAfter), + sturdyc.WithClock(clock), + ) + fetchObserver := NewFetchObserver(1) + + key := "1" + fetchObserver.Response(key) + val, err := sturdyc.GetOrFetch(ctx, c, key, fetchObserver.Fetch) + if val != "value1" { + t.Fatalf("expected value1, got %s", val) + } + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + <-fetchObserver.FetchCompleted + fetchObserver.AssertFetchCount(t, 1) + fetchObserver.Clear() + + // The key is written asynchronously to the distributed storage. + time.Sleep(100 * time.Millisecond) + distributedStorage.assertRecord(t, key) + distributedStorage.assertGetCount(t, 1) + distributedStorage.assertSetCount(t, 1) + + // Next, we'll delete the key from the in-memory cache. + c.Delete(key) + if c.Size() != 0 { + t.Fatalf("expected cache size to be 0, got %d", c.Size()) + } + + // Now we'll request the key again. This time though, we'll move the clock + // forward to indicate that the request has to be refreshed, but before that + // is done we'll cancel the context. + clock.Add(refreshAfter + time.Second) + distributedStorage.cancelFunc = &cancel + val, err = sturdyc.GetOrFetch(ctx, c, key, fetchObserver.Fetch) + + // Assert that we got the key from the distributed storage. + distributedStorage.assertGetCount(t, 2) + + // Assert that we didn't call the fetch function again. + if fetchObserver.fetchCount != 1 { + t.Fatalf("expected fetch count to be 1, got %d", fetchObserver.fetchCount) + } + + if val != "value1" { + t.Fatalf("expected value1, got %s", val) + } + + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected err to be context.Canceled, got %v", err) + } + + if !errors.Is(err, sturdyc.ErrOnlyCachedRecords) { + t.Fatalf("expected err to be context.Canceled, got %v", err) + } + + // Assert that we wrote the value from the distributed cache back to the in-memory cache. + if c.Size() != 1 { + t.Fatalf("expected cache size to be 1, got %d", c.Size()) + } +} + +func TestGetOrFetchBatchDoesNotProceedToCallTheFetchFunctionIfTheContextIsCancelled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + ttl := time.Minute + distributedStorage := &mockStorage{} + c := sturdyc.New[string](1000, 10, ttl, 30, + sturdyc.WithNoContinuousEvictions(), + sturdyc.WithDistributedStorage(distributedStorage), + ) + fetchObserver := NewFetchObserver(1) + + keyFn := c.BatchKeyFn("item") + firstBatchOfIDs := []string{"1", "2", "3"} + fetchObserver.BatchResponse(firstBatchOfIDs) + _, err := sturdyc.GetOrFetchBatch(ctx, c, firstBatchOfIDs, keyFn, fetchObserver.FetchBatch) + <-fetchObserver.FetchCompleted + + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + fetchObserver.AssertRequestedRecords(t, firstBatchOfIDs) + fetchObserver.AssertFetchCount(t, 1) + fetchObserver.Clear() + + // The keys are written asynchronously to the distributed storage. + time.Sleep(100 * time.Millisecond) + distributedStorage.assertRecords(t, firstBatchOfIDs, keyFn) + distributedStorage.assertGetCount(t, 1) + distributedStorage.assertSetCount(t, 1) + + // Next, we'll delete the records from the in-memory cache to ensure that we + // get them from the distributed storage on the next GetOrFetchBatchCall. + for _, id := range firstBatchOfIDs { + c.Delete(keyFn(id)) + } + if c.Size() != 0 { + t.Fatalf("expected cache size to be 0, got %d", c.Size()) + } + + // Now we can request a second batch of IDs. This time though, we'll cancel + // the context after the distributed cache has been called. + secondBatchOfIDs := []string{"1", "2", "3", "4", "5", "6"} + distributedStorage.cancelFunc = &cancel + res, err := sturdyc.GetOrFetchBatch(ctx, c, secondBatchOfIDs, keyFn, fetchObserver.FetchBatch) + + // Assert that we got the 3 records which were stored in the distributed storage after the first request. + if len(res) != 3 { + t.Fatalf("expected 3 records, got %d", len(res)) + } + + // Assert that we didn't call the fetch function again. + if fetchObserver.fetchCount != 1 { + t.Fatalf("expected fetch count to be 1, got %d", fetchObserver.fetchCount) + } + + // Assert that the error is both an ErrOnlyCachedRecords and a context.Canceled error. + if !errors.Is(err, sturdyc.ErrOnlyCachedRecords) { + t.Fatalf("expected err to be ErrOnlyCachedRecords, got %v", err) + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected err to be context.Canceled, got %v", err) + } +} diff --git a/errors.go b/errors.go index e0eebf9..5c7da74 100644 --- a/errors.go +++ b/errors.go @@ -20,16 +20,40 @@ var ( // ErrMissingRecord is returned by client.GetOrFetch and client.Passthrough when a record has been marked // as missing. The cache will still try to refresh the record in the background if it's being requested. ErrMissingRecord = errors.New("sturdyc: the record has been marked as missing in the cache") - // ErrOnlyCachedRecords is returned by client.GetOrFetchBatch and - // client.PassthroughBatch when some of the requested records are available - // in the cache, but the attempt to fetch the remaining records failed. It - // may also be returned when you're using the WithEarlyRefreshes - // functionality, and the call to synchronously refresh a record failed. The - // cache will then give you the latest data it has cached, and you as the - // consumer can then decide whether to proceed with the cached records or if - // the newest data is necessary. + // ErrOnlyCachedRecords can be returned when you're using the cache with + // early refreshes or distributed storage functionality. It indicates that + // the records *should* have been refreshed from the underlying data source, + // but the operation failed. It is up to you to decide whether you want to + // proceed with the records that were retrieved from the cache. Note: For + // batch operations, this might contain only part of the batch. For example, + // if you requested keys 1-10, and we had IDs 1-3 in the cache, but the + // request to fetch records 4-10 failed. ErrOnlyCachedRecords = errors.New("sturdyc: failed to fetch the records that were not in the cache") // ErrInvalidType is returned when you try to use one of the generic // package level functions but the type assertion fails. ErrInvalidType = errors.New("sturdyc: invalid response type") ) + +// onlyCachedRecords is used when we were able to successfully retrieve some +// records from distributed storage, but the request to get additional records +// from the underlying data source failed. In this case, we wrap any potential +// errors from the underlying data source with an ErrOnlyCachedRecords to allow +// the user to decide whether to proceed with the cached records or not. +func onlyCachedRecords(err error) error { + multiErr, isMultiErr := err.(interface{ Unwrap() []error }) + if !isMultiErr { + if errors.Is(errOnlyDistributedRecords, err) { + return ErrOnlyCachedRecords + } + return errors.Join(err, ErrOnlyCachedRecords) + } + + var errs []error + errs = append(errs, ErrOnlyCachedRecords) + for _, e := range multiErr.Unwrap() { + if !errors.Is(e, errOnlyDistributedRecords) { + errs = append(errs, e) + } + } + return errors.Join(errs...) +} diff --git a/fetch.go b/fetch.go index 0ffc292..b39ec66 100644 --- a/fetch.go +++ b/fetch.go @@ -151,7 +151,7 @@ func getFetchBatch[V, T any](ctx context.Context, c *Client[T], ids []string, ke cacheMissesAndSyncRefreshes = append(cacheMissesAndSyncRefreshes, cacheMisses...) cacheMissesAndSyncRefreshes = append(cacheMissesAndSyncRefreshes, idsToSynchronouslyRefresh...) - callBatchOpts := callBatchOpts[T, T]{ids: cacheMissesAndSyncRefreshes, keyFn: keyFn, fn: wrappedFetch} + callBatchOpts := callBatchOpts[T]{ids: cacheMissesAndSyncRefreshes, keyFn: keyFn, fn: wrappedFetch} response, err := callAndCacheBatch(ctx, c, callBatchOpts) // If we did a call to synchronously refresh some of the records, and it @@ -175,10 +175,14 @@ func getFetchBatch[V, T any](ctx context.Context, c *Client[T], ids []string, ke } if err != nil && !errors.Is(err, ErrOnlyCachedRecords) { + // At this point, the call for the IDs that we didn't have in the cache + // have failed. However, these ID's could have been picked from multiple + // in-flight requests. Hence, we'll check if we're able to add any of these + // IDs to the cached records before returning. if len(cachedRecords) > 0 { + maps.Copy(cachedRecords, response) return cachedRecords, errors.Join(ErrOnlyCachedRecords, err) } - return cachedRecords, err } maps.Copy(cachedRecords, response) diff --git a/fetch_test.go b/fetch_test.go index 18ca076..4460fcc 100644 --- a/fetch_test.go +++ b/fetch_test.go @@ -213,6 +213,44 @@ func TestGetOrFetchMissingRecord(t *testing.T) { fetchObserver.AssertFetchCount(t, 2) } +func TestGetOrFetchMissingRecordError(t *testing.T) { + t.Parallel() + + ctx := context.Background() + capacity := 5 + numShards := 2 + ttl := time.Minute + evictionPercentage := 10 + c := sturdyc.New[any](capacity, numShards, ttl, evictionPercentage, + sturdyc.WithNoContinuousEvictions(), + sturdyc.WithMissingRecordStorage(), + ) + + // We'll make the fetch observer return an ErrNotFound error + // to indicate that the record should be stored as missing. + fetchObserver := NewFetchObserver(1) + fetchObserver.Err(sturdyc.ErrNotFound) + id := "1" + + _, err := sturdyc.GetOrFetch(ctx, c, id, fetchObserver.Fetch) + if !errors.Is(err, sturdyc.ErrMissingRecord) { + t.Fatalf("expected missing record error, got %v", err) + } + + <-fetchObserver.FetchCompleted + fetchObserver.AssertFetchCount(t, 1) + fetchObserver.Err(sturdyc.ErrNotFound) + + // The second time we call Get, we should still get the missing record + _, err = sturdyc.GetOrFetch(ctx, c, id, fetchObserver.Fetch) + if !errors.Is(err, sturdyc.ErrMissingRecord) { + t.Fatalf("expected missing record error, got %v", err) + } + + // And only no more fetch should have been made + fetchObserver.AssertFetchCount(t, 1) +} + func TestGetOrFetchBatch(t *testing.T) { t.Parallel() @@ -1568,3 +1606,243 @@ func TestGetFetchBatchMixOfSynchronousAndAsynchronousRefreshes(t *testing.T) { fetchObserver.AssertRequestedRecords(t, []string{"4"}) fetchObserver.AssertFetchCount(t, 4) } + +func TestGetOrFetchGenerics(t *testing.T) { + t.Parallel() + + capacity := 10 + numShards := 2 + ttl := time.Second * 2 + evictionPercentage := 10 + c := sturdyc.New[string](capacity, numShards, ttl, evictionPercentage, + sturdyc.WithNoContinuousEvictions(), + ) + + fetchFuncInt := func(_ context.Context) (int, error) { + return 1, nil + } + + _, err := sturdyc.GetOrFetch(context.Background(), c, "1", fetchFuncInt) + if !errors.Is(err, sturdyc.ErrInvalidType) { + t.Errorf("expected ErrInvalidType, got %v", err) + } + + if _, ok := c.Get("1"); ok { + t.Error("we should not have cached anything given that the value was of the wrong type") + } + + if c.Size() != 0 { + t.Errorf("expected cache size to be 0, got %d", c.Size()) + } + + fetchFuncString := func(_ context.Context) (string, error) { + return "value", nil + } + + val, err := sturdyc.GetOrFetch(context.Background(), c, "1", fetchFuncString) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if val != "value" { + t.Errorf("expected value to be value, got %v", val) + } + + cachedValue, ok := c.Get("1") + if !ok { + t.Error("expected value to be in the cache") + } + if cachedValue != "value" { + t.Errorf("expected value to be value, got %v", cachedValue) + } + if c.Size() != 1 { + t.Errorf("expected cache size to be 1, got %d", c.Size()) + } +} + +func TestGetOrFetchBatchGenerics(t *testing.T) { + t.Parallel() + + capacity := 10 + numShards := 2 + ttl := time.Second * 2 + evictionPercentage := 10 + c := sturdyc.New[string](capacity, numShards, ttl, evictionPercentage, + sturdyc.WithNoContinuousEvictions(), + ) + + fetchFuncInt := func(_ context.Context, ids []string) (map[string]int, error) { + response := make(map[string]int, len(ids)) + for i, id := range ids { + response[id] = i + } + return response, nil + } + + ids := []string{"1", "2", "3"} + _, err := sturdyc.GetOrFetchBatch(context.Background(), c, ids, c.BatchKeyFn("item"), fetchFuncInt) + if !errors.Is(err, sturdyc.ErrInvalidType) { + t.Errorf("expected ErrInvalidType, got %v", err) + } + if c.Size() != 0 { + t.Errorf("expected cache size to be 0, got %d", c.Size()) + } + + fetchFuncString := func(_ context.Context, ids []string) (map[string]string, error) { + response := make(map[string]string, len(ids)) + for _, id := range ids { + response[id] = "value" + id + } + return response, nil + } + + keyFunc := c.BatchKeyFn("item") + values, err := sturdyc.GetOrFetchBatch(context.Background(), c, ids, keyFunc, fetchFuncString) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(values) != 3 { + t.Errorf("expected 3 values, got %d", len(values)) + } + if values["1"] != "value1" { + t.Errorf("expected value to be value1, got %v", values["1"]) + } + if values["2"] != "value2" { + t.Errorf("expected value to be value2, got %v", values["2"]) + } + if values["3"] != "value3" { + t.Errorf("expected value to be value3, got %v", values["3"]) + } + + cacheKeys := make([]string, 0, len(ids)) + for _, id := range ids { + cacheKeys = append(cacheKeys, keyFunc(id)) + } + cachedValues := c.GetMany(cacheKeys) + if len(cachedValues) != 3 { + t.Errorf("expected 3 values, got %d", len(cachedValues)) + } + if cachedValues[cacheKeys[0]] != "value1" { + t.Errorf("expected value to be value1, got %v", cachedValues["1"]) + } + if cachedValues[cacheKeys[1]] != "value2" { + t.Errorf("expected value to be value2, got %v", cachedValues["2"]) + } + if cachedValues[cacheKeys[2]] != "value3" { + t.Errorf("expected value to be value3, got %v", cachedValues["3"]) + } +} + +func TestGetOrFetchGenericsClashingTypes(t *testing.T) { + t.Parallel() + + capacity := 10 + numShards := 2 + ttl := time.Second * 2 + evictionPercentage := 10 + c := sturdyc.New[any](capacity, numShards, ttl, evictionPercentage, + sturdyc.WithNoContinuousEvictions(), + ) + + resolve := make(chan struct{}) + fetchFuncInt := func(_ context.Context) (int, error) { + <-resolve + return 1, nil + } + fetchFuncString := func(_ context.Context) (string, error) { + return "value", nil + } + + go func() { + time.Sleep(time.Millisecond * 500) + resolve <- struct{}{} + }() + resOne, errOne := sturdyc.GetOrFetch(context.Background(), c, "1", fetchFuncInt) + _, errTwo := sturdyc.GetOrFetch(context.Background(), c, "1", fetchFuncString) + + if errOne != nil { + t.Errorf("expected no error, got %v", errOne) + } + if resOne != 1 { + t.Errorf("expected value to be 1, got %v", resOne) + } + + if !errors.Is(errTwo, sturdyc.ErrInvalidType) { + t.Errorf("expected ErrInvalidType, got %v", errTwo) + } +} + +func TestGetOrFetchBatchGenericsClashingTypes(t *testing.T) { + t.Parallel() + + capacity := 10 + numShards := 2 + ttl := time.Second * 2 + evictionPercentage := 10 + c := sturdyc.New[any](capacity, numShards, ttl, evictionPercentage, + sturdyc.WithNoContinuousEvictions(), + ) + + // We are going to test the behaviour for when the same IDs are passed to + // fetch functions which return different types. + cacheKeyFunc := c.BatchKeyFn("item") + firstCallIDs := []string{"1", "2", "3"} + secondCallIDs := []string{"1", "2", "3", "4"} + + // First, we'll create a fetch function which returns integers. This fetch + // function is going to wait for a message on the resolve channel before + // returning. This is so that we're able to create an in-flight batch for the + // first set of IDs. + resolve := make(chan struct{}) + fetchFuncInt := func(_ context.Context, ids []string) (map[string]int, error) { + if len(ids) != 3 { + t.Fatalf("expected 3 IDs, got %d", len(ids)) + } + response := make(map[string]int, len(ids)) + for i, id := range ids { + response[id] = i + 1 + } + <-resolve + return response, nil + } + + // Next, we'll create a fetch function which returns strings. This fetch + // function should only be called with ID "4" because IDs 1-3 are picked from + // the first batch. + fetchFuncString := func(_ context.Context, ids []string) (map[string]string, error) { + if len(ids) != 1 { + t.Fatalf("expected 1 ID, got %d", len(ids)) + } + response := make(map[string]string, len(ids)) + for _, id := range ids { + response[id] = "value" + id + } + return response, nil + } + + go func() { + time.Sleep(time.Millisecond * 500) + resolve <- struct{}{} + }() + firstCallValues, firstCallErr := sturdyc.GetOrFetchBatch(context.Background(), c, firstCallIDs, cacheKeyFunc, fetchFuncInt) + _, secondCallErr := sturdyc.GetOrFetchBatch(context.Background(), c, secondCallIDs, cacheKeyFunc, fetchFuncString) + + if firstCallErr != nil { + t.Errorf("expected no error, got %v", firstCallErr) + } + if len(firstCallValues) != 3 { + t.Errorf("expected 3 values, got %d", len(firstCallValues)) + } + if firstCallValues["1"] != 1 { + t.Errorf("expected value to be 1, got %v", firstCallValues["1"]) + } + if firstCallValues["2"] != 2 { + t.Errorf("expected value to be 2, got %v", firstCallValues["2"]) + } + if firstCallValues["3"] != 3 { + t.Errorf("expected value to be 3, got %v", firstCallValues["3"]) + } + + if !errors.Is(secondCallErr, sturdyc.ErrInvalidType) { + t.Errorf("expected ErrInvalidType, got %v", secondCallErr) + } +} diff --git a/inflight.go b/inflight.go index f4f2464..13d080a 100644 --- a/inflight.go +++ b/inflight.go @@ -21,7 +21,7 @@ func (c *Client[T]) newFlight(key string) *inFlightCall[T] { return call } -func makeCall[T, V any](ctx context.Context, c *Client[T], key string, fn FetchFn[V], call *inFlightCall[T]) { +func makeCall[T any](ctx context.Context, c *Client[T], key string, fn FetchFn[T], call *inFlightCall[T]) { defer func() { if err := recover(); err != nil { call.err = fmt.Errorf("sturdyc: panic recovered: %v", err) @@ -33,40 +33,38 @@ func makeCall[T, V any](ctx context.Context, c *Client[T], key string, fn FetchF }() response, err := fn(ctx) + call.val = response + if c.storeMissingRecords && errors.Is(err, ErrNotFound) { c.StoreMissingRecord(key) call.err = ErrMissingRecord return } - if err != nil { + if err != nil && !errors.Is(err, errOnlyDistributedRecords) { call.err = err return } - res, ok := any(response).(T) - if !ok { - call.err = ErrInvalidType - return + if errors.Is(err, errOnlyDistributedRecords) { + call.err = onlyCachedRecords(err) } - call.err = nil - call.val = res - c.Set(key, res) + c.Set(key, response) } -func callAndCache[V, T any](ctx context.Context, c *Client[T], key string, fn FetchFn[V]) (V, error) { +func callAndCache[T any](ctx context.Context, c *Client[T], key string, fn FetchFn[T]) (T, error) { c.inFlightMutex.Lock() if call, ok := c.inFlightMap[key]; ok { c.inFlightMutex.Unlock() call.Wait() - return unwrap[V, T](call.val, call.err) + return call.val, call.err } call := c.newFlight(key) c.inFlightMutex.Unlock() makeCall(ctx, c, key, fn, call) - return unwrap[V, T](call.val, call.err) + return call.val, call.err } // newBatchFlight should be called with a lock. @@ -89,22 +87,33 @@ func (c *Client[T]) endBatchFlight(ids []string, keyFn KeyFn, call *inFlightCall c.inFlightBatchMutex.Unlock() } -type makeBatchCallOpts[T, V any] struct { +type makeBatchCallOpts[T any] struct { ids []string - fn BatchFetchFn[V] + fn BatchFetchFn[T] keyFn KeyFn call *inFlightCall[map[string]T] } -func makeBatchCall[T, V any](ctx context.Context, c *Client[T], opts makeBatchCallOpts[T, V]) { +func makeBatchCall[T any](ctx context.Context, c *Client[T], opts makeBatchCallOpts[T]) { response, err := opts.fn(ctx, opts.ids) + for id, record := range response { + // We never want to discard values from the fetch functions, even if they + // return an error. Instead, we'll pass them to the user along with any + // errors and let them decide what to do. + opts.call.val[id] = record + // However, we'll only write them to the cache if the fetchFunction returned a non-nil error. + if err == nil || errors.Is(err, errOnlyDistributedRecords) { + c.Set(opts.keyFn(id), record) + } + } + if err != nil && !errors.Is(err, errOnlyDistributedRecords) { opts.call.err = err return } if errors.Is(err, errOnlyDistributedRecords) { - opts.call.err = ErrOnlyCachedRecords + opts.call.err = onlyCachedRecords(err) } // Check if we should store any of these IDs as a missing record. However, we @@ -119,26 +128,15 @@ func makeBatchCall[T, V any](ctx context.Context, c *Client[T], opts makeBatchCa } } } - - // Store the records in the cache. - for id, record := range response { - v, ok := any(record).(T) - if !ok { - c.log.Error("sturdyc: invalid type for ID:" + id) - continue - } - c.Set(opts.keyFn(id), v) - opts.call.val[id] = v - } } -type callBatchOpts[T, V any] struct { +type callBatchOpts[T any] struct { ids []string keyFn KeyFn - fn BatchFetchFn[V] + fn BatchFetchFn[T] } -func callAndCacheBatch[V, T any](ctx context.Context, c *Client[T], opts callBatchOpts[T, V]) (map[string]V, error) { +func callAndCacheBatch[T any](ctx context.Context, c *Client[T], opts callBatchOpts[T]) (map[string]T, error) { c.inFlightBatchMutex.Lock() callIDs := make(map[*inFlightCall[map[string]T]][]string) @@ -161,40 +159,37 @@ func callAndCacheBatch[V, T any](ctx context.Context, c *Client[T], opts callBat } c.endBatchFlight(uniqueIDs, opts.keyFn, call) }() - batchCallOpts := makeBatchCallOpts[T, V]{ids: uniqueIDs, fn: opts.fn, keyFn: opts.keyFn, call: call} + batchCallOpts := makeBatchCallOpts[T]{ids: uniqueIDs, fn: opts.fn, keyFn: opts.keyFn, call: call} makeBatchCall(ctx, c, batchCallOpts) }() } c.inFlightBatchMutex.Unlock() var err error - response := make(map[string]V, len(opts.ids)) + response := make(map[string]T, len(opts.ids)) for call, callIDs := range callIDs { call.Wait() - // It could be only cached records here, if we we're able - // to get some of the IDs from the distributed storage. - if call.err != nil && !errors.Is(call.err, ErrOnlyCachedRecords) { - return response, call.err - } - - if errors.Is(call.err, ErrOnlyCachedRecords) { - err = ErrOnlyCachedRecords - } - // We need to iterate through the values that we want from this call. The - // batch could contain a hundred IDs, but we might only want a few of them. + // We need to iterate through the values that WE want from this call. The batch + // could contain hundreds of IDs, but we might only want a few of them. for _, id := range callIDs { - v, ok := call.val[id] - if !ok { - continue + if v, ok := call.val[id]; ok { + response[id] = v } + } - if val, ok := any(v).(V); ok { - response[id] = val - continue - } - return response, ErrInvalidType + // This handles the scenario where we either don't get an error, or are + // using the distributed storage option and are able to get some records + // while the request to the underlying data source fails. In the latter + // case, we'll continue to accumulate partial responses as long as the only + // issue is cached-only records. + if err == nil || errors.Is(call.err, ErrOnlyCachedRecords) { + err = call.err + continue } + + // For any other kind of error, we'll short‑circuit the function and return. + return response, call.err } return response, err diff --git a/passthrough.go b/passthrough.go index 829fe53..145fbdf 100644 --- a/passthrough.go +++ b/passthrough.go @@ -67,7 +67,7 @@ func Passthrough[T, V any](ctx context.Context, c *Client[T], key string, fetchF // A map of IDs to their corresponding values, and an error if one occurred and // none of the IDs were found in the cache. func (c *Client[T]) PassthroughBatch(ctx context.Context, ids []string, keyFn KeyFn, fetchFn BatchFetchFn[T]) (map[string]T, error) { - res, err := callAndCacheBatch(ctx, c, callBatchOpts[T, T]{ids, keyFn, fetchFn}) + res, err := callAndCacheBatch(ctx, c, callBatchOpts[T]{ids, keyFn, fetchFn}) if err == nil { return res, nil } diff --git a/refresh.go b/refresh.go index 7162de9..c32c954 100644 --- a/refresh.go +++ b/refresh.go @@ -7,7 +7,7 @@ import ( func (c *Client[T]) refresh(key string, fetchFn FetchFn[T]) { response, err := fetchFn(context.Background()) - if err != nil { + if err != nil && !errors.Is(err, errOnlyDistributedRecords) { if c.storeMissingRecords && errors.Is(err, ErrNotFound) { c.StoreMissingRecord(key) } diff --git a/safe.go b/safe.go index ed33c30..05657e1 100644 --- a/safe.go +++ b/safe.go @@ -22,35 +22,29 @@ func (c *Client[T]) safeGo(fn func()) { func wrap[T, V any](fetchFn FetchFn[V]) FetchFn[T] { return func(ctx context.Context) (T, error) { res, err := fetchFn(ctx) - if err != nil { - var zero T - return zero, err + if val, ok := any(res).(T); ok { + return val, err } - val, ok := any(res).(T) - if !ok { - var zero T - return zero, ErrInvalidType - } - return val, nil + var zero T + return zero, ErrInvalidType } } func unwrap[V, T any](val T, err error) (V, error) { + if errors.Is(err, ErrMissingRecord) { + return *new(V), err + } + v, ok := any(val).(V) if !ok { return v, ErrInvalidType } - return v, err } func wrapBatch[T, V any](fetchFn BatchFetchFn[V]) BatchFetchFn[T] { return func(ctx context.Context, ids []string) (map[string]T, error) { resV, err := fetchFn(ctx, ids) - if err != nil && !errors.Is(err, errOnlyDistributedRecords) { - return map[string]T{}, err - } - resT := make(map[string]T, len(resV)) for id, v := range resV { val, ok := any(v).(T) @@ -59,7 +53,6 @@ func wrapBatch[T, V any](fetchFn BatchFetchFn[V]) BatchFetchFn[T] { } resT[id] = val } - return resT, err } }