Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions internal/snapshot/pipe_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package snapshot_test

import (
"os"
"os/exec"
"testing"
"time"

"github.com/alecthomas/assert/v2"
)

// TestPipeLifecycleNoDeadlock verifies the pipe pattern used by Create,
// StreamTo, and Extract does not deadlock when the downstream process exits
// while the upstream is still producing data.
//
// Background: exec.Cmd.StdoutPipe() retains the pipe read end in the parent
// until Wait() runs closeAfterWait. If the downstream exits early, the
// upstream cannot receive SIGPIPE (the parent still holds the read end), so
// it blocks on pipe write and Wait() deadlocks.
//
// The fix (used in Create/StreamTo/Extract): use os.Pipe() manually and close
// both ends in the parent immediately after starting child processes. This
// ensures that when the downstream exits, the upstream receives SIGPIPE.
func TestPipeLifecycleNoDeadlock(t *testing.T) {
t.Run("StdoutPipeDeadlocks", func(t *testing.T) {
// Demonstrates the broken pattern: StdoutPipe holds the read end,
// preventing SIGPIPE delivery to the upstream.
upstream := exec.Command("yes")
downstream := exec.Command("head", "-c", "100")

pipeRead, _ := upstream.StdoutPipe()
downstream.Stdin = pipeRead

assert.NoError(t, upstream.Start())
assert.NoError(t, downstream.Start())

done := make(chan struct{})
go func() {
_ = upstream.Wait()
_ = downstream.Wait()
close(done)
}()

select {
case <-done:
t.Fatal("expected StdoutPipe pattern to deadlock, but it completed")
case <-time.After(2 * time.Second):
// Deadlock confirmed — clean up.
upstream.Process.Kill() //nolint:errcheck
downstream.Process.Kill() //nolint:errcheck
<-done
}
})

t.Run("ManualPipeWorks", func(t *testing.T) {
// The fixed pattern: parent closes both pipe ends after starting
// children, so the upstream gets SIGPIPE when downstream exits.
upstream := exec.Command("yes")
downstream := exec.Command("head", "-c", "100")

pr, pw, err := os.Pipe()
assert.NoError(t, err)

upstream.Stdout = pw
downstream.Stdin = pr

assert.NoError(t, upstream.Start())
pw.Close() //nolint:errcheck,gosec

assert.NoError(t, downstream.Start())
pr.Close() //nolint:errcheck,gosec

done := make(chan struct{})
go func() {
_ = upstream.Wait()
_ = downstream.Wait()
close(done)
}()

select {
case <-done:
// OK: no deadlock.
case <-time.After(5 * time.Second):
upstream.Process.Kill() //nolint:errcheck
downstream.Process.Kill() //nolint:errcheck
<-done
t.Fatal("manual pipe pattern deadlocked unexpectedly")
}
})
}
41 changes: 31 additions & 10 deletions internal/snapshot/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,34 @@ func Create(ctx context.Context, remote cache.Cache, key cache.Key, directory st
tarCmd := exec.CommandContext(ctx, "tar", tarArgs...)
zstdCmd := exec.CommandContext(ctx, "zstd", "-c", fmt.Sprintf("-T%d", threads)) //nolint:gosec // threads is a validated integer, not user input

tarStdout, err := tarCmd.StdoutPipe()
// Use manual pipe so we can close both ends in the parent after starting
// children. This prevents deadlock if zstd exits while tar is still writing:
// closing the read end ensures tar receives SIGPIPE instead of blocking.
pr, pw, err := os.Pipe()
if err != nil {
return errors.Join(errors.Wrap(err, "failed to create tar stdout pipe"), wc.Close())
return errors.Join(errors.Wrap(err, "failed to create pipe"), wc.Close())
}

var tarStderr, zstdStderr bytes.Buffer
tarCmd.Stdout = pw
tarCmd.Stderr = &tarStderr

zstdCmd.Stdin = tarStdout
zstdCmd.Stdin = pr
zstdCmd.Stdout = wc
zstdCmd.Stderr = &zstdStderr

if err := tarCmd.Start(); err != nil {
pw.Close() //nolint:errcheck,gosec // best-effort cleanup
pr.Close() //nolint:errcheck,gosec // best-effort cleanup
return errors.Join(errors.Wrap(err, "failed to start tar"), wc.Close())
}
pw.Close() //nolint:errcheck,gosec // parent no longer needs write end; tar holds its own copy

if err := zstdCmd.Start(); err != nil {
pr.Close() //nolint:errcheck,gosec // let tar receive SIGPIPE so it exits
return errors.Join(errors.Wrap(err, "failed to start zstd"), tarCmd.Wait(), wc.Close())
}
pr.Close() //nolint:errcheck,gosec // parent no longer needs read end; if zstd dies, tar gets SIGPIPE

tarErr := tarCmd.Wait()
zstdErr := zstdCmd.Wait()
Expand Down Expand Up @@ -125,25 +134,31 @@ func StreamTo(ctx context.Context, w io.Writer, directory string, excludePattern
tarCmd := exec.CommandContext(ctx, "tar", tarArgs...)
zstdCmd := exec.CommandContext(ctx, "zstd", "-c", fmt.Sprintf("-T%d", threads)) //nolint:gosec // threads is a validated integer, not user input

tarStdout, err := tarCmd.StdoutPipe()
pr, pw, err := os.Pipe()
if err != nil {
return errors.Wrap(err, "failed to create tar stdout pipe")
return errors.Wrap(err, "failed to create pipe")
}

var tarStderr, zstdStderr bytes.Buffer
tarCmd.Stdout = pw
tarCmd.Stderr = &tarStderr

zstdCmd.Stdin = tarStdout
zstdCmd.Stdin = pr
zstdCmd.Stdout = w
zstdCmd.Stderr = &zstdStderr

if err := tarCmd.Start(); err != nil {
pw.Close() //nolint:errcheck,gosec // best-effort cleanup
pr.Close() //nolint:errcheck,gosec // best-effort cleanup
return errors.Wrap(err, "failed to start tar")
}
pw.Close() //nolint:errcheck,gosec // parent no longer needs write end; tar holds its own copy

if err := zstdCmd.Start(); err != nil {
pr.Close() //nolint:errcheck,gosec // let tar receive SIGPIPE so it exits
return errors.Join(errors.Wrap(err, "failed to start zstd"), tarCmd.Wait())
}
pr.Close() //nolint:errcheck,gosec // parent no longer needs read end; if zstd dies, tar gets SIGPIPE

tarErr := tarCmd.Wait()
zstdErr := zstdCmd.Wait()
Expand Down Expand Up @@ -190,25 +205,31 @@ func Extract(ctx context.Context, r io.Reader, directory string, threads int) er
zstdCmd := exec.CommandContext(ctx, "zstd", "-dc", fmt.Sprintf("-T%d", threads)) //nolint:gosec // threads is a validated integer, not user input
tarCmd := exec.CommandContext(ctx, "tar", "-xpf", "-", "-C", directory)

zstdCmd.Stdin = r
zstdStdout, err := zstdCmd.StdoutPipe()
pr, pw, err := os.Pipe()
if err != nil {
return errors.Wrap(err, "failed to create zstd stdout pipe")
return errors.Wrap(err, "failed to create pipe")
}

var zstdStderr, tarStderr bytes.Buffer
zstdCmd.Stdin = r
zstdCmd.Stdout = pw
zstdCmd.Stderr = &zstdStderr

tarCmd.Stdin = zstdStdout
tarCmd.Stdin = pr
tarCmd.Stderr = &tarStderr

if err := zstdCmd.Start(); err != nil {
pw.Close() //nolint:errcheck,gosec // best-effort cleanup
pr.Close() //nolint:errcheck,gosec // best-effort cleanup
return errors.Wrap(err, "failed to start zstd")
}
pw.Close() //nolint:errcheck,gosec // parent no longer needs write end; zstd holds its own copy

if err := tarCmd.Start(); err != nil {
pr.Close() //nolint:errcheck,gosec // let zstd receive SIGPIPE so it exits
return errors.Join(errors.Wrap(err, "failed to start tar"), zstdCmd.Wait())
}
pr.Close() //nolint:errcheck,gosec // parent no longer needs read end; if tar dies, zstd gets SIGPIPE

zstdErr := zstdCmd.Wait()
tarErr := tarCmd.Wait()
Expand Down