[MUON] add Muon optimizer#408
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a parameter-splitting mechanism to support the Muon optimizer, allowing specific tensors like QKV or MLP gates to be partitioned before optimization. It includes a new _MuonWithAdamW wrapper to handle different parameter types and updates model blocks to register split functions. Key feedback includes correcting the target weight for splitting in the attention module, fixing an invalid assertion type check, ensuring AdamW hyperparameters are correctly propagated in the Muon wrapper, and improving the robustness of state dictionary loading.
| for model, params_groups in zip(model_container, params_groups_list): | ||
| for param_name, param in model.named_parameters(): | ||
| if get_optimizer_split_function(param) is not None: | ||
| assert optimizer_class not in _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS |
There was a problem hiding this comment.
The assertion check will always pass because optimizer_class is a class object, while _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS is a list of strings (e.g., ["TorchMuon"]). You should check against optimizer_class_name instead.
| assert optimizer_class not in _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS | |
| assert optimizer_class_name not in _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS |
There was a problem hiding this comment.
@mayank31398 - also looks like a bug: in line 45 you have: _SPLIT_FUNCTION_INCOMPATIBLE_OPTIMIZERS = ["TorchMuon"] and in 155 you have optimizer_class_name == "TorchMuon"
|
|
||
| optimizer_list = BackwardHookOptimizerContainer([None] * len(model_container)) | ||
| elif optimizer_class_name == "TorchMuon": | ||
| adamw_args = {"lr": optimizer_class_args.get("lr", 1e-3)} |
There was a problem hiding this comment.
When using Muon, the AdamW optimizer used for 1D parameters and embeddings currently only inherits the learning rate. Other important hyperparameters like weight_decay, betas, and eps provided in optimizer_class_args are ignored for the AdamW part. It is recommended to pass these parameters to ensure consistent optimization behavior for non-Muon parameters.
| adamw_args = {"lr": optimizer_class_args.get("lr", 1e-3)} | |
| adamw_args = {k: v for k, v in optimizer_class_args.items() if k in ("lr", "betas", "eps", "weight_decay")} | |
| adamw_args.setdefault("lr", 1e-3) |
There was a problem hiding this comment.
optimizer_list_entries.append(
_MuonWithAdamW(
muon=TorchMuon(muon_groups, **optimizer_class_args) if muon_groups else None,
adamw=TorchAdamW(adamw_groups, **adamw_args) if adamw_groups else None,
)
We should make sure we can pass all the adamw configs
| def load_state_dict(self, state_dict: dict) -> None: | ||
| if self.muon is not None and state_dict["muon"] is not None: | ||
| self.muon.load_state_dict(state_dict["muon"]) | ||
| if self.adamw is not None and state_dict["adamw"] is not None: | ||
| self.adamw.load_state_dict(state_dict["adamw"]) |
There was a problem hiding this comment.
The load_state_dict method in the _MuonWithAdamW wrapper assumes the provided state_dict always contains "muon" and "adamw" keys. If a standard optimizer state dict is passed (e.g., during a transition or from a different checkpoint format), this will raise a KeyError. Using .get() would make this more robust.
| def load_state_dict(self, state_dict: dict) -> None: | |
| if self.muon is not None and state_dict["muon"] is not None: | |
| self.muon.load_state_dict(state_dict["muon"]) | |
| if self.adamw is not None and state_dict["adamw"] is not None: | |
| self.adamw.load_state_dict(state_dict["adamw"]) | |
| def load_state_dict(self, state_dict: dict) -> None: | |
| if self.muon is not None and state_dict.get("muon") is not None: | |
| self.muon.load_state_dict(state_dict["muon"]) | |
| if self.adamw is not None and state_dict.get("adamw") is not None: | |
| self.adamw.load_state_dict(state_dict["adamw"]) |
jyopari
left a comment
There was a problem hiding this comment.
Took a first pass - maybe we should have a test to make sure the Torch implementation of distributed muon works as intended?
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
No description provided.