From 08c3ff746ff173c2e760acaf7d16b4b2f7117f13 Mon Sep 17 00:00:00 2001 From: Christopher Hlubek Date: Thu, 9 Oct 2025 21:39:28 +0200 Subject: [PATCH] fix: add locking via flock to filecache to prevent broken cache files fixes #9 --- filecache/filecache.go | 28 ++++ filecache/filecache_test.go | 260 ++++++++++++++++++++++++++++++++++++ go.mod | 1 + go.sum | 2 + 4 files changed, 291 insertions(+) create mode 100644 filecache/filecache_test.go diff --git a/filecache/filecache.go b/filecache/filecache.go index b449ab0..dad51af 100644 --- a/filecache/filecache.go +++ b/filecache/filecache.go @@ -7,6 +7,7 @@ import ( "path/filepath" "time" + "github.com/gofrs/flock" "github.com/pkg/errors" "go.jetify.com/pkg/cachehash" ) @@ -57,6 +58,13 @@ func (c *Cache[T]) Set(key string, val T, dur time.Duration) error { return errors.WithStack(err) } + // Acquire exclusive lock to prevent concurrent writes from corrupting the file + lock := flock.New(c.lockfile()) + if err := lock.Lock(); err != nil { + return errors.WithStack(err) + } + defer lock.Unlock() + return errors.WithStack(os.WriteFile(c.filename(key), d, 0o644)) } @@ -68,6 +76,13 @@ func (c *Cache[T]) SetWithTime(key string, val T, t time.Time) error { return errors.WithStack(err) } + // Acquire exclusive lock to prevent concurrent writes from corrupting the file + lock := flock.New(c.lockfile()) + if err := lock.Lock(); err != nil { + return errors.WithStack(err) + } + defer lock.Unlock() + return errors.WithStack(os.WriteFile(c.filename(key), d, 0o644)) } @@ -76,6 +91,13 @@ func (c *Cache[T]) Get(key string) (T, error) { path := c.filename(key) resultData := data[T]{} + // Acquire shared lock before checking file existence to prevent TOCTOU race + lock := flock.New(c.lockfile()) + if err := lock.RLock(); err != nil { + return resultData.Val, errors.WithStack(err) + } + defer lock.Unlock() + if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) { return resultData.Val, NotFound } @@ -147,3 +169,9 @@ func (c *Cache[T]) filename(key string) string { _ = os.MkdirAll(dir, 0o755) return filepath.Join(dir, cachehash.Slug(key)) } + +func (c *Cache[T]) lockfile() string { + dir := filepath.Join(c.cacheDir, c.domain) + _ = os.MkdirAll(dir, 0o755) + return filepath.Join(dir, ".lock") +} diff --git a/filecache/filecache_test.go b/filecache/filecache_test.go new file mode 100644 index 0000000..69511f3 --- /dev/null +++ b/filecache/filecache_test.go @@ -0,0 +1,260 @@ +package filecache_test + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.jetify.com/pkg/filecache" +) + +type testData struct { + Value string + Counter int +} + +func TestCacheOperations(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T, cache *filecache.Cache[testData]) + }{ + { + name: "basic set and get", + run: func(t *testing.T, cache *filecache.Cache[testData]) { + // Test cache miss + _, err := cache.Get("key1") + assert.True(t, filecache.IsCacheMiss(err)) + + // Test Set and Get + data := testData{Value: "hello", Counter: 42} + err = cache.Set("key1", data, time.Hour) + require.NoError(t, err) + + result, err := cache.Get("key1") + require.NoError(t, err) + assert.Equal(t, data, result) + }, + }, + { + name: "set with time", + run: func(t *testing.T, cache *filecache.Cache[testData]) { + data := testData{Value: "world", Counter: 123} + expiration := time.Now().Add(time.Hour) + err := cache.SetWithTime("key1", data, expiration) + require.NoError(t, err) + + result, err := cache.Get("key1") + require.NoError(t, err) + assert.Equal(t, data, result) + }, + }, + { + name: "expiration", + run: func(t *testing.T, cache *filecache.Cache[testData]) { + data := testData{Value: "expires", Counter: 1} + // Set with expiration in the past + err := cache.SetWithTime("key1", data, time.Now().Add(-time.Hour)) + require.NoError(t, err) + + _, err = cache.Get("key1") + assert.True(t, filecache.IsCacheMiss(err)) + }, + }, + { + name: "get or set", + run: func(t *testing.T, cache *filecache.Cache[testData]) { + callCount := 0 + fetchFunc := func() (testData, time.Duration, error) { + callCount++ + return testData{Value: "fetched", Counter: callCount}, time.Hour, nil + } + + // First call should fetch + result1, err := cache.GetOrSet("key1", fetchFunc) + require.NoError(t, err) + assert.Equal(t, "fetched", result1.Value) + assert.Equal(t, 1, callCount) + + // Second call should use cache + result2, err := cache.GetOrSet("key1", fetchFunc) + require.NoError(t, err) + assert.Equal(t, "fetched", result2.Value) + assert.Equal(t, 1, callCount) // Should not increment + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cache := filecache.New[testData]("test-domain", filecache.WithCacheDir[testData](t.TempDir())) + tt.run(t, cache) + }) + } +} + +func TestConcurrentAccess(t *testing.T) { + t.Run("concurrent writes to same key", func(t *testing.T) { + cache := filecache.New[testData]("test-domain", filecache.WithCacheDir[testData](t.TempDir())) + + numGoroutines := 10 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + // All goroutines write to the same key + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + data := testData{Value: fmt.Sprintf("writer-%d", id), Counter: id} + err := cache.Set("same-key", data, time.Hour) + assert.NoError(t, err) + }(i) + } + + wg.Wait() + + // The key should exist and contain valid data from one of the writers + result, err := cache.Get("same-key") + require.NoError(t, err) + assert.NotEmpty(t, result.Value) + assert.True(t, result.Counter >= 0 && result.Counter < numGoroutines) + }) + + t.Run("concurrent writes to different keys", func(t *testing.T) { + cache := filecache.New[testData]("test-domain", filecache.WithCacheDir[testData](t.TempDir())) + + numGoroutines := 20 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + // Each goroutine writes to a different key + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + key := fmt.Sprintf("key-%d", id) + data := testData{Value: fmt.Sprintf("value-%d", id), Counter: id} + err := cache.Set(key, data, time.Hour) + assert.NoError(t, err) + }(i) + } + + wg.Wait() + + // Verify all keys were written correctly + for i := 0; i < numGoroutines; i++ { + key := fmt.Sprintf("key-%d", i) + result, err := cache.Get(key) + require.NoError(t, err, "Failed to get key %s", key) + assert.Equal(t, fmt.Sprintf("value-%d", i), result.Value) + assert.Equal(t, i, result.Counter) + } + }) + + t.Run("concurrent get or set same key", func(t *testing.T) { + cache := filecache.New[testData]("test-domain", filecache.WithCacheDir[testData](t.TempDir())) + + numGoroutines := 10 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + callCount := 0 + var mu sync.Mutex + + fetchFunc := func() (testData, time.Duration, error) { + mu.Lock() + callCount++ + count := callCount + mu.Unlock() + // Simulate slow fetch + time.Sleep(10 * time.Millisecond) + return testData{Value: "shared", Counter: count}, time.Hour, nil + } + + // All goroutines try to GetOrSet the same key + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + result, err := cache.GetOrSet("shared-key", fetchFunc) + assert.NoError(t, err) + assert.Equal(t, "shared", result.Value) + }() + } + + wg.Wait() + + // The fetch function may be called multiple times due to race, + // but the final cached value should be valid + result, err := cache.Get("shared-key") + require.NoError(t, err) + assert.Equal(t, "shared", result.Value) + assert.True(t, result.Counter > 0) + }) + + t.Run("concurrent reads and writes", func(t *testing.T) { + cache := filecache.New[testData]("test-domain", filecache.WithCacheDir[testData](t.TempDir())) + + // Pre-populate the cache + err := cache.Set("key", testData{Value: "initial", Counter: 0}, time.Hour) + require.NoError(t, err) + + numReaders := 10 + numWriters := 5 + var wg sync.WaitGroup + wg.Add(numReaders + numWriters) + + // Spawn readers + for i := 0; i < numReaders; i++ { + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + result, err := cache.Get("key") + // We should either get valid data or an error, but never corrupted data + if err == nil { + assert.NotEmpty(t, result.Value) + } + } + }() + } + + // Spawn writers + for i := 0; i < numWriters; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < 50; j++ { + data := testData{Value: fmt.Sprintf("writer-%d-iteration-%d", id, j), Counter: j} + err := cache.Set("key", data, time.Hour) + assert.NoError(t, err) + } + }(i) + } + + wg.Wait() + + // Final read should succeed with valid data + result, err := cache.Get("key") + require.NoError(t, err) + assert.NotEmpty(t, result.Value) + }) +} + +func TestClear(t *testing.T) { + cache := filecache.New[testData]("test-domain", filecache.WithCacheDir[testData](t.TempDir())) + + // Add some data + err := cache.Set("key1", testData{Value: "value1", Counter: 1}, time.Hour) + require.NoError(t, err) + err = cache.Set("key2", testData{Value: "value2", Counter: 2}, time.Hour) + require.NoError(t, err) + + // Clear the cache + err = cache.Clear() + require.NoError(t, err) + + // Data should be gone + _, err = cache.Get("key1") + assert.True(t, filecache.IsCacheMiss(err)) + _, err = cache.Get("key2") + assert.True(t, filecache.IsCacheMiss(err)) +} diff --git a/go.mod b/go.mod index 8ee78b1..a650a0f 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/fatih/color v1.18.0 github.com/go-jose/go-jose/v4 v4.1.2 github.com/goccy/go-yaml v1.18.0 + github.com/gofrs/flock v0.12.1 github.com/google/go-github/v74 v74.0.0 github.com/google/renameio/v2 v2.0.0 github.com/gosimple/slug v1.15.0 diff --git a/go.sum b/go.sum index 3916a63..e3310f9 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,8 @@ github.com/go-jose/go-jose/v4 v4.1.2 h1:TK/7NqRQZfgAh+Td8AlsrvtPoUyiHh0LqVvokh+1 github.com/go-jose/go-jose/v4 v4.1.2/go.mod h1:22cg9HWM1pOlnRiY+9cQYJ9XHmya1bYW8OeDM6Ku6Oo= github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= +github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0= github.com/gofrs/uuid/v5 v5.3.2 h1:2jfO8j3XgSwlz/wHqemAEugfnTlikAYHhnqQ8Xh4fE0= github.com/gofrs/uuid/v5 v5.3.2/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=