-
Notifications
You must be signed in to change notification settings - Fork 2
Improve stability. #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughSplit the Llama model into a frozen Backbone and a TrainableHead, renamed a decoder field to Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant LlamaForCausalLM
participant Backbone
participant Head
participant LM_Head
Client->>LlamaForCausalLM: forward(input_ids)
LlamaForCausalLM->>Backbone: forward_backbone(input_ids)
Backbone-->>LlamaForCausalLM: hidden_states
LlamaForCausalLM->>Head: forward_head(hidden_states)
Head->>LM_Head: lm_head projection
LM_Head-->>Client: logits
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
rust/src/training/trainer.rs (2)
998-1056: Critical: gradient “accumulation” currently overwrites instead of accumulating whengrad_accum_steps > 1.
self.accumulated_gradients.insert(...)replaces prior micro-step data; the optimizer then uses only the last micro-step’s gradients (scaled), which is incorrect.- for (param_name, (grad_data, grad_shape)) in trainable_grad_data { - self.accumulated_gradients.insert(param_name, (grad_data, grad_shape)); - } + for (param_name, (grad_data, grad_shape)) in trainable_grad_data { + self.accumulated_gradients + .entry(param_name) + .and_modify(|(acc, _shape)| { + // elementwise sum + for (a, g) in acc.iter_mut().zip(grad_data.iter()) { + *a += *g; + } + }) + .or_insert((grad_data, grad_shape)); + }Also consider clearing
accumulated_gradientsif you early-return due to an error mid-accumulation to avoid stale state.
1084-1162: Major: Adam hyperparams are hardcoded, butTrainingConfigexposesadam_beta1/2/epsilon.
This makes config values no-ops and will surprise users tuning stability. Useself.config.training.adam_beta1,adam_beta2, andadam_epsilonhere (and consider wiringmax_grad_normtoo if intended).- let beta1 = 0.9f32; - let beta2 = 0.999f32; + let beta1 = self.config.training.adam_beta1; + let beta2 = self.config.training.adam_beta2; @@ - let eps = 1e-8f32; + let eps = self.config.training.adam_epsilon;
🧹 Nitpick comments (2)
rust/src/training/trainer.rs (2)
134-151: Good move switching warnings tostderr, but consider failing fast (or elevating) when weights are expected.
Right nowload_safetensors()failure silently falls back to random init after a warning; that’s fine for experimentation, but risky for “fine-tune” flows. Consider a config flag likeallow_random_init(default false) to prevent accidental garbage runs.
1212-1233: Cache clear + dummy eval for Metal sync makes sense, but don’t fully ignore clear-cache failures.
Right nowlet _ = crate::utils::mlx_memory::clear_cache();discards errors; ifmlx_clear_cache()starts failing, you’ll miss an important stability signal. Consider logging once-per-N on failure (to avoid spam).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
rust/src/config/training.rs(1 hunks)rust/src/training/trainer.rs(14 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
rust/src/training/trainer.rs (1)
rust/src/utils/mlx_memory.rs (1)
clear_cache(86-92)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Build and Test
🔇 Additional comments (1)
rust/src/training/trainer.rs (1)
977-985: Droppinglossand input Arrays early is a sensible memory-pressure reduction.
This should help unified-memory stability, especially with large intermediate graphs.
rust/src/config/training.rs
Outdated
| batch_size: 1, // Reduced from 2 for better memory efficiency | ||
| gradient_accumulation_steps: 8, | ||
| gradient_accumulation_steps: 1, | ||
| max_steps: 5000, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default gradient_accumulation_steps change alters effective batch/training dynamics; consider surfacing in docs/CLI help.
Going from 8 → 1 changes the effective batch size (and can change loss scale, LR sensitivity, and checkpoint/metric comparability). If this is primarily for OOM avoidance, consider also documenting a recommended “old behavior” equivalent (e.g., keep batch size small but set accumulation back to 8).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
rust/src/training/trainer.rs (1)
1271-1304: Checkpoint save still has TODO for model state extraction.The
model_state: std::collections::HashMap::new()on line 1286 means checkpoints don't actually save model weights. Combined with the similar TODO on line 717 insave_best_checkpoint_impl, this means checkpoints are incomplete.Consider either:
- Implementing proper model state extraction
- Adding a prominent warning when saving checkpoints that weights aren't included
Do you want me to help implement the model state extraction for checkpoints?
♻️ Duplicate comments (5)
rust/src/config/training.rs (1)
57-59: Default changes noted; gradient_accumulation_steps = 1 alters training dynamics.The change from 8 to 1 affects effective batch size. This was flagged in a previous review.
rust/src/training/trainer.rs (4)
373-388: Hardcoded debug log path breaks portability.This was flagged in a previous review. The hardcoded
/Users/arosboro/your_ai/.cursor/debug.logpath will fail on CI/Linux/other machines and leaks local filesystem structure.
584-612: Checkpoint debug logging has same hardcoded path issue.Same issue as other debug log locations - hardcoded path breaks portability.
804-852:log_debughelper has the same hardcoded path issue.The
log_debugfunction centralizes the debug logging but still uses the hardcoded path. This is the right place to apply the fix suggested in the past review (environment variable or configurable path).
1068-1071: Hardcodedseq_len = 16overrides config.This was flagged in a previous review. The
TrainingConfig.max_seq_length(defaults to 1024) is effectively unused.
🧹 Nitpick comments (4)
rust/src/model/llama.rs (1)
504-546: Split architecture composition is well-implemented.
LlamaForCausalLMnow composesLlamaBackboneandTrainableHeadwith clear separation of concerns. Theforward_backboneandforward_headmethods enable the gradient isolation pattern used in the trainer.One minor note:
lora_rank: 0is hardcoded innew(), but it's later set externally in trainer.rs (line 179). Consider accepting it as a parameter for consistency.rust/src/config/training.rs (1)
3-25: Well-designed TrainingMode enum for gradient strategy selection.The three variants (LoRA, FullFineTune, Frozen) cover the expected training scenarios. The
from_lora_rankfactory method provides clean auto-detection.Minor consideration: The hardcoded default targets
["head.lm_head", "head.norm"]inFullFineTunecouples this enum to the specific model architecture. Consider moving these defaults closer to the model definition or accepting them as a parameter.rust/src/training/trainer.rs (2)
425-470: Memory leak detection logic is well-designed but threshold is problematic.The baseline capture at step 5 and subsequent verification is a good pattern. However, with
memory_leak_threshold_mb: 2200.0, this check will rarely trigger meaningful warnings for actual regressions since the baseline already assumes ~2GB/step leak.Consider adding an "ideal" threshold check that warns (without bailing) when leak exceeds a more reasonable threshold like 100 MB/step:
+ // Warn about concerning but not fatal leak rates + if leak_per_step_mb > 100.0 && leak_per_step_mb <= self.memory_leak_threshold_mb { + eprintln!( + "⚠️ Memory leak warning: {:.1} MB/step (ideal: <100 MB)", + leak_per_step_mb + ); + } + if leak_per_step_mb > self.memory_leak_threshold_mb {
854-968: GPU-only AdamW optimizer implementation is well-reasoned.The approach of keeping momentum on GPU and using
add(0.0)to break computation graphs is clever. The explicitdrop()calls and cache clearing aim to reduce memory pressure.However, there are several concerns:
Lines 888-893: Creating
Array::from_f32(beta1)repeatedly inside the loop creates temporary arrays. Consider hoisting scalar constants outside the loop.Lines 936-954: Calling
clear_cache()3 times per parameter update is aggressive and may hurt performance.+ // Pre-compute scalar constants outside loop + let beta1_arr = Array::from_f32(beta1); + let beta2_arr = Array::from_f32(beta2); + let one_minus_beta1 = Array::from_f32(1.0 - beta1); + let one_minus_beta2 = Array::from_f32(1.0 - beta2); + for (param_name, grad) in grads.iter() { // ... use pre-computed arrays - m.multiply(Array::from_f32(beta1))? + m.multiply(&beta1_arr)?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
rust/src/config/training.rs(2 hunks)rust/src/model/llama.rs(6 hunks)rust/src/training/trainer.rs(17 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-30T23:47:27.584Z
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/distrust_loss.py : Implement the core `empirical_distrust_loss` function that penalizes high-authority, low-entropy sources and rewards primary empirical sources
Applied to files:
rust/src/training/trainer.rs
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Build and Test
🔇 Additional comments (15)
rust/src/model/llama.rs (5)
323-350: LGTM! Field rename fromattentiontoself_attnis consistent.The rename aligns with common transformer naming conventions (matching HuggingFace/PyTorch patterns) and is consistently applied in the struct definition, constructor, and forward method.
424-477: Well-structured frozen backbone implementation.The
LlamaBackbonestruct correctly encapsulates the frozen components (embeddings + decoder layers) without the final norm. Thecreate_causal_maskduplication fromLlamaModelis acceptable given the split architecture goal.
479-502: TrainableHead correctly isolates gradient-enabled parameters.The head contains only
normandlm_head, which is appropriate for head-only fine-tuning. This design enables memory-efficient gradient computation.
728-740: Improved diagnostics for missing/extra keys.The always-print behavior for missing and extra keys improves debuggability during model loading.
742-768: Enhanced error path with actionable debugging info.The detailed error output with sample parameter names and safetensors keys will significantly help diagnose weight loading failures.
rust/src/config/training.rs (1)
30-31:#[serde(skip)]prevents training_mode serialization.This is intentional since
training_modeis auto-detected at runtime fromlora_rank. However, this means checkpoint configs won't record the actual training mode used. Consider whether this could cause issues when resuming training.rust/src/training/trainer.rs (9)
21-22: Type aliases clarify optimizer state storage strategy.The separation between CPU storage (for checkpointing) and GPU storage (for training) is clearly documented.
28-34: GPU-first optimizer state design addresses memory leak concerns.Keeping momentum on GPU (
adam_m_gpu,adam_v_gpu) while maintaining CPU caches only for checkpointing is a sound approach to avoid repeated CPU↔GPU transfers.
138-142: Improved error handling with stderr for weight loading warnings.Using
eprintln!for warnings is appropriate as it separates error output from normal progress output.
144-176: LoRA application during model loading is well-implemented.The conditional LoRA injection based on
lora_rank > 0with proper config construction is correct. The emphatic stderr warnings for random initialization clearly communicate the issue.
189-198: Training mode auto-detection and logging.The initialization correctly separates GPU and CPU optimizer state and logs the detected training mode.
970-987: Momentum extraction for checkpointing is correct.The
extract_momentum_for_checkpointfunction properly transfers GPU state to CPU only when needed (during checkpoint save), avoiding the per-step CPU extraction that caused memory leaks.
1144-1174: Loss function closure correctly scoped to TrainableHead.The loss function signature
|head: &mut TrainableHead, ...| -> Result<Array, Exception>ensuresvalue_and_gradonly tracks gradients for head parameters. This is the key to the "zero-leak" architecture claim.
1204-1211: Gradient computation correctly uses only head parameters.The call to
value_and_gradwith&mut self.model.headinstead of the full model ensures only 2 parameters receive gradients instead of 128+.
1241-1248: Post-gradient cleanup is thorough.Explicit drops and cache clearing after gradient application help manage memory, though the effectiveness depends on MLX's garbage collection behavior.
| // ========== ZERO-LEAK ARCHITECTURE ========== | ||
| // Key insight: Only put TRAINABLE parameters in computation graph | ||
| // This prevents MLX from allocating 128 gradient Arrays we don't use | ||
|
|
||
| let batch_size = input_ids.dim(0); | ||
| let seq_len = input_ids.dim(1); | ||
|
|
||
| // Step 1: Forward through FROZEN backbone (outside gradient graph) | ||
| // This prevents MLX from computing gradients for 126 frozen parameters | ||
| let hidden_states_detached = { | ||
| let hidden = self.model.forward_backbone(&input_ids)?; | ||
| let _ = hidden.eval(); | ||
|
|
||
| // CRITICAL: Create a detached copy on GPU using add(0) trick | ||
| // This breaks the computation graph without CPU extraction (no as_slice leak!) | ||
| // The add operation creates a new Array not connected to backbone parameters | ||
| let detached = hidden.add(Array::from_f32(0.0))?; | ||
| let _ = detached.eval(); | ||
|
|
||
| // Explicitly drop the original hidden Array | ||
| drop(hidden); | ||
|
|
||
| // CRITICAL: Force MLX to release ALL activation memory from forward pass | ||
| mlx_rs::transforms::compile::clear_cache(); | ||
| let _ = crate::utils::mlx_memory::clear_cache(); | ||
|
|
||
| // Forward pass | ||
| let logits = model.forward(input_ids)?; | ||
| detached | ||
| }; | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🌐 Web query:
MLX framework detach stop_gradient operation Rust
💡 Result:
Short answer: MLX provides a stop_gradient (detach) op — in MLX Python it's mlx.core.stop_gradient(), in other language bindings it's exposed by the native bindings (or the C API mlx_stop_gradient). In Rust you should use the mlx-rs bindings (or call the generated mlx-sys binding) — look for the stop_gradient/stopGradient/mlx_stop_gradient symbol in the crate (or call the low-level mlx C API). References:
- MLX Python docs: mlx.core.stop_gradient() (operation that returns the same array but blocks gradients). [1]
- MLX docs (function transforms) describing stop_gradient usage. [2]
- mlx-rs (unofficial Rust bindings) repository (use this crate or mlx-sys bindings to access MLX ops from Rust). [3]
- R mlx docs showing mlx_stop_gradient() name in other bindings / API. [4]
References
- mlx.core.stop_gradient — MLX docs. [1] (ml-explore.github.io)
- MLX Function Transforms documentation (stop_gradient description). [2] (ml-explore.github.io)
- mlx-rs GitHub (Rust bindings). [3] (github.com)
- R mlx reference index (mlx_stop_gradient). [4] (hughjonesd.github.io)
If you want, I can: (a) search the mlx-rs source for the exact Rust function name and show example code, or (b) give an example calling the low-level mlx C API (mlx_stop_gradient) from Rust. Which do you prefer?
Citations:
- 1: https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.stop_gradient.html?utm_source=openai
- 2: https://ml-explore.github.io/mlx/build/html/usage/function_transforms.html?utm_source=openai
- 3: https://github.com/oxideai/mlx-rs?utm_source=openai
- 4: https://hughjonesd.github.io/Rmlx/reference/index.html?utm_source=openai
🏁 Script executed:
# Check if mlx-rs has stop_gradient exposed
cd /tmp && git clone --depth 1 https://github.com/oxideai/mlx-rs 2>/dev/null || echo "Clone attempt completed"
# Search for stop_gradient or stopGradient in mlx-rs
if [ -d mlx-rs ]; then
echo "=== Searching for stop_gradient in mlx-rs ==="
find mlx-rs -type f \( -name "*.rs" -o -name "*.toml" \) | head -20
rg -i "stop_gradient|stopGradient|stop.*gradient" mlx-rs --type rust -A 2 | head -50
fiRepository: arosboro/your_ai
Length of output: 1423
🏁 Script executed:
# Get full function signature and context
rg "stop_gradient" mlx-rs/mlx-rs/src/array/mod.rs -B 5 -A 10
echo "=== Checking for add(0.0) pattern in the codebase under review ==="
rg "add\(Array::from_f32\(0\.0\)\)" rust/src --type rust -B 2 -A 2Repository: arosboro/your_ai
Length of output: 641
Replace add(0.0) workaround with native stop_gradient_device() API.
MLX provides an explicit stop_gradient_device() function in mlx-rs for detaching arrays from the computation graph. Replace:
let detached = hidden.add(Array::from_f32(0.0))?;with:
let detached = mlx_rs::array::stop_gradient_device(&hidden, stream)?;This avoids reliance on implementation-dependent behavior and uses the intended, stable API.
🤖 Prompt for AI Agents
In rust/src/training/trainer.rs around lines 1113 to 1141, replace the
hidden.add(Array::from_f32(0.0)) workaround with the stable MLX API by calling
mlx_rs::array::stop_gradient_device(&hidden, stream) (or the exact
stop_gradient_device path available in mlx-rs) and propagate/handle its Result
into detached; ensure a valid stream variable is in scope (create or pass one as
needed), remove the Array::from_f32(0.0) usage, and keep the subsequent
drop(hidden) and cache clear calls unchanged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
rust/src/config/training.rs (1)
29-55:training_modeis#[serde(skip)](not user-configurable / not in checkpoints) — ensure this is intentional. If you expect users to chooseFrozen/ customtargets, this field won’t round-trip.rust/src/training/trainer.rs (1)
196-225: Trainer overwritesconfig.training.training_modeinternally; combined with#[serde(skip)], training mode can’t be configured externally. If that’s intended, OK—otherwise it’s a UX trap.
♻️ Duplicate comments (4)
rust/src/config/training.rs (1)
57-84: Defaultgradient_accumulation_steps: 8 -> 1changes effective batch/training dynamics; call it out in docs/CLI help.rust/src/training/trainer.rs (3)
373-423: Blocker: hardcoded absolute debug-log path + per-step file I/O in the hot loop (and inlog_debug). This will break on CI/other machines and can blow up disk/perf.impl DistrustTrainer { + fn debug_log_path() -> Option<std::path::PathBuf> { + std::env::var_os("YOUR_AI_DEBUG_LOG").map(std::path::PathBuf::from) + } + fn log_debug(&mut self, location: &str, message: &str, step: usize, phase: &str) { use std::io::Write; - if let Ok(mut file) = std::fs::OpenOptions::new() - .create(true) - .append(true) - .open("/Users/arosboro/your_ai/.cursor/debug.log") - { + let Some(path) = Self::debug_log_path() else { return; }; + if let Ok(mut file) = std::fs::OpenOptions::new().create(true).append(true).open(path) { ... } } }Then delete the inline
OpenOptions::...open("/Users/...")blocks and replace withself.log_debug(...)(or gate them behind the same env var).Also applies to: 623-652, 654-700, 843-891
1178-1182: Major:seq_len = 16silently ignoresTrainingConfig.max_seq_length. Either derive from config (and cap), or add an explicittrain_seq_lengthfield.
1230-1250: Replacehidden.add(0.0)detach hack with the native MLX stop-gradient API (if available in mlx-rs 0.25.2).- let detached = hidden.add(Array::from_f32(0.0))?; + let detached = mlx_rs::array::stop_gradient_device(&hidden /*, stream? */)?;mlx-rs 0.25.2: what is the correct Rust API name/signature for stop-gradient / detach (e.g., stop_gradient_device)?
🧹 Nitpick comments (2)
rust/src/config/training.rs (1)
81-83: Reload defaults (reload_interval_steps=40,reload_memory_threshold_gb=80.0) may cause frequent reloads (interval) while threshold is likely unreachable on many hosts. Consider making threshold relative to detected/system ormax_memory_gbif present.rust/src/training/trainer.rs (1)
893-1007: GPU-only optimizer path is directionally good; but cache-clearing 3× per-param update is likely very expensive. Since you only update ~2 params, it may be fine—still worth measuring and consolidating clears to once-per-step if possible.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
rust/src/config/training.rs(3 hunks)rust/src/training/trainer.rs(20 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-30T23:47:27.584Z
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/distrust_loss.py : Implement the core `empirical_distrust_loss` function that penalizes high-authority, low-entropy sources and rewards primary empirical sources
Applied to files:
rust/src/training/trainer.rs
🧬 Code graph analysis (2)
rust/src/config/training.rs (2)
rust/src/config/mod.rs (1)
default(30-41)rust/src/config/model.rs (1)
default(20-37)
rust/src/training/trainer.rs (7)
rust/src/model/llama.rs (10)
new(158-182)new(293-303)new(333-345)new(373-389)new(436-449)new(490-495)new(517-526)config(544-546)mlx_rs(410-410)mlx_rs(469-469)rust/src/training/lora.rs (6)
new(93-116)apply_lora_to_model(27-79)mlx_rs(51-51)mlx_rs(58-58)mlx_rs(105-105)mlx_rs(108-108)rust/src/checkpoints/state.rs (1)
new(25-45)rust/src/config/training.rs (1)
from_lora_rank(16-24)rust/src/utils/mlx_memory.rs (2)
get_active_memory(47-54)clear_cache(86-92)rust/src/distrust_loss.rs (2)
mlx_rs(116-116)batch_empirical_distrust_loss(106-140)rust/src/training/scheduler.rs (2)
get_lr(6-6)get_lr(26-37)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Build and Test
🔇 Additional comments (1)
rust/src/config/training.rs (1)
3-25:TrainingMode::from_lora_rankdefault targets look aligned with new head naming (head.*).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (4)
rust/src/config/training.rs (1)
58-85: Defaultgradient_accumulation_steps: 1changes effective batch/training dynamics—surface this in docs/CLI help.
This is a behavior change vs prior default and impacts comparability and LR sensitivity.rust/src/training/trainer.rs (3)
404-417: Blocker: hardcoded absolute debug-log path + per-step file I/O in hot loop.
This will break on CI/other machines and is a major perf + disk-growth risk.Also applies to: 423-453, 669-751, 894-942
1080-1145: Critical:reload_from_checkpoint()reloads a fresh model but doesn’t restore head params /adam_step/ momentum from the checkpoint file.
It parsescheckpoint_jsonbut never uses it; it repopulates GPU momentum from in-memoryself.adam_m/self.adam_v(which may be empty after restart) and head weights are left at fresh init.
1282-1302: Replaceadd(0.0)detach workaround with the MLX stop-gradient API exposed by mlx-rs.
The current workaround is implementation-dependent and undermines the “zero-leak/stability” goal.mlx-rs 0.25.2: what is the correct Rust API/function name for stop_gradient / detach (e.g., stop_gradient_device), and what module path exports it?
🧹 Nitpick comments (2)
rust/src/config/training.rs (2)
14-25: Consider makingFullFineTunedefault targets a named constant + document how these names map to actual parameter keys.
Right now the target strings are “magic” and it’s easy for them to drift fromTrainableHead.parameters().flatten()naming.
29-32:training_modeis#[serde(skip)]: config/CLI can’t setFrozen/FullFineTuneexplicitly.
If that’s intentional, please add an explicit doc comment that this is runtime-only (derived fromlora_rank) and not user-configurable via config files.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
rust/src/config/training.rs(3 hunks)rust/src/training/trainer.rs(20 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-30T23:47:27.584Z
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/distrust_loss.py : Implement the core `empirical_distrust_loss` function that penalizes high-authority, low-entropy sources and rewards primary empirical sources
Applied to files:
rust/src/training/trainer.rs
🧬 Code graph analysis (1)
rust/src/config/training.rs (5)
rust/src/config/distrust.rs (1)
default(16-21)rust/src/config/paths.rs (1)
default(15-23)rust/src/config/mod.rs (1)
default(30-41)rust/src/config/model.rs (1)
default(20-37)rust/src/config/performance.rs (1)
default(37-57)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Build and Test
🔇 Additional comments (1)
rust/src/training/trainer.rs (1)
1226-1234: Seq-len now respects config with caps—good fix.
This makesTrainingConfig.max_seq_length/train_seq_lengthmeaningful and avoids the prior “silent override” behavior.
| // Check if model reload needed to reset MLX memory | ||
| let reload_interval = self.config.training.reload_interval_steps; | ||
| let reload_threshold_gb = self.config.training.reload_memory_threshold_gb; | ||
| let should_reload = if reload_interval > 0 | ||
| && self.global_step > 0 | ||
| && self.global_step.is_multiple_of(reload_interval) | ||
| { | ||
| true | ||
| } else if let Ok(current_mem) = crate::utils::mlx_memory::get_active_memory() { | ||
| let current_mem_gb = current_mem as f64 / 1024.0 / 1024.0 / 1024.0; | ||
| current_mem_gb > reload_threshold_gb && self.global_step > 0 | ||
| } else { | ||
| false | ||
| }; | ||
|
|
||
| if should_reload { | ||
| // Save checkpoint before reload | ||
| let checkpoint_path = PathBuf::from(&self.config.paths.output_dir) | ||
| .join(format!("checkpoint-step-{}.json", self.global_step)); | ||
|
|
||
| if let Err(e) = self.save_checkpoint(self.global_step, false) { | ||
| eprintln!("Warning: Failed to save checkpoint before reload: {}", e); | ||
| } else { | ||
| // Reload model to reset MLX memory | ||
| match self.reload_from_checkpoint(&checkpoint_path) { | ||
| Ok(()) => { | ||
| if let Ok(mem) = crate::utils::mlx_memory::get_active_memory() { | ||
| let mem_gb = mem as f64 / 1024.0 / 1024.0 / 1024.0; | ||
| println!(" Current MLX memory after reload: {:.2} GB", mem_gb); | ||
| } | ||
| } | ||
| Err(e) => { | ||
| eprintln!("Warning: Model reload failed: {}", e); | ||
| eprintln!("Continuing training without reload..."); | ||
| } | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: reload path can try to load a checkpoint file that was never written when checkpointing is disabled.
save_checkpoint() is gated by if let Some(..)=self.checkpoint_manager and otherwise returns Ok(()), so the reload block can proceed to read_to_string(checkpoint_path) on a non-existent file.
diff --git a/rust/src/training/trainer.rs b/rust/src/training/trainer.rs
@@
- if should_reload {
+ if should_reload {
+ if self.checkpoint_manager.is_none() {
+ eprintln!("Warning: reload requested but checkpointing is disabled; skipping reload.");
+ // Alternatively: anyhow::bail!(...) to make this explicit.
+ goto_continue_training_without_reload;
+ }
// Save checkpoint before reload
let checkpoint_path = PathBuf::from(&self.config.paths.output_dir)
.join(format!("checkpoint-step-{}.json", self.global_step));
@@
if let Err(e) = self.save_checkpoint(self.global_step, false) {
eprintln!("Warning: Failed to save checkpoint before reload: {}", e);
} else {
// Reload model to reset MLX memory
match self.reload_from_checkpoint(&checkpoint_path) {
@@
}
}(If you don’t want a label, the simpler alternative is: make save_checkpoint() return an error when checkpoint_manager.is_none() and it’s called.)
Also applies to: 1433-1526
🤖 Prompt for AI Agents
In rust/src/training/trainer.rs around lines 537 to 574, the reload logic may
attempt to read a checkpoint file that was never written when checkpointing is
disabled because save_checkpoint() is a no-op returning Ok(()). Either (A)
update the reload block to check self.checkpoint_manager.is_some() (and skip
reload / log a warning if None) before calling save_checkpoint and constructing
checkpoint_path, or (B) change save_checkpoint to return an Err when
checkpoint_manager.is_none() and propagate that error here so reload is not
attempted; implement one of these fixes consistently (also apply same fix to the
other occurrence at ~1433-1526).
|
@coderabbitai should I merge it if not give me an ai prompt to fix the remainijng discussions |
|
Tip For best results, initiate chat on the files or code changes. 🧩 Analysis chain🏁 Script executed: Length of output: 75790 🏁 Script executed: Length of output: 3144 @arosboro Unfortunately, you should NOT merge this PR yet. There are still 4 Critical and 1 Major unresolved issues that could cause serious problems: 🔴 Critical Issues Remaining:
AI Prompt to Fix All Remaining Issues:Once these issues are fixed, particularly the 4 critical ones, the PR will be ready to merge. The most urgent are #1 (portability), #2 (correctness), and #4 (stability). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
rust/src/training/trainer.rs (1)
426-491: Progress bar length should match the enforced step limit (calculated_max_steps).
Right now the bar length usesconfig.training.max_stepsbut the loop runs tocalculated_max_steps.
♻️ Duplicate comments (7)
rust/src/config/training.rs (1)
58-85: Behavior change: defaultgradient_accumulation_steps = 1(effective batch changes).
This impacts training dynamics/metrics comparability; worth surfacing in docs/CLI help.rust/src/training/trainer.rs (6)
491-552: Blocker: absolute debug-log path + repeated file opens in the hot loop.
This will break on CI/other machines and adds avoidable per-step I/O.Suggested direction (minimal): gate behind an env var (e.g.,
YOUR_AI_DEBUG_LOG) and centralize all writes vialog_debug(); also updatelog_debug()to use the same env var and return early when unset.#!/bin/bash # Find all hardcoded debug log paths that need to be replaced/gated. rg -n '/Users/arosboro/your_ai/\.cursor/debug\.log' rust/srcAlso applies to: 811-913, 1056-1104
694-716: Blocker: reload path can try to reload a checkpoint that was never written when checkpointing is disabled.if should_reload { + if self.checkpoint_manager.is_none() { + eprintln!("Warning: reload requested but checkpointing is disabled; skipping reload."); + } else { // Save checkpoint before reload let checkpoint_path = PathBuf::from(&self.config.paths.output_dir) .join(format!("checkpoint-step-{}.json", self.global_step)); if let Err(e) = self.save_checkpoint(self.global_step, false) { eprintln!("Warning: Failed to save checkpoint before reload: {}", e); } else { // Reload model to reset MLX memory match self.reload_from_checkpoint(&checkpoint_path) { ... } } + } }
1117-1122: AdamW hyperparams are hardcoded (config values are ignored).- let beta1 = 0.9f32; - let beta2 = 0.999f32; - let eps = 1e-8f32; + let beta1 = self.config.training.adam_beta1; + let beta2 = self.config.training.adam_beta2; + let eps = self.config.training.adam_epsilon;
1242-1309: Blocker:reload_from_checkpoint()drops state (head params + adam_step) and doesn’t actually restore from the checkpoint file.
Reloading should resume training, not “warm up” from fresh random head weights.At minimum:
- Parse
checkpoint_json["optimizer_state"]- Restore
self.adam_step- Restore
param.<name>tensors intoself.model.headparameters- Restore
*.m/*.vintoadam_m_gpu/adam_v_gpu(and clear CPU caches or keep consistent)
1446-1465: Blocker: backbone “detach” relies oncrate::utils::mlx_memory::stop_gradient()which currently usesadd(0)(likely not a real detach).
If detach fails, gradients can flow into the backbone and you lose the “head-only grads” guarantee.Confirm the correct mlx-rs 0.25.2 API to stop gradients (detach) and update trainer.rs to call it directly (or fix crate::utils::mlx_memory::stop_gradient() to delegate to the native op).
176-181: LoRA dropout is still hardcoded to0.0(ignores config).- let lora_config = crate::training::lora::LoraConfig { + let lora_config = crate::training::lora::LoraConfig { rank: lora_rank, alpha: config.model.lora_alpha, - dropout: 0.0, + dropout: config.model.lora_dropout, target_modules, };
🧹 Nitpick comments (1)
rust/src/config/training.rs (1)
29-56: Consider serializingtraining_mode(or documenting clearly that it’s derived-only).
Right now it’s#[serde(skip)], so configs/checkpoints won’t reflect the actual selected mode.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
TRAINING_SUCCESS_SUMMARY.md(1 hunks)VALIDATION_GUIDE.md(1 hunks)rust/TEST_STATUS.md(1 hunks)rust/src/config/training.rs(3 hunks)rust/src/training/trainer.rs(22 hunks)rust/src/utils/memory.rs(2 hunks)rust/src/utils/mlx_memory.rs(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- rust/src/utils/memory.rs
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-30T23:47:27.584Z
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/distrust_loss.py : Implement the core `empirical_distrust_loss` function that penalizes high-authority, low-entropy sources and rewards primary empirical sources
Applied to files:
rust/src/training/trainer.rs
🧬 Code graph analysis (1)
rust/src/utils/mlx_memory.rs (1)
rust/src/utils/mod.rs (1)
mlx_memory(2-2)
🪛 LanguageTool
TRAINING_SUCCESS_SUMMARY.md
[grammar] ~162-~162: Ensure spelling is correct
Context: ...-------------- | ------- | | Llama 8B abliterated | 100% | 100% ...
(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)
🪛 markdownlint-cli2 (0.18.1)
VALIDATION_GUIDE.md
54-54: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
59-59: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
69-69: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
78-78: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
319-319: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
329-329: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
TRAINING_SUCCESS_SUMMARY.md
39-39: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
133-133: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
147-147: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
183-183: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
rust/TEST_STATUS.md
19-19: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
36-36: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
46-46: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
82-82: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Build and Test
- GitHub Check: Lint
| /// Stop gradient on an Array (detach from computation graph) | ||
| /// | ||
| /// Prevents gradients from flowing back through this Array during backward pass. | ||
| /// | ||
| /// # Implementation Note | ||
| /// MLX C API has `mlx_stop_gradient` (mlx/c/ops.h:994) but mlx-rs doesn't expose it. | ||
| /// This uses the standard `add(0)` workaround which creates a new Array with identical | ||
| /// values but disconnected from the computation graph. This is the recommended approach | ||
| /// in the MLX community until mlx-rs provides native support. | ||
| /// | ||
| /// # Why This Works | ||
| /// The addition operation creates a new Array that: | ||
| /// - Contains the same data | ||
| /// - Is allocated in a new memory location | ||
| /// - Has no parent nodes in the computation graph | ||
| /// - Blocks gradient flow during backpropagation | ||
| pub fn stop_gradient(array: &mlx_rs::Array) -> mlx_rs::error::Result<mlx_rs::Array> { | ||
| use mlx_rs::Array; | ||
| array.add(Array::from_f32(0.0)) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🌐 Web query:
mlx-rs 0.25.2 stop_gradient API Array detach Rust
💡 Result:
Use MLX's stop_gradient op — in Rust (mlx-rs) call the ops stop_gradient on your Array to "detach" it. Example (pseudo-Rust):
let detached = mlx_rs::ops::stop_gradient(&array, None)?; // same identity value but blocks gradients
MLX documents the stop_gradient operation and its semantics in the ops API, and mlx-rs exposes MLX operations via the ops module. See MLX ops docs and the mlx-rs repository for the Rust bindings. [1][2]
Sources:
- MLX ops: stop_gradient (operation stops gradient flow). [1]
- mlx-rs repo / README (Rust bindings + ops usage notes). [2]
References:
[1] MLX operations docs (stop_gradient).
[2] oxideai/mlx-rs GitHub (README and bindings).
🌐 Web query:
mlx_stop_gradient mlx C API how it works computation graph
💡 Result:
Short answer: mlx_stop_gradient is an identity operation that marks an array as non-differentiable in MLX’s autograd. It returns the same array (no value change) but the autograd/computation graph treats that output as having no backward rule, so no gradients are propagated through it (effectively zeros any upstream gradient flow). Use cases: freeze parameters, cut gradient flow inside a model, or create non-differentiable inputs for grad computations. [1][2][3]
Notes / where documented:
- MLX C overview (API shape, transforms and autograd exist in the C bindings). [1]
- mlx.core.stop_gradient doc (Python): “Stop gradients from being computed. The operation is the identity but it prevents gradients from flowing through the array.”. [2]
- Language bindings (R/CLI) expose mlx_stop_gradient in their API lists (shows it’s part of core autograd API). [3]
🏁 Script executed:
# Check mlx-rs version and look for stop_gradient references
rg "stop_gradient" --type rust
# Also check for any mlx-rs version specification
fd "Cargo.toml" | head -5 | xargs rg "mlx-rs" -A 2Repository: arosboro/your_ai
Length of output: 849
🏁 Script executed:
# Inspect the trainer.rs file to understand the dependency on stop_gradient
fd "trainer.rs" -t f | xargs cat -n | head -100Repository: arosboro/your_ai
Length of output: 5099
🌐 Web query:
mlx-rs 0.25.2 ops::stop_gradient API signature parameters
💡 Result:
I couldn't find a direct source file in the mlx-rs repo showing ops::stop_gradient for v0.25.2. The MLX API's stop_gradient (Python) is: stop_gradient(a: array, *, stream: None | Stream | Device = None) -> array. (ml-explore.github.io)
In mlx-rs the equivalent is exposed as a simple unary op that takes an Array and returns an Array — i.e. the Rust signature you should expect is roughly:
- pub fn stop_gradient(a: &Array) -> Result<Array, Exception>
(mlx-rs wraps the MLX unary operator; any optional stream/device is handled by higher-level device/stream APIs in the crate.) This is inferred from the mlx docs and the mlx-rs crate documentation. (ml-explore.github.io)
If you want, I can (a) search the mlx-rs repo for the exact function/line if you want me to try again, or (b) open the crate source in your local environment and show the exact signature.
Citations:
- 1: https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.stop_gradient.html?utm_source=openai
- 2: https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.stop_gradient.html?utm_source=openai
🌐 Web query:
mlx-rs oxideai stop_gradient ops module 0.25.2 documentation
💡 Result:
I couldn't find a dedicated mlx-rs (Rust) docs page for an ops::stop_gradient function in v0.25.2, but relevant documentation and references:
- mlx (Python) docs for stop_gradient — description and signature: stop_gradient(a: array, *, stream=None) -> array (identity that blocks gradients). [1]
- mlx-rs crate entry (docs.rs) and note that full docs are hosted on the project's GitHub Pages; check the repo for ops source in src/ops (mlx-rs v0.25.2). [2]
- oxideai/mlx-rs GitHub repo (source code / ops modules) — search src/ops in the repo for stop_gradient or related gradient ops. [3]
If you want, I can:
- Search the mlx-rs source tree on GitHub now and locate the exact Rust function (module path, signature, and docstring) for v0.25.2 and return the exact code/doc excerpt.
🏁 Script executed:
# Check if mlx_rs::ops is already imported or used anywhere
rg "mlx_rs::ops" --type rust
# Check for any existing ops usage pattern
rg "ops::" --type rust | head -20Repository: arosboro/your_ai
Length of output: 3317
🏁 Script executed:
# Look at the actual mlx_memory.rs file to see full context
sed -n '85,120p' rust/src/utils/mlx_memory.rsRepository: arosboro/your_ai
Length of output: 1161
🏁 Script executed:
# Search for stop_gradient in mlx-rs if it's vendored or in dependencies
find . -name "*.rs" -path "*/mlx*" | xargs rg "stop_gradient" 2>/dev/null | head -20
# Check if mlx-rs is vendored or in patches
ls -la rust/patches/Repository: arosboro/your_ai
Length of output: 458
🏁 Script executed:
# Try to find mlx-rs source or check what ops are available
rg "pub fn stop_gradient" --type rust
# Also search for references to "stop_gradient" anywhere
rg "stop_gradient" . --type rustRepository: arosboro/your_ai
Length of output: 640
🌐 Web query:
site:github.com oxideai mlx-rs stop_gradient ops v0.25.2 source code
💡 Result:
I found the mlx-rs repo and the MLX stop_gradient operation documentation. Relevant sources:
- mlx-rs GitHub repository (oxideai/mlx-rs). [1]
- MLX Python docs for stop_gradient (description). [2]
- MLX C++ ops docs showing stop_gradient signature. [3]
If you want, I can fetch the exact mlx-rs source file (ops wrapper) that exposes stop_gradient and show the implementation for a specific tag (e.g., v0.25.2) — tell me which you prefer.
Sources:
[1] https://github.com/oxideai/mlx-rs.
[2] https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.stop_gradient.html.
[3] https://ml-explore.github.io/mlx/build/html/cpp/ops.html.
🌐 Web query:
"mlx-rs" "stop_gradient" github oxideai rust binding
💡 Result:
Yes — mlx (the upstream library) provides stop_gradient, and the oxideai/mlx-rs project is the Rust binding for MLX. The MLX docs show stop_gradient as an operation, and the mlx-rs repo (oxideai/mlx-rs) is the Rust wrapper that exposes MLX functionality. [1][2]
If you want the exact Rust symbol/path (e.g., ops::stop_gradient or mlx::core::stop_gradient) I can search the mlx-rs source tree and give the precise module/function name — would you like me to do that?
Sources:
- oxideai/mlx-rs (GitHub repo). [1]
- MLX docs: mlx.core.stop_gradient / C++ ops listing. [2]
🏁 Script executed:
# Check mlx-sys patches to see what's available there
cat rust/patches/mlx-sys/Cargo.toml | head -30
# Look for any C bindings or FFI definitions related to stop_gradient
find rust/patches -name "*.rs" -o -name "*.h" | xargs rg "stop_gradient" 2>/dev/nullRepository: arosboro/your_ai
Length of output: 1061
🏁 Script executed:
# Check if mlx-rs exposes stop_gradient by looking at what operations are wrapped
# First, let's see what the actual pattern is for accessing ops in the codebase
# Look at how mlx_rs::ops functions are called
rg "mlx_rs::ops::" --type rust -A 1 | head -40
# Also check imports to see if there's a pattern
rg "use mlx_rs" --type rust | sort | uniqRepository: arosboro/your_ai
Length of output: 3588
🏁 Script executed:
# Try to infer from Cargo.lock if mlx-rs 0.25.2 source is available
if [ -f "Cargo.lock" ]; then
grep -A 10 "^name = \"mlx-rs\"" Cargo.lock | head -20
else
echo "No Cargo.lock found"
fi
# Check what version of mlx-rs is actually locked
grep -r "mlx-rs" rust/Cargo.lock 2>/dev/null | head -5Repository: arosboro/your_ai
Length of output: 79
Replace the add(0) workaround with the native MLX stop_gradient operation.
The current implementation using array.add(Array::from_f32(0.0)) does not actually stop gradients in MLX's autograd system. MLX's stop_gradient is an identity operation that explicitly marks an array as non-differentiable; add(0) may create computation graph nodes that allow gradients to flow through, breaking the "frozen backbone" design in trainer.rs.
Use mlx_rs::ops::stop_gradient(&array) (matching the pattern of other ops like softmax_axis, expand_dims, etc. already in the codebase), which wraps the native MLX C API (mlx_stop_gradient). Update the doc comment to reflect the correct behavior and remove the misleading claim that mlx-rs doesn't expose it.
🤖 Prompt for AI Agents
rust/src/utils/mlx_memory.rs around lines 94 to 113: the function currently uses
array.add(Array::from_f32(0.0)) which does not reliably stop gradients; replace
that call with the native op mlx_rs::ops::stop_gradient(array) (importing
mlx_rs::ops::stop_gradient or calling it fully-qualified) so the array is
explicitly marked non-differentiable, keep the same return type, and update the
doc comment to remove the incorrect claim that mlx-rs lacks stop_gradient and to
state that this uses the native mlx_stop_gradient via
mlx_rs::ops::stop_gradient.
| ### Unit Tests: 14/16 PASSING (87.5%) | ||
|
|
||
| **Passing Tests (14):** | ||
| ``` | ||
| ✅ distrust_loss::tests::test_basic_calculation | ||
| ✅ distrust_loss::tests::test_invalid_alpha | ||
| ✅ distrust_loss::tests::test_invalid_authority_weight | ||
| ✅ distrust_loss::tests::test_invalid_provenance_entropy | ||
| ✅ hardware::detection::tests::test_get_gpu_cores | ||
| ✅ hardware::scaling::tests::test_memory_estimation | ||
| ✅ hardware::scaling::tests::test_detect_model_size | ||
| ✅ model::loader::tests::test_model_loader_creation | ||
| ✅ training::scheduler::tests::test_warmup_cosine_schedule | ||
| ✅ utils::memory::tests::test_format_bytes | ||
| ✅ citation_scorer::tests::test_extract_year | ||
| ✅ citation_scorer::tests::test_count_citations | ||
| ✅ (+ 2 more utility tests) | ||
| ``` | ||
|
|
||
| **Failing Tests (2):** | ||
| ``` | ||
| ❌ utils::memory::tests::test_memory_info | ||
| ❌ utils::memory::tests::test_memory_monitor | ||
| ``` | ||
|
|
||
| ### Root Cause of Test Failures | ||
|
|
||
| **Issue:** MLX Metal device initialization crash in test environment | ||
|
|
||
| **Error:** | ||
| ``` | ||
| NSRangeException: '*** -[__NSArray0 objectAtIndex:]: index 0 beyond bounds for empty array' | ||
| at mlx::core::metal::Device::Device() | ||
| ``` | ||
|
|
||
| **Explanation:** | ||
| - MLX tries to enumerate Metal GPU devices when test binary loads | ||
| - In test/CI environments, Metal framework may not be fully initialized | ||
| - This is a **known MLX-rs limitation**, not a bug in our code | ||
| - Tests crash before they even run | ||
|
|
||
| **Impact:** | ||
| - Memory tests use system calls (ps, sysctl), not MLX | ||
| - They work fine in production (verified via 50-step training run) | ||
| - Crash is environmental, not functional | ||
|
|
||
| **Mitigation:** | ||
| - Tests marked with `#[ignore]` to skip in automated runs | ||
| - Can be run individually with `--ignored` flag when Metal is available | ||
| - Production training fully validated (6m 23s run, all functionality verified) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test numbers are internally inconsistent (passing vs skipped vs crashing).
Please reconcile the “14/16 passing” claim with the later cargo test note (“12 pass, 2 skip, 2 crash”) and describe the expected CI command/result unambiguously.
Also applies to: 149-161
🧰 Tools
🪛 markdownlint-cli2 (0.18.1)
19-19: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
36-36: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
46-46: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🤖 Prompt for AI Agents
In rust/TEST_STATUS.md around lines 16 to 66 (and also lines 149 to 161), the
test-summary is inconsistent: it claims "Unit Tests: 14/16 PASSING (87.5%)" but
later notes "12 pass, 2 skip, 2 crash" and the crash/skip counts are unclear;
update the file to present a single, unambiguous test-result summary and the
exact CI invocation to reproduce expected behavior. Specifically, change the top
summary to accurately reflect counts (e.g., "12 passed, 2 ignored, 2 crashed —
12/16 passing (75%)" or whatever the true numbers are), mark the two
Metal-dependent tests as ignored in the listing and explain that crashes are due
to MLX Metal initialization, and add the exact commands to run tests in CI and
locally (e.g., `cargo test --tests -- --ignored` to run ignored tests when Metal
is available) plus the expected exit code/behavior for CI (successful run with
ignored tests skipped). Ensure both locations (lines 16–66 and 149–161) are
updated to match verbatim.
| ``` | ||
| ✅ distrust_loss::tests::test_basic_calculation | ||
| ✅ distrust_loss::tests::test_invalid_alpha | ||
| ✅ distrust_loss::tests::test_invalid_authority_weight | ||
| ✅ distrust_loss::tests::test_invalid_provenance_entropy | ||
| ✅ hardware::detection::tests::test_get_gpu_cores | ||
| ✅ hardware::scaling::tests::test_memory_estimation | ||
| ✅ hardware::scaling::tests::test_detect_model_size | ||
| ✅ model::loader::tests::test_model_loader_creation | ||
| ✅ training::scheduler::tests::test_warmup_cosine_schedule | ||
| ✅ utils::memory::tests::test_format_bytes | ||
| ✅ citation_scorer::tests::test_extract_year | ||
| ✅ citation_scorer::tests::test_count_citations | ||
| ✅ (+ 2 more utility tests) | ||
| ``` | ||
|
|
||
| **Failing Tests (2):** | ||
| ``` | ||
| ❌ utils::memory::tests::test_memory_info | ||
| ❌ utils::memory::tests::test_memory_monitor | ||
| ``` | ||
|
|
||
| ### Root Cause of Test Failures | ||
|
|
||
| **Issue:** MLX Metal device initialization crash in test environment | ||
|
|
||
| **Error:** | ||
| ``` | ||
| NSRangeException: '*** -[__NSArray0 objectAtIndex:]: index 0 beyond bounds for empty array' | ||
| at mlx::core::metal::Device::Device() | ||
| ``` | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add language identifiers to fenced blocks (markdownlint MD040).
Also applies to: 82-90
🧰 Tools
🪛 markdownlint-cli2 (0.18.1)
19-19: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
36-36: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
46-46: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🤖 Prompt for AI Agents
In rust/TEST_STATUS.md around lines 19 to 50 (and also lines 82 to 90), the
fenced code blocks lack language identifiers which triggers markdownlint MD040;
update each triple-backtick fence to include an appropriate language tag (e.g.,
"text", "none", or "bash") so the blocks are explicitly marked, ensuring the
linter passes and rendering stays correct.
| ``` | ||
| LlamaForCausalLM | ||
| ├── LlamaBackbone (frozen, 514 params) | ||
| │ ├── embed_tokens | ||
| │ └── layers[0-31] | ||
| └── TrainableHead (gradients, 2-3 params) | ||
| ├── norm | ||
| └── lm_head | ||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix markdownlint MD040/MD036 (language tags + real headings).
Also double-check whether the model name spelling (“abliterated”) is intentional (LanguageTool flagged it).
Also applies to: 133-145, 183-191
🧰 Tools
🪛 markdownlint-cli2 (0.18.1)
39-39: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🤖 Prompt for AI Agents
In TRAINING_SUCCESS_SUMMARY.md around lines 39-47 (and also apply the same
changes to 133-145 and 183-191), the fenced block lacks a language tag and the
document uses non-standard/incorrect heading styling; add an explicit language
tag after each ``` (e.g., ```text or ```dot if these are tree diagrams) and
replace any faux or emphasized headings with proper markdown headings (#+ space
+ heading text) so they are recognized as real headings; additionally verify and
correct the model name spelling “abliterated” to the intended term (e.g.,
“obliterated”) everywhere it appears.
| **Optimizations:** | ||
| - Detached backbone forward using `add(0)` trick (no CPU extraction) | ||
| - GPU-only AdamW optimizer (momentum stored as GPU Arrays) | ||
| - No `as_slice()` calls during training (eliminates CPU transfer leaks) | ||
| - Configurable sequence length (default: max_seq_length.min(512)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc should match the implementation: avoid claiming add(0) detach + avoid hardcoded log locations.
Once code is fixed to use the real stop-gradient op and configurable logging, update this section accordingly.
Also applies to: 254-255
🤖 Prompt for AI Agents
In TRAINING_SUCCESS_SUMMARY.md around lines 57-61 (and also update lines
254-255), the documentation incorrectly claims use of the `add(0)` detach trick
and hardcoded log locations; update the text to reflect the actual
implementation: remove the `add(0)`/no-CPU-extraction claim and instead state
that a proper stop-gradient operation is used, and replace any mention of fixed
log paths with a note that logging locations are configurable (mention the
config/env var key used). Keep wording concise and consistent with real code
behavior.
| ```bash | ||
| cd /Users/arosboro/your_ai | ||
| source venv/bin/activate # or create new venv | ||
| pip install mlx mlx-lm transformers | ||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace user-specific absolute paths with portable commands.
E.g. /Users/arosboro/your_ai should be something like cd <repo-root> or cd "$(git rev-parse --show-toplevel)".
Also applies to: 69-86, 259-272
🤖 Prompt for AI Agents
In VALIDATION_GUIDE.md around lines 45 to 49 (and similarly at 69-86 and
259-272), the example uses a user-specific absolute path
(/Users/arosboro/your_ai); replace that with a portable repository-root command
such as cd <repo-root> or cd "$(git rev-parse --show-toplevel)" and use
relative/neutral instructions for activating/creating the virtualenv and
installing packages so the guide works on any machine (update the three affected
sections consistently).
| ``` | ||
| rust/models/distrust-llama-8b/checkpoint-step-50.json | ||
| ``` | ||
|
|
||
| But validation requires full MLX model directory: | ||
| ``` | ||
| models/distrust-llama-8b/ | ||
| ├── config.json | ||
| ├── tokenizer.json | ||
| ├── model.safetensors (or sharded) | ||
| └── (LoRA adapters if applicable) | ||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix markdownlint issues: add languages to fenced blocks + use real headings.
This should address MD040/MD036 and improves readability.
Also applies to: 319-333
🧰 Tools
🪛 markdownlint-cli2 (0.18.1)
54-54: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
59-59: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🤖 Prompt for AI Agents
In VALIDATION_GUIDE.md around lines 54 to 65 (and also apply same changes to
319-333), the markdown uses fenced code blocks without language identifiers and
non-standard heading styling; update each triple-backtick fence to include an
appropriate language tag (e.g., ```text or ```bash or ```json as appropriate)
and replace any fake headings or underlined text with proper Markdown ATX
headings (e.g., use `##`/`###`) so the document satisfies markdownlint rules
MD040 and MD036 and improves readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
rust/src/model/llama.rs (1)
719-724:extra_keyswill be wrong when legacy-name mapping loads weights.
A key loaded vialegacy_name(e.g.,model.norm.weight) won’t be inparameters, so it’ll be flagged “extra” even if it was used. Track “matched weight keys” during loading and exclude those here.pub fn load_weights_into_model( model: &mut LlamaForCausalLM, weights: HashMap<String, Array>, ) -> anyhow::Result<()> { @@ let mut loaded_count = 0; let mut missing_keys: Vec<String> = Vec::new(); let mut extra_keys: Vec<String> = Vec::new(); + let mut matched_weight_keys: std::collections::HashSet<String> = std::collections::HashSet::new(); @@ - if let Some(weight_array) = weights.get(¶m_name_str) { + if let Some(weight_array) = weights.get(¶m_name_str) { if weight_array.shape() == param.shape() { **param = weight_array.clone(); let _ = param.eval(); loaded_count += 1; + matched_weight_keys.insert(param_name_str.clone()); continue; } } @@ - if let Some(weight_array) = weights.get(&legacy_name) { + if let Some(weight_array) = weights.get(&legacy_name) { if weight_array.shape() == param.shape() { **param = weight_array.clone(); let _ = param.eval(); loaded_count += 1; + matched_weight_keys.insert(legacy_name.clone()); continue; } else { @@ for weight_key in weights.keys() { - if !parameters.contains_key(weight_key.as_str()) { + if !parameters.contains_key(weight_key.as_str()) && !matched_weight_keys.contains(weight_key) { extra_keys.push(weight_key.clone()); } }
♻️ Duplicate comments (5)
rust/src/model/llama.rs (1)
663-717: Legacy name mapping now correctly handles flattened params (.weight/.bias).
This fixes the earlier exact-string matching issue (e.g.,head.norm.weight).rust/src/training/trainer.rs (4)
492-557: Debug logging is still doing hot-loop file I/O (and/dev/nullmakes it unconditional + non-portable).
Several blocks open/write every step even whenYOUR_AI_DEBUG_LOGis unset viaunwrap_or_else(|| PathBuf::from("/dev/null"))(e.g., Line 544, 825, 864, 884, 904). This is both a perf hit and breaks on non-Unix. Prefer routing all these blocks throughlog_debug()(which already gates ondebug_log_path()), and delete/dev/nullfallbacks entirely.- if let Ok(mut file) = std::fs::OpenOptions::new() - .create(true) - .append(true) - .open(debug_log_path().unwrap_or_else(|| PathBuf::from("/dev/null"))) - { - let json = serde_json::json!({ ... }); - let _ = writeln!(file, "{}", json); - } + self.log_debug( + "trainer.rs:after_training_step", + "training_step returned successfully", + self.global_step, + "main_loop", + );Also applies to: 816-918, 1061-1111
699-721: Reload can attempt to read a checkpoint file that was never written when checkpointing is disabled.
save_checkpoint()is a no-op whenself.checkpoint_manager.is_none()(Line 1609), but the reload path still proceeds toreload_from_checkpoint(&checkpoint_path)(Line 708) which reads the file (Line 1256). Gate reload oncheckpoint_manager.is_some()or makesave_checkpoint()error when disabled.if should_reload { + if self.checkpoint_manager.is_none() { + eprintln!("Warning: reload requested but checkpointing is disabled; skipping reload."); + // Or: anyhow::bail!("reload requested but checkpointing is disabled"); + } else { // Save checkpoint before reload let checkpoint_path = PathBuf::from(&self.config.paths.output_dir) .join(format!("checkpoint-step-{}.json", self.global_step)); @@ match self.reload_from_checkpoint(&checkpoint_path) { @@ } } + } }Also applies to: 1608-1698
1250-1317:reload_from_checkpoint()does not restore head params / optimizer state from the checkpoint file.
You serialize head params intooptimizer_stateasparam.<name>and momentum as<name>.m/.v(Lines 1624-1659), but reload currently parses the JSON and then ignores it, instead relying on in-memoryself.adam_m/self.adam_v(Lines 1285-1304). That makes reload incorrect (and it won’t work after restart).fn reload_from_checkpoint(&mut self, checkpoint_path: &PathBuf) -> anyhow::Result<()> { @@ - // Step 2: Drop current model to free ALL MLX Arrays + let optimizer_state = checkpoint_json["optimizer_state"] + .as_object() + .ok_or_else(|| anyhow::anyhow!("Invalid checkpoint: missing optimizer_state"))?; + + // Step 2: Drop current model to free ALL MLX Arrays let config_clone = self.model.config().clone(); let lora_rank = self.model.lora_rank; @@ - // Step 5: Restore trainable head weights from CPU cache (self.adam_m/v already have the data) - // We rely on the fact that parameters were just updated, so we copy from current head - // This avoids complex deserialization - simple approach for MVP - - self.model = fresh_model; + self.model = fresh_model; + + // Restore adam_step + if let Some(v) = optimizer_state.get("adam_step").and_then(|v| v.as_u64()) { + self.adam_step = v as usize; + } + + // Restore head params + { + let mut head_params = self.model.head.parameters_mut().flatten(); + for (k, v) in optimizer_state.iter() { + let Some(name) = k.strip_prefix("param.") else { continue }; + let Some(obj) = v.as_object() else { continue }; + let data: Vec<f32> = obj.get("data") + .and_then(|d| d.as_array()) + .ok_or_else(|| anyhow::anyhow!("param {name} missing data"))? + .iter().filter_map(|x| x.as_f64()).map(|x| x as f32).collect(); + let shape: Vec<i32> = obj.get("shape") + .and_then(|s| s.as_array()) + .ok_or_else(|| anyhow::anyhow!("param {name} missing shape"))? + .iter().filter_map(|x| x.as_i64()).map(|x| x as i32).collect(); + if let Some(p) = head_params.get_mut(name) { + let new_param = Array::from_slice(&data, &shape); + let _old = std::mem::replace(&mut **p, new_param); + drop(_old); + } + } + } @@ - // Step 6: Restore optimizer momentum to GPU from CPU cache - for (param_name, (data, shape)) in &self.adam_m { - let m_array = Array::from_slice(data, shape); - let _ = m_array.eval(); - self.adam_m_gpu.insert(param_name.clone(), m_array); - } - for (param_name, (data, shape)) in &self.adam_v { - let v_array = Array::from_slice(data, shape); - let _ = v_array.eval(); - self.adam_v_gpu.insert(param_name.clone(), v_array); - } + // Restore momentum from checkpoint + self.adam_m.clear(); + self.adam_v.clear(); + for (k, v) in optimizer_state.iter() { + let Some(obj) = v.as_object() else { continue }; + let data: Vec<f32> = obj.get("data").and_then(|d| d.as_array()) + .into_iter().flatten() + .filter_map(|x| x.as_f64()).map(|x| x as f32).collect(); + let shape: Vec<i32> = obj.get("shape").and_then(|s| s.as_array()) + .into_iter().flatten() + .filter_map(|x| x.as_i64()).map(|x| x as i32).collect(); + if let Some(param) = k.strip_suffix(".m") { + let m = Array::from_slice(&data, &shape); + let _ = m.eval(); + self.adam_m_gpu.insert(param.to_string(), m); + } else if let Some(param) = k.strip_suffix(".v") { + let vv = Array::from_slice(&data, &shape); + let _ = vv.eval(); + self.adam_v_gpu.insert(param.to_string(), vv); + } + }Also applies to: 1624-1659
1456-1473: Use mlx-rs native stop-gradient API instead of a custom workaround.
Right now you callcrate::utils::mlx_memory::stop_gradient(&hidden)with a comment that it’s a workaround (Line 1461-1463). In mlx-rs 0.25.2 there may already be a supportedstop_gradient(_device)function—prefer that to avoid subtle graph/memory behavior changes.mlx-rs 0.25.2 stop_gradient API: what is the correct function name/path (e.g., `mlx_rs::array::stop_gradient_device`), and what arguments does it take?
🧹 Nitpick comments (1)
rust/src/training/trainer.rs (1)
87-91: Prefervar_osfor env paths (avoid UTF-8 assumptions).
std::env::var("YOUR_AI_DEBUG_LOG")can fail on non-UTF8 paths;var_osis safer for filesystem paths.fn debug_log_path() -> Option<PathBuf> { - std::env::var("YOUR_AI_DEBUG_LOG").ok().map(PathBuf::from) + std::env::var_os("YOUR_AI_DEBUG_LOG").map(PathBuf::from) }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
rust/src/model/llama.rs(6 hunks)rust/src/training/trainer.rs(23 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-30T23:47:27.584Z
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/distrust_loss.py : Implement the core `empirical_distrust_loss` function that penalizes high-authority, low-entropy sources and rewards primary empirical sources
Applied to files:
rust/src/training/trainer.rs
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Build and Test
- GitHub Check: Lint
🔇 Additional comments (3)
rust/src/model/llama.rs (2)
323-357: Rename toself_attnlooks consistent.
This is a clean mechanical rename in ctor + forward and should be compiler-enforced across call sites.
424-546: Backbone/head split is clear and keeps config access consistent.
Nice separation (config()now returns&self.backbone.config), and the backbone forward returning hidden states matches the intended “head-only gradients” design.rust/src/training/trainer.rs (1)
1124-1130: Good: AdamW hyperparams + seq_len are now config-driven (with safety caps).
This reduces surprising behavior and makes “stability” tunable.Also applies to: 1398-1405
|
@coderabbitai is it ready. to merge or give me an ai prompt to make it work |
|
Tip For best results, initiate chat on the files or code changes. 🧩 Analysis chain🏁 Script executed: Length of output: 3081 🏁 Script executed: Length of output: 1616 🏁 Script executed: Length of output: 1697 🏁 Script executed: Length of output: 2851 🏁 Script executed: Length of output: 3192 🏁 Script executed: Length of output: 543 🏁 Script executed: Length of output: 401 🏁 Script executed: Length of output: 2226 @arosboro, I have good news and one remaining critical issue to address. 📊 Current StatusPR State: The PR is CLOSED (not merged). It was closed on Dec 12, 2025 at 22:58:12Z without being merged. Fixed Issues ✅ (5 of 6):
Remaining Critical Issue ❌ (1 of 6): 🔴 Critical Bug: Checkpoint Reload Without GuardsLocation: Problem: When checkpointing is disabled (
🛠️ AI Prompt to Fix the Last IssueThis ensures reload is only attempted when checkpoint_manager is available to save/load checkpoint files. Run |
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.