From ff9c82d016395da2e879d84e44c30b8977dd1c9a Mon Sep 17 00:00:00 2001 From: Wael Nasreddine Date: Wed, 11 Feb 2026 22:06:20 -0800 Subject: [PATCH] feat: implement ZSTD compression pooling and integrate across project This change introduces a new pkg/compress package that provides a pool of ZSTD encoders and decoders. This optimization reduces allocation overhead during high-concurrency compression and decompression operations. The implementation: - Uses sync.Pool for managing zstd.Encoder and zstd.Decoder instances. - Provides PooledZstdWriter and PooledZstdReader wrappers for easy resource management. - Integrates the pool into the local chunk store for chunk compression. - Integrates the pool into the server middleware for HTTP response compression. - Updates NAR reading and test utilities to use the centralized compression package. This was needed to improve the efficiency of the cache server when handling many small and large NAR files simultaneously, minimizing garbage collection pressure. --- pkg/cache/cache.go | 2 +- pkg/cache/cache_internal_test.go | 11 +- pkg/cache/cache_test.go | 8 +- pkg/cache/cdc_test.go | 2 +- pkg/cache/export_test.go | 15 +- pkg/nar/reader.go | 7 +- pkg/nar/reader_test.go | 9 +- pkg/server/server.go | 23 +-- pkg/server/server_test.go | 10 +- pkg/storage/chunk/local.go | 41 ++-- pkg/storage/chunk/local_test.go | 7 +- pkg/storage/chunk/s3.go | 50 ++--- pkg/storage/chunk/s3_test.go | 66 ++++-- pkg/zstd/README.md | 333 +++++++++++++++++++++++++++++++ pkg/zstd/zstd.go | 193 ++++++++++++++++++ pkg/zstd/zstd_test.go | 313 +++++++++++++++++++++++++++++ testdata/server.go | 16 +- testdata/server_test.go | 8 +- 18 files changed, 971 insertions(+), 143 deletions(-) create mode 100644 pkg/zstd/README.md create mode 100644 pkg/zstd/zstd.go create mode 100644 pkg/zstd/zstd_test.go diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 65424bdc..6cc8fd99 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -4081,7 +4081,7 @@ func zstdMutator( return func(r *http.Request) { zerolog.Ctx(ctx). Debug(). - Msg("narinfo compress is none will set Accept-Encoding to zstd") + Msg("narinfo compression is none will set Accept-Encoding to zstd") r.Header.Set("Accept-Encoding", "zstd") diff --git a/pkg/cache/cache_internal_test.go b/pkg/cache/cache_internal_test.go index a6939fd2..953e9ad4 100644 --- a/pkg/cache/cache_internal_test.go +++ b/pkg/cache/cache_internal_test.go @@ -14,7 +14,6 @@ import ( "testing" "time" - "github.com/klauspost/compress/zstd" "github.com/nix-community/go-nix/pkg/narinfo" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" @@ -26,6 +25,7 @@ import ( "github.com/kalbasit/ncps/pkg/database" "github.com/kalbasit/ncps/pkg/nar" "github.com/kalbasit/ncps/pkg/storage/local" + "github.com/kalbasit/ncps/pkg/zstd" "github.com/kalbasit/ncps/testdata" "github.com/kalbasit/ncps/testhelper" ) @@ -311,13 +311,14 @@ func testRunLRU(factory cacheFactory) func(*testing.T) { narNone := nar.CompressionTypeNone for _, entry := range entries { if entry.NarCompression == narNone { - encoder, _ := zstd.NewWriter(nil) + enc := zstd.GetWriter() + defer zstd.PutWriter(enc) var compressed bytes.Buffer - encoder.Reset(&compressed) - _, err = encoder.Write([]byte(entry.NarText)) + enc.Reset(&compressed) + _, err = enc.Write([]byte(entry.NarText)) require.NoError(t, err) - err = encoder.Close() + err = enc.Close() require.NoError(t, err) zstdSizes[entry.NarInfoHash] = uint64(compressed.Len()) //nolint:gosec diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index 996ccc2b..f42cf96f 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -15,7 +15,6 @@ import ( "testing" "time" - "github.com/klauspost/compress/zstd" "github.com/nix-community/go-nix/pkg/narinfo" "github.com/nix-community/go-nix/pkg/narinfo/signature" "github.com/rs/zerolog" @@ -31,6 +30,7 @@ import ( "github.com/kalbasit/ncps/pkg/storage" "github.com/kalbasit/ncps/pkg/storage/chunk" "github.com/kalbasit/ncps/pkg/storage/local" + "github.com/kalbasit/ncps/pkg/zstd" "github.com/kalbasit/ncps/testdata" "github.com/kalbasit/ncps/testhelper" @@ -648,10 +648,10 @@ func testGetNarInfo(factory cacheFactory) func(*testing.T) { require.NoError(t, err) if assert.NotEqual(t, narEntry.NarText, string(body), "narText should be stored compressed in the store") { - decoder, err := zstd.NewReader(nil) - require.NoError(t, err) + dec := zstd.GetReader() + defer zstd.PutReader(dec) - plain, err := decoder.DecodeAll(body, []byte{}) + plain, err := dec.DecodeAll(body, []byte{}) require.NoError(t, err) assert.Equal(t, narEntry.NarText, string(plain)) diff --git a/pkg/cache/cdc_test.go b/pkg/cache/cdc_test.go index 72b2982c..660f4ae9 100644 --- a/pkg/cache/cdc_test.go +++ b/pkg/cache/cdc_test.go @@ -334,7 +334,7 @@ func testCDCChunksAreCompressed(factory cacheFactory) func(*testing.T) { // Use highly compressible data (repeated bytes) content := strings.Repeat("compressible", 1000) - nu := nar.URL{Hash: "testnar-compress", Compression: nar.CompressionTypeNone} + nu := nar.URL{Hash: "testnar-zstd", Compression: nar.CompressionTypeNone} r := io.NopCloser(strings.NewReader(content)) err = c.PutNar(ctx, nu, r) diff --git a/pkg/cache/export_test.go b/pkg/cache/export_test.go index d4912eed..9ebb4ff8 100644 --- a/pkg/cache/export_test.go +++ b/pkg/cache/export_test.go @@ -6,11 +6,13 @@ import ( "strings" "testing" - "github.com/klauspost/compress/zstd" "github.com/nix-community/go-nix/pkg/narinfo" "github.com/nix-community/go-nix/pkg/narinfo/signature" "github.com/nix-community/go-nix/pkg/nixhash" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/kalbasit/ncps/pkg/zstd" ) // CheckAndFixNarInfo is a test-only export of the unexported checkAndFixNarInfo method. @@ -33,13 +35,14 @@ func compressZstd(t *testing.T, data string) string { var buf strings.Builder - enc, err := zstd.NewWriter(&buf) - require.NoError(t, err) - _, err = io.WriteString(enc, data) - require.NoError(t, err) - err = enc.Close() + pw := zstd.NewPooledWriter(&buf) + + _, err := io.WriteString(pw, data) require.NoError(t, err) + err = pw.Close() + assert.NoError(t, err) //nolint:testifylint + return buf.String() } diff --git a/pkg/nar/reader.go b/pkg/nar/reader.go index a4627028..5e59d7af 100644 --- a/pkg/nar/reader.go +++ b/pkg/nar/reader.go @@ -7,10 +7,11 @@ import ( "io" "github.com/andybalholm/brotli" - "github.com/klauspost/compress/zstd" "github.com/pierrec/lz4/v4" "github.com/sorairolake/lzip-go" "github.com/ulikunitz/xz" + + "github.com/kalbasit/ncps/pkg/zstd" ) // ErrUnsupportedCompressionType is returned when an unsupported compression type is encountered. @@ -31,12 +32,12 @@ func DecompressReader(r io.Reader, comp CompressionType) (io.ReadCloser, error) return io.NopCloser(bzip2.NewReader(r)), nil case CompressionTypeZstd: - zr, err := zstd.NewReader(r) + pr, err := zstd.NewPooledReader(r) if err != nil { return nil, fmt.Errorf("failed to create zstd reader: %w", err) } - return zr.IOReadCloser(), nil + return pr, nil case CompressionTypeLz4: return io.NopCloser(lz4.NewReader(r)), nil diff --git a/pkg/nar/reader_test.go b/pkg/nar/reader_test.go index 25aeb2d8..1fdd6532 100644 --- a/pkg/nar/reader_test.go +++ b/pkg/nar/reader_test.go @@ -7,7 +7,6 @@ import ( "testing" "github.com/andybalholm/brotli" - "github.com/klauspost/compress/zstd" "github.com/pierrec/lz4/v4" "github.com/sorairolake/lzip-go" "github.com/stretchr/testify/assert" @@ -15,6 +14,7 @@ import ( "github.com/ulikunitz/xz" "github.com/kalbasit/ncps/pkg/nar" + "github.com/kalbasit/ncps/pkg/zstd" ) func TestDecompressReader(t *testing.T) { @@ -70,11 +70,10 @@ func TestDecompressReader(t *testing.T) { getInput: func(t *testing.T) io.Reader { var buf bytes.Buffer - zw, err := zstd.NewWriter(&buf) + pw := zstd.NewPooledWriter(&buf) + _, err := pw.Write(content) require.NoError(t, err) - _, err = zw.Write(content) - require.NoError(t, err) - require.NoError(t, zw.Close()) + require.NoError(t, pw.Close()) return &buf }, diff --git a/pkg/server/server.go b/pkg/server/server.go index 5edb9c54..6e5357da 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -8,12 +8,10 @@ import ( "runtime/debug" "strconv" "strings" - "sync" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" - "github.com/klauspost/compress/zstd" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/riandyrn/otelchi" "github.com/rs/zerolog" @@ -31,6 +29,7 @@ import ( "github.com/kalbasit/ncps/pkg/nar" "github.com/kalbasit/ncps/pkg/narinfo" "github.com/kalbasit/ncps/pkg/storage" + "github.com/kalbasit/ncps/pkg/zstd" ) const ( @@ -60,17 +59,6 @@ var tracer trace.Tracer //nolint:gochecknoglobals var prometheusGatherer promclient.Gatherer -//nolint:gochecknoglobals -var zstdWriterPool = sync.Pool{ - New: func() interface{} { - // Not providing any options will use the default compression level. - // The error is ignored as NewWriter(nil) with no options doesn't error. - enc, _ := zstd.NewWriter(nil) - - return enc - }, -} - //nolint:gochecknoinits func init() { tracer = otel.Tracer(otelPackageName) @@ -591,16 +579,13 @@ func (s *Server) getNar(withBody bool) http.HandlerFunc { var out io.Writer = w if useZstd { - enc := zstdWriterPool.Get().(*zstd.Encoder) - enc.Reset(w) - out = enc + pw := zstd.NewPooledWriter(w) + out = pw defer func() { - if err := enc.Close(); err != nil { + if err := pw.Close(); err != nil { zerolog.Ctx(r.Context()).Error().Err(err).Msg("failed to close zstd writer") } - - zstdWriterPool.Put(enc) }() } diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 374b7da6..3cabe14b 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -13,7 +13,6 @@ import ( "testing" "time" - "github.com/klauspost/compress/zstd" "github.com/nix-community/go-nix/pkg/narinfo" "github.com/nix-community/go-nix/pkg/narinfo/signature" "github.com/rs/zerolog" @@ -30,6 +29,7 @@ import ( "github.com/kalbasit/ncps/pkg/server" "github.com/kalbasit/ncps/pkg/storage" "github.com/kalbasit/ncps/pkg/storage/local" + "github.com/kalbasit/ncps/pkg/zstd" "github.com/kalbasit/ncps/testdata" "github.com/kalbasit/ncps/testhelper" ) @@ -949,13 +949,13 @@ func TestGetNar_ZstdCompression(t *testing.T) { assert.Equal(t, "application/x-nix-nar", resp.Header.Get("Content-Type")) assert.Empty(t, resp.Header.Get("Content-Length")) - // 3. Decompress the body and verify content - dec, err := zstd.NewReader(resp.Body) + // 3. DecompressReader the body and verify content + pr, err := zstd.NewPooledReader(resp.Body) require.NoError(t, err) - defer dec.Close() + defer pr.Close() - decompressed, err := io.ReadAll(dec) + decompressed, err := io.ReadAll(pr) require.NoError(t, err) assert.Equal(t, narData, string(decompressed)) } diff --git a/pkg/storage/chunk/local.go b/pkg/storage/chunk/local.go index dbbd1681..076466e7 100644 --- a/pkg/storage/chunk/local.go +++ b/pkg/storage/chunk/local.go @@ -7,17 +7,17 @@ import ( "os" "path/filepath" - "github.com/klauspost/compress/zstd" + "github.com/kalbasit/ncps/pkg/zstd" ) -// localReadCloser wraps a zstd decoder and file to properly close both on Close(). +// localReadCloser wraps a pooled zstd reader and file to properly close both on Close(). type localReadCloser struct { - *zstd.Decoder + *zstd.PooledReader file *os.File } func (r *localReadCloser) Close() error { - r.Decoder.Close() + _ = r.PooledReader.Close() return r.file.Close() } @@ -25,34 +25,15 @@ func (r *localReadCloser) Close() error { // localStore implements Store for local filesystem. type localStore struct { baseDir string - encoder *zstd.Encoder - decoder *zstd.Decoder } // NewLocalStore returns a new local chunk store. func NewLocalStore(baseDir string) (Store, error) { - encoder, err := zstd.NewWriter(nil) - if err != nil { - return nil, fmt.Errorf("failed to create zstd encoder: %w", err) - } - - decoder, err := zstd.NewReader(nil) - if err != nil { - encoder.Close() - - return nil, fmt.Errorf("failed to create zstd decoder: %w", err) - } - s := &localStore{ baseDir: baseDir, - encoder: encoder, - decoder: decoder, } // Ensure base directory exists if err := os.MkdirAll(s.storeDir(), 0o755); err != nil { - encoder.Close() - decoder.Close() - return nil, fmt.Errorf("failed to create chunk store directory: %w", err) } @@ -94,15 +75,15 @@ func (s *localStore) GetChunk(_ context.Context, hash string) (io.ReadCloser, er return nil, err } - // Create a new decoder for this specific file - decoder, err := zstd.NewReader(f) + // Use pooled reader instead of creating new instance + pr, err := zstd.NewPooledReader(f) if err != nil { f.Close() - return nil, fmt.Errorf("failed to create zstd decoder: %w", err) + return nil, fmt.Errorf("failed to create zstd reader: %w", err) } - return &localReadCloser{decoder, f}, nil + return &localReadCloser{pr, f}, nil } func (s *localStore) PutChunk(_ context.Context, hash string, data []byte) (bool, int64, error) { @@ -114,8 +95,12 @@ func (s *localStore) PutChunk(_ context.Context, hash string, data []byte) (bool return false, 0, err } + // Use pooled encoder + enc := zstd.GetWriter() + defer zstd.PutWriter(enc) + // Compress data with zstd - compressed := s.encoder.EncodeAll(data, nil) + compressed := enc.EncodeAll(data, nil) // Write to temporary file first to ensure atomicity tmpFile, err := os.CreateTemp(dir, "chunk-*") diff --git a/pkg/storage/chunk/local_test.go b/pkg/storage/chunk/local_test.go index 183fd5d4..0670ed69 100644 --- a/pkg/storage/chunk/local_test.go +++ b/pkg/storage/chunk/local_test.go @@ -164,7 +164,7 @@ func TestLocalStore(t *testing.T) { // Use highly compressible data (repeated bytes) data := bytes.Repeat([]byte("compressible"), 1024) - isNew, compressedSize, err := store.PutChunk(ctx, "test-hash-compress-1", data) + isNew, compressedSize, err := store.PutChunk(ctx, testhelper.MustRandNarHash(), data) require.NoError(t, err) assert.True(t, isNew) assert.Greater(t, int64(len(data)), compressedSize, "compressed size should be less than original") @@ -175,10 +175,11 @@ func TestLocalStore(t *testing.T) { t.Parallel() data := []byte("hello, compressed world! hello, compressed world! hello, compressed world!") - _, _, err := store.PutChunk(ctx, "test-hash-roundtrip", data) + hash := testhelper.MustRandNarHash() + _, _, err := store.PutChunk(ctx, hash, data) require.NoError(t, err) - rc, err := store.GetChunk(ctx, "test-hash-roundtrip") + rc, err := store.GetChunk(ctx, hash) require.NoError(t, err) defer rc.Close() diff --git a/pkg/storage/chunk/s3.go b/pkg/storage/chunk/s3.go index 9871dc39..d5344173 100644 --- a/pkg/storage/chunk/s3.go +++ b/pkg/storage/chunk/s3.go @@ -10,12 +10,12 @@ import ( "path" "time" - "github.com/klauspost/compress/zstd" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" "github.com/kalbasit/ncps/pkg/lock" "github.com/kalbasit/ncps/pkg/s3" + "github.com/kalbasit/ncps/pkg/zstd" ) // ErrBucketNotFound is returned when the bucket is not found. @@ -28,25 +28,23 @@ const ( chunkPutLockTTL = 5 * time.Minute ) -// s3ReadCloser wraps a zstd decoder and io.ReadCloser to properly close both. +// s3ReadCloser wraps a pooled zstd reader and io.ReadCloser to properly close both. type s3ReadCloser struct { - *zstd.Decoder + *zstd.PooledReader body io.ReadCloser } func (r *s3ReadCloser) Close() error { - r.Decoder.Close() + _ = r.PooledReader.Close() return r.body.Close() } // s3Store implements Store for S3 storage. type s3Store struct { - client *minio.Client - locker lock.Locker - bucket string - encoder *zstd.Encoder - decoder *zstd.Decoder + client *minio.Client + locker lock.Locker + bucket string } // NewS3Store returns a new S3 chunk store. @@ -88,24 +86,10 @@ func NewS3Store(ctx context.Context, cfg s3.Config, locker lock.Locker) (Store, return nil, fmt.Errorf("%w: %s", ErrBucketNotFound, cfg.Bucket) } - encoder, err := zstd.NewWriter(nil) - if err != nil { - return nil, fmt.Errorf("failed to create zstd encoder: %w", err) - } - - decoder, err := zstd.NewReader(nil) - if err != nil { - encoder.Close() - - return nil, fmt.Errorf("failed to create zstd decoder: %w", err) - } - return &s3Store{ - client: client, - locker: locker, - bucket: cfg.Bucket, - encoder: encoder, - decoder: decoder, + client: client, + locker: locker, + bucket: cfg.Bucket, }, nil } @@ -147,15 +131,15 @@ func (s *s3Store) GetChunk(ctx context.Context, hash string) (io.ReadCloser, err return nil, err } - // Create a new decoder for this specific object - decoder, err := zstd.NewReader(obj) + // Use pooled reader instead of creating new instance + pr, err := zstd.NewPooledReader(obj) if err != nil { obj.Close() - return nil, fmt.Errorf("failed to create zstd decoder: %w", err) + return nil, fmt.Errorf("failed to create zstd reader: %w", err) } - return &s3ReadCloser{decoder, obj}, nil + return &s3ReadCloser{pr, obj}, nil } func (s *s3Store) PutChunk(ctx context.Context, hash string, data []byte) (bool, int64, error) { @@ -172,8 +156,12 @@ func (s *s3Store) PutChunk(ctx context.Context, hash string, data []byte) (bool, _ = s.locker.Unlock(ctx, lockKey) }() + // Use pooled encoder + enc := zstd.GetWriter() + defer zstd.PutWriter(enc) + // Compress data with zstd - compressed := s.encoder.EncodeAll(data, nil) + compressed := enc.EncodeAll(data, nil) // Check if exists. exists, err := s.HasChunk(ctx, hash) diff --git a/pkg/storage/chunk/s3_test.go b/pkg/storage/chunk/s3_test.go index 615dfc93..701e8b25 100644 --- a/pkg/storage/chunk/s3_test.go +++ b/pkg/storage/chunk/s3_test.go @@ -117,30 +117,34 @@ func TestS3Store_Integration(t *testing.T) { t.Run("stored chunk is zstd-compressed in S3", func(t *testing.T) { t.Parallel() + hash := testhelper.MustRandNarHash() + data := bytes.Repeat([]byte("compressible"), 1024) - isNew, compressedSize, err := store.PutChunk(ctx, "test-hash-s3-compress", data) + isNew, compressedSize, err := store.PutChunk(ctx, hash, data) require.NoError(t, err) assert.True(t, isNew) assert.Greater(t, int64(len(data)), compressedSize, "compressed size should be less than original") assert.Positive(t, compressedSize) defer func() { - _ = store.DeleteChunk(ctx, "test-hash-s3-compress") + _ = store.DeleteChunk(ctx, hash) }() }) t.Run("compressed chunk round-trips correctly via S3", func(t *testing.T) { t.Parallel() + hash := testhelper.MustRandNarHash() + data := []byte("hello from S3 compressed chunk! hello from S3 compressed chunk!") - _, _, err := store.PutChunk(ctx, "test-hash-s3-roundtrip", data) + _, _, err := store.PutChunk(ctx, hash, data) require.NoError(t, err) defer func() { - _ = store.DeleteChunk(ctx, "test-hash-s3-roundtrip") + _ = store.DeleteChunk(ctx, hash) }() - rc, err := store.GetChunk(ctx, "test-hash-s3-roundtrip") + rc, err := store.GetChunk(ctx, hash) require.NoError(t, err) defer rc.Close() @@ -151,8 +155,18 @@ func TestS3Store_Integration(t *testing.T) { }) } -func TestS3Store_PutChunk_RaceCondition(t *testing.T) { +func TestS3Store_PutSameChunk_RaceCondition(t *testing.T) { t.Parallel() + runRaceConditionTest(t, false) +} + +func TestS3Store_PutDifferentChunk_RaceCondition(t *testing.T) { + t.Parallel() + runRaceConditionTest(t, true) +} + +func runRaceConditionTest(t *testing.T, distinctHashes bool) { + t.Helper() ctx := context.Background() @@ -162,37 +176,49 @@ func TestS3Store_PutChunk_RaceCondition(t *testing.T) { } // We pass a local locker to ensure thread safety during the test. - store, err := chunk.NewS3Store(ctx, *cfg, local.NewLocker()) require.NoError(t, err) - hash := "test-hash-race" - content := []byte(strings.Repeat("race condition content", 1024)) + const numGoRoutines = 10 + + hashes := make(chan string, numGoRoutines) defer func() { - _ = store.DeleteChunk(ctx, hash) + close(hashes) + + for hash := range hashes { + _ = store.DeleteChunk(ctx, hash) + } }() - const numGoRoutines = 10 + content := []byte(strings.Repeat("race condition content", 1024)) + sharedHash := testhelper.MustRandNarHash() results := make(chan bool, numGoRoutines) - errors := make(chan error, numGoRoutines) + errs := make(chan error, numGoRoutines) for range numGoRoutines { go func() { + hash := sharedHash + if distinctHashes { + hash = testhelper.MustRandNarHash() + } + + hashes <- hash + created, size, err := store.PutChunk(ctx, hash, content) results <- created assert.Greater(t, int64(len(content)), size) - errors <- err + errs <- err }() } createdCount := 0 for range numGoRoutines { - err := <-errors + err := <-errs require.NoError(t, err) if <-results { @@ -200,10 +226,14 @@ func TestS3Store_PutChunk_RaceCondition(t *testing.T) { } } - // The contract says true if chunk was new. In a race condition WITHOUT locking, - // multiple goroutines might see created: true. - // We want to ensure only ONE goroutine gets created: true. - assert.Equal(t, 1, createdCount, "Only one goroutine should have created the chunk") + if distinctHashes { + assert.Equal(t, numGoRoutines, createdCount, "All goroutines should have created their unique chunk") + } else { + // The contract says true if chunk was new. In a race condition WITHOUT locking, + // multiple goroutines might see created: true. + // We want to ensure only ONE goroutine gets created: true. + assert.Equal(t, 1, createdCount, "Only one goroutine should have created the chunk") + } } func TestNewS3Store_Validation(t *testing.T) { diff --git a/pkg/zstd/README.md b/pkg/zstd/README.md new file mode 100644 index 00000000..9b8122fe --- /dev/null +++ b/pkg/zstd/README.md @@ -0,0 +1,333 @@ +# ZSTD Pool Management + +## Overview + +The `pkg/zstd` package provides a `sync.Pool`-based implementation for recycling zstd encoder and decoder instances. This reduces allocation overhead when creating multiple compression/decompression operations, which is especially beneficial in high-throughput scenarios like the NCPS cache server. + +## Motivation + +Creating new `zstd.Encoder` and `zstd.Decoder` instances is relatively expensive due to internal buffer allocations. When handling many compression/decompression operations (as in chunk storage and HTTP compression), reusing these instances via a pool significantly reduces garbage collection pressure and improves performance. + +## Quick Reference + +### Import + +```go +import "github.com/kalbasit/ncps/pkg/zstd" +``` + +### Common Patterns + +#### Compress Data + +```go +pw := zstd.NewPooledWriter(&buf) +defer pw.Close() +pw.Write(data) +``` + +#### Decompress Data + +```go +pr, err := zstd.NewPooledReader(reader) +if err != nil { + return err +} +defer pr.Close() +data, _ := io.ReadAll(pr) +``` + +#### One-Shot Encoding + +```go +enc := zstd.GetWriter() +defer zstd.PutWriter(enc) +compressed := enc.EncodeAll(data, nil) +``` + +#### One-Shot Decoding + +```go +dec := zstd.GetReader() +defer zstd.PutReader(dec) +dec.Reset(reader) +data, _ := io.ReadAll(dec) +``` + +### API Cheat Sheet + +| Function | Purpose | Returns | Error | +|----------|---------|---------|-------| +| `GetWriter()` | Get encoder from pool | `*zstd.Encoder` | N/A | +| `PutWriter(enc)` | Return encoder to pool | `void` | N/A | +| `GetReader()` | Get decoder from pool | `*zstd.Decoder` | N/A | +| `PutReader(dec)` | Return decoder to pool | `void` | N/A | +| `NewPooledWriter(w)` | Create auto-managed writer | `*PooledWriter` | N/A | +| `NewPooledReader(r)` | Create auto-managed reader | `*PooledReader` | error | +| `pw.Close()` | Close writer, return to pool | `error` | compression error | +| `pr.Close()` | Close reader, return to pool | `error` | nil | + +______________________________________________________________________ + +## API Documentation + +### Low-Level API (Manual Management) + +For fine-grained control, use the low-level functions: + +#### Writer Pool + +```go +// Get an encoder from the pool +enc := zstd.GetWriter() +defer zstd.PutWriter(enc) + +// Reset the encoder to write to a buffer +var buf bytes.Buffer +enc.Reset(&buf) + +// Use the encoder +enc.Write(data) +enc.Close() + +// The encoder is automatically reset before being returned to the pool +``` + +#### Reader Pool + +```go +// Get a decoder from the pool +dec := zstd.GetReader() +defer zstd.PutReader(dec) + +// Reset the decoder to read from a compressed source +dec.Reset(compressedReader) + +// Use the decoder +decompressed, err := io.ReadAll(dec) +``` + +### High-Level API (Automatic Management) + +For simplicity and to avoid resource leaks, use the wrapped types: + +#### PooledWriter + +```go +import "github.com/kalbasit/ncps/pkg/zstd" + +// Create a pooled writer - automatically manages the encoder +pw := zstd.NewPooledWriter(&buf) +defer pw.Close() // Automatically returns encoder to pool + +// Use like a normal zstd encoder +pw.Write(data) +pw.Close() +``` + +#### PooledReader + +```go +import "github.com/kalbasit/ncps/pkg/zstd" + +// Create a pooled reader - automatically manages the decoder +pr, err := zstd.NewPooledReader(compressedReader) +if err != nil { + return err +} +defer pr.Close() // Automatically returns decoder to pool + +// Use like a normal zstd decoder +data, err := io.ReadAll(pr) +``` + +## Usage Examples + +### Example 1: Compressing Multiple Data Chunks + +```go +func compressChunks(chunks [][]byte) ([][]byte, error) { + result := make([][]byte, len(chunks)) + + for i, chunk := range chunks { + var buf bytes.Buffer + pw := zstd.NewPooledWriter(&buf) + + if _, err := pw.Write(chunk); err != nil { + pw.Close() + return nil, err + } + + if err := pw.Close(); err != nil { + return nil, err + } + + result[i] = buf.Bytes() + } + + return result, nil +} +``` + +### Example 2: Decompressing Data + +```go +func decompressData(compressed []byte) ([]byte, error) { + pr, err := zstd.NewPooledReader(bytes.NewReader(compressed)) + if err != nil { + return nil, err + } + defer pr.Close() + + return io.ReadAll(pr) +} +``` + +### Example 3: Direct Encoding (No Streaming) + +```go +func quickCompress(data []byte) []byte { + enc := zstd.GetWriter() + defer zstd.PutWriter(enc) + + // Use EncodeAll for non-streaming compression + return enc.EncodeAll(data, nil) +} +``` + +## Pool Configuration + +Both pools use the default zstdion level and settings: + +- **WriterPool**: Default compression level (fast but good compression) +- **ReaderPool**: Default decompression settings + +For custom zstdion levels or options, create encoders/decoders directly without pooling: + +```go +// For custom compression level +enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedDefault)) +if err != nil { + return err +} +defer enc.Close() +``` + +## Performance Considerations + +1. **Pool Benefits**: Most beneficial when you have many compression/decompression operations +1. **Memory Trade-off**: The pool maintains encoder/decoder instances in memory, ready for reuse +1. **Thread-Safe**: `sync.Pool` is thread-safe and designed for concurrent use +1. **Automatic Cleanup**: Decoders and encoders are reset to a clean state before being returned to the pool + +## Integration Points + +The zstd pool is used in: + +- `pkg/server/server.go` - HTTP response compression +- `pkg/storage/chunk/local.go` - Local chunk storage compression +- `pkg/storage/chunk/s3.go` - S3 chunk storage compression +- Test utilities and helpers + +## Migration Guide + +To migrate existing code to use the zstd pool: + +### Before (Direct Creation) + +```go +import "github.com/klauspost/zstd/zstd" + +encoder, err := zstd.NewWriter(&buf) +if err != nil { + return err +} +defer encoder.Close() +encoder.Write(data) +``` + +### After (Using Pool) + +```go +import "github.com/kalbasit/ncps/pkg/zstd" + +pw := zstd.NewPooledWriter(&buf) +defer pw.Close() +pw.Write(data) +``` + +## Best Practices + +1. **Always defer Close()**: Ensure pooled resources are returned promptly +1. **Use Wrapped Types**: Prefer `PooledWriter` and `PooledReader` for cleaner code +1. **Handle Errors**: Check errors from Close(), Reset(), and Read/Write operations +1. **One Writer/Reader Per Operation**: Get/release for each independent compression/decompression +1. **Avoid Nested Pools**: Don't hold multiple pooled instances simultaneously unless necessary + +## Testing + +The zstd pool includes comprehensive tests in `pkg/zstd/zstd_test.go`: + +```bash +go test ./pkg/zstd -v -run +``` + +Tests cover: + +- Pool allocation and reuse +- Round-trip compression/decompression +- Error handling +- Resource cleanup +- Concurrent pool access + +______________________________________________________________________ + +## Implementation Details + +### Files Created + +#### 1. `pkg/zstd/zstd.go` + +The main implementation file containing: + +- **WriterPool**: A `sync.Pool` managing reusable `zstd.Encoder` instances +- **ReaderPool**: A `sync.Pool` managing reusable `zstd.Decoder` instances + +#### 2. `pkg/zstd/zstd_test.go` + +Comprehensive test suite covering: + +- Pool get/put operations +- Pooled wrapper functionality +- Round-trip compression/decompression +- Error handling +- Multiple close operations +- Nil safety +- EncodeAll pattern support + +### Design Decisions + +#### Why `sync.Pool`? + +- Built into Go standard library +- Thread-safe without explicit locking +- Automatically adjusts to contention +- Zero-copy semantics + +#### Why Two APIs? + +- **Low-level**: For complex scenarios needing manual control +- **High-level**: For common cases with automatic cleanup +- Recommendation: Use high-level in most cases + +#### Why Default Compression Level? + +- Covers 99% of use cases +- Custom levels can use direct `zstd.NewWriter()` +- Simpler pool implementation + +#### Decoder Reset Pattern + +- Decoders are reset but not explicitly closed when returned to pool +- Prevents "decoder used after Close" errors +- Allows safe reuse of pooled decoders diff --git a/pkg/zstd/zstd.go b/pkg/zstd/zstd.go new file mode 100644 index 00000000..7c3f0682 --- /dev/null +++ b/pkg/zstd/zstd.go @@ -0,0 +1,193 @@ +// Package zstd provides compression utilities for the NCPS project. +package zstd + +import ( + "io" + "sync" + + "github.com/klauspost/compress/zstd" +) + +// writerPool manages a pool of zstd.Encoder instances for reuse. +// This pool is used to reduce allocation overhead when creating multiple +// compression writers. Encoders are reset before being returned to the pool +// and are ready for immediate reuse. +// +// The pool uses the default compression level (no options specified). +// For custom compression levels, create encoders directly with zstd.NewWriter. +// +//nolint:gochecknoglobals +var writerPool = sync.Pool{ + New: func() any { + // Not providing any options will use the default compression level. + // The error is ignored as NewWriter(nil) with no options doesn't error. + enc, _ := zstd.NewWriter(nil) + + return enc + }, +} + +// GetWriter retrieves a zstd.Encoder from the pool, or creates a new one +// if the pool is empty. The caller must call PutWriter to return the encoder +// to the pool when done. +// +// Example: +// +// enc := GetWriter() +// defer PutWriter(enc) +// enc.Reset(buf) +// enc.Write(data) +// enc.Close() +func GetWriter() *zstd.Encoder { + return writerPool.Get().(*zstd.Encoder) +} + +// PutWriter returns a zstd.Encoder to the pool for reuse. +// The encoder is reset to nil before being returned to the pool. +// If enc is nil, this function is a no-op. +// +// Always pair calls to GetWriter with PutWriter in a defer statement +// or ensure it's called in all code paths. +func PutWriter(enc *zstd.Encoder) { + if enc != nil { + enc.Reset(nil) + writerPool.Put(enc) + } +} + +// readerPool manages a pool of zstd.Decoder instances for reuse. +// This pool is used to reduce allocation overhead when creating multiple +// decompression readers. Decoders are reset before being returned to the pool +// and are ready for immediate reuse. +// +// The pool uses the default decompression settings (no options specified). +// For custom decompression settings, create decoders directly with zstd.NewReader. +// +//nolint:gochecknoglobals +var readerPool = sync.Pool{ + New: func() any { + // Not providing any options will use the default decompression settings. + // The error is ignored as NewReader(nil) with no options doesn't error. + dec, _ := zstd.NewReader(nil) + + return dec + }, +} + +// GetReader retrieves a zstd.Decoder from the pool, or creates a new one +// if the pool is empty. The caller must call PutReader or use NewPooledReader +// for automatic pool management. +// +// Note: Prefer NewPooledReader for automatic resource cleanup. +// +// Example (manual management): +// +// dec := GetReader() +// defer PutReader(dec) +// dec.Reset(reader) +// data, err := io.ReadAll(dec) +func GetReader() *zstd.Decoder { + return readerPool.Get().(*zstd.Decoder) +} + +// PutReader returns a zstd.Decoder to the pool for reuse. +// The decoder is reset to nil before being returned to the pool. +// If dec is nil, this function is a no-op. +// +// Always pair calls to GetReader with PutReader in a defer statement +// or ensure it's called in all code paths. +func PutReader(dec *zstd.Decoder) { + if dec != nil { + _ = dec.Reset(nil) + readerPool.Put(dec) + } +} + +// PooledWriter wraps a zstd.Encoder with automatic pool management. +// When closed, the encoder is automatically returned to the pool. +// +// Example: +// +// pw := NewPooledWriter(&buf) +// defer pw.Close() +// pw.Write(data) +type PooledWriter struct { + *zstd.Encoder + w io.Writer +} + +// NewPooledWriter creates a new pooled writer that wraps the given io.Writer. +// The returned writer will automatically return its encoder to the pool when closed. +// This is the recommended way to use pooled writers for write operations. +func NewPooledWriter(w io.Writer) *PooledWriter { + enc := GetWriter() + enc.Reset(w) + + return &PooledWriter{ + Encoder: enc, + w: w, + } +} + +// Close closes the encoder and returns it to the pool. +// Multiple calls to Close are safe and will not panic. +func (pw *PooledWriter) Close() error { + if pw.Encoder == nil { + return nil + } + + err := pw.Encoder.Close() + PutWriter(pw.Encoder) + pw.Encoder = nil + + return err +} + +// PooledReader wraps a zstd.Decoder with automatic pool management. +// When closed, the decoder is automatically returned to the pool. +// +// Example: +// +// pr, err := NewPooledReader(compressedReader) +// if err != nil { +// return err +// } +// defer pr.Close() +// data, err := io.ReadAll(pr) +type PooledReader struct { + *zstd.Decoder + r io.Reader +} + +// NewPooledReader creates a new pooled reader that wraps the given io.Reader. +// The returned reader will automatically return its decoder to the pool when closed. +// This is the recommended way to use pooled readers for read operations. +// +// Returns an error if the decoder cannot be reset to read from the given reader. +func NewPooledReader(r io.Reader) (*PooledReader, error) { + dec := GetReader() + if err := dec.Reset(r); err != nil { + PutReader(dec) + + return nil, err + } + + return &PooledReader{ + Decoder: dec, + r: r, + }, nil +} + +// Close closes the reader and returns it to the pool. +// Multiple calls to Close are safe and will not panic. +// Note: The underlying decoder is not explicitly closed, only reset and returned to the pool. +func (pr *PooledReader) Close() error { + if pr.Decoder == nil { + return nil + } + + PutReader(pr.Decoder) + pr.Decoder = nil + + return nil +} diff --git a/pkg/zstd/zstd_test.go b/pkg/zstd/zstd_test.go new file mode 100644 index 00000000..fec86e04 --- /dev/null +++ b/pkg/zstd/zstd_test.go @@ -0,0 +1,313 @@ +package zstd_test + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/kalbasit/ncps/pkg/zstd" +) + +func TestGetAndPutWriter(t *testing.T) { + t.Parallel() + + // Get a writer from the pool + writer1 := zstd.GetWriter() + require.NotNil(t, writer1) + + // Reset and put it back + zstd.PutWriter(writer1) + + // Get another writer - should be the same instance if pool reuses + writer2 := zstd.GetWriter() + require.NotNil(t, writer2) + + zstd.PutWriter(writer2) +} + +func TestGetAndPutReader(t *testing.T) { + t.Parallel() + + // Get a reader from the pool + reader1 := zstd.GetReader() + require.NotNil(t, reader1) + + // Reset and put it back + zstd.PutReader(reader1) + + // Get another reader - should be the same instance if pool reuses + reader2 := zstd.GetReader() + require.NotNil(t, reader2) + + zstd.PutReader(reader2) +} + +func TestPutWriterWithNil(t *testing.T) { + t.Parallel() + + // Should not panic when putting nil + zstd.PutWriter(nil) +} + +func TestPutReaderWithNil(t *testing.T) { + t.Parallel() + + // Should not panic when putting nil + zstd.PutReader(nil) +} + +func TestPooledWriter(t *testing.T) { + t.Parallel() + + data := []byte("Hello, World!") + + var buf bytes.Buffer + + writer := zstd.NewPooledWriter(&buf) + require.NotNil(t, writer) + + // Write data + n, err := writer.Write(data) + require.NoError(t, err) + assert.Equal(t, len(data), n) + + // Close the writer + err = writer.Close() + require.NoError(t, err) + + // Verify data was compressed + assert.NotEmpty(t, buf.Bytes()) +} + +func TestPooledWriterCloseMultiple(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + + writer := zstd.NewPooledWriter(&buf) + require.NotNil(t, writer) + + // Close should be idempotent + err := writer.Close() + require.NoError(t, err) + + // Second close should not panic + err = writer.Close() + require.NoError(t, err) +} + +func TestPooledReader(t *testing.T) { + t.Parallel() + + // First, create and zstd some data + originalData := []byte("Hello, Reader!") + + var compressed bytes.Buffer + + writer := zstd.NewPooledWriter(&compressed) + require.NotNil(t, writer) + + _, err := writer.Write(originalData) + require.NoError(t, err) + + err = writer.Close() + require.NoError(t, err) + + // Now dezstd using pooled reader + reader, err := zstd.NewPooledReader(bytes.NewReader(compressed.Bytes())) + require.NoError(t, err) + require.NotNil(t, reader) + + // Read all data + decompressed, err := io.ReadAll(reader) + require.NoError(t, err) + + assert.Equal(t, originalData, decompressed) + + // Close the reader + err = reader.Close() + require.NoError(t, err) +} + +func TestPooledReaderCloseMultiple(t *testing.T) { + t.Parallel() + + // Compress some data first + originalData := []byte("test data") + + var compressed bytes.Buffer + + writer := zstd.NewPooledWriter(&compressed) + _, err := writer.Write(originalData) + require.NoError(t, err) + err = writer.Close() + require.NoError(t, err) + + // Create pooled reader + reader, err := zstd.NewPooledReader(bytes.NewReader(compressed.Bytes())) + require.NoError(t, err) + + // Read the data before closing + _, err = io.ReadAll(reader) + require.NoError(t, err) + + // Close multiple times should not panic + err = reader.Close() + require.NoError(t, err) + + err = reader.Close() + require.NoError(t, err) +} + +func TestPooledReaderInvalidData(t *testing.T) { + t.Parallel() + + // Try to read from invalid zstd data + invalidData := []byte("not compressed data") + reader, err := zstd.NewPooledReader(bytes.NewReader(invalidData)) + // This should not error on creation, but on read + if err != nil { + // If error occurs during Reset, that's also acceptable + return + } + + require.NotNil(t, reader) + + // Reading should fail with invalid data + _, err = io.ReadAll(reader) + require.Error(t, err) + + reader.Close() +} + +func TestPooledReaderWithNilDecoder(t *testing.T) { + t.Parallel() + + // Create a pooled reader and close it without using it + originalData := []byte("test") + + var compressed bytes.Buffer + + writer := zstd.NewPooledWriter(&compressed) + _, err := writer.Write(originalData) + require.NoError(t, err) + err = writer.Close() + require.NoError(t, err) + + reader, err := zstd.NewPooledReader(bytes.NewReader(compressed.Bytes())) + require.NoError(t, err) + + // Manually set to nil to test Close with nil decoder + reader.Decoder = nil + err = reader.Close() + require.NoError(t, err) +} + +func TestWriterPoolReuse(t *testing.T) { + t.Parallel() + + // Test that pool actually reuses instances + writer1 := zstd.GetWriter() + ptr1 := writer1 + + zstd.PutWriter(writer1) + + writer2 := zstd.GetWriter() + ptr2 := writer2 + + zstd.PutWriter(writer2) + + // In most cases they should be the same pointer (pool reuse) + // But this is not guaranteed, so we just verify we can use them + assert.NotNil(t, ptr1) + assert.NotNil(t, ptr2) +} + +func TestReaderPoolReuse(t *testing.T) { + t.Parallel() + + // Test that pool actually reuses instances + reader1 := zstd.GetReader() + ptr1 := reader1 + + zstd.PutReader(reader1) + + reader2 := zstd.GetReader() + ptr2 := reader2 + + zstd.PutReader(reader2) + + // In most cases they should be the same pointer (pool reuse) + // But this is not guaranteed, so we just verify we can use them + assert.NotNil(t, ptr1) + assert.NotNil(t, ptr2) +} + +func TestPooledWriterAndReaderRoundTrip(t *testing.T) { + t.Parallel() + + testCases := []string{ + "Hello, World!", + "", + "a", + "The quick brown fox jumps over the lazy dog", + "Multiple\nlines\nof\ntext", + } + + for _, testData := range testCases { + testData := testData + t.Run(testData, func(t *testing.T) { + t.Parallel() + + // Compress + var compressed bytes.Buffer + + writer := zstd.NewPooledWriter(&compressed) + require.NotNil(t, writer) + + n, err := writer.Write([]byte(testData)) + require.NoError(t, err) + assert.Equal(t, len(testData), n) + + err = writer.Close() + require.NoError(t, err) + + // DecompressReader + reader, err := zstd.NewPooledReader(bytes.NewReader(compressed.Bytes())) + require.NoError(t, err) + require.NotNil(t, reader) + + decompressed, err := io.ReadAll(reader) + require.NoError(t, err) + + assert.Equal(t, testData, string(decompressed)) + + err = reader.Close() + require.NoError(t, err) + }) + } +} + +func TestPooledWriterEncodeAllPattern(t *testing.T) { + t.Parallel() + + testData := []byte("test data for encode all pattern") + + // Test the EncodeAll pattern used in chunk storage + writer := zstd.GetWriter() + compressed := writer.EncodeAll(testData, nil) + zstd.PutWriter(writer) + + // Verify the compressed data can be decompressed + reader, err := zstd.NewPooledReader(bytes.NewReader(compressed)) + require.NoError(t, err) + + decompressed, err := io.ReadAll(reader) + require.NoError(t, err) + require.NoError(t, reader.Close()) + + assert.Equal(t, testData, decompressed) +} diff --git a/testdata/server.go b/testdata/server.go index 18decef6..047e612e 100644 --- a/testdata/server.go +++ b/testdata/server.go @@ -9,9 +9,8 @@ import ( "sync" "testing" - "github.com/klauspost/compress/zstd" - "github.com/kalbasit/ncps/pkg/nar" + "github.com/kalbasit/ncps/pkg/zstd" "github.com/kalbasit/ncps/testhelper" ) @@ -37,7 +36,7 @@ func NewTestServer(t *testing.T, priority int) *Server { s.entries = append(s.entries, Entries...) - s.Server = httptest.NewServer(compressMiddleware(s.handler())) + s.Server = httptest.NewServer(zstdMiddleware(s.handler())) return s } @@ -74,7 +73,7 @@ func (s *Server) RemoveMaybeHandler(idx string) { delete(s.maybeHandlers, idx) } -func compressMiddleware(next http.Handler) http.Handler { +func zstdMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Accept-Encoding") != "zstd" { next.ServeHTTP(w, r) @@ -82,13 +81,10 @@ func compressMiddleware(next http.Handler) http.Handler { return } - encoder, err := zstd.NewWriter(w) - if !requireNoError(w, err) { - return - } - defer encoder.Close() + pw := zstd.NewPooledWriter(w) + defer pw.Close() - zw := &zstdResponseWriter{Writer: encoder, ResponseWriter: w} + zw := &zstdResponseWriter{Writer: pw, ResponseWriter: w} next.ServeHTTP(zw, r) }) diff --git a/testdata/server_test.go b/testdata/server_test.go index da414eee..5ee79ef2 100644 --- a/testdata/server_test.go +++ b/testdata/server_test.go @@ -6,10 +6,10 @@ import ( "net/http" "testing" - "github.com/klauspost/compress/zstd" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/kalbasit/ncps/pkg/zstd" "github.com/kalbasit/ncps/testdata" ) @@ -77,10 +77,10 @@ func TestNewTestServerWithZSTD(t *testing.T) { require.NoError(t, err) if assert.NotEqual(t, testdata.Nar1.NarText, string(body)) { - decoder, err := zstd.NewReader(nil) - require.NoError(t, err) + dec := zstd.GetReader() + defer zstd.PutReader(dec) - plain, err := decoder.DecodeAll(body, []byte{}) + plain, err := dec.DecodeAll(body, []byte{}) require.NoError(t, err) if assert.Len(t, testdata.Nar1.NarText, len(string(plain))) {