From 9c2164a16c9bdb108b5db95aeb2a3c99509fa9fa Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Sun, 11 Jan 2026 22:55:22 +1100 Subject: [PATCH] feat: objects will not be created if the Create() context is cancelled This allows callers to abort creation if some failure other than write error occurs. --- internal/cache/api.go | 2 ++ internal/cache/cachetest/suite.go | 34 +++++++++++++++++++++++++++++++ internal/cache/disk.go | 10 ++++++++- internal/cache/memory.go | 9 +++++++- internal/cache/remote.go | 5 +++++ 5 files changed, 58 insertions(+), 2 deletions(-) diff --git a/internal/cache/api.go b/internal/cache/api.go index 1df8b88..419f305 100644 --- a/internal/cache/api.go +++ b/internal/cache/api.go @@ -100,6 +100,8 @@ type Cache interface { // If "ttl" is zero, a maximum TTL MUST be used by the implementation. // // The file MUST not be available for read until completely written and closed. + // + // If the context is cancelled the object MUST not be made available in the cache. Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) // Delete a file from the cache. // diff --git a/internal/cache/cachetest/suite.go b/internal/cache/cachetest/suite.go index 80db90a..706328d 100644 --- a/internal/cache/cachetest/suite.go +++ b/internal/cache/cachetest/suite.go @@ -1,6 +1,7 @@ package cachetest import ( + "context" "io" "net/textproto" "os" @@ -46,6 +47,10 @@ func Suite(t *testing.T, newCache func(t *testing.T) cache.Cache) { t.Run("Headers", func(t *testing.T) { testHeaders(t, newCache(t)) }) + + t.Run("ContextCancellation", func(t *testing.T) { + testContextCancellation(t, newCache(t)) + }) } func testCreateAndOpen(t *testing.T, c cache.Cache) { @@ -233,3 +238,32 @@ func testHeaders(t *testing.T, c cache.Cache) { // Verify headers assert.Equal(t, headers, returnedHeaders) } + +func testContextCancellation(t *testing.T, c cache.Cache) { + defer c.Close() + ctx := t.Context() + + // Create a cancellable context + cancelledCtx, cancel := context.WithCancel(ctx) + + // Create an object with the cancellable context + key := cache.NewKey("test-cancelled") + writer, err := c.Create(cancelledCtx, key, textproto.MIMEHeader{}, time.Hour) + assert.NoError(t, err) + + // Write some data + _, err = writer.Write([]byte("test data")) + assert.NoError(t, err) + + // Cancel the context before closing + cancel() + + // Close should fail due to cancelled context + err = writer.Close() + assert.Error(t, err) + assert.Contains(t, err.Error(), "cancel") + + // Object should not be in cache + _, _, err = c.Open(ctx, key) + assert.IsError(t, err, os.ErrNotExist) +} diff --git a/internal/cache/disk.go b/internal/cache/disk.go index 70af8f2..0de9058 100644 --- a/internal/cache/disk.go +++ b/internal/cache/disk.go @@ -123,7 +123,7 @@ func (d *Disk) Size() int64 { return d.size.Load() } -func (d *Disk) Create(_ context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) { +func (d *Disk) Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) { if ttl > d.config.MaxTTL || ttl == 0 { ttl = d.config.MaxTTL } @@ -152,6 +152,7 @@ func (d *Disk) Create(_ context.Context, key Key, headers textproto.MIMEHeader, tempPath: tempPath, expiresAt: expiresAt, headers: headers, + ctx: ctx, }, nil } @@ -334,6 +335,7 @@ type diskWriter struct { expiresAt time.Time headers textproto.MIMEHeader size int64 + ctx context.Context } func (w *diskWriter) Write(p []byte) (int, error) { @@ -347,6 +349,12 @@ func (w *diskWriter) Close() error { return errors.Errorf("failed to close file: %w", err) } + // Check if context was cancelled + if err := w.ctx.Err(); err != nil { + // Clean up temp file and abort + return errors.Join(errors.Wrap(err, "create operation cancelled"), os.Remove(w.tempPath)) + } + if err := os.Rename(w.tempPath, w.path); err != nil { return errors.Errorf("failed to rename temp file: %w", err) } diff --git a/internal/cache/memory.go b/internal/cache/memory.go index 6c4cd0c..f6e6c6a 100644 --- a/internal/cache/memory.go +++ b/internal/cache/memory.go @@ -60,7 +60,7 @@ func (m *Memory) Open(_ context.Context, key Key) (io.ReadCloser, textproto.MIME return io.NopCloser(bytes.NewReader(entry.data)), entry.headers, nil } -func (m *Memory) Create(_ context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) { +func (m *Memory) Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) { if ttl == 0 { ttl = m.config.MaxTTL } @@ -71,6 +71,7 @@ func (m *Memory) Create(_ context.Context, key Key, headers textproto.MIMEHeader buf: &bytes.Buffer{}, expiresAt: time.Now().Add(ttl), headers: headers, + ctx: ctx, } return writer, nil @@ -104,6 +105,7 @@ type memoryWriter struct { expiresAt time.Time headers textproto.MIMEHeader closed bool + ctx context.Context } func (w *memoryWriter) Write(p []byte) (int, error) { @@ -119,6 +121,11 @@ func (w *memoryWriter) Close() error { } w.closed = true + // Check if context was cancelled + if err := w.ctx.Err(); err != nil { + return errors.Wrap(err, "create operation cancelled") + } + w.cache.mu.Lock() defer w.cache.mu.Unlock() diff --git a/internal/cache/remote.go b/internal/cache/remote.go index 46e43d7..36559a2 100644 --- a/internal/cache/remote.go +++ b/internal/cache/remote.go @@ -77,6 +77,7 @@ func (c *Remote) Create(ctx context.Context, key Key, headers textproto.MIMEHead wc := &writeCloser{ pw: pw, done: make(chan error, 1), + ctx: ctx, } go func() { @@ -133,6 +134,7 @@ func (c *Remote) Close() error { type writeCloser struct { pw *io.PipeWriter done chan error + ctx context.Context } func (wc *writeCloser) Write(p []byte) (int, error) { @@ -141,6 +143,9 @@ func (wc *writeCloser) Write(p []byte) (int, error) { } func (wc *writeCloser) Close() error { + if err := wc.ctx.Err(); err != nil { + return errors.Join(errors.Wrap(err, "create operation cancelled"), wc.pw.CloseWithError(err)) + } if err := wc.pw.Close(); err != nil { return errors.Wrap(err, "failed to close pipe writer") }