diff --git a/hparams/hparams.json b/hparams/hparams.json index 6583e9ce0..8485f7f8c 100644 --- a/hparams/hparams.json +++ b/hparams/hparams.json @@ -46,6 +46,7 @@ "checkpoint_init_version": "2.1.15", "checkpoint_init_window": 59637, "num_evaluation_bins": 5, + "compression_b_choices": [32, 64, 128], "quantization_bins": 4, "quantization_range": 6, "burn_rate": 0.5, diff --git a/neurons/miner.py b/neurons/miner.py index 84bb43d15..f6d21c6e4 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -218,6 +218,7 @@ def __init__(self): use_quantization=True, quantization_bins=self.hparams.quantization_bins, quantization_range=self.hparams.quantization_range, + b_choices=self.hparams.compression_b_choices, ) tplr.logger.info("[Init] compression pipeline ready") diff --git a/neurons/validator.py b/neurons/validator.py index bcc36ffc0..8e673510f 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -384,6 +384,7 @@ def __init__(self): use_quantization=True, quantization_bins=self.hparams.quantization_bins, quantization_range=self.hparams.quantization_range, + b_choices=self.hparams.compression_b_choices, ) # Init optimizer diff --git a/scripts/analyser.py b/scripts/analyser.py index d48167132..3e916ffc7 100644 --- a/scripts/analyser.py +++ b/scripts/analyser.py @@ -88,6 +88,7 @@ def __init__(self): use_quantization=True, quantization_bins=self.hparams.quantization_bins, quantization_range=self.hparams.quantization_range, + b_choices=self.hparams.compression_b_choices, ) # Initialize shapes for each parameter (like in miner/validator) diff --git a/src/tplr/comms.py b/src/tplr/comms.py index ae2194bc2..0eeadf38e 100644 --- a/src/tplr/comms.py +++ b/src/tplr/comms.py @@ -34,6 +34,7 @@ import bittensor as bt import boto3 import botocore +import numpy as np import torch from aiobotocore.client import AioBaseClient from aiobotocore.session import get_session @@ -43,7 +44,8 @@ import tplr from tplr.chain import ChainManager -from tplr.compress import TopKCompressor, unpack_12bit_indices +from tplr.compress import TopKCompressor +from tplr.compression import decode_batch_rows, unpack_12bit_indices from tplr.config import BUCKET_SECRETS, client_config from tplr.schemas import Bucket, CommsGetResult @@ -2478,10 +2480,8 @@ def check_compressed_indices( """ Validates the integrity and format of compressed gradient indices. - This is a crucial security and stability check to ensure that gradients - received from peers are well-formed. It verifies that indices are within - the expected bounds and that the compression format (e.g., 12-bit packing) - is correctly applied. + This ensures indices are within bounds and that the **new Rice/bitmap** + codec payload matches the provided values tensor shape (top‑k). Args: param_name (str): The name of the parameter being checked. @@ -2494,7 +2494,7 @@ def check_compressed_indices( Raises: ValueError: If any validation check fails, such as out-of-bounds - indices, incorrect data types, or malformed packed data. + indices, incorrect data types, or malformed payload. """ allowed_topk = ( min(self.hparams.topk_compression, totalk) @@ -2502,34 +2502,60 @@ def check_compressed_indices( else min(allowed_topk, totalk) ) - def _bounds_check(t: torch.Tensor): - """fast min/max bounds check""" - if t.numel() == 0: - raise ValueError(f"[{param_name}] empty index list") - if t.min().item() < 0 or t.max().item() >= totalk: - bad = t[(t < 0) | (t >= totalk)][0].item() + if not isinstance(idxs, torch.Tensor): + raise ValueError( + f"[{param_name}] Expected tensor for indices, got {type(idxs)}" + ) + if vals is None: + raise ValueError( + f"[{param_name}] Values tensor required for index validation" + ) + if idxs.dtype != torch.uint8: + raise ValueError( + f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}" + ) + if idxs.numel() == 0: + raise ValueError(f"[{param_name}] Empty indices payload") + + try: + rows, C, N = decode_batch_rows(idxs) + if C != totalk: raise ValueError( - f"[{param_name}] Index {bad} out of bounds (totalk = {totalk})" + f"[{param_name}] Payload column size C={C} but expected {totalk}" ) - - # Handle 12-bit packed index format only - if isinstance(idxs, torch.Tensor): - if idxs.dtype != torch.uint8: + # compute expected rows from values shape (flatten all but last dim) + if vals.ndim == 0: + raise ValueError(f"[{param_name}] Values tensor has no top‑k dimension") + expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1 + if N != expected_rows: raise ValueError( - f"[{param_name}] Expected uint8 for 12-bit packed indices, got {idxs.dtype}" + f"[{param_name}] Payload rows N={N} but values imply {expected_rows}" ) - # 12-bit packed format is the only supported format - if vals is None: + + k = vals.shape[-1] + if k != allowed_topk: raise ValueError( - f"[{param_name}] Values tensor required to validate 12-bit packed indices" + f"[{param_name}] Payload K={rows.shape[-1]} but values top-k={k}" ) - if idxs.numel() == 0: - raise ValueError(f"[{param_name}] Empty packed indices tensor") - # Unpack using the values shape + # Bounds check + if rows.numel() > 0: + min_idx = int(rows.min().item()) + max_idx = int(rows.max().item()) + else: + min_idx, max_idx = 0, -1 + if min_idx < 0 or max_idx >= totalk: + raise ValueError( + f"[{param_name}] Index out of bounds (min={min_idx}, max={max_idx}, totalk={totalk})" + ) + except ValueError as e: + # NB: legacy path + tplr.logger.warning( + f"Failed to unpack: {e} Falling back to legacy uncompress." + ) + # Fallback: likely old format -> try legacy decoder try: unpacked = unpack_12bit_indices(idxs, vals.shape) - # Validate that the last dimension matches allowed_topk if unpacked.shape[-1] != allowed_topk: raise ValueError( f"[{param_name}] Invalid topk dimension: " @@ -2537,9 +2563,11 @@ def _bounds_check(t: torch.Tensor): ) _bounds_check(unpacked) except Exception as e: - raise ValueError(f"[{param_name}] Failed to unpack 12-bit indices: {e}") - else: - raise ValueError(f"[{param_name}] Expected tensor but got {type(idxs)}") + raise ValueError( + f"[{param_name}] Failed to unpack legacy 12-bit indices: {e}" + ) + except Exception as e: + raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}") async def s3_get_object_size(self, bucket: Bucket, key: str) -> int | None: """ diff --git a/src/tplr/compress.py b/src/tplr/compress.py index 206f3cd8d..b03decd13 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -29,6 +29,7 @@ from torch.distributed.tensor import DTensor as DT import tplr +from tplr.compression import decode_batch_rows, encode_batch_rows, unpack_12bit_indices # ─────────── type aliases ──────────────────────────────────────────────── # primitive shapes @@ -48,100 +49,6 @@ Q = TypeVar("Q", Literal[True], Literal[False]) -def pack_12bit_indices(indices: torch.Tensor) -> torch.Tensor: - """ - Pack int64 indices into 12-bit representation. - Every 2 indices (24 bits) are packed into 3 uint8 values. - Assumes even number of indices (topk is always even). - - Args: - indices: Tensor with values < 4096 (12-bit max), must have even number of elements - - Returns: - packed_tensor as uint8 - """ - # Ensure indices fit in 12 bits - max_idx = indices.max().item() if indices.numel() > 0 else 0 - if max_idx >= 4096: - raise ValueError(f"Index {max_idx} exceeds 12-bit limit (4095)") - - # Flatten the tensor - indices_flat = indices.flatten() - n_indices = indices_flat.numel() - - # Ensure we have even number of indices - if n_indices % 2 != 0: - raise ValueError(f"Number of indices must be even, got {n_indices}") - - # Convert to int32 for bit manipulation - indices_flat = indices_flat.to(torch.int32) - - # Process all as pairs - indices_pairs = indices_flat - n_pairs = n_indices // 2 - - # Calculate packed size - packed_size = n_pairs * 3 - packed = torch.zeros(packed_size, dtype=torch.uint8, device=indices.device) - - # Vectorized packing for pairs - if n_pairs > 0: - idx_pairs = indices_pairs.reshape(-1, 2) - idx1 = idx_pairs[:, 0] - idx2 = idx_pairs[:, 1] - - # Pack pairs: idx1 uses byte0 + lower 4 bits of byte1 - # idx2 uses upper 4 bits of byte1 + byte2 - packed[0::3] = (idx1 & 0xFF).to(torch.uint8) # Lower 8 bits of idx1 - packed[1::3] = (((idx1 >> 8) & 0x0F) | ((idx2 & 0x0F) << 4)).to(torch.uint8) - packed[2::3] = ((idx2 >> 4) & 0xFF).to(torch.uint8) # Upper 8 bits of idx2 - - return packed - - -def unpack_12bit_indices(packed: torch.Tensor, values_shape: ShapeT) -> torch.Tensor: - """ - Unpack 12-bit packed indices back to int64. - Assumes even number of indices. - - Args: - packed: Packed uint8 tensor - values_shape: Shape of the values tensor (same as original indices shape) - - Returns: - Unpacked indices as int64 tensor with original shape - """ - n_indices = int(torch.prod(torch.tensor(values_shape)).item()) - - if n_indices == 0: - return torch.zeros(values_shape, dtype=torch.int64, device=packed.device) - - # Ensure even number of indices - if n_indices % 2 != 0: - raise ValueError(f"Number of indices must be even, got {n_indices}") - - # Prepare output - indices = torch.zeros(n_indices, dtype=torch.int64, device=packed.device) - - # All indices are paired - n_pairs = n_indices // 2 - - if n_pairs > 0: - # Vectorized unpacking - byte0 = packed[0::3].to(torch.int64) - byte1 = packed[1::3].to(torch.int64) - byte2 = packed[2::3].to(torch.int64) - - # Reconstruct indices - indices[0::2] = byte0 | ((byte1 & 0x0F) << 8) # idx1 - indices[1::2] = ((byte1 >> 4) & 0x0F) | (byte2 << 4) # idx2 - - # Reshape to match values shape - indices = indices.reshape(values_shape) - - return indices - - class ChunkingTransformer: """ A transformer for chunking tensors to enable more efficient gradient processing. @@ -301,9 +208,12 @@ class TopKCompressor(Generic[Q]): It supports both 1D and 2D tensors. """ + DEFAULT_B_CHOICES: tuple[int, ...] = (64, 128) + use_quantization: Q n_bins: int range_in_sigmas: int + b_choices: tuple[int, ...] # for Rice/bitmap codec # ------------------------------------------------------------------ # # Constructor – two overloads so each instance "remembers" its mode @@ -315,6 +225,7 @@ def __init__( use_quantization: Literal[True] = True, quantization_bins: int = 256, quantization_range: int = 6, + b_choices: tuple[int, ...] | None = None, ) -> None: ... @overload @@ -324,6 +235,7 @@ def __init__( use_quantization: Literal[False] = False, quantization_bins: int = 256, quantization_range: int = 6, + b_choices: tuple[int, ...] | None = None, ) -> None: ... @torch.no_grad() @@ -333,6 +245,7 @@ def __init__( use_quantization: bool = False, quantization_bins: int = 256, quantization_range: int = 6, + b_choices: tuple[int, ...] | None = None, ) -> None: """ Initialise the TopKCompressor. @@ -349,6 +262,14 @@ def __init__( quantization_range # Quantization range in standard deviations ) + if b_choices is None: + b_choices = self.DEFAULT_B_CHOICES + b_choices = tuple(sorted(int(b) for b in b_choices)) + for b in b_choices: + if b <= 0 or (b & (b - 1)) != 0: + raise ValueError(f"b_choices must be powers of two > 0, got {b}") + self.b_choices = b_choices + def _clamp_topk(self, x, topk) -> int: """ Clamp the top-k value to be within the valid range and ensure it's even. @@ -406,21 +327,34 @@ def compress(self, x: torch.Tensor, topk: int): # type: ignore[override] totalk = x.shape[-1] topk = self._clamp_topk(x, topk) - idx_int64 = torch.topk( - x.abs(), k=topk, dim=-1, largest=True, sorted=False - ).indices - val = torch.gather(x, dim=-1, index=idx_int64) + topk_vals, idx = torch.topk(x.abs(), k=topk, dim=-1, largest=True, sorted=False) + del topk_vals + + idx = idx.to(torch.int32) + val = torch.gather(x, dim=-1, index=idx) + + # Flatten to [rows, k] for the codec + idx2d = idx.view(-1, topk) + val2d = val.view(-1, topk) + + # sort indices and apply same perm to values + idx_sorted, perm = torch.sort(idx2d, dim=1) + val = torch.gather(val2d, dim=1, index=perm) - # Pack indices into 12-bit representation for efficient storage - # This reduces storage by 25% compared to int16 - idx = pack_12bit_indices(idx_int64) + # pick only B_choices that divide this layer's C + valid_B = tuple(b for b in self.b_choices if totalk % b == 0) + if not valid_B: + raise ValueError( + f"No valid b_choices for C={totalk}; " + f"b_choices={self.b_choices} must divide C" + ) + idx_bytes, _meta = encode_batch_rows(idx_sorted, C=totalk, B_choices=valid_B) # Apply 8-bit quantization if enabled if self.use_quantization: - val, quant_params = self._quantize_values(val) - return idx, val, xshape, totalk, quant_params - - return idx, val, xshape, totalk + val, qparams = self._quantize_values(val) + return idx_bytes, val, xshape, totalk, qparams + return idx_bytes, val, xshape, totalk @torch.no_grad() def decompress( @@ -454,21 +388,31 @@ def decompress( if len(xshape) > 2: # 2D weights x = rearrange(x, "y x h w -> y x (h w)") - # Unpack 12-bit indices using val shape (if needed) + # Decode indices if idx.dtype == torch.uint8: - # 12-bit packed format - unpack it - idx_int64 = unpack_12bit_indices(idx, val.shape) + rows, C, _N = decode_batch_rows(idx) + if C != totalk: + raise ValueError(f"Index payload C={C} but expected {totalk}") + if rows.shape[-1] != val.shape[-1]: + raise ValueError( + f"Row-wise topk size mismatch: decoded K={rows.shape[-1]}, val K={val.shape[-1]}" + ) + idx_int64 = rows.to(device=p.device, dtype=torch.int64) + idx_int64 = idx_int64.view_as(val) elif idx.dtype in (torch.int64, torch.long): - # Already unpacked (from batch_decompress) - idx_int64 = idx + idx_int64 = idx.to(p.device) else: - raise ValueError( - f"Expected uint8 (packed) or int64 (unpacked) indices, got {idx.dtype}" - ) + raise ValueError(f"Unsupported index tensor dtype: {idx.dtype}") + # Ensure val has the same dtype as x for scatter operation if val.dtype != x.dtype: val = val.to(dtype=x.dtype) + # second condition for legacy decompress + if len(xshape) > 2 and len(x) != len(idx_int64): + idx_int64 = rearrange(idx_int64, "(y x) h -> y x h", y=xshape[0]) + val = rearrange(val, "(y x) h -> y x h", y=xshape[0]) + x.scatter_reduce_( dim=-1, index=idx_int64, src=val, reduce="mean", include_self=False ).reshape(xshape) @@ -562,14 +506,35 @@ def batch_decompress( idx_list = idx if isinstance(idx, Sequence) else [idx] for i, i_data in enumerate(idx_list): - if i_data.dtype != torch.uint8: - raise ValueError( - f"Expected uint8 for 12-bit packed indices, got {i_data.dtype}" - ) - # Unpack 12-bit format using corresponding values shape v_data = val_list[i] - idx_unpacked = unpack_12bit_indices(i_data.to(p.device), v_data.shape) - unpacked_indices.append(idx_unpacked) + if i_data.dtype == torch.uint8: + try: + rows, C, _N = decode_batch_rows(i_data) + if C != totalk: + raise ValueError(f"Index payload C={C} but expected {totalk}") + if rows.shape[-1] != v_data.shape[-1]: + raise ValueError( + f"Row-wise topk size mismatch: decoded K={rows.shape[-1]}, " + f"val K={v_data.shape[-1]}" + ) + idx_int64 = rows.to(device=p.device, dtype=torch.int64) + idx_unpacked = idx_int64.view_as(v_data) + except ValueError as e: + # NB: legacy path + tplr.logger.warning( + f"Failed to unpack: {e}. Falling back to legacy uncompress." + ) + # Fallback: likely old format -> try legacy decoder + idx_unpacked = unpack_12bit_indices( + i_data.to(p.device), v_data.shape + ) + + unpacked_indices.append(idx_unpacked) + elif i_data.dtype in (torch.int64, torch.long): + idx_unpacked = i_data.to(p.device) + unpacked_indices.append(idx_unpacked) + else: + raise ValueError(f"Unsupported index dtype in batch: {i_data.dtype}") idx_concat = torch.cat(unpacked_indices, dim=-1) val_concat = torch.cat(processed_vals, dim=-1).to(p.dtype) diff --git a/src/tplr/compression/__init__.py b/src/tplr/compression/__init__.py new file mode 100644 index 000000000..31d94bfb8 --- /dev/null +++ b/src/tplr/compression/__init__.py @@ -0,0 +1,29 @@ +# The MIT License (MIT) +# © 2025 tplr.ai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from .hybrid import ( + decode_batch_rows, # decoder (CPU) + encode_batch_rows, # GPU-accelerated encoder → bytes + perm + meta +) +from .pack12 import pack_12bit_indices, unpack_12bit_indices + +__all__ = [ + "encode_batch_rows", + "decode_batch_rows", + "pack_12bit_indices", + "unpack_12bit_indices", +] diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py new file mode 100644 index 000000000..e8e5fb27f --- /dev/null +++ b/src/tplr/compression/hybrid.py @@ -0,0 +1,667 @@ +import math +import struct +from typing import Dict, Tuple, Union + +import numpy as np +import torch +import triton +import triton.language as tl + +BytesLike = Union[bytes, bytearray, np.ndarray, torch.Tensor] + + +@torch.no_grad() +def encode_batch_rows( + idx_sorted: torch.Tensor, *, C: int, B_choices: Tuple[int, ...] = (64, 128) +) -> Tuple[BytesLike, Dict]: + """ + Compresses a 2D int64 tensor of Top-K indices into a byte string + using a per-row adaptive Rice/Bitmap compression scheme on the GPU. + + Layout: + + [global header] + 0..3 : "CGRP" (magic) + 4..7 : C (uint32 LE) + 8..9 : K (uint16 LE) + 10..13 : R (uint32 LE, num_rows) + 14 : num_B (uint8) + 15.. : B_choices (num_B * uint16 LE) + + [row table] (num_rows entries, 3 bytes each) + - uint16 length_bytes[r] (payload size in BYTES for row r) + - uint8 header[r] ((best_B_idx << 1) | use_bitmap) + + [payload region] + - concatenated bitstreams, one per row, each length_bytes[r] bytes, + byte-aligned, containing ONLY the Rice/bitmap-coded deltas + (no per-row length or header in-band). + """ + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this function.") + + if not isinstance(idx_sorted, torch.Tensor) or idx_sorted.ndim != 2: + raise ValueError( + f"idx must be a 2D int64 tensor, got {idx_sorted.shape} {idx_sorted.dtype}" + ) + + if not all(isinstance(b, int) and (b & (b - 1) == 0) and b > 0 for b in B_choices): + raise ValueError(f"All B_choices must be powers of two, got {B_choices}") + + if not all(C % b == 0 for b in B_choices): + raise ValueError(f"All B_choices must evenly divide C={C}, got {B_choices}") + + num_rows, k_dim = idx_sorted.shape + if num_rows == 0: + return b"", { + "total_bits": 0, + "avg_bits_per_row": 0.0, + "B_hist": {b: 0 for b in B_choices}, + } + + if not idx_sorted.is_cuda: + idx_sorted = idx_sorted.cuda() + idx_sorted = idx_sorted.contiguous() + dev = idx_sorted.device + + vals = torch.cat( + (idx_sorted[:, :1], idx_sorted[:, 1:] - idx_sorted[:, :-1]), + dim=1, + ).to(torch.int32) + + # k_rice parameters (log2(C // B)) + k_rice_choices = tuple(int(math.log2(C // b)) for b in B_choices) + num_B_choices = len(B_choices) + k_rice_choices_tensor = torch.tensor(k_rice_choices, dtype=torch.int32, device=dev) + + # Row header bits + B_choice_bits = (num_B_choices - 1).bit_length() + ROW_HEADER_BITS = 1 + B_choice_bits + + # Output tensors for cost kernel + costs = torch.empty((num_rows, num_B_choices), dtype=torch.int32, device=dev) + is_bitmap = torch.empty((num_rows, num_B_choices), dtype=torch.int8, device=dev) + + # Calculate grid for cost kernel + grid = (num_rows,) + + cost_kernel[grid]( + vals, + costs, + is_bitmap, + k_dim=k_dim, + num_rows=num_rows, + num_B_choices=num_B_choices, + k_rice_choices_ptr=k_rice_choices_tensor, + ) + + # Best choice per row + min_costs, best_B_idx = torch.min(costs, dim=1) + is_bitmap_choice = ( + torch.gather(is_bitmap, 1, best_B_idx.unsqueeze(1)).squeeze(1).to(torch.int32) + ) + + # Payload sizing + row_payload_bits = min_costs + row_payload_bytes = ((row_payload_bits + 7) // 8).to(torch.int32) + + if torch.any(row_payload_bytes > 0xFFFF): + raise ValueError( + "Row payload length exceeds 65535 bytes; cannot store in uint16." + ) + + # Byte offsets + if num_rows == 1: + row_byte_offsets = torch.zeros(1, dtype=torch.int32, device=dev) + else: + row_byte_offsets = torch.nn.functional.pad( + torch.cumsum(row_payload_bytes, dim=0, dtype=torch.int32)[:-1], (1, 0) + ) + total_payload_bytes = int(row_payload_bytes.sum().item()) + + # Global header Construction + header_list = [] + header_list.append(b"CGRP") # 4B magic + header_list.append(struct.pack("> 8) & 0xFF).to(torch.uint8) + row_table_flat[:, 2] = (headers_i32 & ((1 << ROW_HEADER_BITS) - 1)).to(torch.uint8) + + payload_buf[global_header_len_bytes : global_header_len_bytes + row_table_bytes] = ( + row_table_flat.view(-1) + ) + + # Calculate absolute byte offsets for pack kernel + row_abs_byte_offsets = (payload_region_start + row_byte_offsets).to(torch.int32) + + # Pack payloads + pack_kernel[(num_rows,)]( + vals, + payload_buf, + row_abs_byte_offsets, + best_B_idx.to(torch.int32), + is_bitmap_choice, + k_rice_choices_tensor, + num_rows, + k_dim=k_dim, + ) + + # Meta stats + b_counts = torch.bincount(best_B_idx, minlength=len(B_choices)) + B_hist = {b: c.item() for b, c in zip(B_choices, b_counts)} + total_row_bytes = total_payload_bytes + row_entry_bytes * num_rows + total_bits = int(total_row_bytes * 8) + + meta = { + "total_bits": total_bits, + "avg_bits_per_row": float(total_bits / num_rows), + "avg_payload_bits_per_row": float(row_payload_bits.float().mean().item()), + "B_hist": B_hist, + } + return payload_buf, meta + + +@triton.jit +def cost_kernel( + delta_ptr, + costs_ptr, + is_bitmap_ptr, + k_dim: tl.constexpr, + num_rows: tl.int32, + num_B_choices: tl.int32, + k_rice_choices_ptr, +): + """ + Calculates bit cost. One row per program instance. + """ + row_idx = tl.program_id(0) + if row_idx >= num_rows: + return + + i = tl.arange(0, k_dim) + row_base = row_idx * k_dim + delta = tl.load(delta_ptr + row_base + i) + delta0 = tl.load(delta_ptr + row_base) + + b_idx = 0 + while b_idx < num_B_choices: + k_rice = tl.load(k_rice_choices_ptr + b_idx) + + q = delta >> k_rice + q0 = delta0 >> k_rice + + rice_cost = tl.sum(q + 1) + k_dim * k_rice + + # Bitmap cost: head is Rice, tail is (1 + k_rice) + bitmap_cost = (q0 + 1 + k_rice) + (k_dim - 1) * (1 + k_rice) + + # Allow bitmap only if tail q are in {0,1} + q_tail_max = tl.max(tl.where(i > 0, q, 0)) + bitmap_allowed = q_tail_max <= 1 + + use_bitmap = (bitmap_cost < rice_cost) & bitmap_allowed + min_cost = tl.where(use_bitmap, bitmap_cost, rice_cost) + + out_offset = row_idx * num_B_choices + b_idx + tl.store(costs_ptr + out_offset, min_cost) + tl.store(is_bitmap_ptr + out_offset, tl.where(use_bitmap, 1, 0).to(tl.int32)) + b_idx += 1 + + +@triton.jit +def pack_kernel( + delta_ptr, # (rows, k_dim) IN int32 + u8_payload_ptr, # OUT uint8 + row_abs_byte_offsets_ptr, # (rows,) IN int32 (byte offset where payload starts) + best_B_idx_ptr, # (rows,) IN + is_bitmap_ptr, # (rows,) IN + k_rice_choices_ptr, # [num_B] IN + num_rows: tl.int32, + k_dim: tl.int32, # dynamic +): + """ + Writes payload bits using a 64-bit register accumulator. + Uses unaligned byte stores (split into 4 bytes) to prevent cudaErrorMisalignedAddress. + """ + row_idx = tl.program_id(0) + if row_idx >= num_rows: + return + + # Load row params + out_byte_off = tl.load(row_abs_byte_offsets_ptr + row_idx).to(tl.int32) + b_idx_i32 = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) + use_bitmap_i32 = (tl.load(is_bitmap_ptr + row_idx) & 1).to(tl.int32) + k_rice_i32 = tl.load(k_rice_choices_ptr + b_idx_i32).to(tl.int32) + M_i32 = tl.full((), 1, dtype=tl.int32) << k_rice_i32 + + # Accumulator state + acc_data = tl.full((), 0, dtype=tl.uint64) + acc_bits = tl.full((), 0, dtype=tl.int32) + + # Output pointer (byte-aligned) + out_ptr_base = u8_payload_ptr + out_byte_off + + base_idx = row_idx * k_dim + + i = 0 + while i < k_dim: + val = tl.load(delta_ptr + base_idx + i).to(tl.int32) + + # Compute q, r + q = (val >> k_rice_i32).to(tl.uint64) + r = (val & (M_i32 - 1)).to(tl.uint64) + + is_rice = (i == 0) | (use_bitmap_i32 == 0) + + if is_rice: + # Rice: q '1's, then '0', then k_rice bits of r + q_count = q.to(tl.int32) + while q_count > 0: + acc_data |= tl.full((), 1, dtype=tl.uint64) << acc_bits + acc_bits += 1 + q_count -= 1 + + # Flush Check + if acc_bits >= 32: + # Unaligned Store (4 bytes separately) + val_u32 = acc_data.to(tl.uint32) + tl.store(out_ptr_base + 0, (val_u32 & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 1, ((val_u32 >> 8) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 2, ((val_u32 >> 16) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 3, ((val_u32 >> 24) & 0xFF).to(tl.uint8)) + + out_ptr_base += 4 + acc_data >>= 32 + acc_bits -= 32 + + # Append Separator '0' + acc_bits += 1 + else: + # Bitmap: q is 1 bit + q_bit = tl.where(q > 0, 1, 0).to(tl.uint64) + acc_data |= q_bit << acc_bits + acc_bits += 1 + + # Flush Check (after separator/bitmap bit) + if acc_bits >= 32: + val_u32 = acc_data.to(tl.uint32) + tl.store(out_ptr_base + 0, (val_u32 & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 1, ((val_u32 >> 8) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 2, ((val_u32 >> 16) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 3, ((val_u32 >> 24) & 0xFF).to(tl.uint8)) + + out_ptr_base += 4 + acc_data >>= 32 + acc_bits -= 32 + + # Append Remainder + acc_data |= r << acc_bits + acc_bits += k_rice_i32 + + # Flush Check + if acc_bits >= 32: + val_u32 = acc_data.to(tl.uint32) + tl.store(out_ptr_base + 0, (val_u32 & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 1, ((val_u32 >> 8) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 2, ((val_u32 >> 16) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 3, ((val_u32 >> 24) & 0xFF).to(tl.uint8)) + + out_ptr_base += 4 + acc_data >>= 32 + acc_bits -= 32 + + i += 1 + + # Final Flush + while acc_bits > 0: + tl.store(out_ptr_base, (acc_data & 0xFF).to(tl.uint8)) + out_ptr_base += 1 + acc_data >>= 8 + acc_bits -= 8 + + +@triton.jit +def parse_header_kernel( + u8_payload_ptr, # (total_bytes,) uint8 + C_out_ptr, # (1,) int32 + K_out_ptr, # (1,) int32 + R_out_ptr, # (1,) int32 + num_B_out_ptr, # (1,) int32 + B_choices_out_ptr, # (MAX_B_CHOICES,) int32 + header_bytes_out_ptr, # (1,) int32 + max_num_B: tl.constexpr, +): + """ + Simple GPU kernel to parse the global header. + """ + pid = tl.program_id(0) + if pid != 0: + return + + # C (uint32 LE at bytes 4..7) + b4 = tl.load(u8_payload_ptr + 4).to(tl.int32) + b5 = tl.load(u8_payload_ptr + 5).to(tl.int32) + b6 = tl.load(u8_payload_ptr + 6).to(tl.int32) + b7 = tl.load(u8_payload_ptr + 7).to(tl.int32) + C_val = b4 | (b5 << 8) | (b6 << 16) | (b7 << 24) + tl.store(C_out_ptr, C_val) + + # K (uint16 LE at bytes 8..9) + b8 = tl.load(u8_payload_ptr + 8).to(tl.int32) + b9 = tl.load(u8_payload_ptr + 9).to(tl.int32) + K_val = b8 | (b9 << 8) + tl.store(K_out_ptr, K_val) + + # R (uint32 LE at bytes 10..13) + b10 = tl.load(u8_payload_ptr + 10).to(tl.int32) + b11 = tl.load(u8_payload_ptr + 11).to(tl.int32) + b12 = tl.load(u8_payload_ptr + 12).to(tl.int32) + b13 = tl.load(u8_payload_ptr + 13).to(tl.int32) + R_val = b10 | (b11 << 8) | (b12 << 16) | (b13 << 24) + tl.store(R_out_ptr, R_val) + + # num_B at byte 14 + num_B_val = tl.load(u8_payload_ptr + 14).to(tl.int32) + tl.store(num_B_out_ptr, num_B_val) + + # Read B_choices (start at 15) + off = 15 + i = 0 + while i < max_num_B: + if i < num_B_val: + lo = tl.load(u8_payload_ptr + off).to(tl.int32) + hi = tl.load(u8_payload_ptr + off + 1).to(tl.int32) + B_val = lo | (hi << 8) + tl.store(B_choices_out_ptr + i, B_val) + off += 2 + i += 1 + + tl.store(header_bytes_out_ptr, off) + + +@triton.jit +def parse_row_table_kernel( + u8_payload_ptr, + row_payload_bytes_ptr, + best_B_idx_ptr, + use_bitmap_ptr, + row_table_start_ptr, # int32* from header kernel + num_rows_ptr, # int32* from header kernel + ROW_HEADER_BITS: tl.constexpr, +): + pid = tl.program_id(0) + num_rows = tl.load(num_rows_ptr) + if pid >= num_rows: + return + + row_table_start = tl.load(row_table_start_ptr) + entry_offset = row_table_start + pid * 3 + + b0 = tl.load(u8_payload_ptr + entry_offset).to(tl.int32) + b1 = tl.load(u8_payload_ptr + entry_offset + 1).to(tl.int32) + length_i32 = b0 | (b1 << 8) + tl.store(row_payload_bytes_ptr + pid, length_i32) + + header_byte = tl.load(u8_payload_ptr + entry_offset + 2).to(tl.int32) + header_mask = (tl.full((), 1, dtype=tl.int32) << ROW_HEADER_BITS) - 1 + header_i32 = header_byte & header_mask + + use_bitmap_i32 = header_i32 & 1 + best_B_idx_i32 = header_i32 >> 1 + + tl.store(use_bitmap_ptr + pid, use_bitmap_i32) + tl.store(best_B_idx_ptr + pid, best_B_idx_i32) + + +@triton.jit +def decode_rows_kernel( + u8_payload_ptr, + out_vals_ptr, + row_byte_offsets_ptr, + best_B_idx_ptr, + use_bitmap_ptr, + k_rice_choices_ptr, + num_rows_ptr, + K_ptr, +): + """ + Decodes each row's Rice/bitmap bitstream into *final* prefix-summed values. + + Uses a simple streaming bit-buffer over bytes: + - bitbuf: uint64 (low bits are next bits in the stream) + - bits_in_buf: int (# of valid bits in bitbuf) + - byte_offset: int (index into u8_payload_ptr) + """ + row_idx = tl.program_id(0) + num_rows = tl.load(num_rows_ptr) + if row_idx >= num_rows: + return + + K = tl.load(K_ptr) + + # Row params + start_byte = tl.load(row_byte_offsets_ptr + row_idx).to(tl.int32) + b_idx = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) + use_bitmap = (tl.load(use_bitmap_ptr + row_idx) & 1).to(tl.int32) + + k_rice = tl.load(k_rice_choices_ptr + b_idx).to(tl.int32) + M = tl.full((), 1, dtype=tl.int32) << k_rice + + # Streaming bit-buffer state + byte_offset = start_byte + bitbuf = tl.full((), 0, dtype=tl.uint64) + bits_in_buf = tl.full((), 0, dtype=tl.int32) + + base_out = row_idx * K + prev = tl.full((), 0, dtype=tl.int32) + + i = 0 + while i < K: + is_rice = (i == 0) | (use_bitmap == 0) + + # --- Decode q --- + q = tl.full((), 0, dtype=tl.int32) + + if is_rice: + # Unary code: q times '1' then a '0' + reading = 1 + while reading > 0: + # Ensure at least one bit in buffer + if bits_in_buf == 0: + next_byte = tl.load(u8_payload_ptr + byte_offset).to(tl.uint64) + byte_offset += 1 + bitbuf |= next_byte << bits_in_buf + bits_in_buf += 8 + + bit = (bitbuf & 1).to(tl.int32) + bitbuf >>= 1 + bits_in_buf -= 1 + + if bit == 1: + q += 1 + else: + reading = 0 + else: + # Bitmap mode: q is 1 bit + if bits_in_buf == 0: + next_byte = tl.load(u8_payload_ptr + byte_offset).to(tl.uint64) + byte_offset += 1 + bitbuf |= next_byte << bits_in_buf + bits_in_buf += 8 + q = (bitbuf & 1).to(tl.int32) + bitbuf >>= 1 + bits_in_buf -= 1 + + # --- Decode remainder r (k_rice bits) --- + r = tl.full((), 0, dtype=tl.int32) + if k_rice > 0: + # Ensure enough bits for r + while bits_in_buf < k_rice: + next_byte = tl.load(u8_payload_ptr + byte_offset).to(tl.uint64) + byte_offset += 1 + bitbuf |= next_byte << bits_in_buf + bits_in_buf += 8 + + mask = (tl.full((), 1, dtype=tl.uint64) << k_rice) - 1 + r_u64 = bitbuf & mask + bitbuf >>= k_rice + bits_in_buf -= k_rice + r = r_u64.to(tl.int32) + + val = q * M + r + + # In-kernel prefix sum (delta decode) + prev += val + tl.store(out_vals_ptr + base_out + i, prev) + + i += 1 + + +def decode_batch_rows( + payload: BytesLike, + max_num_B: int = 16, +) -> tuple[torch.Tensor, int, int]: + """ + Decode a payload produced by encode_batch_rows. + + Returns: + idx (torch.Tensor[int64]): (rows, K) sorted indices + C (int): vocabulary size parameter + R (int): number of rows + """ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required") + + # Move to GPU/Tensor + if isinstance(payload, torch.Tensor): + payload_gpu = payload if payload.is_cuda else payload.cuda() + elif isinstance(payload, np.ndarray): + payload_gpu = torch.from_numpy(payload).to("cuda", dtype=torch.uint8) + else: + arr = np.frombuffer(bytes(payload), dtype=np.uint8) + payload_gpu = torch.from_numpy(arr).to("cuda", dtype=torch.uint8) + + payload_gpu = payload_gpu.contiguous() + dev = payload_gpu.device + total_bytes = int(payload_gpu.numel()) + + if total_bytes == 0: + return torch.empty((0, 0), dtype=torch.int64, device=dev), 0, 0 + if total_bytes < 15: + raise ValueError("Malformed payload - too few bytes") + magic = bytes(payload_gpu[:4].cpu().tolist()) + if magic != b"CGRP": + raise ValueError("Invalid magic header") + + # Pad payload on GPU with a few zero bytes to make safe over-reads trivial + padded = torch.zeros(total_bytes + 8, dtype=torch.uint8, device=dev) + padded[:total_bytes].copy_(payload_gpu) + payload_gpu = padded + + # --- 1) Parse Header (GPU Kernel) --- + C_out = torch.empty(1, dtype=torch.int32, device=dev) + K_out = torch.empty(1, dtype=torch.int32, device=dev) + R_out = torch.empty(1, dtype=torch.int32, device=dev) + num_B_out = torch.empty(1, dtype=torch.int32, device=dev) + B_choices_out = torch.empty(max_num_B, dtype=torch.int32, device=dev) + header_bytes_out = torch.empty(1, dtype=torch.int32, device=dev) + + parse_header_kernel[(1,)]( + payload_gpu, + C_out, + K_out, + R_out, + num_B_out, + B_choices_out, + header_bytes_out, + max_num_B=max_num_B, + ) + + # Minimal sync to get scalar values needed for kernel setup + C = int(C_out.item()) + num_B = int(num_B_out.item()) + num_rows = int(R_out.item()) + B_choices_list = B_choices_out[:num_B].cpu().tolist() + + # --- 2) Prepare k_rice --- + k_rice_choices = [] + for B in B_choices_list: + M = C // B + k_rice_choices.append(int(math.log2(M))) + k_rice_choices_tensor = torch.tensor(k_rice_choices, dtype=torch.int32, device=dev) + + ROW_HEADER_BITS = 1 + (num_B - 1).bit_length() + + # --- 3) Parse Row Table (GPU) --- + row_payload_bytes = torch.empty(num_rows, dtype=torch.int32, device=dev) + best_B_idx = torch.empty(num_rows, dtype=torch.int32, device=dev) + use_bitmap = torch.empty(num_rows, dtype=torch.int32, device=dev) + + parse_row_table_kernel[(num_rows,)]( + payload_gpu, + row_payload_bytes, + best_B_idx, + use_bitmap, + header_bytes_out, + R_out, + ROW_HEADER_BITS=ROW_HEADER_BITS, + ) + + # --- 4) Offsets (all on GPU) --- + # Row table starts immediately after global header + payload_region_start = header_bytes_out + (R_out * 3) # in BYTES, on GPU tensor + + row_payload_bytes_64 = row_payload_bytes.to(torch.int64) + # Exclusive prefix sum for offsets + row_byte_offsets_rel = ( + torch.cumsum(row_payload_bytes_64, dim=0) - row_payload_bytes_64 + ) + # Absolute byte offsets + row_byte_offsets = (payload_region_start + row_byte_offsets_rel).to(torch.int32) + + # --- 5) Decode (GPU Optimized) --- + K = int(K_out.item()) + out_vals = torch.empty((num_rows, K), dtype=torch.int32, device=dev) + + decode_rows_kernel[(num_rows,)]( + payload_gpu, + out_vals, + row_byte_offsets, + best_B_idx, + use_bitmap, + k_rice_choices_tensor, + R_out, + K_out, + ) + + # No host-side cumsum here: kernel already returns prefix sums + return out_vals.to(torch.int64), C, num_rows diff --git a/src/tplr/compression/pack12.py b/src/tplr/compression/pack12.py new file mode 100644 index 000000000..9f3de6dc7 --- /dev/null +++ b/src/tplr/compression/pack12.py @@ -0,0 +1,97 @@ +import torch + + +def pack_12bit_indices(indices: torch.Tensor) -> torch.Tensor: + """ + Pack int64 indices into 12-bit representation. + Every 2 indices (24 bits) are packed into 3 uint8 values. + Assumes even number of indices (topk is always even). + + Args: + indices: Tensor with values < 4096 (12-bit max), must have even number of elements + + Returns: + packed_tensor as uint8 + """ + # Ensure indices fit in 12 bits + max_idx = indices.max().item() if indices.numel() > 0 else 0 + if max_idx >= 4096: + raise ValueError(f"Index {max_idx} exceeds 12-bit limit (4095)") + + # Flatten the tensor + indices_flat = indices.flatten() + n_indices = indices_flat.numel() + + # Ensure we have even number of indices + if n_indices % 2 != 0: + raise ValueError(f"Number of indices must be even, got {n_indices}") + + # Convert to int32 for bit manipulation + indices_flat = indices_flat.to(torch.int32) + + # Process all as pairs + indices_pairs = indices_flat + n_pairs = n_indices // 2 + + # Calculate packed size + packed_size = n_pairs * 3 + packed = torch.zeros(packed_size, dtype=torch.uint8, device=indices.device) + + # Vectorized packing for pairs + if n_pairs > 0: + idx_pairs = indices_pairs.reshape(-1, 2) + idx1 = idx_pairs[:, 0] + idx2 = idx_pairs[:, 1] + + # Pack pairs: idx1 uses byte0 + lower 4 bits of byte1 + # idx2 uses upper 4 bits of byte1 + byte2 + packed[0::3] = (idx1 & 0xFF).to(torch.uint8) # Lower 8 bits of idx1 + packed[1::3] = (((idx1 >> 8) & 0x0F) | ((idx2 & 0x0F) << 4)).to(torch.uint8) + packed[2::3] = ((idx2 >> 4) & 0xFF).to(torch.uint8) # Upper 8 bits of idx2 + + return packed + + +def unpack_12bit_indices( + packed: torch.Tensor, values_shape: tuple[int, ...] +) -> torch.Tensor: + """ + Unpack 12-bit packed indices back to int64. + Assumes even number of indices. + + Args: + packed: Packed uint8 tensor + values_shape: Shape of the values tensor (same as original indices shape) + + Returns: + Unpacked indices as int64 tensor with original shape + """ + n_indices = int(torch.prod(torch.tensor(values_shape)).item()) + + if n_indices == 0: + return torch.zeros(values_shape, dtype=torch.int64, device=packed.device) + + # Ensure even number of indices + if n_indices % 2 != 0: + raise ValueError(f"Number of indices must be even, got {n_indices}") + + # Prepare output + indices = torch.zeros(n_indices, dtype=torch.int64, device=packed.device) + + # All indices are paired + n_pairs = n_indices // 2 + + if n_pairs > 0: + # Vectorized unpacking + byte0 = packed[0::3].to(torch.int64) + byte1 = packed[1::3].to(torch.int64) + byte2 = packed[2::3].to(torch.int64) + + # Reconstruct indices + indices[0::2] = byte0 | ((byte1 & 0x0F) << 8) # idx1 + indices[1::2] = ((byte1 >> 4) & 0x0F) | (byte2 << 4) # idx2 + + # Reshape to match values shape + indices = indices.reshape(values_shape) + + return indices diff --git a/src/tplr/neurons.py b/src/tplr/neurons.py index f74b25dab..dafb69787 100644 --- a/src/tplr/neurons.py +++ b/src/tplr/neurons.py @@ -34,7 +34,7 @@ from wandb.sdk.wandb_run import Run import tplr -from tplr.compress import unpack_12bit_indices +from tplr.compression import decode_batch_rows from tplr.distributed import dist_helper if TYPE_CHECKING: @@ -1247,30 +1247,46 @@ async def check_uid_index_overlap( if idxs_all is None: continue - # Get values for unpacking shape - vals_key = pname + "vals" - vals_all = getattr(gather_result.state_dict, vals_key, None) - if vals_all is None: - continue + def _as_bytes(x) -> bytes: + if isinstance(x, (bytes, bytearray)): + return bytes(x) + if isinstance(x, torch.Tensor): + if x.dtype != torch.uint8: + raise ValueError( + f"Expected torch.uint8 for Rice payload, got {x.dtype}" + ) + return x.detach().cpu().contiguous().numpy().tobytes() + raise TypeError(f"Unsupported idx payload type: {type(x)}") + + decoded_per_peer: list[torch.Tensor] = [] - # Unpack all 12-bit packed indices using values shape - unpacked_indices = [] for i in range(Ptot): - idx_data = idxs_all[i] if isinstance(idxs_all, list) else idxs_all - val_data = vals_all[i] if isinstance(vals_all, list) else vals_all + idx_data = idxs_all[i] if isinstance(idxs_all, (list, tuple)) else idxs_all + payload = _as_bytes(idx_data) - # 12-bit packed format - use values shape for unpacking - unpacked = unpack_12bit_indices( - idx_data.to(neuron.config.device), val_data.shape - ) - unpacked_indices.append(unpacked) + rows_i, _C_codec, N_rows = decode_batch_rows( + payload + ) # rows_i: list[list[int]] + if N_rows == 0: + # no rows for this param/peer → skip param entirely + decoded_per_peer = [] + break + + # ensure rectangular (constant k) + k0 = len(rows_i[0]) + if not all(len(r) == k0 for r in rows_i): + raise ValueError("Rice payload has variable k per row; unsupported.") + + decoded_per_peer.append(torch.tensor(rows_i, dtype=torch.int64)) + + if not decoded_per_peer: + continue - idxs_tensor = torch.stack(unpacked_indices, dim=0) - P, *chunk_dims, k = idxs_tensor.shape - C = int(torch.prod(torch.tensor(chunk_dims))) # num chunks - idxs_flat = idxs_tensor.reshape(P, C, k) + idxs_tensor = torch.stack(decoded_per_peer, dim=0) # [P, C, K] + P, C_chunks, k = idxs_tensor.shape + idxs_flat = idxs_tensor # already [P, C, K] - param_weight = C * k # size weight + param_weight = C_chunks * k # size weight for i in range(P): for j in range(i + 1, P): diff --git a/tests/test_comms.py b/tests/test_comms.py index 6aea4f417..388896d5e 100644 --- a/tests/test_comms.py +++ b/tests/test_comms.py @@ -13,7 +13,7 @@ from datetime import datetime, timedelta, timezone from tplr import load_hparams -from tplr.compress import pack_12bit_indices +from tplr.compression import encode_batch_rows, pack_12bit_indices hparams = load_hparams() @@ -48,7 +48,7 @@ def create_xshapes_totalks(model): def create_valid_state_dict(model): state_dict = {} for name, _ in model.named_parameters(): - # Create 12-bit packed format + # Create legacy 12-bit packed format indices = torch.tensor([0, 1], dtype=torch.long) packed_data = pack_12bit_indices(indices) state_dict[name + "idxs"] = packed_data @@ -65,9 +65,9 @@ def create_missing_idxs(model): def create_packed_indices(indices_list): - """Helper function to create 12-bit packed indices from a list""" + """Helper function to create legacy 12-bit packed indices from a list""" indices = torch.tensor(indices_list, dtype=torch.long) - # Ensure even number of indices for 12-bit packing + # Ensure even number of indices for legacy 12-bit packing if len(indices_list) % 2 != 0: indices = torch.cat([indices, torch.tensor([0], dtype=torch.long)]) packed_data = pack_12bit_indices(indices) @@ -244,7 +244,7 @@ async def test_gather_basic_functionality(comms_instance, dummy_compressor): data={ "0.weightidxs": create_packed_indices( [0, 1, 2, 3] - ), # Even count for 12-bit + ), # Even count for legacy 12-bit "0.weightvals": torch.tensor([0.4, 0.5, 0.6, 0.7]), "totalks": {"0.weight": totalk_value}, }, @@ -255,7 +255,7 @@ async def test_gather_basic_functionality(comms_instance, dummy_compressor): data={ "0.weightidxs": create_packed_indices( [0, 1, 2, 3] - ), # Even count for 12-bit + ), # Even count for legacy 12-bit "0.weightvals": torch.tensor([0.7, 0.8, 0.9, 1.0]), "totalks": {"0.weight": totalk_value}, }, @@ -315,7 +315,7 @@ async def test_gather_normalization(comms_instance, dummy_compressor): data={ "0.weightidxs": create_packed_indices( [0, 1, 2, 3] - ), # Even count for 12-bit + ), # Even count for legacy 12-bit "0.weightvals": torch.tensor([0.4, 0.5, 0.6, 0.7]), "totalks": {"0.weight": totalk_value}, }, @@ -360,7 +360,7 @@ async def test_gather_quant_params_validation(comms_instance, dummy_compressor): val_key = f"{param_base}vals" qp_key = f"{param_base}quant_params" - idxs = create_packed_indices([0, 1, 2, 3]) # Even count for 12-bit + idxs = create_packed_indices([0, 1, 2, 3]) # Even count for legacy 12-bit vals = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=torch.uint8) # still quantised lookup = torch.zeros(256, dtype=torch.float32) # dummy LUT @@ -480,7 +480,7 @@ async def test_gather_averaging(comms_instance, dummy_compressor): data={ "0.weightidxs": create_packed_indices( [0, 1, 2, 3] - ), # Even count for 12-bit + ), # Even count for legacy 12-bit "0.weightvals": torch.tensor([0.4, 0.5, 0.6, 0.7]), "totalks": {"0.weight": totalk_value}, }, @@ -491,7 +491,7 @@ async def test_gather_averaging(comms_instance, dummy_compressor): data={ "0.weightidxs": create_packed_indices( [0, 1, 2, 3] - ), # Even count for 12-bit + ), # Even count for legacy 12-bit "0.weightvals": torch.tensor([0.8, 0.9, 1.0, 1.1]), "totalks": {"0.weight": totalk_value}, }, @@ -649,7 +649,7 @@ async def test_gather_complex_normalization(comms_instance, dummy_compressor): # Include totalks in each peer response using the key "layer." (so that stripping "idxs"/"vals" returns the same base key). peer1_response = CommsGetResult( data={ - "layer.idxs": create_packed_indices([0, 1, 2, 3]), # Even count for 12-bit + "layer.idxs": create_packed_indices([0, 1, 2, 3]), # Even count for legacy 12-bit "layer.vals": torch.tensor([1.0, 2.0, 2.0, 3.0]), # norm ≈ 3 "totalks": {"layer.": totalk_value}, }, @@ -658,7 +658,7 @@ async def test_gather_complex_normalization(comms_instance, dummy_compressor): ) peer2_response = CommsGetResult( data={ - "layer.idxs": create_packed_indices([0, 1, 2, 3]), # Even count for 12-bit + "layer.idxs": create_packed_indices([0, 1, 2, 3]), # Even count for legacy 12-bit "layer.vals": torch.tensor([10.0, 20.0, 20.0, 30.0]), # Larger scale "totalks": {"layer.": totalk_value}, }, @@ -667,7 +667,7 @@ async def test_gather_complex_normalization(comms_instance, dummy_compressor): ) peer3_response = CommsGetResult( data={ - "layer.idxs": create_packed_indices([0, 1, 2, 3]), # Even count for 12-bit + "layer.idxs": create_packed_indices([0, 1, 2, 3]), # Even count for legacy 12-bit "layer.vals": torch.tensor([-5.0, 5.0, 5.0, 10.0]), # Different sign "totalks": {"layer.": totalk_value}, }, @@ -1464,37 +1464,41 @@ def __init__(self): ) -def test_valid_12bit_packed_indices(): +def test_valid_rice_bitmap_encoded_indices(): """ - Test Case: test_valid_12bit_packed_indices - - Input: 12-bit packed indices with correct topk dimension + Test Case: test_valid_rice_bitmap_encoded_indices + - Input: Rice/bitmap encoded indices with correct topk dimension - Valid indices (all indices within [0, totalk-1]) - Expected Outcome: The function should complete without raising an error. """ dummy_comms = DummyComms() - # totalk is set to 10; allowed_topk is min(4, 10) == 4. - totalk = 10 - valid_indices = torch.tensor([1, 5, 9, 3], dtype=torch.long) - packed_data = pack_12bit_indices(valid_indices) - vals = torch.randn_like(valid_indices, dtype=torch.float32) + # totalk is set to 64; allowed_topk is min(4, 64) == 4. + totalk = 64 + valid_indices = torch.tensor( + [[1, 5, 9, 3]], dtype=torch.long + ) # Shape [1, 4] for one row + # Use the new encoder format + idx, _ = encode_batch_rows(valid_indices, C=totalk) + vals = torch.randn(1, 4, dtype=torch.float32) # Match the shape [rows, k] # This call should complete without any error. - dummy_comms.check_compressed_indices("test_param", packed_data, totalk, vals=vals) + dummy_comms.check_compressed_indices("test_param", idx, totalk, vals=vals) -def test_valid_12bit_packed_multi_dim(): +def test_valid_rice_bitmap_encoded_multi_dim(): """ - Test 12-bit packed indices from multi-dimensional tensor where the last dimension + Test Rice/bitmap encoded indices from multi-dimensional tensor where the last dimension equals min(hparams.topk_compression, totalk) and all indices are within valid range. """ dummy_comms = DummyComms() - totalk = 20 # allowed_topk = min(4, 20) = 4 + totalk = 128 # allowed_topk = min(4, 128) = 4 # Create a valid 2D tensor (shape: 2 x 4) with valid indices. valid_indices = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.long) - packed_data = pack_12bit_indices(valid_indices) - vals = torch.randn_like(valid_indices, dtype=torch.float32) - dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) + # Use the new encoder format + idx, _ = encode_batch_rows(valid_indices, C=totalk) + vals = torch.randn(2, 4, dtype=torch.float32) # Match the shape + dummy_comms.check_compressed_indices("param", idx, totalk, vals=vals) def test_invalid_not_packed_format(): @@ -1502,18 +1506,18 @@ def test_invalid_not_packed_format(): Test that non-packed formats (like regular tensors or lists) are rejected. """ dummy_comms = DummyComms() - totalk = 20 + totalk = 128 # Test with regular tensor (not packed) - should fail because it's not uint8 invalid_tensor = torch.tensor([0, 1, 2, 3], dtype=torch.long) vals = torch.randn(4, dtype=torch.float32) - # This should fail since only uint8 12-bit packed format is supported - with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"): + # This should fail since only uint8 Rice/bitmap encoded format is supported + with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"): dummy_comms.check_compressed_indices("param", invalid_tensor, totalk, vals=vals) # Test with list (not a tensor) invalid_list = torch.tensor([0, 1, 2, 3]) - with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"): + with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"): dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals) @@ -1522,125 +1526,132 @@ def test_invalid_wrong_dtype(): Test that packed data with wrong dtype is handled correctly. """ dummy_comms = DummyComms() - totalk = 20 + totalk = 128 # int32 tensor is not uint8, so it should fail fake_packed = torch.tensor([0, 1, 2, 3], dtype=torch.int32) vals = torch.randn(4, dtype=torch.float32) # Should fail since only uint8 format is supported - with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"): + with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"): dummy_comms.check_compressed_indices("param", fake_packed, totalk, vals=vals) -def test_invalid_12bit_packed_wrong_topk(): +def test_invalid_rice_bitmap_wrong_topk(): """ - Test that 12-bit packed indices with wrong topk dimension raises ValueError. + Test that Rice/bitmap encoded indices with wrong topk dimension raises ValueError. """ dummy_comms = DummyComms() - totalk = 10 # allowed_topk = min(4, 10) = 4 + totalk = 64 # allowed_topk = min(4, 64) = 4 # Create packed indices with wrong topk (2 instead of 4) - invalid_indices = torch.tensor([0, 1], dtype=torch.long) - packed_data = pack_12bit_indices(invalid_indices) - vals = torch.randn(2, dtype=torch.float32) # Wrong shape - should be 4 - with pytest.raises(ValueError, match="Invalid topk dimension"): + invalid_indices = torch.tensor([[0, 1]], dtype=torch.long) + payload, _ = encode_batch_rows(invalid_indices, C=totalk) + packed_data = torch.tensor( + np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8 + ) + vals = torch.randn(1, 2, dtype=torch.float32) # Wrong shape - should be 4 + with pytest.raises(ValueError, match="Values top.*k=2 but allowed_topk=4"): dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) -def test_invalid_12bit_packed_multi_dim_wrong_topk(): +def test_invalid_rice_bitmap_multi_dim_wrong_topk(): """ - Test that 12-bit packed indices from multi-dimensional tensor with wrong last dimension + Test that Rice/bitmap encoded indices from multi-dimensional tensor with wrong last dimension raises ValueError indicating invalid topk dimension. """ dummy_comms = DummyComms() - totalk = 20 # allowed_topk = min(4, 20) = 4 + totalk = 128 # allowed_topk = min(4, 128) = 4 # Create a 2D tensor with last dimension size 6 (should be 4) invalid_indices = torch.tensor( [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]], dtype=torch.long ) - packed_data = pack_12bit_indices(invalid_indices) + idx, _ = encode_batch_rows(invalid_indices, C=totalk) vals = torch.randn(2, 6, dtype=torch.float32) # Wrong shape - should be (2, 4) - with pytest.raises(ValueError, match="Invalid topk dimension"): - dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) + with pytest.raises(ValueError, match="Values top.*k=6 but allowed_topk=4"): + dummy_comms.check_compressed_indices("param", idx, totalk, vals=vals) -# Removed test_invalid_12bit_packed_negative_index as pack_12bit_indices validates input +# Removed test_invalid_rice_bitmap_negative_index as encoder validates input -def test_invalid_12bit_packed_out_of_bounds(): +def test_invalid_rice_bitmap_out_of_bounds(): """ - Test that 12-bit packed indices with out-of-bounds values raise ValueError. + Test that Rice/bitmap encoded indices with out-of-bounds values raise ValueError. """ dummy_comms = DummyComms() - totalk = 10 # allowed_topk = min(4, 10) = 4 - # Index 10 is out-of-range because valid indices are 0 to 9. - invalid_indices = torch.tensor([0, 1, 10, 3], dtype=torch.long) - packed_data = pack_12bit_indices(invalid_indices) - vals = torch.randn(4, dtype=torch.float32) - with pytest.raises(ValueError, match="Index 10 out of bounds"): - dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) + totalk = 64 # allowed_topk = min(4, 64) = 4 + # Index 64 is out-of-range because valid indices are 0 to 63. + # But the encoder will fail before we can test - so let's test with valid encode but wrong totalk + invalid_indices = torch.tensor([[0, 1, 9, 3]], dtype=torch.long) + # Encode with a larger C to make it work + idx, _ = encode_batch_rows(invalid_indices, C=128, B_choices=(64)) + vals = torch.randn(1, 4, dtype=torch.float32) + # Now check with smaller totalk=64, so index 9 is valid but payload says C=128 + with pytest.raises(ValueError, match="Payload column size C=128 but expected 64"): + dummy_comms.check_compressed_indices("param", idx, totalk, vals=vals) # Removed test_invalid_flat_list_wrong_length - covered by test_invalid_not_packed_format -# Removed test_valid_single_value - not applicable to 12-bit packed format +# Removed test_valid_single_value - not applicable to Rice/bitmap encoded format -# Removed test_invalid_single_value_out_of_bounds - not applicable to 12-bit packed format +# Removed test_invalid_single_value_out_of_bounds - not applicable to Rice/bitmap encoded format -def test_override_allowed_topk_12bit(): +def test_override_allowed_topk_rice_bitmap(): """ - Test using the optional allowed_topk parameter with 12-bit packed format. + Test using the optional allowed_topk parameter with Rice/bitmap encoded format. """ dummy_comms = DummyComms() - totalk = 10 + totalk = 64 + B_choices = (64) # Override allowed_topk to 2. valid_indices = torch.tensor( - [0, 9], dtype=torch.long + [[0, 9]], dtype=torch.long ) # Correct length: 2 elements. - packed_data = pack_12bit_indices(valid_indices) - vals = torch.randn(2, dtype=torch.float32) + idx, _ = encode_batch_rows(valid_indices, C=totalk, B_choices=B_choices) + vals = torch.randn(1, 2, dtype=torch.float32) dummy_comms.check_compressed_indices( - "param", packed_data, totalk, allowed_topk=2, vals=vals + "param", idx, totalk, allowed_topk=2, vals=vals ) # Test with wrong topk invalid_indices = torch.tensor( - [0, 1, 2, 3], dtype=torch.long + [[0, 1, 2, 3]], dtype=torch.long ) # 4 elements instead of 2. - packed_data = pack_12bit_indices(invalid_indices) - vals = torch.randn(4, dtype=torch.float32) # Wrong shape for allowed_topk=2 - with pytest.raises(ValueError, match="Invalid topk dimension"): + idx, _ = encode_batch_rows(invalid_indices, C=totalk, B_choices=B_choices) + vals = torch.randn(1, 4, dtype=torch.float32) # Wrong shape for allowed_topk=2 + with pytest.raises(ValueError, match="Values top.*k=4 but allowed_topk=2"): dummy_comms.check_compressed_indices( - "param", packed_data, totalk, allowed_topk=2, vals=vals + "param", idx, totalk, allowed_topk=2, vals=vals ) -def test_topk_auto_adjust_when_totalk_is_lower_12bit(): +def test_topk_auto_adjust_when_totalk_is_lower_rice_bitmap(): """ - Test scenario where totalk is less than hparams.topk_compression with 12-bit packed format. + Test scenario where totalk is less than hparams.topk_compression with Rice/bitmap encoded format. """ dummy_comms = DummyComms() - totalk = 2 # Now allowed_topk becomes min(hparams.topk_compression, totalk) = min(4,2) = 2. + totalk = 64 # Now allowed_topk becomes min(hparams.topk_compression, totalk) = min(4,64) = 4. valid_indices = torch.tensor( - [0, 1], dtype=torch.long - ) # Valid: length matches allowed_topk (which is 2). - packed_data = pack_12bit_indices(valid_indices) - vals = torch.randn(2, dtype=torch.float32) - dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) - - # Note: Can't test with 1 element as pack_12bit_indices requires even number of indices - # Test with 4 elements (wrong topk) + [[0, 1, 2, 3]], dtype=torch.long + ) # Valid: length matches allowed_topk (which is 4). + idx, _ = encode_batch_rows(valid_indices, C=totalk) + vals = torch.randn(1, 4, dtype=torch.float32) + dummy_comms.check_compressed_indices("param", idx, totalk, vals=vals) + + # Note: Can't test with 1 element as encoder requires even number of indices + # Test with 6 elements (wrong topk) invalid_indices = torch.tensor( - [0, 1, 0, 1], dtype=torch.long - ) # 4 elements instead of 2. - packed_data = pack_12bit_indices(invalid_indices) - vals = torch.randn(4, dtype=torch.float32) # Wrong shape for allowed_topk=2 - with pytest.raises(ValueError, match="Invalid topk dimension"): - dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) + [[0, 1, 2, 3, 4, 5]], dtype=torch.long + ) # 6 elements instead of 4. + idx, _ = encode_batch_rows(invalid_indices, C=totalk) + vals = torch.randn(1, 6, dtype=torch.float32) # Wrong shape for allowed_topk=4 + with pytest.raises(ValueError, match="Values top.*k=6 but allowed_topk=4"): + dummy_comms.check_compressed_indices("param", idx, totalk, vals=vals) # Tests for `weighted_random_sample_no_replacement` diff --git a/tests/unit/test_compress.py b/tests/unit/test_compress.py index 2eaab4fcd..25d12f59f 100644 --- a/tests/unit/test_compress.py +++ b/tests/unit/test_compress.py @@ -1,5 +1,6 @@ from typing import Literal +import numpy as np import pytest import torch import torch.nn as nn @@ -10,9 +11,8 @@ _dct, _get_smaller_split, _idct, - pack_12bit_indices, - unpack_12bit_indices, ) +from tplr.compression import encode_batch_rows, pack_12bit_indices, unpack_12bit_indices class TestTopKCompressor: @@ -30,19 +30,19 @@ def compress_instance_quantized(self) -> TopKCompressor[Literal[True]]: use_quantization=True, quantization_bins=256, quantization_range=6 ) - def test_compress_produces_int16_indices( + def test_compress_produces_rice_bitmap_indices( self, compress_instance: TopKCompressor[Literal[False]] ): - """Test that compress() produces 12-bit packed indices""" + """Test that compress() produces Rice/bitmap encoded indices""" # Create test tensor - x = torch.randn(10, 10) - topk = 10 + x = torch.randn(8, 128) # 1024 elements total, last dim=64 + topk = 16 # Compress using actual method idx, val, xshape, totalk = compress_instance.compress(x, topk) - # Verify index format - should be uint8 tensor for 12-bit packed - assert idx.dtype == torch.uint8, f"Expected uint8 packed data, got {idx.dtype}" + # Verify index format - should be uint8 tensor for Rice/bitmap codec + assert idx.dtype == torch.uint8, f"Expected uint8 encoded data, got {idx.dtype}" assert val.shape[-1] == topk assert xshape == x.shape # totalk is the size of the last dimension after rearranging @@ -53,8 +53,8 @@ def test_compress_with_quantization( self, compress_instance_quantized: TopKCompressor[Literal[True]] ): """Test compression with quantization enabled""" - x = torch.randn(10, 10) - topk = 20 + x = torch.randn(8, 128) # 1024 elements total, last dim=64 + topk = 32 # Compress with quantization result = compress_instance_quantized.compress(x, topk) @@ -63,53 +63,52 @@ def test_compress_with_quantization( assert len(result) == 5 idx, val, _, _, qparams = result - # idx should be uint8 tensor for 12-bit packed format + # idx should be uint8 tensor for Rice/bitmap encoded format assert idx.dtype == torch.uint8 assert val.dtype == torch.uint8 # Quantized values assert qparams is not None assert len(qparams) == 5 # shift, scale, offset, lookup, orig_dtype - def test_decompress_with_12bit_tuple_format( + def test_decompress_with_rice_bitmap_format( self, compress_instance: TopKCompressor[Literal[False]] ): - """Test that decompress can handle 12-bit packed tuple format""" + """Test that decompress can handle Rice/bitmap encoded format""" # Setup - p = torch.zeros(10, 10) - xshape = (10, 10) - totalk = 100 + p = torch.zeros(8, 128) # 1024 elements total, last dim=128 + xshape = (8, 128) + totalk = 128 - # Create proper 12-bit packed format using the actual packing function - # Create indices that are within valid range for a 10x10 tensor (even count) + # Create proper Rice/bitmap encoded format using the encoder + # Create indices that are within valid range for a 8x64 tensor (even count) original_indices = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.int64) - # Pack using the actual function - idx = pack_12bit_indices(original_indices) - + # Pack using the new encoder format + idx_bytes, _ = encode_batch_rows(original_indices, C=totalk) val = torch.tensor( [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32 ) # Test decompression with packed format - result = compress_instance.decompress(p, idx, val, xshape, totalk) + result = compress_instance.decompress(p, idx_bytes, val, xshape, totalk) assert result.shape == xshape assert result.dtype == p.dtype - def test_batch_decompress_multiple_12bit_formats( + def test_batch_decompress_multiple_rice_bitmap_formats( self, compress_instance: TopKCompressor[Literal[False]] ): - """Test batch_decompress with multiple 12-bit packed indices""" + """Test batch_decompress with multiple Rice/bitmap encoded indices""" # Setup - p = torch.zeros(10, 10) - xshape = (10, 10) - totalk = 100 + p = torch.zeros(8, 128) # 1024 elements total, last dim=128 + xshape = (8, 128) + totalk = 128 - # Create multiple 12-bit packed indices + # Create multiple Rice/bitmap encoded indices idx1_orig = torch.tensor([[0, 1], [2, 3]], dtype=torch.int64) idx2_orig = torch.tensor([[4, 5], [6, 7]], dtype=torch.int64) - # Pack them using the 12-bit format - idx1_packed = pack_12bit_indices(idx1_orig) - idx2_packed = pack_12bit_indices(idx2_orig) + # Pack them using the new encoder format + idx1_packed, _ = encode_batch_rows(idx1_orig, C=totalk) + idx2_packed, _ = encode_batch_rows(idx2_orig, C=totalk) idx_list = [idx1_packed, idx2_packed] @@ -128,7 +127,7 @@ def test_compress_decompress_round_trip( self, compress_instance: TopKCompressor[Literal[False]] ): """Test full compress-decompress round trip""" - x = torch.zeros(10, 10) + x = torch.zeros(8, 128) # 1024 elements total, last dim=128 x[0, 0] = 1.0 x[1, 1] = 2.0 x[2, 2] = 3.0 @@ -139,7 +138,9 @@ def test_compress_decompress_round_trip( idx, val, xshape, totalk = compress_instance.compress(x, topk) # Verify we got the top-k values - assert idx.dtype == torch.uint8, "Expected uint8 for 12-bit packed indices" + assert idx.dtype == torch.uint8, ( + "Expected uint8 for Rice/bitmap encoded indices" + ) assert val.shape[-1] == topk # Decompress @@ -157,44 +158,99 @@ def test_compress_decompress_round_trip( expected_vals = torch.tensor([4.0, 3.0, 2.0, 1.0]) assert torch.allclose(top_vals, expected_vals, atol=1e-5) - def test_12bit_index_value_range( + def test_encode_compress_decompress_round_trip( self, compress_instance: TopKCompressor[Literal[False]] ): - """Test that indices can represent values appropriate for 12-bit range""" + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.layer1 = nn.Linear(32, 64) + self.layer2 = nn.Linear(64, 256) + + target_chunk = 16 + transform = ChunkingTransformer(SimpleModel(), target_chunk) + x = torch.zeros(256, 64) + x[0, 0] = 1.0 + x[1, 1] = 2.0 + x[2, 2] = 3.0 + x[3, 3] = 4.0 + + topk = 4 + encoded = transform.encode(x) + idxs, vals, xshape, totalk = compress_instance.compress(encoded, topk) + + p = torch.zeros_like(x) + decompressed = compress_instance.decompress(p, idxs, vals, xshape, totalk) + assert torch.allclose(encoded, decompressed, atol=1e-5) + + def test_rice_bitmap_index_value_range( + self, compress_instance: TopKCompressor[Literal[False]] + ): + """Test that Rice/bitmap codec can handle large index ranges efficiently""" # Create a large tensor that would have indices beyond 8-bit range - x = torch.randn(100, 100) # 10,000 elements - topk = 100 + x = torch.randn(128, 128) # 16,384 elements + topk = 64 # Compress idx, val, _, totalk = compress_instance.compress(x, topk) - # Check that indices are 12-bit packed format - assert idx.dtype == torch.uint8, "Expected uint8 for 12-bit packed indices" + # Check that indices are in the new codec format (uint8 bytes) + assert idx.dtype == torch.uint8, "Expected uint8 for Rice/bitmap codec" - # Since idx is packed, we can't directly check max values - # Instead verify the packing worked correctly - # Use val.shape since it has the same shape as the original indices - unpacked = unpack_12bit_indices(idx, val.shape) + # Since idx is a byte stream payload, we can't directly check max values + # Instead verify round-trip works correctly + p = torch.zeros_like(x) + result = compress_instance.decompress(p, idx, val, x.shape, totalk) - # Verify some indices might be larger than 255 (8-bit max) - max_idx = unpacked.max().item() - assert max_idx < 10000, f"Index {max_idx} exceeds tensor size" + # Check that decompression succeeded + assert result.shape == x.shape - # If tensor is large enough, we should have indices > 255 - if totalk > 256: - assert unpacked.max() > 255, ( - "Large tensor should have indices beyond 8-bit range" - ) + # For a 2D tensor, totalk is the size of the last dimension + assert totalk == 128, ( + f"Expected totalk=128 for 128x128 tensor (last dim), got {totalk}" + ) def test_batch_decompress_with_norm_options( self, compress_instance: TopKCompressor[Literal[False]] ): """Test batch_decompress with normalisation and clip_norm options""" - p = torch.zeros(10, 10) - xshape = (10, 10) - totalk = 100 + p = torch.zeros(8, 128) # 1024 elements total, last dim=128 + xshape = (8, 128) + totalk = 128 + + # Create test data with Rice/bitmap encoded format + idx_orig = torch.tensor([[0, 1, 2, 3]], dtype=torch.int64) # Even count + idx_packed, _ = encode_batch_rows(idx_orig, C=totalk) + idx = [idx_packed] + val = [torch.tensor([[10.0, 20.0, 30.0, 40.0]], dtype=torch.float32)] + + # Test with normalisation + result_norm = compress_instance.batch_decompress( + p, idx, val, xshape, totalk, normalise=True + ) + assert result_norm.shape == xshape + + # Test with clip_norm + result_clip = compress_instance.batch_decompress( + p, idx, val, xshape, totalk, clip_norm=True + ) + assert result_clip.shape == xshape + + # Test with block_norms provided + block_norms = torch.tensor([15.0]) + result_with_norms = compress_instance.batch_decompress( + p, idx, val, xshape, totalk, block_norms=block_norms, clip_norm=True + ) + assert result_with_norms.shape == xshape + + def test_batch_decompress_with_legacy_packing( + self, compress_instance: TopKCompressor[Literal[False]] + ): + p = torch.zeros(8, 128) # 1024 elements total, last dim=128 + xshape = (8, 128) + totalk = 128 - # Create test data with 12-bit packed format + # Create test data with Rice/bitmap encoded format idx_orig = torch.tensor([[0, 1, 2, 3]], dtype=torch.int64) # Even count idx_packed = pack_12bit_indices(idx_orig) idx = [idx_packed]