Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions internal/cache/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/sha256"
"encoding/hex"
"io"
"net/textproto"
"time"

"github.com/alecthomas/errors"
Expand Down Expand Up @@ -70,6 +71,21 @@ func (k *Key) MarshalText() ([]byte, error) {
return []byte(k.String()), nil
}

// FilterTransportHeaders returns a copy of the given headers with standard HTTP transport headers removed.
// These headers are typically added by HTTP clients/servers and should not be cached.
func FilterTransportHeaders(headers textproto.MIMEHeader) textproto.MIMEHeader {
filtered := make(textproto.MIMEHeader)
for key, values := range headers {
// Skip standard HTTP headers added by transport layer or that shouldn't be cached
if key == "Content-Length" || key == "Date" || key == "Accept-Encoding" ||
key == "User-Agent" || key == "Transfer-Encoding" || key == "Time-To-Live" {
continue
}
filtered[key] = values
}
return filtered
}

// A Cache knows how to retrieve, create and delete objects from a cache.
type Cache interface {
// String describes the Cache implementation.
Expand All @@ -78,13 +94,13 @@ type Cache interface {
//
// Expired files SHOULD not be returned.
// Must return os.ErrNotExist if the file does not exist.
Open(ctx context.Context, key Key) (io.ReadCloser, error)
Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error)
// Create a new file in the cache.
//
// If "ttl" is zero, a maximum TTL MUST be used by the implementation.
//
// The file MUST not be available for read until completely written and closed.
Create(ctx context.Context, key Key, ttl time.Duration) (io.WriteCloser, error)
Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error)
// Delete a file from the cache.
//
// MUST be atomic.
Expand Down
71 changes: 56 additions & 15 deletions internal/cache/cachetest/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cachetest

import (
"io"
"net/textproto"
"os"
"testing"
"time"
Expand Down Expand Up @@ -41,6 +42,10 @@ func Suite(t *testing.T, newCache func(t *testing.T) cache.Cache) {
t.Run("NotAvailableUntilClosed", func(t *testing.T) {
testNotAvailableUntilClosed(t, newCache(t))
})

t.Run("Headers", func(t *testing.T) {
testHeaders(t, newCache(t))
})
}

func testCreateAndOpen(t *testing.T, c cache.Cache) {
Expand All @@ -49,7 +54,7 @@ func testCreateAndOpen(t *testing.T, c cache.Cache) {

key := cache.NewKey("test-key")

writer, err := c.Create(ctx, key, time.Hour)
writer, err := c.Create(ctx, key, nil, time.Hour)
assert.NoError(t, err)

_, err = writer.Write([]byte("hello world"))
Expand All @@ -58,7 +63,7 @@ func testCreateAndOpen(t *testing.T, c cache.Cache) {
err = writer.Close()
assert.NoError(t, err)

reader, err := c.Open(ctx, key)
reader, _, err := c.Open(ctx, key)
assert.NoError(t, err)
defer reader.Close()

Expand All @@ -73,7 +78,7 @@ func testNotFound(t *testing.T, c cache.Cache) {

key := cache.NewKey("nonexistent")

_, err := c.Open(ctx, key)
_, _, err := c.Open(ctx, key)
assert.IsError(t, err, os.ErrNotExist)
}

Expand All @@ -83,7 +88,7 @@ func testExpiration(t *testing.T, c cache.Cache) {

key := cache.NewKey("test-key")

writer, err := c.Create(ctx, key, 10*time.Millisecond)
writer, err := c.Create(ctx, key, nil, 10*time.Millisecond)
assert.NoError(t, err)

_, err = writer.Write([]byte("test data"))
Expand All @@ -92,13 +97,13 @@ func testExpiration(t *testing.T, c cache.Cache) {
err = writer.Close()
assert.NoError(t, err)

reader, err := c.Open(ctx, key)
reader, _, err := c.Open(ctx, key)
assert.NoError(t, err)
assert.NoError(t, reader.Close())

time.Sleep(20 * time.Millisecond)

_, err = c.Open(ctx, key)
_, _, err = c.Open(ctx, key)
assert.IsError(t, err, os.ErrNotExist)
}

Expand All @@ -108,7 +113,7 @@ func testDefaultTTL(t *testing.T, c cache.Cache) {

key := cache.NewKey("test-key")

writer, err := c.Create(ctx, key, 0)
writer, err := c.Create(ctx, key, nil, 0)
assert.NoError(t, err)

_, err = writer.Write([]byte("test data"))
Expand All @@ -117,7 +122,7 @@ func testDefaultTTL(t *testing.T, c cache.Cache) {
err = writer.Close()
assert.NoError(t, err)

reader, err := c.Open(ctx, key)
reader, _, err := c.Open(ctx, key)
assert.NoError(t, err)
assert.NoError(t, reader.Close())
}
Expand All @@ -128,7 +133,7 @@ func testDelete(t *testing.T, c cache.Cache) {

key := cache.NewKey("test-key")

writer, err := c.Create(ctx, key, time.Hour)
writer, err := c.Create(ctx, key, nil, time.Hour)
assert.NoError(t, err)

_, err = writer.Write([]byte("test data"))
Expand All @@ -140,7 +145,7 @@ func testDelete(t *testing.T, c cache.Cache) {
err = c.Delete(ctx, key)
assert.NoError(t, err)

_, err = c.Open(ctx, key)
_, _, err = c.Open(ctx, key)
assert.IsError(t, err, os.ErrNotExist)
}

Expand All @@ -150,7 +155,7 @@ func testMultipleWrites(t *testing.T, c cache.Cache) {

key := cache.NewKey("test-key")

writer, err := c.Create(ctx, key, time.Hour)
writer, err := c.Create(ctx, key, nil, time.Hour)
assert.NoError(t, err)

_, err = writer.Write([]byte("hello "))
Expand All @@ -162,7 +167,7 @@ func testMultipleWrites(t *testing.T, c cache.Cache) {
err = writer.Close()
assert.NoError(t, err)

reader, err := c.Open(ctx, key)
reader, _, err := c.Open(ctx, key)
assert.NoError(t, err)
defer reader.Close()

Expand All @@ -177,18 +182,54 @@ func testNotAvailableUntilClosed(t *testing.T, c cache.Cache) {

key := cache.NewKey("test-key")

writer, err := c.Create(ctx, key, time.Hour)
writer, err := c.Create(ctx, key, nil, time.Hour)
assert.NoError(t, err)

_, err = writer.Write([]byte("test data"))
assert.NoError(t, err)

_, err = c.Open(ctx, key)
_, _, err = c.Open(ctx, key)
assert.IsError(t, err, os.ErrNotExist)

err = writer.Close()
assert.NoError(t, err)

_, err = c.Open(ctx, key)
_, _, err = c.Open(ctx, key)
assert.NoError(t, err)
}

func testHeaders(t *testing.T, c cache.Cache) {
defer c.Close()
ctx := t.Context()

key := cache.NewKey("test-key-with-headers")

// Create headers to store
headers := textproto.MIMEHeader{
"Content-Type": []string{"application/json"},
"Cache-Control": []string{"max-age=3600"},
"X-Custom-Field": []string{"custom-value"},
}

writer, err := c.Create(ctx, key, headers, time.Hour)
assert.NoError(t, err)

_, err = writer.Write([]byte("test data with headers"))
assert.NoError(t, err)

err = writer.Close()
assert.NoError(t, err)

// Open and verify headers are returned
reader, returnedHeaders, err := c.Open(ctx, key)
assert.NoError(t, err)
defer reader.Close()

// Verify the data
data, err := io.ReadAll(reader)
assert.NoError(t, err)
assert.Equal(t, "test data with headers", string(data))

// Verify headers
assert.Equal(t, headers, returnedHeaders)
}
27 changes: 15 additions & 12 deletions internal/cache/disk.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"io/fs"
"log/slog"
"net/textproto"
"os"
"path/filepath"
"sort"
Expand Down Expand Up @@ -122,7 +123,7 @@ func (d *Disk) Size() int64 {
return d.size.Load()
}

func (d *Disk) Create(_ context.Context, key Key, ttl time.Duration) (io.WriteCloser, error) {
func (d *Disk) Create(_ context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) {
if ttl > d.config.MaxTTL || ttl == 0 {
ttl = d.config.MaxTTL
}
Expand Down Expand Up @@ -150,6 +151,7 @@ func (d *Disk) Create(_ context.Context, key Key, ttl time.Duration) (io.WriteCl
path: fullPath,
tempPath: tempPath,
expiresAt: expiresAt,
headers: headers,
}, nil
}

Expand All @@ -159,7 +161,7 @@ func (d *Disk) Delete(_ context.Context, key Key) error {

// Check if file is expired
expired := false
expiresAt, err := d.ttl.get(key)
expiresAt, _, err := d.ttl.get(key)
if err == nil && time.Now().After(expiresAt) {
expired = true
}
Expand All @@ -186,34 +188,34 @@ func (d *Disk) Delete(_ context.Context, key Key) error {
return nil
}

func (d *Disk) Open(ctx context.Context, key Key) (io.ReadCloser, error) {
func (d *Disk) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error) {
path := d.keyToPath(key)
fullPath := filepath.Join(d.config.Root, path)

f, err := os.Open(fullPath)
if err != nil {
return nil, errors.Errorf("failed to open file: %w", err)
return nil, nil, errors.Errorf("failed to open file: %w", err)
}

expiresAt, err := d.ttl.get(key)
expiresAt, headers, err := d.ttl.get(key)
if err != nil {
return nil, errors.Join(errors.Errorf("failed to get expiration time: %w", err), f.Close())
return nil, nil, errors.Join(errors.Errorf("failed to get metadata: %w", err), f.Close())
}

now := time.Now()
if now.After(expiresAt) {
return nil, errors.Join(fs.ErrNotExist, f.Close(), d.Delete(ctx, key))
return nil, nil, errors.Join(fs.ErrNotExist, f.Close(), d.Delete(ctx, key))
}

// Reset expiration time to implement LRU
ttl := min(expiresAt.Sub(now), d.config.MaxTTL)
newExpiresAt := now.Add(ttl)

if err := d.ttl.set(key, newExpiresAt); err != nil {
return nil, errors.Join(errors.Errorf("failed to update expiration time: %w", err), f.Close())
if err := d.ttl.set(key, newExpiresAt, headers); err != nil {
return nil, nil, errors.Join(errors.Errorf("failed to update expiration time: %w", err), f.Close())
}

return f, nil
return f, headers, nil
}

func (d *Disk) keyToPath(key Key) string {
Expand Down Expand Up @@ -330,6 +332,7 @@ type diskWriter struct {
path string
tempPath string
expiresAt time.Time
headers textproto.MIMEHeader
size int64
}

Expand All @@ -348,8 +351,8 @@ func (w *diskWriter) Close() error {
return errors.Errorf("failed to rename temp file: %w", err)
}

if err := w.disk.ttl.set(w.key, w.expiresAt); err != nil {
return errors.Join(errors.Errorf("failed to set expiration time: %w", err), os.Remove(w.path))
if err := w.disk.ttl.set(w.key, w.expiresAt, w.headers); err != nil {
return errors.Join(errors.Errorf("failed to set metadata: %w", err), os.Remove(w.path))
}

w.disk.size.Add(w.size)
Expand Down
15 changes: 10 additions & 5 deletions internal/cache/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"fmt"
"io"
"net/textproto"
"os"
"sync"
"time"
Expand All @@ -24,6 +25,7 @@ type MemoryConfig struct {
type memoryEntry struct {
data []byte
expiresAt time.Time
headers textproto.MIMEHeader
}

type Memory struct {
Expand All @@ -42,23 +44,23 @@ func NewMemory(_ context.Context, config MemoryConfig) (*Memory, error) {

func (m *Memory) String() string { return fmt.Sprintf("memory:%dMB", m.config.LimitMB) }

func (m *Memory) Open(_ context.Context, key Key) (io.ReadCloser, error) {
func (m *Memory) Open(_ context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error) {
m.mu.RLock()
defer m.mu.RUnlock()

entry, exists := m.entries[key]
if !exists {
return nil, os.ErrNotExist
return nil, nil, os.ErrNotExist
}

if time.Now().After(entry.expiresAt) {
return nil, os.ErrNotExist
return nil, nil, os.ErrNotExist
}

return io.NopCloser(bytes.NewReader(entry.data)), nil
return io.NopCloser(bytes.NewReader(entry.data)), entry.headers, nil
}

func (m *Memory) Create(_ context.Context, key Key, ttl time.Duration) (io.WriteCloser, error) {
func (m *Memory) Create(_ context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) {
if ttl == 0 {
ttl = m.config.MaxTTL
}
Expand All @@ -68,6 +70,7 @@ func (m *Memory) Create(_ context.Context, key Key, ttl time.Duration) (io.Write
key: key,
buf: &bytes.Buffer{},
expiresAt: time.Now().Add(ttl),
headers: headers,
}

return writer, nil
Expand Down Expand Up @@ -99,6 +102,7 @@ type memoryWriter struct {
key Key
buf *bytes.Buffer
expiresAt time.Time
headers textproto.MIMEHeader
closed bool
}

Expand Down Expand Up @@ -139,6 +143,7 @@ func (w *memoryWriter) Close() error {
w.cache.entries[w.key] = &memoryEntry{
data: w.buf.Bytes(),
expiresAt: w.expiresAt,
headers: w.headers,
}
w.cache.currentSize += newSize

Expand Down
Loading