Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions hparams/hparams.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
"blocks_per_window": 65,
"windows_per_weights": 5,
"momentum_decay": 0.95,
"topk_compression": 32,
"topk_compression": 128,
"target_chunk": 64,
"use_dct": false,
"binary_score_ma_alpha": 0.05,
"moving_average_window": 5,
"model_size": "70B",
Expand Down
1 change: 1 addition & 0 deletions neurons/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import bittensor as bt
import torch
import torch.distributed as dist
from lm_eval import simple_evaluate
from torch.cuda import device_count as _cuda_device_count
from torch.utils.data import DataLoader
from torchtitan.components.loss import cross_entropy_loss
Expand Down
3 changes: 1 addition & 2 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ def __init__(self):
)

enc = self.transformer.encode(
torch.empty(p.shape, dtype=torch.float16, device=self.device),
use_dct=self.hparams.use_dct,
torch.empty(p.shape, dtype=torch.float16, device=self.device)
)
_, _, xshape, totalk, _ = self.compressor.compress(
enc,
Expand Down
1 change: 0 additions & 1 deletion neurons/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,6 @@ def outer_step(self, gather_result):
device=str(self.device),
is_master=self.is_master,
world_size=self.world_size,
use_dct=self.hparams.use_dct,
)
return

Expand Down
32 changes: 26 additions & 6 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,18 +259,41 @@ def __init__(self):
self.xshapes = {}
self.totalks = {}
# Use bare_model like the miner does to ensure consistent parameter iteration
import time
total_compress_time = 0.0
total_encode_time = 0.0

# Enable debug timing in compressor
self.compressor._debug_timing = True

for n, p in self.model.named_parameters():
# Use the same approach as miner for creating xshapes and totalks
encode_start = time.time()
enc = self.transformer.encode(
torch.empty(p.shape, dtype=torch.float16, device=self.device),
use_dct=self.hparams.use_dct,
torch.empty(p.shape, dtype=torch.float16, device=self.device)
)
Comment on lines 272 to 274
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

API update to drop use_dct — LGTM.

ChunkingTransformer.encode/decode now called without use_dct; matches the new signatures.

Run to confirm no stale use_dct remains:

Also applies to: 2805-2810


🏁 Script executed:

#!/bin/bash
rg -nP --hidden --glob '!*site-packages*' '\buse_dct\b|outer_step\s*\([^)]*use_dct' -C2

Length of output: 390


Remove stale use_dct reference in test config

  • In scripts/local_miner_test.py (line 179), remove or update the "use_dct": False entry so tests align with the updated ChunkingTransformer API.
🤖 Prompt for AI Agents
In neurons/validator.py around lines 264 to 266 (and the related test at
scripts/local_miner_test.py line 179), there's a stale test config entry
"use_dct": False that no longer matches the updated ChunkingTransformer API;
remove the "use_dct" key from the test config (or replace it with the current
API-equivalent option if the new transformer expects an explicit setting), then
run the tests and adjust any assertions that depended on the previous behavior
so they reflect the transformer’s current chunking behavior.

encode_time = time.time() - encode_start

compress_start = time.time()
_, _, xshape, totalk, _ = self.compressor.compress(
enc,
self.hparams.topk_compression,
)
compress_time = time.time() - compress_start

self.xshapes[n] = xshape
self.totalks[n] = totalk

total_encode_time += encode_time
total_compress_time += compress_time

# Log timing for each layer
tplr.logger.info(f"[COMPRESS TIMING] {n}: encode={encode_time:.3f}s, compress={compress_time:.3f}s, shape={p.shape}")

tplr.logger.info(f"[COMPRESS TIMING TOTAL] encode={total_encode_time:.3f}s, compress={total_compress_time:.3f}s")

# Disable debug timing after initialization
self.compressor._debug_timing = False

self.openskill_model = PlackettLuce(
beta=self.hparams.openskill_beta, tau=self.hparams.openskill_tau
Expand Down Expand Up @@ -1698,7 +1721,6 @@ async def run(self):
device=cast(str, self.device),
is_master=self.is_master,
world_size=self.world_size,
use_dct=self.hparams.use_dct,
wandb_run=self.wandb if self.is_master else None,
global_step=self.global_step,
)
Expand Down Expand Up @@ -2804,9 +2826,7 @@ def update_model_with_gradient(
quant_params,
)

full_grad_src = self.transformer.decode(
decompressed, use_dct=self.hparams.use_dct
)
full_grad_src = self.transformer.decode(decompressed)
# Single conversion to target dtype+device to avoid extra temporaries
full_grad_src = full_grad_src.to(
dtype=p.dtype, device=p.device, non_blocking=True
Expand Down
103 changes: 59 additions & 44 deletions src/tplr/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import bittensor as bt
import boto3
import botocore
import numpy as np
import torch
import torch.distributed as dist
from aiobotocore.client import AioBaseClient
Expand All @@ -48,7 +49,7 @@

import tplr
from tplr.chain import ChainManager
from tplr.compress import TopKCompressor, unpack_12bit_indices
from tplr.compress import TopKCompressor, decode_batch_rows
from tplr.config import BUCKET_SECRETS, client_config
from tplr.schemas import Bucket, CommsGetResult

Expand Down Expand Up @@ -2622,68 +2623,82 @@ def check_compressed_indices(
"""
Validates the integrity and format of compressed gradient indices.

This is a crucial security and stability check to ensure that gradients
received from peers are well-formed. It verifies that indices are within
the expected bounds and that the compression format (e.g., 12-bit packing)
is correctly applied.
This ensures indices are within bounds and that the **new Rice/bitmap**
codec payload matches the provided values tensor shape (top‑k).

Args:
param_name (str): The name of the parameter being checked.
idxs (torch.Tensor): The tensor of indices.
totalk (int): The total number of elements in the original uncompressed tensor.
allowed_topk (int | None, optional): The expected number of top-k values.
Defaults to the hparams configuration.
vals (torch.Tensor | None, optional): The corresponding values tensor,
required for validating 12-bit packed indices. Defaults to None.
vals (torch.Tensor | None, optional): The corresponding values tensor.

Raises:
ValueError: If any validation check fails, such as out-of-bounds
indices, incorrect data types, or malformed packed data.
indices, incorrect data types, or malformed payload.
"""
allowed_topk = (
min(self.hparams.topk_compression, totalk)
if allowed_topk is None
else min(allowed_topk, totalk)
)

def _bounds_check(t: torch.Tensor):
"""fast min/max bounds check"""
if t.numel() == 0:
raise ValueError(f"[{param_name}] empty index list")
if t.min().item() < 0 or t.max().item() >= totalk:
bad = t[(t < 0) | (t >= totalk)][0].item()
raise ValueError(
f"[{param_name}] Index {bad} out of bounds (totalk = {totalk})"
)
if not isinstance(idxs, torch.Tensor):
raise ValueError(
f"[{param_name}] Expected tensor for indices, got {type(idxs)}"
)
if vals is None:
raise ValueError(
f"[{param_name}] Values tensor required for index validation"
)
if idxs.dtype != torch.uint8:
raise ValueError(
f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}"
)
if idxs.numel() == 0:
raise ValueError(f"[{param_name}] Empty indices payload")

# Handle 12-bit packed index format only
if isinstance(idxs, torch.Tensor):
if idxs.dtype != torch.uint8:
raise ValueError(
f"[{param_name}] Expected uint8 for 12-bit packed indices, got {idxs.dtype}"
)
# 12-bit packed format is the only supported format
if vals is None:
raise ValueError(
f"[{param_name}] Values tensor required to validate 12-bit packed indices"
)
if idxs.numel() == 0:
raise ValueError(f"[{param_name}] Empty packed indices tensor")
# Decode (CPU) and perform structural checks
try:
payload_bytes = idxs.detach().cpu().numpy().tobytes()
rows_list, C, N = decode_batch_rows(payload_bytes)
except Exception as e:
raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}")

# Unpack using the values shape
try:
unpacked = unpack_12bit_indices(idxs, vals.shape)
# Validate that the last dimension matches allowed_topk
if unpacked.shape[-1] != allowed_topk:
raise ValueError(
f"[{param_name}] Invalid topk dimension: "
f"shape[-1]={unpacked.shape[-1]} but expected {allowed_topk}"
)
_bounds_check(unpacked)
except Exception as e:
raise ValueError(f"[{param_name}] Failed to unpack 12-bit indices: {e}")
else:
raise ValueError(f"[{param_name}] Expected tensor but got {type(idxs)}")
if C != totalk:
raise ValueError(
f"[{param_name}] Payload column size C={C} but expected {totalk}"
)

# compute expected rows from values shape (flatten all but last dim)
if vals.ndim == 0:
raise ValueError(f"[{param_name}] Values tensor has no top‑k dimension")
expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1
if N != expected_rows:
raise ValueError(
f"[{param_name}] Payload rows N={N} but values imply {expected_rows}"
)

k = vals.shape[-1]
if k != allowed_topk:
raise ValueError(
f"[{param_name}] Values top‑k={k} but allowed_topk={allowed_topk}"
)
if any(len(r) != k for r in rows_list):
raise ValueError(
f"[{param_name}] At least one row has mismatched top‑k size"
)

# bounds check without materialising full tensor
max_idx = max((max(r) if len(r) > 0 else -1) for r in rows_list)
min_idx = (
min((min(r) if len(r) > 0 else 0) for r in rows_list) if rows_list else 0
)
if min_idx < 0 or max_idx >= totalk:
raise ValueError(
f"[{param_name}] Index out of bounds (min={min_idx}, max={max_idx}, totalk={totalk})"
)

Comment on lines +2626 to 2702
Copy link
Copy Markdown

@coderabbitai coderabbitai bot Aug 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

check_compressed_indices rejects legacy/int64 indices; add backward‑compat path.

Today this function hard‑fails unless idxs is uint8 (Rice/bitmap). Older miners or historical payloads may still send int64 index tensors. Decompression paths already handle both; validation should too to avoid skipping otherwise valid peers.

Apply:

@@
-        if idxs.dtype != torch.uint8:
-            raise ValueError(
-                f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}"
-            )
+        if idxs.dtype == torch.uint8:
+            pass  # Rice/bitmap fast-path below
+        elif idxs.dtype in (torch.int64, torch.long):
+            # Legacy/unpacked path: shape and bounds checks only
+            if vals is None:
+                raise ValueError(f"[{param_name}] Values tensor required for index validation")
+            if tuple(idxs.shape) != tuple(vals.shape):
+                raise ValueError(f"[{param_name}] Indices shape {tuple(idxs.shape)} != values shape {tuple(vals.shape)}")
+            k = vals.shape[-1] if vals.ndim > 0 else 0
+            if k != allowed_topk:
+                raise ValueError(f"[{param_name}] Values top‑k={k} but allowed_topk={allowed_topk}")
+            if idxs.numel() > 0:
+                i_min = int(idxs.min().item())
+                i_max = int(idxs.max().item())
+                if i_min < 0 or i_max >= totalk:
+                    raise ValueError(f"[{param_name}] Index out of bounds (min={i_min}, max={i_max}, totalk={totalk})")
+            return  # legacy path validated
+        else:
+            raise ValueError(
+                f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}"
+            )
@@
-            rows_list, C, N = decode_batch_rows(payload_bytes)
+            rows_list, C, N = decode_batch_rows(payload_bytes)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
This ensures indices are within bounds and that the **new Rice/bitmap**
codec payload matches the provided values tensor shape (topk).
Args:
param_name (str): The name of the parameter being checked.
idxs (torch.Tensor): The tensor of indices.
totalk (int): The total number of elements in the original uncompressed tensor.
allowed_topk (int | None, optional): The expected number of top-k values.
Defaults to the hparams configuration.
vals (torch.Tensor | None, optional): The corresponding values tensor,
required for validating 12-bit packed indices. Defaults to None.
vals (torch.Tensor | None, optional): The corresponding values tensor.
Raises:
ValueError: If any validation check fails, such as out-of-bounds
indices, incorrect data types, or malformed packed data.
indices, incorrect data types, or malformed payload.
"""
allowed_topk = (
min(self.hparams.topk_compression, totalk)
if allowed_topk is None
else min(allowed_topk, totalk)
)
def _bounds_check(t: torch.Tensor):
"""fast min/max bounds check"""
if t.numel() == 0:
raise ValueError(f"[{param_name}] empty index list")
if t.min().item() < 0 or t.max().item() >= totalk:
bad = t[(t < 0) | (t >= totalk)][0].item()
raise ValueError(
f"[{param_name}] Index {bad} out of bounds (totalk = {totalk})"
)
if not isinstance(idxs, torch.Tensor):
raise ValueError(
f"[{param_name}] Expected tensor for indices, got {type(idxs)}"
)
if vals is None:
raise ValueError(
f"[{param_name}] Values tensor required for index validation"
)
if idxs.dtype != torch.uint8:
raise ValueError(
f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}"
)
if idxs.numel() == 0:
raise ValueError(f"[{param_name}] Empty indices payload")
# Handle 12-bit packed index format only
if isinstance(idxs, torch.Tensor):
if idxs.dtype != torch.uint8:
raise ValueError(
f"[{param_name}] Expected uint8 for 12-bit packed indices, got {idxs.dtype}"
)
# 12-bit packed format is the only supported format
if vals is None:
raise ValueError(
f"[{param_name}] Values tensor required to validate 12-bit packed indices"
)
if idxs.numel() == 0:
raise ValueError(f"[{param_name}] Empty packed indices tensor")
# Decode (CPU) and perform structural checks
try:
payload_bytes = idxs.detach().cpu().numpy().tobytes()
rows_list, C, N = decode_batch_rows(payload_bytes)
except Exception as e:
raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}")
# Unpack using the values shape
try:
unpacked = unpack_12bit_indices(idxs, vals.shape)
# Validate that the last dimension matches allowed_topk
if unpacked.shape[-1] != allowed_topk:
raise ValueError(
f"[{param_name}] Invalid topk dimension: "
f"shape[-1]={unpacked.shape[-1]} but expected {allowed_topk}"
)
_bounds_check(unpacked)
except Exception as e:
raise ValueError(f"[{param_name}] Failed to unpack 12-bit indices: {e}")
else:
raise ValueError(f"[{param_name}] Expected tensor but got {type(idxs)}")
if C != totalk:
raise ValueError(
f"[{param_name}] Payload column size C={C} but expected {totalk}"
)
# compute expected rows from values shape (flatten all but last dim)
if vals.ndim == 0:
raise ValueError(f"[{param_name}] Values tensor has no top‑k dimension")
expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1
if N != expected_rows:
raise ValueError(
f"[{param_name}] Payload rows N={N} but values imply {expected_rows}"
)
k = vals.shape[-1]
if k != allowed_topk:
raise ValueError(
f"[{param_name}] Values top‑k={k} but allowed_topk={allowed_topk}"
)
if any(len(r) != k for r in rows_list):
raise ValueError(
f"[{param_name}] At least one row has mismatched top‑k size"
)
# bounds check without materialising full tensor
max_idx = max((max(r) if len(r) > 0 else -1) for r in rows_list)
min_idx = (
min((min(r) if len(r) > 0 else 0) for r in rows_list) if rows_list else 0
)
if min_idx < 0 or max_idx >= totalk:
raise ValueError(
f"[{param_name}] Index out of bounds (min={min_idx}, max={max_idx}, totalk={totalk})"
)
if not isinstance(idxs, torch.Tensor):
raise ValueError(
f"[{param_name}] Expected tensor for indices, got {type(idxs)}"
)
if vals is None:
raise ValueError(
f"[{param_name}] Values tensor required for index validation"
)
if idxs.dtype == torch.uint8:
pass # Rice/bitmap fast-path below
elif idxs.dtype in (torch.int64, torch.long):
# Legacy/unpacked path: shape and bounds checks only
if vals is None:
raise ValueError(f"[{param_name}] Values tensor required for index validation")
if tuple(idxs.shape) != tuple(vals.shape):
raise ValueError(f"[{param_name}] Indices shape {tuple(idxs.shape)} != values shape {tuple(vals.shape)}")
k = vals.shape[-1] if vals.ndim > 0 else 0
if k != allowed_topk:
raise ValueError(f"[{param_name}] Values top-k={k} but allowed_topk={allowed_topk}")
if idxs.numel() > 0:
i_min = int(idxs.min().item())
i_max = int(idxs.max().item())
if i_min < 0 or i_max >= totalk:
raise ValueError(f"[{param_name}] Index out of bounds (min={i_min}, max={i_max}, totalk={totalk})")
return # legacy path validated
else:
raise ValueError(
f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}"
)
if idxs.numel() == 0:
raise ValueError(f"[{param_name}] Empty indices payload")
# Decode (CPU) and perform structural checks
try:
payload_bytes = idxs.detach().cpu().numpy().tobytes()
rows_list, C, N = decode_batch_rows(payload_bytes)
except Exception as e:
raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}")
🤖 Prompt for AI Agents
In src/tplr/comms.py around lines 2626 to 2702, the validation currently rejects
idxs unless dtype is torch.uint8; update it to accept legacy torch.int64 index
tensors as well by adding a backward-compat branch: if idxs.dtype is
torch.int64, skip the Rice/bitmap decode and instead interpret the tensor as raw
integer indices (reshape/flatten consistent with vals shape and top-k), build
rows_list from those indices, compute C, N, min/max bounds from that legacy
format, and run the same consistency checks (C vs totalk, N vs expected_rows, k
vs allowed_topk, per-row length, bounds). Keep the existing uint8 path unchanged
and only raise for other dtypes, ensuring all error messages remain
param-prefixed.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want the backwards compability

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joellidin, understood. I'll respect your decision to not include backward compatibility for the legacy index formats. The current implementation correctly enforces that all peers must use the new Rice/bitmap encoding format (uint8 payload), which keeps the validation logic clean and ensures consistency across the network.


✏️ Learnings added
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

Comment on lines +2626 to 2702
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Tighten validation: int-cast topk, guard vals type, drop numpy by using math.prod.

Prevents surprising float topk from misconfigured hparams, avoids attribute errors if vals isn’t a tensor, and removes an unnecessary dependency.

Apply:

- import numpy as np
+ # (numpy no longer needed here)

@@
-        allowed_topk = (
-            min(self.hparams.topk_compression, totalk)
-            if allowed_topk is None
-            else min(allowed_topk, totalk)
-        )
+        allowed_topk = (
+            min(self.hparams.topk_compression, totalk)
+            if allowed_topk is None
+            else min(allowed_topk, totalk)
+        )
+        try:
+            allowed_topk = int(allowed_topk)
+        except Exception:
+            raise ValueError(f"[{param_name}] allowed_topk must be an integer-compatible value")

@@
-        if vals is None:
+        if vals is None:
             raise ValueError(
                 f"[{param_name}] Values tensor required for index validation"
             )
+        if not torch.is_tensor(vals):
+            raise ValueError(f"[{param_name}] Values must be a torch.Tensor, got {type(vals)}")

@@
-        expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1
+        expected_rows = int(math.prod(vals.shape[:-1])) if vals.ndim > 1 else 1

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/tplr/comms.py around lines 2626 to 2702, tighten validation by (1)
enforcing vals is a torch.Tensor before using .ndim/.shape (raise a ValueError
if not), (2) casting allowed_topk/k to int to avoid float surprises, and (3)
replace numpy.prod usage with math.prod over vals.shape[:-1] (import math) to
remove the numpy dependency; keep all existing structural and bounds checks but
perform the type guard early and use int(...) when comparing top-k values.

async def s3_get_object_size(self, bucket: Bucket, key: str) -> int | None:
"""
Expand Down
33 changes: 33 additions & 0 deletions src/tplr/compress/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# The MIT License (MIT)
# © 2025 tplr.ai
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
# the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from .bits import (
decode_batch_rows, # decoder (CPU)
encode_batch_rows, # GPU-accelerated encoder → bytes + perm + meta
)
from .pack12 import pack_12bit_indices, unpack_12bit_indices # legacy
from .topk import ChunkingTransformer, TopKCompressor

__all__ = [
# High level
"TopKCompressor",
"ChunkingTransformer",
"encode_batch_rows",
"decode_batch_rows",
"pack_12bit_indices",
"unpack_12bit_indices",
]
Loading
Loading