feat: distributed gradient gathering for multi-GPU parallelism#649
feat: distributed gradient gathering for multi-GPU parallelism#649shivam-MBZUAI wants to merge 1 commit intodevfrom
Conversation
WalkthroughAdds deterministic per-rank distributed gather (with optional partial returns and merging), device-aware result movement, resilient S3/io with threaded loads, and distributed coordination (broadcasts/timestamps). Propagates partial-aware handling through outer-step, overlap checks, validation, and miner flow; improves CUDA memory hygiene, gradient logging, and gather metrics. Changes
Sequence Diagram(s)sequenceDiagram
participant Miner
participant Comms
participant S3
participant Neurons
participant Validator
rect rgba(230,240,250,0.9)
note right of Miner: Initiate gather (distributed-capable)
Miner->>Comms: gather_with_reserve(..., return_partials?)
alt distributed path (world_size>1)
Comms->>Comms: partition_uids_across_ranks()
par per-rank fetch (concurrent)
Comms->>S3: s3_get_object(uids_for_rank)
S3-->>Comms: partial_result
end
Comms->>Comms: merge_gather_results(partials) [master]
Comms-->>Miner: merged_result or [partials]
else sequential path
Comms->>S3: fetch all UIDs
S3-->>Comms: merged_result
Comms-->>Miner: merged_result
end
end
rect rgba(245,240,230,0.9)
note right of Miner: Post-gather processing and validation
Miner->>Neurons: outer_step(gather_result, device=_resolve_cuda_device(...))
alt gather_result is list (partials)
Neurons->>Neurons: map UID -> partial.state_dict -> _outer_step_single
else merged result
Neurons->>Neurons: _outer_step_single()
end
Neurons->>Validator: compute_peer_val_norms(gather_result)
Validator-->>Neurons: norms + gather metrics (success/skips)
Miner->>Comms: broadcast decisions (master) / receive (workers)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (20.32%) is below the target coverage (85.00%). You can increase the patch coverage or adjust the target coverage. @@ Coverage Diff @@
## dev #649 +/- ##
==========================================
- Coverage 57.89% 55.08% -2.82%
==========================================
Files 27 27
Lines 4895 5223 +328
==========================================
+ Hits 2834 2877 +43
- Misses 2061 2346 +285
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/tplr/comms.py (1)
2364-2380: Fix potential NameError when fallback is None.Lines 2373-2374 reference
fallbackoutside theif fallback:block (lines 2364-2372). IffallbackisNone, accessingfallback.upload_bytesandfallback.download_byteswill raise anAttributeError.Apply this diff to fix the indentation:
if fallback: # merge tensor‑lists inside the nested state_dict for k, v in vars(fallback.state_dict).items(): merged = getattr(primary.state_dict, k, []) + v setattr(primary.state_dict, k, merged) primary.uids.extend(fallback.uids) primary.global_steps.extend(fallback.global_steps) primary.skipped_uids.extend(fallback.skipped_uids) - primary.upload_bytes += fallback.upload_bytes - primary.download_bytes += fallback.download_bytes + primary.upload_bytes += fallback.upload_bytes + primary.download_bytes += fallback.download_bytes context_log( message=f"[gather_with_reserve] ✅ reserve gather "
♻️ Duplicate comments (1)
src/tplr/comms.py (1)
2234-2234: Move dist_helper import to module level.Same as the issue in
gather_distributed:dist_helpershould be imported at the module level rather than inside the function.
🧹 Nitpick comments (5)
hparams/2B.json (1)
1-26: Add trailing newline to JSON file.The file is missing a trailing newline, which is a common convention for text files and may cause issues with some tooling.
Add a newline after the closing brace on line 26.
src/tplr/neurons.py (2)
226-293: Consider documenting the memory_budget_mb asymmetry.The
memory_budget_mbparameter is passed to_outer_step_incrementalbut not to_outer_step_single. While this is intentional (single-result processing doesn't need budget tracking), it would be helpful to document this in the docstring to clarify when the budget is actually enforced.Update the docstring to clarify:
Args: gather_result: Either a single SimpleNamespace or list of partial results - memory_budget_mb: Optional memory budget in MB (not enforced yet, future work) + memory_budget_mb: Optional memory budget in MB (enforced only for list of partials, ignored for single results)
471-473: Consider element-weighted averaging for means.The means are averaged uniformly across partials, but this may not accurately represent the combined mean if partials have different numbers of elements per parameter. Consider weighting by element count for more accurate statistics.
Track element counts per partial and compute weighted average:
for param_name, mean in partial_fingerprint.get("param_means", {}).items(): if param_name not in combined_fingerprint["param_means"]: - combined_fingerprint["param_means"][param_name] = [] - combined_fingerprint["param_means"][param_name].append(mean) + combined_fingerprint["param_means"][param_name] = {"sum": 0.0, "count": 0} + elem_count = partial_fingerprint.get("param_elem_counts", {}).get(param_name, 1) + combined_fingerprint["param_means"][param_name]["sum"] += mean * elem_count + combined_fingerprint["param_means"][param_name]["count"] += elem_countThen finalize:
# Average the means for param_name in combined_fingerprint["param_means"]: - means_list = combined_fingerprint["param_means"][param_name] - combined_fingerprint["param_means"][param_name] = sum(means_list) / len(means_list) + stats = combined_fingerprint["param_means"][param_name] + combined_fingerprint["param_means"][param_name] = stats["sum"] / stats["count"] if stats["count"] > 0 else 0.0neurons/validator.py (1)
91-132: Consider adding input validation to helper methods.The helper methods assume
gather_resultis either alistor has attributes likeuids,skipped_uids,success_rate. If an unexpected type is passed (e.g.,Nonewhen it shouldn't be, or a different object type), the methods will fail with unclear errors or return incorrect results.Consider adding type validation at the start of each helper:
@staticmethod def _get_uids_from_gather_result(gather_result) -> list[int]: """Extract UIDs from either partials (list) or merged result (SimpleNamespace).""" + if gather_result is None: + return [] + if not isinstance(gather_result, (list, SimpleNamespace)): + tplr.logger.warning(f"Unexpected gather_result type: {type(gather_result)}") + return [] if isinstance(gather_result, list):Apply similar guards to the other two helpers.
src/tplr/comms.py (1)
1982-1984: Consider using config device instead of hardcoding cuda:0.In multi-GPU setups, the primary device may not be
cuda:0. Consider defaulting toself.config.deviceif available, or requiring the caller to specify the target device explicitly.Apply this diff to use a more flexible default:
# Determine target device for tensor consolidation if target_device is None: - target_device = "cuda:0" if torch.cuda.is_available() else "cpu" + # Default to primary CUDA device or CPU + target_device = f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
hparams/2B.json(1 hunks)neurons/miner.py(2 hunks)neurons/validator.py(10 hunks)src/tplr/comms.py(8 hunks)src/tplr/neurons.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (1)
src/tplr/comms.py (1)
gather(1607-1929)
src/tplr/neurons.py (1)
src/tplr/compress.py (5)
compress(373-377)compress(379-383)compress(386-423)ChunkingTransformer(145-292)TopKCompressor(295-696)
src/tplr/comms.py (2)
src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
neurons/validator.py (5)
src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)neurons/base_node.py (1)
query_block_timestamp(202-247)src/tplr/comms.py (2)
gather_with_reserve(2199-2423)merge_gather_results(1961-2070)src/tplr/neurons.py (1)
check_uid_index_overlap(1493-1636)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (12)
neurons/miner.py (2)
707-725: LGTM: Sequential gather with comprehensive parameters.The updated gather call correctly uses the expanded parameter set that aligns with the new distributed gathering capabilities. The logging provides clear visibility into the gather process.
728-738: Good: Enhanced logging for gather results.The post-gather logging provides valuable insights including success rate, counts, and timing, which will aid in monitoring and debugging the distributed gathering process.
src/tplr/neurons.py (4)
427-431: LGTM: Correct L2 norm accumulation.The squared norms are correctly accumulated and then square-rooted at the end (line 466-468) to compute the combined L2 norm. This is the mathematically correct way to combine norms from multiple sources.
516-769: LGTM: Original logic correctly preserved in separate function.The
_outer_step_singlefunction correctly preserves the original outer step implementation. The refactoring cleanly separates single-result and incremental processing without changing the core gradient application logic.
488-501: Potential division by zero in utilization calculation.Line 490 calculates
utilization_pctbut doesn't guard againstmemory_budget_mbbeing zero or negative before the division.Apply this fix:
final_memory_mb = torch.cuda.memory_allocated(device) / (1024 ** 2) - utilization_pct = (final_memory_mb / memory_budget_mb * 100) if memory_budget_mb > 0 else 0 + utilization_pct = (final_memory_mb / memory_budget_mb * 100) if memory_budget_mb > 0 else 0.0 if final_memory_mb > memory_budget_mb:Actually, the current code already has the guard
if memory_budget_mb > 0, so this is already correct. However, consider also handling negative budgets:- utilization_pct = (final_memory_mb / memory_budget_mb * 100) if memory_budget_mb > 0 else 0 + utilization_pct = (final_memory_mb / memory_budget_mb * 100) if memory_budget_mb > 0 else 0.0Likely an incorrect or invalid review comment.
386-403: Consider whether defensive programming for circular references and deep recursion is necessary for this specific use case.The review comment raises theoretically valid concerns, but verification shows they are not grounded in observed issues:
- Circular references: No circular-structure patterns found in test payloads or real usage. PyTorch state_dicts are typically OrderedDicts mapping to Tensors/buffers where circular references are uncommon, and the actual codebase uses flat SimpleNamespace attributes populated from dict payloads.
- Deep nesting: Real data shows max 1-2 levels of nesting (e.g.,
[torch.tensor(...)]), not the deeply nested structures that would trigger recursion limits.- Custom objects: Already handled safely via the else clause, which returns objects unchanged.
The suggested safeguards are defensive but add complexity. The current code is safe for the documented use case. If the structure is guaranteed to remain shallow and contain only basic types (tensors, lists, dicts, metadata), the addition is unnecessary. Verify with the maintainers whether this assumption holds before adding the overhead.
hparams/2B.json (1)
22-22: This review comment is incorrect and based on misunderstandings of the codebase.The
gradient_memory_budget_mbparameter is marked as optional and currently unused. The function docstring insrc/tplr/neurons.py:323explicitly states: "Optional memory budget in MB (future: early stopping if exceeded)", indicating this is placeholder functionality that hasn't been implemented yet. There are no budget violation warnings or enforcement mechanisms in the code.Additionally, the memory calculation in the review is based on incorrect assumptions about compression. The compression uses a count-based
topkparameter (absolute element count per layer), not the 3% sparsity ratio assumed in the calculation. The parameter appears only in the 2B model configuration and is not present in any other hparams files (150M.json, 1B.json, etc.), suggesting it's experimental or model-specific rather than a standard configuration concern.Likely an incorrect or invalid review comment.
neurons/validator.py (3)
1382-1396: LGTM: Proper peer list broadcasting for distributed gather.The peer list broadcast from master to all ranks is correctly implemented. The use of
dist.broadcast_object_listensures all ranks have the same peer lists for coordinated distributed gathering.
1530-1577: LGTM: Clean distributed vs sequential gather implementation.The conditional gather execution properly handles both distributed and sequential modes with appropriate rank coordination and logging. The distinction between modes is clear and well-documented in the logs.
2501-2501: LGTM: Proper memory budget parameter propagation.The
memory_budget_mbparameter is correctly passed from hparams toouter_step, enabling memory-bounded incremental processing. The use ofgetattrwith aNonedefault ensures backward compatibility.src/tplr/comms.py (2)
1354-1375: LGTM: Backward-compatible API improvement.The optional version parameter with sensible default improves ergonomics while maintaining backward compatibility.
1934-1958: LGTM: Deterministic round-robin partitioning.The implementation correctly uses sorting for determinism and round-robin slicing for balanced distribution across ranks.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
src/tplr/comms.py (1)
2100-2101: Move imports to module level.Importing modules inside a function is an anti-pattern that impacts performance and readability. The
timemodule is already imported at line 26, anddist_helpershould be imported at the module level for consistency.This issue was previously flagged in past review comments but has not been addressed.
🧹 Nitpick comments (2)
neurons/validator.py (1)
1439-1454: Time window calculation runs on all ranks.The time window calculation (including
query_block_timestampcall) runs on all ranks because distributed gather requires all ranks to participate with consistent time bounds. Whilequery_block_timestampmight involve network calls, this is necessary for correctness.If
query_block_timestampbecomes a performance bottleneck, consider calculating on master and broadcasting:# Master calculates timestamp if self.is_master: ts_value = self.query_block_timestamp(sync_block) if ts_value is None: tplr.log_with_context(...) ts_value = time.time() # Broadcast to all ranks ts_tensor = torch.tensor([ts_value], dtype=torch.float64, device=self.device) dist_helper.broadcast(ts_tensor, src=0) ts_value = ts_tensor.item() # All ranks compute time window from broadcasted timestamp time_min = datetime.fromtimestamp(ts_value, tz=timezone.utc) time_max = time_min + timedelta(seconds=self.hparams.time_window_delta_seconds)src/tplr/neurons.py (1)
1598-1652: Consider caching partial-to-UID mapping for efficiency.The nested loop (Lines 1615-1619) searches for the partial containing
uids[i]on every iteration, resulting in O(Ptot × num_partials) complexity per parameter. While correct, this could be optimized by building auid_to_partial_infomapping once during the initial traversal at Lines 1540-1554.For typical deployments with 2-4 GPUs, the performance impact is negligible, so this optimization is optional. The current implementation prioritizes clarity over micro-optimization.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
hparams/2B.json(1 hunks)neurons/validator.py(11 hunks)src/tplr/comms.py(8 hunks)src/tplr/neurons.py(8 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- hparams/2B.json
🧰 Additional context used
🧬 Code graph analysis (3)
src/tplr/comms.py (2)
src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
src/tplr/neurons.py (1)
src/tplr/compress.py (6)
compress(373-377)compress(379-383)compress(386-423)ChunkingTransformer(145-292)TopKCompressor(295-696)unpack_12bit_indices(102-142)
neurons/validator.py (6)
src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)neurons/base_node.py (1)
query_block_timestamp(202-247)src/tplr/comms.py (2)
gather_with_reserve(2199-2455)merge_gather_results(1961-2070)src/tplr/neurons.py (1)
check_uid_index_overlap(1524-1733)src/tplr/compress.py (1)
maybe_dequantize_values(646-696)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (16)
neurons/validator.py (5)
91-132: LGTM: Clean abstraction for partial vs merged gather results.These static helper methods provide a clean, maintainable way to extract UIDs, skipped UIDs, and success rate from either partial or merged gather results. The implementation correctly handles both data structures and simplifies downstream code.
Minor note: Line 116 deduplicates skipped UIDs using
set(), which loses ordering. This is acceptable since skipped UIDs don't require specific ordering.
1382-1396: Correct distributed synchronization for peer lists.Broadcasting the peer lists from master to all ranks is necessary for distributed gather where each rank needs to know which peers to fetch from. The synchronization happens before gather is invoked (line 1534), ensuring all ranks have consistent peer information.
1507-1583: Well-structured distributed gather integration.The distributed gather mode detection and handling is clean and consistent. Distributed gather returns partials for memory-efficient incremental processing, while sequential gather returns merged results for traditional processing. The logging appropriately handles both cases using the helper methods.
1618-1640: Excellent improvement over past review comments.The past review comments flagged three separate merges (upload, overlap check, norms) creating unnecessary memory overhead. The current implementation addresses this beautifully:
- Overlap check (line 1632): Now handles partials directly, no merge needed ✅
- Norms computation (line 1840): Now handles partials directly, no merge needed ✅
- Upload (lines 1618-1623): Still requires merge for artifact format
By making overlap check and norms partial-aware, you've eliminated 2 out of 3 merges. The remaining upload merge is acceptable since it's async, happens in the background, and artifacts legitimately need merged format. This is actually better than the suggested fix in the past comments!
2494-2494: Clean integration of memory budget configuration.Passing
gradient_memory_budget_mbfrom hparams toouter_stepenables memory-bounded incremental processing of partials, as described in the PR objectives. ThegetattrwithNonedefault provides a safe fallback if the configuration isn't present.src/tplr/comms.py (6)
1354-1387: LGTM!The signature change to make
versionoptional with explicit None handling maintains backward compatibility while providing flexibility for callers.
1931-1959: LGTM!The round-robin partitioning logic is correct and deterministic. Sorting the UIDs ensures consistent ordering across all ranks, and the slice notation correctly implements the round-robin assignment.
1960-2071: Verify parallel time tracking semantics.The merge logic correctly uses
max(total_time, ...)on Line 2044, which is appropriate for distributed parallel execution where partials are fetched simultaneously. However, ensure downstream consumers understand this represents wall-clock time, not cumulative processing time.The state_dict merging logic (Lines 2028-2040) correctly handles both list and non-list param_data, and the device consolidation recursively moves all tensors to the target device.
2072-2195: Verify all-gather synchronization semantics.The distributed gather coordination is well-structured. Line 2153 correctly uses synchronous
all_gather_objectas a blocking barrier operation, ensuring all ranks have consistent partial results before proceeding.Confirm that the
return_partialsflag behavior is correctly understood by callers: whenTrue, each rank receives a list of all partials (not just its own), enabling incremental processing.
2255-2310: Robust distributed gather with fallback.The fallback logic is well-designed:
- Tries distributed gather when configured and available
- Falls back to sequential on master rank if distributed fails or returns no gradients (Lines 2278-2299)
- Provides clear logging for debugging
The check on Line 2278 correctly handles both None results and empty partial lists, ensuring the fallback triggers when no gradients are collected.
2352-2456: Verify reserve peer handling for partials.The reserve peer fallback correctly handles both partials and merged results:
- For partials (Lines 2383-2395): appends fallback as a new partial with proper rank assignment
- For merged results (Lines 2397-2412): merges state_dict lists and extends metadata
Lines 2417-2440 compute aggregate success rate across all partials. Verify that this matches the expected semantics for downstream consumers who may expect per-partial success rates.
src/tplr/neurons.py (5)
226-293: LGTM!The dispatcher pattern cleanly separates incremental and single-result processing paths. The conditional routing based on
isinstance(gather_result, list)is clear and the memory_budget_mb parameter is correctly passed only to the incremental path that uses it.
295-545: Thorough incremental processing with robust error recovery.The implementation correctly handles:
- Deterministic ordering (Line 359): Sorts partials by rank for reproducibility
- Multi-GPU tensor consolidation (Lines 391-407): Critical fix that moves tensors from different GPU devices (cuda:0, cuda:1, etc.) to the target model device
- Duplicate prevention (Lines 375-381): Tracks applied UIDs to skip duplicates
- L2 norm accumulation (Lines 431-435, 484-487): Correctly accumulates squared norms and takes sqrt at the end
- Graceful degradation (Lines 465-479): Catches exceptions, logs failures, continues processing remaining partials
- Memory management (Lines 451-464): Aggressive cleanup with budget monitoring
The error recovery logs failures (Lines 473-477, 507-516) without breaking the entire batch, which is appropriate for production resilience.
547-801: LGTM!The extraction of the original single-result logic into
_outer_step_singleis clean. The functionality remains unchanged, and the updated docstring correctly describes it as the "original single-result outer step implementation."
1539-1559: Correct partial vs. merged result handling.The logic correctly distinguishes between partials and merged results:
- Partials (Lines 1540-1554): Each UID maps to its partial's state_dict, enabling per-UID data extraction
- Merged (Lines 1556-1559): All UIDs share the same merged state_dict, which already contains aggregated gradients
The design maintains backward compatibility while supporting the new incremental partial processing path.
1662-1677: Verify pairwise comparison uses correct UID indices.The comparison correctly uses
valid_uids_for_param(Lines 1673-1675) to map local parameter indices back to the original UID indices in the globaluidslist. This ensures pair keys like(uid_idx_i, uid_idx_j)correctly identify peer pairs even when some peers are missing the parameter.
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
scripts/cleanup_bucket.py (1)
33-33: Update the docstring to reflect all deleted prefixes.The docstring mentions only three prefixes but the implementation now deletes six types of files.
Apply this diff to update the docstring:
- """Delete objects in the R2 bucket that start with 'checkpoint', 'gradient', or 'start_window'""" + """Delete objects in the R2 bucket that start with 'checkpoint', 'gradient', 'start_window', 'aggregator', 'debug', or 'peers'"""neurons/validator.py (1)
4398-4407: Fix slash multiplier when gather success rate is low to preserve scores instead of zeroing them.The slash_multiplier is applied multiplicatively to scores (line 4420:
self.final_scores[uid] *= slash_multiplier). When success_rate is low, setting slash_multiplier to 0.0 wipes positive scores to zero. Since the goal is to avoid slashing when gather is unreliable, use 1.0 instead to preserve scores.if success_rate < self.hparams.gather_peers_slash_threshold: - slash_multiplier = 0.0 # Set to 0 if success rate is low + slash_multiplier = 1.0 # No slashing when gather reliability is low else:
♻️ Duplicate comments (3)
neurons/validator.py (1)
3850-3888: Fix vals/quant_params misalignment in partials path (can IndexError).When a partial lacks quant_params (or lengths differ), all_vals and all_quant_params diverge, breaking maybe_dequantize_values. Pad/trim qparams per-partial to match appended vals.
- if vals is not None: - # vals is a list of tensors, one per peer in this partial - # Move all tensors (including nested) to target device - for v in vals: - all_vals.append(move_to_device_recursive(v)) - # quant_params is also a list if present - if quant_params is not None: - for qp in quant_params: - all_quant_params.append(move_to_device_recursive(qp)) + if vals is not None: + num_vals = len(vals) + for v in vals: + all_vals.append(move_to_device_recursive(v)) + # Align qparams count with vals + if quant_params is None: + all_quant_params.extend([None] * num_vals) + else: + qps = list(quant_params) + if len(qps) < num_vals: + qps = qps + [None] * (num_vals - len(qps)) + elif len(qps) > num_vals: + qps = qps[:num_vals] + for qp in qps: + all_quant_params.append(move_to_device_recursive(qp)) @@ - vals_f32 = self.compressor.maybe_dequantize_values( - all_vals, - all_quant_params if all_quant_params else None, - target_device, - ) + vals_f32 = self.compressor.maybe_dequantize_values( + all_vals, all_quant_params, target_device + )src/tplr/comms.py (2)
2181-2182: Move imports to module level.Importing modules inside a function impacts performance and readability. The
timemodule is already imported at line 26, anddist_helpershould be imported at the module level for consistency.Apply this diff:
from types import SimpleNamespace from typing import Any, Literal, cast import aiofiles import bittensor as bt import boto3 import botocore import torch from aiobotocore.client import AioBaseClient from aiobotocore.session import get_session from boto3.s3.transfer import TransferConfig from botocore.exceptions import ClientError, ConnectionClosedError from tqdm import tqdm as std_tqdm import tplr from tplr.chain import ChainManager from tplr.compress import TopKCompressor, unpack_12bit_indices from tplr.config import BUCKET_SECRETS, client_config from tplr.schemas import Bucket, CommsGetResult +from tplr.distributed import dist_helperThen remove the import from lines 2181-2182:
- import time - from tplr.distributed import dist_helper - window = kwargs.get("window", None)Based on learnings
2317-2317: Move dist_helper import to module level.Similar to line 2182,
dist_helperis imported locally here. This should be at module level for consistency and performance, as noted in the previous comment and past review.
🧹 Nitpick comments (8)
scripts/cleanup_bucket.py (1)
93-93: Verify that "training files" accurately describes all deleted prefixes.The updated messages use the generic term "training files" instead of enumerating specific prefixes, which is more maintainable. However, ensure that all six prefixes ("checkpoint", "gradient", "start_window", "aggregator", "debug", "peers") are actually training-related files. Files with "debug" or "peers" prefixes may serve different purposes.
Also applies to: 114-114, 123-123
neurons/validator.py (4)
95-134: Deduplicate UIDs when aggregating across partials (affects success_rate).Duplicate UIDs across partials will inflate counts. Dedup before computing metrics.
Apply:
@@ - all_uids = [] + all_uids: list[int] = [] @@ - return all_uids + return list(set(all_uids)) @@ - total_uids = len(Validator._get_uids_from_gather_result(gather_result)) - total_skipped = len( + total_uids = len(set(Validator._get_uids_from_gather_result(gather_result))) + total_skipped = len(set( Validator._get_skipped_uids_from_gather_result(gather_result) - ) + ))
138-202: Handle list-of-partials in offload_gather_results to actually save GPU RAM.Currently no-op for list inputs; extend to offload each partial’s state_dict.
def offload_gather_results(self, gather_result, *, log: bool = True) -> None: @@ - if gather_result is None: + if gather_result is None: return @@ - if getattr(gather_result, "_offloaded", None) == "cpu": - return + if isinstance(gather_result, list): + # Offload each partial independently + moved_total = 0 + for part in gather_result: + if part is None or getattr(part, "_offloaded", None) == "cpu": + continue + moved_total += self._offload_one_namespace(part) + if moved_total and log: + tplr.log_with_context( + level="info", + message=f"[GatherOffload] partials → CPU pinned | moved ~{moved_total / 1e6:.1f} MB", + sync_window=self.sync_window, + current_window=self.current_window, + ) + return + if getattr(gather_result, "_offloaded", None) == "cpu": + return @@ - moved_bytes = 0 + moved_bytes = 0 @@ - state = getattr(gather_result, "state_dict", None) + state = getattr(gather_result, "state_dict", None) if state is None: return @@ - setattr(gather_result, "_offloaded", "cpu") + setattr(gather_result, "_offloaded", "cpu") @@ if log: tplr.log_with_context( level="info", message=( "[GatherOffload] aggregator → CPU pinned | moved ~" f"{moved_bytes / 1e6:.1f} MB in {time.time() - t0:.3f}s" ), sync_window=self.sync_window, current_window=self.current_window, ) + + def _offload_one_namespace(self, ns: SimpleNamespace) -> int: + moved_bytes = 0 + state = getattr(ns, "state_dict", None) + if state is None: + return 0 + def _to_pinned_cpu(obj): + nonlocal moved_bytes + if torch.is_tensor(obj): + if obj.device.type == "cpu": + return obj.pin_memory() if not obj.is_pinned() else obj + dst = torch.empty_like(obj, device="cpu", pin_memory=True) + dst.copy_(obj, non_blocking=True) + moved_bytes += obj.element_size() * obj.numel() + return dst + if isinstance(obj, list): + return [_to_pinned_cpu(x) for x in obj] + if isinstance(obj, tuple): + return tuple(_to_pinned_cpu(x) for x in obj) + if isinstance(obj, dict): + return {k: _to_pinned_cpu(v) for k, v in obj.items()} + if isinstance(obj, SimpleNamespace): + for k, v in vars(obj).items(): + setattr(obj, k, _to_pinned_cpu(v)) + return obj + return obj + if isinstance(state, SimpleNamespace): + _to_pinned_cpu(state) + elif isinstance(state, dict): + for k, v in list(state.items()): + state[k] = _to_pinned_cpu(v) + setattr(ns, "_offloaded", "cpu") + return moved_bytes
2490-2507: Standardize device argument to string for outer_step.Align with Miner and Comms usage for predictability across loaders/mappers.
- device=cast(str, self.device), + device=str(self.device),
1539-1555: Usestr(self.device)instead ofcast(str, self.device)for consistency with Miner.The
gather()method at src/tplr/comms.py:1655 explicitly types the device parameter asstr. Miner usesstr(self.device)when calling the same method, making this a consistency issue. Thecast()form is a type-checking construct and doesn't perform runtime conversion;str()is more explicit and Pythonic.- device=cast(str, self.device), + device=str(self.device),src/tplr/neurons.py (3)
477-484: Memory budget check is informational only.The memory budget check logs a warning but continues processing. This aligns with the docstring stating "not enforced yet, future work", but consider adding a configuration flag to make this check fail-fast when memory constraints are critical.
1571-1591: Optimize state_dict storage for partials.For partials, the code stores
state_dictreferences for each UID (lines 1582-1583), creating O(n) duplicate references where n is the total number of UIDs. Consider building auid_to_partialmapping once upfront to avoid redundant storage and improve the lookup performance in the unpacking phase (lines 1645-1654).Apply this pattern:
# Build mapping once upfront uid_to_partial: dict[int, tuple[SimpleNamespace, int]] = {} for partial in gather_result: if partial is not None: partial_uids = getattr(partial, "uids", []) for idx, uid in enumerate(partial_uids): uid_to_partial[uid] = (partial, idx) uids = list(uid_to_partial.keys()) # Then in unpacking loop (around line 1645): partial_for_uid, uid_idx_in_partial = uid_to_partial.get(uids[i], (None, None)) if partial_for_uid is None: continue
1645-1667: Nested search is O(partials) per UID per parameter.The code searches through all partials for each UID on each parameter (lines 1645-1654). With many partials, parameters, and UIDs, this becomes expensive. The suggested
uid_to_partialmapping in the previous comment would eliminate this repeated search.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
neurons/miner.py(2 hunks)neurons/validator.py(11 hunks)scripts/cleanup_bucket.py(3 hunks)src/tplr/comms.py(10 hunks)src/tplr/neurons.py(8 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (1)
src/tplr/comms.py (1)
gather(1648-2010)
neurons/validator.py (6)
src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)neurons/base_node.py (1)
query_block_timestamp(202-247)src/tplr/comms.py (2)
gather_with_reserve(2282-2547)merge_gather_results(2042-2151)src/tplr/neurons.py (1)
check_uid_index_overlap(1556-1769)src/tplr/compress.py (1)
maybe_dequantize_values(646-696)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
src/tplr/neurons.py (1)
src/tplr/compress.py (6)
compress(373-377)compress(379-383)compress(386-423)ChunkingTransformer(145-292)TopKCompressor(295-696)unpack_12bit_indices(102-142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (19)
neurons/miner.py (3)
707-709: LGTM: clearer gather start log.The updated sequential-gather log is helpful and low-cost.
728-739: LGTM: post-gather summary logging.Concise success/skip breakdown improves ops visibility.
711-725: Original review comment is incorrect—expected_compressed_params IS initialized.Miner inherits Trainer.init_model(), which is called at line 200 of Miner.init() with meta=True. Trainer.init_model() unconditionally executes
self.expected_compressed_params = self.get_expected_params()at line 123, regardless of the meta parameter. The attribute is properly initialized before the gather() call at line 724, so no AttributeError will occur.Likely an incorrect or invalid review comment.
neurons/validator.py (3)
1557-1593: LGTM: distributed/seq gather completion logs.Good differentiation between partials vs merged with clear KPIs.
1386-1399: LGTM: broadcast peers to all ranks.Keeps ranks aligned for distributed gather without extra RPCs.
1446-1460: LGTM: shared time window computation.Consistent time_min/max across ranks; sensible fallback to current time.
src/tplr/neurons.py (3)
226-293: LGTM – Clean delegation pattern.The refactoring to delegate based on input type (single vs. list of partials) is well-structured. All parameters are correctly forwarded, and the delegation logic is clear.
486-501: LGTM – Graceful error recovery.The error handling approach is robust: logs the error, tracks failed partials with metadata, and continues processing remaining partials. This prevents a single bad partial from blocking the entire gradient application.
443-465: Verify approach to averaging parameter means across partials.The code accumulates
param_meansfrom each partial as a list, then computes a simple arithmetic average (lines 512–516). This is mathematically correct only if all partials contribute parameters with equal weight. If partials can have different parameter distributions or varying numbers of elements per parameter, the mean should be weighted by parameter size. However, the fingerprint structure doesn't track per-parameter element counts across partials, making weighted averaging difficult to implement with the current design.Clarify whether:
- Partials always include the complete parameter set (non-sparse)
- If sparse, whether weighted averaging by
total_elementsis needed or if simple average is semantically correct for your use casesrc/tplr/comms.py (10)
477-517: LGTM – Conservative error handling prevents false slashing.The implementation correctly uses HEAD requests to check existence without downloading. The conservative approach of returning
Truefor non-404 errors (lines 511-516) is appropriate to prevent falsely penalizing miners due to transient network issues.
1395-1427: LGTM – Flexible version parameter.The signature change from
version: str = tplr.__version__toversion: str | None = Nonewith runtime defaulting (lines 1415-1416) is more flexible and maintains backward compatibility.
2015-2040: LGTM – Simple and deterministic partitioning.The round-robin assignment with deterministic sorting ensures all ranks agree on the partition. The implementation is straightforward and correct.
2093-2121: LGTM – Device consolidation during merge is correct.The merge logic correctly concatenates state_dict lists from all partials while moving tensors to the target device (lines 2115-2116). This device consolidation is essential for multi-GPU scenarios where different ranks may have tensors on different devices.
2132-2134: Success rate calculation is partial-aware.The success rate correctly considers both
all_uids(successful) andall_skipped_uids(failed) from all partials, giving an accurate overall success rate rather than just the rate from individual partials.
2244-2277: LGTM – Return path logic handles both merged and partial modes.The conditional return based on
return_partialsflag is clear and correct. When returning partials (lines 2244-2260), rank metadata is added for downstream deterministic sorting. When merging (lines 2262-2277), the merge happens on all ranks for consistency, with appropriate logging on rank 0.
2340-2397: LGTM – Robust distributed gather with fallback.The implementation correctly attempts distributed gather (lines 2352-2375), detects failures (lines 2362-2370), and falls back to sequential gather on the master rank (lines 2377-2387). The fallback logic handles both exceptions and empty results, providing resilience against distributed coordination issues.
2473-2502: LGTM – Reserve fallback correctly handles both merged and partial modes.The code correctly appends reserve results as a new partial (lines 2473-2485) when using partials mode, and merges tensor lists directly (lines 2486-2502) for merged mode. The rank assignment for the new partial (line 2476) ensures deterministic ordering.
2507-2546: LGTM – Success rate computation is mode-aware.For partials (lines 2507-2528), the code computes an aggregate success rate and propagates it to all partials. For merged results (lines 2534-2541), it computes a simple rate. Both approaches are correct for their respective modes.
1754-1803: LGTM – File existence check prevents false slashing.The approach of checking S3 object existence to distinguish between network failures and missing uploads is sound. When a file exists but download fails, the miner is not penalized (lines 1765-1772, 1790-1796), correctly attributing the failure to network issues rather than miner negligence.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
src/tplr/comms.py (2)
2181-2182: Move imports to module level.In-function imports are an anti-pattern. The
timemodule is already imported at line 26, anddist_helpershould be imported at the module level for consistency.This issue was flagged in previous reviews. Apply this diff to move the import:
from tplr.chain import ChainManager from tplr.compress import TopKCompressor, unpack_12bit_indices from tplr.config import BUCKET_SECRETS, client_config +from tplr.distributed import dist_helper from tplr.schemas import Bucket, CommsGetResultThen remove the import from line 2182:
- import time - from tplr.distributed import dist_helper - window = kwargs.get("window", None)Based on learnings
2317-2317: Move import to module level.Same issue as line 2182:
dist_helpershould be imported at the module level, not inside the function.Remove this import after moving
dist_helperto the module-level imports as suggested in the previous comment.Based on learnings
🧹 Nitpick comments (1)
src/tplr/comms.py (1)
2412-2546: Complex but correct handling of partial vs merged results.The code correctly handles both
return_partials=True(returning a list) andreturn_partials=False(returning merged namespace) throughout the reserve gather logic. The success rate calculations for both cases (lines 2507-2545) are mathematically correct.The complexity is necessary for the dual-mode operation. Consider extracting the success rate calculation into a helper method if similar logic is needed elsewhere.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/tplr/comms.py(10 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (5)
src/tplr/comms.py (5)
477-517: LGTM: Conservative existence check implementation.The method correctly uses HEAD requests to check object existence without downloading. The conservative approach of returning
Trueon non-404 errors (lines 512-513, 515-516) is appropriate to prevent false slashing due to transient network issues.
1395-1416: LGTM: Proper None-handling for optional version parameter.The signature change and None-handling logic (lines 1415-1416) correctly default to
tplr.__version__when version is not provided, maintaining backward compatibility.
2015-2039: LGTM: Correct deterministic partitioning.The round-robin assignment with sorted UIDs ensures deterministic, balanced distribution across ranks.
2340-2397: LGTM: Robust fallback mechanism for distributed gather.The logic properly attempts distributed gather first, detects failures (lines 2362-2370), and falls back to sequential gather on the master rank (lines 2378-2387). The error handling and logging provide good observability.
1754-1803: Critical: Undefined attributes cause runtime failures.Lines 1759 and 1784 reference
self.version, which is never defined in theCommsclass. Lines 1762 and 1787 callself.peer_bucket(str(uid)), which does not exist. These were flagged in previous reviews but remain unfixed.Apply this diff:
) - # Check if file exists on R2 to distinguish miner vs network failure - gradient_key = f"{key}-{window}-{uid}-v{self.version}.pt" + # Check if file exists on R2 to distinguish miner vs network failure + gradient_key = f"{key}-{window}-{uid}-v{tplr.__version__}.pt" file_exists = await self.s3_object_exists( key=gradient_key, - bucket=self.peer_bucket(str(uid)), + bucket=self.commitments.get(int(uid)), timeout=5, )Also apply to lines 1784-1787:
if response is None: - # Similar check for None responses - gradient_key = f"{key}-{window}-{uid}-v{self.version}.pt" + # Similar check for None responses + gradient_key = f"{key}-{window}-{uid}-v{tplr.__version__}.pt" file_exists = await self.s3_object_exists( key=gradient_key, - bucket=self.peer_bucket(str(uid)), + bucket=self.commitments.get(int(uid)), timeout=5, )Likely an incorrect or invalid review comment.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
src/tplr/neurons.py (2)
328-329: Remove redundant imports.Both
gcandtorch.cudaare already imported at the module level (lines 20 and 28). These redundant imports inside the function are unnecessary.Apply this diff:
- import gc - import torch.cuda -
1584-1603: Consider clearer data structure for partial-to-UID mapping.The current approach builds
all_state_dictsby appending the samepartial.state_dictreference multiple times (once per UID in that partial). While functionally correct, this creates a list whereall_state_dicts[i]corresponds touids[i]. However, the subsequent code (lines 1655-1663) then re-searches through all partials to find which partial each UID belongs to, which is inefficient and confusing.Consider building a direct UID-to-partial mapping upfront:
# Build UID -> partial mapping once uid_to_partial = {} if isinstance(gather_result, list): for partial in gather_result: if partial is not None: for uid in getattr(partial, "uids", []): uid_to_partial[uid] = partial uids = list(uid_to_partial.keys()) all_state_dicts = [uid_to_partial[uid].state_dict for uid in uids] else: uids = list(getattr(gather_result, "uids", [])) all_state_dicts = [gather_result.state_dict] * len(uids)This eliminates the need to search for partials later and makes the relationship explicit.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/tplr/neurons.py(8 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/neurons.py (3)
tests/test_comms.py (1)
compress(946-948)tests/test_prepare_gradient_dict.py (3)
compress(18-27)compress(155-165)compress(242-253)src/tplr/compress.py (6)
compress(373-377)compress(379-383)compress(386-423)ChunkingTransformer(145-292)TopKCompressor(295-696)unpack_12bit_indices(102-142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (3)
src/tplr/neurons.py (3)
225-292: LGTM: Clean dispatcher pattern.The refactoring to split single and incremental processing paths improves code organization and makes the incremental processing logic easier to maintain.
591-844: LGTM: Fingerprint statistics properly integrated.The addition of per-parameter fingerprint tracking (lines 742-747) correctly accumulates norms, means, and element counts for monitoring gradient statistics.
512-516: Verify mean averaging is correct across partials with varying element counts.The code computes a simple average of means across partials, but if different partials have different numbers of elements for the same parameter, this should be a weighted average. For example, if partial 1 contributes 100 elements with mean=0.5 and partial 2 contributes 10 elements with mean=1.0, the combined mean should be
(100*0.5 + 10*1.0)/110 = 0.545, not(0.5 + 1.0)/2 = 0.75.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (4)
src/tplr/comms.py (3)
2197-2198: Move imports to module level.The
timemodule is already imported at line 26, anddist_helpershould be imported at the module level for consistency with the rest of the file. In-function imports impact performance and readability.Move the import to the top of the file:
from types import SimpleNamespace from typing import Any, Literal, cast import aiofiles import bittensor as bt import boto3 import botocore import torch from aiobotocore.client import AioBaseClient from aiobotocore.session import get_session from boto3.s3.transfer import TransferConfig from botocore.exceptions import ClientError, ConnectionClosedError from tqdm import tqdm as std_tqdm import tplr from tplr.chain import ChainManager from tplr.compress import TopKCompressor, unpack_12bit_indices from tplr.config import BUCKET_SECRETS, client_config from tplr.schemas import Bucket, CommsGetResult +from tplr.distributed import dist_helperThen remove the local import at line 2198.
2333-2333: Move dist_helper import to module level.Another instance of the local
dist_helperimport. This should be consolidated with the module-level imports for consistency.Use the same module-level import suggested in the previous comment and remove this local import.
2080-2081: Verify cuda:0 availability before defaulting.Line 2081 defaults to
"cuda:0"when CUDA is available, but doesn't verify that device 0 specifically exists. On multi-GPU systems, device 0 might not be available.Use a safer default:
# Determine target device for tensor consolidation if target_device is None: - target_device = "cuda:0" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + target_device = f"cuda:{torch.cuda.current_device()}" + else: + target_device = "cpu"src/tplr/neurons.py (1)
1644-1698: Optimize O(n×m) partial lookup in inner loop.The nested loop at lines 1661-1665 searches through all partials for each UID, creating O(n×m) complexity where n is the number of UIDs and m is the number of partials. This issue was previously identified but remains unresolved.
As suggested in the previous review, build a UID-to-partial mapping once before the loop:
# Before line 1644, add: if isinstance(gather_result, list): uid_to_partial = {} for partial in gather_result: if partial is not None: for uid in getattr(partial, "uids", []): uid_to_partial[uid] = partialThen replace lines 1644-1698 with direct lookups using this mapping, eliminating the inner search loop.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/tplr/comms.py(10 hunks)src/tplr/neurons.py(8 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/tplr/neurons.py (1)
src/tplr/compress.py (6)
compress(373-377)compress(379-383)compress(386-423)ChunkingTransformer(145-292)TopKCompressor(295-696)unpack_12bit_indices(102-142)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
🔇 Additional comments (8)
src/tplr/neurons.py (4)
226-293: LGTM! Clean dispatch pattern.The refactored
outer_stepfunction provides a clear separation between incremental and single-result processing paths. The documentation accurately describes the memory-minimizing approach and the conditional dispatch logic is straightforward.
451-456: Verify mean averaging is statistically correct.The code averages parameter means across partials using simple arithmetic mean. This is only correct if each partial contributes the same number of elements for each parameter. If partials have different numbers of UIDs or elements, the mean should be weighted by the number of elements each partial contributes.
Consider tracking element counts per parameter and computing a weighted average:
for param_name, mean in partial_fingerprint.get( "param_means", {} ).items(): if param_name not in combined_fingerprint["param_means"]: - combined_fingerprint["param_means"][param_name] = [] - combined_fingerprint["param_means"][param_name].append(mean) + combined_fingerprint["param_means"][param_name] = {"sum": 0.0, "count": 0} + # Need to track element count from partial_fingerprint + elem_count = partial_fingerprint.get("total_elements", 1) + combined_fingerprint["param_means"][param_name]["sum"] += mean * elem_count + combined_fingerprint["param_means"][param_name]["count"] += elem_countThen at finalization:
# Average the means for param_name in combined_fingerprint["param_means"]: - means_list = combined_fingerprint["param_means"][param_name] - combined_fingerprint["param_means"][param_name] = sum(means_list) / len( - means_list - ) + stats = combined_fingerprint["param_means"][param_name] + combined_fingerprint["param_means"][param_name] = stats["sum"] / stats["count"]
470-484: LGTM! Proper memory management.The aggressive memory cleanup after each partial (including garbage collection, CUDA cache clearing, and synchronization) is appropriate for memory-bounded incremental processing. The budget check after cleanup ensures accurate memory measurements.
384-501: LGTM! Robust error recovery.The error handling implements proper graceful degradation: failures are logged with context (rank, UIDs, error message), collected for final reporting, and processing continues with remaining partials. This ensures maximum gradient application even when some partials fail.
src/tplr/comms.py (4)
477-516: LGTM! Conservative existence check.The
s3_object_existsmethod correctly implements a conservative approach by returningTrueon non-404 errors (line 513). This prevents false slashing due to transient network issues, which is the right design choice for a production system.
1395-1416: LGTM! Proper version handling.The updated
gradient_timestampmethod now accepts an optionalversionparameter with a sensible default totplr.__version__whenNone. This addresses potential version-handling issues.
2032-2055: LGTM! Deterministic partitioning.The round-robin UID partitioning with deterministic sorting (line 2052) ensures all ranks agree on the assignment and achieve balanced load distribution. The implementation is correct and efficient.
1754-1818: LGTM! Proper failure differentiation.The enhanced error handling correctly distinguishes between network failures (file exists but download failed) and miner failures (file doesn't exist). This prevents false slashing due to transient network issues while still catching miners who didn't upload gradients.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
neurons/validator.py (2)
1392-1409: Consider adding error handling for peer broadcast.The
dist.broadcast_object_list()call lacks error handling. If the broadcast fails or times out, non-master ranks would retainNonevalues forself.comms.peersandself.comms.reserve_peers, potentially causing issues in the subsequent distributed gather.Consider wrapping the broadcast in a try-except block with a fallback:
# Broadcast peers list from master to all ranks (needed for distributed gather) if dist_helper.world_size > 1 and dist_helper.is_distributed(): - peers_list = ( - [self.comms.peers, self.comms.reserve_peers] - if self.is_master - else [None, None] - ) - dist.broadcast_object_list(peers_list, src=0) - self.comms.peers = peers_list[0] - self.comms.reserve_peers = peers_list[1] + try: + peers_list = ( + [self.comms.peers, self.comms.reserve_peers] + if self.is_master + else [None, None] + ) + dist.broadcast_object_list(peers_list, src=0) + + if peers_list[0] is None or peers_list[1] is None: + raise RuntimeError("Broadcast resulted in None values") + + self.comms.peers = peers_list[0] + self.comms.reserve_peers = peers_list[1] + except Exception as e: + tplr.log_with_context( + level="error", + message=f"Failed to broadcast peers: {e}", + sync_window=self.sync_window, + current_window=self.current_window, + ) + # Skip this window if broadcast fails + continue
1565-1592: Simplify logging - isinstance check is redundant.Since
return_partials=use_distributed_gather(line 1551), distributed gather always returns a list (partials) and sequential gather always returns SimpleNamespace (merged). Theisinstance(gather_result, list)check at line 1577 is therefore redundant within theif use_distributed_gatherblock—it will always be True.Simplify the logging:
# Log gather completion if use_distributed_gather: if gather_result is not None: - # Handle both partials (list) and merged result (SimpleNamespace) - # Show quality metrics for both partials and merged results uids = self._get_uids_from_gather_result(gather_result) skipped_uids = self._get_skipped_uids_from_gather_result( gather_result ) success_rate = self._get_success_rate_from_gather_result( gather_result ) - if isinstance(gather_result, list): - tplr.log_with_context( - level="info", - message=f"[Rank {dist_helper.rank}] Distributed gather complete: {len(uids)}/{len(self.comms.peers)} successful, " - f"{len(skipped_uids)} skipped, success_rate={success_rate:.2%} across {len(gather_result)} partials", - sync_window=self.sync_window, - current_window=self.current_window, - ) - else: - tplr.log_with_context( - level="info", - message=f"[Rank {dist_helper.rank}] Distributed gather complete: {len(uids)}/{len(self.comms.peers)} successful, " - f"{len(skipped_uids)} skipped, success_rate={success_rate:.2%}", - sync_window=self.sync_window, - current_window=self.current_window, - ) + # Distributed gather always returns partials (list) + tplr.log_with_context( + level="info", + message=f"[Rank {dist_helper.rank}] Distributed gather complete: {len(uids)}/{len(self.comms.peers)} successful, " + f"{len(skipped_uids)} skipped, success_rate={success_rate:.2%} across {len(gather_result)} partials", + sync_window=self.sync_window, + current_window=self.current_window, + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
neurons/validator.py(12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
neurons/validator.py (6)
src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)neurons/base_node.py (1)
query_block_timestamp(202-247)src/tplr/comms.py (2)
gather_with_reserve(2298-2563)merge_gather_results(2058-2167)src/tplr/neurons.py (1)
check_uid_index_overlap(1570-1783)src/tplr/compress.py (1)
maybe_dequantize_values(646-696)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (5)
neurons/validator.py (5)
94-134: LGTM! Well-designed helper methods for dual format support.The three static helper methods cleanly abstract the differences between partials (list) and merged results (SimpleNamespace). The success rate calculation correctly handles the total attempted count, and the deduplication of skipped_uids prevents double-counting across partials.
975-981: Good defensive programming to prevent rating errors.Initializing OpenSkill ratings for all UIDs before calling
rate()prevents errors when new peers are encountered. This ensures the rating update always has valid input.
1635-1650: Good memory optimization - single merge for upload only.The code now performs only one merge (for upload artifacts, lines 1636-1640) and passes
gather_resultdirectly tocheck_uid_index_overlap(line 1649), which handles both partials and merged formats. This addresses the past review concern about multiple temporary merges creating unnecessary memory overhead.
2511-2513: LGTM! Proper integration of memory budget parameter.Using
getattrwithNonedefault maintains backward compatibility with older hyperparameter files that don't includegradient_memory_budget_mb, while enabling the new memory-bounded incremental processing feature for configurations that do.
3820-3918: Edge case is properly handled - no changes required.The refactored code correctly handles the empty vals edge case. When
compute_peer_val_normsskips settingclip_norm_dict[vals_key]for a parameter (line 3882), theupdate_model_with_gradientfunction safely retrieves it via.get(vals_key, eval_norm)at line 3687. The fallback uses the evaluation's own norm as the clip threshold, which is a reasonable design: if no peer norms are available, clip based on the evaluation's own magnitude. The debug logging suggestion is optional and not necessary.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (3)
src/tplr/comms.py (3)
1754-1820: Extract duplicate file existence check logic into a helper method.Lines 1759-1789 and 1793-1819 contain nearly identical logic for checking file existence and determining whether to slash a miner. This includes duplicate gradient key construction, bucket retrieval, existence checks, and logging patterns.
Consider extracting this into a helper method:
async def _should_skip_uid_for_missing_gradient( self, uid: int, key: str, window: int ) -> tuple[bool, str]: """ Check if a UID should be skipped due to missing gradient. Returns: tuple[bool, str]: (should_skip, reason) - (True, "not_registered") if no bucket committed - (False, "network_issue") if file exists but wasn't retrieved - (True, "not_uploaded") if file doesn't exist """ gradient_key = f"{key}-{window}-{uid}-v{tplr.__version__}.pt" peer_bucket = self.commitments.get(int(uid)) if not peer_bucket: tplr.logger.info(f"Skipped UID {uid} - no bucket committed") return True, "not_registered" file_exists = await self.s3_object_exists( key=gradient_key, bucket=peer_bucket, timeout=5 ) if file_exists: tplr.log_with_context( level="warning", message=f"UID {uid}: File exists but not retrieved - network issue", current_window=window, ) return False, "network_issue" else: tplr.logger.info(f"Skipped UID {uid} - file not found") return True, "not_uploaded"Then replace both blocks with:
should_skip, reason = await self._should_skip_uid_for_missing_gradient(uid, key, window) if should_skip: skipped_uids.append(uid) continue
2080-2086: Device selection improved but could be more robust.Line 2084 now uses
torch.cuda.current_device()instead of hardcodedcuda:0, which addresses the previous review concern. However,current_device()can raise RuntimeError if called before any CUDA context is established.Consider wrapping in try-except for maximum safety:
if target_device is None: # Use the current device if CUDA is available, otherwise CPU if torch.cuda.is_available(): - target_device = f"cuda:{torch.cuda.current_device()}" + try: + target_device = f"cuda:{torch.cuda.current_device()}" + except RuntimeError: + # No CUDA context yet, use device 0 + target_device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu" else: target_device = "cpu"
2482-2488: Reserve gather always uses sequential mode.Line 2483 uses sequential
gather()for reserve UIDs even when distributed mode is available. This is a reasonable design choice since reserve sets are typically smaller, but it should be documented in the method docstring or code comment.Add a comment explaining the decision:
# Use sequential gather for reserve (not distributed) # Reserve sets are typically small, so the distributed overhead isn't worth it fallback = await self.gather( my_uid=my_uid, uids=replacements, expected_compressed_params=expected_compressed_params, **kwargs, )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/tplr/comms.py(11 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (5)
src/tplr/comms.py (5)
49-49: Import moved to module level as recommended.This addresses the previous review comment about in-function imports. The placement is appropriate with other module-level imports.
1396-1417: Properly handles version parameter with None default.The signature change from
version: str = tplr.__version__toversion: str | None = Nonewith runtime initialization correctly addresses the previous review feedback. This pattern avoids evaluation-time binding issues while maintaining backward compatibility.
2033-2057: Clean round-robin partitioning implementation.The deterministic sorting followed by slice-based round-robin assignment (
sorted_uids[rank::world_size]) is simple, efficient, and correct. Edge cases are properly handled.
2276-2278: Check for existing rank attribute before setting.Lines 2276-2278 add rank metadata to partials without checking if the attribute already exists. While unlikely in practice, this could overwrite existing rank information if a partial was processed multiple times.
Add a guard:
# Add rank metadata to each partial for deterministic sorting for idx, partial in enumerate(all_partials): - if partial is not None and not hasattr(partial, "rank"): + if partial is not None and not hasattr(partial, "rank"): partial.rank = idxActually, the code already has
and not hasattr(partial, "rank")- this is correct. No change needed.
2244-2260: Verify behavior when distributed setup fails.If
world_size > 1butdist_helper.is_distributed()returns False (e.g., distributed initialization failed), line 2260 returns[partial_result], which contains only this rank's subset. This means other ranks' data is lost.Consider adding a warning or error when this mismatch occurs:
if world_size > 1 and dist_helper.is_distributed(): # ... all-gather logic ... else: if world_size > 1: tplr.logger.error( f"[DISTRIBUTED_GATHER] world_size={world_size} but distributed not initialized! " f"Only returning rank {rank} data - other ranks' data will be lost" ) all_partials = [partial_result]This would help catch configuration issues where distributed gathering is expected but the environment isn't properly set up.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
src/tplr/neurons.py (1)
557-557: Fix inverted budget_exceeded flag logic (still not addressed).The logic remains inverted from the previous review: the flag is set to
0when budget IS exceeded and1when it's NOT exceeded. This contradicts the variable name and will cause incorrect metrics.Apply this diff:
- budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0The same issue exists at line 854 in
_outer_step_single.src/tplr/comms.py (1)
2395-2442: Master-only fallback still causes rank inconsistency (duplicate concern).Lines 2433-2442 perform sequential gather fallback only on the master rank when distributed gather fails. This leaves other ranks with
Nonewhile master has data, causing rank divergence.As flagged in past reviews, consider:
- Broadcasting master's fallback result to all ranks using
dist_helper.broadcast_object(), or- Having all ranks perform the sequential fallback (if network is shared), or
- Explicitly failing if distributed gather was required
The current implementation will cause downstream issues when non-master ranks expect valid gather results.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/tplr/comms.py(11 hunks)src/tplr/neurons.py(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/tplr/neurons.py (1)
src/tplr/compress.py (6)
compress(373-377)compress(379-383)compress(386-423)ChunkingTransformer(145-292)TopKCompressor(295-696)unpack_12bit_indices(102-142)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (7)
src/tplr/neurons.py (2)
296-592: LGTM: Robust incremental partial processing.The implementation correctly:
- Sorts partials deterministically by rank for reproducible processing
- Applies sequential device consolidation to handle multi-GPU tensors
- Implements graceful error recovery (logs and continues on partial failures)
- Aggregates statistics correctly (squared norms summed, then sqrt for L2)
- Performs aggressive memory cleanup after each partial
The error handling at lines 486-502 provides good resilience without breaking the entire gather on individual partial failures.
1620-1809: LGTM: Efficient partial-aware overlap checking.The updates correctly support both merged and partial gather results:
- Lines 1625-1635: Creates O(1) UID-to-partial mapping (addresses past O(n×m) concern)
- Lines 1687-1707: Uses direct lookup instead of linear search through partials
- Lines 1737-1753: Correctly maps pair keys using original UID indices
The optimization from the past review is properly implemented.
src/tplr/comms.py (5)
478-555: LGTM: Conservative existence check with proper error classification.The implementation correctly distinguishes between:
- Definitive absence (404) →
False- Configuration errors (403, NoSuchBucket) →
Falsewith ERROR-level logging- Transient issues (timeouts, 5xx) →
True(conservative to avoid false slashing)The ERROR-level logging for config issues (lines 524-534) helps surface legitimate setup problems that need attention.
1433-1465: LGTM: Clean optional version parameter.The signature update correctly:
- Uses
str | Nonewith defaultNoneinstead of the problematicself.version- Applies runtime default
tplr.__version__at lines 1453-1454 whenNoneis passedThis addresses the past review concern about undefined
self.version.
2066-2334: LGTM: Well-structured distributed gather with proper device handling.The implementation addresses past review concerns and is production-ready:
- Lines 2118-2123: Uses
torch.cuda.current_device()instead of hardcodedcuda:0(addresses past review)- Lines 2070-2094: Deterministic round-robin partitioning ensures consistent UID assignment across ranks
- Lines 2151-2163: Recursive device movement handles nested structures correctly
- Lines 2283-2290: Proper all-gather synchronization with blocking semantics
The
return_partialsflag cleanly supports both merged and incremental processing modes.
1792-1857: LGTM: Improved error handling prevents false slashing.The updated error handling correctly uses
s3_object_existsto distinguish:
- Lines 1806-1818: File exists but download failed → network issue, don't slash
- Lines 1819-1826: File doesn't exist → miner didn't upload, legitimate slash
Same logic applied for
Noneresponses at lines 1828-1857. This prevents false negatives due to transient network issues while still catching miners who don't upload.The 5-second timeout at lines 1809 and 1842 is reasonable for existence checks.
2468-2600: LGTM: Clean unified handling of partials and merged results.The implementation elegantly handles both modes:
- Lines 2468-2495: Aggregate metrics computed correctly for both partials and merged
- Lines 2531-2540: Fallback appended as new partial with rank metadata (supports incremental processing)
- Lines 2542-2557: Fallback merged in-place for merged results (supports traditional flow)
- Lines 2562-2583: Success rate computed correctly as aggregate across all partials
The conditional logic cleanly branches on
isinstance(primary, list)without code duplication.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (2)
src/tplr/neurons.py (2)
558-558: CRITICAL: Fix inverted budget_exceeded flag.This bug was flagged in past review comments but was NOT fixed. The flag logic is inverted:
- Currently:
budget_exceeded = 0when memory exceeds budget (WRONG)- Should be:
budget_exceeded = 1when memory exceeds budgetThis makes the WandB metric
outer_step/memory_budget_violationreport the opposite of reality.Apply this fix:
- budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0Also applies to line 855 in
_outer_step_single.
849-878: CRITICAL: Same inverted budget_exceeded flag bug.This is the same bug as line 558. The flag logic is inverted at line 855:
budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1This makes the WandB metric
outer_step/memory_budget_violationincorrect.Apply this fix:
- budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0
🧹 Nitpick comments (2)
src/tplr/comms.py (1)
2118-2124: Add error handling for CUDA device operations.The device consolidation uses
torch.cuda.current_device()which is an improvement over the previous"cuda:0"hardcoding, but there's no error handling if CUDA operations fail.Consider wrapping the device determination in a try-except:
if target_device is None: # Use the current device if CUDA is available, otherwise CPU if torch.cuda.is_available(): - target_device = f"cuda:{torch.cuda.current_device()}" + try: + target_device = f"cuda:{torch.cuda.current_device()}" + except RuntimeError as e: + tplr.logger.warning(f"Failed to get current CUDA device: {e}, falling back to cpu") + target_device = "cpu" else: target_device = "cpu"src/tplr/neurons.py (1)
329-330: Remove redundant imports.Both
gcandtorch.cudaimports inside the function are redundant:
gcis already imported at module level (line 20)torchis already imported at module level (line 28), sotorch.cudais availableRemove these lines:
- import gc - - import torch.cuda - # Track applied UIDs to prevent duplicates
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/tplr/comms.py(11 hunks)src/tplr/neurons.py(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/tplr/neurons.py (1)
src/tplr/compress.py (5)
compress(373-377)compress(379-383)compress(386-423)ChunkingTransformer(145-292)unpack_12bit_indices(102-142)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
🔇 Additional comments (10)
src/tplr/comms.py (5)
478-555: LGTM - Conservative error handling prevents false slashing.The implementation correctly distinguishes between transient failures (timeouts, 5xx errors) and permanent failures (404, 403, invalid bucket). The conservative approach of returning
Truefor transient errors prevents false slashing of miners due to network issues.The logging levels are appropriate: ERROR for configuration issues that need attention, WARNING for transient issues.
1792-1857: Verify latency impact of HEAD requests in high-failure scenarios.The file existence checks prevent false slashing by distinguishing network failures from missing files, which is critical. However, each failed download now triggers an additional HEAD request (5s timeout).
In distributed gather scenarios with many UIDs and high network failure rates, this could add significant cumulative latency (e.g., 20 failed downloads = 100s of additional checks even if all return immediately).
Consider monitoring the impact in production and potentially batching HEAD requests if latency becomes an issue:
# Potential optimization: batch HEAD requests failed_uids_to_check = [(uid, gradient_key, peer_bucket) for ...] results = await asyncio.gather(*[ self.s3_object_exists(key=k, bucket=b, timeout=5) for _, k, b in failed_uids_to_check ])
2211-2333: LGTM - Well-structured distributed gather implementation.The distributed gather correctly:
- Partitions UIDs deterministically using round-robin (ensures reproducibility)
- Processes subsets in parallel across ranks (achieves ~2x speedup as reported)
- Collects results via all-gather (all ranks get consistent view)
- Supports both merged and partial return modes for memory-bounded processing
The detailed per-rank logging will be valuable for debugging distributed issues.
2432-2461: LGTM - Consistent fallback ensures all ranks stay synchronized.The fallback logic correctly ensures that all ranks fall back together when distributed gather fails, preventing rank inconsistency. The comment at line 2444 helpfully explains that sequential gather (R2 downloads) can be done independently by each rank without requiring coordination.
This addresses the past review concern about master-only fallback causing inconsistency.
2475-2609: LGTM - Consistent handling of merged and partial results.The implementation correctly:
- Computes aggregate metrics across partials when
return_partials=True- Uses traditional success rate calculation for merged results
- Returns
Nonewhen no gradients were successfully gathered (for both cases)- Always uses sequential gather for reserve (reasonable since reserve set is smaller)
The logging clearly distinguishes between partials and merged results, which aids debugging.
src/tplr/neurons.py (5)
226-293: LGTM - Clean separation of incremental and single-result processing.The expanded signature and branching logic clearly separate the incremental partial processing path from the original single-result path. The type hints accurately reflect the supported input types.
480-486: Memory budget is logged but never enforced.The code tracks memory usage against
memory_budget_mband logs warnings when exceeded, but never enforces the budget (no early stopping, no rejection of additional partials). The comment at line 325 says "future: early stopping if exceeded" but the implementation only logs.Clarify the intended behavior:
- Is this intentional (logging only for monitoring)?
- Should the code stop processing additional partials when budget is exceeded?
- Should there be a hard limit that fails the entire operation?
If enforcement is needed, consider:
if current_memory_mb > memory_budget_mb: tplr.logger.error( f"Memory budget exceeded: {current_memory_mb:.0f}MB > {memory_budget_mb:.0f}MB. " f"Stopping partial processing at {partial_idx + 1}/{len(partials)}" ) break # Stop processing additional partials
487-503: Verify error recovery strategy for failed partials.When a partial fails to process, the code logs the error and continues processing remaining partials. This graceful degradation might be acceptable, but it could lead to incomplete gradient application.
Confirm that validators correctly handle partial application in weight updates. Specifically:
- Do validators track which UIDs had gradients successfully applied?
- Are miners penalized if their gradient was in a failed partial?
- Should the entire operation fail if any partial fails?
If complete application is required, consider collecting failures and raising after all partials are attempted:
if failed_partials: raise RuntimeError( f"Failed to process {len(failed_partials)} partials. " f"Errors: {[fp['error'] for fp in failed_partials]}" )
1621-1755: LGTM - Efficient partials support with O(1) lookup optimization.The implementation correctly:
- Builds a
uid_to_partialmapping for O(1) lookup (addresses past review about O(n×m) complexity)- Handles both merged results (shared state_dict) and partials (per-UID state_dict)
- Uses original UID indices for consistent pair keys across comparisons
- Falls back gracefully when state_dicts are missing
The optimization reduces complexity from O(n×m) to O(n) as intended.
392-417: LGTM - Device consolidation is critical for multi-GPU support.The recursive tensor movement to target device is essential for multi-GPU setups where different ranks produce tensors on different devices (e.g., rank 0 on cuda:0, rank 1 on cuda:1). Without this consolidation, applying gradients would fail with device mismatch errors.
The comment at line 391 helpfully explains this requirement.
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
src/tplr/neurons.py (1)
485-486: Fix torch.cuda APIs called with string device and correct memory budget flag.
- torch.cuda.memory_allocated/synchronize expect a device index or torch.device, not a str.
- budget_exceeded is inverted; should be 1 when budget is exceeded, 0 otherwise. (Duplicate of earlier note.)
Apply this diff:
@@ - if memory_budget_mb is not None and torch.cuda.is_available(): - initial_memory_mb = torch.cuda.memory_allocated(device) / (1024**2) + if memory_budget_mb is not None and torch.cuda.is_available(): + dev_obj = torch.device(device) + initial_memory_mb = torch.cuda.memory_allocated(dev_obj) / (1024**2) @@ - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize(device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize(torch.device(device)) @@ - if memory_budget_mb is not None and torch.cuda.is_available(): - final_memory_mb = torch.cuda.memory_allocated(device) / (1024**2) + if memory_budget_mb is not None and torch.cuda.is_available(): + dev_obj = torch.device(device) + final_memory_mb = torch.cuda.memory_allocated(dev_obj) / (1024**2) @@ - budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0 @@ - if on_src and memory_budget_mb is not None and torch.cuda.is_available(): - final_memory_mb = torch.cuda.memory_allocated(device) / (1024**2) + if on_src and memory_budget_mb is not None and torch.cuda.is_available(): + dev_obj = torch.device(device) + final_memory_mb = torch.cuda.memory_allocated(dev_obj) / (1024**2) @@ - budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0Also applies to: 591-596, 681-689, 981-986
🧹 Nitpick comments (3)
src/tplr/comms.py (2)
1792-1857: Optional: parallelize R2 existence checks to avoid serial awaits on slow paths.Both exception and None branches await s3_object_exists per UID sequentially. Consider a small bounded-concurrency gather for these checks to reduce tail latency under widespread timeouts.
2096-2210: Default merge target_device to CPU to reduce OOM risk when merging large partials.Merging moves all tensors to the current CUDA device by default; on large peer sets this can spike GPU memory. Prefer CPU by default; let callers opt-in to CUDA explicitly.
Apply this diff:
- if target_device is None: - # Use the current device if CUDA is available, otherwise CPU - if torch.cuda.is_available(): - target_device = f"cuda:{torch.cuda.current_device()}" - else: - target_device = "cpu" + if target_device is None: + # Safer default: consolidate on CPU to avoid GPU OOM during merges. + target_device = "cpu"src/tplr/neurons.py (1)
388-396: Move function-scope import to module level to avoid repeated import and potential cycles._import inside merge_partial_batch is unnecessary. Import Comms at top-level for consistency/perf (unless this was added to break a cycle).
Apply this diff:
- from tplr.comms import Comms + # at module top: + # from tplr.comms import CommsAnd remove the in-function import.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
hparams/2B.json(1 hunks)neurons/validator.py(13 hunks)src/tplr/comms.py(11 hunks)src/tplr/neurons.py(9 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- neurons/validator.py
- hparams/2B.json
🧰 Additional context used
🧬 Code graph analysis (2)
src/tplr/neurons.py (2)
src/tplr/compress.py (4)
compress(373-377)compress(379-383)compress(386-423)unpack_12bit_indices(102-142)src/tplr/comms.py (2)
merge_gather_results(2096-2209)get(1467-1582)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (1)
src/tplr/comms.py (1)
2432-2451: Good fix: all-ranks sequential fallback preserves consistency.Sequential fallback no longer runs only on master; all ranks perform the same path. This avoids rank divergence.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
neurons/miner.py (1)
876-882: Fix AttributeError when gather_result is a list of partials.When distributed gather returns a list of partials, there's no top-level
skipped_uidsattribute. Accessinggather_result.skipped_uidsat line 879 and 881 will raise anAttributeError.Apply this diff to handle both merged and partial results:
# Add successful peers information if gather_result is not None: - debug_dict["successful_peers"] = sorted( - list(set(self.comms.peers) - set(gather_result.skipped_uids)) - ) - debug_dict["skipped_peers"] = sorted( - list(gather_result.skipped_uids) - ) + # Handle both merged results and list of partials + if isinstance(gather_result, list): + # For partials, collect all UIDs that were successfully gathered + successful_uids = set() + for partial in gather_result: + if partial is not None: + successful_uids.update(getattr(partial, "uids", [])) + debug_dict["successful_peers"] = sorted(list(successful_uids)) + debug_dict["skipped_peers"] = sorted(list(set(self.comms.peers) - successful_uids)) + else: + # Merged result has skipped_uids attribute + debug_dict["successful_peers"] = sorted( + list(set(self.comms.peers) - set(gather_result.skipped_uids)) + ) + debug_dict["skipped_peers"] = sorted( + list(gather_result.skipped_uids) + )
♻️ Duplicate comments (2)
src/tplr/neurons.py (2)
686-686: Fix inverted budget_exceeded flag logic.The
budget_exceededflag is inverted: it's set to0when memory exceeds the budget and1when within budget. This will cause incorrect metrics and monitoring.Apply this diff:
- budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0Based on learnings from past review comments (lines 686 flagged previously).
977-1006: Fix inverted budget_exceeded flag in single-result path.The
budget_exceededflag has the same inversion issue as the incremental path (line 686): it's0when exceeded and1when within budget.Apply this diff:
- budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0Based on learnings from past review comments (line 854 in old numbering).
🧹 Nitpick comments (1)
ecosystem.config.js (1)
5-11: Remove unused RANDOM_SUFFIX generation.The
RANDOM_SUFFIXis generated but no longer used after switching to a staticPROJECT_NAME. Consider removing lines 5-9 to eliminate dead code.Apply this diff:
-const { execSync } = require('child_process'); -const RANDOM_SUFFIX = execSync( - "cat /dev/urandom | tr -dc 'a-z0-9' | fold -w 4 | head -n 1" -) - .toString() - .trim(); - const PROJECT_NAME = `distribute-gather`;Additionally, verify whether the static project name could cause conflicts if multiple test environments run concurrently.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
ecosystem.config.js(6 hunks)hparams/hparams.json(1 hunks)neurons/miner.py(1 hunks)neurons/trainer.py(1 hunks)src/tplr/model_factory.py(1 hunks)src/tplr/neurons.py(9 hunks)
✅ Files skipped from review due to trivial changes (1)
- hparams/hparams.json
🧰 Additional context used
🧬 Code graph analysis (2)
neurons/miner.py (3)
src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/logging.py (1)
T(41-48)src/tplr/comms.py (2)
gather_with_reserve(2338-2610)gather(1686-2064)
src/tplr/neurons.py (2)
src/tplr/compress.py (5)
compress(373-377)compress(379-383)compress(386-423)TopKCompressor(295-696)unpack_12bit_indices(102-142)src/tplr/comms.py (2)
merge_gather_results(2096-2209)get(1467-1582)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (12)
src/tplr/model_factory.py (1)
256-256: LGTM: Validator dp_shard constraint relaxation.Changing
dp_shardfrom 4 to 1 for validators removes divisibility requirements, allowing validators to run with any world size. This aligns with the distributed gather enhancements in this PR.neurons/trainer.py (1)
1004-1009: LGTM: Proper integration of memory budgeting parameters.The trainer correctly forwards the new
memory_budget_mbandbatch_size_overrideparameters toouter_step, using safegetattrwithNonedefaults for backward compatibility.neurons/miner.py (1)
705-762: LGTM: Distributed gather implementation.The distributed gather path is well-structured with proper rank logging, error handling, and fallback to sequential mode on failure. The use of
return_partials=Trueenables memory-bounded incremental processing.src/tplr/neurons.py (6)
22-22: LGTM: Time import added for batch timing.The
timemodule import is used for batch processing timing measurements in the incremental path.
226-297: LGTM: Clean dispatch pattern for incremental vs single-result processing.The refactored
outer_stepproperly dispatches to incremental or single-result implementations based on whethergather_resultis a list. The extended signature withmemory_budget_mbandbatch_size_overrideenables memory-bounded processing.
300-396: LGTM: Helper functions for memory-bounded batch processing.The helper functions provide reasonable heuristics for memory estimation (~40MB per UID) and batch size calculation with a 0.7 safety factor. The merge helper correctly delegates to existing merge logic.
516-627: LGTM: Robust batch processing with graceful error recovery.The batch processing loop properly:
- Checks for duplicate UIDs
- Handles exceptions gracefully without breaking the loop
- Performs aggressive memory cleanup after each batch
- Tracks failed batches for comprehensive error reporting
643-720: LGTM: Proper statistics aggregation across partials.The fingerprint finalization correctly:
- Computes final norms via sqrt of accumulated squared norms
- Averages means across batches
- Logs comprehensive summaries including failed batches
- Integrates memory metrics with WandB
1749-1883: LGTM: Efficient partial support in overlap checking.The refactored overlap logic properly:
- Handles both merged and partial results without unnecessary merging
- Uses O(1)
uid_to_partialmapping to avoid O(n×m) complexity- Maintains original UID indices for consistent pair keys
- Correctly unpacks 12-bit indices per UID
The O(n×m) optimization from past reviews is properly implemented.
ecosystem.config.js (3)
20-20: Verify virtual environment path across all deployment targets.The script path now hardcodes
.venv/bin/torchrun(also on lines 44, 68, 94). Ensure this path exists and is consistent across all environments where this configuration will be deployed.
64-87: LGTM!The new TM3 miner app is properly configured and follows the same structure as TM1 and TM2, which aligns with the PR's goal of supporting more miners.
37-111: Remove unsubstantiated GPU testing claim; verify target environment has sufficient GPUs.The ecosystem.config.js configuration allocates 7 GPUs (TM1: 0,1; TM2: 2,3; TM3: 4,5; TV1: 6). While this GPU requirement is accurately reflected in the code, the PR commit message does not mention testing with any specific GPU count. Ensure your target deployment environment has at least 7 GPUs available, or adjust the configuration to match your hardware.
Likely an incorrect or invalid review comment.
There was a problem hiding this comment.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
hparams/2B.json(1 hunks)neurons/miner.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- hparams/2B.json
🧰 Additional context used
🧬 Code graph analysis (1)
neurons/miner.py (3)
src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/logging.py (1)
T(41-48)src/tplr/comms.py (2)
gather_with_reserve(2338-2610)gather(1686-2064)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (2)
neurons/miner.py (2)
705-710: LGTM: Robust distributed-mode detection.The three-way check (config flag, world_size, and runtime state) is defensive and correct.
803-805: LGTM: Proper synchronization.Broadcasting
should_updateensures all ranks agree before proceeding toouter_step, preventing divergence.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (3)
neurons/miner.py (1)
935-1045: Fix distributed gather_result attribute access.When
return_partials=True,gather_resultis alist[SimpleNamespace]. The blocks starting at Line 937 and Line 978 still treat it as a merged namespace (gather_result.skipped_uids,gather_result.success_rate, etc.), so the first distributed gather will raiseAttributeErroron every rank. Merge the partials (e.g., viaself.comms.merge_gather_results(...)) or add safe accessors before computing debug/metrics fields.src/tplr/neurons.py (2)
686-686: Inverted budget_exceeded flag persists (duplicate issue).This line still has the inverted logic flagged in previous reviews: the flag is set to
0when budget is exceeded and1when within budget.The same inverted flag is then logged to WandB at line 714, propagating incorrect metrics.
983-983: Inverted budget_exceeded flag persists (duplicate issue).This line has the same inverted logic flagged in previous reviews.
The inverted flag is then logged to WandB at line 1002.
🧹 Nitpick comments (1)
src/tplr/neurons.py (1)
315-318: Consider making the memory estimation heuristic configurable.The hard-coded 40MB per UID estimate may not fit all model sizes. Consider adding a hyperparameter like
gradient_mb_per_uidto allow tuning without code changes.Apply this diff to make it configurable:
def _estimate_partial_memory_mb(partial: SimpleNamespace) -> float: """ Estimate memory usage of a partial result in MB. - This is a heuristic based on the number of UIDs and typical gradient size. + This is a heuristic based on the number of UIDs and typical gradient size. + The estimate can be tuned via hparams.gradient_mb_per_uid if needed. Args: partial: Partial result from distributed gather Returns: Estimated memory in MB """ if partial is None or not hasattr(partial, "uids"): return 0.0 - # Rough heuristic: ~30-50MB per UID for typical gradients - # This can be tuned based on actual measurements num_uids = len(getattr(partial, "uids", [])) - avg_mb_per_uid = 40.0 # Conservative estimate + # Default 40MB, can be overridden via hparams if needed + avg_mb_per_uid = 40.0 return num_uids * avg_mb_per_uidThen pass it as a parameter if the hparam exists.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
hparams/2B.json(1 hunks)neurons/miner.py(2 hunks)neurons/trainer.py(1 hunks)src/tplr/comms.py(13 hunks)src/tplr/neurons.py(9 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- neurons/trainer.py
- hparams/2B.json
🧰 Additional context used
🧬 Code graph analysis (3)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
src/tplr/neurons.py (2)
src/tplr/compress.py (6)
compress(373-377)compress(379-383)compress(386-423)ChunkingTransformer(145-292)TopKCompressor(295-696)unpack_12bit_indices(102-142)src/tplr/comms.py (2)
merge_gather_results(2102-2215)get(1473-1588)
neurons/miner.py (4)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/logging.py (1)
T(41-48)src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/comms.py (2)
gather_with_reserve(2344-2616)gather(1692-2070)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (1)
src/tplr/neurons.py (1)
1749-1846: LGTM: Efficient partial handling implementation.The O(1) uid_to_partial lookup (lines 1754-1764, 1822) successfully addresses the O(n×m) complexity concern from previous reviews. The dual-path logic cleanly handles both partials and merged results.
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
hparams/hparams.json(1 hunks)neurons/base_node.py(1 hunks)neurons/miner.py(5 hunks)src/tplr/comms.py(13 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
neurons/miner.py (4)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/logging.py (1)
T(41-48)src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/comms.py (2)
gather_with_reserve(2369-2641)gather(1714-2092)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (2)
hparams/hparams.json (1)
22-22: Clarify the intent of changing the default model size from 70B to 2B.This change updates the default model configuration from 70B to 2B, which is a significant architectural change. The PR also introduces a new
hparams/2B.jsonfile. Please clarify:
- Is this change intentional?
- Should the default remain 70B with 2B as an alternative configuration file (rather than replacing the default)?
- Will this break existing workflows, training jobs, or downstream code that assumes the 70B model?
- Are there backward compatibility concerns?
neurons/base_node.py (1)
61-62: Verified: Cooperative multiple inheritance correctly implemented.The addition of
super().__init__()at the end ofBaseNode.__init__()properly completes the MRO chain for multiple inheritance scenarios. Verification confirms:
Miner(BaseNode, Trainer)callssuper().__init__()at its end (line 197)Validator(BaseNode, Trainer)callssuper().__init__()at its end (line 377)BaseNodenow callssuper().__init__()at line 62, delegating toTrainerTrainerintentionally stops the chain (nosuper()call), serving as the terminusThis follows the proper cooperative multiple inheritance pattern in Python and integrates correctly with the existing subclass implementations.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
neurons/miner.py (1)
707-745: LGTM: Sequential gather correctly addresses previous critical issues.The sequential gather implementation properly resolves the issues flagged in previous reviews:
- Returns a single
SimpleNamespaceinstead of a list of partials, eliminatingAttributeErrorrisks- Includes
local=Falseandstale_retention=100to pull from remote storage (R2)- All attribute accesses on
gather_resultare properly guarded withNonechecks- Appropriate master-only execution with broadcast to other ranks
Optional: Defensive coding for
expected_compressed_paramsWhile the past review indicates the initialization issue was fixed via MRO chain repair, consider defensive coding to guard against potential
AttributeError:time_min=time_min, time_max=time_max, - expected_compressed_params=self.expected_compressed_params, + expected_compressed_params=getattr(self, 'expected_compressed_params', None), )This ensures the code won't crash if the attribute is somehow missing, since
gather()handlesNonegracefully (defaults to empty set).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
neurons/miner.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
neurons/miner.py (3)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/logging.py (1)
T(41-48)src/tplr/comms.py (1)
gather(1714-2092)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (1)
neurons/miner.py (1)
468-469: LGTM: Commitment refresh ensures up-to-date bucket info.Refreshing commitments after updating peers is the correct approach to ensure bucket information is current before the gather operation.
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
neurons/validator.py (1)
3911-3933: Indent merged‑result branch (syntax error).state_dict and subsequent lines must be inside the else block.
@@ - else: - # Merged result: original logic - state_dict = gather_result.state_dict - if not state_dict: - raise ValueError("Must have gather_result.state_dict to compute norms") + else: + # Merged result: original logic + state_dict = gather_result.state_dict + if not state_dict: + raise ValueError("Must have gather_result.state_dict to compute norms") @@ - for n, p in self.model.named_parameters(): + for n, p in self.model.named_parameters(): @@ - vals = getattr(state_dict, vals_key, None) - quant_params = getattr(state_dict, quant_key, None) + vals = getattr(state_dict, vals_key, None) + quant_params = getattr(state_dict, quant_key, None) @@ - if vals is None: - continue + if vals is None: + continue @@ - vals_f32 = self.compressor.maybe_dequantize_values( - vals, quant_params, p.device - ) + vals_f32 = self.compressor.maybe_dequantize_values( + vals, quant_params, p.device + ) @@ - norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device) - clip_norm_dict[vals_key] = torch.median(norms) + norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device) + clip_norm_dict[vals_key] = torch.median(norms)
♻️ Duplicate comments (4)
neurons/validator.py (1)
3837-3910: Align quant_params with vals across partials (prevents IndexError).If a partial has vals but missing/short quant_params, later dequantization can index past the end.
@@ - if vals is not None: - # vals is a list of tensors, one per peer in this partial - # Move all tensors (including nested) to target device - for v in vals: - all_vals.append(move_to_device_recursive(v)) - # quant_params is also a list if present - if quant_params is not None: - for qp in quant_params: - all_quant_params.append(move_to_device_recursive(qp)) + if vals is not None: + num_vals = len(vals) + # Move all tensors (including nested) to target device + for v in vals: + all_vals.append(move_to_device_recursive(v)) + # Ensure quant_params list stays index‑aligned with vals + if quant_params is not None and len(quant_params) >= num_vals: + for qp in quant_params[:num_vals]: + all_quant_params.append(move_to_device_recursive(qp)) + elif quant_params is not None and 0 < len(quant_params) < num_vals: + for qp in quant_params: + all_quant_params.append(move_to_device_recursive(qp)) + # pad with None for the remainder + all_quant_params.extend([None] * (num_vals - len(quant_params))) + tplr.logger.warning(f"Inconsistent quant_params length for {vals_key}; padding with None") + else: + # no quant_params for this partial → pad + all_quant_params.extend([None] * num_vals)Based on learnings
src/tplr/neurons.py (3)
697-705: Invert memory_budget_violation flag (1 = violation, 0 = OK).Same inverted logic previously flagged; fix to avoid misleading monitoring.
- budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0
1001-1021: Invert memory_budget_violation flag here as well.Mirror the fix in the single‑result path.
- budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0
300-339: Support SimpleNamespace and lists in partial memory estimator.Current code only handles dict; partial.state_dict is often SimpleNamespace with list values.
@@ - # Measure actual tensor sizes in state_dict - if hasattr(partial, "state_dict") and isinstance(partial.state_dict, dict): - for value in partial.state_dict.values(): - if isinstance(value, torch.Tensor): - total_bytes += value.element_size() * value.nelement() - elif isinstance(value, dict): - # Handle nested dicts (e.g., compressed gradients) - for nested_val in value.values(): - if isinstance(nested_val, torch.Tensor): - total_bytes += nested_val.element_size() * nested_val.nelement() + # Measure actual tensor sizes in state_dict (SimpleNamespace or dict) + if hasattr(partial, "state_dict"): + sd = partial.state_dict + values = ( + vars(sd).values() + if isinstance(sd, SimpleNamespace) + else (sd.values() if isinstance(sd, dict) else []) + ) + for value in values: + if isinstance(value, torch.Tensor): + total_bytes += value.element_size() * value.nelement() + elif isinstance(value, (list, tuple)): + for v in value: + if isinstance(v, torch.Tensor): + total_bytes += v.element_size() * v.nelement() + elif isinstance(value, dict): + for v in value.values(): + if isinstance(v, torch.Tensor): + total_bytes += v.element_size() * v.nelement()
🧹 Nitpick comments (1)
src/tplr/comms.py (1)
2473-2527: Distributed gather with proper fallback logic.The fallback logic correctly ensures all ranks fall back together (lines 2511-2527) to maintain consistency when distributed gather fails. This addresses the previous review comment about rank consistency.
The logging clearly distinguishes between distributed and sequential modes, and the error handling is comprehensive.
Consider adding a comment explaining the consistency requirement:
# Fallback to sequential if distributed failed # ALL ranks must fall back together to maintain consistency if primary is None: # Sequential gather can be done independently by each rank # since it fetches from R2, not inter-rank communication
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
neurons/validator.py(12 hunks)src/tplr/comms.py(16 hunks)src/tplr/neurons.py(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
src/tplr/neurons.py (2)
src/tplr/compress.py (6)
compress(373-377)compress(379-383)compress(386-423)ChunkingTransformer(145-292)TopKCompressor(295-696)unpack_12bit_indices(102-142)src/tplr/comms.py (3)
Comms(63-3392)merge_gather_results(2170-2283)get(1499-1615)
neurons/validator.py (6)
src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)neurons/base_node.py (1)
query_block_timestamp(205-250)src/tplr/comms.py (2)
gather_with_reserve(2415-2687)merge_gather_results(2170-2283)src/tplr/neurons.py (1)
check_uid_index_overlap(1751-1956)src/tplr/compress.py (1)
maybe_dequantize_values(646-696)
🪛 GitHub Actions: CI
neurons/validator.py
[error] 1457-1457: invalid-syntax: Expected an indented block after if statement
[error] 1465-1465: invalid-syntax: unindent does not match any outer indentation level
[error] 1466-1466: invalid-syntax: Expected a statement
[error] 1590-1591: invalid-syntax: Expected an indented block after if statement
[error] 1597-1599: invalid-syntax: Unexpected indentation around else block
[error] 3912-3914: invalid-syntax: Expected an indented block after else clause
[error] 4537-4537: Syntax issue detected at function definition
🔇 Additional comments (18)
src/tplr/comms.py (13)
48-48: Import moved to module level - good fix!This addresses the previous review comment about moving
dist_helperimport from inside a function to module level, improving performance and readability.
478-577: Excellent implementation of s3_object_exists with structured error handling.This method properly distinguishes between different error types using status codes and error codes instead of string matching. The conservative approach (returning True for transient errors) is appropriate to avoid false slashing. Good logging at ERROR level for configuration issues that need attention.
1465-1497: Good fix for gradient_timestamp signature.Adding an optional
versionparameter that defaults totplr.__version__addresses the previous review comment about undefinedself.version. This allows flexibility while maintaining backward compatibility.
703-717: Critical performance fix: offloading torch.load() to thread pool.This prevents blocking the event loop during CPU decompression, which can take 1-2s per file and would otherwise serialize parallel downloads. The fix is correctly applied using
asyncio.to_thread().
1557-1558: torch.load() offloading also applied to local storage path.Consistent with the remote storage fix, ensuring non-blocking behavior across all code paths.
1836-1867: Excellent optimization: parallel S3 existence checks for failed downloads.Batching all existence checks and running them in parallel avoids the serial API calls that would add ~5s per failure. The cached results are properly reused later in the error handling logic. Good use of asyncio.gather with return_exceptions=True.
1869-1931: Well-structured error handling that prevents false slashing.The code correctly distinguishes between network/timeout issues (file exists but download failed) and miner failures (file doesn't exist). Uses the cached existence check results efficiently and provides clear logging with appropriate levels.
2144-2167: Correct round-robin partitioning implementation.The deterministic sorting ensures all ranks agree on the assignment, and the slicing approach
sorted_uids[rank::world_size]correctly distributes UIDs evenly across ranks.
2198-2283: Comprehensive merging logic with proper device consolidation.The method correctly:
- Filters out None results
- Recursively moves tensors to target device via
move_to_devicehelper- Merges state_dict by concatenating lists for each parameter
- Aggregates metrics (max time, sum bytes, extend arrays)
- Calculates accurate success rate as
len(all_uids) / total_attemptedGood logging of the merge summary.
2285-2410: Well-implemented distributed gather with proper synchronization.The method correctly:
- Partitions UIDs deterministically across ranks for parallel fetching
- Each rank independently gathers its assigned subset
- Uses
all_gather_objectfor barrier synchronization (correctly called synchronously)- Adds rank metadata to partials for deterministic ordering
- Supports both incremental (return_partials=True) and merged modes
Good contextual logging with rank and window information throughout.
2552-2642: Reserve fallback logic correctly handles both partials and merged results.The code properly:
- Detects missing UIDs for both result types (lines 2552-2580)
- Uses sequential gather for reserve peers (not distributed, which is correct)
- Appends reserve as new partial with proper rank metadata (lines 2613-2625)
- Merges reserve into merged result by extending arrays (lines 2627-2642)
Good contextual logging for both paths.
2647-2686: Success rate calculation and final validation are correct.The method properly:
- Calculates success rate relative to original gather_uids (not including reserves)
- Handles both partials list and merged SimpleNamespace
- Updates success_rate attribute on partials for consistency (lines 2654-2657)
- Returns None if no gradients were gathered (lines 2671-2685)
The aggregate success rate calculation for partials (lines 2649-2652) is correct.
1661-1728: Good optimization: early termination for repeated NOT_FOUND errors.Tracking consecutive NOT_FOUND errors and giving up after 5 attempts (2.5s total) prevents unnecessary retries when a file likely doesn't exist. The counter is properly reset on other error types that might be transient network issues.
neurons/validator.py (2)
91-134: Helpers for UIDs/skips/success-rate look good.Uniform handling across partials vs merged is clear and side‑effect free.
1391-1410: Confirm broadcast_object_list compatibility across ranks.All ranks must supply lists of identical length/types; also large peer lists can bloat pickles. If size grows, consider broadcasting only IDs and reconstructing from metagraph.
Would you like a follow‑up patch to broadcast only UID tensors to cut overhead?
src/tplr/neurons.py (3)
416-451: Incremental batching pipeline reads well.Batch sizing, merging via Comms.merge_gather_results, duplicate‑UID guard, and per‑batch cleanup are solid.
Consider logging per‑batch GPU mem deltas to validate budget sizing heuristics in production.
741-771: Single‑result path is efficient and DTensor‑aware.Param‑at‑a‑time update, direct broadcast into p.grad, and fingerprinting are clean.
Also applies to: 789-899, 941-971
1753-1956: Overlap checker partials path avoids O(n×m) lookup—nice.UID→partial map and consistent pair accounting improve scalability.
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
neurons/validator.py (1)
3835-3933: Fix UnboundLocalError and align quant_params in compute_peer_val_norms()When gather_result is a list, state_dict is undefined but used later, causing a crash. Also, vals and quant_params can become misaligned across partials.
Apply:
@@ - def compute_peer_val_norms( - self, - gather_result: SimpleNamespace | list[SimpleNamespace], - ) -> dict[str, torch.Tensor]: + def compute_peer_val_norms( + self, + gather_result: SimpleNamespace | list[SimpleNamespace], + ) -> dict[str, torch.Tensor]: @@ - clip_norm_dict = {} - # Handle both merged results and partials - if isinstance(gather_result, list): - # Partials: collect vals from all partials for each parameter - # Use model's first parameter device as target (typically cuda:0) - target_device = next(self.model.parameters()).device - # Helper to recursively move all tensors to target device - def move_to_device_recursive(obj): - if torch.is_tensor(obj): - return obj.to(target_device) - elif isinstance(obj, dict): - return {k: move_to_device_recursive(v) for k, v in obj.items()} - elif isinstance(obj, (list, tuple)): - return type(obj)(move_to_device_recursive(item) for item in obj) - else: - return obj - for n, p in self.model.named_parameters(): - vals_key = n + "vals" - quant_key = n + "quant_params" - # Collect vals and quant_params from all partials - all_vals = [] - all_quant_params = [] - for partial in gather_result: - # Skip None partials (ranks that gathered 0 peers) - if partial is None: - continue - state_dict = partial.state_dict - if not state_dict: - continue - vals = getattr(state_dict, vals_key, None) - quant_params = getattr(state_dict, quant_key, None) - if vals is not None: - # vals is a list of tensors, one per peer in this partial - # Move all tensors (including nested) to target device - for v in vals: - all_vals.append(move_to_device_recursive(v)) - # quant_params is also a list if present - if quant_params is not None: - for qp in quant_params: - all_quant_params.append(move_to_device_recursive(qp)) - if not all_vals: - continue - # Dequantize all collected values (now all on same device) - vals_f32 = self.compressor.maybe_dequantize_values( - all_vals, - all_quant_params if all_quant_params else None, - target_device, - ) - # Compute median norm across all peers from all partials - norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device) - clip_norm_dict[vals_key] = torch.median(norms) - else: - # Merged result: original logic - state_dict = gather_result.state_dict - if not state_dict: - raise ValueError("Must have gather_result.state_dict to compute norms") - for n, p in self.model.named_parameters(): - vals_key = n + "vals" - quant_key = n + "quant_params" - vals = getattr(state_dict, vals_key, None) - quant_params = getattr(state_dict, quant_key, None) - if vals is None: - continue - vals_f32 = self.compressor.maybe_dequantize_values( - vals, quant_params, p.device - ) - norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device) - clip_norm_dict[vals_key] = torch.median(norms) - return clip_norm_dict + clip_norm_dict: dict[str, torch.Tensor] = {} + + # Common helper + target_device = next(self.model.parameters()).device + def move_to_device_recursive(obj): + if torch.is_tensor(obj): + return obj.to(target_device) + if isinstance(obj, dict): + return {k: move_to_device_recursive(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return type(obj)(move_to_device_recursive(x) for x in obj) + return obj + + if isinstance(gather_result, list): + # Partials: aggregate across partials + for n, p in self.model.named_parameters(): + vals_key = n + "vals" + quant_key = n + "quant_params" + all_vals: list[torch.Tensor] = [] + all_qp: list[QuantParamsT | None] = [] + for partial in gather_result: + if partial is None or not getattr(partial, "state_dict", None): + continue + state_dict = partial.state_dict + vals = getattr(state_dict, vals_key, None) + qps = getattr(state_dict, quant_key, None) + if vals is None: + continue + # Append vals + for v in vals: + all_vals.append(move_to_device_recursive(v)) + # Align quant_params length with vals length + num_vals = len(vals) + if qps is None: + all_qp.extend([None] * num_vals) + else: + qps_moved = [move_to_device_recursive(q) for q in qps[:num_vals]] + if len(qps_moved) < num_vals: + qps_moved.extend([None] * (num_vals - len(qps_moved))) + all_qp.extend(qps_moved) + if not all_vals: + continue + vals_f32 = self.compressor.maybe_dequantize_values(all_vals, all_qp, target_device) + norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device) + clip_norm_dict[vals_key] = torch.median(norms) + return clip_norm_dict + + # Merged result path + state_dict = getattr(gather_result, "state_dict", None) + if not state_dict: + raise ValueError("Must have gather_result.state_dict to compute norms") + for n, p in self.model.named_parameters(): + vals_key = n + "vals" + quant_key = n + "quant_params" + vals = getattr(state_dict, vals_key, None) + if vals is None: + continue + qps = getattr(state_dict, quant_key, None) + vals_f32 = self.compressor.maybe_dequantize_values(vals, qps, p.device) + norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device) + clip_norm_dict[vals_key] = torch.median(norms) + return clip_norm_dictBased on learnings.
src/tplr/comms.py (1)
578-607: Docstring says bucket defaults to self.bucket, but code doesn't; add fallback.Without a fallback, callers passing bucket=None will crash. Enforce the documented default before creating the client.
async def s3_get_object( self, key: str, - bucket: Bucket | None = None, + bucket: Bucket | None = None, timeout: int = 30, @@ - s3_client = await self._get_s3_client(bucket) + if bucket is None: + bucket = self.bucket + s3_client = await self._get_s3_client(bucket)
♻️ Duplicate comments (3)
src/tplr/comms.py (1)
2191-2197: Guard torch.cuda.current_device() with try/except and CPU fallback.current_device() can raise if no context is set. Use a safe fallback.
- if target_device is None: - # Use the current device if CUDA is available, otherwise CPU - if torch.cuda.is_available(): - target_device = f"cuda:{torch.cuda.current_device()}" - else: - target_device = "cpu" + if target_device is None: + if torch.cuda.is_available(): + try: + target_device = f"cuda:{torch.cuda.current_device()}" + except RuntimeError: + target_device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu" + else: + target_device = "cpu"src/tplr/neurons.py (2)
501-504: Do not pass string device to CUDA APIs; use default or torch.device.torch.cuda.memory_allocated()/synchronize() with "cuda:0" (str) can raise. Omit the arg or pass torch.device.
- if memory_budget_mb is not None and torch.cuda.is_available(): - initial_memory_mb = torch.cuda.memory_allocated(device) / (1024**2) + if memory_budget_mb is not None and torch.cuda.is_available(): + initial_memory_mb = torch.cuda.memory_allocated() / (1024**2) @@ - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize(device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() @@ - final_memory_mb = torch.cuda.memory_allocated(device) / (1024**2) + final_memory_mb = torch.cuda.memory_allocated() / (1024**2) @@ - final_memory_mb = torch.cuda.memory_allocated(device) / (1024**2) + final_memory_mb = torch.cuda.memory_allocated() / (1024**2) @@ - tplr.logger.warning( + tplr.logger.warning( f"⚠️ GPU Memory: {final_memory_mb:.0f}MB / {memory_budget_mb:.0f}MB " f"({utilization_pct:.0f}% - EXCEEDED by {final_memory_mb - memory_budget_mb:.0f}MB)" ) - else: + else: tplr.logger.info( f"GPU Memory: {final_memory_mb:.0f}MB / {memory_budget_mb:.0f}MB " f"({utilization_pct:.0f}% utilized)" )Also applies to: 617-623, 697-705, 996-1002, 1003-1014
704-705: Flip memory_violation flag: 1 = violation, 0 = OK.Current logic is inverted and misleads dashboards.
- budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0Also applies to: 1001-1002
🧹 Nitpick comments (7)
neurons/miner.py (1)
559-561: Guard CUDA-only calls (optional)empty_cache/synchronize should be gated to CUDA to avoid surprises in CPU/testing.
- torch.cuda.empty_cache() - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize(self.device)neurons/validator.py (2)
138-203: Handle partial lists in offload/load helpers to actually free GPU memoryoffload_gather_results/load_gather_results_to_device operate only on a merged SimpleNamespace. With distributed gather (partials), tensors remain on GPU.
@@ def offload_gather_results(self, gather_result, *, log: bool = True) -> None: - if gather_result is None: + if gather_result is None: return + # NEW: handle list of partials + if isinstance(gather_result, list): + count = 0 + for partial in gather_result: + if partial is not None: + self.offload_gather_results(partial, log=False) + count += 1 + if log and count: + tplr.log_with_context( + level="info", + message=f"[GatherOffload] Offloaded {count} partial(s) to CPU pinned", + sync_window=self.sync_window, + current_window=self.current_window, + ) + return @@ def load_gather_results_to_device(self, gather_result, device: torch.device, *, log: bool = True) -> None: - if gather_result is None: + if gather_result is None: return + # NEW: handle list of partials + if isinstance(gather_result, list): + count = 0 + for partial in gather_result: + if partial is not None: + self.load_gather_results_to_device(partial, device, log=False) + count += 1 + if log and count: + tplr.log_with_context( + level="info", + message=f"[GatherOffload] Loaded {count} partial(s) back to {device.type}", + sync_window=self.sync_window, + current_window=self.current_window, + ) + return
1563-1574: Nit: pass device as string for consistencygather_with_reserve accepts device used in map_location/to(). Passing str(self.device) matches miner and avoids type surprises with custom loaders.
- device=cast(str, self.device), + device=str(self.device),src/tplr/comms.py (3)
650-655: Prefer structured ClientError inspection over string matching in HEAD error path.Replace "404" string checks with HTTPStatusCode/Error.Code parsing (consistent with s3_object_exists), and keep purging on client errors.
- except (ConnectionClosedError, ClientError) as e: - await self._purge_s3_client(bucket) - if "404" in str(e): - tplr.logger.debug(f"Object {key} not found in bucket {bucket.name}") - return None + except (ConnectionClosedError, ClientError) as e: + await self._purge_s3_client(bucket) + status = None + code = None + if isinstance(e, ClientError) and hasattr(e, "response"): + status = e.response.get("ResponseMetadata", {}).get("HTTPStatusCode") + code = e.response.get("Error", {}).get("Code") + if status == 404 or code in {"404", "NoSuchKey", "NotFound"}: + tplr.logger.debug(f"Object {key} not found in bucket {bucket.name}") + return None + if status in {401, 403} or code in {"AccessDenied", "InvalidAccessKeyId", "SignatureDoesNotMatch"}: + tplr.logger.error(f"Permission error checking {key} in {bucket.name}: {e}") + return None + # For transient 5xx/network errors, let caller retry + tplr.logger.warning(f"Transient error on HEAD {key}: {e}") + return None
478-577: Purge broken S3 client on connection close to avoid reuse of a bad pool.If the underlying connection is closed, purge the cached client so next call recreates it. Keeps semantics unchanged.
- except (ConnectionClosedError, ClientError) as e: + except (ConnectionClosedError, ClientError) as e: + # Drop broken clients so future calls re-create a fresh one + if isinstance(e, ConnectionClosedError): + try: + await self._purge_s3_client(bucket) + except Exception: + pass
1850-1867: Throttle parallel existence checks to avoid overwhelming R2/S3.Bound concurrency with a semaphore (reuse self.client_semaphore or a local one).
- if failed_checks: - existence_tasks = [ - self.s3_object_exists(key=check[1], bucket=check[2], timeout=5) - for check in failed_checks - ] - existence_responses = await asyncio.gather( - *existence_tasks, return_exceptions=True - ) + if failed_checks: + sem = asyncio.Semaphore(min(32, len(failed_checks))) + async def _checked(key, bucket): + async with sem: + return await self.s3_object_exists(key=key, bucket=bucket, timeout=5) + existence_tasks = [_checked(check[1], check[2]) for check in failed_checks] + existence_responses = await asyncio.gather(*existence_tasks, return_exceptions=True)src/tplr/neurons.py (1)
452-456: Remove duplicate imports inside function.gc and torch.cuda are already imported at module level; keep imports top-level.
- import gc - - import torch.cuda + # (imports are module‑level; no per‑call imports needed)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
hparams/2B.json(1 hunks)neurons/miner.py(3 hunks)neurons/validator.py(13 hunks)src/tplr/comms.py(16 hunks)src/tplr/compress.py(1 hunks)src/tplr/neurons.py(9 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- hparams/2B.json
🧰 Additional context used
🧬 Code graph analysis (3)
src/tplr/neurons.py (2)
src/tplr/compress.py (5)
compress(373-377)compress(379-383)compress(386-423)TopKCompressor(295-697)unpack_12bit_indices(102-142)src/tplr/comms.py (2)
merge_gather_results(2170-2283)get(1499-1617)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
neurons/miner.py (3)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/logging.py (1)
T(41-48)src/tplr/comms.py (1)
gather(1736-2138)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (2)
src/tplr/compress.py (1)
639-642: LGTM! Device alignment for multi-GPU consistency.The addition of device alignment for
shiftmirrors the existing pattern forlookupand ensures all operands in the dequantization calculation (line 643) reside on the same device. This prevents device mismatch errors in multi-GPU scenarios and aligns with the PR's goal of automatic multi-GPU tensor consolidation.neurons/miner.py (1)
716-731: ****The review comment is incorrect. Miner inherits from BaseNode and Trainer, and calls super().init() at line 197. BaseNode.init explicitly calls super().init() at line 62 to continue the MRO chain and initialize parent classes like Trainer. Trainer.init initializes self.expected_compressed_params at line 123. The attribute is properly initialized through normal inheritance—no AttributeError will occur, and the proposed diff is unnecessary.
Likely an incorrect or invalid review comment.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/tplr/comms.py (1)
2379-2395: All-gather logic is correct, but logging could be clearer about synchronization point.The synchronous call to
dist_helper.all_gather_object(line 2388) is correctly documented as a blocking barrier operation. All ranks execute this regardless of their individualpartial_result, which maintains distributed consistency.Consider enhancing the log message at lines 2382-2386 to explicitly note this is a synchronization barrier, as developers debugging hangs might benefit from knowing this is an expected all-ranks rendezvous point.
tplr.log_with_context( level="info", - message=f"[DISTRIBUTED_GATHER] Rank {rank} performing all-gather of partial results", + message=f"[DISTRIBUTED_GATHER] Rank {rank} entering all-gather barrier (synchronization point for all ranks)", window=window, )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/tplr/comms.py(18 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (8)
src/tplr/comms.py (8)
482-581: Well-implemented conservative existence check with structured error handling.The method correctly prioritizes avoiding false negatives for slashing decisions. The structured ClientError inspection (lines 524-527) is a significant improvement over string matching, and the tiered error handling (404 → False, transient → True) aligns with the stated design goal.
One minor observation: line 511-514's
asyncio.wait_forwrapper aroundhead_objectmay leave the underlying HTTP connection in an indeterminate state if timeout fires mid-request. Consider whether the existing purge/retry logic in callers adequately handles this, or if explicit connection cleanup is needed on timeout.
707-721: Excellent fix: offloading torch.load prevents event loop blocking.The use of
asyncio.to_threadfor CPU-bound decompression is the correct solution for parallel downloads, as noted in the comment. Both remote (lines 707-721) and local (lines 1561-1564) paths are now consistent.The
weights_only=Trueparameter is security-conscious. Verify that no legitimate gradient/checkpoint files in this system require custom unpicklers or arbitrary objects; if they do, this will fail silently (returningNoneon line 721) and may be difficult to debug.
1469-1490: LGTM: version parameter correctly defaults to tplr.version.This addresses the previous review concern about undefined
self.version. The optional parameter provides flexibility while maintaining backward compatibility.
1667-1738: Sensible optimization: early exit on repeated NOT_FOUND reduces wasted retries.The 5-attempt threshold (2.5s total) provides a reasonable balance between accommodating transient issues and avoiding lengthy waits for genuinely missing files. The counter reset on non-NOT_FOUND errors (line 1733) correctly distinguishes persistent absence from intermittent network problems.
Verify that this early-exit logic plays nicely with the
time_maxgrace period (lines 1680-1697). If a file is uploaded slightly late and arrives between attempt 4 and 5, the early exit might prevent a successful fetch that would have occurred on the next retry.
2148-2171: LGTM: clean, deterministic UID partitioning for distributed gather.The round-robin approach (line 2171) is simple and provides good load balance. The sort (line 2168) ensures deterministic assignments across ranks, which is critical for correctness.
2537-2555: Fallback logic maintains consistency: all ranks perform sequential gather.Lines 2537-2555 correctly implement the requirement from past review: when distributed gather fails, all ranks fall back to sequential gather (line 2550), not just the master. The
if dist_helper.is_master:check (line 2540) only controls logging, not execution flow. This maintains consistency and avoids divergent state across ranks.The comment at line 2548-2549 explicitly documents this design choice, which is helpful for future maintainers.
2633-2638: Reserve gather uses sequential mode: verify this is intentional.Line 2633 explicitly uses
self.gather()(sequential) rather thanself.gather_distributed()for the reserve tier. If reserve peers are typically a small subset (lines 2622-2624 show[:len(missing)]), sequential mode may be acceptable. However, if the reserve pool is large and network latency is the bottleneck, using distributed mode for reserve could provide additional speedup.Confirm this design choice aligns with expected reserve pool sizes and performance goals. If reserves are consistently small (e.g., <10 UIDs), sequential is fine; if they can be large, consider distributed mode for consistency.
1842-1908: Parallelization and slashing mechanism verified—no issues found.The parallel existence checks (lines 1854–1870) correctly avoid serial bottlenecks. The caching strategy properly eliminates redundant API calls. The distinction between network failures (file exists but download failed, lines 1893–1900) and miner no-shows (file missing, lines 1901–1908) is sound and correctly implemented.
Verification confirms
skipped_uidsis the proper slashing trigger: it flows fromgather_new_data()toslash_for_missing_gradients()(neurons/validator.py line 1769), which iterates through skipped UIDs and applies penalties. Network failures are correctly excluded (line 1900 ensures they bypassskipped_uids), preventing false slashing. No edge case exists where a UID escapes both valid and skipped tracking.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/tplr/comms.py (1)
2196-2211: Consider avoiding global device state mutation during device selection.Line 2206 calls
torch.cuda.set_device(0), which mutates global CUDA state as a side effect of probing device availability. A cleaner approach would validate device 0 without changing the current device:if torch.cuda.is_available(): try: target_device = f"cuda:{torch.cuda.current_device()}" except RuntimeError: - # Fall back to first available device or CPU + # Fall back to device 0 if available, else CPU if torch.cuda.device_count() > 0: - # Use device 0, but could be extended to find first available - try: - torch.cuda.set_device(0) - target_device = "cuda:0" - except RuntimeError: - target_device = "cpu" + # Validate device 0 exists without mutating state + target_device = "cuda:0" else: target_device = "cpu" else: target_device = "cpu"Note: Tensors will fail on
.to(target_device)if the device doesn't exist, which is caught downstream.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/tplr/comms.py(18 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (10)
src/tplr/comms.py (10)
28-28: LGTM: Imports and partition cache are well-designed.The
OrderedDictanddist_helperimports are properly used throughout the file. The_partition_cachefield provides efficient caching of UID partitions with clear type annotations.Also applies to: 49-49, 141-143
482-581: LGTM: Robust existence check with conservative error handling.The structured error inspection (lines 521-527) and branching logic properly distinguish between missing objects (404/NoSuchKey), configuration errors (403/AccessDenied), and transient service errors (5xx). The conservative approach (returning
Truefor transient errors) prevents false slashing.
707-721: LGTM: Critical fix for parallel download performance.Offloading
torch.load()to a thread pool (viaasyncio.to_thread) prevents CPU-intensive decompression from blocking the event loop during parallel downloads. This enables true parallelism and should significantly improve throughput.Also applies to: 1561-1564
1470-1490: LGTM: Version parameter properly defaults totplr.__version__.The
versionparameter with runtime defaulting addresses previous concerns about undefinedself.version. This provides flexibility for version-specific queries while maintaining backward compatibility.
1667-1734: LGTM: Smart early exit for non-existent files.The consecutive NOT_FOUND tracking (lines 1667-1668, 1723-1734) provides intelligent early termination after 5 consecutive failures (~2.5s), avoiding wasted retries while resetting the counter for transient errors. This balances robustness with efficiency.
1841-1935: LGTM: Parallel S3 existence checks provide major performance improvement.The parallel execution of S3 HEAD requests (lines 1854-1870) is a significant optimization that should eliminate the ~5s per-failure overhead mentioned in the PR objectives. The logic correctly distinguishes between network failures (file exists but download failed → don't slash) and missing uploads (file doesn't exist → legitimate slash).
2148-2172: LGTM: Clean round-robin UID partitioning.The deterministic partitioning logic is straightforward and correct. Sorting UIDs (line 2168) ensures consistency across ranks, and the round-robin slice notation
[rank::world_size](line 2171) evenly distributes the load.
2213-2297: LGTM: Comprehensive merging logic with proper device consolidation.The recursive
move_to_devicehelper (lines 2240-2251) and state dict merging (lines 2253-2267) correctly consolidate partial results onto a single device. Metrics accumulation (lines 2269-2290) and success rate calculation (lines 2278-2280) properly aggregate statistics across all partials.
2299-2448: LGTM: Well-structured distributed gather with efficient caching.The partition caching (lines 2334-2349) with LRU eviction, per-rank gathering (lines 2351-2368), and all-gather synchronization (lines 2388-2404) form a robust distributed gather implementation. The dual return path (lines 2406-2447) supports both incremental processing (partials) and traditional merged results, with rank metadata (lines 2426-2429) enabling deterministic downstream processing.
2452-2724: LGTM: Distributed gather integration with proper fallback consistency.The distributed gather integration (lines 2509-2575) correctly determines when to use distributed mode and handles exceptions gracefully. The fallback logic (lines 2546-2564) ensures all ranks fall back together, maintaining consistency as noted in previous reviews. The dual-mode handling (lines 2589-2723) properly supports both merged results and partials, with reserve tier integration working correctly for both paths.
085574a to
146a54c
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
neurons/validator.py (1)
3837-3933: Critical: UnboundLocalError and qparams/vals misalignment in compute_peer_val_norms.
- For list (partials) path, state_dict is undefined but used later → UnboundLocalError.
- all_vals/all_quant_params can become misaligned when some partials miss quant_params, causing dequant IndexError or norms on uint8.
Apply this fix:
- Return early after computing norms for partials.
- Keep qparams aligned with vals (pad/truncate).
- Warn on inconsistencies.
def compute_peer_val_norms( self, - gather_result: SimpleNamespace | list[SimpleNamespace], + gather_result: SimpleNamespace | list[SimpleNamespace], ) -> dict[str, torch.Tensor]: @@ - # Handle both merged results and partials - if isinstance(gather_result, list): + # Handle both merged results and partials + if isinstance(gather_result, list): @@ - for n, p in self.model.named_parameters(): + for n, p in self.model.named_parameters(): vals_key = n + "vals" quant_key = n + "quant_params" @@ - for partial in gather_result: + for partial in gather_result: # Skip None partials (ranks that gathered 0 peers) if partial is None: continue @@ - vals = getattr(state_dict, vals_key, None) - quant_params = getattr(state_dict, quant_key, None) + vals = getattr(state_dict, vals_key, None) + quant_params = getattr(state_dict, quant_key, None) @@ - if vals is not None: - # vals is a list of tensors, one per peer in this partial - # Move all tensors (including nested) to target device - for v in vals: - all_vals.append(move_to_device_recursive(v)) - # quant_params is also a list if present - if quant_params is not None: - for qp in quant_params: - all_quant_params.append(move_to_device_recursive(qp)) + if vals is not None: + # vals: list[Tensor] per peer in this partial + num_vals = len(vals) + for v in vals: + all_vals.append(move_to_device_recursive(v)) + # Align quant_params with vals length + if quant_params is None: + all_quant_params.extend([None] * num_vals) + else: + qps = [move_to_device_recursive(qp) for qp in quant_params] + if len(qps) < num_vals: + # pad with None to keep indices aligned + qps.extend([None] * (num_vals - len(qps))) + tplr.logger.warning( + f"Inconsistent quant_params for {vals_key}; padding {num_vals - len(quant_params)} None(s)" + ) + elif len(qps) > num_vals: + qps = qps[:num_vals] + all_quant_params.extend(qps) @@ - if not all_vals: - continue + if not all_vals: + continue @@ - vals_f32 = self.compressor.maybe_dequantize_values( - all_vals, - all_quant_params if all_quant_params else None, - target_device, - ) + vals_f32 = self.compressor.maybe_dequantize_values( + all_vals, all_quant_params, target_device + ) @@ - norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device) + norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device) clip_norm_dict[vals_key] = torch.median(norms) - else: - # Merged result: original logic - state_dict = gather_result.state_dict - if not state_dict: - raise ValueError("Must have gather_result.state_dict to compute norms") + # Early‑return for partials; avoid falling through to merged path + return clip_norm_dict + else: + # Merged result: original logic + state_dict = gather_result.state_dict + if not state_dict: + raise ValueError("Must have gather_result.state_dict to compute norms") @@ - for n, p in self.model.named_parameters(): + for n, p in self.model.named_parameters(): vals_key = n + "vals" quant_key = n + "quant_params" @@ - return clip_norm_dict + return clip_norm_dictThis removes the UnboundLocalError and keeps dequantization robust against inconsistent partials.
src/tplr/comms.py (2)
21-21: Fix missing imports used in MPU and network error handling.
base64(MD5 header),EndpointConnectionError, andaiohttpexceptions are referenced but never imported; will raise at runtime on failures.Apply this diff near the top import block:
@@ import json +import base64 import math @@ -from botocore.exceptions import ClientError, ConnectionClosedError +from botocore.exceptions import ClientError, ConnectionClosedError, EndpointConnectionError +from aiohttp.client_exceptions import ClientOSError, ServerDisconnectedErrorAlso applies to: 37-43
582-592: Normalize optionalbucketparameter in s3_get_object.Doc says it defaults to
self.bucket, but code uses it directly and may beNone.Apply:
async def s3_get_object( @@ - bucket: Bucket | None = None, + bucket: Bucket | None = None, @@ - ): + ): @@ - """Download object from S3 using asynchronous streaming. + """Download object from S3 using asynchronous streaming. @@ - s3_client = await self._get_s3_client(bucket) + # Default to instance bucket unless explicitly provided + if bucket is None: + bucket = self.bucket + s3_client = await self._get_s3_client(bucket)
♻️ Duplicate comments (1)
src/tplr/neurons.py (1)
734-742: Fix inverted memory_budget_violation flag (1 should indicate violation).Currently sets 0 when exceeding and 1 otherwise; invert in both incremental and single paths.
@@ - budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0 @@ - budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0Also applies to: 1034-1040
🧹 Nitpick comments (8)
neurons/validator.py (3)
95-106: Deduplicate UIDs and base success rate on unique attempts.Across partials, a UID could appear twice (fallbacks/retries). Use sets to avoid inflation.
def _get_uids_from_gather_result(gather_result) -> list[int]: @@ - if isinstance(gather_result, list): - all_uids = [] - for partial in gather_result: - if partial is not None: - all_uids.extend(getattr(partial, "uids", [])) - return all_uids + if isinstance(gather_result, list): + uids = [] + for partial in gather_result: + if partial is not None: + uids.extend(getattr(partial, "uids", [])) + return sorted(set(uids)) @@ def _get_success_rate_from_gather_result(gather_result) -> float: @@ - if isinstance(gather_result, list): - total_uids = len(Validator._get_uids_from_gather_result(gather_result)) - total_skipped = len( - Validator._get_skipped_uids_from_gather_result(gather_result) - ) - total_attempted = total_uids + total_skipped - if total_attempted == 0: - return 0.0 - return total_uids / total_attempted + if isinstance(gather_result, list): + unique_uids = set(Validator._get_uids_from_gather_result(gather_result)) + unique_skipped = set( + Validator._get_skipped_uids_from_gather_result(gather_result) + ) + total_attempted = len(unique_uids | unique_skipped) + return (len(unique_uids) / total_attempted) if total_attempted else 0.0Also applies to: 120-134
1556-1574: Pass device as string, not via cast().cast() doesn’t convert at runtime; prefer explicit str for consistency with Comms API.
- device=cast(str, self.device), + device=str(self.device),
1684-1689: Guard offload for partials.offload_gather_results expects a merged result; if gather_result is a list, it no-ops. Avoid the call for clarity, or offload each partial.
- self.offload_gather_results(gather_result, log=True) + if not isinstance(gather_result, list): + self.offload_gather_results(gather_result, log=True)src/tplr/comms.py (2)
655-661: Prefer structured error inspection in HEAD path (avoid brittle string checks).Use
e.response['Error']['Code']and HTTP status like ins3_object_exists; also guardbucketbefore purge.Apply:
- except (ConnectionClosedError, ClientError) as e: - await self._purge_s3_client(bucket) - if "404" in str(e): - tplr.logger.debug(f"Object {key} not found in bucket {bucket.name}") - return None + except (ConnectionClosedError, ClientError) as e: + if bucket is not None: + await self._purge_s3_client(bucket) + status = None + code = None + if isinstance(e, ClientError) and hasattr(e, "response"): + status = e.response.get("ResponseMetadata", {}).get("HTTPStatusCode") + code = e.response.get("Error", {}).get("Code") + if status == 404 or code in {"404", "NoSuchKey", "NotFound"}: + tplr.logger.debug(f"Object {key} not found in bucket {bucket.name if bucket else '<?>'}") + return None + if status in {401, 403} or code in {"AccessDenied", "Forbidden"}: + tplr.logger.error(f"Permission error checking {key}: {e}") + return None + tplr.logger.warning(f"HEAD error for {key}: {e}") + return None
2196-2211: Avoid mutating global CUDA device in merge_gather_results.Calling
torch.cuda.set_device(0)here is a side effect; choose a device string without changing global state.- if torch.cuda.is_available(): - try: - target_device = f"cuda:{torch.cuda.current_device()}" - except RuntimeError: - # Fall back to first available device or CPU - target_device = "cpu" - if torch.cuda.device_count() > 0: - # Use device 0, but could be extended to find first available - try: - torch.cuda.set_device(0) - target_device = "cuda:0" - except RuntimeError: - target_device = "cpu" + if torch.cuda.is_available(): + try: + target_device = f"cuda:{torch.cuda.current_device()}" + except RuntimeError: + target_device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu"src/tplr/neurons.py (3)
557-563: Don’t skip the entire batch on duplicate UIDs; prune duplicates instead.Skipping the whole batch wastes valid gradients. Drop only the partials containing already-applied UIDs.
- duplicates = applied_uids & batch_uids - if duplicates: - tplr.logger.warning( - f"Skipping batch {batch_idx + 1}: {len(duplicates)} duplicate UIDs" - ) - continue + duplicates = applied_uids & batch_uids + if duplicates: + before = len(batch_partials) + batch_partials = [ + p for p in batch_partials + if not (set(getattr(p, "uids", [])) & applied_uids) + ] + pruned = before - len(batch_partials) + if pruned > 0: + tplr.logger.warning( + f"Pruned {pruned} partial(s) with duplicate UIDs; continuing with {len(batch_partials)}" + ) + if not batch_partials: + continue
482-485: Remove redundant inner imports.
gcis already imported at module top;torch.cudais accessible viatorch.cudawithout separate import.- import gc - - import torch.cuda + # (imports removed; use top-level gc and torch.cuda via torch.cuda)
1805-1823: Minor: avoid buildingall_state_dictsfor partials path.For partials, you re-derive
state_dictviauid_to_partial; theall_state_dicts.append(...)list is unused and increases memory.Replace the partials branch to only build
all_uidsanduid_to_partial; deferstate_dictaccess to the main loop. No behavior change.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
hparams/2B.json(1 hunks)hparams/hparams.json(1 hunks)neurons/base_node.py(1 hunks)neurons/miner.py(3 hunks)neurons/trainer.py(1 hunks)neurons/validator.py(13 hunks)src/tplr/comms.py(18 hunks)src/tplr/compress.py(1 hunks)src/tplr/neurons.py(12 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
- src/tplr/compress.py
- hparams/hparams.json
- neurons/trainer.py
- neurons/base_node.py
- hparams/2B.json
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/validator.py (5)
src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/comms.py (2)
gather_with_reserve(2452-2724)merge_gather_results(2174-2297)src/tplr/neurons.py (1)
check_uid_index_overlap(1789-1994)src/tplr/compress.py (1)
maybe_dequantize_values(647-697)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
src/tplr/neurons.py (2)
src/tplr/compress.py (6)
compress(373-377)compress(379-383)compress(386-423)ChunkingTransformer(145-292)TopKCompressor(295-697)unpack_12bit_indices(102-142)src/tplr/comms.py (3)
Comms(64-3429)merge_gather_results(2174-2297)get(1503-1621)
neurons/miner.py (3)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/logging.py (1)
T(41-48)src/tplr/comms.py (1)
gather(1740-2142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (8)
neurons/miner.py (3)
468-470: Good call refreshing commitments before gather.Ensures R2 existence checks use current bucket info.
707-731: Sequential gather params look correct and production-safe.Passing local=False and stale_retention=100 avoids local FS misses; expected_compressed_params forwarded.
733-744: Clear, partial-aware logging.Logs cover successes, skips, rate, and timing; guarded for None.
neurons/validator.py (4)
1392-1410: Peers broadcast across ranks looks right.Using broadcast_object_list aligns shards before distributed gather.
1452-1479: Timestamp: master query + broadcast is correct.Prevents rank skew in time window selection.
1659-1670: Overlap check handles partials/merged.Good reuse of a single API for both paths.
975-981: Rating initialization before OpenSkill.rate is correct.Prevents KeyErrors for new peers.
src/tplr/comms.py (1)
708-717: Good fix: offload torch.load to a thread.Avoids blocking the event loop on CPU decompression. Solid improvement.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
neurons/miner.py (1)
469-481: Consider runningget_commitments()only on master rank to reduce chain queries.Line 470 calls
get_commitments()on all ranks when only the master rank needs the updated commitments to broadcast. The non-master ranks receive the peer list via broadcast and don't use commitments at this point.Apply this diff to reduce redundant chain queries:
peer_start = tplr.T() if self.is_master: await tplr.neurons.update_peers( instance=self, window=step_window, peer_start=peer_start ) - # Refresh commitments to get updated bucket info for all peers - self.comms.commitments = await self.comms.get_commitments() + # Refresh commitments to get updated bucket info for all peers + self.comms.commitments = await self.comms.get_commitments() # Broadcast peer list from master to all ranks for distributed gather if dist_helper.world_size > 1 and dist_helper.is_distributed():src/tplr/comms.py (1)
2196-2210: Device selection fallback could fail when device 0 is unavailable.The fallback logic (lines 2200-2209) assumes
cuda:0exists whendevice_count() > 0, but device 0 might be unavailable on some systems. Whiletorch.cuda.set_device(0)is wrapped in try/except (lines 2205-2209), a more robust approach would probe available devices sequentially.Consider this alternative that finds the first usable device:
try: target_device = f"cuda:{torch.cuda.current_device()}" except RuntimeError: - # Fall back to first available device or CPU target_device = "cpu" - if torch.cuda.device_count() > 0: - # Use device 0, but could be extended to find first available - try: - torch.cuda.set_device(0) - target_device = "cuda:0" - except RuntimeError: - target_device = "cpu" + # Probe for first available CUDA device + for device_id in range(torch.cuda.device_count()): + try: + torch.cuda.set_device(device_id) + target_device = f"cuda:{device_id}" + break + except RuntimeError: + continue
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
neurons/miner.py(6 hunks)src/tplr/comms.py(18 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
neurons/miner.py (3)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/comms.py (1)
gather_with_reserve(2439-2713)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (18)
neurons/miner.py (3)
762-792: LGTM!The dual-path handling for partials (list) vs. merged results (SimpleNamespace) is implemented correctly. The code safely extracts UIDs and skipped UIDs using
getattrwith defaults, aggregates metrics across partials, and logs appropriate per-rank context for distributed mode.
885-904: LGTM!The debug dictionary construction correctly handles both partials (list) and merged results (SimpleNamespace). The use of sets for deduplication when computing
successful_peersandskipped_peersis appropriate and prevents duplicate UIDs in the debug output.
940-963: LGTM!The metrics calculation correctly aggregates statistics from both partials (list) and merged results (SimpleNamespace). The success rate computation accounts for both paths, and the extraction of
skipped_uids_listuses safe defaults.src/tplr/comms.py (15)
516-519: Timeout handling is conservative and appropriate.Returning
Trueonasyncio.TimeoutError(line 516) follows the conservative approach for transient network issues, which is correct for preventing false slashing. The debug-level log is suitable for a timeout that's likely transient.
707-721: LGTM!Offloading
torch.load()to a thread pool (line 711) prevents blocking the event loop during CPU decompression, which is critical for parallel downloads. The error handling appropriately logs failures and returnsNoneto signal the issue to callers.
1469-1490: LGTM!The
versionparameter now defaults toNoneand is normalized totplr.__version__when not provided (lines 1489-1490). This is a clean API design that allows callers to omit the parameter while maintaining backwards compatibility.
1561-1564: LGTM!Offloading
torch.load()for local files (line 1562) maintains consistency with the remote loading path and prevents event loop blocking even for local storage. This is a good defensive practice.
1667-1734: LGTM!The consecutive NOT_FOUND tracking (lines 1667-1668, 1723-1734) is a smart optimization that prevents long retry loops when a file clearly doesn't exist. The threshold of 5 consecutive failures (2.5 seconds) is reasonable, and resetting the counter on other error types correctly distinguishes transient network issues from permanent absence.
1842-1871: LGTM!Parallelizing S3 existence checks (lines 1854-1870) is an excellent optimization that avoids serial API calls. Running all checks concurrently with
asyncio.gatherand caching results inexistence_resultsprevents the 5-second-per-failure delay mentioned in the comment. Treating exceptions asFalse(line 1865) is conservative and appropriate.
1873-1935: LGTM!The failure attribution logic correctly distinguishes between network issues and miner failures using the cached existence checks from lines 1892 and 1921. Network failures (file exists but download failed) don't add to
skipped_uidsto prevent false slashing (lines 1893-1900, 1922-1928), while genuine miner failures (file doesn't exist) are correctly flagged for slashing (lines 1902-1908, 1931-1934). The logging levels are appropriate for each case.
2147-2171: LGTM!The
partition_uids_across_ranksmethod implements deterministic round-robin UID assignment (line 2171) with proper sorting (line 2168) to ensure all ranks agree on the partition. The early return forworld_size <= 1(lines 2164-2165) is a sensible optimization.
2240-2297: LGTM!The
merge_gather_resultsimplementation is thorough and correct:
- The recursive
move_to_devicehelper (lines 2240-2251) handles all common Python data structures- State dict merging (lines 2254-2267) correctly concatenates parameter lists while moving tensors to the target device
- Metrics aggregation (lines 2270-2290) uses appropriate operations: max for time (parallel execution), sum for bytes, extend for lists
- The aggregate success rate calculation (lines 2278-2280) correctly divides total successful by total attempted
2334-2349: LGTM!The partition caching strategy (lines 2334-2349) is well-designed:
- Cache key includes window, sorted UIDs, and world_size for deterministic lookups
- LRU eviction maintains a bounded cache size (32 entries)
- Moving accessed entries to the end (line 2340) implements LRU correctly
2382-2392: LGTM!The all-gather synchronization (lines 2382-2392) correctly uses
dist_helper.all_gather_objectto collect partials from all ranks. The synchronous call is appropriate here as it's a blocking collective operation by design. The early normalization of empty payloads toNone(lines 2395-2399) simplifies downstream filtering.
2402-2434: LGTM!The dual return path (lines 2402-2434) correctly implements the
return_partialsflag:
- When
True: returns unmerged list with rank metadata (lines 2414-2416) for incremental processing- When
False: merges on all ranks (line 2420) for consistency
The rank-0-only logging (lines 2408-2412, 2424-2432) avoids duplicate logs while providing comprehensive metrics.
2505-2553: Distributed fallback correctly runs sequential gather on all ranks.Lines 2537-2553 implement the fallback by having all ranks run sequential gather (line 2548) when distributed gather fails. This maintains consistency across ranks as noted in the comment (line 2540). The comment at lines 2546-2547 correctly states that sequential gather (R2 downloads) can be done independently by each rank.
2630-2668: Reserve retry correctly uses sequential gather and handles both result types.Line 2631 appropriately uses sequential
gather()for reserve UIDs rather than distributed, as reserve retries are typically smaller and don't benefit from distribution overhead. The dual handling (lines 2639-2668) correctly:
- Appends fallback as a new partial with rank metadata (lines 2641-2651) when
primaryis a list- Merges fallback into the existing SimpleNamespace (lines 2653-2668) when
primaryis merged
2673-2711: LGTM!The final success rate computation and dual logging (lines 2673-2711) correctly handles both return types:
- For partials: aggregates across all partials, adds
success_rateto each (lines 2681-2683), and returnsNoneif no gradients (lines 2697-2698)- For merged: computes simple rate and returns
Noneif empty (lines 2700-2711)The early return of
Nonefor zero gradients (lines 2697, 2710) ensures callers receive consistent failure signals.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/tplr/comms.py (1)
2196-2211: Device fallback could still fail on multi-GPU systems.The fallback logic at lines 2201-2210 attempts to use device 0 but doesn't verify its actual availability before setting it. If device 0 is unavailable (e.g., on systems with only devices 1-3),
torch.cuda.set_device(0)will raise a RuntimeError and fall back to CPU, which might not be intended.Consider iterating through available devices:
if torch.cuda.is_available(): try: target_device = f"cuda:{torch.cuda.current_device()}" except RuntimeError: - # Fall back to first available device or CPU target_device = "cpu" - if torch.cuda.device_count() > 0: - # Use device 0, but could be extended to find first available - try: - torch.cuda.set_device(0) - target_device = "cuda:0" - except RuntimeError: - target_device = "cpu" + # Try to find first actually available device + for device_id in range(torch.cuda.device_count()): + try: + with torch.cuda.device(device_id): + target_device = f"cuda:{device_id}" + break + except RuntimeError: + continue
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
hparams/hparams.json(1 hunks)neurons/trainer.py(1 hunks)src/tplr/comms.py(18 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- hparams/hparams.json
- neurons/trainer.py
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (14)
src/tplr/comms.py (14)
28-28: LGTM: Imports and cache field properly structured.The module-level imports and partition cache structure are well-designed. The OrderedDict cache with tuple keys enables efficient LRU eviction for deterministic UID partitioning across windows.
Also applies to: 49-49, 141-143
482-581: LGTM: Robust existence check with conservative error handling.The structured error inspection and conservative assumptions (returning True for transient errors) correctly prevent false slashing of miners due to network issues. Permission errors appropriately logged at ERROR level.
707-721: LGTM: Critical fix for event-loop blocking.Offloading
torch.load()to a thread pool viaasyncio.to_threadis essential for enabling true parallel downloads. The 1-2s CPU decompression time per file would otherwise serialize concurrent operations. Error handling andweights_only=Trueare appropriate.
1561-1564: LGTM: Consistent thread-pool offloading for local files.Local file loading also correctly uses
asyncio.to_threadto prevent event-loop blocking, maintaining consistency with remote S3 downloads.
1469-1490: LGTM: Fixed undefined version reference.The signature now properly defaults
versiontoNoneand falls back totplr.__version__, resolving the previous undefinedself.versionissue. This pattern allows optional version specification while maintaining backward compatibility.
1667-1734: LGTM: Smart early-exit optimization for missing files.Tracking consecutive NOT_FOUND responses enables early termination (2.5s) when a file clearly doesn't exist, avoiding full timeout waits. Resetting the counter on other error types correctly distinguishes persistent absence from transient network issues.
1841-1870: LGTM: Critical performance optimization for failure handling.Parallelizing S3 existence checks dramatically reduces failure-handling latency from O(N * 5s) to O(5s) for N failures. The cached results enable efficient O(1) lookups during response processing, preventing cascading timeouts in distributed scenarios.
1873-1935: LGTM: Sophisticated failure attribution prevents false slashing.The enhanced error handling correctly distinguishes between miner failures (no upload) and network/timeout issues using the pre-computed existence checks. Only legitimate upload failures are added to
skipped_uids, preventing false slashing while maintaining accountability.
2148-2172: LGTM: Clean deterministic partitioning.The round-robin partitioning strategy with sorted UIDs ensures deterministic, balanced distribution across ranks. The implementation is simple and correct.
2239-2297: LGTM: Robust merge implementation with recursive device movement.The merge logic correctly handles nested structures via recursive
move_to_device, consolidates per-rank state dictionaries by extending parameter lists, and properly aggregates metrics. Moving tensors to the target device during merge (rather than afterward) is more efficient.
2334-2350: LGTM: Efficient partition caching with LRU eviction.The partition cache uses window-aware keys and OrderedDict for LRU eviction (keeping 32 most recent). The
move_to_endcall on cache hits maintains LRU ordering, andpopitem(last=False)correctly evicts oldest entries.
2380-2434: LGTM: Correct distributed synchronization and dual return modes.The all-gather barrier correctly blocks for synchronization, empty partials are normalized to None, and rank metadata enables deterministic ordering. The dual return mode (partials vs merged) supports both incremental and immediate processing patterns. All-ranks merging maintains consistency.
2609-2713: LGTM: Reserve retry handles both result modes correctly.The reserve fallback logic correctly branches on result type (partials vs merged), appending fallback as a new partial with rank metadata in one case and extending state dict lists in the other. Success rate is properly recomputed with respect to the original gather target, and None is returned when no gradients are collected.
2535-2553: Looking at the actual code structure from the script outputs, the indentation reveals that the review comment's analysis is incorrect.The code at lines 2546-2553 is NOT inside the
if dist_helper.is_master:block. The key evidence:
- Line 2538:
if dist_helper.is_master:(with specific indentation)- Lines 2539-2545: Logging code (indented deeper, inside master block)
- Lines 2546-2547: Comments with same indentation as the if statement (outside master block)
- Lines 2548-2553:
primary = await self.gather(...)with same indentation as the if statement (outside master block)The actual control flow is:
if primary is None: if dist_helper.is_master: # Only master logs context_log(...) tplr.logger.info(...) # ALL ranks execute this: primary = await self.gather(...)This means all ranks already perform the sequential gather independently, which is correct behavior and matches the intent of the comment. No broadcast is needed because each rank downloads independently (R2 downloads are rank-local operations).
The review comment's proposed diff is unnecessary—the code already implements the intended behavior.
Likely an incorrect or invalid review comment.
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (4)
neurons/validator.py (2)
1452-1479: Timestamp broadcast: consistent windowing across ranks. LGTM.Master queries, broadcasts to all ranks, then shared time_min/time_max. This fixes prior inconsistency; indentation and device usage are correct.
1647-1667: Single merge for upload; overlap uses partial-aware path.The prior duplicate merges are removed; now you merge once for upload and pass original gather_result to overlap. Memory win.
Also applies to: 1657-1664
src/tplr/neurons.py (2)
741-741: Still inverted: budget_exceeded flag logic remains incorrect.This was flagged in previous reviews but not corrected. The flag is set to
0when budget IS exceeded and1when it's NOT exceeded, which is counter-intuitive and breaks metric semantics.Apply this diff:
- budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0Note: The same issue exists at line 1039 in
_outer_step_single.
1039-1039: Still inverted: budget_exceeded flag logic incorrect here too.Same issue as line 741. The flag should be
1when budget is exceeded, not0.Apply this diff:
- budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0
🧹 Nitpick comments (3)
neurons/validator.py (2)
138-202: Extend offload/load helpers to handle list-of-partials (memory relief in distributed mode).offload_gather_results/load_gather_results_to_device currently no-op for partials, leaving tensors on GPU. Add list handling.
def offload_gather_results(self, gather_result, *, log: bool = True) -> None: @@ - state = getattr(gather_result, "state_dict", None) + # Handle list-of-partials early + if isinstance(gather_result, list): + for partial in gather_result: + if partial is None or getattr(partial, "_offloaded", None) == "cpu": + continue + state = getattr(partial, "state_dict", None) + if state is None: + continue + if isinstance(state, SimpleNamespace): + _to_pinned_cpu(state) + elif isinstance(state, dict): + for k, v in list(state.items()): + state[k] = _to_pinned_cpu(v) + setattr(partial, "_offloaded", "cpu") + if torch.cuda.is_available(): + torch.cuda.synchronize(getattr(self, "device", None)) + torch.cuda.empty_cache() + return + + state = getattr(gather_result, "state_dict", None) @@ def load_gather_results_to_device( self, gather_result, device: torch.device, *, log: bool = True ) -> None: @@ - state = getattr(gather_result, "state_dict", None) + if isinstance(gather_result, list): + if torch.cuda.is_available(): + torch.cuda.synchronize(device) + for partial in gather_result: + if partial is None or getattr(partial, "_offloaded", None) != "cpu": + continue + state = getattr(partial, "state_dict", None) + if state is None: + continue + if isinstance(state, SimpleNamespace): + _to_device(state) + elif isinstance(state, dict): + for k, v in list(state.items()): + state[k] = _to_device(v) + setattr(partial, "_offloaded", None) + if torch.cuda.is_available(): + torch.cuda.synchronize(device) + return + + state = getattr(gather_result, "state_dict", None)If you want, I can add a micro benchmark to compare VRAM before/after offloading partials.
Also applies to: 203-264
1392-1410: Trim log volume when broadcasting peers to non-master ranks.Logging entire peer lists every window can spam logs and increase costs. Prefer counts.
- tplr.log_with_context( - level="info", - message=f"[Rank {dist_helper.rank}] Received peers from master: gather={self.comms.peers}, reserve={self.comms.reserve_peers}", + tplr.log_with_context( + level="info", + message=f"[Rank {dist_helper.rank}] Received peers from master: gather={len(self.comms.peers)} peers, reserve={len(self.comms.reserve_peers)} peers", sync_window=self.sync_window, current_window=self.current_window, )src/tplr/neurons.py (1)
1740-1755: Consider clarifying the comment.The implementation uses
torch.maximumto take the larger ofparam_avg_changeandlearning_rate, which is good for stability. However, the comment at line 1740-1741 says "Use learning_rate as the baseline" which doesn't fully capture that you're using the maximum of both values. Consider updating the comment to be more precise about the approach.Example:
- # Use learning_rate as the baseline step size for more stable comparison - # param_avg_change can be very small when model is converged, leading to inflated step ratios + # Use the maximum of param_avg_change and learning_rate for stable step ratio calculation + # This avoids division by tiny numbers when the model has converged
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
neurons/validator.py(14 hunks)src/tplr/neurons.py(13 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/tplr/neurons.py (2)
src/tplr/compress.py (5)
compress(373-377)compress(379-383)compress(386-423)TopKCompressor(295-697)unpack_12bit_indices(102-142)src/tplr/comms.py (2)
merge_gather_results(2174-2297)get(1503-1621)
neurons/validator.py (5)
src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/comms.py (3)
gather_with_reserve(2439-2713)merge_gather_results(2174-2297)get_with_retry(1623-1738)src/tplr/neurons.py (1)
check_uid_index_overlap(1795-2000)src/tplr/compress.py (1)
maybe_dequantize_values(647-697)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (6)
src/tplr/neurons.py (6)
48-70: LGTM: Robust CUDA device resolution with proper fallbacks.The helper correctly handles edge cases (unavailable CUDA, invalid device strings, RuntimeErrors) and provides sensible fallbacks. Good defensive programming.
250-322: LGTM: Clean dispatcher implementation.The function signature and dispatch logic correctly handle both single merged results and lists of partials, forwarding all parameters appropriately to the specialized implementations.
324-369: LGTM: Recursive memory estimation correctly implemented.The recursive
_iter_tensorshelper properly traverses all container types (dict, list, tuple, SimpleNamespace) to sum tensor memory, addressing the previous issue where memory estimates were often 0 MB.Based on learnings
371-419: LGTM: Sound batch size calculation with safety margins.The sampling approach (first 3 partials) and 0.7 safety factor provide reasonable estimates while avoiding OOM risks. Edge cases are handled properly.
665-696: LGTM: Excellent error recovery and batch failure logging.The per-batch try/except with detailed failure tracking allows processing to continue even when individual batches fail. The corrected logging using
batch_idx(line 693) properly addresses the previous KeyError issue.Based on learnings
1795-1945: LGTM: Efficient overlap detection with O(n) optimization.The uid_to_partial mapping (lines 1816-1826) eliminates the O(n×m) nested loop search that was flagged previously, achieving O(1) lookup per UID (line 1881). The function now correctly handles both merged and partial results.
Based on learnings
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
hparams/hparams.json(2 hunks)src/tplr/model_factory.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- hparams/hparams.json
0d669d2 to
5060b03
Compare
956d520 to
ca50d92
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
neurons/miner.py (1)
843-877: Critical:gather_resultis cleared/deleted before later use, causing errors and losing metrics.Inside the
if should_update:block you now do:
- For list results: delete each partial’s
state_dictand thengather_result.clear().- For merged results: delete
gather_result.state_dict.- In both cases:
del gather_result.Later in the same loop you still rely on
gather_resultin several places:
- Building
debug_dict["successful_peers"]anddebug_dict["skipped_peers"](Lines 927–949).- Computing
gather_success_rateandskipped_uids_listfor wandb and Influx metrics (Lines 985–1007, 1060–1072).- Final cleanup at the very end of the loop (
if gather_result is not None: del gather_resultat Lines 1088–1092).When
should_updateis True (the common case), the earlierdel gather_resultmakes all these later references anUnboundLocalError. For the list case, even without the explicitdel,gather_result.clear()destroys the partials needed to compute success and skipped peer sets.You can keep the memory savings while fixing the lifecycle by only dropping the heavy
state_dictpayloads and deferring deletion of the container until after metrics/debug logging. For example:@@ - if should_update: - gradient_fingerprint = self.outer_step(gather_result) - self.global_step += ( - 1 # Increment only when we actually do an outer step - ) - - # Aggressive cleanup after outer step to prevent memory leaks - # Clear gradients and free any tensors from the merge/apply operation - if hasattr(self, "model") and self.model is not None: - self.model.zero_grad(set_to_none=True) - - # Clear gather_result to free memory - if gather_result is not None: - if isinstance(gather_result, list): - for partial in gather_result: - if partial is not None and hasattr(partial, "state_dict"): - delattr(partial, "state_dict") - gather_result.clear() - elif hasattr(gather_result, "state_dict"): - delattr(gather_result, "state_dict") - del gather_result + if should_update: + gradient_fingerprint = self.outer_step(gather_result) + self.global_step += 1 # Increment only when we actually do an outer step + + # Aggressive cleanup after outer step to prevent memory leaks + # Clear gradients and free any tensors from the merge/apply operation + if hasattr(self, "model") and self.model is not None: + self.model.zero_grad(set_to_none=True) + + # Drop heavy tensor payloads from gather_result but keep metadata + if gather_result is not None: + if isinstance(gather_result, list): + for partial in gather_result: + if partial is not None and hasattr(partial, "state_dict"): + delattr(partial, "state_dict") + elif hasattr(gather_result, "state_dict"): + delattr(gather_result, "state_dict") @@ - # Delete any remaining local variables to clear up memory - del shard_gradient - if gather_result is not None: - del gather_result + # Delete any remaining local variables to clear up memory + del shard_gradient + if gather_result is not None: + del gather_resultThis way:
gather_result(and any partials) remain available for debug and metrics, but without their largestate_dicttensors.- Final deletion still happens once at the end of the loop, after all uses.
Apply an equivalent fix if you shuffle code, but the key is: don’t clear/delete
gather_resultbefore its last use.
-->Also applies to: 927-949, 985-1007, 1060-1072, 1088-1092
♻️ Duplicate comments (3)
neurons/validator.py (2)
1459-1489: Timestamp broadcast fix resolves prior rank inconsistencyOnly querying
query_block_timestampon master, then broadcastingts_valuewithdist_helper.broadcastensures every rank uses the same[time_min, time_max]window. This addresses the earlier concern where some ranks could fall back totime.time()while others used chain timestamps.
3873-3971: Bug:compute_peer_val_normsstill falls through for partials, using undefinedstate_dictWhen
gather_resultis alist[SimpleNamespace](the partials case used by distributed gather):
- The
if isinstance(gather_result, list):block buildsclip_norm_dictby aggregating and dequantizing vals from all partials.- After that block, execution continues into the shared
for n, p in self.model.named_parameters():loop, which referencesstate_dict—butstate_dictis only defined in theelse(merged) branch.This re‑introduces the earlier
UnboundLocalError/NameErrorforstate_dictin the partials path and will blow up whenever you callcompute_peer_val_normswith a list (i.e., in distributed gather mode). It also means you’re recomputing norms twice for merged results if the list branch wasn’t intended to fall through.Recommend returning early for the partials case and scoping the merged‑result loop to the
elsebranch:def compute_peer_val_norms( self, - gather_result: SimpleNamespace | list[SimpleNamespace], + gather_result: SimpleNamespace | list[SimpleNamespace], ) -> dict[str, torch.Tensor]: @@ - # Handle both merged results and partials - if isinstance(gather_result, list): + # Handle both merged results and partials + if isinstance(gather_result, list): @@ - norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device) - clip_norm_dict[vals_key] = torch.median(norms) - else: - # Merged result: original logic - state_dict = gather_result.state_dict - if not state_dict: - raise ValueError("Must have gather_result.state_dict to compute norms") - - for n, p in self.model.named_parameters(): - vals_key = n + "vals" - quant_key = n + "quant_params" - - vals = getattr(state_dict, vals_key, None) - quant_params = getattr(state_dict, quant_key, None) - - if vals is None: - continue - - vals_f32 = self.compressor.maybe_dequantize_values( - vals, quant_params, p.device - ) - - norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device) - clip_norm_dict[vals_key] = torch.median(norms) - - return clip_norm_dict + norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device) + clip_norm_dict[vals_key] = torch.median(norms) + # Partials path fully handled + return clip_norm_dict + else: + # Merged result: original logic + state_dict = gather_result.state_dict + if not state_dict: + raise ValueError("Must have gather_result.state_dict to compute norms") + + for n, p in self.model.named_parameters(): + vals_key = n + "vals" + quant_key = n + "quant_params" + + vals = getattr(state_dict, vals_key, None) + quant_params = getattr(state_dict, quant_key, None) + + if vals is None: + continue + + vals_f32 = self.compressor.maybe_dequantize_values( + vals, quant_params, p.device + ) + + norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device) + clip_norm_dict[vals_key] = torch.median(norms) + + return clip_norm_dictYou may also want to re‑introduce the additional defensive alignment logic for
vals/quant_paramsthat was suggested previously if you expect inconsistent partials, but the scoping/early‑return above is the immediate correctness fix.#!/bin/bash # Quick check: show full compute_peer_val_norms and verify state_dict is only used in merged path rg -n "def compute_peer_val_norms" -n -C15 neurons/validator.pysrc/tplr/neurons.py (1)
593-609: Memory budget violation flag still inverted in logging/metricsThe new use of
_resolve_cuda_device(device)andtorch.cuda.memory_allocated(cuda_device)is correct and fixes the earlier string‑device usage. However, thebudget_exceededflag at Line [854] is still inverted:budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1Given the metric name
"outer_step/memory_budget_violation", this should be1when memory exceeds the budget and0otherwise; currently it reports0on violation and1when within budget, which is the opposite of what the name implies and will mislead monitoring.A minimal fix:
- budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0Also applies to: 629-629, 848-877
🧹 Nitpick comments (10)
neurons/validator.py (1)
1399-1417: Distributed peers broadcast keeps ranks aligned, but watch payload sizeBroadcasting
[self.comms.peers, self.comms.reserve_peers]from master viadist.broadcast_object_listis a clean way to ensure all ranks share an identical peer set before distributed gather. On large subnets this object payload could become sizeable; if that ever shows up in profiling, consider broadcasting a compact, fixed‑shape representation (e.g., tensors of UIDs) instead.src/tplr/comms.py (4)
483-582: s3_object_exists: consider purging bad clients and broadening error code coverage.The structured handling of
ClientErrorvs transient issues looks solid and the conservative defaults are appropriate. Two refinements would harden this further:
- On
ConnectionClosedErrorand similar low-level failures, consider calling_purge_s3_client(bucket)so you don’t keep reusing a broken client instance, mirroring the behavior ins3_get_objectands3_put_object.- When checking
codefor “not found” and bucket/config cases, you might also include"NotFound"(and possibly other provider-specific aliases) to avoid misclassifying some 404-style responses.Both are incremental reliability tweaks; current behavior is functionally correct.
-->
640-683: s3_get_object: HEAD logging, size-based path, and torch.load offload look good.
- HEAD-time logging around
time_min/time_maxand the 500 MB threshold for switching to the multipart downloader are reasonable and align with the large‑model use case.- Offloading
torch.loadto a thread viaasyncio.to_threadis the right move to avoid blocking the event loop during CPU decompression.One thing to double‑check is that the deployed PyTorch version supports
weights_only=Trueintorch.load; if not, this will raise at runtime and should be guarded or dropped for older versions.-->
Also applies to: 699-723
1563-1569: Local torch.load offload is correct; just mindweights_onlycompatibility.Offloading the local
torch.loadto a thread matches the remote path and avoids event‑loop blocking. As withs3_get_object, ensure the runtime PyTorch version supportsweights_only=Truewhen loading these saved dicts, or gate that argument based on version.-->
2178-2211: Distributed merge helpers are structurally sound; consider minor device-policy tweaks.Overall:
_move_gather_result_to_devicecorrectly walksstate_dictand moves tensors (and nested containers) without touching metadata.merge_gather_resultsproperly filtersNonepartials, aggregates metrics, concatenates per‑param lists, and recomputes a sensible global success_rate.Two small considerations:
- The default
target_devicepath will often move everything back to CUDA (viacurrent_device/device 0). Ingather_distributedyou already run the final device move explicitly; if you’d rather keep the merge purely on CPU for low‑memory environments, you could passtarget_device="cpu"from the caller instead of relying on the default.- The duplicated
move_to_devicehelper (here and in_move_gather_result_to_device) could be factored once to reduce drift, though this is cosmetic.Functionally, this looks correct.
-->Also applies to: 2214-2337
neurons/miner.py (2)
548-556: Pre‑training cleanup is effective; consider guarding CUDA calls.Flushing the CUDA cache, synchronizing, and zeroing gradients before
inner_stepshelps avoid OOMs from prior windows. To make this robust to CPU‑only or--device cpuruns, you may want to wraptorch.cuda.empty_cache()andtorch.cuda.synchronize()inif torch.cuda.is_available():or checkself.device.type == "cuda"here (and similarly in other cleanup sites).
-->
1103-1134: cleanup_window is aggressive and effective; guard CUDA for broader compatibility.The between‑window cleanup:
- Zeroes gradients on model and inner optimizer.
- Drops error feedback for non‑owned params.
- Clears autocast cache, empties CUDA cache multiple times, synchronizes, and runs GC.
- Reuses
check_memory_thresholdto ensure post‑cleanup headroom.Functionally this is fine for GPU miners. For robustness in environments where CUDA may not be available or
--device cpuis used, you may want to wrap thetorch.cuda.*calls in availability checks (similar to the earlier suggestion) to avoid runtime errors.
-->src/tplr/neurons.py (3)
371-419: Batch-size helper currently unused; consider wiring it or removing it
_calculate_optimal_batch_sizeis a reasonable heuristic (sampling a few partials, applying a safety factor, clamping) but it isn’t invoked from_outer_step_incrementalor elsewhere in this file. If incremental per‑batch processing is no longer planned in this path, consider either:
- Calling this helper from
_outer_step_incrementalwhen you reintroduce true batching, or- Dropping the function to avoid dead code and reduce cognitive load.
446-591: Incremental path now “merge-then-apply once”; docstring still suggests true batching
_outer_step_incrementalnow deterministically sorts partials by rank, merges all valid partials into a singlemerged_result, and then applies it once via_outer_step_single, with good cleanup and optional memory logging. That matches the catch‑up semantics, but:
- The docstring and argument docs still talk about “processing partials by merging them first, then applying once” (accurate) but also mention
batch_size_overrideand budget‑based batching, which are no longer used._calculate_optimal_batch_size/_estimate_partial_memory_mbare not integrated, somemory_budget_mbhere is effectively “for telemetry only”.To avoid confusion, consider either wiring in true batch‑wise merging using the helpers, or updating the docstring and parameter descriptions to match the current “merge all once” behavior and noting that the budget is only logged, not enforced.
1554-1569: Step-size denominator now stabilized by learning_rate; consider min_step_size paramUsing
learning_rateas the baseline step size and clampingparam_avg_changewithtorch.maximumis a sensible way to avoid huge step ratios when the EMA is tiny. Note thatmin_step_sizein the function signature is now effectively unused; if you still want a global floor, you could clamp the denominator with it, or else remove/rename that parameter to avoid confusion.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
hparams/hparams.json(1 hunks)neurons/base_node.py(1 hunks)neurons/miner.py(8 hunks)neurons/trainer.py(4 hunks)neurons/validator.py(14 hunks)src/tplr/comms.py(20 hunks)src/tplr/compress.py(1 hunks)src/tplr/neurons.py(13 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- hparams/hparams.json
- neurons/trainer.py
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (3)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/comms.py (1)
gather_with_reserve(2556-2836)
src/tplr/neurons.py (2)
src/tplr/compress.py (4)
compress(373-377)compress(379-383)compress(386-423)unpack_12bit_indices(102-142)src/tplr/comms.py (1)
merge_gather_results(2214-2337)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
neurons/validator.py (4)
src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/comms.py (3)
gather_with_reserve(2556-2836)merge_gather_results(2214-2337)get_with_retry(1625-1742)src/tplr/neurons.py (1)
check_uid_index_overlap(1610-1815)src/tplr/compress.py (1)
maybe_dequantize_values(647-697)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (20)
neurons/base_node.py (1)
61-62: Verification confirms the code change is correct and the comment is accurate.The Trainer class exists at
neurons/trainer.py:41with an__init__method at line 47. BothMiner(neurons/miner.py:64) andValidator(neurons/validator.py:90) inherit from(BaseNode, Trainer), making thesuper().__init__()call necessary for proper MRO chain initialization. The comment accurately reflects this multiple inheritance pattern.src/tplr/compress.py (1)
637-645: Device alignment fix in_dequantize_valueslooks correctMoving both
lookupandshifttoval.deviceeliminates cross-device indexing issues whenvalis on GPU. The guards keep it safe for any non-tensor qparam shapes. No further changes needed here.neurons/validator.py (4)
90-141: Helper methods correctly normalize gather results across partial and merged modesThe new static helpers (_get_uids_from_gather_result, _get_skipped_uids_from_gather_result, _get_success_rate_from_gather_result) nicely encapsulate list‑vs‑merged handling and dedupe UIDs so success rate and metrics are robust even when partials include retries or fallbacks. This centralization reduces duplication and makes later logging/metric code much clearer.
3281-3295: Switch toget_with_retryfor miner debug dicts improves robustnessUsing
self.comms.get_with_retrywith a bounded timeout and explicit handling ofNone/success=Falsegivesevaluate_miner_synca clean failure path and avoids hanging on slow or missing debug uploads. Returning a structured failure result withsync_score=0.0integrates smoothly withslash_for_poor_sync.
1679-1722: All verification checks passed; code changes are coherent and safe to approveThe verification confirms all reviewer assumptions:
- ✓
outer_stepsignature (src/tplr/neurons.py:250-267) acceptsmemory_budget_mbandbatch_size_overrideas optional parameters- ✓ Call site (neurons/validator.py:2542-2560) correctly passes both via
getattr(self.hparams, ...)for optional feature support- ✓ Helper methods
_get_uids_from_gather_result,_get_skipped_uids_from_gather_result, and_get_success_rate_from_gather_resultare consistently used across referenced lines (1745-1763, 2556-2561)- ✓ Design properly handles both partial and merged
gather_resultinputs with minimal memory overhead
1539-1565: Design is sound; heterogeneousgather_resultstates are safe due to rank-aware implementationThe distributed/sequential gather flow correctly gates
gather_with_reserveexecution. The edge case you identified—non-master rank withgather_result=Nonewhile master succeeds—is handled safely by design:
_outer_step_single(called by all ranks) is rank-aware: only master readsgather_result(line 634-640)- In distributed mode, non-master ranks skip reading
gather_resultentirely (on_src=Falsewhenddp=Trueand not master)- Master broadcasts per-parameter
has_updateflags (via_bcast_flag); non-master follows the broadcast regardless of local state- Therefore, a non-master rank can safely have
gather_result=None—it's never dereferencedNo changes needed; the current behavior is correct. The design relies on explicit rank-awareness in
_outer_step_singlerather than requiring all-rank synchronization before the call.src/tplr/comms.py (4)
1668-1673: Early exit on repeated NOT_FOUND plus extended grace window is sensible.The increased
time_max_grace_periodand theconsecutive_not_foundcutoff prevent spinning for the full timeout when the object almost certainly doesn’t exist, while still giving some leeway for late uploads and clock skew. The logic correctly resets the counter on non‑NOT_FOUND statuses, so transient network errors still get full retry behavior.
-->Also applies to: 1727-1739
1846-1875: Parallel existence checks and miner vs network failure attribution look correct.The new block that:
- collects failed downloads (
Exception/None),- runs
s3_object_existsfor them in parallel, and- distinguishes: no bucket → “not registered”, object exists → network/timing issue (no slashing), object missing → miner didn’t upload (skippable/slashable),
is well structured and uses the cached existence results correctly in both the
ExceptionandNonepaths. This should remove the previous N×5 s serial penalty and give clearer attribution without changing success semantics.
-->Also applies to: 1880-1939
2371-2476: gather_distributed: partitioning, CPU all_gather, and device moves behave as intended.
- UID partitioning via a cached round‑robin over
sorted(uids)is deterministic and respectsworld_size/rank.- Local gather → optional local merge (when not returning partials) → CPU move before
all_gather_objectis a reasonable sequence for 70B‑scale gradients.- All ranks participate in the barrier/all‑gather when
world_size>1 and dist_helper.is_distributed(), and non‑distributed cases correctly degenerate to a single‑elementall_partialslist.- The two return paths (
return_partials=Truevs merged result) are consistent with the declared return type, and partials are normalized (uidsempty →None) before downstream use.I don’t see logic errors here; remaining choices (e.g., doing the local merge on GPU then moving to CPU) are trade‑offs rather than bugs.
-->Also applies to: 2479-2514, 2516-2551
2556-2700: gather_with_reserve: distributed integration and type consistency look good.This refactor:
- Decides
use_distributedexactly once based onhparams.distributed_gather,world_size, anddist_helper.is_distributed().- Uses
gather_distributedwhen enabled, with a clear “all ranks fall back together” path to sequentialgather()if distributed returns no gradients or raises.- Preserves the contract that
return_partials=Trueyieldslist[SimpleNamespace | None]and otherwise a mergedSimpleNamespace, but returnsNonewhen there are truly no successful gradients.- Handles reserve retries for both partial and merged primaries, avoiding duplicate UIDs and aggregating metrics correctly.
- Recomputes success_rate relative to
gather_uidsand returnsNoneif even after reserves there are zero successes, which upstream callers already check.Semantics are consistent across distributed and non‑distributed modes; no correctness issues spotted.
-->Also applies to: 2701-2836
neurons/miner.py (5)
32-35: CUDA allocator env var set correctly before torch import.Setting
PYTORCH_CUDA_ALLOC_CONFbefore importing torch, and only when it’s absent, is the right pattern; it gives a sane default (expandable_segments:True) without overriding explicit user configuration.
-->
165-177: Distributed init and device wiring look coherent.Using
dist_helper.init_process_group(...)once in__init__and then sourcingrank,world_size,local_rank,is_master, anddevicefromdist_helperkeeps all distributed state in one place. Falling back toself.config.deviceand wrapping it intorch.deviceis standard.Just ensure
dist_helper.init_process_grouphandles single‑process and CPU‑only cases gracefully with the chosen backend. -->
471-476: Peer refresh and broadcast across ranks are correct and necessary for distributed gather.
- Refreshing
self.comms.commitmentson master afterupdate_peersensures bucket metadata is current before a gather.- Broadcasting
self.comms.peersfrom rank 0 to all ranks viadist.broadcast_object_listunderworld_size > 1 and dist_helper.is_distributed()correctly synchronizes peer selection, and non‑master logging confirms reception.This keeps all ranks aligned on the same peer set for subsequent distributed gathers.
Confirm that
torch.distributedis initialized in exactly the same cases wheredist_helper.is_distributed()returns True so the broadcast can’t be called before init. -->Also applies to: 478-486
927-949: Debug successful/skipped peers: logic is correct given a live gather_result.Assuming the lifecycle fix above is applied, the debug computation:
- Aggregates
uidsandskipped_uidsacross partials (list case).- Falls back to the merged
SimpleNamespace’suids/skipped_uidsotherwise.- Derives
successful_peersaspeers − skippedand logs both.This correctly reflects “who actually contributed” independent of distributed vs sequential gather. No changes needed beyond ensuring
gather_resultstill exists here.
-->
985-1007: Gather success-rate and skipped-peers metrics handle partial vs merged results correctly.The metrics block:
- Treats
gather_result is Noneas 0 % success, with an empty skipped list.- For partial lists, aggregates successes and skipped across all partials and computes success_rate as
total_uids / (total_uids + total_skipped) * 100.- For merged results, reuses
gather_result.success_rate * 100and itsskipped_uids.Logging
skipped_peersto Influx asjson.dumps(skipped_uids_list)gives a clear diagnostics trail. As above, this relies ongather_resultnot having been deleted earlier in the loop.
-->Also applies to: 1060-1072
src/tplr/neurons.py (5)
48-69: CUDA device resolver avoids string-device CUDA pitfallsThe helper cleanly normalizes incoming
devicevalues and guarantees that CUDA APIs only see atorch.device('cuda:X')orNone, which fixes prior string-device issues and degrades safely toNoneon CPU-only or invalid inputs. No changes needed here.
249-321: outer_step dispatch for merged vs partial results is straightforwardRouting
list[SimpleNamespace]to_outer_step_incrementaland everything else to_outer_step_singlekeeps the public API stable while making the new distributed path opt‑in and explicit. The keyword‑only arguments and propagation ofmemory_budget_mb/batch_size_overridelook consistent.
324-369: Recursive partial memory estimator correctly handles nested namespaces/containersThe new
_estimate_partial_memory_mbwalks tensors throughSimpleNamespace, dicts, lists, and tuples, and applies a safety overhead before converting to MB, which aligns with the earlier review feedback about under‑estimating memory for nested payloads. This looks correct and robust.
421-444: _merge_partial_batch cleanly delegates to Comms.merge_gather_resultsUsing a thin wrapper around
Comms.merge_gather_resultswith a local import avoids cycles and keeps the merge semantics centralized incomms.py. Behavior for empty batches (None) is clear. No issues here.
1612-1651: Overlap checker’s partial-support logic and UID mapping look correctThe updated
check_uid_index_overlapcorrectly:
- Accepts either a merged
SimpleNamespaceorlist[SimpleNamespace]of partials.- Builds a
uid_to_partialmap so each UID is resolved in O(1) instead of scanning all partials (fixing the prior O(n×m) inner‑loop issue).- Tracks
valid_uids_for_paramso pair keys are stored in terms of the original UID indices, ensuring summary stats still index into the globaluidslist consistently.The per‑param filtering (
len(unpacked_indices) < 2) and the subsequent pair accumulation logic all line up with the existing reporting. This looks good.Also applies to: 1685-1714, 1730-1735, 1743-1759
ca50d92 to
8488c13
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
neurons/trainer.py (1)
872-929: Fix incorrect double gradient clipping with GradScalerYou're clipping gradients twice: once before
scaler.unscale_()at line 872, and again after at lines 887–889 when capturingtotal_grad_norm. PyTorch's official pattern with AMP isunscale_()→clip_grad_norm_()→step()→update(). Clipping before unscale operates on scaled gradients, violating the intended gradient magnitude constraint.Remove the pre-unscale clip at line 872. The post-unscale clip at lines 887–889 is correct and properly captures the norm. The
null_roundbranch settingtotal_grad_norm = 0.0at line 912 remains appropriate: spin-up steps log 0.0, training steps log the actual clipped norm.
♻️ Duplicate comments (4)
neurons/validator.py (2)
1459-1486: Timestamp query + broadcast fixes prior time window inconsistencyOnly querying
query_block_timestampon master and broadcastingts_valueto all ranks ensures a consistent(time_min, time_max)on every rank, including the fallback totime.time()when the query fails. This addresses the earlier inconsistency concerns while keeping the device handling reasonable.Also applies to: 1488-1499
3875-3971:compute_peer_val_normshas a partials‑path bug (UnboundLocalError) and fragile quant_params handlingTwo issues here:
UnboundLocalError when
gather_resultis a list (partials)
- In the
if isinstance(gather_result, list):branch you never definestate_dict.- Execution then falls through to the merged‑result loop below, which unconditionally references
state_dict(initialized only in theelsebranch).- In distributed gather mode (where
gather_resultis a list), this will raise at the first parameter.Quantization metadata can be misaligned across partials
- For partials,
all_valsis extended for everyvalsentry, butall_quant_paramsis only extended whenquant_paramsis notNone.- If some partials have
valsbut missing or shorterquant_params, you end up withlen(all_vals) != len(all_quant_params).- Passing those into
self.compressor.maybe_dequantize_values(all_vals, all_quant_params, ...)can raiseIndexErroror silently compute norms on rawuint8tensors instead of dequantized values.Both issues can surface exactly in the new distributed/partials flow, making this a blocker for that path. You likely want the partials branch to:
- fully compute norms and
return clip_norm_dictearly, and- only feed dequantization with aligned
(vals, qparams)pairs, skipping any entries that are quantized but lack valid qparams.Here is a concrete refactor that addresses both problems while keeping semantics close to what you intend:
def compute_peer_val_norms( - self, - gather_result: SimpleNamespace | list[SimpleNamespace], - ) -> dict[str, torch.Tensor]: - """ - Compute median L2 norms across all peers for each parameter. - - Args: - gather_result: Either a merged SimpleNamespace or list of partials - - Returns: - Dictionary mapping parameter names to median norms - """ - clip_norm_dict = {} - - # Handle both merged results and partials - if isinstance(gather_result, list): - # Partials: collect vals from all partials for each parameter - # Use model's first parameter device as target (typically cuda:0) - target_device = next(self.model.parameters()).device - - # Helper to recursively move all tensors to target device - def move_to_device_recursive(obj): - if torch.is_tensor(obj): - return obj.to(target_device) - elif isinstance(obj, dict): - return {k: move_to_device_recursive(v) for k, v in obj.items()} - elif isinstance(obj, (list, tuple)): - return type(obj)(move_to_device_recursive(item) for item in obj) - else: - return obj - - for n, p in self.model.named_parameters(): - vals_key = n + "vals" - quant_key = n + "quant_params" - - # Collect vals and quant_params from all partials - all_vals = [] - all_quant_params = [] - - for partial in gather_result: - # Skip None partials (ranks that gathered 0 peers) - if partial is None: - continue - - state_dict = partial.state_dict - if not state_dict: - continue - - vals = getattr(state_dict, vals_key, None) - quant_params = getattr(state_dict, quant_key, None) - - if vals is not None: - # vals is a list of tensors, one per peer in this partial - # Move all tensors (including nested) to target device - for v in vals: - all_vals.append(move_to_device_recursive(v)) - # quant_params is also a list if present - if quant_params is not None: - for qp in quant_params: - all_quant_params.append(move_to_device_recursive(qp)) - - if not all_vals: - continue - - # Dequantize all collected values (now all on same device) - vals_f32 = self.compressor.maybe_dequantize_values( - all_vals, - all_quant_params if all_quant_params else None, - target_device, - ) - - # Compute median norm across all peers from all partials - norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device) - clip_norm_dict[vals_key] = torch.median(norms) - else: - # Merged result: original logic - state_dict = gather_result.state_dict - if not state_dict: - raise ValueError("Must have gather_result.state_dict to compute norms") + self, + gather_result: SimpleNamespace | list[SimpleNamespace | None], + ) -> dict[str, torch.Tensor]: + """ + Compute median L2 norms across all peers for each parameter. + + Args: + gather_result: Either a merged SimpleNamespace or list of partials. + + Returns: + Dictionary mapping parameter names to median norms. + """ + clip_norm_dict: dict[str, torch.Tensor] = {} + + # Handle both merged results and partials + if isinstance(gather_result, list): + # Partials: collect vals from all partials for each parameter. + # Use model's first parameter device as target (typically cuda:0) + target_device = next(self.model.parameters()).device + + def move_to_device_recursive(obj): + if torch.is_tensor(obj): + return obj.to(target_device) + if isinstance(obj, dict): + return {k: move_to_device_recursive(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return type(obj)(move_to_device_recursive(v) for v in obj) + return obj + + for n, p in self.model.named_parameters(): + vals_key = n + "vals" + quant_key = n + "quant_params" + + # Separate unquantized and quantized entries to keep qparams aligned + all_unquantized_vals: list[torch.Tensor] = [] + all_quantized_vals: list[torch.Tensor] = [] + all_quantized_qparams: list[QuantParamsT] = [] + + for partial in gather_result: + if partial is None: + continue + state_dict = getattr(partial, "state_dict", None) + if not state_dict: + continue + + vals = getattr(state_dict, vals_key, None) + quant_params = getattr(state_dict, quant_key, None) + if vals is None: + continue + + # vals is expected to be a list-like (one per peer in this partial) + for idx, v in enumerate(vals): + v_dev = move_to_device_recursive(v) + if torch.is_tensor(v_dev) and v_dev.dtype == torch.uint8: + # Quantized entry – require matching qparams + if ( + quant_params is None + or idx >= len(quant_params) + or quant_params[idx] is None + ): + tplr.logger.warning( + f"Missing quant_params for {vals_key}[{idx}] in a partial; " + "skipping for norm computation." + ) + continue + qp = move_to_device_recursive(quant_params[idx]) + all_quantized_vals.append(v_dev) + all_quantized_qparams.append(qp) + else: + # Treat non‑uint8 as already in float space + all_unquantized_vals.append(v_dev) + + if not all_unquantized_vals and not all_quantized_vals: + continue + + norms_list: list[torch.Tensor] = [] + + # Dequantize quantized entries + if all_quantized_vals: + vals_f32_q = self.compressor.maybe_dequantize_values( + all_quantized_vals, + all_quantized_qparams, + target_device, + ) + norms_list.extend(torch.norm(v, p=2) for v in vals_f32_q) + + # Unquantized entries: just compute norms + if all_unquantized_vals: + norms_list.extend( + torch.norm(v.to(target_device), p=2) + for v in all_unquantized_vals + ) + + if norms_list: + norms = torch.stack(norms_list).to(p.device) + clip_norm_dict[vals_key] = torch.median(norms) + + # Partials handled entirely above + return clip_norm_dict + + # Merged result: original logic + state_dict = gather_result.state_dict + if not state_dict: + raise ValueError("Must have gather_result.state_dict to compute norms") @@ - for n, p in self.model.named_parameters(): + for n, p in self.model.named_parameters(): vals_key = n + "vals" quant_key = n + "quant_params" @@ - vals_f32 = self.compressor.maybe_dequantize_values( - vals, quant_params, p.device - ) + vals_f32 = self.compressor.maybe_dequantize_values( + vals, quant_params, p.device + ) @@ - return clip_norm_dict + return clip_norm_dictThis keeps the merged‑result behavior unchanged, makes the partials path self‑contained (no
state_dictleakage), and defensively skips any quantized entries that don’t have usable qparams instead of crashing.src/tplr/comms.py (1)
2562-2836: Sequential fallback after distributed failure remains per-rank; confirm you’re okay with potential cross-rank divergence.The new
gather_with_reservelogic integrates distributed vs sequential paths and fixes the previous type-contract issue whenreturn_partials=True:
- When
use_distributedis True, all ranks participate ingather_distributed, and you correctly normalizeprimaryto a list whenreturn_partials=True.- If distributed gather returns no gradients or throws, all ranks fall back to sequential
self.gather(...), again wrapping in a single-element list whenreturn_partials=True.- Logging and final
success_ratecomputation now treat both merged and partials cases consistently.However, the sequential fallback path still has each rank call
gather()independently against the fullgather_uidsset, and the results are not synchronized afterward. In steady state, all ranks will most likely see the same gradients (same S3 objects), but under asymmetric network errors or transient bucket issues, different ranks could end up with slightly differentprimarycontents and success rates.Previous reviews already raised the “rank consistency after distributed fallback” concern; this remains conceptually the same, just now with list-wrapping fixed. Unless downstream code is explicitly robust to cross-rank differences here, you may want to:
- Either broadcast one rank’s sequential fallback result (merged or partials) to all ranks after the fallback, or
- Explicitly document that fallback windows are allowed to be slightly inconsistent across ranks.
src/tplr/neurons.py (1)
593-609: Memory budget metric still uses inverted flag logic (budget_exceeded).Within
_outer_step_singleyou correctly:
- Resolve a CUDA device for memory accounting via
_resolve_cuda_device.- Compute
final_memory_mbandutilization_pctrelative tomemory_budget_mb.But:
budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1means:
budget_exceeded == 0when the budget is exceeded.budget_exceeded == 1when you are within budget.For a metric named
memory_budget_violation, this is inverted and will be quite confusing for monitoring and WandB dashboards.This was called out in earlier reviews and the issue is still present here; you probably want:
budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0so that
1consistently means “violation”.Also applies to: 629-630, 848-877
🧹 Nitpick comments (9)
neurons/miner.py (1)
959-982: Metrics success‑rate math is sound; consider reusing upstream rate and deduping skipped peersThe recomputation of
gather_success_ratefor partials (usingtotal_uids / (total_uids + total_skipped) * 100) is reasonable and matches the intended semantics, and the merged case reusesgather_result.success_rate. Two small polish ideas:
- For partials, you could simply reuse the aggregate success rate already attached to each partial by
gather_with_reserveinstead of recomputing it here.- Before logging
skipped_uids_listto Influx, consider deduping:
skipped_uids_list = sorted(set(skipped_uids_list))
to avoid duplicates if any partials shareskipped_uids.Both are optional, but would slightly simplify the code and make
skipped_peersmetrics cleaner.Also applies to: 1046-1046
neurons/validator.py (1)
1539-1565: Consider adding all-rank consensus aftergather_with_reservefor strict "all-or-nothing" behaviorThe concern is valid: in distributed mode, if
gather_with_reservefails on a non-master rank but succeeds on the master, the current one-way broadcast from master will resetskip_windowfrom the non-master's localTrueto the master'sFalse. The non-master then continues withgather_result=Noneandskip_window=False, creating an inconsistent state.To enforce strict consistency, add an
all_ok()consensus step after the gather attempt:# After gather_with_reserve try/except block (around line 1595) gather_ok = dist_helper.all_ok(gather_result is not None, self.device, tag="gather_with_reserve") if use_distributed_gather and not gather_ok: skip_window = TrueThis ensures that if any rank's gather fails, all ranks skip the window. The
all_ok()helper already exists atsrc/tplr/distributed.py:185and is used in similar contexts elsewhere (save/eval/apply operations at lines 1981, 2111, 2177, 2209, 2286).Also applies to: 1569-1596, 1597-1633
src/tplr/comms.py (3)
483-582: S3 existence helper looks solid; consider handling additional “NotFound”-style codes.The new
s3_object_existsimplementation is conservative and aligns well with the slashing requirements (only returningFalsefor clear 404/config errors, and treating network/service issues as transient). One small robustness gap: some S3-compatible backends may emit"NotFound"as the error code for missing keys. Right now you only treat{404, "NoSuchKey"}as “not found”; adding"NotFound"to that set would avoid misclassifying a missing object as a transient network case.This is minor and could be done later; overall the structure and logging semantics here look good.
2213-2337:merge_gather_resultssemantics matchgather()layout; minor duplication with_move_gather_result_to_device.Functionally this does what we need:
- Filters out
Nonepartials and bails cleanly when nothing usable is present.- Computes a device for consolidation (
target_device) with a reasonable CUDA/CPU fallback strategy.- For each param key, concatenates per-partial lists while recursively moving tensors to
target_device.- Aggregates metrics (
time, bytes, uids, skipped_uids) and computes a globalsuccess_rate.Two small follow-ups you might consider:
You already have
_move_gather_result_to_deviceabove; the localmove_to_devicehere is effectively re-implementing the same recursion over containers. Reusing the static helper would remove duplication and centralize behavior.When
torch.cuda.is_available()is True butcurrent_device()fails, you optimistically try toset_device(0). In practice that’s fine, but wrapping that in a single helper (or the_resolve_cuda_devicepattern you use inneurons.py) would make behavior more predictable and easier to test.Functionally this looks correct; the above are just refactors.
2339-2552: Distributed gather flow is coherent; watch memory when merging all partials on GPU.The overall
gather_distributedflow is solid:
- Deterministic round‑robin partitioning with a per-window cache (
_partition_cache), so all ranks agree onmy_uids.- Local
gatherper rank on its subset, with logging and a local pre-merge (merge_gather_results([partial_result])) whenreturn_partials=False.- Explicit move of each rank’s partial to CPU before
all_gather_objectto avoid GPU OOM.- After all‑gather, either:
- Return a normalized list of partials with
rankmetadata whenreturn_partials=True, or- Merge all partials into one
merged_result(on CPU) and then move back to the requesteddevice.Two things to keep in mind:
- In the
return_partials=Falsepath you locally pre‑merge on each rank and then globally mergeall_partials, which can temporarily double CPU footprint. That’s unavoidable if you want a fully merged namespace, but it’s worth being aware of when you scale peers×ranks up.- In the
return_partials=Truepath, downstream code (e.g._outer_step_incremental) currently merges all partials again into a single result on the target device, which largely reverts memory savings from keeping them as partials. If you intend to reintroduce true incremental/batched application later, you already have the helper scaffolding (_estimate_partial_memory_mb,_calculate_optimal_batch_size) to do so.No immediate correctness bug here, but future tuning for very large peer sets might want to avoid “merge-all-partials-twice” patterns.
src/tplr/neurons.py (3)
22-23: CUDA device resolver is reasonable; consider catching broader errors for robustness.
_resolve_cuda_devicegives you a clean way to:
- Short-circuit when CUDA isn’t available.
- Prefer
current_device()when possible.- Fall back to
"cuda:0"when the provideddeviceis invalid.One minor hardening tweak you might consider is broadening the exception handling around
torch.cuda.current_device()andtorch.device(device)to catchExceptionrather than onlyRuntimeError/TypeError/ValueError, to cover any assertion-style failures from CUDA initialization in some PyTorch builds. Not critical, but would make this helper more bulletproof.Also applies to: 48-69
249-321: outer_step dispatch between single vs list results is clean; note that batching knobs are currently a no-op.The updated
outer_stepsignature and dispatch logic look good:
- Accepts
gather_resultas eitherSimpleNamespace, list of partials, orNone.- For lists, delegates to
_outer_step_incremental; otherwise falls back to_outer_step_single.- Propagates
memory_budget_mbandbatch_size_overridethrough.However, in the current implementation:
_outer_step_incrementaldoes not usebatch_size_override, and- There is no call to
_calculate_optimal_batch_sizefor true incremental/batched processing.So from a caller’s perspective, the incremental path always merges all partials into a single
merged_resultbefore applying, and the “batching” knobs are effectively ignored (only memory metrics are logged).If you intend this API to provide real memory-bounded batching, consider either wiring
_calculate_optimal_batch_sizeinto_outer_step_incrementalor, for now, documenting thatmemory_budget_mb/batch_size_overrideare advisory/logging only.
324-369: Per-partial memory estimator is correct for current layouts but currently unused.
_estimate_partial_memory_mbcorrectly:
- Recurses through
SimpleNamespace, dicts, lists, tuples to find tensors.- Sums
element_size() * nelement()and applies a 1.3× overhead factor.- Logs MB-per-UID when it can.
Right now this helper isn’t used anywhere in the file (including
_outer_step_incremental), which means it adds maintenance cost without delivering behavior. If you don’t plan to hook it up to dynamic batch sizing soon, it might be better to either:
- Wire it into
_calculate_optimal_batch_sizeand then use that in_outer_step_incremental, or- Drop it entirely until you reintroduce true incremental batching.
neurons/trainer.py (1)
972-1021: Align returned grad/Adam metrics with new per-stepgrad_normloggingYou now log
grad_norm: {total_grad_norm:.4f}per inner step (Line 928), but the summary metrics at the end ofinner_steps(adam_metrics,global_grad_norm,global_weight_norm, Lines 977–1021) are computed afterself.inner_optimizer.zero_grad(set_to_none=True)has run in the last step, sop.gradis typicallyNoneand all grad-based metrics will be ~0.Docstring on
compute_adam_metricsexplicitly says it “should be called BEFORE gradients are cleared,” which isn’t currently true here.Consider one of:
- Move the metrics computation into the
if step_now:block, right beforezero_grad, and cache the last-step metrics to return; or- Defer
zero_graduntil after metrics are computed for the final step; or- If you only care about per-step
total_grad_norm, drop or rename the returnedglobal_grad_norm/ grad-based Adam metrics to avoid confusion.This will keep aggregated metrics consistent with the newly logged per-step
grad_norm.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
neurons/miner.py(9 hunks)neurons/trainer.py(4 hunks)neurons/validator.py(14 hunks)src/tplr/comms.py(20 hunks)src/tplr/compress.py(1 hunks)src/tplr/neurons.py(13 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/tplr/compress.py
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (3)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/comms.py (1)
gather_with_reserve(2556-2836)
src/tplr/neurons.py (2)
src/tplr/compress.py (6)
compress(373-377)compress(379-383)compress(386-423)ChunkingTransformer(145-292)TopKCompressor(295-697)unpack_12bit_indices(102-142)src/tplr/comms.py (1)
merge_gather_results(2214-2337)
neurons/validator.py (5)
src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/comms.py (3)
gather_with_reserve(2556-2836)merge_gather_results(2214-2337)get_with_retry(1625-1742)src/tplr/neurons.py (1)
check_uid_index_overlap(1610-1815)src/tplr/compress.py (1)
maybe_dequantize_values(647-697)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (22)
neurons/miner.py (6)
31-35: CUDA allocator pre-config and dist import look solidSetting
PYTORCH_CUDA_ALLOC_CONFbefore importing torch and pulling intorch.distributed as distfor the laterbroadcast_object_listis correct and low‑risk; no changes needed here.Also applies to: 42-42
548-556: Pre‑inner‑step GPU cleanup is a good guardrail against OOMThe aggressive
torch.cuda.empty_cache()/synchronize()plusmodel.zero_grad(set_to_none=True)beforeinner_stepsis a sensible addition to shed residue from the previous outer step and reduce OOM risk; the conditional onself.modelalso avoids NPEs.
734-755: Distributed vs sequential gather selection is consistent; double‑check sequential semantics across ranksThe
use_distributed_gatherflag and theif self.is_master or use_distributed_gathergate correctly ensure all ranks participate in distributed mode and only rank 0 runs the sequential gather, withreturn_partialswired as expected. The completion logging also correctly handles both list‑of‑partials and mergedSimpleNamespaceresults.One thing to verify: in sequential mode, only master sets
gather_result, but all ranks later callouter_step(gather_result)guarded byshould_updatecoming fromdist_helper.all_ok(...). Please confirm thatall_okandouter_stepare implemented such that non‑master ranks never end up callingouter_step(None)(e.g., by syncingshould_updateappropriately or broadcasting a non‑None result), to avoid per‑rank divergence.Also applies to: 759-760, 774-820
849-852: Post‑outer‑step gradient clearing is appropriateZeroing model gradients right after a successful outer step is a good addition and keeps the next window from accidentally reusing stale grads; this aligns with the aggressive cleanup strategy elsewhere.
904-924: Successful/skipped peers aggregation now correctly supports partial and merged gather resultsThe updated debug block that derives
successful_peers/skipped_peersby aggregatinguidsandskipped_uidsacross a list of partials (or reading directly from a merged result) removes the previous incompatibility whenreturn_partials=True. Using set arithmetic againstself.comms.peerskeeps the debug output robust even if partials differ.
1066-1074: gather_result cleanup now happens at the right time and is safely scopedClearing
state_dict.__dict__for each partial (or merged) result, then deletinggather_resultonly after all debug/metrics usage, is a good fix for the earlier lifecycle issue and should materially reduce peak memory. The defensivehasattr(..., "__dict__")checks avoid breaking ifstate_dictever changes type.neurons/validator.py (7)
94-141: Static helpers for gather_result look correct and centralize logic nicelyThe new
_get_uids_from_gather_result,_get_skipped_uids_from_gather_result, and_get_success_rate_from_gather_resultcorrectly handle both partial lists and merged namespaces, including deduplication for partials. This should make downstream code simpler and less error‑prone.
982-988: OpenSkill rating initialization beforeratecall is soundEnsuring all
window_uidshave initializedopenskill_ratingsbefore buildingteamsavoids key/shape inconsistencies and makes the rating update path robust for first‑time peers.
1399-1416: Peers/reserve broadcast to all ranks is appropriate for distributed gatherBroadcasting
[self.comms.peers, self.comms.reserve_peers]from master ensures all ranks use the same peer sets whendistributed_gatheris enabled. The guard onworld_size > 1 and dist_helper.is_distributed()avoids misuse in non‑distributed runs.
1690-1703: Reuse of helpers for overlap logging and gather metrics is consistentUsing
_get_uids_from_gather_result,_get_skipped_uids_from_gather_result, and_get_success_rate_from_gather_resultfor:
- overlap logging (
UIDs checked),- skipped UID logging,
- and gather‑quality metrics
keeps behavior consistent between distributed (partials) and sequential (merged) modes and avoids hand‑rolled counting in multiple places.
Also applies to: 1718-1722, 1744-1746
1888-1903: Passing rawgather_resultintocompute_peer_val_normsaligns with its new signatureThe call now correctly relies on
compute_peer_val_normsto handle both merged namespaces and partial lists without pre‑merging on the validator side, which is what you want in distributed gather mode. The remaining caveats are insidecompute_peer_val_normsitself (see separate comment).
2556-2561: Wiring gradient memory budget and incremental batch size intoouter_stepPlumbing
gradient_memory_budget_mbandincremental_batch_sizefromhparamsintotplr.neurons.outer_stepvia optional kwargs is a clean way to make incremental processing opt‑in without changing existing behavior when these fields are absent.
3281-3295: Switch tocomms.get_with_retryfor debug dict fetching improves robustnessUsing
get_with_retrywith a bounded timeout and explicitly handling theNonecase avoids tight retry loops and clarifies failure handling. Thedebug_result is None or not debug_result.successguard correctly covers both timeout and error responses, and the warning log is defensive even whendebug_resultisNone.src/tplr/comms.py (4)
1668-1673:get_with_retryearly-exit semantics for repeated NOT_FOUND look reasonable.The new
time_max_grace_periodand consecutiveNOT_FOUNDtracking are well thought out:
- Extending the grace period to 10s should reduce false negatives from clock skew.
- Cutting off after 5 consecutive
NOT_FOUNDresponses avoids spinning uselessly when a file clearly isn’t going to show up, without affectingTOO_LATE/TOO_EARLYhandling.No changes needed here from my side.
Also applies to: 1727-1739
1846-1875: Parallel existence checks + slashing attribution logic look correct and materially safer.The new block that:
- Collects all
(uid, key, peer_bucket)for errored/None responses,- Runs
s3_object_existsfor them in parallel, and- Distinguishes “file exists but download/get failed” vs “file truly missing” vs “no bucket committed”
fixes the previous over-slashing risk and removes the sequential 5s-per-failure bottleneck.
Behavioral details look right:
- Unregistered (
peer_bucket is None) UIDs are treated as miners that didn’t upload and go straight toskipped_uids.- If
file_existsis True, you log a warning but do not append toskipped_uids, explicitly avoiding slashing for network/time-window issues.- Only confirmed-missing files (
file_existsis False) get tagged as legitimate upload failures and added toskipped_uids.I don’t see correctness issues here; this is a clear improvement.
Also applies to: 1882-1913, 1915-1939
142-145: Partition cache and deterministic UID assignment across ranks are well-structured.The
_partition_cachekeyed by(window, tuple(sorted_uids), world_size)with an LRU cap of 32 entries gives you deterministic round‑robin assignments while avoiding recomputing partitions for the same(window, peer set, world size).A few minor observations:
- Using
sorted_uidsin the key guarantees stability even if call sites pass different orderings.- Falling back to
effective_world_size = max(world_size, 1)avoids division issues whenworld_sizeis misconfigured.- For typical validator workloads, a cache size of 32 windows should be more than enough; if you ever do very long runs with frequent re-partitioning you might consider making this limit a config knob, but it’s not necessary now.
No action required; just confirming the design looks sound.
Also applies to: 2375-2393
640-649: Based on my verification:PyTorch version compatibility confirmed: The project requires
torch>=2.7.1, which fully supports theweights_onlyargument intorch.load.Saved format compatibility confirmed: All checkpoint files saved in the codebase use plain Python dicts containing tensors and primitives (validator state, comms data, etc.), which are fully compatible with
weights_only=Trueaccording to PyTorch's serialization design.Code locations verified:
- Lines 714-717 (src/tplr/comms.py): Remote file loading with
weights_only=True✓- Lines 1564-1565 (src/tplr/comms.py): Local file loading with
weights_only=True✓Both locations extract
state_dictandglobal_stepfrom a plain dict format that is compatible withweights_only=True. No runtime errors should occur from theweights_onlyparameter itself.src/tplr/neurons.py (4)
371-419: Incremental outer step now “merge-then-apply”; batching knobs are ignored but semantics match catch‑up.
_outer_step_incremental:
- Filters and deterministically sorts partials by
partial.rank(or metadata) for reproducible behavior.- Merges all valid partials into a single
merged_resultviaComms.merge_gather_results, then calls_outer_step_singleonce.- Aggressively frees partial state_dicts and clears CUDA cache before and after application.
- Logs memory deltas when
memory_budget_mb+ usable CUDA device are present.This ensures the application semantics match the catch‑up path (one merged gradient application) and is easy to reason about.
Trade-offs to be aware of:
- It no longer processes partials in smaller batches, so peak memory is that of a fully merged gather on the target device, plus whatever overhead
merge_gather_resultsneeds.batch_size_overrideis explicitly documented as unused, and_calculate_optimal_batch_sizeis not invoked.If this is an intentional simplification for now, the code is consistent; just be aware that the “incremental” path is currently about how you construct the merged result (via distributed gather/partials), not about memory-bounded incremental application.
Also applies to: 421-444, 446-587
1413-1477: Catch-up outer_step integration and debug-dict comparison flow look coherent.The catch-up loop now:
- Uses
outer_stepuniformly for both aggregated objects and live gather fallbacks, so the same application path is exercised as in normal training.- Adds a per-window debug-dict comparison (when available) to estimate “steps behind” and log L2 norms, while guarding errors so catch-up continues even if debug data is missing or malformed.
- Cleans up GPU and CPU memory aggressively after each window.
This wiring looks correct and should make catch-up behavior much easier to reason about and monitor.
1554-1569: Step-size denominator change incompare_model_with_debug_dictis a good stabilization.In
compare_model_with_debug_dict, switching to:
- Use
learning_rateas a baseline step size, and- Take
max(param_avg_change[name], learning_rate)elementwisebefore computing
step_ratio = abs_vec / step_vecis a sensible improvement:
- It avoids exploding step ratios when the EMA of parameter changes becomes extremely small in converged regimes.
- The
learning_ratelower bound keeps ratios on a more interpretable scale.Implementation details (shape guard + fallback when stored slice length mismatches) look correct.
1609-1651: Overlap checker’s support for partial results and O(n) UID lookup look correct; pair indexing is consistent.The updated
check_uid_index_overlap:
- Accepts either a merged
SimpleNamespaceor a list of partials.- For partials, builds:
uidsas a flat list of all UIDs across partials,uid_to_partial[uid] = (partial, idx_in_partial)for O(1) lookup,all_state_dictsas a parallel list, though it’s no longer used on the partial path.- When unpacking indices:
- Uses per-UID state_dict slices (
uid_idx_in_partial) for partials, and- The original merged indexing for single-result cases.
The
valid_uids_for_parambookkeeping plus:uid_idx_i = valid_uids_for_param[i] uid_idx_j = valid_uids_for_param[j] acc = pair_acc[(uid_idx_i, uid_idx_j)] ... uids[i], uids[j]maintains a consistent mapping from pair-accumulator keys back to global
uidsindices, so offender attribution and summary logs still refer to the right peer IDs.This looks like a correct and more scalable implementation than the original O(n×m) search.
Also applies to: 1684-1731, 1743-1760
neurons/trainer.py (1)
1023-1040: Threading newmemory_budget_mb/batch_size_overrideintoouter_steplooks good—verify signature & configPassing
memory_budget_mb=getattr(self.hparams, "gradient_memory_budget_mb", None)andbatch_size_override=getattr(self.hparams, "incremental_batch_size", None)is backward-compatible and keeps trainer decoupled from whether these fields exist in all configs.
Two checks to confirm:
- Ensure
tplr.neurons.outer_stephas matching keyword params with sensible defaults so older call sites (if any) remain valid.- Decide how to handle non-positive values (e.g.,
0or<0) in these hparams—treat as “no limit” or validate earlier and keep onlyNonereachingouter_step.If both are already handled in
outer_step, this wiring is solid.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
src/tplr/neurons.py (1)
842-870: Fix inverted budget_exceeded flag logic.Line 848 has inverted logic:
budget_exceededis set to0when memory exceeds budget and1when within budget. This makes thememory_budget_violationmetric (line 867) report the opposite of reality.Apply this diff:
- budget_exceeded = 0 if final_memory_mb > memory_budget_mb else 1 + budget_exceeded = 1 if final_memory_mb > memory_budget_mb else 0neurons/validator.py (1)
1902-1906: Fixcompute_peer_val_normsfor partials:state_dictundefined + qparams alignment riskWhen
gather_resultis a list (partials path):
- After the
if isinstance(gather_result, list): ...block completes, execution falls through to the merged‑result logic.state_dictis only defined in theelsebranch, so the subsequent loopfor n, p in self.model.named_parameters(): vals_key = n + "vals" ... vals = getattr(state_dict, vals_key, None)will raise
UnboundLocalError: local variable 'state_dict' referenced before assignment. This will trigger wheneverdistributed_gatheris enabled andcompute_peer_val_normsis invoked on partials.
- Inside the partials branch,
all_valsandall_quant_paramscan become length‑misaligned when some partials havevalsbut noquant_params, or mismatched list lengths. Passing such arrays intoself.compressor.maybe_dequantize_values(all_vals, all_quant_params if all_quant_params else None, ...)can causeIndexErroror incorrectly treat quantizeduint8tensors as unquantized floats, echoing the earlier misalignment concern.Please:
- Return early from the partials branch after populating
clip_norm_dict, so the merged‑result loop is not executed for lists.- Align
all_valsandall_quant_paramsdefensively:
- Only append qparams when present and matching the number of vals.
- For unquantized vals, append
Nonesentinel entries sovalsandqparamsstay index‑aligned.- Optionally filter out any residual
(v, qp)pairs wherev.dtype == torch.uint8andqp is Nonebefore callingmaybe_dequantize_values.An example structural fix:
- if isinstance(gather_result, list): + if isinstance(gather_result, list): ... - if not all_vals: - continue + if not all_vals: + continue ... - clip_norm_dict[vals_key] = torch.median(norms) - else: + clip_norm_dict[vals_key] = torch.median(norms) + return clip_norm_dict # <-- avoid falling through to merged path + else: # Merged result: original logic state_dict = gather_result.state_dictand tighten the
all_vals/all_quant_paramsaccumulation as per the earlier suggestion to keep them length‑aligned.Also applies to: 3876-3974
🧹 Nitpick comments (4)
src/tplr/comms.py (4)
483-582: s3_object_exists error handling is solid; consider purging broken clients on connection errors.The structured use of
HTTPStatusCode/Error.Codeand conservative defaults looks good and should reduce false slashing. One minor improvement: unlike other S3 helpers in this class, this function never calls_purge_s3_clientonConnectionClosedError/ClientError, which means a bad client can remain in the pool and keep failing future HEADs. You could callawait self._purge_s3_client(bucket)when you detect persistent client-side issues (e.g.,ConnectionClosedErroror non‑404/403 4xx), while keeping the same boolean return semantics.
2151-2176:partition_uids_across_rankshelper is currently unused; consider wiring it in or dropping it.You now have both:
partition_uids_across_ranks()(round‑robin helper), and- Inline partitioning in
gather_distributedusing the samesorted_uids[r::world_size]pattern plus_partition_cache.Because the static helper is not used, it may confuse future readers and drift from the cached implementation. Either:
- Reuse the helper inside
gather_distributed(e.g., to build eachpartitions[r]), or- Remove the helper if the cached implementation is the only intended path.
This keeps the public surface lean and avoids duplication.
Also applies to: 2375-2390
2178-2338: Avoid redundant device moves inmerge_gather_resultsandgather_distributed.When
return_partials=False:
merge_gather_resultsdefaultstarget_devicetocuda:{current_device}when CUDA is available, and moves all tensors there.- Later,
gather_distributedmay movemerged_resultagain todevice(which can be a different GPU index).On multi‑GPU setups this can cause two
to()passes and cross‑GPU copies ifdevice≠cuda:{current_device}. To tighten this:
- Pass the caller’s
deviceexplicitly intomerge_gather_resultsfromgather_distributed, and- Skip the final
_move_gather_result_to_device(merged_result, device)when you already merged onto that device.This keeps semantics but avoids unnecessary transfers and surprises around which GPU is used.
Also applies to: 2484-2504
2339-2518: Sequential fallback after distributed gather can still yield divergent results across ranks.In the distributed path of
gather_with_reserve:
- When distributed gather fails or returns no gradients, every rank independently calls
self.gather(...)as a sequential fallback and uses its own result (primary), without any subsequent broadcast or synchronization.- Because
gather()internally does S3 I/O with retries and time‑window filtering, different ranks can legitimately end up with different subsets of successful UIDs and differentskipped_uids, while downstream code may be assuming cross‑rank consistency.For stricter consistency, consider:
- Having only
dist_helper.is_masterperform the sequentialgather(), then broadcasting the result (merged or partials) to all ranks viadist_helper.all_gather_object/broadcast_object, or- Explicitly documenting that in fallback mode each rank may see a different
primaryand that downstream consumers must not assume identical results.Also in partials mode:
- Sequential fallback currently returns a list with a single partial (or
[None]), and those partials do not have arankattribute set (unlike the distributed case). Ensure incremental consumers handle this gracefully (e.g., don’t rely onpartial.rankalways being present or list length == world size).These are behavioral clarifications rather than correctness bugs, but worth tightening given this is core training infrastructure.
Also applies to: 2523-2803
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
neurons/miner.py(9 hunks)neurons/validator.py(15 hunks)src/tplr/comms.py(20 hunks)src/tplr/neurons.py(13 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (3)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/comms.py (1)
gather_with_reserve(2523-2803)
neurons/validator.py (4)
src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/comms.py (3)
gather_with_reserve(2523-2803)merge_gather_results(2214-2337)get_with_retry(1625-1742)src/tplr/neurons.py (1)
check_uid_index_overlap(1604-1809)src/tplr/compress.py (1)
maybe_dequantize_values(647-697)
src/tplr/comms.py (2)
src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
src/tplr/neurons.py (2)
src/tplr/compress.py (6)
compress(373-377)compress(379-383)compress(386-423)ChunkingTransformer(145-292)TopKCompressor(295-697)unpack_12bit_indices(102-142)src/tplr/comms.py (1)
merge_gather_results(2214-2337)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (26)
src/tplr/comms.py (2)
666-669: Offloadingtorch.loadis a good change; verifyweights_only=Truematches your PyTorch floor.Moving both remote and local
torch.loadcalls intoasyncio.to_threadis exactly what you want to avoid blocking the event loop on CPU decompression. The only compatibility concern is the use ofweights_only=True, which is only supported on newer PyTorch versions; on older releases it will raise aTypeError. If you still support 1.x/early 2.x, consider guarding this flag behind a version check or kwargs construction; otherwise, document that the minimum PyTorch version includesweights_only.Also applies to: 710-723, 1564-1566
1846-1939: New parallel existence checks and slashing classification ingatherlook correct.The two‑phase approach—first collecting failed UIDs, then fanning out
s3_object_existschecks in parallel and using those results to distinguish “no bucket”, “file truly missing”, and “file exists but download failed”—matches the intended slashing semantics and avoids the previous serial HEAD bottleneck. The reuse of cachedexistence_resultsin both theExceptionandNonebranches is consistent, and skippingskipped_uidsupdates when the file exists prevents penalizing miners for network/timeout issues. No changes needed here.neurons/miner.py (8)
31-34: LGTM: CUDA allocator configured correctly.Setting
expandable_segments:Truebefore importing torch is a recommended practice for better memory management with PyTorch CUDA operations.
474-490: Well implemented: Commitments now synchronized across all ranks.The master rank refreshes commitments and broadcasts them along with peers, ensuring all ranks have consistent bucket mappings for distributed gather. This addresses the stale commitment concern from previous reviews.
552-560: LGTM: Proactive memory cleanup before training.The aggressive memory cleanup before
inner_stepsensures that memory from the previous outer step is fully released, helping prevent OOM errors.
738-824: Well designed: Distributed vs sequential gather mode selection and logging.The code correctly:
- Determines gather mode based on hparams and distributed state
- Ensures all ranks participate in distributed gather while only master performs sequential gather
- Handles both partial results (list) and merged results (SimpleNamespace) in logging
- Includes proper remote fetch parameters (
local=False,stale_retention=100) that were flagged in past reviews
852-856: LGTM: Gradients cleared after outer step.Properly clears gradients after the outer step is applied, preventing memory accumulation.
908-927: LGTM: Robust handling of both partial and merged gather results.The code correctly extracts successful and skipped UIDs from both partial lists and merged results, using sets to deduplicate when aggregating across partials.
963-986: LGTM: Comprehensive success rate calculation for all gather result types.The code correctly handles None, partial lists, and merged results when calculating success rates and extracting skipped UIDs for metrics logging.
1067-1081: LGTM: Thorough cleanup after gather_result consumption.The code properly clears state_dict contents from both partial and merged gather results after they've been fully used, followed by garbage collection and CUDA cache cleanup to reclaim memory.
src/tplr/neurons.py (8)
48-69: LGTM: Robust CUDA device resolution helper.This helper correctly converts various device representations to
torch.deviceobjects, preventing the string-device errors flagged in previous reviews. The fallback logic is sound.
250-321: LGTM: Clean dispatch pattern for incremental vs single outer step.The outer_step function now dispatches to the appropriate implementation based on whether gather_result is a list of partials or a single merged result, with consistent parameter passing to both paths.
324-368: Well implemented: Memory estimation now handles all container types.The recursive tensor iterator correctly walks through SimpleNamespace, dicts, lists, and tuples to accurately measure memory, addressing the previous concern about underestimation.
371-418: LGTM: Sensible batch size calculation based on memory budget.The function estimates average partial size from samples and calculates how many partials fit in the available budget, with appropriate safety margin and bounds checking.
421-443: LGTM: Simple wrapper delegating to existing merge logic.Appropriately delegates to
Comms.merge_gather_resultsfor consistency with the rest of the codebase.
446-584: Excellent: Incremental path merges all partials before applying once.The implementation correctly merges all partials into a single result before applying it via
_outer_step_single. This ensures consistency with catch-up behavior and prevents gradient double-application. The deterministic rank-based sorting ensures reproducibility.
1549-1563: LGTM: More stable step ratio calculation.Using the maximum of
param_avg_changeandlearning_rateprevents division by tiny numbers when the model has converged, resulting in more meaningful step-behind estimates.
1604-1809: Well optimized: O(n) lookup complexity for partial results.The function now supports both partial and merged gather results. The
uid_to_partialmapping (line 1625-1635) enables O(1) lookups when processing partials (line 1690-1693), addressing the O(n×m) complexity concern from previous reviews.neurons/validator.py (8)
90-141: Helpers for unified gather-result handling look solidThe new static helpers correctly normalize partial vs merged gather results, dedupe UIDs to avoid double-counting, and provide a consistent success-rate view for downstream logging and slashing logic. The implementations are straightforward and type‑safe for both list and SimpleNamespace inputs.
982-988: OpenSkill rating initialization avoids uninitialized entriesPre‑initializing
self.openskill_ratings[uid]for allwindow_uidsbefore callingself.openskill_model.rateprevents missing-rating edge cases and keeps the rating update path consistent for new peers.
1388-1419: Peer/commitment broadcast for distributed ranks is appropriateRefreshing
self.comms.commitmentson master and broadcasting[peers, reserve_peers, commitments]viadist.broadcast_object_listensures all ranks share a consistent view of gather participants and buckets, which is crucial for distributed gather correctness.
1462-1499: Timestamp centralization and broadcast fixes prior rank divergenceComputing
ts_valueonly on master with fallback logging, then broadcasting it as a tensor so all ranks derive the sametime_min/time_maxwindow, resolves the earlier inconsistency risk between ranks during time‑windowed gathers.
1542-1578: Distributed gather mode selection and logging are well‑structured (verify sequential path semantics)The
use_distributed_gatherpredicate, gather‑mode logging, and theif self.is_master or use_distributed_gathercondition cleanly separate sequential (master‑only) and distributed (all ranks) execution, whilereturn_partials=use_distributed_gatherties neatly into incremental processing.Since sequential mode now runs
gather_with_reserveonly on the master whenworld_size > 1anddistributed_gatheris disabled, please double‑check thattplr.neurons.outer_stepis designed to be called on non‑master ranks withgather_result=None(i.e., it must not dereferencegather_resultwhenis_masteris False).Also applies to: 1600-1636
1682-1725: Master‑only gather post‑processing and metrics reuse helpers correctlyThe master‑only block now avoids redundant merges by using the partial‑aware overlap check directly, uploads a single merged result, and offloads gather results to CPU when applicable. Reusing
_get_uids_from_gather_result,_get_skipped_uids_from_gather_result, and_get_success_rate_from_gather_resultfor overlap logging, gather‑quality metrics, and later slashing keeps the handling of partial vs merged results consistent.Also applies to: 1702-1706, 1723-1725, 1748-1759
2559-2564: outer_step incremental args: confirm signature and multi‑rank behaviourPassing
memory_budget_mb=self.hparams.gradient_memory_budget_mbandbatch_size_override=self.hparams.incremental_batch_sizeintotplr.neurons.outer_stepis a clean hook for incremental, memory‑bounded processing, and thegetattr(..., None)guards are safe.Please verify at the call site that:
- The
outer_stepsignature actually acceptsmemory_budget_mbandbatch_size_overridekeyword arguments.- Its implementation behaves correctly when these are
Noneand when running in distributed vs single‑rank modes.
3283-3299: Usingget_with_retryfor miner debug dicts improves robustnessSwapping to
self.comms.get_with_retrywith a 60‑second timeout, and treating bothNoneandsuccess == Falseas failure, gives a more resilient path for fetching miner debug dictionaries while cleanly falling back to a neutral sync score on errors. Thegetattr(debug_result, "error", "unknown")access is safe even whendebug_resultisNone.
| lookup = ( | ||
| lookup.to(val.device) if isinstance(lookup, torch.Tensor) else lookup | ||
| ) | ||
| if isinstance(lookup, torch.Tensor): | ||
| lookup = lookup.to(val.device) | ||
| if isinstance(shift, torch.Tensor): | ||
| shift = shift.to(val.device) |
There was a problem hiding this comment.
What is the reason for these changes?
There was a problem hiding this comment.
I was getting device mismatch errors here. The old code only moved lookup to val.device, but not shift. This ensures both lookup and shift are moved to the same device as val before the computation:
There was a problem hiding this comment.
There are many more paths on the outer step right now. Why do we need that? I think this complicates the code. I think it is better to keep outer step untouched for this one since the distributed gather would make master rank have the full gather results exactly like before. Can you explain a bit your changes?
There was a problem hiding this comment.
Yes, I think I fixed this now. The initial implementation had multiple code paths (outer_step_incremental for partials, _outer_step_single for merged results), but I have simplified it to a single path, exactly like before.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
neurons/trainer.py (1)
872-872: Critical: Duplicate gradient clipping—remove line 872.Gradients are being clipped twice: once at line 872 (before unscaling) and again at lines 887-889 (after unscaling). When using
GradScalerfor mixed precision training, gradients are scaled during the backward pass and must be unscaled before clipping. The clip at line 872 operates on scaled gradients, which is incorrect.Apply this diff to remove the duplicate clip:
log_loss = dist_helper.ddp_reduce( loss_item, op=ReduceOp.AVG, device=self.device ) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) if not null_round:Also applies to: 887-889
neurons/miner.py (2)
470-487: Broadcast missingreserve_peerson miner ranks—inconsistent gather behavior.The verification confirms the issue. In
neurons/validator.py(lines 1370–1378), the broadcast correctly includesreserve_peers:payload = ( [self.comms.peers, self.comms.reserve_peers, self.comms.commitments] if self.is_master else [None, None, None] ) dist.broadcast_object_list(payload, src=0) self.comms.peers = payload[0] self.comms.reserve_peers = payload[1] self.comms.commitments = payload[2]However, in
neurons/miner.py(lines 470–487), onlypeersandcommitmentsare broadcast. Sinceneurons/miner.pyline 762 callsgather_with_reservewithreserve_uids=self.comms.reserve_peers, non-master ranks will use a stale or default[]value, diverging from the master's reserves. Apply the suggested fix to broadcastreserve_peersalongsidepeersandcommitments:- # Broadcast peer list (and optionally commitments) from master to all ranks for distributed gather + # Broadcast peer list, reserve peers and commitments from master to all ranks for distributed gather if dist_helper.world_size > 1 and dist_helper.is_distributed(): - payload = ( - [self.comms.peers, self.comms.commitments] - if self.is_master - else [None, None] - ) - dist.broadcast_object_list(payload, src=0) - self.comms.peers, self.comms.commitments = payload + payload = ( + [self.comms.peers, self.comms.reserve_peers, self.comms.commitments] + if self.is_master + else [None, None, None] + ) + dist.broadcast_object_list(payload, src=0) + self.comms.peers, self.comms.reserve_peers, self.comms.commitments = payload
380-385: Start window must be sourced from chain via get_start_window()—not local current_window.The mining start window must come from the chain/R2 (via
get_start_window()) because validators publish a global start_window to the chain. Miners must use this global value to correctly calculateglobal_step = current_window - start_windowin sync with validators. Using a localcurrent_windowrisks desynchronization and incorrect global_step calculation.Re-enable the
get_start_window()call at neurons/miner.py line 381, or gate it behind a config flag if this change was intentional for testing only.src/tplr/neurons.py (1)
1447-1488: check_uid_index_overlap: partial‑aware logic and UID mapping look correctThe updated overlap checker:
- Accepts either a merged
SimpleNamespaceorlist[SimpleNamespace | None]and builds a flatuidslist plus auid_to_partialmap for O(1) lookup, avoiding the previous O(n×m) search across partials.- For each parameter, unpacks 12‑bit indices per UID using that map, keeping track of which global UID indices are valid for the parameter in
valid_uids_for_param.- Computes pairwise overlaps in the packed index domain, and records pair statistics keyed by global UID indices, later mapped back to
uids[i], uids[j]for logging and slashing decisions.The use of
valid_uids_for_param[i]/[j]ensures that pairs are always reported in terms of the original global UID ordering, even when some peers lack a given parameter. This matches the intended semantics and should behave correctly for both merged and distributed‑partials inputs.Also applies to: 1523-1561, 1580-1595
♻️ Duplicate comments (2)
neurons/validator.py (1)
3821-3919: compute_peer_val_norms: partials path can raise UnboundLocalError and is misleadingly structured.When
gather_resultis alist[SimpleNamespace], the function fillsclip_norm_dictinside theif isinstance(gather_result, list)block but then falls through to the merged‑result section and the final loop, wherestate_dictis undefined. If this path is ever exercised, it will raiseUnboundLocalErrorat Line 3906. The structure also suggests both paths run, which is confusing.Minimal fix: return immediately after computing norms for partials so the merged branch and final loop are skipped in that case:
@@ - # Handle both merged results and partials - if isinstance(gather_result, list): + # Handle both merged results and partials + if isinstance(gather_result, list): @@ - norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device) - clip_norm_dict[vals_key] = torch.median(norms) - else: + norms = torch.stack([torch.norm(v, p=2) for v in vals_f32]).to(p.device) + clip_norm_dict[vals_key] = torch.median(norms) + # For partials, we've fully populated clip_norm_dict and can return early + return clip_norm_dict + else: # Merged result: original logic state_dict = gather_result.state_dict if not state_dict: raise ValueError("Must have gather_result.state_dict to compute norms") @@ - for n, p in self.model.named_parameters(): + for n, p in self.model.named_parameters(): vals_key = n + "vals" quant_key = n + "quant_params"You may also want to re‑introduce the stricter
vals/quant_paramsalignment checks you previously had here to avoidIndexErrorinmaybe_dequantize_valueswhen partials have inconsistent quantization layouts.src/tplr/comms.py (1)
1669-1671: Early cutoff on NOT_FOUND still breaks timeout semantics (duplicate of prior concern)The
consecutive_not_foundlogic still gives up after 5 NOT_FOUNDs (~2.5 s at 0.5 s sleep) regardless of the caller’stimeout, andget()currently collapses timeouts, network glitches, and real 404s intostatus="NOT_FOUND". That means slow or flaky peers can be treated as “likely doesn’t exist” well beforetimeoutexpires, contradicting the contract ofget_with_retry. This is the same issue raised in the earlier review; it remains unresolved in the current code. Consider either:
- Distinguishing true 404s vs transient/network errors in
get()and only incrementingconsecutive_not_foundfor confirmed 404s, or- Deriving
max_consecutive_not_foundfromtimeout(or dropping the early cutoff entirely) so that transient errors are retried for the full timeout budget.Also applies to: 1725-1737
🧹 Nitpick comments (13)
neurons/validator.py (1)
90-108: Helpers currently assume merged gather results only (naming/docs still reference partials).
_get_uids_from_gather_result,_get_skipped_uids_from_gather_result, and_get_success_rate_from_gather_resultoperate only on mergedSimpleNamespaceresults, while the surrounding comment mentions “partials vs merged gather results”. That’s fine for current call sites (they always pass merged results), but either updating the comment/docstrings to match, or extending these helpers to handlelist[SimpleNamespace]defensively, would reduce the chance of misuse if partials are re‑introduced later.neurons/miner.py (2)
32-35: CUDA allocator env var placement is fine but only effective when this is the entrypoint.Setting
PYTORCH_CUDA_ALLOC_CONFbefore importingtorchin this module is correct. Just be aware this won’t affect runs where another module importstorchfirst; if that matters, consider setting it in the launcher or a shared bootstrap module.
1018-1027: End‑of‑window gather_result cleanup is safe but only clears SimpleNamespace state_dicts.Clearing
gather_result.state_dict.__dict__and then deletinggather_resultafter all uses (outer_step, debug, metrics) is a good way to reduce peak memory before the next window. Note that this only affectsSimpleNamespacestate dicts; ifgather_with_reserveever returns a plaindictforstate_dict, it will remain for GC to handle, which is acceptable but slightly less aggressive.src/tplr/comms.py (5)
20-21: Partition cache for UID splits looks sound; minor reuse opportunityThe
_partition_cachekeyed by(window, sorted_uids, world_size)with LRU eviction gives deterministic, bounded partitioning across calls, which is appropriate here. Sincepartition_uids_across_ranks()implements the same round‑robin logic, consider reusing that helper inside the cache‑fill branch to avoid duplicating the slicing scheme in two places.Also applies to: 29-29, 142-144
483-582: s3_object_exists semantics are conservative and aligned with slashing requirementsThe new
s3_object_existsimplementation cleanly separates 404/NoSuchKey, permission/config errors, and transient network/service issues using structuredClientError.responsefields, and errs on the side of “exists” for transient failures, which is consistent with avoiding false slashing. One minor follow‑up you could consider is purging the S3 client (_purge_s3_client) on persistent config errors (e.g., invalid bucket) to avoid reusing a known‑bad client, though this is an optimization rather than a correctness fix.
666-669: Verify torch.load(weights_only=True) behavior and map_location choicesOffloading
torch.loadtoasyncio.to_threadis a good fix for event‑loop blocking, and the higher 500 MB threshold for chunked downloads is reasonable. Two things to double‑check:
weights_only=Trueis only available in newer PyTorch versions and may behave differently for arbitrary dict payloads; please confirm the project’s minimum torch version and that these gradient/aggregator payloads still round‑trip as expected.- For remote loads you honor
map_location(orself.config.device), but for local loads you now use the default device. If local gradient/aggregator files can be large and validators often run with GPUs attached, you may want to thread amap_locationthroughget(..., local=True, ...)as well to avoid surprises/OOMs on unintended devices.Also applies to: 709-723
2224-2356: Distributed gather core flow looks correct; note unused target_device and list element NullabilityThe distributed gather path (per‑rank
gather→ optional local merge →all_gather_object→ global merge or partials) is coherent and matches the intended API:
- Partitioning via cached round‑robin slices is deterministic across ranks.
all_partialsnormalization toNonefor emptyuidsavoids downstream surprises.- When
return_partials=True, preserving the per‑rank structure (with optionalpartial.rank) is useful for incremental consumers.Two minor nits:
merge_gather_resultscomputestarget_devicebut never uses it when building the merged state; either wire this into tensor placement or drop the parameter to avoid confusion.- The return type for
gather_distributedallowslist[SimpleNamespace | None], but callers like_outer_step_incrementaltreat the list aslist[SimpleNamespace]and filter outNonethemselves. That’s safe, but consider tightening the type hints or documenting thatNoneentries are expected placeholders.
2396-2521: gather_with_reserve: distributed fallback and type normalization mostly consistent; confirm call‑site expectationsThe refactored
gather_with_reservecleanly separates:
- Primary distributed gather (with
gather_distributed) whenworld_size>1, including detection of an “all empty” result across ranks.- Shared fallback to sequential
gatheron all ranks when distributed gather fails, with wrapping into a list whenreturn_partials=Trueto keep shapes consistent.- Sequential path for non‑distributed runs, again normalizing to
[SimpleNamespace]when returning partials.Two points to verify:
Cross‑rank consistency on fallback – Each rank independently calls
gather(...)in the distributed‑fallback case. That’s fine ifgather_with_reserveis only ever invoked on the master rank (non‑masters passgather_result=Noneintoouter_stepand rely on broadcast), but would yield per‑rank divergence if non‑master ranks also consumeprimary. Please confirm master‑only usage at call sites or add a broadcast step after sequential fallback if other ranks might rely on the value.Partials semantics – When falling back to sequential in distributed mode with
return_partials=True, you now return a single‑element list rather than one entry per rank. That’s acceptable for consumers that only treat “partials” as arbitrary chunks of UIDs (as_outer_step_incrementaldoes), but it no longer encodes rank topology; clarifying this in the docstring would help future readers.src/tplr/neurons.py (5)
22-22: _resolve_cuda_device helper is reasonable but currently unusedThe
_resolve_cuda_devicehelper handles various string/device/None inputs and CUDA initialization edge‑cases sensibly. However, its result (cuda_device) is computed in_outer_step_singleand then never used, and all CUDA API calls now operate on the default device only. Either wirecuda_deviceinto the few CUDA calls that still need an explicit device (if any), or remove the helper and local variable to keep the code simpler.Also applies to: 48-69
250-313: outer_step increment/merge flow: dispatcher is correct; note behavior vs “incremental” goal and inner importThe new
outer_stepdispatcher cleanly routes:
gather_result: list[...]→_outer_step_incrementalgather_result: SimpleNamespace | None→_outer_step_singleand keeps the public API uniform across call sites like
catchup_with_aggregation_server, which is good.Within
_outer_step_incremental:
Nonepartials are ignored via thevalid_partialsfilter, so it can safely consume the lists returned fromgather_with_reserve(return_partials=True).- Partials are sorted deterministically by
partial.rank(or metadata) before merging, and then merged once via_merge_partial_batch→Comms.merge_gather_results, after which_outer_step_singleis executed once. This “merge then apply once” design matches catch‑up semantics and avoids per‑partial re‑application.Two small points:
- The function name/Docstring still call this “incremental” processing, but the implementation now merges all partials in one go before applying. Peak memory will thus look similar to a single merged gather. If the original intent was memory‑bounded, batched merging, it may be worth either adjusting the description or re‑introducing batched processing later.
_merge_partial_batchperforms an in‑functionfrom tplr.comms import Commsimport. For consistency with the rest of the module (and to avoid repeated imports under heavy training loops), it would be cleaner to move that import to module scope.Also applies to: 341-459
461-477: _outer_step_single: core logic preserved; minor dead codeThe refactored
_outer_step_singlepreserves the original behavior:
- Only the “source rank” (
is_masteror non‑DDP) consumesgather_resultand buildssrc_sd; other ranks rely on broadcast ordistribute_tensor.- Per‑parameter decompression, decode, and immediate
optimizer.step()with aggressive freeing remain intact.- Fingerprint accumulation and optional wandb logging are still gated on the source rank.
The only notable leftover is
cuda_device = _resolve_cuda_device(device), which is no longer used in any CUDA calls (they now all use default‑device operations). This can be safely dropped to avoid confusion.Also applies to: 493-515
1393-1407: compare_model_with_debug_dict: step‑ratio denominator change improves stabilitySwitching the step‑ratio denominator to
max(param_avg_change, learning_rate)prevents extremely smallparam_avg_changevalues from inflating “steps behind” metrics when the model has mostly converged. The fallback to a constantlearning_ratewhen shapes mismatch or no EMA is available is a sensible conservative choice. Just note that themin_step_sizeargument is now effectively unused; if you don’t expect to bring it back, consider removing it from the signature to avoid misleading callers.
1447-1488: check_uid_index_overlap: partial‑aware logic and UID mapping look correctThe updated overlap checker:
- Accepts either a merged
SimpleNamespaceorlist[SimpleNamespace | None]and builds a flatuidslist plus auid_to_partialmap for O(1) lookup, avoiding the previous O(n×m) scan across partials.- For each parameter, unpacks 12‑bit indices per UID using that map, keeping track of which global UID indices are valid for that parameter in
valid_uids_for_param.- Computes pairwise overlaps on these unpacked indices and stores stats keyed by global UID indices, later mapped back to
uids[i], uids[j]for logging and slashing.This matches the intended semantics for both merged and distributed‑partials inputs and keeps complexity under control.
Also applies to: 1523-1561, 1580-1595
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
neurons/miner.py(10 hunks)neurons/trainer.py(3 hunks)neurons/validator.py(14 hunks)src/tplr/comms.py(19 hunks)src/tplr/neurons.py(13 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (3)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/comms.py (1)
gather_with_reserve(2396-2672)
src/tplr/neurons.py (2)
src/tplr/compress.py (5)
compress(373-377)compress(379-383)compress(386-423)TopKCompressor(295-697)unpack_12bit_indices(102-142)src/tplr/comms.py (1)
merge_gather_results(2114-2222)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
neurons/validator.py (6)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/comms.py (2)
gather_with_reserve(2396-2672)get_with_retry(1625-1740)src/tplr/neurons.py (1)
check_uid_index_overlap(1447-1652)src/tplr/compress.py (1)
maybe_dequantize_values(647-697)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (9)
neurons/trainer.py (2)
911-911: LGTM: Correct handling of grad_norm during spin-up.Setting
total_grad_normto 0.0 during null rounds (spin-up phase) is appropriate since no optimizer step is taken.
928-928: LGTM: Useful logging enhancement.Adding gradient norm to the inner step logs provides valuable insight into training dynamics.
neurons/validator.py (2)
1355-1387: Distributed gather orchestration and timestamp/peer synchronization look consistent.Broadcasting peers, reserve_peers, and commitments; using a single timestamp window across ranks; and gating gather_with_reserve via
use_distributed_gatherwith master‑driven skip decisions keep all ranks aligned and avoids rank divergence or deadlocks in the new distributed gather path. The logging/use of the new_get_*_from_gather_resulthelpers on the merged result is also coherent withreturn_partials=False.Also applies to: 1429-1472, 1509-1592, 1638-1676, 1698-1710
3230-3247: Switch to get_with_retry for miner debug dict improves robustness.Using
self.comms.get_with_retrywith a bounded timeout and handling theNonecase before dereferencing the result prevents hangs and attribute errors when miners don’t publish debug data in time. The warning log with a safegetattr(debug_result, "error", "unknown")is also defensive.neurons/miner.py (4)
553-560: Pre‑training CUDA cleanup and gradient reset are aggressive but appropriate for OOM mitigation.Calling
torch.cuda.empty_cache(),torch.cuda.synchronize(), andmodel.zero_grad(set_to_none=True)beforeinner_stepsis a reasonable trade‑off to reduce fragmentation/VRAM pressure between windows, given this code assumes CUDA is available.
832-835: Post‑outer‑step gradient clearing is good hygiene.Zeroing gradients with
set_to_none=Trueafter a successful outer step avoids gradient accumulation and can help allocator reuse; this matches how you clear grads before training and incleanup_window().
885-895: Gather success and skipped peer metrics wiring is consistent.Using the merged
gather_result’ssuccess_rateandskipped_uidsfor:
- Debug fields (
successful_peers/skipped_peers), and- Per‑window metrics (
miner/gather/success_rateandskipped_peersJSON in Influx),keeps miner metrics aligned with validator‑side gather reporting. This assumes
gather_resultis always merged (which you enforce viareturn_partials=False).Also applies to: 930-937, 999-1002
739-776: Review comment is accurate; no changes requiredVerification confirms all claims in the review comment are correct:
gather_with_reserve return type: When called with
return_partials=False(line 764), the function returns a mergedSimpleNamespace, not a list. The implementation respects this parameter throughout all code paths (comms.py lines 2508, 2520, 2535, 2672).Usage consistency: The code at lines 771–797 correctly accesses
gather_result.uids,gather_result.skipped_uids, andgather_result.success_rate—all attributes of aSimpleNamespace. Line 826 passesgather_resulttoself.outer_step()as expected for a merged result object.Distributed gather coherence: The conditional logic at line 759 correctly ensures all ranks participate when
use_distributed_gather=True, while sequential mode remains master-only.Synchronization: The
should_updatereduction viadist_helper.all_ok()at lines 806–809 properly keeps all ranks synchronized on the decision to apply an outer step.src/tplr/comms.py (1)
1472-1493: gradient_timestamp default version handling is straightforwardAllowing an explicit
versionwhile defaulting totplr.__version__whenNonemakes this helper usable for both current and historical gradient versions without duplicating call sites. The error path still safely purges the S3 client on failure and returns0.0, which matches existing semantics.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/tplr/comms.py (1)
1815-1875: Major semantic change ingather()return value breaks multiple call sites relying onNonechecksThe verification confirms the review comment is accurate. The new
gather()implementation always returns aSimpleNamespaceobject, even when no valid gradients are received—it never returnsNone. This breaks control flow in at least two critical locations:
neurons.py:1189-1199– The catchup fallback decides whether to skip the window usingif gather_ns is None. With the new behavior, this condition will always be false, and windows with zero valid gradients will proceed throughouter_step(advancing schedulers andglobal_step) instead of being skipped.
comms.py:2645– The checkif primary is None:ingather_with_reserve()is now unreachable, breaking the fallback logic for empty results.
comms.py:2391, 2621, 2633, 2710– Other gather() call sites have similar None checks or truthiness checks that will now behave unexpectedly.Action required:
- Either return
Nonefromgather()whenvalid_uidsis empty and not in a distributed-merge scenario, or- Update all
gather()call sites to checklen(result.uids)instead ofresult is None, ensuring callers properly handle empty-but-non-None results.At minimum, audit all listed call sites to confirm whether the new "silent no-op window" behavior is intentional.
♻️ Duplicate comments (1)
src/tplr/comms.py (1)
1665-1671: Early NOT_FOUND cutoff still conflates 404 with timeouts/network errors.
get_with_retrynow stops aftermax_consecutive_not_found = 5, butget()still returnsstatus="NOT_FOUND"for alls3_get_objectNonereturns, including timeouts and transient client/network errors. This preserves the earlier problem: slow/flaky peers can be treated as “likely doesn’t exist” and given up on after ~2.5s (5 × 0.5s), independent of the caller’stimeout.To align semantics with the
timeoutparameter and the new structured S3 helpers, consider:
- Distinguishing true 404 “not found” from transient errors in
get()(e.g., propagate a status enum or error type instead of collapsing to NOT_FOUND).- Only incrementing
consecutive_not_foundfor confirmed 404s, not for timeouts or generic errors.- Optionally deriving
max_consecutive_not_foundfromtimeout(e.g., a fraction of it) rather than a fixed 5.This keeps the fast‑fail optimization for non-existent objects without prematurely aborting legitimate slow uploads.
Also applies to: 1725-1737
🧹 Nitpick comments (5)
src/tplr/comms.py (3)
483-582: s3_object_exists error handling looks solid; consider also purging bad clients.The structured handling of 404/403/5xx vs transient errors is good and matches the “don’t slash on network issues” goal. One improvement: for
ConnectionClosedErrorandClientErrorcases where you decide to treat the error as transient (5xx or generic network issues), you currently don’t call_purge_s3_client, unlike other S3 helpers in this class. That means a poisoned client (closed connection, stale DNS, etc.) may be reused repeatedly.It would be safer to purge and recreate the client on these non-404/-403 branches as well, similar to
s3_get_objectanddownload_large_file.
2096-2327: Distributed helpers and merge_gather_results behaviour are mostly sound; minor cleanups possible.The helpers encapsulate the distributed flow well:
partition_uids_across_ranksuses deterministic sort + round‑robin; good for reproducibility._move_gather_result_to_deviceand_convert_result_dtypecorrectly recurse through nested containers and avoid breaking 12‑bitidxs.merge_gather_resultscorrectly concatenates per‑param lists, aggregates metrics, and returnsNonewhen no UIDs succeeded.Two minor points:
merge_gather_resultscomputes atarget_devicebut never actually uses it to move tensors; callers instead explicitly move results via_move_gather_result_to_device. Either:
- Remove
target_devicefrommerge_gather_results(and the discovery logic) to avoid confusion, or- Apply it before returning, e.g. by calling
_move_gather_result_to_device(merged, target_device).When called with a list of partials where all have empty
uids,merge_gather_resultswill log and returnNone(good), but the earlier normalisation ingather_distributed(treating empty‑uidspartials asNone) is now redundant. You could rely on one of these mechanisms to simplify the code.Neither issue is correctness‑critical, just opportunities to tighten the implementation.
2522-2792: gather_with_reserve distributed fallback and partial handling look consistent; be aware sequential fallback can still diverge cross‑rank.The refactored
gather_with_reservenicely integrates distributed gather:
- In distributed mode all ranks participate in
gather_distributed, even if they have no UIDs, avoiding deadlocks at the all‑gather barrier.- The
return_partialsflag is propagated through to both the distributed path and the sequential fallback, and you now wrap sequential results as a single‑element list to keep type contracts consistent.- The final success‑rate recomputation for both merged and partials paths keeps metrics aligned with the original primary tier.
One caveat: when distributed gather fails and you fall back to sequential, each rank independently runs
self.gather(...)with the fullgather_uidslist. That preserves type contracts but can still result in different primary results per rank under asymmetric S3/network conditions (timeouts, transient errors), because there’s no subsequent broadcast or re‑synchronization of the sequential result.If strict cross‑rank consistency is required even in this fallback path, consider:
- Having only the master rank run the sequential gather, then broadcasting the resulting merged object (or partial list) via
dist_helper.broadcast_object, or- Ensuring all ranks perform the sequential gather but then choose a canonical one (e.g., rank 0’s) to apply, again via broadcast.
If your downstream logic is already robust to per‑rank variation in the “all ranks fallback” case, the current approach is acceptable; otherwise, a broadcast step would harden guarantees.
src/tplr/neurons.py (2)
48-70: outer_step dispatch and device resolution are fine, but type hints and unused cuda_device could be tidied.The new
outer_stepdispatcher cleanly separates the “list of partials” case (incremental path) from the singleSimpleNamespacecase, and_outer_step_incrementalcorrectly funnels back into the original single‑result implementation. That’s a good way to keep semantics aligned with catch‑up.Two small nits:
The signature
gather_result: SimpleNamespace | list[SimpleNamespace] | Nonedoesn’t reflect that partial lists can (and do) containNoneentries (fromgather_with_reserveandgather_distributed)._outer_step_incrementalhandlesNonesafely by filtering them out, but the type hint is misleading. Consider updating it toSimpleNamespace | list[SimpleNamespace | None] | None.
_resolve_cuda_deviceis invoked in_outer_step_single, but the returnedcuda_deviceisn’t used. Unless you plan to plug it into CUDA‑specific calls soon, this is dead code and could be removed, or the helper’s result should be wired into the operations that actually require atorch.deviceinstead of a string.These are minor maintainability issues; behaviour looks correct.
Also applies to: 250-314, 461-497
1029-1101: Catch‑up debug comparison and memory logging are valuable but add overhead; ensure they’re acceptable for your deployment.The catch‑up loop now:
- Logs detailed CUDA memory stats before, during, and after replay.
- Optionally verifies the initial checkpoint against the leader’s debug dict.
- For each window, fetches and compares debug dicts, and tracks an EMA of small parameter slices to estimate “steps behind”.
Functionally this is all sound and gated on
instance.is_master, but it does add:
- Extra S3 GETs for
debugobjects at each window.- Per‑window per‑parameter CPU work (slicing, EMAs, diff stats).
- Occasional
compare_model_with_debug_dictcalls.If you run very long catch‑up sequences or have large models, this may noticeably extend catch‑up time. If that’s a concern, consider adding a simple flag/hparam to enable/disable the per‑window debug‑dict comparison (keeping the memory logging, which is cheaper) so operators can turn it off in production if needed.
Also applies to: 1217-1271, 1337-1356
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
neurons/miner.py(9 hunks)src/tplr/comms.py(22 hunks)src/tplr/neurons.py(14 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
neurons/miner.py (3)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/comms.py (1)
gather_with_reserve(2516-2792)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
src/tplr/neurons.py (2)
src/tplr/compress.py (5)
compress(373-377)compress(379-383)compress(386-423)maybe_dequantize_values(647-697)unpack_12bit_indices(102-142)src/tplr/comms.py (1)
merge_gather_results(2208-2326)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (15)
src/tplr/comms.py (1)
666-677: Offloading torch.load and raising the size threshold is a good fix.Moving
torch.loadintoasyncio.to_threadand increasing the “simple download” threshold to 500MB is a solid improvement: it prevents the event loop from being blocked on CPU decompression while still using the chunked downloader for truly large objects. The error logging around failedtorch.loadcalls is also appropriate and keeps failures non-fatal for the calling code.Also applies to: 709-723
src/tplr/neurons.py (4)
316-459: Incremental outer step merging via Comms.merge_gather_results is consistent and robust.The incremental path does the right thing:
- Filters out
Nonepartials and sorts bypartial.rank(or embedded metadata) for deterministic ordering.- Uses
Comms.merge_gather_resultsto merge all partials once, then applies a single outer step via_outer_step_single, matching catch‑up semantics.- Wraps both merge and apply in try/except blocks with clear logging and aggressive
gc.collect()/torch.cuda.empty_cache()cleanup.Given that
merge_gather_resultsalready computes aggregate metrics and handles the “no UIDs at all” case by returningNone, this reuse avoids duplicating intricate merge logic in neurons.py.Overall this is a good design choice and looks correct as long as all partials are Comms‑style gather results (which they are in the current flow).
Also applies to: 579-587
1413-1425: compare_model_with_debug_dict now uses param_avg_change more meaningfully.The revised logic that, when
param_avg_changeis available, uses it as the denominator (with amin_step_sizefloor) instead of the raw learning rate gives a much more realistic estimate of “steps behind”:
- It respects per‑parameter observed change magnitudes.
- It still falls back safely to
learning_ratewhen stored slices are mismatched or unavailable.This should make the reported
avg_steps_behindmetric significantly more informative during catch‑up and debugging.
1468-1510: check_uid_index_overlap: partial‑aware path looks correct; index bookkeeping is subtle but consistent.The extension to support both merged and partial gather results is well thought out:
- For partials, you build a global
uidslist and auid_to_partialmap(partial, index_in_partial), avoiding the previous O(n×m) scan across partials.- For merged results, you retain the original logic by sharing one
state_dictacross all UIDs.Inside the per‑parameter loop:
- You correctly skip UIDs that lack
idxs/valsentries, and only append tounpacked_indicesplusvalid_uids_for_paramfor those that do.- You then work in the compact
P‑sized space (valid peers for that param), but store pair‑accumulation keys in terms of the original UID indices viauid_idx_i = valid_uids_for_param[i]/uid_idx_j = valid_uids_for_param[j].- Later, when summarizing, you map back through
uids[i]/uids[j]for human‑readable pairs, which stays consistent with the keys you used inpair_acc.Overall, the indexing is non‑trivial but appears internally consistent and should produce correct overlap and slashing decisions for both merged and partial inputs.
Also applies to: 1542-1593, 1601-1617
563-587: Perfect! I found that the codebase requirestorch>=2.7.1. Now let me verify that torch.Tensor.to() supports device strings in this PyTorch version.Review verified: device string support confirmed.
PyTorch's device handling supports device strings like "cuda:0"—methods taking a device argument accept properly formatted strings as arguments. The codebase requires
torch>=2.7.1, and tensor.to() accepts device strings including 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. The guard and dequantization logic in the code is correctly designed for this API.neurons/miner.py (10)
32-34: LGTM: CUDA memory allocator configuration.Setting
expandable_segments:Truebefore importing torch is a good practice to reduce memory fragmentation in distributed training. The check for existing environment variable ensures user configurations are not overwritten.
42-42: LGTM: Distributed training import.Required import for the distributed gather functionality implemented in this PR.
552-559: LGTM: Pre-training memory hygiene.Aggressively clearing CUDA cache and lingering gradients before training is a sound practice for preventing OOM in multi-GPU setups. The defensive check for model existence is appropriate.
738-802: LGTM: Distributed gather implementation with proper fallback and logging.The gather mode selection and logging are well-structured:
use_distributed_gatheris correctly determined from world size and distributed statusreturn_partials=Falseensures all ranks receive a merged result (not a list of partials), consistent with the design intent from past reviewslocal=Falsecorrectly pulls gradients from R2 remote storage- Detailed logging distinguishes distributed vs sequential modes with appropriate metrics
- All attribute accesses (
.uids,.skipped_uids,.success_rate) are consistent with the merged result typeshould_updateis broadcast to ensure all ranks agree on whether to perform outer stepBased on past reviews and PR objectives.
831-833: LGTM: Post outer-step gradient clearing.Clearing gradients after the outer step with
set_to_none=Trueis appropriate for memory efficiency. The defensive check ensures safety.
886-893: LGTM: Safe UID extraction from merged gather result.Uses
getattrwith defaults to safely extract UIDs and compute successful/skipped peer sets. The logic correctly derives successful peers as the set difference between all peers and skipped UIDs.
929-936: LGTM: Consistent success rate calculation.Safely handles the None case explicitly and accesses
success_ratedirectly from the merged result. The use ofgetattrforskipped_uidsadds defensive robustness even though the merged result type guarantees the attribute exists.
1000-1000: LGTM: Metrics logging with skipped peers.Correctly uses the safely extracted
skipped_uids_listfrom the merged gather result. JSON encoding is appropriate for InfluxDB storage.
1017-1027: LGTM: Thorough cleanup of gather artifacts.The cleanup correctly occurs after all uses of
gather_result(outer step, debug dict, metrics logging). Clearingstate_dict.__dict__is aggressive but appropriate for reclaiming memory in distributed training, as the comment confirms this is the only remaining reference.
474-490: Broadcast reserve_peers along with peers and commitments.The code correctly broadcasts
peersandcommitmentsto all ranks, addressing the past concern about stale bucket mappings. However,self.comms.reserve_peersis used on line 761 but is not included in the broadcast payload. Sinceupdate_peersonly runs on the master rank, non-master ranks will have stale or uninitializedreserve_peersvalues.Apply this diff to include reserve_peers in the broadcast:
# Broadcast peer list (and optionally commitments) from master to all ranks for distributed gather if dist_helper.world_size > 1 and dist_helper.is_distributed(): payload = ( - [self.comms.peers, self.comms.commitments] + [self.comms.peers, self.comms.commitments, self.comms.reserve_peers] if self.is_master - else [None, None] + else [None, None, None] ) dist.broadcast_object_list(payload, src=0) - self.comms.peers, self.comms.commitments = payload + self.comms.peers, self.comms.commitments, self.comms.reserve_peers = payload if not self.is_master: tplr.logger.info( f"[Rank {dist_helper.rank}] Received peers from master: {self.comms.peers}" )Likely an incorrect or invalid review comment.
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/tplr/comms.py (1)
1551-1567: Local torch.load path should honor map_location for symmetry and OOM avoidance.The local branch ignores the
map_locationparameter while the remote S3 branch correctly uses it. This asymmetry means callers requestingmap_location="cpu"to avoid GPU OOM will only get that behavior for remote loads, not local ones.The remote implementation pattern (lines 705-718) already shows the correct approach:
device_location = ( map_location if map_location is not None else self.config.device ) loaded_data = await asyncio.to_thread( torch.load, temp_file_path, map_location=device_location, weights_only=True, )Apply the same fix to the local branch (around line 1564) for consistency and to honor the caller's explicit device mapping requests.
🧹 Nitpick comments (19)
neurons/miner.py (3)
552-560: Aggressive CUDA cleanup is effective but may be overly heavy on every windowCalling
torch.cuda.empty_cache()+torch.cuda.synchronize()before every inner loop, plus a fullmodel.zero_grad(set_to_none=True), is a strong guardrail against OOM but adds non‑trivial latency per window.If you find throughput becomes an issue, consider gating this behind a hyperparameter (e.g.
hparams.aggressive_cleanup) or a less frequent cadence (every N windows) while keeping the current behavior as the safe default.
884-892: Gather debug and metrics fields line up with merged gather_resultThe debug and metrics paths:
- Use
gather_result.uids/gather_result.skipped_uidsfrom the merged result.- Derive
successful_peersaspeers − skipped, which matches the comms contract whereskipped_uidsis defined relative to the original peer list.- Log
gather_success_rate = gather_result.success_rate * 100and JSON‑encodeskipped_uids_listfor Influx.This is consistent with the merged‑result semantics enforced by
return_partials=Falseand should give accurate visibility into gather quality.Also applies to: 927-933, 997-997
1014-1023: End‑of‑window gather_result cleanup is safe and targetedThe cleanup block:
- Deletes
shard_gradient.- For a non‑None
gather_result, clearsgather_result.state_dict.__dict__when it is aSimpleNamespace, then deletesgather_result.- Forces
gc.collect()andtorch.cuda.empty_cache().Given that all uses of
gather_result(outer_step, debug, metrics) are complete by this point and no background tasks hold references to it, this is a reasonable way to aggressively release memory between windows.neurons/validator.py (3)
1491-1561: Distributed gather orchestration is correct; comment about partials is now misleadingThe gather block now:
- Forces
use_distributed_gather = Trueso all ranks participate ingather_with_reserve, letting the comms layer decide between distributed vs sequential internally.- Calls
gather_with_reservewith:
gather_uids=self.comms.peersreserve_uids=self.comms.reserve_peersreturn_partials=Falselocal=False, extendedtimeout=150,time_min/time_max, andexpected_compressed_params.- Logs distributed gather completion using
gather_result.uids,gather_result.skipped_uids, andgather_result.success_rate.That wiring is consistent with the comms implementation and the later uses of
gather_result(upload, overlap, slashing, outer_step).The inline comment at Line 1517, though, still says “Request partials for incremental processing in outer_step” while
return_partials=Falseensures a merged result. To avoid confusion for future refactors (especially if partials are re‑enabled), consider updating or removing that comment:- # Request partials for incremental processing in outer_step gather_result = await self.comms.gather_with_reserve( @@ - return_partials=False, # Always return merged result (distributed gather merges internally) + return_partials=False, # Always return merged result; partials are handled inside comms when enabled
1823-1825: compute_peer_val_norms currently only supports merged results despite broader type hint
compute_peer_val_normsis now typed as:def compute_peer_val_norms( self, gather_result: SimpleNamespace | list[SimpleNamespace], ) -> dict[str, torch.Tensor]:but the implementation:
- Assumes
gather_resulthas astate_dictattribute.- Treats
state_dictas an attribute container (getattr(state_dict, vals_key, None)), i.e., aSimpleNamespaceor similar.- Does not branch on the list‑of‑partials case.
Given that this call site passes
return_partials=Falsetogather_with_reserve, the function is currently only invoked with a mergedSimpleNamespace, so behavior is correct today. The signature, docstring, and nearby comment (“handles both partials and merged results”) are misleading, though, and increase the risk of someone later enablingreturn_partials=Trueand getting a hard failure.Two options:
- Narrow the contract to merged results only (simplest, current behavior):
- def compute_peer_val_norms( - self, - gather_result: SimpleNamespace | list[SimpleNamespace], - ) -> dict[str, torch.Tensor]: + def compute_peer_val_norms( + self, + gather_result: SimpleNamespace, + ) -> dict[str, torch.Tensor]: @@ - """ - Compute median L2 norms across all peers for each parameter. - ... - """ + """Compute median L2 norms across peers for each parameter using a merged gather_result."""
- Or, if you want real partial support, add a branch that merges partials (or reuses
self.comms.merge_gather_results) before the existing body.Given the current usage, (1) is enough and avoids accidental misuse.
Also applies to: 3792-3804
1412-1439: Minor note: local use_distributed_gather flag is now purely for loggingWith:
use_distributed_gather = True ... if self.is_master or use_distributed_gather: gather_result = await self.comms.gather_with_reserve(...)the local
use_distributed_gatherno longer reflects the actual distributed/ sequential mode (that decision is made insidegather_with_reservebased ondist_helper.is_distributed()). It is effectively a “call on all ranks” flag, and theif use_distributed_gather:log message will say “Starting distributed gather …” even in the single‑rank case.This is harmless but a bit confusing. Consider either:
- Renaming the local flag (e.g.,
all_ranks_participate = True), or- Restoring the previous conditional definition that mirrors
gather_with_reserve’suse_distributedcalculation if you want logs to distinguish true distributed from single‑rank gathers.Functionally everything is correct; this is just about log clarity.
Also applies to: 1491-1492
src/tplr/neurons.py (4)
48-70: _resolve_cuda_device is correct but currently unused downstream.The helper cleanly normalizes string / torch.device / None into a safe CUDA device or None, but
_outer_step_singleonly assignscuda_deviceand never uses it. Either wire this into future CUDA calls or drop the local variable to avoid confusion.
297-333: Minor clean‑up in _outer_step_single.
cuda_device = _resolve_cuda_device(device)is currently dead code in this function. If you don't plan to use it for CUDA‑only operations (e.g., memory stats, sync), consider removing the assignment or usingcuda_devicein any future CUDA calls to keep intent clear.
520-526: Duplicate comment line before optimizer.step.Lines 520–521 repeat the same comment twice. Safe to drop one of them to reduce noise.
- # ---- apply update immediately for THIS param and free its grad ---- - # ---- apply update immediately for THIS param and free its grad ---- + # ---- apply update immediately for THIS param and free its grad ----
868-912: Checkpoint verification path is useful but can be expensive.The added debug‑dict comparison in catch‑up provides strong safety checks on checkpoint correctness, but it does an extra
get(...)and comparison before replaying any window. For very frequent restarts or large models this might add noticeable latency. If this becomes an issue operationally, consider gating the verification behind a config flag.src/tplr/comms.py (9)
142-145: Partition cache is fine; consider documenting eviction policy.The
_partition_cachekeyed by(window, sorted_uids, world_size)with an LRU cap of 32 entries is reasonable. A short comment about expected cache size vs typical window counts would help future maintainers understand why 32 was chosen and when eviction might occur.
483-582: s3_object_exists error classification looks solid; consider purging bad clients on hard errors.The structured handling of 404, auth/bucket config errors, and transient 5xx/network issues is a good improvement over string matching. For consistency with other S3 paths, you might also call
_purge_s3_client(bucket)on non‑transientClientErrors (e.g., invalid credentials/bucket) to force recreation on the next call, though this is optional.
2071-2096: partition_uids_across_ranks is correct but currently unused.The static round‑robin helper does exactly what gather_distributed’s inline partitioning does. To reduce duplication, consider having
gather_distributedcallpartition_uids_across_ranksinstead of reimplementing the slice logic with its own cache (or remove the helper if you intend to rely solely on the cached path).
2133-2257: merge_gather_results: semantics look right; minor device/perf nits.
- Correctly filters out
Nonepartials, aggregates all param keys, concatenates per‑param lists, and recomputes metrics and success_rate.- Device selection tries current CUDA device, with a safe CPU fallback.
Two small suggestions:
- There’s a second
move_to_devicehere that largely duplicates_move_gather_result_to_device; you could reuse the static helper for consistency.- When
gather_distributed’s local‑merge path callsmerge_gather_results([partial_result])and then immediately moves to CPU, you end up moving tensors twice (GPU→GPU viato(target_device)then GPU→CPU). Passingtarget_device="cpu"from that call would avoid an extra copy.
2318-2364: Local merge + CPU move before all_gather mostly achieves the memory goal.Merging per‑rank results locally when
return_partials=Falseand then moving that partial to CPU beforeall_gather_objectaddresses the multi‑rank GPU OOM scenario. To minimize device churn, consider invokingmerge_gather_results([partial_result], target_device="cpu")directly so tensors are moved only once (GPU→CPU) instead of GPU→GPU then GPU→CPU.
2404-2438: Merged‑result path is correct; minor duplication in logging.Merging
all_partialson every rank ensures consistent behavior; moving the final result back todeviceonly when non‑CPU and CUDA‑available is reasonable. Thelog_with_contextcall on rank 0 provides a good summary. Note:gather_timeandmerge_timereferenced in the final log come from the local rank and may differ slightly across ranks; that’s acceptable but could be renamed to clarify they’re per‑rank.
2479-2494: Entry logging is very verbose; consider truncating UID lists.
[gather_with_reserve][Rank ...] ENTRY | gather_uids={gather_uids}, reserve_uids={reserve_uids}at INFO can be quite large when many peers are involved. If logs become noisy, you might want to limit to counts and maybe the first few UIDs.
2507-2557: Distributed integration and sequential fallback logic are mostly sound; note potential per‑rank divergence.
- Using
gather_distributedwhenworld_size > 1and distributed is initialized is the right default.- The post‑call check treating
Noneand all‑empty partials as failure is good.However, when the distributed path fails, each rank independently runs the sequential
gather(...)with the fullgather_uidslist. Because these are separate S3 read patterns, ranks can in theory end up with slightly different results under flaky network conditions. If strict cross‑rank consistency is required, consider broadcasting a single primary result from the master rank after sequential fallback.
2571-2585: Empty-primary normalization is correct; partials vs merged behavior is slightly asymmetric.When
primary is None, you create an empty shell and, ifreturn_partials=True, wrap it in a list to maintain type consistency. For the case wherereturn_partials=Trueandsequential_resultis[None],primaryskips this path and remains[None]—later logic converts that intoNoneviatotal_uids == 0. Behavior is still correct but slightly harder to reason about; a brief comment explaining the two “no gradients” paths would help.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
neurons/miner.py(9 hunks)neurons/validator.py(13 hunks)src/tplr/comms.py(19 hunks)src/tplr/neurons.py(11 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/validator.py (6)
tests/test_state_loading.py (1)
rating(59-62)src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/comms.py (2)
gather_with_reserve(2443-2721)get_with_retry(1625-1740)src/tplr/neurons.py (1)
check_uid_index_overlap(1281-1486)src/tplr/schemas.py (1)
success(63-65)
neurons/miner.py (3)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/comms.py (1)
gather_with_reserve(2443-2721)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
src/tplr/neurons.py (1)
src/tplr/compress.py (4)
compress(373-377)compress(379-383)compress(386-423)unpack_12bit_indices(102-142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (26)
neurons/miner.py (2)
830-833: Post‑outer‑step gradient clearing is in the right placeZeroing gradients with
model.zero_grad(set_to_none=True)immediately after a successful outer step is a good safeguard against stale grads leaking into the next window, especially now that outer steps are conditional onshould_update.This fits cleanly with the additional cleanup in
cleanup_window().
738-775: Verification confirms distributed gather wiring is consistent across all call sitesThe sanity check confirms:
- Miner.py (lines 758–771): correctly uses
local=False, passestime_min/time_max, includesexpected_compressed_params, setsreturn_partials=Falsewithtimeout=90- Validator.py (lines 1517–1532): identically configured with
local=False,time_min/time_max,expected_compressed_params,return_partials=Falsebut withtimeout=150(appropriate for validator workload)- Test call appropriately uses minimal parameters for mocked context
Both production call sites implement the same parameter discipline and avoid partial-list vs merged-result pitfalls.
neurons/validator.py (5)
1338-1370: Peers, reserve_peers, and commitments broadcast keeps all ranks aligned for distributed gatherThe new logic:
- Refreshes
self.comms.commitmentson the master each window.- Broadcasts
[peers, reserve_peers, commitments]from master to all ranks usingdist.broadcast_object_list.- Updates
self.comms.peers,self.comms.reserve_peers, andself.comms.commitmentson non‑master ranks and logs what was received.This directly addresses the stale‑commitments / stale‑peer‑set risk for distributed gather and ensures all ranks use the same bucket mappings and reserve lists when calling
gather_with_reserve.
1412-1439: Timestamp broadcast ensures consistent time_min/time_max across ranksOnly the master now:
- Calls
self.query_block_timestamp(sync_block).- Falls back to
time.time()on failure with a warning.Then
ts_valueis broadcast via atorch.float64tensor to all ranks, andtime_min/time_maxare derived from the shared value.This removes the previous possibility of rank‑dependent time windows for
gather_with_reserve(time_min=time_min, time_max=time_max)and is the right pattern for distributed timing semantics.
1607-1645: Master‑only post‑gather processing and offload flow looks coherentIn the master‑only block when
gather_resultis notNone:
- You upload the already‑merged gather results via
upload_gather_results(gather_result).- Run
check_uid_index_overlapon the merged result and applyslash_from_overlap.- Call
offload_gather_results(gather_result, log=True)to move itsstate_dictto pinned CPU, freeing GPU memory before evaluation.- Extract
skipped_uidsandsuccess_rateonce, then barrier on"post_gather".This sequencing:
- Ensures all heavy gather‑side work is done before evaluation and outer‑step.
- Leaves
gather_resultvalid (albeit offloaded to CPU) for later uses:compute_peer_val_norms, slashing paths, outer_step, and metrics.Given
offload_gather_resultshandles bothSimpleNamespaceand dictstate_dictlayouts, this looks safe.
3198-3213: evaluate_miner_sync: get_with_retry integration and error handling look solidThe new
evaluate_miner_synclogic:
- Uses
self.comms.get_with_retry(..., timeout=60, local=False, stale_retention=10)to fetch the miner’sdebugdict, which is a better fit for transient S3/R2 hiccups than a singleget.- Correctly handles both timeout (
debug_result is None) and unsuccessful responses (not debug_result.success), logging a warning that includesgetattr(debug_result, "error", "unknown")without risking an AttributeError whendebug_resultisNone.- Falls back to a simple
{success: False, error: "...", sync_score: 0.0}payload, which upstream slashing logic can safely consume.This should reduce spurious sync‑slash events due to transient infra issues without introducing new failure modes.
932-938: OpenSkill rating initialization is required by the PlackettLuce modelInitializing
self.openskill_ratings[uid]for allwindow_uidsbefore callingself.openskill_model.rate(...)is not just an optimization—it is a requirement. The PlackettLuce.rate function expects teams to be lists of PlackettLuceRating objects (i.e. pre-initialized rating objects), and passing raw player values instead will raise a TypeError.This change correctly ensures:
- Every UID has a properly initialized rating object before being passed to the model.
- The code avoids runtime TypeErrors that would otherwise occur.
- Ratings are stable and predictable as new peers appear.
src/tplr/neurons.py (5)
249-295: outer_step wrapper looks good and preserves original semantics.Delegating to
_outer_step_singlekeeps the public API stable while centralizing the implementation. Given that distributed gather now always delivers a merged result here, the updated docstring is accurate.
1084-1147: Per‑window debug‑dict comparison logic is sound.The EMA of parameter‑slice deltas (
param_avg_change) combined withcompare_model_with_debug_dictgives a reasonable estimate of “steps behind” during catch‑up. The handling of DTensor vs non‑DTensor parameters and the index slicing guards (numel() < 2) look correct.
1225-1240: Stabilized step‑ratio calculation in compare_model_with_debug_dict is a good fix.Using
max(param_avg_change, learning_rate)as the denominator avoids explosive step ratios when the EMA of parameter change becomes tiny, while still falling back tolearning_rateif the stored slice has mismatched length. This should make the “steps behind” metric much more robust.
1281-1322: Partial‑aware handling in check_uid_index_overlap looks correct and fixes O(n×m) scan.The new
uid_to_partial: dict[int, tuple[SimpleNamespace, int]]map lets you locate a UID’s partial and index in O(1) instead of scanning all partials per UID. The merged‑vs‑partial branching andall_state_dictsconstruction are consistent, andPtotremains the full UID count in both cases.
1355-1431: Overlap computation for partial results is logically consistent.Using
valid_uids_for_paramto map back from the compact per‑parameter index (0..len(unpacked_indices)-1) to the original UID index, and then keyingpair_accby those original indices, preserves the semantics of later reporting viauids[i], uids[j]. Early continues whenidxs_list/vals_listare missing or too short correctly skip malformed entries.src/tplr/comms.py (14)
666-695: 500MB threshold for multipart downloads is reasonable.Using a 500MB cutoff to switch from simple
get_objectto the paralleldownload_large_fileimplementation aligns with the goal of handling large model gradients robustly, and the dynamicdownload_timeoutscaling by size is sensible.
1714-1737: Resetting consecutive_not_found on non‑NOT_FOUND statuses is correct.Once the above distinction is addressed, the logic to reset
consecutive_not_foundon other statuses should keep the early‑exit optimization from triggering on mixed error patterns.
1815-1990: gather’s validation and logging additions look good.The added per‑UID logging for exceptions (
log_with_contextat warning level) and the NaN/Inf checks on non‑quantizedvals(moving only for validation) strengthen robustness without changing semantics. Theexpected_compressed_paramscross‑check and quant_params presence check remain consistent.
2097-2131: _move_gather_result_to_device is straightforward and matches usages.The recursive
move_to_devicehelper correctly handles tensors, lists/tuples, and dicts underresult.state_dict. In‑place mutation of the namespace is fine here and fits howgather_distributedmoves results to CPU/GPU before/after collectives.
2259-2313: Distributed UID partitioning and caching in gather_distributed are reasonable.The
(window, sorted_uids, world_size)cache key plus LRU eviction keeps partitions deterministic and avoids recomputing slices every step. Using sorted UIDs ensures all ranks see the same partitioning, and therank < len(partitions)guard handles any mismatch gracefully.
2365-2385: all_gather_object usage and normalization look correct.Using
dist_helper.all_gather_object(partial_result)guarded by bothworld_size > 1anddist_helper.is_distributed()is appropriate. Normalizing empty payloads toNonebased on missinguidssimplifies downstream checks when deciding whether any rank produced gradients.
2387-2402: return_partials path behavior is consistent and type‑safe.Returning the full
all_partialslist (possibly containingNones) with arankattribute set for each non‑None entry gives incremental consumers enough metadata to sort or bucket by source rank. Logging only on rank 0 avoids redundant logs.
2443-2477: gather_with_reserve signature/docs match new behavior.Adding
return_partialsand updating the return type docstring to allowlist[SimpleNamespace | None]is aligned withgather_distributed’s contract and with the incremental processing use‑case.
2586-2615: Primary‑gather logging for partials vs merged results is clear and correct.The split handling of
primaryas list vs namespace, including computingmissingdifferently, matches the two modes. Logging vialog_with_contextwith both counts and skipped UIDs should be useful when debugging gather quality.
2618-2676: Reserve retry logic integrates well with partials and merged paths.
- Filtering
reserve_uidsto exclude already‑used primary UIDs avoids duplicates.- Using sequential
gatherfor the reserve tier is a reasonable compromise in complexity.- For partials, appending
fallbackas a new partial and assigningfallback.rank = len(primary)preserves deterministic ordering; for merged results, in‑place tensor‑list merging and metric aggregation look correct.
2678-2721: Final success‑rate recomputation and “no gradients” handling are correct.Recomputing success_rate against the original
gather_uidslength and propagating it to each partial (or to the merged result) makes metrics easier to interpret. ReturningNonewhentotal_uids == 0(both in partial and merged branches) ensures downstream callers don’t have to special‑case an empty structure.
1471-1493: gradient_timestamp’s version arg extension is safe and backwards compatible.Defaulting
versiontotplr.__version__whenNonepreserves existing behavior while allowing callers to probe other versions explicitly. Exception handling still purges the S3 client and returns 0.0 on failure.
3071-3116: get_start_window version support aligns with gradient_timestamp changes.Using
target_version = version if version else tplr.__version__keeps existing behavior and lets callers request bootstrap/alternate versions. The call intos3_get_objectis consistent with other JSON fetch paths.
697-724: No issues detected—the weights_only fix is valid and compatible.The
weights_onlyparameter was introduced in PyTorch 2.6, and the codebase requirestorch>=2.7.1(perpyproject.toml). A dict like{"state_dict": <state_dict>, "global_step": 123}(where state_dict is a mapping of tensors) loads correctly withweights_only=True, which matches exactly what the code saves and loads. Errors are already handled gracefully (logged and result inNone).
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/tplr/comms.py (1)
1551-1570: Apply map_location to local torch.load path (lines 1564-1566)The local storage branch does not respect the
map_locationparameter documented in the function signature and docstring. While the remote path (vias3_get_object()) correctly computesdevice_locationand passes it totorch.load()(lines 705-718), the local branch omitsmap_locationentirely. This creates a behavioral inconsistency where local loads ignore the caller's device preference and device fallback logic.- loaded_data = await asyncio.to_thread( - torch.load, local_path, weights_only=True - ) + device_location = ( + map_location if map_location is not None else self.config.device + ) + loaded_data = await asyncio.to_thread( + torch.load, + local_path, + map_location=device_location, + weights_only=True, + )
♻️ Duplicate comments (1)
src/tplr/comms.py (1)
1665-1737: NOT_FOUND retry cutoff still conflates 404 with transient failures
get_with_retrynow tracksconsecutive_not_foundand aborts after 5 occurrences (~2.5 s at 0.5 s sleep), butget()still maps alls3_get_objectNone returns (404, timeouts, and some ClientError cases) tostatus="NOT_FOUND". This means slow or flaky peers can be treated as “likely doesn’t exist” long before the caller’stimeoutortime_maxwindow ends.To avoid violating the timeout semantics, you likely still need to either:
- Distinguish true 404s from transient errors in
get()/s3_get_objectand only incrementconsecutive_not_foundfor confirmed 404s, or- Derive
max_consecutive_not_foundfromtimeoutand/ortime_maxand apply an early cutoff only once you’re reasonably sure the gradient should already have appeared.This aligns with the earlier review concern about premature NOT_FOUND cutoffs.
🧹 Nitpick comments (7)
neurons/validator.py (4)
1491-1493: Consider making validator distributed‑gather opt‑in via configRight now
use_distributed_gatheris hard‑coded toFalse, so validators will never exercise the distributed gather path even when running multi‑GPU. If you foresee wanting to experiment with distributed gather for validators, consider wiring this to a config/hparam while keeping sequential as the default, e.g.:- # Validators use sequential gather (master only) for stability - # Distributed gather is only for miners - use_distributed_gather = False + # Validators default to sequential gather (master only) for stability. + # Allow optional opt‑in to distributed gather when running multi‑GPU. + use_distributed_gather = bool( + dist_helper.is_distributed() + and getattr(self.hparams, "validator_distributed_gather", False) + )If validators are intentionally constrained to sequential gather long‑term, this can be left as‑is.
1517-1525: Comment no longer matchesreturn_partials=FalsebehaviorThe inline comment says “Request partials for incremental processing in outer_step”, but the call now explicitly uses
return_partials=Falseand expects a mergedgather_result. To avoid confusion for future maintainers, it would be better to align the comment with the actual behavior, for example:- # Request partials for incremental processing in outer_step + # Request a merged result; incremental processing happens inside outer_step @@ - return_partials=False, # Always return merged result (distributed gather merges internally) + return_partials=False, # Always return merged resultThe increased timeout to 150s for gradient gathering looks reasonable.
1545-1562: Distributed‑gather completion logging is clear; sequential path relies on later metricsThe new logging for the distributed gather branch (success vs failure with uids/skipped/success_rate) is well structured. Note that sequential gather only logs start and later “Skipped UIDs” plus numeric metrics in the evaluation section; if you want symmetry in textual logs, you could optionally add a similar “gather complete” summary in the sequential path using
gather_result.{uids,skipped_uids,success_rate}.
1823-1826:compute_peer_val_normstype hint/docstring vs implementationAt the call site and in practice you always pass a merged
SimpleNamespacefromgather_with_reserve(return_partials=False), but the signature advertisesgather_result: SimpleNamespace | list[SimpleNamespace]and the nearby comment says it “handles both partials and merged results” while the body assumes a merged object with.state_dict.To avoid misleading callers (and accidental
AttributeErrorif someone passes a list of partials later), I’d suggest either:
- Narrowing the type hint and docstring to merged‑only, matching current behavior:
- def compute_peer_val_norms( - self, - gather_result: SimpleNamespace | list[SimpleNamespace], - ) -> dict[str, torch.Tensor]: - """ - Compute median L2 norms across all peers for each parameter. - - Args: - gather_result: Merged gather result SimpleNamespace - """ + def compute_peer_val_norms( + self, + gather_result: SimpleNamespace, + ) -> dict[str, torch.Tensor]: + """ + Compute median L2 norms across all peers for each parameter + from a merged gather result. + """or
- Adding an explicit
if isinstance(gather_result, list): ...branch that merges/normalizes partials before the current logic.Given the current usage in this file, tightening the annotation/docstring is the minimal, safe change.
Also applies to: 3791-3803
src/tplr/neurons.py (1)
48-70: CUDA device resolver is robust but currently unused in this pathThe
_resolve_cuda_devicehelper correctly normalizesstr | torch.device | Noneand handles lack of CUDA gracefully. In_outer_step_singlethe computedcuda_deviceis not used, so this helper is effectively dead code here. Either wire it into places that need a normalized CUDA device or drop the localcuda_deviceassignment to avoid confusion.src/tplr/comms.py (2)
2068-2257: Distributed gather helpers maintain UID alignment; minor reuse/cleanup opportunitiesThe new helpers look sound overall:
partition_uids_across_ranksand the_partition_cache(bounded LRU) give deterministic, cached round‑robin UID splits._move_gather_result_to_deviceandmerge_gather_resultscorrectly recurse through tensors/lists/tuples/dicts and preserve the invariant thatmerged.state_dict[<param>][i]corresponds tomerged.uids[i]across concatenated partials.target_deviceselection handlescurrent_device()failures and falls back to CPU.Two small cleanups you might consider:
gather_distributedcurrently re‑implements the round‑robin partitioning instead of usingpartition_uids_across_ranks; calling the static helper here would reduce duplication and ensure a single source of truth for partitioning.move_to_deviceis implemented twice (once here and once in_move_gather_result_to_device); factoring it into a shared utility would avoid divergence in future changes.Functionally this looks correct and matches the new distributed gather design.
2443-2719: gather_with_reserve correctly integrates distributed + partial flows; sequential fallback is coherentThe extended
gather_with_reservelogic appears consistent:
- In distributed mode it defers to
gather_distributedand handles both merged and partial return shapes viareturn_partials.- Sequential fallback is triggered only when distributed gather yields no gradients or raises; for
return_partials=Trueit wraps results in a one‑element list to preserve the type contract.- Primary and reserve results are merged correctly for both merged and partial modes (reserve partial appended as a new element vs. in‑place list concatenation for merged).
- Final success rates are recomputed against the original
gather_uids, and the function returnsNonewhen no gradients were successfully gathered.One trade‑off is that in distributed fallback, every rank independently performs the sequential gather even though only rank 0’s result is actually consumed by
outer_step(non‑masters rely on broadcast there). If sequential fallback becomes common, you might later consider an optimization where only the master gathers and then broadcasts the result to other ranks viadist_helper.broadcast_object. For now, the behavior is correct, just somewhat redundant.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
neurons/validator.py(12 hunks)src/tplr/comms.py(19 hunks)src/tplr/neurons.py(10 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/tplr/neurons.py (1)
src/tplr/compress.py (4)
compress(373-377)compress(379-383)compress(386-423)unpack_12bit_indices(102-142)
neurons/validator.py (6)
tests/test_state_loading.py (1)
rating(59-62)src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/comms.py (2)
gather_with_reserve(2443-2719)get_with_retry(1625-1740)src/tplr/neurons.py (1)
check_uid_index_overlap(1275-1480)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (10)
neurons/validator.py (5)
931-937: Initialize OpenSkill ratings beforerate()– good defensive changePre‑initializing
self.openskill_ratings[uid]for allwindow_uidsavoids missing‑rating issues and keeps the laterteams/rate()call robust. Implementation is straightforward and side‑effect free for already‑rated peers.
1337-1361: Peer/commitment broadcast across ranks is correctly guardedRefreshing
self.comms.commitmentson master and broadcasting[peers, reserve_peers, commitments]viadist.broadcast_object_listonly whenworld_size > 1 and dist_helper.is_distributed()ensures all ranks share a consistent view for distributed operations without touching the non‑distributed path. This wiring looks correct.
1411-1438: Timestamp broadcast keeps gather time window rank‑consistentCentralizing the
query_block_timestamp(sync_block)call on master, falling back totime.time()once, and then broadcastingts_valueto all ranks before computingtime_min/time_maxremoves the prior possibility of divergent time windows between ranks. The control flow and device placement forts_tensorlook sound.Also applies to: 1440-1447
3199-3213: Switch toget_with_retryfor debug dict fetch improves robustnessUsing
self.comms.get_with_retry(..., timeout=60, ...)and handling bothNoneand unsuccessful results in one branch is a solid improvement over a singlegetcall. Thegetattr(debug_result, "error", "unknown")usage safely covers theNonecase without raising, and returning a{"success": False, "sync_score": 0.0}stub integrates cleanly withslash_for_poor_sync’s existingavg_steps_behind/successhandling.
3791-3827: Merged‑result norm computation logic looks correctGiven a merged
gather_result.state_dictholding per‑peervals/quant_paramslists, dequantizing viamaybe_dequantize_values, stacking per‑peer L2 norms, and taking the median per parameter is a reasonable and numerically stable way to buildclip_norm_dict. This integrates cleanly with the later fallbackclip_norm_dict.get(vals_key, eval_norm)inupdate_model_with_gradient.src/tplr/neurons.py (2)
249-552: outer_step➝_outer_step_single refactor preserves single-result semanticsThe split of
outer_stepinto a thin wrapper and_outer_step_singlekeeps the previous behavior: only the source rank consumesgather_result, gradients are built and applied per-param to minimize memory, and non-source ranks receive grads via DTensor scatter or broadcast. Fingerprint accumulation and W&B logging remain gated onon_src. I don’t see new correctness or DDP edge cases introduced by this refactor.
1274-1480: Overlap checker now correctly supports partial results and fixes UID/index mappingThe changes to
check_uid_index_overlaplook solid:
- Accepting
gather_result: SimpleNamespace | list[SimpleNamespace]matches the new distributed/partial gather API.- The
uid_to_partialmap andvalid_uids_for_paramarray remove the O(n×m) search and correctly track which original UID each unpacked index tensor corresponds to.- Using
(uid_idx_i, uid_idx_j)as keys intopair_accensures overlap stats and offender selection are computed in the global UID index space, even when some peers lack a given parameter.This should resolve the earlier complexity and misalignment issues without changing the external metrics contract.
src/tplr/comms.py (3)
483-582: s3_object_exists implements conservative, structured error handling correctlyThe new
s3_object_existshelper cleanly distinguishes 404/NoSuchKey, permission/config errors, and transient/network/service issues usingClientError.response. Returning False only for “definitely missing or misconfigured” and True for transient failures matches the intended conservative semantics for slashing and availability checks. No functional issues from this implementation.
666-724: Large-file handling and async torch.load offloading look correctRaising the simple-download threshold to 500 MB and delegating larger objects to
download_large_filewith a size-scaled timeout is reasonable for 70B-class gradients. Offloadingtorch.loadto a thread viaasyncio.to_thread(withmap_location=device_locationandweights_only=True) avoids blocking the event loop during CPU decompression without changing load semantics. This block looks correct.
2259-2439: gather_distributed flow and type contracts look consistent for merged and partial results
gather_distributed’s behavior matches its signature and the downstream expectations:
- Deterministic, cached partitioning of sorted UIDs across ranks.
- Per‑rank
gather()only whenmy_uidsis non‑empty, with local pre‑merge whenreturn_partials=Falseto reduce per‑rank memory.- Moving partials to CPU before
all_gather_objectto avoid multi‑GPU OOMs, then optionally moving the final merged result back to the requested device.- When
return_partials=True, returning alist[SimpleNamespace | None]withrankmetadata set per entry; when False, returning a single mergedSimpleNamespaceorNonewhen no UIDs were fetched.The normalization of empty partials to
Noneand the summary logging on rank 0 are also consistent withgather_with_reserve’s expectations. I don’t see correctness issues in this orchestration.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
neurons/validator.py (1)
3790-3826: Type signature doesn't match implementation.The signature indicates
gather_resultcan beSimpleNamespace | list[SimpleNamespace](line 3792), but the implementation only handles the merged case (lines 3805-3824). There's no check forisinstance(gather_result, list)before accessinggather_result.state_dictat line 3805.While
return_partials=Falseat line 1520 means this function currently only receives merged results, the type signature is misleading.Either update the implementation to handle both cases or correct the type signature:
Option 1: Match implementation to signature
def compute_peer_val_norms( self, gather_result: SimpleNamespace | list[SimpleNamespace], ) -> dict[str, torch.Tensor]: + # Handle list of partials by merging first + if isinstance(gather_result, list): + gather_result = self.comms.merge_gather_results(gather_result) + clip_norm_dict = {} state_dict = gather_result.state_dictOption 2: Correct the signature to match implementation
def compute_peer_val_norms( self, - gather_result: SimpleNamespace | list[SimpleNamespace], + gather_result: SimpleNamespace, ) -> dict[str, torch.Tensor]:
🧹 Nitpick comments (2)
neurons/validator.py (2)
1350-1360: Consider logging commitments dict size for monitoring.Broadcasting the commitments dict (line 1353) to all ranks is necessary for distributed gather, but the dict can be large (up to 256 UIDs × bucket metadata). While
broadcast_object_listuses pickle serialization which can be slow for large objects, this operation happens once per window so it's not a hot path.Consider adding debug logging to monitor the size:
payload = ( [self.comms.peers, self.comms.reserve_peers, self.comms.commitments] if self.is_master else [None, None, None] ) + if self.is_master: + import sys + size_mb = sys.getsizeof(payload) / 1e6 + tplr.logger.debug(f"Broadcasting payload size: {size_mb:.2f} MB") dist.broadcast_object_list(payload, src=0)
1490-1506: Consider adding a fallback mechanism for distributed gather failures.
use_distributed_gatheris hardcoded toTrue(line 1490), and theelif self.is_masterbranch at line 1500 for sequential gather is unreachable. If distributed gather encounters issues (e.g., rank failures, communication timeouts), there's no automatic fallback to sequential mode.Consider adding a configuration option or automatic fallback:
- use_distributed_gather = True + use_distributed_gather = getattr(self.hparams, "use_distributed_gather", True)Or add error handling to fall back to sequential on distributed failure (though this would require re-architecting the gather flow to retry).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
neurons/validator.py(12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
neurons/validator.py (7)
tests/test_state_loading.py (1)
rating(59-62)src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/comms.py (2)
gather_with_reserve(2443-2719)get_with_retry(1625-1740)src/tplr/neurons.py (1)
check_uid_index_overlap(1275-1480)src/tplr/schemas.py (1)
success(63-65)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (4)
neurons/validator.py (4)
931-936: LGTM! Defensive initialization prevents KeyError.Ensuring all UIDs have OpenSkill ratings before calling
rate()prevents potential KeyError when accessingself.openskill_ratings[uid]in the teams list creation at line 939. This is good defensive programming.
1414-1433: LGTM! Timestamp broadcasting ensures consistent time windows across ranks.The timestamp is queried by master (with fallback to
time.time()on failure), then broadcast to all ranks usingtorch.tensorwithfloat64precision. This ensures all ranks use identicaltime_minandtime_maxvalues for the distributed gather operation, preventing rank inconsistencies.
1607-1624: LGTM! Gather result processing correctly uses merged result.The code correctly handles the merged
gather_result(not a list of partials) as returned bygather_with_reservewithreturn_partials=False. Direct access togather_result.uidsand use inupload_gather_resultsandcheck_uid_index_overlapis appropriate.
3197-3212: LGTM! Correct usage ofget_with_retrywith proper error handling.The change from
comms.get()tocomms.get_with_retry()withtimeout=60adds retry logic for transient failures. The error handling correctly checks forNone(timeout) andnot debug_result.success(retrieval failure) before proceeding.
2f412b0 to
6728bed
Compare
6728bed to
8e9b8c7
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
neurons/validator.py (1)
1798-1825: Narrow the type annotation forcompute_peer_val_normsto reflect actual implementationThe review comment is accurate. The method signature at line 3790 declares
gather_result: SimpleNamespace | list[SimpleNamespace], but the implementation at line 3806 directly accessesstate_dict = gather_result.state_dictwithout checking whethergather_resultis a list, which would cause anAttributeErrorif called with a list of partials.Currently, the code is safe because the only
gather_with_reservecall at line 1516 explicitly usesreturn_partials=False(line 1521). However, the misleading type annotation and docstring ("Merged gather result SimpleNamespace") create risk if the method is reused or ifreturn_partialschanges.Recommended action: Apply option 1 from the review—narrow the type annotation to
SimpleNamespaceonly and update the docstring to remove the "handles both partials and merged results" comment at the call site. This avoids a latent runtime bug with minimal changes and keeps behavior as-is.src/tplr/comms.py (1)
1665-1737: Major issue: Early NOT_FOUND cutoff conflates transient errors with true 404s.The
max_consecutive_not_found = 5threshold (2.5 seconds) causes premature exits for transient network issues and timeouts, which are all mapped tostatus="NOT_FOUND"at line 1600. This breaks the intendedtimeoutsemantics (defaulting to 30 seconds).Root cause: The
get()method returnsCommsGetResult(status="NOT_FOUND")for both:
- True 404 responses (where early exit is appropriate)
- Transient failures like
asyncio.TimeoutErrorand network errors (which should retry until the caller'stimeoutexpires)Recommendation: Modify
get()ands3_get_object()to return distinct status codes (e.g.,"TRANSIENT_ERROR","TIMEOUT","NOT_FOUND") so the retry logic can distinguish true 404s from transient issues. Only incrementconsecutive_not_foundfor confirmed 404 responses.Based on learnings
♻️ Duplicate comments (1)
neurons/miner.py (1)
474-490: Broadcastreserve_peersalongsidepeersandcommitmentsfor consistent fallback behavior.Currently only
peersandcommitmentsare broadcast to non-master ranks. However,gather_with_reserveis invoked withreserve_uids=self.comms.reserve_peers(line 761). If the primary gather fails and falls back to reserves, non-master ranks may have stalereserve_peers, causing divergent behavior across ranks.Apply this diff to include
reserve_peersin the broadcast:# Broadcast peer list (and optionally commitments) from master to all ranks for distributed gather if dist_helper.world_size > 1 and dist_helper.is_distributed(): payload = ( - [self.comms.peers, self.comms.commitments] + [self.comms.peers, self.comms.reserve_peers, self.comms.commitments] if self.is_master - else [None, None] + else [None, None, None] ) dist.broadcast_object_list(payload, src=0) - self.comms.peers, self.comms.commitments = payload + self.comms.peers, self.comms.reserve_peers, self.comms.commitments = payload if not self.is_master: tplr.logger.info( - f"[Rank {dist_helper.rank}] Received peers from master: {self.comms.peers}" + f"[Rank {dist_helper.rank}] Received peers from master: gather={self.comms.peers}, reserve={self.comms.reserve_peers}" )Based on learnings
🧹 Nitpick comments (2)
neurons/validator.py (1)
1491-1524: Clarifyuse_distributed_gathersemantics and loggingRight now
use_distributed_gatheris hard‑coded toTrue, so:
- The comment “For distributed gather, all ranks must participate / For sequential gather, only master rank performs the operation” is only accurate if someone later wires
use_distributed_gatherto an actual condition.- Logs always say “Starting distributed gather …” and “Distributed gather complete …” even in effectively sequential/single‑rank cases.
Functionally it’s fine because
Comms.gather_with_reserveindependently decides whether to use distributed gather (world_size > 1 & is_distributed) and falls back to sequential as needed, but wiring the flag would make intent clearer and future changes safer.Consider:
- use_distributed_gather = True + use_distributed_gather = ( + dist_helper.world_size > 1 and dist_helper.is_distributed() + )and keep the current control flow / logging structure. That preserves behavior while making the mode selection explicit.
Also applies to: 1543-1560
src/tplr/comms.py (1)
20-20: Unused import:gcis never used.The
gcmodule is imported but not referenced anywhere in the file. Remove it to keep imports clean.-import gc import json
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
neurons/base_node.py(1 hunks)neurons/miner.py(9 hunks)neurons/trainer.py(3 hunks)neurons/validator.py(12 hunks)src/tplr/comms.py(19 hunks)src/tplr/compress.py(1 hunks)src/tplr/neurons.py(10 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/tplr/compress.py
🧰 Additional context used
🧬 Code graph analysis (4)
src/tplr/neurons.py (1)
src/tplr/compress.py (5)
compress(373-377)compress(379-383)compress(386-423)TopKCompressor(295-697)unpack_12bit_indices(102-142)
neurons/miner.py (4)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/comms.py (1)
gather_with_reserve(2443-2719)src/tplr/dcp_checkpoint.py (2)
state_dict(60-61)state_dict(82-83)
src/tplr/comms.py (2)
src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
neurons/validator.py (5)
tests/test_state_loading.py (1)
rating(59-62)src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/comms.py (2)
gather_with_reserve(2443-2719)get_with_retry(1625-1740)src/tplr/neurons.py (1)
check_uid_index_overlap(1275-1480)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (22)
neurons/base_node.py (1)
61-62: LGTM! MRO chain properly initialized.The
super().__init__()call ensures parent classes (e.g., Trainer) are initialized correctly, fixing the MRO chain issue. The inline comment clearly documents the intent.neurons/trainer.py (1)
887-889: LGTM! Gradient norm now properly captured and logged.The changes correctly capture
total_grad_normfromclip_grad_norm_().item()and handle the null round case by setting it to 0.0. This improves observability of the training process.Also applies to: 911-911, 928-928
src/tplr/neurons.py (3)
48-70: LGTM! Robust CUDA device resolution.The
_resolve_cuda_devicehelper properly handles:
- CUDA unavailability (returns None)
- Missing/invalid device inputs (fallback to cuda:0)
- Exception cases during device resolution
- Type validation to ensure CUDA device
This centralizes device resolution logic and prevents string-device issues in CUDA APIs.
250-294: LGTM! Clean delegation to internal single-step implementation.The refactor separates the public API (
outer_step) from the internal implementation (_outer_step_single). The docstring correctly documents that distributed gather merges internally, so the function always receives a single merged result, not a list of partials.Also applies to: 297-312
1277-1277: LGTM! Efficient dual-path overlap detection.The updated
check_uid_index_overlapcorrectly handles both merged and partial results:
- Partial path (lines 1291-1306): Builds
uid_to_partialmapping for O(1) lookup, eliminating the O(n×m) complexity flagged in past reviews- Merged path (lines 1311-1315): Uses traditional shared state_dict indexing
- Index retrieval (lines 1354-1396): Uses the mapping or direct indexing based on result type
- Pair keys (lines 1419-1422): Correctly uses
valid_uids_for_paramto reference original UID indicesThis optimization and dual-path support align with the PR objectives for distributed gather.
Also applies to: 1291-1315, 1354-1396
neurons/miner.py (4)
31-34: LGTM! CUDA allocator configuration properly set.Setting
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:Truebefore importing torch helps reduce memory fragmentation. The placement is correct (must be set before torch import).
552-559: LGTM! Proactive memory hygiene before training.The aggressive cleanup (empty_cache, synchronize, zero_grad) ensures memory from the previous outer step is fully released before accumulating new gradients. This helps prevent OOM issues in multi-GPU setups.
738-762: LGTM! Gather mode selection and logging correctly implemented.The code properly:
- Selects distributed gather when
world_size > 1andis_distributed()(lines 739-741)- Ensures all ranks participate in distributed mode (line 757)
- Sets
return_partials=Falseto get merged results internally (line 762)- Logs mode-specific completion details with success rates and timing (lines 777-800)
Since
return_partials=False,gather_resultwill be a mergedSimpleNamespace(not a list), making subsequent attribute access safe.Also applies to: 777-800
830-832: LGTM! Gradient clearing and gather result handling are correct.The changes properly:
- Clear gradients after outer step to free memory (lines 830-832)
- Extract UIDs and calculate success metrics from
gather_result(lines 885-891, 927-933) – safe becausereturn_partials=Falseensures a merged SimpleNamespace- Perform aggressive cleanup with state_dict clearing and gc.collect() (lines 1014-1025)
These improvements enhance memory management in multi-GPU distributed training.
Also applies to: 885-891, 927-933, 1014-1025
neurons/validator.py (6)
931-947: Pre‑initializing OpenSkill ratings for window_uids looks correctInitializing
self.openskill_ratings[uid]for allwindow_uidsbefore buildingteamsguaranteesteams = [[self.openskill_ratings[uid]] for uid in window_uids]never hits missing keys. This removes a subtle failure mode for first‑time peers and aligns the rating update flow cleanly.
1337-1369: Broadcasting peers and commitments keeps all ranks alignedRefreshing
self.comms.commitmentson master and then broadcasting[peers, reserve_peers, commitments]viadist.broadcast_object_listensures every rank sees exactly the same peer lists and bucket metadata before distributed gather. This avoids per‑rank divergence in peer selection and is consistent with the requirement that all ranks participate in the same distributed gather configuration.
1411-1439: Timestamp query + broadcast fixes time‑window divergence across ranksComputing
sync_blockonce, querying the timestamp only onself.is_master, falling back totime.time()with a warning, and then broadcastingts_valueto all ranks before derivingtime_min/time_maxensures every rank uses an identical time window for bothgather_with_reserveand subsequent gradient GETs. This removes the earlier risk where some ranks might silently usetime.time()while others used the chain timestamp.
1605-1644: Master‑only post‑gather processing with merged result is consistentRestricting upload, overlap checking, CPU offload, and extraction of
skipped_uids/success_rateto:if self.is_master and gather_result is not None: ...is appropriate now that
gather_with_reserveis always called withreturn_partials=False. The master operates on a single mergedgather_result, which:
- Avoids redundant
merge_gather_resultsallocations on multiple ranks.- Ensures
check_uid_index_overlapandupload_gather_resultssee the same merged view.- Leaves non‑master ranks free of extra memory pressure while still participating in the distributed gather itself.
The later metrics/slashing paths only touch these values under
self.is_master, so this reshaping is coherent.
3197-3212: Usingget_with_retryfor debug dict fetch is robustSwitching
evaluate_miner_syncto:debug_result = await self.comms.get_with_retry( uid=str(eval_uid), window=self.sync_window - 1, key="debug", timeout=60, local=False, stale_retention=10, )and then treating both
Noneandsuccess == Falseas failure:if debug_result is None or not debug_result.success: ... return {"success": False, "error": "...", "sync_score": 0.0}gives you:
- Bounded, retrying reads of the miner’s debug artifact, tolerant of transient NOT_FOUND / ERROR conditions.
- A clean “unsuccessful” result shape that downstream code already handles by defaulting metrics to sentinel values.
Using
getattr(debug_result, "error", "unknown")is also safe even whendebug_resultisNonethanks to the default argument.
3803-3826: Norm computation logic itself is straightforward and appropriateWithin the (merged) path, iterating
model.named_parameters(), extractingvals_key/quant_keypairs, dequantizing withmaybe_dequantize_values, and then taking the median of per‑peer L2 norms to populateclip_norm_dict[vals_key]is a reasonable and robust aggregation strategy for per‑parameter clipping thresholds. Givenupdate_model_with_gradientuses these as upper bounds (falling back to the peer’s own norm if missing), the behavior is as intended.src/tplr/comms.py (7)
483-582: Excellent structured error handling.The
s3_object_existsmethod correctly distinguishes between permanent failures (404, config errors) and transient issues (timeouts, 5xx). The conservative approach of returningTruefor transient errors prevents false positives in availability checks. The use of structured fields (status,code) instead of string matching makes the error handling robust.
709-723: Good optimization: offloading torch.load to thread pool.The use of
asyncio.to_thread()prevents blocking the event loop during tensor decompression, enabling true parallel downloads. The consistent application to both remote (line 713) and local storage (line 1564) is good for maintainability.
1471-1493: Good: version parameter enables multi-version gradient checks.The addition of the
versionparameter with a sensible default maintains backward compatibility while enabling flexibility. The implementation correctly defaults totplr.__version__whenNoneis provided.
2072-2095: LGTM: Deterministic round-robin partitioning.The implementation correctly sorts UIDs and uses Python's slice notation
[rank::world_size]for clean round-robin assignment. The single-rank early return is appropriate.
2156-2171: Device selection is reasonable for most scenarios.The fallback logic at lines 2158-2169 handles
RuntimeErrorfromtorch.cuda.current_device()and conservatively attemptscuda:0only after verifying device count. While this doesn't guard against device 0 being out of memory or otherwise unavailable at the hardware level, such edge cases are rare and would manifest as explicit CUDA errors downstream.
2443-2719: Complex but correct handling of partials vs. merged results.The method correctly handles two return paths:
return_partials=True: Returns list of unmerged partialsreturn_partials=False: Returns merged SimpleNamespaceThe wrapping logic (lines 2554-2555, 2566-2567, 2581-2582) consistently converts sequential results to single-element lists when
return_partials=True, maintaining type consistency.Minor note: The aggregate success rate (lines 2687-2689) overwrites per-partial success rates. This is intentional—success rate is computed relative to the original
gather_uidslist, not per-partial—but may be surprising to downstream consumers expecting local rates.
2291-2292: The device parameter is properly passed by all callers—no issue found.Verification shows that actual callers of
gather_with_reserve()inneurons/miner.py(line 766) andneurons/validator.py(line 1524) both explicitly passdevicethrough kwargs. The device flows correctly through the call chain togather_distributed(), where it's extracted at line 2292 and used at lines 2414-2424 to move results back to GPU as intended. The default of"cpu"applies only when device is omitted, which does not occur in production code paths.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/tplr/comms.py (2)
618-626: Restore default bucket fallback in s3_get_object.
bucketremains optional in the signature/docstring, but the new implementation no longer substitutesself.bucketwhen callers omit it. Any call such asawait self.s3_get_object(key=...)now raisesAttributeErrorwhen_get_s3_clientdereferencesNone.access_key_id. Please reintroduce the fallback before_get_s3_client()is invoked.+ if bucket is None: + bucket = self.bucket s3_client = await self._get_s3_client(bucket)
1668-1737: Do not abort get_with_retry after 2.5 seconds of NOT_FOUND.The new
max_consecutive_not_found = 5cap forces an early exit (~2.5 s) even when callers supplied a much larger timeout (30 s by default). Becauseget()returnsNOT_FOUNDfor transient S3 timeouts/network hiccups, we now stop long before the object has a chance to appear, skipping gradients under normal latency. Please either tie the cutoff totimeout(e.g., derive it fromtimeout / 0.5) or only increment the counter whenget()can confirm a real 404 so the loop continues honouring the caller’s timeout.
🧹 Nitpick comments (5)
neurons/miner.py (3)
554-562: Consider guarding the pre‑train CUDA cleanup for robustnessThe aggressive
torch.cuda.empty_cache()+torch.cuda.synchronize()before every inner step is fine for GPU‑only deployments, but it will throw on CPU‑only builds and adds some overhead each window.You can keep the behavior while making it safer:
- # Aggressive memory cleanup before training to prevent OOM - # This ensures memory from previous outer step is fully released - torch.cuda.empty_cache() - torch.cuda.synchronize() + # Aggressive memory cleanup before training to prevent OOM + # This ensures memory from previous outer step is fully released + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize(self.device)This matches how
check_memory_threshold()already guards other CUDA calls.
711-775: Align timestamp handling across ranks with validator-style broadcastWith the new distributed gather path, each rank passes
time_min/time_maxintogather_with_reserve. Today, every rank queriesquery_block_timestamp()independently and may fall back totime.time()on failures. If that fallback happens on only some ranks, different partitions can see slightly different time windows, which makes behavior harder to reason about and diverges from the validator’s now-unified timestamp logic.You can mirror the validator pattern so all ranks share the same
ts_valuewhile still keeping the non‑blocking executor call:- sync_block = self.current_window * self.hparams.blocks_per_window - ts_value = await self.loop.run_in_executor( - None, self.query_block_timestamp, sync_block - ) - if ts_value is None: - tplr.logger.warning( - f"Could not get timestamp for sync block {sync_block}. Using current time as fall back.", - ) - ts_value = time.time() + sync_block = self.current_window * self.hparams.blocks_per_window + + # Only master queries timestamp, then broadcasts to all ranks + if self.is_master: + ts_value = await self.loop.run_in_executor( + None, self.query_block_timestamp, sync_block + ) + if ts_value is None: + tplr.logger.warning( + f"Could not get timestamp for sync block {sync_block}. Using current time as fall back.", + ) + ts_value = time.time() + else: + ts_value = 0.0 + + # Broadcast timestamp from master to all ranks + ts_tensor = torch.tensor( + [ts_value], dtype=torch.float64, device=self.device + ) + dist_helper.broadcast(ts_tensor, src=0) + ts_value = float(ts_tensor.item())
time_min/time_maxcomputation below can stay as-is, now based on a single shared timestamp.
885-894: Tighten successful/skipped peer bookkeeping in debug + metricsRight now
successful_uidsis computed but unused, andsuccessful_peersis derived aspeers - skipped_uids. Sincegather_result.uidsalready encodes the actual successful peers, you can simplify and avoid the unused local:- # Add successful peers information - if gather_result is not None: - successful_uids = set(gather_result.uids) - skipped_uids = set(gather_result.skipped_uids) - - debug_dict["successful_peers"] = sorted( - list(set(self.comms.peers) - skipped_uids) - ) - debug_dict["skipped_peers"] = sorted(list(skipped_uids)) + # Add successful peers information + if gather_result is not None: + successful_uids = set(gather_result.uids) + skipped_uids = set(gather_result.skipped_uids) + + debug_dict["successful_peers"] = sorted(successful_uids) + debug_dict["skipped_peers"] = sorted(skipped_uids)The metrics section below (
gather_success_rateandskipped_uids_list→skipped_peersfield) already looks consistent with a mergedgather_result, so no change needed there.Also applies to: 929-999
neurons/validator.py (2)
1490-1524: Make gather mode detection dynamic and fix device typing for gather_with_reserveTwo small issues here:
use_distributed_gather = Truemakes the sequential logging branch dead and causes all ranks to enter the gather path even in non‑distributed setups. It’s safer and clearer to mirror the miner:- gather_start = tplr.T() - skipped_uids: list[int] = [] - success_rate = 0.0 - gather_result = None - skip_window = False - use_distributed_gather = True + gather_start = tplr.T() + skipped_uids: list[int] = [] + success_rate = 0.0 + gather_result = None + skip_window = False + # Use distributed gather automatically when running with multiple ranks + use_distributed_gather = ( + dist_helper.world_size > 1 and dist_helper.is_distributed() + )
- In the
gather_with_reservecall,device=cast(str, self.device)is a runtime no‑op; the actual object passed is still atorch.device. Miner passesdevice=str(self.device), andCommsis written to accept a string device identifier.For consistency and to avoid surprises inside
Comms:- timeout=150, # Increased from 90 to allow more time for gradient gathering - device=cast(str, self.device), + timeout=150, # Increased from 90 to allow more time for gradient gathering + device=str(self.device),These tweaks keep the control flow honest and the
deviceargument uniform across callsites.
1821-1824: Clarify compute_peer_val_norms contract to “merged only”
compute_peer_val_normscurrently assumesgather_result.state_dictis a merged object with per‑parametervals/quant_params; it doesn’t handle the list-of-partials case suggested by its type hint and the comment at Line 1823. Given this run path now always usesreturn_partials=False, it’s better to make the contract explicit rather than suggest partial support that doesn’t exist.You can align the comment and signature:
- clip_norm_dict = {} - if gather_result is not None: - # compute_peer_val_norms handles both partials and merged results - clip_norm_dict = self.compute_peer_val_norms(gather_result) + clip_norm_dict = {} + if gather_result is not None: + # compute_peer_val_norms expects a merged gather result + clip_norm_dict = self.compute_peer_val_norms(gather_result)and:
- def compute_peer_val_norms( - self, - gather_result: SimpleNamespace | list[SimpleNamespace], - ) -> dict[str, torch.Tensor]: + def compute_peer_val_norms( + self, + gather_result: SimpleNamespace, + ) -> dict[str, torch.Tensor]:The implementation below then accurately matches the type and docstring (“Merged gather result SimpleNamespace”), avoiding confusion for future partial‑support work.
Also applies to: 3790-3826
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
neurons/base_node.py(1 hunks)neurons/miner.py(9 hunks)neurons/trainer.py(3 hunks)neurons/validator.py(12 hunks)src/tplr/comms.py(19 hunks)src/tplr/compress.py(1 hunks)src/tplr/neurons.py(10 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- neurons/trainer.py
- src/tplr/compress.py
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (4)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/comms.py (1)
gather_with_reserve(2443-2719)src/tplr/dcp_checkpoint.py (2)
state_dict(60-61)state_dict(82-83)
neurons/validator.py (6)
tests/test_state_loading.py (1)
rating(59-62)src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/comms.py (2)
gather_with_reserve(2443-2719)get_with_retry(1625-1740)src/tplr/neurons.py (1)
check_uid_index_overlap(1275-1480)src/tplr/schemas.py (1)
success(63-65)
src/tplr/neurons.py (1)
src/tplr/compress.py (4)
compress(373-377)compress(379-383)compress(386-423)unpack_12bit_indices(102-142)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)all_gather_object(307-325)
🔇 Additional comments (9)
neurons/miner.py (4)
31-42: CUDA allocator env + dist import look solidSetting
PYTORCH_CUDA_ALLOC_CONFbefore importing torch (while honoring any pre‑set value) and adding the explicittorch.distributedimport are appropriate for the new distributed gather + memory‑hygiene flow. No issues here.
469-493: Peer / reserve / commitment broadcast keeps ranks alignedBroadcasting
self.comms.peers,self.comms.reserve_peers, andself.comms.commitmentsfrom the master to all ranks before gather ensures every rank has a consistent view of primary and reserve peers plus bucket metadata, which is exactly whatgather_with_reserveexpects. This also mirrors the validator path, reducing the risk of subtle per‑rank divergence.
740-803: Distributed vs sequential gather wiring matches Comms semanticsDeriving
use_distributed_gatherfromdist_helperand then:
- having all ranks call
gather_with_reservein distributed mode,- restricting the sequential path to the (single) master rank when not distributed, and
- forcing
return_partials=Falseso miners always see a mergedSimpleNamespace,is consistent with the documented behavior of
gather_with_reserveand with how downstream code accessesuids,skipped_uids, andsuccess_rate. This addresses the earlier partial‑vs‑merged mismatch cleanly.
1016-1025: Gather-result cleanup timing is safe but future uses should bewareClearing
gather_result.state_dict.__dict__(when present) and deletinggather_resultafter all logging/outer-step work is finished is a reasonable way to encourage GC and reduce memory pressure.Just be aware that if you later start any background tasks that capture
gather_result(or itsstate_dict) by reference, they will see this cleared state; in that case this cleanup would need to move after those tasks complete or be narrowed to temporary structures only. As written, with no async users ofgather_result, the lifecycle is coherent.neurons/validator.py (5)
931-937: OpenSkill rating pre-initialization avoids missing-rating edge casesInitializing
self.openskill_ratings[uid]for everyuidinwindow_uidsbefore callingself.openskill_model.rate(...)prevents gaps where a peer participates in rating updates without an existing rating object. This keeps the OpenSkill flow consistent with how you hydrate ratings from saved state and avoids ad-hoc defaults.
1337-1369: Commitment refresh + peer/reserve broadcast correctly sync gather stateRefreshing
self.comms.commitments = await self.comms.get_commitments()on the master and then broadcasting[peers, reserve_peers, commitments]to all ranks ensures:
- all ranks use the same primary + reserve peer lists, and
- everyone sees the same bucket config when calling
gather_with_reserve.The logging for non‑master ranks also helps debug any divergence. This is aligned with the miner side and looks good.
1411-1439: Timestamp broadcast eliminates per-rank time window skewThe new pattern—master querying
query_block_timestamp(sync_block)(with atime.time()fallback), broadcastingts_valueviadist_helper.broadcast, then derivingtime_min/time_maxfrom the shared value—fixes the earlier risk where some ranks could fall back to local time while others used chain time.Given
time_min/time_maxfeed intogather_with_reserveand latercomms.get(...)calls, this is a solid synchronization improvement for distributed validator runs.
1607-1643: Master-only post-gather processing and offload look coherentUsing the master rank to:
- upload the merged
gather_result,- run
check_uid_index_overlap(...),- offload aggregated results to pinned CPU via
offload_gather_results, and- snapshot
skipped_uids = list(gather_result.skipped_uids)andsuccess_rate = gather_result.success_ratebefore the barrier keeps heavy processing centralized while still letting all ranks synchronize on the same
skipped_uids/success_ratefor later slashing and metrics. The ordering (overlap → offload → slashing metadata capture) looks correct.
3197-3212: get_with_retry-based debug fetch handles timeouts/invalid data safelySwitching
evaluate_miner_syncto useself.comms.get_with_retry(..., timeout=60, stale_retention=10)and then treating bothNoneandnot debug_result.successas failure is a good improvement:
- Timeouts and transient issues are absorbed as
None.- The
getattr(debug_result, "error", "unknown")usage keeps logging robust even whendebug_resultisNone.- Invalid or missing debug payloads cleanly return a
{"success": False, "sync_score": 0.0}result, which the slashing logic can consume.This keeps sync evaluation resilient without raising on infrastructure glitches.
8e9b8c7 to
20fec9f
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
src/tplr/comms.py (1)
1665-1671: Early NOT_FOUND cutoff still risks violating timeout semantics
get_with_retrynow tracksconsecutive_not_foundand exits after 5 NOT_FOUND results (~2.5s at 0.5s sleep), regardless of the caller’stimeoutand even beforetime_maxis reached. Becauseget()maps many transient S3/network failures (e.g., timeouts, someClientErrors) tostatus="NOT_FOUND"vias3_get_objectreturningNone, slow or flaky peers can still be treated as “likely doesn’t exist” long before the configured timeout window elapses. This is essentially the same concern raised earlier in past reviews.A minimal change that keeps the optimization but respects
time_maxwould be to only start the “give up early” logic once we’re sure the object should have appeared (i.e., notime_max, or we are already pasttime_max+ grace):- consecutive_not_found = 0 # Track consecutive NOT_FOUND errors - max_consecutive_not_found = 5 # Give up after 5 consecutive NOT_FOUND (2.5s) + consecutive_not_found = 0 # Track consecutive NOT_FOUND errors + max_consecutive_not_found = 5 # Only used once we're sure gradient should exist @@ - # For NOT_FOUND, track consecutive failures - if result.status == "NOT_FOUND": - consecutive_not_found += 1 - if consecutive_not_found >= max_consecutive_not_found: - tplr.logger.debug( - f"GET {uid}/{window}/{key} - {consecutive_not_found} consecutive NOT_FOUND. " - f"Giving up early (file likely doesn't exist)." - ) - return None - else: - # Reset counter on other errors (might be transient network issues) - consecutive_not_found = 0 + # For NOT_FOUND, track consecutive failures, but only after the file + # should reasonably exist (no time_max, or we've passed time_max+grace). + if result.status == "NOT_FOUND": + if time_max is None or past_time_max: + consecutive_not_found += 1 + if consecutive_not_found >= max_consecutive_not_found: + tplr.logger.debug( + f"GET {uid}/{window}/{key} - {consecutive_not_found} consecutive NOT_FOUND. " + f"Giving up early (file likely doesn't exist)." + ) + return None + else: + # Reset counter on other errors (might be transient network issues) + consecutive_not_found = 0This way the early cutoff only triggers after we either have no
time_maxhint or are already beyondtime_max + time_max_grace_period, which aligns better with the intent described in earlier discussion.If you prefer to instead distinguish true 404s from transient errors inside
get()/s3_get_object, that would also resolve the issue; please double-check that mapping as well.Also applies to: 1725-1737
🧹 Nitpick comments (11)
src/tplr/comms.py (4)
667-682: Large-object branching and asynctorch.loadare implemented correctly; consider tightening bucket handlingThe 500MB size threshold and switch to
download_large_filewith a scaled timeout look reasonable, and offloadingtorch.loadtoasyncio.to_thread(with error logging) is a good fix for event-loop blocking.One minor robustness nit:
s3_get_object’s signature/docstring saybucketdefaults toself.bucket, but the implementation assumesbucketis non-None (e.g.,_get_s3_client(bucket)andbucket.name). If this method is ever called withbucket=None, it will fail. You may want to add a small guard at the top:async def s3_get_object( - self, - key: str, - bucket: Bucket | None = None, + self, + key: str, + bucket: Bucket | None = None, @@ - s3_client = await self._get_s3_client(bucket) + if bucket is None: + bucket = self.bucket + s3_client = await self._get_s3_client(bucket)This keeps the behavior aligned with the docstring.
Also applies to: 709-723
2072-2257:merge_gather_resultsbehavior is sound; minor device-choice nuanceThe new distributed helpers are generally well-structured:
partition_uids_across_ranksuses deterministic round-robin over a sorted UID list, which is important for reproducibility.merge_gather_resultscorrectly:
- Filters out
Nonepartials.- Selects a reasonable default
target_device(current CUDA device when available, otherwise CPU).- Recursively moves tensors to
target_device.- Concatenates lists per parameter key across partials, and aggregates metrics/UIDs/steps and skipped_uids.
- Computes a global success_rate based on successes vs skipped.
One nuance to be aware of: when
target_deviceis left asNone, the default path will move all tensors in the merged result to the current CUDA device. In contexts where you intend to keep intermediate compressed tensors on CPU to save GPU memory, make sure callers passtarget_device="cpu"explicitly (as you already do in some other paths) to avoid surprising memory pressure.
2415-2545: Sequential vs distributed fallback messaging could be clearerThe
gather_with_reserveorchestration looks logically consistent:
- In distributed mode, all ranks participate in
gather_distributed, but only master receives merged/partial results; non-master ranks explicitly returnNone.- Master-only sequential fallback uses
self.gather(...)and wraps the result in a list whenreturn_partials=True, maintaining the return-type contract.- Non-distributed mode always uses sequential gather and applies the same wrapping behavior.
One minor confusion: the log message
context_log( message=f"[gather_with_reserve] 🔄 Falling back to SEQUENTIAL gather on ALL ranks (consistency requirement)" )is emitted inside
if primary is None and dist_helper.is_master:, but only the master actually runs the sequentialgather()that follows—other ranks just seeprimary is Noneand returnNone. If the intent is truly master-only fallback (with subsequent DDP broadcast/coordination elsewhere), consider updating the log text to avoid implying that all ranks are performing the sequential gather. Otherwise, you’d need to invokegather()on all ranks to match the current message.This is cosmetic but helps future readers understand the actual behavior.
2567-2701: Partials handling ingather_with_reserveis consistent with type contractThe downstream handling of
primaryonce it may be a list of partials is well thought out:
- Logging aggregates
total_uidsand total skipped across all partials.- Missing slots are filled from
reserve_uidswithout duplicating already-used UIDs.- Reserve fallback is appended as an extra partial when
return_partials=True, and merged into the single namespace otherwise.- Success-rate recomputation correctly normalizes by the original
gather_uidslength and propagates the final success_rate to each partial.Once
gather_distributedrespectsreturn_partials=True, this machinery will give you consistent behavior between distributed and sequential modes.src/tplr/neurons.py (3)
22-22: Outer-step refactor to_outer_step_singlepreserves behaviorThe changes here:
- Introduce
_resolve_cuda_deviceto centralize CUDA-device derivation.- Turn
outer_stepinto a thin wrapper that simply forwards all arguments to_outer_step_single.- Keep the original, memory-minimizing outer-step logic entirely inside
_outer_step_single(per-param decode/decompress, DDP broadcast, optimizer.step per parameter, fingerprint accumulation).Functionally, this keeps the previous behavior while making the outer-step surface cleaner and easier to extend. There are no new CUDA API misuses (no string devices passed to
torch.cuda.*here), and the DDP logic still gates all heavy work onon_src = is_master or not ddp.Also applies to: 48-69, 249-295, 297-551
816-883: Catch-up path enhancements are coherent and safeIn
catchup_with_aggregation_server:
- The initial memory logging helper and periodic
log_memory_usagecalls provide useful visibility without affecting control flow.- Checkpoint verification against a validator
debugpayload atcheckpoint_current_windowis guarded byinstance.is_masterand only proceeds whendebug_fetchcontains a dict, so non-master ranks aren’t impacted.- When an aggregated
aggregatorobject is missing or malformed on the final window, the livecomms.gather(...)fallback is only attempted once, and its result is wrapped into the samegather_nsschema expected byouter_step.- Skip decisions are broadcast via
dist_helper.broadcast, and both pre- and post-outer-step barriers (safe_barrier) ensure all ranks stay in lock-step through the catch-up loop.- After each window, you aggressively clear
gather_ns, rungc.collect(), and flush CUDA caches, which is appropriate for 70B-scale models in a replay loop.Overall this is a solid and carefully synchronized catch-up flow; I don’t see correctness or ordering issues in the new logic.
Also applies to: 912-969, 985-1005, 1008-1079, 1084-1148, 1149-1168
1275-1316:check_uid_index_overlappartial support and complexity reduction are correctThe updated
check_uid_index_overlapnow:
- Accepts either a merged
SimpleNamespaceorlist[SimpleNamespace]partials.- For partials, flattens all UIDs into a single
uidslist and builds auid_to_partialmapping from UID → (partial, index_in_partial), giving O(n) lookup instead of scanning each partial per UID.- Tracks
valid_uids_for_param(indices into the globaluidslist) for each parameter, so the pair accumulatorpair_accstill keys pairs by global UID indices, and final reporting (uids[i],uids[j]) remains correct.- Only includes peers that actually provided compressed indices for a parameter; if fewer than 2 peers have data for a given param, it skips that parameter cleanly.
- Leaves the slashing decision, overlap aggregation, and summary metrics (mean/min/max overlaps, offenders) unchanged.
This is a nice improvement in both flexibility and complexity. One small cleanup you could consider (non-blocking): in the partials branch,
all_state_dictsis populated but then immediately overridden viauid_to_partiallookups, so it can likely be removed entirely from that branch to reduce confusion.Also applies to: 1351-1399, 1408-1424, 1431-1471
neurons/miner.py (2)
554-562: Consider guarding CUDA cleanup for CPU‑only runsUnconditionally calling
torch.cuda.empty_cache()andtorch.cuda.synchronize()is fine on CUDA builds, but will raise on CPU‑only PyTorch builds. If you expect any CPU‑only miner runs, consider wrapping this block withif torch.cuda.is_available():to avoid surprises.
740-745: Distributed gather / outer‑step semantics rely on merged result on all ranksThis section assumes that in distributed mode:
gather_with_reserve(..., return_partials=False, ...)returns a mergedSimpleNamespacewith.uids,.skipped_uids,.success_rate, and.state_dicton every rank, not just on the master.should_updateis determined by the master (gather_result is not None) and then broadcast viadist_helper.all_ok, while non‑masters still pass their localgather_resultintoself.outer_step(gather_result).Given that, the logging, metrics (
gather_success_rate,skipped_uids_list), debug info, and cleanup (state_dict.__dict__.clear()) are all consistent and safe.If, however,
gather_with_reserveonly returns a non‑Noneresult on the master in some paths, non‑master ranks would end up callingouter_step(None)and thestate_dictcleanup would be a no‑op on those ranks. Please double‑checkComms.gather_with_reserve/gather_distributedto confirm that a merged result is indeed broadcast (or otherwise made available) to all ranks in distributed mode, or alternatively gateouter_stepand the cleanup onself.is_masterand broadcast updated model weights instead.Also applies to: 747-756, 757-757, 778-813, 840-842, 895-902, 937-943, 1007-1007, 1024-1033
neurons/validator.py (2)
1490-1515: Distributed gather flow is coherent; consider clarifying mode flag and assumptionsThis block wires up the validator’s distributed gather path:
use_distributed_gather = Trueforces all ranks into the distributed path, withgather_with_reserve(..., return_partials=False, ...)responsible for choosing distributed vs sequential and merging results.- Gather completion logs and downstream metrics (
skipped_uids,success_rate,actual_gather_uids, reserve usage, quality stats) all treatgather_resultas a mergedSimpleNamespace, not a list of partials.skip_windowis decided on the master based ongather_result is Noneor exceptions and then broadcast, keeping all ranks aligned.Functionally this looks correct assuming:
gather_with_reservealways returns a merged result (orNone) rather than partials for this call site, and- In distributed mode, that merged result is available on all ranks (or
outer_step/norm code gracefully handlesNoneon non‑masters).Two small suggestions:
- Rename
use_distributed_gatheror derive it fromdist_helper.is_distributed()again, so the variable name matches its intent (currently it’s alwaysTrue, even in effectively sequential runs).- Optionally add an assertion or comment at this call site documenting that this path is “merged‑result only” and will break if
return_partials=Trueis ever reintroduced here.Also applies to: 1516-1531, 1543-1560, 1561-1590, 1606-1643, 1667-1705, 1721-1733
1821-1824:compute_peer_val_normstype hint/docstring are broader than the implementationThe signature and nearby comment say
gather_result: SimpleNamespace | list[SimpleNamespace]and “handles both partials and merged results”, but the implementation assumes:
gather_resulthas a.state_dictattribute, andstate_dictexposes attributes like"{param}vals"/"{param}quant_params"(i.e.,SimpleNamespacelayout), not a list of partials.Given the current call site (
compute_peer_val_norms(gather_result)right aftergather_with_reserve(..., return_partials=False, ...)), this works becausegather_resultis always merged. However, the type hint and docstring are misleading and could encourage incorrect future use.Consider either:
- Narrowing the annotation and docstring to
gather_result: SimpleNamespaceand “merged result only”, or- Re‑introducing a partials‑aware implementation (e.g., aggregate per‑param
vals/quant_paramsacross a list of namespaces, with proper alignment and dequantization) if you intend to support partials here again.Also applies to: 3791-3802
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
neurons/base_node.py(1 hunks)neurons/miner.py(9 hunks)neurons/trainer.py(3 hunks)neurons/validator.py(12 hunks)src/tplr/comms.py(19 hunks)src/tplr/compress.py(1 hunks)src/tplr/neurons.py(10 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- neurons/trainer.py
- src/tplr/compress.py
🧰 Additional context used
🧬 Code graph analysis (4)
src/tplr/neurons.py (1)
src/tplr/compress.py (6)
compress(373-377)compress(379-383)compress(386-423)ChunkingTransformer(145-292)TopKCompressor(295-697)unpack_12bit_indices(102-142)
neurons/miner.py (3)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/comms.py (1)
gather_with_reserve(2408-2701)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)gather_object(284-305)
neurons/validator.py (6)
tests/test_state_loading.py (1)
rating(59-62)src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/comms.py (2)
gather_with_reserve(2408-2701)get_with_retry(1625-1740)src/tplr/neurons.py (1)
check_uid_index_overlap(1275-1480)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (12)
neurons/base_node.py (1)
61-62: ****The placement of
super().__init__()at the end ofBaseNode.__init__()is correct and intentional for the cooperative multiple inheritance pattern used here.Verification confirms:
Trainer.__init__()only initializes one attribute (inner_scheduler_step_count) and has no dependencies onBaseNodeattributesTrainermethods assumeBaseNodeattributes (uid, rank, world_size, hparams, etc.) are already initialized- For both
Miner(BaseNode, Trainer)andValidator(BaseNode, Trainer), the MRO placesBaseNodefirst, followed byTrainer- The end-of-initialization call to
super().__init__()ensuresBaseNodecompletes its initialization before delegating toTrainerThis is a well-designed cooperative multiple inheritance pattern where parent-class initialization properly follows the method resolution order. No changes are needed.
src/tplr/comms.py (5)
20-21: Partition cache introduction looks goodUsing
OrderedDictwith a hard cap (32 entries) keyed by(window, sorted_uids, world_size)is a sensible way to reuse UID partitions without unbounded growth. No functional or concurrency issues spotted here.Also applies to: 29-29, 142-144
483-582: HEAD-baseds3_object_existsis robust and conservativeThe new
s3_object_existshelper correctly:
- Uses
head_objectwith a timeout.- Distinguishes 404/not-found and config/permission errors (404/403, NoSuchBucket, invalid bucket).
- Treats timeouts and 5xx/network issues as transient and returns
Trueto avoid false slashing.This is a solid, structured improvement over string-based error checks and matches the intended conservative semantics.
1563-1566: Local-pathtorch.loadoffload is appropriateRunning the local
torch.loadviaasyncio.to_threadin thelocal=Truecode path mirrors the remote path change and avoids blocking the event loop during deserialization. The subsequent handling ofstate_dictandglobal_stepis unchanged, so this is a safe performance/latency improvement.
1472-1493:gradient_timestampversion parameterization is correctAllowing an optional
versionoverride and defaulting totplr.__version__fixes the earlier undefinedself.versionissue and aligns with how other paths construct gradient keys. The rest of the logic (usingcommitments,head_object, and returning0.0on error) remains intact.
1844-2006: Gather response processing and validation look consistentThe updated response-processing loop in
gather():
- Logs per-UID download exceptions with
log_with_context(warning level).- Treats
Noneresponses fromget_with_retryas “skipped” UIDs with an info log.- Adds tighter validation for quantization params and NaN/Inf checks using a temporary
tensor_to_checkon the target device and cleans it up immediately.These changes preserve the earlier semantics (only valid, fully-formed compressed payloads contribute to aggregation) while improving observability and keeping GPU temporaries short-lived. No correctness issues spotted here.
neurons/miner.py (2)
32-35: CUDA allocator config before torch import looks correctSetting
PYTORCH_CUDA_ALLOC_CONFbefore importingtorchis the right place to influence the allocator; the conditional guard also preserves caller‑provided settings. No changes needed.
474-493: Peer / reserve / commitment broadcast keeps all ranks consistentBroadcasting
[peers, reserve_peers, commitments]from the master to all ranks aligns miner behavior with the validator and withgather_with_reservesemantics; the logging on non‑master ranks is useful for debugging. This block looks sound.neurons/validator.py (4)
931-937: OpenSkill rating initialization beforerate()is correctEnsuring every UID in
window_uidshas an entry inself.openskill_ratingsbefore buildingteamsavoids model errors and keeps rating updates stable across windows. This change looks good.
1337-1339: Commitment refresh + peers/reserve broadcast matches miner and comms usageRefreshing
self.comms.commitmentson the master afterupdate_peers, then broadcasting[peers, reserve_peers, commitments]to all ranks, keeps gather configuration aligned across the validator group and mirrors the miner’s behavior. Logging on non‑master ranks provides helpful visibility. No issues here.Also applies to: 1350-1369
1411-1437: Timestamp broadcast for gather time window fixes rank divergenceHaving only the master call
query_block_timestamp(), falling back totime.time()if needed, and then broadcastingts_valueto all ranks ensurestime_min/time_maxare identical across the distributed gather. This addresses the earlier risk of per‑rank time‑window skew. Looks good.Also applies to: 1441-1454
3197-3205: Robust handling of missing validator debug dictSwitching to
self.comms.get_with_retry(...)and explicitly handlingNone/ unsuccessful results gives clearer behavior when the miner’s debug payload is missing or unreachable and avoids attribute errors. The warning log that includesgetattr(debug_result, "error", "unknown")is also safe for theNonecase. This looks correct.
20fec9f to
66f8365
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
neurons/trainer.py (1)
872-872: Remove duplicate gradient clipping before unscaling.This line clips gradients before
scaler.unscale_()is called (line 886), which is incorrect for mixed precision training. When usingGradScaler, gradients must be unscaled before clipping; otherwise, you're clipping scaled gradients rather than true gradients. Additionally, lines 887-889 correctly clip gradients after unscaling, making this line redundant.Impact: Gradients are clipped twice in the normal training path, which incorrectly reduces gradient magnitudes and affects convergence.
Apply this diff to remove the duplicate clipping:
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) - if not null_round:src/tplr/comms.py (1)
1665-1736: Early NOT_FOUND cutoff may still violate timeout semantics for transient errors.The past review comment correctly identified that
max_consecutive_not_found = 5(≈2.5s) causes early exit independent of thetimeoutparameter (default 30s). The current implementation still has this issue:
- Line 1600 converts all
Nonereturns froms3_get_objecttostatus="NOT_FOUND"without distinguishing the causes3_get_objectreturnsNonefor both true 404s (line 658) AND asyncio timeouts (line 653)- The early cutoff at 2.5s will abort downloads that are experiencing slow/flaky connections but should continue retrying until the caller's timeout expires
Suggested fix: Either:
- Distinguish true 404s from transient errors in the status (e.g., return
"TIMEOUT"vs"NOT_FOUND")- Only increment
consecutive_not_foundfor confirmed 404 responses- Derive the cutoff threshold from the
timeoutparameter:max(5, int(timeout / 1.0))This ensures transient network issues respect the caller's timeout window while still optimizing away persistent 404s.
🧹 Nitpick comments (4)
src/tplr/neurons.py (1)
332-332: Optional cleanup: Unused device resolution.The
cuda_devicevariable is computed but never used in the function body. The stringdeviceparameter is used throughout instead.If CUDA device resolution isn't needed here, consider removing this line:
- cuda_device = _resolve_cuda_device(device)If it was intended for future use (e.g., for CUDA API calls), consider adding a comment explaining why it's computed but unused.
src/tplr/comms.py (1)
2156-2171: Device fallback logic handles edge cases but has a subtle issue.The device selection correctly handles
RuntimeErrorfromtorch.cuda.current_device()and falls back gracefully. However, lines 2165-2169 attempt to set device 0, which could fail if device 0 is unavailable:try: torch.cuda.set_device(0) target_device = "cuda:0" except RuntimeError: target_device = "cpu"This is a minor issue since it eventually falls back to CPU, but it could be cleaner to probe for the first available device index rather than assuming device 0 exists. That said, the current implementation is safe and will not crash.
neurons/validator.py (2)
1490-1490: Auto-detect distributed gather instead of hardcoding.Line 1490 unconditionally sets
use_distributed_gather = True, but the PR objectives specify "auto-detection to use distributed gather when world_size > 1." This reduces clarity and doesn't match the stated behavior.Apply this diff to auto-detect based on actual distributed state:
- use_distributed_gather = True + use_distributed_gather = dist_helper.world_size > 1 and dist_helper.is_distributed()This aligns with the PR objectives and makes the gather mode selection explicit and conditional on actual multi-GPU setup.
3790-3826: Function signature claims to accept list but implementation doesn't handle it.The signature on line 3792 declares
gather_result: SimpleNamespace | list[SimpleNamespace], but the implementation at line 3805 directly accessesgather_result.state_dictwithout checking if it's a list first. If a list were passed, this would raiseAttributeError.While this isn't a runtime bug (line 1520 ensures
return_partials=Falseso only merged results are passed), it violates the interface contract and could mislead future developers.Choose one of these options:
Option 1: Remove list from signature (simpler, matches actual usage):
def compute_peer_val_norms( self, - gather_result: SimpleNamespace | list[SimpleNamespace], + gather_result: SimpleNamespace, ) -> dict[str, torch.Tensor]: """ Compute median L2 norms across all peers for each parameter. Args: - gather_result: Merged gather result SimpleNamespace + gather_result: Merged gather result (must be SimpleNamespace, not partials list)Option 2: Add list handling (if partials support is needed):
def compute_peer_val_norms( self, gather_result: SimpleNamespace | list[SimpleNamespace], ) -> dict[str, torch.Tensor]: """ Compute median L2 norms across all peers for each parameter. Args: - gather_result: Merged gather result SimpleNamespace + gather_result: Either merged result (SimpleNamespace) or list of partials Returns: Dictionary mapping parameter names to median norms """ clip_norm_dict = {} + # If partials, merge first + if isinstance(gather_result, list): + gather_result = self.comms.merge_gather_results(gather_result) + state_dict = gather_result.state_dictOption 1 is recommended since the code always uses merged results.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
neurons/base_node.py(1 hunks)neurons/miner.py(9 hunks)neurons/trainer.py(3 hunks)neurons/validator.py(12 hunks)src/tplr/comms.py(19 hunks)src/tplr/compress.py(1 hunks)src/tplr/neurons.py(10 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- src/tplr/compress.py
- neurons/base_node.py
🧰 Additional context used
🧬 Code graph analysis (4)
neurons/miner.py (3)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/comms.py (1)
gather_with_reserve(2395-2688)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)gather_object(284-305)
neurons/validator.py (6)
tests/test_state_loading.py (1)
rating(59-62)src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/comms.py (2)
gather_with_reserve(2395-2688)get_with_retry(1625-1740)src/tplr/neurons.py (1)
check_uid_index_overlap(1275-1480)
src/tplr/neurons.py (1)
src/tplr/compress.py (4)
compress(373-377)compress(379-383)compress(386-423)unpack_12bit_indices(102-142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (22)
neurons/trainer.py (3)
887-889: LGTM: Correct gradient clipping and norm capture.This correctly captures the gradient norm after unscaling and before the optimizer step, following the proper pattern for mixed precision training with
GradScaler.
911-911: LGTM: Appropriate default for spin-up phase.Setting
total_grad_normto 0.0 during the spin-up phase (when no optimizer step is taken) provides a consistent value for logging and is appropriate for this branch.
928-928: LGTM: Useful diagnostic logging.Adding gradient norm to the inner step logging provides valuable insight into training dynamics and helps monitor for gradient explosion or vanishing issues.
src/tplr/neurons.py (5)
48-69: LGTM: Well-designed device resolution helper.The function properly handles all edge cases (None device, string conversion, no CUDA devices) with appropriate fallbacks. The defensive error handling ensures robustness.
280-294: LGTM: Clean delegation pattern.The function correctly delegates to
_outer_step_single()with all parameters forwarded, maintaining a clean separation of concerns.
1277-1315: LGTM: Well-structured partial vs. merged handling.The function properly extracts UIDs and state_dicts from both partial and merged gather results. The
uid_to_partialmapping at line 1296 provides O(1) lookup, successfully addressing the O(n×m) complexity concern from previous reviews.
1354-1395: LGTM: Correct partial/merged indexing with defensive checks.The code properly handles both partial and merged cases:
- Partials use the O(1)
uid_to_partiallookup (lines 1361-1365)- Merged results use direct list indexing (lines 1381-1388)
- Defensive
continueat line 1362 handles missing UIDs gracefullyThe optimization successfully reduces complexity from O(n×m) to O(n).
1408-1424: LGTM: Correct pair key construction.The code properly uses
valid_uids_for_paramto map back to original UID indices (lines 1420-1422), ensuring pair keys remain consistent whether using partials or merged results.neurons/miner.py (7)
32-34: LGTM: Correct CUDA allocator configuration.The environment variable is properly set before importing torch (required for it to take effect). The conditional check prevents overwriting any existing configuration.
474-492: LGTM: Complete peer synchronization across ranks.The broadcast properly synchronizes peers, reserve_peers, and commitments from master to all ranks. This ensures consistent gather behavior across distributed ranks and addresses the previous review comment about broadcasting reserve peers.
The payload structure (3-item list) and unpacking logic are correct.
554-561: LGTM: Thorough pre-training memory cleanup.The aggressive cleanup (cache clear, synchronize, gradient clear) helps prevent OOM at the start of training. The defensive check at line 560 safely handles the model attribute.
740-816: LGTM: Correct distributed/sequential gather orchestration.The gather mode selection logic properly handles both scenarios:
Distributed mode (lines 741-743, 759):
- All ranks participate in collective operations
- Master checks result, non-master ranks defer (lines 806-810)
Sequential mode (lines 752-755, 759):
- Only master performs gather
Consistent configuration (lines 763-774):
return_partials=False: Always returns merged result, aligning withouter_stepexpectationsreserve_uids=self.comms.reserve_peers: Matches validator behaviorlocal=False: Correctly fetches from remote storageThe final broadcast at line 815 ensures all ranks agree on whether to update.
840-842: LGTM: Proper post-outer-step gradient cleanup.The defensive check and
set_to_none=Trueensure memory is efficiently released after the outer step.
895-901: LGTM: Correct metrics extraction from merged gather result.Since
return_partials=Falseis set at line 764,gather_resultis always a mergedSimpleNamespaceorNone(never a list). The attribute accesses (.uids,.skipped_uids,.success_rate) are correct for this format, and the defensiveNonecheck at lines 938-943 prevents errors when no gradients are gathered.Also applies to: 937-943
1024-1035: LGTM: Comprehensive post-window memory cleanup.The cleanup thoroughly releases gather result memory (lines 1028-1030) with defensive attribute checks, followed by garbage collection and CUDA cache clearing. This aggressive memory management helps prevent OOM issues across windows, particularly important for 70B-model scenarios mentioned in the PR objectives.
src/tplr/comms.py (7)
142-144: Verify the partition cache key includes all relevant factors.The cache key
(window, tuple(sorted_uids), world_size)correctly captures the inputs that determine partitioning. However, consider whethermy_uidorrankshould also be part of the key if they affect partition assignment in the future. Currently, the cache stores all partitions for all ranks (as a list), so this is correct.
483-582: LGTM! Conservative error handling prevents false slashing.The structured error inspection correctly distinguishes between confirmed 404s (returns False), permission/configuration errors (returns False with ERROR logging to surface setup issues), and transient network/service errors (returns True conservatively). This approach aligns with the goal of avoiding false slashing due to temporary network issues.
The fallback to string matching provides backward compatibility while the structured checks take precedence.
709-723: Critical fix: Offloading torch.load enables true parallel downloads.The
asyncio.to_threadwrapper correctly moves CPU-bound decompression off the event loop, allowing parallel gradient downloads to proceed concurrently instead of being serialized during decompression. The 1-2s decompression time per file would otherwise be a major bottleneck.The exception handling at lines 719-723 properly catches and logs load errors without crashing the download process.
1471-1492: LGTM! Clean addition of optional version parameter.The
versionparameter defaults toNoneand falls back totplr.__version__, preserving backward compatibility. The implementation is straightforward and allows querying gradient timestamps across different versions when needed.
2072-2095: LGTM! Deterministic round-robin partitioning ensures consistency.The
partition_uids_across_ranksmethod correctly:
- Sorts UIDs for deterministic ordering across all ranks
- Uses round-robin assignment: rank i gets indices i, i+world_size, i+2*world_size, ...
- Returns all UIDs when world_size <= 1 (single-rank case)
This ensures all ranks see the same partition mapping and no UID is assigned to multiple ranks.
2500-2551: Verify type consistency when wrapping sequential fallback results.The code wraps sequential gather results in a list when
return_partials=True:
- Line 2519:
primary = [sequential_result] if return_partials else sequential_result- Line 2531:
primary = [sequential_result] if return_partials else sequential_result- Line 2551:
primary = [empty_result] if return_partials else empty_resultHowever, there's a potential issue: when the distributed gather succeeds but returns a merged result instead of partials (due to the bug in
gather_distributed), the caller will receive aSimpleNamespaceeven though they requestedreturn_partials=True. This inconsistency could cause type errors downstream.Once the
gather_distributedbug is fixed (see previous comment), verify that all code paths consistently return the correct type based onreturn_partials.
2584-2643: Reserve peer logic correctly handles both partials and merged results.The reserve peer retry logic properly branches based on whether
primaryis a list (partials) or SimpleNamespace (merged):
- Partials path (lines 2614-2626): Appends the fallback as a new partial entry, assigns it a rank index, and logs appropriately
- Merged path (lines 2627-2643): Merges the fallback's state_dict tensors, UIDs, and metrics into the primary result
This dual handling ensures the reserve mechanism works correctly regardless of the return format.
66f8365 to
fdce702
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/tplr/comms.py (1)
1665-1740: Early NOT_FOUND cutoff still violates timeout semantics.The consecutive NOT_FOUND counter (lines 1669-1670, 1725-1736) exits after 5 occurrences (~2.5s), but
get()at line 1600 returnsstatus="NOT_FOUND"for ALLNonereturns froms3_get_object, including transientasyncio.TimeoutError(line 653). This causes premature exit for slow/flaky peers long before the caller'stimeout(default 30s) expires, breaking the intended retry semantics.Fix: Either:
- Make
s3_get_objectdistinguish true 404s from transient errors in its return value, or- Derive
max_consecutive_not_foundfromtimeout(e.g.,int(timeout / 0.5)) so it scales with the caller's intent, or- Only apply the early exit after
time_maxhas been exceeded (when we know the file should exist)Based on learnings
🧹 Nitpick comments (3)
neurons/validator.py (2)
1490-1506: Hardcoded flag contradicts auto-detection intent.Line 1490 sets
use_distributed_gather = Trueunconditionally, but the PR objectives state "Auto-detection enables distributed gather when world_size > 1." The actual distributed vs. sequential decision happens insidegather_with_reserve(src/tplr/comms.py:2424 checksdist_helper.world_size > 1), so the runtime behavior is correct. However, the hardcoded flag here makes the logging and control flow misleading.Consider aligning with the actual auto-detection:
- use_distributed_gather = True + use_distributed_gather = dist_helper.world_size > 1 and dist_helper.is_distributed()This makes the logging accurately reflect whether distributed or sequential gather will actually be used.
3790-3826: Type signature broader than implementation.The function signature declares
gather_result: SimpleNamespace | list[SimpleNamespace], but the implementation at line 3805 directly accessesgather_result.state_dict, which only works forSimpleNamespace, notlist[SimpleNamespace]. There's noisinstance()check or branching to handle the list case.In practice, this isn't a runtime bug because
gather_resultis always a mergedSimpleNamespace(see line 1520:return_partials=Falsewith comment "Always return merged result"). However, the type hint misleadingly suggests the function can handle both forms.Consider narrowing the type hint to match actual usage:
def compute_peer_val_norms( self, - gather_result: SimpleNamespace | list[SimpleNamespace], + gather_result: SimpleNamespace, ) -> dict[str, torch.Tensor]:Alternatively, if future flexibility is desired, add explicit handling:
if isinstance(gather_result, list): raise NotImplementedError("Partials path not needed since return_partials=False")src/tplr/comms.py (1)
2133-2257: LGTM: Comprehensive merge with robust device handling.The merge correctly:
- Concatenates per-rank tensor lists for each parameter
- Aggregates metrics (max time for parallel, sum bytes, extend UIDs)
- Calculates overall success rate
Device fallback logic gracefully handles edge cases with nested try/except, though it assumes device 0 is available when multiple devices exist. For systems where device 0 is missing but others exist, this will fall back to CPU, which is safe but suboptimal. Consider iterating through
range(torch.cuda.device_count())to find the first working device if this becomes an issue.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
neurons/base_node.py(1 hunks)neurons/miner.py(9 hunks)neurons/trainer.py(3 hunks)neurons/validator.py(12 hunks)src/tplr/comms.py(19 hunks)src/tplr/compress.py(1 hunks)src/tplr/neurons.py(10 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- neurons/trainer.py
- src/tplr/compress.py
🧰 Additional context used
🧬 Code graph analysis (4)
src/tplr/neurons.py (1)
src/tplr/compress.py (5)
compress(373-377)compress(379-383)compress(386-423)TopKCompressor(295-697)unpack_12bit_indices(102-142)
neurons/miner.py (3)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/comms.py (1)
gather_with_reserve(2401-2694)
neurons/validator.py (5)
tests/test_state_loading.py (1)
rating(59-62)src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/comms.py (2)
gather_with_reserve(2401-2694)get_with_retry(1625-1740)src/tplr/neurons.py (1)
check_uid_index_overlap(1275-1480)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)gather_object(284-305)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.12)
- GitHub Check: test (3.11)
🔇 Additional comments (28)
neurons/base_node.py (1)
61-62: LGTM! MRO chain properly initialized.The
super().__init__()call correctly continues the Method Resolution Order chain, ensuring parent classes (e.g.,Trainer) are properly initialized. This resolves the previously flagged issue whereexpected_compressed_paramsand other parent class attributes were not being initialized.src/tplr/neurons.py (4)
48-69: LGTM! Robust CUDA device resolution.The
_resolve_cuda_devicehelper properly handles all edge cases: CUDA unavailability, None inputs, invalid device strings, and type coercion. This addresses the previously flagged issue where CUDA APIs were called with string devices instead oftorch.deviceobjects.
250-294: LGTM! Simplified outer_step delegation.The refactor simplifies
outer_stepto pure delegation while maintaining the single-path behavior requested in previous reviews. The updated docstring clearly documents that distributed gather merges partials internally, so callers always receive a single merged result.
297-551: LGTM! Device resolution properly integrated.The
_outer_step_singleimplementation correctly uses_resolve_cuda_deviceat line 332, ensuring CUDA operations receive propertorch.deviceobjects instead of strings. The rest of the memory-minimizing outer step logic remains sound.
1275-1480: LGTM! Efficient dual-path overlap checking.The updated
check_uid_index_overlapcorrectly supports both merged results and partial results. The UID-to-partial mapping (lines 1296-1306) provides O(1) lookup, reducing complexity from O(n×m) to O(n) as requested in previous reviews. The unpacking logic (lines 1358-1388) properly handles both paths, and pair keys use original UID indices for consistency.neurons/miner.py (9)
32-34: LGTM! Early CUDA allocator configuration.Setting
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:Truebefore importing torch is the correct approach for memory optimization. The conditional check respects existing user environment settings.
474-492: LGTM! Complete peer state synchronization.The broadcast now correctly includes
peers,reserve_peers, andcommitments, ensuring all ranks have consistent gather state. This addresses the previous review comment about missingreserve_peersin the broadcast payload.
554-561: LGTM! Proactive memory management.The aggressive memory cleanup (cache clearing, synchronization, and gradient zeroing) before training helps prevent OOM issues, especially important for distributed training with large models mentioned in the PR objectives (70B-model scenarios).
740-812: LGTM! Distributed gather properly integrated.The distributed gather implementation correctly:
- Auto-detects when to use distributed mode based on
world_sizeanddist_helper.is_distributed()(lines 741-743)- Includes
reserve_uids=self.comms.reserve_peersas requested in previous reviews (line 763)- Sets
local=Falseto pull from remote storage (line 769), fixing the previously flagged issue- Always returns merged results (
return_partials=Falseat line 764), consistent with theouter_stepdocstring- Provides comprehensive logging for both distributed and sequential paths (lines 778-802)
- Correctly handles master/non-master decision logic for
should_update(lines 806-812)
840-842: LGTM! Explicit gradient cleanup after outer step.The gradient clearing with
set_to_none=Trueafter the outer step ensures no stale gradients remain, and is more memory-efficient than zeroing. The defensivehasattrcheck is good practice.
895-901: LGTM! Correct attribute access for merged results.Direct access to
gather_result.uidsandgather_result.skipped_uidsis correct sincegather_with_reserveis called withreturn_partials=Falseat line 764, ensuring a mergedSimpleNamespaceis always returned. This addresses the previously flagged incompatible attribute access issue.
937-943: LGTM! Safe metric extraction with None handling.The success rate and skipped UIDs extraction properly handles the
Nonecase and correctly accesses attributes from the merged result. Thesuccess_rate * 100conversion to percentage is appropriate for logging.
1024-1033: LGTM! Aggressive memory cleanup strategy.The enhanced cleanup explicitly clears
state_dictcontents (lines 1027-1030), deletes objects, forces garbage collection, and empties CUDA cache. This aggressive approach aligns with the PR objective of avoiding GPU OOM in 70B-model scenarios. The defensivehasattrchecks prevent errors if structures are already cleaned up.
774-774:expected_compressed_paramsis properly initialized in parent class.The verification confirms that
Trainer.__init__()at line 123 setsself.expected_compressed_params = self.get_expected_params(). With thesuper().__init__()fix inBaseNode, the MRO chain ensures this attribute is initialized before line 774 accesses it. No action needed.neurons/validator.py (5)
931-937: LGTM! OpenSkill rating initialization prevents KeyError.Properly initializes ratings for new UIDs before the rating operation. This ensures all peers in
window_uidshave valid OpenSkill rating objects beforeself.openskill_model.rate()is called.
1337-1368: LGTM! Peers and commitments broadcast enables distributed gather across ranks.The master rank fetches updated commitments and broadcasts peer lists plus bucket commitments to all ranks, enabling them to participate in distributed gather. The second commitment fetch at line 1338 (after the initial fetch at line 1171) ensures fresh bucket information immediately before gathering.
1411-1438: LGTM! Timestamp broadcast ensures all ranks use consistent time window.Correctly addresses the previous critical review concern about potential rank inconsistency. The master rank queries the timestamp (with fallback to
time.time()on failure), then broadcasts it to all ranks viadist_helper.broadcast(). This guarantees all ranks use identicaltime_minandtime_maxvalues for the distributed gather operation.
1607-1643: LGTM! Simplified gather result handling eliminates duplicate merges.Since
return_partials=False(line 1520),gather_resultis always a mergedSimpleNamespace. The code correctly uses it directly for upload (line 1608), overlap checking (line 1614), and metric extraction (lines 1621, 1642) without redundant merging operations. This addresses the memory overhead concerns from previous reviews.
3197-3217: LGTM! get_with_retry improves reliability with better timeout and retry logic.The change from
self.comms.get()toself.comms.get_with_retry()adds robust retry mechanisms including exponential backoff, consecutive NOT_FOUND tracking, and graceful timeout handling (returnsNoneinstead of raising). The error handling correctly checks for bothNone(timeout) and unsuccessful results before proceeding.src/tplr/comms.py (9)
20-20: LGTM: Clean imports and cache initialization for distributed gather.The additions are well-structured:
gcimport supports aggressive memory cleanup for large modelsOrderedDictenables LRU-like partition cachingdist_helperimport provides distributed primitives_partition_cachetype annotation correctly captures the cache structureAlso applies to: 29-29, 50-50, 142-144
483-582: LGTM: Robust existence check with structured error handling.The implementation properly distinguishes between different error types using structured
ClientError.responseinspection rather than brittle string matching. The conservative approach (returning True for transient errors) is appropriate to avoid false negatives that could lead to incorrect slashing decisions.
667-668: LGTM: Critical fix for parallel download performance.Offloading
torch.load()to the thread pool prevents blocking the event loop during CPU-intensive decompression (1-2s per file). This enables true parallel downloads. The 500MB threshold increase is appropriate for 70B model gradients.Also applies to: 709-723
1471-1493: LGTM: Clean version parameter with sensible default.The optional
versionparameter withtplr.__version__as the default is well-designed and properly documented. This enables checking gradients from different versions while maintaining backward compatibility.
1563-1566: LGTM: Consistent thread pool offloading for local loads.Applying the same
asyncio.to_threadfix to local file loading ensures the event loop isn't blocked during decompression regardless of storage location.
2071-2095: LGTM: Clean deterministic round-robin partitioning.The implementation correctly sorts UIDs for determinism and uses idiomatic Python slicing
[rank::world_size]for round-robin assignment. Edge case handling for single-rank execution is appropriate.
2097-2131: LGTM: Robust recursive device movement.The recursive
move_to_devicehelper correctly handles nested structures (lists, tuples, dicts) and tensors. In-place modification ofresult.state_dictis acceptable since this is typically called during final result processing.
2259-2396: LGTM: Well-structured distributed gather with proper return_partials support.The implementation correctly:
- Caches partition patterns per window with LRU eviction
- Gathers partials to master only (memory-efficient)
- Returns normalized partial list when
return_partials=True(line 2365-2370)- Merges and returns consolidated result when
return_partials=False(line 2372-2389)Non-master ranks receiving
Noneis expected with the gather-to-master approach and must be handled by callers via theall_okbroadcast pattern.
2401-2694: LGTM: Comprehensive gather with proper distributed fallback and type consistency.The refactor correctly:
- Auto-detects when to use distributed gather (line 2442)
- Falls back to sequential gather when distributed fails (lines 2506-2525)
- Maintains type consistency: wraps sequential/empty results in list when
return_partials=True(lines 2525, 2537, 2557)- Handles reserve retry by either appending as a new partial (lines 2620-2632) or merging into the result (lines 2633-2649)
- Computes aggregate success rate for partials (lines 2654-2675)
- Returns
Nonewhen no gradients are gathered (lines 2678-2679, 2691-2692)The sequential fallback for reserve retry (line 2612) is appropriate since it's typically a small set of replacement UIDs.
Core implementation: - Add gather_distributed() for parallel gradient downloads across ranks - Partition UIDs deterministically across GPU ranks - Each rank downloads subset in parallel, then all-gather to merge - CPU/GPU memory management to prevent OOM on large models - Fallback to sequential gather if distributed fails - Auto-detect based on world_size (no manual config needed) Miners: - Use distributed gather when world_size > 1 - Parallel gradient downloads from peer buckets - Use reserve peers for fallback (same as validators) - Proper synchronization across all ranks Validators: - Support distributed gather for multi-GPU validators - Improved error handling with get_with_retry - Better timeout handling (150s for 70B models) - Test mode support for local networks Code quality: - Removed dead code paths (~450 lines) - Simplified outer_step (single path) - Removed unused parameters - Clean gradient averaging via scatter_reduce(reduce="mean") Net: -339 lines while adding full distributed gather support
fdce702 to
89caddd
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
src/tplr/comms.py (2)
2425-2718: gather_with_reserve distributed path is mostly sound but master‑only fallback still risks divergenceThe overall structure of
gather_with_reserveis solid:
- Decides
use_distributedviadist_helperand logs entry with rank and UID sets.- In distributed mode, calls
gather_distributedwithreturn_partialspropagated.- Normalizes “no gradients” on master to trigger a sequential fallback, while treating
Noneon non‑masters as expected.- Wraps sequential/fallback results into a list when
return_partials=Trueto maintain type consistency.- For reserve retries, appends a new partial in partials mode or merges into the merged namespace.
- Recomputes
success_rateagainst the original gather tier and emits detailed context logs.However, in the distributed branch the sequential fallback still runs only on the master rank:
if primary is None and dist_helper.is_master: sequential_result = await self.gather(...) primary = [sequential_result] if return_partials else sequential_resultNon‑master ranks keep
primary=Noneand immediately returnNonelater in the function. This duplicates a prior concern: if callers expect all ranks to see the same post‑fallback gradients (e.g. miner/validator outer steps running on every rank), you may get cross‑rank inconsistency unless:
outer_step(or equivalent) is explicitly guarded to only use the result on master and broadcast updated model state, or- the master’s
primaryis broadcast to other ranks after the fallback.Please verify downstream usage (e.g., in
neurons/miner.pyand validator code) to ensure either master‑only update semantics are intentional and correctly implemented, or adjust this fallback to synchronize results.
1668-1671: Fixed NOT_FOUND cutoff still short‑circuits retries relative to timeoutThe
get_with_retryloop still gives up aftermax_consecutive_not_found = 5with 0.5 s sleeps, i.e. ≈2.5 s total, regardless of thetimeoutargument. Becauseget()currently maps all “no data” conditions (404, timeouts, various network/client errors that causes3_get_objectto returnNone) tostatus="NOT_FOUND", slow or flaky peers can be treated as “likely doesn’t exist” well before the caller’s timeout budget is used.This duplicates an earlier concern: either
get()/s3_get_objectshould distinguish true 404s from transient failures, or the NOT_FOUND cutoff should be derived fromtimeoutand/or only applied once we’re clearly pasttime_max.Also applies to: 1725-1737
🧹 Nitpick comments (8)
neurons/validator.py (3)
1337-1363: Peers/commitments broadcast keeps ranks aligned but may be heavier than neededRefreshing
self.comms.commitmentson the master each window and broadcasting[peers, reserve_peers, commitments]viadist.broadcast_object_listguarantees all ranks share an identical view before gather, which is important for distributed partitioning.Two minor points to consider:
get_commitments()is called afterupdate_peers_with_buckets(), so the freshly fetched commitments only take effect on the next window; if you want current‑window buckets, you could invert the order.- Broadcasting the full
commitmentsdict every window can become expensive as the network scales; if this shows up in profiling, consider broadcasting a lighter representation (e.g., UIDs plus a version or hash) and letting non‑master ranks lazily refresh from storage when the version changes.
1490-1524: Gather mode flag is hard‑coded; sequential-branch comments are now misleading
use_distributed_gatheris set toTrueunconditionally, so:
- The “sequential gather” logging branch (
elif self.is_master) is never taken.- All ranks always call
self.comms.gather_with_reserve(...), relying on Comms’ internal logic (world_size/is_distributed()) to decide whether the actual implementation is distributed or sequential.This is functionally fine (single‑GPU runs still behave sequentially, and multi‑GPU runs use the distributed path), but the flag no longer conveys real mode information and the surrounding comments become confusing.
Consider either:
- Deriving the flag from the actual distributed condition, e.g.:
and keeping the dual logging branches; oruse_distributed_gather = dist_helper.is_distributed()- Dropping
use_distributed_gatherentirely and simplifying the conditionals/log messages around the gather to reflect that Comms auto‑selects the mode.Also applies to: 1543-1560
3790-3826:compute_peer_val_normssignature suggests list support but implementation only handles merged resultsThe signature and type hint allow
gather_result: SimpleNamespace | list[SimpleNamespace], but the implementation assumes a single merged result:
- It immediately does
state_dict = gather_result.state_dictand usesgetattr(state_dict, vals_key, ...), which will fail for alistinput.- The docstring also still describes
gather_resultas “Merged gather result SimpleNamespace”.Given the current call site in this file always passes a merged result (because
gather_with_reserveis invoked withreturn_partials=False), this isn’t a runtime bug today, but it is misleading for future callers.I’d either:
- Narrow the type to
SimpleNamespacein both signature and docstring; or- Add an explicit
isinstance(gather_result, list)branch that merges partials or forwards them through a compatible path before computing norms.src/tplr/neurons.py (2)
48-69:_resolve_cuda_deviceresult is computed but never used
_resolve_cuda_deviceis a reasonable helper for normalizing CUDA devices, but in_outer_step_singleyou assign:cuda_device = _resolve_cuda_device(device)and then never use
cuda_device— all downstream calls still use the originaldeviceargument.To keep
outer_steplean and avoid confusion, consider either:
- Wiring
cuda_deviceinto the places that actually need a CUDAtorch.device(e.g., when constructing tensors meant for CUDA‑only APIs), or- Removing the helper and the assignment until you need it.
Also applies to: 332-333
969-1000: Avoid coupling catch-up toinstance.loop; use the current event loop insteadIn
catchup_with_aggregation_server, the fallback path uses:ts_value = await instance.loop.run_in_executor( None, instance.query_block_timestamp, sync_block )This assumes every
Miner/Validatorhas aninstance.loopattribute set before catch‑up runs. That’s true for the current validator path, but it’s a fragile coupling and easy to break if catch‑up is ever invoked from a different context.You can make this more robust and self‑contained by using the running loop inside the coroutine:
loop = asyncio.get_running_loop() ts_value = await loop.run_in_executor( None, instance.query_block_timestamp, sync_block )This keeps behavior the same while removing the implicit dependency on
instance.loop.neurons/miner.py (1)
1024-1033: Final gather_result cleanup is safe but note reliance on state_dict.dictClearing
gather_result.state_dict.__dict__and then deletinggather_resultbeforegc.collect()andtorch.cuda.empty_cache()is an effective way to drop large gradient payloads between windows. This block runs after all usages ofgather_resultin the loop, so lifecycle is correct.Only caveat: this assumes
state_dictis aSimpleNamespace(as produced by Comms). If future refactors change that type, this direct__dict__access may need adjustment.src/tplr/comms.py (2)
2098-2131: _move_gather_result_to_device is straightforward and safeThe recursive
move_to_devicehelper handles tensors nested in lists/tuples/dicts and rewritesresult.state_dictin place. Given callers only use it on gather results with SimpleNamespacestate_dict, this is an appropriate utility. Just be aware of potential extra device transfers if called redundantly.
3213-3247: s3_get_object_size remains simple but usable
head_object+ContentLengthfor object size, with 404 detection and purge on client errors, is a reasonable utility. If you later centralize ClientError handling (similar tos3_object_exists), this could share that logic, but current behavior is acceptable.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
neurons/base_node.py(1 hunks)neurons/miner.py(9 hunks)neurons/trainer.py(3 hunks)neurons/validator.py(12 hunks)src/tplr/comms.py(19 hunks)src/tplr/compress.py(1 hunks)src/tplr/neurons.py(10 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- src/tplr/compress.py
- neurons/base_node.py
- neurons/trainer.py
🧰 Additional context used
🧬 Code graph analysis (4)
src/tplr/neurons.py (1)
src/tplr/compress.py (4)
compress(373-377)compress(379-383)compress(386-423)unpack_12bit_indices(102-142)
neurons/validator.py (6)
tests/test_state_loading.py (1)
rating(59-62)src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (2)
is_distributed(75-77)broadcast(273-282)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/comms.py (2)
gather_with_reserve(2425-2718)get_with_retry(1625-1740)src/tplr/neurons.py (1)
check_uid_index_overlap(1275-1480)
neurons/miner.py (4)
src/tplr/chain.py (1)
get_commitments(264-322)src/tplr/distributed.py (1)
is_distributed(75-77)src/tplr/comms.py (1)
gather_with_reserve(2425-2718)src/tplr/dcp_checkpoint.py (2)
state_dict(60-61)state_dict(82-83)
src/tplr/comms.py (3)
src/tplr/schemas.py (1)
Bucket(24-46)src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/distributed.py (2)
is_distributed(75-77)gather_object(284-305)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: test (3.11)
- GitHub Check: test (3.12)
🔇 Additional comments (26)
neurons/validator.py (4)
931-937: OpenSkill rating initialization is robust and avoids missing-state edge casesEagerly instantiating
self.openskill_ratings[uid]for allwindow_uidsbefore callingrate()prevents lookups on missing ratings and keeps the update logic below straightforward. This block is correct and well‑scoped.
1411-1438: Timestamp calculation and broadcast now correctly synchronize all ranksOnly the master queries
query_block_timestamp(sync_block)and falls back totime.time()with logging, while all ranks receive the finalts_valueviadist_helper.broadcastand derivetime_min/time_maxfrom that shared scalar. This removes the prior risk of rank‑dependent time windows and looks correct.
1606-1643: Post-gather pipeline (upload → overlap → offload → norms) is consistentOn master, you:
- Upload the already‑merged
gather_resultasynchronously.- Run
check_uid_index_overlap(self, gather_result, ...)while gradients are still on the original device.- Then offload
gather_result.state_dictto pinned CPU viaoffload_gather_results.- Later, compute
clip_norm_dict = self.compute_peer_val_norms(gather_result)(which will pull tensors back to the target device as needed).This ordering keeps overlap detection on-device while still aggressively freeing GPU memory before evaluation and outer_step, and it matches the assumption that
gather_with_reserve(..., return_partials=False)always returns a mergedSimpleNamespace. The control flow here looks sound.Also applies to: 1821-1825
3197-3217: Switching toget_with_retrymakes sync-eval more resilientUsing
self.comms.get_with_retrywith an explicit timeout and handling theNonecase cleanly:
- Avoids tight retry loops inside
evaluate_miner_sync.- Ensures failures (timeout or non‑success) result in a well‑formed
{success: False, sync_score: 0.0}payload, with a defensivegetattr(debug_result, "error", "unknown")log message.The control flow and error handling here look correct.
src/tplr/neurons.py (2)
249-295: Outer-step refactor keeps semantics while improving memory safetyThe new
outer_stepdelegating to_outer_step_single:
- Treats
gather_resultas a single merged SimpleNamespace (matching the updated validator path withreturn_partials=False).- Builds and applies one parameter’s dense grad at a time, calling
optimizer.step()per param and immediately freeingp.gradplus intermediates, which is appropriate for tight memory budgets.- Correctly coordinates DDP/DTensor:
- Master (or single‑rank) computes
src_sdand dense grads._bcast_flagtells other ranks whether a param has an update.- Non‑DTensor params use
dist.broadcastintop.grad.- DTensor params use
distribute_tensorwithsrc_data_rank=0, then sanity‑check local shards.The broadcasting and collective calls are ordered identically on all ranks, so there’s no obvious deadlock risk, and invalid gradients are guarded by finite checks before being applied.
Overall this refactor looks correct and aligns with the PR’s “memory‑minimizing outer step” goal.
Also applies to: 297-552
1290-1479: Partial-awarecheck_uid_index_overlaplogic is consistent with merged-path semanticsThe updated
check_uid_index_overlapnow:
- Accepts either a merged
SimpleNamespaceor alist[SimpleNamespace]of partials.- For partials, flattens all
partial.uidsinto a singleuidslist and builds auid_to_partial[uid] = (partial, idx_in_partial)map so that per‑UID state can be located in O(1) rather than scanning all partials.- For each parameter:
- Builds
unpacked_indicesonly for UIDs that actually haveidxs/valspresent in their per‑UID state_dict.- Tracks the original UID index via
valid_uids_for_param, and uses those indices as keys intopair_acc, so the summary stage can still safely map back touids[i]/uids[j]when computing offenders and logging min/max overlaps.This preserves the original merged‑result behavior while extending it cleanly to partials, and the index bookkeeping between
valid_uids_for_param,pair_acc, and the finaluids[...]lookups is internally consistent.neurons/miner.py (7)
32-35: CUDA allocator env setup is in the right placeSetting
PYTORCH_CUDA_ALLOC_CONFbefore importingtorchis correct and avoids allocator reconfiguration warnings at runtime. No issues here.
474-493: Distributed broadcast keeps peers/reserves/commitments consistent across ranksBroadcasting
self.comms.peers,self.comms.reserve_peers, andself.comms.commitmentsfrom master ensures all ranks use identical gather inputs, which is critical for deterministic distributed gather. Thedist.broadcast_object_listusage and post‑assignment look correct.
554-562: Pre‑training CUDA cleanup is aggressive but appropriate for OOM mitigationCalling
torch.cuda.empty_cache(),torch.cuda.synchronize(), andmodel.zero_grad(set_to_none=True)beforeinner_stepsis a reasonable tradeoff to reduce VRAM fragmentation and stale gradients for large models. Given this code assumes CUDA, this is acceptable.
740-813: Distributed vs sequential gather selection and logging look consistentUsing
use_distributed_gather = dist_helper.world_size > 1 and dist_helper.is_distributed()and then:
- Running
gather_with_reserveon all ranks in distributed mode, and- Logging separate paths for distributed vs sequential gathers,
is coherent with the comms-sidegather_with_reserve/gather_distributedbehavior. Theshould_updatelogic (master decides, broadcast viaall_ok) is also sound.
840-843: Post‑outer‑step gradient clearing is correctZeroing gradients with
set_to_none=Trueimmediately after the outer step is good hygiene and helps reduce memory usage, especially with AMP and large models.
893-902: Debug peer bookkeeping aligns with gather_result structureUsing
gather_result.uids/gather_result.skipped_uidsand deriving:
successful_peers = peers - skipped_uidsskipped_peers = skipped_uidsmatches the semantics of the merged gather result from
gather_with_reserve. This should give accurate debug traces for which peers contributed gradients.
937-945: Gather success rate and skipped UID logging integrate cleanly into WandB/metricsThe computation of
gather_success_rateandskipped_uids_listand their inclusion both in WandB metrics and Influx fields is consistent and safe whengather_resultisNone(falls back to 0 / empty list).src/tplr/comms.py (13)
20-21: gc import is appropriate for explicit cleanup sitesImporting
gcat module level matches the later explicitgc.collect()calls and is a reasonable addition for aggressive cleanup paths.
29-50: Ordering new imports (OrderedDict, dist_helper) is consistent with usageAdding
OrderedDictanddist_helperat module level is correct and matches where they are used (_partition_cache, distributed helpers). No issues here.
142-145: _partition_cache design is reasonable for per‑window UID partitionsUsing an
OrderedDictkeyed by(window, tuple(sorted_uids), world_size)to cache UID partitions and trimming it to 32 entries is a good balance between determinism and memory usage. This avoids recomputing round‑robin partitions each window and is safe given small value sizes.
483-582: s3_object_exists implements robust, conservative existence checksThe new
s3_object_exists:
- Uses
head_objectwith a per‑call timeout.- Distinguishes 404 /
NoSuchKey, 403 / access issues, and bucket misconfig separately.- Treats 5xx and generic network errors as transient and conservatively returns
True.This is a solid improvement over brittle string matching and aligns with the intent to avoid false slashing on transient infrastructure issues.
666-693: Large‑file download path and size‑based timeout scaling are sensibleRaising the simple‑download threshold to 500 MB and using
download_large_filefor larger objects, withdownload_timeout = max(300, file_size / (25 * 1024 * 1024)), is appropriate for huge gradient files. This should materially reduce timeouts on >GB objects while still bounding total wall‑clock time.
699-723: Offloading torch.load to a worker thread is the right fix for event‑loop blockingWrapping
torch.loadinasyncio.to_threadavoids blocking the asyncio event loop during CPU decompression, andweights_only=Trueis appropriate for these tensor‑only payloads. The error logging path that falls back toloaded_data = Noneis also safe for callers.
1563-1566: Local torch.load offload mirrors remote behavior appropriatelyUsing
await asyncio.to_thread(torch.load, local_path, weights_only=True)for local gradient/debug loads keeps local and remote paths consistent and prevents blocking the loop on large local files. This change looks good.
1844-2056: Gather response processing and validation improvements are solidThe reworked
gatherloop:
- Logs per‑UID download errors and
Noneresponses clearly.- Validates quantization params, packed indices, and value tensors without dequantizing (for memory efficiency).
- Enforces presence of all
expected_compressed_paramsbefore accepting a UID.- Accumulates metrics and builds a structured
SimpleNamespaceresult.This is a robust design for validating and aggregating compressed gradients.
2071-2096: partition_uids_across_ranks is deterministic and matches round‑robin specSorting UIDs and then using
sorted_uids[rank::world_size]is a clear, deterministic round‑robin partitioning scheme. This aligns with the documented behavior and is easy to reason about.
2133-2281: merge_gather_results correctly rebuilds a deterministic merged resultThis helper does the right things:
- Filters out
Nonepartials and bails out early if none.- Picks a safe
target_device(current CUDA device if available with fallbacks).- Reconstructs
merged_state_dictby iterating UIDs in sorted order and indexing into each partial’s per‑param lists, ensuring the merged representation matches sequential gather ordering.- Aggregates
global_steps,skipped_uids, and byte/time metrics, computing a global success rate.Overall, the logic matches the semantics expected from merging per‑rank outputs of
gather().
2283-2420: gather_distributed semantics (master‑only merge, return_partials handling) look correctThe distributed gather flow:
- Caches deterministic UID partitions per (window, sorted_uids, world_size).
- Has each rank run
gather()on its assigned subset (or no‑op if empty).- Uses
dist_helper.gather_objectto bring all partial results only to rank 0.- On rank 0, normalizes empty partials to
None, and:
- returns the normalized list when
return_partials=True, or- calls
merge_gather_results(..., target_device=device)and returns the merged result.- Non‑master ranks always return
None.This matches the documented contract: in distributed mode only the master has data, and
return_partialsnow genuinely switches between merged vs per‑rank partials.
2937-3058: Peer list retrieval logic and parsing are consistentFetching peer lists from the highest‑stake validator, filtering by versioned
peers_*.jsonkeys, selecting most‑recent or previous windows, and parsing{peers, reserve_peers, first_effective_window}into a tuple is coherent. Given other hardening already present inget_peer_list, this path looks fine.
3068-3137: get_start_window version handling and retries are correctAllowing an optional
versionoverride while defaulting totplr.__version__, and retrying with backoff via_get_highest_stake_validator_bucketpluss3_get_object, matches expectations for bootstrap synchronization. No issues here.
| gather_result = await self.comms.gather_with_reserve( | ||
| my_uid=self.uid, | ||
| gather_uids=self.comms.peers, | ||
| reserve_uids=self.comms.reserve_peers, | ||
| reserve_uids=self.comms.reserve_peers, # Miners gather same way as validator | ||
| return_partials=False, # Always return merged result (distributed gather merges internally) | ||
| window=step_window, | ||
| key="gradient", | ||
| timeout=90, | ||
| device=str(self.device), | ||
| local=False, | ||
| stale_retention=100, | ||
| totalks=self.totalks, | ||
| compressor=self.compressor, | ||
| time_min=time_min, | ||
| time_max=time_max, | ||
| expected_compressed_params=self.expected_compressed_params, | ||
| ) | ||
| tplr.logger.info("Gather task completed!") | ||
| gather_time = tplr.T() - gather_start | ||
| should_update = gather_result is not None | ||
|
|
||
| # Broadcast whether we should update to all ranks | ||
| # Log gather completion | ||
| if use_distributed_gather: | ||
| if gather_result is not None: | ||
| tplr.logger.info( | ||
| f"[Rank {dist_helper.rank}] Distributed gather complete: {len(gather_result.uids)}/{len(self.comms.peers)} successful, " | ||
| f"{len(gather_result.skipped_uids)} skipped, success_rate={gather_result.success_rate:.2%}, " | ||
| f"time={gather_time:.2f}s" | ||
| ) | ||
| else: | ||
| tplr.logger.warning( | ||
| f"[Rank {dist_helper.rank}] Distributed gather failed - no gradients collected from peers" | ||
| ) | ||
| else: | ||
| # Sequential gather logging (master only) | ||
| if gather_result is not None: | ||
| tplr.logger.info( | ||
| f"Sequential gather complete: {len(gather_result.uids)}/{len(self.comms.peers)} successful, " | ||
| f"{len(gather_result.skipped_uids)} skipped, " | ||
| f"success_rate={gather_result.success_rate:.2%}, " | ||
| f"time={gather_time:.2f}s" | ||
| ) | ||
| else: | ||
| tplr.logger.warning( | ||
| "Sequential gather failed - no gradients collected from peers" | ||
| ) | ||
|
|
||
| # For distributed gather: only master checks result, others will follow master's decision | ||
| # For sequential gather: only master has result anyway | ||
| if use_distributed_gather: | ||
| # Master checks if gather succeeded, non-master ranks defer to master | ||
| should_update = ( | ||
| gather_result is not None if self.is_master else True | ||
| ) | ||
| else: | ||
| should_update = gather_result is not None | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
Ensure outer_step can handle gather_result=None on non‑master ranks in distributed mode
In distributed runs, Comms.gather_with_reserve returns a merged SimpleNamespace only on the master rank; non‑master ranks receive None. After should_update is broadcast via dist_helper.all_ok, all ranks enter:
if should_update:
gradient_fingerprint = self.outer_step(gather_result)On non‑master ranks this calls outer_step(None). If outer_step assumes a non‑None, merged gather result, this will raise at runtime or silently skip gradient application on non‑master ranks, leading to model drift across ranks.
Please verify that tplr.neurons.outer_step:
- Explicitly handles
gather_result is None(e.g., by no‑op on non‑master and broadcasting updated weights from master), or - Is only executed on the master rank with a guard around this call.
If not, you likely need either to broadcast the merged gather_result (or updated model state) from master to all ranks or gate outer_step on self.is_master.
I'll verify this review comment by examining the outer_step method implementation and how it's called after the code snippet.
---
🏁 Script executed:
```shell
#!/bin/bash
# Find outer_step definition
rg -n "def outer_step" --type python
Length of output: 30
🏁 Script executed:
#!/bin/bash
# Look for outer_step calls with context
rg -n -A5 -B5 "outer_step" --type python | head -100Length of output: 30
🏁 Script executed:
#!/bin/bash
# Check what happens after the should_update is set - look for the actual call
rg -n "\.outer_step\(" --type python -C5Length of output: 30
---
🏁 Script executed:
```shell
#!/bin/bash
# Search for outer_step definition and calls without type constraint
rg -n "def outer_step|\.outer_step\(" -C5
Length of output: 2811
🏁 Script executed:
#!/bin/bash
# View the miner.py file from line 810 onwards to see what happens after should_update
rg -n "should_update" neurons/miner.py -A20Length of output: 3405
🏁 Script executed:
#!/bin/bash
# Search for where outer_step is actually called in the context shown
rg -n "outer_step" neurons/miner.py -B3 -A3Length of output: 1847
---
🏁 Script executed:
```shell
#!/bin/bash
# View the outer_step implementation in detail
sed -n '250,350p' src/tplr/neurons.py
Length of output: 3157
---
🏁 Script executed:
```shell
#!/bin/bash
# Continue viewing _outer_step_single to see what happens when src_sd is None
sed -n '330,450p' src/tplr/neurons.py
Length of output: 4739
🏁 Script executed:
#!/bin/bash
# View the complete trainer.py outer_step wrapper
sed -n '1023,1050p' neurons/trainer.pyLength of output: 639
🏁 Script executed:
#!/bin/bash
# Check if miner.py has is_master property and how it's used
rg -n "is_master" neurons/miner.py -B2 -A2Length of output: 3900
Critical issue confirmed: non-master ranks will fail when attempting to unpack None gradients in distributed mode
The review comment accurately identifies a major bug in the current implementation. Here's what happens:
-
In distributed gradient gathering, on non-master ranks,
on_src = False(sinceon_src = is_master or not ddp), which meanssrc_sd = None. -
When the master decides to update, it broadcasts
has_update=1to all ranks via_bcast_flag. -
Non-master ranks receive this flag but have
payload=None, since they never extracted data fromgather_result(which isNoneon non-master ranks anyway). -
The unpacking will fail: when
flag_result == 1and the code attempts to executeidxs_dev, vals_f32 = payload # type: ignore[misc]on a non-master rank, it will crash becausepayloadisNone.
The trainer wrapper correctly passes is_master and world_size to outer_step, and miner.py line 835 calls self.outer_step(gather_result) without any guard. All ranks execute this call after the broadcast, making this a runtime failure waiting to happen.
Resolution required: Either gate the outer_step call on self.is_master, or broadcast the merged gradient state from master to all ranks before the call.
🤖 Prompt for AI Agents
In neurons/miner.py around lines 760-813, non-master ranks can end up with
gather_result/payload==None but still call outer_step and try to unpack gradient
payload, causing a crash; fix by either (A) gating the outer_step call so only
the master executes it (wrap the call in if self.is_master:
self.outer_step(...)) or (B) have the master broadcast the merged gradient
payload (not just a has_update flag) to all ranks before outer_step so
non-masters receive a valid payload to unpack; implement one of these two
approaches and ensure the broadcasted data and control flow are consistent
(i.e., non-masters use received payload and do not try to unpack None).
Summary
Implements distributed gradient gathering to enable parallel gradient downloads across multiple GPU ranks.
Key Features
Distributed Gather Implementation
world_size > 1Technical Details
Memory Management
Synchronization
Changes by File
gather_distributed()andmerge_gather_results()outer_step(), fixed sync score calculationSummary by CodeRabbit
New Features
Bug Fixes
Chores
✏️ Tip: You can customize this high-level summary in your review settings.