diff --git a/.golangci.yml b/.golangci.yml index 7f13bdd..a83a1ea 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -483,7 +483,6 @@ linters: path: "(internal/cache/disk\\.go|internal/strategy/git/spool\\.go)" - text: "G704:" linters: [gosec] - path: "(internal/cache/(http|remote)\\.go|internal/githubapp/tokens\\.go|internal/strategy/(git/proxy|github_releases|handler/handler)\\.go)" - text: "avoid package names that conflict with Go standard library" linters: [revive] path: "(internal/httputil/|internal/metrics/)" diff --git a/client/archive.go b/client/archive.go new file mode 100644 index 0000000..5cc7f52 --- /dev/null +++ b/client/archive.go @@ -0,0 +1,152 @@ +package client + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "runtime" + + "github.com/alecthomas/errors" +) + +// Archive writes a tar+zstd stream of the given paths to w. Each entry in +// includePaths is relative to baseDir and must exist. Exclude patterns use +// tar's --exclude syntax. threads controls zstd parallelism; 0 uses all CPU +// cores. +func Archive(ctx context.Context, w io.Writer, baseDir string, includePaths []string, excludePatterns []string, threads int) error { + if threads <= 0 { + threads = runtime.NumCPU() + } + + if len(includePaths) == 0 { + return errors.New("includePaths must not be empty") + } + + info, err := os.Stat(baseDir) + if err != nil { + return errors.Wrap(err, "failed to stat base directory") + } + if !info.IsDir() { + return errors.Errorf("not a directory: %s", baseDir) + } + for _, path := range includePaths { + if _, err := os.Stat(filepath.Join(baseDir, path)); err != nil { + return errors.Wrapf(err, "failed to stat include path %q", path) + } + } + + tarArgs := []string{"-cpf", "-", "-C", baseDir} + for _, pattern := range excludePatterns { + tarArgs = append(tarArgs, "--exclude", pattern) + } + tarArgs = append(tarArgs, "--") + tarArgs = append(tarArgs, includePaths...) + + return runTarZstdPipeline(ctx, tarArgs, threads, w) +} + +// Extract decompresses a zstd+tar stream from r into directory, preserving +// file permissions, ownership, and symlinks. threads controls zstd +// parallelism; 0 uses all CPU cores. +func Extract(ctx context.Context, r io.Reader, directory string, threads int) error { + if threads <= 0 { + threads = runtime.NumCPU() + } + + if err := os.MkdirAll(directory, 0o750); err != nil { + return errors.Wrap(err, "failed to create target directory") + } + + zstdCmd := exec.CommandContext(ctx, "zstd", "-dc", fmt.Sprintf("-T%d", threads)) //nolint:gosec + tarCmd := exec.CommandContext(ctx, "tar", "-xpf", "-", "-C", directory) + + pr, pw, err := os.Pipe() + if err != nil { + return errors.Wrap(err, "failed to create pipe") + } + + var zstdStderr, tarStderr bytes.Buffer + zstdCmd.Stdin = r + zstdCmd.Stdout = pw + zstdCmd.Stderr = &zstdStderr + + tarCmd.Stdin = pr + tarCmd.Stderr = &tarStderr + + if err := zstdCmd.Start(); err != nil { + pw.Close() //nolint:errcheck,gosec + pr.Close() //nolint:errcheck,gosec + return errors.Wrap(err, "failed to start zstd") + } + pw.Close() //nolint:errcheck,gosec + + if err := tarCmd.Start(); err != nil { + pr.Close() //nolint:errcheck,gosec + return errors.Join(errors.Wrap(err, "failed to start tar"), zstdCmd.Wait()) + } + pr.Close() //nolint:errcheck,gosec + + zstdErr := zstdCmd.Wait() + tarErr := tarCmd.Wait() + + var errs []error + if zstdErr != nil { + errs = append(errs, errors.Errorf("zstd failed: %w: %s", zstdErr, zstdStderr.String())) + } + if tarErr != nil { + errs = append(errs, errors.Errorf("tar failed: %w: %s", tarErr, tarStderr.String())) + } + return errors.Join(errs...) +} + +// runTarZstdPipeline runs tar piped through zstd, writing compressed output +// to w. The caller is responsible for closing w after this returns. +func runTarZstdPipeline(ctx context.Context, tarArgs []string, threads int, w io.Writer) error { + tarCmd := exec.CommandContext(ctx, "tar", tarArgs...) + zstdCmd := exec.CommandContext(ctx, "zstd", "-c", fmt.Sprintf("-T%d", threads)) //nolint:gosec + + // Manual pipe so we can close both ends in the parent after starting + // children. Prevents deadlock if zstd exits while tar is still writing: + // closing the read end ensures tar receives SIGPIPE instead of blocking. + pr, pw, err := os.Pipe() + if err != nil { + return errors.Wrap(err, "failed to create pipe") + } + + var tarStderr, zstdStderr bytes.Buffer + tarCmd.Stdout = pw + tarCmd.Stderr = &tarStderr + + zstdCmd.Stdin = pr + zstdCmd.Stdout = w + zstdCmd.Stderr = &zstdStderr + + if err := tarCmd.Start(); err != nil { + pw.Close() //nolint:errcheck,gosec + pr.Close() //nolint:errcheck,gosec + return errors.Wrap(err, "failed to start tar") + } + pw.Close() //nolint:errcheck,gosec + + if err := zstdCmd.Start(); err != nil { + pr.Close() //nolint:errcheck,gosec + return errors.Join(errors.Wrap(err, "failed to start zstd"), tarCmd.Wait()) + } + pr.Close() //nolint:errcheck,gosec + + tarErr := tarCmd.Wait() + zstdErr := zstdCmd.Wait() + + var errs []error + if tarErr != nil { + errs = append(errs, errors.Errorf("tar failed: %w: %s", tarErr, tarStderr.String())) + } + if zstdErr != nil { + errs = append(errs, errors.Errorf("zstd failed: %w: %s", zstdErr, zstdStderr.String())) + } + return errors.Join(errs...) +} diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..147d15e --- /dev/null +++ b/client/client.go @@ -0,0 +1,339 @@ +// Package client provides a standalone HTTP client for the Cachew cache server. +package client + +import ( + "context" + "encoding/json" + "fmt" + "io" + "maps" + "net/http" + "os" + "time" + + "github.com/alecthomas/errors" +) + +// transportHeaders are headers added by the HTTP transport layer that should +// not be surfaced as cached-object metadata on responses. +var transportHeaders = []string{ //nolint:gochecknoglobals + "Content-Length", + "Date", + "Accept-Encoding", + "User-Agent", + "Transfer-Encoding", + "Time-To-Live", +} + +// HeaderFunc returns headers to attach to each outgoing request. +type HeaderFunc func() http.Header + +// NewHTTPClient creates an *http.Client that attaches headerFunc headers +// to every outgoing request. Useful for callers that need to talk to +// non-API endpoints (e.g. /git/) with the same auth as the cache client. +func NewHTTPClient(headerFunc HeaderFunc) *http.Client { + transport := http.DefaultTransport.(*http.Transport).Clone() //nolint:errcheck + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 100 + + var rt http.RoundTripper = transport + if headerFunc != nil { + rt = &headerTransport{base: transport, headerFunc: headerFunc} + } + return &http.Client{Transport: rt} +} + +type headerTransport struct { + base http.RoundTripper + headerFunc HeaderFunc +} + +func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + for key, values := range t.headerFunc() { + for _, value := range values { + req.Header.Add(key, value) + } + } + return t.base.RoundTrip(req) //nolint:wrapcheck +} + +// Client is an HTTP client for a Cachew cache server. Its method set mirrors +// the cache.Cache interface, so it can be used as the transport for a remote +// cache backend. +type Client struct { + baseURL string + http *http.Client + namespace Namespace +} + +// New creates a Client against the given base URL (e.g. "http://localhost:8080"). +// If headerFunc is non-nil, its returned headers are added to every outgoing +// request. +func New(baseURL string, headerFunc HeaderFunc) *Client { + return &Client{ + baseURL: baseURL + "/api/v1", + http: NewHTTPClient(headerFunc), + } +} + +// NewWithHTTPClient creates a Client against baseURL using the supplied +// *http.Client. Callers are responsible for configuring authentication on +// the supplied client (e.g. via a custom RoundTripper). +func NewWithHTTPClient(baseURL string, httpClient *http.Client) *Client { + return &Client{ + baseURL: baseURL + "/api/v1", + http: httpClient, + } +} + +// HTTP returns the underlying HTTP client, for callers needing to talk to +// non-API endpoints with the same auth configuration. +func (c *Client) HTTP() *http.Client { return c.http } + +// BaseURL returns the /api/v1 base URL this client targets. +func (c *Client) BaseURL() string { return c.baseURL } + +// String describes the client. +func (c *Client) String() string { return "remote:" + c.baseURL } + +// Namespace returns a derived client that targets the given namespace. +func (c *Client) Namespace(namespace Namespace) *Client { + return &Client{ + baseURL: c.baseURL, + http: c.http, + namespace: namespace, + } +} + +func (c *Client) resolvedNamespace() Namespace { + if c.namespace == "" { + return DefaultNamespace + } + return c.namespace +} + +func (c *Client) objectURL(key Key) string { + return fmt.Sprintf("%s/object/%s/%s", c.baseURL, c.resolvedNamespace(), key.String()) +} + +// Open retrieves an object from the cache server. +func (c *Client) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.objectURL(key), nil) + if err != nil { + return nil, nil, errors.Wrap(err, "failed to create request") + } + + resp, err := c.http.Do(req) + if err != nil { + return nil, nil, errors.Wrap(err, "failed to execute request") + } + + if resp.StatusCode == http.StatusNotFound { + _, _ = io.Copy(io.Discard, resp.Body) //nolint:errcheck,gosec + return nil, nil, errors.Join(os.ErrNotExist, resp.Body.Close()) + } + + if resp.StatusCode != http.StatusOK { + _, _ = io.Copy(io.Discard, resp.Body) //nolint:errcheck,gosec + return nil, nil, errors.Join(errors.Errorf("unexpected status code: %d", resp.StatusCode), resp.Body.Close()) + } + + return resp.Body, filterHeaders(resp.Header, transportHeaders...), nil +} + +// Stat retrieves headers for an object from the cache server. +func (c *Client) Stat(ctx context.Context, key Key) (http.Header, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, c.objectURL(key), nil) + if err != nil { + return nil, errors.Wrap(err, "failed to create request") + } + + resp, err := c.http.Do(req) + if err != nil { + return nil, errors.Wrap(err, "failed to execute request") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return nil, os.ErrNotExist + } + + if resp.StatusCode != http.StatusOK { + return nil, errors.Errorf("unexpected status code: %d", resp.StatusCode) + } + + return filterHeaders(resp.Header, transportHeaders...), nil +} + +// Create stores a new object in the cache server. The returned io.WriteCloser +// must be closed to complete the upload; if the context is cancelled before +// Close returns, the object is not made available in the cache. +func (c *Client) Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (io.WriteCloser, error) { + pr, pw := io.Pipe() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.objectURL(key), pr) + if err != nil { + return nil, errors.Join(errors.Wrap(err, "failed to create request"), pr.Close(), pw.Close()) + } + + maps.Copy(req.Header, headers) + + if ttl > 0 { + req.Header.Set("Time-To-Live", ttl.String()) + } + + wc := &writeCloser{ + pw: pw, + done: make(chan error, 1), + ctx: ctx, + } + + go func() { + resp, err := c.http.Do(req) + if err != nil { + wc.done <- errors.Wrap(err, "failed to execute request") + return + } + _, _ = io.Copy(io.Discard, resp.Body) //nolint:errcheck,gosec + _ = resp.Body.Close() //nolint:gosec + + if resp.StatusCode != http.StatusOK { + wc.done <- errors.Errorf("unexpected status code: %d", resp.StatusCode) + return + } + + wc.done <- nil + }() + + return wc, nil +} + +// Delete removes an object from the cache server. +func (c *Client) Delete(ctx context.Context, key Key) error { + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.objectURL(key), nil) + if err != nil { + return errors.Wrap(err, "failed to create request") + } + + resp, err := c.http.Do(req) + if err != nil { + return errors.Wrap(err, "failed to execute request") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return os.ErrNotExist + } + + if resp.StatusCode != http.StatusOK { + return errors.Errorf("unexpected status code: %d", resp.StatusCode) + } + + return nil +} + +// Close releases resources held by the client. +func (c *Client) Close() error { + c.http.CloseIdleConnections() + return nil +} + +// Stats retrieves cache statistics from the server. +func (c *Client) Stats(ctx context.Context) (Stats, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/stats", nil) + if err != nil { + return Stats{}, errors.Wrap(err, "failed to create request") + } + + resp, err := c.http.Do(req) + if err != nil { + return Stats{}, errors.Wrap(err, "failed to execute request") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotImplemented { + return Stats{}, ErrStatsUnavailable + } + + if resp.StatusCode != http.StatusOK { + return Stats{}, errors.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var stats Stats + if err := json.NewDecoder(resp.Body).Decode(&stats); err != nil { + return Stats{}, errors.Wrap(err, "failed to decode stats response") + } + + return stats, nil +} + +// ListNamespaces requests the namespace list from the server. +func (c *Client) ListNamespaces(ctx context.Context) ([]string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/namespaces", nil) + if err != nil { + return nil, errors.WithStack(err) + } + + resp, err := c.http.Do(req) + if err != nil { + return nil, errors.WithStack(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) //nolint:errcheck + return nil, errors.Errorf("unexpected status %d: %s", resp.StatusCode, body) + } + + var namespaces []string + if err := json.NewDecoder(resp.Body).Decode(&namespaces); err != nil { + return nil, errors.WithStack(err) + } + + return namespaces, nil +} + +// filterHeaders returns a copy of headers with the specified keys removed. +func filterHeaders(headers http.Header, skip ...string) http.Header { + skipSet := make(map[string]bool, len(skip)) + for _, s := range skip { + skipSet[http.CanonicalHeaderKey(s)] = true + } + filtered := make(http.Header, len(headers)) + for key, values := range headers { + if skipSet[http.CanonicalHeaderKey(key)] { + continue + } + filtered[key] = values + } + return filtered +} + +// writeCloser wraps a pipe writer and waits for the HTTP request to complete. +type writeCloser struct { + pw *io.PipeWriter + done chan error + ctx context.Context +} + +func (wc *writeCloser) Write(p []byte) (int, error) { + n, err := wc.pw.Write(p) + return n, errors.WithStack(err) +} + +func (wc *writeCloser) Close() error { + if err := wc.ctx.Err(); err != nil { + _ = wc.pw.CloseWithError(err) + <-wc.done + return errors.Wrap(err, "create operation cancelled") + } + if err := wc.pw.Close(); err != nil { + <-wc.done + return errors.Wrap(err, "failed to close pipe writer") + } + err := <-wc.done + if err != nil { + return errors.Wrap(err, "request failed") + } + return nil +} diff --git a/client/client_test.go b/client/client_test.go new file mode 100644 index 0000000..322f5c0 --- /dev/null +++ b/client/client_test.go @@ -0,0 +1,308 @@ +package client_test + +import ( + "bytes" + "encoding/json" + "io" + "maps" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "slices" + "sort" + "strings" + "sync" + "testing" + + "github.com/alecthomas/assert/v2" + + "github.com/block/cachew/client" +) + +// fakeServer is a minimal /api/v1 backend for exercising the client over HTTP. +type fakeServer struct { + mu sync.Mutex + objects map[string]fakeObject // keyed by "namespace/hex-key" + stats *client.Stats // nil signals ErrStatsUnavailable +} + +type fakeObject struct { + body []byte + headers http.Header +} + +func newFakeServer(stats *client.Stats) *httptest.Server { + fs := &fakeServer{objects: make(map[string]fakeObject), stats: stats} + mux := http.NewServeMux() + mux.HandleFunc("GET /api/v1/object/{namespace}/{key}", fs.get) + mux.HandleFunc("HEAD /api/v1/object/{namespace}/{key}", fs.stat) + mux.HandleFunc("POST /api/v1/object/{namespace}/{key}", fs.put) + mux.HandleFunc("DELETE /api/v1/object/{namespace}/{key}", fs.delete) + mux.HandleFunc("GET /api/v1/namespaces", fs.namespaces) + mux.HandleFunc("GET /api/v1/stats", fs.getStats) + return httptest.NewServer(mux) +} + +func (fs *fakeServer) key(r *http.Request) string { + return r.PathValue("namespace") + "/" + r.PathValue("key") +} + +func (fs *fakeServer) get(w http.ResponseWriter, r *http.Request) { + fs.mu.Lock() + obj, ok := fs.objects[fs.key(r)] + fs.mu.Unlock() + if !ok { + http.NotFound(w, r) + return + } + maps.Copy(w.Header(), obj.headers) + w.WriteHeader(http.StatusOK) + w.Write(obj.body) //nolint:errcheck +} + +func (fs *fakeServer) stat(w http.ResponseWriter, r *http.Request) { + fs.mu.Lock() + obj, ok := fs.objects[fs.key(r)] + fs.mu.Unlock() + if !ok { + http.NotFound(w, r) + return + } + maps.Copy(w.Header(), obj.headers) + w.WriteHeader(http.StatusOK) +} + +func (fs *fakeServer) put(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + headers := make(http.Header) + for k, v := range r.Header { + if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "User-Agent") || + strings.EqualFold(k, "Accept-Encoding") { + continue + } + headers[k] = v + } + fs.mu.Lock() + fs.objects[fs.key(r)] = fakeObject{body: body, headers: headers} + fs.mu.Unlock() + w.WriteHeader(http.StatusOK) +} + +func (fs *fakeServer) delete(w http.ResponseWriter, r *http.Request) { + fs.mu.Lock() + _, ok := fs.objects[fs.key(r)] + if ok { + delete(fs.objects, fs.key(r)) + } + fs.mu.Unlock() + if !ok { + http.NotFound(w, r) + return + } + w.WriteHeader(http.StatusOK) +} + +func (fs *fakeServer) namespaces(w http.ResponseWriter, _ *http.Request) { + fs.mu.Lock() + seen := make(map[string]struct{}) + for k := range fs.objects { + ns := strings.SplitN(k, "/", 2)[0] + seen[ns] = struct{}{} + } + fs.mu.Unlock() + out := make([]string, 0, len(seen)) + for ns := range seen { + out = append(out, ns) + } + sort.Strings(out) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(out) //nolint:errcheck +} + +func (fs *fakeServer) getStats(w http.ResponseWriter, _ *http.Request) { + if fs.stats == nil { + http.Error(w, "stats unavailable", http.StatusNotImplemented) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(fs.stats) //nolint:errcheck +} + +func TestObjectRoundTrip(t *testing.T) { + srv := newFakeServer(nil) + defer srv.Close() + + c := client.New(srv.URL, nil).Namespace("test") + defer c.Close() + + ctx := t.Context() + key := client.NewKey("hello") + payload := []byte("hello world") + + wc, err := c.Create(ctx, key, http.Header{"Content-Type": {"text/plain"}}, 0) + assert.NoError(t, err) + _, err = wc.Write(payload) + assert.NoError(t, err) + assert.NoError(t, wc.Close()) + + headers, err := c.Stat(ctx, key) + assert.NoError(t, err) + assert.Equal(t, "text/plain", headers.Get("Content-Type")) + + rc, headers, err := c.Open(ctx, key) + assert.NoError(t, err) + got, err := io.ReadAll(rc) + assert.NoError(t, err) + assert.NoError(t, rc.Close()) + assert.Equal(t, payload, got) + assert.Equal(t, "text/plain", headers.Get("Content-Type")) + + assert.NoError(t, c.Delete(ctx, key)) + _, err = c.Stat(ctx, key) + assert.Error(t, err) + assert.True(t, isNotExist(err)) +} + +func isNotExist(err error) bool { return err != nil && os.IsNotExist(err) } + +func TestHeaderFuncAppliesAuth(t *testing.T) { + var seenAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenAuth = r.Header.Get("Authorization") + http.NotFound(w, r) + })) + defer srv.Close() + + c := client.New(srv.URL, func() http.Header { + return http.Header{"Authorization": {"Bearer token"}} + }) + defer c.Close() + + _, err := c.Stat(t.Context(), client.NewKey("missing")) + assert.Error(t, err) + assert.Equal(t, "Bearer token", seenAuth) +} + +func TestListNamespaces(t *testing.T) { + srv := newFakeServer(nil) + defer srv.Close() + + c := client.New(srv.URL, nil) + defer c.Close() + ctx := t.Context() + + for _, ns := range []string{"alpha", "beta"} { + nsClient := c.Namespace(client.Namespace(ns)) + wc, err := nsClient.Create(ctx, client.NewKey("x"), nil, 0) + assert.NoError(t, err) + _, err = wc.Write([]byte("x")) + assert.NoError(t, err) + assert.NoError(t, wc.Close()) + } + + namespaces, err := c.ListNamespaces(ctx) + assert.NoError(t, err) + assert.Equal(t, []string{"alpha", "beta"}, namespaces) +} + +func TestStatsUnavailable(t *testing.T) { + srv := newFakeServer(nil) + defer srv.Close() + + c := client.New(srv.URL, nil) + defer c.Close() + + _, err := c.Stats(t.Context()) + assert.Equal(t, client.ErrStatsUnavailable, err) +} + +func TestStatsReturned(t *testing.T) { + want := client.Stats{Objects: 3, Size: 1024, Capacity: 4096} + srv := newFakeServer(&want) + defer srv.Close() + + c := client.New(srv.URL, nil) + defer c.Close() + + got, err := c.Stats(t.Context()) + assert.NoError(t, err) + assert.Equal(t, want, got) +} + +func TestSnapshotRoundTrip(t *testing.T) { + srv := newFakeServer(nil) + defer srv.Close() + + c := client.New(srv.URL, nil).Namespace("snap") + defer c.Close() + ctx := t.Context() + + src := t.TempDir() + assert.NoError(t, os.MkdirAll(filepath.Join(src, "sub"), 0o755)) + assert.NoError(t, os.WriteFile(filepath.Join(src, "a.txt"), []byte("alpha"), 0o644)) + assert.NoError(t, os.WriteFile(filepath.Join(src, "sub", "b.txt"), []byte("bravo"), 0o644)) + + key := client.NewKey("snapshot") + assert.NoError(t, c.Snapshot(ctx, key, src, client.SnapshotOptions{})) + + dst := filepath.Join(t.TempDir(), "out") + assert.NoError(t, c.Restore(ctx, key, dst, client.RestoreOptions{})) + + a, err := os.ReadFile(filepath.Join(dst, "a.txt")) + assert.NoError(t, err) + assert.Equal(t, "alpha", string(a)) + + b, err := os.ReadFile(filepath.Join(dst, "sub", "b.txt")) + assert.NoError(t, err) + assert.Equal(t, "bravo", string(b)) +} + +func TestArchiveExtract(t *testing.T) { + src := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(src, "x.txt"), []byte("x"), 0o644)) + assert.NoError(t, os.WriteFile(filepath.Join(src, "y.log"), []byte("y"), 0o644)) + + var buf bytes.Buffer + assert.NoError(t, client.Archive(t.Context(), &buf, src, []string{"."}, []string{"*.log"}, 0)) + + dst := filepath.Join(t.TempDir(), "out") + assert.NoError(t, client.Extract(t.Context(), &buf, dst, 0)) + + entries, err := os.ReadDir(dst) + assert.NoError(t, err) + names := make([]string, 0, len(entries)) + for _, e := range entries { + names = append(names, e.Name()) + } + assert.True(t, slices.Contains(names, "x.txt")) + assert.False(t, slices.Contains(names, "y.log")) +} + +func TestParseKey(t *testing.T) { + tests := []struct { + name string + input string + }{ + {name: "RawString", input: "hello"}, + {name: "HexString", input: keyHex(client.NewKey("hello"))}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := client.ParseKey(tt.input) + assert.NoError(t, err) + assert.Equal(t, client.NewKey("hello"), got) + }) + } +} + +func keyHex(k client.Key) string { return (&k).String() } + +func TestParseNamespaceInvalid(t *testing.T) { + _, err := client.ParseNamespace("_bad") + assert.Error(t, err) +} diff --git a/client/key.go b/client/key.go new file mode 100644 index 0000000..9ce6ebc --- /dev/null +++ b/client/key.go @@ -0,0 +1,36 @@ +package client + +import ( + "crypto/sha256" + "encoding/hex" +) + +// Key represents a unique identifier for a cached object. +type Key [32]byte + +// ParseKey from its hex-encoded string form. +func ParseKey(key string) (Key, error) { + var k Key + return k, k.UnmarshalText([]byte(key)) +} + +// NewKey returns the SHA256 of s. +func NewKey(s string) Key { return Key(sha256.Sum256([]byte(s))) } + +func (k *Key) String() string { return hex.EncodeToString(k[:]) } + +func (k *Key) UnmarshalText(text []byte) error { + if len(text) == 64 { + bytes, err := hex.DecodeString(string(text)) + if err == nil && len(bytes) == len(*k) { + copy(k[:], bytes) + return nil + } + } + *k = NewKey(string(text)) + return nil +} + +func (k *Key) MarshalText() ([]byte, error) { + return []byte(k.String()), nil +} diff --git a/client/namespace.go b/client/namespace.go new file mode 100644 index 0000000..ade5036 --- /dev/null +++ b/client/namespace.go @@ -0,0 +1,44 @@ +package client + +import ( + "regexp" + + "github.com/alecthomas/errors" +) + +// DefaultNamespace is used when a namespace is not explicitly specified. +const DefaultNamespace Namespace = "default" + +var namespaceRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`) + +// Namespace identifies a logical partition within a cache or metadata store. +// Valid names start with an alphanumeric character and contain only +// alphanumerics, hyphens, and underscores. +type Namespace string + +// ValidateNamespace checks that a namespace name is valid. +func ValidateNamespace(name string) error { + if !namespaceRe.MatchString(name) { + return errors.Errorf("invalid namespace %q: must match %s", name, namespaceRe) + } + return nil +} + +// ParseNamespace validates and returns a Namespace from a plain string. +func ParseNamespace(name string) (Namespace, error) { + if err := ValidateNamespace(name); err != nil { + return "", err + } + return Namespace(name), nil +} + +func (n *Namespace) String() string { return string(*n) } + +// UnmarshalText implements encoding.TextUnmarshaler with validation. +func (n *Namespace) UnmarshalText(text []byte) error { + if err := ValidateNamespace(string(text)); err != nil { + return err + } + *n = Namespace(text) + return nil +} diff --git a/client/snapshot.go b/client/snapshot.go new file mode 100644 index 0000000..6d24189 --- /dev/null +++ b/client/snapshot.go @@ -0,0 +1,71 @@ +package client + +import ( + "context" + "fmt" + "net/http" + "path/filepath" + "time" + + "github.com/alecthomas/errors" +) + +// SnapshotOptions control how an archive is created and uploaded. +type SnapshotOptions struct { + // TTL for the uploaded object. Zero uses the server default. + TTL time.Duration + // Exclude patterns (tar --exclude syntax). + Exclude []string + // ZstdThreads controls zstd parallelism; 0 uses all CPU cores. + ZstdThreads int + // ExtraHeaders are merged into the upload headers alongside Content-Type + // and Content-Disposition. + ExtraHeaders http.Header +} + +// RestoreOptions control how an archive is downloaded and extracted. +type RestoreOptions struct { + // ZstdThreads controls zstd parallelism; 0 uses all CPU cores. + ZstdThreads int +} + +// Snapshot archives a directory and uploads the tar+zstd stream under the +// given key. +func (c *Client) Snapshot(ctx context.Context, key Key, directory string, opts SnapshotOptions) error { + return c.SnapshotPaths(ctx, key, directory, filepath.Base(directory), []string{"."}, opts) +} + +// SnapshotPaths archives named paths within baseDir and uploads the tar+zstd +// stream under the given key. archiveName is used to set the upload's +// Content-Disposition filename. +func (c *Client) SnapshotPaths(ctx context.Context, key Key, baseDir, archiveName string, includePaths []string, opts SnapshotOptions) error { + headers := make(http.Header) + headers.Set("Content-Type", "application/zstd") + headers.Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", archiveName+".tar.zst")) + for k, values := range opts.ExtraHeaders { + for _, v := range values { + headers.Set(k, v) + } + } + + wc, err := c.Create(ctx, key, headers, opts.TTL) + if err != nil { + return errors.Wrap(err, "failed to create object") + } + + if err := Archive(ctx, wc, baseDir, includePaths, opts.Exclude, opts.ZstdThreads); err != nil { + return errors.Join(err, wc.Close()) + } + return errors.Wrap(wc.Close(), "failed to close writer") +} + +// Restore downloads an archive by key and extracts it into directory. +func (c *Client) Restore(ctx context.Context, key Key, directory string, opts RestoreOptions) error { + rc, _, err := c.Open(ctx, key) + if err != nil { + return errors.Wrap(err, "failed to open object") + } + defer rc.Close() + + return errors.WithStack(Extract(ctx, rc, directory, opts.ZstdThreads)) +} diff --git a/client/stats.go b/client/stats.go new file mode 100644 index 0000000..7dc844a --- /dev/null +++ b/client/stats.go @@ -0,0 +1,16 @@ +package client + +import "github.com/alecthomas/errors" + +// ErrStatsUnavailable is returned when a cache backend cannot provide statistics. +var ErrStatsUnavailable = errors.New("stats unavailable") + +// Stats contains health and usage statistics for a cache. +type Stats struct { + // Objects is the number of objects currently in the cache. + Objects int64 `json:"objects"` + // Size is the total size of all objects in the cache in bytes. + Size int64 `json:"size"` + // Capacity is the maximum size of the cache in bytes (0 if unlimited). + Capacity int64 `json:"capacity"` +} diff --git a/cmd/cachew/main.go b/cmd/cachew/main.go index 97ff618..3737bc0 100644 --- a/cmd/cachew/main.go +++ b/cmd/cachew/main.go @@ -13,9 +13,8 @@ import ( "github.com/alecthomas/errors" "github.com/alecthomas/kong" - "github.com/block/cachew/internal/cache" + "github.com/block/cachew/client" "github.com/block/cachew/internal/logging" - "github.com/block/cachew/internal/snapshot" ) type CLI struct { @@ -45,33 +44,31 @@ func main() { ctx := context.Background() _, ctx = logging.Configure(ctx, cli.LoggingConfig) - var headerFunc cache.HeaderFunc + var headerFunc client.HeaderFunc if cli.Authorization != "" { headerFunc = func() http.Header { return http.Header{"Authorization": {cli.Authorization}} } } - remote := cache.NewRemote(cli.URL, headerFunc) - defer remote.Close() - httpClient := cache.NewHTTPClient(headerFunc) + c := client.New(cli.URL, headerFunc) + defer c.Close() kctx.BindTo(ctx, (*context.Context)(nil)) - kctx.BindTo(remote, (*cache.Cache)(nil)) - kctx.Bind(httpClient) + kctx.Bind(c) + kctx.Bind(c.HTTP()) kctx.FatalIfErrorf(kctx.Run(ctx)) } type GetCmd struct { - Namespace cache.Namespace `arg:"" help:"Namespace for organizing cache objects."` - Key PlatformKey `arg:"" help:"Object key (hex or string)."` - Output *os.File `short:"o" help:"Output file (default: stdout)." default:"-"` + Namespace client.Namespace `arg:"" help:"Namespace for organizing cache objects."` + Key PlatformKey `arg:"" help:"Object key (hex or string)."` + Output *os.File `short:"o" help:"Output file (default: stdout)." default:"-"` } -func (c *GetCmd) Run(ctx context.Context, cache cache.Cache) error { +func (c *GetCmd) Run(ctx context.Context, api *client.Client) error { defer c.Output.Close() - namespacedCache := cache.Namespace(c.Namespace) - rc, headers, err := namespacedCache.Open(ctx, c.Key.Key()) + rc, headers, err := api.Namespace(c.Namespace).Open(ctx, c.Key.Key()) if err != nil { return errors.Wrap(err, "failed to open object") } @@ -88,13 +85,12 @@ func (c *GetCmd) Run(ctx context.Context, cache cache.Cache) error { } type StatCmd struct { - Namespace cache.Namespace `arg:"" help:"Namespace for organizing cache objects."` - Key PlatformKey `arg:"" help:"Object key (hex or string)."` + Namespace client.Namespace `arg:"" help:"Namespace for organizing cache objects."` + Key PlatformKey `arg:"" help:"Object key (hex or string)."` } -func (c *StatCmd) Run(ctx context.Context, cache cache.Cache) error { - namespacedCache := cache.Namespace(c.Namespace) - headers, err := namespacedCache.Stat(ctx, c.Key.Key()) +func (c *StatCmd) Run(ctx context.Context, api *client.Client) error { + headers, err := api.Namespace(c.Namespace).Stat(ctx, c.Key.Key()) if err != nil { return errors.Wrap(err, "failed to stat object") } @@ -109,14 +105,14 @@ func (c *StatCmd) Run(ctx context.Context, cache cache.Cache) error { } type PutCmd struct { - Namespace cache.Namespace `arg:"" help:"Namespace for organizing cache objects."` + Namespace client.Namespace `arg:"" help:"Namespace for organizing cache objects."` Key PlatformKey `arg:"" help:"Object key (hex or string)."` Input *os.File `arg:"" help:"Input file (default: stdin)." default:"-"` TTL time.Duration `help:"Time to live for the object."` Headers map[string]string `short:"H" help:"Additional headers (key=value)."` } -func (c *PutCmd) Run(ctx context.Context, cache cache.Cache) error { +func (c *PutCmd) Run(ctx context.Context, api *client.Client) error { defer c.Input.Close() headers := make(http.Header) @@ -128,8 +124,7 @@ func (c *PutCmd) Run(ctx context.Context, cache cache.Cache) error { headers.Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filepath.Base(filename))) //nolint:perfsprint } - namespacedCache := cache.Namespace(c.Namespace) - wc, err := namespacedCache.Create(ctx, c.Key.Key(), headers, c.TTL) + wc, err := api.Namespace(c.Namespace).Create(ctx, c.Key.Key(), headers, c.TTL) if err != nil { return errors.Wrap(err, "failed to create object") } @@ -142,19 +137,18 @@ func (c *PutCmd) Run(ctx context.Context, cache cache.Cache) error { } type DeleteCmd struct { - Namespace cache.Namespace `arg:"" help:"Namespace for organizing cache objects."` - Key PlatformKey `arg:"" help:"Object key (hex or string)."` + Namespace client.Namespace `arg:"" help:"Namespace for organizing cache objects."` + Key PlatformKey `arg:"" help:"Object key (hex or string)."` } -func (c *DeleteCmd) Run(ctx context.Context, cache cache.Cache) error { - namespacedCache := cache.Namespace(c.Namespace) - return errors.Wrap(namespacedCache.Delete(ctx, c.Key.Key()), "failed to delete object") +func (c *DeleteCmd) Run(ctx context.Context, api *client.Client) error { + return errors.Wrap(api.Namespace(c.Namespace).Delete(ctx, c.Key.Key()), "failed to delete object") } type NamespacesCmd struct{} -func (c *NamespacesCmd) Run(ctx context.Context, cache cache.Cache) error { - namespaces, err := cache.ListNamespaces(ctx) +func (c *NamespacesCmd) Run(ctx context.Context, api *client.Client) error { + namespaces, err := api.ListNamespaces(ctx) if err != nil { return errors.Wrap(err, "failed to list namespaces") } @@ -171,18 +165,22 @@ func (c *NamespacesCmd) Run(ctx context.Context, cache cache.Cache) error { } type SnapshotCmd struct { - Namespace cache.Namespace `arg:"" help:"Namespace for organizing cache objects."` - Key PlatformKey `arg:"" help:"Object key (hex or string)."` - Directory string `arg:"" help:"Directory to archive." type:"path"` - TTL time.Duration `help:"Time to live for the object."` - Exclude []string `help:"Patterns to exclude (tar --exclude syntax)."` - ZstdThreads int `help:"Threads for zstd compression (0 = all CPU cores)." default:"0"` + Namespace client.Namespace `arg:"" help:"Namespace for organizing cache objects."` + Key PlatformKey `arg:"" help:"Object key (hex or string)."` + Directory string `arg:"" help:"Directory to archive." type:"path"` + TTL time.Duration `help:"Time to live for the object."` + Exclude []string `help:"Patterns to exclude (tar --exclude syntax)."` + ZstdThreads int `help:"Threads for zstd compression (0 = all CPU cores)." default:"0"` } -func (c *SnapshotCmd) Run(ctx context.Context, cache cache.Cache) error { +func (c *SnapshotCmd) Run(ctx context.Context, api *client.Client) error { fmt.Fprintf(os.Stderr, "Archiving %s...\n", c.Directory) //nolint:forbidigo - namespacedCache := cache.Namespace(c.Namespace) - if err := snapshot.Create(ctx, namespacedCache, c.Key.Key(), c.Directory, c.TTL, c.Exclude, c.ZstdThreads); err != nil { + opts := client.SnapshotOptions{ + TTL: c.TTL, + Exclude: c.Exclude, + ZstdThreads: c.ZstdThreads, + } + if err := api.Namespace(c.Namespace).Snapshot(ctx, c.Key.Key(), c.Directory, opts); err != nil { return errors.Wrap(err, "failed to create snapshot") } @@ -191,16 +189,16 @@ func (c *SnapshotCmd) Run(ctx context.Context, cache cache.Cache) error { } type RestoreCmd struct { - Namespace cache.Namespace `arg:"" help:"Namespace for organizing cache objects."` - Key PlatformKey `arg:"" help:"Object key (hex or string)."` - Directory string `arg:"" help:"Target directory for extraction." type:"path"` - ZstdThreads int `help:"Threads for zstd decompression (0 = all CPU cores)." default:"0"` + Namespace client.Namespace `arg:"" help:"Namespace for organizing cache objects."` + Key PlatformKey `arg:"" help:"Object key (hex or string)."` + Directory string `arg:"" help:"Target directory for extraction." type:"path"` + ZstdThreads int `help:"Threads for zstd decompression (0 = all CPU cores)." default:"0"` } -func (c *RestoreCmd) Run(ctx context.Context, cache cache.Cache) error { +func (c *RestoreCmd) Run(ctx context.Context, api *client.Client) error { fmt.Fprintf(os.Stderr, "Restoring to %s...\n", c.Directory) //nolint:forbidigo - namespacedCache := cache.Namespace(c.Namespace) - if err := snapshot.Restore(ctx, namespacedCache, c.Key.Key(), c.Directory, c.ZstdThreads); err != nil { + opts := client.RestoreOptions{ZstdThreads: c.ZstdThreads} + if err := api.Namespace(c.Namespace).Restore(ctx, c.Key.Key(), c.Directory, opts); err != nil { return errors.Wrap(err, "failed to restore snapshot") } @@ -221,10 +219,10 @@ func getFilename(f *os.File) string { return f.Name() } -// PlatformKey wraps a cache.Key and stores the original input for platform prefixing. +// PlatformKey wraps a client.Key and stores the original input for platform prefixing. type PlatformKey struct { raw string - key cache.Key + key client.Key } func (pk *PlatformKey) UnmarshalText(text []byte) error { @@ -232,7 +230,7 @@ func (pk *PlatformKey) UnmarshalText(text []byte) error { return errors.WithStack(pk.key.UnmarshalText(text)) } -func (pk *PlatformKey) Key() cache.Key { +func (pk *PlatformKey) Key() client.Key { return pk.key } @@ -243,12 +241,10 @@ func (pk *PlatformKey) String() string { func (pk *PlatformKey) AfterApply(cli *CLI) error { prefixed := pk.raw - // Apply platform prefix if enabled if cli.Platform { prefixed = fmt.Sprintf("%s-%s-%s", runtime.GOOS, runtime.GOARCH, prefixed) } - // Apply time-based prefix if enabled (goes first in final order) now := time.Now() if cli.Hourly { prefixed = now.Format("2006-01-02-15-") + prefixed diff --git a/internal/cache/api.go b/internal/cache/api.go index de7a11a..3ea41a4 100644 --- a/internal/cache/api.go +++ b/internal/cache/api.go @@ -3,57 +3,33 @@ package cache import ( "context" - "crypto/sha256" - "encoding/hex" "io" "net/http" "os" - "regexp" "time" "github.com/alecthomas/errors" "github.com/alecthomas/hcl/v2" -) -var namespaceRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`) + "github.com/block/cachew/client" +) // Namespace identifies a logical partition within a cache or metadata store. -// Valid names start with an alphanumeric character and contain only -// alphanumerics, hyphens, and underscores. -type Namespace string +type Namespace = client.Namespace // ValidateNamespace checks that a namespace name is valid. -func ValidateNamespace(name string) error { - if !namespaceRe.MatchString(name) { - return errors.Errorf("invalid namespace %q: must match %s", name, namespaceRe) - } - return nil -} +func ValidateNamespace(name string) error { return errors.WithStack(client.ValidateNamespace(name)) } // ParseNamespace validates and returns a Namespace from a plain string. func ParseNamespace(name string) (Namespace, error) { - if err := ValidateNamespace(name); err != nil { - return "", err - } - return Namespace(name), nil -} - -func (n *Namespace) String() string { return string(*n) } - -// UnmarshalText implements encoding.TextUnmarshaler with validation. -func (n *Namespace) UnmarshalText(text []byte) error { - if err := ValidateNamespace(string(text)); err != nil { - return err - } - *n = Namespace(text) - return nil + return errors.WithStack2(client.ParseNamespace(name)) } // ErrNotFound is returned when a cache backend is not found. var ErrNotFound = errors.New("cache backend not found") // ErrStatsUnavailable is returned when a cache backend cannot provide statistics. -var ErrStatsUnavailable = errors.New("stats unavailable") +var ErrStatsUnavailable = client.ErrStatsUnavailable type registryEntry struct { schema *hcl.Block @@ -130,45 +106,16 @@ func (r *Registry) Create(ctx context.Context, name string, config *hcl.Block, v } // Key represents a unique identifier for a cached object. -type Key [32]byte +type Key = client.Key // ParseKey from its hex-encoded string form. -func ParseKey(key string) (Key, error) { - var k Key - return k, k.UnmarshalText([]byte(key)) -} - -func NewKey(url string) Key { return Key(sha256.Sum256([]byte(url))) } +func ParseKey(key string) (Key, error) { return errors.WithStack2(client.ParseKey(key)) } -func (k *Key) String() string { return hex.EncodeToString(k[:]) } - -func (k *Key) UnmarshalText(text []byte) error { - // Try to decode as SHA256 hex encoded string - if len(text) == 64 { - bytes, err := hex.DecodeString(string(text)) - if err == nil && len(bytes) == len(*k) { - copy(k[:], bytes) - return nil - } - } - // If not valid hex, treat as string and SHA256 it - *k = NewKey(string(text)) - return nil -} - -func (k *Key) MarshalText() ([]byte, error) { - return []byte(k.String()), nil -} +// NewKey returns the SHA256 of s. +func NewKey(s string) Key { return client.NewKey(s) } // Stats contains health and usage statistics for a cache. -type Stats struct { - // Objects is the number of objects currently in the cache. - Objects int64 `json:"objects"` - // Size is the total size of all objects in the cache in bytes. - Size int64 `json:"size"` - // Capacity is the maximum size of the cache in bytes (0 if unlimited). - Capacity int64 `json:"capacity"` -} +type Stats = client.Stats // A Cache knows how to retrieve, create and delete objects from a cache. // diff --git a/internal/cache/remote.go b/internal/cache/remote.go index f0ed89f..d4cce7e 100644 --- a/internal/cache/remote.go +++ b/internal/cache/remote.go @@ -2,310 +2,65 @@ package cache import ( "context" - "encoding/json" - "fmt" "io" - "maps" "net/http" - "os" "time" "github.com/alecthomas/errors" - "github.com/block/cachew/internal/httputil" + "github.com/block/cachew/client" ) -const defaultNamespace Namespace = "default" - -// Remote implements Cache as a client for the remote cache server. -type Remote struct { - baseURL string - client *http.Client - namespace Namespace -} - -var _ Cache = (*Remote)(nil) - // HeaderFunc returns headers to attach to each outgoing request. -type HeaderFunc func() http.Header +type HeaderFunc = client.HeaderFunc // NewHTTPClient creates an *http.Client that attaches headerFunc headers -// to every outgoing request. Useful for callers that need to talk to -// non-API endpoints (e.g. /git/) with the same auth as the cache client. -func NewHTTPClient(headerFunc HeaderFunc) *http.Client { - transport := http.DefaultTransport.(*http.Transport).Clone() //nolint:errcheck - transport.MaxIdleConns = 100 - transport.MaxIdleConnsPerHost = 100 +// to every outgoing request. +func NewHTTPClient(headerFunc HeaderFunc) *http.Client { return client.NewHTTPClient(headerFunc) } - var rt http.RoundTripper = transport - if headerFunc != nil { - rt = &headerTransport{base: transport, headerFunc: headerFunc} - } - return &http.Client{Transport: rt} +// Remote implements Cache as a client for the remote cache server, wrapping +// a *client.Client. +type Remote struct { + c *client.Client } +var _ Cache = (*Remote)(nil) + // NewRemote creates a new remote cache client. If headerFunc is non-nil, // its returned headers are added to every outgoing request. func NewRemote(baseURL string, headerFunc HeaderFunc) *Remote { - return &Remote{ - baseURL: baseURL + "/api/v1", - client: NewHTTPClient(headerFunc), - } + return &Remote{c: client.New(baseURL, headerFunc)} } -type headerTransport struct { - base http.RoundTripper - headerFunc HeaderFunc -} +func (r *Remote) String() string { return r.c.String() } -func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { - for key, values := range t.headerFunc() { - for _, value := range values { - req.Header.Add(key, value) - } - } - return t.base.RoundTrip(req) //nolint:wrapcheck +func (r *Remote) Namespace(namespace Namespace) Cache { + return &Remote{c: r.c.Namespace(namespace)} } -func (c *Remote) String() string { return "remote:" + c.baseURL } - -// Open retrieves an object from the remote. -func (c *Remote) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) { - namespace := c.namespace - if namespace == "" { - namespace = defaultNamespace - } - url := fmt.Sprintf("%s/object/%s/%s", c.baseURL, namespace, key.String()) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return nil, nil, errors.Wrap(err, "failed to create request") - } - - resp, err := c.client.Do(req) - if err != nil { - return nil, nil, errors.Wrap(err, "failed to execute request") - } - - if resp.StatusCode == http.StatusNotFound { - _, _ = io.Copy(io.Discard, resp.Body) //nolint:errcheck,gosec - return nil, nil, errors.Join(os.ErrNotExist, resp.Body.Close()) - } - - if resp.StatusCode != http.StatusOK { - _, _ = io.Copy(io.Discard, resp.Body) //nolint:errcheck,gosec - return nil, nil, errors.Join(errors.Errorf("unexpected status code: %d", resp.StatusCode), resp.Body.Close()) - } - - // Filter out HTTP transport headers - headers := httputil.FilterHeaders(resp.Header, httputil.TransportHeaders...) - - return resp.Body, headers, nil +func (r *Remote) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) { + rc, h, err := r.c.Open(ctx, key) + return rc, h, errors.WithStack(err) } -// Stat retrieves headers for an object from the remote. -func (c *Remote) Stat(ctx context.Context, key Key) (http.Header, error) { - namespace := c.namespace - if namespace == "" { - namespace = defaultNamespace - } - url := fmt.Sprintf("%s/object/%s/%s", c.baseURL, namespace, key.String()) - req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) - if err != nil { - return nil, errors.Wrap(err, "failed to create request") - } - - resp, err := c.client.Do(req) - if err != nil { - return nil, errors.Wrap(err, "failed to execute request") - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusNotFound { - return nil, os.ErrNotExist - } - - if resp.StatusCode != http.StatusOK { - return nil, errors.Errorf("unexpected status code: %d", resp.StatusCode) - } - - // Filter out HTTP transport headers - headers := httputil.FilterHeaders(resp.Header, httputil.TransportHeaders...) - - return headers, nil +func (r *Remote) Stat(ctx context.Context, key Key) (http.Header, error) { + return errors.WithStack2(r.c.Stat(ctx, key)) } -// Create stores a new object in the remote. -func (c *Remote) Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (io.WriteCloser, error) { - pr, pw := io.Pipe() - - namespace := c.namespace - if namespace == "" { - namespace = defaultNamespace - } - url := fmt.Sprintf("%s/object/%s/%s", c.baseURL, namespace, key.String()) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, pr) - if err != nil { - return nil, errors.Join(errors.Wrap(err, "failed to create request"), pr.Close(), pw.Close()) - } - - maps.Copy(req.Header, headers) - - if ttl > 0 { - req.Header.Set("Time-To-Live", ttl.String()) - } - - wc := &writeCloser{ - pw: pw, - done: make(chan error, 1), - ctx: ctx, - } - - go func() { - resp, err := c.client.Do(req) - if err != nil { - wc.done <- errors.Wrap(err, "failed to execute request") - return - } - _, _ = io.Copy(io.Discard, resp.Body) //nolint:errcheck,gosec - _ = resp.Body.Close() //nolint:gosec - - if resp.StatusCode != http.StatusOK { - wc.done <- errors.Errorf("unexpected status code: %d", resp.StatusCode) - return - } - - wc.done <- nil - }() - - return wc, nil +func (r *Remote) Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (io.WriteCloser, error) { + return errors.WithStack2(r.c.Create(ctx, key, headers, ttl)) } -// Delete removes an object from the remote. -func (c *Remote) Delete(ctx context.Context, key Key) error { - namespace := c.namespace - if namespace == "" { - namespace = defaultNamespace - } - url := fmt.Sprintf("%s/object/%s/%s", c.baseURL, namespace, key.String()) - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil) - if err != nil { - return errors.Wrap(err, "failed to create request") - } - - resp, err := c.client.Do(req) - if err != nil { - return errors.Wrap(err, "failed to execute request") - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusNotFound { - return os.ErrNotExist - } - - if resp.StatusCode != http.StatusOK { - return errors.Errorf("unexpected status code: %d", resp.StatusCode) - } - - return nil +func (r *Remote) Delete(ctx context.Context, key Key) error { + return errors.WithStack(r.c.Delete(ctx, key)) } -// Close closes the client and releases resources. -func (c *Remote) Close() error { - c.client.CloseIdleConnections() - return nil +func (r *Remote) Stats(ctx context.Context) (Stats, error) { + return errors.WithStack2(r.c.Stats(ctx)) } -// Stats retrieves cache statistics from the remote server. -func (c *Remote) Stats(ctx context.Context) (Stats, error) { - url := c.baseURL + "/stats" - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return Stats{}, errors.Wrap(err, "failed to create request") - } - - resp, err := c.client.Do(req) - if err != nil { - return Stats{}, errors.Wrap(err, "failed to execute request") - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusNotImplemented { - return Stats{}, ErrStatsUnavailable - } - - if resp.StatusCode != http.StatusOK { - return Stats{}, errors.Errorf("unexpected status code: %d", resp.StatusCode) - } - - var stats Stats - if err := json.NewDecoder(resp.Body).Decode(&stats); err != nil { - return Stats{}, errors.Wrap(err, "failed to decode stats response") - } - - return stats, nil +func (r *Remote) ListNamespaces(ctx context.Context) ([]string, error) { + return errors.WithStack2(r.c.ListNamespaces(ctx)) } -// writeCloser wraps a pipe writer and waits for the HTTP request to complete. -type writeCloser struct { - pw *io.PipeWriter - done chan error - ctx context.Context -} - -func (wc *writeCloser) Write(p []byte) (int, error) { - n, err := wc.pw.Write(p) - return n, errors.WithStack(err) -} - -func (wc *writeCloser) Close() error { - if err := wc.ctx.Err(); err != nil { - _ = wc.pw.CloseWithError(err) - <-wc.done // Wait for goroutine to finish and release connection - return errors.Wrap(err, "create operation cancelled") - } - if err := wc.pw.Close(); err != nil { - <-wc.done // Wait for goroutine to finish and release connection - return errors.Wrap(err, "failed to close pipe writer") - } - err := <-wc.done - if err != nil { - return errors.Wrap(err, "request failed") - } - return nil -} - -// Namespace creates a namespaced view of the remote cache. -func (c *Remote) Namespace(namespace Namespace) Cache { - return &Remote{ - baseURL: c.baseURL, - client: c.client, - namespace: namespace, - } -} - -// ListNamespaces requests namespace list from the remote server. -func (c *Remote) ListNamespaces(ctx context.Context) ([]string, error) { - url := c.baseURL + "/namespaces" - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return nil, errors.WithStack(err) - } - - resp, err := c.client.Do(req) - if err != nil { - return nil, errors.WithStack(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) //nolint:errcheck - return nil, errors.Errorf("unexpected status %d: %s", resp.StatusCode, body) - } - - var namespaces []string - if err := json.NewDecoder(resp.Body).Decode(&namespaces); err != nil { - return nil, errors.WithStack(err) - } - - return namespaces, nil -} +func (r *Remote) Close() error { return errors.WithStack(r.c.Close()) } diff --git a/internal/snapshot/snapshot.go b/internal/snapshot/snapshot.go index 20e50a1..6d1e4db 100644 --- a/internal/snapshot/snapshot.go +++ b/internal/snapshot/snapshot.go @@ -2,19 +2,16 @@ package snapshot import ( - "bytes" "context" "fmt" "io" "net/http" - "os" - "os/exec" "path/filepath" - "runtime" "time" "github.com/alecthomas/errors" + "github.com/block/cachew/client" "github.com/block/cachew/internal/cache" ) @@ -35,31 +32,9 @@ func Create(ctx context.Context, remote cache.Cache, key cache.Key, directory st // // The archive preserves all file permissions, ownership, and symlinks. // Each entry in includePaths is archived relative to baseDir and must exist. -// This allows callers to archive either an entire directory with "." or a -// specific subtree such as "lfs" while preserving that relative path prefix. // Exclude patterns use tar's --exclude syntax. // threads controls zstd parallelism; 0 uses all available CPU cores. func CreatePaths(ctx context.Context, remote cache.Cache, key cache.Key, baseDir, archiveName string, includePaths []string, ttl time.Duration, excludePatterns []string, threads int, extraHeaders ...http.Header) error { - if threads <= 0 { - threads = runtime.NumCPU() - } - - if len(includePaths) == 0 { - return errors.New("includePaths must not be empty") - } - - if info, err := os.Stat(baseDir); err != nil { - return errors.Wrap(err, "failed to stat base directory") - } else if !info.IsDir() { - return errors.Errorf("not a directory: %s", baseDir) - } - for _, path := range includePaths { - targetPath := filepath.Join(baseDir, path) - if _, err := os.Stat(targetPath); err != nil { - return errors.Wrapf(err, "failed to stat include path %q", path) - } - } - headers := make(http.Header) headers.Set("Content-Type", "application/zstd") headers.Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", archiveName+".tar.zst")) @@ -76,129 +51,18 @@ func CreatePaths(ctx context.Context, remote cache.Cache, key cache.Key, baseDir return errors.Wrap(err, "failed to create object") } - tarArgs := []string{"-cpf", "-", "-C", baseDir} - for _, pattern := range excludePatterns { - tarArgs = append(tarArgs, "--exclude", pattern) - } - tarArgs = append(tarArgs, "--") - tarArgs = append(tarArgs, includePaths...) - - if err := runTarZstdPipeline(ctx, tarArgs, threads, wc); err != nil { + if err := client.Archive(ctx, wc, baseDir, includePaths, excludePatterns, threads); err != nil { return errors.Join(err, wc.Close()) } return errors.Wrap(wc.Close(), "failed to close writer") } -// runTarZstdPipeline runs tar piped through zstd, writing compressed output to w. -// The caller is responsible for closing w after this returns. -func runTarZstdPipeline(ctx context.Context, tarArgs []string, threads int, w io.Writer) error { - tarCmd := exec.CommandContext(ctx, "tar", tarArgs...) - zstdCmd := exec.CommandContext(ctx, "zstd", "-c", fmt.Sprintf("-T%d", threads)) //nolint:gosec // threads is a validated integer, not user input - - // Use manual pipe so we can close both ends in the parent after starting - // children. This prevents deadlock if zstd exits while tar is still writing: - // closing the read end ensures tar receives SIGPIPE instead of blocking. - pr, pw, err := os.Pipe() - if err != nil { - return errors.Wrap(err, "failed to create pipe") - } - - var tarStderr, zstdStderr bytes.Buffer - tarCmd.Stdout = pw - tarCmd.Stderr = &tarStderr - - zstdCmd.Stdin = pr - zstdCmd.Stdout = w - zstdCmd.Stderr = &zstdStderr - - if err := tarCmd.Start(); err != nil { - pw.Close() //nolint:errcheck,gosec // best-effort cleanup - pr.Close() //nolint:errcheck,gosec // best-effort cleanup - return errors.Wrap(err, "failed to start tar") - } - pw.Close() //nolint:errcheck,gosec // parent no longer needs write end; tar holds its own copy - - if err := zstdCmd.Start(); err != nil { - pr.Close() //nolint:errcheck,gosec // let tar receive SIGPIPE so it exits - return errors.Join(errors.Wrap(err, "failed to start zstd"), tarCmd.Wait()) - } - pr.Close() //nolint:errcheck,gosec // parent no longer needs read end; if zstd dies, tar gets SIGPIPE - - tarErr := tarCmd.Wait() - zstdErr := zstdCmd.Wait() - - var errs []error - if tarErr != nil { - errs = append(errs, errors.Errorf("tar failed: %w: %s", tarErr, tarStderr.String())) - } - if zstdErr != nil { - errs = append(errs, errors.Errorf("zstd failed: %w: %s", zstdErr, zstdStderr.String())) - } - return errors.Join(errs...) -} - // StreamTo archives a directory using tar with zstd compression and streams the // output directly to w. Unlike Create, it does not upload to any cache backend. // This is used on cache miss to serve the client immediately while a background // job populates the cache. func StreamTo(ctx context.Context, w io.Writer, directory string, excludePatterns []string, threads int) error { - if threads <= 0 { - threads = runtime.NumCPU() - } - - if info, err := os.Stat(directory); err != nil { - return errors.Wrap(err, "failed to stat directory") - } else if !info.IsDir() { - return errors.Errorf("not a directory: %s", directory) - } - - tarArgs := []string{"-cpf", "-", "-C", directory} - for _, pattern := range excludePatterns { - tarArgs = append(tarArgs, "--exclude", pattern) - } - tarArgs = append(tarArgs, ".") - - tarCmd := exec.CommandContext(ctx, "tar", tarArgs...) - zstdCmd := exec.CommandContext(ctx, "zstd", "-c", fmt.Sprintf("-T%d", threads)) //nolint:gosec // threads is a validated integer, not user input - - pr, pw, err := os.Pipe() - if err != nil { - return errors.Wrap(err, "failed to create pipe") - } - - var tarStderr, zstdStderr bytes.Buffer - tarCmd.Stdout = pw - tarCmd.Stderr = &tarStderr - - zstdCmd.Stdin = pr - zstdCmd.Stdout = w - zstdCmd.Stderr = &zstdStderr - - if err := tarCmd.Start(); err != nil { - pw.Close() //nolint:errcheck,gosec // best-effort cleanup - pr.Close() //nolint:errcheck,gosec // best-effort cleanup - return errors.Wrap(err, "failed to start tar") - } - pw.Close() //nolint:errcheck,gosec // parent no longer needs write end; tar holds its own copy - - if err := zstdCmd.Start(); err != nil { - pr.Close() //nolint:errcheck,gosec // let tar receive SIGPIPE so it exits - return errors.Join(errors.Wrap(err, "failed to start zstd"), tarCmd.Wait()) - } - pr.Close() //nolint:errcheck,gosec // parent no longer needs read end; if zstd dies, tar gets SIGPIPE - - tarErr := tarCmd.Wait() - zstdErr := zstdCmd.Wait() - - var errs []error - if tarErr != nil { - errs = append(errs, errors.Errorf("tar failed: %w: %s", tarErr, tarStderr.String())) - } - if zstdErr != nil { - errs = append(errs, errors.Errorf("zstd failed: %w: %s", zstdErr, zstdStderr.String())) - } - - return errors.Join(errs...) + return errors.WithStack(client.Archive(ctx, w, directory, []string{"."}, excludePatterns, threads)) } // Restore downloads an archive from the cache and extracts it to a directory. @@ -221,53 +85,5 @@ func Restore(ctx context.Context, remote cache.Cache, key cache.Key, directory s // permissions, ownership, and symlinks. threads controls zstd parallelism; // 0 uses all available CPU cores. func Extract(ctx context.Context, r io.Reader, directory string, threads int) error { - if threads <= 0 { - threads = runtime.NumCPU() - } - - if err := os.MkdirAll(directory, 0o750); err != nil { - return errors.Wrap(err, "failed to create target directory") - } - - zstdCmd := exec.CommandContext(ctx, "zstd", "-dc", fmt.Sprintf("-T%d", threads)) //nolint:gosec // threads is a validated integer, not user input - tarCmd := exec.CommandContext(ctx, "tar", "-xpf", "-", "-C", directory) - - pr, pw, err := os.Pipe() - if err != nil { - return errors.Wrap(err, "failed to create pipe") - } - - var zstdStderr, tarStderr bytes.Buffer - zstdCmd.Stdin = r - zstdCmd.Stdout = pw - zstdCmd.Stderr = &zstdStderr - - tarCmd.Stdin = pr - tarCmd.Stderr = &tarStderr - - if err := zstdCmd.Start(); err != nil { - pw.Close() //nolint:errcheck,gosec // best-effort cleanup - pr.Close() //nolint:errcheck,gosec // best-effort cleanup - return errors.Wrap(err, "failed to start zstd") - } - pw.Close() //nolint:errcheck,gosec // parent no longer needs write end; zstd holds its own copy - - if err := tarCmd.Start(); err != nil { - pr.Close() //nolint:errcheck,gosec // let zstd receive SIGPIPE so it exits - return errors.Join(errors.Wrap(err, "failed to start tar"), zstdCmd.Wait()) - } - pr.Close() //nolint:errcheck,gosec // parent no longer needs read end; if tar dies, zstd gets SIGPIPE - - zstdErr := zstdCmd.Wait() - tarErr := tarCmd.Wait() - - var errs []error - if zstdErr != nil { - errs = append(errs, errors.Errorf("zstd failed: %w: %s", zstdErr, zstdStderr.String())) - } - if tarErr != nil { - errs = append(errs, errors.Errorf("tar failed: %w: %s", tarErr, tarStderr.String())) - } - - return errors.Join(errs...) + return errors.WithStack(client.Extract(ctx, r, directory, threads)) }