Refactor and fix rollout routing replay (R3)#575
Conversation
Summary of ChangesHello @yueming-yuan, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the rollout routing replay (R3) system, moving away from a global environment variable-based control to a more object-oriented and extensible manager pattern. The changes improve the maintainability and flexibility of the replay mechanism, enabling support for new data formats like Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-executed refactoring of the rollout routing replay functionality. Replacing the environment variable-driven approach with a BaseReplayManager pattern is a major improvement for code clarity, maintainability, and extensibility. The generalization of fill_routing_replay and other utilities is also a positive change. My review includes a few suggestions to further enhance robustness and code consistency.
| elif stage == "record": | ||
| result = old_topk_fn(scores, topk, *args, **kwargs) | ||
| if return_probs: | ||
| probs, top_indices = result | ||
| else: | ||
| top_indices = result | ||
| replay.record(top_indices) | ||
| return result | ||
|
|
||
| elif stage == "replay_forward": | ||
| replay_top_indices = replay.pop_forward() | ||
|
|
||
| shape_sanity_check(replay_top_indices, scores, topk) | ||
| top_indices = replay_top_indices[..., :topk].view(scores.shape[:-1] + (topk,)) | ||
|
|
||
| self.check_replay_result(old_topk_fn, scores, topk, top_indices, **kwargs) | ||
|
|
||
| return get_probs_and_top_indices(top_indices, return_probs) | ||
|
|
||
| elif stage == "replay_backward": | ||
| replay_top_indices = replay.pop_backward() | ||
|
|
||
| shape_sanity_check(replay_top_indices, scores, topk) | ||
| top_indices = replay_top_indices[..., :topk].view(scores.shape[:-1] + (topk,)) | ||
|
|
||
| self.check_replay_result(old_topk_fn, scores, topk, top_indices, **kwargs) | ||
|
|
||
| return get_probs_and_top_indices(top_indices, return_probs) | ||
| else: | ||
| return old_topk_fn(scores, topk, *args, **kwargs) |
There was a problem hiding this comment.
The manager.get_current() call can return None, but the subsequent logic in the record and replay_* stages unconditionally calls methods on the replay object. This could lead to a runtime AttributeError if the current replay object is not set. Adding assertions to ensure replay is not None in these stages will make the code more robust.
elif stage == "record":
assert replay is not None, "Replay object not set in 'record' stage."
result = old_topk_fn(scores, topk, *args, **kwargs)
if return_probs:
probs, top_indices = result
else:
top_indices = result
replay.record(top_indices)
return result
elif stage == "replay_forward":
assert replay is not None, "Replay object not set in 'replay_forward' stage."
replay_top_indices = replay.pop_forward()
shape_sanity_check(replay_top_indices, scores, topk)
top_indices = replay_top_indices[..., :topk].view(scores.shape[:-1] + (topk,))
self.check_replay_result(old_topk_fn, scores, topk, top_indices, **kwargs)
return get_probs_and_top_indices(top_indices, return_probs)
elif stage == "replay_backward":
assert replay is not None, "Replay object not set in 'replay_backward' stage."
replay_top_indices = replay.pop_backward()
shape_sanity_check(replay_top_indices, scores, topk)
top_indices = replay_top_indices[..., :topk].view(scores.shape[:-1] + (topk,))
self.check_replay_result(old_topk_fn, scores, topk, top_indices, **kwargs)
return get_probs_and_top_indices(top_indices, return_probs)
else:
return old_topk_fn(scores, topk, *args, **kwargs)| assert len(rollout_routed_experts) == len(tokens) | ||
| for a, b in zip(rollout_routed_experts, tokens, strict=False): | ||
| assert len(replay_data) == len(tokens) | ||
| for a, b in zip(replay_data, tokens, strict=False): |
There was a problem hiding this comment.
For consistency and to leverage Python's built-in safety checks, it's recommended to use strict=True with zip. You have an assertion on the preceding line to ensure the lengths match, so this change makes the intent explicit and aligns with other zip calls in the codebase that use this parameter.
| for a, b in zip(replay_data, tokens, strict=False): | |
| for a, b in zip(replay_data, tokens, strict=True): |
| if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1": | ||
| old_stage = os.environ["ROUTING_REPLAY_STAGE"] | ||
| os.environ["ROUTING_REPLAY_STAGE"] = "replay_forward" | ||
| from miles.utils.replay_base import all_replay_managers |
|
|
||
| def _register_replay_list_moe(replay_list, replay_data, models): | ||
| layer_indices = [] | ||
| replay_idx = 0 |
fzyzcjy
left a comment
There was a problem hiding this comment.
TODO(tom): I am still not very satisfied w/ my refactor b/c the time was urgent when I did it, but anyway this is not a blocker and feel free to merge it, and I will work on it later when having time
There was a problem hiding this comment.
nit: btw when I named this file I originally thought there will be several files like routing_replay and indexer_replay, but later it seems no need for heavy weight files, thus should this be renamed
EDIT: @yueming-yuan already pinged me offline me further modifying this pr, I will do it as soon as having time
Replace the old env-var-driven RoutingReplay with a cleaner BaseReplayManager / RoutingReplayManager pattern. This removes os.environ usage for replay stage control in favor of direct manager state, generalizes fill_routing_replay into _fill_replay_data, and extracts layer registration logic into replay_utils. Co-authored-by: fzyzcjy <ch271828n@outlook.com> Co-authored-by: Yueming Yuan <yy28@illinois.edu>
d2db38b to
6cad330
Compare
|
Megatron patch need some small fix, like: guapisolo/Megatron-LM@0a21682 |
miles/utils/replay_base.py
Outdated
| replay_top_indices = replay.pop_backward() | ||
|
|
||
| shape_sanity_check(replay_top_indices, scores, topk) | ||
| top_indices = replay_top_indices[..., :topk].view(scores.shape[:-1] + (topk,)) |
There was a problem hiding this comment.
Sorry. I am a little bit confused by this line. I think the replay_top_indices has a shape (num_tokens, topk) then why do we need to change the view?
There was a problem hiding this comment.
oh it's related to another feature, for generalization changed to this one. maybe I can change back
miles/utils/replay_base.py
Outdated
|
|
||
| module.register_forward_pre_hook(pre_forward_hook) | ||
|
|
||
| def check_replay_result(self, old_topk_fn, scores, topk, top_indices, **kwargs): |
There was a problem hiding this comment.
I runned test_moonlight_16B_A3B_r3.py with env var MILES_CHECK_REPLAY_RESULT=1, but the replay check failed frequently, which says token xxx failed replay check, len(orig_set & replay_set)=2 len(replay_set)=6. The default thresh_check_replay_result=0.5 maybe too strict?
There was a problem hiding this comment.
Could we make the thresh an env var? So the r3 CI can hack this value.
There was a problem hiding this comment.
Could we make the thresh an env var? So the r3 CI can hack this value.
Yes plan to add this in CI
fzyzcjy
left a comment
There was a problem hiding this comment.
LGTM since discussed offline w/ yueming and confirmed no core logic is changed compared w/ the impl we did it originally
| shell: bash | ||
| run: | | ||
| cd /sgl-workspace/sglang && git fetch origin sglang-miles && git checkout FETCH_HEAD && git log --oneline -1 && pip install -e python --no-deps --break-system-packages | ||
| cd /root/Megatron-LM && git reset --hard HEAD && git log --oneline -1 && git apply $GITHUB_WORKSPACE/docker/patch/dev/megatron.patch && pip install -e . --no-deps --break-system-packages |
There was a problem hiding this comment.
Discuss with Tom to see whose PR will be merged first. If Tom's merge first. Might need to change here?~
Note: remember to apply
dev/megatron.patchbefore run this PR.Summary
qkv_format=bshd(non-packing)RoutingReplaywith a cleanerBaseReplayManager/RoutingReplayManagerpatternos.environusage for replay stage control in favor of direct manager statefill_routing_replayinto_fill_replay_datawith extracted layer registration logic inreplay_utilsrouting_replay.py, replaced byreplay_base.pyFiles changed
miles/utils/replay_base.py— New: replay framework (Replay,BaseReplayManager,RoutingReplayManager)miles/utils/routing_replay.py— Deleted: old implementationmiles/backends/megatron_utils/replay_utils.py— New: MoE layer replay data distributionmiles/backends/megatron_utils/actor.py— Supportbshdformat; useall_replay_managersinstead ofRoutingReplay+os.environmiles/backends/megatron_utils/model.py— Useall_replay_managersinstead ofos.environmiles/ray/actor_group.py— Remove oldENABLE_ROUTING_REPLAYenv varmiles/rollout/generate_utils/generate_endpoint_utils.py— Generalize_get_rollout_topk_from_responseCI changed
dev/megatron.patcheach time before run CI.