Skip to content

feat: distributed gradient gathering for multi-GPU parallelism#649

Open
shivam-MBZUAI wants to merge 1 commit intodevfrom
feat/distribute_gather
Open

feat: distributed gradient gathering for multi-GPU parallelism#649
shivam-MBZUAI wants to merge 1 commit intodevfrom
feat/distribute_gather

Conversation

@shivam-MBZUAI
Copy link
Copy Markdown
Contributor

@shivam-MBZUAI shivam-MBZUAI commented Nov 6, 2025

Summary

Implements distributed gradient gathering to enable parallel gradient downloads across multiple GPU ranks.

Key Features

Distributed Gather Implementation

  • Parallel downloads: UIDs partitioned deterministically across GPU ranks
  • All-gather synchronization: Gradients collected and merged from all ranks
  • Memory optimization: CPU intermediate storage prevents GPU OOM on 70B models
  • Auto-detection: Automatically uses distributed gather when world_size > 1

Technical Details

Memory Management

  • Move to CPU before all-gather (prevents 64GB GPU usage on 8 ranks)
  • Local merge optimization before all-gather
  • Aggressive cleanup after gradient application

Synchronization

  • All ranks participate in barriers for consistency
  • Deterministic UID partitioning (cached per window)
  • Broadcast commitments and peers to all ranks

Changes by File

  • neurons/miner.py: Distributed gather integration for miners
  • neurons/validator.py: Distributed gather + test mode fixes
  • src/tplr/comms.py: Core gather_distributed() and merge_gather_results()
  • src/tplr/neurons.py: Simplified outer_step(), fixed sync score calculation
  • src/tplr/compress.py: Minor quantization device handling
  • neurons/trainer.py: Gradient norm logging
  • neurons/base_node.py: Minor improvements

Summary by CodeRabbit

  • New Features

    • True distributed gather with partial/merged results, deterministic UID partitioning, device-aware merging, and improved storage/object checks.
  • Bug Fixes

    • Better CUDA memory handling and aggressive cleanup, safer post-gather state clears, consistent device handling, and grad-norm captured and logged.
  • Chores

    • Expanded metrics/logging (success rates, skipped peers), master-driven time sync, longer gather timeouts, and more resilient download/offload paths.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Nov 6, 2025

Walkthrough

Adds deterministic per-rank distributed gather (with optional partial returns and merging), device-aware result movement, resilient S3/io with threaded loads, and distributed coordination (broadcasts/timestamps). Propagates partial-aware handling through outer-step, overlap checks, validation, and miner flow; improves CUDA memory hygiene, gradient logging, and gather metrics.

Changes

Cohort / File(s) Change Summary
Distributed Gather & I/O
src/tplr/comms.py
Added deterministic UID partitioning (partition_uids_across_ranks), device movement helper (_move_gather_result_to_device), merge_gather_results, and gather_distributed for per-rank fetch + master merge; caching of partitions; extended gather_with_reserve to support return_partials; added async s3_object_exists; offloaded torch.load to threads; added NOT_FOUND retry limiter and improved logging; gradient_timestamp now defaults version to tplr.version.
Miner: gather control & memory hygiene
neurons/miner.py
Sets PYTORCH_CUDA_ALLOC_CONF early; master broadcasts peers/reserve/commitments and a timestamp; introduced use_distributed_gather gating and distributed-aware gather flow; aggressive GC/CUDA empty_cache/sync and explicit gradient clearing between inner/outer steps; clears gather_result.state_dict when present; logs successful/skipped peers, gather success rate, and propagates update decisions via dist helper.
Validator: partial-aware metrics & norms
neurons/validator.py
Initializes openskill ratings for new UIDs; master-driven timestamp broadcast and unified time window; uses distributed gather by default with increased timeout; added helpers _get_uids_from_gather_result, _get_skipped_uids_from_gather_result, _get_success_rate_from_gather_result; expanded compute_peer_val_norms to accept `SimpleNamespace
Neurons: outer-step, overlap & device helpers
src/tplr/neurons.py
Added _resolve_cuda_device(device) to pick CUDA device safely; introduced _outer_step_single(...) as a memory-minimizing single-result outer step; outer_step now delegates to _outer_step_single; check_uid_index_overlap updated to accept merged or partial gather_result and to map per-UID state_dicts for index alignment.
Compressor fix
src/tplr/compress.py
In _dequantize_values(), when val.dtype == torch.uint8, move both lookup and shift tensors to val.device before dequantization.
Trainer: gradient norm logging
neurons/trainer.py
Capture total_grad_norm from clip_grad_norm_(...).item() during inner steps and log it; set grad_norm = 0.0 in spin-up branch (no optimizer step).
BaseNode: init chain
neurons/base_node.py
Call super().__init__() inside BaseNode.__init__() to continue MRO initialization.

Sequence Diagram(s)

sequenceDiagram
    participant Miner
    participant Comms
    participant S3
    participant Neurons
    participant Validator

    rect rgba(230,240,250,0.9)
    note right of Miner: Initiate gather (distributed-capable)
    Miner->>Comms: gather_with_reserve(..., return_partials?)
    alt distributed path (world_size>1)
        Comms->>Comms: partition_uids_across_ranks()
        par per-rank fetch (concurrent)
            Comms->>S3: s3_get_object(uids_for_rank)
            S3-->>Comms: partial_result
        end
        Comms->>Comms: merge_gather_results(partials) [master]
        Comms-->>Miner: merged_result or [partials]
    else sequential path
        Comms->>S3: fetch all UIDs
        S3-->>Comms: merged_result
        Comms-->>Miner: merged_result
    end
    end

    rect rgba(245,240,230,0.9)
    note right of Miner: Post-gather processing and validation
    Miner->>Neurons: outer_step(gather_result, device=_resolve_cuda_device(...))
    alt gather_result is list (partials)
        Neurons->>Neurons: map UID -> partial.state_dict -> _outer_step_single
    else merged result
        Neurons->>Neurons: _outer_step_single()
    end
    Neurons->>Validator: compute_peer_val_norms(gather_result)
    Validator-->>Neurons: norms + gather metrics (success/skips)
    Miner->>Comms: broadcast decisions (master) / receive (workers)
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Files/areas needing extra attention:
    • src/tplr/comms.py: partitioning determinism, partial merge semantics, device transfers, async/threaded S3 loads, retry logic and fallbacks.
    • neurons/miner.py & neurons/validator.py: broadcast/timestamp ordering, barrier usage, and correct propagation/interpretation of partial vs merged gather results.
    • src/tplr/neurons.py: per-UID state_dict mapping and index alignment for partials.
    • Memory-management sequences: correctness and potential performance/race impacts from aggressive GC/CUDA cache/sync and explicit gradient clearing.

Poem

🐇 I hop through ranks with UIDs in tow,
I fetch in parts and stitch them slow.
Caches swept and gradients light,
Partials merged — the training's bright.
A rabbit cheers: distributed flow!

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: distributed gradient gathering for multi-GPU parallelism' directly and clearly describes the main feature introduced in this PR, which is enabling parallel gradient downloads across GPU ranks.
Description check ✅ Passed The PR description is comprehensive and well-structured with a clear summary, key features, technical details, and a breakdown of changes by file, though it does not follow the repository's description template exactly.
Docstring Coverage ✅ Passed Docstring coverage is 87.50% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/distribute_gather

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link
Copy Markdown

codecov bot commented Nov 6, 2025

Codecov Report

❌ Patch coverage is 20.32967% with 290 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/tplr/comms.py 16.77% 248 Missing ⚠️
src/tplr/neurons.py 38.70% 38 Missing ⚠️
src/tplr/compress.py 0.00% 4 Missing ⚠️

❌ Your patch check has failed because the patch coverage (20.32%) is below the target coverage (85.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project check has failed because the head coverage (55.08%) is below the target coverage (85.00%). You can increase the head coverage or adjust the target coverage.

Impacted file tree graph

@@            Coverage Diff             @@
##              dev     #649      +/-   ##
==========================================
- Coverage   57.89%   55.08%   -2.82%     
==========================================
  Files          27       27              
  Lines        4895     5223     +328     
==========================================
+ Hits         2834     2877      +43     
- Misses       2061     2346     +285     
Files with missing lines Coverage Δ
src/tplr/compress.py 80.28% <0.00%> (-0.69%) ⬇️
src/tplr/neurons.py 73.37% <38.70%> (-3.83%) ⬇️
src/tplr/comms.py 54.31% <16.77%> (-10.78%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/tplr/comms.py (1)

2364-2380: Fix potential NameError when fallback is None.

Lines 2373-2374 reference fallback outside the if fallback: block (lines 2364-2372). If fallback is None, accessing fallback.upload_bytes and fallback.download_bytes will raise an AttributeError.

Apply this diff to fix the indentation:

                     if fallback:
                         # merge tensor‑lists inside the nested state_dict
                         for k, v in vars(fallback.state_dict).items():
                             merged = getattr(primary.state_dict, k, []) + v
                             setattr(primary.state_dict, k, merged)
 
                         primary.uids.extend(fallback.uids)
                         primary.global_steps.extend(fallback.global_steps)
                         primary.skipped_uids.extend(fallback.skipped_uids)
-                    primary.upload_bytes += fallback.upload_bytes
-                    primary.download_bytes += fallback.download_bytes
+                        primary.upload_bytes += fallback.upload_bytes
+                        primary.download_bytes += fallback.download_bytes
 
                     context_log(
                         message=f"[gather_with_reserve] ✅ reserve gather "
♻️ Duplicate comments (1)
src/tplr/comms.py (1)

2234-2234: Move dist_helper import to module level.

Same as the issue in gather_distributed: dist_helper should be imported at the module level rather than inside the function.

🧹 Nitpick comments (5)
hparams/2B.json (1)

1-26: Add trailing newline to JSON file.

The file is missing a trailing newline, which is a common convention for text files and may cause issues with some tooling.

Add a newline after the closing brace on line 26.

src/tplr/neurons.py (2)

226-293: Consider documenting the memory_budget_mb asymmetry.

The memory_budget_mb parameter is passed to _outer_step_incremental but not to _outer_step_single. While this is intentional (single-result processing doesn't need budget tracking), it would be helpful to document this in the docstring to clarify when the budget is actually enforced.

Update the docstring to clarify:

     Args:
         gather_result: Either a single SimpleNamespace or list of partial results
-        memory_budget_mb: Optional memory budget in MB (not enforced yet, future work)
+        memory_budget_mb: Optional memory budget in MB (enforced only for list of partials, ignored for single results)

471-473: Consider element-weighted averaging for means.

The means are averaged uniformly across partials, but this may not accurately represent the combined mean if partials have different numbers of elements per parameter. Consider weighting by element count for more accurate statistics.

Track element counts per partial and compute weighted average:

             for param_name, mean in partial_fingerprint.get("param_means", {}).items():
                 if param_name not in combined_fingerprint["param_means"]:
-                    combined_fingerprint["param_means"][param_name] = []
-                combined_fingerprint["param_means"][param_name].append(mean)
+                    combined_fingerprint["param_means"][param_name] = {"sum": 0.0, "count": 0}
+                elem_count = partial_fingerprint.get("param_elem_counts", {}).get(param_name, 1)
+                combined_fingerprint["param_means"][param_name]["sum"] += mean * elem_count
+                combined_fingerprint["param_means"][param_name]["count"] += elem_count

Then finalize:

         # Average the means
         for param_name in combined_fingerprint["param_means"]:
-            means_list = combined_fingerprint["param_means"][param_name]
-            combined_fingerprint["param_means"][param_name] = sum(means_list) / len(means_list)
+            stats = combined_fingerprint["param_means"][param_name]
+            combined_fingerprint["param_means"][param_name] = stats["sum"] / stats["count"] if stats["count"] > 0 else 0.0
neurons/validator.py (1)

91-132: Consider adding input validation to helper methods.

The helper methods assume gather_result is either a list or has attributes like uids, skipped_uids, success_rate. If an unexpected type is passed (e.g., None when it shouldn't be, or a different object type), the methods will fail with unclear errors or return incorrect results.

Consider adding type validation at the start of each helper:

 @staticmethod
 def _get_uids_from_gather_result(gather_result) -> list[int]:
     """Extract UIDs from either partials (list) or merged result (SimpleNamespace)."""
+    if gather_result is None:
+        return []
+    if not isinstance(gather_result, (list, SimpleNamespace)):
+        tplr.logger.warning(f"Unexpected gather_result type: {type(gather_result)}")
+        return []
     if isinstance(gather_result, list):

Apply similar guards to the other two helpers.

src/tplr/comms.py (1)

1982-1984: Consider using config device instead of hardcoding cuda:0.

In multi-GPU setups, the primary device may not be cuda:0. Consider defaulting to self.config.device if available, or requiring the caller to specify the target device explicitly.

Apply this diff to use a more flexible default:

         # Determine target device for tensor consolidation
         if target_device is None:
-            target_device = "cuda:0" if torch.cuda.is_available() else "cpu"
+            # Default to primary CUDA device or CPU
+            target_device = f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu"
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 080905b and 98d3bb3.

📒 Files selected for processing (5)
  • hparams/2B.json (1 hunks)
  • neurons/miner.py (2 hunks)
  • neurons/validator.py (10 hunks)
  • src/tplr/comms.py (8 hunks)
  • src/tplr/neurons.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (1)
src/tplr/comms.py (1)
  • gather (1607-1929)
src/tplr/neurons.py (1)
src/tplr/compress.py (5)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • ChunkingTransformer (145-292)
  • TopKCompressor (295-696)
src/tplr/comms.py (2)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
neurons/validator.py (5)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
neurons/base_node.py (1)
  • query_block_timestamp (202-247)
src/tplr/comms.py (2)
  • gather_with_reserve (2199-2423)
  • merge_gather_results (1961-2070)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1493-1636)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (12)
neurons/miner.py (2)

707-725: LGTM: Sequential gather with comprehensive parameters.

The updated gather call correctly uses the expanded parameter set that aligns with the new distributed gathering capabilities. The logging provides clear visibility into the gather process.


728-738: Good: Enhanced logging for gather results.

The post-gather logging provides valuable insights including success rate, counts, and timing, which will aid in monitoring and debugging the distributed gathering process.

src/tplr/neurons.py (4)

427-431: LGTM: Correct L2 norm accumulation.

The squared norms are correctly accumulated and then square-rooted at the end (line 466-468) to compute the combined L2 norm. This is the mathematically correct way to combine norms from multiple sources.


516-769: LGTM: Original logic correctly preserved in separate function.

The _outer_step_single function correctly preserves the original outer step implementation. The refactoring cleanly separates single-result and incremental processing without changing the core gradient application logic.


488-501: Potential division by zero in utilization calculation.

Line 490 calculates utilization_pct but doesn't guard against memory_budget_mb being zero or negative before the division.

Apply this fix:

             final_memory_mb = torch.cuda.memory_allocated(device) / (1024 ** 2)
-            utilization_pct = (final_memory_mb / memory_budget_mb * 100) if memory_budget_mb > 0 else 0
+            utilization_pct = (final_memory_mb / memory_budget_mb * 100) if memory_budget_mb > 0 else 0.0
             
             if final_memory_mb > memory_budget_mb:

Actually, the current code already has the guard if memory_budget_mb > 0, so this is already correct. However, consider also handling negative budgets:

-            utilization_pct = (final_memory_mb / memory_budget_mb * 100) if memory_budget_mb > 0 else 0
+            utilization_pct = (final_memory_mb / memory_budget_mb * 100) if memory_budget_mb > 0 else 0.0

Likely an incorrect or invalid review comment.


386-403: Consider whether defensive programming for circular references and deep recursion is necessary for this specific use case.

The review comment raises theoretically valid concerns, but verification shows they are not grounded in observed issues:

  • Circular references: No circular-structure patterns found in test payloads or real usage. PyTorch state_dicts are typically OrderedDicts mapping to Tensors/buffers where circular references are uncommon, and the actual codebase uses flat SimpleNamespace attributes populated from dict payloads.
  • Deep nesting: Real data shows max 1-2 levels of nesting (e.g., [torch.tensor(...)]), not the deeply nested structures that would trigger recursion limits.
  • Custom objects: Already handled safely via the else clause, which returns objects unchanged.

The suggested safeguards are defensive but add complexity. The current code is safe for the documented use case. If the structure is guaranteed to remain shallow and contain only basic types (tensors, lists, dicts, metadata), the addition is unnecessary. Verify with the maintainers whether this assumption holds before adding the overhead.

hparams/2B.json (1)

22-22: This review comment is incorrect and based on misunderstandings of the codebase.

The gradient_memory_budget_mb parameter is marked as optional and currently unused. The function docstring in src/tplr/neurons.py:323 explicitly states: "Optional memory budget in MB (future: early stopping if exceeded)", indicating this is placeholder functionality that hasn't been implemented yet. There are no budget violation warnings or enforcement mechanisms in the code.

Additionally, the memory calculation in the review is based on incorrect assumptions about compression. The compression uses a count-based topk parameter (absolute element count per layer), not the 3% sparsity ratio assumed in the calculation. The parameter appears only in the 2B model configuration and is not present in any other hparams files (150M.json, 1B.json, etc.), suggesting it's experimental or model-specific rather than a standard configuration concern.

Likely an incorrect or invalid review comment.

neurons/validator.py (3)

1382-1396: LGTM: Proper peer list broadcasting for distributed gather.

The peer list broadcast from master to all ranks is correctly implemented. The use of dist.broadcast_object_list ensures all ranks have the same peer lists for coordinated distributed gathering.


1530-1577: LGTM: Clean distributed vs sequential gather implementation.

The conditional gather execution properly handles both distributed and sequential modes with appropriate rank coordination and logging. The distinction between modes is clear and well-documented in the logs.


2501-2501: LGTM: Proper memory budget parameter propagation.

The memory_budget_mb parameter is correctly passed from hparams to outer_step, enabling memory-bounded incremental processing. The use of getattr with a None default ensures backward compatibility.

src/tplr/comms.py (2)

1354-1375: LGTM: Backward-compatible API improvement.

The optional version parameter with sensible default improves ergonomics while maintaining backward compatibility.


1934-1958: LGTM: Deterministic round-robin partitioning.

The implementation correctly uses sorting for determinism and round-robin slicing for balanced distribution across ranks.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
src/tplr/comms.py (1)

2100-2101: Move imports to module level.

Importing modules inside a function is an anti-pattern that impacts performance and readability. The time module is already imported at line 26, and dist_helper should be imported at the module level for consistency.

This issue was previously flagged in past review comments but has not been addressed.

🧹 Nitpick comments (2)
neurons/validator.py (1)

1439-1454: Time window calculation runs on all ranks.

The time window calculation (including query_block_timestamp call) runs on all ranks because distributed gather requires all ranks to participate with consistent time bounds. While query_block_timestamp might involve network calls, this is necessary for correctness.

If query_block_timestamp becomes a performance bottleneck, consider calculating on master and broadcasting:

# Master calculates timestamp
if self.is_master:
    ts_value = self.query_block_timestamp(sync_block)
    if ts_value is None:
        tplr.log_with_context(...)
        ts_value = time.time()
    
# Broadcast to all ranks
ts_tensor = torch.tensor([ts_value], dtype=torch.float64, device=self.device)
dist_helper.broadcast(ts_tensor, src=0)
ts_value = ts_tensor.item()

# All ranks compute time window from broadcasted timestamp
time_min = datetime.fromtimestamp(ts_value, tz=timezone.utc)
time_max = time_min + timedelta(seconds=self.hparams.time_window_delta_seconds)
src/tplr/neurons.py (1)

1598-1652: Consider caching partial-to-UID mapping for efficiency.

The nested loop (Lines 1615-1619) searches for the partial containing uids[i] on every iteration, resulting in O(Ptot × num_partials) complexity per parameter. While correct, this could be optimized by building a uid_to_partial_info mapping once during the initial traversal at Lines 1540-1554.

For typical deployments with 2-4 GPUs, the performance impact is negligible, so this optimization is optional. The current implementation prioritizes clarity over micro-optimization.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 98d3bb3 and 0ce09f2.

📒 Files selected for processing (4)
  • hparams/2B.json (1 hunks)
  • neurons/validator.py (11 hunks)
  • src/tplr/comms.py (8 hunks)
  • src/tplr/neurons.py (8 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • hparams/2B.json
🧰 Additional context used
🧬 Code graph analysis (3)
src/tplr/comms.py (2)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
src/tplr/neurons.py (1)
src/tplr/compress.py (6)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • ChunkingTransformer (145-292)
  • TopKCompressor (295-696)
  • unpack_12bit_indices (102-142)
neurons/validator.py (6)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
neurons/base_node.py (1)
  • query_block_timestamp (202-247)
src/tplr/comms.py (2)
  • gather_with_reserve (2199-2455)
  • merge_gather_results (1961-2070)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1524-1733)
src/tplr/compress.py (1)
  • maybe_dequantize_values (646-696)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (16)
neurons/validator.py (5)

91-132: LGTM: Clean abstraction for partial vs merged gather results.

These static helper methods provide a clean, maintainable way to extract UIDs, skipped UIDs, and success rate from either partial or merged gather results. The implementation correctly handles both data structures and simplifies downstream code.

Minor note: Line 116 deduplicates skipped UIDs using set(), which loses ordering. This is acceptable since skipped UIDs don't require specific ordering.


1382-1396: Correct distributed synchronization for peer lists.

Broadcasting the peer lists from master to all ranks is necessary for distributed gather where each rank needs to know which peers to fetch from. The synchronization happens before gather is invoked (line 1534), ensuring all ranks have consistent peer information.


1507-1583: Well-structured distributed gather integration.

The distributed gather mode detection and handling is clean and consistent. Distributed gather returns partials for memory-efficient incremental processing, while sequential gather returns merged results for traditional processing. The logging appropriately handles both cases using the helper methods.


1618-1640: Excellent improvement over past review comments.

The past review comments flagged three separate merges (upload, overlap check, norms) creating unnecessary memory overhead. The current implementation addresses this beautifully:

  • Overlap check (line 1632): Now handles partials directly, no merge needed ✅
  • Norms computation (line 1840): Now handles partials directly, no merge needed ✅
  • Upload (lines 1618-1623): Still requires merge for artifact format

By making overlap check and norms partial-aware, you've eliminated 2 out of 3 merges. The remaining upload merge is acceptable since it's async, happens in the background, and artifacts legitimately need merged format. This is actually better than the suggested fix in the past comments!


2494-2494: Clean integration of memory budget configuration.

Passing gradient_memory_budget_mb from hparams to outer_step enables memory-bounded incremental processing of partials, as described in the PR objectives. The getattr with None default provides a safe fallback if the configuration isn't present.

src/tplr/comms.py (6)

1354-1387: LGTM!

The signature change to make version optional with explicit None handling maintains backward compatibility while providing flexibility for callers.


1931-1959: LGTM!

The round-robin partitioning logic is correct and deterministic. Sorting the UIDs ensures consistent ordering across all ranks, and the slice notation correctly implements the round-robin assignment.


1960-2071: Verify parallel time tracking semantics.

The merge logic correctly uses max(total_time, ...) on Line 2044, which is appropriate for distributed parallel execution where partials are fetched simultaneously. However, ensure downstream consumers understand this represents wall-clock time, not cumulative processing time.

The state_dict merging logic (Lines 2028-2040) correctly handles both list and non-list param_data, and the device consolidation recursively moves all tensors to the target device.


2072-2195: Verify all-gather synchronization semantics.

The distributed gather coordination is well-structured. Line 2153 correctly uses synchronous all_gather_object as a blocking barrier operation, ensuring all ranks have consistent partial results before proceeding.

Confirm that the return_partials flag behavior is correctly understood by callers: when True, each rank receives a list of all partials (not just its own), enabling incremental processing.


2255-2310: Robust distributed gather with fallback.

The fallback logic is well-designed:

  1. Tries distributed gather when configured and available
  2. Falls back to sequential on master rank if distributed fails or returns no gradients (Lines 2278-2299)
  3. Provides clear logging for debugging

The check on Line 2278 correctly handles both None results and empty partial lists, ensuring the fallback triggers when no gradients are collected.


2352-2456: Verify reserve peer handling for partials.

The reserve peer fallback correctly handles both partials and merged results:

  • For partials (Lines 2383-2395): appends fallback as a new partial with proper rank assignment
  • For merged results (Lines 2397-2412): merges state_dict lists and extends metadata

Lines 2417-2440 compute aggregate success rate across all partials. Verify that this matches the expected semantics for downstream consumers who may expect per-partial success rates.

src/tplr/neurons.py (5)

226-293: LGTM!

The dispatcher pattern cleanly separates incremental and single-result processing paths. The conditional routing based on isinstance(gather_result, list) is clear and the memory_budget_mb parameter is correctly passed only to the incremental path that uses it.


295-545: Thorough incremental processing with robust error recovery.

The implementation correctly handles:

  1. Deterministic ordering (Line 359): Sorts partials by rank for reproducibility
  2. Multi-GPU tensor consolidation (Lines 391-407): Critical fix that moves tensors from different GPU devices (cuda:0, cuda:1, etc.) to the target model device
  3. Duplicate prevention (Lines 375-381): Tracks applied UIDs to skip duplicates
  4. L2 norm accumulation (Lines 431-435, 484-487): Correctly accumulates squared norms and takes sqrt at the end
  5. Graceful degradation (Lines 465-479): Catches exceptions, logs failures, continues processing remaining partials
  6. Memory management (Lines 451-464): Aggressive cleanup with budget monitoring

The error recovery logs failures (Lines 473-477, 507-516) without breaking the entire batch, which is appropriate for production resilience.


547-801: LGTM!

The extraction of the original single-result logic into _outer_step_single is clean. The functionality remains unchanged, and the updated docstring correctly describes it as the "original single-result outer step implementation."


1539-1559: Correct partial vs. merged result handling.

The logic correctly distinguishes between partials and merged results:

  • Partials (Lines 1540-1554): Each UID maps to its partial's state_dict, enabling per-UID data extraction
  • Merged (Lines 1556-1559): All UIDs share the same merged state_dict, which already contains aggregated gradients

The design maintains backward compatibility while supporting the new incremental partial processing path.


1662-1677: Verify pairwise comparison uses correct UID indices.

The comparison correctly uses valid_uids_for_param (Lines 1673-1675) to map local parameter indices back to the original UID indices in the global uids list. This ensures pair keys like (uid_idx_i, uid_idx_j) correctly identify peer pairs even when some peers are missing the parameter.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
scripts/cleanup_bucket.py (1)

33-33: Update the docstring to reflect all deleted prefixes.

The docstring mentions only three prefixes but the implementation now deletes six types of files.

Apply this diff to update the docstring:

-    """Delete objects in the R2 bucket that start with 'checkpoint', 'gradient', or 'start_window'"""
+    """Delete objects in the R2 bucket that start with 'checkpoint', 'gradient', 'start_window', 'aggregator', 'debug', or 'peers'"""
neurons/validator.py (1)

4398-4407: Fix slash multiplier when gather success rate is low to preserve scores instead of zeroing them.

The slash_multiplier is applied multiplicatively to scores (line 4420: self.final_scores[uid] *= slash_multiplier). When success_rate is low, setting slash_multiplier to 0.0 wipes positive scores to zero. Since the goal is to avoid slashing when gather is unreliable, use 1.0 instead to preserve scores.

                if success_rate < self.hparams.gather_peers_slash_threshold:
-                    slash_multiplier = 0.0  # Set to 0 if success rate is low
+                    slash_multiplier = 1.0  # No slashing when gather reliability is low
                 else:
♻️ Duplicate comments (3)
neurons/validator.py (1)

3850-3888: Fix vals/quant_params misalignment in partials path (can IndexError).

When a partial lacks quant_params (or lengths differ), all_vals and all_quant_params diverge, breaking maybe_dequantize_values. Pad/trim qparams per-partial to match appended vals.

-                    if vals is not None:
-                        # vals is a list of tensors, one per peer in this partial
-                        # Move all tensors (including nested) to target device
-                        for v in vals:
-                            all_vals.append(move_to_device_recursive(v))
-                        # quant_params is also a list if present
-                        if quant_params is not None:
-                            for qp in quant_params:
-                                all_quant_params.append(move_to_device_recursive(qp))
+                    if vals is not None:
+                        num_vals = len(vals)
+                        for v in vals:
+                            all_vals.append(move_to_device_recursive(v))
+                        # Align qparams count with vals
+                        if quant_params is None:
+                            all_quant_params.extend([None] * num_vals)
+                        else:
+                            qps = list(quant_params)
+                            if len(qps) < num_vals:
+                                qps = qps + [None] * (num_vals - len(qps))
+                            elif len(qps) > num_vals:
+                                qps = qps[:num_vals]
+                            for qp in qps:
+                                all_quant_params.append(move_to_device_recursive(qp))
@@
-                vals_f32 = self.compressor.maybe_dequantize_values(
-                    all_vals,
-                    all_quant_params if all_quant_params else None,
-                    target_device,
-                )
+                vals_f32 = self.compressor.maybe_dequantize_values(
+                    all_vals, all_quant_params, target_device
+                )
src/tplr/comms.py (2)

2181-2182: Move imports to module level.

Importing modules inside a function impacts performance and readability. The time module is already imported at line 26, and dist_helper should be imported at the module level for consistency.

Apply this diff:

 from types import SimpleNamespace
 from typing import Any, Literal, cast
 
 import aiofiles
 import bittensor as bt
 import boto3
 import botocore
 import torch
 from aiobotocore.client import AioBaseClient
 from aiobotocore.session import get_session
 from boto3.s3.transfer import TransferConfig
 from botocore.exceptions import ClientError, ConnectionClosedError
 from tqdm import tqdm as std_tqdm
 
 import tplr
 from tplr.chain import ChainManager
 from tplr.compress import TopKCompressor, unpack_12bit_indices
 from tplr.config import BUCKET_SECRETS, client_config
 from tplr.schemas import Bucket, CommsGetResult
+from tplr.distributed import dist_helper

Then remove the import from lines 2181-2182:

-        import time
-        from tplr.distributed import dist_helper
-        
         window = kwargs.get("window", None)

Based on learnings


2317-2317: Move dist_helper import to module level.

Similar to line 2182, dist_helper is imported locally here. This should be at module level for consistency and performance, as noted in the previous comment and past review.

🧹 Nitpick comments (8)
scripts/cleanup_bucket.py (1)

93-93: Verify that "training files" accurately describes all deleted prefixes.

The updated messages use the generic term "training files" instead of enumerating specific prefixes, which is more maintainable. However, ensure that all six prefixes ("checkpoint", "gradient", "start_window", "aggregator", "debug", "peers") are actually training-related files. Files with "debug" or "peers" prefixes may serve different purposes.

Also applies to: 114-114, 123-123

neurons/validator.py (4)

95-134: Deduplicate UIDs when aggregating across partials (affects success_rate).

Duplicate UIDs across partials will inflate counts. Dedup before computing metrics.

Apply:

@@
-            all_uids = []
+            all_uids: list[int] = []
@@
-            return all_uids
+            return list(set(all_uids))
@@
-            total_uids = len(Validator._get_uids_from_gather_result(gather_result))
-            total_skipped = len(
+            total_uids = len(set(Validator._get_uids_from_gather_result(gather_result)))
+            total_skipped = len(set(
                 Validator._get_skipped_uids_from_gather_result(gather_result)
-            )
+            ))

138-202: Handle list-of-partials in offload_gather_results to actually save GPU RAM.

Currently no-op for list inputs; extend to offload each partial’s state_dict.

 def offload_gather_results(self, gather_result, *, log: bool = True) -> None:
@@
-        if gather_result is None:
+        if gather_result is None:
             return
@@
-        if getattr(gather_result, "_offloaded", None) == "cpu":
-            return
+        if isinstance(gather_result, list):
+            # Offload each partial independently
+            moved_total = 0
+            for part in gather_result:
+                if part is None or getattr(part, "_offloaded", None) == "cpu":
+                    continue
+                moved_total += self._offload_one_namespace(part)
+            if moved_total and log:
+                tplr.log_with_context(
+                    level="info",
+                    message=f"[GatherOffload] partials → CPU pinned | moved ~{moved_total / 1e6:.1f} MB",
+                    sync_window=self.sync_window,
+                    current_window=self.current_window,
+                )
+            return
+        if getattr(gather_result, "_offloaded", None) == "cpu":
+            return
@@
-        moved_bytes = 0
+        moved_bytes = 0
@@
-        state = getattr(gather_result, "state_dict", None)
+        state = getattr(gather_result, "state_dict", None)
         if state is None:
             return
@@
-        setattr(gather_result, "_offloaded", "cpu")
+        setattr(gather_result, "_offloaded", "cpu")
@@
         if log:
             tplr.log_with_context(
                 level="info",
                 message=(
                     "[GatherOffload] aggregator → CPU pinned | moved ~"
                     f"{moved_bytes / 1e6:.1f} MB in {time.time() - t0:.3f}s"
                 ),
                 sync_window=self.sync_window,
                 current_window=self.current_window,
             )
+
+    def _offload_one_namespace(self, ns: SimpleNamespace) -> int:
+        moved_bytes = 0
+        state = getattr(ns, "state_dict", None)
+        if state is None:
+            return 0
+        def _to_pinned_cpu(obj):
+            nonlocal moved_bytes
+            if torch.is_tensor(obj):
+                if obj.device.type == "cpu":
+                    return obj.pin_memory() if not obj.is_pinned() else obj
+                dst = torch.empty_like(obj, device="cpu", pin_memory=True)
+                dst.copy_(obj, non_blocking=True)
+                moved_bytes += obj.element_size() * obj.numel()
+                return dst
+            if isinstance(obj, list):
+                return [_to_pinned_cpu(x) for x in obj]
+            if isinstance(obj, tuple):
+                return tuple(_to_pinned_cpu(x) for x in obj)
+            if isinstance(obj, dict):
+                return {k: _to_pinned_cpu(v) for k, v in obj.items()}
+            if isinstance(obj, SimpleNamespace):
+                for k, v in vars(obj).items():
+                    setattr(obj, k, _to_pinned_cpu(v))
+                return obj
+            return obj
+        if isinstance(state, SimpleNamespace):
+            _to_pinned_cpu(state)
+        elif isinstance(state, dict):
+            for k, v in list(state.items()):
+                state[k] = _to_pinned_cpu(v)
+        setattr(ns, "_offloaded", "cpu")
+        return moved_bytes

2490-2507: Standardize device argument to string for outer_step.

Align with Miner and Comms usage for predictability across loaders/mappers.

-                device=cast(str, self.device),
+                device=str(self.device),

1539-1555: Use str(self.device) instead of cast(str, self.device) for consistency with Miner.

The gather() method at src/tplr/comms.py:1655 explicitly types the device parameter as str. Miner uses str(self.device) when calling the same method, making this a consistency issue. The cast() form is a type-checking construct and doesn't perform runtime conversion; str() is more explicit and Pythonic.

-                    device=cast(str, self.device),
+                    device=str(self.device),
src/tplr/neurons.py (3)

477-484: Memory budget check is informational only.

The memory budget check logs a warning but continues processing. This aligns with the docstring stating "not enforced yet, future work", but consider adding a configuration flag to make this check fail-fast when memory constraints are critical.


1571-1591: Optimize state_dict storage for partials.

For partials, the code stores state_dict references for each UID (lines 1582-1583), creating O(n) duplicate references where n is the total number of UIDs. Consider building a uid_to_partial mapping once upfront to avoid redundant storage and improve the lookup performance in the unpacking phase (lines 1645-1654).

Apply this pattern:

# Build mapping once upfront
uid_to_partial: dict[int, tuple[SimpleNamespace, int]] = {}
for partial in gather_result:
    if partial is not None:
        partial_uids = getattr(partial, "uids", [])
        for idx, uid in enumerate(partial_uids):
            uid_to_partial[uid] = (partial, idx)

uids = list(uid_to_partial.keys())

# Then in unpacking loop (around line 1645):
partial_for_uid, uid_idx_in_partial = uid_to_partial.get(uids[i], (None, None))
if partial_for_uid is None:
    continue

1645-1667: Nested search is O(partials) per UID per parameter.

The code searches through all partials for each UID on each parameter (lines 1645-1654). With many partials, parameters, and UIDs, this becomes expensive. The suggested uid_to_partial mapping in the previous comment would eliminate this repeated search.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0ce09f2 and e73a41b.

📒 Files selected for processing (5)
  • neurons/miner.py (2 hunks)
  • neurons/validator.py (11 hunks)
  • scripts/cleanup_bucket.py (3 hunks)
  • src/tplr/comms.py (10 hunks)
  • src/tplr/neurons.py (8 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (1)
src/tplr/comms.py (1)
  • gather (1648-2010)
neurons/validator.py (6)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
neurons/base_node.py (1)
  • query_block_timestamp (202-247)
src/tplr/comms.py (2)
  • gather_with_reserve (2282-2547)
  • merge_gather_results (2042-2151)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1556-1769)
src/tplr/compress.py (1)
  • maybe_dequantize_values (646-696)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
src/tplr/neurons.py (1)
src/tplr/compress.py (6)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • ChunkingTransformer (145-292)
  • TopKCompressor (295-696)
  • unpack_12bit_indices (102-142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (19)
neurons/miner.py (3)

707-709: LGTM: clearer gather start log.

The updated sequential-gather log is helpful and low-cost.


728-739: LGTM: post-gather summary logging.

Concise success/skip breakdown improves ops visibility.


711-725: Original review comment is incorrect—expected_compressed_params IS initialized.

Miner inherits Trainer.init_model(), which is called at line 200 of Miner.init() with meta=True. Trainer.init_model() unconditionally executes self.expected_compressed_params = self.get_expected_params() at line 123, regardless of the meta parameter. The attribute is properly initialized before the gather() call at line 724, so no AttributeError will occur.

Likely an incorrect or invalid review comment.

neurons/validator.py (3)

1557-1593: LGTM: distributed/seq gather completion logs.

Good differentiation between partials vs merged with clear KPIs.


1386-1399: LGTM: broadcast peers to all ranks.

Keeps ranks aligned for distributed gather without extra RPCs.


1446-1460: LGTM: shared time window computation.

Consistent time_min/max across ranks; sensible fallback to current time.

src/tplr/neurons.py (3)

226-293: LGTM – Clean delegation pattern.

The refactoring to delegate based on input type (single vs. list of partials) is well-structured. All parameters are correctly forwarded, and the delegation logic is clear.


486-501: LGTM – Graceful error recovery.

The error handling approach is robust: logs the error, tracks failed partials with metadata, and continues processing remaining partials. This prevents a single bad partial from blocking the entire gradient application.


443-465: Verify approach to averaging parameter means across partials.

The code accumulates param_means from each partial as a list, then computes a simple arithmetic average (lines 512–516). This is mathematically correct only if all partials contribute parameters with equal weight. If partials can have different parameter distributions or varying numbers of elements per parameter, the mean should be weighted by parameter size. However, the fingerprint structure doesn't track per-parameter element counts across partials, making weighted averaging difficult to implement with the current design.

Clarify whether:

  • Partials always include the complete parameter set (non-sparse)
  • If sparse, whether weighted averaging by total_elements is needed or if simple average is semantically correct for your use case
src/tplr/comms.py (10)

477-517: LGTM – Conservative error handling prevents false slashing.

The implementation correctly uses HEAD requests to check existence without downloading. The conservative approach of returning True for non-404 errors (lines 511-516) is appropriate to prevent falsely penalizing miners due to transient network issues.


1395-1427: LGTM – Flexible version parameter.

The signature change from version: str = tplr.__version__ to version: str | None = None with runtime defaulting (lines 1415-1416) is more flexible and maintains backward compatibility.


2015-2040: LGTM – Simple and deterministic partitioning.

The round-robin assignment with deterministic sorting ensures all ranks agree on the partition. The implementation is straightforward and correct.


2093-2121: LGTM – Device consolidation during merge is correct.

The merge logic correctly concatenates state_dict lists from all partials while moving tensors to the target device (lines 2115-2116). This device consolidation is essential for multi-GPU scenarios where different ranks may have tensors on different devices.


2132-2134: Success rate calculation is partial-aware.

The success rate correctly considers both all_uids (successful) and all_skipped_uids (failed) from all partials, giving an accurate overall success rate rather than just the rate from individual partials.


2244-2277: LGTM – Return path logic handles both merged and partial modes.

The conditional return based on return_partials flag is clear and correct. When returning partials (lines 2244-2260), rank metadata is added for downstream deterministic sorting. When merging (lines 2262-2277), the merge happens on all ranks for consistency, with appropriate logging on rank 0.


2340-2397: LGTM – Robust distributed gather with fallback.

The implementation correctly attempts distributed gather (lines 2352-2375), detects failures (lines 2362-2370), and falls back to sequential gather on the master rank (lines 2377-2387). The fallback logic handles both exceptions and empty results, providing resilience against distributed coordination issues.


2473-2502: LGTM – Reserve fallback correctly handles both merged and partial modes.

The code correctly appends reserve results as a new partial (lines 2473-2485) when using partials mode, and merges tensor lists directly (lines 2486-2502) for merged mode. The rank assignment for the new partial (line 2476) ensures deterministic ordering.


2507-2546: LGTM – Success rate computation is mode-aware.

For partials (lines 2507-2528), the code computes an aggregate success rate and propagates it to all partials. For merged results (lines 2534-2541), it computes a simple rate. Both approaches are correct for their respective modes.


1754-1803: LGTM – File existence check prevents false slashing.

The approach of checking S3 object existence to distinguish between network failures and missing uploads is sound. When a file exists but download fails, the miner is not penalized (lines 1765-1772, 1790-1796), correctly attributing the failure to network issues rather than miner negligence.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
src/tplr/comms.py (2)

2181-2182: Move imports to module level.

In-function imports are an anti-pattern. The time module is already imported at line 26, and dist_helper should be imported at the module level for consistency.

This issue was flagged in previous reviews. Apply this diff to move the import:

 from tplr.chain import ChainManager
 from tplr.compress import TopKCompressor, unpack_12bit_indices
 from tplr.config import BUCKET_SECRETS, client_config
+from tplr.distributed import dist_helper
 from tplr.schemas import Bucket, CommsGetResult

Then remove the import from line 2182:

-        import time
-        from tplr.distributed import dist_helper
-        
         window = kwargs.get("window", None)

Based on learnings


2317-2317: Move import to module level.

Same issue as line 2182: dist_helper should be imported at the module level, not inside the function.

Remove this import after moving dist_helper to the module-level imports as suggested in the previous comment.

Based on learnings

🧹 Nitpick comments (1)
src/tplr/comms.py (1)

2412-2546: Complex but correct handling of partial vs merged results.

The code correctly handles both return_partials=True (returning a list) and return_partials=False (returning merged namespace) throughout the reserve gather logic. The success rate calculations for both cases (lines 2507-2545) are mathematically correct.

The complexity is necessary for the dual-mode operation. Consider extracting the success rate calculation into a helper method if similar logic is needed elsewhere.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e73a41b and 128afb8.

📒 Files selected for processing (1)
  • src/tplr/comms.py (10 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (5)
src/tplr/comms.py (5)

477-517: LGTM: Conservative existence check implementation.

The method correctly uses HEAD requests to check object existence without downloading. The conservative approach of returning True on non-404 errors (lines 512-513, 515-516) is appropriate to prevent false slashing due to transient network issues.


1395-1416: LGTM: Proper None-handling for optional version parameter.

The signature change and None-handling logic (lines 1415-1416) correctly default to tplr.__version__ when version is not provided, maintaining backward compatibility.


2015-2039: LGTM: Correct deterministic partitioning.

The round-robin assignment with sorted UIDs ensures deterministic, balanced distribution across ranks.


2340-2397: LGTM: Robust fallback mechanism for distributed gather.

The logic properly attempts distributed gather first, detects failures (lines 2362-2370), and falls back to sequential gather on the master rank (lines 2378-2387). The error handling and logging provide good observability.


1754-1803: Critical: Undefined attributes cause runtime failures.

Lines 1759 and 1784 reference self.version, which is never defined in the Comms class. Lines 1762 and 1787 call self.peer_bucket(str(uid)), which does not exist. These were flagged in previous reviews but remain unfixed.

Apply this diff:

                        )
-                        # Check if file exists on R2 to distinguish miner vs network failure
-                        gradient_key = f"{key}-{window}-{uid}-v{self.version}.pt"
+                        # Check if file exists on R2 to distinguish miner vs network failure  
+                        gradient_key = f"{key}-{window}-{uid}-v{tplr.__version__}.pt"
                         file_exists = await self.s3_object_exists(
                             key=gradient_key,
-                            bucket=self.peer_bucket(str(uid)),
+                            bucket=self.commitments.get(int(uid)),
                             timeout=5,
                         )

Also apply to lines 1784-1787:

                     if response is None:
-                        # Similar check for None responses
-                        gradient_key = f"{key}-{window}-{uid}-v{self.version}.pt"
+                        # Similar check for None responses
+                        gradient_key = f"{key}-{window}-{uid}-v{tplr.__version__}.pt"
                         file_exists = await self.s3_object_exists(
                             key=gradient_key,
-                            bucket=self.peer_bucket(str(uid)),
+                            bucket=self.commitments.get(int(uid)),
                             timeout=5,
                         )

Likely an incorrect or invalid review comment.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
src/tplr/neurons.py (2)

328-329: Remove redundant imports.

Both gc and torch.cuda are already imported at the module level (lines 20 and 28). These redundant imports inside the function are unnecessary.

Apply this diff:

-    import gc
-    import torch.cuda
-

1584-1603: Consider clearer data structure for partial-to-UID mapping.

The current approach builds all_state_dicts by appending the same partial.state_dict reference multiple times (once per UID in that partial). While functionally correct, this creates a list where all_state_dicts[i] corresponds to uids[i]. However, the subsequent code (lines 1655-1663) then re-searches through all partials to find which partial each UID belongs to, which is inefficient and confusing.

Consider building a direct UID-to-partial mapping upfront:

# Build UID -> partial mapping once
uid_to_partial = {}
if isinstance(gather_result, list):
    for partial in gather_result:
        if partial is not None:
            for uid in getattr(partial, "uids", []):
                uid_to_partial[uid] = partial
    uids = list(uid_to_partial.keys())
    all_state_dicts = [uid_to_partial[uid].state_dict for uid in uids]
else:
    uids = list(getattr(gather_result, "uids", []))
    all_state_dicts = [gather_result.state_dict] * len(uids)

This eliminates the need to search for partials later and makes the relationship explicit.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 128afb8 and 000c12a.

📒 Files selected for processing (1)
  • src/tplr/neurons.py (8 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/neurons.py (3)
tests/test_comms.py (1)
  • compress (946-948)
tests/test_prepare_gradient_dict.py (3)
  • compress (18-27)
  • compress (155-165)
  • compress (242-253)
src/tplr/compress.py (6)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • ChunkingTransformer (145-292)
  • TopKCompressor (295-696)
  • unpack_12bit_indices (102-142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (3)
src/tplr/neurons.py (3)

225-292: LGTM: Clean dispatcher pattern.

The refactoring to split single and incremental processing paths improves code organization and makes the incremental processing logic easier to maintain.


591-844: LGTM: Fingerprint statistics properly integrated.

The addition of per-parameter fingerprint tracking (lines 742-747) correctly accumulates norms, means, and element counts for monitoring gradient statistics.


512-516: Verify mean averaging is correct across partials with varying element counts.

The code computes a simple average of means across partials, but if different partials have different numbers of elements for the same parameter, this should be a weighted average. For example, if partial 1 contributes 100 elements with mean=0.5 and partial 2 contributes 10 elements with mean=1.0, the combined mean should be (100*0.5 + 10*1.0)/110 = 0.545, not (0.5 + 1.0)/2 = 0.75.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (4)
src/tplr/comms.py (3)

2197-2198: Move imports to module level.

The time module is already imported at line 26, and dist_helper should be imported at the module level for consistency with the rest of the file. In-function imports impact performance and readability.

Move the import to the top of the file:

 from types import SimpleNamespace
 from typing import Any, Literal, cast
 
 import aiofiles
 import bittensor as bt
 import boto3
 import botocore
 import torch
 from aiobotocore.client import AioBaseClient
 from aiobotocore.session import get_session
 from boto3.s3.transfer import TransferConfig
 from botocore.exceptions import ClientError, ConnectionClosedError
 from tqdm import tqdm as std_tqdm
 
 import tplr
 from tplr.chain import ChainManager
 from tplr.compress import TopKCompressor, unpack_12bit_indices
 from tplr.config import BUCKET_SECRETS, client_config
 from tplr.schemas import Bucket, CommsGetResult
+from tplr.distributed import dist_helper

Then remove the local import at line 2198.


2333-2333: Move dist_helper import to module level.

Another instance of the local dist_helper import. This should be consolidated with the module-level imports for consistency.

Use the same module-level import suggested in the previous comment and remove this local import.


2080-2081: Verify cuda:0 availability before defaulting.

Line 2081 defaults to "cuda:0" when CUDA is available, but doesn't verify that device 0 specifically exists. On multi-GPU systems, device 0 might not be available.

Use a safer default:

         # Determine target device for tensor consolidation
         if target_device is None:
-            target_device = "cuda:0" if torch.cuda.is_available() else "cpu"
+            if torch.cuda.is_available() and torch.cuda.device_count() > 0:
+                target_device = f"cuda:{torch.cuda.current_device()}"
+            else:
+                target_device = "cpu"
src/tplr/neurons.py (1)

1644-1698: Optimize O(n×m) partial lookup in inner loop.

The nested loop at lines 1661-1665 searches through all partials for each UID, creating O(n×m) complexity where n is the number of UIDs and m is the number of partials. This issue was previously identified but remains unresolved.

As suggested in the previous review, build a UID-to-partial mapping once before the loop:

# Before line 1644, add:
if isinstance(gather_result, list):
    uid_to_partial = {}
    for partial in gather_result:
        if partial is not None:
            for uid in getattr(partial, "uids", []):
                uid_to_partial[uid] = partial

Then replace lines 1644-1698 with direct lookups using this mapping, eliminating the inner search loop.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 000c12a and 4c36955.

📒 Files selected for processing (2)
  • src/tplr/comms.py (10 hunks)
  • src/tplr/neurons.py (8 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/tplr/neurons.py (1)
src/tplr/compress.py (6)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • ChunkingTransformer (145-292)
  • TopKCompressor (295-696)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
🔇 Additional comments (8)
src/tplr/neurons.py (4)

226-293: LGTM! Clean dispatch pattern.

The refactored outer_step function provides a clear separation between incremental and single-result processing paths. The documentation accurately describes the memory-minimizing approach and the conditional dispatch logic is straightforward.


451-456: Verify mean averaging is statistically correct.

The code averages parameter means across partials using simple arithmetic mean. This is only correct if each partial contributes the same number of elements for each parameter. If partials have different numbers of UIDs or elements, the mean should be weighted by the number of elements each partial contributes.

Consider tracking element counts per parameter and computing a weighted average:

 for param_name, mean in partial_fingerprint.get(
     "param_means", {}
 ).items():
     if param_name not in combined_fingerprint["param_means"]:
-        combined_fingerprint["param_means"][param_name] = []
-    combined_fingerprint["param_means"][param_name].append(mean)
+        combined_fingerprint["param_means"][param_name] = {"sum": 0.0, "count": 0}
+    # Need to track element count from partial_fingerprint
+    elem_count = partial_fingerprint.get("total_elements", 1)
+    combined_fingerprint["param_means"][param_name]["sum"] += mean * elem_count
+    combined_fingerprint["param_means"][param_name]["count"] += elem_count

Then at finalization:

 # Average the means
 for param_name in combined_fingerprint["param_means"]:
-    means_list = combined_fingerprint["param_means"][param_name]
-    combined_fingerprint["param_means"][param_name] = sum(means_list) / len(
-        means_list
-    )
+    stats = combined_fingerprint["param_means"][param_name]
+    combined_fingerprint["param_means"][param_name] = stats["sum"] / stats["count"]

470-484: LGTM! Proper memory management.

The aggressive memory cleanup after each partial (including garbage collection, CUDA cache clearing, and synchronization) is appropriate for memory-bounded incremental processing. The budget check after cleanup ensures accurate memory measurements.


384-501: LGTM! Robust error recovery.

The error handling implements proper graceful degradation: failures are logged with context (rank, UIDs, error message), collected for final reporting, and processing continues with remaining partials. This ensures maximum gradient application even when some partials fail.

src/tplr/comms.py (4)

477-516: LGTM! Conservative existence check.

The s3_object_exists method correctly implements a conservative approach by returning True on non-404 errors (line 513). This prevents false slashing due to transient network issues, which is the right design choice for a production system.


1395-1416: LGTM! Proper version handling.

The updated gradient_timestamp method now accepts an optional version parameter with a sensible default to tplr.__version__ when None. This addresses potential version-handling issues.


2032-2055: LGTM! Deterministic partitioning.

The round-robin UID partitioning with deterministic sorting (line 2052) ensures all ranks agree on the assignment and achieve balanced load distribution. The implementation is correct and efficient.


1754-1818: LGTM! Proper failure differentiation.

The enhanced error handling correctly distinguishes between network failures (file exists but download failed) and miner failures (file doesn't exist). This prevents false slashing due to transient network issues while still catching miners who didn't upload gradients.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
neurons/validator.py (2)

1392-1409: Consider adding error handling for peer broadcast.

The dist.broadcast_object_list() call lacks error handling. If the broadcast fails or times out, non-master ranks would retain None values for self.comms.peers and self.comms.reserve_peers, potentially causing issues in the subsequent distributed gather.

Consider wrapping the broadcast in a try-except block with a fallback:

         # Broadcast peers list from master to all ranks (needed for distributed gather)
         if dist_helper.world_size > 1 and dist_helper.is_distributed():
-            peers_list = (
-                [self.comms.peers, self.comms.reserve_peers]
-                if self.is_master
-                else [None, None]
-            )
-            dist.broadcast_object_list(peers_list, src=0)
-            self.comms.peers = peers_list[0]
-            self.comms.reserve_peers = peers_list[1]
+            try:
+                peers_list = (
+                    [self.comms.peers, self.comms.reserve_peers]
+                    if self.is_master
+                    else [None, None]
+                )
+                dist.broadcast_object_list(peers_list, src=0)
+                
+                if peers_list[0] is None or peers_list[1] is None:
+                    raise RuntimeError("Broadcast resulted in None values")
+                    
+                self.comms.peers = peers_list[0]
+                self.comms.reserve_peers = peers_list[1]
+            except Exception as e:
+                tplr.log_with_context(
+                    level="error",
+                    message=f"Failed to broadcast peers: {e}",
+                    sync_window=self.sync_window,
+                    current_window=self.current_window,
+                )
+                # Skip this window if broadcast fails
+                continue

1565-1592: Simplify logging - isinstance check is redundant.

Since return_partials=use_distributed_gather (line 1551), distributed gather always returns a list (partials) and sequential gather always returns SimpleNamespace (merged). The isinstance(gather_result, list) check at line 1577 is therefore redundant within the if use_distributed_gather block—it will always be True.

Simplify the logging:

             # Log gather completion
             if use_distributed_gather:
                 if gather_result is not None:
-                    # Handle both partials (list) and merged result (SimpleNamespace)
-                    # Show quality metrics for both partials and merged results
                     uids = self._get_uids_from_gather_result(gather_result)
                     skipped_uids = self._get_skipped_uids_from_gather_result(
                         gather_result
                     )
                     success_rate = self._get_success_rate_from_gather_result(
                         gather_result
                     )

-                    if isinstance(gather_result, list):
-                        tplr.log_with_context(
-                            level="info",
-                            message=f"[Rank {dist_helper.rank}] Distributed gather complete: {len(uids)}/{len(self.comms.peers)} successful, "
-                            f"{len(skipped_uids)} skipped, success_rate={success_rate:.2%} across {len(gather_result)} partials",
-                            sync_window=self.sync_window,
-                            current_window=self.current_window,
-                        )
-                    else:
-                        tplr.log_with_context(
-                            level="info",
-                            message=f"[Rank {dist_helper.rank}] Distributed gather complete: {len(uids)}/{len(self.comms.peers)} successful, "
-                            f"{len(skipped_uids)} skipped, success_rate={success_rate:.2%}",
-                            sync_window=self.sync_window,
-                            current_window=self.current_window,
-                        )
+                    # Distributed gather always returns partials (list)
+                    tplr.log_with_context(
+                        level="info",
+                        message=f"[Rank {dist_helper.rank}] Distributed gather complete: {len(uids)}/{len(self.comms.peers)} successful, "
+                        f"{len(skipped_uids)} skipped, success_rate={success_rate:.2%} across {len(gather_result)} partials",
+                        sync_window=self.sync_window,
+                        current_window=self.current_window,
+                    )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4c36955 and a6a7c3b.

📒 Files selected for processing (1)
  • neurons/validator.py (12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
neurons/validator.py (6)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
neurons/base_node.py (1)
  • query_block_timestamp (202-247)
src/tplr/comms.py (2)
  • gather_with_reserve (2298-2563)
  • merge_gather_results (2058-2167)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1570-1783)
src/tplr/compress.py (1)
  • maybe_dequantize_values (646-696)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (5)
neurons/validator.py (5)

94-134: LGTM! Well-designed helper methods for dual format support.

The three static helper methods cleanly abstract the differences between partials (list) and merged results (SimpleNamespace). The success rate calculation correctly handles the total attempted count, and the deduplication of skipped_uids prevents double-counting across partials.


975-981: Good defensive programming to prevent rating errors.

Initializing OpenSkill ratings for all UIDs before calling rate() prevents errors when new peers are encountered. This ensures the rating update always has valid input.


1635-1650: Good memory optimization - single merge for upload only.

The code now performs only one merge (for upload artifacts, lines 1636-1640) and passes gather_result directly to check_uid_index_overlap (line 1649), which handles both partials and merged formats. This addresses the past review concern about multiple temporary merges creating unnecessary memory overhead.


2511-2513: LGTM! Proper integration of memory budget parameter.

Using getattr with None default maintains backward compatibility with older hyperparameter files that don't include gradient_memory_budget_mb, while enabling the new memory-bounded incremental processing feature for configurations that do.


3820-3918: Edge case is properly handled - no changes required.

The refactored code correctly handles the empty vals edge case. When compute_peer_val_norms skips setting clip_norm_dict[vals_key] for a parameter (line 3882), the update_model_with_gradient function safely retrieves it via .get(vals_key, eval_norm) at line 3687. The fallback uses the evaluation's own norm as the clip threshold, which is a reasonable design: if no peer norms are available, clip based on the evaluation's own magnitude. The debug logging suggestion is optional and not necessary.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (3)
src/tplr/comms.py (3)

1754-1820: Extract duplicate file existence check logic into a helper method.

Lines 1759-1789 and 1793-1819 contain nearly identical logic for checking file existence and determining whether to slash a miner. This includes duplicate gradient key construction, bucket retrieval, existence checks, and logging patterns.

Consider extracting this into a helper method:

async def _should_skip_uid_for_missing_gradient(
    self, uid: int, key: str, window: int
) -> tuple[bool, str]:
    """
    Check if a UID should be skipped due to missing gradient.
    
    Returns:
        tuple[bool, str]: (should_skip, reason)
          - (True, "not_registered") if no bucket committed
          - (False, "network_issue") if file exists but wasn't retrieved
          - (True, "not_uploaded") if file doesn't exist
    """
    gradient_key = f"{key}-{window}-{uid}-v{tplr.__version__}.pt"
    peer_bucket = self.commitments.get(int(uid))
    
    if not peer_bucket:
        tplr.logger.info(f"Skipped UID {uid} - no bucket committed")
        return True, "not_registered"
    
    file_exists = await self.s3_object_exists(
        key=gradient_key, bucket=peer_bucket, timeout=5
    )
    
    if file_exists:
        tplr.log_with_context(
            level="warning",
            message=f"UID {uid}: File exists but not retrieved - network issue",
            current_window=window,
        )
        return False, "network_issue"
    else:
        tplr.logger.info(f"Skipped UID {uid} - file not found")
        return True, "not_uploaded"

Then replace both blocks with:

should_skip, reason = await self._should_skip_uid_for_missing_gradient(uid, key, window)
if should_skip:
    skipped_uids.append(uid)
continue

2080-2086: Device selection improved but could be more robust.

Line 2084 now uses torch.cuda.current_device() instead of hardcoded cuda:0, which addresses the previous review concern. However, current_device() can raise RuntimeError if called before any CUDA context is established.

Consider wrapping in try-except for maximum safety:

         if target_device is None:
             # Use the current device if CUDA is available, otherwise CPU
             if torch.cuda.is_available():
-                target_device = f"cuda:{torch.cuda.current_device()}"
+                try:
+                    target_device = f"cuda:{torch.cuda.current_device()}"
+                except RuntimeError:
+                    # No CUDA context yet, use device 0
+                    target_device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu"
             else:
                 target_device = "cpu"

2482-2488: Reserve gather always uses sequential mode.

Line 2483 uses sequential gather() for reserve UIDs even when distributed mode is available. This is a reasonable design choice since reserve sets are typically smaller, but it should be documented in the method docstring or code comment.

Add a comment explaining the decision:

# Use sequential gather for reserve (not distributed)
# Reserve sets are typically small, so the distributed overhead isn't worth it
fallback = await self.gather(
    my_uid=my_uid,
    uids=replacements,
    expected_compressed_params=expected_compressed_params,
    **kwargs,
)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e51df9f and 64bf414.

📒 Files selected for processing (1)
  • src/tplr/comms.py (11 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (5)
src/tplr/comms.py (5)

49-49: Import moved to module level as recommended.

This addresses the previous review comment about in-function imports. The placement is appropriate with other module-level imports.


1396-1417: Properly handles version parameter with None default.

The signature change from version: str = tplr.__version__ to version: str | None = None with runtime initialization correctly addresses the previous review feedback. This pattern avoids evaluation-time binding issues while maintaining backward compatibility.


2033-2057: Clean round-robin partitioning implementation.

The deterministic sorting followed by slice-based round-robin assignment (sorted_uids[rank::world_size]) is simple, efficient, and correct. Edge cases are properly handled.


2276-2278: Check for existing rank attribute before setting.

Lines 2276-2278 add rank metadata to partials without checking if the attribute already exists. While unlikely in practice, this could overwrite existing rank information if a partial was processed multiple times.

Add a guard:

             # Add rank metadata to each partial for deterministic sorting
             for idx, partial in enumerate(all_partials):
-                if partial is not None and not hasattr(partial, "rank"):
+                if partial is not None and not hasattr(partial, "rank"):
                     partial.rank = idx

Actually, the code already has and not hasattr(partial, "rank") - this is correct. No change needed.


2244-2260: Verify behavior when distributed setup fails.

If world_size > 1 but dist_helper.is_distributed() returns False (e.g., distributed initialization failed), line 2260 returns [partial_result], which contains only this rank's subset. This means other ranks' data is lost.

Consider adding a warning or error when this mismatch occurs:

if world_size > 1 and dist_helper.is_distributed():
    # ... all-gather logic ...
else:
    if world_size > 1:
        tplr.logger.error(
            f"[DISTRIBUTED_GATHER] world_size={world_size} but distributed not initialized! "
            f"Only returning rank {rank} data - other ranks' data will be lost"
        )
    all_partials = [partial_result]

This would help catch configuration issues where distributed gathering is expected but the environment isn't properly set up.

@shivam-MBZUAI shivam-MBZUAI changed the title feat: implement incremental gradient processing with memory-efficient… feat: Distributed Gather Nov 7, 2025
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
src/tplr/neurons.py (1)

557-557: Fix inverted budget_exceeded flag logic (still not addressed).

The logic remains inverted from the previous review: the flag is set to 0 when budget IS exceeded and 1 when it's NOT exceeded. This contradicts the variable name and will cause incorrect metrics.

Apply this diff:

-            budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+            budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0

The same issue exists at line 854 in _outer_step_single.

src/tplr/comms.py (1)

2395-2442: Master-only fallback still causes rank inconsistency (duplicate concern).

Lines 2433-2442 perform sequential gather fallback only on the master rank when distributed gather fails. This leaves other ranks with None while master has data, causing rank divergence.

As flagged in past reviews, consider:

  1. Broadcasting master's fallback result to all ranks using dist_helper.broadcast_object(), or
  2. Having all ranks perform the sequential fallback (if network is shared), or
  3. Explicitly failing if distributed gather was required

The current implementation will cause downstream issues when non-master ranks expect valid gather results.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 96df9d6 and f60714e.

📒 Files selected for processing (2)
  • src/tplr/comms.py (11 hunks)
  • src/tplr/neurons.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/tplr/neurons.py (1)
src/tplr/compress.py (6)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • ChunkingTransformer (145-292)
  • TopKCompressor (295-696)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (7)
src/tplr/neurons.py (2)

296-592: LGTM: Robust incremental partial processing.

The implementation correctly:

  • Sorts partials deterministically by rank for reproducible processing
  • Applies sequential device consolidation to handle multi-GPU tensors
  • Implements graceful error recovery (logs and continues on partial failures)
  • Aggregates statistics correctly (squared norms summed, then sqrt for L2)
  • Performs aggressive memory cleanup after each partial

The error handling at lines 486-502 provides good resilience without breaking the entire gather on individual partial failures.


1620-1809: LGTM: Efficient partial-aware overlap checking.

The updates correctly support both merged and partial gather results:

  • Lines 1625-1635: Creates O(1) UID-to-partial mapping (addresses past O(n×m) concern)
  • Lines 1687-1707: Uses direct lookup instead of linear search through partials
  • Lines 1737-1753: Correctly maps pair keys using original UID indices

The optimization from the past review is properly implemented.

src/tplr/comms.py (5)

478-555: LGTM: Conservative existence check with proper error classification.

The implementation correctly distinguishes between:

  • Definitive absence (404) → False
  • Configuration errors (403, NoSuchBucket) → False with ERROR-level logging
  • Transient issues (timeouts, 5xx) → True (conservative to avoid false slashing)

The ERROR-level logging for config issues (lines 524-534) helps surface legitimate setup problems that need attention.


1433-1465: LGTM: Clean optional version parameter.

The signature update correctly:

  • Uses str | None with default None instead of the problematic self.version
  • Applies runtime default tplr.__version__ at lines 1453-1454 when None is passed

This addresses the past review concern about undefined self.version.


2066-2334: LGTM: Well-structured distributed gather with proper device handling.

The implementation addresses past review concerns and is production-ready:

  • Lines 2118-2123: Uses torch.cuda.current_device() instead of hardcoded cuda:0 (addresses past review)
  • Lines 2070-2094: Deterministic round-robin partitioning ensures consistent UID assignment across ranks
  • Lines 2151-2163: Recursive device movement handles nested structures correctly
  • Lines 2283-2290: Proper all-gather synchronization with blocking semantics

The return_partials flag cleanly supports both merged and incremental processing modes.


1792-1857: LGTM: Improved error handling prevents false slashing.

The updated error handling correctly uses s3_object_exists to distinguish:

  • Lines 1806-1818: File exists but download failed → network issue, don't slash
  • Lines 1819-1826: File doesn't exist → miner didn't upload, legitimate slash

Same logic applied for None responses at lines 1828-1857. This prevents false negatives due to transient network issues while still catching miners who don't upload.

The 5-second timeout at lines 1809 and 1842 is reasonable for existence checks.


2468-2600: LGTM: Clean unified handling of partials and merged results.

The implementation elegantly handles both modes:

  • Lines 2468-2495: Aggregate metrics computed correctly for both partials and merged
  • Lines 2531-2540: Fallback appended as new partial with rank metadata (supports incremental processing)
  • Lines 2542-2557: Fallback merged in-place for merged results (supports traditional flow)
  • Lines 2562-2583: Success rate computed correctly as aggregate across all partials

The conditional logic cleanly branches on isinstance(primary, list) without code duplication.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
src/tplr/neurons.py (2)

558-558: CRITICAL: Fix inverted budget_exceeded flag.

This bug was flagged in past review comments but was NOT fixed. The flag logic is inverted:

  • Currently: budget_exceeded = 0 when memory exceeds budget (WRONG)
  • Should be: budget_exceeded = 1 when memory exceeds budget

This makes the WandB metric outer_step/memory_budget_violation report the opposite of reality.

Apply this fix:

-            budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+            budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0

Also applies to line 855 in _outer_step_single.


849-878: CRITICAL: Same inverted budget_exceeded flag bug.

This is the same bug as line 558. The flag logic is inverted at line 855:

budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1

This makes the WandB metric outer_step/memory_budget_violation incorrect.

Apply this fix:

-        budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+        budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0
🧹 Nitpick comments (2)
src/tplr/comms.py (1)

2118-2124: Add error handling for CUDA device operations.

The device consolidation uses torch.cuda.current_device() which is an improvement over the previous "cuda:0" hardcoding, but there's no error handling if CUDA operations fail.

Consider wrapping the device determination in a try-except:

         if target_device is None:
             # Use the current device if CUDA is available, otherwise CPU
             if torch.cuda.is_available():
-                target_device = f"cuda:{torch.cuda.current_device()}"
+                try:
+                    target_device = f"cuda:{torch.cuda.current_device()}"
+                except RuntimeError as e:
+                    tplr.logger.warning(f"Failed to get current CUDA device: {e}, falling back to cpu")
+                    target_device = "cpu"
             else:
                 target_device = "cpu"
src/tplr/neurons.py (1)

329-330: Remove redundant imports.

Both gc and torch.cuda imports inside the function are redundant:

  • gc is already imported at module level (line 20)
  • torch is already imported at module level (line 28), so torch.cuda is available

Remove these lines:

-    import gc
-
-    import torch.cuda
-
     # Track applied UIDs to prevent duplicates
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2a50f2a and e13eb9c.

📒 Files selected for processing (2)
  • src/tplr/comms.py (11 hunks)
  • src/tplr/neurons.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/tplr/neurons.py (1)
src/tplr/compress.py (5)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • ChunkingTransformer (145-292)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
🔇 Additional comments (10)
src/tplr/comms.py (5)

478-555: LGTM - Conservative error handling prevents false slashing.

The implementation correctly distinguishes between transient failures (timeouts, 5xx errors) and permanent failures (404, 403, invalid bucket). The conservative approach of returning True for transient errors prevents false slashing of miners due to network issues.

The logging levels are appropriate: ERROR for configuration issues that need attention, WARNING for transient issues.


1792-1857: Verify latency impact of HEAD requests in high-failure scenarios.

The file existence checks prevent false slashing by distinguishing network failures from missing files, which is critical. However, each failed download now triggers an additional HEAD request (5s timeout).

In distributed gather scenarios with many UIDs and high network failure rates, this could add significant cumulative latency (e.g., 20 failed downloads = 100s of additional checks even if all return immediately).

Consider monitoring the impact in production and potentially batching HEAD requests if latency becomes an issue:

# Potential optimization: batch HEAD requests
failed_uids_to_check = [(uid, gradient_key, peer_bucket) for ...]
results = await asyncio.gather(*[
    self.s3_object_exists(key=k, bucket=b, timeout=5)
    for _, k, b in failed_uids_to_check
])

2211-2333: LGTM - Well-structured distributed gather implementation.

The distributed gather correctly:

  • Partitions UIDs deterministically using round-robin (ensures reproducibility)
  • Processes subsets in parallel across ranks (achieves ~2x speedup as reported)
  • Collects results via all-gather (all ranks get consistent view)
  • Supports both merged and partial return modes for memory-bounded processing

The detailed per-rank logging will be valuable for debugging distributed issues.


2432-2461: LGTM - Consistent fallback ensures all ranks stay synchronized.

The fallback logic correctly ensures that all ranks fall back together when distributed gather fails, preventing rank inconsistency. The comment at line 2444 helpfully explains that sequential gather (R2 downloads) can be done independently by each rank without requiring coordination.

This addresses the past review concern about master-only fallback causing inconsistency.


2475-2609: LGTM - Consistent handling of merged and partial results.

The implementation correctly:

  • Computes aggregate metrics across partials when return_partials=True
  • Uses traditional success rate calculation for merged results
  • Returns None when no gradients were successfully gathered (for both cases)
  • Always uses sequential gather for reserve (reasonable since reserve set is smaller)

The logging clearly distinguishes between partials and merged results, which aids debugging.

src/tplr/neurons.py (5)

226-293: LGTM - Clean separation of incremental and single-result processing.

The expanded signature and branching logic clearly separate the incremental partial processing path from the original single-result path. The type hints accurately reflect the supported input types.


480-486: Memory budget is logged but never enforced.

The code tracks memory usage against memory_budget_mb and logs warnings when exceeded, but never enforces the budget (no early stopping, no rejection of additional partials). The comment at line 325 says "future: early stopping if exceeded" but the implementation only logs.

Clarify the intended behavior:

  1. Is this intentional (logging only for monitoring)?
  2. Should the code stop processing additional partials when budget is exceeded?
  3. Should there be a hard limit that fails the entire operation?

If enforcement is needed, consider:

if current_memory_mb > memory_budget_mb:
    tplr.logger.error(
        f"Memory budget exceeded: {current_memory_mb:.0f}MB > {memory_budget_mb:.0f}MB. "
        f"Stopping partial processing at {partial_idx + 1}/{len(partials)}"
    )
    break  # Stop processing additional partials

487-503: Verify error recovery strategy for failed partials.

When a partial fails to process, the code logs the error and continues processing remaining partials. This graceful degradation might be acceptable, but it could lead to incomplete gradient application.

Confirm that validators correctly handle partial application in weight updates. Specifically:

  • Do validators track which UIDs had gradients successfully applied?
  • Are miners penalized if their gradient was in a failed partial?
  • Should the entire operation fail if any partial fails?

If complete application is required, consider collecting failures and raising after all partials are attempted:

if failed_partials:
    raise RuntimeError(
        f"Failed to process {len(failed_partials)} partials. "
        f"Errors: {[fp['error'] for fp in failed_partials]}"
    )

1621-1755: LGTM - Efficient partials support with O(1) lookup optimization.

The implementation correctly:

  • Builds a uid_to_partial mapping for O(1) lookup (addresses past review about O(n×m) complexity)
  • Handles both merged results (shared state_dict) and partials (per-UID state_dict)
  • Uses original UID indices for consistent pair keys across comparisons
  • Falls back gracefully when state_dicts are missing

The optimization reduces complexity from O(n×m) to O(n) as intended.


392-417: LGTM - Device consolidation is critical for multi-GPU support.

The recursive tensor movement to target device is essential for multi-GPU setups where different ranks produce tensors on different devices (e.g., rank 0 on cuda:0, rank 1 on cuda:1). Without this consolidation, applying gradients would fail with device mismatch errors.

The comment at line 391 helpfully explains this requirement.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
src/tplr/neurons.py (1)

485-486: Fix torch.cuda APIs called with string device and correct memory budget flag.

  • torch.cuda.memory_allocated/synchronize expect a device index or torch.device, not a str.
  • budget_exceeded is inverted; should be 1 when budget is exceeded, 0 otherwise. (Duplicate of earlier note.)

Apply this diff:

@@
-    if memory_budget_mb is not None and torch.cuda.is_available():
-        initial_memory_mb = torch.cuda.memory_allocated(device) / (1024**2)
+    if memory_budget_mb is not None and torch.cuda.is_available():
+        dev_obj = torch.device(device)
+        initial_memory_mb = torch.cuda.memory_allocated(dev_obj) / (1024**2)
@@
-            if torch.cuda.is_available():
-                torch.cuda.empty_cache()
-                torch.cuda.synchronize(device)
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
+                torch.cuda.synchronize(torch.device(device))
@@
-        if memory_budget_mb is not None and torch.cuda.is_available():
-            final_memory_mb = torch.cuda.memory_allocated(device) / (1024**2)
+        if memory_budget_mb is not None and torch.cuda.is_available():
+            dev_obj = torch.device(device)
+            final_memory_mb = torch.cuda.memory_allocated(dev_obj) / (1024**2)
@@
-            budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+            budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0
@@
-    if on_src and memory_budget_mb is not None and torch.cuda.is_available():
-        final_memory_mb = torch.cuda.memory_allocated(device) / (1024**2)
+    if on_src and memory_budget_mb is not None and torch.cuda.is_available():
+        dev_obj = torch.device(device)
+        final_memory_mb = torch.cuda.memory_allocated(dev_obj) / (1024**2)
@@
-        budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+        budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0

Also applies to: 591-596, 681-689, 981-986

🧹 Nitpick comments (3)
src/tplr/comms.py (2)

1792-1857: Optional: parallelize R2 existence checks to avoid serial awaits on slow paths.

Both exception and None branches await s3_object_exists per UID sequentially. Consider a small bounded-concurrency gather for these checks to reduce tail latency under widespread timeouts.


2096-2210: Default merge target_device to CPU to reduce OOM risk when merging large partials.

Merging moves all tensors to the current CUDA device by default; on large peer sets this can spike GPU memory. Prefer CPU by default; let callers opt-in to CUDA explicitly.

Apply this diff:

-        if target_device is None:
-            # Use the current device if CUDA is available, otherwise CPU
-            if torch.cuda.is_available():
-                target_device = f"cuda:{torch.cuda.current_device()}"
-            else:
-                target_device = "cpu"
+        if target_device is None:
+            # Safer default: consolidate on CPU to avoid GPU OOM during merges.
+            target_device = "cpu"
src/tplr/neurons.py (1)

388-396: Move function-scope import to module level to avoid repeated import and potential cycles.

_import inside merge_partial_batch is unnecessary. Import Comms at top-level for consistency/perf (unless this was added to break a cycle).

Apply this diff:

-    from tplr.comms import Comms
+    # at module top:
+    # from tplr.comms import Comms

And remove the in-function import.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e13eb9c and 0609f6a.

📒 Files selected for processing (4)
  • hparams/2B.json (1 hunks)
  • neurons/validator.py (13 hunks)
  • src/tplr/comms.py (11 hunks)
  • src/tplr/neurons.py (9 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • neurons/validator.py
  • hparams/2B.json
🧰 Additional context used
🧬 Code graph analysis (2)
src/tplr/neurons.py (2)
src/tplr/compress.py (4)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (2)
  • merge_gather_results (2096-2209)
  • get (1467-1582)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (1)
src/tplr/comms.py (1)

2432-2451: Good fix: all-ranks sequential fallback preserves consistency.

Sequential fallback no longer runs only on master; all ranks perform the same path. This avoids rank divergence.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
neurons/miner.py (1)

876-882: Fix AttributeError when gather_result is a list of partials.

When distributed gather returns a list of partials, there's no top-level skipped_uids attribute. Accessing gather_result.skipped_uids at line 879 and 881 will raise an AttributeError.

Apply this diff to handle both merged and partial results:

                # Add successful peers information
                if gather_result is not None:
-                    debug_dict["successful_peers"] = sorted(
-                        list(set(self.comms.peers) - set(gather_result.skipped_uids))
-                    )
-                    debug_dict["skipped_peers"] = sorted(
-                        list(gather_result.skipped_uids)
-                    )
+                    # Handle both merged results and list of partials
+                    if isinstance(gather_result, list):
+                        # For partials, collect all UIDs that were successfully gathered
+                        successful_uids = set()
+                        for partial in gather_result:
+                            if partial is not None:
+                                successful_uids.update(getattr(partial, "uids", []))
+                        debug_dict["successful_peers"] = sorted(list(successful_uids))
+                        debug_dict["skipped_peers"] = sorted(list(set(self.comms.peers) - successful_uids))
+                    else:
+                        # Merged result has skipped_uids attribute
+                        debug_dict["successful_peers"] = sorted(
+                            list(set(self.comms.peers) - set(gather_result.skipped_uids))
+                        )
+                        debug_dict["skipped_peers"] = sorted(
+                            list(gather_result.skipped_uids)
+                        )
♻️ Duplicate comments (2)
src/tplr/neurons.py (2)

686-686: Fix inverted budget_exceeded flag logic.

The budget_exceeded flag is inverted: it's set to 0 when memory exceeds the budget and 1 when within budget. This will cause incorrect metrics and monitoring.

Apply this diff:

-            budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+            budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0

Based on learnings from past review comments (lines 686 flagged previously).


977-1006: Fix inverted budget_exceeded flag in single-result path.

The budget_exceeded flag has the same inversion issue as the incremental path (line 686): it's 0 when exceeded and 1 when within budget.

Apply this diff:

-        budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+        budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0

Based on learnings from past review comments (line 854 in old numbering).

🧹 Nitpick comments (1)
ecosystem.config.js (1)

5-11: Remove unused RANDOM_SUFFIX generation.

The RANDOM_SUFFIX is generated but no longer used after switching to a static PROJECT_NAME. Consider removing lines 5-9 to eliminate dead code.

Apply this diff:

-const { execSync } = require('child_process');
-const RANDOM_SUFFIX = execSync(
-  "cat /dev/urandom | tr -dc 'a-z0-9' | fold -w 4 | head -n 1"
-)
-  .toString()
-  .trim();
-
 const PROJECT_NAME = `distribute-gather`;

Additionally, verify whether the static project name could cause conflicts if multiple test environments run concurrently.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0609f6a and 2aa28e3.

📒 Files selected for processing (6)
  • ecosystem.config.js (6 hunks)
  • hparams/hparams.json (1 hunks)
  • neurons/miner.py (1 hunks)
  • neurons/trainer.py (1 hunks)
  • src/tplr/model_factory.py (1 hunks)
  • src/tplr/neurons.py (9 hunks)
✅ Files skipped from review due to trivial changes (1)
  • hparams/hparams.json
🧰 Additional context used
🧬 Code graph analysis (2)
neurons/miner.py (3)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/logging.py (1)
  • T (41-48)
src/tplr/comms.py (2)
  • gather_with_reserve (2338-2610)
  • gather (1686-2064)
src/tplr/neurons.py (2)
src/tplr/compress.py (5)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • TopKCompressor (295-696)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (2)
  • merge_gather_results (2096-2209)
  • get (1467-1582)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (12)
src/tplr/model_factory.py (1)

256-256: LGTM: Validator dp_shard constraint relaxation.

Changing dp_shard from 4 to 1 for validators removes divisibility requirements, allowing validators to run with any world size. This aligns with the distributed gather enhancements in this PR.

neurons/trainer.py (1)

1004-1009: LGTM: Proper integration of memory budgeting parameters.

The trainer correctly forwards the new memory_budget_mb and batch_size_override parameters to outer_step, using safe getattr with None defaults for backward compatibility.

neurons/miner.py (1)

705-762: LGTM: Distributed gather implementation.

The distributed gather path is well-structured with proper rank logging, error handling, and fallback to sequential mode on failure. The use of return_partials=True enables memory-bounded incremental processing.

src/tplr/neurons.py (6)

22-22: LGTM: Time import added for batch timing.

The time module import is used for batch processing timing measurements in the incremental path.


226-297: LGTM: Clean dispatch pattern for incremental vs single-result processing.

The refactored outer_step properly dispatches to incremental or single-result implementations based on whether gather_result is a list. The extended signature with memory_budget_mb and batch_size_override enables memory-bounded processing.


300-396: LGTM: Helper functions for memory-bounded batch processing.

The helper functions provide reasonable heuristics for memory estimation (~40MB per UID) and batch size calculation with a 0.7 safety factor. The merge helper correctly delegates to existing merge logic.


516-627: LGTM: Robust batch processing with graceful error recovery.

The batch processing loop properly:

  • Checks for duplicate UIDs
  • Handles exceptions gracefully without breaking the loop
  • Performs aggressive memory cleanup after each batch
  • Tracks failed batches for comprehensive error reporting

643-720: LGTM: Proper statistics aggregation across partials.

The fingerprint finalization correctly:

  • Computes final norms via sqrt of accumulated squared norms
  • Averages means across batches
  • Logs comprehensive summaries including failed batches
  • Integrates memory metrics with WandB

1749-1883: LGTM: Efficient partial support in overlap checking.

The refactored overlap logic properly:

  • Handles both merged and partial results without unnecessary merging
  • Uses O(1) uid_to_partial mapping to avoid O(n×m) complexity
  • Maintains original UID indices for consistent pair keys
  • Correctly unpacks 12-bit indices per UID

The O(n×m) optimization from past reviews is properly implemented.

ecosystem.config.js (3)

20-20: Verify virtual environment path across all deployment targets.

The script path now hardcodes .venv/bin/torchrun (also on lines 44, 68, 94). Ensure this path exists and is consistent across all environments where this configuration will be deployed.


64-87: LGTM!

The new TM3 miner app is properly configured and follows the same structure as TM1 and TM2, which aligns with the PR's goal of supporting more miners.


37-111: Remove unsubstantiated GPU testing claim; verify target environment has sufficient GPUs.

The ecosystem.config.js configuration allocates 7 GPUs (TM1: 0,1; TM2: 2,3; TM3: 4,5; TV1: 6). While this GPU requirement is accurately reflected in the code, the PR commit message does not mention testing with any specific GPU count. Ensure your target deployment environment has at least 7 GPUs available, or adjust the configuration to match your hardware.

Likely an incorrect or invalid review comment.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2aa28e3 and b8162c0.

📒 Files selected for processing (2)
  • hparams/2B.json (1 hunks)
  • neurons/miner.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • hparams/2B.json
🧰 Additional context used
🧬 Code graph analysis (1)
neurons/miner.py (3)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/logging.py (1)
  • T (41-48)
src/tplr/comms.py (2)
  • gather_with_reserve (2338-2610)
  • gather (1686-2064)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (2)
neurons/miner.py (2)

705-710: LGTM: Robust distributed-mode detection.

The three-way check (config flag, world_size, and runtime state) is defensive and correct.


803-805: LGTM: Proper synchronization.

Broadcasting should_update ensures all ranks agree before proceeding to outer_step, preventing divergence.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (3)
neurons/miner.py (1)

935-1045: Fix distributed gather_result attribute access.

When return_partials=True, gather_result is a list[SimpleNamespace]. The blocks starting at Line 937 and Line 978 still treat it as a merged namespace (gather_result.skipped_uids, gather_result.success_rate, etc.), so the first distributed gather will raise AttributeError on every rank. Merge the partials (e.g., via self.comms.merge_gather_results(...)) or add safe accessors before computing debug/metrics fields.

src/tplr/neurons.py (2)

686-686: Inverted budget_exceeded flag persists (duplicate issue).

This line still has the inverted logic flagged in previous reviews: the flag is set to 0 when budget is exceeded and 1 when within budget.

The same inverted flag is then logged to WandB at line 714, propagating incorrect metrics.


983-983: Inverted budget_exceeded flag persists (duplicate issue).

This line has the same inverted logic flagged in previous reviews.

The inverted flag is then logged to WandB at line 1002.

🧹 Nitpick comments (1)
src/tplr/neurons.py (1)

315-318: Consider making the memory estimation heuristic configurable.

The hard-coded 40MB per UID estimate may not fit all model sizes. Consider adding a hyperparameter like gradient_mb_per_uid to allow tuning without code changes.

Apply this diff to make it configurable:

 def _estimate_partial_memory_mb(partial: SimpleNamespace) -> float:
     """
     Estimate memory usage of a partial result in MB.
 
-    This is a heuristic based on the number of UIDs and typical gradient size.
+    This is a heuristic based on the number of UIDs and typical gradient size.
+    The estimate can be tuned via hparams.gradient_mb_per_uid if needed.
 
     Args:
         partial: Partial result from distributed gather
 
     Returns:
         Estimated memory in MB
     """
     if partial is None or not hasattr(partial, "uids"):
         return 0.0
 
-    # Rough heuristic: ~30-50MB per UID for typical gradients
-    # This can be tuned based on actual measurements
     num_uids = len(getattr(partial, "uids", []))
-    avg_mb_per_uid = 40.0  # Conservative estimate
+    # Default 40MB, can be overridden via hparams if needed
+    avg_mb_per_uid = 40.0
 
     return num_uids * avg_mb_per_uid

Then pass it as a parameter if the hparam exists.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b8162c0 and bb147b3.

📒 Files selected for processing (5)
  • hparams/2B.json (1 hunks)
  • neurons/miner.py (2 hunks)
  • neurons/trainer.py (1 hunks)
  • src/tplr/comms.py (13 hunks)
  • src/tplr/neurons.py (9 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • neurons/trainer.py
  • hparams/2B.json
🧰 Additional context used
🧬 Code graph analysis (3)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
src/tplr/neurons.py (2)
src/tplr/compress.py (6)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • ChunkingTransformer (145-292)
  • TopKCompressor (295-696)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (2)
  • merge_gather_results (2102-2215)
  • get (1473-1588)
neurons/miner.py (4)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/logging.py (1)
  • T (41-48)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/comms.py (2)
  • gather_with_reserve (2344-2616)
  • gather (1692-2070)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (1)
src/tplr/neurons.py (1)

1749-1846: LGTM: Efficient partial handling implementation.

The O(1) uid_to_partial lookup (lines 1754-1764, 1822) successfully addresses the O(n×m) complexity concern from previous reviews. The dual-path logic cleanly handles both partials and merged results.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bb147b3 and 6ce2c6d.

📒 Files selected for processing (4)
  • hparams/hparams.json (1 hunks)
  • neurons/base_node.py (1 hunks)
  • neurons/miner.py (5 hunks)
  • src/tplr/comms.py (13 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
neurons/miner.py (4)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/logging.py (1)
  • T (41-48)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/comms.py (2)
  • gather_with_reserve (2369-2641)
  • gather (1714-2092)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (2)
hparams/hparams.json (1)

22-22: Clarify the intent of changing the default model size from 70B to 2B.

This change updates the default model configuration from 70B to 2B, which is a significant architectural change. The PR also introduces a new hparams/2B.json file. Please clarify:

  1. Is this change intentional?
  2. Should the default remain 70B with 2B as an alternative configuration file (rather than replacing the default)?
  3. Will this break existing workflows, training jobs, or downstream code that assumes the 70B model?
  4. Are there backward compatibility concerns?
neurons/base_node.py (1)

61-62: Verified: Cooperative multiple inheritance correctly implemented.

The addition of super().__init__() at the end of BaseNode.__init__() properly completes the MRO chain for multiple inheritance scenarios. Verification confirms:

  • Miner(BaseNode, Trainer) calls super().__init__() at its end (line 197)
  • Validator(BaseNode, Trainer) calls super().__init__() at its end (line 377)
  • BaseNode now calls super().__init__() at line 62, delegating to Trainer
  • Trainer intentionally stops the chain (no super() call), serving as the terminus

This follows the proper cooperative multiple inheritance pattern in Python and integrates correctly with the existing subclass implementations.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
neurons/miner.py (1)

707-745: LGTM: Sequential gather correctly addresses previous critical issues.

The sequential gather implementation properly resolves the issues flagged in previous reviews:

  • Returns a single SimpleNamespace instead of a list of partials, eliminating AttributeError risks
  • Includes local=False and stale_retention=100 to pull from remote storage (R2)
  • All attribute accesses on gather_result are properly guarded with None checks
  • Appropriate master-only execution with broadcast to other ranks

Optional: Defensive coding for expected_compressed_params

While the past review indicates the initialization issue was fixed via MRO chain repair, consider defensive coding to guard against potential AttributeError:

                         time_min=time_min,
                         time_max=time_max,
-                        expected_compressed_params=self.expected_compressed_params,
+                        expected_compressed_params=getattr(self, 'expected_compressed_params', None),
                     )

This ensures the code won't crash if the attribute is somehow missing, since gather() handles None gracefully (defaults to empty set).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6ce2c6d and 8c12640.

📒 Files selected for processing (1)
  • neurons/miner.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
neurons/miner.py (3)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/logging.py (1)
  • T (41-48)
src/tplr/comms.py (1)
  • gather (1714-2092)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (1)
neurons/miner.py (1)

468-469: LGTM: Commitment refresh ensures up-to-date bucket info.

Refreshing commitments after updating peers is the correct approach to ensure bucket information is current before the gather operation.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
neurons/validator.py (1)

3911-3933: Indent merged‑result branch (syntax error).

state_dict and subsequent lines must be inside the else block.

@@
-        else:
-            # Merged result: original logic
-        state_dict = gather_result.state_dict
-        if not state_dict:
-            raise ValueError("Must have gather_result.state_dict to compute norms")
+        else:
+            # Merged result: original logic
+            state_dict = gather_result.state_dict
+            if not state_dict:
+                raise ValueError("Must have gather_result.state_dict to compute norms")
@@
-        for n, p in self.model.named_parameters():
+            for n, p in self.model.named_parameters():
@@
-            vals = getattr(state_dict, vals_key, None)
-            quant_params = getattr(state_dict, quant_key, None)
+                vals = getattr(state_dict, vals_key, None)
+                quant_params = getattr(state_dict, quant_key, None)
@@
-            if vals is None:
-                continue
+                if vals is None:
+                    continue
@@
-            vals_f32 = self.compressor.maybe_dequantize_values(
-                vals, quant_params, p.device
-            )
+                vals_f32 = self.compressor.maybe_dequantize_values(
+                    vals, quant_params, p.device
+                )
@@
-            norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device)
-            clip_norm_dict[vals_key] = torch.median(norms)
+                norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device)
+                clip_norm_dict[vals_key] = torch.median(norms)
♻️ Duplicate comments (4)
neurons/validator.py (1)

3837-3910: Align quant_params with vals across partials (prevents IndexError).

If a partial has vals but missing/short quant_params, later dequantization can index past the end.

@@
-                if vals is not None:
-                    # vals is a list of tensors, one per peer in this partial
-                    # Move all tensors (including nested) to target device
-                    for v in vals:
-                        all_vals.append(move_to_device_recursive(v))
-                    # quant_params is also a list if present
-                    if quant_params is not None:
-                        for qp in quant_params:
-                            all_quant_params.append(move_to_device_recursive(qp))
+                if vals is not None:
+                    num_vals = len(vals)
+                    # Move all tensors (including nested) to target device
+                    for v in vals:
+                        all_vals.append(move_to_device_recursive(v))
+                    # Ensure quant_params list stays index‑aligned with vals
+                    if quant_params is not None and len(quant_params) >= num_vals:
+                        for qp in quant_params[:num_vals]:
+                            all_quant_params.append(move_to_device_recursive(qp))
+                    elif quant_params is not None and 0 < len(quant_params) < num_vals:
+                        for qp in quant_params:
+                            all_quant_params.append(move_to_device_recursive(qp))
+                        # pad with None for the remainder
+                        all_quant_params.extend([None] * (num_vals - len(quant_params)))
+                        tplr.logger.warning(f"Inconsistent quant_params length for {vals_key}; padding with None")
+                    else:
+                        # no quant_params for this partial → pad
+                        all_quant_params.extend([None] * num_vals)

Based on learnings

src/tplr/neurons.py (3)

697-705: Invert memory_budget_violation flag (1 = violation, 0 = OK).

Same inverted logic previously flagged; fix to avoid misleading monitoring.

-            budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+            budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0

1001-1021: Invert memory_budget_violation flag here as well.

Mirror the fix in the single‑result path.

-        budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+        budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0

300-339: Support SimpleNamespace and lists in partial memory estimator.

Current code only handles dict; partial.state_dict is often SimpleNamespace with list values.

@@
-    # Measure actual tensor sizes in state_dict
-    if hasattr(partial, "state_dict") and isinstance(partial.state_dict, dict):
-        for value in partial.state_dict.values():
-            if isinstance(value, torch.Tensor):
-                total_bytes += value.element_size() * value.nelement()
-            elif isinstance(value, dict):
-                # Handle nested dicts (e.g., compressed gradients)
-                for nested_val in value.values():
-                    if isinstance(nested_val, torch.Tensor):
-                        total_bytes += nested_val.element_size() * nested_val.nelement()
+    # Measure actual tensor sizes in state_dict (SimpleNamespace or dict)
+    if hasattr(partial, "state_dict"):
+        sd = partial.state_dict
+        values = (
+            vars(sd).values()
+            if isinstance(sd, SimpleNamespace)
+            else (sd.values() if isinstance(sd, dict) else [])
+        )
+        for value in values:
+            if isinstance(value, torch.Tensor):
+                total_bytes += value.element_size() * value.nelement()
+            elif isinstance(value, (list, tuple)):
+                for v in value:
+                    if isinstance(v, torch.Tensor):
+                        total_bytes += v.element_size() * v.nelement()
+            elif isinstance(value, dict):
+                for v in value.values():
+                    if isinstance(v, torch.Tensor):
+                        total_bytes += v.element_size() * v.nelement()
🧹 Nitpick comments (1)
src/tplr/comms.py (1)

2473-2527: Distributed gather with proper fallback logic.

The fallback logic correctly ensures all ranks fall back together (lines 2511-2527) to maintain consistency when distributed gather fails. This addresses the previous review comment about rank consistency.

The logging clearly distinguishes between distributed and sequential modes, and the error handling is comprehensive.

Consider adding a comment explaining the consistency requirement:

# Fallback to sequential if distributed failed
# ALL ranks must fall back together to maintain consistency
if primary is None:
    # Sequential gather can be done independently by each rank
    # since it fetches from R2, not inter-rank communication
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8c12640 and c608455.

📒 Files selected for processing (3)
  • neurons/validator.py (12 hunks)
  • src/tplr/comms.py (16 hunks)
  • src/tplr/neurons.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
src/tplr/neurons.py (2)
src/tplr/compress.py (6)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • ChunkingTransformer (145-292)
  • TopKCompressor (295-696)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (3)
  • Comms (63-3392)
  • merge_gather_results (2170-2283)
  • get (1499-1615)
neurons/validator.py (6)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
neurons/base_node.py (1)
  • query_block_timestamp (205-250)
src/tplr/comms.py (2)
  • gather_with_reserve (2415-2687)
  • merge_gather_results (2170-2283)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1751-1956)
src/tplr/compress.py (1)
  • maybe_dequantize_values (646-696)
🪛 GitHub Actions: CI
neurons/validator.py

[error] 1457-1457: invalid-syntax: Expected an indented block after if statement


[error] 1465-1465: invalid-syntax: unindent does not match any outer indentation level


[error] 1466-1466: invalid-syntax: Expected a statement


[error] 1590-1591: invalid-syntax: Expected an indented block after if statement


[error] 1597-1599: invalid-syntax: Unexpected indentation around else block


[error] 3912-3914: invalid-syntax: Expected an indented block after else clause


[error] 4537-4537: Syntax issue detected at function definition

🔇 Additional comments (18)
src/tplr/comms.py (13)

48-48: Import moved to module level - good fix!

This addresses the previous review comment about moving dist_helper import from inside a function to module level, improving performance and readability.


478-577: Excellent implementation of s3_object_exists with structured error handling.

This method properly distinguishes between different error types using status codes and error codes instead of string matching. The conservative approach (returning True for transient errors) is appropriate to avoid false slashing. Good logging at ERROR level for configuration issues that need attention.


1465-1497: Good fix for gradient_timestamp signature.

Adding an optional version parameter that defaults to tplr.__version__ addresses the previous review comment about undefined self.version. This allows flexibility while maintaining backward compatibility.


703-717: Critical performance fix: offloading torch.load() to thread pool.

This prevents blocking the event loop during CPU decompression, which can take 1-2s per file and would otherwise serialize parallel downloads. The fix is correctly applied using asyncio.to_thread().


1557-1558: torch.load() offloading also applied to local storage path.

Consistent with the remote storage fix, ensuring non-blocking behavior across all code paths.


1836-1867: Excellent optimization: parallel S3 existence checks for failed downloads.

Batching all existence checks and running them in parallel avoids the serial API calls that would add ~5s per failure. The cached results are properly reused later in the error handling logic. Good use of asyncio.gather with return_exceptions=True.


1869-1931: Well-structured error handling that prevents false slashing.

The code correctly distinguishes between network/timeout issues (file exists but download failed) and miner failures (file doesn't exist). Uses the cached existence check results efficiently and provides clear logging with appropriate levels.


2144-2167: Correct round-robin partitioning implementation.

The deterministic sorting ensures all ranks agree on the assignment, and the slicing approach sorted_uids[rank::world_size] correctly distributes UIDs evenly across ranks.


2198-2283: Comprehensive merging logic with proper device consolidation.

The method correctly:

  • Filters out None results
  • Recursively moves tensors to target device via move_to_device helper
  • Merges state_dict by concatenating lists for each parameter
  • Aggregates metrics (max time, sum bytes, extend arrays)
  • Calculates accurate success rate as len(all_uids) / total_attempted

Good logging of the merge summary.


2285-2410: Well-implemented distributed gather with proper synchronization.

The method correctly:

  • Partitions UIDs deterministically across ranks for parallel fetching
  • Each rank independently gathers its assigned subset
  • Uses all_gather_object for barrier synchronization (correctly called synchronously)
  • Adds rank metadata to partials for deterministic ordering
  • Supports both incremental (return_partials=True) and merged modes

Good contextual logging with rank and window information throughout.


2552-2642: Reserve fallback logic correctly handles both partials and merged results.

The code properly:

  • Detects missing UIDs for both result types (lines 2552-2580)
  • Uses sequential gather for reserve peers (not distributed, which is correct)
  • Appends reserve as new partial with proper rank metadata (lines 2613-2625)
  • Merges reserve into merged result by extending arrays (lines 2627-2642)

Good contextual logging for both paths.


2647-2686: Success rate calculation and final validation are correct.

The method properly:

  • Calculates success rate relative to original gather_uids (not including reserves)
  • Handles both partials list and merged SimpleNamespace
  • Updates success_rate attribute on partials for consistency (lines 2654-2657)
  • Returns None if no gradients were gathered (lines 2671-2685)

The aggregate success rate calculation for partials (lines 2649-2652) is correct.


1661-1728: Good optimization: early termination for repeated NOT_FOUND errors.

Tracking consecutive NOT_FOUND errors and giving up after 5 attempts (2.5s total) prevents unnecessary retries when a file likely doesn't exist. The counter is properly reset on other error types that might be transient network issues.

neurons/validator.py (2)

91-134: Helpers for UIDs/skips/success-rate look good.

Uniform handling across partials vs merged is clear and side‑effect free.


1391-1410: Confirm broadcast_object_list compatibility across ranks.

All ranks must supply lists of identical length/types; also large peer lists can bloat pickles. If size grows, consider broadcasting only IDs and reconstructing from metagraph.

Would you like a follow‑up patch to broadcast only UID tensors to cut overhead?

src/tplr/neurons.py (3)

416-451: Incremental batching pipeline reads well.

Batch sizing, merging via Comms.merge_gather_results, duplicate‑UID guard, and per‑batch cleanup are solid.

Consider logging per‑batch GPU mem deltas to validate budget sizing heuristics in production.


741-771: Single‑result path is efficient and DTensor‑aware.

Param‑at‑a‑time update, direct broadcast into p.grad, and fingerprinting are clean.

Also applies to: 789-899, 941-971


1753-1956: Overlap checker partials path avoids O(n×m) lookup—nice.

UID→partial map and consistent pair accounting improve scalability.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
neurons/validator.py (1)

3835-3933: Fix UnboundLocalError and align quant_params in compute_peer_val_norms()

When gather_result is a list, state_dict is undefined but used later, causing a crash. Also, vals and quant_params can become misaligned across partials.

Apply:

@@
-    def compute_peer_val_norms(
-        self,
-        gather_result: SimpleNamespace | list[SimpleNamespace],
-    ) -> dict[str, torch.Tensor]:
+    def compute_peer_val_norms(
+        self,
+        gather_result: SimpleNamespace | list[SimpleNamespace],
+    ) -> dict[str, torch.Tensor]:
@@
-        clip_norm_dict = {}
-        # Handle both merged results and partials
-        if isinstance(gather_result, list):
-            # Partials: collect vals from all partials for each parameter
-            # Use model's first parameter device as target (typically cuda:0)
-            target_device = next(self.model.parameters()).device
-            # Helper to recursively move all tensors to target device
-            def move_to_device_recursive(obj):
-                if torch.is_tensor(obj):
-                    return obj.to(target_device)
-                elif isinstance(obj, dict):
-                    return {k: move_to_device_recursive(v) for k, v in obj.items()}
-                elif isinstance(obj, (list, tuple)):
-                    return type(obj)(move_to_device_recursive(item) for item in obj)
-                else:
-                    return obj
-            for n, p in self.model.named_parameters():
-                vals_key = n + "vals"
-                quant_key = n + "quant_params"
-                # Collect vals and quant_params from all partials
-                all_vals = []
-                all_quant_params = []
-                for partial in gather_result:
-                    # Skip None partials (ranks that gathered 0 peers)
-                    if partial is None:
-                        continue
-                    state_dict = partial.state_dict
-                    if not state_dict:
-                        continue
-                    vals = getattr(state_dict, vals_key, None)
-                    quant_params = getattr(state_dict, quant_key, None)
-                    if vals is not None:
-                        # vals is a list of tensors, one per peer in this partial
-                        # Move all tensors (including nested) to target device
-                        for v in vals:
-                            all_vals.append(move_to_device_recursive(v))
-                        # quant_params is also a list if present
-                        if quant_params is not None:
-                            for qp in quant_params:
-                                all_quant_params.append(move_to_device_recursive(qp))
-                if not all_vals:
-                    continue
-                # Dequantize all collected values (now all on same device)
-                vals_f32 = self.compressor.maybe_dequantize_values(
-                    all_vals,
-                    all_quant_params if all_quant_params else None,
-                    target_device,
-                )
-                # Compute median norm across all peers from all partials
-                norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device)
-                clip_norm_dict[vals_key] = torch.median(norms)
-        else:
-            # Merged result: original logic
-            state_dict = gather_result.state_dict
-            if not state_dict:
-                raise ValueError("Must have gather_result.state_dict to compute norms")
-        for n, p in self.model.named_parameters():
-            vals_key = n + "vals"
-            quant_key = n + "quant_params"
-            vals = getattr(state_dict, vals_key, None)
-            quant_params = getattr(state_dict, quant_key, None)
-            if vals is None:
-                continue
-            vals_f32 = self.compressor.maybe_dequantize_values(
-                vals, quant_params, p.device
-            )
-            norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device)
-            clip_norm_dict[vals_key] = torch.median(norms)
-        return clip_norm_dict
+        clip_norm_dict: dict[str, torch.Tensor] = {}
+
+        # Common helper
+        target_device = next(self.model.parameters()).device
+        def move_to_device_recursive(obj):
+            if torch.is_tensor(obj):
+                return obj.to(target_device)
+            if isinstance(obj, dict):
+                return {k: move_to_device_recursive(v) for k, v in obj.items()}
+            if isinstance(obj, (list, tuple)):
+                return type(obj)(move_to_device_recursive(x) for x in obj)
+            return obj
+
+        if isinstance(gather_result, list):
+            # Partials: aggregate across partials
+            for n, p in self.model.named_parameters():
+                vals_key = n + "vals"
+                quant_key = n + "quant_params"
+                all_vals: list[torch.Tensor] = []
+                all_qp: list[QuantParamsT | None] = []
+                for partial in gather_result:
+                    if partial is None or not getattr(partial, "state_dict", None):
+                        continue
+                    state_dict = partial.state_dict
+                    vals = getattr(state_dict, vals_key, None)
+                    qps = getattr(state_dict, quant_key, None)
+                    if vals is None:
+                        continue
+                    # Append vals
+                    for v in vals:
+                        all_vals.append(move_to_device_recursive(v))
+                    # Align quant_params length with vals length
+                    num_vals = len(vals)
+                    if qps is None:
+                        all_qp.extend([None] * num_vals)
+                    else:
+                        qps_moved = [move_to_device_recursive(q) for q in qps[:num_vals]]
+                        if len(qps_moved) < num_vals:
+                            qps_moved.extend([None] * (num_vals - len(qps_moved)))
+                        all_qp.extend(qps_moved)
+                if not all_vals:
+                    continue
+                vals_f32 = self.compressor.maybe_dequantize_values(all_vals, all_qp, target_device)
+                norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device)
+                clip_norm_dict[vals_key] = torch.median(norms)
+            return clip_norm_dict
+
+        # Merged result path
+        state_dict = getattr(gather_result, "state_dict", None)
+        if not state_dict:
+            raise ValueError("Must have gather_result.state_dict to compute norms")
+        for n, p in self.model.named_parameters():
+            vals_key = n + "vals"
+            quant_key = n + "quant_params"
+            vals = getattr(state_dict, vals_key, None)
+            if vals is None:
+                continue
+            qps = getattr(state_dict, quant_key, None)
+            vals_f32 = self.compressor.maybe_dequantize_values(vals, qps, p.device)
+            norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device)
+            clip_norm_dict[vals_key] = torch.median(norms)
+        return clip_norm_dict

Based on learnings.

src/tplr/comms.py (1)

578-607: Docstring says bucket defaults to self.bucket, but code doesn't; add fallback.

Without a fallback, callers passing bucket=None will crash. Enforce the documented default before creating the client.

 async def s3_get_object(
     self,
     key: str,
-    bucket: Bucket | None = None,
+    bucket: Bucket | None = None,
     timeout: int = 30,
@@
-    s3_client = await self._get_s3_client(bucket)
+    if bucket is None:
+        bucket = self.bucket
+    s3_client = await self._get_s3_client(bucket)
♻️ Duplicate comments (3)
src/tplr/comms.py (1)

2191-2197: Guard torch.cuda.current_device() with try/except and CPU fallback.

current_device() can raise if no context is set. Use a safe fallback.

-        if target_device is None:
-            # Use the current device if CUDA is available, otherwise CPU
-            if torch.cuda.is_available():
-                target_device = f"cuda:{torch.cuda.current_device()}"
-            else:
-                target_device = "cpu"
+        if target_device is None:
+            if torch.cuda.is_available():
+                try:
+                    target_device = f"cuda:{torch.cuda.current_device()}"
+                except RuntimeError:
+                    target_device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu"
+            else:
+                target_device = "cpu"
src/tplr/neurons.py (2)

501-504: Do not pass string device to CUDA APIs; use default or torch.device.

torch.cuda.memory_allocated()/synchronize() with "cuda:0" (str) can raise. Omit the arg or pass torch.device.

-    if memory_budget_mb is not None and torch.cuda.is_available():
-        initial_memory_mb = torch.cuda.memory_allocated(device) / (1024**2)
+    if memory_budget_mb is not None and torch.cuda.is_available():
+        initial_memory_mb = torch.cuda.memory_allocated() / (1024**2)
@@
-            if torch.cuda.is_available():
-                torch.cuda.empty_cache()
-                torch.cuda.synchronize(device)
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
+                torch.cuda.synchronize()
@@
-            final_memory_mb = torch.cuda.memory_allocated(device) / (1024**2)
+            final_memory_mb = torch.cuda.memory_allocated() / (1024**2)
@@
-        final_memory_mb = torch.cuda.memory_allocated(device) / (1024**2)
+        final_memory_mb = torch.cuda.memory_allocated() / (1024**2)
@@
-            tplr.logger.warning(
+            tplr.logger.warning(
                 f"⚠️  GPU Memory: {final_memory_mb:.0f}MB / {memory_budget_mb:.0f}MB "
                 f"({utilization_pct:.0f}% - EXCEEDED by {final_memory_mb - memory_budget_mb:.0f}MB)"
             )
-        else:
+        else:
             tplr.logger.info(
                 f"GPU Memory: {final_memory_mb:.0f}MB / {memory_budget_mb:.0f}MB "
                 f"({utilization_pct:.0f}% utilized)"
             )

Also applies to: 617-623, 697-705, 996-1002, 1003-1014


704-705: Flip memory_violation flag: 1 = violation, 0 = OK.

Current logic is inverted and misleads dashboards.

-            budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+            budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0

Also applies to: 1001-1002

🧹 Nitpick comments (7)
neurons/miner.py (1)

559-561: Guard CUDA-only calls (optional)

empty_cache/synchronize should be gated to CUDA to avoid surprises in CPU/testing.

-            torch.cuda.empty_cache()
-            torch.cuda.synchronize()
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
+                torch.cuda.synchronize(self.device)
neurons/validator.py (2)

138-203: Handle partial lists in offload/load helpers to actually free GPU memory

offload_gather_results/load_gather_results_to_device operate only on a merged SimpleNamespace. With distributed gather (partials), tensors remain on GPU.

@@ def offload_gather_results(self, gather_result, *, log: bool = True) -> None:
-        if gather_result is None:
+        if gather_result is None:
             return
+        # NEW: handle list of partials
+        if isinstance(gather_result, list):
+            count = 0
+            for partial in gather_result:
+                if partial is not None:
+                    self.offload_gather_results(partial, log=False)
+                    count += 1
+            if log and count:
+                tplr.log_with_context(
+                    level="info",
+                    message=f"[GatherOffload] Offloaded {count} partial(s) to CPU pinned",
+                    sync_window=self.sync_window,
+                    current_window=self.current_window,
+                )
+            return
@@ def load_gather_results_to_device(self, gather_result, device: torch.device, *, log: bool = True) -> None:
-        if gather_result is None:
+        if gather_result is None:
             return
+        # NEW: handle list of partials
+        if isinstance(gather_result, list):
+            count = 0
+            for partial in gather_result:
+                if partial is not None:
+                    self.load_gather_results_to_device(partial, device, log=False)
+                    count += 1
+            if log and count:
+                tplr.log_with_context(
+                    level="info",
+                    message=f"[GatherOffload] Loaded {count} partial(s) back to {device.type}",
+                    sync_window=self.sync_window,
+                    current_window=self.current_window,
+                )
+            return

1563-1574: Nit: pass device as string for consistency

gather_with_reserve accepts device used in map_location/to(). Passing str(self.device) matches miner and avoids type surprises with custom loaders.

-                    device=cast(str, self.device),
+                    device=str(self.device),
src/tplr/comms.py (3)

650-655: Prefer structured ClientError inspection over string matching in HEAD error path.

Replace "404" string checks with HTTPStatusCode/Error.Code parsing (consistent with s3_object_exists), and keep purging on client errors.

-            except (ConnectionClosedError, ClientError) as e:
-                await self._purge_s3_client(bucket)
-                if "404" in str(e):
-                    tplr.logger.debug(f"Object {key} not found in bucket {bucket.name}")
-                    return None
+            except (ConnectionClosedError, ClientError) as e:
+                await self._purge_s3_client(bucket)
+                status = None
+                code = None
+                if isinstance(e, ClientError) and hasattr(e, "response"):
+                    status = e.response.get("ResponseMetadata", {}).get("HTTPStatusCode")
+                    code = e.response.get("Error", {}).get("Code")
+                if status == 404 or code in {"404", "NoSuchKey", "NotFound"}:
+                    tplr.logger.debug(f"Object {key} not found in bucket {bucket.name}")
+                    return None
+                if status in {401, 403} or code in {"AccessDenied", "InvalidAccessKeyId", "SignatureDoesNotMatch"}:
+                    tplr.logger.error(f"Permission error checking {key} in {bucket.name}: {e}")
+                    return None
+                # For transient 5xx/network errors, let caller retry
+                tplr.logger.warning(f"Transient error on HEAD {key}: {e}")
+                return None

478-577: Purge broken S3 client on connection close to avoid reuse of a bad pool.

If the underlying connection is closed, purge the cached client so next call recreates it. Keeps semantics unchanged.

-        except (ConnectionClosedError, ClientError) as e:
+        except (ConnectionClosedError, ClientError) as e:
+            # Drop broken clients so future calls re-create a fresh one
+            if isinstance(e, ConnectionClosedError):
+                try:
+                    await self._purge_s3_client(bucket)
+                except Exception:
+                    pass

1850-1867: Throttle parallel existence checks to avoid overwhelming R2/S3.

Bound concurrency with a semaphore (reuse self.client_semaphore or a local one).

-                if failed_checks:
-                    existence_tasks = [
-                        self.s3_object_exists(key=check[1], bucket=check[2], timeout=5)
-                        for check in failed_checks
-                    ]
-                    existence_responses = await asyncio.gather(
-                        *existence_tasks, return_exceptions=True
-                    )
+                if failed_checks:
+                    sem = asyncio.Semaphore(min(32, len(failed_checks)))
+                    async def _checked(key, bucket):
+                        async with sem:
+                            return await self.s3_object_exists(key=key, bucket=bucket, timeout=5)
+                    existence_tasks = [_checked(check[1], check[2]) for check in failed_checks]
+                    existence_responses = await asyncio.gather(*existence_tasks, return_exceptions=True)
src/tplr/neurons.py (1)

452-456: Remove duplicate imports inside function.

gc and torch.cuda are already imported at module level; keep imports top-level.

-    import gc
-
-    import torch.cuda
+    # (imports are module‑level; no per‑call imports needed)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c608455 and 0cbc278.

📒 Files selected for processing (6)
  • hparams/2B.json (1 hunks)
  • neurons/miner.py (3 hunks)
  • neurons/validator.py (13 hunks)
  • src/tplr/comms.py (16 hunks)
  • src/tplr/compress.py (1 hunks)
  • src/tplr/neurons.py (9 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • hparams/2B.json
🧰 Additional context used
🧬 Code graph analysis (3)
src/tplr/neurons.py (2)
src/tplr/compress.py (5)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • TopKCompressor (295-697)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (2)
  • merge_gather_results (2170-2283)
  • get (1499-1617)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
neurons/miner.py (3)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/logging.py (1)
  • T (41-48)
src/tplr/comms.py (1)
  • gather (1736-2138)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (2)
src/tplr/compress.py (1)

639-642: LGTM! Device alignment for multi-GPU consistency.

The addition of device alignment for shift mirrors the existing pattern for lookup and ensures all operands in the dequantization calculation (line 643) reside on the same device. This prevents device mismatch errors in multi-GPU scenarios and aligns with the PR's goal of automatic multi-GPU tensor consolidation.

neurons/miner.py (1)

716-731: ****

The review comment is incorrect. Miner inherits from BaseNode and Trainer, and calls super().init() at line 197. BaseNode.init explicitly calls super().init() at line 62 to continue the MRO chain and initialize parent classes like Trainer. Trainer.init initializes self.expected_compressed_params at line 123. The attribute is properly initialized through normal inheritance—no AttributeError will occur, and the proposed diff is unnecessary.

Likely an incorrect or invalid review comment.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
src/tplr/comms.py (1)

2379-2395: All-gather logic is correct, but logging could be clearer about synchronization point.

The synchronous call to dist_helper.all_gather_object (line 2388) is correctly documented as a blocking barrier operation. All ranks execute this regardless of their individual partial_result, which maintains distributed consistency.

Consider enhancing the log message at lines 2382-2386 to explicitly note this is a synchronization barrier, as developers debugging hangs might benefit from knowing this is an expected all-ranks rendezvous point.

             tplr.log_with_context(
                 level="info",
-                message=f"[DISTRIBUTED_GATHER] Rank {rank} performing all-gather of partial results",
+                message=f"[DISTRIBUTED_GATHER] Rank {rank} entering all-gather barrier (synchronization point for all ranks)",
                 window=window,
             )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 951b40b and 9502397.

📒 Files selected for processing (1)
  • src/tplr/comms.py (18 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (8)
src/tplr/comms.py (8)

482-581: Well-implemented conservative existence check with structured error handling.

The method correctly prioritizes avoiding false negatives for slashing decisions. The structured ClientError inspection (lines 524-527) is a significant improvement over string matching, and the tiered error handling (404 → False, transient → True) aligns with the stated design goal.

One minor observation: line 511-514's asyncio.wait_for wrapper around head_object may leave the underlying HTTP connection in an indeterminate state if timeout fires mid-request. Consider whether the existing purge/retry logic in callers adequately handles this, or if explicit connection cleanup is needed on timeout.


707-721: Excellent fix: offloading torch.load prevents event loop blocking.

The use of asyncio.to_thread for CPU-bound decompression is the correct solution for parallel downloads, as noted in the comment. Both remote (lines 707-721) and local (lines 1561-1564) paths are now consistent.

The weights_only=True parameter is security-conscious. Verify that no legitimate gradient/checkpoint files in this system require custom unpicklers or arbitrary objects; if they do, this will fail silently (returning None on line 721) and may be difficult to debug.


1469-1490: LGTM: version parameter correctly defaults to tplr.version.

This addresses the previous review concern about undefined self.version. The optional parameter provides flexibility while maintaining backward compatibility.


1667-1738: Sensible optimization: early exit on repeated NOT_FOUND reduces wasted retries.

The 5-attempt threshold (2.5s total) provides a reasonable balance between accommodating transient issues and avoiding lengthy waits for genuinely missing files. The counter reset on non-NOT_FOUND errors (line 1733) correctly distinguishes persistent absence from intermittent network problems.

Verify that this early-exit logic plays nicely with the time_max grace period (lines 1680-1697). If a file is uploaded slightly late and arrives between attempt 4 and 5, the early exit might prevent a successful fetch that would have occurred on the next retry.


2148-2171: LGTM: clean, deterministic UID partitioning for distributed gather.

The round-robin approach (line 2171) is simple and provides good load balance. The sort (line 2168) ensures deterministic assignments across ranks, which is critical for correctness.


2537-2555: Fallback logic maintains consistency: all ranks perform sequential gather.

Lines 2537-2555 correctly implement the requirement from past review: when distributed gather fails, all ranks fall back to sequential gather (line 2550), not just the master. The if dist_helper.is_master: check (line 2540) only controls logging, not execution flow. This maintains consistency and avoids divergent state across ranks.

The comment at line 2548-2549 explicitly documents this design choice, which is helpful for future maintainers.


2633-2638: Reserve gather uses sequential mode: verify this is intentional.

Line 2633 explicitly uses self.gather() (sequential) rather than self.gather_distributed() for the reserve tier. If reserve peers are typically a small subset (lines 2622-2624 show [:len(missing)]), sequential mode may be acceptable. However, if the reserve pool is large and network latency is the bottleneck, using distributed mode for reserve could provide additional speedup.

Confirm this design choice aligns with expected reserve pool sizes and performance goals. If reserves are consistently small (e.g., <10 UIDs), sequential is fine; if they can be large, consider distributed mode for consistency.


1842-1908: Parallelization and slashing mechanism verified—no issues found.

The parallel existence checks (lines 1854–1870) correctly avoid serial bottlenecks. The caching strategy properly eliminates redundant API calls. The distinction between network failures (file exists but download failed, lines 1893–1900) and miner no-shows (file missing, lines 1901–1908) is sound and correctly implemented.

Verification confirms skipped_uids is the proper slashing trigger: it flows from gather_new_data() to slash_for_missing_gradients() (neurons/validator.py line 1769), which iterates through skipped UIDs and applies penalties. Network failures are correctly excluded (line 1900 ensures they bypass skipped_uids), preventing false slashing. No edge case exists where a UID escapes both valid and skipped tracking.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/tplr/comms.py (1)

2196-2211: Consider avoiding global device state mutation during device selection.

Line 2206 calls torch.cuda.set_device(0), which mutates global CUDA state as a side effect of probing device availability. A cleaner approach would validate device 0 without changing the current device:

         if torch.cuda.is_available():
             try:
                 target_device = f"cuda:{torch.cuda.current_device()}"
             except RuntimeError:
-                # Fall back to first available device or CPU
+                # Fall back to device 0 if available, else CPU
                 if torch.cuda.device_count() > 0:
-                    # Use device 0, but could be extended to find first available
-                    try:
-                        torch.cuda.set_device(0)
-                        target_device = "cuda:0"
-                    except RuntimeError:
-                        target_device = "cpu"
+                    # Validate device 0 exists without mutating state
+                    target_device = "cuda:0"
                 else:
                     target_device = "cpu"
         else:
             target_device = "cpu"

Note: Tensors will fail on .to(target_device) if the device doesn't exist, which is caught downstream.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9502397 and 085574a.

📒 Files selected for processing (1)
  • src/tplr/comms.py (18 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (10)
src/tplr/comms.py (10)

28-28: LGTM: Imports and partition cache are well-designed.

The OrderedDict and dist_helper imports are properly used throughout the file. The _partition_cache field provides efficient caching of UID partitions with clear type annotations.

Also applies to: 49-49, 141-143


482-581: LGTM: Robust existence check with conservative error handling.

The structured error inspection (lines 521-527) and branching logic properly distinguish between missing objects (404/NoSuchKey), configuration errors (403/AccessDenied), and transient service errors (5xx). The conservative approach (returning True for transient errors) prevents false slashing.


707-721: LGTM: Critical fix for parallel download performance.

Offloading torch.load() to a thread pool (via asyncio.to_thread) prevents CPU-intensive decompression from blocking the event loop during parallel downloads. This enables true parallelism and should significantly improve throughput.

Also applies to: 1561-1564


1470-1490: LGTM: Version parameter properly defaults to tplr.__version__.

The version parameter with runtime defaulting addresses previous concerns about undefined self.version. This provides flexibility for version-specific queries while maintaining backward compatibility.


1667-1734: LGTM: Smart early exit for non-existent files.

The consecutive NOT_FOUND tracking (lines 1667-1668, 1723-1734) provides intelligent early termination after 5 consecutive failures (~2.5s), avoiding wasted retries while resetting the counter for transient errors. This balances robustness with efficiency.


1841-1935: LGTM: Parallel S3 existence checks provide major performance improvement.

The parallel execution of S3 HEAD requests (lines 1854-1870) is a significant optimization that should eliminate the ~5s per-failure overhead mentioned in the PR objectives. The logic correctly distinguishes between network failures (file exists but download failed → don't slash) and missing uploads (file doesn't exist → legitimate slash).


2148-2172: LGTM: Clean round-robin UID partitioning.

The deterministic partitioning logic is straightforward and correct. Sorting UIDs (line 2168) ensures consistency across ranks, and the round-robin slice notation [rank::world_size] (line 2171) evenly distributes the load.


2213-2297: LGTM: Comprehensive merging logic with proper device consolidation.

The recursive move_to_device helper (lines 2240-2251) and state dict merging (lines 2253-2267) correctly consolidate partial results onto a single device. Metrics accumulation (lines 2269-2290) and success rate calculation (lines 2278-2280) properly aggregate statistics across all partials.


2299-2448: LGTM: Well-structured distributed gather with efficient caching.

The partition caching (lines 2334-2349) with LRU eviction, per-rank gathering (lines 2351-2368), and all-gather synchronization (lines 2388-2404) form a robust distributed gather implementation. The dual return path (lines 2406-2447) supports both incremental processing (partials) and traditional merged results, with rank metadata (lines 2426-2429) enabling deterministic downstream processing.


2452-2724: LGTM: Distributed gather integration with proper fallback consistency.

The distributed gather integration (lines 2509-2575) correctly determines when to use distributed mode and handles exceptions gracefully. The fallback logic (lines 2546-2564) ensures all ranks fall back together, maintaining consistency as noted in previous reviews. The dual-mode handling (lines 2589-2723) properly supports both merged results and partials, with reserve tier integration working correctly for both paths.

@shivam-MBZUAI shivam-MBZUAI changed the base branch from main to dev November 9, 2025 06:50
@shivam-MBZUAI shivam-MBZUAI force-pushed the feat/distribute_gather branch from 085574a to 146a54c Compare November 9, 2025 06:56
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
neurons/validator.py (1)

3837-3933: Critical: UnboundLocalError and qparams/vals misalignment in compute_peer_val_norms.

  • For list (partials) path, state_dict is undefined but used later → UnboundLocalError.
  • all_vals/all_quant_params can become misaligned when some partials miss quant_params, causing dequant IndexError or norms on uint8.

Apply this fix:

  • Return early after computing norms for partials.
  • Keep qparams aligned with vals (pad/truncate).
  • Warn on inconsistencies.
 def compute_peer_val_norms(
     self,
-    gather_result: SimpleNamespace | list[SimpleNamespace],
+    gather_result: SimpleNamespace | list[SimpleNamespace],
 ) -> dict[str, torch.Tensor]:
@@
-    # Handle both merged results and partials
-    if isinstance(gather_result, list):
+    # Handle both merged results and partials
+    if isinstance(gather_result, list):
@@
-        for n, p in self.model.named_parameters():
+        for n, p in self.model.named_parameters():
             vals_key = n + "vals"
             quant_key = n + "quant_params"
@@
-            for partial in gather_result:
+            for partial in gather_result:
                 # Skip None partials (ranks that gathered 0 peers)
                 if partial is None:
                     continue
@@
-                vals = getattr(state_dict, vals_key, None)
-                quant_params = getattr(state_dict, quant_key, None)
+                vals = getattr(state_dict, vals_key, None)
+                quant_params = getattr(state_dict, quant_key, None)
@@
-                if vals is not None:
-                    # vals is a list of tensors, one per peer in this partial
-                    # Move all tensors (including nested) to target device
-                    for v in vals:
-                        all_vals.append(move_to_device_recursive(v))
-                    # quant_params is also a list if present
-                    if quant_params is not None:
-                        for qp in quant_params:
-                            all_quant_params.append(move_to_device_recursive(qp))
+                if vals is not None:
+                    # vals: list[Tensor] per peer in this partial
+                    num_vals = len(vals)
+                    for v in vals:
+                        all_vals.append(move_to_device_recursive(v))
+                    # Align quant_params with vals length
+                    if quant_params is None:
+                        all_quant_params.extend([None] * num_vals)
+                    else:
+                        qps = [move_to_device_recursive(qp) for qp in quant_params]
+                        if len(qps) < num_vals:
+                            # pad with None to keep indices aligned
+                            qps.extend([None] * (num_vals - len(qps)))
+                            tplr.logger.warning(
+                                f"Inconsistent quant_params for {vals_key}; padding {num_vals - len(quant_params)} None(s)"
+                            )
+                        elif len(qps) > num_vals:
+                            qps = qps[:num_vals]
+                        all_quant_params.extend(qps)
@@
-            if not all_vals:
-                continue
+            if not all_vals:
+                continue
@@
-            vals_f32 = self.compressor.maybe_dequantize_values(
-                all_vals,
-                all_quant_params if all_quant_params else None,
-                target_device,
-            )
+            vals_f32 = self.compressor.maybe_dequantize_values(
+                all_vals, all_quant_params, target_device
+            )
@@
-            norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device)
+            norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device)
             clip_norm_dict[vals_key] = torch.median(norms)
-        else:
-            # Merged result: original logic
-            state_dict = gather_result.state_dict
-            if not state_dict:
-                raise ValueError("Must have gather_result.state_dict to compute norms")
+        # Early‑return for partials; avoid falling through to merged path
+        return clip_norm_dict
+    else:
+        # Merged result: original logic
+        state_dict = gather_result.state_dict
+        if not state_dict:
+            raise ValueError("Must have gather_result.state_dict to compute norms")
@@
-    for n, p in self.model.named_parameters():
+    for n, p in self.model.named_parameters():
         vals_key = n + "vals"
         quant_key = n + "quant_params"
@@
-    return clip_norm_dict
+    return clip_norm_dict

This removes the UnboundLocalError and keeps dequantization robust against inconsistent partials.

src/tplr/comms.py (2)

21-21: Fix missing imports used in MPU and network error handling.

base64 (MD5 header), EndpointConnectionError, and aiohttp exceptions are referenced but never imported; will raise at runtime on failures.

Apply this diff near the top import block:

@@
 import json
+import base64
 import math
@@
-from botocore.exceptions import ClientError, ConnectionClosedError
+from botocore.exceptions import ClientError, ConnectionClosedError, EndpointConnectionError
+from aiohttp.client_exceptions import ClientOSError, ServerDisconnectedError

Also applies to: 37-43


582-592: Normalize optional bucket parameter in s3_get_object.

Doc says it defaults to self.bucket, but code uses it directly and may be None.

Apply:

 async def s3_get_object(
@@
-        bucket: Bucket | None = None,
+        bucket: Bucket | None = None,
@@
-    ):
+    ):
@@
-        """Download object from S3 using asynchronous streaming.
+        """Download object from S3 using asynchronous streaming.
@@
-        s3_client = await self._get_s3_client(bucket)
+        # Default to instance bucket unless explicitly provided
+        if bucket is None:
+            bucket = self.bucket
+        s3_client = await self._get_s3_client(bucket)
♻️ Duplicate comments (1)
src/tplr/neurons.py (1)

734-742: Fix inverted memory_budget_violation flag (1 should indicate violation).

Currently sets 0 when exceeding and 1 otherwise; invert in both incremental and single paths.

@@
-            budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+            budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0
@@
-        budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+        budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0

Also applies to: 1034-1040

🧹 Nitpick comments (8)
neurons/validator.py (3)

95-106: Deduplicate UIDs and base success rate on unique attempts.

Across partials, a UID could appear twice (fallbacks/retries). Use sets to avoid inflation.

 def _get_uids_from_gather_result(gather_result) -> list[int]:
@@
-    if isinstance(gather_result, list):
-        all_uids = []
-        for partial in gather_result:
-            if partial is not None:
-                all_uids.extend(getattr(partial, "uids", []))
-        return all_uids
+    if isinstance(gather_result, list):
+        uids = []
+        for partial in gather_result:
+            if partial is not None:
+                uids.extend(getattr(partial, "uids", []))
+        return sorted(set(uids))
@@
 def _get_success_rate_from_gather_result(gather_result) -> float:
@@
-    if isinstance(gather_result, list):
-        total_uids = len(Validator._get_uids_from_gather_result(gather_result))
-        total_skipped = len(
-            Validator._get_skipped_uids_from_gather_result(gather_result)
-        )
-        total_attempted = total_uids + total_skipped
-        if total_attempted == 0:
-            return 0.0
-        return total_uids / total_attempted
+    if isinstance(gather_result, list):
+        unique_uids = set(Validator._get_uids_from_gather_result(gather_result))
+        unique_skipped = set(
+            Validator._get_skipped_uids_from_gather_result(gather_result)
+        )
+        total_attempted = len(unique_uids | unique_skipped)
+        return (len(unique_uids) / total_attempted) if total_attempted else 0.0

Also applies to: 120-134


1556-1574: Pass device as string, not via cast().

cast() doesn’t convert at runtime; prefer explicit str for consistency with Comms API.

-                    device=cast(str, self.device),
+                    device=str(self.device),

1684-1689: Guard offload for partials.

offload_gather_results expects a merged result; if gather_result is a list, it no-ops. Avoid the call for clarity, or offload each partial.

-                self.offload_gather_results(gather_result, log=True)
+                if not isinstance(gather_result, list):
+                    self.offload_gather_results(gather_result, log=True)
src/tplr/comms.py (2)

655-661: Prefer structured error inspection in HEAD path (avoid brittle string checks).

Use e.response['Error']['Code'] and HTTP status like in s3_object_exists; also guard bucket before purge.

Apply:

-            except (ConnectionClosedError, ClientError) as e:
-                await self._purge_s3_client(bucket)
-                if "404" in str(e):
-                    tplr.logger.debug(f"Object {key} not found in bucket {bucket.name}")
-                    return None
+            except (ConnectionClosedError, ClientError) as e:
+                if bucket is not None:
+                    await self._purge_s3_client(bucket)
+                status = None
+                code = None
+                if isinstance(e, ClientError) and hasattr(e, "response"):
+                    status = e.response.get("ResponseMetadata", {}).get("HTTPStatusCode")
+                    code = e.response.get("Error", {}).get("Code")
+                if status == 404 or code in {"404", "NoSuchKey", "NotFound"}:
+                    tplr.logger.debug(f"Object {key} not found in bucket {bucket.name if bucket else '<?>'}")
+                    return None
+                if status in {401, 403} or code in {"AccessDenied", "Forbidden"}:
+                    tplr.logger.error(f"Permission error checking {key}: {e}")
+                    return None
+                tplr.logger.warning(f"HEAD error for {key}: {e}")
+                return None

2196-2211: Avoid mutating global CUDA device in merge_gather_results.

Calling torch.cuda.set_device(0) here is a side effect; choose a device string without changing global state.

-            if torch.cuda.is_available():
-                try:
-                    target_device = f"cuda:{torch.cuda.current_device()}"
-                except RuntimeError:
-                    # Fall back to first available device or CPU
-                    target_device = "cpu"
-                    if torch.cuda.device_count() > 0:
-                        # Use device 0, but could be extended to find first available
-                        try:
-                            torch.cuda.set_device(0)
-                            target_device = "cuda:0"
-                        except RuntimeError:
-                            target_device = "cpu"
+            if torch.cuda.is_available():
+                try:
+                    target_device = f"cuda:{torch.cuda.current_device()}"
+                except RuntimeError:
+                    target_device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu"
src/tplr/neurons.py (3)

557-563: Don’t skip the entire batch on duplicate UIDs; prune duplicates instead.

Skipping the whole batch wastes valid gradients. Drop only the partials containing already-applied UIDs.

-        duplicates = applied_uids & batch_uids
-        if duplicates:
-            tplr.logger.warning(
-                f"Skipping batch {batch_idx + 1}: {len(duplicates)} duplicate UIDs"
-            )
-            continue
+        duplicates = applied_uids & batch_uids
+        if duplicates:
+            before = len(batch_partials)
+            batch_partials = [
+                p for p in batch_partials
+                if not (set(getattr(p, "uids", [])) & applied_uids)
+            ]
+            pruned = before - len(batch_partials)
+            if pruned > 0:
+                tplr.logger.warning(
+                    f"Pruned {pruned} partial(s) with duplicate UIDs; continuing with {len(batch_partials)}"
+                )
+            if not batch_partials:
+                continue

482-485: Remove redundant inner imports.

gc is already imported at module top; torch.cuda is accessible via torch.cuda without separate import.

-    import gc
-
-    import torch.cuda
+    # (imports removed; use top-level gc and torch.cuda via torch.cuda)

1805-1823: Minor: avoid building all_state_dicts for partials path.

For partials, you re-derive state_dict via uid_to_partial; the all_state_dicts.append(...) list is unused and increases memory.

Replace the partials branch to only build all_uids and uid_to_partial; defer state_dict access to the main loop. No behavior change.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 085574a and 146a54c.

📒 Files selected for processing (9)
  • hparams/2B.json (1 hunks)
  • hparams/hparams.json (1 hunks)
  • neurons/base_node.py (1 hunks)
  • neurons/miner.py (3 hunks)
  • neurons/trainer.py (1 hunks)
  • neurons/validator.py (13 hunks)
  • src/tplr/comms.py (18 hunks)
  • src/tplr/compress.py (1 hunks)
  • src/tplr/neurons.py (12 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
  • src/tplr/compress.py
  • hparams/hparams.json
  • neurons/trainer.py
  • neurons/base_node.py
  • hparams/2B.json
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/validator.py (5)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/comms.py (2)
  • gather_with_reserve (2452-2724)
  • merge_gather_results (2174-2297)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1789-1994)
src/tplr/compress.py (1)
  • maybe_dequantize_values (647-697)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
src/tplr/neurons.py (2)
src/tplr/compress.py (6)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • ChunkingTransformer (145-292)
  • TopKCompressor (295-697)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (3)
  • Comms (64-3429)
  • merge_gather_results (2174-2297)
  • get (1503-1621)
neurons/miner.py (3)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/logging.py (1)
  • T (41-48)
src/tplr/comms.py (1)
  • gather (1740-2142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (8)
neurons/miner.py (3)

468-470: Good call refreshing commitments before gather.

Ensures R2 existence checks use current bucket info.


707-731: Sequential gather params look correct and production-safe.

Passing local=False and stale_retention=100 avoids local FS misses; expected_compressed_params forwarded.


733-744: Clear, partial-aware logging.

Logs cover successes, skips, rate, and timing; guarded for None.

neurons/validator.py (4)

1392-1410: Peers broadcast across ranks looks right.

Using broadcast_object_list aligns shards before distributed gather.


1452-1479: Timestamp: master query + broadcast is correct.

Prevents rank skew in time window selection.


1659-1670: Overlap check handles partials/merged.

Good reuse of a single API for both paths.


975-981: Rating initialization before OpenSkill.rate is correct.

Prevents KeyErrors for new peers.

src/tplr/comms.py (1)

708-717: Good fix: offload torch.load to a thread.

Avoids blocking the event loop on CPU decompression. Solid improvement.

@shivam-MBZUAI shivam-MBZUAI changed the title feat: Distributed Gather feat: Distributed Gather with Incremental Processing Nov 9, 2025
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
neurons/miner.py (1)

469-481: Consider running get_commitments() only on master rank to reduce chain queries.

Line 470 calls get_commitments() on all ranks when only the master rank needs the updated commitments to broadcast. The non-master ranks receive the peer list via broadcast and don't use commitments at this point.

Apply this diff to reduce redundant chain queries:

         peer_start = tplr.T()
         if self.is_master:
             await tplr.neurons.update_peers(
                 instance=self, window=step_window, peer_start=peer_start
             )
-            # Refresh commitments to get updated bucket info for all peers
-            self.comms.commitments = await self.comms.get_commitments()
+             # Refresh commitments to get updated bucket info for all peers
+             self.comms.commitments = await self.comms.get_commitments()

         # Broadcast peer list from master to all ranks for distributed gather
         if dist_helper.world_size > 1 and dist_helper.is_distributed():
src/tplr/comms.py (1)

2196-2210: Device selection fallback could fail when device 0 is unavailable.

The fallback logic (lines 2200-2209) assumes cuda:0 exists when device_count() > 0, but device 0 might be unavailable on some systems. While torch.cuda.set_device(0) is wrapped in try/except (lines 2205-2209), a more robust approach would probe available devices sequentially.

Consider this alternative that finds the first usable device:

             try:
                 target_device = f"cuda:{torch.cuda.current_device()}"
             except RuntimeError:
-                # Fall back to first available device or CPU
                 target_device = "cpu"
-                if torch.cuda.device_count() > 0:
-                    # Use device 0, but could be extended to find first available
-                    try:
-                        torch.cuda.set_device(0)
-                        target_device = "cuda:0"
-                    except RuntimeError:
-                        target_device = "cpu"
+                # Probe for first available CUDA device
+                for device_id in range(torch.cuda.device_count()):
+                    try:
+                        torch.cuda.set_device(device_id)
+                        target_device = f"cuda:{device_id}"
+                        break
+                    except RuntimeError:
+                        continue
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 146a54c and c3b1627.

📒 Files selected for processing (2)
  • neurons/miner.py (6 hunks)
  • src/tplr/comms.py (18 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
neurons/miner.py (3)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/comms.py (1)
  • gather_with_reserve (2439-2713)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (18)
neurons/miner.py (3)

762-792: LGTM!

The dual-path handling for partials (list) vs. merged results (SimpleNamespace) is implemented correctly. The code safely extracts UIDs and skipped UIDs using getattr with defaults, aggregates metrics across partials, and logs appropriate per-rank context for distributed mode.


885-904: LGTM!

The debug dictionary construction correctly handles both partials (list) and merged results (SimpleNamespace). The use of sets for deduplication when computing successful_peers and skipped_peers is appropriate and prevents duplicate UIDs in the debug output.


940-963: LGTM!

The metrics calculation correctly aggregates statistics from both partials (list) and merged results (SimpleNamespace). The success rate computation accounts for both paths, and the extraction of skipped_uids_list uses safe defaults.

src/tplr/comms.py (15)

516-519: Timeout handling is conservative and appropriate.

Returning True on asyncio.TimeoutError (line 516) follows the conservative approach for transient network issues, which is correct for preventing false slashing. The debug-level log is suitable for a timeout that's likely transient.


707-721: LGTM!

Offloading torch.load() to a thread pool (line 711) prevents blocking the event loop during CPU decompression, which is critical for parallel downloads. The error handling appropriately logs failures and returns None to signal the issue to callers.


1469-1490: LGTM!

The version parameter now defaults to None and is normalized to tplr.__version__ when not provided (lines 1489-1490). This is a clean API design that allows callers to omit the parameter while maintaining backwards compatibility.


1561-1564: LGTM!

Offloading torch.load() for local files (line 1562) maintains consistency with the remote loading path and prevents event loop blocking even for local storage. This is a good defensive practice.


1667-1734: LGTM!

The consecutive NOT_FOUND tracking (lines 1667-1668, 1723-1734) is a smart optimization that prevents long retry loops when a file clearly doesn't exist. The threshold of 5 consecutive failures (2.5 seconds) is reasonable, and resetting the counter on other error types correctly distinguishes transient network issues from permanent absence.


1842-1871: LGTM!

Parallelizing S3 existence checks (lines 1854-1870) is an excellent optimization that avoids serial API calls. Running all checks concurrently with asyncio.gather and caching results in existence_results prevents the 5-second-per-failure delay mentioned in the comment. Treating exceptions as False (line 1865) is conservative and appropriate.


1873-1935: LGTM!

The failure attribution logic correctly distinguishes between network issues and miner failures using the cached existence checks from lines 1892 and 1921. Network failures (file exists but download failed) don't add to skipped_uids to prevent false slashing (lines 1893-1900, 1922-1928), while genuine miner failures (file doesn't exist) are correctly flagged for slashing (lines 1902-1908, 1931-1934). The logging levels are appropriate for each case.


2147-2171: LGTM!

The partition_uids_across_ranks method implements deterministic round-robin UID assignment (line 2171) with proper sorting (line 2168) to ensure all ranks agree on the partition. The early return for world_size <= 1 (lines 2164-2165) is a sensible optimization.


2240-2297: LGTM!

The merge_gather_results implementation is thorough and correct:

  • The recursive move_to_device helper (lines 2240-2251) handles all common Python data structures
  • State dict merging (lines 2254-2267) correctly concatenates parameter lists while moving tensors to the target device
  • Metrics aggregation (lines 2270-2290) uses appropriate operations: max for time (parallel execution), sum for bytes, extend for lists
  • The aggregate success rate calculation (lines 2278-2280) correctly divides total successful by total attempted

2334-2349: LGTM!

The partition caching strategy (lines 2334-2349) is well-designed:

  • Cache key includes window, sorted UIDs, and world_size for deterministic lookups
  • LRU eviction maintains a bounded cache size (32 entries)
  • Moving accessed entries to the end (line 2340) implements LRU correctly

2382-2392: LGTM!

The all-gather synchronization (lines 2382-2392) correctly uses dist_helper.all_gather_object to collect partials from all ranks. The synchronous call is appropriate here as it's a blocking collective operation by design. The early normalization of empty payloads to None (lines 2395-2399) simplifies downstream filtering.


2402-2434: LGTM!

The dual return path (lines 2402-2434) correctly implements the return_partials flag:

  • When True: returns unmerged list with rank metadata (lines 2414-2416) for incremental processing
  • When False: merges on all ranks (line 2420) for consistency
    The rank-0-only logging (lines 2408-2412, 2424-2432) avoids duplicate logs while providing comprehensive metrics.

2505-2553: Distributed fallback correctly runs sequential gather on all ranks.

Lines 2537-2553 implement the fallback by having all ranks run sequential gather (line 2548) when distributed gather fails. This maintains consistency across ranks as noted in the comment (line 2540). The comment at lines 2546-2547 correctly states that sequential gather (R2 downloads) can be done independently by each rank.


2630-2668: Reserve retry correctly uses sequential gather and handles both result types.

Line 2631 appropriately uses sequential gather() for reserve UIDs rather than distributed, as reserve retries are typically smaller and don't benefit from distribution overhead. The dual handling (lines 2639-2668) correctly:

  • Appends fallback as a new partial with rank metadata (lines 2641-2651) when primary is a list
  • Merges fallback into the existing SimpleNamespace (lines 2653-2668) when primary is merged

2673-2711: LGTM!

The final success rate computation and dual logging (lines 2673-2711) correctly handles both return types:

  • For partials: aggregates across all partials, adds success_rate to each (lines 2681-2683), and returns None if no gradients (lines 2697-2698)
  • For merged: computes simple rate and returns None if empty (lines 2700-2711)

The early return of None for zero gradients (lines 2697, 2710) ensures callers receive consistent failure signals.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/tplr/comms.py (1)

2196-2211: Device fallback could still fail on multi-GPU systems.

The fallback logic at lines 2201-2210 attempts to use device 0 but doesn't verify its actual availability before setting it. If device 0 is unavailable (e.g., on systems with only devices 1-3), torch.cuda.set_device(0) will raise a RuntimeError and fall back to CPU, which might not be intended.

Consider iterating through available devices:

         if torch.cuda.is_available():
             try:
                 target_device = f"cuda:{torch.cuda.current_device()}"
             except RuntimeError:
-                # Fall back to first available device or CPU
                 target_device = "cpu"
-                if torch.cuda.device_count() > 0:
-                    # Use device 0, but could be extended to find first available
-                    try:
-                        torch.cuda.set_device(0)
-                        target_device = "cuda:0"
-                    except RuntimeError:
-                        target_device = "cpu"
+                # Try to find first actually available device
+                for device_id in range(torch.cuda.device_count()):
+                    try:
+                        with torch.cuda.device(device_id):
+                            target_device = f"cuda:{device_id}"
+                            break
+                    except RuntimeError:
+                        continue
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c3b1627 and d59bde7.

📒 Files selected for processing (3)
  • hparams/hparams.json (1 hunks)
  • neurons/trainer.py (1 hunks)
  • src/tplr/comms.py (18 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • hparams/hparams.json
  • neurons/trainer.py
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (14)
src/tplr/comms.py (14)

28-28: LGTM: Imports and cache field properly structured.

The module-level imports and partition cache structure are well-designed. The OrderedDict cache with tuple keys enables efficient LRU eviction for deterministic UID partitioning across windows.

Also applies to: 49-49, 141-143


482-581: LGTM: Robust existence check with conservative error handling.

The structured error inspection and conservative assumptions (returning True for transient errors) correctly prevent false slashing of miners due to network issues. Permission errors appropriately logged at ERROR level.


707-721: LGTM: Critical fix for event-loop blocking.

Offloading torch.load() to a thread pool via asyncio.to_thread is essential for enabling true parallel downloads. The 1-2s CPU decompression time per file would otherwise serialize concurrent operations. Error handling and weights_only=True are appropriate.


1561-1564: LGTM: Consistent thread-pool offloading for local files.

Local file loading also correctly uses asyncio.to_thread to prevent event-loop blocking, maintaining consistency with remote S3 downloads.


1469-1490: LGTM: Fixed undefined version reference.

The signature now properly defaults version to None and falls back to tplr.__version__, resolving the previous undefined self.version issue. This pattern allows optional version specification while maintaining backward compatibility.


1667-1734: LGTM: Smart early-exit optimization for missing files.

Tracking consecutive NOT_FOUND responses enables early termination (2.5s) when a file clearly doesn't exist, avoiding full timeout waits. Resetting the counter on other error types correctly distinguishes persistent absence from transient network issues.


1841-1870: LGTM: Critical performance optimization for failure handling.

Parallelizing S3 existence checks dramatically reduces failure-handling latency from O(N * 5s) to O(5s) for N failures. The cached results enable efficient O(1) lookups during response processing, preventing cascading timeouts in distributed scenarios.


1873-1935: LGTM: Sophisticated failure attribution prevents false slashing.

The enhanced error handling correctly distinguishes between miner failures (no upload) and network/timeout issues using the pre-computed existence checks. Only legitimate upload failures are added to skipped_uids, preventing false slashing while maintaining accountability.


2148-2172: LGTM: Clean deterministic partitioning.

The round-robin partitioning strategy with sorted UIDs ensures deterministic, balanced distribution across ranks. The implementation is simple and correct.


2239-2297: LGTM: Robust merge implementation with recursive device movement.

The merge logic correctly handles nested structures via recursive move_to_device, consolidates per-rank state dictionaries by extending parameter lists, and properly aggregates metrics. Moving tensors to the target device during merge (rather than afterward) is more efficient.


2334-2350: LGTM: Efficient partition caching with LRU eviction.

The partition cache uses window-aware keys and OrderedDict for LRU eviction (keeping 32 most recent). The move_to_end call on cache hits maintains LRU ordering, and popitem(last=False) correctly evicts oldest entries.


2380-2434: LGTM: Correct distributed synchronization and dual return modes.

The all-gather barrier correctly blocks for synchronization, empty partials are normalized to None, and rank metadata enables deterministic ordering. The dual return mode (partials vs merged) supports both incremental and immediate processing patterns. All-ranks merging maintains consistency.


2609-2713: LGTM: Reserve retry handles both result modes correctly.

The reserve fallback logic correctly branches on result type (partials vs merged), appending fallback as a new partial with rank metadata in one case and extending state dict lists in the other. Success rate is properly recomputed with respect to the original gather target, and None is returned when no gradients are collected.


2535-2553: Looking at the actual code structure from the script outputs, the indentation reveals that the review comment's analysis is incorrect.

The code at lines 2546-2553 is NOT inside the if dist_helper.is_master: block. The key evidence:

  1. Line 2538: if dist_helper.is_master: (with specific indentation)
  2. Lines 2539-2545: Logging code (indented deeper, inside master block)
  3. Lines 2546-2547: Comments with same indentation as the if statement (outside master block)
  4. Lines 2548-2553: primary = await self.gather(...) with same indentation as the if statement (outside master block)

The actual control flow is:

if primary is None:
    if dist_helper.is_master:      # Only master logs
        context_log(...)
        tplr.logger.info(...)
    # ALL ranks execute this:
    primary = await self.gather(...)

This means all ranks already perform the sequential gather independently, which is correct behavior and matches the intent of the comment. No broadcast is needed because each rank downloads independently (R2 downloads are rank-local operations).

The review comment's proposed diff is unnecessary—the code already implements the intended behavior.

Likely an incorrect or invalid review comment.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (4)
neurons/validator.py (2)

1452-1479: Timestamp broadcast: consistent windowing across ranks. LGTM.

Master queries, broadcasts to all ranks, then shared time_min/time_max. This fixes prior inconsistency; indentation and device usage are correct.


1647-1667: Single merge for upload; overlap uses partial-aware path.

The prior duplicate merges are removed; now you merge once for upload and pass original gather_result to overlap. Memory win.

Also applies to: 1657-1664

src/tplr/neurons.py (2)

741-741: Still inverted: budget_exceeded flag logic remains incorrect.

This was flagged in previous reviews but not corrected. The flag is set to 0 when budget IS exceeded and 1 when it's NOT exceeded, which is counter-intuitive and breaks metric semantics.

Apply this diff:

-            budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+            budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0

Note: The same issue exists at line 1039 in _outer_step_single.


1039-1039: Still inverted: budget_exceeded flag logic incorrect here too.

Same issue as line 741. The flag should be 1 when budget is exceeded, not 0.

Apply this diff:

-        budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+        budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0
🧹 Nitpick comments (3)
neurons/validator.py (2)

138-202: Extend offload/load helpers to handle list-of-partials (memory relief in distributed mode).

offload_gather_results/load_gather_results_to_device currently no-op for partials, leaving tensors on GPU. Add list handling.

 def offload_gather_results(self, gather_result, *, log: bool = True) -> None:
@@
-        state = getattr(gather_result, "state_dict", None)
+        # Handle list-of-partials early
+        if isinstance(gather_result, list):
+            for partial in gather_result:
+                if partial is None or getattr(partial, "_offloaded", None) == "cpu":
+                    continue
+                state = getattr(partial, "state_dict", None)
+                if state is None:
+                    continue
+                if isinstance(state, SimpleNamespace):
+                    _to_pinned_cpu(state)
+                elif isinstance(state, dict):
+                    for k, v in list(state.items()):
+                        state[k] = _to_pinned_cpu(v)
+                setattr(partial, "_offloaded", "cpu")
+            if torch.cuda.is_available():
+                torch.cuda.synchronize(getattr(self, "device", None))
+                torch.cuda.empty_cache()
+            return
+
+        state = getattr(gather_result, "state_dict", None)
@@
 def load_gather_results_to_device(
         self, gather_result, device: torch.device, *, log: bool = True
     ) -> None:
@@
-        state = getattr(gather_result, "state_dict", None)
+        if isinstance(gather_result, list):
+            if torch.cuda.is_available():
+                torch.cuda.synchronize(device)
+            for partial in gather_result:
+                if partial is None or getattr(partial, "_offloaded", None) != "cpu":
+                    continue
+                state = getattr(partial, "state_dict", None)
+                if state is None:
+                    continue
+                if isinstance(state, SimpleNamespace):
+                    _to_device(state)
+                elif isinstance(state, dict):
+                    for k, v in list(state.items()):
+                        state[k] = _to_device(v)
+                setattr(partial, "_offloaded", None)
+            if torch.cuda.is_available():
+                torch.cuda.synchronize(device)
+            return
+
+        state = getattr(gather_result, "state_dict", None)

If you want, I can add a micro benchmark to compare VRAM before/after offloading partials.

Also applies to: 203-264


1392-1410: Trim log volume when broadcasting peers to non-master ranks.

Logging entire peer lists every window can spam logs and increase costs. Prefer counts.

-                    tplr.log_with_context(
-                        level="info",
-                        message=f"[Rank {dist_helper.rank}] Received peers from master: gather={self.comms.peers}, reserve={self.comms.reserve_peers}",
+                    tplr.log_with_context(
+                        level="info",
+                        message=f"[Rank {dist_helper.rank}] Received peers from master: gather={len(self.comms.peers)} peers, reserve={len(self.comms.reserve_peers)} peers",
                         sync_window=self.sync_window,
                         current_window=self.current_window,
                     )
src/tplr/neurons.py (1)

1740-1755: Consider clarifying the comment.

The implementation uses torch.maximum to take the larger of param_avg_change and learning_rate, which is good for stability. However, the comment at line 1740-1741 says "Use learning_rate as the baseline" which doesn't fully capture that you're using the maximum of both values. Consider updating the comment to be more precise about the approach.

Example:

-        # Use learning_rate as the baseline step size for more stable comparison
-        # param_avg_change can be very small when model is converged, leading to inflated step ratios
+        # Use the maximum of param_avg_change and learning_rate for stable step ratio calculation
+        # This avoids division by tiny numbers when the model has converged
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d59bde7 and 0afef3f.

📒 Files selected for processing (2)
  • neurons/validator.py (14 hunks)
  • src/tplr/neurons.py (13 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/tplr/neurons.py (2)
src/tplr/compress.py (5)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • TopKCompressor (295-697)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (2)
  • merge_gather_results (2174-2297)
  • get (1503-1621)
neurons/validator.py (5)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/comms.py (3)
  • gather_with_reserve (2439-2713)
  • merge_gather_results (2174-2297)
  • get_with_retry (1623-1738)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1795-2000)
src/tplr/compress.py (1)
  • maybe_dequantize_values (647-697)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (6)
src/tplr/neurons.py (6)

48-70: LGTM: Robust CUDA device resolution with proper fallbacks.

The helper correctly handles edge cases (unavailable CUDA, invalid device strings, RuntimeErrors) and provides sensible fallbacks. Good defensive programming.


250-322: LGTM: Clean dispatcher implementation.

The function signature and dispatch logic correctly handle both single merged results and lists of partials, forwarding all parameters appropriately to the specialized implementations.


324-369: LGTM: Recursive memory estimation correctly implemented.

The recursive _iter_tensors helper properly traverses all container types (dict, list, tuple, SimpleNamespace) to sum tensor memory, addressing the previous issue where memory estimates were often 0 MB.

Based on learnings


371-419: LGTM: Sound batch size calculation with safety margins.

The sampling approach (first 3 partials) and 0.7 safety factor provide reasonable estimates while avoiding OOM risks. Edge cases are handled properly.


665-696: LGTM: Excellent error recovery and batch failure logging.

The per-batch try/except with detailed failure tracking allows processing to continue even when individual batches fail. The corrected logging using batch_idx (line 693) properly addresses the previous KeyError issue.

Based on learnings


1795-1945: LGTM: Efficient overlap detection with O(n) optimization.

The uid_to_partial mapping (lines 1816-1826) eliminates the O(n×m) nested loop search that was flagged previously, achieving O(1) lookup per UID (line 1881). The function now correctly handles both merged and partial results.

Based on learnings

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9880455 and f51dc0b.

📒 Files selected for processing (2)
  • hparams/hparams.json (2 hunks)
  • src/tplr/model_factory.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • hparams/hparams.json

@shivam-MBZUAI shivam-MBZUAI force-pushed the feat/distribute_gather branch 3 times, most recently from 956d520 to ca50d92 Compare November 19, 2025 19:14
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
neurons/miner.py (1)

843-877: Critical: gather_result is cleared/deleted before later use, causing errors and losing metrics.

Inside the if should_update: block you now do:

  • For list results: delete each partial’s state_dict and then gather_result.clear().
  • For merged results: delete gather_result.state_dict.
  • In both cases: del gather_result.

Later in the same loop you still rely on gather_result in several places:

  • Building debug_dict["successful_peers"] and debug_dict["skipped_peers"] (Lines 927–949).
  • Computing gather_success_rate and skipped_uids_list for wandb and Influx metrics (Lines 985–1007, 1060–1072).
  • Final cleanup at the very end of the loop (if gather_result is not None: del gather_result at Lines 1088–1092).

When should_update is True (the common case), the earlier del gather_result makes all these later references an UnboundLocalError. For the list case, even without the explicit del, gather_result.clear() destroys the partials needed to compute success and skipped peer sets.

You can keep the memory savings while fixing the lifecycle by only dropping the heavy state_dict payloads and deferring deletion of the container until after metrics/debug logging. For example:

@@
-            if should_update:
-                gradient_fingerprint = self.outer_step(gather_result)
-                self.global_step += (
-                    1  # Increment only when we actually do an outer step
-                )
-
-                # Aggressive cleanup after outer step to prevent memory leaks
-                # Clear gradients and free any tensors from the merge/apply operation
-                if hasattr(self, "model") and self.model is not None:
-                    self.model.zero_grad(set_to_none=True)
-
-                # Clear gather_result to free memory
-                if gather_result is not None:
-                    if isinstance(gather_result, list):
-                        for partial in gather_result:
-                            if partial is not None and hasattr(partial, "state_dict"):
-                                delattr(partial, "state_dict")
-                        gather_result.clear()
-                    elif hasattr(gather_result, "state_dict"):
-                        delattr(gather_result, "state_dict")
-                    del gather_result
+            if should_update:
+                gradient_fingerprint = self.outer_step(gather_result)
+                self.global_step += 1  # Increment only when we actually do an outer step
+
+                # Aggressive cleanup after outer step to prevent memory leaks
+                # Clear gradients and free any tensors from the merge/apply operation
+                if hasattr(self, "model") and self.model is not None:
+                    self.model.zero_grad(set_to_none=True)
+
+                # Drop heavy tensor payloads from gather_result but keep metadata
+                if gather_result is not None:
+                    if isinstance(gather_result, list):
+                        for partial in gather_result:
+                            if partial is not None and hasattr(partial, "state_dict"):
+                                delattr(partial, "state_dict")
+                    elif hasattr(gather_result, "state_dict"):
+                        delattr(gather_result, "state_dict")
@@
-            # Delete any remaining local variables to clear up memory
-            del shard_gradient
-            if gather_result is not None:
-                del gather_result
+            # Delete any remaining local variables to clear up memory
+            del shard_gradient
+            if gather_result is not None:
+                del gather_result

This way:

  • gather_result (and any partials) remain available for debug and metrics, but without their large state_dict tensors.
  • Final deletion still happens once at the end of the loop, after all uses.

Apply an equivalent fix if you shuffle code, but the key is: don’t clear/delete gather_result before its last use.
-->

Also applies to: 927-949, 985-1007, 1060-1072, 1088-1092

♻️ Duplicate comments (3)
neurons/validator.py (2)

1459-1489: Timestamp broadcast fix resolves prior rank inconsistency

Only querying query_block_timestamp on master, then broadcasting ts_value with dist_helper.broadcast ensures every rank uses the same [time_min, time_max] window. This addresses the earlier concern where some ranks could fall back to time.time() while others used chain timestamps.


3873-3971: Bug: compute_peer_val_norms still falls through for partials, using undefined state_dict

When gather_result is a list[SimpleNamespace] (the partials case used by distributed gather):

  • The if isinstance(gather_result, list): block builds clip_norm_dict by aggregating and dequantizing vals from all partials.
  • After that block, execution continues into the shared for n, p in self.model.named_parameters(): loop, which references state_dict—but state_dict is only defined in the else (merged) branch.

This re‑introduces the earlier UnboundLocalError/NameError for state_dict in the partials path and will blow up whenever you call compute_peer_val_norms with a list (i.e., in distributed gather mode). It also means you’re recomputing norms twice for merged results if the list branch wasn’t intended to fall through.

Recommend returning early for the partials case and scoping the merged‑result loop to the else branch:

     def compute_peer_val_norms(
         self,
-        gather_result: SimpleNamespace | list[SimpleNamespace],
+        gather_result: SimpleNamespace | list[SimpleNamespace],
     ) -> dict[str, torch.Tensor]:
@@
-        # Handle both merged results and partials
-        if isinstance(gather_result, list):
+        # Handle both merged results and partials
+        if isinstance(gather_result, list):
@@
-                norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device)
-                clip_norm_dict[vals_key] = torch.median(norms)
-        else:
-            # Merged result: original logic
-            state_dict = gather_result.state_dict
-            if not state_dict:
-                raise ValueError("Must have gather_result.state_dict to compute norms")
-
-        for n, p in self.model.named_parameters():
-            vals_key = n + "vals"
-            quant_key = n + "quant_params"
-
-            vals = getattr(state_dict, vals_key, None)
-            quant_params = getattr(state_dict, quant_key, None)
-
-            if vals is None:
-                continue
-
-            vals_f32 = self.compressor.maybe_dequantize_values(
-                vals, quant_params, p.device
-            )
-
-            norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device)
-            clip_norm_dict[vals_key] = torch.median(norms)
-
-        return clip_norm_dict
+                norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device)
+                clip_norm_dict[vals_key] = torch.median(norms)
+            # Partials path fully handled
+            return clip_norm_dict
+        else:
+            # Merged result: original logic
+            state_dict = gather_result.state_dict
+            if not state_dict:
+                raise ValueError("Must have gather_result.state_dict to compute norms")
+
+            for n, p in self.model.named_parameters():
+                vals_key = n + "vals"
+                quant_key = n + "quant_params"
+
+                vals = getattr(state_dict, vals_key, None)
+                quant_params = getattr(state_dict, quant_key, None)
+
+                if vals is None:
+                    continue
+
+                vals_f32 = self.compressor.maybe_dequantize_values(
+                    vals, quant_params, p.device
+                )
+
+                norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device)
+                clip_norm_dict[vals_key] = torch.median(norms)
+
+            return clip_norm_dict

You may also want to re‑introduce the additional defensive alignment logic for vals / quant_params that was suggested previously if you expect inconsistent partials, but the scoping/early‑return above is the immediate correctness fix.

#!/bin/bash
# Quick check: show full compute_peer_val_norms and verify state_dict is only used in merged path
rg -n "def compute_peer_val_norms" -n -C15 neurons/validator.py
src/tplr/neurons.py (1)

593-609: Memory budget violation flag still inverted in logging/metrics

The new use of _resolve_cuda_device(device) and torch.cuda.memory_allocated(cuda_device) is correct and fixes the earlier string‑device usage. However, the budget_exceeded flag at Line [854] is still inverted:

budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1

Given the metric name "outer_step/memory_budget_violation", this should be 1 when memory exceeds the budget and 0 otherwise; currently it reports 0 on violation and 1 when within budget, which is the opposite of what the name implies and will mislead monitoring.

A minimal fix:

-        budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+        budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0

Also applies to: 629-629, 848-877

🧹 Nitpick comments (10)
neurons/validator.py (1)

1399-1417: Distributed peers broadcast keeps ranks aligned, but watch payload size

Broadcasting [self.comms.peers, self.comms.reserve_peers] from master via dist.broadcast_object_list is a clean way to ensure all ranks share an identical peer set before distributed gather. On large subnets this object payload could become sizeable; if that ever shows up in profiling, consider broadcasting a compact, fixed‑shape representation (e.g., tensors of UIDs) instead.

src/tplr/comms.py (4)

483-582: s3_object_exists: consider purging bad clients and broadening error code coverage.

The structured handling of ClientError vs transient issues looks solid and the conservative defaults are appropriate. Two refinements would harden this further:

  • On ConnectionClosedError and similar low-level failures, consider calling _purge_s3_client(bucket) so you don’t keep reusing a broken client instance, mirroring the behavior in s3_get_object and s3_put_object.
  • When checking code for “not found” and bucket/config cases, you might also include "NotFound" (and possibly other provider-specific aliases) to avoid misclassifying some 404-style responses.

Both are incremental reliability tweaks; current behavior is functionally correct.
-->


640-683: s3_get_object: HEAD logging, size-based path, and torch.load offload look good.

  • HEAD-time logging around time_min/time_max and the 500 MB threshold for switching to the multipart downloader are reasonable and align with the large‑model use case.
  • Offloading torch.load to a thread via asyncio.to_thread is the right move to avoid blocking the event loop during CPU decompression.

One thing to double‑check is that the deployed PyTorch version supports weights_only=True in torch.load; if not, this will raise at runtime and should be guarded or dropped for older versions.

-->

Also applies to: 699-723


1563-1569: Local torch.load offload is correct; just mind weights_only compatibility.

Offloading the local torch.load to a thread matches the remote path and avoids event‑loop blocking. As with s3_get_object, ensure the runtime PyTorch version supports weights_only=True when loading these saved dicts, or gate that argument based on version.

-->


2178-2211: Distributed merge helpers are structurally sound; consider minor device-policy tweaks.

Overall:

  • _move_gather_result_to_device correctly walks state_dict and moves tensors (and nested containers) without touching metadata.
  • merge_gather_results properly filters None partials, aggregates metrics, concatenates per‑param lists, and recomputes a sensible global success_rate.

Two small considerations:

  • The default target_device path will often move everything back to CUDA (via current_device/device 0). In gather_distributed you already run the final device move explicitly; if you’d rather keep the merge purely on CPU for low‑memory environments, you could pass target_device="cpu" from the caller instead of relying on the default.
  • The duplicated move_to_device helper (here and in _move_gather_result_to_device) could be factored once to reduce drift, though this is cosmetic.

Functionally, this looks correct.
-->

Also applies to: 2214-2337

neurons/miner.py (2)

548-556: Pre‑training cleanup is effective; consider guarding CUDA calls.

Flushing the CUDA cache, synchronizing, and zeroing gradients before inner_steps helps avoid OOMs from prior windows. To make this robust to CPU‑only or --device cpu runs, you may want to wrap torch.cuda.empty_cache() and torch.cuda.synchronize() in if torch.cuda.is_available(): or check self.device.type == "cuda" here (and similarly in other cleanup sites).
-->


1103-1134: cleanup_window is aggressive and effective; guard CUDA for broader compatibility.

The between‑window cleanup:

  • Zeroes gradients on model and inner optimizer.
  • Drops error feedback for non‑owned params.
  • Clears autocast cache, empties CUDA cache multiple times, synchronizes, and runs GC.
  • Reuses check_memory_threshold to ensure post‑cleanup headroom.

Functionally this is fine for GPU miners. For robustness in environments where CUDA may not be available or --device cpu is used, you may want to wrap the torch.cuda.* calls in availability checks (similar to the earlier suggestion) to avoid runtime errors.
-->

src/tplr/neurons.py (3)

371-419: Batch-size helper currently unused; consider wiring it or removing it

_calculate_optimal_batch_size is a reasonable heuristic (sampling a few partials, applying a safety factor, clamping) but it isn’t invoked from _outer_step_incremental or elsewhere in this file. If incremental per‑batch processing is no longer planned in this path, consider either:

  • Calling this helper from _outer_step_incremental when you reintroduce true batching, or
  • Dropping the function to avoid dead code and reduce cognitive load.

446-591: Incremental path now “merge-then-apply once”; docstring still suggests true batching

_outer_step_incremental now deterministically sorts partials by rank, merges all valid partials into a single merged_result, and then applies it once via _outer_step_single, with good cleanup and optional memory logging. That matches the catch‑up semantics, but:

  • The docstring and argument docs still talk about “processing partials by merging them first, then applying once” (accurate) but also mention batch_size_override and budget‑based batching, which are no longer used.
  • _calculate_optimal_batch_size / _estimate_partial_memory_mb are not integrated, so memory_budget_mb here is effectively “for telemetry only”.

To avoid confusion, consider either wiring in true batch‑wise merging using the helpers, or updating the docstring and parameter descriptions to match the current “merge all once” behavior and noting that the budget is only logged, not enforced.


1554-1569: Step-size denominator now stabilized by learning_rate; consider min_step_size param

Using learning_rate as the baseline step size and clamping param_avg_change with torch.maximum is a sensible way to avoid huge step ratios when the EMA is tiny. Note that min_step_size in the function signature is now effectively unused; if you still want a global floor, you could clamp the denominator with it, or else remove/rename that parameter to avoid confusion.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 242ae7e and ca50d92.

📒 Files selected for processing (8)
  • hparams/hparams.json (1 hunks)
  • neurons/base_node.py (1 hunks)
  • neurons/miner.py (8 hunks)
  • neurons/trainer.py (4 hunks)
  • neurons/validator.py (14 hunks)
  • src/tplr/comms.py (20 hunks)
  • src/tplr/compress.py (1 hunks)
  • src/tplr/neurons.py (13 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • hparams/hparams.json
  • neurons/trainer.py
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (3)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/comms.py (1)
  • gather_with_reserve (2556-2836)
src/tplr/neurons.py (2)
src/tplr/compress.py (4)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (1)
  • merge_gather_results (2214-2337)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
neurons/validator.py (4)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/comms.py (3)
  • gather_with_reserve (2556-2836)
  • merge_gather_results (2214-2337)
  • get_with_retry (1625-1742)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1610-1815)
src/tplr/compress.py (1)
  • maybe_dequantize_values (647-697)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (20)
neurons/base_node.py (1)

61-62: Verification confirms the code change is correct and the comment is accurate.

The Trainer class exists at neurons/trainer.py:41 with an __init__ method at line 47. Both Miner (neurons/miner.py:64) and Validator (neurons/validator.py:90) inherit from (BaseNode, Trainer), making the super().__init__() call necessary for proper MRO chain initialization. The comment accurately reflects this multiple inheritance pattern.

src/tplr/compress.py (1)

637-645: Device alignment fix in _dequantize_values looks correct

Moving both lookup and shift to val.device eliminates cross-device indexing issues when val is on GPU. The guards keep it safe for any non-tensor qparam shapes. No further changes needed here.

neurons/validator.py (4)

90-141: Helper methods correctly normalize gather results across partial and merged modes

The new static helpers (_get_uids_from_gather_result, _get_skipped_uids_from_gather_result, _get_success_rate_from_gather_result) nicely encapsulate list‑vs‑merged handling and dedupe UIDs so success rate and metrics are robust even when partials include retries or fallbacks. This centralization reduces duplication and makes later logging/metric code much clearer.


3281-3295: Switch to get_with_retry for miner debug dicts improves robustness

Using self.comms.get_with_retry with a bounded timeout and explicit handling of None/success=False gives evaluate_miner_sync a clean failure path and avoids hanging on slow or missing debug uploads. Returning a structured failure result with sync_score=0.0 integrates smoothly with slash_for_poor_sync.


1679-1722: All verification checks passed; code changes are coherent and safe to approve

The verification confirms all reviewer assumptions:

  1. outer_step signature (src/tplr/neurons.py:250-267) accepts memory_budget_mb and batch_size_override as optional parameters
  2. ✓ Call site (neurons/validator.py:2542-2560) correctly passes both via getattr(self.hparams, ...) for optional feature support
  3. ✓ Helper methods _get_uids_from_gather_result, _get_skipped_uids_from_gather_result, and _get_success_rate_from_gather_result are consistently used across referenced lines (1745-1763, 2556-2561)
  4. ✓ Design properly handles both partial and merged gather_result inputs with minimal memory overhead

1539-1565: Design is sound; heterogeneous gather_result states are safe due to rank-aware implementation

The distributed/sequential gather flow correctly gates gather_with_reserve execution. The edge case you identified—non-master rank with gather_result=None while master succeeds—is handled safely by design:

  • _outer_step_single (called by all ranks) is rank-aware: only master reads gather_result (line 634-640)
  • In distributed mode, non-master ranks skip reading gather_result entirely (on_src=False when ddp=True and not master)
  • Master broadcasts per-parameter has_update flags (via _bcast_flag); non-master follows the broadcast regardless of local state
  • Therefore, a non-master rank can safely have gather_result=None—it's never dereferenced

No changes needed; the current behavior is correct. The design relies on explicit rank-awareness in _outer_step_single rather than requiring all-rank synchronization before the call.

src/tplr/comms.py (4)

1668-1673: Early exit on repeated NOT_FOUND plus extended grace window is sensible.

The increased time_max_grace_period and the consecutive_not_found cutoff prevent spinning for the full timeout when the object almost certainly doesn’t exist, while still giving some leeway for late uploads and clock skew. The logic correctly resets the counter on non‑NOT_FOUND statuses, so transient network errors still get full retry behavior.
-->

Also applies to: 1727-1739


1846-1875: Parallel existence checks and miner vs network failure attribution look correct.

The new block that:

  • collects failed downloads (Exception/None),
  • runs s3_object_exists for them in parallel, and
  • distinguishes: no bucket → “not registered”, object exists → network/timing issue (no slashing), object missing → miner didn’t upload (skippable/slashable),

is well structured and uses the cached existence results correctly in both the Exception and None paths. This should remove the previous N×5 s serial penalty and give clearer attribution without changing success semantics.
-->

Also applies to: 1880-1939


2371-2476: gather_distributed: partitioning, CPU all_gather, and device moves behave as intended.

  • UID partitioning via a cached round‑robin over sorted(uids) is deterministic and respects world_size/rank.
  • Local gather → optional local merge (when not returning partials) → CPU move before all_gather_object is a reasonable sequence for 70B‑scale gradients.
  • All ranks participate in the barrier/all‑gather when world_size>1 and dist_helper.is_distributed(), and non‑distributed cases correctly degenerate to a single‑element all_partials list.
  • The two return paths (return_partials=True vs merged result) are consistent with the declared return type, and partials are normalized (uids empty → None) before downstream use.

I don’t see logic errors here; remaining choices (e.g., doing the local merge on GPU then moving to CPU) are trade‑offs rather than bugs.
-->

Also applies to: 2479-2514, 2516-2551


2556-2700: gather_with_reserve: distributed integration and type consistency look good.

This refactor:

  • Decides use_distributed exactly once based on hparams.distributed_gather, world_size, and dist_helper.is_distributed().
  • Uses gather_distributed when enabled, with a clear “all ranks fall back together” path to sequential gather() if distributed returns no gradients or raises.
  • Preserves the contract that return_partials=True yields list[SimpleNamespace | None] and otherwise a merged SimpleNamespace, but returns None when there are truly no successful gradients.
  • Handles reserve retries for both partial and merged primaries, avoiding duplicate UIDs and aggregating metrics correctly.
  • Recomputes success_rate relative to gather_uids and returns None if even after reserves there are zero successes, which upstream callers already check.

Semantics are consistent across distributed and non‑distributed modes; no correctness issues spotted.
-->

Also applies to: 2701-2836

neurons/miner.py (5)

32-35: CUDA allocator env var set correctly before torch import.

Setting PYTORCH_CUDA_ALLOC_CONF before importing torch, and only when it’s absent, is the right pattern; it gives a sane default (expandable_segments:True) without overriding explicit user configuration.
-->


165-177: Distributed init and device wiring look coherent.

Using dist_helper.init_process_group(...) once in __init__ and then sourcing rank, world_size, local_rank, is_master, and device from dist_helper keeps all distributed state in one place. Falling back to self.config.device and wrapping it in torch.device is standard.

Just ensure dist_helper.init_process_group handles single‑process and CPU‑only cases gracefully with the chosen backend. -->


471-476: Peer refresh and broadcast across ranks are correct and necessary for distributed gather.

  • Refreshing self.comms.commitments on master after update_peers ensures bucket metadata is current before a gather.
  • Broadcasting self.comms.peers from rank 0 to all ranks via dist.broadcast_object_list under world_size > 1 and dist_helper.is_distributed() correctly synchronizes peer selection, and non‑master logging confirms reception.

This keeps all ranks aligned on the same peer set for subsequent distributed gathers.

Confirm that torch.distributed is initialized in exactly the same cases where dist_helper.is_distributed() returns True so the broadcast can’t be called before init. -->

Also applies to: 478-486


927-949: Debug successful/skipped peers: logic is correct given a live gather_result.

Assuming the lifecycle fix above is applied, the debug computation:

  • Aggregates uids and skipped_uids across partials (list case).
  • Falls back to the merged SimpleNamespace’s uids/skipped_uids otherwise.
  • Derives successful_peers as peers − skipped and logs both.

This correctly reflects “who actually contributed” independent of distributed vs sequential gather. No changes needed beyond ensuring gather_result still exists here.
-->


985-1007: Gather success-rate and skipped-peers metrics handle partial vs merged results correctly.

The metrics block:

  • Treats gather_result is None as 0 % success, with an empty skipped list.
  • For partial lists, aggregates successes and skipped across all partials and computes success_rate as total_uids / (total_uids + total_skipped) * 100.
  • For merged results, reuses gather_result.success_rate * 100 and its skipped_uids.

Logging skipped_peers to Influx as json.dumps(skipped_uids_list) gives a clear diagnostics trail. As above, this relies on gather_result not having been deleted earlier in the loop.
-->

Also applies to: 1060-1072

src/tplr/neurons.py (5)

48-69: CUDA device resolver avoids string-device CUDA pitfalls

The helper cleanly normalizes incoming device values and guarantees that CUDA APIs only see a torch.device('cuda:X') or None, which fixes prior string-device issues and degrades safely to None on CPU-only or invalid inputs. No changes needed here.


249-321: outer_step dispatch for merged vs partial results is straightforward

Routing list[SimpleNamespace] to _outer_step_incremental and everything else to _outer_step_single keeps the public API stable while making the new distributed path opt‑in and explicit. The keyword‑only arguments and propagation of memory_budget_mb / batch_size_override look consistent.


324-369: Recursive partial memory estimator correctly handles nested namespaces/containers

The new _estimate_partial_memory_mb walks tensors through SimpleNamespace, dicts, lists, and tuples, and applies a safety overhead before converting to MB, which aligns with the earlier review feedback about under‑estimating memory for nested payloads. This looks correct and robust.


421-444: _merge_partial_batch cleanly delegates to Comms.merge_gather_results

Using a thin wrapper around Comms.merge_gather_results with a local import avoids cycles and keeps the merge semantics centralized in comms.py. Behavior for empty batches (None) is clear. No issues here.


1612-1651: Overlap checker’s partial-support logic and UID mapping look correct

The updated check_uid_index_overlap correctly:

  • Accepts either a merged SimpleNamespace or list[SimpleNamespace] of partials.
  • Builds a uid_to_partial map so each UID is resolved in O(1) instead of scanning all partials (fixing the prior O(n×m) inner‑loop issue).
  • Tracks valid_uids_for_param so pair keys are stored in terms of the original UID indices, ensuring summary stats still index into the global uids list consistently.

The per‑param filtering (len(unpacked_indices) < 2) and the subsequent pair accumulation logic all line up with the existing reporting. This looks good.

Also applies to: 1685-1714, 1730-1735, 1743-1759

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
neurons/trainer.py (1)

872-929: Fix incorrect double gradient clipping with GradScaler

You're clipping gradients twice: once before scaler.unscale_() at line 872, and again after at lines 887–889 when capturing total_grad_norm. PyTorch's official pattern with AMP is unscale_()clip_grad_norm_()step()update(). Clipping before unscale operates on scaled gradients, violating the intended gradient magnitude constraint.

Remove the pre-unscale clip at line 872. The post-unscale clip at lines 887–889 is correct and properly captures the norm. The null_round branch setting total_grad_norm = 0.0 at line 912 remains appropriate: spin-up steps log 0.0, training steps log the actual clipped norm.

♻️ Duplicate comments (4)
neurons/validator.py (2)

1459-1486: Timestamp query + broadcast fixes prior time window inconsistency

Only querying query_block_timestamp on master and broadcasting ts_value to all ranks ensures a consistent (time_min, time_max) on every rank, including the fallback to time.time() when the query fails. This addresses the earlier inconsistency concerns while keeping the device handling reasonable.

Also applies to: 1488-1499


3875-3971: compute_peer_val_norms has a partials‑path bug (UnboundLocalError) and fragile quant_params handling

Two issues here:

  1. UnboundLocalError when gather_result is a list (partials)

    • In the if isinstance(gather_result, list): branch you never define state_dict.
    • Execution then falls through to the merged‑result loop below, which unconditionally references state_dict (initialized only in the else branch).
    • In distributed gather mode (where gather_result is a list), this will raise at the first parameter.
  2. Quantization metadata can be misaligned across partials

    • For partials, all_vals is extended for every vals entry, but all_quant_params is only extended when quant_params is not None.
    • If some partials have vals but missing or shorter quant_params, you end up with len(all_vals) != len(all_quant_params).
    • Passing those into self.compressor.maybe_dequantize_values(all_vals, all_quant_params, ...) can raise IndexError or silently compute norms on raw uint8 tensors instead of dequantized values.

Both issues can surface exactly in the new distributed/partials flow, making this a blocker for that path. You likely want the partials branch to:

  • fully compute norms and return clip_norm_dict early, and
  • only feed dequantization with aligned (vals, qparams) pairs, skipping any entries that are quantized but lack valid qparams.

Here is a concrete refactor that addresses both problems while keeping semantics close to what you intend:

 def compute_peer_val_norms(
-        self,
-        gather_result: SimpleNamespace | list[SimpleNamespace],
-    ) -> dict[str, torch.Tensor]:
-        """
-        Compute median L2 norms across all peers for each parameter.
-
-        Args:
-            gather_result: Either a merged SimpleNamespace or list of partials
-
-        Returns:
-            Dictionary mapping parameter names to median norms
-        """
-        clip_norm_dict = {}
-
-        # Handle both merged results and partials
-        if isinstance(gather_result, list):
-            # Partials: collect vals from all partials for each parameter
-            # Use model's first parameter device as target (typically cuda:0)
-            target_device = next(self.model.parameters()).device
-
-            # Helper to recursively move all tensors to target device
-            def move_to_device_recursive(obj):
-                if torch.is_tensor(obj):
-                    return obj.to(target_device)
-                elif isinstance(obj, dict):
-                    return {k: move_to_device_recursive(v) for k, v in obj.items()}
-                elif isinstance(obj, (list, tuple)):
-                    return type(obj)(move_to_device_recursive(item) for item in obj)
-                else:
-                    return obj
-
-            for n, p in self.model.named_parameters():
-                vals_key = n + "vals"
-                quant_key = n + "quant_params"
-
-                # Collect vals and quant_params from all partials
-                all_vals = []
-                all_quant_params = []
-
-                for partial in gather_result:
-                    # Skip None partials (ranks that gathered 0 peers)
-                    if partial is None:
-                        continue
-
-                    state_dict = partial.state_dict
-                    if not state_dict:
-                        continue
-
-                    vals = getattr(state_dict, vals_key, None)
-                    quant_params = getattr(state_dict, quant_key, None)
-
-                    if vals is not None:
-                        # vals is a list of tensors, one per peer in this partial
-                        # Move all tensors (including nested) to target device
-                        for v in vals:
-                            all_vals.append(move_to_device_recursive(v))
-                        # quant_params is also a list if present
-                        if quant_params is not None:
-                            for qp in quant_params:
-                                all_quant_params.append(move_to_device_recursive(qp))
-
-                if not all_vals:
-                    continue
-
-                # Dequantize all collected values (now all on same device)
-                vals_f32 = self.compressor.maybe_dequantize_values(
-                    all_vals,
-                    all_quant_params if all_quant_params else None,
-                    target_device,
-                )
-
-                # Compute median norm across all peers from all partials
-                norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device)
-                clip_norm_dict[vals_key] = torch.median(norms)
-        else:
-            # Merged result: original logic
-            state_dict = gather_result.state_dict
-            if not state_dict:
-                raise ValueError("Must have gather_result.state_dict to compute norms")
+        self,
+        gather_result: SimpleNamespace | list[SimpleNamespace | None],
+    ) -> dict[str, torch.Tensor]:
+        """
+        Compute median L2 norms across all peers for each parameter.
+
+        Args:
+            gather_result: Either a merged SimpleNamespace or list of partials.
+
+        Returns:
+            Dictionary mapping parameter names to median norms.
+        """
+        clip_norm_dict: dict[str, torch.Tensor] = {}
+
+        # Handle both merged results and partials
+        if isinstance(gather_result, list):
+            # Partials: collect vals from all partials for each parameter.
+            # Use model's first parameter device as target (typically cuda:0)
+            target_device = next(self.model.parameters()).device
+
+            def move_to_device_recursive(obj):
+                if torch.is_tensor(obj):
+                    return obj.to(target_device)
+                if isinstance(obj, dict):
+                    return {k: move_to_device_recursive(v) for k, v in obj.items()}
+                if isinstance(obj, (list, tuple)):
+                    return type(obj)(move_to_device_recursive(v) for v in obj)
+                return obj
+
+            for n, p in self.model.named_parameters():
+                vals_key = n + "vals"
+                quant_key = n + "quant_params"
+
+                # Separate unquantized and quantized entries to keep qparams aligned
+                all_unquantized_vals: list[torch.Tensor] = []
+                all_quantized_vals: list[torch.Tensor] = []
+                all_quantized_qparams: list[QuantParamsT] = []
+
+                for partial in gather_result:
+                    if partial is None:
+                        continue
+                    state_dict = getattr(partial, "state_dict", None)
+                    if not state_dict:
+                        continue
+
+                    vals = getattr(state_dict, vals_key, None)
+                    quant_params = getattr(state_dict, quant_key, None)
+                    if vals is None:
+                        continue
+
+                    # vals is expected to be a list-like (one per peer in this partial)
+                    for idx, v in enumerate(vals):
+                        v_dev = move_to_device_recursive(v)
+                        if torch.is_tensor(v_dev) and v_dev.dtype == torch.uint8:
+                            # Quantized entry – require matching qparams
+                            if (
+                                quant_params is None
+                                or idx >= len(quant_params)
+                                or quant_params[idx] is None
+                            ):
+                                tplr.logger.warning(
+                                    f"Missing quant_params for {vals_key}[{idx}] in a partial; "
+                                    "skipping for norm computation."
+                                )
+                                continue
+                            qp = move_to_device_recursive(quant_params[idx])
+                            all_quantized_vals.append(v_dev)
+                            all_quantized_qparams.append(qp)
+                        else:
+                            # Treat non‑uint8 as already in float space
+                            all_unquantized_vals.append(v_dev)
+
+                if not all_unquantized_vals and not all_quantized_vals:
+                    continue
+
+                norms_list: list[torch.Tensor] = []
+
+                # Dequantize quantized entries
+                if all_quantized_vals:
+                    vals_f32_q = self.compressor.maybe_dequantize_values(
+                        all_quantized_vals,
+                        all_quantized_qparams,
+                        target_device,
+                    )
+                    norms_list.extend(torch.norm(v, p=2) for v in vals_f32_q)
+
+                # Unquantized entries: just compute norms
+                if all_unquantized_vals:
+                    norms_list.extend(
+                        torch.norm(v.to(target_device), p=2)
+                        for v in all_unquantized_vals
+                    )
+
+                if norms_list:
+                    norms = torch.stack(norms_list).to(p.device)
+                    clip_norm_dict[vals_key] = torch.median(norms)
+
+            # Partials handled entirely above
+            return clip_norm_dict
+
+        # Merged result: original logic
+        state_dict = gather_result.state_dict
+        if not state_dict:
+            raise ValueError("Must have gather_result.state_dict to compute norms")
@@
-        for n, p in self.model.named_parameters():
+        for n, p in self.model.named_parameters():
             vals_key = n + "vals"
             quant_key = n + "quant_params"
@@
-            vals_f32 = self.compressor.maybe_dequantize_values(
-                vals, quant_params, p.device
-            )
+            vals_f32 = self.compressor.maybe_dequantize_values(
+                vals, quant_params, p.device
+            )
@@
-        return clip_norm_dict
+        return clip_norm_dict

This keeps the merged‑result behavior unchanged, makes the partials path self‑contained (no state_dict leakage), and defensively skips any quantized entries that don’t have usable qparams instead of crashing.

src/tplr/comms.py (1)

2562-2836: Sequential fallback after distributed failure remains per-rank; confirm you’re okay with potential cross-rank divergence.

The new gather_with_reserve logic integrates distributed vs sequential paths and fixes the previous type-contract issue when return_partials=True:

  • When use_distributed is True, all ranks participate in gather_distributed, and you correctly normalize primary to a list when return_partials=True.
  • If distributed gather returns no gradients or throws, all ranks fall back to sequential self.gather(...), again wrapping in a single-element list when return_partials=True.
  • Logging and final success_rate computation now treat both merged and partials cases consistently.

However, the sequential fallback path still has each rank call gather() independently against the full gather_uids set, and the results are not synchronized afterward. In steady state, all ranks will most likely see the same gradients (same S3 objects), but under asymmetric network errors or transient bucket issues, different ranks could end up with slightly different primary contents and success rates.

Previous reviews already raised the “rank consistency after distributed fallback” concern; this remains conceptually the same, just now with list-wrapping fixed. Unless downstream code is explicitly robust to cross-rank differences here, you may want to:

  • Either broadcast one rank’s sequential fallback result (merged or partials) to all ranks after the fallback, or
  • Explicitly document that fallback windows are allowed to be slightly inconsistent across ranks.
src/tplr/neurons.py (1)

593-609: Memory budget metric still uses inverted flag logic (budget_exceeded).

Within _outer_step_single you correctly:

  • Resolve a CUDA device for memory accounting via _resolve_cuda_device.
  • Compute final_memory_mb and utilization_pct relative to memory_budget_mb.

But:

budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1

means:

  • budget_exceeded == 0 when the budget is exceeded.
  • budget_exceeded == 1 when you are within budget.

For a metric named memory_budget_violation, this is inverted and will be quite confusing for monitoring and WandB dashboards.

This was called out in earlier reviews and the issue is still present here; you probably want:

budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0

so that 1 consistently means “violation”.

Also applies to: 629-630, 848-877

🧹 Nitpick comments (9)
neurons/miner.py (1)

959-982: Metrics success‑rate math is sound; consider reusing upstream rate and deduping skipped peers

The recomputation of gather_success_rate for partials (using total_uids / (total_uids + total_skipped) * 100) is reasonable and matches the intended semantics, and the merged case reuses gather_result.success_rate. Two small polish ideas:

  • For partials, you could simply reuse the aggregate success rate already attached to each partial by gather_with_reserve instead of recomputing it here.
  • Before logging skipped_uids_list to Influx, consider deduping:
    skipped_uids_list = sorted(set(skipped_uids_list))
    to avoid duplicates if any partials share skipped_uids.

Both are optional, but would slightly simplify the code and make skipped_peers metrics cleaner.

Also applies to: 1046-1046

neurons/validator.py (1)

1539-1565: Consider adding all-rank consensus after gather_with_reserve for strict "all-or-nothing" behavior

The concern is valid: in distributed mode, if gather_with_reserve fails on a non-master rank but succeeds on the master, the current one-way broadcast from master will reset skip_window from the non-master's local True to the master's False. The non-master then continues with gather_result=None and skip_window=False, creating an inconsistent state.

To enforce strict consistency, add an all_ok() consensus step after the gather attempt:

# After gather_with_reserve try/except block (around line 1595)
gather_ok = dist_helper.all_ok(gather_result is not None, self.device, tag="gather_with_reserve")
if use_distributed_gather and not gather_ok:
    skip_window = True

This ensures that if any rank's gather fails, all ranks skip the window. The all_ok() helper already exists at src/tplr/distributed.py:185 and is used in similar contexts elsewhere (save/eval/apply operations at lines 1981, 2111, 2177, 2209, 2286).

Also applies to: 1569-1596, 1597-1633

src/tplr/comms.py (3)

483-582: S3 existence helper looks solid; consider handling additional “NotFound”-style codes.

The new s3_object_exists implementation is conservative and aligns well with the slashing requirements (only returning False for clear 404/config errors, and treating network/service issues as transient). One small robustness gap: some S3-compatible backends may emit "NotFound" as the error code for missing keys. Right now you only treat {404, "NoSuchKey"} as “not found”; adding "NotFound" to that set would avoid misclassifying a missing object as a transient network case.

This is minor and could be done later; overall the structure and logging semantics here look good.


2213-2337: merge_gather_results semantics match gather() layout; minor duplication with _move_gather_result_to_device.

Functionally this does what we need:

  • Filters out None partials and bails cleanly when nothing usable is present.
  • Computes a device for consolidation (target_device) with a reasonable CUDA/CPU fallback strategy.
  • For each param key, concatenates per-partial lists while recursively moving tensors to target_device.
  • Aggregates metrics (time, bytes, uids, skipped_uids) and computes a global success_rate.

Two small follow-ups you might consider:

  1. You already have _move_gather_result_to_device above; the local move_to_device here is effectively re-implementing the same recursion over containers. Reusing the static helper would remove duplication and centralize behavior.

  2. When torch.cuda.is_available() is True but current_device() fails, you optimistically try to set_device(0). In practice that’s fine, but wrapping that in a single helper (or the _resolve_cuda_device pattern you use in neurons.py) would make behavior more predictable and easier to test.

Functionally this looks correct; the above are just refactors.


2339-2552: Distributed gather flow is coherent; watch memory when merging all partials on GPU.

The overall gather_distributed flow is solid:

  • Deterministic round‑robin partitioning with a per-window cache (_partition_cache), so all ranks agree on my_uids.
  • Local gather per rank on its subset, with logging and a local pre-merge (merge_gather_results([partial_result])) when return_partials=False.
  • Explicit move of each rank’s partial to CPU before all_gather_object to avoid GPU OOM.
  • After all‑gather, either:
    • Return a normalized list of partials with rank metadata when return_partials=True, or
    • Merge all partials into one merged_result (on CPU) and then move back to the requested device.

Two things to keep in mind:

  • In the return_partials=False path you locally pre‑merge on each rank and then globally merge all_partials, which can temporarily double CPU footprint. That’s unavoidable if you want a fully merged namespace, but it’s worth being aware of when you scale peers×ranks up.
  • In the return_partials=True path, downstream code (e.g. _outer_step_incremental) currently merges all partials again into a single result on the target device, which largely reverts memory savings from keeping them as partials. If you intend to reintroduce true incremental/batched application later, you already have the helper scaffolding (_estimate_partial_memory_mb, _calculate_optimal_batch_size) to do so.

No immediate correctness bug here, but future tuning for very large peer sets might want to avoid “merge-all-partials-twice” patterns.

src/tplr/neurons.py (3)

22-23: CUDA device resolver is reasonable; consider catching broader errors for robustness.

_resolve_cuda_device gives you a clean way to:

  • Short-circuit when CUDA isn’t available.
  • Prefer current_device() when possible.
  • Fall back to "cuda:0" when the provided device is invalid.

One minor hardening tweak you might consider is broadening the exception handling around torch.cuda.current_device() and torch.device(device) to catch Exception rather than only RuntimeError/TypeError/ValueError, to cover any assertion-style failures from CUDA initialization in some PyTorch builds. Not critical, but would make this helper more bulletproof.

Also applies to: 48-69


249-321: outer_step dispatch between single vs list results is clean; note that batching knobs are currently a no-op.

The updated outer_step signature and dispatch logic look good:

  • Accepts gather_result as either SimpleNamespace, list of partials, or None.
  • For lists, delegates to _outer_step_incremental; otherwise falls back to _outer_step_single.
  • Propagates memory_budget_mb and batch_size_override through.

However, in the current implementation:

  • _outer_step_incremental does not use batch_size_override, and
  • There is no call to _calculate_optimal_batch_size for true incremental/batched processing.

So from a caller’s perspective, the incremental path always merges all partials into a single merged_result before applying, and the “batching” knobs are effectively ignored (only memory metrics are logged).

If you intend this API to provide real memory-bounded batching, consider either wiring _calculate_optimal_batch_size into _outer_step_incremental or, for now, documenting that memory_budget_mb / batch_size_override are advisory/logging only.


324-369: Per-partial memory estimator is correct for current layouts but currently unused.

_estimate_partial_memory_mb correctly:

  • Recurses through SimpleNamespace, dicts, lists, tuples to find tensors.
  • Sums element_size() * nelement() and applies a 1.3× overhead factor.
  • Logs MB-per-UID when it can.

Right now this helper isn’t used anywhere in the file (including _outer_step_incremental), which means it adds maintenance cost without delivering behavior. If you don’t plan to hook it up to dynamic batch sizing soon, it might be better to either:

  • Wire it into _calculate_optimal_batch_size and then use that in _outer_step_incremental, or
  • Drop it entirely until you reintroduce true incremental batching.
neurons/trainer.py (1)

972-1021: Align returned grad/Adam metrics with new per-step grad_norm logging

You now log grad_norm: {total_grad_norm:.4f} per inner step (Line 928), but the summary metrics at the end of inner_steps (adam_metrics, global_grad_norm, global_weight_norm, Lines 977–1021) are computed after self.inner_optimizer.zero_grad(set_to_none=True) has run in the last step, so p.grad is typically None and all grad-based metrics will be ~0.

Docstring on compute_adam_metrics explicitly says it “should be called BEFORE gradients are cleared,” which isn’t currently true here.

Consider one of:

  • Move the metrics computation into the if step_now: block, right before zero_grad, and cache the last-step metrics to return; or
  • Defer zero_grad until after metrics are computed for the final step; or
  • If you only care about per-step total_grad_norm, drop or rename the returned global_grad_norm / grad-based Adam metrics to avoid confusion.

This will keep aggregated metrics consistent with the newly logged per-step grad_norm.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ca50d92 and 8488c13.

📒 Files selected for processing (6)
  • neurons/miner.py (9 hunks)
  • neurons/trainer.py (4 hunks)
  • neurons/validator.py (14 hunks)
  • src/tplr/comms.py (20 hunks)
  • src/tplr/compress.py (1 hunks)
  • src/tplr/neurons.py (13 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/tplr/compress.py
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (3)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/comms.py (1)
  • gather_with_reserve (2556-2836)
src/tplr/neurons.py (2)
src/tplr/compress.py (6)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • ChunkingTransformer (145-292)
  • TopKCompressor (295-697)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (1)
  • merge_gather_results (2214-2337)
neurons/validator.py (5)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/comms.py (3)
  • gather_with_reserve (2556-2836)
  • merge_gather_results (2214-2337)
  • get_with_retry (1625-1742)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1610-1815)
src/tplr/compress.py (1)
  • maybe_dequantize_values (647-697)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (22)
neurons/miner.py (6)

31-35: CUDA allocator pre-config and dist import look solid

Setting PYTORCH_CUDA_ALLOC_CONF before importing torch and pulling in torch.distributed as dist for the later broadcast_object_list is correct and low‑risk; no changes needed here.

Also applies to: 42-42


548-556: Pre‑inner‑step GPU cleanup is a good guardrail against OOM

The aggressive torch.cuda.empty_cache()/synchronize() plus model.zero_grad(set_to_none=True) before inner_steps is a sensible addition to shed residue from the previous outer step and reduce OOM risk; the conditional on self.model also avoids NPEs.


734-755: Distributed vs sequential gather selection is consistent; double‑check sequential semantics across ranks

The use_distributed_gather flag and the if self.is_master or use_distributed_gather gate correctly ensure all ranks participate in distributed mode and only rank 0 runs the sequential gather, with return_partials wired as expected. The completion logging also correctly handles both list‑of‑partials and merged SimpleNamespace results.

One thing to verify: in sequential mode, only master sets gather_result, but all ranks later call outer_step(gather_result) guarded by should_update coming from dist_helper.all_ok(...). Please confirm that all_ok and outer_step are implemented such that non‑master ranks never end up calling outer_step(None) (e.g., by syncing should_update appropriately or broadcasting a non‑None result), to avoid per‑rank divergence.

Also applies to: 759-760, 774-820


849-852: Post‑outer‑step gradient clearing is appropriate

Zeroing model gradients right after a successful outer step is a good addition and keeps the next window from accidentally reusing stale grads; this aligns with the aggressive cleanup strategy elsewhere.


904-924: Successful/skipped peers aggregation now correctly supports partial and merged gather results

The updated debug block that derives successful_peers/skipped_peers by aggregating uids and skipped_uids across a list of partials (or reading directly from a merged result) removes the previous incompatibility when return_partials=True. Using set arithmetic against self.comms.peers keeps the debug output robust even if partials differ.


1066-1074: gather_result cleanup now happens at the right time and is safely scoped

Clearing state_dict.__dict__ for each partial (or merged) result, then deleting gather_result only after all debug/metrics usage, is a good fix for the earlier lifecycle issue and should materially reduce peak memory. The defensive hasattr(..., "__dict__") checks avoid breaking if state_dict ever changes type.

neurons/validator.py (7)

94-141: Static helpers for gather_result look correct and centralize logic nicely

The new _get_uids_from_gather_result, _get_skipped_uids_from_gather_result, and _get_success_rate_from_gather_result correctly handle both partial lists and merged namespaces, including deduplication for partials. This should make downstream code simpler and less error‑prone.


982-988: OpenSkill rating initialization before rate call is sound

Ensuring all window_uids have initialized openskill_ratings before building teams avoids key/shape inconsistencies and makes the rating update path robust for first‑time peers.


1399-1416: Peers/reserve broadcast to all ranks is appropriate for distributed gather

Broadcasting [self.comms.peers, self.comms.reserve_peers] from master ensures all ranks use the same peer sets when distributed_gather is enabled. The guard on world_size > 1 and dist_helper.is_distributed() avoids misuse in non‑distributed runs.


1690-1703: Reuse of helpers for overlap logging and gather metrics is consistent

Using _get_uids_from_gather_result, _get_skipped_uids_from_gather_result, and _get_success_rate_from_gather_result for:

  • overlap logging (UIDs checked),
  • skipped UID logging,
  • and gather‑quality metrics

keeps behavior consistent between distributed (partials) and sequential (merged) modes and avoids hand‑rolled counting in multiple places.

Also applies to: 1718-1722, 1744-1746


1888-1903: Passing raw gather_result into compute_peer_val_norms aligns with its new signature

The call now correctly relies on compute_peer_val_norms to handle both merged namespaces and partial lists without pre‑merging on the validator side, which is what you want in distributed gather mode. The remaining caveats are inside compute_peer_val_norms itself (see separate comment).


2556-2561: Wiring gradient memory budget and incremental batch size into outer_step

Plumbing gradient_memory_budget_mb and incremental_batch_size from hparams into tplr.neurons.outer_step via optional kwargs is a clean way to make incremental processing opt‑in without changing existing behavior when these fields are absent.


3281-3295: Switch to comms.get_with_retry for debug dict fetching improves robustness

Using get_with_retry with a bounded timeout and explicitly handling the None case avoids tight retry loops and clarifies failure handling. The debug_result is None or not debug_result.success guard correctly covers both timeout and error responses, and the warning log is defensive even when debug_result is None.

src/tplr/comms.py (4)

1668-1673: get_with_retry early-exit semantics for repeated NOT_FOUND look reasonable.

The new time_max_grace_period and consecutive NOT_FOUND tracking are well thought out:

  • Extending the grace period to 10s should reduce false negatives from clock skew.
  • Cutting off after 5 consecutive NOT_FOUND responses avoids spinning uselessly when a file clearly isn’t going to show up, without affecting TOO_LATE/TOO_EARLY handling.

No changes needed here from my side.

Also applies to: 1727-1739


1846-1875: Parallel existence checks + slashing attribution logic look correct and materially safer.

The new block that:

  • Collects all (uid, key, peer_bucket) for errored/None responses,
  • Runs s3_object_exists for them in parallel, and
  • Distinguishes “file exists but download/get failed” vs “file truly missing” vs “no bucket committed”

fixes the previous over-slashing risk and removes the sequential 5s-per-failure bottleneck.

Behavioral details look right:

  • Unregistered (peer_bucket is None) UIDs are treated as miners that didn’t upload and go straight to skipped_uids.
  • If file_exists is True, you log a warning but do not append to skipped_uids, explicitly avoiding slashing for network/time-window issues.
  • Only confirmed-missing files (file_exists is False) get tagged as legitimate upload failures and added to skipped_uids.

I don’t see correctness issues here; this is a clear improvement.

Also applies to: 1882-1913, 1915-1939


142-145: Partition cache and deterministic UID assignment across ranks are well-structured.

The _partition_cache keyed by (window, tuple(sorted_uids), world_size) with an LRU cap of 32 entries gives you deterministic round‑robin assignments while avoiding recomputing partitions for the same (window, peer set, world size).

A few minor observations:

  • Using sorted_uids in the key guarantees stability even if call sites pass different orderings.
  • Falling back to effective_world_size = max(world_size, 1) avoids division issues when world_size is misconfigured.
  • For typical validator workloads, a cache size of 32 windows should be more than enough; if you ever do very long runs with frequent re-partitioning you might consider making this limit a config knob, but it’s not necessary now.

No action required; just confirming the design looks sound.

Also applies to: 2375-2393


640-649: Based on my verification:

PyTorch version compatibility confirmed: The project requires torch>=2.7.1, which fully supports the weights_only argument in torch.load.

Saved format compatibility confirmed: All checkpoint files saved in the codebase use plain Python dicts containing tensors and primitives (validator state, comms data, etc.), which are fully compatible with weights_only=True according to PyTorch's serialization design.

Code locations verified:

  • Lines 714-717 (src/tplr/comms.py): Remote file loading with weights_only=True
  • Lines 1564-1565 (src/tplr/comms.py): Local file loading with weights_only=True

Both locations extract state_dict and global_step from a plain dict format that is compatible with weights_only=True. No runtime errors should occur from the weights_only parameter itself.

src/tplr/neurons.py (4)

371-419: Incremental outer step now “merge-then-apply”; batching knobs are ignored but semantics match catch‑up.

_outer_step_incremental:

  • Filters and deterministically sorts partials by partial.rank (or metadata) for reproducible behavior.
  • Merges all valid partials into a single merged_result via Comms.merge_gather_results, then calls _outer_step_single once.
  • Aggressively frees partial state_dicts and clears CUDA cache before and after application.
  • Logs memory deltas when memory_budget_mb + usable CUDA device are present.

This ensures the application semantics match the catch‑up path (one merged gradient application) and is easy to reason about.

Trade-offs to be aware of:

  • It no longer processes partials in smaller batches, so peak memory is that of a fully merged gather on the target device, plus whatever overhead merge_gather_results needs.
  • batch_size_override is explicitly documented as unused, and _calculate_optimal_batch_size is not invoked.

If this is an intentional simplification for now, the code is consistent; just be aware that the “incremental” path is currently about how you construct the merged result (via distributed gather/partials), not about memory-bounded incremental application.

Also applies to: 421-444, 446-587


1413-1477: Catch-up outer_step integration and debug-dict comparison flow look coherent.

The catch-up loop now:

  • Uses outer_step uniformly for both aggregated objects and live gather fallbacks, so the same application path is exercised as in normal training.
  • Adds a per-window debug-dict comparison (when available) to estimate “steps behind” and log L2 norms, while guarding errors so catch-up continues even if debug data is missing or malformed.
  • Cleans up GPU and CPU memory aggressively after each window.

This wiring looks correct and should make catch-up behavior much easier to reason about and monitor.


1554-1569: Step-size denominator change in compare_model_with_debug_dict is a good stabilization.

In compare_model_with_debug_dict, switching to:

  • Use learning_rate as a baseline step size, and
  • Take max(param_avg_change[name], learning_rate) elementwise

before computing step_ratio = abs_vec / step_vec is a sensible improvement:

  • It avoids exploding step ratios when the EMA of parameter changes becomes extremely small in converged regimes.
  • The learning_rate lower bound keeps ratios on a more interpretable scale.

Implementation details (shape guard + fallback when stored slice length mismatches) look correct.


1609-1651: Overlap checker’s support for partial results and O(n) UID lookup look correct; pair indexing is consistent.

The updated check_uid_index_overlap:

  • Accepts either a merged SimpleNamespace or a list of partials.
  • For partials, builds:
    • uids as a flat list of all UIDs across partials,
    • uid_to_partial[uid] = (partial, idx_in_partial) for O(1) lookup,
    • all_state_dicts as a parallel list, though it’s no longer used on the partial path.
  • When unpacking indices:
    • Uses per-UID state_dict slices (uid_idx_in_partial) for partials, and
    • The original merged indexing for single-result cases.

The valid_uids_for_param bookkeeping plus:

uid_idx_i = valid_uids_for_param[i]
uid_idx_j = valid_uids_for_param[j]
acc = pair_acc[(uid_idx_i, uid_idx_j)]
...
uids[i], uids[j]

maintains a consistent mapping from pair-accumulator keys back to global uids indices, so offender attribution and summary logs still refer to the right peer IDs.

This looks like a correct and more scalable implementation than the original O(n×m) search.

Also applies to: 1684-1731, 1743-1760

neurons/trainer.py (1)

1023-1040: Threading new memory_budget_mb / batch_size_override into outer_step looks good—verify signature & config

Passing

  • memory_budget_mb=getattr(self.hparams, "gradient_memory_budget_mb", None) and
  • batch_size_override=getattr(self.hparams, "incremental_batch_size", None)

is backward-compatible and keeps trainer decoupled from whether these fields exist in all configs.

Two checks to confirm:

  1. Ensure tplr.neurons.outer_step has matching keyword params with sensible defaults so older call sites (if any) remain valid.
  2. Decide how to handle non-positive values (e.g., 0 or <0) in these hparams—treat as “no limit” or validate earlier and keep only None reaching outer_step.

If both are already handled in outer_step, this wiring is solid.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
src/tplr/neurons.py (1)

842-870: Fix inverted budget_exceeded flag logic.

Line 848 has inverted logic: budget_exceeded is set to 0 when memory exceeds budget and 1 when within budget. This makes the memory_budget_violation metric (line 867) report the opposite of reality.

Apply this diff:

-        budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1
+        budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0
neurons/validator.py (1)

1902-1906: Fix compute_peer_val_norms for partials: state_dict undefined + qparams alignment risk

When gather_result is a list (partials path):

  • After the if isinstance(gather_result, list): ... block completes, execution falls through to the merged‑result logic. state_dict is only defined in the else branch, so the subsequent loop
for n, p in self.model.named_parameters():
    vals_key = n + "vals"
    ...
    vals = getattr(state_dict, vals_key, None)

will raise UnboundLocalError: local variable 'state_dict' referenced before assignment. This will trigger whenever distributed_gather is enabled and compute_peer_val_norms is invoked on partials.

  • Inside the partials branch, all_vals and all_quant_params can become length‑misaligned when some partials have vals but no quant_params, or mismatched list lengths. Passing such arrays into self.compressor.maybe_dequantize_values(all_vals, all_quant_params if all_quant_params else None, ...) can cause IndexError or incorrectly treat quantized uint8 tensors as unquantized floats, echoing the earlier misalignment concern.

Please:

  1. Return early from the partials branch after populating clip_norm_dict, so the merged‑result loop is not executed for lists.
  2. Align all_vals and all_quant_params defensively:
    • Only append qparams when present and matching the number of vals.
    • For unquantized vals, append None sentinel entries so vals and qparams stay index‑aligned.
    • Optionally filter out any residual (v, qp) pairs where v.dtype == torch.uint8 and qp is None before calling maybe_dequantize_values.

An example structural fix:

-        if isinstance(gather_result, list):
+        if isinstance(gather_result, list):
             ...
-                if not all_vals:
-                    continue
+                if not all_vals:
+                    continue
                 ...
-                clip_norm_dict[vals_key] = torch.median(norms)
-        else:
+                clip_norm_dict[vals_key] = torch.median(norms)
+            return clip_norm_dict  # <-- avoid falling through to merged path
+        else:
             # Merged result: original logic
             state_dict = gather_result.state_dict

and tighten the all_vals / all_quant_params accumulation as per the earlier suggestion to keep them length‑aligned.

Also applies to: 3876-3974

🧹 Nitpick comments (4)
src/tplr/comms.py (4)

483-582: s3_object_exists error handling is solid; consider purging broken clients on connection errors.

The structured use of HTTPStatusCode/Error.Code and conservative defaults looks good and should reduce false slashing. One minor improvement: unlike other S3 helpers in this class, this function never calls _purge_s3_client on ConnectionClosedError/ClientError, which means a bad client can remain in the pool and keep failing future HEADs. You could call await self._purge_s3_client(bucket) when you detect persistent client-side issues (e.g., ConnectionClosedError or non‑404/403 4xx), while keeping the same boolean return semantics.


2151-2176: partition_uids_across_ranks helper is currently unused; consider wiring it in or dropping it.

You now have both:

  • partition_uids_across_ranks() (round‑robin helper), and
  • Inline partitioning in gather_distributed using the same sorted_uids[r::world_size] pattern plus _partition_cache.

Because the static helper is not used, it may confuse future readers and drift from the cached implementation. Either:

  • Reuse the helper inside gather_distributed (e.g., to build each partitions[r]), or
  • Remove the helper if the cached implementation is the only intended path.

This keeps the public surface lean and avoids duplication.

Also applies to: 2375-2390


2178-2338: Avoid redundant device moves in merge_gather_results and gather_distributed.

When return_partials=False:

  • merge_gather_results defaults target_device to cuda:{current_device} when CUDA is available, and moves all tensors there.
  • Later, gather_distributed may move merged_result again to device (which can be a different GPU index).

On multi‑GPU setups this can cause two to() passes and cross‑GPU copies if devicecuda:{current_device}. To tighten this:

  • Pass the caller’s device explicitly into merge_gather_results from gather_distributed, and
  • Skip the final _move_gather_result_to_device(merged_result, device) when you already merged onto that device.

This keeps semantics but avoids unnecessary transfers and surprises around which GPU is used.

Also applies to: 2484-2504


2339-2518: Sequential fallback after distributed gather can still yield divergent results across ranks.

In the distributed path of gather_with_reserve:

  • When distributed gather fails or returns no gradients, every rank independently calls self.gather(...) as a sequential fallback and uses its own result (primary), without any subsequent broadcast or synchronization.
  • Because gather() internally does S3 I/O with retries and time‑window filtering, different ranks can legitimately end up with different subsets of successful UIDs and different skipped_uids, while downstream code may be assuming cross‑rank consistency.

For stricter consistency, consider:

  • Having only dist_helper.is_master perform the sequential gather(), then broadcasting the result (merged or partials) to all ranks via dist_helper.all_gather_object/broadcast_object, or
  • Explicitly documenting that in fallback mode each rank may see a different primary and that downstream consumers must not assume identical results.

Also in partials mode:

  • Sequential fallback currently returns a list with a single partial (or [None]), and those partials do not have a rank attribute set (unlike the distributed case). Ensure incremental consumers handle this gracefully (e.g., don’t rely on partial.rank always being present or list length == world size).

These are behavioral clarifications rather than correctness bugs, but worth tightening given this is core training infrastructure.

Also applies to: 2523-2803

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8488c13 and 7526353.

📒 Files selected for processing (4)
  • neurons/miner.py (9 hunks)
  • neurons/validator.py (15 hunks)
  • src/tplr/comms.py (20 hunks)
  • src/tplr/neurons.py (13 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (3)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/comms.py (1)
  • gather_with_reserve (2523-2803)
neurons/validator.py (4)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/comms.py (3)
  • gather_with_reserve (2523-2803)
  • merge_gather_results (2214-2337)
  • get_with_retry (1625-1742)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1604-1809)
src/tplr/compress.py (1)
  • maybe_dequantize_values (647-697)
src/tplr/comms.py (2)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
src/tplr/neurons.py (2)
src/tplr/compress.py (6)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • ChunkingTransformer (145-292)
  • TopKCompressor (295-697)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (1)
  • merge_gather_results (2214-2337)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (26)
src/tplr/comms.py (2)

666-669: Offloading torch.load is a good change; verify weights_only=True matches your PyTorch floor.

Moving both remote and local torch.load calls into asyncio.to_thread is exactly what you want to avoid blocking the event loop on CPU decompression. The only compatibility concern is the use of weights_only=True, which is only supported on newer PyTorch versions; on older releases it will raise a TypeError. If you still support 1.x/early 2.x, consider guarding this flag behind a version check or kwargs construction; otherwise, document that the minimum PyTorch version includes weights_only.

Also applies to: 710-723, 1564-1566


1846-1939: New parallel existence checks and slashing classification in gather look correct.

The two‑phase approach—first collecting failed UIDs, then fanning out s3_object_exists checks in parallel and using those results to distinguish “no bucket”, “file truly missing”, and “file exists but download failed”—matches the intended slashing semantics and avoids the previous serial HEAD bottleneck. The reuse of cached existence_results in both the Exception and None branches is consistent, and skipping skipped_uids updates when the file exists prevents penalizing miners for network/timeout issues. No changes needed here.

neurons/miner.py (8)

31-34: LGTM: CUDA allocator configured correctly.

Setting expandable_segments:True before importing torch is a recommended practice for better memory management with PyTorch CUDA operations.


474-490: Well implemented: Commitments now synchronized across all ranks.

The master rank refreshes commitments and broadcasts them along with peers, ensuring all ranks have consistent bucket mappings for distributed gather. This addresses the stale commitment concern from previous reviews.


552-560: LGTM: Proactive memory cleanup before training.

The aggressive memory cleanup before inner_steps ensures that memory from the previous outer step is fully released, helping prevent OOM errors.


738-824: Well designed: Distributed vs sequential gather mode selection and logging.

The code correctly:

  • Determines gather mode based on hparams and distributed state
  • Ensures all ranks participate in distributed gather while only master performs sequential gather
  • Handles both partial results (list) and merged results (SimpleNamespace) in logging
  • Includes proper remote fetch parameters (local=False, stale_retention=100) that were flagged in past reviews

852-856: LGTM: Gradients cleared after outer step.

Properly clears gradients after the outer step is applied, preventing memory accumulation.


908-927: LGTM: Robust handling of both partial and merged gather results.

The code correctly extracts successful and skipped UIDs from both partial lists and merged results, using sets to deduplicate when aggregating across partials.


963-986: LGTM: Comprehensive success rate calculation for all gather result types.

The code correctly handles None, partial lists, and merged results when calculating success rates and extracting skipped UIDs for metrics logging.


1067-1081: LGTM: Thorough cleanup after gather_result consumption.

The code properly clears state_dict contents from both partial and merged gather results after they've been fully used, followed by garbage collection and CUDA cache cleanup to reclaim memory.

src/tplr/neurons.py (8)

48-69: LGTM: Robust CUDA device resolution helper.

This helper correctly converts various device representations to torch.device objects, preventing the string-device errors flagged in previous reviews. The fallback logic is sound.


250-321: LGTM: Clean dispatch pattern for incremental vs single outer step.

The outer_step function now dispatches to the appropriate implementation based on whether gather_result is a list of partials or a single merged result, with consistent parameter passing to both paths.


324-368: Well implemented: Memory estimation now handles all container types.

The recursive tensor iterator correctly walks through SimpleNamespace, dicts, lists, and tuples to accurately measure memory, addressing the previous concern about underestimation.


371-418: LGTM: Sensible batch size calculation based on memory budget.

The function estimates average partial size from samples and calculates how many partials fit in the available budget, with appropriate safety margin and bounds checking.


421-443: LGTM: Simple wrapper delegating to existing merge logic.

Appropriately delegates to Comms.merge_gather_results for consistency with the rest of the codebase.


446-584: Excellent: Incremental path merges all partials before applying once.

The implementation correctly merges all partials into a single result before applying it via _outer_step_single. This ensures consistency with catch-up behavior and prevents gradient double-application. The deterministic rank-based sorting ensures reproducibility.


1549-1563: LGTM: More stable step ratio calculation.

Using the maximum of param_avg_change and learning_rate prevents division by tiny numbers when the model has converged, resulting in more meaningful step-behind estimates.


1604-1809: Well optimized: O(n) lookup complexity for partial results.

The function now supports both partial and merged gather results. The uid_to_partial mapping (line 1625-1635) enables O(1) lookups when processing partials (line 1690-1693), addressing the O(n×m) complexity concern from previous reviews.

neurons/validator.py (8)

90-141: Helpers for unified gather-result handling look solid

The new static helpers correctly normalize partial vs merged gather results, dedupe UIDs to avoid double-counting, and provide a consistent success-rate view for downstream logging and slashing logic. The implementations are straightforward and type‑safe for both list and SimpleNamespace inputs.


982-988: OpenSkill rating initialization avoids uninitialized entries

Pre‑initializing self.openskill_ratings[uid] for all window_uids before calling self.openskill_model.rate prevents missing-rating edge cases and keeps the rating update path consistent for new peers.


1388-1419: Peer/commitment broadcast for distributed ranks is appropriate

Refreshing self.comms.commitments on master and broadcasting [peers, reserve_peers, commitments] via dist.broadcast_object_list ensures all ranks share a consistent view of gather participants and buckets, which is crucial for distributed gather correctness.


1462-1499: Timestamp centralization and broadcast fixes prior rank divergence

Computing ts_value only on master with fallback logging, then broadcasting it as a tensor so all ranks derive the same time_min/time_max window, resolves the earlier inconsistency risk between ranks during time‑windowed gathers.


1542-1578: Distributed gather mode selection and logging are well‑structured (verify sequential path semantics)

The use_distributed_gather predicate, gather‑mode logging, and the if self.is_master or use_distributed_gather condition cleanly separate sequential (master‑only) and distributed (all ranks) execution, while return_partials=use_distributed_gather ties neatly into incremental processing.

Since sequential mode now runs gather_with_reserve only on the master when world_size > 1 and distributed_gather is disabled, please double‑check that tplr.neurons.outer_step is designed to be called on non‑master ranks with gather_result=None (i.e., it must not dereference gather_result when is_master is False).

Also applies to: 1600-1636


1682-1725: Master‑only gather post‑processing and metrics reuse helpers correctly

The master‑only block now avoids redundant merges by using the partial‑aware overlap check directly, uploads a single merged result, and offloads gather results to CPU when applicable. Reusing _get_uids_from_gather_result, _get_skipped_uids_from_gather_result, and _get_success_rate_from_gather_result for overlap logging, gather‑quality metrics, and later slashing keeps the handling of partial vs merged results consistent.

Also applies to: 1702-1706, 1723-1725, 1748-1759


2559-2564: outer_step incremental args: confirm signature and multi‑rank behaviour

Passing memory_budget_mb=self.hparams.gradient_memory_budget_mb and batch_size_override=self.hparams.incremental_batch_size into tplr.neurons.outer_step is a clean hook for incremental, memory‑bounded processing, and the getattr(..., None) guards are safe.

Please verify at the call site that:

  • The outer_step signature actually accepts memory_budget_mb and batch_size_override keyword arguments.
  • Its implementation behaves correctly when these are None and when running in distributed vs single‑rank modes.

3283-3299: Using get_with_retry for miner debug dicts improves robustness

Swapping to self.comms.get_with_retry with a 60‑second timeout, and treating both None and success == False as failure, gives a more resilient path for fetching miner debug dictionaries while cleanly falling back to a neutral sync score on errors. The getattr(debug_result, "error", "unknown") access is safe even when debug_result is None.

Comment on lines -639 to +642
lookup = (
lookup.to(val.device) if isinstance(lookup, torch.Tensor) else lookup
)
if isinstance(lookup, torch.Tensor):
lookup = lookup.to(val.device)
if isinstance(shift, torch.Tensor):
shift = shift.to(val.device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for these changes?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was getting device mismatch errors here. The old code only moved lookup to val.device, but not shift. This ensures both lookup and shift are moved to the same device as val before the computation:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are many more paths on the outer step right now. Why do we need that? I think this complicates the code. I think it is better to keep outer step untouched for this one since the distributed gather would make master rank have the full gather results exactly like before. Can you explain a bit your changes?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think I fixed this now. The initial implementation had multiple code paths (outer_step_incremental for partials, _outer_step_single for merged results), but I have simplified it to a single path, exactly like before.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
neurons/trainer.py (1)

872-872: Critical: Duplicate gradient clipping—remove line 872.

Gradients are being clipped twice: once at line 872 (before unscaling) and again at lines 887-889 (after unscaling). When using GradScaler for mixed precision training, gradients are scaled during the backward pass and must be unscaled before clipping. The clip at line 872 operates on scaled gradients, which is incorrect.

Apply this diff to remove the duplicate clip:

                        log_loss = dist_helper.ddp_reduce(
                            loss_item, op=ReduceOp.AVG, device=self.device
                        )
-                       torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

                        if not null_round:

Also applies to: 887-889

neurons/miner.py (2)

470-487: Broadcast missing reserve_peers on miner ranks—inconsistent gather behavior.

The verification confirms the issue. In neurons/validator.py (lines 1370–1378), the broadcast correctly includes reserve_peers:

payload = (
    [self.comms.peers, self.comms.reserve_peers, self.comms.commitments]
    if self.is_master
    else [None, None, None]
)
dist.broadcast_object_list(payload, src=0)
self.comms.peers = payload[0]
self.comms.reserve_peers = payload[1]
self.comms.commitments = payload[2]

However, in neurons/miner.py (lines 470–487), only peers and commitments are broadcast. Since neurons/miner.py line 762 calls gather_with_reserve with reserve_uids=self.comms.reserve_peers, non-master ranks will use a stale or default [] value, diverging from the master's reserves. Apply the suggested fix to broadcast reserve_peers alongside peers and commitments:

-            # Broadcast peer list (and optionally commitments) from master to all ranks for distributed gather
+            # Broadcast peer list, reserve peers and commitments from master to all ranks for distributed gather
             if dist_helper.world_size > 1 and dist_helper.is_distributed():
-                payload = (
-                    [self.comms.peers, self.comms.commitments]
-                    if self.is_master
-                    else [None, None]
-                )
-                dist.broadcast_object_list(payload, src=0)
-                self.comms.peers, self.comms.commitments = payload
+                payload = (
+                    [self.comms.peers, self.comms.reserve_peers, self.comms.commitments]
+                    if self.is_master
+                    else [None, None, None]
+                )
+                dist.broadcast_object_list(payload, src=0)
+                self.comms.peers, self.comms.reserve_peers, self.comms.commitments = payload

380-385: Start window must be sourced from chain via get_start_window()—not local current_window.

The mining start window must come from the chain/R2 (via get_start_window()) because validators publish a global start_window to the chain. Miners must use this global value to correctly calculate global_step = current_window - start_window in sync with validators. Using a local current_window risks desynchronization and incorrect global_step calculation.

Re-enable the get_start_window() call at neurons/miner.py line 381, or gate it behind a config flag if this change was intentional for testing only.

src/tplr/neurons.py (1)

1447-1488: check_uid_index_overlap: partial‑aware logic and UID mapping look correct

The updated overlap checker:

  • Accepts either a merged SimpleNamespace or list[SimpleNamespace | None] and builds a flat uids list plus a uid_to_partial map for O(1) lookup, avoiding the previous O(n×m) search across partials.
  • For each parameter, unpacks 12‑bit indices per UID using that map, keeping track of which global UID indices are valid for the parameter in valid_uids_for_param.
  • Computes pairwise overlaps in the packed index domain, and records pair statistics keyed by global UID indices, later mapped back to uids[i], uids[j] for logging and slashing decisions.

The use of valid_uids_for_param[i] / [j] ensures that pairs are always reported in terms of the original global UID ordering, even when some peers lack a given parameter. This matches the intended semantics and should behave correctly for both merged and distributed‑partials inputs.

Also applies to: 1523-1561, 1580-1595

♻️ Duplicate comments (2)
neurons/validator.py (1)

3821-3919: compute_peer_val_norms: partials path can raise UnboundLocalError and is misleadingly structured.

When gather_result is a list[SimpleNamespace], the function fills clip_norm_dict inside the if isinstance(gather_result, list) block but then falls through to the merged‑result section and the final loop, where state_dict is undefined. If this path is ever exercised, it will raise UnboundLocalError at Line 3906. The structure also suggests both paths run, which is confusing.

Minimal fix: return immediately after computing norms for partials so the merged branch and final loop are skipped in that case:

@@
-        # Handle both merged results and partials
-        if isinstance(gather_result, list):
+        # Handle both merged results and partials
+        if isinstance(gather_result, list):
@@
-                norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device)
-                clip_norm_dict[vals_key] = torch.median(norms)
-        else:
+                norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device)
+                clip_norm_dict[vals_key] = torch.median(norms)
+            # For partials, we've fully populated clip_norm_dict and can return early
+            return clip_norm_dict
+        else:
             # Merged result: original logic
             state_dict = gather_result.state_dict
             if not state_dict:
                 raise ValueError("Must have gather_result.state_dict to compute norms")
@@
-        for n, p in self.model.named_parameters():
+        for n, p in self.model.named_parameters():
             vals_key = n + "vals"
             quant_key = n + "quant_params"

You may also want to re‑introduce the stricter vals/quant_params alignment checks you previously had here to avoid IndexError in maybe_dequantize_values when partials have inconsistent quantization layouts.

src/tplr/comms.py (1)

1669-1671: Early cutoff on NOT_FOUND still breaks timeout semantics (duplicate of prior concern)

The consecutive_not_found logic still gives up after 5 NOT_FOUNDs (~2.5 s at 0.5 s sleep) regardless of the caller’s timeout, and get() currently collapses timeouts, network glitches, and real 404s into status="NOT_FOUND". That means slow or flaky peers can be treated as “likely doesn’t exist” well before timeout expires, contradicting the contract of get_with_retry. This is the same issue raised in the earlier review; it remains unresolved in the current code. Consider either:

  • Distinguishing true 404s vs transient/network errors in get() and only incrementing consecutive_not_found for confirmed 404s, or
  • Deriving max_consecutive_not_found from timeout (or dropping the early cutoff entirely) so that transient errors are retried for the full timeout budget.

Also applies to: 1725-1737

🧹 Nitpick comments (13)
neurons/validator.py (1)

90-108: Helpers currently assume merged gather results only (naming/docs still reference partials).

_get_uids_from_gather_result, _get_skipped_uids_from_gather_result, and _get_success_rate_from_gather_result operate only on merged SimpleNamespace results, while the surrounding comment mentions “partials vs merged gather results”. That’s fine for current call sites (they always pass merged results), but either updating the comment/docstrings to match, or extending these helpers to handle list[SimpleNamespace] defensively, would reduce the chance of misuse if partials are re‑introduced later.

neurons/miner.py (2)

32-35: CUDA allocator env var placement is fine but only effective when this is the entrypoint.

Setting PYTORCH_CUDA_ALLOC_CONF before importing torch in this module is correct. Just be aware this won’t affect runs where another module imports torch first; if that matters, consider setting it in the launcher or a shared bootstrap module.


1018-1027: End‑of‑window gather_result cleanup is safe but only clears SimpleNamespace state_dicts.

Clearing gather_result.state_dict.__dict__ and then deleting gather_result after all uses (outer_step, debug, metrics) is a good way to reduce peak memory before the next window. Note that this only affects SimpleNamespace state dicts; if gather_with_reserve ever returns a plain dict for state_dict, it will remain for GC to handle, which is acceptable but slightly less aggressive.

src/tplr/comms.py (5)

20-21: Partition cache for UID splits looks sound; minor reuse opportunity

The _partition_cache keyed by (window, sorted_uids, world_size) with LRU eviction gives deterministic, bounded partitioning across calls, which is appropriate here. Since partition_uids_across_ranks() implements the same round‑robin logic, consider reusing that helper inside the cache‑fill branch to avoid duplicating the slicing scheme in two places.

Also applies to: 29-29, 142-144


483-582: s3_object_exists semantics are conservative and aligned with slashing requirements

The new s3_object_exists implementation cleanly separates 404/NoSuchKey, permission/config errors, and transient network/service issues using structured ClientError.response fields, and errs on the side of “exists” for transient failures, which is consistent with avoiding false slashing. One minor follow‑up you could consider is purging the S3 client (_purge_s3_client) on persistent config errors (e.g., invalid bucket) to avoid reusing a known‑bad client, though this is an optimization rather than a correctness fix.


666-669: Verify torch.load(weights_only=True) behavior and map_location choices

Offloading torch.load to asyncio.to_thread is a good fix for event‑loop blocking, and the higher 500 MB threshold for chunked downloads is reasonable. Two things to double‑check:

  • weights_only=True is only available in newer PyTorch versions and may behave differently for arbitrary dict payloads; please confirm the project’s minimum torch version and that these gradient/aggregator payloads still round‑trip as expected.
  • For remote loads you honor map_location (or self.config.device), but for local loads you now use the default device. If local gradient/aggregator files can be large and validators often run with GPUs attached, you may want to thread a map_location through get(..., local=True, ...) as well to avoid surprises/OOMs on unintended devices.

Also applies to: 709-723


2224-2356: Distributed gather core flow looks correct; note unused target_device and list element Nullability

The distributed gather path (per‑rank gather → optional local merge → all_gather_object → global merge or partials) is coherent and matches the intended API:

  • Partitioning via cached round‑robin slices is deterministic across ranks.
  • all_partials normalization to None for empty uids avoids downstream surprises.
  • When return_partials=True, preserving the per‑rank structure (with optional partial.rank) is useful for incremental consumers.

Two minor nits:

  • merge_gather_results computes target_device but never uses it when building the merged state; either wire this into tensor placement or drop the parameter to avoid confusion.
  • The return type for gather_distributed allows list[SimpleNamespace | None], but callers like _outer_step_incremental treat the list as list[SimpleNamespace] and filter out None themselves. That’s safe, but consider tightening the type hints or documenting that None entries are expected placeholders.

2396-2521: gather_with_reserve: distributed fallback and type normalization mostly consistent; confirm call‑site expectations

The refactored gather_with_reserve cleanly separates:

  • Primary distributed gather (with gather_distributed) when world_size>1, including detection of an “all empty” result across ranks.
  • Shared fallback to sequential gather on all ranks when distributed gather fails, with wrapping into a list when return_partials=True to keep shapes consistent.
  • Sequential path for non‑distributed runs, again normalizing to [SimpleNamespace] when returning partials.

Two points to verify:

  1. Cross‑rank consistency on fallback – Each rank independently calls gather(...) in the distributed‑fallback case. That’s fine if gather_with_reserve is only ever invoked on the master rank (non‑masters pass gather_result=None into outer_step and rely on broadcast), but would yield per‑rank divergence if non‑master ranks also consume primary. Please confirm master‑only usage at call sites or add a broadcast step after sequential fallback if other ranks might rely on the value.

  2. Partials semantics – When falling back to sequential in distributed mode with return_partials=True, you now return a single‑element list rather than one entry per rank. That’s acceptable for consumers that only treat “partials” as arbitrary chunks of UIDs (as _outer_step_incremental does), but it no longer encodes rank topology; clarifying this in the docstring would help future readers.

src/tplr/neurons.py (5)

22-22: _resolve_cuda_device helper is reasonable but currently unused

The _resolve_cuda_device helper handles various string/device/None inputs and CUDA initialization edge‑cases sensibly. However, its result (cuda_device) is computed in _outer_step_single and then never used, and all CUDA API calls now operate on the default device only. Either wire cuda_device into the few CUDA calls that still need an explicit device (if any), or remove the helper and local variable to keep the code simpler.

Also applies to: 48-69


250-313: outer_step increment/merge flow: dispatcher is correct; note behavior vs “incremental” goal and inner import

The new outer_step dispatcher cleanly routes:

  • gather_result: list[...]_outer_step_incremental
  • gather_result: SimpleNamespace | None_outer_step_single

and keeps the public API uniform across call sites like catchup_with_aggregation_server, which is good.

Within _outer_step_incremental:

  • None partials are ignored via the valid_partials filter, so it can safely consume the lists returned from gather_with_reserve(return_partials=True).
  • Partials are sorted deterministically by partial.rank (or metadata) before merging, and then merged once via _merge_partial_batchComms.merge_gather_results, after which _outer_step_single is executed once. This “merge then apply once” design matches catch‑up semantics and avoids per‑partial re‑application.

Two small points:

  • The function name/Docstring still call this “incremental” processing, but the implementation now merges all partials in one go before applying. Peak memory will thus look similar to a single merged gather. If the original intent was memory‑bounded, batched merging, it may be worth either adjusting the description or re‑introducing batched processing later.
  • _merge_partial_batch performs an in‑function from tplr.comms import Comms import. For consistency with the rest of the module (and to avoid repeated imports under heavy training loops), it would be cleaner to move that import to module scope.

Also applies to: 341-459


461-477: _outer_step_single: core logic preserved; minor dead code

The refactored _outer_step_single preserves the original behavior:

  • Only the “source rank” (is_master or non‑DDP) consumes gather_result and builds src_sd; other ranks rely on broadcast or distribute_tensor.
  • Per‑parameter decompression, decode, and immediate optimizer.step() with aggressive freeing remain intact.
  • Fingerprint accumulation and optional wandb logging are still gated on the source rank.

The only notable leftover is cuda_device = _resolve_cuda_device(device), which is no longer used in any CUDA calls (they now all use default‑device operations). This can be safely dropped to avoid confusion.

Also applies to: 493-515


1393-1407: compare_model_with_debug_dict: step‑ratio denominator change improves stability

Switching the step‑ratio denominator to max(param_avg_change, learning_rate) prevents extremely small param_avg_change values from inflating “steps behind” metrics when the model has mostly converged. The fallback to a constant learning_rate when shapes mismatch or no EMA is available is a sensible conservative choice. Just note that the min_step_size argument is now effectively unused; if you don’t expect to bring it back, consider removing it from the signature to avoid misleading callers.


1447-1488: check_uid_index_overlap: partial‑aware logic and UID mapping look correct

The updated overlap checker:

  • Accepts either a merged SimpleNamespace or list[SimpleNamespace | None] and builds a flat uids list plus a uid_to_partial map for O(1) lookup, avoiding the previous O(n×m) scan across partials.
  • For each parameter, unpacks 12‑bit indices per UID using that map, keeping track of which global UID indices are valid for that parameter in valid_uids_for_param.
  • Computes pairwise overlaps on these unpacked indices and stores stats keyed by global UID indices, later mapped back to uids[i], uids[j] for logging and slashing.

This matches the intended semantics for both merged and distributed‑partials inputs and keeps complexity under control.

Also applies to: 1523-1561, 1580-1595

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7526353 and 22a4606.

📒 Files selected for processing (5)
  • neurons/miner.py (10 hunks)
  • neurons/trainer.py (3 hunks)
  • neurons/validator.py (14 hunks)
  • src/tplr/comms.py (19 hunks)
  • src/tplr/neurons.py (13 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (3)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/comms.py (1)
  • gather_with_reserve (2396-2672)
src/tplr/neurons.py (2)
src/tplr/compress.py (5)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • TopKCompressor (295-697)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (1)
  • merge_gather_results (2114-2222)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
neurons/validator.py (6)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/comms.py (2)
  • gather_with_reserve (2396-2672)
  • get_with_retry (1625-1740)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1447-1652)
src/tplr/compress.py (1)
  • maybe_dequantize_values (647-697)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (9)
neurons/trainer.py (2)

911-911: LGTM: Correct handling of grad_norm during spin-up.

Setting total_grad_norm to 0.0 during null rounds (spin-up phase) is appropriate since no optimizer step is taken.


928-928: LGTM: Useful logging enhancement.

Adding gradient norm to the inner step logs provides valuable insight into training dynamics.

neurons/validator.py (2)

1355-1387: Distributed gather orchestration and timestamp/peer synchronization look consistent.

Broadcasting peers, reserve_peers, and commitments; using a single timestamp window across ranks; and gating gather_with_reserve via use_distributed_gather with master‑driven skip decisions keep all ranks aligned and avoids rank divergence or deadlocks in the new distributed gather path. The logging/use of the new _get_*_from_gather_result helpers on the merged result is also coherent with return_partials=False.

Also applies to: 1429-1472, 1509-1592, 1638-1676, 1698-1710


3230-3247: Switch to get_with_retry for miner debug dict improves robustness.

Using self.comms.get_with_retry with a bounded timeout and handling the None case before dereferencing the result prevents hangs and attribute errors when miners don’t publish debug data in time. The warning log with a safe getattr(debug_result, "error", "unknown") is also defensive.

neurons/miner.py (4)

553-560: Pre‑training CUDA cleanup and gradient reset are aggressive but appropriate for OOM mitigation.

Calling torch.cuda.empty_cache(), torch.cuda.synchronize(), and model.zero_grad(set_to_none=True) before inner_steps is a reasonable trade‑off to reduce fragmentation/VRAM pressure between windows, given this code assumes CUDA is available.


832-835: Post‑outer‑step gradient clearing is good hygiene.

Zeroing gradients with set_to_none=True after a successful outer step avoids gradient accumulation and can help allocator reuse; this matches how you clear grads before training and in cleanup_window().


885-895: Gather success and skipped peer metrics wiring is consistent.

Using the merged gather_result’s success_rate and skipped_uids for:

  • Debug fields (successful_peers / skipped_peers), and
  • Per‑window metrics (miner/gather/success_rate and skipped_peers JSON in Influx),

keeps miner metrics aligned with validator‑side gather reporting. This assumes gather_result is always merged (which you enforce via return_partials=False).

Also applies to: 930-937, 999-1002


739-776: Review comment is accurate; no changes required

Verification confirms all claims in the review comment are correct:

  1. gather_with_reserve return type: When called with return_partials=False (line 764), the function returns a merged SimpleNamespace, not a list. The implementation respects this parameter throughout all code paths (comms.py lines 2508, 2520, 2535, 2672).

  2. Usage consistency: The code at lines 771–797 correctly accesses gather_result.uids, gather_result.skipped_uids, and gather_result.success_rate—all attributes of a SimpleNamespace. Line 826 passes gather_result to self.outer_step() as expected for a merged result object.

  3. Distributed gather coherence: The conditional logic at line 759 correctly ensures all ranks participate when use_distributed_gather=True, while sequential mode remains master-only.

  4. Synchronization: The should_update reduction via dist_helper.all_ok() at lines 806–809 properly keeps all ranks synchronized on the decision to apply an outer step.

src/tplr/comms.py (1)

1472-1493: gradient_timestamp default version handling is straightforward

Allowing an explicit version while defaulting to tplr.__version__ when None makes this helper usable for both current and historical gradient versions without duplicating call sites. The error path still safely purges the S3 client on failure and returns 0.0, which matches existing semantics.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/tplr/comms.py (1)

1815-1875: Major semantic change in gather() return value breaks multiple call sites relying on None checks

The verification confirms the review comment is accurate. The new gather() implementation always returns a SimpleNamespace object, even when no valid gradients are received—it never returns None. This breaks control flow in at least two critical locations:

  1. neurons.py:1189-1199 – The catchup fallback decides whether to skip the window using if gather_ns is None. With the new behavior, this condition will always be false, and windows with zero valid gradients will proceed through outer_step (advancing schedulers and global_step) instead of being skipped.

  2. comms.py:2645 – The check if primary is None: in gather_with_reserve() is now unreachable, breaking the fallback logic for empty results.

  3. comms.py:2391, 2621, 2633, 2710 – Other gather() call sites have similar None checks or truthiness checks that will now behave unexpectedly.

Action required:

  • Either return None from gather() when valid_uids is empty and not in a distributed-merge scenario, or
  • Update all gather() call sites to check len(result.uids) instead of result is None, ensuring callers properly handle empty-but-non-None results.

At minimum, audit all listed call sites to confirm whether the new "silent no-op window" behavior is intentional.

♻️ Duplicate comments (1)
src/tplr/comms.py (1)

1665-1671: Early NOT_FOUND cutoff still conflates 404 with timeouts/network errors.

get_with_retry now stops after max_consecutive_not_found = 5, but get() still returns status="NOT_FOUND" for all s3_get_object None returns, including timeouts and transient client/network errors. This preserves the earlier problem: slow/flaky peers can be treated as “likely doesn’t exist” and given up on after ~2.5s (5 × 0.5s), independent of the caller’s timeout.

To align semantics with the timeout parameter and the new structured S3 helpers, consider:

  • Distinguishing true 404 “not found” from transient errors in get() (e.g., propagate a status enum or error type instead of collapsing to NOT_FOUND).
  • Only incrementing consecutive_not_found for confirmed 404s, not for timeouts or generic errors.
  • Optionally deriving max_consecutive_not_found from timeout (e.g., a fraction of it) rather than a fixed 5.

This keeps the fast‑fail optimization for non-existent objects without prematurely aborting legitimate slow uploads.

Also applies to: 1725-1737

🧹 Nitpick comments (5)
src/tplr/comms.py (3)

483-582: s3_object_exists error handling looks solid; consider also purging bad clients.

The structured handling of 404/403/5xx vs transient errors is good and matches the “don’t slash on network issues” goal. One improvement: for ConnectionClosedError and ClientError cases where you decide to treat the error as transient (5xx or generic network issues), you currently don’t call _purge_s3_client, unlike other S3 helpers in this class. That means a poisoned client (closed connection, stale DNS, etc.) may be reused repeatedly.

It would be safer to purge and recreate the client on these non-404/-403 branches as well, similar to s3_get_object and download_large_file.


2096-2327: Distributed helpers and merge_gather_results behaviour are mostly sound; minor cleanups possible.

The helpers encapsulate the distributed flow well:

  • partition_uids_across_ranks uses deterministic sort + round‑robin; good for reproducibility.
  • _move_gather_result_to_device and _convert_result_dtype correctly recurse through nested containers and avoid breaking 12‑bit idxs.
  • merge_gather_results correctly concatenates per‑param lists, aggregates metrics, and returns None when no UIDs succeeded.

Two minor points:

  1. merge_gather_results computes a target_device but never actually uses it to move tensors; callers instead explicitly move results via _move_gather_result_to_device. Either:

    • Remove target_device from merge_gather_results (and the discovery logic) to avoid confusion, or
    • Apply it before returning, e.g. by calling _move_gather_result_to_device(merged, target_device).
  2. When called with a list of partials where all have empty uids, merge_gather_results will log and return None (good), but the earlier normalisation in gather_distributed (treating empty‑uids partials as None) is now redundant. You could rely on one of these mechanisms to simplify the code.

Neither issue is correctness‑critical, just opportunities to tighten the implementation.


2522-2792: gather_with_reserve distributed fallback and partial handling look consistent; be aware sequential fallback can still diverge cross‑rank.

The refactored gather_with_reserve nicely integrates distributed gather:

  • In distributed mode all ranks participate in gather_distributed, even if they have no UIDs, avoiding deadlocks at the all‑gather barrier.
  • The return_partials flag is propagated through to both the distributed path and the sequential fallback, and you now wrap sequential results as a single‑element list to keep type contracts consistent.
  • The final success‑rate recomputation for both merged and partials paths keeps metrics aligned with the original primary tier.

One caveat: when distributed gather fails and you fall back to sequential, each rank independently runs self.gather(...) with the full gather_uids list. That preserves type contracts but can still result in different primary results per rank under asymmetric S3/network conditions (timeouts, transient errors), because there’s no subsequent broadcast or re‑synchronization of the sequential result.

If strict cross‑rank consistency is required even in this fallback path, consider:

  • Having only the master rank run the sequential gather, then broadcasting the resulting merged object (or partial list) via dist_helper.broadcast_object, or
  • Ensuring all ranks perform the sequential gather but then choose a canonical one (e.g., rank 0’s) to apply, again via broadcast.

If your downstream logic is already robust to per‑rank variation in the “all ranks fallback” case, the current approach is acceptable; otherwise, a broadcast step would harden guarantees.

src/tplr/neurons.py (2)

48-70: outer_step dispatch and device resolution are fine, but type hints and unused cuda_device could be tidied.

The new outer_step dispatcher cleanly separates the “list of partials” case (incremental path) from the single SimpleNamespace case, and _outer_step_incremental correctly funnels back into the original single‑result implementation. That’s a good way to keep semantics aligned with catch‑up.

Two small nits:

  1. The signature gather_result: SimpleNamespace | list[SimpleNamespace] | None doesn’t reflect that partial lists can (and do) contain None entries (from gather_with_reserve and gather_distributed). _outer_step_incremental handles None safely by filtering them out, but the type hint is misleading. Consider updating it to SimpleNamespace | list[SimpleNamespace | None] | None.

  2. _resolve_cuda_device is invoked in _outer_step_single, but the returned cuda_device isn’t used. Unless you plan to plug it into CUDA‑specific calls soon, this is dead code and could be removed, or the helper’s result should be wired into the operations that actually require a torch.device instead of a string.

These are minor maintainability issues; behaviour looks correct.

Also applies to: 250-314, 461-497


1029-1101: Catch‑up debug comparison and memory logging are valuable but add overhead; ensure they’re acceptable for your deployment.

The catch‑up loop now:

  • Logs detailed CUDA memory stats before, during, and after replay.
  • Optionally verifies the initial checkpoint against the leader’s debug dict.
  • For each window, fetches and compares debug dicts, and tracks an EMA of small parameter slices to estimate “steps behind”.

Functionally this is all sound and gated on instance.is_master, but it does add:

  • Extra S3 GETs for debug objects at each window.
  • Per‑window per‑parameter CPU work (slicing, EMAs, diff stats).
  • Occasional compare_model_with_debug_dict calls.

If you run very long catch‑up sequences or have large models, this may noticeably extend catch‑up time. If that’s a concern, consider adding a simple flag/hparam to enable/disable the per‑window debug‑dict comparison (keeping the memory logging, which is cheaper) so operators can turn it off in production if needed.

Also applies to: 1217-1271, 1337-1356

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 22a4606 and 698402d.

📒 Files selected for processing (3)
  • neurons/miner.py (9 hunks)
  • src/tplr/comms.py (22 hunks)
  • src/tplr/neurons.py (14 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
neurons/miner.py (3)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/comms.py (1)
  • gather_with_reserve (2516-2792)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
src/tplr/neurons.py (2)
src/tplr/compress.py (5)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • maybe_dequantize_values (647-697)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (1)
  • merge_gather_results (2208-2326)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (15)
src/tplr/comms.py (1)

666-677: Offloading torch.load and raising the size threshold is a good fix.

Moving torch.load into asyncio.to_thread and increasing the “simple download” threshold to 500MB is a solid improvement: it prevents the event loop from being blocked on CPU decompression while still using the chunked downloader for truly large objects. The error logging around failed torch.load calls is also appropriate and keeps failures non-fatal for the calling code.

Also applies to: 709-723

src/tplr/neurons.py (4)

316-459: Incremental outer step merging via Comms.merge_gather_results is consistent and robust.

The incremental path does the right thing:

  • Filters out None partials and sorts by partial.rank (or embedded metadata) for deterministic ordering.
  • Uses Comms.merge_gather_results to merge all partials once, then applies a single outer step via _outer_step_single, matching catch‑up semantics.
  • Wraps both merge and apply in try/except blocks with clear logging and aggressive gc.collect()/torch.cuda.empty_cache() cleanup.

Given that merge_gather_results already computes aggregate metrics and handles the “no UIDs at all” case by returning None, this reuse avoids duplicating intricate merge logic in neurons.py.

Overall this is a good design choice and looks correct as long as all partials are Comms‑style gather results (which they are in the current flow).

Also applies to: 579-587


1413-1425: compare_model_with_debug_dict now uses param_avg_change more meaningfully.

The revised logic that, when param_avg_change is available, uses it as the denominator (with a min_step_size floor) instead of the raw learning rate gives a much more realistic estimate of “steps behind”:

  • It respects per‑parameter observed change magnitudes.
  • It still falls back safely to learning_rate when stored slices are mismatched or unavailable.

This should make the reported avg_steps_behind metric significantly more informative during catch‑up and debugging.


1468-1510: check_uid_index_overlap: partial‑aware path looks correct; index bookkeeping is subtle but consistent.

The extension to support both merged and partial gather results is well thought out:

  • For partials, you build a global uids list and a uid_to_partial map (partial, index_in_partial), avoiding the previous O(n×m) scan across partials.
  • For merged results, you retain the original logic by sharing one state_dict across all UIDs.

Inside the per‑parameter loop:

  • You correctly skip UIDs that lack idxs/vals entries, and only append to unpacked_indices plus valid_uids_for_param for those that do.
  • You then work in the compact P‑sized space (valid peers for that param), but store pair‑accumulation keys in terms of the original UID indices via uid_idx_i = valid_uids_for_param[i] / uid_idx_j = valid_uids_for_param[j].
  • Later, when summarizing, you map back through uids[i] / uids[j] for human‑readable pairs, which stays consistent with the keys you used in pair_acc.

Overall, the indexing is non‑trivial but appears internally consistent and should produce correct overlap and slashing decisions for both merged and partial inputs.

Also applies to: 1542-1593, 1601-1617


563-587: Perfect! I found that the codebase requires torch>=2.7.1. Now let me verify that torch.Tensor.to() supports device strings in this PyTorch version.

Review verified: device string support confirmed.

PyTorch's device handling supports device strings like "cuda:0"—methods taking a device argument accept properly formatted strings as arguments. The codebase requires torch>=2.7.1, and tensor.to() accepts device strings including 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. The guard and dequantization logic in the code is correctly designed for this API.

neurons/miner.py (10)

32-34: LGTM: CUDA memory allocator configuration.

Setting expandable_segments:True before importing torch is a good practice to reduce memory fragmentation in distributed training. The check for existing environment variable ensures user configurations are not overwritten.


42-42: LGTM: Distributed training import.

Required import for the distributed gather functionality implemented in this PR.


552-559: LGTM: Pre-training memory hygiene.

Aggressively clearing CUDA cache and lingering gradients before training is a sound practice for preventing OOM in multi-GPU setups. The defensive check for model existence is appropriate.


738-802: LGTM: Distributed gather implementation with proper fallback and logging.

The gather mode selection and logging are well-structured:

  • use_distributed_gather is correctly determined from world size and distributed status
  • return_partials=False ensures all ranks receive a merged result (not a list of partials), consistent with the design intent from past reviews
  • local=False correctly pulls gradients from R2 remote storage
  • Detailed logging distinguishes distributed vs sequential modes with appropriate metrics
  • All attribute accesses (.uids, .skipped_uids, .success_rate) are consistent with the merged result type
  • should_update is broadcast to ensure all ranks agree on whether to perform outer step

Based on past reviews and PR objectives.


831-833: LGTM: Post outer-step gradient clearing.

Clearing gradients after the outer step with set_to_none=True is appropriate for memory efficiency. The defensive check ensures safety.


886-893: LGTM: Safe UID extraction from merged gather result.

Uses getattr with defaults to safely extract UIDs and compute successful/skipped peer sets. The logic correctly derives successful peers as the set difference between all peers and skipped UIDs.


929-936: LGTM: Consistent success rate calculation.

Safely handles the None case explicitly and accesses success_rate directly from the merged result. The use of getattr for skipped_uids adds defensive robustness even though the merged result type guarantees the attribute exists.


1000-1000: LGTM: Metrics logging with skipped peers.

Correctly uses the safely extracted skipped_uids_list from the merged gather result. JSON encoding is appropriate for InfluxDB storage.


1017-1027: LGTM: Thorough cleanup of gather artifacts.

The cleanup correctly occurs after all uses of gather_result (outer step, debug dict, metrics logging). Clearing state_dict.__dict__ is aggressive but appropriate for reclaiming memory in distributed training, as the comment confirms this is the only remaining reference.


474-490: Broadcast reserve_peers along with peers and commitments.

The code correctly broadcasts peers and commitments to all ranks, addressing the past concern about stale bucket mappings. However, self.comms.reserve_peers is used on line 761 but is not included in the broadcast payload. Since update_peers only runs on the master rank, non-master ranks will have stale or uninitialized reserve_peers values.

Apply this diff to include reserve_peers in the broadcast:

             # Broadcast peer list (and optionally commitments) from master to all ranks for distributed gather
             if dist_helper.world_size > 1 and dist_helper.is_distributed():
                 payload = (
-                    [self.comms.peers, self.comms.commitments]
+                    [self.comms.peers, self.comms.commitments, self.comms.reserve_peers]
                     if self.is_master
-                    else [None, None]
+                    else [None, None, None]
                 )
                 dist.broadcast_object_list(payload, src=0)
-                self.comms.peers, self.comms.commitments = payload
+                self.comms.peers, self.comms.commitments, self.comms.reserve_peers = payload
 
                 if not self.is_master:
                     tplr.logger.info(
                         f"[Rank {dist_helper.rank}] Received peers from master: {self.comms.peers}"
                     )

Likely an incorrect or invalid review comment.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/tplr/comms.py (1)

1551-1567: Local torch.load path should honor map_location for symmetry and OOM avoidance.

The local branch ignores the map_location parameter while the remote S3 branch correctly uses it. This asymmetry means callers requesting map_location="cpu" to avoid GPU OOM will only get that behavior for remote loads, not local ones.

The remote implementation pattern (lines 705-718) already shows the correct approach:

device_location = (
    map_location if map_location is not None else self.config.device
)
loaded_data = await asyncio.to_thread(
    torch.load,
    temp_file_path,
    map_location=device_location,
    weights_only=True,
)

Apply the same fix to the local branch (around line 1564) for consistency and to honor the caller's explicit device mapping requests.

🧹 Nitpick comments (19)
neurons/miner.py (3)

552-560: Aggressive CUDA cleanup is effective but may be overly heavy on every window

Calling torch.cuda.empty_cache() + torch.cuda.synchronize() before every inner loop, plus a full model.zero_grad(set_to_none=True), is a strong guardrail against OOM but adds non‑trivial latency per window.

If you find throughput becomes an issue, consider gating this behind a hyperparameter (e.g. hparams.aggressive_cleanup) or a less frequent cadence (every N windows) while keeping the current behavior as the safe default.


884-892: Gather debug and metrics fields line up with merged gather_result

The debug and metrics paths:

  • Use gather_result.uids / gather_result.skipped_uids from the merged result.
  • Derive successful_peers as peers − skipped, which matches the comms contract where skipped_uids is defined relative to the original peer list.
  • Log gather_success_rate = gather_result.success_rate * 100 and JSON‑encode skipped_uids_list for Influx.

This is consistent with the merged‑result semantics enforced by return_partials=False and should give accurate visibility into gather quality.

Also applies to: 927-933, 997-997


1014-1023: End‑of‑window gather_result cleanup is safe and targeted

The cleanup block:

  • Deletes shard_gradient.
  • For a non‑None gather_result, clears gather_result.state_dict.__dict__ when it is a SimpleNamespace, then deletes gather_result.
  • Forces gc.collect() and torch.cuda.empty_cache().

Given that all uses of gather_result (outer_step, debug, metrics) are complete by this point and no background tasks hold references to it, this is a reasonable way to aggressively release memory between windows.

neurons/validator.py (3)

1491-1561: Distributed gather orchestration is correct; comment about partials is now misleading

The gather block now:

  • Forces use_distributed_gather = True so all ranks participate in gather_with_reserve, letting the comms layer decide between distributed vs sequential internally.
  • Calls gather_with_reserve with:
    • gather_uids=self.comms.peers
    • reserve_uids=self.comms.reserve_peers
    • return_partials=False
    • local=False, extended timeout=150, time_min/time_max, and expected_compressed_params.
  • Logs distributed gather completion using gather_result.uids, gather_result.skipped_uids, and gather_result.success_rate.

That wiring is consistent with the comms implementation and the later uses of gather_result (upload, overlap, slashing, outer_step).

The inline comment at Line 1517, though, still says “Request partials for incremental processing in outer_step” while return_partials=False ensures a merged result. To avoid confusion for future refactors (especially if partials are re‑enabled), consider updating or removing that comment:

-                    # Request partials for incremental processing in outer_step
                     gather_result = await self.comms.gather_with_reserve(
@@
-                        return_partials=False,  # Always return merged result (distributed gather merges internally)
+                        return_partials=False,  # Always return merged result; partials are handled inside comms when enabled

1823-1825: compute_peer_val_norms currently only supports merged results despite broader type hint

compute_peer_val_norms is now typed as:

def compute_peer_val_norms(
    self,
    gather_result: SimpleNamespace | list[SimpleNamespace],
) -> dict[str, torch.Tensor]:

but the implementation:

  • Assumes gather_result has a state_dict attribute.
  • Treats state_dict as an attribute container (getattr(state_dict, vals_key, None)), i.e., a SimpleNamespace or similar.
  • Does not branch on the list‑of‑partials case.

Given that this call site passes return_partials=False to gather_with_reserve, the function is currently only invoked with a merged SimpleNamespace, so behavior is correct today. The signature, docstring, and nearby comment (“handles both partials and merged results”) are misleading, though, and increase the risk of someone later enabling return_partials=True and getting a hard failure.

Two options:

  1. Narrow the contract to merged results only (simplest, current behavior):
-    def compute_peer_val_norms(
-        self,
-        gather_result: SimpleNamespace | list[SimpleNamespace],
-    ) -> dict[str, torch.Tensor]:
+    def compute_peer_val_norms(
+        self,
+        gather_result: SimpleNamespace,
+    ) -> dict[str, torch.Tensor]:
@@
-        """
-        Compute median L2 norms across all peers for each parameter.
-        ...
-        """
+        """Compute median L2 norms across peers for each parameter using a merged gather_result."""
  1. Or, if you want real partial support, add a branch that merges partials (or reuses self.comms.merge_gather_results) before the existing body.

Given the current usage, (1) is enough and avoids accidental misuse.

Also applies to: 3792-3804


1412-1439: Minor note: local use_distributed_gather flag is now purely for logging

With:

use_distributed_gather = True
...
if self.is_master or use_distributed_gather:
    gather_result = await self.comms.gather_with_reserve(...)

the local use_distributed_gather no longer reflects the actual distributed/ sequential mode (that decision is made inside gather_with_reserve based on dist_helper.is_distributed()). It is effectively a “call on all ranks” flag, and the if use_distributed_gather: log message will say “Starting distributed gather …” even in the single‑rank case.

This is harmless but a bit confusing. Consider either:

  • Renaming the local flag (e.g., all_ranks_participate = True), or
  • Restoring the previous conditional definition that mirrors gather_with_reserve’s use_distributed calculation if you want logs to distinguish true distributed from single‑rank gathers.

Functionally everything is correct; this is just about log clarity.

Also applies to: 1491-1492

src/tplr/neurons.py (4)

48-70: _resolve_cuda_device is correct but currently unused downstream.

The helper cleanly normalizes string / torch.device / None into a safe CUDA device or None, but _outer_step_single only assigns cuda_device and never uses it. Either wire this into future CUDA calls or drop the local variable to avoid confusion.


297-333: Minor clean‑up in _outer_step_single.

cuda_device = _resolve_cuda_device(device) is currently dead code in this function. If you don't plan to use it for CUDA‑only operations (e.g., memory stats, sync), consider removing the assignment or using cuda_device in any future CUDA calls to keep intent clear.


520-526: Duplicate comment line before optimizer.step.

Lines 520–521 repeat the same comment twice. Safe to drop one of them to reduce noise.

-        # ---- apply update immediately for THIS param and free its grad ----
-        # ---- apply update immediately for THIS param and free its grad ----
+        # ---- apply update immediately for THIS param and free its grad ----

868-912: Checkpoint verification path is useful but can be expensive.

The added debug‑dict comparison in catch‑up provides strong safety checks on checkpoint correctness, but it does an extra get(...) and comparison before replaying any window. For very frequent restarts or large models this might add noticeable latency. If this becomes an issue operationally, consider gating the verification behind a config flag.

src/tplr/comms.py (9)

142-145: Partition cache is fine; consider documenting eviction policy.

The _partition_cache keyed by (window, sorted_uids, world_size) with an LRU cap of 32 entries is reasonable. A short comment about expected cache size vs typical window counts would help future maintainers understand why 32 was chosen and when eviction might occur.


483-582: s3_object_exists error classification looks solid; consider purging bad clients on hard errors.

The structured handling of 404, auth/bucket config errors, and transient 5xx/network issues is a good improvement over string matching. For consistency with other S3 paths, you might also call _purge_s3_client(bucket) on non‑transient ClientErrors (e.g., invalid credentials/bucket) to force recreation on the next call, though this is optional.


2071-2096: partition_uids_across_ranks is correct but currently unused.

The static round‑robin helper does exactly what gather_distributed’s inline partitioning does. To reduce duplication, consider having gather_distributed call partition_uids_across_ranks instead of reimplementing the slice logic with its own cache (or remove the helper if you intend to rely solely on the cached path).


2133-2257: merge_gather_results: semantics look right; minor device/perf nits.

  • Correctly filters out None partials, aggregates all param keys, concatenates per‑param lists, and recomputes metrics and success_rate.
  • Device selection tries current CUDA device, with a safe CPU fallback.

Two small suggestions:

  • There’s a second move_to_device here that largely duplicates _move_gather_result_to_device; you could reuse the static helper for consistency.
  • When gather_distributed’s local‑merge path calls merge_gather_results([partial_result]) and then immediately moves to CPU, you end up moving tensors twice (GPU→GPU via to(target_device) then GPU→CPU). Passing target_device="cpu" from that call would avoid an extra copy.

2318-2364: Local merge + CPU move before all_gather mostly achieves the memory goal.

Merging per‑rank results locally when return_partials=False and then moving that partial to CPU before all_gather_object addresses the multi‑rank GPU OOM scenario. To minimize device churn, consider invoking merge_gather_results([partial_result], target_device="cpu") directly so tensors are moved only once (GPU→CPU) instead of GPU→GPU then GPU→CPU.


2404-2438: Merged‑result path is correct; minor duplication in logging.

Merging all_partials on every rank ensures consistent behavior; moving the final result back to device only when non‑CPU and CUDA‑available is reasonable. The log_with_context call on rank 0 provides a good summary. Note: gather_time and merge_time referenced in the final log come from the local rank and may differ slightly across ranks; that’s acceptable but could be renamed to clarify they’re per‑rank.


2479-2494: Entry logging is very verbose; consider truncating UID lists.

[gather_with_reserve][Rank ...] ENTRY | gather_uids={gather_uids}, reserve_uids={reserve_uids} at INFO can be quite large when many peers are involved. If logs become noisy, you might want to limit to counts and maybe the first few UIDs.


2507-2557: Distributed integration and sequential fallback logic are mostly sound; note potential per‑rank divergence.

  • Using gather_distributed when world_size > 1 and distributed is initialized is the right default.
  • The post‑call check treating None and all‑empty partials as failure is good.

However, when the distributed path fails, each rank independently runs the sequential gather(...) with the full gather_uids list. Because these are separate S3 read patterns, ranks can in theory end up with slightly different results under flaky network conditions. If strict cross‑rank consistency is required, consider broadcasting a single primary result from the master rank after sequential fallback.


2571-2585: Empty-primary normalization is correct; partials vs merged behavior is slightly asymmetric.

When primary is None, you create an empty shell and, if return_partials=True, wrap it in a list to maintain type consistency. For the case where return_partials=True and sequential_result is [None], primary skips this path and remains [None]—later logic converts that into None via total_uids == 0. Behavior is still correct but slightly harder to reason about; a brief comment explaining the two “no gradients” paths would help.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 698402d and 7fccf1f.

📒 Files selected for processing (4)
  • neurons/miner.py (9 hunks)
  • neurons/validator.py (13 hunks)
  • src/tplr/comms.py (19 hunks)
  • src/tplr/neurons.py (11 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/validator.py (6)
tests/test_state_loading.py (1)
  • rating (59-62)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/comms.py (2)
  • gather_with_reserve (2443-2721)
  • get_with_retry (1625-1740)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1281-1486)
src/tplr/schemas.py (1)
  • success (63-65)
neurons/miner.py (3)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/comms.py (1)
  • gather_with_reserve (2443-2721)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
src/tplr/neurons.py (1)
src/tplr/compress.py (4)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • unpack_12bit_indices (102-142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (26)
neurons/miner.py (2)

830-833: Post‑outer‑step gradient clearing is in the right place

Zeroing gradients with model.zero_grad(set_to_none=True) immediately after a successful outer step is a good safeguard against stale grads leaking into the next window, especially now that outer steps are conditional on should_update.

This fits cleanly with the additional cleanup in cleanup_window().


738-775: Verification confirms distributed gather wiring is consistent across all call sites

The sanity check confirms:

  • Miner.py (lines 758–771): correctly uses local=False, passes time_min/time_max, includes expected_compressed_params, sets return_partials=False with timeout=90
  • Validator.py (lines 1517–1532): identically configured with local=False, time_min/time_max, expected_compressed_params, return_partials=False but with timeout=150 (appropriate for validator workload)
  • Test call appropriately uses minimal parameters for mocked context

Both production call sites implement the same parameter discipline and avoid partial-list vs merged-result pitfalls.

neurons/validator.py (5)

1338-1370: Peers, reserve_peers, and commitments broadcast keeps all ranks aligned for distributed gather

The new logic:

  • Refreshes self.comms.commitments on the master each window.
  • Broadcasts [peers, reserve_peers, commitments] from master to all ranks using dist.broadcast_object_list.
  • Updates self.comms.peers, self.comms.reserve_peers, and self.comms.commitments on non‑master ranks and logs what was received.

This directly addresses the stale‑commitments / stale‑peer‑set risk for distributed gather and ensures all ranks use the same bucket mappings and reserve lists when calling gather_with_reserve.


1412-1439: Timestamp broadcast ensures consistent time_min/time_max across ranks

Only the master now:

  • Calls self.query_block_timestamp(sync_block).
  • Falls back to time.time() on failure with a warning.

Then ts_value is broadcast via a torch.float64 tensor to all ranks, and time_min / time_max are derived from the shared value.

This removes the previous possibility of rank‑dependent time windows for gather_with_reserve(time_min=time_min, time_max=time_max) and is the right pattern for distributed timing semantics.


1607-1645: Master‑only post‑gather processing and offload flow looks coherent

In the master‑only block when gather_result is not None:

  • You upload the already‑merged gather results via upload_gather_results(gather_result).
  • Run check_uid_index_overlap on the merged result and apply slash_from_overlap.
  • Call offload_gather_results(gather_result, log=True) to move its state_dict to pinned CPU, freeing GPU memory before evaluation.
  • Extract skipped_uids and success_rate once, then barrier on "post_gather".

This sequencing:

  • Ensures all heavy gather‑side work is done before evaluation and outer‑step.
  • Leaves gather_result valid (albeit offloaded to CPU) for later uses: compute_peer_val_norms, slashing paths, outer_step, and metrics.

Given offload_gather_results handles both SimpleNamespace and dict state_dict layouts, this looks safe.


3198-3213: evaluate_miner_sync: get_with_retry integration and error handling look solid

The new evaluate_miner_sync logic:

  • Uses self.comms.get_with_retry(..., timeout=60, local=False, stale_retention=10) to fetch the miner’s debug dict, which is a better fit for transient S3/R2 hiccups than a single get.
  • Correctly handles both timeout (debug_result is None) and unsuccessful responses (not debug_result.success), logging a warning that includes getattr(debug_result, "error", "unknown") without risking an AttributeError when debug_result is None.
  • Falls back to a simple {success: False, error: "...", sync_score: 0.0} payload, which upstream slashing logic can safely consume.

This should reduce spurious sync‑slash events due to transient infra issues without introducing new failure modes.


932-938: OpenSkill rating initialization is required by the PlackettLuce model

Initializing self.openskill_ratings[uid] for all window_uids before calling self.openskill_model.rate(...) is not just an optimization—it is a requirement. The PlackettLuce.rate function expects teams to be lists of PlackettLuceRating objects (i.e. pre-initialized rating objects), and passing raw player values instead will raise a TypeError.

This change correctly ensures:

  • Every UID has a properly initialized rating object before being passed to the model.
  • The code avoids runtime TypeErrors that would otherwise occur.
  • Ratings are stable and predictable as new peers appear.
src/tplr/neurons.py (5)

249-295: outer_step wrapper looks good and preserves original semantics.

Delegating to _outer_step_single keeps the public API stable while centralizing the implementation. Given that distributed gather now always delivers a merged result here, the updated docstring is accurate.


1084-1147: Per‑window debug‑dict comparison logic is sound.

The EMA of parameter‑slice deltas (param_avg_change) combined with compare_model_with_debug_dict gives a reasonable estimate of “steps behind” during catch‑up. The handling of DTensor vs non‑DTensor parameters and the index slicing guards (numel() < 2) look correct.


1225-1240: Stabilized step‑ratio calculation in compare_model_with_debug_dict is a good fix.

Using max(param_avg_change, learning_rate) as the denominator avoids explosive step ratios when the EMA of parameter change becomes tiny, while still falling back to learning_rate if the stored slice has mismatched length. This should make the “steps behind” metric much more robust.


1281-1322: Partial‑aware handling in check_uid_index_overlap looks correct and fixes O(n×m) scan.

The new uid_to_partial: dict[int, tuple[SimpleNamespace, int]] map lets you locate a UID’s partial and index in O(1) instead of scanning all partials per UID. The merged‑vs‑partial branching and all_state_dicts construction are consistent, and Ptot remains the full UID count in both cases.


1355-1431: Overlap computation for partial results is logically consistent.

Using valid_uids_for_param to map back from the compact per‑parameter index (0..len(unpacked_indices)-1) to the original UID index, and then keying pair_acc by those original indices, preserves the semantics of later reporting via uids[i], uids[j]. Early continues when idxs_list/vals_list are missing or too short correctly skip malformed entries.

src/tplr/comms.py (14)

666-695: 500MB threshold for multipart downloads is reasonable.

Using a 500MB cutoff to switch from simple get_object to the parallel download_large_file implementation aligns with the goal of handling large model gradients robustly, and the dynamic download_timeout scaling by size is sensible.


1714-1737: Resetting consecutive_not_found on non‑NOT_FOUND statuses is correct.

Once the above distinction is addressed, the logic to reset consecutive_not_found on other statuses should keep the early‑exit optimization from triggering on mixed error patterns.


1815-1990: gather’s validation and logging additions look good.

The added per‑UID logging for exceptions (log_with_context at warning level) and the NaN/Inf checks on non‑quantized vals (moving only for validation) strengthen robustness without changing semantics. The expected_compressed_params cross‑check and quant_params presence check remain consistent.


2097-2131: _move_gather_result_to_device is straightforward and matches usages.

The recursive move_to_device helper correctly handles tensors, lists/tuples, and dicts under result.state_dict. In‑place mutation of the namespace is fine here and fits how gather_distributed moves results to CPU/GPU before/after collectives.


2259-2313: Distributed UID partitioning and caching in gather_distributed are reasonable.

The (window, sorted_uids, world_size) cache key plus LRU eviction keeps partitions deterministic and avoids recomputing slices every step. Using sorted UIDs ensures all ranks see the same partitioning, and the rank < len(partitions) guard handles any mismatch gracefully.


2365-2385: all_gather_object usage and normalization look correct.

Using dist_helper.all_gather_object(partial_result) guarded by both world_size > 1 and dist_helper.is_distributed() is appropriate. Normalizing empty payloads to None based on missing uids simplifies downstream checks when deciding whether any rank produced gradients.


2387-2402: return_partials path behavior is consistent and type‑safe.

Returning the full all_partials list (possibly containing Nones) with a rank attribute set for each non‑None entry gives incremental consumers enough metadata to sort or bucket by source rank. Logging only on rank 0 avoids redundant logs.


2443-2477: gather_with_reserve signature/docs match new behavior.

Adding return_partials and updating the return type docstring to allow list[SimpleNamespace | None] is aligned with gather_distributed’s contract and with the incremental processing use‑case.


2586-2615: Primary‑gather logging for partials vs merged results is clear and correct.

The split handling of primary as list vs namespace, including computing missing differently, matches the two modes. Logging via log_with_context with both counts and skipped UIDs should be useful when debugging gather quality.


2618-2676: Reserve retry logic integrates well with partials and merged paths.

  • Filtering reserve_uids to exclude already‑used primary UIDs avoids duplicates.
  • Using sequential gather for the reserve tier is a reasonable compromise in complexity.
  • For partials, appending fallback as a new partial and assigning fallback.rank = len(primary) preserves deterministic ordering; for merged results, in‑place tensor‑list merging and metric aggregation look correct.

2678-2721: Final success‑rate recomputation and “no gradients” handling are correct.

Recomputing success_rate against the original gather_uids length and propagating it to each partial (or to the merged result) makes metrics easier to interpret. Returning None when total_uids == 0 (both in partial and merged branches) ensures downstream callers don’t have to special‑case an empty structure.


1471-1493: gradient_timestamp’s version arg extension is safe and backwards compatible.

Defaulting version to tplr.__version__ when None preserves existing behavior while allowing callers to probe other versions explicitly. Exception handling still purges the S3 client and returns 0.0 on failure.


3071-3116: get_start_window version support aligns with gradient_timestamp changes.

Using target_version = version if version else tplr.__version__ keeps existing behavior and lets callers request bootstrap/alternate versions. The call into s3_get_object is consistent with other JSON fetch paths.


697-724: No issues detected—the weights_only fix is valid and compatible.

The weights_only parameter was introduced in PyTorch 2.6, and the codebase requires torch>=2.7.1 (per pyproject.toml). A dict like {"state_dict": <state_dict>, "global_step": 123} (where state_dict is a mapping of tensors) loads correctly with weights_only=True, which matches exactly what the code saves and loads. Errors are already handled gracefully (logged and result in None).

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/tplr/comms.py (1)

1551-1570: Apply map_location to local torch.load path (lines 1564-1566)

The local storage branch does not respect the map_location parameter documented in the function signature and docstring. While the remote path (via s3_get_object()) correctly computes device_location and passes it to torch.load() (lines 705-718), the local branch omits map_location entirely. This creates a behavioral inconsistency where local loads ignore the caller's device preference and device fallback logic.

-                loaded_data = await asyncio.to_thread(
-                    torch.load, local_path, weights_only=True
-                )
+                device_location = (
+                    map_location if map_location is not None else self.config.device
+                )
+                loaded_data = await asyncio.to_thread(
+                    torch.load,
+                    local_path,
+                    map_location=device_location,
+                    weights_only=True,
+                )
♻️ Duplicate comments (1)
src/tplr/comms.py (1)

1665-1737: NOT_FOUND retry cutoff still conflates 404 with transient failures

get_with_retry now tracks consecutive_not_found and aborts after 5 occurrences (~2.5 s at 0.5 s sleep), but get() still maps all s3_get_object None returns (404, timeouts, and some ClientError cases) to status="NOT_FOUND". This means slow or flaky peers can be treated as “likely doesn’t exist” long before the caller’s timeout or time_max window ends.

To avoid violating the timeout semantics, you likely still need to either:

  • Distinguish true 404s from transient errors in get()/s3_get_object and only increment consecutive_not_found for confirmed 404s, or
  • Derive max_consecutive_not_found from timeout and/or time_max and apply an early cutoff only once you’re reasonably sure the gradient should already have appeared.

This aligns with the earlier review concern about premature NOT_FOUND cutoffs.

🧹 Nitpick comments (7)
neurons/validator.py (4)

1491-1493: Consider making validator distributed‑gather opt‑in via config

Right now use_distributed_gather is hard‑coded to False, so validators will never exercise the distributed gather path even when running multi‑GPU. If you foresee wanting to experiment with distributed gather for validators, consider wiring this to a config/hparam while keeping sequential as the default, e.g.:

-            # Validators use sequential gather (master only) for stability
-            # Distributed gather is only for miners  
-            use_distributed_gather = False
+            # Validators default to sequential gather (master only) for stability.
+            # Allow optional opt‑in to distributed gather when running multi‑GPU.
+            use_distributed_gather = bool(
+                dist_helper.is_distributed()
+                and getattr(self.hparams, "validator_distributed_gather", False)
+            )

If validators are intentionally constrained to sequential gather long‑term, this can be left as‑is.


1517-1525: Comment no longer matches return_partials=False behavior

The inline comment says “Request partials for incremental processing in outer_step”, but the call now explicitly uses return_partials=False and expects a merged gather_result. To avoid confusion for future maintainers, it would be better to align the comment with the actual behavior, for example:

-                    # Request partials for incremental processing in outer_step
+                    # Request a merged result; incremental processing happens inside outer_step
@@
-                        return_partials=False,  # Always return merged result (distributed gather merges internally)
+                        return_partials=False,  # Always return merged result

The increased timeout to 150s for gradient gathering looks reasonable.


1545-1562: Distributed‑gather completion logging is clear; sequential path relies on later metrics

The new logging for the distributed gather branch (success vs failure with uids/skipped/success_rate) is well structured. Note that sequential gather only logs start and later “Skipped UIDs” plus numeric metrics in the evaluation section; if you want symmetry in textual logs, you could optionally add a similar “gather complete” summary in the sequential path using gather_result.{uids,skipped_uids,success_rate}.


1823-1826: compute_peer_val_norms type hint/docstring vs implementation

At the call site and in practice you always pass a merged SimpleNamespace from gather_with_reserve(return_partials=False), but the signature advertises gather_result: SimpleNamespace | list[SimpleNamespace] and the nearby comment says it “handles both partials and merged results” while the body assumes a merged object with .state_dict.

To avoid misleading callers (and accidental AttributeError if someone passes a list of partials later), I’d suggest either:

  • Narrowing the type hint and docstring to merged‑only, matching current behavior:
-    def compute_peer_val_norms(
-        self,
-        gather_result: SimpleNamespace | list[SimpleNamespace],
-    ) -> dict[str, torch.Tensor]:
-        """
-        Compute median L2 norms across all peers for each parameter.
- 
-        Args:
-            gather_result: Merged gather result SimpleNamespace
-        """
+    def compute_peer_val_norms(
+        self,
+        gather_result: SimpleNamespace,
+    ) -> dict[str, torch.Tensor]:
+        """
+        Compute median L2 norms across all peers for each parameter
+        from a merged gather result.
+        """

or

  • Adding an explicit if isinstance(gather_result, list): ... branch that merges/normalizes partials before the current logic.

Given the current usage in this file, tightening the annotation/docstring is the minimal, safe change.

Also applies to: 3791-3803

src/tplr/neurons.py (1)

48-70: CUDA device resolver is robust but currently unused in this path

The _resolve_cuda_device helper correctly normalizes str | torch.device | None and handles lack of CUDA gracefully. In _outer_step_single the computed cuda_device is not used, so this helper is effectively dead code here. Either wire it into places that need a normalized CUDA device or drop the local cuda_device assignment to avoid confusion.

src/tplr/comms.py (2)

2068-2257: Distributed gather helpers maintain UID alignment; minor reuse/cleanup opportunities

The new helpers look sound overall:

  • partition_uids_across_ranks and the _partition_cache (bounded LRU) give deterministic, cached round‑robin UID splits.
  • _move_gather_result_to_device and merge_gather_results correctly recurse through tensors/lists/tuples/dicts and preserve the invariant that merged.state_dict[<param>][i] corresponds to merged.uids[i] across concatenated partials.
  • target_device selection handles current_device() failures and falls back to CPU.

Two small cleanups you might consider:

  • gather_distributed currently re‑implements the round‑robin partitioning instead of using partition_uids_across_ranks; calling the static helper here would reduce duplication and ensure a single source of truth for partitioning.
  • move_to_device is implemented twice (once here and once in _move_gather_result_to_device); factoring it into a shared utility would avoid divergence in future changes.

Functionally this looks correct and matches the new distributed gather design.


2443-2719: gather_with_reserve correctly integrates distributed + partial flows; sequential fallback is coherent

The extended gather_with_reserve logic appears consistent:

  • In distributed mode it defers to gather_distributed and handles both merged and partial return shapes via return_partials.
  • Sequential fallback is triggered only when distributed gather yields no gradients or raises; for return_partials=True it wraps results in a one‑element list to preserve the type contract.
  • Primary and reserve results are merged correctly for both merged and partial modes (reserve partial appended as a new element vs. in‑place list concatenation for merged).
  • Final success rates are recomputed against the original gather_uids, and the function returns None when no gradients were successfully gathered.

One trade‑off is that in distributed fallback, every rank independently performs the sequential gather even though only rank 0’s result is actually consumed by outer_step (non‑masters rely on broadcast there). If sequential fallback becomes common, you might later consider an optimization where only the master gathers and then broadcasts the result to other ranks via dist_helper.broadcast_object. For now, the behavior is correct, just somewhat redundant.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7fccf1f and d4ea81a.

📒 Files selected for processing (3)
  • neurons/validator.py (12 hunks)
  • src/tplr/comms.py (19 hunks)
  • src/tplr/neurons.py (10 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/tplr/neurons.py (1)
src/tplr/compress.py (4)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • unpack_12bit_indices (102-142)
neurons/validator.py (6)
tests/test_state_loading.py (1)
  • rating (59-62)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/comms.py (2)
  • gather_with_reserve (2443-2719)
  • get_with_retry (1625-1740)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1275-1480)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (10)
neurons/validator.py (5)

931-937: Initialize OpenSkill ratings before rate() – good defensive change

Pre‑initializing self.openskill_ratings[uid] for all window_uids avoids missing‑rating issues and keeps the later teams/rate() call robust. Implementation is straightforward and side‑effect free for already‑rated peers.


1337-1361: Peer/commitment broadcast across ranks is correctly guarded

Refreshing self.comms.commitments on master and broadcasting [peers, reserve_peers, commitments] via dist.broadcast_object_list only when world_size > 1 and dist_helper.is_distributed() ensures all ranks share a consistent view for distributed operations without touching the non‑distributed path. This wiring looks correct.


1411-1438: Timestamp broadcast keeps gather time window rank‑consistent

Centralizing the query_block_timestamp(sync_block) call on master, falling back to time.time() once, and then broadcasting ts_value to all ranks before computing time_min/time_max removes the prior possibility of divergent time windows between ranks. The control flow and device placement for ts_tensor look sound.

Also applies to: 1440-1447


3199-3213: Switch to get_with_retry for debug dict fetch improves robustness

Using self.comms.get_with_retry(..., timeout=60, ...) and handling both None and unsuccessful results in one branch is a solid improvement over a single get call. The getattr(debug_result, "error", "unknown") usage safely covers the None case without raising, and returning a {"success": False, "sync_score": 0.0} stub integrates cleanly with slash_for_poor_sync’s existing avg_steps_behind/success handling.


3791-3827: Merged‑result norm computation logic looks correct

Given a merged gather_result.state_dict holding per‑peer vals/quant_params lists, dequantizing via maybe_dequantize_values, stacking per‑peer L2 norms, and taking the median per parameter is a reasonable and numerically stable way to build clip_norm_dict. This integrates cleanly with the later fallback clip_norm_dict.get(vals_key, eval_norm) in update_model_with_gradient.

src/tplr/neurons.py (2)

249-552: outer_step➝_outer_step_single refactor preserves single-result semantics

The split of outer_step into a thin wrapper and _outer_step_single keeps the previous behavior: only the source rank consumes gather_result, gradients are built and applied per-param to minimize memory, and non-source ranks receive grads via DTensor scatter or broadcast. Fingerprint accumulation and W&B logging remain gated on on_src. I don’t see new correctness or DDP edge cases introduced by this refactor.


1274-1480: Overlap checker now correctly supports partial results and fixes UID/index mapping

The changes to check_uid_index_overlap look solid:

  • Accepting gather_result: SimpleNamespace | list[SimpleNamespace] matches the new distributed/partial gather API.
  • The uid_to_partial map and valid_uids_for_param array remove the O(n×m) search and correctly track which original UID each unpacked index tensor corresponds to.
  • Using (uid_idx_i, uid_idx_j) as keys into pair_acc ensures overlap stats and offender selection are computed in the global UID index space, even when some peers lack a given parameter.

This should resolve the earlier complexity and misalignment issues without changing the external metrics contract.

src/tplr/comms.py (3)

483-582: s3_object_exists implements conservative, structured error handling correctly

The new s3_object_exists helper cleanly distinguishes 404/NoSuchKey, permission/config errors, and transient/network/service issues using ClientError.response. Returning False only for “definitely missing or misconfigured” and True for transient failures matches the intended conservative semantics for slashing and availability checks. No functional issues from this implementation.


666-724: Large-file handling and async torch.load offloading look correct

Raising the simple-download threshold to 500 MB and delegating larger objects to download_large_file with a size-scaled timeout is reasonable for 70B-class gradients. Offloading torch.load to a thread via asyncio.to_thread (with map_location=device_location and weights_only=True) avoids blocking the event loop during CPU decompression without changing load semantics. This block looks correct.


2259-2439: gather_distributed flow and type contracts look consistent for merged and partial results

gather_distributed’s behavior matches its signature and the downstream expectations:

  • Deterministic, cached partitioning of sorted UIDs across ranks.
  • Per‑rank gather() only when my_uids is non‑empty, with local pre‑merge when return_partials=False to reduce per‑rank memory.
  • Moving partials to CPU before all_gather_object to avoid multi‑GPU OOMs, then optionally moving the final merged result back to the requested device.
  • When return_partials=True, returning a list[SimpleNamespace | None] with rank metadata set per entry; when False, returning a single merged SimpleNamespace or None when no UIDs were fetched.

The normalization of empty partials to None and the summary logging on rank 0 are also consistent with gather_with_reserve’s expectations. I don’t see correctness issues in this orchestration.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
neurons/validator.py (1)

3790-3826: Type signature doesn't match implementation.

The signature indicates gather_result can be SimpleNamespace | list[SimpleNamespace] (line 3792), but the implementation only handles the merged case (lines 3805-3824). There's no check for isinstance(gather_result, list) before accessing gather_result.state_dict at line 3805.

While return_partials=False at line 1520 means this function currently only receives merged results, the type signature is misleading.

Either update the implementation to handle both cases or correct the type signature:

Option 1: Match implementation to signature

 def compute_peer_val_norms(
     self,
     gather_result: SimpleNamespace | list[SimpleNamespace],
 ) -> dict[str, torch.Tensor]:
+    # Handle list of partials by merging first
+    if isinstance(gather_result, list):
+        gather_result = self.comms.merge_gather_results(gather_result)
+    
     clip_norm_dict = {}
     state_dict = gather_result.state_dict

Option 2: Correct the signature to match implementation

 def compute_peer_val_norms(
     self,
-    gather_result: SimpleNamespace | list[SimpleNamespace],
+    gather_result: SimpleNamespace,
 ) -> dict[str, torch.Tensor]:
🧹 Nitpick comments (2)
neurons/validator.py (2)

1350-1360: Consider logging commitments dict size for monitoring.

Broadcasting the commitments dict (line 1353) to all ranks is necessary for distributed gather, but the dict can be large (up to 256 UIDs × bucket metadata). While broadcast_object_list uses pickle serialization which can be slow for large objects, this operation happens once per window so it's not a hot path.

Consider adding debug logging to monitor the size:

                payload = (
                    [self.comms.peers, self.comms.reserve_peers, self.comms.commitments]
                    if self.is_master
                    else [None, None, None]
                )
+                if self.is_master:
+                    import sys
+                    size_mb = sys.getsizeof(payload) / 1e6
+                    tplr.logger.debug(f"Broadcasting payload size: {size_mb:.2f} MB")
                dist.broadcast_object_list(payload, src=0)

1490-1506: Consider adding a fallback mechanism for distributed gather failures.

use_distributed_gather is hardcoded to True (line 1490), and the elif self.is_master branch at line 1500 for sequential gather is unreachable. If distributed gather encounters issues (e.g., rank failures, communication timeouts), there's no automatic fallback to sequential mode.

Consider adding a configuration option or automatic fallback:

-            use_distributed_gather = True
+            use_distributed_gather = getattr(self.hparams, "use_distributed_gather", True)

Or add error handling to fall back to sequential on distributed failure (though this would require re-architecting the gather flow to retry).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d4ea81a and 2f412b0.

📒 Files selected for processing (1)
  • neurons/validator.py (12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
neurons/validator.py (7)
tests/test_state_loading.py (1)
  • rating (59-62)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/comms.py (2)
  • gather_with_reserve (2443-2719)
  • get_with_retry (1625-1740)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1275-1480)
src/tplr/schemas.py (1)
  • success (63-65)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (4)
neurons/validator.py (4)

931-936: LGTM! Defensive initialization prevents KeyError.

Ensuring all UIDs have OpenSkill ratings before calling rate() prevents potential KeyError when accessing self.openskill_ratings[uid] in the teams list creation at line 939. This is good defensive programming.


1414-1433: LGTM! Timestamp broadcasting ensures consistent time windows across ranks.

The timestamp is queried by master (with fallback to time.time() on failure), then broadcast to all ranks using torch.tensor with float64 precision. This ensures all ranks use identical time_min and time_max values for the distributed gather operation, preventing rank inconsistencies.


1607-1624: LGTM! Gather result processing correctly uses merged result.

The code correctly handles the merged gather_result (not a list of partials) as returned by gather_with_reserve with return_partials=False. Direct access to gather_result.uids and use in upload_gather_results and check_uid_index_overlap is appropriate.


3197-3212: LGTM! Correct usage of get_with_retry with proper error handling.

The change from comms.get() to comms.get_with_retry() with timeout=60 adds retry logic for transient failures. The error handling correctly checks for None (timeout) and not debug_result.success (retrieval failure) before proceeding.

@shivam-MBZUAI shivam-MBZUAI changed the title feat: Distributed Gather with Incremental Processing feat: distributed gradient gathering for multi-GPU parallelism Nov 21, 2025
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
neurons/validator.py (1)

1798-1825: Narrow the type annotation for compute_peer_val_norms to reflect actual implementation

The review comment is accurate. The method signature at line 3790 declares gather_result: SimpleNamespace | list[SimpleNamespace], but the implementation at line 3806 directly accesses state_dict = gather_result.state_dict without checking whether gather_result is a list, which would cause an AttributeError if called with a list of partials.

Currently, the code is safe because the only gather_with_reserve call at line 1516 explicitly uses return_partials=False (line 1521). However, the misleading type annotation and docstring ("Merged gather result SimpleNamespace") create risk if the method is reused or if return_partials changes.

Recommended action: Apply option 1 from the review—narrow the type annotation to SimpleNamespace only and update the docstring to remove the "handles both partials and merged results" comment at the call site. This avoids a latent runtime bug with minimal changes and keeps behavior as-is.

src/tplr/comms.py (1)

1665-1737: Major issue: Early NOT_FOUND cutoff conflates transient errors with true 404s.

The max_consecutive_not_found = 5 threshold (2.5 seconds) causes premature exits for transient network issues and timeouts, which are all mapped to status="NOT_FOUND" at line 1600. This breaks the intended timeout semantics (defaulting to 30 seconds).

Root cause: The get() method returns CommsGetResult(status="NOT_FOUND") for both:

  • True 404 responses (where early exit is appropriate)
  • Transient failures like asyncio.TimeoutError and network errors (which should retry until the caller's timeout expires)

Recommendation: Modify get() and s3_get_object() to return distinct status codes (e.g., "TRANSIENT_ERROR", "TIMEOUT", "NOT_FOUND") so the retry logic can distinguish true 404s from transient issues. Only increment consecutive_not_found for confirmed 404 responses.

Based on learnings

♻️ Duplicate comments (1)
neurons/miner.py (1)

474-490: Broadcast reserve_peers alongside peers and commitments for consistent fallback behavior.

Currently only peers and commitments are broadcast to non-master ranks. However, gather_with_reserve is invoked with reserve_uids=self.comms.reserve_peers (line 761). If the primary gather fails and falls back to reserves, non-master ranks may have stale reserve_peers, causing divergent behavior across ranks.

Apply this diff to include reserve_peers in the broadcast:

             # Broadcast peer list (and optionally commitments) from master to all ranks for distributed gather
             if dist_helper.world_size > 1 and dist_helper.is_distributed():
                 payload = (
-                    [self.comms.peers, self.comms.commitments]
+                    [self.comms.peers, self.comms.reserve_peers, self.comms.commitments]
                     if self.is_master
-                    else [None, None]
+                    else [None, None, None]
                 )
                 dist.broadcast_object_list(payload, src=0)
-                self.comms.peers, self.comms.commitments = payload
+                self.comms.peers, self.comms.reserve_peers, self.comms.commitments = payload
 
                 if not self.is_master:
                     tplr.logger.info(
-                        f"[Rank {dist_helper.rank}] Received peers from master: {self.comms.peers}"
+                        f"[Rank {dist_helper.rank}] Received peers from master: gather={self.comms.peers}, reserve={self.comms.reserve_peers}"
                     )

Based on learnings

🧹 Nitpick comments (2)
neurons/validator.py (1)

1491-1524: Clarify use_distributed_gather semantics and logging

Right now use_distributed_gather is hard‑coded to True, so:

  • The comment “For distributed gather, all ranks must participate / For sequential gather, only master rank performs the operation” is only accurate if someone later wires use_distributed_gather to an actual condition.
  • Logs always say “Starting distributed gather …” and “Distributed gather complete …” even in effectively sequential/single‑rank cases.

Functionally it’s fine because Comms.gather_with_reserve independently decides whether to use distributed gather (world_size > 1 & is_distributed) and falls back to sequential as needed, but wiring the flag would make intent clearer and future changes safer.

Consider:

-            use_distributed_gather = True
+            use_distributed_gather = (
+                dist_helper.world_size > 1 and dist_helper.is_distributed()
+            )

and keep the current control flow / logging structure. That preserves behavior while making the mode selection explicit.

Also applies to: 1543-1560

src/tplr/comms.py (1)

20-20: Unused import: gc is never used.

The gc module is imported but not referenced anywhere in the file. Remove it to keep imports clean.

-import gc
 import json
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2f412b0 and 6728bed.

📒 Files selected for processing (7)
  • neurons/base_node.py (1 hunks)
  • neurons/miner.py (9 hunks)
  • neurons/trainer.py (3 hunks)
  • neurons/validator.py (12 hunks)
  • src/tplr/comms.py (19 hunks)
  • src/tplr/compress.py (1 hunks)
  • src/tplr/neurons.py (10 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/tplr/compress.py
🧰 Additional context used
🧬 Code graph analysis (4)
src/tplr/neurons.py (1)
src/tplr/compress.py (5)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • TopKCompressor (295-697)
  • unpack_12bit_indices (102-142)
neurons/miner.py (4)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/comms.py (1)
  • gather_with_reserve (2443-2719)
src/tplr/dcp_checkpoint.py (2)
  • state_dict (60-61)
  • state_dict (82-83)
src/tplr/comms.py (2)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
neurons/validator.py (5)
tests/test_state_loading.py (1)
  • rating (59-62)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/comms.py (2)
  • gather_with_reserve (2443-2719)
  • get_with_retry (1625-1740)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1275-1480)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (22)
neurons/base_node.py (1)

61-62: LGTM! MRO chain properly initialized.

The super().__init__() call ensures parent classes (e.g., Trainer) are initialized correctly, fixing the MRO chain issue. The inline comment clearly documents the intent.

neurons/trainer.py (1)

887-889: LGTM! Gradient norm now properly captured and logged.

The changes correctly capture total_grad_norm from clip_grad_norm_().item() and handle the null round case by setting it to 0.0. This improves observability of the training process.

Also applies to: 911-911, 928-928

src/tplr/neurons.py (3)

48-70: LGTM! Robust CUDA device resolution.

The _resolve_cuda_device helper properly handles:

  • CUDA unavailability (returns None)
  • Missing/invalid device inputs (fallback to cuda:0)
  • Exception cases during device resolution
  • Type validation to ensure CUDA device

This centralizes device resolution logic and prevents string-device issues in CUDA APIs.


250-294: LGTM! Clean delegation to internal single-step implementation.

The refactor separates the public API (outer_step) from the internal implementation (_outer_step_single). The docstring correctly documents that distributed gather merges internally, so the function always receives a single merged result, not a list of partials.

Also applies to: 297-312


1277-1277: LGTM! Efficient dual-path overlap detection.

The updated check_uid_index_overlap correctly handles both merged and partial results:

  • Partial path (lines 1291-1306): Builds uid_to_partial mapping for O(1) lookup, eliminating the O(n×m) complexity flagged in past reviews
  • Merged path (lines 1311-1315): Uses traditional shared state_dict indexing
  • Index retrieval (lines 1354-1396): Uses the mapping or direct indexing based on result type
  • Pair keys (lines 1419-1422): Correctly uses valid_uids_for_param to reference original UID indices

This optimization and dual-path support align with the PR objectives for distributed gather.

Also applies to: 1291-1315, 1354-1396

neurons/miner.py (4)

31-34: LGTM! CUDA allocator configuration properly set.

Setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True before importing torch helps reduce memory fragmentation. The placement is correct (must be set before torch import).


552-559: LGTM! Proactive memory hygiene before training.

The aggressive cleanup (empty_cache, synchronize, zero_grad) ensures memory from the previous outer step is fully released before accumulating new gradients. This helps prevent OOM issues in multi-GPU setups.


738-762: LGTM! Gather mode selection and logging correctly implemented.

The code properly:

  • Selects distributed gather when world_size > 1 and is_distributed() (lines 739-741)
  • Ensures all ranks participate in distributed mode (line 757)
  • Sets return_partials=False to get merged results internally (line 762)
  • Logs mode-specific completion details with success rates and timing (lines 777-800)

Since return_partials=False, gather_result will be a merged SimpleNamespace (not a list), making subsequent attribute access safe.

Also applies to: 777-800


830-832: LGTM! Gradient clearing and gather result handling are correct.

The changes properly:

  • Clear gradients after outer step to free memory (lines 830-832)
  • Extract UIDs and calculate success metrics from gather_result (lines 885-891, 927-933) – safe because return_partials=False ensures a merged SimpleNamespace
  • Perform aggressive cleanup with state_dict clearing and gc.collect() (lines 1014-1025)

These improvements enhance memory management in multi-GPU distributed training.

Also applies to: 885-891, 927-933, 1014-1025

neurons/validator.py (6)

931-947: Pre‑initializing OpenSkill ratings for window_uids looks correct

Initializing self.openskill_ratings[uid] for all window_uids before building teams guarantees teams = [[self.openskill_ratings[uid]] for uid in window_uids] never hits missing keys. This removes a subtle failure mode for first‑time peers and aligns the rating update flow cleanly.


1337-1369: Broadcasting peers and commitments keeps all ranks aligned

Refreshing self.comms.commitments on master and then broadcasting [peers, reserve_peers, commitments] via dist.broadcast_object_list ensures every rank sees exactly the same peer lists and bucket metadata before distributed gather. This avoids per‑rank divergence in peer selection and is consistent with the requirement that all ranks participate in the same distributed gather configuration.


1411-1439: Timestamp query + broadcast fixes time‑window divergence across ranks

Computing sync_block once, querying the timestamp only on self.is_master, falling back to time.time() with a warning, and then broadcasting ts_value to all ranks before deriving time_min / time_max ensures every rank uses an identical time window for both gather_with_reserve and subsequent gradient GETs. This removes the earlier risk where some ranks might silently use time.time() while others used the chain timestamp.


1605-1644: Master‑only post‑gather processing with merged result is consistent

Restricting upload, overlap checking, CPU offload, and extraction of skipped_uids / success_rate to:

if self.is_master and gather_result is not None:
    ...

is appropriate now that gather_with_reserve is always called with return_partials=False. The master operates on a single merged gather_result, which:

  • Avoids redundant merge_gather_results allocations on multiple ranks.
  • Ensures check_uid_index_overlap and upload_gather_results see the same merged view.
  • Leaves non‑master ranks free of extra memory pressure while still participating in the distributed gather itself.

The later metrics/slashing paths only touch these values under self.is_master, so this reshaping is coherent.


3197-3212: Using get_with_retry for debug dict fetch is robust

Switching evaluate_miner_sync to:

debug_result = await self.comms.get_with_retry(
    uid=str(eval_uid),
    window=self.sync_window - 1,
    key="debug",
    timeout=60,
    local=False,
    stale_retention=10,
)

and then treating both None and success == False as failure:

if debug_result is None or not debug_result.success:
    ...
    return {"success": False, "error": "...", "sync_score": 0.0}

gives you:

  • Bounded, retrying reads of the miner’s debug artifact, tolerant of transient NOT_FOUND / ERROR conditions.
  • A clean “unsuccessful” result shape that downstream code already handles by defaulting metrics to sentinel values.

Using getattr(debug_result, "error", "unknown") is also safe even when debug_result is None thanks to the default argument.


3803-3826: Norm computation logic itself is straightforward and appropriate

Within the (merged) path, iterating model.named_parameters(), extracting vals_key / quant_key pairs, dequantizing with maybe_dequantize_values, and then taking the median of per‑peer L2 norms to populate clip_norm_dict[vals_key] is a reasonable and robust aggregation strategy for per‑parameter clipping thresholds. Given update_model_with_gradient uses these as upper bounds (falling back to the peer’s own norm if missing), the behavior is as intended.

src/tplr/comms.py (7)

483-582: Excellent structured error handling.

The s3_object_exists method correctly distinguishes between permanent failures (404, config errors) and transient issues (timeouts, 5xx). The conservative approach of returning True for transient errors prevents false positives in availability checks. The use of structured fields (status, code) instead of string matching makes the error handling robust.


709-723: Good optimization: offloading torch.load to thread pool.

The use of asyncio.to_thread() prevents blocking the event loop during tensor decompression, enabling true parallel downloads. The consistent application to both remote (line 713) and local storage (line 1564) is good for maintainability.


1471-1493: Good: version parameter enables multi-version gradient checks.

The addition of the version parameter with a sensible default maintains backward compatibility while enabling flexibility. The implementation correctly defaults to tplr.__version__ when None is provided.


2072-2095: LGTM: Deterministic round-robin partitioning.

The implementation correctly sorts UIDs and uses Python's slice notation [rank::world_size] for clean round-robin assignment. The single-rank early return is appropriate.


2156-2171: Device selection is reasonable for most scenarios.

The fallback logic at lines 2158-2169 handles RuntimeError from torch.cuda.current_device() and conservatively attempts cuda:0 only after verifying device count. While this doesn't guard against device 0 being out of memory or otherwise unavailable at the hardware level, such edge cases are rare and would manifest as explicit CUDA errors downstream.


2443-2719: Complex but correct handling of partials vs. merged results.

The method correctly handles two return paths:

  • return_partials=True: Returns list of unmerged partials
  • return_partials=False: Returns merged SimpleNamespace

The wrapping logic (lines 2554-2555, 2566-2567, 2581-2582) consistently converts sequential results to single-element lists when return_partials=True, maintaining type consistency.

Minor note: The aggregate success rate (lines 2687-2689) overwrites per-partial success rates. This is intentional—success rate is computed relative to the original gather_uids list, not per-partial—but may be surprising to downstream consumers expecting local rates.


2291-2292: The device parameter is properly passed by all callers—no issue found.

Verification shows that actual callers of gather_with_reserve() in neurons/miner.py (line 766) and neurons/validator.py (line 1524) both explicitly pass device through kwargs. The device flows correctly through the call chain to gather_distributed(), where it's extracted at line 2292 and used at lines 2414-2424 to move results back to GPU as intended. The default of "cpu" applies only when device is omitted, which does not occur in production code paths.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/tplr/comms.py (2)

618-626: Restore default bucket fallback in s3_get_object.

bucket remains optional in the signature/docstring, but the new implementation no longer substitutes self.bucket when callers omit it. Any call such as await self.s3_get_object(key=...) now raises AttributeError when _get_s3_client dereferences None.access_key_id. Please reintroduce the fallback before _get_s3_client() is invoked.

+        if bucket is None:
+            bucket = self.bucket
         s3_client = await self._get_s3_client(bucket)

1668-1737: Do not abort get_with_retry after 2.5 seconds of NOT_FOUND.

The new max_consecutive_not_found = 5 cap forces an early exit (~2.5 s) even when callers supplied a much larger timeout (30 s by default). Because get() returns NOT_FOUND for transient S3 timeouts/network hiccups, we now stop long before the object has a chance to appear, skipping gradients under normal latency. Please either tie the cutoff to timeout (e.g., derive it from timeout / 0.5) or only increment the counter when get() can confirm a real 404 so the loop continues honouring the caller’s timeout.

🧹 Nitpick comments (5)
neurons/miner.py (3)

554-562: Consider guarding the pre‑train CUDA cleanup for robustness

The aggressive torch.cuda.empty_cache() + torch.cuda.synchronize() before every inner step is fine for GPU‑only deployments, but it will throw on CPU‑only builds and adds some overhead each window.

You can keep the behavior while making it safer:

-            # Aggressive memory cleanup before training to prevent OOM
-            # This ensures memory from previous outer step is fully released
-            torch.cuda.empty_cache()
-            torch.cuda.synchronize()
+            # Aggressive memory cleanup before training to prevent OOM
+            # This ensures memory from previous outer step is fully released
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
+                torch.cuda.synchronize(self.device)

This matches how check_memory_threshold() already guards other CUDA calls.


711-775: Align timestamp handling across ranks with validator-style broadcast

With the new distributed gather path, each rank passes time_min/time_max into gather_with_reserve. Today, every rank queries query_block_timestamp() independently and may fall back to time.time() on failures. If that fallback happens on only some ranks, different partitions can see slightly different time windows, which makes behavior harder to reason about and diverges from the validator’s now-unified timestamp logic.

You can mirror the validator pattern so all ranks share the same ts_value while still keeping the non‑blocking executor call:

-            sync_block = self.current_window * self.hparams.blocks_per_window
-            ts_value = await self.loop.run_in_executor(
-                None, self.query_block_timestamp, sync_block
-            )
-            if ts_value is None:
-                tplr.logger.warning(
-                    f"Could not get timestamp for sync block {sync_block}. Using current time as fall back.",
-                )
-                ts_value = time.time()
+            sync_block = self.current_window * self.hparams.blocks_per_window
+
+            # Only master queries timestamp, then broadcasts to all ranks
+            if self.is_master:
+                ts_value = await self.loop.run_in_executor(
+                    None, self.query_block_timestamp, sync_block
+                )
+                if ts_value is None:
+                    tplr.logger.warning(
+                        f"Could not get timestamp for sync block {sync_block}. Using current time as fall back.",
+                    )
+                    ts_value = time.time()
+            else:
+                ts_value = 0.0
+
+            # Broadcast timestamp from master to all ranks
+            ts_tensor = torch.tensor(
+                [ts_value], dtype=torch.float64, device=self.device
+            )
+            dist_helper.broadcast(ts_tensor, src=0)
+            ts_value = float(ts_tensor.item())

time_min / time_max computation below can stay as-is, now based on a single shared timestamp.


885-894: Tighten successful/skipped peer bookkeeping in debug + metrics

Right now successful_uids is computed but unused, and successful_peers is derived as peers - skipped_uids. Since gather_result.uids already encodes the actual successful peers, you can simplify and avoid the unused local:

-                # Add successful peers information
-                if gather_result is not None:
-                    successful_uids = set(gather_result.uids)
-                    skipped_uids = set(gather_result.skipped_uids)
-
-                    debug_dict["successful_peers"] = sorted(
-                        list(set(self.comms.peers) - skipped_uids)
-                    )
-                    debug_dict["skipped_peers"] = sorted(list(skipped_uids))
+                # Add successful peers information
+                if gather_result is not None:
+                    successful_uids = set(gather_result.uids)
+                    skipped_uids = set(gather_result.skipped_uids)
+
+                    debug_dict["successful_peers"] = sorted(successful_uids)
+                    debug_dict["skipped_peers"] = sorted(skipped_uids)

The metrics section below (gather_success_rate and skipped_uids_listskipped_peers field) already looks consistent with a merged gather_result, so no change needed there.

Also applies to: 929-999

neurons/validator.py (2)

1490-1524: Make gather mode detection dynamic and fix device typing for gather_with_reserve

Two small issues here:

  1. use_distributed_gather = True makes the sequential logging branch dead and causes all ranks to enter the gather path even in non‑distributed setups. It’s safer and clearer to mirror the miner:
-            gather_start = tplr.T()
-            skipped_uids: list[int] = []
-            success_rate = 0.0
-            gather_result = None
-            skip_window = False
-            use_distributed_gather = True
+            gather_start = tplr.T()
+            skipped_uids: list[int] = []
+            success_rate = 0.0
+            gather_result = None
+            skip_window = False
+            # Use distributed gather automatically when running with multiple ranks
+            use_distributed_gather = (
+                dist_helper.world_size > 1 and dist_helper.is_distributed()
+            )
  1. In the gather_with_reserve call, device=cast(str, self.device) is a runtime no‑op; the actual object passed is still a torch.device. Miner passes device=str(self.device), and Comms is written to accept a string device identifier.

For consistency and to avoid surprises inside Comms:

-                        timeout=150,  # Increased from 90 to allow more time for gradient gathering
-                        device=cast(str, self.device),
+                        timeout=150,  # Increased from 90 to allow more time for gradient gathering
+                        device=str(self.device),

These tweaks keep the control flow honest and the device argument uniform across callsites.


1821-1824: Clarify compute_peer_val_norms contract to “merged only”

compute_peer_val_norms currently assumes gather_result.state_dict is a merged object with per‑parameter vals/quant_params; it doesn’t handle the list-of-partials case suggested by its type hint and the comment at Line 1823. Given this run path now always uses return_partials=False, it’s better to make the contract explicit rather than suggest partial support that doesn’t exist.

You can align the comment and signature:

-                clip_norm_dict = {}
-                if gather_result is not None:
-                    # compute_peer_val_norms handles both partials and merged results
-                    clip_norm_dict = self.compute_peer_val_norms(gather_result)
+                clip_norm_dict = {}
+                if gather_result is not None:
+                    # compute_peer_val_norms expects a merged gather result
+                    clip_norm_dict = self.compute_peer_val_norms(gather_result)

and:

-    def compute_peer_val_norms(
-        self,
-        gather_result: SimpleNamespace | list[SimpleNamespace],
-    ) -> dict[str, torch.Tensor]:
+    def compute_peer_val_norms(
+        self,
+        gather_result: SimpleNamespace,
+    ) -> dict[str, torch.Tensor]:

The implementation below then accurately matches the type and docstring (“Merged gather result SimpleNamespace”), avoiding confusion for future partial‑support work.

Also applies to: 3790-3826

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6728bed and 8e9b8c7.

📒 Files selected for processing (7)
  • neurons/base_node.py (1 hunks)
  • neurons/miner.py (9 hunks)
  • neurons/trainer.py (3 hunks)
  • neurons/validator.py (12 hunks)
  • src/tplr/comms.py (19 hunks)
  • src/tplr/compress.py (1 hunks)
  • src/tplr/neurons.py (10 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • neurons/trainer.py
  • src/tplr/compress.py
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (4)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/comms.py (1)
  • gather_with_reserve (2443-2719)
src/tplr/dcp_checkpoint.py (2)
  • state_dict (60-61)
  • state_dict (82-83)
neurons/validator.py (6)
tests/test_state_loading.py (1)
  • rating (59-62)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/comms.py (2)
  • gather_with_reserve (2443-2719)
  • get_with_retry (1625-1740)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1275-1480)
src/tplr/schemas.py (1)
  • success (63-65)
src/tplr/neurons.py (1)
src/tplr/compress.py (4)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • unpack_12bit_indices (102-142)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • all_gather_object (307-325)
🔇 Additional comments (9)
neurons/miner.py (4)

31-42: CUDA allocator env + dist import look solid

Setting PYTORCH_CUDA_ALLOC_CONF before importing torch (while honoring any pre‑set value) and adding the explicit torch.distributed import are appropriate for the new distributed gather + memory‑hygiene flow. No issues here.


469-493: Peer / reserve / commitment broadcast keeps ranks aligned

Broadcasting self.comms.peers, self.comms.reserve_peers, and self.comms.commitments from the master to all ranks before gather ensures every rank has a consistent view of primary and reserve peers plus bucket metadata, which is exactly what gather_with_reserve expects. This also mirrors the validator path, reducing the risk of subtle per‑rank divergence.


740-803: Distributed vs sequential gather wiring matches Comms semantics

Deriving use_distributed_gather from dist_helper and then:

  • having all ranks call gather_with_reserve in distributed mode,
  • restricting the sequential path to the (single) master rank when not distributed, and
  • forcing return_partials=False so miners always see a merged SimpleNamespace,

is consistent with the documented behavior of gather_with_reserve and with how downstream code accesses uids, skipped_uids, and success_rate. This addresses the earlier partial‑vs‑merged mismatch cleanly.


1016-1025: Gather-result cleanup timing is safe but future uses should beware

Clearing gather_result.state_dict.__dict__ (when present) and deleting gather_result after all logging/outer-step work is finished is a reasonable way to encourage GC and reduce memory pressure.

Just be aware that if you later start any background tasks that capture gather_result (or its state_dict) by reference, they will see this cleared state; in that case this cleanup would need to move after those tasks complete or be narrowed to temporary structures only. As written, with no async users of gather_result, the lifecycle is coherent.

neurons/validator.py (5)

931-937: OpenSkill rating pre-initialization avoids missing-rating edge cases

Initializing self.openskill_ratings[uid] for every uid in window_uids before calling self.openskill_model.rate(...) prevents gaps where a peer participates in rating updates without an existing rating object. This keeps the OpenSkill flow consistent with how you hydrate ratings from saved state and avoids ad-hoc defaults.


1337-1369: Commitment refresh + peer/reserve broadcast correctly sync gather state

Refreshing self.comms.commitments = await self.comms.get_commitments() on the master and then broadcasting [peers, reserve_peers, commitments] to all ranks ensures:

  • all ranks use the same primary + reserve peer lists, and
  • everyone sees the same bucket config when calling gather_with_reserve.

The logging for non‑master ranks also helps debug any divergence. This is aligned with the miner side and looks good.


1411-1439: Timestamp broadcast eliminates per-rank time window skew

The new pattern—master querying query_block_timestamp(sync_block) (with a time.time() fallback), broadcasting ts_value via dist_helper.broadcast, then deriving time_min/time_max from the shared value—fixes the earlier risk where some ranks could fall back to local time while others used chain time.

Given time_min/time_max feed into gather_with_reserve and later comms.get(...) calls, this is a solid synchronization improvement for distributed validator runs.


1607-1643: Master-only post-gather processing and offload look coherent

Using the master rank to:

  • upload the merged gather_result,
  • run check_uid_index_overlap(...),
  • offload aggregated results to pinned CPU via offload_gather_results, and
  • snapshot skipped_uids = list(gather_result.skipped_uids) and success_rate = gather_result.success_rate

before the barrier keeps heavy processing centralized while still letting all ranks synchronize on the same skipped_uids/success_rate for later slashing and metrics. The ordering (overlap → offload → slashing metadata capture) looks correct.


3197-3212: get_with_retry-based debug fetch handles timeouts/invalid data safely

Switching evaluate_miner_sync to use self.comms.get_with_retry(..., timeout=60, stale_retention=10) and then treating both None and not debug_result.success as failure is a good improvement:

  • Timeouts and transient issues are absorbed as None.
  • The getattr(debug_result, "error", "unknown") usage keeps logging robust even when debug_result is None.
  • Invalid or missing debug payloads cleanly return a {"success": False, "sync_score": 0.0} result, which the slashing logic can consume.

This keeps sync evaluation resilient without raising on infrastructure glitches.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
src/tplr/comms.py (1)

1665-1671: Early NOT_FOUND cutoff still risks violating timeout semantics

get_with_retry now tracks consecutive_not_found and exits after 5 NOT_FOUND results (~2.5s at 0.5s sleep), regardless of the caller’s timeout and even before time_max is reached. Because get() maps many transient S3/network failures (e.g., timeouts, some ClientErrors) to status="NOT_FOUND" via s3_get_object returning None, slow or flaky peers can still be treated as “likely doesn’t exist” long before the configured timeout window elapses. This is essentially the same concern raised earlier in past reviews.

A minimal change that keeps the optimization but respects time_max would be to only start the “give up early” logic once we’re sure the object should have appeared (i.e., no time_max, or we are already past time_max + grace):

-        consecutive_not_found = 0  # Track consecutive NOT_FOUND errors
-        max_consecutive_not_found = 5  # Give up after 5 consecutive NOT_FOUND (2.5s)
+        consecutive_not_found = 0  # Track consecutive NOT_FOUND errors
+        max_consecutive_not_found = 5  # Only used once we're sure gradient should exist

@@
-            # For NOT_FOUND, track consecutive failures
-            if result.status == "NOT_FOUND":
-                consecutive_not_found += 1
-                if consecutive_not_found >= max_consecutive_not_found:
-                    tplr.logger.debug(
-                        f"GET {uid}/{window}/{key} - {consecutive_not_found} consecutive NOT_FOUND. "
-                        f"Giving up early (file likely doesn't exist)."
-                    )
-                    return None
-            else:
-                # Reset counter on other errors (might be transient network issues)
-                consecutive_not_found = 0
+            # For NOT_FOUND, track consecutive failures, but only after the file
+            # should reasonably exist (no time_max, or we've passed time_max+grace).
+            if result.status == "NOT_FOUND":
+                if time_max is None or past_time_max:
+                    consecutive_not_found += 1
+                    if consecutive_not_found >= max_consecutive_not_found:
+                        tplr.logger.debug(
+                            f"GET {uid}/{window}/{key} - {consecutive_not_found} consecutive NOT_FOUND. "
+                            f"Giving up early (file likely doesn't exist)."
+                        )
+                        return None
+            else:
+                # Reset counter on other errors (might be transient network issues)
+                consecutive_not_found = 0

This way the early cutoff only triggers after we either have no time_max hint or are already beyond time_max + time_max_grace_period, which aligns better with the intent described in earlier discussion.

If you prefer to instead distinguish true 404s from transient errors inside get()/s3_get_object, that would also resolve the issue; please double-check that mapping as well.

Also applies to: 1725-1737

🧹 Nitpick comments (11)
src/tplr/comms.py (4)

667-682: Large-object branching and async torch.load are implemented correctly; consider tightening bucket handling

The 500MB size threshold and switch to download_large_file with a scaled timeout look reasonable, and offloading torch.load to asyncio.to_thread (with error logging) is a good fix for event-loop blocking.

One minor robustness nit: s3_get_object’s signature/docstring say bucket defaults to self.bucket, but the implementation assumes bucket is non-None (e.g., _get_s3_client(bucket) and bucket.name). If this method is ever called with bucket=None, it will fail. You may want to add a small guard at the top:

 async def s3_get_object(
-    self,
-    key: str,
-    bucket: Bucket | None = None,
+    self,
+    key: str,
+    bucket: Bucket | None = None,
@@
-    s3_client = await self._get_s3_client(bucket)
+    if bucket is None:
+        bucket = self.bucket
+    s3_client = await self._get_s3_client(bucket)

This keeps the behavior aligned with the docstring.

Also applies to: 709-723


2072-2257: merge_gather_results behavior is sound; minor device-choice nuance

The new distributed helpers are generally well-structured:

  • partition_uids_across_ranks uses deterministic round-robin over a sorted UID list, which is important for reproducibility.
  • merge_gather_results correctly:
    • Filters out None partials.
    • Selects a reasonable default target_device (current CUDA device when available, otherwise CPU).
    • Recursively moves tensors to target_device.
    • Concatenates lists per parameter key across partials, and aggregates metrics/UIDs/steps and skipped_uids.
    • Computes a global success_rate based on successes vs skipped.

One nuance to be aware of: when target_device is left as None, the default path will move all tensors in the merged result to the current CUDA device. In contexts where you intend to keep intermediate compressed tensors on CPU to save GPU memory, make sure callers pass target_device="cpu" explicitly (as you already do in some other paths) to avoid surprising memory pressure.


2415-2545: Sequential vs distributed fallback messaging could be clearer

The gather_with_reserve orchestration looks logically consistent:

  • In distributed mode, all ranks participate in gather_distributed, but only master receives merged/partial results; non-master ranks explicitly return None.
  • Master-only sequential fallback uses self.gather(...) and wraps the result in a list when return_partials=True, maintaining the return-type contract.
  • Non-distributed mode always uses sequential gather and applies the same wrapping behavior.

One minor confusion: the log message

context_log(
    message=f"[gather_with_reserve] 🔄 Falling back to SEQUENTIAL gather on ALL ranks (consistency requirement)"
)

is emitted inside if primary is None and dist_helper.is_master:, but only the master actually runs the sequential gather() that follows—other ranks just see primary is None and return None. If the intent is truly master-only fallback (with subsequent DDP broadcast/coordination elsewhere), consider updating the log text to avoid implying that all ranks are performing the sequential gather. Otherwise, you’d need to invoke gather() on all ranks to match the current message.

This is cosmetic but helps future readers understand the actual behavior.


2567-2701: Partials handling in gather_with_reserve is consistent with type contract

The downstream handling of primary once it may be a list of partials is well thought out:

  • Logging aggregates total_uids and total skipped across all partials.
  • Missing slots are filled from reserve_uids without duplicating already-used UIDs.
  • Reserve fallback is appended as an extra partial when return_partials=True, and merged into the single namespace otherwise.
  • Success-rate recomputation correctly normalizes by the original gather_uids length and propagates the final success_rate to each partial.

Once gather_distributed respects return_partials=True, this machinery will give you consistent behavior between distributed and sequential modes.

src/tplr/neurons.py (3)

22-22: Outer-step refactor to _outer_step_single preserves behavior

The changes here:

  • Introduce _resolve_cuda_device to centralize CUDA-device derivation.
  • Turn outer_step into a thin wrapper that simply forwards all arguments to _outer_step_single.
  • Keep the original, memory-minimizing outer-step logic entirely inside _outer_step_single (per-param decode/decompress, DDP broadcast, optimizer.step per parameter, fingerprint accumulation).

Functionally, this keeps the previous behavior while making the outer-step surface cleaner and easier to extend. There are no new CUDA API misuses (no string devices passed to torch.cuda.* here), and the DDP logic still gates all heavy work on on_src = is_master or not ddp.

Also applies to: 48-69, 249-295, 297-551


816-883: Catch-up path enhancements are coherent and safe

In catchup_with_aggregation_server:

  • The initial memory logging helper and periodic log_memory_usage calls provide useful visibility without affecting control flow.
  • Checkpoint verification against a validator debug payload at checkpoint_current_window is guarded by instance.is_master and only proceeds when debug_fetch contains a dict, so non-master ranks aren’t impacted.
  • When an aggregated aggregator object is missing or malformed on the final window, the live comms.gather(...) fallback is only attempted once, and its result is wrapped into the same gather_ns schema expected by outer_step.
  • Skip decisions are broadcast via dist_helper.broadcast, and both pre- and post-outer-step barriers (safe_barrier) ensure all ranks stay in lock-step through the catch-up loop.
  • After each window, you aggressively clear gather_ns, run gc.collect(), and flush CUDA caches, which is appropriate for 70B-scale models in a replay loop.

Overall this is a solid and carefully synchronized catch-up flow; I don’t see correctness or ordering issues in the new logic.

Also applies to: 912-969, 985-1005, 1008-1079, 1084-1148, 1149-1168


1275-1316: check_uid_index_overlap partial support and complexity reduction are correct

The updated check_uid_index_overlap now:

  • Accepts either a merged SimpleNamespace or list[SimpleNamespace] partials.
  • For partials, flattens all UIDs into a single uids list and builds a uid_to_partial mapping from UID → (partial, index_in_partial), giving O(n) lookup instead of scanning each partial per UID.
  • Tracks valid_uids_for_param (indices into the global uids list) for each parameter, so the pair accumulator pair_acc still keys pairs by global UID indices, and final reporting (uids[i], uids[j]) remains correct.
  • Only includes peers that actually provided compressed indices for a parameter; if fewer than 2 peers have data for a given param, it skips that parameter cleanly.
  • Leaves the slashing decision, overlap aggregation, and summary metrics (mean/min/max overlaps, offenders) unchanged.

This is a nice improvement in both flexibility and complexity. One small cleanup you could consider (non-blocking): in the partials branch, all_state_dicts is populated but then immediately overridden via uid_to_partial lookups, so it can likely be removed entirely from that branch to reduce confusion.

Also applies to: 1351-1399, 1408-1424, 1431-1471

neurons/miner.py (2)

554-562: Consider guarding CUDA cleanup for CPU‑only runs

Unconditionally calling torch.cuda.empty_cache() and torch.cuda.synchronize() is fine on CUDA builds, but will raise on CPU‑only PyTorch builds. If you expect any CPU‑only miner runs, consider wrapping this block with if torch.cuda.is_available(): to avoid surprises.


740-745: Distributed gather / outer‑step semantics rely on merged result on all ranks

This section assumes that in distributed mode:

  • gather_with_reserve(..., return_partials=False, ...) returns a merged SimpleNamespace with .uids, .skipped_uids, .success_rate, and .state_dict on every rank, not just on the master.
  • should_update is determined by the master (gather_result is not None) and then broadcast via dist_helper.all_ok, while non‑masters still pass their local gather_result into self.outer_step(gather_result).

Given that, the logging, metrics (gather_success_rate, skipped_uids_list), debug info, and cleanup (state_dict.__dict__.clear()) are all consistent and safe.

If, however, gather_with_reserve only returns a non‑None result on the master in some paths, non‑master ranks would end up calling outer_step(None) and the state_dict cleanup would be a no‑op on those ranks. Please double‑check Comms.gather_with_reserve/gather_distributed to confirm that a merged result is indeed broadcast (or otherwise made available) to all ranks in distributed mode, or alternatively gate outer_step and the cleanup on self.is_master and broadcast updated model weights instead.

Also applies to: 747-756, 757-757, 778-813, 840-842, 895-902, 937-943, 1007-1007, 1024-1033

neurons/validator.py (2)

1490-1515: Distributed gather flow is coherent; consider clarifying mode flag and assumptions

This block wires up the validator’s distributed gather path:

  • use_distributed_gather = True forces all ranks into the distributed path, with gather_with_reserve(..., return_partials=False, ...) responsible for choosing distributed vs sequential and merging results.
  • Gather completion logs and downstream metrics (skipped_uids, success_rate, actual_gather_uids, reserve usage, quality stats) all treat gather_result as a merged SimpleNamespace, not a list of partials.
  • skip_window is decided on the master based on gather_result is None or exceptions and then broadcast, keeping all ranks aligned.

Functionally this looks correct assuming:

  1. gather_with_reserve always returns a merged result (or None) rather than partials for this call site, and
  2. In distributed mode, that merged result is available on all ranks (or outer_step/norm code gracefully handles None on non‑masters).

Two small suggestions:

  • Rename use_distributed_gather or derive it from dist_helper.is_distributed() again, so the variable name matches its intent (currently it’s always True, even in effectively sequential runs).
  • Optionally add an assertion or comment at this call site documenting that this path is “merged‑result only” and will break if return_partials=True is ever reintroduced here.

Also applies to: 1516-1531, 1543-1560, 1561-1590, 1606-1643, 1667-1705, 1721-1733


1821-1824: compute_peer_val_norms type hint/docstring are broader than the implementation

The signature and nearby comment say gather_result: SimpleNamespace | list[SimpleNamespace] and “handles both partials and merged results”, but the implementation assumes:

  • gather_result has a .state_dict attribute, and
  • state_dict exposes attributes like "{param}vals" / "{param}quant_params" (i.e., SimpleNamespace layout), not a list of partials.

Given the current call site (compute_peer_val_norms(gather_result) right after gather_with_reserve(..., return_partials=False, ...)), this works because gather_result is always merged. However, the type hint and docstring are misleading and could encourage incorrect future use.

Consider either:

  • Narrowing the annotation and docstring to gather_result: SimpleNamespace and “merged result only”, or
  • Re‑introducing a partials‑aware implementation (e.g., aggregate per‑param vals/quant_params across a list of namespaces, with proper alignment and dequantization) if you intend to support partials here again.

Also applies to: 3791-3802

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8e9b8c7 and 20fec9f.

📒 Files selected for processing (7)
  • neurons/base_node.py (1 hunks)
  • neurons/miner.py (9 hunks)
  • neurons/trainer.py (3 hunks)
  • neurons/validator.py (12 hunks)
  • src/tplr/comms.py (19 hunks)
  • src/tplr/compress.py (1 hunks)
  • src/tplr/neurons.py (10 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • neurons/trainer.py
  • src/tplr/compress.py
🧰 Additional context used
🧬 Code graph analysis (4)
src/tplr/neurons.py (1)
src/tplr/compress.py (6)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • ChunkingTransformer (145-292)
  • TopKCompressor (295-697)
  • unpack_12bit_indices (102-142)
neurons/miner.py (3)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/comms.py (1)
  • gather_with_reserve (2408-2701)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • gather_object (284-305)
neurons/validator.py (6)
tests/test_state_loading.py (1)
  • rating (59-62)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/comms.py (2)
  • gather_with_reserve (2408-2701)
  • get_with_retry (1625-1740)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1275-1480)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (12)
neurons/base_node.py (1)

61-62: ****

The placement of super().__init__() at the end of BaseNode.__init__() is correct and intentional for the cooperative multiple inheritance pattern used here.

Verification confirms:

  • Trainer.__init__() only initializes one attribute (inner_scheduler_step_count) and has no dependencies on BaseNode attributes
  • Trainer methods assume BaseNode attributes (uid, rank, world_size, hparams, etc.) are already initialized
  • For both Miner(BaseNode, Trainer) and Validator(BaseNode, Trainer), the MRO places BaseNode first, followed by Trainer
  • The end-of-initialization call to super().__init__() ensures BaseNode completes its initialization before delegating to Trainer

This is a well-designed cooperative multiple inheritance pattern where parent-class initialization properly follows the method resolution order. No changes are needed.

src/tplr/comms.py (5)

20-21: Partition cache introduction looks good

Using OrderedDict with a hard cap (32 entries) keyed by (window, sorted_uids, world_size) is a sensible way to reuse UID partitions without unbounded growth. No functional or concurrency issues spotted here.

Also applies to: 29-29, 142-144


483-582: HEAD-based s3_object_exists is robust and conservative

The new s3_object_exists helper correctly:

  • Uses head_object with a timeout.
  • Distinguishes 404/not-found and config/permission errors (404/403, NoSuchBucket, invalid bucket).
  • Treats timeouts and 5xx/network issues as transient and returns True to avoid false slashing.

This is a solid, structured improvement over string-based error checks and matches the intended conservative semantics.


1563-1566: Local-path torch.load offload is appropriate

Running the local torch.load via asyncio.to_thread in the local=True code path mirrors the remote path change and avoids blocking the event loop during deserialization. The subsequent handling of state_dict and global_step is unchanged, so this is a safe performance/latency improvement.


1472-1493: gradient_timestamp version parameterization is correct

Allowing an optional version override and defaulting to tplr.__version__ fixes the earlier undefined self.version issue and aligns with how other paths construct gradient keys. The rest of the logic (using commitments, head_object, and returning 0.0 on error) remains intact.


1844-2006: Gather response processing and validation look consistent

The updated response-processing loop in gather():

  • Logs per-UID download exceptions with log_with_context (warning level).
  • Treats None responses from get_with_retry as “skipped” UIDs with an info log.
  • Adds tighter validation for quantization params and NaN/Inf checks using a temporary tensor_to_check on the target device and cleans it up immediately.

These changes preserve the earlier semantics (only valid, fully-formed compressed payloads contribute to aggregation) while improving observability and keeping GPU temporaries short-lived. No correctness issues spotted here.

neurons/miner.py (2)

32-35: CUDA allocator config before torch import looks correct

Setting PYTORCH_CUDA_ALLOC_CONF before importing torch is the right place to influence the allocator; the conditional guard also preserves caller‑provided settings. No changes needed.


474-493: Peer / reserve / commitment broadcast keeps all ranks consistent

Broadcasting [peers, reserve_peers, commitments] from the master to all ranks aligns miner behavior with the validator and with gather_with_reserve semantics; the logging on non‑master ranks is useful for debugging. This block looks sound.

neurons/validator.py (4)

931-937: OpenSkill rating initialization before rate() is correct

Ensuring every UID in window_uids has an entry in self.openskill_ratings before building teams avoids model errors and keeps rating updates stable across windows. This change looks good.


1337-1339: Commitment refresh + peers/reserve broadcast matches miner and comms usage

Refreshing self.comms.commitments on the master after update_peers, then broadcasting [peers, reserve_peers, commitments] to all ranks, keeps gather configuration aligned across the validator group and mirrors the miner’s behavior. Logging on non‑master ranks provides helpful visibility. No issues here.

Also applies to: 1350-1369


1411-1437: Timestamp broadcast for gather time window fixes rank divergence

Having only the master call query_block_timestamp(), falling back to time.time() if needed, and then broadcasting ts_value to all ranks ensures time_min/time_max are identical across the distributed gather. This addresses the earlier risk of per‑rank time‑window skew. Looks good.

Also applies to: 1441-1454


3197-3205: Robust handling of missing validator debug dict

Switching to self.comms.get_with_retry(...) and explicitly handling None / unsuccessful results gives clearer behavior when the miner’s debug payload is missing or unreachable and avoids attribute errors. The warning log that includes getattr(debug_result, "error", "unknown") is also safe for the None case. This looks correct.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
neurons/trainer.py (1)

872-872: Remove duplicate gradient clipping before unscaling.

This line clips gradients before scaler.unscale_() is called (line 886), which is incorrect for mixed precision training. When using GradScaler, gradients must be unscaled before clipping; otherwise, you're clipping scaled gradients rather than true gradients. Additionally, lines 887-889 correctly clip gradients after unscaling, making this line redundant.

Impact: Gradients are clipped twice in the normal training path, which incorrectly reduces gradient magnitudes and affects convergence.

Apply this diff to remove the duplicate clipping:

-                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
-
                         if not null_round:
src/tplr/comms.py (1)

1665-1736: Early NOT_FOUND cutoff may still violate timeout semantics for transient errors.

The past review comment correctly identified that max_consecutive_not_found = 5 (≈2.5s) causes early exit independent of the timeout parameter (default 30s). The current implementation still has this issue:

  1. Line 1600 converts all None returns from s3_get_object to status="NOT_FOUND" without distinguishing the cause
  2. s3_get_object returns None for both true 404s (line 658) AND asyncio timeouts (line 653)
  3. The early cutoff at 2.5s will abort downloads that are experiencing slow/flaky connections but should continue retrying until the caller's timeout expires

Suggested fix: Either:

  • Distinguish true 404s from transient errors in the status (e.g., return "TIMEOUT" vs "NOT_FOUND")
  • Only increment consecutive_not_found for confirmed 404 responses
  • Derive the cutoff threshold from the timeout parameter: max(5, int(timeout / 1.0))

This ensures transient network issues respect the caller's timeout window while still optimizing away persistent 404s.

🧹 Nitpick comments (4)
src/tplr/neurons.py (1)

332-332: Optional cleanup: Unused device resolution.

The cuda_device variable is computed but never used in the function body. The string device parameter is used throughout instead.

If CUDA device resolution isn't needed here, consider removing this line:

-    cuda_device = _resolve_cuda_device(device)

If it was intended for future use (e.g., for CUDA API calls), consider adding a comment explaining why it's computed but unused.

src/tplr/comms.py (1)

2156-2171: Device fallback logic handles edge cases but has a subtle issue.

The device selection correctly handles RuntimeError from torch.cuda.current_device() and falls back gracefully. However, lines 2165-2169 attempt to set device 0, which could fail if device 0 is unavailable:

try:
    torch.cuda.set_device(0)
    target_device = "cuda:0"
except RuntimeError:
    target_device = "cpu"

This is a minor issue since it eventually falls back to CPU, but it could be cleaner to probe for the first available device index rather than assuming device 0 exists. That said, the current implementation is safe and will not crash.

neurons/validator.py (2)

1490-1490: Auto-detect distributed gather instead of hardcoding.

Line 1490 unconditionally sets use_distributed_gather = True, but the PR objectives specify "auto-detection to use distributed gather when world_size > 1." This reduces clarity and doesn't match the stated behavior.

Apply this diff to auto-detect based on actual distributed state:

-            use_distributed_gather = True
+            use_distributed_gather = dist_helper.world_size > 1 and dist_helper.is_distributed()

This aligns with the PR objectives and makes the gather mode selection explicit and conditional on actual multi-GPU setup.


3790-3826: Function signature claims to accept list but implementation doesn't handle it.

The signature on line 3792 declares gather_result: SimpleNamespace | list[SimpleNamespace], but the implementation at line 3805 directly accesses gather_result.state_dict without checking if it's a list first. If a list were passed, this would raise AttributeError.

While this isn't a runtime bug (line 1520 ensures return_partials=False so only merged results are passed), it violates the interface contract and could mislead future developers.

Choose one of these options:

Option 1: Remove list from signature (simpler, matches actual usage):

     def compute_peer_val_norms(
         self,
-        gather_result: SimpleNamespace | list[SimpleNamespace],
+        gather_result: SimpleNamespace,
     ) -> dict[str, torch.Tensor]:
         """
         Compute median L2 norms across all peers for each parameter.
 
         Args:
-            gather_result: Merged gather result SimpleNamespace
+            gather_result: Merged gather result (must be SimpleNamespace, not partials list)

Option 2: Add list handling (if partials support is needed):

     def compute_peer_val_norms(
         self,
         gather_result: SimpleNamespace | list[SimpleNamespace],
     ) -> dict[str, torch.Tensor]:
         """
         Compute median L2 norms across all peers for each parameter.
 
         Args:
-            gather_result: Merged gather result SimpleNamespace
+            gather_result: Either merged result (SimpleNamespace) or list of partials
 
         Returns:
             Dictionary mapping parameter names to median norms
         """
         clip_norm_dict = {}
 
+        # If partials, merge first
+        if isinstance(gather_result, list):
+            gather_result = self.comms.merge_gather_results(gather_result)
+
         state_dict = gather_result.state_dict

Option 1 is recommended since the code always uses merged results.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 20fec9f and 66f8365.

📒 Files selected for processing (7)
  • neurons/base_node.py (1 hunks)
  • neurons/miner.py (9 hunks)
  • neurons/trainer.py (3 hunks)
  • neurons/validator.py (12 hunks)
  • src/tplr/comms.py (19 hunks)
  • src/tplr/compress.py (1 hunks)
  • src/tplr/neurons.py (10 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/tplr/compress.py
  • neurons/base_node.py
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (3)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/comms.py (1)
  • gather_with_reserve (2395-2688)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • gather_object (284-305)
neurons/validator.py (6)
tests/test_state_loading.py (1)
  • rating (59-62)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/comms.py (2)
  • gather_with_reserve (2395-2688)
  • get_with_retry (1625-1740)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1275-1480)
src/tplr/neurons.py (1)
src/tplr/compress.py (4)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • unpack_12bit_indices (102-142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (22)
neurons/trainer.py (3)

887-889: LGTM: Correct gradient clipping and norm capture.

This correctly captures the gradient norm after unscaling and before the optimizer step, following the proper pattern for mixed precision training with GradScaler.


911-911: LGTM: Appropriate default for spin-up phase.

Setting total_grad_norm to 0.0 during the spin-up phase (when no optimizer step is taken) provides a consistent value for logging and is appropriate for this branch.


928-928: LGTM: Useful diagnostic logging.

Adding gradient norm to the inner step logging provides valuable insight into training dynamics and helps monitor for gradient explosion or vanishing issues.

src/tplr/neurons.py (5)

48-69: LGTM: Well-designed device resolution helper.

The function properly handles all edge cases (None device, string conversion, no CUDA devices) with appropriate fallbacks. The defensive error handling ensures robustness.


280-294: LGTM: Clean delegation pattern.

The function correctly delegates to _outer_step_single() with all parameters forwarded, maintaining a clean separation of concerns.


1277-1315: LGTM: Well-structured partial vs. merged handling.

The function properly extracts UIDs and state_dicts from both partial and merged gather results. The uid_to_partial mapping at line 1296 provides O(1) lookup, successfully addressing the O(n×m) complexity concern from previous reviews.


1354-1395: LGTM: Correct partial/merged indexing with defensive checks.

The code properly handles both partial and merged cases:

  • Partials use the O(1) uid_to_partial lookup (lines 1361-1365)
  • Merged results use direct list indexing (lines 1381-1388)
  • Defensive continue at line 1362 handles missing UIDs gracefully

The optimization successfully reduces complexity from O(n×m) to O(n).


1408-1424: LGTM: Correct pair key construction.

The code properly uses valid_uids_for_param to map back to original UID indices (lines 1420-1422), ensuring pair keys remain consistent whether using partials or merged results.

neurons/miner.py (7)

32-34: LGTM: Correct CUDA allocator configuration.

The environment variable is properly set before importing torch (required for it to take effect). The conditional check prevents overwriting any existing configuration.


474-492: LGTM: Complete peer synchronization across ranks.

The broadcast properly synchronizes peers, reserve_peers, and commitments from master to all ranks. This ensures consistent gather behavior across distributed ranks and addresses the previous review comment about broadcasting reserve peers.

The payload structure (3-item list) and unpacking logic are correct.


554-561: LGTM: Thorough pre-training memory cleanup.

The aggressive cleanup (cache clear, synchronize, gradient clear) helps prevent OOM at the start of training. The defensive check at line 560 safely handles the model attribute.


740-816: LGTM: Correct distributed/sequential gather orchestration.

The gather mode selection logic properly handles both scenarios:

  1. Distributed mode (lines 741-743, 759):

    • All ranks participate in collective operations
    • Master checks result, non-master ranks defer (lines 806-810)
  2. Sequential mode (lines 752-755, 759):

    • Only master performs gather
  3. Consistent configuration (lines 763-774):

    • return_partials=False: Always returns merged result, aligning with outer_step expectations
    • reserve_uids=self.comms.reserve_peers: Matches validator behavior
    • local=False: Correctly fetches from remote storage

The final broadcast at line 815 ensures all ranks agree on whether to update.


840-842: LGTM: Proper post-outer-step gradient cleanup.

The defensive check and set_to_none=True ensure memory is efficiently released after the outer step.


895-901: LGTM: Correct metrics extraction from merged gather result.

Since return_partials=False is set at line 764, gather_result is always a merged SimpleNamespace or None (never a list). The attribute accesses (.uids, .skipped_uids, .success_rate) are correct for this format, and the defensive None check at lines 938-943 prevents errors when no gradients are gathered.

Also applies to: 937-943


1024-1035: LGTM: Comprehensive post-window memory cleanup.

The cleanup thoroughly releases gather result memory (lines 1028-1030) with defensive attribute checks, followed by garbage collection and CUDA cache clearing. This aggressive memory management helps prevent OOM issues across windows, particularly important for 70B-model scenarios mentioned in the PR objectives.

src/tplr/comms.py (7)

142-144: Verify the partition cache key includes all relevant factors.

The cache key (window, tuple(sorted_uids), world_size) correctly captures the inputs that determine partitioning. However, consider whether my_uid or rank should also be part of the key if they affect partition assignment in the future. Currently, the cache stores all partitions for all ranks (as a list), so this is correct.


483-582: LGTM! Conservative error handling prevents false slashing.

The structured error inspection correctly distinguishes between confirmed 404s (returns False), permission/configuration errors (returns False with ERROR logging to surface setup issues), and transient network/service errors (returns True conservatively). This approach aligns with the goal of avoiding false slashing due to temporary network issues.

The fallback to string matching provides backward compatibility while the structured checks take precedence.


709-723: Critical fix: Offloading torch.load enables true parallel downloads.

The asyncio.to_thread wrapper correctly moves CPU-bound decompression off the event loop, allowing parallel gradient downloads to proceed concurrently instead of being serialized during decompression. The 1-2s decompression time per file would otherwise be a major bottleneck.

The exception handling at lines 719-723 properly catches and logs load errors without crashing the download process.


1471-1492: LGTM! Clean addition of optional version parameter.

The version parameter defaults to None and falls back to tplr.__version__, preserving backward compatibility. The implementation is straightforward and allows querying gradient timestamps across different versions when needed.


2072-2095: LGTM! Deterministic round-robin partitioning ensures consistency.

The partition_uids_across_ranks method correctly:

  • Sorts UIDs for deterministic ordering across all ranks
  • Uses round-robin assignment: rank i gets indices i, i+world_size, i+2*world_size, ...
  • Returns all UIDs when world_size <= 1 (single-rank case)

This ensures all ranks see the same partition mapping and no UID is assigned to multiple ranks.


2500-2551: Verify type consistency when wrapping sequential fallback results.

The code wraps sequential gather results in a list when return_partials=True:

  • Line 2519: primary = [sequential_result] if return_partials else sequential_result
  • Line 2531: primary = [sequential_result] if return_partials else sequential_result
  • Line 2551: primary = [empty_result] if return_partials else empty_result

However, there's a potential issue: when the distributed gather succeeds but returns a merged result instead of partials (due to the bug in gather_distributed), the caller will receive a SimpleNamespace even though they requested return_partials=True. This inconsistency could cause type errors downstream.

Once the gather_distributed bug is fixed (see previous comment), verify that all code paths consistently return the correct type based on return_partials.


2584-2643: Reserve peer logic correctly handles both partials and merged results.

The reserve peer retry logic properly branches based on whether primary is a list (partials) or SimpleNamespace (merged):

  • Partials path (lines 2614-2626): Appends the fallback as a new partial entry, assigns it a rank index, and logs appropriately
  • Merged path (lines 2627-2643): Merges the fallback's state_dict tensors, UIDs, and metrics into the primary result

This dual handling ensures the reserve mechanism works correctly regardless of the return format.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/tplr/comms.py (1)

1665-1740: Early NOT_FOUND cutoff still violates timeout semantics.

The consecutive NOT_FOUND counter (lines 1669-1670, 1725-1736) exits after 5 occurrences (~2.5s), but get() at line 1600 returns status="NOT_FOUND" for ALL None returns from s3_get_object, including transient asyncio.TimeoutError (line 653). This causes premature exit for slow/flaky peers long before the caller's timeout (default 30s) expires, breaking the intended retry semantics.

Fix: Either:

  1. Make s3_get_object distinguish true 404s from transient errors in its return value, or
  2. Derive max_consecutive_not_found from timeout (e.g., int(timeout / 0.5)) so it scales with the caller's intent, or
  3. Only apply the early exit after time_max has been exceeded (when we know the file should exist)

Based on learnings

🧹 Nitpick comments (3)
neurons/validator.py (2)

1490-1506: Hardcoded flag contradicts auto-detection intent.

Line 1490 sets use_distributed_gather = True unconditionally, but the PR objectives state "Auto-detection enables distributed gather when world_size > 1." The actual distributed vs. sequential decision happens inside gather_with_reserve (src/tplr/comms.py:2424 checks dist_helper.world_size > 1), so the runtime behavior is correct. However, the hardcoded flag here makes the logging and control flow misleading.

Consider aligning with the actual auto-detection:

-            use_distributed_gather = True
+            use_distributed_gather = dist_helper.world_size > 1 and dist_helper.is_distributed()

This makes the logging accurately reflect whether distributed or sequential gather will actually be used.


3790-3826: Type signature broader than implementation.

The function signature declares gather_result: SimpleNamespace | list[SimpleNamespace], but the implementation at line 3805 directly accesses gather_result.state_dict, which only works for SimpleNamespace, not list[SimpleNamespace]. There's no isinstance() check or branching to handle the list case.

In practice, this isn't a runtime bug because gather_result is always a merged SimpleNamespace (see line 1520: return_partials=False with comment "Always return merged result"). However, the type hint misleadingly suggests the function can handle both forms.

Consider narrowing the type hint to match actual usage:

     def compute_peer_val_norms(
         self,
-        gather_result: SimpleNamespace | list[SimpleNamespace],
+        gather_result: SimpleNamespace,
     ) -> dict[str, torch.Tensor]:

Alternatively, if future flexibility is desired, add explicit handling:

if isinstance(gather_result, list):
    raise NotImplementedError("Partials path not needed since return_partials=False")
src/tplr/comms.py (1)

2133-2257: LGTM: Comprehensive merge with robust device handling.

The merge correctly:

  • Concatenates per-rank tensor lists for each parameter
  • Aggregates metrics (max time for parallel, sum bytes, extend UIDs)
  • Calculates overall success rate

Device fallback logic gracefully handles edge cases with nested try/except, though it assumes device 0 is available when multiple devices exist. For systems where device 0 is missing but others exist, this will fall back to CPU, which is safe but suboptimal. Consider iterating through range(torch.cuda.device_count()) to find the first working device if this becomes an issue.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 66f8365 and fdce702.

📒 Files selected for processing (7)
  • neurons/base_node.py (1 hunks)
  • neurons/miner.py (9 hunks)
  • neurons/trainer.py (3 hunks)
  • neurons/validator.py (12 hunks)
  • src/tplr/comms.py (19 hunks)
  • src/tplr/compress.py (1 hunks)
  • src/tplr/neurons.py (10 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • neurons/trainer.py
  • src/tplr/compress.py
🧰 Additional context used
🧬 Code graph analysis (4)
src/tplr/neurons.py (1)
src/tplr/compress.py (5)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • TopKCompressor (295-697)
  • unpack_12bit_indices (102-142)
neurons/miner.py (3)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/comms.py (1)
  • gather_with_reserve (2401-2694)
neurons/validator.py (5)
tests/test_state_loading.py (1)
  • rating (59-62)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/comms.py (2)
  • gather_with_reserve (2401-2694)
  • get_with_retry (1625-1740)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1275-1480)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • gather_object (284-305)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.12)
  • GitHub Check: test (3.11)
🔇 Additional comments (28)
neurons/base_node.py (1)

61-62: LGTM! MRO chain properly initialized.

The super().__init__() call correctly continues the Method Resolution Order chain, ensuring parent classes (e.g., Trainer) are properly initialized. This resolves the previously flagged issue where expected_compressed_params and other parent class attributes were not being initialized.

src/tplr/neurons.py (4)

48-69: LGTM! Robust CUDA device resolution.

The _resolve_cuda_device helper properly handles all edge cases: CUDA unavailability, None inputs, invalid device strings, and type coercion. This addresses the previously flagged issue where CUDA APIs were called with string devices instead of torch.device objects.


250-294: LGTM! Simplified outer_step delegation.

The refactor simplifies outer_step to pure delegation while maintaining the single-path behavior requested in previous reviews. The updated docstring clearly documents that distributed gather merges partials internally, so callers always receive a single merged result.


297-551: LGTM! Device resolution properly integrated.

The _outer_step_single implementation correctly uses _resolve_cuda_device at line 332, ensuring CUDA operations receive proper torch.device objects instead of strings. The rest of the memory-minimizing outer step logic remains sound.


1275-1480: LGTM! Efficient dual-path overlap checking.

The updated check_uid_index_overlap correctly supports both merged results and partial results. The UID-to-partial mapping (lines 1296-1306) provides O(1) lookup, reducing complexity from O(n×m) to O(n) as requested in previous reviews. The unpacking logic (lines 1358-1388) properly handles both paths, and pair keys use original UID indices for consistency.

neurons/miner.py (9)

32-34: LGTM! Early CUDA allocator configuration.

Setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True before importing torch is the correct approach for memory optimization. The conditional check respects existing user environment settings.


474-492: LGTM! Complete peer state synchronization.

The broadcast now correctly includes peers, reserve_peers, and commitments, ensuring all ranks have consistent gather state. This addresses the previous review comment about missing reserve_peers in the broadcast payload.


554-561: LGTM! Proactive memory management.

The aggressive memory cleanup (cache clearing, synchronization, and gradient zeroing) before training helps prevent OOM issues, especially important for distributed training with large models mentioned in the PR objectives (70B-model scenarios).


740-812: LGTM! Distributed gather properly integrated.

The distributed gather implementation correctly:

  • Auto-detects when to use distributed mode based on world_size and dist_helper.is_distributed() (lines 741-743)
  • Includes reserve_uids=self.comms.reserve_peers as requested in previous reviews (line 763)
  • Sets local=False to pull from remote storage (line 769), fixing the previously flagged issue
  • Always returns merged results (return_partials=False at line 764), consistent with the outer_step docstring
  • Provides comprehensive logging for both distributed and sequential paths (lines 778-802)
  • Correctly handles master/non-master decision logic for should_update (lines 806-812)

840-842: LGTM! Explicit gradient cleanup after outer step.

The gradient clearing with set_to_none=True after the outer step ensures no stale gradients remain, and is more memory-efficient than zeroing. The defensive hasattr check is good practice.


895-901: LGTM! Correct attribute access for merged results.

Direct access to gather_result.uids and gather_result.skipped_uids is correct since gather_with_reserve is called with return_partials=False at line 764, ensuring a merged SimpleNamespace is always returned. This addresses the previously flagged incompatible attribute access issue.


937-943: LGTM! Safe metric extraction with None handling.

The success rate and skipped UIDs extraction properly handles the None case and correctly accesses attributes from the merged result. The success_rate * 100 conversion to percentage is appropriate for logging.


1024-1033: LGTM! Aggressive memory cleanup strategy.

The enhanced cleanup explicitly clears state_dict contents (lines 1027-1030), deletes objects, forces garbage collection, and empties CUDA cache. This aggressive approach aligns with the PR objective of avoiding GPU OOM in 70B-model scenarios. The defensive hasattr checks prevent errors if structures are already cleaned up.


774-774: expected_compressed_params is properly initialized in parent class.

The verification confirms that Trainer.__init__() at line 123 sets self.expected_compressed_params = self.get_expected_params(). With the super().__init__() fix in BaseNode, the MRO chain ensures this attribute is initialized before line 774 accesses it. No action needed.

neurons/validator.py (5)

931-937: LGTM! OpenSkill rating initialization prevents KeyError.

Properly initializes ratings for new UIDs before the rating operation. This ensures all peers in window_uids have valid OpenSkill rating objects before self.openskill_model.rate() is called.


1337-1368: LGTM! Peers and commitments broadcast enables distributed gather across ranks.

The master rank fetches updated commitments and broadcasts peer lists plus bucket commitments to all ranks, enabling them to participate in distributed gather. The second commitment fetch at line 1338 (after the initial fetch at line 1171) ensures fresh bucket information immediately before gathering.


1411-1438: LGTM! Timestamp broadcast ensures all ranks use consistent time window.

Correctly addresses the previous critical review concern about potential rank inconsistency. The master rank queries the timestamp (with fallback to time.time() on failure), then broadcasts it to all ranks via dist_helper.broadcast(). This guarantees all ranks use identical time_min and time_max values for the distributed gather operation.


1607-1643: LGTM! Simplified gather result handling eliminates duplicate merges.

Since return_partials=False (line 1520), gather_result is always a merged SimpleNamespace. The code correctly uses it directly for upload (line 1608), overlap checking (line 1614), and metric extraction (lines 1621, 1642) without redundant merging operations. This addresses the memory overhead concerns from previous reviews.


3197-3217: LGTM! get_with_retry improves reliability with better timeout and retry logic.

The change from self.comms.get() to self.comms.get_with_retry() adds robust retry mechanisms including exponential backoff, consecutive NOT_FOUND tracking, and graceful timeout handling (returns None instead of raising). The error handling correctly checks for both None (timeout) and unsuccessful results before proceeding.

src/tplr/comms.py (9)

20-20: LGTM: Clean imports and cache initialization for distributed gather.

The additions are well-structured:

  • gc import supports aggressive memory cleanup for large models
  • OrderedDict enables LRU-like partition caching
  • dist_helper import provides distributed primitives
  • _partition_cache type annotation correctly captures the cache structure

Also applies to: 29-29, 50-50, 142-144


483-582: LGTM: Robust existence check with structured error handling.

The implementation properly distinguishes between different error types using structured ClientError.response inspection rather than brittle string matching. The conservative approach (returning True for transient errors) is appropriate to avoid false negatives that could lead to incorrect slashing decisions.


667-668: LGTM: Critical fix for parallel download performance.

Offloading torch.load() to the thread pool prevents blocking the event loop during CPU-intensive decompression (1-2s per file). This enables true parallel downloads. The 500MB threshold increase is appropriate for 70B model gradients.

Also applies to: 709-723


1471-1493: LGTM: Clean version parameter with sensible default.

The optional version parameter with tplr.__version__ as the default is well-designed and properly documented. This enables checking gradients from different versions while maintaining backward compatibility.


1563-1566: LGTM: Consistent thread pool offloading for local loads.

Applying the same asyncio.to_thread fix to local file loading ensures the event loop isn't blocked during decompression regardless of storage location.


2071-2095: LGTM: Clean deterministic round-robin partitioning.

The implementation correctly sorts UIDs for determinism and uses idiomatic Python slicing [rank::world_size] for round-robin assignment. Edge case handling for single-rank execution is appropriate.


2097-2131: LGTM: Robust recursive device movement.

The recursive move_to_device helper correctly handles nested structures (lists, tuples, dicts) and tensors. In-place modification of result.state_dict is acceptable since this is typically called during final result processing.


2259-2396: LGTM: Well-structured distributed gather with proper return_partials support.

The implementation correctly:

  • Caches partition patterns per window with LRU eviction
  • Gathers partials to master only (memory-efficient)
  • Returns normalized partial list when return_partials=True (line 2365-2370)
  • Merges and returns consolidated result when return_partials=False (line 2372-2389)

Non-master ranks receiving None is expected with the gather-to-master approach and must be handled by callers via the all_ok broadcast pattern.


2401-2694: LGTM: Comprehensive gather with proper distributed fallback and type consistency.

The refactor correctly:

  • Auto-detects when to use distributed gather (line 2442)
  • Falls back to sequential gather when distributed fails (lines 2506-2525)
  • Maintains type consistency: wraps sequential/empty results in list when return_partials=True (lines 2525, 2537, 2557)
  • Handles reserve retry by either appending as a new partial (lines 2620-2632) or merging into the result (lines 2633-2649)
  • Computes aggregate success rate for partials (lines 2654-2675)
  • Returns None when no gradients are gathered (lines 2678-2679, 2691-2692)

The sequential fallback for reserve retry (line 2612) is appropriate since it's typically a small set of replacement UIDs.

Core implementation:
- Add gather_distributed() for parallel gradient downloads across ranks
- Partition UIDs deterministically across GPU ranks
- Each rank downloads subset in parallel, then all-gather to merge
- CPU/GPU memory management to prevent OOM on large models
- Fallback to sequential gather if distributed fails
- Auto-detect based on world_size (no manual config needed)

Miners:
- Use distributed gather when world_size > 1
- Parallel gradient downloads from peer buckets
- Use reserve peers for fallback (same as validators)
- Proper synchronization across all ranks

Validators:
- Support distributed gather for multi-GPU validators
- Improved error handling with get_with_retry
- Better timeout handling (150s for 70B models)
- Test mode support for local networks

Code quality:
- Removed dead code paths (~450 lines)
- Simplified outer_step (single path)
- Removed unused parameters
- Clean gradient averaging via scatter_reduce(reduce="mean")

Net: -339 lines while adding full distributed gather support
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
src/tplr/comms.py (2)

2425-2718: gather_with_reserve distributed path is mostly sound but master‑only fallback still risks divergence

The overall structure of gather_with_reserve is solid:

  • Decides use_distributed via dist_helper and logs entry with rank and UID sets.
  • In distributed mode, calls gather_distributed with return_partials propagated.
  • Normalizes “no gradients” on master to trigger a sequential fallback, while treating None on non‑masters as expected.
  • Wraps sequential/fallback results into a list when return_partials=True to maintain type consistency.
  • For reserve retries, appends a new partial in partials mode or merges into the merged namespace.
  • Recomputes success_rate against the original gather tier and emits detailed context logs.

However, in the distributed branch the sequential fallback still runs only on the master rank:

if primary is None and dist_helper.is_master:
    sequential_result = await self.gather(...)
    primary = [sequential_result] if return_partials else sequential_result

Non‑master ranks keep primary=None and immediately return None later in the function. This duplicates a prior concern: if callers expect all ranks to see the same post‑fallback gradients (e.g. miner/validator outer steps running on every rank), you may get cross‑rank inconsistency unless:

  • outer_step (or equivalent) is explicitly guarded to only use the result on master and broadcast updated model state, or
  • the master’s primary is broadcast to other ranks after the fallback.

Please verify downstream usage (e.g., in neurons/miner.py and validator code) to ensure either master‑only update semantics are intentional and correctly implemented, or adjust this fallback to synchronize results.


1668-1671: Fixed NOT_FOUND cutoff still short‑circuits retries relative to timeout

The get_with_retry loop still gives up after max_consecutive_not_found = 5 with 0.5 s sleeps, i.e. ≈2.5 s total, regardless of the timeout argument. Because get() currently maps all “no data” conditions (404, timeouts, various network/client errors that cause s3_get_object to return None) to status="NOT_FOUND", slow or flaky peers can be treated as “likely doesn’t exist” well before the caller’s timeout budget is used.

This duplicates an earlier concern: either get()/s3_get_object should distinguish true 404s from transient failures, or the NOT_FOUND cutoff should be derived from timeout and/or only applied once we’re clearly past time_max.

Also applies to: 1725-1737

🧹 Nitpick comments (8)
neurons/validator.py (3)

1337-1363: Peers/commitments broadcast keeps ranks aligned but may be heavier than needed

Refreshing self.comms.commitments on the master each window and broadcasting [peers, reserve_peers, commitments] via dist.broadcast_object_list guarantees all ranks share an identical view before gather, which is important for distributed partitioning.

Two minor points to consider:

  • get_commitments() is called after update_peers_with_buckets(), so the freshly fetched commitments only take effect on the next window; if you want current‑window buckets, you could invert the order.
  • Broadcasting the full commitments dict every window can become expensive as the network scales; if this shows up in profiling, consider broadcasting a lighter representation (e.g., UIDs plus a version or hash) and letting non‑master ranks lazily refresh from storage when the version changes.

1490-1524: Gather mode flag is hard‑coded; sequential-branch comments are now misleading

use_distributed_gather is set to True unconditionally, so:

  • The “sequential gather” logging branch (elif self.is_master) is never taken.
  • All ranks always call self.comms.gather_with_reserve(...), relying on Comms’ internal logic (world_size/is_distributed()) to decide whether the actual implementation is distributed or sequential.

This is functionally fine (single‑GPU runs still behave sequentially, and multi‑GPU runs use the distributed path), but the flag no longer conveys real mode information and the surrounding comments become confusing.

Consider either:

  • Deriving the flag from the actual distributed condition, e.g.:
    use_distributed_gather = dist_helper.is_distributed()
    and keeping the dual logging branches; or
  • Dropping use_distributed_gather entirely and simplifying the conditionals/log messages around the gather to reflect that Comms auto‑selects the mode.

Also applies to: 1543-1560


3790-3826: compute_peer_val_norms signature suggests list support but implementation only handles merged results

The signature and type hint allow gather_result: SimpleNamespace | list[SimpleNamespace], but the implementation assumes a single merged result:

  • It immediately does state_dict = gather_result.state_dict and uses getattr(state_dict, vals_key, ...), which will fail for a list input.
  • The docstring also still describes gather_result as “Merged gather result SimpleNamespace”.

Given the current call site in this file always passes a merged result (because gather_with_reserve is invoked with return_partials=False), this isn’t a runtime bug today, but it is misleading for future callers.

I’d either:

  • Narrow the type to SimpleNamespace in both signature and docstring; or
  • Add an explicit isinstance(gather_result, list) branch that merges partials or forwards them through a compatible path before computing norms.
src/tplr/neurons.py (2)

48-69: _resolve_cuda_device result is computed but never used

_resolve_cuda_device is a reasonable helper for normalizing CUDA devices, but in _outer_step_single you assign:

cuda_device = _resolve_cuda_device(device)

and then never use cuda_device — all downstream calls still use the original device argument.

To keep outer_step lean and avoid confusion, consider either:

  • Wiring cuda_device into the places that actually need a CUDA torch.device (e.g., when constructing tensors meant for CUDA‑only APIs), or
  • Removing the helper and the assignment until you need it.

Also applies to: 332-333


969-1000: Avoid coupling catch-up to instance.loop; use the current event loop instead

In catchup_with_aggregation_server, the fallback path uses:

ts_value = await instance.loop.run_in_executor(
    None, instance.query_block_timestamp, sync_block
)

This assumes every Miner/Validator has an instance.loop attribute set before catch‑up runs. That’s true for the current validator path, but it’s a fragile coupling and easy to break if catch‑up is ever invoked from a different context.

You can make this more robust and self‑contained by using the running loop inside the coroutine:

loop = asyncio.get_running_loop()
ts_value = await loop.run_in_executor(
    None, instance.query_block_timestamp, sync_block
)

This keeps behavior the same while removing the implicit dependency on instance.loop.

neurons/miner.py (1)

1024-1033: Final gather_result cleanup is safe but note reliance on state_dict.dict

Clearing gather_result.state_dict.__dict__ and then deleting gather_result before gc.collect() and torch.cuda.empty_cache() is an effective way to drop large gradient payloads between windows. This block runs after all usages of gather_result in the loop, so lifecycle is correct.

Only caveat: this assumes state_dict is a SimpleNamespace (as produced by Comms). If future refactors change that type, this direct __dict__ access may need adjustment.

src/tplr/comms.py (2)

2098-2131: _move_gather_result_to_device is straightforward and safe

The recursive move_to_device helper handles tensors nested in lists/tuples/dicts and rewrites result.state_dict in place. Given callers only use it on gather results with SimpleNamespace state_dict, this is an appropriate utility. Just be aware of potential extra device transfers if called redundantly.


3213-3247: s3_get_object_size remains simple but usable

head_object + ContentLength for object size, with 404 detection and purge on client errors, is a reasonable utility. If you later centralize ClientError handling (similar to s3_object_exists), this could share that logic, but current behavior is acceptable.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fdce702 and 89caddd.

📒 Files selected for processing (7)
  • neurons/base_node.py (1 hunks)
  • neurons/miner.py (9 hunks)
  • neurons/trainer.py (3 hunks)
  • neurons/validator.py (12 hunks)
  • src/tplr/comms.py (19 hunks)
  • src/tplr/compress.py (1 hunks)
  • src/tplr/neurons.py (10 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • src/tplr/compress.py
  • neurons/base_node.py
  • neurons/trainer.py
🧰 Additional context used
🧬 Code graph analysis (4)
src/tplr/neurons.py (1)
src/tplr/compress.py (4)
  • compress (373-377)
  • compress (379-383)
  • compress (386-423)
  • unpack_12bit_indices (102-142)
neurons/validator.py (6)
tests/test_state_loading.py (1)
  • rating (59-62)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • broadcast (273-282)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/comms.py (2)
  • gather_with_reserve (2425-2718)
  • get_with_retry (1625-1740)
src/tplr/neurons.py (1)
  • check_uid_index_overlap (1275-1480)
neurons/miner.py (4)
src/tplr/chain.py (1)
  • get_commitments (264-322)
src/tplr/distributed.py (1)
  • is_distributed (75-77)
src/tplr/comms.py (1)
  • gather_with_reserve (2425-2718)
src/tplr/dcp_checkpoint.py (2)
  • state_dict (60-61)
  • state_dict (82-83)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
  • Bucket (24-46)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/distributed.py (2)
  • is_distributed (75-77)
  • gather_object (284-305)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: test (3.11)
  • GitHub Check: test (3.12)
🔇 Additional comments (26)
neurons/validator.py (4)

931-937: OpenSkill rating initialization is robust and avoids missing-state edge cases

Eagerly instantiating self.openskill_ratings[uid] for all window_uids before calling rate() prevents lookups on missing ratings and keeps the update logic below straightforward. This block is correct and well‑scoped.


1411-1438: Timestamp calculation and broadcast now correctly synchronize all ranks

Only the master queries query_block_timestamp(sync_block) and falls back to time.time() with logging, while all ranks receive the final ts_value via dist_helper.broadcast and derive time_min/time_max from that shared scalar. This removes the prior risk of rank‑dependent time windows and looks correct.


1606-1643: Post-gather pipeline (upload → overlap → offload → norms) is consistent

On master, you:

  • Upload the already‑merged gather_result asynchronously.
  • Run check_uid_index_overlap(self, gather_result, ...) while gradients are still on the original device.
  • Then offload gather_result.state_dict to pinned CPU via offload_gather_results.
  • Later, compute clip_norm_dict = self.compute_peer_val_norms(gather_result) (which will pull tensors back to the target device as needed).

This ordering keeps overlap detection on-device while still aggressively freeing GPU memory before evaluation and outer_step, and it matches the assumption that gather_with_reserve(..., return_partials=False) always returns a merged SimpleNamespace. The control flow here looks sound.

Also applies to: 1821-1825


3197-3217: Switching to get_with_retry makes sync-eval more resilient

Using self.comms.get_with_retry with an explicit timeout and handling the None case cleanly:

  • Avoids tight retry loops inside evaluate_miner_sync.
  • Ensures failures (timeout or non‑success) result in a well‑formed {success: False, sync_score: 0.0} payload, with a defensive getattr(debug_result, "error", "unknown") log message.

The control flow and error handling here look correct.

src/tplr/neurons.py (2)

249-295: Outer-step refactor keeps semantics while improving memory safety

The new outer_step delegating to _outer_step_single:

  • Treats gather_result as a single merged SimpleNamespace (matching the updated validator path with return_partials=False).
  • Builds and applies one parameter’s dense grad at a time, calling optimizer.step() per param and immediately freeing p.grad plus intermediates, which is appropriate for tight memory budgets.
  • Correctly coordinates DDP/DTensor:
    • Master (or single‑rank) computes src_sd and dense grads.
    • _bcast_flag tells other ranks whether a param has an update.
    • Non‑DTensor params use dist.broadcast into p.grad.
    • DTensor params use distribute_tensor with src_data_rank=0, then sanity‑check local shards.

The broadcasting and collective calls are ordered identically on all ranks, so there’s no obvious deadlock risk, and invalid gradients are guarded by finite checks before being applied.

Overall this refactor looks correct and aligns with the PR’s “memory‑minimizing outer step” goal.

Also applies to: 297-552


1290-1479: Partial-aware check_uid_index_overlap logic is consistent with merged-path semantics

The updated check_uid_index_overlap now:

  • Accepts either a merged SimpleNamespace or a list[SimpleNamespace] of partials.
  • For partials, flattens all partial.uids into a single uids list and builds a uid_to_partial[uid] = (partial, idx_in_partial) map so that per‑UID state can be located in O(1) rather than scanning all partials.
  • For each parameter:
    • Builds unpacked_indices only for UIDs that actually have idxs/vals present in their per‑UID state_dict.
    • Tracks the original UID index via valid_uids_for_param, and uses those indices as keys into pair_acc, so the summary stage can still safely map back to uids[i]/uids[j] when computing offenders and logging min/max overlaps.

This preserves the original merged‑result behavior while extending it cleanly to partials, and the index bookkeeping between valid_uids_for_param, pair_acc, and the final uids[...] lookups is internally consistent.

neurons/miner.py (7)

32-35: CUDA allocator env setup is in the right place

Setting PYTORCH_CUDA_ALLOC_CONF before importing torch is correct and avoids allocator reconfiguration warnings at runtime. No issues here.


474-493: Distributed broadcast keeps peers/reserves/commitments consistent across ranks

Broadcasting self.comms.peers, self.comms.reserve_peers, and self.comms.commitments from master ensures all ranks use identical gather inputs, which is critical for deterministic distributed gather. The dist.broadcast_object_list usage and post‑assignment look correct.


554-562: Pre‑training CUDA cleanup is aggressive but appropriate for OOM mitigation

Calling torch.cuda.empty_cache(), torch.cuda.synchronize(), and model.zero_grad(set_to_none=True) before inner_steps is a reasonable tradeoff to reduce VRAM fragmentation and stale gradients for large models. Given this code assumes CUDA, this is acceptable.


740-813: Distributed vs sequential gather selection and logging look consistent

Using use_distributed_gather = dist_helper.world_size > 1 and dist_helper.is_distributed() and then:

  • Running gather_with_reserve on all ranks in distributed mode, and
  • Logging separate paths for distributed vs sequential gathers,
    is coherent with the comms-side gather_with_reserve/gather_distributed behavior. The should_update logic (master decides, broadcast via all_ok) is also sound.

840-843: Post‑outer‑step gradient clearing is correct

Zeroing gradients with set_to_none=True immediately after the outer step is good hygiene and helps reduce memory usage, especially with AMP and large models.


893-902: Debug peer bookkeeping aligns with gather_result structure

Using gather_result.uids / gather_result.skipped_uids and deriving:

  • successful_peers = peers - skipped_uids
  • skipped_peers = skipped_uids

matches the semantics of the merged gather result from gather_with_reserve. This should give accurate debug traces for which peers contributed gradients.


937-945: Gather success rate and skipped UID logging integrate cleanly into WandB/metrics

The computation of gather_success_rate and skipped_uids_list and their inclusion both in WandB metrics and Influx fields is consistent and safe when gather_result is None (falls back to 0 / empty list).

src/tplr/comms.py (13)

20-21: gc import is appropriate for explicit cleanup sites

Importing gc at module level matches the later explicit gc.collect() calls and is a reasonable addition for aggressive cleanup paths.


29-50: Ordering new imports (OrderedDict, dist_helper) is consistent with usage

Adding OrderedDict and dist_helper at module level is correct and matches where they are used (_partition_cache, distributed helpers). No issues here.


142-145: _partition_cache design is reasonable for per‑window UID partitions

Using an OrderedDict keyed by (window, tuple(sorted_uids), world_size) to cache UID partitions and trimming it to 32 entries is a good balance between determinism and memory usage. This avoids recomputing round‑robin partitions each window and is safe given small value sizes.


483-582: s3_object_exists implements robust, conservative existence checks

The new s3_object_exists:

  • Uses head_object with a per‑call timeout.
  • Distinguishes 404 / NoSuchKey, 403 / access issues, and bucket misconfig separately.
  • Treats 5xx and generic network errors as transient and conservatively returns True.

This is a solid improvement over brittle string matching and aligns with the intent to avoid false slashing on transient infrastructure issues.


666-693: Large‑file download path and size‑based timeout scaling are sensible

Raising the simple‑download threshold to 500 MB and using download_large_file for larger objects, with download_timeout = max(300, file_size / (25 * 1024 * 1024)), is appropriate for huge gradient files. This should materially reduce timeouts on >GB objects while still bounding total wall‑clock time.


699-723: Offloading torch.load to a worker thread is the right fix for event‑loop blocking

Wrapping torch.load in asyncio.to_thread avoids blocking the asyncio event loop during CPU decompression, and weights_only=True is appropriate for these tensor‑only payloads. The error logging path that falls back to loaded_data = None is also safe for callers.


1563-1566: Local torch.load offload mirrors remote behavior appropriately

Using await asyncio.to_thread(torch.load, local_path, weights_only=True) for local gradient/debug loads keeps local and remote paths consistent and prevents blocking the loop on large local files. This change looks good.


1844-2056: Gather response processing and validation improvements are solid

The reworked gather loop:

  • Logs per‑UID download errors and None responses clearly.
  • Validates quantization params, packed indices, and value tensors without dequantizing (for memory efficiency).
  • Enforces presence of all expected_compressed_params before accepting a UID.
  • Accumulates metrics and builds a structured SimpleNamespace result.

This is a robust design for validating and aggregating compressed gradients.


2071-2096: partition_uids_across_ranks is deterministic and matches round‑robin spec

Sorting UIDs and then using sorted_uids[rank::world_size] is a clear, deterministic round‑robin partitioning scheme. This aligns with the documented behavior and is easy to reason about.


2133-2281: merge_gather_results correctly rebuilds a deterministic merged result

This helper does the right things:

  • Filters out None partials and bails out early if none.
  • Picks a safe target_device (current CUDA device if available with fallbacks).
  • Reconstructs merged_state_dict by iterating UIDs in sorted order and indexing into each partial’s per‑param lists, ensuring the merged representation matches sequential gather ordering.
  • Aggregates global_steps, skipped_uids, and byte/time metrics, computing a global success rate.

Overall, the logic matches the semantics expected from merging per‑rank outputs of gather().


2283-2420: gather_distributed semantics (master‑only merge, return_partials handling) look correct

The distributed gather flow:

  • Caches deterministic UID partitions per (window, sorted_uids, world_size).
  • Has each rank run gather() on its assigned subset (or no‑op if empty).
  • Uses dist_helper.gather_object to bring all partial results only to rank 0.
  • On rank 0, normalizes empty partials to None, and:
    • returns the normalized list when return_partials=True, or
    • calls merge_gather_results(..., target_device=device) and returns the merged result.
  • Non‑master ranks always return None.

This matches the documented contract: in distributed mode only the master has data, and return_partials now genuinely switches between merged vs per‑rank partials.


2937-3058: Peer list retrieval logic and parsing are consistent

Fetching peer lists from the highest‑stake validator, filtering by versioned peers_*.json keys, selecting most‑recent or previous windows, and parsing {peers, reserve_peers, first_effective_window} into a tuple is coherent. Given other hardening already present in get_peer_list, this path looks fine.


3068-3137: get_start_window version handling and retries are correct

Allowing an optional version override while defaulting to tplr.__version__, and retrying with backoff via _get_highest_stake_validator_bucket plus s3_get_object, matches expectations for bootstrap synchronization. No issues here.

Comment on lines 708 to +813
gather_result = await self.comms.gather_with_reserve(
my_uid=self.uid,
gather_uids=self.comms.peers,
reserve_uids=self.comms.reserve_peers,
reserve_uids=self.comms.reserve_peers, # Miners gather same way as validator
return_partials=False, # Always return merged result (distributed gather merges internally)
window=step_window,
key="gradient",
timeout=90,
device=str(self.device),
local=False,
stale_retention=100,
totalks=self.totalks,
compressor=self.compressor,
time_min=time_min,
time_max=time_max,
expected_compressed_params=self.expected_compressed_params,
)
tplr.logger.info("Gather task completed!")
gather_time = tplr.T() - gather_start
should_update = gather_result is not None

# Broadcast whether we should update to all ranks
# Log gather completion
if use_distributed_gather:
if gather_result is not None:
tplr.logger.info(
f"[Rank {dist_helper.rank}] Distributed gather complete: {len(gather_result.uids)}/{len(self.comms.peers)} successful, "
f"{len(gather_result.skipped_uids)} skipped, success_rate={gather_result.success_rate:.2%}, "
f"time={gather_time:.2f}s"
)
else:
tplr.logger.warning(
f"[Rank {dist_helper.rank}] Distributed gather failed - no gradients collected from peers"
)
else:
# Sequential gather logging (master only)
if gather_result is not None:
tplr.logger.info(
f"Sequential gather complete: {len(gather_result.uids)}/{len(self.comms.peers)} successful, "
f"{len(gather_result.skipped_uids)} skipped, "
f"success_rate={gather_result.success_rate:.2%}, "
f"time={gather_time:.2f}s"
)
else:
tplr.logger.warning(
"Sequential gather failed - no gradients collected from peers"
)

# For distributed gather: only master checks result, others will follow master's decision
# For sequential gather: only master has result anyway
if use_distributed_gather:
# Master checks if gather succeeded, non-master ranks defer to master
should_update = (
gather_result is not None if self.is_master else True
)
else:
should_update = gather_result is not None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

Ensure outer_step can handle gather_result=None on non‑master ranks in distributed mode

In distributed runs, Comms.gather_with_reserve returns a merged SimpleNamespace only on the master rank; non‑master ranks receive None. After should_update is broadcast via dist_helper.all_ok, all ranks enter:

if should_update:
    gradient_fingerprint = self.outer_step(gather_result)

On non‑master ranks this calls outer_step(None). If outer_step assumes a non‑None, merged gather result, this will raise at runtime or silently skip gradient application on non‑master ranks, leading to model drift across ranks.

Please verify that tplr.neurons.outer_step:

  • Explicitly handles gather_result is None (e.g., by no‑op on non‑master and broadcasting updated weights from master), or
  • Is only executed on the master rank with a guard around this call.

If not, you likely need either to broadcast the merged gather_result (or updated model state) from master to all ranks or gate outer_step on self.is_master.


I'll verify this review comment by examining the outer_step method implementation and how it's called after the code snippet.


---

🏁 Script executed:

```shell
#!/bin/bash
# Find outer_step definition
rg -n "def outer_step" --type python

Length of output: 30


🏁 Script executed:

#!/bin/bash
# Look for outer_step calls with context
rg -n -A5 -B5 "outer_step" --type python | head -100

Length of output: 30


🏁 Script executed:

#!/bin/bash
# Check what happens after the should_update is set - look for the actual call
rg -n "\.outer_step\(" --type python -C5

Length of output: 30



---

🏁 Script executed:

```shell
#!/bin/bash
# Search for outer_step definition and calls without type constraint
rg -n "def outer_step|\.outer_step\(" -C5

Length of output: 2811


🏁 Script executed:

#!/bin/bash
# View the miner.py file from line 810 onwards to see what happens after should_update
rg -n "should_update" neurons/miner.py -A20

Length of output: 3405


🏁 Script executed:

#!/bin/bash
# Search for where outer_step is actually called in the context shown
rg -n "outer_step" neurons/miner.py -B3 -A3

Length of output: 1847



---

🏁 Script executed:

```shell
#!/bin/bash
# View the outer_step implementation in detail
sed -n '250,350p' src/tplr/neurons.py

Length of output: 3157



---

🏁 Script executed:

```shell
#!/bin/bash
# Continue viewing _outer_step_single to see what happens when src_sd is None
sed -n '330,450p' src/tplr/neurons.py

Length of output: 4739


🏁 Script executed:

#!/bin/bash
# View the complete trainer.py outer_step wrapper
sed -n '1023,1050p' neurons/trainer.py

Length of output: 639


🏁 Script executed:

#!/bin/bash
# Check if miner.py has is_master property and how it's used
rg -n "is_master" neurons/miner.py -B2 -A2

Length of output: 3900


Critical issue confirmed: non-master ranks will fail when attempting to unpack None gradients in distributed mode

The review comment accurately identifies a major bug in the current implementation. Here's what happens:

  1. In distributed gradient gathering, on non-master ranks, on_src = False (since on_src = is_master or not ddp), which means src_sd = None.

  2. When the master decides to update, it broadcasts has_update=1 to all ranks via _bcast_flag.

  3. Non-master ranks receive this flag but have payload=None, since they never extracted data from gather_result (which is None on non-master ranks anyway).

  4. The unpacking will fail: when flag_result == 1 and the code attempts to execute idxs_dev, vals_f32 = payload # type: ignore[misc] on a non-master rank, it will crash because payload is None.

The trainer wrapper correctly passes is_master and world_size to outer_step, and miner.py line 835 calls self.outer_step(gather_result) without any guard. All ranks execute this call after the broadcast, making this a runtime failure waiting to happen.

Resolution required: Either gate the outer_step call on self.is_master, or broadcast the merged gradient state from master to all ranks before the call.

🤖 Prompt for AI Agents
In neurons/miner.py around lines 760-813, non-master ranks can end up with
gather_result/payload==None but still call outer_step and try to unpack gradient
payload, causing a crash; fix by either (A) gating the outer_step call so only
the master executes it (wrap the call in if self.is_master:
self.outer_step(...)) or (B) have the master broadcast the merged gradient
payload (not just a has_update flag) to all ranks before outer_step so
non-masters receive a valid payload to unpack; implement one of these two
approaches and ensure the broadcasted data and control flow are consistent
(i.e., non-masters use received payload and do not try to unpack None).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants