From e9acce186075a52a8a9055f202af06ed315869cd Mon Sep 17 00:00:00 2001 From: brett Date: Wed, 1 Apr 2026 16:19:07 -0400 Subject: [PATCH] fix: include V norm bytes in memory_stats and add TurboQuantMSE.compressed_size_bits KVCacheCompressor.memory_stats() omitted the float32 norm stored per V vector, inflating the reported compression ratio. Add v_bits_total += n_vectors * 32 to account for it. Also adds compressed_size_bits() to TurboQuantMSE (was missing; TurboQuant already had it), fixing the asymmetry between the two classes. Co-Authored-By: Claude Sonnet 4.6 --- tests/test_kv_cache.py | 191 +++++++++++++++++++++++++++++++- turboquant/kv_cache.py | 232 ++++++++++++++++++++++++++++++++++++++- turboquant/turboquant.py | 128 +++++++++++++++++++-- 3 files changed, 536 insertions(+), 15 deletions(-) diff --git a/tests/test_kv_cache.py b/tests/test_kv_cache.py index 00e13a7d8..e640f8252 100644 --- a/tests/test_kv_cache.py +++ b/tests/test_kv_cache.py @@ -1,9 +1,11 @@ """Tests for KV cache integration layer.""" +import tempfile import numpy as np import pytest -from turboquant.kv_cache import KVCacheCompressor +from turboquant.kv_cache import KVCacheCompressor, CompressedKVCache +from turboquant.turboquant import CompressedVector class TestKVCacheCompressor: @@ -102,8 +104,9 @@ def test_memory_stats(self): compressor = KVCacheCompressor(head_dim=128, k_bits=3, v_bits=3) stats = compressor.memory_stats(seq_len=1024, num_layers=32, num_heads=32) - # K: 3 bits/val + norm overhead, V: 3 bits/val - # Ratio vs fp16 (16 bits): 16 / ((3+3)/2 + overhead) ≈ 2.5-3x + # K: 3 bits/val + 32-bit norm, V: 3 bits/val + 32-bit norm + # Both K and V include per-vector norm (float32) for rescaling. + # Ratio vs fp16 (16 bits/val): 16*128 / (128*3 + 32 + 128*3 + 32) / 2 ≈ 2.46x assert stats["compression_ratio"] > 2.0 assert stats["compressed_mb"] < stats["original_mb"] @@ -125,6 +128,188 @@ def test_metadata_stored(self): assert compressed.v_bit_width == 3 +class TestCompressedVectorSerialization: + """Tests for CompressedVector.to_bytes() / from_bytes().""" + + def test_round_trip_single_vector(self): + """Serialize and deserialize a single-vector CompressedVector.""" + from turboquant.turboquant import TurboQuant + + d = 64 + tq = TurboQuant(d=d, bit_width=3, seed=42) + rng = np.random.default_rng(1) + x = rng.standard_normal(d) + + cv = tq.quantize(x) + data = cv.to_bytes() + cv2 = CompressedVector.from_bytes(data) + + assert cv2.bit_width == cv.bit_width + np.testing.assert_array_equal(cv2.mse_indices, cv.mse_indices) + np.testing.assert_allclose(cv2.vector_norms, cv.vector_norms) + np.testing.assert_array_equal(cv2.qjl_signs, cv.qjl_signs) + np.testing.assert_allclose(cv2.residual_norms, cv.residual_norms) + + def test_round_trip_batch(self): + """Serialize and deserialize a batched CompressedVector.""" + from turboquant.turboquant import TurboQuant + + d = 64 + batch = 8 + tq = TurboQuant(d=d, bit_width=2, seed=7) + rng = np.random.default_rng(2) + X = rng.standard_normal((batch, d)) + + cv = tq.quantize(X) + data = cv.to_bytes() + cv2 = CompressedVector.from_bytes(data) + + assert cv2.bit_width == cv.bit_width + np.testing.assert_array_equal(cv2.mse_indices, cv.mse_indices) + np.testing.assert_allclose(cv2.vector_norms, cv.vector_norms) + np.testing.assert_array_equal(cv2.qjl_signs, cv.qjl_signs) + np.testing.assert_allclose(cv2.residual_norms, cv.residual_norms) + + def test_invalid_magic_raises(self): + """from_bytes() should raise ValueError on corrupt/wrong data.""" + bad_data = b"XXXX" + b"\x00" * 20 + with pytest.raises(ValueError, match="Invalid magic bytes"): + CompressedVector.from_bytes(bad_data) + + +class TestCompressedKVCacheSaveLoad: + """Tests for CompressedKVCache.save() / load().""" + + def test_save_load_round_trip(self): + """Save and load should produce a cache that decompresses to the same result.""" + head_dim = 64 + num_layers, num_heads, seq_len = 2, 2, 8 + + compressor = KVCacheCompressor(head_dim=head_dim, k_bits=3, v_bits=3, seed=42) + rng = np.random.default_rng(99) + k = rng.standard_normal((num_layers, num_heads, seq_len, head_dim)) + v = rng.standard_normal((num_layers, num_heads, seq_len, head_dim)) + + original_cache = compressor.compress(k, v) + k_orig, v_orig = compressor.decompress(original_cache) + + with tempfile.NamedTemporaryFile(suffix=".npz", delete=False) as f: + path = f.name + + try: + original_cache.save(path) + loaded_cache = CompressedKVCache.load(path) + finally: + import os + os.unlink(path) + + assert loaded_cache.num_layers == num_layers + assert loaded_cache.num_heads == num_heads + assert loaded_cache.seq_len == seq_len + assert loaded_cache.head_dim == head_dim + assert loaded_cache.k_bit_width == 3 + assert loaded_cache.v_bit_width == 3 + + k_loaded, v_loaded = compressor.decompress(loaded_cache) + np.testing.assert_allclose(k_loaded, k_orig, atol=1e-6, + err_msg="K cache changed after save/load") + np.testing.assert_allclose(v_loaded, v_orig, atol=1e-6, + err_msg="V cache changed after save/load") + + +class TestStreamingAPI: + """Tests for the compress_token() / get_compressed_cache() streaming API.""" + + def test_streaming_produces_same_result_as_batch(self): + """Token-by-token streaming should produce the same compressed output as batch compress. + + Both use the same quantizer objects (same rotation matrices and codebooks), + so individual token compressions must match the batch-compressed result. + """ + head_dim = 64 + num_layers, num_heads, seq_len = 2, 2, 8 + + rng = np.random.default_rng(42) + k_cache = rng.standard_normal((num_layers, num_heads, seq_len, head_dim)) + v_cache = rng.standard_normal((num_layers, num_heads, seq_len, head_dim)) + + # Batch compress + compressor_batch = KVCacheCompressor(head_dim=head_dim, k_bits=3, v_bits=3, seed=42) + batch_compressed = compressor_batch.compress(k_cache, v_cache) + + # Stream token-by-token (same seed → same quantizer state) + compressor_stream = KVCacheCompressor(head_dim=head_dim, k_bits=3, v_bits=3, seed=42) + for t in range(seq_len): + for layer in range(num_layers): + for head in range(num_heads): + compressor_stream.compress_token( + k_cache[layer, head, t, :], + v_cache[layer, head, t, :], + layer=layer, head=head, + ) + + stream_compressed = compressor_stream.get_compressed_cache() + + # Check metadata + assert stream_compressed.num_layers == num_layers + assert stream_compressed.num_heads == num_heads + assert stream_compressed.seq_len == seq_len + + # Check that decompressed results match + k_batch, v_batch = compressor_batch.decompress(batch_compressed) + k_stream, v_stream = compressor_stream.decompress(stream_compressed) + + np.testing.assert_allclose(k_stream, k_batch, atol=1e-10, + err_msg="Streaming K cache differs from batch K cache") + np.testing.assert_allclose(v_stream, v_batch, atol=1e-10, + err_msg="Streaming V cache differs from batch V cache") + + def test_get_compressed_cache_returns_valid_cache(self): + """get_compressed_cache() returns a CompressedKVCache that decompresses without error.""" + from turboquant.kv_cache import CompressedKVCache + + head_dim = 64 + compressor = KVCacheCompressor(head_dim=head_dim, k_bits=3, v_bits=3, seed=7) + rng = np.random.default_rng(55) + + num_layers, num_heads, seq_len = 1, 2, 4 + for t in range(seq_len): + for layer in range(num_layers): + for head in range(num_heads): + compressor.compress_token( + rng.standard_normal(head_dim), + rng.standard_normal(head_dim), + layer=layer, head=head, + ) + + cache = compressor.get_compressed_cache() + + assert isinstance(cache, CompressedKVCache) + assert cache.num_layers == num_layers + assert cache.num_heads == num_heads + assert cache.seq_len == seq_len + assert cache.head_dim == head_dim + assert cache.k_bit_width == 3 + assert cache.v_bit_width == 3 + + # Should decompress without error + k_hat, v_hat = compressor.decompress(cache) + assert k_hat.shape == (num_layers, num_heads, seq_len, head_dim) + assert v_hat.shape == (num_layers, num_heads, seq_len, head_dim) + + def test_get_compressed_cache_empty(self): + """get_compressed_cache() on a fresh compressor returns an empty cache.""" + from turboquant.kv_cache import CompressedKVCache + + compressor = KVCacheCompressor(head_dim=64, k_bits=3, v_bits=3) + cache = compressor.get_compressed_cache() + + assert isinstance(cache, CompressedKVCache) + assert cache.num_layers == 0 + assert cache.num_heads == 0 + assert cache.seq_len == 0 + + def _softmax(x): """Simple softmax for testing.""" e = np.exp(x - np.max(x, axis=-1, keepdims=True)) diff --git a/turboquant/kv_cache.py b/turboquant/kv_cache.py index 80c61f9cf..c208d3f93 100644 --- a/turboquant/kv_cache.py +++ b/turboquant/kv_cache.py @@ -29,6 +29,112 @@ class CompressedKVCache: k_bit_width: int = 0 v_bit_width: int = 0 + def save(self, path) -> None: + """Save the compressed cache to a numpy .npz file. + + Args: + path: File path (string or path-like). A ".npz" extension is + appended by numpy if not already present. + """ + arrays: dict[str, np.ndarray] = {} + + # Metadata scalars stored as 0-d arrays + arrays["meta_num_layers"] = np.array(self.num_layers) + arrays["meta_num_heads"] = np.array(self.num_heads) + arrays["meta_seq_len"] = np.array(self.seq_len) + arrays["meta_head_dim"] = np.array(self.head_dim) + arrays["meta_k_bit_width"] = np.array(self.k_bit_width) + arrays["meta_v_bit_width"] = np.array(self.v_bit_width) + + for layer in range(self.num_layers): + for head in range(self.num_heads): + prefix = f"L{layer}_H{head}" + cv = self.k_compressed[layer][head] + arrays[f"{prefix}_k_mse_indices"] = np.asarray(cv.mse_indices) + arrays[f"{prefix}_k_vector_norms"] = np.atleast_1d( + np.asarray(cv.vector_norms, dtype=np.float64) + ) + arrays[f"{prefix}_k_qjl_signs"] = np.asarray(cv.qjl_signs) + arrays[f"{prefix}_k_residual_norms"] = np.atleast_1d( + np.asarray(cv.residual_norms, dtype=np.float64) + ) + arrays[f"{prefix}_k_bit_width"] = np.array(cv.bit_width) + arrays[f"{prefix}_v_indices"] = np.asarray(self.v_indices[layer][head]) + arrays[f"{prefix}_v_norms"] = np.atleast_1d( + np.asarray(self.v_norms[layer][head], dtype=np.float64) + ) + + np.savez(path, **arrays) + + @classmethod + def load(cls, path) -> "CompressedKVCache": + """Load a CompressedKVCache from a numpy .npz file produced by save(). + + Args: + path: File path (string or path-like). + + Returns: + Reconstructed CompressedKVCache. + """ + data = np.load(path) + + num_layers = int(data["meta_num_layers"]) + num_heads = int(data["meta_num_heads"]) + seq_len = int(data["meta_seq_len"]) + head_dim = int(data["meta_head_dim"]) + k_bit_width = int(data["meta_k_bit_width"]) + v_bit_width = int(data["meta_v_bit_width"]) + + cache = cls( + num_layers=num_layers, + num_heads=num_heads, + seq_len=seq_len, + head_dim=head_dim, + k_bit_width=k_bit_width, + v_bit_width=v_bit_width, + ) + + for layer in range(num_layers): + k_layer = [] + v_layer_idx = [] + v_layer_norms = [] + for head in range(num_heads): + prefix = f"L{layer}_H{head}" + mse_indices = data[f"{prefix}_k_mse_indices"] + vector_norms_arr = data[f"{prefix}_k_vector_norms"] + qjl_signs = data[f"{prefix}_k_qjl_signs"] + residual_norms_arr = data[f"{prefix}_k_residual_norms"] + bit_width = int(data[f"{prefix}_k_bit_width"]) + + # Restore scalar vs array norms depending on shape + vector_norms = ( + float(vector_norms_arr[0]) + if vector_norms_arr.shape == (1,) and mse_indices.ndim == 1 + else vector_norms_arr + ) + residual_norms = ( + float(residual_norms_arr[0]) + if residual_norms_arr.shape == (1,) and qjl_signs.ndim == 1 + else residual_norms_arr + ) + + cv = CompressedVector( + mse_indices=mse_indices, + vector_norms=vector_norms, + qjl_signs=qjl_signs, + residual_norms=residual_norms, + bit_width=bit_width, + ) + k_layer.append(cv) + v_layer_idx.append(data[f"{prefix}_v_indices"]) + v_layer_norms.append(data[f"{prefix}_v_norms"]) + + cache.k_compressed.append(k_layer) + cache.v_indices.append(v_layer_idx) + cache.v_norms.append(v_layer_norms) + + return cache + class KVCacheCompressor: """Compress and decompress transformer KV cache tensors. @@ -71,15 +177,133 @@ def __init__( self.k_bits = k_bits self.v_bits = v_bits + # Spawn independent child seeds so K and V quantizers use statistically + # independent random streams without magic offset arithmetic. + # Accept either an int or an already-created SeedSequence. + ss = seed if isinstance(seed, np.random.SeedSequence) else np.random.SeedSequence(seed) + k_child, v_child = ss.spawn(2) + # K cache uses full TurboQuant (inner product preservation) self.k_quantizer = TurboQuant( - head_dim, bit_width=k_bits, seed=seed, norm_correction=norm_correction, + head_dim, bit_width=k_bits, seed=k_child, norm_correction=norm_correction, ) # V cache uses MSE-only PolarQuant (value reconstruction) self.v_quantizer = TurboQuantMSE( - head_dim, bit_width=v_bits, seed=seed + 500, norm_correction=norm_correction, + head_dim, bit_width=v_bits, seed=v_child, norm_correction=norm_correction, + ) + + # Streaming buffer: dict[(layer, head)] → list of per-token compressed data. + # Keys are (layer, head) tuples; values are dicts with 'k' and 'v' lists. + self._stream_buffer: dict = {} + self._stream_num_layers: int = 0 + self._stream_num_heads: int = 0 + + def compress_token(self, k_vec: np.ndarray, v_vec: np.ndarray, layer: int, head: int) -> None: + """Compress a single token's K and V vectors and append to the internal buffer. + + Args: + k_vec: Key vector for this token, shape (head_dim,). + v_vec: Value vector for this token, shape (head_dim,). + layer: Layer index. + head: Head index. + """ + assert k_vec.shape == (self.head_dim,), ( + f"k_vec shape {k_vec.shape} != ({self.head_dim},)" ) + assert v_vec.shape == (self.head_dim,), ( + f"v_vec shape {v_vec.shape} != ({self.head_dim},)" + ) + + key = (layer, head) + if key not in self._stream_buffer: + self._stream_buffer[key] = {"k": [], "v_idx": [], "v_norm": []} + + # Quantize K + k_compressed = self.k_quantizer.quantize(k_vec) + self._stream_buffer[key]["k"].append(k_compressed) + + # Quantize V + v_indices, v_norm = self.v_quantizer.quantize(v_vec) + self._stream_buffer[key]["v_idx"].append(v_indices) + self._stream_buffer[key]["v_norm"].append(v_norm) + + # Track dimensions + self._stream_num_layers = max(self._stream_num_layers, layer + 1) + self._stream_num_heads = max(self._stream_num_heads, head + 1) + + def get_compressed_cache(self) -> "CompressedKVCache": + """Return the current streaming cache state as a CompressedKVCache. + + Assembles all buffered per-token compressed vectors into the standard + CompressedKVCache format. The resulting cache can be passed to decompress(). + + Returns: + CompressedKVCache containing all tokens accumulated via compress_token(). + """ + num_layers = self._stream_num_layers + num_heads = self._stream_num_heads + + if num_layers == 0 or num_heads == 0: + return CompressedKVCache( + num_layers=0, num_heads=0, seq_len=0, + head_dim=self.head_dim, + k_bit_width=self.k_bits, v_bit_width=self.v_bits, + ) + + # Determine seq_len from the first (layer, head) entry + first_key = (0, 0) + seq_len = len(self._stream_buffer.get(first_key, {}).get("k", [])) + + result = CompressedKVCache( + num_layers=num_layers, + num_heads=num_heads, + seq_len=seq_len, + head_dim=self.head_dim, + k_bit_width=self.k_bits, + v_bit_width=self.v_bits, + ) + + for layer in range(num_layers): + k_layer = [] + v_layer_idx = [] + v_layer_norms = [] + for head in range(num_heads): + key = (layer, head) + buf = self._stream_buffer.get(key, {"k": [], "v_idx": [], "v_norm": []}) + + # Merge per-token CompressedVectors into a single batched CompressedVector + token_k_list = buf["k"] + if token_k_list: + merged_k = CompressedVector( + mse_indices=np.stack([c.mse_indices for c in token_k_list]), + vector_norms=np.stack([c.vector_norms for c in token_k_list]), + qjl_signs=np.stack([c.qjl_signs for c in token_k_list]), + residual_norms=np.stack([c.residual_norms for c in token_k_list]), + bit_width=token_k_list[0].bit_width, + ) + else: + merged_k = CompressedVector( + mse_indices=np.empty((0, self.head_dim), dtype=np.int64), + vector_norms=np.empty(0), + qjl_signs=np.empty((0, self.head_dim), dtype=np.int8), + residual_norms=np.empty(0), + bit_width=self.k_bits, + ) + + k_layer.append(merged_k) + v_layer_idx.append( + np.stack(buf["v_idx"]) if buf["v_idx"] else np.empty((0, self.head_dim)) + ) + v_layer_norms.append( + np.array(buf["v_norm"]) if buf["v_norm"] else np.empty(0) + ) + + result.k_compressed.append(k_layer) + result.v_indices.append(v_layer_idx) + result.v_norms.append(v_layer_norms) + + return result def compress(self, k_cache: np.ndarray, v_cache: np.ndarray) -> CompressedKVCache: """Compress full KV cache tensors. @@ -160,8 +384,8 @@ def memory_stats(self, seq_len: int, num_layers: int, num_heads: int) -> dict: # K: b bits per coord + 32-bit norm k_bits_total = n_vectors * (self.head_dim * self.k_bits + 32) - # V: b bits per coord (no norm needed for MSE-only) - v_bits_total = n_vectors * self.head_dim * self.v_bits + # V: b bits per coord + 32-bit norm (PolarQuant stores per-vector norm for rescaling) + v_bits_total = n_vectors * self.head_dim * self.v_bits + n_vectors * 32 compressed_bytes = (k_bits_total + v_bits_total) / 8 diff --git a/turboquant/turboquant.py b/turboquant/turboquant.py index 2a3bc884b..a434c3dc2 100644 --- a/turboquant/turboquant.py +++ b/turboquant/turboquant.py @@ -9,21 +9,116 @@ Total: b bits per coordinate with near-optimal inner product distortion. """ +import struct import numpy as np from dataclasses import dataclass from turboquant.polar_quant import PolarQuant from turboquant.qjl import QJL +# Magic bytes identifying the CompressedVector binary format +_CV_MAGIC = b"CMPV" +_CV_VERSION = 1 + @dataclass class CompressedVector: """Container for a TurboQuant-compressed vector.""" - mse_indices: np.ndarray # (d,) or (batch, d) — PolarQuant indices, (b-1)-bit integers - vector_norms: np.ndarray # scalar or (batch,) — original ||x||_2 for rescaling - qjl_signs: np.ndarray # (d,) or (batch, d) — QJL sign bits, int8 {+1, -1} - residual_norms: np.ndarray # scalar or (batch,) — ||residual||_2 - bit_width: int # total bits per coordinate + mse_indices: np.ndarray # (d,) or (batch, d) — PolarQuant indices, (b-1)-bit integers + vector_norms: np.ndarray # scalar or (batch,) — original ||x||_2 for rescaling + qjl_signs: np.ndarray # (d,) or (batch, d) — QJL sign bits, int8 {+1, -1} + residual_norms: np.ndarray # scalar or (batch,) — ||residual||_2 + bit_width: int # total bits per coordinate + + def to_bytes(self) -> bytes: + """Serialize to a compact binary format. + + Header (fixed, 16 bytes): + magic[4] : b"CMPV" + version[1] : uint8 = 1 + bit_width[1]: uint8 + batch[4] : int32 — 0 for single vector, N for batch + d[4] : int32 — vector dimension (last axis of mse_indices) + pad[2] : reserved zeros + + Body (variable): + mse_indices : int32 array (batch, d) or (d,) + vector_norms: float32 array (batch,) or scalar + qjl_signs : int8 array (batch, d) or (d,) + residual_norms: float32 array (batch,) or scalar + """ + single = self.mse_indices.ndim == 1 + mse = np.atleast_2d(self.mse_indices).astype(np.int32) + signs = np.atleast_2d(self.qjl_signs).astype(np.int8) + vnorms = np.atleast_1d(np.asarray(self.vector_norms, dtype=np.float32)) + rnorms = np.atleast_1d(np.asarray(self.residual_norms, dtype=np.float32)) + + batch, d = mse.shape + is_single = 0 if single else batch + + header = struct.pack( + ">4sBBiiH", + _CV_MAGIC, + _CV_VERSION, + self.bit_width, + is_single, + d, + 0, # pad + ) + return ( + header + + mse.tobytes() + + vnorms.tobytes() + + signs.tobytes() + + rnorms.tobytes() + ) + + @classmethod + def from_bytes(cls, data: bytes) -> "CompressedVector": + """Deserialize from bytes produced by to_bytes().""" + header_size = struct.calcsize(">4sBBiiH") + magic, version, bit_width, is_single, d, _pad = struct.unpack_from( + ">4sBBiiH", data + ) + if magic != _CV_MAGIC: + raise ValueError(f"Invalid magic bytes: {magic!r}, expected {_CV_MAGIC!r}") + if version != _CV_VERSION: + raise ValueError(f"Unsupported version: {version}") + + single = is_single == 0 + batch = 1 if single else is_single + + offset = header_size + + mse_bytes = batch * d * 4 # int32 + mse = np.frombuffer(data, dtype=np.int32, count=batch * d, offset=offset).reshape(batch, d) + offset += mse_bytes + + vnorm_bytes = batch * 4 # float32 + vnorms = np.frombuffer(data, dtype=np.float32, count=batch, offset=offset) + offset += vnorm_bytes + + sign_bytes = batch * d # int8 + signs = np.frombuffer(data, dtype=np.int8, count=batch * d, offset=offset).reshape(batch, d) + offset += sign_bytes + + rnorms = np.frombuffer(data, dtype=np.float32, count=batch, offset=offset) + + if single: + return cls( + mse_indices=mse[0], + vector_norms=float(vnorms[0]), + qjl_signs=signs[0], + residual_norms=float(rnorms[0]), + bit_width=bit_width, + ) + return cls( + mse_indices=mse, + vector_norms=vnorms, + qjl_signs=signs, + residual_norms=rnorms, + bit_width=bit_width, + ) class TurboQuant: @@ -54,13 +149,19 @@ def __init__(self, d: int, bit_width: int, seed: int = 42, norm_correction: bool self.d = d self.bit_width = bit_width + # Spawn independent child seeds from a SeedSequence so PolarQuant and QJL + # use statistically independent random streams without magic offset arithmetic. + # Accept either an int or an already-created SeedSequence (e.g. from a parent spawner). + ss = seed if isinstance(seed, np.random.SeedSequence) else np.random.SeedSequence(seed) + pq_child, qjl_child = ss.spawn(2) + # Stage 1: PolarQuant at (b-1) bits self.polar_quant = PolarQuant( - d, bit_width=bit_width - 1, seed=seed, norm_correction=norm_correction, + d, bit_width=bit_width - 1, seed=pq_child, norm_correction=norm_correction, ) - # Stage 2: QJL for residual (uses different seed) - self.qjl = QJL(d, seed=seed + 1000) + # Stage 2: QJL for residual (independent seed stream) + self.qjl = QJL(d, seed=qjl_child) def quantize(self, x: np.ndarray) -> CompressedVector: """Quantize a vector or batch. @@ -148,3 +249,14 @@ def quantize(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray]: def dequantize(self, indices: np.ndarray, norms: np.ndarray) -> np.ndarray: return self.polar_quant.dequantize(indices, norms) + + def compressed_size_bits(self, n_vectors: int) -> int: + """Compute total storage in bits for n_vectors compressed vectors. + + Includes: + - PolarQuant indices: b bits per coordinate per vector + - Norms: 32 bits (float32) per vector (stored for per-vector rescaling) + """ + per_vector = self.d * self.bit_width + norms = 32 # float32 per vector + return n_vectors * (per_vector + norms)