diff --git a/fetch.go b/fetch.go index 0ffc292..b39ec66 100644 --- a/fetch.go +++ b/fetch.go @@ -151,7 +151,7 @@ func getFetchBatch[V, T any](ctx context.Context, c *Client[T], ids []string, ke cacheMissesAndSyncRefreshes = append(cacheMissesAndSyncRefreshes, cacheMisses...) cacheMissesAndSyncRefreshes = append(cacheMissesAndSyncRefreshes, idsToSynchronouslyRefresh...) - callBatchOpts := callBatchOpts[T, T]{ids: cacheMissesAndSyncRefreshes, keyFn: keyFn, fn: wrappedFetch} + callBatchOpts := callBatchOpts[T]{ids: cacheMissesAndSyncRefreshes, keyFn: keyFn, fn: wrappedFetch} response, err := callAndCacheBatch(ctx, c, callBatchOpts) // If we did a call to synchronously refresh some of the records, and it @@ -175,10 +175,14 @@ func getFetchBatch[V, T any](ctx context.Context, c *Client[T], ids []string, ke } if err != nil && !errors.Is(err, ErrOnlyCachedRecords) { + // At this point, the call for the IDs that we didn't have in the cache + // have failed. However, these ID's could have been picked from multiple + // in-flight requests. Hence, we'll check if we're able to add any of these + // IDs to the cached records before returning. if len(cachedRecords) > 0 { + maps.Copy(cachedRecords, response) return cachedRecords, errors.Join(ErrOnlyCachedRecords, err) } - return cachedRecords, err } maps.Copy(cachedRecords, response) diff --git a/fetch_test.go b/fetch_test.go index 18ca076..ec817d6 100644 --- a/fetch_test.go +++ b/fetch_test.go @@ -1568,3 +1568,243 @@ func TestGetFetchBatchMixOfSynchronousAndAsynchronousRefreshes(t *testing.T) { fetchObserver.AssertRequestedRecords(t, []string{"4"}) fetchObserver.AssertFetchCount(t, 4) } + +func TestGetOrFetchGenerics(t *testing.T) { + t.Parallel() + + capacity := 10 + numShards := 2 + ttl := time.Second * 2 + evictionPercentage := 10 + c := sturdyc.New[string](capacity, numShards, ttl, evictionPercentage, + sturdyc.WithNoContinuousEvictions(), + ) + + fetchFuncInt := func(_ context.Context) (int, error) { + return 1, nil + } + + _, err := sturdyc.GetOrFetch(context.Background(), c, "1", fetchFuncInt) + if !errors.Is(err, sturdyc.ErrInvalidType) { + t.Errorf("expected ErrInvalidType, got %v", err) + } + + if _, ok := c.Get("1"); ok { + t.Error("we should not have cached anything given that the value was of the wrong type") + } + + if c.Size() != 0 { + t.Errorf("expected cache size to be 0, got %d", c.Size()) + } + + fetchFuncString := func(_ context.Context) (string, error) { + return "value", nil + } + + val, err := sturdyc.GetOrFetch(context.Background(), c, "1", fetchFuncString) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if val != "value" { + t.Errorf("expected value to be value, got %v", val) + } + + cachedValue, ok := c.Get("1") + if !ok { + t.Error("expected value to be in the cache") + } + if cachedValue != "value" { + t.Errorf("expected value to be value, got %v", cachedValue) + } + if c.Size() != 1 { + t.Errorf("expected cache size to be 1, got %d", c.Size()) + } +} + +func TestGetOrFetchBatchGenerics(t *testing.T) { + t.Parallel() + + capacity := 10 + numShards := 2 + ttl := time.Second * 2 + evictionPercentage := 10 + c := sturdyc.New[string](capacity, numShards, ttl, evictionPercentage, + sturdyc.WithNoContinuousEvictions(), + ) + + fetchFuncInt := func(_ context.Context, ids []string) (map[string]int, error) { + response := make(map[string]int, len(ids)) + for i, id := range ids { + response[id] = i + } + return response, nil + } + + ids := []string{"1", "2", "3"} + _, err := sturdyc.GetOrFetchBatch(context.Background(), c, ids, c.BatchKeyFn("item"), fetchFuncInt) + if !errors.Is(err, sturdyc.ErrInvalidType) { + t.Errorf("expected ErrInvalidType, got %v", err) + } + if c.Size() != 0 { + t.Errorf("expected cache size to be 0, got %d", c.Size()) + } + + fetchFuncString := func(_ context.Context, ids []string) (map[string]string, error) { + response := make(map[string]string, len(ids)) + for _, id := range ids { + response[id] = "value" + id + } + return response, nil + } + + keyFunc := c.BatchKeyFn("item") + values, err := sturdyc.GetOrFetchBatch(context.Background(), c, ids, keyFunc, fetchFuncString) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(values) != 3 { + t.Errorf("expected 3 values, got %d", len(values)) + } + if values["1"] != "value1" { + t.Errorf("expected value to be value1, got %v", values["1"]) + } + if values["2"] != "value2" { + t.Errorf("expected value to be value2, got %v", values["2"]) + } + if values["3"] != "value3" { + t.Errorf("expected value to be value3, got %v", values["3"]) + } + + cacheKeys := make([]string, 0, len(ids)) + for _, id := range ids { + cacheKeys = append(cacheKeys, keyFunc(id)) + } + cachedValues := c.GetMany(cacheKeys) + if len(cachedValues) != 3 { + t.Errorf("expected 3 values, got %d", len(cachedValues)) + } + if cachedValues[cacheKeys[0]] != "value1" { + t.Errorf("expected value to be value1, got %v", cachedValues["1"]) + } + if cachedValues[cacheKeys[1]] != "value2" { + t.Errorf("expected value to be value2, got %v", cachedValues["2"]) + } + if cachedValues[cacheKeys[2]] != "value3" { + t.Errorf("expected value to be value3, got %v", cachedValues["3"]) + } +} + +func TestGetOrFetchGenericsClashingTypes(t *testing.T) { + t.Parallel() + + capacity := 10 + numShards := 2 + ttl := time.Second * 2 + evictionPercentage := 10 + c := sturdyc.New[any](capacity, numShards, ttl, evictionPercentage, + sturdyc.WithNoContinuousEvictions(), + ) + + resolve := make(chan struct{}) + fetchFuncInt := func(_ context.Context) (int, error) { + <-resolve + return 1, nil + } + fetchFuncString := func(_ context.Context) (string, error) { + return "value", nil + } + + go func() { + time.Sleep(time.Millisecond * 500) + resolve <- struct{}{} + }() + resOne, errOne := sturdyc.GetOrFetch(context.Background(), c, "1", fetchFuncInt) + _, errTwo := sturdyc.GetOrFetch(context.Background(), c, "1", fetchFuncString) + + if errOne != nil { + t.Errorf("expected no error, got %v", errOne) + } + if resOne != 1 { + t.Errorf("expected value to be 1, got %v", resOne) + } + + if !errors.Is(errTwo, sturdyc.ErrInvalidType) { + t.Errorf("expected ErrInvalidType, got %v", errTwo) + } +} + +func TestGetOrFetchBatchGenericsClashingTypes(t *testing.T) { + t.Parallel() + + capacity := 10 + numShards := 2 + ttl := time.Second * 2 + evictionPercentage := 10 + c := sturdyc.New[any](capacity, numShards, ttl, evictionPercentage, + sturdyc.WithNoContinuousEvictions(), + ) + + // We are going to test the behaviour for when the same IDs are passed to + // fetch functions which return different types. + cacheKeyFunc := c.BatchKeyFn("item") + firstCallIDs := []string{"1", "2", "3"} + secondCallIDs := []string{"1", "2", "3", "4"} + + // First, we'll create a fetch function which returns integers. This fetch + // function is going to wait for a message on the resolve channel before + // returning. This is so that we're able to create an in-flight batch for the + // first set of IDs. + resolve := make(chan struct{}) + fetchFuncInt := func(_ context.Context, ids []string) (map[string]int, error) { + if len(ids) != 3 { + t.Fatalf("expected 3 IDs, got %d", len(ids)) + } + response := make(map[string]int, len(ids)) + for i, id := range ids { + response[id] = i + 1 + } + <-resolve + return response, nil + } + + // Next, we'll create a fetch function which returns strings. This fetch + // function should only be called with ID "4" because IDs 1-3 are picked from + // the first batch. + fetchFuncString := func(_ context.Context, ids []string) (map[string]string, error) { + if len(ids) != 1 { + t.Fatalf("expected 1 ID, got %d", len(ids)) + } + response := make(map[string]string, len(ids)) + for _, id := range ids { + response[id] = "value" + id + } + return response, nil + } + + go func() { + time.Sleep(time.Millisecond * 500) + resolve <- struct{}{} + }() + firstCallValues, firstCallErr := sturdyc.GetOrFetchBatch(context.Background(), c, firstCallIDs, cacheKeyFunc, fetchFuncInt) + _, secondCallErr := sturdyc.GetOrFetchBatch(context.Background(), c, secondCallIDs, cacheKeyFunc, fetchFuncString) + + if firstCallErr != nil { + t.Errorf("expected no error, got %v", firstCallErr) + } + if len(firstCallValues) != 3 { + t.Errorf("expected 3 values, got %d", len(firstCallValues)) + } + if firstCallValues["1"] != 1 { + t.Errorf("expected value to be 1, got %v", firstCallValues["1"]) + } + if firstCallValues["2"] != 2 { + t.Errorf("expected value to be 2, got %v", firstCallValues["2"]) + } + if firstCallValues["3"] != 3 { + t.Errorf("expected value to be 3, got %v", firstCallValues["3"]) + } + + if !errors.Is(secondCallErr, sturdyc.ErrInvalidType) { + t.Errorf("expected ErrInvalidType, got %v", secondCallErr) + } +} diff --git a/inflight.go b/inflight.go index f4f2464..499f259 100644 --- a/inflight.go +++ b/inflight.go @@ -21,7 +21,7 @@ func (c *Client[T]) newFlight(key string) *inFlightCall[T] { return call } -func makeCall[T, V any](ctx context.Context, c *Client[T], key string, fn FetchFn[V], call *inFlightCall[T]) { +func makeCall[T any](ctx context.Context, c *Client[T], key string, fn FetchFn[T], call *inFlightCall[T]) { defer func() { if err := recover(); err != nil { call.err = fmt.Errorf("sturdyc: panic recovered: %v", err) @@ -33,40 +33,34 @@ func makeCall[T, V any](ctx context.Context, c *Client[T], key string, fn FetchF }() response, err := fn(ctx) + call.val = response + if c.storeMissingRecords && errors.Is(err, ErrNotFound) { c.StoreMissingRecord(key) call.err = ErrMissingRecord return } + call.err = err if err != nil { - call.err = err return } - res, ok := any(response).(T) - if !ok { - call.err = ErrInvalidType - return - } - - call.err = nil - call.val = res - c.Set(key, res) + c.Set(key, response) } -func callAndCache[V, T any](ctx context.Context, c *Client[T], key string, fn FetchFn[V]) (V, error) { +func callAndCache[T any](ctx context.Context, c *Client[T], key string, fn FetchFn[T]) (T, error) { c.inFlightMutex.Lock() if call, ok := c.inFlightMap[key]; ok { c.inFlightMutex.Unlock() call.Wait() - return unwrap[V, T](call.val, call.err) + return call.val, call.err } call := c.newFlight(key) c.inFlightMutex.Unlock() makeCall(ctx, c, key, fn, call) - return unwrap[V, T](call.val, call.err) + return call.val, call.err } // newBatchFlight should be called with a lock. @@ -89,15 +83,26 @@ func (c *Client[T]) endBatchFlight(ids []string, keyFn KeyFn, call *inFlightCall c.inFlightBatchMutex.Unlock() } -type makeBatchCallOpts[T, V any] struct { +type makeBatchCallOpts[T any] struct { ids []string - fn BatchFetchFn[V] + fn BatchFetchFn[T] keyFn KeyFn call *inFlightCall[map[string]T] } -func makeBatchCall[T, V any](ctx context.Context, c *Client[T], opts makeBatchCallOpts[T, V]) { +func makeBatchCall[T any](ctx context.Context, c *Client[T], opts makeBatchCallOpts[T]) { response, err := opts.fn(ctx, opts.ids) + for id, record := range response { + // We never want to discard values from the fetch functions, even if they + // return an error. Instead, we'll pass them to the user along with any + // errors and let them decide what to do. + opts.call.val[id] = record + // However, we'll only write them to the cache if the fetchFunction returned a non-nil error. + if err == nil || errors.Is(err, errOnlyDistributedRecords) { + c.Set(opts.keyFn(id), record) + } + } + if err != nil && !errors.Is(err, errOnlyDistributedRecords) { opts.call.err = err return @@ -119,26 +124,15 @@ func makeBatchCall[T, V any](ctx context.Context, c *Client[T], opts makeBatchCa } } } - - // Store the records in the cache. - for id, record := range response { - v, ok := any(record).(T) - if !ok { - c.log.Error("sturdyc: invalid type for ID:" + id) - continue - } - c.Set(opts.keyFn(id), v) - opts.call.val[id] = v - } } -type callBatchOpts[T, V any] struct { +type callBatchOpts[T any] struct { ids []string keyFn KeyFn - fn BatchFetchFn[V] + fn BatchFetchFn[T] } -func callAndCacheBatch[V, T any](ctx context.Context, c *Client[T], opts callBatchOpts[T, V]) (map[string]V, error) { +func callAndCacheBatch[T any](ctx context.Context, c *Client[T], opts callBatchOpts[T]) (map[string]T, error) { c.inFlightBatchMutex.Lock() callIDs := make(map[*inFlightCall[map[string]T]][]string) @@ -161,40 +155,37 @@ func callAndCacheBatch[V, T any](ctx context.Context, c *Client[T], opts callBat } c.endBatchFlight(uniqueIDs, opts.keyFn, call) }() - batchCallOpts := makeBatchCallOpts[T, V]{ids: uniqueIDs, fn: opts.fn, keyFn: opts.keyFn, call: call} + batchCallOpts := makeBatchCallOpts[T]{ids: uniqueIDs, fn: opts.fn, keyFn: opts.keyFn, call: call} makeBatchCall(ctx, c, batchCallOpts) }() } c.inFlightBatchMutex.Unlock() var err error - response := make(map[string]V, len(opts.ids)) + response := make(map[string]T, len(opts.ids)) for call, callIDs := range callIDs { call.Wait() - // It could be only cached records here, if we we're able - // to get some of the IDs from the distributed storage. - if call.err != nil && !errors.Is(call.err, ErrOnlyCachedRecords) { - return response, call.err - } - - if errors.Is(call.err, ErrOnlyCachedRecords) { - err = ErrOnlyCachedRecords - } - // We need to iterate through the values that we want from this call. The - // batch could contain a hundred IDs, but we might only want a few of them. + // We need to iterate through the values that WE want from this call. The batch + // could contain hundreds of IDs, but we might only want a few of them. for _, id := range callIDs { - v, ok := call.val[id] - if !ok { - continue + if v, ok := call.val[id]; ok { + response[id] = v } + } - if val, ok := any(v).(V); ok { - response[id] = val - continue - } - return response, ErrInvalidType + // This handles the scenario where we either don't get an error, or are + // using the distributed storage option and are able to get some records + // while the request to the underlying data source fails. In the latter + // case, we'll continue to accumulate partial responses as long as the only + // issue is cached-only records. + if err == nil || errors.Is(call.err, ErrOnlyCachedRecords) { + err = call.err + continue } + + // For any other kind of error, we'll short‑circuit the function and return. + return response, call.err } return response, err diff --git a/passthrough.go b/passthrough.go index 829fe53..145fbdf 100644 --- a/passthrough.go +++ b/passthrough.go @@ -67,7 +67,7 @@ func Passthrough[T, V any](ctx context.Context, c *Client[T], key string, fetchF // A map of IDs to their corresponding values, and an error if one occurred and // none of the IDs were found in the cache. func (c *Client[T]) PassthroughBatch(ctx context.Context, ids []string, keyFn KeyFn, fetchFn BatchFetchFn[T]) (map[string]T, error) { - res, err := callAndCacheBatch(ctx, c, callBatchOpts[T, T]{ids, keyFn, fetchFn}) + res, err := callAndCacheBatch(ctx, c, callBatchOpts[T]{ids, keyFn, fetchFn}) if err == nil { return res, nil } diff --git a/safe.go b/safe.go index ed33c30..c156891 100644 --- a/safe.go +++ b/safe.go @@ -2,7 +2,6 @@ package sturdyc import ( "context" - "errors" "fmt" ) @@ -22,16 +21,11 @@ func (c *Client[T]) safeGo(fn func()) { func wrap[T, V any](fetchFn FetchFn[V]) FetchFn[T] { return func(ctx context.Context) (T, error) { res, err := fetchFn(ctx) - if err != nil { - var zero T - return zero, err + if val, ok := any(res).(T); ok { + return val, err } - val, ok := any(res).(T) - if !ok { - var zero T - return zero, ErrInvalidType - } - return val, nil + var zero T + return zero, ErrInvalidType } } @@ -40,17 +34,12 @@ func unwrap[V, T any](val T, err error) (V, error) { if !ok { return v, ErrInvalidType } - return v, err } func wrapBatch[T, V any](fetchFn BatchFetchFn[V]) BatchFetchFn[T] { return func(ctx context.Context, ids []string) (map[string]T, error) { resV, err := fetchFn(ctx, ids) - if err != nil && !errors.Is(err, errOnlyDistributedRecords) { - return map[string]T{}, err - } - resT := make(map[string]T, len(resV)) for id, v := range resV { val, ok := any(v).(T) @@ -59,7 +48,6 @@ func wrapBatch[T, V any](fetchFn BatchFetchFn[V]) BatchFetchFn[T] { } resT[id] = val } - return resT, err } }