From 3bfa1eb1af4346e4caca1e6ea4e9b5eccbaaa96e Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Thu, 8 Jan 2026 13:25:20 +1100 Subject: [PATCH] feat: add `cache.Fetch()` function for fetching+caching a HTTP response --- internal/cache/http.go | 82 ++++++++++++++++++ internal/cache/http_test.go | 83 +++++++++++++++++++ .../cache/{remote/client.go => remote.go} | 34 ++++---- .../{remote/client_test.go => remote_test.go} | 5 +- internal/strategy/host.go | 41 ++------- 5 files changed, 192 insertions(+), 53 deletions(-) create mode 100644 internal/cache/http.go create mode 100644 internal/cache/http_test.go rename internal/cache/{remote/client.go => remote.go} (75%) rename internal/cache/{remote/client_test.go => remote_test.go} (88%) diff --git a/internal/cache/http.go b/internal/cache/http.go new file mode 100644 index 0000000..9d2e1bb --- /dev/null +++ b/internal/cache/http.go @@ -0,0 +1,82 @@ +package cache + +import ( + "fmt" + "io" + "maps" + "net/http" + "net/textproto" + "os" + + "github.com/alecthomas/errors" +) + +type HTTPError struct { + status int + err error +} + +func (h HTTPError) Error() string { return fmt.Sprintf("%d: %s", h.status, h.err) } +func (h HTTPError) Unwrap() error { return h.err } +func (h HTTPError) StatusCode() int { return h.status } + +func HTTPErrorf(status int, format string, args ...any) error { + return HTTPError{ + status: status, + err: errors.Errorf(format, args...), + } +} + +// Fetch retrieves a response from cache or fetches from the request URL and caches it. +// The response is streamed without buffering. Returns HTTPError for semantic errors. +// The caller must close the response body. +func Fetch(client *http.Client, r *http.Request, c Cache) (*http.Response, error) { + url := r.URL.String() + key := NewKey(url) + + cr, headers, err := c.Open(r.Context(), key) + if err == nil { + return &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header(headers), + Body: cr, + ContentLength: -1, + Request: r, + }, nil + } + if !errors.Is(err, os.ErrNotExist) { + return nil, HTTPErrorf(http.StatusInternalServerError, "failed to open cache: %w", err) + } + + resp, err := client.Do(r) //nolint:bodyclose // Body is returned to caller + if err != nil { + return nil, HTTPErrorf(http.StatusBadGateway, "failed to fetch: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return resp, nil + } + + responseHeaders := textproto.MIMEHeader(maps.Clone(resp.Header)) + cw, err := c.Create(r.Context(), key, responseHeaders, 0) + if err != nil { + _ = resp.Body.Close() + return nil, HTTPErrorf(http.StatusInternalServerError, "failed to create cache entry: %w", err) + } + + originalBody := resp.Body + pr, pw := io.Pipe() + go func() { + mw := io.MultiWriter(pw, cw) + _, copyErr := io.Copy(mw, originalBody) + closeErr := errors.Join(cw.Close(), originalBody.Close()) + pw.CloseWithError(errors.Join(copyErr, closeErr)) + }() + + resp.Body = pr + return resp, nil +} diff --git a/internal/cache/http_test.go b/internal/cache/http_test.go new file mode 100644 index 0000000..3a8d710 --- /dev/null +++ b/internal/cache/http_test.go @@ -0,0 +1,83 @@ +package cache_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/alecthomas/assert/v2" + + "github.com/block/sfptc/internal/cache" +) + +func TestCachedFetch(t *testing.T) { + ctx := context.Background() + memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer memCache.Close() + + callCount := 0 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + callCount++ + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("hello world")) + })) + defer backend.Close() + + client := &http.Client{} + + // First request - should hit backend + req1, err := http.NewRequestWithContext(ctx, http.MethodGet, backend.URL+"/test", nil) + assert.NoError(t, err) + resp1, err := cache.Fetch(client, req1, memCache) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp1.StatusCode) + assert.Equal(t, "text/plain", resp1.Header.Get("Content-Type")) + body1, err := io.ReadAll(resp1.Body) + assert.NoError(t, err) + assert.NoError(t, resp1.Body.Close()) + assert.Equal(t, "hello world", string(body1)) + assert.Equal(t, 1, callCount) + + // Second request - should hit cache + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, backend.URL+"/test", nil) + assert.NoError(t, err) + resp2, err := cache.Fetch(client, req2, memCache) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp2.StatusCode) + assert.Equal(t, "text/plain", resp2.Header.Get("Content-Type")) + body2, err := io.ReadAll(resp2.Body) + assert.NoError(t, err) + assert.NoError(t, resp2.Body.Close()) + assert.Equal(t, "hello world", string(body2)) + assert.Equal(t, 1, callCount, "should serve from cache") +} + +func TestCachedFetchNonOKStatus(t *testing.T) { + ctx := context.Background() + memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer memCache.Close() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("not found")) + })) + defer backend.Close() + + client := &http.Client{} + req, err := http.NewRequestWithContext(ctx, http.MethodGet, backend.URL+"/missing", nil) + assert.NoError(t, err) + + resp, err := cache.Fetch(client, req, memCache) + assert.NoError(t, err) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.NoError(t, resp.Body.Close()) + assert.Equal(t, "not found", string(body)) +} diff --git a/internal/cache/remote/client.go b/internal/cache/remote.go similarity index 75% rename from internal/cache/remote/client.go rename to internal/cache/remote.go index 80688ff..039e7c6 100644 --- a/internal/cache/remote/client.go +++ b/internal/cache/remote.go @@ -1,4 +1,4 @@ -package remote +package cache import ( "context" @@ -11,30 +11,28 @@ import ( "time" "github.com/alecthomas/errors" - - "github.com/block/sfptc/internal/cache" ) -// Client implements cache.Cache as a client for the remote cache server. -type Client struct { +// Remote implements Cache as a client for the remote cache server. +type Remote struct { baseURL string client *http.Client } -var _ cache.Cache = (*Client)(nil) +var _ Cache = (*Remote)(nil) -// NewClient creates a new remote cache client. -func NewClient(baseURL string) *Client { - return &Client{ +// NewRemote creates a new remote cache client. +func NewRemote(baseURL string) *Remote { + return &Remote{ baseURL: baseURL, client: &http.Client{}, } } -func (c *Client) String() string { return "remote:" + c.baseURL } +func (c *Remote) String() string { return "remote:" + c.baseURL } -// Open retrieves an object from the remote cache. -func (c *Client) Open(ctx context.Context, key cache.Key) (io.ReadCloser, textproto.MIMEHeader, error) { +// Open retrieves an object from the remote. +func (c *Remote) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error) { url := fmt.Sprintf("%s/%s", c.baseURL, key.String()) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -55,13 +53,13 @@ func (c *Client) Open(ctx context.Context, key cache.Key) (io.ReadCloser, textpr } // Filter out HTTP transport headers - headers := cache.FilterTransportHeaders(textproto.MIMEHeader(resp.Header)) + headers := FilterTransportHeaders(textproto.MIMEHeader(resp.Header)) return resp.Body, headers, nil } -// Create stores a new object in the remote cache. -func (c *Client) Create(ctx context.Context, key cache.Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) { +// Create stores a new object in the remote. +func (c *Remote) Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) { pr, pw := io.Pipe() url := fmt.Sprintf("%s/%s", c.baseURL, key.String()) @@ -100,8 +98,8 @@ func (c *Client) Create(ctx context.Context, key cache.Key, headers textproto.MI return wc, nil } -// Delete removes an object from the remote cache. -func (c *Client) Delete(ctx context.Context, key cache.Key) error { +// Delete removes an object from the remote. +func (c *Remote) Delete(ctx context.Context, key Key) error { url := fmt.Sprintf("%s/%s", c.baseURL, key.String()) req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil) if err != nil { @@ -126,7 +124,7 @@ func (c *Client) Delete(ctx context.Context, key cache.Key) error { } // Close closes the client and releases resources. -func (c *Client) Close() error { +func (c *Remote) Close() error { c.client.CloseIdleConnections() return nil } diff --git a/internal/cache/remote/client_test.go b/internal/cache/remote_test.go similarity index 88% rename from internal/cache/remote/client_test.go rename to internal/cache/remote_test.go index 871c4b4..fe6b46e 100644 --- a/internal/cache/remote/client_test.go +++ b/internal/cache/remote_test.go @@ -1,4 +1,4 @@ -package remote_test +package cache_test import ( "log/slog" @@ -10,7 +10,6 @@ import ( "github.com/block/sfptc/internal/cache" "github.com/block/sfptc/internal/cache/cachetest" - "github.com/block/sfptc/internal/cache/remote" "github.com/block/sfptc/internal/logging" "github.com/block/sfptc/internal/strategy" ) @@ -30,7 +29,7 @@ func TestRemoteClient(t *testing.T) { ts := httptest.NewServer(server) t.Cleanup(ts.Close) - client := remote.NewClient(ts.URL) + client := cache.NewRemote(ts.URL) return client }) } diff --git a/internal/strategy/host.go b/internal/strategy/host.go index 86a1862..fc17f11 100644 --- a/internal/strategy/host.go +++ b/internal/strategy/host.go @@ -7,9 +7,7 @@ import ( "log/slog" "maps" "net/http" - "net/textproto" "net/url" - "os" "github.com/alecthomas/errors" @@ -70,31 +68,19 @@ func (d *Host) ServeHTTP(w http.ResponseWriter, r *http.Request) { targetURL.RawQuery = r.URL.RawQuery fullURL := targetURL.String() - key := cache.NewKey(fullURL) - - cr, headers, err := d.cache.Open(r.Context(), key) - if err == nil { - defer cr.Close() - maps.Copy(w.Header(), headers) - if _, err := io.Copy(w, cr); err != nil { - d.logger.Error("Failed to copy cached response", slog.String("error", err.Error()), slog.String("url", fullURL)) - } - return - } - - if !errors.Is(err, os.ErrNotExist) { - d.logger.Error("Failed to open cache", slog.String("error", err.Error()), slog.String("url", fullURL)) - } - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, fullURL, nil) if err != nil { d.httpError(w, http.StatusInternalServerError, err, "Failed to create request", slog.String("url", fullURL)) return } - resp, err := d.client.Do(req) + resp, err := cache.Fetch(d.client, req, d.cache) if err != nil { - d.httpError(w, http.StatusBadGateway, err, "Failed to fetch from target", slog.String("url", fullURL)) + if httpErr, ok := errors.AsType[cache.HTTPError](err); ok { + d.httpError(w, httpErr.StatusCode(), httpErr, httpErr.Error(), slog.String("url", fullURL)) + } else { + d.httpError(w, http.StatusInternalServerError, err, "Failed to fetch", slog.String("url", fullURL)) + } return } defer resp.Body.Close() @@ -107,18 +93,9 @@ func (d *Host) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - responseHeaders := textproto.MIMEHeader(maps.Clone(resp.Header)) - cw, err := d.cache.Create(r.Context(), key, responseHeaders, 0) - if err != nil { - d.httpError(w, http.StatusInternalServerError, err, "Failed to create cache entry", slog.String("url", fullURL)) - return - } - - mw := io.MultiWriter(w, cw) - _, copyErr := io.Copy(mw, resp.Body) - closeErr := cw.Close() - if err := errors.Join(copyErr, closeErr); err != nil { - d.logger.Error("Failed to write to cache", slog.String("error", err.Error()), slog.String("url", fullURL)) + maps.Copy(w.Header(), resp.Header) + if _, err := io.Copy(w, resp.Body); err != nil { + d.logger.Error("Failed to copy response", slog.String("error", err.Error()), slog.String("url", fullURL)) } }