Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
730588e
wip
kasper-piskorski Nov 10, 2025
646c268
update compression tests and pass
kasper-piskorski Nov 11, 2025
6f6d2e5
sync idx and val
kasper-piskorski Nov 13, 2025
a17e1cb
cleanup
kasper-piskorski Nov 14, 2025
ddf7448
cleanup impl
kasper-piskorski Nov 14, 2025
a152b24
full gpu encoder plus optional delta
kasper-piskorski Nov 15, 2025
f156743
remove use_delta arg
kasper-piskorski Nov 17, 2025
f0a3531
layout v2 - faster decoding
kasper-piskorski Nov 17, 2025
3d406e1
track compression time
kasper-piskorski Nov 17, 2025
d61e830
trace compression time
kasper-piskorski Nov 17, 2025
42a7281
Merge branch 'main' of github.com:one-covenant/templar into feat/inde…
kasper-piskorski Nov 17, 2025
69a867a
add timings
kasper-piskorski Nov 18, 2025
4c81409
add fallback for batch decompress
kasper-piskorski Nov 18, 2025
4fc3e59
fix legacy path for decompress
kasper-piskorski Nov 19, 2025
de3ae06
adjust condition
kasper-piskorski Nov 19, 2025
37b240f
fix legacy condition
kasper-piskorski Nov 19, 2025
56c8d7a
added register buffering, fast unary decoding and simple header parsing
kasper-piskorski Nov 19, 2025
616e210
add offload time tracking
kasper-piskorski Nov 19, 2025
d2a05b9
stablish
kasper-piskorski Nov 19, 2025
628e020
Revert "stablish"
kasper-piskorski Nov 19, 2025
28d3ebc
Revert "add offload time tracking"
kasper-piskorski Nov 19, 2025
bb68d63
Revert "added register buffering, fast unary decoding and simple head…
kasper-piskorski Nov 19, 2025
0120eb5
make sure dst and gather idx are aligned
kasper-piskorski Nov 19, 2025
dba1fb6
v2
kasper-piskorski Nov 19, 2025
db49bb3
log exception
kasper-piskorski Nov 19, 2025
8a77ca1
cleanup
kasper-piskorski Nov 19, 2025
044c89b
remove GPU→CPU→GPU copies in decompress and batch_decompress
kasper-piskorski Nov 21, 2025
8577338
update decompression checks in comms and cleanup
kasper-piskorski Nov 21, 2025
436b2eb
make b choices a param
kasper-piskorski Nov 21, 2025
93ab9cc
fix linting errors
kasper-piskorski Nov 21, 2025
95b281a
revert validator
kasper-piskorski Nov 21, 2025
030c7f8
fix missing import
kasper-piskorski Nov 21, 2025
aa66e2d
fix formatting
kasper-piskorski Nov 21, 2025
74d092f
update comms test
kasper-piskorski Nov 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hparams/hparams.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"checkpoint_init_version": "2.1.15",
"checkpoint_init_window": 59637,
"num_evaluation_bins": 5,
"compression_b_choices": [32, 64, 128],
"quantization_bins": 4,
"quantization_range": 6,
"burn_rate": 0.5,
Expand Down
1 change: 1 addition & 0 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def __init__(self):
use_quantization=True,
quantization_bins=self.hparams.quantization_bins,
quantization_range=self.hparams.quantization_range,
b_choices=self.hparams.compression_b_choices,
)
tplr.logger.info("[Init] compression pipeline ready")

Expand Down
1 change: 1 addition & 0 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def __init__(self):
use_quantization=True,
quantization_bins=self.hparams.quantization_bins,
quantization_range=self.hparams.quantization_range,
b_choices=self.hparams.compression_b_choices,
)

# Init optimizer
Expand Down
1 change: 1 addition & 0 deletions scripts/analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(self):
use_quantization=True,
quantization_bins=self.hparams.quantization_bins,
quantization_range=self.hparams.quantization_range,
b_choices=self.hparams.compression_b_choices,
)

# Initialize shapes for each parameter (like in miner/validator)
Expand Down
84 changes: 56 additions & 28 deletions src/tplr/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import bittensor as bt
import boto3
import botocore
import numpy as np
import torch
from aiobotocore.client import AioBaseClient
from aiobotocore.session import get_session
Expand All @@ -43,7 +44,8 @@

import tplr
from tplr.chain import ChainManager
from tplr.compress import TopKCompressor, unpack_12bit_indices
from tplr.compress import TopKCompressor
from tplr.compression import decode_batch_rows, unpack_12bit_indices
from tplr.config import BUCKET_SECRETS, client_config
from tplr.schemas import Bucket, CommsGetResult

Expand Down Expand Up @@ -2478,10 +2480,8 @@ def check_compressed_indices(
"""
Validates the integrity and format of compressed gradient indices.

This is a crucial security and stability check to ensure that gradients
received from peers are well-formed. It verifies that indices are within
the expected bounds and that the compression format (e.g., 12-bit packing)
is correctly applied.
This ensures indices are within bounds and that the **new Rice/bitmap**
codec payload matches the provided values tensor shape (top‑k).

Args:
param_name (str): The name of the parameter being checked.
Expand All @@ -2494,52 +2494,80 @@ def check_compressed_indices(

Raises:
ValueError: If any validation check fails, such as out-of-bounds
indices, incorrect data types, or malformed packed data.
indices, incorrect data types, or malformed payload.
"""
allowed_topk = (
min(self.hparams.topk_compression, totalk)
if allowed_topk is None
else min(allowed_topk, totalk)
)

def _bounds_check(t: torch.Tensor):
"""fast min/max bounds check"""
if t.numel() == 0:
raise ValueError(f"[{param_name}] empty index list")
if t.min().item() < 0 or t.max().item() >= totalk:
bad = t[(t < 0) | (t >= totalk)][0].item()
if not isinstance(idxs, torch.Tensor):
raise ValueError(
f"[{param_name}] Expected tensor for indices, got {type(idxs)}"
)
if vals is None:
raise ValueError(
f"[{param_name}] Values tensor required for index validation"
)
if idxs.dtype != torch.uint8:
raise ValueError(
f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}"
)
if idxs.numel() == 0:
raise ValueError(f"[{param_name}] Empty indices payload")

try:
rows, C, N = decode_batch_rows(idxs)
if C != totalk:
raise ValueError(
f"[{param_name}] Index {bad} out of bounds (totalk = {totalk})"
f"[{param_name}] Payload column size C={C} but expected {totalk}"
)

# Handle 12-bit packed index format only
if isinstance(idxs, torch.Tensor):
if idxs.dtype != torch.uint8:
# compute expected rows from values shape (flatten all but last dim)
if vals.ndim == 0:
raise ValueError(f"[{param_name}] Values tensor has no top‑k dimension")
expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1
if N != expected_rows:
raise ValueError(
f"[{param_name}] Expected uint8 for 12-bit packed indices, got {idxs.dtype}"
f"[{param_name}] Payload rows N={N} but values imply {expected_rows}"
)
# 12-bit packed format is the only supported format
if vals is None:

k = vals.shape[-1]
if k != allowed_topk:
raise ValueError(
f"[{param_name}] Values tensor required to validate 12-bit packed indices"
f"[{param_name}] Payload K={rows.shape[-1]} but values top-k={k}"
)
if idxs.numel() == 0:
raise ValueError(f"[{param_name}] Empty packed indices tensor")

# Unpack using the values shape
# Bounds check
if rows.numel() > 0:
min_idx = int(rows.min().item())
max_idx = int(rows.max().item())
else:
min_idx, max_idx = 0, -1
if min_idx < 0 or max_idx >= totalk:
raise ValueError(
f"[{param_name}] Index out of bounds (min={min_idx}, max={max_idx}, totalk={totalk})"
)
except ValueError as e:
# NB: legacy path
tplr.logger.warning(
f"Failed to unpack: {e} Falling back to legacy uncompress."
)
# Fallback: likely old format -> try legacy decoder
try:
unpacked = unpack_12bit_indices(idxs, vals.shape)
# Validate that the last dimension matches allowed_topk
if unpacked.shape[-1] != allowed_topk:
raise ValueError(
f"[{param_name}] Invalid topk dimension: "
f"shape[-1]={unpacked.shape[-1]} but expected {allowed_topk}"
)
_bounds_check(unpacked)
except Exception as e:
raise ValueError(f"[{param_name}] Failed to unpack 12-bit indices: {e}")
else:
raise ValueError(f"[{param_name}] Expected tensor but got {type(idxs)}")
raise ValueError(
f"[{param_name}] Failed to unpack legacy 12-bit indices: {e}"
)
except Exception as e:
raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}")

async def s3_get_object_size(self, bucket: Bucket, key: str) -> int | None:
"""
Expand Down
Loading
Loading