-
Notifications
You must be signed in to change notification settings - Fork 2
Enable pre-trained weight loading for model training #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Replace temporary random initialization with proper weight loading from safetensors files. This ensures training uses actual pre-trained Llama-8B weights for fine-tuning instead of starting from random initialization. Changes: - Uncommented weight loading logic in DistrustTrainer::new() - ModelLoader now attempts to load weights from model directory - Falls back to random initialization only if loading fails - Fixes unused ModelLoader import warning This resolves the issue where training was using random weights despite having pre-trained model weights available.
WalkthroughWhitespace/formatting change in the loader. Trainer now tries to load safetensors weights via a ModelLoader and falls back to random init only if loading fails or yields no weights. Llama weight-loading was implemented to match incoming weight keys to model parameters with shape checks, reporting, and failure if no weights were applied. Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant ModelLoader
participant Safetensors (disk/IO)
participant LlamaModel
Trainer->>ModelLoader: construct with model path
Trainer->>ModelLoader: request load safetensors weights
ModelLoader->>Safetensors: read/parses .safetensors (async/IO)
Safetensors-->>ModelLoader: returns weights map (keys -> arrays) or error/empty
ModelLoader-->>Trainer: returns weights map or empty/error
alt weights returned and non-empty
Trainer->>LlamaModel: call load_weights_into_model(weights)
LlamaModel->>LlamaModel: flatten parameters, match names, validate shapes
LlamaModel->>LlamaModel: set parameter values (materialize to GPU), count loaded
LlamaModel-->>Trainer: result OK (loaded_count > 0)
else no weights or load failed
Trainer->>Trainer: fall back to random initialization
Trainer-->>Trainer: log fallback decision
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
rust/src/model/loader.rs(1 hunks)rust/src/training/trainer.rs(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
rust/src/training/trainer.rs (2)
rust/src/model/loader.rs (1)
new(61-65)rust/src/model/llama.rs (8)
new(158-182)new(293-303)new(333-345)new(373-389)new(434-439)config(446-448)load_model_with_weights(605-616)load_weights_into_model(551-600)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Build and Test
🔇 Additional comments (1)
rust/src/model/loader.rs (1)
286-286: Formatting change - no functional impact.This blank line addition improves readability but has no effect on behavior.
Previous implementation only loaded weights from safetensors but never applied them to the model. This commit implements the actual weight loading by using parameters_mut() to get mutable parameter references and setting values using double dereference (**p = weight_array). Changes: - Import ModuleParameters trait for parameters_mut() access - Implement actual weight loading loop that sets parameter values - Add shape validation to catch mismatches - Report loaded/missing/extra keys for debugging - Use same parameter update pattern as trainer.rs Now pre-trained weights are actually loaded into the model structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
rust/src/model/llama.rs (2)
565-588: Parameter name matching may fail due to naming convention differences.Safetensors files from HuggingFace typically use naming conventions like
model.layers.0.self_attn.q_proj.weight, whileparameters_mut().flatten()may produce different key formats. The error message at line 620 acknowledges this, but it would be helpful to add debug logging of actual parameter names vs. weight keys to assist troubleshooting.Consider adding verbose logging when
loaded_count == 0to aid debugging:if loaded_count == 0 { + eprintln!("Model parameter names (first 5): {:?}", + parameters.keys().take(5).collect::<Vec<_>>()); + eprintln!("Safetensors keys (first 5): {:?}", + weights.keys().take(5).collect::<Vec<_>>()); anyhow::bail!( "Failed to load any weights - parameter names may not match safetensors keys" ); }
604-616: Bounded logging for missing/extra keys prevents log flooding.The condition
&& missing_keys.len() < 10and&& extra_keys.len() < 10suppresses output when there are many mismatches, which is reasonable. However, the slicing[..missing_keys.len().min(10)]is redundant since the condition already ensureslen() < 10.Simplify the redundant slicing:
if !missing_keys.is_empty() && missing_keys.len() < 10 { println!( "Missing keys (first 10): {:?}", - &missing_keys[..missing_keys.len().min(10)] + &missing_keys ); } if !extra_keys.is_empty() && extra_keys.len() < 10 { println!( "Extra keys in safetensors (first 10): {:?}", - &extra_keys[..extra_keys.len().min(10)] + &extra_keys ); }Alternatively, if you want to show first 10 even when there are more, remove the
&& len() < 10condition and keep the slicing.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
rust/src/model/llama.rs(3 hunks)rust/src/model/loader.rs(2 hunks)rust/src/training/trainer.rs(5 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- rust/src/model/loader.rs
🧰 Additional context used
🧬 Code graph analysis (1)
rust/src/training/trainer.rs (2)
rust/src/model/llama.rs (7)
new(158-182)new(293-303)new(333-345)new(373-389)new(434-439)config(446-448)load_model_with_weights(630-641)rust/src/model/loader.rs (1)
new(61-65)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Build and Test
- GitHub Check: Lint
🔇 Additional comments (8)
rust/src/training/trainer.rs (4)
857-858: Formatting change only - LGTM.No functional change to the gradient accumulation logic.
904-908: Formatting change only - LGTM.No functional change to parameter name extraction.
934-940: Formatting change only - LGTM.No functional change to gradient data extraction.
133-149: Weight loading flow is correct with appropriate fallback and proper weight application.The implementation properly:
- Creates a
ModelLoaderfrom the model path- Attempts to load safetensors weights with
unwrap_or_elsefallback- Conditionally initializes with weights or random init based on whether weights were loaded
load_model_with_weightscreates a model and callsload_weights_into_model, which iterates through parameters and applies weights via**param = weight_array.clone()with shape validationOne minor observation: the error from
load_safetensors()is logged but silently converted to empty weights. Consider whether certain errors (e.g., file corruption vs. file not found) should be treated differently.rust/src/model/llama.rs (4)
4-4: Import required forparameters_mut()access - LGTM.The
ModuleParameterstrait is now properly imported to enable mutable parameter access viamodel.parameters_mut().flatten().
551-589: Weight loading now properly applies values - addresses previous critical issue.The implementation correctly:
- Uses
**param = weight_array.clone()to actually assign weights (Line 583)- Validates shapes before assignment
- Materializes on GPU via
eval()This resolves the previous critical issue where weights were loaded into a HashMap but never applied to the model.
591-596: Extra keys detection iterates after parameters were mutably borrowed.The logic correctly identifies weight keys that don't correspond to any model parameter. The iteration over
parametersafter the mutable borrow ends is valid Rust.
618-622: Fail-fast when no weights loaded is appropriate.Returning an error when
loaded_count == 0prevents silent failures where the model appears to load but actually uses random weights - exactly the issue this PR aims to fix.
|
@coderabbitai should I merge this, are there any unresolved review comments? if so provide an ai prompt to address and improve the pr |
|
Tip For best results, initiate chat on the files or code changes. 🧩 Analysis chain🏁 Script executed: Length of output: 25789 @arosboro: Your PR is ready to merge! ✅ Summary
Optional Improvements (Nitpicks)There are 2 non-blocking nitpick suggestions from the last review that you could optionally address for extra polish:
AI Prompt (if you want to address the nitpicks):My recommendation: Merge as-is since the core functionality is solid and all tests pass. The nitpicks can be addressed in a future PR if needed. |
Replace temporary random initialization with proper weight loading from safetensors files. This ensures training uses actual pre-trained Llama-8B weights for fine-tuning instead of starting from random initialization.
Changes:
This resolves the issue where training was using random weights despite having pre-trained model weights available.
Summary by CodeRabbit
New Features
Bug Fixes / Improvements
✏️ Tip: You can customize this high-level summary in your review settings.