diff --git a/.claude/rules/base.md b/.claude/rules/base.md new file mode 100644 index 0000000..6967553 --- /dev/null +++ b/.claude/rules/base.md @@ -0,0 +1,21 @@ +# Base Rules + +BEFORE DOING ANYTHING, RESEARCH THE BEST APPROACH. ONCE YOU HAVE DONE SO, COME UP WITH A PLAN, PAUSE AND PROMPT FOR CONFIRMATION. + +DO NOT ASSUME I AM RIGHT, VERIFY WHAT I ASSERT + +- When working on a list of tasks in a README.md bullet list `- [ ] ...`, pick the next incomplete one and implement that, mark it as complete, then stop. +- If you are in read-only "Ask" mode, and are asked to modify something, immediately abort saying you can't modify anything. +- Do exactly what I ask, no more. Don't add extra scripts, documentation, etc. +- Always run tests to verify correctness. +- Always write tests for updated/new code. +- Be succinct. +- Don't write comments if the related code itself is simple. +- If you're not sure of next steps, ask for clarification. +- Prefer to create helper functions rather than writing single giant functions. +- Search for existing functions and reuse them where possible, refactoring them if the old functionality and new desired functionality is similar. +- Be succinct when writing documentation and comments. +- Always use ripgrep rather than grep. +- Extend or create abstractions where appropriate, rather than inlining large amounts of bespoke code. +- For changes that are fairly mechanical across 10 or more locations, prefer to create a temporary script that makes the change in one shot. +- If asked to research or plan something, do NOT include detailed configuration or code, just human readable descriptions. diff --git a/.claude/rules/go.md b/.claude/rules/go.md new file mode 100644 index 0000000..d52de8c --- /dev/null +++ b/.claude/rules/go.md @@ -0,0 +1,44 @@ +## Go (Golang) Code + +- We are targeting Go 1.24 or newer. +- Use Go 1.22+'s new "for range" syntax everywhere possible. +- Combine multiple if clauses whose bodies do the same thing, into single expressions. +- Always use `any` rather than `interface{}` +- Use `github.com/alecthomas/errors` for errors if the project already uses it. It has `Errorf`, `New`, `Wrap`, etc. +- Always wrap errors, but try to be succinct if possible. +- Never use underscore in names. +- Use `github.com/alecthomas/assert/v2` for test assertions. In particular note that `assert.Equal()` performs a deep comparison. +- Prefer to compare whole objects rather than individual fields, using `assert.Equal(t, expected, actual, assert.Exclude[T]())` to exclude dynamic values like time. +- ALWAYS use table-driven tests if the tests can be parameterised on data. If not, just create distinct test functions. +- When writing "sub tests", their names MUST be UpperCamelCase. +- Test functions must always be UpperCamelCase, never with underscores. +- When writing code, avoid using `strings.Contains()` and string comparisons to compare types. Instead, use existing helper functions or methods, or write new ones. +- Where it makes sense, update existing test rather than creating new ones. +- ALWAYS run tests with `-timeout 30s` to ensure that wedged tests don't last forever. +- Don't run tests with `-v` in general, as it produces a large amount of output. +- Once the change is complete and working, run `golangci-lint run` and fix any linter errors introduced before adding the files to git. Do NOT EVER run `golangci-lint` on individual files. +- For "unparam" linter warnings about "XXX is unused", remove the parameter unless the type is part of an interface implementation or callback system. +- ALWAYS respect encapsulation of struct fields, even between types in the same package. +- ALWAYS apply the Go proverb "align the happy path to the left", to avoid deep nesting. + + eg. instead of: + ```go + if a, ok := doA(); ok { + if b, ok := doB(); ok { + // Code + } + } + ``` + + Do this: + + ```go + a, ok := doA() + if !ok { + continue // Or return + } + b, ok := doB() + if !ok { + continue // Or return + } + ``` diff --git a/.golangci.yml b/.golangci.yml index aff0fb5..72bd61a 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -223,10 +223,6 @@ linters: forbid-mutex: true errcheck: - # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. - # Such cases aren't reported by default. - # Default: false - check-type-assertions: true check-blank: true exhaustive: diff --git a/docs/git-strategy-research.md b/docs/git-strategy-research.md deleted file mode 100644 index 7ad0561..0000000 --- a/docs/git-strategy-research.md +++ /dev/null @@ -1,230 +0,0 @@ -# Git Caching Strategy Research - -## Goals - -1. Minimize impact on upstream Git servers -2. Make git clones as fast as possible -3. Efficiently handle incremental fetches - -## Three-Layer Approach - -### Layer 1: Snapshot Tarballs (Fastest Initial Clones) - -**Observation**: `tar` is significantly faster than Git at populating a repository because: -- No pack negotiation overhead -- No delta resolution computation -- Single sequential read/write operation -- Can use fast compression (zstd) - -**Approach**: -1. Cache server maintains full clones of upstream repositories -2. Generate daily tarballs of the full clone -3. Client downloads and extracts tarball, then runs `git fetch` to catch up - -**Client-side workflow**: -``` -# Instead of: git clone https://github.com/org/repo -cachew git clone https://github.com/org/repo -``` - -Under the hood: -1. Check if snapshot tarball exists for repo -2. Download and extract: curl ... | zstd -d | tar -xf - -3. Set remote URL to upstream (or through cache proxy) -4. git fetch to get any updates since snapshot -5. git checkout as normal - -### Layer 2: Daily Bundles (Fallback for Non-Tarball Clients) - -For clients that don't use the tarball option, daily bundles provide a simpler optimisation. - -**Approach**: -- Generate one daily bundle containing all refs -- Cache server advertises bundle URI via protocol v2 `bundle-uri` capability -- Client cloning through cache proxy automatically fetches bundle first -- Git then negotiates remaining objects via normal protocol - -### Layer 3: Git Protocol Proxy (Normal Fetches) - -Proxy `git-upload-pack` requests, always serving from the local clone. - -**Approach**: -- Cache server intercepts git protocol requests -- Always serves objects from local clone (never proxies to upstream) -- Local clone is kept fresh via periodic background fetches - -**Cache Key Strategy**: - -To cache packfile responses, normalize and hash the request: -``` -cache_key = hash(repo_url, sorted(want_refs), sorted(have_refs)) -``` - -**Normalization**: -- Sort want/have OIDs lexicographically -- Include repo identifier -- Optionally include filter spec (for partial clones) - -**Example**: -``` -wants: [abc123, def456, 789xyz] -haves: [111aaa, 222bbb] - -normalized = "{host}/{path}:wants=789xyz,abc123,def456:haves=111aaa,222bbb" -cache_key = sha256(normalized) -``` - -**Benefits**: -- Zero load on upstream for git protocol operations -- Multiple clients with same repo state get cache hits -- CI builds cloning same commit hit cache -- Works transparently with standard git - -**Considerations**: -- Local clone freshness depends on background fetch interval -- May need to handle shallow clones separately - -## Architecture - -``` -┌─────────────────────────────────────────────────────────────┐ -│ Cache Server │ -├─────────────────────────────────────────────────────────────┤ -│ │ -│ ┌─────────────────┐ ┌─────────────────────────────────┐ │ -│ │ Full Clone │ │ Daily Generators │ │ -│ │ Storage │───▶│ - Tarball snapshots (.tar.zst) │ │ -│ │ │ │ - Bundle files (.bundle) │ │ -│ │ /repos/ │ └─────────────────────────────────┘ │ -│ │ {host}/{path} │ │ │ -│ │ │ ▼ │ -│ └────────┬────────┘ ┌─────────────────────────────────┐ │ -│ │ │ Object Cache │ │ -│ │ │ - Snapshots │ │ -│ │ │ - Bundles │ │ -│ └────────────▶│ - Packfile responses │ │ -│ └─────────────────────────────────┘ │ -│ │ │ -│ ▼ │ -│ ┌─────────────────────────────────────────────────────────┐│ -│ │ HTTP Endpoints ││ -│ │ ││ -│ │ GET /git/{host}/{path}/snapshot.tar.zst ││ -│ │ GET /git/{host}/{path}/bundle.bundle ││ -│ │ POST /git/{host}/{path}/git-upload-pack ││ -│ │ ││ -│ └─────────────────────────────────────────────────────────┘│ -└─────────────────────────────────────────────────────────────┘ -``` - -### Client Options - -**Option A: Wrapper Script** (`cachew-git`) - Recommended -- Intercepts `clone` command -- Downloads snapshot tarball, extracts, fetches updates -- Falls back to bundle-uri or cached git protocol - -**Option B: Git Config Redirect** -- Configure `url..insteadOf` to redirect through cache -- Works with standard git commands -- Only benefits from protocol caching and bundles (no tarball support) - -### Data Flow: Initial Clone (Tarball Client) - -``` -Client Cache Server Upstream - │ │ │ - │ GET /snapshot.tar.zst │ │ - │────────────────────────────▶│ │ - │◀────────────────────────────│ (serve from cache) │ - │ tar -xf │ │ - │ │ │ - │ git fetch (via cache) │ │ - │────────────────────────────▶│ │ - │ │ (cache lookup by │ - │ │ hashed refs) │ - │◀────────────────────────────│ │ -``` - -### Data Flow: Normal Git Clone (Protocol Proxy) - -``` -Client Cache Server Upstream - │ │ │ - │ git-upload-pack │ │ - │ wants=[...] haves=[...] │ │ - │────────────────────────────▶│ │ - │ │ hash(wants, haves) │ - │ │ cache lookup │ - │ │ │ - │ │ MISS: serve from local │ - │ │ clone, cache response │ - │◀────────────────────────────│ │ - │ │ │ - │ │ HIT: serve from cache │ - │◀────────────────────────────│ │ -``` - -## Implementation Plan - -### Phase 1: Clone Management -1. Storage for full clones on cache server -2. Background job to `git fetch` from upstream periodically -3. Track last-fetched time per repository - -### Phase 2: Snapshot Tarballs -1. Daily tarball generation from full clones -2. HTTP endpoint to serve snapshots -3. Client wrapper script (`cachew-git clone`) - -### Phase 3: Git Protocol Proxy -1. Implement `git-upload-pack` endpoint -2. Parse wants/haves from request -3. Normalize and hash for cache key -4. Serve from local clone, cache packfile responses - -### Phase 4: Bundle Support -1. Daily bundle generation from full clones -2. HTTP endpoint to serve bundle file -3. Advertise bundle-uri in protocol v2 capability during git-upload-pack - -## Key Decisions - -### Git Version Requirement -- Git 2.38+ for bundle-uri support -- Client wrapper works with any Git version - -### Compression -- Tarballs: zstd (fast decompression, good ratio) -- Bundles: Git's native pack compression - -### Cache Keys -- Snapshots: `git/{host}/{path}/snapshot-{date}.tar.zst` -- Bundles: `git/{host}/{path}/bundle-{date}.bundle` -- Packfiles: `git/{host}/{path}/pack-{hash(wants,haves)}.pack` - -### Freshness -- Bare clone fetch: every 5-15 minutes (configurable) -- Snapshots: generated daily -- Bundles: generated daily -- Packfiles: long TTL (immutable for given inputs) - -### Storage -- Full clones: local filesystem (fast access needed) -- Everything else: cache backend (tiered) - -## Risks and Mitigations - -| Risk | Mitigation | -|------|------------| -| Stale snapshots | Always `git fetch` after snapshot extract | -| Large repositories | Consider blobless partial clone support later | -| Upstream auth | Pass through credentials or use deployment keys | -| Storage growth | Retention policies, single clone per repo | -| Packfile cache misses | Most CI builds have identical state = high hit rate | - -## References - -- [Git Bundle-URI Documentation](https://git-scm.com/docs/bundle-uri) -- [Git Protocol v2](https://git-scm.com/docs/protocol-v2) -- [Git Pack Protocol](https://git-scm.com/docs/pack-protocol) diff --git a/internal/strategy/git/git.go b/internal/strategy/git/git.go index fb9c843..8c2fb6e 100644 --- a/internal/strategy/git/git.go +++ b/internal/strategy/git/git.go @@ -2,13 +2,19 @@ package git import ( + "bytes" "context" + "crypto/sha256" + "encoding/hex" "io" "log/slog" "net/http" "net/http/httputil" + "net/url" "os" + "path/filepath" "strings" + "sync" "time" "github.com/alecthomas/errors" @@ -42,6 +48,8 @@ type Strategy struct { proxy *httputil.ReverseProxy ctx context.Context scheduler jobscheduler.Scheduler + spoolsMu sync.Mutex + spools map[string]*RepoSpools } func New(ctx context.Context, config Config, scheduler jobscheduler.Scheduler, cache cache.Cache, mux strategy.Mux) (*Strategy, error) { @@ -71,6 +79,10 @@ func New(ctx context.Context, config Config, scheduler jobscheduler.Scheduler, c gitclone.SetShared(cloneManager) + if err := os.RemoveAll(filepath.Join(config.MirrorRoot, ".spools")); err != nil { + return nil, errors.Wrap(err, "clean up stale spools") + } + s := &Strategy{ config: config, cache: cache, @@ -78,6 +90,7 @@ func New(ctx context.Context, config Config, scheduler jobscheduler.Scheduler, c httpClient: http.DefaultClient, ctx: ctx, scheduler: scheduler.WithQueuePrefix("git"), + spools: make(map[string]*RepoSpools), } existing, err := s.cloneManager.DiscoverExisting(ctx) @@ -123,6 +136,13 @@ func New(ctx context.Context, config Config, scheduler jobscheduler.Scheduler, c var _ strategy.Strategy = (*Strategy)(nil) +// SetHTTPTransport overrides the HTTP transport used for upstream requests. +// This is intended for testing. +func (s *Strategy) SetHTTPTransport(t http.RoundTripper) { + s.httpClient.Transport = t + s.proxy.Transport = t +} + func (s *Strategy) String() string { return "git" } func (s *Strategy) handleRequest(w http.ResponseWriter, r *http.Request) { @@ -181,17 +201,130 @@ func (s *Strategy) handleRequest(w http.ResponseWriter, r *http.Request) { s.maybeBackgroundFetch(repo) s.serveFromBackend(w, r, repo) - case gitclone.StateCloning: - logger.DebugContext(ctx, "Clone in progress, forwarding to upstream") + case gitclone.StateCloning, gitclone.StateEmpty: + if state == gitclone.StateEmpty { + logger.DebugContext(ctx, "Starting background clone, forwarding to upstream") + s.scheduler.Submit(repo.UpstreamURL(), "clone", func(ctx context.Context) error { + s.startClone(ctx, repo) + return nil + }) + } + s.serveWithSpool(w, r, host, pathValue, upstreamURL) + } +} + +// SpoolKeyForRequest returns the spool key for a request, or empty string if the +// request is not spoolable. For POST requests, the body is hashed to differentiate +// protocol v2 commands (e.g. ls-refs vs fetch) that share the same URL. The request +// body is buffered and replaced so it can still be read by the caller. +func SpoolKeyForRequest(pathValue string, r *http.Request) (string, error) { + if !strings.HasSuffix(pathValue, "/git-upload-pack") { + return "", nil + } + if r.Method != http.MethodPost || r.Body == nil { + return "upload-pack", nil + } + body, err := io.ReadAll(r.Body) + if err != nil { + return "", errors.Wrap(err, "read request body for spool key") + } + r.Body = io.NopCloser(bytes.NewReader(body)) + h := sha256.Sum256(body) + return "upload-pack-" + hex.EncodeToString(h[:8]), nil +} + +func spoolDirForURL(mirrorRoot, upstreamURL string) string { + parsed, err := url.Parse(upstreamURL) + if err != nil { + return filepath.Join(mirrorRoot, ".spools", "unknown") + } + repoPath := strings.TrimSuffix(parsed.Path, ".git") + return filepath.Join(mirrorRoot, ".spools", parsed.Host, repoPath) +} + +func (s *Strategy) getOrCreateRepoSpools(upstreamURL string) *RepoSpools { + s.spoolsMu.Lock() + defer s.spoolsMu.Unlock() + rp, exists := s.spools[upstreamURL] + if !exists { + dir := spoolDirForURL(s.config.MirrorRoot, upstreamURL) + rp = NewRepoSpools(dir) + s.spools[upstreamURL] = rp + } + return rp +} + +func (s *Strategy) cleanupSpools(upstreamURL string) { + s.spoolsMu.Lock() + rp, exists := s.spools[upstreamURL] + if exists { + delete(s.spools, upstreamURL) + } + s.spoolsMu.Unlock() + if rp != nil { + if err := rp.Close(); err != nil { + logging.FromContext(s.ctx).WarnContext(s.ctx, "Failed to clean up spools", + slog.String("upstream", upstreamURL), + slog.String("error", err.Error())) + } + } +} + +func (s *Strategy) serveWithSpool(w http.ResponseWriter, r *http.Request, host, pathValue, upstreamURL string) { + ctx := r.Context() + logger := logging.FromContext(ctx) + + key, err := SpoolKeyForRequest(pathValue, r) + if err != nil { + logger.WarnContext(ctx, "Failed to compute spool key, forwarding to upstream", + slog.String("error", err.Error())) + s.forwardToUpstream(w, r, host, pathValue) + return + } + if key == "" { + s.forwardToUpstream(w, r, host, pathValue) + return + } + + rp := s.getOrCreateRepoSpools(upstreamURL) + spool, isWriter, err := rp.GetOrCreate(key) + if err != nil { + logger.WarnContext(ctx, "Failed to create spool, forwarding to upstream", + slog.String("error", err.Error())) s.forwardToUpstream(w, r, host, pathValue) + return + } - case gitclone.StateEmpty: - logger.DebugContext(ctx, "Starting background clone, forwarding to upstream") - s.scheduler.Submit(repo.UpstreamURL(), "clone", func(ctx context.Context) error { - s.startClone(ctx, repo) - return nil - }) + if isWriter { + logger.DebugContext(ctx, "Spooling upstream response", + slog.String("key", key), + slog.String("upstream", upstreamURL)) + tw := NewSpoolTeeWriter(w, spool) + s.forwardToUpstream(tw, r, host, pathValue) + spool.MarkComplete() + return + } + + if spool.Failed() { + logger.DebugContext(ctx, "Spool failed, forwarding to upstream", + slog.String("key", key)) s.forwardToUpstream(w, r, host, pathValue) + return + } + + logger.DebugContext(ctx, "Serving from spool", + slog.String("key", key), + slog.String("upstream", upstreamURL)) + if err := spool.ServeTo(w); err != nil { + if errors.Is(err, ErrSpoolFailed) { + logger.DebugContext(ctx, "Spool failed before response started, forwarding to upstream", + slog.String("key", key)) + s.forwardToUpstream(w, r, host, pathValue) + return + } + logger.WarnContext(ctx, "Spool read failed mid-stream", + slog.String("key", key), + slog.String("error", err.Error())) } } @@ -267,6 +400,10 @@ func (s *Strategy) startClone(ctx context.Context, repo *gitclone.Repository) { err := repo.Clone(ctx, gitcloneConfig) + // Clean up spools regardless of clone success or failure, so that subsequent + // requests either serve from the local backend or go directly to upstream. + s.cleanupSpools(repo.UpstreamURL()) + if err != nil { logger.ErrorContext(ctx, "Clone failed", slog.String("upstream", repo.UpstreamURL()), diff --git a/internal/strategy/git/integration_test.go b/internal/strategy/git/integration_test.go index 861bfa3..22005d9 100644 --- a/internal/strategy/git/integration_test.go +++ b/internal/strategy/git/integration_test.go @@ -6,11 +6,14 @@ import ( "context" "fmt" "io" + "log/slog" "net/http" "net/http/httptest" "os" "os/exec" "path/filepath" + "strings" + "sync/atomic" "testing" "time" @@ -266,3 +269,102 @@ func TestIntegrationPushForwardsToUpstream(t *testing.T) { // if we had wired up the server properly) t.Logf("Push forwarding test completed, pushReceived=%v", pushReceived) } + +// countingTransport wraps an http.RoundTripper to count outbound requests by URL path pattern. +type countingTransport struct { + inner http.RoundTripper + counter *atomic.Int32 + pattern string +} + +func (ct *countingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if strings.Contains(req.URL.Path, ct.pattern) { + ct.counter.Add(1) + } + return ct.inner.RoundTrip(req) +} + +// TestIntegrationSpoolReusesDuringClone clones github.com/git/git through the proxy, +// waits 5 seconds (enough for the first clone to start but not finish), then clones +// again. The second clone should be served from the spool rather than making a new +// upstream request. +func TestIntegrationSpoolReusesDuringClone(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not found in PATH") + } + + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelDebug}) + tmpDir := t.TempDir() + clonesDir := filepath.Join(tmpDir, "clones") + workDir := filepath.Join(tmpDir, "work") + err := os.MkdirAll(workDir, 0o750) + assert.NoError(t, err) + + // Count actual outbound upstream requests via a transport wrapper. + var upstreamUploadPackRequests atomic.Int32 + + mux := http.NewServeMux() + strategy, err := git.New(ctx, git.Config{ + MirrorRoot: clonesDir, + FetchInterval: 15, + }, jobscheduler.New(ctx, jobscheduler.Config{}), nil, mux) + assert.NoError(t, err) + + strategy.SetHTTPTransport(&countingTransport{ + inner: http.DefaultTransport, + counter: &upstreamUploadPackRequests, + pattern: "git-upload-pack", + }) + + server := testServerWithLogging(ctx, mux) + defer server.Close() + + repoURL := fmt.Sprintf("%s/git/github.com/git/git", server.URL) + + // First clone – triggers upstream pass-through and background clone. + t.Log("Starting first clone") + cmd := exec.Command("git", "clone", "--depth=1", repoURL, filepath.Join(workDir, "repo1")) + cmd.Env = append(os.Environ(), "GIT_TERMINAL_PROMPT=0") + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("first clone output: %s", output) + } + assert.NoError(t, err) + + // Record how many upstream upload-pack requests the first clone made. + firstCloneCount := upstreamUploadPackRequests.Load() + t.Logf("Upstream upload-pack requests after first clone: %d", firstCloneCount) + assert.True(t, firstCloneCount > 0, "first clone should have made upstream requests") + + // Wait long enough for the background clone to have started but (likely) not + // finished for a repo as large as git/git. + t.Log("Waiting 5 seconds for background clone to be in progress") + time.Sleep(5 * time.Second) + + // Second clone – should be served from the spool if the background clone is + // still running, or from the local backend if it already finished. + t.Log("Starting second clone") + cmd = exec.Command("git", "clone", "--depth=1", repoURL, filepath.Join(workDir, "repo2")) + cmd.Env = append(os.Environ(), "GIT_TERMINAL_PROMPT=0") + output, err = cmd.CombinedOutput() + if err != nil { + t.Logf("second clone output: %s", output) + } + assert.NoError(t, err) + + // Verify both clones produced a working checkout. + for _, name := range []string{"repo1", "repo2"} { + gitDir := filepath.Join(workDir, name, ".git") + _, statErr := os.Stat(gitDir) + assert.NoError(t, statErr, "expected .git in %s", name) + } + + // The second clone should not have generated any new upstream upload-pack + // requests — it should have been served entirely from the spool or local backend. + totalCount := upstreamUploadPackRequests.Load() + t.Logf("Total upstream upload-pack requests: %d (first clone: %d)", totalCount, firstCloneCount) + assert.Equal(t, firstCloneCount, totalCount, "second clone should not have made additional upstream upload-pack requests") +} diff --git a/internal/strategy/git/spool.go b/internal/strategy/git/spool.go new file mode 100644 index 0000000..766ceb8 --- /dev/null +++ b/internal/strategy/git/spool.go @@ -0,0 +1,281 @@ +package git + +import ( + "io" + "maps" + "net/http" + "os" + "path/filepath" + "sync" + "sync/atomic" + + "github.com/alecthomas/errors" +) + +// ErrSpoolFailed is returned by ServeTo when the spool failed before any +// headers were written to the client, allowing the caller to fall back to +// upstream. +var ErrSpoolFailed = errors.New("spool failed before response started") + +// ResponseSpool captures a single HTTP response (headers + body) to a file on disk, +// allowing one writer and multiple concurrent readers. Readers follow the writer, +// blocking when caught up until the write completes. +type ResponseSpool struct { + mu sync.Mutex + cond *sync.Cond + filePath string + file *os.File + status int + headers http.Header + written int64 + complete bool + err error + readers sync.WaitGroup +} + +func NewResponseSpool(filePath string) (*ResponseSpool, error) { + if err := os.MkdirAll(filepath.Dir(filePath), 0o750); err != nil { + return nil, errors.Wrap(err, "create spool directory") + } + f, err := os.Create(filePath) + if err != nil { + return nil, errors.Wrap(err, "create spool file") + } + rs := &ResponseSpool{ + filePath: filePath, + file: f, + } + rs.cond = sync.NewCond(&rs.mu) + return rs, nil +} + +func (rs *ResponseSpool) CaptureHeader(status int, header http.Header) { + rs.mu.Lock() + defer rs.mu.Unlock() + rs.status = status + rs.headers = header.Clone() + rs.cond.Broadcast() +} + +func (rs *ResponseSpool) Write(data []byte) error { + rs.mu.Lock() + defer rs.mu.Unlock() + if rs.err != nil { + return rs.err + } + n, err := rs.file.Write(data) + rs.written += int64(n) + if err != nil { + rs.err = errors.Wrap(err, "write to spool file") + } + rs.cond.Broadcast() + return rs.err +} + +func (rs *ResponseSpool) MarkComplete() { + rs.mu.Lock() + defer rs.mu.Unlock() + if rs.complete { + return + } + rs.complete = true + rs.err = errors.Join(rs.err, rs.file.Close()) + rs.cond.Broadcast() +} + +func (rs *ResponseSpool) MarkError(err error) { + rs.mu.Lock() + defer rs.mu.Unlock() + if rs.complete { + return + } + rs.err = errors.Join(err, rs.file.Close()) + rs.complete = true + rs.cond.Broadcast() +} + +func (rs *ResponseSpool) Failed() bool { + rs.mu.Lock() + defer rs.mu.Unlock() + return rs.err != nil +} + +// ServeTo streams the spooled response to w, blocking when caught up to the writer. +func (rs *ResponseSpool) ServeTo(w http.ResponseWriter) error { + rs.readers.Add(1) + defer rs.readers.Done() + + // Wait for headers to be available. + rs.mu.Lock() + for rs.status == 0 && rs.err == nil { + rs.cond.Wait() + } + if rs.err != nil && rs.status == 0 { + rs.mu.Unlock() + return ErrSpoolFailed + } + status := rs.status + headers := rs.headers.Clone() + rs.mu.Unlock() + + maps.Copy(w.Header(), headers) + w.WriteHeader(status) + + f, err := os.Open(rs.filePath) + if err != nil { + return errors.Wrap(err, "open spool file for reading") + } + defer f.Close() + + buf := make([]byte, 32*1024) + var offset int64 + for { + rs.mu.Lock() + for offset >= rs.written && !rs.complete && rs.err == nil { + rs.cond.Wait() + } + written := rs.written + complete := rs.complete + spoolErr := rs.err + rs.mu.Unlock() + + // Read all available data up to `written`. + for offset < written { + toRead := min(written-offset, int64(len(buf))) + n, readErr := f.ReadAt(buf[:toRead], offset) + if n > 0 { + if _, writeErr := w.Write(buf[:n]); writeErr != nil { + return errors.Wrap(writeErr, "write to client from spool") + } + offset += int64(n) + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + } + if readErr != nil && readErr != io.EOF { + return errors.Wrap(readErr, "read spool") + } + } + + if complete && offset >= written { + if spoolErr != nil { + return spoolErr + } + return nil + } + } +} + +// WaitForReaders blocks until all active spool readers have finished. +func (rs *ResponseSpool) WaitForReaders() { + rs.readers.Wait() +} + +// SpoolTeeWriter wraps an http.ResponseWriter to capture the response into a spool +// while simultaneously streaming it to the original client. +type SpoolTeeWriter struct { + inner http.ResponseWriter + spool *ResponseSpool + wroteHeader bool +} + +// NewSpoolTeeWriter creates a new SpoolTeeWriter that tees writes to both the +// inner ResponseWriter and the given spool. +func NewSpoolTeeWriter(inner http.ResponseWriter, spool *ResponseSpool) *SpoolTeeWriter { + return &SpoolTeeWriter{inner: inner, spool: spool} +} + +func (w *SpoolTeeWriter) Header() http.Header { + return w.inner.Header() +} + +func (w *SpoolTeeWriter) WriteHeader(code int) { + if w.wroteHeader { + return + } + w.wroteHeader = true + if code >= 200 && code < 300 { + w.spool.CaptureHeader(code, w.inner.Header()) + } else { + w.spool.MarkError(errors.Errorf("upstream returned status %d", code)) + } + w.inner.WriteHeader(code) +} + +func (w *SpoolTeeWriter) Write(data []byte) (int, error) { + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + if err := w.spool.Write(data); err != nil { + // Spool write failed; still try to serve the client. + n, writeErr := w.inner.Write(data) + return n, errors.Wrap(writeErr, "write to client") + } + n, err := w.inner.Write(data) + if err != nil { + err = errors.Wrap(err, "write to client") + w.spool.MarkError(err) + } + return n, err +} + +func (w *SpoolTeeWriter) Flush() { + if f, ok := w.inner.(http.Flusher); ok { + f.Flush() + } +} + +// RepoSpools manages all response spools for a single repository. +type RepoSpools struct { + mu sync.Mutex + dir string + spools map[string]*ResponseSpool + closed atomic.Bool +} + +func NewRepoSpools(dir string) *RepoSpools { + return &RepoSpools{ + dir: dir, + spools: make(map[string]*ResponseSpool), + } +} + +// GetOrCreate returns an existing spool for the key, or creates a new one. +// isWriter is true if the caller created the spool and should act as the writer. +func (rp *RepoSpools) GetOrCreate(key string) (spool *ResponseSpool, isWriter bool, err error) { + if rp.closed.Load() { + return nil, false, errors.New("repo spools closed") + } + + rp.mu.Lock() + defer rp.mu.Unlock() + + if s, exists := rp.spools[key]; exists { + return s, false, nil + } + + s, err := NewResponseSpool(filepath.Join(rp.dir, key+".spool")) + if err != nil { + return nil, false, err + } + rp.spools[key] = s + return s, true, nil +} + +// Close marks the repo spools as closed, waits for all readers to finish, +// and removes spool files from disk. +func (rp *RepoSpools) Close() error { + rp.closed.Store(true) + + rp.mu.Lock() + spools := make([]*ResponseSpool, 0, len(rp.spools)) + for _, s := range rp.spools { + spools = append(spools, s) + } + rp.mu.Unlock() + + for _, s := range spools { + s.WaitForReaders() + } + return errors.Wrap(os.RemoveAll(rp.dir), "remove spool directory") +} diff --git a/internal/strategy/git/spool_test.go b/internal/strategy/git/spool_test.go new file mode 100644 index 0000000..a9cf3db --- /dev/null +++ b/internal/strategy/git/spool_test.go @@ -0,0 +1,402 @@ +package git_test + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/alecthomas/assert/v2" + + "github.com/block/cachew/internal/strategy/git" +) + +func TestResponseSpoolWriteAndRead(t *testing.T) { + dir := t.TempDir() + rs, err := git.NewResponseSpool(filepath.Join(dir, "test.spool")) + assert.NoError(t, err) + + rs.CaptureHeader(http.StatusOK, http.Header{ + "Content-Type": []string{"application/octet-stream"}, + }) + assert.NoError(t, rs.Write([]byte("hello "))) + assert.NoError(t, rs.Write([]byte("world"))) + rs.MarkComplete() + + rec := httptest.NewRecorder() + err = rs.ServeTo(rec) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "application/octet-stream", rec.Header().Get("Content-Type")) + assert.Equal(t, "hello world", rec.Body.String()) +} + +func TestResponseSpoolConcurrentReaders(t *testing.T) { + dir := t.TempDir() + rs, err := git.NewResponseSpool(filepath.Join(dir, "test.spool")) + assert.NoError(t, err) + + rs.CaptureHeader(http.StatusOK, http.Header{}) + + const numReaders = 5 + var wg sync.WaitGroup + recorders := make([]*httptest.ResponseRecorder, numReaders) + for i := range numReaders { + recorders[i] = httptest.NewRecorder() + wg.Add(1) + go func(rec *httptest.ResponseRecorder) { + defer wg.Done() + assert.NoError(t, rs.ServeTo(rec)) + }(recorders[i]) + } + + // Write data in chunks with small delays so readers exercise the blocking path. + chunks := []string{"chunk1-", "chunk2-", "chunk3"} + for _, chunk := range chunks { + time.Sleep(5 * time.Millisecond) + assert.NoError(t, rs.Write([]byte(chunk))) + } + rs.MarkComplete() + + wg.Wait() + + expected := "chunk1-chunk2-chunk3" + for i, rec := range recorders { + assert.Equal(t, http.StatusOK, rec.Code, "reader %d status", i) + assert.Equal(t, expected, rec.Body.String(), "reader %d body", i) + } +} + +func TestResponseSpoolReaderFollowsWriter(t *testing.T) { + dir := t.TempDir() + rs, err := git.NewResponseSpool(filepath.Join(dir, "test.spool")) + assert.NoError(t, err) + + rs.CaptureHeader(http.StatusOK, http.Header{}) + + rec := httptest.NewRecorder() + readDone := make(chan struct{}) + go func() { + defer close(readDone) + assert.NoError(t, rs.ServeTo(rec)) + }() + + // Write progressively and give the reader time to consume. + for i := range 10 { + assert.NoError(t, rs.Write([]byte{byte('a' + i)})) + time.Sleep(2 * time.Millisecond) + } + rs.MarkComplete() + + <-readDone + assert.Equal(t, "abcdefghij", rec.Body.String()) +} + +func TestResponseSpoolErrorPropagation(t *testing.T) { + dir := t.TempDir() + rs, err := git.NewResponseSpool(filepath.Join(dir, "test.spool")) + assert.NoError(t, err) + + rs.CaptureHeader(http.StatusOK, http.Header{}) + assert.NoError(t, rs.Write([]byte("partial"))) + + writeErr := os.ErrClosed + rs.MarkError(writeErr) + + assert.True(t, rs.Failed()) + + rec := httptest.NewRecorder() + err = rs.ServeTo(rec) + // Headers were captured before the error, so the reader serves partial data + // and returns the original error from the read loop. + assert.IsError(t, err, writeErr) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "partial", rec.Body.String()) +} + +func TestResponseSpoolErrorBeforeHeader(t *testing.T) { + dir := t.TempDir() + rs, err := git.NewResponseSpool(filepath.Join(dir, "test.spool")) + assert.NoError(t, err) + + rs.MarkError(os.ErrClosed) + + rec := httptest.NewRecorder() + err = rs.ServeTo(rec) + // No headers were captured, so ServeTo returns ErrSpoolFailed to allow + // the caller to fall back to upstream. + assert.IsError(t, err, git.ErrSpoolFailed) +} + +func TestResponseSpoolWaitForReaders(t *testing.T) { + dir := t.TempDir() + rs, err := git.NewResponseSpool(filepath.Join(dir, "test.spool")) + assert.NoError(t, err) + + rs.CaptureHeader(http.StatusOK, http.Header{}) + + rec := httptest.NewRecorder() + readerStarted := make(chan struct{}) + readDone := make(chan struct{}) + go func() { + defer close(readDone) + close(readerStarted) + _ = rs.ServeTo(rec) + }() + + <-readerStarted + time.Sleep(10 * time.Millisecond) + + // Complete the spool so the reader can finish. + assert.NoError(t, rs.Write([]byte("data"))) + rs.MarkComplete() + + // WaitForReaders should return once the reader goroutine finishes. + done := make(chan struct{}) + go func() { + rs.WaitForReaders() + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("WaitForReaders timed out") + } +} + +func TestSpoolTeeWriter(t *testing.T) { + dir := t.TempDir() + rs, err := git.NewResponseSpool(filepath.Join(dir, "test.spool")) + assert.NoError(t, err) + + rec := httptest.NewRecorder() + tw := git.NewSpoolTeeWriter(rec, rs) + + tw.Header().Set("X-Custom", "value") + tw.WriteHeader(http.StatusCreated) + _, err = tw.Write([]byte("tee-data")) + assert.NoError(t, err) + rs.MarkComplete() + + // Verify original writer got the response. + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, "value", rec.Header().Get("X-Custom")) + assert.Equal(t, "tee-data", rec.Body.String()) + + // Verify spool captured the response. + rec2 := httptest.NewRecorder() + err = rs.ServeTo(rec2) + assert.NoError(t, err) + assert.Equal(t, http.StatusCreated, rec2.Code) + assert.Equal(t, "tee-data", rec2.Body.String()) +} + +func TestSpoolTeeWriterUpstreamError(t *testing.T) { + dir := t.TempDir() + rs, err := git.NewResponseSpool(filepath.Join(dir, "test.spool")) + assert.NoError(t, err) + + rec := httptest.NewRecorder() + tw := git.NewSpoolTeeWriter(rec, rs) + + // Upstream returns a 502 — the spool should be marked as failed. + tw.WriteHeader(http.StatusBadGateway) + _, err = tw.Write([]byte("bad gateway")) + assert.NoError(t, err) + rs.MarkComplete() + + // The original client still gets the error response. + assert.Equal(t, http.StatusBadGateway, rec.Code) + assert.Equal(t, "bad gateway", rec.Body.String()) + + // The spool should be marked as failed so readers fall back to upstream. + assert.True(t, rs.Failed()) +} + +func TestSpoolTeeWriterImplicitHeader(t *testing.T) { + dir := t.TempDir() + rs, err := git.NewResponseSpool(filepath.Join(dir, "test.spool")) + assert.NoError(t, err) + + rec := httptest.NewRecorder() + tw := git.NewSpoolTeeWriter(rec, rs) + + // Write without explicit WriteHeader; should default to 200. + _, err = tw.Write([]byte("implicit")) + assert.NoError(t, err) + rs.MarkComplete() + + assert.Equal(t, http.StatusOK, rec.Code) + + rec2 := httptest.NewRecorder() + err = rs.ServeTo(rec2) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec2.Code) + assert.Equal(t, "implicit", rec2.Body.String()) +} + +func TestRepoSpoolsGetOrCreate(t *testing.T) { + dir := t.TempDir() + rp := git.NewRepoSpools(dir) + + s1, isWriter1, err := rp.GetOrCreate("info-refs") + assert.NoError(t, err) + assert.True(t, isWriter1) + assert.NotZero(t, s1) + + s2, isWriter2, err := rp.GetOrCreate("info-refs") + assert.NoError(t, err) + assert.False(t, isWriter2) + assert.Equal(t, s1, s2) + + s3, isWriter3, err := rp.GetOrCreate("upload-pack") + assert.NoError(t, err) + assert.True(t, isWriter3) + assert.NotEqual(t, s1, s3) +} + +func TestRepoSpoolsClose(t *testing.T) { + dir := filepath.Join(t.TempDir(), "spooldir") + rp := git.NewRepoSpools(dir) + + s1, _, err := rp.GetOrCreate("info-refs") + assert.NoError(t, err) + s1.CaptureHeader(http.StatusOK, http.Header{}) + assert.NoError(t, s1.Write([]byte("data"))) + s1.MarkComplete() + + assert.NoError(t, rp.Close()) + + // Spool directory should be removed. + _, err = os.Stat(dir) + assert.True(t, os.IsNotExist(err)) + + // Further GetOrCreate calls should fail. + _, _, err = rp.GetOrCreate("upload-pack") + assert.Error(t, err) +} + +func TestRepoSpoolsCloseWaitsForReaders(t *testing.T) { + dir := filepath.Join(t.TempDir(), "spooldir") + rp := git.NewRepoSpools(dir) + + s1, _, err := rp.GetOrCreate("test") + assert.NoError(t, err) + s1.CaptureHeader(http.StatusOK, http.Header{}) + + rec := httptest.NewRecorder() + readerRunning := make(chan struct{}) + go func() { + close(readerRunning) + _ = s1.ServeTo(rec) + }() + + <-readerRunning + time.Sleep(10 * time.Millisecond) + + closed := make(chan struct{}) + go func() { + assert.NoError(t, rp.Close()) + close(closed) + }() + + // Close should block because a reader is active. + select { + case <-closed: + t.Fatal("Close returned before reader finished") + case <-time.After(50 * time.Millisecond): + } + + // Complete the write so the reader can finish. + assert.NoError(t, s1.Write([]byte("ok"))) + s1.MarkComplete() + + select { + case <-closed: + case <-time.After(5 * time.Second): + t.Fatal("Close timed out after reader finished") + } +} + +func TestSpoolKeyForRequest(t *testing.T) { + tests := []struct { + name string + path string + method string + body string + expected string + }{ + {name: "InfoRefs", path: "org/repo.git/info/refs", method: http.MethodGet, expected: ""}, + {name: "UploadPackGET", path: "org/repo.git/git-upload-pack", method: http.MethodGet, expected: "upload-pack"}, + {name: "Unknown", path: "org/repo.git/something-else", method: http.MethodGet, expected: ""}, + {name: "Plain", path: "org/repo", method: http.MethodGet, expected: ""}, + {name: "UploadPackPOSTSameBody", path: "org/repo.git/git-upload-pack", method: http.MethodPost, body: "command=ls-refs\n", expected: ""}, + {name: "UploadPackPOSTDiffBody", path: "org/repo.git/git-upload-pack", method: http.MethodPost, body: "command=fetch\n", expected: ""}, + } + + // Compute expected keys for the two POST cases by running them first. + var lsRefsKey, fetchKey string + for i, tt := range tests { + r := httptest.NewRequest(tt.method, "/"+tt.path, strings.NewReader(tt.body)) + key, err := git.SpoolKeyForRequest(tt.path, r) + assert.NoError(t, err) + switch tt.name { + case "UploadPackPOSTSameBody": + lsRefsKey = key + tests[i].expected = key + case "UploadPackPOSTDiffBody": + fetchKey = key + tests[i].expected = key + } + } + // The two POST keys must differ (different bodies). + assert.NotEqual(t, lsRefsKey, fetchKey) + // Both must start with "upload-pack-" prefix. + assert.Contains(t, lsRefsKey, "upload-pack-") + assert.Contains(t, fetchKey, "upload-pack-") + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(tt.method, "/"+tt.path, strings.NewReader(tt.body)) + key, err := git.SpoolKeyForRequest(tt.path, r) + assert.NoError(t, err) + assert.Equal(t, tt.expected, key) + }) + } +} + +func TestResponseSpoolLargeData(t *testing.T) { + dir := t.TempDir() + rs, err := git.NewResponseSpool(filepath.Join(dir, "large.spool")) + assert.NoError(t, err) + + rs.CaptureHeader(http.StatusOK, http.Header{}) + + // Write 1MB in 4KB chunks. + chunk := make([]byte, 4096) + for i := range chunk { + chunk[i] = byte(i % 256) + } + totalChunks := 256 + totalSize := len(chunk) * totalChunks + + rec := httptest.NewRecorder() + readDone := make(chan struct{}) + go func() { + defer close(readDone) + assert.NoError(t, rs.ServeTo(rec)) + }() + + for range totalChunks { + assert.NoError(t, rs.Write(chunk)) + } + rs.MarkComplete() + + <-readDone + assert.Equal(t, totalSize, rec.Body.Len()) +}