From 730588efee0178ca6d1fdde76ef62e11534f4d2e Mon Sep 17 00:00:00 2001 From: Kasper Date: Mon, 10 Nov 2025 16:57:41 +0400 Subject: [PATCH 01/33] wip --- src/tplr/compression/__init__.py | 27 ++ src/tplr/compression/bits.py | 657 +++++++++++++++++++++++++++++++ src/tplr/compression/hybrid.py | 639 ++++++++++++++++++++++++++++++ 3 files changed, 1323 insertions(+) create mode 100644 src/tplr/compression/__init__.py create mode 100644 src/tplr/compression/bits.py create mode 100644 src/tplr/compression/hybrid.py diff --git a/src/tplr/compression/__init__.py b/src/tplr/compression/__init__.py new file mode 100644 index 000000000..d823aa0f5 --- /dev/null +++ b/src/tplr/compression/__init__.py @@ -0,0 +1,27 @@ + +# The MIT License (MIT) +# © 2025 tplr.ai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from .bits import ( + decode_batch_rows, # decoder (CPU) + encode_batch_rows, # GPU-accelerated encoder → bytes + perm + meta +) +__all__ = [ + # High level + "encode_batch_rows", + "decode_batch_rows", +] \ No newline at end of file diff --git a/src/tplr/compression/bits.py b/src/tplr/compression/bits.py new file mode 100644 index 000000000..21e87a563 --- /dev/null +++ b/src/tplr/compression/bits.py @@ -0,0 +1,657 @@ +# bits.py +# The MIT License (MIT) +# © 2025 tplr.ai +# +# Triton-powered Rice/bitmap encoder (per-row scheme) with a CPU decoder. +# Payload layout matches the CPU reference: +# [ C-1 : 12b ][ N : 16b ][ reserved : 1b ] +# then, for each row r=0..N-1: +# [ row_len_bytes[r] : 16b ][ row_payload_bits[r] ] +# +# Dependencies: torch, triton (runtime), numpy (only for decode consumer code elsewhere if needed) + +from __future__ import annotations +import math +from typing import Sequence, Tuple + +import torch +import torch.nn.functional as F + +try: + import triton + import triton.language as tl + + TRITON_AVAILABLE = True +except Exception: + TRITON_AVAILABLE = False + + +# -------------------------- CPU decoder (unchanged format) -------------------------- + + +class _BitReader: + def __init__(self, data: bytes) -> None: + self.data = data + self.idx = 0 + self.cur = 0 + self.nbits = 0 + + def _fill(self, n: int) -> None: + while self.nbits < n and self.idx < len(self.data): + self.cur |= int(self.data[self.idx]) << self.nbits + self.idx += 1 + self.nbits += 8 + + def read_bits(self, n: int) -> int: + if n <= 0: + return 0 + self._fill(n) + mask = (1 << n) - 1 + out = self.cur & mask + self.cur >>= n + self.nbits -= n + return out + + def read_unary(self) -> int: + q = 0 + while True: + bit = self.read_bits(1) + if bit == 0: + break + q += 1 + return q + + def read_bytes(self, n: int) -> bytes: + out = bytearray() + for _ in range(n): + out.append(self.read_bits(8)) + return bytes(out) + + +def _rice_read(br: _BitReader, k: int) -> int: + q = br.read_unary() + r = br.read_bits(k) if k > 0 else 0 + return (q << k) + r + + +def decode_batch_rows(payload: bytes) -> tuple[list[list[int]], int, int]: + """ + Decode payload created by encode_batch_rows(...). + Returns (rows, C, N) where `rows` is a list of per-row global indices. + """ + br = _BitReader(payload) + C = br.read_bits(12) + 1 + N = br.read_bits(16) + _ = br.read_bits(1) # reserved + + rows: list[list[int]] = [] + for _i in range(N): + row_len = br.read_bits(16) + row_bytes = br.read_bytes(row_len) + rr = _BitReader(row_bytes) + lb = rr.read_bits(5) + k_param = rr.read_bits(4) + use_bitmap = rr.read_bits(1) + B = 1 << lb + n_sub = C // B + + indices: list[int] = [] + for j in range(n_sub): + s_len = _rice_read(rr, k_param) + if s_len == 0: + continue + if use_bitmap: + bitmask = rr.read_bits(B) + for loc in range(B): + if (bitmask >> loc) & 1: + indices.append(j * B + loc) + else: + for _ in range(s_len): + loc = rr.read_bits(lb) + indices.append(j * B + loc) + rows.append(indices) + return rows, C, N + + +# --------------------------- GPU-side param selection --------------------------- + + +def _rice_k_from_mean(lmbda: float) -> int: + if lmbda <= 0.0: + return 0 + return max(0, round(math.log2(max(lmbda, 1e-9)))) + + +@torch.no_grad() +def _estimate_best_params_per_row( + idx: torch.Tensor, C: int, B_choices: Sequence[int] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Torch (GPU) estimate of best B, use_bitmap, and k_param per row. + Mirrors your previous vectorised selector. + """ + assert idx.dtype == torch.int64 + rows, k = idx.shape + device = idx.device + + B_sorted = tuple( + sorted([b for b in B_choices if b > 0 and (C % b) == 0 and (b & (b - 1)) == 0]) + ) + if not B_sorted: + raise ValueError("No valid B choices for C") + + header = 5 + 4 + 1 + best_bits = torch.full((rows,), 1 << 60, device=device, dtype=torch.int64) + best_B = torch.full((rows,), B_sorted[0], device=device, dtype=torch.int64) + best_use_bitmap = torch.zeros((rows,), device=device, dtype=torch.bool) + + Bmin = B_sorted[0] + if all((b % Bmin) == 0 for b in B_sorted): + n_sub_min = C // Bmin + js_min = (idx // Bmin).to(torch.int64) # [rows, k] + counts_min = ( + F.one_hot(js_min, num_classes=n_sub_min).sum(dim=1).to(torch.int64) + ) # [rows, n_sub_min] + lmbda_base = k / max(1, C) + + for B in B_sorted: + g = B // Bmin + counts_B = ( + counts_min + if g == 1 + else counts_min.reshape(rows, n_sub_min // g, g).sum(dim=2) + ) + lb = int(math.ceil(math.log2(B))) + n_sub = C // B + k_param = int(max(0, round(math.log2(max(lmbda_base * B, 1e-9))))) + m = 1 << k_param + q = counts_B // m + rb_sum = q.sum(dim=1) + (1 + k_param) * n_sub + nonzero = (counts_B > 0).sum(dim=1) + bits_local = header + rb_sum + lb * k + bits_bitmap = header + rb_sum + B * nonzero + cur_bits = torch.minimum(bits_local, bits_bitmap).to(torch.int64) + use_bitmap = bits_bitmap < bits_local + update = cur_bits < best_bits + best_bits = torch.where(update, cur_bits, best_bits) + best_B = torch.where(update, torch.full_like(best_B, B), best_B) + best_use_bitmap = torch.where(update, use_bitmap, best_use_bitmap) + else: + for B in B_sorted: + lb = int(math.ceil(math.log2(B))) + n_sub = C // B + js = (idx // B).to(torch.int64) + # Bincount per row with a constant bin width M=max_n_sub to avoid collisions + M = C // min(B_sorted) + row_ids = torch.arange(rows, device=device, dtype=torch.int64).unsqueeze(1) + flat = (row_ids * M + js).reshape(-1) + counts = torch.bincount(flat, minlength=rows * M).reshape(rows, M)[ + :, :n_sub + ] + lmbda = (k / max(1, C)) * B + k_param = int(max(0, round(math.log2(max(lmbda, 1e-9))))) + m = 1 << k_param + q = counts // m + rb_sum = q.sum(dim=1) + (1 + k_param) * n_sub + nonzero = (counts > 0).sum(dim=1) + bits_local = header + rb_sum + lb * k + bits_bitmap = header + rb_sum + B * nonzero + cur_bits = torch.minimum(bits_local, bits_bitmap).to(torch.int64) + use_bitmap = bits_bitmap < bits_local + update = cur_bits < best_bits + best_bits = torch.where(update, cur_bits, best_bits) + best_B = torch.where(update, torch.full_like(best_B, B), best_B) + best_use_bitmap = torch.where(update, use_bitmap, best_use_bitmap) + + # Rice k for chosen B + lmbda = (idx.shape[1] / max(1, C)) * best_B.float() + k_param = torch.clamp((lmbda.clamp_min(1e-9).log2().round()).to(torch.int64), min=0) + + return ( + best_B.to(torch.int32), + best_use_bitmap.to(torch.uint8), + k_param.to(torch.int32), + ) + + +# --------------------------- Triton helpers (no tl.uint32 calls) --------------------------- + + +@triton.jit +def _write_bits_u32(buf_ptr, bitpos, value, nbits): + # Write 'nbits' LSBs from value, advance bitpos, and return the new bitpos. + i = tl.zeros((), dtype=tl.int32) + bp = bitpos.to(tl.int64) + v = value.to(tl.int32) + while i < nbits: + bit = (v >> i) & 1 + byte_idx = bp // 8 + off = bp % 8 + p = buf_ptr + byte_idx + old = tl.load(p, mask=True, other=0).to(tl.int32) + newv = old | (bit << off) + tl.store(p, newv.to(tl.uint8)) + bp += 1 + i += 1 + return bp + + +@triton.jit +def _write_unary(buf_ptr, bitpos, q): + # Write q ones then a zero; zeros are already in buffer → just advance for trailing zero. + bp = bitpos.to(tl.int64) + i = tl.zeros((), dtype=tl.int32) + while i < q: + byte_idx = bp // 8 + off = bp % 8 + p = buf_ptr + byte_idx + old = tl.load(p, mask=True, other=0).to(tl.int32) + newv = old | (1 << off) + tl.store(p, newv.to(tl.uint8)) + bp += 1 + i += 1 + # trailing zero bit: buffer is zero-initialized, so just skip one bit + bp += 1 + return bp + + +@triton.jit +def _write_rice(buf_ptr, bitpos, x, kparam): + # Golomb-Rice for non-negative x with parameter k. + k = kparam.to(tl.int32) + if k == 0: + return _write_unary(buf_ptr, bitpos, x) + m = 1 << k + q = x // m + r = x & (m - 1) + bp = _write_unary(buf_ptr, bitpos, q) + bp = _write_bits_u32(buf_ptr, bp, r, k) + return bp + + +@triton.jit +def _set_one_bit(buf_ptr, bitpos): + bp = bitpos.to(tl.int64) + byte_idx = bp // 8 + off = bp % 8 + p = buf_ptr + byte_idx + old = tl.load(p, mask=True, other=0).to(tl.int32) + newv = old | (1 << off) + tl.store(p, newv.to(tl.uint8)) + + +# ------------------------------ Triton kernel --------------------------------- + + +@triton.jit +def _kernel_write_rows( + idx_ptr, # int64 [N*K] + rows, + k, + C, # ints + bestB_ptr, # int32 [N] + usebm_ptr, # uint8 [N] (0/1) + kparam_ptr, # int32 [N] + row_bytes_ptr, # int32 [N] + len_bitpos_ptr, # int64 [N] (bitpos of 16-bit length) + pay_bitpos_ptr, # int64 [N] (bitpos of first payload bit for the row) + payload_ptr, # uint8 [TOTAL_BYTES] + K_MAX: tl.constexpr, # upper bound for K (e.g., 256 or 512) +): + r = tl.program_id(0) + if r >= rows: + return + + # Per-row params + B = tl.load(bestB_ptr + r).to(tl.int32) + use_bitmap = tl.load(usebm_ptr + r).to(tl.int1) + kparam = tl.load(kparam_ptr + r).to(tl.int32) + # B is power-of-two ⇒ lb = log2(B) exactly (cast to float for tl.log2) + lb = tl.log2(B.to(tl.float32)).to(tl.int32) + n_sub = C // B + + # Write 16-bit length at its position (interleaved layout) + row_len = tl.load(row_bytes_ptr + r).to(tl.int32) + len_bp = tl.load(len_bitpos_ptr + r) + _ = _write_bits_u32(payload_ptr, len_bp, row_len, 16) + + # Row payload header + bp = tl.load(pay_bitpos_ptr + r) + bp = _write_bits_u32(payload_ptr, bp, lb, 5) # lb + bp = _write_bits_u32(payload_ptr, bp, kparam, 4) # k + bp = _write_bits_u32(payload_ptr, bp, use_bitmap.to(tl.int32), 1) # mode + + # Emit each sub-chunk j in ascending order + j = tl.zeros((), dtype=tl.int32) + while j < n_sub: + # -- first pass: count how many entries go to sub j + got = tl.zeros((), dtype=tl.int32) + t = tl.zeros((), dtype=tl.int32) + while t < k: + v = tl.load(idx_ptr + r * k + t).to(tl.int64) + jj = (v // B).to(tl.int32) + got += jj == j + t += 1 + + # write Rice length + bp = _write_rice(payload_ptr, bp, got, kparam) + + if got > 0: + if use_bitmap: + # second pass: set bits for locations; advance by B bits + start = bp + t = tl.zeros((), dtype=tl.int32) + while t < k: + v = tl.load(idx_ptr + r * k + t).to(tl.int64) + jj = (v // B).to(tl.int32) + if jj == j: + loc = (v - j.to(tl.int64) * B.to(tl.int64)).to(tl.int32) + _set_one_bit(payload_ptr, start + loc) + t += 1 + bp = start + B + else: + # local list: second pass writing lb-bit locs (order needn't be sorted) + t = tl.zeros((), dtype=tl.int32) + while t < k: + v = tl.load(idx_ptr + r * k + t).to(tl.int64) + jj = (v // B).to(tl.int32) + if jj == j: + loc = (v - j.to(tl.int64) * B.to(tl.int64)).to(tl.int32) + bp = _write_bits_u32(payload_ptr, bp, loc, lb) + t += 1 + j += 1 + # done + + +@triton.jit +def sub_block_cost_kernel( + idx_ptr, # IN: [rows, k_dim] int64 + best_B_ptr, # OUT: [rows] int32 + best_use_bitmap_ptr, # OUT: [rows] uint8 + best_k_param_ptr, # OUT: [rows] int32 + B_choices_ptr, # IN: [num_B_choices] int32 + k_params_ptr, # IN: [num_B_choices] int32 + rows: tl.int32, + num_B_choices: tl.int32, + k_dim: tl.constexpr, # K (e.g., 128) + C: tl.constexpr, + N_SUB_MAX: tl.constexpr, # C // B_min +): + """ + Triton kernel to find the best (B, k_param, use_bitmap) for each row + using the "sub-block" algorithm. + + This version uses a DYNAMIC loop over B_choices, so it compiles instantly. + """ + row_idx = tl.program_id(0) + if row_idx >= rows: + return + + # --- 1. Load all k_dim indices for this row into SRAM --- + k_offsets = tl.arange(0, k_dim) + row_k_mask = k_offsets < k_dim + idx_vals = tl.load(idx_ptr + row_idx * k_dim + k_offsets, mask=row_k_mask, other=0) + + # --- 2. Create index block for histogram --- + # j_indices will be [0, 1, 2, ..., N_SUB_MAX-1] + j_indices = tl.arange(0, N_SUB_MAX) + + # --- 3. Initialize best-of tracking for this row --- + best_bits = 1 << 60 # "infinity" (this is tl.int64) + best_B_val = 0 # tl.int32 + + # FIX: Initialize as tl.int1 (boolean) type + best_use_bitmap_val = tl.zeros((), dtype=tl.int1) # tl.int1 + + best_k_param_val = 0 # tl.int32 + + header_bits = 5 + 4 + 1 # lb + k + mode + + # --- 4. Dynamically loop over B_choices --- + b_idx = 0 + while b_idx < num_B_choices: + # Load B and k_param from tensors + B = tl.load(B_choices_ptr + b_idx) + k_param = tl.load(k_params_ptr + b_idx) + + n_sub = C // B + + # FIX: Cast tl.log2 (which returns float32) back to int64 + lb = tl.log2(B.to(tl.float32)).to(tl.int64) # B is power-of-two, so this is exact + + m_val = 1 << k_param + + # --- 5. Calculate sub-block counts (The Histogram) --- + + # [k_dim] -> [0, 5, 0, 1, ...] sub-block ID for each index + j_block = (idx_vals // B.to(tl.int64)).to(tl.int64) + + # Broadcasted histogram: + # (j_block[None, :] == j_indices[:, None]) creates a [N_SUB_MAX, k_dim] matrix + # tl.sum(..., axis=1) sums along the k_dim axis. + counts_all = tl.sum((j_block[None, :] == j_indices[:, None]), axis=1) + + # We only care about the counts for the *valid* sub-blocks + n_sub_mask = j_indices < n_sub + counts_unsigned = tl.where(n_sub_mask, counts_all, 0) + + # --- 6. Calculate cost for this B --- + + # Cast both operands to tl.int64 to avoid signedness mismatch + counts = counts_unsigned.to(tl.int64) + m = m_val.to(tl.int64) + + q = counts // m + + # Σ(q + 1) + n_sub * k + rb_sum = tl.sum((q + 1) * n_sub_mask.to(tl.int64)) + (k_param * n_sub) + + # FIX: Cast nonzero to tl.int64 to prevent type mismatch + nonzero = tl.sum((counts > 0).to(tl.int64)) + + bits_local = header_bits + rb_sum + lb * k_dim + bits_bitmap = header_bits + rb_sum + B.to(tl.int64) * nonzero + + cur_bits = tl.minimum(bits_local, bits_bitmap) # This is now int64 + use_bitmap = bits_bitmap < bits_local # This is tl.int1 + + # --- 7. Update best-of --- + if cur_bits < best_bits: # This is now int64 < int64 + best_bits = cur_bits + best_B_val = B + best_use_bitmap_val = use_bitmap # This is now int1 = int1 (SUCCESS) + best_k_param_val = k_param + + b_idx += 1 # Advance dynamic loop + + # --- 8. Store results for this row --- + tl.store(best_B_ptr + row_idx, best_B_val) + tl.store(best_use_bitmap_ptr + row_idx, best_use_bitmap_val) + tl.store(best_k_param_ptr + row_idx, best_k_param_val) + + +# ---------------------------------------------------------------------------- +# 2. Python Wrapper for the new Triton Kernel +# ---------------------------------------------------------------------------- + +def _estimate_best_params_per_row_triton( + idx: torch.Tensor, C: int, B_choices: Sequence[int] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Python wrapper for the sub_block_cost_kernel. + """ + rows, k_dim = idx.shape + device = idx.device + + # --- 1. Pre-calculate all args for the kernel --- + B_sorted = tuple( + sorted([b for b in B_choices if b > 0 and (C % b) == 0 and (b & (b - 1)) == 0]) + ) + if not B_sorted: + raise ValueError("No valid B choices for C") + + Bmin = B_sorted[0] + N_SUB_MAX = C // Bmin + num_B_choices = len(B_sorted) + + lmbda_base = k_dim / max(1, C) + + # --- 2. Create Tensors to pass to kernel --- + B_choices_tensor = torch.tensor(B_sorted, dtype=torch.int32, device=device) + + K_params_list = [ + int(max(0, round(math.log2(max(lmbda_base * B, 1e-9))))) + for B in B_sorted + ] + K_params_tensor = torch.tensor(K_params_list, dtype=torch.int32, device=device) + + # --- 3. Allocate output tensors --- + best_B = torch.empty((rows,), dtype=torch.int32, device=device) + best_use_bitmap = torch.empty((rows,), dtype=torch.uint8, device=device) + best_k_param = torch.empty((rows,), dtype=torch.int32, device=device) + + # --- 4. Launch Kernel --- + grid = (rows,) + + # Corrected Kernel Call: + # All non-constexpr args must be passed positionally. + sub_block_cost_kernel[grid]( + idx, + best_B, + best_use_bitmap, + best_k_param, + B_choices_tensor, + K_params_tensor, + rows, + num_B_choices, + # Constexpr args are passed by keyword + k_dim=k_dim, + C=C, + N_SUB_MAX=N_SUB_MAX, + ) + + # The Triton kernel stored uint8(bool), convert back to bool for consistency + return best_B, best_use_bitmap.to(torch.bool), best_k_param + + +# -------------------------------- Public API ---------------------------------- + + +@torch.no_grad() +def encode_batch_rows( + idx: torch.Tensor, # [rows, k] int64 (CUDA strongly recommended) + *, + C: int, + B_choices: tuple[int, ...] = (64, 128), +) -> tuple[bytes, dict]: + """ + Triton encoder for per-row Rice/bitmap codec. + + Returns: + payload: bytes + meta: {total_bits, avg_bits_per_row, B_hist} + """ + if not TRITON_AVAILABLE: + raise RuntimeError("Triton is not available. `pip install triton` and re-run.") + + if idx.dtype != torch.int64: + idx = idx.to(torch.int64) + if not idx.is_cuda: + idx = idx.cuda() + idx = idx.contiguous() + + rows, k = idx.shape + device = idx.device + + # 1) Pick best params per row (GPU) + best_B, use_bitmap, k_param = _estimate_best_params_per_row( + idx, C=C, B_choices=B_choices + ) + + # 2) Compute exact per-row bit counts to size output + header_bits_row = 5 + 4 + 1 # lb + k + mode + lb = (best_B.float().log2().round().to(torch.int64)).clamp_min( + 1 + ) # exact for power-of-two + n_sub = C // best_B.to(torch.int64) + + # Bincount per row with constant width M to prevent collisions + M = int(max(C // int(b) for b in best_B.unique().tolist())) + row_ids = torch.arange(rows, device=device, dtype=torch.int64).unsqueeze(1) + js = (idx // best_B.to(torch.int64).unsqueeze(1)).to(torch.int64) # [rows, k] + flat = (row_ids * M + js).reshape(-1) + counts = torch.bincount(flat, minlength=rows * M).reshape(rows, M) # [rows, M] + # limit to effective n_sub per row when summing + # rice bits: Σ(q + 1 + k) with q = c // 2^k + m = 1 << k_param.to(torch.int64) + q = counts // m.unsqueeze(1) + rb_sum = q.sum(dim=1) + (1 + k_param.to(torch.int64)) * n_sub.to(torch.int64) + nonzero = (counts > 0).sum(dim=1).to(torch.int64) + + bits_local = header_bits_row + rb_sum + lb * k + bits_bitmap = header_bits_row + rb_sum + best_B.to(torch.int64) * nonzero + row_bits = torch.minimum(bits_local, bits_bitmap).to(torch.int64) + row_bytes = ((row_bits + 7) // 8).to(torch.int32) + + # 3) Allocate payload buffer (global header + interleaved [len16 | payload]) + total_bits_rows = int((16 * rows + 8 * row_bytes.sum().item())) + total_bits = 12 + 16 + 1 + total_bits_rows + total_bytes = (total_bits + 7) // 8 + payload = torch.zeros(total_bytes, dtype=torch.uint8, device=device) + + # 4) Compute interleaved bit positions for each row + header_bits = 12 + 16 + 1 + body_chunk_bits = 16 + 8 * row_bytes.to(torch.int64) # [rows] + prefix = torch.zeros_like(body_chunk_bits) + if rows > 0: + prefix[1:] = torch.cumsum(body_chunk_bits[:-1], dim=0) + len_bitpos = header_bits + prefix # [rows] + pay_bitpos = len_bitpos + 16 # [rows] + + # 5) Write global header in-place (LSB-first) using torch ops + def _write_scalar_bits(val: int, nbits: int, start_bit: int): + v = int(val) + bp = int(start_bit) + nb = int(nbits) + while nb > 0: + byte_idx = bp // 8 + off = bp % 8 + take = min(nb, 8 - off) + mask = ((v & ((1 << take) - 1)) << off) & 0xFF + payload[byte_idx] |= torch.as_tensor(mask, dtype=torch.uint8, device=device) + v >>= take + bp += take + nb -= take + + _write_scalar_bits(C - 1, 12, 0) + _write_scalar_bits(rows, 16, 12) + _write_scalar_bits(0, 1, 28) # reserved + + # 6) Launch Triton to write all rows + # Choose a safe K_MAX for your top-k; 256 covers k<=256; use 512 if you push k higher. + K_MAX = 256 if k <= 256 else 512 + grid = (rows,) + _kernel_write_rows[grid]( + idx, + rows, + k, + C, + best_B, + use_bitmap, + k_param, + row_bytes, + len_bitpos.to(torch.int64), + pay_bitpos.to(torch.int64), + payload, + K_MAX=K_MAX, + ) + + # 7) Return bytes + meta + payload_bytes = bytes(payload.detach().cpu().numpy().tobytes()) + B_hist = {int(b): int((best_B == b).sum().item()) for b in best_B.unique()} + meta = { + "total_bits": total_bits, + "avg_bits_per_row": float(row_bits.float().mean().item()) if rows > 0 else 0.0, + "B_hist": B_hist, + } + return payload_bytes, meta diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py new file mode 100644 index 000000000..3e10bffbb --- /dev/null +++ b/src/tplr/compression/hybrid.py @@ -0,0 +1,639 @@ +import math +from typing import Dict +from typing import List, Tuple, Union + +import numpy as np +import torch +import triton +import triton.language as tl + +BytesLike = Union[bytes, bytearray, np.ndarray, torch.Tensor] + + +def encode_batch_rows( + idx: torch.Tensor, + *, + C: int, + B_choices: Tuple[int, ...] = (64, 128) +) -> Tuple[bytes, Dict]: + """ + Compresses a 2D int64 tensor of Top-K indices into a byte string + using a per-row adaptive Rice/Bitmap compression scheme on the GPU. + + Args: + idx (torch.Tensor): [rows, k] int64 tensor of indices. + C (int): The total number of columns (0 <= idx < C). + B_choices (tuple[int, ...]): Block sizes to evaluate. + Must be powers of two. + Must evenly divide C. + + Returns: + tuple[bytes, dict]: (payload, meta) + - payload (bytes): The compressed byte string. + - meta (dict): Metadata about the compression. + """ + + # --- 1. Input Validation & Setup --- + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this function.") + + if not isinstance(idx, torch.Tensor) or idx.ndim != 2 or idx.dtype != torch.int64: + raise ValueError(f"idx must be a 2D int64 tensor, got {idx.shape} {idx.dtype}") + + if not all(isinstance(b, int) and (b & (b - 1) == 0) and b > 0 for b in B_choices): + raise ValueError(f"All B_choices must be powers of two, got {B_choices}") + + if not all(C % b == 0 for b in B_choices): + raise ValueError(f"All B_choices must evenly divide C={C}, got {B_choices}") + + num_rows, k_dim = idx.shape + if num_rows == 0: + return b"", { + "total_bits": 0, + "avg_bits_per_row": 0.0, + "B_hist": {b: 0 for b in B_choices} + } + + dev = torch.device("cuda") + idx = idx.to(dev, non_blocking=True) + + # Calculate k_rice parameters (log2(C // B)) + k_rice_choices = tuple(int(math.log2(C // b)) for b in B_choices) + num_B_choices = len(B_choices) + + # Create tensors for dynamic kernel + B_choices_tensor = torch.tensor(B_choices, dtype=torch.int32, device=dev) + k_rice_choices_tensor = torch.tensor(k_rice_choices, dtype=torch.int32, device=dev) + + # Bits needed to store the B_choice index + B_choice_bits = (num_B_choices - 1).bit_length() + + # Row header: 1 bit (bitmap/rice) + B_choice_bits + ROW_HEADER_BITS = 1 + B_choice_bits + + # --- 2. GPU Preprocessing: Sort & Delta Encode --- + + # Sort each row for delta encoding + idx_sorted, _ = torch.sort(idx, dim=1) + + # Delta encode: val[0], val[1]-val[0], val[2]-val[1], ... + delta = torch.cat( + (idx_sorted[:, :1], idx_sorted[:, 1:] - idx_sorted[:, :-1]), + dim=1 + ) + # Ensure all deltas are non-negative + delta = torch.clamp(delta, min=0) + + # --- 3. Kernel 1: Cost Analysis --- + + # Output tensors for cost kernel + costs = torch.empty((num_rows, num_B_choices), dtype=torch.int32, device=dev) + is_bitmap = torch.empty((num_rows, num_B_choices), dtype=torch.int8, device=dev) + + # Grid is 1D, one program per row + grid = (num_rows,) + + # Launch cost kernel + # k_dim is passed as constexpr for tl.arange, but B_choices are dynamic + cost_kernel[grid]( + delta, + costs, + is_bitmap, + C=C, + k_dim=k_dim, + num_rows=num_rows, + num_B_choices=num_B_choices, + B_choices_ptr=B_choices_tensor, + k_rice_choices_ptr=k_rice_choices_tensor, + ) + + # --- 4. Post-Kernel 1: pick best B/mode & compute layout --- + + # Best choice per row + min_costs, best_B_idx = torch.min(costs, dim=1) + is_bitmap_choice = torch.gather(is_bitmap, 1, best_B_idx.unsqueeze(1)).squeeze(1).to(torch.int32) + + # (1) payload bits per row (deltas only) + row_payload_bits = min_costs + ROW_HEADER_BITS # (rows,) + + # (2) payload bytes per row (rounded up) + row_payload_bytes = ((row_payload_bits + 7) // 8).to(torch.int32) # (rows,) + + # (3) on-wire bits per row = 16 (length) + payload rounded to bytes + row_bits_aligned = (16 + row_payload_bytes * 8).to(torch.int64) # (rows,) + + # (4) starting bit offsets (before header) + row_bit_offsets = torch.nn.functional.pad( + torch.cumsum(row_bits_aligned, dim=0, dtype=torch.int64)[:-1], + (1, 0) + ) + + # (5) total bits across all rows (Python int) + total_bits = int(row_bits_aligned.sum().item()) + + # Build global header bytes + header_list = [] + header_list.append(b"CGRP") # 4B magic + header_list.append(int(C).to_bytes(4, "little")) # 4B C (uint32 LE) + header_list.append(int(k_dim).to_bytes(2, "little")) # 2B K (uint16 LE) <--- NEW + header_list.append(bytes([len(B_choices)])) # 1B num_B + for b in B_choices: + header_list.append(int(b).to_bytes(2, "little")) # 2B per B (uint16 LE) + + global_header_py = b"".join(header_list) + global_header_len_bytes = len(global_header_py) + + # shift row starts by header + row_bit_offsets = row_bit_offsets + global_header_len_bytes * 8 + + # final sizes (Python ints) + total_payload_bytes = (total_bits + 7) // 8 + final_buffer_bytes = global_header_len_bytes + total_payload_bytes + + # allocate + write header + payload_buf = torch.zeros(final_buffer_bytes, dtype=torch.uint8, device=dev) + payload_buf[:global_header_len_bytes] = torch.tensor( + list(global_header_py), dtype=torch.uint8, device=dev + ) + + # --- pack kernel --- + pack_kernel[(num_rows,)]( + delta, + payload_buf, + row_bit_offsets.to(torch.int32), + row_payload_bytes, # already int32 + best_B_idx.to(torch.int32), + is_bitmap_choice, # int32 0/1 + B_choices_tensor, + k_rice_choices_tensor, + num_rows, + C=C, + k_dim=k_dim, + ROW_HEADER_BITS=ROW_HEADER_BITS, + ) + + # --- 7. Copy to CPU and Return --- + + # Copy buffer from GPU to CPU + payload_cpu = payload_buf.cpu().numpy() + + # Convert to final bytes object + payload_bytes = payload_cpu.tobytes() + + # --- meta --- + b_counts = torch.bincount(best_B_idx, minlength=len(B_choices)) + B_hist = {b: c.item() for b, c in zip(B_choices, b_counts)} + meta = { + "total_bits": total_bits, # includes 16-bit length and byte padding + "avg_bits_per_row": float(row_bits_aligned.float().mean().item()), + "avg_payload_bits_per_row": float(row_payload_bits.float().mean().item()), + # header+payload, no 16-bit length, before byte-rounding + "B_hist": B_hist, + } + return payload_bytes, meta + + +# --- Triton Kernel 1: Cost Analysis --- + +@triton.jit +def cost_kernel( + delta_ptr, # (rows, k_dim) IN + costs_ptr, # (rows, num_B_choices) OUT + is_bitmap_ptr, # (rows, num_B_choices) OUT (bool/int) + C: tl.int32, + k_dim: tl.constexpr, # constexpr for tl.arange + num_rows: tl.int32, + num_B_choices: tl.int32, + B_choices_ptr, # IN (tensor) + k_rice_choices_ptr, # IN (tensor) +): + """ + Calculates the compressed bit cost for each row for each B in B_choices. + One program instance processes one row. + Variant B: first delta encoded with Rice, tail optionally bitmap (q in {0,1}). + """ + row_idx = tl.program_id(0) + if row_idx >= num_rows: + return + + # Lane indices for this row (constexpr width) + i = tl.arange(0, k_dim) + + # Load the entire row of delta-encoded values into SRAM + row_base = row_idx * k_dim + delta = tl.load(delta_ptr + row_base + i) + + # Also load the first delta as a scalar for q0 (avoids illegal q[0] indexing) + delta0 = tl.load(delta_ptr + row_base) + + # Iterate over B choices dynamically + b_idx = 0 + while b_idx < num_B_choices: + # Dynamic parameters for this choice + B = tl.load(B_choices_ptr + b_idx) + k_rice = tl.load(k_rice_choices_ptr + b_idx) + + # Rice modulus + M = C // B + + # Vectorized q for the row and scalar q0 for the first element + q = delta // M + q0 = delta0 // M + + # Pure Rice cost: sum(unary(q)) + sum(r) where unary(q) has (q + 1) bits, + # and r contributes k_rice bits per element. + rice_cost = tl.sum(q + 1) + k_dim * k_rice + + # Variant B bitmap cost: + # - first element written with full Rice: (q0 + 1 + k_rice) + # - remaining (k_dim - 1) elements written as (1 + k_rice) each + # (1 bit for q in {0,1} + k_rice bits for r) + bitmap_cost = (q0 + 1 + k_rice) + (k_dim - 1) * (1 + k_rice) + # equivalently: bitmap_cost = k_dim * (1 + k_rice) + q0 + + # Allow bitmap only if tail q are in {0,1} + # Compute tail max with a masked reduction (ignore lane 0) + q_tail_max = tl.max(tl.where(i > 0, q, 0)) + bitmap_allowed = q_tail_max <= 1 + + use_bitmap = (bitmap_cost < rice_cost) & bitmap_allowed + min_cost = tl.where(use_bitmap, bitmap_cost, rice_cost) + + out_offset = row_idx * num_B_choices + b_idx + tl.store(costs_ptr + out_offset, min_cost) + # make sure is_bitmap is exactly 0/1 in memory + tl.store(is_bitmap_ptr + out_offset, tl.where(use_bitmap, 1, 0)) + b_idx += 1 + + +# --- Triton Kernel 2: Bit-Stream Packing --- + +@triton.jit +def write_nbits(u8_ptr, bit_off_i32, value_u32, nbits_i32): + """ + Write `nbits_i32` bits from `value_u32` at bit offset `bit_off_i32` (LSB-first). + All args are Triton scalars: + - bit_off_i32 : tl.int32 + - value_u32 : tl.uint32 (<= 32 bits) + - nbits_i32 : tl.int32 + Returns new bit offset (tl.int32). + """ + j = tl.full((), 0, dtype=tl.int32) + ONE_U32 = tl.full((), 1, dtype=tl.uint32) + + while j < nbits_i32: + pos = bit_off_i32 + j + byte_idx = (pos >> 3).to(tl.int32) + bit_idx = (pos & 7).to(tl.int32) + + old_u8 = tl.load(u8_ptr + byte_idx) + old_u32 = old_u8.to(tl.uint32) + + vbit = (value_u32 >> j) & ONE_U32 + mask = ONE_U32 << bit_idx + new_u32 = (old_u32 & (~mask)) | (vbit << bit_idx) + tl.store(u8_ptr + byte_idx, new_u32.to(tl.uint8)) + + j += 1 + + return bit_off_i32 + nbits_i32 + + +@triton.jit +def pack_kernel( + delta_ptr, # (rows, k_dim) IN int32 + u8_payload_ptr, # (final_buffer_bytes,) OUT uint8 + row_bit_offsets_ptr, # (rows,) IN (int32 preferred) + row_payload_bytes_ptr, # (rows,) IN int32 + best_B_idx_ptr, # (rows,) IN int32 + is_bitmap_ptr, # (rows,) IN int32 (0/1) + B_choices_ptr, # [num_B] IN int32 + k_rice_choices_ptr, # [num_B] IN int32 + num_rows: tl.int32, + C: tl.constexpr, + k_dim: tl.int32, # dynamic + ROW_HEADER_BITS: tl.constexpr, +): + """ + Variant B: first delta Rice (unary = q ones then 0) + r; tail bitmap or Rice. + Bit order: LSB-first. + """ + row_idx = tl.program_id(0) + if row_idx >= num_rows: + return + + # per-row meta + bit_off_i32 = tl.load(row_bit_offsets_ptr + row_idx).to(tl.int32) + payload_bytes_i32 = tl.load(row_payload_bytes_ptr + row_idx).to(tl.int32) + b_idx_i32 = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) + use_bitmap_i32 = (tl.load(is_bitmap_ptr + row_idx) & 1).to(tl.int32) + + # params + B_i32 = tl.load(B_choices_ptr + b_idx_i32).to(tl.int32) + k_rice_i32 = tl.load(k_rice_choices_ptr + b_idx_i32).to(tl.int32) + M_i32 = (C // B_i32).to(tl.int32) + + ONE_U32 = tl.full((), 1, dtype=tl.uint32) + ZERO_U32 = tl.full((), 0, dtype=tl.uint32) + ONE_I32 = tl.full((), 1, dtype=tl.int32) + THIRTY_ONE_I32 = tl.full((), 31, dtype=tl.int32) # **cap chunks at 31** + + # 16-bit length + bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, + payload_bytes_i32.to(tl.uint32), + tl.full((), 16, dtype=tl.int32)) + + # header ((b_idx << 1) | use_bitmap) + header_i32 = (b_idx_i32 << 1) | use_bitmap_i32 + bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, + header_i32.to(tl.uint32), + tl.full((), ROW_HEADER_BITS, dtype=tl.int32)) + + base = row_idx * k_dim + + # ---- first delta: ALWAYS Rice ---- + if k_dim > 0: + v0 = tl.load(delta_ptr + base).to(tl.int32) + q0 = (v0 // M_i32).to(tl.int32) + r0 = (v0 % M_i32).to(tl.int32) + + # q0 ones in chunks of <=31, then a single 0 + q_left = q0 + while q_left > 0: + chunk = tl.minimum(q_left, THIRTY_ONE_I32) + ones = (ONE_U32 << chunk) - ONE_U32 + bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, ones, chunk) + q_left -= chunk + bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, ZERO_U32, ONE_I32) + + # remainder + bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, r0.to(tl.uint32), k_rice_i32) + + # ---- tail ---- + i = 1 + while i < k_dim: + v = tl.load(delta_ptr + base + i).to(tl.int32) + q = (v // M_i32).to(tl.int32) + r = (v % M_i32).to(tl.int32) + + # Rice unary only if NOT bitmap + q_left = tl.where(use_bitmap_i32 != 0, tl.full((), 0, dtype=tl.int32), q) + while q_left > 0: + chunk = tl.minimum(q_left, THIRTY_ONE_I32) + ones = (ONE_U32 << chunk) - ONE_U32 + bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, ones, chunk) + q_left -= chunk + n_term = tl.where(use_bitmap_i32 != 0, tl.full((), 0, dtype=tl.int32), ONE_I32) + bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, ZERO_U32, n_term) + + # bitmap q only if bitmap + q_bit = tl.where(q > 0, ONE_U32, ZERO_U32) + n_qbit = tl.where(use_bitmap_i32 != 0, ONE_I32, tl.full((), 0, dtype=tl.int32)) + bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, q_bit, n_qbit) + + # remainder always + bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, r.to(tl.uint32), k_rice_i32) + + i += 1 + + + +# --------------------------- +# Bitstream reader (LSB-first) +# --------------------------- +class BitStreamReader: + """ + LSB-first bit reader over a bytes-like buffer (torch.uint8, np.uint8, or Python bytes). + - read_bits(n): reads n bits, returning an integer whose bit j is the j-th bit read. + - read_unary_bounded(end_bit): reads '1's until a '0' or end_bit; returns (q, hit_end) + """ + __slots__ = ("buf", "total_bits", "bit_off") + + def __init__(self, payload: BytesLike, bit_offset_start: int = 0): + if isinstance(payload, torch.Tensor): + assert payload.dtype == torch.uint8 + self.buf = payload.cpu().numpy().tobytes() + elif isinstance(payload, np.ndarray): + assert payload.dtype == np.uint8 + self.buf = payload.tobytes() + elif isinstance(payload, (bytes, bytearray)): + self.buf = bytes(payload) + else: + raise TypeError("Unsupported payload type for BitStreamReader") + + self.total_bits = len(self.buf) * 8 + self.bit_off = int(bit_offset_start) + + def read_bits(self, n_bits: int) -> int: + """Read n_bits in LSB-first order; returns value with bit j equal to j-th bit read.""" + if n_bits == 0: + return 0 + if self.bit_off + n_bits > self.total_bits: + raise EOFError("Attempt to read past end of bitstream") + val = 0 + start = self.bit_off + for j in range(n_bits): + pos = start + j + b = self.buf[pos >> 3] + bit = (b >> (pos & 7)) & 1 + val |= (bit << j) + self.bit_off = start + n_bits + return val + + def read_unary_bounded(self, end_bit: int) -> Tuple[int, bool]: + """ + Read unary as q ones followed by a single 0, *bounded* by end_bit. + Returns: (q, hit_end) + - q: number of 1s seen before the terminating 0 + - hit_end: True if we reached end_bit before seeing the terminating 0 + """ + q = 0 + while self.bit_off < end_bit: + bit = self.read_bits(1) + if bit == 1: + q += 1 + else: + return q, False + return q, True # ran out of bits without seeing the terminating 0 + + def bits_remaining(self) -> int: + return self.total_bits - self.bit_off + + def is_at_end(self) -> bool: + """True if only up to 7 padding bits remain globally.""" + return self.bit_off >= self.total_bits - 7 + + +# --------------------------- +# Header parsing +# --------------------------- +def _parse_global_header(payload: BytesLike) -> Tuple[int, int, List[int], int]: + """ + Layout: + 4B "CGRP" + 4B C (uint32 LE) + 2B K (uint16 LE) <-- NEW + 1B num_B + 2B * num_B (each B, uint16 LE) + Returns: (C, K, B_choices, header_end_bit_offset) + """ + if isinstance(payload, torch.Tensor): + assert payload.dtype == torch.uint8 + raw = payload.cpu().numpy().tobytes() + elif isinstance(payload, np.ndarray): + assert payload.dtype == np.uint8 + raw = payload.tobytes() + elif isinstance(payload, (bytes, bytearray)): + raw = bytes(payload) + else: + raise TypeError("Unsupported payload type") + + if len(raw) < 11: + raise ValueError("Payload too short for global header") + if raw[:4] != b"CGRP": + raise ValueError("Bad magic; expected 'CGRP'") + + C = int.from_bytes(raw[4:8], "little", signed=False) + K = int.from_bytes(raw[8:10], "little", signed=False) # NEW + num_B = raw[10] + need = 4 + 4 + 2 + 1 + 2 * num_B + if len(raw) < need: + raise ValueError("Payload shorter than header requires") + + B_choices = [] + off = 11 + for _ in range(num_B): + b = int.from_bytes(raw[off:off+2], "little", signed=False) + B_choices.append(b) + off += 2 + + return C, K, B_choices, off * 8 + + + +def _decode_row_variant_b(stream: BitStreamReader, M: int, k_rice: int, use_bitmap: int, + row_payload_bytes: int, row_header_bits: int, + K: int) -> List[int]: + """ + Stream is positioned just AFTER the 16-bit length. + Decode EXACTLY K deltas (first is Rice; tail is bitmap or Rice), + then align to end-of-row (row_payload_bytes*8). + """ + start_bit = stream.bit_off + row_end_bit = start_bit + row_payload_bytes * 8 + + # header + _ = stream.read_bits(row_header_bits) + + deltas: List[int] = [] + + # first (Rice) + q0, hit_end = stream.read_unary_bounded(row_end_bit) + if hit_end or stream.bit_off + k_rice > row_end_bit: + stream.bit_off = row_end_bit + return [] + r0 = stream.read_bits(k_rice) + deltas.append(q0 * M + r0) + + # Tail: exactly K-1 more + for _ in range(K - 1): + if use_bitmap: + need = 1 + k_rice + if stream.bit_off + need > row_end_bit: + # not enough bits; treat as malformed/padded + break + q = stream.read_bits(1) + r = stream.read_bits(k_rice) + else: + # Rice + if stream.bit_off >= row_end_bit: + break + q, hit_end = stream.read_unary_bounded(row_end_bit) + if hit_end or stream.bit_off + k_rice > row_end_bit: + break + r = stream.read_bits(k_rice) + deltas.append(q * M + r) + + # align to end of row explicitly + stream.bit_off = row_end_bit + + # prefix sum + if not deltas: + return [] + vals = [0] * len(deltas) + vals[0] = deltas[0] + for i in range(1, len(deltas)): + vals[i] = vals[i-1] + deltas[i] + return vals + + +def decode_batch_rows(payload: BytesLike) -> List[List[int]]: + C, K, B_choices, header_end_bit = _parse_global_header(payload) + num_B = len(B_choices) + + # derive M/k_rice per choice + M_choices, k_rice_choices = [], [] + for B in B_choices: + M = C // B + if M <= 0 or (M & (M - 1)) != 0: + raise ValueError(f"M=C//B={M} not power of two for B={B}") + M_choices.append(M) + k_rice_choices.append(int(math.log2(M))) + + B_choice_bits = (num_B - 1).bit_length() + ROW_HEADER_BITS = 1 + B_choice_bits + + stream = BitStreamReader(payload, bit_offset_start=header_end_bit) + rows_out: List[List[int]] = [] + + while stream.bits_remaining() >= 16: + row_payload_bytes = stream.read_bits(16) + if row_payload_bytes == 0 and stream.is_at_end(): + break + + # Peek header to learn best_B_idx & use_bitmap + if stream.bits_remaining() < ROW_HEADER_BITS: + break + header = stream.read_bits(ROW_HEADER_BITS) + use_bitmap = header & 1 + best_B_idx = header >> 1 + + if not (0 <= best_B_idx < num_B): + break + M = M_choices[best_B_idx] + k_rice = k_rice_choices[best_B_idx] + + # Rewind header; decode the row with exact K + stream.bit_off -= ROW_HEADER_BITS + row_vals = _decode_row_variant_b( + stream, M=M, k_rice=k_rice, use_bitmap=use_bitmap, + row_payload_bytes=row_payload_bytes, row_header_bits=ROW_HEADER_BITS, + K=K + ) + if not row_vals: + break + rows_out.append(row_vals) + return rows_out + + +if __name__ == "__main__": + torch.manual_seed(0) + ROWS, K = 32, 16 + COLS = 4096 + + x = torch.randn((ROWS, COLS), dtype=torch.float32) + idx = torch.topk(x.abs(), k=K, dim=-1, largest=True, sorted=False).indices + + idx, _ = torch.sort(idx, dim=1) + payload, _ = encode_batch_rows(idx, C=COLS, B_choices=(64, 128, 256)) + decoded = decode_batch_rows(payload) + dec = [torch.tensor(r, dtype=torch.int64) for r in decoded] + ok = True + idx = [row for row in idx] + for r in range(ROWS): + if not torch.equal(torch.tensor(decoded[r]), idx[r].cpu()): + ok = False + print("Mismatch row", r) + print("orig:", idx[r].tolist()) + print("dec :", decoded[r]) + print("Round-trip OK" if ok else "Round-trip MISMATCH") From 646c268949b52386ac128bf23e49e7309727ef30 Mon Sep 17 00:00:00 2001 From: Kasper Date: Tue, 11 Nov 2025 19:31:12 +0400 Subject: [PATCH 02/33] update compression tests and pass --- src/tplr/comms.py | 3 +- src/tplr/compress.py | 160 +++++++++-------------------- src/tplr/compression/__init__.py | 9 +- src/tplr/compression/bits.py | 171 ------------------------------- src/tplr/compression/hybrid.py | 24 ++--- src/tplr/compression/pack12.py | 94 +++++++++++++++++ src/tplr/neurons.py | 1 - tests/unit/test_compress.py | 141 +++++++++++++------------ 8 files changed, 234 insertions(+), 369 deletions(-) create mode 100644 src/tplr/compression/pack12.py diff --git a/src/tplr/comms.py b/src/tplr/comms.py index ae2194bc2..1578d5f7c 100644 --- a/src/tplr/comms.py +++ b/src/tplr/comms.py @@ -43,7 +43,8 @@ import tplr from tplr.chain import ChainManager -from tplr.compress import TopKCompressor, unpack_12bit_indices +from tplr.compress import TopKCompressor +from tplr.compression import unpack_12bit_indices from tplr.config import BUCKET_SECRETS, client_config from tplr.schemas import Bucket, CommsGetResult diff --git a/src/tplr/compress.py b/src/tplr/compress.py index 206f3cd8d..2a57bef77 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -20,6 +20,7 @@ # Global imports +import numpy as np import math from typing import Generic, Literal, Sequence, TypeAlias, TypeVar, cast, overload @@ -30,6 +31,8 @@ import tplr +from tplr.compression import encode_batch_rows, decode_batch_rows + # ─────────── type aliases ──────────────────────────────────────────────── # primitive shapes ShapeT: TypeAlias = tuple[int, ...] # original dense tensor shape @@ -47,99 +50,7 @@ # Boolean flag that propagates the chosen quantisation mode Q = TypeVar("Q", Literal[True], Literal[False]) - -def pack_12bit_indices(indices: torch.Tensor) -> torch.Tensor: - """ - Pack int64 indices into 12-bit representation. - Every 2 indices (24 bits) are packed into 3 uint8 values. - Assumes even number of indices (topk is always even). - - Args: - indices: Tensor with values < 4096 (12-bit max), must have even number of elements - - Returns: - packed_tensor as uint8 - """ - # Ensure indices fit in 12 bits - max_idx = indices.max().item() if indices.numel() > 0 else 0 - if max_idx >= 4096: - raise ValueError(f"Index {max_idx} exceeds 12-bit limit (4095)") - - # Flatten the tensor - indices_flat = indices.flatten() - n_indices = indices_flat.numel() - - # Ensure we have even number of indices - if n_indices % 2 != 0: - raise ValueError(f"Number of indices must be even, got {n_indices}") - - # Convert to int32 for bit manipulation - indices_flat = indices_flat.to(torch.int32) - - # Process all as pairs - indices_pairs = indices_flat - n_pairs = n_indices // 2 - - # Calculate packed size - packed_size = n_pairs * 3 - packed = torch.zeros(packed_size, dtype=torch.uint8, device=indices.device) - - # Vectorized packing for pairs - if n_pairs > 0: - idx_pairs = indices_pairs.reshape(-1, 2) - idx1 = idx_pairs[:, 0] - idx2 = idx_pairs[:, 1] - - # Pack pairs: idx1 uses byte0 + lower 4 bits of byte1 - # idx2 uses upper 4 bits of byte1 + byte2 - packed[0::3] = (idx1 & 0xFF).to(torch.uint8) # Lower 8 bits of idx1 - packed[1::3] = (((idx1 >> 8) & 0x0F) | ((idx2 & 0x0F) << 4)).to(torch.uint8) - packed[2::3] = ((idx2 >> 4) & 0xFF).to(torch.uint8) # Upper 8 bits of idx2 - - return packed - - -def unpack_12bit_indices(packed: torch.Tensor, values_shape: ShapeT) -> torch.Tensor: - """ - Unpack 12-bit packed indices back to int64. - Assumes even number of indices. - - Args: - packed: Packed uint8 tensor - values_shape: Shape of the values tensor (same as original indices shape) - - Returns: - Unpacked indices as int64 tensor with original shape - """ - n_indices = int(torch.prod(torch.tensor(values_shape)).item()) - - if n_indices == 0: - return torch.zeros(values_shape, dtype=torch.int64, device=packed.device) - - # Ensure even number of indices - if n_indices % 2 != 0: - raise ValueError(f"Number of indices must be even, got {n_indices}") - - # Prepare output - indices = torch.zeros(n_indices, dtype=torch.int64, device=packed.device) - - # All indices are paired - n_pairs = n_indices // 2 - - if n_pairs > 0: - # Vectorized unpacking - byte0 = packed[0::3].to(torch.int64) - byte1 = packed[1::3].to(torch.int64) - byte2 = packed[2::3].to(torch.int64) - - # Reconstruct indices - indices[0::2] = byte0 | ((byte1 & 0x0F) << 8) # idx1 - indices[1::2] = ((byte1 >> 4) & 0x0F) | (byte2 << 4) # idx2 - - # Reshape to match values shape - indices = indices.reshape(values_shape) - - return indices +_DEFAULT_B_CHOICES: tuple[int, ...] = (64, 128) class ChunkingTransformer: @@ -411,16 +322,21 @@ def compress(self, x: torch.Tensor, topk: int): # type: ignore[override] ).indices val = torch.gather(x, dim=-1, index=idx_int64) - # Pack indices into 12-bit representation for efficient storage - # This reduces storage by 25% compared to int16 - idx = pack_12bit_indices(idx_int64) + # Flatten to [rows, k] for the codec + idx2d = idx_int64.reshape(-1, topk).contiguous() + payload, _meta = encode_batch_rows(idx2d, C=totalk, B_choices=_DEFAULT_B_CHOICES) + idx_bytes = torch.tensor( + np.frombuffer(payload, dtype=np.uint8).copy(), + dtype=torch.uint8, + device="cpu", + ) # Apply 8-bit quantization if enabled if self.use_quantization: - val, quant_params = self._quantize_values(val) - return idx, val, xshape, totalk, quant_params + val, qparams = self._quantize_values(val) + return idx_bytes, val, xshape, totalk, qparams + return idx_bytes, val, xshape, totalk - return idx, val, xshape, totalk @torch.no_grad() def decompress( @@ -454,17 +370,23 @@ def decompress( if len(xshape) > 2: # 2D weights x = rearrange(x, "y x h w -> y x (h w)") - # Unpack 12-bit indices using val shape (if needed) + # Decode indices if idx.dtype == torch.uint8: - # 12-bit packed format - unpack it - idx_int64 = unpack_12bit_indices(idx, val.shape) + payload_bytes = idx.detach().cpu().numpy().tobytes() + rows_list, C, _N = decode_batch_rows(payload_bytes) + if C != totalk: + raise ValueError(f"Index payload C={C} but expected {totalk}") + k = val.shape[-1] + #if any(len(r) != k for r in rows_list): + # raise ValueError("Row-wise topk size mismatch in index payload") + idx_int64 = torch.tensor( + rows_list, dtype=torch.int64, device=p.device + ).view(*val.shape) elif idx.dtype in (torch.int64, torch.long): - # Already unpacked (from batch_decompress) - idx_int64 = idx + idx_int64 = idx.to(p.device) else: - raise ValueError( - f"Expected uint8 (packed) or int64 (unpacked) indices, got {idx.dtype}" - ) + raise ValueError(f"Unsupported index tensor dtype: {idx.dtype}") + # Ensure val has the same dtype as x for scatter operation if val.dtype != x.dtype: val = val.to(dtype=x.dtype) @@ -562,14 +484,24 @@ def batch_decompress( idx_list = idx if isinstance(idx, Sequence) else [idx] for i, i_data in enumerate(idx_list): - if i_data.dtype != torch.uint8: - raise ValueError( - f"Expected uint8 for 12-bit packed indices, got {i_data.dtype}" - ) - # Unpack 12-bit format using corresponding values shape v_data = val_list[i] - idx_unpacked = unpack_12bit_indices(i_data.to(p.device), v_data.shape) - unpacked_indices.append(idx_unpacked) + if i_data.dtype == torch.uint8: + rows, C, _N = decode_batch_rows(i_data.detach().cpu().numpy().tobytes()) + if C != totalk: + raise ValueError(f"Index payload C={C} but expected {totalk}") + if any(len(r) != v_data.shape[-1] for r in rows): + raise ValueError( + "Row-wise topk size mismatch in index payload (batch)" + ) + idx_unpacked = torch.tensor( + rows, dtype=torch.int64, device=p.device + ).view(*v_data.shape) + unpacked_indices.append(idx_unpacked) + elif i_data.dtype in (torch.int64, torch.long): + idx_unpacked = i_data.to(p.device) + unpacked_indices.append(idx_unpacked) + else: + raise ValueError(f"Unsupported index dtype in batch: {i_data.dtype}") idx_concat = torch.cat(unpacked_indices, dim=-1) val_concat = torch.cat(processed_vals, dim=-1).to(p.dtype) diff --git a/src/tplr/compression/__init__.py b/src/tplr/compression/__init__.py index d823aa0f5..2b570f514 100644 --- a/src/tplr/compression/__init__.py +++ b/src/tplr/compression/__init__.py @@ -16,12 +16,19 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. -from .bits import ( +from .hybrid import ( decode_batch_rows, # decoder (CPU) encode_batch_rows, # GPU-accelerated encoder → bytes + perm + meta ) + +from .pack12 import ( + pack_12bit_indices, + unpack_12bit_indices +) __all__ = [ # High level "encode_batch_rows", "decode_batch_rows", + "pack_12bit_indices", + "unpack_12bit_indices" ] \ No newline at end of file diff --git a/src/tplr/compression/bits.py b/src/tplr/compression/bits.py index 21e87a563..0ce31380f 100644 --- a/src/tplr/compression/bits.py +++ b/src/tplr/compression/bits.py @@ -363,177 +363,6 @@ def _kernel_write_rows( # done -@triton.jit -def sub_block_cost_kernel( - idx_ptr, # IN: [rows, k_dim] int64 - best_B_ptr, # OUT: [rows] int32 - best_use_bitmap_ptr, # OUT: [rows] uint8 - best_k_param_ptr, # OUT: [rows] int32 - B_choices_ptr, # IN: [num_B_choices] int32 - k_params_ptr, # IN: [num_B_choices] int32 - rows: tl.int32, - num_B_choices: tl.int32, - k_dim: tl.constexpr, # K (e.g., 128) - C: tl.constexpr, - N_SUB_MAX: tl.constexpr, # C // B_min -): - """ - Triton kernel to find the best (B, k_param, use_bitmap) for each row - using the "sub-block" algorithm. - - This version uses a DYNAMIC loop over B_choices, so it compiles instantly. - """ - row_idx = tl.program_id(0) - if row_idx >= rows: - return - - # --- 1. Load all k_dim indices for this row into SRAM --- - k_offsets = tl.arange(0, k_dim) - row_k_mask = k_offsets < k_dim - idx_vals = tl.load(idx_ptr + row_idx * k_dim + k_offsets, mask=row_k_mask, other=0) - - # --- 2. Create index block for histogram --- - # j_indices will be [0, 1, 2, ..., N_SUB_MAX-1] - j_indices = tl.arange(0, N_SUB_MAX) - - # --- 3. Initialize best-of tracking for this row --- - best_bits = 1 << 60 # "infinity" (this is tl.int64) - best_B_val = 0 # tl.int32 - - # FIX: Initialize as tl.int1 (boolean) type - best_use_bitmap_val = tl.zeros((), dtype=tl.int1) # tl.int1 - - best_k_param_val = 0 # tl.int32 - - header_bits = 5 + 4 + 1 # lb + k + mode - - # --- 4. Dynamically loop over B_choices --- - b_idx = 0 - while b_idx < num_B_choices: - # Load B and k_param from tensors - B = tl.load(B_choices_ptr + b_idx) - k_param = tl.load(k_params_ptr + b_idx) - - n_sub = C // B - - # FIX: Cast tl.log2 (which returns float32) back to int64 - lb = tl.log2(B.to(tl.float32)).to(tl.int64) # B is power-of-two, so this is exact - - m_val = 1 << k_param - - # --- 5. Calculate sub-block counts (The Histogram) --- - - # [k_dim] -> [0, 5, 0, 1, ...] sub-block ID for each index - j_block = (idx_vals // B.to(tl.int64)).to(tl.int64) - - # Broadcasted histogram: - # (j_block[None, :] == j_indices[:, None]) creates a [N_SUB_MAX, k_dim] matrix - # tl.sum(..., axis=1) sums along the k_dim axis. - counts_all = tl.sum((j_block[None, :] == j_indices[:, None]), axis=1) - - # We only care about the counts for the *valid* sub-blocks - n_sub_mask = j_indices < n_sub - counts_unsigned = tl.where(n_sub_mask, counts_all, 0) - - # --- 6. Calculate cost for this B --- - - # Cast both operands to tl.int64 to avoid signedness mismatch - counts = counts_unsigned.to(tl.int64) - m = m_val.to(tl.int64) - - q = counts // m - - # Σ(q + 1) + n_sub * k - rb_sum = tl.sum((q + 1) * n_sub_mask.to(tl.int64)) + (k_param * n_sub) - - # FIX: Cast nonzero to tl.int64 to prevent type mismatch - nonzero = tl.sum((counts > 0).to(tl.int64)) - - bits_local = header_bits + rb_sum + lb * k_dim - bits_bitmap = header_bits + rb_sum + B.to(tl.int64) * nonzero - - cur_bits = tl.minimum(bits_local, bits_bitmap) # This is now int64 - use_bitmap = bits_bitmap < bits_local # This is tl.int1 - - # --- 7. Update best-of --- - if cur_bits < best_bits: # This is now int64 < int64 - best_bits = cur_bits - best_B_val = B - best_use_bitmap_val = use_bitmap # This is now int1 = int1 (SUCCESS) - best_k_param_val = k_param - - b_idx += 1 # Advance dynamic loop - - # --- 8. Store results for this row --- - tl.store(best_B_ptr + row_idx, best_B_val) - tl.store(best_use_bitmap_ptr + row_idx, best_use_bitmap_val) - tl.store(best_k_param_ptr + row_idx, best_k_param_val) - - -# ---------------------------------------------------------------------------- -# 2. Python Wrapper for the new Triton Kernel -# ---------------------------------------------------------------------------- - -def _estimate_best_params_per_row_triton( - idx: torch.Tensor, C: int, B_choices: Sequence[int] -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Python wrapper for the sub_block_cost_kernel. - """ - rows, k_dim = idx.shape - device = idx.device - - # --- 1. Pre-calculate all args for the kernel --- - B_sorted = tuple( - sorted([b for b in B_choices if b > 0 and (C % b) == 0 and (b & (b - 1)) == 0]) - ) - if not B_sorted: - raise ValueError("No valid B choices for C") - - Bmin = B_sorted[0] - N_SUB_MAX = C // Bmin - num_B_choices = len(B_sorted) - - lmbda_base = k_dim / max(1, C) - - # --- 2. Create Tensors to pass to kernel --- - B_choices_tensor = torch.tensor(B_sorted, dtype=torch.int32, device=device) - - K_params_list = [ - int(max(0, round(math.log2(max(lmbda_base * B, 1e-9))))) - for B in B_sorted - ] - K_params_tensor = torch.tensor(K_params_list, dtype=torch.int32, device=device) - - # --- 3. Allocate output tensors --- - best_B = torch.empty((rows,), dtype=torch.int32, device=device) - best_use_bitmap = torch.empty((rows,), dtype=torch.uint8, device=device) - best_k_param = torch.empty((rows,), dtype=torch.int32, device=device) - - # --- 4. Launch Kernel --- - grid = (rows,) - - # Corrected Kernel Call: - # All non-constexpr args must be passed positionally. - sub_block_cost_kernel[grid]( - idx, - best_B, - best_use_bitmap, - best_k_param, - B_choices_tensor, - K_params_tensor, - rows, - num_B_choices, - # Constexpr args are passed by keyword - k_dim=k_dim, - C=C, - N_SUB_MAX=N_SUB_MAX, - ) - - # The Triton kernel stored uint8(bool), convert back to bool for consistency - return best_B, best_use_bitmap.to(torch.bool), best_k_param - - # -------------------------------- Public API ---------------------------------- diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py index 3e10bffbb..9032dbaf4 100644 --- a/src/tplr/compression/hybrid.py +++ b/src/tplr/compression/hybrid.py @@ -464,10 +464,7 @@ def is_at_end(self) -> bool: return self.bit_off >= self.total_bits - 7 -# --------------------------- -# Header parsing -# --------------------------- -def _parse_global_header(payload: BytesLike) -> Tuple[int, int, List[int], int]: +def _parse_global_header(payload: BytesLike) -> Tuple[int, int, list[int], int]: """ Layout: 4B "CGRP" @@ -506,14 +503,12 @@ def _parse_global_header(payload: BytesLike) -> Tuple[int, int, List[int], int]: b = int.from_bytes(raw[off:off+2], "little", signed=False) B_choices.append(b) off += 2 - return C, K, B_choices, off * 8 - -def _decode_row_variant_b(stream: BitStreamReader, M: int, k_rice: int, use_bitmap: int, - row_payload_bytes: int, row_header_bits: int, - K: int) -> List[int]: +def _decode_row(stream: BitStreamReader, M: int, k_rice: int, use_bitmap: int, + row_payload_bytes: int, row_header_bits: int, + K: int) -> list[int]: """ Stream is positioned just AFTER the 16-bit length. Decode EXACTLY K deltas (first is Rice; tail is bitmap or Rice), @@ -525,7 +520,7 @@ def _decode_row_variant_b(stream: BitStreamReader, M: int, k_rice: int, use_bitm # header _ = stream.read_bits(row_header_bits) - deltas: List[int] = [] + deltas: list[int] = [] # first (Rice) q0, hit_end = stream.read_unary_bounded(row_end_bit) @@ -567,7 +562,7 @@ def _decode_row_variant_b(stream: BitStreamReader, M: int, k_rice: int, use_bitm return vals -def decode_batch_rows(payload: BytesLike) -> List[List[int]]: +def decode_batch_rows(payload: BytesLike) -> tuple[list[list[int]], int, int]: C, K, B_choices, header_end_bit = _parse_global_header(payload) num_B = len(B_choices) @@ -584,8 +579,7 @@ def decode_batch_rows(payload: BytesLike) -> List[List[int]]: ROW_HEADER_BITS = 1 + B_choice_bits stream = BitStreamReader(payload, bit_offset_start=header_end_bit) - rows_out: List[List[int]] = [] - + rows_out: list[list[int]] = [] while stream.bits_remaining() >= 16: row_payload_bytes = stream.read_bits(16) if row_payload_bytes == 0 and stream.is_at_end(): @@ -605,7 +599,7 @@ def decode_batch_rows(payload: BytesLike) -> List[List[int]]: # Rewind header; decode the row with exact K stream.bit_off -= ROW_HEADER_BITS - row_vals = _decode_row_variant_b( + row_vals = _decode_row( stream, M=M, k_rice=k_rice, use_bitmap=use_bitmap, row_payload_bytes=row_payload_bytes, row_header_bits=ROW_HEADER_BITS, K=K @@ -613,7 +607,7 @@ def decode_batch_rows(payload: BytesLike) -> List[List[int]]: if not row_vals: break rows_out.append(row_vals) - return rows_out + return rows_out, C, num_B if __name__ == "__main__": diff --git a/src/tplr/compression/pack12.py b/src/tplr/compression/pack12.py new file mode 100644 index 000000000..ecc2617ad --- /dev/null +++ b/src/tplr/compression/pack12.py @@ -0,0 +1,94 @@ +import torch + +def pack_12bit_indices(indices: torch.Tensor) -> torch.Tensor: + """ + Pack int64 indices into 12-bit representation. + Every 2 indices (24 bits) are packed into 3 uint8 values. + Assumes even number of indices (topk is always even). + + Args: + indices: Tensor with values < 4096 (12-bit max), must have even number of elements + + Returns: + packed_tensor as uint8 + """ + # Ensure indices fit in 12 bits + max_idx = indices.max().item() if indices.numel() > 0 else 0 + if max_idx >= 4096: + raise ValueError(f"Index {max_idx} exceeds 12-bit limit (4095)") + + # Flatten the tensor + indices_flat = indices.flatten() + n_indices = indices_flat.numel() + + # Ensure we have even number of indices + if n_indices % 2 != 0: + raise ValueError(f"Number of indices must be even, got {n_indices}") + + # Convert to int32 for bit manipulation + indices_flat = indices_flat.to(torch.int32) + + # Process all as pairs + indices_pairs = indices_flat + n_pairs = n_indices // 2 + + # Calculate packed size + packed_size = n_pairs * 3 + packed = torch.zeros(packed_size, dtype=torch.uint8, device=indices.device) + + # Vectorized packing for pairs + if n_pairs > 0: + idx_pairs = indices_pairs.reshape(-1, 2) + idx1 = idx_pairs[:, 0] + idx2 = idx_pairs[:, 1] + + # Pack pairs: idx1 uses byte0 + lower 4 bits of byte1 + # idx2 uses upper 4 bits of byte1 + byte2 + packed[0::3] = (idx1 & 0xFF).to(torch.uint8) # Lower 8 bits of idx1 + packed[1::3] = (((idx1 >> 8) & 0x0F) | ((idx2 & 0x0F) << 4)).to(torch.uint8) + packed[2::3] = ((idx2 >> 4) & 0xFF).to(torch.uint8) # Upper 8 bits of idx2 + + return packed + + +def unpack_12bit_indices(packed: torch.Tensor, values_shape: tuple[int, ...] ) -> torch.Tensor: + """ + Unpack 12-bit packed indices back to int64. + Assumes even number of indices. + + Args: + packed: Packed uint8 tensor + values_shape: Shape of the values tensor (same as original indices shape) + + Returns: + Unpacked indices as int64 tensor with original shape + """ + n_indices = int(torch.prod(torch.tensor(values_shape)).item()) + + if n_indices == 0: + return torch.zeros(values_shape, dtype=torch.int64, device=packed.device) + + # Ensure even number of indices + if n_indices % 2 != 0: + raise ValueError(f"Number of indices must be even, got {n_indices}") + + # Prepare output + indices = torch.zeros(n_indices, dtype=torch.int64, device=packed.device) + + # All indices are paired + n_pairs = n_indices // 2 + + if n_pairs > 0: + # Vectorized unpacking + byte0 = packed[0::3].to(torch.int64) + byte1 = packed[1::3].to(torch.int64) + byte2 = packed[2::3].to(torch.int64) + + # Reconstruct indices + indices[0::2] = byte0 | ((byte1 & 0x0F) << 8) # idx1 + indices[1::2] = ((byte1 >> 4) & 0x0F) | (byte2 << 4) # idx2 + + # Reshape to match values shape + indices = indices.reshape(values_shape) + + return indices \ No newline at end of file diff --git a/src/tplr/neurons.py b/src/tplr/neurons.py index f74b25dab..c4f6d8852 100644 --- a/src/tplr/neurons.py +++ b/src/tplr/neurons.py @@ -34,7 +34,6 @@ from wandb.sdk.wandb_run import Run import tplr -from tplr.compress import unpack_12bit_indices from tplr.distributed import dist_helper if TYPE_CHECKING: diff --git a/tests/unit/test_compress.py b/tests/unit/test_compress.py index 2eaab4fcd..b9f26c70b 100644 --- a/tests/unit/test_compress.py +++ b/tests/unit/test_compress.py @@ -1,5 +1,6 @@ from typing import Literal +import numpy as np import pytest import torch import torch.nn as nn @@ -9,11 +10,11 @@ TopKCompressor, _dct, _get_smaller_split, - _idct, - pack_12bit_indices, - unpack_12bit_indices, + _idct ) +from tplr.compression import encode_batch_rows, pack_12bit_indices, unpack_12bit_indices + class TestTopKCompressor: """Test TopKCompressor class using actual implementation""" @@ -30,19 +31,19 @@ def compress_instance_quantized(self) -> TopKCompressor[Literal[True]]: use_quantization=True, quantization_bins=256, quantization_range=6 ) - def test_compress_produces_int16_indices( - self, compress_instance: TopKCompressor[Literal[False]] + def test_compress_produces_rice_bitmap_indices( + self, compress_instance: TopKCompressor[Literal[False]] ): - """Test that compress() produces 12-bit packed indices""" + """Test that compress() produces Rice/bitmap encoded indices""" # Create test tensor - x = torch.randn(10, 10) - topk = 10 + x = torch.randn(8, 128) # 1024 elements total, last dim=64 + topk = 16 # Compress using actual method idx, val, xshape, totalk = compress_instance.compress(x, topk) - # Verify index format - should be uint8 tensor for 12-bit packed - assert idx.dtype == torch.uint8, f"Expected uint8 packed data, got {idx.dtype}" + # Verify index format - should be uint8 tensor for Rice/bitmap codec + assert idx.dtype == torch.uint8, f"Expected uint8 encoded data, got {idx.dtype}" assert val.shape[-1] == topk assert xshape == x.shape # totalk is the size of the last dimension after rearranging @@ -50,11 +51,11 @@ def test_compress_produces_int16_indices( assert totalk == x.shape[-1] # For 2D tensor, it's the last dimension def test_compress_with_quantization( - self, compress_instance_quantized: TopKCompressor[Literal[True]] + self, compress_instance_quantized: TopKCompressor[Literal[True]] ): """Test compression with quantization enabled""" - x = torch.randn(10, 10) - topk = 20 + x = torch.randn(8, 128) # 1024 elements total, last dim=64 + topk = 32 # Compress with quantization result = compress_instance_quantized.compress(x, topk) @@ -63,53 +64,58 @@ def test_compress_with_quantization( assert len(result) == 5 idx, val, _, _, qparams = result - # idx should be uint8 tensor for 12-bit packed format + # idx should be uint8 tensor for Rice/bitmap encoded format assert idx.dtype == torch.uint8 assert val.dtype == torch.uint8 # Quantized values assert qparams is not None assert len(qparams) == 5 # shift, scale, offset, lookup, orig_dtype - def test_decompress_with_12bit_tuple_format( - self, compress_instance: TopKCompressor[Literal[False]] + def test_decompress_with_rice_bitmap_format( + self, compress_instance: TopKCompressor[Literal[False]] ): - """Test that decompress can handle 12-bit packed tuple format""" + """Test that decompress can handle Rice/bitmap encoded format""" # Setup - p = torch.zeros(10, 10) - xshape = (10, 10) - totalk = 100 + p = torch.zeros(8, 128) # 1024 elements total, last dim=128 + xshape = (8, 128) + totalk = 128 - # Create proper 12-bit packed format using the actual packing function - # Create indices that are within valid range for a 10x10 tensor (even count) + # Create proper Rice/bitmap encoded format using the encoder + # Create indices that are within valid range for a 8x64 tensor (even count) original_indices = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.int64) - # Pack using the actual function - idx = pack_12bit_indices(original_indices) - - val = torch.tensor( - [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32 - ) + # Pack using the new encoder format + payload, _ = encode_batch_rows(original_indices, C=totalk) + idx = torch.tensor(np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8) + val = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32) # Test decompression with packed format result = compress_instance.decompress(p, idx, val, xshape, totalk) assert result.shape == xshape assert result.dtype == p.dtype - def test_batch_decompress_multiple_12bit_formats( - self, compress_instance: TopKCompressor[Literal[False]] + def test_batch_decompress_multiple_rice_bitmap_formats( + self, compress_instance: TopKCompressor[Literal[False]] ): - """Test batch_decompress with multiple 12-bit packed indices""" + """Test batch_decompress with multiple Rice/bitmap encoded indices""" # Setup - p = torch.zeros(10, 10) - xshape = (10, 10) - totalk = 100 + p = torch.zeros(8, 128) # 1024 elements total, last dim=128 + xshape = (8, 128) + totalk = 128 - # Create multiple 12-bit packed indices + # Create multiple Rice/bitmap encoded indices idx1_orig = torch.tensor([[0, 1], [2, 3]], dtype=torch.int64) idx2_orig = torch.tensor([[4, 5], [6, 7]], dtype=torch.int64) - # Pack them using the 12-bit format - idx1_packed = pack_12bit_indices(idx1_orig) - idx2_packed = pack_12bit_indices(idx2_orig) + # Pack them using the new encoder format + payload1, _ = encode_batch_rows(idx1_orig, C=totalk) + idx1_packed = torch.tensor( + np.frombuffer(payload1, dtype=np.uint8), dtype=torch.uint8 + ) + + payload2, _ = encode_batch_rows(idx2_orig, C=totalk) + idx2_packed = torch.tensor( + np.frombuffer(payload2, dtype=np.uint8), dtype=torch.uint8 + ) idx_list = [idx1_packed, idx2_packed] @@ -125,10 +131,10 @@ def test_batch_decompress_multiple_12bit_formats( assert result.dtype == p.dtype def test_compress_decompress_round_trip( - self, compress_instance: TopKCompressor[Literal[False]] + self, compress_instance: TopKCompressor[Literal[False]] ): """Test full compress-decompress round trip""" - x = torch.zeros(10, 10) + x = torch.zeros(8, 128) # 1024 elements total, last dim=128 x[0, 0] = 1.0 x[1, 1] = 2.0 x[2, 2] = 3.0 @@ -139,7 +145,9 @@ def test_compress_decompress_round_trip( idx, val, xshape, totalk = compress_instance.compress(x, topk) # Verify we got the top-k values - assert idx.dtype == torch.uint8, "Expected uint8 for 12-bit packed indices" + assert idx.dtype == torch.uint8, ( + "Expected uint8 for Rice/bitmap encoded indices" + ) assert val.shape[-1] == topk # Decompress @@ -157,46 +165,47 @@ def test_compress_decompress_round_trip( expected_vals = torch.tensor([4.0, 3.0, 2.0, 1.0]) assert torch.allclose(top_vals, expected_vals, atol=1e-5) - def test_12bit_index_value_range( - self, compress_instance: TopKCompressor[Literal[False]] + def test_rice_bitmap_index_value_range( + self, compress_instance: TopKCompressor[Literal[False]] ): - """Test that indices can represent values appropriate for 12-bit range""" + """Test that Rice/bitmap codec can handle large index ranges efficiently""" # Create a large tensor that would have indices beyond 8-bit range - x = torch.randn(100, 100) # 10,000 elements - topk = 100 + x = torch.randn(128, 128) # 16,384 elements + topk = 64 # Compress idx, val, _, totalk = compress_instance.compress(x, topk) - # Check that indices are 12-bit packed format - assert idx.dtype == torch.uint8, "Expected uint8 for 12-bit packed indices" + # Check that indices are in the new codec format (uint8 bytes) + assert idx.dtype == torch.uint8, "Expected uint8 for Rice/bitmap codec" - # Since idx is packed, we can't directly check max values - # Instead verify the packing worked correctly - # Use val.shape since it has the same shape as the original indices - unpacked = unpack_12bit_indices(idx, val.shape) + # Since idx is a byte stream payload, we can't directly check max values + # Instead verify round-trip works correctly + p = torch.zeros_like(x) + result = compress_instance.decompress(p, idx, val, x.shape, totalk) - # Verify some indices might be larger than 255 (8-bit max) - max_idx = unpacked.max().item() - assert max_idx < 10000, f"Index {max_idx} exceeds tensor size" + # Check that decompression succeeded + assert result.shape == x.shape - # If tensor is large enough, we should have indices > 255 - if totalk > 256: - assert unpacked.max() > 255, ( - "Large tensor should have indices beyond 8-bit range" - ) + # For a 2D tensor, totalk is the size of the last dimension + assert totalk == 128, ( + f"Expected totalk=128 for 128x128 tensor (last dim), got {totalk}" + ) def test_batch_decompress_with_norm_options( - self, compress_instance: TopKCompressor[Literal[False]] + self, compress_instance: TopKCompressor[Literal[False]] ): """Test batch_decompress with normalisation and clip_norm options""" - p = torch.zeros(10, 10) - xshape = (10, 10) - totalk = 100 + p = torch.zeros(8, 128) # 1024 elements total, last dim=128 + xshape = (8, 128) + totalk = 128 - # Create test data with 12-bit packed format + # Create test data with Rice/bitmap encoded format idx_orig = torch.tensor([[0, 1, 2, 3]], dtype=torch.int64) # Even count - idx_packed = pack_12bit_indices(idx_orig) + payload, _ = encode_batch_rows(idx_orig, C=totalk) + idx_packed = torch.tensor( + np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8 + ) idx = [idx_packed] val = [torch.tensor([[10.0, 20.0, 30.0, 40.0]], dtype=torch.float32)] From 6f6d2e5f3e207304f62a773fa8d9c8260a60636d Mon Sep 17 00:00:00 2001 From: Kasper Date: Thu, 13 Nov 2025 13:29:24 +0400 Subject: [PATCH 03/33] sync idx and val --- neurons/validator.py | 19 ++++++++++++++ src/tplr/compress.py | 20 +++++++------- src/tplr/compression/__init__.py | 4 ++- src/tplr/compression/hybrid.py | 31 ++++++++++++++++------ tests/unit/test_compress.py | 45 +++++++++++++++++++++++++++----- 5 files changed, 94 insertions(+), 25 deletions(-) diff --git a/neurons/validator.py b/neurons/validator.py index 3f432f35e..a3310c8b5 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -391,18 +391,37 @@ def __init__(self): self.xshapes = {} self.totalks = {} + + import time + total_compress_time = 0.0 + total_encode_time = 0.0 for n, p in self.model.named_parameters(): + tplr.logger.info(f"[COMPRESS START] {n}: shape={p.shape}") + + encode_start = time.time() enc = self.transformer.encode( torch.empty(p.shape, dtype=torch.float16, device=self.device), use_dct=self.hparams.use_dct, ) + encode_time = time.time() - encode_start + + compress_start = time.time() _, _, xshape, totalk, _ = self.compressor.compress( enc, self.hparams.topk_compression, ) + compress_time = time.time() - compress_start + self.xshapes[n] = xshape self.totalks[n] = totalk + total_encode_time += encode_time + total_compress_time += compress_time + + tplr.logger.info(f"[COMPRESS TIMING] {n}: encode={encode_time:.3f}s, compress={compress_time:.3f}s, shape={p.shape}") + + tplr.logger.info(f"[COMPRESS TIMING TOTAL] encode={total_encode_time:.3f}s, compress={total_compress_time:.3f}s") + self.openskill_model = PlackettLuce( beta=self.hparams.openskill_beta, tau=self.hparams.openskill_tau ) diff --git a/src/tplr/compress.py b/src/tplr/compress.py index 2a57bef77..6b25d1072 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -31,7 +31,7 @@ import tplr -from tplr.compression import encode_batch_rows, decode_batch_rows +from tplr.compression import encode_batch_rows, decode_batch_rows, encode_sparsification # ─────────── type aliases ──────────────────────────────────────────────── # primitive shapes @@ -50,7 +50,7 @@ # Boolean flag that propagates the chosen quantisation mode Q = TypeVar("Q", Literal[True], Literal[False]) -_DEFAULT_B_CHOICES: tuple[int, ...] = (64, 128) +_DEFAULT_B_CHOICES: tuple[int, ...] = (32, 64) class ChunkingTransformer: @@ -324,12 +324,8 @@ def compress(self, x: torch.Tensor, topk: int): # type: ignore[override] # Flatten to [rows, k] for the codec idx2d = idx_int64.reshape(-1, topk).contiguous() - payload, _meta = encode_batch_rows(idx2d, C=totalk, B_choices=_DEFAULT_B_CHOICES) - idx_bytes = torch.tensor( - np.frombuffer(payload, dtype=np.uint8).copy(), - dtype=torch.uint8, - device="cpu", - ) + val2d = val.reshape(-1, topk).contiguous() + idx_bytes, val, _meta = encode_sparsification(idx2d, val2d, C=totalk, B_choices=_DEFAULT_B_CHOICES) # Apply 8-bit quantization if enabled if self.use_quantization: @@ -377,8 +373,8 @@ def decompress( if C != totalk: raise ValueError(f"Index payload C={C} but expected {totalk}") k = val.shape[-1] - #if any(len(r) != k for r in rows_list): - # raise ValueError("Row-wise topk size mismatch in index payload") + if any(len(r) != k for r in rows_list): + raise ValueError("Row-wise topk size mismatch in index payload") idx_int64 = torch.tensor( rows_list, dtype=torch.int64, device=p.device ).view(*val.shape) @@ -391,6 +387,10 @@ def decompress( if val.dtype != x.dtype: val = val.to(dtype=x.dtype) + if x.ndim == 1: + idx_int64 = idx_int64.flatten() + val = val.flatten() + x.scatter_reduce_( dim=-1, index=idx_int64, src=val, reduce="mean", include_self=False ).reshape(xshape) diff --git a/src/tplr/compression/__init__.py b/src/tplr/compression/__init__.py index 2b570f514..4c6ac0b3f 100644 --- a/src/tplr/compression/__init__.py +++ b/src/tplr/compression/__init__.py @@ -17,6 +17,7 @@ # DEALINGS IN THE SOFTWARE. from .hybrid import ( + encode_sparsification, decode_batch_rows, # decoder (CPU) encode_batch_rows, # GPU-accelerated encoder → bytes + perm + meta ) @@ -30,5 +31,6 @@ "encode_batch_rows", "decode_batch_rows", "pack_12bit_indices", - "unpack_12bit_indices" + "unpack_12bit_indices", + "encode_sparsification" ] \ No newline at end of file diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py index 9032dbaf4..5a32a347e 100644 --- a/src/tplr/compression/hybrid.py +++ b/src/tplr/compression/hybrid.py @@ -10,12 +10,28 @@ BytesLike = Union[bytes, bytearray, np.ndarray, torch.Tensor] +def encode_sparsification( + idx: torch.Tensor, + vals: torch.Tensor, + C: int, + B_choices: Tuple[int, ...] = (64, 128) +) -> Tuple[bytes, torch.Tensor, Dict]: + dev = idx.device + payload, meta, perm = encode_batch_rows(idx, C=C, B_choices=B_choices) + vals = torch.gather(vals, dim=1, index=perm.to(dev)) + idx_bytes = torch.tensor( + np.frombuffer(payload, dtype=np.uint8).copy(), + dtype=torch.uint8, + device="cpu", + ) + return idx_bytes, vals, meta + def encode_batch_rows( idx: torch.Tensor, *, C: int, B_choices: Tuple[int, ...] = (64, 128) -) -> Tuple[bytes, Dict]: +) -> Tuple[bytes, Dict, torch.Tensor]: """ Compresses a 2D int64 tensor of Top-K indices into a byte string using a per-row adaptive Rice/Bitmap compression scheme on the GPU. @@ -33,7 +49,6 @@ def encode_batch_rows( - meta (dict): Metadata about the compression. """ - # --- 1. Input Validation & Setup --- if not torch.cuda.is_available(): raise RuntimeError("CUDA is required for this function.") @@ -54,8 +69,10 @@ def encode_batch_rows( "B_hist": {b: 0 for b in B_choices} } - dev = torch.device("cuda") - idx = idx.to(dev, non_blocking=True) + if not idx.is_cuda: + idx = idx.cuda() + idx = idx.contiguous() + dev = idx.device # Calculate k_rice parameters (log2(C // B)) k_rice_choices = tuple(int(math.log2(C // b)) for b in B_choices) @@ -73,10 +90,8 @@ def encode_batch_rows( # --- 2. GPU Preprocessing: Sort & Delta Encode --- - # Sort each row for delta encoding - idx_sorted, _ = torch.sort(idx, dim=1) - # Delta encode: val[0], val[1]-val[0], val[2]-val[1], ... + idx_sorted, idx_perm = torch.sort(idx, dim=1) delta = torch.cat( (idx_sorted[:, :1], idx_sorted[:, 1:] - idx_sorted[:, :-1]), dim=1 @@ -190,7 +205,7 @@ def encode_batch_rows( # header+payload, no 16-bit length, before byte-rounding "B_hist": B_hist, } - return payload_bytes, meta + return payload_bytes, meta, idx_perm # --- Triton Kernel 1: Cost Analysis --- diff --git a/tests/unit/test_compress.py b/tests/unit/test_compress.py index b9f26c70b..c4f6933c3 100644 --- a/tests/unit/test_compress.py +++ b/tests/unit/test_compress.py @@ -70,7 +70,7 @@ def test_compress_with_quantization( assert qparams is not None assert len(qparams) == 5 # shift, scale, offset, lookup, orig_dtype - def test_decompress_with_rice_bitmap_format( + def test_decompress_2d_with_rice_bitmap_format( self, compress_instance: TopKCompressor[Literal[False]] ): """Test that decompress can handle Rice/bitmap encoded format""" @@ -84,7 +84,7 @@ def test_decompress_with_rice_bitmap_format( original_indices = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.int64) # Pack using the new encoder format - payload, _ = encode_batch_rows(original_indices, C=totalk) + payload, _, _ = encode_batch_rows(original_indices, C=totalk) idx = torch.tensor(np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8) val = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32) @@ -107,12 +107,12 @@ def test_batch_decompress_multiple_rice_bitmap_formats( idx2_orig = torch.tensor([[4, 5], [6, 7]], dtype=torch.int64) # Pack them using the new encoder format - payload1, _ = encode_batch_rows(idx1_orig, C=totalk) + payload1, _, _ = encode_batch_rows(idx1_orig, C=totalk) idx1_packed = torch.tensor( np.frombuffer(payload1, dtype=np.uint8), dtype=torch.uint8 ) - payload2, _ = encode_batch_rows(idx2_orig, C=totalk) + payload2, _, _ = encode_batch_rows(idx2_orig, C=totalk) idx2_packed = torch.tensor( np.frombuffer(payload2, dtype=np.uint8), dtype=torch.uint8 ) @@ -130,7 +130,40 @@ def test_batch_decompress_multiple_rice_bitmap_formats( assert result.shape == xshape assert result.dtype == p.dtype - def test_compress_decompress_round_trip( + def test_compress_decompress_round_trip_1d( + self, compress_instance: TopKCompressor[Literal[False]] + ): + x = torch.zeros(128,) # 1024 elements total, last dim=128 + x[0] = 1.0 + x[1] = 2.0 + x[2] = 3.0 + x[3] = 4.0 + topk = 4 + + idx, val, xshape, totalk = compress_instance.compress(x, topk) + + # Verify we got the top-k values + assert idx.dtype == torch.uint8, ( + "Expected uint8 for Rice/bitmap encoded indices" + ) + assert val.shape[-1] == topk + + # Decompress + p = torch.zeros_like(x) + result = compress_instance.decompress(p, idx, val, xshape, totalk) + + # Verify shape + assert result.shape == x.shape + + # Verify the top values were preserved + assert result.abs().max() > 0, "Decompressed tensor should have non-zero values" + + # The top 4 values should be approximately 4, 3, 2, 1 + top_vals = torch.topk(result.abs().flatten(), k=4).values + expected_vals = torch.tensor([4.0, 3.0, 2.0, 1.0]) + assert torch.allclose(top_vals, expected_vals, atol=1e-5) + + def test_compress_decompress_round_trip_2d( self, compress_instance: TopKCompressor[Literal[False]] ): """Test full compress-decompress round trip""" @@ -202,7 +235,7 @@ def test_batch_decompress_with_norm_options( # Create test data with Rice/bitmap encoded format idx_orig = torch.tensor([[0, 1, 2, 3]], dtype=torch.int64) # Even count - payload, _ = encode_batch_rows(idx_orig, C=totalk) + payload, _, _ = encode_batch_rows(idx_orig, C=totalk) idx_packed = torch.tensor( np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8 ) From a17e1cb929212c89f09ff49779e3f4856dd55856 Mon Sep 17 00:00:00 2001 From: Kasper Date: Fri, 14 Nov 2025 14:47:40 +0400 Subject: [PATCH 04/33] cleanup --- src/tplr/comms.py | 100 ++++++++++++++++++------------- src/tplr/compress.py | 18 +++--- src/tplr/compression/__init__.py | 6 +- src/tplr/compression/hybrid.py | 67 ++++++--------------- src/tplr/neurons.py | 55 +++++++++++------ tests/unit/test_compress.py | 78 ++++++++++-------------- 6 files changed, 156 insertions(+), 168 deletions(-) diff --git a/src/tplr/comms.py b/src/tplr/comms.py index 1578d5f7c..9b90c2c9e 100644 --- a/src/tplr/comms.py +++ b/src/tplr/comms.py @@ -29,6 +29,7 @@ from functools import partial from types import SimpleNamespace from typing import Any, Literal, cast +import numpy as np import aiofiles import bittensor as bt @@ -44,7 +45,7 @@ import tplr from tplr.chain import ChainManager from tplr.compress import TopKCompressor -from tplr.compression import unpack_12bit_indices +from tplr.compression import decode_batch_rows from tplr.config import BUCKET_SECRETS, client_config from tplr.schemas import Bucket, CommsGetResult @@ -2479,10 +2480,8 @@ def check_compressed_indices( """ Validates the integrity and format of compressed gradient indices. - This is a crucial security and stability check to ensure that gradients - received from peers are well-formed. It verifies that indices are within - the expected bounds and that the compression format (e.g., 12-bit packing) - is correctly applied. + This ensures indices are within bounds and that the **new Rice/bitmap** + codec payload matches the provided values tensor shape (top‑k). Args: param_name (str): The name of the parameter being checked. @@ -2495,7 +2494,7 @@ def check_compressed_indices( Raises: ValueError: If any validation check fails, such as out-of-bounds - indices, incorrect data types, or malformed packed data. + indices, incorrect data types, or malformed payload. """ allowed_topk = ( min(self.hparams.topk_compression, totalk) @@ -2503,44 +2502,61 @@ def check_compressed_indices( else min(allowed_topk, totalk) ) - def _bounds_check(t: torch.Tensor): - """fast min/max bounds check""" - if t.numel() == 0: - raise ValueError(f"[{param_name}] empty index list") - if t.min().item() < 0 or t.max().item() >= totalk: - bad = t[(t < 0) | (t >= totalk)][0].item() - raise ValueError( - f"[{param_name}] Index {bad} out of bounds (totalk = {totalk})" - ) + if not isinstance(idxs, torch.Tensor): + raise ValueError( + f"[{param_name}] Expected tensor for indices, got {type(idxs)}" + ) + if vals is None: + raise ValueError( + f"[{param_name}] Values tensor required for index validation" + ) + if idxs.dtype != torch.uint8: + raise ValueError( + f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}" + ) + if idxs.numel() == 0: + raise ValueError(f"[{param_name}] Empty indices payload") - # Handle 12-bit packed index format only - if isinstance(idxs, torch.Tensor): - if idxs.dtype != torch.uint8: - raise ValueError( - f"[{param_name}] Expected uint8 for 12-bit packed indices, got {idxs.dtype}" - ) - # 12-bit packed format is the only supported format - if vals is None: - raise ValueError( - f"[{param_name}] Values tensor required to validate 12-bit packed indices" - ) - if idxs.numel() == 0: - raise ValueError(f"[{param_name}] Empty packed indices tensor") + # Decode (CPU) and perform structural checks + try: + payload_bytes = idxs.detach().cpu().numpy().tobytes() + rows_list, C, N = decode_batch_rows(payload_bytes) + except Exception as e: + raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}") - # Unpack using the values shape - try: - unpacked = unpack_12bit_indices(idxs, vals.shape) - # Validate that the last dimension matches allowed_topk - if unpacked.shape[-1] != allowed_topk: - raise ValueError( - f"[{param_name}] Invalid topk dimension: " - f"shape[-1]={unpacked.shape[-1]} but expected {allowed_topk}" - ) - _bounds_check(unpacked) - except Exception as e: - raise ValueError(f"[{param_name}] Failed to unpack 12-bit indices: {e}") - else: - raise ValueError(f"[{param_name}] Expected tensor but got {type(idxs)}") + if C != totalk: + raise ValueError( + f"[{param_name}] Payload column size C={C} but expected {totalk}" + ) + + # compute expected rows from values shape (flatten all but last dim) + if vals.ndim == 0: + raise ValueError(f"[{param_name}] Values tensor has no top‑k dimension") + expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1 + if N != expected_rows: + raise ValueError( + f"[{param_name}] Payload rows N={N} but values imply {expected_rows}" + ) + + k = vals.shape[-1] + if k != allowed_topk: + raise ValueError( + f"[{param_name}] Values top‑k={k} but allowed_topk={allowed_topk}" + ) + if any(len(r) != k for r in rows_list): + raise ValueError( + f"[{param_name}] At least one row has mismatched top‑k size" + ) + + # bounds check without materialising full tensor + max_idx = max((max(r) if len(r) > 0 else -1) for r in rows_list) + min_idx = ( + min((min(r) if len(r) > 0 else 0) for r in rows_list) if rows_list else 0 + ) + if min_idx < 0 or max_idx >= totalk: + raise ValueError( + f"[{param_name}] Index out of bounds (min={min_idx}, max={max_idx}, totalk={totalk})" + ) async def s3_get_object_size(self, bucket: Bucket, key: str) -> int | None: """ diff --git a/src/tplr/compress.py b/src/tplr/compress.py index 6b25d1072..41927497a 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -31,7 +31,7 @@ import tplr -from tplr.compression import encode_batch_rows, decode_batch_rows, encode_sparsification +from tplr.compression import encode_batch_rows_sorted, decode_batch_rows # ─────────── type aliases ──────────────────────────────────────────────── # primitive shapes @@ -323,9 +323,13 @@ def compress(self, x: torch.Tensor, topk: int): # type: ignore[override] val = torch.gather(x, dim=-1, index=idx_int64) # Flatten to [rows, k] for the codec - idx2d = idx_int64.reshape(-1, topk).contiguous() - val2d = val.reshape(-1, topk).contiguous() - idx_bytes, val, _meta = encode_sparsification(idx2d, val2d, C=totalk, B_choices=_DEFAULT_B_CHOICES) + idx2d = idx_int64.reshape(-1, topk).to(torch.int32) + val2d = val.reshape(-1, topk) + + # sort indices and apply same perm to values + idx_sorted, perm = torch.sort(idx2d, dim=1) + val = torch.gather(val2d, dim=1, index=perm) + idx_bytes, _meta = encode_batch_rows_sorted(idx_sorted, C=totalk, B_choices=_DEFAULT_B_CHOICES) # Apply 8-bit quantization if enabled if self.use_quantization: @@ -387,9 +391,9 @@ def decompress( if val.dtype != x.dtype: val = val.to(dtype=x.dtype) - if x.ndim == 1: - idx_int64 = idx_int64.flatten() - val = val.flatten() + if len(xshape) > 2: + idx_int64 = rearrange(idx_int64, "(y x) h -> y x h", y=xshape[0]) + val = rearrange(val, "(y x) h -> y x h", y=xshape[0]) x.scatter_reduce_( dim=-1, index=idx_int64, src=val, reduce="mean", include_self=False diff --git a/src/tplr/compression/__init__.py b/src/tplr/compression/__init__.py index 4c6ac0b3f..732c8483e 100644 --- a/src/tplr/compression/__init__.py +++ b/src/tplr/compression/__init__.py @@ -17,9 +17,8 @@ # DEALINGS IN THE SOFTWARE. from .hybrid import ( - encode_sparsification, decode_batch_rows, # decoder (CPU) - encode_batch_rows, # GPU-accelerated encoder → bytes + perm + meta + encode_batch_rows_sorted, # GPU-accelerated encoder → bytes + perm + meta ) from .pack12 import ( @@ -28,9 +27,8 @@ ) __all__ = [ # High level - "encode_batch_rows", + "encode_batch_rows_sorted", "decode_batch_rows", "pack_12bit_indices", "unpack_12bit_indices", - "encode_sparsification" ] \ No newline at end of file diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py index 5a32a347e..9535d1fae 100644 --- a/src/tplr/compression/hybrid.py +++ b/src/tplr/compression/hybrid.py @@ -10,34 +10,18 @@ BytesLike = Union[bytes, bytearray, np.ndarray, torch.Tensor] -def encode_sparsification( - idx: torch.Tensor, - vals: torch.Tensor, - C: int, - B_choices: Tuple[int, ...] = (64, 128) -) -> Tuple[bytes, torch.Tensor, Dict]: - dev = idx.device - payload, meta, perm = encode_batch_rows(idx, C=C, B_choices=B_choices) - vals = torch.gather(vals, dim=1, index=perm.to(dev)) - idx_bytes = torch.tensor( - np.frombuffer(payload, dtype=np.uint8).copy(), - dtype=torch.uint8, - device="cpu", - ) - return idx_bytes, vals, meta - -def encode_batch_rows( - idx: torch.Tensor, +def encode_batch_rows_sorted( + idx_sorted: torch.Tensor, *, C: int, B_choices: Tuple[int, ...] = (64, 128) -) -> Tuple[bytes, Dict, torch.Tensor]: +) -> Tuple[bytes, Dict]: """ - Compresses a 2D int64 tensor of Top-K indices into a byte string + Compresses a 2D tensor of Top-K indices into a byte string using a per-row adaptive Rice/Bitmap compression scheme on the GPU. Args: - idx (torch.Tensor): [rows, k] int64 tensor of indices. + idx_sorted (torch.Tensor): [rows, k] sorted tensor of indices. C (int): The total number of columns (0 <= idx < C). B_choices (tuple[int, ...]): Block sizes to evaluate. Must be powers of two. @@ -52,8 +36,8 @@ def encode_batch_rows( if not torch.cuda.is_available(): raise RuntimeError("CUDA is required for this function.") - if not isinstance(idx, torch.Tensor) or idx.ndim != 2 or idx.dtype != torch.int64: - raise ValueError(f"idx must be a 2D int64 tensor, got {idx.shape} {idx.dtype}") + if not isinstance(idx_sorted, torch.Tensor) or idx_sorted.ndim != 2: + raise ValueError(f"idx must be a 2D int64 tensor, got {idx_sorted.shape}") if not all(isinstance(b, int) and (b & (b - 1) == 0) and b > 0 for b in B_choices): raise ValueError(f"All B_choices must be powers of two, got {B_choices}") @@ -61,7 +45,7 @@ def encode_batch_rows( if not all(C % b == 0 for b in B_choices): raise ValueError(f"All B_choices must evenly divide C={C}, got {B_choices}") - num_rows, k_dim = idx.shape + num_rows, k_dim = idx_sorted.shape if num_rows == 0: return b"", { "total_bits": 0, @@ -69,10 +53,10 @@ def encode_batch_rows( "B_hist": {b: 0 for b in B_choices} } - if not idx.is_cuda: - idx = idx.cuda() - idx = idx.contiguous() - dev = idx.device + if not idx_sorted.is_cuda: + idx_sorted = idx_sorted.cuda() + idx_sorted = idx_sorted.contiguous() + dev = idx_sorted.device # Calculate k_rice parameters (log2(C // B)) k_rice_choices = tuple(int(math.log2(C // b)) for b in B_choices) @@ -88,27 +72,19 @@ def encode_batch_rows( # Row header: 1 bit (bitmap/rice) + B_choice_bits ROW_HEADER_BITS = 1 + B_choice_bits - # --- 2. GPU Preprocessing: Sort & Delta Encode --- - # Delta encode: val[0], val[1]-val[0], val[2]-val[1], ... - idx_sorted, idx_perm = torch.sort(idx, dim=1) delta = torch.cat( (idx_sorted[:, :1], idx_sorted[:, 1:] - idx_sorted[:, :-1]), dim=1 ) - # Ensure all deltas are non-negative - delta = torch.clamp(delta, min=0) # --- 3. Kernel 1: Cost Analysis --- # Output tensors for cost kernel costs = torch.empty((num_rows, num_B_choices), dtype=torch.int32, device=dev) is_bitmap = torch.empty((num_rows, num_B_choices), dtype=torch.int8, device=dev) - - # Grid is 1D, one program per row grid = (num_rows,) - # Launch cost kernel # k_dim is passed as constexpr for tl.arange, but B_choices are dynamic cost_kernel[grid]( delta, @@ -187,15 +163,8 @@ def encode_batch_rows( ROW_HEADER_BITS=ROW_HEADER_BITS, ) - # --- 7. Copy to CPU and Return --- + payload_cpu = payload_buf.cpu() - # Copy buffer from GPU to CPU - payload_cpu = payload_buf.cpu().numpy() - - # Convert to final bytes object - payload_bytes = payload_cpu.tobytes() - - # --- meta --- b_counts = torch.bincount(best_B_idx, minlength=len(B_choices)) B_hist = {b: c.item() for b, c in zip(B_choices, b_counts)} meta = { @@ -205,10 +174,8 @@ def encode_batch_rows( # header+payload, no 16-bit length, before byte-rounding "B_hist": B_hist, } - return payload_bytes, meta, idx_perm - + return payload_cpu, meta -# --- Triton Kernel 1: Cost Analysis --- @triton.jit def cost_kernel( @@ -484,7 +451,7 @@ def _parse_global_header(payload: BytesLike) -> Tuple[int, int, list[int], int]: Layout: 4B "CGRP" 4B C (uint32 LE) - 2B K (uint16 LE) <-- NEW + 2B K (uint16 LE) 1B num_B 2B * num_B (each B, uint16 LE) Returns: (C, K, B_choices, header_end_bit_offset) @@ -506,7 +473,7 @@ def _parse_global_header(payload: BytesLike) -> Tuple[int, int, list[int], int]: raise ValueError("Bad magic; expected 'CGRP'") C = int.from_bytes(raw[4:8], "little", signed=False) - K = int.from_bytes(raw[8:10], "little", signed=False) # NEW + K = int.from_bytes(raw[8:10], "little", signed=False) num_B = raw[10] need = 4 + 4 + 2 + 1 + 2 * num_B if len(raw) < need: @@ -622,7 +589,7 @@ def decode_batch_rows(payload: BytesLike) -> tuple[list[list[int]], int, int]: if not row_vals: break rows_out.append(row_vals) - return rows_out, C, num_B + return rows_out, C, len(rows_out) if __name__ == "__main__": diff --git a/src/tplr/neurons.py b/src/tplr/neurons.py index c4f6d8852..4f9bcdc98 100644 --- a/src/tplr/neurons.py +++ b/src/tplr/neurons.py @@ -35,6 +35,7 @@ import tplr from tplr.distributed import dist_helper +from tplr.compression import decode_batch_rows if TYPE_CHECKING: from neurons.miner import Miner @@ -1246,30 +1247,46 @@ async def check_uid_index_overlap( if idxs_all is None: continue - # Get values for unpacking shape - vals_key = pname + "vals" - vals_all = getattr(gather_result.state_dict, vals_key, None) - if vals_all is None: - continue + def _as_bytes(x) -> bytes: + if isinstance(x, (bytes, bytearray)): + return bytes(x) + if isinstance(x, torch.Tensor): + if x.dtype != torch.uint8: + raise ValueError( + f"Expected torch.uint8 for Rice payload, got {x.dtype}" + ) + return x.detach().cpu().contiguous().numpy().tobytes() + raise TypeError(f"Unsupported idx payload type: {type(x)}") + + decoded_per_peer: list[torch.Tensor] = [] - # Unpack all 12-bit packed indices using values shape - unpacked_indices = [] for i in range(Ptot): - idx_data = idxs_all[i] if isinstance(idxs_all, list) else idxs_all - val_data = vals_all[i] if isinstance(vals_all, list) else vals_all + idx_data = idxs_all[i] if isinstance(idxs_all, (list, tuple)) else idxs_all + payload = _as_bytes(idx_data) - # 12-bit packed format - use values shape for unpacking - unpacked = unpack_12bit_indices( - idx_data.to(neuron.config.device), val_data.shape - ) - unpacked_indices.append(unpacked) + rows_i, _C_codec, N_rows = decode_batch_rows( + payload + ) # rows_i: list[list[int]] + if N_rows == 0: + # no rows for this param/peer → skip param entirely + decoded_per_peer = [] + break + + # ensure rectangular (constant k) + k0 = len(rows_i[0]) + if not all(len(r) == k0 for r in rows_i): + raise ValueError("Rice payload has variable k per row; unsupported.") + + decoded_per_peer.append(torch.tensor(rows_i, dtype=torch.int64)) + + if not decoded_per_peer: + continue - idxs_tensor = torch.stack(unpacked_indices, dim=0) - P, *chunk_dims, k = idxs_tensor.shape - C = int(torch.prod(torch.tensor(chunk_dims))) # num chunks - idxs_flat = idxs_tensor.reshape(P, C, k) + idxs_tensor = torch.stack(decoded_per_peer, dim=0) # [P, C, K] + P, C_chunks, k = idxs_tensor.shape + idxs_flat = idxs_tensor # already [P, C, K] - param_weight = C * k # size weight + param_weight = C_chunks * k # size weight for i in range(P): for j in range(i + 1, P): diff --git a/tests/unit/test_compress.py b/tests/unit/test_compress.py index c4f6933c3..1363436bf 100644 --- a/tests/unit/test_compress.py +++ b/tests/unit/test_compress.py @@ -13,7 +13,7 @@ _idct ) -from tplr.compression import encode_batch_rows, pack_12bit_indices, unpack_12bit_indices +from tplr.compression import encode_batch_rows_sorted, pack_12bit_indices, unpack_12bit_indices class TestTopKCompressor: @@ -70,7 +70,7 @@ def test_compress_with_quantization( assert qparams is not None assert len(qparams) == 5 # shift, scale, offset, lookup, orig_dtype - def test_decompress_2d_with_rice_bitmap_format( + def test_decompress_with_rice_bitmap_format( self, compress_instance: TopKCompressor[Literal[False]] ): """Test that decompress can handle Rice/bitmap encoded format""" @@ -84,12 +84,11 @@ def test_decompress_2d_with_rice_bitmap_format( original_indices = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.int64) # Pack using the new encoder format - payload, _, _ = encode_batch_rows(original_indices, C=totalk) - idx = torch.tensor(np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8) + idx_bytes, _ = encode_batch_rows_sorted(original_indices, C=totalk) val = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32) # Test decompression with packed format - result = compress_instance.decompress(p, idx, val, xshape, totalk) + result = compress_instance.decompress(p, idx_bytes, val, xshape, totalk) assert result.shape == xshape assert result.dtype == p.dtype @@ -107,15 +106,8 @@ def test_batch_decompress_multiple_rice_bitmap_formats( idx2_orig = torch.tensor([[4, 5], [6, 7]], dtype=torch.int64) # Pack them using the new encoder format - payload1, _, _ = encode_batch_rows(idx1_orig, C=totalk) - idx1_packed = torch.tensor( - np.frombuffer(payload1, dtype=np.uint8), dtype=torch.uint8 - ) - - payload2, _, _ = encode_batch_rows(idx2_orig, C=totalk) - idx2_packed = torch.tensor( - np.frombuffer(payload2, dtype=np.uint8), dtype=torch.uint8 - ) + idx1_packed, _ = encode_batch_rows_sorted(idx1_orig, C=totalk) + idx2_packed, _ = encode_batch_rows_sorted(idx2_orig, C=totalk) idx_list = [idx1_packed, idx2_packed] @@ -130,14 +122,16 @@ def test_batch_decompress_multiple_rice_bitmap_formats( assert result.shape == xshape assert result.dtype == p.dtype - def test_compress_decompress_round_trip_1d( + def test_compress_decompress_round_trip( self, compress_instance: TopKCompressor[Literal[False]] ): - x = torch.zeros(128,) # 1024 elements total, last dim=128 - x[0] = 1.0 - x[1] = 2.0 - x[2] = 3.0 - x[3] = 4.0 + """Test full compress-decompress round trip""" + x = torch.zeros(8, 128) # 1024 elements total, last dim=128 + x[0, 0] = 1.0 + x[1, 1] = 2.0 + x[2, 2] = 3.0 + x[3, 3] = 4.0 + topk = 4 idx, val, xshape, totalk = compress_instance.compress(x, topk) @@ -163,40 +157,35 @@ def test_compress_decompress_round_trip_1d( expected_vals = torch.tensor([4.0, 3.0, 2.0, 1.0]) assert torch.allclose(top_vals, expected_vals, atol=1e-5) - def test_compress_decompress_round_trip_2d( + def test_encode_compress_decompress_round_trip( self, compress_instance: TopKCompressor[Literal[False]] ): - """Test full compress-decompress round trip""" - x = torch.zeros(8, 128) # 1024 elements total, last dim=128 + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.layer1 = nn.Linear(32, 64) + self.layer2 = nn.Linear(64, 256) + + target_chunk = 16 + transform = ChunkingTransformer( SimpleModel(), target_chunk) + x = torch.zeros(256, 64) x[0, 0] = 1.0 x[1, 1] = 2.0 x[2, 2] = 3.0 x[3, 3] = 4.0 topk = 4 - - idx, val, xshape, totalk = compress_instance.compress(x, topk) - - # Verify we got the top-k values - assert idx.dtype == torch.uint8, ( - "Expected uint8 for Rice/bitmap encoded indices" + encoded = transform.encode(x) + idxs, vals, xshape, totalk = compress_instance.compress( + encoded, topk ) - assert val.shape[-1] == topk - # Decompress p = torch.zeros_like(x) - result = compress_instance.decompress(p, idx, val, xshape, totalk) - - # Verify shape - assert result.shape == x.shape - - # Verify the top values were preserved - assert result.abs().max() > 0, "Decompressed tensor should have non-zero values" + decompressed = compress_instance.decompress( + p, idxs, vals, xshape, totalk + ) + assert torch.allclose(encoded, decompressed, atol=1e-5) - # The top 4 values should be approximately 4, 3, 2, 1 - top_vals = torch.topk(result.abs().flatten(), k=4).values - expected_vals = torch.tensor([4.0, 3.0, 2.0, 1.0]) - assert torch.allclose(top_vals, expected_vals, atol=1e-5) def test_rice_bitmap_index_value_range( self, compress_instance: TopKCompressor[Literal[False]] @@ -235,10 +224,7 @@ def test_batch_decompress_with_norm_options( # Create test data with Rice/bitmap encoded format idx_orig = torch.tensor([[0, 1, 2, 3]], dtype=torch.int64) # Even count - payload, _, _ = encode_batch_rows(idx_orig, C=totalk) - idx_packed = torch.tensor( - np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8 - ) + idx_packed, _ = encode_batch_rows_sorted(idx_orig, C=totalk) idx = [idx_packed] val = [torch.tensor([[10.0, 20.0, 30.0, 40.0]], dtype=torch.float32)] From ddf7448f1bd21d055f1a64442788cbff9e4de4a0 Mon Sep 17 00:00:00 2001 From: Kasper Date: Fri, 14 Nov 2025 16:06:37 +0400 Subject: [PATCH 05/33] cleanup impl --- src/tplr/compression/hybrid.py | 105 +++++++++++---------------------- 1 file changed, 34 insertions(+), 71 deletions(-) diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py index 9535d1fae..99cb3e7e1 100644 --- a/src/tplr/compression/hybrid.py +++ b/src/tplr/compression/hybrid.py @@ -15,7 +15,7 @@ def encode_batch_rows_sorted( *, C: int, B_choices: Tuple[int, ...] = (64, 128) -) -> Tuple[bytes, Dict]: +) -> Tuple[BytesLike, Dict]: """ Compresses a 2D tensor of Top-K indices into a byte string using a per-row adaptive Rice/Bitmap compression scheme on the GPU. @@ -61,12 +61,7 @@ def encode_batch_rows_sorted( # Calculate k_rice parameters (log2(C // B)) k_rice_choices = tuple(int(math.log2(C // b)) for b in B_choices) num_B_choices = len(B_choices) - - # Create tensors for dynamic kernel - B_choices_tensor = torch.tensor(B_choices, dtype=torch.int32, device=dev) k_rice_choices_tensor = torch.tensor(k_rice_choices, dtype=torch.int32, device=dev) - - # Bits needed to store the B_choice index B_choice_bits = (num_B_choices - 1).bit_length() # Row header: 1 bit (bitmap/rice) + B_choice_bits @@ -78,8 +73,6 @@ def encode_batch_rows_sorted( dim=1 ) - # --- 3. Kernel 1: Cost Analysis --- - # Output tensors for cost kernel costs = torch.empty((num_rows, num_B_choices), dtype=torch.int32, device=dev) is_bitmap = torch.empty((num_rows, num_B_choices), dtype=torch.int8, device=dev) @@ -90,15 +83,13 @@ def encode_batch_rows_sorted( delta, costs, is_bitmap, - C=C, k_dim=k_dim, num_rows=num_rows, num_B_choices=num_B_choices, - B_choices_ptr=B_choices_tensor, k_rice_choices_ptr=k_rice_choices_tensor, ) - # --- 4. Post-Kernel 1: pick best B/mode & compute layout --- + # pick best B/mode & compute layout # Best choice per row min_costs, best_B_idx = torch.min(costs, dim=1) @@ -147,7 +138,6 @@ def encode_batch_rows_sorted( list(global_header_py), dtype=torch.uint8, device=dev ) - # --- pack kernel --- pack_kernel[(num_rows,)]( delta, payload_buf, @@ -155,23 +145,19 @@ def encode_batch_rows_sorted( row_payload_bytes, # already int32 best_B_idx.to(torch.int32), is_bitmap_choice, # int32 0/1 - B_choices_tensor, k_rice_choices_tensor, num_rows, - C=C, k_dim=k_dim, ROW_HEADER_BITS=ROW_HEADER_BITS, ) payload_cpu = payload_buf.cpu() - b_counts = torch.bincount(best_B_idx, minlength=len(B_choices)) B_hist = {b: c.item() for b, c in zip(B_choices, b_counts)} meta = { "total_bits": total_bits, # includes 16-bit length and byte padding "avg_bits_per_row": float(row_bits_aligned.float().mean().item()), "avg_payload_bits_per_row": float(row_payload_bits.float().mean().item()), - # header+payload, no 16-bit length, before byte-rounding "B_hist": B_hist, } return payload_cpu, meta @@ -182,12 +168,10 @@ def cost_kernel( delta_ptr, # (rows, k_dim) IN costs_ptr, # (rows, num_B_choices) OUT is_bitmap_ptr, # (rows, num_B_choices) OUT (bool/int) - C: tl.int32, k_dim: tl.constexpr, # constexpr for tl.arange num_rows: tl.int32, num_B_choices: tl.int32, - B_choices_ptr, # IN (tensor) - k_rice_choices_ptr, # IN (tensor) + k_rice_choices_ptr, # (num_B_choices,) int32 ): """ Calculates the compressed bit cost for each row for each B in B_choices. @@ -204,31 +188,22 @@ def cost_kernel( # Load the entire row of delta-encoded values into SRAM row_base = row_idx * k_dim delta = tl.load(delta_ptr + row_base + i) - - # Also load the first delta as a scalar for q0 (avoids illegal q[0] indexing) delta0 = tl.load(delta_ptr + row_base) - # Iterate over B choices dynamically b_idx = 0 while b_idx < num_B_choices: - # Dynamic parameters for this choice - B = tl.load(B_choices_ptr + b_idx) + # k_rice and M = 1 << k_rice k_rice = tl.load(k_rice_choices_ptr + b_idx) - # Rice modulus - M = C // B - - # Vectorized q for the row and scalar q0 for the first element - q = delta // M - q0 = delta0 // M + # q via shift, r via mask + q = delta >> k_rice + q0 = delta0 >> k_rice # Pure Rice cost: sum(unary(q)) + sum(r) where unary(q) has (q + 1) bits, # and r contributes k_rice bits per element. rice_cost = tl.sum(q + 1) + k_dim * k_rice - # Variant B bitmap cost: - # - first element written with full Rice: (q0 + 1 + k_rice) - # - remaining (k_dim - 1) elements written as (1 + k_rice) each + # Bitmap cost: first element full Rice, tail has (1 + k_rice) bits # (1 bit for q in {0,1} + k_rice bits for r) bitmap_cost = (q0 + 1 + k_rice) + (k_dim - 1) * (1 + k_rice) # equivalently: bitmap_cost = k_dim * (1 + k_rice) + q0 @@ -244,12 +219,13 @@ def cost_kernel( out_offset = row_idx * num_B_choices + b_idx tl.store(costs_ptr + out_offset, min_cost) # make sure is_bitmap is exactly 0/1 in memory - tl.store(is_bitmap_ptr + out_offset, tl.where(use_bitmap, 1, 0)) + tl.store( + is_bitmap_ptr + out_offset, + tl.where(use_bitmap, 1, 0).to(tl.int32), + ) b_idx += 1 -# --- Triton Kernel 2: Bit-Stream Packing --- - @triton.jit def write_nbits(u8_ptr, bit_off_i32, value_u32, nbits_i32): """ @@ -264,20 +240,18 @@ def write_nbits(u8_ptr, bit_off_i32, value_u32, nbits_i32): ONE_U32 = tl.full((), 1, dtype=tl.uint32) while j < nbits_i32: - pos = bit_off_i32 + j + pos = bit_off_i32 + j byte_idx = (pos >> 3).to(tl.int32) - bit_idx = (pos & 7).to(tl.int32) + bit_idx = (pos & 7).to(tl.int32) - old_u8 = tl.load(u8_ptr + byte_idx) - old_u32 = old_u8.to(tl.uint32) + old_u8 = tl.load(u8_ptr + byte_idx) + old_u32 = old_u8.to(tl.uint32) - vbit = (value_u32 >> j) & ONE_U32 - mask = ONE_U32 << bit_idx - new_u32 = (old_u32 & (~mask)) | (vbit << bit_idx) + vbit = (value_u32 >> j) & ONE_U32 + mask = ONE_U32 << bit_idx + new_u32 = (old_u32 & (~mask)) | (vbit << bit_idx) tl.store(u8_ptr + byte_idx, new_u32.to(tl.uint8)) - j += 1 - return bit_off_i32 + nbits_i32 @@ -289,10 +263,8 @@ def pack_kernel( row_payload_bytes_ptr, # (rows,) IN int32 best_B_idx_ptr, # (rows,) IN int32 is_bitmap_ptr, # (rows,) IN int32 (0/1) - B_choices_ptr, # [num_B] IN int32 k_rice_choices_ptr, # [num_B] IN int32 num_rows: tl.int32, - C: tl.constexpr, k_dim: tl.int32, # dynamic ROW_HEADER_BITS: tl.constexpr, ): @@ -305,20 +277,19 @@ def pack_kernel( return # per-row meta - bit_off_i32 = tl.load(row_bit_offsets_ptr + row_idx).to(tl.int32) + bit_off_i32 = tl.load(row_bit_offsets_ptr + row_idx).to(tl.int32) payload_bytes_i32 = tl.load(row_payload_bytes_ptr + row_idx).to(tl.int32) - b_idx_i32 = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) - use_bitmap_i32 = (tl.load(is_bitmap_ptr + row_idx) & 1).to(tl.int32) + b_idx_i32 = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) + use_bitmap_i32 = (tl.load(is_bitmap_ptr + row_idx) & 1).to(tl.int32) # params - B_i32 = tl.load(B_choices_ptr + b_idx_i32).to(tl.int32) k_rice_i32 = tl.load(k_rice_choices_ptr + b_idx_i32).to(tl.int32) - M_i32 = (C // B_i32).to(tl.int32) + M_i32 = (tl.full((), 1, dtype=tl.int32) << k_rice_i32) ONE_U32 = tl.full((), 1, dtype=tl.uint32) ZERO_U32 = tl.full((), 0, dtype=tl.uint32) - ONE_I32 = tl.full((), 1, dtype=tl.int32) - THIRTY_ONE_I32 = tl.full((), 31, dtype=tl.int32) # **cap chunks at 31** + ONE_I32 = tl.full((), 1, dtype=tl.int32) + THIRTY_ONE_I32 = tl.full((), 31, dtype=tl.int32) # 16-bit length bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, @@ -336,8 +307,8 @@ def pack_kernel( # ---- first delta: ALWAYS Rice ---- if k_dim > 0: v0 = tl.load(delta_ptr + base).to(tl.int32) - q0 = (v0 // M_i32).to(tl.int32) - r0 = (v0 % M_i32).to(tl.int32) + q0 = (v0 >> k_rice_i32).to(tl.int32) + r0 = (v0 & (M_i32 - 1)).to(tl.int32) # q0 ones in chunks of <=31, then a single 0 q_left = q0 @@ -346,17 +317,15 @@ def pack_kernel( ones = (ONE_U32 << chunk) - ONE_U32 bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, ones, chunk) q_left -= chunk - bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, ZERO_U32, ONE_I32) + bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, ZERO_U32, ONE_I32) # terminating 0 + bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, r0.to(tl.uint32), k_rice_i32) # remainder - # remainder - bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, r0.to(tl.uint32), k_rice_i32) - - # ---- tail ---- + # tail deltas i = 1 while i < k_dim: v = tl.load(delta_ptr + base + i).to(tl.int32) - q = (v // M_i32).to(tl.int32) - r = (v % M_i32).to(tl.int32) + q = (v >> k_rice_i32).to(tl.int32) + r = (v & (M_i32 - 1)).to(tl.int32) # Rice unary only if NOT bitmap q_left = tl.where(use_bitmap_i32 != 0, tl.full((), 0, dtype=tl.int32), q) @@ -375,14 +344,9 @@ def pack_kernel( # remainder always bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, r.to(tl.uint32), k_rice_i32) - i += 1 - -# --------------------------- -# Bitstream reader (LSB-first) -# --------------------------- class BitStreamReader: """ LSB-first bit reader over a bytes-like buffer (torch.uint8, np.uint8, or Python bytes). @@ -601,9 +565,8 @@ def decode_batch_rows(payload: BytesLike) -> tuple[list[list[int]], int, int]: idx = torch.topk(x.abs(), k=K, dim=-1, largest=True, sorted=False).indices idx, _ = torch.sort(idx, dim=1) - payload, _ = encode_batch_rows(idx, C=COLS, B_choices=(64, 128, 256)) - decoded = decode_batch_rows(payload) - dec = [torch.tensor(r, dtype=torch.int64) for r in decoded] + payload, _ = encode_batch_rows_sorted(idx, C=COLS, B_choices=(64, 128, 256)) + decoded, _, _ = decode_batch_rows(payload) ok = True idx = [row for row in idx] for r in range(ROWS): From a152b24e746a3e539c792c2135a14db7d8061dc2 Mon Sep 17 00:00:00 2001 From: Kasper Date: Sat, 15 Nov 2025 11:34:07 +0400 Subject: [PATCH 06/33] full gpu encoder plus optional delta --- src/tplr/compress.py | 7 +- src/tplr/compression/__init__.py | 4 +- src/tplr/compression/hybrid.py | 769 +++++++++++++++++++++---------- tests/unit/test_compress.py | 10 +- 4 files changed, 538 insertions(+), 252 deletions(-) diff --git a/src/tplr/compress.py b/src/tplr/compress.py index 41927497a..de3b7e9af 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -31,7 +31,7 @@ import tplr -from tplr.compression import encode_batch_rows_sorted, decode_batch_rows +from tplr.compression import encode_batch_rows, decode_batch_rows # ─────────── type aliases ──────────────────────────────────────────────── # primitive shapes @@ -329,7 +329,7 @@ def compress(self, x: torch.Tensor, topk: int): # type: ignore[override] # sort indices and apply same perm to values idx_sorted, perm = torch.sort(idx2d, dim=1) val = torch.gather(val2d, dim=1, index=perm) - idx_bytes, _meta = encode_batch_rows_sorted(idx_sorted, C=totalk, B_choices=_DEFAULT_B_CHOICES) + idx_bytes, _meta = encode_batch_rows(idx_sorted, C=totalk, B_choices=_DEFAULT_B_CHOICES) # Apply 8-bit quantization if enabled if self.use_quantization: @@ -372,8 +372,7 @@ def decompress( # Decode indices if idx.dtype == torch.uint8: - payload_bytes = idx.detach().cpu().numpy().tobytes() - rows_list, C, _N = decode_batch_rows(payload_bytes) + rows_list, C, _N = decode_batch_rows(idx) if C != totalk: raise ValueError(f"Index payload C={C} but expected {totalk}") k = val.shape[-1] diff --git a/src/tplr/compression/__init__.py b/src/tplr/compression/__init__.py index 732c8483e..3f0f3a875 100644 --- a/src/tplr/compression/__init__.py +++ b/src/tplr/compression/__init__.py @@ -18,7 +18,7 @@ from .hybrid import ( decode_batch_rows, # decoder (CPU) - encode_batch_rows_sorted, # GPU-accelerated encoder → bytes + perm + meta + encode_batch_rows, # GPU-accelerated encoder → bytes + perm + meta ) from .pack12 import ( @@ -27,7 +27,7 @@ ) __all__ = [ # High level - "encode_batch_rows_sorted", + "encode_batch_rows", "decode_batch_rows", "pack_12bit_indices", "unpack_12bit_indices", diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py index 99cb3e7e1..22c065ca3 100644 --- a/src/tplr/compression/hybrid.py +++ b/src/tplr/compression/hybrid.py @@ -1,6 +1,6 @@ import math from typing import Dict -from typing import List, Tuple, Union +from typing import Tuple, Union import numpy as np import torch @@ -10,18 +10,28 @@ BytesLike = Union[bytes, bytearray, np.ndarray, torch.Tensor] -def encode_batch_rows_sorted( - idx_sorted: torch.Tensor, +@torch.no_grad() +def encode_batch_rows( + idx: torch.Tensor, *, C: int, + use_delta: bool = True, B_choices: Tuple[int, ...] = (64, 128) ) -> Tuple[BytesLike, Dict]: """ - Compresses a 2D tensor of Top-K indices into a byte string + Compresses a 2D int64 tensor of Top-K indices into a byte string using a per-row adaptive Rice/Bitmap compression scheme on the GPU. + Layout: + 0..3 : "CGRP" (magic) + 4..7 : C (uint32 LE) + 8..9 : K (uint16 LE) + 10..13 : R (uint32 LE, num_rows) + 14 : num_B (uint8) + 15.. : B_choices (num_B * uint16 LE) + Args: - idx_sorted (torch.Tensor): [rows, k] sorted tensor of indices. + idx (torch.Tensor): [rows, k] int64 tensor of indices. C (int): The total number of columns (0 <= idx < C). B_choices (tuple[int, ...]): Block sizes to evaluate. Must be powers of two. @@ -36,8 +46,8 @@ def encode_batch_rows_sorted( if not torch.cuda.is_available(): raise RuntimeError("CUDA is required for this function.") - if not isinstance(idx_sorted, torch.Tensor) or idx_sorted.ndim != 2: - raise ValueError(f"idx must be a 2D int64 tensor, got {idx_sorted.shape}") + if not isinstance(idx, torch.Tensor) or idx.ndim != 2: + raise ValueError(f"idx must be a 2D int64 tensor, got {idx.shape} {idx.dtype}") if not all(isinstance(b, int) and (b & (b - 1) == 0) and b > 0 for b in B_choices): raise ValueError(f"All B_choices must be powers of two, got {B_choices}") @@ -45,7 +55,7 @@ def encode_batch_rows_sorted( if not all(C % b == 0 for b in B_choices): raise ValueError(f"All B_choices must evenly divide C={C}, got {B_choices}") - num_rows, k_dim = idx_sorted.shape + num_rows, k_dim = idx.shape if num_rows == 0: return b"", { "total_bits": 0, @@ -53,10 +63,22 @@ def encode_batch_rows_sorted( "B_hist": {b: 0 for b in B_choices} } - if not idx_sorted.is_cuda: - idx_sorted = idx_sorted.cuda() - idx_sorted = idx_sorted.contiguous() - dev = idx_sorted.device + if not idx.is_cuda: + idx = idx.cuda() + idx = idx.contiguous() + dev = idx.device + + if use_delta: + # v[0], v[1]-v[0], v[2]-v[1], ... + vals = torch.cat( + (idx[:, :1], idx[:, 1:] - idx[:, :-1]), + dim=1, + ) + else: + vals = idx + + # Cast to int32 for Triton kernels + vals = vals.to(torch.int32) # Calculate k_rice parameters (log2(C // B)) k_rice_choices = tuple(int(math.log2(C // b)) for b in B_choices) @@ -67,20 +89,15 @@ def encode_batch_rows_sorted( # Row header: 1 bit (bitmap/rice) + B_choice_bits ROW_HEADER_BITS = 1 + B_choice_bits - # Delta encode: val[0], val[1]-val[0], val[2]-val[1], ... - delta = torch.cat( - (idx_sorted[:, :1], idx_sorted[:, 1:] - idx_sorted[:, :-1]), - dim=1 - ) - # Output tensors for cost kernel costs = torch.empty((num_rows, num_B_choices), dtype=torch.int32, device=dev) is_bitmap = torch.empty((num_rows, num_B_choices), dtype=torch.int8, device=dev) grid = (num_rows,) + # Launch cost kernel # k_dim is passed as constexpr for tl.arange, but B_choices are dynamic cost_kernel[grid]( - delta, + vals, costs, is_bitmap, k_dim=k_dim, @@ -116,14 +133,16 @@ def encode_batch_rows_sorted( # Build global header bytes header_list = [] header_list.append(b"CGRP") # 4B magic - header_list.append(int(C).to_bytes(4, "little")) # 4B C (uint32 LE) - header_list.append(int(k_dim).to_bytes(2, "little")) # 2B K (uint16 LE) <--- NEW - header_list.append(bytes([len(B_choices)])) # 1B num_B + header_list.append(int(C).to_bytes(4, "little")) # 4B C (uint32 LE) + header_list.append(int(k_dim).to_bytes(2, "little")) # 2B K (uint16 LE) + header_list.append(int(num_rows).to_bytes(4, "little")) # 4B R (uint32 LE) NEW + header_list.append(bytes([len(B_choices)])) # 1B num_B for b in B_choices: - header_list.append(int(b).to_bytes(2, "little")) # 2B per B (uint16 LE) + header_list.append(int(b).to_bytes(2, "little")) # 2B per B (uint16 LE) global_header_py = b"".join(header_list) global_header_len_bytes = len(global_header_py) + # this is 15 + 2 * len(B_choices) # shift row starts by header row_bit_offsets = row_bit_offsets + global_header_len_bytes * 8 @@ -139,7 +158,7 @@ def encode_batch_rows_sorted( ) pack_kernel[(num_rows,)]( - delta, + vals, payload_buf, row_bit_offsets.to(torch.int32), row_payload_bytes, # already int32 @@ -151,16 +170,16 @@ def encode_batch_rows_sorted( ROW_HEADER_BITS=ROW_HEADER_BITS, ) - payload_cpu = payload_buf.cpu() b_counts = torch.bincount(best_B_idx, minlength=len(B_choices)) B_hist = {b: c.item() for b, c in zip(B_choices, b_counts)} meta = { "total_bits": total_bits, # includes 16-bit length and byte padding "avg_bits_per_row": float(row_bits_aligned.float().mean().item()), "avg_payload_bits_per_row": float(row_payload_bits.float().mean().item()), + # header+payload, no 16-bit length, before byte-rounding "B_hist": B_hist, } - return payload_cpu, meta + return payload_buf, meta @triton.jit @@ -185,7 +204,7 @@ def cost_kernel( # Lane indices for this row (constexpr width) i = tl.arange(0, k_dim) - # Load the entire row of delta-encoded values into SRAM + # Load entire row of delta-encoded values into SRAM row_base = row_idx * k_dim delta = tl.load(delta_ptr + row_base + i) delta0 = tl.load(delta_ptr + row_base) @@ -199,17 +218,13 @@ def cost_kernel( q = delta >> k_rice q0 = delta0 >> k_rice - # Pure Rice cost: sum(unary(q)) + sum(r) where unary(q) has (q + 1) bits, - # and r contributes k_rice bits per element. + # Pure Rice cost: sum(q + 1) + k_dim * k_rice rice_cost = tl.sum(q + 1) + k_dim * k_rice # Bitmap cost: first element full Rice, tail has (1 + k_rice) bits - # (1 bit for q in {0,1} + k_rice bits for r) bitmap_cost = (q0 + 1 + k_rice) + (k_dim - 1) * (1 + k_rice) - # equivalently: bitmap_cost = k_dim * (1 + k_rice) + q0 # Allow bitmap only if tail q are in {0,1} - # Compute tail max with a masked reduction (ignore lane 0) q_tail_max = tl.max(tl.where(i > 0, q, 0)) bitmap_allowed = q_tail_max <= 1 @@ -227,14 +242,18 @@ def cost_kernel( @triton.jit -def write_nbits(u8_ptr, bit_off_i32, value_u32, nbits_i32): +def write_nbits( + u8_ptr, # uint8* global buffer + bit_off_i32, # scalar tl.int32 bit offset + value_u32, # scalar tl.uint32, up to 32 bits used + nbits_i32, # scalar tl.int32, number of bits to write +): """ - Write `nbits_i32` bits from `value_u32` at bit offset `bit_off_i32` (LSB-first). - All args are Triton scalars: - - bit_off_i32 : tl.int32 - - value_u32 : tl.uint32 (<= 32 bits) - - nbits_i32 : tl.int32 - Returns new bit offset (tl.int32). + Writes `nbits_i32` least-significant bits of `value_u32` into `u8_ptr` + starting at bit offset `bit_off_i32` in LSB-first order. + + This is still a bit-at-a-time writer; higher-level kernels have been + adjusted to use int32 + shift/mask ahead of time. """ j = tl.full((), 0, dtype=tl.int32) ONE_U32 = tl.full((), 1, dtype=tl.uint32) @@ -269,7 +288,7 @@ def pack_kernel( ROW_HEADER_BITS: tl.constexpr, ): """ - Variant B: first delta Rice (unary = q ones then 0) + r; tail bitmap or Rice. + First delta Rice (unary = q ones then 0) + r; tail bitmap or Rice. Bit order: LSB-first. """ row_idx = tl.program_id(0) @@ -304,7 +323,7 @@ def pack_kernel( base = row_idx * k_dim - # ---- first delta: ALWAYS Rice ---- + # first delta: ALWAYS Rice if k_dim > 0: v0 = tl.load(delta_ptr + base).to(tl.int32) q0 = (v0 >> k_rice_i32).to(tl.int32) @@ -314,7 +333,7 @@ def pack_kernel( q_left = q0 while q_left > 0: chunk = tl.minimum(q_left, THIRTY_ONE_I32) - ones = (ONE_U32 << chunk) - ONE_U32 + ones = (ONE_U32 << chunk) - ONE_U32 bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, ones, chunk) q_left -= chunk bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, ZERO_U32, ONE_I32) # terminating 0 @@ -331,14 +350,14 @@ def pack_kernel( q_left = tl.where(use_bitmap_i32 != 0, tl.full((), 0, dtype=tl.int32), q) while q_left > 0: chunk = tl.minimum(q_left, THIRTY_ONE_I32) - ones = (ONE_U32 << chunk) - ONE_U32 + ones = (ONE_U32 << chunk) - ONE_U32 bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, ones, chunk) q_left -= chunk n_term = tl.where(use_bitmap_i32 != 0, tl.full((), 0, dtype=tl.int32), ONE_I32) bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, ZERO_U32, n_term) # bitmap q only if bitmap - q_bit = tl.where(q > 0, ONE_U32, ZERO_U32) + q_bit = tl.where(q > 0, ONE_U32, ZERO_U32) n_qbit = tl.where(use_bitmap_i32 != 0, ONE_I32, tl.full((), 0, dtype=tl.int32)) bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, q_bit, n_qbit) @@ -347,213 +366,480 @@ def pack_kernel( i += 1 -class BitStreamReader: +@triton.jit +def read_nbits_triton(u8_ptr, bit_off_i32, nbits_i32, limit_bit_i32): """ - LSB-first bit reader over a bytes-like buffer (torch.uint8, np.uint8, or Python bytes). - - read_bits(n): reads n bits, returning an integer whose bit j is the j-th bit read. - - read_unary_bounded(end_bit): reads '1's until a '0' or end_bit; returns (q, hit_end) + GPU version of BitStreamReader.read_bits (LSB-first), but bounds-safe. + + Reads `nbits_i32` bits starting at `bit_off_i32`, but never loads beyond + bit index `limit_bit_i32` (masked loads return 0 out-of-bounds). + + Returns: (value_u32, new_bit_off_i32) """ - __slots__ = ("buf", "total_bits", "bit_off") - - def __init__(self, payload: BytesLike, bit_offset_start: int = 0): - if isinstance(payload, torch.Tensor): - assert payload.dtype == torch.uint8 - self.buf = payload.cpu().numpy().tobytes() - elif isinstance(payload, np.ndarray): - assert payload.dtype == np.uint8 - self.buf = payload.tobytes() - elif isinstance(payload, (bytes, bytearray)): - self.buf = bytes(payload) - else: - raise TypeError("Unsupported payload type for BitStreamReader") - - self.total_bits = len(self.buf) * 8 - self.bit_off = int(bit_offset_start) - - def read_bits(self, n_bits: int) -> int: - """Read n_bits in LSB-first order; returns value with bit j equal to j-th bit read.""" - if n_bits == 0: - return 0 - if self.bit_off + n_bits > self.total_bits: - raise EOFError("Attempt to read past end of bitstream") - val = 0 - start = self.bit_off - for j in range(n_bits): - pos = start + j - b = self.buf[pos >> 3] - bit = (b >> (pos & 7)) & 1 - val |= (bit << j) - self.bit_off = start + n_bits - return val - - def read_unary_bounded(self, end_bit: int) -> Tuple[int, bool]: - """ - Read unary as q ones followed by a single 0, *bounded* by end_bit. - Returns: (q, hit_end) - - q: number of 1s seen before the terminating 0 - - hit_end: True if we reached end_bit before seeing the terminating 0 - """ - q = 0 - while self.bit_off < end_bit: - bit = self.read_bits(1) - if bit == 1: - q += 1 - else: - return q, False - return q, True # ran out of bits without seeing the terminating 0 + j = tl.full((), 0, dtype=tl.int32) + val_u32 = tl.full((), 0, dtype=tl.uint32) + ONE_U32 = tl.full((), 1, dtype=tl.uint32) + ZERO_U8 = tl.full((), 0, dtype=tl.uint8) + + while j < nbits_i32: + pos = bit_off_i32 + j + in_bounds = pos < limit_bit_i32 + + byte_idx = (pos >> 3).to(tl.int32) + bit_idx = (pos & 7).to(tl.int32) + + # Masked load: if in_bounds==0, we load ZERO_U8 instead of touching memory. + u8 = tl.load(u8_ptr + byte_idx, mask=in_bounds, other=ZERO_U8) + u32 = u8.to(tl.uint32) + bit = (u32 >> bit_idx) & ONE_U32 + + val_u32 |= (bit << j) + j += 1 + + new_bit_off = bit_off_i32 + nbits_i32 + return val_u32, new_bit_off + + +@triton.jit +def read_unary_bounded_triton(u8_ptr, bit_off_i32, end_bit_i32): + """ + GPU version of BitStreamReader.read_unary_bounded(end_bit). + Reads '1's until a '0' or end_bit. + Returns: (q_i32, new_bit_off_i32, hit_end_i32) + - q_i32: number of 1s before the terminating 0 + - hit_end_i32: 1 if we reached end_bit without seeing 0 + 0 if we saw a terminating 0 + """ + ONE_U32 = tl.full((), 1, dtype=tl.uint32) + q_i32 = tl.full((), 0, dtype=tl.int32) + hit_end_i32 = tl.full((), 1, dtype=tl.int32) + + cond = bit_off_i32 < end_bit_i32 + while cond: + pos = bit_off_i32 + byte_idx = (pos >> 3).to(tl.int32) + bit_idx = (pos & 7).to(tl.int32) + + u8 = tl.load(u8_ptr + byte_idx) + u32 = u8.to(tl.uint32) + bit = (u32 >> bit_idx) & ONE_U32 - def bits_remaining(self) -> int: - return self.total_bits - self.bit_off + bit_off_i32 += 1 - def is_at_end(self) -> bool: - """True if only up to 7 padding bits remain globally.""" - return self.bit_off >= self.total_bits - 7 + is_one = (bit == ONE_U32) + q_i32 += is_one.to(tl.int32) + # If bit is 0, we did NOT hit end + hit_end_i32 = tl.where(is_one, hit_end_i32, + tl.full((), 0, dtype=tl.int32)) -def _parse_global_header(payload: BytesLike) -> Tuple[int, int, list[int], int]: + # Continue only if we are still inside the row and last bit was 1 + cond = (bit_off_i32 < end_bit_i32) & is_one + + return q_i32, bit_off_i32, hit_end_i32 + + +@triton.jit +def parse_header_kernel( + u8_payload_ptr, # (total_bytes,) uint8 + C_out_ptr, # (1,) int32 + K_out_ptr, # (1,) int32 + R_out_ptr, # (1,) int32 NEW: num_rows + num_B_out_ptr, # (1,) int32 + B_choices_out_ptr, # (MAX_B_CHOICES,) int32 + header_bytes_out_ptr, # (1,) int32 + error_flag_ptr, # (1,) int32 + total_bytes: tl.int32, + MAX_B_CHOICES: tl.constexpr, +): """ + Parse the global header entirely on GPU. + Layout: - 4B "CGRP" - 4B C (uint32 LE) - 2B K (uint16 LE) - 1B num_B - 2B * num_B (each B, uint16 LE) - Returns: (C, K, B_choices, header_end_bit_offset) + 0..3 : "CGRP" + 4..7 : C (uint32 LE) + 8..9 : K (uint16 LE) + 10..13 : R (uint32 LE, num_rows) + 14 : num_B (uint8) + 15.. : B_choices (num_B * 2 bytes, uint16 LE) + """ + + pid = tl.program_id(0) + if pid != 0: + return + + # ---- init outputs / error ---- + C_val = tl.full((), 0, dtype=tl.int32) + K_val = tl.full((), 0, dtype=tl.int32) + R_val = tl.full((), 0, dtype=tl.int32) + num_B_val = tl.full((), 0, dtype=tl.int32) + header_bytes_i32 = tl.full((), 0, dtype=tl.int32) + err = tl.full((), 0, dtype=tl.int32) + + # ---- basic size + magic checks ---- + # Minimum header size: 15 bytes (without B_choices) + if total_bytes < 15: + err = 1 + else: + # Magic "CGRP" = [67, 71, 82, 80] + m0 = tl.load(u8_payload_ptr + 0) + m1 = tl.load(u8_payload_ptr + 1) + m2 = tl.load(u8_payload_ptr + 2) + m3 = tl.load(u8_payload_ptr + 3) + cond_magic = (m0 == 67) & (m1 == 71) & (m2 == 82) & (m3 == 80) + bad_magic = cond_magic == 0 + err = tl.where(bad_magic, tl.full((), 2, dtype=tl.int32), err) + + # ---- C, K, R, num_B ---- + if err == 0: + # C (uint32 LE at bytes 4..7) + b4 = tl.load(u8_payload_ptr + 4).to(tl.int32) + b5 = tl.load(u8_payload_ptr + 5).to(tl.int32) + b6 = tl.load(u8_payload_ptr + 6).to(tl.int32) + b7 = tl.load(u8_payload_ptr + 7).to(tl.int32) + C_val = b4 | (b5 << 8) | (b6 << 16) | (b7 << 24) + + # K (uint16 LE at bytes 8..9) + b8 = tl.load(u8_payload_ptr + 8).to(tl.int32) + b9 = tl.load(u8_payload_ptr + 9).to(tl.int32) + K_val = b8 | (b9 << 8) + + # R (uint32 LE at bytes 10..13) + b10 = tl.load(u8_payload_ptr + 10).to(tl.int32) + b11 = tl.load(u8_payload_ptr + 11).to(tl.int32) + b12 = tl.load(u8_payload_ptr + 12).to(tl.int32) + b13 = tl.load(u8_payload_ptr + 13).to(tl.int32) + R_val = b10 | (b11 << 8) | (b12 << 16) | (b13 << 24) + + # num_B at byte 14 + num_B_val = tl.load(u8_payload_ptr + 14).to(tl.int32) + invalid_num_B = (num_B_val <= 0) | (num_B_val > MAX_B_CHOICES) + err = tl.where(invalid_num_B, tl.full((), 3, dtype=tl.int32), err) + + # ---- read B_choices in a structured loop (no break/return) ---- + off = tl.full((), 15, dtype=tl.int32) # B_choices start at byte 15 + i = tl.full((), 0, dtype=tl.int32) + + while i < MAX_B_CHOICES: + need_this = (i < num_B_val) & (err == 0) + + if need_this: + cond_in_bounds = (off + 1) < total_bytes + if cond_in_bounds: + lo = tl.load(u8_payload_ptr + off).to(tl.int32) + hi = tl.load(u8_payload_ptr + off + 1).to(tl.int32) + B_val = lo | (hi << 8) + tl.store(B_choices_out_ptr + i, B_val) + off += 2 + else: + err = tl.full((), 4, dtype=tl.int32) + tl.store(B_choices_out_ptr + i, tl.full((), 0, dtype=tl.int32)) + else: + tl.store(B_choices_out_ptr + i, tl.full((), 0, dtype=tl.int32)) + + i += 1 + + # header_bytes = 15 + 2 * num_B (only meaningful if err == 0) + if err == 0: + header_bytes_i32 = 15 + (num_B_val * 2) + + # ---- store outputs ---- + tl.store(C_out_ptr, C_val) + tl.store(K_out_ptr, K_val) + tl.store(R_out_ptr, R_val) + tl.store(num_B_out_ptr, num_B_val) + tl.store(header_bytes_out_ptr, header_bytes_i32) + tl.store(error_flag_ptr, err) + + +@triton.jit +def scan_rows_kernel( + u8_payload_ptr, # (total_bytes,) uint8 + row_bit_offsets_ptr, # (num_rows,) int32 (bit offset of 16-bit length) + row_payload_bytes_ptr, # (num_rows,) int32 + best_B_idx_ptr, # (num_rows,) int32 + use_bitmap_ptr, # (num_rows,) int32 (0/1) + header_end_bit: tl.int32, + total_bits: tl.int32, + num_rows: tl.int32, + ROW_HEADER_BITS: tl.constexpr, +): + """ + Sequential scan of all rows (1 program). For each row r: + + bit_off: bit offset of 16-bit payload length + length: row_payload_bytes[r] + header: ((b_idx << 1) | use_bitmap) in ROW_HEADER_BITS bits + rest: payload_bits - ROW_HEADER_BITS bits + + Assumes the bitstream is valid and has enough bits for num_rows rows. """ + pid = tl.program_id(0) + if pid != 0: + return + + bit_off_i32 = header_end_bit + r = tl.full((), 0, dtype=tl.int32) + SIXTEEN_I32 = tl.full((), 16, dtype=tl.int32) + + while r < num_rows: + # bit offset at the start of the 16-bit length for this row + tl.store(row_bit_offsets_ptr + r, bit_off_i32) + + # read 16-bit payload length (bytes) + length_u32, bit_off_after_len = read_nbits_triton( + u8_payload_ptr, bit_off_i32, SIXTEEN_I32, total_bits + ) + length_i32 = length_u32.to(tl.int32) + tl.store(row_payload_bytes_ptr + r, length_i32) + + # read row header bits: ((b_idx << 1) | use_bitmap) + header_u32, bit_off_after_header = read_nbits_triton( + u8_payload_ptr, + bit_off_after_len, + tl.full((), ROW_HEADER_BITS, dtype=tl.int32), + total_bits, + ) + header_i32 = header_u32.to(tl.int32) + use_bitmap_i32 = header_i32 & 1 + best_B_idx_i32 = header_i32 >> 1 + + tl.store(best_B_idx_ptr + r, best_B_idx_i32) + tl.store(use_bitmap_ptr + r, use_bitmap_i32) + + # skip remainder of this row's payload + payload_bits_i32 = length_i32 * 8 + rem_bits_i32 = payload_bits_i32 - ROW_HEADER_BITS + bit_off_i32 = bit_off_after_header + rem_bits_i32 + r += 1 + + +@triton.jit +def decode_rows_kernel( + u8_payload_ptr, # (total_bytes,) uint8 + out_vals_ptr, # (num_rows * K,) int32 + row_bit_offsets_ptr, # (num_rows,) int32 (bit offset of 16-bit length) + row_payload_bytes_ptr, # (num_rows,) int32 + best_B_idx_ptr, # (num_rows,) int32 + use_bitmap_ptr, # (num_rows,) int32 + k_rice_choices_ptr, # (num_B,) int32 + num_rows: tl.int32, + K: tl.int32, + ROW_HEADER_BITS: tl.constexpr, +): + """ + Fully GPU decode of Rice/bitmap rows. + + For each row: + - Start at bit offset of 16-bit length + - Skip 16-bit length + - Skip header bits (we already know b_idx/use_bitmap from scan) + - First value: full Rice (unary + remainder) + - Tail: Rice or bitmap+remainder + """ + row_idx = tl.program_id(0) + if row_idx >= num_rows: + return + + # Per-row metadata + row_start_bit_i32 = tl.load(row_bit_offsets_ptr + row_idx).to(tl.int32) + payload_bytes_i32 = tl.load(row_payload_bytes_ptr + row_idx).to(tl.int32) + best_B_idx_i32 = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) + use_bitmap_i32 = (tl.load(use_bitmap_ptr + row_idx) & 1).to(tl.int32) + + # k_rice and M for this row + k_rice_i32 = tl.load(k_rice_choices_ptr + best_B_idx_i32).to(tl.int32) + M_i32 = (tl.full((), 1, dtype=tl.int32) << k_rice_i32) + + # Bit range of this row + bit_after_len_i32 = row_start_bit_i32 + 16 + row_end_bit_i32 = bit_after_len_i32 + payload_bytes_i32 * 8 + + # Skip header bits (we already know the contents) + header_dummy_u32, bit_off_i32 = read_nbits_triton( + u8_payload_ptr, + bit_after_len_i32, + tl.full((), ROW_HEADER_BITS, dtype=tl.int32), + row_end_bit_i32, # limit = end of this row + ) + + base_out = row_idx * K + ONE_I32 = tl.full((), 1, dtype=tl.int32) + + # ---- first value: ALWAYS full Rice ---- + if K > 0: + q0_i32, bit_off_i32, hit_end0_i32 = read_unary_bounded_triton( + u8_payload_ptr, + bit_off_i32, + row_end_bit_i32, + ) + r0_u32, bit_off_i32 = read_nbits_triton( + u8_payload_ptr, + bit_off_i32, + k_rice_i32, + row_end_bit_i32, # limit + ) + r0_i32 = r0_u32.to(tl.int32) + v0_i32 = q0_i32 * M_i32 + r0_i32 + tl.store(out_vals_ptr + base_out, v0_i32) + + # ---- tail values ---- + i = tl.full((), 1, dtype=tl.int32) + while i < K: + if use_bitmap_i32 != 0: + # Bitmap mode: q is 1 bit in {0,1} + q_bit_u32, bit_off_i32 = read_nbits_triton( + u8_payload_ptr, + bit_off_i32, + ONE_I32, + row_end_bit_i32, + ) + q_i32 = q_bit_u32.to(tl.int32) + + r_u32, bit_off_i32 = read_nbits_triton( + u8_payload_ptr, + bit_off_i32, + k_rice_i32, + row_end_bit_i32, + ) + r_i32 = r_u32.to(tl.int32) + else: + # Full Rice mode + q_i32, bit_off_i32, hit_end_i32 = read_unary_bounded_triton( + u8_payload_ptr, + bit_off_i32, + row_end_bit_i32, + ) + r_u32, bit_off_i32 = read_nbits_triton( + u8_payload_ptr, + bit_off_i32, + k_rice_i32, + row_end_bit_i32, + ) + r_i32 = r_u32.to(tl.int32) + + v_i32 = q_i32 * M_i32 + r_i32 + tl.store(out_vals_ptr + base_out + i, v_i32) + i += 1 + + +def decode_batch_rows( + payload: BytesLike, + use_delta: bool = True, + max_num_B: int = 16, +) -> tuple[torch.Tensor, int, int]: + + if not torch.cuda.is_available(): + raise RuntimeError("decode_batch_rows_gpu requires CUDA") + + # --- Move payload to CUDA (if needed) --- if isinstance(payload, torch.Tensor): assert payload.dtype == torch.uint8 - raw = payload.cpu().numpy().tobytes() + payload_gpu = payload if payload.is_cuda else payload.cuda() elif isinstance(payload, np.ndarray): assert payload.dtype == np.uint8 - raw = payload.tobytes() + payload_gpu = torch.from_numpy(payload).to("cuda", dtype=torch.uint8) elif isinstance(payload, (bytes, bytearray)): - raw = bytes(payload) + arr = np.frombuffer(bytes(payload), dtype=np.uint8) + payload_gpu = torch.from_numpy(arr).to("cuda", dtype=torch.uint8) else: raise TypeError("Unsupported payload type") - if len(raw) < 11: - raise ValueError("Payload too short for global header") - if raw[:4] != b"CGRP": - raise ValueError("Bad magic; expected 'CGRP'") - - C = int.from_bytes(raw[4:8], "little", signed=False) - K = int.from_bytes(raw[8:10], "little", signed=False) - num_B = raw[10] - need = 4 + 4 + 2 + 1 + 2 * num_B - if len(raw) < need: - raise ValueError("Payload shorter than header requires") - - B_choices = [] - off = 11 - for _ in range(num_B): - b = int.from_bytes(raw[off:off+2], "little", signed=False) - B_choices.append(b) - off += 2 - return C, K, B_choices, off * 8 - - -def _decode_row(stream: BitStreamReader, M: int, k_rice: int, use_bitmap: int, - row_payload_bytes: int, row_header_bits: int, - K: int) -> list[int]: - """ - Stream is positioned just AFTER the 16-bit length. - Decode EXACTLY K deltas (first is Rice; tail is bitmap or Rice), - then align to end-of-row (row_payload_bytes*8). - """ - start_bit = stream.bit_off - row_end_bit = start_bit + row_payload_bytes * 8 - - # header - _ = stream.read_bits(row_header_bits) - - deltas: list[int] = [] - - # first (Rice) - q0, hit_end = stream.read_unary_bounded(row_end_bit) - if hit_end or stream.bit_off + k_rice > row_end_bit: - stream.bit_off = row_end_bit - return [] - r0 = stream.read_bits(k_rice) - deltas.append(q0 * M + r0) - - # Tail: exactly K-1 more - for _ in range(K - 1): - if use_bitmap: - need = 1 + k_rice - if stream.bit_off + need > row_end_bit: - # not enough bits; treat as malformed/padded - break - q = stream.read_bits(1) - r = stream.read_bits(k_rice) - else: - # Rice - if stream.bit_off >= row_end_bit: - break - q, hit_end = stream.read_unary_bounded(row_end_bit) - if hit_end or stream.bit_off + k_rice > row_end_bit: - break - r = stream.read_bits(k_rice) - deltas.append(q * M + r) - - # align to end of row explicitly - stream.bit_off = row_end_bit - - # prefix sum - if not deltas: - return [] - vals = [0] * len(deltas) - vals[0] = deltas[0] - for i in range(1, len(deltas)): - vals[i] = vals[i-1] + deltas[i] - return vals - - -def decode_batch_rows(payload: BytesLike) -> tuple[list[list[int]], int, int]: - C, K, B_choices, header_end_bit = _parse_global_header(payload) - num_B = len(B_choices) - - # derive M/k_rice per choice - M_choices, k_rice_choices = [], [] - for B in B_choices: + payload_gpu = payload_gpu.contiguous() + dev = payload_gpu.device + total_bytes = int(payload_gpu.numel()) + if total_bytes == 0: + empty = torch.empty((0, 0), dtype=torch.int64, device=dev) + return empty, 0, 0 + + total_bits = total_bytes * 8 + + # --- 1) Parse global header on GPU (now also gets num_rows = R) --- + C_out = torch.empty(1, dtype=torch.int32, device=dev) + K_out = torch.empty(1, dtype=torch.int32, device=dev) + R_out = torch.empty(1, dtype=torch.int32, device=dev) # NEW + num_B_out = torch.empty(1, dtype=torch.int32, device=dev) + B_choices_out = torch.empty(max_num_B, dtype=torch.int32, device=dev) + header_bytes_out = torch.empty(1, dtype=torch.int32, device=dev) + err_flag = torch.zeros(1, dtype=torch.int32, device=dev) + + parse_header_kernel[(1,)]( + payload_gpu, + C_out, + K_out, + R_out, + num_B_out, + B_choices_out, + header_bytes_out, + err_flag, + total_bytes, + MAX_B_CHOICES=max_num_B, + ) + + torch.cuda.synchronize() + err = int(err_flag.cpu().item()) + if err != 0: + raise ValueError(f"parse_header_kernel failed with error code {err}") + + C = int(C_out.cpu().item()) + K = int(K_out.cpu().item()) + num_rows = int(R_out.cpu().item()) # NEW + num_B = int(num_B_out.cpu().item()) + header_bytes = int(header_bytes_out.cpu().item()) + B_choices_list = [int(x) for x in B_choices_out[:num_B].cpu().tolist()] + header_end_bit = header_bytes * 8 + + # --- 2) Build k_rice choices on CPU -> move to GPU --- + k_rice_choices = [] + for B in B_choices_list: M = C // B if M <= 0 or (M & (M - 1)) != 0: raise ValueError(f"M=C//B={M} not power of two for B={B}") - M_choices.append(M) k_rice_choices.append(int(math.log2(M))) + k_rice_choices_tensor = torch.tensor( + k_rice_choices, dtype=torch.int32, device=dev + ) - B_choice_bits = (num_B - 1).bit_length() + B_choice_bits = (num_B - 1).bit_length() ROW_HEADER_BITS = 1 + B_choice_bits - stream = BitStreamReader(payload, bit_offset_start=header_end_bit) - rows_out: list[list[int]] = [] - while stream.bits_remaining() >= 16: - row_payload_bytes = stream.read_bits(16) - if row_payload_bytes == 0 and stream.is_at_end(): - break - - # Peek header to learn best_B_idx & use_bitmap - if stream.bits_remaining() < ROW_HEADER_BITS: - break - header = stream.read_bits(ROW_HEADER_BITS) - use_bitmap = header & 1 - best_B_idx = header >> 1 - - if not (0 <= best_B_idx < num_B): - break - M = M_choices[best_B_idx] - k_rice = k_rice_choices[best_B_idx] - - # Rewind header; decode the row with exact K - stream.bit_off -= ROW_HEADER_BITS - row_vals = _decode_row( - stream, M=M, k_rice=k_rice, use_bitmap=use_bitmap, - row_payload_bytes=row_payload_bytes, row_header_bits=ROW_HEADER_BITS, - K=K - ) - if not row_vals: - break - rows_out.append(row_vals) - return rows_out, C, len(rows_out) + # --- 3) Scan rows on GPU to get per-row metadata --- + row_bit_offsets = torch.empty(num_rows, dtype=torch.int32, device=dev) + row_payload_bytes = torch.empty(num_rows, dtype=torch.int32, device=dev) + best_B_idx = torch.empty(num_rows, dtype=torch.int32, device=dev) + use_bitmap = torch.empty(num_rows, dtype=torch.int32, device=dev) + + scan_rows_kernel[(1,)]( + payload_gpu, + row_bit_offsets, + row_payload_bytes, + best_B_idx, + use_bitmap, + header_end_bit, + int(total_bits), + int(num_rows), + ROW_HEADER_BITS=ROW_HEADER_BITS, + ) + + # --- 4) Decode rows in parallel on GPU --- + out_vals = torch.empty((num_rows, K), dtype=torch.int32, device=dev) + decode_rows_kernel[(num_rows,)]( + payload_gpu, + out_vals, + row_bit_offsets, + row_payload_bytes, + best_B_idx, + use_bitmap, + k_rice_choices_tensor, + int(num_rows), + int(K), + ROW_HEADER_BITS=ROW_HEADER_BITS, + ) + + # --- undo delta on-GPU if needed --- + if use_delta: + out_vals = torch.cumsum(out_vals, dim=1) + return out_vals.to(torch.int64), C, num_rows if __name__ == "__main__": @@ -561,18 +847,19 @@ def decode_batch_rows(payload: BytesLike) -> tuple[list[list[int]], int, int]: ROWS, K = 32, 16 COLS = 4096 - x = torch.randn((ROWS, COLS), dtype=torch.float32) + x = torch.randn((ROWS, COLS), dtype=torch.float32, device="cuda") idx = torch.topk(x.abs(), k=K, dim=-1, largest=True, sorted=False).indices - - idx, _ = torch.sort(idx, dim=1) - payload, _ = encode_batch_rows_sorted(idx, C=COLS, B_choices=(64, 128, 256)) - decoded, _, _ = decode_batch_rows(payload) - ok = True - idx = [row for row in idx] - for r in range(ROWS): - if not torch.equal(torch.tensor(decoded[r]), idx[r].cpu()): - ok = False - print("Mismatch row", r) - print("orig:", idx[r].tolist()) - print("dec :", decoded[r]) - print("Round-trip OK" if ok else "Round-trip MISMATCH") + for use_delta in [False, True]: + if use_delta: + idx, _ = torch.sort(idx, dim=1) + payload, _ = encode_batch_rows(idx, C=COLS, use_delta=use_delta, B_choices=(64, 128, 256)) + decoded, _, _ = decode_batch_rows(payload, use_delta=use_delta) + dec = [torch.tensor(r, dtype=torch.int64) for r in decoded] + ok = True + for r in range(ROWS): + if not torch.equal(torch.tensor(decoded[r]), idx[r]): + ok = False + print("Mismatch row", r) + print("orig:", idx[r].tolist()) + print("dec :", decoded[r]) + print("Round-trip OK" if ok else "Round-trip MISMATCH") diff --git a/tests/unit/test_compress.py b/tests/unit/test_compress.py index 1363436bf..934697e33 100644 --- a/tests/unit/test_compress.py +++ b/tests/unit/test_compress.py @@ -13,7 +13,7 @@ _idct ) -from tplr.compression import encode_batch_rows_sorted, pack_12bit_indices, unpack_12bit_indices +from tplr.compression import encode_batch_rows, pack_12bit_indices, unpack_12bit_indices class TestTopKCompressor: @@ -84,7 +84,7 @@ def test_decompress_with_rice_bitmap_format( original_indices = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.int64) # Pack using the new encoder format - idx_bytes, _ = encode_batch_rows_sorted(original_indices, C=totalk) + idx_bytes, _ = encode_batch_rows(original_indices, C=totalk) val = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32) # Test decompression with packed format @@ -106,8 +106,8 @@ def test_batch_decompress_multiple_rice_bitmap_formats( idx2_orig = torch.tensor([[4, 5], [6, 7]], dtype=torch.int64) # Pack them using the new encoder format - idx1_packed, _ = encode_batch_rows_sorted(idx1_orig, C=totalk) - idx2_packed, _ = encode_batch_rows_sorted(idx2_orig, C=totalk) + idx1_packed, _ = encode_batch_rows(idx1_orig, C=totalk) + idx2_packed, _ = encode_batch_rows(idx2_orig, C=totalk) idx_list = [idx1_packed, idx2_packed] @@ -224,7 +224,7 @@ def test_batch_decompress_with_norm_options( # Create test data with Rice/bitmap encoded format idx_orig = torch.tensor([[0, 1, 2, 3]], dtype=torch.int64) # Even count - idx_packed, _ = encode_batch_rows_sorted(idx_orig, C=totalk) + idx_packed, _ = encode_batch_rows(idx_orig, C=totalk) idx = [idx_packed] val = [torch.tensor([[10.0, 20.0, 30.0, 40.0]], dtype=torch.float32)] From f156743c981c615b60a6a38476ad62d0e7d70ecf Mon Sep 17 00:00:00 2001 From: Kasper Date: Mon, 17 Nov 2025 14:10:20 +0400 Subject: [PATCH 07/33] remove use_delta arg --- src/tplr/compression/hybrid.py | 43 ++++++++++++++-------------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py index 22c065ca3..8d6730e21 100644 --- a/src/tplr/compression/hybrid.py +++ b/src/tplr/compression/hybrid.py @@ -15,7 +15,6 @@ def encode_batch_rows( idx: torch.Tensor, *, C: int, - use_delta: bool = True, B_choices: Tuple[int, ...] = (64, 128) ) -> Tuple[BytesLike, Dict]: """ @@ -68,14 +67,11 @@ def encode_batch_rows( idx = idx.contiguous() dev = idx.device - if use_delta: - # v[0], v[1]-v[0], v[2]-v[1], ... - vals = torch.cat( - (idx[:, :1], idx[:, 1:] - idx[:, :-1]), - dim=1, - ) - else: - vals = idx + # v[0], v[1]-v[0], v[2]-v[1], ... + vals = torch.cat( + (idx[:, :1], idx[:, 1:] - idx[:, :-1]), + dim=1, + ) # Cast to int32 for Triton kernels vals = vals.to(torch.int32) @@ -725,7 +721,6 @@ def decode_rows_kernel( def decode_batch_rows( payload: BytesLike, - use_delta: bool = True, max_num_B: int = 16, ) -> tuple[torch.Tensor, int, int]: @@ -837,8 +832,7 @@ def decode_batch_rows( ) # --- undo delta on-GPU if needed --- - if use_delta: - out_vals = torch.cumsum(out_vals, dim=1) + out_vals = torch.cumsum(out_vals, dim=1) return out_vals.to(torch.int64), C, num_rows @@ -849,17 +843,14 @@ def decode_batch_rows( x = torch.randn((ROWS, COLS), dtype=torch.float32, device="cuda") idx = torch.topk(x.abs(), k=K, dim=-1, largest=True, sorted=False).indices - for use_delta in [False, True]: - if use_delta: - idx, _ = torch.sort(idx, dim=1) - payload, _ = encode_batch_rows(idx, C=COLS, use_delta=use_delta, B_choices=(64, 128, 256)) - decoded, _, _ = decode_batch_rows(payload, use_delta=use_delta) - dec = [torch.tensor(r, dtype=torch.int64) for r in decoded] - ok = True - for r in range(ROWS): - if not torch.equal(torch.tensor(decoded[r]), idx[r]): - ok = False - print("Mismatch row", r) - print("orig:", idx[r].tolist()) - print("dec :", decoded[r]) - print("Round-trip OK" if ok else "Round-trip MISMATCH") + idx, _ = torch.sort(idx, dim=1) + payload, _ = encode_batch_rows(idx, C=COLS, B_choices=(64, 128, 256)) + decoded, _, _ = decode_batch_rows(payload) + ok = True + for r in range(ROWS): + if not torch.equal(torch.tensor(decoded[r]), idx[r]): + ok = False + print("Mismatch row", r) + print("orig:", idx[r].tolist()) + print("dec :", decoded[r]) + print("Round-trip OK" if ok else "Round-trip MISMATCH") From f0a3531b8f749038d36432441c2958cf14ac3493 Mon Sep 17 00:00:00 2001 From: Kasper Date: Mon, 17 Nov 2025 18:20:09 +0400 Subject: [PATCH 08/33] layout v2 - faster decoding --- src/tplr/compression/bitops.py | 250 +++++++++++++++++ src/tplr/compression/bits.py | 486 --------------------------------- src/tplr/compression/hybrid.py | 470 ++++++++++++------------------- 3 files changed, 427 insertions(+), 779 deletions(-) create mode 100644 src/tplr/compression/bitops.py delete mode 100644 src/tplr/compression/bits.py diff --git a/src/tplr/compression/bitops.py b/src/tplr/compression/bitops.py new file mode 100644 index 000000000..96b1c2d88 --- /dev/null +++ b/src/tplr/compression/bitops.py @@ -0,0 +1,250 @@ +from typing import Union + +import numpy as np +import torch +import triton +import triton.language as tl + +BytesLike = Union[bytes, bytearray, np.ndarray, torch.Tensor] + +@triton.jit +def write_nbits( + u8_ptr, # uint8* global buffer + bit_off_i32, # scalar tl.int32 bit offset + value_u32, # scalar tl.uint32, up to 32 bits used + nbits_i32, # scalar tl.int32, number of bits to write +): + """ + Writes `nbits_i32` least-significant bits of `value_u32` into `u8_ptr` + starting at bit offset `bit_off_i32` in LSB-first order. + + This is still a bit-at-a-time writer; higher-level kernels have been + adjusted to use int32 + shift/mask ahead of time. + """ + j = tl.full((), 0, dtype=tl.int32) + ONE_U32 = tl.full((), 1, dtype=tl.uint32) + + while j < nbits_i32: + pos = bit_off_i32 + j + byte_idx = (pos >> 3).to(tl.int32) + bit_idx = (pos & 7).to(tl.int32) + + old_u8 = tl.load(u8_ptr + byte_idx) + old_u32 = old_u8.to(tl.uint32) + + vbit = (value_u32 >> j) & ONE_U32 + mask = ONE_U32 << bit_idx + new_u32 = (old_u32 & (~mask)) | (vbit << bit_idx) + tl.store(u8_ptr + byte_idx, new_u32.to(tl.uint8)) + j += 1 + return bit_off_i32 + nbits_i32 + + +@triton.jit +def write_nbits_fast( + u8_ptr, + bit_off_i32, # start bit + value_u32, # LSB-first payload bits + nbits_i32, # 0..32 +): + # If nothing to write + if nbits_i32 <= 0: + return bit_off_i32 + + start_bit = bit_off_i32 + first_byte = (start_bit >> 3).to(tl.int32) + first_bit = (start_bit & 7).to(tl.int32) + + # How many bits fit in the first byte + bits_in_first = tl.minimum( + nbits_i32, + tl.full((), 8, dtype=tl.int32) - first_bit, + ) + + # -------- leading partial byte -------- + if bits_in_first > 0: + old_u8 = tl.load(u8_ptr + first_byte).to(tl.uint32) + + # mask for the bits we overwrite inside that byte + mask_u32 = ((tl.full((), 1, tl.uint32) << bits_in_first) - 1) \ + << first_bit + + # extract those bits from value_u32 + bits_u32 = (value_u32 & ((tl.full((), 1, tl.uint32) << bits_in_first) - 1)) \ + << first_bit + + new_u8 = ((old_u8 & ~mask_u32) | bits_u32).to(tl.uint8) + tl.store(u8_ptr + first_byte, new_u8) + + bit_off_i32 += bits_in_first + value_u32 >>= bits_in_first + nbits_i32 -= bits_in_first + + # Now bit_off_i32 is byte aligned (or nbits_i32 == 0) + if nbits_i32 <= 0: + return bit_off_i32 + + cur_byte = (bit_off_i32 >> 3).to(tl.int32) + + # full bytes we can write + full_bytes = (nbits_i32 >> 3).to(tl.int32) # nbits_i32 // 8 + rem_bits = (nbits_i32 & 7).to(tl.int32) + + # -------- full bytes -------- + jb = tl.full((), 0, dtype=tl.int32) + while jb < full_bytes: + # take lowest 8 bits from value_u32 + byte_val = (value_u32 & tl.full((), 0xFF, tl.uint32)).to(tl.uint8) + tl.store(u8_ptr + cur_byte + jb, byte_val) + value_u32 >>= 8 + jb += 1 + + bit_off_i32 += full_bytes * 8 + + # -------- trailing partial byte -------- + if rem_bits > 0: + byte_idx = (bit_off_i32 >> 3).to(tl.int32) + old_u8 = tl.load(u8_ptr + byte_idx).to(tl.uint32) + + mask_u32 = ( (tl.full((), 1, tl.uint32) << rem_bits) - 1 ) + bits_u32 = ( value_u32 & mask_u32 ) + + new_u8 = ((old_u8 & ~mask_u32) | bits_u32).to(tl.uint8) + tl.store(u8_ptr + byte_idx, new_u8) + + bit_off_i32 += rem_bits + return bit_off_i32 + + +@triton.jit +def read_nbits(u8_ptr, bit_off_i32, nbits_i32, limit_bit_i32): + """ + GPU version of BitStreamReader.read_bits (LSB-first), but bounds-safe. + + Reads `nbits_i32` bits starting at `bit_off_i32`, but never loads beyond + bit index `limit_bit_i32` (masked loads return 0 out-of-bounds). + + Returns: (value_u32, new_bit_off_i32) + """ + j = tl.full((), 0, dtype=tl.int32) + val_u32 = tl.full((), 0, dtype=tl.uint32) + ONE_U32 = tl.full((), 1, dtype=tl.uint32) + ZERO_U8 = tl.full((), 0, dtype=tl.uint8) + + while j < nbits_i32: + pos = bit_off_i32 + j + in_bounds = pos < limit_bit_i32 + + byte_idx = (pos >> 3).to(tl.int32) + bit_idx = (pos & 7).to(tl.int32) + + # Masked load: if in_bounds==0, we load ZERO_U8 instead of touching memory. + u8 = tl.load(u8_ptr + byte_idx, mask=in_bounds, other=ZERO_U8) + u32 = u8.to(tl.uint32) + bit = (u32 >> bit_idx) & ONE_U32 + + val_u32 |= (bit << j) + j += 1 + + new_bit_off = bit_off_i32 + nbits_i32 + return val_u32, new_bit_off + + +@triton.jit +def read_nbits_fast(u8_ptr, bit_off_i32, nbits_i32, limit_bit_i32): + if nbits_i32 <= 0: + return tl.full((), 0, tl.uint32), bit_off_i32 + + # clamp to limit if you want to keep the defensive behavior + max_bits = limit_bit_i32 - bit_off_i32 + nbits_i32 = tl.minimum(nbits_i32, max_bits) + + start_bit = bit_off_i32 + end_bit = bit_off_i32 + nbits_i32 + + first_byte = (start_bit >> 3).to(tl.int32) + first_bit = (start_bit & 7).to(tl.int32) + + bits_in_first = tl.minimum( + nbits_i32, + tl.full((), 8, dtype=tl.int32) - first_bit, + ) + + val_u32 = tl.full((), 0, dtype=tl.uint32) + shift = tl.full((), 0, dtype=tl.int32) + + # -------- leading partial byte -------- + if bits_in_first > 0: + byte = tl.load(u8_ptr + first_byte).to(tl.uint32) + mask = ((tl.full((), 1, tl.uint32) << bits_in_first) - 1) << first_bit + chunk = (byte & mask) >> first_bit + val_u32 |= (chunk << shift) + + bit_off_i32 += bits_in_first + shift += bits_in_first + nbits_i32 -= bits_in_first + + if nbits_i32 <= 0: + return val_u32, bit_off_i32 + + cur_byte = (bit_off_i32 >> 3).to(tl.int32) + full_bytes = (nbits_i32 >> 3).to(tl.int32) + rem_bits = (nbits_i32 & 7).to(tl.int32) + + # -------- full bytes -------- + jb = tl.full((), 0, dtype=tl.int32) + while jb < full_bytes: + byte = tl.load(u8_ptr + cur_byte + jb).to(tl.uint32) + val_u32 |= (byte << shift) + shift += 8 + jb += 1 + + bit_off_i32 += full_bytes * 8 + + # -------- trailing partial byte -------- + if rem_bits > 0: + byte = tl.load(u8_ptr + (bit_off_i32 >> 3).to(tl.int32)).to(tl.uint32) + mask = (tl.full((), 1, tl.uint32) << rem_bits) - 1 + chunk = byte & mask + val_u32 |= (chunk << shift) + bit_off_i32 += rem_bits + return val_u32, bit_off_i32 + + +@triton.jit +def read_unary_bounded_triton(u8_ptr, bit_off_i32, end_bit_i32): + """ + GPU version of BitStreamReader.read_unary_bounded(end_bit). + Reads '1's until a '0' or end_bit. + Returns: (q_i32, new_bit_off_i32, hit_end_i32) + - q_i32: number of 1s before the terminating 0 + - hit_end_i32: 1 if we reached end_bit without seeing 0 + 0 if we saw a terminating 0 + """ + ONE_U32 = tl.full((), 1, dtype=tl.uint32) + q_i32 = tl.full((), 0, dtype=tl.int32) + hit_end_i32 = tl.full((), 1, dtype=tl.int32) + + cond = bit_off_i32 < end_bit_i32 + while cond: + pos = bit_off_i32 + byte_idx = (pos >> 3).to(tl.int32) + bit_idx = (pos & 7).to(tl.int32) + + u8 = tl.load(u8_ptr + byte_idx) + u32 = u8.to(tl.uint32) + bit = (u32 >> bit_idx) & ONE_U32 + + bit_off_i32 += 1 + + is_one = (bit == ONE_U32) + q_i32 += is_one.to(tl.int32) + + # If bit is 0, we did NOT hit end + hit_end_i32 = tl.where(is_one, hit_end_i32, + tl.full((), 0, dtype=tl.int32)) + + # Continue only if we are still inside the row and last bit was 1 + cond = (bit_off_i32 < end_bit_i32) & is_one + + return q_i32, bit_off_i32, hit_end_i32 \ No newline at end of file diff --git a/src/tplr/compression/bits.py b/src/tplr/compression/bits.py deleted file mode 100644 index 0ce31380f..000000000 --- a/src/tplr/compression/bits.py +++ /dev/null @@ -1,486 +0,0 @@ -# bits.py -# The MIT License (MIT) -# © 2025 tplr.ai -# -# Triton-powered Rice/bitmap encoder (per-row scheme) with a CPU decoder. -# Payload layout matches the CPU reference: -# [ C-1 : 12b ][ N : 16b ][ reserved : 1b ] -# then, for each row r=0..N-1: -# [ row_len_bytes[r] : 16b ][ row_payload_bits[r] ] -# -# Dependencies: torch, triton (runtime), numpy (only for decode consumer code elsewhere if needed) - -from __future__ import annotations -import math -from typing import Sequence, Tuple - -import torch -import torch.nn.functional as F - -try: - import triton - import triton.language as tl - - TRITON_AVAILABLE = True -except Exception: - TRITON_AVAILABLE = False - - -# -------------------------- CPU decoder (unchanged format) -------------------------- - - -class _BitReader: - def __init__(self, data: bytes) -> None: - self.data = data - self.idx = 0 - self.cur = 0 - self.nbits = 0 - - def _fill(self, n: int) -> None: - while self.nbits < n and self.idx < len(self.data): - self.cur |= int(self.data[self.idx]) << self.nbits - self.idx += 1 - self.nbits += 8 - - def read_bits(self, n: int) -> int: - if n <= 0: - return 0 - self._fill(n) - mask = (1 << n) - 1 - out = self.cur & mask - self.cur >>= n - self.nbits -= n - return out - - def read_unary(self) -> int: - q = 0 - while True: - bit = self.read_bits(1) - if bit == 0: - break - q += 1 - return q - - def read_bytes(self, n: int) -> bytes: - out = bytearray() - for _ in range(n): - out.append(self.read_bits(8)) - return bytes(out) - - -def _rice_read(br: _BitReader, k: int) -> int: - q = br.read_unary() - r = br.read_bits(k) if k > 0 else 0 - return (q << k) + r - - -def decode_batch_rows(payload: bytes) -> tuple[list[list[int]], int, int]: - """ - Decode payload created by encode_batch_rows(...). - Returns (rows, C, N) where `rows` is a list of per-row global indices. - """ - br = _BitReader(payload) - C = br.read_bits(12) + 1 - N = br.read_bits(16) - _ = br.read_bits(1) # reserved - - rows: list[list[int]] = [] - for _i in range(N): - row_len = br.read_bits(16) - row_bytes = br.read_bytes(row_len) - rr = _BitReader(row_bytes) - lb = rr.read_bits(5) - k_param = rr.read_bits(4) - use_bitmap = rr.read_bits(1) - B = 1 << lb - n_sub = C // B - - indices: list[int] = [] - for j in range(n_sub): - s_len = _rice_read(rr, k_param) - if s_len == 0: - continue - if use_bitmap: - bitmask = rr.read_bits(B) - for loc in range(B): - if (bitmask >> loc) & 1: - indices.append(j * B + loc) - else: - for _ in range(s_len): - loc = rr.read_bits(lb) - indices.append(j * B + loc) - rows.append(indices) - return rows, C, N - - -# --------------------------- GPU-side param selection --------------------------- - - -def _rice_k_from_mean(lmbda: float) -> int: - if lmbda <= 0.0: - return 0 - return max(0, round(math.log2(max(lmbda, 1e-9)))) - - -@torch.no_grad() -def _estimate_best_params_per_row( - idx: torch.Tensor, C: int, B_choices: Sequence[int] -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Torch (GPU) estimate of best B, use_bitmap, and k_param per row. - Mirrors your previous vectorised selector. - """ - assert idx.dtype == torch.int64 - rows, k = idx.shape - device = idx.device - - B_sorted = tuple( - sorted([b for b in B_choices if b > 0 and (C % b) == 0 and (b & (b - 1)) == 0]) - ) - if not B_sorted: - raise ValueError("No valid B choices for C") - - header = 5 + 4 + 1 - best_bits = torch.full((rows,), 1 << 60, device=device, dtype=torch.int64) - best_B = torch.full((rows,), B_sorted[0], device=device, dtype=torch.int64) - best_use_bitmap = torch.zeros((rows,), device=device, dtype=torch.bool) - - Bmin = B_sorted[0] - if all((b % Bmin) == 0 for b in B_sorted): - n_sub_min = C // Bmin - js_min = (idx // Bmin).to(torch.int64) # [rows, k] - counts_min = ( - F.one_hot(js_min, num_classes=n_sub_min).sum(dim=1).to(torch.int64) - ) # [rows, n_sub_min] - lmbda_base = k / max(1, C) - - for B in B_sorted: - g = B // Bmin - counts_B = ( - counts_min - if g == 1 - else counts_min.reshape(rows, n_sub_min // g, g).sum(dim=2) - ) - lb = int(math.ceil(math.log2(B))) - n_sub = C // B - k_param = int(max(0, round(math.log2(max(lmbda_base * B, 1e-9))))) - m = 1 << k_param - q = counts_B // m - rb_sum = q.sum(dim=1) + (1 + k_param) * n_sub - nonzero = (counts_B > 0).sum(dim=1) - bits_local = header + rb_sum + lb * k - bits_bitmap = header + rb_sum + B * nonzero - cur_bits = torch.minimum(bits_local, bits_bitmap).to(torch.int64) - use_bitmap = bits_bitmap < bits_local - update = cur_bits < best_bits - best_bits = torch.where(update, cur_bits, best_bits) - best_B = torch.where(update, torch.full_like(best_B, B), best_B) - best_use_bitmap = torch.where(update, use_bitmap, best_use_bitmap) - else: - for B in B_sorted: - lb = int(math.ceil(math.log2(B))) - n_sub = C // B - js = (idx // B).to(torch.int64) - # Bincount per row with a constant bin width M=max_n_sub to avoid collisions - M = C // min(B_sorted) - row_ids = torch.arange(rows, device=device, dtype=torch.int64).unsqueeze(1) - flat = (row_ids * M + js).reshape(-1) - counts = torch.bincount(flat, minlength=rows * M).reshape(rows, M)[ - :, :n_sub - ] - lmbda = (k / max(1, C)) * B - k_param = int(max(0, round(math.log2(max(lmbda, 1e-9))))) - m = 1 << k_param - q = counts // m - rb_sum = q.sum(dim=1) + (1 + k_param) * n_sub - nonzero = (counts > 0).sum(dim=1) - bits_local = header + rb_sum + lb * k - bits_bitmap = header + rb_sum + B * nonzero - cur_bits = torch.minimum(bits_local, bits_bitmap).to(torch.int64) - use_bitmap = bits_bitmap < bits_local - update = cur_bits < best_bits - best_bits = torch.where(update, cur_bits, best_bits) - best_B = torch.where(update, torch.full_like(best_B, B), best_B) - best_use_bitmap = torch.where(update, use_bitmap, best_use_bitmap) - - # Rice k for chosen B - lmbda = (idx.shape[1] / max(1, C)) * best_B.float() - k_param = torch.clamp((lmbda.clamp_min(1e-9).log2().round()).to(torch.int64), min=0) - - return ( - best_B.to(torch.int32), - best_use_bitmap.to(torch.uint8), - k_param.to(torch.int32), - ) - - -# --------------------------- Triton helpers (no tl.uint32 calls) --------------------------- - - -@triton.jit -def _write_bits_u32(buf_ptr, bitpos, value, nbits): - # Write 'nbits' LSBs from value, advance bitpos, and return the new bitpos. - i = tl.zeros((), dtype=tl.int32) - bp = bitpos.to(tl.int64) - v = value.to(tl.int32) - while i < nbits: - bit = (v >> i) & 1 - byte_idx = bp // 8 - off = bp % 8 - p = buf_ptr + byte_idx - old = tl.load(p, mask=True, other=0).to(tl.int32) - newv = old | (bit << off) - tl.store(p, newv.to(tl.uint8)) - bp += 1 - i += 1 - return bp - - -@triton.jit -def _write_unary(buf_ptr, bitpos, q): - # Write q ones then a zero; zeros are already in buffer → just advance for trailing zero. - bp = bitpos.to(tl.int64) - i = tl.zeros((), dtype=tl.int32) - while i < q: - byte_idx = bp // 8 - off = bp % 8 - p = buf_ptr + byte_idx - old = tl.load(p, mask=True, other=0).to(tl.int32) - newv = old | (1 << off) - tl.store(p, newv.to(tl.uint8)) - bp += 1 - i += 1 - # trailing zero bit: buffer is zero-initialized, so just skip one bit - bp += 1 - return bp - - -@triton.jit -def _write_rice(buf_ptr, bitpos, x, kparam): - # Golomb-Rice for non-negative x with parameter k. - k = kparam.to(tl.int32) - if k == 0: - return _write_unary(buf_ptr, bitpos, x) - m = 1 << k - q = x // m - r = x & (m - 1) - bp = _write_unary(buf_ptr, bitpos, q) - bp = _write_bits_u32(buf_ptr, bp, r, k) - return bp - - -@triton.jit -def _set_one_bit(buf_ptr, bitpos): - bp = bitpos.to(tl.int64) - byte_idx = bp // 8 - off = bp % 8 - p = buf_ptr + byte_idx - old = tl.load(p, mask=True, other=0).to(tl.int32) - newv = old | (1 << off) - tl.store(p, newv.to(tl.uint8)) - - -# ------------------------------ Triton kernel --------------------------------- - - -@triton.jit -def _kernel_write_rows( - idx_ptr, # int64 [N*K] - rows, - k, - C, # ints - bestB_ptr, # int32 [N] - usebm_ptr, # uint8 [N] (0/1) - kparam_ptr, # int32 [N] - row_bytes_ptr, # int32 [N] - len_bitpos_ptr, # int64 [N] (bitpos of 16-bit length) - pay_bitpos_ptr, # int64 [N] (bitpos of first payload bit for the row) - payload_ptr, # uint8 [TOTAL_BYTES] - K_MAX: tl.constexpr, # upper bound for K (e.g., 256 or 512) -): - r = tl.program_id(0) - if r >= rows: - return - - # Per-row params - B = tl.load(bestB_ptr + r).to(tl.int32) - use_bitmap = tl.load(usebm_ptr + r).to(tl.int1) - kparam = tl.load(kparam_ptr + r).to(tl.int32) - # B is power-of-two ⇒ lb = log2(B) exactly (cast to float for tl.log2) - lb = tl.log2(B.to(tl.float32)).to(tl.int32) - n_sub = C // B - - # Write 16-bit length at its position (interleaved layout) - row_len = tl.load(row_bytes_ptr + r).to(tl.int32) - len_bp = tl.load(len_bitpos_ptr + r) - _ = _write_bits_u32(payload_ptr, len_bp, row_len, 16) - - # Row payload header - bp = tl.load(pay_bitpos_ptr + r) - bp = _write_bits_u32(payload_ptr, bp, lb, 5) # lb - bp = _write_bits_u32(payload_ptr, bp, kparam, 4) # k - bp = _write_bits_u32(payload_ptr, bp, use_bitmap.to(tl.int32), 1) # mode - - # Emit each sub-chunk j in ascending order - j = tl.zeros((), dtype=tl.int32) - while j < n_sub: - # -- first pass: count how many entries go to sub j - got = tl.zeros((), dtype=tl.int32) - t = tl.zeros((), dtype=tl.int32) - while t < k: - v = tl.load(idx_ptr + r * k + t).to(tl.int64) - jj = (v // B).to(tl.int32) - got += jj == j - t += 1 - - # write Rice length - bp = _write_rice(payload_ptr, bp, got, kparam) - - if got > 0: - if use_bitmap: - # second pass: set bits for locations; advance by B bits - start = bp - t = tl.zeros((), dtype=tl.int32) - while t < k: - v = tl.load(idx_ptr + r * k + t).to(tl.int64) - jj = (v // B).to(tl.int32) - if jj == j: - loc = (v - j.to(tl.int64) * B.to(tl.int64)).to(tl.int32) - _set_one_bit(payload_ptr, start + loc) - t += 1 - bp = start + B - else: - # local list: second pass writing lb-bit locs (order needn't be sorted) - t = tl.zeros((), dtype=tl.int32) - while t < k: - v = tl.load(idx_ptr + r * k + t).to(tl.int64) - jj = (v // B).to(tl.int32) - if jj == j: - loc = (v - j.to(tl.int64) * B.to(tl.int64)).to(tl.int32) - bp = _write_bits_u32(payload_ptr, bp, loc, lb) - t += 1 - j += 1 - # done - - -# -------------------------------- Public API ---------------------------------- - - -@torch.no_grad() -def encode_batch_rows( - idx: torch.Tensor, # [rows, k] int64 (CUDA strongly recommended) - *, - C: int, - B_choices: tuple[int, ...] = (64, 128), -) -> tuple[bytes, dict]: - """ - Triton encoder for per-row Rice/bitmap codec. - - Returns: - payload: bytes - meta: {total_bits, avg_bits_per_row, B_hist} - """ - if not TRITON_AVAILABLE: - raise RuntimeError("Triton is not available. `pip install triton` and re-run.") - - if idx.dtype != torch.int64: - idx = idx.to(torch.int64) - if not idx.is_cuda: - idx = idx.cuda() - idx = idx.contiguous() - - rows, k = idx.shape - device = idx.device - - # 1) Pick best params per row (GPU) - best_B, use_bitmap, k_param = _estimate_best_params_per_row( - idx, C=C, B_choices=B_choices - ) - - # 2) Compute exact per-row bit counts to size output - header_bits_row = 5 + 4 + 1 # lb + k + mode - lb = (best_B.float().log2().round().to(torch.int64)).clamp_min( - 1 - ) # exact for power-of-two - n_sub = C // best_B.to(torch.int64) - - # Bincount per row with constant width M to prevent collisions - M = int(max(C // int(b) for b in best_B.unique().tolist())) - row_ids = torch.arange(rows, device=device, dtype=torch.int64).unsqueeze(1) - js = (idx // best_B.to(torch.int64).unsqueeze(1)).to(torch.int64) # [rows, k] - flat = (row_ids * M + js).reshape(-1) - counts = torch.bincount(flat, minlength=rows * M).reshape(rows, M) # [rows, M] - # limit to effective n_sub per row when summing - # rice bits: Σ(q + 1 + k) with q = c // 2^k - m = 1 << k_param.to(torch.int64) - q = counts // m.unsqueeze(1) - rb_sum = q.sum(dim=1) + (1 + k_param.to(torch.int64)) * n_sub.to(torch.int64) - nonzero = (counts > 0).sum(dim=1).to(torch.int64) - - bits_local = header_bits_row + rb_sum + lb * k - bits_bitmap = header_bits_row + rb_sum + best_B.to(torch.int64) * nonzero - row_bits = torch.minimum(bits_local, bits_bitmap).to(torch.int64) - row_bytes = ((row_bits + 7) // 8).to(torch.int32) - - # 3) Allocate payload buffer (global header + interleaved [len16 | payload]) - total_bits_rows = int((16 * rows + 8 * row_bytes.sum().item())) - total_bits = 12 + 16 + 1 + total_bits_rows - total_bytes = (total_bits + 7) // 8 - payload = torch.zeros(total_bytes, dtype=torch.uint8, device=device) - - # 4) Compute interleaved bit positions for each row - header_bits = 12 + 16 + 1 - body_chunk_bits = 16 + 8 * row_bytes.to(torch.int64) # [rows] - prefix = torch.zeros_like(body_chunk_bits) - if rows > 0: - prefix[1:] = torch.cumsum(body_chunk_bits[:-1], dim=0) - len_bitpos = header_bits + prefix # [rows] - pay_bitpos = len_bitpos + 16 # [rows] - - # 5) Write global header in-place (LSB-first) using torch ops - def _write_scalar_bits(val: int, nbits: int, start_bit: int): - v = int(val) - bp = int(start_bit) - nb = int(nbits) - while nb > 0: - byte_idx = bp // 8 - off = bp % 8 - take = min(nb, 8 - off) - mask = ((v & ((1 << take) - 1)) << off) & 0xFF - payload[byte_idx] |= torch.as_tensor(mask, dtype=torch.uint8, device=device) - v >>= take - bp += take - nb -= take - - _write_scalar_bits(C - 1, 12, 0) - _write_scalar_bits(rows, 16, 12) - _write_scalar_bits(0, 1, 28) # reserved - - # 6) Launch Triton to write all rows - # Choose a safe K_MAX for your top-k; 256 covers k<=256; use 512 if you push k higher. - K_MAX = 256 if k <= 256 else 512 - grid = (rows,) - _kernel_write_rows[grid]( - idx, - rows, - k, - C, - best_B, - use_bitmap, - k_param, - row_bytes, - len_bitpos.to(torch.int64), - pay_bitpos.to(torch.int64), - payload, - K_MAX=K_MAX, - ) - - # 7) Return bytes + meta - payload_bytes = bytes(payload.detach().cpu().numpy().tobytes()) - B_hist = {int(b): int((best_B == b).sum().item()) for b in best_B.unique()} - meta = { - "total_bits": total_bits, - "avg_bits_per_row": float(row_bits.float().mean().item()) if rows > 0 else 0.0, - "B_hist": B_hist, - } - return payload_bytes, meta diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py index 8d6730e21..0d2dc5777 100644 --- a/src/tplr/compression/hybrid.py +++ b/src/tplr/compression/hybrid.py @@ -7,6 +7,8 @@ import triton import triton.language as tl +from .bitops import write_nbits_fast, read_unary_bounded_triton, read_nbits_fast + BytesLike = Union[bytes, bytearray, np.ndarray, torch.Tensor] @@ -22,24 +24,23 @@ def encode_batch_rows( using a per-row adaptive Rice/Bitmap compression scheme on the GPU. Layout: - 0..3 : "CGRP" (magic) - 4..7 : C (uint32 LE) - 8..9 : K (uint16 LE) - 10..13 : R (uint32 LE, num_rows) - 14 : num_B (uint8) - 15.. : B_choices (num_B * uint16 LE) - - Args: - idx (torch.Tensor): [rows, k] int64 tensor of indices. - C (int): The total number of columns (0 <= idx < C). - B_choices (tuple[int, ...]): Block sizes to evaluate. - Must be powers of two. - Must evenly divide C. - - Returns: - tuple[bytes, dict]: (payload, meta) - - payload (bytes): The compressed byte string. - - meta (dict): Metadata about the compression. + + [global header] + 0..3 : "CGRP" (magic) + 4..7 : C (uint32 LE) + 8..9 : K (uint16 LE) + 10..13 : R (uint32 LE, num_rows) + 14 : num_B (uint8) + 15.. : B_choices (num_B * uint16 LE) + + [row table] (num_rows entries, 3 bytes each) + - uint16 length_bytes[r] (payload size in BYTES for row r) + - uint8 header[r] ((best_B_idx << 1) | use_bitmap) + + [payload region] + - concatenated bitstreams, one per row, each length_bytes[r] bytes, + byte-aligned, containing ONLY the Rice/bitmap-coded deltas + (no per-row length or header in-band). """ if not torch.cuda.is_available(): @@ -67,7 +68,6 @@ def encode_batch_rows( idx = idx.contiguous() dev = idx.device - # v[0], v[1]-v[0], v[2]-v[1], ... vals = torch.cat( (idx[:, :1], idx[:, 1:] - idx[:, :-1]), dim=1, @@ -76,22 +76,21 @@ def encode_batch_rows( # Cast to int32 for Triton kernels vals = vals.to(torch.int32) - # Calculate k_rice parameters (log2(C // B)) + # k_rice parameters (log2(C // B)) k_rice_choices = tuple(int(math.log2(C // b)) for b in B_choices) num_B_choices = len(B_choices) k_rice_choices_tensor = torch.tensor(k_rice_choices, dtype=torch.int32, device=dev) - B_choice_bits = (num_B_choices - 1).bit_length() - # Row header: 1 bit (bitmap/rice) + B_choice_bits - ROW_HEADER_BITS = 1 + B_choice_bits + # Row header bits (only used for packing row-table header byte) + B_choice_bits = (num_B_choices - 1).bit_length() + ROW_HEADER_BITS = 1 + B_choice_bits # (best_B_idx << 1) | use_bitmap # Output tensors for cost kernel costs = torch.empty((num_rows, num_B_choices), dtype=torch.int32, device=dev) is_bitmap = torch.empty((num_rows, num_B_choices), dtype=torch.int8, device=dev) grid = (num_rows,) - # Launch cost kernel - # k_dim is passed as constexpr for tl.arange, but B_choices are dynamic + # cost kernel: bits required for deltas only (no header bits) cost_kernel[grid]( vals, costs, @@ -102,77 +101,99 @@ def encode_batch_rows( k_rice_choices_ptr=k_rice_choices_tensor, ) - # pick best B/mode & compute layout - # Best choice per row min_costs, best_B_idx = torch.min(costs, dim=1) is_bitmap_choice = torch.gather(is_bitmap, 1, best_B_idx.unsqueeze(1)).squeeze(1).to(torch.int32) - # (1) payload bits per row (deltas only) - row_payload_bits = min_costs + ROW_HEADER_BITS # (rows,) + # (1) payload bits per row = bits for deltas only + row_payload_bits = min_costs # (rows,) # (2) payload bytes per row (rounded up) row_payload_bytes = ((row_payload_bits + 7) // 8).to(torch.int32) # (rows,) - # (3) on-wire bits per row = 16 (length) + payload rounded to bytes - row_bits_aligned = (16 + row_payload_bytes * 8).to(torch.int64) # (rows,) + # ensure fit in uint16 for the row table + if torch.any(row_payload_bytes > 0xFFFF): + raise ValueError("Row payload length exceeds 65535 bytes; cannot store in uint16.") - # (4) starting bit offsets (before header) - row_bit_offsets = torch.nn.functional.pad( - torch.cumsum(row_bits_aligned, dim=0, dtype=torch.int64)[:-1], - (1, 0) - ) - - # (5) total bits across all rows (Python int) - total_bits = int(row_bits_aligned.sum().item()) + # byte offsets within the payload region (no gaps) + # row_byte_offsets[r] = sum_{i> 8) & 0xFF).to(torch.uint8) + + # Only the low ROW_HEADER_BITS bits are meaningful, but we just store the byte. + row_table[:, 2] = (headers_i32 & ((1 << ROW_HEADER_BITS) - 1)).to(torch.uint8) + + payload_buf[ + global_header_len_bytes : global_header_len_bytes + row_table_bytes + ] = row_table.view(-1) + + # compute bit offsets for each row's payload (no per-row length/header in-band) + row_bit_offsets = (payload_region_start + row_byte_offsets).to(torch.int32) * 8 + + # pack payloads pack_kernel[(num_rows,)]( vals, payload_buf, - row_bit_offsets.to(torch.int32), - row_payload_bytes, # already int32 + row_bit_offsets, best_B_idx.to(torch.int32), is_bitmap_choice, # int32 0/1 k_rice_choices_tensor, num_rows, k_dim=k_dim, - ROW_HEADER_BITS=ROW_HEADER_BITS, ) + # meta b_counts = torch.bincount(best_B_idx, minlength=len(B_choices)) B_hist = {b: c.item() for b, c in zip(B_choices, b_counts)} + total_row_bytes = total_payload_bytes + row_entry_bytes * num_rows + total_bits = int(total_row_bytes * 8) + meta = { - "total_bits": total_bits, # includes 16-bit length and byte padding - "avg_bits_per_row": float(row_bits_aligned.float().mean().item()), + "total_bits": total_bits, + "avg_bits_per_row": float(total_bits / num_rows), "avg_payload_bits_per_row": float(row_payload_bits.float().mean().item()), - # header+payload, no 16-bit length, before byte-rounding "B_hist": B_hist, } return payload_buf, meta @@ -237,63 +258,28 @@ def cost_kernel( b_idx += 1 -@triton.jit -def write_nbits( - u8_ptr, # uint8* global buffer - bit_off_i32, # scalar tl.int32 bit offset - value_u32, # scalar tl.uint32, up to 32 bits used - nbits_i32, # scalar tl.int32, number of bits to write -): - """ - Writes `nbits_i32` least-significant bits of `value_u32` into `u8_ptr` - starting at bit offset `bit_off_i32` in LSB-first order. - - This is still a bit-at-a-time writer; higher-level kernels have been - adjusted to use int32 + shift/mask ahead of time. - """ - j = tl.full((), 0, dtype=tl.int32) - ONE_U32 = tl.full((), 1, dtype=tl.uint32) - - while j < nbits_i32: - pos = bit_off_i32 + j - byte_idx = (pos >> 3).to(tl.int32) - bit_idx = (pos & 7).to(tl.int32) - - old_u8 = tl.load(u8_ptr + byte_idx) - old_u32 = old_u8.to(tl.uint32) - - vbit = (value_u32 >> j) & ONE_U32 - mask = ONE_U32 << bit_idx - new_u32 = (old_u32 & (~mask)) | (vbit << bit_idx) - tl.store(u8_ptr + byte_idx, new_u32.to(tl.uint8)) - j += 1 - return bit_off_i32 + nbits_i32 - - @triton.jit def pack_kernel( delta_ptr, # (rows, k_dim) IN int32 u8_payload_ptr, # (final_buffer_bytes,) OUT uint8 - row_bit_offsets_ptr, # (rows,) IN (int32 preferred) - row_payload_bytes_ptr, # (rows,) IN int32 + row_bit_offsets_ptr, # (rows,) IN int32 (bit offset where payload starts) best_B_idx_ptr, # (rows,) IN int32 is_bitmap_ptr, # (rows,) IN int32 (0/1) k_rice_choices_ptr, # [num_B] IN int32 num_rows: tl.int32, k_dim: tl.int32, # dynamic - ROW_HEADER_BITS: tl.constexpr, ): """ - First delta Rice (unary = q ones then 0) + r; tail bitmap or Rice. - Bit order: LSB-first. + Writes only the Rice/bitmap-coded payload bits for each row. + + Each program instance handles one row. Bit order is LSB-first. """ row_idx = tl.program_id(0) if row_idx >= num_rows: return - # per-row meta + # Per-row meta bit_off_i32 = tl.load(row_bit_offsets_ptr + row_idx).to(tl.int32) - payload_bytes_i32 = tl.load(row_payload_bytes_ptr + row_idx).to(tl.int32) b_idx_i32 = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) use_bitmap_i32 = (tl.load(is_bitmap_ptr + row_idx) & 1).to(tl.int32) @@ -306,36 +292,30 @@ def pack_kernel( ONE_I32 = tl.full((), 1, dtype=tl.int32) THIRTY_ONE_I32 = tl.full((), 31, dtype=tl.int32) - # 16-bit length - bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, - payload_bytes_i32.to(tl.uint32), - tl.full((), 16, dtype=tl.int32)) - - # header ((b_idx << 1) | use_bitmap) - header_i32 = (b_idx_i32 << 1) | use_bitmap_i32 - bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, - header_i32.to(tl.uint32), - tl.full((), ROW_HEADER_BITS, dtype=tl.int32)) - base = row_idx * k_dim - # first delta: ALWAYS Rice + # ---- first delta: ALWAYS full Rice (unary + remainder) ---- if k_dim > 0: v0 = tl.load(delta_ptr + base).to(tl.int32) q0 = (v0 >> k_rice_i32).to(tl.int32) r0 = (v0 & (M_i32 - 1)).to(tl.int32) - # q0 ones in chunks of <=31, then a single 0 + # q0 ones in chunks of <= 31, then a single 0 q_left = q0 while q_left > 0: chunk = tl.minimum(q_left, THIRTY_ONE_I32) ones = (ONE_U32 << chunk) - ONE_U32 - bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, ones, chunk) + bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ones, chunk) q_left -= chunk - bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, ZERO_U32, ONE_I32) # terminating 0 - bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, r0.to(tl.uint32), k_rice_i32) # remainder - # tail deltas + # terminating 0 bit + bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ZERO_U32, ONE_I32) + # remainder + bit_off_i32 = write_nbits_fast( + u8_payload_ptr, bit_off_i32, r0.to(tl.uint32), k_rice_i32 + ) + + # ---- tail deltas ---- i = 1 while i < k_dim: v = tl.load(delta_ptr + base + i).to(tl.int32) @@ -347,94 +327,23 @@ def pack_kernel( while q_left > 0: chunk = tl.minimum(q_left, THIRTY_ONE_I32) ones = (ONE_U32 << chunk) - ONE_U32 - bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, ones, chunk) + bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ones, chunk) q_left -= chunk + + # terminating 0 bit only in full-Rice mode n_term = tl.where(use_bitmap_i32 != 0, tl.full((), 0, dtype=tl.int32), ONE_I32) - bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, ZERO_U32, n_term) + bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ZERO_U32, n_term) # bitmap q only if bitmap q_bit = tl.where(q > 0, ONE_U32, ZERO_U32) n_qbit = tl.where(use_bitmap_i32 != 0, ONE_I32, tl.full((), 0, dtype=tl.int32)) - bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, q_bit, n_qbit) + bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, q_bit, n_qbit) # remainder always - bit_off_i32 = write_nbits(u8_payload_ptr, bit_off_i32, r.to(tl.uint32), k_rice_i32) + bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, r.to(tl.uint32), k_rice_i32) i += 1 -@triton.jit -def read_nbits_triton(u8_ptr, bit_off_i32, nbits_i32, limit_bit_i32): - """ - GPU version of BitStreamReader.read_bits (LSB-first), but bounds-safe. - - Reads `nbits_i32` bits starting at `bit_off_i32`, but never loads beyond - bit index `limit_bit_i32` (masked loads return 0 out-of-bounds). - - Returns: (value_u32, new_bit_off_i32) - """ - j = tl.full((), 0, dtype=tl.int32) - val_u32 = tl.full((), 0, dtype=tl.uint32) - ONE_U32 = tl.full((), 1, dtype=tl.uint32) - ZERO_U8 = tl.full((), 0, dtype=tl.uint8) - - while j < nbits_i32: - pos = bit_off_i32 + j - in_bounds = pos < limit_bit_i32 - - byte_idx = (pos >> 3).to(tl.int32) - bit_idx = (pos & 7).to(tl.int32) - - # Masked load: if in_bounds==0, we load ZERO_U8 instead of touching memory. - u8 = tl.load(u8_ptr + byte_idx, mask=in_bounds, other=ZERO_U8) - u32 = u8.to(tl.uint32) - bit = (u32 >> bit_idx) & ONE_U32 - - val_u32 |= (bit << j) - j += 1 - - new_bit_off = bit_off_i32 + nbits_i32 - return val_u32, new_bit_off - - -@triton.jit -def read_unary_bounded_triton(u8_ptr, bit_off_i32, end_bit_i32): - """ - GPU version of BitStreamReader.read_unary_bounded(end_bit). - Reads '1's until a '0' or end_bit. - Returns: (q_i32, new_bit_off_i32, hit_end_i32) - - q_i32: number of 1s before the terminating 0 - - hit_end_i32: 1 if we reached end_bit without seeing 0 - 0 if we saw a terminating 0 - """ - ONE_U32 = tl.full((), 1, dtype=tl.uint32) - q_i32 = tl.full((), 0, dtype=tl.int32) - hit_end_i32 = tl.full((), 1, dtype=tl.int32) - - cond = bit_off_i32 < end_bit_i32 - while cond: - pos = bit_off_i32 - byte_idx = (pos >> 3).to(tl.int32) - bit_idx = (pos & 7).to(tl.int32) - - u8 = tl.load(u8_ptr + byte_idx) - u32 = u8.to(tl.uint32) - bit = (u32 >> bit_idx) & ONE_U32 - - bit_off_i32 += 1 - - is_one = (bit == ONE_U32) - q_i32 += is_one.to(tl.int32) - - # If bit is 0, we did NOT hit end - hit_end_i32 = tl.where(is_one, hit_end_i32, - tl.full((), 0, dtype=tl.int32)) - - # Continue only if we are still inside the row and last bit was 1 - cond = (bit_off_i32 < end_bit_i32) & is_one - - return q_i32, bit_off_i32, hit_end_i32 - - @triton.jit def parse_header_kernel( u8_payload_ptr, # (total_bytes,) uint8 @@ -450,7 +359,6 @@ def parse_header_kernel( ): """ Parse the global header entirely on GPU. - Layout: 0..3 : "CGRP" 4..7 : C (uint32 LE) @@ -549,89 +457,72 @@ def parse_header_kernel( @triton.jit -def scan_rows_kernel( +def parse_row_table_kernel( u8_payload_ptr, # (total_bytes,) uint8 - row_bit_offsets_ptr, # (num_rows,) int32 (bit offset of 16-bit length) row_payload_bytes_ptr, # (num_rows,) int32 best_B_idx_ptr, # (num_rows,) int32 - use_bitmap_ptr, # (num_rows,) int32 (0/1) - header_end_bit: tl.int32, - total_bits: tl.int32, + use_bitmap_ptr, # (num_rows,) int32 + row_table_start: tl.int32, num_rows: tl.int32, ROW_HEADER_BITS: tl.constexpr, ): """ - Sequential scan of all rows (1 program). For each row r: - - bit_off: bit offset of 16-bit payload length - length: row_payload_bytes[r] - header: ((b_idx << 1) | use_bitmap) in ROW_HEADER_BITS bits - rest: payload_bits - ROW_HEADER_BITS bits - - Assumes the bitstream is valid and has enough bits for num_rows rows. + Parse the row table: + + For each row r: + offset = row_table_start + r * 3 + length_bytes[r] = uint16 LE at offset + header_byte = uint8 at offset + 2 + header_bits = header_byte & ((1 << ROW_HEADER_BITS) - 1) + use_bitmap[r] = header_bits & 1 + best_B_idx[r] = header_bits >> 1 """ pid = tl.program_id(0) - if pid != 0: + if pid >= num_rows: return - bit_off_i32 = header_end_bit - r = tl.full((), 0, dtype=tl.int32) - SIXTEEN_I32 = tl.full((), 16, dtype=tl.int32) + entry_offset = row_table_start + pid * 3 - while r < num_rows: - # bit offset at the start of the 16-bit length for this row - tl.store(row_bit_offsets_ptr + r, bit_off_i32) + # length_bytes: uint16 LE + b0 = tl.load(u8_payload_ptr + entry_offset).to(tl.int32) + b1 = tl.load(u8_payload_ptr + entry_offset + 1).to(tl.int32) + length_i32 = b0 | (b1 << 8) + tl.store(row_payload_bytes_ptr + pid, length_i32) - # read 16-bit payload length (bytes) - length_u32, bit_off_after_len = read_nbits_triton( - u8_payload_ptr, bit_off_i32, SIXTEEN_I32, total_bits - ) - length_i32 = length_u32.to(tl.int32) - tl.store(row_payload_bytes_ptr + r, length_i32) + # header byte + header_byte = tl.load(u8_payload_ptr + entry_offset + 2).to(tl.int32) + header_mask = (tl.full((), 1, dtype=tl.int32) << ROW_HEADER_BITS) - 1 + header_i32 = header_byte & header_mask - # read row header bits: ((b_idx << 1) | use_bitmap) - header_u32, bit_off_after_header = read_nbits_triton( - u8_payload_ptr, - bit_off_after_len, - tl.full((), ROW_HEADER_BITS, dtype=tl.int32), - total_bits, - ) - header_i32 = header_u32.to(tl.int32) - use_bitmap_i32 = header_i32 & 1 - best_B_idx_i32 = header_i32 >> 1 + use_bitmap_i32 = header_i32 & 1 + best_B_idx_i32 = header_i32 >> 1 - tl.store(best_B_idx_ptr + r, best_B_idx_i32) - tl.store(use_bitmap_ptr + r, use_bitmap_i32) + tl.store(use_bitmap_ptr + pid, use_bitmap_i32) + tl.store(best_B_idx_ptr + pid, best_B_idx_i32) - # skip remainder of this row's payload - payload_bits_i32 = length_i32 * 8 - rem_bits_i32 = payload_bits_i32 - ROW_HEADER_BITS - bit_off_i32 = bit_off_after_header + rem_bits_i32 - r += 1 @triton.jit def decode_rows_kernel( u8_payload_ptr, # (total_bytes,) uint8 out_vals_ptr, # (num_rows * K,) int32 - row_bit_offsets_ptr, # (num_rows,) int32 (bit offset of 16-bit length) + row_bit_offsets_ptr, # (num_rows,) int32 (bit offset of first encoded bit) row_payload_bytes_ptr, # (num_rows,) int32 best_B_idx_ptr, # (num_rows,) int32 use_bitmap_ptr, # (num_rows,) int32 k_rice_choices_ptr, # (num_B,) int32 num_rows: tl.int32, K: tl.int32, - ROW_HEADER_BITS: tl.constexpr, ): """ Fully GPU decode of Rice/bitmap rows. - For each row: - - Start at bit offset of 16-bit length - - Skip 16-bit length - - Skip header bits (we already know b_idx/use_bitmap from scan) + For each row r: + - Bit range: + start_bit = row_bit_offsets[r] + end_bit = start_bit + row_payload_bytes[r] * 8 - First value: full Rice (unary + remainder) - - Tail: Rice or bitmap+remainder + - Tail: Rice or bitmap+remainder depending on use_bitmap[r]. """ row_idx = tl.program_id(0) if row_idx >= num_rows: @@ -648,20 +539,13 @@ def decode_rows_kernel( M_i32 = (tl.full((), 1, dtype=tl.int32) << k_rice_i32) # Bit range of this row - bit_after_len_i32 = row_start_bit_i32 + 16 - row_end_bit_i32 = bit_after_len_i32 + payload_bytes_i32 * 8 - - # Skip header bits (we already know the contents) - header_dummy_u32, bit_off_i32 = read_nbits_triton( - u8_payload_ptr, - bit_after_len_i32, - tl.full((), ROW_HEADER_BITS, dtype=tl.int32), - row_end_bit_i32, # limit = end of this row - ) + row_end_bit_i32 = row_start_bit_i32 + payload_bytes_i32 * 8 base_out = row_idx * K ONE_I32 = tl.full((), 1, dtype=tl.int32) + bit_off_i32 = row_start_bit_i32 + # ---- first value: ALWAYS full Rice ---- if K > 0: q0_i32, bit_off_i32, hit_end0_i32 = read_unary_bounded_triton( @@ -669,7 +553,7 @@ def decode_rows_kernel( bit_off_i32, row_end_bit_i32, ) - r0_u32, bit_off_i32 = read_nbits_triton( + r0_u32, bit_off_i32 = read_nbits_fast( u8_payload_ptr, bit_off_i32, k_rice_i32, @@ -684,7 +568,7 @@ def decode_rows_kernel( while i < K: if use_bitmap_i32 != 0: # Bitmap mode: q is 1 bit in {0,1} - q_bit_u32, bit_off_i32 = read_nbits_triton( + q_bit_u32, bit_off_i32 = read_nbits_fast( u8_payload_ptr, bit_off_i32, ONE_I32, @@ -692,7 +576,7 @@ def decode_rows_kernel( ) q_i32 = q_bit_u32.to(tl.int32) - r_u32, bit_off_i32 = read_nbits_triton( + r_u32, bit_off_i32 = read_nbits_fast( u8_payload_ptr, bit_off_i32, k_rice_i32, @@ -706,7 +590,7 @@ def decode_rows_kernel( bit_off_i32, row_end_bit_i32, ) - r_u32, bit_off_i32 = read_nbits_triton( + r_u32, bit_off_i32 = read_nbits_fast( u8_payload_ptr, bit_off_i32, k_rice_i32, @@ -747,12 +631,10 @@ def decode_batch_rows( empty = torch.empty((0, 0), dtype=torch.int64, device=dev) return empty, 0, 0 - total_bits = total_bytes * 8 - - # --- 1) Parse global header on GPU (now also gets num_rows = R) --- + # --- 1) Parse global header on GPU --- C_out = torch.empty(1, dtype=torch.int32, device=dev) K_out = torch.empty(1, dtype=torch.int32, device=dev) - R_out = torch.empty(1, dtype=torch.int32, device=dev) # NEW + R_out = torch.empty(1, dtype=torch.int32, device=dev) num_B_out = torch.empty(1, dtype=torch.int32, device=dev) B_choices_out = torch.empty(max_num_B, dtype=torch.int32, device=dev) header_bytes_out = torch.empty(1, dtype=torch.int32, device=dev) @@ -778,11 +660,10 @@ def decode_batch_rows( C = int(C_out.cpu().item()) K = int(K_out.cpu().item()) - num_rows = int(R_out.cpu().item()) # NEW + num_rows = int(R_out.cpu().item()) num_B = int(num_B_out.cpu().item()) header_bytes = int(header_bytes_out.cpu().item()) B_choices_list = [int(x) for x in B_choices_out[:num_B].cpu().tolist()] - header_end_bit = header_bytes * 8 # --- 2) Build k_rice choices on CPU -> move to GPU --- k_rice_choices = [] @@ -798,25 +679,49 @@ def decode_batch_rows( B_choice_bits = (num_B - 1).bit_length() ROW_HEADER_BITS = 1 + B_choice_bits - # --- 3) Scan rows on GPU to get per-row metadata --- - row_bit_offsets = torch.empty(num_rows, dtype=torch.int32, device=dev) + # --- 3) Parse row table on GPU --- + row_entry_bytes = 3 + row_table_bytes = num_rows * row_entry_bytes + if header_bytes + row_table_bytes > total_bytes: + raise ValueError("Truncated payload: row table exceeds payload length") + row_payload_bytes = torch.empty(num_rows, dtype=torch.int32, device=dev) best_B_idx = torch.empty(num_rows, dtype=torch.int32, device=dev) use_bitmap = torch.empty(num_rows, dtype=torch.int32, device=dev) - scan_rows_kernel[(1,)]( + parse_row_table_kernel[(num_rows,)]( payload_gpu, - row_bit_offsets, row_payload_bytes, best_B_idx, use_bitmap, - header_end_bit, - int(total_bits), + int(header_bytes), int(num_rows), ROW_HEADER_BITS=ROW_HEADER_BITS, ) - # --- 4) Decode rows in parallel on GPU --- + # --- 4) Compute per-row bit offsets into payload region --- + payload_region_start_byte = header_bytes + row_table_bytes + if payload_region_start_byte > total_bytes: + raise ValueError("Truncated payload: missing payload region") + + # byte offsets within the payload region + row_payload_bytes_64 = row_payload_bytes.to(torch.int64) + row_byte_offsets = torch.cumsum(row_payload_bytes_64, dim=0) - row_payload_bytes_64 + + # Sanity check: last row must end within the buffer + last_end = int( + payload_region_start_byte + + row_byte_offsets[-1].item() + + row_payload_bytes_64[-1].item() + ) + if last_end > total_bytes: + raise ValueError("Truncated payload: row payload bytes exceed buffer length") + + row_bit_offsets = ( + payload_region_start_byte + row_byte_offsets + ).to(torch.int32) * 8 + + # --- 5) Decode rows in parallel on GPU --- out_vals = torch.empty((num_rows, K), dtype=torch.int32, device=dev) decode_rows_kernel[(num_rows,)]( payload_gpu, @@ -828,29 +733,8 @@ def decode_batch_rows( k_rice_choices_tensor, int(num_rows), int(K), - ROW_HEADER_BITS=ROW_HEADER_BITS, ) - # --- undo delta on-GPU if needed --- out_vals = torch.cumsum(out_vals, dim=1) return out_vals.to(torch.int64), C, num_rows - -if __name__ == "__main__": - torch.manual_seed(0) - ROWS, K = 32, 16 - COLS = 4096 - - x = torch.randn((ROWS, COLS), dtype=torch.float32, device="cuda") - idx = torch.topk(x.abs(), k=K, dim=-1, largest=True, sorted=False).indices - idx, _ = torch.sort(idx, dim=1) - payload, _ = encode_batch_rows(idx, C=COLS, B_choices=(64, 128, 256)) - decoded, _, _ = decode_batch_rows(payload) - ok = True - for r in range(ROWS): - if not torch.equal(torch.tensor(decoded[r]), idx[r]): - ok = False - print("Mismatch row", r) - print("orig:", idx[r].tolist()) - print("dec :", decoded[r]) - print("Round-trip OK" if ok else "Round-trip MISMATCH") From 3d406e1155d0d7639268068bbf41254d42ed5630 Mon Sep 17 00:00:00 2001 From: Kasper Date: Mon, 17 Nov 2025 18:37:07 +0400 Subject: [PATCH 09/33] track compression time --- src/tplr/neurons.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/tplr/neurons.py b/src/tplr/neurons.py index 4f9bcdc98..1c3d5da97 100644 --- a/src/tplr/neurons.py +++ b/src/tplr/neurons.py @@ -154,6 +154,7 @@ def barrier(group=None): # --- 4) Encode & compress (owner only) --- encoded = miner.transformer.encode(error_feedback, use_dct=use_dct) + compress_start = tplr.T() idxs, vals, xshape, totalk, quant_params = miner.compressor.compress( encoded, topk ) @@ -164,6 +165,8 @@ def barrier(group=None): decompressed = miner.compressor.decompress( p, idxs, vals, xshape, totalk, quant_params ) + compression_time = tplr.T() - compress_start + tplr.logger.info(f"Compression time: {compression_time}") # --- 6) Decode & error-feedback update (owner only) --- transmit_grad = miner.transformer.decode(decompressed, use_dct=use_dct) From d61e8305113cb023acbd2eac6cac50c30a51183c Mon Sep 17 00:00:00 2001 From: Kasper Date: Mon, 17 Nov 2025 21:43:36 +0400 Subject: [PATCH 10/33] trace compression time --- src/tplr/neurons.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/tplr/neurons.py b/src/tplr/neurons.py index 1c3d5da97..cee53ed7d 100644 --- a/src/tplr/neurons.py +++ b/src/tplr/neurons.py @@ -110,7 +110,7 @@ def barrier(group=None): miner.error_feedback[n] = miner.error_feedback[n].to( param.device, non_blocking=True ) - + compression_time = 0 for _, (n, p) in enumerate(model_iterator, 1): owned = n in miner.owned_params p_is_dt = is_dtensor(p) @@ -165,8 +165,7 @@ def barrier(group=None): decompressed = miner.compressor.decompress( p, idxs, vals, xshape, totalk, quant_params ) - compression_time = tplr.T() - compress_start - tplr.logger.info(f"Compression time: {compression_time}") + compression_time += tplr.T() - compress_start # --- 6) Decode & error-feedback update (owner only) --- transmit_grad = miner.transformer.decode(decompressed, use_dct=use_dct) @@ -204,6 +203,8 @@ def barrier(group=None): # Clear per-param grad p.grad = None + tplr.logger.info(f"Compression time: {compression_time}") + # Batch offload all error feedback tensors to CPU with pinned memory for name in miner.error_feedback: if ( From 69a867a157bd0b5d3d57ab61490c8eac871bc387 Mon Sep 17 00:00:00 2001 From: Kasper Date: Tue, 18 Nov 2025 08:45:17 +0400 Subject: [PATCH 11/33] add timings --- src/tplr/neurons.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/tplr/neurons.py b/src/tplr/neurons.py index cee53ed7d..410a18268 100644 --- a/src/tplr/neurons.py +++ b/src/tplr/neurons.py @@ -111,6 +111,8 @@ def barrier(group=None): param.device, non_blocking=True ) compression_time = 0 + copy_time = 0 + encode_time = 0 for _, (n, p) in enumerate(model_iterator, 1): owned = n in miner.owned_params p_is_dt = is_dtensor(p) @@ -152,7 +154,9 @@ def barrier(group=None): error_feedback.add_(grad_full) # --- 4) Encode & compress (owner only) --- + encode_start = tplr.T() encoded = miner.transformer.encode(error_feedback, use_dct=use_dct) + encode_time += tplr.T() - encode_start compress_start = tplr.T() idxs, vals, xshape, totalk, quant_params = miner.compressor.compress( @@ -168,15 +172,18 @@ def barrier(group=None): compression_time += tplr.T() - compress_start # --- 6) Decode & error-feedback update (owner only) --- + encode_start = tplr.T() transmit_grad = miner.transformer.decode(decompressed, use_dct=use_dct) del decompressed error_feedback.sub_(transmit_grad) # Keep error feedback on GPU for now, batch offload later miner.error_feedback[n] = error_feedback del transmit_grad, error_feedback + encode_time += tplr.T() - encode_start # --- 7) Pack outputs (move compressed artifacts to CPU asynchronously) --- # Using non_blocking=True for async D2H transfers when CUDA is available + copy_start = tplr.T() if isinstance(idxs, torch.Tensor): if torch.cuda.is_available(): cpu_idxs = torch.empty_like(idxs, device="cpu", pin_memory=True) @@ -202,8 +209,9 @@ def barrier(group=None): # Clear per-param grad p.grad = None + copy_time += tplr.T() - copy_start - tplr.logger.info(f"Compression time: {compression_time}") + tplr.logger.info(f"times: {encode_time}, {compression_time}, {copy_time}") # Batch offload all error feedback tensors to CPU with pinned memory for name in miner.error_feedback: From 4c81409fa1515906df4a20a1907ffae2ef351219 Mon Sep 17 00:00:00 2001 From: Kasper Date: Tue, 18 Nov 2025 11:00:23 +0400 Subject: [PATCH 12/33] add fallback for batch decompress --- src/tplr/compress.py | 27 ++++++++++++++++----------- src/tplr/compression/__init__.py | 1 - 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/tplr/compress.py b/src/tplr/compress.py index de3b7e9af..ca3c71cb5 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -31,7 +31,7 @@ import tplr -from tplr.compression import encode_batch_rows, decode_batch_rows +from tplr.compression import encode_batch_rows, decode_batch_rows, unpack_12bit_indices # ─────────── type aliases ──────────────────────────────────────────────── # primitive shapes @@ -489,16 +489,21 @@ def batch_decompress( for i, i_data in enumerate(idx_list): v_data = val_list[i] if i_data.dtype == torch.uint8: - rows, C, _N = decode_batch_rows(i_data.detach().cpu().numpy().tobytes()) - if C != totalk: - raise ValueError(f"Index payload C={C} but expected {totalk}") - if any(len(r) != v_data.shape[-1] for r in rows): - raise ValueError( - "Row-wise topk size mismatch in index payload (batch)" - ) - idx_unpacked = torch.tensor( - rows, dtype=torch.int64, device=p.device - ).view(*v_data.shape) + try: + rows, C, _N = decode_batch_rows(i_data.detach().cpu().numpy().tobytes()) + if C != totalk: + raise ValueError(f"Index payload C={C} but expected {totalk}") + if any(len(r) != v_data.shape[-1] for r in rows): + raise ValueError( + "Row-wise topk size mismatch in index payload (batch)" + ) + idx_unpacked = torch.tensor( + rows, dtype=torch.int64, device=p.device + ).view(*v_data.shape) + except ValueError as e: + # Fallback: likely old format -> try legacy decoder + idx_unpacked = unpack_12bit_indices(i_data.to(p.device), v_data.shape) + unpacked_indices.append(idx_unpacked) elif i_data.dtype in (torch.int64, torch.long): idx_unpacked = i_data.to(p.device) diff --git a/src/tplr/compression/__init__.py b/src/tplr/compression/__init__.py index 3f0f3a875..0c9666af6 100644 --- a/src/tplr/compression/__init__.py +++ b/src/tplr/compression/__init__.py @@ -26,7 +26,6 @@ unpack_12bit_indices ) __all__ = [ - # High level "encode_batch_rows", "decode_batch_rows", "pack_12bit_indices", From 4fc3e5902840c2e6f20075ec2ed36b0b69f09bc9 Mon Sep 17 00:00:00 2001 From: Kasper Date: Wed, 19 Nov 2025 11:23:36 +0400 Subject: [PATCH 13/33] fix legacy path for decompress --- src/tplr/compress.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tplr/compress.py b/src/tplr/compress.py index ca3c71cb5..c903f044e 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -390,7 +390,8 @@ def decompress( if val.dtype != x.dtype: val = val.to(dtype=x.dtype) - if len(xshape) > 2: + # second condition for legacy decompress + if len(xshape) > 2 and len(idx_int64) == 2: idx_int64 = rearrange(idx_int64, "(y x) h -> y x h", y=xshape[0]) val = rearrange(val, "(y x) h -> y x h", y=xshape[0]) From de3ae06eaaf5fca4c47e052a81dafd17cb3fbaca Mon Sep 17 00:00:00 2001 From: Kasper Date: Wed, 19 Nov 2025 12:02:38 +0400 Subject: [PATCH 14/33] adjust condition --- src/tplr/compress.py | 2 +- src/tplr/compression/bitops.py | 1 - src/tplr/compression/hybrid.py | 18 +++++++++--------- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/tplr/compress.py b/src/tplr/compress.py index c903f044e..2bbba0844 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -391,7 +391,7 @@ def decompress( val = val.to(dtype=x.dtype) # second condition for legacy decompress - if len(xshape) > 2 and len(idx_int64) == 2: + if len(xshape) > 2 and len(idx_int64) > 1: idx_int64 = rearrange(idx_int64, "(y x) h -> y x h", y=xshape[0]) val = rearrange(val, "(y x) h -> y x h", y=xshape[0]) diff --git a/src/tplr/compression/bitops.py b/src/tplr/compression/bitops.py index 96b1c2d88..6302e5692 100644 --- a/src/tplr/compression/bitops.py +++ b/src/tplr/compression/bitops.py @@ -246,5 +246,4 @@ def read_unary_bounded_triton(u8_ptr, bit_off_i32, end_bit_i32): # Continue only if we are still inside the row and last bit was 1 cond = (bit_off_i32 < end_bit_i32) & is_one - return q_i32, bit_off_i32, hit_end_i32 \ No newline at end of file diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py index 0d2dc5777..abc7bcc24 100644 --- a/src/tplr/compression/hybrid.py +++ b/src/tplr/compression/hybrid.py @@ -14,7 +14,7 @@ @torch.no_grad() def encode_batch_rows( - idx: torch.Tensor, + idx_sorted: torch.Tensor, *, C: int, B_choices: Tuple[int, ...] = (64, 128) @@ -46,8 +46,8 @@ def encode_batch_rows( if not torch.cuda.is_available(): raise RuntimeError("CUDA is required for this function.") - if not isinstance(idx, torch.Tensor) or idx.ndim != 2: - raise ValueError(f"idx must be a 2D int64 tensor, got {idx.shape} {idx.dtype}") + if not isinstance(idx_sorted, torch.Tensor) or idx_sorted.ndim != 2: + raise ValueError(f"idx must be a 2D int64 tensor, got {idx_sorted.shape} {idx_sorted.dtype}") if not all(isinstance(b, int) and (b & (b - 1) == 0) and b > 0 for b in B_choices): raise ValueError(f"All B_choices must be powers of two, got {B_choices}") @@ -55,7 +55,7 @@ def encode_batch_rows( if not all(C % b == 0 for b in B_choices): raise ValueError(f"All B_choices must evenly divide C={C}, got {B_choices}") - num_rows, k_dim = idx.shape + num_rows, k_dim = idx_sorted.shape if num_rows == 0: return b"", { "total_bits": 0, @@ -63,13 +63,13 @@ def encode_batch_rows( "B_hist": {b: 0 for b in B_choices} } - if not idx.is_cuda: - idx = idx.cuda() - idx = idx.contiguous() - dev = idx.device + if not idx_sorted.is_cuda: + idx_sorted = idx_sorted.cuda() + idx_sorted = idx_sorted.contiguous() + dev = idx_sorted.device vals = torch.cat( - (idx[:, :1], idx[:, 1:] - idx[:, :-1]), + (idx_sorted[:, :1], idx_sorted[:, 1:] - idx_sorted[:, :-1]), dim=1, ) From 37b240fa5210aec085f28c4adaa4f24ee894e0d0 Mon Sep 17 00:00:00 2001 From: Kasper Date: Wed, 19 Nov 2025 12:28:04 +0400 Subject: [PATCH 15/33] fix legacy condition --- src/tplr/compress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tplr/compress.py b/src/tplr/compress.py index 2bbba0844..b52094fc3 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -391,7 +391,7 @@ def decompress( val = val.to(dtype=x.dtype) # second condition for legacy decompress - if len(xshape) > 2 and len(idx_int64) > 1: + if len(xshape) > 2 and len(xshape) != len(idx_int64): idx_int64 = rearrange(idx_int64, "(y x) h -> y x h", y=xshape[0]) val = rearrange(val, "(y x) h -> y x h", y=xshape[0]) From 56c8d7a02e442e151e85c1eb0c9ef96a7cb82e72 Mon Sep 17 00:00:00 2001 From: Kasper Date: Wed, 19 Nov 2025 13:01:17 +0400 Subject: [PATCH 16/33] added register buffering, fast unary decoding and simple header parsing --- src/tplr/compress.py | 2 +- src/tplr/compression/bitops.py | 249 ------------ src/tplr/compression/hybrid.py | 700 +++++++++++++-------------------- 3 files changed, 280 insertions(+), 671 deletions(-) delete mode 100644 src/tplr/compression/bitops.py diff --git a/src/tplr/compress.py b/src/tplr/compress.py index b52094fc3..b6cc544be 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -391,7 +391,7 @@ def decompress( val = val.to(dtype=x.dtype) # second condition for legacy decompress - if len(xshape) > 2 and len(xshape) != len(idx_int64): + if len(xshape) > 2 and len(x) != len(idx_int64): idx_int64 = rearrange(idx_int64, "(y x) h -> y x h", y=xshape[0]) val = rearrange(val, "(y x) h -> y x h", y=xshape[0]) diff --git a/src/tplr/compression/bitops.py b/src/tplr/compression/bitops.py deleted file mode 100644 index 6302e5692..000000000 --- a/src/tplr/compression/bitops.py +++ /dev/null @@ -1,249 +0,0 @@ -from typing import Union - -import numpy as np -import torch -import triton -import triton.language as tl - -BytesLike = Union[bytes, bytearray, np.ndarray, torch.Tensor] - -@triton.jit -def write_nbits( - u8_ptr, # uint8* global buffer - bit_off_i32, # scalar tl.int32 bit offset - value_u32, # scalar tl.uint32, up to 32 bits used - nbits_i32, # scalar tl.int32, number of bits to write -): - """ - Writes `nbits_i32` least-significant bits of `value_u32` into `u8_ptr` - starting at bit offset `bit_off_i32` in LSB-first order. - - This is still a bit-at-a-time writer; higher-level kernels have been - adjusted to use int32 + shift/mask ahead of time. - """ - j = tl.full((), 0, dtype=tl.int32) - ONE_U32 = tl.full((), 1, dtype=tl.uint32) - - while j < nbits_i32: - pos = bit_off_i32 + j - byte_idx = (pos >> 3).to(tl.int32) - bit_idx = (pos & 7).to(tl.int32) - - old_u8 = tl.load(u8_ptr + byte_idx) - old_u32 = old_u8.to(tl.uint32) - - vbit = (value_u32 >> j) & ONE_U32 - mask = ONE_U32 << bit_idx - new_u32 = (old_u32 & (~mask)) | (vbit << bit_idx) - tl.store(u8_ptr + byte_idx, new_u32.to(tl.uint8)) - j += 1 - return bit_off_i32 + nbits_i32 - - -@triton.jit -def write_nbits_fast( - u8_ptr, - bit_off_i32, # start bit - value_u32, # LSB-first payload bits - nbits_i32, # 0..32 -): - # If nothing to write - if nbits_i32 <= 0: - return bit_off_i32 - - start_bit = bit_off_i32 - first_byte = (start_bit >> 3).to(tl.int32) - first_bit = (start_bit & 7).to(tl.int32) - - # How many bits fit in the first byte - bits_in_first = tl.minimum( - nbits_i32, - tl.full((), 8, dtype=tl.int32) - first_bit, - ) - - # -------- leading partial byte -------- - if bits_in_first > 0: - old_u8 = tl.load(u8_ptr + first_byte).to(tl.uint32) - - # mask for the bits we overwrite inside that byte - mask_u32 = ((tl.full((), 1, tl.uint32) << bits_in_first) - 1) \ - << first_bit - - # extract those bits from value_u32 - bits_u32 = (value_u32 & ((tl.full((), 1, tl.uint32) << bits_in_first) - 1)) \ - << first_bit - - new_u8 = ((old_u8 & ~mask_u32) | bits_u32).to(tl.uint8) - tl.store(u8_ptr + first_byte, new_u8) - - bit_off_i32 += bits_in_first - value_u32 >>= bits_in_first - nbits_i32 -= bits_in_first - - # Now bit_off_i32 is byte aligned (or nbits_i32 == 0) - if nbits_i32 <= 0: - return bit_off_i32 - - cur_byte = (bit_off_i32 >> 3).to(tl.int32) - - # full bytes we can write - full_bytes = (nbits_i32 >> 3).to(tl.int32) # nbits_i32 // 8 - rem_bits = (nbits_i32 & 7).to(tl.int32) - - # -------- full bytes -------- - jb = tl.full((), 0, dtype=tl.int32) - while jb < full_bytes: - # take lowest 8 bits from value_u32 - byte_val = (value_u32 & tl.full((), 0xFF, tl.uint32)).to(tl.uint8) - tl.store(u8_ptr + cur_byte + jb, byte_val) - value_u32 >>= 8 - jb += 1 - - bit_off_i32 += full_bytes * 8 - - # -------- trailing partial byte -------- - if rem_bits > 0: - byte_idx = (bit_off_i32 >> 3).to(tl.int32) - old_u8 = tl.load(u8_ptr + byte_idx).to(tl.uint32) - - mask_u32 = ( (tl.full((), 1, tl.uint32) << rem_bits) - 1 ) - bits_u32 = ( value_u32 & mask_u32 ) - - new_u8 = ((old_u8 & ~mask_u32) | bits_u32).to(tl.uint8) - tl.store(u8_ptr + byte_idx, new_u8) - - bit_off_i32 += rem_bits - return bit_off_i32 - - -@triton.jit -def read_nbits(u8_ptr, bit_off_i32, nbits_i32, limit_bit_i32): - """ - GPU version of BitStreamReader.read_bits (LSB-first), but bounds-safe. - - Reads `nbits_i32` bits starting at `bit_off_i32`, but never loads beyond - bit index `limit_bit_i32` (masked loads return 0 out-of-bounds). - - Returns: (value_u32, new_bit_off_i32) - """ - j = tl.full((), 0, dtype=tl.int32) - val_u32 = tl.full((), 0, dtype=tl.uint32) - ONE_U32 = tl.full((), 1, dtype=tl.uint32) - ZERO_U8 = tl.full((), 0, dtype=tl.uint8) - - while j < nbits_i32: - pos = bit_off_i32 + j - in_bounds = pos < limit_bit_i32 - - byte_idx = (pos >> 3).to(tl.int32) - bit_idx = (pos & 7).to(tl.int32) - - # Masked load: if in_bounds==0, we load ZERO_U8 instead of touching memory. - u8 = tl.load(u8_ptr + byte_idx, mask=in_bounds, other=ZERO_U8) - u32 = u8.to(tl.uint32) - bit = (u32 >> bit_idx) & ONE_U32 - - val_u32 |= (bit << j) - j += 1 - - new_bit_off = bit_off_i32 + nbits_i32 - return val_u32, new_bit_off - - -@triton.jit -def read_nbits_fast(u8_ptr, bit_off_i32, nbits_i32, limit_bit_i32): - if nbits_i32 <= 0: - return tl.full((), 0, tl.uint32), bit_off_i32 - - # clamp to limit if you want to keep the defensive behavior - max_bits = limit_bit_i32 - bit_off_i32 - nbits_i32 = tl.minimum(nbits_i32, max_bits) - - start_bit = bit_off_i32 - end_bit = bit_off_i32 + nbits_i32 - - first_byte = (start_bit >> 3).to(tl.int32) - first_bit = (start_bit & 7).to(tl.int32) - - bits_in_first = tl.minimum( - nbits_i32, - tl.full((), 8, dtype=tl.int32) - first_bit, - ) - - val_u32 = tl.full((), 0, dtype=tl.uint32) - shift = tl.full((), 0, dtype=tl.int32) - - # -------- leading partial byte -------- - if bits_in_first > 0: - byte = tl.load(u8_ptr + first_byte).to(tl.uint32) - mask = ((tl.full((), 1, tl.uint32) << bits_in_first) - 1) << first_bit - chunk = (byte & mask) >> first_bit - val_u32 |= (chunk << shift) - - bit_off_i32 += bits_in_first - shift += bits_in_first - nbits_i32 -= bits_in_first - - if nbits_i32 <= 0: - return val_u32, bit_off_i32 - - cur_byte = (bit_off_i32 >> 3).to(tl.int32) - full_bytes = (nbits_i32 >> 3).to(tl.int32) - rem_bits = (nbits_i32 & 7).to(tl.int32) - - # -------- full bytes -------- - jb = tl.full((), 0, dtype=tl.int32) - while jb < full_bytes: - byte = tl.load(u8_ptr + cur_byte + jb).to(tl.uint32) - val_u32 |= (byte << shift) - shift += 8 - jb += 1 - - bit_off_i32 += full_bytes * 8 - - # -------- trailing partial byte -------- - if rem_bits > 0: - byte = tl.load(u8_ptr + (bit_off_i32 >> 3).to(tl.int32)).to(tl.uint32) - mask = (tl.full((), 1, tl.uint32) << rem_bits) - 1 - chunk = byte & mask - val_u32 |= (chunk << shift) - bit_off_i32 += rem_bits - return val_u32, bit_off_i32 - - -@triton.jit -def read_unary_bounded_triton(u8_ptr, bit_off_i32, end_bit_i32): - """ - GPU version of BitStreamReader.read_unary_bounded(end_bit). - Reads '1's until a '0' or end_bit. - Returns: (q_i32, new_bit_off_i32, hit_end_i32) - - q_i32: number of 1s before the terminating 0 - - hit_end_i32: 1 if we reached end_bit without seeing 0 - 0 if we saw a terminating 0 - """ - ONE_U32 = tl.full((), 1, dtype=tl.uint32) - q_i32 = tl.full((), 0, dtype=tl.int32) - hit_end_i32 = tl.full((), 1, dtype=tl.int32) - - cond = bit_off_i32 < end_bit_i32 - while cond: - pos = bit_off_i32 - byte_idx = (pos >> 3).to(tl.int32) - bit_idx = (pos & 7).to(tl.int32) - - u8 = tl.load(u8_ptr + byte_idx) - u32 = u8.to(tl.uint32) - bit = (u32 >> bit_idx) & ONE_U32 - - bit_off_i32 += 1 - - is_one = (bit == ONE_U32) - q_i32 += is_one.to(tl.int32) - - # If bit is 0, we did NOT hit end - hit_end_i32 = tl.where(is_one, hit_end_i32, - tl.full((), 0, dtype=tl.int32)) - - # Continue only if we are still inside the row and last bit was 1 - cond = (bit_off_i32 < end_bit_i32) & is_one - return q_i32, bit_off_i32, hit_end_i32 \ No newline at end of file diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py index abc7bcc24..273de408c 100644 --- a/src/tplr/compression/hybrid.py +++ b/src/tplr/compression/hybrid.py @@ -1,13 +1,12 @@ import math -from typing import Dict -from typing import Tuple, Union +import struct +from typing import Dict, Tuple, Union import numpy as np import torch import triton import triton.language as tl -from .bitops import write_nbits_fast, read_unary_bounded_triton, read_nbits_fast BytesLike = Union[bytes, bytearray, np.ndarray, torch.Tensor] @@ -20,7 +19,7 @@ def encode_batch_rows( B_choices: Tuple[int, ...] = (64, 128) ) -> Tuple[BytesLike, Dict]: """ - Compresses a 2D int64 tensor of Top-K indices into a byte string + Compresses a 2D sorted int tensor of Top-K indices into a byte string using a per-row adaptive Rice/Bitmap compression scheme on the GPU. Layout: @@ -68,29 +67,28 @@ def encode_batch_rows( idx_sorted = idx_sorted.contiguous() dev = idx_sorted.device + # delta-encoding vals = torch.cat( (idx_sorted[:, :1], idx_sorted[:, 1:] - idx_sorted[:, :-1]), dim=1, - ) - - # Cast to int32 for Triton kernels - vals = vals.to(torch.int32) + ).to(torch.int32) # k_rice parameters (log2(C // B)) k_rice_choices = tuple(int(math.log2(C // b)) for b in B_choices) num_B_choices = len(B_choices) k_rice_choices_tensor = torch.tensor(k_rice_choices, dtype=torch.int32, device=dev) - # Row header bits (only used for packing row-table header byte) + # Row header bits B_choice_bits = (num_B_choices - 1).bit_length() - ROW_HEADER_BITS = 1 + B_choice_bits # (best_B_idx << 1) | use_bitmap + ROW_HEADER_BITS = 1 + B_choice_bits # Output tensors for cost kernel costs = torch.empty((num_rows, num_B_choices), dtype=torch.int32, device=dev) is_bitmap = torch.empty((num_rows, num_B_choices), dtype=torch.int8, device=dev) + + # Calculate grid for cost kernel grid = (num_rows,) - # cost kernel: bits required for deltas only (no header bits) cost_kernel[grid]( vals, costs, @@ -105,18 +103,14 @@ def encode_batch_rows( min_costs, best_B_idx = torch.min(costs, dim=1) is_bitmap_choice = torch.gather(is_bitmap, 1, best_B_idx.unsqueeze(1)).squeeze(1).to(torch.int32) - # (1) payload bits per row = bits for deltas only - row_payload_bits = min_costs # (rows,) - - # (2) payload bytes per row (rounded up) - row_payload_bytes = ((row_payload_bits + 7) // 8).to(torch.int32) # (rows,) + # Payload sizing + row_payload_bits = min_costs + row_payload_bytes = ((row_payload_bits + 7) // 8).to(torch.int32) - # ensure fit in uint16 for the row table if torch.any(row_payload_bytes > 0xFFFF): raise ValueError("Row payload length exceeds 65535 bytes; cannot store in uint16.") - # byte offsets within the payload region (no gaps) - # row_byte_offsets[r] = sum_{i> 8) & 0xFF).to(torch.uint8) + # Vectorized row table construction + row_table_flat = torch.empty((num_rows, 3), dtype=torch.uint8, device=dev) + row_table_flat[:, 0] = (lengths_i32 & 0xFF).to(torch.uint8) + row_table_flat[:, 1] = ((lengths_i32 >> 8) & 0xFF).to(torch.uint8) + row_table_flat[:, 2] = (headers_i32 & ((1 << ROW_HEADER_BITS) - 1)).to(torch.uint8) - # Only the low ROW_HEADER_BITS bits are meaningful, but we just store the byte. - row_table[:, 2] = (headers_i32 & ((1 << ROW_HEADER_BITS) - 1)).to(torch.uint8) + payload_buf[global_header_len_bytes: global_header_len_bytes + row_table_bytes] = row_table_flat.view(-1) - payload_buf[ - global_header_len_bytes : global_header_len_bytes + row_table_bytes - ] = row_table.view(-1) + # Calculate absolute byte offsets for pack kernel + row_abs_byte_offsets = (payload_region_start + row_byte_offsets).to(torch.int32) - # compute bit offsets for each row's payload (no per-row length/header in-band) - row_bit_offsets = (payload_region_start + row_byte_offsets).to(torch.int32) * 8 - - # pack payloads + # Pack payloads (Optimized Kernel) pack_kernel[(num_rows,)]( vals, payload_buf, - row_bit_offsets, + row_abs_byte_offsets, best_B_idx.to(torch.int32), - is_bitmap_choice, # int32 0/1 + is_bitmap_choice, k_rice_choices_tensor, num_rows, k_dim=k_dim, ) - # meta + # Meta stats b_counts = torch.bincount(best_B_idx, minlength=len(B_choices)) B_hist = {b: c.item() for b, c in zip(B_choices, b_counts)} total_row_bytes = total_payload_bytes + row_entry_bytes * num_rows @@ -201,44 +191,36 @@ def encode_batch_rows( @triton.jit def cost_kernel( - delta_ptr, # (rows, k_dim) IN - costs_ptr, # (rows, num_B_choices) OUT - is_bitmap_ptr, # (rows, num_B_choices) OUT (bool/int) - k_dim: tl.constexpr, # constexpr for tl.arange + delta_ptr, + costs_ptr, + is_bitmap_ptr, + k_dim: tl.constexpr, num_rows: tl.int32, num_B_choices: tl.int32, - k_rice_choices_ptr, # (num_B_choices,) int32 + k_rice_choices_ptr, ): """ - Calculates the compressed bit cost for each row for each B in B_choices. - One program instance processes one row. - Variant B: first delta encoded with Rice, tail optionally bitmap (q in {0,1}). + Calculates bit cost. One row per program instance. """ row_idx = tl.program_id(0) if row_idx >= num_rows: return - # Lane indices for this row (constexpr width) i = tl.arange(0, k_dim) - - # Load entire row of delta-encoded values into SRAM row_base = row_idx * k_dim delta = tl.load(delta_ptr + row_base + i) delta0 = tl.load(delta_ptr + row_base) b_idx = 0 while b_idx < num_B_choices: - # k_rice and M = 1 << k_rice k_rice = tl.load(k_rice_choices_ptr + b_idx) - # q via shift, r via mask q = delta >> k_rice q0 = delta0 >> k_rice - # Pure Rice cost: sum(q + 1) + k_dim * k_rice rice_cost = tl.sum(q + 1) + k_dim * k_rice - # Bitmap cost: first element full Rice, tail has (1 + k_rice) bits + # Bitmap cost: head is Rice, tail is (1 + k_rice) bitmap_cost = (q0 + 1 + k_rice) + (k_dim - 1) * (1 + k_rice) # Allow bitmap only if tail q are in {0,1} @@ -250,246 +232,152 @@ def cost_kernel( out_offset = row_idx * num_B_choices + b_idx tl.store(costs_ptr + out_offset, min_cost) - # make sure is_bitmap is exactly 0/1 in memory - tl.store( - is_bitmap_ptr + out_offset, - tl.where(use_bitmap, 1, 0).to(tl.int32), - ) + tl.store(is_bitmap_ptr + out_offset, tl.where(use_bitmap, 1, 0).to(tl.int32)) b_idx += 1 @triton.jit def pack_kernel( - delta_ptr, # (rows, k_dim) IN int32 - u8_payload_ptr, # (final_buffer_bytes,) OUT uint8 - row_bit_offsets_ptr, # (rows,) IN int32 (bit offset where payload starts) - best_B_idx_ptr, # (rows,) IN int32 - is_bitmap_ptr, # (rows,) IN int32 (0/1) - k_rice_choices_ptr, # [num_B] IN int32 - num_rows: tl.int32, - k_dim: tl.int32, # dynamic + delta_ptr, # (rows, k_dim) IN int32 + u8_payload_ptr, # OUT uint8 + row_abs_byte_offsets_ptr, # (rows,) IN int32 (byte offset where payload starts) + best_B_idx_ptr, # (rows,) IN + is_bitmap_ptr, # (rows,) IN + k_rice_choices_ptr, # [num_B] IN + num_rows: tl.int32, + k_dim: tl.int32, # dynamic ): """ - Writes only the Rice/bitmap-coded payload bits for each row. - - Each program instance handles one row. Bit order is LSB-first. + Writes payload bits using a 64-bit register accumulator. + Modified to use unaligned byte stores to prevent cudaErrorMisalignedAddress. """ row_idx = tl.program_id(0) if row_idx >= num_rows: return - # Per-row meta - bit_off_i32 = tl.load(row_bit_offsets_ptr + row_idx).to(tl.int32) + # Load row params + out_byte_off = tl.load(row_abs_byte_offsets_ptr + row_idx).to(tl.int32) b_idx_i32 = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) use_bitmap_i32 = (tl.load(is_bitmap_ptr + row_idx) & 1).to(tl.int32) - - # params k_rice_i32 = tl.load(k_rice_choices_ptr + b_idx_i32).to(tl.int32) M_i32 = (tl.full((), 1, dtype=tl.int32) << k_rice_i32) - ONE_U32 = tl.full((), 1, dtype=tl.uint32) - ZERO_U32 = tl.full((), 0, dtype=tl.uint32) - ONE_I32 = tl.full((), 1, dtype=tl.int32) - THIRTY_ONE_I32 = tl.full((), 31, dtype=tl.int32) - - base = row_idx * k_dim - - # ---- first delta: ALWAYS full Rice (unary + remainder) ---- - if k_dim > 0: - v0 = tl.load(delta_ptr + base).to(tl.int32) - q0 = (v0 >> k_rice_i32).to(tl.int32) - r0 = (v0 & (M_i32 - 1)).to(tl.int32) - - # q0 ones in chunks of <= 31, then a single 0 - q_left = q0 - while q_left > 0: - chunk = tl.minimum(q_left, THIRTY_ONE_I32) - ones = (ONE_U32 << chunk) - ONE_U32 - bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ones, chunk) - q_left -= chunk - - # terminating 0 bit - bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ZERO_U32, ONE_I32) - # remainder - bit_off_i32 = write_nbits_fast( - u8_payload_ptr, bit_off_i32, r0.to(tl.uint32), k_rice_i32 - ) + # Accumulator state + acc_data = tl.full((), 0, dtype=tl.uint64) + acc_bits = tl.full((), 0, dtype=tl.int32) + + # Output pointer (byte-aligned) + out_ptr_base = u8_payload_ptr + out_byte_off - # ---- tail deltas ---- - i = 1 + base_idx = row_idx * k_dim + + # ------------------------------------------------------------------ + # PROCESS LOOP + # ------------------------------------------------------------------ + i = 0 while i < k_dim: - v = tl.load(delta_ptr + base + i).to(tl.int32) - q = (v >> k_rice_i32).to(tl.int32) - r = (v & (M_i32 - 1)).to(tl.int32) - - # Rice unary only if NOT bitmap - q_left = tl.where(use_bitmap_i32 != 0, tl.full((), 0, dtype=tl.int32), q) - while q_left > 0: - chunk = tl.minimum(q_left, THIRTY_ONE_I32) - ones = (ONE_U32 << chunk) - ONE_U32 - bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ones, chunk) - q_left -= chunk - - # terminating 0 bit only in full-Rice mode - n_term = tl.where(use_bitmap_i32 != 0, tl.full((), 0, dtype=tl.int32), ONE_I32) - bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ZERO_U32, n_term) - - # bitmap q only if bitmap - q_bit = tl.where(q > 0, ONE_U32, ZERO_U32) - n_qbit = tl.where(use_bitmap_i32 != 0, ONE_I32, tl.full((), 0, dtype=tl.int32)) - bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, q_bit, n_qbit) - - # remainder always - bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, r.to(tl.uint32), k_rice_i32) - i += 1 + val = tl.load(delta_ptr + base_idx + i).to(tl.int32) + # Compute q, r + q = (val >> k_rice_i32).to(tl.uint64) + r = (val & (M_i32 - 1)).to(tl.uint64) -@triton.jit -def parse_header_kernel( - u8_payload_ptr, # (total_bytes,) uint8 - C_out_ptr, # (1,) int32 - K_out_ptr, # (1,) int32 - R_out_ptr, # (1,) int32 NEW: num_rows - num_B_out_ptr, # (1,) int32 - B_choices_out_ptr, # (MAX_B_CHOICES,) int32 - header_bytes_out_ptr, # (1,) int32 - error_flag_ptr, # (1,) int32 - total_bytes: tl.int32, - MAX_B_CHOICES: tl.constexpr, -): - """ - Parse the global header entirely on GPU. - Layout: - 0..3 : "CGRP" - 4..7 : C (uint32 LE) - 8..9 : K (uint16 LE) - 10..13 : R (uint32 LE, num_rows) - 14 : num_B (uint8) - 15.. : B_choices (num_B * 2 bytes, uint16 LE) - """ + is_rice = (i == 0) | (use_bitmap_i32 == 0) - pid = tl.program_id(0) - if pid != 0: - return + if is_rice: + # Rice: q '1's, then '0', then k_rice bits of r + + # Append Unary (q ones) + q_count = q.to(tl.int32) + while q_count > 0: + acc_data |= (tl.full((), 1, dtype=tl.uint64) << acc_bits) + acc_bits += 1 + q_count -= 1 + + # Flush Check + if acc_bits >= 32: + # Unaligned Store (4 bytes separately) + val_u32 = acc_data.to(tl.uint32) + tl.store(out_ptr_base + 0, (val_u32 & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 1, ((val_u32 >> 8) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 2, ((val_u32 >> 16) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 3, ((val_u32 >> 24) & 0xFF).to(tl.uint8)) + + out_ptr_base += 4 + acc_data >>= 32 + acc_bits -= 32 + + # Append Separator '0' + acc_bits += 1 - # ---- init outputs / error ---- - C_val = tl.full((), 0, dtype=tl.int32) - K_val = tl.full((), 0, dtype=tl.int32) - R_val = tl.full((), 0, dtype=tl.int32) - num_B_val = tl.full((), 0, dtype=tl.int32) - header_bytes_i32 = tl.full((), 0, dtype=tl.int32) - err = tl.full((), 0, dtype=tl.int32) - - # ---- basic size + magic checks ---- - # Minimum header size: 15 bytes (without B_choices) - if total_bytes < 15: - err = 1 - else: - # Magic "CGRP" = [67, 71, 82, 80] - m0 = tl.load(u8_payload_ptr + 0) - m1 = tl.load(u8_payload_ptr + 1) - m2 = tl.load(u8_payload_ptr + 2) - m3 = tl.load(u8_payload_ptr + 3) - cond_magic = (m0 == 67) & (m1 == 71) & (m2 == 82) & (m3 == 80) - bad_magic = cond_magic == 0 - err = tl.where(bad_magic, tl.full((), 2, dtype=tl.int32), err) - - # ---- C, K, R, num_B ---- - if err == 0: - # C (uint32 LE at bytes 4..7) - b4 = tl.load(u8_payload_ptr + 4).to(tl.int32) - b5 = tl.load(u8_payload_ptr + 5).to(tl.int32) - b6 = tl.load(u8_payload_ptr + 6).to(tl.int32) - b7 = tl.load(u8_payload_ptr + 7).to(tl.int32) - C_val = b4 | (b5 << 8) | (b6 << 16) | (b7 << 24) - - # K (uint16 LE at bytes 8..9) - b8 = tl.load(u8_payload_ptr + 8).to(tl.int32) - b9 = tl.load(u8_payload_ptr + 9).to(tl.int32) - K_val = b8 | (b9 << 8) - - # R (uint32 LE at bytes 10..13) - b10 = tl.load(u8_payload_ptr + 10).to(tl.int32) - b11 = tl.load(u8_payload_ptr + 11).to(tl.int32) - b12 = tl.load(u8_payload_ptr + 12).to(tl.int32) - b13 = tl.load(u8_payload_ptr + 13).to(tl.int32) - R_val = b10 | (b11 << 8) | (b12 << 16) | (b13 << 24) - - # num_B at byte 14 - num_B_val = tl.load(u8_payload_ptr + 14).to(tl.int32) - invalid_num_B = (num_B_val <= 0) | (num_B_val > MAX_B_CHOICES) - err = tl.where(invalid_num_B, tl.full((), 3, dtype=tl.int32), err) - - # ---- read B_choices in a structured loop (no break/return) ---- - off = tl.full((), 15, dtype=tl.int32) # B_choices start at byte 15 - i = tl.full((), 0, dtype=tl.int32) - - while i < MAX_B_CHOICES: - need_this = (i < num_B_val) & (err == 0) - - if need_this: - cond_in_bounds = (off + 1) < total_bytes - if cond_in_bounds: - lo = tl.load(u8_payload_ptr + off).to(tl.int32) - hi = tl.load(u8_payload_ptr + off + 1).to(tl.int32) - B_val = lo | (hi << 8) - tl.store(B_choices_out_ptr + i, B_val) - off += 2 - else: - err = tl.full((), 4, dtype=tl.int32) - tl.store(B_choices_out_ptr + i, tl.full((), 0, dtype=tl.int32)) else: - tl.store(B_choices_out_ptr + i, tl.full((), 0, dtype=tl.int32)) + # Bitmap: q is 1 bit + q_bit = tl.where(q > 0, 1, 0).to(tl.uint64) + acc_data |= (q_bit << acc_bits) + acc_bits += 1 + + # Flush Check + if acc_bits >= 32: + val_u32 = acc_data.to(tl.uint32) + tl.store(out_ptr_base + 0, (val_u32 & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 1, ((val_u32 >> 8) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 2, ((val_u32 >> 16) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 3, ((val_u32 >> 24) & 0xFF).to(tl.uint8)) + + out_ptr_base += 4 + acc_data >>= 32 + acc_bits -= 32 + + # Append Remainder + acc_data |= (r << acc_bits) + acc_bits += k_rice_i32 + + # Flush Check + if acc_bits >= 32: + val_u32 = acc_data.to(tl.uint32) + tl.store(out_ptr_base + 0, (val_u32 & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 1, ((val_u32 >> 8) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 2, ((val_u32 >> 16) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 3, ((val_u32 >> 24) & 0xFF).to(tl.uint8)) + + out_ptr_base += 4 + acc_data >>= 32 + acc_bits -= 32 i += 1 - # header_bytes = 15 + 2 * num_B (only meaningful if err == 0) - if err == 0: - header_bytes_i32 = 15 + (num_B_val * 2) - - # ---- store outputs ---- - tl.store(C_out_ptr, C_val) - tl.store(K_out_ptr, K_val) - tl.store(R_out_ptr, R_val) - tl.store(num_B_out_ptr, num_B_val) - tl.store(header_bytes_out_ptr, header_bytes_i32) - tl.store(error_flag_ptr, err) + # ------------------------------------------------------------------ + # FINAL FLUSH + # ------------------------------------------------------------------ + # We might have 1..31 bits left. Write byte-by-byte. + while acc_bits > 0: + tl.store(out_ptr_base, (acc_data & 0xFF).to(tl.uint8)) + out_ptr_base += 1 + acc_data >>= 8 + acc_bits -= 8 @triton.jit def parse_row_table_kernel( - u8_payload_ptr, # (total_bytes,) uint8 - row_payload_bytes_ptr, # (num_rows,) int32 - best_B_idx_ptr, # (num_rows,) int32 - use_bitmap_ptr, # (num_rows,) int32 - row_table_start: tl.int32, - num_rows: tl.int32, - ROW_HEADER_BITS: tl.constexpr, + u8_payload_ptr, + row_payload_bytes_ptr, + best_B_idx_ptr, + use_bitmap_ptr, + row_table_start: tl.int32, + num_rows: tl.int32, + ROW_HEADER_BITS: tl.constexpr, ): - """ - Parse the row table: - - For each row r: - offset = row_table_start + r * 3 - length_bytes[r] = uint16 LE at offset - header_byte = uint8 at offset + 2 - header_bits = header_byte & ((1 << ROW_HEADER_BITS) - 1) - use_bitmap[r] = header_bits & 1 - best_B_idx[r] = header_bits >> 1 - """ pid = tl.program_id(0) if pid >= num_rows: return entry_offset = row_table_start + pid * 3 - # length_bytes: uint16 LE b0 = tl.load(u8_payload_ptr + entry_offset).to(tl.int32) b1 = tl.load(u8_payload_ptr + entry_offset + 1).to(tl.int32) length_i32 = b0 | (b1 << 8) tl.store(row_payload_bytes_ptr + pid, length_i32) - # header byte header_byte = tl.load(u8_payload_ptr + entry_offset + 2).to(tl.int32) header_mask = (tl.full((), 1, dtype=tl.int32) << ROW_HEADER_BITS) - 1 header_i32 = header_byte & header_mask @@ -501,189 +389,172 @@ def parse_row_table_kernel( tl.store(best_B_idx_ptr + pid, best_B_idx_i32) +@triton.jit +def count_ones_in_word(word_u64): + """ + Counts trailing ones in a 64-bit word (register level). + Used for fast unary decoding without global memory access. + """ + cnt = tl.full((), 0, dtype=tl.int32) + ONE_U64 = tl.full((), 1, dtype=tl.uint64) + + check = word_u64 + cond = ((check & ONE_U64) == ONE_U64) & (cnt < 64) + while cond: + cnt += 1 + check >>= 1 + # Update condition for next iteration + cond = ((check & ONE_U64) == ONE_U64) & (cnt < 64) + return cnt + @triton.jit def decode_rows_kernel( - u8_payload_ptr, # (total_bytes,) uint8 - out_vals_ptr, # (num_rows * K,) int32 - row_bit_offsets_ptr, # (num_rows,) int32 (bit offset of first encoded bit) - row_payload_bytes_ptr, # (num_rows,) int32 - best_B_idx_ptr, # (num_rows,) int32 - use_bitmap_ptr, # (num_rows,) int32 - k_rice_choices_ptr, # (num_B,) int32 - num_rows: tl.int32, - K: tl.int32, + u8_payload_ptr, + out_vals_ptr, + row_bit_offsets_ptr, # (rows,) + row_payload_bytes_ptr, # (rows,) + best_B_idx_ptr, # (rows,) + use_bitmap_ptr, # (rows,) + k_rice_choices_ptr, # (num_B,) + num_rows: tl.int32, + K: tl.int32, ): """ - Fully GPU decode of Rice/bitmap rows. - - For each row r: - - Bit range: - start_bit = row_bit_offsets[r] - end_bit = start_bit + row_payload_bytes[r] * 8 - - First value: full Rice (unary + remainder) - - Tail: Rice or bitmap+remainder depending on use_bitmap[r]. + Decodes rows using unaligned-safe 64-bit reads (via byte loads). """ row_idx = tl.program_id(0) if row_idx >= num_rows: return - # Per-row metadata - row_start_bit_i32 = tl.load(row_bit_offsets_ptr + row_idx).to(tl.int32) - payload_bytes_i32 = tl.load(row_payload_bytes_ptr + row_idx).to(tl.int32) - best_B_idx_i32 = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) - use_bitmap_i32 = (tl.load(use_bitmap_ptr + row_idx) & 1).to(tl.int32) - - # k_rice and M for this row - k_rice_i32 = tl.load(k_rice_choices_ptr + best_B_idx_i32).to(tl.int32) - M_i32 = (tl.full((), 1, dtype=tl.int32) << k_rice_i32) + # Row params + start_bit = tl.load(row_bit_offsets_ptr + row_idx).to(tl.int32) + payload_bytes = tl.load(row_payload_bytes_ptr + row_idx).to(tl.int32) + b_idx = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) + use_bitmap = (tl.load(use_bitmap_ptr + row_idx) & 1).to(tl.int32) - # Bit range of this row - row_end_bit_i32 = row_start_bit_i32 + payload_bytes_i32 * 8 + k_rice = tl.load(k_rice_choices_ptr + b_idx).to(tl.int32) + M = (tl.full((), 1, dtype=tl.int32) << k_rice) + current_bit = start_bit base_out = row_idx * K - ONE_I32 = tl.full((), 1, dtype=tl.int32) - bit_off_i32 = row_start_bit_i32 - - # ---- first value: ALWAYS full Rice ---- - if K > 0: - q0_i32, bit_off_i32, hit_end0_i32 = read_unary_bounded_triton( - u8_payload_ptr, - bit_off_i32, - row_end_bit_i32, - ) - r0_u32, bit_off_i32 = read_nbits_fast( - u8_payload_ptr, - bit_off_i32, - k_rice_i32, - row_end_bit_i32, # limit - ) - r0_i32 = r0_u32.to(tl.int32) - v0_i32 = q0_i32 * M_i32 + r0_i32 - tl.store(out_vals_ptr + base_out, v0_i32) - - # ---- tail values ---- - i = tl.full((), 1, dtype=tl.int32) + i = 0 while i < K: - if use_bitmap_i32 != 0: - # Bitmap mode: q is 1 bit in {0,1} - q_bit_u32, bit_off_i32 = read_nbits_fast( - u8_payload_ptr, - bit_off_i32, - ONE_I32, - row_end_bit_i32, - ) - q_i32 = q_bit_u32.to(tl.int32) - - r_u32, bit_off_i32 = read_nbits_fast( - u8_payload_ptr, - bit_off_i32, - k_rice_i32, - row_end_bit_i32, - ) - r_i32 = r_u32.to(tl.int32) + # ------------------------------------------------ + # BUFFERED LOAD (Unaligned Safe) + # ------------------------------------------------ + byte_idx = current_bit // 8 + bit_in_byte = current_bit % 8 + + # Manually load 8 bytes to form uint64. + # This prevents misaligned access on all GPUs. + b0 = tl.load(u8_payload_ptr + byte_idx + 0).to(tl.uint64) + b1 = tl.load(u8_payload_ptr + byte_idx + 1).to(tl.uint64) + b2 = tl.load(u8_payload_ptr + byte_idx + 2).to(tl.uint64) + b3 = tl.load(u8_payload_ptr + byte_idx + 3).to(tl.uint64) + b4 = tl.load(u8_payload_ptr + byte_idx + 4).to(tl.uint64) + b5 = tl.load(u8_payload_ptr + byte_idx + 5).to(tl.uint64) + b6 = tl.load(u8_payload_ptr + byte_idx + 6).to(tl.uint64) + b7 = tl.load(u8_payload_ptr + byte_idx + 7).to(tl.uint64) + + word_u64 = b0 | (b1 << 8) | (b2 << 16) | (b3 << 24) | \ + (b4 << 32) | (b5 << 40) | (b6 << 48) | (b7 << 56) + + # Shift out consumed bits + stream = word_u64 >> bit_in_byte + + # ------------------------------------------------ + # DECODE LOGIC + # ------------------------------------------------ + q = 0 + r = 0 + + is_rice = (i == 0) | (use_bitmap == 0) + + bits_consumed = 0 + + if is_rice: + # Decode Unary q + q = count_ones_in_word(stream) + bits_consumed += (q + 1) + stream >>= (q + 1) else: - # Full Rice mode - q_i32, bit_off_i32, hit_end_i32 = read_unary_bounded_triton( - u8_payload_ptr, - bit_off_i32, - row_end_bit_i32, - ) - r_u32, bit_off_i32 = read_nbits_fast( - u8_payload_ptr, - bit_off_i32, - k_rice_i32, - row_end_bit_i32, - ) - r_i32 = r_u32.to(tl.int32) - - v_i32 = q_i32 * M_i32 + r_i32 - tl.store(out_vals_ptr + base_out + i, v_i32) + # Bitmap: q is single bit + q = (stream & 1).to(tl.int32) + bits_consumed += 1 + stream >>= 1 + + # Decode Remainder r + mask = (tl.full((), 1, dtype=tl.uint64) << k_rice) - 1 + r = (stream & mask).to(tl.int32) + bits_consumed += k_rice + + # ------------------------------------------------ + # STORE + # ------------------------------------------------ + val = q * M + r + tl.store(out_vals_ptr + base_out + i, val) + + current_bit += bits_consumed i += 1 def decode_batch_rows( - payload: BytesLike, - max_num_B: int = 16, + payload: BytesLike, + max_num_B: int = 16, ) -> tuple[torch.Tensor, int, int]: - if not torch.cuda.is_available(): - raise RuntimeError("decode_batch_rows_gpu requires CUDA") + raise RuntimeError("CUDA required") - # --- Move payload to CUDA (if needed) --- + # Move to GPU/Tensor if isinstance(payload, torch.Tensor): - assert payload.dtype == torch.uint8 payload_gpu = payload if payload.is_cuda else payload.cuda() elif isinstance(payload, np.ndarray): - assert payload.dtype == np.uint8 payload_gpu = torch.from_numpy(payload).to("cuda", dtype=torch.uint8) - elif isinstance(payload, (bytes, bytearray)): + else: arr = np.frombuffer(bytes(payload), dtype=np.uint8) payload_gpu = torch.from_numpy(arr).to("cuda", dtype=torch.uint8) - else: - raise TypeError("Unsupported payload type") payload_gpu = payload_gpu.contiguous() dev = payload_gpu.device total_bytes = int(payload_gpu.numel()) - if total_bytes == 0: - empty = torch.empty((0, 0), dtype=torch.int64, device=dev) - return empty, 0, 0 - - # --- 1) Parse global header on GPU --- - C_out = torch.empty(1, dtype=torch.int32, device=dev) - K_out = torch.empty(1, dtype=torch.int32, device=dev) - R_out = torch.empty(1, dtype=torch.int32, device=dev) - num_B_out = torch.empty(1, dtype=torch.int32, device=dev) - B_choices_out = torch.empty(max_num_B, dtype=torch.int32, device=dev) - header_bytes_out = torch.empty(1, dtype=torch.int32, device=dev) - err_flag = torch.zeros(1, dtype=torch.int32, device=dev) - - parse_header_kernel[(1,)]( - payload_gpu, - C_out, - K_out, - R_out, - num_B_out, - B_choices_out, - header_bytes_out, - err_flag, - total_bytes, - MAX_B_CHOICES=max_num_B, - ) - - torch.cuda.synchronize() - err = int(err_flag.cpu().item()) - if err != 0: - raise ValueError(f"parse_header_kernel failed with error code {err}") - - C = int(C_out.cpu().item()) - K = int(K_out.cpu().item()) - num_rows = int(R_out.cpu().item()) - num_B = int(num_B_out.cpu().item()) - header_bytes = int(header_bytes_out.cpu().item()) - B_choices_list = [int(x) for x in B_choices_out[:num_B].cpu().tolist()] - # --- 2) Build k_rice choices on CPU -> move to GPU --- + if total_bytes == 0: + return torch.empty((0, 0), dtype=torch.int64, device=dev), 0, 0 + + # --- 1) Parse Global Header (CPU) --- + header_size_min = 15 + header_cpu = payload_gpu[:64].cpu().numpy().tobytes() + + try: + # Fixed format string to match 15 bytes + magic, C, K, num_rows, num_B = struct.unpack("<4sIHIB", header_cpu[:15]) + except struct.error: + raise ValueError("Payload too short for header") + + if magic != b"CGRP": + raise ValueError("Invalid magic bytes") + + offset = 15 + B_choices = [] + for _ in range(num_B): + b_val = struct.unpack(" total_bytes: - raise ValueError("Truncated payload: row table exceeds payload length") + # --- 3) Parse Row Table (GPU) --- + row_table_bytes = num_rows * 3 row_payload_bytes = torch.empty(num_rows, dtype=torch.int32, device=dev) best_B_idx = torch.empty(num_rows, dtype=torch.int32, device=dev) @@ -699,30 +570,18 @@ def decode_batch_rows( ROW_HEADER_BITS=ROW_HEADER_BITS, ) - # --- 4) Compute per-row bit offsets into payload region --- - payload_region_start_byte = header_bytes + row_table_bytes - if payload_region_start_byte > total_bytes: - raise ValueError("Truncated payload: missing payload region") + # --- 4) Offsets --- + payload_region_start = header_bytes + row_table_bytes - # byte offsets within the payload region row_payload_bytes_64 = row_payload_bytes.to(torch.int64) - row_byte_offsets = torch.cumsum(row_payload_bytes_64, dim=0) - row_payload_bytes_64 - # Sanity check: last row must end within the buffer - last_end = int( - payload_region_start_byte - + row_byte_offsets[-1].item() - + row_payload_bytes_64[-1].item() - ) - if last_end > total_bytes: - raise ValueError("Truncated payload: row payload bytes exceed buffer length") + row_byte_offsets = torch.cumsum(row_payload_bytes_64, dim=0) - row_payload_bytes_64 - row_bit_offsets = ( - payload_region_start_byte + row_byte_offsets - ).to(torch.int32) * 8 + row_bit_offsets = (payload_region_start + row_byte_offsets).to(torch.int32) * 8 - # --- 5) Decode rows in parallel on GPU --- + # --- 5) Decode (GPU Optimized) --- out_vals = torch.empty((num_rows, K), dtype=torch.int32, device=dev) + decode_rows_kernel[(num_rows,)]( payload_gpu, out_vals, @@ -736,5 +595,4 @@ def decode_batch_rows( ) out_vals = torch.cumsum(out_vals, dim=1) - return out_vals.to(torch.int64), C, num_rows - + return out_vals.to(torch.int64), C, num_rows \ No newline at end of file From 616e21031209e3482595d09da5f8f3d2c0455553 Mon Sep 17 00:00:00 2001 From: Kasper Date: Wed, 19 Nov 2025 13:04:40 +0400 Subject: [PATCH 17/33] add offload time tracking --- src/tplr/neurons.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/tplr/neurons.py b/src/tplr/neurons.py index 410a18268..fddfee509 100644 --- a/src/tplr/neurons.py +++ b/src/tplr/neurons.py @@ -113,6 +113,7 @@ def barrier(group=None): compression_time = 0 copy_time = 0 encode_time = 0 + offload_time = 0 for _, (n, p) in enumerate(model_iterator, 1): owned = n in miner.owned_params p_is_dt = is_dtensor(p) @@ -211,8 +212,7 @@ def barrier(group=None): p.grad = None copy_time += tplr.T() - copy_start - tplr.logger.info(f"times: {encode_time}, {compression_time}, {copy_time}") - + offload_start = tplr.T() # Batch offload all error feedback tensors to CPU with pinned memory for name in miner.error_feedback: if ( @@ -224,6 +224,8 @@ def barrier(group=None): miner.error_feedback[name], non_blocking=True ) miner.error_feedback[name] = miner.error_feedback_cpu_buffers[name] + offload_time += tplr.T() - offload_start + tplr.logger.info(f"times: {encode_time}, {compression_time}, {copy_time}, {offload_time}") # Single synchronization at the end for all async operations if torch.cuda.is_available(): From d2a05b96730b27b4d30ccf34d07d3d564c59d332 Mon Sep 17 00:00:00 2001 From: Kasper Date: Wed, 19 Nov 2025 15:22:06 +0400 Subject: [PATCH 18/33] stablish --- src/tplr/compression/hybrid.py | 361 +++++++++++++++++++++------------ src/tplr/neurons.py | 8 +- 2 files changed, 237 insertions(+), 132 deletions(-) diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py index 273de408c..4757089b9 100644 --- a/src/tplr/compression/hybrid.py +++ b/src/tplr/compression/hybrid.py @@ -120,7 +120,7 @@ def encode_batch_rows( ) total_payload_bytes = int(row_payload_bytes.sum().item()) - # Global header Construction (CPU is faster for this small structured data) + # Global header Construction header_list = [] header_list.append(b"CGRP") # 4B magic header_list.append(struct.pack("> 8) & 0xFF).to(torch.uint8) @@ -162,7 +163,7 @@ def encode_batch_rows( # Calculate absolute byte offsets for pack kernel row_abs_byte_offsets = (payload_region_start + row_byte_offsets).to(torch.int32) - # Pack payloads (Optimized Kernel) + # Pack payloads pack_kernel[(num_rows,)]( vals, payload_buf, @@ -249,7 +250,7 @@ def pack_kernel( ): """ Writes payload bits using a 64-bit register accumulator. - Modified to use unaligned byte stores to prevent cudaErrorMisalignedAddress. + Uses unaligned byte stores (split into 4 bytes) to prevent cudaErrorMisalignedAddress. """ row_idx = tl.program_id(0) if row_idx >= num_rows: @@ -271,9 +272,6 @@ def pack_kernel( base_idx = row_idx * k_dim - # ------------------------------------------------------------------ - # PROCESS LOOP - # ------------------------------------------------------------------ i = 0 while i < k_dim: val = tl.load(delta_ptr + base_idx + i).to(tl.int32) @@ -286,8 +284,6 @@ def pack_kernel( if is_rice: # Rice: q '1's, then '0', then k_rice bits of r - - # Append Unary (q ones) q_count = q.to(tl.int32) while q_count > 0: acc_data |= (tl.full((), 1, dtype=tl.uint64) << acc_bits) @@ -316,7 +312,7 @@ def pack_kernel( acc_data |= (q_bit << acc_bits) acc_bits += 1 - # Flush Check + # Flush Check (after separator/bitmap bit) if acc_bits >= 32: val_u32 = acc_data.to(tl.uint32) tl.store(out_ptr_base + 0, (val_u32 & 0xFF).to(tl.uint8)) @@ -346,10 +342,7 @@ def pack_kernel( i += 1 - # ------------------------------------------------------------------ - # FINAL FLUSH - # ------------------------------------------------------------------ - # We might have 1..31 bits left. Write byte-by-byte. + # Final Flush while acc_bits > 0: tl.store(out_ptr_base, (acc_data & 0xFF).to(tl.uint8)) out_ptr_base += 1 @@ -357,20 +350,84 @@ def pack_kernel( acc_bits -= 8 +@triton.jit +def parse_header_kernel( + u8_payload_ptr, # (total_bytes,) uint8 + C_out_ptr, # (1,) int32 + K_out_ptr, # (1,) int32 + R_out_ptr, # (1,) int32 + num_B_out_ptr, # (1,) int32 + B_choices_out_ptr, # (MAX_B_CHOICES,) int32 + header_bytes_out_ptr, # (1,) int32 + max_num_B: tl.constexpr, +): + """ + Simple GPU kernel to parse the global header. + Replaces CPU struct.unpack to avoid CPU<->GPU synchronization. + """ + pid = tl.program_id(0) + if pid != 0: + return + + # Magic "CGRP" checked implicitly or skipped for speed + + # C (uint32 LE at bytes 4..7) + b4 = tl.load(u8_payload_ptr + 4).to(tl.int32) + b5 = tl.load(u8_payload_ptr + 5).to(tl.int32) + b6 = tl.load(u8_payload_ptr + 6).to(tl.int32) + b7 = tl.load(u8_payload_ptr + 7).to(tl.int32) + C_val = b4 | (b5 << 8) | (b6 << 16) | (b7 << 24) + tl.store(C_out_ptr, C_val) + + # K (uint16 LE at bytes 8..9) + b8 = tl.load(u8_payload_ptr + 8).to(tl.int32) + b9 = tl.load(u8_payload_ptr + 9).to(tl.int32) + K_val = b8 | (b9 << 8) + tl.store(K_out_ptr, K_val) + + # R (uint32 LE at bytes 10..13) + b10 = tl.load(u8_payload_ptr + 10).to(tl.int32) + b11 = tl.load(u8_payload_ptr + 11).to(tl.int32) + b12 = tl.load(u8_payload_ptr + 12).to(tl.int32) + b13 = tl.load(u8_payload_ptr + 13).to(tl.int32) + R_val = b10 | (b11 << 8) | (b12 << 16) | (b13 << 24) + tl.store(R_out_ptr, R_val) + + # num_B at byte 14 + num_B_val = tl.load(u8_payload_ptr + 14).to(tl.int32) + tl.store(num_B_out_ptr, num_B_val) + + # Read B_choices (start at 15) + off = 15 + i = 0 + while i < max_num_B: + if i < num_B_val: + lo = tl.load(u8_payload_ptr + off).to(tl.int32) + hi = tl.load(u8_payload_ptr + off + 1).to(tl.int32) + B_val = lo | (hi << 8) + tl.store(B_choices_out_ptr + i, B_val) + off += 2 + i += 1 + + tl.store(header_bytes_out_ptr, off) + + @triton.jit def parse_row_table_kernel( u8_payload_ptr, row_payload_bytes_ptr, best_B_idx_ptr, use_bitmap_ptr, - row_table_start: tl.int32, - num_rows: tl.int32, + row_table_start_ptr, # int32* from header kernel + num_rows_ptr, # int32* from header kernel ROW_HEADER_BITS: tl.constexpr, ): pid = tl.program_id(0) + num_rows = tl.load(num_rows_ptr) if pid >= num_rows: return + row_table_start = tl.load(row_table_start_ptr) entry_offset = row_table_start + pid * 3 b0 = tl.load(u8_payload_ptr + entry_offset).to(tl.int32) @@ -390,46 +447,45 @@ def parse_row_table_kernel( @triton.jit -def count_ones_in_word(word_u64): +def popc64(x): """ - Counts trailing ones in a 64-bit word (register level). - Used for fast unary decoding without global memory access. + Population count (Hamming weight) for uint64 using SWAR. + Safe for Triton (no loops). """ - cnt = tl.full((), 0, dtype=tl.int32) - ONE_U64 = tl.full((), 1, dtype=tl.uint64) - - check = word_u64 - cond = ((check & ONE_U64) == ONE_U64) & (cnt < 64) - while cond: - cnt += 1 - check >>= 1 - # Update condition for next iteration - cond = ((check & ONE_U64) == ONE_U64) & (cnt < 64) - return cnt + x = x - ((x >> 1) & 0x5555555555555555) + x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333) + x = (x + (x >> 4)) & 0x0F0F0F0F0F0F0F0F + x = x + (x >> 8) + x = x + (x >> 16) + x = x + (x >> 32) + return (x & 0x7F).to(tl.int32) @triton.jit def decode_rows_kernel( u8_payload_ptr, out_vals_ptr, - row_bit_offsets_ptr, # (rows,) - row_payload_bytes_ptr, # (rows,) - best_B_idx_ptr, # (rows,) - use_bitmap_ptr, # (rows,) - k_rice_choices_ptr, # (num_B,) - num_rows: tl.int32, - K: tl.int32, + row_bit_offsets_ptr, + best_B_idx_ptr, + use_bitmap_ptr, + k_rice_choices_ptr, + num_rows_ptr, + K_ptr, + total_payload_bytes: tl.int32, ): - """ - Decodes rows using unaligned-safe 64-bit reads (via byte loads). - """ row_idx = tl.program_id(0) + num_rows = tl.load(num_rows_ptr) if row_idx >= num_rows: return + K = tl.load(K_ptr) + + # Safety limits + limit_ptr = u8_payload_ptr + total_payload_bytes + total_bits = total_payload_bytes * 8 + # Row params start_bit = tl.load(row_bit_offsets_ptr + row_idx).to(tl.int32) - payload_bytes = tl.load(row_payload_bytes_ptr + row_idx).to(tl.int32) b_idx = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) use_bitmap = (tl.load(use_bitmap_ptr + row_idx) & 1).to(tl.int32) @@ -441,65 +497,110 @@ def decode_rows_kernel( i = 0 while i < K: - # ------------------------------------------------ - # BUFFERED LOAD (Unaligned Safe) - # ------------------------------------------------ - byte_idx = current_bit // 8 - bit_in_byte = current_bit % 8 - - # Manually load 8 bytes to form uint64. - # This prevents misaligned access on all GPUs. - b0 = tl.load(u8_payload_ptr + byte_idx + 0).to(tl.uint64) - b1 = tl.load(u8_payload_ptr + byte_idx + 1).to(tl.uint64) - b2 = tl.load(u8_payload_ptr + byte_idx + 2).to(tl.uint64) - b3 = tl.load(u8_payload_ptr + byte_idx + 3).to(tl.uint64) - b4 = tl.load(u8_payload_ptr + byte_idx + 4).to(tl.uint64) - b5 = tl.load(u8_payload_ptr + byte_idx + 5).to(tl.uint64) - b6 = tl.load(u8_payload_ptr + byte_idx + 6).to(tl.uint64) - b7 = tl.load(u8_payload_ptr + byte_idx + 7).to(tl.uint64) - - word_u64 = b0 | (b1 << 8) | (b2 << 16) | (b3 << 24) | \ - (b4 << 32) | (b5 << 40) | (b6 << 48) | (b7 << 56) - - # Shift out consumed bits - stream = word_u64 >> bit_in_byte - - # ------------------------------------------------ - # DECODE LOGIC - # ------------------------------------------------ - q = 0 - r = 0 - - is_rice = (i == 0) | (use_bitmap == 0) - - bits_consumed = 0 - - if is_rice: - # Decode Unary q - q = count_ones_in_word(stream) - bits_consumed += (q + 1) - stream >>= (q + 1) + # FIX: Use if/else instead of 'continue' + if current_bit >= total_bits: + # OOB: Fill with 0 + tl.store(out_vals_ptr + base_out + i, 0) else: - # Bitmap: q is single bit - q = (stream & 1).to(tl.int32) - bits_consumed += 1 - stream >>= 1 - - # Decode Remainder r - mask = (tl.full((), 1, dtype=tl.uint64) << k_rice) - 1 - r = (stream & mask).to(tl.int32) - bits_consumed += k_rice - - # ------------------------------------------------ - # STORE - # ------------------------------------------------ - val = q * M + r - tl.store(out_vals_ptr + base_out + i, val) - - current_bit += bits_consumed + # NORMAL PATH + is_rice = (i == 0) | (use_bitmap == 0) + + # Accumulate q here + q = 0 + + if is_rice: + # --- RICE DECODING (Unary Part) --- + reading_unary = 1 + + while reading_unary > 0: + # 1. Load Window + byte_idx = current_bit // 8 + bit_in_byte = current_bit % 8 + + base_ptr = u8_payload_ptr + byte_idx + ptr_int = base_ptr.to(tl.uint64) + + # Align + mask_aligned = tl.full((), 0xFFFFFFFFFFFFFFF8, dtype=tl.uint64) + aligned_ptr_int = ptr_int & mask_aligned + offset = (ptr_int & 7).to(tl.int32) + aligned_ptr = aligned_ptr_int.to(tl.pointer_type(tl.uint64)) + + w0 = tl.load(aligned_ptr) + w1_addr_int = aligned_ptr_int + 8 + limit_addr_int = limit_ptr.to(tl.uint64) + w1 = tl.load(aligned_ptr + 1, mask=w1_addr_int < limit_addr_int, other=0) + + # Construct 64-bit word + shift = offset * 8 + lo = w0 >> shift + shift_up = (64 - shift) & 63 + hi = w1 << shift_up + word_u64 = tl.where(offset == 0, w0, lo | hi) + + stream = word_u64 >> bit_in_byte + + # 2. Find First Zero + inv_stream = ~stream + lsb = inv_stream & (0 - inv_stream) + dist = popc64(lsb - 1) + + valid_bits_in_window = 64 - bit_in_byte + + if dist < valid_bits_in_window: + q += dist + current_bit += (dist + 1) + reading_unary = 0 + else: + q += valid_bits_in_window + current_bit += valid_bits_in_window + + if current_bit >= total_bits: + reading_unary = 0 + else: + # --- BITMAP DECODING --- + byte_idx = current_bit // 8 + bit_in_byte = current_bit % 8 + b0 = tl.load(u8_payload_ptr + byte_idx).to(tl.uint64) + q = ((b0 >> bit_in_byte) & 1).to(tl.int32) + current_bit += 1 + + # --- REMAINDER DECODING --- + # Safety check for remainder bits + if current_bit + k_rice > total_bits + 8: + val = q * M + else: + byte_idx = current_bit // 8 + bit_in_byte = current_bit % 8 + + base_ptr = u8_payload_ptr + byte_idx + ptr_int = base_ptr.to(tl.uint64) + mask_aligned = tl.full((), 0xFFFFFFFFFFFFFFF8, dtype=tl.uint64) + aligned_ptr_int = ptr_int & mask_aligned + offset = (ptr_int & 7).to(tl.int32) + aligned_ptr = aligned_ptr_int.to(tl.pointer_type(tl.uint64)) + + w0 = tl.load(aligned_ptr) + w1_addr_int = aligned_ptr_int + 8 + limit_addr_int = limit_ptr.to(tl.uint64) + w1 = tl.load(aligned_ptr + 1, mask=w1_addr_int < limit_addr_int, other=0) + + shift = offset * 8 + lo = w0 >> shift + shift_up = (64 - shift) & 63 + hi = w1 << shift_up + word_u64 = tl.where(offset == 0, w0, lo | hi) + + stream = word_u64 >> bit_in_byte + mask = (tl.full((), 1, dtype=tl.uint64) << k_rice) - 1 + r = (stream & mask).to(tl.int32) + + current_bit += k_rice + val = q * M + r + + tl.store(out_vals_ptr + base_out + i, val) i += 1 - def decode_batch_rows( payload: BytesLike, max_num_B: int = 16, @@ -523,30 +624,35 @@ def decode_batch_rows( if total_bytes == 0: return torch.empty((0, 0), dtype=torch.int64, device=dev), 0, 0 - # --- 1) Parse Global Header (CPU) --- - header_size_min = 15 - header_cpu = payload_gpu[:64].cpu().numpy().tobytes() - - try: - # Fixed format string to match 15 bytes - magic, C, K, num_rows, num_B = struct.unpack("<4sIHIB", header_cpu[:15]) - except struct.error: - raise ValueError("Payload too short for header") + # --- 1) Parse Header (GPU Kernel) --- + # Avoid CPU synchronization by parsing on GPU + C_out = torch.empty(1, dtype=torch.int32, device=dev) + K_out = torch.empty(1, dtype=torch.int32, device=dev) + R_out = torch.empty(1, dtype=torch.int32, device=dev) + num_B_out = torch.empty(1, dtype=torch.int32, device=dev) + B_choices_out = torch.empty(max_num_B, dtype=torch.int32, device=dev) + header_bytes_out = torch.empty(1, dtype=torch.int32, device=dev) - if magic != b"CGRP": - raise ValueError("Invalid magic bytes") + parse_header_kernel[(1,)]( + payload_gpu, + C_out, + K_out, + R_out, + num_B_out, + B_choices_out, + header_bytes_out, + max_num_B=max_num_B + ) - offset = 15 - B_choices = [] - for _ in range(num_B): - b_val = struct.unpack(" Date: Wed, 19 Nov 2025 15:46:10 +0400 Subject: [PATCH 19/33] Revert "stablish" This reverts commit d2a05b96730b27b4d30ccf34d07d3d564c59d332. --- src/tplr/compression/hybrid.py | 361 ++++++++++++--------------------- src/tplr/neurons.py | 8 +- 2 files changed, 132 insertions(+), 237 deletions(-) diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py index 4757089b9..273de408c 100644 --- a/src/tplr/compression/hybrid.py +++ b/src/tplr/compression/hybrid.py @@ -120,7 +120,7 @@ def encode_batch_rows( ) total_payload_bytes = int(row_payload_bytes.sum().item()) - # Global header Construction + # Global header Construction (CPU is faster for this small structured data) header_list = [] header_list.append(b"CGRP") # 4B magic header_list.append(struct.pack("> 8) & 0xFF).to(torch.uint8) @@ -163,7 +162,7 @@ def encode_batch_rows( # Calculate absolute byte offsets for pack kernel row_abs_byte_offsets = (payload_region_start + row_byte_offsets).to(torch.int32) - # Pack payloads + # Pack payloads (Optimized Kernel) pack_kernel[(num_rows,)]( vals, payload_buf, @@ -250,7 +249,7 @@ def pack_kernel( ): """ Writes payload bits using a 64-bit register accumulator. - Uses unaligned byte stores (split into 4 bytes) to prevent cudaErrorMisalignedAddress. + Modified to use unaligned byte stores to prevent cudaErrorMisalignedAddress. """ row_idx = tl.program_id(0) if row_idx >= num_rows: @@ -272,6 +271,9 @@ def pack_kernel( base_idx = row_idx * k_dim + # ------------------------------------------------------------------ + # PROCESS LOOP + # ------------------------------------------------------------------ i = 0 while i < k_dim: val = tl.load(delta_ptr + base_idx + i).to(tl.int32) @@ -284,6 +286,8 @@ def pack_kernel( if is_rice: # Rice: q '1's, then '0', then k_rice bits of r + + # Append Unary (q ones) q_count = q.to(tl.int32) while q_count > 0: acc_data |= (tl.full((), 1, dtype=tl.uint64) << acc_bits) @@ -312,7 +316,7 @@ def pack_kernel( acc_data |= (q_bit << acc_bits) acc_bits += 1 - # Flush Check (after separator/bitmap bit) + # Flush Check if acc_bits >= 32: val_u32 = acc_data.to(tl.uint32) tl.store(out_ptr_base + 0, (val_u32 & 0xFF).to(tl.uint8)) @@ -342,7 +346,10 @@ def pack_kernel( i += 1 - # Final Flush + # ------------------------------------------------------------------ + # FINAL FLUSH + # ------------------------------------------------------------------ + # We might have 1..31 bits left. Write byte-by-byte. while acc_bits > 0: tl.store(out_ptr_base, (acc_data & 0xFF).to(tl.uint8)) out_ptr_base += 1 @@ -350,84 +357,20 @@ def pack_kernel( acc_bits -= 8 -@triton.jit -def parse_header_kernel( - u8_payload_ptr, # (total_bytes,) uint8 - C_out_ptr, # (1,) int32 - K_out_ptr, # (1,) int32 - R_out_ptr, # (1,) int32 - num_B_out_ptr, # (1,) int32 - B_choices_out_ptr, # (MAX_B_CHOICES,) int32 - header_bytes_out_ptr, # (1,) int32 - max_num_B: tl.constexpr, -): - """ - Simple GPU kernel to parse the global header. - Replaces CPU struct.unpack to avoid CPU<->GPU synchronization. - """ - pid = tl.program_id(0) - if pid != 0: - return - - # Magic "CGRP" checked implicitly or skipped for speed - - # C (uint32 LE at bytes 4..7) - b4 = tl.load(u8_payload_ptr + 4).to(tl.int32) - b5 = tl.load(u8_payload_ptr + 5).to(tl.int32) - b6 = tl.load(u8_payload_ptr + 6).to(tl.int32) - b7 = tl.load(u8_payload_ptr + 7).to(tl.int32) - C_val = b4 | (b5 << 8) | (b6 << 16) | (b7 << 24) - tl.store(C_out_ptr, C_val) - - # K (uint16 LE at bytes 8..9) - b8 = tl.load(u8_payload_ptr + 8).to(tl.int32) - b9 = tl.load(u8_payload_ptr + 9).to(tl.int32) - K_val = b8 | (b9 << 8) - tl.store(K_out_ptr, K_val) - - # R (uint32 LE at bytes 10..13) - b10 = tl.load(u8_payload_ptr + 10).to(tl.int32) - b11 = tl.load(u8_payload_ptr + 11).to(tl.int32) - b12 = tl.load(u8_payload_ptr + 12).to(tl.int32) - b13 = tl.load(u8_payload_ptr + 13).to(tl.int32) - R_val = b10 | (b11 << 8) | (b12 << 16) | (b13 << 24) - tl.store(R_out_ptr, R_val) - - # num_B at byte 14 - num_B_val = tl.load(u8_payload_ptr + 14).to(tl.int32) - tl.store(num_B_out_ptr, num_B_val) - - # Read B_choices (start at 15) - off = 15 - i = 0 - while i < max_num_B: - if i < num_B_val: - lo = tl.load(u8_payload_ptr + off).to(tl.int32) - hi = tl.load(u8_payload_ptr + off + 1).to(tl.int32) - B_val = lo | (hi << 8) - tl.store(B_choices_out_ptr + i, B_val) - off += 2 - i += 1 - - tl.store(header_bytes_out_ptr, off) - - @triton.jit def parse_row_table_kernel( u8_payload_ptr, row_payload_bytes_ptr, best_B_idx_ptr, use_bitmap_ptr, - row_table_start_ptr, # int32* from header kernel - num_rows_ptr, # int32* from header kernel + row_table_start: tl.int32, + num_rows: tl.int32, ROW_HEADER_BITS: tl.constexpr, ): pid = tl.program_id(0) - num_rows = tl.load(num_rows_ptr) if pid >= num_rows: return - row_table_start = tl.load(row_table_start_ptr) entry_offset = row_table_start + pid * 3 b0 = tl.load(u8_payload_ptr + entry_offset).to(tl.int32) @@ -447,45 +390,46 @@ def parse_row_table_kernel( @triton.jit -def popc64(x): +def count_ones_in_word(word_u64): """ - Population count (Hamming weight) for uint64 using SWAR. - Safe for Triton (no loops). + Counts trailing ones in a 64-bit word (register level). + Used for fast unary decoding without global memory access. """ - x = x - ((x >> 1) & 0x5555555555555555) - x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333) - x = (x + (x >> 4)) & 0x0F0F0F0F0F0F0F0F - x = x + (x >> 8) - x = x + (x >> 16) - x = x + (x >> 32) - return (x & 0x7F).to(tl.int32) + cnt = tl.full((), 0, dtype=tl.int32) + ONE_U64 = tl.full((), 1, dtype=tl.uint64) + + check = word_u64 + cond = ((check & ONE_U64) == ONE_U64) & (cnt < 64) + while cond: + cnt += 1 + check >>= 1 + # Update condition for next iteration + cond = ((check & ONE_U64) == ONE_U64) & (cnt < 64) + return cnt @triton.jit def decode_rows_kernel( u8_payload_ptr, out_vals_ptr, - row_bit_offsets_ptr, - best_B_idx_ptr, - use_bitmap_ptr, - k_rice_choices_ptr, - num_rows_ptr, - K_ptr, - total_payload_bytes: tl.int32, + row_bit_offsets_ptr, # (rows,) + row_payload_bytes_ptr, # (rows,) + best_B_idx_ptr, # (rows,) + use_bitmap_ptr, # (rows,) + k_rice_choices_ptr, # (num_B,) + num_rows: tl.int32, + K: tl.int32, ): + """ + Decodes rows using unaligned-safe 64-bit reads (via byte loads). + """ row_idx = tl.program_id(0) - num_rows = tl.load(num_rows_ptr) if row_idx >= num_rows: return - K = tl.load(K_ptr) - - # Safety limits - limit_ptr = u8_payload_ptr + total_payload_bytes - total_bits = total_payload_bytes * 8 - # Row params start_bit = tl.load(row_bit_offsets_ptr + row_idx).to(tl.int32) + payload_bytes = tl.load(row_payload_bytes_ptr + row_idx).to(tl.int32) b_idx = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) use_bitmap = (tl.load(use_bitmap_ptr + row_idx) & 1).to(tl.int32) @@ -497,110 +441,65 @@ def decode_rows_kernel( i = 0 while i < K: - # FIX: Use if/else instead of 'continue' - if current_bit >= total_bits: - # OOB: Fill with 0 - tl.store(out_vals_ptr + base_out + i, 0) + # ------------------------------------------------ + # BUFFERED LOAD (Unaligned Safe) + # ------------------------------------------------ + byte_idx = current_bit // 8 + bit_in_byte = current_bit % 8 + + # Manually load 8 bytes to form uint64. + # This prevents misaligned access on all GPUs. + b0 = tl.load(u8_payload_ptr + byte_idx + 0).to(tl.uint64) + b1 = tl.load(u8_payload_ptr + byte_idx + 1).to(tl.uint64) + b2 = tl.load(u8_payload_ptr + byte_idx + 2).to(tl.uint64) + b3 = tl.load(u8_payload_ptr + byte_idx + 3).to(tl.uint64) + b4 = tl.load(u8_payload_ptr + byte_idx + 4).to(tl.uint64) + b5 = tl.load(u8_payload_ptr + byte_idx + 5).to(tl.uint64) + b6 = tl.load(u8_payload_ptr + byte_idx + 6).to(tl.uint64) + b7 = tl.load(u8_payload_ptr + byte_idx + 7).to(tl.uint64) + + word_u64 = b0 | (b1 << 8) | (b2 << 16) | (b3 << 24) | \ + (b4 << 32) | (b5 << 40) | (b6 << 48) | (b7 << 56) + + # Shift out consumed bits + stream = word_u64 >> bit_in_byte + + # ------------------------------------------------ + # DECODE LOGIC + # ------------------------------------------------ + q = 0 + r = 0 + + is_rice = (i == 0) | (use_bitmap == 0) + + bits_consumed = 0 + + if is_rice: + # Decode Unary q + q = count_ones_in_word(stream) + bits_consumed += (q + 1) + stream >>= (q + 1) else: - # NORMAL PATH - is_rice = (i == 0) | (use_bitmap == 0) - - # Accumulate q here - q = 0 - - if is_rice: - # --- RICE DECODING (Unary Part) --- - reading_unary = 1 - - while reading_unary > 0: - # 1. Load Window - byte_idx = current_bit // 8 - bit_in_byte = current_bit % 8 - - base_ptr = u8_payload_ptr + byte_idx - ptr_int = base_ptr.to(tl.uint64) - - # Align - mask_aligned = tl.full((), 0xFFFFFFFFFFFFFFF8, dtype=tl.uint64) - aligned_ptr_int = ptr_int & mask_aligned - offset = (ptr_int & 7).to(tl.int32) - aligned_ptr = aligned_ptr_int.to(tl.pointer_type(tl.uint64)) - - w0 = tl.load(aligned_ptr) - w1_addr_int = aligned_ptr_int + 8 - limit_addr_int = limit_ptr.to(tl.uint64) - w1 = tl.load(aligned_ptr + 1, mask=w1_addr_int < limit_addr_int, other=0) - - # Construct 64-bit word - shift = offset * 8 - lo = w0 >> shift - shift_up = (64 - shift) & 63 - hi = w1 << shift_up - word_u64 = tl.where(offset == 0, w0, lo | hi) - - stream = word_u64 >> bit_in_byte - - # 2. Find First Zero - inv_stream = ~stream - lsb = inv_stream & (0 - inv_stream) - dist = popc64(lsb - 1) - - valid_bits_in_window = 64 - bit_in_byte - - if dist < valid_bits_in_window: - q += dist - current_bit += (dist + 1) - reading_unary = 0 - else: - q += valid_bits_in_window - current_bit += valid_bits_in_window - - if current_bit >= total_bits: - reading_unary = 0 - else: - # --- BITMAP DECODING --- - byte_idx = current_bit // 8 - bit_in_byte = current_bit % 8 - b0 = tl.load(u8_payload_ptr + byte_idx).to(tl.uint64) - q = ((b0 >> bit_in_byte) & 1).to(tl.int32) - current_bit += 1 - - # --- REMAINDER DECODING --- - # Safety check for remainder bits - if current_bit + k_rice > total_bits + 8: - val = q * M - else: - byte_idx = current_bit // 8 - bit_in_byte = current_bit % 8 - - base_ptr = u8_payload_ptr + byte_idx - ptr_int = base_ptr.to(tl.uint64) - mask_aligned = tl.full((), 0xFFFFFFFFFFFFFFF8, dtype=tl.uint64) - aligned_ptr_int = ptr_int & mask_aligned - offset = (ptr_int & 7).to(tl.int32) - aligned_ptr = aligned_ptr_int.to(tl.pointer_type(tl.uint64)) - - w0 = tl.load(aligned_ptr) - w1_addr_int = aligned_ptr_int + 8 - limit_addr_int = limit_ptr.to(tl.uint64) - w1 = tl.load(aligned_ptr + 1, mask=w1_addr_int < limit_addr_int, other=0) - - shift = offset * 8 - lo = w0 >> shift - shift_up = (64 - shift) & 63 - hi = w1 << shift_up - word_u64 = tl.where(offset == 0, w0, lo | hi) - - stream = word_u64 >> bit_in_byte - mask = (tl.full((), 1, dtype=tl.uint64) << k_rice) - 1 - r = (stream & mask).to(tl.int32) - - current_bit += k_rice - val = q * M + r - - tl.store(out_vals_ptr + base_out + i, val) + # Bitmap: q is single bit + q = (stream & 1).to(tl.int32) + bits_consumed += 1 + stream >>= 1 + + # Decode Remainder r + mask = (tl.full((), 1, dtype=tl.uint64) << k_rice) - 1 + r = (stream & mask).to(tl.int32) + bits_consumed += k_rice + + # ------------------------------------------------ + # STORE + # ------------------------------------------------ + val = q * M + r + tl.store(out_vals_ptr + base_out + i, val) + + current_bit += bits_consumed i += 1 + def decode_batch_rows( payload: BytesLike, max_num_B: int = 16, @@ -624,35 +523,30 @@ def decode_batch_rows( if total_bytes == 0: return torch.empty((0, 0), dtype=torch.int64, device=dev), 0, 0 - # --- 1) Parse Header (GPU Kernel) --- - # Avoid CPU synchronization by parsing on GPU - C_out = torch.empty(1, dtype=torch.int32, device=dev) - K_out = torch.empty(1, dtype=torch.int32, device=dev) - R_out = torch.empty(1, dtype=torch.int32, device=dev) - num_B_out = torch.empty(1, dtype=torch.int32, device=dev) - B_choices_out = torch.empty(max_num_B, dtype=torch.int32, device=dev) - header_bytes_out = torch.empty(1, dtype=torch.int32, device=dev) + # --- 1) Parse Global Header (CPU) --- + header_size_min = 15 + header_cpu = payload_gpu[:64].cpu().numpy().tobytes() - parse_header_kernel[(1,)]( - payload_gpu, - C_out, - K_out, - R_out, - num_B_out, - B_choices_out, - header_bytes_out, - max_num_B=max_num_B - ) + try: + # Fixed format string to match 15 bytes + magic, C, K, num_rows, num_B = struct.unpack("<4sIHIB", header_cpu[:15]) + except struct.error: + raise ValueError("Payload too short for header") + + if magic != b"CGRP": + raise ValueError("Invalid magic bytes") - # Minimal sync to get scalar values needed for kernel setup - C = int(C_out.item()) - num_B = int(num_B_out.item()) - num_rows = int(R_out.item()) - B_choices_list = B_choices_out[:num_B].cpu().tolist() + offset = 15 + B_choices = [] + for _ in range(num_B): + b_val = struct.unpack(" Date: Wed, 19 Nov 2025 15:46:14 +0400 Subject: [PATCH 20/33] Revert "add offload time tracking" This reverts commit 616e21031209e3482595d09da5f8f3d2c0455553. --- src/tplr/neurons.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/tplr/neurons.py b/src/tplr/neurons.py index fddfee509..410a18268 100644 --- a/src/tplr/neurons.py +++ b/src/tplr/neurons.py @@ -113,7 +113,6 @@ def barrier(group=None): compression_time = 0 copy_time = 0 encode_time = 0 - offload_time = 0 for _, (n, p) in enumerate(model_iterator, 1): owned = n in miner.owned_params p_is_dt = is_dtensor(p) @@ -212,7 +211,8 @@ def barrier(group=None): p.grad = None copy_time += tplr.T() - copy_start - offload_start = tplr.T() + tplr.logger.info(f"times: {encode_time}, {compression_time}, {copy_time}") + # Batch offload all error feedback tensors to CPU with pinned memory for name in miner.error_feedback: if ( @@ -224,8 +224,6 @@ def barrier(group=None): miner.error_feedback[name], non_blocking=True ) miner.error_feedback[name] = miner.error_feedback_cpu_buffers[name] - offload_time += tplr.T() - offload_start - tplr.logger.info(f"times: {encode_time}, {compression_time}, {copy_time}, {offload_time}") # Single synchronization at the end for all async operations if torch.cuda.is_available(): From bb68d639ed692dff10890381094cfd7a76158015 Mon Sep 17 00:00:00 2001 From: Kasper Date: Wed, 19 Nov 2025 15:46:16 +0400 Subject: [PATCH 21/33] Revert "added register buffering, fast unary decoding and simple header parsing" This reverts commit 56c8d7a02e442e151e85c1eb0c9ef96a7cb82e72. --- src/tplr/compress.py | 2 +- src/tplr/compression/bitops.py | 249 ++++++++++++ src/tplr/compression/hybrid.py | 700 ++++++++++++++++++++------------- 3 files changed, 671 insertions(+), 280 deletions(-) create mode 100644 src/tplr/compression/bitops.py diff --git a/src/tplr/compress.py b/src/tplr/compress.py index b6cc544be..b52094fc3 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -391,7 +391,7 @@ def decompress( val = val.to(dtype=x.dtype) # second condition for legacy decompress - if len(xshape) > 2 and len(x) != len(idx_int64): + if len(xshape) > 2 and len(xshape) != len(idx_int64): idx_int64 = rearrange(idx_int64, "(y x) h -> y x h", y=xshape[0]) val = rearrange(val, "(y x) h -> y x h", y=xshape[0]) diff --git a/src/tplr/compression/bitops.py b/src/tplr/compression/bitops.py new file mode 100644 index 000000000..6302e5692 --- /dev/null +++ b/src/tplr/compression/bitops.py @@ -0,0 +1,249 @@ +from typing import Union + +import numpy as np +import torch +import triton +import triton.language as tl + +BytesLike = Union[bytes, bytearray, np.ndarray, torch.Tensor] + +@triton.jit +def write_nbits( + u8_ptr, # uint8* global buffer + bit_off_i32, # scalar tl.int32 bit offset + value_u32, # scalar tl.uint32, up to 32 bits used + nbits_i32, # scalar tl.int32, number of bits to write +): + """ + Writes `nbits_i32` least-significant bits of `value_u32` into `u8_ptr` + starting at bit offset `bit_off_i32` in LSB-first order. + + This is still a bit-at-a-time writer; higher-level kernels have been + adjusted to use int32 + shift/mask ahead of time. + """ + j = tl.full((), 0, dtype=tl.int32) + ONE_U32 = tl.full((), 1, dtype=tl.uint32) + + while j < nbits_i32: + pos = bit_off_i32 + j + byte_idx = (pos >> 3).to(tl.int32) + bit_idx = (pos & 7).to(tl.int32) + + old_u8 = tl.load(u8_ptr + byte_idx) + old_u32 = old_u8.to(tl.uint32) + + vbit = (value_u32 >> j) & ONE_U32 + mask = ONE_U32 << bit_idx + new_u32 = (old_u32 & (~mask)) | (vbit << bit_idx) + tl.store(u8_ptr + byte_idx, new_u32.to(tl.uint8)) + j += 1 + return bit_off_i32 + nbits_i32 + + +@triton.jit +def write_nbits_fast( + u8_ptr, + bit_off_i32, # start bit + value_u32, # LSB-first payload bits + nbits_i32, # 0..32 +): + # If nothing to write + if nbits_i32 <= 0: + return bit_off_i32 + + start_bit = bit_off_i32 + first_byte = (start_bit >> 3).to(tl.int32) + first_bit = (start_bit & 7).to(tl.int32) + + # How many bits fit in the first byte + bits_in_first = tl.minimum( + nbits_i32, + tl.full((), 8, dtype=tl.int32) - first_bit, + ) + + # -------- leading partial byte -------- + if bits_in_first > 0: + old_u8 = tl.load(u8_ptr + first_byte).to(tl.uint32) + + # mask for the bits we overwrite inside that byte + mask_u32 = ((tl.full((), 1, tl.uint32) << bits_in_first) - 1) \ + << first_bit + + # extract those bits from value_u32 + bits_u32 = (value_u32 & ((tl.full((), 1, tl.uint32) << bits_in_first) - 1)) \ + << first_bit + + new_u8 = ((old_u8 & ~mask_u32) | bits_u32).to(tl.uint8) + tl.store(u8_ptr + first_byte, new_u8) + + bit_off_i32 += bits_in_first + value_u32 >>= bits_in_first + nbits_i32 -= bits_in_first + + # Now bit_off_i32 is byte aligned (or nbits_i32 == 0) + if nbits_i32 <= 0: + return bit_off_i32 + + cur_byte = (bit_off_i32 >> 3).to(tl.int32) + + # full bytes we can write + full_bytes = (nbits_i32 >> 3).to(tl.int32) # nbits_i32 // 8 + rem_bits = (nbits_i32 & 7).to(tl.int32) + + # -------- full bytes -------- + jb = tl.full((), 0, dtype=tl.int32) + while jb < full_bytes: + # take lowest 8 bits from value_u32 + byte_val = (value_u32 & tl.full((), 0xFF, tl.uint32)).to(tl.uint8) + tl.store(u8_ptr + cur_byte + jb, byte_val) + value_u32 >>= 8 + jb += 1 + + bit_off_i32 += full_bytes * 8 + + # -------- trailing partial byte -------- + if rem_bits > 0: + byte_idx = (bit_off_i32 >> 3).to(tl.int32) + old_u8 = tl.load(u8_ptr + byte_idx).to(tl.uint32) + + mask_u32 = ( (tl.full((), 1, tl.uint32) << rem_bits) - 1 ) + bits_u32 = ( value_u32 & mask_u32 ) + + new_u8 = ((old_u8 & ~mask_u32) | bits_u32).to(tl.uint8) + tl.store(u8_ptr + byte_idx, new_u8) + + bit_off_i32 += rem_bits + return bit_off_i32 + + +@triton.jit +def read_nbits(u8_ptr, bit_off_i32, nbits_i32, limit_bit_i32): + """ + GPU version of BitStreamReader.read_bits (LSB-first), but bounds-safe. + + Reads `nbits_i32` bits starting at `bit_off_i32`, but never loads beyond + bit index `limit_bit_i32` (masked loads return 0 out-of-bounds). + + Returns: (value_u32, new_bit_off_i32) + """ + j = tl.full((), 0, dtype=tl.int32) + val_u32 = tl.full((), 0, dtype=tl.uint32) + ONE_U32 = tl.full((), 1, dtype=tl.uint32) + ZERO_U8 = tl.full((), 0, dtype=tl.uint8) + + while j < nbits_i32: + pos = bit_off_i32 + j + in_bounds = pos < limit_bit_i32 + + byte_idx = (pos >> 3).to(tl.int32) + bit_idx = (pos & 7).to(tl.int32) + + # Masked load: if in_bounds==0, we load ZERO_U8 instead of touching memory. + u8 = tl.load(u8_ptr + byte_idx, mask=in_bounds, other=ZERO_U8) + u32 = u8.to(tl.uint32) + bit = (u32 >> bit_idx) & ONE_U32 + + val_u32 |= (bit << j) + j += 1 + + new_bit_off = bit_off_i32 + nbits_i32 + return val_u32, new_bit_off + + +@triton.jit +def read_nbits_fast(u8_ptr, bit_off_i32, nbits_i32, limit_bit_i32): + if nbits_i32 <= 0: + return tl.full((), 0, tl.uint32), bit_off_i32 + + # clamp to limit if you want to keep the defensive behavior + max_bits = limit_bit_i32 - bit_off_i32 + nbits_i32 = tl.minimum(nbits_i32, max_bits) + + start_bit = bit_off_i32 + end_bit = bit_off_i32 + nbits_i32 + + first_byte = (start_bit >> 3).to(tl.int32) + first_bit = (start_bit & 7).to(tl.int32) + + bits_in_first = tl.minimum( + nbits_i32, + tl.full((), 8, dtype=tl.int32) - first_bit, + ) + + val_u32 = tl.full((), 0, dtype=tl.uint32) + shift = tl.full((), 0, dtype=tl.int32) + + # -------- leading partial byte -------- + if bits_in_first > 0: + byte = tl.load(u8_ptr + first_byte).to(tl.uint32) + mask = ((tl.full((), 1, tl.uint32) << bits_in_first) - 1) << first_bit + chunk = (byte & mask) >> first_bit + val_u32 |= (chunk << shift) + + bit_off_i32 += bits_in_first + shift += bits_in_first + nbits_i32 -= bits_in_first + + if nbits_i32 <= 0: + return val_u32, bit_off_i32 + + cur_byte = (bit_off_i32 >> 3).to(tl.int32) + full_bytes = (nbits_i32 >> 3).to(tl.int32) + rem_bits = (nbits_i32 & 7).to(tl.int32) + + # -------- full bytes -------- + jb = tl.full((), 0, dtype=tl.int32) + while jb < full_bytes: + byte = tl.load(u8_ptr + cur_byte + jb).to(tl.uint32) + val_u32 |= (byte << shift) + shift += 8 + jb += 1 + + bit_off_i32 += full_bytes * 8 + + # -------- trailing partial byte -------- + if rem_bits > 0: + byte = tl.load(u8_ptr + (bit_off_i32 >> 3).to(tl.int32)).to(tl.uint32) + mask = (tl.full((), 1, tl.uint32) << rem_bits) - 1 + chunk = byte & mask + val_u32 |= (chunk << shift) + bit_off_i32 += rem_bits + return val_u32, bit_off_i32 + + +@triton.jit +def read_unary_bounded_triton(u8_ptr, bit_off_i32, end_bit_i32): + """ + GPU version of BitStreamReader.read_unary_bounded(end_bit). + Reads '1's until a '0' or end_bit. + Returns: (q_i32, new_bit_off_i32, hit_end_i32) + - q_i32: number of 1s before the terminating 0 + - hit_end_i32: 1 if we reached end_bit without seeing 0 + 0 if we saw a terminating 0 + """ + ONE_U32 = tl.full((), 1, dtype=tl.uint32) + q_i32 = tl.full((), 0, dtype=tl.int32) + hit_end_i32 = tl.full((), 1, dtype=tl.int32) + + cond = bit_off_i32 < end_bit_i32 + while cond: + pos = bit_off_i32 + byte_idx = (pos >> 3).to(tl.int32) + bit_idx = (pos & 7).to(tl.int32) + + u8 = tl.load(u8_ptr + byte_idx) + u32 = u8.to(tl.uint32) + bit = (u32 >> bit_idx) & ONE_U32 + + bit_off_i32 += 1 + + is_one = (bit == ONE_U32) + q_i32 += is_one.to(tl.int32) + + # If bit is 0, we did NOT hit end + hit_end_i32 = tl.where(is_one, hit_end_i32, + tl.full((), 0, dtype=tl.int32)) + + # Continue only if we are still inside the row and last bit was 1 + cond = (bit_off_i32 < end_bit_i32) & is_one + return q_i32, bit_off_i32, hit_end_i32 \ No newline at end of file diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py index 273de408c..abc7bcc24 100644 --- a/src/tplr/compression/hybrid.py +++ b/src/tplr/compression/hybrid.py @@ -1,12 +1,13 @@ import math -import struct -from typing import Dict, Tuple, Union +from typing import Dict +from typing import Tuple, Union import numpy as np import torch import triton import triton.language as tl +from .bitops import write_nbits_fast, read_unary_bounded_triton, read_nbits_fast BytesLike = Union[bytes, bytearray, np.ndarray, torch.Tensor] @@ -19,7 +20,7 @@ def encode_batch_rows( B_choices: Tuple[int, ...] = (64, 128) ) -> Tuple[BytesLike, Dict]: """ - Compresses a 2D sorted int tensor of Top-K indices into a byte string + Compresses a 2D int64 tensor of Top-K indices into a byte string using a per-row adaptive Rice/Bitmap compression scheme on the GPU. Layout: @@ -67,28 +68,29 @@ def encode_batch_rows( idx_sorted = idx_sorted.contiguous() dev = idx_sorted.device - # delta-encoding vals = torch.cat( (idx_sorted[:, :1], idx_sorted[:, 1:] - idx_sorted[:, :-1]), dim=1, - ).to(torch.int32) + ) + + # Cast to int32 for Triton kernels + vals = vals.to(torch.int32) # k_rice parameters (log2(C // B)) k_rice_choices = tuple(int(math.log2(C // b)) for b in B_choices) num_B_choices = len(B_choices) k_rice_choices_tensor = torch.tensor(k_rice_choices, dtype=torch.int32, device=dev) - # Row header bits + # Row header bits (only used for packing row-table header byte) B_choice_bits = (num_B_choices - 1).bit_length() - ROW_HEADER_BITS = 1 + B_choice_bits + ROW_HEADER_BITS = 1 + B_choice_bits # (best_B_idx << 1) | use_bitmap # Output tensors for cost kernel costs = torch.empty((num_rows, num_B_choices), dtype=torch.int32, device=dev) is_bitmap = torch.empty((num_rows, num_B_choices), dtype=torch.int8, device=dev) - - # Calculate grid for cost kernel grid = (num_rows,) + # cost kernel: bits required for deltas only (no header bits) cost_kernel[grid]( vals, costs, @@ -103,14 +105,18 @@ def encode_batch_rows( min_costs, best_B_idx = torch.min(costs, dim=1) is_bitmap_choice = torch.gather(is_bitmap, 1, best_B_idx.unsqueeze(1)).squeeze(1).to(torch.int32) - # Payload sizing - row_payload_bits = min_costs - row_payload_bytes = ((row_payload_bits + 7) // 8).to(torch.int32) + # (1) payload bits per row = bits for deltas only + row_payload_bits = min_costs # (rows,) + + # (2) payload bytes per row (rounded up) + row_payload_bytes = ((row_payload_bits + 7) // 8).to(torch.int32) # (rows,) + # ensure fit in uint16 for the row table if torch.any(row_payload_bytes > 0xFFFF): raise ValueError("Row payload length exceeds 65535 bytes; cannot store in uint16.") - # Byte offsets + # byte offsets within the payload region (no gaps) + # row_byte_offsets[r] = sum_{i> 8) & 0xFF).to(torch.uint8) - row_table_flat[:, 2] = (headers_i32 & ((1 << ROW_HEADER_BITS) - 1)).to(torch.uint8) + row_table = torch.empty((num_rows, row_entry_bytes), dtype=torch.uint8, device=dev) + row_table[:, 0] = (lengths_i32 & 0xFF).to(torch.uint8) + row_table[:, 1] = ((lengths_i32 >> 8) & 0xFF).to(torch.uint8) - payload_buf[global_header_len_bytes: global_header_len_bytes + row_table_bytes] = row_table_flat.view(-1) + # Only the low ROW_HEADER_BITS bits are meaningful, but we just store the byte. + row_table[:, 2] = (headers_i32 & ((1 << ROW_HEADER_BITS) - 1)).to(torch.uint8) - # Calculate absolute byte offsets for pack kernel - row_abs_byte_offsets = (payload_region_start + row_byte_offsets).to(torch.int32) + payload_buf[ + global_header_len_bytes : global_header_len_bytes + row_table_bytes + ] = row_table.view(-1) - # Pack payloads (Optimized Kernel) + # compute bit offsets for each row's payload (no per-row length/header in-band) + row_bit_offsets = (payload_region_start + row_byte_offsets).to(torch.int32) * 8 + + # pack payloads pack_kernel[(num_rows,)]( vals, payload_buf, - row_abs_byte_offsets, + row_bit_offsets, best_B_idx.to(torch.int32), - is_bitmap_choice, + is_bitmap_choice, # int32 0/1 k_rice_choices_tensor, num_rows, k_dim=k_dim, ) - # Meta stats + # meta b_counts = torch.bincount(best_B_idx, minlength=len(B_choices)) B_hist = {b: c.item() for b, c in zip(B_choices, b_counts)} total_row_bytes = total_payload_bytes + row_entry_bytes * num_rows @@ -191,36 +201,44 @@ def encode_batch_rows( @triton.jit def cost_kernel( - delta_ptr, - costs_ptr, - is_bitmap_ptr, - k_dim: tl.constexpr, + delta_ptr, # (rows, k_dim) IN + costs_ptr, # (rows, num_B_choices) OUT + is_bitmap_ptr, # (rows, num_B_choices) OUT (bool/int) + k_dim: tl.constexpr, # constexpr for tl.arange num_rows: tl.int32, num_B_choices: tl.int32, - k_rice_choices_ptr, + k_rice_choices_ptr, # (num_B_choices,) int32 ): """ - Calculates bit cost. One row per program instance. + Calculates the compressed bit cost for each row for each B in B_choices. + One program instance processes one row. + Variant B: first delta encoded with Rice, tail optionally bitmap (q in {0,1}). """ row_idx = tl.program_id(0) if row_idx >= num_rows: return + # Lane indices for this row (constexpr width) i = tl.arange(0, k_dim) + + # Load entire row of delta-encoded values into SRAM row_base = row_idx * k_dim delta = tl.load(delta_ptr + row_base + i) delta0 = tl.load(delta_ptr + row_base) b_idx = 0 while b_idx < num_B_choices: + # k_rice and M = 1 << k_rice k_rice = tl.load(k_rice_choices_ptr + b_idx) + # q via shift, r via mask q = delta >> k_rice q0 = delta0 >> k_rice + # Pure Rice cost: sum(q + 1) + k_dim * k_rice rice_cost = tl.sum(q + 1) + k_dim * k_rice - # Bitmap cost: head is Rice, tail is (1 + k_rice) + # Bitmap cost: first element full Rice, tail has (1 + k_rice) bits bitmap_cost = (q0 + 1 + k_rice) + (k_dim - 1) * (1 + k_rice) # Allow bitmap only if tail q are in {0,1} @@ -232,152 +250,246 @@ def cost_kernel( out_offset = row_idx * num_B_choices + b_idx tl.store(costs_ptr + out_offset, min_cost) - tl.store(is_bitmap_ptr + out_offset, tl.where(use_bitmap, 1, 0).to(tl.int32)) + # make sure is_bitmap is exactly 0/1 in memory + tl.store( + is_bitmap_ptr + out_offset, + tl.where(use_bitmap, 1, 0).to(tl.int32), + ) b_idx += 1 @triton.jit def pack_kernel( - delta_ptr, # (rows, k_dim) IN int32 - u8_payload_ptr, # OUT uint8 - row_abs_byte_offsets_ptr, # (rows,) IN int32 (byte offset where payload starts) - best_B_idx_ptr, # (rows,) IN - is_bitmap_ptr, # (rows,) IN - k_rice_choices_ptr, # [num_B] IN - num_rows: tl.int32, - k_dim: tl.int32, # dynamic + delta_ptr, # (rows, k_dim) IN int32 + u8_payload_ptr, # (final_buffer_bytes,) OUT uint8 + row_bit_offsets_ptr, # (rows,) IN int32 (bit offset where payload starts) + best_B_idx_ptr, # (rows,) IN int32 + is_bitmap_ptr, # (rows,) IN int32 (0/1) + k_rice_choices_ptr, # [num_B] IN int32 + num_rows: tl.int32, + k_dim: tl.int32, # dynamic ): """ - Writes payload bits using a 64-bit register accumulator. - Modified to use unaligned byte stores to prevent cudaErrorMisalignedAddress. + Writes only the Rice/bitmap-coded payload bits for each row. + + Each program instance handles one row. Bit order is LSB-first. """ row_idx = tl.program_id(0) if row_idx >= num_rows: return - # Load row params - out_byte_off = tl.load(row_abs_byte_offsets_ptr + row_idx).to(tl.int32) + # Per-row meta + bit_off_i32 = tl.load(row_bit_offsets_ptr + row_idx).to(tl.int32) b_idx_i32 = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) use_bitmap_i32 = (tl.load(is_bitmap_ptr + row_idx) & 1).to(tl.int32) + + # params k_rice_i32 = tl.load(k_rice_choices_ptr + b_idx_i32).to(tl.int32) M_i32 = (tl.full((), 1, dtype=tl.int32) << k_rice_i32) - # Accumulator state - acc_data = tl.full((), 0, dtype=tl.uint64) - acc_bits = tl.full((), 0, dtype=tl.int32) - - # Output pointer (byte-aligned) - out_ptr_base = u8_payload_ptr + out_byte_off - - base_idx = row_idx * k_dim + ONE_U32 = tl.full((), 1, dtype=tl.uint32) + ZERO_U32 = tl.full((), 0, dtype=tl.uint32) + ONE_I32 = tl.full((), 1, dtype=tl.int32) + THIRTY_ONE_I32 = tl.full((), 31, dtype=tl.int32) + + base = row_idx * k_dim + + # ---- first delta: ALWAYS full Rice (unary + remainder) ---- + if k_dim > 0: + v0 = tl.load(delta_ptr + base).to(tl.int32) + q0 = (v0 >> k_rice_i32).to(tl.int32) + r0 = (v0 & (M_i32 - 1)).to(tl.int32) + + # q0 ones in chunks of <= 31, then a single 0 + q_left = q0 + while q_left > 0: + chunk = tl.minimum(q_left, THIRTY_ONE_I32) + ones = (ONE_U32 << chunk) - ONE_U32 + bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ones, chunk) + q_left -= chunk + + # terminating 0 bit + bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ZERO_U32, ONE_I32) + # remainder + bit_off_i32 = write_nbits_fast( + u8_payload_ptr, bit_off_i32, r0.to(tl.uint32), k_rice_i32 + ) - # ------------------------------------------------------------------ - # PROCESS LOOP - # ------------------------------------------------------------------ - i = 0 + # ---- tail deltas ---- + i = 1 while i < k_dim: - val = tl.load(delta_ptr + base_idx + i).to(tl.int32) - - # Compute q, r - q = (val >> k_rice_i32).to(tl.uint64) - r = (val & (M_i32 - 1)).to(tl.uint64) - - is_rice = (i == 0) | (use_bitmap_i32 == 0) - - if is_rice: - # Rice: q '1's, then '0', then k_rice bits of r - - # Append Unary (q ones) - q_count = q.to(tl.int32) - while q_count > 0: - acc_data |= (tl.full((), 1, dtype=tl.uint64) << acc_bits) - acc_bits += 1 - q_count -= 1 + v = tl.load(delta_ptr + base + i).to(tl.int32) + q = (v >> k_rice_i32).to(tl.int32) + r = (v & (M_i32 - 1)).to(tl.int32) + + # Rice unary only if NOT bitmap + q_left = tl.where(use_bitmap_i32 != 0, tl.full((), 0, dtype=tl.int32), q) + while q_left > 0: + chunk = tl.minimum(q_left, THIRTY_ONE_I32) + ones = (ONE_U32 << chunk) - ONE_U32 + bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ones, chunk) + q_left -= chunk + + # terminating 0 bit only in full-Rice mode + n_term = tl.where(use_bitmap_i32 != 0, tl.full((), 0, dtype=tl.int32), ONE_I32) + bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ZERO_U32, n_term) + + # bitmap q only if bitmap + q_bit = tl.where(q > 0, ONE_U32, ZERO_U32) + n_qbit = tl.where(use_bitmap_i32 != 0, ONE_I32, tl.full((), 0, dtype=tl.int32)) + bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, q_bit, n_qbit) + + # remainder always + bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, r.to(tl.uint32), k_rice_i32) + i += 1 - # Flush Check - if acc_bits >= 32: - # Unaligned Store (4 bytes separately) - val_u32 = acc_data.to(tl.uint32) - tl.store(out_ptr_base + 0, (val_u32 & 0xFF).to(tl.uint8)) - tl.store(out_ptr_base + 1, ((val_u32 >> 8) & 0xFF).to(tl.uint8)) - tl.store(out_ptr_base + 2, ((val_u32 >> 16) & 0xFF).to(tl.uint8)) - tl.store(out_ptr_base + 3, ((val_u32 >> 24) & 0xFF).to(tl.uint8)) - out_ptr_base += 4 - acc_data >>= 32 - acc_bits -= 32 +@triton.jit +def parse_header_kernel( + u8_payload_ptr, # (total_bytes,) uint8 + C_out_ptr, # (1,) int32 + K_out_ptr, # (1,) int32 + R_out_ptr, # (1,) int32 NEW: num_rows + num_B_out_ptr, # (1,) int32 + B_choices_out_ptr, # (MAX_B_CHOICES,) int32 + header_bytes_out_ptr, # (1,) int32 + error_flag_ptr, # (1,) int32 + total_bytes: tl.int32, + MAX_B_CHOICES: tl.constexpr, +): + """ + Parse the global header entirely on GPU. + Layout: + 0..3 : "CGRP" + 4..7 : C (uint32 LE) + 8..9 : K (uint16 LE) + 10..13 : R (uint32 LE, num_rows) + 14 : num_B (uint8) + 15.. : B_choices (num_B * 2 bytes, uint16 LE) + """ - # Append Separator '0' - acc_bits += 1 + pid = tl.program_id(0) + if pid != 0: + return + # ---- init outputs / error ---- + C_val = tl.full((), 0, dtype=tl.int32) + K_val = tl.full((), 0, dtype=tl.int32) + R_val = tl.full((), 0, dtype=tl.int32) + num_B_val = tl.full((), 0, dtype=tl.int32) + header_bytes_i32 = tl.full((), 0, dtype=tl.int32) + err = tl.full((), 0, dtype=tl.int32) + + # ---- basic size + magic checks ---- + # Minimum header size: 15 bytes (without B_choices) + if total_bytes < 15: + err = 1 + else: + # Magic "CGRP" = [67, 71, 82, 80] + m0 = tl.load(u8_payload_ptr + 0) + m1 = tl.load(u8_payload_ptr + 1) + m2 = tl.load(u8_payload_ptr + 2) + m3 = tl.load(u8_payload_ptr + 3) + cond_magic = (m0 == 67) & (m1 == 71) & (m2 == 82) & (m3 == 80) + bad_magic = cond_magic == 0 + err = tl.where(bad_magic, tl.full((), 2, dtype=tl.int32), err) + + # ---- C, K, R, num_B ---- + if err == 0: + # C (uint32 LE at bytes 4..7) + b4 = tl.load(u8_payload_ptr + 4).to(tl.int32) + b5 = tl.load(u8_payload_ptr + 5).to(tl.int32) + b6 = tl.load(u8_payload_ptr + 6).to(tl.int32) + b7 = tl.load(u8_payload_ptr + 7).to(tl.int32) + C_val = b4 | (b5 << 8) | (b6 << 16) | (b7 << 24) + + # K (uint16 LE at bytes 8..9) + b8 = tl.load(u8_payload_ptr + 8).to(tl.int32) + b9 = tl.load(u8_payload_ptr + 9).to(tl.int32) + K_val = b8 | (b9 << 8) + + # R (uint32 LE at bytes 10..13) + b10 = tl.load(u8_payload_ptr + 10).to(tl.int32) + b11 = tl.load(u8_payload_ptr + 11).to(tl.int32) + b12 = tl.load(u8_payload_ptr + 12).to(tl.int32) + b13 = tl.load(u8_payload_ptr + 13).to(tl.int32) + R_val = b10 | (b11 << 8) | (b12 << 16) | (b13 << 24) + + # num_B at byte 14 + num_B_val = tl.load(u8_payload_ptr + 14).to(tl.int32) + invalid_num_B = (num_B_val <= 0) | (num_B_val > MAX_B_CHOICES) + err = tl.where(invalid_num_B, tl.full((), 3, dtype=tl.int32), err) + + # ---- read B_choices in a structured loop (no break/return) ---- + off = tl.full((), 15, dtype=tl.int32) # B_choices start at byte 15 + i = tl.full((), 0, dtype=tl.int32) + + while i < MAX_B_CHOICES: + need_this = (i < num_B_val) & (err == 0) + + if need_this: + cond_in_bounds = (off + 1) < total_bytes + if cond_in_bounds: + lo = tl.load(u8_payload_ptr + off).to(tl.int32) + hi = tl.load(u8_payload_ptr + off + 1).to(tl.int32) + B_val = lo | (hi << 8) + tl.store(B_choices_out_ptr + i, B_val) + off += 2 + else: + err = tl.full((), 4, dtype=tl.int32) + tl.store(B_choices_out_ptr + i, tl.full((), 0, dtype=tl.int32)) else: - # Bitmap: q is 1 bit - q_bit = tl.where(q > 0, 1, 0).to(tl.uint64) - acc_data |= (q_bit << acc_bits) - acc_bits += 1 - - # Flush Check - if acc_bits >= 32: - val_u32 = acc_data.to(tl.uint32) - tl.store(out_ptr_base + 0, (val_u32 & 0xFF).to(tl.uint8)) - tl.store(out_ptr_base + 1, ((val_u32 >> 8) & 0xFF).to(tl.uint8)) - tl.store(out_ptr_base + 2, ((val_u32 >> 16) & 0xFF).to(tl.uint8)) - tl.store(out_ptr_base + 3, ((val_u32 >> 24) & 0xFF).to(tl.uint8)) - - out_ptr_base += 4 - acc_data >>= 32 - acc_bits -= 32 - - # Append Remainder - acc_data |= (r << acc_bits) - acc_bits += k_rice_i32 - - # Flush Check - if acc_bits >= 32: - val_u32 = acc_data.to(tl.uint32) - tl.store(out_ptr_base + 0, (val_u32 & 0xFF).to(tl.uint8)) - tl.store(out_ptr_base + 1, ((val_u32 >> 8) & 0xFF).to(tl.uint8)) - tl.store(out_ptr_base + 2, ((val_u32 >> 16) & 0xFF).to(tl.uint8)) - tl.store(out_ptr_base + 3, ((val_u32 >> 24) & 0xFF).to(tl.uint8)) - - out_ptr_base += 4 - acc_data >>= 32 - acc_bits -= 32 + tl.store(B_choices_out_ptr + i, tl.full((), 0, dtype=tl.int32)) i += 1 - # ------------------------------------------------------------------ - # FINAL FLUSH - # ------------------------------------------------------------------ - # We might have 1..31 bits left. Write byte-by-byte. - while acc_bits > 0: - tl.store(out_ptr_base, (acc_data & 0xFF).to(tl.uint8)) - out_ptr_base += 1 - acc_data >>= 8 - acc_bits -= 8 + # header_bytes = 15 + 2 * num_B (only meaningful if err == 0) + if err == 0: + header_bytes_i32 = 15 + (num_B_val * 2) + + # ---- store outputs ---- + tl.store(C_out_ptr, C_val) + tl.store(K_out_ptr, K_val) + tl.store(R_out_ptr, R_val) + tl.store(num_B_out_ptr, num_B_val) + tl.store(header_bytes_out_ptr, header_bytes_i32) + tl.store(error_flag_ptr, err) @triton.jit def parse_row_table_kernel( - u8_payload_ptr, - row_payload_bytes_ptr, - best_B_idx_ptr, - use_bitmap_ptr, - row_table_start: tl.int32, - num_rows: tl.int32, - ROW_HEADER_BITS: tl.constexpr, + u8_payload_ptr, # (total_bytes,) uint8 + row_payload_bytes_ptr, # (num_rows,) int32 + best_B_idx_ptr, # (num_rows,) int32 + use_bitmap_ptr, # (num_rows,) int32 + row_table_start: tl.int32, + num_rows: tl.int32, + ROW_HEADER_BITS: tl.constexpr, ): + """ + Parse the row table: + + For each row r: + offset = row_table_start + r * 3 + length_bytes[r] = uint16 LE at offset + header_byte = uint8 at offset + 2 + header_bits = header_byte & ((1 << ROW_HEADER_BITS) - 1) + use_bitmap[r] = header_bits & 1 + best_B_idx[r] = header_bits >> 1 + """ pid = tl.program_id(0) if pid >= num_rows: return entry_offset = row_table_start + pid * 3 + # length_bytes: uint16 LE b0 = tl.load(u8_payload_ptr + entry_offset).to(tl.int32) b1 = tl.load(u8_payload_ptr + entry_offset + 1).to(tl.int32) length_i32 = b0 | (b1 << 8) tl.store(row_payload_bytes_ptr + pid, length_i32) + # header byte header_byte = tl.load(u8_payload_ptr + entry_offset + 2).to(tl.int32) header_mask = (tl.full((), 1, dtype=tl.int32) << ROW_HEADER_BITS) - 1 header_i32 = header_byte & header_mask @@ -389,172 +501,189 @@ def parse_row_table_kernel( tl.store(best_B_idx_ptr + pid, best_B_idx_i32) -@triton.jit -def count_ones_in_word(word_u64): - """ - Counts trailing ones in a 64-bit word (register level). - Used for fast unary decoding without global memory access. - """ - cnt = tl.full((), 0, dtype=tl.int32) - ONE_U64 = tl.full((), 1, dtype=tl.uint64) - - check = word_u64 - cond = ((check & ONE_U64) == ONE_U64) & (cnt < 64) - while cond: - cnt += 1 - check >>= 1 - # Update condition for next iteration - cond = ((check & ONE_U64) == ONE_U64) & (cnt < 64) - return cnt - @triton.jit def decode_rows_kernel( - u8_payload_ptr, - out_vals_ptr, - row_bit_offsets_ptr, # (rows,) - row_payload_bytes_ptr, # (rows,) - best_B_idx_ptr, # (rows,) - use_bitmap_ptr, # (rows,) - k_rice_choices_ptr, # (num_B,) - num_rows: tl.int32, - K: tl.int32, + u8_payload_ptr, # (total_bytes,) uint8 + out_vals_ptr, # (num_rows * K,) int32 + row_bit_offsets_ptr, # (num_rows,) int32 (bit offset of first encoded bit) + row_payload_bytes_ptr, # (num_rows,) int32 + best_B_idx_ptr, # (num_rows,) int32 + use_bitmap_ptr, # (num_rows,) int32 + k_rice_choices_ptr, # (num_B,) int32 + num_rows: tl.int32, + K: tl.int32, ): """ - Decodes rows using unaligned-safe 64-bit reads (via byte loads). + Fully GPU decode of Rice/bitmap rows. + + For each row r: + - Bit range: + start_bit = row_bit_offsets[r] + end_bit = start_bit + row_payload_bytes[r] * 8 + - First value: full Rice (unary + remainder) + - Tail: Rice or bitmap+remainder depending on use_bitmap[r]. """ row_idx = tl.program_id(0) if row_idx >= num_rows: return - # Row params - start_bit = tl.load(row_bit_offsets_ptr + row_idx).to(tl.int32) - payload_bytes = tl.load(row_payload_bytes_ptr + row_idx).to(tl.int32) - b_idx = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) - use_bitmap = (tl.load(use_bitmap_ptr + row_idx) & 1).to(tl.int32) + # Per-row metadata + row_start_bit_i32 = tl.load(row_bit_offsets_ptr + row_idx).to(tl.int32) + payload_bytes_i32 = tl.load(row_payload_bytes_ptr + row_idx).to(tl.int32) + best_B_idx_i32 = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) + use_bitmap_i32 = (tl.load(use_bitmap_ptr + row_idx) & 1).to(tl.int32) + + # k_rice and M for this row + k_rice_i32 = tl.load(k_rice_choices_ptr + best_B_idx_i32).to(tl.int32) + M_i32 = (tl.full((), 1, dtype=tl.int32) << k_rice_i32) - k_rice = tl.load(k_rice_choices_ptr + b_idx).to(tl.int32) - M = (tl.full((), 1, dtype=tl.int32) << k_rice) + # Bit range of this row + row_end_bit_i32 = row_start_bit_i32 + payload_bytes_i32 * 8 - current_bit = start_bit base_out = row_idx * K + ONE_I32 = tl.full((), 1, dtype=tl.int32) - i = 0 + bit_off_i32 = row_start_bit_i32 + + # ---- first value: ALWAYS full Rice ---- + if K > 0: + q0_i32, bit_off_i32, hit_end0_i32 = read_unary_bounded_triton( + u8_payload_ptr, + bit_off_i32, + row_end_bit_i32, + ) + r0_u32, bit_off_i32 = read_nbits_fast( + u8_payload_ptr, + bit_off_i32, + k_rice_i32, + row_end_bit_i32, # limit + ) + r0_i32 = r0_u32.to(tl.int32) + v0_i32 = q0_i32 * M_i32 + r0_i32 + tl.store(out_vals_ptr + base_out, v0_i32) + + # ---- tail values ---- + i = tl.full((), 1, dtype=tl.int32) while i < K: - # ------------------------------------------------ - # BUFFERED LOAD (Unaligned Safe) - # ------------------------------------------------ - byte_idx = current_bit // 8 - bit_in_byte = current_bit % 8 - - # Manually load 8 bytes to form uint64. - # This prevents misaligned access on all GPUs. - b0 = tl.load(u8_payload_ptr + byte_idx + 0).to(tl.uint64) - b1 = tl.load(u8_payload_ptr + byte_idx + 1).to(tl.uint64) - b2 = tl.load(u8_payload_ptr + byte_idx + 2).to(tl.uint64) - b3 = tl.load(u8_payload_ptr + byte_idx + 3).to(tl.uint64) - b4 = tl.load(u8_payload_ptr + byte_idx + 4).to(tl.uint64) - b5 = tl.load(u8_payload_ptr + byte_idx + 5).to(tl.uint64) - b6 = tl.load(u8_payload_ptr + byte_idx + 6).to(tl.uint64) - b7 = tl.load(u8_payload_ptr + byte_idx + 7).to(tl.uint64) - - word_u64 = b0 | (b1 << 8) | (b2 << 16) | (b3 << 24) | \ - (b4 << 32) | (b5 << 40) | (b6 << 48) | (b7 << 56) - - # Shift out consumed bits - stream = word_u64 >> bit_in_byte - - # ------------------------------------------------ - # DECODE LOGIC - # ------------------------------------------------ - q = 0 - r = 0 - - is_rice = (i == 0) | (use_bitmap == 0) - - bits_consumed = 0 - - if is_rice: - # Decode Unary q - q = count_ones_in_word(stream) - bits_consumed += (q + 1) - stream >>= (q + 1) + if use_bitmap_i32 != 0: + # Bitmap mode: q is 1 bit in {0,1} + q_bit_u32, bit_off_i32 = read_nbits_fast( + u8_payload_ptr, + bit_off_i32, + ONE_I32, + row_end_bit_i32, + ) + q_i32 = q_bit_u32.to(tl.int32) + + r_u32, bit_off_i32 = read_nbits_fast( + u8_payload_ptr, + bit_off_i32, + k_rice_i32, + row_end_bit_i32, + ) + r_i32 = r_u32.to(tl.int32) else: - # Bitmap: q is single bit - q = (stream & 1).to(tl.int32) - bits_consumed += 1 - stream >>= 1 - - # Decode Remainder r - mask = (tl.full((), 1, dtype=tl.uint64) << k_rice) - 1 - r = (stream & mask).to(tl.int32) - bits_consumed += k_rice - - # ------------------------------------------------ - # STORE - # ------------------------------------------------ - val = q * M + r - tl.store(out_vals_ptr + base_out + i, val) - - current_bit += bits_consumed + # Full Rice mode + q_i32, bit_off_i32, hit_end_i32 = read_unary_bounded_triton( + u8_payload_ptr, + bit_off_i32, + row_end_bit_i32, + ) + r_u32, bit_off_i32 = read_nbits_fast( + u8_payload_ptr, + bit_off_i32, + k_rice_i32, + row_end_bit_i32, + ) + r_i32 = r_u32.to(tl.int32) + + v_i32 = q_i32 * M_i32 + r_i32 + tl.store(out_vals_ptr + base_out + i, v_i32) i += 1 def decode_batch_rows( - payload: BytesLike, - max_num_B: int = 16, + payload: BytesLike, + max_num_B: int = 16, ) -> tuple[torch.Tensor, int, int]: + if not torch.cuda.is_available(): - raise RuntimeError("CUDA required") + raise RuntimeError("decode_batch_rows_gpu requires CUDA") - # Move to GPU/Tensor + # --- Move payload to CUDA (if needed) --- if isinstance(payload, torch.Tensor): + assert payload.dtype == torch.uint8 payload_gpu = payload if payload.is_cuda else payload.cuda() elif isinstance(payload, np.ndarray): + assert payload.dtype == np.uint8 payload_gpu = torch.from_numpy(payload).to("cuda", dtype=torch.uint8) - else: + elif isinstance(payload, (bytes, bytearray)): arr = np.frombuffer(bytes(payload), dtype=np.uint8) payload_gpu = torch.from_numpy(arr).to("cuda", dtype=torch.uint8) + else: + raise TypeError("Unsupported payload type") payload_gpu = payload_gpu.contiguous() dev = payload_gpu.device total_bytes = int(payload_gpu.numel()) - if total_bytes == 0: - return torch.empty((0, 0), dtype=torch.int64, device=dev), 0, 0 - - # --- 1) Parse Global Header (CPU) --- - header_size_min = 15 - header_cpu = payload_gpu[:64].cpu().numpy().tobytes() - - try: - # Fixed format string to match 15 bytes - magic, C, K, num_rows, num_B = struct.unpack("<4sIHIB", header_cpu[:15]) - except struct.error: - raise ValueError("Payload too short for header") - - if magic != b"CGRP": - raise ValueError("Invalid magic bytes") - - offset = 15 - B_choices = [] - for _ in range(num_B): - b_val = struct.unpack(" move to GPU --- k_rice_choices = [] - for B in B_choices: + for B in B_choices_list: M = C // B + if M <= 0 or (M & (M - 1)) != 0: + raise ValueError(f"M=C//B={M} not power of two for B={B}") k_rice_choices.append(int(math.log2(M))) - k_rice_choices_tensor = torch.tensor(k_rice_choices, dtype=torch.int32, device=dev) + k_rice_choices_tensor = torch.tensor( + k_rice_choices, dtype=torch.int32, device=dev + ) - ROW_HEADER_BITS = 1 + (num_B - 1).bit_length() + B_choice_bits = (num_B - 1).bit_length() + ROW_HEADER_BITS = 1 + B_choice_bits - # --- 3) Parse Row Table (GPU) --- - row_table_bytes = num_rows * 3 + # --- 3) Parse row table on GPU --- + row_entry_bytes = 3 + row_table_bytes = num_rows * row_entry_bytes + if header_bytes + row_table_bytes > total_bytes: + raise ValueError("Truncated payload: row table exceeds payload length") row_payload_bytes = torch.empty(num_rows, dtype=torch.int32, device=dev) best_B_idx = torch.empty(num_rows, dtype=torch.int32, device=dev) @@ -570,18 +699,30 @@ def decode_batch_rows( ROW_HEADER_BITS=ROW_HEADER_BITS, ) - # --- 4) Offsets --- - payload_region_start = header_bytes + row_table_bytes + # --- 4) Compute per-row bit offsets into payload region --- + payload_region_start_byte = header_bytes + row_table_bytes + if payload_region_start_byte > total_bytes: + raise ValueError("Truncated payload: missing payload region") + # byte offsets within the payload region row_payload_bytes_64 = row_payload_bytes.to(torch.int64) - row_byte_offsets = torch.cumsum(row_payload_bytes_64, dim=0) - row_payload_bytes_64 - row_bit_offsets = (payload_region_start + row_byte_offsets).to(torch.int32) * 8 + # Sanity check: last row must end within the buffer + last_end = int( + payload_region_start_byte + + row_byte_offsets[-1].item() + + row_payload_bytes_64[-1].item() + ) + if last_end > total_bytes: + raise ValueError("Truncated payload: row payload bytes exceed buffer length") - # --- 5) Decode (GPU Optimized) --- - out_vals = torch.empty((num_rows, K), dtype=torch.int32, device=dev) + row_bit_offsets = ( + payload_region_start_byte + row_byte_offsets + ).to(torch.int32) * 8 + # --- 5) Decode rows in parallel on GPU --- + out_vals = torch.empty((num_rows, K), dtype=torch.int32, device=dev) decode_rows_kernel[(num_rows,)]( payload_gpu, out_vals, @@ -595,4 +736,5 @@ def decode_batch_rows( ) out_vals = torch.cumsum(out_vals, dim=1) - return out_vals.to(torch.int64), C, num_rows \ No newline at end of file + return out_vals.to(torch.int64), C, num_rows + From 0120eb587d1cb57b67774a6e2aa65eeae8989b05 Mon Sep 17 00:00:00 2001 From: Kasper Date: Wed, 19 Nov 2025 15:46:47 +0400 Subject: [PATCH 22/33] make sure dst and gather idx are aligned --- src/tplr/compress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tplr/compress.py b/src/tplr/compress.py index b52094fc3..b6cc544be 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -391,7 +391,7 @@ def decompress( val = val.to(dtype=x.dtype) # second condition for legacy decompress - if len(xshape) > 2 and len(xshape) != len(idx_int64): + if len(xshape) > 2 and len(x) != len(idx_int64): idx_int64 = rearrange(idx_int64, "(y x) h -> y x h", y=xshape[0]) val = rearrange(val, "(y x) h -> y x h", y=xshape[0]) From dba1fb6a7d5d2ae96f8ebff9a84f18460c6e63a5 Mon Sep 17 00:00:00 2001 From: Kasper Date: Wed, 19 Nov 2025 16:12:43 +0400 Subject: [PATCH 23/33] v2 --- src/tplr/compress.py | 4 +- src/tplr/compression/bitops.py | 249 ----------- src/tplr/compression/hybrid.py | 726 +++++++++++++++------------------ src/tplr/neurons.py | 5 +- 4 files changed, 326 insertions(+), 658 deletions(-) delete mode 100644 src/tplr/compression/bitops.py diff --git a/src/tplr/compress.py b/src/tplr/compress.py index b6cc544be..458097634 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -491,7 +491,7 @@ def batch_decompress( v_data = val_list[i] if i_data.dtype == torch.uint8: try: - rows, C, _N = decode_batch_rows(i_data.detach().cpu().numpy().tobytes()) + rows, C, _N = decode_batch_rows(i_data) if C != totalk: raise ValueError(f"Index payload C={C} but expected {totalk}") if any(len(r) != v_data.shape[-1] for r in rows): @@ -501,7 +501,7 @@ def batch_decompress( idx_unpacked = torch.tensor( rows, dtype=torch.int64, device=p.device ).view(*v_data.shape) - except ValueError as e: + except Exception as e: # Fallback: likely old format -> try legacy decoder idx_unpacked = unpack_12bit_indices(i_data.to(p.device), v_data.shape) diff --git a/src/tplr/compression/bitops.py b/src/tplr/compression/bitops.py deleted file mode 100644 index 6302e5692..000000000 --- a/src/tplr/compression/bitops.py +++ /dev/null @@ -1,249 +0,0 @@ -from typing import Union - -import numpy as np -import torch -import triton -import triton.language as tl - -BytesLike = Union[bytes, bytearray, np.ndarray, torch.Tensor] - -@triton.jit -def write_nbits( - u8_ptr, # uint8* global buffer - bit_off_i32, # scalar tl.int32 bit offset - value_u32, # scalar tl.uint32, up to 32 bits used - nbits_i32, # scalar tl.int32, number of bits to write -): - """ - Writes `nbits_i32` least-significant bits of `value_u32` into `u8_ptr` - starting at bit offset `bit_off_i32` in LSB-first order. - - This is still a bit-at-a-time writer; higher-level kernels have been - adjusted to use int32 + shift/mask ahead of time. - """ - j = tl.full((), 0, dtype=tl.int32) - ONE_U32 = tl.full((), 1, dtype=tl.uint32) - - while j < nbits_i32: - pos = bit_off_i32 + j - byte_idx = (pos >> 3).to(tl.int32) - bit_idx = (pos & 7).to(tl.int32) - - old_u8 = tl.load(u8_ptr + byte_idx) - old_u32 = old_u8.to(tl.uint32) - - vbit = (value_u32 >> j) & ONE_U32 - mask = ONE_U32 << bit_idx - new_u32 = (old_u32 & (~mask)) | (vbit << bit_idx) - tl.store(u8_ptr + byte_idx, new_u32.to(tl.uint8)) - j += 1 - return bit_off_i32 + nbits_i32 - - -@triton.jit -def write_nbits_fast( - u8_ptr, - bit_off_i32, # start bit - value_u32, # LSB-first payload bits - nbits_i32, # 0..32 -): - # If nothing to write - if nbits_i32 <= 0: - return bit_off_i32 - - start_bit = bit_off_i32 - first_byte = (start_bit >> 3).to(tl.int32) - first_bit = (start_bit & 7).to(tl.int32) - - # How many bits fit in the first byte - bits_in_first = tl.minimum( - nbits_i32, - tl.full((), 8, dtype=tl.int32) - first_bit, - ) - - # -------- leading partial byte -------- - if bits_in_first > 0: - old_u8 = tl.load(u8_ptr + first_byte).to(tl.uint32) - - # mask for the bits we overwrite inside that byte - mask_u32 = ((tl.full((), 1, tl.uint32) << bits_in_first) - 1) \ - << first_bit - - # extract those bits from value_u32 - bits_u32 = (value_u32 & ((tl.full((), 1, tl.uint32) << bits_in_first) - 1)) \ - << first_bit - - new_u8 = ((old_u8 & ~mask_u32) | bits_u32).to(tl.uint8) - tl.store(u8_ptr + first_byte, new_u8) - - bit_off_i32 += bits_in_first - value_u32 >>= bits_in_first - nbits_i32 -= bits_in_first - - # Now bit_off_i32 is byte aligned (or nbits_i32 == 0) - if nbits_i32 <= 0: - return bit_off_i32 - - cur_byte = (bit_off_i32 >> 3).to(tl.int32) - - # full bytes we can write - full_bytes = (nbits_i32 >> 3).to(tl.int32) # nbits_i32 // 8 - rem_bits = (nbits_i32 & 7).to(tl.int32) - - # -------- full bytes -------- - jb = tl.full((), 0, dtype=tl.int32) - while jb < full_bytes: - # take lowest 8 bits from value_u32 - byte_val = (value_u32 & tl.full((), 0xFF, tl.uint32)).to(tl.uint8) - tl.store(u8_ptr + cur_byte + jb, byte_val) - value_u32 >>= 8 - jb += 1 - - bit_off_i32 += full_bytes * 8 - - # -------- trailing partial byte -------- - if rem_bits > 0: - byte_idx = (bit_off_i32 >> 3).to(tl.int32) - old_u8 = tl.load(u8_ptr + byte_idx).to(tl.uint32) - - mask_u32 = ( (tl.full((), 1, tl.uint32) << rem_bits) - 1 ) - bits_u32 = ( value_u32 & mask_u32 ) - - new_u8 = ((old_u8 & ~mask_u32) | bits_u32).to(tl.uint8) - tl.store(u8_ptr + byte_idx, new_u8) - - bit_off_i32 += rem_bits - return bit_off_i32 - - -@triton.jit -def read_nbits(u8_ptr, bit_off_i32, nbits_i32, limit_bit_i32): - """ - GPU version of BitStreamReader.read_bits (LSB-first), but bounds-safe. - - Reads `nbits_i32` bits starting at `bit_off_i32`, but never loads beyond - bit index `limit_bit_i32` (masked loads return 0 out-of-bounds). - - Returns: (value_u32, new_bit_off_i32) - """ - j = tl.full((), 0, dtype=tl.int32) - val_u32 = tl.full((), 0, dtype=tl.uint32) - ONE_U32 = tl.full((), 1, dtype=tl.uint32) - ZERO_U8 = tl.full((), 0, dtype=tl.uint8) - - while j < nbits_i32: - pos = bit_off_i32 + j - in_bounds = pos < limit_bit_i32 - - byte_idx = (pos >> 3).to(tl.int32) - bit_idx = (pos & 7).to(tl.int32) - - # Masked load: if in_bounds==0, we load ZERO_U8 instead of touching memory. - u8 = tl.load(u8_ptr + byte_idx, mask=in_bounds, other=ZERO_U8) - u32 = u8.to(tl.uint32) - bit = (u32 >> bit_idx) & ONE_U32 - - val_u32 |= (bit << j) - j += 1 - - new_bit_off = bit_off_i32 + nbits_i32 - return val_u32, new_bit_off - - -@triton.jit -def read_nbits_fast(u8_ptr, bit_off_i32, nbits_i32, limit_bit_i32): - if nbits_i32 <= 0: - return tl.full((), 0, tl.uint32), bit_off_i32 - - # clamp to limit if you want to keep the defensive behavior - max_bits = limit_bit_i32 - bit_off_i32 - nbits_i32 = tl.minimum(nbits_i32, max_bits) - - start_bit = bit_off_i32 - end_bit = bit_off_i32 + nbits_i32 - - first_byte = (start_bit >> 3).to(tl.int32) - first_bit = (start_bit & 7).to(tl.int32) - - bits_in_first = tl.minimum( - nbits_i32, - tl.full((), 8, dtype=tl.int32) - first_bit, - ) - - val_u32 = tl.full((), 0, dtype=tl.uint32) - shift = tl.full((), 0, dtype=tl.int32) - - # -------- leading partial byte -------- - if bits_in_first > 0: - byte = tl.load(u8_ptr + first_byte).to(tl.uint32) - mask = ((tl.full((), 1, tl.uint32) << bits_in_first) - 1) << first_bit - chunk = (byte & mask) >> first_bit - val_u32 |= (chunk << shift) - - bit_off_i32 += bits_in_first - shift += bits_in_first - nbits_i32 -= bits_in_first - - if nbits_i32 <= 0: - return val_u32, bit_off_i32 - - cur_byte = (bit_off_i32 >> 3).to(tl.int32) - full_bytes = (nbits_i32 >> 3).to(tl.int32) - rem_bits = (nbits_i32 & 7).to(tl.int32) - - # -------- full bytes -------- - jb = tl.full((), 0, dtype=tl.int32) - while jb < full_bytes: - byte = tl.load(u8_ptr + cur_byte + jb).to(tl.uint32) - val_u32 |= (byte << shift) - shift += 8 - jb += 1 - - bit_off_i32 += full_bytes * 8 - - # -------- trailing partial byte -------- - if rem_bits > 0: - byte = tl.load(u8_ptr + (bit_off_i32 >> 3).to(tl.int32)).to(tl.uint32) - mask = (tl.full((), 1, tl.uint32) << rem_bits) - 1 - chunk = byte & mask - val_u32 |= (chunk << shift) - bit_off_i32 += rem_bits - return val_u32, bit_off_i32 - - -@triton.jit -def read_unary_bounded_triton(u8_ptr, bit_off_i32, end_bit_i32): - """ - GPU version of BitStreamReader.read_unary_bounded(end_bit). - Reads '1's until a '0' or end_bit. - Returns: (q_i32, new_bit_off_i32, hit_end_i32) - - q_i32: number of 1s before the terminating 0 - - hit_end_i32: 1 if we reached end_bit without seeing 0 - 0 if we saw a terminating 0 - """ - ONE_U32 = tl.full((), 1, dtype=tl.uint32) - q_i32 = tl.full((), 0, dtype=tl.int32) - hit_end_i32 = tl.full((), 1, dtype=tl.int32) - - cond = bit_off_i32 < end_bit_i32 - while cond: - pos = bit_off_i32 - byte_idx = (pos >> 3).to(tl.int32) - bit_idx = (pos & 7).to(tl.int32) - - u8 = tl.load(u8_ptr + byte_idx) - u32 = u8.to(tl.uint32) - bit = (u32 >> bit_idx) & ONE_U32 - - bit_off_i32 += 1 - - is_one = (bit == ONE_U32) - q_i32 += is_one.to(tl.int32) - - # If bit is 0, we did NOT hit end - hit_end_i32 = tl.where(is_one, hit_end_i32, - tl.full((), 0, dtype=tl.int32)) - - # Continue only if we are still inside the row and last bit was 1 - cond = (bit_off_i32 < end_bit_i32) & is_one - return q_i32, bit_off_i32, hit_end_i32 \ No newline at end of file diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py index abc7bcc24..324dd520c 100644 --- a/src/tplr/compression/hybrid.py +++ b/src/tplr/compression/hybrid.py @@ -1,14 +1,12 @@ import math -from typing import Dict -from typing import Tuple, Union +import struct +from typing import Dict, Tuple, Union import numpy as np import torch import triton import triton.language as tl -from .bitops import write_nbits_fast, read_unary_bounded_triton, read_nbits_fast - BytesLike = Union[bytes, bytearray, np.ndarray, torch.Tensor] @@ -71,26 +69,24 @@ def encode_batch_rows( vals = torch.cat( (idx_sorted[:, :1], idx_sorted[:, 1:] - idx_sorted[:, :-1]), dim=1, - ) - - # Cast to int32 for Triton kernels - vals = vals.to(torch.int32) + ).to(torch.int32) # k_rice parameters (log2(C // B)) k_rice_choices = tuple(int(math.log2(C // b)) for b in B_choices) num_B_choices = len(B_choices) k_rice_choices_tensor = torch.tensor(k_rice_choices, dtype=torch.int32, device=dev) - # Row header bits (only used for packing row-table header byte) + # Row header bits B_choice_bits = (num_B_choices - 1).bit_length() - ROW_HEADER_BITS = 1 + B_choice_bits # (best_B_idx << 1) | use_bitmap + ROW_HEADER_BITS = 1 + B_choice_bits # Output tensors for cost kernel costs = torch.empty((num_rows, num_B_choices), dtype=torch.int32, device=dev) is_bitmap = torch.empty((num_rows, num_B_choices), dtype=torch.int8, device=dev) + + # Calculate grid for cost kernel grid = (num_rows,) - # cost kernel: bits required for deltas only (no header bits) cost_kernel[grid]( vals, costs, @@ -105,18 +101,14 @@ def encode_batch_rows( min_costs, best_B_idx = torch.min(costs, dim=1) is_bitmap_choice = torch.gather(is_bitmap, 1, best_B_idx.unsqueeze(1)).squeeze(1).to(torch.int32) - # (1) payload bits per row = bits for deltas only - row_payload_bits = min_costs # (rows,) + # Payload sizing + row_payload_bits = min_costs + row_payload_bytes = ((row_payload_bits + 7) // 8).to(torch.int32) - # (2) payload bytes per row (rounded up) - row_payload_bytes = ((row_payload_bits + 7) // 8).to(torch.int32) # (rows,) - - # ensure fit in uint16 for the row table if torch.any(row_payload_bytes > 0xFFFF): raise ValueError("Row payload length exceeds 65535 bytes; cannot store in uint16.") - # byte offsets within the payload region (no gaps) - # row_byte_offsets[r] = sum_{i> 8) & 0xFF).to(torch.uint8) - - # Only the low ROW_HEADER_BITS bits are meaningful, but we just store the byte. - row_table[:, 2] = (headers_i32 & ((1 << ROW_HEADER_BITS) - 1)).to(torch.uint8) + row_table_flat = torch.empty((num_rows, 3), dtype=torch.uint8, device=dev) + row_table_flat[:, 0] = (lengths_i32 & 0xFF).to(torch.uint8) + row_table_flat[:, 1] = ((lengths_i32 >> 8) & 0xFF).to(torch.uint8) + row_table_flat[:, 2] = (headers_i32 & ((1 << ROW_HEADER_BITS) - 1)).to(torch.uint8) - payload_buf[ - global_header_len_bytes : global_header_len_bytes + row_table_bytes - ] = row_table.view(-1) + payload_buf[global_header_len_bytes: global_header_len_bytes + row_table_bytes] = row_table_flat.view(-1) - # compute bit offsets for each row's payload (no per-row length/header in-band) - row_bit_offsets = (payload_region_start + row_byte_offsets).to(torch.int32) * 8 + # Calculate absolute byte offsets for pack kernel + row_abs_byte_offsets = (payload_region_start + row_byte_offsets).to(torch.int32) - # pack payloads + # Pack payloads pack_kernel[(num_rows,)]( vals, payload_buf, - row_bit_offsets, + row_abs_byte_offsets, best_B_idx.to(torch.int32), - is_bitmap_choice, # int32 0/1 + is_bitmap_choice, k_rice_choices_tensor, num_rows, k_dim=k_dim, ) - # meta + # Meta stats b_counts = torch.bincount(best_B_idx, minlength=len(B_choices)) B_hist = {b: c.item() for b, c in zip(B_choices, b_counts)} total_row_bytes = total_payload_bytes + row_entry_bytes * num_rows @@ -201,44 +188,36 @@ def encode_batch_rows( @triton.jit def cost_kernel( - delta_ptr, # (rows, k_dim) IN - costs_ptr, # (rows, num_B_choices) OUT - is_bitmap_ptr, # (rows, num_B_choices) OUT (bool/int) - k_dim: tl.constexpr, # constexpr for tl.arange + delta_ptr, + costs_ptr, + is_bitmap_ptr, + k_dim: tl.constexpr, num_rows: tl.int32, num_B_choices: tl.int32, - k_rice_choices_ptr, # (num_B_choices,) int32 + k_rice_choices_ptr, ): """ - Calculates the compressed bit cost for each row for each B in B_choices. - One program instance processes one row. - Variant B: first delta encoded with Rice, tail optionally bitmap (q in {0,1}). + Calculates bit cost. One row per program instance. """ row_idx = tl.program_id(0) if row_idx >= num_rows: return - # Lane indices for this row (constexpr width) i = tl.arange(0, k_dim) - - # Load entire row of delta-encoded values into SRAM row_base = row_idx * k_dim delta = tl.load(delta_ptr + row_base + i) delta0 = tl.load(delta_ptr + row_base) b_idx = 0 while b_idx < num_B_choices: - # k_rice and M = 1 << k_rice k_rice = tl.load(k_rice_choices_ptr + b_idx) - # q via shift, r via mask q = delta >> k_rice q0 = delta0 >> k_rice - # Pure Rice cost: sum(q + 1) + k_dim * k_rice rice_cost = tl.sum(q + 1) + k_dim * k_rice - # Bitmap cost: first element full Rice, tail has (1 + k_rice) bits + # Bitmap cost: head is Rice, tail is (1 + k_rice) bitmap_cost = (q0 + 1 + k_rice) + (k_dim - 1) * (1 + k_rice) # Allow bitmap only if tail q are in {0,1} @@ -250,246 +229,205 @@ def cost_kernel( out_offset = row_idx * num_B_choices + b_idx tl.store(costs_ptr + out_offset, min_cost) - # make sure is_bitmap is exactly 0/1 in memory - tl.store( - is_bitmap_ptr + out_offset, - tl.where(use_bitmap, 1, 0).to(tl.int32), - ) + tl.store(is_bitmap_ptr + out_offset, tl.where(use_bitmap, 1, 0).to(tl.int32)) b_idx += 1 @triton.jit def pack_kernel( - delta_ptr, # (rows, k_dim) IN int32 - u8_payload_ptr, # (final_buffer_bytes,) OUT uint8 - row_bit_offsets_ptr, # (rows,) IN int32 (bit offset where payload starts) - best_B_idx_ptr, # (rows,) IN int32 - is_bitmap_ptr, # (rows,) IN int32 (0/1) - k_rice_choices_ptr, # [num_B] IN int32 - num_rows: tl.int32, - k_dim: tl.int32, # dynamic + delta_ptr, # (rows, k_dim) IN int32 + u8_payload_ptr, # OUT uint8 + row_abs_byte_offsets_ptr, # (rows,) IN int32 (byte offset where payload starts) + best_B_idx_ptr, # (rows,) IN + is_bitmap_ptr, # (rows,) IN + k_rice_choices_ptr, # [num_B] IN + num_rows: tl.int32, + k_dim: tl.int32, # dynamic ): """ - Writes only the Rice/bitmap-coded payload bits for each row. - - Each program instance handles one row. Bit order is LSB-first. + Writes payload bits using a 64-bit register accumulator. + Uses unaligned byte stores (split into 4 bytes) to prevent cudaErrorMisalignedAddress. """ row_idx = tl.program_id(0) if row_idx >= num_rows: return - # Per-row meta - bit_off_i32 = tl.load(row_bit_offsets_ptr + row_idx).to(tl.int32) + # Load row params + out_byte_off = tl.load(row_abs_byte_offsets_ptr + row_idx).to(tl.int32) b_idx_i32 = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) use_bitmap_i32 = (tl.load(is_bitmap_ptr + row_idx) & 1).to(tl.int32) - - # params k_rice_i32 = tl.load(k_rice_choices_ptr + b_idx_i32).to(tl.int32) M_i32 = (tl.full((), 1, dtype=tl.int32) << k_rice_i32) - ONE_U32 = tl.full((), 1, dtype=tl.uint32) - ZERO_U32 = tl.full((), 0, dtype=tl.uint32) - ONE_I32 = tl.full((), 1, dtype=tl.int32) - THIRTY_ONE_I32 = tl.full((), 31, dtype=tl.int32) - - base = row_idx * k_dim - - # ---- first delta: ALWAYS full Rice (unary + remainder) ---- - if k_dim > 0: - v0 = tl.load(delta_ptr + base).to(tl.int32) - q0 = (v0 >> k_rice_i32).to(tl.int32) - r0 = (v0 & (M_i32 - 1)).to(tl.int32) - - # q0 ones in chunks of <= 31, then a single 0 - q_left = q0 - while q_left > 0: - chunk = tl.minimum(q_left, THIRTY_ONE_I32) - ones = (ONE_U32 << chunk) - ONE_U32 - bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ones, chunk) - q_left -= chunk - - # terminating 0 bit - bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ZERO_U32, ONE_I32) - # remainder - bit_off_i32 = write_nbits_fast( - u8_payload_ptr, bit_off_i32, r0.to(tl.uint32), k_rice_i32 - ) + # Accumulator state + acc_data = tl.full((), 0, dtype=tl.uint64) + acc_bits = tl.full((), 0, dtype=tl.int32) - # ---- tail deltas ---- - i = 1 + # Output pointer (byte-aligned) + out_ptr_base = u8_payload_ptr + out_byte_off + + base_idx = row_idx * k_dim + + i = 0 while i < k_dim: - v = tl.load(delta_ptr + base + i).to(tl.int32) - q = (v >> k_rice_i32).to(tl.int32) - r = (v & (M_i32 - 1)).to(tl.int32) - - # Rice unary only if NOT bitmap - q_left = tl.where(use_bitmap_i32 != 0, tl.full((), 0, dtype=tl.int32), q) - while q_left > 0: - chunk = tl.minimum(q_left, THIRTY_ONE_I32) - ones = (ONE_U32 << chunk) - ONE_U32 - bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ones, chunk) - q_left -= chunk - - # terminating 0 bit only in full-Rice mode - n_term = tl.where(use_bitmap_i32 != 0, tl.full((), 0, dtype=tl.int32), ONE_I32) - bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, ZERO_U32, n_term) - - # bitmap q only if bitmap - q_bit = tl.where(q > 0, ONE_U32, ZERO_U32) - n_qbit = tl.where(use_bitmap_i32 != 0, ONE_I32, tl.full((), 0, dtype=tl.int32)) - bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, q_bit, n_qbit) - - # remainder always - bit_off_i32 = write_nbits_fast(u8_payload_ptr, bit_off_i32, r.to(tl.uint32), k_rice_i32) + val = tl.load(delta_ptr + base_idx + i).to(tl.int32) + + # Compute q, r + q = (val >> k_rice_i32).to(tl.uint64) + r = (val & (M_i32 - 1)).to(tl.uint64) + + is_rice = (i == 0) | (use_bitmap_i32 == 0) + + if is_rice: + # Rice: q '1's, then '0', then k_rice bits of r + q_count = q.to(tl.int32) + while q_count > 0: + acc_data |= (tl.full((), 1, dtype=tl.uint64) << acc_bits) + acc_bits += 1 + q_count -= 1 + + # Flush Check + if acc_bits >= 32: + # Unaligned Store (4 bytes separately) + val_u32 = acc_data.to(tl.uint32) + tl.store(out_ptr_base + 0, (val_u32 & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 1, ((val_u32 >> 8) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 2, ((val_u32 >> 16) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 3, ((val_u32 >> 24) & 0xFF).to(tl.uint8)) + + out_ptr_base += 4 + acc_data >>= 32 + acc_bits -= 32 + + # Append Separator '0' + acc_bits += 1 + else: + # Bitmap: q is 1 bit + q_bit = tl.where(q > 0, 1, 0).to(tl.uint64) + acc_data |= (q_bit << acc_bits) + acc_bits += 1 + + # Flush Check (after separator/bitmap bit) + if acc_bits >= 32: + val_u32 = acc_data.to(tl.uint32) + tl.store(out_ptr_base + 0, (val_u32 & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 1, ((val_u32 >> 8) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 2, ((val_u32 >> 16) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 3, ((val_u32 >> 24) & 0xFF).to(tl.uint8)) + + out_ptr_base += 4 + acc_data >>= 32 + acc_bits -= 32 + + # Append Remainder + acc_data |= (r << acc_bits) + acc_bits += k_rice_i32 + + # Flush Check + if acc_bits >= 32: + val_u32 = acc_data.to(tl.uint32) + tl.store(out_ptr_base + 0, (val_u32 & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 1, ((val_u32 >> 8) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 2, ((val_u32 >> 16) & 0xFF).to(tl.uint8)) + tl.store(out_ptr_base + 3, ((val_u32 >> 24) & 0xFF).to(tl.uint8)) + + out_ptr_base += 4 + acc_data >>= 32 + acc_bits -= 32 + i += 1 + # Final Flush + while acc_bits > 0: + tl.store(out_ptr_base, (acc_data & 0xFF).to(tl.uint8)) + out_ptr_base += 1 + acc_data >>= 8 + acc_bits -= 8 + @triton.jit def parse_header_kernel( - u8_payload_ptr, # (total_bytes,) uint8 - C_out_ptr, # (1,) int32 - K_out_ptr, # (1,) int32 - R_out_ptr, # (1,) int32 NEW: num_rows - num_B_out_ptr, # (1,) int32 - B_choices_out_ptr, # (MAX_B_CHOICES,) int32 - header_bytes_out_ptr, # (1,) int32 - error_flag_ptr, # (1,) int32 - total_bytes: tl.int32, - MAX_B_CHOICES: tl.constexpr, + u8_payload_ptr, # (total_bytes,) uint8 + C_out_ptr, # (1,) int32 + K_out_ptr, # (1,) int32 + R_out_ptr, # (1,) int32 + num_B_out_ptr, # (1,) int32 + B_choices_out_ptr, # (MAX_B_CHOICES,) int32 + header_bytes_out_ptr, # (1,) int32 + max_num_B: tl.constexpr, ): """ - Parse the global header entirely on GPU. - Layout: - 0..3 : "CGRP" - 4..7 : C (uint32 LE) - 8..9 : K (uint16 LE) - 10..13 : R (uint32 LE, num_rows) - 14 : num_B (uint8) - 15.. : B_choices (num_B * 2 bytes, uint16 LE) + Simple GPU kernel to parse the global header. + Replaces CPU struct.unpack to avoid CPU<->GPU synchronization. """ - pid = tl.program_id(0) if pid != 0: return - # ---- init outputs / error ---- - C_val = tl.full((), 0, dtype=tl.int32) - K_val = tl.full((), 0, dtype=tl.int32) - R_val = tl.full((), 0, dtype=tl.int32) - num_B_val = tl.full((), 0, dtype=tl.int32) - header_bytes_i32 = tl.full((), 0, dtype=tl.int32) - err = tl.full((), 0, dtype=tl.int32) - - # ---- basic size + magic checks ---- - # Minimum header size: 15 bytes (without B_choices) - if total_bytes < 15: - err = 1 - else: - # Magic "CGRP" = [67, 71, 82, 80] - m0 = tl.load(u8_payload_ptr + 0) - m1 = tl.load(u8_payload_ptr + 1) - m2 = tl.load(u8_payload_ptr + 2) - m3 = tl.load(u8_payload_ptr + 3) - cond_magic = (m0 == 67) & (m1 == 71) & (m2 == 82) & (m3 == 80) - bad_magic = cond_magic == 0 - err = tl.where(bad_magic, tl.full((), 2, dtype=tl.int32), err) - - # ---- C, K, R, num_B ---- - if err == 0: - # C (uint32 LE at bytes 4..7) - b4 = tl.load(u8_payload_ptr + 4).to(tl.int32) - b5 = tl.load(u8_payload_ptr + 5).to(tl.int32) - b6 = tl.load(u8_payload_ptr + 6).to(tl.int32) - b7 = tl.load(u8_payload_ptr + 7).to(tl.int32) - C_val = b4 | (b5 << 8) | (b6 << 16) | (b7 << 24) - - # K (uint16 LE at bytes 8..9) - b8 = tl.load(u8_payload_ptr + 8).to(tl.int32) - b9 = tl.load(u8_payload_ptr + 9).to(tl.int32) - K_val = b8 | (b9 << 8) - - # R (uint32 LE at bytes 10..13) - b10 = tl.load(u8_payload_ptr + 10).to(tl.int32) - b11 = tl.load(u8_payload_ptr + 11).to(tl.int32) - b12 = tl.load(u8_payload_ptr + 12).to(tl.int32) - b13 = tl.load(u8_payload_ptr + 13).to(tl.int32) - R_val = b10 | (b11 << 8) | (b12 << 16) | (b13 << 24) - - # num_B at byte 14 - num_B_val = tl.load(u8_payload_ptr + 14).to(tl.int32) - invalid_num_B = (num_B_val <= 0) | (num_B_val > MAX_B_CHOICES) - err = tl.where(invalid_num_B, tl.full((), 3, dtype=tl.int32), err) - - # ---- read B_choices in a structured loop (no break/return) ---- - off = tl.full((), 15, dtype=tl.int32) # B_choices start at byte 15 - i = tl.full((), 0, dtype=tl.int32) - - while i < MAX_B_CHOICES: - need_this = (i < num_B_val) & (err == 0) - - if need_this: - cond_in_bounds = (off + 1) < total_bytes - if cond_in_bounds: - lo = tl.load(u8_payload_ptr + off).to(tl.int32) - hi = tl.load(u8_payload_ptr + off + 1).to(tl.int32) - B_val = lo | (hi << 8) - tl.store(B_choices_out_ptr + i, B_val) - off += 2 - else: - err = tl.full((), 4, dtype=tl.int32) - tl.store(B_choices_out_ptr + i, tl.full((), 0, dtype=tl.int32)) - else: - tl.store(B_choices_out_ptr + i, tl.full((), 0, dtype=tl.int32)) - - i += 1 - - # header_bytes = 15 + 2 * num_B (only meaningful if err == 0) - if err == 0: - header_bytes_i32 = 15 + (num_B_val * 2) - - # ---- store outputs ---- + # C (uint32 LE at bytes 4..7) + b4 = tl.load(u8_payload_ptr + 4).to(tl.int32) + b5 = tl.load(u8_payload_ptr + 5).to(tl.int32) + b6 = tl.load(u8_payload_ptr + 6).to(tl.int32) + b7 = tl.load(u8_payload_ptr + 7).to(tl.int32) + C_val = b4 | (b5 << 8) | (b6 << 16) | (b7 << 24) tl.store(C_out_ptr, C_val) + + # K (uint16 LE at bytes 8..9) + b8 = tl.load(u8_payload_ptr + 8).to(tl.int32) + b9 = tl.load(u8_payload_ptr + 9).to(tl.int32) + K_val = b8 | (b9 << 8) tl.store(K_out_ptr, K_val) + + # R (uint32 LE at bytes 10..13) + b10 = tl.load(u8_payload_ptr + 10).to(tl.int32) + b11 = tl.load(u8_payload_ptr + 11).to(tl.int32) + b12 = tl.load(u8_payload_ptr + 12).to(tl.int32) + b13 = tl.load(u8_payload_ptr + 13).to(tl.int32) + R_val = b10 | (b11 << 8) | (b12 << 16) | (b13 << 24) tl.store(R_out_ptr, R_val) + + # num_B at byte 14 + num_B_val = tl.load(u8_payload_ptr + 14).to(tl.int32) tl.store(num_B_out_ptr, num_B_val) - tl.store(header_bytes_out_ptr, header_bytes_i32) - tl.store(error_flag_ptr, err) + + # Read B_choices (start at 15) + off = 15 + i = 0 + while i < max_num_B: + if i < num_B_val: + lo = tl.load(u8_payload_ptr + off).to(tl.int32) + hi = tl.load(u8_payload_ptr + off + 1).to(tl.int32) + B_val = lo | (hi << 8) + tl.store(B_choices_out_ptr + i, B_val) + off += 2 + i += 1 + + tl.store(header_bytes_out_ptr, off) @triton.jit def parse_row_table_kernel( - u8_payload_ptr, # (total_bytes,) uint8 - row_payload_bytes_ptr, # (num_rows,) int32 - best_B_idx_ptr, # (num_rows,) int32 - use_bitmap_ptr, # (num_rows,) int32 - row_table_start: tl.int32, - num_rows: tl.int32, - ROW_HEADER_BITS: tl.constexpr, + u8_payload_ptr, + row_payload_bytes_ptr, + best_B_idx_ptr, + use_bitmap_ptr, + row_table_start_ptr, # int32* from header kernel + num_rows_ptr, # int32* from header kernel + ROW_HEADER_BITS: tl.constexpr, ): - """ - Parse the row table: - - For each row r: - offset = row_table_start + r * 3 - length_bytes[r] = uint16 LE at offset - header_byte = uint8 at offset + 2 - header_bits = header_byte & ((1 << ROW_HEADER_BITS) - 1) - use_bitmap[r] = header_bits & 1 - best_B_idx[r] = header_bits >> 1 - """ pid = tl.program_id(0) + num_rows = tl.load(num_rows_ptr) if pid >= num_rows: return + row_table_start = tl.load(row_table_start_ptr) entry_offset = row_table_start + pid * 3 - # length_bytes: uint16 LE b0 = tl.load(u8_payload_ptr + entry_offset).to(tl.int32) b1 = tl.load(u8_payload_ptr + entry_offset + 1).to(tl.int32) length_i32 = b0 | (b1 << 8) tl.store(row_payload_bytes_ptr + pid, length_i32) - # header byte header_byte = tl.load(u8_payload_ptr + entry_offset + 2).to(tl.int32) header_mask = (tl.full((), 1, dtype=tl.int32) << ROW_HEADER_BITS) - 1 header_i32 = header_byte & header_mask @@ -501,144 +439,155 @@ def parse_row_table_kernel( tl.store(best_B_idx_ptr + pid, best_B_idx_i32) - @triton.jit def decode_rows_kernel( - u8_payload_ptr, # (total_bytes,) uint8 - out_vals_ptr, # (num_rows * K,) int32 - row_bit_offsets_ptr, # (num_rows,) int32 (bit offset of first encoded bit) - row_payload_bytes_ptr, # (num_rows,) int32 - best_B_idx_ptr, # (num_rows,) int32 - use_bitmap_ptr, # (num_rows,) int32 - k_rice_choices_ptr, # (num_B,) int32 - num_rows: tl.int32, - K: tl.int32, + u8_payload_ptr, + out_vals_ptr, + row_byte_offsets_ptr, + best_B_idx_ptr, + use_bitmap_ptr, + k_rice_choices_ptr, + num_rows_ptr, + K_ptr, + total_payload_bytes: tl.int32, # kept for signature compatibility (unused) ): """ - Fully GPU decode of Rice/bitmap rows. - - For each row r: - - Bit range: - start_bit = row_bit_offsets[r] - end_bit = start_bit + row_payload_bytes[r] * 8 - - First value: full Rice (unary + remainder) - - Tail: Rice or bitmap+remainder depending on use_bitmap[r]. + Decodes each row's Rice/bitmap bitstream into *final* prefix-summed values. + + Uses a simple streaming bit-buffer over bytes: + - bitbuf: uint64 (low bits are next bits in the stream) + - bits_in_buf: int (# of valid bits in bitbuf) + - byte_offset: int (index into u8_payload_ptr) """ row_idx = tl.program_id(0) + num_rows = tl.load(num_rows_ptr) if row_idx >= num_rows: return - # Per-row metadata - row_start_bit_i32 = tl.load(row_bit_offsets_ptr + row_idx).to(tl.int32) - payload_bytes_i32 = tl.load(row_payload_bytes_ptr + row_idx).to(tl.int32) - best_B_idx_i32 = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) - use_bitmap_i32 = (tl.load(use_bitmap_ptr + row_idx) & 1).to(tl.int32) + K = tl.load(K_ptr) - # k_rice and M for this row - k_rice_i32 = tl.load(k_rice_choices_ptr + best_B_idx_i32).to(tl.int32) - M_i32 = (tl.full((), 1, dtype=tl.int32) << k_rice_i32) + # Row params + start_byte = tl.load(row_byte_offsets_ptr + row_idx).to(tl.int32) + b_idx = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) + use_bitmap = (tl.load(use_bitmap_ptr + row_idx) & 1).to(tl.int32) - # Bit range of this row - row_end_bit_i32 = row_start_bit_i32 + payload_bytes_i32 * 8 + k_rice = tl.load(k_rice_choices_ptr + b_idx).to(tl.int32) + M = (tl.full((), 1, dtype=tl.int32) << k_rice) - base_out = row_idx * K - ONE_I32 = tl.full((), 1, dtype=tl.int32) + # Streaming bit-buffer state + byte_offset = start_byte + bitbuf = tl.full((), 0, dtype=tl.uint64) + bits_in_buf = tl.full((), 0, dtype=tl.int32) - bit_off_i32 = row_start_bit_i32 - - # ---- first value: ALWAYS full Rice ---- - if K > 0: - q0_i32, bit_off_i32, hit_end0_i32 = read_unary_bounded_triton( - u8_payload_ptr, - bit_off_i32, - row_end_bit_i32, - ) - r0_u32, bit_off_i32 = read_nbits_fast( - u8_payload_ptr, - bit_off_i32, - k_rice_i32, - row_end_bit_i32, # limit - ) - r0_i32 = r0_u32.to(tl.int32) - v0_i32 = q0_i32 * M_i32 + r0_i32 - tl.store(out_vals_ptr + base_out, v0_i32) + base_out = row_idx * K + prev = tl.full((), 0, dtype=tl.int32) - # ---- tail values ---- - i = tl.full((), 1, dtype=tl.int32) + i = 0 while i < K: - if use_bitmap_i32 != 0: - # Bitmap mode: q is 1 bit in {0,1} - q_bit_u32, bit_off_i32 = read_nbits_fast( - u8_payload_ptr, - bit_off_i32, - ONE_I32, - row_end_bit_i32, - ) - q_i32 = q_bit_u32.to(tl.int32) - - r_u32, bit_off_i32 = read_nbits_fast( - u8_payload_ptr, - bit_off_i32, - k_rice_i32, - row_end_bit_i32, - ) - r_i32 = r_u32.to(tl.int32) + is_rice = (i == 0) | (use_bitmap == 0) + + # --- Decode q --- + q = tl.full((), 0, dtype=tl.int32) + + if is_rice: + # Unary code: q times '1' then a '0' + reading = 1 + while reading > 0: + # Ensure at least one bit in buffer + if bits_in_buf == 0: + next_byte = tl.load(u8_payload_ptr + byte_offset).to(tl.uint64) + byte_offset += 1 + bitbuf |= (next_byte << bits_in_buf) + bits_in_buf += 8 + + bit = (bitbuf & 1).to(tl.int32) + bitbuf >>= 1 + bits_in_buf -= 1 + + if bit == 1: + q += 1 + else: + reading = 0 else: - # Full Rice mode - q_i32, bit_off_i32, hit_end_i32 = read_unary_bounded_triton( - u8_payload_ptr, - bit_off_i32, - row_end_bit_i32, - ) - r_u32, bit_off_i32 = read_nbits_fast( - u8_payload_ptr, - bit_off_i32, - k_rice_i32, - row_end_bit_i32, - ) - r_i32 = r_u32.to(tl.int32) - - v_i32 = q_i32 * M_i32 + r_i32 - tl.store(out_vals_ptr + base_out + i, v_i32) + # Bitmap mode: q is 1 bit + if bits_in_buf == 0: + next_byte = tl.load(u8_payload_ptr + byte_offset).to(tl.uint64) + byte_offset += 1 + bitbuf |= (next_byte << bits_in_buf) + bits_in_buf += 8 + q = (bitbuf & 1).to(tl.int32) + bitbuf >>= 1 + bits_in_buf -= 1 + + # --- Decode remainder r (k_rice bits) --- + r = tl.full((), 0, dtype=tl.int32) + if k_rice > 0: + # Ensure enough bits for r + while bits_in_buf < k_rice: + next_byte = tl.load(u8_payload_ptr + byte_offset).to(tl.uint64) + byte_offset += 1 + bitbuf |= (next_byte << bits_in_buf) + bits_in_buf += 8 + + mask = (tl.full((), 1, dtype=tl.uint64) << k_rice) - 1 + r_u64 = (bitbuf & mask) + bitbuf >>= k_rice + bits_in_buf -= k_rice + r = r_u64.to(tl.int32) + + val = q * M + r + + # In-kernel prefix sum (delta decode) + prev += val + tl.store(out_vals_ptr + base_out + i, prev) + i += 1 def decode_batch_rows( - payload: BytesLike, - max_num_B: int = 16, + payload: BytesLike, + max_num_B: int = 16, ) -> tuple[torch.Tensor, int, int]: + """ + Decode a payload produced by encode_batch_rows. + Returns: + idx (torch.Tensor[int64]): (rows, K) sorted indices + C (int): vocabulary size parameter + R (int): number of rows + """ if not torch.cuda.is_available(): - raise RuntimeError("decode_batch_rows_gpu requires CUDA") + raise RuntimeError("CUDA required") - # --- Move payload to CUDA (if needed) --- + # Move to GPU/Tensor if isinstance(payload, torch.Tensor): - assert payload.dtype == torch.uint8 payload_gpu = payload if payload.is_cuda else payload.cuda() elif isinstance(payload, np.ndarray): - assert payload.dtype == np.uint8 payload_gpu = torch.from_numpy(payload).to("cuda", dtype=torch.uint8) - elif isinstance(payload, (bytes, bytearray)): + else: arr = np.frombuffer(bytes(payload), dtype=np.uint8) payload_gpu = torch.from_numpy(arr).to("cuda", dtype=torch.uint8) - else: - raise TypeError("Unsupported payload type") payload_gpu = payload_gpu.contiguous() dev = payload_gpu.device total_bytes = int(payload_gpu.numel()) + if total_bytes == 0: - empty = torch.empty((0, 0), dtype=torch.int64, device=dev) - return empty, 0, 0 + return torch.empty((0, 0), dtype=torch.int64, device=dev), 0, 0 - # --- 1) Parse global header on GPU --- + # Pad payload on GPU with a few zero bytes to make safe over-reads trivial + padded = torch.zeros(total_bytes + 8, dtype=torch.uint8, device=dev) + padded[:total_bytes].copy_(payload_gpu) + payload_gpu = padded + total_bytes_padded = int(payload_gpu.numel()) + + # --- 1) Parse Header (GPU Kernel) --- C_out = torch.empty(1, dtype=torch.int32, device=dev) K_out = torch.empty(1, dtype=torch.int32, device=dev) R_out = torch.empty(1, dtype=torch.int32, device=dev) num_B_out = torch.empty(1, dtype=torch.int32, device=dev) B_choices_out = torch.empty(max_num_B, dtype=torch.int32, device=dev) header_bytes_out = torch.empty(1, dtype=torch.int32, device=dev) - err_flag = torch.zeros(1, dtype=torch.int32, device=dev) parse_header_kernel[(1,)]( payload_gpu, @@ -648,43 +597,25 @@ def decode_batch_rows( num_B_out, B_choices_out, header_bytes_out, - err_flag, - total_bytes, - MAX_B_CHOICES=max_num_B, + max_num_B=max_num_B ) - torch.cuda.synchronize() - err = int(err_flag.cpu().item()) - if err != 0: - raise ValueError(f"parse_header_kernel failed with error code {err}") - - C = int(C_out.cpu().item()) - K = int(K_out.cpu().item()) - num_rows = int(R_out.cpu().item()) - num_B = int(num_B_out.cpu().item()) - header_bytes = int(header_bytes_out.cpu().item()) - B_choices_list = [int(x) for x in B_choices_out[:num_B].cpu().tolist()] + # Minimal sync to get scalar values needed for kernel setup + C = int(C_out.item()) + num_B = int(num_B_out.item()) + num_rows = int(R_out.item()) + B_choices_list = B_choices_out[:num_B].cpu().tolist() - # --- 2) Build k_rice choices on CPU -> move to GPU --- + # --- 2) Prepare k_rice --- k_rice_choices = [] for B in B_choices_list: M = C // B - if M <= 0 or (M & (M - 1)) != 0: - raise ValueError(f"M=C//B={M} not power of two for B={B}") k_rice_choices.append(int(math.log2(M))) - k_rice_choices_tensor = torch.tensor( - k_rice_choices, dtype=torch.int32, device=dev - ) - - B_choice_bits = (num_B - 1).bit_length() - ROW_HEADER_BITS = 1 + B_choice_bits + k_rice_choices_tensor = torch.tensor(k_rice_choices, dtype=torch.int32, device=dev) - # --- 3) Parse row table on GPU --- - row_entry_bytes = 3 - row_table_bytes = num_rows * row_entry_bytes - if header_bytes + row_table_bytes > total_bytes: - raise ValueError("Truncated payload: row table exceeds payload length") + ROW_HEADER_BITS = 1 + (num_B - 1).bit_length() + # --- 3) Parse Row Table (GPU) --- row_payload_bytes = torch.empty(num_rows, dtype=torch.int32, device=dev) best_B_idx = torch.empty(num_rows, dtype=torch.int32, device=dev) use_bitmap = torch.empty(num_rows, dtype=torch.int32, device=dev) @@ -694,47 +625,36 @@ def decode_batch_rows( row_payload_bytes, best_B_idx, use_bitmap, - int(header_bytes), - int(num_rows), + header_bytes_out, + R_out, ROW_HEADER_BITS=ROW_HEADER_BITS, ) - # --- 4) Compute per-row bit offsets into payload region --- - payload_region_start_byte = header_bytes + row_table_bytes - if payload_region_start_byte > total_bytes: - raise ValueError("Truncated payload: missing payload region") + # --- 4) Offsets (all on GPU) --- + # Row table starts immediately after global header + payload_region_start = header_bytes_out + (R_out * 3) # in BYTES, on GPU tensor - # byte offsets within the payload region row_payload_bytes_64 = row_payload_bytes.to(torch.int64) - row_byte_offsets = torch.cumsum(row_payload_bytes_64, dim=0) - row_payload_bytes_64 + # Exclusive prefix sum for offsets + row_byte_offsets_rel = torch.cumsum(row_payload_bytes_64, dim=0) - row_payload_bytes_64 + # Absolute byte offsets + row_byte_offsets = (payload_region_start + row_byte_offsets_rel).to(torch.int32) - # Sanity check: last row must end within the buffer - last_end = int( - payload_region_start_byte - + row_byte_offsets[-1].item() - + row_payload_bytes_64[-1].item() - ) - if last_end > total_bytes: - raise ValueError("Truncated payload: row payload bytes exceed buffer length") - - row_bit_offsets = ( - payload_region_start_byte + row_byte_offsets - ).to(torch.int32) * 8 - - # --- 5) Decode rows in parallel on GPU --- + # --- 5) Decode (GPU Optimized) --- + K = int(K_out.item()) out_vals = torch.empty((num_rows, K), dtype=torch.int32, device=dev) + decode_rows_kernel[(num_rows,)]( payload_gpu, out_vals, - row_bit_offsets, - row_payload_bytes, + row_byte_offsets, best_B_idx, use_bitmap, k_rice_choices_tensor, - int(num_rows), - int(K), + R_out, + K_out, + total_bytes_padded ) - out_vals = torch.cumsum(out_vals, dim=1) - return out_vals.to(torch.int64), C, num_rows - + # No host-side cumsum here: kernel already returns prefix sums + return out_vals.to(torch.int64), C, num_rows \ No newline at end of file diff --git a/src/tplr/neurons.py b/src/tplr/neurons.py index 410a18268..39d72b5d3 100644 --- a/src/tplr/neurons.py +++ b/src/tplr/neurons.py @@ -111,7 +111,6 @@ def barrier(group=None): param.device, non_blocking=True ) compression_time = 0 - copy_time = 0 encode_time = 0 for _, (n, p) in enumerate(model_iterator, 1): owned = n in miner.owned_params @@ -183,7 +182,6 @@ def barrier(group=None): # --- 7) Pack outputs (move compressed artifacts to CPU asynchronously) --- # Using non_blocking=True for async D2H transfers when CUDA is available - copy_start = tplr.T() if isinstance(idxs, torch.Tensor): if torch.cuda.is_available(): cpu_idxs = torch.empty_like(idxs, device="cpu", pin_memory=True) @@ -209,9 +207,8 @@ def barrier(group=None): # Clear per-param grad p.grad = None - copy_time += tplr.T() - copy_start - tplr.logger.info(f"times: {encode_time}, {compression_time}, {copy_time}") + tplr.logger.info(f"times: {encode_time}, {compression_time}") # Batch offload all error feedback tensors to CPU with pinned memory for name in miner.error_feedback: From db49bb334a102fb53754aa32f704c79292593504 Mon Sep 17 00:00:00 2001 From: Kasper Date: Wed, 19 Nov 2025 16:22:46 +0400 Subject: [PATCH 24/33] log exception --- src/tplr/compress.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tplr/compress.py b/src/tplr/compress.py index 458097634..4d4746388 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -502,6 +502,7 @@ def batch_decompress( rows, dtype=torch.int64, device=p.device ).view(*v_data.shape) except Exception as e: + tplr.logger.warning(f"Failed to unpack: {e} Falling back to legacy uncompress.") # Fallback: likely old format -> try legacy decoder idx_unpacked = unpack_12bit_indices(i_data.to(p.device), v_data.shape) From 8a77ca186e590440ee453f401c6e5ec366879086 Mon Sep 17 00:00:00 2001 From: Kasper Date: Wed, 19 Nov 2025 19:21:13 +0400 Subject: [PATCH 25/33] cleanup --- src/tplr/compression/hybrid.py | 11 ++++++----- tests/unit/test_compress.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py index 324dd520c..bf4dae76b 100644 --- a/src/tplr/compression/hybrid.py +++ b/src/tplr/compression/hybrid.py @@ -358,7 +358,6 @@ def parse_header_kernel( ): """ Simple GPU kernel to parse the global header. - Replaces CPU struct.unpack to avoid CPU<->GPU synchronization. """ pid = tl.program_id(0) if pid != 0: @@ -449,7 +448,6 @@ def decode_rows_kernel( k_rice_choices_ptr, num_rows_ptr, K_ptr, - total_payload_bytes: tl.int32, # kept for signature compatibility (unused) ): """ Decodes each row's Rice/bitmap bitstream into *final* prefix-summed values. @@ -574,12 +572,16 @@ def decode_batch_rows( if total_bytes == 0: return torch.empty((0, 0), dtype=torch.int64, device=dev), 0, 0 + if total_bytes < 15: + raise ValueError("Malformed payload - too few bytes") + magic = bytes(payload_gpu[:4].cpu().tolist()) + if magic != b"CGRP": + raise ValueError("Invalid magic header") # Pad payload on GPU with a few zero bytes to make safe over-reads trivial padded = torch.zeros(total_bytes + 8, dtype=torch.uint8, device=dev) padded[:total_bytes].copy_(payload_gpu) payload_gpu = padded - total_bytes_padded = int(payload_gpu.numel()) # --- 1) Parse Header (GPU Kernel) --- C_out = torch.empty(1, dtype=torch.int32, device=dev) @@ -652,8 +654,7 @@ def decode_batch_rows( use_bitmap, k_rice_choices_tensor, R_out, - K_out, - total_bytes_padded + K_out ) # No host-side cumsum here: kernel already returns prefix sums diff --git a/tests/unit/test_compress.py b/tests/unit/test_compress.py index 934697e33..4f655ae91 100644 --- a/tests/unit/test_compress.py +++ b/tests/unit/test_compress.py @@ -247,6 +247,38 @@ def test_batch_decompress_with_norm_options( ) assert result_with_norms.shape == xshape + def test_batch_decompress_with_legacy_packing( + self, compress_instance: TopKCompressor[Literal[False]] + ): + p = torch.zeros(8, 128) # 1024 elements total, last dim=128 + xshape = (8, 128) + totalk = 128 + + # Create test data with Rice/bitmap encoded format + idx_orig = torch.tensor([[0, 1, 2, 3]], dtype=torch.int64) # Even count + idx_packed = pack_12bit_indices(idx_orig) + idx = [idx_packed] + val = [torch.tensor([[10.0, 20.0, 30.0, 40.0]], dtype=torch.float32)] + + # Test with normalisation + result_norm = compress_instance.batch_decompress( + p, idx, val, xshape, totalk, normalise=True + ) + assert result_norm.shape == xshape + + # Test with clip_norm + result_clip = compress_instance.batch_decompress( + p, idx, val, xshape, totalk, clip_norm=True + ) + assert result_clip.shape == xshape + + # Test with block_norms provided + block_norms = torch.tensor([15.0]) + result_with_norms = compress_instance.batch_decompress( + p, idx, val, xshape, totalk, block_norms=block_norms, clip_norm=True + ) + assert result_with_norms.shape == xshape + class TestChunkingTransformer: """Test ChunkingTransformer using actual implementation""" From 044c89b5b41ff74cd7458e7ad6f2d7c51cbdce5b Mon Sep 17 00:00:00 2001 From: Kasper Date: Fri, 21 Nov 2025 12:08:27 +0100 Subject: [PATCH 26/33] =?UTF-8?q?remove=20GPU=E2=86=92CPU=E2=86=92GPU=20co?= =?UTF-8?q?pies=20in=20decompress=20and=20batch=5Fdecompress?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/tplr/compress.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/tplr/compress.py b/src/tplr/compress.py index 4d4746388..dc629fdee 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -372,15 +372,15 @@ def decompress( # Decode indices if idx.dtype == torch.uint8: - rows_list, C, _N = decode_batch_rows(idx) + rows, C, _N = decode_batch_rows(idx) if C != totalk: raise ValueError(f"Index payload C={C} but expected {totalk}") - k = val.shape[-1] - if any(len(r) != k for r in rows_list): - raise ValueError("Row-wise topk size mismatch in index payload") - idx_int64 = torch.tensor( - rows_list, dtype=torch.int64, device=p.device - ).view(*val.shape) + if rows.shape[-1] != val.shape[-1]: + raise ValueError( + f"Row-wise topk size mismatch: decoded K={rows.shape[-1]}, val K={val.shape[-1]}" + ) + idx_int64 = rows.to(device=p.device, dtype=torch.int64) + idx_int64 = idx_int64.view_as(val) elif idx.dtype in (torch.int64, torch.long): idx_int64 = idx.to(p.device) else: @@ -494,13 +494,13 @@ def batch_decompress( rows, C, _N = decode_batch_rows(i_data) if C != totalk: raise ValueError(f"Index payload C={C} but expected {totalk}") - if any(len(r) != v_data.shape[-1] for r in rows): + if rows.shape[-1] != v_data.shape[-1]: raise ValueError( - "Row-wise topk size mismatch in index payload (batch)" + f"Row-wise topk size mismatch: decoded K={rows.shape[-1]}, " + f"val K={v_data.shape[-1]}" ) - idx_unpacked = torch.tensor( - rows, dtype=torch.int64, device=p.device - ).view(*v_data.shape) + idx_int64 = rows.to(device=p.device, dtype=torch.int64) + idx_unpacked = idx_int64.view_as(v_data) except Exception as e: tplr.logger.warning(f"Failed to unpack: {e} Falling back to legacy uncompress.") # Fallback: likely old format -> try legacy decoder From 8577338a29ce1a6ab83b697ddb5348766bf4eff3 Mon Sep 17 00:00:00 2001 From: Kasper Date: Fri, 21 Nov 2025 15:02:35 +0100 Subject: [PATCH 27/33] update decompression checks in comms and cleanup --- src/tplr/comms.py | 80 ++++++++++++++++++++++++-------------------- src/tplr/compress.py | 19 ++++++----- 2 files changed, 55 insertions(+), 44 deletions(-) diff --git a/src/tplr/comms.py b/src/tplr/comms.py index 9b90c2c9e..961924b33 100644 --- a/src/tplr/comms.py +++ b/src/tplr/comms.py @@ -2517,46 +2517,54 @@ def check_compressed_indices( if idxs.numel() == 0: raise ValueError(f"[{param_name}] Empty indices payload") - # Decode (CPU) and perform structural checks try: - payload_bytes = idxs.detach().cpu().numpy().tobytes() - rows_list, C, N = decode_batch_rows(payload_bytes) - except Exception as e: - raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}") - - if C != totalk: - raise ValueError( - f"[{param_name}] Payload column size C={C} but expected {totalk}" - ) + rows, C, N = decode_batch_rows(idxs) + if C != totalk: + raise ValueError( + f"[{param_name}] Payload column size C={C} but expected {totalk}" + ) + # compute expected rows from values shape (flatten all but last dim) + if vals.ndim == 0: + raise ValueError(f"[{param_name}] Values tensor has no top‑k dimension") + expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1 + if N != expected_rows: + raise ValueError( + f"[{param_name}] Payload rows N={N} but values imply {expected_rows}" + ) - # compute expected rows from values shape (flatten all but last dim) - if vals.ndim == 0: - raise ValueError(f"[{param_name}] Values tensor has no top‑k dimension") - expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1 - if N != expected_rows: - raise ValueError( - f"[{param_name}] Payload rows N={N} but values imply {expected_rows}" - ) + k = vals.shape[-1] + if k != allowed_topk: + raise ValueError( + f"[{param_name}] Payload K={rows.shape[-1]} but values top-k={k}" + ) - k = vals.shape[-1] - if k != allowed_topk: - raise ValueError( - f"[{param_name}] Values top‑k={k} but allowed_topk={allowed_topk}" - ) - if any(len(r) != k for r in rows_list): - raise ValueError( - f"[{param_name}] At least one row has mismatched top‑k size" - ) + # Bounds check + if rows.numel() > 0: + min_idx = int(rows.min().item()) + max_idx = int(rows.max().item()) + else: + min_idx, max_idx = 0, -1 + if min_idx < 0 or max_idx >= totalk: + raise ValueError( + f"[{param_name}] Index out of bounds (min={min_idx}, max={max_idx}, totalk={totalk})" + ) + except ValueError as e: + # NB: legacy path + tplr.logger.warning(f"Failed to unpack: {e} Falling back to legacy uncompress.") + # Fallback: likely old format -> try legacy decoder + try: + unpacked = unpack_12bit_indices(idxs, vals.shape) + if unpacked.shape[-1] != allowed_topk: + raise ValueError( + f"[{param_name}] Invalid topk dimension: " + f"shape[-1]={unpacked.shape[-1]} but expected {allowed_topk}" + ) + _bounds_check(unpacked) + except Exception as e: + raise ValueError(f"[{param_name}] Failed to unpack legacy 12-bit indices: {e}") + except Exception as e: + raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}") - # bounds check without materialising full tensor - max_idx = max((max(r) if len(r) > 0 else -1) for r in rows_list) - min_idx = ( - min((min(r) if len(r) > 0 else 0) for r in rows_list) if rows_list else 0 - ) - if min_idx < 0 or max_idx >= totalk: - raise ValueError( - f"[{param_name}] Index out of bounds (min={min_idx}, max={max_idx}, totalk={totalk})" - ) async def s3_get_object_size(self, bucket: Bucket, key: str) -> int | None: """ diff --git a/src/tplr/compress.py b/src/tplr/compress.py index dc629fdee..a6fe1014c 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -20,7 +20,6 @@ # Global imports -import numpy as np import math from typing import Generic, Literal, Sequence, TypeAlias, TypeVar, cast, overload @@ -317,14 +316,17 @@ def compress(self, x: torch.Tensor, topk: int): # type: ignore[override] totalk = x.shape[-1] topk = self._clamp_topk(x, topk) - idx_int64 = torch.topk( + topk_vals, idx = torch.topk( x.abs(), k=topk, dim=-1, largest=True, sorted=False - ).indices - val = torch.gather(x, dim=-1, index=idx_int64) + ) + del topk_vals + + idx = idx.to(torch.int32) + val = torch.gather(x, dim=-1, index=idx) # Flatten to [rows, k] for the codec - idx2d = idx_int64.reshape(-1, topk).to(torch.int32) - val2d = val.reshape(-1, topk) + idx2d = idx.view(-1, topk) + val2d = val.view(-1, topk) # sort indices and apply same perm to values idx_sorted, perm = torch.sort(idx2d, dim=1) @@ -501,8 +503,9 @@ def batch_decompress( ) idx_int64 = rows.to(device=p.device, dtype=torch.int64) idx_unpacked = idx_int64.view_as(v_data) - except Exception as e: - tplr.logger.warning(f"Failed to unpack: {e} Falling back to legacy uncompress.") + except ValueError as e: + # NB: legacy path + tplr.logger.warning(f"Failed to unpack: {e}. Falling back to legacy uncompress.") # Fallback: likely old format -> try legacy decoder idx_unpacked = unpack_12bit_indices(i_data.to(p.device), v_data.shape) From 436b2eb8540ded2a34d7e8ff6400cd68931cef1d Mon Sep 17 00:00:00 2001 From: Kasper Date: Fri, 21 Nov 2025 15:15:25 +0100 Subject: [PATCH 28/33] make b choices a param --- hparams/hparams.json | 1 + neurons/miner.py | 1 + neurons/validator.py | 1 + scripts/analyser.py | 1 + src/tplr/compress.py | 26 +++++++++++++++++++++++--- 5 files changed, 27 insertions(+), 3 deletions(-) diff --git a/hparams/hparams.json b/hparams/hparams.json index 6583e9ce0..8485f7f8c 100644 --- a/hparams/hparams.json +++ b/hparams/hparams.json @@ -46,6 +46,7 @@ "checkpoint_init_version": "2.1.15", "checkpoint_init_window": 59637, "num_evaluation_bins": 5, + "compression_b_choices": [32, 64, 128], "quantization_bins": 4, "quantization_range": 6, "burn_rate": 0.5, diff --git a/neurons/miner.py b/neurons/miner.py index 84bb43d15..f6d21c6e4 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -218,6 +218,7 @@ def __init__(self): use_quantization=True, quantization_bins=self.hparams.quantization_bins, quantization_range=self.hparams.quantization_range, + b_choices=self.hparams.compression_b_choices, ) tplr.logger.info("[Init] compression pipeline ready") diff --git a/neurons/validator.py b/neurons/validator.py index ce16a33be..961e12910 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -384,6 +384,7 @@ def __init__(self): use_quantization=True, quantization_bins=self.hparams.quantization_bins, quantization_range=self.hparams.quantization_range, + b_choices=self.hparams.compression_b_choices, ) # Init optimizer diff --git a/scripts/analyser.py b/scripts/analyser.py index d48167132..3e916ffc7 100644 --- a/scripts/analyser.py +++ b/scripts/analyser.py @@ -88,6 +88,7 @@ def __init__(self): use_quantization=True, quantization_bins=self.hparams.quantization_bins, quantization_range=self.hparams.quantization_range, + b_choices=self.hparams.compression_b_choices, ) # Initialize shapes for each parameter (like in miner/validator) diff --git a/src/tplr/compress.py b/src/tplr/compress.py index a6fe1014c..23e96e26f 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -49,8 +49,6 @@ # Boolean flag that propagates the chosen quantisation mode Q = TypeVar("Q", Literal[True], Literal[False]) -_DEFAULT_B_CHOICES: tuple[int, ...] = (32, 64) - class ChunkingTransformer: """ @@ -211,9 +209,12 @@ class TopKCompressor(Generic[Q]): It supports both 1D and 2D tensors. """ + DEFAULT_B_CHOICES: tuple[int, ...] = (64, 128) + use_quantization: Q n_bins: int range_in_sigmas: int + b_choices: tuple[int, ...] # for Rice/bitmap codec # ------------------------------------------------------------------ # # Constructor – two overloads so each instance "remembers" its mode @@ -225,6 +226,7 @@ def __init__( use_quantization: Literal[True] = True, quantization_bins: int = 256, quantization_range: int = 6, + b_choices: tuple[int, ...] | None = None, ) -> None: ... @overload @@ -234,6 +236,7 @@ def __init__( use_quantization: Literal[False] = False, quantization_bins: int = 256, quantization_range: int = 6, + b_choices: tuple[int, ...] | None = None, ) -> None: ... @torch.no_grad() @@ -243,6 +246,7 @@ def __init__( use_quantization: bool = False, quantization_bins: int = 256, quantization_range: int = 6, + b_choices: tuple[int, ...] | None = None, ) -> None: """ Initialise the TopKCompressor. @@ -259,6 +263,14 @@ def __init__( quantization_range # Quantization range in standard deviations ) + if b_choices is None: + b_choices = self.DEFAULT_B_CHOICES + b_choices = tuple(sorted(int(b) for b in b_choices)) + for b in b_choices: + if b <= 0 or (b & (b - 1)) != 0: + raise ValueError(f"b_choices must be powers of two > 0, got {b}") + self.b_choices = b_choices + def _clamp_topk(self, x, topk) -> int: """ Clamp the top-k value to be within the valid range and ensure it's even. @@ -331,7 +343,15 @@ def compress(self, x: torch.Tensor, topk: int): # type: ignore[override] # sort indices and apply same perm to values idx_sorted, perm = torch.sort(idx2d, dim=1) val = torch.gather(val2d, dim=1, index=perm) - idx_bytes, _meta = encode_batch_rows(idx_sorted, C=totalk, B_choices=_DEFAULT_B_CHOICES) + + # pick only B_choices that divide this layer's C + valid_B = tuple(b for b in self.b_choices if totalk % b == 0) + if not valid_B: + raise ValueError( + f"No valid b_choices for C={totalk}; " + f"b_choices={self.b_choices} must divide C" + ) + idx_bytes, _meta = encode_batch_rows(idx_sorted, C=totalk, B_choices=valid_B) # Apply 8-bit quantization if enabled if self.use_quantization: From 93ab9cc5e70f63d278e90f4bc869e5b9a8af1d15 Mon Sep 17 00:00:00 2001 From: Kasper Date: Fri, 21 Nov 2025 15:24:21 +0100 Subject: [PATCH 29/33] fix linting errors --- src/tplr/comms.py | 2 +- src/tplr/compress.py | 3 +-- src/tplr/compression/__init__.py | 5 +---- src/tplr/compression/pack12.py | 1 + src/tplr/neurons.py | 2 +- tests/unit/test_compress.py | 3 +-- 6 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/tplr/comms.py b/src/tplr/comms.py index 961924b33..7540b8426 100644 --- a/src/tplr/comms.py +++ b/src/tplr/comms.py @@ -29,12 +29,12 @@ from functools import partial from types import SimpleNamespace from typing import Any, Literal, cast -import numpy as np import aiofiles import bittensor as bt import boto3 import botocore +import numpy as np import torch from aiobotocore.client import AioBaseClient from aiobotocore.session import get_session diff --git a/src/tplr/compress.py b/src/tplr/compress.py index 23e96e26f..c8610ca81 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -29,8 +29,7 @@ from torch.distributed.tensor import DTensor as DT import tplr - -from tplr.compression import encode_batch_rows, decode_batch_rows, unpack_12bit_indices +from tplr.compression import decode_batch_rows, encode_batch_rows, unpack_12bit_indices # ─────────── type aliases ──────────────────────────────────────────────── # primitive shapes diff --git a/src/tplr/compression/__init__.py b/src/tplr/compression/__init__.py index 0c9666af6..c55ae998d 100644 --- a/src/tplr/compression/__init__.py +++ b/src/tplr/compression/__init__.py @@ -20,11 +20,8 @@ decode_batch_rows, # decoder (CPU) encode_batch_rows, # GPU-accelerated encoder → bytes + perm + meta ) +from .pack12 import pack_12bit_indices, unpack_12bit_indices -from .pack12 import ( - pack_12bit_indices, - unpack_12bit_indices -) __all__ = [ "encode_batch_rows", "decode_batch_rows", diff --git a/src/tplr/compression/pack12.py b/src/tplr/compression/pack12.py index ecc2617ad..b38a9219b 100644 --- a/src/tplr/compression/pack12.py +++ b/src/tplr/compression/pack12.py @@ -1,5 +1,6 @@ import torch + def pack_12bit_indices(indices: torch.Tensor) -> torch.Tensor: """ Pack int64 indices into 12-bit representation. diff --git a/src/tplr/neurons.py b/src/tplr/neurons.py index 39d72b5d3..88a1e7f5a 100644 --- a/src/tplr/neurons.py +++ b/src/tplr/neurons.py @@ -34,8 +34,8 @@ from wandb.sdk.wandb_run import Run import tplr -from tplr.distributed import dist_helper from tplr.compression import decode_batch_rows +from tplr.distributed import dist_helper if TYPE_CHECKING: from neurons.miner import Miner diff --git a/tests/unit/test_compress.py b/tests/unit/test_compress.py index 4f655ae91..5a49e2c3b 100644 --- a/tests/unit/test_compress.py +++ b/tests/unit/test_compress.py @@ -10,9 +10,8 @@ TopKCompressor, _dct, _get_smaller_split, - _idct + _idct, ) - from tplr.compression import encode_batch_rows, pack_12bit_indices, unpack_12bit_indices From 95b281ae316e892d78ce6a8ffadbb2299d5a3e4a Mon Sep 17 00:00:00 2001 From: Kasper Date: Fri, 21 Nov 2025 15:37:10 +0100 Subject: [PATCH 30/33] revert validator --- neurons/validator.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/neurons/validator.py b/neurons/validator.py index 961e12910..8e673510f 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -392,37 +392,18 @@ def __init__(self): self.xshapes = {} self.totalks = {} - - import time - total_compress_time = 0.0 - total_encode_time = 0.0 for n, p in self.model.named_parameters(): - tplr.logger.info(f"[COMPRESS START] {n}: shape={p.shape}") - - encode_start = time.time() enc = self.transformer.encode( torch.empty(p.shape, dtype=torch.float16, device=self.device), use_dct=self.hparams.use_dct, ) - encode_time = time.time() - encode_start - - compress_start = time.time() _, _, xshape, totalk, _ = self.compressor.compress( enc, self.hparams.topk_compression, ) - compress_time = time.time() - compress_start - self.xshapes[n] = xshape self.totalks[n] = totalk - total_encode_time += encode_time - total_compress_time += compress_time - - tplr.logger.info(f"[COMPRESS TIMING] {n}: encode={encode_time:.3f}s, compress={compress_time:.3f}s, shape={p.shape}") - - tplr.logger.info(f"[COMPRESS TIMING TOTAL] encode={total_encode_time:.3f}s, compress={total_compress_time:.3f}s") - self.openskill_model = PlackettLuce( beta=self.hparams.openskill_beta, tau=self.hparams.openskill_tau ) From 030c7f83c0a115821c78e3061689e6d6a676f97a Mon Sep 17 00:00:00 2001 From: Kasper Date: Fri, 21 Nov 2025 15:53:23 +0100 Subject: [PATCH 31/33] fix missing import --- src/tplr/comms.py | 2 +- src/tplr/neurons.py | 11 +---------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/src/tplr/comms.py b/src/tplr/comms.py index 7540b8426..82e6c90e0 100644 --- a/src/tplr/comms.py +++ b/src/tplr/comms.py @@ -45,7 +45,7 @@ import tplr from tplr.chain import ChainManager from tplr.compress import TopKCompressor -from tplr.compression import decode_batch_rows +from tplr.compression import decode_batch_rows, unpack_12bit_indices from tplr.config import BUCKET_SECRETS, client_config from tplr.schemas import Bucket, CommsGetResult diff --git a/src/tplr/neurons.py b/src/tplr/neurons.py index 88a1e7f5a..dafb69787 100644 --- a/src/tplr/neurons.py +++ b/src/tplr/neurons.py @@ -110,8 +110,7 @@ def barrier(group=None): miner.error_feedback[n] = miner.error_feedback[n].to( param.device, non_blocking=True ) - compression_time = 0 - encode_time = 0 + for _, (n, p) in enumerate(model_iterator, 1): owned = n in miner.owned_params p_is_dt = is_dtensor(p) @@ -153,11 +152,8 @@ def barrier(group=None): error_feedback.add_(grad_full) # --- 4) Encode & compress (owner only) --- - encode_start = tplr.T() encoded = miner.transformer.encode(error_feedback, use_dct=use_dct) - encode_time += tplr.T() - encode_start - compress_start = tplr.T() idxs, vals, xshape, totalk, quant_params = miner.compressor.compress( encoded, topk ) @@ -168,17 +164,14 @@ def barrier(group=None): decompressed = miner.compressor.decompress( p, idxs, vals, xshape, totalk, quant_params ) - compression_time += tplr.T() - compress_start # --- 6) Decode & error-feedback update (owner only) --- - encode_start = tplr.T() transmit_grad = miner.transformer.decode(decompressed, use_dct=use_dct) del decompressed error_feedback.sub_(transmit_grad) # Keep error feedback on GPU for now, batch offload later miner.error_feedback[n] = error_feedback del transmit_grad, error_feedback - encode_time += tplr.T() - encode_start # --- 7) Pack outputs (move compressed artifacts to CPU asynchronously) --- # Using non_blocking=True for async D2H transfers when CUDA is available @@ -208,8 +201,6 @@ def barrier(group=None): # Clear per-param grad p.grad = None - tplr.logger.info(f"times: {encode_time}, {compression_time}") - # Batch offload all error feedback tensors to CPU with pinned memory for name in miner.error_feedback: if ( From aa66e2dc44d0af3123bc8505d213f727001242d4 Mon Sep 17 00:00:00 2001 From: Kasper Date: Fri, 21 Nov 2025 15:59:19 +0100 Subject: [PATCH 32/33] fix formatting --- src/tplr/comms.py | 9 ++- src/tplr/compress.py | 13 +-- src/tplr/compression/__init__.py | 3 +- src/tplr/compression/hybrid.py | 134 ++++++++++++++++--------------- src/tplr/compression/pack12.py | 6 +- tests/unit/test_compress.py | 33 ++++---- 6 files changed, 103 insertions(+), 95 deletions(-) diff --git a/src/tplr/comms.py b/src/tplr/comms.py index 82e6c90e0..0eeadf38e 100644 --- a/src/tplr/comms.py +++ b/src/tplr/comms.py @@ -2550,7 +2550,9 @@ def check_compressed_indices( ) except ValueError as e: # NB: legacy path - tplr.logger.warning(f"Failed to unpack: {e} Falling back to legacy uncompress.") + tplr.logger.warning( + f"Failed to unpack: {e} Falling back to legacy uncompress." + ) # Fallback: likely old format -> try legacy decoder try: unpacked = unpack_12bit_indices(idxs, vals.shape) @@ -2561,11 +2563,12 @@ def check_compressed_indices( ) _bounds_check(unpacked) except Exception as e: - raise ValueError(f"[{param_name}] Failed to unpack legacy 12-bit indices: {e}") + raise ValueError( + f"[{param_name}] Failed to unpack legacy 12-bit indices: {e}" + ) except Exception as e: raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}") - async def s3_get_object_size(self, bucket: Bucket, key: str) -> int | None: """ Retrieves the size of an S3 object without downloading its content. diff --git a/src/tplr/compress.py b/src/tplr/compress.py index c8610ca81..b03decd13 100644 --- a/src/tplr/compress.py +++ b/src/tplr/compress.py @@ -327,9 +327,7 @@ def compress(self, x: torch.Tensor, topk: int): # type: ignore[override] totalk = x.shape[-1] topk = self._clamp_topk(x, topk) - topk_vals, idx = torch.topk( - x.abs(), k=topk, dim=-1, largest=True, sorted=False - ) + topk_vals, idx = torch.topk(x.abs(), k=topk, dim=-1, largest=True, sorted=False) del topk_vals idx = idx.to(torch.int32) @@ -358,7 +356,6 @@ def compress(self, x: torch.Tensor, topk: int): # type: ignore[override] return idx_bytes, val, xshape, totalk, qparams return idx_bytes, val, xshape, totalk - @torch.no_grad() def decompress( self, @@ -524,9 +521,13 @@ def batch_decompress( idx_unpacked = idx_int64.view_as(v_data) except ValueError as e: # NB: legacy path - tplr.logger.warning(f"Failed to unpack: {e}. Falling back to legacy uncompress.") + tplr.logger.warning( + f"Failed to unpack: {e}. Falling back to legacy uncompress." + ) # Fallback: likely old format -> try legacy decoder - idx_unpacked = unpack_12bit_indices(i_data.to(p.device), v_data.shape) + idx_unpacked = unpack_12bit_indices( + i_data.to(p.device), v_data.shape + ) unpacked_indices.append(idx_unpacked) elif i_data.dtype in (torch.int64, torch.long): diff --git a/src/tplr/compression/__init__.py b/src/tplr/compression/__init__.py index c55ae998d..31d94bfb8 100644 --- a/src/tplr/compression/__init__.py +++ b/src/tplr/compression/__init__.py @@ -1,4 +1,3 @@ - # The MIT License (MIT) # © 2025 tplr.ai # @@ -27,4 +26,4 @@ "decode_batch_rows", "pack_12bit_indices", "unpack_12bit_indices", -] \ No newline at end of file +] diff --git a/src/tplr/compression/hybrid.py b/src/tplr/compression/hybrid.py index bf4dae76b..e8e5fb27f 100644 --- a/src/tplr/compression/hybrid.py +++ b/src/tplr/compression/hybrid.py @@ -12,10 +12,7 @@ @torch.no_grad() def encode_batch_rows( - idx_sorted: torch.Tensor, - *, - C: int, - B_choices: Tuple[int, ...] = (64, 128) + idx_sorted: torch.Tensor, *, C: int, B_choices: Tuple[int, ...] = (64, 128) ) -> Tuple[BytesLike, Dict]: """ Compresses a 2D int64 tensor of Top-K indices into a byte string @@ -45,7 +42,9 @@ def encode_batch_rows( raise RuntimeError("CUDA is required for this function.") if not isinstance(idx_sorted, torch.Tensor) or idx_sorted.ndim != 2: - raise ValueError(f"idx must be a 2D int64 tensor, got {idx_sorted.shape} {idx_sorted.dtype}") + raise ValueError( + f"idx must be a 2D int64 tensor, got {idx_sorted.shape} {idx_sorted.dtype}" + ) if not all(isinstance(b, int) and (b & (b - 1) == 0) and b > 0 for b in B_choices): raise ValueError(f"All B_choices must be powers of two, got {B_choices}") @@ -58,7 +57,7 @@ def encode_batch_rows( return b"", { "total_bits": 0, "avg_bits_per_row": 0.0, - "B_hist": {b: 0 for b in B_choices} + "B_hist": {b: 0 for b in B_choices}, } if not idx_sorted.is_cuda: @@ -99,22 +98,25 @@ def encode_batch_rows( # Best choice per row min_costs, best_B_idx = torch.min(costs, dim=1) - is_bitmap_choice = torch.gather(is_bitmap, 1, best_B_idx.unsqueeze(1)).squeeze(1).to(torch.int32) + is_bitmap_choice = ( + torch.gather(is_bitmap, 1, best_B_idx.unsqueeze(1)).squeeze(1).to(torch.int32) + ) # Payload sizing row_payload_bits = min_costs row_payload_bytes = ((row_payload_bits + 7) // 8).to(torch.int32) if torch.any(row_payload_bytes > 0xFFFF): - raise ValueError("Row payload length exceeds 65535 bytes; cannot store in uint16.") + raise ValueError( + "Row payload length exceeds 65535 bytes; cannot store in uint16." + ) # Byte offsets if num_rows == 1: row_byte_offsets = torch.zeros(1, dtype=torch.int32, device=dev) else: row_byte_offsets = torch.nn.functional.pad( - torch.cumsum(row_payload_bytes, dim=0, dtype=torch.int32)[:-1], - (1, 0) + torch.cumsum(row_payload_bytes, dim=0, dtype=torch.int32)[:-1], (1, 0) ) total_payload_bytes = int(row_payload_bytes.sum().item()) @@ -154,7 +156,9 @@ def encode_batch_rows( row_table_flat[:, 1] = ((lengths_i32 >> 8) & 0xFF).to(torch.uint8) row_table_flat[:, 2] = (headers_i32 & ((1 << ROW_HEADER_BITS) - 1)).to(torch.uint8) - payload_buf[global_header_len_bytes: global_header_len_bytes + row_table_bytes] = row_table_flat.view(-1) + payload_buf[global_header_len_bytes : global_header_len_bytes + row_table_bytes] = ( + row_table_flat.view(-1) + ) # Calculate absolute byte offsets for pack kernel row_abs_byte_offsets = (payload_region_start + row_byte_offsets).to(torch.int32) @@ -188,13 +192,13 @@ def encode_batch_rows( @triton.jit def cost_kernel( - delta_ptr, - costs_ptr, - is_bitmap_ptr, - k_dim: tl.constexpr, - num_rows: tl.int32, - num_B_choices: tl.int32, - k_rice_choices_ptr, + delta_ptr, + costs_ptr, + is_bitmap_ptr, + k_dim: tl.constexpr, + num_rows: tl.int32, + num_B_choices: tl.int32, + k_rice_choices_ptr, ): """ Calculates bit cost. One row per program instance. @@ -235,14 +239,14 @@ def cost_kernel( @triton.jit def pack_kernel( - delta_ptr, # (rows, k_dim) IN int32 - u8_payload_ptr, # OUT uint8 - row_abs_byte_offsets_ptr, # (rows,) IN int32 (byte offset where payload starts) - best_B_idx_ptr, # (rows,) IN - is_bitmap_ptr, # (rows,) IN - k_rice_choices_ptr, # [num_B] IN - num_rows: tl.int32, - k_dim: tl.int32, # dynamic + delta_ptr, # (rows, k_dim) IN int32 + u8_payload_ptr, # OUT uint8 + row_abs_byte_offsets_ptr, # (rows,) IN int32 (byte offset where payload starts) + best_B_idx_ptr, # (rows,) IN + is_bitmap_ptr, # (rows,) IN + k_rice_choices_ptr, # [num_B] IN + num_rows: tl.int32, + k_dim: tl.int32, # dynamic ): """ Writes payload bits using a 64-bit register accumulator. @@ -257,7 +261,7 @@ def pack_kernel( b_idx_i32 = tl.load(best_B_idx_ptr + row_idx).to(tl.int32) use_bitmap_i32 = (tl.load(is_bitmap_ptr + row_idx) & 1).to(tl.int32) k_rice_i32 = tl.load(k_rice_choices_ptr + b_idx_i32).to(tl.int32) - M_i32 = (tl.full((), 1, dtype=tl.int32) << k_rice_i32) + M_i32 = tl.full((), 1, dtype=tl.int32) << k_rice_i32 # Accumulator state acc_data = tl.full((), 0, dtype=tl.uint64) @@ -282,7 +286,7 @@ def pack_kernel( # Rice: q '1's, then '0', then k_rice bits of r q_count = q.to(tl.int32) while q_count > 0: - acc_data |= (tl.full((), 1, dtype=tl.uint64) << acc_bits) + acc_data |= tl.full((), 1, dtype=tl.uint64) << acc_bits acc_bits += 1 q_count -= 1 @@ -304,7 +308,7 @@ def pack_kernel( else: # Bitmap: q is 1 bit q_bit = tl.where(q > 0, 1, 0).to(tl.uint64) - acc_data |= (q_bit << acc_bits) + acc_data |= q_bit << acc_bits acc_bits += 1 # Flush Check (after separator/bitmap bit) @@ -320,7 +324,7 @@ def pack_kernel( acc_bits -= 32 # Append Remainder - acc_data |= (r << acc_bits) + acc_data |= r << acc_bits acc_bits += k_rice_i32 # Flush Check @@ -347,14 +351,14 @@ def pack_kernel( @triton.jit def parse_header_kernel( - u8_payload_ptr, # (total_bytes,) uint8 - C_out_ptr, # (1,) int32 - K_out_ptr, # (1,) int32 - R_out_ptr, # (1,) int32 - num_B_out_ptr, # (1,) int32 - B_choices_out_ptr, # (MAX_B_CHOICES,) int32 - header_bytes_out_ptr, # (1,) int32 - max_num_B: tl.constexpr, + u8_payload_ptr, # (total_bytes,) uint8 + C_out_ptr, # (1,) int32 + K_out_ptr, # (1,) int32 + R_out_ptr, # (1,) int32 + num_B_out_ptr, # (1,) int32 + B_choices_out_ptr, # (MAX_B_CHOICES,) int32 + header_bytes_out_ptr, # (1,) int32 + max_num_B: tl.constexpr, ): """ Simple GPU kernel to parse the global header. @@ -406,13 +410,13 @@ def parse_header_kernel( @triton.jit def parse_row_table_kernel( - u8_payload_ptr, - row_payload_bytes_ptr, - best_B_idx_ptr, - use_bitmap_ptr, - row_table_start_ptr, # int32* from header kernel - num_rows_ptr, # int32* from header kernel - ROW_HEADER_BITS: tl.constexpr, + u8_payload_ptr, + row_payload_bytes_ptr, + best_B_idx_ptr, + use_bitmap_ptr, + row_table_start_ptr, # int32* from header kernel + num_rows_ptr, # int32* from header kernel + ROW_HEADER_BITS: tl.constexpr, ): pid = tl.program_id(0) num_rows = tl.load(num_rows_ptr) @@ -440,14 +444,14 @@ def parse_row_table_kernel( @triton.jit def decode_rows_kernel( - u8_payload_ptr, - out_vals_ptr, - row_byte_offsets_ptr, - best_B_idx_ptr, - use_bitmap_ptr, - k_rice_choices_ptr, - num_rows_ptr, - K_ptr, + u8_payload_ptr, + out_vals_ptr, + row_byte_offsets_ptr, + best_B_idx_ptr, + use_bitmap_ptr, + k_rice_choices_ptr, + num_rows_ptr, + K_ptr, ): """ Decodes each row's Rice/bitmap bitstream into *final* prefix-summed values. @@ -470,7 +474,7 @@ def decode_rows_kernel( use_bitmap = (tl.load(use_bitmap_ptr + row_idx) & 1).to(tl.int32) k_rice = tl.load(k_rice_choices_ptr + b_idx).to(tl.int32) - M = (tl.full((), 1, dtype=tl.int32) << k_rice) + M = tl.full((), 1, dtype=tl.int32) << k_rice # Streaming bit-buffer state byte_offset = start_byte @@ -495,7 +499,7 @@ def decode_rows_kernel( if bits_in_buf == 0: next_byte = tl.load(u8_payload_ptr + byte_offset).to(tl.uint64) byte_offset += 1 - bitbuf |= (next_byte << bits_in_buf) + bitbuf |= next_byte << bits_in_buf bits_in_buf += 8 bit = (bitbuf & 1).to(tl.int32) @@ -511,7 +515,7 @@ def decode_rows_kernel( if bits_in_buf == 0: next_byte = tl.load(u8_payload_ptr + byte_offset).to(tl.uint64) byte_offset += 1 - bitbuf |= (next_byte << bits_in_buf) + bitbuf |= next_byte << bits_in_buf bits_in_buf += 8 q = (bitbuf & 1).to(tl.int32) bitbuf >>= 1 @@ -524,11 +528,11 @@ def decode_rows_kernel( while bits_in_buf < k_rice: next_byte = tl.load(u8_payload_ptr + byte_offset).to(tl.uint64) byte_offset += 1 - bitbuf |= (next_byte << bits_in_buf) + bitbuf |= next_byte << bits_in_buf bits_in_buf += 8 mask = (tl.full((), 1, dtype=tl.uint64) << k_rice) - 1 - r_u64 = (bitbuf & mask) + r_u64 = bitbuf & mask bitbuf >>= k_rice bits_in_buf -= k_rice r = r_u64.to(tl.int32) @@ -543,8 +547,8 @@ def decode_rows_kernel( def decode_batch_rows( - payload: BytesLike, - max_num_B: int = 16, + payload: BytesLike, + max_num_B: int = 16, ) -> tuple[torch.Tensor, int, int]: """ Decode a payload produced by encode_batch_rows. @@ -599,7 +603,7 @@ def decode_batch_rows( num_B_out, B_choices_out, header_bytes_out, - max_num_B=max_num_B + max_num_B=max_num_B, ) # Minimal sync to get scalar values needed for kernel setup @@ -638,7 +642,9 @@ def decode_batch_rows( row_payload_bytes_64 = row_payload_bytes.to(torch.int64) # Exclusive prefix sum for offsets - row_byte_offsets_rel = torch.cumsum(row_payload_bytes_64, dim=0) - row_payload_bytes_64 + row_byte_offsets_rel = ( + torch.cumsum(row_payload_bytes_64, dim=0) - row_payload_bytes_64 + ) # Absolute byte offsets row_byte_offsets = (payload_region_start + row_byte_offsets_rel).to(torch.int32) @@ -654,8 +660,8 @@ def decode_batch_rows( use_bitmap, k_rice_choices_tensor, R_out, - K_out + K_out, ) # No host-side cumsum here: kernel already returns prefix sums - return out_vals.to(torch.int64), C, num_rows \ No newline at end of file + return out_vals.to(torch.int64), C, num_rows diff --git a/src/tplr/compression/pack12.py b/src/tplr/compression/pack12.py index b38a9219b..9f3de6dc7 100644 --- a/src/tplr/compression/pack12.py +++ b/src/tplr/compression/pack12.py @@ -52,7 +52,9 @@ def pack_12bit_indices(indices: torch.Tensor) -> torch.Tensor: return packed -def unpack_12bit_indices(packed: torch.Tensor, values_shape: tuple[int, ...] ) -> torch.Tensor: +def unpack_12bit_indices( + packed: torch.Tensor, values_shape: tuple[int, ...] +) -> torch.Tensor: """ Unpack 12-bit packed indices back to int64. Assumes even number of indices. @@ -92,4 +94,4 @@ def unpack_12bit_indices(packed: torch.Tensor, values_shape: tuple[int, ...] ) - # Reshape to match values shape indices = indices.reshape(values_shape) - return indices \ No newline at end of file + return indices diff --git a/tests/unit/test_compress.py b/tests/unit/test_compress.py index 5a49e2c3b..25d12f59f 100644 --- a/tests/unit/test_compress.py +++ b/tests/unit/test_compress.py @@ -31,7 +31,7 @@ def compress_instance_quantized(self) -> TopKCompressor[Literal[True]]: ) def test_compress_produces_rice_bitmap_indices( - self, compress_instance: TopKCompressor[Literal[False]] + self, compress_instance: TopKCompressor[Literal[False]] ): """Test that compress() produces Rice/bitmap encoded indices""" # Create test tensor @@ -50,7 +50,7 @@ def test_compress_produces_rice_bitmap_indices( assert totalk == x.shape[-1] # For 2D tensor, it's the last dimension def test_compress_with_quantization( - self, compress_instance_quantized: TopKCompressor[Literal[True]] + self, compress_instance_quantized: TopKCompressor[Literal[True]] ): """Test compression with quantization enabled""" x = torch.randn(8, 128) # 1024 elements total, last dim=64 @@ -70,7 +70,7 @@ def test_compress_with_quantization( assert len(qparams) == 5 # shift, scale, offset, lookup, orig_dtype def test_decompress_with_rice_bitmap_format( - self, compress_instance: TopKCompressor[Literal[False]] + self, compress_instance: TopKCompressor[Literal[False]] ): """Test that decompress can handle Rice/bitmap encoded format""" # Setup @@ -84,7 +84,9 @@ def test_decompress_with_rice_bitmap_format( # Pack using the new encoder format idx_bytes, _ = encode_batch_rows(original_indices, C=totalk) - val = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32) + val = torch.tensor( + [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32 + ) # Test decompression with packed format result = compress_instance.decompress(p, idx_bytes, val, xshape, totalk) @@ -92,7 +94,7 @@ def test_decompress_with_rice_bitmap_format( assert result.dtype == p.dtype def test_batch_decompress_multiple_rice_bitmap_formats( - self, compress_instance: TopKCompressor[Literal[False]] + self, compress_instance: TopKCompressor[Literal[False]] ): """Test batch_decompress with multiple Rice/bitmap encoded indices""" # Setup @@ -122,7 +124,7 @@ def test_batch_decompress_multiple_rice_bitmap_formats( assert result.dtype == p.dtype def test_compress_decompress_round_trip( - self, compress_instance: TopKCompressor[Literal[False]] + self, compress_instance: TopKCompressor[Literal[False]] ): """Test full compress-decompress round trip""" x = torch.zeros(8, 128) # 1024 elements total, last dim=128 @@ -157,7 +159,7 @@ def test_compress_decompress_round_trip( assert torch.allclose(top_vals, expected_vals, atol=1e-5) def test_encode_compress_decompress_round_trip( - self, compress_instance: TopKCompressor[Literal[False]] + self, compress_instance: TopKCompressor[Literal[False]] ): class SimpleModel(nn.Module): def __init__(self): @@ -166,7 +168,7 @@ def __init__(self): self.layer2 = nn.Linear(64, 256) target_chunk = 16 - transform = ChunkingTransformer( SimpleModel(), target_chunk) + transform = ChunkingTransformer(SimpleModel(), target_chunk) x = torch.zeros(256, 64) x[0, 0] = 1.0 x[1, 1] = 2.0 @@ -175,19 +177,14 @@ def __init__(self): topk = 4 encoded = transform.encode(x) - idxs, vals, xshape, totalk = compress_instance.compress( - encoded, topk - ) + idxs, vals, xshape, totalk = compress_instance.compress(encoded, topk) p = torch.zeros_like(x) - decompressed = compress_instance.decompress( - p, idxs, vals, xshape, totalk - ) + decompressed = compress_instance.decompress(p, idxs, vals, xshape, totalk) assert torch.allclose(encoded, decompressed, atol=1e-5) - def test_rice_bitmap_index_value_range( - self, compress_instance: TopKCompressor[Literal[False]] + self, compress_instance: TopKCompressor[Literal[False]] ): """Test that Rice/bitmap codec can handle large index ranges efficiently""" # Create a large tensor that would have indices beyond 8-bit range @@ -214,7 +211,7 @@ def test_rice_bitmap_index_value_range( ) def test_batch_decompress_with_norm_options( - self, compress_instance: TopKCompressor[Literal[False]] + self, compress_instance: TopKCompressor[Literal[False]] ): """Test batch_decompress with normalisation and clip_norm options""" p = torch.zeros(8, 128) # 1024 elements total, last dim=128 @@ -247,7 +244,7 @@ def test_batch_decompress_with_norm_options( assert result_with_norms.shape == xshape def test_batch_decompress_with_legacy_packing( - self, compress_instance: TopKCompressor[Literal[False]] + self, compress_instance: TopKCompressor[Literal[False]] ): p = torch.zeros(8, 128) # 1024 elements total, last dim=128 xshape = (8, 128) From 74d092f75de3818f9ecf88bb1a7e0ac16c9dfef2 Mon Sep 17 00:00:00 2001 From: Kasper Date: Mon, 24 Nov 2025 11:45:23 +0100 Subject: [PATCH 33/33] update comms test --- tests/test_comms.py | 187 +++++++++++++++++++++++--------------------- 1 file changed, 99 insertions(+), 88 deletions(-) diff --git a/tests/test_comms.py b/tests/test_comms.py index 6aea4f417..388896d5e 100644 --- a/tests/test_comms.py +++ b/tests/test_comms.py @@ -13,7 +13,7 @@ from datetime import datetime, timedelta, timezone from tplr import load_hparams -from tplr.compress import pack_12bit_indices +from tplr.compression import encode_batch_rows, pack_12bit_indices hparams = load_hparams() @@ -48,7 +48,7 @@ def create_xshapes_totalks(model): def create_valid_state_dict(model): state_dict = {} for name, _ in model.named_parameters(): - # Create 12-bit packed format + # Create legacy 12-bit packed format indices = torch.tensor([0, 1], dtype=torch.long) packed_data = pack_12bit_indices(indices) state_dict[name + "idxs"] = packed_data @@ -65,9 +65,9 @@ def create_missing_idxs(model): def create_packed_indices(indices_list): - """Helper function to create 12-bit packed indices from a list""" + """Helper function to create legacy 12-bit packed indices from a list""" indices = torch.tensor(indices_list, dtype=torch.long) - # Ensure even number of indices for 12-bit packing + # Ensure even number of indices for legacy 12-bit packing if len(indices_list) % 2 != 0: indices = torch.cat([indices, torch.tensor([0], dtype=torch.long)]) packed_data = pack_12bit_indices(indices) @@ -244,7 +244,7 @@ async def test_gather_basic_functionality(comms_instance, dummy_compressor): data={ "0.weightidxs": create_packed_indices( [0, 1, 2, 3] - ), # Even count for 12-bit + ), # Even count for legacy 12-bit "0.weightvals": torch.tensor([0.4, 0.5, 0.6, 0.7]), "totalks": {"0.weight": totalk_value}, }, @@ -255,7 +255,7 @@ async def test_gather_basic_functionality(comms_instance, dummy_compressor): data={ "0.weightidxs": create_packed_indices( [0, 1, 2, 3] - ), # Even count for 12-bit + ), # Even count for legacy 12-bit "0.weightvals": torch.tensor([0.7, 0.8, 0.9, 1.0]), "totalks": {"0.weight": totalk_value}, }, @@ -315,7 +315,7 @@ async def test_gather_normalization(comms_instance, dummy_compressor): data={ "0.weightidxs": create_packed_indices( [0, 1, 2, 3] - ), # Even count for 12-bit + ), # Even count for legacy 12-bit "0.weightvals": torch.tensor([0.4, 0.5, 0.6, 0.7]), "totalks": {"0.weight": totalk_value}, }, @@ -360,7 +360,7 @@ async def test_gather_quant_params_validation(comms_instance, dummy_compressor): val_key = f"{param_base}vals" qp_key = f"{param_base}quant_params" - idxs = create_packed_indices([0, 1, 2, 3]) # Even count for 12-bit + idxs = create_packed_indices([0, 1, 2, 3]) # Even count for legacy 12-bit vals = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=torch.uint8) # still quantised lookup = torch.zeros(256, dtype=torch.float32) # dummy LUT @@ -480,7 +480,7 @@ async def test_gather_averaging(comms_instance, dummy_compressor): data={ "0.weightidxs": create_packed_indices( [0, 1, 2, 3] - ), # Even count for 12-bit + ), # Even count for legacy 12-bit "0.weightvals": torch.tensor([0.4, 0.5, 0.6, 0.7]), "totalks": {"0.weight": totalk_value}, }, @@ -491,7 +491,7 @@ async def test_gather_averaging(comms_instance, dummy_compressor): data={ "0.weightidxs": create_packed_indices( [0, 1, 2, 3] - ), # Even count for 12-bit + ), # Even count for legacy 12-bit "0.weightvals": torch.tensor([0.8, 0.9, 1.0, 1.1]), "totalks": {"0.weight": totalk_value}, }, @@ -649,7 +649,7 @@ async def test_gather_complex_normalization(comms_instance, dummy_compressor): # Include totalks in each peer response using the key "layer." (so that stripping "idxs"/"vals" returns the same base key). peer1_response = CommsGetResult( data={ - "layer.idxs": create_packed_indices([0, 1, 2, 3]), # Even count for 12-bit + "layer.idxs": create_packed_indices([0, 1, 2, 3]), # Even count for legacy 12-bit "layer.vals": torch.tensor([1.0, 2.0, 2.0, 3.0]), # norm ≈ 3 "totalks": {"layer.": totalk_value}, }, @@ -658,7 +658,7 @@ async def test_gather_complex_normalization(comms_instance, dummy_compressor): ) peer2_response = CommsGetResult( data={ - "layer.idxs": create_packed_indices([0, 1, 2, 3]), # Even count for 12-bit + "layer.idxs": create_packed_indices([0, 1, 2, 3]), # Even count for legacy 12-bit "layer.vals": torch.tensor([10.0, 20.0, 20.0, 30.0]), # Larger scale "totalks": {"layer.": totalk_value}, }, @@ -667,7 +667,7 @@ async def test_gather_complex_normalization(comms_instance, dummy_compressor): ) peer3_response = CommsGetResult( data={ - "layer.idxs": create_packed_indices([0, 1, 2, 3]), # Even count for 12-bit + "layer.idxs": create_packed_indices([0, 1, 2, 3]), # Even count for legacy 12-bit "layer.vals": torch.tensor([-5.0, 5.0, 5.0, 10.0]), # Different sign "totalks": {"layer.": totalk_value}, }, @@ -1464,37 +1464,41 @@ def __init__(self): ) -def test_valid_12bit_packed_indices(): +def test_valid_rice_bitmap_encoded_indices(): """ - Test Case: test_valid_12bit_packed_indices - - Input: 12-bit packed indices with correct topk dimension + Test Case: test_valid_rice_bitmap_encoded_indices + - Input: Rice/bitmap encoded indices with correct topk dimension - Valid indices (all indices within [0, totalk-1]) - Expected Outcome: The function should complete without raising an error. """ dummy_comms = DummyComms() - # totalk is set to 10; allowed_topk is min(4, 10) == 4. - totalk = 10 - valid_indices = torch.tensor([1, 5, 9, 3], dtype=torch.long) - packed_data = pack_12bit_indices(valid_indices) - vals = torch.randn_like(valid_indices, dtype=torch.float32) + # totalk is set to 64; allowed_topk is min(4, 64) == 4. + totalk = 64 + valid_indices = torch.tensor( + [[1, 5, 9, 3]], dtype=torch.long + ) # Shape [1, 4] for one row + # Use the new encoder format + idx, _ = encode_batch_rows(valid_indices, C=totalk) + vals = torch.randn(1, 4, dtype=torch.float32) # Match the shape [rows, k] # This call should complete without any error. - dummy_comms.check_compressed_indices("test_param", packed_data, totalk, vals=vals) + dummy_comms.check_compressed_indices("test_param", idx, totalk, vals=vals) -def test_valid_12bit_packed_multi_dim(): +def test_valid_rice_bitmap_encoded_multi_dim(): """ - Test 12-bit packed indices from multi-dimensional tensor where the last dimension + Test Rice/bitmap encoded indices from multi-dimensional tensor where the last dimension equals min(hparams.topk_compression, totalk) and all indices are within valid range. """ dummy_comms = DummyComms() - totalk = 20 # allowed_topk = min(4, 20) = 4 + totalk = 128 # allowed_topk = min(4, 128) = 4 # Create a valid 2D tensor (shape: 2 x 4) with valid indices. valid_indices = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.long) - packed_data = pack_12bit_indices(valid_indices) - vals = torch.randn_like(valid_indices, dtype=torch.float32) - dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) + # Use the new encoder format + idx, _ = encode_batch_rows(valid_indices, C=totalk) + vals = torch.randn(2, 4, dtype=torch.float32) # Match the shape + dummy_comms.check_compressed_indices("param", idx, totalk, vals=vals) def test_invalid_not_packed_format(): @@ -1502,18 +1506,18 @@ def test_invalid_not_packed_format(): Test that non-packed formats (like regular tensors or lists) are rejected. """ dummy_comms = DummyComms() - totalk = 20 + totalk = 128 # Test with regular tensor (not packed) - should fail because it's not uint8 invalid_tensor = torch.tensor([0, 1, 2, 3], dtype=torch.long) vals = torch.randn(4, dtype=torch.float32) - # This should fail since only uint8 12-bit packed format is supported - with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"): + # This should fail since only uint8 Rice/bitmap encoded format is supported + with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"): dummy_comms.check_compressed_indices("param", invalid_tensor, totalk, vals=vals) # Test with list (not a tensor) invalid_list = torch.tensor([0, 1, 2, 3]) - with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"): + with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"): dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals) @@ -1522,125 +1526,132 @@ def test_invalid_wrong_dtype(): Test that packed data with wrong dtype is handled correctly. """ dummy_comms = DummyComms() - totalk = 20 + totalk = 128 # int32 tensor is not uint8, so it should fail fake_packed = torch.tensor([0, 1, 2, 3], dtype=torch.int32) vals = torch.randn(4, dtype=torch.float32) # Should fail since only uint8 format is supported - with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"): + with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"): dummy_comms.check_compressed_indices("param", fake_packed, totalk, vals=vals) -def test_invalid_12bit_packed_wrong_topk(): +def test_invalid_rice_bitmap_wrong_topk(): """ - Test that 12-bit packed indices with wrong topk dimension raises ValueError. + Test that Rice/bitmap encoded indices with wrong topk dimension raises ValueError. """ dummy_comms = DummyComms() - totalk = 10 # allowed_topk = min(4, 10) = 4 + totalk = 64 # allowed_topk = min(4, 64) = 4 # Create packed indices with wrong topk (2 instead of 4) - invalid_indices = torch.tensor([0, 1], dtype=torch.long) - packed_data = pack_12bit_indices(invalid_indices) - vals = torch.randn(2, dtype=torch.float32) # Wrong shape - should be 4 - with pytest.raises(ValueError, match="Invalid topk dimension"): + invalid_indices = torch.tensor([[0, 1]], dtype=torch.long) + payload, _ = encode_batch_rows(invalid_indices, C=totalk) + packed_data = torch.tensor( + np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8 + ) + vals = torch.randn(1, 2, dtype=torch.float32) # Wrong shape - should be 4 + with pytest.raises(ValueError, match="Values top.*k=2 but allowed_topk=4"): dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) -def test_invalid_12bit_packed_multi_dim_wrong_topk(): +def test_invalid_rice_bitmap_multi_dim_wrong_topk(): """ - Test that 12-bit packed indices from multi-dimensional tensor with wrong last dimension + Test that Rice/bitmap encoded indices from multi-dimensional tensor with wrong last dimension raises ValueError indicating invalid topk dimension. """ dummy_comms = DummyComms() - totalk = 20 # allowed_topk = min(4, 20) = 4 + totalk = 128 # allowed_topk = min(4, 128) = 4 # Create a 2D tensor with last dimension size 6 (should be 4) invalid_indices = torch.tensor( [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]], dtype=torch.long ) - packed_data = pack_12bit_indices(invalid_indices) + idx, _ = encode_batch_rows(invalid_indices, C=totalk) vals = torch.randn(2, 6, dtype=torch.float32) # Wrong shape - should be (2, 4) - with pytest.raises(ValueError, match="Invalid topk dimension"): - dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) + with pytest.raises(ValueError, match="Values top.*k=6 but allowed_topk=4"): + dummy_comms.check_compressed_indices("param", idx, totalk, vals=vals) -# Removed test_invalid_12bit_packed_negative_index as pack_12bit_indices validates input +# Removed test_invalid_rice_bitmap_negative_index as encoder validates input -def test_invalid_12bit_packed_out_of_bounds(): +def test_invalid_rice_bitmap_out_of_bounds(): """ - Test that 12-bit packed indices with out-of-bounds values raise ValueError. + Test that Rice/bitmap encoded indices with out-of-bounds values raise ValueError. """ dummy_comms = DummyComms() - totalk = 10 # allowed_topk = min(4, 10) = 4 - # Index 10 is out-of-range because valid indices are 0 to 9. - invalid_indices = torch.tensor([0, 1, 10, 3], dtype=torch.long) - packed_data = pack_12bit_indices(invalid_indices) - vals = torch.randn(4, dtype=torch.float32) - with pytest.raises(ValueError, match="Index 10 out of bounds"): - dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) + totalk = 64 # allowed_topk = min(4, 64) = 4 + # Index 64 is out-of-range because valid indices are 0 to 63. + # But the encoder will fail before we can test - so let's test with valid encode but wrong totalk + invalid_indices = torch.tensor([[0, 1, 9, 3]], dtype=torch.long) + # Encode with a larger C to make it work + idx, _ = encode_batch_rows(invalid_indices, C=128, B_choices=(64)) + vals = torch.randn(1, 4, dtype=torch.float32) + # Now check with smaller totalk=64, so index 9 is valid but payload says C=128 + with pytest.raises(ValueError, match="Payload column size C=128 but expected 64"): + dummy_comms.check_compressed_indices("param", idx, totalk, vals=vals) # Removed test_invalid_flat_list_wrong_length - covered by test_invalid_not_packed_format -# Removed test_valid_single_value - not applicable to 12-bit packed format +# Removed test_valid_single_value - not applicable to Rice/bitmap encoded format -# Removed test_invalid_single_value_out_of_bounds - not applicable to 12-bit packed format +# Removed test_invalid_single_value_out_of_bounds - not applicable to Rice/bitmap encoded format -def test_override_allowed_topk_12bit(): +def test_override_allowed_topk_rice_bitmap(): """ - Test using the optional allowed_topk parameter with 12-bit packed format. + Test using the optional allowed_topk parameter with Rice/bitmap encoded format. """ dummy_comms = DummyComms() - totalk = 10 + totalk = 64 + B_choices = (64) # Override allowed_topk to 2. valid_indices = torch.tensor( - [0, 9], dtype=torch.long + [[0, 9]], dtype=torch.long ) # Correct length: 2 elements. - packed_data = pack_12bit_indices(valid_indices) - vals = torch.randn(2, dtype=torch.float32) + idx, _ = encode_batch_rows(valid_indices, C=totalk, B_choices=B_choices) + vals = torch.randn(1, 2, dtype=torch.float32) dummy_comms.check_compressed_indices( - "param", packed_data, totalk, allowed_topk=2, vals=vals + "param", idx, totalk, allowed_topk=2, vals=vals ) # Test with wrong topk invalid_indices = torch.tensor( - [0, 1, 2, 3], dtype=torch.long + [[0, 1, 2, 3]], dtype=torch.long ) # 4 elements instead of 2. - packed_data = pack_12bit_indices(invalid_indices) - vals = torch.randn(4, dtype=torch.float32) # Wrong shape for allowed_topk=2 - with pytest.raises(ValueError, match="Invalid topk dimension"): + idx, _ = encode_batch_rows(invalid_indices, C=totalk, B_choices=B_choices) + vals = torch.randn(1, 4, dtype=torch.float32) # Wrong shape for allowed_topk=2 + with pytest.raises(ValueError, match="Values top.*k=4 but allowed_topk=2"): dummy_comms.check_compressed_indices( - "param", packed_data, totalk, allowed_topk=2, vals=vals + "param", idx, totalk, allowed_topk=2, vals=vals ) -def test_topk_auto_adjust_when_totalk_is_lower_12bit(): +def test_topk_auto_adjust_when_totalk_is_lower_rice_bitmap(): """ - Test scenario where totalk is less than hparams.topk_compression with 12-bit packed format. + Test scenario where totalk is less than hparams.topk_compression with Rice/bitmap encoded format. """ dummy_comms = DummyComms() - totalk = 2 # Now allowed_topk becomes min(hparams.topk_compression, totalk) = min(4,2) = 2. + totalk = 64 # Now allowed_topk becomes min(hparams.topk_compression, totalk) = min(4,64) = 4. valid_indices = torch.tensor( - [0, 1], dtype=torch.long - ) # Valid: length matches allowed_topk (which is 2). - packed_data = pack_12bit_indices(valid_indices) - vals = torch.randn(2, dtype=torch.float32) - dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) - - # Note: Can't test with 1 element as pack_12bit_indices requires even number of indices - # Test with 4 elements (wrong topk) + [[0, 1, 2, 3]], dtype=torch.long + ) # Valid: length matches allowed_topk (which is 4). + idx, _ = encode_batch_rows(valid_indices, C=totalk) + vals = torch.randn(1, 4, dtype=torch.float32) + dummy_comms.check_compressed_indices("param", idx, totalk, vals=vals) + + # Note: Can't test with 1 element as encoder requires even number of indices + # Test with 6 elements (wrong topk) invalid_indices = torch.tensor( - [0, 1, 0, 1], dtype=torch.long - ) # 4 elements instead of 2. - packed_data = pack_12bit_indices(invalid_indices) - vals = torch.randn(4, dtype=torch.float32) # Wrong shape for allowed_topk=2 - with pytest.raises(ValueError, match="Invalid topk dimension"): - dummy_comms.check_compressed_indices("param", packed_data, totalk, vals=vals) + [[0, 1, 2, 3, 4, 5]], dtype=torch.long + ) # 6 elements instead of 4. + idx, _ = encode_batch_rows(invalid_indices, C=totalk) + vals = torch.randn(1, 6, dtype=torch.float32) # Wrong shape for allowed_topk=4 + with pytest.raises(ValueError, match="Values top.*k=6 but allowed_topk=4"): + dummy_comms.check_compressed_indices("param", idx, totalk, vals=vals) # Tests for `weighted_random_sample_no_replacement`