diff --git a/internal/snapshot/pipe_test.go b/internal/snapshot/pipe_test.go new file mode 100644 index 00000000..5a463163 --- /dev/null +++ b/internal/snapshot/pipe_test.go @@ -0,0 +1,90 @@ +package snapshot_test + +import ( + "os" + "os/exec" + "testing" + "time" + + "github.com/alecthomas/assert/v2" +) + +// TestPipeLifecycleNoDeadlock verifies the pipe pattern used by Create, +// StreamTo, and Extract does not deadlock when the downstream process exits +// while the upstream is still producing data. +// +// Background: exec.Cmd.StdoutPipe() retains the pipe read end in the parent +// until Wait() runs closeAfterWait. If the downstream exits early, the +// upstream cannot receive SIGPIPE (the parent still holds the read end), so +// it blocks on pipe write and Wait() deadlocks. +// +// The fix (used in Create/StreamTo/Extract): use os.Pipe() manually and close +// both ends in the parent immediately after starting child processes. This +// ensures that when the downstream exits, the upstream receives SIGPIPE. +func TestPipeLifecycleNoDeadlock(t *testing.T) { + t.Run("StdoutPipeDeadlocks", func(t *testing.T) { + // Demonstrates the broken pattern: StdoutPipe holds the read end, + // preventing SIGPIPE delivery to the upstream. + upstream := exec.Command("yes") + downstream := exec.Command("head", "-c", "100") + + pipeRead, _ := upstream.StdoutPipe() + downstream.Stdin = pipeRead + + assert.NoError(t, upstream.Start()) + assert.NoError(t, downstream.Start()) + + done := make(chan struct{}) + go func() { + _ = upstream.Wait() + _ = downstream.Wait() + close(done) + }() + + select { + case <-done: + t.Fatal("expected StdoutPipe pattern to deadlock, but it completed") + case <-time.After(2 * time.Second): + // Deadlock confirmed — clean up. + upstream.Process.Kill() //nolint:errcheck + downstream.Process.Kill() //nolint:errcheck + <-done + } + }) + + t.Run("ManualPipeWorks", func(t *testing.T) { + // The fixed pattern: parent closes both pipe ends after starting + // children, so the upstream gets SIGPIPE when downstream exits. + upstream := exec.Command("yes") + downstream := exec.Command("head", "-c", "100") + + pr, pw, err := os.Pipe() + assert.NoError(t, err) + + upstream.Stdout = pw + downstream.Stdin = pr + + assert.NoError(t, upstream.Start()) + pw.Close() //nolint:errcheck,gosec + + assert.NoError(t, downstream.Start()) + pr.Close() //nolint:errcheck,gosec + + done := make(chan struct{}) + go func() { + _ = upstream.Wait() + _ = downstream.Wait() + close(done) + }() + + select { + case <-done: + // OK: no deadlock. + case <-time.After(5 * time.Second): + upstream.Process.Kill() //nolint:errcheck + downstream.Process.Kill() //nolint:errcheck + <-done + t.Fatal("manual pipe pattern deadlocked unexpectedly") + } + }) +} diff --git a/internal/snapshot/snapshot.go b/internal/snapshot/snapshot.go index 49e88ef2..80523b1a 100644 --- a/internal/snapshot/snapshot.go +++ b/internal/snapshot/snapshot.go @@ -63,25 +63,34 @@ func Create(ctx context.Context, remote cache.Cache, key cache.Key, directory st tarCmd := exec.CommandContext(ctx, "tar", tarArgs...) zstdCmd := exec.CommandContext(ctx, "zstd", "-c", fmt.Sprintf("-T%d", threads)) //nolint:gosec // threads is a validated integer, not user input - tarStdout, err := tarCmd.StdoutPipe() + // Use manual pipe so we can close both ends in the parent after starting + // children. This prevents deadlock if zstd exits while tar is still writing: + // closing the read end ensures tar receives SIGPIPE instead of blocking. + pr, pw, err := os.Pipe() if err != nil { - return errors.Join(errors.Wrap(err, "failed to create tar stdout pipe"), wc.Close()) + return errors.Join(errors.Wrap(err, "failed to create pipe"), wc.Close()) } var tarStderr, zstdStderr bytes.Buffer + tarCmd.Stdout = pw tarCmd.Stderr = &tarStderr - zstdCmd.Stdin = tarStdout + zstdCmd.Stdin = pr zstdCmd.Stdout = wc zstdCmd.Stderr = &zstdStderr if err := tarCmd.Start(); err != nil { + pw.Close() //nolint:errcheck,gosec // best-effort cleanup + pr.Close() //nolint:errcheck,gosec // best-effort cleanup return errors.Join(errors.Wrap(err, "failed to start tar"), wc.Close()) } + pw.Close() //nolint:errcheck,gosec // parent no longer needs write end; tar holds its own copy if err := zstdCmd.Start(); err != nil { + pr.Close() //nolint:errcheck,gosec // let tar receive SIGPIPE so it exits return errors.Join(errors.Wrap(err, "failed to start zstd"), tarCmd.Wait(), wc.Close()) } + pr.Close() //nolint:errcheck,gosec // parent no longer needs read end; if zstd dies, tar gets SIGPIPE tarErr := tarCmd.Wait() zstdErr := zstdCmd.Wait() @@ -125,25 +134,31 @@ func StreamTo(ctx context.Context, w io.Writer, directory string, excludePattern tarCmd := exec.CommandContext(ctx, "tar", tarArgs...) zstdCmd := exec.CommandContext(ctx, "zstd", "-c", fmt.Sprintf("-T%d", threads)) //nolint:gosec // threads is a validated integer, not user input - tarStdout, err := tarCmd.StdoutPipe() + pr, pw, err := os.Pipe() if err != nil { - return errors.Wrap(err, "failed to create tar stdout pipe") + return errors.Wrap(err, "failed to create pipe") } var tarStderr, zstdStderr bytes.Buffer + tarCmd.Stdout = pw tarCmd.Stderr = &tarStderr - zstdCmd.Stdin = tarStdout + zstdCmd.Stdin = pr zstdCmd.Stdout = w zstdCmd.Stderr = &zstdStderr if err := tarCmd.Start(); err != nil { + pw.Close() //nolint:errcheck,gosec // best-effort cleanup + pr.Close() //nolint:errcheck,gosec // best-effort cleanup return errors.Wrap(err, "failed to start tar") } + pw.Close() //nolint:errcheck,gosec // parent no longer needs write end; tar holds its own copy if err := zstdCmd.Start(); err != nil { + pr.Close() //nolint:errcheck,gosec // let tar receive SIGPIPE so it exits return errors.Join(errors.Wrap(err, "failed to start zstd"), tarCmd.Wait()) } + pr.Close() //nolint:errcheck,gosec // parent no longer needs read end; if zstd dies, tar gets SIGPIPE tarErr := tarCmd.Wait() zstdErr := zstdCmd.Wait() @@ -190,25 +205,31 @@ func Extract(ctx context.Context, r io.Reader, directory string, threads int) er zstdCmd := exec.CommandContext(ctx, "zstd", "-dc", fmt.Sprintf("-T%d", threads)) //nolint:gosec // threads is a validated integer, not user input tarCmd := exec.CommandContext(ctx, "tar", "-xpf", "-", "-C", directory) - zstdCmd.Stdin = r - zstdStdout, err := zstdCmd.StdoutPipe() + pr, pw, err := os.Pipe() if err != nil { - return errors.Wrap(err, "failed to create zstd stdout pipe") + return errors.Wrap(err, "failed to create pipe") } var zstdStderr, tarStderr bytes.Buffer + zstdCmd.Stdin = r + zstdCmd.Stdout = pw zstdCmd.Stderr = &zstdStderr - tarCmd.Stdin = zstdStdout + tarCmd.Stdin = pr tarCmd.Stderr = &tarStderr if err := zstdCmd.Start(); err != nil { + pw.Close() //nolint:errcheck,gosec // best-effort cleanup + pr.Close() //nolint:errcheck,gosec // best-effort cleanup return errors.Wrap(err, "failed to start zstd") } + pw.Close() //nolint:errcheck,gosec // parent no longer needs write end; zstd holds its own copy if err := tarCmd.Start(); err != nil { + pr.Close() //nolint:errcheck,gosec // let zstd receive SIGPIPE so it exits return errors.Join(errors.Wrap(err, "failed to start tar"), zstdCmd.Wait()) } + pr.Close() //nolint:errcheck,gosec // parent no longer needs read end; if tar dies, zstd gets SIGPIPE zstdErr := zstdCmd.Wait() tarErr := tarCmd.Wait()