diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 65424bdc..6cc8fd99 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -4081,7 +4081,7 @@ func zstdMutator( return func(r *http.Request) { zerolog.Ctx(ctx). Debug(). - Msg("narinfo compress is none will set Accept-Encoding to zstd") + Msg("narinfo compression is none will set Accept-Encoding to zstd") r.Header.Set("Accept-Encoding", "zstd") diff --git a/pkg/cache/cache_internal_test.go b/pkg/cache/cache_internal_test.go index a6939fd2..953e9ad4 100644 --- a/pkg/cache/cache_internal_test.go +++ b/pkg/cache/cache_internal_test.go @@ -14,7 +14,6 @@ import ( "testing" "time" - "github.com/klauspost/compress/zstd" "github.com/nix-community/go-nix/pkg/narinfo" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" @@ -26,6 +25,7 @@ import ( "github.com/kalbasit/ncps/pkg/database" "github.com/kalbasit/ncps/pkg/nar" "github.com/kalbasit/ncps/pkg/storage/local" + "github.com/kalbasit/ncps/pkg/zstd" "github.com/kalbasit/ncps/testdata" "github.com/kalbasit/ncps/testhelper" ) @@ -311,13 +311,14 @@ func testRunLRU(factory cacheFactory) func(*testing.T) { narNone := nar.CompressionTypeNone for _, entry := range entries { if entry.NarCompression == narNone { - encoder, _ := zstd.NewWriter(nil) + enc := zstd.GetWriter() + defer zstd.PutWriter(enc) var compressed bytes.Buffer - encoder.Reset(&compressed) - _, err = encoder.Write([]byte(entry.NarText)) + enc.Reset(&compressed) + _, err = enc.Write([]byte(entry.NarText)) require.NoError(t, err) - err = encoder.Close() + err = enc.Close() require.NoError(t, err) zstdSizes[entry.NarInfoHash] = uint64(compressed.Len()) //nolint:gosec diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index 996ccc2b..f42cf96f 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -15,7 +15,6 @@ import ( "testing" "time" - "github.com/klauspost/compress/zstd" "github.com/nix-community/go-nix/pkg/narinfo" "github.com/nix-community/go-nix/pkg/narinfo/signature" "github.com/rs/zerolog" @@ -31,6 +30,7 @@ import ( "github.com/kalbasit/ncps/pkg/storage" "github.com/kalbasit/ncps/pkg/storage/chunk" "github.com/kalbasit/ncps/pkg/storage/local" + "github.com/kalbasit/ncps/pkg/zstd" "github.com/kalbasit/ncps/testdata" "github.com/kalbasit/ncps/testhelper" @@ -648,10 +648,10 @@ func testGetNarInfo(factory cacheFactory) func(*testing.T) { require.NoError(t, err) if assert.NotEqual(t, narEntry.NarText, string(body), "narText should be stored compressed in the store") { - decoder, err := zstd.NewReader(nil) - require.NoError(t, err) + dec := zstd.GetReader() + defer zstd.PutReader(dec) - plain, err := decoder.DecodeAll(body, []byte{}) + plain, err := dec.DecodeAll(body, []byte{}) require.NoError(t, err) assert.Equal(t, narEntry.NarText, string(plain)) diff --git a/pkg/cache/cdc_test.go b/pkg/cache/cdc_test.go index 72b2982c..660f4ae9 100644 --- a/pkg/cache/cdc_test.go +++ b/pkg/cache/cdc_test.go @@ -334,7 +334,7 @@ func testCDCChunksAreCompressed(factory cacheFactory) func(*testing.T) { // Use highly compressible data (repeated bytes) content := strings.Repeat("compressible", 1000) - nu := nar.URL{Hash: "testnar-compress", Compression: nar.CompressionTypeNone} + nu := nar.URL{Hash: "testnar-zstd", Compression: nar.CompressionTypeNone} r := io.NopCloser(strings.NewReader(content)) err = c.PutNar(ctx, nu, r) diff --git a/pkg/cache/export_test.go b/pkg/cache/export_test.go index d4912eed..9ebb4ff8 100644 --- a/pkg/cache/export_test.go +++ b/pkg/cache/export_test.go @@ -6,11 +6,13 @@ import ( "strings" "testing" - "github.com/klauspost/compress/zstd" "github.com/nix-community/go-nix/pkg/narinfo" "github.com/nix-community/go-nix/pkg/narinfo/signature" "github.com/nix-community/go-nix/pkg/nixhash" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/kalbasit/ncps/pkg/zstd" ) // CheckAndFixNarInfo is a test-only export of the unexported checkAndFixNarInfo method. @@ -33,13 +35,14 @@ func compressZstd(t *testing.T, data string) string { var buf strings.Builder - enc, err := zstd.NewWriter(&buf) - require.NoError(t, err) - _, err = io.WriteString(enc, data) - require.NoError(t, err) - err = enc.Close() + pw := zstd.NewPooledWriter(&buf) + + _, err := io.WriteString(pw, data) require.NoError(t, err) + err = pw.Close() + assert.NoError(t, err) //nolint:testifylint + return buf.String() } diff --git a/pkg/nar/reader.go b/pkg/nar/reader.go index a4627028..5e59d7af 100644 --- a/pkg/nar/reader.go +++ b/pkg/nar/reader.go @@ -7,10 +7,11 @@ import ( "io" "github.com/andybalholm/brotli" - "github.com/klauspost/compress/zstd" "github.com/pierrec/lz4/v4" "github.com/sorairolake/lzip-go" "github.com/ulikunitz/xz" + + "github.com/kalbasit/ncps/pkg/zstd" ) // ErrUnsupportedCompressionType is returned when an unsupported compression type is encountered. @@ -31,12 +32,12 @@ func DecompressReader(r io.Reader, comp CompressionType) (io.ReadCloser, error) return io.NopCloser(bzip2.NewReader(r)), nil case CompressionTypeZstd: - zr, err := zstd.NewReader(r) + pr, err := zstd.NewPooledReader(r) if err != nil { return nil, fmt.Errorf("failed to create zstd reader: %w", err) } - return zr.IOReadCloser(), nil + return pr, nil case CompressionTypeLz4: return io.NopCloser(lz4.NewReader(r)), nil diff --git a/pkg/nar/reader_test.go b/pkg/nar/reader_test.go index 25aeb2d8..1fdd6532 100644 --- a/pkg/nar/reader_test.go +++ b/pkg/nar/reader_test.go @@ -7,7 +7,6 @@ import ( "testing" "github.com/andybalholm/brotli" - "github.com/klauspost/compress/zstd" "github.com/pierrec/lz4/v4" "github.com/sorairolake/lzip-go" "github.com/stretchr/testify/assert" @@ -15,6 +14,7 @@ import ( "github.com/ulikunitz/xz" "github.com/kalbasit/ncps/pkg/nar" + "github.com/kalbasit/ncps/pkg/zstd" ) func TestDecompressReader(t *testing.T) { @@ -70,11 +70,10 @@ func TestDecompressReader(t *testing.T) { getInput: func(t *testing.T) io.Reader { var buf bytes.Buffer - zw, err := zstd.NewWriter(&buf) + pw := zstd.NewPooledWriter(&buf) + _, err := pw.Write(content) require.NoError(t, err) - _, err = zw.Write(content) - require.NoError(t, err) - require.NoError(t, zw.Close()) + require.NoError(t, pw.Close()) return &buf }, diff --git a/pkg/server/server.go b/pkg/server/server.go index 5edb9c54..6e5357da 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -8,12 +8,10 @@ import ( "runtime/debug" "strconv" "strings" - "sync" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" - "github.com/klauspost/compress/zstd" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/riandyrn/otelchi" "github.com/rs/zerolog" @@ -31,6 +29,7 @@ import ( "github.com/kalbasit/ncps/pkg/nar" "github.com/kalbasit/ncps/pkg/narinfo" "github.com/kalbasit/ncps/pkg/storage" + "github.com/kalbasit/ncps/pkg/zstd" ) const ( @@ -60,17 +59,6 @@ var tracer trace.Tracer //nolint:gochecknoglobals var prometheusGatherer promclient.Gatherer -//nolint:gochecknoglobals -var zstdWriterPool = sync.Pool{ - New: func() interface{} { - // Not providing any options will use the default compression level. - // The error is ignored as NewWriter(nil) with no options doesn't error. - enc, _ := zstd.NewWriter(nil) - - return enc - }, -} - //nolint:gochecknoinits func init() { tracer = otel.Tracer(otelPackageName) @@ -591,16 +579,13 @@ func (s *Server) getNar(withBody bool) http.HandlerFunc { var out io.Writer = w if useZstd { - enc := zstdWriterPool.Get().(*zstd.Encoder) - enc.Reset(w) - out = enc + pw := zstd.NewPooledWriter(w) + out = pw defer func() { - if err := enc.Close(); err != nil { + if err := pw.Close(); err != nil { zerolog.Ctx(r.Context()).Error().Err(err).Msg("failed to close zstd writer") } - - zstdWriterPool.Put(enc) }() } diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 374b7da6..3cabe14b 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -13,7 +13,6 @@ import ( "testing" "time" - "github.com/klauspost/compress/zstd" "github.com/nix-community/go-nix/pkg/narinfo" "github.com/nix-community/go-nix/pkg/narinfo/signature" "github.com/rs/zerolog" @@ -30,6 +29,7 @@ import ( "github.com/kalbasit/ncps/pkg/server" "github.com/kalbasit/ncps/pkg/storage" "github.com/kalbasit/ncps/pkg/storage/local" + "github.com/kalbasit/ncps/pkg/zstd" "github.com/kalbasit/ncps/testdata" "github.com/kalbasit/ncps/testhelper" ) @@ -949,13 +949,13 @@ func TestGetNar_ZstdCompression(t *testing.T) { assert.Equal(t, "application/x-nix-nar", resp.Header.Get("Content-Type")) assert.Empty(t, resp.Header.Get("Content-Length")) - // 3. Decompress the body and verify content - dec, err := zstd.NewReader(resp.Body) + // 3. DecompressReader the body and verify content + pr, err := zstd.NewPooledReader(resp.Body) require.NoError(t, err) - defer dec.Close() + defer pr.Close() - decompressed, err := io.ReadAll(dec) + decompressed, err := io.ReadAll(pr) require.NoError(t, err) assert.Equal(t, narData, string(decompressed)) } diff --git a/pkg/storage/chunk/local.go b/pkg/storage/chunk/local.go index dbbd1681..076466e7 100644 --- a/pkg/storage/chunk/local.go +++ b/pkg/storage/chunk/local.go @@ -7,17 +7,17 @@ import ( "os" "path/filepath" - "github.com/klauspost/compress/zstd" + "github.com/kalbasit/ncps/pkg/zstd" ) -// localReadCloser wraps a zstd decoder and file to properly close both on Close(). +// localReadCloser wraps a pooled zstd reader and file to properly close both on Close(). type localReadCloser struct { - *zstd.Decoder + *zstd.PooledReader file *os.File } func (r *localReadCloser) Close() error { - r.Decoder.Close() + _ = r.PooledReader.Close() return r.file.Close() } @@ -25,34 +25,15 @@ func (r *localReadCloser) Close() error { // localStore implements Store for local filesystem. type localStore struct { baseDir string - encoder *zstd.Encoder - decoder *zstd.Decoder } // NewLocalStore returns a new local chunk store. func NewLocalStore(baseDir string) (Store, error) { - encoder, err := zstd.NewWriter(nil) - if err != nil { - return nil, fmt.Errorf("failed to create zstd encoder: %w", err) - } - - decoder, err := zstd.NewReader(nil) - if err != nil { - encoder.Close() - - return nil, fmt.Errorf("failed to create zstd decoder: %w", err) - } - s := &localStore{ baseDir: baseDir, - encoder: encoder, - decoder: decoder, } // Ensure base directory exists if err := os.MkdirAll(s.storeDir(), 0o755); err != nil { - encoder.Close() - decoder.Close() - return nil, fmt.Errorf("failed to create chunk store directory: %w", err) } @@ -94,15 +75,15 @@ func (s *localStore) GetChunk(_ context.Context, hash string) (io.ReadCloser, er return nil, err } - // Create a new decoder for this specific file - decoder, err := zstd.NewReader(f) + // Use pooled reader instead of creating new instance + pr, err := zstd.NewPooledReader(f) if err != nil { f.Close() - return nil, fmt.Errorf("failed to create zstd decoder: %w", err) + return nil, fmt.Errorf("failed to create zstd reader: %w", err) } - return &localReadCloser{decoder, f}, nil + return &localReadCloser{pr, f}, nil } func (s *localStore) PutChunk(_ context.Context, hash string, data []byte) (bool, int64, error) { @@ -114,8 +95,12 @@ func (s *localStore) PutChunk(_ context.Context, hash string, data []byte) (bool return false, 0, err } + // Use pooled encoder + enc := zstd.GetWriter() + defer zstd.PutWriter(enc) + // Compress data with zstd - compressed := s.encoder.EncodeAll(data, nil) + compressed := enc.EncodeAll(data, nil) // Write to temporary file first to ensure atomicity tmpFile, err := os.CreateTemp(dir, "chunk-*") diff --git a/pkg/storage/chunk/local_test.go b/pkg/storage/chunk/local_test.go index 183fd5d4..0670ed69 100644 --- a/pkg/storage/chunk/local_test.go +++ b/pkg/storage/chunk/local_test.go @@ -164,7 +164,7 @@ func TestLocalStore(t *testing.T) { // Use highly compressible data (repeated bytes) data := bytes.Repeat([]byte("compressible"), 1024) - isNew, compressedSize, err := store.PutChunk(ctx, "test-hash-compress-1", data) + isNew, compressedSize, err := store.PutChunk(ctx, testhelper.MustRandNarHash(), data) require.NoError(t, err) assert.True(t, isNew) assert.Greater(t, int64(len(data)), compressedSize, "compressed size should be less than original") @@ -175,10 +175,11 @@ func TestLocalStore(t *testing.T) { t.Parallel() data := []byte("hello, compressed world! hello, compressed world! hello, compressed world!") - _, _, err := store.PutChunk(ctx, "test-hash-roundtrip", data) + hash := testhelper.MustRandNarHash() + _, _, err := store.PutChunk(ctx, hash, data) require.NoError(t, err) - rc, err := store.GetChunk(ctx, "test-hash-roundtrip") + rc, err := store.GetChunk(ctx, hash) require.NoError(t, err) defer rc.Close() diff --git a/pkg/storage/chunk/s3.go b/pkg/storage/chunk/s3.go index 9871dc39..d5344173 100644 --- a/pkg/storage/chunk/s3.go +++ b/pkg/storage/chunk/s3.go @@ -10,12 +10,12 @@ import ( "path" "time" - "github.com/klauspost/compress/zstd" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" "github.com/kalbasit/ncps/pkg/lock" "github.com/kalbasit/ncps/pkg/s3" + "github.com/kalbasit/ncps/pkg/zstd" ) // ErrBucketNotFound is returned when the bucket is not found. @@ -28,25 +28,23 @@ const ( chunkPutLockTTL = 5 * time.Minute ) -// s3ReadCloser wraps a zstd decoder and io.ReadCloser to properly close both. +// s3ReadCloser wraps a pooled zstd reader and io.ReadCloser to properly close both. type s3ReadCloser struct { - *zstd.Decoder + *zstd.PooledReader body io.ReadCloser } func (r *s3ReadCloser) Close() error { - r.Decoder.Close() + _ = r.PooledReader.Close() return r.body.Close() } // s3Store implements Store for S3 storage. type s3Store struct { - client *minio.Client - locker lock.Locker - bucket string - encoder *zstd.Encoder - decoder *zstd.Decoder + client *minio.Client + locker lock.Locker + bucket string } // NewS3Store returns a new S3 chunk store. @@ -88,24 +86,10 @@ func NewS3Store(ctx context.Context, cfg s3.Config, locker lock.Locker) (Store, return nil, fmt.Errorf("%w: %s", ErrBucketNotFound, cfg.Bucket) } - encoder, err := zstd.NewWriter(nil) - if err != nil { - return nil, fmt.Errorf("failed to create zstd encoder: %w", err) - } - - decoder, err := zstd.NewReader(nil) - if err != nil { - encoder.Close() - - return nil, fmt.Errorf("failed to create zstd decoder: %w", err) - } - return &s3Store{ - client: client, - locker: locker, - bucket: cfg.Bucket, - encoder: encoder, - decoder: decoder, + client: client, + locker: locker, + bucket: cfg.Bucket, }, nil } @@ -147,15 +131,15 @@ func (s *s3Store) GetChunk(ctx context.Context, hash string) (io.ReadCloser, err return nil, err } - // Create a new decoder for this specific object - decoder, err := zstd.NewReader(obj) + // Use pooled reader instead of creating new instance + pr, err := zstd.NewPooledReader(obj) if err != nil { obj.Close() - return nil, fmt.Errorf("failed to create zstd decoder: %w", err) + return nil, fmt.Errorf("failed to create zstd reader: %w", err) } - return &s3ReadCloser{decoder, obj}, nil + return &s3ReadCloser{pr, obj}, nil } func (s *s3Store) PutChunk(ctx context.Context, hash string, data []byte) (bool, int64, error) { @@ -172,8 +156,12 @@ func (s *s3Store) PutChunk(ctx context.Context, hash string, data []byte) (bool, _ = s.locker.Unlock(ctx, lockKey) }() + // Use pooled encoder + enc := zstd.GetWriter() + defer zstd.PutWriter(enc) + // Compress data with zstd - compressed := s.encoder.EncodeAll(data, nil) + compressed := enc.EncodeAll(data, nil) // Check if exists. exists, err := s.HasChunk(ctx, hash) diff --git a/pkg/storage/chunk/s3_test.go b/pkg/storage/chunk/s3_test.go index 615dfc93..701e8b25 100644 --- a/pkg/storage/chunk/s3_test.go +++ b/pkg/storage/chunk/s3_test.go @@ -117,30 +117,34 @@ func TestS3Store_Integration(t *testing.T) { t.Run("stored chunk is zstd-compressed in S3", func(t *testing.T) { t.Parallel() + hash := testhelper.MustRandNarHash() + data := bytes.Repeat([]byte("compressible"), 1024) - isNew, compressedSize, err := store.PutChunk(ctx, "test-hash-s3-compress", data) + isNew, compressedSize, err := store.PutChunk(ctx, hash, data) require.NoError(t, err) assert.True(t, isNew) assert.Greater(t, int64(len(data)), compressedSize, "compressed size should be less than original") assert.Positive(t, compressedSize) defer func() { - _ = store.DeleteChunk(ctx, "test-hash-s3-compress") + _ = store.DeleteChunk(ctx, hash) }() }) t.Run("compressed chunk round-trips correctly via S3", func(t *testing.T) { t.Parallel() + hash := testhelper.MustRandNarHash() + data := []byte("hello from S3 compressed chunk! hello from S3 compressed chunk!") - _, _, err := store.PutChunk(ctx, "test-hash-s3-roundtrip", data) + _, _, err := store.PutChunk(ctx, hash, data) require.NoError(t, err) defer func() { - _ = store.DeleteChunk(ctx, "test-hash-s3-roundtrip") + _ = store.DeleteChunk(ctx, hash) }() - rc, err := store.GetChunk(ctx, "test-hash-s3-roundtrip") + rc, err := store.GetChunk(ctx, hash) require.NoError(t, err) defer rc.Close() @@ -151,8 +155,18 @@ func TestS3Store_Integration(t *testing.T) { }) } -func TestS3Store_PutChunk_RaceCondition(t *testing.T) { +func TestS3Store_PutSameChunk_RaceCondition(t *testing.T) { t.Parallel() + runRaceConditionTest(t, false) +} + +func TestS3Store_PutDifferentChunk_RaceCondition(t *testing.T) { + t.Parallel() + runRaceConditionTest(t, true) +} + +func runRaceConditionTest(t *testing.T, distinctHashes bool) { + t.Helper() ctx := context.Background() @@ -162,37 +176,49 @@ func TestS3Store_PutChunk_RaceCondition(t *testing.T) { } // We pass a local locker to ensure thread safety during the test. - store, err := chunk.NewS3Store(ctx, *cfg, local.NewLocker()) require.NoError(t, err) - hash := "test-hash-race" - content := []byte(strings.Repeat("race condition content", 1024)) + const numGoRoutines = 10 + + hashes := make(chan string, numGoRoutines) defer func() { - _ = store.DeleteChunk(ctx, hash) + close(hashes) + + for hash := range hashes { + _ = store.DeleteChunk(ctx, hash) + } }() - const numGoRoutines = 10 + content := []byte(strings.Repeat("race condition content", 1024)) + sharedHash := testhelper.MustRandNarHash() results := make(chan bool, numGoRoutines) - errors := make(chan error, numGoRoutines) + errs := make(chan error, numGoRoutines) for range numGoRoutines { go func() { + hash := sharedHash + if distinctHashes { + hash = testhelper.MustRandNarHash() + } + + hashes <- hash + created, size, err := store.PutChunk(ctx, hash, content) results <- created assert.Greater(t, int64(len(content)), size) - errors <- err + errs <- err }() } createdCount := 0 for range numGoRoutines { - err := <-errors + err := <-errs require.NoError(t, err) if <-results { @@ -200,10 +226,14 @@ func TestS3Store_PutChunk_RaceCondition(t *testing.T) { } } - // The contract says true if chunk was new. In a race condition WITHOUT locking, - // multiple goroutines might see created: true. - // We want to ensure only ONE goroutine gets created: true. - assert.Equal(t, 1, createdCount, "Only one goroutine should have created the chunk") + if distinctHashes { + assert.Equal(t, numGoRoutines, createdCount, "All goroutines should have created their unique chunk") + } else { + // The contract says true if chunk was new. In a race condition WITHOUT locking, + // multiple goroutines might see created: true. + // We want to ensure only ONE goroutine gets created: true. + assert.Equal(t, 1, createdCount, "Only one goroutine should have created the chunk") + } } func TestNewS3Store_Validation(t *testing.T) { diff --git a/pkg/zstd/README.md b/pkg/zstd/README.md new file mode 100644 index 00000000..9b8122fe --- /dev/null +++ b/pkg/zstd/README.md @@ -0,0 +1,333 @@ +# ZSTD Pool Management + +## Overview + +The `pkg/zstd` package provides a `sync.Pool`-based implementation for recycling zstd encoder and decoder instances. This reduces allocation overhead when creating multiple compression/decompression operations, which is especially beneficial in high-throughput scenarios like the NCPS cache server. + +## Motivation + +Creating new `zstd.Encoder` and `zstd.Decoder` instances is relatively expensive due to internal buffer allocations. When handling many compression/decompression operations (as in chunk storage and HTTP compression), reusing these instances via a pool significantly reduces garbage collection pressure and improves performance. + +## Quick Reference + +### Import + +```go +import "github.com/kalbasit/ncps/pkg/zstd" +``` + +### Common Patterns + +#### Compress Data + +```go +pw := zstd.NewPooledWriter(&buf) +defer pw.Close() +pw.Write(data) +``` + +#### Decompress Data + +```go +pr, err := zstd.NewPooledReader(reader) +if err != nil { + return err +} +defer pr.Close() +data, _ := io.ReadAll(pr) +``` + +#### One-Shot Encoding + +```go +enc := zstd.GetWriter() +defer zstd.PutWriter(enc) +compressed := enc.EncodeAll(data, nil) +``` + +#### One-Shot Decoding + +```go +dec := zstd.GetReader() +defer zstd.PutReader(dec) +dec.Reset(reader) +data, _ := io.ReadAll(dec) +``` + +### API Cheat Sheet + +| Function | Purpose | Returns | Error | +|----------|---------|---------|-------| +| `GetWriter()` | Get encoder from pool | `*zstd.Encoder` | N/A | +| `PutWriter(enc)` | Return encoder to pool | `void` | N/A | +| `GetReader()` | Get decoder from pool | `*zstd.Decoder` | N/A | +| `PutReader(dec)` | Return decoder to pool | `void` | N/A | +| `NewPooledWriter(w)` | Create auto-managed writer | `*PooledWriter` | N/A | +| `NewPooledReader(r)` | Create auto-managed reader | `*PooledReader` | error | +| `pw.Close()` | Close writer, return to pool | `error` | compression error | +| `pr.Close()` | Close reader, return to pool | `error` | nil | + +______________________________________________________________________ + +## API Documentation + +### Low-Level API (Manual Management) + +For fine-grained control, use the low-level functions: + +#### Writer Pool + +```go +// Get an encoder from the pool +enc := zstd.GetWriter() +defer zstd.PutWriter(enc) + +// Reset the encoder to write to a buffer +var buf bytes.Buffer +enc.Reset(&buf) + +// Use the encoder +enc.Write(data) +enc.Close() + +// The encoder is automatically reset before being returned to the pool +``` + +#### Reader Pool + +```go +// Get a decoder from the pool +dec := zstd.GetReader() +defer zstd.PutReader(dec) + +// Reset the decoder to read from a compressed source +dec.Reset(compressedReader) + +// Use the decoder +decompressed, err := io.ReadAll(dec) +``` + +### High-Level API (Automatic Management) + +For simplicity and to avoid resource leaks, use the wrapped types: + +#### PooledWriter + +```go +import "github.com/kalbasit/ncps/pkg/zstd" + +// Create a pooled writer - automatically manages the encoder +pw := zstd.NewPooledWriter(&buf) +defer pw.Close() // Automatically returns encoder to pool + +// Use like a normal zstd encoder +pw.Write(data) +pw.Close() +``` + +#### PooledReader + +```go +import "github.com/kalbasit/ncps/pkg/zstd" + +// Create a pooled reader - automatically manages the decoder +pr, err := zstd.NewPooledReader(compressedReader) +if err != nil { + return err +} +defer pr.Close() // Automatically returns decoder to pool + +// Use like a normal zstd decoder +data, err := io.ReadAll(pr) +``` + +## Usage Examples + +### Example 1: Compressing Multiple Data Chunks + +```go +func compressChunks(chunks [][]byte) ([][]byte, error) { + result := make([][]byte, len(chunks)) + + for i, chunk := range chunks { + var buf bytes.Buffer + pw := zstd.NewPooledWriter(&buf) + + if _, err := pw.Write(chunk); err != nil { + pw.Close() + return nil, err + } + + if err := pw.Close(); err != nil { + return nil, err + } + + result[i] = buf.Bytes() + } + + return result, nil +} +``` + +### Example 2: Decompressing Data + +```go +func decompressData(compressed []byte) ([]byte, error) { + pr, err := zstd.NewPooledReader(bytes.NewReader(compressed)) + if err != nil { + return nil, err + } + defer pr.Close() + + return io.ReadAll(pr) +} +``` + +### Example 3: Direct Encoding (No Streaming) + +```go +func quickCompress(data []byte) []byte { + enc := zstd.GetWriter() + defer zstd.PutWriter(enc) + + // Use EncodeAll for non-streaming compression + return enc.EncodeAll(data, nil) +} +``` + +## Pool Configuration + +Both pools use the default zstdion level and settings: + +- **WriterPool**: Default compression level (fast but good compression) +- **ReaderPool**: Default decompression settings + +For custom zstdion levels or options, create encoders/decoders directly without pooling: + +```go +// For custom compression level +enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedDefault)) +if err != nil { + return err +} +defer enc.Close() +``` + +## Performance Considerations + +1. **Pool Benefits**: Most beneficial when you have many compression/decompression operations +1. **Memory Trade-off**: The pool maintains encoder/decoder instances in memory, ready for reuse +1. **Thread-Safe**: `sync.Pool` is thread-safe and designed for concurrent use +1. **Automatic Cleanup**: Decoders and encoders are reset to a clean state before being returned to the pool + +## Integration Points + +The zstd pool is used in: + +- `pkg/server/server.go` - HTTP response compression +- `pkg/storage/chunk/local.go` - Local chunk storage compression +- `pkg/storage/chunk/s3.go` - S3 chunk storage compression +- Test utilities and helpers + +## Migration Guide + +To migrate existing code to use the zstd pool: + +### Before (Direct Creation) + +```go +import "github.com/klauspost/zstd/zstd" + +encoder, err := zstd.NewWriter(&buf) +if err != nil { + return err +} +defer encoder.Close() +encoder.Write(data) +``` + +### After (Using Pool) + +```go +import "github.com/kalbasit/ncps/pkg/zstd" + +pw := zstd.NewPooledWriter(&buf) +defer pw.Close() +pw.Write(data) +``` + +## Best Practices + +1. **Always defer Close()**: Ensure pooled resources are returned promptly +1. **Use Wrapped Types**: Prefer `PooledWriter` and `PooledReader` for cleaner code +1. **Handle Errors**: Check errors from Close(), Reset(), and Read/Write operations +1. **One Writer/Reader Per Operation**: Get/release for each independent compression/decompression +1. **Avoid Nested Pools**: Don't hold multiple pooled instances simultaneously unless necessary + +## Testing + +The zstd pool includes comprehensive tests in `pkg/zstd/zstd_test.go`: + +```bash +go test ./pkg/zstd -v -run +``` + +Tests cover: + +- Pool allocation and reuse +- Round-trip compression/decompression +- Error handling +- Resource cleanup +- Concurrent pool access + +______________________________________________________________________ + +## Implementation Details + +### Files Created + +#### 1. `pkg/zstd/zstd.go` + +The main implementation file containing: + +- **WriterPool**: A `sync.Pool` managing reusable `zstd.Encoder` instances +- **ReaderPool**: A `sync.Pool` managing reusable `zstd.Decoder` instances + +#### 2. `pkg/zstd/zstd_test.go` + +Comprehensive test suite covering: + +- Pool get/put operations +- Pooled wrapper functionality +- Round-trip compression/decompression +- Error handling +- Multiple close operations +- Nil safety +- EncodeAll pattern support + +### Design Decisions + +#### Why `sync.Pool`? + +- Built into Go standard library +- Thread-safe without explicit locking +- Automatically adjusts to contention +- Zero-copy semantics + +#### Why Two APIs? + +- **Low-level**: For complex scenarios needing manual control +- **High-level**: For common cases with automatic cleanup +- Recommendation: Use high-level in most cases + +#### Why Default Compression Level? + +- Covers 99% of use cases +- Custom levels can use direct `zstd.NewWriter()` +- Simpler pool implementation + +#### Decoder Reset Pattern + +- Decoders are reset but not explicitly closed when returned to pool +- Prevents "decoder used after Close" errors +- Allows safe reuse of pooled decoders diff --git a/pkg/zstd/zstd.go b/pkg/zstd/zstd.go new file mode 100644 index 00000000..7c3f0682 --- /dev/null +++ b/pkg/zstd/zstd.go @@ -0,0 +1,193 @@ +// Package zstd provides compression utilities for the NCPS project. +package zstd + +import ( + "io" + "sync" + + "github.com/klauspost/compress/zstd" +) + +// writerPool manages a pool of zstd.Encoder instances for reuse. +// This pool is used to reduce allocation overhead when creating multiple +// compression writers. Encoders are reset before being returned to the pool +// and are ready for immediate reuse. +// +// The pool uses the default compression level (no options specified). +// For custom compression levels, create encoders directly with zstd.NewWriter. +// +//nolint:gochecknoglobals +var writerPool = sync.Pool{ + New: func() any { + // Not providing any options will use the default compression level. + // The error is ignored as NewWriter(nil) with no options doesn't error. + enc, _ := zstd.NewWriter(nil) + + return enc + }, +} + +// GetWriter retrieves a zstd.Encoder from the pool, or creates a new one +// if the pool is empty. The caller must call PutWriter to return the encoder +// to the pool when done. +// +// Example: +// +// enc := GetWriter() +// defer PutWriter(enc) +// enc.Reset(buf) +// enc.Write(data) +// enc.Close() +func GetWriter() *zstd.Encoder { + return writerPool.Get().(*zstd.Encoder) +} + +// PutWriter returns a zstd.Encoder to the pool for reuse. +// The encoder is reset to nil before being returned to the pool. +// If enc is nil, this function is a no-op. +// +// Always pair calls to GetWriter with PutWriter in a defer statement +// or ensure it's called in all code paths. +func PutWriter(enc *zstd.Encoder) { + if enc != nil { + enc.Reset(nil) + writerPool.Put(enc) + } +} + +// readerPool manages a pool of zstd.Decoder instances for reuse. +// This pool is used to reduce allocation overhead when creating multiple +// decompression readers. Decoders are reset before being returned to the pool +// and are ready for immediate reuse. +// +// The pool uses the default decompression settings (no options specified). +// For custom decompression settings, create decoders directly with zstd.NewReader. +// +//nolint:gochecknoglobals +var readerPool = sync.Pool{ + New: func() any { + // Not providing any options will use the default decompression settings. + // The error is ignored as NewReader(nil) with no options doesn't error. + dec, _ := zstd.NewReader(nil) + + return dec + }, +} + +// GetReader retrieves a zstd.Decoder from the pool, or creates a new one +// if the pool is empty. The caller must call PutReader or use NewPooledReader +// for automatic pool management. +// +// Note: Prefer NewPooledReader for automatic resource cleanup. +// +// Example (manual management): +// +// dec := GetReader() +// defer PutReader(dec) +// dec.Reset(reader) +// data, err := io.ReadAll(dec) +func GetReader() *zstd.Decoder { + return readerPool.Get().(*zstd.Decoder) +} + +// PutReader returns a zstd.Decoder to the pool for reuse. +// The decoder is reset to nil before being returned to the pool. +// If dec is nil, this function is a no-op. +// +// Always pair calls to GetReader with PutReader in a defer statement +// or ensure it's called in all code paths. +func PutReader(dec *zstd.Decoder) { + if dec != nil { + _ = dec.Reset(nil) + readerPool.Put(dec) + } +} + +// PooledWriter wraps a zstd.Encoder with automatic pool management. +// When closed, the encoder is automatically returned to the pool. +// +// Example: +// +// pw := NewPooledWriter(&buf) +// defer pw.Close() +// pw.Write(data) +type PooledWriter struct { + *zstd.Encoder + w io.Writer +} + +// NewPooledWriter creates a new pooled writer that wraps the given io.Writer. +// The returned writer will automatically return its encoder to the pool when closed. +// This is the recommended way to use pooled writers for write operations. +func NewPooledWriter(w io.Writer) *PooledWriter { + enc := GetWriter() + enc.Reset(w) + + return &PooledWriter{ + Encoder: enc, + w: w, + } +} + +// Close closes the encoder and returns it to the pool. +// Multiple calls to Close are safe and will not panic. +func (pw *PooledWriter) Close() error { + if pw.Encoder == nil { + return nil + } + + err := pw.Encoder.Close() + PutWriter(pw.Encoder) + pw.Encoder = nil + + return err +} + +// PooledReader wraps a zstd.Decoder with automatic pool management. +// When closed, the decoder is automatically returned to the pool. +// +// Example: +// +// pr, err := NewPooledReader(compressedReader) +// if err != nil { +// return err +// } +// defer pr.Close() +// data, err := io.ReadAll(pr) +type PooledReader struct { + *zstd.Decoder + r io.Reader +} + +// NewPooledReader creates a new pooled reader that wraps the given io.Reader. +// The returned reader will automatically return its decoder to the pool when closed. +// This is the recommended way to use pooled readers for read operations. +// +// Returns an error if the decoder cannot be reset to read from the given reader. +func NewPooledReader(r io.Reader) (*PooledReader, error) { + dec := GetReader() + if err := dec.Reset(r); err != nil { + PutReader(dec) + + return nil, err + } + + return &PooledReader{ + Decoder: dec, + r: r, + }, nil +} + +// Close closes the reader and returns it to the pool. +// Multiple calls to Close are safe and will not panic. +// Note: The underlying decoder is not explicitly closed, only reset and returned to the pool. +func (pr *PooledReader) Close() error { + if pr.Decoder == nil { + return nil + } + + PutReader(pr.Decoder) + pr.Decoder = nil + + return nil +} diff --git a/pkg/zstd/zstd_test.go b/pkg/zstd/zstd_test.go new file mode 100644 index 00000000..fec86e04 --- /dev/null +++ b/pkg/zstd/zstd_test.go @@ -0,0 +1,313 @@ +package zstd_test + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/kalbasit/ncps/pkg/zstd" +) + +func TestGetAndPutWriter(t *testing.T) { + t.Parallel() + + // Get a writer from the pool + writer1 := zstd.GetWriter() + require.NotNil(t, writer1) + + // Reset and put it back + zstd.PutWriter(writer1) + + // Get another writer - should be the same instance if pool reuses + writer2 := zstd.GetWriter() + require.NotNil(t, writer2) + + zstd.PutWriter(writer2) +} + +func TestGetAndPutReader(t *testing.T) { + t.Parallel() + + // Get a reader from the pool + reader1 := zstd.GetReader() + require.NotNil(t, reader1) + + // Reset and put it back + zstd.PutReader(reader1) + + // Get another reader - should be the same instance if pool reuses + reader2 := zstd.GetReader() + require.NotNil(t, reader2) + + zstd.PutReader(reader2) +} + +func TestPutWriterWithNil(t *testing.T) { + t.Parallel() + + // Should not panic when putting nil + zstd.PutWriter(nil) +} + +func TestPutReaderWithNil(t *testing.T) { + t.Parallel() + + // Should not panic when putting nil + zstd.PutReader(nil) +} + +func TestPooledWriter(t *testing.T) { + t.Parallel() + + data := []byte("Hello, World!") + + var buf bytes.Buffer + + writer := zstd.NewPooledWriter(&buf) + require.NotNil(t, writer) + + // Write data + n, err := writer.Write(data) + require.NoError(t, err) + assert.Equal(t, len(data), n) + + // Close the writer + err = writer.Close() + require.NoError(t, err) + + // Verify data was compressed + assert.NotEmpty(t, buf.Bytes()) +} + +func TestPooledWriterCloseMultiple(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + + writer := zstd.NewPooledWriter(&buf) + require.NotNil(t, writer) + + // Close should be idempotent + err := writer.Close() + require.NoError(t, err) + + // Second close should not panic + err = writer.Close() + require.NoError(t, err) +} + +func TestPooledReader(t *testing.T) { + t.Parallel() + + // First, create and zstd some data + originalData := []byte("Hello, Reader!") + + var compressed bytes.Buffer + + writer := zstd.NewPooledWriter(&compressed) + require.NotNil(t, writer) + + _, err := writer.Write(originalData) + require.NoError(t, err) + + err = writer.Close() + require.NoError(t, err) + + // Now dezstd using pooled reader + reader, err := zstd.NewPooledReader(bytes.NewReader(compressed.Bytes())) + require.NoError(t, err) + require.NotNil(t, reader) + + // Read all data + decompressed, err := io.ReadAll(reader) + require.NoError(t, err) + + assert.Equal(t, originalData, decompressed) + + // Close the reader + err = reader.Close() + require.NoError(t, err) +} + +func TestPooledReaderCloseMultiple(t *testing.T) { + t.Parallel() + + // Compress some data first + originalData := []byte("test data") + + var compressed bytes.Buffer + + writer := zstd.NewPooledWriter(&compressed) + _, err := writer.Write(originalData) + require.NoError(t, err) + err = writer.Close() + require.NoError(t, err) + + // Create pooled reader + reader, err := zstd.NewPooledReader(bytes.NewReader(compressed.Bytes())) + require.NoError(t, err) + + // Read the data before closing + _, err = io.ReadAll(reader) + require.NoError(t, err) + + // Close multiple times should not panic + err = reader.Close() + require.NoError(t, err) + + err = reader.Close() + require.NoError(t, err) +} + +func TestPooledReaderInvalidData(t *testing.T) { + t.Parallel() + + // Try to read from invalid zstd data + invalidData := []byte("not compressed data") + reader, err := zstd.NewPooledReader(bytes.NewReader(invalidData)) + // This should not error on creation, but on read + if err != nil { + // If error occurs during Reset, that's also acceptable + return + } + + require.NotNil(t, reader) + + // Reading should fail with invalid data + _, err = io.ReadAll(reader) + require.Error(t, err) + + reader.Close() +} + +func TestPooledReaderWithNilDecoder(t *testing.T) { + t.Parallel() + + // Create a pooled reader and close it without using it + originalData := []byte("test") + + var compressed bytes.Buffer + + writer := zstd.NewPooledWriter(&compressed) + _, err := writer.Write(originalData) + require.NoError(t, err) + err = writer.Close() + require.NoError(t, err) + + reader, err := zstd.NewPooledReader(bytes.NewReader(compressed.Bytes())) + require.NoError(t, err) + + // Manually set to nil to test Close with nil decoder + reader.Decoder = nil + err = reader.Close() + require.NoError(t, err) +} + +func TestWriterPoolReuse(t *testing.T) { + t.Parallel() + + // Test that pool actually reuses instances + writer1 := zstd.GetWriter() + ptr1 := writer1 + + zstd.PutWriter(writer1) + + writer2 := zstd.GetWriter() + ptr2 := writer2 + + zstd.PutWriter(writer2) + + // In most cases they should be the same pointer (pool reuse) + // But this is not guaranteed, so we just verify we can use them + assert.NotNil(t, ptr1) + assert.NotNil(t, ptr2) +} + +func TestReaderPoolReuse(t *testing.T) { + t.Parallel() + + // Test that pool actually reuses instances + reader1 := zstd.GetReader() + ptr1 := reader1 + + zstd.PutReader(reader1) + + reader2 := zstd.GetReader() + ptr2 := reader2 + + zstd.PutReader(reader2) + + // In most cases they should be the same pointer (pool reuse) + // But this is not guaranteed, so we just verify we can use them + assert.NotNil(t, ptr1) + assert.NotNil(t, ptr2) +} + +func TestPooledWriterAndReaderRoundTrip(t *testing.T) { + t.Parallel() + + testCases := []string{ + "Hello, World!", + "", + "a", + "The quick brown fox jumps over the lazy dog", + "Multiple\nlines\nof\ntext", + } + + for _, testData := range testCases { + testData := testData + t.Run(testData, func(t *testing.T) { + t.Parallel() + + // Compress + var compressed bytes.Buffer + + writer := zstd.NewPooledWriter(&compressed) + require.NotNil(t, writer) + + n, err := writer.Write([]byte(testData)) + require.NoError(t, err) + assert.Equal(t, len(testData), n) + + err = writer.Close() + require.NoError(t, err) + + // DecompressReader + reader, err := zstd.NewPooledReader(bytes.NewReader(compressed.Bytes())) + require.NoError(t, err) + require.NotNil(t, reader) + + decompressed, err := io.ReadAll(reader) + require.NoError(t, err) + + assert.Equal(t, testData, string(decompressed)) + + err = reader.Close() + require.NoError(t, err) + }) + } +} + +func TestPooledWriterEncodeAllPattern(t *testing.T) { + t.Parallel() + + testData := []byte("test data for encode all pattern") + + // Test the EncodeAll pattern used in chunk storage + writer := zstd.GetWriter() + compressed := writer.EncodeAll(testData, nil) + zstd.PutWriter(writer) + + // Verify the compressed data can be decompressed + reader, err := zstd.NewPooledReader(bytes.NewReader(compressed)) + require.NoError(t, err) + + decompressed, err := io.ReadAll(reader) + require.NoError(t, err) + require.NoError(t, reader.Close()) + + assert.Equal(t, testData, decompressed) +} diff --git a/testdata/server.go b/testdata/server.go index 18decef6..047e612e 100644 --- a/testdata/server.go +++ b/testdata/server.go @@ -9,9 +9,8 @@ import ( "sync" "testing" - "github.com/klauspost/compress/zstd" - "github.com/kalbasit/ncps/pkg/nar" + "github.com/kalbasit/ncps/pkg/zstd" "github.com/kalbasit/ncps/testhelper" ) @@ -37,7 +36,7 @@ func NewTestServer(t *testing.T, priority int) *Server { s.entries = append(s.entries, Entries...) - s.Server = httptest.NewServer(compressMiddleware(s.handler())) + s.Server = httptest.NewServer(zstdMiddleware(s.handler())) return s } @@ -74,7 +73,7 @@ func (s *Server) RemoveMaybeHandler(idx string) { delete(s.maybeHandlers, idx) } -func compressMiddleware(next http.Handler) http.Handler { +func zstdMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Accept-Encoding") != "zstd" { next.ServeHTTP(w, r) @@ -82,13 +81,10 @@ func compressMiddleware(next http.Handler) http.Handler { return } - encoder, err := zstd.NewWriter(w) - if !requireNoError(w, err) { - return - } - defer encoder.Close() + pw := zstd.NewPooledWriter(w) + defer pw.Close() - zw := &zstdResponseWriter{Writer: encoder, ResponseWriter: w} + zw := &zstdResponseWriter{Writer: pw, ResponseWriter: w} next.ServeHTTP(zw, r) }) diff --git a/testdata/server_test.go b/testdata/server_test.go index da414eee..5ee79ef2 100644 --- a/testdata/server_test.go +++ b/testdata/server_test.go @@ -6,10 +6,10 @@ import ( "net/http" "testing" - "github.com/klauspost/compress/zstd" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/kalbasit/ncps/pkg/zstd" "github.com/kalbasit/ncps/testdata" ) @@ -77,10 +77,10 @@ func TestNewTestServerWithZSTD(t *testing.T) { require.NoError(t, err) if assert.NotEqual(t, testdata.Nar1.NarText, string(body)) { - decoder, err := zstd.NewReader(nil) - require.NoError(t, err) + dec := zstd.GetReader() + defer zstd.PutReader(dec) - plain, err := decoder.DecodeAll(body, []byte{}) + plain, err := dec.DecodeAll(body, []byte{}) require.NoError(t, err) if assert.Len(t, testdata.Nar1.NarText, len(string(plain))) {