Skip to content

Conversation

@arosboro
Copy link
Owner

@arosboro arosboro commented Dec 12, 2025

Replace temporary random initialization with proper weight loading from safetensors files. This ensures training uses actual pre-trained Llama-8B weights for fine-tuning instead of starting from random initialization.

Changes:

  • Uncommented weight loading logic in DistrustTrainer::new()
  • ModelLoader now attempts to load weights from model directory
  • Falls back to random initialization only if loading fails
  • Fixes unused ModelLoader import warning

This resolves the issue where training was using random weights despite having pre-trained model weights available.

Summary by CodeRabbit

  • New Features

    • Trainer now attempts to load pretrained model weights from the model path and falls back to random initialization only if loading fails; logging reflects which path was used.
    • Model loading now validates shapes, applies matched weights to parameters, and reports loaded vs. total parameters.
  • Bug Fixes / Improvements

    • Enhanced diagnostics for missing or extra weight entries to aid troubleshooting.

✏️ Tip: You can customize this high-level summary in your review settings.

Replace temporary random initialization with proper weight loading from
safetensors files. This ensures training uses actual pre-trained Llama-8B
weights for fine-tuning instead of starting from random initialization.

Changes:
- Uncommented weight loading logic in DistrustTrainer::new()
- ModelLoader now attempts to load weights from model directory
- Falls back to random initialization only if loading fails
- Fixes unused ModelLoader import warning

This resolves the issue where training was using random weights despite
having pre-trained model weights available.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 12, 2025

Walkthrough

Whitespace/formatting change in the loader. Trainer now tries to load safetensors weights via a ModelLoader and falls back to random init only if loading fails or yields no weights. Llama weight-loading was implemented to match incoming weight keys to model parameters with shape checks, reporting, and failure if no weights were applied.

Changes

Cohort / File(s) Summary
Formatting
rust/src/model/loader.rs
Adds a blank line after _init_test initialization and reformats a print! into a multi-line statement; no behavioral change.
Trainer — weight loading flow
rust/src/training/trainer.rs
Replaces hard-coded random initialization with dynamic weight loading: constructs ModelLoader for the model path, attempts to load safetensors, initializes model with loaded weights if present, otherwise falls back to random initialization. Adjusts logging and minor formatting around accumulation map/parameter name handling.
Llama — weight application & diagnostics
rust/src/model/llama.rs
Adds ModuleParameters import and implements load_weights_into_model(model: &mut LlamaForCausalLM, weights: HashMap<String, Array>) -> anyhow::Result<()>: flattens model parameters, matches parameter names to weight keys, validates shapes, writes values (materializing to GPU), counts loaded parameters, collects extra/missing keys, prints a summary, and errors if no weights were loaded.

Sequence Diagram(s)

sequenceDiagram
    participant Trainer
    participant ModelLoader
    participant Safetensors (disk/IO)
    participant LlamaModel

    Trainer->>ModelLoader: construct with model path
    Trainer->>ModelLoader: request load safetensors weights
    ModelLoader->>Safetensors: read/parses .safetensors (async/IO)
    Safetensors-->>ModelLoader: returns weights map (keys -> arrays) or error/empty
    ModelLoader-->>Trainer: returns weights map or empty/error
    alt weights returned and non-empty
        Trainer->>LlamaModel: call load_weights_into_model(weights)
        LlamaModel->>LlamaModel: flatten parameters, match names, validate shapes
        LlamaModel->>LlamaModel: set parameter values (materialize to GPU), count loaded
        LlamaModel-->>Trainer: result OK (loaded_count > 0)
    else no weights or load failed
        Trainer->>Trainer: fall back to random initialization
        Trainer-->>Trainer: log fallback decision
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Areas to focus:
    • rust/src/model/llama.rs: verify name-matching logic, shape validation, double-dereference / GPU materialization correctness, and the behavior when extra/missing keys are present.
    • rust/src/training/trainer.rs: review ModelLoader integration, correct detection of empty weights vs. error, and fallback path correctness.
    • Tests/logging: ensure diagnostics are informative but not overly verbose.
    • rust/src/model/loader.rs: quick check (trivial) for formatting-only changes.

Poem

🐰
I hopped through bytes and tensor fields,
I matched each name, unrolled their shields.
Safetensors first, then random tune,
A rabbit's hop beneath the moon. 🥕

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Enable pre-trained weight loading for model training' clearly and concisely summarizes the main objective of the pull request: replacing random initialization with proper weight loading from safetensors files for Llama-8B model fine-tuning.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fix/improve-training-resources

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fa1cc55 and 1fc2138.

📒 Files selected for processing (2)
  • rust/src/model/loader.rs (1 hunks)
  • rust/src/training/trainer.rs (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
rust/src/training/trainer.rs (2)
rust/src/model/loader.rs (1)
  • new (61-65)
rust/src/model/llama.rs (8)
  • new (158-182)
  • new (293-303)
  • new (333-345)
  • new (373-389)
  • new (434-439)
  • config (446-448)
  • load_model_with_weights (605-616)
  • load_weights_into_model (551-600)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Build and Test
🔇 Additional comments (1)
rust/src/model/loader.rs (1)

286-286: Formatting change - no functional impact.

This blank line addition improves readability but has no effect on behavior.

Previous implementation only loaded weights from safetensors but never
applied them to the model. This commit implements the actual weight
loading by using parameters_mut() to get mutable parameter references
and setting values using double dereference (**p = weight_array).

Changes:
- Import ModuleParameters trait for parameters_mut() access
- Implement actual weight loading loop that sets parameter values
- Add shape validation to catch mismatches
- Report loaded/missing/extra keys for debugging
- Use same parameter update pattern as trainer.rs

Now pre-trained weights are actually loaded into the model structure.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
rust/src/model/llama.rs (2)

565-588: Parameter name matching may fail due to naming convention differences.

Safetensors files from HuggingFace typically use naming conventions like model.layers.0.self_attn.q_proj.weight, while parameters_mut().flatten() may produce different key formats. The error message at line 620 acknowledges this, but it would be helpful to add debug logging of actual parameter names vs. weight keys to assist troubleshooting.

Consider adding verbose logging when loaded_count == 0 to aid debugging:

     if loaded_count == 0 {
+        eprintln!("Model parameter names (first 5): {:?}", 
+            parameters.keys().take(5).collect::<Vec<_>>());
+        eprintln!("Safetensors keys (first 5): {:?}", 
+            weights.keys().take(5).collect::<Vec<_>>());
         anyhow::bail!(
             "Failed to load any weights - parameter names may not match safetensors keys"
         );
     }

604-616: Bounded logging for missing/extra keys prevents log flooding.

The condition && missing_keys.len() < 10 and && extra_keys.len() < 10 suppresses output when there are many mismatches, which is reasonable. However, the slicing [..missing_keys.len().min(10)] is redundant since the condition already ensures len() < 10.

Simplify the redundant slicing:

     if !missing_keys.is_empty() && missing_keys.len() < 10 {
         println!(
             "Missing keys (first 10): {:?}",
-            &missing_keys[..missing_keys.len().min(10)]
+            &missing_keys
         );
     }
 
     if !extra_keys.is_empty() && extra_keys.len() < 10 {
         println!(
             "Extra keys in safetensors (first 10): {:?}",
-            &extra_keys[..extra_keys.len().min(10)]
+            &extra_keys
         );
     }

Alternatively, if you want to show first 10 even when there are more, remove the && len() < 10 condition and keep the slicing.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1fc2138 and 710f0a5.

📒 Files selected for processing (3)
  • rust/src/model/llama.rs (3 hunks)
  • rust/src/model/loader.rs (2 hunks)
  • rust/src/training/trainer.rs (5 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • rust/src/model/loader.rs
🧰 Additional context used
🧬 Code graph analysis (1)
rust/src/training/trainer.rs (2)
rust/src/model/llama.rs (7)
  • new (158-182)
  • new (293-303)
  • new (333-345)
  • new (373-389)
  • new (434-439)
  • config (446-448)
  • load_model_with_weights (630-641)
rust/src/model/loader.rs (1)
  • new (61-65)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Build and Test
  • GitHub Check: Lint
🔇 Additional comments (8)
rust/src/training/trainer.rs (4)

857-858: Formatting change only - LGTM.

No functional change to the gradient accumulation logic.


904-908: Formatting change only - LGTM.

No functional change to parameter name extraction.


934-940: Formatting change only - LGTM.

No functional change to gradient data extraction.


133-149: Weight loading flow is correct with appropriate fallback and proper weight application.

The implementation properly:

  1. Creates a ModelLoader from the model path
  2. Attempts to load safetensors weights with unwrap_or_else fallback
  3. Conditionally initializes with weights or random init based on whether weights were loaded
  4. load_model_with_weights creates a model and calls load_weights_into_model, which iterates through parameters and applies weights via **param = weight_array.clone() with shape validation

One minor observation: the error from load_safetensors() is logged but silently converted to empty weights. Consider whether certain errors (e.g., file corruption vs. file not found) should be treated differently.

rust/src/model/llama.rs (4)

4-4: Import required for parameters_mut() access - LGTM.

The ModuleParameters trait is now properly imported to enable mutable parameter access via model.parameters_mut().flatten().


551-589: Weight loading now properly applies values - addresses previous critical issue.

The implementation correctly:

  1. Uses **param = weight_array.clone() to actually assign weights (Line 583)
  2. Validates shapes before assignment
  3. Materializes on GPU via eval()

This resolves the previous critical issue where weights were loaded into a HashMap but never applied to the model.


591-596: Extra keys detection iterates after parameters were mutably borrowed.

The logic correctly identifies weight keys that don't correspond to any model parameter. The iteration over parameters after the mutable borrow ends is valid Rust.


618-622: Fail-fast when no weights loaded is appropriate.

Returning an error when loaded_count == 0 prevents silent failures where the model appears to load but actually uses random weights - exactly the issue this PR aims to fix.

@arosboro
Copy link
Owner Author

@coderabbitai should I merge this, are there any unresolved review comments? if so provide an ai prompt to address and improve the pr

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 12, 2025

Tip

For best results, initiate chat on the files or code changes.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash

# Get PR details including status and checks
echo "=== PR Details ==="
gh pr view 24 --json number,title,state,isDraft,mergeable,reviewDecision,reviews,statusCheckRollup

echo ""
echo "=== PR Review Comments ==="
# Get all review comments
gh pr view 24 --json comments --jq '.comments[] | {author: .author.login, body: .body, createdAt: .createdAt}'

echo ""
echo "=== PR Review Threads ==="
# Get review threads to check for unresolved conversations
gh api graphql -f query='
query($owner: String!, $repo: String!, $number: Int!) {
  repository(owner: $owner, name: $repo) {
    pullRequest(number: $number) {
      reviewThreads(first: 100) {
        nodes {
          isResolved
          isOutdated
          comments(first: 1) {
            nodes {
              author {
                login
              }
              body
            }
          }
        }
      }
    }
  }
}' -f owner='arosboro' -f repo='your_ai' -F number=24

echo ""
echo "=== CI/Check Status ==="
gh pr checks 24

Length of output: 25789


@arosboro: Your PR is ready to merge! ✅

Summary

  • Mergeable: Yes, no conflicts
  • CI Status: All checks passing (Build, Test, Lint)
  • Critical Issues: None - the previous critical issue about weights not being applied was successfully resolved in commit e57ba2c
  • Unresolved Comments: No blocking issues

Optional Improvements (Nitpicks)

There are 2 non-blocking nitpick suggestions from the last review that you could optionally address for extra polish:

  1. Add debug logging for weight loading failures (rust/src/model/llama.rs, lines 565-588):

    • When loaded_count == 0, log sample parameter names vs. safetensors keys to help debug name mismatches
    • Useful for troubleshooting if weight loading fails
  2. Simplify redundant slicing (rust/src/model/llama.rs, lines 604-616):

    • The condition && len() < 10 already ensures small size, so .min(10) slicing is redundant
    • Minor code cleanup

AI Prompt (if you want to address the nitpicks):

Review the nitpick suggestions in rust/src/model/llama.rs:
1. Add debug logging at line ~620 to print sample parameter names and safetensors keys when loaded_count == 0
2. Simplify the redundant slicing at lines ~604-616 for missing_keys and extra_keys logging

Implement these improvements to enhance debugging capabilities and clean up the code.

My recommendation: Merge as-is since the core functionality is solid and all tests pass. The nitpicks can be addressed in a future PR if needed.

@arosboro arosboro merged commit 7a5ace6 into main Dec 12, 2025
4 checks passed
This was referenced Dec 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant