diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 59645422..5c0599c2 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -2552,6 +2552,10 @@ func (c *Cache) GetNarInfo(ctx context.Context, hash string) (*narinfo.NarInfo, var err error narInfo, err = c.getNarInfoFromDatabase(ctx, hash) + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error fetching narinfo from database: %w", err) + } + if err == nil { metricAttrs = append(metricAttrs, attribute.String("result", "hit"), @@ -2626,7 +2630,12 @@ func (c *Cache) GetNarInfo(ctx context.Context, hash string) (*narinfo.NarInfo, zerolog.Ctx(ctx). Debug(). Msg("pulling nar in a go-routing and will wait for it") - <-ds.done + + select { + case <-ctx.Done(): + return ctx.Err() + case <-ds.done: + } err = ds.getError() if err != nil { @@ -2984,7 +2993,7 @@ func (c *Cache) prePullNarInfo(ctx context.Context, hash string) *downloadState return c.coordinateDownload( ctx, - ctx, + context.WithoutCancel(ctx), narInfoJobKey(hash), hash, true, @@ -2992,7 +3001,7 @@ func (c *Cache) prePullNarInfo(ctx context.Context, hash string) *downloadState return c.narInfoStore.HasNarInfo(ctx, hash) }, func(ds *downloadState) { - c.pullNarInfo(ctx, hash, ds) + c.pullNarInfo(context.WithoutCancel(ctx), hash, ds) }, ) } @@ -4397,6 +4406,7 @@ func (c *Cache) coordinateDownload( // Download completed (successfully or with error) case <-coordCtx.Done(): // Caller context canceled + return ds } return ds @@ -5135,6 +5145,7 @@ func (c *Cache) selectUpstream( errC := make(chan error, len(ucs)) ctx, cancel := context.WithCancel(ctx) + defer cancel() var wg sync.WaitGroup for _, uc := range ucs { @@ -5155,8 +5166,12 @@ func (c *Cache) selectUpstream( for { select { - case uc := <-ch: - cancel() + case <-ctx.Done(): + return nil, errors.Join(ctx.Err(), errs) + case uc, ok := <-ch: + if !ok { + return nil, errs + } return uc, errs case err := <-errC: diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index 93f209ed..f45ba281 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -3092,3 +3092,70 @@ func testBackgroundMigrateNarToChunksAfterCancellation(factory cacheFactory) fun assert.True(t, hasChunks) } } + +func TestIssue990_BackgroundJobContextCancellation(t *testing.T) { + t.Parallel() + // 1. Setup test environment + c, db, _, _, rebind, cleanup := setupSQLiteFactory(t) + defer cleanup() + + // 2. Setup a slow upstream server + ts := testdata.NewTestServer(t, 40) + defer ts.Close() + + // Add a handler that sleeps for 1 second before responding + ts.AddMaybeHandler(func(_ http.ResponseWriter, r *http.Request) bool { + if strings.HasSuffix(r.URL.Path, ".narinfo") { + time.Sleep(1 * time.Second) + // The default handler will handle it after we Return false here if we didn't write anything, + // but we want to actually respond here to be sure. + return false + } + + return false + }) + + uc, err := upstream.New(context.Background(), testhelper.MustParseURL(t, ts.URL), &upstream.Options{ + PublicKeys: testdata.PublicKeys(), + }) + require.NoError(t, err) + c.AddUpstreamCaches(context.Background(), uc) + + // Wait for upstream to be healthy + <-c.GetHealthChecker().Trigger() + + // 3. Trigger GetNarInfo with a context that will be canceled + ctx, cancel := context.WithCancel(context.Background()) + + hash := testdata.Nar1.NarInfoHash + + // Start GetNarInfo in a goroutine + errCh := make(chan error, 1) + + go func() { + _, err := c.GetNarInfo(ctx, hash) + errCh <- err + }() + + // Wait a bit then cancel the context + time.Sleep(200 * time.Millisecond) + cancel() + + // Wait for GetNarInfo to return + err = <-errCh + require.ErrorIs(t, err, context.Canceled) + + // 4. Wait for the background job to (hopefully) finish + time.Sleep(1500 * time.Millisecond) + + // 5. Check if the narinfo is in the database + var count int + + err = db.DB().QueryRowContext(context.Background(), + rebind("SELECT COUNT(*) FROM narinfos WHERE hash = ?"), hash).Scan(&count) + require.NoError(t, err) + + // In the BUGGY version, count will be 0 because pullNarInfo was canceled. + // In the FIXED version, count should be 1. + assert.Equal(t, 1, count, "NarInfo should be in database even if the original request was canceled") +} diff --git a/pkg/server/server.go b/pkg/server/server.go index ed62741e..718bd0f1 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -2,6 +2,7 @@ package server import ( "compress/gzip" + "context" "encoding/json" "errors" "io" @@ -344,6 +345,10 @@ func (s *Server) getNarInfo(withBody bool) http.HandlerFunc { return } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return + } + zerolog.Ctx(r.Context()). Error(). Err(err). @@ -562,6 +567,10 @@ func (s *Server) getNar(withBody bool) http.HandlerFunc { return } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return + } + zerolog.Ctx(r.Context()). Error(). Err(err).