-
Notifications
You must be signed in to change notification settings - Fork 0
Refactor TitansMAG forward logic and fix memory state handling #32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Added `self.memory.reset_memory()` to `TitansMAG::forward` to ensure consistent state initialization, matching the assumption in `compute_gradients` that execution starts from `init_memory`. This prevents state leakage across independent forward calls (e.g., between batches). - Refactored `src/memory/titans/mag.rs` to use `NeuralMemory::mlp_forward` instead of a duplicated local helper function, improving code maintainability. - Exposed `NeuralMemory::mlp_forward` as `pub(crate)` in `src/memory/titans/neural.rs` to support the refactoring. - Added `test_titans_mag_deterministic_forward` to verify that repeated forward passes on the same input produce identical outputs, confirming the state reset logic works. Co-authored-by: ryancinsight <55164720+ryancinsight@users.noreply.github.com>
|
👋 Jules, reporting for duty! I'm here to lend a hand with this pull request. When you start a review, I'll add a 👀 emoji to each comment to let you know I've read it. I'll focus on feedback directed at me and will do my best to stay out of conversations between you and other bots or reviewers to keep the noise down. I'll push a commit with your requested changes shortly after. Please note there might be a delay between these steps, but rest assured I'm on the job! For more direct control, you can switch me to Reactive Mode. When this mode is on, I will only act on comments where you specifically mention me with New to Jules? Learn more at jules.google/docs. For security, I will only act on instructions from the user who triggered this task. |
|
Warning Rate limit exceeded
⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @ryancinsight, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review by RecurseML
🔍 Review performed on a004bfd..7f24bef
✨ No bugs found, your code is sparkling clean
✅ Files analyzed, no issues (2)
• src/memory/titans/mag.rs
• src/memory/titans/neural.rs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a critical fix to the TitansMAG forward pass by ensuring the memory state is reset between calls. This prevents state leakage across batches and aligns the forward pass with the assumptions of the backward pass, resolving a potential bug in training loops. The changes also include a valuable refactoring that removes duplicated code by centralizing the mlp_forward logic into the NeuralMemory implementation. A new regression test has been added to verify that the forward pass is now deterministic. The changes are well-implemented, clearly explained, and improve both the correctness and maintainability of the code.
This PR refactors the
TitansMAGforward pass implementation to ensure correctness and maintainability.Key changes:
reset_memory()at the start offorward. In the Titans architecture, the neural memory is typically treated as "fast weights" learned over the context window. The gradient computation (compute_gradients) assumes the memory starts atinit_memory. Without an explicit reset inforward, the memory state would persist across calls (e.g., across training batches), leading to a disconnect between the forward pass state and the backward pass assumptions. This fixes potential bugs in training loops.mlp_forwardhelper function inmag.rsand exposed the existing implementation inNeuralMemory(neural.rs) to the crate. This eliminates code duplication.test_titans_mag_deterministic_forwardto confirm that the module is stateless across calls, as expected for a standard layer in this context. All existing tests passed.PR created automatically by Jules for task 6565573961390046446 started by @ryancinsight
High-level PR Summary
This PR refactors the
TitansMAGforward pass to fix memory state handling and eliminate code duplication. The key fix addsreset_memory()at the start offorwardto ensure the memory state is reset between calls, preventing state persistence across training batches that would violate backward pass assumptions. The refactor also removes a duplicatedmlp_forwardhelper function and exposes the existing implementation fromNeuralMemoryaspub(crate)instead. A new test verifies that forward passes are now deterministic across calls.⏱️ Estimated Review Time: 5-15 minutes
💡 Review Order Suggestion
src/memory/titans/neural.rssrc/memory/titans/mag.rs