Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions internal/cache/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package cache

import (
"fmt"
"io"
"maps"
"net/http"
"net/textproto"
"os"

"github.com/alecthomas/errors"
)

type HTTPError struct {
status int
err error
}

func (h HTTPError) Error() string { return fmt.Sprintf("%d: %s", h.status, h.err) }
func (h HTTPError) Unwrap() error { return h.err }
func (h HTTPError) StatusCode() int { return h.status }

func HTTPErrorf(status int, format string, args ...any) error {
return HTTPError{
status: status,
err: errors.Errorf(format, args...),
}
}

// Fetch retrieves a response from cache or fetches from the request URL and caches it.
// The response is streamed without buffering. Returns HTTPError for semantic errors.
// The caller must close the response body.
func Fetch(client *http.Client, r *http.Request, c Cache) (*http.Response, error) {
url := r.URL.String()
key := NewKey(url)

cr, headers, err := c.Open(r.Context(), key)
if err == nil {
return &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header(headers),
Body: cr,
ContentLength: -1,
Request: r,
}, nil
}
if !errors.Is(err, os.ErrNotExist) {
return nil, HTTPErrorf(http.StatusInternalServerError, "failed to open cache: %w", err)
}

resp, err := client.Do(r) //nolint:bodyclose // Body is returned to caller
if err != nil {
return nil, HTTPErrorf(http.StatusBadGateway, "failed to fetch: %w", err)
}

if resp.StatusCode != http.StatusOK {
return resp, nil
}

responseHeaders := textproto.MIMEHeader(maps.Clone(resp.Header))
cw, err := c.Create(r.Context(), key, responseHeaders, 0)
if err != nil {
_ = resp.Body.Close()
return nil, HTTPErrorf(http.StatusInternalServerError, "failed to create cache entry: %w", err)
}

originalBody := resp.Body
pr, pw := io.Pipe()
go func() {
mw := io.MultiWriter(pw, cw)
_, copyErr := io.Copy(mw, originalBody)
closeErr := errors.Join(cw.Close(), originalBody.Close())
pw.CloseWithError(errors.Join(copyErr, closeErr))
}()

resp.Body = pr
return resp, nil
}
83 changes: 83 additions & 0 deletions internal/cache/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package cache_test

import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/alecthomas/assert/v2"

"github.com/block/sfptc/internal/cache"
)

func TestCachedFetch(t *testing.T) {
ctx := context.Background()
memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour})
assert.NoError(t, err)
defer memCache.Close()

callCount := 0
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
callCount++
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello world"))
}))
defer backend.Close()

client := &http.Client{}

// First request - should hit backend
req1, err := http.NewRequestWithContext(ctx, http.MethodGet, backend.URL+"/test", nil)
assert.NoError(t, err)
resp1, err := cache.Fetch(client, req1, memCache)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp1.StatusCode)
assert.Equal(t, "text/plain", resp1.Header.Get("Content-Type"))
body1, err := io.ReadAll(resp1.Body)
assert.NoError(t, err)
assert.NoError(t, resp1.Body.Close())
assert.Equal(t, "hello world", string(body1))
assert.Equal(t, 1, callCount)

// Second request - should hit cache
req2, err := http.NewRequestWithContext(ctx, http.MethodGet, backend.URL+"/test", nil)
assert.NoError(t, err)
resp2, err := cache.Fetch(client, req2, memCache)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp2.StatusCode)
assert.Equal(t, "text/plain", resp2.Header.Get("Content-Type"))
body2, err := io.ReadAll(resp2.Body)
assert.NoError(t, err)
assert.NoError(t, resp2.Body.Close())
assert.Equal(t, "hello world", string(body2))
assert.Equal(t, 1, callCount, "should serve from cache")
}

func TestCachedFetchNonOKStatus(t *testing.T) {
ctx := context.Background()
memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour})
assert.NoError(t, err)
defer memCache.Close()

backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte("not found"))
}))
defer backend.Close()

client := &http.Client{}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, backend.URL+"/missing", nil)
assert.NoError(t, err)

resp, err := cache.Fetch(client, req, memCache)
assert.NoError(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.NoError(t, resp.Body.Close())
assert.Equal(t, "not found", string(body))
}
34 changes: 16 additions & 18 deletions internal/cache/remote/client.go → internal/cache/remote.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package remote
package cache

import (
"context"
Expand All @@ -11,30 +11,28 @@ import (
"time"

"github.com/alecthomas/errors"

"github.com/block/sfptc/internal/cache"
)

// Client implements cache.Cache as a client for the remote cache server.
type Client struct {
// Remote implements Cache as a client for the remote cache server.
type Remote struct {
baseURL string
client *http.Client
}

var _ cache.Cache = (*Client)(nil)
var _ Cache = (*Remote)(nil)

// NewClient creates a new remote cache client.
func NewClient(baseURL string) *Client {
return &Client{
// NewRemote creates a new remote cache client.
func NewRemote(baseURL string) *Remote {
return &Remote{
baseURL: baseURL,
client: &http.Client{},
}
}

func (c *Client) String() string { return "remote:" + c.baseURL }
func (c *Remote) String() string { return "remote:" + c.baseURL }

// Open retrieves an object from the remote cache.
func (c *Client) Open(ctx context.Context, key cache.Key) (io.ReadCloser, textproto.MIMEHeader, error) {
// Open retrieves an object from the remote.
func (c *Remote) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error) {
url := fmt.Sprintf("%s/%s", c.baseURL, key.String())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
Expand All @@ -55,13 +53,13 @@ func (c *Client) Open(ctx context.Context, key cache.Key) (io.ReadCloser, textpr
}

// Filter out HTTP transport headers
headers := cache.FilterTransportHeaders(textproto.MIMEHeader(resp.Header))
headers := FilterTransportHeaders(textproto.MIMEHeader(resp.Header))

return resp.Body, headers, nil
}

// Create stores a new object in the remote cache.
func (c *Client) Create(ctx context.Context, key cache.Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) {
// Create stores a new object in the remote.
func (c *Remote) Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) {
pr, pw := io.Pipe()

url := fmt.Sprintf("%s/%s", c.baseURL, key.String())
Expand Down Expand Up @@ -100,8 +98,8 @@ func (c *Client) Create(ctx context.Context, key cache.Key, headers textproto.MI
return wc, nil
}

// Delete removes an object from the remote cache.
func (c *Client) Delete(ctx context.Context, key cache.Key) error {
// Delete removes an object from the remote.
func (c *Remote) Delete(ctx context.Context, key Key) error {
url := fmt.Sprintf("%s/%s", c.baseURL, key.String())
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil)
if err != nil {
Expand All @@ -126,7 +124,7 @@ func (c *Client) Delete(ctx context.Context, key cache.Key) error {
}

// Close closes the client and releases resources.
func (c *Client) Close() error {
func (c *Remote) Close() error {
c.client.CloseIdleConnections()
return nil
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package remote_test
package cache_test

import (
"log/slog"
Expand All @@ -10,7 +10,6 @@ import (

"github.com/block/sfptc/internal/cache"
"github.com/block/sfptc/internal/cache/cachetest"
"github.com/block/sfptc/internal/cache/remote"
"github.com/block/sfptc/internal/logging"
"github.com/block/sfptc/internal/strategy"
)
Expand All @@ -30,7 +29,7 @@ func TestRemoteClient(t *testing.T) {
ts := httptest.NewServer(server)
t.Cleanup(ts.Close)

client := remote.NewClient(ts.URL)
client := cache.NewRemote(ts.URL)
return client
})
}
41 changes: 9 additions & 32 deletions internal/strategy/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ import (
"log/slog"
"maps"
"net/http"
"net/textproto"
"net/url"
"os"

"github.com/alecthomas/errors"

Expand Down Expand Up @@ -70,31 +68,19 @@ func (d *Host) ServeHTTP(w http.ResponseWriter, r *http.Request) {
targetURL.RawQuery = r.URL.RawQuery
fullURL := targetURL.String()

key := cache.NewKey(fullURL)

cr, headers, err := d.cache.Open(r.Context(), key)
if err == nil {
defer cr.Close()
maps.Copy(w.Header(), headers)
if _, err := io.Copy(w, cr); err != nil {
d.logger.Error("Failed to copy cached response", slog.String("error", err.Error()), slog.String("url", fullURL))
}
return
}

if !errors.Is(err, os.ErrNotExist) {
d.logger.Error("Failed to open cache", slog.String("error", err.Error()), slog.String("url", fullURL))
}

req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, fullURL, nil)
if err != nil {
d.httpError(w, http.StatusInternalServerError, err, "Failed to create request", slog.String("url", fullURL))
return
}

resp, err := d.client.Do(req)
resp, err := cache.Fetch(d.client, req, d.cache)
if err != nil {
d.httpError(w, http.StatusBadGateway, err, "Failed to fetch from target", slog.String("url", fullURL))
if httpErr, ok := errors.AsType[cache.HTTPError](err); ok {
d.httpError(w, httpErr.StatusCode(), httpErr, httpErr.Error(), slog.String("url", fullURL))
} else {
d.httpError(w, http.StatusInternalServerError, err, "Failed to fetch", slog.String("url", fullURL))
}
return
}
defer resp.Body.Close()
Expand All @@ -107,18 +93,9 @@ func (d *Host) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

responseHeaders := textproto.MIMEHeader(maps.Clone(resp.Header))
cw, err := d.cache.Create(r.Context(), key, responseHeaders, 0)
if err != nil {
d.httpError(w, http.StatusInternalServerError, err, "Failed to create cache entry", slog.String("url", fullURL))
return
}

mw := io.MultiWriter(w, cw)
_, copyErr := io.Copy(mw, resp.Body)
closeErr := cw.Close()
if err := errors.Join(copyErr, closeErr); err != nil {
d.logger.Error("Failed to write to cache", slog.String("error", err.Error()), slog.String("url", fullURL))
maps.Copy(w.Header(), resp.Header)
if _, err := io.Copy(w, resp.Body); err != nil {
d.logger.Error("Failed to copy response", slog.String("error", err.Error()), slog.String("url", fullURL))
}
}

Expand Down