Skip to content

[MUON] add Muon optimizer#408

Open
mayank31398 wants to merge 19 commits intomainfrom
muon
Open

[MUON] add Muon optimizer#408
mayank31398 wants to merge 19 commits intomainfrom
muon

Conversation

@mayank31398
Copy link
Copy Markdown
Collaborator

No description provided.

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a parameter-splitting mechanism to support the Muon optimizer, allowing specific tensors like QKV or MLP gates to be partitioned before optimization. It includes a new _MuonWithAdamW wrapper to handle different parameter types and updates model blocks to register split functions. Key feedback includes correcting the target weight for splitting in the attention module, fixing an invalid assertion type check, ensuring AdamW hyperparameters are correctly propagated in the Muon wrapper, and improving the robustness of state dictionary loading.

for model, params_groups in zip(model_container, params_groups_list):
for param_name, param in model.named_parameters():
if get_optimizer_split_function(param) is not None:
assert optimizer_class not in _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion check will always pass because optimizer_class is a class object, while _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS is a list of strings (e.g., ["TorchMuon"]). You should check against optimizer_class_name instead.

Suggested change
assert optimizer_class not in _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS
assert optimizer_class_name not in _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mayank31398 - also looks like a bug: in line 45 you have: _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS = ["TorchMuon"] and in 155 you have optimizer_class_name == "TorchMuon"

Comment thread lm_engine/hf_models/parameter.py Outdated

optimizer_list = BackwardHookOptimizerContainer([None] * len(model_container))
elif optimizer_class_name == "TorchMuon":
adamw_args = {"lr": optimizer_class_args.get("lr", 1e-3)}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When using Muon, the AdamW optimizer used for 1D parameters and embeddings currently only inherits the learning rate. Other important hyperparameters like weight_decay, betas, and eps provided in optimizer_class_args are ignored for the AdamW part. It is recommended to pass these parameters to ensure consistent optimization behavior for non-Muon parameters.

Suggested change
adamw_args = {"lr": optimizer_class_args.get("lr", 1e-3)}
adamw_args = {k: v for k, v in optimizer_class_args.items() if k in ("lr", "betas", "eps", "weight_decay")}
adamw_args.setdefault("lr", 1e-3)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mayank31398 -

optimizer_list_entries.append(
                _MuonWithAdamW(
                    muon=TorchMuon(muon_groups, **optimizer_class_args) if muon_groups else None,
                    adamw=TorchAdamW(adamw_groups, **adamw_args) if adamw_groups else None,
                )

We should make sure we can pass all the adamw configs

Comment on lines +96 to +100
def load_state_dict(self, state_dict: dict) -> None:
if self.muon is not None and state_dict["muon"] is not None:
self.muon.load_state_dict(state_dict["muon"])
if self.adamw is not None and state_dict["adamw"] is not None:
self.adamw.load_state_dict(state_dict["adamw"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The load_state_dict method in the _MuonWithAdamW wrapper assumes the provided state_dict always contains "muon" and "adamw" keys. If a standard optimizer state dict is passed (e.g., during a transition or from a different checkpoint format), this will raise a KeyError. Using .get() would make this more robust.

Suggested change
def load_state_dict(self, state_dict: dict) -> None:
if self.muon is not None and state_dict["muon"] is not None:
self.muon.load_state_dict(state_dict["muon"])
if self.adamw is not None and state_dict["adamw"] is not None:
self.adamw.load_state_dict(state_dict["adamw"])
def load_state_dict(self, state_dict: dict) -> None:
if self.muon is not None and state_dict.get("muon") is not None:
self.muon.load_state_dict(state_dict["muon"])
if self.adamw is not None and state_dict.get("adamw") is not None:
self.adamw.load_state_dict(state_dict["adamw"])

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Copy link
Copy Markdown

@jyopari jyopari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a first pass - maybe we should have a test to make sure the Torch implementation of distributed muon works as intended?

mayank31398 and others added 3 commits April 15, 2026 11:40
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants