diff --git a/sampling/reservoir_items_sketch.go b/sampling/reservoir_items_sketch.go index f9abaca..e219c13 100644 --- a/sampling/reservoir_items_sketch.go +++ b/sampling/reservoir_items_sketch.go @@ -180,6 +180,7 @@ func (s *ReservoirItemsSketch[T]) Copy() *ReservoirItemsSketch[T] { return &ReservoirItemsSketch[T]{ k: s.k, n: s.n, + rf: s.rf, data: dataCopy, } } @@ -191,7 +192,7 @@ func (s *ReservoirItemsSketch[T]) DownsampledCopy(newK int) (*ReservoirItemsSket return s.Copy(), nil } - result, err := NewReservoirItemsSketch[T](newK) + result, err := NewReservoirItemsSketch[T](newK, WithReservoirItemsSketchResizeFactor(s.rf)) if err != nil { return nil, err } @@ -230,14 +231,49 @@ const ( preambleIntsNonEmpty = 2 serVer = 2 flagEmpty = 0x04 - resizeFactorBits = 0xC0 // ResizeFactor X8 + resizeFactorMask = 0xC0 ) +func resizeFactorBitsFor(rf ResizeFactor) (byte, error) { + switch rf { + case ResizeX1: + return 0x00, nil + case ResizeX2: + return 0x40, nil + case ResizeX4: + return 0x80, nil + case ResizeX8: + return 0xC0, nil + default: + return 0, errors.New("unsupported resize factor") + } +} + +func resizeFactorFromHeaderByte(b byte) (ResizeFactor, error) { + switch (b & resizeFactorMask) >> 6 { + case 0: + return ResizeX1, nil + case 1: + return ResizeX2, nil + case 2: + return ResizeX4, nil + case 3: + return ResizeX8, nil + default: + return 0, errors.New("unsupported resize factor bits") + } +} + // ToSlice serializes the sketch to a byte slice. func (s *ReservoirItemsSketch[T]) ToSlice(serde ItemsSerDe[T]) ([]byte, error) { + rfBits, err := resizeFactorBitsFor(s.rf) + if err != nil { + return nil, err + } + if s.IsEmpty() { buf := make([]byte, 8) - buf[0] = resizeFactorBits | preambleIntsEmpty + buf[0] = rfBits | preambleIntsEmpty buf[1] = serVer buf[2] = byte(internal.FamilyEnum.ReservoirItems.Id) buf[3] = flagEmpty @@ -253,7 +289,7 @@ func (s *ReservoirItemsSketch[T]) ToSlice(serde ItemsSerDe[T]) ([]byte, error) { preambleBytes := preambleIntsNonEmpty * 8 buf := make([]byte, preambleBytes+len(itemsBytes)) - buf[0] = resizeFactorBits | preambleIntsNonEmpty + buf[0] = rfBits | preambleIntsNonEmpty buf[1] = serVer buf[2] = byte(internal.FamilyEnum.ReservoirItems.Id) buf[3] = 0 @@ -276,16 +312,29 @@ func NewReservoirItemsSketchFromSlice[T any](data []byte, serde ItemsSerDe[T]) ( family := data[2] flags := data[3] k := int(binary.LittleEndian.Uint32(data[4:])) + rf, err := resizeFactorFromHeaderByte(data[0]) + if err != nil { + return nil, err + } if ver != serVer { - return nil, errors.New("unsupported serialization version") + if ver == 1 { + encK := binary.LittleEndian.Uint16(data[4:]) + decodedK, err := decodeReservoirSize(encK) + if err != nil { + return nil, err + } + k = decodedK + } else { + return nil, errors.New("unsupported serialization version") + } } if family != byte(internal.FamilyEnum.ReservoirItems.Id) { return nil, errors.New("wrong sketch family") } if (flags&flagEmpty) != 0 || preambleInts == preambleIntsEmpty { - return NewReservoirItemsSketch[T](k) + return NewReservoirItemsSketch[T](k, WithReservoirItemsSketchResizeFactor(rf)) } preambleBytes := preambleIntsNonEmpty * 8 @@ -306,6 +355,7 @@ func NewReservoirItemsSketchFromSlice[T any](data []byte, serde ItemsSerDe[T]) ( return &ReservoirItemsSketch[T]{ k: k, n: n, + rf: rf, data: items, }, nil } diff --git a/sampling/reservoir_items_sketch_test.go b/sampling/reservoir_items_sketch_test.go index 468f0cd..f0a098e 100644 --- a/sampling/reservoir_items_sketch_test.go +++ b/sampling/reservoir_items_sketch_test.go @@ -18,6 +18,7 @@ package sampling import ( + "encoding/binary" "math" "testing" @@ -182,3 +183,32 @@ func TestReservoirItemsSketchGetSamplesIsCopy(t *testing.T) { assert.NotEqual(t, samples1[0], samples2[0]) assert.Equal(t, int64(42), samples2[0]) } + +func TestReservoirItemsSketchResizeFactorSerialization(t *testing.T) { + sketch, err := NewReservoirItemsSketch[int64](10, WithReservoirItemsSketchResizeFactor(ResizeX2)) + assert.NoError(t, err) + sketch.Update(1) + + data, err := sketch.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + assert.Equal(t, byte(0x42), data[0]) // ResizeX2 (0x40) + preambleIntsNonEmpty (0x02) + + restored, err := NewReservoirItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + assert.Equal(t, ResizeX2, restored.rf) +} + +func TestReservoirItemsSketchLegacySerVerEmpty(t *testing.T) { + data := make([]byte, 8) + data[0] = 0xC0 | preambleIntsEmpty + data[1] = 1 // legacy serVer + data[2] = byte(internal.FamilyEnum.ReservoirItems.Id) + data[3] = flagEmpty + binary.LittleEndian.PutUint16(data[4:], 0x5000) // p=10, i=0 => k=1024 + + sketch, err := NewReservoirItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + assert.True(t, sketch.IsEmpty()) + assert.Equal(t, 1024, sketch.K()) + assert.Equal(t, ResizeX8, sketch.rf) +} diff --git a/sampling/reservoir_items_union.go b/sampling/reservoir_items_union.go index 8bcfd8b..f626602 100644 --- a/sampling/reservoir_items_union.go +++ b/sampling/reservoir_items_union.go @@ -109,6 +109,7 @@ func (u *ReservoirItemsUnion[T]) UpdateFromRaw(n int64, k int, items []T) error sketch := &ReservoirItemsSketch[T]{ k: k, n: n, + rf: defaultResizeFactor, data: items, } @@ -282,7 +283,16 @@ func NewReservoirItemsUnionFromSlice[T any](data []byte, serde ItemsSerDe[T]) (* } if ver != unionSerVer { - return nil, errors.New("unsupported serialization version") + if ver == 1 { + encMaxK := binary.LittleEndian.Uint16(data[4:]) + decodedMaxK, err := decodeReservoirSize(encMaxK) + if err != nil { + return nil, err + } + maxK = decodedMaxK + } else { + return nil, errors.New("unsupported serialization version") + } } if family != byte(internal.FamilyEnum.ReservoirUnion.Id) { return nil, errors.New("wrong sketch family") diff --git a/sampling/reservoir_items_union_test.go b/sampling/reservoir_items_union_test.go index 724e3df..b340478 100644 --- a/sampling/reservoir_items_union_test.go +++ b/sampling/reservoir_items_union_test.go @@ -18,8 +18,10 @@ package sampling import ( + "encoding/binary" "testing" + "github.com/apache/datasketches-go/internal" "github.com/stretchr/testify/assert" ) @@ -285,6 +287,23 @@ func TestReservoirItemsUnionSerialization(t *testing.T) { assert.True(t, result.IsEmpty()) }) + t.Run("LegacySerVerEmptyUnion", func(t *testing.T) { + data := make([]byte, 8) + data[0] = unionPreambleLongs + data[1] = 1 // legacy serVer + data[2] = byte(internal.FamilyEnum.ReservoirUnion.Id) + data[3] = unionFlagEmpty + binary.LittleEndian.PutUint16(data[4:], 0x5000) // p=10, i=0 => maxK=1024 + + union, err := NewReservoirItemsUnionFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + assert.Equal(t, 1024, union.MaxK()) + + result, err := union.Result() + assert.NoError(t, err) + assert.True(t, result.IsEmpty()) + }) + t.Run("NonEmptyUnion", func(t *testing.T) { union, err := NewReservoirItemsUnion[int64](100) assert.NoError(t, err) diff --git a/sampling/reservoir_size.go b/sampling/reservoir_size.go new file mode 100644 index 0000000..840045f --- /dev/null +++ b/sampling/reservoir_size.go @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sampling + +import "errors" + +const ( + reservoirSizeBinsPerOctave = 2048 + reservoirSizeInvBinsPerOctave = 1.0 / reservoirSizeBinsPerOctave + reservoirSizeExponentMask = 0x1F + reservoirSizeExponentShift = 11 + reservoirSizeIndexMask = 0x07FF + reservoirSizeMaxEncValue = 0xF7FF // p=30, i=2047 +) + +func decodeReservoirSize(encoded uint16) (int, error) { + value := int(encoded) + if value > reservoirSizeMaxEncValue { + return 0, errors.New("invalid encoded reservoir size") + } + + p := (value >> reservoirSizeExponentShift) & reservoirSizeExponentMask + i := value & reservoirSizeIndexMask + + base := 1 << uint(p) + return int(float64(base) * ((float64(i) * reservoirSizeInvBinsPerOctave) + 1.0)), nil +}