diff --git a/go.mod b/go.mod index 082a180e..7bc14d88 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/alecthomas/kong v1.14.0 github.com/golang-jwt/jwt/v5 v5.3.1 github.com/goproxy/goproxy v0.25.0 + github.com/klauspost/compress v1.18.5 github.com/lmittmann/tint v1.1.3 github.com/minio/minio-go/v7 v7.0.98 github.com/open-policy-agent/opa v1.14.1 @@ -40,7 +41,6 @@ require ( github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect github.com/hexops/gotextdiff v1.0.3 // indirect - github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.2.11 // indirect github.com/klauspost/crc32 v1.3.0 // indirect github.com/lestrrat-go/blackmagic v1.0.4 // indirect diff --git a/go.sum b/go.sum index eeb5da27..2d3ba6f3 100644 --- a/go.sum +++ b/go.sum @@ -76,8 +76,8 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= -github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= -github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= +github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.11 h1:0OwqZRYI2rFrjS4kvkDnqJkKHdHaRnCm68/DY4OxRzU= github.com/klauspost/cpuid/v2 v2.2.11/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= diff --git a/internal/snapshot/snapshot.go b/internal/snapshot/snapshot.go index 49e88ef2..a3147e81 100644 --- a/internal/snapshot/snapshot.go +++ b/internal/snapshot/snapshot.go @@ -2,6 +2,7 @@ package snapshot import ( + "bufio" "bytes" "context" "fmt" @@ -14,6 +15,7 @@ import ( "time" "github.com/alecthomas/errors" + "github.com/klauspost/compress/zstd" "github.com/block/cachew/internal/cache" ) @@ -61,38 +63,45 @@ func Create(ctx context.Context, remote cache.Cache, key cache.Key, directory st tarArgs = append(tarArgs, ".") tarCmd := exec.CommandContext(ctx, "tar", tarArgs...) - zstdCmd := exec.CommandContext(ctx, "zstd", "-c", fmt.Sprintf("-T%d", threads)) //nolint:gosec // threads is a validated integer, not user input tarStdout, err := tarCmd.StdoutPipe() if err != nil { return errors.Join(errors.Wrap(err, "failed to create tar stdout pipe"), wc.Close()) } - var tarStderr, zstdStderr bytes.Buffer + var tarStderr bytes.Buffer tarCmd.Stderr = &tarStderr - zstdCmd.Stdin = tarStdout - zstdCmd.Stdout = wc - zstdCmd.Stderr = &zstdStderr - if err := tarCmd.Start(); err != nil { return errors.Join(errors.Wrap(err, "failed to start tar"), wc.Close()) } - if err := zstdCmd.Start(); err != nil { - return errors.Join(errors.Wrap(err, "failed to start zstd"), tarCmd.Wait(), wc.Close()) + // Compression uses the in-process klauspost/compress/zstd encoder with NumCPU + // goroutines, producing parallel frames that can be decompressed in parallel. + // This eliminates the zstd subprocess (one fewer fork/exec, one fewer + // kernel pipe) and removes the runtime dependency on the zstd binary. + enc, err := zstd.NewWriter(wc, + zstd.WithEncoderConcurrency(threads), + zstd.WithWindowSize(zstd.MaxWindowSize)) + if err != nil { + return errors.Join(errors.Wrap(err, "failed to create zstd encoder"), tarCmd.Wait(), wc.Close()) } + _, copyErr := io.Copy(enc, tarStdout) + tarStdout.Close() //nolint:errcheck,gosec // best-effort; tar will exit via SIGPIPE + encErr := enc.Close() tarErr := tarCmd.Wait() - zstdErr := zstdCmd.Wait() closeErr := wc.Close() var errs []error if tarErr != nil { errs = append(errs, errors.Errorf("tar failed: %w: %s", tarErr, tarStderr.String())) } - if zstdErr != nil { - errs = append(errs, errors.Errorf("zstd failed: %w: %s", zstdErr, zstdStderr.String())) + if copyErr != nil { + errs = append(errs, errors.Wrap(copyErr, "failed to copy tar output to zstd encoder")) + } + if encErr != nil { + errs = append(errs, errors.Wrap(encErr, "failed to close zstd encoder")) } if closeErr != nil { errs = append(errs, errors.Wrap(closeErr, "failed to close writer")) @@ -123,37 +132,40 @@ func StreamTo(ctx context.Context, w io.Writer, directory string, excludePattern tarArgs = append(tarArgs, ".") tarCmd := exec.CommandContext(ctx, "tar", tarArgs...) - zstdCmd := exec.CommandContext(ctx, "zstd", "-c", fmt.Sprintf("-T%d", threads)) //nolint:gosec // threads is a validated integer, not user input tarStdout, err := tarCmd.StdoutPipe() if err != nil { return errors.Wrap(err, "failed to create tar stdout pipe") } - var tarStderr, zstdStderr bytes.Buffer + var tarStderr bytes.Buffer tarCmd.Stderr = &tarStderr - zstdCmd.Stdin = tarStdout - zstdCmd.Stdout = w - zstdCmd.Stderr = &zstdStderr - if err := tarCmd.Start(); err != nil { return errors.Wrap(err, "failed to start tar") } - if err := zstdCmd.Start(); err != nil { - return errors.Join(errors.Wrap(err, "failed to start zstd"), tarCmd.Wait()) + enc, err := zstd.NewWriter(w, + zstd.WithEncoderConcurrency(threads), + zstd.WithWindowSize(zstd.MaxWindowSize)) + if err != nil { + return errors.Join(errors.Wrap(err, "failed to create zstd encoder"), tarCmd.Wait()) } + _, copyErr := io.Copy(enc, tarStdout) + tarStdout.Close() //nolint:errcheck,gosec // best-effort; tar will exit via SIGPIPE + encErr := enc.Close() tarErr := tarCmd.Wait() - zstdErr := zstdCmd.Wait() var errs []error if tarErr != nil { errs = append(errs, errors.Errorf("tar failed: %w: %s", tarErr, tarStderr.String())) } - if zstdErr != nil { - errs = append(errs, errors.Errorf("zstd failed: %w: %s", zstdErr, zstdStderr.String())) + if copyErr != nil { + errs = append(errs, errors.Wrap(copyErr, "failed to copy tar output to zstd encoder")) + } + if encErr != nil { + errs = append(errs, errors.Wrap(encErr, "failed to close zstd encoder")) } return errors.Join(errs...) @@ -187,39 +199,28 @@ func Extract(ctx context.Context, r io.Reader, directory string, threads int) er return errors.Wrap(err, "failed to create target directory") } - zstdCmd := exec.CommandContext(ctx, "zstd", "-dc", fmt.Sprintf("-T%d", threads)) //nolint:gosec // threads is a validated integer, not user input - tarCmd := exec.CommandContext(ctx, "tar", "-xpf", "-", "-C", directory) - - zstdCmd.Stdin = r - zstdStdout, err := zstdCmd.StdoutPipe() + // Decompression uses the in-process Go zstd decoder to avoid subprocess IPC + // overhead (no kernel pipes, no process spawning, no goroutine synchronization + // across process boundaries). + // Buffer between the source reader and the zstd decoder. The reader may be an + // io.Pipe (zero-copy, one Read per Write), so without buffering each small + // decoder read stalls the upstream goroutine. 8 MiB lets the decoder read + // ahead while the source fills the next chunk. + dec, err := zstd.NewReader(bufio.NewReaderSize(r, 8<<20), zstd.WithDecoderConcurrency(threads)) if err != nil { - return errors.Wrap(err, "failed to create zstd stdout pipe") + return errors.Wrap(err, "failed to create zstd decoder") } + defer dec.Close() - var zstdStderr, tarStderr bytes.Buffer - zstdCmd.Stderr = &zstdStderr + tarCmd := exec.CommandContext(ctx, "tar", "-xpf", "-", "-C", directory) + tarCmd.Stdin = dec - tarCmd.Stdin = zstdStdout + var tarStderr bytes.Buffer tarCmd.Stderr = &tarStderr - if err := zstdCmd.Start(); err != nil { - return errors.Wrap(err, "failed to start zstd") + if err := tarCmd.Run(); err != nil { + return errors.Errorf("tar failed: %w: %s", err, tarStderr.String()) } - if err := tarCmd.Start(); err != nil { - return errors.Join(errors.Wrap(err, "failed to start tar"), zstdCmd.Wait()) - } - - zstdErr := zstdCmd.Wait() - tarErr := tarCmd.Wait() - - var errs []error - if zstdErr != nil { - errs = append(errs, errors.Errorf("zstd failed: %w: %s", zstdErr, zstdStderr.String())) - } - if tarErr != nil { - errs = append(errs, errors.Errorf("tar failed: %w: %s", tarErr, tarStderr.String())) - } - - return errors.Join(errs...) + return nil }