Skip to content

Refactor and fix rollout routing replay (R3)#575

Merged
yueming-yuan merged 15 commits intomainfrom
feature/r3-upstream
Feb 18, 2026
Merged

Refactor and fix rollout routing replay (R3)#575
yueming-yuan merged 15 commits intomainfrom
feature/r3-upstream

Conversation

@yueming-yuan
Copy link
Collaborator

@yueming-yuan yueming-yuan commented Feb 9, 2026

Note: remember to apply dev/megatron.patch before run this PR.

Summary

  • Support rollout replay for qkv_format=bshd (non-packing)
  • Replace env-var-driven RoutingReplay with a cleaner BaseReplayManager / RoutingReplayManager pattern
  • Remove os.environ usage for replay stage control in favor of direct manager state
  • Generalize fill_routing_replay into _fill_replay_data with extracted layer registration logic in replay_utils
  • Delete old routing_replay.py, replaced by replay_base.py

Files changed

  • miles/utils/replay_base.py — New: replay framework (Replay, BaseReplayManager, RoutingReplayManager)
  • miles/utils/routing_replay.py — Deleted: old implementation
  • miles/backends/megatron_utils/replay_utils.py — New: MoE layer replay data distribution
  • miles/backends/megatron_utils/actor.py — Support bshd format; use all_replay_managers instead of RoutingReplay + os.environ
  • miles/backends/megatron_utils/model.py — Use all_replay_managers instead of os.environ
  • miles/ray/actor_group.py — Remove old ENABLE_ROUTING_REPLAY env var
  • miles/rollout/generate_utils/generate_endpoint_utils.py — Generalize _get_rollout_topk_from_response

CI changed

  • re-apply dev/megatron.patch each time before run CI.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yueming-yuan, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the rollout routing replay (R3) system, moving away from a global environment variable-based control to a more object-oriented and extensible manager pattern. The changes improve the maintainability and flexibility of the replay mechanism, enabling support for new data formats like bshd and streamlining the process of recording and replaying expert routing decisions during model training.

Highlights

  • Replay System Refactor: The RoutingReplay mechanism has been refactored from an environment-variable-driven approach to a more modular BaseReplayManager and RoutingReplayManager pattern, improving extensibility and maintainability.
  • Environment Variable Removal: Usage of os.environ for controlling replay stages has been eliminated, leading to cleaner and more direct state management within the replay managers.
  • Generalized Replay Data Handling: The fill_routing_replay function has been generalized into _fill_replay_data, allowing for more flexible replay data population and introducing extracted layer registration logic in replay_utils.py.
  • QKV Format Support: Rollout replay now includes support for the qkv_format=bshd (non-packing) configuration, expanding its compatibility.
  • New Replay Framework: A new replay_base.py file introduces the core Replay class and the BaseReplayManager framework, completely replacing the old routing_replay.py.
Changelog
  • miles/backends/megatron_utils/actor.py
    • Removed direct os import and RoutingReplay import.
    • Integrated all_replay_managers for centralized replay management.
    • Replaced environment variable checks with direct manager state access for replay stages.
    • Generalized fill_routing_replay to _fill_replay_data for broader applicability.
  • miles/backends/megatron_utils/model.py
    • Removed direct os import.
    • Updated replay stage management to use all_replay_managers instead of os.environ.
  • miles/backends/megatron_utils/replay_utils.py
    • Added new file to encapsulate logic for registering replay lists, specifically for MoE layers.
  • miles/ray/actor_group.py
    • Removed the ENABLE_ROUTING_REPLAY environment variable setting.
  • miles/rollout/generate_utils/generate_endpoint_utils.py
    • Refactored _get_rollout_routed_experts_from_response into a more generic _get_rollout_topk_from_response to handle various top-k data types.
  • miles/utils/replay_base.py
    • Added new file introducing Replay class for managing replay data.
    • Introduced BaseReplayManager and RoutingReplayManager for a structured approach to replay logic.
    • Implemented get_topk_fn to inject replay behavior into top-k computations.
  • miles/utils/routing_replay.py
    • Deleted old implementation of the routing replay system.
Activity
  • No human activity (comments, reviews, approvals) has been detected on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-executed refactoring of the rollout routing replay functionality. Replacing the environment variable-driven approach with a BaseReplayManager pattern is a major improvement for code clarity, maintainability, and extensibility. The generalization of fill_routing_replay and other utilities is also a positive change. My review includes a few suggestions to further enhance robustness and code consistency.

Comment on lines 113 to 142
elif stage == "record":
result = old_topk_fn(scores, topk, *args, **kwargs)
if return_probs:
probs, top_indices = result
else:
top_indices = result
replay.record(top_indices)
return result

elif stage == "replay_forward":
replay_top_indices = replay.pop_forward()

shape_sanity_check(replay_top_indices, scores, topk)
top_indices = replay_top_indices[..., :topk].view(scores.shape[:-1] + (topk,))

self.check_replay_result(old_topk_fn, scores, topk, top_indices, **kwargs)

return get_probs_and_top_indices(top_indices, return_probs)

elif stage == "replay_backward":
replay_top_indices = replay.pop_backward()

shape_sanity_check(replay_top_indices, scores, topk)
top_indices = replay_top_indices[..., :topk].view(scores.shape[:-1] + (topk,))

self.check_replay_result(old_topk_fn, scores, topk, top_indices, **kwargs)

return get_probs_and_top_indices(top_indices, return_probs)
else:
return old_topk_fn(scores, topk, *args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The manager.get_current() call can return None, but the subsequent logic in the record and replay_* stages unconditionally calls methods on the replay object. This could lead to a runtime AttributeError if the current replay object is not set. Adding assertions to ensure replay is not None in these stages will make the code more robust.

            elif stage == "record":
                assert replay is not None, "Replay object not set in 'record' stage."
                result = old_topk_fn(scores, topk, *args, **kwargs)
                if return_probs:
                    probs, top_indices = result
                else:
                    top_indices = result
                replay.record(top_indices)
                return result

            elif stage == "replay_forward":
                assert replay is not None, "Replay object not set in 'replay_forward' stage."
                replay_top_indices = replay.pop_forward()

                shape_sanity_check(replay_top_indices, scores, topk)
                top_indices = replay_top_indices[..., :topk].view(scores.shape[:-1] + (topk,))

                self.check_replay_result(old_topk_fn, scores, topk, top_indices, **kwargs)

                return get_probs_and_top_indices(top_indices, return_probs)

            elif stage == "replay_backward":
                assert replay is not None, "Replay object not set in 'replay_backward' stage."
                replay_top_indices = replay.pop_backward()

                shape_sanity_check(replay_top_indices, scores, topk)
                top_indices = replay_top_indices[..., :topk].view(scores.shape[:-1] + (topk,))

                self.check_replay_result(old_topk_fn, scores, topk, top_indices, **kwargs)

                return get_probs_and_top_indices(top_indices, return_probs)
            else:
                return old_topk_fn(scores, topk, *args, **kwargs)

assert len(rollout_routed_experts) == len(tokens)
for a, b in zip(rollout_routed_experts, tokens, strict=False):
assert len(replay_data) == len(tokens)
for a, b in zip(replay_data, tokens, strict=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency and to leverage Python's built-in safety checks, it's recommended to use strict=True with zip. You have an assertion on the preceding line to ensure the lengths match, so this change makes the intent explicit and aligns with other zip calls in the codebase that use this parameter.

Suggested change
for a, b in zip(replay_data, tokens, strict=False):
for a, b in zip(replay_data, tokens, strict=True):

if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1":
old_stage = os.environ["ROUTING_REPLAY_STAGE"]
os.environ["ROUTING_REPLAY_STAGE"] = "replay_forward"
from miles.utils.replay_base import all_replay_managers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Per PEP 8 guidelines, imports should be placed at the top of the file. Moving this import to the top-level scope will improve readability and maintainability. There doesn't appear to be a circular dependency that would necessitate a local import here.


def _register_replay_list_moe(replay_list, replay_data, models):
layer_indices = []
replay_idx = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The replay_idx variable initialized here is unused and is later shadowed by the loop variable of the same name on line 24. It can be safely removed to improve code clarity.

Copy link
Collaborator

@fzyzcjy fzyzcjy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO(tom): I am still not very satisfied w/ my refactor b/c the time was urgent when I did it, but anyway this is not a blocker and feel free to merge it, and I will work on it later when having time

Copy link
Collaborator

@fzyzcjy fzyzcjy Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: btw when I named this file I originally thought there will be several files like routing_replay and indexer_replay, but later it seems no need for heavy weight files, thus should this be renamed

EDIT: @yueming-yuan already pinged me offline me further modifying this pr, I will do it as soon as having time

Replace the old env-var-driven RoutingReplay with a cleaner
BaseReplayManager / RoutingReplayManager pattern. This removes
os.environ usage for replay stage control in favor of direct
manager state, generalizes fill_routing_replay into _fill_replay_data,
and extracts layer registration logic into replay_utils.

Co-authored-by: fzyzcjy <ch271828n@outlook.com>
Co-authored-by: Yueming Yuan <yy28@illinois.edu>
@guapisolo
Copy link
Collaborator

guapisolo commented Feb 14, 2026

Megatron patch need some small fix, like: guapisolo/Megatron-LM@0a21682

replay_top_indices = replay.pop_backward()

shape_sanity_check(replay_top_indices, scores, topk)
top_indices = replay_top_indices[..., :topk].view(scores.shape[:-1] + (topk,))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry. I am a little bit confused by this line. I think the replay_top_indices has a shape (num_tokens, topk) then why do we need to change the view?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh it's related to another feature, for generalization changed to this one. maybe I can change back


module.register_forward_pre_hook(pre_forward_hook)

def check_replay_result(self, old_topk_fn, scores, topk, top_indices, **kwargs):
Copy link
Collaborator

@guapisolo guapisolo Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I runned test_moonlight_16B_A3B_r3.py with env var MILES_CHECK_REPLAY_RESULT=1, but the replay check failed frequently, which says token xxx failed replay check, len(orig_set & replay_set)=2 len(replay_set)=6. The default thresh_check_replay_result=0.5 maybe too strict?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make the thresh an env var? So the r3 CI can hack this value.

Copy link
Collaborator Author

@yueming-yuan yueming-yuan Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make the thresh an env var? So the r3 CI can hack this value.

Yes plan to add this in CI

Copy link
Collaborator

@fzyzcjy fzyzcjy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM since discussed offline w/ yueming and confirmed no core logic is changed compared w/ the impl we did it originally

shell: bash
run: |
cd /sgl-workspace/sglang && git fetch origin sglang-miles && git checkout FETCH_HEAD && git log --oneline -1 && pip install -e python --no-deps --break-system-packages
cd /root/Megatron-LM && git reset --hard HEAD && git log --oneline -1 && git apply $GITHUB_WORKSPACE/docker/patch/dev/megatron.patch && pip install -e . --no-deps --break-system-packages
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discuss with Tom to see whose PR will be merged first. If Tom's merge first. Might need to change here?~

@yueming-yuan yueming-yuan merged commit 6d15e48 into main Feb 18, 2026
22 checks passed
@yueming-yuan yueming-yuan deleted the feature/r3-upstream branch February 18, 2026 00:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants