Conversation
Reorganize the compress module into a proper package structure to improve maintainability and separate concerns. - Move compress.py to compress/topk.py for TopK-specific functionality - Create compress/__init__.py with clean public API exports - Extract 12-bit packing functions to compress/pack12.py - Update test imports to use new package structure - Maintain backward compatibility through __init__.py exports
Implement new GPU-accelerated compression codec for gradient indices using Rice coding with adaptive parameters and bitmap encoding for dense blocks. - Add BitWriter/BitReader classes for efficient bit-level I/O - Implement encode_batch_rows with GPU-accelerated path and CPU fallback - Add decode_batch_rows for CPU-based decompression - Support adaptive Rice parameters and bitmap vs local encoding - Export new codec functions in compress package __init__.py - Maintain backward compatibility with existing pack12/topk modules
Add extensive test coverage for the bits compression codec including round-trip encode/decode validation, permutation tracking, and cross-device compatibility testing. - Test round-trip operations ensuring data integrity - Verify permutation tracking for correct value reordering - Cover edge cases: zero rows, zero K values, empty tensors - Test CPU and GPU device parity when CUDA available - Validate bitmap vs local encoding path selection - Ensure compatibility with C values divisible by B_choices - Test both GPU-accelerated and CPU reference implementations - Verify codec handles unique indices from TopK operations
|
Warning Rate limit exceeded@joellidin has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 11 minutes and 46 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (2)
WalkthroughReplaces 12‑bit packed indices with a Rice/bitmap bits codec and GPU encoder/CPU decoder, centralizes compress exports, removes Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Outer as outer_step / Trainer
participant Transformer
participant Compressor
participant Comms
participant Peer
rect rgb(240,248,255)
note over Outer: Outer step (no use_dct)
Outer->>Transformer: encode(tensor)
Transformer-->>Outer: encoded_chunks
Outer->>Compressor: compress(encoded_chunks, K)
Compressor-->>Outer: (idx_bytes, perm2d, vals, xshape, totalk[, qparams])
end
rect rgb(245,255,240)
note over Outer,Comms: transmit & validate
Outer->>Comms: check_compressed_indices(idxs=uint8_payload, vals, totalk, k)
Comms->>Comms: rows_list, C, N = decode_batch_rows(payload_bytes)
Comms-->>Outer: validation_result / ValueError
Outer-->>Peer: send payload bytes + vals
end
rect rgb(255,248,240)
note over Peer: decompress & reconstruct
Peer->>Compressor: batch_decompress(payload_bytes, vals, ...)
Compressor->>Transformer: decode(decompressed)
Transformer-->>Peer: reconstructed_tensor
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested reviewers
Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
Codecov Report❌ Patch coverage is
@@ Coverage Diff @@
## dev #553 +/- ##
==========================================
+ Coverage 61.07% 62.92% +1.85%
==========================================
Files 25 28 +3
Lines 4056 4340 +284
==========================================
+ Hits 2477 2731 +254
- Misses 1579 1609 +30
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (8)
tests/test_comms.py (2)
779-791: Fix async client context in s3_put_small_file test (missing aenter/aexit).currently new_client.aenter() returns a fresh AsyncMock without put_object, causing AttributeError.
Apply:
- mock_client = AsyncMock() - mock_client.put_object = AsyncMock() + mock_client = AsyncMock() + mock_client.put_object = AsyncMock() + # ensure _get_s3_client awaits __aenter__ and gets this same client + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = AsyncMock() comms_instance.session.create_client = MagicMock(return_value=mock_client)
809-817: Same async context fix for multipart upload test.without aenter/aexit, create_multipart_upload etc. may not be found on the awaited client.
Apply:
- mock_client = AsyncMock() + mock_client = AsyncMock() mock_client.create_multipart_upload = AsyncMock( return_value={"UploadId": "test_id"} ) mock_client.upload_part = AsyncMock(return_value={"ETag": "test_etag"}) mock_client.complete_multipart_upload = AsyncMock() mock_client.abort_multipart_upload = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = AsyncMock() comms_instance.session.create_client = MagicMock(return_value=mock_client)src/tplr/neurons.py (2)
36-36: Stop importing 12‑bit unpacker; use Rice/bitmap decoder.check_uid_index_overlap still assumes 12‑bit. Replace with decode_batch_rows to match the new codec; keep a fallback for legacy int64 tensors.
Apply:
-from tplr.compress import unpack_12bit_indices +from tplr.compress import decode_batch_rows
952-964: Overlap detector decodes indices with 12‑bit unpacker (wrong format).This will miscompute overlaps for Rice/bitmap payloads. Decode bytes via decode_batch_rows and reshape to vals shape; keep int64 fallback.
Apply:
- # Unpack all 12-bit packed indices using values shape + # Decode indices (Rice/bitmap bytes → rows → reshape) or accept int64 legacy unpacked_indices = [] for i in range(Ptot): idx_data = idxs_all[i] if isinstance(idxs_all, list) else idxs_all val_data = vals_all[i] if isinstance(vals_all, list) else vals_all - - # 12-bit packed format - use values shape for unpacking - unpacked = unpack_12bit_indices( - idx_data.to(neuron.config.device), val_data.shape - ) + if isinstance(idx_data, torch.Tensor) and idx_data.dtype == torch.uint8: + rows, _C, _N = decode_batch_rows(idx_data.detach().cpu().numpy().tobytes()) + if any(len(r) != val_data.shape[-1] for r in rows): + raise ValueError("Row-wise topk size mismatch in index payload (overlap check)") + unpacked = torch.tensor(rows, dtype=torch.int64, device=neuron.config.device).view(*val_data.shape) + elif isinstance(idx_data, torch.Tensor) and idx_data.dtype in (torch.int64, torch.long): + unpacked = idx_data.to(neuron.config.device).view(*val_data.shape) + else: + raise ValueError(f"Unsupported index dtype in overlap check: {getattr(idx_data, 'dtype', type(idx_data))}") unpacked_indices.append(unpacked)tests/unit/test_compress.py (2)
102-141: Same issue: each payload encodes 2 rows; xshape has 8 rows.Both payloads must encode N=8 (or set xshape=(2,64)).
Apply:
- p = torch.zeros(8, 64) - xshape = (8, 64) - totalk = 64 - idx1_orig = torch.tensor([[0, 1], [2, 3]], dtype=torch.int64) - idx2_orig = torch.tensor([[4, 5], [6, 7]], dtype=torch.int64) + p = torch.zeros(8, 64) + xshape = (8, 64) + totalk = 64 + K = 2 + idx1_orig = torch.stack( + [torch.tensor([0, 1], dtype=torch.int64) for _ in range(8)] + ) + idx2_orig = torch.stack( + [torch.tensor([2, 3], dtype=torch.int64) for _ in range(8)] + ) @@ - val1 = torch.tensor([[0.1, 0.2], [0.3, 0.4]], dtype=torch.float32) - val2 = torch.tensor([[0.5, 0.6], [0.7, 0.8]], dtype=torch.float32) + val1 = torch.randn(8, K, dtype=torch.float32) + val2 = torch.randn(8, K, dtype=torch.float32) @@ - val1 = torch.gather(val1, dim=1, index=perm1) - val2 = torch.gather(val2, dim=1, index=perm2) + val1 = torch.gather(val1, dim=1, index=perm1) + val2 = torch.gather(val2, dim=1, index=perm2)
205-241: Again: payload encodes 1 row; xshape has 8 rows.Align N in payload with xshape’s row count.
Apply:
- p = torch.zeros(8, 64) - xshape = (8, 64) - totalk = 64 - idx_orig = torch.tensor([[0, 1, 2, 3]], dtype=torch.int64) + p = torch.zeros(8, 64) + xshape = (8, 64) + totalk = 64 + K = 4 + idx_orig = torch.stack( + [torch.arange(K, dtype=torch.int64) for _ in range(8)] + ) @@ - val_orig = torch.tensor([[10.0, 20.0, 30.0, 40.0]], dtype=torch.float32) - val = [torch.gather(val_orig, dim=1, index=perm)] + val_orig = torch.full((8, K), 10.0, dtype=torch.float32) + val = [torch.gather(val_orig, dim=1, index=perm)]src/tplr/compress/topk.py (2)
430-471: Bug: boolean tests on tensors in _quantize_values will raise.Using Python or with torch tensors is invalid. Also ensure scale is finite via tensor ops.
Apply:
- scale = self.range_in_sigmas * std / self.n_bins - if ( - isinstance(scale, torch.Tensor) - and (scale == 0 or torch.isnan(scale) or torch.isinf(scale)) - ) or ( - not isinstance(scale, torch.Tensor) - and (scale == 0 or not math.isfinite(float(scale))) - ): - scale = torch.tensor(1.0, dtype=centered.dtype, device=val.device) + scale_t = (self.range_in_sigmas * std) / float(self.n_bins) + if not isinstance(scale_t, torch.Tensor): + scale_t = torch.tensor(float(scale_t), dtype=centered.dtype, device=val.device) + if torch.isnan(scale_t) or torch.isinf(scale_t) or (scale_t == 0): + scale_t = torch.tensor(1.0, dtype=centered.dtype, device=val.device) @@ - centered_fp32 = centered.to(torch.float32) - qval = ((centered_fp32 / scale + offset).round().clamp(0, self.n_bins - 1)).to( + centered_fp32 = centered.to(torch.float32) + qval = ((centered_fp32 / scale_t + offset).round().clamp(0, self.n_bins - 1)).to( torch.uint8 ) @@ - qparams: QuantParamsT = (shift, float(scale), int(offset), lookup, val.dtype) + qparams: QuantParamsT = (shift, float(scale_t.item()), int(offset), lookup, val.dtype)
494-545: maybe_dequantize_values: tuple/list detection is unsafe and can TypeError.len() on elements of the tuple (e.g., float) will crash. Simplify to a straightforward normalisation.
Apply:
- if ( - isinstance(qparams, tuple) - and len(qparams) == 5 # potentially single or already 5 elements - and not all([len(q) == 5 for q in qparams]) # already correctly formatted - ): - qparams = [qparams] - if not isinstance(qparams, list): - qparams = [qparams] + if isinstance(qparams, tuple): + qparams = [qparams] + elif not isinstance(qparams, list): + qparams = [qparams] # type: ignore[list-item]
🧹 Nitpick comments (19)
hparams/hparams.json (3)
13-15: Check semantic consistency: topk_compression (128) vs target_chunk (64).If target_chunk bounds the number of active entries per chunk, having topk_compression > target_chunk could be inconsistent or wasteful. Validate intended relationship or add an assertion.
1-3: Consider bumping spec_version or adding a migration note.Given the removal of use_dct across the codebase and the new codec, keeping spec_version at 5 may confuse consumers parsing configs. Either bump spec_version or document that spec v5 now implies Rice/bitmap-only.
13-13: Add tests and performance checks fortopk_compression=128
- Verify the Rice/bitmap codec’s dynamic bit-length encoding correctly handles
k=128(no hard-coded 32-value limits were found) by adding round-trip/property tests atk=128and extreme sparse/dense scenarios.- Benchmark RPC/message payload sizes and per-batch memory overhead at
target_batch_size=1024to quantify any increases versusk=32.- Update or add tests that currently assume or reference
32(e.g., intests/…,local_miner_test.py) to cover the new default of128.neurons/evaluator.py (1)
63-63: Remove unused import and avoid shelling out; also fix variable shadowing in result parsing.
- The added import
from lm_eval import simple_evaluateis unused.- Prefer
subprocess.run(..., check=True)overos.systemfor better error handling.- In
_process_results, the loop variabletask_nameshadows the function argument; rename the inner variable.Apply these diffs:
- from lm_eval import simple_evaluate + # (optional) import only if switching to in-process evaluation + # from lm_eval import simple_evaluate- # TODO: Consider replacing os.system with subprocess.run for better control and error handling. - exit_code = os.system(command) + import subprocess + try: + completed = subprocess.run(command, shell=True, check=False) + exit_code = completed.returncode + except Exception as e: + tplr.logger.error(f"Failed to launch benchmark: {e}") + exit_code = 1- for task_name, task_results in results.get("results", {}).items(): + for task, task_results in results.get("results", {}).items(): @@ - if metric_value is not None: + if metric_value is not None: tplr.logger.info( - f"Benchmark for {task_name} ({used_metric}): {metric_value}" + f"Benchmark for {task} ({used_metric}): {metric_value}" ) self.metrics_logger.log( measurement="benchmark_task", tags={ - "task": task_name, + "task": task,Also applies to: 476-478, 538-556
src/tplr/compress/bits.py (4)
224-229: Validate header bounds (C, N) before writing.Header allocates 12 bits for C-1 and 16 bits for N. Guard overflows early.
- bw = BitWriter() + bw = BitWriter() + if C <= 0 or C > (1 << 12): + raise ValueError(f"C={C} out of range for 12-bit header (1..4096)") + if N > 0xFFFF: + raise ValueError(f"N={N} exceeds 16-bit row count limit ({N} > 65535)")
268-279: Add the same header bounds check to the GPU path.Mirror CPU encoder constraints to avoid producing undecodable payloads.
rows, k = idx.shape device = idx.device + if C <= 0 or C > (1 << 12): + raise ValueError(f"C={C} out of range for 12-bit header (1..4096)") + if rows > 0xFFFF: + raise ValueError(f"rows={rows} exceeds 16-bit row count limit (65535)")
252-318: Avoid F.one_hot blow-up for large (rows, k, C/Bmin).The one-hot path can allocate rows×k×(C/Bmin) temporarily. Consider using a bincount-based aggregation for all B values to reduce memory.
452-489: Decode robustness: handle truncated/corrupt row payloads cleanly.Wrap per‑row decode to surface EOFs as ValueError and avoid partial state.
- for _i in range(N): - row_len = br.read_bits(16) - row_bytes = br.read_bytes(row_len) - rr = BitReader(row_bytes) + for _i in range(N): + try: + row_len = br.read_bits(16) + row_bytes = br.read_bytes(row_len) + rr = BitReader(row_bytes) + except EOFError as e: + raise ValueError(f"Truncated payload while reading row header #{_i}: {e}") from e @@ - s_len = _rice_read(rr, k_param) + try: + s_len = _rice_read(rr, k_param) + except EOFError as e: + raise ValueError(f"Truncated payload in row #{_i}, sub-block {j}: {e}") from e @@ - for loc in range(B): - if (bitmask >> loc) & 1: - indices.append(j * B + loc) + for loc in range(B): + if (bitmask >> loc) & 1: + indices.append(j * B + loc) else: - for _ in range(s_len): - loc = rr.read_bits(lb) - indices.append(j * B + loc) + for _ in range(s_len): + try: + loc = rr.read_bits(lb) + except EOFError as e: + raise ValueError(f"Truncated payload reading loc in row #{_i}, sub-block {j}") from e + indices.append(j * B + loc)src/tplr/comms.py (2)
2663-2682: Micro: guard edge cases and cut tiny allocations.
- If vals.ndim == 1, expected_rows is 1; add an explicit k>0 check to fail early for empty last dim.
- Use math.prod on Python ints instead of np.prod to avoid creating a NumPy array for small shapes.
Apply:
- expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1 + expected_rows = int(math.prod([int(d) for d in vals.shape[:-1]])) if vals.ndim > 1 else 1 + if vals.shape[-1] <= 0: + raise ValueError(f"[{param_name}] Values top‑k dimension must be > 0")
1606-1756: Metrics: upload_bytes remains unused.Only download_bytes is updated. If you intend to track egress in this path, either remove upload_bytes or populate it where appropriate.
Also applies to: 1729-1756
tests/test_comms.py (2)
16-18: Tests now cover Rice/bitmap path; keep numpy import local to those tests.You can import numpy only where needed (the Rice/bitmap tests) to keep module import time down, but it’s minor.
1609-1617: Comment mismatch (minor).The “Test with list (not a tensor)” block actually uses a tensor; either change the wording or use a list to avoid confusion.
Apply:
- # Test with list (not a tensor) - invalid_list = torch.tensor([0, 1, 2, 3]) + # Test with wrong dtype tensor (int64) + invalid_list = torch.tensor([0, 1, 2, 3], dtype=torch.int64)tests/unit/test_bits_codec.py (3)
71-124: Round‑trip tests look solid; add determinism to avoid flakiness.Seed the RNG to make failures reproducible across devices/backends.
Apply:
- K = make_even_k(K) + torch.manual_seed(0) + K = make_even_k(K)
157-191: CPU vs new encoder equality is brittle; compare sets or relax ordering.Different B choices can yield different emission orders while still being correct. Prefer set equality or sort both rows before comparison.
Apply:
- for i in range(N): - assert rows_ref[i] == rows_new[i], f"row {i} decode differs (CPU ref vs new)" + for i in range(N): + assert sorted(rows_ref[i]) == sorted(rows_new[i]), \ + f"row {i}: decoded sets differ (CPU ref vs new)"
282-314: Good coverage of bitmap/local path bit; consider asserting row_len > 0.Sanity‑check that the row actually carries payload in these cases.
Apply:
- C2, N2, row_len, lb, k_param, use_bitmap = parse_first_row_header(payload) + C2, N2, row_len, lb, k_param, use_bitmap = parse_first_row_header(payload) + assert row_len > 0tests/unit/test_compress.py (1)
177-203: Nice large‑C coverage. Minor nit: add seed.Seed once to stabilise.
Apply:
- x = torch.randn(128, 128) # 16,384 elements + torch.manual_seed(0) + x = torch.randn(128, 128) # 16,384 elementssrc/tplr/compress/topk.py (3)
173-189: Even‑forcing top‑k is legacy; consider allowing odd k.Rice/bitmap no longer requires even k (12‑bit path did). Enforcing even may reduce fidelity unnecessarily.
Apply:
- topk = max(topk, 2) - # Keep even by default (matches broader system expectations). - topk = topk - (topk % 2) + topk = max(topk, 1)If other callers still assume even, keep as‑is.
248-253: Avoid unnecessary CPU pinning for idx bytes.Allow device‑agnostic storage; it still decodes via .cpu().
Apply:
- idx_bytes = torch.tensor( + idx_bytes = torch.as_tensor( np.frombuffer(payload, dtype=np.uint8).copy(), - dtype=torch.uint8, - device="cpu", + dtype=torch.uint8, )
21-34: Optional: guard DTensor import for environments without DTensor.Unconditional DT import can fail on some installs.
Apply:
-from torch.distributed.tensor import DTensor as DT +try: + from torch.distributed.tensor import DTensor as DT # type: ignore +except Exception: # pragma: no cover + DT = tuple() # sentinel for isinstance checksAnd update the check:
- if isinstance(x, DT): # check for dtensors + if DT and isinstance(x, DT): # check for dtensors
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (16)
hparams/hparams.json(1 hunks)neurons/evaluator.py(1 hunks)neurons/miner.py(1 hunks)neurons/trainer.py(0 hunks)neurons/validator.py(2 hunks)src/tplr/comms.py(3 hunks)src/tplr/compress/__init__.py(1 hunks)src/tplr/compress/bits.py(1 hunks)src/tplr/compress/pack12.py(1 hunks)src/tplr/compress/topk.py(10 hunks)src/tplr/neurons.py(3 hunks)tests/test_comms.py(11 hunks)tests/test_prepare_gradient_dict.py(4 hunks)tests/unit/test_bits_codec.py(1 hunks)tests/unit/test_compress.py(8 hunks)tests/unit/test_neurons.py(0 hunks)
💤 Files with no reviewable changes (2)
- neurons/trainer.py
- tests/unit/test_neurons.py
🧰 Additional context used
🧬 Code graph analysis (10)
src/tplr/compress/bits.py (1)
src/tplr/comms.py (2)
get(1248-1364)gather(1466-1787)
tests/unit/test_bits_codec.py (1)
src/tplr/compress/bits.py (6)
BitReader(62-98)decode_batch_rows(452-488)encode_batch_rows(253-444)encode_batch_rows_cpu(129-249)read_bits(75-83)read_bytes(94-98)
tests/test_comms.py (3)
src/tplr/compress/pack12.py (1)
pack_12bit_indices(21-47)src/tplr/compress/bits.py (1)
encode_batch_rows(253-444)src/tplr/comms.py (1)
check_compressed_indices(2615-2701)
src/tplr/compress/__init__.py (2)
src/tplr/compress/bits.py (3)
decode_batch_rows(452-488)encode_batch_rows(253-444)encode_batch_rows_cpu(129-249)src/tplr/compress/pack12.py (2)
pack_12bit_indices(21-47)unpack_12bit_indices(50-75)
src/tplr/comms.py (2)
src/tplr/compress/topk.py (3)
compress(194-198)compress(200-204)compress(207-257)src/tplr/compress/bits.py (1)
decode_batch_rows(452-488)
src/tplr/neurons.py (3)
src/tplr/compress/topk.py (2)
encode(77-95)decode(98-113)tests/test_prepare_gradient_dict.py (6)
encode(33-34)encode(167-168)encode(260-262)decode(36-37)decode(170-171)decode(264-266)tests/test_comms.py (2)
encode(1026-1028)decode(1030-1032)
tests/test_prepare_gradient_dict.py (2)
src/tplr/compress/topk.py (2)
encode(77-95)decode(98-113)tests/test_comms.py (2)
encode(1026-1028)decode(1030-1032)
neurons/validator.py (3)
src/tplr/compress/topk.py (1)
decode(98-113)tests/test_prepare_gradient_dict.py (3)
decode(36-37)decode(170-171)decode(264-266)tests/test_comms.py (1)
decode(1030-1032)
tests/unit/test_compress.py (1)
src/tplr/compress/bits.py (1)
encode_batch_rows(253-444)
src/tplr/compress/topk.py (1)
src/tplr/compress/bits.py (2)
decode_batch_rows(452-488)encode_batch_rows(253-444)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (12)
neurons/validator.py (2)
265-266: Encode call updated to new signature — looks good.
ChunkingTransformer.encode(x)change is correctly applied.
2805-2806: Decode call updated to new signature — looks good.
self.transformer.decode(decompressed)matches the refactor.neurons/miner.py (1)
261-262: Encode call updated to new signature — looks good.Miner now calls
ChunkingTransformer.encode(x)withoutuse_dct. No other changes required here.src/tplr/comms.py (1)
52-52: Good switch to decoder-driven validation path.Importing decode_batch_rows here aligns comms validation with the new codec and keeps the surface consistent with tplr.compress.
tests/test_comms.py (1)
1557-1579: Nice coverage of Rice/bitmap encode/decode validation.Good positive/negative tests around shapes, k, and C. This will catch most regressions in the new codec path.
Also applies to: 1582-1598, 1635-1650, 1652-1670, 1675-1693, 1705-1737, 1739-1767
src/tplr/compress/__init__.py (1)
18-35: Public surface looks clean and complete.Re‑exporting encode/decode and keeping pack12 legacy helpers under the same namespace is a sensible transition strategy.
src/tplr/neurons.py (1)
163-183: DCT toggle removal looks correct.encode/decode calls now use the unified transformer path; outer_step also switched to transformer.decode without the old flag.
Also applies to: 328-333
tests/test_prepare_gradient_dict.py (3)
32-38: Interface alignment with transformer encode/decode (no use_dct).These changes look correct and consistent with the new API.
166-172: Local dummy passthrough matches new transformer signature.Good test double; no issues.
470-477: Updated failing_decode signature is correct.Matches the simplified transformer API.
tests/unit/test_compress.py (1)
35-54: OK: validates Rice/bitmap output type.Asserting uint8 for indices is appropriate.
src/tplr/compress/topk.py (1)
230-247: Value reordering matches emitted index order — good.This ensures decode alignment and avoids extra shuffles downstream.
| This ensures indices are within bounds and that the **new Rice/bitmap** | ||
| codec payload matches the provided values tensor shape (top‑k). | ||
|
|
||
| Args: | ||
| param_name (str): The name of the parameter being checked. | ||
| idxs (torch.Tensor): The tensor of indices. | ||
| totalk (int): The total number of elements in the original uncompressed tensor. | ||
| allowed_topk (int | None, optional): The expected number of top-k values. | ||
| Defaults to the hparams configuration. | ||
| vals (torch.Tensor | None, optional): The corresponding values tensor, | ||
| required for validating 12-bit packed indices. Defaults to None. | ||
| vals (torch.Tensor | None, optional): The corresponding values tensor. | ||
|
|
||
| Raises: | ||
| ValueError: If any validation check fails, such as out-of-bounds | ||
| indices, incorrect data types, or malformed packed data. | ||
| indices, incorrect data types, or malformed payload. | ||
| """ | ||
| allowed_topk = ( | ||
| min(self.hparams.topk_compression, totalk) | ||
| if allowed_topk is None | ||
| else min(allowed_topk, totalk) | ||
| ) | ||
|
|
||
| def _bounds_check(t: torch.Tensor): | ||
| """fast min/max bounds check""" | ||
| if t.numel() == 0: | ||
| raise ValueError(f"[{param_name}] empty index list") | ||
| if t.min().item() < 0 or t.max().item() >= totalk: | ||
| bad = t[(t < 0) | (t >= totalk)][0].item() | ||
| raise ValueError( | ||
| f"[{param_name}] Index {bad} out of bounds (totalk = {totalk})" | ||
| ) | ||
| if not isinstance(idxs, torch.Tensor): | ||
| raise ValueError( | ||
| f"[{param_name}] Expected tensor for indices, got {type(idxs)}" | ||
| ) | ||
| if vals is None: | ||
| raise ValueError( | ||
| f"[{param_name}] Values tensor required for index validation" | ||
| ) | ||
| if idxs.dtype != torch.uint8: | ||
| raise ValueError( | ||
| f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}" | ||
| ) | ||
| if idxs.numel() == 0: | ||
| raise ValueError(f"[{param_name}] Empty indices payload") | ||
|
|
||
| # Handle 12-bit packed index format only | ||
| if isinstance(idxs, torch.Tensor): | ||
| if idxs.dtype != torch.uint8: | ||
| raise ValueError( | ||
| f"[{param_name}] Expected uint8 for 12-bit packed indices, got {idxs.dtype}" | ||
| ) | ||
| # 12-bit packed format is the only supported format | ||
| if vals is None: | ||
| raise ValueError( | ||
| f"[{param_name}] Values tensor required to validate 12-bit packed indices" | ||
| ) | ||
| if idxs.numel() == 0: | ||
| raise ValueError(f"[{param_name}] Empty packed indices tensor") | ||
| # Decode (CPU) and perform structural checks | ||
| try: | ||
| payload_bytes = idxs.detach().cpu().numpy().tobytes() | ||
| rows_list, C, N = decode_batch_rows(payload_bytes) | ||
| except Exception as e: | ||
| raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}") | ||
|
|
||
| # Unpack using the values shape | ||
| try: | ||
| unpacked = unpack_12bit_indices(idxs, vals.shape) | ||
| # Validate that the last dimension matches allowed_topk | ||
| if unpacked.shape[-1] != allowed_topk: | ||
| raise ValueError( | ||
| f"[{param_name}] Invalid topk dimension: " | ||
| f"shape[-1]={unpacked.shape[-1]} but expected {allowed_topk}" | ||
| ) | ||
| _bounds_check(unpacked) | ||
| except Exception as e: | ||
| raise ValueError(f"[{param_name}] Failed to unpack 12-bit indices: {e}") | ||
| else: | ||
| raise ValueError(f"[{param_name}] Expected tensor but got {type(idxs)}") | ||
| if C != totalk: | ||
| raise ValueError( | ||
| f"[{param_name}] Payload column size C={C} but expected {totalk}" | ||
| ) | ||
|
|
||
| # compute expected rows from values shape (flatten all but last dim) | ||
| if vals.ndim == 0: | ||
| raise ValueError(f"[{param_name}] Values tensor has no top‑k dimension") | ||
| expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1 | ||
| if N != expected_rows: | ||
| raise ValueError( | ||
| f"[{param_name}] Payload rows N={N} but values imply {expected_rows}" | ||
| ) | ||
|
|
||
| k = vals.shape[-1] | ||
| if k != allowed_topk: | ||
| raise ValueError( | ||
| f"[{param_name}] Values top‑k={k} but allowed_topk={allowed_topk}" | ||
| ) | ||
| if any(len(r) != k for r in rows_list): | ||
| raise ValueError( | ||
| f"[{param_name}] At least one row has mismatched top‑k size" | ||
| ) | ||
|
|
||
| # bounds check without materialising full tensor | ||
| max_idx = max((max(r) if len(r) > 0 else -1) for r in rows_list) | ||
| min_idx = ( | ||
| min((min(r) if len(r) > 0 else 0) for r in rows_list) if rows_list else 0 | ||
| ) | ||
| if min_idx < 0 or max_idx >= totalk: | ||
| raise ValueError( | ||
| f"[{param_name}] Index out of bounds (min={min_idx}, max={max_idx}, totalk={totalk})" | ||
| ) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
check_compressed_indices rejects legacy/int64 indices; add backward‑compat path.
Today this function hard‑fails unless idxs is uint8 (Rice/bitmap). Older miners or historical payloads may still send int64 index tensors. Decompression paths already handle both; validation should too to avoid skipping otherwise valid peers.
Apply:
@@
- if idxs.dtype != torch.uint8:
- raise ValueError(
- f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}"
- )
+ if idxs.dtype == torch.uint8:
+ pass # Rice/bitmap fast-path below
+ elif idxs.dtype in (torch.int64, torch.long):
+ # Legacy/unpacked path: shape and bounds checks only
+ if vals is None:
+ raise ValueError(f"[{param_name}] Values tensor required for index validation")
+ if tuple(idxs.shape) != tuple(vals.shape):
+ raise ValueError(f"[{param_name}] Indices shape {tuple(idxs.shape)} != values shape {tuple(vals.shape)}")
+ k = vals.shape[-1] if vals.ndim > 0 else 0
+ if k != allowed_topk:
+ raise ValueError(f"[{param_name}] Values top‑k={k} but allowed_topk={allowed_topk}")
+ if idxs.numel() > 0:
+ i_min = int(idxs.min().item())
+ i_max = int(idxs.max().item())
+ if i_min < 0 or i_max >= totalk:
+ raise ValueError(f"[{param_name}] Index out of bounds (min={i_min}, max={i_max}, totalk={totalk})")
+ return # legacy path validated
+ else:
+ raise ValueError(
+ f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}"
+ )
@@
- rows_list, C, N = decode_batch_rows(payload_bytes)
+ rows_list, C, N = decode_batch_rows(payload_bytes)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| This ensures indices are within bounds and that the **new Rice/bitmap** | |
| codec payload matches the provided values tensor shape (top‑k). | |
| Args: | |
| param_name (str): The name of the parameter being checked. | |
| idxs (torch.Tensor): The tensor of indices. | |
| totalk (int): The total number of elements in the original uncompressed tensor. | |
| allowed_topk (int | None, optional): The expected number of top-k values. | |
| Defaults to the hparams configuration. | |
| vals (torch.Tensor | None, optional): The corresponding values tensor, | |
| required for validating 12-bit packed indices. Defaults to None. | |
| vals (torch.Tensor | None, optional): The corresponding values tensor. | |
| Raises: | |
| ValueError: If any validation check fails, such as out-of-bounds | |
| indices, incorrect data types, or malformed packed data. | |
| indices, incorrect data types, or malformed payload. | |
| """ | |
| allowed_topk = ( | |
| min(self.hparams.topk_compression, totalk) | |
| if allowed_topk is None | |
| else min(allowed_topk, totalk) | |
| ) | |
| def _bounds_check(t: torch.Tensor): | |
| """fast min/max bounds check""" | |
| if t.numel() == 0: | |
| raise ValueError(f"[{param_name}] empty index list") | |
| if t.min().item() < 0 or t.max().item() >= totalk: | |
| bad = t[(t < 0) | (t >= totalk)][0].item() | |
| raise ValueError( | |
| f"[{param_name}] Index {bad} out of bounds (totalk = {totalk})" | |
| ) | |
| if not isinstance(idxs, torch.Tensor): | |
| raise ValueError( | |
| f"[{param_name}] Expected tensor for indices, got {type(idxs)}" | |
| ) | |
| if vals is None: | |
| raise ValueError( | |
| f"[{param_name}] Values tensor required for index validation" | |
| ) | |
| if idxs.dtype != torch.uint8: | |
| raise ValueError( | |
| f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}" | |
| ) | |
| if idxs.numel() == 0: | |
| raise ValueError(f"[{param_name}] Empty indices payload") | |
| # Handle 12-bit packed index format only | |
| if isinstance(idxs, torch.Tensor): | |
| if idxs.dtype != torch.uint8: | |
| raise ValueError( | |
| f"[{param_name}] Expected uint8 for 12-bit packed indices, got {idxs.dtype}" | |
| ) | |
| # 12-bit packed format is the only supported format | |
| if vals is None: | |
| raise ValueError( | |
| f"[{param_name}] Values tensor required to validate 12-bit packed indices" | |
| ) | |
| if idxs.numel() == 0: | |
| raise ValueError(f"[{param_name}] Empty packed indices tensor") | |
| # Decode (CPU) and perform structural checks | |
| try: | |
| payload_bytes = idxs.detach().cpu().numpy().tobytes() | |
| rows_list, C, N = decode_batch_rows(payload_bytes) | |
| except Exception as e: | |
| raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}") | |
| # Unpack using the values shape | |
| try: | |
| unpacked = unpack_12bit_indices(idxs, vals.shape) | |
| # Validate that the last dimension matches allowed_topk | |
| if unpacked.shape[-1] != allowed_topk: | |
| raise ValueError( | |
| f"[{param_name}] Invalid topk dimension: " | |
| f"shape[-1]={unpacked.shape[-1]} but expected {allowed_topk}" | |
| ) | |
| _bounds_check(unpacked) | |
| except Exception as e: | |
| raise ValueError(f"[{param_name}] Failed to unpack 12-bit indices: {e}") | |
| else: | |
| raise ValueError(f"[{param_name}] Expected tensor but got {type(idxs)}") | |
| if C != totalk: | |
| raise ValueError( | |
| f"[{param_name}] Payload column size C={C} but expected {totalk}" | |
| ) | |
| # compute expected rows from values shape (flatten all but last dim) | |
| if vals.ndim == 0: | |
| raise ValueError(f"[{param_name}] Values tensor has no top‑k dimension") | |
| expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1 | |
| if N != expected_rows: | |
| raise ValueError( | |
| f"[{param_name}] Payload rows N={N} but values imply {expected_rows}" | |
| ) | |
| k = vals.shape[-1] | |
| if k != allowed_topk: | |
| raise ValueError( | |
| f"[{param_name}] Values top‑k={k} but allowed_topk={allowed_topk}" | |
| ) | |
| if any(len(r) != k for r in rows_list): | |
| raise ValueError( | |
| f"[{param_name}] At least one row has mismatched top‑k size" | |
| ) | |
| # bounds check without materialising full tensor | |
| max_idx = max((max(r) if len(r) > 0 else -1) for r in rows_list) | |
| min_idx = ( | |
| min((min(r) if len(r) > 0 else 0) for r in rows_list) if rows_list else 0 | |
| ) | |
| if min_idx < 0 or max_idx >= totalk: | |
| raise ValueError( | |
| f"[{param_name}] Index out of bounds (min={min_idx}, max={max_idx}, totalk={totalk})" | |
| ) | |
| if not isinstance(idxs, torch.Tensor): | |
| raise ValueError( | |
| f"[{param_name}] Expected tensor for indices, got {type(idxs)}" | |
| ) | |
| if vals is None: | |
| raise ValueError( | |
| f"[{param_name}] Values tensor required for index validation" | |
| ) | |
| if idxs.dtype == torch.uint8: | |
| pass # Rice/bitmap fast-path below | |
| elif idxs.dtype in (torch.int64, torch.long): | |
| # Legacy/unpacked path: shape and bounds checks only | |
| if vals is None: | |
| raise ValueError(f"[{param_name}] Values tensor required for index validation") | |
| if tuple(idxs.shape) != tuple(vals.shape): | |
| raise ValueError(f"[{param_name}] Indices shape {tuple(idxs.shape)} != values shape {tuple(vals.shape)}") | |
| k = vals.shape[-1] if vals.ndim > 0 else 0 | |
| if k != allowed_topk: | |
| raise ValueError(f"[{param_name}] Values top-k={k} but allowed_topk={allowed_topk}") | |
| if idxs.numel() > 0: | |
| i_min = int(idxs.min().item()) | |
| i_max = int(idxs.max().item()) | |
| if i_min < 0 or i_max >= totalk: | |
| raise ValueError(f"[{param_name}] Index out of bounds (min={i_min}, max={i_max}, totalk={totalk})") | |
| return # legacy path validated | |
| else: | |
| raise ValueError( | |
| f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}" | |
| ) | |
| if idxs.numel() == 0: | |
| raise ValueError(f"[{param_name}] Empty indices payload") | |
| # Decode (CPU) and perform structural checks | |
| try: | |
| payload_bytes = idxs.detach().cpu().numpy().tobytes() | |
| rows_list, C, N = decode_batch_rows(payload_bytes) | |
| except Exception as e: | |
| raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}") |
🤖 Prompt for AI Agents
In src/tplr/comms.py around lines 2626 to 2702, the validation currently rejects
idxs unless dtype is torch.uint8; update it to accept legacy torch.int64 index
tensors as well by adding a backward-compat branch: if idxs.dtype is
torch.int64, skip the Rice/bitmap decode and instead interpret the tensor as raw
integer indices (reshape/flatten consistent with vals shape and top-k), build
rows_list from those indices, compute C, N, min/max bounds from that legacy
format, and run the same consistency checks (C vs totalk, N vs expected_rows, k
vs allowed_topk, per-row length, bounds). Keep the existing uint8 path unchanged
and only raise for other dtypes, ensuring all error messages remain
param-prefixed.
There was a problem hiding this comment.
I don't want the backwards compability
There was a problem hiding this comment.
@joellidin, understood. I'll respect your decision to not include backward compatibility for the legacy index formats. The current implementation correctly enforces that all peers must use the new Rice/bitmap encoding format (uint8 payload), which keeps the validation logic clean and ensures consistency across the network.
✏️ Learnings added
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
| def read_bits(self, n: int) -> int: | ||
| if n <= 0: | ||
| return 0 | ||
| self._fill(n) | ||
| mask = (1 << n) - 1 | ||
| out = self.cur & mask | ||
| self.cur >>= n | ||
| self.nbits -= n | ||
| return out |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Prevent bit underflow on truncated input.
read_bits can drive nbits negative if the stream is shorter than requested. Raise an explicit EOF to avoid corrupted state.
def read_bits(self, n: int) -> int:
if n <= 0:
return 0
self._fill(n)
+ if self.nbits < n:
+ raise EOFError(f"Not enough bits in stream to read {n} bits")
mask = (1 << n) - 1
out = self.cur & mask
self.cur >>= n
self.nbits -= n
return out📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def read_bits(self, n: int) -> int: | |
| if n <= 0: | |
| return 0 | |
| self._fill(n) | |
| mask = (1 << n) - 1 | |
| out = self.cur & mask | |
| self.cur >>= n | |
| self.nbits -= n | |
| return out | |
| def read_bits(self, n: int) -> int: | |
| if n <= 0: | |
| return 0 | |
| self._fill(n) | |
| if self.nbits < n: | |
| raise EOFError(f"Not enough bits in stream to read {n} bits") | |
| mask = (1 << n) - 1 | |
| out = self.cur & mask | |
| self.cur >>= n | |
| self.nbits -= n | |
| return out |
🤖 Prompt for AI Agents
In src/tplr/compress/bits.py around lines 75 to 83, read_bits currently calls
self._fill(n) then unconditionally consumes n bits which can make self.nbits
negative on truncated input; change it so that after calling self._fill(n) you
check whether there are at least n bits available and if not raise EOFError (or
a suitable custom exception) without mutating self.cur or self.nbits, otherwise
proceed to compute mask/out and update self.cur and self.nbits; ensure the
function returns the bits only when the full n bits are available.
| max_idx = indices.max().item() if indices.numel() > 0 else 0 | ||
| if max_idx >= 4096: | ||
| raise ValueError(f"Index {max_idx} exceeds 12-bit limit (4095)") | ||
|
|
||
| flat = indices.flatten() | ||
| n = flat.numel() | ||
| if n % 2 != 0: | ||
| raise ValueError(f"Number of indices must be even, got {n}") |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Guard against negative indices.
Packing negative values produces incorrect bit patterns. Add a lower-bound check.
- max_idx = indices.max().item() if indices.numel() > 0 else 0
+ max_idx = indices.max().item() if indices.numel() > 0 else 0
+ min_idx = indices.min().item() if indices.numel() > 0 else 0
+ if min_idx < 0:
+ raise ValueError(f"Index {min_idx} is negative; only non-negative indices are supported")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| max_idx = indices.max().item() if indices.numel() > 0 else 0 | |
| if max_idx >= 4096: | |
| raise ValueError(f"Index {max_idx} exceeds 12-bit limit (4095)") | |
| flat = indices.flatten() | |
| n = flat.numel() | |
| if n % 2 != 0: | |
| raise ValueError(f"Number of indices must be even, got {n}") | |
| max_idx = indices.max().item() if indices.numel() > 0 else 0 | |
| min_idx = indices.min().item() if indices.numel() > 0 else 0 | |
| if min_idx < 0: | |
| raise ValueError(f"Index {min_idx} is negative; only non-negative indices are supported") | |
| if max_idx >= 4096: | |
| raise ValueError(f"Index {max_idx} exceeds 12-bit limit (4095)") | |
| flat = indices.flatten() | |
| n = flat.numel() | |
| if n % 2 != 0: | |
| raise ValueError(f"Number of indices must be even, got {n}") |
🤖 Prompt for AI Agents
In src/tplr/compress/pack12.py around lines 26 to 33, add a lower-bound check
for negative indices before packing: compute the minimum index (only when
indices.numel() > 0), and if min_idx < 0 raise ValueError("Index {min_idx} is
negative; cannot pack negative indices") (or similar message), keeping the
existing upper-bound check for >=4096 and the even-count check; ensure you use
the same tensor-to-python conversion pattern (.min().item()) as used for max.
| device = packed.device | ||
| n_indices = 1 | ||
| for d in values_shape: | ||
| n_indices *= int(d) | ||
| if n_indices == 0: | ||
| return torch.zeros(values_shape, dtype=torch.int64, device=device) | ||
| if n_indices % 2 != 0: | ||
| raise ValueError(f"Number of indices must be even, got {n_indices}") | ||
|
|
||
| out = torch.zeros(n_indices, dtype=torch.int64, device=device) | ||
| n_pairs = n_indices // 2 | ||
| if n_pairs > 0: | ||
| b0 = packed[0::3].to(torch.int64) | ||
| b1 = packed[1::3].to(torch.int64) | ||
| b2 = packed[2::3].to(torch.int64) | ||
|
|
||
| out[0::2] = b0 | ((b1 & 0x0F) << 8) | ||
| out[1::2] = ((b1 >> 4) & 0x0F) | (b2 << 4) | ||
| return out.view(*values_shape) |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Validate packed length and dtype before unpacking.
Mismatched or non-uint8 buffers can raise shape errors or silently decode wrong values.
- device = packed.device
+ # Ensure correct dtype/device
+ if packed.dtype != torch.uint8:
+ packed = packed.to(torch.uint8)
+ device = packed.device
@@
- if n_indices % 2 != 0:
+ if n_indices % 2 != 0:
raise ValueError(f"Number of indices must be even, got {n_indices}")
+ expected_bytes = (n_indices // 2) * 3
+ if packed.numel() != expected_bytes:
+ raise ValueError(
+ f"Packed buffer length mismatch: expected {expected_bytes} bytes for "
+ f"{n_indices} indices, got {packed.numel()}"
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| device = packed.device | |
| n_indices = 1 | |
| for d in values_shape: | |
| n_indices *= int(d) | |
| if n_indices == 0: | |
| return torch.zeros(values_shape, dtype=torch.int64, device=device) | |
| if n_indices % 2 != 0: | |
| raise ValueError(f"Number of indices must be even, got {n_indices}") | |
| out = torch.zeros(n_indices, dtype=torch.int64, device=device) | |
| n_pairs = n_indices // 2 | |
| if n_pairs > 0: | |
| b0 = packed[0::3].to(torch.int64) | |
| b1 = packed[1::3].to(torch.int64) | |
| b2 = packed[2::3].to(torch.int64) | |
| out[0::2] = b0 | ((b1 & 0x0F) << 8) | |
| out[1::2] = ((b1 >> 4) & 0x0F) | (b2 << 4) | |
| return out.view(*values_shape) | |
| # Ensure correct dtype/device | |
| if packed.dtype != torch.uint8: | |
| packed = packed.to(torch.uint8) | |
| device = packed.device | |
| n_indices = 1 | |
| for d in values_shape: | |
| n_indices *= int(d) | |
| if n_indices == 0: | |
| return torch.zeros(values_shape, dtype=torch.int64, device=device) | |
| if n_indices % 2 != 0: | |
| raise ValueError(f"Number of indices must be even, got {n_indices}") | |
| expected_bytes = (n_indices // 2) * 3 | |
| if packed.numel() != expected_bytes: | |
| raise ValueError( | |
| f"Packed buffer length mismatch: expected {expected_bytes} bytes for " | |
| f"{n_indices} indices, got {packed.numel()}" | |
| ) | |
| out = torch.zeros(n_indices, dtype=torch.int64, device=device) | |
| n_pairs = n_indices // 2 | |
| if n_pairs > 0: | |
| b0 = packed[0::3].to(torch.int64) | |
| b1 = packed[1::3].to(torch.int64) | |
| b2 = packed[2::3].to(torch.int64) | |
| out[0::2] = b0 | ((b1 & 0x0F) << 8) | |
| out[1::2] = ((b1 >> 4) & 0x0F) | (b2 << 4) | |
| return out.view(*values_shape) |
Replace the 12-bit packed index format with the new Rice/bitmap compression codec throughout the codebase. This migration affects compression, communication, and neurons components. - Update TopKCompressor to use encode_batch_rows/decode_batch_rows - Migrate check_compressed_indices validation to Rice/bitmap format - Reorder values to match codec permutation during compression - Update all test cases from 12-bit to Rice/bitmap encoding - Handle uint8 payload format and structural validation - Ensure compatibility with B_choices=(64, 128) tensor dimensions - Extend Rice/bitmap codec support to neurons module
Remove Discrete Cosine Transform functionality from the compression system while retaining the ChunkingTransformer for tensor chunking. - Remove DCT transformation methods from ChunkingTransformer - Remove f_dict and b_dict (DCT basis matrices) - Remove norm parameter and einsum helper methods - Remove use_dct parameter from encode/decode methods throughout - Remove DCT helper functions (_dct, _idct, etc.) - Update all callers to remove use_dct parameter - Remove use_dct from hparams configuration - Update tests to remove DCT-related test cases ChunkingTransformer now only handles tensor reshaping and chunking, simplifying the implementation by removing unused DCT functionality.
With the improved Rice/bitmap compression codec, we can now transmit 4x more gradient values (128 vs 32) with acceptable overhead. This should improve gradient quality and convergence while the new codec keeps communication costs manageable.
c1adcac to
db48e62
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (5)
tests/unit/test_compress.py (2)
102-141: Each encoded payload must cover all rows (N == xshape[0]).Both idx1_orig and idx2_orig encode only 2 rows while xshape is 8. This will break batch_decompress scatter.
Apply:
- # Create multiple Rice/bitmap encoded indices - idx1_orig = torch.tensor([[0, 1], [2, 3]], dtype=torch.int64) - idx2_orig = torch.tensor([[4, 5], [6, 7]], dtype=torch.int64) + # Create per-row indices for all rows so N matches xshape[0] + rows = xshape[0] + base = torch.arange(rows, dtype=torch.int64) * 4 + idx1_orig = torch.stack([(base + 0) % totalk, (base + 1) % totalk], dim=1) + idx2_orig = torch.stack([(base + 2) % totalk, (base + 3) % totalk], dim=1) @@ - val1 = torch.tensor([[0.1, 0.2], [0.3, 0.4]], dtype=torch.float32) - val2 = torch.tensor([[0.5, 0.6], [0.7, 0.8]], dtype=torch.float32) + val1 = torch.randn(rows, 2, dtype=torch.float32) + val2 = torch.randn(rows, 2, dtype=torch.float32)
208-241: Row-count mismatch again in norm/clip tests; idx encodes 1 row while xshape has 8.Align rows in the encoded payload and values with xshape[0].
Apply:
- # Create test data with Rice/bitmap encoded format - idx_orig = torch.tensor([[0, 1, 2, 3]], dtype=torch.int64) # Even count - payload, perm, _ = encode_batch_rows(idx_orig, C=totalk) + # Create test data with Rice/bitmap encoded format for all rows + K = 4 + rows = xshape[0] + idx_orig = torch.stack( + [torch.arange(i * K, i * K + K, dtype=torch.int64) % totalk for i in range(rows)] + ) + payload, perm, _ = encode_batch_rows(idx_orig, C=totalk) @@ - val_orig = torch.tensor([[10.0, 20.0, 30.0, 40.0]], dtype=torch.float32) - # Reorder values to match permutation - val = [torch.gather(val_orig, dim=1, index=perm)] + val_orig = torch.randn(rows, K, dtype=torch.float32) + # Reorder values to match permutation + val = [torch.gather(val_orig, dim=1, index=perm)]tests/test_comms.py (1)
52-57: Use Rice/bitmap payload in the “valid state dict” helper.This aligns fixtures with the enforced format and avoids silently producing incompatible index payloads.
- for name, _ in model.named_parameters(): - # Create legacy 12-bit packed format (for backwards compatibility test) - indices = torch.tensor([0, 1], dtype=torch.long) - packed_data = pack_12bit_indices(indices) - state_dict[name + "idxs"] = packed_data - state_dict[name + "vals"] = torch.tensor([0.1, 0.2], dtype=torch.float32) + for name, param in model.named_parameters(): + # Produce Rice/bitmap payload matching totalk + totalk = param.numel() + idx2d = torch.tensor([[0, 1]], dtype=torch.long) + payload, _perm, _ = encode_batch_rows(idx2d, C=totalk) + packed = torch.tensor(np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8) + state_dict[name + "idxs"] = packed + state_dict[name + "vals"] = torch.tensor([0.1, 0.2], dtype=torch.float32)neurons/validator.py (1)
2686-2694: Remove legacy 12‑bit tuple handling; enforce Rice/bitmap only.This path contradicts the migration objective and will silently accept legacy payloads. Per the project decision, only uint8 Rice/bitmap payloads must be accepted. Proactively reject tuples here and let
check_compressed_indicesvalidate the Rice/bitmap format.[pulled_from_retrieved_learnings]
Apply this diff:- # Handle 12-bit packed format: (packed_tensor, original_shape) - if isinstance(idxs, tuple) and len(idxs) == 2: - packed_data, original_shape = idxs - # Move packed data to device - packed_data = packed_data.to(self.device) - idxs = (packed_data, original_shape) + # Reject legacy 12-bit formats explicitly + if isinstance(idxs, tuple): + raise ValueError( + f"Invalid gradient data from peer {eval_uid}: legacy 12-bit indices are no longer supported" + )src/tplr/neurons.py (1)
324-336: NameError at runtime: deleting undefined variablesvals/qps.After refactor,
valsandqpsare not defined inouter_step. Thedelwill crash on master.Apply this diff:
- # Free intermediate pieces ASAP - del vals_f32, idxs, vals, qps, ref, decompressed + # Free intermediate pieces ASAP + del vals_f32, idxs, ref, decompressed
♻️ Duplicate comments (1)
tests/unit/test_compress.py (1)
74-101: Row-count mismatch: payload encodes 2 rows while xshape is 8 → decompress will fail.encode_batch_rows(original_indices, C=totalk) here emits rows=2 but p/xshape expect rows=8. Make payload rows match xshape[0].
Apply:
- # Create proper Rice/bitmap encoded format using the encoder - # Create indices that are within valid range for a 8x64 tensor (even count) - original_indices = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.int64) - - # Pack using the new encoder format - payload, perm, _ = encode_batch_rows(original_indices, C=totalk) - idx = torch.tensor(np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8) - - val = torch.tensor( - [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32 - ) - # Reorder values to match permutation - val = torch.gather(val, dim=1, index=perm) + # Build per-row K indices matching xshape[0] rows + K = 4 + rows = xshape[0] + original_indices = torch.stack( + [torch.arange(i * K, i * K + K, dtype=torch.int64) % totalk for i in range(rows)] + ) + # Encode and permute corresponding values + payload, perm, _ = encode_batch_rows(original_indices, C=totalk) + idx = torch.tensor(np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8) + val = torch.randn(rows, K, dtype=torch.float32) + val = torch.gather(val, dim=1, index=perm)
🧹 Nitpick comments (9)
tests/unit/test_compress.py (3)
54-73: Quantized path checks mostly fine; consider asserting qparams structure.If qparams is a dict, assert expected keys to guard regressions.
Apply this minimal guard:
- assert qparams is not None - assert len(qparams) == 5 # shift, scale, offset, lookup, orig_dtype + assert qparams is not None + if isinstance(qparams, dict): + for k in ("shift", "scale", "offset", "lookup", "orig_dtype"): + assert k in qparams
83-85: Remove stale “even count” comments (not required for Rice/bitmap).Rice/bitmap has no even-count constraint; these comments are legacy from 12‑bit tests.
- # Create indices that are within valid range for a 8x64 tensor (even count) + # Create indices within valid range for an 8x64 tensorand similarly in the norm/clip test.
Also applies to: 212-214
54-73: Optional: factor a small helper to generate per-row indices/values to reduce duplication.Keeps tests short and consistent.
Example helper to add near the top of the file:
def build_rowwise_indices_values(rows: int, K: int, C: int): idx = torch.stack( [torch.arange(i * K, i * K + K, dtype=torch.int64) % C for i in range(rows)] ) val = torch.randn(rows, K, dtype=torch.float32) payload, perm, _ = encode_batch_rows(idx, C=C) idx_bytes = torch.tensor(np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8) val = torch.gather(val, dim=1, index=perm) return idx_bytes, valAlso applies to: 102-141, 208-241
src/tplr/comms.py (1)
1656-1657: Update stale comments: replace “12‑bit” wording with Rice/bitmap.Minor clarity fix to match the new format used across the codebase.
- # Get corresponding vals tensor for 12-bit unpacking + # Get corresponding vals tensor to validate decoder-aligned top‑k @@ - # Handle 12-bit packed format (uint8 tensor) + # Account bytes size from Rice/bitmap (uint8 payload)Also applies to: 1737-1741
tests/test_comms.py (4)
16-18: Prefer new codec in tests; keep legacy helpers clearly scoped.Importing encode_batch_rows and numpy is fine. If legacy pack_12bit_indices remains, limit its use to explicitly marked legacy tests to avoid confusion with the enforced Rice/bitmap-only path.
69-76: Clarify helper’s legacy intent.Docstring says “legacy 12‑bit” already; consider renaming to create_legacy_packed_indices to make call sites self‑documenting. Not blocking.
1572-1575: Silence unused variable warnings for perm.Ruff/flake8 may flag ‘perm’ as unused; bind to underscore.
- payload, perm, _ = encode_batch_rows(valid_indices, C=totalk) + payload, _perm, _ = encode_batch_rows(valid_indices, C=totalk)(Apply similarly in the other tests listed in the line ranges.)
Also applies to: 1592-1595, 1644-1646, 1664-1666, 1716-1719, 1729-1731, 1749-1751, 1762-1764
1756-1757: Correct comment: Rice/bitmap doesn’t require an even number of indices.Update the note to avoid implying 12‑bit constraints.
- # Note: Can't test with 1 element as encoder requires even number of indices + # Note: Rice/bitmap supports any k; we use k=6 here to deliberately trigger a top‑k mismatch.src/tplr/neurons.py (1)
962-976: Optional sanity: validate codec’s C against expected param size (when available).
decode_batch_rowsreturnsCbut it’s unused. Ifneuronhastotalks[pname], assert equality to catch mismatched payloads early.Example:
- rows_i, _C_codec, N_rows = decode_batch_rows( + rows_i, _C_codec, N_rows = decode_batch_rows( payload ) # rows_i: list[list[int]] + # Optional guard: ensure codec header matches expected C + # if hasattr(neuron, "totalks") and pname in neuron.totalks: + # expected_C = int(neuron.totalks[pname]) + # if _C_codec != expected_C: + # raise ValueError(f"Codec C mismatch for {pname}: {_C_codec} != {expected_C}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (12)
hparams/hparams.json(1 hunks)neurons/evaluator.py(1 hunks)neurons/miner.py(1 hunks)neurons/trainer.py(0 hunks)neurons/validator.py(2 hunks)src/tplr/comms.py(3 hunks)src/tplr/compress/topk.py(10 hunks)src/tplr/neurons.py(5 hunks)tests/test_comms.py(11 hunks)tests/test_prepare_gradient_dict.py(4 hunks)tests/unit/test_compress.py(8 hunks)tests/unit/test_neurons.py(0 hunks)
💤 Files with no reviewable changes (2)
- tests/unit/test_neurons.py
- neurons/trainer.py
🚧 Files skipped from review as they are similar to previous changes (5)
- neurons/miner.py
- neurons/evaluator.py
- tests/test_prepare_gradient_dict.py
- src/tplr/compress/topk.py
- hparams/hparams.json
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
📚 Learning: 2025-08-31T15:02:03.568Z
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
Applied to files:
src/tplr/comms.pytests/test_comms.py
🧬 Code graph analysis (5)
src/tplr/neurons.py (2)
src/tplr/compress/topk.py (5)
compress(194-198)compress(200-204)compress(207-257)encode(77-95)decode(98-113)src/tplr/compress/bits.py (1)
decode_batch_rows(452-488)
src/tplr/comms.py (2)
src/tplr/compress/topk.py (3)
compress(194-198)compress(200-204)compress(207-257)src/tplr/compress/bits.py (1)
decode_batch_rows(452-488)
tests/unit/test_compress.py (3)
src/tplr/compress/topk.py (5)
compress(194-198)compress(200-204)compress(207-257)_get_smaller_split(609-628)encode(77-95)src/tplr/compress/bits.py (1)
encode_batch_rows(253-444)src/tplr/compress/pack12.py (2)
pack_12bit_indices(21-47)unpack_12bit_indices(50-75)
tests/test_comms.py (3)
src/tplr/compress/pack12.py (1)
pack_12bit_indices(21-47)src/tplr/compress/bits.py (1)
encode_batch_rows(253-444)src/tplr/comms.py (1)
check_compressed_indices(2615-2701)
neurons/validator.py (3)
src/tplr/compress/topk.py (1)
decode(98-113)tests/test_prepare_gradient_dict.py (3)
decode(36-37)decode(170-171)decode(264-266)tests/test_comms.py (1)
decode(1030-1032)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (11)
tests/unit/test_compress.py (6)
3-3: Import of numpy is appropriate.Used for frombuffer on codec payload; no issues.
11-17: Public API imports look correct.Shifts tests to Rice/bitmap path and the new helper; aligns with the refactor and DCT removal.
35-53: Good assertion of Rice/bitmap payload type.Asserting uint8 for indices matches the new codec and the migration plan to drop legacy int64 formats.
146-176: Round-trip test looks solid.Confirms new path end-to-end and asserts dtype/shape.
177-203: Good coverage for large index ranges.Asserting totalk equals the last dim for 2D tensors and verifying decompress result shape is sufficient here.
282-289: Transformer encode/decode tests align with DCT removal.API usage matches the updated signatures.
src/tplr/comms.py (1)
52-52: Decoder-driven validation import looks good.Switching to decode_batch_rows via the compress package aligns validation with the new codec.
src/tplr/neurons.py (4)
163-169: Switched encode path to new transformer API — LGTM.Owner-only encode aligns with the refactor and keeps EF on-device.
182-187: Switched decode path to new transformer API — LGTM.Post-decompress decode matches updated interface.
980-985: Weighting metric clarified — LGTM.Using
param_weight = C_chunks * ksensibly size‑weights by rows×k after the codec change.
35-36: Importdecode_batch_rowscorrectly or expose it intplr.compress
tplr.compress/__init__.pydoesn’t exportdecode_batch_rows, sofrom tplr.compress import decode_batch_rows(src/tplr/neurons.py:35) will fail. Either change the import to:from tplr.compress.bits import decode_batch_rowsor re-export it in
tplr/compress/__init__.pyby adding:from .bits import decode_batch_rows __all__.append("decode_batch_rows")Likely an incorrect or invalid review comment.
| enc = self.transformer.encode( | ||
| torch.empty(p.shape, dtype=torch.float16, device=self.device), | ||
| use_dct=self.hparams.use_dct, | ||
| torch.empty(p.shape, dtype=torch.float16, device=self.device) | ||
| ) |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
API update to drop use_dct — LGTM.
ChunkingTransformer.encode/decode now called without use_dct; matches the new signatures.
Run to confirm no stale use_dct remains:
Also applies to: 2805-2810
🏁 Script executed:
#!/bin/bash
rg -nP --hidden --glob '!*site-packages*' '\buse_dct\b|outer_step\s*\([^)]*use_dct' -C2Length of output: 390
Remove stale use_dct reference in test config
- In scripts/local_miner_test.py (line 179), remove or update the
"use_dct": Falseentry so tests align with the updated ChunkingTransformer API.
🤖 Prompt for AI Agents
In neurons/validator.py around lines 264 to 266 (and the related test at
scripts/local_miner_test.py line 179), there's a stale test config entry
"use_dct": False that no longer matches the updated ChunkingTransformer API;
remove the "use_dct" key from the test config (or replace it with the current
API-equivalent option if the new transformer expects an explicit setting), then
run the tests and adjust any assertions that depended on the previous behavior
so they reflect the transformer’s current chunking behavior.
| This ensures indices are within bounds and that the **new Rice/bitmap** | ||
| codec payload matches the provided values tensor shape (top‑k). | ||
|
|
||
| Args: | ||
| param_name (str): The name of the parameter being checked. | ||
| idxs (torch.Tensor): The tensor of indices. | ||
| totalk (int): The total number of elements in the original uncompressed tensor. | ||
| allowed_topk (int | None, optional): The expected number of top-k values. | ||
| Defaults to the hparams configuration. | ||
| vals (torch.Tensor | None, optional): The corresponding values tensor, | ||
| required for validating 12-bit packed indices. Defaults to None. | ||
| vals (torch.Tensor | None, optional): The corresponding values tensor. | ||
|
|
||
| Raises: | ||
| ValueError: If any validation check fails, such as out-of-bounds | ||
| indices, incorrect data types, or malformed packed data. | ||
| indices, incorrect data types, or malformed payload. | ||
| """ | ||
| allowed_topk = ( | ||
| min(self.hparams.topk_compression, totalk) | ||
| if allowed_topk is None | ||
| else min(allowed_topk, totalk) | ||
| ) | ||
|
|
||
| def _bounds_check(t: torch.Tensor): | ||
| """fast min/max bounds check""" | ||
| if t.numel() == 0: | ||
| raise ValueError(f"[{param_name}] empty index list") | ||
| if t.min().item() < 0 or t.max().item() >= totalk: | ||
| bad = t[(t < 0) | (t >= totalk)][0].item() | ||
| raise ValueError( | ||
| f"[{param_name}] Index {bad} out of bounds (totalk = {totalk})" | ||
| ) | ||
| if not isinstance(idxs, torch.Tensor): | ||
| raise ValueError( | ||
| f"[{param_name}] Expected tensor for indices, got {type(idxs)}" | ||
| ) | ||
| if vals is None: | ||
| raise ValueError( | ||
| f"[{param_name}] Values tensor required for index validation" | ||
| ) | ||
| if idxs.dtype != torch.uint8: | ||
| raise ValueError( | ||
| f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}" | ||
| ) | ||
| if idxs.numel() == 0: | ||
| raise ValueError(f"[{param_name}] Empty indices payload") | ||
|
|
||
| # Handle 12-bit packed index format only | ||
| if isinstance(idxs, torch.Tensor): | ||
| if idxs.dtype != torch.uint8: | ||
| raise ValueError( | ||
| f"[{param_name}] Expected uint8 for 12-bit packed indices, got {idxs.dtype}" | ||
| ) | ||
| # 12-bit packed format is the only supported format | ||
| if vals is None: | ||
| raise ValueError( | ||
| f"[{param_name}] Values tensor required to validate 12-bit packed indices" | ||
| ) | ||
| if idxs.numel() == 0: | ||
| raise ValueError(f"[{param_name}] Empty packed indices tensor") | ||
| # Decode (CPU) and perform structural checks | ||
| try: | ||
| payload_bytes = idxs.detach().cpu().numpy().tobytes() | ||
| rows_list, C, N = decode_batch_rows(payload_bytes) | ||
| except Exception as e: | ||
| raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}") | ||
|
|
||
| # Unpack using the values shape | ||
| try: | ||
| unpacked = unpack_12bit_indices(idxs, vals.shape) | ||
| # Validate that the last dimension matches allowed_topk | ||
| if unpacked.shape[-1] != allowed_topk: | ||
| raise ValueError( | ||
| f"[{param_name}] Invalid topk dimension: " | ||
| f"shape[-1]={unpacked.shape[-1]} but expected {allowed_topk}" | ||
| ) | ||
| _bounds_check(unpacked) | ||
| except Exception as e: | ||
| raise ValueError(f"[{param_name}] Failed to unpack 12-bit indices: {e}") | ||
| else: | ||
| raise ValueError(f"[{param_name}] Expected tensor but got {type(idxs)}") | ||
| if C != totalk: | ||
| raise ValueError( | ||
| f"[{param_name}] Payload column size C={C} but expected {totalk}" | ||
| ) | ||
|
|
||
| # compute expected rows from values shape (flatten all but last dim) | ||
| if vals.ndim == 0: | ||
| raise ValueError(f"[{param_name}] Values tensor has no top‑k dimension") | ||
| expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1 | ||
| if N != expected_rows: | ||
| raise ValueError( | ||
| f"[{param_name}] Payload rows N={N} but values imply {expected_rows}" | ||
| ) | ||
|
|
||
| k = vals.shape[-1] | ||
| if k != allowed_topk: | ||
| raise ValueError( | ||
| f"[{param_name}] Values top‑k={k} but allowed_topk={allowed_topk}" | ||
| ) | ||
| if any(len(r) != k for r in rows_list): | ||
| raise ValueError( | ||
| f"[{param_name}] At least one row has mismatched top‑k size" | ||
| ) | ||
|
|
||
| # bounds check without materialising full tensor | ||
| max_idx = max((max(r) if len(r) > 0 else -1) for r in rows_list) | ||
| min_idx = ( | ||
| min((min(r) if len(r) > 0 else 0) for r in rows_list) if rows_list else 0 | ||
| ) | ||
| if min_idx < 0 or max_idx >= totalk: | ||
| raise ValueError( | ||
| f"[{param_name}] Index out of bounds (min={min_idx}, max={max_idx}, totalk={totalk})" | ||
| ) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Tighten validation: int-cast topk, guard vals type, drop numpy by using math.prod.
Prevents surprising float topk from misconfigured hparams, avoids attribute errors if vals isn’t a tensor, and removes an unnecessary dependency.
Apply:
- import numpy as np
+ # (numpy no longer needed here)
@@
- allowed_topk = (
- min(self.hparams.topk_compression, totalk)
- if allowed_topk is None
- else min(allowed_topk, totalk)
- )
+ allowed_topk = (
+ min(self.hparams.topk_compression, totalk)
+ if allowed_topk is None
+ else min(allowed_topk, totalk)
+ )
+ try:
+ allowed_topk = int(allowed_topk)
+ except Exception:
+ raise ValueError(f"[{param_name}] allowed_topk must be an integer-compatible value")
@@
- if vals is None:
+ if vals is None:
raise ValueError(
f"[{param_name}] Values tensor required for index validation"
)
+ if not torch.is_tensor(vals):
+ raise ValueError(f"[{param_name}] Values must be a torch.Tensor, got {type(vals)}")
@@
- expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1
+ expected_rows = int(math.prod(vals.shape[:-1])) if vals.ndim > 1 else 1Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/tplr/comms.py around lines 2626 to 2702, tighten validation by (1)
enforcing vals is a torch.Tensor before using .ndim/.shape (raise a ValueError
if not), (2) casting allowed_topk/k to int to avoid float surprises, and (3)
replace numpy.prod usage with math.prod over vals.shape[:-1] (import math) to
remove the numpy dependency; keep all existing structural and bounds checks but
perform the type guard early and use int(...) when comparing top-k values.
| idx_data = idxs_all[i] if isinstance(idxs_all, (list, tuple)) else idxs_all | ||
| payload = _as_bytes(idx_data) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Avoid duplicating a single idx payload across all peers.
If idxs_all isn’t a list/tuple, treating it as every peer’s indices can produce false positives in overlap. Prefer skipping that param or enforcing list/tuple.
Apply this diff:
- idx_data = idxs_all[i] if isinstance(idxs_all, (list, tuple)) else idxs_all
+ if not isinstance(idxs_all, (list, tuple)):
+ # Not a per-peer payload; skip this parameter safely
+ decoded_per_peer = []
+ break
+ idx_data = idxs_all[i]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| idx_data = idxs_all[i] if isinstance(idxs_all, (list, tuple)) else idxs_all | |
| payload = _as_bytes(idx_data) | |
| if not isinstance(idxs_all, (list, tuple)): | |
| # Not a per-peer payload; skip this parameter safely | |
| decoded_per_peer = [] | |
| break | |
| idx_data = idxs_all[i] | |
| payload = _as_bytes(idx_data) |
🤖 Prompt for AI Agents
In src/tplr/neurons.py around lines 959-961, the code treats a non-list/tuple
idxs_all as if it applies to every peer, causing duplicate idx payloads and
false overlap; change the logic so idxs_all must be a list/tuple to be used
per-peer—if not isinstance(idxs_all, (list, tuple)) then skip per-peer idx
payload (e.g., set idx_data = None and avoid creating payload) or explicitly
raise a TypeError/ValueError upstream; ensure the per-peer branch only uses
idxs_all[i] when idxs_all is a list/tuple to prevent duplicating the same idx
across all peers.
| totalk = 128 | ||
|
|
||
| # Test with regular tensor (not packed) - should fail because it's not uint8 | ||
| invalid_tensor = torch.tensor([0, 1, 2, 3], dtype=torch.long) | ||
| vals = torch.randn(4, dtype=torch.float32) | ||
| # This should fail since only uint8 12-bit packed format is supported | ||
| with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"): | ||
| # This should fail since only uint8 Rice/bitmap encoded format is supported | ||
| with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"): | ||
| dummy_comms.check_compressed_indices("param", invalid_tensor, totalk, vals=vals) | ||
|
|
||
| # Test with list (not a tensor) | ||
| invalid_list = torch.tensor([0, 1, 2, 3]) | ||
| with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"): | ||
| with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"): | ||
| dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals) |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Fix test case intent: non‑tensor input vs wrong dtype.
The second assertion claims to test a non‑tensor, but passes a tensor and matches the wrong error. Either make it a true non‑tensor case or rename. Suggest splitting:
- # Test with list (not a tensor)
- invalid_list = torch.tensor([0, 1, 2, 3])
- with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"):
- dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals)
+ # Test with list (not a tensor) → type error branch
+ invalid_list = [0, 1, 2, 3]
+ with pytest.raises(ValueError, match="Expected tensor for indices"):
+ dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| totalk = 128 | |
| # Test with regular tensor (not packed) - should fail because it's not uint8 | |
| invalid_tensor = torch.tensor([0, 1, 2, 3], dtype=torch.long) | |
| vals = torch.randn(4, dtype=torch.float32) | |
| # This should fail since only uint8 12-bit packed format is supported | |
| with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"): | |
| # This should fail since only uint8 Rice/bitmap encoded format is supported | |
| with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"): | |
| dummy_comms.check_compressed_indices("param", invalid_tensor, totalk, vals=vals) | |
| # Test with list (not a tensor) | |
| invalid_list = torch.tensor([0, 1, 2, 3]) | |
| with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"): | |
| with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"): | |
| dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals) | |
| totalk = 128 | |
| # Test with regular tensor (not packed) - should fail because it's not uint8 | |
| invalid_tensor = torch.tensor([0, 1, 2, 3], dtype=torch.long) | |
| vals = torch.randn(4, dtype=torch.float32) | |
| # This should fail since only uint8 Rice/bitmap encoded format is supported | |
| with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"): | |
| dummy_comms.check_compressed_indices("param", invalid_tensor, totalk, vals=vals) | |
| # Test with list (not a tensor) → type error branch | |
| invalid_list = [0, 1, 2, 3] | |
| with pytest.raises(ValueError, match="Expected tensor for indices"): | |
| dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals) |
🤖 Prompt for AI Agents
In tests/test_comms.py around lines 1605 to 1617, the second assertion claims to
test a non-tensor but currently passes a torch.tensor; change invalid_list to be
an actual Python list (e.g. [0,1,2,3]) and update the with pytest.raises to
expect a TypeError (or the specific exception your function raises for
non-tensor inputs) instead of the uint8/dtype ValueError so the test truly
verifies the non-tensor input path.
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tests/unit/test_compress.py (3)
74-96: Align payload row count with target shape to avoid scatter errors.Payload encodes 2 rows but xshape is (8, 64). Make indices/vals for 8 rows so N matches.
Apply:
- p = torch.zeros(8, 64) # 512 elements total, last dim=64 - xshape = (8, 64) - totalk = 64 + p = torch.zeros(8, 64) # 512 elements total, last dim=64 + xshape = (8, 64) + totalk = 64 @@ - original_indices = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.int64) + K = 4 + original_indices = torch.stack( + [torch.arange(i * K, i * K + K, dtype=torch.int64) % totalk for i in range(8)] + ) @@ - val = torch.tensor( - [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32 - ) + val = torch.randn(8, K, dtype=torch.float32)
105-136: Fix batch_decompress test shapes (rows must match xshape[0]).Each packed payload encodes 2 rows; set xshape (and p) to 2×64.
Apply:
- p = torch.zeros(8, 64) # 512 elements total, last dim=64 - xshape = (8, 64) + p = torch.zeros(2, 64) + xshape = (2, 64) totalk = 64
203-216: Fix norm/clip test shapes to match single-row payload.Current payload encodes 1 row; xshape is 8×64 → mismatch.
Apply:
- p = torch.zeros(8, 64) # 512 elements total, last dim=64 - xshape = (8, 64) + p = torch.zeros(1, 64) + xshape = (1, 64) totalk = 64 @@ - block_norms = torch.tensor([15.0]) + block_norms = torch.tensor([15.0])src/tplr/compress/topk.py (1)
513-539: Fix tuple/list detection in maybe_dequantize_values (len(float) crash).The current check iterates elements of a 5-tuple (shift, scale, …) and calls len() on floats/ints → TypeError.
Apply:
- if ( - isinstance(qparams, tuple) - and len(qparams) == 5 # potentially single or already 5 elements - and not all([len(q) == 5 for q in qparams]) # already correctly formatted - ): - qparams = [qparams] - if not isinstance(qparams, list): - qparams = [qparams] + if isinstance(qparams, tuple) and len(qparams) == 5: + qparams = [qparams] + if not isinstance(qparams, list): + qparams = [qparams]
♻️ Duplicate comments (3)
src/tplr/compress/bits.py (1)
75-83: Guard against truncated input in BitReader.read_bits.read_bits mutates state even when there aren’t enough bits, causing silent corruption and possible negative nbits. Raise EOFError before consuming.
Apply:
def read_bits(self, n: int) -> int: if n <= 0: return 0 self._fill(n) + if self.nbits < n: + raise EOFError(f"Not enough bits in stream to read {n} bits") mask = (1 << n) - 1 out = self.cur & mask self.cur >>= n self.nbits -= n return outtests/test_comms.py (1)
1614-1617: Fix non-tensor test to actually pass a list and assert the correct error.Currently passes a tensor and asserts the dtype error; intent was non-tensor type.
Apply:
- # Test with list (not a tensor) - invalid_list = torch.tensor([0, 1, 2, 3]) - with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"): - dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals) + # Test with list (not a tensor) → type branch + invalid_list = [0, 1, 2, 3] + with pytest.raises(ValueError, match="Expected tensor for indices"): + dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals)src/tplr/compress/topk.py (1)
398-414: Mirror N/k checks in batch path.Catch mismatched rows early with a clear error; current code only checks k.
Apply:
if i_data.dtype == torch.uint8: rows, C, _N = decode_batch_rows(i_data.detach().cpu().numpy().tobytes()) if C != totalk: raise ValueError(f"Index payload C={C} but expected {totalk}") - if any(len(r) != v_data.shape[-1] for r in rows): + k = int(v_data.shape[-1]) + if any(len(r) != k for r in rows): raise ValueError( "Row-wise topk size mismatch in index payload (batch)" ) + expected_rows = int(v_data.reshape(-1, k).shape[0]) + if len(rows) != expected_rows: + raise ValueError( + f"Row count mismatch in batch: payload N={len(rows)} vs vals rows={expected_rows}" + ) idx_unpacked = torch.tensor( rows, dtype=torch.int64, device=p.device ).view(*v_data.shape)
🧹 Nitpick comments (1)
tests/test_comms.py (1)
1756-1757: Correct comment: even-count constraint doesn’t apply to Rice/bitmap.Parity restriction was for legacy 12-bit packing, not the new codec.
Apply:
- # Note: Can't test with 1 element as encoder requires even number of indices + # Note: Rice/bitmap codec accepts any k; even-count was a 12-bit legacy constraint
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
src/tplr/compress/bits.py(1 hunks)src/tplr/compress/topk.py(10 hunks)tests/test_comms.py(11 hunks)tests/unit/test_bits_codec.py(1 hunks)tests/unit/test_compress.py(9 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/unit/test_bits_codec.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
📚 Learning: 2025-08-31T15:02:03.568Z
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
Applied to files:
tests/test_comms.pytests/unit/test_compress.py
🧬 Code graph analysis (3)
tests/test_comms.py (3)
src/tplr/compress/pack12.py (1)
pack_12bit_indices(21-47)src/tplr/compress/bits.py (1)
encode_batch_rows(253-411)src/tplr/comms.py (1)
check_compressed_indices(2615-2701)
src/tplr/compress/topk.py (2)
src/tplr/compress/bits.py (2)
decode_batch_rows(419-455)encode_batch_rows(253-411)src/tplr/comms.py (1)
gather(1466-1787)
tests/unit/test_compress.py (3)
src/tplr/compress/topk.py (5)
compress(194-198)compress(200-204)compress(207-252)_get_smaller_split(604-623)encode(77-95)src/tplr/compress/bits.py (1)
encode_batch_rows(253-411)src/tplr/compress/pack12.py (2)
pack_12bit_indices(21-47)unpack_12bit_indices(50-75)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
| # Decode indices | ||
| if idx.dtype == torch.uint8: | ||
| # 12-bit packed format - unpack it | ||
| idx_int64 = unpack_12bit_indices(idx, val.shape) | ||
| payload_bytes = idx.detach().cpu().numpy().tobytes() | ||
| rows_list, C, _N = decode_batch_rows(payload_bytes) | ||
| if C != totalk: | ||
| raise ValueError(f"Index payload C={C} but expected {totalk}") | ||
| k = val.shape[-1] | ||
| if any(len(r) != k for r in rows_list): | ||
| raise ValueError("Row-wise topk size mismatch in index payload") | ||
| idx_int64 = torch.tensor( | ||
| rows_list, dtype=torch.int64, device=p.device | ||
| ).view(*val.shape) | ||
| elif idx.dtype in (torch.int64, torch.long): | ||
| # Already unpacked (from batch_decompress) | ||
| idx_int64 = idx | ||
| idx_int64 = idx.to(p.device) | ||
| else: | ||
| raise ValueError( | ||
| f"Expected uint8 (packed) or int64 (unpacked) indices, got {idx.dtype}" | ||
| ) | ||
| # Ensure val has the same dtype as x for scatter operation | ||
| raise ValueError(f"Unsupported index tensor dtype: {idx.dtype}") | ||
|
|
||
| if val.dtype != x.dtype: |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Validate row count and remove no-op reshape.
decompress should fail early when payload N ≠ vals rows. Also the .reshape(xshape) return value is ignored.
Apply:
if idx.dtype == torch.uint8:
payload_bytes = idx.detach().cpu().numpy().tobytes()
rows_list, C, _N = decode_batch_rows(payload_bytes)
if C != totalk:
raise ValueError(f"Index payload C={C} but expected {totalk}")
k = val.shape[-1]
if any(len(r) != k for r in rows_list):
raise ValueError("Row-wise topk size mismatch in index payload")
+ expected_rows = int(val.reshape(-1, k).shape[0])
+ if len(rows_list) != expected_rows:
+ raise ValueError(
+ f"Row count mismatch: payload N={len(rows_list)} vs vals rows={expected_rows}"
+ )
idx_int64 = torch.tensor(
rows_list, dtype=torch.int64, device=p.device
).view(*val.shape)
@@
- x.scatter_reduce_(
- dim=-1, index=idx_int64, src=val, reduce="mean", include_self=False
- ).reshape(xshape)
+ x.scatter_reduce_(dim=-1, index=idx_int64, src=val, reduce="mean", include_self=False)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Decode indices | |
| if idx.dtype == torch.uint8: | |
| # 12-bit packed format - unpack it | |
| idx_int64 = unpack_12bit_indices(idx, val.shape) | |
| payload_bytes = idx.detach().cpu().numpy().tobytes() | |
| rows_list, C, _N = decode_batch_rows(payload_bytes) | |
| if C != totalk: | |
| raise ValueError(f"Index payload C={C} but expected {totalk}") | |
| k = val.shape[-1] | |
| if any(len(r) != k for r in rows_list): | |
| raise ValueError("Row-wise topk size mismatch in index payload") | |
| idx_int64 = torch.tensor( | |
| rows_list, dtype=torch.int64, device=p.device | |
| ).view(*val.shape) | |
| elif idx.dtype in (torch.int64, torch.long): | |
| # Already unpacked (from batch_decompress) | |
| idx_int64 = idx | |
| idx_int64 = idx.to(p.device) | |
| else: | |
| raise ValueError( | |
| f"Expected uint8 (packed) or int64 (unpacked) indices, got {idx.dtype}" | |
| ) | |
| # Ensure val has the same dtype as x for scatter operation | |
| raise ValueError(f"Unsupported index tensor dtype: {idx.dtype}") | |
| if val.dtype != x.dtype: | |
| # Decode indices | |
| if idx.dtype == torch.uint8: | |
| payload_bytes = idx.detach().cpu().numpy().tobytes() | |
| rows_list, C, _N = decode_batch_rows(payload_bytes) | |
| if C != totalk: | |
| raise ValueError(f"Index payload C={C} but expected {totalk}") | |
| k = val.shape[-1] | |
| if any(len(r) != k for r in rows_list): | |
| raise ValueError("Row-wise topk size mismatch in index payload") | |
| # NEW: ensure the total number of rows matches | |
| expected_rows = int(val.reshape(-1, k).shape[0]) | |
| if len(rows_list) != expected_rows: | |
| raise ValueError( | |
| f"Row count mismatch: payload N={len(rows_list)} vs vals rows={expected_rows}" | |
| ) | |
| idx_int64 = torch.tensor( | |
| rows_list, dtype=torch.int64, device=p.device | |
| ).view(*val.shape) | |
| elif idx.dtype in (torch.int64, torch.long): | |
| idx_int64 = idx.to(p.device) | |
| else: | |
| raise ValueError(f"Unsupported index tensor dtype: {idx.dtype}") | |
| if val.dtype != x.dtype: | |
| ... | |
| # Drop the no-op .reshape(xshape) since scatter_reduce_ is in-place | |
| x.scatter_reduce_( | |
| dim=-1, | |
| index=idx_int64, | |
| src=val, | |
| reduce="mean", | |
| include_self=False | |
| ) |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/tplr/compress/topk.py (2)
173-189: Fix k clamping when last-dim < 2 (avoids topk(k>dim) runtime error).Current logic can set k=2 even if dim=1. Clamp to [1..dim] first, then enforce evenness only when k>=2.
def _clamp_topk(self, x, topk) -> int: - topk = min(topk, x.shape[-1]) - topk = max(topk, 2) - # Keep even by default (matches broader system expectations). - topk = topk - (topk % 2) - return int(topk) + n = int(x.shape[-1]) + k = int(topk) + k = max(1, min(k, n)) # stay within [1..n] + if k >= 2: + k -= (k % 2) # even if possible + return k
533-542: Fix tuple/list detection; current code can raise TypeError on floats/dtypes.The len(q) check iterates elements of a single QuantParamsT (includes float/int/dtype), causing TypeError. Just wrap a single 5‑tuple.
- if ( - isinstance(qparams, tuple) - and len(qparams) == 5 # potentially single or already 5 elements - and not all([len(q) == 5 for q in qparams]) # already correctly formatted - ): - qparams = [qparams] + if isinstance(qparams, tuple) and len(qparams) == 5: + qparams = [qparams]
♻️ Duplicate comments (2)
src/tplr/compress/topk.py (2)
306-321: Add N (row‑count) validation and drop no‑op reshape; mirror prior guidance.if idx.dtype == torch.uint8: payload_bytes = idx.detach().cpu().numpy().tobytes() rows_list, C, _N = decode_batch_rows(payload_bytes) if C != totalk: raise ValueError(f"Index payload C={C} but expected {totalk}") k = val.shape[-1] if any(len(r) != k for r in rows_list): raise ValueError("Row-wise topk size mismatch in index payload") + expected_rows = int(val.reshape(-1, k).shape[0]) + if len(rows_list) != expected_rows: + raise ValueError( + f"Row count mismatch: payload N={len(rows_list)} vs vals rows={expected_rows}" + ) idx_int64 = torch.tensor( rows_list, dtype=torch.int64, device=p.device ).view(*val.shape) @@ - x.scatter_reduce_( - dim=-1, index=idx_int64, src=val, reduce="mean", include_self=False - ).reshape(xshape) + x.scatter_reduce_(dim=-1, index=idx_int64, src=val, reduce="mean", include_self=False)Also applies to: 326-329
418-435: Batch path: also validate N (row‑count) to fail early on shape mismatches.if i_data.dtype == torch.uint8: rows, C, _N = decode_batch_rows(i_data.detach().cpu().numpy().tobytes()) if C != totalk: raise ValueError(f"Index payload C={C} but expected {totalk}") - if any(len(r) != v_data.shape[-1] for r in rows): + k = int(v_data.shape[-1]) + if any(len(r) != k for r in rows): raise ValueError( "Row-wise topk size mismatch in index payload (batch)" ) + if len(rows) != int(v_data.reshape(-1, k).shape[0]): + raise ValueError( + f"Row count mismatch in batch: payload N={len(rows)} " + f"vs vals rows={int(v_data.reshape(-1, k).shape[0])}" + ) idx_unpacked = torch.tensor( - rows, dtype=torch.int64, device=p.device - ).view(*v_data.shape) + rows, dtype=torch.int64, device=p.device + ).view(*v_data.shape)
🧹 Nitpick comments (2)
neurons/validator.py (2)
262-297: Avoid compressing uninitialized tensors; compute xshape/totalk cheaply.torch.empty may contain NaN/Inf, tripping quantization fallbacks and wasting time. You only need xshape and totalk; skip full compress and use encode().shape.
- import time - total_compress_time = 0.0 - total_encode_time = 0.0 - - # Enable debug timing in compressor - self.compressor._debug_timing = True - + import time + total_compress_time = 0.0 + total_encode_time = 0.0 + # Enable debug timing in compressor (optional; keep False to speed init) + self.compressor._debug_timing = False @@ - encode_start = time.time() - enc = self.transformer.encode( - torch.empty(p.shape, dtype=torch.float16, device=self.device) - ) - encode_time = time.time() - encode_start - - compress_start = time.time() - _, _, xshape, totalk, _ = self.compressor.compress( - enc, - self.hparams.topk_compression, - ) - compress_time = time.time() - compress_start - - self.xshapes[n] = xshape - self.totalks[n] = totalk - - total_encode_time += encode_time - total_compress_time += compress_time - - # Log timing for each layer - tplr.logger.info(f"[COMPRESS TIMING] {n}: encode={encode_time:.3f}s, compress={compress_time:.3f}s, shape={p.shape}") + encode_start = time.time() + enc = self.transformer.encode(torch.zeros(p.shape, dtype=torch.float16, device=self.device)) + encode_time = time.time() - encode_start + + # Derive shapes without running top‑k/codec + self.xshapes[n] = tuple(p.shape) + self.totalks[n] = int(enc.shape[-1]) + + total_encode_time += encode_time + tplr.logger.info(f"[COMPRESS TIMING] {n}: encode={encode_time:.3f}s, shape={p.shape}") @@ - tplr.logger.info(f"[COMPRESS TIMING TOTAL] encode={total_encode_time:.3f}s, compress={total_compress_time:.3f}s") - - # Disable debug timing after initialization - self.compressor._debug_timing = False + tplr.logger.info(f"[COMPRESS TIMING TOTAL] encode={total_encode_time:.3f}s")
2710-2717: Drop stale 12‑bit handling in validator; enforce new Rice/bitmap only.Per your migration decision, keep only uint8 payloads and let Comms reject legacy formats. This removes dead paths and confusion.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
neurons/validator.py(2 hunks)src/tplr/compress/bits.py(1 hunks)src/tplr/compress/topk.py(10 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/tplr/compress/bits.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
🧬 Code graph analysis (2)
neurons/validator.py (3)
tests/test_comms.py (3)
encode(1026-1028)compress(1036-1038)decode(1030-1032)src/tplr/compress/topk.py (5)
encode(77-95)compress(194-198)compress(200-204)compress(207-272)decode(98-113)tests/test_prepare_gradient_dict.py (10)
encode(33-34)encode(167-168)encode(260-262)compress(17-26)compress(151-161)compress(238-249)info(44-45)decode(36-37)decode(170-171)decode(264-266)
src/tplr/compress/topk.py (2)
src/tplr/compress/bits.py (2)
decode_batch_rows(439-475)encode_batch_rows(253-431)src/tplr/comms.py (1)
gather(1466-1787)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
| # Flatten to [rows, k] for the codec | ||
| encode_start = time.time() | ||
| idx2d = idx_int64.reshape(-1, topk).contiguous() | ||
| # GPU‑accelerated encode → bytes | ||
| payload, _meta = encode_batch_rows( | ||
| idx2d, C=totalk, B_choices=_DEFAULT_B_CHOICES | ||
| ) | ||
|
|
There was a problem hiding this comment.
Make B choices dynamic per C to prevent “No valid B choices for C”.
Using a fixed (64,128) fails for small or non‑divisible C (e.g., 1D biases, odd chunk sizes). Build powers‑of‑two divisors of C at call‑site.
- idx2d = idx_int64.reshape(-1, topk).contiguous()
- # GPU‑accelerated encode → bytes
- payload, _meta = encode_batch_rows(
- idx2d, C=totalk, B_choices=_DEFAULT_B_CHOICES
- )
+ idx2d = idx_int64.reshape(-1, topk).contiguous()
+ # Choose B dynamically: all powers of two that divide C (largest→smallest; include 1 as last‑resort)
+ _cand = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
+ b_choices = tuple(b for b in reversed(_cand) if b <= totalk and (totalk % b) == 0) or (1,)
+ # GPU‑accelerated encode → bytes
+ payload, _meta = encode_batch_rows(idx2d, C=totalk, B_choices=b_choices)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Flatten to [rows, k] for the codec | |
| encode_start = time.time() | |
| idx2d = idx_int64.reshape(-1, topk).contiguous() | |
| # GPU‑accelerated encode → bytes | |
| payload, _meta = encode_batch_rows( | |
| idx2d, C=totalk, B_choices=_DEFAULT_B_CHOICES | |
| ) | |
| # Flatten to [rows, k] for the codec | |
| encode_start = time.time() | |
| idx2d = idx_int64.reshape(-1, topk).contiguous() | |
| # Choose B dynamically: all powers of two that divide C (largest→smallest; include 1 as last-resort) | |
| _cand = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] | |
| b_choices = tuple(b for b in reversed(_cand) if b <= totalk and (totalk % b) == 0) or (1,) | |
| # GPU-accelerated encode → bytes | |
| payload, _meta = encode_batch_rows( | |
| idx2d, C=totalk, B_choices=b_choices | |
| ) |
🤖 Prompt for AI Agents
In src/tplr/compress/topk.py around lines 246 to 253, the call to
encode_batch_rows uses a fixed _DEFAULT_B_CHOICES which can be invalid for small
or non‑divisible C and causes "No valid B choices for C"; replace this by
computing B_choices at the call site as the powers‑of‑two divisors of totalk (C)
— generate all 2^n that divide C (or are <= C and C % B == 0), ensure the list
is sorted/unique, and if empty fallback to [1]; then pass that B_choices to
encode_batch_rows instead of _DEFAULT_B_CHOICES so encoding works for small or
odd C values.
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (3)
src/tplr/compress/bits.py (3)
234-242: Row emission order fixed — thank you.Writing rows in original order (per-row encode then emit) resolves the earlier B-bucket ordering bug.
75-83: Fix BitReader underflow (read past end corrupts state).
read_bitscan drivenbitsnegative on truncated payloads. Raise EOF and avoid mutating state when fewer thannbits are available.Apply:
def read_bits(self, n: int) -> int: if n <= 0: return 0 self._fill(n) + if self.nbits < n: + raise EOFError(f"Not enough bits in stream to read {n} bits") mask = (1 << n) - 1 out = self.cur & mask self.cur >>= n self.nbits -= n return out
366-402: Harden decoder against malformed/untrusted payloads.Validate header fields, row boundaries, and per-row params; ensure full row chunks are consumed. Prevent OOB reads and silent corruption.
Apply:
def decode_batch_rows(payload: bytes) -> tuple[list[list[int]], int, int]: @@ - C = br.read_bits(12) + 1 - N = br.read_bits(16) + C = br.read_bits(12) + 1 + if not (1 <= C <= 4096): + raise ValueError(f"Invalid C={C}; expected 1..4096") + N = br.read_bits(16) + if N < 0: + raise ValueError(f"Invalid N={N}") _ = br.read_bits(1) # reserved @@ - for _i in range(N): - row_len = br.read_bits(16) + for _i in range(N): + row_len = br.read_bits(16) + # Ensure enough bits remain for row_len bytes + remaining_bits = 8 * (len(payload) - br.idx) + br.nbits + if row_len * 8 > remaining_bits: + raise ValueError("Row length exceeds remaining payload") row_bytes = br.read_bytes(row_len) rr = BitReader(row_bytes) lb = rr.read_bits(5) k_param = rr.read_bits(4) use_bitmap = rr.read_bits(1) B = 1 << lb - n_sub = C // B + # Validate B + if B <= 0 or (B & (B - 1)) != 0 or (C % B) != 0: + raise ValueError(f"Invalid row block size B={B} for C={C}") + n_sub = C // B @@ - s_len = _rice_read(rr, k_param) + s_len = _rice_read(rr, k_param) if s_len == 0: continue if use_bitmap: bitmask = rr.read_bits(B) for loc in range(B): if (bitmask >> loc) & 1: indices.append(j * B + loc) else: - for _ in range(s_len): - loc = rr.read_bits(lb) - indices.append(j * B + loc) + for _ in range(s_len): + loc = rr.read_bits(lb) + if loc >= B: + raise ValueError("Location out of range") + indices.append(j * B + loc) rows.append(indices) + # Ensure no trailing data per row + if rr.idx != len(row_bytes): + raise ValueError("Row payload has trailing or malformed data") return rows, C, N
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
src/tplr/compress/bits.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
src/tplr/compress/bits.py
Outdated
| def encode_batch_rows_cpu( | ||
| rows_np: np.ndarray, | ||
| *, | ||
| C: int, | ||
| B_choices: tuple[int, ...] = (64, 128), | ||
| scheme: str = "per_row", | ||
| workers: int | None = None, | ||
| ) -> tuple[bytes, dict]: | ||
| if scheme != "per_row": | ||
| raise ValueError("Only scheme='per_row' is implemented") | ||
| valid_B: list[int] = [ | ||
| b for b in B_choices if b > 0 and (b & (b - 1)) == 0 and (C % b) == 0 | ||
| ] | ||
| if not valid_B: | ||
| b = 1 | ||
| valid_B = [] | ||
| while b <= C: | ||
| if C % b == 0 and (b & (b - 1)) == 0: | ||
| valid_B.append(b) | ||
| b <<= 1 | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Add input validation to prevent silent index drop/corruption.
Out-of-range indices (e.g., <0 or ≥C) are currently binned out of n_sub and silently discarded. Validate upfront and fail fast.
Apply:
def encode_batch_rows_cpu(
@@
) -> tuple[bytes, dict]:
- if scheme != "per_row":
+ if scheme != "per_row":
raise ValueError("Only scheme='per_row' is implemented")
+ if not np.issubdtype(rows_np.dtype, np.integer):
+ raise TypeError("rows_np must be an integer array")
+ if rows_np.size:
+ mn = int(rows_np.min())
+ mx = int(rows_np.max())
+ if mn < 0 or mx >= C:
+ raise ValueError(f"Index out of range: allowed [0, {C - 1}], got min={mn}, max={mx}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def encode_batch_rows_cpu( | |
| rows_np: np.ndarray, | |
| *, | |
| C: int, | |
| B_choices: tuple[int, ...] = (64, 128), | |
| scheme: str = "per_row", | |
| workers: int | None = None, | |
| ) -> tuple[bytes, dict]: | |
| if scheme != "per_row": | |
| raise ValueError("Only scheme='per_row' is implemented") | |
| valid_B: list[int] = [ | |
| b for b in B_choices if b > 0 and (b & (b - 1)) == 0 and (C % b) == 0 | |
| ] | |
| if not valid_B: | |
| b = 1 | |
| valid_B = [] | |
| while b <= C: | |
| if C % b == 0 and (b & (b - 1)) == 0: | |
| valid_B.append(b) | |
| b <<= 1 | |
| def encode_batch_rows_cpu( | |
| rows_np: np.ndarray, | |
| *, | |
| C: int, | |
| B_choices: tuple[int, ...] = (64, 128), | |
| scheme: str = "per_row", | |
| workers: int | None = None, | |
| ) -> tuple[bytes, dict]: | |
| if scheme != "per_row": | |
| raise ValueError("Only scheme='per_row' is implemented") | |
| if not np.issubdtype(rows_np.dtype, np.integer): | |
| raise TypeError("rows_np must be an integer array") | |
| if rows_np.size: | |
| mn = int(rows_np.min()) | |
| mx = int(rows_np.max()) | |
| if mn < 0 or mx >= C: | |
| raise ValueError( | |
| f"Index out of range: allowed [0, {C - 1}], got min={mn}, max={mx}" | |
| ) | |
| valid_B: list[int] = [ | |
| b for b in B_choices if b > 0 and (b & (b - 1)) == 0 and (C % b) == 0 | |
| ] | |
| if not valid_B: | |
| b = 1 | |
| valid_B = [] | |
| while b <= C: | |
| if C % b == 0 and (b & (b - 1)) == 0: | |
| valid_B.append(b) | |
| b <<= 1 |
🤖 Prompt for AI Agents
In src/tplr/compress/bits.py around lines 129 to 149, input indexing into
buckets can silently discard out-of-range indices; add upfront validation on
rows_np to fail fast: ensure rows_np is an integer ndarray, non-empty, has shape
consistent with expected rows (e.g., 2D or 1D as required by this function), and
that every index value is >= 0 and < C; if any check fails raise a ValueError
with a clear message indicating invalid index values and the offending
range/shape so callers can correct their input rather than losing data silently.
src/tplr/compress/bits.py
Outdated
| rows, k = idx_gpu.shape | ||
| device = idx_gpu.device | ||
| header = 5 + 4 + 1 | ||
|
|
||
| best_bits = None # torch.Tensor [rows] | ||
| best_B = None # torch.Tensor [rows] | ||
|
|
||
| # Fast path: single histogram at smallest B, aggregate for larger B |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Guard GPU path inputs (dtype/range).
Mirror CPU-side validation to avoid UB on GPU (e.g., invalid dtype or indices out of [0,C)).
Apply:
@torch.no_grad()
def encode_batch_rows(
idx_gpu: torch.Tensor, # [rows, k] int64
@@
) -> tuple[int, dict]:
@@
- rows, k = idx_gpu.shape
+ if idx_gpu.dtype != torch.long:
+ raise TypeError("idx_gpu must be torch.int64 (long)")
+ if idx_gpu.numel():
+ if int(idx_gpu.min().item()) < 0 or int(idx_gpu.max().item()) >= C:
+ raise ValueError(f"Index out of range for C={C}")
+ rows, k = idx_gpu.shape📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| rows, k = idx_gpu.shape | |
| device = idx_gpu.device | |
| header = 5 + 4 + 1 | |
| best_bits = None # torch.Tensor [rows] | |
| best_B = None # torch.Tensor [rows] | |
| # Fast path: single histogram at smallest B, aggregate for larger B | |
| if idx_gpu.dtype != torch.long: | |
| raise TypeError("idx_gpu must be torch.int64 (long)") | |
| if idx_gpu.numel(): | |
| if int(idx_gpu.min().item()) < 0 or int(idx_gpu.max().item()) >= C: | |
| raise ValueError(f"Index out of range for C={C}") | |
| rows, k = idx_gpu.shape | |
| device = idx_gpu.device | |
| header = 5 + 4 + 1 | |
| best_bits = None # torch.Tensor [rows] | |
| best_B = None # torch.Tensor [rows] | |
| # Fast path: single histogram at smallest B, aggregate for larger B |
375b0e3 to
4f23a1e
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (5)
src/tplr/compress/bits.py (5)
230-243: Row-order emission is correct now (fixes prior B-bucket misalignment).Using ThreadPoolExecutor.map over range(N) preserves original row order and aligns with the decoder’s expectations.
75-83: Fix bit underflow in BitReader.read_bits (silent corruption on truncated input).If the stream is shorter than requested, nbits can go negative and decoding silently corrupts. Raise EOFError when not enough bits are available.
Apply:
def read_bits(self, n: int) -> int: if n <= 0: return 0 self._fill(n) + if self.nbits < n: + raise EOFError(f"Not enough bits in stream to read {n} bits") mask = (1 << n) - 1 out = self.cur & mask self.cur >>= n self.nbits -= n return out
346-358: GPU size calculation undercounts: omit global header, per-row 16-bit lengths, and per-row byte rounding.This can cause buffer undersizing downstream. Compute bytes_per_row and include fixed overhead.
Apply:
- total_bits = int(best_bits.sum().item()) - bytes_len = (total_bits + 7) // 8 + # Round each row to full bytes; outer stream stores per-row blobs. + bytes_per_row = (best_bits + 7) // 8 # [rows] + rows_bytes_total = int(bytes_per_row.sum().item()) + header_bits = 12 + 16 + 1 # global C-1, N, reserved + length_bits = 16 * rows # per-row byte lengths + total_bits = rows_bytes_total * 8 + header_bits + length_bits + bytes_len = (total_bits + 7) // 8 @@ - "avg_bits_per_row": float(best_bits.float().mean().item()), + "avg_bits_per_row": float((bytes_per_row * 8).to(torch.float32).mean().item()), "B_hist": B_hist, }
129-149: Validate CPU encoder inputs (dtype/range) to avoid silent index loss.Reject non-integer arrays and out-of-range indices upfront.
Apply:
def encode_batch_rows_cpu( @@ ) -> tuple[bytes, dict]: if scheme != "per_row": raise ValueError("Only scheme='per_row' is implemented") + if not np.issubdtype(rows_np.dtype, np.integer): + raise TypeError("rows_np must be an integer array") + if rows_np.size: + mn = int(rows_np.min()) + mx = int(rows_np.max()) + if mn < 0 or mx >= C: + raise ValueError(f"Index out of range: allowed [0, {C - 1}], got min={mn}, max={mx}")
263-270: Guard GPU path inputs (dtype/range) to prevent UB and wrong sizing.Mirror CPU checks for dtype=int64 and indices in [0, C).
Apply:
@torch.no_grad() def encode_batch_rows( @@ ) -> tuple[int, dict]: @@ - rows, k = idx_gpu.shape + if idx_gpu.dtype != torch.long: + raise TypeError("idx_gpu must be torch.int64 (long)") + if idx_gpu.numel(): + if int(idx_gpu.min().item()) < 0 or int(idx_gpu.max().item()) >= C: + raise ValueError(f"Index out of range for C={C}") + rows, k = idx_gpu.shape
🧹 Nitpick comments (2)
src/tplr/compress/bits.py (2)
62-74: Add BitReader.available_bits() helper (used by decoder preflight checks).Small utility to validate remaining capacity before reading fixed-size row chunks.
Apply:
class BitReader: def __init__(self, data: bytes) -> None: self.data = data self.idx = 0 self.cur = 0 self.nbits = 0 def _fill(self, n: int) -> None: while self.nbits < n and self.idx < len(self.data): self.cur |= int(self.data[self.idx]) << self.nbits self.idx += 1 self.nbits += 8 + + def available_bits(self) -> int: + # Bits still buffered plus bits in unread bytes + return self.nbits + (len(self.data) - self.idx) * 8
243-250: Make CPU meta.avg_bits_per_row reflect allocated bytes (optional).Average currently uses raw per-row estimates, not rounded bytes + headers. Consider computing from actual row_bytes.
Apply:
- row_bits: list[int] = [] + row_bits: list[int] = [] + row_bytes_lens: list[int] = [] @@ bw.write_bits(len(row_bytes), 16) for byte in row_bytes: bw.write_bits(int(byte), 8) row_bits.append(bits_used) B_hist[B_used] = B_hist.get(B_used, 0) + 1 + row_bytes_lens.append(len(row_bytes)) @@ meta = { "total_bits": len(payload) * 8, - "avg_bits_per_row": (sum(row_bits) / max(1, N)) if N else 0.0, + "avg_bits_per_row": (8.0 * sum(row_bytes_lens) / max(1, N)) if N else 0.0, "B_hist": B_hist, }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
src/tplr/compress/bits.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
4f23a1e to
807b76a
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (4)
src/tplr/compress/bits.py (4)
35-43: Add EOF protection to prevent bit underflow.The
read_bitsmethod can drivenbitsnegative if the stream is truncated, leading to corrupted state. This is critical for robust handling of potentially malformed payloads.Apply this diff to prevent underflow:
def read_bits(self, n: int) -> int: if n <= 0: return 0 self._fill(n) + if self.nbits < n: + raise EOFError(f"Not enough bits in stream to read {n} bits") mask = (1 << n) - 1 out = self.cur & mask self.cur >>= n self.nbits -= n return out
67-103: Add validation for untrusted payloads in decoder.The decoder lacks validation for malformed payloads that could cause crashes or incorrect behavior. This is essential when processing data from untrusted sources.
Apply validation checks:
def decode_batch_rows(payload: bytes) -> tuple[list[list[int]], int, int]: """ Decode payload created by `encode_batch_rows(...)`. Returns (rows, C, N) with per-row global indices. """ br = _BitReader(payload) C = br.read_bits(12) + 1 + if C < 1 or C > 4096: # Reasonable bounds + raise ValueError(f"Invalid C value: {C}") N = br.read_bits(16) + if N > 65536: # Reasonable upper bound + raise ValueError(f"Invalid N value: {N}") _ = br.read_bits(1) # reserved rows: list[list[int]] = [] for _i in range(N): row_len = br.read_bits(16) + # Verify enough bytes remain + if br.idx + row_len > len(br.data): + raise EOFError(f"Truncated payload: need {row_len} bytes, have {len(br.data) - br.idx}") row_bytes = br.read_bytes(row_len) rr = _BitReader(row_bytes) lb = rr.read_bits(5) k_param = rr.read_bits(4) use_bitmap = rr.read_bits(1) B = 1 << lb + # Validate B + if B < 1 or B > C or (C % B) != 0: + raise ValueError(f"Invalid row block size B={B} for C={C}") n_sub = C // B indices: list[int] = [] for j in range(n_sub): s_len = _rice_read(rr, k_param) + if s_len > B: + raise ValueError(f"Invalid s_len={s_len} > B={B}") if s_len == 0: continue if use_bitmap: bitmask = rr.read_bits(B) for loc in range(B): if (bitmask >> loc) & 1: indices.append(j * B + loc) else: for _ in range(s_len): loc = rr.read_bits(lb) + if loc >= B: + raise ValueError(f"Invalid loc={loc} >= B={B}") indices.append(j * B + loc) rows.append(indices) + # Ensure full row chunk consumed + if rr.idx != len(row_bytes): + raise ValueError("Row payload has trailing or malformed data") return rows, C, N
406-412: Correct total_bits calculation to include per-row byte rounding.The current calculation underestimates the actual payload size as it doesn't account for row-level byte alignment.
Apply the corrected calculation:
- row_bytes = ((row_bits + 7) // 8).to(torch.int32) + row_bytes = ((row_bits + 7) // 8).to(torch.int32) # 3) Allocate payload buffer B_hist = {int(b): int((best_B == b).sum().item()) for b in torch.unique(best_B)} - total_bits_rows = int((16 * rows + 8 * row_bytes.sum().item())) - total_bits = 12 + 16 + 1 + total_bits_rows + # Account for: global header (29 bits) + per-row length fields (16 bits each) + per-row payloads (bytes) + header_bits = 12 + 16 + 1 # C-1, N, reserved + length_bits = 16 * rows + payload_bytes_total = int(row_bytes.sum().item()) + total_bits = header_bits + length_bits + (payload_bytes_total * 8) total_bytes = (total_bits + 7) // 8
354-376: Add input validation for GPU encoding path.The encoder should validate input indices to prevent silent data corruption or out-of-bounds access.
Apply input validation:
@torch.no_grad() def encode_batch_rows( idx: torch.Tensor, # [rows, k] int64 (CUDA strongly recommended) *, C: int, B_choices: tuple[int, ...] = (64, 128), ) -> tuple[bytes, dict]: """ Triton encoder for per-row Rice/bitmap codec. Returns: payload: bytes meta: {total_bits, avg_bits_per_row, B_hist} """ if not TRITON_AVAILABLE: raise RuntimeError("Triton is not available. `pip install triton` and re-run.") if idx.dtype != torch.int64: idx = idx.to(torch.int64) if not idx.is_cuda: idx = idx.cuda() + + # Validate input indices + if idx.numel() > 0: + min_idx = int(idx.min().item()) + max_idx = int(idx.max().item()) + if min_idx < 0 or max_idx >= C: + raise ValueError(f"Index out of range for C={C}: found indices in [{min_idx}, {max_idx}]") + rows, k = idx.shape device = idx.device
🧹 Nitpick comments (2)
src/tplr/compress/bits.py (2)
311-312: Increase buffer size for safety or add dynamic check.The fixed buffer size of 256 could be exceeded if k > 256. Consider making it dynamic or adding a guard.
Add a safety check or increase buffer size:
- buf = tl.zeros([256], dtype=tl.int32) # K<=256 guard; adjust if you need larger K + # Ensure buffer is large enough for all possible k values + buf = tl.zeros([512], dtype=tl.int32) # Increased for safety + tl.device_assert(k <= 512, "k exceeds buffer size")
469-470: Improve avg_bits_per_row calculation accuracy.The average should be computed after byte-rounding to reflect actual allocated space.
meta = { "total_bits": total_bits, - "avg_bits_per_row": float(row_bits.float().mean().item()) if rows > 0 else 0.0, + "avg_bits_per_row": float((row_bytes * 8).float().mean().item()) if rows > 0 else 0.0, "B_hist": B_hist, }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
src/tplr/compress/bits.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
🪛 GitHub Actions: CI
src/tplr/compress/bits.py
[error] 5-5: uv run ruff check: I001 - Import block is un-sorted or un-formatted.
🔇 Additional comments (1)
src/tplr/compress/bits.py (1)
1-18: Review imports and exception handling insrc/tplr/compress/bits.py
- Add a blank line after
from __future__ import annotationsand group imports per PEP 8:from __future__ import annotations import math from typing import Sequence import torch import torch.nn.functional as F- Restrict the catch to
ImportErrorso other issues aren’t masked:try: import triton import triton.language as tl TRITON_AVAILABLE = True -except Exception: +except ImportError: TRITON_AVAILABLE = False
src/tplr/compress/bits.py
Outdated
| lb = tl.cdiv(tl.log2(B), 1) # ceil(log2(B)) ; B is power of two so exact | ||
| n_sub = C // B |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Use tl.math.log2 for proper Triton compilation.
The tl.log2 function returns float values which need proper ceiling computation in Triton context.
Apply this fix for proper Triton compilation:
- lb = tl.cdiv(tl.log2(B), 1) # ceil(log2(B)) ; B is power of two so exact
+ lb = tl.math.log2(B.to(tl.float32)).to(tl.int32) # B is power of two so exact📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| lb = tl.cdiv(tl.log2(B), 1) # ceil(log2(B)) ; B is power of two so exact | |
| n_sub = C // B | |
| lb = tl.math.log2(B.to(tl.float32)).to(tl.int32) # B is power of two so exact | |
| n_sub = C // B |
🤖 Prompt for AI Agents
In src/tplr/compress/bits.py around lines 268-269, replace the use of tl.log2
with Triton's math API and perform an explicit ceiling: compute lb using
tl.math.log2 and tl.math.ceil (e.g. lb = tl.math.ceil(tl.math.log2(B))) so the
value is computed with Triton-compatible functions; leave n_sub = C // B
unchanged.
src/tplr/compress/bits.py
Outdated
| payload_bitpos.to(torch.int64), # starts_bits_ptr | ||
| payload, # payload buffer | ||
| torch.as_tensor(int(len_bits_base), dtype=torch.int64, device=device), | ||
| torch.as_tensor(int(payload_bits_base), dtype=torch.int64, device=device), | ||
| NSUB_MAX=128, | ||
| ) |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Consider validating NSUB_MAX against actual requirements.
The hardcoded NSUB_MAX=128 assumes C/B_min <= 128. This should be validated or made dynamic.
🏁 Script executed:
#!/bin/bash
# Check if there are any configurations where C/B_min might exceed 128
rg -n "C\s*=|C:" --type py | head -20
rg -n "B_choices|B_sorted" --type py | head -20Length of output: 3317
🏁 Script executed:
#!/bin/bash
sed -n '430,500p' src/tplr/compress/bits.pyLength of output: 1717
Enforce C/B_min ≤ NSUB_MAX (128) or compute NSUB_MAX dynamically
Add a runtime check (or derive NSUB_MAX = C // B_min) before launching the Triton kernel in src/tplr/compress/bits.py (around lines 457–462) to ensure C/B_min never exceeds 128 and avoid out-of-bounds indexing.
807b76a to
2727ada
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (4)
src/tplr/compress/bits.py (4)
50-58: Guard against truncated input in BitReader.read_bits.Without a fill-then-check, nbits can go negative on short payloads.
def read_bits(self, n: int) -> int: if n <= 0: return 0 self._fill(n) + if self.nbits < n: + raise EOFError(f"Not enough bits in stream to read {n} bits") mask = (1 << n) - 1 out = self.cur & mask self.cur >>= n self.nbits -= n return out
82-119: Harden decoder against malformed/untrusted payloads.Add header, per-row length, B, and row-consumption validation to avoid OOB/DoS.
def decode_batch_rows(payload: bytes) -> tuple[list[list[int]], int, int]: @@ - C = br.read_bits(12) + 1 - N = br.read_bits(16) + C = br.read_bits(12) + 1 + if C < 1: + raise ValueError(f"Invalid C={C}") + N = br.read_bits(16) _ = br.read_bits(1) # reserved @@ - for _i in range(N): - row_len = br.read_bits(16) + for _i in range(N): + row_len = br.read_bits(16) + if hasattr(br, "available_bits") and br.available_bits() < row_len * 8: + raise EOFError("Truncated payload: not enough bytes for row") row_bytes = br.read_bytes(row_len) rr = _BitReader(row_bytes) lb = rr.read_bits(5) k_param = rr.read_bits(4) use_bitmap = rr.read_bits(1) - B = 1 << lb - n_sub = C // B + B = 1 << lb + # B must be a power of two ≤ C and divide C exactly. + if B < 1 or (B & (B - 1)) != 0 or (C % B) != 0: + raise ValueError(f"Invalid row block size B={B} for C={C}") + n_sub = C // B @@ - if use_bitmap: + if use_bitmap: bitmask = rr.read_bits(B) for loc in range(B): if (bitmask >> loc) & 1: indices.append(j * B + loc) else: for _ in range(s_len): - loc = rr.read_bits(lb) + loc = rr.read_bits(lb) + if loc >= B: + raise ValueError(f"Decoded loc {loc} out of range for B={B}") indices.append(j * B + loc) rows.append(indices) + # Ensure row chunk fully consumed + if rr.idx != len(row_bytes): + raise ValueError("Row payload has trailing or malformed data") return rows, C, N
301-304: Use Triton’s math API for log2 to avoid compilation issues.tl.math.log2 is the supported path; cast explicitly.
- lb = tl.log2(B.to(tl.float32)).to(tl.int32) + lb = tl.math.log2(B.to(tl.float32)).to(tl.int32)
370-393: Add dtype/range guards in the public encoder.Fail fast on wrong dtype or indices outside [0, C).
def encode_batch_rows( idx: torch.Tensor, # [rows, k] int64 (CUDA strongly recommended) @@ - if idx.dtype != torch.int64: - idx = idx.to(torch.int64) + if idx.dtype != torch.int64: + idx = idx.to(torch.int64) @@ - rows, k = idx.shape + rows, k = idx.shape + if idx.numel(): + mn = int(idx.min().item()) + mx = int(idx.max().item()) + if mn < 0 or mx >= C: + raise ValueError(f"encode_batch_rows: index out of range for C={C}: min={mn}, max={mx}")
🧹 Nitpick comments (4)
src/tplr/compress/__init__.py (2)
18-21: Fix misleading comment: encoder does not return a perm.encode_batch_rows returns (bytes, meta). Drop “+ perm” to avoid confusion.
-from .bits import ( - decode_batch_rows, # decoder (CPU) - encode_batch_rows, # GPU-accelerated encoder → bytes + perm + meta -) +from .bits import ( + decode_batch_rows, # decoder (CPU) + encode_batch_rows, # GPU-accelerated encoder → bytes + meta +)
22-33: Consider deprecating legacy pack12 exports in public API.Given the migration policy to accept only uint8 Rice/bitmap on the wire, re-exporting pack_12bit_indices/unpack_12bit_indices can prolong legacy usage. Either mark as deprecated with warnings or drop from all.
Do you want a follow-up PR to emit DeprecationWarning on import?
src/tplr/compress/bits.py (2)
69-74: Expose available_bits() to enable safe length checks.Small helper to validate remaining capacity before reads.
def read_bytes(self, n: int) -> bytes: out = bytearray() for _ in range(n): out.append(self.read_bits(8)) return bytes(out) + + def available_bits(self) -> int: + # Bits left in cur plus remaining bytes. + return self.nbits + (len(self.data) - self.idx) * 8
445-451: Expose MAX_K and reflect new top‑k ceiling in error text.Consider exporting MAX_K or documenting it to align with topk=128 default.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
src/tplr/compress/__init__.py(1 hunks)src/tplr/compress/bits.py(1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
📚 Learning: 2025-08-31T15:02:03.568Z
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
Applied to files:
src/tplr/compress/bits.py
🧬 Code graph analysis (1)
src/tplr/compress/__init__.py (2)
src/tplr/compress/bits.py (2)
decode_batch_rows(82-118)encode_batch_rows(371-479)src/tplr/compress/pack12.py (2)
pack_12bit_indices(21-47)unpack_12bit_indices(50-75)
🪛 GitHub Actions: CI
src/tplr/compress/bits.py
[error] 18-23: I001: Import block is un-sorted or un-formatted.
| from __future__ import annotations | ||
| import math | ||
| from typing import Sequence | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F # only used if we pick the one-hot path | ||
|
|
||
| try: | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| TRITON_AVAILABLE = True | ||
| except Exception: | ||
| TRITON_AVAILABLE = False |
There was a problem hiding this comment.
Fix CI import-order error and remove unused F import.
Ruff I001 flags lines 18–23; also torch.nn.functional as F is unused.
-from __future__ import annotations
-import math
-from typing import Sequence
-
-import torch
-import torch.nn.functional as F # only used if we pick the one-hot path
-
-try:
- import triton
- import triton.language as tl
-
- TRITON_AVAILABLE = True
-except Exception:
- TRITON_AVAILABLE = False
+from __future__ import annotations
+
+import math
+from typing import Sequence
+
+import torch
+
+try:
+ import triton
+ import triton.language as tl
+ TRITON_AVAILABLE = True
+except Exception:
+ TRITON_AVAILABLE = False📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from __future__ import annotations | |
| import math | |
| from typing import Sequence | |
| import torch | |
| import torch.nn.functional as F # only used if we pick the one-hot path | |
| try: | |
| import triton | |
| import triton.language as tl | |
| TRITON_AVAILABLE = True | |
| except Exception: | |
| TRITON_AVAILABLE = False | |
| from __future__ import annotations | |
| import math | |
| from typing import Sequence | |
| import torch | |
| try: | |
| import triton | |
| import triton.language as tl | |
| TRITON_AVAILABLE = True | |
| except Exception: | |
| TRITON_AVAILABLE = False |
🧰 Tools
🪛 GitHub Actions: CI
[error] 18-23: I001: Import block is un-sorted or un-formatted.
🤖 Prompt for AI Agents
In src/tplr/compress/bits.py around lines 18 to 31, fix the import-order and
remove the unused alias: move stdlib imports (math, typing) directly after the
from __future__ line, then import third-party packages (torch) and remove
"import torch.nn.functional as F" since F is unused, and keep the triton
try/except block after torch; ensure grouping follows stdlib then third-party
order to satisfy Ruff I001.
2727ada to
65167f1
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (4)
src/tplr/compress/bits.py (4)
45-53: Prevent bit underflow on truncated input in BitReader.read_bits.Guard after _fill to avoid corrupting state when not enough bits remain.
Apply:
def read_bits(self, n: int) -> int: if n <= 0: return 0 self._fill(n) + if self.nbits < n: + raise EOFError(f"Not enough bits in stream to read {n} bits") mask = (1 << n) - 1 out = self.cur & mask self.cur >>= n self.nbits -= n return out
82-113: Harden decoder against malformed/untrusted payloads (header, per-row length, B, and trailing data).Validate header and row chunks; fail fast on inconsistencies.
Apply:
br = _BitReader(payload) C = br.read_bits(12) + 1 N = br.read_bits(16) _ = br.read_bits(1) # reserved - - rows: list[list[int]] = [] + if C < 1: + raise ValueError(f"Invalid C={C}") + rows: list[list[int]] = [] for _i in range(N): - row_len = br.read_bits(16) + row_len = br.read_bits(16) + if row_len < 0: + raise ValueError("Negative row length") + if br.available_bits() < row_len * 8: + raise EOFError("Truncated payload: not enough bytes for row") row_bytes = br.read_bytes(row_len) rr = _BitReader(row_bytes) lb = rr.read_bits(5) k_param = rr.read_bits(4) use_bitmap = rr.read_bits(1) B = 1 << lb - n_sub = C // B + if B < 1 or (C % B) != 0: + raise ValueError(f"Invalid row block size B={B} for C={C}") + n_sub = C // B indices: list[int] = [] for j in range(n_sub): s_len = _rice_read(rr, k_param) if s_len == 0: continue if use_bitmap: bitmask = rr.read_bits(B) for loc in range(B): if (bitmask >> loc) & 1: - indices.append(j * B + loc) + gidx = j * B + loc + if gidx >= C: + raise ValueError("Decoded index out of range") + indices.append(gidx) else: for _ in range(s_len): loc = rr.read_bits(lb) - indices.append(j * B + loc) + if loc >= B: + raise ValueError("Decoded loc out of range for block") + gidx = j * B + loc + if gidx >= C: + raise ValueError("Decoded index out of range") + indices.append(gidx) rows.append(indices) + if rr.idx != len(row_bytes): + raise ValueError("Row payload has trailing or malformed data") return rows, C, N
310-311: Use Triton’s math API for log2; tl.log2 can miscompile.Switch to tl.math.log2 to compute lb robustly.
Apply:
- lb = tl.log2(B.to(tl.float32)).to(tl.int32) + lb = tl.math.log2(B.to(tl.float32)).to(tl.int32)
133-136: Mirror range validation in parameter estimator.Keep GPU counts sane; reject out-of-range indices upfront.
Apply:
assert idx.dtype == torch.int64 rows, k = idx.shape device = idx.device + if idx.numel(): + mn = int(idx.min().item()) + mx = int(idx.max().item()) + if mn < 0 or mx >= C: + raise ValueError(f"Index out of range for C={C}: min={mn}, max={mx}")
🧹 Nitpick comments (2)
src/tplr/compress/bits.py (2)
39-44: Expose available_bits() for safe decoder bounds checks.Small helper enables validating remaining capacity before reads.
Apply:
def _fill(self, n: int) -> None: while self.nbits < n and self.idx < len(self.data): self.cur |= int(self.data[self.idx]) << self.nbits self.idx += 1 self.nbits += 8 + + def available_bits(self) -> int: + # Remaining bits including latched bits. + return (len(self.data) - self.idx) * 8 + self.nbits
481-485: Report allocated bits per row in meta (matches actual payload).Use per-row byte-allocation rather than raw bit estimate.
Apply:
- "avg_bits_per_row": float(row_bits.float().mean().item()) if rows > 0 else 0.0, + "avg_bits_per_row": float((row_bytes.to(torch.int64) * 8).float().mean().item()) if rows > 0 else 0.0,
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
src/tplr/compress/__init__.py(1 hunks)src/tplr/compress/bits.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/tplr/compress/init.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
📚 Learning: 2025-08-31T15:02:03.568Z
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
Applied to files:
src/tplr/compress/bits.py
🪛 GitHub Actions: CI
src/tplr/compress/bits.py
[error] 13-13: I001: Import block is un-sorted or un-formatted.
| try: | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| TRITON_AVAILABLE = True | ||
| except Exception: | ||
| TRITON_AVAILABLE = False |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Fix import block to satisfy Ruff I001 (don’t assign inside import block).
Move TRITON_AVAILABLE assignment out of the try/except and avoid blank separation inside the block.
Apply:
-try:
- import triton
- import triton.language as tl
-
- TRITON_AVAILABLE = True
-except Exception:
- TRITON_AVAILABLE = False
+try:
+ import triton
+ import triton.language as tl
+except Exception:
+ triton = None
+ tl = None
+TRITON_AVAILABLE = triton is not None📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: | |
| import triton | |
| import triton.language as tl | |
| TRITON_AVAILABLE = True | |
| except Exception: | |
| TRITON_AVAILABLE = False | |
| try: | |
| import triton | |
| import triton.language as tl | |
| except Exception: | |
| triton = None | |
| tl = None | |
| TRITON_AVAILABLE = triton is not None |
🤖 Prompt for AI Agents
In src/tplr/compress/bits.py around lines 20 to 26, the import block assigns
TRITON_AVAILABLE inside the try which triggers Ruff I001; move the assignment
out of the try/except: initialize TRITON_AVAILABLE = False before the try, keep
the try block containing only the import statements with no extra blank line,
use except Exception: pass, and after the try/except set TRITON_AVAILABLE = True
only if the imports succeeded (for example by checking that the imported name
exists in locals() or by using a try/except/else and assigning True in the
else).
| if idx.dtype != torch.int64: | ||
| idx = idx.to(torch.int64) | ||
| if not idx.is_cuda: | ||
| idx = idx.cuda() | ||
| idx = idx.contiguous() | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Validate index range early to prevent silent corruption (encode path).
Fail fast if any idx < 0 or ≥ C.
Apply:
if idx.dtype != torch.int64:
idx = idx.to(torch.int64)
if not idx.is_cuda:
idx = idx.cuda()
idx = idx.contiguous()
+ if idx.numel():
+ mn = int(idx.min().item())
+ mx = int(idx.max().item())
+ if mn < 0 or mx >= C:
+ raise ValueError(f"Index out of range for C={C}: min={mn}, max={mx}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if idx.dtype != torch.int64: | |
| idx = idx.to(torch.int64) | |
| if not idx.is_cuda: | |
| idx = idx.cuda() | |
| idx = idx.contiguous() | |
| if idx.dtype != torch.int64: | |
| idx = idx.to(torch.int64) | |
| if not idx.is_cuda: | |
| idx = idx.cuda() | |
| idx = idx.contiguous() | |
| # Fail fast on any out-of-bounds index | |
| if idx.numel(): | |
| mn = int(idx.min().item()) | |
| mx = int(idx.max().item()) | |
| if mn < 0 or mx >= C: | |
| raise ValueError(f"Index out of range for C={C}: min={mn}, max={mx}") |
🤖 Prompt for AI Agents
In src/tplr/compress/bits.py around lines 386–391, add explicit range validation
for idx before it can silently corrupt encoded data: after converting idx to an
integer tensor (and before using it on CUDA), move a copy to CPU if needed,
compute min_val = int(idx.min().item()) and max_val = int(idx.max().item()), and
if min_val < 0 or max_val >= C raise a ValueError (or assert) with a clear
message containing min_val, max_val and C; this ensures we fail fast on
out-of-range indices and avoids silent corruption.
65167f1 to
62815c0
Compare
62815c0 to
51fed2b
Compare
Description
ChunkingTransformerRelated Issue(s)
Type of Change
Branch Naming
Commit Messages
Code Quality
Testing
Documentation
If this is a breaking change
Screenshots/Examples
Additional Notes
Summary by CodeRabbit
New Features
Changes
Bug Fixes
Refactor
Tests