From 7e2e1e575b43e2528b34cab7ea8d379c6474b6bb Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Wed, 7 Jan 2026 15:01:55 +1100 Subject: [PATCH] refactor: replace xattr with bbolt for storing TTLs for the disk cache xattr's are a bit fragile, and additionally if we use an on-disk database we will be able to snapshot the cache trivially. --- Justfile | 2 +- go.mod | 4 +- go.sum | 18 +++- internal/cache/disk.go | 150 ++++++++++++++-------------------- internal/cache/ttl_storage.go | 133 ++++++++++++++++++++++++++++++ 5 files changed, 213 insertions(+), 94 deletions(-) create mode 100644 internal/cache/ttl_storage.go diff --git a/Justfile b/Justfile index 3911149..9f18a8a 100644 --- a/Justfile +++ b/Justfile @@ -6,7 +6,7 @@ _help: # Run tests test: - @gotestsum --hide-summary output,skipped --format-hide-empty-pkg ${CI:+--format github-actions} + @gotestsum --hide-summary output,skipped --format-hide-empty-pkg ${CI:+--format github-actions} ./... -- -race -timeout 30s # Lint code lint: diff --git a/go.mod b/go.mod index 4635eb4..1a5e865 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,12 @@ require ( github.com/alecthomas/hcl/v2 v2.3.1 github.com/alecthomas/kong v1.13.0 github.com/lmittmann/tint v1.1.2 + go.etcd.io/bbolt v1.4.3 ) require ( github.com/hexops/gotextdiff v1.0.3 // indirect - golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f // indirect + golang.org/x/sys v0.29.0 // indirect ) require ( @@ -18,5 +19,4 @@ require ( github.com/alecthomas/errors v0.9.1 github.com/alecthomas/participle/v2 v2.1.4 // indirect github.com/alecthomas/repr v0.5.2 // indirect - github.com/pkg/xattr v0.4.12 ) diff --git a/go.sum b/go.sum index c545120..6a8c87d 100644 --- a/go.sum +++ b/go.sum @@ -10,11 +10,21 @@ github.com/alecthomas/participle/v2 v2.1.4 h1:W/H79S8Sat/krZ3el6sQMvMaahJ+XcM9WS github.com/alecthomas/participle/v2 v2.1.4/go.mod h1:8tqVbpTX20Ru4NfYQgZf4mP18eXPTBViyMWiArNEgGI= github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs= github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/lmittmann/tint v1.1.2 h1:2CQzrL6rslrsyjqLDwD11bZ5OpLBPU+g3G/r5LSfS8w= github.com/lmittmann/tint v1.1.2/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE= -github.com/pkg/xattr v0.4.12 h1:rRTkSyFNTRElv6pkA3zpjHpQ90p/OdHQC1GmGh1aTjM= -github.com/pkg/xattr v0.4.12/go.mod h1:di8WF84zAKk8jzR1UBTEWh9AUlIZZ7M/JNt8e9B6ktU= -golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f h1:8w7RhxzTVgUzw/AH/9mUV5q0vMgy40SQRursCcfmkCw= -golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo= +go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/cache/disk.go b/internal/cache/disk.go index 45cf3e5..3797d33 100644 --- a/internal/cache/disk.go +++ b/internal/cache/disk.go @@ -2,7 +2,6 @@ package cache import ( "context" - "fmt" "io" "io/fs" "log/slog" @@ -14,13 +13,10 @@ import ( "github.com/alecthomas/errors" "github.com/alecthomas/kong" - "github.com/pkg/xattr" "github.com/block/sfptc/internal/logging" ) -const expiresAtXAttr = "user.expires-at" - func init() { Register("disk", NewDisk) } @@ -35,6 +31,7 @@ type DiskConfig struct { type Disk struct { logger *slog.Logger config DiskConfig + ttl *ttlStorage size atomic.Int64 runEviction chan struct{} stop context.CancelFunc @@ -47,7 +44,7 @@ var _ Cache = (*Disk)(nil) // config.Root MUST be set. // // This [Cache] implementation stores cache entries under a directory. If total usage exceeds the limit, entries are -// evicted based on their last access time. TTLs are stored in extended file attributes (xattr). If an entry exceeds its +// evicted based on their last access time. TTLs are stored in a bbolt database. If an entry exceeds its // TTL or the default, it is evicted. The implementation is safe for concurrent use within a single Go process. func NewDisk(ctx context.Context, config DiskConfig) (*Disk, error) { // Validate config @@ -67,17 +64,11 @@ func NewDisk(ctx context.Context, config DiskConfig) (*Disk, error) { return nil, errors.Errorf("failed to create cache root: %w", err) } - // Check if the filesystem supports xattr's by creating a temporary test file. - f, err := os.CreateTemp(config.Root, ".xattr-test-*") + // Open TTL storage + ttl, err := newTTLStorage(filepath.Join(config.Root, "metadata.db")) if err != nil { - return nil, errors.Errorf("failed to create xattr test file: %w", err) - } - testFile := f.Name() - if err := xattr.FSet(f, "user.limit-mb", fmt.Appendf(nil, "%x", config.LimitMB)); err != nil { - return nil, errors.Join(errors.Errorf("fatal: xattrs are not supported on %s: %w", config.Root, err), f.Close(), os.Remove(testFile)) + return nil, errors.Errorf("failed to create TTL storage: %w", err) } - _ = f.Close() - _ = os.Remove(testFile) // Determine the initial size. var size int64 @@ -88,6 +79,10 @@ func NewDisk(ctx context.Context, config DiskConfig) (*Disk, error) { if info.IsDir() { return nil } + // Skip metadata.db file + if info.Name() == "metadata.db" { + return nil + } size += info.Size() return nil }) @@ -102,6 +97,7 @@ func NewDisk(ctx context.Context, config DiskConfig) (*Disk, error) { disk := &Disk{ logger: logger, config: config, + ttl: ttl, runEviction: make(chan struct{}), stop: stop, } @@ -116,6 +112,9 @@ func (d *Disk) String() string { return "disk:" + d.config.Root } func (d *Disk) Close() error { d.stop() + if d.ttl != nil { + return d.ttl.close() + } return nil } @@ -147,6 +146,7 @@ func (d *Disk) Create(_ context.Context, key Key, ttl time.Duration) (io.WriteCl return &diskWriter{ disk: d, file: f, + key: key, path: fullPath, tempPath: tempPath, expiresAt: expiresAt, @@ -159,19 +159,9 @@ func (d *Disk) Delete(_ context.Context, key Key) error { // Check if file is expired expired := false - expiresAtBytes, err := xattr.Get(fullPath, expiresAtXAttr) - if err != nil { - if errors.Is(err, fs.ErrNotExist) { - return fs.ErrNotExist - } - // Continue with deletion even if we can't read xattr - } else { - var expiresAt time.Time - if err := expiresAt.UnmarshalBinary(expiresAtBytes); err == nil { - if time.Now().After(expiresAt) { - expired = true - } - } + expiresAt, err := d.ttl.get(key) + if err == nil && time.Now().After(expiresAt) { + expired = true } info, err := os.Stat(fullPath) @@ -183,6 +173,11 @@ func (d *Disk) Delete(_ context.Context, key Key) error { return errors.Errorf("failed to remove file: %w", err) } + // Remove TTL metadata + if err := d.ttl.delete(key); err != nil { + return errors.Errorf("failed to delete TTL metadata: %w", err) + } + d.size.Add(-info.Size()) if expired { @@ -200,16 +195,11 @@ func (d *Disk) Open(ctx context.Context, key Key) (io.ReadCloser, error) { return nil, errors.Errorf("failed to open file: %w", err) } - expiresAtBytes, err := xattr.FGet(f, expiresAtXAttr) + expiresAt, err := d.ttl.get(key) if err != nil { return nil, errors.Join(errors.Errorf("failed to get expiration time: %w", err), f.Close()) } - var expiresAt time.Time - if err := expiresAt.UnmarshalBinary(expiresAtBytes); err != nil { - return nil, errors.Join(errors.Errorf("failed to unmarshal expiration time: %w", err), f.Close()) - } - now := time.Now() if now.After(expiresAt) { return nil, errors.Join(fs.ErrNotExist, f.Close(), d.Delete(ctx, key)) @@ -218,12 +208,8 @@ func (d *Disk) Open(ctx context.Context, key Key) (io.ReadCloser, error) { // Reset expiration time to implement LRU ttl := min(expiresAt.Sub(now), d.config.MaxTTL) newExpiresAt := now.Add(ttl) - newExpiresAtBytes, err := newExpiresAt.MarshalBinary() - if err != nil { - return nil, errors.Join(errors.Errorf("failed to marshal new expiration time: %w", err), f.Close()) - } - if err := xattr.FSet(f, expiresAtXAttr, newExpiresAtBytes); err != nil { + if err := d.ttl.set(key, newExpiresAt); err != nil { return nil, errors.Join(errors.Errorf("failed to update expiration time: %w", err), f.Close()) } @@ -258,64 +244,52 @@ func (d *Disk) evictionLoop(ctx context.Context) { func (d *Disk) evict() error { type fileInfo struct { + key Key path string size int64 expiresAt time.Time accessedAt time.Time } - var files []fileInfo + var remainingFiles []fileInfo + var expiredKeys []Key now := time.Now() - err := filepath.Walk(d.config.Root, func(path string, info fs.FileInfo, err error) error { - if err != nil { - return errors.WithStack(err) - } - if info.IsDir() { - return nil - } - - relPath, err := filepath.Rel(d.config.Root, path) - if err != nil { - return errors.WithStack(err) - } + err := d.ttl.walk(func(key Key, expiresAt time.Time) error { + path := d.keyToPath(key) + fullPath := filepath.Join(d.config.Root, path) - expiresAtBytes, err := xattr.Get(path, expiresAtXAttr) + info, err := os.Stat(fullPath) if err != nil { - return nil //nolint:nilerr + if errors.Is(err, fs.ErrNotExist) { + expiredKeys = append(expiredKeys, key) + } + return nil } - var expiresAt time.Time - if err := expiresAt.UnmarshalBinary(expiresAtBytes); err != nil { - return nil //nolint:nilerr + if now.After(expiresAt) { + if err := os.Remove(fullPath); err != nil && !errors.Is(err, fs.ErrNotExist) { + return errors.Errorf("failed to delete expired file %s: %w", path, err) + } + expiredKeys = append(expiredKeys, key) + d.size.Add(-info.Size()) + } else { + remainingFiles = append(remainingFiles, fileInfo{ + key: key, + path: path, + size: info.Size(), + expiresAt: expiresAt, + accessedAt: info.ModTime(), + }) } - - files = append(files, fileInfo{ - path: relPath, - size: info.Size(), - expiresAt: expiresAt, - accessedAt: info.ModTime(), - }) - return nil }) - if err != nil { - return errors.Errorf("failed to walk cache directory: %w", err) + return errors.Errorf("failed to walk TTL entries: %w", err) } - var remainingFiles []fileInfo - - for _, f := range files { - if now.After(f.expiresAt) { - fullPath := filepath.Join(d.config.Root, f.path) - if err := os.Remove(fullPath); err != nil && !errors.Is(err, fs.ErrNotExist) { - return errors.Errorf("failed to delete expired file %s: %w", f.path, err) - } - d.size.Add(-f.size) - } else { - remainingFiles = append(remainingFiles, f) - } + if err := d.ttl.deleteAll(expiredKeys); err != nil { + return errors.Errorf("failed to delete TTL metadata: %w", err) } limitBytes := int64(d.config.LimitMB) * 1024 * 1024 @@ -328,6 +302,7 @@ func (d *Disk) evict() error { return remainingFiles[i].accessedAt.Before(remainingFiles[j].accessedAt) }) + var sizeEvictedKeys []Key for _, f := range remainingFiles { if d.size.Load() <= limitBytes { break @@ -337,15 +312,21 @@ func (d *Disk) evict() error { if err := os.Remove(fullPath); err != nil && !errors.Is(err, fs.ErrNotExist) { return errors.Errorf("failed to delete file during size eviction %s: %w", f.path, err) } + sizeEvictedKeys = append(sizeEvictedKeys, f.key) d.size.Add(-f.size) } + if err := d.ttl.deleteAll(sizeEvictedKeys); err != nil { + return errors.Errorf("failed to delete TTL metadata: %w", err) + } + return nil } type diskWriter struct { disk *Disk file *os.File + key Key path string tempPath string expiresAt time.Time @@ -359,15 +340,6 @@ func (w *diskWriter) Write(p []byte) (int, error) { } func (w *diskWriter) Close() error { - expiresAtBytes, err := w.expiresAt.MarshalBinary() - if err != nil { - return errors.Join(errors.Errorf("failed to marshal expiration time: %w", err), w.file.Close()) - } - - if err := xattr.FSet(w.file, expiresAtXAttr, expiresAtBytes); err != nil { - return errors.Join(errors.Errorf("failed to set expiration time: %w", err), w.file.Close()) - } - if err := w.file.Close(); err != nil { return errors.Errorf("failed to close file: %w", err) } @@ -376,6 +348,10 @@ func (w *diskWriter) Close() error { return errors.Errorf("failed to rename temp file: %w", err) } + if err := w.disk.ttl.set(w.key, w.expiresAt); err != nil { + return errors.Join(errors.Errorf("failed to set expiration time: %w", err), os.Remove(w.path)) + } + w.disk.size.Add(w.size) select { diff --git a/internal/cache/ttl_storage.go b/internal/cache/ttl_storage.go new file mode 100644 index 0000000..33e6582 --- /dev/null +++ b/internal/cache/ttl_storage.go @@ -0,0 +1,133 @@ +package cache + +import ( + "time" + + "github.com/alecthomas/errors" + "go.etcd.io/bbolt" +) + +var ttlBucketName = []byte("ttl") + +// ttlStorage manages expiration times for cache entries using bbolt. +type ttlStorage struct { + db *bbolt.DB +} + +// newTTLStorage creates a new bbolt-backed TTL storage. +func newTTLStorage(dbPath string) (*ttlStorage, error) { + db, err := bbolt.Open(dbPath, 0600, &bbolt.Options{ + Timeout: 1 * time.Second, + }) + if err != nil { + return nil, errors.Errorf("failed to open bbolt database: %w", err) + } + + // Create the bucket if it doesn't exist + err = db.Update(func(tx *bbolt.Tx) error { + _, err := tx.CreateBucketIfNotExists(ttlBucketName) + return errors.WithStack(err) + }) + if err != nil { + return nil, errors.Join(errors.Errorf("failed to create ttl bucket: %w", err), db.Close()) + } + + return &ttlStorage{db: db}, nil +} + +func (s *ttlStorage) set(key Key, expiresAt time.Time) error { + expiresAtBytes, err := expiresAt.MarshalBinary() + if err != nil { + return errors.Errorf("failed to marshal expiration time: %w", err) + } + + err = s.db.Update(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(ttlBucketName) + return bucket.Put(key[:], expiresAtBytes) + }) + if err != nil { + return errors.Errorf("failed to set expiration time: %w", err) + } + + return nil +} + +func (s *ttlStorage) get(key Key) (time.Time, error) { + var expiresAt time.Time + + err := s.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(ttlBucketName) + expiresAtBytes := bucket.Get(key[:]) + if expiresAtBytes == nil { + return errors.New("key not found") + } + return errors.WithStack(expiresAt.UnmarshalBinary(expiresAtBytes)) + }) + if err != nil { + return time.Time{}, errors.Errorf("failed to get expiration time: %w", err) + } + + return expiresAt, nil +} + +func (s *ttlStorage) delete(key Key) error { + err := s.db.Update(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(ttlBucketName) + return bucket.Delete(key[:]) + }) + if err != nil { + return errors.Errorf("failed to delete expiration time: %w", err) + } + return nil +} + +func (s *ttlStorage) deleteAll(keys []Key) error { + if len(keys) == 0 { + return nil + } + err := s.db.Update(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(ttlBucketName) + for _, key := range keys { + if err := bucket.Delete(key[:]); err != nil { + return errors.Errorf("failed to delete expiration time: %w", err) + } + } + return nil + }) + if err != nil { + return errors.Errorf("failed to delete expiration times: %w", err) + } + return nil +} + +func (s *ttlStorage) walk(fn func(key Key, expiresAt time.Time) error) error { + err := s.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(ttlBucketName) + if bucket == nil { + return nil + } + return bucket.ForEach(func(k, v []byte) error { + if len(k) != 32 { + return nil + } + var key Key + copy(key[:], k) + var expiresAt time.Time + if err := expiresAt.UnmarshalBinary(v); err != nil { + return nil //nolint:nilerr + } + return fn(key, expiresAt) + }) + }) + if err != nil { + return errors.Errorf("failed to walk TTL entries: %w", err) + } + return nil +} + +func (s *ttlStorage) close() error { + if err := s.db.Close(); err != nil { + return errors.Errorf("failed to close bbolt database: %w", err) + } + return nil +}