Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func getFetchBatch[V, T any](ctx context.Context, c *Client[T], ids []string, ke
cacheMissesAndSyncRefreshes = append(cacheMissesAndSyncRefreshes, cacheMisses...)
cacheMissesAndSyncRefreshes = append(cacheMissesAndSyncRefreshes, idsToSynchronouslyRefresh...)

callBatchOpts := callBatchOpts[T, T]{ids: cacheMissesAndSyncRefreshes, keyFn: keyFn, fn: wrappedFetch}
callBatchOpts := callBatchOpts[T]{ids: cacheMissesAndSyncRefreshes, keyFn: keyFn, fn: wrappedFetch}
response, err := callAndCacheBatch(ctx, c, callBatchOpts)

// If we did a call to synchronously refresh some of the records, and it
Expand All @@ -175,10 +175,14 @@ func getFetchBatch[V, T any](ctx context.Context, c *Client[T], ids []string, ke
}

if err != nil && !errors.Is(err, ErrOnlyCachedRecords) {
// At this point, the call for the IDs that we didn't have in the cache
// have failed. However, these ID's could have been picked from multiple
// in-flight requests. Hence, we'll check if we're able to add any of these
// IDs to the cached records before returning.
if len(cachedRecords) > 0 {
maps.Copy(cachedRecords, response)
return cachedRecords, errors.Join(ErrOnlyCachedRecords, err)
}
return cachedRecords, err
}

maps.Copy(cachedRecords, response)
Expand Down
240 changes: 240 additions & 0 deletions fetch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1568,3 +1568,243 @@ func TestGetFetchBatchMixOfSynchronousAndAsynchronousRefreshes(t *testing.T) {
fetchObserver.AssertRequestedRecords(t, []string{"4"})
fetchObserver.AssertFetchCount(t, 4)
}

func TestGetOrFetchGenerics(t *testing.T) {
t.Parallel()

capacity := 10
numShards := 2
ttl := time.Second * 2
evictionPercentage := 10
c := sturdyc.New[string](capacity, numShards, ttl, evictionPercentage,
sturdyc.WithNoContinuousEvictions(),
)

fetchFuncInt := func(_ context.Context) (int, error) {
return 1, nil
}

_, err := sturdyc.GetOrFetch(context.Background(), c, "1", fetchFuncInt)
if !errors.Is(err, sturdyc.ErrInvalidType) {
t.Errorf("expected ErrInvalidType, got %v", err)
}

if _, ok := c.Get("1"); ok {
t.Error("we should not have cached anything given that the value was of the wrong type")
}

if c.Size() != 0 {
t.Errorf("expected cache size to be 0, got %d", c.Size())
}

fetchFuncString := func(_ context.Context) (string, error) {
return "value", nil
}

val, err := sturdyc.GetOrFetch(context.Background(), c, "1", fetchFuncString)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if val != "value" {
t.Errorf("expected value to be value, got %v", val)
}

cachedValue, ok := c.Get("1")
if !ok {
t.Error("expected value to be in the cache")
}
if cachedValue != "value" {
t.Errorf("expected value to be value, got %v", cachedValue)
}
if c.Size() != 1 {
t.Errorf("expected cache size to be 1, got %d", c.Size())
}
}

func TestGetOrFetchBatchGenerics(t *testing.T) {
t.Parallel()

capacity := 10
numShards := 2
ttl := time.Second * 2
evictionPercentage := 10
c := sturdyc.New[string](capacity, numShards, ttl, evictionPercentage,
sturdyc.WithNoContinuousEvictions(),
)

fetchFuncInt := func(_ context.Context, ids []string) (map[string]int, error) {
response := make(map[string]int, len(ids))
for i, id := range ids {
response[id] = i
}
return response, nil
}

ids := []string{"1", "2", "3"}
_, err := sturdyc.GetOrFetchBatch(context.Background(), c, ids, c.BatchKeyFn("item"), fetchFuncInt)
if !errors.Is(err, sturdyc.ErrInvalidType) {
t.Errorf("expected ErrInvalidType, got %v", err)
}
if c.Size() != 0 {
t.Errorf("expected cache size to be 0, got %d", c.Size())
}

fetchFuncString := func(_ context.Context, ids []string) (map[string]string, error) {
response := make(map[string]string, len(ids))
for _, id := range ids {
response[id] = "value" + id
}
return response, nil
}

keyFunc := c.BatchKeyFn("item")
values, err := sturdyc.GetOrFetchBatch(context.Background(), c, ids, keyFunc, fetchFuncString)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if len(values) != 3 {
t.Errorf("expected 3 values, got %d", len(values))
}
if values["1"] != "value1" {
t.Errorf("expected value to be value1, got %v", values["1"])
}
if values["2"] != "value2" {
t.Errorf("expected value to be value2, got %v", values["2"])
}
if values["3"] != "value3" {
t.Errorf("expected value to be value3, got %v", values["3"])
}

cacheKeys := make([]string, 0, len(ids))
for _, id := range ids {
cacheKeys = append(cacheKeys, keyFunc(id))
}
cachedValues := c.GetMany(cacheKeys)
if len(cachedValues) != 3 {
t.Errorf("expected 3 values, got %d", len(cachedValues))
}
if cachedValues[cacheKeys[0]] != "value1" {
t.Errorf("expected value to be value1, got %v", cachedValues["1"])
}
if cachedValues[cacheKeys[1]] != "value2" {
t.Errorf("expected value to be value2, got %v", cachedValues["2"])
}
if cachedValues[cacheKeys[2]] != "value3" {
t.Errorf("expected value to be value3, got %v", cachedValues["3"])
}
}

func TestGetOrFetchGenericsClashingTypes(t *testing.T) {
t.Parallel()

capacity := 10
numShards := 2
ttl := time.Second * 2
evictionPercentage := 10
c := sturdyc.New[any](capacity, numShards, ttl, evictionPercentage,
sturdyc.WithNoContinuousEvictions(),
)

resolve := make(chan struct{})
fetchFuncInt := func(_ context.Context) (int, error) {
<-resolve
return 1, nil
}
fetchFuncString := func(_ context.Context) (string, error) {
return "value", nil
}

go func() {
time.Sleep(time.Millisecond * 500)
resolve <- struct{}{}
}()
resOne, errOne := sturdyc.GetOrFetch(context.Background(), c, "1", fetchFuncInt)
_, errTwo := sturdyc.GetOrFetch(context.Background(), c, "1", fetchFuncString)

if errOne != nil {
t.Errorf("expected no error, got %v", errOne)
}
if resOne != 1 {
t.Errorf("expected value to be 1, got %v", resOne)
}

if !errors.Is(errTwo, sturdyc.ErrInvalidType) {
t.Errorf("expected ErrInvalidType, got %v", errTwo)
}
}

func TestGetOrFetchBatchGenericsClashingTypes(t *testing.T) {
t.Parallel()

capacity := 10
numShards := 2
ttl := time.Second * 2
evictionPercentage := 10
c := sturdyc.New[any](capacity, numShards, ttl, evictionPercentage,
sturdyc.WithNoContinuousEvictions(),
)

// We are going to test the behaviour for when the same IDs are passed to
// fetch functions which return different types.
cacheKeyFunc := c.BatchKeyFn("item")
firstCallIDs := []string{"1", "2", "3"}
secondCallIDs := []string{"1", "2", "3", "4"}

// First, we'll create a fetch function which returns integers. This fetch
// function is going to wait for a message on the resolve channel before
// returning. This is so that we're able to create an in-flight batch for the
// first set of IDs.
resolve := make(chan struct{})
fetchFuncInt := func(_ context.Context, ids []string) (map[string]int, error) {
if len(ids) != 3 {
t.Fatalf("expected 3 IDs, got %d", len(ids))
}
response := make(map[string]int, len(ids))
for i, id := range ids {
response[id] = i + 1
}
<-resolve
return response, nil
}

// Next, we'll create a fetch function which returns strings. This fetch
// function should only be called with ID "4" because IDs 1-3 are picked from
// the first batch.
fetchFuncString := func(_ context.Context, ids []string) (map[string]string, error) {
if len(ids) != 1 {
t.Fatalf("expected 1 ID, got %d", len(ids))
}
response := make(map[string]string, len(ids))
for _, id := range ids {
response[id] = "value" + id
}
return response, nil
}

go func() {
time.Sleep(time.Millisecond * 500)
resolve <- struct{}{}
}()
firstCallValues, firstCallErr := sturdyc.GetOrFetchBatch(context.Background(), c, firstCallIDs, cacheKeyFunc, fetchFuncInt)
_, secondCallErr := sturdyc.GetOrFetchBatch(context.Background(), c, secondCallIDs, cacheKeyFunc, fetchFuncString)

if firstCallErr != nil {
t.Errorf("expected no error, got %v", firstCallErr)
}
if len(firstCallValues) != 3 {
t.Errorf("expected 3 values, got %d", len(firstCallValues))
}
if firstCallValues["1"] != 1 {
t.Errorf("expected value to be 1, got %v", firstCallValues["1"])
}
if firstCallValues["2"] != 2 {
t.Errorf("expected value to be 2, got %v", firstCallValues["2"])
}
if firstCallValues["3"] != 3 {
t.Errorf("expected value to be 3, got %v", firstCallValues["3"])
}

if !errors.Is(secondCallErr, sturdyc.ErrInvalidType) {
t.Errorf("expected ErrInvalidType, got %v", secondCallErr)
}
}
Loading