Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions internal/cache/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ type Cache interface {
// If "ttl" is zero, a maximum TTL MUST be used by the implementation.
//
// The file MUST not be available for read until completely written and closed.
//
// If the context is cancelled the object MUST not be made available in the cache.
Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error)
// Delete a file from the cache.
//
Expand Down
34 changes: 34 additions & 0 deletions internal/cache/cachetest/suite.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cachetest

import (
"context"
"io"
"net/textproto"
"os"
Expand Down Expand Up @@ -46,6 +47,10 @@ func Suite(t *testing.T, newCache func(t *testing.T) cache.Cache) {
t.Run("Headers", func(t *testing.T) {
testHeaders(t, newCache(t))
})

t.Run("ContextCancellation", func(t *testing.T) {
testContextCancellation(t, newCache(t))
})
}

func testCreateAndOpen(t *testing.T, c cache.Cache) {
Expand Down Expand Up @@ -233,3 +238,32 @@ func testHeaders(t *testing.T, c cache.Cache) {
// Verify headers
assert.Equal(t, headers, returnedHeaders)
}

func testContextCancellation(t *testing.T, c cache.Cache) {
defer c.Close()
ctx := t.Context()

// Create a cancellable context
cancelledCtx, cancel := context.WithCancel(ctx)

// Create an object with the cancellable context
key := cache.NewKey("test-cancelled")
writer, err := c.Create(cancelledCtx, key, textproto.MIMEHeader{}, time.Hour)
assert.NoError(t, err)

// Write some data
_, err = writer.Write([]byte("test data"))
assert.NoError(t, err)

// Cancel the context before closing
cancel()

// Close should fail due to cancelled context
err = writer.Close()
assert.Error(t, err)
assert.Contains(t, err.Error(), "cancel")

// Object should not be in cache
_, _, err = c.Open(ctx, key)
assert.IsError(t, err, os.ErrNotExist)
}
10 changes: 9 additions & 1 deletion internal/cache/disk.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (d *Disk) Size() int64 {
return d.size.Load()
}

func (d *Disk) Create(_ context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) {
func (d *Disk) Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) {
if ttl > d.config.MaxTTL || ttl == 0 {
ttl = d.config.MaxTTL
}
Expand Down Expand Up @@ -152,6 +152,7 @@ func (d *Disk) Create(_ context.Context, key Key, headers textproto.MIMEHeader,
tempPath: tempPath,
expiresAt: expiresAt,
headers: headers,
ctx: ctx,
}, nil
}

Expand Down Expand Up @@ -334,6 +335,7 @@ type diskWriter struct {
expiresAt time.Time
headers textproto.MIMEHeader
size int64
ctx context.Context
}

func (w *diskWriter) Write(p []byte) (int, error) {
Expand All @@ -347,6 +349,12 @@ func (w *diskWriter) Close() error {
return errors.Errorf("failed to close file: %w", err)
}

// Check if context was cancelled
if err := w.ctx.Err(); err != nil {
// Clean up temp file and abort
return errors.Join(errors.Wrap(err, "create operation cancelled"), os.Remove(w.tempPath))
}

if err := os.Rename(w.tempPath, w.path); err != nil {
return errors.Errorf("failed to rename temp file: %w", err)
}
Expand Down
9 changes: 8 additions & 1 deletion internal/cache/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (m *Memory) Open(_ context.Context, key Key) (io.ReadCloser, textproto.MIME
return io.NopCloser(bytes.NewReader(entry.data)), entry.headers, nil
}

func (m *Memory) Create(_ context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) {
func (m *Memory) Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) {
if ttl == 0 {
ttl = m.config.MaxTTL
}
Expand All @@ -71,6 +71,7 @@ func (m *Memory) Create(_ context.Context, key Key, headers textproto.MIMEHeader
buf: &bytes.Buffer{},
expiresAt: time.Now().Add(ttl),
headers: headers,
ctx: ctx,
}

return writer, nil
Expand Down Expand Up @@ -104,6 +105,7 @@ type memoryWriter struct {
expiresAt time.Time
headers textproto.MIMEHeader
closed bool
ctx context.Context
}

func (w *memoryWriter) Write(p []byte) (int, error) {
Expand All @@ -119,6 +121,11 @@ func (w *memoryWriter) Close() error {
}
w.closed = true

// Check if context was cancelled
if err := w.ctx.Err(); err != nil {
return errors.Wrap(err, "create operation cancelled")
}

w.cache.mu.Lock()
defer w.cache.mu.Unlock()

Expand Down
5 changes: 5 additions & 0 deletions internal/cache/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ func (c *Remote) Create(ctx context.Context, key Key, headers textproto.MIMEHead
wc := &writeCloser{
pw: pw,
done: make(chan error, 1),
ctx: ctx,
}

go func() {
Expand Down Expand Up @@ -133,6 +134,7 @@ func (c *Remote) Close() error {
type writeCloser struct {
pw *io.PipeWriter
done chan error
ctx context.Context
}

func (wc *writeCloser) Write(p []byte) (int, error) {
Expand All @@ -141,6 +143,9 @@ func (wc *writeCloser) Write(p []byte) (int, error) {
}

func (wc *writeCloser) Close() error {
if err := wc.ctx.Err(); err != nil {
return errors.Join(errors.Wrap(err, "create operation cancelled"), wc.pw.CloseWithError(err))
}
if err := wc.pw.Close(); err != nil {
return errors.Wrap(err, "failed to close pipe writer")
}
Expand Down