diff --git a/hparams/hparams.json b/hparams/hparams.json index 97cfbec13..3e30f897b 100644 --- a/hparams/hparams.json +++ b/hparams/hparams.json @@ -10,9 +10,8 @@ "blocks_per_window": 65, "windows_per_weights": 5, "momentum_decay": 0.95, - "topk_compression": 32, + "topk_compression": 128, "target_chunk": 64, - "use_dct": false, "binary_score_ma_alpha": 0.05, "moving_average_window": 5, "model_size": "70B", diff --git a/neurons/evaluator.py b/neurons/evaluator.py index 5019b109c..b69e4a75c 100644 --- a/neurons/evaluator.py +++ b/neurons/evaluator.py @@ -60,6 +60,7 @@ import bittensor as bt import torch import torch.distributed as dist +from lm_eval import simple_evaluate from torch.cuda import device_count as _cuda_device_count from torch.utils.data import DataLoader from torchtitan.components.loss import cross_entropy_loss diff --git a/neurons/miner.py b/neurons/miner.py index e74bd5c35..eeab499c7 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -258,8 +258,7 @@ def __init__(self): ) enc = self.transformer.encode( - torch.empty(p.shape, dtype=torch.float16, device=self.device), - use_dct=self.hparams.use_dct, + torch.empty(p.shape, dtype=torch.float16, device=self.device) ) _, _, xshape, totalk, _ = self.compressor.compress( enc, diff --git a/neurons/trainer.py b/neurons/trainer.py index 615eb6f7d..103d0bdb1 100644 --- a/neurons/trainer.py +++ b/neurons/trainer.py @@ -676,7 +676,6 @@ def outer_step(self, gather_result): device=str(self.device), is_master=self.is_master, world_size=self.world_size, - use_dct=self.hparams.use_dct, ) return diff --git a/neurons/validator.py b/neurons/validator.py index 3f925f687..f11288b6c 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -259,18 +259,41 @@ def __init__(self): self.xshapes = {} self.totalks = {} # Use bare_model like the miner does to ensure consistent parameter iteration + import time + total_compress_time = 0.0 + total_encode_time = 0.0 + + # Enable debug timing in compressor + self.compressor._debug_timing = True + for n, p in self.model.named_parameters(): # Use the same approach as miner for creating xshapes and totalks + encode_start = time.time() enc = self.transformer.encode( - torch.empty(p.shape, dtype=torch.float16, device=self.device), - use_dct=self.hparams.use_dct, + torch.empty(p.shape, dtype=torch.float16, device=self.device) ) + encode_time = time.time() - encode_start + + compress_start = time.time() _, _, xshape, totalk, _ = self.compressor.compress( enc, self.hparams.topk_compression, ) + compress_time = time.time() - compress_start + self.xshapes[n] = xshape self.totalks[n] = totalk + + total_encode_time += encode_time + total_compress_time += compress_time + + # Log timing for each layer + tplr.logger.info(f"[COMPRESS TIMING] {n}: encode={encode_time:.3f}s, compress={compress_time:.3f}s, shape={p.shape}") + + tplr.logger.info(f"[COMPRESS TIMING TOTAL] encode={total_encode_time:.3f}s, compress={total_compress_time:.3f}s") + + # Disable debug timing after initialization + self.compressor._debug_timing = False self.openskill_model = PlackettLuce( beta=self.hparams.openskill_beta, tau=self.hparams.openskill_tau @@ -1698,7 +1721,6 @@ async def run(self): device=cast(str, self.device), is_master=self.is_master, world_size=self.world_size, - use_dct=self.hparams.use_dct, wandb_run=self.wandb if self.is_master else None, global_step=self.global_step, ) @@ -2804,9 +2826,7 @@ def update_model_with_gradient( quant_params, ) - full_grad_src = self.transformer.decode( - decompressed, use_dct=self.hparams.use_dct - ) + full_grad_src = self.transformer.decode(decompressed) # Single conversion to target dtype+device to avoid extra temporaries full_grad_src = full_grad_src.to( dtype=p.dtype, device=p.device, non_blocking=True diff --git a/src/tplr/comms.py b/src/tplr/comms.py index 19a862439..da839e9f0 100644 --- a/src/tplr/comms.py +++ b/src/tplr/comms.py @@ -34,6 +34,7 @@ import bittensor as bt import boto3 import botocore +import numpy as np import torch import torch.distributed as dist from aiobotocore.client import AioBaseClient @@ -48,7 +49,7 @@ import tplr from tplr.chain import ChainManager -from tplr.compress import TopKCompressor, unpack_12bit_indices +from tplr.compress import TopKCompressor, decode_batch_rows from tplr.config import BUCKET_SECRETS, client_config from tplr.schemas import Bucket, CommsGetResult @@ -2622,10 +2623,8 @@ def check_compressed_indices( """ Validates the integrity and format of compressed gradient indices. - This is a crucial security and stability check to ensure that gradients - received from peers are well-formed. It verifies that indices are within - the expected bounds and that the compression format (e.g., 12-bit packing) - is correctly applied. + This ensures indices are within bounds and that the **new Rice/bitmap** + codec payload matches the provided values tensor shape (top‑k). Args: param_name (str): The name of the parameter being checked. @@ -2633,12 +2632,11 @@ def check_compressed_indices( totalk (int): The total number of elements in the original uncompressed tensor. allowed_topk (int | None, optional): The expected number of top-k values. Defaults to the hparams configuration. - vals (torch.Tensor | None, optional): The corresponding values tensor, - required for validating 12-bit packed indices. Defaults to None. + vals (torch.Tensor | None, optional): The corresponding values tensor. Raises: ValueError: If any validation check fails, such as out-of-bounds - indices, incorrect data types, or malformed packed data. + indices, incorrect data types, or malformed payload. """ allowed_topk = ( min(self.hparams.topk_compression, totalk) @@ -2646,44 +2644,61 @@ def check_compressed_indices( else min(allowed_topk, totalk) ) - def _bounds_check(t: torch.Tensor): - """fast min/max bounds check""" - if t.numel() == 0: - raise ValueError(f"[{param_name}] empty index list") - if t.min().item() < 0 or t.max().item() >= totalk: - bad = t[(t < 0) | (t >= totalk)][0].item() - raise ValueError( - f"[{param_name}] Index {bad} out of bounds (totalk = {totalk})" - ) + if not isinstance(idxs, torch.Tensor): + raise ValueError( + f"[{param_name}] Expected tensor for indices, got {type(idxs)}" + ) + if vals is None: + raise ValueError( + f"[{param_name}] Values tensor required for index validation" + ) + if idxs.dtype != torch.uint8: + raise ValueError( + f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}" + ) + if idxs.numel() == 0: + raise ValueError(f"[{param_name}] Empty indices payload") - # Handle 12-bit packed index format only - if isinstance(idxs, torch.Tensor): - if idxs.dtype != torch.uint8: - raise ValueError( - f"[{param_name}] Expected uint8 for 12-bit packed indices, got {idxs.dtype}" - ) - # 12-bit packed format is the only supported format - if vals is None: - raise ValueError( - f"[{param_name}] Values tensor required to validate 12-bit packed indices" - ) - if idxs.numel() == 0: - raise ValueError(f"[{param_name}] Empty packed indices tensor") + # Decode (CPU) and perform structural checks + try: + payload_bytes = idxs.detach().cpu().numpy().tobytes() + rows_list, C, N = decode_batch_rows(payload_bytes) + except Exception as e: + raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}") - # Unpack using the values shape - try: - unpacked = unpack_12bit_indices(idxs, vals.shape) - # Validate that the last dimension matches allowed_topk - if unpacked.shape[-1] != allowed_topk: - raise ValueError( - f"[{param_name}] Invalid topk dimension: " - f"shape[-1]={unpacked.shape[-1]} but expected {allowed_topk}" - ) - _bounds_check(unpacked) - except Exception as e: - raise ValueError(f"[{param_name}] Failed to unpack 12-bit indices: {e}") - else: - raise ValueError(f"[{param_name}] Expected tensor but got {type(idxs)}") + if C != totalk: + raise ValueError( + f"[{param_name}] Payload column size C={C} but expected {totalk}" + ) + + # compute expected rows from values shape (flatten all but last dim) + if vals.ndim == 0: + raise ValueError(f"[{param_name}] Values tensor has no top‑k dimension") + expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1 + if N != expected_rows: + raise ValueError( + f"[{param_name}] Payload rows N={N} but values imply {expected_rows}" + ) + + k = vals.shape[-1] + if k != allowed_topk: + raise ValueError( + f"[{param_name}] Values top‑k={k} but allowed_topk={allowed_topk}" + ) + if any(len(r) != k for r in rows_list): + raise ValueError( + f"[{param_name}] At least one row has mismatched top‑k size" + ) + + # bounds check without materialising full tensor + max_idx = max((max(r) if len(r) > 0 else -1) for r in rows_list) + min_idx = ( + min((min(r) if len(r) > 0 else 0) for r in rows_list) if rows_list else 0 + ) + if min_idx < 0 or max_idx >= totalk: + raise ValueError( + f"[{param_name}] Index out of bounds (min={min_idx}, max={max_idx}, totalk={totalk})" + ) async def s3_get_object_size(self, bucket: Bucket, key: str) -> int | None: """ diff --git a/src/tplr/compress/__init__.py b/src/tplr/compress/__init__.py new file mode 100644 index 000000000..b7dc9c53f --- /dev/null +++ b/src/tplr/compress/__init__.py @@ -0,0 +1,33 @@ +# The MIT License (MIT) +# © 2025 tplr.ai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from .bits import ( + decode_batch_rows, # decoder (CPU) + encode_batch_rows, # GPU-accelerated encoder → bytes + perm + meta +) +from .pack12 import pack_12bit_indices, unpack_12bit_indices # legacy +from .topk import ChunkingTransformer, TopKCompressor + +__all__ = [ + # High level + "TopKCompressor", + "ChunkingTransformer", + "encode_batch_rows", + "decode_batch_rows", + "pack_12bit_indices", + "unpack_12bit_indices", +] diff --git a/src/tplr/compress/bits.py b/src/tplr/compress/bits.py new file mode 100644 index 000000000..612f07962 --- /dev/null +++ b/src/tplr/compress/bits.py @@ -0,0 +1,487 @@ +# bits.py +# The MIT License (MIT) +# © 2025 tplr.ai +# +# Triton-powered Rice/bitmap encoder (per-row scheme) with a CPU decoder. +# Payload layout matches the CPU reference: +# [ C-1 : 12b ][ N : 16b ][ reserved : 1b ] +# then, for each row r=0..N-1: +# [ row_len_bytes[r] : 16b ][ row_payload_bits[r] ] +# +# Dependencies: torch, triton (runtime), numpy (only for decode consumer code elsewhere if needed) + +from __future__ import annotations +import math +from typing import Sequence, Tuple + +import torch +import torch.nn.functional as F + +try: + import triton + import triton.language as tl + + TRITON_AVAILABLE = True +except Exception: + TRITON_AVAILABLE = False + + +# -------------------------- CPU decoder (unchanged format) -------------------------- + + +class _BitReader: + def __init__(self, data: bytes) -> None: + self.data = data + self.idx = 0 + self.cur = 0 + self.nbits = 0 + + def _fill(self, n: int) -> None: + while self.nbits < n and self.idx < len(self.data): + self.cur |= int(self.data[self.idx]) << self.nbits + self.idx += 1 + self.nbits += 8 + + def read_bits(self, n: int) -> int: + if n <= 0: + return 0 + self._fill(n) + mask = (1 << n) - 1 + out = self.cur & mask + self.cur >>= n + self.nbits -= n + return out + + def read_unary(self) -> int: + q = 0 + while True: + bit = self.read_bits(1) + if bit == 0: + break + q += 1 + return q + + def read_bytes(self, n: int) -> bytes: + out = bytearray() + for _ in range(n): + out.append(self.read_bits(8)) + return bytes(out) + + +def _rice_read(br: _BitReader, k: int) -> int: + q = br.read_unary() + r = br.read_bits(k) if k > 0 else 0 + return (q << k) + r + + +def decode_batch_rows(payload: bytes) -> tuple[list[list[int]], int, int]: + """ + Decode payload created by encode_batch_rows(...). + Returns (rows, C, N) where `rows` is a list of per-row global indices. + """ + br = _BitReader(payload) + C = br.read_bits(12) + 1 + N = br.read_bits(16) + _ = br.read_bits(1) # reserved + + rows: list[list[int]] = [] + for _i in range(N): + row_len = br.read_bits(16) + row_bytes = br.read_bytes(row_len) + rr = _BitReader(row_bytes) + lb = rr.read_bits(5) + k_param = rr.read_bits(4) + use_bitmap = rr.read_bits(1) + B = 1 << lb + n_sub = C // B + + indices: list[int] = [] + for j in range(n_sub): + s_len = _rice_read(rr, k_param) + if s_len == 0: + continue + if use_bitmap: + bitmask = rr.read_bits(B) + for loc in range(B): + if (bitmask >> loc) & 1: + indices.append(j * B + loc) + else: + for _ in range(s_len): + loc = rr.read_bits(lb) + indices.append(j * B + loc) + rows.append(indices) + return rows, C, N + + +# --------------------------- GPU-side param selection --------------------------- + + +def _rice_k_from_mean(lmbda: float) -> int: + if lmbda <= 0.0: + return 0 + return max(0, round(math.log2(max(lmbda, 1e-9)))) + + +@torch.no_grad() +def _estimate_best_params_per_row( + idx: torch.Tensor, C: int, B_choices: Sequence[int] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Torch (GPU) estimate of best B, use_bitmap, and k_param per row. + Mirrors your previous vectorised selector. + """ + assert idx.dtype == torch.int64 + rows, k = idx.shape + device = idx.device + + B_sorted = tuple( + sorted([b for b in B_choices if b > 0 and (C % b) == 0 and (b & (b - 1)) == 0]) + ) + if not B_sorted: + raise ValueError("No valid B choices for C") + + header = 5 + 4 + 1 + best_bits = torch.full((rows,), 1 << 60, device=device, dtype=torch.int64) + best_B = torch.full((rows,), B_sorted[0], device=device, dtype=torch.int64) + best_use_bitmap = torch.zeros((rows,), device=device, dtype=torch.bool) + + Bmin = B_sorted[0] + if all((b % Bmin) == 0 for b in B_sorted): + n_sub_min = C // Bmin + js_min = (idx // Bmin).to(torch.int64) # [rows, k] + counts_min = ( + F.one_hot(js_min, num_classes=n_sub_min).sum(dim=1).to(torch.int64) + ) # [rows, n_sub_min] + lmbda_base = k / max(1, C) + + for B in B_sorted: + g = B // Bmin + counts_B = ( + counts_min + if g == 1 + else counts_min.reshape(rows, n_sub_min // g, g).sum(dim=2) + ) + lb = int(math.ceil(math.log2(B))) + n_sub = C // B + k_param = int(max(0, round(math.log2(max(lmbda_base * B, 1e-9))))) + m = 1 << k_param + q = counts_B // m + rb_sum = q.sum(dim=1) + (1 + k_param) * n_sub + nonzero = (counts_B > 0).sum(dim=1) + bits_local = header + rb_sum + lb * k + bits_bitmap = header + rb_sum + B * nonzero + cur_bits = torch.minimum(bits_local, bits_bitmap).to(torch.int64) + use_bitmap = bits_bitmap < bits_local + update = cur_bits < best_bits + best_bits = torch.where(update, cur_bits, best_bits) + best_B = torch.where(update, torch.full_like(best_B, B), best_B) + best_use_bitmap = torch.where(update, use_bitmap, best_use_bitmap) + else: + for B in B_sorted: + lb = int(math.ceil(math.log2(B))) + n_sub = C // B + js = (idx // B).to(torch.int64) + # Bincount per row with a constant bin width M=max_n_sub to avoid collisions + M = C // min(B_sorted) + row_ids = torch.arange(rows, device=device, dtype=torch.int64).unsqueeze(1) + flat = (row_ids * M + js).reshape(-1) + counts = torch.bincount(flat, minlength=rows * M).reshape(rows, M)[ + :, :n_sub + ] + lmbda = (k / max(1, C)) * B + k_param = int(max(0, round(math.log2(max(lmbda, 1e-9))))) + m = 1 << k_param + q = counts // m + rb_sum = q.sum(dim=1) + (1 + k_param) * n_sub + nonzero = (counts > 0).sum(dim=1) + bits_local = header + rb_sum + lb * k + bits_bitmap = header + rb_sum + B * nonzero + cur_bits = torch.minimum(bits_local, bits_bitmap).to(torch.int64) + use_bitmap = bits_bitmap < bits_local + update = cur_bits < best_bits + best_bits = torch.where(update, cur_bits, best_bits) + best_B = torch.where(update, torch.full_like(best_B, B), best_B) + best_use_bitmap = torch.where(update, use_bitmap, best_use_bitmap) + + # Rice k for chosen B + lmbda = (idx.shape[1] / max(1, C)) * best_B.float() + k_param = torch.clamp((lmbda.clamp_min(1e-9).log2().round()).to(torch.int64), min=0) + + return ( + best_B.to(torch.int32), + best_use_bitmap.to(torch.uint8), + k_param.to(torch.int32), + ) + + +# --------------------------- Triton helpers (no tl.uint32 calls) --------------------------- + + +@triton.jit +def _write_bits_u32(buf_ptr, bitpos, value, nbits): + # Write 'nbits' LSBs from value, advance bitpos, and return the new bitpos. + i = tl.zeros((), dtype=tl.int32) + bp = bitpos.to(tl.int64) + v = value.to(tl.int32) + while i < nbits: + bit = (v >> i) & 1 + byte_idx = bp // 8 + off = bp % 8 + p = buf_ptr + byte_idx + old = tl.load(p, mask=True, other=0).to(tl.int32) + newv = old | (bit << off) + tl.store(p, newv.to(tl.uint8)) + bp += 1 + i += 1 + return bp + + +@triton.jit +def _write_unary(buf_ptr, bitpos, q): + # Write q ones then a zero; zeros are already in buffer → just advance for trailing zero. + bp = bitpos.to(tl.int64) + i = tl.zeros((), dtype=tl.int32) + while i < q: + byte_idx = bp // 8 + off = bp % 8 + p = buf_ptr + byte_idx + old = tl.load(p, mask=True, other=0).to(tl.int32) + newv = old | (1 << off) + tl.store(p, newv.to(tl.uint8)) + bp += 1 + i += 1 + # trailing zero bit: buffer is zero-initialized, so just skip one bit + bp += 1 + return bp + + +@triton.jit +def _write_rice(buf_ptr, bitpos, x, kparam): + # Golomb-Rice for non-negative x with parameter k. + k = kparam.to(tl.int32) + if k == 0: + return _write_unary(buf_ptr, bitpos, x) + m = 1 << k + q = x // m + r = x & (m - 1) + bp = _write_unary(buf_ptr, bitpos, q) + bp = _write_bits_u32(buf_ptr, bp, r, k) + return bp + + +@triton.jit +def _set_one_bit(buf_ptr, bitpos): + bp = bitpos.to(tl.int64) + byte_idx = bp // 8 + off = bp % 8 + p = buf_ptr + byte_idx + old = tl.load(p, mask=True, other=0).to(tl.int32) + newv = old | (1 << off) + tl.store(p, newv.to(tl.uint8)) + + +# ------------------------------ Triton kernel --------------------------------- + + +@triton.jit +def _kernel_write_rows( + idx_ptr, # int64 [N*K] + rows, + k, + C, # ints + bestB_ptr, # int32 [N] + usebm_ptr, # uint8 [N] (0/1) + kparam_ptr, # int32 [N] + row_bytes_ptr, # int32 [N] + len_bitpos_ptr, # int64 [N] (bitpos of 16-bit length) + pay_bitpos_ptr, # int64 [N] (bitpos of first payload bit for the row) + payload_ptr, # uint8 [TOTAL_BYTES] + K_MAX: tl.constexpr, # upper bound for K (e.g., 256 or 512) +): + r = tl.program_id(0) + if r >= rows: + return + + # Per-row params + B = tl.load(bestB_ptr + r).to(tl.int32) + use_bitmap = tl.load(usebm_ptr + r).to(tl.int1) + kparam = tl.load(kparam_ptr + r).to(tl.int32) + # B is power-of-two ⇒ lb = log2(B) exactly (cast to float for tl.log2) + lb = tl.log2(B.to(tl.float32)).to(tl.int32) + n_sub = C // B + + # Write 16-bit length at its position (interleaved layout) + row_len = tl.load(row_bytes_ptr + r).to(tl.int32) + len_bp = tl.load(len_bitpos_ptr + r) + _ = _write_bits_u32(payload_ptr, len_bp, row_len, 16) + + # Row payload header + bp = tl.load(pay_bitpos_ptr + r) + bp = _write_bits_u32(payload_ptr, bp, lb, 5) # lb + bp = _write_bits_u32(payload_ptr, bp, kparam, 4) # k + bp = _write_bits_u32(payload_ptr, bp, use_bitmap.to(tl.int32), 1) # mode + + # Emit each sub-chunk j in ascending order + j = tl.zeros((), dtype=tl.int32) + while j < n_sub: + # -- first pass: count how many entries go to sub j + got = tl.zeros((), dtype=tl.int32) + t = tl.zeros((), dtype=tl.int32) + while t < k: + v = tl.load(idx_ptr + r * k + t).to(tl.int64) + jj = (v // B).to(tl.int32) + got += jj == j + t += 1 + + # write Rice length + bp = _write_rice(payload_ptr, bp, got, kparam) + + if got > 0: + if use_bitmap: + # second pass: set bits for locations; advance by B bits + start = bp + t = tl.zeros((), dtype=tl.int32) + while t < k: + v = tl.load(idx_ptr + r * k + t).to(tl.int64) + jj = (v // B).to(tl.int32) + if jj == j: + loc = (v - j.to(tl.int64) * B.to(tl.int64)).to(tl.int32) + _set_one_bit(payload_ptr, start + loc) + t += 1 + bp = start + B + else: + # local list: second pass writing lb-bit locs (order needn't be sorted) + t = tl.zeros((), dtype=tl.int32) + while t < k: + v = tl.load(idx_ptr + r * k + t).to(tl.int64) + jj = (v // B).to(tl.int32) + if jj == j: + loc = (v - j.to(tl.int64) * B.to(tl.int64)).to(tl.int32) + bp = _write_bits_u32(payload_ptr, bp, loc, lb) + t += 1 + j += 1 + # done + + +# -------------------------------- Public API ---------------------------------- + + +@torch.no_grad() +def encode_batch_rows( + idx: torch.Tensor, # [rows, k] int64 (CUDA strongly recommended) + *, + C: int, + B_choices: tuple[int, ...] = (64, 128), +) -> tuple[bytes, dict]: + """ + Triton encoder for per-row Rice/bitmap codec. + + Returns: + payload: bytes + meta: {total_bits, avg_bits_per_row, B_hist} + """ + if not TRITON_AVAILABLE: + raise RuntimeError("Triton is not available. `pip install triton` and re-run.") + + if idx.dtype != torch.int64: + idx = idx.to(torch.int64) + if not idx.is_cuda: + idx = idx.cuda() + idx = idx.contiguous() + + rows, k = idx.shape + device = idx.device + + # 1) Pick best params per row (GPU) + best_B, use_bitmap, k_param = _estimate_best_params_per_row( + idx, C=C, B_choices=B_choices + ) + + # 2) Compute exact per-row bit counts to size output + header_bits_row = 5 + 4 + 1 # lb + k + mode + lb = (best_B.float().log2().round().to(torch.int64)).clamp_min( + 1 + ) # exact for power-of-two + n_sub = C // best_B.to(torch.int64) + + # Bincount per row with constant width M to prevent collisions + M = int(max(C // int(b) for b in best_B.unique().tolist())) + row_ids = torch.arange(rows, device=device, dtype=torch.int64).unsqueeze(1) + js = (idx // best_B.to(torch.int64).unsqueeze(1)).to(torch.int64) # [rows, k] + flat = (row_ids * M + js).reshape(-1) + counts = torch.bincount(flat, minlength=rows * M).reshape(rows, M) # [rows, M] + # limit to effective n_sub per row when summing + # rice bits: Σ(q + 1 + k) with q = c // 2^k + m = 1 << k_param.to(torch.int64) + q = counts // m.unsqueeze(1) + rb_sum = q.sum(dim=1) + (1 + k_param.to(torch.int64)) * n_sub.to(torch.int64) + nonzero = (counts > 0).sum(dim=1).to(torch.int64) + + bits_local = header_bits_row + rb_sum + lb * k + bits_bitmap = header_bits_row + rb_sum + best_B.to(torch.int64) * nonzero + row_bits = torch.minimum(bits_local, bits_bitmap).to(torch.int64) + row_bytes = ((row_bits + 7) // 8).to(torch.int32) + + # 3) Allocate payload buffer (global header + interleaved [len16 | payload]) + total_bits_rows = int((16 * rows + 8 * row_bytes.sum().item())) + total_bits = 12 + 16 + 1 + total_bits_rows + total_bytes = (total_bits + 7) // 8 + payload = torch.zeros(total_bytes, dtype=torch.uint8, device=device) + + # 4) Compute interleaved bit positions for each row + header_bits = 12 + 16 + 1 + body_chunk_bits = 16 + 8 * row_bytes.to(torch.int64) # [rows] + prefix = torch.zeros_like(body_chunk_bits) + if rows > 0: + prefix[1:] = torch.cumsum(body_chunk_bits[:-1], dim=0) + len_bitpos = header_bits + prefix # [rows] + pay_bitpos = len_bitpos + 16 # [rows] + + # 5) Write global header in-place (LSB-first) using torch ops + def _write_scalar_bits(val: int, nbits: int, start_bit: int): + v = int(val) + bp = int(start_bit) + nb = int(nbits) + while nb > 0: + byte_idx = bp // 8 + off = bp % 8 + take = min(nb, 8 - off) + mask = ((v & ((1 << take) - 1)) << off) & 0xFF + payload[byte_idx] |= torch.as_tensor(mask, dtype=torch.uint8, device=device) + v >>= take + bp += take + nb -= take + + _write_scalar_bits(C - 1, 12, 0) + _write_scalar_bits(rows, 16, 12) + _write_scalar_bits(0, 1, 28) # reserved + + # 6) Launch Triton to write all rows + # Choose a safe K_MAX for your top-k; 256 covers k<=256; use 512 if you push k higher. + K_MAX = 256 if k <= 256 else 512 + grid = (rows,) + _kernel_write_rows[grid]( + idx, + rows, + k, + C, + best_B, + use_bitmap, + k_param, + row_bytes, + len_bitpos.to(torch.int64), + pay_bitpos.to(torch.int64), + payload, + K_MAX=K_MAX, + ) + + # 7) Return bytes + meta + payload_bytes = bytes(payload.detach().cpu().numpy().tobytes()) + B_hist = {int(b): int((best_B == b).sum().item()) for b in best_B.unique()} + meta = { + "total_bits": total_bits, + "avg_bits_per_row": float(row_bits.float().mean().item()) if rows > 0 else 0.0, + "B_hist": B_hist, + } + return payload_bytes, meta + diff --git a/src/tplr/compress/pack12.py b/src/tplr/compress/pack12.py new file mode 100644 index 000000000..388aa3502 --- /dev/null +++ b/src/tplr/compress/pack12.py @@ -0,0 +1,75 @@ +# The MIT License (MIT) +# © 2025 tplr.ai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import torch + + +def pack_12bit_indices(indices: torch.Tensor) -> torch.Tensor: + """ + Legacy helper: Pack int64 indices into a 12‑bit representation (pairs → 3 bytes). + Requires an even count of indices and values < 4096. + """ + max_idx = indices.max().item() if indices.numel() > 0 else 0 + if max_idx >= 4096: + raise ValueError(f"Index {max_idx} exceeds 12-bit limit (4095)") + + flat = indices.flatten() + n = flat.numel() + if n % 2 != 0: + raise ValueError(f"Number of indices must be even, got {n}") + + flat = flat.to(torch.int32) + n_pairs = n // 2 + packed = torch.zeros(n_pairs * 3, dtype=torch.uint8, device=indices.device) + + if n_pairs > 0: + pairs = flat.reshape(-1, 2) + idx1 = pairs[:, 0] + idx2 = pairs[:, 1] + packed[0::3] = (idx1 & 0xFF).to(torch.uint8) + packed[1::3] = (((idx1 >> 8) & 0x0F) | ((idx2 & 0x0F) << 4)).to(torch.uint8) + packed[2::3] = ((idx2 >> 4) & 0xFF).to(torch.uint8) + + return packed + + +def unpack_12bit_indices( + packed: torch.Tensor, values_shape: tuple[int, ...] +) -> torch.Tensor: + """ + Legacy helper: Unpack 12‑bit representation back into int64 indices and reshape + to the provided `values_shape` (which must match the original indices shape). + """ + device = packed.device + n_indices = 1 + for d in values_shape: + n_indices *= int(d) + if n_indices == 0: + return torch.zeros(values_shape, dtype=torch.int64, device=device) + if n_indices % 2 != 0: + raise ValueError(f"Number of indices must be even, got {n_indices}") + + out = torch.zeros(n_indices, dtype=torch.int64, device=device) + n_pairs = n_indices // 2 + if n_pairs > 0: + b0 = packed[0::3].to(torch.int64) + b1 = packed[1::3].to(torch.int64) + b2 = packed[2::3].to(torch.int64) + + out[0::2] = b0 | ((b1 & 0x0F) << 8) + out[1::2] = ((b1 >> 4) & 0x0F) | (b2 << 4) + return out.view(*values_shape) diff --git a/src/tplr/compress.py b/src/tplr/compress/topk.py similarity index 62% rename from src/tplr/compress.py rename to src/tplr/compress/topk.py index 206f3cd8d..4d5d01a7f 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress/topk.py @@ -1,14 +1,14 @@ # The MIT License (MIT) # © 2025 tplr.ai - +# # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the "Software"), to deal in the Software without restriction, including without limitation # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - +# # The above copyright notice and this permission notice shall be included in all copies or substantial portions of # the Software. - +# # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION @@ -23,212 +23,63 @@ import math from typing import Generic, Literal, Sequence, TypeAlias, TypeVar, cast, overload +import numpy as np import torch -import torch.fft from einops import rearrange from torch.distributed.tensor import DTensor as DT import tplr -# ─────────── type aliases ──────────────────────────────────────────────── -# primitive shapes -ShapeT: TypeAlias = tuple[int, ...] # original dense tensor shape -Shape4D = tuple[int, int, int, int] # y, x, h, w -TotK: TypeAlias = int # size of the last dim - -# 12‑bit packed representation - just the uint8 buffer, no tuple -IdxT: TypeAlias = torch.Tensor # 12-bit packed indices (stored as uint8 tensor) +from .bits import decode_batch_rows, encode_batch_rows +# ─────────── type aliases ──────────────────────────────────────────────── +ShapeT: TypeAlias = tuple[int, ...] +Shape4D = tuple[int, int, int, int] +TotK: TypeAlias = int +IdxT: TypeAlias = torch.Tensor # stored as uint8 byte-stream (new codec) QuantParamsT: TypeAlias = tuple[torch.Tensor, float, int, torch.Tensor, torch.dtype] - -# For historical names kept elsewhere in the code ValT: TypeAlias = torch.Tensor +_DEFAULT_B_CHOICES: tuple[int, ...] = (64, 128) + # Boolean flag that propagates the chosen quantisation mode Q = TypeVar("Q", Literal[True], Literal[False]) -def pack_12bit_indices(indices: torch.Tensor) -> torch.Tensor: - """ - Pack int64 indices into 12-bit representation. - Every 2 indices (24 bits) are packed into 3 uint8 values. - Assumes even number of indices (topk is always even). - - Args: - indices: Tensor with values < 4096 (12-bit max), must have even number of elements - - Returns: - packed_tensor as uint8 - """ - # Ensure indices fit in 12 bits - max_idx = indices.max().item() if indices.numel() > 0 else 0 - if max_idx >= 4096: - raise ValueError(f"Index {max_idx} exceeds 12-bit limit (4095)") - - # Flatten the tensor - indices_flat = indices.flatten() - n_indices = indices_flat.numel() - - # Ensure we have even number of indices - if n_indices % 2 != 0: - raise ValueError(f"Number of indices must be even, got {n_indices}") - - # Convert to int32 for bit manipulation - indices_flat = indices_flat.to(torch.int32) - - # Process all as pairs - indices_pairs = indices_flat - n_pairs = n_indices // 2 - - # Calculate packed size - packed_size = n_pairs * 3 - packed = torch.zeros(packed_size, dtype=torch.uint8, device=indices.device) - - # Vectorized packing for pairs - if n_pairs > 0: - idx_pairs = indices_pairs.reshape(-1, 2) - idx1 = idx_pairs[:, 0] - idx2 = idx_pairs[:, 1] - - # Pack pairs: idx1 uses byte0 + lower 4 bits of byte1 - # idx2 uses upper 4 bits of byte1 + byte2 - packed[0::3] = (idx1 & 0xFF).to(torch.uint8) # Lower 8 bits of idx1 - packed[1::3] = (((idx1 >> 8) & 0x0F) | ((idx2 & 0x0F) << 4)).to(torch.uint8) - packed[2::3] = ((idx2 >> 4) & 0xFF).to(torch.uint8) # Upper 8 bits of idx2 - - return packed - - -def unpack_12bit_indices(packed: torch.Tensor, values_shape: ShapeT) -> torch.Tensor: - """ - Unpack 12-bit packed indices back to int64. - Assumes even number of indices. - - Args: - packed: Packed uint8 tensor - values_shape: Shape of the values tensor (same as original indices shape) - - Returns: - Unpacked indices as int64 tensor with original shape - """ - n_indices = int(torch.prod(torch.tensor(values_shape)).item()) - - if n_indices == 0: - return torch.zeros(values_shape, dtype=torch.int64, device=packed.device) - - # Ensure even number of indices - if n_indices % 2 != 0: - raise ValueError(f"Number of indices must be even, got {n_indices}") - - # Prepare output - indices = torch.zeros(n_indices, dtype=torch.int64, device=packed.device) - - # All indices are paired - n_pairs = n_indices // 2 - - if n_pairs > 0: - # Vectorized unpacking - byte0 = packed[0::3].to(torch.int64) - byte1 = packed[1::3].to(torch.int64) - byte2 = packed[2::3].to(torch.int64) - - # Reconstruct indices - indices[0::2] = byte0 | ((byte1 & 0x0F) << 8) # idx1 - indices[1::2] = ((byte1 >> 4) & 0x0F) | (byte2 << 4) # idx2 - - # Reshape to match values shape - indices = indices.reshape(values_shape) - - return indices - - class ChunkingTransformer: """ A transformer for chunking tensors to enable more efficient gradient processing. - - This class handles the chunking of tensors into smaller blocks, which can be - processed more efficiently. It pre-calculates Discrete Cosine Transform (DCT) - basis matrices for various tensor sizes to speed up the transformation process. """ @torch.no_grad() - def __init__(self, model, target_chunk, norm="ortho"): + def __init__(self, model, target_chunk): """ Initialise the ChunkingTransformer. Args: model: The model whose parameters will be processed. target_chunk (int): The target size for tensor chunks. - norm (str): The normalization to be used for DCT ('ortho' or None). """ self.target_chunk = target_chunk self.shape_dict = dict() - self.f_dict = dict() - self.b_dict = dict() # Get all variants of model tensor sizes - # Generate all possible valid DCT sizes for model tensors for _, p in model.named_parameters(): if not p.requires_grad: continue for s in p.shape: - # Get the closest smallest divisor to the targeted DCT size + # Get the closest smallest divisor to the target chunk size sc = _get_smaller_split(s, self.target_chunk) self.shape_dict[s] = sc - # Pregenerate DCT basis matrices - if sc not in self.f_dict: - I = torch.eye(sc) # noqa: E741 - self.f_dict[sc] = _dct(I, norm=norm).to(p.dtype).to(p.device) - self.b_dict[sc] = _idct(I, norm=norm).to(p.dtype).to(p.device) - - @torch.no_grad() - def einsum_2d(self, x, b, d=None) -> torch.Tensor: - """ - Apply a 2D einsum operation for encoding. - - Args: - x (torch.Tensor): The input tensor. - b (torch.Tensor): The first basis matrix. - d (torch.Tensor, optional): The second basis matrix. Defaults to None. - - Returns: - torch.Tensor: The transformed tensor. - """ - if d is None: - return torch.einsum("...ij, jb -> ...ib", x, b) - else: - # Note: b-c axis output is transposed to chunk DCT in 2D - return torch.einsum("...ijkl, kb, ld -> ...ijbd", x, b, d) - - @torch.no_grad() - def einsum_2d_t(self, x, b, d=None) -> torch.Tensor: - """ - Apply a 2D einsum operation for decoding (transpose). - - Args: - x (torch.Tensor): The input tensor. - b (torch.Tensor): The first basis matrix. - d (torch.Tensor, optional): The second basis matrix. Defaults to None. - - Returns: - torch.Tensor: The transformed tensor. - """ - if d is None: - return torch.einsum("...ij, jb -> ...ib", x, b) - else: - # Note: b-c axis output is transposed to chunk DCT in 2D - return torch.einsum("...ijbd, bk, dl -> ...ijkl", x, b, d) - @torch.no_grad() - def encode(self, x: torch.Tensor, *, use_dct: bool = False) -> torch.Tensor: + def encode(self, x: torch.Tensor) -> torch.Tensor: """ - Encode a tensor by chunking and optionally applying DCT. + Encode a tensor by chunking. Args: x (torch.Tensor): The input tensor to encode. - use_dct (bool): Whether to apply the Discrete Cosine Transform. Returns: torch.Tensor: The encoded tensor. @@ -236,57 +87,27 @@ def encode(self, x: torch.Tensor, *, use_dct: bool = False) -> torch.Tensor: if len(x.shape) > 1: # 2D weights n1 = self.shape_dict[x.shape[0]] n2 = self.shape_dict[x.shape[1]] - n1w = self.f_dict[n1].to(x.device) - n2w = self.f_dict[n2].to(x.device) - self.f_dict[n1] = n1w - self.f_dict[n2] = n2w - x = rearrange(x, "(y h) (x w) -> y x h w", h=n1, w=n2) - if use_dct: - x = self.einsum_2d(x, n1w, n2w) - else: # 1D weights n1 = self.shape_dict[x.shape[0]] - n1w = self.f_dict[n1].to(x.device) - self.f_dict[n1] = n1w - x = rearrange(x, "(x w) -> x w", w=n1) - if use_dct: - x = self.einsum_2d(x, n1w) return x @torch.no_grad() - def decode(self, x: torch.Tensor, *, use_dct: bool = False) -> torch.Tensor: + def decode(self, x: torch.Tensor) -> torch.Tensor: """ - Decode a tensor by un-chunking and optionally applying inverse DCT. + Decode a tensor by un-chunking. Args: x (torch.Tensor): The input tensor to decode. - use_dct (bool): Whether to apply the inverse Discrete Cosine Transform. Returns: torch.Tensor: The decoded tensor. """ if len(x.shape) > 2: # 2D weights - if use_dct: - n1 = x.shape[2] - n2 = x.shape[3] - n1w = self.b_dict[n1].to(x.device) - n2w = self.b_dict[n2].to(x.device) - self.b_dict[n1] = n1w - self.b_dict[n2] = n2w - - x = self.einsum_2d_t(x, n1w, n2w) x = rearrange(x, "y x h w -> (y h) (x w)") - else: # 1D weights - if use_dct: - n1 = x.shape[1] - n1w = self.b_dict[n1].to(x.device) - self.b_dict[n1] = n1w - - x = self.einsum_2d_t(x, n1w) x = rearrange(x, "x w -> (x w)") return x @@ -362,12 +183,12 @@ def _clamp_topk(self, x, topk) -> int: """ topk = min(topk, x.shape[-1]) topk = max(topk, 2) - # Ensure topk is even for 12-bit packing efficiency + # Keep even by default (matches broader system expectations). topk = topk - (topk % 2) return int(topk) # ------------------------------------------------------------------ # - # compress – returns a 5-tuple *or* a 4-tuple, depending on the mode + # compress – returns a 5‑tuple (quant) or 4‑tuple (no quant) # ------------------------------------------------------------------ # @overload def compress( @@ -395,32 +216,60 @@ def compress(self, x: torch.Tensor, topk: int): # type: ignore[override] A tuple containing the compressed data. The format depends on whether quantization is used. """ + import time + if isinstance(x, DT): # check for dtensors x = x.to_local() xshape = x.shape - + + # Log the shape we're compressing + shape_start = time.time() + original_shape = xshape + if len(x.shape) > 2: # 2D weights x = rearrange(x, "y x h w -> y x (h w)") + + reshape_time = time.time() - shape_start # Limit topk to max size totalk = x.shape[-1] topk = self._clamp_topk(x, topk) + # Top‑K + topk_start = time.time() idx_int64 = torch.topk( x.abs(), k=topk, dim=-1, largest=True, sorted=False ).indices val = torch.gather(x, dim=-1, index=idx_int64) + topk_time = time.time() - topk_start + + # Flatten to [rows, k] for the codec + encode_start = time.time() + idx2d = idx_int64.reshape(-1, topk).contiguous() + # GPU‑accelerated encode → bytes + payload, _meta = encode_batch_rows( + idx2d, C=totalk, B_choices=_DEFAULT_B_CHOICES + ) - # Pack indices into 12-bit representation for efficient storage - # This reduces storage by 25% compared to int16 - idx = pack_12bit_indices(idx_int64) + idx_bytes = torch.tensor( + np.frombuffer(payload, dtype=np.uint8).copy(), + dtype=torch.uint8, + device="cpu", + ) + encode_time = time.time() - encode_start + + # Debug logging for timing + if hasattr(self, '_debug_timing') and self._debug_timing: + import tplr + tplr.logger.info( + f"[TOPK COMPRESS] shape={original_shape}, totalk={totalk}, topk={topk}, " + f"reshape={reshape_time:.3f}s, topk_select={topk_time:.3f}s, encode={encode_time:.3f}s" + ) - # Apply 8-bit quantization if enabled if self.use_quantization: - val, quant_params = self._quantize_values(val) - return idx, val, xshape, totalk, quant_params - - return idx, val, xshape, totalk + val, qparams = self._quantize_values(val) + return idx_bytes, val, xshape, totalk, qparams + return idx_bytes, val, xshape, totalk @torch.no_grad() def decompress( @@ -454,18 +303,23 @@ def decompress( if len(xshape) > 2: # 2D weights x = rearrange(x, "y x h w -> y x (h w)") - # Unpack 12-bit indices using val shape (if needed) + # Decode indices if idx.dtype == torch.uint8: - # 12-bit packed format - unpack it - idx_int64 = unpack_12bit_indices(idx, val.shape) + payload_bytes = idx.detach().cpu().numpy().tobytes() + rows_list, C, _N = decode_batch_rows(payload_bytes) + if C != totalk: + raise ValueError(f"Index payload C={C} but expected {totalk}") + k = val.shape[-1] + if any(len(r) != k for r in rows_list): + raise ValueError("Row-wise topk size mismatch in index payload") + idx_int64 = torch.tensor( + rows_list, dtype=torch.int64, device=p.device + ).view(*val.shape) elif idx.dtype in (torch.int64, torch.long): - # Already unpacked (from batch_decompress) - idx_int64 = idx + idx_int64 = idx.to(p.device) else: - raise ValueError( - f"Expected uint8 (packed) or int64 (unpacked) indices, got {idx.dtype}" - ) - # Ensure val has the same dtype as x for scatter operation + raise ValueError(f"Unsupported index tensor dtype: {idx.dtype}") + if val.dtype != x.dtype: val = val.to(dtype=x.dtype) @@ -562,13 +416,22 @@ def batch_decompress( idx_list = idx if isinstance(idx, Sequence) else [idx] for i, i_data in enumerate(idx_list): - if i_data.dtype != torch.uint8: - raise ValueError( - f"Expected uint8 for 12-bit packed indices, got {i_data.dtype}" - ) - # Unpack 12-bit format using corresponding values shape v_data = val_list[i] - idx_unpacked = unpack_12bit_indices(i_data.to(p.device), v_data.shape) + if i_data.dtype == torch.uint8: + rows, C, _N = decode_batch_rows(i_data.detach().cpu().numpy().tobytes()) + if C != totalk: + raise ValueError(f"Index payload C={C} but expected {totalk}") + if any(len(r) != v_data.shape[-1] for r in rows): + raise ValueError( + "Row-wise topk size mismatch in index payload (batch)" + ) + idx_unpacked = torch.tensor( + rows, dtype=torch.int64, device=p.device + ).view(*v_data.shape) + elif i_data.dtype in (torch.int64, torch.long): + idx_unpacked = i_data.to(p.device) + else: + raise ValueError(f"Unsupported index dtype in batch: {i_data.dtype}") unpacked_indices.append(idx_unpacked) idx_concat = torch.cat(unpacked_indices, dim=-1) @@ -579,6 +442,7 @@ def batch_decompress( p, idx_concat, val_concat, xshape, totalk, quantize_params=None ) + # -------------------- quantisation helpers --------------------------- @torch.no_grad() def _quantize_values(self, val: torch.Tensor) -> tuple[torch.Tensor, QuantParamsT]: """ @@ -596,17 +460,18 @@ def _quantize_values(self, val: torch.Tensor) -> tuple[torch.Tensor, QuantParams std = centered.norm() / math.sqrt(centered.numel() - 1) scale = self.range_in_sigmas * std / self.n_bins - if scale == 0 or torch.isnan(scale) or torch.isinf(scale): + if ( + isinstance(scale, torch.Tensor) + and (scale == 0 or torch.isnan(scale) or torch.isinf(scale)) + ) or ( + not isinstance(scale, torch.Tensor) + and (scale == 0 or not math.isfinite(float(scale))) + ): scale = torch.tensor(1.0, dtype=centered.dtype, device=val.device) - centered_fp32 = centered.to(torch.float32) - qval = ( - (centered_fp32 / scale + offset) - .round() - .clamp(0, self.n_bins - 1) - .to(torch.uint8) + qval = ((centered_fp32 / scale + offset).round().clamp(0, self.n_bins - 1)).to( + torch.uint8 ) - device = qval.device sums = torch.zeros(self.n_bins, dtype=torch.float32, device=device) counts = torch.zeros(self.n_bins, dtype=torch.float32, device=device) @@ -617,7 +482,7 @@ def _quantize_values(self, val: torch.Tensor) -> tuple[torch.Tensor, QuantParams ) lookup = torch.where(counts > 0, sums / counts, torch.zeros_like(sums)) - qparams: QuantParamsT = (shift, float(scale), offset, lookup, val.dtype) + qparams: QuantParamsT = (shift, float(scale), int(offset), lookup, val.dtype) return qval, qparams @torch.no_grad() @@ -635,10 +500,8 @@ def _dequantize_values( torch.Tensor: The dequantized values. """ if val.dtype == torch.uint8: - shift, _, _, lookup, orig_dtype = qparams - lookup = ( - lookup.to(val.device) if isinstance(lookup, torch.Tensor) else lookup - ) + shift, _scale, _offset, lookup, orig_dtype = qparams + lookup = lookup.to(val.device) deq = lookup[val.long()] + shift val = deq.to(orig_dtype) return val @@ -696,98 +559,6 @@ def maybe_dequantize_values( return vals_f32 -# Code modified and sourced from https://github.com/zh217/torch-dct -def _dct_fft_impl(v) -> torch.Tensor: - """FFT-based implementation of the DCT.""" - return torch.view_as_real(torch.fft.fft(v, dim=1)) - - -def _idct_irfft_impl(V) -> torch.Tensor: - """IRFFT-based implementation of the IDCT.""" - return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) - - -def _dct(x, norm=None) -> torch.Tensor: - """ - Discrete Cosine Transform, Type II (a.k.a. the DCT) - - For the meaning of the parameter `norm`, see: - https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html - - :param x: the input signal - :param norm: the normalization, None or 'ortho' - :return: the DCT-II of the signal over the last dimension - """ - x_shape = x.shape - N = x_shape[-1] - x = x.contiguous().view(-1, N) - - v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) - - Vc = _dct_fft_impl(v) - - k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * math.pi / (2 * N) - W_r = torch.cos(k) - W_i = torch.sin(k) - - V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i - - if norm == "ortho": - V[:, 0] /= math.sqrt(N) * 2 - V[:, 1:] /= math.sqrt(N / 2) * 2 - - V = 2 * V.view(*x_shape) - - return V - - -def _idct(X, norm=None) -> torch.Tensor: - """ - The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III - - Our definition of idct is that idct(dct(x)) == x - - For the meaning of the parameter `norm`, see: - https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html - - :param X: the input signal - :param norm: the normalization, None or 'ortho' - :return: the inverse DCT-II of the signal over the last dimension - """ - - x_shape = X.shape - N = x_shape[-1] - - X_v = X.contiguous().view(-1, x_shape[-1]) / 2 - - if norm == "ortho": - X_v[:, 0] *= math.sqrt(N) * 2 - X_v[:, 1:] *= math.sqrt(N / 2) * 2 - - k = ( - torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] - * math.pi - / (2 * N) - ) - W_r = torch.cos(k) - W_i = torch.sin(k) - - V_t_r = X_v - V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) - - V_r = V_t_r * W_r - V_t_i * W_i - V_i = V_t_r * W_i + V_t_i * W_r - - V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) - - v = _idct_irfft_impl(V) - x = v.new_zeros(v.shape) - x[:, ::2] += v[:, : N - (N // 2)] - x[:, 1::2] += v.flip([1])[:, : N // 2] - - return x.view(*x_shape) - - def _get_prime_divisors(n: int) -> list[int]: """ Get the prime divisors of a number. diff --git a/src/tplr/neurons.py b/src/tplr/neurons.py index 2bdce9aeb..63b8de756 100644 --- a/src/tplr/neurons.py +++ b/src/tplr/neurons.py @@ -30,10 +30,10 @@ from torch.distributed.tensor import distribute_tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler -from wandb.sdk.wandb_run import Run import tplr -from tplr.compress import unpack_12bit_indices +from tplr.compress import decode_batch_rows +from wandb.sdk.wandb_run import Run if TYPE_CHECKING: from neurons.miner import Miner @@ -87,7 +87,6 @@ def barrier(group=None): # ------------ start ------------ gradient, xshapes, totalks = {}, {}, {} lr = float(miner.hparams.outer_learning_rate) - use_dct = getattr(miner.hparams, "use_dct", False) topk = getattr(miner.hparams, "topk_compression", 32) if isinstance(miner.model, torch.nn.parallel.DistributedDataParallel): @@ -161,7 +160,7 @@ def barrier(group=None): error_feedback.add_(grad_full, alpha=lr) # --- 4) Encode & compress (owner only) --- - encoded = miner.transformer.encode(error_feedback, use_dct=use_dct) + encoded = miner.transformer.encode(error_feedback) idxs, vals, xshape, totalk, quant_params = miner.compressor.compress( encoded, topk @@ -180,7 +179,7 @@ def barrier(group=None): ) # --- 6) Decode & error-feedback update (owner only) --- - transmit_grad = miner.transformer.decode(decompressed, use_dct=use_dct) + transmit_grad = miner.transformer.decode(decompressed) del decompressed error_feedback.sub_(transmit_grad) # Keep error feedback on GPU for now, batch offload later @@ -232,7 +231,6 @@ def outer_step( device: str, is_master: bool, world_size: int, - use_dct: bool = False, wandb_run: Run | None = None, global_step: int | None = None, ) -> None: @@ -327,7 +325,7 @@ def _bcast_flag(v: int) -> int: clip_norm=True, ) - full_grad_src = transformer.decode(decompressed, use_dct=use_dct) + full_grad_src = transformer.decode(decompressed) # Single conversion to target dtype+device to avoid extra temporaries full_grad_src = full_grad_src.to( dtype=p.dtype, device=p.device, non_blocking=True @@ -702,7 +700,6 @@ async def catchup_with_aggregation_server( device=instance.config.device, is_master=instance.is_master, # rank-0 handles logging world_size=instance.world_size, - use_dct=instance.hparams.use_dct, wandb_run=instance.wandb if instance.is_master else None, global_step=instance.global_step, ) @@ -946,30 +943,45 @@ async def check_uid_index_overlap( if idxs_all is None: continue - # Get values for unpacking shape - vals_key = pname + "vals" - vals_all = getattr(gather_result.state_dict, vals_key, None) - if vals_all is None: - continue + def _as_bytes(x) -> bytes: + if isinstance(x, (bytes, bytearray)): + return bytes(x) + if isinstance(x, torch.Tensor): + if x.dtype != torch.uint8: + raise ValueError( + f"Expected torch.uint8 for Rice payload, got {x.dtype}" + ) + return x.detach().cpu().contiguous().numpy().tobytes() + raise TypeError(f"Unsupported idx payload type: {type(x)}") - # Unpack all 12-bit packed indices using values shape - unpacked_indices = [] + decoded_per_peer: list[torch.Tensor] = [] for i in range(Ptot): - idx_data = idxs_all[i] if isinstance(idxs_all, list) else idxs_all - val_data = vals_all[i] if isinstance(vals_all, list) else vals_all + idx_data = idxs_all[i] if isinstance(idxs_all, (list, tuple)) else idxs_all + payload = _as_bytes(idx_data) - # 12-bit packed format - use values shape for unpacking - unpacked = unpack_12bit_indices( - idx_data.to(neuron.config.device), val_data.shape - ) - unpacked_indices.append(unpacked) + rows_i, _C_codec, N_rows = decode_batch_rows( + payload + ) # rows_i: list[list[int]] + if N_rows == 0: + # no rows for this param/peer → skip param entirely + decoded_per_peer = [] + break + + # ensure rectangular (constant k) + k0 = len(rows_i[0]) + if not all(len(r) == k0 for r in rows_i): + raise ValueError("Rice payload has variable k per row; unsupported.") + + decoded_per_peer.append(torch.tensor(rows_i, dtype=torch.int64)) + + if not decoded_per_peer: + continue - idxs_tensor = torch.stack(unpacked_indices, dim=0) - P, *chunk_dims, k = idxs_tensor.shape - C = int(torch.prod(torch.tensor(chunk_dims))) # num chunks - idxs_flat = idxs_tensor.reshape(P, C, k) + idxs_tensor = torch.stack(decoded_per_peer, dim=0) # [P, C, K] + P, C_chunks, k = idxs_tensor.shape + idxs_flat = idxs_tensor # already [P, C, K] - param_weight = C * k # size weight + param_weight = C_chunks * k # size weight for i in range(P): for j in range(i + 1, P): diff --git a/tests/test_comms.py b/tests/test_comms.py index 3263c4cf6..154c1ac5b 100644 --- a/tests/test_comms.py +++ b/tests/test_comms.py @@ -9,13 +9,12 @@ import pytest import torch from types import SimpleNamespace -from dotenv import load_dotenv import asyncio -from dataclasses import dataclass from datetime import datetime, timedelta, timezone from tplr import load_hparams -from tplr.compress import pack_12bit_indices +from tplr.compress import pack_12bit_indices, encode_batch_rows +import numpy as np hparams = load_hparams() @@ -50,7 +49,7 @@ def create_xshapes_totalks(model): def create_valid_state_dict(model): state_dict = {} for name, _ in model.named_parameters(): - # Create 12-bit packed format + # Create legacy 12-bit packed format (for backwards compatibility test) indices = torch.tensor([0, 1], dtype=torch.long) packed_data = pack_12bit_indices(indices) state_dict[name + "idxs"] = packed_data @@ -67,9 +66,9 @@ def create_missing_idxs(model): def create_packed_indices(indices_list): - """Helper function to create 12-bit packed indices from a list""" + """Helper function to create legacy 12-bit packed indices from a list""" indices = torch.tensor(indices_list, dtype=torch.long) - # Ensure even number of indices for 12-bit packing + # Ensure even number of indices for legacy 12-bit packing if len(indices_list) % 2 != 0: indices = torch.cat([indices, torch.tensor([0], dtype=torch.long)]) packed_data = pack_12bit_indices(indices) @@ -246,7 +245,7 @@ async def test_gather_basic_functionality(comms_instance, dummy_compressor): data={ "0.weightidxs": create_packed_indices( [0, 1, 2, 3] - ), # Even count for 12-bit + ), # Even count for legacy format "0.weightvals": torch.tensor([0.4, 0.5, 0.6, 0.7]), "totalks": {"0.weight": totalk_value}, }, @@ -257,7 +256,7 @@ async def test_gather_basic_functionality(comms_instance, dummy_compressor): data={ "0.weightidxs": create_packed_indices( [0, 1, 2, 3] - ), # Even count for 12-bit + ), # Even count for legacy format "0.weightvals": torch.tensor([0.7, 0.8, 0.9, 1.0]), "totalks": {"0.weight": totalk_value}, }, @@ -317,7 +316,7 @@ async def test_gather_normalization(comms_instance, dummy_compressor): data={ "0.weightidxs": create_packed_indices( [0, 1, 2, 3] - ), # Even count for 12-bit + ), # Even count for legacy format "0.weightvals": torch.tensor([0.4, 0.5, 0.6, 0.7]), "totalks": {"0.weight": totalk_value}, }, @@ -482,7 +481,7 @@ async def test_gather_averaging(comms_instance, dummy_compressor): data={ "0.weightidxs": create_packed_indices( [0, 1, 2, 3] - ), # Even count for 12-bit + ), # Even count for legacy format "0.weightvals": torch.tensor([0.4, 0.5, 0.6, 0.7]), "totalks": {"0.weight": totalk_value}, }, @@ -493,7 +492,7 @@ async def test_gather_averaging(comms_instance, dummy_compressor): data={ "0.weightidxs": create_packed_indices( [0, 1, 2, 3] - ), # Even count for 12-bit + ), # Even count for legacy format "0.weightvals": torch.tensor([0.8, 0.9, 1.0, 1.1]), "totalks": {"0.weight": totalk_value}, }, @@ -1555,36 +1554,46 @@ def __init__(self): ) -def test_valid_12bit_packed_indices(): +def test_valid_rice_bitmap_encoded_indices(): """ - Test Case: test_valid_12bit_packed_indices - - Input: 12-bit packed indices with correct topk dimension + Test Case: test_valid_rice_bitmap_encoded_indices + - Input: Rice/bitmap encoded indices with correct topk dimension - Valid indices (all indices within [0, totalk-1]) - Expected Outcome: The function should complete without raising an error. """ dummy_comms = DummyComms() - # totalk is set to 10; allowed_topk is min(4, 10) == 4. - totalk = 10 - valid_indices = torch.tensor([1, 5, 9, 3], dtype=torch.long) - packed_data = pack_12bit_indices(valid_indices) - vals = torch.randn_like(valid_indices, dtype=torch.float32) + # totalk is set to 64; allowed_topk is min(4, 64) == 4. + totalk = 64 + valid_indices = torch.tensor( + [[1, 5, 9, 3]], dtype=torch.long + ) # Shape [1, 4] for one row + # Use the new encoder format + payload, _ = encode_batch_rows(valid_indices, C=totalk) + packed_data = torch.tensor( + np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8 + ) + vals = torch.randn(1, 4, dtype=torch.float32) # Match the shape [rows, k] # This call should complete without any error. dummy_comms.check_compressed_indices("test_param", packed_data, totalk, vals=vals) -def test_valid_12bit_packed_multi_dim(): +def test_valid_rice_bitmap_encoded_multi_dim(): """ - Test 12-bit packed indices from multi-dimensional tensor where the last dimension + Test Rice/bitmap encoded indices from multi-dimensional tensor where the last dimension equals min(hparams.topk_compression, totalk) and all indices are within valid range. """ dummy_comms = DummyComms() - totalk = 20 # allowed_topk = min(4, 20) = 4 + totalk = 128 # allowed_topk = min(4, 128) = 4 # Create a valid 2D tensor (shape: 2 x 4) with valid indices. valid_indices = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.long) - packed_data = pack_12bit_indices(valid_indices) - vals = torch.randn_like(valid_indices, dtype=torch.float32) + # Use the new encoder format + payload, _ = encode_batch_rows(valid_indices, C=totalk) + packed_data = torch.tensor( + np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8 + ) + vals = torch.randn(2, 4, dtype=torch.float32) # Match the shape dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) @@ -1593,18 +1602,18 @@ def test_invalid_not_packed_format(): Test that non-packed formats (like regular tensors or lists) are rejected. """ dummy_comms = DummyComms() - totalk = 20 + totalk = 128 # Test with regular tensor (not packed) - should fail because it's not uint8 invalid_tensor = torch.tensor([0, 1, 2, 3], dtype=torch.long) vals = torch.randn(4, dtype=torch.float32) - # This should fail since only uint8 12-bit packed format is supported - with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"): + # This should fail since only uint8 Rice/bitmap encoded format is supported + with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"): dummy_comms.check_compressed_indices("param", invalid_tensor, totalk, vals=vals) # Test with list (not a tensor) invalid_list = torch.tensor([0, 1, 2, 3]) - with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"): + with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"): dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals) @@ -1613,124 +1622,148 @@ def test_invalid_wrong_dtype(): Test that packed data with wrong dtype is handled correctly. """ dummy_comms = DummyComms() - totalk = 20 + totalk = 128 # int32 tensor is not uint8, so it should fail fake_packed = torch.tensor([0, 1, 2, 3], dtype=torch.int32) vals = torch.randn(4, dtype=torch.float32) # Should fail since only uint8 format is supported - with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"): + with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"): dummy_comms.check_compressed_indices("param", fake_packed, totalk, vals=vals) -def test_invalid_12bit_packed_wrong_topk(): +def test_invalid_rice_bitmap_wrong_topk(): """ - Test that 12-bit packed indices with wrong topk dimension raises ValueError. + Test that Rice/bitmap encoded indices with wrong topk dimension raises ValueError. """ dummy_comms = DummyComms() - totalk = 10 # allowed_topk = min(4, 10) = 4 + totalk = 64 # allowed_topk = min(4, 64) = 4 # Create packed indices with wrong topk (2 instead of 4) - invalid_indices = torch.tensor([0, 1], dtype=torch.long) - packed_data = pack_12bit_indices(invalid_indices) - vals = torch.randn(2, dtype=torch.float32) # Wrong shape - should be 4 - with pytest.raises(ValueError, match="Invalid topk dimension"): + invalid_indices = torch.tensor([[0, 1]], dtype=torch.long) + payload, _ = encode_batch_rows(invalid_indices, C=totalk) + packed_data = torch.tensor( + np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8 + ) + vals = torch.randn(1, 2, dtype=torch.float32) # Wrong shape - should be 4 + with pytest.raises(ValueError, match="Values top.*k=2 but allowed_topk=4"): dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) -def test_invalid_12bit_packed_multi_dim_wrong_topk(): +def test_invalid_rice_bitmap_multi_dim_wrong_topk(): """ - Test that 12-bit packed indices from multi-dimensional tensor with wrong last dimension + Test that Rice/bitmap encoded indices from multi-dimensional tensor with wrong last dimension raises ValueError indicating invalid topk dimension. """ dummy_comms = DummyComms() - totalk = 20 # allowed_topk = min(4, 20) = 4 + totalk = 128 # allowed_topk = min(4, 128) = 4 # Create a 2D tensor with last dimension size 6 (should be 4) invalid_indices = torch.tensor( [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]], dtype=torch.long ) - packed_data = pack_12bit_indices(invalid_indices) + payload, _ = encode_batch_rows(invalid_indices, C=totalk) + packed_data = torch.tensor( + np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8 + ) vals = torch.randn(2, 6, dtype=torch.float32) # Wrong shape - should be (2, 4) - with pytest.raises(ValueError, match="Invalid topk dimension"): + with pytest.raises(ValueError, match="Values top.*k=6 but allowed_topk=4"): dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) -# Removed test_invalid_12bit_packed_negative_index as pack_12bit_indices validates input +# Removed test_invalid_rice_bitmap_negative_index as encoder validates input -def test_invalid_12bit_packed_out_of_bounds(): +def test_invalid_rice_bitmap_out_of_bounds(): """ - Test that 12-bit packed indices with out-of-bounds values raise ValueError. + Test that Rice/bitmap encoded indices with out-of-bounds values raise ValueError. """ dummy_comms = DummyComms() - totalk = 10 # allowed_topk = min(4, 10) = 4 - # Index 10 is out-of-range because valid indices are 0 to 9. - invalid_indices = torch.tensor([0, 1, 10, 3], dtype=torch.long) - packed_data = pack_12bit_indices(invalid_indices) - vals = torch.randn(4, dtype=torch.float32) - with pytest.raises(ValueError, match="Index 10 out of bounds"): + totalk = 64 # allowed_topk = min(4, 64) = 4 + # Index 64 is out-of-range because valid indices are 0 to 63. + # But the encoder will fail before we can test - so let's test with valid encode but wrong totalk + invalid_indices = torch.tensor([[0, 1, 9, 3]], dtype=torch.long) + # Encode with a larger C to make it work + payload, _ = encode_batch_rows(invalid_indices, C=128) + packed_data = torch.tensor( + np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8 + ) + vals = torch.randn(1, 4, dtype=torch.float32) + # Now check with smaller totalk=64, so index 9 is valid but payload says C=128 + with pytest.raises(ValueError, match="Payload column size C=128 but expected 64"): dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) # Removed test_invalid_flat_list_wrong_length - covered by test_invalid_not_packed_format -# Removed test_valid_single_value - not applicable to 12-bit packed format +# Removed test_valid_single_value - not applicable to Rice/bitmap encoded format -# Removed test_invalid_single_value_out_of_bounds - not applicable to 12-bit packed format +# Removed test_invalid_single_value_out_of_bounds - not applicable to Rice/bitmap encoded format -def test_override_allowed_topk_12bit(): +def test_override_allowed_topk_rice_bitmap(): """ - Test using the optional allowed_topk parameter with 12-bit packed format. + Test using the optional allowed_topk parameter with Rice/bitmap encoded format. """ dummy_comms = DummyComms() - totalk = 10 + totalk = 64 # Override allowed_topk to 2. valid_indices = torch.tensor( - [0, 9], dtype=torch.long + [[0, 9]], dtype=torch.long ) # Correct length: 2 elements. - packed_data = pack_12bit_indices(valid_indices) - vals = torch.randn(2, dtype=torch.float32) + payload, _ = encode_batch_rows(valid_indices, C=totalk) + packed_data = torch.tensor( + np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8 + ) + vals = torch.randn(1, 2, dtype=torch.float32) dummy_comms.check_compressed_indices( "param", packed_data, totalk, allowed_topk=2, vals=vals ) # Test with wrong topk invalid_indices = torch.tensor( - [0, 1, 2, 3], dtype=torch.long + [[0, 1, 2, 3]], dtype=torch.long ) # 4 elements instead of 2. - packed_data = pack_12bit_indices(invalid_indices) - vals = torch.randn(4, dtype=torch.float32) # Wrong shape for allowed_topk=2 - with pytest.raises(ValueError, match="Invalid topk dimension"): + payload, _ = encode_batch_rows(invalid_indices, C=totalk) + packed_data = torch.tensor( + np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8 + ) + vals = torch.randn(1, 4, dtype=torch.float32) # Wrong shape for allowed_topk=2 + with pytest.raises(ValueError, match="Values top.*k=4 but allowed_topk=2"): dummy_comms.check_compressed_indices( "param", packed_data, totalk, allowed_topk=2, vals=vals ) -def test_topk_auto_adjust_when_totalk_is_lower_12bit(): +def test_topk_auto_adjust_when_totalk_is_lower_rice_bitmap(): """ - Test scenario where totalk is less than hparams.topk_compression with 12-bit packed format. + Test scenario where totalk is less than hparams.topk_compression with Rice/bitmap encoded format. """ dummy_comms = DummyComms() - totalk = 2 # Now allowed_topk becomes min(hparams.topk_compression, totalk) = min(4,2) = 2. + totalk = 64 # Now allowed_topk becomes min(hparams.topk_compression, totalk) = min(4,64) = 4. valid_indices = torch.tensor( - [0, 1], dtype=torch.long - ) # Valid: length matches allowed_topk (which is 2). - packed_data = pack_12bit_indices(valid_indices) - vals = torch.randn(2, dtype=torch.float32) + [[0, 1, 2, 3]], dtype=torch.long + ) # Valid: length matches allowed_topk (which is 4). + payload, _ = encode_batch_rows(valid_indices, C=totalk) + packed_data = torch.tensor( + np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8 + ) + vals = torch.randn(1, 4, dtype=torch.float32) dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) - # Note: Can't test with 1 element as pack_12bit_indices requires even number of indices - # Test with 4 elements (wrong topk) + # Note: Can't test with 1 element as encoder requires even number of indices + # Test with 6 elements (wrong topk) invalid_indices = torch.tensor( - [0, 1, 0, 1], dtype=torch.long - ) # 4 elements instead of 2. - packed_data = pack_12bit_indices(invalid_indices) - vals = torch.randn(4, dtype=torch.float32) # Wrong shape for allowed_topk=2 - with pytest.raises(ValueError, match="Invalid topk dimension"): + [[0, 1, 2, 3, 4, 5]], dtype=torch.long + ) # 6 elements instead of 4. + payload, _ = encode_batch_rows(invalid_indices, C=totalk) + packed_data = torch.tensor( + np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8 + ) + vals = torch.randn(1, 6, dtype=torch.float32) # Wrong shape for allowed_topk=4 + with pytest.raises(ValueError, match="Values top.*k=6 but allowed_topk=4"): dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) diff --git a/tests/test_prepare_gradient_dict.py b/tests/test_prepare_gradient_dict.py index 144503975..c88d4270c 100644 --- a/tests/test_prepare_gradient_dict.py +++ b/tests/test_prepare_gradient_dict.py @@ -11,7 +11,6 @@ def __init__(self): self.momentum_decay = 0.9 self.topk_compression = 5 self.outer_learning_rate = 0.9 - self.use_dct = False class DummyCompressor: @@ -31,10 +30,10 @@ def decompress(self, p, idxs, vals, xshape, totalk, quant_params): class DummyTransformer: - def encode(self, tensor, use_dct): + def encode(self, tensor): return tensor - def decode(self, tensor, use_dct): + def decode(self, tensor): return torch.tensor([0.1, 0.1]) @@ -165,10 +164,10 @@ def decompress(self, p, idxs, vals, xshape, totalk, quant_params): return torch.zeros_like(p) class DummyPassThroughTransformer: - def encode(self, tensor, use_dct): # identity + def encode(self, tensor): # identity return tensor - def decode(self, tensor, use_dct): # returns tensor as-is + def decode(self, tensor): # returns tensor as-is return tensor # ------------------------------------------------------------------ # @@ -258,11 +257,11 @@ class DummyRecordingTransformer: def __init__(self): self.decode_called_with = None - def encode(self, tensor, use_dct): + def encode(self, tensor): # Identity for easier reasoning return tensor - def decode(self, tensor, use_dct): + def decode(self, tensor): self.decode_called_with = tensor.clone() return torch.tensor([0.1, 0.1]) # value not important for this test @@ -468,7 +467,7 @@ def test_propagation_of_transformer_failure(): miner = DummyMiner() # Override transformer.decode to throw an exception. - def failing_decode(tensor, use_dct): + def failing_decode(tensor): raise RuntimeError("Transformer error") miner.transformer.decode = failing_decode diff --git a/tests/unit/test_bits_codec.py b/tests/unit/test_bits_codec.py new file mode 100644 index 000000000..21cfc725a --- /dev/null +++ b/tests/unit/test_bits_codec.py @@ -0,0 +1,335 @@ +import math + +import numpy as np +import pytest +import torch + +# Import from your module (adjust the import path if your file lives elsewhere) +from tplr.compress.bits import ( + BitReader, + decode_batch_rows, + encode_batch_rows, + encode_batch_rows_cpu, +) + +# ------------------------------------------------------------------------- +# Helpers +# ------------------------------------------------------------------------- + + +def device_params(): + devs = ["cpu"] + if torch.cuda.is_available(): + devs.append("cuda") + return devs + + +def make_even_k(k: int) -> int: + return k if k % 2 == 0 else k - 1 if k > 0 else 0 + + +def scatter2d(indices: torch.Tensor, values: torch.Tensor, C: int) -> torch.Tensor: + """ + Scatter-add helper: indices [N, K], values [N, K] -> dense [N, C]. + Uses sum; for unique indices per row, mean == sum (used in your decompressor). + """ + N, K = indices.shape + out = torch.zeros((N, C), dtype=values.dtype, device=values.device) + out.scatter_add_(1, indices.long(), values) + return out + + +def parse_first_row_header(payload: bytes): + """ + Read container header and the first row header to extract: + - C, N + - first row byte length + - (lb, k_param, use_bitmap) for row 0 + """ + br = BitReader(payload) + C = br.read_bits(12) + 1 + N = br.read_bits(16) + _ = br.read_bits(1) # reserved + + if N == 0: + return C, N, 0, 0, 0, 0 + + row_len = br.read_bits(16) + row_bytes = br.read_bytes(row_len) + rr = BitReader(row_bytes) + lb = rr.read_bits(5) + k_param = rr.read_bits(4) + use_bitmap = rr.read_bits(1) + return C, N, row_len, lb, k_param, use_bitmap + + +# ------------------------------------------------------------------------- +# Core correctness tests +# ------------------------------------------------------------------------- + + +@pytest.mark.parametrize("device", device_params()) +@pytest.mark.parametrize( + "N,C,K", + [ + (1, 64, 6), # C=64 is divisible by 64 + (4, 128, 8), # C=128 is divisible by both 64 and 128 + (8, 256, 12), # C=256 is divisible by both 64 and 128 + (3, 192, 10), # C=192 is divisible by 64 + ], +) +def test_roundtrip_decode_matches_original_permutation(device, N, C, K): + """ + The strongest property: for each row, + decoded_indices == original_indices[ perm ]. + Also the sets match (ignoring ordering). + """ + K = make_even_k(K) + # Generate unique indices per row (as topk would produce) + # Use a simpler approach: create ascending indices then shuffle per row + idx = torch.zeros((N, K), device=device, dtype=torch.int64) + for i in range(N): + # Create a unique set of K indices for this row + all_indices = torch.arange(C, device=device, dtype=torch.int64) + shuffled = all_indices[torch.randperm(C, device=device)][:K] + idx[i] = shuffled + + payload, meta = encode_batch_rows(idx, C=C) + rows, C2, N2 = decode_batch_rows(payload) + + assert C2 == C + assert N2 == N + # Check decoded indices set equality + for i in range(N): + decoded = rows[i] + assert len(decoded) == K + orig = idx[i].detach().cpu().tolist() + # set equality - decoded values should match original, though order may differ + assert sorted(decoded) == sorted(orig) + + # meta sanity + assert isinstance(meta, dict) + assert "total_bits" in meta and meta["total_bits"] > 0 + assert "avg_bits_per_row" in meta and meta["avg_bits_per_row"] >= 0 + assert "B_hist" in meta and isinstance(meta["B_hist"], dict) + assert sum(meta["B_hist"].values()) == N + + +@pytest.mark.parametrize("device", device_params()) +def test_decode_preserves_indices(device): + """ + Test that decoded indices preserve the same set of values as original. + """ + N, C, K = 5, 128, 8 # C=128 is divisible by both 64 and 128 + K = make_even_k(K) + # Generate unique indices per row (as topk would produce) + idx = torch.zeros((N, K), device=device, dtype=torch.int64) + for i in range(N): + idx[i] = torch.randperm(C, device=device, dtype=torch.int64)[:K] + + payload, _ = encode_batch_rows(idx, C=C) + rows, C2, N2 = decode_batch_rows(payload) + assert C2 == C and N2 == N + + # Check that decoded indices match original (set equality) + for i in range(N): + orig = idx[i].detach().cpu().tolist() + decoded = rows[i] + assert sorted(orig) == sorted(decoded), f"Row {i}: indices don't match" + + +@pytest.mark.parametrize("device", device_params()) +def test_cpu_reference_decoder_equivalence(device): + """ + The CPU reference encoder should decode to the same per-row indices + as the new encode_batch_rows (not necessarily byte-identical payload). + """ + N, C, K = 6, 128, 10 # C=128 is divisible by both 64 and 128 + K = make_even_k(K) + # Generate unique indices per row + idx = torch.zeros((N, K), device=device, dtype=torch.int64) + for i in range(N): + idx[i] = torch.randperm(C, device=device, dtype=torch.int64)[:K] + + # new path + payload_new, _ = encode_batch_rows(idx, C=C) + rows_new, Cn, Nn = decode_batch_rows(payload_new) + assert Cn == C and Nn == N + # ref path + payload_ref, _meta_ref = encode_batch_rows_cpu( + idx.detach().cpu().numpy().astype(np.int64), C=C + ) + rows_ref, Cr, Nr = decode_batch_rows(payload_ref) + assert Cr == C and Nr == N + + # compare decoded rows - check set equality since order may differ + for i in range(N): + assert sorted(rows_ref[i]) == sorted(rows_new[i]), ( + f"row {i} decode differs (CPU ref vs new)" + ) + # check that decoded values match original + for i in range(N): + orig = idx[i].detach().cpu().tolist() + assert sorted(orig) == sorted(rows_new[i]) + + +# ------------------------------------------------------------------------- +# Edge cases & error handling +# ------------------------------------------------------------------------- + + +@pytest.mark.parametrize("device", device_params()) +def test_zero_rows(device): + C, K = 64, 6 # C=64 is divisible by 64 + K = make_even_k(K) + idx = torch.empty(0, K, dtype=torch.int64, device=device) + + payload, meta = encode_batch_rows(idx, C=C) + rows, C2, N2 = decode_batch_rows(payload) + assert C2 == C and N2 == 0 + assert rows == [] + assert "B_hist" in meta and sum(meta["B_hist"].values()) == 0 + + +@pytest.mark.parametrize("device", device_params()) +def test_zero_k(device): + """ + k == 0 should still produce a valid payload and 0-length rows; + permutation is [N, 0]. + """ + N, C, K = 3, 128, 0 # C=128 is divisible by both 64 and 128 + idx = torch.empty(N, K, dtype=torch.int64, device=device) + + payload, _ = encode_batch_rows(idx, C=C) + rows, C2, N2 = decode_batch_rows(payload) + assert C2 == C and N2 == N + for i in range(N): + assert rows[i] == [] + + +@pytest.mark.parametrize("device", device_params()) +def test_non_int64_indices_cast_ok(device): + """ + encode_batch_rows should accept integer tensors not strictly int64 + and cast internally without error. + """ + N, C, K = 4, 128, 6 # C=128 is divisible by both 64 and 128 + K = make_even_k(K) + # Generate unique indices per row + idx_64 = torch.zeros((N, K), device=device, dtype=torch.int64) + for i in range(N): + idx_64[i] = torch.randperm(C, device=device, dtype=torch.int64)[:K] + idx = idx_64.to(torch.int32) + + payload, _ = encode_batch_rows(idx, C=C) + rows, C2, N2 = decode_batch_rows(payload) + assert C2 == C and N2 == N + for i in range(N): + assert len(rows[i]) == K + + +def test_invalid_b_choices_raise_for_new_encoder(): + """ + New encoder returns ValueError when no valid B in B_choices. + (CPU reference falls back to power-of-two divisors, tested below.) + """ + N, C, K = 2, 10, 4 # C=10 is not divisible by 64 or 128 + # Generate unique indices per row + idx = torch.zeros((N, K), dtype=torch.int64) + for i in range(N): + idx[i] = torch.randperm(C, dtype=torch.int64)[:K] + with pytest.raises(ValueError, match="No valid B choices for C"): + encode_batch_rows( + idx, C=C, B_choices=(3, 6, 12) + ) # none is a power-of-two divisor of 10 + + +def test_cpu_reference_fallback_works_with_invalid_b_choices(): + """ + CPU reference should still work (it falls back to power-of-two divisors). + """ + N, C, K = 2, 10, 4 # C=10 is not divisible by 64 or 128 + rows_np = np.random.randint(0, C, size=(N, K), dtype=np.int64) + payload, meta = encode_batch_rows_cpu(rows_np, C=C, B_choices=(3, 6, 12)) + rows, C2, N2 = decode_batch_rows(payload) + assert C2 == C and N2 == N + assert "B_hist" in meta and sum(meta["B_hist"].values()) == N + + +# ------------------------------------------------------------------------- +# Bitmap vs local payload path selection +# ------------------------------------------------------------------------- + + +def test_uses_bitmap_when_dense_within_subbucket(): + """ + Construct a case where k is large within one B=64 sub-bucket, + so bitmap (B bits) is cheaper than emitting locs (k * lb). + We verify 'use_bitmap' bit in the row header. + """ + N, C, B = 1, 128, 64 + # put many positions inside sub 0 of B=64 (enough to make bitmap worthwhile) + idx = torch.tensor( + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]], dtype=torch.int64 + ) + payload, _ = encode_batch_rows(idx, C=C, B_choices=(B,)) + C2, N2, row_len, lb, k_param, use_bitmap = parse_first_row_header(payload) + assert ( + C2 == C and N2 == 1 and lb == int(math.ceil(math.log2(B))) + ) # lb should be 6 for B=64 + assert use_bitmap == 1, "Expected bitmap path for dense sub-bucket" + + +def test_uses_local_when_sparse_within_subbucket(): + """ + Construct a case where very few locs within a B=64 block + makes local payload (k * lb) cheaper than bitmap (B bits). + """ + N, C, B = 1, 128, 64 + idx = torch.tensor([[0, 63]], dtype=torch.int64) # very sparse within the block + payload, _ = encode_batch_rows(idx, C=C, B_choices=(B,)) + C2, N2, row_len, lb, k_param, use_bitmap = parse_first_row_header(payload) + assert ( + C2 == C and N2 == 1 and lb == int(math.ceil(math.log2(B))) + ) # lb should be 6 for B=64 + assert use_bitmap == 0, "Expected local (loc-stream) path for sparse sub-bucket" + + +# ------------------------------------------------------------------------- +# Cross-device parity (optional, only when CUDA is available) +# ------------------------------------------------------------------------- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cuda_vs_cpu_decode_equivalence(): + """ + If CUDA is available, the CPU and CUDA encodes should decode equivalently. + """ + torch.manual_seed(0) + N, C, K = 7, 128, 10 # C=128 is divisible by both 64 and 128 + K = make_even_k(K) + + # Generate unique indices per row + idx_cpu = torch.zeros((N, K), device="cpu", dtype=torch.int64) + for i in range(N): + idx_cpu[i] = torch.randperm(C, device="cpu", dtype=torch.int64)[:K] + idx_gpu = idx_cpu.to("cuda") + + payload_cpu, _ = encode_batch_rows(idx_cpu, C=C) + payload_gpu, _ = encode_batch_rows(idx_gpu, C=C) + + rows_cpu, Cc, Nc = decode_batch_rows(payload_cpu) + rows_gpu, Cg, Ng = decode_batch_rows(payload_gpu) + + assert (Cc, Nc) == (C, N) and (Cg, Ng) == (C, N) + + # decoded rows must match exactly + assert rows_cpu == rows_gpu + + # permutations must reorder original to decoded in both cases + for i in range(N): + orig = idx_cpu[i].tolist() + re_cpu = [orig[p] for p in perm_cpu[i].cpu().tolist()] + re_gpu = [orig[p] for p in perm_gpu[i].cpu().tolist()] + assert re_cpu == rows_cpu[i] + assert re_gpu == rows_cpu[i] diff --git a/tests/unit/test_compress.py b/tests/unit/test_compress.py index 2eaab4fcd..eecb773b0 100644 --- a/tests/unit/test_compress.py +++ b/tests/unit/test_compress.py @@ -1,5 +1,6 @@ from typing import Literal +import numpy as np import pytest import torch import torch.nn as nn @@ -7,12 +8,13 @@ from tplr.compress import ( ChunkingTransformer, TopKCompressor, - _dct, - _get_smaller_split, - _idct, + encode_batch_rows, pack_12bit_indices, unpack_12bit_indices, ) +from tplr.compress.topk import ( + _get_smaller_split, +) class TestTopKCompressor: @@ -30,19 +32,19 @@ def compress_instance_quantized(self) -> TopKCompressor[Literal[True]]: use_quantization=True, quantization_bins=256, quantization_range=6 ) - def test_compress_produces_int16_indices( + def test_compress_produces_rice_bitmap_indices( self, compress_instance: TopKCompressor[Literal[False]] ): - """Test that compress() produces 12-bit packed indices""" + """Test that compress() produces Rice/bitmap encoded indices""" # Create test tensor - x = torch.randn(10, 10) + x = torch.randn(8, 64) # 512 elements total, last dim=64 topk = 10 # Compress using actual method idx, val, xshape, totalk = compress_instance.compress(x, topk) - # Verify index format - should be uint8 tensor for 12-bit packed - assert idx.dtype == torch.uint8, f"Expected uint8 packed data, got {idx.dtype}" + # Verify index format - should be uint8 tensor for Rice/bitmap codec + assert idx.dtype == torch.uint8, f"Expected uint8 encoded data, got {idx.dtype}" assert val.shape[-1] == topk assert xshape == x.shape # totalk is the size of the last dimension after rearranging @@ -53,7 +55,7 @@ def test_compress_with_quantization( self, compress_instance_quantized: TopKCompressor[Literal[True]] ): """Test compression with quantization enabled""" - x = torch.randn(10, 10) + x = torch.randn(8, 64) # 512 elements total, last dim=64 topk = 20 # Compress with quantization @@ -63,27 +65,28 @@ def test_compress_with_quantization( assert len(result) == 5 idx, val, _, _, qparams = result - # idx should be uint8 tensor for 12-bit packed format + # idx should be uint8 tensor for Rice/bitmap encoded format assert idx.dtype == torch.uint8 assert val.dtype == torch.uint8 # Quantized values assert qparams is not None assert len(qparams) == 5 # shift, scale, offset, lookup, orig_dtype - def test_decompress_with_12bit_tuple_format( + def test_decompress_with_rice_bitmap_format( self, compress_instance: TopKCompressor[Literal[False]] ): - """Test that decompress can handle 12-bit packed tuple format""" + """Test that decompress can handle Rice/bitmap encoded format""" # Setup - p = torch.zeros(10, 10) - xshape = (10, 10) - totalk = 100 + p = torch.zeros(8, 64) # 512 elements total, last dim=64 + xshape = (8, 64) + totalk = 64 - # Create proper 12-bit packed format using the actual packing function - # Create indices that are within valid range for a 10x10 tensor (even count) + # Create proper Rice/bitmap encoded format using the encoder + # Create indices that are within valid range for a 8x64 tensor (even count) original_indices = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.int64) - # Pack using the actual function - idx = pack_12bit_indices(original_indices) + # Pack using the new encoder format + payload, _ = encode_batch_rows(original_indices, C=totalk) + idx = torch.tensor(np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8) val = torch.tensor( [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32 @@ -94,22 +97,29 @@ def test_decompress_with_12bit_tuple_format( assert result.shape == xshape assert result.dtype == p.dtype - def test_batch_decompress_multiple_12bit_formats( + def test_batch_decompress_multiple_rice_bitmap_formats( self, compress_instance: TopKCompressor[Literal[False]] ): - """Test batch_decompress with multiple 12-bit packed indices""" + """Test batch_decompress with multiple Rice/bitmap encoded indices""" # Setup - p = torch.zeros(10, 10) - xshape = (10, 10) - totalk = 100 + p = torch.zeros(8, 64) # 512 elements total, last dim=64 + xshape = (8, 64) + totalk = 64 - # Create multiple 12-bit packed indices + # Create multiple Rice/bitmap encoded indices idx1_orig = torch.tensor([[0, 1], [2, 3]], dtype=torch.int64) idx2_orig = torch.tensor([[4, 5], [6, 7]], dtype=torch.int64) - # Pack them using the 12-bit format - idx1_packed = pack_12bit_indices(idx1_orig) - idx2_packed = pack_12bit_indices(idx2_orig) + # Pack them using the new encoder format + payload1, _ = encode_batch_rows(idx1_orig, C=totalk) + idx1_packed = torch.tensor( + np.frombuffer(payload1, dtype=np.uint8), dtype=torch.uint8 + ) + + payload2, _ = encode_batch_rows(idx2_orig, C=totalk) + idx2_packed = torch.tensor( + np.frombuffer(payload2, dtype=np.uint8), dtype=torch.uint8 + ) idx_list = [idx1_packed, idx2_packed] @@ -128,7 +138,7 @@ def test_compress_decompress_round_trip( self, compress_instance: TopKCompressor[Literal[False]] ): """Test full compress-decompress round trip""" - x = torch.zeros(10, 10) + x = torch.zeros(8, 64) # 512 elements total, last dim=64 x[0, 0] = 1.0 x[1, 1] = 2.0 x[2, 2] = 3.0 @@ -139,7 +149,9 @@ def test_compress_decompress_round_trip( idx, val, xshape, totalk = compress_instance.compress(x, topk) # Verify we got the top-k values - assert idx.dtype == torch.uint8, "Expected uint8 for 12-bit packed indices" + assert idx.dtype == torch.uint8, ( + "Expected uint8 for Rice/bitmap encoded indices" + ) assert val.shape[-1] == topk # Decompress @@ -157,46 +169,47 @@ def test_compress_decompress_round_trip( expected_vals = torch.tensor([4.0, 3.0, 2.0, 1.0]) assert torch.allclose(top_vals, expected_vals, atol=1e-5) - def test_12bit_index_value_range( + def test_rice_bitmap_index_value_range( self, compress_instance: TopKCompressor[Literal[False]] ): - """Test that indices can represent values appropriate for 12-bit range""" + """Test that Rice/bitmap codec can handle large index ranges efficiently""" # Create a large tensor that would have indices beyond 8-bit range - x = torch.randn(100, 100) # 10,000 elements + x = torch.randn(128, 128) # 16,384 elements topk = 100 # Compress idx, val, _, totalk = compress_instance.compress(x, topk) - # Check that indices are 12-bit packed format - assert idx.dtype == torch.uint8, "Expected uint8 for 12-bit packed indices" + # Check that indices are in the new codec format (uint8 bytes) + assert idx.dtype == torch.uint8, "Expected uint8 for Rice/bitmap codec" - # Since idx is packed, we can't directly check max values - # Instead verify the packing worked correctly - # Use val.shape since it has the same shape as the original indices - unpacked = unpack_12bit_indices(idx, val.shape) + # Since idx is a byte stream payload, we can't directly check max values + # Instead verify round-trip works correctly + p = torch.zeros_like(x) + result = compress_instance.decompress(p, idx, val, x.shape, totalk) - # Verify some indices might be larger than 255 (8-bit max) - max_idx = unpacked.max().item() - assert max_idx < 10000, f"Index {max_idx} exceeds tensor size" + # Check that decompression succeeded + assert result.shape == x.shape - # If tensor is large enough, we should have indices > 255 - if totalk > 256: - assert unpacked.max() > 255, ( - "Large tensor should have indices beyond 8-bit range" - ) + # For a 2D tensor, totalk is the size of the last dimension + assert totalk == 128, ( + f"Expected totalk=128 for 128x128 tensor (last dim), got {totalk}" + ) def test_batch_decompress_with_norm_options( self, compress_instance: TopKCompressor[Literal[False]] ): """Test batch_decompress with normalisation and clip_norm options""" - p = torch.zeros(10, 10) - xshape = (10, 10) - totalk = 100 + p = torch.zeros(8, 64) # 512 elements total, last dim=64 + xshape = (8, 64) + totalk = 64 - # Create test data with 12-bit packed format + # Create test data with Rice/bitmap encoded format idx_orig = torch.tensor([[0, 1, 2, 3]], dtype=torch.int64) # Even count - idx_packed = pack_12bit_indices(idx_orig) + payload, _ = encode_batch_rows(idx_orig, C=totalk) + idx_packed = torch.tensor( + np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8 + ) idx = [idx_packed] val = [torch.tensor([[10.0, 20.0, 30.0, 40.0]], dtype=torch.float32)] @@ -243,8 +256,6 @@ def test_transform_init(self, mock_model): # Check that dictionaries were populated assert len(transform.shape_dict) > 0 - assert len(transform.f_dict) > 0 - assert len(transform.b_dict) > 0 # Check that shape_dict contains parameter dimensions for param in mock_model.parameters(): @@ -261,34 +272,18 @@ def test_encode_decode_real_tensors(self, mock_model): param = next(mock_model.parameters()) # Test encoding - encoded = transform.encode(param, use_dct=False) + encoded = transform.encode(param) assert encoded.numel() == param.numel() # Test decoding - decoded = transform.decode(encoded, use_dct=False) + decoded = transform.decode(encoded) assert decoded.shape == param.shape assert torch.allclose(decoded, param.reshape(decoded.shape)) - # Test with DCT - encoded_dct = transform.encode(param, use_dct=True) - decoded_dct = transform.decode(encoded_dct, use_dct=True) - assert decoded_dct.shape == param.shape - class TestUtilityFunctions: """Test utility functions using actual implementations""" - def test_dct_idct_round_trip(self): - """Test DCT and IDCT implementations""" - x = torch.randn(4, 8) - - # Apply DCT then IDCT - X = _dct(x, norm="ortho") - x_reconstructed = _idct(X, norm="ortho") - - # Should reconstruct original - assert torch.allclose(x, x_reconstructed, atol=1e-6) - def test_get_smaller_split(self): """Test _get_smaller_split function""" # Test with actual use case diff --git a/tests/unit/test_neurons.py b/tests/unit/test_neurons.py index 48b18a018..b85dc35e8 100644 --- a/tests/unit/test_neurons.py +++ b/tests/unit/test_neurons.py @@ -161,7 +161,6 @@ def test_outer_step_master_node( device=self.device, is_master=True, world_size=2, - use_dct=False, wandb_run=self.wandb_run, global_step=1, ) @@ -187,7 +186,6 @@ def test_outer_step_worker_node( device=self.device, is_master=False, world_size=2, - use_dct=False, ) self.optimizer.step.assert_not_called() @@ -199,7 +197,6 @@ def setUp(self): self.miner = MagicMock() self.miner.hparams.outer_learning_rate = 0.01 self.miner.hparams.momentum_decay = 0.9 - self.miner.hparams.use_dct = False self.miner.hparams.topk_compression = 0.1 self.miner.model = MagicMock() self.miner.owned_params = {"param1", "param2"} @@ -273,7 +270,6 @@ def setUp(self): self.instance.xshapes = {} self.instance.totalks = {} self.instance.config.device = "cpu" - self.instance.hparams.use_dct = False self.instance.hparams.inner_steps = 10 self.instance.hparams.time_window_delta_seconds = 10 self.instance.loop = MagicMock()