Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2552,6 +2552,10 @@ func (c *Cache) GetNarInfo(ctx context.Context, hash string) (*narinfo.NarInfo,
var err error

narInfo, err = c.getNarInfoFromDatabase(ctx, hash)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("error fetching narinfo from database: %w", err)
}

if err == nil {
metricAttrs = append(metricAttrs,
attribute.String("result", "hit"),
Expand Down Expand Up @@ -2626,7 +2630,12 @@ func (c *Cache) GetNarInfo(ctx context.Context, hash string) (*narinfo.NarInfo,
zerolog.Ctx(ctx).
Debug().
Msg("pulling nar in a go-routing and will wait for it")
<-ds.done

select {
case <-ctx.Done():
return ctx.Err()
case <-ds.done:
}

err = ds.getError()
if err != nil {
Expand Down Expand Up @@ -2984,15 +2993,15 @@ func (c *Cache) prePullNarInfo(ctx context.Context, hash string) *downloadState

return c.coordinateDownload(
ctx,
ctx,
context.WithoutCancel(ctx),
narInfoJobKey(hash),
hash,
true,
func(ctx context.Context) bool {
return c.narInfoStore.HasNarInfo(ctx, hash)
},
func(ds *downloadState) {
c.pullNarInfo(ctx, hash, ds)
c.pullNarInfo(context.WithoutCancel(ctx), hash, ds)
},
)
}
Expand Down Expand Up @@ -4397,6 +4406,7 @@ func (c *Cache) coordinateDownload(
// Download completed (successfully or with error)
case <-coordCtx.Done():
// Caller context canceled
return ds
}

return ds
Expand Down Expand Up @@ -5135,6 +5145,7 @@ func (c *Cache) selectUpstream(
errC := make(chan error, len(ucs))

ctx, cancel := context.WithCancel(ctx)
defer cancel()

var wg sync.WaitGroup
for _, uc := range ucs {
Expand All @@ -5155,8 +5166,12 @@ func (c *Cache) selectUpstream(

for {
select {
case uc := <-ch:
cancel()
case <-ctx.Done():
return nil, errors.Join(ctx.Err(), errs)
case uc, ok := <-ch:
if !ok {
return nil, errs
}

return uc, errs
case err := <-errC:
Expand Down
67 changes: 67 additions & 0 deletions pkg/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3092,3 +3092,70 @@ func testBackgroundMigrateNarToChunksAfterCancellation(factory cacheFactory) fun
assert.True(t, hasChunks)
}
}

func TestIssue990_BackgroundJobContextCancellation(t *testing.T) {
t.Parallel()
// 1. Setup test environment
c, db, _, _, rebind, cleanup := setupSQLiteFactory(t)
defer cleanup()

// 2. Setup a slow upstream server
ts := testdata.NewTestServer(t, 40)
defer ts.Close()

// Add a handler that sleeps for 1 second before responding
ts.AddMaybeHandler(func(_ http.ResponseWriter, r *http.Request) bool {
if strings.HasSuffix(r.URL.Path, ".narinfo") {
time.Sleep(1 * time.Second)
// The default handler will handle it after we Return false here if we didn't write anything,
// but we want to actually respond here to be sure.
return false
}

return false
})

uc, err := upstream.New(context.Background(), testhelper.MustParseURL(t, ts.URL), &upstream.Options{
PublicKeys: testdata.PublicKeys(),
})
require.NoError(t, err)
c.AddUpstreamCaches(context.Background(), uc)

// Wait for upstream to be healthy
<-c.GetHealthChecker().Trigger()

// 3. Trigger GetNarInfo with a context that will be canceled
ctx, cancel := context.WithCancel(context.Background())

hash := testdata.Nar1.NarInfoHash

// Start GetNarInfo in a goroutine
errCh := make(chan error, 1)

go func() {
_, err := c.GetNarInfo(ctx, hash)
errCh <- err
}()

// Wait a bit then cancel the context
time.Sleep(200 * time.Millisecond)
cancel()

// Wait for GetNarInfo to return
err = <-errCh
require.ErrorIs(t, err, context.Canceled)

// 4. Wait for the background job to (hopefully) finish
time.Sleep(1500 * time.Millisecond)

// 5. Check if the narinfo is in the database
var count int

err = db.DB().QueryRowContext(context.Background(),
rebind("SELECT COUNT(*) FROM narinfos WHERE hash = ?"), hash).Scan(&count)
require.NoError(t, err)

// In the BUGGY version, count will be 0 because pullNarInfo was canceled.
// In the FIXED version, count should be 1.
assert.Equal(t, 1, count, "NarInfo should be in database even if the original request was canceled")
}
9 changes: 9 additions & 0 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"compress/gzip"
"context"
"encoding/json"
"errors"
"io"
Expand Down Expand Up @@ -344,6 +345,10 @@ func (s *Server) getNarInfo(withBody bool) http.HandlerFunc {
return
}

if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
}

zerolog.Ctx(r.Context()).
Error().
Err(err).
Expand Down Expand Up @@ -562,6 +567,10 @@ func (s *Server) getNar(withBody bool) http.HandlerFunc {
return
}

if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
}

zerolog.Ctx(r.Context()).
Error().
Err(err).
Expand Down