Skip to content

feat/efficient compress#553

Draft
joellidin wants to merge 12 commits intodevfrom
feat/efficient-compress
Draft

feat/efficient compress#553
joellidin wants to merge 12 commits intodevfrom
feat/efficient-compress

Conversation

@joellidin
Copy link
Copy Markdown
Collaborator

@joellidin joellidin commented Aug 31, 2025

Description

  • (compress) Refactor into package structure
  • (compress) Add Rice/bitmap codec for indices
  • (tests) Add unit tests for Rice/bitmap codec
  • (compress) Migrate to Rice/bitmap codec
  • (compress) Remove DCT from ChunkingTransformer
  • (hparams) Increase topk compression from 32 to 128

Related Issue(s)

  • Closes #[issue number]

Type of Change

  • Feature (adding new functionality)
  • Fix (resolving a bug or issue)
  • Docs (documentation updates)
  • Refactor (code changes that don't affect functionality)
  • Maintenance (dependency updates or other maintenance)
  • Tests (adding or improving tests)
  • Breaking change (fix or feature with incompatible API changes)
  • Other: _____

Branch Naming

  • My branch follows the project's naming convention (e.g., feature/add-new-capability)

Commit Messages

  • My commits are small, atomic, and have proper commit messages
  • Commit messages are in imperative mood with a capitalized summary under 50 chars

Code Quality

  • I've performed a self-review of my code
  • I've added appropriate docstrings following the project's conventions
  • I've added proper logging where necessary (without trailing periods)
  • I've applied linting and formatting with Ruff
  • My code generates no new warnings

Testing

  • I've added tests for new functionality or bug fixes
  • All tests pass locally with my changes
  • Test coverage has not decreased

Documentation

  • I've updated documentation to reflect my changes
  • I've updated comments in hard-to-understand areas

If this is a breaking change

Screenshots/Examples

Additional Notes

Summary by CodeRabbit

  • New Features

    • New Rice/bitmap top‑k codec with CPU/GPU encoders/decoders that emit compact byte payloads and permutation metadata; unified compression exports.
  • Changes

    • Default top‑k compression increased 32 → 128.
    • Transformer API simplified: DCT/use_dct toggle removed; encode/decode use default behavior.
    • Initialization now logs per‑layer encode/compress timing and totals.
  • Bug Fixes

    • Stricter payload validation, clearer error messages, and explicit input/type checks.
  • Refactor

    • Legacy 12‑bit packing and DCT-based paths replaced by the new codec.
  • Tests

    • Added/updated comprehensive tests for the new codec and interfaces.

Reorganize the compress module into a proper package structure to
improve maintainability and separate concerns.

- Move compress.py to compress/topk.py for TopK-specific functionality
- Create compress/__init__.py with clean public API exports
- Extract 12-bit packing functions to compress/pack12.py
- Update test imports to use new package structure
- Maintain backward compatibility through __init__.py exports
Implement new GPU-accelerated compression codec for gradient indices
using Rice coding with adaptive parameters and bitmap encoding for dense
blocks.

- Add BitWriter/BitReader classes for efficient bit-level I/O
- Implement encode_batch_rows with GPU-accelerated path and CPU fallback
- Add decode_batch_rows for CPU-based decompression
- Support adaptive Rice parameters and bitmap vs local encoding
- Export new codec functions in compress package __init__.py
- Maintain backward compatibility with existing pack12/topk modules
Add extensive test coverage for the bits compression codec including
round-trip encode/decode validation, permutation tracking, and
cross-device compatibility testing.

- Test round-trip operations ensuring data integrity
- Verify permutation tracking for correct value reordering
- Cover edge cases: zero rows, zero K values, empty tensors
- Test CPU and GPU device parity when CUDA available
- Validate bitmap vs local encoding path selection
- Ensure compatibility with C values divisible by B_choices
- Test both GPU-accelerated and CPU reference implementations
- Verify codec handles unique indices from TopK operations
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Aug 31, 2025

Warning

Rate limit exceeded

@joellidin has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 11 minutes and 46 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between 65167f1 and 51fed2b.

📒 Files selected for processing (2)
  • src/tplr/compress/__init__.py (1 hunks)
  • src/tplr/compress/bits.py (1 hunks)

Walkthrough

Replaces 12‑bit packed indices with a Rice/bitmap bits codec and GPU encoder/CPU decoder, centralizes compress exports, removes use_dct from transformer and outer_step APIs, updates comms validation to use decode_batch_rows, raises topk_compression 32→128, and adds encoder/compressor timing instrumentation and tests.

Changes

Cohort / File(s) Summary of changes
Config
hparams/hparams.json
topk_compression changed 32 → 128; removed use_dct field.
Compress public API & topk
src/tplr/compress/__init__.py, src/tplr/compress/topk.py
Added package initializer re-exporting public API; replaced 12‑bit packing with Rice/bitmap batch codec (encode_batch_rows/decode_batch_rows); removed DCT helpers and use_dct toggles; updated ChunkingTransformer signatures and TopKCompressor compress/decompress contracts; added _DEFAULT_B_CHOICES and QuantParamsT.
Bits codec & pack12 utilities
src/tplr/compress/bits.py, src/tplr/compress/pack12.py
New Triton-enabled GPU encoder and CPU decoder for Rice/bitmap bitstream (encode_batch_rows, decode_batch_rows), bitstream helpers and readers; added legacy 12‑bit pack/unpack utilities in pack12.py.
Comms validation
src/tplr/comms.py, tests/test_comms.py
Validation moved to decoder-driven flow using decode_batch_rows; stricter dtype/shape checks and explicit ValueError messages; tests updated to accept byte payload → torch.uint8 tensors and new error strings.
Neurons & callsites
src/tplr/neurons.py, neurons/miner.py, neurons/trainer.py, neurons/validator.py
Removed use_dct parameter from encode/decode call sites and outer_step signature; replaced 12‑bit unpacking with decode_batch_rows; validator now collects encode/compress timing per-layer and logs totals; normalized payload bytes handling.
Compressor internals
src/tplr/compress/* (general)
TopK compressor emits index payload bytes and permutation metadata; decompress/batch_decompress updated to consume new payload format and validate totalk; quantization qparams adjusted.
Evaluator import
neurons/evaluator.py
Added from lm_eval import simple_evaluate import.
Tests: new and updated suites
tests/unit/test_bits_codec.py, tests/unit/test_compress.py, tests/test_prepare_gradient_dict.py, tests/unit/test_neurons.py, tests/test_comms.py
Added comprehensive bits codec tests and adapted compress/comms/neuron tests to the new codec; removed DCT/use_dct from test doubles and call sites; retained legacy 12‑bit checks where relevant.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Outer as outer_step / Trainer
  participant Transformer
  participant Compressor
  participant Comms
  participant Peer

  rect rgb(240,248,255)
    note over Outer: Outer step (no use_dct)
    Outer->>Transformer: encode(tensor)
    Transformer-->>Outer: encoded_chunks
    Outer->>Compressor: compress(encoded_chunks, K)
    Compressor-->>Outer: (idx_bytes, perm2d, vals, xshape, totalk[, qparams])
  end

  rect rgb(245,255,240)
    note over Outer,Comms: transmit & validate
    Outer->>Comms: check_compressed_indices(idxs=uint8_payload, vals, totalk, k)
    Comms->>Comms: rows_list, C, N = decode_batch_rows(payload_bytes)
    Comms-->>Outer: validation_result / ValueError
    Outer-->>Peer: send payload bytes + vals
  end

  rect rgb(255,248,240)
    note over Peer: decompress & reconstruct
    Peer->>Compressor: batch_decompress(payload_bytes, vals, ...)
    Compressor->>Transformer: decode(decompressed)
    Transformer-->>Peer: reconstructed_tensor
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested reviewers

  • AlexanderLavelle
  • amiiir-sarfi
  • distributedstatemachine

Poem

I nibble bytes and hop through rows,
Rice seeds tumble where the top‑K grows.
No DCT hat, just bits that play —
Perms and payloads lead the way.
A rabbit cheers: compress, decode, hooray! 🐇✨

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/efficient-compress

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@codecov
Copy link
Copy Markdown

codecov bot commented Aug 31, 2025

Codecov Report

❌ Patch coverage is 89.08686% with 49 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/tplr/compress/bits.py 90.85% 30 Missing ⚠️
src/tplr/compress/topk.py 77.77% 10 Missing ⚠️
src/tplr/comms.py 70.96% 9 Missing ⚠️

Impacted file tree graph

@@            Coverage Diff             @@
##              dev     #553      +/-   ##
==========================================
+ Coverage   61.07%   62.92%   +1.85%     
==========================================
  Files          25       28       +3     
  Lines        4056     4340     +284     
==========================================
+ Hits         2477     2731     +254     
- Misses       1579     1609      +30     
Files with missing lines Coverage Δ
src/tplr/compress/__init__.py 100.00% <100.00%> (ø)
src/tplr/compress/pack12.py 100.00% <100.00%> (ø)
src/tplr/neurons.py 75.37% <100.00%> (-0.06%) ⬇️
src/tplr/comms.py 69.18% <70.96%> (-0.20%) ⬇️
src/tplr/compress/topk.py 75.96% <77.77%> (ø)
src/tplr/compress/bits.py 90.85% <90.85%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (8)
tests/test_comms.py (2)

779-791: Fix async client context in s3_put_small_file test (missing aenter/aexit).

currently new_client.aenter() returns a fresh AsyncMock without put_object, causing AttributeError.

Apply:

-    mock_client = AsyncMock()
-    mock_client.put_object = AsyncMock()
+    mock_client = AsyncMock()
+    mock_client.put_object = AsyncMock()
+    # ensure _get_s3_client awaits __aenter__ and gets this same client
+    mock_client.__aenter__.return_value = mock_client
+    mock_client.__aexit__.return_value = AsyncMock()
     comms_instance.session.create_client = MagicMock(return_value=mock_client)

809-817: Same async context fix for multipart upload test.

without aenter/aexit, create_multipart_upload etc. may not be found on the awaited client.

Apply:

-    mock_client = AsyncMock()
+    mock_client = AsyncMock()
     mock_client.create_multipart_upload = AsyncMock(
         return_value={"UploadId": "test_id"}
     )
     mock_client.upload_part = AsyncMock(return_value={"ETag": "test_etag"})
     mock_client.complete_multipart_upload = AsyncMock()
     mock_client.abort_multipart_upload = AsyncMock()
+    mock_client.__aenter__.return_value = mock_client
+    mock_client.__aexit__.return_value = AsyncMock()
     comms_instance.session.create_client = MagicMock(return_value=mock_client)
src/tplr/neurons.py (2)

36-36: Stop importing 12‑bit unpacker; use Rice/bitmap decoder.

check_uid_index_overlap still assumes 12‑bit. Replace with decode_batch_rows to match the new codec; keep a fallback for legacy int64 tensors.

Apply:

-from tplr.compress import unpack_12bit_indices
+from tplr.compress import decode_batch_rows

952-964: Overlap detector decodes indices with 12‑bit unpacker (wrong format).

This will miscompute overlaps for Rice/bitmap payloads. Decode bytes via decode_batch_rows and reshape to vals shape; keep int64 fallback.

Apply:

-        # Unpack all 12-bit packed indices using values shape
+        # Decode indices (Rice/bitmap bytes → rows → reshape) or accept int64 legacy
         unpacked_indices = []
         for i in range(Ptot):
             idx_data = idxs_all[i] if isinstance(idxs_all, list) else idxs_all
             val_data = vals_all[i] if isinstance(vals_all, list) else vals_all
-
-            # 12-bit packed format - use values shape for unpacking
-            unpacked = unpack_12bit_indices(
-                idx_data.to(neuron.config.device), val_data.shape
-            )
+            if isinstance(idx_data, torch.Tensor) and idx_data.dtype == torch.uint8:
+                rows, _C, _N = decode_batch_rows(idx_data.detach().cpu().numpy().tobytes())
+                if any(len(r) != val_data.shape[-1] for r in rows):
+                    raise ValueError("Row-wise topk size mismatch in index payload (overlap check)")
+                unpacked = torch.tensor(rows, dtype=torch.int64, device=neuron.config.device).view(*val_data.shape)
+            elif isinstance(idx_data, torch.Tensor) and idx_data.dtype in (torch.int64, torch.long):
+                unpacked = idx_data.to(neuron.config.device).view(*val_data.shape)
+            else:
+                raise ValueError(f"Unsupported index dtype in overlap check: {getattr(idx_data, 'dtype', type(idx_data))}")
             unpacked_indices.append(unpacked)
tests/unit/test_compress.py (2)

102-141: Same issue: each payload encodes 2 rows; xshape has 8 rows.

Both payloads must encode N=8 (or set xshape=(2,64)).

Apply:

-        p = torch.zeros(8, 64)
-        xshape = (8, 64)
-        totalk = 64
-        idx1_orig = torch.tensor([[0, 1], [2, 3]], dtype=torch.int64)
-        idx2_orig = torch.tensor([[4, 5], [6, 7]], dtype=torch.int64)
+        p = torch.zeros(8, 64)
+        xshape = (8, 64)
+        totalk = 64
+        K = 2
+        idx1_orig = torch.stack(
+            [torch.tensor([0, 1], dtype=torch.int64) for _ in range(8)]
+        )
+        idx2_orig = torch.stack(
+            [torch.tensor([2, 3], dtype=torch.int64) for _ in range(8)]
+        )
@@
-        val1 = torch.tensor([[0.1, 0.2], [0.3, 0.4]], dtype=torch.float32)
-        val2 = torch.tensor([[0.5, 0.6], [0.7, 0.8]], dtype=torch.float32)
+        val1 = torch.randn(8, K, dtype=torch.float32)
+        val2 = torch.randn(8, K, dtype=torch.float32)
@@
-        val1 = torch.gather(val1, dim=1, index=perm1)
-        val2 = torch.gather(val2, dim=1, index=perm2)
+        val1 = torch.gather(val1, dim=1, index=perm1)
+        val2 = torch.gather(val2, dim=1, index=perm2)

205-241: Again: payload encodes 1 row; xshape has 8 rows.

Align N in payload with xshape’s row count.

Apply:

-        p = torch.zeros(8, 64)
-        xshape = (8, 64)
-        totalk = 64
-        idx_orig = torch.tensor([[0, 1, 2, 3]], dtype=torch.int64)
+        p = torch.zeros(8, 64)
+        xshape = (8, 64)
+        totalk = 64
+        K = 4
+        idx_orig = torch.stack(
+            [torch.arange(K, dtype=torch.int64) for _ in range(8)]
+        )
@@
-        val_orig = torch.tensor([[10.0, 20.0, 30.0, 40.0]], dtype=torch.float32)
-        val = [torch.gather(val_orig, dim=1, index=perm)]
+        val_orig = torch.full((8, K), 10.0, dtype=torch.float32)
+        val = [torch.gather(val_orig, dim=1, index=perm)]
src/tplr/compress/topk.py (2)

430-471: Bug: boolean tests on tensors in _quantize_values will raise.

Using Python or with torch tensors is invalid. Also ensure scale is finite via tensor ops.

Apply:

-        scale = self.range_in_sigmas * std / self.n_bins
-        if (
-            isinstance(scale, torch.Tensor)
-            and (scale == 0 or torch.isnan(scale) or torch.isinf(scale))
-        ) or (
-            not isinstance(scale, torch.Tensor)
-            and (scale == 0 or not math.isfinite(float(scale)))
-        ):
-            scale = torch.tensor(1.0, dtype=centered.dtype, device=val.device)
+        scale_t = (self.range_in_sigmas * std) / float(self.n_bins)
+        if not isinstance(scale_t, torch.Tensor):
+            scale_t = torch.tensor(float(scale_t), dtype=centered.dtype, device=val.device)
+        if torch.isnan(scale_t) or torch.isinf(scale_t) or (scale_t == 0):
+            scale_t = torch.tensor(1.0, dtype=centered.dtype, device=val.device)
@@
-        centered_fp32 = centered.to(torch.float32)
-        qval = ((centered_fp32 / scale + offset).round().clamp(0, self.n_bins - 1)).to(
+        centered_fp32 = centered.to(torch.float32)
+        qval = ((centered_fp32 / scale_t + offset).round().clamp(0, self.n_bins - 1)).to(
             torch.uint8
         )
@@
-        qparams: QuantParamsT = (shift, float(scale), int(offset), lookup, val.dtype)
+        qparams: QuantParamsT = (shift, float(scale_t.item()), int(offset), lookup, val.dtype)

494-545: maybe_dequantize_values: tuple/list detection is unsafe and can TypeError.

len() on elements of the tuple (e.g., float) will crash. Simplify to a straightforward normalisation.

Apply:

-        if (
-            isinstance(qparams, tuple)
-            and len(qparams) == 5  # potentially single or already 5 elements
-            and not all([len(q) == 5 for q in qparams])  # already correctly formatted
-        ):
-            qparams = [qparams]
-        if not isinstance(qparams, list):
-            qparams = [qparams]
+        if isinstance(qparams, tuple):
+            qparams = [qparams]
+        elif not isinstance(qparams, list):
+            qparams = [qparams]  # type: ignore[list-item]
🧹 Nitpick comments (19)
hparams/hparams.json (3)

13-15: Check semantic consistency: topk_compression (128) vs target_chunk (64).

If target_chunk bounds the number of active entries per chunk, having topk_compression > target_chunk could be inconsistent or wasteful. Validate intended relationship or add an assertion.


1-3: Consider bumping spec_version or adding a migration note.

Given the removal of use_dct across the codebase and the new codec, keeping spec_version at 5 may confuse consumers parsing configs. Either bump spec_version or document that spec v5 now implies Rice/bitmap-only.


13-13: Add tests and performance checks for topk_compression=128

  • Verify the Rice/bitmap codec’s dynamic bit-length encoding correctly handles k=128 (no hard-coded 32-value limits were found) by adding round-trip/property tests at k=128 and extreme sparse/dense scenarios.
  • Benchmark RPC/message payload sizes and per-batch memory overhead at target_batch_size=1024 to quantify any increases versus k=32.
  • Update or add tests that currently assume or reference 32 (e.g., in tests/…, local_miner_test.py) to cover the new default of 128.
neurons/evaluator.py (1)

63-63: Remove unused import and avoid shelling out; also fix variable shadowing in result parsing.

  • The added import from lm_eval import simple_evaluate is unused.
  • Prefer subprocess.run(..., check=True) over os.system for better error handling.
  • In _process_results, the loop variable task_name shadows the function argument; rename the inner variable.

Apply these diffs:

- from lm_eval import simple_evaluate
+ # (optional) import only if switching to in-process evaluation
+ # from lm_eval import simple_evaluate
-        # TODO: Consider replacing os.system with subprocess.run for better control and error handling.
-        exit_code = os.system(command)
+        import subprocess
+        try:
+            completed = subprocess.run(command, shell=True, check=False)
+            exit_code = completed.returncode
+        except Exception as e:
+            tplr.logger.error(f"Failed to launch benchmark: {e}")
+            exit_code = 1
-        for task_name, task_results in results.get("results", {}).items():
+        for task, task_results in results.get("results", {}).items():
@@
-            if metric_value is not None:
+            if metric_value is not None:
                 tplr.logger.info(
-                    f"Benchmark for {task_name} ({used_metric}): {metric_value}"
+                    f"Benchmark for {task} ({used_metric}): {metric_value}"
                 )
                 self.metrics_logger.log(
                     measurement="benchmark_task",
                     tags={
-                        "task": task_name,
+                        "task": task,

Also applies to: 476-478, 538-556

src/tplr/compress/bits.py (4)

224-229: Validate header bounds (C, N) before writing.

Header allocates 12 bits for C-1 and 16 bits for N. Guard overflows early.

-    bw = BitWriter()
+    bw = BitWriter()
+    if C <= 0 or C > (1 << 12):
+        raise ValueError(f"C={C} out of range for 12-bit header (1..4096)")
+    if N > 0xFFFF:
+        raise ValueError(f"N={N} exceeds 16-bit row count limit ({N} > 65535)")

268-279: Add the same header bounds check to the GPU path.

Mirror CPU encoder constraints to avoid producing undecodable payloads.

     rows, k = idx.shape
     device = idx.device
+    if C <= 0 or C > (1 << 12):
+        raise ValueError(f"C={C} out of range for 12-bit header (1..4096)")
+    if rows > 0xFFFF:
+        raise ValueError(f"rows={rows} exceeds 16-bit row count limit (65535)")

252-318: Avoid F.one_hot blow-up for large (rows, k, C/Bmin).

The one-hot path can allocate rows×k×(C/Bmin) temporarily. Consider using a bincount-based aggregation for all B values to reduce memory.


452-489: Decode robustness: handle truncated/corrupt row payloads cleanly.

Wrap per‑row decode to surface EOFs as ValueError and avoid partial state.

-    for _i in range(N):
-        row_len = br.read_bits(16)
-        row_bytes = br.read_bytes(row_len)
-        rr = BitReader(row_bytes)
+    for _i in range(N):
+        try:
+            row_len = br.read_bits(16)
+            row_bytes = br.read_bytes(row_len)
+            rr = BitReader(row_bytes)
+        except EOFError as e:
+            raise ValueError(f"Truncated payload while reading row header #{_i}: {e}") from e
@@
-            s_len = _rice_read(rr, k_param)
+            try:
+                s_len = _rice_read(rr, k_param)
+            except EOFError as e:
+                raise ValueError(f"Truncated payload in row #{_i}, sub-block {j}: {e}") from e
@@
-                for loc in range(B):
-                    if (bitmask >> loc) & 1:
-                        indices.append(j * B + loc)
+                for loc in range(B):
+                    if (bitmask >> loc) & 1:
+                        indices.append(j * B + loc)
             else:
-                for _ in range(s_len):
-                    loc = rr.read_bits(lb)
-                    indices.append(j * B + loc)
+                for _ in range(s_len):
+                    try:
+                        loc = rr.read_bits(lb)
+                    except EOFError as e:
+                        raise ValueError(f"Truncated payload reading loc in row #{_i}, sub-block {j}") from e
+                    indices.append(j * B + loc)
src/tplr/comms.py (2)

2663-2682: Micro: guard edge cases and cut tiny allocations.

  • If vals.ndim == 1, expected_rows is 1; add an explicit k>0 check to fail early for empty last dim.
  • Use math.prod on Python ints instead of np.prod to avoid creating a NumPy array for small shapes.

Apply:

-        expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1
+        expected_rows = int(math.prod([int(d) for d in vals.shape[:-1]])) if vals.ndim > 1 else 1
+        if vals.shape[-1] <= 0:
+            raise ValueError(f"[{param_name}] Values top‑k dimension must be > 0")

1606-1756: Metrics: upload_bytes remains unused.

Only download_bytes is updated. If you intend to track egress in this path, either remove upload_bytes or populate it where appropriate.

Also applies to: 1729-1756

tests/test_comms.py (2)

16-18: Tests now cover Rice/bitmap path; keep numpy import local to those tests.

You can import numpy only where needed (the Rice/bitmap tests) to keep module import time down, but it’s minor.


1609-1617: Comment mismatch (minor).

The “Test with list (not a tensor)” block actually uses a tensor; either change the wording or use a list to avoid confusion.

Apply:

-    # Test with list (not a tensor)
-    invalid_list = torch.tensor([0, 1, 2, 3])
+    # Test with wrong dtype tensor (int64)
+    invalid_list = torch.tensor([0, 1, 2, 3], dtype=torch.int64)
tests/unit/test_bits_codec.py (3)

71-124: Round‑trip tests look solid; add determinism to avoid flakiness.

Seed the RNG to make failures reproducible across devices/backends.

Apply:

-    K = make_even_k(K)
+    torch.manual_seed(0)
+    K = make_even_k(K)

157-191: CPU vs new encoder equality is brittle; compare sets or relax ordering.

Different B choices can yield different emission orders while still being correct. Prefer set equality or sort both rows before comparison.

Apply:

-    for i in range(N):
-        assert rows_ref[i] == rows_new[i], f"row {i} decode differs (CPU ref vs new)"
+    for i in range(N):
+        assert sorted(rows_ref[i]) == sorted(rows_new[i]), \
+            f"row {i}: decoded sets differ (CPU ref vs new)"

282-314: Good coverage of bitmap/local path bit; consider asserting row_len > 0.

Sanity‑check that the row actually carries payload in these cases.

Apply:

-    C2, N2, row_len, lb, k_param, use_bitmap = parse_first_row_header(payload)
+    C2, N2, row_len, lb, k_param, use_bitmap = parse_first_row_header(payload)
+    assert row_len > 0
tests/unit/test_compress.py (1)

177-203: Nice large‑C coverage. Minor nit: add seed.

Seed once to stabilise.

Apply:

-        x = torch.randn(128, 128)  # 16,384 elements
+        torch.manual_seed(0)
+        x = torch.randn(128, 128)  # 16,384 elements
src/tplr/compress/topk.py (3)

173-189: Even‑forcing top‑k is legacy; consider allowing odd k.

Rice/bitmap no longer requires even k (12‑bit path did). Enforcing even may reduce fidelity unnecessarily.

Apply:

-        topk = max(topk, 2)
-        # Keep even by default (matches broader system expectations).
-        topk = topk - (topk % 2)
+        topk = max(topk, 1)

If other callers still assume even, keep as‑is.


248-253: Avoid unnecessary CPU pinning for idx bytes.

Allow device‑agnostic storage; it still decodes via .cpu().

Apply:

-        idx_bytes = torch.tensor(
+        idx_bytes = torch.as_tensor(
             np.frombuffer(payload, dtype=np.uint8).copy(),
-            dtype=torch.uint8,
-            device="cpu",
+            dtype=torch.uint8,
         )

21-34: Optional: guard DTensor import for environments without DTensor.

Unconditional DT import can fail on some installs.

Apply:

-from torch.distributed.tensor import DTensor as DT
+try:
+    from torch.distributed.tensor import DTensor as DT  # type: ignore
+except Exception:  # pragma: no cover
+    DT = tuple()  # sentinel for isinstance checks

And update the check:

-        if isinstance(x, DT):  # check for dtensors
+        if DT and isinstance(x, DT):  # check for dtensors
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 6bd6871 and c1adcac.

📒 Files selected for processing (16)
  • hparams/hparams.json (1 hunks)
  • neurons/evaluator.py (1 hunks)
  • neurons/miner.py (1 hunks)
  • neurons/trainer.py (0 hunks)
  • neurons/validator.py (2 hunks)
  • src/tplr/comms.py (3 hunks)
  • src/tplr/compress/__init__.py (1 hunks)
  • src/tplr/compress/bits.py (1 hunks)
  • src/tplr/compress/pack12.py (1 hunks)
  • src/tplr/compress/topk.py (10 hunks)
  • src/tplr/neurons.py (3 hunks)
  • tests/test_comms.py (11 hunks)
  • tests/test_prepare_gradient_dict.py (4 hunks)
  • tests/unit/test_bits_codec.py (1 hunks)
  • tests/unit/test_compress.py (8 hunks)
  • tests/unit/test_neurons.py (0 hunks)
💤 Files with no reviewable changes (2)
  • neurons/trainer.py
  • tests/unit/test_neurons.py
🧰 Additional context used
🧬 Code graph analysis (10)
src/tplr/compress/bits.py (1)
src/tplr/comms.py (2)
  • get (1248-1364)
  • gather (1466-1787)
tests/unit/test_bits_codec.py (1)
src/tplr/compress/bits.py (6)
  • BitReader (62-98)
  • decode_batch_rows (452-488)
  • encode_batch_rows (253-444)
  • encode_batch_rows_cpu (129-249)
  • read_bits (75-83)
  • read_bytes (94-98)
tests/test_comms.py (3)
src/tplr/compress/pack12.py (1)
  • pack_12bit_indices (21-47)
src/tplr/compress/bits.py (1)
  • encode_batch_rows (253-444)
src/tplr/comms.py (1)
  • check_compressed_indices (2615-2701)
src/tplr/compress/__init__.py (2)
src/tplr/compress/bits.py (3)
  • decode_batch_rows (452-488)
  • encode_batch_rows (253-444)
  • encode_batch_rows_cpu (129-249)
src/tplr/compress/pack12.py (2)
  • pack_12bit_indices (21-47)
  • unpack_12bit_indices (50-75)
src/tplr/comms.py (2)
src/tplr/compress/topk.py (3)
  • compress (194-198)
  • compress (200-204)
  • compress (207-257)
src/tplr/compress/bits.py (1)
  • decode_batch_rows (452-488)
src/tplr/neurons.py (3)
src/tplr/compress/topk.py (2)
  • encode (77-95)
  • decode (98-113)
tests/test_prepare_gradient_dict.py (6)
  • encode (33-34)
  • encode (167-168)
  • encode (260-262)
  • decode (36-37)
  • decode (170-171)
  • decode (264-266)
tests/test_comms.py (2)
  • encode (1026-1028)
  • decode (1030-1032)
tests/test_prepare_gradient_dict.py (2)
src/tplr/compress/topk.py (2)
  • encode (77-95)
  • decode (98-113)
tests/test_comms.py (2)
  • encode (1026-1028)
  • decode (1030-1032)
neurons/validator.py (3)
src/tplr/compress/topk.py (1)
  • decode (98-113)
tests/test_prepare_gradient_dict.py (3)
  • decode (36-37)
  • decode (170-171)
  • decode (264-266)
tests/test_comms.py (1)
  • decode (1030-1032)
tests/unit/test_compress.py (1)
src/tplr/compress/bits.py (1)
  • encode_batch_rows (253-444)
src/tplr/compress/topk.py (1)
src/tplr/compress/bits.py (2)
  • decode_batch_rows (452-488)
  • encode_batch_rows (253-444)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (12)
neurons/validator.py (2)

265-266: Encode call updated to new signature — looks good.

ChunkingTransformer.encode(x) change is correctly applied.


2805-2806: Decode call updated to new signature — looks good.

self.transformer.decode(decompressed) matches the refactor.

neurons/miner.py (1)

261-262: Encode call updated to new signature — looks good.

Miner now calls ChunkingTransformer.encode(x) without use_dct. No other changes required here.

src/tplr/comms.py (1)

52-52: Good switch to decoder-driven validation path.

Importing decode_batch_rows here aligns comms validation with the new codec and keeps the surface consistent with tplr.compress.

tests/test_comms.py (1)

1557-1579: Nice coverage of Rice/bitmap encode/decode validation.

Good positive/negative tests around shapes, k, and C. This will catch most regressions in the new codec path.

Also applies to: 1582-1598, 1635-1650, 1652-1670, 1675-1693, 1705-1737, 1739-1767

src/tplr/compress/__init__.py (1)

18-35: Public surface looks clean and complete.

Re‑exporting encode/decode and keeping pack12 legacy helpers under the same namespace is a sensible transition strategy.

src/tplr/neurons.py (1)

163-183: DCT toggle removal looks correct.

encode/decode calls now use the unified transformer path; outer_step also switched to transformer.decode without the old flag.

Also applies to: 328-333

tests/test_prepare_gradient_dict.py (3)

32-38: Interface alignment with transformer encode/decode (no use_dct).

These changes look correct and consistent with the new API.


166-172: Local dummy passthrough matches new transformer signature.

Good test double; no issues.


470-477: Updated failing_decode signature is correct.

Matches the simplified transformer API.

tests/unit/test_compress.py (1)

35-54: OK: validates Rice/bitmap output type.

Asserting uint8 for indices is appropriate.

src/tplr/compress/topk.py (1)

230-247: Value reordering matches emitted index order — good.

This ensures decode alignment and avoids extra shuffles downstream.

Comment on lines +2626 to 2702
This ensures indices are within bounds and that the **new Rice/bitmap**
codec payload matches the provided values tensor shape (top‑k).

Args:
param_name (str): The name of the parameter being checked.
idxs (torch.Tensor): The tensor of indices.
totalk (int): The total number of elements in the original uncompressed tensor.
allowed_topk (int | None, optional): The expected number of top-k values.
Defaults to the hparams configuration.
vals (torch.Tensor | None, optional): The corresponding values tensor,
required for validating 12-bit packed indices. Defaults to None.
vals (torch.Tensor | None, optional): The corresponding values tensor.

Raises:
ValueError: If any validation check fails, such as out-of-bounds
indices, incorrect data types, or malformed packed data.
indices, incorrect data types, or malformed payload.
"""
allowed_topk = (
min(self.hparams.topk_compression, totalk)
if allowed_topk is None
else min(allowed_topk, totalk)
)

def _bounds_check(t: torch.Tensor):
"""fast min/max bounds check"""
if t.numel() == 0:
raise ValueError(f"[{param_name}] empty index list")
if t.min().item() < 0 or t.max().item() >= totalk:
bad = t[(t < 0) | (t >= totalk)][0].item()
raise ValueError(
f"[{param_name}] Index {bad} out of bounds (totalk = {totalk})"
)
if not isinstance(idxs, torch.Tensor):
raise ValueError(
f"[{param_name}] Expected tensor for indices, got {type(idxs)}"
)
if vals is None:
raise ValueError(
f"[{param_name}] Values tensor required for index validation"
)
if idxs.dtype != torch.uint8:
raise ValueError(
f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}"
)
if idxs.numel() == 0:
raise ValueError(f"[{param_name}] Empty indices payload")

# Handle 12-bit packed index format only
if isinstance(idxs, torch.Tensor):
if idxs.dtype != torch.uint8:
raise ValueError(
f"[{param_name}] Expected uint8 for 12-bit packed indices, got {idxs.dtype}"
)
# 12-bit packed format is the only supported format
if vals is None:
raise ValueError(
f"[{param_name}] Values tensor required to validate 12-bit packed indices"
)
if idxs.numel() == 0:
raise ValueError(f"[{param_name}] Empty packed indices tensor")
# Decode (CPU) and perform structural checks
try:
payload_bytes = idxs.detach().cpu().numpy().tobytes()
rows_list, C, N = decode_batch_rows(payload_bytes)
except Exception as e:
raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}")

# Unpack using the values shape
try:
unpacked = unpack_12bit_indices(idxs, vals.shape)
# Validate that the last dimension matches allowed_topk
if unpacked.shape[-1] != allowed_topk:
raise ValueError(
f"[{param_name}] Invalid topk dimension: "
f"shape[-1]={unpacked.shape[-1]} but expected {allowed_topk}"
)
_bounds_check(unpacked)
except Exception as e:
raise ValueError(f"[{param_name}] Failed to unpack 12-bit indices: {e}")
else:
raise ValueError(f"[{param_name}] Expected tensor but got {type(idxs)}")
if C != totalk:
raise ValueError(
f"[{param_name}] Payload column size C={C} but expected {totalk}"
)

# compute expected rows from values shape (flatten all but last dim)
if vals.ndim == 0:
raise ValueError(f"[{param_name}] Values tensor has no top‑k dimension")
expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1
if N != expected_rows:
raise ValueError(
f"[{param_name}] Payload rows N={N} but values imply {expected_rows}"
)

k = vals.shape[-1]
if k != allowed_topk:
raise ValueError(
f"[{param_name}] Values top‑k={k} but allowed_topk={allowed_topk}"
)
if any(len(r) != k for r in rows_list):
raise ValueError(
f"[{param_name}] At least one row has mismatched top‑k size"
)

# bounds check without materialising full tensor
max_idx = max((max(r) if len(r) > 0 else -1) for r in rows_list)
min_idx = (
min((min(r) if len(r) > 0 else 0) for r in rows_list) if rows_list else 0
)
if min_idx < 0 or max_idx >= totalk:
raise ValueError(
f"[{param_name}] Index out of bounds (min={min_idx}, max={max_idx}, totalk={totalk})"
)

Copy link
Copy Markdown

@coderabbitai coderabbitai bot Aug 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

check_compressed_indices rejects legacy/int64 indices; add backward‑compat path.

Today this function hard‑fails unless idxs is uint8 (Rice/bitmap). Older miners or historical payloads may still send int64 index tensors. Decompression paths already handle both; validation should too to avoid skipping otherwise valid peers.

Apply:

@@
-        if idxs.dtype != torch.uint8:
-            raise ValueError(
-                f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}"
-            )
+        if idxs.dtype == torch.uint8:
+            pass  # Rice/bitmap fast-path below
+        elif idxs.dtype in (torch.int64, torch.long):
+            # Legacy/unpacked path: shape and bounds checks only
+            if vals is None:
+                raise ValueError(f"[{param_name}] Values tensor required for index validation")
+            if tuple(idxs.shape) != tuple(vals.shape):
+                raise ValueError(f"[{param_name}] Indices shape {tuple(idxs.shape)} != values shape {tuple(vals.shape)}")
+            k = vals.shape[-1] if vals.ndim > 0 else 0
+            if k != allowed_topk:
+                raise ValueError(f"[{param_name}] Values top‑k={k} but allowed_topk={allowed_topk}")
+            if idxs.numel() > 0:
+                i_min = int(idxs.min().item())
+                i_max = int(idxs.max().item())
+                if i_min < 0 or i_max >= totalk:
+                    raise ValueError(f"[{param_name}] Index out of bounds (min={i_min}, max={i_max}, totalk={totalk})")
+            return  # legacy path validated
+        else:
+            raise ValueError(
+                f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}"
+            )
@@
-            rows_list, C, N = decode_batch_rows(payload_bytes)
+            rows_list, C, N = decode_batch_rows(payload_bytes)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
This ensures indices are within bounds and that the **new Rice/bitmap**
codec payload matches the provided values tensor shape (topk).
Args:
param_name (str): The name of the parameter being checked.
idxs (torch.Tensor): The tensor of indices.
totalk (int): The total number of elements in the original uncompressed tensor.
allowed_topk (int | None, optional): The expected number of top-k values.
Defaults to the hparams configuration.
vals (torch.Tensor | None, optional): The corresponding values tensor,
required for validating 12-bit packed indices. Defaults to None.
vals (torch.Tensor | None, optional): The corresponding values tensor.
Raises:
ValueError: If any validation check fails, such as out-of-bounds
indices, incorrect data types, or malformed packed data.
indices, incorrect data types, or malformed payload.
"""
allowed_topk = (
min(self.hparams.topk_compression, totalk)
if allowed_topk is None
else min(allowed_topk, totalk)
)
def _bounds_check(t: torch.Tensor):
"""fast min/max bounds check"""
if t.numel() == 0:
raise ValueError(f"[{param_name}] empty index list")
if t.min().item() < 0 or t.max().item() >= totalk:
bad = t[(t < 0) | (t >= totalk)][0].item()
raise ValueError(
f"[{param_name}] Index {bad} out of bounds (totalk = {totalk})"
)
if not isinstance(idxs, torch.Tensor):
raise ValueError(
f"[{param_name}] Expected tensor for indices, got {type(idxs)}"
)
if vals is None:
raise ValueError(
f"[{param_name}] Values tensor required for index validation"
)
if idxs.dtype != torch.uint8:
raise ValueError(
f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}"
)
if idxs.numel() == 0:
raise ValueError(f"[{param_name}] Empty indices payload")
# Handle 12-bit packed index format only
if isinstance(idxs, torch.Tensor):
if idxs.dtype != torch.uint8:
raise ValueError(
f"[{param_name}] Expected uint8 for 12-bit packed indices, got {idxs.dtype}"
)
# 12-bit packed format is the only supported format
if vals is None:
raise ValueError(
f"[{param_name}] Values tensor required to validate 12-bit packed indices"
)
if idxs.numel() == 0:
raise ValueError(f"[{param_name}] Empty packed indices tensor")
# Decode (CPU) and perform structural checks
try:
payload_bytes = idxs.detach().cpu().numpy().tobytes()
rows_list, C, N = decode_batch_rows(payload_bytes)
except Exception as e:
raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}")
# Unpack using the values shape
try:
unpacked = unpack_12bit_indices(idxs, vals.shape)
# Validate that the last dimension matches allowed_topk
if unpacked.shape[-1] != allowed_topk:
raise ValueError(
f"[{param_name}] Invalid topk dimension: "
f"shape[-1]={unpacked.shape[-1]} but expected {allowed_topk}"
)
_bounds_check(unpacked)
except Exception as e:
raise ValueError(f"[{param_name}] Failed to unpack 12-bit indices: {e}")
else:
raise ValueError(f"[{param_name}] Expected tensor but got {type(idxs)}")
if C != totalk:
raise ValueError(
f"[{param_name}] Payload column size C={C} but expected {totalk}"
)
# compute expected rows from values shape (flatten all but last dim)
if vals.ndim == 0:
raise ValueError(f"[{param_name}] Values tensor has no top‑k dimension")
expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1
if N != expected_rows:
raise ValueError(
f"[{param_name}] Payload rows N={N} but values imply {expected_rows}"
)
k = vals.shape[-1]
if k != allowed_topk:
raise ValueError(
f"[{param_name}] Values top‑k={k} but allowed_topk={allowed_topk}"
)
if any(len(r) != k for r in rows_list):
raise ValueError(
f"[{param_name}] At least one row has mismatched top‑k size"
)
# bounds check without materialising full tensor
max_idx = max((max(r) if len(r) > 0 else -1) for r in rows_list)
min_idx = (
min((min(r) if len(r) > 0 else 0) for r in rows_list) if rows_list else 0
)
if min_idx < 0 or max_idx >= totalk:
raise ValueError(
f"[{param_name}] Index out of bounds (min={min_idx}, max={max_idx}, totalk={totalk})"
)
if not isinstance(idxs, torch.Tensor):
raise ValueError(
f"[{param_name}] Expected tensor for indices, got {type(idxs)}"
)
if vals is None:
raise ValueError(
f"[{param_name}] Values tensor required for index validation"
)
if idxs.dtype == torch.uint8:
pass # Rice/bitmap fast-path below
elif idxs.dtype in (torch.int64, torch.long):
# Legacy/unpacked path: shape and bounds checks only
if vals is None:
raise ValueError(f"[{param_name}] Values tensor required for index validation")
if tuple(idxs.shape) != tuple(vals.shape):
raise ValueError(f"[{param_name}] Indices shape {tuple(idxs.shape)} != values shape {tuple(vals.shape)}")
k = vals.shape[-1] if vals.ndim > 0 else 0
if k != allowed_topk:
raise ValueError(f"[{param_name}] Values top-k={k} but allowed_topk={allowed_topk}")
if idxs.numel() > 0:
i_min = int(idxs.min().item())
i_max = int(idxs.max().item())
if i_min < 0 or i_max >= totalk:
raise ValueError(f"[{param_name}] Index out of bounds (min={i_min}, max={i_max}, totalk={totalk})")
return # legacy path validated
else:
raise ValueError(
f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}"
)
if idxs.numel() == 0:
raise ValueError(f"[{param_name}] Empty indices payload")
# Decode (CPU) and perform structural checks
try:
payload_bytes = idxs.detach().cpu().numpy().tobytes()
rows_list, C, N = decode_batch_rows(payload_bytes)
except Exception as e:
raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}")
🤖 Prompt for AI Agents
In src/tplr/comms.py around lines 2626 to 2702, the validation currently rejects
idxs unless dtype is torch.uint8; update it to accept legacy torch.int64 index
tensors as well by adding a backward-compat branch: if idxs.dtype is
torch.int64, skip the Rice/bitmap decode and instead interpret the tensor as raw
integer indices (reshape/flatten consistent with vals shape and top-k), build
rows_list from those indices, compute C, N, min/max bounds from that legacy
format, and run the same consistency checks (C vs totalk, N vs expected_rows, k
vs allowed_topk, per-row length, bounds). Keep the existing uint8 path unchanged
and only raise for other dtypes, ensuring all error messages remain
param-prefixed.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want the backwards compability

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joellidin, understood. I'll respect your decision to not include backward compatibility for the legacy index formats. The current implementation correctly enforces that all peers must use the new Rice/bitmap encoding format (uint8 payload), which keeps the validation logic clean and ensures consistency across the network.


✏️ Learnings added
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

Comment on lines +75 to +83
def read_bits(self, n: int) -> int:
if n <= 0:
return 0
self._fill(n)
mask = (1 << n) - 1
out = self.cur & mask
self.cur >>= n
self.nbits -= n
return out
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Prevent bit underflow on truncated input.

read_bits can drive nbits negative if the stream is shorter than requested. Raise an explicit EOF to avoid corrupted state.

     def read_bits(self, n: int) -> int:
         if n <= 0:
             return 0
         self._fill(n)
+        if self.nbits < n:
+            raise EOFError(f"Not enough bits in stream to read {n} bits")
         mask = (1 << n) - 1
         out = self.cur & mask
         self.cur >>= n
         self.nbits -= n
         return out
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def read_bits(self, n: int) -> int:
if n <= 0:
return 0
self._fill(n)
mask = (1 << n) - 1
out = self.cur & mask
self.cur >>= n
self.nbits -= n
return out
def read_bits(self, n: int) -> int:
if n <= 0:
return 0
self._fill(n)
if self.nbits < n:
raise EOFError(f"Not enough bits in stream to read {n} bits")
mask = (1 << n) - 1
out = self.cur & mask
self.cur >>= n
self.nbits -= n
return out
🤖 Prompt for AI Agents
In src/tplr/compress/bits.py around lines 75 to 83, read_bits currently calls
self._fill(n) then unconditionally consumes n bits which can make self.nbits
negative on truncated input; change it so that after calling self._fill(n) you
check whether there are at least n bits available and if not raise EOFError (or
a suitable custom exception) without mutating self.cur or self.nbits, otherwise
proceed to compute mask/out and update self.cur and self.nbits; ensure the
function returns the bits only when the full n bits are available.

Comment on lines +26 to +33
max_idx = indices.max().item() if indices.numel() > 0 else 0
if max_idx >= 4096:
raise ValueError(f"Index {max_idx} exceeds 12-bit limit (4095)")

flat = indices.flatten()
n = flat.numel()
if n % 2 != 0:
raise ValueError(f"Number of indices must be even, got {n}")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Guard against negative indices.

Packing negative values produces incorrect bit patterns. Add a lower-bound check.

-    max_idx = indices.max().item() if indices.numel() > 0 else 0
+    max_idx = indices.max().item() if indices.numel() > 0 else 0
+    min_idx = indices.min().item() if indices.numel() > 0 else 0
+    if min_idx < 0:
+        raise ValueError(f"Index {min_idx} is negative; only non-negative indices are supported")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
max_idx = indices.max().item() if indices.numel() > 0 else 0
if max_idx >= 4096:
raise ValueError(f"Index {max_idx} exceeds 12-bit limit (4095)")
flat = indices.flatten()
n = flat.numel()
if n % 2 != 0:
raise ValueError(f"Number of indices must be even, got {n}")
max_idx = indices.max().item() if indices.numel() > 0 else 0
min_idx = indices.min().item() if indices.numel() > 0 else 0
if min_idx < 0:
raise ValueError(f"Index {min_idx} is negative; only non-negative indices are supported")
if max_idx >= 4096:
raise ValueError(f"Index {max_idx} exceeds 12-bit limit (4095)")
flat = indices.flatten()
n = flat.numel()
if n % 2 != 0:
raise ValueError(f"Number of indices must be even, got {n}")
🤖 Prompt for AI Agents
In src/tplr/compress/pack12.py around lines 26 to 33, add a lower-bound check
for negative indices before packing: compute the minimum index (only when
indices.numel() > 0), and if min_idx < 0 raise ValueError("Index {min_idx} is
negative; cannot pack negative indices") (or similar message), keeping the
existing upper-bound check for >=4096 and the even-count check; ensure you use
the same tensor-to-python conversion pattern (.min().item()) as used for max.

Comment on lines +57 to +75
device = packed.device
n_indices = 1
for d in values_shape:
n_indices *= int(d)
if n_indices == 0:
return torch.zeros(values_shape, dtype=torch.int64, device=device)
if n_indices % 2 != 0:
raise ValueError(f"Number of indices must be even, got {n_indices}")

out = torch.zeros(n_indices, dtype=torch.int64, device=device)
n_pairs = n_indices // 2
if n_pairs > 0:
b0 = packed[0::3].to(torch.int64)
b1 = packed[1::3].to(torch.int64)
b2 = packed[2::3].to(torch.int64)

out[0::2] = b0 | ((b1 & 0x0F) << 8)
out[1::2] = ((b1 >> 4) & 0x0F) | (b2 << 4)
return out.view(*values_shape)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Validate packed length and dtype before unpacking.

Mismatched or non-uint8 buffers can raise shape errors or silently decode wrong values.

-    device = packed.device
+    # Ensure correct dtype/device
+    if packed.dtype != torch.uint8:
+        packed = packed.to(torch.uint8)
+    device = packed.device
@@
-    if n_indices % 2 != 0:
+    if n_indices % 2 != 0:
         raise ValueError(f"Number of indices must be even, got {n_indices}")
+    expected_bytes = (n_indices // 2) * 3
+    if packed.numel() != expected_bytes:
+        raise ValueError(
+            f"Packed buffer length mismatch: expected {expected_bytes} bytes for "
+            f"{n_indices} indices, got {packed.numel()}"
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
device = packed.device
n_indices = 1
for d in values_shape:
n_indices *= int(d)
if n_indices == 0:
return torch.zeros(values_shape, dtype=torch.int64, device=device)
if n_indices % 2 != 0:
raise ValueError(f"Number of indices must be even, got {n_indices}")
out = torch.zeros(n_indices, dtype=torch.int64, device=device)
n_pairs = n_indices // 2
if n_pairs > 0:
b0 = packed[0::3].to(torch.int64)
b1 = packed[1::3].to(torch.int64)
b2 = packed[2::3].to(torch.int64)
out[0::2] = b0 | ((b1 & 0x0F) << 8)
out[1::2] = ((b1 >> 4) & 0x0F) | (b2 << 4)
return out.view(*values_shape)
# Ensure correct dtype/device
if packed.dtype != torch.uint8:
packed = packed.to(torch.uint8)
device = packed.device
n_indices = 1
for d in values_shape:
n_indices *= int(d)
if n_indices == 0:
return torch.zeros(values_shape, dtype=torch.int64, device=device)
if n_indices % 2 != 0:
raise ValueError(f"Number of indices must be even, got {n_indices}")
expected_bytes = (n_indices // 2) * 3
if packed.numel() != expected_bytes:
raise ValueError(
f"Packed buffer length mismatch: expected {expected_bytes} bytes for "
f"{n_indices} indices, got {packed.numel()}"
)
out = torch.zeros(n_indices, dtype=torch.int64, device=device)
n_pairs = n_indices // 2
if n_pairs > 0:
b0 = packed[0::3].to(torch.int64)
b1 = packed[1::3].to(torch.int64)
b2 = packed[2::3].to(torch.int64)
out[0::2] = b0 | ((b1 & 0x0F) << 8)
out[1::2] = ((b1 >> 4) & 0x0F) | (b2 << 4)
return out.view(*values_shape)

Replace the 12-bit packed index format with the new Rice/bitmap
compression codec throughout the codebase. This migration affects
compression, communication, and neurons components.

- Update TopKCompressor to use encode_batch_rows/decode_batch_rows
- Migrate check_compressed_indices validation to Rice/bitmap format
- Reorder values to match codec permutation during compression
- Update all test cases from 12-bit to Rice/bitmap encoding
- Handle uint8 payload format and structural validation
- Ensure compatibility with B_choices=(64, 128) tensor dimensions
- Extend Rice/bitmap codec support to neurons module
Remove Discrete Cosine Transform functionality from the compression
system while retaining the ChunkingTransformer for tensor chunking.

- Remove DCT transformation methods from ChunkingTransformer
- Remove f_dict and b_dict (DCT basis matrices)
- Remove norm parameter and einsum helper methods
- Remove use_dct parameter from encode/decode methods throughout
- Remove DCT helper functions (_dct, _idct, etc.)
- Update all callers to remove use_dct parameter
- Remove use_dct from hparams configuration
- Update tests to remove DCT-related test cases

ChunkingTransformer now only handles tensor reshaping and chunking,
simplifying the implementation by removing unused DCT functionality.
With the improved Rice/bitmap compression codec, we can now transmit 4x
more gradient values (128 vs 32) with acceptable overhead. This should
improve gradient quality and convergence while the new codec keeps
communication costs manageable.
@joellidin joellidin force-pushed the feat/efficient-compress branch from c1adcac to db48e62 Compare August 31, 2025 16:12
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (5)
tests/unit/test_compress.py (2)

102-141: Each encoded payload must cover all rows (N == xshape[0]).

Both idx1_orig and idx2_orig encode only 2 rows while xshape is 8. This will break batch_decompress scatter.

Apply:

-        # Create multiple Rice/bitmap encoded indices
-        idx1_orig = torch.tensor([[0, 1], [2, 3]], dtype=torch.int64)
-        idx2_orig = torch.tensor([[4, 5], [6, 7]], dtype=torch.int64)
+        # Create per-row indices for all rows so N matches xshape[0]
+        rows = xshape[0]
+        base = torch.arange(rows, dtype=torch.int64) * 4
+        idx1_orig = torch.stack([(base + 0) % totalk, (base + 1) % totalk], dim=1)
+        idx2_orig = torch.stack([(base + 2) % totalk, (base + 3) % totalk], dim=1)
@@
-        val1 = torch.tensor([[0.1, 0.2], [0.3, 0.4]], dtype=torch.float32)
-        val2 = torch.tensor([[0.5, 0.6], [0.7, 0.8]], dtype=torch.float32)
+        val1 = torch.randn(rows, 2, dtype=torch.float32)
+        val2 = torch.randn(rows, 2, dtype=torch.float32)

208-241: Row-count mismatch again in norm/clip tests; idx encodes 1 row while xshape has 8.

Align rows in the encoded payload and values with xshape[0].

Apply:

-        # Create test data with Rice/bitmap encoded format
-        idx_orig = torch.tensor([[0, 1, 2, 3]], dtype=torch.int64)  # Even count
-        payload, perm, _ = encode_batch_rows(idx_orig, C=totalk)
+        # Create test data with Rice/bitmap encoded format for all rows
+        K = 4
+        rows = xshape[0]
+        idx_orig = torch.stack(
+            [torch.arange(i * K, i * K + K, dtype=torch.int64) % totalk for i in range(rows)]
+        )
+        payload, perm, _ = encode_batch_rows(idx_orig, C=totalk)
@@
-        val_orig = torch.tensor([[10.0, 20.0, 30.0, 40.0]], dtype=torch.float32)
-        # Reorder values to match permutation
-        val = [torch.gather(val_orig, dim=1, index=perm)]
+        val_orig = torch.randn(rows, K, dtype=torch.float32)
+        # Reorder values to match permutation
+        val = [torch.gather(val_orig, dim=1, index=perm)]
tests/test_comms.py (1)

52-57: Use Rice/bitmap payload in the “valid state dict” helper.

This aligns fixtures with the enforced format and avoids silently producing incompatible index payloads.

-    for name, _ in model.named_parameters():
-        # Create legacy 12-bit packed format (for backwards compatibility test)
-        indices = torch.tensor([0, 1], dtype=torch.long)
-        packed_data = pack_12bit_indices(indices)
-        state_dict[name + "idxs"] = packed_data
-        state_dict[name + "vals"] = torch.tensor([0.1, 0.2], dtype=torch.float32)
+    for name, param in model.named_parameters():
+        # Produce Rice/bitmap payload matching totalk
+        totalk = param.numel()
+        idx2d = torch.tensor([[0, 1]], dtype=torch.long)
+        payload, _perm, _ = encode_batch_rows(idx2d, C=totalk)
+        packed = torch.tensor(np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8)
+        state_dict[name + "idxs"] = packed
+        state_dict[name + "vals"] = torch.tensor([0.1, 0.2], dtype=torch.float32)
neurons/validator.py (1)

2686-2694: Remove legacy 12‑bit tuple handling; enforce Rice/bitmap only.

This path contradicts the migration objective and will silently accept legacy payloads. Per the project decision, only uint8 Rice/bitmap payloads must be accepted. Proactively reject tuples here and let check_compressed_indices validate the Rice/bitmap format.

[pulled_from_retrieved_learnings]
Apply this diff:

-                # Handle 12-bit packed format: (packed_tensor, original_shape)
-                if isinstance(idxs, tuple) and len(idxs) == 2:
-                    packed_data, original_shape = idxs
-                    # Move packed data to device
-                    packed_data = packed_data.to(self.device)
-                    idxs = (packed_data, original_shape)
+                # Reject legacy 12-bit formats explicitly
+                if isinstance(idxs, tuple):
+                    raise ValueError(
+                        f"Invalid gradient data from peer {eval_uid}: legacy 12-bit indices are no longer supported"
+                    )
src/tplr/neurons.py (1)

324-336: NameError at runtime: deleting undefined variables vals/qps.

After refactor, vals and qps are not defined in outer_step. The del will crash on master.

Apply this diff:

-            # Free intermediate pieces ASAP
-            del vals_f32, idxs, vals, qps, ref, decompressed
+            # Free intermediate pieces ASAP
+            del vals_f32, idxs, ref, decompressed
♻️ Duplicate comments (1)
tests/unit/test_compress.py (1)

74-101: Row-count mismatch: payload encodes 2 rows while xshape is 8 → decompress will fail.

encode_batch_rows(original_indices, C=totalk) here emits rows=2 but p/xshape expect rows=8. Make payload rows match xshape[0].

Apply:

-        # Create proper Rice/bitmap encoded format using the encoder
-        # Create indices that are within valid range for a 8x64 tensor (even count)
-        original_indices = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.int64)
-
-        # Pack using the new encoder format
-        payload, perm, _ = encode_batch_rows(original_indices, C=totalk)
-        idx = torch.tensor(np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8)
-
-        val = torch.tensor(
-            [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32
-        )
-        # Reorder values to match permutation
-        val = torch.gather(val, dim=1, index=perm)
+        # Build per-row K indices matching xshape[0] rows
+        K = 4
+        rows = xshape[0]
+        original_indices = torch.stack(
+            [torch.arange(i * K, i * K + K, dtype=torch.int64) % totalk for i in range(rows)]
+        )
+        # Encode and permute corresponding values
+        payload, perm, _ = encode_batch_rows(original_indices, C=totalk)
+        idx = torch.tensor(np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8)
+        val = torch.randn(rows, K, dtype=torch.float32)
+        val = torch.gather(val, dim=1, index=perm)
🧹 Nitpick comments (9)
tests/unit/test_compress.py (3)

54-73: Quantized path checks mostly fine; consider asserting qparams structure.

If qparams is a dict, assert expected keys to guard regressions.

Apply this minimal guard:

-        assert qparams is not None
-        assert len(qparams) == 5  # shift, scale, offset, lookup, orig_dtype
+        assert qparams is not None
+        if isinstance(qparams, dict):
+            for k in ("shift", "scale", "offset", "lookup", "orig_dtype"):
+                assert k in qparams

83-85: Remove stale “even count” comments (not required for Rice/bitmap).

Rice/bitmap has no even-count constraint; these comments are legacy from 12‑bit tests.

-        # Create indices that are within valid range for a 8x64 tensor (even count)
+        # Create indices within valid range for an 8x64 tensor

and similarly in the norm/clip test.

Also applies to: 212-214


54-73: Optional: factor a small helper to generate per-row indices/values to reduce duplication.

Keeps tests short and consistent.

Example helper to add near the top of the file:

def build_rowwise_indices_values(rows: int, K: int, C: int):
    idx = torch.stack(
        [torch.arange(i * K, i * K + K, dtype=torch.int64) % C for i in range(rows)]
    )
    val = torch.randn(rows, K, dtype=torch.float32)
    payload, perm, _ = encode_batch_rows(idx, C=C)
    idx_bytes = torch.tensor(np.frombuffer(payload, dtype=np.uint8), dtype=torch.uint8)
    val = torch.gather(val, dim=1, index=perm)
    return idx_bytes, val

Also applies to: 102-141, 208-241

src/tplr/comms.py (1)

1656-1657: Update stale comments: replace “12‑bit” wording with Rice/bitmap.

Minor clarity fix to match the new format used across the codebase.

-                            # Get corresponding vals tensor for 12-bit unpacking
+                            # Get corresponding vals tensor to validate decoder-aligned top‑k

@@
-                            # Handle 12-bit packed format (uint8 tensor)
+                            # Account bytes size from Rice/bitmap (uint8 payload)

Also applies to: 1737-1741

tests/test_comms.py (4)

16-18: Prefer new codec in tests; keep legacy helpers clearly scoped.

Importing encode_batch_rows and numpy is fine. If legacy pack_12bit_indices remains, limit its use to explicitly marked legacy tests to avoid confusion with the enforced Rice/bitmap-only path.


69-76: Clarify helper’s legacy intent.

Docstring says “legacy 12‑bit” already; consider renaming to create_legacy_packed_indices to make call sites self‑documenting. Not blocking.


1572-1575: Silence unused variable warnings for perm.

Ruff/flake8 may flag ‘perm’ as unused; bind to underscore.

-    payload, perm, _ = encode_batch_rows(valid_indices, C=totalk)
+    payload, _perm, _ = encode_batch_rows(valid_indices, C=totalk)

(Apply similarly in the other tests listed in the line ranges.)

Also applies to: 1592-1595, 1644-1646, 1664-1666, 1716-1719, 1729-1731, 1749-1751, 1762-1764


1756-1757: Correct comment: Rice/bitmap doesn’t require an even number of indices.

Update the note to avoid implying 12‑bit constraints.

-    # Note: Can't test with 1 element as encoder requires even number of indices
+    # Note: Rice/bitmap supports any k; we use k=6 here to deliberately trigger a top‑k mismatch.
src/tplr/neurons.py (1)

962-976: Optional sanity: validate codec’s C against expected param size (when available).

decode_batch_rows returns C but it’s unused. If neuron has totalks[pname], assert equality to catch mismatched payloads early.

Example:

-            rows_i, _C_codec, N_rows = decode_batch_rows(
+            rows_i, _C_codec, N_rows = decode_batch_rows(
                 payload
             )  # rows_i: list[list[int]]
+            # Optional guard: ensure codec header matches expected C
+            # if hasattr(neuron, "totalks") and pname in neuron.totalks:
+            #     expected_C = int(neuron.totalks[pname])
+            #     if _C_codec != expected_C:
+            #         raise ValueError(f"Codec C mismatch for {pname}: {_C_codec} != {expected_C}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between c1adcac and db48e62.

📒 Files selected for processing (12)
  • hparams/hparams.json (1 hunks)
  • neurons/evaluator.py (1 hunks)
  • neurons/miner.py (1 hunks)
  • neurons/trainer.py (0 hunks)
  • neurons/validator.py (2 hunks)
  • src/tplr/comms.py (3 hunks)
  • src/tplr/compress/topk.py (10 hunks)
  • src/tplr/neurons.py (5 hunks)
  • tests/test_comms.py (11 hunks)
  • tests/test_prepare_gradient_dict.py (4 hunks)
  • tests/unit/test_compress.py (8 hunks)
  • tests/unit/test_neurons.py (0 hunks)
💤 Files with no reviewable changes (2)
  • tests/unit/test_neurons.py
  • neurons/trainer.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • neurons/miner.py
  • neurons/evaluator.py
  • tests/test_prepare_gradient_dict.py
  • src/tplr/compress/topk.py
  • hparams/hparams.json
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
📚 Learning: 2025-08-31T15:02:03.568Z
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.

Applied to files:

  • src/tplr/comms.py
  • tests/test_comms.py
🧬 Code graph analysis (5)
src/tplr/neurons.py (2)
src/tplr/compress/topk.py (5)
  • compress (194-198)
  • compress (200-204)
  • compress (207-257)
  • encode (77-95)
  • decode (98-113)
src/tplr/compress/bits.py (1)
  • decode_batch_rows (452-488)
src/tplr/comms.py (2)
src/tplr/compress/topk.py (3)
  • compress (194-198)
  • compress (200-204)
  • compress (207-257)
src/tplr/compress/bits.py (1)
  • decode_batch_rows (452-488)
tests/unit/test_compress.py (3)
src/tplr/compress/topk.py (5)
  • compress (194-198)
  • compress (200-204)
  • compress (207-257)
  • _get_smaller_split (609-628)
  • encode (77-95)
src/tplr/compress/bits.py (1)
  • encode_batch_rows (253-444)
src/tplr/compress/pack12.py (2)
  • pack_12bit_indices (21-47)
  • unpack_12bit_indices (50-75)
tests/test_comms.py (3)
src/tplr/compress/pack12.py (1)
  • pack_12bit_indices (21-47)
src/tplr/compress/bits.py (1)
  • encode_batch_rows (253-444)
src/tplr/comms.py (1)
  • check_compressed_indices (2615-2701)
neurons/validator.py (3)
src/tplr/compress/topk.py (1)
  • decode (98-113)
tests/test_prepare_gradient_dict.py (3)
  • decode (36-37)
  • decode (170-171)
  • decode (264-266)
tests/test_comms.py (1)
  • decode (1030-1032)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (11)
tests/unit/test_compress.py (6)

3-3: Import of numpy is appropriate.

Used for frombuffer on codec payload; no issues.


11-17: Public API imports look correct.

Shifts tests to Rice/bitmap path and the new helper; aligns with the refactor and DCT removal.


35-53: Good assertion of Rice/bitmap payload type.

Asserting uint8 for indices matches the new codec and the migration plan to drop legacy int64 formats.


146-176: Round-trip test looks solid.

Confirms new path end-to-end and asserts dtype/shape.


177-203: Good coverage for large index ranges.

Asserting totalk equals the last dim for 2D tensors and verifying decompress result shape is sufficient here.


282-289: Transformer encode/decode tests align with DCT removal.

API usage matches the updated signatures.

src/tplr/comms.py (1)

52-52: Decoder-driven validation import looks good.

Switching to decode_batch_rows via the compress package aligns validation with the new codec.

src/tplr/neurons.py (4)

163-169: Switched encode path to new transformer API — LGTM.

Owner-only encode aligns with the refactor and keeps EF on-device.


182-187: Switched decode path to new transformer API — LGTM.

Post-decompress decode matches updated interface.


980-985: Weighting metric clarified — LGTM.

Using param_weight = C_chunks * k sensibly size‑weights by rows×k after the codec change.


35-36: Import decode_batch_rows correctly or expose it in tplr.compress
tplr.compress/__init__.py doesn’t export decode_batch_rows, so from tplr.compress import decode_batch_rows (src/tplr/neurons.py:35) will fail. Either change the import to:

from tplr.compress.bits import decode_batch_rows

or re-export it in tplr/compress/__init__.py by adding:

from .bits import decode_batch_rows
__all__.append("decode_batch_rows")

Likely an incorrect or invalid review comment.

Comment on lines 264 to 266
enc = self.transformer.encode(
torch.empty(p.shape, dtype=torch.float16, device=self.device),
use_dct=self.hparams.use_dct,
torch.empty(p.shape, dtype=torch.float16, device=self.device)
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

API update to drop use_dct — LGTM.

ChunkingTransformer.encode/decode now called without use_dct; matches the new signatures.

Run to confirm no stale use_dct remains:

Also applies to: 2805-2810


🏁 Script executed:

#!/bin/bash
rg -nP --hidden --glob '!*site-packages*' '\buse_dct\b|outer_step\s*\([^)]*use_dct' -C2

Length of output: 390


Remove stale use_dct reference in test config

  • In scripts/local_miner_test.py (line 179), remove or update the "use_dct": False entry so tests align with the updated ChunkingTransformer API.
🤖 Prompt for AI Agents
In neurons/validator.py around lines 264 to 266 (and the related test at
scripts/local_miner_test.py line 179), there's a stale test config entry
"use_dct": False that no longer matches the updated ChunkingTransformer API;
remove the "use_dct" key from the test config (or replace it with the current
API-equivalent option if the new transformer expects an explicit setting), then
run the tests and adjust any assertions that depended on the previous behavior
so they reflect the transformer’s current chunking behavior.

Comment on lines +2626 to 2702
This ensures indices are within bounds and that the **new Rice/bitmap**
codec payload matches the provided values tensor shape (top‑k).

Args:
param_name (str): The name of the parameter being checked.
idxs (torch.Tensor): The tensor of indices.
totalk (int): The total number of elements in the original uncompressed tensor.
allowed_topk (int | None, optional): The expected number of top-k values.
Defaults to the hparams configuration.
vals (torch.Tensor | None, optional): The corresponding values tensor,
required for validating 12-bit packed indices. Defaults to None.
vals (torch.Tensor | None, optional): The corresponding values tensor.

Raises:
ValueError: If any validation check fails, such as out-of-bounds
indices, incorrect data types, or malformed packed data.
indices, incorrect data types, or malformed payload.
"""
allowed_topk = (
min(self.hparams.topk_compression, totalk)
if allowed_topk is None
else min(allowed_topk, totalk)
)

def _bounds_check(t: torch.Tensor):
"""fast min/max bounds check"""
if t.numel() == 0:
raise ValueError(f"[{param_name}] empty index list")
if t.min().item() < 0 or t.max().item() >= totalk:
bad = t[(t < 0) | (t >= totalk)][0].item()
raise ValueError(
f"[{param_name}] Index {bad} out of bounds (totalk = {totalk})"
)
if not isinstance(idxs, torch.Tensor):
raise ValueError(
f"[{param_name}] Expected tensor for indices, got {type(idxs)}"
)
if vals is None:
raise ValueError(
f"[{param_name}] Values tensor required for index validation"
)
if idxs.dtype != torch.uint8:
raise ValueError(
f"[{param_name}] Expected uint8 (Rice/bitmap payload), got {idxs.dtype}"
)
if idxs.numel() == 0:
raise ValueError(f"[{param_name}] Empty indices payload")

# Handle 12-bit packed index format only
if isinstance(idxs, torch.Tensor):
if idxs.dtype != torch.uint8:
raise ValueError(
f"[{param_name}] Expected uint8 for 12-bit packed indices, got {idxs.dtype}"
)
# 12-bit packed format is the only supported format
if vals is None:
raise ValueError(
f"[{param_name}] Values tensor required to validate 12-bit packed indices"
)
if idxs.numel() == 0:
raise ValueError(f"[{param_name}] Empty packed indices tensor")
# Decode (CPU) and perform structural checks
try:
payload_bytes = idxs.detach().cpu().numpy().tobytes()
rows_list, C, N = decode_batch_rows(payload_bytes)
except Exception as e:
raise ValueError(f"[{param_name}] Failed to decode indices payload: {e}")

# Unpack using the values shape
try:
unpacked = unpack_12bit_indices(idxs, vals.shape)
# Validate that the last dimension matches allowed_topk
if unpacked.shape[-1] != allowed_topk:
raise ValueError(
f"[{param_name}] Invalid topk dimension: "
f"shape[-1]={unpacked.shape[-1]} but expected {allowed_topk}"
)
_bounds_check(unpacked)
except Exception as e:
raise ValueError(f"[{param_name}] Failed to unpack 12-bit indices: {e}")
else:
raise ValueError(f"[{param_name}] Expected tensor but got {type(idxs)}")
if C != totalk:
raise ValueError(
f"[{param_name}] Payload column size C={C} but expected {totalk}"
)

# compute expected rows from values shape (flatten all but last dim)
if vals.ndim == 0:
raise ValueError(f"[{param_name}] Values tensor has no top‑k dimension")
expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1
if N != expected_rows:
raise ValueError(
f"[{param_name}] Payload rows N={N} but values imply {expected_rows}"
)

k = vals.shape[-1]
if k != allowed_topk:
raise ValueError(
f"[{param_name}] Values top‑k={k} but allowed_topk={allowed_topk}"
)
if any(len(r) != k for r in rows_list):
raise ValueError(
f"[{param_name}] At least one row has mismatched top‑k size"
)

# bounds check without materialising full tensor
max_idx = max((max(r) if len(r) > 0 else -1) for r in rows_list)
min_idx = (
min((min(r) if len(r) > 0 else 0) for r in rows_list) if rows_list else 0
)
if min_idx < 0 or max_idx >= totalk:
raise ValueError(
f"[{param_name}] Index out of bounds (min={min_idx}, max={max_idx}, totalk={totalk})"
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Tighten validation: int-cast topk, guard vals type, drop numpy by using math.prod.

Prevents surprising float topk from misconfigured hparams, avoids attribute errors if vals isn’t a tensor, and removes an unnecessary dependency.

Apply:

- import numpy as np
+ # (numpy no longer needed here)

@@
-        allowed_topk = (
-            min(self.hparams.topk_compression, totalk)
-            if allowed_topk is None
-            else min(allowed_topk, totalk)
-        )
+        allowed_topk = (
+            min(self.hparams.topk_compression, totalk)
+            if allowed_topk is None
+            else min(allowed_topk, totalk)
+        )
+        try:
+            allowed_topk = int(allowed_topk)
+        except Exception:
+            raise ValueError(f"[{param_name}] allowed_topk must be an integer-compatible value")

@@
-        if vals is None:
+        if vals is None:
             raise ValueError(
                 f"[{param_name}] Values tensor required for index validation"
             )
+        if not torch.is_tensor(vals):
+            raise ValueError(f"[{param_name}] Values must be a torch.Tensor, got {type(vals)}")

@@
-        expected_rows = int(np.prod(vals.shape[:-1])) if vals.ndim > 1 else 1
+        expected_rows = int(math.prod(vals.shape[:-1])) if vals.ndim > 1 else 1

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/tplr/comms.py around lines 2626 to 2702, tighten validation by (1)
enforcing vals is a torch.Tensor before using .ndim/.shape (raise a ValueError
if not), (2) casting allowed_topk/k to int to avoid float surprises, and (3)
replace numpy.prod usage with math.prod over vals.shape[:-1] (import math) to
remove the numpy dependency; keep all existing structural and bounds checks but
perform the type guard early and use int(...) when comparing top-k values.

Comment on lines +959 to 961
idx_data = idxs_all[i] if isinstance(idxs_all, (list, tuple)) else idxs_all
payload = _as_bytes(idx_data)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid duplicating a single idx payload across all peers.

If idxs_all isn’t a list/tuple, treating it as every peer’s indices can produce false positives in overlap. Prefer skipping that param or enforcing list/tuple.

Apply this diff:

-            idx_data = idxs_all[i] if isinstance(idxs_all, (list, tuple)) else idxs_all
+            if not isinstance(idxs_all, (list, tuple)):
+                # Not a per-peer payload; skip this parameter safely
+                decoded_per_peer = []
+                break
+            idx_data = idxs_all[i]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
idx_data = idxs_all[i] if isinstance(idxs_all, (list, tuple)) else idxs_all
payload = _as_bytes(idx_data)
if not isinstance(idxs_all, (list, tuple)):
# Not a per-peer payload; skip this parameter safely
decoded_per_peer = []
break
idx_data = idxs_all[i]
payload = _as_bytes(idx_data)
🤖 Prompt for AI Agents
In src/tplr/neurons.py around lines 959-961, the code treats a non-list/tuple
idxs_all as if it applies to every peer, causing duplicate idx payloads and
false overlap; change the logic so idxs_all must be a list/tuple to be used
per-peer—if not isinstance(idxs_all, (list, tuple)) then skip per-peer idx
payload (e.g., set idx_data = None and avoid creating payload) or explicitly
raise a TypeError/ValueError upstream; ensure the per-peer branch only uses
idxs_all[i] when idxs_all is a list/tuple to prevent duplicating the same idx
across all peers.

Comment on lines +1605 to 1617
totalk = 128

# Test with regular tensor (not packed) - should fail because it's not uint8
invalid_tensor = torch.tensor([0, 1, 2, 3], dtype=torch.long)
vals = torch.randn(4, dtype=torch.float32)
# This should fail since only uint8 12-bit packed format is supported
with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"):
# This should fail since only uint8 Rice/bitmap encoded format is supported
with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"):
dummy_comms.check_compressed_indices("param", invalid_tensor, totalk, vals=vals)

# Test with list (not a tensor)
invalid_list = torch.tensor([0, 1, 2, 3])
with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"):
with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"):
dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix test case intent: non‑tensor input vs wrong dtype.

The second assertion claims to test a non‑tensor, but passes a tensor and matches the wrong error. Either make it a true non‑tensor case or rename. Suggest splitting:

-    # Test with list (not a tensor)
-    invalid_list = torch.tensor([0, 1, 2, 3])
-    with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"):
-        dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals)
+    # Test with list (not a tensor) → type error branch
+    invalid_list = [0, 1, 2, 3]
+    with pytest.raises(ValueError, match="Expected tensor for indices"):
+        dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
totalk = 128
# Test with regular tensor (not packed) - should fail because it's not uint8
invalid_tensor = torch.tensor([0, 1, 2, 3], dtype=torch.long)
vals = torch.randn(4, dtype=torch.float32)
# This should fail since only uint8 12-bit packed format is supported
with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"):
# This should fail since only uint8 Rice/bitmap encoded format is supported
with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"):
dummy_comms.check_compressed_indices("param", invalid_tensor, totalk, vals=vals)
# Test with list (not a tensor)
invalid_list = torch.tensor([0, 1, 2, 3])
with pytest.raises(ValueError, match="Expected uint8 for 12-bit packed indices"):
with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"):
dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals)
totalk = 128
# Test with regular tensor (not packed) - should fail because it's not uint8
invalid_tensor = torch.tensor([0, 1, 2, 3], dtype=torch.long)
vals = torch.randn(4, dtype=torch.float32)
# This should fail since only uint8 Rice/bitmap encoded format is supported
with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"):
dummy_comms.check_compressed_indices("param", invalid_tensor, totalk, vals=vals)
# Test with list (not a tensor) → type error branch
invalid_list = [0, 1, 2, 3]
with pytest.raises(ValueError, match="Expected tensor for indices"):
dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals)
🤖 Prompt for AI Agents
In tests/test_comms.py around lines 1605 to 1617, the second assertion claims to
test a non-tensor but currently passes a torch.tensor; change invalid_list to be
an actual Python list (e.g. [0,1,2,3]) and update the with pytest.raises to
expect a TypeError (or the specific exception your function raises for
non-tensor inputs) instead of the uint8/dtype ValueError so the test truly
verifies the non-tensor input path.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tests/unit/test_compress.py (3)

74-96: Align payload row count with target shape to avoid scatter errors.

Payload encodes 2 rows but xshape is (8, 64). Make indices/vals for 8 rows so N matches.

Apply:

-        p = torch.zeros(8, 64)  # 512 elements total, last dim=64
-        xshape = (8, 64)
-        totalk = 64
+        p = torch.zeros(8, 64)  # 512 elements total, last dim=64
+        xshape = (8, 64)
+        totalk = 64
@@
-        original_indices = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.int64)
+        K = 4
+        original_indices = torch.stack(
+            [torch.arange(i * K, i * K + K, dtype=torch.int64) % totalk for i in range(8)]
+        )
@@
-        val = torch.tensor(
-            [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32
-        )
+        val = torch.randn(8, K, dtype=torch.float32)

105-136: Fix batch_decompress test shapes (rows must match xshape[0]).

Each packed payload encodes 2 rows; set xshape (and p) to 2×64.

Apply:

-        p = torch.zeros(8, 64)  # 512 elements total, last dim=64
-        xshape = (8, 64)
+        p = torch.zeros(2, 64)
+        xshape = (2, 64)
         totalk = 64

203-216: Fix norm/clip test shapes to match single-row payload.

Current payload encodes 1 row; xshape is 8×64 → mismatch.

Apply:

-        p = torch.zeros(8, 64)  # 512 elements total, last dim=64
-        xshape = (8, 64)
+        p = torch.zeros(1, 64)
+        xshape = (1, 64)
         totalk = 64
@@
-        block_norms = torch.tensor([15.0])
+        block_norms = torch.tensor([15.0])
src/tplr/compress/topk.py (1)

513-539: Fix tuple/list detection in maybe_dequantize_values (len(float) crash).

The current check iterates elements of a 5-tuple (shift, scale, …) and calls len() on floats/ints → TypeError.

Apply:

-        if (
-            isinstance(qparams, tuple)
-            and len(qparams) == 5  # potentially single or already 5 elements
-            and not all([len(q) == 5 for q in qparams])  # already correctly formatted
-        ):
-            qparams = [qparams]
-        if not isinstance(qparams, list):
-            qparams = [qparams]
+        if isinstance(qparams, tuple) and len(qparams) == 5:
+            qparams = [qparams]
+        if not isinstance(qparams, list):
+            qparams = [qparams]
♻️ Duplicate comments (3)
src/tplr/compress/bits.py (1)

75-83: Guard against truncated input in BitReader.read_bits.

read_bits mutates state even when there aren’t enough bits, causing silent corruption and possible negative nbits. Raise EOFError before consuming.

Apply:

 def read_bits(self, n: int) -> int:
     if n <= 0:
         return 0
     self._fill(n)
+    if self.nbits < n:
+        raise EOFError(f"Not enough bits in stream to read {n} bits")
     mask = (1 << n) - 1
     out = self.cur & mask
     self.cur >>= n
     self.nbits -= n
     return out
tests/test_comms.py (1)

1614-1617: Fix non-tensor test to actually pass a list and assert the correct error.

Currently passes a tensor and asserts the dtype error; intent was non-tensor type.

Apply:

-    # Test with list (not a tensor)
-    invalid_list = torch.tensor([0, 1, 2, 3])
-    with pytest.raises(ValueError, match="Expected uint8.*Rice/bitmap payload"):
-        dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals)
+    # Test with list (not a tensor) → type branch
+    invalid_list = [0, 1, 2, 3]
+    with pytest.raises(ValueError, match="Expected tensor for indices"):
+        dummy_comms.check_compressed_indices("param", invalid_list, totalk, vals=vals)
src/tplr/compress/topk.py (1)

398-414: Mirror N/k checks in batch path.

Catch mismatched rows early with a clear error; current code only checks k.

Apply:

             if i_data.dtype == torch.uint8:
                 rows, C, _N = decode_batch_rows(i_data.detach().cpu().numpy().tobytes())
                 if C != totalk:
                     raise ValueError(f"Index payload C={C} but expected {totalk}")
-                if any(len(r) != v_data.shape[-1] for r in rows):
+                k = int(v_data.shape[-1])
+                if any(len(r) != k for r in rows):
                     raise ValueError(
                         "Row-wise topk size mismatch in index payload (batch)"
                     )
+                expected_rows = int(v_data.reshape(-1, k).shape[0])
+                if len(rows) != expected_rows:
+                    raise ValueError(
+                        f"Row count mismatch in batch: payload N={len(rows)} vs vals rows={expected_rows}"
+                    )
                 idx_unpacked = torch.tensor(
                     rows, dtype=torch.int64, device=p.device
                 ).view(*v_data.shape)
🧹 Nitpick comments (1)
tests/test_comms.py (1)

1756-1757: Correct comment: even-count constraint doesn’t apply to Rice/bitmap.

Parity restriction was for legacy 12-bit packing, not the new codec.

Apply:

-    # Note: Can't test with 1 element as encoder requires even number of indices
+    # Note: Rice/bitmap codec accepts any k; even-count was a 12-bit legacy constraint
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between db48e62 and 774088d.

📒 Files selected for processing (5)
  • src/tplr/compress/bits.py (1 hunks)
  • src/tplr/compress/topk.py (10 hunks)
  • tests/test_comms.py (11 hunks)
  • tests/unit/test_bits_codec.py (1 hunks)
  • tests/unit/test_compress.py (9 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/unit/test_bits_codec.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
📚 Learning: 2025-08-31T15:02:03.568Z
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.

Applied to files:

  • tests/test_comms.py
  • tests/unit/test_compress.py
🧬 Code graph analysis (3)
tests/test_comms.py (3)
src/tplr/compress/pack12.py (1)
  • pack_12bit_indices (21-47)
src/tplr/compress/bits.py (1)
  • encode_batch_rows (253-411)
src/tplr/comms.py (1)
  • check_compressed_indices (2615-2701)
src/tplr/compress/topk.py (2)
src/tplr/compress/bits.py (2)
  • decode_batch_rows (419-455)
  • encode_batch_rows (253-411)
src/tplr/comms.py (1)
  • gather (1466-1787)
tests/unit/test_compress.py (3)
src/tplr/compress/topk.py (5)
  • compress (194-198)
  • compress (200-204)
  • compress (207-252)
  • _get_smaller_split (604-623)
  • encode (77-95)
src/tplr/compress/bits.py (1)
  • encode_batch_rows (253-411)
src/tplr/compress/pack12.py (2)
  • pack_12bit_indices (21-47)
  • unpack_12bit_indices (50-75)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)

Comment on lines +286 to 303
# Decode indices
if idx.dtype == torch.uint8:
# 12-bit packed format - unpack it
idx_int64 = unpack_12bit_indices(idx, val.shape)
payload_bytes = idx.detach().cpu().numpy().tobytes()
rows_list, C, _N = decode_batch_rows(payload_bytes)
if C != totalk:
raise ValueError(f"Index payload C={C} but expected {totalk}")
k = val.shape[-1]
if any(len(r) != k for r in rows_list):
raise ValueError("Row-wise topk size mismatch in index payload")
idx_int64 = torch.tensor(
rows_list, dtype=torch.int64, device=p.device
).view(*val.shape)
elif idx.dtype in (torch.int64, torch.long):
# Already unpacked (from batch_decompress)
idx_int64 = idx
idx_int64 = idx.to(p.device)
else:
raise ValueError(
f"Expected uint8 (packed) or int64 (unpacked) indices, got {idx.dtype}"
)
# Ensure val has the same dtype as x for scatter operation
raise ValueError(f"Unsupported index tensor dtype: {idx.dtype}")

if val.dtype != x.dtype:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Validate row count and remove no-op reshape.

decompress should fail early when payload N ≠ vals rows. Also the .reshape(xshape) return value is ignored.

Apply:

         if idx.dtype == torch.uint8:
             payload_bytes = idx.detach().cpu().numpy().tobytes()
             rows_list, C, _N = decode_batch_rows(payload_bytes)
             if C != totalk:
                 raise ValueError(f"Index payload C={C} but expected {totalk}")
             k = val.shape[-1]
             if any(len(r) != k for r in rows_list):
                 raise ValueError("Row-wise topk size mismatch in index payload")
+            expected_rows = int(val.reshape(-1, k).shape[0])
+            if len(rows_list) != expected_rows:
+                raise ValueError(
+                    f"Row count mismatch: payload N={len(rows_list)} vs vals rows={expected_rows}"
+                )
             idx_int64 = torch.tensor(
                 rows_list, dtype=torch.int64, device=p.device
             ).view(*val.shape)
@@
-        x.scatter_reduce_(
-            dim=-1, index=idx_int64, src=val, reduce="mean", include_self=False
-        ).reshape(xshape)
+        x.scatter_reduce_(dim=-1, index=idx_int64, src=val, reduce="mean", include_self=False)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Decode indices
if idx.dtype == torch.uint8:
# 12-bit packed format - unpack it
idx_int64 = unpack_12bit_indices(idx, val.shape)
payload_bytes = idx.detach().cpu().numpy().tobytes()
rows_list, C, _N = decode_batch_rows(payload_bytes)
if C != totalk:
raise ValueError(f"Index payload C={C} but expected {totalk}")
k = val.shape[-1]
if any(len(r) != k for r in rows_list):
raise ValueError("Row-wise topk size mismatch in index payload")
idx_int64 = torch.tensor(
rows_list, dtype=torch.int64, device=p.device
).view(*val.shape)
elif idx.dtype in (torch.int64, torch.long):
# Already unpacked (from batch_decompress)
idx_int64 = idx
idx_int64 = idx.to(p.device)
else:
raise ValueError(
f"Expected uint8 (packed) or int64 (unpacked) indices, got {idx.dtype}"
)
# Ensure val has the same dtype as x for scatter operation
raise ValueError(f"Unsupported index tensor dtype: {idx.dtype}")
if val.dtype != x.dtype:
# Decode indices
if idx.dtype == torch.uint8:
payload_bytes = idx.detach().cpu().numpy().tobytes()
rows_list, C, _N = decode_batch_rows(payload_bytes)
if C != totalk:
raise ValueError(f"Index payload C={C} but expected {totalk}")
k = val.shape[-1]
if any(len(r) != k for r in rows_list):
raise ValueError("Row-wise topk size mismatch in index payload")
# NEW: ensure the total number of rows matches
expected_rows = int(val.reshape(-1, k).shape[0])
if len(rows_list) != expected_rows:
raise ValueError(
f"Row count mismatch: payload N={len(rows_list)} vs vals rows={expected_rows}"
)
idx_int64 = torch.tensor(
rows_list, dtype=torch.int64, device=p.device
).view(*val.shape)
elif idx.dtype in (torch.int64, torch.long):
idx_int64 = idx.to(p.device)
else:
raise ValueError(f"Unsupported index tensor dtype: {idx.dtype}")
if val.dtype != x.dtype:
...
# Drop the no-op .reshape(xshape) since scatter_reduce_ is in-place
x.scatter_reduce_(
dim=-1,
index=idx_int64,
src=val,
reduce="mean",
include_self=False
)

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/tplr/compress/topk.py (2)

173-189: Fix k clamping when last-dim < 2 (avoids topk(k>dim) runtime error).

Current logic can set k=2 even if dim=1. Clamp to [1..dim] first, then enforce evenness only when k>=2.

 def _clamp_topk(self, x, topk) -> int:
-        topk = min(topk, x.shape[-1])
-        topk = max(topk, 2)
-        # Keep even by default (matches broader system expectations).
-        topk = topk - (topk % 2)
-        return int(topk)
+        n = int(x.shape[-1])
+        k = int(topk)
+        k = max(1, min(k, n))  # stay within [1..n]
+        if k >= 2:
+            k -= (k % 2)  # even if possible
+        return k

533-542: Fix tuple/list detection; current code can raise TypeError on floats/dtypes.

The len(q) check iterates elements of a single QuantParamsT (includes float/int/dtype), causing TypeError. Just wrap a single 5‑tuple.

-        if (
-            isinstance(qparams, tuple)
-            and len(qparams) == 5  # potentially single or already 5 elements
-            and not all([len(q) == 5 for q in qparams])  # already correctly formatted
-        ):
-            qparams = [qparams]
+        if isinstance(qparams, tuple) and len(qparams) == 5:
+            qparams = [qparams]
♻️ Duplicate comments (2)
src/tplr/compress/topk.py (2)

306-321: Add N (row‑count) validation and drop no‑op reshape; mirror prior guidance.

         if idx.dtype == torch.uint8:
             payload_bytes = idx.detach().cpu().numpy().tobytes()
             rows_list, C, _N = decode_batch_rows(payload_bytes)
             if C != totalk:
                 raise ValueError(f"Index payload C={C} but expected {totalk}")
             k = val.shape[-1]
             if any(len(r) != k for r in rows_list):
                 raise ValueError("Row-wise topk size mismatch in index payload")
+            expected_rows = int(val.reshape(-1, k).shape[0])
+            if len(rows_list) != expected_rows:
+                raise ValueError(
+                    f"Row count mismatch: payload N={len(rows_list)} vs vals rows={expected_rows}"
+                )
             idx_int64 = torch.tensor(
                 rows_list, dtype=torch.int64, device=p.device
             ).view(*val.shape)
@@
-        x.scatter_reduce_(
-            dim=-1, index=idx_int64, src=val, reduce="mean", include_self=False
-        ).reshape(xshape)
+        x.scatter_reduce_(dim=-1, index=idx_int64, src=val, reduce="mean", include_self=False)

Also applies to: 326-329


418-435: Batch path: also validate N (row‑count) to fail early on shape mismatches.

             if i_data.dtype == torch.uint8:
                 rows, C, _N = decode_batch_rows(i_data.detach().cpu().numpy().tobytes())
                 if C != totalk:
                     raise ValueError(f"Index payload C={C} but expected {totalk}")
-                if any(len(r) != v_data.shape[-1] for r in rows):
+                k = int(v_data.shape[-1])
+                if any(len(r) != k for r in rows):
                     raise ValueError(
                         "Row-wise topk size mismatch in index payload (batch)"
                     )
+                if len(rows) != int(v_data.reshape(-1, k).shape[0]):
+                    raise ValueError(
+                        f"Row count mismatch in batch: payload N={len(rows)} "
+                        f"vs vals rows={int(v_data.reshape(-1, k).shape[0])}"
+                    )
                 idx_unpacked = torch.tensor(
-                    rows, dtype=torch.int64, device=p.device
-                ).view(*v_data.shape)
+                    rows, dtype=torch.int64, device=p.device
+                ).view(*v_data.shape)
🧹 Nitpick comments (2)
neurons/validator.py (2)

262-297: Avoid compressing uninitialized tensors; compute xshape/totalk cheaply.

torch.empty may contain NaN/Inf, tripping quantization fallbacks and wasting time. You only need xshape and totalk; skip full compress and use encode().shape.

-        import time
-        total_compress_time = 0.0
-        total_encode_time = 0.0
-        
-        # Enable debug timing in compressor
-        self.compressor._debug_timing = True
-        
+        import time
+        total_compress_time = 0.0
+        total_encode_time = 0.0
+        # Enable debug timing in compressor (optional; keep False to speed init)
+        self.compressor._debug_timing = False
@@
-            encode_start = time.time()
-            enc = self.transformer.encode(
-                torch.empty(p.shape, dtype=torch.float16, device=self.device)
-            )
-            encode_time = time.time() - encode_start
-            
-            compress_start = time.time()
-            _, _, xshape, totalk, _ = self.compressor.compress(
-                enc,
-                self.hparams.topk_compression,
-            )
-            compress_time = time.time() - compress_start
-            
-            self.xshapes[n] = xshape
-            self.totalks[n] = totalk
-            
-            total_encode_time += encode_time
-            total_compress_time += compress_time
-            
-            # Log timing for each layer
-            tplr.logger.info(f"[COMPRESS TIMING] {n}: encode={encode_time:.3f}s, compress={compress_time:.3f}s, shape={p.shape}")
+            encode_start = time.time()
+            enc = self.transformer.encode(torch.zeros(p.shape, dtype=torch.float16, device=self.device))
+            encode_time = time.time() - encode_start
+
+            # Derive shapes without running top‑k/codec
+            self.xshapes[n] = tuple(p.shape)
+            self.totalks[n] = int(enc.shape[-1])
+
+            total_encode_time += encode_time
+            tplr.logger.info(f"[COMPRESS TIMING] {n}: encode={encode_time:.3f}s, shape={p.shape}")
@@
-        tplr.logger.info(f"[COMPRESS TIMING TOTAL] encode={total_encode_time:.3f}s, compress={total_compress_time:.3f}s")
-        
-        # Disable debug timing after initialization
-        self.compressor._debug_timing = False
+        tplr.logger.info(f"[COMPRESS TIMING TOTAL] encode={total_encode_time:.3f}s")

2710-2717: Drop stale 12‑bit handling in validator; enforce new Rice/bitmap only.

Per your migration decision, keep only uint8 payloads and let Comms reject legacy formats. This removes dead paths and confusion.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 774088d and 5c8e8d9.

📒 Files selected for processing (3)
  • neurons/validator.py (2 hunks)
  • src/tplr/compress/bits.py (1 hunks)
  • src/tplr/compress/topk.py (10 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/tplr/compress/bits.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
🧬 Code graph analysis (2)
neurons/validator.py (3)
tests/test_comms.py (3)
  • encode (1026-1028)
  • compress (1036-1038)
  • decode (1030-1032)
src/tplr/compress/topk.py (5)
  • encode (77-95)
  • compress (194-198)
  • compress (200-204)
  • compress (207-272)
  • decode (98-113)
tests/test_prepare_gradient_dict.py (10)
  • encode (33-34)
  • encode (167-168)
  • encode (260-262)
  • compress (17-26)
  • compress (151-161)
  • compress (238-249)
  • info (44-45)
  • decode (36-37)
  • decode (170-171)
  • decode (264-266)
src/tplr/compress/topk.py (2)
src/tplr/compress/bits.py (2)
  • decode_batch_rows (439-475)
  • encode_batch_rows (253-431)
src/tplr/comms.py (1)
  • gather (1466-1787)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)

Comment on lines +246 to 253
# Flatten to [rows, k] for the codec
encode_start = time.time()
idx2d = idx_int64.reshape(-1, topk).contiguous()
# GPU‑accelerated encode → bytes
payload, _meta = encode_batch_rows(
idx2d, C=totalk, B_choices=_DEFAULT_B_CHOICES
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Make B choices dynamic per C to prevent “No valid B choices for C”.

Using a fixed (64,128) fails for small or non‑divisible C (e.g., 1D biases, odd chunk sizes). Build powers‑of‑two divisors of C at call‑site.

-        idx2d = idx_int64.reshape(-1, topk).contiguous()
-        # GPU‑accelerated encode → bytes
-        payload, _meta = encode_batch_rows(
-            idx2d, C=totalk, B_choices=_DEFAULT_B_CHOICES
-        )
+        idx2d = idx_int64.reshape(-1, topk).contiguous()
+        # Choose B dynamically: all powers of two that divide C (largest→smallest; include 1 as last‑resort)
+        _cand = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
+        b_choices = tuple(b for b in reversed(_cand) if b <= totalk and (totalk % b) == 0) or (1,)
+        # GPU‑accelerated encode → bytes
+        payload, _meta = encode_batch_rows(idx2d, C=totalk, B_choices=b_choices)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Flatten to [rows, k] for the codec
encode_start = time.time()
idx2d = idx_int64.reshape(-1, topk).contiguous()
# GPU‑accelerated encode → bytes
payload, _meta = encode_batch_rows(
idx2d, C=totalk, B_choices=_DEFAULT_B_CHOICES
)
# Flatten to [rows, k] for the codec
encode_start = time.time()
idx2d = idx_int64.reshape(-1, topk).contiguous()
# Choose B dynamically: all powers of two that divide C (largest→smallest; include 1 as last-resort)
_cand = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
b_choices = tuple(b for b in reversed(_cand) if b <= totalk and (totalk % b) == 0) or (1,)
# GPU-accelerated encode → bytes
payload, _meta = encode_batch_rows(
idx2d, C=totalk, B_choices=b_choices
)
🤖 Prompt for AI Agents
In src/tplr/compress/topk.py around lines 246 to 253, the call to
encode_batch_rows uses a fixed _DEFAULT_B_CHOICES which can be invalid for small
or non‑divisible C and causes "No valid B choices for C"; replace this by
computing B_choices at the call site as the powers‑of‑two divisors of totalk (C)
— generate all 2^n that divide C (or are <= C and C % B == 0), ensure the list
is sorted/unique, and if empty fallback to [1]; then pass that B_choices to
encode_batch_rows instead of _DEFAULT_B_CHOICES so encoding works for small or
odd C values.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (3)
src/tplr/compress/bits.py (3)

234-242: Row emission order fixed — thank you.

Writing rows in original order (per-row encode then emit) resolves the earlier B-bucket ordering bug.


75-83: Fix BitReader underflow (read past end corrupts state).

read_bits can drive nbits negative on truncated payloads. Raise EOF and avoid mutating state when fewer than n bits are available.

Apply:

 def read_bits(self, n: int) -> int:
     if n <= 0:
         return 0
     self._fill(n)
+    if self.nbits < n:
+        raise EOFError(f"Not enough bits in stream to read {n} bits")
     mask = (1 << n) - 1
     out = self.cur & mask
     self.cur >>= n
     self.nbits -= n
     return out

366-402: Harden decoder against malformed/untrusted payloads.

Validate header fields, row boundaries, and per-row params; ensure full row chunks are consumed. Prevent OOB reads and silent corruption.

Apply:

 def decode_batch_rows(payload: bytes) -> tuple[list[list[int]], int, int]:
@@
-    C = br.read_bits(12) + 1
-    N = br.read_bits(16)
+    C = br.read_bits(12) + 1
+    if not (1 <= C <= 4096):
+        raise ValueError(f"Invalid C={C}; expected 1..4096")
+    N = br.read_bits(16)
+    if N < 0:
+        raise ValueError(f"Invalid N={N}")
     _ = br.read_bits(1)  # reserved
@@
-    for _i in range(N):
-        row_len = br.read_bits(16)
+    for _i in range(N):
+        row_len = br.read_bits(16)
+        # Ensure enough bits remain for row_len bytes
+        remaining_bits = 8 * (len(payload) - br.idx) + br.nbits
+        if row_len * 8 > remaining_bits:
+            raise ValueError("Row length exceeds remaining payload")
         row_bytes = br.read_bytes(row_len)
         rr = BitReader(row_bytes)
         lb = rr.read_bits(5)
         k_param = rr.read_bits(4)
         use_bitmap = rr.read_bits(1)
         B = 1 << lb
-        n_sub = C // B
+        # Validate B
+        if B <= 0 or (B & (B - 1)) != 0 or (C % B) != 0:
+            raise ValueError(f"Invalid row block size B={B} for C={C}")
+        n_sub = C // B
@@
-            s_len = _rice_read(rr, k_param)
+            s_len = _rice_read(rr, k_param)
             if s_len == 0:
                 continue
             if use_bitmap:
                 bitmask = rr.read_bits(B)
                 for loc in range(B):
                     if (bitmask >> loc) & 1:
                         indices.append(j * B + loc)
             else:
-                for _ in range(s_len):
-                    loc = rr.read_bits(lb)
-                    indices.append(j * B + loc)
+                for _ in range(s_len):
+                    loc = rr.read_bits(lb)
+                    if loc >= B:
+                        raise ValueError("Location out of range")
+                    indices.append(j * B + loc)
         rows.append(indices)
+        # Ensure no trailing data per row
+        if rr.idx != len(row_bytes):
+            raise ValueError("Row payload has trailing or malformed data")
     return rows, C, N
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 5c8e8d9 and 375b0e3.

📒 Files selected for processing (1)
  • src/tplr/compress/bits.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)

Comment on lines +129 to +149
def encode_batch_rows_cpu(
rows_np: np.ndarray,
*,
C: int,
B_choices: tuple[int, ...] = (64, 128),
scheme: str = "per_row",
workers: int | None = None,
) -> tuple[bytes, dict]:
if scheme != "per_row":
raise ValueError("Only scheme='per_row' is implemented")
valid_B: list[int] = [
b for b in B_choices if b > 0 and (b & (b - 1)) == 0 and (C % b) == 0
]
if not valid_B:
b = 1
valid_B = []
while b <= C:
if C % b == 0 and (b & (b - 1)) == 0:
valid_B.append(b)
b <<= 1

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add input validation to prevent silent index drop/corruption.

Out-of-range indices (e.g., <0 or ≥C) are currently binned out of n_sub and silently discarded. Validate upfront and fail fast.

Apply:

 def encode_batch_rows_cpu(
@@
 ) -> tuple[bytes, dict]:
-    if scheme != "per_row":
+    if scheme != "per_row":
         raise ValueError("Only scheme='per_row' is implemented")
+    if not np.issubdtype(rows_np.dtype, np.integer):
+        raise TypeError("rows_np must be an integer array")
+    if rows_np.size:
+        mn = int(rows_np.min())
+        mx = int(rows_np.max())
+        if mn < 0 or mx >= C:
+            raise ValueError(f"Index out of range: allowed [0, {C - 1}], got min={mn}, max={mx}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def encode_batch_rows_cpu(
rows_np: np.ndarray,
*,
C: int,
B_choices: tuple[int, ...] = (64, 128),
scheme: str = "per_row",
workers: int | None = None,
) -> tuple[bytes, dict]:
if scheme != "per_row":
raise ValueError("Only scheme='per_row' is implemented")
valid_B: list[int] = [
b for b in B_choices if b > 0 and (b & (b - 1)) == 0 and (C % b) == 0
]
if not valid_B:
b = 1
valid_B = []
while b <= C:
if C % b == 0 and (b & (b - 1)) == 0:
valid_B.append(b)
b <<= 1
def encode_batch_rows_cpu(
rows_np: np.ndarray,
*,
C: int,
B_choices: tuple[int, ...] = (64, 128),
scheme: str = "per_row",
workers: int | None = None,
) -> tuple[bytes, dict]:
if scheme != "per_row":
raise ValueError("Only scheme='per_row' is implemented")
if not np.issubdtype(rows_np.dtype, np.integer):
raise TypeError("rows_np must be an integer array")
if rows_np.size:
mn = int(rows_np.min())
mx = int(rows_np.max())
if mn < 0 or mx >= C:
raise ValueError(
f"Index out of range: allowed [0, {C - 1}], got min={mn}, max={mx}"
)
valid_B: list[int] = [
b for b in B_choices if b > 0 and (b & (b - 1)) == 0 and (C % b) == 0
]
if not valid_B:
b = 1
valid_B = []
while b <= C:
if C % b == 0 and (b & (b - 1)) == 0:
valid_B.append(b)
b <<= 1
🤖 Prompt for AI Agents
In src/tplr/compress/bits.py around lines 129 to 149, input indexing into
buckets can silently discard out-of-range indices; add upfront validation on
rows_np to fail fast: ensure rows_np is an integer ndarray, non-empty, has shape
consistent with expected rows (e.g., 2D or 1D as required by this function), and
that every index value is >= 0 and < C; if any check fails raise a ValueError
with a clear message indicating invalid index values and the offending
range/shape so callers can correct their input rather than losing data silently.

Comment on lines +263 to +270
rows, k = idx_gpu.shape
device = idx_gpu.device
header = 5 + 4 + 1

best_bits = None # torch.Tensor [rows]
best_B = None # torch.Tensor [rows]

# Fast path: single histogram at smallest B, aggregate for larger B
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Guard GPU path inputs (dtype/range).

Mirror CPU-side validation to avoid UB on GPU (e.g., invalid dtype or indices out of [0,C)).

Apply:

 @torch.no_grad()
 def encode_batch_rows(
     idx_gpu: torch.Tensor,  # [rows, k] int64
@@
 ) -> tuple[int, dict]:
@@
-    rows, k = idx_gpu.shape
+    if idx_gpu.dtype != torch.long:
+        raise TypeError("idx_gpu must be torch.int64 (long)")
+    if idx_gpu.numel():
+        if int(idx_gpu.min().item()) < 0 or int(idx_gpu.max().item()) >= C:
+            raise ValueError(f"Index out of range for C={C}")
+    rows, k = idx_gpu.shape
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
rows, k = idx_gpu.shape
device = idx_gpu.device
header = 5 + 4 + 1
best_bits = None # torch.Tensor [rows]
best_B = None # torch.Tensor [rows]
# Fast path: single histogram at smallest B, aggregate for larger B
if idx_gpu.dtype != torch.long:
raise TypeError("idx_gpu must be torch.int64 (long)")
if idx_gpu.numel():
if int(idx_gpu.min().item()) < 0 or int(idx_gpu.max().item()) >= C:
raise ValueError(f"Index out of range for C={C}")
rows, k = idx_gpu.shape
device = idx_gpu.device
header = 5 + 4 + 1
best_bits = None # torch.Tensor [rows]
best_B = None # torch.Tensor [rows]
# Fast path: single histogram at smallest B, aggregate for larger B

@joellidin joellidin force-pushed the feat/efficient-compress branch from 375b0e3 to 4f23a1e Compare August 31, 2025 22:03
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (5)
src/tplr/compress/bits.py (5)

230-243: Row-order emission is correct now (fixes prior B-bucket misalignment).

Using ThreadPoolExecutor.map over range(N) preserves original row order and aligns with the decoder’s expectations.


75-83: Fix bit underflow in BitReader.read_bits (silent corruption on truncated input).

If the stream is shorter than requested, nbits can go negative and decoding silently corrupts. Raise EOFError when not enough bits are available.

Apply:

 def read_bits(self, n: int) -> int:
     if n <= 0:
         return 0
     self._fill(n)
+    if self.nbits < n:
+        raise EOFError(f"Not enough bits in stream to read {n} bits")
     mask = (1 << n) - 1
     out = self.cur & mask
     self.cur >>= n
     self.nbits -= n
     return out

346-358: GPU size calculation undercounts: omit global header, per-row 16-bit lengths, and per-row byte rounding.

This can cause buffer undersizing downstream. Compute bytes_per_row and include fixed overhead.

Apply:

-    total_bits = int(best_bits.sum().item())
-    bytes_len = (total_bits + 7) // 8
+    # Round each row to full bytes; outer stream stores per-row blobs.
+    bytes_per_row = (best_bits + 7) // 8  # [rows]
+    rows_bytes_total = int(bytes_per_row.sum().item())
+    header_bits = 12 + 16 + 1  # global C-1, N, reserved
+    length_bits = 16 * rows    # per-row byte lengths
+    total_bits = rows_bytes_total * 8 + header_bits + length_bits
+    bytes_len = (total_bits + 7) // 8
@@
-        "avg_bits_per_row": float(best_bits.float().mean().item()),
+        "avg_bits_per_row": float((bytes_per_row * 8).to(torch.float32).mean().item()),
         "B_hist": B_hist,
     }

129-149: Validate CPU encoder inputs (dtype/range) to avoid silent index loss.

Reject non-integer arrays and out-of-range indices upfront.

Apply:

 def encode_batch_rows_cpu(
@@
 ) -> tuple[bytes, dict]:
     if scheme != "per_row":
         raise ValueError("Only scheme='per_row' is implemented")
+    if not np.issubdtype(rows_np.dtype, np.integer):
+        raise TypeError("rows_np must be an integer array")
+    if rows_np.size:
+        mn = int(rows_np.min())
+        mx = int(rows_np.max())
+        if mn < 0 or mx >= C:
+            raise ValueError(f"Index out of range: allowed [0, {C - 1}], got min={mn}, max={mx}")

263-270: Guard GPU path inputs (dtype/range) to prevent UB and wrong sizing.

Mirror CPU checks for dtype=int64 and indices in [0, C).

Apply:

 @torch.no_grad()
 def encode_batch_rows(
@@
 ) -> tuple[int, dict]:
@@
-    rows, k = idx_gpu.shape
+    if idx_gpu.dtype != torch.long:
+        raise TypeError("idx_gpu must be torch.int64 (long)")
+    if idx_gpu.numel():
+        if int(idx_gpu.min().item()) < 0 or int(idx_gpu.max().item()) >= C:
+            raise ValueError(f"Index out of range for C={C}")
+    rows, k = idx_gpu.shape
🧹 Nitpick comments (2)
src/tplr/compress/bits.py (2)

62-74: Add BitReader.available_bits() helper (used by decoder preflight checks).

Small utility to validate remaining capacity before reading fixed-size row chunks.

Apply:

 class BitReader:
     def __init__(self, data: bytes) -> None:
         self.data = data
         self.idx = 0
         self.cur = 0
         self.nbits = 0

     def _fill(self, n: int) -> None:
         while self.nbits < n and self.idx < len(self.data):
             self.cur |= int(self.data[self.idx]) << self.nbits
             self.idx += 1
             self.nbits += 8
+
+    def available_bits(self) -> int:
+        # Bits still buffered plus bits in unread bytes
+        return self.nbits + (len(self.data) - self.idx) * 8

243-250: Make CPU meta.avg_bits_per_row reflect allocated bytes (optional).

Average currently uses raw per-row estimates, not rounded bytes + headers. Consider computing from actual row_bytes.

Apply:

-    row_bits: list[int] = []
+    row_bits: list[int] = []
+    row_bytes_lens: list[int] = []
@@
             bw.write_bits(len(row_bytes), 16)
             for byte in row_bytes:
                 bw.write_bits(int(byte), 8)
             row_bits.append(bits_used)
             B_hist[B_used] = B_hist.get(B_used, 0) + 1
+            row_bytes_lens.append(len(row_bytes))
@@
     meta = {
         "total_bits": len(payload) * 8,
-        "avg_bits_per_row": (sum(row_bits) / max(1, N)) if N else 0.0,
+        "avg_bits_per_row": (8.0 * sum(row_bytes_lens) / max(1, N)) if N else 0.0,
         "B_hist": B_hist,
     }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 375b0e3 and 4f23a1e.

📒 Files selected for processing (1)
  • src/tplr/compress/bits.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)

@joellidin joellidin force-pushed the feat/efficient-compress branch from 4f23a1e to 807b76a Compare August 31, 2025 22:30
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (4)
src/tplr/compress/bits.py (4)

35-43: Add EOF protection to prevent bit underflow.

The read_bits method can drive nbits negative if the stream is truncated, leading to corrupted state. This is critical for robust handling of potentially malformed payloads.

Apply this diff to prevent underflow:

 def read_bits(self, n: int) -> int:
     if n <= 0:
         return 0
     self._fill(n)
+    if self.nbits < n:
+        raise EOFError(f"Not enough bits in stream to read {n} bits")
     mask = (1 << n) - 1
     out = self.cur & mask
     self.cur >>= n
     self.nbits -= n
     return out

67-103: Add validation for untrusted payloads in decoder.

The decoder lacks validation for malformed payloads that could cause crashes or incorrect behavior. This is essential when processing data from untrusted sources.

Apply validation checks:

 def decode_batch_rows(payload: bytes) -> tuple[list[list[int]], int, int]:
     """
     Decode payload created by `encode_batch_rows(...)`.
     Returns (rows, C, N) with per-row global indices.
     """
     br = _BitReader(payload)
     C = br.read_bits(12) + 1
+    if C < 1 or C > 4096:  # Reasonable bounds
+        raise ValueError(f"Invalid C value: {C}")
     N = br.read_bits(16)
+    if N > 65536:  # Reasonable upper bound
+        raise ValueError(f"Invalid N value: {N}")
     _ = br.read_bits(1)  # reserved
 
     rows: list[list[int]] = []
     for _i in range(N):
         row_len = br.read_bits(16)
+        # Verify enough bytes remain
+        if br.idx + row_len > len(br.data):
+            raise EOFError(f"Truncated payload: need {row_len} bytes, have {len(br.data) - br.idx}")
         row_bytes = br.read_bytes(row_len)
         rr = _BitReader(row_bytes)
         lb = rr.read_bits(5)
         k_param = rr.read_bits(4)
         use_bitmap = rr.read_bits(1)
         B = 1 << lb
+        # Validate B
+        if B < 1 or B > C or (C % B) != 0:
+            raise ValueError(f"Invalid row block size B={B} for C={C}")
         n_sub = C // B
 
         indices: list[int] = []
         for j in range(n_sub):
             s_len = _rice_read(rr, k_param)
+            if s_len > B:
+                raise ValueError(f"Invalid s_len={s_len} > B={B}")
             if s_len == 0:
                 continue
             if use_bitmap:
                 bitmask = rr.read_bits(B)
                 for loc in range(B):
                     if (bitmask >> loc) & 1:
                         indices.append(j * B + loc)
             else:
                 for _ in range(s_len):
                     loc = rr.read_bits(lb)
+                    if loc >= B:
+                        raise ValueError(f"Invalid loc={loc} >= B={B}")
                     indices.append(j * B + loc)
         rows.append(indices)
+        # Ensure full row chunk consumed
+        if rr.idx != len(row_bytes):
+            raise ValueError("Row payload has trailing or malformed data")
     return rows, C, N

406-412: Correct total_bits calculation to include per-row byte rounding.

The current calculation underestimates the actual payload size as it doesn't account for row-level byte alignment.

Apply the corrected calculation:

-    row_bytes = ((row_bits + 7) // 8).to(torch.int32)
+    row_bytes = ((row_bits + 7) // 8).to(torch.int32)
 
     # 3) Allocate payload buffer
     B_hist = {int(b): int((best_B == b).sum().item()) for b in torch.unique(best_B)}
-    total_bits_rows = int((16 * rows + 8 * row_bytes.sum().item()))
-    total_bits = 12 + 16 + 1 + total_bits_rows
+    # Account for: global header (29 bits) + per-row length fields (16 bits each) + per-row payloads (bytes)
+    header_bits = 12 + 16 + 1  # C-1, N, reserved
+    length_bits = 16 * rows
+    payload_bytes_total = int(row_bytes.sum().item())
+    total_bits = header_bits + length_bits + (payload_bytes_total * 8)
     total_bytes = (total_bits + 7) // 8

354-376: Add input validation for GPU encoding path.

The encoder should validate input indices to prevent silent data corruption or out-of-bounds access.

Apply input validation:

 @torch.no_grad()
 def encode_batch_rows(
     idx: torch.Tensor,  # [rows, k] int64 (CUDA strongly recommended)
     *,
     C: int,
     B_choices: tuple[int, ...] = (64, 128),
 ) -> tuple[bytes, dict]:
     """
     Triton encoder for per-row Rice/bitmap codec.
 
     Returns:
       payload: bytes
       meta:    {total_bits, avg_bits_per_row, B_hist}
     """
     if not TRITON_AVAILABLE:
         raise RuntimeError("Triton is not available. `pip install triton` and re-run.")
 
     if idx.dtype != torch.int64:
         idx = idx.to(torch.int64)
     if not idx.is_cuda:
         idx = idx.cuda()
+    
+    # Validate input indices
+    if idx.numel() > 0:
+        min_idx = int(idx.min().item())
+        max_idx = int(idx.max().item())
+        if min_idx < 0 or max_idx >= C:
+            raise ValueError(f"Index out of range for C={C}: found indices in [{min_idx}, {max_idx}]")
+    
     rows, k = idx.shape
     device = idx.device
🧹 Nitpick comments (2)
src/tplr/compress/bits.py (2)

311-312: Increase buffer size for safety or add dynamic check.

The fixed buffer size of 256 could be exceeded if k > 256. Consider making it dynamic or adding a guard.

Add a safety check or increase buffer size:

-            buf = tl.zeros([256], dtype=tl.int32)  # K<=256 guard; adjust if you need larger K
+            # Ensure buffer is large enough for all possible k values
+            buf = tl.zeros([512], dtype=tl.int32)  # Increased for safety
+            tl.device_assert(k <= 512, "k exceeds buffer size")

469-470: Improve avg_bits_per_row calculation accuracy.

The average should be computed after byte-rounding to reflect actual allocated space.

     meta = {
         "total_bits": total_bits,
-        "avg_bits_per_row": float(row_bits.float().mean().item()) if rows > 0 else 0.0,
+        "avg_bits_per_row": float((row_bytes * 8).float().mean().item()) if rows > 0 else 0.0,
         "B_hist": B_hist,
     }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 4f23a1e and 807b76a.

📒 Files selected for processing (1)
  • src/tplr/compress/bits.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
🪛 GitHub Actions: CI
src/tplr/compress/bits.py

[error] 5-5: uv run ruff check: I001 - Import block is un-sorted or un-formatted.

🔇 Additional comments (1)
src/tplr/compress/bits.py (1)

1-18: Review imports and exception handling in src/tplr/compress/bits.py

  • Add a blank line after from __future__ import annotations and group imports per PEP 8:
    from __future__ import annotations
    
    import math
    from typing import Sequence
    
    import torch
    import torch.nn.functional as F
  • Restrict the catch to ImportError so other issues aren’t masked:
    try:
        import triton
        import triton.language as tl
        TRITON_AVAILABLE = True
    -except Exception:
    +except ImportError:
        TRITON_AVAILABLE = False

Comment on lines +268 to +269
lb = tl.cdiv(tl.log2(B), 1) # ceil(log2(B)) ; B is power of two so exact
n_sub = C // B
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use tl.math.log2 for proper Triton compilation.

The tl.log2 function returns float values which need proper ceiling computation in Triton context.

Apply this fix for proper Triton compilation:

-    lb = tl.cdiv(tl.log2(B), 1)  # ceil(log2(B)) ; B is power of two so exact
+    lb = tl.math.log2(B.to(tl.float32)).to(tl.int32)  # B is power of two so exact
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
lb = tl.cdiv(tl.log2(B), 1) # ceil(log2(B)) ; B is power of two so exact
n_sub = C // B
lb = tl.math.log2(B.to(tl.float32)).to(tl.int32) # B is power of two so exact
n_sub = C // B
🤖 Prompt for AI Agents
In src/tplr/compress/bits.py around lines 268-269, replace the use of tl.log2
with Triton's math API and perform an explicit ceiling: compute lb using
tl.math.log2 and tl.math.ceil (e.g. lb = tl.math.ceil(tl.math.log2(B))) so the
value is computed with Triton-compatible functions; leave n_sub = C // B
unchanged.

Comment on lines +457 to +462
payload_bitpos.to(torch.int64), # starts_bits_ptr
payload, # payload buffer
torch.as_tensor(int(len_bits_base), dtype=torch.int64, device=device),
torch.as_tensor(int(payload_bits_base), dtype=torch.int64, device=device),
NSUB_MAX=128,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Consider validating NSUB_MAX against actual requirements.

The hardcoded NSUB_MAX=128 assumes C/B_min <= 128. This should be validated or made dynamic.


🏁 Script executed:

#!/bin/bash
# Check if there are any configurations where C/B_min might exceed 128
rg -n "C\s*=|C:" --type py | head -20
rg -n "B_choices|B_sorted" --type py | head -20

Length of output: 3317


🏁 Script executed:

#!/bin/bash
sed -n '430,500p' src/tplr/compress/bits.py

Length of output: 1717


Enforce C/B_min ≤ NSUB_MAX (128) or compute NSUB_MAX dynamically
Add a runtime check (or derive NSUB_MAX = C // B_min) before launching the Triton kernel in src/tplr/compress/bits.py (around lines 457–462) to ensure C/B_min never exceeds 128 and avoid out-of-bounds indexing.

@joellidin joellidin force-pushed the feat/efficient-compress branch from 807b76a to 2727ada Compare August 31, 2025 22:47
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (4)
src/tplr/compress/bits.py (4)

50-58: Guard against truncated input in BitReader.read_bits.

Without a fill-then-check, nbits can go negative on short payloads.

     def read_bits(self, n: int) -> int:
         if n <= 0:
             return 0
         self._fill(n)
+        if self.nbits < n:
+            raise EOFError(f"Not enough bits in stream to read {n} bits")
         mask = (1 << n) - 1
         out = self.cur & mask
         self.cur >>= n
         self.nbits -= n
         return out

82-119: Harden decoder against malformed/untrusted payloads.

Add header, per-row length, B, and row-consumption validation to avoid OOB/DoS.

 def decode_batch_rows(payload: bytes) -> tuple[list[list[int]], int, int]:
@@
-    C = br.read_bits(12) + 1
-    N = br.read_bits(16)
+    C = br.read_bits(12) + 1
+    if C < 1:
+        raise ValueError(f"Invalid C={C}")
+    N = br.read_bits(16)
     _ = br.read_bits(1)  # reserved
@@
-    for _i in range(N):
-        row_len = br.read_bits(16)
+    for _i in range(N):
+        row_len = br.read_bits(16)
+        if hasattr(br, "available_bits") and br.available_bits() < row_len * 8:
+            raise EOFError("Truncated payload: not enough bytes for row")
         row_bytes = br.read_bytes(row_len)
         rr = _BitReader(row_bytes)
         lb = rr.read_bits(5)
         k_param = rr.read_bits(4)
         use_bitmap = rr.read_bits(1)
-        B = 1 << lb
-        n_sub = C // B
+        B = 1 << lb
+        # B must be a power of two ≤ C and divide C exactly.
+        if B < 1 or (B & (B - 1)) != 0 or (C % B) != 0:
+            raise ValueError(f"Invalid row block size B={B} for C={C}")
+        n_sub = C // B
@@
-            if use_bitmap:
+            if use_bitmap:
                 bitmask = rr.read_bits(B)
                 for loc in range(B):
                     if (bitmask >> loc) & 1:
                         indices.append(j * B + loc)
             else:
                 for _ in range(s_len):
-                    loc = rr.read_bits(lb)
+                    loc = rr.read_bits(lb)
+                    if loc >= B:
+                        raise ValueError(f"Decoded loc {loc} out of range for B={B}")
                     indices.append(j * B + loc)
         rows.append(indices)
+        # Ensure row chunk fully consumed
+        if rr.idx != len(row_bytes):
+            raise ValueError("Row payload has trailing or malformed data")
     return rows, C, N

301-304: Use Triton’s math API for log2 to avoid compilation issues.

tl.math.log2 is the supported path; cast explicitly.

-    lb = tl.log2(B.to(tl.float32)).to(tl.int32)
+    lb = tl.math.log2(B.to(tl.float32)).to(tl.int32)

370-393: Add dtype/range guards in the public encoder.

Fail fast on wrong dtype or indices outside [0, C).

 def encode_batch_rows(
     idx: torch.Tensor,  # [rows, k] int64 (CUDA strongly recommended)
@@
-    if idx.dtype != torch.int64:
-        idx = idx.to(torch.int64)
+    if idx.dtype != torch.int64:
+        idx = idx.to(torch.int64)
@@
-    rows, k = idx.shape
+    rows, k = idx.shape
+    if idx.numel():
+        mn = int(idx.min().item())
+        mx = int(idx.max().item())
+        if mn < 0 or mx >= C:
+            raise ValueError(f"encode_batch_rows: index out of range for C={C}: min={mn}, max={mx}")
🧹 Nitpick comments (4)
src/tplr/compress/__init__.py (2)

18-21: Fix misleading comment: encoder does not return a perm.

encode_batch_rows returns (bytes, meta). Drop “+ perm” to avoid confusion.

-from .bits import (
-    decode_batch_rows,  # decoder (CPU)
-    encode_batch_rows,  # GPU-accelerated encoder → bytes + perm + meta
-)
+from .bits import (
+    decode_batch_rows,  # decoder (CPU)
+    encode_batch_rows,  # GPU-accelerated encoder → bytes + meta
+)

22-33: Consider deprecating legacy pack12 exports in public API.

Given the migration policy to accept only uint8 Rice/bitmap on the wire, re-exporting pack_12bit_indices/unpack_12bit_indices can prolong legacy usage. Either mark as deprecated with warnings or drop from all.

Do you want a follow-up PR to emit DeprecationWarning on import?

src/tplr/compress/bits.py (2)

69-74: Expose available_bits() to enable safe length checks.

Small helper to validate remaining capacity before reads.

     def read_bytes(self, n: int) -> bytes:
         out = bytearray()
         for _ in range(n):
             out.append(self.read_bits(8))
         return bytes(out)
+
+    def available_bits(self) -> int:
+        # Bits left in cur plus remaining bytes.
+        return self.nbits + (len(self.data) - self.idx) * 8

445-451: Expose MAX_K and reflect new top‑k ceiling in error text.

Consider exporting MAX_K or documenting it to align with topk=128 default.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 807b76a and 2727ada.

📒 Files selected for processing (2)
  • src/tplr/compress/__init__.py (1 hunks)
  • src/tplr/compress/bits.py (1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
📚 Learning: 2025-08-31T15:02:03.568Z
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.

Applied to files:

  • src/tplr/compress/bits.py
🧬 Code graph analysis (1)
src/tplr/compress/__init__.py (2)
src/tplr/compress/bits.py (2)
  • decode_batch_rows (82-118)
  • encode_batch_rows (371-479)
src/tplr/compress/pack12.py (2)
  • pack_12bit_indices (21-47)
  • unpack_12bit_indices (50-75)
🪛 GitHub Actions: CI
src/tplr/compress/bits.py

[error] 18-23: I001: Import block is un-sorted or un-formatted.

Comment on lines +18 to +31
from __future__ import annotations
import math
from typing import Sequence

import torch
import torch.nn.functional as F # only used if we pick the one-hot path

try:
import triton
import triton.language as tl

TRITON_AVAILABLE = True
except Exception:
TRITON_AVAILABLE = False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix CI import-order error and remove unused F import.

Ruff I001 flags lines 18–23; also torch.nn.functional as F is unused.

-from __future__ import annotations
-import math
-from typing import Sequence
-
-import torch
-import torch.nn.functional as F  # only used if we pick the one-hot path
-
-try:
-    import triton
-    import triton.language as tl
-
-    TRITON_AVAILABLE = True
-except Exception:
-    TRITON_AVAILABLE = False
+from __future__ import annotations
+
+import math
+from typing import Sequence
+
+import torch
+
+try:
+    import triton
+    import triton.language as tl
+    TRITON_AVAILABLE = True
+except Exception:
+    TRITON_AVAILABLE = False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from __future__ import annotations
import math
from typing import Sequence
import torch
import torch.nn.functional as F # only used if we pick the one-hot path
try:
import triton
import triton.language as tl
TRITON_AVAILABLE = True
except Exception:
TRITON_AVAILABLE = False
from __future__ import annotations
import math
from typing import Sequence
import torch
try:
import triton
import triton.language as tl
TRITON_AVAILABLE = True
except Exception:
TRITON_AVAILABLE = False
🧰 Tools
🪛 GitHub Actions: CI

[error] 18-23: I001: Import block is un-sorted or un-formatted.

🤖 Prompt for AI Agents
In src/tplr/compress/bits.py around lines 18 to 31, fix the import-order and
remove the unused alias: move stdlib imports (math, typing) directly after the
from __future__ line, then import third-party packages (torch) and remove
"import torch.nn.functional as F" since F is unused, and keep the triton
try/except block after torch; ensure grouping follows stdlib then third-party
order to satisfy Ruff I001.

@joellidin joellidin force-pushed the feat/efficient-compress branch from 2727ada to 65167f1 Compare August 31, 2025 22:56
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (4)
src/tplr/compress/bits.py (4)

45-53: Prevent bit underflow on truncated input in BitReader.read_bits.

Guard after _fill to avoid corrupting state when not enough bits remain.

Apply:

 def read_bits(self, n: int) -> int:
     if n <= 0:
         return 0
     self._fill(n)
+    if self.nbits < n:
+        raise EOFError(f"Not enough bits in stream to read {n} bits")
     mask = (1 << n) - 1
     out = self.cur & mask
     self.cur >>= n
     self.nbits -= n
     return out

82-113: Harden decoder against malformed/untrusted payloads (header, per-row length, B, and trailing data).

Validate header and row chunks; fail fast on inconsistencies.

Apply:

     br = _BitReader(payload)
     C = br.read_bits(12) + 1
     N = br.read_bits(16)
     _ = br.read_bits(1)  # reserved
-
-    rows: list[list[int]] = []
+    if C < 1:
+        raise ValueError(f"Invalid C={C}")
+    rows: list[list[int]] = []
     for _i in range(N):
-        row_len = br.read_bits(16)
+        row_len = br.read_bits(16)
+        if row_len < 0:
+            raise ValueError("Negative row length")
+        if br.available_bits() < row_len * 8:
+            raise EOFError("Truncated payload: not enough bytes for row")
         row_bytes = br.read_bytes(row_len)
         rr = _BitReader(row_bytes)
         lb = rr.read_bits(5)
         k_param = rr.read_bits(4)
         use_bitmap = rr.read_bits(1)
         B = 1 << lb
-        n_sub = C // B
+        if B < 1 or (C % B) != 0:
+            raise ValueError(f"Invalid row block size B={B} for C={C}")
+        n_sub = C // B
 
         indices: list[int] = []
         for j in range(n_sub):
             s_len = _rice_read(rr, k_param)
             if s_len == 0:
                 continue
             if use_bitmap:
                 bitmask = rr.read_bits(B)
                 for loc in range(B):
                     if (bitmask >> loc) & 1:
-                        indices.append(j * B + loc)
+                        gidx = j * B + loc
+                        if gidx >= C:
+                            raise ValueError("Decoded index out of range")
+                        indices.append(gidx)
             else:
                 for _ in range(s_len):
                     loc = rr.read_bits(lb)
-                    indices.append(j * B + loc)
+                    if loc >= B:
+                        raise ValueError("Decoded loc out of range for block")
+                    gidx = j * B + loc
+                    if gidx >= C:
+                        raise ValueError("Decoded index out of range")
+                    indices.append(gidx)
         rows.append(indices)
+        if rr.idx != len(row_bytes):
+            raise ValueError("Row payload has trailing or malformed data")
     return rows, C, N

310-311: Use Triton’s math API for log2; tl.log2 can miscompile.

Switch to tl.math.log2 to compute lb robustly.

Apply:

-    lb = tl.log2(B.to(tl.float32)).to(tl.int32)
+    lb = tl.math.log2(B.to(tl.float32)).to(tl.int32)

133-136: Mirror range validation in parameter estimator.

Keep GPU counts sane; reject out-of-range indices upfront.

Apply:

     assert idx.dtype == torch.int64
     rows, k = idx.shape
     device = idx.device
+    if idx.numel():
+        mn = int(idx.min().item())
+        mx = int(idx.max().item())
+        if mn < 0 or mx >= C:
+            raise ValueError(f"Index out of range for C={C}: min={mn}, max={mx}")
🧹 Nitpick comments (2)
src/tplr/compress/bits.py (2)

39-44: Expose available_bits() for safe decoder bounds checks.

Small helper enables validating remaining capacity before reads.

Apply:

     def _fill(self, n: int) -> None:
         while self.nbits < n and self.idx < len(self.data):
             self.cur |= int(self.data[self.idx]) << self.nbits
             self.idx += 1
             self.nbits += 8
+
+    def available_bits(self) -> int:
+        # Remaining bits including latched bits.
+        return (len(self.data) - self.idx) * 8 + self.nbits

481-485: Report allocated bits per row in meta (matches actual payload).

Use per-row byte-allocation rather than raw bit estimate.

Apply:

-        "avg_bits_per_row": float(row_bits.float().mean().item()) if rows > 0 else 0.0,
+        "avg_bits_per_row": float((row_bytes.to(torch.int64) * 8).float().mean().item()) if rows > 0 else 0.0,
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 2727ada and 65167f1.

📒 Files selected for processing (2)
  • src/tplr/compress/__init__.py (1 hunks)
  • src/tplr/compress/bits.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/tplr/compress/init.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.
📚 Learning: 2025-08-31T15:02:03.568Z
Learnt from: joellidin
PR: tplr-ai/templar#553
File: src/tplr/comms.py:2626-2702
Timestamp: 2025-08-31T15:02:03.568Z
Learning: The user joellidin has explicitly decided not to implement backward compatibility for legacy int64 index tensors in the check_compressed_indices function in src/tplr/comms.py. The function should only accept uint8 Rice/bitmap payload format and reject other formats to enforce network-wide migration to the new compression scheme.

Applied to files:

  • src/tplr/compress/bits.py
🪛 GitHub Actions: CI
src/tplr/compress/bits.py

[error] 13-13: I001: Import block is un-sorted or un-formatted.

Comment on lines +20 to +26
try:
import triton
import triton.language as tl

TRITON_AVAILABLE = True
except Exception:
TRITON_AVAILABLE = False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix import block to satisfy Ruff I001 (don’t assign inside import block).

Move TRITON_AVAILABLE assignment out of the try/except and avoid blank separation inside the block.

Apply:

-try:
-    import triton
-    import triton.language as tl
-
-    TRITON_AVAILABLE = True
-except Exception:
-    TRITON_AVAILABLE = False
+try:
+    import triton
+    import triton.language as tl
+except Exception:
+    triton = None
+    tl = None
+TRITON_AVAILABLE = triton is not None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
import triton
import triton.language as tl
TRITON_AVAILABLE = True
except Exception:
TRITON_AVAILABLE = False
try:
import triton
import triton.language as tl
except Exception:
triton = None
tl = None
TRITON_AVAILABLE = triton is not None
🤖 Prompt for AI Agents
In src/tplr/compress/bits.py around lines 20 to 26, the import block assigns
TRITON_AVAILABLE inside the try which triggers Ruff I001; move the assignment
out of the try/except: initialize TRITON_AVAILABLE = False before the try, keep
the try block containing only the import statements with no extra blank line,
use except Exception: pass, and after the try/except set TRITON_AVAILABLE = True
only if the imports succeeded (for example by checking that the imported name
exists in locals() or by using a try/except/else and assigning True in the
else).

Comment on lines +386 to +391
if idx.dtype != torch.int64:
idx = idx.to(torch.int64)
if not idx.is_cuda:
idx = idx.cuda()
idx = idx.contiguous()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Validate index range early to prevent silent corruption (encode path).

Fail fast if any idx < 0 or ≥ C.

Apply:

     if idx.dtype != torch.int64:
         idx = idx.to(torch.int64)
     if not idx.is_cuda:
         idx = idx.cuda()
     idx = idx.contiguous()
+    if idx.numel():
+        mn = int(idx.min().item())
+        mx = int(idx.max().item())
+        if mn < 0 or mx >= C:
+            raise ValueError(f"Index out of range for C={C}: min={mn}, max={mx}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if idx.dtype != torch.int64:
idx = idx.to(torch.int64)
if not idx.is_cuda:
idx = idx.cuda()
idx = idx.contiguous()
if idx.dtype != torch.int64:
idx = idx.to(torch.int64)
if not idx.is_cuda:
idx = idx.cuda()
idx = idx.contiguous()
# Fail fast on any out-of-bounds index
if idx.numel():
mn = int(idx.min().item())
mx = int(idx.max().item())
if mn < 0 or mx >= C:
raise ValueError(f"Index out of range for C={C}: min={mn}, max={mx}")
🤖 Prompt for AI Agents
In src/tplr/compress/bits.py around lines 386–391, add explicit range validation
for idx before it can silently corrupt encoded data: after converting idx to an
integer tensor (and before using it on CUDA), move a copy to CPU if needed,
compute min_val = int(idx.min().item()) and max_val = int(idx.max().item()), and
if min_val < 0 or max_val >= C raise a ValueError (or assert) with a clear
message containing min_val, max_val and C; this ensures we fail fast on
out-of-range indices and avoids silent corruption.

@joellidin joellidin force-pushed the feat/efficient-compress branch from 65167f1 to 62815c0 Compare August 31, 2025 23:08
@joellidin joellidin force-pushed the feat/efficient-compress branch from 62815c0 to 51fed2b Compare August 31, 2025 23:11
@joellidin joellidin marked this pull request as draft September 4, 2025 20:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant